Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/clementfarabet/lua---nnx.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorClement Farabet <clement.farabet@gmail.com>2011-08-24 18:44:38 +0400
committerClement Farabet <clement.farabet@gmail.com>2011-08-24 18:44:38 +0400
commit5ac570666ccba10bcc1e4cd1bc2b9846ccda7f07 (patch)
treed872525f12646f569928ad8fe370c3bd809004f0
parent52568ca8072eed52ef784262144802dfc62d296a (diff)
Added a max iterations to lBFGS.
-rw-r--r--LBFGSOptimization.lua5
-rw-r--r--lbfgs.c2
2 files changed, 4 insertions, 3 deletions
diff --git a/LBFGSOptimization.lua b/LBFGSOptimization.lua
index 7d41844..3d9a9ed 100644
--- a/LBFGSOptimization.lua
+++ b/LBFGSOptimization.lua
@@ -6,7 +6,8 @@ function LBFGS:__init(...)
xlua.unpack_class(self, {...},
'LBFGSOptimization', nil,
{arg='module', type='nn.Module', help='a module to train', req=true},
- {arg='criterion', type='nn.Criterion', help='a criterion to estimate the error', req=true}
+ {arg='criterion', type='nn.Criterion', help='a criterion to estimate the error', req=true},
+ {arg='maxIterations', type='number', help='maximum nb of iterations per pass (0 = no max)', default=0}
)
self.parametersT = nnx.getParameters(self.module)
self.gradParametersT = nnx.getGradParameters(self.module)
@@ -47,7 +48,7 @@ function LBFGS:forward(inputs, targets)
-- (3) the magic function: will update the parameter vector
-- according to the l-BFGS method
- self.output = lbfgs.run(self.parameters, self.gradParameters)
+ self.output = lbfgs.run(self.parameters, self.gradParameters, self.maxIterations)
-- (4) last: read parameters back into the model
self:unflatten(self.parametersT, self.gradParametersT)
diff --git a/lbfgs.c b/lbfgs.c
index 63fe2a1..13c8cef 100644
--- a/lbfgs.c
+++ b/lbfgs.c
@@ -1421,7 +1421,7 @@ int lbfgs_run(lua_State *L) {
// initialize the parameters for the L-BFGS optimization
lbfgs_parameter_init(&param);
- //param.linesearch = LBFGS_LINESEARCH_BACKTRACKING;
+ param.max_iterations = lua_tonumber(L, 3);
// Start the L-BFGS optimization; this will invoke the callback functions
// evaluate() and progress() when necessary.