diff options
author | Sam Gross <sgross@fb.com> | 2016-02-20 02:16:04 +0300 |
---|---|---|
committer | Sam Gross <sgross@fb.com> | 2016-02-20 02:16:04 +0300 |
commit | 69b2dff62e026a07cd56c76408e653fb9dee0b9d (patch) | |
tree | 9a912bb73707197b72d640b588355dc816f135ff /ConcatTable.lua | |
parent | afa8d1a6cf0977d476d9195a984a55587682f5cc (diff) |
Override Module:backwards in ConcatTable.
Overrides backwards to recursively call backwards on child modules. This
makes it easier to share gradInput storages between modules to save
memory.
Diffstat (limited to 'ConcatTable.lua')
-rw-r--r-- | ConcatTable.lua | 14 |
1 files changed, 11 insertions, 3 deletions
diff --git a/ConcatTable.lua b/ConcatTable.lua index 61e9011..eddf6cc 100644 --- a/ConcatTable.lua +++ b/ConcatTable.lua @@ -27,12 +27,12 @@ local function retable(t1, t2, f) return t1 end -function ConcatTable:updateGradInput(input, gradOutput) +local function backward(self, method, input, gradOutput, scale) local isTable = torch.type(input) == 'table' local wasTable = torch.type(self.gradInput) == 'table' if isTable then for i,module in ipairs(self.modules) do - local currentGradInput = module:updateGradInput(input, gradOutput[i]) + local currentGradInput = module[method](module, input, gradOutput[i], scale) if torch.type(currentGradInput) ~= 'table' then error"currentGradInput is not a table!" end @@ -63,7 +63,7 @@ function ConcatTable:updateGradInput(input, gradOutput) else self.gradInput = (not wasTable) and self.gradInput or input:clone() for i,module in ipairs(self.modules) do - local currentGradInput = module:updateGradInput(input, gradOutput[i]) + local currentGradInput = module[method](module, input, gradOutput[i], scale) if i == 1 then self.gradInput:resizeAs(currentGradInput):copy(currentGradInput) else @@ -74,6 +74,14 @@ function ConcatTable:updateGradInput(input, gradOutput) return self.gradInput end +function ConcatTable:updateGradInput(input, gradOutput) + return backward(self, 'updateGradInput', input, gradOutput) +end + +function ConcatTable:backward(input, gradOutput, scale) + return backward(self, 'backward', input, gradOutput, scale) +end + function ConcatTable:accGradParameters(input, gradOutput, scale) scale = scale or 1 for i,module in ipairs(self.modules) do |