Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/clementfarabet/lua---nnx.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorClement Farabet <clement.farabet@gmail.com>2011-10-12 09:29:42 +0400
committerClement Farabet <clement.farabet@gmail.com>2011-10-12 09:29:42 +0400
commitfa2f719f009204ecc7430b351694bdf54f956a58 (patch)
tree32560c26ef51349ccfd0f4fa577a878675875323
parentc42ff17b16bca9c4aa6d264cc2a79b3347a4b81f (diff)
Added ConcatTable to diaghessian.
-rw-r--r--DiagHessian.lua26
1 files changed, 26 insertions, 0 deletions
diff --git a/DiagHessian.lua b/DiagHessian.lua
index d163051..278aa0e 100644
--- a/DiagHessian.lua
+++ b/DiagHessian.lua
@@ -101,3 +101,29 @@ function nn.Sequential.accDiagHessianParameters(self, input, diagHessianOutput,
end
currentModule:accDiagHessianParameters(input, currentDiagHessianOutput, scale)
end
+
+-- ConcatTable
+function nn.ConcatTable.backwardDiagHessian(self, input, diagHessianOutput)
+ for i,module in ipairs(self.modules) do
+ local currentDiagHessianInput = module:backward(input, diagHessianOutput[i])
+ if i == 1 then
+ self.diagHessianInput:resizeAs(currentDiagHessianInput):copy(currentDiagHessianInput)
+ else
+ self.diagHessianInput:add(currentDiagHessianInput)
+ end
+ end
+ return self.diagHessianInput
+end
+
+function nn.ConcatTable.initDiagHessianParameters(self)
+ for i=1,#self.modules do
+ self.modules[i]:initDiagHessianParameters()
+ end
+end
+
+function nn.ConcatTable.accDiagHessianParameters(self, input, diagHessianOutput, scale)
+ scale = scale or 1
+ for i,module in ipairs(self.modules) do
+ module:accDiagHessianParameters(input, diagHessianOutput[i], scale)
+ end
+end