Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/soumith/cudnn.torch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSergey Zagoruyko <zagoruyko2@gmail.com>2016-02-12 18:27:22 +0300
committerSergey Zagoruyko <zagoruyko2@gmail.com>2016-02-12 18:27:22 +0300
commit45fec278bff32e3788232c0d70c0b5311bc95b57 (patch)
tree1b3d0538059d7ea61a7d4ffcc64436554c88c41f /SpatialBatchNormalization.lua
parent28c2f6e76a0d3671ce127197c25e39c5ee4be627 (diff)
clear save_mean and save_std in BN
Diffstat (limited to 'SpatialBatchNormalization.lua')
-rw-r--r--SpatialBatchNormalization.lua19
1 files changed, 16 insertions, 3 deletions
diff --git a/SpatialBatchNormalization.lua b/SpatialBatchNormalization.lua
index 82a0e2d..093e20b 100644
--- a/SpatialBatchNormalization.lua
+++ b/SpatialBatchNormalization.lua
@@ -28,8 +28,6 @@ function SpatialBatchNormalization:__init(nFeature, eps, momentum, affine)
self:reset()
end
self.mode = 'CUDNN_BATCHNORM_SPATIAL'
- self.save_mean = torch.Tensor(nFeature)
- self.save_std = torch.Tensor(nFeature)
end
function SpatialBatchNormalization:createIODescriptors(input)
@@ -56,6 +54,11 @@ local scaleTens = torch.FloatTensor(1);
function SpatialBatchNormalization:updateOutput(input)
self:createIODescriptors(input)
+ self.save_mean = self.save_mean or input.new()
+ self.save_mean:resizeAs(self.running_mean)
+ self.save_std = self.save_std or input.new()
+ self.save_std:resizeAs(self.running_std)
+
if self.train then
errcheck('cudnnBatchNormalizationForwardTraining',
cudnn.getHandle(), self.mode, one:data(), zero:data(),
@@ -99,13 +102,23 @@ end
function SpatialBatchNormalization:accGradParameters(input, gradOutput, scale)
end
-function SpatialBatchNormalization:write(f)
+function SpatialBatchNormalization:clearDesc()
self.iDesc = nil
self.oDesc = nil
self.sDesc = nil
+end
+
+function SpatialBatchNormalization:write(f)
+ self:clearDesc()
local var = {}
for k,v in pairs(self) do
var[k] = v
end
f:writeObject(var)
end
+
+function SpatialBatchNormalization:clearState()
+ self:clearDesc()
+ nn.utils.clear(self, 'save_mean', 'save_std')
+ return parent.clearState(self)
+end