Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSam Gross <sgross@fb.com>2016-01-05 07:57:45 +0300
committerSam Gross <sgross@fb.com>2016-01-05 22:45:11 +0300
commita142233b8bebc7a4acfa0ad6a66c1d400803034e (patch)
treebcef518cdd9125c5c68d493bbfc5873261b72192 /SpatialBatchNormalization.lua
parent31f71b92881ad1ef3356a5077725f0541f765340 (diff)
Add C implementation of SpatialBatchNormalization
This is primarily to support the fast, memory-efficient CUDA implementation. Some other changes include making weight and bias each individually optional and averaging the variances instead of the inverse standard deviation.
Diffstat (limited to 'SpatialBatchNormalization.lua')
-rw-r--r--SpatialBatchNormalization.lua167
1 files changed, 62 insertions, 105 deletions
diff --git a/SpatialBatchNormalization.lua b/SpatialBatchNormalization.lua
index a5eac14..e844385 100644
--- a/SpatialBatchNormalization.lua
+++ b/SpatialBatchNormalization.lua
@@ -18,9 +18,9 @@
Usage:
with learnable parameters: nn.BatchNormalization(N [,eps] [,momentum])
where N = dimensionality of input
- without learnable parameters: nn.BatchNormalization(0 [,eps] [,momentum])
+ without learnable parameters: nn.BatchNormalization(N [,eps] [,momentum], false)
- eps is a small value added to the standard-deviation to avoid divide-by-zero.
+ eps is a small value added to the variance to avoid divide-by-zero.
Defaults to 1e-5
In training time, this layer keeps a running estimate of it's computed mean and std.
@@ -30,13 +30,15 @@
]]--
local BN,parent = torch.class('nn.SpatialBatchNormalization', 'nn.Module')
+BN.__version = 2
+
function BN:__init(nFeature, eps, momentum, affine)
parent.__init(self)
assert(nFeature and type(nFeature) == 'number',
'Missing argument #1: Number of feature planes. ')
assert(nFeature ~= 0, 'To set affine=false call SpatialBatchNormalization'
.. '(nFeature, eps, momentum, false) ')
- if affine ~=nil then
+ if affine ~= nil then
assert(type(affine) == 'boolean', 'affine has to be true/false')
self.affine = affine
else
@@ -47,7 +49,7 @@ function BN:__init(nFeature, eps, momentum, affine)
self.momentum = momentum or 0.1
self.running_mean = torch.zeros(nFeature)
- self.running_std = torch.ones(nFeature)
+ self.running_var = torch.ones(nFeature)
if self.affine then
self.weight = torch.Tensor(nFeature)
self.bias = torch.Tensor(nFeature)
@@ -58,127 +60,82 @@ function BN:__init(nFeature, eps, momentum, affine)
end
function BN:reset()
- self.weight:uniform()
- self.bias:zero()
- self.running_mean:zero()
- self.running_std:fill(1)
+ if self.weight then
+ self.weight:uniform()
+ end
+ if self.bias then
+ self.bias:zero()
+ end
end
function BN:updateOutput(input)
assert(input:dim() == 4, 'only mini-batch supported (4D tensor), got '
.. input:dim() .. 'D tensor instead')
- local nBatch = input:size(1)
- local nFeature = input:size(2)
- local iH = input:size(3)
- local iW = input:size(4)
- -- buffers that are reused
- self.buffer = self.buffer or input.new()
self.output:resizeAs(input)
- if self.train == false then
- self.output:copy(input)
- self.buffer:repeatTensor(self.running_mean:view(1, nFeature, 1, 1), nBatch, 1, iH, iW)
- self.output:add(-1, self.buffer)
- self.buffer:repeatTensor(self.running_std:view(1, nFeature, 1, 1), nBatch, 1, iH, iW)
- self.output:cmul(self.buffer)
- else -- training mode
- self.buffer2 = self.buffer2 or input.new()
- self.centered = self.centered or input.new()
- self.centered:resizeAs(input)
- self.std = self.std or input.new()
- self.normalized = self.normalized or input.new()
- self.normalized:resizeAs(input)
- self.gradInput:resizeAs(input)
- -- calculate mean over mini-batch, over feature-maps
- local in_folded = input:view(nBatch, nFeature, iH * iW)
- self.buffer:mean(in_folded, 1)
- self.buffer2:mean(self.buffer, 3)
- self.running_mean:mul(1 - self.momentum):add(self.momentum, self.buffer2) -- add to running mean
- self.buffer:repeatTensor(self.buffer2:view(1, nFeature, 1, 1),
- nBatch, 1, iH, iW)
-
- -- subtract mean
- self.centered:add(input, -1, self.buffer) -- x - E(x)
-
- -- calculate standard deviation over mini-batch
- self.buffer:copy(self.centered):cmul(self.buffer) -- [x - E(x)]^2
- local buf_folded = self.buffer:view(nBatch,nFeature,iH*iW)
- self.std:mean(self.buffer2:mean(buf_folded, 1), 3)
- self.std:add(self.eps):sqrt():pow(-1) -- 1 / E([x - E(x)]^2)
- self.running_std:mul(1 - self.momentum):add(self.momentum, self.std) -- add to running stdv
- self.buffer:repeatTensor(self.std:view(1, nFeature, 1, 1),
- nBatch, 1, iH, iW)
-
- -- divide standard-deviation + eps
- self.output:cmul(self.centered, self.buffer)
- self.normalized:copy(self.output)
- end
-
- if self.affine then
- -- multiply with gamma and add beta
- self.buffer:repeatTensor(self.weight:view(1, nFeature, 1, 1),
- nBatch, 1, iH, iW)
- self.output:cmul(self.buffer)
- self.buffer:repeatTensor(self.bias:view(1, nFeature, 1, 1),
- nBatch, 1, iH, iW)
- self.output:add(self.buffer)
- end
+ self.save_mean = self.save_mean or input.new():resizeAs(self.running_mean)
+ self.save_std = self.save_std or input.new():resizeAs(self.running_var)
+
+ input.nn.SpatialBatchNormalization_updateOutput(
+ input,
+ self.output,
+ self.weight,
+ self.bias,
+ self.train,
+ self.eps,
+ self.momentum,
+ self.running_mean,
+ self.running_var,
+ self.save_mean,
+ self.save_std)
return self.output
end
-function BN:updateGradInput(input, gradOutput)
+local function backward(self, input, gradOutput, scale, gradInput, gradWeight, gradBias)
assert(input:dim() == 4, 'only mini-batch supported')
assert(gradOutput:dim() == 4, 'only mini-batch supported')
assert(self.train == true, 'should be in training mode when self.train is true')
- local nBatch = input:size(1)
- local nFeature = input:size(2)
- local iH = input:size(3)
- local iW = input:size(4)
-
- self.gradInput:cmul(self.centered, gradOutput)
- local gi_folded = self.gradInput:view(nBatch, nFeature, iH * iW)
- self.buffer2:mean(self.buffer:mean(gi_folded, 1), 3)
- self.gradInput:repeatTensor(self.buffer2:view(1, nFeature, 1, 1),
- nBatch, 1, iH, iW)
- self.gradInput:cmul(self.centered):mul(-1)
- self.buffer:repeatTensor(self.std:view(1, nFeature, 1, 1),
- nBatch, 1, iH, iW)
- self.gradInput:cmul(self.buffer):cmul(self.buffer)
-
- self.buffer:mean(gradOutput:view(nBatch, nFeature, iH*iW), 1)
- self.buffer2:mean(self.buffer, 3)
- self.buffer:repeatTensor(self.buffer2:view(1, nFeature, 1, 1),
- nBatch, 1, iH, iW)
- self.gradInput:add(gradOutput):add(-1, self.buffer)
- self.buffer:repeatTensor(self.std:view(1, nFeature, 1, 1),
- nBatch, 1, iH, iW)
- self.gradInput:cmul(self.buffer)
+ assert(self.save_mean and self.save_std, 'must call :updateOutput() first')
- if self.affine then
- self.buffer:repeatTensor(self.weight:view(1, nFeature, 1, 1),
- nBatch, 1, iH, iW)
- self.gradInput:cmul(self.buffer)
+ scale = scale or 1
+ if gradInput then
+ gradInput:resizeAs(gradOutput)
end
+ input.nn.SpatialBatchNormalization_backward(
+ input,
+ gradOutput,
+ gradInput,
+ gradWeight,
+ gradBias,
+ self.weight,
+ self.save_mean,
+ self.save_std,
+ scale)
+
return self.gradInput
end
+function BN:backward(input, gradOutput, scale)
+ return backward(self, input, gradOutput, scale, self.gradInput, self.gradWeight, self.gradBias)
+end
+
+function BN:updateGradInput(input, gradOutput)
+ return backward(self, input, gradOutput, 1, self.gradInput)
+end
+
function BN:accGradParameters(input, gradOutput, scale)
- if self.affine then
- scale = scale or 1.0
- local nBatch = input:size(1)
- local nFeature = input:size(2)
- local iH = input:size(3)
- local iW = input:size(4)
- self.buffer2:resizeAs(self.normalized):copy(self.normalized)
- self.buffer2 = self.buffer2:cmul(gradOutput):view(nBatch, nFeature, iH*iW)
- self.buffer:sum(self.buffer2, 1) -- sum over mini-batch
- self.buffer2:sum(self.buffer, 3) -- sum over pixels
- self.gradWeight:add(scale, self.buffer2)
-
- self.buffer:sum(gradOutput:view(nBatch, nFeature, iH*iW), 1)
- self.buffer2:sum(self.buffer, 3)
- self.gradBias:add(scale, self.buffer2) -- sum over mini-batch
+ return backward(self, input, gradOutput, scale, nil, self.gradWeight, self.gradBias)
+end
+
+function BN:read(file, version)
+ local var = file:readObject()
+ for k,v in pairs(var) do
+ if version < 2 and k == 'running_std' then
+ k = 'running_var'
+ v = v:cmul(v):pow(-1)
+ end
+ self[k] = v
end
end