Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/soumith/cudnn.torch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorsoumith <soumith@gmail.com>2015-11-06 01:44:15 +0300
committersoumith <soumith@gmail.com>2015-11-06 01:44:15 +0300
commit507c3a35e3bfbba2af102fbead0c2fb41e9db9b0 (patch)
treeadd28ed63809ef677e224128e31706bb9df7dc4f /test
parent3532bf81ab15df72a6dc900a1788d3e1d3d1fa2e (diff)
integrating changes from master
Diffstat (limited to 'test')
-rw-r--r--test/test.lua67
1 files changed, 67 insertions, 0 deletions
diff --git a/test/test.lua b/test/test.lua
index 0a2fb01..6dbd0d8 100644
--- a/test/test.lua
+++ b/test/test.lua
@@ -730,6 +730,73 @@ function cudnntest.LogSoftMax_batch()
precision_backward, 'error on state (backward) ')
end
+function cudnntest.SpatialLogSoftMax()
+ -- batch
+ local numLabels = math.random(5,10)
+ local h = math.random(5,10)
+ local w = math.random(5,10)
+ local bsz = math.random(3, 7)
+ local input = torch.zeros(bsz, numLabels, h, w):normal():cuda()
+ local target = torch.zeros(bsz, numLabels, h, w):normal():cuda()
+
+ local cri = cudnn.SpatialLogSoftMax():cuda()
+ local gcri = nn.LogSoftMax():cuda()
+
+ local op = cri:forward(input, target)
+ local gi = cri:backward(input, target)
+
+ local gop = op:clone():zero()
+ local ggi = gi:clone():zero()
+
+ for i=1,h do
+ for j=1,w do
+ local i1 = input[{{}, {}, {i}, {j}}]:contiguous():squeeze()
+ local t1 = target[{{}, {}, {i}, {j}}]:contiguous():squeeze()
+ local gop1 = gcri:forward(i1, t1)
+ local ggi1 = gcri:backward(i1, t1)
+ gop[{{}, {}, {i}, {j}}]:copy(gop1)
+ ggi[{{}, {}, {i}, {j}}]:copy(ggi1)
+ end
+ end
+ local err = (gi - ggi):abs():max()
+ mytester:assertlt(err, precision_backward, 'error in difference between central difference and :backward')
+ local err = (op - gop):abs():max()
+ mytester:assertlt(err, precision_backward, 'error in difference between central difference and :backward')
+end
+
+function cudnntest.SpatialCrossEntropyCriterion()
+ -- batch
+ local numLabels = math.random(5,10)
+ local h = math.random(5,10)
+ local w = math.random(5,10)
+ local bsz = math.random(3, 7)
+ local input = torch.zeros(bsz, numLabels, h, w):normal():cuda()
+ local target = torch.Tensor(bsz, h, w):random(1, numLabels):cuda()
+
+ local cri = cudnn.SpatialCrossEntropyCriterion():cuda()
+
+ local gcri = nn.CrossEntropyCriterion():cuda()
+
+ local op = cri:forward(input, target)
+ local gi = cri:backward(input, target)
+
+ local ggi = gi:clone():zero()
+
+ for i=1,h do
+ for j=1,w do
+ local i1 = input[{{}, {}, {i}, {j}}]:contiguous():squeeze()
+ local t1 = target[{{}, {i}, {j}}]:contiguous():squeeze()
+ local gop1 = gcri:forward(i1, t1)
+ local ggi1 = gcri:backward(i1, t1)
+ ggi[{{}, {}, {i}, {j}}]:copy(ggi1)
+ end
+ end
+ local err = (gi - ggi):abs():max()
+ mytester:assertlt(err, precision_backward, 'error in difference between central difference and :backward')
+
+end
+
+
function cudnntest.functional_bias2D()
local bs = math.random(1,32)
local from = math.random(1,32)