diff options
-rw-r--r-- | README.md | 2 | ||||
-rw-r--r-- | SpatialCrossEntropyCriterion.lua | 78 | ||||
-rw-r--r-- | init.lua | 1 | ||||
-rw-r--r-- | test/test.lua | 68 |
4 files changed, 149 insertions, 0 deletions
@@ -30,6 +30,8 @@ cudnn.LogSoftMax() -- LogSoftMax across each image (just cudnn.SpatialSoftMax(fastMode [= false]) -- SoftMax across feature-maps (per spatial location) cudnn.SpatialLogSoftMax() -- LogSoftMax across feature-maps (per spatial location) +cudnn.SpatialCrossEntropyCriterion() -- A spatial version of LogSoftMax + ClassNLLCriterion in one shot + -- Volumetric inputs (4D or 5D batched mode) cudnn.VolumetricConvolution(nInputPlane, nOutputPlane, kT, kW, kH, dT, dW, dH, padT, padW, padH) cudnn.VolumetricMaxPooling(kT, kW, kH, dT, dW, dH, padT, padW, padH) diff --git a/SpatialCrossEntropyCriterion.lua b/SpatialCrossEntropyCriterion.lua new file mode 100644 index 0000000..d780c39 --- /dev/null +++ b/SpatialCrossEntropyCriterion.lua @@ -0,0 +1,78 @@ +require 'nn' + +local SpatialCrossEntropyCriterion, parent = torch.class('cudnn.SpatialCrossEntropyCriterion', 'nn.Criterion') + +--[[ + This criterion does the SpatialCrossEntropyCriterion across + the feature dimension for a N-channel image of HxW in size. + + It only supports mini-batches (4D input, 3D target) + + It does a LogSoftMax on the input (over the channel dimension), + so no LogSoftMax is needed in the network at the end + + input = batchSize x nClasses x H x W + target = batchSize x H x W +]]-- +function SpatialCrossEntropyCriterion:__init() + parent.__init(self) + self.slsm = cudnn.SpatialLogSoftMax() + self.nll = nn.ClassNLLCriterion() + self.nll.sizeAverage = false + self.sizeAverage = true +end + +local transpose = function(input) + input = input:transpose(2,4):transpose(2,3):contiguous() -- bdhw -> bwhd -> bhwd + input = input:view(input:size(1)*input:size(2)*input:size(3), input:size(4)) + return input +end + +local transposeBack = function(input, originalInput) + input = input:view(originalInput:size(1), originalInput:size(3), + originalInput:size(4), originalInput:size(2)) + input = input:transpose(2,4):transpose(3,4):contiguous() -- bhwd -> bdwh -> bdhw + return input +end + +function SpatialCrossEntropyCriterion:updateOutput(input, target) + assert(input:dim() == 4, 'mini-batch supported only') + assert(target:dim() == 3, 'mini-batch supported only') + assert(input:size(1) == target:size(1), 'input and target should be of same size') + assert(input:size(3) == target:size(2), 'input and target should be of same size') + assert(input:size(4) == target:size(3), 'input and target should be of same size') + -- apply SpatialLogSoftMax to input + self.slsm:updateOutput(input) + + -- fold the height and width dims into the mini-batch dim. + self.nll:updateOutput(transpose(self.slsm.output), target:view(-1)) + self.output = self.nll.output + return self.output +end + +function SpatialCrossEntropyCriterion:updateGradInput(input, target) + assert(input:dim() == 4, 'mini-batch supported only') + assert(target:dim() == 3, 'mini-batch supported only') + assert(input:size(1) == target:size(1), 'input and target should be of same size') + assert(input:size(3) == target:size(2), 'input and target should be of same size') + assert(input:size(4) == target:size(3), 'input and target should be of same size') + + self.nll:updateGradInput(transpose(self.slsm.output), target:view(-1)) + + -- unfold the height and width dims back + self.slsm:updateGradInput(input, transposeBack(self.nll.gradInput, input)) + self.gradInput = self.slsm.gradInput + if self.sizeAverage then + self.gradInput:div(input:size(1)) + end + return self.gradInput +end + +function SpatialCrossEntropyCriterion:type(type) + if type then + self.nll:type(type) + self.slsm:type(type) + end + parent.type(self, type) + return self +end @@ -110,6 +110,7 @@ include 'SpatialLogSoftMax.lua' include 'SoftMax.lua' include 'LogSoftMax.lua' include 'SpatialCrossMapLRN.lua' +include 'SpatialCrossEntropyCriterion.lua' include 'functional.lua' diff --git a/test/test.lua b/test/test.lua index 0a2fb01..8c97ffa 100644 --- a/test/test.lua +++ b/test/test.lua @@ -730,6 +730,74 @@ function cudnntest.LogSoftMax_batch() precision_backward, 'error on state (backward) ') end +function cudnntest.SpatialLogSoftMax() + -- batch + local numLabels = math.random(5,10) + local h = math.random(5,10) + local w = math.random(5,10) + local bsz = math.random(3, 7) + local input = torch.zeros(bsz, numLabels, h, w):normal():cuda() + local target = torch.zeros(bsz, numLabels, h, w):normal():cuda() + + local cri = cudnn.SpatialLogSoftMax():cuda() + local gcri = nn.LogSoftMax():cuda() + + local op = cri:forward(input, target) + local gi = cri:backward(input, target) + + local gop = op:clone():zero() + local ggi = gi:clone():zero() + + for i=1,h do + for j=1,w do + local i1 = input[{{}, {}, {i}, {j}}]:contiguous():squeeze() + local t1 = target[{{}, {}, {i}, {j}}]:contiguous():squeeze() + local gop1 = gcri:forward(i1, t1) + local ggi1 = gcri:backward(i1, t1) + gop[{{}, {}, {i}, {j}}]:copy(gop1) + ggi[{{}, {}, {i}, {j}}]:copy(ggi1) + end + end + local err = (gi - ggi):abs():max() + mytester:assertlt(err, precision_backward, 'error in difference between central difference and :backward') + local err = (op - gop):abs():max() + mytester:assertlt(err, precision_backward, 'error in difference between central difference and :backward') +end + +function cudnntest.SpatialCrossEntropyCriterion() + -- batch + local numLabels = math.random(5,10) + local h = math.random(5,10) + local w = math.random(5,10) + local bsz = math.random(3, 7) + local input = torch.zeros(bsz, numLabels, h, w):normal():cuda() + local target = torch.Tensor(bsz, h, w):random(1, numLabels):cuda() + + local cri = cudnn.SpatialCrossEntropyCriterion():cuda() + + local gcri = nn.CrossEntropyCriterion():cuda() + + local op = cri:forward(input, target) + local gi = cri:backward(input, target) + + local ggi = gi:clone():zero() + + for i=1,h do + for j=1,w do + local i1 = input[{{}, {}, {i}, {j}}]:contiguous():squeeze() + local t1 = target[{{}, {i}, {j}}]:contiguous():squeeze() + local gop1 = gcri:forward(i1, t1) + local ggi1 = gcri:backward(i1, t1) + ggi[{{}, {}, {i}, {j}}]:copy(ggi1) + end + end + local err = (gi - ggi):abs():max() + mytester:assertlt(err, precision_backward, 'error in difference between central difference and :backward') + +end + + + function cudnntest.functional_bias2D() local bs = math.random(1,32) local from = math.random(1,32) |