Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/optim.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSoumith Chintala <soumith@gmail.com>2016-09-20 18:09:32 +0300
committerGitHub <noreply@github.com>2016-09-20 18:09:32 +0300
commitbd5b7bd742eff867a3d2bc7c34bccc660d8b58a7 (patch)
tree109debee87f7de7dde1de0fcceacc53a3e9b64b0
parent66a6ece7e5288514b1bfa23b8be0b22f8230c70a (diff)
parentafa2bc9a0e0bb9d54a126257707756c4cbd7dc2a (diff)
Merge pull request #134 from hughperkins/migrate-example-from-nn
move optim doc from nn
-rw-r--r--doc/image/parameterflattening.pngbin0 -> 74658 bytes
-rw-r--r--doc/image/parameterflattening.svg338
-rw-r--r--doc/image/parameterflattening.svg.pngbin0 -> 74546 bytes
-rw-r--r--doc/intro.md194
4 files changed, 532 insertions, 0 deletions
diff --git a/doc/image/parameterflattening.png b/doc/image/parameterflattening.png
new file mode 100644
index 0000000..efab4de
--- /dev/null
+++ b/doc/image/parameterflattening.png
Binary files differ
diff --git a/doc/image/parameterflattening.svg b/doc/image/parameterflattening.svg
new file mode 100644
index 0000000..d58d62f
--- /dev/null
+++ b/doc/image/parameterflattening.svg
@@ -0,0 +1,338 @@
+<?xml version="1.0" encoding="UTF-8" standalone="no"?>
+<!-- Created with Inkscape (http://www.inkscape.org/) -->
+
+<svg
+ xmlns:dc="http://purl.org/dc/elements/1.1/"
+ xmlns:cc="http://creativecommons.org/ns#"
+ xmlns:rdf="http://www.w3.org/1999/02/22-rdf-syntax-ns#"
+ xmlns:svg="http://www.w3.org/2000/svg"
+ xmlns="http://www.w3.org/2000/svg"
+ xmlns:sodipodi="http://sodipodi.sourceforge.net/DTD/sodipodi-0.dtd"
+ xmlns:inkscape="http://www.inkscape.org/namespaces/inkscape"
+ width="275.54715mm"
+ height="214.99242mm"
+ viewBox="0 0 976.34814 761.78413"
+ id="svg2"
+ version="1.1"
+ inkscape:version="0.91 r13725"
+ sodipodi:docname="parameterflattening.svg"
+ inkscape:export-filename="/home/ubuntu/git/nn/doc/image/parameterflattening.svg.png"
+ inkscape:export-xdpi="90"
+ inkscape:export-ydpi="90">
+ <defs
+ id="defs4" />
+ <sodipodi:namedview
+ id="base"
+ pagecolor="#ffffff"
+ bordercolor="#666666"
+ borderopacity="1.0"
+ inkscape:pageopacity="0.0"
+ inkscape:pageshadow="2"
+ inkscape:zoom="0.7"
+ inkscape:cx="165.78568"
+ inkscape:cy="360.0347"
+ inkscape:document-units="px"
+ inkscape:current-layer="layer1"
+ showgrid="false"
+ inkscape:window-width="1920"
+ inkscape:window-height="1024"
+ inkscape:window-x="0"
+ inkscape:window-y="0"
+ inkscape:window-maximized="1"
+ fit-margin-top="0"
+ fit-margin-left="0"
+ fit-margin-right="0"
+ fit-margin-bottom="0" />
+ <metadata
+ id="metadata7">
+ <rdf:RDF>
+ <cc:Work
+ rdf:about="">
+ <dc:format>image/svg+xml</dc:format>
+ <dc:type
+ rdf:resource="http://purl.org/dc/dcmitype/StillImage" />
+ <dc:title></dc:title>
+ </cc:Work>
+ </rdf:RDF>
+ </metadata>
+ <g
+ inkscape:label="Layer 1"
+ inkscape:groupmode="layer"
+ id="layer1"
+ transform="translate(-145.10191,-140.95261)">
+ <rect
+ id="rect3336"
+ width="264.20071"
+ height="127.05788"
+ x="498.61389"
+ y="212.40469"
+ style="fill:none;stroke:#000000;stroke-width:1.08497822;stroke-opacity:1" />
+ <rect
+ id="rect3336-7"
+ width="264.20071"
+ height="127.05788"
+ x="499.32819"
+ y="384.54752"
+ style="fill:none;stroke:#000000;stroke-width:1.08497822;stroke-opacity:1" />
+ <rect
+ id="rect3336-7-1"
+ width="264.20071"
+ height="127.05788"
+ x="502.18533"
+ y="554.54755"
+ style="fill:none;stroke:#000000;stroke-width:1.08497822;stroke-opacity:1" />
+ <rect
+ id="rect3336-7-1-4"
+ width="264.20071"
+ height="127.05788"
+ x="499.32816"
+ y="705.97614"
+ style="fill:none;stroke:#000000;stroke-width:1.08497822;stroke-opacity:1" />
+ <rect
+ style="fill:#aafff8;fill-opacity:1;stroke:#000000;stroke-opacity:1"
+ id="rect4183"
+ width="18.571428"
+ height="631.42859"
+ x="170.00005"
+ y="206.64792" />
+ <rect
+ style="fill:#fcf2cd;fill-opacity:1;stroke:#000000;stroke-opacity:1"
+ id="rect4185"
+ width="18.571428"
+ height="631.42859"
+ x="207.14287"
+ y="207.50507" />
+ <rect
+ style="fill:#aafff8;fill-opacity:1;stroke:#000000;stroke-width:1;stroke-miterlimit:4;stroke-dasharray:8, 8;stroke-dashoffset:0;stroke-opacity:1"
+ id="rect4187"
+ width="84.285713"
+ height="41.42857"
+ x="518.57141"
+ y="229.50507" />
+ <rect
+ style="fill:#fcf2cd;fill-opacity:1;stroke:#000000;stroke-width:1;stroke-miterlimit:4;stroke-dasharray:8, 8;stroke-dashoffset:0;stroke-opacity:1"
+ id="rect4187-3"
+ width="84.285713"
+ height="41.42857"
+ x="518.42853"
+ y="283.07651" />
+ <rect
+ style="fill:#aafff8;fill-opacity:1;stroke:#000000;stroke-width:1;stroke-miterlimit:4;stroke-dasharray:8, 8;stroke-dashoffset:0;stroke-opacity:1"
+ id="rect4187-8"
+ width="84.285713"
+ height="41.42857"
+ x="519.35712"
+ y="400.57651" />
+ <rect
+ style="fill:#fcf2cd;fill-opacity:1;stroke:#000000;stroke-width:1;stroke-miterlimit:4;stroke-dasharray:8, 8;stroke-dashoffset:0;stroke-opacity:1"
+ id="rect4187-3-3"
+ width="84.285713"
+ height="41.42857"
+ x="519.21423"
+ y="454.14792" />
+ <rect
+ style="fill:#aafff8;fill-opacity:1;stroke:#000000;stroke-width:1;stroke-miterlimit:4;stroke-dasharray:8, 8;stroke-dashoffset:0;stroke-opacity:1"
+ id="rect4187-8-7"
+ width="84.285713"
+ height="41.42857"
+ x="526.5"
+ y="572.00507" />
+ <rect
+ style="fill:#fcf2cd;fill-opacity:1;stroke:#000000;stroke-width:1;stroke-miterlimit:4;stroke-dasharray:8, 8;stroke-dashoffset:0;stroke-opacity:1"
+ id="rect4187-3-3-8"
+ width="84.285713"
+ height="41.42857"
+ x="526.35712"
+ y="625.57648" />
+ <rect
+ style="fill:#aafff8;fill-opacity:1;stroke:#000000;stroke-width:1;stroke-miterlimit:4;stroke-dasharray:8, 8;stroke-dashoffset:0;stroke-opacity:1"
+ id="rect4187-8-7-8"
+ width="84.285713"
+ height="41.42857"
+ x="529.35718"
+ y="722.00513" />
+ <rect
+ style="fill:#fcf2cd;fill-opacity:1;stroke:#000000;stroke-width:1;stroke-miterlimit:4;stroke-dasharray:8, 8;stroke-dashoffset:0;stroke-opacity:1"
+ id="rect4187-3-3-8-3"
+ width="84.285713"
+ height="41.42857"
+ x="529.21429"
+ y="775.57648" />
+ <text
+ xml:space="preserve"
+ style="font-size:20px;fill:none;stroke:#000000;stroke-opacity:1"
+ x="1515.7142"
+ y="190.93362"
+ id="text4278"><tspan
+ sodipodi:role="line"
+ id="tspan4280"
+ x="1515.7142"
+ y="190.93362"></tspan></text>
+ <text
+ xml:space="preserve"
+ style="font-size:20px;fill:#000000;stroke:#000000;stroke-opacity:1;fill-opacity:1;"
+ x="635.71429"
+ y="768.07654"
+ id="text4290"><tspan
+ sodipodi:role="line"
+ id="tspan4292"
+ x="635.71429"
+ y="768.07654">conv1</tspan></text>
+ <text
+ xml:space="preserve"
+ style="font-size:20px;fill:#000000;stroke:#000000;stroke-opacity:1;fill-opacity:1;"
+ x="627.14288"
+ y="613.79077"
+ id="text4294"><tspan
+ sodipodi:role="line"
+ id="tspan4296"
+ x="627.14288"
+ y="613.79077">conv2</tspan></text>
+ <text
+ xml:space="preserve"
+ style="font-size:20px;fill:#000000;stroke:#000000;stroke-opacity:1;fill-opacity:1;"
+ x="632.85718"
+ y="443.79074"
+ id="text4298"><tspan
+ sodipodi:role="line"
+ id="tspan4300"
+ x="632.85718"
+ y="443.79074">conv3</tspan></text>
+ <text
+ xml:space="preserve"
+ style="font-size:20px;fill:#000000;stroke:#000000;stroke-opacity:1;fill-opacity:1;"
+ x="631.42865"
+ y="259.50507"
+ id="text4302"><tspan
+ sodipodi:role="line"
+ id="tspan4304"
+ x="631.42865"
+ y="259.50507">conv4</tspan></text>
+ <text
+ xml:space="preserve"
+ style="font-size:20px;fill:#000000;stroke:#000000;stroke-opacity:1;fill-opacity:1;"
+ x="528.57141"
+ y="156.64792"
+ id="text4306"><tspan
+ sodipodi:role="line"
+ id="tspan4308"
+ x="528.57141"
+ y="156.64792">Network layers:</tspan></text>
+ <text
+ xml:space="preserve"
+ style="font-size:20px;fill:#000000;stroke:#000000;stroke-opacity:1;fill-opacity:1;"
+ x="145.14287"
+ y="159.79077"
+ id="text4310"><tspan
+ sodipodi:role="line"
+ x="145.14287"
+ y="159.79077"
+ id="tspan4314">flattened tensors:</tspan></text>
+ <text
+ xml:space="preserve"
+ style="font-size:20px;fill:#000000;fill-opacity:1;stroke:#000000;stroke-opacity:1;"
+ x="175.71434"
+ y="898.0766"
+ id="text4337"><tspan
+ sodipodi:role="line"
+ id="tspan4339"
+ x="175.71434"
+ y="898.0766">params tensor</tspan></text>
+ <text
+ xml:space="preserve"
+ style="font-size:20px;fill:#000000;fill-opacity:1;stroke:#000000;stroke-opacity:1;"
+ x="288.57147"
+ y="815.21936"
+ id="text4341"><tspan
+ sodipodi:role="line"
+ id="tspan4343"
+ x="288.57147"
+ y="815.21936">gradParams</tspan><tspan
+ sodipodi:role="line"
+ x="288.57147"
+ y="840.21936"
+ id="tspan4345">tensor</tspan></text>
+ <path
+ style="fill:none;fill-rule:evenodd;stroke:#000000;stroke-width:1px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1"
+ d="M 284.28571,810.93366 228.57143,793.79078"
+ id="path4347"
+ inkscape:connector-curvature="0" />
+ <path
+ style="fill:none;fill-rule:evenodd;stroke:#000000;stroke-width:1px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1"
+ d="M 191.42857,872.36216 180,843.79076"
+ id="path4349"
+ inkscape:connector-curvature="0" />
+ <path
+ style="fill:none;fill-rule:evenodd;stroke:#000000;stroke-width:1px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1"
+ d="M 522.85714,230.93364 185.71429,205.21935"
+ id="path4351"
+ inkscape:connector-curvature="0" />
+ <path
+ style="fill:none;fill-rule:evenodd;stroke:#000000;stroke-width:1px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1"
+ d="M 517.14285,269.50506 187.14286,342.36221"
+ id="path4353"
+ inkscape:connector-curvature="0" />
+ <path
+ style="fill:none;fill-rule:evenodd;stroke:#000000;stroke-width:1px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1"
+ d="M 521.42857,396.64792 187.14286,340.93364"
+ id="path4355"
+ inkscape:connector-curvature="0" />
+ <path
+ style="fill:none;fill-rule:evenodd;stroke:#000000;stroke-width:1px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1"
+ d="M 521.42857,440.93364 185.71429,483.79078"
+ id="path4357"
+ inkscape:connector-curvature="0" />
+ <path
+ style="fill:none;fill-rule:evenodd;stroke:#000000;stroke-width:1px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1"
+ d="M 527.14285,625.21935 225.71428,506.64792"
+ id="path4359"
+ inkscape:connector-curvature="0" />
+ <path
+ style="fill:none;fill-rule:evenodd;stroke:#000000;stroke-width:1px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1"
+ d="M 522.85714,666.64792 225.71428,659.50506"
+ id="path4361"
+ inkscape:connector-curvature="0" />
+ <text
+ xml:space="preserve"
+ style="font-size:20px;fill:#000000;fill-opacity:1;stroke:#000000;stroke-width:1;stroke-miterlimit:4;stroke-dasharray:none;stroke-dashoffset:0;stroke-opacity:1;"
+ x="801.42853"
+ y="649.50513"
+ id="text4363"><tspan
+ sodipodi:role="line"
+ id="tspan4365"
+ x="801.42853"
+ y="649.50513">conv2 grad weight:</tspan><tspan
+ sodipodi:role="line"
+ x="801.42853"
+ y="674.50513"
+ id="tspan4367">view onto flattened gradParams</tspan></text>
+ <path
+ style="fill:none;fill-rule:evenodd;stroke:#000000;stroke-width:1px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1"
+ d="m 612.85708,640.9336 180,14.2857"
+ id="path4375"
+ inkscape:connector-curvature="0" />
+ <text
+ xml:space="preserve"
+ style="font-size:20px;fill:#000000;fill-opacity:1;stroke:#000000;stroke-width:1;stroke-miterlimit:4;stroke-dasharray:none;stroke-dashoffset:0;stroke-opacity:1;"
+ x="791.42853"
+ y="400.93353"
+ id="text4377"><tspan
+ sodipodi:role="line"
+ id="tspan4379"
+ x="791.42853"
+ y="400.93353">conv3 weight:</tspan><tspan
+ sodipodi:role="line"
+ x="791.42853"
+ y="425.93353"
+ id="tspan4381">view onto flattened params</tspan><tspan
+ sodipodi:role="line"
+ x="791.42853"
+ y="450.93353"
+ id="tspan4383">tensor</tspan></text>
+ <path
+ style="fill:none;fill-rule:evenodd;stroke:#000000;stroke-width:1px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1"
+ d="m 782.85708,403.7907 -180,11.4286"
+ id="path4387"
+ inkscape:connector-curvature="0" />
+ </g>
+</svg>
diff --git a/doc/image/parameterflattening.svg.png b/doc/image/parameterflattening.svg.png
new file mode 100644
index 0000000..ecf4068
--- /dev/null
+++ b/doc/image/parameterflattening.svg.png
Binary files differ
diff --git a/doc/intro.md b/doc/intro.md
index 4032029..a95aa53 100644
--- a/doc/intro.md
+++ b/doc/intro.md
@@ -39,3 +39,197 @@ for i, sample in ipairs(training_samples) do
end
```
+<a name="optim.training"></a>
+## Training using optim ##
+
+`optim` is a quite general optimizer, for minimizing any function with respect to a set
+of parameters. In our case, our
+function will be the loss of our network, given an input, and a set of weights. The goal of training
+a neural net is to
+optimize the weights to give the lowest loss over our training set of input data. So, we are going to use optim
+to minimize the loss with respect to the weights, over our training set. We will feed the data to
+`optim` in minibatches. For this particular example, we will use just one minibatch, but in your own training
+you will almost certainly want to break your training set into minibatches, and feed each minibatch to `optim`,
+one by one.
+
+We need to give `optim` a function that will output the loss and the derivative of the loss with respect to the
+weights, given the current weights, as a function parameter. The function will have access to our training minibatch, and use this
+to calculate the loss, for this minibatch. Typically, the function would be defined inside our loop over
+batches, and therefore have access to the current minibatch data.
+
+Here's how this looks:
+
+__Neural Network__
+
+We create a simple neural network with one hidden layer.
+```lua
+require 'nn'
+
+local model = nn.Sequential(); -- make a multi-layer perceptron
+local inputs = 2; local outputs = 1; local HUs = 20; -- parameters
+model:add(nn.Linear(inputs, HUs))
+model:add(nn.Tanh())
+model:add(nn.Linear(HUs, outputs))
+```
+
+__Criterion__
+
+We choose the Mean Squared Error loss criterion:
+```lua
+local criterion = nn.MSECriterion()
+```
+
+We are using an `nn.MSECriterion` because we are training on a regression task, predicting float target values.
+For a classification task, we would add an `nn.LogSoftMax()` layer to the end of our
+network, and use a `nn.ClassNLLCriterion` loss criterion.
+
+__Dataset__
+
+We will just create one minibatch of 128 examples. In your own networks, you'd want to break down your
+rather larger dataset into multiple minibatches, of around 32-512 examples each.
+
+```lua
+local batchSize = 128
+local batchInputs = torch.Tensor(batchSize, inputs)
+local batchLabels = torch.DoubleTensor(batchSize)
+
+for i=1,batchSize do
+ local input = torch.randn(2) -- normally distributed example in 2d
+ local label = 1
+ if input[1]*input[2]>0 then -- calculate label for XOR function
+ label = -1;
+ end
+ batchInputs[i]:copy(input)
+ batchLabels[i] = label
+end
+```
+
+__Flatten Parameters__
+
+`optim` expects the parameters that are to be optimized, and their gradients, to be one-dimensional tensors.
+But, our network model contains probably multiple modules, typically multiple convolutional layers, and each
+of these layers has their own weight and bias tensors. How to handle this?
+
+It is simple: we can call a standard method `:getParameters()`, that is defined for any network module. When
+we call this method, the following magic will happen:
+- a new tensor will be created, large enough to hold all the weights and biases of the entire network model
+- the model weight and bias tensors are replaced with views onto the new contiguous parameter tensor
+- and the exact same thing will happen for all the gradient tensors: replaced with views onto one single
+contiguous gradient tensor
+
+We can call this method as follows:
+```lua
+local params, gradParams = model:getParameters()
+```
+
+These flattened tensors have the following characteristics:
+- to `optim`, the parameters it needs to optimize are all contained in one single one-dimensional tensor
+- when `optim` optimizes the parameters in this large one-dimensional tensor, it is implicitly optimizing
+the weights and biases in our network model, since those are now simply views onto this large one-dimensional
+parameter tensor.
+
+It will look something like this:
+
+![Parameter Flattening](image/parameterflattening.png?raw=true "Parameter Flattening")
+
+Note that flattening the parameters redefines the weight and bias tensors for all the network modules
+in our network model. Therefore, any pre-existing references to the original model layer weight and bias tensors
+will no longer point to the model weight and bias tensors, after flattening.
+
+__Training__
+
+Now that we have created our model, our training set, and prepared the flattened network parameters,
+we can run training, using `optim`. `optim` provides [various training algorithms](https://github.com/torch/optim/blob/master/doc/index.md). We
+will use the stochastic gradient descent algorithm [sgd](https://github.com/torch/optim/blob/master/doc/index.md#x-sgdopfunc-x-state). We
+need to provide the learning rate, via an optimization state table:
+
+```lua
+local optimState = {learningRate=0.01}
+```
+
+We define an evaluation function, inside our training loop, and use `optim.sgd` to run training:
+```lua
+require 'optim'
+
+for epoch=1,50 do
+ -- local function we give to optim
+ -- it takes current weights as input, and outputs the loss
+ -- and the gradient of the loss with respect to the weights
+ -- gradParams is calculated implicitly by calling 'backward',
+ -- because the model's weight and bias gradient tensors
+ -- are simply views onto gradParams
+ local function feval(params)
+ gradParams:zero()
+
+ local outputs = model:forward(batchInputs)
+ local loss = criterion:forward(outputs, batchLabels)
+ local dloss_doutput = criterion:backward(outputs, batchLabels)
+ model:backward(batchInputs, dloss_doutput)
+
+ return loss,gradParams
+ end
+ optim.sgd(feval, params, optimState)
+end
+```
+__Test the network__
+
+For the prediction task, we will also typically use minibatches, although we can run prediction sample by
+sample too. In this example, we will predict sample by sample. To run prediction on a minibatch, simply
+pass in a tensor with one additional dimension, which represents the sample index.
+
+```lua
+x = torch.Tensor(2)
+x[1] = 0.5; x[2] = 0.5; print(model:forward(x))
+x[1] = 0.5; x[2] = -0.5; print(model:forward(x))
+x[1] = -0.5; x[2] = 0.5; print(model:forward(x))
+x[1] = -0.5; x[2] = -0.5; print(model:forward(x))
+```
+
+You should see something like:
+```lua
+> x = torch.Tensor(2)
+> x[1] = 0.5; x[2] = 0.5; print(model:forward(x))
+
+-0.3490
+[torch.Tensor of dimension 1]
+
+> x[1] = 0.5; x[2] = -0.5; print(model:forward(x))
+
+ 1.0561
+[torch.Tensor of dimension 1]
+
+> x[1] = -0.5; x[2] = 0.5; print(model:forward(x))
+
+ 0.8640
+[torch.Tensor of dimension 1]
+
+> x[1] = -0.5; x[2] = -0.5; print(model:forward(x))
+
+-0.2941
+[torch.Tensor of dimension 1]
+```
+
+If we were running on a GPU, we would probably want to predict using minibatches, because this will
+hide the latencies involved in transferring data from main memory to the GPU. To predict
+on a minbatch, we could do something like:
+
+```lua
+local x = torch.Tensor({
+ {0.5, 0.5},
+ {0.5, -0.5},
+ {-0.5, 0.5},
+ {-0.5, -0.5}
+})
+print(model:forward(x))
+```
+You should see something like:
+```lua
+> print(model:forward(x))
+ -0.3490
+ 1.0561
+ 0.8640
+ -0.2941
+[torch.Tensor of size 4]
+```
+
+That's it! For minibatched prediction, the output tensor contains one value for each of our input data samples.