Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--BatchNormalization.lua151
-rw-r--r--SpatialBatchNormalization.lua179
-rwxr-xr-xdoc/convolution.md84
-rw-r--r--doc/simple.md168
-rw-r--r--init.lua2
-rw-r--r--test.lua106
6 files changed, 608 insertions, 82 deletions
diff --git a/BatchNormalization.lua b/BatchNormalization.lua
new file mode 100644
index 0000000..bd38232
--- /dev/null
+++ b/BatchNormalization.lua
@@ -0,0 +1,151 @@
+--[[
+ This file implements Batch Normalization as described in the paper:
+ "Batch Normalization: Accelerating Deep Network Training
+ by Reducing Internal Covariate Shift"
+ by Sergey Ioffe, Christian Szegedy
+
+ This implementation is useful for inputs NOT coming from convolution layers.
+ For Convolution layers, see SpatialBatchNormalization.lua
+
+ The operation implemented is:
+ y = ( x - mean(x) )
+ -------------------- * gamma + beta
+ standard-deviation(x)
+ where gamma and beta are learnable parameters.
+
+ The learning of gamma and beta is optional.
+
+ Usage:
+ with learnable parameters: nn.BatchNormalization(N [, eps] [,momentum])
+ where N = dimensionality of input
+ without learnable parameters: nn.BatchNormalization(0 [, eps] [,momentum])
+
+ eps is a small value added to the standard-deviation to avoid divide-by-zero.
+ Defaults to 1e-5
+
+ In training time, this layer keeps a running estimate of it's computed mean and std.
+ The running sum is kept with a default momentup of 0.1 (unless over-ridden)
+ In test time, this running mean/std is used to normalize.
+]]--
+local BN,parent = torch.class('nn.BatchNormalization', 'nn.Module')
+
+function BN:__init(nOutput, eps, momentum)
+ parent.__init(self)
+ assert(nOutput and type(nOutput) == 'number',
+ 'Missing argument #1: dimensionality of input. ' ..
+ 'Give 0 for no affine transform')
+ self.eps = eps or 1e-5
+ self.train = true
+ self.momentum = momentum or 0.1
+
+ if nOutput > 0 then self.affine = true end
+ if self.affine then
+ self.weight = torch.Tensor(nOutput)
+ self.bias = torch.Tensor(nOutput)
+ self.gradWeight = torch.Tensor(nOutput)
+ self.gradBias = torch.Tensor(nOutput)
+ self:reset()
+ end
+end
+
+function BN:reset()
+ self.weight:uniform()
+ self.bias:zero()
+end
+
+function BN:updateOutput(input)
+ assert(input:dim() == 2, 'only mini-batch supported (2D tensor), got '
+ .. input:dim() .. 'D tensor instead')
+ local nBatch = input:size(1)
+
+ -- buffers that are reused
+ self.buffer = self.buffer or input.new()
+ self.buffer2 = self.buffer2 or input.new()
+ self.centered = self.centered or input.new()
+ self.centered:resizeAs(input)
+ self.std = self.std or input.new()
+ self.normalized = self.normalized or input.new()
+ self.normalized:resizeAs(input)
+ self.output:resizeAs(input)
+ self.gradInput:resizeAs(input)
+ if self.train == false then
+ assert(self.running_mean,
+ 'Module never run on training data. First run on some training data before evaluating.')
+ self.output:copy(input)
+ self.buffer:repeatTensor(self.running_mean, nBatch, 1)
+ self.output:add(-1, self.buffer)
+ self.buffer:repeatTensor(self.running_std, nBatch, 1)
+ self.output:cmul(self.buffer)
+ else -- training mode
+ self.running_mean = self.running_mean or input.new(input:size(2)):zero()
+ self.running_std = self.running_std or input.new(input:size(2)):zero()
+
+ -- calculate mean over mini-batch
+ self.buffer:mean(input, 1) -- E(x) = expectation of x.
+ self.running_mean:mul(1 - self.momentum):add(self.momentum, self.buffer) -- add to running mean
+ self.buffer:repeatTensor(self.buffer, nBatch, 1)
+
+ -- subtract mean
+ self.centered:add(input, -1, self.buffer) -- x - E(x)
+
+ -- calculate standard deviation over mini-batch
+ self.buffer:copy(self.centered):cmul(self.buffer) -- [x - E(x)]^2
+
+ -- 1 / E([x - E(x)]^2)
+ self.std:mean(self.buffer, 1):add(self.eps):sqrt():pow(-1)
+ self.running_std:mul(1 - self.momentum):add(self.momentum, self.std) -- add to running stdv
+ self.buffer:repeatTensor(self.std, nBatch, 1)
+
+ -- divide standard-deviation + eps
+ self.output:cmul(self.centered, self.buffer)
+ self.normalized:copy(self.output)
+ end
+
+ if self.affine then
+ -- multiply with gamma and add beta
+ self.buffer:repeatTensor(self.weight, nBatch, 1)
+ self.output:cmul(self.buffer)
+ self.buffer:repeatTensor(self.bias, nBatch, 1)
+ self.output:add(self.buffer)
+ end
+ return self.output
+end
+
+function BN:updateGradInput(input, gradOutput)
+ assert(input:dim() == 2, 'only mini-batch supported')
+ assert(gradOutput:dim() == 2, 'only mini-batch supported')
+ assert(self.train == true, 'should be in training mode when self.train is true')
+ local nBatch = input:size(1)
+
+ self.gradInput:cmul(self.centered, gradOutput)
+ self.buffer:mean(self.gradInput, 1)
+ self.gradInput:repeatTensor(self.buffer, nBatch, 1)
+ self.gradInput:cmul(self.centered):mul(-1)
+ self.buffer:repeatTensor(self.std, nBatch, 1)
+ self.gradInput:cmul(self.buffer):cmul(self.buffer)
+
+ self.buffer:mean(gradOutput, 1)
+ self.buffer:repeatTensor(self.buffer, nBatch, 1)
+ self.gradInput:add(gradOutput):add(-1, self.buffer)
+ self.buffer:repeatTensor(self.std, nBatch, 1)
+ self.gradInput:cmul(self.buffer)
+
+ if self.affine then
+ self.buffer:repeatTensor(self.weight, nBatch, 1)
+ self.gradInput:cmul(self.buffer)
+ end
+
+ return self.gradInput
+end
+
+function BN:accGradParameters(input, gradOutput, scale)
+ if self.affine then
+ scale = scale or 1.0
+ self.buffer2:resizeAs(self.normalized):copy(self.normalized)
+ self.buffer2:cmul(gradOutput)
+ self.buffer:sum(self.buffer2, 1) -- sum over mini-batch
+ self.gradWeight:add(scale, self.buffer)
+ self.buffer:sum(gradOutput, 1) -- sum over mini-batch
+ self.gradBias:add(scale, self.buffer)
+ end
+end
diff --git a/SpatialBatchNormalization.lua b/SpatialBatchNormalization.lua
new file mode 100644
index 0000000..ed612aa
--- /dev/null
+++ b/SpatialBatchNormalization.lua
@@ -0,0 +1,179 @@
+--[[
+ This file implements Batch Normalization as described in the paper:
+ "Batch Normalization: Accelerating Deep Network Training
+ by Reducing Internal Covariate Shift"
+ by Sergey Ioffe, Christian Szegedy
+
+ This implementation is useful for inputs coming from convolution layers.
+ For Non-convolutional layers, see BatchNormalization.lua
+
+ The operation implemented is:
+ y = ( x - mean(x) )
+ -------------------- * gamma + beta
+ standard-deviation(x)
+ where gamma and beta are learnable parameters.
+
+ The learning of gamma and beta is optional.
+
+ Usage:
+ with learnable parameters: nn.BatchNormalization(N [,eps] [,momentum])
+ where N = dimensionality of input
+ without learnable parameters: nn.BatchNormalization(0 [,eps] [,momentum])
+
+ eps is a small value added to the standard-deviation to avoid divide-by-zero.
+ Defaults to 1e-5
+
+ In training time, this layer keeps a running estimate of it's computed mean and std.
+ The running sum is kept with a default momentup of 0.1 (unless over-ridden)
+ In test time, this running mean/std is used to normalize.
+
+]]--
+local BN,parent = torch.class('nn.SpatialBatchNormalization', 'nn.Module')
+
+function BN:__init(nFeature, eps, momentum)
+ parent.__init(self)
+ assert(nFeature and type(nFeature) == 'number',
+ 'Missing argument #1: Number of feature planes. ' ..
+ 'Give 0 for no affine transform')
+ self.eps = eps or 1e-5
+ self.train = true
+ self.momentum = momentum or 0.1
+
+ if nFeature > 0 then self.affine = true end
+ if self.affine then
+ self.weight = torch.Tensor(nFeature)
+ self.bias = torch.Tensor(nFeature)
+ self.gradWeight = torch.Tensor(nFeature)
+ self.gradBias = torch.Tensor(nFeature)
+ self:reset()
+ end
+end
+
+function BN:reset()
+ self.weight:uniform()
+ self.bias:zero()
+end
+
+function BN:updateOutput(input)
+ assert(input:dim() == 4, 'only mini-batch supported (4D tensor), got '
+ .. input:dim() .. 'D tensor instead')
+ local nBatch = input:size(1)
+ local nFeature = input:size(2)
+ local iH = input:size(3)
+ local iW = input:size(4)
+
+ -- buffers that are reused
+ self.buffer = self.buffer or input.new()
+ self.buffer2 = self.buffer2 or input.new()
+ self.centered = self.centered or input.new()
+ self.centered:resizeAs(input)
+ self.std = self.std or input.new()
+ self.normalized = self.normalized or input.new()
+ self.normalized:resizeAs(input)
+ self.output:resizeAs(input)
+ self.gradInput:resizeAs(input)
+ if self.train == false then
+ assert(self.running_mean,
+ 'Module never run on training data. First run on some training data before evaluating.')
+ self.output:copy(input)
+ self.buffer:repeatTensor(self.running_mean:view(1, nFeature, 1, 1), nBatch, 1, iH, iW)
+ self.output:add(-1, self.buffer)
+ self.buffer:repeatTensor(self.running_std:view(1, nFeature, 1, 1), nBatch, 1, iH, iW)
+ self.output:cmul(self.buffer)
+ else -- training mode
+ self.running_mean = self.running_mean or input.new(nFeature):zero()
+ self.running_std = self.running_std or input.new(nFeature):zero()
+
+ -- calculate mean over mini-batch, over feature-maps
+ local in_folded = input:view(nBatch, nFeature, iH * iW)
+ self.buffer:mean(in_folded, 1)
+ self.buffer2:mean(self.buffer, 3)
+ self.running_mean:mul(1 - self.momentum):add(self.momentum, self.buffer2) -- add to running mean
+ self.buffer:repeatTensor(self.buffer2:view(1, nFeature, 1, 1),
+ nBatch, 1, iH, iW)
+
+ -- subtract mean
+ self.centered:add(input, -1, self.buffer) -- x - E(x)
+
+ -- calculate standard deviation over mini-batch
+ self.buffer:copy(self.centered):cmul(self.buffer) -- [x - E(x)]^2
+ local buf_folded = self.buffer:view(nBatch,nFeature,iH*iW)
+ self.std:mean(self.buffer2:mean(buf_folded, 1), 3)
+ self.std:add(self.eps):sqrt():pow(-1) -- 1 / E([x - E(x)]^2)
+ self.running_std:mul(1 - self.momentum):add(self.momentum, self.std) -- add to running stdv
+ self.buffer:repeatTensor(self.std:view(1, nFeature, 1, 1),
+ nBatch, 1, iH, iW)
+
+ -- divide standard-deviation + eps
+ self.output:cmul(self.centered, self.buffer)
+ self.normalized:copy(self.output)
+ end
+
+ if self.affine then
+ -- multiply with gamma and add beta
+ self.buffer:repeatTensor(self.weight:view(1, nFeature, 1, 1),
+ nBatch, 1, iH, iW)
+ self.output:cmul(self.buffer)
+ self.buffer:repeatTensor(self.bias:view(1, nFeature, 1, 1),
+ nBatch, 1, iH, iW)
+ self.output:add(self.buffer)
+ end
+
+ return self.output
+end
+
+function BN:updateGradInput(input, gradOutput)
+ assert(input:dim() == 4, 'only mini-batch supported')
+ assert(gradOutput:dim() == 4, 'only mini-batch supported')
+ assert(self.train == true, 'should be in training mode when self.train is true')
+ local nBatch = input:size(1)
+ local nFeature = input:size(2)
+ local iH = input:size(3)
+ local iW = input:size(4)
+
+ self.gradInput:cmul(self.centered, gradOutput)
+ local gi_folded = self.gradInput:view(nBatch, nFeature, iH * iW)
+ self.buffer2:mean(self.buffer:mean(gi_folded, 1), 3)
+ self.gradInput:repeatTensor(self.buffer2:view(1, nFeature, 1, 1),
+ nBatch, 1, iH, iW)
+ self.gradInput:cmul(self.centered):mul(-1)
+ self.buffer:repeatTensor(self.std:view(1, nFeature, 1, 1),
+ nBatch, 1, iH, iW)
+ self.gradInput:cmul(self.buffer):cmul(self.buffer)
+
+ self.buffer:mean(gradOutput:view(nBatch, nFeature, iH*iW), 1)
+ self.buffer2:mean(self.buffer, 3)
+ self.buffer:repeatTensor(self.buffer2:view(1, nFeature, 1, 1),
+ nBatch, 1, iH, iW)
+ self.gradInput:add(gradOutput):add(-1, self.buffer)
+ self.buffer:repeatTensor(self.std:view(1, nFeature, 1, 1),
+ nBatch, 1, iH, iW)
+ self.gradInput:cmul(self.buffer)
+
+ if self.affine then
+ self.buffer:repeatTensor(self.weight:view(1, nFeature, 1, 1),
+ nBatch, 1, iH, iW)
+ self.gradInput:cmul(self.buffer)
+ end
+
+ return self.gradInput
+end
+
+function BN:accGradParameters(input, gradOutput, scale)
+ if self.affine then
+ scale = scale or 1.0
+ local nBatch = input:size(1)
+ local nFeature = input:size(2)
+ local iH = input:size(3)
+ local iW = input:size(4)
+ self.buffer2:resizeAs(self.normalized):copy(self.normalized)
+ self.buffer2 = self.buffer2:cmul(gradOutput):view(nBatch, nFeature, iH*iW)
+ self.buffer:sum(self.buffer2, 1) -- sum over mini-batch
+ self.buffer2:sum(self.buffer, 3) -- sum over pixels
+ self.gradWeight:add(scale, self.buffer2)
+
+ self.buffer:sum(gradOutput:view(nBatch, nFeature, iH*iW), 1)
+ self.buffer2:sum(self.buffer, 3)
+ self.gradBias:add(scale, self.buffer2) -- sum over mini-batch
+ end
+end
diff --git a/doc/convolution.md b/doc/convolution.md
index ef1c9f2..5b6eebd 100755
--- a/doc/convolution.md
+++ b/doc/convolution.md
@@ -2,7 +2,7 @@
# Convolutional layers #
A convolution is an integral that expresses the amount of overlap of one function `g` as it is shifted over another function `f`. It therefore "blends" one function with another. The neural network package supports convolution, pooling, subsampling and other relevant facilities. These are divided base on the dimensionality of the input and output [Tensors](https://github.com/torch/torch7/blob/master/doc/tensor.md#tensor):
- * [Temporal Modules](#nn.TemporalModules) apply to sequences with a one-dimensional relationship
+ * [Temporal Modules](#nn.TemporalModules) apply to sequences with a one-dimensional relationship
(e.g. sequences of words, phonemes and letters. Strings of some kind).
* [TemporalConvolution](#nn.TemporalConvolution) : a 1D convolution over an input sequence ;
* [TemporalSubSampling](#nn.TemporalSubSampling) : a 1D sub-sampling over an input sequence ;
@@ -18,6 +18,7 @@ A convolution is an integral that expresses the amount of overlap of one functio
* [SpatialConvolutionMap](#nn.SpatialConvolutionMap) : a 2D convolution that uses a generic connection table ;
* [SpatialZeroPadding](#nn.SpatialZeroPadding) : padds a feature map with specified number of zeros ;
* [SpatialSubtractiveNormalization](#nn.SpatialSubtractiveNormalization) : a spatial subtraction operation on a series of 2D inputs using
+ * [SpatialBatchNormalization](#nn.SpatialBatchNormalization): mean/std normalization over the mini-batch inputs and pixels, with an optional affine transform that follows
a kernel for computing the weighted average in a neighborhood ;
* [Volumetric Modules](#nn.VolumetricModules) apply to inputs with three-dimensional relationships (e.g. videos) :
* [VolumetricConvolution](#nn.VolumetricConvolution) : a 3D convolution over an input video (a sequence of images) ;
@@ -27,10 +28,10 @@ a kernel for computing the weighted average in a neighborhood ;
## Temporal Modules ##
Excluding an optional first batch dimension, temporal layers expect a 2D Tensor as input. The
first dimension is the number of frames in the sequence (e.g. `nInputFrame`), the last dimenstion
-is the number of features per frame (e.g. `inputFrameSize`). The output will normally have the same number
-of dimensions, although the size of each dimension may change. These are commonly used for processing acoustic signals or sequences of words, i.e. in Natural Language Processing.
+is the number of features per frame (e.g. `inputFrameSize`). The output will normally have the same number
+of dimensions, although the size of each dimension may change. These are commonly used for processing acoustic signals or sequences of words, i.e. in Natural Language Processing.
-Note: The [LookupTable](#nn.LookupTable) is special in that while it does output a temporal Tensor of size `nOutputFrame x outputFrameSize`,
+Note: The [LookupTable](#nn.LookupTable) is special in that while it does output a temporal Tensor of size `nOutputFrame x outputFrameSize`,
its input is a 1D Tensor of indices of size `nIndices`. Again, this is excluding the option first batch dimension.
<a name="nn.TemporalConvolution"/>
@@ -77,7 +78,7 @@ output[t][i] = bias[i]
Here is a simple example:
```lua
-inp=5; -- dimensionality of one sequence element
+inp=5; -- dimensionality of one sequence element
outp=1; -- number of derived features for one sequence element
kw=1; -- kernel only operates on one sequence element per step
dw=1; -- we step once and go on to the next sequence element
@@ -93,8 +94,8 @@ which gives:
-0.9872
-0.6808
-0.9403
--0.9680
--0.6901
+-0.9680
+-0.6901
-0.6387
[torch.Tensor of dimension 7x1]
```
@@ -128,7 +129,7 @@ module = nn.TemporalMaxPooling(kW, [dW])
Applies 1D max-pooling operation in `kW` regions by step size
`dW` steps. Input sequence composed of `nInputFrame` frames. The `input` tensor in
-`forward(input)` is expected to be a 2D tensor (`nInputFrame x inputFrameSize`)
+`forward(input)` is expected to be a 2D tensor (`nInputFrame x inputFrameSize`)
or a 3D tensor (`nBatchFrame x nInputFrame x inputFrameSize`).
If the input sequence is a 2D tensor of dimension `nInputFrame x inputFrameSize`, the output sequence will be
@@ -185,16 +186,16 @@ module = nn.LookupTable(nIndex, size1, [size2], [size3], ...)
```
This layer is a particular case of a convolution, where the width of the convolution would be `1`.
-When calling `forward(input)`, it assumes `input` is a 1D or 2D tensor filled with indices.
+When calling `forward(input)`, it assumes `input` is a 1D or 2D tensor filled with indices.
If the input is a matrix, then each row is assumed to be an input sample of given batch. Indices start
at `1` and can go up to `nIndex`. For each index, it outputs a corresponding `Tensor` of size
specified by `sizes` (a `LongStorage`) or `size1 x size2 x...`.
-Given a 1D input, the output tensors are concatenated,
+Given a 1D input, the output tensors are concatenated,
generating a `n x size1 x size2 x ... x sizeN` tensor, where `n`
-is the size of a 1D `input` tensor.
+is the size of a 1D `input` tensor.
-Again with a 1D input, when only `size1` is provided, the `forward(input)` is equivalent to
+Again with a 1D input, when only `size1` is provided, the `forward(input)` is equivalent to
performing the following matrix-matrix multiplication in an efficient manner:
```lua
M P
@@ -205,7 +206,7 @@ where `M` is a 2D matrix `size1 x nIndex` containing the parameters of the looku
1D example:
```lua
-- a lookup table containing 10 tensors of size 3
- module = nn.LookupTable(10, 3)
+ module = nn.LookupTable(10, 3)
input = torch.Tensor{1,2,1,10}
print(module:forward(input))
@@ -221,14 +222,14 @@ Outputs something like:
```
Note that the first row vector is the same as the 3rd one!
-Given a 2D input tensor of size `m x n`, the output is a `m x n x size1 x size2 x ... x sizeN`
-tensor, where `m` is the number of samples in
+Given a 2D input tensor of size `m x n`, the output is a `m x n x size1 x size2 x ... x sizeN`
+tensor, where `m` is the number of samples in
the batch and `n` is the number of indices per sample.
2D example:
```lua
-- a lookup table containing 10 tensors of size 3
- module = nn.LookupTable(10, 3)
+ module = nn.LookupTable(10, 3)
-- a batch of 2 samples of 4 indices each
input = torch.Tensor({{1,2,4,5},{4,3,2,10}})
@@ -237,13 +238,13 @@ the batch and `n` is the number of indices per sample.
Outputs something like:
```lua
-(1,.,.) =
+(1,.,.) =
-0.0570 -1.5354 1.8555
-0.9067 1.3392 0.6275
1.9662 0.4645 -0.8111
0.1103 1.7811 1.5969
-(2,.,.) =
+(2,.,.) =
1.9662 0.4645 -0.8111
0.0026 -1.4547 -0.5154
-0.9067 1.3392 0.6275
@@ -445,7 +446,7 @@ Applies a 2D up-sampling over an input image composed of several input planes. T
The parameters are the following:
* `scale`: The upscale ratio. Must be a positive integer
-The up-scaling method is simple nearest neighbor, ie:
+The up-scaling method is simple nearest neighbor, ie:
```lua
output(u,v) = input(floor((u-1)/scale)+1, floor((v-1)/scale)+1)
@@ -500,10 +501,53 @@ w2=image.display(processed)
```
![](image/lena.jpg)![](image/lenap.jpg)
+<a name="nn.SpatialBatchNormalization"/>
+## SpatialBatchNormalization ##
+
+`module` = `nn.SpatialBatchNormalization(N [,eps] [, momentum])`
+ where N = number of input feature maps
+giving N = 0 disables the learnable affine transform.
+eps is a small value added to the standard-deviation to avoid divide-by-zero. Defaults to 1e-5
+
+Implements Batch Normalization as described in the paper:
+ "Batch Normalization: Accelerating Deep Network Training
+ by Reducing Internal Covariate Shift"
+ by Sergey Ioffe, Christian Szegedy
+
+The operation implemented is:
+```
+ y = ( x - mean(x) )
+ -------------------- * gamma + beta
+ standard-deviation(x)
+```
+where the mean and standard-deviation are calculated per feature-map over the mini-batches and pixels
+and where gamma and beta are learnable parameter vectors of size N (where N = number of feature maps).
+The learning of gamma and beta is optional.
+
+ In training time, this layer keeps a running estimate of it's computed mean and std.
+ The running sum is kept with a default momentup of 0.1 (unless over-ridden)
+ In test time, this running mean/std is used to normalize.
+
+
+
+The module only accepts 4D inputs.
+
+```lua
+-- with learnable parameters
+model = nn.SpatialBatchNormalization(m)
+A = torch.randn(b, m)
+B = model.forward(B) -- C will be of size `b x m`
+
+-- without learnable parameters
+model = nn.SpatialBatchNormalization(0)
+A = torch.randn(b, m)
+B = model.forward(B) -- C will be of size `b x m`
+```
+
<a name="nn.VolumetricModules"/>
## Volumetric Modules ##
Excluding and optional batch dimension, volumetric layers expect a 4D Tensor as input. The
-first dimension is the number of features (e.g. `frameSize`), the second is sequential (e.g. `time`) and the
+first dimension is the number of features (e.g. `frameSize`), the second is sequential (e.g. `time`) and the
last two dimenstions are spatial (e.g. `height x width`). These are commonly used for processing videos (sequences of images).
<a name="nn.VolumetricConvolution"/>
diff --git a/doc/simple.md b/doc/simple.md
index 999880e..f649464 100644
--- a/doc/simple.md
+++ b/doc/simple.md
@@ -1,6 +1,6 @@
<a name="nn.simplelayers.dok"/>
# Simple layers #
-Simple Modules are used for various tasks like adapting Tensor methods
+Simple Modules are used for various tasks like adapting Tensor methods
and providing affine transformations :
* Parameterized Modules :
* [Linear](#nn.Linear) : a linear transformation ;
@@ -29,10 +29,12 @@ and providing affine transformations :
* [Square](#nn.Square) : an element-wise square operation ;
* [Sqrt](#nn.Sqrt) : an element-wise [sqrt](https://github.com/torch/torch7/blob/master/doc/maths.md#res-torchsqrtres-x) operation ;
* [MM](#nn.MM) : matrix-matrix multiplication (also supports batches of matrices) ;
+* Normalization modules:
+ * [BatchNormalization](#nn.BatchNormalization) - mean/std normalization over the mini-batch inputs, with an optional affine transform that follows
* Miscellaneous Modules :
* [Identity](#nn.Identity) : forward input as-is to output (useful with [ParallelTable](table.md#nn.ParallelTable));
* [Dropout](#nn.Dropout) : masks parts of the `input` using binary samples from a [bernoulli](http://en.wikipedia.org/wiki/Bernoulli_distribution) distribution ;
-
+
<a name="nn.Linear"/>
## Linear ##
@@ -77,9 +79,9 @@ applying the linear transformation is performed with:
Applies a linear transformation to the incoming sparse data, i.e.
_y= Ax+b_. The `input` tensor given in `forward(input)` must
-be a sparse vector represented as 2D tensor of the form
+be a sparse vector represented as 2D tensor of the form
torch.Tensor(N, 2) where the pairs represent indices and values.
-The SparseLinear layer is useful when the number of input
+The SparseLinear layer is useful when the number of input
dimensions is very large and the input data is sparse.
You can create a sparse linear layer in the following way:
@@ -87,9 +89,9 @@ You can create a sparse linear layer in the following way:
```lua
module= nn.SparseLinear(10000,2) -- 10000 inputs, 2 outputs
```
-The sparse linear module may be used as part of a larger network,
-and apart from the form of the input,
-[SparseLinear](#nn.SparseLinear)
+The sparse linear module may be used as part of a larger network,
+and apart from the form of the input,
+[SparseLinear](#nn.SparseLinear)
operates in exactly the same way as the [Linear](#nn.Linear) layer.
A sparse input vector may be created as so..
@@ -107,9 +109,9 @@ A sparse input vector may be created as so..
```
-The first column contains indices, the second column contains
-values in a a vector where all other elements are zeros. The
-indices should not exceed the stated dimensions of the input to the
+The first column contains indices, the second column contains
+values in a a vector where all other elements are zeros. The
+indices should not exceed the stated dimensions of the input to the
layer (10000 in the example).
@@ -118,17 +120,17 @@ layer (10000 in the example).
`module` = `nn.Dropout(p)`
-During training, `Dropout` masks parts of the `input` using binary samples from
+During training, `Dropout` masks parts of the `input` using binary samples from
a [bernoulli](http://en.wikipedia.org/wiki/Bernoulli_distribution) distribution.
Each `input` element has a probability of `p` of being dropped, i.e having its
-commensurate output element be zero. This has proven an effective technique for
-regularization and preventing the co-adaptation of neurons
-(see [Hinton et al. 2012](http://arxiv.org/abs/1207.0580)).
+commensurate output element be zero. This has proven an effective technique for
+regularization and preventing the co-adaptation of neurons
+(see [Hinton et al. 2012](http://arxiv.org/abs/1207.0580)).
-Furthermore, the ouputs are scaled by a factor of `1/(1-p)` during training. This allows the
+Furthermore, the ouputs are scaled by a factor of `1/(1-p)` during training. This allows the
`input` to be simply forwarded as-is during evaluation.
-In this example, we demonstrate how the call to [forward](module.md#output-forwardinput) samples
+In this example, we demonstrate how the call to [forward](module.md#output-forwardinput) samples
different `outputs` to dropout (the zeros) given the same `input`:
```lua
module = nn.Dropout()
@@ -162,7 +164,7 @@ module = nn.Dropout()
```
In both cases the `gradOutput` and `input` are scaled by `1/(1-p)`, which in this case is `2`.
-During [evaluation](module.md#evaluate), `Dropout` does nothing more than
+During [evaluation](module.md#evaluate), `Dropout` does nothing more than
forward the input such that all elements of the input are considered.
```lua
> module:evaluate()
@@ -185,13 +187,13 @@ We can return to training our model by first calling [Module:training()](module.
```
-When used, `Dropout` should normally be applied to the input of parameterized
-[Modules](module.md#nn.Module) like [Linear](#nn.Linear)
+When used, `Dropout` should normally be applied to the input of parameterized
+[Modules](module.md#nn.Module) like [Linear](#nn.Linear)
or [SpatialConvolution](convolution.md#nn.SpatialConvolution).
A `p` of `0.5` (the default) is usually okay for hidden layers.
`Dropout` can sometimes be used successfully on the dataset inputs with a `p` around `0.2`.
-It sometimes works best following [Transfer](transfer.md) Modules
-like [ReLU](transfer.md#nn.ReLU). All this depends a great deal on the dataset so its up
+It sometimes works best following [Transfer](transfer.md) Modules
+like [ReLU](transfer.md#nn.ReLU). All this depends a great deal on the dataset so its up
to the user to try different combinations.
@@ -221,15 +223,15 @@ gnuplot.grid(true)
Applies a bias term to the incoming data, i.e.
_y_i= x_i + b_i, or if _scalar=true_ then uses a single bias term,
-_y_i= x_i + b.
+_y_i= x_i + b.
Example:
```lua
-y=torch.Tensor(5);
+y=torch.Tensor(5);
mlp=nn.Sequential()
mlp:add(nn.Add(5))
-function gradUpdate(mlp, x, y, criterion, learningRate)
+function gradUpdate(mlp, x, y, criterion, learningRate)
local pred = mlp:forward(x)
local err = criterion:forward(pred, y)
local gradCriterion = criterion:backward(pred, y)
@@ -241,7 +243,7 @@ end
for i=1,10000 do
x=torch.rand(5)
- y:copy(x);
+ y:copy(x);
for i=1,5 do y[i]=y[i]+i; end
err=gradUpdate(mlp,x,y,nn.MSECriterion(),0.01)
end
@@ -256,7 +258,7 @@ gives the output:
5.0000
[torch.Tensor of dimension 5]
```
-i.e. the network successfully learns the input _x_ has been shifted
+i.e. the network successfully learns the input _x_ has been shifted
to produce the output _y_.
@@ -266,15 +268,15 @@ to produce the output _y_.
`module` = `Mul()`
Applies a _single_ scaling factor to the incoming data, i.e.
-_y= w x_, where _w_ is a scalar.
+_y= w x_, where _w_ is a scalar.
Example:
```lua
-y=torch.Tensor(5);
+y=torch.Tensor(5);
mlp=nn.Sequential()
mlp:add(nn.Mul())
-function gradUpdate(mlp, x, y, criterion, learningRate)
+function gradUpdate(mlp, x, y, criterion, learningRate)
local pred = mlp:forward(x)
local err = criterion:forward(pred,y)
local gradCriterion = criterion:backward(pred,y);
@@ -307,7 +309,7 @@ pi.
Applies a component-wise multiplication to the incoming data, i.e.
`y_i = w_i * x_i`. Argument `size` can be one or many numbers (sizes)
-or a `torch.LongStorage`. For example, `nn.CMul(3,4,5)` is equivalent to
+or a `torch.LongStorage`. For example, `nn.CMul(3,4,5)` is equivalent to
`nn.CMul(torch.LongStorage{3,4,5})`.
Example:
@@ -315,10 +317,10 @@ Example:
mlp=nn.Sequential()
mlp:add(nn.CMul(5))
-y=torch.Tensor(5);
+y=torch.Tensor(5);
sc=torch.Tensor(5); for i=1,5 do sc[i]=i; end -- scale input with this
-function gradUpdate(mlp,x,y,criterion,learningRate)
+function gradUpdate(mlp,x,y,criterion,learningRate)
local pred = mlp:forward(x)
local err = criterion:forward(pred,y)
local gradCriterion = criterion:backward(pred,y);
@@ -394,7 +396,7 @@ then an `nxq` matrix would be output.
Outputs the Euclidean distance of the input to `outputSize` centers,
i.e. this layer has the weights `w_j`, for `j` = `1`,..,`outputSize`, where
-`w_j` are vectors of dimension `inputSize`.
+`w_j` are vectors of dimension `inputSize`.
The distance `y_j` between center `j` and input `x` is formulated as
`y_j = || w_j - x ||`.
@@ -406,10 +408,10 @@ The distance `y_j` between center `j` and input `x` is formulated as
This module is similar to [Euclidean](#nn.Euclidean), but
additionally learns a separate diagonal covariance matrix across the
-features of the input space _for each center_.
+features of the input space _for each center_.
-In other words, for each of the `outputSize` centers `w_j`, there is
-a diagonal covariance matrices `c_j`, for `j` = `1`,..,`outputSize`,
+In other words, for each of the `outputSize` centers `w_j`, there is
+a diagonal covariance matrices `c_j`, for `j` = `1`,..,`outputSize`,
where `c_j` are stored as vectors of size `inputSize`.
The distance `y_j` between center `j` and input `x` is formulated as
@@ -420,8 +422,8 @@ The distance `y_j` between center `j` and input `x` is formulated as
`module` = `Identity()`
-Creates a module that returns whatever is input to it as output.
-This is useful when combined with the module
+Creates a module that returns whatever is input to it as output.
+This is useful when combined with the module
[ParallelTable](table.md#nn.ParallelTable)
in case you do not wish to do anything to one of the input Tensors.
Example:
@@ -429,7 +431,7 @@ Example:
mlp=nn.Identity()
print(mlp:forward(torch.ones(5,2)))
```
-gives the output:
+gives the output:
```lua
1 1
1 1
@@ -440,10 +442,10 @@ gives the output:
```
Here is a more useful example, where one can implement a network which also computes a Criterion using this module:
-```lua
+```lua
pred_mlp=nn.Sequential(); -- A network that makes predictions given x.
-pred_mlp:add(nn.Linear(5,4))
-pred_mlp:add(nn.Linear(4,3))
+pred_mlp:add(nn.Linear(5,4))
+pred_mlp:add(nn.Linear(4,3))
xy_mlp=nn.ParallelTable();-- A network for predictions and for keeping the
xy_mlp:add(pred_mlp) -- true label for comparison with a criterion
@@ -457,14 +459,14 @@ mlp:add(cr_wrap) -- and then applies the criterion.
for i=1,100 do -- Do a few training iterations
x=torch.ones(5); -- Make input features.
- y=torch.Tensor(3);
+ y=torch.Tensor(3);
y:copy(x:narrow(1,1,3)) -- Make output label.
err=mlp:forward{x,y} -- Forward both input and output.
print(err) -- Print error from criterion.
- mlp:zeroGradParameters(); -- Do backprop...
- mlp:backward({x, y} );
- mlp:updateParameters(0.05);
+ mlp:zeroGradParameters(); -- Do backprop...
+ mlp:backward({x, y} );
+ mlp:updateParameters(0.05);
end
```
@@ -476,9 +478,9 @@ end
This layer copies the input to output with type casting from input
type from `inputType` to `outputType`. Unless `forceCopy` is true, when
the first two arguments are the same, the input isn't copied, only transfered
-as the output. The default `forceCopy` is false.
+as the output. The default `forceCopy` is false.
When `dontCast` is true, a call to `nn.Copy:type(type)` will not cast
-the module's `output` and `gradInput` Tensors to the new type. The default
+the module's `output` and `gradInput` Tensors to the new type. The default
is false.
<a name="nn.Narrow"/>
@@ -543,16 +545,16 @@ torch> =o
`module` = `Reshape(dimension1, dimension2, ... [, batchMode])`
Reshapes an `nxpxqx..` Tensor into a `dimension1xdimension2x...` Tensor,
-taking the elements column-wise.
+taking the elements column-wise.
-The optional last argument `batchMode`,
-when `true` forces the first dimension of the input to be considered
-the batch dimension, and thus keep its size fixed. This is necessary when
-dealing with batch sizes of one. When `false`, it forces the
-entire input (including the first dimension) to be reshaped to the
-input size. Default `batchMode=nil`, which means that the module
-considers inputs with more elements than the produce of provided sizes,
-i.e. `dimension1xdimension2x...`, to be batches.
+The optional last argument `batchMode`,
+when `true` forces the first dimension of the input to be considered
+the batch dimension, and thus keep its size fixed. This is necessary when
+dealing with batch sizes of one. When `false`, it forces the
+entire input (including the first dimension) to be reshaped to the
+input size. Default `batchMode=nil`, which means that the module
+considers inputs with more elements than the produce of provided sizes,
+i.e. `dimension1xdimension2x...`, to be batches.
Example:
```lua
@@ -731,17 +733,17 @@ gives the output:
```
This can be used in conjunction with [Concat](containers.md#nn.Concat)
-to emulate the behavior
-of [Parallel](containers.md#nn.Parallel), or to select various parts of an input Tensor to
+to emulate the behavior
+of [Parallel](containers.md#nn.Parallel), or to select various parts of an input Tensor to
perform operations on. Here is a fairly complicated example:
```lua
mlp=nn.Sequential();
-c=nn.Concat(2)
+c=nn.Concat(2)
for i=1,10 do
local t=nn.Sequential()
t:add(nn.Select(1,i))
- t:add(nn.Linear(3,2))
+ t:add(nn.Linear(3,2))
t:add(nn.Reshape(2,1))
c:add(t)
end
@@ -759,7 +761,7 @@ for i=1,10000 do -- Train for a few iterations
err=criterion:forward(pred,y)
gradCriterion = criterion:backward(pred,y);
mlp:zeroGradParameters();
- mlp:backward(x, gradCriterion);
+ mlp:backward(x, gradCriterion);
mlp:updateParameters(0.01);
print(err)
end
@@ -853,3 +855,45 @@ A = torch.randn(b, m, n)
B = torch.randn(b, n, p)
C = model.forward({A, B}) -- C will be of size `b x m x n`
```
+
+<a name="nn.BatchNormalization"/>
+## BatchNormalization ##
+
+`module` = `nn.BatchNormalization(N [, eps] [, momentum])`
+ where N = dimensionality of input
+giving N = 0 disables the learnable affine transform.
+eps is a small value added to the standard-deviation to avoid divide-by-zero. Defaults to 1e-5
+
+In training time, this layer keeps a running estimate of it's computed mean and std.
+The running sum is kept with a default momentup of 0.1 (unless over-ridden)
+In test time, this running mean/std is used to normalize.
+
+
+Implements Batch Normalization as described in the paper:
+ "Batch Normalization: Accelerating Deep Network Training
+ by Reducing Internal Covariate Shift"
+ by Sergey Ioffe, Christian Szegedy
+
+The operation implemented is:
+```
+ y = ( x - mean(x) )
+ -------------------- * gamma + beta
+ standard-deviation(x) + eps
+```
+where the mean and standard-deviation are calculated per-dimension over the mini-batches
+and where gamma and beta are learnable parameter vectors of size N (where N = input dimensionality).
+The learning of gamma and beta is optional.
+
+The module only accepts 2D inputs.
+
+```lua
+-- with learnable parameters
+model = nn.BatchNormalization(m)
+A = torch.randn(b, m)
+B = model.forward(B) -- C will be of size `b x m`
+
+-- without learnable parameters
+model = nn.BatchNormalization(0)
+A = torch.randn(b, m)
+B = model.forward(B) -- C will be of size `b x m`
+```
diff --git a/init.lua b/init.lua
index f25275b..7704712 100644
--- a/init.lua
+++ b/init.lua
@@ -18,6 +18,7 @@ include('Select.lua')
include('Narrow.lua')
include('Replicate.lua')
include('Transpose.lua')
+include('BatchNormalization.lua')
include('Copy.lua')
include('Min.lua')
@@ -86,6 +87,7 @@ include('SpatialDivisiveNormalization.lua')
include('SpatialContrastiveNormalization.lua')
include('SpatialZeroPadding.lua')
include('SpatialUpSamplingNearest.lua')
+include('SpatialBatchNormalization.lua')
include('VolumetricConvolution.lua')
include('VolumetricMaxPooling.lua')
diff --git a/test.lua b/test.lua
index 3661d03..a2c2465 100644
--- a/test.lua
+++ b/test.lua
@@ -3188,6 +3188,112 @@ function nntest.CosineEmbeddingCriterion()
equal(grads[2], zero, 'gradient should be zero')
end
+function nntest.BatchNormalization()
+ local nframes = torch.random(50,70)
+ local indim = torch.random(1,10)
+ local input = torch.zeros(nframes, indim):uniform()
+ local module = nn.BatchNormalization(indim)
+
+ local err = jac.testJacobian(module,input)
+ mytester:assertlt(err,precision, 'error on state ')
+
+ local err = jac.testJacobianParameters(module, input,
+ module.weight, module.gradWeight)
+ mytester:assertlt(err,precision, 'error on weight ')
+
+ local err = jac.testJacobianParameters(module, input,
+ module.bias, module.gradBias)
+ mytester:assertlt(err,precision, 'error on weight ')
+
+ local err = jac.testJacobianUpdateParameters(module, input, module.weight)
+ mytester:assertlt(err,precision, 'error on weight [direct update] ')
+
+ local err = jac.testJacobianUpdateParameters(module, input, module.bias)
+ mytester:assertlt(err,precision, 'error on bias [direct update] ')
+
+ for t,err in pairs(jac.testAllUpdate(module, input,
+ 'weight', 'gradWeight')) do
+ mytester:assertlt(err, precision, string.format(
+ 'error on weight [%s]', t))
+ end
+
+ for t,err in pairs(jac.testAllUpdate(module, input,
+ 'bias', 'gradBias')) do
+ mytester:assertlt(err, precision, string.format(
+ 'error on bias [%s]', t))
+ end
+
+ -- IO
+ local ferr,berr = jac.testIO(module,input)
+ mytester:asserteq(ferr, 0, torch.typename(module) .. ' - i/o forward err ')
+ mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err ')
+
+ -- batch norm without affine transform
+ module = nn.BatchNormalization(0)
+
+ local err = jac.testJacobian(module,input)
+ mytester:assertlt(err,precision, 'error on state ')
+
+ -- IO
+ local ferr,berr = jac.testIO(module,input)
+ mytester:asserteq(ferr, 0, torch.typename(module) .. ' - i/o forward err ')
+ mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err ')
+end
+
+function nntest.SpatialBatchNormalization()
+ local nframes = torch.random(1,10)
+ local indim = torch.random(1,4)
+ local ini = torch.random(1,5)
+ local inj = torch.random(1,5)
+ local input = torch.zeros(nframes, indim, ini, inj):uniform()
+ local module = nn.SpatialBatchNormalization(indim)
+
+ local err = jac.testJacobian(module,input)
+ mytester:assertlt(err,precision, 'error on state ')
+
+ local err = jac.testJacobianParameters(module, input,
+ module.weight, module.gradWeight)
+ mytester:assertlt(err,precision, 'error on weight ')
+
+ local err = jac.testJacobianParameters(module, input,
+ module.bias, module.gradBias)
+ mytester:assertlt(err,precision, 'error on weight ')
+
+ local err = jac.testJacobianUpdateParameters(module, input, module.weight)
+ mytester:assertlt(err,precision, 'error on weight [direct update] ')
+
+ local err = jac.testJacobianUpdateParameters(module, input, module.bias)
+ mytester:assertlt(err,precision, 'error on bias [direct update] ')
+
+ for t,err in pairs(jac.testAllUpdate(module, input,
+ 'weight', 'gradWeight')) do
+ mytester:assertlt(err, precision, string.format(
+ 'error on weight [%s]', t))
+ end
+
+ for t,err in pairs(jac.testAllUpdate(module, input,
+ 'bias', 'gradBias')) do
+ mytester:assertlt(err, precision, string.format(
+ 'error on bias [%s]', t))
+ end
+
+ -- IO
+ local ferr,berr = jac.testIO(module,input)
+ mytester:asserteq(ferr, 0, torch.typename(module) .. ' - i/o forward err ')
+ mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err ')
+
+ -- batch norm without affine transform
+ module = nn.SpatialBatchNormalization(0)
+
+ local err = jac.testJacobian(module,input)
+ mytester:assertlt(err,precision, 'error on state ')
+
+ -- IO
+ local ferr,berr = jac.testIO(module,input)
+ mytester:asserteq(ferr, 0, torch.typename(module) .. ' - i/o forward err ')
+ mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err ')
+end
+
mytester:add(nntest)
if not nn then