Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/soumith/cudnn.torch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'test/test.lua')
-rw-r--r--test/test.lua55
1 files changed, 43 insertions, 12 deletions
diff --git a/test/test.lua b/test/test.lua
index b7e5376..9b5cbde 100644
--- a/test/test.lua
+++ b/test/test.lua
@@ -1241,19 +1241,50 @@ local function testBatchNormalization(moduleName, inputSize)
local gbn = nn[moduleName](inputSize[2], 1e-3):cuda()
cbn.weight:copy(gbn.weight)
cbn.bias:copy(gbn.bias)
- mytester:asserteq(cbn.running_mean:mean(), 0, 'error on BN running_mean init')
- mytester:asserteq(cbn.running_var:mean(), 1, 'error on BN running_var init')
- local rescuda = cbn:forward(input)
- local groundtruth = gbn:forward(input)
- local resgrad = cbn:backward(input, gradOutput)
- local groundgrad = gbn:backward(input, gradOutput)
- local error = rescuda:float() - groundtruth:float()
- mytester:assertlt(error:abs():max(),
- precision_forward, 'error in batch normalization (forward) ')
- error = resgrad:float() - groundgrad:float()
- mytester:assertlt(error:abs():max(),
- precision_backward, 'error in batch normalization (backward) ')
+ local function testFWDBWD(cbn, gbn)
+ cbn:training()
+ gbn:training()
+ mytester:asserteq(cbn.running_mean:mean(), 0, 'error on BN running_mean init')
+ mytester:asserteq(cbn.running_var:mean(), 1, 'error on BN running_var init')
+ local rescuda = cbn:forward(input)
+ local groundtruth = gbn:forward(input)
+ local resgrad = cbn:backward(input, gradOutput)
+ local groundgrad = gbn:backward(input, gradOutput)
+
+ local error = rescuda:float() - groundtruth:float()
+ mytester:assertlt(error:abs():max(),
+ precision_forward, 'error in batch normalization (forward) ')
+ error = resgrad:float() - groundgrad:float()
+ mytester:assertlt(error:abs():max(),
+ precision_backward, 'error in batch normalization (backward) ')
+ error = cbn.running_mean:float() - gbn.running_mean:float()
+ mytester:assertlt(error:abs():max(),
+ precision_forward, 'error in batch normalization (running_mean) ')
+ error = cbn.running_var:float() - gbn.running_var:float()
+ mytester:assertlt(error:abs():max(),
+ precision_forward, 'error in batch normalization (running_var) ')
+ end
+
+ local function testFWD(cbn, gbn)
+ cbn:evaluate()
+ gbn:evaluate()
+ local rescuda = cbn:forward(input)
+ local groundtruth = gbn:forward(input)
+
+ local error = rescuda:float() - groundtruth:float()
+ mytester:assertlt(error:abs():max(),
+ precision_forward, 'error in batch normalization (forward) ')
+ end
+
+ testFWDBWD(cbn, gbn)
+ testFWD(cbn, gbn)
+ local cudnn2nn = cudnn.convert(cbn:clone(), nn)
+ mytester:asserteq(torch.type(cudnn2nn), 'nn.'..moduleName, 'cudnn to nn')
+ testFWD(cudnn2nn, gbn)
+ local nn2cudnn = cudnn.convert(gbn:clone(), cudnn)
+ mytester:asserteq(torch.type(nn2cudnn), 'cudnn.'..moduleName, 'cudnn to nn')
+ testFWD(nn2cudnn, gbn)
end
function cudnntest.BatchNormalization()