diff options
author | Soumith Chintala <soumith@gmail.com> | 2016-06-30 15:15:40 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2016-06-30 15:15:40 +0300 |
commit | e24fd8550d366668473d0b5a89c00546d2145c81 (patch) | |
tree | e02ff522a1e8c03aa8c3f8e186a79c0a0420ffa6 | |
parent | f6be4bb195e3e128ab027326255172ff36b6c63c (diff) | |
parent | 06a42e2af40697bd2c95843aee1b75bc51d4270d (diff) |
Merge pull request #121 from Atcold/doc-fix
Documentation and code refactoring
-rw-r--r-- | README.md | 47 | ||||
-rw-r--r-- | adam.lua | 70 | ||||
-rw-r--r-- | adamax.lua | 70 | ||||
-rw-r--r-- | asgd.lua | 4 | ||||
-rw-r--r-- | cmaes.lua | 62 | ||||
-rw-r--r-- | doc/algos.md | 363 | ||||
-rw-r--r-- | doc/index.md | 409 | ||||
-rw-r--r-- | doc/intro.md | 41 | ||||
-rw-r--r-- | doc/logger.md | 73 | ||||
-rw-r--r-- | doc/logger_plot.png | bin | 0 -> 45532 bytes | |||
-rw-r--r-- | fista.lua | 14 | ||||
-rw-r--r-- | lbfgs.lua | 6 | ||||
-rw-r--r-- | nag.lua | 6 | ||||
-rw-r--r-- | rmsprop.lua | 50 | ||||
-rw-r--r-- | rprop.lua | 154 | ||||
-rw-r--r-- | sgd.lua | 2 |
16 files changed, 701 insertions, 670 deletions
@@ -1,45 +1,8 @@ +<a name='optim.dok'></a> # Optimization package -This package contains several optimization routines for [Torch](https://github.com/torch/torch7/blob/master/README.md). -Each optimization algorithm is based on the same interface: +This package contains several optimization routines and a logger for [Torch](https://github.com/torch/torch7/blob/master/README.md): -```lua -x*, {f}, ... = optim.method(func, x, state) -``` - -where: - -* `func`: a user-defined closure that respects this API: `f, df/dx = func(x)` -* `x`: the current parameter vector (a 1D `torch.Tensor`) -* `state`: a table of parameters, and state variables, dependent upon the algorithm -* `x*`: the new parameter vector that minimizes `f, x* = argmin_x f(x)` -* `{f}`: a table of all f values, in the order they've been evaluated - (for some simple algorithms, like SGD, `#f == 1`) - -## Available algorithms - -Please check [this file](doc/index.md) for the full list of -optimization algorithms available and examples. Get also into the -[`test`](test/) directory for straightforward examples using the -[Rosenbrock's](test/rosenbrock.lua) function. - -## Important Note - -The state table is used to hold the state of the algorithm. -It's usually initialized once, by the user, and then passed to the optim function -as a black box. Example: - -```lua -state = { - learningRate = 1e-3, - momentum = 0.5 -} - -for i,sample in ipairs(training_samples) do - local func = function(x) - -- define eval function - return f,df_dx - end - optim.sgd(func,x,state) -end -``` + * [Overview](doc/intro.md); + * [Optimization algorithms](doc/algos.md); + * [Logger](doc/logger.md). @@ -21,47 +21,47 @@ RETURN: ]] function optim.adam(opfunc, x, config, state) - -- (0) get/update state - local config = config or {} - local state = state or config - local lr = config.learningRate or 0.001 + -- (0) get/update state + local config = config or {} + local state = state or config + local lr = config.learningRate or 0.001 - local beta1 = config.beta1 or 0.9 - local beta2 = config.beta2 or 0.999 - local epsilon = config.epsilon or 1e-8 - local wd = config.weightDecay or 0 + local beta1 = config.beta1 or 0.9 + local beta2 = config.beta2 or 0.999 + local epsilon = config.epsilon or 1e-8 + local wd = config.weightDecay or 0 - -- (1) evaluate f(x) and df/dx - local fx, dfdx = opfunc(x) + -- (1) evaluate f(x) and df/dx + local fx, dfdx = opfunc(x) - -- (2) weight decay - if wd ~= 0 then - dfdx:add(wd, x) - end + -- (2) weight decay + if wd ~= 0 then + dfdx:add(wd, x) + end - -- Initialization - state.t = state.t or 0 - -- Exponential moving average of gradient values - state.m = state.m or x.new(dfdx:size()):zero() - -- Exponential moving average of squared gradient values - state.v = state.v or x.new(dfdx:size()):zero() - -- A tmp tensor to hold the sqrt(v) + epsilon - state.denom = state.denom or x.new(dfdx:size()):zero() + -- Initialization + state.t = state.t or 0 + -- Exponential moving average of gradient values + state.m = state.m or x.new(dfdx:size()):zero() + -- Exponential moving average of squared gradient values + state.v = state.v or x.new(dfdx:size()):zero() + -- A tmp tensor to hold the sqrt(v) + epsilon + state.denom = state.denom or x.new(dfdx:size()):zero() - state.t = state.t + 1 - - -- Decay the first and second moment running average coefficient - state.m:mul(beta1):add(1-beta1, dfdx) - state.v:mul(beta2):addcmul(1-beta2, dfdx, dfdx) + state.t = state.t + 1 - state.denom:copy(state.v):sqrt():add(epsilon) + -- Decay the first and second moment running average coefficient + state.m:mul(beta1):add(1-beta1, dfdx) + state.v:mul(beta2):addcmul(1-beta2, dfdx, dfdx) - local biasCorrection1 = 1 - beta1^state.t - local biasCorrection2 = 1 - beta2^state.t - local stepSize = lr * math.sqrt(biasCorrection2)/biasCorrection1 - -- (3) update x - x:addcdiv(-stepSize, state.m, state.denom) + state.denom:copy(state.v):sqrt():add(epsilon) - -- return x*, f(x) before optimization - return x, {fx} + local biasCorrection1 = 1 - beta1^state.t + local biasCorrection2 = 1 - beta2^state.t + local stepSize = lr * math.sqrt(biasCorrection2)/biasCorrection1 + -- (3) update x + x:addcdiv(-stepSize, state.m, state.denom) + + -- return x*, f(x) before optimization + return x, {fx} end @@ -20,47 +20,47 @@ RETURN: ]] function optim.adamax(opfunc, x, config, state) - -- (0) get/update state - local config = config or {} - local state = state or config - local lr = config.learningRate or 0.002 + -- (0) get/update state + local config = config or {} + local state = state or config + local lr = config.learningRate or 0.002 - local beta1 = config.beta1 or 0.9 - local beta2 = config.beta2 or 0.999 - local epsilon = config.epsilon or 1e-38 - local wd = config.weightDecay or 0 + local beta1 = config.beta1 or 0.9 + local beta2 = config.beta2 or 0.999 + local epsilon = config.epsilon or 1e-38 + local wd = config.weightDecay or 0 - -- (1) evaluate f(x) and df/dx - local fx, dfdx = opfunc(x) + -- (1) evaluate f(x) and df/dx + local fx, dfdx = opfunc(x) - -- (2) weight decay - if wd ~= 0 then - dfdx:add(wd, x) - end + -- (2) weight decay + if wd ~= 0 then + dfdx:add(wd, x) + end - -- Initialization - state.t = state.t or 0 - -- Exponential moving average of gradient values - state.m = state.m or x.new(dfdx:size()):zero() - -- Exponential moving average of the infinity norm - state.u = state.u or x.new(dfdx:size()):zero() - -- A tmp tensor to hold the input to max() - state.max = state.max or x.new(2, unpack(dfdx:size():totable())):zero() + -- Initialization + state.t = state.t or 0 + -- Exponential moving average of gradient values + state.m = state.m or x.new(dfdx:size()):zero() + -- Exponential moving average of the infinity norm + state.u = state.u or x.new(dfdx:size()):zero() + -- A tmp tensor to hold the input to max() + state.max = state.max or x.new(2, unpack(dfdx:size():totable())):zero() - state.t = state.t + 1 + state.t = state.t + 1 - -- Update biased first moment estimate. - state.m:mul(beta1):add(1-beta1, dfdx) - -- Update the exponentially weighted infinity norm. - state.max[1]:copy(state.u):mul(beta2) - state.max[2]:copy(dfdx):abs():add(epsilon) - state.u:max(state.max, 1) + -- Update biased first moment estimate. + state.m:mul(beta1):add(1-beta1, dfdx) + -- Update the exponentially weighted infinity norm. + state.max[1]:copy(state.u):mul(beta2) + state.max[2]:copy(dfdx):abs():add(epsilon) + state.u:max(state.max, 1) - local biasCorrection1 = 1 - beta1^state.t - local stepSize = lr/biasCorrection1 - -- (2) update x - x:addcdiv(-stepSize, state.m, state.u) + local biasCorrection1 = 1 - beta1^state.t + local stepSize = lr/biasCorrection1 + -- (2) update x + x:addcdiv(-stepSize, state.m, state.u) - -- return x*, f(x) before optimization - return x, {fx} + -- return x*, f(x) before optimization + return x, {fx} end @@ -1,6 +1,6 @@ --[[ An implementation of ASGD -ASGD: +ASGD: x := (1 - lambda eta_t) x - eta_t df/dx(z,x) a := a + mu_t [ x - a ] @@ -12,7 +12,7 @@ implements ASGD algoritm as in L.Bottou's sgd-2.0 ARGS: -- `opfunc` : a function that takes a single input (X), the point of +- `opfunc` : a function that takes a single input (X), the point of evaluation, and returns f(X) and df/dX - `x` : the initial point - `state` : a table describing the state of the optimizer; after each @@ -1,16 +1,16 @@ require 'torch' require 'math' -local BestSolution = {} ---[[ An implementation of `CMAES` (Covariance Matrix Adaptation Evolution Strategy), +local BestSolution = {} +--[[ An implementation of `CMAES` (Covariance Matrix Adaptation Evolution Strategy), ported from https://www.lri.fr/~hansen/barecmaes2.html. - + Parameters ---------- ARGS: -- `opfunc` : a function that takes a single input (X), the point of - evaluation, and returns f(X) and df/dX. Note that df/dX is not used +- `opfunc` : a function that takes a single input (X), the point of + evaluation, and returns f(X) and df/dX. Note that df/dX is not used - `x` : the initial point - `state.sigma` float, initial step-size (standard deviation in each @@ -20,16 +20,16 @@ ARGS: - `state.ftarget` float, target function value - `state.popsize` - population size. If this is left empty, + population size. If this is left empty, 4 + int(3 * log(|x|)) will be used -- `state.ftarget` +- `state.ftarget` stop if fitness < ftarget - `state.verb_disp` int, display on console every verb_disp iteration, 0 for never RETURN: - `x*` : the new `x` vector, at the optimal point -- `f` : a table of all function values: +- `f` : a table of all function values: `f[1]` is the value of the function before any optimization and `f[#f]` is the final fully optimized value, at `x*` --]] @@ -50,13 +50,13 @@ function optim.cmaes(opfunc, x, config, state) local min_iterations = state.min_iterations or 1 local lambda = state.popsize -- population size, offspring number - -- Strategy parameter setting: Selection + -- Strategy parameter setting: Selection if state.popsize == nil then lambda = 4 + math.floor(3 * math.log(N)) end local mu = lambda / 2 -- number of parents/points for recombination - local weights = torch.range(0,mu-1):apply(function(i) + local weights = torch.range(0,mu-1):apply(function(i) return math.log(mu+0.5) - math.log(i+1) end) -- recombination weights weights:div(weights:sum()) -- normalize recombination weights array local mueff = weights:sum()^2 / torch.pow(weights,2):sum() -- variance-effectiveness of sum w_i x_i @@ -69,18 +69,18 @@ function optim.cmaes(opfunc, x, config, state) local cmu = math.min(1 - c1, 2 * (mueff - 2 + 1/mueff) / ((N + 2)^2 + mueff)) -- and for rank-mu update local damps = 2 * mueff/lambda + 0.3 + cs -- damping for sigma, usually close to 1 - -- Initialize dynamic (internal) state variables + -- Initialize dynamic (internal) state variables local pc = torch.Tensor(N):zero():typeAs(x) -- evolution paths for C local ps = torch.Tensor(N):zero():typeAs(x) -- evolution paths for sigma - local B = torch.eye(N):typeAs(x) -- B defines the coordinate system + local B = torch.eye(N):typeAs(x) -- B defines the coordinate system local D = torch.Tensor(N):fill(1):typeAs(x) -- diagonal D defines the scaling - local C = torch.eye(N):typeAs(x) -- covariance matrix + local C = torch.eye(N):typeAs(x) -- covariance matrix if not pcall(function () torch.symeig(C,'V') end) then -- if error occurs trying to use symeig - error('torch.symeig not available for ' .. x:type() .. + error('torch.symeig not available for ' .. x:type() .. " please use Float- or DoubleTensor for x") end local candidates = torch.Tensor(lambda,N):typeAs(x) - local invsqrtC = torch.eye(N):typeAs(x) -- C^-1/2 + local invsqrtC = torch.eye(N):typeAs(x) -- C^-1/2 local eigeneval = 0 -- tracking the update of B and D local counteval = 0 local f_hist = {[1]=opfunc(x)} -- for bookkeeping output and termination @@ -90,7 +90,7 @@ function optim.cmaes(opfunc, x, config, state) local function ask() - --[[return a list of lambda candidate solutions according to + --[[return a list of lambda candidate solutions according to m + sig * Normal(0,C) = m + sig * B * D * Normal(0,I) --]] -- Eigendecomposition: first update B, D and invsqrtC from C @@ -117,9 +117,9 @@ function optim.cmaes(opfunc, x, config, state) Parameters ---------- - `arx` + `arx` a list of solutions, presumably from `ask()` - `fitvals` + `fitvals` the corresponding objective function values --]] -- bookkeeping, preparation counteval = counteval + lambda -- slightly artificial to do here @@ -142,7 +142,7 @@ function optim.cmaes(opfunc, x, config, state) local c = (cs * (2-cs) * mueff)^0.5 / sigma ps = ps - ps * cs + z * c -- exponential decay on ps - local hsig = (torch.sum(torch.pow(ps,2)) / + local hsig = (torch.sum(torch.pow(ps,2)) / (1-(1-cs)^(2*counteval/lambda)) / N < 2 + 4./(N+1)) hsig = hsig and 1.0 or 0.0 --use binary numbers @@ -155,23 +155,23 @@ function optim.cmaes(opfunc, x, config, state) for i=1,N do for j=1,N do local r = torch.range(1,mu) - r:apply(function(k) + r:apply(function(k) return weights[k] * (arx[k][i]-xold[i]) * (arx[k][j]-xold[j]) end) local Cmuij = torch.sum(r) / sigma^2 -- rank-mu update - C[i][j] = C[i][j] + ((-c1a - cmu) * C[i][j] + + C[i][j] = C[i][j] + ((-c1a - cmu) * C[i][j] + c1 * pc[i]*pc[j] + cmu * Cmuij) end end -- Adapt step-size sigma with factor <= exp(0.6) \approx 1.82 - sigma = sigma * math.exp(math.min(0.6, + sigma = sigma * math.exp(math.min(0.6, (cs / damps) * (torch.sum(torch.pow(ps,2))/N - 1)/2)) end - local function stop() - --[[return satisfied termination conditions in a table like - {'termination reason':value, ...}, for example {'tolfun':1e-12}, - or the empty table {}--]] + local function stop() + --[[return satisfied termination conditions in a table like + {'termination reason':value, ...}, for example {'tolfun':1e-12}, + or the empty table {}--]] local res = {} if counteval > 0 then if counteval >= maxEval then @@ -184,7 +184,7 @@ function optim.cmaes(opfunc, x, config, state) res['condition'] = 1e7 end if fitvals:nElement() > 1 and fitvals[fitvals:nElement()] - fitvals[1] < 1e-12 then - res['tolfun'] = 1e-12 + res['tolfun'] = 1e-12 end if sigma * torch.max(D) < 1e-11 then -- remark: max(D) >= max(diag(C))^0.5 @@ -206,8 +206,8 @@ function optim.cmaes(opfunc, x, config, state) end if iteration <= 2 or iteration % verb_modulo == 0 then local max_std = math.sqrt(torch.max(torch.diag(C))) - print(tostring(counteval).. ': ' .. - string.format(' %6.1f %8.1e ', torch.max(D) / torch.min(D), sigma * max_std) + print(tostring(counteval).. ': ' .. + string.format(' %6.1f %8.1e ', torch.max(D) / torch.min(D), sigma * max_std) .. tostring(fitvals[1])) end @@ -224,7 +224,7 @@ function optim.cmaes(opfunc, x, config, state) fitvals[i] = objfunc(candidate) end - tell(X) + tell(X) disp(verb_disp) end @@ -245,7 +245,7 @@ end -BestSolution.__index = BestSolution +BestSolution.__index = BestSolution function BestSolution.new(x, f, evals) local self = setmetatable({}, BestSolution) self.x = x diff --git a/doc/algos.md b/doc/algos.md new file mode 100644 index 0000000..a671420 --- /dev/null +++ b/doc/algos.md @@ -0,0 +1,363 @@ +<a name='optim.algos'></a> +# Optimization algorithms + +The following algorithms are provided: + + * [*Stochastic Gradient Descent*](#optim.sgd) + * [*Averaged Stochastic Gradient Descent*](#optim.asgd) + * [*L-BFGS*](#optim.lbfgs) + * [*Congugate Gradients*](#optim.cg) + * [*AdaDelta*](#optim.adadelta) + * [*AdaGrad*](#optim.adagrad) + * [*Adam*](#optim.adam) + * [*AdaMax*](#optim.adamax) + * [*FISTA with backtracking line search*](#optim.FistaLS) + * [*Nesterov's Accelerated Gradient method*](#optim.nag) + * [*RMSprop*](#optim.rmsprop) + * [*Rprop*](#optim.rprop) + * [*CMAES*](#optim.cmaes) + +All these algorithms are designed to support batch optimization as well as stochastic optimization. +It's up to the user to construct an objective function that represents the batch, mini-batch, or single sample on which to evaluate the objective. + +Some of these algorithms support a line search, which can be passed as a function (*L-BFGS*), whereas others only support a learning rate (*SGD*). + +General interface: + +```lua +x*, {f}, ... = optim.method(opfunc, x[, config][, state]) +``` + + +<a name='optim.sgd'></a> +## sgd(opfunc, x[, config][, state]) + +An implementation of *Stochastic Gradient Descent* (*SGD*). + +Arguments: + + * `opfunc`: a function that takes a single input `X`, the point of a evaluation, and returns `f(X)` and `df/dX` + * `x`: the initial point + * `config`: a table with configuration parameters for the optimizer + * `config.learningRate`: learning rate + * `config.learningRateDecay`: learning rate decay + * `config.weightDecay`: weight decay + * `config.weightDecays`: vector of individual weight decays + * `config.momentum`: momentum + * `config.dampening`: dampening for momentum + * `config.nesterov`: enables Nesterov momentum + * `state`: a table describing the state of the optimizer; after each call the state is modified + * `state.learningRates`: vector of individual learning rates + +Returns: + + * `x*`: the new x vector + * `f(x)`: the function, evaluated before the update + + +<a name='optim.asgd'></a> +## asgd(opfunc, x[, config][, state]) + +An implementation of *Averaged Stochastic Gradient Descent* (*ASGD*): + +```lua +x = (1 - lambda eta_t) x - eta_t df / dx(z, x) +a = a + mu_t [ x - a ] + +eta_t = eta0 / (1 + lambda eta0 t) ^ 0.75 +mu_t = 1 / max(1, t - t0) +``` + +Arguments: + + * `opfunc`: a function that takes a single input `X`, the point of evaluation, and returns `f(X)` and `df/dX` + * `x`: the initial point + * `config`: a table with configuration parameters for the optimizer + * `config.eta0`: learning rate + * `config.lambda`: decay term + * `config.alpha`: power for eta update + * `config.t0`: point at which to start averaging + +Returns: + + * `x*`: the new x vector + * `f(x)`: the function, evaluated before the update + * `ax`: the averaged x vector + + +<a name='optim.lbfgs'></a> +## lbfgs(opfunc, x[, config][, state]) + +An implementation of *L-BFGS* that relies on a user-provided line search function (`state.lineSearch`). +If this function is not provided, then a simple learning rate is used to produce fixed size steps. +Fixed size steps are much less costly than line searches, and can be useful for stochastic problems. + +The learning rate is used even when a line search is provided. +This is also useful for large-scale stochastic problems, where opfunc is a noisy approximation of `f(x)`. +In that case, the learning rate allows a reduction of confidence in the step size. + +Arguments: + + * `opfunc`: a function that takes a single input `X`, the point of evaluation, and returns `f(X)` and `df/dX` + * `x`: the initial point + * `config`: a table with configuration parameters for the optimizer + * `config.maxIter`: Maximum number of iterations allowed + * `config.maxEval`: Maximum number of function evaluations + * `config.tolFun`: Termination tolerance on the first-order optimality + * `config.tolX`: Termination tol on progress in terms of func/param changes + * `config.lineSearch`: A line search function + * `config.learningRate`: If no line search provided, then a fixed step size is used + +Returns: + * `x*`: the new `x` vector, at the optimal point + * `f`: a table of all function values: + * `f[1]` is the value of the function before any optimization and + * `f[#f]` is the final fully optimized value, at `x*` + + +<a name='optim.cg'></a> +## cg(opfunc, x[, config][, state]) + +An implementation of the *Conjugate Gradient* method which is a rewrite of `minimize.m` written by Carl E. Rasmussen. +It is supposed to produce exactly same results (give or take numerical accuracy due to some changed order of operations). +You can compare the result on rosenbrock with [`minimize.m`](http://www.gatsby.ucl.ac.uk/~edward/code/minimize/example.html). + +```lua +x, fx, c = minimize([0, 0]', 'rosenbrock', -25) +``` + +Note that we limit the number of function evaluations only, it seems much more important in practical use. + +Arguments: + + * `opfunc`: a function that takes a single input, the point of evaluation. + * `x`: the initial point + * `config`: a table with configuration parameters for the optimizer + * `config.maxEval`: max number of function evaluations + * `config.maxIter`: max number of iterations + * `state`: a table of parameters and temporary allocations. + * `state.df[0, 1, 2, 3]`: if you pass `Tensor` they will be used for temp storage + * `state.[s, x0]`: if you pass `Tensor` they will be used for temp storage + +Returns: + + * `x*`: the new `x` vector, at the optimal point + * `f`: a table of all function values where + * `f[1]` is the value of the function before any optimization and + * `f[#f]` is the final fully optimized value, at `x*` + + +<a name='optim.adadelta'></a> +## adadelta(opfunc, x[, config][, state]) + +*AdaDelta* implementation for *SGD* http://arxiv.org/abs/1212.5701. + +Arguments: + + * `opfunc`: a function that takes a single input `X`, the point of evaluation, and returns `f(X)` and `df/dX` + * `x`: the initial point + * `config`: a table of hyper-parameters + * `config.rho`: interpolation parameter + * `config.eps`: for numerical stability + * `state`: a table describing the state of the optimizer; after each call the state is modified + * `state.paramVariance`: vector of temporal variances of parameters + * `state.accDelta`: vector of accummulated delta of gradients + +Returns: + + * `x*`: the new x vector + * `f(x)`: the function, evaluated before the update + + +<a name='optim.adagrad'></a> +## adagrad(opfunc, x[, config][, state]) + +*AdaGrad* implementation for *SGD*. + +Arguments: + + * `opfunc`: a function that takes a single input `X`, the point of evaluation, and returns `f(X)` and `df/dX` + * `x`: the initial point + * `config`: a table with configuration parameters for the optimizer + * `config.learningRate`: learning rate + * `state`: a table describing the state of the optimizer; after each call the state is modified + * `state.paramVariance`: vector of temporal variances of parameters + +Returns: + + * `x*`: the new `x` vector + * `f(x)`: the function, evaluated before the update + + +<a name='optim.adam'></a> +## adam(opfunc, x[, config][, state]) + +An implementation of *Adam* from http://arxiv.org/pdf/1412.6980.pdf. + +Arguments: + + * `opfunc`: a function that takes a single input `X`, the point of a evaluation, and returns `f(X)` and `df/dX` + * `x`: the initial point + * `config`: a table with configuration parameters for the optimizer + * `config.learningRate`: learning rate + * `config.beta1`: first moment coefficient + * `config.beta2`: second moment coefficient + * `config.epsilon`: for numerical stability + * `state`: a table describing the state of the optimizer; after each call the state is modified + +Returns: + + * `x*`: the new x vector + * `f(x)`: the function, evaluated before the update + + +<a name='optim.adamax'></a> +## adamax(opfunc, x[, config][, state]) + +An implementation of *AdaMax* http://arxiv.org/pdf/1412.6980.pdf. + +Arguments: + + * `opfunc`: a function that takes a single input `X`, the point of a evaluation, and returns `f(X)` and `df/dX` + * `x`: the initial point + * `config`: a table with configuration parameters for the optimizer + * `config.learningRate`: learning rate + * `config.beta1`: first moment coefficient + * `config.beta2`: second moment coefficient + * `config.epsilon`: for numerical stability + * `state`: a table describing the state of the optimizer; after each call the state is modified + +Returns: + + * `x*`: the new `x` vector + * `f(x)`: the function, evaluated before the update + + +<a name='optim.FistaLS'></a> +## FistaLS(f, g, pl, xinit[, params]) + +*Fista* with backtracking *Line Search*: + + * `f`: smooth function + * `g`: non-smooth function + * `pl`: minimizer of intermediate problem Q(x, y) + * `xinit`: initial point + * `params`: table of parameters (**optional**) + * `params.L`: 1/(step size) for ISTA/FISTA iteration (0.1) + * `params.Lstep`: step size multiplier at each iteration (1.5) + * `params.maxiter`: max number of iterations (50) + * `params.maxline`: max number of line search iterations per iteration (20) + * `params.errthres`: Error thershold for convergence check (1e-4) + * `params.doFistaUpdate`: true : use FISTA, false: use ISTA (true) + * `params.verbose`: store each iteration solution and print detailed info (false) + +On output, `params` will contain these additional fields that can be reused. + * `params.L`: last used L value will be written. + +These are temporary storages needed by the algo and if the same params object is +passed a second time, these same storages will be used without new allocation. + * `params.xkm`: previous iterarion point + * `params.y`: fista iteration + * `params.ply`: `ply = pl(y * 1/L grad(f))` + +Returns the solution `x` and history of `{function evals, number of line search , ...}`. + +Algorithm is published in http://epubs.siam.org/doi/abs/10.1137/080716542 + + +<a name='optim.nag'></a> +## nag(opfunc, x[, config][, state]) + +An implementation of *SGD* adapted with features of *Nesterov's Accelerated Gradient method*, based on the paper "On the Importance of Initialization and Momentum in Deep Learning" (Sutsveker et. al., ICML 2013) http://www.cs.toronto.edu/~fritz/absps/momentum.pdf. + +Arguments: + + * `opfunc`: a function that takes a single input `X`, the point of evaluation, and returns `f(X)` and `df/dX` + * `x`: the initial point + * `config`: a table with configuration parameters for the optimizer + * `config.learningRate`: learning rate + * `config.learningRateDecay`: learning rate decay + * `config.weightDecay`: weight decay + * `config.momentum`: momentum + * `config.learningRates`: vector of individual learning rates + +Returns: + + * `x*`: the new `x` vector + * `f(x)`: the function, evaluated before the update + + +<a name='optim.rmsprop'></a> +## rmsprop(opfunc, x[, config][, state]) + +An implementation of *RMSprop*. + +Arguments: + + * `opfunc`: a function that takes a single input `X`, the point of a evaluation, and returns `f(X)` and `df/dX` + * `x`: the initial point + * `config`: a table with configuration parameters for the optimizer + * `config.learningRate`: learning rate + * `config.alpha`: smoothing constant + * `config.epsilon`: value with which to initialise m + * `state`: a table describing the state of the optimizer; after each call the state is modified + * `state.m`: leaky sum of squares of parameter gradients, + * `state.tmp`: and the square root (with epsilon smoothing) + +Returns: + + * `x*`: the new x vector + * `f(x)`: the function, evaluated before the update + + +<a name='optim.rprop'></a> +## rprop(opfunc, x[, config][, state]) + +A plain implementation of *Rprop* (Martin Riedmiller, Koray Kavukcuoglu 2013). + +Arguments: + + * `opfunc`: a function that takes a single input `X`, the point of evaluation, and returns `f(X)` and `df/dX` + * `x`: the initial point + * `config`: a table with configuration parameters for the optimizer + * `config.stepsize`: initial step size, common to all components + * `config.etaplus`: multiplicative increase factor, > 1 (default 1.2) + * `config.etaminus`: multiplicative decrease factor, < 1 (default 0.5) + * `config.stepsizemax`: maximum stepsize allowed (default 50) + * `config.stepsizemin`: minimum stepsize allowed (default 1e-6) + * `config.niter`: number of iterations (default 1) + +Returns: + + * `x*`: the new x vector + * `f(x)`: the function, evaluated before the update + + +<a name='optim.cmaes'></a> +## cmaes(opfunc, x[, config][, state]) + +An implementation of *CMAES* (*Covariance Matrix Adaptation Evolution Strategy*), ported from https://www.lri.fr/~hansen/barecmaes2.html. + +*CMAES* is a stochastic, derivative-free method for heuristic global optimization of non-linear or non-convex continuous optimization problems. +Note that this method will on average take much more function evaluations to converge then a gradient based method. + +Arguments: + +If `state` is specified, then `config` is not used at all. +Otherwise `state` is `config`. + + * `opfunc`: a function that takes a single input `X`, the point of evaluation, and returns `f(X)` and `df/dX`. Note that `df/dX` is not used and can be left 0 + * `x`: the initial point + * `state`: a table describing the state of the optimizer; after each call the state is modified + * `state.sigma`: float, initial step-size (standard deviation in each coordinate) + * `state.maxEval`: int, maximal number of function evaluations + * `state.ftarget`: float, target function value + * `state.popsize`: population size. If this is left empty, `4 + int(3 * log(|x|))` will be used + * `state.ftarget`: stop if `fitness < ftarget` + * `state.verb_disp`: display info on console every verb_disp iteration, 0 for never + +Returns: + * `x*`: the new `x` vector, at the optimal point + * `f`: a table of all function values: + * `f[1]` is the value of the function before any optimization and + * `f[#f]` is the final fully optimized value, at `x*` diff --git a/doc/index.md b/doc/index.md deleted file mode 100644 index a399206..0000000 --- a/doc/index.md +++ /dev/null @@ -1,409 +0,0 @@ -<a name='optim.dok'></a> -# Optim Package - -This package provides a set of optimization algorithms, which all follow -a unified, closure-based API. - -This package is fully compatible with the [nn](http://nn.readthedocs.org) package, but can also be -used to optimize arbitrary objective functions. - -For now, the following algorithms are provided: - - * [Stochastic Gradient Descent](#optim.sgd) - * [Averaged Stochastic Gradient Descent](#optim.asgd) - * [L-BFGS](#optim.lbfgs) - * [Congugate Gradients](#optim.cg) - * [AdaDelta](#optim.adadelta) - * [AdaGrad](#optim.adagrad) - * [Adam](#optim.adam) - * [AdaMax](#optim.adamax) - * [FISTA with backtracking line search](#optim.FistaLS) - * [Nesterov's Accelerated Gradient method](#optim.nag) - * [RMSprop](#optim.rmsprop) - * [Rprop](#optim.rprop) - * [CMAES](#optim.cmaes) - -All these algorithms are designed to support batch optimization as -well as stochastic optimization. It's up to the user to construct an -objective function that represents the batch, mini-batch, or single sample -on which to evaluate the objective. - -Some of these algorithms support a line search, which can be passed as -a function (L-BFGS), whereas others only support a learning rate (SGD). - -<a name='optim.overview'></a> -## Overview - -This package contains several optimization routines for [Torch](https://github.com/torch/torch7/blob/master/README.md). -Most optimization algorithms has the following interface: - -```lua -x*, {f}, ... = optim.method(opfunc, x, state) -``` - -where: - -* `opfunc`: a user-defined closure that respects this API: `f, df/dx = func(x)` -* `x`: the current parameter vector (a 1D `torch.Tensor`) -* `state`: a table of parameters, and state variables, dependent upon the algorithm -* `x*`: the new parameter vector that minimizes `f, x* = argmin_x f(x)` -* `{f}`: a table of all f values, in the order they've been evaluated (for some simple algorithms, like SGD, `#f == 1`) - -<a name='optim.example'></a> -## Example - -The state table is used to hold the state of the algorihtm. -It's usually initialized once, by the user, and then passed to the optim function -as a black box. Example: - -```lua -state = { - learningRate = 1e-3, - momentum = 0.5 -} - -for i,sample in ipairs(training_samples) do - local func = function(x) - -- define eval function - return f,df_dx - end - optim.sgd(func,x,state) -end -``` - -<a name='optim.algorithms'></a> -## Algorithms - -Most algorithms provided rely on a unified interface: -```lua -x_new,fs = optim.method(opfunc, x, state) -``` -where: -x is the trainable/adjustable parameter vector, -state contains both options for the algorithm and the state of the algorihtm, -opfunc is a closure that has the following interface: -```lua -f,df_dx = opfunc(x) -``` -x_new is the new parameter vector (after optimization), -fs is a a table containing all the values of the objective, as evaluated during -the optimization procedure: fs[1] is the value before optimization, and fs[#fs] -is the most optimized one (the lowest). - -<a name='optim.sgd'></a> -### [x] sgd(opfunc, x, state) - -An implementation of Stochastic Gradient Descent (SGD). - -Arguments: - - * `opfunc` : a function that takes a single input (`X`), the point of a evaluation, and returns `f(X)` and `df/dX` - * `x` : the initial point - * `config` : a table with configuration parameters for the optimizer - * `config.learningRate` : learning rate - * `config.learningRateDecay` : learning rate decay - * `config.weightDecay` : weight decay - * `config.weightDecays` : vector of individual weight decays - * `config.momentum` : momentum - * `config.dampening` : dampening for momentum - * `config.nesterov` : enables Nesterov momentum - * `state` : a table describing the state of the optimizer; after each call the state is modified - * `state.learningRates` : vector of individual learning rates - -Returns : - - * `x` : the new x vector - * `f(x)` : the function, evaluated before the update - -<a name='optim.asgd'></a> -### [x] asgd(opfunc, x, state) - -An implementation of Averaged Stochastic Gradient Descent (ASGD): - -``` -x = (1 - lambda eta_t) x - eta_t df/dx(z,x) -a = a + mu_t [ x - a ] - -eta_t = eta0 / (1 + lambda eta0 t) ^ 0.75 -mu_t = 1/max(1,t-t0) -``` - -Arguments: - - * `opfunc` : a function that takes a single input (`X`), the point of evaluation, and returns `f(X)` and `df/dX` - * `x` : the initial point - * `state` : a table describing the state of the optimizer; after each call the state is modified - * `state.eta0` : learning rate - * `state.lambda` : decay term - * `state.alpha` : power for eta update - * `state.t0` : point at which to start averaging - -Returns: - - * `x` : the new x vector - * `f(x)` : the function, evaluated before the update - * `ax` : the averaged x vector - - -<a name='optim.lbfgs'></a> -### [x] lbfgs(opfunc, x, state) - -An implementation of L-BFGS that relies on a user-provided line -search function (`state.lineSearch`). If this function is not -provided, then a simple learningRate is used to produce fixed -size steps. Fixed size steps are much less costly than line -searches, and can be useful for stochastic problems. - -The learning rate is used even when a line search is provided. -This is also useful for large-scale stochastic problems, where -opfunc is a noisy approximation of `f(x)`. In that case, the learning -rate allows a reduction of confidence in the step size. - -Arguments : - - * `opfunc` : a function that takes a single input (`X`), the point of evaluation, and returns `f(X)` and `df/dX` - * `x` : the initial point - * `state` : a table describing the state of the optimizer; after each call the state is modified - * `state.maxIter` : Maximum number of iterations allowed - * `state.maxEval` : Maximum number of function evaluations - * `state.tolFun` : Termination tolerance on the first-order optimality - * `state.tolX` : Termination tol on progress in terms of func/param changes - * `state.lineSearch` : A line search function - * `state.learningRate` : If no line search provided, then a fixed step size is used - -Returns : - * `x*` : the new `x` vector, at the optimal point - * `f` : a table of all function values: - * `f[1]` is the value of the function before any optimization and - * `f[#f]` is the final fully optimized value, at `x*` - - -<a name='optim.cg'></a> -### [x] cg(opfunc, x, state) - -An implementation of the Conjugate Gradient method which is a rewrite of -`minimize.m` written by Carl E. Rasmussen. -It is supposed to produce exactly same results (give -or take numerical accuracy due to some changed order of -operations). You can compare the result on rosenbrock with -[minimize.m](http://www.gatsby.ucl.ac.uk/~edward/code/minimize/example.html). -``` -[x fx c] = minimize([0 0]', 'rosenbrock', -25) -``` - -Note that we limit the number of function evaluations only, it seems much -more important in practical use. - -Arguments : - - * `opfunc` : a function that takes a single input, the point of evaluation. - * `x` : the initial point - * `state` : a table of parameters and temporary allocations. - * `state.maxEval` : max number of function evaluations - * `state.maxIter` : max number of iterations - * `state.df[0,1,2,3]` : if you pass torch.Tensor they will be used for temp storage - * `state.[s,x0]` : if you pass torch.Tensor they will be used for temp storage - -Returns : - - * `x*` : the new x vector, at the optimal point - * `f` : a table of all function values where - * `f[1]` is the value of the function before any optimization and - * `f[#f]` is the final fully optimized value, at x* - -<a name='optim.adadelta'></a> -### [x] adadelta(opfunc, x, config, state) -ADADELTA implementation for SGD http://arxiv.org/abs/1212.5701 - -Arguments : - -* `opfunc` : a function that takes a single input (X), the point of evaluation, and returns f(X) and df/dX -* `x` : the initial point -* `config` : a table of hyper-parameters -* `config.rho` : interpolation parameter -* `config.eps` : for numerical stability -* `state` : a table describing the state of the optimizer; after each call the state is modified -* `state.paramVariance` : vector of temporal variances of parameters -* `state.accDelta` : vector of accummulated delta of gradients - -Returns : - -* `x` : the new x vector -* `f(x)` : the function, evaluated before the update - -<a name='optim.adagrad'></a> -### [x] adagrad(opfunc, x, config, state) -AdaGrad implementation for SGD - -Arguments : - -* `opfunc` : a function that takes a single input (X), the point of evaluation, and returns f(X) and df/dX -* `x` : the initial point -* `state` : a table describing the state of the optimizer; after each call the state is modified -* `state.learningRate` : learning rate -* `state.paramVariance` : vector of temporal variances of parameters - -Returns : - -* `x` : the new x vector -* `f(x)` : the function, evaluated before the update - -<a name='optim.adam'></a> -### [x] adam(opfunc, x, config, state) -An implementation of Adam from http://arxiv.org/pdf/1412.6980.pdf - -Arguments : - -* `opfunc` : a function that takes a single input (X), the point of a evaluation, and returns f(X) and df/dX -* `x` : the initial point -* `config` : a table with configuration parameters for the optimizer -* `config.learningRate` : learning rate -* `config.beta1` : first moment coefficient -* `config.beta2` : second moment coefficient -* `config.epsilon` : for numerical stability -* `state` : a table describing the state of the optimizer; after each call the state is modified - -Returns : - -* `x` : the new x vector -* `f(x)` : the function, evaluated before the update - -<a name='optim.adamax'></a> -### [x] adamax(opfunc, x, config, state) -An implementation of AdaMax http://arxiv.org/pdf/1412.6980.pdf - -Arguments : - -* `opfunc` : a function that takes a single input (X), the point of a evaluation, and returns f(X) and df/dX -* `x` : the initial point -* `config` : a table with configuration parameters for the optimizer -* `config.learningRate` : learning rate -* `config.beta1` : first moment coefficient -* `config.beta2` : second moment coefficient -* `config.epsilon` : for numerical stability -* `state` : a table describing the state of the optimizer; after each call the state is modified. - -Returns : - -* `x` : the new x vector -* `f(x)` : the function, evaluated before the update - -<a name='optim.FistaLS'></a> -### [x] FistaLS(f, g, pl, xinit, params) -FISTA with backtracking line search -* `f` : smooth function -* `g` : non-smooth function -* `pl` : minimizer of intermediate problem Q(x,y) -* `xinit` : initial point -* `params` : table of parameters (**optional**) -* `params.L` : 1/(step size) for ISTA/FISTA iteration (0.1) -* `params.Lstep` : step size multiplier at each iteration (1.5) -* `params.maxiter` : max number of iterations (50) -* `params.maxline` : max number of line search iterations per iteration (20) -* `params.errthres`: Error thershold for convergence check (1e-4) -* `params.doFistaUpdate` : true : use FISTA, false: use ISTA (true) -* `params.verbose` : store each iteration solution and print detailed info (false) - -On output, `params` will contain these additional fields that can be reused. -* `params.L` : last used L value will be written. - -These are temporary storages needed by the algo and if the same params object is -passed a second time, these same storages will be used without new allocation. -* `params.xkm` : previous iterarion point -* `params.y` : fista iteration -* `params.ply` : ply = pl(y * 1/L grad(f)) - -Returns the solution x and history of {function evals, number of line search ,...} - -Algorithm is published in http://epubs.siam.org/doi/abs/10.1137/080716542 - -<a name='optim.nag'></a> -### [x] nag(opfunc, x, config, state) -An implementation of SGD adapted with features of Nesterov's -Accelerated Gradient method, based on the paper "On the Importance of Initialization and Momentum in Deep Learning" (Sutsveker et. al., ICML 2013). - -Arguments : - -* `opfunc` : a function that takes a single input (X), the point of evaluation, and returns f(X) and df/dX -* `x` : the initial point -* `state` : a table describing the state of the optimizer; after each call the state is modified -* `state.learningRate` : learning rate -* `state.learningRateDecay` : learning rate decay -* `astate.weightDecay` : weight decay -* `state.momentum` : momentum -* `state.learningRates` : vector of individual learning rates - -Returns : - -* `x` : the new x vector -* `f(x)` : the function, evaluated before the update - -<a name='optim.rmsprop'></a> -### [x] rmsprop(opfunc, x, config, state) -An implementation of RMSprop - -Arguments : - -* `opfunc` : a function that takes a single input (X), the point of a evaluation, and returns f(X) and df/dX -* `x` : the initial point -* `config` : a table with configuration parameters for the optimizer -* `config.learningRate` : learning rate -* `config.alpha` : smoothing constant -* `config.epsilon` : value with which to initialise m -* `state` : a table describing the state of the optimizer; after each call the state is modified -* `state.m` : leaky sum of squares of parameter gradients, -* `state.tmp` : and the square root (with epsilon smoothing) - -Returns : - -* `x` : the new x vector -* `f(x)` : the function, evaluated before the update - -<a name='optim.rprop'></a> -### [x] rprop(opfunc, x, config, state) -A plain implementation of Rprop -(Martin Riedmiller, Koray Kavukcuoglu 2013) - -Arguments : - -* `opfunc` : a function that takes a single input (X), the point of evaluation, and returns f(X) and df/dX -* `x` : the initial point -* `state` : a table describing the state of the optimizer; after each call the state is modified -* `state.stepsize` : initial step size, common to all components -* `state.etaplus` : multiplicative increase factor, > 1 (default 1.2) -* `state.etaminus` : multiplicative decrease factor, < 1 (default 0.5) -* `state.stepsizemax` : maximum stepsize allowed (default 50) -* `state.stepsizemin` : minimum stepsize allowed (default 1e-6) -* `state.niter` : number of iterations (default 1) - -Returns : - -* `x` : the new x vector -* `f(x)` : the function, evaluated before the update - - - - -<a name='optim.cmaes'></a> -### [x] cmaes(opfunc, x, config, state) -An implementation of `CMAES` (Covariance Matrix Adaptation Evolution Strategy), -ported from https://www.lri.fr/~hansen/barecmaes2.html. - -CMAES is a stochastic, derivative-free method for heuristic global optimization of non-linear or non-convex continuous optimization problems. Note that this method will on average take much more function evaluations to converge then a gradient based method. - -Arguments: - -* `opfunc` : a function that takes a single input (X), the point of evaluation, and returns f(X) and df/dX. Note that df/dX is not used and can be left 0 -* `x` : the initial point -* `state.sigma` : float, initial step-size (standard deviation in each coordinate) -* `state.maxEval` : int, maximal number of function evaluations -* `state.ftarget` : float, target function value -* `state.popsize` : population size. If this is left empty, 4 + int(3 * log(|x|)) will be used -* `state.ftarget` : stop if fitness < ftarget -* `state.verb_disp` : display info on console every verb_disp iteration, 0 for never - -Returns: -* `x*` : the new `x` vector, at the optimal point -* `f` : a table of all function values: - * `f[1]` is the value of the function before any optimization and - * `f[#f]` is the final fully optimized value, at `x*` diff --git a/doc/intro.md b/doc/intro.md new file mode 100644 index 0000000..b387235 --- /dev/null +++ b/doc/intro.md @@ -0,0 +1,41 @@ +<a name='optim.overview'></a> +# Overview + +Most optimization algorithms have the following interface: + +```lua +x*, {f}, ... = optim.method(opfunc, x[, config][, state]) +``` + +where: + +* `opfunc`: a user-defined closure that respects this API: `f, df/dx = func(x)` +* `x`: the current parameter vector (a 1D `Tensor`) +* `config`: a table of parameters, dependent upon the algorithm +* `state`: a table of state variables, if `nil`, `config` will contain the state +* `x*`: the new parameter vector that minimizes `f, x* = argmin_x f(x)` +* `{f}`: a table of all `f` values, in the order they've been evaluated (for some simple algorithms, like SGD, `#f == 1`) + + +<a name='optim.example'></a> +## Example + +The state table is used to hold the state of the algorihtm. +It's usually initialized once, by the user, and then passed to the optim function as a black box. +Example: + +```lua +config = { + learningRate = 1e-3, + momentum = 0.5 +} + +for i, sample in ipairs(training_samples) do + local func = function(x) + -- define eval function + return f, df_dx + end + optim.sgd(func, x, config) +end +``` + diff --git a/doc/logger.md b/doc/logger.md new file mode 100644 index 0000000..b7797d2 --- /dev/null +++ b/doc/logger.md @@ -0,0 +1,73 @@ +<a name='optim.logger'></a> +# Logger + +`optim` provides also logging and live plotting capabilities via the `optim.Logger()` function. + +Live logging is essential to monitor the *network accuracy* and *cost function* during training and testing, for spotting *under-* and *over-fitting*, for *early stopping* or just for monitoring the health of the current optimisation task. + + +## Logging data + +Let walk through an example to see how it works. + +We start with initialising our logger connected to a text file `accuracy.log`. + +```lua +logger = optim.Logger('accuracy.log') +``` + +We can decide to log on it, for example, *training* and *testing accuracies*. + +```lua +logger:setNames{'Training acc.', 'Test acc.'} +``` + +And now we can populate our logger randomly. + +```lua +for i = 1, 10 do + trainAcc = math.random(0, 100) + testAcc = math.random(0, 100) + logger:add{trainAcc, testAcc} +end +``` + +We can `cat` `accuracy.log` and see what's in it. + +``` +Training acc. Test acc. + 7.0000e+01 5.9000e+01 + 7.6000e+01 8.0000e+00 + 6.6000e+01 3.4000e+01 + 7.4000e+01 4.3000e+01 + 5.7000e+01 1.1000e+01 + 5.0000e+00 9.8000e+01 + 7.1000e+01 1.7000e+01 + 9.8000e+01 2.7000e+01 + 3.5000e+01 4.7000e+01 + 6.8000e+01 5.8000e+01 +``` + +## Visualising logs + +OK, cool, but how can we actually see what's going on? + +To have a better grasp of what's happening, we can plot our curves. +We need first to specify the plotting style, choosing from: + + * `.` for dots + * `+` for points + * `-` for lines + * `+-` for points and lines + * `~` for using smoothed lines with cubic interpolation + * `|` for using boxes + * custom string, one can also pass custom strings to use full capability of gnuplot. + +```lua +logger:style{'+-', '+-'} +logger:plot() +``` + +![Logging plot](logger_plot.png) + +If we'd like an interactive visualisation, we can put the `logger:plot()` instruction within the `for` loop, and the chart will be updated at every iteration. diff --git a/doc/logger_plot.png b/doc/logger_plot.png Binary files differnew file mode 100644 index 0000000..c5e86ae --- /dev/null +++ b/doc/logger_plot.png @@ -17,7 +17,7 @@ On output, `params` will contain these additional fields that can be reused. - `params.L` : last used L value will be written. -These are temporary storages needed by the algo and if the same params object is +These are temporary storages needed by the algo and if the same params object is passed a second time, these same storages will be used without new allocation. - `params.xkm` : previous iterarion point @@ -26,7 +26,7 @@ passed a second time, these same storages will be used without new allocation. Returns the solution x and history of {function evals, number of line search ,...} -Algorithm is published in +Algorithm is published in @article{beck-fista-09, Author = {Beck, Amir and Teboulle, Marc}, @@ -38,7 +38,7 @@ Algorithm is published in Year = {2009}} ]] function optim.FistaLS(f, g, pl, xinit, params) - + local params = params or {} local L = params.L or 0.1 local Lstep = params.Lstep or 1.5 @@ -46,7 +46,7 @@ function optim.FistaLS(f, g, pl, xinit, params) local maxline = params.maxline or 20 local errthres = params.errthres or 1e-4 local doFistaUpdate = params.doFistaUpdate - local verbose = params.verbose + local verbose = params.verbose -- temporary allocations params.xkm = params.xkm or torch.Tensor() @@ -77,11 +77,11 @@ function optim.FistaLS(f, g, pl, xinit, params) -- get derivatives from smooth function local fy,gfy = f(y,'dx') --local gfy = f(y) - + local fply = 0 local gply = 0 local Q = 0 - + ---------------------------------------------- -- do line search to find new current location starting from fista loc local nline = 0 @@ -98,7 +98,7 @@ function optim.FistaLS(f, g, pl, xinit, params) -- evaluate this point F(ply) fply = f(ply) - + -- ply - y ply:add(-1, y) -- <ply-y , \Grad(f(y))> @@ -27,7 +27,7 @@ ARGS: RETURN: - `x*` : the new `x` vector, at the optimal point -- `f` : a table of all function values: +- `f` : a table of all function values: `f[1]` is the value of the function before any optimization and `f[#f]` is the final fully optimized value, at `x*` @@ -46,7 +46,7 @@ function optim.lbfgs(opfunc, x, config, state) local lineSearchOpts = config.lineSearchOptions local learningRate = config.learningRate or 1 local isverbose = config.verbose or false - + state.funcEval = state.funcEval or 0 state.nIter = state.nIter or 0 @@ -142,7 +142,7 @@ function optim.lbfgs(opfunc, x, config, state) table.insert(state.stp_bufs, s) end - -- compute the approximate (L-BFGS) inverse Hessian + -- compute the approximate (L-BFGS) inverse Hessian -- multiplied by the gradient local k = #old_dirs @@ -1,11 +1,11 @@ ---------------------------------------------------------------------- --- An implementation of SGD adapted with features of Nesterov's +-- An implementation of SGD adapted with features of Nesterov's -- Accelerated Gradient method, based on the paper -- On the Importance of Initialization and Momentum in Deep Learning -- Sutsveker et. al., ICML 2013 -- -- ARGS: --- opfunc : a function that takes a single input (X), the point of +-- opfunc : a function that takes a single input (X), the point of -- evaluation, and returns f(X) and df/dX -- x : the initial point -- state : a table describing the state of the optimizer; after each @@ -44,7 +44,7 @@ function optim.nag(opfunc, x, config, state) -- first step in the direction of the momentum vector if state.dfdx then - x:add(mom, state.dfdx) + x:add(mom, state.dfdx) end -- then compute gradient at that point -- comment out the above line to get the original SGD diff --git a/rmsprop.lua b/rmsprop.lua index 038af21..1eb526d 100644 --- a/rmsprop.lua +++ b/rmsprop.lua @@ -22,36 +22,36 @@ RETURN: ]] function optim.rmsprop(opfunc, x, config, state) - -- (0) get/update state - local config = config or {} - local state = state or config - local lr = config.learningRate or 1e-2 - local alpha = config.alpha or 0.99 - local epsilon = config.epsilon or 1e-8 - local wd = config.weightDecay or 0 - - -- (1) evaluate f(x) and df/dx - local fx, dfdx = opfunc(x) - - -- (2) weight decay - if wd ~= 0 then + -- (0) get/update state + local config = config or {} + local state = state or config + local lr = config.learningRate or 1e-2 + local alpha = config.alpha or 0.99 + local epsilon = config.epsilon or 1e-8 + local wd = config.weightDecay or 0 + + -- (1) evaluate f(x) and df/dx + local fx, dfdx = opfunc(x) + + -- (2) weight decay + if wd ~= 0 then dfdx:add(wd, x) - end + end - -- (3) initialize mean square values and square gradient storage - if not state.m then + -- (3) initialize mean square values and square gradient storage + if not state.m then state.m = torch.Tensor():typeAs(x):resizeAs(dfdx):fill(1) state.tmp = torch.Tensor():typeAs(x):resizeAs(dfdx) - end + end - -- (4) calculate new (leaky) mean squared values - state.m:mul(alpha) - state.m:addcmul(1.0-alpha, dfdx, dfdx) + -- (4) calculate new (leaky) mean squared values + state.m:mul(alpha) + state.m:addcmul(1.0-alpha, dfdx, dfdx) - -- (5) perform update - state.tmp:sqrt(state.m):add(epsilon) - x:addcdiv(-lr, dfdx, state.tmp) + -- (5) perform update + state.tmp:sqrt(state.m):add(epsilon) + x:addcdiv(-lr, dfdx, state.tmp) - -- return x*, f(x) before optimization - return x, {fx} + -- return x*, f(x) before optimization + return x, {fx} end @@ -20,83 +20,83 @@ RETURN: (Martin Riedmiller, Koray Kavukcuoglu 2013) --]] function optim.rprop(opfunc, x, config, state) - if config == nil and state == nil then - print('no state table RPROP initializing') - end - -- (0) get/update state - local config = config or {} - local state = state or config - local stepsize = config.stepsize or 0.1 - local etaplus = config.etaplus or 1.2 - local etaminus = config.etaminus or 0.5 - local stepsizemax = config.stepsizemax or 50.0 - local stepsizemin = config.stepsizemin or 1E-06 - local niter = config.niter or 1 - - local hfx = {} - - for i=1,niter do - - -- (1) evaluate f(x) and df/dx - local fx,dfdx = opfunc(x) - - -- init temp storage - if not state.delta then - state.delta = dfdx.new(dfdx:size()):zero() - state.stepsize = dfdx.new(dfdx:size()):fill(stepsize) - state.sign = dfdx.new(dfdx:size()) - state.psign = torch.ByteTensor(dfdx:size()) - state.nsign = torch.ByteTensor(dfdx:size()) - state.zsign = torch.ByteTensor(dfdx:size()) - state.dminmax = torch.ByteTensor(dfdx:size()) - if torch.type(x)=='torch.CudaTensor' then - -- Push to GPU - state.psign = state.psign:cuda() - state.nsign = state.nsign:cuda() - state.zsign = state.zsign:cuda() - state.dminmax = state.dminmax:cuda() - end - end - - -- sign of derivative from last step to this one - torch.cmul(state.sign, dfdx, state.delta) - torch.sign(state.sign, state.sign) - - -- get indices of >0, <0 and ==0 entries - state.sign.gt(state.psign, state.sign, 0) - state.sign.lt(state.nsign, state.sign, 0) - state.sign.eq(state.zsign, state.sign, 0) - - -- get step size updates - state.sign[state.psign] = etaplus - state.sign[state.nsign] = etaminus - state.sign[state.zsign] = 1 - - -- update stepsizes with step size updates - state.stepsize:cmul(state.sign) - - -- threshold step sizes - -- >50 => 50 - state.stepsize.gt(state.dminmax, state.stepsize, stepsizemax) - state.stepsize[state.dminmax] = stepsizemax - -- <1e-6 ==> 1e-6 - state.stepsize.lt(state.dminmax, state.stepsize, stepsizemin) - state.stepsize[state.dminmax] = stepsizemin - - -- for dir<0, dfdx=0 - -- for dir>=0 dfdx=dfdx - dfdx[state.nsign] = 0 - -- state.sign = sign(dfdx) - torch.sign(state.sign,dfdx) - - -- update weights - x:addcmul(-1,state.sign,state.stepsize) - - -- update state.dfdx with current dfdx - state.delta:copy(dfdx) - - table.insert(hfx,fx) - end + if config == nil and state == nil then + print('no state table RPROP initializing') + end + -- (0) get/update state + local config = config or {} + local state = state or config + local stepsize = config.stepsize or 0.1 + local etaplus = config.etaplus or 1.2 + local etaminus = config.etaminus or 0.5 + local stepsizemax = config.stepsizemax or 50.0 + local stepsizemin = config.stepsizemin or 1E-06 + local niter = config.niter or 1 + + local hfx = {} + + for i=1,niter do + + -- (1) evaluate f(x) and df/dx + local fx,dfdx = opfunc(x) + + -- init temp storage + if not state.delta then + state.delta = dfdx.new(dfdx:size()):zero() + state.stepsize = dfdx.new(dfdx:size()):fill(stepsize) + state.sign = dfdx.new(dfdx:size()) + state.psign = torch.ByteTensor(dfdx:size()) + state.nsign = torch.ByteTensor(dfdx:size()) + state.zsign = torch.ByteTensor(dfdx:size()) + state.dminmax = torch.ByteTensor(dfdx:size()) + if torch.type(x)=='torch.CudaTensor' then + -- Push to GPU + state.psign = state.psign:cuda() + state.nsign = state.nsign:cuda() + state.zsign = state.zsign:cuda() + state.dminmax = state.dminmax:cuda() + end + end + + -- sign of derivative from last step to this one + torch.cmul(state.sign, dfdx, state.delta) + torch.sign(state.sign, state.sign) + + -- get indices of >0, <0 and ==0 entries + state.sign.gt(state.psign, state.sign, 0) + state.sign.lt(state.nsign, state.sign, 0) + state.sign.eq(state.zsign, state.sign, 0) + + -- get step size updates + state.sign[state.psign] = etaplus + state.sign[state.nsign] = etaminus + state.sign[state.zsign] = 1 + + -- update stepsizes with step size updates + state.stepsize:cmul(state.sign) + + -- threshold step sizes + -- >50 => 50 + state.stepsize.gt(state.dminmax, state.stepsize, stepsizemax) + state.stepsize[state.dminmax] = stepsizemax + -- <1e-6 ==> 1e-6 + state.stepsize.lt(state.dminmax, state.stepsize, stepsizemin) + state.stepsize[state.dminmax] = stepsizemin + + -- for dir<0, dfdx=0 + -- for dir>=0 dfdx=dfdx + dfdx[state.nsign] = 0 + -- state.sign = sign(dfdx) + torch.sign(state.sign,dfdx) + + -- update weights + x:addcmul(-1,state.sign,state.stepsize) + + -- update state.dfdx with current dfdx + state.delta:copy(dfdx) + + table.insert(hfx,fx) + end -- return x*, f(x) before optimization return x,hfx @@ -70,7 +70,7 @@ function optim.sgd(opfunc, x, config, state) -- (4) learning rate decay (annealing) local clr = lr / (1 + nevals*lrd) - + -- (5) parameter update with single or individual learning rates if lrs then if not state.deltaParameters then |