diff options
author | Clement Farabet <clement.farabet@gmail.com> | 2011-10-12 09:29:42 +0400 |
---|---|---|
committer | Clement Farabet <clement.farabet@gmail.com> | 2011-10-12 09:29:42 +0400 |
commit | fa2f719f009204ecc7430b351694bdf54f956a58 (patch) | |
tree | 32560c26ef51349ccfd0f4fa577a878675875323 | |
parent | c42ff17b16bca9c4aa6d264cc2a79b3347a4b81f (diff) |
Added ConcatTable to diaghessian.
-rw-r--r-- | DiagHessian.lua | 26 |
1 files changed, 26 insertions, 0 deletions
diff --git a/DiagHessian.lua b/DiagHessian.lua index d163051..278aa0e 100644 --- a/DiagHessian.lua +++ b/DiagHessian.lua @@ -101,3 +101,29 @@ function nn.Sequential.accDiagHessianParameters(self, input, diagHessianOutput, end currentModule:accDiagHessianParameters(input, currentDiagHessianOutput, scale) end + +-- ConcatTable +function nn.ConcatTable.backwardDiagHessian(self, input, diagHessianOutput) + for i,module in ipairs(self.modules) do + local currentDiagHessianInput = module:backward(input, diagHessianOutput[i]) + if i == 1 then + self.diagHessianInput:resizeAs(currentDiagHessianInput):copy(currentDiagHessianInput) + else + self.diagHessianInput:add(currentDiagHessianInput) + end + end + return self.diagHessianInput +end + +function nn.ConcatTable.initDiagHessianParameters(self) + for i=1,#self.modules do + self.modules[i]:initDiagHessianParameters() + end +end + +function nn.ConcatTable.accDiagHessianParameters(self, input, diagHessianOutput, scale) + scale = scale or 1 + for i,module in ipairs(self.modules) do + module:accDiagHessianParameters(input, diagHessianOutput[i], scale) + end +end |