diff options
author | Soumith Chintala <soumith@gmail.com> | 2014-07-17 19:24:13 +0400 |
---|---|---|
committer | Soumith Chintala <soumith@gmail.com> | 2014-07-17 19:24:13 +0400 |
commit | ab40dc77545eac562d3eb2babe25d3522c706954 (patch) | |
tree | f03ab950027433897971052eaf531b87b8133324 /test | |
parent | 2fc7c68c9adeb1c97790a6c9e1684759bfc23958 (diff) | |
parent | 4baec0485c839ac7017048d88c4b125de6ad72a5 (diff) |
Merge pull request #38 from nicholas-leonard/concat
ConcatTable nested table input
Diffstat (limited to 'test')
-rw-r--r-- | test/test.lua | 72 |
1 files changed, 33 insertions, 39 deletions
diff --git a/test/test.lua b/test/test.lua index a5550b7..1330607 100644 --- a/test/test.lua +++ b/test/test.lua @@ -9,6 +9,16 @@ local expprecision = 1e-4 local nntest = {} +local function equal(t1, t2, msg) + if (torch.type(t1) == "table") then + for k, v in pairs(t2) do + equal(t1[k], t2[k]) + end + else + mytester:assertTensorEq(t1, t2, 0.00001, msg) + end +end + function nntest.Add() local ini = math.random(10,20) local inj = math.random(10,20) @@ -1916,15 +1926,6 @@ function nntest.SelectTable() {torch.Tensor(3,4,5):zero()}, {torch.Tensor(3,4,5):zero(), {torch.Tensor(3,4,5):zero()}} } - local function equal(t1, t2, msg) - if (torch.type(t1) == "table") then - for k, v in pairs(t2) do - equal(t1[k], t2[k]) - end - else - mytester:assertTensorEq(t1, t2, 0.00001, msg) - end - end local nonIdx = {2,3,4,1} local module for idx = 1,#input do @@ -2013,36 +2014,29 @@ function nntest.ConcatTable() mytester:asserteq(berr, 0, torch.typename(m)..' - i/o backward err ') -- Now test a table input - -- jac needs a tensor input, so we have to form a network that creates - -- a table internally: Do this using a Reshape and a SplitTable - m = nn.Sequential() - m:add(nn.Reshape(1,10,10,10)) - m:add(nn.SplitTable(1)) -- output of Split table is a table of length 1 - - concat = nn.ConcatTable() - concat:add(nn.JoinTable(1)) - - m:add(concat) - m:add(nn.JoinTable(1)) - - err = jac.testJacobian(m, input) - mytester:assertlt(err, precision, ' error on state ') - - ferr, berr = jac.testIO(m, input) - mytester:asserteq(ferr, 0, torch.typename(m)..' - i/o forward err ') - mytester:asserteq(berr, 0, torch.typename(m)..' - i/o backward err ') - - -- As per Soumith's suggestion, make sure getParameters works: - m = nn.ConcatTable() - local l = nn.Linear(16,16) - m:add(l) - mparams = m:getParameters() - -- I don't know of a way to make sure that the storage is equal, however - -- the linear weight and bias will be randomly initialized, so just make - -- sure both parameter sets are equal - lparams = l:getParameters() - err = (mparams - lparams):abs():max() - mytester:assertlt(err, precision, ' getParameters error ') + local input = { + torch.randn(3,4):float(), torch.randn(3,4):float(), {torch.randn(3,4):float()} + } + local _gradOutput = { + torch.randn(3,3,4):float(), torch.randn(3,3,4):float(), torch.randn(3,3,4):float() + } + local gradOutput = { + {_gradOutput[1][1], _gradOutput[2][1], {_gradOutput[3][1]}}, + {_gradOutput[1][2], _gradOutput[2][2], {_gradOutput[3][2]}}, + {_gradOutput[1][3], _gradOutput[2][3], {_gradOutput[3][3]}} + } + local module = nn.ConcatTable() + module:add(nn.Identity()) + module:add(nn.Identity()) + module:add(nn.Identity()) + module:float() + + local output = module:forward(input) + local output2 = {input, input, input} + equal(output2, output, "ConcatTable table output") + local gradInput = module:backward(input, gradOutput) + local gradInput2 = {_gradOutput[1]:sum(1), _gradOutput[2]:sum(1), {_gradOutput[3]:sum(1)}} + equal(gradInput, gradInput2, "ConcatTable table gradInput") end mytester:add(nntest) |