From e7024d5b41e5efa9ef36b20fe8fbf5b93df25667 Mon Sep 17 00:00:00 2001 From: Clement Farabet Date: Wed, 24 Aug 2011 11:39:52 -0400 Subject: More complete hook system for Optimizations --- LBFGSOptimization.lua | 11 ++++++++--- SGDOptimization.lua | 13 ++++++++++--- 2 files changed, 18 insertions(+), 6 deletions(-) diff --git a/LBFGSOptimization.lua b/LBFGSOptimization.lua index f901a57..671ce7a 100644 --- a/LBFGSOptimization.lua +++ b/LBFGSOptimization.lua @@ -13,7 +13,8 @@ function LBFGS:__init(...) self.gradParametersT = nnx.getGradParameters(self.module) end -function LBFGS:forward(inputs, targets) +function LBFGS:forward(inputs, targets, options) + options = options or {} -- (1) construct a closure that compute f(inputs) + df/dW -- after each call to that function: -- + self.parameters contains the current X vector @@ -29,6 +30,10 @@ function LBFGS:forward(inputs, targets) self.output = 0 -- given all inputs, evaluate gradients for i = 1,#inputs do + -- user hook + if self.prehook then + self.prehook(self, {inputs[i], targets[i], options[i]}) + end -- estimate f local output = self.module:forward(inputs[i]) local err = self.criterion:forward(output, targets[i]) @@ -37,8 +42,8 @@ function LBFGS:forward(inputs, targets) local df_do = self.criterion:backward(output, targets[i]) self.module:backward(inputs[i], df_do) -- user hook - if self.hook then - self.hook(self, {inputs[i], targets[i]}) + if self.posthook then + self.posthook(self, {inputs[i], targets[i], options[i]}) end end -- update state from computed parameters diff --git a/SGDOptimization.lua b/SGDOptimization.lua index 52d780e..f59dc02 100644 --- a/SGDOptimization.lua +++ b/SGDOptimization.lua @@ -14,7 +14,9 @@ function SGD:__init(...) self.gradParametersT = nnx.getGradParameters(self.module) end -function SGD:forward(inputs, targets) +function SGD:forward(inputs, targets, options) + options = options or {} + -- reset gradients self.module:zeroGradParameters() @@ -23,6 +25,11 @@ function SGD:forward(inputs, targets) -- given all inputs, evaluate gradients for i = 1,#inputs do + -- user hook + if self.prehook then + self.prehook(self, {inputs[i], targets[i], options[i]}) + end + -- estimate f local output = self.module:forward(inputs[i]) local err = self.criterion:forward(output, targets[i]) @@ -33,8 +40,8 @@ function SGD:forward(inputs, targets) self.module:backward(inputs[i], df_do) -- user hook - if self.hook then - self.hook(self, {inputs[i], targets[i]}) + if self.posthook then + self.posthook(self, {inputs[i], targets[i], options[i]}) end end -- cgit v1.2.3