diff options
author | Sergey Zagoruyko <zagoruyko2@gmail.com> | 2016-02-12 18:27:22 +0300 |
---|---|---|
committer | Sergey Zagoruyko <zagoruyko2@gmail.com> | 2016-02-12 18:27:22 +0300 |
commit | 45fec278bff32e3788232c0d70c0b5311bc95b57 (patch) | |
tree | 1b3d0538059d7ea61a7d4ffcc64436554c88c41f /SpatialBatchNormalization.lua | |
parent | 28c2f6e76a0d3671ce127197c25e39c5ee4be627 (diff) |
clear save_mean and save_std in BN
Diffstat (limited to 'SpatialBatchNormalization.lua')
-rw-r--r-- | SpatialBatchNormalization.lua | 19 |
1 files changed, 16 insertions, 3 deletions
diff --git a/SpatialBatchNormalization.lua b/SpatialBatchNormalization.lua index 82a0e2d..093e20b 100644 --- a/SpatialBatchNormalization.lua +++ b/SpatialBatchNormalization.lua @@ -28,8 +28,6 @@ function SpatialBatchNormalization:__init(nFeature, eps, momentum, affine) self:reset() end self.mode = 'CUDNN_BATCHNORM_SPATIAL' - self.save_mean = torch.Tensor(nFeature) - self.save_std = torch.Tensor(nFeature) end function SpatialBatchNormalization:createIODescriptors(input) @@ -56,6 +54,11 @@ local scaleTens = torch.FloatTensor(1); function SpatialBatchNormalization:updateOutput(input) self:createIODescriptors(input) + self.save_mean = self.save_mean or input.new() + self.save_mean:resizeAs(self.running_mean) + self.save_std = self.save_std or input.new() + self.save_std:resizeAs(self.running_std) + if self.train then errcheck('cudnnBatchNormalizationForwardTraining', cudnn.getHandle(), self.mode, one:data(), zero:data(), @@ -99,13 +102,23 @@ end function SpatialBatchNormalization:accGradParameters(input, gradOutput, scale) end -function SpatialBatchNormalization:write(f) +function SpatialBatchNormalization:clearDesc() self.iDesc = nil self.oDesc = nil self.sDesc = nil +end + +function SpatialBatchNormalization:write(f) + self:clearDesc() local var = {} for k,v in pairs(self) do var[k] = v end f:writeObject(var) end + +function SpatialBatchNormalization:clearState() + self:clearDesc() + nn.utils.clear(self, 'save_mean', 'save_std') + return parent.clearState(self) +end |