diff options
author | Soumith Chintala <soumith@gmail.com> | 2015-03-21 08:52:30 +0300 |
---|---|---|
committer | Soumith Chintala <soumith@gmail.com> | 2015-03-21 08:52:30 +0300 |
commit | a716e978686d8d5cfb3092aff9c7773883717d2b (patch) | |
tree | 2ad7bedf709f5b96199203aa5c0a1023b1b8fa85 | |
parent | c2394aefb897914f16e958ec9489d9327fa7e8c6 (diff) | |
parent | fa12c6fc92095b65f44163bb70b3ba32f0970229 (diff) |
Merge pull request #189 from torch/batchnorm
Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift
-rw-r--r-- | BatchNormalization.lua | 151 | ||||
-rw-r--r-- | SpatialBatchNormalization.lua | 179 | ||||
-rwxr-xr-x | doc/convolution.md | 84 | ||||
-rw-r--r-- | doc/simple.md | 168 | ||||
-rw-r--r-- | init.lua | 2 | ||||
-rw-r--r-- | test.lua | 106 |
6 files changed, 608 insertions, 82 deletions
diff --git a/BatchNormalization.lua b/BatchNormalization.lua new file mode 100644 index 0000000..bd38232 --- /dev/null +++ b/BatchNormalization.lua @@ -0,0 +1,151 @@ +--[[ + This file implements Batch Normalization as described in the paper: + "Batch Normalization: Accelerating Deep Network Training + by Reducing Internal Covariate Shift" + by Sergey Ioffe, Christian Szegedy + + This implementation is useful for inputs NOT coming from convolution layers. + For Convolution layers, see SpatialBatchNormalization.lua + + The operation implemented is: + y = ( x - mean(x) ) + -------------------- * gamma + beta + standard-deviation(x) + where gamma and beta are learnable parameters. + + The learning of gamma and beta is optional. + + Usage: + with learnable parameters: nn.BatchNormalization(N [, eps] [,momentum]) + where N = dimensionality of input + without learnable parameters: nn.BatchNormalization(0 [, eps] [,momentum]) + + eps is a small value added to the standard-deviation to avoid divide-by-zero. + Defaults to 1e-5 + + In training time, this layer keeps a running estimate of it's computed mean and std. + The running sum is kept with a default momentup of 0.1 (unless over-ridden) + In test time, this running mean/std is used to normalize. +]]-- +local BN,parent = torch.class('nn.BatchNormalization', 'nn.Module') + +function BN:__init(nOutput, eps, momentum) + parent.__init(self) + assert(nOutput and type(nOutput) == 'number', + 'Missing argument #1: dimensionality of input. ' .. + 'Give 0 for no affine transform') + self.eps = eps or 1e-5 + self.train = true + self.momentum = momentum or 0.1 + + if nOutput > 0 then self.affine = true end + if self.affine then + self.weight = torch.Tensor(nOutput) + self.bias = torch.Tensor(nOutput) + self.gradWeight = torch.Tensor(nOutput) + self.gradBias = torch.Tensor(nOutput) + self:reset() + end +end + +function BN:reset() + self.weight:uniform() + self.bias:zero() +end + +function BN:updateOutput(input) + assert(input:dim() == 2, 'only mini-batch supported (2D tensor), got ' + .. input:dim() .. 'D tensor instead') + local nBatch = input:size(1) + + -- buffers that are reused + self.buffer = self.buffer or input.new() + self.buffer2 = self.buffer2 or input.new() + self.centered = self.centered or input.new() + self.centered:resizeAs(input) + self.std = self.std or input.new() + self.normalized = self.normalized or input.new() + self.normalized:resizeAs(input) + self.output:resizeAs(input) + self.gradInput:resizeAs(input) + if self.train == false then + assert(self.running_mean, + 'Module never run on training data. First run on some training data before evaluating.') + self.output:copy(input) + self.buffer:repeatTensor(self.running_mean, nBatch, 1) + self.output:add(-1, self.buffer) + self.buffer:repeatTensor(self.running_std, nBatch, 1) + self.output:cmul(self.buffer) + else -- training mode + self.running_mean = self.running_mean or input.new(input:size(2)):zero() + self.running_std = self.running_std or input.new(input:size(2)):zero() + + -- calculate mean over mini-batch + self.buffer:mean(input, 1) -- E(x) = expectation of x. + self.running_mean:mul(1 - self.momentum):add(self.momentum, self.buffer) -- add to running mean + self.buffer:repeatTensor(self.buffer, nBatch, 1) + + -- subtract mean + self.centered:add(input, -1, self.buffer) -- x - E(x) + + -- calculate standard deviation over mini-batch + self.buffer:copy(self.centered):cmul(self.buffer) -- [x - E(x)]^2 + + -- 1 / E([x - E(x)]^2) + self.std:mean(self.buffer, 1):add(self.eps):sqrt():pow(-1) + self.running_std:mul(1 - self.momentum):add(self.momentum, self.std) -- add to running stdv + self.buffer:repeatTensor(self.std, nBatch, 1) + + -- divide standard-deviation + eps + self.output:cmul(self.centered, self.buffer) + self.normalized:copy(self.output) + end + + if self.affine then + -- multiply with gamma and add beta + self.buffer:repeatTensor(self.weight, nBatch, 1) + self.output:cmul(self.buffer) + self.buffer:repeatTensor(self.bias, nBatch, 1) + self.output:add(self.buffer) + end + return self.output +end + +function BN:updateGradInput(input, gradOutput) + assert(input:dim() == 2, 'only mini-batch supported') + assert(gradOutput:dim() == 2, 'only mini-batch supported') + assert(self.train == true, 'should be in training mode when self.train is true') + local nBatch = input:size(1) + + self.gradInput:cmul(self.centered, gradOutput) + self.buffer:mean(self.gradInput, 1) + self.gradInput:repeatTensor(self.buffer, nBatch, 1) + self.gradInput:cmul(self.centered):mul(-1) + self.buffer:repeatTensor(self.std, nBatch, 1) + self.gradInput:cmul(self.buffer):cmul(self.buffer) + + self.buffer:mean(gradOutput, 1) + self.buffer:repeatTensor(self.buffer, nBatch, 1) + self.gradInput:add(gradOutput):add(-1, self.buffer) + self.buffer:repeatTensor(self.std, nBatch, 1) + self.gradInput:cmul(self.buffer) + + if self.affine then + self.buffer:repeatTensor(self.weight, nBatch, 1) + self.gradInput:cmul(self.buffer) + end + + return self.gradInput +end + +function BN:accGradParameters(input, gradOutput, scale) + if self.affine then + scale = scale or 1.0 + self.buffer2:resizeAs(self.normalized):copy(self.normalized) + self.buffer2:cmul(gradOutput) + self.buffer:sum(self.buffer2, 1) -- sum over mini-batch + self.gradWeight:add(scale, self.buffer) + self.buffer:sum(gradOutput, 1) -- sum over mini-batch + self.gradBias:add(scale, self.buffer) + end +end diff --git a/SpatialBatchNormalization.lua b/SpatialBatchNormalization.lua new file mode 100644 index 0000000..ed612aa --- /dev/null +++ b/SpatialBatchNormalization.lua @@ -0,0 +1,179 @@ +--[[ + This file implements Batch Normalization as described in the paper: + "Batch Normalization: Accelerating Deep Network Training + by Reducing Internal Covariate Shift" + by Sergey Ioffe, Christian Szegedy + + This implementation is useful for inputs coming from convolution layers. + For Non-convolutional layers, see BatchNormalization.lua + + The operation implemented is: + y = ( x - mean(x) ) + -------------------- * gamma + beta + standard-deviation(x) + where gamma and beta are learnable parameters. + + The learning of gamma and beta is optional. + + Usage: + with learnable parameters: nn.BatchNormalization(N [,eps] [,momentum]) + where N = dimensionality of input + without learnable parameters: nn.BatchNormalization(0 [,eps] [,momentum]) + + eps is a small value added to the standard-deviation to avoid divide-by-zero. + Defaults to 1e-5 + + In training time, this layer keeps a running estimate of it's computed mean and std. + The running sum is kept with a default momentup of 0.1 (unless over-ridden) + In test time, this running mean/std is used to normalize. + +]]-- +local BN,parent = torch.class('nn.SpatialBatchNormalization', 'nn.Module') + +function BN:__init(nFeature, eps, momentum) + parent.__init(self) + assert(nFeature and type(nFeature) == 'number', + 'Missing argument #1: Number of feature planes. ' .. + 'Give 0 for no affine transform') + self.eps = eps or 1e-5 + self.train = true + self.momentum = momentum or 0.1 + + if nFeature > 0 then self.affine = true end + if self.affine then + self.weight = torch.Tensor(nFeature) + self.bias = torch.Tensor(nFeature) + self.gradWeight = torch.Tensor(nFeature) + self.gradBias = torch.Tensor(nFeature) + self:reset() + end +end + +function BN:reset() + self.weight:uniform() + self.bias:zero() +end + +function BN:updateOutput(input) + assert(input:dim() == 4, 'only mini-batch supported (4D tensor), got ' + .. input:dim() .. 'D tensor instead') + local nBatch = input:size(1) + local nFeature = input:size(2) + local iH = input:size(3) + local iW = input:size(4) + + -- buffers that are reused + self.buffer = self.buffer or input.new() + self.buffer2 = self.buffer2 or input.new() + self.centered = self.centered or input.new() + self.centered:resizeAs(input) + self.std = self.std or input.new() + self.normalized = self.normalized or input.new() + self.normalized:resizeAs(input) + self.output:resizeAs(input) + self.gradInput:resizeAs(input) + if self.train == false then + assert(self.running_mean, + 'Module never run on training data. First run on some training data before evaluating.') + self.output:copy(input) + self.buffer:repeatTensor(self.running_mean:view(1, nFeature, 1, 1), nBatch, 1, iH, iW) + self.output:add(-1, self.buffer) + self.buffer:repeatTensor(self.running_std:view(1, nFeature, 1, 1), nBatch, 1, iH, iW) + self.output:cmul(self.buffer) + else -- training mode + self.running_mean = self.running_mean or input.new(nFeature):zero() + self.running_std = self.running_std or input.new(nFeature):zero() + + -- calculate mean over mini-batch, over feature-maps + local in_folded = input:view(nBatch, nFeature, iH * iW) + self.buffer:mean(in_folded, 1) + self.buffer2:mean(self.buffer, 3) + self.running_mean:mul(1 - self.momentum):add(self.momentum, self.buffer2) -- add to running mean + self.buffer:repeatTensor(self.buffer2:view(1, nFeature, 1, 1), + nBatch, 1, iH, iW) + + -- subtract mean + self.centered:add(input, -1, self.buffer) -- x - E(x) + + -- calculate standard deviation over mini-batch + self.buffer:copy(self.centered):cmul(self.buffer) -- [x - E(x)]^2 + local buf_folded = self.buffer:view(nBatch,nFeature,iH*iW) + self.std:mean(self.buffer2:mean(buf_folded, 1), 3) + self.std:add(self.eps):sqrt():pow(-1) -- 1 / E([x - E(x)]^2) + self.running_std:mul(1 - self.momentum):add(self.momentum, self.std) -- add to running stdv + self.buffer:repeatTensor(self.std:view(1, nFeature, 1, 1), + nBatch, 1, iH, iW) + + -- divide standard-deviation + eps + self.output:cmul(self.centered, self.buffer) + self.normalized:copy(self.output) + end + + if self.affine then + -- multiply with gamma and add beta + self.buffer:repeatTensor(self.weight:view(1, nFeature, 1, 1), + nBatch, 1, iH, iW) + self.output:cmul(self.buffer) + self.buffer:repeatTensor(self.bias:view(1, nFeature, 1, 1), + nBatch, 1, iH, iW) + self.output:add(self.buffer) + end + + return self.output +end + +function BN:updateGradInput(input, gradOutput) + assert(input:dim() == 4, 'only mini-batch supported') + assert(gradOutput:dim() == 4, 'only mini-batch supported') + assert(self.train == true, 'should be in training mode when self.train is true') + local nBatch = input:size(1) + local nFeature = input:size(2) + local iH = input:size(3) + local iW = input:size(4) + + self.gradInput:cmul(self.centered, gradOutput) + local gi_folded = self.gradInput:view(nBatch, nFeature, iH * iW) + self.buffer2:mean(self.buffer:mean(gi_folded, 1), 3) + self.gradInput:repeatTensor(self.buffer2:view(1, nFeature, 1, 1), + nBatch, 1, iH, iW) + self.gradInput:cmul(self.centered):mul(-1) + self.buffer:repeatTensor(self.std:view(1, nFeature, 1, 1), + nBatch, 1, iH, iW) + self.gradInput:cmul(self.buffer):cmul(self.buffer) + + self.buffer:mean(gradOutput:view(nBatch, nFeature, iH*iW), 1) + self.buffer2:mean(self.buffer, 3) + self.buffer:repeatTensor(self.buffer2:view(1, nFeature, 1, 1), + nBatch, 1, iH, iW) + self.gradInput:add(gradOutput):add(-1, self.buffer) + self.buffer:repeatTensor(self.std:view(1, nFeature, 1, 1), + nBatch, 1, iH, iW) + self.gradInput:cmul(self.buffer) + + if self.affine then + self.buffer:repeatTensor(self.weight:view(1, nFeature, 1, 1), + nBatch, 1, iH, iW) + self.gradInput:cmul(self.buffer) + end + + return self.gradInput +end + +function BN:accGradParameters(input, gradOutput, scale) + if self.affine then + scale = scale or 1.0 + local nBatch = input:size(1) + local nFeature = input:size(2) + local iH = input:size(3) + local iW = input:size(4) + self.buffer2:resizeAs(self.normalized):copy(self.normalized) + self.buffer2 = self.buffer2:cmul(gradOutput):view(nBatch, nFeature, iH*iW) + self.buffer:sum(self.buffer2, 1) -- sum over mini-batch + self.buffer2:sum(self.buffer, 3) -- sum over pixels + self.gradWeight:add(scale, self.buffer2) + + self.buffer:sum(gradOutput:view(nBatch, nFeature, iH*iW), 1) + self.buffer2:sum(self.buffer, 3) + self.gradBias:add(scale, self.buffer2) -- sum over mini-batch + end +end diff --git a/doc/convolution.md b/doc/convolution.md index 47a7d4d..d5f94aa 100755 --- a/doc/convolution.md +++ b/doc/convolution.md @@ -2,7 +2,7 @@ # Convolutional layers # A convolution is an integral that expresses the amount of overlap of one function `g` as it is shifted over another function `f`. It therefore "blends" one function with another. The neural network package supports convolution, pooling, subsampling and other relevant facilities. These are divided base on the dimensionality of the input and output [Tensors](https://github.com/torch/torch7/blob/master/doc/tensor.md#tensor): - * [Temporal Modules](#nn.TemporalModules) apply to sequences with a one-dimensional relationship + * [Temporal Modules](#nn.TemporalModules) apply to sequences with a one-dimensional relationship (e.g. sequences of words, phonemes and letters. Strings of some kind). * [TemporalConvolution](#nn.TemporalConvolution) : a 1D convolution over an input sequence ; * [TemporalSubSampling](#nn.TemporalSubSampling) : a 1D sub-sampling over an input sequence ; @@ -18,6 +18,7 @@ A convolution is an integral that expresses the amount of overlap of one functio * [SpatialConvolutionMap](#nn.SpatialConvolutionMap) : a 2D convolution that uses a generic connection table ; * [SpatialZeroPadding](#nn.SpatialZeroPadding) : padds a feature map with specified number of zeros ; * [SpatialSubtractiveNormalization](#nn.SpatialSubtractiveNormalization) : a spatial subtraction operation on a series of 2D inputs using + * [SpatialBatchNormalization](#nn.SpatialBatchNormalization): mean/std normalization over the mini-batch inputs and pixels, with an optional affine transform that follows a kernel for computing the weighted average in a neighborhood ; * [Volumetric Modules](#nn.VolumetricModules) apply to inputs with three-dimensional relationships (e.g. videos) : * [VolumetricConvolution](#nn.VolumetricConvolution) : a 3D convolution over an input video (a sequence of images) ; @@ -27,10 +28,10 @@ a kernel for computing the weighted average in a neighborhood ; ## Temporal Modules ## Excluding an optional first batch dimension, temporal layers expect a 2D Tensor as input. The first dimension is the number of frames in the sequence (e.g. `nInputFrame`), the last dimenstion -is the number of features per frame (e.g. `inputFrameSize`). The output will normally have the same number -of dimensions, although the size of each dimension may change. These are commonly used for processing acoustic signals or sequences of words, i.e. in Natural Language Processing. +is the number of features per frame (e.g. `inputFrameSize`). The output will normally have the same number +of dimensions, although the size of each dimension may change. These are commonly used for processing acoustic signals or sequences of words, i.e. in Natural Language Processing. -Note: The [LookupTable](#nn.LookupTable) is special in that while it does output a temporal Tensor of size `nOutputFrame x outputFrameSize`, +Note: The [LookupTable](#nn.LookupTable) is special in that while it does output a temporal Tensor of size `nOutputFrame x outputFrameSize`, its input is a 1D Tensor of indices of size `nIndices`. Again, this is excluding the option first batch dimension. <a name="nn.TemporalConvolution"/> @@ -77,7 +78,7 @@ output[t][i] = bias[i] Here is a simple example: ```lua -inp=5; -- dimensionality of one sequence element +inp=5; -- dimensionality of one sequence element outp=1; -- number of derived features for one sequence element kw=1; -- kernel only operates on one sequence element per step dw=1; -- we step once and go on to the next sequence element @@ -93,8 +94,8 @@ which gives: -0.9872 -0.6808 -0.9403 --0.9680 --0.6901 +-0.9680 +-0.6901 -0.6387 [torch.Tensor of dimension 7x1] ``` @@ -128,7 +129,7 @@ module = nn.TemporalMaxPooling(kW, [dW]) Applies 1D max-pooling operation in `kW` regions by step size `dW` steps. Input sequence composed of `nInputFrame` frames. The `input` tensor in -`forward(input)` is expected to be a 2D tensor (`nInputFrame x inputFrameSize`) +`forward(input)` is expected to be a 2D tensor (`nInputFrame x inputFrameSize`) or a 3D tensor (`nBatchFrame x nInputFrame x inputFrameSize`). If the input sequence is a 2D tensor of dimension `nInputFrame x inputFrameSize`, the output sequence will be @@ -185,16 +186,16 @@ module = nn.LookupTable(nIndex, size1, [size2], [size3], ...) ``` This layer is a particular case of a convolution, where the width of the convolution would be `1`. -When calling `forward(input)`, it assumes `input` is a 1D or 2D tensor filled with indices. +When calling `forward(input)`, it assumes `input` is a 1D or 2D tensor filled with indices. If the input is a matrix, then each row is assumed to be an input sample of given batch. Indices start at `1` and can go up to `nIndex`. For each index, it outputs a corresponding `Tensor` of size specified by `sizes` (a `LongStorage`) or `size1 x size2 x...`. -Given a 1D input, the output tensors are concatenated, +Given a 1D input, the output tensors are concatenated, generating a `n x size1 x size2 x ... x sizeN` tensor, where `n` -is the size of a 1D `input` tensor. +is the size of a 1D `input` tensor. -Again with a 1D input, when only `size1` is provided, the `forward(input)` is equivalent to +Again with a 1D input, when only `size1` is provided, the `forward(input)` is equivalent to performing the following matrix-matrix multiplication in an efficient manner: ```lua M P @@ -205,7 +206,7 @@ where `M` is a 2D matrix `size1 x nIndex` containing the parameters of the looku 1D example: ```lua -- a lookup table containing 10 tensors of size 3 - module = nn.LookupTable(10, 3) + module = nn.LookupTable(10, 3) input = torch.Tensor{1,2,1,10} print(module:forward(input)) @@ -221,14 +222,14 @@ Outputs something like: ``` Note that the first row vector is the same as the 3rd one! -Given a 2D input tensor of size `m x n`, the output is a `m x n x size1 x size2 x ... x sizeN` -tensor, where `m` is the number of samples in +Given a 2D input tensor of size `m x n`, the output is a `m x n x size1 x size2 x ... x sizeN` +tensor, where `m` is the number of samples in the batch and `n` is the number of indices per sample. 2D example: ```lua -- a lookup table containing 10 tensors of size 3 - module = nn.LookupTable(10, 3) + module = nn.LookupTable(10, 3) -- a batch of 2 samples of 4 indices each input = torch.Tensor({{1,2,4,5},{4,3,2,10}}) @@ -237,13 +238,13 @@ the batch and `n` is the number of indices per sample. Outputs something like: ```lua -(1,.,.) = +(1,.,.) = -0.0570 -1.5354 1.8555 -0.9067 1.3392 0.6275 1.9662 0.4645 -0.8111 0.1103 1.7811 1.5969 -(2,.,.) = +(2,.,.) = 1.9662 0.4645 -0.8111 0.0026 -1.4547 -0.5154 -0.9067 1.3392 0.6275 @@ -453,7 +454,7 @@ Applies a 2D up-sampling over an input image composed of several input planes. T The parameters are the following: * `scale`: The upscale ratio. Must be a positive integer -The up-scaling method is simple nearest neighbor, ie: +The up-scaling method is simple nearest neighbor, ie: ```lua output(u,v) = input(floor((u-1)/scale)+1, floor((v-1)/scale)+1) @@ -508,10 +509,53 @@ w2=image.display(processed) ``` ![](image/lena.jpg)![](image/lenap.jpg) +<a name="nn.SpatialBatchNormalization"/> +## SpatialBatchNormalization ## + +`module` = `nn.SpatialBatchNormalization(N [,eps] [, momentum])` + where N = number of input feature maps +giving N = 0 disables the learnable affine transform. +eps is a small value added to the standard-deviation to avoid divide-by-zero. Defaults to 1e-5 + +Implements Batch Normalization as described in the paper: + "Batch Normalization: Accelerating Deep Network Training + by Reducing Internal Covariate Shift" + by Sergey Ioffe, Christian Szegedy + +The operation implemented is: +``` + y = ( x - mean(x) ) + -------------------- * gamma + beta + standard-deviation(x) +``` +where the mean and standard-deviation are calculated per feature-map over the mini-batches and pixels +and where gamma and beta are learnable parameter vectors of size N (where N = number of feature maps). +The learning of gamma and beta is optional. + + In training time, this layer keeps a running estimate of it's computed mean and std. + The running sum is kept with a default momentup of 0.1 (unless over-ridden) + In test time, this running mean/std is used to normalize. + + + +The module only accepts 4D inputs. + +```lua +-- with learnable parameters +model = nn.SpatialBatchNormalization(m) +A = torch.randn(b, m) +B = model.forward(B) -- C will be of size `b x m` + +-- without learnable parameters +model = nn.SpatialBatchNormalization(0) +A = torch.randn(b, m) +B = model.forward(B) -- C will be of size `b x m` +``` + <a name="nn.VolumetricModules"/> ## Volumetric Modules ## Excluding and optional batch dimension, volumetric layers expect a 4D Tensor as input. The -first dimension is the number of features (e.g. `frameSize`), the second is sequential (e.g. `time`) and the +first dimension is the number of features (e.g. `frameSize`), the second is sequential (e.g. `time`) and the last two dimenstions are spatial (e.g. `height x width`). These are commonly used for processing videos (sequences of images). <a name="nn.VolumetricConvolution"/> diff --git a/doc/simple.md b/doc/simple.md index 999880e..f649464 100644 --- a/doc/simple.md +++ b/doc/simple.md @@ -1,6 +1,6 @@ <a name="nn.simplelayers.dok"/> # Simple layers # -Simple Modules are used for various tasks like adapting Tensor methods +Simple Modules are used for various tasks like adapting Tensor methods and providing affine transformations : * Parameterized Modules : * [Linear](#nn.Linear) : a linear transformation ; @@ -29,10 +29,12 @@ and providing affine transformations : * [Square](#nn.Square) : an element-wise square operation ; * [Sqrt](#nn.Sqrt) : an element-wise [sqrt](https://github.com/torch/torch7/blob/master/doc/maths.md#res-torchsqrtres-x) operation ; * [MM](#nn.MM) : matrix-matrix multiplication (also supports batches of matrices) ; +* Normalization modules: + * [BatchNormalization](#nn.BatchNormalization) - mean/std normalization over the mini-batch inputs, with an optional affine transform that follows * Miscellaneous Modules : * [Identity](#nn.Identity) : forward input as-is to output (useful with [ParallelTable](table.md#nn.ParallelTable)); * [Dropout](#nn.Dropout) : masks parts of the `input` using binary samples from a [bernoulli](http://en.wikipedia.org/wiki/Bernoulli_distribution) distribution ; - + <a name="nn.Linear"/> ## Linear ## @@ -77,9 +79,9 @@ applying the linear transformation is performed with: Applies a linear transformation to the incoming sparse data, i.e. _y= Ax+b_. The `input` tensor given in `forward(input)` must -be a sparse vector represented as 2D tensor of the form +be a sparse vector represented as 2D tensor of the form torch.Tensor(N, 2) where the pairs represent indices and values. -The SparseLinear layer is useful when the number of input +The SparseLinear layer is useful when the number of input dimensions is very large and the input data is sparse. You can create a sparse linear layer in the following way: @@ -87,9 +89,9 @@ You can create a sparse linear layer in the following way: ```lua module= nn.SparseLinear(10000,2) -- 10000 inputs, 2 outputs ``` -The sparse linear module may be used as part of a larger network, -and apart from the form of the input, -[SparseLinear](#nn.SparseLinear) +The sparse linear module may be used as part of a larger network, +and apart from the form of the input, +[SparseLinear](#nn.SparseLinear) operates in exactly the same way as the [Linear](#nn.Linear) layer. A sparse input vector may be created as so.. @@ -107,9 +109,9 @@ A sparse input vector may be created as so.. ``` -The first column contains indices, the second column contains -values in a a vector where all other elements are zeros. The -indices should not exceed the stated dimensions of the input to the +The first column contains indices, the second column contains +values in a a vector where all other elements are zeros. The +indices should not exceed the stated dimensions of the input to the layer (10000 in the example). @@ -118,17 +120,17 @@ layer (10000 in the example). `module` = `nn.Dropout(p)` -During training, `Dropout` masks parts of the `input` using binary samples from +During training, `Dropout` masks parts of the `input` using binary samples from a [bernoulli](http://en.wikipedia.org/wiki/Bernoulli_distribution) distribution. Each `input` element has a probability of `p` of being dropped, i.e having its -commensurate output element be zero. This has proven an effective technique for -regularization and preventing the co-adaptation of neurons -(see [Hinton et al. 2012](http://arxiv.org/abs/1207.0580)). +commensurate output element be zero. This has proven an effective technique for +regularization and preventing the co-adaptation of neurons +(see [Hinton et al. 2012](http://arxiv.org/abs/1207.0580)). -Furthermore, the ouputs are scaled by a factor of `1/(1-p)` during training. This allows the +Furthermore, the ouputs are scaled by a factor of `1/(1-p)` during training. This allows the `input` to be simply forwarded as-is during evaluation. -In this example, we demonstrate how the call to [forward](module.md#output-forwardinput) samples +In this example, we demonstrate how the call to [forward](module.md#output-forwardinput) samples different `outputs` to dropout (the zeros) given the same `input`: ```lua module = nn.Dropout() @@ -162,7 +164,7 @@ module = nn.Dropout() ``` In both cases the `gradOutput` and `input` are scaled by `1/(1-p)`, which in this case is `2`. -During [evaluation](module.md#evaluate), `Dropout` does nothing more than +During [evaluation](module.md#evaluate), `Dropout` does nothing more than forward the input such that all elements of the input are considered. ```lua > module:evaluate() @@ -185,13 +187,13 @@ We can return to training our model by first calling [Module:training()](module. ``` -When used, `Dropout` should normally be applied to the input of parameterized -[Modules](module.md#nn.Module) like [Linear](#nn.Linear) +When used, `Dropout` should normally be applied to the input of parameterized +[Modules](module.md#nn.Module) like [Linear](#nn.Linear) or [SpatialConvolution](convolution.md#nn.SpatialConvolution). A `p` of `0.5` (the default) is usually okay for hidden layers. `Dropout` can sometimes be used successfully on the dataset inputs with a `p` around `0.2`. -It sometimes works best following [Transfer](transfer.md) Modules -like [ReLU](transfer.md#nn.ReLU). All this depends a great deal on the dataset so its up +It sometimes works best following [Transfer](transfer.md) Modules +like [ReLU](transfer.md#nn.ReLU). All this depends a great deal on the dataset so its up to the user to try different combinations. @@ -221,15 +223,15 @@ gnuplot.grid(true) Applies a bias term to the incoming data, i.e. _y_i= x_i + b_i, or if _scalar=true_ then uses a single bias term, -_y_i= x_i + b. +_y_i= x_i + b. Example: ```lua -y=torch.Tensor(5); +y=torch.Tensor(5); mlp=nn.Sequential() mlp:add(nn.Add(5)) -function gradUpdate(mlp, x, y, criterion, learningRate) +function gradUpdate(mlp, x, y, criterion, learningRate) local pred = mlp:forward(x) local err = criterion:forward(pred, y) local gradCriterion = criterion:backward(pred, y) @@ -241,7 +243,7 @@ end for i=1,10000 do x=torch.rand(5) - y:copy(x); + y:copy(x); for i=1,5 do y[i]=y[i]+i; end err=gradUpdate(mlp,x,y,nn.MSECriterion(),0.01) end @@ -256,7 +258,7 @@ gives the output: 5.0000 [torch.Tensor of dimension 5] ``` -i.e. the network successfully learns the input _x_ has been shifted +i.e. the network successfully learns the input _x_ has been shifted to produce the output _y_. @@ -266,15 +268,15 @@ to produce the output _y_. `module` = `Mul()` Applies a _single_ scaling factor to the incoming data, i.e. -_y= w x_, where _w_ is a scalar. +_y= w x_, where _w_ is a scalar. Example: ```lua -y=torch.Tensor(5); +y=torch.Tensor(5); mlp=nn.Sequential() mlp:add(nn.Mul()) -function gradUpdate(mlp, x, y, criterion, learningRate) +function gradUpdate(mlp, x, y, criterion, learningRate) local pred = mlp:forward(x) local err = criterion:forward(pred,y) local gradCriterion = criterion:backward(pred,y); @@ -307,7 +309,7 @@ pi. Applies a component-wise multiplication to the incoming data, i.e. `y_i = w_i * x_i`. Argument `size` can be one or many numbers (sizes) -or a `torch.LongStorage`. For example, `nn.CMul(3,4,5)` is equivalent to +or a `torch.LongStorage`. For example, `nn.CMul(3,4,5)` is equivalent to `nn.CMul(torch.LongStorage{3,4,5})`. Example: @@ -315,10 +317,10 @@ Example: mlp=nn.Sequential() mlp:add(nn.CMul(5)) -y=torch.Tensor(5); +y=torch.Tensor(5); sc=torch.Tensor(5); for i=1,5 do sc[i]=i; end -- scale input with this -function gradUpdate(mlp,x,y,criterion,learningRate) +function gradUpdate(mlp,x,y,criterion,learningRate) local pred = mlp:forward(x) local err = criterion:forward(pred,y) local gradCriterion = criterion:backward(pred,y); @@ -394,7 +396,7 @@ then an `nxq` matrix would be output. Outputs the Euclidean distance of the input to `outputSize` centers, i.e. this layer has the weights `w_j`, for `j` = `1`,..,`outputSize`, where -`w_j` are vectors of dimension `inputSize`. +`w_j` are vectors of dimension `inputSize`. The distance `y_j` between center `j` and input `x` is formulated as `y_j = || w_j - x ||`. @@ -406,10 +408,10 @@ The distance `y_j` between center `j` and input `x` is formulated as This module is similar to [Euclidean](#nn.Euclidean), but additionally learns a separate diagonal covariance matrix across the -features of the input space _for each center_. +features of the input space _for each center_. -In other words, for each of the `outputSize` centers `w_j`, there is -a diagonal covariance matrices `c_j`, for `j` = `1`,..,`outputSize`, +In other words, for each of the `outputSize` centers `w_j`, there is +a diagonal covariance matrices `c_j`, for `j` = `1`,..,`outputSize`, where `c_j` are stored as vectors of size `inputSize`. The distance `y_j` between center `j` and input `x` is formulated as @@ -420,8 +422,8 @@ The distance `y_j` between center `j` and input `x` is formulated as `module` = `Identity()` -Creates a module that returns whatever is input to it as output. -This is useful when combined with the module +Creates a module that returns whatever is input to it as output. +This is useful when combined with the module [ParallelTable](table.md#nn.ParallelTable) in case you do not wish to do anything to one of the input Tensors. Example: @@ -429,7 +431,7 @@ Example: mlp=nn.Identity() print(mlp:forward(torch.ones(5,2))) ``` -gives the output: +gives the output: ```lua 1 1 1 1 @@ -440,10 +442,10 @@ gives the output: ``` Here is a more useful example, where one can implement a network which also computes a Criterion using this module: -```lua +```lua pred_mlp=nn.Sequential(); -- A network that makes predictions given x. -pred_mlp:add(nn.Linear(5,4)) -pred_mlp:add(nn.Linear(4,3)) +pred_mlp:add(nn.Linear(5,4)) +pred_mlp:add(nn.Linear(4,3)) xy_mlp=nn.ParallelTable();-- A network for predictions and for keeping the xy_mlp:add(pred_mlp) -- true label for comparison with a criterion @@ -457,14 +459,14 @@ mlp:add(cr_wrap) -- and then applies the criterion. for i=1,100 do -- Do a few training iterations x=torch.ones(5); -- Make input features. - y=torch.Tensor(3); + y=torch.Tensor(3); y:copy(x:narrow(1,1,3)) -- Make output label. err=mlp:forward{x,y} -- Forward both input and output. print(err) -- Print error from criterion. - mlp:zeroGradParameters(); -- Do backprop... - mlp:backward({x, y} ); - mlp:updateParameters(0.05); + mlp:zeroGradParameters(); -- Do backprop... + mlp:backward({x, y} ); + mlp:updateParameters(0.05); end ``` @@ -476,9 +478,9 @@ end This layer copies the input to output with type casting from input type from `inputType` to `outputType`. Unless `forceCopy` is true, when the first two arguments are the same, the input isn't copied, only transfered -as the output. The default `forceCopy` is false. +as the output. The default `forceCopy` is false. When `dontCast` is true, a call to `nn.Copy:type(type)` will not cast -the module's `output` and `gradInput` Tensors to the new type. The default +the module's `output` and `gradInput` Tensors to the new type. The default is false. <a name="nn.Narrow"/> @@ -543,16 +545,16 @@ torch> =o `module` = `Reshape(dimension1, dimension2, ... [, batchMode])` Reshapes an `nxpxqx..` Tensor into a `dimension1xdimension2x...` Tensor, -taking the elements column-wise. +taking the elements column-wise. -The optional last argument `batchMode`, -when `true` forces the first dimension of the input to be considered -the batch dimension, and thus keep its size fixed. This is necessary when -dealing with batch sizes of one. When `false`, it forces the -entire input (including the first dimension) to be reshaped to the -input size. Default `batchMode=nil`, which means that the module -considers inputs with more elements than the produce of provided sizes, -i.e. `dimension1xdimension2x...`, to be batches. +The optional last argument `batchMode`, +when `true` forces the first dimension of the input to be considered +the batch dimension, and thus keep its size fixed. This is necessary when +dealing with batch sizes of one. When `false`, it forces the +entire input (including the first dimension) to be reshaped to the +input size. Default `batchMode=nil`, which means that the module +considers inputs with more elements than the produce of provided sizes, +i.e. `dimension1xdimension2x...`, to be batches. Example: ```lua @@ -731,17 +733,17 @@ gives the output: ``` This can be used in conjunction with [Concat](containers.md#nn.Concat) -to emulate the behavior -of [Parallel](containers.md#nn.Parallel), or to select various parts of an input Tensor to +to emulate the behavior +of [Parallel](containers.md#nn.Parallel), or to select various parts of an input Tensor to perform operations on. Here is a fairly complicated example: ```lua mlp=nn.Sequential(); -c=nn.Concat(2) +c=nn.Concat(2) for i=1,10 do local t=nn.Sequential() t:add(nn.Select(1,i)) - t:add(nn.Linear(3,2)) + t:add(nn.Linear(3,2)) t:add(nn.Reshape(2,1)) c:add(t) end @@ -759,7 +761,7 @@ for i=1,10000 do -- Train for a few iterations err=criterion:forward(pred,y) gradCriterion = criterion:backward(pred,y); mlp:zeroGradParameters(); - mlp:backward(x, gradCriterion); + mlp:backward(x, gradCriterion); mlp:updateParameters(0.01); print(err) end @@ -853,3 +855,45 @@ A = torch.randn(b, m, n) B = torch.randn(b, n, p) C = model.forward({A, B}) -- C will be of size `b x m x n` ``` + +<a name="nn.BatchNormalization"/> +## BatchNormalization ## + +`module` = `nn.BatchNormalization(N [, eps] [, momentum])` + where N = dimensionality of input +giving N = 0 disables the learnable affine transform. +eps is a small value added to the standard-deviation to avoid divide-by-zero. Defaults to 1e-5 + +In training time, this layer keeps a running estimate of it's computed mean and std. +The running sum is kept with a default momentup of 0.1 (unless over-ridden) +In test time, this running mean/std is used to normalize. + + +Implements Batch Normalization as described in the paper: + "Batch Normalization: Accelerating Deep Network Training + by Reducing Internal Covariate Shift" + by Sergey Ioffe, Christian Szegedy + +The operation implemented is: +``` + y = ( x - mean(x) ) + -------------------- * gamma + beta + standard-deviation(x) + eps +``` +where the mean and standard-deviation are calculated per-dimension over the mini-batches +and where gamma and beta are learnable parameter vectors of size N (where N = input dimensionality). +The learning of gamma and beta is optional. + +The module only accepts 2D inputs. + +```lua +-- with learnable parameters +model = nn.BatchNormalization(m) +A = torch.randn(b, m) +B = model.forward(B) -- C will be of size `b x m` + +-- without learnable parameters +model = nn.BatchNormalization(0) +A = torch.randn(b, m) +B = model.forward(B) -- C will be of size `b x m` +``` @@ -18,6 +18,7 @@ include('Select.lua') include('Narrow.lua') include('Replicate.lua') include('Transpose.lua') +include('BatchNormalization.lua') include('Copy.lua') include('Min.lua') @@ -86,6 +87,7 @@ include('SpatialDivisiveNormalization.lua') include('SpatialContrastiveNormalization.lua') include('SpatialZeroPadding.lua') include('SpatialUpSamplingNearest.lua') +include('SpatialBatchNormalization.lua') include('VolumetricConvolution.lua') include('VolumetricMaxPooling.lua') @@ -3195,6 +3195,112 @@ function nntest.CosineEmbeddingCriterion() equal(grads[2], zero, 'gradient should be zero') end +function nntest.BatchNormalization() + local nframes = torch.random(50,70) + local indim = torch.random(1,10) + local input = torch.zeros(nframes, indim):uniform() + local module = nn.BatchNormalization(indim) + + local err = jac.testJacobian(module,input) + mytester:assertlt(err,precision, 'error on state ') + + local err = jac.testJacobianParameters(module, input, + module.weight, module.gradWeight) + mytester:assertlt(err,precision, 'error on weight ') + + local err = jac.testJacobianParameters(module, input, + module.bias, module.gradBias) + mytester:assertlt(err,precision, 'error on weight ') + + local err = jac.testJacobianUpdateParameters(module, input, module.weight) + mytester:assertlt(err,precision, 'error on weight [direct update] ') + + local err = jac.testJacobianUpdateParameters(module, input, module.bias) + mytester:assertlt(err,precision, 'error on bias [direct update] ') + + for t,err in pairs(jac.testAllUpdate(module, input, + 'weight', 'gradWeight')) do + mytester:assertlt(err, precision, string.format( + 'error on weight [%s]', t)) + end + + for t,err in pairs(jac.testAllUpdate(module, input, + 'bias', 'gradBias')) do + mytester:assertlt(err, precision, string.format( + 'error on bias [%s]', t)) + end + + -- IO + local ferr,berr = jac.testIO(module,input) + mytester:asserteq(ferr, 0, torch.typename(module) .. ' - i/o forward err ') + mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err ') + + -- batch norm without affine transform + module = nn.BatchNormalization(0) + + local err = jac.testJacobian(module,input) + mytester:assertlt(err,precision, 'error on state ') + + -- IO + local ferr,berr = jac.testIO(module,input) + mytester:asserteq(ferr, 0, torch.typename(module) .. ' - i/o forward err ') + mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err ') +end + +function nntest.SpatialBatchNormalization() + local nframes = torch.random(1,10) + local indim = torch.random(1,4) + local ini = torch.random(1,5) + local inj = torch.random(1,5) + local input = torch.zeros(nframes, indim, ini, inj):uniform() + local module = nn.SpatialBatchNormalization(indim) + + local err = jac.testJacobian(module,input) + mytester:assertlt(err,precision, 'error on state ') + + local err = jac.testJacobianParameters(module, input, + module.weight, module.gradWeight) + mytester:assertlt(err,precision, 'error on weight ') + + local err = jac.testJacobianParameters(module, input, + module.bias, module.gradBias) + mytester:assertlt(err,precision, 'error on weight ') + + local err = jac.testJacobianUpdateParameters(module, input, module.weight) + mytester:assertlt(err,precision, 'error on weight [direct update] ') + + local err = jac.testJacobianUpdateParameters(module, input, module.bias) + mytester:assertlt(err,precision, 'error on bias [direct update] ') + + for t,err in pairs(jac.testAllUpdate(module, input, + 'weight', 'gradWeight')) do + mytester:assertlt(err, precision, string.format( + 'error on weight [%s]', t)) + end + + for t,err in pairs(jac.testAllUpdate(module, input, + 'bias', 'gradBias')) do + mytester:assertlt(err, precision, string.format( + 'error on bias [%s]', t)) + end + + -- IO + local ferr,berr = jac.testIO(module,input) + mytester:asserteq(ferr, 0, torch.typename(module) .. ' - i/o forward err ') + mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err ') + + -- batch norm without affine transform + module = nn.SpatialBatchNormalization(0) + + local err = jac.testJacobian(module,input) + mytester:assertlt(err,precision, 'error on state ') + + -- IO + local ferr,berr = jac.testIO(module,input) + mytester:asserteq(ferr, 0, torch.typename(module) .. ' - i/o forward err ') + mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err ') +end + mytester:add(nntest) if not nn then |