Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/soumith/cudnn.torch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorBoris Fomitchev <bfomitchev@nvidia.com>2015-11-13 01:25:43 +0300
committerBoris Fomitchev <bfomitchev@nvidia.com>2015-11-13 01:25:43 +0300
commit413634aa8e27d4daed18d03e56da20046c62ce66 (patch)
tree497c41f3f912e64459d3d013f939c7178f65b1f5 /test
parent09b428e5896f62f700e24aa3393ebdac75982f30 (diff)
Natalia's fixed for BN. Added bntest.lua
Diffstat (limited to 'test')
-rw-r--r--test/bntest.lua19
-rw-r--r--test/test.lua28
2 files changed, 47 insertions, 0 deletions
diff --git a/test/bntest.lua b/test/bntest.lua
new file mode 100644
index 0000000..8ebd1fa
--- /dev/null
+++ b/test/bntest.lua
@@ -0,0 +1,19 @@
+require 'cunn'
+require 'cudnn'
+
+local h=5
+local w=5
+local bsz=4
+local from=4
+local input = torch.randn(bsz,from,h,w):cuda()
+local gradOutput = torch.randn(bsz,from,h,w):cuda()
+local cbn = cudnn.SpatialBatchNormalization(bsz, 1e-3):cuda()
+local gbn = nn.SpatialBatchNormalization(bsz, 1e-3):cuda()
+local groundtruth = gbn:forward(input)
+local rescuda = cbn:forward(input)
+local resgrad = cbn:backward(input, gradOutput)
+local groundgrad = gbn:backward(input, gradOutput)
+local error = (rescuda:float() - groundtruth:float()):abs():max()
+print("error",error)
+error = (resgrad:float() - groundgrad:float()):abs():max()
+print("error back",error)
diff --git a/test/test.lua b/test/test.lua
index 8c97ffa..9b85499 100644
--- a/test/test.lua
+++ b/test/test.lua
@@ -764,6 +764,34 @@ function cudnntest.SpatialLogSoftMax()
mytester:assertlt(err, precision_backward, 'error in difference between central difference and :backward')
end
+
+function cudnntest.SpatialBatchNormalization()
+ -- batch
+ local h = 4 --math.random(5,10)
+ local w = 4 --math.random(5,10)
+ local bsz = 4 --math.random(1, 32)
+ local from = 4 --math.random(1, 32)
+ local input = torch.randn(bsz,from,h,w):cuda()
+ local gradOutput = torch.randn(bsz,from,h,w):cuda()
+ local cbn = cudnn.SpatialBatchNormalization(bsz, 1e-3):cuda()
+ local gbn = nn.SpatialBatchNormalization(bsz, 1e-3):cuda()
+
+ local rescuda = cbn:forward(input)
+ local groundtruth = gbn:forward(input)
+ local resgrad = cbn:backward(input, gradOutput)
+ local groundgrad = gbn:backward(input, gradOutput)
+
+
+ local error = rescuda:float() - groundtruth:float()
+ mytester:assertlt(error:abs():max(),
+ precision_forward, 'error in batch normalization (forward) ')
+ error = resgrad:float() - groundgrad:float()
+ mytester:assertlt(error:abs():max(),
+ precision_backward, 'error in batch normalization (backward) ')
+
+end
+
+
function cudnntest.SpatialCrossEntropyCriterion()
-- batch
local numLabels = math.random(5,10)