Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/soumith/cudnn.torch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorsoumith <soumith@fb.com>2015-09-15 02:37:02 +0300
committersoumith <soumith@fb.com>2015-09-15 02:37:02 +0300
commit77f32a01edd42f9ca481263359e32f8a1d73f3d1 (patch)
treeea50d8acbc0b4b521916ebd163b7cb97b7da0acc
parentf85c8e0d178baf0dab9deb982c76b95191620418 (diff)
functional interface for R3 as well
-rw-r--r--functional.lua350
-rw-r--r--test/test.lua57
2 files changed, 400 insertions, 7 deletions
diff --git a/functional.lua b/functional.lua
index 765f6fd..64c2e1c 100644
--- a/functional.lua
+++ b/functional.lua
@@ -3,6 +3,7 @@
-- There shouldn't be any reference to "self" in this file.
local cudnn = require 'cudnn.env'
+local ffi = require 'ffi'
local errcheck = cudnn.errcheck
cudnn.functional = {}
@@ -10,12 +11,18 @@ cudnn.functional = {}
local one = torch.FloatTensor({1});
local zero = torch.FloatTensor({0});
+local function Batch2D(t)
+ return t:view(1, t:size(1), t:size(2), t:size(3))
+end
+
-- accumulates the bias into output.
-- output is assumed to be allocated and given.
-cudnn.functional.SpatialBias_updateOutput = function(bias, output)
+cudnn.functional.bias2D_updateOutput = function(handle, bias, output)
+ output = output:dim() == 3 and Batch2D(output) or output
+
local biasDesc = cudnn.toDescriptor(bias:view(1, bias:nElement(),1,1))
local oDesc = cudnn.toDescriptor(output)
- errcheck('cudnnAddTensor', cudnn.getHandle(),
+ errcheck('cudnnAddTensor', handle,
'CUDNN_ADD_SAME_C',
one:data(), biasDesc[0], bias:data(),
one:data(), oDesc[0], output:data())
@@ -23,14 +30,349 @@ end
-- accumulates the gradients into gradBias.
-- gradBias is assumed to be allocated and given.
-cudnn.functional.SpatialBias_accGradParameters = function(gradOutput, gradBias, scale)
+cudnn.functional.bias2D_accGradParameters = function(handle, gradOutput, gradBias, scale)
+ gradOutput = gradOutput:dim() == 3 and Batch2D(gradOutput) or gradOutput
scale = scale or 1.0
local scaleT = torch.FloatTensor({scale})
local oDesc = cudnn.toDescriptor(gradOutput)
local biasDesc = cudnn.toDescriptor(gradBias:view(1, gradBias:nElement(),1,1))
- errcheck('cudnnConvolutionBackwardBias', cudnn.getHandle(),
+ errcheck('cudnnConvolutionBackwardBias', handle,
scaleT:data(),
oDesc[0], gradOutput:data(),
one:data(),
biasDesc[0], gradBias:data())
end
+
+-- Does a 2D Convolution (updateOutput) on input, weight
+-- output is assumed to be allocated and given.
+cudnn.functional.Convolution2D_updateOutput = function(handle, input, weight, output,
+ strideH, strideW, padH, padW, workspace)
+ input = input:dim() == 3 and Batch2D(input) or input
+ output = output:dim() == 3 and Batch2D(output) or output
+
+ -- create a weight descriptor
+ local weightDesc = ffi.new('struct cudnnFilterStruct*[1]')
+ errcheck('cudnnCreateFilterDescriptor', weightDesc)
+ local nOutputPlane, nInputPlane, kH, kW
+ = weight:size(1), weight:size(2), weight:size(3), weight:size(4)
+ local desc = torch.IntTensor({nOutputPlane, nInputPlane, kH, kW})
+ errcheck('cudnnSetFilterNdDescriptor', weightDesc[0], 'CUDNN_DATA_FLOAT', 4,
+ desc:data());
+ local function destroyWDesc(d)
+ errcheck('cudnnDestroyFilterDescriptor', d[0]);
+ end
+ ffi.gc(weightDesc, destroyWDesc)
+
+ -- create a convolution descriptor
+ local convDesc = ffi.new('struct cudnnConvolutionStruct*[1]')
+ errcheck('cudnnCreateConvolutionDescriptor', convDesc)
+ local pad = torch.IntTensor({padH, padW})
+ local stride = torch.IntTensor({strideH, strideW})
+ local upscale = torch.IntTensor({1,1})
+ errcheck('cudnnSetConvolutionNdDescriptor_v3', convDesc[0],
+ 2, pad:data(),
+ stride:data(), upscale:data(), 'CUDNN_CROSS_CORRELATION',
+ 'CUDNN_DATA_FLOAT');
+ local function destroyConvDesc(d)
+ errcheck('cudnnDestroyConvolutionDescriptor', d[0]);
+ end
+ ffi.gc(convDesc, destroyConvDesc)
+
+ -- create input descriptor
+ local iDesc = cudnn.toDescriptor(input)
+
+ -- create output descriptor
+ local oSize = torch.IntTensor(4)
+ errcheck('cudnnGetConvolutionNdForwardOutputDim',
+ convDesc[0], iDesc[0],
+ weightDesc[0], 4, oSize:data())
+ oSize = oSize:long()
+ assert(output:dim() == 4 and
+ output:size(1) == oSize[1] and
+ output:size(2) == oSize[2] and
+ output:size(3) == oSize[3] and
+ output:size(4) == oSize[4],
+ 'Output is of wrong size')
+ -- create descriptor for output
+ local oDesc = cudnn.toDescriptor(output)
+
+ -- create forwardAlgorithm descriptors for
+ local algType = ffi.new("cudnnConvolutionFwdAlgo_t[?]", 1)
+ local algSearchMode = 'CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT'
+ local algWorkspaceLimit = 0
+ if workspace then
+ algWorkspaceLimit = workspace:nElement() * 4 -- 4 = sizeof float
+ end
+ errcheck('cudnnGetConvolutionForwardAlgorithm',
+ handle,
+ iDesc[0], weightDesc[0],
+ convDesc[0], oDesc[0],
+ algSearchMode, algWorkspaceLimit, algType)
+
+ -- do convolution
+ errcheck('cudnnConvolutionForward', handle,
+ one:data(),
+ iDesc[0], input:data(),
+ weightDesc[0], weight:data(),
+ convDesc[0], algType[0],
+ workspace and workspace:data() or nil, algWorkspaceLimit,
+ zero:data(),
+ oDesc[0], output:data());
+end
+
+-- Does a 2D Convolution (updateGradInput) on input, weight, output, gradOutput
+-- gradInput is assumed to be allocated and given.
+cudnn.functional.Convolution2D_updateGradInput = function(handle, input, weight, output, gradOutput, gradInput,
+ strideH, strideW, padH, padW)
+ input = input:dim() == 3 and Batch2D(input) or input
+ output = output:dim() == 3 and Batch2D(output) or output
+ gradOutput = gradOutput:dim() == 3 and Batch2D(gradOutput) or gradOutput
+ gradInput = gradInput:dim() == 3 and Batch2D(gradInput) or gradInput
+
+ -- create a weight descriptor
+ local weightDesc = ffi.new('struct cudnnFilterStruct*[1]')
+ errcheck('cudnnCreateFilterDescriptor', weightDesc)
+ local nOutputPlane, nInputPlane, kH, kW
+ = weight:size(1), weight:size(2), weight:size(3), weight:size(4)
+ local desc = torch.IntTensor({nOutputPlane, nInputPlane, kH, kW})
+ errcheck('cudnnSetFilterNdDescriptor', weightDesc[0], 'CUDNN_DATA_FLOAT', 4,
+ desc:data());
+ local function destroyWDesc(d)
+ errcheck('cudnnDestroyFilterDescriptor', d[0]);
+ end
+ ffi.gc(weightDesc, destroyWDesc)
+
+ -- create a convolution descriptor
+ local convDesc = ffi.new('struct cudnnConvolutionStruct*[1]')
+ errcheck('cudnnCreateConvolutionDescriptor', convDesc)
+ local pad = torch.IntTensor({padH, padW})
+ local stride = torch.IntTensor({strideH, strideW})
+ local upscale = torch.IntTensor({1,1})
+ errcheck('cudnnSetConvolutionNdDescriptor_v3', convDesc[0],
+ 2, pad:data(),
+ stride:data(), upscale:data(), 'CUDNN_CROSS_CORRELATION',
+ 'CUDNN_DATA_FLOAT');
+ local function destroyConvDesc(d)
+ errcheck('cudnnDestroyConvolutionDescriptor', d[0]);
+ end
+ ffi.gc(convDesc, destroyConvDesc)
+
+ -- create input, output descriptor
+ local iDesc = cudnn.toDescriptor(input)
+ local oDesc = cudnn.toDescriptor(output)
+
+ local algType = ffi.new("cudnnConvolutionBwdDataAlgo_t[?]", 1)
+ local algSearchMode = 'CUDNN_CONVOLUTION_BWD_DATA_NO_WORKSPACE'
+
+ errcheck('cudnnGetConvolutionBackwardDataAlgorithm',
+ cudnn.getHandle(),
+ weightDesc[0], oDesc[0],
+ convDesc[0], iDesc[0],
+ algSearchMode, 0, algType)
+
+ -- do convolution
+ errcheck('cudnnConvolutionBackwardData_v3', handle,
+ one:data(),
+ weightDesc[0], weight:data(),
+ oDesc[0], gradOutput:data(),
+ convDesc[0],
+ algType[0],
+ ffi.C.NULL, 0,
+ zero:data(),
+ iDesc[0], gradInput:data());
+
+
+end
+
+-- accumulates the gradients into gradWeight.
+-- gradWeight is assumed to be allocated and given.
+local scaleT = torch.FloatTensor(1):fill(1.0)
+cudnn.functional.Convolution2D_accGradParameters = function(handle, input, gradWeight, gradOutput,
+ strideH, strideW, padH, padW, scale)
+ input = input:dim() == 3 and Batch2D(input) or input
+ gradOutput = gradOutput:dim() == 3 and Batch2D(gradOutput) or gradOutput
+
+ scale = scale or 1.0
+ scaleT[1] = scale
+ -- create a weight descriptor
+ local weightDesc = ffi.new('struct cudnnFilterStruct*[1]')
+ errcheck('cudnnCreateFilterDescriptor', weightDesc)
+ local nOutputPlane, nInputPlane, kH, kW
+ = gradWeight:size(1), gradWeight:size(2), gradWeight:size(3), gradWeight:size(4)
+ local desc = torch.IntTensor({nOutputPlane, nInputPlane, kH, kW})
+ errcheck('cudnnSetFilterNdDescriptor', weightDesc[0], 'CUDNN_DATA_FLOAT', 4,
+ desc:data());
+ local function destroyWDesc(d)
+ errcheck('cudnnDestroyFilterDescriptor', d[0]);
+ end
+ ffi.gc(weightDesc, destroyWDesc)
+
+ -- create a convolution descriptor
+ local convDesc = ffi.new('struct cudnnConvolutionStruct*[1]')
+ errcheck('cudnnCreateConvolutionDescriptor', convDesc)
+ local pad = torch.IntTensor({padH, padW})
+ local stride = torch.IntTensor({strideH, strideW})
+ local upscale = torch.IntTensor({1,1})
+ errcheck('cudnnSetConvolutionNdDescriptor_v3', convDesc[0],
+ 2, pad:data(),
+ stride:data(), upscale:data(), 'CUDNN_CROSS_CORRELATION',
+ 'CUDNN_DATA_FLOAT');
+ local function destroyConvDesc(d)
+ errcheck('cudnnDestroyConvolutionDescriptor', d[0]);
+ end
+ ffi.gc(convDesc, destroyConvDesc)
+
+ -- create input, output descriptor
+ local iDesc = cudnn.toDescriptor(input)
+ local oDesc = cudnn.toDescriptor(gradOutput)
+
+ local algType = ffi.new("cudnnConvolutionBwdFilterAlgo_t[?]", 1)
+ local algSearchMode = 'CUDNN_CONVOLUTION_BWD_FILTER_NO_WORKSPACE'
+ local algWorkspaceLimit = 0
+
+ errcheck('cudnnGetConvolutionBackwardFilterAlgorithm',
+ cudnn.getHandle(),
+ iDesc[0], oDesc[0],
+ convDesc[0], weightDesc[0],
+ algSearchMode, algWorkspaceLimit, algType)
+
+
+ -- do convolution
+ errcheck('cudnnConvolutionBackwardFilter_v3', handle,
+ scaleT:data(),
+ iDesc[0], input:data(),
+ oDesc[0], gradOutput:data(),
+ convDesc[0],
+ algType[0],
+ ffi.C.NULL, 0,
+ one:data(),
+ weightDesc[0], gradWeight:data());
+end
+
+
+
+-- Does a 2D Pooling (updateOutput) on input, weight
+-- output is assumed to be allocated and given.
+cudnn.functional.Pooling_updateOutput = function(handle, mode, input, output,
+ kH, kW, dH, dW, padH, padW, ceil_mode)
+ input = input:dim() == 3 and Batch2D(input) or input
+ output = output:dim() == 3 and Batch2D(output) or output
+
+ padH = padH or 0
+ padW = padW or 0
+ ceil_mode = ceil_mode or false
+
+ local oW, oH
+ if ceil_mode then
+ oW = math.ceil((input:size(4)+padW*2 - kW)/dW + 1)
+ oH = math.ceil((input:size(3)+padH*2 - kH)/dH + 1)
+ else
+ oW = math.floor((input:size(4)+padW*2 - kW)/dW + 1)
+ oH = math.floor((input:size(3)+padH*2 - kH)/dH + 1)
+ end
+ assert(oH == output:size(3) and oW == output:size(4),
+ 'size mismatch: ' .. oH .. 'x' .. oW .. ' vs ' ..
+ output:size(3) .. 'x' .. output:size(4))
+
+ -- create pooling descriptor
+ local poolDesc = ffi.new('struct cudnnPoolingStruct*[1]')
+ errcheck('cudnnCreatePoolingDescriptor', poolDesc)
+ local ker = torch.IntTensor({kH, kW})
+ local str = torch.IntTensor({dH, dW})
+ local pad = torch.IntTensor({padH, padW})
+ errcheck('cudnnSetPoolingNdDescriptor', poolDesc[0], mode, 2,
+ ker:data(), pad:data(), str:data());
+ local function destroyPoolDesc(d)
+ errcheck('cudnnDestroyPoolingDescriptor', d[0]);
+ end
+ ffi.gc(poolDesc, destroyPoolDesc)
+
+ -- create input, output descriptor
+ local iDesc = cudnn.toDescriptor(input)
+ local oDesc = cudnn.toDescriptor(output)
+
+ -- pool
+ errcheck('cudnnPoolingForward', handle,
+ poolDesc[0],
+ one:data(),
+ iDesc[0], input:data(),
+ zero:data(),
+ oDesc[0], output:data());
+end
+
+cudnn.functional.MaxPooling2D_updateOutput = function(handle, input, output,
+ kH, kW, dH, dW, padH, padW, ceil_mode)
+ cudnn.functional.Pooling_updateOutput(handle, 'CUDNN_POOLING_MAX', input, output,
+ kH, kW, dH, dW, padH, padW, ceil_mode);
+end
+
+cudnn.functional.AveragePooling2D_updateOutput = function(handle, input, output,
+ kH, kW, dH, dW, padH, padW, ceil_mode)
+ cudnn.functional.Pooling_updateOutput(handle, 'CUDNN_POOLING_AVERAGE', input, output,
+ kH, kW, dH, dW, padH, padW, ceil_mode);
+end
+
+-- Does a 2D Pooling (updateGradInput) on input, weight
+-- output is assumed to be allocated and given.
+cudnn.functional.Pooling_updateGradInput = function(handle, mode, input, output, gradOutput, gradInput,
+ kH, kW, dH, dW, padH, padW, ceil_mode)
+ input = input:dim() == 3 and Batch2D(input) or input
+ output = output:dim() == 3 and Batch2D(output) or output
+ gradOutput = gradOutput:dim() == 3 and Batch2D(gradOutput) or gradOutput
+ gradInput = gradInput:dim() == 3 and Batch2D(gradInput) or gradInput
+
+ padH = padH or 0
+ padW = padW or 0
+ ceil_mode = ceil_mode or false
+
+ local oW, oH
+ if ceil_mode then
+ oW = math.ceil((input:size(4)+padW*2 - kW)/dW + 1)
+ oH = math.ceil((input:size(3)+padH*2 - kH)/dH + 1)
+ else
+ oW = math.floor((input:size(4)+padW*2 - kW)/dW + 1)
+ oH = math.floor((input:size(3)+padH*2 - kH)/dH + 1)
+ end
+ assert(oH == output:size(3) and oW == output:size(4),
+ 'size mismatch: ' .. oH .. 'x' .. oW .. ' vs ' ..
+ output:size(3) .. 'x' .. output:size(4))
+
+ -- create pooling descriptor
+ local poolDesc = ffi.new('struct cudnnPoolingStruct*[1]')
+ errcheck('cudnnCreatePoolingDescriptor', poolDesc)
+ local ker = torch.IntTensor({kH, kW})
+ local str = torch.IntTensor({dH, dW})
+ local pad = torch.IntTensor({padH, padW})
+ errcheck('cudnnSetPoolingNdDescriptor', poolDesc[0], mode, 2,
+ ker:data(), pad:data(), str:data());
+ local function destroyPoolDesc(d)
+ errcheck('cudnnDestroyPoolingDescriptor', d[0]);
+ end
+ ffi.gc(poolDesc, destroyPoolDesc)
+
+ -- create input, output descriptor
+ local iDesc = cudnn.toDescriptor(input)
+ local oDesc = cudnn.toDescriptor(output)
+
+ -- pool
+ errcheck('cudnnPoolingBackward',
+ handle, poolDesc[0],
+ one:data(),
+ oDesc[0], output:data(),
+ oDesc[0], gradOutput:data(),
+ iDesc[0], input:data(),
+ zero:data(),
+ iDesc[0], gradInput:data());
+end
+
+cudnn.functional.MaxPooling2D_updateGradInput = function(handle, input, output, gradOutput, gradInput,
+ kH, kW, dH, dW, padH, padW, ceil_mode)
+ cudnn.functional.Pooling_updateGradInput(handle, 'CUDNN_POOLING_MAX', input, output, gradOutput, gradInput,
+ kH, kW, dH, dW, padH, padW, ceil_mode);
+end
+
+cudnn.functional.AveragePooling2D_updateGradInput = function(handle, input, output, gradOutput, gradInput,
+ kH, kW, dH, dW, padH, padW, ceil_mode)
+ cudnn.functional.Pooling_updateGradInput(handle, 'CUDNN_POOLING_AVERAGE', input, output, gradOutput, gradInput,
+ kH, kW, dH, dW, padH, padW, ceil_mode);
+end
diff --git a/test/test.lua b/test/test.lua
index c2938de..4062425 100644
--- a/test/test.lua
+++ b/test/test.lua
@@ -724,7 +724,7 @@ function cudnntest.LogSoftMax_batch()
precision_backward, 'error on state (backward) ')
end
-function cudnntest.functional_SpatialBias()
+function cudnntest.functional_bias2D()
local bs = math.random(1,32)
local from = math.random(1,32)
local to = math.random(1,64)
@@ -742,7 +742,7 @@ function cudnntest.functional_SpatialBias()
mod.weight:zero()
local groundtruth = mod:forward(input)
local result = groundtruth:clone():zero()
- cudnn.functional.SpatialBias_updateOutput(mod.bias, result)
+ cudnn.functional.bias2D_updateOutput(cudnn.getHandle(), mod.bias, result)
local error = result:float() - groundtruth:float()
mytester:assertlt(error:abs():max(),
precision_forward, 'error on forward ')
@@ -752,12 +752,63 @@ function cudnntest.functional_SpatialBias()
mod:backward(input, gradOutput, scale)
local groundtruth = mod.gradBias
local result = groundtruth:clone():zero()
- cudnn.functional.SpatialBias_accGradParameters(gradOutput, result, scale)
+ cudnn.functional.bias2D_accGradParameters(cudnn.getHandle(), gradOutput, result, scale)
error = result:float() - groundtruth:float()
mytester:assertlt(error:abs():max(),
precision_backward, 'error on accGradParameters ')
end
+function cudnntest.functional_convolution2d()
+ local a=cudnn.SpatialConvolution(3,16,5,5):cuda()
+ a.bias:zero();
+ local input = torch.randn(10,3,10,10):cuda()
+ a:zeroGradParameters()
+ a:forward(input);
+ local output = a.output:clone():normal()
+ local gradOutput = a.output:clone():normal()
+ local gradInput = a:backward(input, gradOutput):clone():normal()
+ local gradWeight = a.gradWeight:clone():zero()
+ cudnn.functional.Convolution2D_updateOutput(cudnn.getHandle(), input,
+ a.weight, output, a.dH,
+ a.dW, a.padH, a.padW)
+ mytester:assertlt((output - a.output):abs():max(),
+ precision_forward, 'error on forward ')
+
+ cudnn.functional.Convolution2D_updateGradInput(cudnn.getHandle(), input,
+ a.weight, output, gradOutput,
+ gradInput,
+ a.dH, a.dW, a.padH, a.padW)
+ mytester:assertlt((gradInput - a.gradInput):abs():max(),
+ precision_forward, 'error on updateGradInput ')
+
+ cudnn.functional.Convolution2D_accGradParameters(cudnn.getHandle(), input,
+ gradWeight, gradOutput,
+ a.dH, a.dW, a.padH, a.padW)
+ mytester:assertlt((gradWeight - a.gradWeight):abs():max(),
+ precision_forward, 'error on accGradParameters ')
+end
+
+function cudnntest.functional_maxpooling2d()
+ local a=cudnn.SpatialMaxPooling(2,2,2,2):cuda()
+ local input = torch.randn(10,3,10,10):cuda()
+ a:forward(input);
+ local output = a.output:clone():normal()
+ local gradOutput = a.output:clone():normal()
+ local gradInput = a:backward(input, gradOutput):clone():normal()
+ cudnn.functional.MaxPooling2D_updateOutput(cudnn.getHandle(), input,
+ output, a.kH, a.kW,
+ a.dH, a.dW, a.padH, a.padW)
+ mytester:assertlt((output - a.output):abs():max(),
+ precision_forward, 'error on forward ')
+
+ cudnn.functional.MaxPooling2D_updateGradInput(cudnn.getHandle(), input,
+ output, gradOutput, gradInput,
+ a.kH, a.kW, a.dH, a.dW,
+ a.padH, a.padW)
+ mytester:assertlt((gradInput - a.gradInput):abs():max(),
+ precision_forward, 'error on updateGradInput ')
+end
+
torch.setdefaulttensortype('torch.FloatTensor')
math.randomseed(os.time())