Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authornicholas-leonard <nick@nikopia.org>2014-07-09 21:48:33 +0400
committernicholas-leonard <nick@nikopia.org>2014-07-09 21:48:33 +0400
commit02b7e7205db5359f1df63f7bee3439a320e004bf (patch)
tree2c732327fcbdba8e7e072f0ac036699a2cb12b22
parent157004f72ff7fa9f9c7216ca6b0b766476b4696c (diff)
ElementTable unit tests
-rw-r--r--ElementTable.lua5
-rw-r--r--test/test.lua43
2 files changed, 45 insertions, 3 deletions
diff --git a/ElementTable.lua b/ElementTable.lua
index b1b28d0..cb3ff0f 100644
--- a/ElementTable.lua
+++ b/ElementTable.lua
@@ -7,12 +7,12 @@ function ElementTable:__init(index)
end
function ElementTable:updateOutput(input)
- self.output:set(input[self.index])
+ self.output = input[self.index]
return self.output
end
function ElementTable:updateGradInput(input, gradOutput)
- if #gradInput == 0 then
+ if #self.gradInput == 0 then
local function zeroTableCopy(t1, t2)
for k, v in pairs(t2) do
if (torch.type(v) == "table") then
@@ -31,5 +31,4 @@ end
function ElementTable:type(type)
self.gradInput = {}
- self.output = self.output:type(type)
end
diff --git a/test/test.lua b/test/test.lua
index c88c908..0c9a43c 100644
--- a/test/test.lua
+++ b/test/test.lua
@@ -1869,6 +1869,49 @@ function nntest.SplitTable()
end
end
+function nntest.ElementTable()
+ local input = {
+ torch.rand(3,4,5), torch.rand(3,4,5),
+ {torch.rand(3,4,5)},
+ {torch.rand(3,4,5), {torch.rand(3,4,5)}}
+ }
+ local gradOutputs = {
+ torch.rand(3,4,5), torch.rand(3,4,5),
+ {torch.rand(3,4,5)},
+ {torch.rand(3,4,5), {torch.rand(3,4,5)}}
+ }
+ local zeros = {
+ torch.Tensor(3,4,5):zero(), torch.Tensor(3,4,5):zero(),
+ {torch.Tensor(3,4,5):zero()},
+ {torch.Tensor(3,4,5):zero(), {torch.Tensor(3,4,5):zero()}}
+ }
+ local function equal(t1, t2, msg)
+ if (torch.type(t1) == "table") then
+ for k, v in pairs(t2) do
+ equal(t1[k], t2[k])
+ end
+ else
+ mytester:assertTensorEq(t1, t2, 0.00001, msg)
+ end
+ end
+ local nonIdx = {2,3,4,1}
+ local module
+ for idx = 1,#input do
+ module = nn.ElementTable(idx)
+ local output = module:forward(input)
+ equal(output, input[idx], "output dimension " .. idx)
+ local gradInput = module:backward(input, gradOutputs[idx])
+ equal(gradInput[idx], gradOutputs[idx], "gradInput[idx] dimension " .. idx)
+ equal(gradInput[nonIdx[idx]], zeros[nonIdx[idx]], "gradInput[nonIdx] dimension " .. idx)
+ end
+ module:float()
+ local idx = #input
+ local output = module:forward(input)
+ equal(output, input[idx], "type output")
+ local gradInput = module:backward(input, gradOutputs[idx])
+ equal(gradInput[idx], gradOutputs[idx], "gradInput[idx] dimension " .. idx)
+ equal(gradInput[nonIdx[idx]], zeros[nonIdx[idx]], "gradInput[nonIdx] dimension " .. idx)
+end
function nntest.View()
local input = torch.rand(10)