Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSam Gross <sgross@fb.com>2016-02-20 02:16:04 +0300
committerSam Gross <sgross@fb.com>2016-02-20 02:16:04 +0300
commit69b2dff62e026a07cd56c76408e653fb9dee0b9d (patch)
tree9a912bb73707197b72d640b588355dc816f135ff /ConcatTable.lua
parentafa8d1a6cf0977d476d9195a984a55587682f5cc (diff)
Override Module:backwards in ConcatTable.
Overrides backwards to recursively call backwards on child modules. This makes it easier to share gradInput storages between modules to save memory.
Diffstat (limited to 'ConcatTable.lua')
-rw-r--r--ConcatTable.lua14
1 files changed, 11 insertions, 3 deletions
diff --git a/ConcatTable.lua b/ConcatTable.lua
index 61e9011..eddf6cc 100644
--- a/ConcatTable.lua
+++ b/ConcatTable.lua
@@ -27,12 +27,12 @@ local function retable(t1, t2, f)
return t1
end
-function ConcatTable:updateGradInput(input, gradOutput)
+local function backward(self, method, input, gradOutput, scale)
local isTable = torch.type(input) == 'table'
local wasTable = torch.type(self.gradInput) == 'table'
if isTable then
for i,module in ipairs(self.modules) do
- local currentGradInput = module:updateGradInput(input, gradOutput[i])
+ local currentGradInput = module[method](module, input, gradOutput[i], scale)
if torch.type(currentGradInput) ~= 'table' then
error"currentGradInput is not a table!"
end
@@ -63,7 +63,7 @@ function ConcatTable:updateGradInput(input, gradOutput)
else
self.gradInput = (not wasTable) and self.gradInput or input:clone()
for i,module in ipairs(self.modules) do
- local currentGradInput = module:updateGradInput(input, gradOutput[i])
+ local currentGradInput = module[method](module, input, gradOutput[i], scale)
if i == 1 then
self.gradInput:resizeAs(currentGradInput):copy(currentGradInput)
else
@@ -74,6 +74,14 @@ function ConcatTable:updateGradInput(input, gradOutput)
return self.gradInput
end
+function ConcatTable:updateGradInput(input, gradOutput)
+ return backward(self, 'updateGradInput', input, gradOutput)
+end
+
+function ConcatTable:backward(input, gradOutput, scale)
+ return backward(self, 'backward', input, gradOutput, scale)
+end
+
function ConcatTable:accGradParameters(input, gradOutput, scale)
scale = scale or 1
for i,module in ipairs(self.modules) do