Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/optim.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorWill Williams <gwillwill@gmail.com>2015-03-23 18:38:20 +0300
committerWill Williams <willw@cantabresearch.com>2015-03-23 18:38:36 +0300
commit7ea43915a2b0da2d651c779d4ea246d7a1ccee77 (patch)
treea0a788a808101679f491e319b0fec55a0712f261
parent5d8da2a49082353882ec0e0025bdfe676a28f379 (diff)
added rmsprop to init.lua, removed one malloc entirely and placed the update int state.tmp to remove the other malloc
-rw-r--r--init.lua1
-rw-r--r--rmsprop.lua20
2 files changed, 9 insertions, 12 deletions
diff --git a/init.lua b/init.lua
index 391258b..55cb2a4 100644
--- a/init.lua
+++ b/init.lua
@@ -13,6 +13,7 @@ torch.include('optim', 'lbfgs.lua')
torch.include('optim', 'adagrad.lua')
torch.include('optim', 'rprop.lua')
torch.include('optim', 'adam.lua')
+torch.include('optim', 'rmsprop.lua')
-- line search functions
torch.include('optim', 'lswolfe.lua')
diff --git a/rmsprop.lua b/rmsprop.lua
index c8c5ddd..b6baf96 100644
--- a/rmsprop.lua
+++ b/rmsprop.lua
@@ -25,11 +25,11 @@ function optim.rmsprop(opfunc, x, config, state)
-- (0) get/update state
local config = config or {}
local state = state or config
- local lr = config.learningRate or 1e-3
+ local lr = config.learningRate or 1e-4
local alpha = config.alpha or 0.998
local epsilon = config.epsilon or 1e-8
- local epsilon2 = config.epsilon2 or 1e-4
- local max_gain = config.max_gain or 100
+ local epsilon2 = config.epsilon2 or 1e-8
+ local max_gain = config.max_gain or 1000
local min_gain = config.min_gain or 1e-8
-- (1) evaluate f(x) and df/dx
@@ -37,21 +37,17 @@ function optim.rmsprop(opfunc, x, config, state)
-- (2) initialize mean square values and square gradient storage
state.m = state.m or torch.Tensor():typeAs(dfdx):resizeAs(dfdx):fill(epsilon)
- state.dfdx_sq = state.dfdx_sq or torch.Tensor():typeAs(dfdx):resizeAs(dfdx)
+ state.tmp = state.tmp or x.new(dfdx:size()):zero()
-- (3) calculate new mean squared values
- torch.cmul(state.dfdx_sq, dfdx, dfdx)
state.m:mul(alpha)
- state.m:add(state.dfdx_sq:mul(1.0-alpha))
- state.m:add(epsilon2)
+ state.m:addcmul(1.0-alpha,dfdx,dfdx):add(epsilon2)
-- (4) perform update
- local one_over_rms = torch.pow(state.m, -0.5)
- one_over_rms:clamp(min_gain, max_gain)
- local update = torch.cmul(torch.mul(dfdx,lr), one_over_rms)
- x:add(-update)
+ state.tmp:copy(state.m):pow(-0.5):clamp(min_gain, max_gain)
+ x:add(-lr, state.tmp)
-- return x*, f(x) before optimization
- return x, {fx}, update
+ return x, {fx}, state.tmp
end