From 02cade60f8f40b84c8d65ea2616a317b1fa9590e Mon Sep 17 00:00:00 2001 From: soumith Date: Thu, 11 Feb 2016 18:50:44 -0500 Subject: serialization fix for older models --- SpatialBatchNormalization.lua | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) (limited to 'SpatialBatchNormalization.lua') diff --git a/SpatialBatchNormalization.lua b/SpatialBatchNormalization.lua index d88da4e..c8adce7 100644 --- a/SpatialBatchNormalization.lua +++ b/SpatialBatchNormalization.lua @@ -2,6 +2,8 @@ local SpatialBatchNormalization, parent = torch.class('cudnn.SpatialBatchNormali local ffi = require 'ffi' local errcheck = cudnn.errcheck +SpatialBatchNormalization.__version = 2 + function SpatialBatchNormalization:__init(nFeature, eps, momentum, affine) parent.__init(self, nFeature, eps, momentum, affine) self.mode = 'CUDNN_BATCHNORM_SPATIAL' @@ -53,13 +55,13 @@ local function backward(self,input,gradOutput, scale) assert(gradOutput:isContiguous()) self:createIODescriptors(input) scale = scale or 1 - scaleTens:fill(scale) + scaleTens:fill(scale) errcheck('cudnnBatchNormalizationBackward', cudnn.getHandle(), self.mode, one:data(), zero:data(), scaleTens:data(), one:data(), self.iDesc[0], input:data(), self.iDesc[0], gradOutput:data(), self.iDesc[0], self.gradInput:data(), -- input is bottom, gradOutput is topDiff, self.gradInput is resultBottomDiff self.sDesc[0], self.weight:data(), self.gradWeight:data(), self.gradBias:data(), - self.eps, self.save_mean:data(), self.save_std:data()); + self.eps, self.save_mean:data(), self.save_std:data()); return self.gradInput end @@ -86,3 +88,14 @@ function SpatialBatchNormalization:write(f) end f:writeObject(var) end + +function SpatialBatchNormalization:read(file, version) + parent.read(self, file) + if version < 2 then + if self.running_std then + -- for models before https://github.com/soumith/cudnn.torch/pull/101 + self.running_var = self.running_std:pow(-2):add(-self.eps) + self.running_std = nil + end + end +end -- cgit v1.2.3