diff options
author | Sam Gross <sgross@fb.com> | 2015-12-22 21:30:06 +0300 |
---|---|---|
committer | Sam Gross <sgross@fb.com> | 2015-12-22 21:31:39 +0300 |
commit | 8c92b43e9633f385a3b71268b177db700ba44290 (patch) | |
tree | d2d9ee9ab5071b7e680c4110debfd3ce561d6b8d /SpatialBatchNormalization.lua | |
parent | bfecee58b2e03fd3ff212b5ef83c00a6b9abcbba (diff) |
Nil out the userdata 'oData' before serialization
Diffstat (limited to 'SpatialBatchNormalization.lua')
-rw-r--r-- | SpatialBatchNormalization.lua | 7 |
1 files changed, 4 insertions, 3 deletions
diff --git a/SpatialBatchNormalization.lua b/SpatialBatchNormalization.lua index bb3b934..9148873 100644 --- a/SpatialBatchNormalization.lua +++ b/SpatialBatchNormalization.lua @@ -11,10 +11,10 @@ function SpatialBatchNormalization:__init(nFeature, eps, momentum, affine) end function SpatialBatchNormalization:createIODescriptors(input) - assert(input:dim() == 4) + assert(input:dim() == 4) assert(torch.typename(self.weight) == 'torch.CudaTensor' and torch.typename(self.bias) == 'torch.CudaTensor', 'Only CUDA tensors are supported for cudnn.SpatialBatchNormalization!') - if not self.iDesc or not self.oDesc or + if not self.iDesc or not self.oDesc or input:size(1) ~= self.iSize[1] or input:size(2) ~= self.iSize[2] or input:size(3) ~= self.iSize[3] or input:size(4) ~= self.iSize[4] then self.iSize = input:size() @@ -30,7 +30,7 @@ local one = torch.FloatTensor({1}); local zero = torch.FloatTensor({0}); function SpatialBatchNormalization:updateOutput(input) - self:createIODescriptors(input) + self:createIODescriptors(input) if self.train then errcheck('cudnnBatchNormalizationForwardTraining', @@ -65,6 +65,7 @@ end function SpatialBatchNormalization:write(f) self.iDesc = nil + self.oDesc = nil self.sDesc = nil local var = {} for k,v in pairs(self) do |