Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorsoumith <soumith@fb.com>2015-04-21 03:26:06 +0300
committersoumith <soumith@fb.com>2015-04-21 04:09:54 +0300
commit73f762296407110e2725fb5b9caf672afb32acc3 (patch)
tree00099f05d9c39f84adbc9bf52a61924a688ca3a6 /Sequential.lua
parenta2db5ec31f2dd236186c376a04daa31af319e39d (diff)
adding direct :backward to Concat, DepthConcat, Sequential
Diffstat (limited to 'Sequential.lua')
-rw-r--r--Sequential.lua23
1 files changed, 19 insertions, 4 deletions
diff --git a/Sequential.lua b/Sequential.lua
index 3288e6d..4ab82bf 100644
--- a/Sequential.lua
+++ b/Sequential.lua
@@ -21,9 +21,9 @@ end
function Sequential:updateOutput(input)
local currentOutput = input
- for i=1,#self.modules do
+ for i=1,#self.modules do
currentOutput = self.modules[i]:updateOutput(currentOutput)
- end
+ end
self.output = currentOutput
return currentOutput
end
@@ -52,10 +52,25 @@ function Sequential:accGradParameters(input, gradOutput, scale)
currentGradOutput = currentModule.gradInput
currentModule = previousModule
end
-
+
currentModule:accGradParameters(input, currentGradOutput, scale)
end
+function Sequential:backward(input, gradOutput, scale)
+ scale = scale or 1
+ local currentGradOutput = gradOutput
+ local currentModule = self.modules[#self.modules]
+ for i=#self.modules-1,1,-1 do
+ local previousModule = self.modules[i]
+ currentGradOutput = currentModule:backward(previousModule.output, currentGradOutput, scale)
+ currentModule.gradInput = currentGradOutput
+ currentModule = previousModule
+ end
+ currentGradOutput = currentModule:backward(input, currentGradOutput, scale)
+ self.gradInput = currentGradOutput
+ return currentGradOutput
+end
+
function Sequential:accUpdateGradParameters(input, gradOutput, lr)
local currentGradOutput = gradOutput
local currentModule = self.modules[#self.modules]
@@ -65,7 +80,7 @@ function Sequential:accUpdateGradParameters(input, gradOutput, lr)
currentGradOutput = currentModule.gradInput
currentModule = previousModule
end
-
+
currentModule:accUpdateGradParameters(input, currentGradOutput, lr)
end