Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/soumith/cudnn.torch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'test/bntest.lua')
-rw-r--r--test/bntest.lua19
1 files changed, 19 insertions, 0 deletions
diff --git a/test/bntest.lua b/test/bntest.lua
new file mode 100644
index 0000000..8ebd1fa
--- /dev/null
+++ b/test/bntest.lua
@@ -0,0 +1,19 @@
+require 'cunn'
+require 'cudnn'
+
+local h=5
+local w=5
+local bsz=4
+local from=4
+local input = torch.randn(bsz,from,h,w):cuda()
+local gradOutput = torch.randn(bsz,from,h,w):cuda()
+local cbn = cudnn.SpatialBatchNormalization(bsz, 1e-3):cuda()
+local gbn = nn.SpatialBatchNormalization(bsz, 1e-3):cuda()
+local groundtruth = gbn:forward(input)
+local rescuda = cbn:forward(input)
+local resgrad = cbn:backward(input, gradOutput)
+local groundgrad = gbn:backward(input, gradOutput)
+local error = (rescuda:float() - groundtruth:float()):abs():max()
+print("error",error)
+error = (resgrad:float() - groundgrad:float()):abs():max()
+print("error back",error)