diff options
author | Sergey Zagoruyko <zagoruyko2@gmail.com> | 2016-02-17 18:56:48 +0300 |
---|---|---|
committer | Sergey Zagoruyko <zagoruyko2@gmail.com> | 2016-02-17 18:56:48 +0300 |
commit | f1cfa7ca2d379ac9c336147a59b45fbf2039ffbf (patch) | |
tree | 3b121f9d6724eaab4c259fd748730762bbcb3f6b /SpatialBatchNormalization.lua | |
parent | 6eb3c2b93545d9b809845c10197f91f3f4a9886a (diff) | |
parent | f15204373c76637684cb8a14f30d7ebb3bf63c62 (diff) |
Merge pull request #116 from szagoruyko/bn-reset-fix
add forgotten reset for Batch Normalization
Diffstat (limited to 'SpatialBatchNormalization.lua')
-rw-r--r-- | SpatialBatchNormalization.lua | 11 |
1 files changed, 11 insertions, 0 deletions
diff --git a/SpatialBatchNormalization.lua b/SpatialBatchNormalization.lua index 093e20b..53e8f7e 100644 --- a/SpatialBatchNormalization.lua +++ b/SpatialBatchNormalization.lua @@ -30,6 +30,17 @@ function SpatialBatchNormalization:__init(nFeature, eps, momentum, affine) self.mode = 'CUDNN_BATCHNORM_SPATIAL' end +function SpatialBatchNormalization:reset() + if self.weight then + self.weight:uniform() + end + if self.bias then + self.bias:zero() + end + self.running_mean:zero() + self.running_std:fill(1) +end + function SpatialBatchNormalization:createIODescriptors(input) assert(input:dim() == 4) assert(torch.typename(self.weight) == 'torch.CudaTensor' and torch.typename(self.bias) == 'torch.CudaTensor', |