Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/clementfarabet/lua---nnx.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'SGDOptimization.lua')
-rw-r--r--SGDOptimization.lua13
1 files changed, 10 insertions, 3 deletions
diff --git a/SGDOptimization.lua b/SGDOptimization.lua
index 52d780e..f59dc02 100644
--- a/SGDOptimization.lua
+++ b/SGDOptimization.lua
@@ -14,7 +14,9 @@ function SGD:__init(...)
self.gradParametersT = nnx.getGradParameters(self.module)
end
-function SGD:forward(inputs, targets)
+function SGD:forward(inputs, targets, options)
+ options = options or {}
+
-- reset gradients
self.module:zeroGradParameters()
@@ -23,6 +25,11 @@ function SGD:forward(inputs, targets)
-- given all inputs, evaluate gradients
for i = 1,#inputs do
+ -- user hook
+ if self.prehook then
+ self.prehook(self, {inputs[i], targets[i], options[i]})
+ end
+
-- estimate f
local output = self.module:forward(inputs[i])
local err = self.criterion:forward(output, targets[i])
@@ -33,8 +40,8 @@ function SGD:forward(inputs, targets)
self.module:backward(inputs[i], df_do)
-- user hook
- if self.hook then
- self.hook(self, {inputs[i], targets[i]})
+ if self.posthook then
+ self.posthook(self, {inputs[i], targets[i], options[i]})
end
end