diff options
author | soumith <soumith@fb.com> | 2015-04-21 03:26:06 +0300 |
---|---|---|
committer | soumith <soumith@fb.com> | 2015-04-21 04:09:54 +0300 |
commit | 73f762296407110e2725fb5b9caf672afb32acc3 (patch) | |
tree | 00099f05d9c39f84adbc9bf52a61924a688ca3a6 /Sequential.lua | |
parent | a2db5ec31f2dd236186c376a04daa31af319e39d (diff) |
adding direct :backward to Concat, DepthConcat, Sequential
Diffstat (limited to 'Sequential.lua')
-rw-r--r-- | Sequential.lua | 23 |
1 files changed, 19 insertions, 4 deletions
diff --git a/Sequential.lua b/Sequential.lua index 3288e6d..4ab82bf 100644 --- a/Sequential.lua +++ b/Sequential.lua @@ -21,9 +21,9 @@ end function Sequential:updateOutput(input) local currentOutput = input - for i=1,#self.modules do + for i=1,#self.modules do currentOutput = self.modules[i]:updateOutput(currentOutput) - end + end self.output = currentOutput return currentOutput end @@ -52,10 +52,25 @@ function Sequential:accGradParameters(input, gradOutput, scale) currentGradOutput = currentModule.gradInput currentModule = previousModule end - + currentModule:accGradParameters(input, currentGradOutput, scale) end +function Sequential:backward(input, gradOutput, scale) + scale = scale or 1 + local currentGradOutput = gradOutput + local currentModule = self.modules[#self.modules] + for i=#self.modules-1,1,-1 do + local previousModule = self.modules[i] + currentGradOutput = currentModule:backward(previousModule.output, currentGradOutput, scale) + currentModule.gradInput = currentGradOutput + currentModule = previousModule + end + currentGradOutput = currentModule:backward(input, currentGradOutput, scale) + self.gradInput = currentGradOutput + return currentGradOutput +end + function Sequential:accUpdateGradParameters(input, gradOutput, lr) local currentGradOutput = gradOutput local currentModule = self.modules[#self.modules] @@ -65,7 +80,7 @@ function Sequential:accUpdateGradParameters(input, gradOutput, lr) currentGradOutput = currentModule.gradInput currentModule = previousModule end - + currentModule:accUpdateGradParameters(input, currentGradOutput, lr) end |