diff options
author | Jonathan Tompson <tompson@cims.nyu.edu> | 2014-07-08 19:42:43 +0400 |
---|---|---|
committer | Jonathan Tompson <tompson@cims.nyu.edu> | 2014-07-08 23:28:30 +0400 |
commit | 75a2279ef3dac76046f128a2d77e1ffd2dcd5397 (patch) | |
tree | 8453a26dedbd624768af327a469c51986f009600 /ConcatTable.lua | |
parent | 4f01eb0e7359e59ed435fe4bbd45a19bf5df9b17 (diff) |
updated ConcatTable so that it works with table inputs as well as tensors.
removed a temporary line.
added a test for getParameters to ConcatTable.
Diffstat (limited to 'ConcatTable.lua')
-rw-r--r-- | ConcatTable.lua | 27 |
1 files changed, 25 insertions, 2 deletions
diff --git a/ConcatTable.lua b/ConcatTable.lua index c42776d..de61ca5 100644 --- a/ConcatTable.lua +++ b/ConcatTable.lua @@ -30,9 +30,32 @@ function ConcatTable:updateGradInput(input, gradOutput) for i,module in ipairs(self.modules) do local currentGradInput = module:updateGradInput(input, gradOutput[i]) if i == 1 then - self.gradInput:resizeAs(currentGradInput):copy(currentGradInput) + if type(input) == 'table' then + assert(type(currentGradInput) == 'table', + 'currentGradInput is not a table!') + assert(#input == #currentGradInput, + 'table size mismatch') + -- gradInput is also a table + self.gradInput = {} + for j = 1, #currentGradInput do + self.gradInput[j] = currentGradInput[j]:clone() + end + else + -- gradInput is a tensor + self.gradInput:resizeAs(currentGradInput):copy(currentGradInput) + end else - self.gradInput:add(currentGradInput) + if type(input) == 'table' then + assert(type(currentGradInput) == 'table', + 'currentGradInput is not a table!') + assert(#input == #currentGradInput, + 'table size mismatch') + for j = 1, #self.gradInput do + self.gradInput[j]:add(currentGradInput[j]) + end + else + self.gradInput:add(currentGradInput) + end end end return self.gradInput |