Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/soumith/cudnn.torch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBoris Fomitchev <bfomitchev@nvidia.com>2016-09-15 11:36:52 +0300
committerBoris Fomitchev <bfomitchev@nvidia.com>2016-09-15 11:36:52 +0300
commit89abf80596dff69787de759000ac4d480691c9e2 (patch)
treee54f8900efd98ebc3d220b412c9426c8b7ea549d
parentb146c809cab5a943b802002822f72c3ea6ed0def (diff)
parent8c112dfe7adb26cad7f10c3f0919234a3ffd7b70 (diff)
Merge remote-tracking branch 'upstream/master' into find_ex
-rw-r--r--README.md4
-rw-r--r--RNN.lua6
-rw-r--r--VolumetricCrossEntropyCriterion.lua63
-rw-r--r--VolumetricLogSoftMax.lua7
-rw-r--r--VolumetricSoftMax.lua47
-rw-r--r--convert.lua5
-rw-r--r--ffi.lua32
-rw-r--r--init.lua3
-rw-r--r--test/test.lua79
9 files changed, 235 insertions, 11 deletions
diff --git a/README.md b/README.md
index c55e44a..303619f 100644
--- a/README.md
+++ b/README.md
@@ -31,12 +31,16 @@ cudnn.SoftMax(fastMode [= false]) -- SoftMax across each image (just li
cudnn.LogSoftMax() -- LogSoftMax across each image (just like nn.LogSoftMax)
cudnn.SpatialSoftMax(fastMode [= false]) -- SoftMax across feature-maps (per spatial location)
cudnn.SpatialLogSoftMax() -- LogSoftMax across feature-maps (per spatial location)
+cudnn.VolumetricSoftMax(fastMode [= false]) -- SoftMax across feature-maps (per spatial location)
+cudnn.VolumetricLogSoftMax() -- LogSoftMax across feature-maps (per spatial location)
cudnn.SpatialCrossEntropyCriterion() -- A spatial version of LogSoftMax + ClassNLLCriterion in one shot
+cudnn.VolumetricCrossEntropyCriterion() -- A volumetric version of LogSoftMax + ClassNLLCriterion in one shot
-- Batch Normalization
cudnn.BatchNormalization(nFeature, eps, momentum, affine) -- same arguments as https://github.com/torch/nn/blob/master/doc/simple.md#nn.BatchNormalization
cudnn.SpatialBatchNormalization(nFeature, eps, momentum, affine)
+cudnn.VolumetricBatchNormalization(nFeature, eps, momentum, affine)
-- Volumetric inputs (4D or 5D batched mode)
diff --git a/RNN.lua b/RNN.lua
index 91b4228..af37afc 100644
--- a/RNN.lua
+++ b/RNN.lua
@@ -234,6 +234,12 @@ function RNN:resetStates()
if self.cellInput then
self.cellInput = nil
end
+ if self.gradHiddenOutput then
+ self.gradHiddenOutput = nil
+ end
+ if self.gradCellOutput then
+ self.gradCellOutput = nil
+ end
end
diff --git a/VolumetricCrossEntropyCriterion.lua b/VolumetricCrossEntropyCriterion.lua
new file mode 100644
index 0000000..3faee19
--- /dev/null
+++ b/VolumetricCrossEntropyCriterion.lua
@@ -0,0 +1,63 @@
+local VolumetricCrossEntropyCriterion, parent = torch.class('cudnn.VolumetricCrossEntropyCriterion', 'nn.Criterion')
+
+--[[
+ This criterion does the VolumetricCrossEntropyCriterion across
+ the feature dimension for a N-channel 3D image/video of TxHxW in size.
+
+ It only supports mini-batches (5D input, 4D target)
+
+ It does a LogSoftMax on the input (over the channel dimension),
+ so no LogSoftMax is needed in the network at the end
+
+ input = batchSize x nClasses x T x H x W
+ target = batchSize x T x H x W
+]]--
+
+function VolumetricCrossEntropyCriterion:__init(weights)
+ parent.__init(self)
+ self.scec = cudnn.SpatialCrossEntropyCriterion(weights)
+end
+
+local foldInput = function(input)
+ -- Fold time and height into one dimension
+ -- bdthw -> bd(t*h)w
+ input = input:view(input:size(1), input:size(2),
+ input:size(3)*input:size(4), input:size(5))
+ return input
+end
+
+local foldTarget = function(target)
+ -- Fold time and height into one dimension
+ -- bthw -> b(t*h)w
+ target = target:view(target:size(1), target:size(2)*target:size(3),
+ target:size(4))
+ return target
+end
+
+function VolumetricCrossEntropyCriterion:updateOutput(input, target)
+ assert(input:dim() == 5, 'mini-batch supported only')
+ assert(target:dim() == 4, 'mini-batch supported only')
+ assert(input:size(1) == target:size(1), 'input and target should be of same size')
+ assert(input:size(3) == target:size(2), 'input and target should be of same size')
+ assert(input:size(4) == target:size(3), 'input and target should be of same size')
+ assert(input:size(5) == target:size(4), 'input and target should be of same size')
+
+ -- Fold inputs and use spatial cross entropy criterion
+ self.scec:updateOutput(foldInput(input), foldTarget(target))
+ self.output = self.scec.output
+ return self.output
+end
+
+function VolumetricCrossEntropyCriterion:updateGradInput(input, target)
+ assert(input:dim() == 5, 'mini-batch supported only')
+ assert(target:dim() == 4, 'mini-batch supported only')
+ assert(input:size(1) == target:size(1), 'input and target should be of same size')
+ assert(input:size(3) == target:size(2), 'input and target should be of same size')
+ assert(input:size(4) == target:size(3), 'input and target should be of same size')
+ assert(input:size(5) == target:size(4), 'input and target should be of same size')
+
+ local originalInputSize = input:size()
+ self.scec:updateGradInput(foldInput(input), foldTarget(target))
+ self.gradInput = self.scec.gradInput:view(originalInputSize)
+ return self.gradInput
+end \ No newline at end of file
diff --git a/VolumetricLogSoftMax.lua b/VolumetricLogSoftMax.lua
new file mode 100644
index 0000000..a23ed60
--- /dev/null
+++ b/VolumetricLogSoftMax.lua
@@ -0,0 +1,7 @@
+local SoftMax, parent = torch.class('cudnn.VolumetricLogSoftMax', 'cudnn.VolumetricSoftMax')
+
+function SoftMax:__init(fast)
+ parent.__init(self, fast)
+ self.ssm.mode = 'CUDNN_SOFTMAX_MODE_CHANNEL'
+ self.ssm.algorithm = 'CUDNN_SOFTMAX_LOG'
+end
diff --git a/VolumetricSoftMax.lua b/VolumetricSoftMax.lua
new file mode 100644
index 0000000..7a463a2
--- /dev/null
+++ b/VolumetricSoftMax.lua
@@ -0,0 +1,47 @@
+local VolumetricSoftMax, parent = torch.class('cudnn.VolumetricSoftMax', 'nn.Module')
+
+function VolumetricSoftMax:__init(fast)
+ parent.__init(self)
+ self.ssm = cudnn.SpatialSoftMax(fast)
+end
+
+local fold = function(input)
+ -- Fold time and height into one dimension
+ if input:dim() == 4 then
+ -- dthw -> d(t*h)w
+ input = input:view(input:size(1), input:size(2)*input:size(3),
+ input:size(4))
+ else
+ -- bdthw -> bd(t*h)w
+ input = input:view(input:size(1), input:size(2),
+ input:size(3)*input:size(4), input:size(5))
+ end
+ return input
+end
+
+function VolumetricSoftMax:updateOutput(input)
+ assert(input:dim() == 4 or input:dim() == 5,
+ 'input should either be a 3d image or a minibatch of them')
+ local originalInputSize = input:size()
+
+ -- Apply SpatialSoftMax to folded input
+ self.ssm:updateOutput(fold(input))
+ self.output = self.ssm.output:view(originalInputSize)
+ return self.output
+end
+
+function VolumetricSoftMax:updateGradInput(input, gradOutput)
+ assert(input:dim() == 4 or input:dim() == 5,
+ 'input should either be a 3d image or a minibatch of them')
+
+ local originalInputSize = input:size()
+ self.ssm:updateGradInput(fold(input), fold(gradOutput))
+
+ self.gradInput = self.ssm.gradInput:view(originalInputSize)
+ return self.gradInput
+end
+
+function VolumetricSoftMax:clearState()
+ self.ssm:clearState()
+ return parent.clearState(self)
+end \ No newline at end of file
diff --git a/convert.lua b/convert.lua
index 49a6fa5..5368418 100644
--- a/convert.lua
+++ b/convert.lua
@@ -22,6 +22,10 @@ local layer_list = {
-- for example: net = cudnn.convert(net, cudnn)
function cudnn.convert(net, dst, exclusion_fn)
return net:replace(function(x)
+ if torch.type(x) == 'nn.gModule' then
+ io.stderr:write('Warning: cudnn.convert does not work with nngraph yet. Ignoring nn.gModule')
+ return x
+ end
local y = 0
local src = dst == nn and cudnn or nn
local src_prefix = src == nn and 'nn.' or 'cudnn.'
@@ -58,4 +62,3 @@ function cudnn.convert(net, dst, exclusion_fn)
return y == 0 and x or y
end)
end
-
diff --git a/ffi.lua b/ffi.lua
index cf56341..d5b5f8c 100644
--- a/ffi.lua
+++ b/ffi.lua
@@ -1586,18 +1586,30 @@ cudnnStatus_t cudnnActivationBackward_v4(
]]
-local libnames = {'libcudnn.so.5', 'libcudnn.5.dylib'}
-local ok = false
-for i=1,#libnames do
- ok = pcall(function () cudnn.C = ffi.load(libnames[i]) end)
- if ok then break; end
-end
-
-if not ok then
- error([['libcudnn (R5) not found in library path.
+local CUDNN_PATH = os.getenv('CUDNN_PATH')
+if CUDNN_PATH then
+ print('Found Environment variable CUDNN_PATH = ' .. CUDNN_PATH)
+ cudnn.C = ffi.load(CUDNN_PATH)
+else
+
+ local libnames = {'libcudnn.so.5', 'libcudnn.5.dylib'}
+ local ok = false
+ for i=1,#libnames do
+ ok = pcall(function () cudnn.C = ffi.load(libnames[i]) end)
+ if ok then break; end
+ end
+
+ if not ok then
+ error([['libcudnn (R5) not found in library path.
Please install CuDNN from https://developer.nvidia.com/cuDNN
-Then make sure files named as libcudnn.so.5 or libcudnn.5.dylib are placed in your library load path (for example /usr/local/lib , or manually add a path to LD_LIBRARY_PATH)
+Then make sure files named as libcudnn.so.5 or libcudnn.5.dylib are placed in
+your library load path (for example /usr/local/lib , or manually add a path to LD_LIBRARY_PATH)
+
+Alternatively, set the path to libcudnn.so.5 or libcudnn.5.dylib
+to the environment variable CUDNN_PATH and rerun torch.
+For example: export CUDNN_PATH = "/usr/local/cuda/lib64/libcudnn.so.5"
]])
+ end
end
-- check cuDNN version
diff --git a/init.lua b/init.lua
index 59418d0..54cf72b 100644
--- a/init.lua
+++ b/init.lua
@@ -289,6 +289,8 @@ require('cudnn.Tanh')
require('cudnn.Sigmoid')
require('cudnn.SpatialSoftMax')
require('cudnn.SpatialLogSoftMax')
+require('cudnn.VolumetricSoftMax')
+require('cudnn.VolumetricLogSoftMax')
require('cudnn.SoftMax')
require('cudnn.LogSoftMax')
require('cudnn.SpatialCrossMapLRN')
@@ -296,6 +298,7 @@ require('cudnn.BatchNormalization')
require('cudnn.SpatialBatchNormalization')
require('cudnn.VolumetricBatchNormalization')
require('cudnn.SpatialCrossEntropyCriterion')
+require('cudnn.VolumetricCrossEntropyCriterion')
require('cudnn.TemporalConvolution')
require('cudnn.RNN')
require('cudnn.RNNTanh')
diff --git a/test/test.lua b/test/test.lua
index 4b86bcb..8d3fc95 100644
--- a/test/test.lua
+++ b/test/test.lua
@@ -563,6 +563,45 @@ function cudnntest.SpatialLogSoftMax()
'error in difference between central difference and :backward')
end
+function cudnntest.VolumetricLogSoftMax()
+ -- batch
+ local numLabels = math.random(5,10)
+ local t = math.random(5,10)
+ local h = math.random(5,10)
+ local w = math.random(5,10)
+ local bsz = math.random(3, 7)
+ local input = torch.zeros(bsz, numLabels, t, h, w):normal():cuda()
+ local target = torch.zeros(bsz, numLabels, t, h, w):normal():cuda()
+
+ local cri = cast(cudnn.VolumetricLogSoftMax())
+ local gcri = nn.LogSoftMax():cuda()
+
+ local op = cri:forward(cast(input), cast(target))
+ local gi = cri:backward(cast(input), cast(target))
+
+ local gop = op:clone():zero()
+ local ggi = gi:clone():zero()
+
+ for i=1,t do
+ for j=1,h do
+ for k =1,w do
+ local i1 = input[{ {}, {}, {i}, {j}, {k} }]:contiguous():squeeze()
+ local t1 = target[{ {}, {}, {i}, {j}, {k} }]:contiguous():squeeze()
+ local gop1 = gcri:forward(i1, t1)
+ local ggi1 = gcri:backward(i1, t1)
+ gop[{ {}, {}, {i}, {j}, {k} }]:copy(gop1)
+ ggi[{ {}, {}, {i}, {j}, {k} }]:copy(ggi1)
+ end
+ end
+ end
+ local err = (gi - ggi):abs():max()
+ mytester:assertlt(err, testparams.precision_backward,
+ 'error in difference between central difference and :backward')
+ local err = (op - gop):abs():max()
+ mytester:assertlt(err, testparams.precision_backward,
+ 'error in difference between central difference and :backward')
+end
+
local function testBatchNormalization(moduleName, inputSize)
local input = torch.randn(table.unpack(inputSize)):cuda()
local gradOutput = torch.randn(table.unpack(inputSize)):cuda()
@@ -684,6 +723,46 @@ function cudnntest.SpatialCrossEntropyCriterion()
'error in difference between central difference and :backward')
end
+function cudnntest.VolumetricCrossEntropyCriterion()
+ if testparams.test_type ~= 'torch.CudaTensor' then return end
+ -- batch
+ local numLabels = math.random(5,10)
+ local t = math.random(5,10)
+ local h = math.random(5,10)
+ local w = math.random(5,10)
+ local bsz = math.random(3, 7)
+ local input = torch.zeros(bsz, numLabels, t, h, w):normal():cuda()
+ local target = torch.Tensor(bsz, t, h, w):random(1, numLabels):cuda()
+
+ local cri = cast(cudnn.VolumetricCrossEntropyCriterion())
+ local gcri = nn.CrossEntropyCriterion():cuda()
+
+ local op = cri:forward(cast(input), cast(target))
+ local gi = cri:backward(cast(input), cast(target))
+
+ local ggi = gi:clone():zero()
+
+ for i=1,t do
+ for j=1,h do
+ for k=1,w do
+ local i1 = input[{ {}, {}, {i}, {j}, {k} }]:contiguous():squeeze()
+ local t1 = target[{ {}, {i}, {j}, {k} }]:contiguous():squeeze()
+ local gop1 = gcri:forward(i1, t1)
+ local ggi1 = gcri:backward(i1, t1)
+ ggi[{ {}, {}, {i}, {j}, {k} }]:copy(ggi1)
+ end
+ end
+ end
+
+ -- nn.CrossEntropy in contrast to cudnn.VolumetricCrossEntropyCriterion cannot
+ -- average over the last spatial dimensions because it is run in a loop
+ ggi:div(t* h * w)
+
+ local err = (gi - ggi):abs():max()
+ mytester:assertlt(err, testparams.precision_backward,
+ 'error in difference between central difference and :backward')
+end
+
function cudnntest.functional_bias2D()
local bs = math.random(1,32)
local from = math.random(1,32)