Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/soumith/cudnn.torch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBoris Fomitchev <bfomitchev@nvidia.com>2016-01-22 01:04:15 +0300
committerBoris Fomitchev <bfomitchev@nvidia.com>2016-01-22 01:04:15 +0300
commit7ad74db1bf2d93edbc794b3f1de73e6db9470aad (patch)
tree4723ecc91890eb6667853e738ef4854a81a9c851 /SpatialBatchNormalization.lua
parent056ed8965ce9557db146a475e0ef4772b3afda77 (diff)
Calls updated to 4.0.5
Diffstat (limited to 'SpatialBatchNormalization.lua')
-rw-r--r--SpatialBatchNormalization.lua50
1 files changed, 21 insertions, 29 deletions
diff --git a/SpatialBatchNormalization.lua b/SpatialBatchNormalization.lua
index 8020195..d88da4e 100644
--- a/SpatialBatchNormalization.lua
+++ b/SpatialBatchNormalization.lua
@@ -1,29 +1,13 @@
-local SpatialBatchNormalization, parent = torch.class('cudnn.SpatialBatchNormalization', 'nn.Module')
+local SpatialBatchNormalization, parent = torch.class('cudnn.SpatialBatchNormalization', 'nn.SpatialBatchNormalization')
local ffi = require 'ffi'
local errcheck = cudnn.errcheck
function SpatialBatchNormalization:__init(nFeature, eps, momentum, affine)
- parent.__init(self)
- assert(nFeature and type(nFeature) == 'number',
- 'Missing argument #1: Number of feature planes. ')
- assert(nFeature ~= 0, 'To set affine=false call BatchNormalization'
- .. '(nFeature, eps, momentum, false) ')
- assert(affine == nil or affine == true, 'only affine supported')
-
+ parent.__init(self, nFeature, eps, momentum, affine)
self.mode = 'CUDNN_BATCHNORM_SPATIAL'
self.nFeature = nFeature
- self.eps = eps or 1e-5
- self.train = true
- self.momentum = momentum or 0.1
self.save_mean = torch.Tensor(nFeature)
self.save_std = torch.Tensor(nFeature)
- self.running_mean = torch.zeros(nFeature)
- self.running_std = torch.ones(nFeature)
- self.weight = torch.Tensor(nFeature)
- self.bias = torch.Tensor(nFeature)
- self.gradWeight = torch.Tensor(nFeature)
- self.gradBias = torch.Tensor(nFeature)
- self:reset()
end
function SpatialBatchNormalization:createIODescriptors(input)
@@ -44,11 +28,7 @@ end
local one = torch.FloatTensor({1});
local zero = torch.FloatTensor({0});
-
-function SpatialBatchNormalization:reset()
- self.weight:uniform()
- self.bias:zero()
-end
+local scaleTens = torch.FloatTensor(1);
function SpatialBatchNormalization:updateOutput(input)
self:createIODescriptors(input)
@@ -58,29 +38,41 @@ function SpatialBatchNormalization:updateOutput(input)
cudnn.getHandle(), self.mode, one:data(), zero:data(),
self.iDesc[0], input:data(), self.oDesc[0], self.output:data(),
self.sDesc[0], self.weight:data(), self.bias:data(),
- self.momentum, self.running_mean:data(), self.running_std:data(), self.eps, self.save_mean:data(), self.save_std:data());
+ self.momentum, self.running_mean:data(), self.running_var:data(), self.eps, self.save_mean:data(), self.save_std:data());
else
errcheck('cudnnBatchNormalizationForwardInference',
cudnn.getHandle(), self.mode, one:data(), zero:data(),
self.iDesc[0], input:data(), self.oDesc[0], self.output:data(),
self.sDesc[0], self.weight:data(), self.bias:data(),
- self.running_mean:data(), self.running_std:data(), self.eps);
+ self.running_mean:data(), self.running_var:data(), self.eps);
end
return self.output
end
-function SpatialBatchNormalization:updateGradInput(input, gradOutput)
- assert(gradOutput:isContiguous());
+local function backward(self,input,gradOutput, scale)
+ assert(gradOutput:isContiguous())
self:createIODescriptors(input)
+ scale = scale or 1
+ scaleTens:fill(scale)
errcheck('cudnnBatchNormalizationBackward',
- cudnn.getHandle(), self.mode, one:data(), zero:data(),
+ cudnn.getHandle(), self.mode, one:data(), zero:data(), scaleTens:data(), one:data(),
self.iDesc[0], input:data(), self.iDesc[0], gradOutput:data(), self.iDesc[0], self.gradInput:data(),
-- input is bottom, gradOutput is topDiff, self.gradInput is resultBottomDiff
self.sDesc[0], self.weight:data(), self.gradWeight:data(), self.gradBias:data(),
- self.eps, self.save_mean:data(), self.save_std:data());
+ self.eps, self.save_mean:data(), self.save_std:data());
return self.gradInput
end
+function SpatialBatchNormalization:updateGradInput(input, gradOutput, scale)
+-- will in fact update gradWeight and gradBias too, accGradParameters call is empty
+ return backward(self, input,gradOutput, scale)
+end
+
+
+function SpatialBatchNormalization:backward(input, gradOutput, scale)
+ return backward(self, input,gradOutput, scale)
+end
+
function SpatialBatchNormalization:accGradParameters(input, gradOutput, scale)
end