diff options
author | Jonathan Tompson <tompson@cims.nyu.edu> | 2014-07-08 19:42:43 +0400 |
---|---|---|
committer | Jonathan Tompson <tompson@cims.nyu.edu> | 2014-07-08 23:28:30 +0400 |
commit | 75a2279ef3dac76046f128a2d77e1ffd2dcd5397 (patch) | |
tree | 8453a26dedbd624768af327a469c51986f009600 /test | |
parent | 4f01eb0e7359e59ed435fe4bbd45a19bf5df9b17 (diff) |
updated ConcatTable so that it works with table inputs as well as tensors.
removed a temporary line.
added a test for getParameters to ConcatTable.
Diffstat (limited to 'test')
-rw-r--r-- | test/test.lua | 51 |
1 files changed, 51 insertions, 0 deletions
diff --git a/test/test.lua b/test/test.lua index c88c908..9ecc923 100644 --- a/test/test.lua +++ b/test/test.lua @@ -1920,6 +1920,57 @@ function nntest.SpatialUpSamplingNearest() end end +function nntest.ConcatTable() + -- Test tensor input + local input = torch.rand(10, 10, 10) + local m = nn.Sequential() + + local concat = nn.ConcatTable() + concat:add(nn.Identity()) + + m:add(concat) -- Output of concat is a table of length 1 + m:add(nn.JoinTable(1)) -- jac needs a tensor tensor output + + local err = jac.testJacobian(m, input) + mytester:assertlt(err, precision, ' error on state ') + + local ferr, berr = jac.testIO(m, input) + mytester:asserteq(ferr, 0, torch.typename(m)..' - i/o forward err ') + mytester:asserteq(berr, 0, torch.typename(m)..' - i/o backward err ') + + -- Now test a table input + -- jac needs a tensor input, so we have to form a network that creates + -- a table internally: Do this using a Reshape and a SplitTable + m = nn.Sequential() + m:add(nn.Reshape(1,10,10,10)) + m:add(nn.SplitTable(1)) -- output of Split table is a table of length 1 + + concat = nn.ConcatTable() + concat:add(nn.JoinTable(1)) + + m:add(concat) + m:add(nn.JoinTable(1)) + + err = jac.testJacobian(m, input) + mytester:assertlt(err, precision, ' error on state ') + + ferr, berr = jac.testIO(m, input) + mytester:asserteq(ferr, 0, torch.typename(m)..' - i/o forward err ') + mytester:asserteq(berr, 0, torch.typename(m)..' - i/o backward err ') + + -- As per Soumith's suggestion, make sure getParameters works: + m = nn.ConcatTable() + local l = nn.Linear(16,16) + m:add(l) + mparams = m:getParameters() + -- I don't know of a way to make sure that the storage is equal, however + -- the linear weight and bias will be randomly initialized, so just make + -- sure both parameter sets are equal + lparams = l:getParameters() + err = (mparams - lparams):abs():max() + mytester:assertlt(err, precision, ' getParameters error ') +end + mytester:add(nntest) if not nn then |