Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/clementfarabet/lua---nnx.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorClement Farabet <clement.farabet@gmail.com>2011-08-24 19:03:00 +0400
committerClement Farabet <clement.farabet@gmail.com>2011-08-24 19:03:00 +0400
commitdcfcc7b199aebfac80e1c2caf3b5af4d66906f2d (patch)
tree67aebcf32da0c64aa7a31968372d2ee9c5712318
parent5ac570666ccba10bcc1e4cd1bc2b9846ccda7f07 (diff)
Added user hook to Optimization modules.
-rw-r--r--LBFGSOptimization.lua4
-rw-r--r--SGDOptimization.lua5
2 files changed, 9 insertions, 0 deletions
diff --git a/LBFGSOptimization.lua b/LBFGSOptimization.lua
index 3d9a9ed..f901a57 100644
--- a/LBFGSOptimization.lua
+++ b/LBFGSOptimization.lua
@@ -36,6 +36,10 @@ function LBFGS:forward(inputs, targets)
-- estimate df/dW
local df_do = self.criterion:backward(output, targets[i])
self.module:backward(inputs[i], df_do)
+ -- user hook
+ if self.hook then
+ self.hook(self, {inputs[i], targets[i]})
+ end
end
-- update state from computed parameters
self:flatten(self.parametersT, self.gradParametersT)
diff --git a/SGDOptimization.lua b/SGDOptimization.lua
index 514c1a8..52d780e 100644
--- a/SGDOptimization.lua
+++ b/SGDOptimization.lua
@@ -31,6 +31,11 @@ function SGD:forward(inputs, targets)
-- estimate df/dW
local df_do = self.criterion:backward(output, targets[i])
self.module:backward(inputs[i], df_do)
+
+ -- user hook
+ if self.hook then
+ self.hook(self, {inputs[i], targets[i]})
+ end
end
-- renorm f