Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorSoumith Chintala <soumith@gmail.com>2014-07-17 19:24:13 +0400
committerSoumith Chintala <soumith@gmail.com>2014-07-17 19:24:13 +0400
commitab40dc77545eac562d3eb2babe25d3522c706954 (patch)
treef03ab950027433897971052eaf531b87b8133324 /test
parent2fc7c68c9adeb1c97790a6c9e1684759bfc23958 (diff)
parent4baec0485c839ac7017048d88c4b125de6ad72a5 (diff)
Merge pull request #38 from nicholas-leonard/concat
ConcatTable nested table input
Diffstat (limited to 'test')
-rw-r--r--test/test.lua72
1 files changed, 33 insertions, 39 deletions
diff --git a/test/test.lua b/test/test.lua
index a5550b7..1330607 100644
--- a/test/test.lua
+++ b/test/test.lua
@@ -9,6 +9,16 @@ local expprecision = 1e-4
local nntest = {}
+local function equal(t1, t2, msg)
+ if (torch.type(t1) == "table") then
+ for k, v in pairs(t2) do
+ equal(t1[k], t2[k])
+ end
+ else
+ mytester:assertTensorEq(t1, t2, 0.00001, msg)
+ end
+end
+
function nntest.Add()
local ini = math.random(10,20)
local inj = math.random(10,20)
@@ -1916,15 +1926,6 @@ function nntest.SelectTable()
{torch.Tensor(3,4,5):zero()},
{torch.Tensor(3,4,5):zero(), {torch.Tensor(3,4,5):zero()}}
}
- local function equal(t1, t2, msg)
- if (torch.type(t1) == "table") then
- for k, v in pairs(t2) do
- equal(t1[k], t2[k])
- end
- else
- mytester:assertTensorEq(t1, t2, 0.00001, msg)
- end
- end
local nonIdx = {2,3,4,1}
local module
for idx = 1,#input do
@@ -2013,36 +2014,29 @@ function nntest.ConcatTable()
mytester:asserteq(berr, 0, torch.typename(m)..' - i/o backward err ')
-- Now test a table input
- -- jac needs a tensor input, so we have to form a network that creates
- -- a table internally: Do this using a Reshape and a SplitTable
- m = nn.Sequential()
- m:add(nn.Reshape(1,10,10,10))
- m:add(nn.SplitTable(1)) -- output of Split table is a table of length 1
-
- concat = nn.ConcatTable()
- concat:add(nn.JoinTable(1))
-
- m:add(concat)
- m:add(nn.JoinTable(1))
-
- err = jac.testJacobian(m, input)
- mytester:assertlt(err, precision, ' error on state ')
-
- ferr, berr = jac.testIO(m, input)
- mytester:asserteq(ferr, 0, torch.typename(m)..' - i/o forward err ')
- mytester:asserteq(berr, 0, torch.typename(m)..' - i/o backward err ')
-
- -- As per Soumith's suggestion, make sure getParameters works:
- m = nn.ConcatTable()
- local l = nn.Linear(16,16)
- m:add(l)
- mparams = m:getParameters()
- -- I don't know of a way to make sure that the storage is equal, however
- -- the linear weight and bias will be randomly initialized, so just make
- -- sure both parameter sets are equal
- lparams = l:getParameters()
- err = (mparams - lparams):abs():max()
- mytester:assertlt(err, precision, ' getParameters error ')
+ local input = {
+ torch.randn(3,4):float(), torch.randn(3,4):float(), {torch.randn(3,4):float()}
+ }
+ local _gradOutput = {
+ torch.randn(3,3,4):float(), torch.randn(3,3,4):float(), torch.randn(3,3,4):float()
+ }
+ local gradOutput = {
+ {_gradOutput[1][1], _gradOutput[2][1], {_gradOutput[3][1]}},
+ {_gradOutput[1][2], _gradOutput[2][2], {_gradOutput[3][2]}},
+ {_gradOutput[1][3], _gradOutput[2][3], {_gradOutput[3][3]}}
+ }
+ local module = nn.ConcatTable()
+ module:add(nn.Identity())
+ module:add(nn.Identity())
+ module:add(nn.Identity())
+ module:float()
+
+ local output = module:forward(input)
+ local output2 = {input, input, input}
+ equal(output2, output, "ConcatTable table output")
+ local gradInput = module:backward(input, gradOutput)
+ local gradInput2 = {_gradOutput[1]:sum(1), _gradOutput[2]:sum(1), {_gradOutput[3]:sum(1)}}
+ equal(gradInput, gradInput2, "ConcatTable table gradInput")
end
mytester:add(nntest)