diff options
author | Sergey Zagoruyko <zagoruyko2@gmail.com> | 2016-02-12 18:00:14 +0300 |
---|---|---|
committer | Sergey Zagoruyko <zagoruyko2@gmail.com> | 2016-02-12 18:00:58 +0300 |
commit | 278aa716e2327920f5c6b3035d8013e140098cbc (patch) | |
tree | d50271abebb6edf69072165a0deb829b99212e99 | |
parent | 60c85023e872318904c0f53af30f93648c7258df (diff) |
running_var to running_std in BN
-rw-r--r-- | SpatialBatchNormalization.lua | 48 |
1 files changed, 29 insertions, 19 deletions
diff --git a/SpatialBatchNormalization.lua b/SpatialBatchNormalization.lua index c25cd92..82a0e2d 100644 --- a/SpatialBatchNormalization.lua +++ b/SpatialBatchNormalization.lua @@ -1,13 +1,33 @@ -local SpatialBatchNormalization, parent = torch.class('cudnn.SpatialBatchNormalization', 'nn.SpatialBatchNormalization') +local SpatialBatchNormalization, parent = torch.class('cudnn.SpatialBatchNormalization', 'nn.Module') local ffi = require 'ffi' local errcheck = cudnn.errcheck -SpatialBatchNormalization.__version = 2 - function SpatialBatchNormalization:__init(nFeature, eps, momentum, affine) - parent.__init(self, nFeature, eps, momentum, affine) + parent.__init(self) + assert(nFeature and type(nFeature) == 'number', + 'Missing argument #1: Number of feature planes. ') + assert(nFeature ~= 0, 'To set affine=false call SpatialBatchNormalization' + .. '(nFeature, eps, momentum, false) ') + if affine ~= nil then + assert(type(affine) == 'boolean', 'affine has to be true/false') + self.affine = affine + else + self.affine = true + end + self.eps = eps or 1e-5 + self.train = true + self.momentum = momentum or 0.1 + + self.running_mean = torch.zeros(nFeature) + self.running_std = torch.ones(nFeature) + if self.affine then + self.weight = torch.Tensor(nFeature) + self.bias = torch.Tensor(nFeature) + self.gradWeight = torch.Tensor(nFeature) + self.gradBias = torch.Tensor(nFeature) + self:reset() + end self.mode = 'CUDNN_BATCHNORM_SPATIAL' - self.nFeature = nFeature self.save_mean = torch.Tensor(nFeature) self.save_std = torch.Tensor(nFeature) end @@ -19,12 +39,13 @@ function SpatialBatchNormalization:createIODescriptors(input) if not self.iDesc or not self.oDesc or input:size(1) ~= self.iSize[1] or input:size(2) ~= self.iSize[2] or input:size(3) ~= self.iSize[3] or input:size(4) ~= self.iSize[4] then + local nFeature = self.running_mean:numel() self.iSize = input:size() self.output:resizeAs(input) self.gradInput:resizeAs(input) self.iDesc = cudnn.toDescriptor(input) self.oDesc = cudnn.toDescriptor(self.output) - self.sDesc = cudnn.toDescriptor(self.bias:view(1, self.nFeature, 1, 1)) + self.sDesc = cudnn.toDescriptor(self.bias:view(1, nFeature, 1, 1)) end end @@ -40,13 +61,13 @@ function SpatialBatchNormalization:updateOutput(input) cudnn.getHandle(), self.mode, one:data(), zero:data(), self.iDesc[0], input:data(), self.oDesc[0], self.output:data(), self.sDesc[0], self.weight:data(), self.bias:data(), - self.momentum, self.running_mean:data(), self.running_var:data(), self.eps, self.save_mean:data(), self.save_std:data()); + self.momentum, self.running_mean:data(), self.running_std:data(), self.eps, self.save_mean:data(), self.save_std:data()); else errcheck('cudnnBatchNormalizationForwardInference', cudnn.getHandle(), self.mode, one:data(), zero:data(), self.iDesc[0], input:data(), self.oDesc[0], self.output:data(), self.sDesc[0], self.weight:data(), self.bias:data(), - self.running_mean:data(), self.running_var:data(), self.eps); + self.running_mean:data(), self.running_std:data(), self.eps); end return self.output end @@ -88,14 +109,3 @@ function SpatialBatchNormalization:write(f) end f:writeObject(var) end - -function SpatialBatchNormalization:read(file, version) - parent.read(self, file) - if version < 2 then - if self.running_std then - -- for models before https://github.com/soumith/cudnn.torch/pull/101 - self.running_var = self.running_std - self.running_std = nil - end - end -end |