Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/optim.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSoumith Chintala <soumith@gmail.com>2016-06-30 15:15:40 +0300
committerGitHub <noreply@github.com>2016-06-30 15:15:40 +0300
commite24fd8550d366668473d0b5a89c00546d2145c81 (patch)
treee02ff522a1e8c03aa8c3f8e186a79c0a0420ffa6
parentf6be4bb195e3e128ab027326255172ff36b6c63c (diff)
parent06a42e2af40697bd2c95843aee1b75bc51d4270d (diff)
Merge pull request #121 from Atcold/doc-fix
Documentation and code refactoring
-rw-r--r--README.md47
-rw-r--r--adam.lua70
-rw-r--r--adamax.lua70
-rw-r--r--asgd.lua4
-rw-r--r--cmaes.lua62
-rw-r--r--doc/algos.md363
-rw-r--r--doc/index.md409
-rw-r--r--doc/intro.md41
-rw-r--r--doc/logger.md73
-rw-r--r--doc/logger_plot.pngbin0 -> 45532 bytes
-rw-r--r--fista.lua14
-rw-r--r--lbfgs.lua6
-rw-r--r--nag.lua6
-rw-r--r--rmsprop.lua50
-rw-r--r--rprop.lua154
-rw-r--r--sgd.lua2
16 files changed, 701 insertions, 670 deletions
diff --git a/README.md b/README.md
index 572a67a..561621b 100644
--- a/README.md
+++ b/README.md
@@ -1,45 +1,8 @@
+<a name='optim.dok'></a>
# Optimization package
-This package contains several optimization routines for [Torch](https://github.com/torch/torch7/blob/master/README.md).
-Each optimization algorithm is based on the same interface:
+This package contains several optimization routines and a logger for [Torch](https://github.com/torch/torch7/blob/master/README.md):
-```lua
-x*, {f}, ... = optim.method(func, x, state)
-```
-
-where:
-
-* `func`: a user-defined closure that respects this API: `f, df/dx = func(x)`
-* `x`: the current parameter vector (a 1D `torch.Tensor`)
-* `state`: a table of parameters, and state variables, dependent upon the algorithm
-* `x*`: the new parameter vector that minimizes `f, x* = argmin_x f(x)`
-* `{f}`: a table of all f values, in the order they've been evaluated
- (for some simple algorithms, like SGD, `#f == 1`)
-
-## Available algorithms
-
-Please check [this file](doc/index.md) for the full list of
-optimization algorithms available and examples. Get also into the
-[`test`](test/) directory for straightforward examples using the
-[Rosenbrock's](test/rosenbrock.lua) function.
-
-## Important Note
-
-The state table is used to hold the state of the algorithm.
-It's usually initialized once, by the user, and then passed to the optim function
-as a black box. Example:
-
-```lua
-state = {
- learningRate = 1e-3,
- momentum = 0.5
-}
-
-for i,sample in ipairs(training_samples) do
- local func = function(x)
- -- define eval function
- return f,df_dx
- end
- optim.sgd(func,x,state)
-end
-```
+ * [Overview](doc/intro.md);
+ * [Optimization algorithms](doc/algos.md);
+ * [Logger](doc/logger.md).
diff --git a/adam.lua b/adam.lua
index a6ad588..505a779 100644
--- a/adam.lua
+++ b/adam.lua
@@ -21,47 +21,47 @@ RETURN:
]]
function optim.adam(opfunc, x, config, state)
- -- (0) get/update state
- local config = config or {}
- local state = state or config
- local lr = config.learningRate or 0.001
+ -- (0) get/update state
+ local config = config or {}
+ local state = state or config
+ local lr = config.learningRate or 0.001
- local beta1 = config.beta1 or 0.9
- local beta2 = config.beta2 or 0.999
- local epsilon = config.epsilon or 1e-8
- local wd = config.weightDecay or 0
+ local beta1 = config.beta1 or 0.9
+ local beta2 = config.beta2 or 0.999
+ local epsilon = config.epsilon or 1e-8
+ local wd = config.weightDecay or 0
- -- (1) evaluate f(x) and df/dx
- local fx, dfdx = opfunc(x)
+ -- (1) evaluate f(x) and df/dx
+ local fx, dfdx = opfunc(x)
- -- (2) weight decay
- if wd ~= 0 then
- dfdx:add(wd, x)
- end
+ -- (2) weight decay
+ if wd ~= 0 then
+ dfdx:add(wd, x)
+ end
- -- Initialization
- state.t = state.t or 0
- -- Exponential moving average of gradient values
- state.m = state.m or x.new(dfdx:size()):zero()
- -- Exponential moving average of squared gradient values
- state.v = state.v or x.new(dfdx:size()):zero()
- -- A tmp tensor to hold the sqrt(v) + epsilon
- state.denom = state.denom or x.new(dfdx:size()):zero()
+ -- Initialization
+ state.t = state.t or 0
+ -- Exponential moving average of gradient values
+ state.m = state.m or x.new(dfdx:size()):zero()
+ -- Exponential moving average of squared gradient values
+ state.v = state.v or x.new(dfdx:size()):zero()
+ -- A tmp tensor to hold the sqrt(v) + epsilon
+ state.denom = state.denom or x.new(dfdx:size()):zero()
- state.t = state.t + 1
-
- -- Decay the first and second moment running average coefficient
- state.m:mul(beta1):add(1-beta1, dfdx)
- state.v:mul(beta2):addcmul(1-beta2, dfdx, dfdx)
+ state.t = state.t + 1
- state.denom:copy(state.v):sqrt():add(epsilon)
+ -- Decay the first and second moment running average coefficient
+ state.m:mul(beta1):add(1-beta1, dfdx)
+ state.v:mul(beta2):addcmul(1-beta2, dfdx, dfdx)
- local biasCorrection1 = 1 - beta1^state.t
- local biasCorrection2 = 1 - beta2^state.t
- local stepSize = lr * math.sqrt(biasCorrection2)/biasCorrection1
- -- (3) update x
- x:addcdiv(-stepSize, state.m, state.denom)
+ state.denom:copy(state.v):sqrt():add(epsilon)
- -- return x*, f(x) before optimization
- return x, {fx}
+ local biasCorrection1 = 1 - beta1^state.t
+ local biasCorrection2 = 1 - beta2^state.t
+ local stepSize = lr * math.sqrt(biasCorrection2)/biasCorrection1
+ -- (3) update x
+ x:addcdiv(-stepSize, state.m, state.denom)
+
+ -- return x*, f(x) before optimization
+ return x, {fx}
end
diff --git a/adamax.lua b/adamax.lua
index c06fddd..2b64877 100644
--- a/adamax.lua
+++ b/adamax.lua
@@ -20,47 +20,47 @@ RETURN:
]]
function optim.adamax(opfunc, x, config, state)
- -- (0) get/update state
- local config = config or {}
- local state = state or config
- local lr = config.learningRate or 0.002
+ -- (0) get/update state
+ local config = config or {}
+ local state = state or config
+ local lr = config.learningRate or 0.002
- local beta1 = config.beta1 or 0.9
- local beta2 = config.beta2 or 0.999
- local epsilon = config.epsilon or 1e-38
- local wd = config.weightDecay or 0
+ local beta1 = config.beta1 or 0.9
+ local beta2 = config.beta2 or 0.999
+ local epsilon = config.epsilon or 1e-38
+ local wd = config.weightDecay or 0
- -- (1) evaluate f(x) and df/dx
- local fx, dfdx = opfunc(x)
+ -- (1) evaluate f(x) and df/dx
+ local fx, dfdx = opfunc(x)
- -- (2) weight decay
- if wd ~= 0 then
- dfdx:add(wd, x)
- end
+ -- (2) weight decay
+ if wd ~= 0 then
+ dfdx:add(wd, x)
+ end
- -- Initialization
- state.t = state.t or 0
- -- Exponential moving average of gradient values
- state.m = state.m or x.new(dfdx:size()):zero()
- -- Exponential moving average of the infinity norm
- state.u = state.u or x.new(dfdx:size()):zero()
- -- A tmp tensor to hold the input to max()
- state.max = state.max or x.new(2, unpack(dfdx:size():totable())):zero()
+ -- Initialization
+ state.t = state.t or 0
+ -- Exponential moving average of gradient values
+ state.m = state.m or x.new(dfdx:size()):zero()
+ -- Exponential moving average of the infinity norm
+ state.u = state.u or x.new(dfdx:size()):zero()
+ -- A tmp tensor to hold the input to max()
+ state.max = state.max or x.new(2, unpack(dfdx:size():totable())):zero()
- state.t = state.t + 1
+ state.t = state.t + 1
- -- Update biased first moment estimate.
- state.m:mul(beta1):add(1-beta1, dfdx)
- -- Update the exponentially weighted infinity norm.
- state.max[1]:copy(state.u):mul(beta2)
- state.max[2]:copy(dfdx):abs():add(epsilon)
- state.u:max(state.max, 1)
+ -- Update biased first moment estimate.
+ state.m:mul(beta1):add(1-beta1, dfdx)
+ -- Update the exponentially weighted infinity norm.
+ state.max[1]:copy(state.u):mul(beta2)
+ state.max[2]:copy(dfdx):abs():add(epsilon)
+ state.u:max(state.max, 1)
- local biasCorrection1 = 1 - beta1^state.t
- local stepSize = lr/biasCorrection1
- -- (2) update x
- x:addcdiv(-stepSize, state.m, state.u)
+ local biasCorrection1 = 1 - beta1^state.t
+ local stepSize = lr/biasCorrection1
+ -- (2) update x
+ x:addcdiv(-stepSize, state.m, state.u)
- -- return x*, f(x) before optimization
- return x, {fx}
+ -- return x*, f(x) before optimization
+ return x, {fx}
end
diff --git a/asgd.lua b/asgd.lua
index 659db22..cc1c459 100644
--- a/asgd.lua
+++ b/asgd.lua
@@ -1,6 +1,6 @@
--[[ An implementation of ASGD
-ASGD:
+ASGD:
x := (1 - lambda eta_t) x - eta_t df/dx(z,x)
a := a + mu_t [ x - a ]
@@ -12,7 +12,7 @@ implements ASGD algoritm as in L.Bottou's sgd-2.0
ARGS:
-- `opfunc` : a function that takes a single input (X), the point of
+- `opfunc` : a function that takes a single input (X), the point of
evaluation, and returns f(X) and df/dX
- `x` : the initial point
- `state` : a table describing the state of the optimizer; after each
diff --git a/cmaes.lua b/cmaes.lua
index 1045a48..74cd58a 100644
--- a/cmaes.lua
+++ b/cmaes.lua
@@ -1,16 +1,16 @@
require 'torch'
require 'math'
-local BestSolution = {}
---[[ An implementation of `CMAES` (Covariance Matrix Adaptation Evolution Strategy),
+local BestSolution = {}
+--[[ An implementation of `CMAES` (Covariance Matrix Adaptation Evolution Strategy),
ported from https://www.lri.fr/~hansen/barecmaes2.html.
-
+
Parameters
----------
ARGS:
-- `opfunc` : a function that takes a single input (X), the point of
- evaluation, and returns f(X) and df/dX. Note that df/dX is not used
+- `opfunc` : a function that takes a single input (X), the point of
+ evaluation, and returns f(X) and df/dX. Note that df/dX is not used
- `x` : the initial point
- `state.sigma`
float, initial step-size (standard deviation in each
@@ -20,16 +20,16 @@ ARGS:
- `state.ftarget`
float, target function value
- `state.popsize`
- population size. If this is left empty,
+ population size. If this is left empty,
4 + int(3 * log(|x|)) will be used
-- `state.ftarget`
+- `state.ftarget`
stop if fitness < ftarget
- `state.verb_disp`
int, display on console every verb_disp iteration, 0 for never
RETURN:
- `x*` : the new `x` vector, at the optimal point
-- `f` : a table of all function values:
+- `f` : a table of all function values:
`f[1]` is the value of the function before any optimization and
`f[#f]` is the final fully optimized value, at `x*`
--]]
@@ -50,13 +50,13 @@ function optim.cmaes(opfunc, x, config, state)
local min_iterations = state.min_iterations or 1
local lambda = state.popsize -- population size, offspring number
- -- Strategy parameter setting: Selection
+ -- Strategy parameter setting: Selection
if state.popsize == nil then
lambda = 4 + math.floor(3 * math.log(N))
end
local mu = lambda / 2 -- number of parents/points for recombination
- local weights = torch.range(0,mu-1):apply(function(i)
+ local weights = torch.range(0,mu-1):apply(function(i)
return math.log(mu+0.5) - math.log(i+1) end) -- recombination weights
weights:div(weights:sum()) -- normalize recombination weights array
local mueff = weights:sum()^2 / torch.pow(weights,2):sum() -- variance-effectiveness of sum w_i x_i
@@ -69,18 +69,18 @@ function optim.cmaes(opfunc, x, config, state)
local cmu = math.min(1 - c1, 2 * (mueff - 2 + 1/mueff) / ((N + 2)^2 + mueff)) -- and for rank-mu update
local damps = 2 * mueff/lambda + 0.3 + cs -- damping for sigma, usually close to 1
- -- Initialize dynamic (internal) state variables
+ -- Initialize dynamic (internal) state variables
local pc = torch.Tensor(N):zero():typeAs(x) -- evolution paths for C
local ps = torch.Tensor(N):zero():typeAs(x) -- evolution paths for sigma
- local B = torch.eye(N):typeAs(x) -- B defines the coordinate system
+ local B = torch.eye(N):typeAs(x) -- B defines the coordinate system
local D = torch.Tensor(N):fill(1):typeAs(x) -- diagonal D defines the scaling
- local C = torch.eye(N):typeAs(x) -- covariance matrix
+ local C = torch.eye(N):typeAs(x) -- covariance matrix
if not pcall(function () torch.symeig(C,'V') end) then -- if error occurs trying to use symeig
- error('torch.symeig not available for ' .. x:type() ..
+ error('torch.symeig not available for ' .. x:type() ..
" please use Float- or DoubleTensor for x")
end
local candidates = torch.Tensor(lambda,N):typeAs(x)
- local invsqrtC = torch.eye(N):typeAs(x) -- C^-1/2
+ local invsqrtC = torch.eye(N):typeAs(x) -- C^-1/2
local eigeneval = 0 -- tracking the update of B and D
local counteval = 0
local f_hist = {[1]=opfunc(x)} -- for bookkeeping output and termination
@@ -90,7 +90,7 @@ function optim.cmaes(opfunc, x, config, state)
local function ask()
- --[[return a list of lambda candidate solutions according to
+ --[[return a list of lambda candidate solutions according to
m + sig * Normal(0,C) = m + sig * B * D * Normal(0,I)
--]]
-- Eigendecomposition: first update B, D and invsqrtC from C
@@ -117,9 +117,9 @@ function optim.cmaes(opfunc, x, config, state)
Parameters
----------
- `arx`
+ `arx`
a list of solutions, presumably from `ask()`
- `fitvals`
+ `fitvals`
the corresponding objective function values --]]
-- bookkeeping, preparation
counteval = counteval + lambda -- slightly artificial to do here
@@ -142,7 +142,7 @@ function optim.cmaes(opfunc, x, config, state)
local c = (cs * (2-cs) * mueff)^0.5 / sigma
ps = ps - ps * cs + z * c -- exponential decay on ps
- local hsig = (torch.sum(torch.pow(ps,2)) /
+ local hsig = (torch.sum(torch.pow(ps,2)) /
(1-(1-cs)^(2*counteval/lambda)) / N < 2 + 4./(N+1))
hsig = hsig and 1.0 or 0.0 --use binary numbers
@@ -155,23 +155,23 @@ function optim.cmaes(opfunc, x, config, state)
for i=1,N do
for j=1,N do
local r = torch.range(1,mu)
- r:apply(function(k)
+ r:apply(function(k)
return weights[k] * (arx[k][i]-xold[i]) * (arx[k][j]-xold[j]) end)
local Cmuij = torch.sum(r) / sigma^2 -- rank-mu update
- C[i][j] = C[i][j] + ((-c1a - cmu) * C[i][j] +
+ C[i][j] = C[i][j] + ((-c1a - cmu) * C[i][j] +
c1 * pc[i]*pc[j] + cmu * Cmuij)
end
end
-- Adapt step-size sigma with factor <= exp(0.6) \approx 1.82
- sigma = sigma * math.exp(math.min(0.6,
+ sigma = sigma * math.exp(math.min(0.6,
(cs / damps) * (torch.sum(torch.pow(ps,2))/N - 1)/2))
end
- local function stop()
- --[[return satisfied termination conditions in a table like
- {'termination reason':value, ...}, for example {'tolfun':1e-12},
- or the empty table {}--]]
+ local function stop()
+ --[[return satisfied termination conditions in a table like
+ {'termination reason':value, ...}, for example {'tolfun':1e-12},
+ or the empty table {}--]]
local res = {}
if counteval > 0 then
if counteval >= maxEval then
@@ -184,7 +184,7 @@ function optim.cmaes(opfunc, x, config, state)
res['condition'] = 1e7
end
if fitvals:nElement() > 1 and fitvals[fitvals:nElement()] - fitvals[1] < 1e-12 then
- res['tolfun'] = 1e-12
+ res['tolfun'] = 1e-12
end
if sigma * torch.max(D) < 1e-11 then
-- remark: max(D) >= max(diag(C))^0.5
@@ -206,8 +206,8 @@ function optim.cmaes(opfunc, x, config, state)
end
if iteration <= 2 or iteration % verb_modulo == 0 then
local max_std = math.sqrt(torch.max(torch.diag(C)))
- print(tostring(counteval).. ': ' ..
- string.format(' %6.1f %8.1e ', torch.max(D) / torch.min(D), sigma * max_std)
+ print(tostring(counteval).. ': ' ..
+ string.format(' %6.1f %8.1e ', torch.max(D) / torch.min(D), sigma * max_std)
.. tostring(fitvals[1]))
end
@@ -224,7 +224,7 @@ function optim.cmaes(opfunc, x, config, state)
fitvals[i] = objfunc(candidate)
end
- tell(X)
+ tell(X)
disp(verb_disp)
end
@@ -245,7 +245,7 @@ end
-BestSolution.__index = BestSolution
+BestSolution.__index = BestSolution
function BestSolution.new(x, f, evals)
local self = setmetatable({}, BestSolution)
self.x = x
diff --git a/doc/algos.md b/doc/algos.md
new file mode 100644
index 0000000..a671420
--- /dev/null
+++ b/doc/algos.md
@@ -0,0 +1,363 @@
+<a name='optim.algos'></a>
+# Optimization algorithms
+
+The following algorithms are provided:
+
+ * [*Stochastic Gradient Descent*](#optim.sgd)
+ * [*Averaged Stochastic Gradient Descent*](#optim.asgd)
+ * [*L-BFGS*](#optim.lbfgs)
+ * [*Congugate Gradients*](#optim.cg)
+ * [*AdaDelta*](#optim.adadelta)
+ * [*AdaGrad*](#optim.adagrad)
+ * [*Adam*](#optim.adam)
+ * [*AdaMax*](#optim.adamax)
+ * [*FISTA with backtracking line search*](#optim.FistaLS)
+ * [*Nesterov's Accelerated Gradient method*](#optim.nag)
+ * [*RMSprop*](#optim.rmsprop)
+ * [*Rprop*](#optim.rprop)
+ * [*CMAES*](#optim.cmaes)
+
+All these algorithms are designed to support batch optimization as well as stochastic optimization.
+It's up to the user to construct an objective function that represents the batch, mini-batch, or single sample on which to evaluate the objective.
+
+Some of these algorithms support a line search, which can be passed as a function (*L-BFGS*), whereas others only support a learning rate (*SGD*).
+
+General interface:
+
+```lua
+x*, {f}, ... = optim.method(opfunc, x[, config][, state])
+```
+
+
+<a name='optim.sgd'></a>
+## sgd(opfunc, x[, config][, state])
+
+An implementation of *Stochastic Gradient Descent* (*SGD*).
+
+Arguments:
+
+ * `opfunc`: a function that takes a single input `X`, the point of a evaluation, and returns `f(X)` and `df/dX`
+ * `x`: the initial point
+ * `config`: a table with configuration parameters for the optimizer
+ * `config.learningRate`: learning rate
+ * `config.learningRateDecay`: learning rate decay
+ * `config.weightDecay`: weight decay
+ * `config.weightDecays`: vector of individual weight decays
+ * `config.momentum`: momentum
+ * `config.dampening`: dampening for momentum
+ * `config.nesterov`: enables Nesterov momentum
+ * `state`: a table describing the state of the optimizer; after each call the state is modified
+ * `state.learningRates`: vector of individual learning rates
+
+Returns:
+
+ * `x*`: the new x vector
+ * `f(x)`: the function, evaluated before the update
+
+
+<a name='optim.asgd'></a>
+## asgd(opfunc, x[, config][, state])
+
+An implementation of *Averaged Stochastic Gradient Descent* (*ASGD*):
+
+```lua
+x = (1 - lambda eta_t) x - eta_t df / dx(z, x)
+a = a + mu_t [ x - a ]
+
+eta_t = eta0 / (1 + lambda eta0 t) ^ 0.75
+mu_t = 1 / max(1, t - t0)
+```
+
+Arguments:
+
+ * `opfunc`: a function that takes a single input `X`, the point of evaluation, and returns `f(X)` and `df/dX`
+ * `x`: the initial point
+ * `config`: a table with configuration parameters for the optimizer
+ * `config.eta0`: learning rate
+ * `config.lambda`: decay term
+ * `config.alpha`: power for eta update
+ * `config.t0`: point at which to start averaging
+
+Returns:
+
+ * `x*`: the new x vector
+ * `f(x)`: the function, evaluated before the update
+ * `ax`: the averaged x vector
+
+
+<a name='optim.lbfgs'></a>
+## lbfgs(opfunc, x[, config][, state])
+
+An implementation of *L-BFGS* that relies on a user-provided line search function (`state.lineSearch`).
+If this function is not provided, then a simple learning rate is used to produce fixed size steps.
+Fixed size steps are much less costly than line searches, and can be useful for stochastic problems.
+
+The learning rate is used even when a line search is provided.
+This is also useful for large-scale stochastic problems, where opfunc is a noisy approximation of `f(x)`.
+In that case, the learning rate allows a reduction of confidence in the step size.
+
+Arguments:
+
+ * `opfunc`: a function that takes a single input `X`, the point of evaluation, and returns `f(X)` and `df/dX`
+ * `x`: the initial point
+ * `config`: a table with configuration parameters for the optimizer
+ * `config.maxIter`: Maximum number of iterations allowed
+ * `config.maxEval`: Maximum number of function evaluations
+ * `config.tolFun`: Termination tolerance on the first-order optimality
+ * `config.tolX`: Termination tol on progress in terms of func/param changes
+ * `config.lineSearch`: A line search function
+ * `config.learningRate`: If no line search provided, then a fixed step size is used
+
+Returns:
+ * `x*`: the new `x` vector, at the optimal point
+ * `f`: a table of all function values:
+ * `f[1]` is the value of the function before any optimization and
+ * `f[#f]` is the final fully optimized value, at `x*`
+
+
+<a name='optim.cg'></a>
+## cg(opfunc, x[, config][, state])
+
+An implementation of the *Conjugate Gradient* method which is a rewrite of `minimize.m` written by Carl E. Rasmussen.
+It is supposed to produce exactly same results (give or take numerical accuracy due to some changed order of operations).
+You can compare the result on rosenbrock with [`minimize.m`](http://www.gatsby.ucl.ac.uk/~edward/code/minimize/example.html).
+
+```lua
+x, fx, c = minimize([0, 0]', 'rosenbrock', -25)
+```
+
+Note that we limit the number of function evaluations only, it seems much more important in practical use.
+
+Arguments:
+
+ * `opfunc`: a function that takes a single input, the point of evaluation.
+ * `x`: the initial point
+ * `config`: a table with configuration parameters for the optimizer
+ * `config.maxEval`: max number of function evaluations
+ * `config.maxIter`: max number of iterations
+ * `state`: a table of parameters and temporary allocations.
+ * `state.df[0, 1, 2, 3]`: if you pass `Tensor` they will be used for temp storage
+ * `state.[s, x0]`: if you pass `Tensor` they will be used for temp storage
+
+Returns:
+
+ * `x*`: the new `x` vector, at the optimal point
+ * `f`: a table of all function values where
+ * `f[1]` is the value of the function before any optimization and
+ * `f[#f]` is the final fully optimized value, at `x*`
+
+
+<a name='optim.adadelta'></a>
+## adadelta(opfunc, x[, config][, state])
+
+*AdaDelta* implementation for *SGD* http://arxiv.org/abs/1212.5701.
+
+Arguments:
+
+ * `opfunc`: a function that takes a single input `X`, the point of evaluation, and returns `f(X)` and `df/dX`
+ * `x`: the initial point
+ * `config`: a table of hyper-parameters
+ * `config.rho`: interpolation parameter
+ * `config.eps`: for numerical stability
+ * `state`: a table describing the state of the optimizer; after each call the state is modified
+ * `state.paramVariance`: vector of temporal variances of parameters
+ * `state.accDelta`: vector of accummulated delta of gradients
+
+Returns:
+
+ * `x*`: the new x vector
+ * `f(x)`: the function, evaluated before the update
+
+
+<a name='optim.adagrad'></a>
+## adagrad(opfunc, x[, config][, state])
+
+*AdaGrad* implementation for *SGD*.
+
+Arguments:
+
+ * `opfunc`: a function that takes a single input `X`, the point of evaluation, and returns `f(X)` and `df/dX`
+ * `x`: the initial point
+ * `config`: a table with configuration parameters for the optimizer
+ * `config.learningRate`: learning rate
+ * `state`: a table describing the state of the optimizer; after each call the state is modified
+ * `state.paramVariance`: vector of temporal variances of parameters
+
+Returns:
+
+ * `x*`: the new `x` vector
+ * `f(x)`: the function, evaluated before the update
+
+
+<a name='optim.adam'></a>
+## adam(opfunc, x[, config][, state])
+
+An implementation of *Adam* from http://arxiv.org/pdf/1412.6980.pdf.
+
+Arguments:
+
+ * `opfunc`: a function that takes a single input `X`, the point of a evaluation, and returns `f(X)` and `df/dX`
+ * `x`: the initial point
+ * `config`: a table with configuration parameters for the optimizer
+ * `config.learningRate`: learning rate
+ * `config.beta1`: first moment coefficient
+ * `config.beta2`: second moment coefficient
+ * `config.epsilon`: for numerical stability
+ * `state`: a table describing the state of the optimizer; after each call the state is modified
+
+Returns:
+
+ * `x*`: the new x vector
+ * `f(x)`: the function, evaluated before the update
+
+
+<a name='optim.adamax'></a>
+## adamax(opfunc, x[, config][, state])
+
+An implementation of *AdaMax* http://arxiv.org/pdf/1412.6980.pdf.
+
+Arguments:
+
+ * `opfunc`: a function that takes a single input `X`, the point of a evaluation, and returns `f(X)` and `df/dX`
+ * `x`: the initial point
+ * `config`: a table with configuration parameters for the optimizer
+ * `config.learningRate`: learning rate
+ * `config.beta1`: first moment coefficient
+ * `config.beta2`: second moment coefficient
+ * `config.epsilon`: for numerical stability
+ * `state`: a table describing the state of the optimizer; after each call the state is modified
+
+Returns:
+
+ * `x*`: the new `x` vector
+ * `f(x)`: the function, evaluated before the update
+
+
+<a name='optim.FistaLS'></a>
+## FistaLS(f, g, pl, xinit[, params])
+
+*Fista* with backtracking *Line Search*:
+
+ * `f`: smooth function
+ * `g`: non-smooth function
+ * `pl`: minimizer of intermediate problem Q(x, y)
+ * `xinit`: initial point
+ * `params`: table of parameters (**optional**)
+ * `params.L`: 1/(step size) for ISTA/FISTA iteration (0.1)
+ * `params.Lstep`: step size multiplier at each iteration (1.5)
+ * `params.maxiter`: max number of iterations (50)
+ * `params.maxline`: max number of line search iterations per iteration (20)
+ * `params.errthres`: Error thershold for convergence check (1e-4)
+ * `params.doFistaUpdate`: true : use FISTA, false: use ISTA (true)
+ * `params.verbose`: store each iteration solution and print detailed info (false)
+
+On output, `params` will contain these additional fields that can be reused.
+ * `params.L`: last used L value will be written.
+
+These are temporary storages needed by the algo and if the same params object is
+passed a second time, these same storages will be used without new allocation.
+ * `params.xkm`: previous iterarion point
+ * `params.y`: fista iteration
+ * `params.ply`: `ply = pl(y * 1/L grad(f))`
+
+Returns the solution `x` and history of `{function evals, number of line search , ...}`.
+
+Algorithm is published in http://epubs.siam.org/doi/abs/10.1137/080716542
+
+
+<a name='optim.nag'></a>
+## nag(opfunc, x[, config][, state])
+
+An implementation of *SGD* adapted with features of *Nesterov's Accelerated Gradient method*, based on the paper "On the Importance of Initialization and Momentum in Deep Learning" (Sutsveker et. al., ICML 2013) http://www.cs.toronto.edu/~fritz/absps/momentum.pdf.
+
+Arguments:
+
+ * `opfunc`: a function that takes a single input `X`, the point of evaluation, and returns `f(X)` and `df/dX`
+ * `x`: the initial point
+ * `config`: a table with configuration parameters for the optimizer
+ * `config.learningRate`: learning rate
+ * `config.learningRateDecay`: learning rate decay
+ * `config.weightDecay`: weight decay
+ * `config.momentum`: momentum
+ * `config.learningRates`: vector of individual learning rates
+
+Returns:
+
+ * `x*`: the new `x` vector
+ * `f(x)`: the function, evaluated before the update
+
+
+<a name='optim.rmsprop'></a>
+## rmsprop(opfunc, x[, config][, state])
+
+An implementation of *RMSprop*.
+
+Arguments:
+
+ * `opfunc`: a function that takes a single input `X`, the point of a evaluation, and returns `f(X)` and `df/dX`
+ * `x`: the initial point
+ * `config`: a table with configuration parameters for the optimizer
+ * `config.learningRate`: learning rate
+ * `config.alpha`: smoothing constant
+ * `config.epsilon`: value with which to initialise m
+ * `state`: a table describing the state of the optimizer; after each call the state is modified
+ * `state.m`: leaky sum of squares of parameter gradients,
+ * `state.tmp`: and the square root (with epsilon smoothing)
+
+Returns:
+
+ * `x*`: the new x vector
+ * `f(x)`: the function, evaluated before the update
+
+
+<a name='optim.rprop'></a>
+## rprop(opfunc, x[, config][, state])
+
+A plain implementation of *Rprop* (Martin Riedmiller, Koray Kavukcuoglu 2013).
+
+Arguments:
+
+ * `opfunc`: a function that takes a single input `X`, the point of evaluation, and returns `f(X)` and `df/dX`
+ * `x`: the initial point
+ * `config`: a table with configuration parameters for the optimizer
+ * `config.stepsize`: initial step size, common to all components
+ * `config.etaplus`: multiplicative increase factor, > 1 (default 1.2)
+ * `config.etaminus`: multiplicative decrease factor, < 1 (default 0.5)
+ * `config.stepsizemax`: maximum stepsize allowed (default 50)
+ * `config.stepsizemin`: minimum stepsize allowed (default 1e-6)
+ * `config.niter`: number of iterations (default 1)
+
+Returns:
+
+ * `x*`: the new x vector
+ * `f(x)`: the function, evaluated before the update
+
+
+<a name='optim.cmaes'></a>
+## cmaes(opfunc, x[, config][, state])
+
+An implementation of *CMAES* (*Covariance Matrix Adaptation Evolution Strategy*), ported from https://www.lri.fr/~hansen/barecmaes2.html.
+
+*CMAES* is a stochastic, derivative-free method for heuristic global optimization of non-linear or non-convex continuous optimization problems.
+Note that this method will on average take much more function evaluations to converge then a gradient based method.
+
+Arguments:
+
+If `state` is specified, then `config` is not used at all.
+Otherwise `state` is `config`.
+
+ * `opfunc`: a function that takes a single input `X`, the point of evaluation, and returns `f(X)` and `df/dX`. Note that `df/dX` is not used and can be left 0
+ * `x`: the initial point
+ * `state`: a table describing the state of the optimizer; after each call the state is modified
+ * `state.sigma`: float, initial step-size (standard deviation in each coordinate)
+ * `state.maxEval`: int, maximal number of function evaluations
+ * `state.ftarget`: float, target function value
+ * `state.popsize`: population size. If this is left empty, `4 + int(3 * log(|x|))` will be used
+ * `state.ftarget`: stop if `fitness < ftarget`
+ * `state.verb_disp`: display info on console every verb_disp iteration, 0 for never
+
+Returns:
+ * `x*`: the new `x` vector, at the optimal point
+ * `f`: a table of all function values:
+ * `f[1]` is the value of the function before any optimization and
+ * `f[#f]` is the final fully optimized value, at `x*`
diff --git a/doc/index.md b/doc/index.md
deleted file mode 100644
index a399206..0000000
--- a/doc/index.md
+++ /dev/null
@@ -1,409 +0,0 @@
-<a name='optim.dok'></a>
-# Optim Package
-
-This package provides a set of optimization algorithms, which all follow
-a unified, closure-based API.
-
-This package is fully compatible with the [nn](http://nn.readthedocs.org) package, but can also be
-used to optimize arbitrary objective functions.
-
-For now, the following algorithms are provided:
-
- * [Stochastic Gradient Descent](#optim.sgd)
- * [Averaged Stochastic Gradient Descent](#optim.asgd)
- * [L-BFGS](#optim.lbfgs)
- * [Congugate Gradients](#optim.cg)
- * [AdaDelta](#optim.adadelta)
- * [AdaGrad](#optim.adagrad)
- * [Adam](#optim.adam)
- * [AdaMax](#optim.adamax)
- * [FISTA with backtracking line search](#optim.FistaLS)
- * [Nesterov's Accelerated Gradient method](#optim.nag)
- * [RMSprop](#optim.rmsprop)
- * [Rprop](#optim.rprop)
- * [CMAES](#optim.cmaes)
-
-All these algorithms are designed to support batch optimization as
-well as stochastic optimization. It's up to the user to construct an
-objective function that represents the batch, mini-batch, or single sample
-on which to evaluate the objective.
-
-Some of these algorithms support a line search, which can be passed as
-a function (L-BFGS), whereas others only support a learning rate (SGD).
-
-<a name='optim.overview'></a>
-## Overview
-
-This package contains several optimization routines for [Torch](https://github.com/torch/torch7/blob/master/README.md).
-Most optimization algorithms has the following interface:
-
-```lua
-x*, {f}, ... = optim.method(opfunc, x, state)
-```
-
-where:
-
-* `opfunc`: a user-defined closure that respects this API: `f, df/dx = func(x)`
-* `x`: the current parameter vector (a 1D `torch.Tensor`)
-* `state`: a table of parameters, and state variables, dependent upon the algorithm
-* `x*`: the new parameter vector that minimizes `f, x* = argmin_x f(x)`
-* `{f}`: a table of all f values, in the order they've been evaluated (for some simple algorithms, like SGD, `#f == 1`)
-
-<a name='optim.example'></a>
-## Example
-
-The state table is used to hold the state of the algorihtm.
-It's usually initialized once, by the user, and then passed to the optim function
-as a black box. Example:
-
-```lua
-state = {
- learningRate = 1e-3,
- momentum = 0.5
-}
-
-for i,sample in ipairs(training_samples) do
- local func = function(x)
- -- define eval function
- return f,df_dx
- end
- optim.sgd(func,x,state)
-end
-```
-
-<a name='optim.algorithms'></a>
-## Algorithms
-
-Most algorithms provided rely on a unified interface:
-```lua
-x_new,fs = optim.method(opfunc, x, state)
-```
-where:
-x is the trainable/adjustable parameter vector,
-state contains both options for the algorithm and the state of the algorihtm,
-opfunc is a closure that has the following interface:
-```lua
-f,df_dx = opfunc(x)
-```
-x_new is the new parameter vector (after optimization),
-fs is a a table containing all the values of the objective, as evaluated during
-the optimization procedure: fs[1] is the value before optimization, and fs[#fs]
-is the most optimized one (the lowest).
-
-<a name='optim.sgd'></a>
-### [x] sgd(opfunc, x, state)
-
-An implementation of Stochastic Gradient Descent (SGD).
-
-Arguments:
-
- * `opfunc` : a function that takes a single input (`X`), the point of a evaluation, and returns `f(X)` and `df/dX`
- * `x` : the initial point
- * `config` : a table with configuration parameters for the optimizer
- * `config.learningRate` : learning rate
- * `config.learningRateDecay` : learning rate decay
- * `config.weightDecay` : weight decay
- * `config.weightDecays` : vector of individual weight decays
- * `config.momentum` : momentum
- * `config.dampening` : dampening for momentum
- * `config.nesterov` : enables Nesterov momentum
- * `state` : a table describing the state of the optimizer; after each call the state is modified
- * `state.learningRates` : vector of individual learning rates
-
-Returns :
-
- * `x` : the new x vector
- * `f(x)` : the function, evaluated before the update
-
-<a name='optim.asgd'></a>
-### [x] asgd(opfunc, x, state)
-
-An implementation of Averaged Stochastic Gradient Descent (ASGD):
-
-```
-x = (1 - lambda eta_t) x - eta_t df/dx(z,x)
-a = a + mu_t [ x - a ]
-
-eta_t = eta0 / (1 + lambda eta0 t) ^ 0.75
-mu_t = 1/max(1,t-t0)
-```
-
-Arguments:
-
- * `opfunc` : a function that takes a single input (`X`), the point of evaluation, and returns `f(X)` and `df/dX`
- * `x` : the initial point
- * `state` : a table describing the state of the optimizer; after each call the state is modified
- * `state.eta0` : learning rate
- * `state.lambda` : decay term
- * `state.alpha` : power for eta update
- * `state.t0` : point at which to start averaging
-
-Returns:
-
- * `x` : the new x vector
- * `f(x)` : the function, evaluated before the update
- * `ax` : the averaged x vector
-
-
-<a name='optim.lbfgs'></a>
-### [x] lbfgs(opfunc, x, state)
-
-An implementation of L-BFGS that relies on a user-provided line
-search function (`state.lineSearch`). If this function is not
-provided, then a simple learningRate is used to produce fixed
-size steps. Fixed size steps are much less costly than line
-searches, and can be useful for stochastic problems.
-
-The learning rate is used even when a line search is provided.
-This is also useful for large-scale stochastic problems, where
-opfunc is a noisy approximation of `f(x)`. In that case, the learning
-rate allows a reduction of confidence in the step size.
-
-Arguments :
-
- * `opfunc` : a function that takes a single input (`X`), the point of evaluation, and returns `f(X)` and `df/dX`
- * `x` : the initial point
- * `state` : a table describing the state of the optimizer; after each call the state is modified
- * `state.maxIter` : Maximum number of iterations allowed
- * `state.maxEval` : Maximum number of function evaluations
- * `state.tolFun` : Termination tolerance on the first-order optimality
- * `state.tolX` : Termination tol on progress in terms of func/param changes
- * `state.lineSearch` : A line search function
- * `state.learningRate` : If no line search provided, then a fixed step size is used
-
-Returns :
- * `x*` : the new `x` vector, at the optimal point
- * `f` : a table of all function values:
- * `f[1]` is the value of the function before any optimization and
- * `f[#f]` is the final fully optimized value, at `x*`
-
-
-<a name='optim.cg'></a>
-### [x] cg(opfunc, x, state)
-
-An implementation of the Conjugate Gradient method which is a rewrite of
-`minimize.m` written by Carl E. Rasmussen.
-It is supposed to produce exactly same results (give
-or take numerical accuracy due to some changed order of
-operations). You can compare the result on rosenbrock with
-[minimize.m](http://www.gatsby.ucl.ac.uk/~edward/code/minimize/example.html).
-```
-[x fx c] = minimize([0 0]', 'rosenbrock', -25)
-```
-
-Note that we limit the number of function evaluations only, it seems much
-more important in practical use.
-
-Arguments :
-
- * `opfunc` : a function that takes a single input, the point of evaluation.
- * `x` : the initial point
- * `state` : a table of parameters and temporary allocations.
- * `state.maxEval` : max number of function evaluations
- * `state.maxIter` : max number of iterations
- * `state.df[0,1,2,3]` : if you pass torch.Tensor they will be used for temp storage
- * `state.[s,x0]` : if you pass torch.Tensor they will be used for temp storage
-
-Returns :
-
- * `x*` : the new x vector, at the optimal point
- * `f` : a table of all function values where
- * `f[1]` is the value of the function before any optimization and
- * `f[#f]` is the final fully optimized value, at x*
-
-<a name='optim.adadelta'></a>
-### [x] adadelta(opfunc, x, config, state)
-ADADELTA implementation for SGD http://arxiv.org/abs/1212.5701
-
-Arguments :
-
-* `opfunc` : a function that takes a single input (X), the point of evaluation, and returns f(X) and df/dX
-* `x` : the initial point
-* `config` : a table of hyper-parameters
-* `config.rho` : interpolation parameter
-* `config.eps` : for numerical stability
-* `state` : a table describing the state of the optimizer; after each call the state is modified
-* `state.paramVariance` : vector of temporal variances of parameters
-* `state.accDelta` : vector of accummulated delta of gradients
-
-Returns :
-
-* `x` : the new x vector
-* `f(x)` : the function, evaluated before the update
-
-<a name='optim.adagrad'></a>
-### [x] adagrad(opfunc, x, config, state)
-AdaGrad implementation for SGD
-
-Arguments :
-
-* `opfunc` : a function that takes a single input (X), the point of evaluation, and returns f(X) and df/dX
-* `x` : the initial point
-* `state` : a table describing the state of the optimizer; after each call the state is modified
-* `state.learningRate` : learning rate
-* `state.paramVariance` : vector of temporal variances of parameters
-
-Returns :
-
-* `x` : the new x vector
-* `f(x)` : the function, evaluated before the update
-
-<a name='optim.adam'></a>
-### [x] adam(opfunc, x, config, state)
-An implementation of Adam from http://arxiv.org/pdf/1412.6980.pdf
-
-Arguments :
-
-* `opfunc` : a function that takes a single input (X), the point of a evaluation, and returns f(X) and df/dX
-* `x` : the initial point
-* `config` : a table with configuration parameters for the optimizer
-* `config.learningRate` : learning rate
-* `config.beta1` : first moment coefficient
-* `config.beta2` : second moment coefficient
-* `config.epsilon` : for numerical stability
-* `state` : a table describing the state of the optimizer; after each call the state is modified
-
-Returns :
-
-* `x` : the new x vector
-* `f(x)` : the function, evaluated before the update
-
-<a name='optim.adamax'></a>
-### [x] adamax(opfunc, x, config, state)
-An implementation of AdaMax http://arxiv.org/pdf/1412.6980.pdf
-
-Arguments :
-
-* `opfunc` : a function that takes a single input (X), the point of a evaluation, and returns f(X) and df/dX
-* `x` : the initial point
-* `config` : a table with configuration parameters for the optimizer
-* `config.learningRate` : learning rate
-* `config.beta1` : first moment coefficient
-* `config.beta2` : second moment coefficient
-* `config.epsilon` : for numerical stability
-* `state` : a table describing the state of the optimizer; after each call the state is modified.
-
-Returns :
-
-* `x` : the new x vector
-* `f(x)` : the function, evaluated before the update
-
-<a name='optim.FistaLS'></a>
-### [x] FistaLS(f, g, pl, xinit, params)
-FISTA with backtracking line search
-* `f` : smooth function
-* `g` : non-smooth function
-* `pl` : minimizer of intermediate problem Q(x,y)
-* `xinit` : initial point
-* `params` : table of parameters (**optional**)
-* `params.L` : 1/(step size) for ISTA/FISTA iteration (0.1)
-* `params.Lstep` : step size multiplier at each iteration (1.5)
-* `params.maxiter` : max number of iterations (50)
-* `params.maxline` : max number of line search iterations per iteration (20)
-* `params.errthres`: Error thershold for convergence check (1e-4)
-* `params.doFistaUpdate` : true : use FISTA, false: use ISTA (true)
-* `params.verbose` : store each iteration solution and print detailed info (false)
-
-On output, `params` will contain these additional fields that can be reused.
-* `params.L` : last used L value will be written.
-
-These are temporary storages needed by the algo and if the same params object is
-passed a second time, these same storages will be used without new allocation.
-* `params.xkm` : previous iterarion point
-* `params.y` : fista iteration
-* `params.ply` : ply = pl(y * 1/L grad(f))
-
-Returns the solution x and history of {function evals, number of line search ,...}
-
-Algorithm is published in http://epubs.siam.org/doi/abs/10.1137/080716542
-
-<a name='optim.nag'></a>
-### [x] nag(opfunc, x, config, state)
-An implementation of SGD adapted with features of Nesterov's
-Accelerated Gradient method, based on the paper "On the Importance of Initialization and Momentum in Deep Learning" (Sutsveker et. al., ICML 2013).
-
-Arguments :
-
-* `opfunc` : a function that takes a single input (X), the point of evaluation, and returns f(X) and df/dX
-* `x` : the initial point
-* `state` : a table describing the state of the optimizer; after each call the state is modified
-* `state.learningRate` : learning rate
-* `state.learningRateDecay` : learning rate decay
-* `astate.weightDecay` : weight decay
-* `state.momentum` : momentum
-* `state.learningRates` : vector of individual learning rates
-
-Returns :
-
-* `x` : the new x vector
-* `f(x)` : the function, evaluated before the update
-
-<a name='optim.rmsprop'></a>
-### [x] rmsprop(opfunc, x, config, state)
-An implementation of RMSprop
-
-Arguments :
-
-* `opfunc` : a function that takes a single input (X), the point of a evaluation, and returns f(X) and df/dX
-* `x` : the initial point
-* `config` : a table with configuration parameters for the optimizer
-* `config.learningRate` : learning rate
-* `config.alpha` : smoothing constant
-* `config.epsilon` : value with which to initialise m
-* `state` : a table describing the state of the optimizer; after each call the state is modified
-* `state.m` : leaky sum of squares of parameter gradients,
-* `state.tmp` : and the square root (with epsilon smoothing)
-
-Returns :
-
-* `x` : the new x vector
-* `f(x)` : the function, evaluated before the update
-
-<a name='optim.rprop'></a>
-### [x] rprop(opfunc, x, config, state)
-A plain implementation of Rprop
-(Martin Riedmiller, Koray Kavukcuoglu 2013)
-
-Arguments :
-
-* `opfunc` : a function that takes a single input (X), the point of evaluation, and returns f(X) and df/dX
-* `x` : the initial point
-* `state` : a table describing the state of the optimizer; after each call the state is modified
-* `state.stepsize` : initial step size, common to all components
-* `state.etaplus` : multiplicative increase factor, > 1 (default 1.2)
-* `state.etaminus` : multiplicative decrease factor, < 1 (default 0.5)
-* `state.stepsizemax` : maximum stepsize allowed (default 50)
-* `state.stepsizemin` : minimum stepsize allowed (default 1e-6)
-* `state.niter` : number of iterations (default 1)
-
-Returns :
-
-* `x` : the new x vector
-* `f(x)` : the function, evaluated before the update
-
-
-
-
-<a name='optim.cmaes'></a>
-### [x] cmaes(opfunc, x, config, state)
-An implementation of `CMAES` (Covariance Matrix Adaptation Evolution Strategy),
-ported from https://www.lri.fr/~hansen/barecmaes2.html.
-
-CMAES is a stochastic, derivative-free method for heuristic global optimization of non-linear or non-convex continuous optimization problems. Note that this method will on average take much more function evaluations to converge then a gradient based method.
-
-Arguments:
-
-* `opfunc` : a function that takes a single input (X), the point of evaluation, and returns f(X) and df/dX. Note that df/dX is not used and can be left 0
-* `x` : the initial point
-* `state.sigma` : float, initial step-size (standard deviation in each coordinate)
-* `state.maxEval` : int, maximal number of function evaluations
-* `state.ftarget` : float, target function value
-* `state.popsize` : population size. If this is left empty, 4 + int(3 * log(|x|)) will be used
-* `state.ftarget` : stop if fitness < ftarget
-* `state.verb_disp` : display info on console every verb_disp iteration, 0 for never
-
-Returns:
-* `x*` : the new `x` vector, at the optimal point
-* `f` : a table of all function values:
- * `f[1]` is the value of the function before any optimization and
- * `f[#f]` is the final fully optimized value, at `x*`
diff --git a/doc/intro.md b/doc/intro.md
new file mode 100644
index 0000000..b387235
--- /dev/null
+++ b/doc/intro.md
@@ -0,0 +1,41 @@
+<a name='optim.overview'></a>
+# Overview
+
+Most optimization algorithms have the following interface:
+
+```lua
+x*, {f}, ... = optim.method(opfunc, x[, config][, state])
+```
+
+where:
+
+* `opfunc`: a user-defined closure that respects this API: `f, df/dx = func(x)`
+* `x`: the current parameter vector (a 1D `Tensor`)
+* `config`: a table of parameters, dependent upon the algorithm
+* `state`: a table of state variables, if `nil`, `config` will contain the state
+* `x*`: the new parameter vector that minimizes `f, x* = argmin_x f(x)`
+* `{f}`: a table of all `f` values, in the order they've been evaluated (for some simple algorithms, like SGD, `#f == 1`)
+
+
+<a name='optim.example'></a>
+## Example
+
+The state table is used to hold the state of the algorihtm.
+It's usually initialized once, by the user, and then passed to the optim function as a black box.
+Example:
+
+```lua
+config = {
+ learningRate = 1e-3,
+ momentum = 0.5
+}
+
+for i, sample in ipairs(training_samples) do
+ local func = function(x)
+ -- define eval function
+ return f, df_dx
+ end
+ optim.sgd(func, x, config)
+end
+```
+
diff --git a/doc/logger.md b/doc/logger.md
new file mode 100644
index 0000000..b7797d2
--- /dev/null
+++ b/doc/logger.md
@@ -0,0 +1,73 @@
+<a name='optim.logger'></a>
+# Logger
+
+`optim` provides also logging and live plotting capabilities via the `optim.Logger()` function.
+
+Live logging is essential to monitor the *network accuracy* and *cost function* during training and testing, for spotting *under-* and *over-fitting*, for *early stopping* or just for monitoring the health of the current optimisation task.
+
+
+## Logging data
+
+Let walk through an example to see how it works.
+
+We start with initialising our logger connected to a text file `accuracy.log`.
+
+```lua
+logger = optim.Logger('accuracy.log')
+```
+
+We can decide to log on it, for example, *training* and *testing accuracies*.
+
+```lua
+logger:setNames{'Training acc.', 'Test acc.'}
+```
+
+And now we can populate our logger randomly.
+
+```lua
+for i = 1, 10 do
+ trainAcc = math.random(0, 100)
+ testAcc = math.random(0, 100)
+ logger:add{trainAcc, testAcc}
+end
+```
+
+We can `cat` `accuracy.log` and see what's in it.
+
+```
+Training acc. Test acc.
+ 7.0000e+01 5.9000e+01
+ 7.6000e+01 8.0000e+00
+ 6.6000e+01 3.4000e+01
+ 7.4000e+01 4.3000e+01
+ 5.7000e+01 1.1000e+01
+ 5.0000e+00 9.8000e+01
+ 7.1000e+01 1.7000e+01
+ 9.8000e+01 2.7000e+01
+ 3.5000e+01 4.7000e+01
+ 6.8000e+01 5.8000e+01
+```
+
+## Visualising logs
+
+OK, cool, but how can we actually see what's going on?
+
+To have a better grasp of what's happening, we can plot our curves.
+We need first to specify the plotting style, choosing from:
+
+ * `.` for dots
+ * `+` for points
+ * `-` for lines
+ * `+-` for points and lines
+ * `~` for using smoothed lines with cubic interpolation
+ * `|` for using boxes
+ * custom string, one can also pass custom strings to use full capability of gnuplot.
+
+```lua
+logger:style{'+-', '+-'}
+logger:plot()
+```
+
+![Logging plot](logger_plot.png)
+
+If we'd like an interactive visualisation, we can put the `logger:plot()` instruction within the `for` loop, and the chart will be updated at every iteration.
diff --git a/doc/logger_plot.png b/doc/logger_plot.png
new file mode 100644
index 0000000..c5e86ae
--- /dev/null
+++ b/doc/logger_plot.png
Binary files differ
diff --git a/fista.lua b/fista.lua
index 7fba128..c8c6f5e 100644
--- a/fista.lua
+++ b/fista.lua
@@ -17,7 +17,7 @@ On output, `params` will contain these additional fields that can be reused.
- `params.L` : last used L value will be written.
-These are temporary storages needed by the algo and if the same params object is
+These are temporary storages needed by the algo and if the same params object is
passed a second time, these same storages will be used without new allocation.
- `params.xkm` : previous iterarion point
@@ -26,7 +26,7 @@ passed a second time, these same storages will be used without new allocation.
Returns the solution x and history of {function evals, number of line search ,...}
-Algorithm is published in
+Algorithm is published in
@article{beck-fista-09,
Author = {Beck, Amir and Teboulle, Marc},
@@ -38,7 +38,7 @@ Algorithm is published in
Year = {2009}}
]]
function optim.FistaLS(f, g, pl, xinit, params)
-
+
local params = params or {}
local L = params.L or 0.1
local Lstep = params.Lstep or 1.5
@@ -46,7 +46,7 @@ function optim.FistaLS(f, g, pl, xinit, params)
local maxline = params.maxline or 20
local errthres = params.errthres or 1e-4
local doFistaUpdate = params.doFistaUpdate
- local verbose = params.verbose
+ local verbose = params.verbose
-- temporary allocations
params.xkm = params.xkm or torch.Tensor()
@@ -77,11 +77,11 @@ function optim.FistaLS(f, g, pl, xinit, params)
-- get derivatives from smooth function
local fy,gfy = f(y,'dx')
--local gfy = f(y)
-
+
local fply = 0
local gply = 0
local Q = 0
-
+
----------------------------------------------
-- do line search to find new current location starting from fista loc
local nline = 0
@@ -98,7 +98,7 @@ function optim.FistaLS(f, g, pl, xinit, params)
-- evaluate this point F(ply)
fply = f(ply)
-
+
-- ply - y
ply:add(-1, y)
-- <ply-y , \Grad(f(y))>
diff --git a/lbfgs.lua b/lbfgs.lua
index 4c7a0b8..d850fcb 100644
--- a/lbfgs.lua
+++ b/lbfgs.lua
@@ -27,7 +27,7 @@ ARGS:
RETURN:
- `x*` : the new `x` vector, at the optimal point
-- `f` : a table of all function values:
+- `f` : a table of all function values:
`f[1]` is the value of the function before any optimization and
`f[#f]` is the final fully optimized value, at `x*`
@@ -46,7 +46,7 @@ function optim.lbfgs(opfunc, x, config, state)
local lineSearchOpts = config.lineSearchOptions
local learningRate = config.learningRate or 1
local isverbose = config.verbose or false
-
+
state.funcEval = state.funcEval or 0
state.nIter = state.nIter or 0
@@ -142,7 +142,7 @@ function optim.lbfgs(opfunc, x, config, state)
table.insert(state.stp_bufs, s)
end
- -- compute the approximate (L-BFGS) inverse Hessian
+ -- compute the approximate (L-BFGS) inverse Hessian
-- multiplied by the gradient
local k = #old_dirs
diff --git a/nag.lua b/nag.lua
index fd4210d..875d81e 100644
--- a/nag.lua
+++ b/nag.lua
@@ -1,11 +1,11 @@
----------------------------------------------------------------------
--- An implementation of SGD adapted with features of Nesterov's
+-- An implementation of SGD adapted with features of Nesterov's
-- Accelerated Gradient method, based on the paper
-- On the Importance of Initialization and Momentum in Deep Learning
-- Sutsveker et. al., ICML 2013
--
-- ARGS:
--- opfunc : a function that takes a single input (X), the point of
+-- opfunc : a function that takes a single input (X), the point of
-- evaluation, and returns f(X) and df/dX
-- x : the initial point
-- state : a table describing the state of the optimizer; after each
@@ -44,7 +44,7 @@ function optim.nag(opfunc, x, config, state)
-- first step in the direction of the momentum vector
if state.dfdx then
- x:add(mom, state.dfdx)
+ x:add(mom, state.dfdx)
end
-- then compute gradient at that point
-- comment out the above line to get the original SGD
diff --git a/rmsprop.lua b/rmsprop.lua
index 038af21..1eb526d 100644
--- a/rmsprop.lua
+++ b/rmsprop.lua
@@ -22,36 +22,36 @@ RETURN:
]]
function optim.rmsprop(opfunc, x, config, state)
- -- (0) get/update state
- local config = config or {}
- local state = state or config
- local lr = config.learningRate or 1e-2
- local alpha = config.alpha or 0.99
- local epsilon = config.epsilon or 1e-8
- local wd = config.weightDecay or 0
-
- -- (1) evaluate f(x) and df/dx
- local fx, dfdx = opfunc(x)
-
- -- (2) weight decay
- if wd ~= 0 then
+ -- (0) get/update state
+ local config = config or {}
+ local state = state or config
+ local lr = config.learningRate or 1e-2
+ local alpha = config.alpha or 0.99
+ local epsilon = config.epsilon or 1e-8
+ local wd = config.weightDecay or 0
+
+ -- (1) evaluate f(x) and df/dx
+ local fx, dfdx = opfunc(x)
+
+ -- (2) weight decay
+ if wd ~= 0 then
dfdx:add(wd, x)
- end
+ end
- -- (3) initialize mean square values and square gradient storage
- if not state.m then
+ -- (3) initialize mean square values and square gradient storage
+ if not state.m then
state.m = torch.Tensor():typeAs(x):resizeAs(dfdx):fill(1)
state.tmp = torch.Tensor():typeAs(x):resizeAs(dfdx)
- end
+ end
- -- (4) calculate new (leaky) mean squared values
- state.m:mul(alpha)
- state.m:addcmul(1.0-alpha, dfdx, dfdx)
+ -- (4) calculate new (leaky) mean squared values
+ state.m:mul(alpha)
+ state.m:addcmul(1.0-alpha, dfdx, dfdx)
- -- (5) perform update
- state.tmp:sqrt(state.m):add(epsilon)
- x:addcdiv(-lr, dfdx, state.tmp)
+ -- (5) perform update
+ state.tmp:sqrt(state.m):add(epsilon)
+ x:addcdiv(-lr, dfdx, state.tmp)
- -- return x*, f(x) before optimization
- return x, {fx}
+ -- return x*, f(x) before optimization
+ return x, {fx}
end
diff --git a/rprop.lua b/rprop.lua
index d6c9579..d7af164 100644
--- a/rprop.lua
+++ b/rprop.lua
@@ -20,83 +20,83 @@ RETURN:
(Martin Riedmiller, Koray Kavukcuoglu 2013)
--]]
function optim.rprop(opfunc, x, config, state)
- if config == nil and state == nil then
- print('no state table RPROP initializing')
- end
- -- (0) get/update state
- local config = config or {}
- local state = state or config
- local stepsize = config.stepsize or 0.1
- local etaplus = config.etaplus or 1.2
- local etaminus = config.etaminus or 0.5
- local stepsizemax = config.stepsizemax or 50.0
- local stepsizemin = config.stepsizemin or 1E-06
- local niter = config.niter or 1
-
- local hfx = {}
-
- for i=1,niter do
-
- -- (1) evaluate f(x) and df/dx
- local fx,dfdx = opfunc(x)
-
- -- init temp storage
- if not state.delta then
- state.delta = dfdx.new(dfdx:size()):zero()
- state.stepsize = dfdx.new(dfdx:size()):fill(stepsize)
- state.sign = dfdx.new(dfdx:size())
- state.psign = torch.ByteTensor(dfdx:size())
- state.nsign = torch.ByteTensor(dfdx:size())
- state.zsign = torch.ByteTensor(dfdx:size())
- state.dminmax = torch.ByteTensor(dfdx:size())
- if torch.type(x)=='torch.CudaTensor' then
- -- Push to GPU
- state.psign = state.psign:cuda()
- state.nsign = state.nsign:cuda()
- state.zsign = state.zsign:cuda()
- state.dminmax = state.dminmax:cuda()
- end
- end
-
- -- sign of derivative from last step to this one
- torch.cmul(state.sign, dfdx, state.delta)
- torch.sign(state.sign, state.sign)
-
- -- get indices of >0, <0 and ==0 entries
- state.sign.gt(state.psign, state.sign, 0)
- state.sign.lt(state.nsign, state.sign, 0)
- state.sign.eq(state.zsign, state.sign, 0)
-
- -- get step size updates
- state.sign[state.psign] = etaplus
- state.sign[state.nsign] = etaminus
- state.sign[state.zsign] = 1
-
- -- update stepsizes with step size updates
- state.stepsize:cmul(state.sign)
-
- -- threshold step sizes
- -- >50 => 50
- state.stepsize.gt(state.dminmax, state.stepsize, stepsizemax)
- state.stepsize[state.dminmax] = stepsizemax
- -- <1e-6 ==> 1e-6
- state.stepsize.lt(state.dminmax, state.stepsize, stepsizemin)
- state.stepsize[state.dminmax] = stepsizemin
-
- -- for dir<0, dfdx=0
- -- for dir>=0 dfdx=dfdx
- dfdx[state.nsign] = 0
- -- state.sign = sign(dfdx)
- torch.sign(state.sign,dfdx)
-
- -- update weights
- x:addcmul(-1,state.sign,state.stepsize)
-
- -- update state.dfdx with current dfdx
- state.delta:copy(dfdx)
-
- table.insert(hfx,fx)
- end
+ if config == nil and state == nil then
+ print('no state table RPROP initializing')
+ end
+ -- (0) get/update state
+ local config = config or {}
+ local state = state or config
+ local stepsize = config.stepsize or 0.1
+ local etaplus = config.etaplus or 1.2
+ local etaminus = config.etaminus or 0.5
+ local stepsizemax = config.stepsizemax or 50.0
+ local stepsizemin = config.stepsizemin or 1E-06
+ local niter = config.niter or 1
+
+ local hfx = {}
+
+ for i=1,niter do
+
+ -- (1) evaluate f(x) and df/dx
+ local fx,dfdx = opfunc(x)
+
+ -- init temp storage
+ if not state.delta then
+ state.delta = dfdx.new(dfdx:size()):zero()
+ state.stepsize = dfdx.new(dfdx:size()):fill(stepsize)
+ state.sign = dfdx.new(dfdx:size())
+ state.psign = torch.ByteTensor(dfdx:size())
+ state.nsign = torch.ByteTensor(dfdx:size())
+ state.zsign = torch.ByteTensor(dfdx:size())
+ state.dminmax = torch.ByteTensor(dfdx:size())
+ if torch.type(x)=='torch.CudaTensor' then
+ -- Push to GPU
+ state.psign = state.psign:cuda()
+ state.nsign = state.nsign:cuda()
+ state.zsign = state.zsign:cuda()
+ state.dminmax = state.dminmax:cuda()
+ end
+ end
+
+ -- sign of derivative from last step to this one
+ torch.cmul(state.sign, dfdx, state.delta)
+ torch.sign(state.sign, state.sign)
+
+ -- get indices of >0, <0 and ==0 entries
+ state.sign.gt(state.psign, state.sign, 0)
+ state.sign.lt(state.nsign, state.sign, 0)
+ state.sign.eq(state.zsign, state.sign, 0)
+
+ -- get step size updates
+ state.sign[state.psign] = etaplus
+ state.sign[state.nsign] = etaminus
+ state.sign[state.zsign] = 1
+
+ -- update stepsizes with step size updates
+ state.stepsize:cmul(state.sign)
+
+ -- threshold step sizes
+ -- >50 => 50
+ state.stepsize.gt(state.dminmax, state.stepsize, stepsizemax)
+ state.stepsize[state.dminmax] = stepsizemax
+ -- <1e-6 ==> 1e-6
+ state.stepsize.lt(state.dminmax, state.stepsize, stepsizemin)
+ state.stepsize[state.dminmax] = stepsizemin
+
+ -- for dir<0, dfdx=0
+ -- for dir>=0 dfdx=dfdx
+ dfdx[state.nsign] = 0
+ -- state.sign = sign(dfdx)
+ torch.sign(state.sign,dfdx)
+
+ -- update weights
+ x:addcmul(-1,state.sign,state.stepsize)
+
+ -- update state.dfdx with current dfdx
+ state.delta:copy(dfdx)
+
+ table.insert(hfx,fx)
+ end
-- return x*, f(x) before optimization
return x,hfx
diff --git a/sgd.lua b/sgd.lua
index ea13c55..e21c696 100644
--- a/sgd.lua
+++ b/sgd.lua
@@ -70,7 +70,7 @@ function optim.sgd(opfunc, x, config, state)
-- (4) learning rate decay (annealing)
local clr = lr / (1 + nevals*lrd)
-
+
-- (5) parameter update with single or individual learning rates
if lrs then
if not state.deltaParameters then