diff options
-rw-r--r-- | ElementTable.lua | 5 | ||||
-rw-r--r-- | test/test.lua | 43 |
2 files changed, 45 insertions, 3 deletions
diff --git a/ElementTable.lua b/ElementTable.lua index b1b28d0..cb3ff0f 100644 --- a/ElementTable.lua +++ b/ElementTable.lua @@ -7,12 +7,12 @@ function ElementTable:__init(index) end function ElementTable:updateOutput(input) - self.output:set(input[self.index]) + self.output = input[self.index] return self.output end function ElementTable:updateGradInput(input, gradOutput) - if #gradInput == 0 then + if #self.gradInput == 0 then local function zeroTableCopy(t1, t2) for k, v in pairs(t2) do if (torch.type(v) == "table") then @@ -31,5 +31,4 @@ end function ElementTable:type(type) self.gradInput = {} - self.output = self.output:type(type) end diff --git a/test/test.lua b/test/test.lua index c88c908..0c9a43c 100644 --- a/test/test.lua +++ b/test/test.lua @@ -1869,6 +1869,49 @@ function nntest.SplitTable() end end +function nntest.ElementTable() + local input = { + torch.rand(3,4,5), torch.rand(3,4,5), + {torch.rand(3,4,5)}, + {torch.rand(3,4,5), {torch.rand(3,4,5)}} + } + local gradOutputs = { + torch.rand(3,4,5), torch.rand(3,4,5), + {torch.rand(3,4,5)}, + {torch.rand(3,4,5), {torch.rand(3,4,5)}} + } + local zeros = { + torch.Tensor(3,4,5):zero(), torch.Tensor(3,4,5):zero(), + {torch.Tensor(3,4,5):zero()}, + {torch.Tensor(3,4,5):zero(), {torch.Tensor(3,4,5):zero()}} + } + local function equal(t1, t2, msg) + if (torch.type(t1) == "table") then + for k, v in pairs(t2) do + equal(t1[k], t2[k]) + end + else + mytester:assertTensorEq(t1, t2, 0.00001, msg) + end + end + local nonIdx = {2,3,4,1} + local module + for idx = 1,#input do + module = nn.ElementTable(idx) + local output = module:forward(input) + equal(output, input[idx], "output dimension " .. idx) + local gradInput = module:backward(input, gradOutputs[idx]) + equal(gradInput[idx], gradOutputs[idx], "gradInput[idx] dimension " .. idx) + equal(gradInput[nonIdx[idx]], zeros[nonIdx[idx]], "gradInput[nonIdx] dimension " .. idx) + end + module:float() + local idx = #input + local output = module:forward(input) + equal(output, input[idx], "type output") + local gradInput = module:backward(input, gradOutputs[idx]) + equal(gradInput[idx], gradOutputs[idx], "gradInput[idx] dimension " .. idx) + equal(gradInput[nonIdx[idx]], zeros[nonIdx[idx]], "gradInput[nonIdx] dimension " .. idx) +end function nntest.View() local input = torch.rand(10) |