Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/soumith/cudnn.torch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorsoumith <soumith@gmail.com>2016-02-12 02:50:44 +0300
committersoumith <soumith@gmail.com>2016-02-12 02:50:44 +0300
commit02cade60f8f40b84c8d65ea2616a317b1fa9590e (patch)
tree35ea69bfdf130e2fd045611651e58337453565d1
parent66a046907ca362745368b7ca8551a51062352d96 (diff)
serialization fix for older models
-rw-r--r--SpatialBatchNormalization.lua17
1 files changed, 15 insertions, 2 deletions
diff --git a/SpatialBatchNormalization.lua b/SpatialBatchNormalization.lua
index d88da4e..c8adce7 100644
--- a/SpatialBatchNormalization.lua
+++ b/SpatialBatchNormalization.lua
@@ -2,6 +2,8 @@ local SpatialBatchNormalization, parent = torch.class('cudnn.SpatialBatchNormali
local ffi = require 'ffi'
local errcheck = cudnn.errcheck
+SpatialBatchNormalization.__version = 2
+
function SpatialBatchNormalization:__init(nFeature, eps, momentum, affine)
parent.__init(self, nFeature, eps, momentum, affine)
self.mode = 'CUDNN_BATCHNORM_SPATIAL'
@@ -53,13 +55,13 @@ local function backward(self,input,gradOutput, scale)
assert(gradOutput:isContiguous())
self:createIODescriptors(input)
scale = scale or 1
- scaleTens:fill(scale)
+ scaleTens:fill(scale)
errcheck('cudnnBatchNormalizationBackward',
cudnn.getHandle(), self.mode, one:data(), zero:data(), scaleTens:data(), one:data(),
self.iDesc[0], input:data(), self.iDesc[0], gradOutput:data(), self.iDesc[0], self.gradInput:data(),
-- input is bottom, gradOutput is topDiff, self.gradInput is resultBottomDiff
self.sDesc[0], self.weight:data(), self.gradWeight:data(), self.gradBias:data(),
- self.eps, self.save_mean:data(), self.save_std:data());
+ self.eps, self.save_mean:data(), self.save_std:data());
return self.gradInput
end
@@ -86,3 +88,14 @@ function SpatialBatchNormalization:write(f)
end
f:writeObject(var)
end
+
+function SpatialBatchNormalization:read(file, version)
+ parent.read(self, file)
+ if version < 2 then
+ if self.running_std then
+ -- for models before https://github.com/soumith/cudnn.torch/pull/101
+ self.running_var = self.running_std:pow(-2):add(-self.eps)
+ self.running_std = nil
+ end
+ end
+end