diff options
author | Clement Farabet <clement.farabet@gmail.com> | 2011-08-24 19:03:00 +0400 |
---|---|---|
committer | Clement Farabet <clement.farabet@gmail.com> | 2011-08-24 19:03:00 +0400 |
commit | dcfcc7b199aebfac80e1c2caf3b5af4d66906f2d (patch) | |
tree | 67aebcf32da0c64aa7a31968372d2ee9c5712318 | |
parent | 5ac570666ccba10bcc1e4cd1bc2b9846ccda7f07 (diff) |
Added user hook to Optimization modules.
-rw-r--r-- | LBFGSOptimization.lua | 4 | ||||
-rw-r--r-- | SGDOptimization.lua | 5 |
2 files changed, 9 insertions, 0 deletions
diff --git a/LBFGSOptimization.lua b/LBFGSOptimization.lua index 3d9a9ed..f901a57 100644 --- a/LBFGSOptimization.lua +++ b/LBFGSOptimization.lua @@ -36,6 +36,10 @@ function LBFGS:forward(inputs, targets) -- estimate df/dW local df_do = self.criterion:backward(output, targets[i]) self.module:backward(inputs[i], df_do) + -- user hook + if self.hook then + self.hook(self, {inputs[i], targets[i]}) + end end -- update state from computed parameters self:flatten(self.parametersT, self.gradParametersT) diff --git a/SGDOptimization.lua b/SGDOptimization.lua index 514c1a8..52d780e 100644 --- a/SGDOptimization.lua +++ b/SGDOptimization.lua @@ -31,6 +31,11 @@ function SGD:forward(inputs, targets) -- estimate df/dW local df_do = self.criterion:backward(output, targets[i]) self.module:backward(inputs[i], df_do) + + -- user hook + if self.hook then + self.hook(self, {inputs[i], targets[i]}) + end end -- renorm f |