Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/soumith/cudnn.torch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorsoumith <soumith@fb.com>2015-10-20 22:48:37 +0300
committersoumith <soumith@fb.com>2015-10-20 22:49:35 +0300
commit596d8cbca19a425c51c1b69d0bfac1387324ac38 (patch)
tree7a6e72c328c0b1fcb9ff0958d8f486854761ddf5 /test
parent5caa2641e5f6371391159f3d379dbe93c8dda5f0 (diff)
adding SpatialCrossEntropyCriterion
Diffstat (limited to 'test')
-rw-r--r--test/test.lua68
1 files changed, 68 insertions, 0 deletions
diff --git a/test/test.lua b/test/test.lua
index 0a2fb01..8c97ffa 100644
--- a/test/test.lua
+++ b/test/test.lua
@@ -730,6 +730,74 @@ function cudnntest.LogSoftMax_batch()
precision_backward, 'error on state (backward) ')
end
+function cudnntest.SpatialLogSoftMax()
+ -- batch
+ local numLabels = math.random(5,10)
+ local h = math.random(5,10)
+ local w = math.random(5,10)
+ local bsz = math.random(3, 7)
+ local input = torch.zeros(bsz, numLabels, h, w):normal():cuda()
+ local target = torch.zeros(bsz, numLabels, h, w):normal():cuda()
+
+ local cri = cudnn.SpatialLogSoftMax():cuda()
+ local gcri = nn.LogSoftMax():cuda()
+
+ local op = cri:forward(input, target)
+ local gi = cri:backward(input, target)
+
+ local gop = op:clone():zero()
+ local ggi = gi:clone():zero()
+
+ for i=1,h do
+ for j=1,w do
+ local i1 = input[{{}, {}, {i}, {j}}]:contiguous():squeeze()
+ local t1 = target[{{}, {}, {i}, {j}}]:contiguous():squeeze()
+ local gop1 = gcri:forward(i1, t1)
+ local ggi1 = gcri:backward(i1, t1)
+ gop[{{}, {}, {i}, {j}}]:copy(gop1)
+ ggi[{{}, {}, {i}, {j}}]:copy(ggi1)
+ end
+ end
+ local err = (gi - ggi):abs():max()
+ mytester:assertlt(err, precision_backward, 'error in difference between central difference and :backward')
+ local err = (op - gop):abs():max()
+ mytester:assertlt(err, precision_backward, 'error in difference between central difference and :backward')
+end
+
+function cudnntest.SpatialCrossEntropyCriterion()
+ -- batch
+ local numLabels = math.random(5,10)
+ local h = math.random(5,10)
+ local w = math.random(5,10)
+ local bsz = math.random(3, 7)
+ local input = torch.zeros(bsz, numLabels, h, w):normal():cuda()
+ local target = torch.Tensor(bsz, h, w):random(1, numLabels):cuda()
+
+ local cri = cudnn.SpatialCrossEntropyCriterion():cuda()
+
+ local gcri = nn.CrossEntropyCriterion():cuda()
+
+ local op = cri:forward(input, target)
+ local gi = cri:backward(input, target)
+
+ local ggi = gi:clone():zero()
+
+ for i=1,h do
+ for j=1,w do
+ local i1 = input[{{}, {}, {i}, {j}}]:contiguous():squeeze()
+ local t1 = target[{{}, {i}, {j}}]:contiguous():squeeze()
+ local gop1 = gcri:forward(i1, t1)
+ local ggi1 = gcri:backward(i1, t1)
+ ggi[{{}, {}, {i}, {j}}]:copy(ggi1)
+ end
+ end
+ local err = (gi - ggi):abs():max()
+ mytester:assertlt(err, precision_backward, 'error in difference between central difference and :backward')
+
+end
+
+
+
function cudnntest.functional_bias2D()
local bs = math.random(1,32)
local from = math.random(1,32)