Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorHugh Perkins <hughperkins@gmail.com>2016-05-17 16:11:05 +0300
committerHugh Perkins <hughperkins@gmail.com>2016-05-17 16:11:05 +0300
commitdd61a0289852d990117c21a2ba886eccc6d4213f (patch)
tree7c47323a3799909e7382e9cedd07b3d5596afbbf
parent7ef5ec0716a7895edbbbef278f2af969b4c19210 (diff)
move manual training to the front
-rw-r--r--README.md2
-rw-r--r--doc/image/parameterflattening.svg45
-rw-r--r--doc/training.md196
3 files changed, 130 insertions, 113 deletions
diff --git a/README.md b/README.md
index 378a440..e848fd8 100644
--- a/README.md
+++ b/README.md
@@ -16,6 +16,6 @@ This package provides an easy and modular way to build and train simple or compl
* [`ClassNLLCriterion`](doc/criterion.md#nn.ClassNLLCriterion): the Negative Log Likelihood criterion used for classification;
* Additional documentation:
* [Overview](doc/overview.md#nn.overview.dok) of the package essentials including modules, containers and training;
- * [Training](doc/training.md#nn.traningneuralnet.dok): how to train a neural network using [`StochasticGradient`](doc/training.md#nn.StochasticGradient);
+ * [Training](doc/training.md#nn.traningneuralnet.dok): how to train a neural network using [optim](https://github.com/torch/optim);
* [Testing](doc/testing.md): how to test your modules.
* [Experimental Modules](https://github.com/clementfarabet/lua---nnx/blob/master/README.md): a package containing experimental modules and criteria.
diff --git a/doc/image/parameterflattening.svg b/doc/image/parameterflattening.svg
index 2aebe4c..d58d62f 100644
--- a/doc/image/parameterflattening.svg
+++ b/doc/image/parameterflattening.svg
@@ -9,13 +9,16 @@
xmlns="http://www.w3.org/2000/svg"
xmlns:sodipodi="http://sodipodi.sourceforge.net/DTD/sodipodi-0.dtd"
xmlns:inkscape="http://www.inkscape.org/namespaces/inkscape"
- width="400mm"
- height="600mm"
- viewBox="0 0 1417.3228 2125.9842"
+ width="275.54715mm"
+ height="214.99242mm"
+ viewBox="0 0 976.34814 761.78413"
id="svg2"
version="1.1"
inkscape:version="0.91 r13725"
- sodipodi:docname="parameterflattening.svg">
+ sodipodi:docname="parameterflattening.svg"
+ inkscape:export-filename="/home/ubuntu/git/nn/doc/image/parameterflattening.svg.png"
+ inkscape:export-xdpi="90"
+ inkscape:export-ydpi="90">
<defs
id="defs4" />
<sodipodi:namedview
@@ -26,8 +29,8 @@
inkscape:pageopacity="0.0"
inkscape:pageshadow="2"
inkscape:zoom="0.7"
- inkscape:cx="343.74474"
- inkscape:cy="509.6602"
+ inkscape:cx="165.78568"
+ inkscape:cy="360.0347"
inkscape:document-units="px"
inkscape:current-layer="layer1"
showgrid="false"
@@ -35,7 +38,11 @@
inkscape:window-height="1024"
inkscape:window-x="0"
inkscape:window-y="0"
- inkscape:window-maximized="1" />
+ inkscape:window-maximized="1"
+ fit-margin-top="0"
+ fit-margin-left="0"
+ fit-margin-right="0"
+ fit-margin-bottom="0" />
<metadata
id="metadata7">
<rdf:RDF>
@@ -52,7 +59,7 @@
inkscape:label="Layer 1"
inkscape:groupmode="layer"
id="layer1"
- transform="translate(0,1073.622)">
+ transform="translate(-145.10191,-140.95261)">
<rect
id="rect3336"
width="264.20071"
@@ -163,7 +170,7 @@
y="190.93362"></tspan></text>
<text
xml:space="preserve"
- style="font-size:20px;fill:none;stroke:#000000;stroke-opacity:1"
+ style="font-size:20px;fill:#000000;stroke:#000000;stroke-opacity:1;fill-opacity:1;"
x="635.71429"
y="768.07654"
id="text4290"><tspan
@@ -173,7 +180,7 @@
y="768.07654">conv1</tspan></text>
<text
xml:space="preserve"
- style="font-size:20px;fill:none;stroke:#000000;stroke-opacity:1"
+ style="font-size:20px;fill:#000000;stroke:#000000;stroke-opacity:1;fill-opacity:1;"
x="627.14288"
y="613.79077"
id="text4294"><tspan
@@ -183,7 +190,7 @@
y="613.79077">conv2</tspan></text>
<text
xml:space="preserve"
- style="font-size:20px;fill:none;stroke:#000000;stroke-opacity:1"
+ style="font-size:20px;fill:#000000;stroke:#000000;stroke-opacity:1;fill-opacity:1;"
x="632.85718"
y="443.79074"
id="text4298"><tspan
@@ -193,7 +200,7 @@
y="443.79074">conv3</tspan></text>
<text
xml:space="preserve"
- style="font-size:20px;fill:none;stroke:#000000;stroke-opacity:1"
+ style="font-size:20px;fill:#000000;stroke:#000000;stroke-opacity:1;fill-opacity:1;"
x="631.42865"
y="259.50507"
id="text4302"><tspan
@@ -203,7 +210,7 @@
y="259.50507">conv4</tspan></text>
<text
xml:space="preserve"
- style="font-size:20px;fill:none;stroke:#000000;stroke-opacity:1"
+ style="font-size:20px;fill:#000000;stroke:#000000;stroke-opacity:1;fill-opacity:1;"
x="528.57141"
y="156.64792"
id="text4306"><tspan
@@ -213,7 +220,7 @@
y="156.64792">Network layers:</tspan></text>
<text
xml:space="preserve"
- style="font-size:20px;fill:none;stroke:#000000;stroke-opacity:1"
+ style="font-size:20px;fill:#000000;stroke:#000000;stroke-opacity:1;fill-opacity:1;"
x="145.14287"
y="159.79077"
id="text4310"><tspan
@@ -223,7 +230,7 @@
id="tspan4314">flattened tensors:</tspan></text>
<text
xml:space="preserve"
- style="font-size:20px;fill:#fcf2cd;fill-opacity:1;stroke:#000000;stroke-opacity:1"
+ style="font-size:20px;fill:#000000;fill-opacity:1;stroke:#000000;stroke-opacity:1;"
x="175.71434"
y="898.0766"
id="text4337"><tspan
@@ -233,7 +240,7 @@
y="898.0766">params tensor</tspan></text>
<text
xml:space="preserve"
- style="font-size:20px;fill:#fcf2cd;fill-opacity:1;stroke:#000000;stroke-opacity:1"
+ style="font-size:20px;fill:#000000;fill-opacity:1;stroke:#000000;stroke-opacity:1;"
x="288.57147"
y="815.21936"
id="text4341"><tspan
@@ -243,7 +250,7 @@
y="815.21936">gradParams</tspan><tspan
sodipodi:role="line"
x="288.57147"
- y="840.2193"
+ y="840.21936"
id="tspan4345">tensor</tspan></text>
<path
style="fill:none;fill-rule:evenodd;stroke:#000000;stroke-width:1px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1"
@@ -287,7 +294,7 @@
inkscape:connector-curvature="0" />
<text
xml:space="preserve"
- style="font-size:20px;fill:#fcf2cd;fill-opacity:1;stroke:#000000;stroke-width:1;stroke-miterlimit:4;stroke-dasharray:none;stroke-dashoffset:0;stroke-opacity:1"
+ style="font-size:20px;fill:#000000;fill-opacity:1;stroke:#000000;stroke-width:1;stroke-miterlimit:4;stroke-dasharray:none;stroke-dashoffset:0;stroke-opacity:1;"
x="801.42853"
y="649.50513"
id="text4363"><tspan
@@ -306,7 +313,7 @@
inkscape:connector-curvature="0" />
<text
xml:space="preserve"
- style="font-size:20px;fill:#fcf2cd;fill-opacity:1;stroke:#000000;stroke-width:1;stroke-miterlimit:4;stroke-dasharray:none;stroke-dashoffset:0;stroke-opacity:1"
+ style="font-size:20px;fill:#000000;fill-opacity:1;stroke:#000000;stroke-width:1;stroke-miterlimit:4;stroke-dasharray:none;stroke-dashoffset:0;stroke-opacity:1;"
x="791.42853"
y="400.93353"
id="text4377"><tspan
diff --git a/doc/training.md b/doc/training.md
index 67d438b..49d8dda 100644
--- a/doc/training.md
+++ b/doc/training.md
@@ -1,12 +1,107 @@
<a name="nn.traningneuralnet.dok"></a>
# Training a neural network #
-Training a neural network is easy with a [simple `for` loop](#nn.DoItYourself).
-While doing your own loop provides great flexibility, you might
-want sometimes a quick way of training neural
-networks. [optim](https://github.com/torch/optim) is the standard way of training Torch7 neural networks.
+Training a neural network is easy with a [simple `for` loop](#nn.DoItYourself). Typically however we would
+use the `optim` optimizer, which implements some cool functionalities, like Nesterov momentum,
+[adagrad](https://github.com/torch/optim/blob/master/doc/index.md#x-adagradopfunc-x-config-state) and
+[adam](https://github.com/torch/optim/blob/master/doc/index.md#x-adamopfunc-x-config-state).
-`optim` is a quite general optimizer, for minimizing any function that outputs a loss. In our case, our
+We will demonstrate using a for-loop first, to show the low-level view of what happens in training, and then
+we will show how to train using `optim`.
+
+<a name="nn.DoItYourself"></a>
+## Example of manual training of a neural network ##
+
+We show an example here on a classical XOR problem.
+
+__Neural Network__
+
+We create a simple neural network with one hidden layer.
+```lua
+require "nn"
+mlp = nn.Sequential(); -- make a multi-layer perceptron
+inputs = 2; outputs = 1; HUs = 20; -- parameters
+mlp:add(nn.Linear(inputs, HUs))
+mlp:add(nn.Tanh())
+mlp:add(nn.Linear(HUs, outputs))
+```
+
+__Loss function__
+
+We choose the Mean Squared Error criterion:
+```lua
+criterion = nn.MSECriterion()
+```
+
+__Training__
+
+We create data _on the fly_ and feed it to the neural network.
+
+```lua
+for i = 1,2500 do
+ -- random sample
+ local input= torch.randn(2); -- normally distributed example in 2d
+ local output= torch.Tensor(1);
+ if input[1]*input[2] > 0 then -- calculate label for XOR function
+ output[1] = -1
+ else
+ output[1] = 1
+ end
+
+ -- feed it to the neural network and the criterion
+ criterion:forward(mlp:forward(input), output)
+
+ -- train over this example in 3 steps
+ -- (1) zero the accumulation of the gradients
+ mlp:zeroGradParameters()
+ -- (2) accumulate gradients
+ mlp:backward(input, criterion:backward(mlp.output, output))
+ -- (3) update parameters with a 0.01 learning rate
+ mlp:updateParameters(0.01)
+end
+```
+
+__Test the network__
+
+```lua
+x = torch.Tensor(2)
+x[1] = 0.5; x[2] = 0.5; print(mlp:forward(x))
+x[1] = 0.5; x[2] = -0.5; print(mlp:forward(x))
+x[1] = -0.5; x[2] = 0.5; print(mlp:forward(x))
+x[1] = -0.5; x[2] = -0.5; print(mlp:forward(x))
+```
+
+You should see something like:
+```lua
+> x = torch.Tensor(2)
+> x[1] = 0.5; x[2] = 0.5; print(mlp:forward(x))
+
+-0.6140
+[torch.Tensor of dimension 1]
+
+> x[1] = 0.5; x[2] = -0.5; print(mlp:forward(x))
+
+ 0.8878
+[torch.Tensor of dimension 1]
+
+> x[1] = -0.5; x[2] = 0.5; print(mlp:forward(x))
+
+ 0.8548
+[torch.Tensor of dimension 1]
+
+> x[1] = -0.5; x[2] = -0.5; print(mlp:forward(x))
+
+-0.5498
+[torch.Tensor of dimension 1]
+```
+
+<a name="nn.DoItYourself"></a>
+## Training using optim ##
+
+[optim](https://github.com/torch/optim) is the standard way of training Torch7 neural networks.
+
+`optim` is a quite general optimizer, for minimizing any function with respect to a set
+of parameters. In our case, our
function will be the loss of our network, given an input, and a set of weights. The goal of training
a neural net is to
optimize the weights to give the lowest loss over our training set of input data. So, we are going to use optim
@@ -119,8 +214,8 @@ for epoch=1,50 do
-- it takes current weights as input, and outputs the loss
-- and the gradient of the loss with respect to the weights
-- gradParams is calculated implicitly by calling 'backward',
- -- because gradParams is a view onto the model's weight and bias
- -- gradient tensors
+ -- because the model's weight and bias gradient tensors
+ -- are simply views onto gradParams
local function feval(params)
gradParams:zero()
@@ -195,90 +290,5 @@ You should see something like:
[torch.Tensor of size 4]
```
-In this case, the output tensor contains one value for each of our input data samples.
-
-<a name="nn.DoItYourself"></a>
-## Example of manual training of a neural network ##
-
-We show an example here on a classical XOR problem.
-
-__Neural Network__
-
-We create a simple neural network with one hidden layer.
-```lua
-require "nn"
-mlp = nn.Sequential(); -- make a multi-layer perceptron
-inputs = 2; outputs = 1; HUs = 20; -- parameters
-mlp:add(nn.Linear(inputs, HUs))
-mlp:add(nn.Tanh())
-mlp:add(nn.Linear(HUs, outputs))
-```
-
-__Loss function__
-
-We choose the Mean Squared Error criterion.
-```lua
-criterion = nn.MSECriterion()
-```
-
-__Training__
-
-We create data _on the fly_ and feed it to the neural network.
-
-```lua
-for i = 1,2500 do
- -- random sample
- local input= torch.randn(2); -- normally distributed example in 2d
- local output= torch.Tensor(1);
- if input[1]*input[2] > 0 then -- calculate label for XOR function
- output[1] = -1
- else
- output[1] = 1
- end
-
- -- feed it to the neural network and the criterion
- criterion:forward(mlp:forward(input), output)
-
- -- train over this example in 3 steps
- -- (1) zero the accumulation of the gradients
- mlp:zeroGradParameters()
- -- (2) accumulate gradients
- mlp:backward(input, criterion:backward(mlp.output, output))
- -- (3) update parameters with a 0.01 learning rate
- mlp:updateParameters(0.01)
-end
-```
-
-__Test the network__
+That's it! For minibatched prediction, the output tensor contains one value for each of our input data samples.
-```lua
-x = torch.Tensor(2)
-x[1] = 0.5; x[2] = 0.5; print(mlp:forward(x))
-x[1] = 0.5; x[2] = -0.5; print(mlp:forward(x))
-x[1] = -0.5; x[2] = 0.5; print(mlp:forward(x))
-x[1] = -0.5; x[2] = -0.5; print(mlp:forward(x))
-```
-
-You should see something like:
-```lua
-> x = torch.Tensor(2)
-> x[1] = 0.5; x[2] = 0.5; print(mlp:forward(x))
-
--0.6140
-[torch.Tensor of dimension 1]
-
-> x[1] = 0.5; x[2] = -0.5; print(mlp:forward(x))
-
- 0.8878
-[torch.Tensor of dimension 1]
-
-> x[1] = -0.5; x[2] = 0.5; print(mlp:forward(x))
-
- 0.8548
-[torch.Tensor of dimension 1]
-
-> x[1] = -0.5; x[2] = -0.5; print(mlp:forward(x))
-
--0.5498
-[torch.Tensor of dimension 1]
-```