Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/clementfarabet/lua---nnx.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'LBFGSOptimization.lua')
-rw-r--r--LBFGSOptimization.lua11
1 files changed, 8 insertions, 3 deletions
diff --git a/LBFGSOptimization.lua b/LBFGSOptimization.lua
index f901a57..671ce7a 100644
--- a/LBFGSOptimization.lua
+++ b/LBFGSOptimization.lua
@@ -13,7 +13,8 @@ function LBFGS:__init(...)
self.gradParametersT = nnx.getGradParameters(self.module)
end
-function LBFGS:forward(inputs, targets)
+function LBFGS:forward(inputs, targets, options)
+ options = options or {}
-- (1) construct a closure that compute f(inputs) + df/dW
-- after each call to that function:
-- + self.parameters contains the current X vector
@@ -29,6 +30,10 @@ function LBFGS:forward(inputs, targets)
self.output = 0
-- given all inputs, evaluate gradients
for i = 1,#inputs do
+ -- user hook
+ if self.prehook then
+ self.prehook(self, {inputs[i], targets[i], options[i]})
+ end
-- estimate f
local output = self.module:forward(inputs[i])
local err = self.criterion:forward(output, targets[i])
@@ -37,8 +42,8 @@ function LBFGS:forward(inputs, targets)
local df_do = self.criterion:backward(output, targets[i])
self.module:backward(inputs[i], df_do)
-- user hook
- if self.hook then
- self.hook(self, {inputs[i], targets[i]})
+ if self.posthook then
+ self.posthook(self, {inputs[i], targets[i], options[i]})
end
end
-- update state from computed parameters