diff options
author | Clement Farabet <clement.farabet@gmail.com> | 2011-08-24 19:39:52 +0400 |
---|---|---|
committer | Clement Farabet <clement.farabet@gmail.com> | 2011-08-24 19:39:52 +0400 |
commit | e7024d5b41e5efa9ef36b20fe8fbf5b93df25667 (patch) | |
tree | cfdc6b47be8e420b89b79e2e08c697f7bc4ceb3b | |
parent | dcfcc7b199aebfac80e1c2caf3b5af4d66906f2d (diff) |
More complete hook system for Optimizationslbfgs
-rw-r--r-- | LBFGSOptimization.lua | 11 | ||||
-rw-r--r-- | SGDOptimization.lua | 13 |
2 files changed, 18 insertions, 6 deletions
diff --git a/LBFGSOptimization.lua b/LBFGSOptimization.lua index f901a57..671ce7a 100644 --- a/LBFGSOptimization.lua +++ b/LBFGSOptimization.lua @@ -13,7 +13,8 @@ function LBFGS:__init(...) self.gradParametersT = nnx.getGradParameters(self.module) end -function LBFGS:forward(inputs, targets) +function LBFGS:forward(inputs, targets, options) + options = options or {} -- (1) construct a closure that compute f(inputs) + df/dW -- after each call to that function: -- + self.parameters contains the current X vector @@ -29,6 +30,10 @@ function LBFGS:forward(inputs, targets) self.output = 0 -- given all inputs, evaluate gradients for i = 1,#inputs do + -- user hook + if self.prehook then + self.prehook(self, {inputs[i], targets[i], options[i]}) + end -- estimate f local output = self.module:forward(inputs[i]) local err = self.criterion:forward(output, targets[i]) @@ -37,8 +42,8 @@ function LBFGS:forward(inputs, targets) local df_do = self.criterion:backward(output, targets[i]) self.module:backward(inputs[i], df_do) -- user hook - if self.hook then - self.hook(self, {inputs[i], targets[i]}) + if self.posthook then + self.posthook(self, {inputs[i], targets[i], options[i]}) end end -- update state from computed parameters diff --git a/SGDOptimization.lua b/SGDOptimization.lua index 52d780e..f59dc02 100644 --- a/SGDOptimization.lua +++ b/SGDOptimization.lua @@ -14,7 +14,9 @@ function SGD:__init(...) self.gradParametersT = nnx.getGradParameters(self.module) end -function SGD:forward(inputs, targets) +function SGD:forward(inputs, targets, options) + options = options or {} + -- reset gradients self.module:zeroGradParameters() @@ -23,6 +25,11 @@ function SGD:forward(inputs, targets) -- given all inputs, evaluate gradients for i = 1,#inputs do + -- user hook + if self.prehook then + self.prehook(self, {inputs[i], targets[i], options[i]}) + end + -- estimate f local output = self.module:forward(inputs[i]) local err = self.criterion:forward(output, targets[i]) @@ -33,8 +40,8 @@ function SGD:forward(inputs, targets) self.module:backward(inputs[i], df_do) -- user hook - if self.hook then - self.hook(self, {inputs[i], targets[i]}) + if self.posthook then + self.posthook(self, {inputs[i], targets[i], options[i]}) end end |