diff options
Diffstat (limited to 'LBFGSOptimization.lua')
-rw-r--r-- | LBFGSOptimization.lua | 11 |
1 files changed, 8 insertions, 3 deletions
diff --git a/LBFGSOptimization.lua b/LBFGSOptimization.lua index f901a57..671ce7a 100644 --- a/LBFGSOptimization.lua +++ b/LBFGSOptimization.lua @@ -13,7 +13,8 @@ function LBFGS:__init(...) self.gradParametersT = nnx.getGradParameters(self.module) end -function LBFGS:forward(inputs, targets) +function LBFGS:forward(inputs, targets, options) + options = options or {} -- (1) construct a closure that compute f(inputs) + df/dW -- after each call to that function: -- + self.parameters contains the current X vector @@ -29,6 +30,10 @@ function LBFGS:forward(inputs, targets) self.output = 0 -- given all inputs, evaluate gradients for i = 1,#inputs do + -- user hook + if self.prehook then + self.prehook(self, {inputs[i], targets[i], options[i]}) + end -- estimate f local output = self.module:forward(inputs[i]) local err = self.criterion:forward(output, targets[i]) @@ -37,8 +42,8 @@ function LBFGS:forward(inputs, targets) local df_do = self.criterion:backward(output, targets[i]) self.module:backward(inputs[i], df_do) -- user hook - if self.hook then - self.hook(self, {inputs[i], targets[i]}) + if self.posthook then + self.posthook(self, {inputs[i], targets[i], options[i]}) end end -- update state from computed parameters |