diff options
author | Soumith Chintala <soumith@gmail.com> | 2016-09-20 18:09:32 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2016-09-20 18:09:32 +0300 |
commit | bd5b7bd742eff867a3d2bc7c34bccc660d8b58a7 (patch) | |
tree | 109debee87f7de7dde1de0fcceacc53a3e9b64b0 | |
parent | 66a6ece7e5288514b1bfa23b8be0b22f8230c70a (diff) | |
parent | afa2bc9a0e0bb9d54a126257707756c4cbd7dc2a (diff) |
Merge pull request #134 from hughperkins/migrate-example-from-nn
move optim doc from nn
-rw-r--r-- | doc/image/parameterflattening.png | bin | 0 -> 74658 bytes | |||
-rw-r--r-- | doc/image/parameterflattening.svg | 338 | ||||
-rw-r--r-- | doc/image/parameterflattening.svg.png | bin | 0 -> 74546 bytes | |||
-rw-r--r-- | doc/intro.md | 194 |
4 files changed, 532 insertions, 0 deletions
diff --git a/doc/image/parameterflattening.png b/doc/image/parameterflattening.png Binary files differnew file mode 100644 index 0000000..efab4de --- /dev/null +++ b/doc/image/parameterflattening.png diff --git a/doc/image/parameterflattening.svg b/doc/image/parameterflattening.svg new file mode 100644 index 0000000..d58d62f --- /dev/null +++ b/doc/image/parameterflattening.svg @@ -0,0 +1,338 @@ +<?xml version="1.0" encoding="UTF-8" standalone="no"?> +<!-- Created with Inkscape (http://www.inkscape.org/) --> + +<svg + xmlns:dc="http://purl.org/dc/elements/1.1/" + xmlns:cc="http://creativecommons.org/ns#" + xmlns:rdf="http://www.w3.org/1999/02/22-rdf-syntax-ns#" + xmlns:svg="http://www.w3.org/2000/svg" + xmlns="http://www.w3.org/2000/svg" + xmlns:sodipodi="http://sodipodi.sourceforge.net/DTD/sodipodi-0.dtd" + xmlns:inkscape="http://www.inkscape.org/namespaces/inkscape" + width="275.54715mm" + height="214.99242mm" + viewBox="0 0 976.34814 761.78413" + id="svg2" + version="1.1" + inkscape:version="0.91 r13725" + sodipodi:docname="parameterflattening.svg" + inkscape:export-filename="/home/ubuntu/git/nn/doc/image/parameterflattening.svg.png" + inkscape:export-xdpi="90" + inkscape:export-ydpi="90"> + <defs + id="defs4" /> + <sodipodi:namedview + id="base" + pagecolor="#ffffff" + bordercolor="#666666" + borderopacity="1.0" + inkscape:pageopacity="0.0" + inkscape:pageshadow="2" + inkscape:zoom="0.7" + inkscape:cx="165.78568" + inkscape:cy="360.0347" + inkscape:document-units="px" + inkscape:current-layer="layer1" + showgrid="false" + inkscape:window-width="1920" + inkscape:window-height="1024" + inkscape:window-x="0" + inkscape:window-y="0" + inkscape:window-maximized="1" + fit-margin-top="0" + fit-margin-left="0" + fit-margin-right="0" + fit-margin-bottom="0" /> + <metadata + id="metadata7"> + <rdf:RDF> + <cc:Work + rdf:about=""> + <dc:format>image/svg+xml</dc:format> + <dc:type + rdf:resource="http://purl.org/dc/dcmitype/StillImage" /> + <dc:title></dc:title> + </cc:Work> + </rdf:RDF> + </metadata> + <g + inkscape:label="Layer 1" + inkscape:groupmode="layer" + id="layer1" + transform="translate(-145.10191,-140.95261)"> + <rect + id="rect3336" + width="264.20071" + height="127.05788" + x="498.61389" + y="212.40469" + style="fill:none;stroke:#000000;stroke-width:1.08497822;stroke-opacity:1" /> + <rect + id="rect3336-7" + width="264.20071" + height="127.05788" + x="499.32819" + y="384.54752" + style="fill:none;stroke:#000000;stroke-width:1.08497822;stroke-opacity:1" /> + <rect + id="rect3336-7-1" + width="264.20071" + height="127.05788" + x="502.18533" + y="554.54755" + style="fill:none;stroke:#000000;stroke-width:1.08497822;stroke-opacity:1" /> + <rect + id="rect3336-7-1-4" + width="264.20071" + height="127.05788" + x="499.32816" + y="705.97614" + style="fill:none;stroke:#000000;stroke-width:1.08497822;stroke-opacity:1" /> + <rect + style="fill:#aafff8;fill-opacity:1;stroke:#000000;stroke-opacity:1" + id="rect4183" + width="18.571428" + height="631.42859" + x="170.00005" + y="206.64792" /> + <rect + style="fill:#fcf2cd;fill-opacity:1;stroke:#000000;stroke-opacity:1" + id="rect4185" + width="18.571428" + height="631.42859" + x="207.14287" + y="207.50507" /> + <rect + style="fill:#aafff8;fill-opacity:1;stroke:#000000;stroke-width:1;stroke-miterlimit:4;stroke-dasharray:8, 8;stroke-dashoffset:0;stroke-opacity:1" + id="rect4187" + width="84.285713" + height="41.42857" + x="518.57141" + y="229.50507" /> + <rect + style="fill:#fcf2cd;fill-opacity:1;stroke:#000000;stroke-width:1;stroke-miterlimit:4;stroke-dasharray:8, 8;stroke-dashoffset:0;stroke-opacity:1" + id="rect4187-3" + width="84.285713" + height="41.42857" + x="518.42853" + y="283.07651" /> + <rect + style="fill:#aafff8;fill-opacity:1;stroke:#000000;stroke-width:1;stroke-miterlimit:4;stroke-dasharray:8, 8;stroke-dashoffset:0;stroke-opacity:1" + id="rect4187-8" + width="84.285713" + height="41.42857" + x="519.35712" + y="400.57651" /> + <rect + style="fill:#fcf2cd;fill-opacity:1;stroke:#000000;stroke-width:1;stroke-miterlimit:4;stroke-dasharray:8, 8;stroke-dashoffset:0;stroke-opacity:1" + id="rect4187-3-3" + width="84.285713" + height="41.42857" + x="519.21423" + y="454.14792" /> + <rect + style="fill:#aafff8;fill-opacity:1;stroke:#000000;stroke-width:1;stroke-miterlimit:4;stroke-dasharray:8, 8;stroke-dashoffset:0;stroke-opacity:1" + id="rect4187-8-7" + width="84.285713" + height="41.42857" + x="526.5" + y="572.00507" /> + <rect + style="fill:#fcf2cd;fill-opacity:1;stroke:#000000;stroke-width:1;stroke-miterlimit:4;stroke-dasharray:8, 8;stroke-dashoffset:0;stroke-opacity:1" + id="rect4187-3-3-8" + width="84.285713" + height="41.42857" + x="526.35712" + y="625.57648" /> + <rect + style="fill:#aafff8;fill-opacity:1;stroke:#000000;stroke-width:1;stroke-miterlimit:4;stroke-dasharray:8, 8;stroke-dashoffset:0;stroke-opacity:1" + id="rect4187-8-7-8" + width="84.285713" + height="41.42857" + x="529.35718" + y="722.00513" /> + <rect + style="fill:#fcf2cd;fill-opacity:1;stroke:#000000;stroke-width:1;stroke-miterlimit:4;stroke-dasharray:8, 8;stroke-dashoffset:0;stroke-opacity:1" + id="rect4187-3-3-8-3" + width="84.285713" + height="41.42857" + x="529.21429" + y="775.57648" /> + <text + xml:space="preserve" + style="font-size:20px;fill:none;stroke:#000000;stroke-opacity:1" + x="1515.7142" + y="190.93362" + id="text4278"><tspan + sodipodi:role="line" + id="tspan4280" + x="1515.7142" + y="190.93362"></tspan></text> + <text + xml:space="preserve" + style="font-size:20px;fill:#000000;stroke:#000000;stroke-opacity:1;fill-opacity:1;" + x="635.71429" + y="768.07654" + id="text4290"><tspan + sodipodi:role="line" + id="tspan4292" + x="635.71429" + y="768.07654">conv1</tspan></text> + <text + xml:space="preserve" + style="font-size:20px;fill:#000000;stroke:#000000;stroke-opacity:1;fill-opacity:1;" + x="627.14288" + y="613.79077" + id="text4294"><tspan + sodipodi:role="line" + id="tspan4296" + x="627.14288" + y="613.79077">conv2</tspan></text> + <text + xml:space="preserve" + style="font-size:20px;fill:#000000;stroke:#000000;stroke-opacity:1;fill-opacity:1;" + x="632.85718" + y="443.79074" + id="text4298"><tspan + sodipodi:role="line" + id="tspan4300" + x="632.85718" + y="443.79074">conv3</tspan></text> + <text + xml:space="preserve" + style="font-size:20px;fill:#000000;stroke:#000000;stroke-opacity:1;fill-opacity:1;" + x="631.42865" + y="259.50507" + id="text4302"><tspan + sodipodi:role="line" + id="tspan4304" + x="631.42865" + y="259.50507">conv4</tspan></text> + <text + xml:space="preserve" + style="font-size:20px;fill:#000000;stroke:#000000;stroke-opacity:1;fill-opacity:1;" + x="528.57141" + y="156.64792" + id="text4306"><tspan + sodipodi:role="line" + id="tspan4308" + x="528.57141" + y="156.64792">Network layers:</tspan></text> + <text + xml:space="preserve" + style="font-size:20px;fill:#000000;stroke:#000000;stroke-opacity:1;fill-opacity:1;" + x="145.14287" + y="159.79077" + id="text4310"><tspan + sodipodi:role="line" + x="145.14287" + y="159.79077" + id="tspan4314">flattened tensors:</tspan></text> + <text + xml:space="preserve" + style="font-size:20px;fill:#000000;fill-opacity:1;stroke:#000000;stroke-opacity:1;" + x="175.71434" + y="898.0766" + id="text4337"><tspan + sodipodi:role="line" + id="tspan4339" + x="175.71434" + y="898.0766">params tensor</tspan></text> + <text + xml:space="preserve" + style="font-size:20px;fill:#000000;fill-opacity:1;stroke:#000000;stroke-opacity:1;" + x="288.57147" + y="815.21936" + id="text4341"><tspan + sodipodi:role="line" + id="tspan4343" + x="288.57147" + y="815.21936">gradParams</tspan><tspan + sodipodi:role="line" + x="288.57147" + y="840.21936" + id="tspan4345">tensor</tspan></text> + <path + style="fill:none;fill-rule:evenodd;stroke:#000000;stroke-width:1px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1" + d="M 284.28571,810.93366 228.57143,793.79078" + id="path4347" + inkscape:connector-curvature="0" /> + <path + style="fill:none;fill-rule:evenodd;stroke:#000000;stroke-width:1px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1" + d="M 191.42857,872.36216 180,843.79076" + id="path4349" + inkscape:connector-curvature="0" /> + <path + style="fill:none;fill-rule:evenodd;stroke:#000000;stroke-width:1px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1" + d="M 522.85714,230.93364 185.71429,205.21935" + id="path4351" + inkscape:connector-curvature="0" /> + <path + style="fill:none;fill-rule:evenodd;stroke:#000000;stroke-width:1px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1" + d="M 517.14285,269.50506 187.14286,342.36221" + id="path4353" + inkscape:connector-curvature="0" /> + <path + style="fill:none;fill-rule:evenodd;stroke:#000000;stroke-width:1px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1" + d="M 521.42857,396.64792 187.14286,340.93364" + id="path4355" + inkscape:connector-curvature="0" /> + <path + style="fill:none;fill-rule:evenodd;stroke:#000000;stroke-width:1px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1" + d="M 521.42857,440.93364 185.71429,483.79078" + id="path4357" + inkscape:connector-curvature="0" /> + <path + style="fill:none;fill-rule:evenodd;stroke:#000000;stroke-width:1px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1" + d="M 527.14285,625.21935 225.71428,506.64792" + id="path4359" + inkscape:connector-curvature="0" /> + <path + style="fill:none;fill-rule:evenodd;stroke:#000000;stroke-width:1px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1" + d="M 522.85714,666.64792 225.71428,659.50506" + id="path4361" + inkscape:connector-curvature="0" /> + <text + xml:space="preserve" + style="font-size:20px;fill:#000000;fill-opacity:1;stroke:#000000;stroke-width:1;stroke-miterlimit:4;stroke-dasharray:none;stroke-dashoffset:0;stroke-opacity:1;" + x="801.42853" + y="649.50513" + id="text4363"><tspan + sodipodi:role="line" + id="tspan4365" + x="801.42853" + y="649.50513">conv2 grad weight:</tspan><tspan + sodipodi:role="line" + x="801.42853" + y="674.50513" + id="tspan4367">view onto flattened gradParams</tspan></text> + <path + style="fill:none;fill-rule:evenodd;stroke:#000000;stroke-width:1px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1" + d="m 612.85708,640.9336 180,14.2857" + id="path4375" + inkscape:connector-curvature="0" /> + <text + xml:space="preserve" + style="font-size:20px;fill:#000000;fill-opacity:1;stroke:#000000;stroke-width:1;stroke-miterlimit:4;stroke-dasharray:none;stroke-dashoffset:0;stroke-opacity:1;" + x="791.42853" + y="400.93353" + id="text4377"><tspan + sodipodi:role="line" + id="tspan4379" + x="791.42853" + y="400.93353">conv3 weight:</tspan><tspan + sodipodi:role="line" + x="791.42853" + y="425.93353" + id="tspan4381">view onto flattened params</tspan><tspan + sodipodi:role="line" + x="791.42853" + y="450.93353" + id="tspan4383">tensor</tspan></text> + <path + style="fill:none;fill-rule:evenodd;stroke:#000000;stroke-width:1px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1" + d="m 782.85708,403.7907 -180,11.4286" + id="path4387" + inkscape:connector-curvature="0" /> + </g> +</svg> diff --git a/doc/image/parameterflattening.svg.png b/doc/image/parameterflattening.svg.png Binary files differnew file mode 100644 index 0000000..ecf4068 --- /dev/null +++ b/doc/image/parameterflattening.svg.png diff --git a/doc/intro.md b/doc/intro.md index 4032029..a95aa53 100644 --- a/doc/intro.md +++ b/doc/intro.md @@ -39,3 +39,197 @@ for i, sample in ipairs(training_samples) do end ``` +<a name="optim.training"></a> +## Training using optim ## + +`optim` is a quite general optimizer, for minimizing any function with respect to a set +of parameters. In our case, our +function will be the loss of our network, given an input, and a set of weights. The goal of training +a neural net is to +optimize the weights to give the lowest loss over our training set of input data. So, we are going to use optim +to minimize the loss with respect to the weights, over our training set. We will feed the data to +`optim` in minibatches. For this particular example, we will use just one minibatch, but in your own training +you will almost certainly want to break your training set into minibatches, and feed each minibatch to `optim`, +one by one. + +We need to give `optim` a function that will output the loss and the derivative of the loss with respect to the +weights, given the current weights, as a function parameter. The function will have access to our training minibatch, and use this +to calculate the loss, for this minibatch. Typically, the function would be defined inside our loop over +batches, and therefore have access to the current minibatch data. + +Here's how this looks: + +__Neural Network__ + +We create a simple neural network with one hidden layer. +```lua +require 'nn' + +local model = nn.Sequential(); -- make a multi-layer perceptron +local inputs = 2; local outputs = 1; local HUs = 20; -- parameters +model:add(nn.Linear(inputs, HUs)) +model:add(nn.Tanh()) +model:add(nn.Linear(HUs, outputs)) +``` + +__Criterion__ + +We choose the Mean Squared Error loss criterion: +```lua +local criterion = nn.MSECriterion() +``` + +We are using an `nn.MSECriterion` because we are training on a regression task, predicting float target values. +For a classification task, we would add an `nn.LogSoftMax()` layer to the end of our +network, and use a `nn.ClassNLLCriterion` loss criterion. + +__Dataset__ + +We will just create one minibatch of 128 examples. In your own networks, you'd want to break down your +rather larger dataset into multiple minibatches, of around 32-512 examples each. + +```lua +local batchSize = 128 +local batchInputs = torch.Tensor(batchSize, inputs) +local batchLabels = torch.DoubleTensor(batchSize) + +for i=1,batchSize do + local input = torch.randn(2) -- normally distributed example in 2d + local label = 1 + if input[1]*input[2]>0 then -- calculate label for XOR function + label = -1; + end + batchInputs[i]:copy(input) + batchLabels[i] = label +end +``` + +__Flatten Parameters__ + +`optim` expects the parameters that are to be optimized, and their gradients, to be one-dimensional tensors. +But, our network model contains probably multiple modules, typically multiple convolutional layers, and each +of these layers has their own weight and bias tensors. How to handle this? + +It is simple: we can call a standard method `:getParameters()`, that is defined for any network module. When +we call this method, the following magic will happen: +- a new tensor will be created, large enough to hold all the weights and biases of the entire network model +- the model weight and bias tensors are replaced with views onto the new contiguous parameter tensor +- and the exact same thing will happen for all the gradient tensors: replaced with views onto one single +contiguous gradient tensor + +We can call this method as follows: +```lua +local params, gradParams = model:getParameters() +``` + +These flattened tensors have the following characteristics: +- to `optim`, the parameters it needs to optimize are all contained in one single one-dimensional tensor +- when `optim` optimizes the parameters in this large one-dimensional tensor, it is implicitly optimizing +the weights and biases in our network model, since those are now simply views onto this large one-dimensional +parameter tensor. + +It will look something like this: + +![Parameter Flattening](image/parameterflattening.png?raw=true "Parameter Flattening") + +Note that flattening the parameters redefines the weight and bias tensors for all the network modules +in our network model. Therefore, any pre-existing references to the original model layer weight and bias tensors +will no longer point to the model weight and bias tensors, after flattening. + +__Training__ + +Now that we have created our model, our training set, and prepared the flattened network parameters, +we can run training, using `optim`. `optim` provides [various training algorithms](https://github.com/torch/optim/blob/master/doc/index.md). We +will use the stochastic gradient descent algorithm [sgd](https://github.com/torch/optim/blob/master/doc/index.md#x-sgdopfunc-x-state). We +need to provide the learning rate, via an optimization state table: + +```lua +local optimState = {learningRate=0.01} +``` + +We define an evaluation function, inside our training loop, and use `optim.sgd` to run training: +```lua +require 'optim' + +for epoch=1,50 do + -- local function we give to optim + -- it takes current weights as input, and outputs the loss + -- and the gradient of the loss with respect to the weights + -- gradParams is calculated implicitly by calling 'backward', + -- because the model's weight and bias gradient tensors + -- are simply views onto gradParams + local function feval(params) + gradParams:zero() + + local outputs = model:forward(batchInputs) + local loss = criterion:forward(outputs, batchLabels) + local dloss_doutput = criterion:backward(outputs, batchLabels) + model:backward(batchInputs, dloss_doutput) + + return loss,gradParams + end + optim.sgd(feval, params, optimState) +end +``` +__Test the network__ + +For the prediction task, we will also typically use minibatches, although we can run prediction sample by +sample too. In this example, we will predict sample by sample. To run prediction on a minibatch, simply +pass in a tensor with one additional dimension, which represents the sample index. + +```lua +x = torch.Tensor(2) +x[1] = 0.5; x[2] = 0.5; print(model:forward(x)) +x[1] = 0.5; x[2] = -0.5; print(model:forward(x)) +x[1] = -0.5; x[2] = 0.5; print(model:forward(x)) +x[1] = -0.5; x[2] = -0.5; print(model:forward(x)) +``` + +You should see something like: +```lua +> x = torch.Tensor(2) +> x[1] = 0.5; x[2] = 0.5; print(model:forward(x)) + +-0.3490 +[torch.Tensor of dimension 1] + +> x[1] = 0.5; x[2] = -0.5; print(model:forward(x)) + + 1.0561 +[torch.Tensor of dimension 1] + +> x[1] = -0.5; x[2] = 0.5; print(model:forward(x)) + + 0.8640 +[torch.Tensor of dimension 1] + +> x[1] = -0.5; x[2] = -0.5; print(model:forward(x)) + +-0.2941 +[torch.Tensor of dimension 1] +``` + +If we were running on a GPU, we would probably want to predict using minibatches, because this will +hide the latencies involved in transferring data from main memory to the GPU. To predict +on a minbatch, we could do something like: + +```lua +local x = torch.Tensor({ + {0.5, 0.5}, + {0.5, -0.5}, + {-0.5, 0.5}, + {-0.5, -0.5} +}) +print(model:forward(x)) +``` +You should see something like: +```lua +> print(model:forward(x)) + -0.3490 + 1.0561 + 0.8640 + -0.2941 +[torch.Tensor of size 4] +``` + +That's it! For minibatched prediction, the output tensor contains one value for each of our input data samples. |