diff options
Diffstat (limited to 'SGDOptimization.lua')
-rw-r--r-- | SGDOptimization.lua | 13 |
1 files changed, 10 insertions, 3 deletions
diff --git a/SGDOptimization.lua b/SGDOptimization.lua index 52d780e..f59dc02 100644 --- a/SGDOptimization.lua +++ b/SGDOptimization.lua @@ -14,7 +14,9 @@ function SGD:__init(...) self.gradParametersT = nnx.getGradParameters(self.module) end -function SGD:forward(inputs, targets) +function SGD:forward(inputs, targets, options) + options = options or {} + -- reset gradients self.module:zeroGradParameters() @@ -23,6 +25,11 @@ function SGD:forward(inputs, targets) -- given all inputs, evaluate gradients for i = 1,#inputs do + -- user hook + if self.prehook then + self.prehook(self, {inputs[i], targets[i], options[i]}) + end + -- estimate f local output = self.module:forward(inputs[i]) local err = self.criterion:forward(output, targets[i]) @@ -33,8 +40,8 @@ function SGD:forward(inputs, targets) self.module:backward(inputs[i], df_do) -- user hook - if self.hook then - self.hook(self, {inputs[i], targets[i]}) + if self.posthook then + self.posthook(self, {inputs[i], targets[i], options[i]}) end end |