Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authornicholas-leonard <nick@nikopia.org>2014-07-09 22:42:50 +0400
committernicholas-leonard <nick@nikopia.org>2014-07-09 22:42:50 +0400
commitb6f924b855d8ed79387832d1c196bcb8d805555e (patch)
tree39ba412ace3802323f6b799973de11e1b434ea6b /test
parent12355083c7d11328bdb1b610c76c936eaa6ed293 (diff)
parent3f35d9bce7e0335b801f3bfe36d8a86cd53ba4ed (diff)
Merge branch 'master' of github.com:torch/nn into nnx
Diffstat (limited to 'test')
-rw-r--r--test/test.lua43
1 files changed, 43 insertions, 0 deletions
diff --git a/test/test.lua b/test/test.lua
index 8b4cfb9..a3c816f 100644
--- a/test/test.lua
+++ b/test/test.lua
@@ -1879,6 +1879,49 @@ function nntest.SplitTable()
end
end
+function nntest.ElementTable()
+ local input = {
+ torch.rand(3,4,5), torch.rand(3,4,5),
+ {torch.rand(3,4,5)},
+ {torch.rand(3,4,5), {torch.rand(3,4,5)}}
+ }
+ local gradOutputs = {
+ torch.rand(3,4,5), torch.rand(3,4,5),
+ {torch.rand(3,4,5)},
+ {torch.rand(3,4,5), {torch.rand(3,4,5)}}
+ }
+ local zeros = {
+ torch.Tensor(3,4,5):zero(), torch.Tensor(3,4,5):zero(),
+ {torch.Tensor(3,4,5):zero()},
+ {torch.Tensor(3,4,5):zero(), {torch.Tensor(3,4,5):zero()}}
+ }
+ local function equal(t1, t2, msg)
+ if (torch.type(t1) == "table") then
+ for k, v in pairs(t2) do
+ equal(t1[k], t2[k])
+ end
+ else
+ mytester:assertTensorEq(t1, t2, 0.00001, msg)
+ end
+ end
+ local nonIdx = {2,3,4,1}
+ local module
+ for idx = 1,#input do
+ module = nn.ElementTable(idx)
+ local output = module:forward(input)
+ equal(output, input[idx], "output dimension " .. idx)
+ local gradInput = module:backward(input, gradOutputs[idx])
+ equal(gradInput[idx], gradOutputs[idx], "gradInput[idx] dimension " .. idx)
+ equal(gradInput[nonIdx[idx]], zeros[nonIdx[idx]], "gradInput[nonIdx] dimension " .. idx)
+ end
+ module:float()
+ local idx = #input
+ local output = module:forward(input)
+ equal(output, input[idx], "type output")
+ local gradInput = module:backward(input, gradOutputs[idx])
+ equal(gradInput[idx], gradOutputs[idx], "gradInput[idx] dimension " .. idx)
+ equal(gradInput[nonIdx[idx]], zeros[nonIdx[idx]], "gradInput[nonIdx] dimension " .. idx)
+end
function nntest.View()
local input = torch.rand(10)