Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/soumith/cudnn.torch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorSergey Zagoruyko <zagoruyko2@gmail.com>2016-02-17 18:40:23 +0300
committerSergey Zagoruyko <zagoruyko2@gmail.com>2016-02-17 18:55:17 +0300
commitf15204373c76637684cb8a14f30d7ebb3bf63c62 (patch)
tree3b121f9d6724eaab4c259fd748730762bbcb3f6b /test
parent6eb3c2b93545d9b809845c10197f91f3f4a9886a (diff)
add forgotten reset for BN
Diffstat (limited to 'test')
-rw-r--r--test/test.lua42
1 files changed, 22 insertions, 20 deletions
diff --git a/test/test.lua b/test/test.lua
index 493ce53..10ceabb 100644
--- a/test/test.lua
+++ b/test/test.lua
@@ -1136,27 +1136,29 @@ end
function cudnntest.SpatialBatchNormalization()
-- batch
- local h = math.random(5,10)
- local w = math.random(5,10)
- local bsz = math.random(1, 32)
- local from = math.random(1, 32)
- local input = torch.randn(bsz,from,h,w):cuda()
- local gradOutput = torch.randn(bsz,from,h,w):cuda()
- local cbn = cudnn.SpatialBatchNormalization(from, 1e-3):cuda()
- local gbn = nn.SpatialBatchNormalization(from, 1e-3):cuda()
- cbn.weight:copy(gbn.weight)
- cbn.bias:copy(gbn.bias)
- local rescuda = cbn:forward(input)
- local groundtruth = gbn:forward(input)
- local resgrad = cbn:backward(input, gradOutput)
- local groundgrad = gbn:backward(input, gradOutput)
+ local h = math.random(5,10)
+ local w = math.random(5,10)
+ local bsz = math.random(1, 32)
+ local from = math.random(1, 32)
+ local input = torch.randn(bsz,from,h,w):cuda()
+ local gradOutput = torch.randn(bsz,from,h,w):cuda()
+ local cbn = cudnn.SpatialBatchNormalization(from, 1e-3):cuda()
+ local gbn = nn.SpatialBatchNormalization(from, 1e-3):cuda()
+ cbn.weight:copy(gbn.weight)
+ cbn.bias:copy(gbn.bias)
+ mytester:asserteq(cbn.running_mean:mean(), 0, 'error on BN running_mean init')
+ mytester:asserteq(cbn.running_std:mean(), 1, 'error on BN running_std init')
+ local rescuda = cbn:forward(input)
+ local groundtruth = gbn:forward(input)
+ local resgrad = cbn:backward(input, gradOutput)
+ local groundgrad = gbn:backward(input, gradOutput)
- local error = rescuda:float() - groundtruth:float()
- mytester:assertlt(error:abs():max(),
- precision_forward, 'error in batch normalization (forward) ')
- error = resgrad:float() - groundgrad:float()
- mytester:assertlt(error:abs():max(),
- precision_backward, 'error in batch normalization (backward) ')
+ local error = rescuda:float() - groundtruth:float()
+ mytester:assertlt(error:abs():max(),
+ precision_forward, 'error in batch normalization (forward) ')
+ error = resgrad:float() - groundgrad:float()
+ mytester:assertlt(error:abs():max(),
+ precision_backward, 'error in batch normalization (backward) ')
end
function cudnntest.SpatialCrossEntropyCriterion()