Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authornicholas-leonard <nick@nikopia.org>2014-07-17 23:58:10 +0400
committernicholas-leonard <nick@nikopia.org>2014-07-17 23:58:10 +0400
commit703932c5aca5e9f664e51e905991fd71f8bb46d9 (patch)
tree3967324407b77d5f34617ec06ea67b63a05a393d /test
parente074440aaf6f02491c9685d4e254ecb159622414 (diff)
parentadbd1ef0638358f2b54980112ec10cd322544de0 (diff)
Merge branch 'master' of github.com:torch/nn into mixture
Diffstat (limited to 'test')
-rw-r--r--test/test.lua77
1 files changed, 36 insertions, 41 deletions
diff --git a/test/test.lua b/test/test.lua
index 8c17e90..4c93040 100644
--- a/test/test.lua
+++ b/test/test.lua
@@ -13,6 +13,16 @@ local expprecision = 1e-4
local nntest = {}
+local function equal(t1, t2, msg)
+ if (torch.type(t1) == "table") then
+ for k, v in pairs(t2) do
+ equal(t1[k], t2[k])
+ end
+ else
+ mytester:assertTensorEq(t1, t2, 0.00001, msg)
+ end
+end
+
function nntest.Add()
local ini = math.random(10,20)
local inj = math.random(10,20)
@@ -481,8 +491,9 @@ function nntest.WeightedMSECriterion()
end
function nntest.BCECriterion()
- local input = torch.rand(100)
- local target = input:clone():add(torch.rand(100))
+ local eps = 1e-2
+ local input = torch.rand(100)*(1-eps) + eps/2
+ local target = torch.rand(100)*(1-eps) + eps/2
local cri = nn.BCECriterion()
criterionJacobianTest1D(cri, input, target)
end
@@ -1926,15 +1937,6 @@ function nntest.SelectTable()
{torch.Tensor(3,4,5):zero()},
{torch.Tensor(3,4,5):zero(), {torch.Tensor(3,4,5):zero()}}
}
- local function equal(t1, t2, msg)
- if (torch.type(t1) == "table") then
- for k, v in pairs(t2) do
- equal(t1[k], t2[k])
- end
- else
- mytester:assertTensorEq(t1, t2, 0.00001, msg)
- end
- end
local nonIdx = {2,3,4,1}
local module
for idx = 1,#input do
@@ -2154,36 +2156,29 @@ function nntest.ConcatTable()
mytester:asserteq(berr, 0, torch.typename(m)..' - i/o backward err ')
-- Now test a table input
- -- jac needs a tensor input, so we have to form a network that creates
- -- a table internally: Do this using a Reshape and a SplitTable
- m = nn.Sequential()
- m:add(nn.Reshape(1,10,10,10))
- m:add(nn.SplitTable(1)) -- output of Split table is a table of length 1
-
- concat = nn.ConcatTable()
- concat:add(nn.JoinTable(1))
-
- m:add(concat)
- m:add(nn.JoinTable(1))
-
- err = jac.testJacobian(m, input)
- mytester:assertlt(err, precision, ' error on state ')
-
- ferr, berr = jac.testIO(m, input)
- mytester:asserteq(ferr, 0, torch.typename(m)..' - i/o forward err ')
- mytester:asserteq(berr, 0, torch.typename(m)..' - i/o backward err ')
-
- -- As per Soumith's suggestion, make sure getParameters works:
- m = nn.ConcatTable()
- local l = nn.Linear(16,16)
- m:add(l)
- mparams = m:getParameters()
- -- I don't know of a way to make sure that the storage is equal, however
- -- the linear weight and bias will be randomly initialized, so just make
- -- sure both parameter sets are equal
- lparams = l:getParameters()
- err = (mparams - lparams):abs():max()
- mytester:assertlt(err, precision, ' getParameters error ')
+ local input = {
+ torch.randn(3,4):float(), torch.randn(3,4):float(), {torch.randn(3,4):float()}
+ }
+ local _gradOutput = {
+ torch.randn(3,3,4):float(), torch.randn(3,3,4):float(), torch.randn(3,3,4):float()
+ }
+ local gradOutput = {
+ {_gradOutput[1][1], _gradOutput[2][1], {_gradOutput[3][1]}},
+ {_gradOutput[1][2], _gradOutput[2][2], {_gradOutput[3][2]}},
+ {_gradOutput[1][3], _gradOutput[2][3], {_gradOutput[3][3]}}
+ }
+ local module = nn.ConcatTable()
+ module:add(nn.Identity())
+ module:add(nn.Identity())
+ module:add(nn.Identity())
+ module:float()
+
+ local output = module:forward(input)
+ local output2 = {input, input, input}
+ equal(output2, output, "ConcatTable table output")
+ local gradInput = module:backward(input, gradOutput)
+ local gradInput2 = {_gradOutput[1]:sum(1), _gradOutput[2]:sum(1), {_gradOutput[3]:sum(1)}}
+ equal(gradInput, gradInput2, "ConcatTable table gradInput")
end
mytester:add(nntest)