Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/soumith/cudnn.torch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSergey Zagoruyko <zagoruyko2@gmail.com>2016-02-17 18:56:48 +0300
committerSergey Zagoruyko <zagoruyko2@gmail.com>2016-02-17 18:56:48 +0300
commitf1cfa7ca2d379ac9c336147a59b45fbf2039ffbf (patch)
tree3b121f9d6724eaab4c259fd748730762bbcb3f6b /SpatialBatchNormalization.lua
parent6eb3c2b93545d9b809845c10197f91f3f4a9886a (diff)
parentf15204373c76637684cb8a14f30d7ebb3bf63c62 (diff)
Merge pull request #116 from szagoruyko/bn-reset-fix
add forgotten reset for Batch Normalization
Diffstat (limited to 'SpatialBatchNormalization.lua')
-rw-r--r--SpatialBatchNormalization.lua11
1 files changed, 11 insertions, 0 deletions
diff --git a/SpatialBatchNormalization.lua b/SpatialBatchNormalization.lua
index 093e20b..53e8f7e 100644
--- a/SpatialBatchNormalization.lua
+++ b/SpatialBatchNormalization.lua
@@ -30,6 +30,17 @@ function SpatialBatchNormalization:__init(nFeature, eps, momentum, affine)
self.mode = 'CUDNN_BATCHNORM_SPATIAL'
end
+function SpatialBatchNormalization:reset()
+ if self.weight then
+ self.weight:uniform()
+ end
+ if self.bias then
+ self.bias:zero()
+ end
+ self.running_mean:zero()
+ self.running_std:fill(1)
+end
+
function SpatialBatchNormalization:createIODescriptors(input)
assert(input:dim() == 4)
assert(torch.typename(self.weight) == 'torch.CudaTensor' and torch.typename(self.bias) == 'torch.CudaTensor',