diff options
author | Sergey Zagoruyko <zagoruyko2@gmail.com> | 2016-04-14 02:00:34 +0300 |
---|---|---|
committer | Sergey Zagoruyko <zagoruyko2@gmail.com> | 2016-04-16 18:27:36 +0300 |
commit | fa165cb48914a08f4f9c476009dda26a91679a9a (patch) | |
tree | dbe1506c7ffaf6b35439acfd3dc16b56658732ca /test | |
parent | a5c2b6fa7042fafa2d2179d5f7cccf06a026a15f (diff) |
cudnn.convert for BN
Diffstat (limited to 'test')
-rw-r--r-- | test/test.lua | 55 |
1 files changed, 43 insertions, 12 deletions
diff --git a/test/test.lua b/test/test.lua index b7e5376..9b5cbde 100644 --- a/test/test.lua +++ b/test/test.lua @@ -1241,19 +1241,50 @@ local function testBatchNormalization(moduleName, inputSize) local gbn = nn[moduleName](inputSize[2], 1e-3):cuda() cbn.weight:copy(gbn.weight) cbn.bias:copy(gbn.bias) - mytester:asserteq(cbn.running_mean:mean(), 0, 'error on BN running_mean init') - mytester:asserteq(cbn.running_var:mean(), 1, 'error on BN running_var init') - local rescuda = cbn:forward(input) - local groundtruth = gbn:forward(input) - local resgrad = cbn:backward(input, gradOutput) - local groundgrad = gbn:backward(input, gradOutput) - local error = rescuda:float() - groundtruth:float() - mytester:assertlt(error:abs():max(), - precision_forward, 'error in batch normalization (forward) ') - error = resgrad:float() - groundgrad:float() - mytester:assertlt(error:abs():max(), - precision_backward, 'error in batch normalization (backward) ') + local function testFWDBWD(cbn, gbn) + cbn:training() + gbn:training() + mytester:asserteq(cbn.running_mean:mean(), 0, 'error on BN running_mean init') + mytester:asserteq(cbn.running_var:mean(), 1, 'error on BN running_var init') + local rescuda = cbn:forward(input) + local groundtruth = gbn:forward(input) + local resgrad = cbn:backward(input, gradOutput) + local groundgrad = gbn:backward(input, gradOutput) + + local error = rescuda:float() - groundtruth:float() + mytester:assertlt(error:abs():max(), + precision_forward, 'error in batch normalization (forward) ') + error = resgrad:float() - groundgrad:float() + mytester:assertlt(error:abs():max(), + precision_backward, 'error in batch normalization (backward) ') + error = cbn.running_mean:float() - gbn.running_mean:float() + mytester:assertlt(error:abs():max(), + precision_forward, 'error in batch normalization (running_mean) ') + error = cbn.running_var:float() - gbn.running_var:float() + mytester:assertlt(error:abs():max(), + precision_forward, 'error in batch normalization (running_var) ') + end + + local function testFWD(cbn, gbn) + cbn:evaluate() + gbn:evaluate() + local rescuda = cbn:forward(input) + local groundtruth = gbn:forward(input) + + local error = rescuda:float() - groundtruth:float() + mytester:assertlt(error:abs():max(), + precision_forward, 'error in batch normalization (forward) ') + end + + testFWDBWD(cbn, gbn) + testFWD(cbn, gbn) + local cudnn2nn = cudnn.convert(cbn:clone(), nn) + mytester:asserteq(torch.type(cudnn2nn), 'nn.'..moduleName, 'cudnn to nn') + testFWD(cudnn2nn, gbn) + local nn2cudnn = cudnn.convert(gbn:clone(), cudnn) + mytester:asserteq(torch.type(nn2cudnn), 'cudnn.'..moduleName, 'cudnn to nn') + testFWD(nn2cudnn, gbn) end function cudnntest.BatchNormalization() |