diff options
author | Sergey Zagoruyko <zagoruyko2@gmail.com> | 2016-02-17 18:40:23 +0300 |
---|---|---|
committer | Sergey Zagoruyko <zagoruyko2@gmail.com> | 2016-02-17 18:55:17 +0300 |
commit | f15204373c76637684cb8a14f30d7ebb3bf63c62 (patch) | |
tree | 3b121f9d6724eaab4c259fd748730762bbcb3f6b /test | |
parent | 6eb3c2b93545d9b809845c10197f91f3f4a9886a (diff) |
add forgotten reset for BN
Diffstat (limited to 'test')
-rw-r--r-- | test/test.lua | 42 |
1 files changed, 22 insertions, 20 deletions
diff --git a/test/test.lua b/test/test.lua index 493ce53..10ceabb 100644 --- a/test/test.lua +++ b/test/test.lua @@ -1136,27 +1136,29 @@ end function cudnntest.SpatialBatchNormalization() -- batch - local h = math.random(5,10) - local w = math.random(5,10) - local bsz = math.random(1, 32) - local from = math.random(1, 32) - local input = torch.randn(bsz,from,h,w):cuda() - local gradOutput = torch.randn(bsz,from,h,w):cuda() - local cbn = cudnn.SpatialBatchNormalization(from, 1e-3):cuda() - local gbn = nn.SpatialBatchNormalization(from, 1e-3):cuda() - cbn.weight:copy(gbn.weight) - cbn.bias:copy(gbn.bias) - local rescuda = cbn:forward(input) - local groundtruth = gbn:forward(input) - local resgrad = cbn:backward(input, gradOutput) - local groundgrad = gbn:backward(input, gradOutput) + local h = math.random(5,10) + local w = math.random(5,10) + local bsz = math.random(1, 32) + local from = math.random(1, 32) + local input = torch.randn(bsz,from,h,w):cuda() + local gradOutput = torch.randn(bsz,from,h,w):cuda() + local cbn = cudnn.SpatialBatchNormalization(from, 1e-3):cuda() + local gbn = nn.SpatialBatchNormalization(from, 1e-3):cuda() + cbn.weight:copy(gbn.weight) + cbn.bias:copy(gbn.bias) + mytester:asserteq(cbn.running_mean:mean(), 0, 'error on BN running_mean init') + mytester:asserteq(cbn.running_std:mean(), 1, 'error on BN running_std init') + local rescuda = cbn:forward(input) + local groundtruth = gbn:forward(input) + local resgrad = cbn:backward(input, gradOutput) + local groundgrad = gbn:backward(input, gradOutput) - local error = rescuda:float() - groundtruth:float() - mytester:assertlt(error:abs():max(), - precision_forward, 'error in batch normalization (forward) ') - error = resgrad:float() - groundgrad:float() - mytester:assertlt(error:abs():max(), - precision_backward, 'error in batch normalization (backward) ') + local error = rescuda:float() - groundtruth:float() + mytester:assertlt(error:abs():max(), + precision_forward, 'error in batch normalization (forward) ') + error = resgrad:float() - groundgrad:float() + mytester:assertlt(error:abs():max(), + precision_backward, 'error in batch normalization (backward) ') end function cudnntest.SpatialCrossEntropyCriterion() |