Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorJonathan Tompson <tompson@cims.nyu.edu>2014-07-08 19:42:43 +0400
committerJonathan Tompson <tompson@cims.nyu.edu>2014-07-08 23:28:30 +0400
commit75a2279ef3dac76046f128a2d77e1ffd2dcd5397 (patch)
tree8453a26dedbd624768af327a469c51986f009600 /test
parent4f01eb0e7359e59ed435fe4bbd45a19bf5df9b17 (diff)
updated ConcatTable so that it works with table inputs as well as tensors.
removed a temporary line. added a test for getParameters to ConcatTable.
Diffstat (limited to 'test')
-rw-r--r--test/test.lua51
1 files changed, 51 insertions, 0 deletions
diff --git a/test/test.lua b/test/test.lua
index c88c908..9ecc923 100644
--- a/test/test.lua
+++ b/test/test.lua
@@ -1920,6 +1920,57 @@ function nntest.SpatialUpSamplingNearest()
end
end
+function nntest.ConcatTable()
+ -- Test tensor input
+ local input = torch.rand(10, 10, 10)
+ local m = nn.Sequential()
+
+ local concat = nn.ConcatTable()
+ concat:add(nn.Identity())
+
+ m:add(concat) -- Output of concat is a table of length 1
+ m:add(nn.JoinTable(1)) -- jac needs a tensor tensor output
+
+ local err = jac.testJacobian(m, input)
+ mytester:assertlt(err, precision, ' error on state ')
+
+ local ferr, berr = jac.testIO(m, input)
+ mytester:asserteq(ferr, 0, torch.typename(m)..' - i/o forward err ')
+ mytester:asserteq(berr, 0, torch.typename(m)..' - i/o backward err ')
+
+ -- Now test a table input
+ -- jac needs a tensor input, so we have to form a network that creates
+ -- a table internally: Do this using a Reshape and a SplitTable
+ m = nn.Sequential()
+ m:add(nn.Reshape(1,10,10,10))
+ m:add(nn.SplitTable(1)) -- output of Split table is a table of length 1
+
+ concat = nn.ConcatTable()
+ concat:add(nn.JoinTable(1))
+
+ m:add(concat)
+ m:add(nn.JoinTable(1))
+
+ err = jac.testJacobian(m, input)
+ mytester:assertlt(err, precision, ' error on state ')
+
+ ferr, berr = jac.testIO(m, input)
+ mytester:asserteq(ferr, 0, torch.typename(m)..' - i/o forward err ')
+ mytester:asserteq(berr, 0, torch.typename(m)..' - i/o backward err ')
+
+ -- As per Soumith's suggestion, make sure getParameters works:
+ m = nn.ConcatTable()
+ local l = nn.Linear(16,16)
+ m:add(l)
+ mparams = m:getParameters()
+ -- I don't know of a way to make sure that the storage is equal, however
+ -- the linear weight and bias will be randomly initialized, so just make
+ -- sure both parameter sets are equal
+ lparams = l:getParameters()
+ err = (mparams - lparams):abs():max()
+ mytester:assertlt(err, precision, ' getParameters error ')
+end
+
mytester:add(nntest)
if not nn then