From 7ad74db1bf2d93edbc794b3f1de73e6db9470aad Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Thu, 21 Jan 2016 14:04:15 -0800 Subject: Calls updated to 4.0.5 --- SpatialBatchNormalization.lua | 50 ++++++++++++++++++------------------------- 1 file changed, 21 insertions(+), 29 deletions(-) (limited to 'SpatialBatchNormalization.lua') diff --git a/SpatialBatchNormalization.lua b/SpatialBatchNormalization.lua index 8020195..d88da4e 100644 --- a/SpatialBatchNormalization.lua +++ b/SpatialBatchNormalization.lua @@ -1,29 +1,13 @@ -local SpatialBatchNormalization, parent = torch.class('cudnn.SpatialBatchNormalization', 'nn.Module') +local SpatialBatchNormalization, parent = torch.class('cudnn.SpatialBatchNormalization', 'nn.SpatialBatchNormalization') local ffi = require 'ffi' local errcheck = cudnn.errcheck function SpatialBatchNormalization:__init(nFeature, eps, momentum, affine) - parent.__init(self) - assert(nFeature and type(nFeature) == 'number', - 'Missing argument #1: Number of feature planes. ') - assert(nFeature ~= 0, 'To set affine=false call BatchNormalization' - .. '(nFeature, eps, momentum, false) ') - assert(affine == nil or affine == true, 'only affine supported') - + parent.__init(self, nFeature, eps, momentum, affine) self.mode = 'CUDNN_BATCHNORM_SPATIAL' self.nFeature = nFeature - self.eps = eps or 1e-5 - self.train = true - self.momentum = momentum or 0.1 self.save_mean = torch.Tensor(nFeature) self.save_std = torch.Tensor(nFeature) - self.running_mean = torch.zeros(nFeature) - self.running_std = torch.ones(nFeature) - self.weight = torch.Tensor(nFeature) - self.bias = torch.Tensor(nFeature) - self.gradWeight = torch.Tensor(nFeature) - self.gradBias = torch.Tensor(nFeature) - self:reset() end function SpatialBatchNormalization:createIODescriptors(input) @@ -44,11 +28,7 @@ end local one = torch.FloatTensor({1}); local zero = torch.FloatTensor({0}); - -function SpatialBatchNormalization:reset() - self.weight:uniform() - self.bias:zero() -end +local scaleTens = torch.FloatTensor(1); function SpatialBatchNormalization:updateOutput(input) self:createIODescriptors(input) @@ -58,29 +38,41 @@ function SpatialBatchNormalization:updateOutput(input) cudnn.getHandle(), self.mode, one:data(), zero:data(), self.iDesc[0], input:data(), self.oDesc[0], self.output:data(), self.sDesc[0], self.weight:data(), self.bias:data(), - self.momentum, self.running_mean:data(), self.running_std:data(), self.eps, self.save_mean:data(), self.save_std:data()); + self.momentum, self.running_mean:data(), self.running_var:data(), self.eps, self.save_mean:data(), self.save_std:data()); else errcheck('cudnnBatchNormalizationForwardInference', cudnn.getHandle(), self.mode, one:data(), zero:data(), self.iDesc[0], input:data(), self.oDesc[0], self.output:data(), self.sDesc[0], self.weight:data(), self.bias:data(), - self.running_mean:data(), self.running_std:data(), self.eps); + self.running_mean:data(), self.running_var:data(), self.eps); end return self.output end -function SpatialBatchNormalization:updateGradInput(input, gradOutput) - assert(gradOutput:isContiguous()); +local function backward(self,input,gradOutput, scale) + assert(gradOutput:isContiguous()) self:createIODescriptors(input) + scale = scale or 1 + scaleTens:fill(scale) errcheck('cudnnBatchNormalizationBackward', - cudnn.getHandle(), self.mode, one:data(), zero:data(), + cudnn.getHandle(), self.mode, one:data(), zero:data(), scaleTens:data(), one:data(), self.iDesc[0], input:data(), self.iDesc[0], gradOutput:data(), self.iDesc[0], self.gradInput:data(), -- input is bottom, gradOutput is topDiff, self.gradInput is resultBottomDiff self.sDesc[0], self.weight:data(), self.gradWeight:data(), self.gradBias:data(), - self.eps, self.save_mean:data(), self.save_std:data()); + self.eps, self.save_mean:data(), self.save_std:data()); return self.gradInput end +function SpatialBatchNormalization:updateGradInput(input, gradOutput, scale) +-- will in fact update gradWeight and gradBias too, accGradParameters call is empty + return backward(self, input,gradOutput, scale) +end + + +function SpatialBatchNormalization:backward(input, gradOutput, scale) + return backward(self, input,gradOutput, scale) +end + function SpatialBatchNormalization:accGradParameters(input, gradOutput, scale) end -- cgit v1.2.3