diff options
author | Sam Gross <sgross@fb.com> | 2016-01-05 07:57:45 +0300 |
---|---|---|
committer | Sam Gross <sgross@fb.com> | 2016-01-05 22:45:11 +0300 |
commit | a142233b8bebc7a4acfa0ad6a66c1d400803034e (patch) | |
tree | bcef518cdd9125c5c68d493bbfc5873261b72192 /SpatialBatchNormalization.lua | |
parent | 31f71b92881ad1ef3356a5077725f0541f765340 (diff) |
Add C implementation of SpatialBatchNormalization
This is primarily to support the fast, memory-efficient CUDA
implementation. Some other changes include making weight and bias each
individually optional and averaging the variances instead of the
inverse standard deviation.
Diffstat (limited to 'SpatialBatchNormalization.lua')
-rw-r--r-- | SpatialBatchNormalization.lua | 167 |
1 files changed, 62 insertions, 105 deletions
diff --git a/SpatialBatchNormalization.lua b/SpatialBatchNormalization.lua index a5eac14..e844385 100644 --- a/SpatialBatchNormalization.lua +++ b/SpatialBatchNormalization.lua @@ -18,9 +18,9 @@ Usage: with learnable parameters: nn.BatchNormalization(N [,eps] [,momentum]) where N = dimensionality of input - without learnable parameters: nn.BatchNormalization(0 [,eps] [,momentum]) + without learnable parameters: nn.BatchNormalization(N [,eps] [,momentum], false) - eps is a small value added to the standard-deviation to avoid divide-by-zero. + eps is a small value added to the variance to avoid divide-by-zero. Defaults to 1e-5 In training time, this layer keeps a running estimate of it's computed mean and std. @@ -30,13 +30,15 @@ ]]-- local BN,parent = torch.class('nn.SpatialBatchNormalization', 'nn.Module') +BN.__version = 2 + function BN:__init(nFeature, eps, momentum, affine) parent.__init(self) assert(nFeature and type(nFeature) == 'number', 'Missing argument #1: Number of feature planes. ') assert(nFeature ~= 0, 'To set affine=false call SpatialBatchNormalization' .. '(nFeature, eps, momentum, false) ') - if affine ~=nil then + if affine ~= nil then assert(type(affine) == 'boolean', 'affine has to be true/false') self.affine = affine else @@ -47,7 +49,7 @@ function BN:__init(nFeature, eps, momentum, affine) self.momentum = momentum or 0.1 self.running_mean = torch.zeros(nFeature) - self.running_std = torch.ones(nFeature) + self.running_var = torch.ones(nFeature) if self.affine then self.weight = torch.Tensor(nFeature) self.bias = torch.Tensor(nFeature) @@ -58,127 +60,82 @@ function BN:__init(nFeature, eps, momentum, affine) end function BN:reset() - self.weight:uniform() - self.bias:zero() - self.running_mean:zero() - self.running_std:fill(1) + if self.weight then + self.weight:uniform() + end + if self.bias then + self.bias:zero() + end end function BN:updateOutput(input) assert(input:dim() == 4, 'only mini-batch supported (4D tensor), got ' .. input:dim() .. 'D tensor instead') - local nBatch = input:size(1) - local nFeature = input:size(2) - local iH = input:size(3) - local iW = input:size(4) - -- buffers that are reused - self.buffer = self.buffer or input.new() self.output:resizeAs(input) - if self.train == false then - self.output:copy(input) - self.buffer:repeatTensor(self.running_mean:view(1, nFeature, 1, 1), nBatch, 1, iH, iW) - self.output:add(-1, self.buffer) - self.buffer:repeatTensor(self.running_std:view(1, nFeature, 1, 1), nBatch, 1, iH, iW) - self.output:cmul(self.buffer) - else -- training mode - self.buffer2 = self.buffer2 or input.new() - self.centered = self.centered or input.new() - self.centered:resizeAs(input) - self.std = self.std or input.new() - self.normalized = self.normalized or input.new() - self.normalized:resizeAs(input) - self.gradInput:resizeAs(input) - -- calculate mean over mini-batch, over feature-maps - local in_folded = input:view(nBatch, nFeature, iH * iW) - self.buffer:mean(in_folded, 1) - self.buffer2:mean(self.buffer, 3) - self.running_mean:mul(1 - self.momentum):add(self.momentum, self.buffer2) -- add to running mean - self.buffer:repeatTensor(self.buffer2:view(1, nFeature, 1, 1), - nBatch, 1, iH, iW) - - -- subtract mean - self.centered:add(input, -1, self.buffer) -- x - E(x) - - -- calculate standard deviation over mini-batch - self.buffer:copy(self.centered):cmul(self.buffer) -- [x - E(x)]^2 - local buf_folded = self.buffer:view(nBatch,nFeature,iH*iW) - self.std:mean(self.buffer2:mean(buf_folded, 1), 3) - self.std:add(self.eps):sqrt():pow(-1) -- 1 / E([x - E(x)]^2) - self.running_std:mul(1 - self.momentum):add(self.momentum, self.std) -- add to running stdv - self.buffer:repeatTensor(self.std:view(1, nFeature, 1, 1), - nBatch, 1, iH, iW) - - -- divide standard-deviation + eps - self.output:cmul(self.centered, self.buffer) - self.normalized:copy(self.output) - end - - if self.affine then - -- multiply with gamma and add beta - self.buffer:repeatTensor(self.weight:view(1, nFeature, 1, 1), - nBatch, 1, iH, iW) - self.output:cmul(self.buffer) - self.buffer:repeatTensor(self.bias:view(1, nFeature, 1, 1), - nBatch, 1, iH, iW) - self.output:add(self.buffer) - end + self.save_mean = self.save_mean or input.new():resizeAs(self.running_mean) + self.save_std = self.save_std or input.new():resizeAs(self.running_var) + + input.nn.SpatialBatchNormalization_updateOutput( + input, + self.output, + self.weight, + self.bias, + self.train, + self.eps, + self.momentum, + self.running_mean, + self.running_var, + self.save_mean, + self.save_std) return self.output end -function BN:updateGradInput(input, gradOutput) +local function backward(self, input, gradOutput, scale, gradInput, gradWeight, gradBias) assert(input:dim() == 4, 'only mini-batch supported') assert(gradOutput:dim() == 4, 'only mini-batch supported') assert(self.train == true, 'should be in training mode when self.train is true') - local nBatch = input:size(1) - local nFeature = input:size(2) - local iH = input:size(3) - local iW = input:size(4) - - self.gradInput:cmul(self.centered, gradOutput) - local gi_folded = self.gradInput:view(nBatch, nFeature, iH * iW) - self.buffer2:mean(self.buffer:mean(gi_folded, 1), 3) - self.gradInput:repeatTensor(self.buffer2:view(1, nFeature, 1, 1), - nBatch, 1, iH, iW) - self.gradInput:cmul(self.centered):mul(-1) - self.buffer:repeatTensor(self.std:view(1, nFeature, 1, 1), - nBatch, 1, iH, iW) - self.gradInput:cmul(self.buffer):cmul(self.buffer) - - self.buffer:mean(gradOutput:view(nBatch, nFeature, iH*iW), 1) - self.buffer2:mean(self.buffer, 3) - self.buffer:repeatTensor(self.buffer2:view(1, nFeature, 1, 1), - nBatch, 1, iH, iW) - self.gradInput:add(gradOutput):add(-1, self.buffer) - self.buffer:repeatTensor(self.std:view(1, nFeature, 1, 1), - nBatch, 1, iH, iW) - self.gradInput:cmul(self.buffer) + assert(self.save_mean and self.save_std, 'must call :updateOutput() first') - if self.affine then - self.buffer:repeatTensor(self.weight:view(1, nFeature, 1, 1), - nBatch, 1, iH, iW) - self.gradInput:cmul(self.buffer) + scale = scale or 1 + if gradInput then + gradInput:resizeAs(gradOutput) end + input.nn.SpatialBatchNormalization_backward( + input, + gradOutput, + gradInput, + gradWeight, + gradBias, + self.weight, + self.save_mean, + self.save_std, + scale) + return self.gradInput end +function BN:backward(input, gradOutput, scale) + return backward(self, input, gradOutput, scale, self.gradInput, self.gradWeight, self.gradBias) +end + +function BN:updateGradInput(input, gradOutput) + return backward(self, input, gradOutput, 1, self.gradInput) +end + function BN:accGradParameters(input, gradOutput, scale) - if self.affine then - scale = scale or 1.0 - local nBatch = input:size(1) - local nFeature = input:size(2) - local iH = input:size(3) - local iW = input:size(4) - self.buffer2:resizeAs(self.normalized):copy(self.normalized) - self.buffer2 = self.buffer2:cmul(gradOutput):view(nBatch, nFeature, iH*iW) - self.buffer:sum(self.buffer2, 1) -- sum over mini-batch - self.buffer2:sum(self.buffer, 3) -- sum over pixels - self.gradWeight:add(scale, self.buffer2) - - self.buffer:sum(gradOutput:view(nBatch, nFeature, iH*iW), 1) - self.buffer2:sum(self.buffer, 3) - self.gradBias:add(scale, self.buffer2) -- sum over mini-batch + return backward(self, input, gradOutput, scale, nil, self.gradWeight, self.gradBias) +end + +function BN:read(file, version) + local var = file:readObject() + for k,v in pairs(var) do + if version < 2 and k == 'running_std' then + k = 'running_var' + v = v:cmul(v):pow(-1) + end + self[k] = v end end |