diff options
author | soumith <soumith@fb.com> | 2015-10-20 22:48:37 +0300 |
---|---|---|
committer | soumith <soumith@fb.com> | 2015-10-20 22:49:35 +0300 |
commit | 596d8cbca19a425c51c1b69d0bfac1387324ac38 (patch) | |
tree | 7a6e72c328c0b1fcb9ff0958d8f486854761ddf5 /test | |
parent | 5caa2641e5f6371391159f3d379dbe93c8dda5f0 (diff) |
adding SpatialCrossEntropyCriterion
Diffstat (limited to 'test')
-rw-r--r-- | test/test.lua | 68 |
1 files changed, 68 insertions, 0 deletions
diff --git a/test/test.lua b/test/test.lua index 0a2fb01..8c97ffa 100644 --- a/test/test.lua +++ b/test/test.lua @@ -730,6 +730,74 @@ function cudnntest.LogSoftMax_batch() precision_backward, 'error on state (backward) ') end +function cudnntest.SpatialLogSoftMax() + -- batch + local numLabels = math.random(5,10) + local h = math.random(5,10) + local w = math.random(5,10) + local bsz = math.random(3, 7) + local input = torch.zeros(bsz, numLabels, h, w):normal():cuda() + local target = torch.zeros(bsz, numLabels, h, w):normal():cuda() + + local cri = cudnn.SpatialLogSoftMax():cuda() + local gcri = nn.LogSoftMax():cuda() + + local op = cri:forward(input, target) + local gi = cri:backward(input, target) + + local gop = op:clone():zero() + local ggi = gi:clone():zero() + + for i=1,h do + for j=1,w do + local i1 = input[{{}, {}, {i}, {j}}]:contiguous():squeeze() + local t1 = target[{{}, {}, {i}, {j}}]:contiguous():squeeze() + local gop1 = gcri:forward(i1, t1) + local ggi1 = gcri:backward(i1, t1) + gop[{{}, {}, {i}, {j}}]:copy(gop1) + ggi[{{}, {}, {i}, {j}}]:copy(ggi1) + end + end + local err = (gi - ggi):abs():max() + mytester:assertlt(err, precision_backward, 'error in difference between central difference and :backward') + local err = (op - gop):abs():max() + mytester:assertlt(err, precision_backward, 'error in difference between central difference and :backward') +end + +function cudnntest.SpatialCrossEntropyCriterion() + -- batch + local numLabels = math.random(5,10) + local h = math.random(5,10) + local w = math.random(5,10) + local bsz = math.random(3, 7) + local input = torch.zeros(bsz, numLabels, h, w):normal():cuda() + local target = torch.Tensor(bsz, h, w):random(1, numLabels):cuda() + + local cri = cudnn.SpatialCrossEntropyCriterion():cuda() + + local gcri = nn.CrossEntropyCriterion():cuda() + + local op = cri:forward(input, target) + local gi = cri:backward(input, target) + + local ggi = gi:clone():zero() + + for i=1,h do + for j=1,w do + local i1 = input[{{}, {}, {i}, {j}}]:contiguous():squeeze() + local t1 = target[{{}, {i}, {j}}]:contiguous():squeeze() + local gop1 = gcri:forward(i1, t1) + local ggi1 = gcri:backward(i1, t1) + ggi[{{}, {}, {i}, {j}}]:copy(ggi1) + end + end + local err = (gi - ggi):abs():max() + mytester:assertlt(err, precision_backward, 'error in difference between central difference and :backward') + +end + + + function cudnntest.functional_bias2D() local bs = math.random(1,32) local from = math.random(1,32) |