Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authornicholas-leonard <nick@nikopia.org>2014-07-09 21:48:33 +0400
committernicholas-leonard <nick@nikopia.org>2014-07-09 21:48:33 +0400
commit02b7e7205db5359f1df63f7bee3439a320e004bf (patch)
tree2c732327fcbdba8e7e072f0ac036699a2cb12b22 /test
parent157004f72ff7fa9f9c7216ca6b0b766476b4696c (diff)
ElementTable unit tests
Diffstat (limited to 'test')
-rw-r--r--test/test.lua43
1 files changed, 43 insertions, 0 deletions
diff --git a/test/test.lua b/test/test.lua
index c88c908..0c9a43c 100644
--- a/test/test.lua
+++ b/test/test.lua
@@ -1869,6 +1869,49 @@ function nntest.SplitTable()
end
end
+function nntest.ElementTable()
+ local input = {
+ torch.rand(3,4,5), torch.rand(3,4,5),
+ {torch.rand(3,4,5)},
+ {torch.rand(3,4,5), {torch.rand(3,4,5)}}
+ }
+ local gradOutputs = {
+ torch.rand(3,4,5), torch.rand(3,4,5),
+ {torch.rand(3,4,5)},
+ {torch.rand(3,4,5), {torch.rand(3,4,5)}}
+ }
+ local zeros = {
+ torch.Tensor(3,4,5):zero(), torch.Tensor(3,4,5):zero(),
+ {torch.Tensor(3,4,5):zero()},
+ {torch.Tensor(3,4,5):zero(), {torch.Tensor(3,4,5):zero()}}
+ }
+ local function equal(t1, t2, msg)
+ if (torch.type(t1) == "table") then
+ for k, v in pairs(t2) do
+ equal(t1[k], t2[k])
+ end
+ else
+ mytester:assertTensorEq(t1, t2, 0.00001, msg)
+ end
+ end
+ local nonIdx = {2,3,4,1}
+ local module
+ for idx = 1,#input do
+ module = nn.ElementTable(idx)
+ local output = module:forward(input)
+ equal(output, input[idx], "output dimension " .. idx)
+ local gradInput = module:backward(input, gradOutputs[idx])
+ equal(gradInput[idx], gradOutputs[idx], "gradInput[idx] dimension " .. idx)
+ equal(gradInput[nonIdx[idx]], zeros[nonIdx[idx]], "gradInput[nonIdx] dimension " .. idx)
+ end
+ module:float()
+ local idx = #input
+ local output = module:forward(input)
+ equal(output, input[idx], "type output")
+ local gradInput = module:backward(input, gradOutputs[idx])
+ equal(gradInput[idx], gradOutputs[idx], "gradInput[idx] dimension " .. idx)
+ equal(gradInput[nonIdx[idx]], zeros[nonIdx[idx]], "gradInput[nonIdx] dimension " .. idx)
+end
function nntest.View()
local input = torch.rand(10)