Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJonathan Tompson <tompson@cims.nyu.edu>2014-07-08 19:42:43 +0400
committerJonathan Tompson <tompson@cims.nyu.edu>2014-07-08 23:28:30 +0400
commit75a2279ef3dac76046f128a2d77e1ffd2dcd5397 (patch)
tree8453a26dedbd624768af327a469c51986f009600 /ConcatTable.lua
parent4f01eb0e7359e59ed435fe4bbd45a19bf5df9b17 (diff)
updated ConcatTable so that it works with table inputs as well as tensors.
removed a temporary line. added a test for getParameters to ConcatTable.
Diffstat (limited to 'ConcatTable.lua')
-rw-r--r--ConcatTable.lua27
1 files changed, 25 insertions, 2 deletions
diff --git a/ConcatTable.lua b/ConcatTable.lua
index c42776d..de61ca5 100644
--- a/ConcatTable.lua
+++ b/ConcatTable.lua
@@ -30,9 +30,32 @@ function ConcatTable:updateGradInput(input, gradOutput)
for i,module in ipairs(self.modules) do
local currentGradInput = module:updateGradInput(input, gradOutput[i])
if i == 1 then
- self.gradInput:resizeAs(currentGradInput):copy(currentGradInput)
+ if type(input) == 'table' then
+ assert(type(currentGradInput) == 'table',
+ 'currentGradInput is not a table!')
+ assert(#input == #currentGradInput,
+ 'table size mismatch')
+ -- gradInput is also a table
+ self.gradInput = {}
+ for j = 1, #currentGradInput do
+ self.gradInput[j] = currentGradInput[j]:clone()
+ end
+ else
+ -- gradInput is a tensor
+ self.gradInput:resizeAs(currentGradInput):copy(currentGradInput)
+ end
else
- self.gradInput:add(currentGradInput)
+ if type(input) == 'table' then
+ assert(type(currentGradInput) == 'table',
+ 'currentGradInput is not a table!')
+ assert(#input == #currentGradInput,
+ 'table size mismatch')
+ for j = 1, #self.gradInput do
+ self.gradInput[j]:add(currentGradInput[j])
+ end
+ else
+ self.gradInput:add(currentGradInput)
+ end
end
end
return self.gradInput