Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/soumith/cudnn.torch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'test/test.lua')
-rw-r--r--test/test.lua28
1 files changed, 28 insertions, 0 deletions
diff --git a/test/test.lua b/test/test.lua
index 8c97ffa..9b85499 100644
--- a/test/test.lua
+++ b/test/test.lua
@@ -764,6 +764,34 @@ function cudnntest.SpatialLogSoftMax()
mytester:assertlt(err, precision_backward, 'error in difference between central difference and :backward')
end
+
+function cudnntest.SpatialBatchNormalization()
+ -- batch
+ local h = 4 --math.random(5,10)
+ local w = 4 --math.random(5,10)
+ local bsz = 4 --math.random(1, 32)
+ local from = 4 --math.random(1, 32)
+ local input = torch.randn(bsz,from,h,w):cuda()
+ local gradOutput = torch.randn(bsz,from,h,w):cuda()
+ local cbn = cudnn.SpatialBatchNormalization(bsz, 1e-3):cuda()
+ local gbn = nn.SpatialBatchNormalization(bsz, 1e-3):cuda()
+
+ local rescuda = cbn:forward(input)
+ local groundtruth = gbn:forward(input)
+ local resgrad = cbn:backward(input, gradOutput)
+ local groundgrad = gbn:backward(input, gradOutput)
+
+
+ local error = rescuda:float() - groundtruth:float()
+ mytester:assertlt(error:abs():max(),
+ precision_forward, 'error in batch normalization (forward) ')
+ error = resgrad:float() - groundgrad:float()
+ mytester:assertlt(error:abs():max(),
+ precision_backward, 'error in batch normalization (backward) ')
+
+end
+
+
function cudnntest.SpatialCrossEntropyCriterion()
-- batch
local numLabels = math.random(5,10)