Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/clementfarabet/lua---nnx.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorClement Farabet <clement.farabet@gmail.com>2011-08-24 19:39:52 +0400
committerClement Farabet <clement.farabet@gmail.com>2011-08-24 19:39:52 +0400
commite7024d5b41e5efa9ef36b20fe8fbf5b93df25667 (patch)
treecfdc6b47be8e420b89b79e2e08c697f7bc4ceb3b
parentdcfcc7b199aebfac80e1c2caf3b5af4d66906f2d (diff)
More complete hook system for Optimizationslbfgs
-rw-r--r--LBFGSOptimization.lua11
-rw-r--r--SGDOptimization.lua13
2 files changed, 18 insertions, 6 deletions
diff --git a/LBFGSOptimization.lua b/LBFGSOptimization.lua
index f901a57..671ce7a 100644
--- a/LBFGSOptimization.lua
+++ b/LBFGSOptimization.lua
@@ -13,7 +13,8 @@ function LBFGS:__init(...)
self.gradParametersT = nnx.getGradParameters(self.module)
end
-function LBFGS:forward(inputs, targets)
+function LBFGS:forward(inputs, targets, options)
+ options = options or {}
-- (1) construct a closure that compute f(inputs) + df/dW
-- after each call to that function:
-- + self.parameters contains the current X vector
@@ -29,6 +30,10 @@ function LBFGS:forward(inputs, targets)
self.output = 0
-- given all inputs, evaluate gradients
for i = 1,#inputs do
+ -- user hook
+ if self.prehook then
+ self.prehook(self, {inputs[i], targets[i], options[i]})
+ end
-- estimate f
local output = self.module:forward(inputs[i])
local err = self.criterion:forward(output, targets[i])
@@ -37,8 +42,8 @@ function LBFGS:forward(inputs, targets)
local df_do = self.criterion:backward(output, targets[i])
self.module:backward(inputs[i], df_do)
-- user hook
- if self.hook then
- self.hook(self, {inputs[i], targets[i]})
+ if self.posthook then
+ self.posthook(self, {inputs[i], targets[i], options[i]})
end
end
-- update state from computed parameters
diff --git a/SGDOptimization.lua b/SGDOptimization.lua
index 52d780e..f59dc02 100644
--- a/SGDOptimization.lua
+++ b/SGDOptimization.lua
@@ -14,7 +14,9 @@ function SGD:__init(...)
self.gradParametersT = nnx.getGradParameters(self.module)
end
-function SGD:forward(inputs, targets)
+function SGD:forward(inputs, targets, options)
+ options = options or {}
+
-- reset gradients
self.module:zeroGradParameters()
@@ -23,6 +25,11 @@ function SGD:forward(inputs, targets)
-- given all inputs, evaluate gradients
for i = 1,#inputs do
+ -- user hook
+ if self.prehook then
+ self.prehook(self, {inputs[i], targets[i], options[i]})
+ end
+
-- estimate f
local output = self.module:forward(inputs[i])
local err = self.criterion:forward(output, targets[i])
@@ -33,8 +40,8 @@ function SGD:forward(inputs, targets)
self.module:backward(inputs[i], df_do)
-- user hook
- if self.hook then
- self.hook(self, {inputs[i], targets[i]})
+ if self.posthook then
+ self.posthook(self, {inputs[i], targets[i], options[i]})
end
end