diff options
Diffstat (limited to 'test/test.lua')
-rw-r--r-- | test/test.lua | 28 |
1 files changed, 28 insertions, 0 deletions
diff --git a/test/test.lua b/test/test.lua index 8c97ffa..9b85499 100644 --- a/test/test.lua +++ b/test/test.lua @@ -764,6 +764,34 @@ function cudnntest.SpatialLogSoftMax() mytester:assertlt(err, precision_backward, 'error in difference between central difference and :backward') end + +function cudnntest.SpatialBatchNormalization() + -- batch + local h = 4 --math.random(5,10) + local w = 4 --math.random(5,10) + local bsz = 4 --math.random(1, 32) + local from = 4 --math.random(1, 32) + local input = torch.randn(bsz,from,h,w):cuda() + local gradOutput = torch.randn(bsz,from,h,w):cuda() + local cbn = cudnn.SpatialBatchNormalization(bsz, 1e-3):cuda() + local gbn = nn.SpatialBatchNormalization(bsz, 1e-3):cuda() + + local rescuda = cbn:forward(input) + local groundtruth = gbn:forward(input) + local resgrad = cbn:backward(input, gradOutput) + local groundgrad = gbn:backward(input, gradOutput) + + + local error = rescuda:float() - groundtruth:float() + mytester:assertlt(error:abs():max(), + precision_forward, 'error in batch normalization (forward) ') + error = resgrad:float() - groundgrad:float() + mytester:assertlt(error:abs():max(), + precision_backward, 'error in batch normalization (backward) ') + +end + + function cudnntest.SpatialCrossEntropyCriterion() -- batch local numLabels = math.random(5,10) |