diff options
author | Boris Fomitchev <bfomitchev@nvidia.com> | 2015-11-13 01:25:43 +0300 |
---|---|---|
committer | Boris Fomitchev <bfomitchev@nvidia.com> | 2015-11-13 01:25:43 +0300 |
commit | 413634aa8e27d4daed18d03e56da20046c62ce66 (patch) | |
tree | 497c41f3f912e64459d3d013f939c7178f65b1f5 /test | |
parent | 09b428e5896f62f700e24aa3393ebdac75982f30 (diff) |
Natalia's fixed for BN. Added bntest.lua
Diffstat (limited to 'test')
-rw-r--r-- | test/bntest.lua | 19 | ||||
-rw-r--r-- | test/test.lua | 28 |
2 files changed, 47 insertions, 0 deletions
diff --git a/test/bntest.lua b/test/bntest.lua new file mode 100644 index 0000000..8ebd1fa --- /dev/null +++ b/test/bntest.lua @@ -0,0 +1,19 @@ +require 'cunn' +require 'cudnn' + +local h=5 +local w=5 +local bsz=4 +local from=4 +local input = torch.randn(bsz,from,h,w):cuda() +local gradOutput = torch.randn(bsz,from,h,w):cuda() +local cbn = cudnn.SpatialBatchNormalization(bsz, 1e-3):cuda() +local gbn = nn.SpatialBatchNormalization(bsz, 1e-3):cuda() +local groundtruth = gbn:forward(input) +local rescuda = cbn:forward(input) +local resgrad = cbn:backward(input, gradOutput) +local groundgrad = gbn:backward(input, gradOutput) +local error = (rescuda:float() - groundtruth:float()):abs():max() +print("error",error) +error = (resgrad:float() - groundgrad:float()):abs():max() +print("error back",error) diff --git a/test/test.lua b/test/test.lua index 8c97ffa..9b85499 100644 --- a/test/test.lua +++ b/test/test.lua @@ -764,6 +764,34 @@ function cudnntest.SpatialLogSoftMax() mytester:assertlt(err, precision_backward, 'error in difference between central difference and :backward') end + +function cudnntest.SpatialBatchNormalization() + -- batch + local h = 4 --math.random(5,10) + local w = 4 --math.random(5,10) + local bsz = 4 --math.random(1, 32) + local from = 4 --math.random(1, 32) + local input = torch.randn(bsz,from,h,w):cuda() + local gradOutput = torch.randn(bsz,from,h,w):cuda() + local cbn = cudnn.SpatialBatchNormalization(bsz, 1e-3):cuda() + local gbn = nn.SpatialBatchNormalization(bsz, 1e-3):cuda() + + local rescuda = cbn:forward(input) + local groundtruth = gbn:forward(input) + local resgrad = cbn:backward(input, gradOutput) + local groundgrad = gbn:backward(input, gradOutput) + + + local error = rescuda:float() - groundtruth:float() + mytester:assertlt(error:abs():max(), + precision_forward, 'error in batch normalization (forward) ') + error = resgrad:float() - groundgrad:float() + mytester:assertlt(error:abs():max(), + precision_backward, 'error in batch normalization (backward) ') + +end + + function cudnntest.SpatialCrossEntropyCriterion() -- batch local numLabels = math.random(5,10) |