Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorJonathan Tompson <tompson@cims.nyu.edu>2014-07-19 01:45:28 +0400
committerJonathan Tompson <tompson@cims.nyu.edu>2014-07-19 01:45:28 +0400
commit7a5870de32b7ca2285a41ee41d490a93b1152e12 (patch)
tree783d1f1df75082767b015acee486f8b1ca6aa772 /test
parentf19810729f293a0ecbd094572e87972a4adf9502 (diff)
parent2cab3858fc05685593a0f7e56fa2d7fef194b228 (diff)
Merge branch 'master' of git://github.com/torch/nn into flatten_table
Diffstat (limited to 'test')
-rw-r--r--test/test.lua214
1 files changed, 175 insertions, 39 deletions
diff --git a/test/test.lua b/test/test.lua
index b1b4f90..55fe7b6 100644
--- a/test/test.lua
+++ b/test/test.lua
@@ -1,5 +1,9 @@
require 'torch'
+-- you can easily test specific units like this:
+-- luajit -lnn -e "nn.test{'LookupTable'}"
+-- luajit -lnn -e "nn.test{'LookupTable', 'Add'}"
+
local mytester = torch.Tester()
local jac
local sjac
@@ -9,6 +13,16 @@ local expprecision = 1e-4
local nntest = {}
+local function equal(t1, t2, msg)
+ if (torch.type(t1) == "table") then
+ for k, v in pairs(t2) do
+ equal(t1[k], t2[k])
+ end
+ else
+ mytester:assertTensorEq(t1, t2, 0.00001, msg)
+ end
+end
+
function nntest.Add()
local ini = math.random(10,20)
local inj = math.random(10,20)
@@ -1822,6 +1836,13 @@ function nntest.LookupTable()
local ferr,berr = jac.testIO(module,input,minval,maxval)
mytester:asserteq(ferr, 0, torch.typename(module) .. ' - i/o forward err ')
mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err ')
+
+ -- accUpdate
+ module:accUpdateOnly()
+ mytester:assert(not module.gradWeight, 'gradWeight is nil')
+ module:float()
+ local output = module:forward(input)
+ module:backwardUpdate(input, output, 0.1)
end
function nntest.AddConstant()
@@ -1916,15 +1937,6 @@ function nntest.SelectTable()
{torch.Tensor(3,4,5):zero()},
{torch.Tensor(3,4,5):zero(), {torch.Tensor(3,4,5):zero()}}
}
- local function equal(t1, t2, msg)
- if (torch.type(t1) == "table") then
- for k, v in pairs(t2) do
- equal(t1[k], t2[k])
- end
- else
- mytester:assertTensorEq(t1, t2, 0.00001, msg)
- end
- end
local nonIdx = {2,3,4,1}
local module
for idx = 1,#input do
@@ -1944,6 +1956,137 @@ function nntest.SelectTable()
equal(gradInput[nonIdx[idx]], zeros[nonIdx[idx]], "gradInput[nonIdx] dimension " .. idx)
end
+function nntest.MixtureTable()
+ --[[ 2D ]]--
+ -- expertInput is a Table:
+ local expertInput = torch.randn(5,3,6)
+ local gradOutput = torch.randn(5,6)
+ local input = {
+ torch.rand(5,3),
+ {expertInput:select(2,1), expertInput:select(2,2), expertInput:select(2,3)}
+ }
+ local module = nn.MixtureTable()
+ local output = module:forward(input)
+ local output2 = torch.cmul(input[1]:view(5,3,1):expand(5,3,6), expertInput):sum(2)
+ mytester:assertTensorEq(output, output2, 0.000001, "mixture output")
+ local gradInput = module:backward(input, gradOutput)
+ local gradOutput2 = torch.view(gradOutput, 5, 1, 6):expandAs(expertInput)
+ local gaterGradInput2 = torch.cmul(gradOutput2, expertInput):sum(3):select(3,1)
+ mytester:assertTensorEq(gradInput[1], gaterGradInput2, 0.000001, "mixture gater gradInput")
+ local expertGradInput2 = torch.cmul(input[1]:view(5,3,1):expand(5,3,6), gradOutput:view(5,1,6):expand(5,3,6))
+ for i, expertGradInput in ipairs(gradInput[2]) do
+ mytester:assertTensorEq(expertGradInput, expertGradInput2:select(2,i), 0.000001, "mixture expert "..i.." gradInput")
+ end
+ -- expertInput is a Tensor:
+ local input = {input[1], expertInput}
+ local module = nn.MixtureTable(2)
+ local output = module:forward(input)
+ mytester:assertTensorEq(output, output2, 0.000001, "mixture2 output")
+ local gradInput = module:backward(input, gradOutput)
+ mytester:assertTensorEq(gradInput[1], gaterGradInput2, 0.000001, "mixture2 gater gradInput")
+ mytester:assertTensorEq(gradInput[2], expertGradInput2, 0.000001, "mixture2 expert gradInput")
+
+ --[[ 3D ]]--
+ local expertInput = torch.randn(5,6,3,2)
+ local gradOutput = torch.randn(5,6,2)
+ -- expertInput is a Table:
+ local input = {
+ torch.rand(5,3),
+ {expertInput:select(3,1), expertInput:select(3,2), expertInput:select(3,3)}
+ }
+ local module = nn.MixtureTable()
+ local output = module:forward(input)
+ local output2 = torch.cmul(input[1]:view(5,1,3,1):expand(5,6,3,2), expertInput):sum(3)
+ mytester:assertTensorEq(output, output2, 0.000001, "mixture3 output")
+ local gradInput = module:backward(input, gradOutput)
+ local gradOutput2 = torch.view(gradOutput,5,6,1,2):expandAs(expertInput)
+ local gaterGradInput2 = torch.cmul(gradOutput2, expertInput):sum(4):select(4,1):sum(2):select(2,1)
+ mytester:assertTensorEq(gradInput[1], gaterGradInput2, 0.000001, "mixture3 gater gradInput")
+ local expertGradInput2 = torch.cmul(input[1]:view(5,1,3,1):expand(5,6,3,2), gradOutput2)
+ for i, expertGradInput in ipairs(gradInput[2]) do
+ mytester:assertTensorEq(expertGradInput, expertGradInput2:select(3,i), 0.000001, "mixture3 expert "..i.." gradInput")
+ end
+ -- expertInput is a Tensor
+ local input = {input[1], expertInput}
+ local module = nn.MixtureTable(3)
+ local output = module:forward(input)
+ mytester:assertTensorEq(output, output2, 0.000001, "mixture4 output")
+ local gradInput = module:backward(input, gradOutput)
+ mytester:assertTensorEq(gradInput[1], gaterGradInput2, 0.000001, "mixture4 gater gradInput")
+ mytester:assertTensorEq(gradInput[2], expertGradInput2, 0.000001, "mixture4 expert gradInput")
+
+ --[[ 1D ]]--
+ -- expertInput is a Table:
+ local expertInput = torch.randn(3,6)
+ local gradOutput = torch.randn(6)
+ local input = {
+ torch.rand(3),
+ {expertInput:select(1,1), expertInput:select(1,2), expertInput:select(1,3)}
+ }
+ local module = nn.MixtureTable()
+ local output = module:forward(input)
+ local output2 = torch.cmul(input[1]:view(3,1):expand(3,6), expertInput):sum(1)
+ mytester:assertTensorEq(output, output2, 0.000001, "mixture5 output")
+ local gradInput = module:backward(input, gradOutput)
+ local gradOutput2 = torch.view(gradOutput, 1, 6):expandAs(expertInput)
+ local gaterGradInput2 = torch.cmul(gradOutput2, expertInput):sum(2):select(2,1)
+ mytester:assertTensorEq(gradInput[1], gaterGradInput2, 0.000001, "mixture5 gater gradInput")
+ local expertGradInput2 = torch.cmul(input[1]:view(3,1):expand(3,6), gradOutput:view(1,6):expand(3,6))
+ for i, expertGradInput in ipairs(gradInput[2]) do
+ mytester:assertTensorEq(expertGradInput, expertGradInput2:select(1,i), 0.000001, "mixture5 expert "..i.." gradInput")
+ end
+ -- test type-cast
+ module:float()
+ local input2 = {
+ input[1]:float(),
+ {input[2][1]:float(), input[2][2]:float(), input[2][3]:float()}
+ }
+ local output = module:forward(input2)
+ mytester:assertTensorEq(output, output2:float(), 0.000001, "mixture5B output")
+ local gradInput = module:backward(input2, gradOutput:float())
+ mytester:assertTensorEq(gradInput[1], gaterGradInput2:float(), 0.000001, "mixture5B gater gradInput")
+ for i, expertGradInput in ipairs(gradInput[2]) do
+ mytester:assertTensorEq(expertGradInput, expertGradInput2:select(1,i):float(), 0.000001, "mixture5B expert "..i.." gradInput")
+ end
+ -- expertInput is a Tensor:
+ local input = {input[1], expertInput}
+ local module = nn.MixtureTable(1)
+ local output = module:forward(input)
+ mytester:assertTensorEq(output, output2, 0.000001, "mixture6 output")
+ local gradInput = module:backward(input, gradOutput)
+ mytester:assertTensorEq(gradInput[1], gaterGradInput2, 0.000001, "mixture6 gater gradInput")
+ mytester:assertTensorEq(gradInput[2], expertGradInput2, 0.000001, "mixture6 expert gradInput")
+ -- test type-cast:
+ module:float()
+ local input2 = {input[1]:float(), expertInput:float()}
+ local output = module:forward(input2)
+ mytester:assertTensorEq(output, output2:float(), 0.000001, "mixture6B output")
+ local gradInput = module:backward(input2, gradOutput:float())
+ mytester:assertTensorEq(gradInput[1], gaterGradInput2:float(), 0.000001, "mixture6B gater gradInput")
+ mytester:assertTensorEq(gradInput[2], expertGradInput2:float(), 0.000001, "mixture6B expert gradInput")
+
+ --[[ 2D gater, 1D expert]]--
+ -- expertInput is a Table:
+ local expertInput = torch.randn(5,3)
+ local gradOutput = torch.randn(5)
+ local input = {
+ torch.rand(5,3),
+ {expertInput:select(2,1), expertInput:select(2,2), expertInput:select(2,3)}
+ }
+ local module = nn.MixtureTable()
+ local output = module:forward(input)
+ local output2 = torch.cmul(input[1], expertInput):sum(2)
+ mytester:assertTensorEq(output, output2, 0.000001, "mixture7 output")
+ local gradInput = module:backward(input, gradOutput)
+ local gradOutput2 = torch.view(gradOutput, 5, 1):expandAs(expertInput)
+ local gaterGradInput2 = torch.cmul(gradOutput2, expertInput)
+ mytester:assertTensorEq(gradInput[1], gaterGradInput2, 0.000001, "mixture7 gater gradInput")
+ local expertGradInput2 = torch.cmul(input[1], gradOutput:view(5,1):expand(5,3))
+ for i, expertGradInput in ipairs(gradInput[2]) do
+ mytester:assertTensorEq(expertGradInput, expertGradInput2:select(2,i), 0.000001, "mixture7 expert "..i.." gradInput")
+ end
+end
+
function nntest.View()
local input = torch.rand(10)
local template = torch.rand(5,2)
@@ -2013,36 +2156,29 @@ function nntest.ConcatTable()
mytester:asserteq(berr, 0, torch.typename(m)..' - i/o backward err ')
-- Now test a table input
- -- jac needs a tensor input, so we have to form a network that creates
- -- a table internally: Do this using a Reshape and a SplitTable
- m = nn.Sequential()
- m:add(nn.Reshape(1,10,10,10))
- m:add(nn.SplitTable(1)) -- output of Split table is a table of length 1
-
- concat = nn.ConcatTable()
- concat:add(nn.JoinTable(1))
-
- m:add(concat)
- m:add(nn.JoinTable(1))
-
- err = jac.testJacobian(m, input)
- mytester:assertlt(err, precision, ' error on state ')
-
- ferr, berr = jac.testIO(m, input)
- mytester:asserteq(ferr, 0, torch.typename(m)..' - i/o forward err ')
- mytester:asserteq(berr, 0, torch.typename(m)..' - i/o backward err ')
-
- -- As per Soumith's suggestion, make sure getParameters works:
- m = nn.ConcatTable()
- local l = nn.Linear(16,16)
- m:add(l)
- mparams = m:getParameters()
- -- I don't know of a way to make sure that the storage is equal, however
- -- the linear weight and bias will be randomly initialized, so just make
- -- sure both parameter sets are equal
- lparams = l:getParameters()
- err = (mparams - lparams):abs():max()
- mytester:assertlt(err, precision, ' getParameters error ')
+ local input = {
+ torch.randn(3,4):float(), torch.randn(3,4):float(), {torch.randn(3,4):float()}
+ }
+ local _gradOutput = {
+ torch.randn(3,3,4):float(), torch.randn(3,3,4):float(), torch.randn(3,3,4):float()
+ }
+ local gradOutput = {
+ {_gradOutput[1][1], _gradOutput[2][1], {_gradOutput[3][1]}},
+ {_gradOutput[1][2], _gradOutput[2][2], {_gradOutput[3][2]}},
+ {_gradOutput[1][3], _gradOutput[2][3], {_gradOutput[3][3]}}
+ }
+ local module = nn.ConcatTable()
+ module:add(nn.Identity())
+ module:add(nn.Identity())
+ module:add(nn.Identity())
+ module:float()
+
+ local output = module:forward(input)
+ local output2 = {input, input, input}
+ equal(output2, output, "ConcatTable table output")
+ local gradInput = module:backward(input, gradOutput)
+ local gradInput2 = {_gradOutput[1]:sum(1), _gradOutput[2]:sum(1), {_gradOutput[3]:sum(1)}}
+ equal(gradInput, gradInput2, "ConcatTable table gradInput")
end
function nntest.FlattenTable()