Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorsoumith <soumith@fb.com>2015-05-15 00:29:42 +0300
committersoumith <soumith@fb.com>2015-06-03 08:08:12 +0300
commit83a3815dc70255c978405e8e966d7b02d580cc11 (patch)
tree0a947fe2d6889de9c5d77d8947bee3129cfb2754 /SpatialBatchNormalization.lua
parent975360a3ddd0ee0fccf7b86a3ce5120f6a9c55bd (diff)
batchnorm is clonable by adding the running estimates to constructor
fixing batchnorm tests
Diffstat (limited to 'SpatialBatchNormalization.lua')
-rw-r--r--SpatialBatchNormalization.lua26
1 files changed, 12 insertions, 14 deletions
diff --git a/SpatialBatchNormalization.lua b/SpatialBatchNormalization.lua
index 3f09c3f..cbc50d3 100644
--- a/SpatialBatchNormalization.lua
+++ b/SpatialBatchNormalization.lua
@@ -30,18 +30,24 @@
]]--
local BN,parent = torch.class('nn.SpatialBatchNormalization', 'nn.Module')
-function BN:__init(nFeature, eps, momentum)
+function BN:__init(nFeature, eps, momentum, affine)
parent.__init(self)
assert(nFeature and type(nFeature) == 'number',
- 'Missing argument #1: Number of feature planes. ' ..
- 'Give 0 for no affine transform')
+ 'Missing argument #1: Number of feature planes. ')
+ assert(nFeature ~= 0, 'To set affine=false call SpatialBatchNormalization'
+ .. '(nFeature, eps, momentum, false) ')
+ if affine ~=nil then
+ assert(type(affine) == 'boolean', 'affine has to be true/false')
+ self.affine = affine
+ else
+ self.affine = true
+ end
self.eps = eps or 1e-5
self.train = true
self.momentum = momentum or 0.1
- self.running_mean = torch.Tensor()
- self.running_std = torch.Tensor()
- if nFeature > 0 then self.affine = true end
+ self.running_mean = torch.zeros(nFeature)
+ self.running_std = torch.ones(nFeature)
if self.affine then
self.weight = torch.Tensor(nFeature)
self.bias = torch.Tensor(nFeature)
@@ -75,20 +81,12 @@ function BN:updateOutput(input)
self.output:resizeAs(input)
self.gradInput:resizeAs(input)
if self.train == false then
- assert(self.running_mean:nDimension() ~= 0,
- 'Module never run on training data. First run on some training data before evaluating.')
self.output:copy(input)
self.buffer:repeatTensor(self.running_mean:view(1, nFeature, 1, 1), nBatch, 1, iH, iW)
self.output:add(-1, self.buffer)
self.buffer:repeatTensor(self.running_std:view(1, nFeature, 1, 1), nBatch, 1, iH, iW)
self.output:cmul(self.buffer)
else -- training mode
- if self.running_mean:nDimension() == 0 then
- self.running_mean:resize(nFeature):zero()
- end
- if self.running_std:nDimension() == 0 then
- self.running_std:resize(nFeature):zero()
- end
-- calculate mean over mini-batch, over feature-maps
local in_folded = input:view(nBatch, nFeature, iH * iW)
self.buffer:mean(in_folded, 1)