diff options
author | Sam Gross <sgross@fb.com> | 2016-01-08 04:14:07 +0300 |
---|---|---|
committer | Sam Gross <sgross@fb.com> | 2016-01-08 04:14:07 +0300 |
commit | 69d2b6824ee18b672132661e9e162e88af6f8c6b (patch) | |
tree | e56b493f6ec64bb442e2b11bc841409027b80359 | |
parent | a412cb2fe19f3b3aadab35672e485f53130879e3 (diff) |
Fix cudnn.SpatialBatchNormalization after nn change
-rw-r--r-- | SpatialBatchNormalization.lua | 25 |
1 files changed, 23 insertions, 2 deletions
diff --git a/SpatialBatchNormalization.lua b/SpatialBatchNormalization.lua index 9148873..8020195 100644 --- a/SpatialBatchNormalization.lua +++ b/SpatialBatchNormalization.lua @@ -1,13 +1,29 @@ -local SpatialBatchNormalization, parent = torch.class('cudnn.SpatialBatchNormalization', 'nn.SpatialBatchNormalization') +local SpatialBatchNormalization, parent = torch.class('cudnn.SpatialBatchNormalization', 'nn.Module') local ffi = require 'ffi' local errcheck = cudnn.errcheck function SpatialBatchNormalization:__init(nFeature, eps, momentum, affine) - parent.__init(self, nFeature, eps, momentum, affine) + parent.__init(self) + assert(nFeature and type(nFeature) == 'number', + 'Missing argument #1: Number of feature planes. ') + assert(nFeature ~= 0, 'To set affine=false call BatchNormalization' + .. '(nFeature, eps, momentum, false) ') + assert(affine == nil or affine == true, 'only affine supported') + self.mode = 'CUDNN_BATCHNORM_SPATIAL' self.nFeature = nFeature + self.eps = eps or 1e-5 + self.train = true + self.momentum = momentum or 0.1 self.save_mean = torch.Tensor(nFeature) self.save_std = torch.Tensor(nFeature) + self.running_mean = torch.zeros(nFeature) + self.running_std = torch.ones(nFeature) + self.weight = torch.Tensor(nFeature) + self.bias = torch.Tensor(nFeature) + self.gradWeight = torch.Tensor(nFeature) + self.gradBias = torch.Tensor(nFeature) + self:reset() end function SpatialBatchNormalization:createIODescriptors(input) @@ -29,6 +45,11 @@ end local one = torch.FloatTensor({1}); local zero = torch.FloatTensor({0}); +function SpatialBatchNormalization:reset() + self.weight:uniform() + self.bias:zero() +end + function SpatialBatchNormalization:updateOutput(input) self:createIODescriptors(input) |