Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/soumith/cudnn.torch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSam Gross <sgross@fb.com>2016-01-08 04:14:07 +0300
committerSam Gross <sgross@fb.com>2016-01-08 04:14:07 +0300
commit69d2b6824ee18b672132661e9e162e88af6f8c6b (patch)
treee56b493f6ec64bb442e2b11bc841409027b80359
parenta412cb2fe19f3b3aadab35672e485f53130879e3 (diff)
Fix cudnn.SpatialBatchNormalization after nn change
-rw-r--r--SpatialBatchNormalization.lua25
1 files changed, 23 insertions, 2 deletions
diff --git a/SpatialBatchNormalization.lua b/SpatialBatchNormalization.lua
index 9148873..8020195 100644
--- a/SpatialBatchNormalization.lua
+++ b/SpatialBatchNormalization.lua
@@ -1,13 +1,29 @@
-local SpatialBatchNormalization, parent = torch.class('cudnn.SpatialBatchNormalization', 'nn.SpatialBatchNormalization')
+local SpatialBatchNormalization, parent = torch.class('cudnn.SpatialBatchNormalization', 'nn.Module')
local ffi = require 'ffi'
local errcheck = cudnn.errcheck
function SpatialBatchNormalization:__init(nFeature, eps, momentum, affine)
- parent.__init(self, nFeature, eps, momentum, affine)
+ parent.__init(self)
+ assert(nFeature and type(nFeature) == 'number',
+ 'Missing argument #1: Number of feature planes. ')
+ assert(nFeature ~= 0, 'To set affine=false call BatchNormalization'
+ .. '(nFeature, eps, momentum, false) ')
+ assert(affine == nil or affine == true, 'only affine supported')
+
self.mode = 'CUDNN_BATCHNORM_SPATIAL'
self.nFeature = nFeature
+ self.eps = eps or 1e-5
+ self.train = true
+ self.momentum = momentum or 0.1
self.save_mean = torch.Tensor(nFeature)
self.save_std = torch.Tensor(nFeature)
+ self.running_mean = torch.zeros(nFeature)
+ self.running_std = torch.ones(nFeature)
+ self.weight = torch.Tensor(nFeature)
+ self.bias = torch.Tensor(nFeature)
+ self.gradWeight = torch.Tensor(nFeature)
+ self.gradBias = torch.Tensor(nFeature)
+ self:reset()
end
function SpatialBatchNormalization:createIODescriptors(input)
@@ -29,6 +45,11 @@ end
local one = torch.FloatTensor({1});
local zero = torch.FloatTensor({0});
+function SpatialBatchNormalization:reset()
+ self.weight:uniform()
+ self.bias:zero()
+end
+
function SpatialBatchNormalization:updateOutput(input)
self:createIODescriptors(input)