Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/soumith/cudnn.torch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorSergey Zagoruyko <zagoruyko2@gmail.com>2016-04-14 02:00:34 +0300
committerSergey Zagoruyko <zagoruyko2@gmail.com>2016-04-16 18:27:36 +0300
commitfa165cb48914a08f4f9c476009dda26a91679a9a (patch)
treedbe1506c7ffaf6b35439acfd3dc16b56658732ca /test
parenta5c2b6fa7042fafa2d2179d5f7cccf06a026a15f (diff)
cudnn.convert for BN
Diffstat (limited to 'test')
-rw-r--r--test/test.lua55
1 files changed, 43 insertions, 12 deletions
diff --git a/test/test.lua b/test/test.lua
index b7e5376..9b5cbde 100644
--- a/test/test.lua
+++ b/test/test.lua
@@ -1241,19 +1241,50 @@ local function testBatchNormalization(moduleName, inputSize)
local gbn = nn[moduleName](inputSize[2], 1e-3):cuda()
cbn.weight:copy(gbn.weight)
cbn.bias:copy(gbn.bias)
- mytester:asserteq(cbn.running_mean:mean(), 0, 'error on BN running_mean init')
- mytester:asserteq(cbn.running_var:mean(), 1, 'error on BN running_var init')
- local rescuda = cbn:forward(input)
- local groundtruth = gbn:forward(input)
- local resgrad = cbn:backward(input, gradOutput)
- local groundgrad = gbn:backward(input, gradOutput)
- local error = rescuda:float() - groundtruth:float()
- mytester:assertlt(error:abs():max(),
- precision_forward, 'error in batch normalization (forward) ')
- error = resgrad:float() - groundgrad:float()
- mytester:assertlt(error:abs():max(),
- precision_backward, 'error in batch normalization (backward) ')
+ local function testFWDBWD(cbn, gbn)
+ cbn:training()
+ gbn:training()
+ mytester:asserteq(cbn.running_mean:mean(), 0, 'error on BN running_mean init')
+ mytester:asserteq(cbn.running_var:mean(), 1, 'error on BN running_var init')
+ local rescuda = cbn:forward(input)
+ local groundtruth = gbn:forward(input)
+ local resgrad = cbn:backward(input, gradOutput)
+ local groundgrad = gbn:backward(input, gradOutput)
+
+ local error = rescuda:float() - groundtruth:float()
+ mytester:assertlt(error:abs():max(),
+ precision_forward, 'error in batch normalization (forward) ')
+ error = resgrad:float() - groundgrad:float()
+ mytester:assertlt(error:abs():max(),
+ precision_backward, 'error in batch normalization (backward) ')
+ error = cbn.running_mean:float() - gbn.running_mean:float()
+ mytester:assertlt(error:abs():max(),
+ precision_forward, 'error in batch normalization (running_mean) ')
+ error = cbn.running_var:float() - gbn.running_var:float()
+ mytester:assertlt(error:abs():max(),
+ precision_forward, 'error in batch normalization (running_var) ')
+ end
+
+ local function testFWD(cbn, gbn)
+ cbn:evaluate()
+ gbn:evaluate()
+ local rescuda = cbn:forward(input)
+ local groundtruth = gbn:forward(input)
+
+ local error = rescuda:float() - groundtruth:float()
+ mytester:assertlt(error:abs():max(),
+ precision_forward, 'error in batch normalization (forward) ')
+ end
+
+ testFWDBWD(cbn, gbn)
+ testFWD(cbn, gbn)
+ local cudnn2nn = cudnn.convert(cbn:clone(), nn)
+ mytester:asserteq(torch.type(cudnn2nn), 'nn.'..moduleName, 'cudnn to nn')
+ testFWD(cudnn2nn, gbn)
+ local nn2cudnn = cudnn.convert(gbn:clone(), cudnn)
+ mytester:asserteq(torch.type(nn2cudnn), 'cudnn.'..moduleName, 'cudnn to nn')
+ testFWD(nn2cudnn, gbn)
end
function cudnntest.BatchNormalization()