Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/optim.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSoumith Chintala <soumith@gmail.com>2016-06-02 19:39:43 +0300
committerSoumith Chintala <soumith@gmail.com>2016-06-02 19:39:43 +0300
commit76e666f78bde9522ca20ac3426370416e23d0283 (patch)
tree282c9fca91f206c164777a4f51e9fe7f9c23dd4e
parent894081857c59f38ef619282bfdded19ad757d2be (diff)
parenteb3d1ee8f209db6e9158dbf84304fa75295795d6 (diff)
Merge pull request #113 from gpleiss/sgd-lrs-fix
Fix bug with sgd individual learning rates
-rw-r--r--sgd.lua11
1 files changed, 8 insertions, 3 deletions
diff --git a/sgd.lua b/sgd.lua
index ea13c55..d96bd4b 100644
--- a/sgd.lua
+++ b/sgd.lua
@@ -69,15 +69,20 @@ function optim.sgd(opfunc, x, config, state)
end
-- (4) learning rate decay (annealing)
- local clr = lr / (1 + nevals*lrd)
+ local clr, clrs
+ if lrs then
+ clrs = lrs / (1 + nevals*lrd)
+ else
+ clr = lr / (1 + nevals*lrd)
+ end
-- (5) parameter update with single or individual learning rates
if lrs then
if not state.deltaParameters then
state.deltaParameters = torch.Tensor():typeAs(x):resizeAs(dfdx)
end
- state.deltaParameters:copy(lrs):cmul(dfdx)
- x:add(-clr, state.deltaParameters)
+ state.deltaParameters:copy(clrs):cmul(dfdx)
+ x:add(-state.deltaParameters)
else
x:add(-clr, dfdx)
end