Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/soumith/cudnn.torch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--SpatialBatchNormalization.lua11
-rw-r--r--test/test.lua42
2 files changed, 33 insertions, 20 deletions
diff --git a/SpatialBatchNormalization.lua b/SpatialBatchNormalization.lua
index 093e20b..53e8f7e 100644
--- a/SpatialBatchNormalization.lua
+++ b/SpatialBatchNormalization.lua
@@ -30,6 +30,17 @@ function SpatialBatchNormalization:__init(nFeature, eps, momentum, affine)
self.mode = 'CUDNN_BATCHNORM_SPATIAL'
end
+function SpatialBatchNormalization:reset()
+ if self.weight then
+ self.weight:uniform()
+ end
+ if self.bias then
+ self.bias:zero()
+ end
+ self.running_mean:zero()
+ self.running_std:fill(1)
+end
+
function SpatialBatchNormalization:createIODescriptors(input)
assert(input:dim() == 4)
assert(torch.typename(self.weight) == 'torch.CudaTensor' and torch.typename(self.bias) == 'torch.CudaTensor',
diff --git a/test/test.lua b/test/test.lua
index 493ce53..10ceabb 100644
--- a/test/test.lua
+++ b/test/test.lua
@@ -1136,27 +1136,29 @@ end
function cudnntest.SpatialBatchNormalization()
-- batch
- local h = math.random(5,10)
- local w = math.random(5,10)
- local bsz = math.random(1, 32)
- local from = math.random(1, 32)
- local input = torch.randn(bsz,from,h,w):cuda()
- local gradOutput = torch.randn(bsz,from,h,w):cuda()
- local cbn = cudnn.SpatialBatchNormalization(from, 1e-3):cuda()
- local gbn = nn.SpatialBatchNormalization(from, 1e-3):cuda()
- cbn.weight:copy(gbn.weight)
- cbn.bias:copy(gbn.bias)
- local rescuda = cbn:forward(input)
- local groundtruth = gbn:forward(input)
- local resgrad = cbn:backward(input, gradOutput)
- local groundgrad = gbn:backward(input, gradOutput)
+ local h = math.random(5,10)
+ local w = math.random(5,10)
+ local bsz = math.random(1, 32)
+ local from = math.random(1, 32)
+ local input = torch.randn(bsz,from,h,w):cuda()
+ local gradOutput = torch.randn(bsz,from,h,w):cuda()
+ local cbn = cudnn.SpatialBatchNormalization(from, 1e-3):cuda()
+ local gbn = nn.SpatialBatchNormalization(from, 1e-3):cuda()
+ cbn.weight:copy(gbn.weight)
+ cbn.bias:copy(gbn.bias)
+ mytester:asserteq(cbn.running_mean:mean(), 0, 'error on BN running_mean init')
+ mytester:asserteq(cbn.running_std:mean(), 1, 'error on BN running_std init')
+ local rescuda = cbn:forward(input)
+ local groundtruth = gbn:forward(input)
+ local resgrad = cbn:backward(input, gradOutput)
+ local groundgrad = gbn:backward(input, gradOutput)
- local error = rescuda:float() - groundtruth:float()
- mytester:assertlt(error:abs():max(),
- precision_forward, 'error in batch normalization (forward) ')
- error = resgrad:float() - groundgrad:float()
- mytester:assertlt(error:abs():max(),
- precision_backward, 'error in batch normalization (backward) ')
+ local error = rescuda:float() - groundtruth:float()
+ mytester:assertlt(error:abs():max(),
+ precision_forward, 'error in batch normalization (forward) ')
+ error = resgrad:float() - groundgrad:float()
+ mytester:assertlt(error:abs():max(),
+ precision_backward, 'error in batch normalization (backward) ')
end
function cudnntest.SpatialCrossEntropyCriterion()