Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/soumith/cudnn.torch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSoumith Chintala <soumith@gmail.com>2016-08-25 17:31:26 +0300
committerGitHub <noreply@github.com>2016-08-25 17:31:26 +0300
commit440f0d5f63bbb2f9dadd3fa8fd0b520d8f38abdf (patch)
tree49cd215565529664fff473ae92e022466436a1da
parent33c36f0e9672a9b1e5e1db71497297c9c4be3dcd (diff)
parent7eeeeed054d31f8a58ddcb8f53a93dcd9dd3ee5b (diff)
Merge pull request #239 from qureai/master
Volumetric softmax and cross entropy criterion
-rw-r--r--README.md4
-rw-r--r--VolumetricCrossEntropyCriterion.lua63
-rw-r--r--VolumetricLogSoftMax.lua7
-rw-r--r--VolumetricSoftMax.lua47
-rw-r--r--init.lua3
-rw-r--r--test/test.lua79
6 files changed, 203 insertions, 0 deletions
diff --git a/README.md b/README.md
index c55e44a..303619f 100644
--- a/README.md
+++ b/README.md
@@ -31,12 +31,16 @@ cudnn.SoftMax(fastMode [= false]) -- SoftMax across each image (just li
cudnn.LogSoftMax() -- LogSoftMax across each image (just like nn.LogSoftMax)
cudnn.SpatialSoftMax(fastMode [= false]) -- SoftMax across feature-maps (per spatial location)
cudnn.SpatialLogSoftMax() -- LogSoftMax across feature-maps (per spatial location)
+cudnn.VolumetricSoftMax(fastMode [= false]) -- SoftMax across feature-maps (per spatial location)
+cudnn.VolumetricLogSoftMax() -- LogSoftMax across feature-maps (per spatial location)
cudnn.SpatialCrossEntropyCriterion() -- A spatial version of LogSoftMax + ClassNLLCriterion in one shot
+cudnn.VolumetricCrossEntropyCriterion() -- A volumetric version of LogSoftMax + ClassNLLCriterion in one shot
-- Batch Normalization
cudnn.BatchNormalization(nFeature, eps, momentum, affine) -- same arguments as https://github.com/torch/nn/blob/master/doc/simple.md#nn.BatchNormalization
cudnn.SpatialBatchNormalization(nFeature, eps, momentum, affine)
+cudnn.VolumetricBatchNormalization(nFeature, eps, momentum, affine)
-- Volumetric inputs (4D or 5D batched mode)
diff --git a/VolumetricCrossEntropyCriterion.lua b/VolumetricCrossEntropyCriterion.lua
new file mode 100644
index 0000000..3faee19
--- /dev/null
+++ b/VolumetricCrossEntropyCriterion.lua
@@ -0,0 +1,63 @@
+local VolumetricCrossEntropyCriterion, parent = torch.class('cudnn.VolumetricCrossEntropyCriterion', 'nn.Criterion')
+
+--[[
+ This criterion does the VolumetricCrossEntropyCriterion across
+ the feature dimension for a N-channel 3D image/video of TxHxW in size.
+
+ It only supports mini-batches (5D input, 4D target)
+
+ It does a LogSoftMax on the input (over the channel dimension),
+ so no LogSoftMax is needed in the network at the end
+
+ input = batchSize x nClasses x T x H x W
+ target = batchSize x T x H x W
+]]--
+
+function VolumetricCrossEntropyCriterion:__init(weights)
+ parent.__init(self)
+ self.scec = cudnn.SpatialCrossEntropyCriterion(weights)
+end
+
+local foldInput = function(input)
+ -- Fold time and height into one dimension
+ -- bdthw -> bd(t*h)w
+ input = input:view(input:size(1), input:size(2),
+ input:size(3)*input:size(4), input:size(5))
+ return input
+end
+
+local foldTarget = function(target)
+ -- Fold time and height into one dimension
+ -- bthw -> b(t*h)w
+ target = target:view(target:size(1), target:size(2)*target:size(3),
+ target:size(4))
+ return target
+end
+
+function VolumetricCrossEntropyCriterion:updateOutput(input, target)
+ assert(input:dim() == 5, 'mini-batch supported only')
+ assert(target:dim() == 4, 'mini-batch supported only')
+ assert(input:size(1) == target:size(1), 'input and target should be of same size')
+ assert(input:size(3) == target:size(2), 'input and target should be of same size')
+ assert(input:size(4) == target:size(3), 'input and target should be of same size')
+ assert(input:size(5) == target:size(4), 'input and target should be of same size')
+
+ -- Fold inputs and use spatial cross entropy criterion
+ self.scec:updateOutput(foldInput(input), foldTarget(target))
+ self.output = self.scec.output
+ return self.output
+end
+
+function VolumetricCrossEntropyCriterion:updateGradInput(input, target)
+ assert(input:dim() == 5, 'mini-batch supported only')
+ assert(target:dim() == 4, 'mini-batch supported only')
+ assert(input:size(1) == target:size(1), 'input and target should be of same size')
+ assert(input:size(3) == target:size(2), 'input and target should be of same size')
+ assert(input:size(4) == target:size(3), 'input and target should be of same size')
+ assert(input:size(5) == target:size(4), 'input and target should be of same size')
+
+ local originalInputSize = input:size()
+ self.scec:updateGradInput(foldInput(input), foldTarget(target))
+ self.gradInput = self.scec.gradInput:view(originalInputSize)
+ return self.gradInput
+end \ No newline at end of file
diff --git a/VolumetricLogSoftMax.lua b/VolumetricLogSoftMax.lua
new file mode 100644
index 0000000..a23ed60
--- /dev/null
+++ b/VolumetricLogSoftMax.lua
@@ -0,0 +1,7 @@
+local SoftMax, parent = torch.class('cudnn.VolumetricLogSoftMax', 'cudnn.VolumetricSoftMax')
+
+function SoftMax:__init(fast)
+ parent.__init(self, fast)
+ self.ssm.mode = 'CUDNN_SOFTMAX_MODE_CHANNEL'
+ self.ssm.algorithm = 'CUDNN_SOFTMAX_LOG'
+end
diff --git a/VolumetricSoftMax.lua b/VolumetricSoftMax.lua
new file mode 100644
index 0000000..7a463a2
--- /dev/null
+++ b/VolumetricSoftMax.lua
@@ -0,0 +1,47 @@
+local VolumetricSoftMax, parent = torch.class('cudnn.VolumetricSoftMax', 'nn.Module')
+
+function VolumetricSoftMax:__init(fast)
+ parent.__init(self)
+ self.ssm = cudnn.SpatialSoftMax(fast)
+end
+
+local fold = function(input)
+ -- Fold time and height into one dimension
+ if input:dim() == 4 then
+ -- dthw -> d(t*h)w
+ input = input:view(input:size(1), input:size(2)*input:size(3),
+ input:size(4))
+ else
+ -- bdthw -> bd(t*h)w
+ input = input:view(input:size(1), input:size(2),
+ input:size(3)*input:size(4), input:size(5))
+ end
+ return input
+end
+
+function VolumetricSoftMax:updateOutput(input)
+ assert(input:dim() == 4 or input:dim() == 5,
+ 'input should either be a 3d image or a minibatch of them')
+ local originalInputSize = input:size()
+
+ -- Apply SpatialSoftMax to folded input
+ self.ssm:updateOutput(fold(input))
+ self.output = self.ssm.output:view(originalInputSize)
+ return self.output
+end
+
+function VolumetricSoftMax:updateGradInput(input, gradOutput)
+ assert(input:dim() == 4 or input:dim() == 5,
+ 'input should either be a 3d image or a minibatch of them')
+
+ local originalInputSize = input:size()
+ self.ssm:updateGradInput(fold(input), fold(gradOutput))
+
+ self.gradInput = self.ssm.gradInput:view(originalInputSize)
+ return self.gradInput
+end
+
+function VolumetricSoftMax:clearState()
+ self.ssm:clearState()
+ return parent.clearState(self)
+end \ No newline at end of file
diff --git a/init.lua b/init.lua
index e69eade..bbb17a3 100644
--- a/init.lua
+++ b/init.lua
@@ -185,6 +185,8 @@ require('cudnn.Tanh')
require('cudnn.Sigmoid')
require('cudnn.SpatialSoftMax')
require('cudnn.SpatialLogSoftMax')
+require('cudnn.VolumetricSoftMax')
+require('cudnn.VolumetricLogSoftMax')
require('cudnn.SoftMax')
require('cudnn.LogSoftMax')
require('cudnn.SpatialCrossMapLRN')
@@ -192,6 +194,7 @@ require('cudnn.BatchNormalization')
require('cudnn.SpatialBatchNormalization')
require('cudnn.VolumetricBatchNormalization')
require('cudnn.SpatialCrossEntropyCriterion')
+require('cudnn.VolumetricCrossEntropyCriterion')
require('cudnn.TemporalConvolution')
require('cudnn.RNN')
require('cudnn.RNNTanh')
diff --git a/test/test.lua b/test/test.lua
index 0c9e852..7986e9f 100644
--- a/test/test.lua
+++ b/test/test.lua
@@ -562,6 +562,45 @@ function cudnntest.SpatialLogSoftMax()
'error in difference between central difference and :backward')
end
+function cudnntest.VolumetricLogSoftMax()
+ -- batch
+ local numLabels = math.random(5,10)
+ local t = math.random(5,10)
+ local h = math.random(5,10)
+ local w = math.random(5,10)
+ local bsz = math.random(3, 7)
+ local input = torch.zeros(bsz, numLabels, t, h, w):normal():cuda()
+ local target = torch.zeros(bsz, numLabels, t, h, w):normal():cuda()
+
+ local cri = cast(cudnn.VolumetricLogSoftMax())
+ local gcri = nn.LogSoftMax():cuda()
+
+ local op = cri:forward(cast(input), cast(target))
+ local gi = cri:backward(cast(input), cast(target))
+
+ local gop = op:clone():zero()
+ local ggi = gi:clone():zero()
+
+ for i=1,t do
+ for j=1,h do
+ for k =1,w do
+ local i1 = input[{ {}, {}, {i}, {j}, {k} }]:contiguous():squeeze()
+ local t1 = target[{ {}, {}, {i}, {j}, {k} }]:contiguous():squeeze()
+ local gop1 = gcri:forward(i1, t1)
+ local ggi1 = gcri:backward(i1, t1)
+ gop[{ {}, {}, {i}, {j}, {k} }]:copy(gop1)
+ ggi[{ {}, {}, {i}, {j}, {k} }]:copy(ggi1)
+ end
+ end
+ end
+ local err = (gi - ggi):abs():max()
+ mytester:assertlt(err, testparams.precision_backward,
+ 'error in difference between central difference and :backward')
+ local err = (op - gop):abs():max()
+ mytester:assertlt(err, testparams.precision_backward,
+ 'error in difference between central difference and :backward')
+end
+
local function testBatchNormalization(moduleName, inputSize)
local input = torch.randn(table.unpack(inputSize)):cuda()
local gradOutput = torch.randn(table.unpack(inputSize)):cuda()
@@ -683,6 +722,46 @@ function cudnntest.SpatialCrossEntropyCriterion()
'error in difference between central difference and :backward')
end
+function cudnntest.VolumetricCrossEntropyCriterion()
+ if testparams.test_type ~= 'torch.CudaTensor' then return end
+ -- batch
+ local numLabels = math.random(5,10)
+ local t = math.random(5,10)
+ local h = math.random(5,10)
+ local w = math.random(5,10)
+ local bsz = math.random(3, 7)
+ local input = torch.zeros(bsz, numLabels, t, h, w):normal():cuda()
+ local target = torch.Tensor(bsz, t, h, w):random(1, numLabels):cuda()
+
+ local cri = cast(cudnn.VolumetricCrossEntropyCriterion())
+ local gcri = nn.CrossEntropyCriterion():cuda()
+
+ local op = cri:forward(cast(input), cast(target))
+ local gi = cri:backward(cast(input), cast(target))
+
+ local ggi = gi:clone():zero()
+
+ for i=1,t do
+ for j=1,h do
+ for k=1,w do
+ local i1 = input[{ {}, {}, {i}, {j}, {k} }]:contiguous():squeeze()
+ local t1 = target[{ {}, {i}, {j}, {k} }]:contiguous():squeeze()
+ local gop1 = gcri:forward(i1, t1)
+ local ggi1 = gcri:backward(i1, t1)
+ ggi[{ {}, {}, {i}, {j}, {k} }]:copy(ggi1)
+ end
+ end
+ end
+
+ -- nn.CrossEntropy in contrast to cudnn.VolumetricCrossEntropyCriterion cannot
+ -- average over the last spatial dimensions because it is run in a loop
+ ggi:div(t* h * w)
+
+ local err = (gi - ggi):abs():max()
+ mytester:assertlt(err, testparams.precision_backward,
+ 'error in difference between central difference and :backward')
+end
+
function cudnntest.functional_bias2D()
local bs = math.random(1,32)
local from = math.random(1,32)