Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSoumith Chintala <soumith@gmail.com>2015-04-30 17:57:37 +0300
committerSoumith Chintala <soumith@gmail.com>2015-04-30 17:57:37 +0300
commitdc3831e730493afd49de586476aaf942f8999e2b (patch)
tree2decf9626e72ea2cbb4055897037c1edb8d67b91 /Sequential.lua
parent485dd619695c47e49ab56ff518edad52b74475fc (diff)
parent73f762296407110e2725fb5b9caf672afb32acc3 (diff)
Merge pull request #234 from torch/backward
adding direct :backward to Concat, DepthConcat, Sequential
Diffstat (limited to 'Sequential.lua')
-rw-r--r--Sequential.lua23
1 files changed, 19 insertions, 4 deletions
diff --git a/Sequential.lua b/Sequential.lua
index 5145727..b08f7df 100644
--- a/Sequential.lua
+++ b/Sequential.lua
@@ -25,9 +25,9 @@ end
function Sequential:updateOutput(input)
local currentOutput = input
- for i=1,#self.modules do
+ for i=1,#self.modules do
currentOutput = self.modules[i]:updateOutput(currentOutput)
- end
+ end
self.output = currentOutput
return currentOutput
end
@@ -56,10 +56,25 @@ function Sequential:accGradParameters(input, gradOutput, scale)
currentGradOutput = currentModule.gradInput
currentModule = previousModule
end
-
+
currentModule:accGradParameters(input, currentGradOutput, scale)
end
+function Sequential:backward(input, gradOutput, scale)
+ scale = scale or 1
+ local currentGradOutput = gradOutput
+ local currentModule = self.modules[#self.modules]
+ for i=#self.modules-1,1,-1 do
+ local previousModule = self.modules[i]
+ currentGradOutput = currentModule:backward(previousModule.output, currentGradOutput, scale)
+ currentModule.gradInput = currentGradOutput
+ currentModule = previousModule
+ end
+ currentGradOutput = currentModule:backward(input, currentGradOutput, scale)
+ self.gradInput = currentGradOutput
+ return currentGradOutput
+end
+
function Sequential:accUpdateGradParameters(input, gradOutput, lr)
local currentGradOutput = gradOutput
local currentModule = self.modules[#self.modules]
@@ -69,7 +84,7 @@ function Sequential:accUpdateGradParameters(input, gradOutput, lr)
currentGradOutput = currentModule.gradInput
currentModule = previousModule
end
-
+
currentModule:accUpdateGradParameters(input, currentGradOutput, lr)
end