Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/soumith/cudnn.torch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAnthony Sandrin <asandrin@nvidia.com>2016-03-12 02:29:49 +0300
committerBoris Fomitchev <bfomitchev@nvidia.com>2016-04-18 23:18:26 +0300
commit15cec07968175f2edc5368bca32829319ad0570b (patch)
treea3dc910d9c7d22932daa6e5c04bb77a41397db7a
parent19dc9a13b5812cd74016a3b9e27d306f35b4b42f (diff)
Initial work for cudnn RNN api integration
Added cudnnFind auto-tuning Change RNN layer api and improve descriptor/tensor resizing conditions Implement updateGradInput and accGradParameters
-rw-r--r--RNN.lua393
-rw-r--r--init.lua1
-rw-r--r--rnn_exp.lua293
-rw-r--r--rnn_exp2.lua198
-rw-r--r--test/test.lua5
5 files changed, 888 insertions, 2 deletions
diff --git a/RNN.lua b/RNN.lua
new file mode 100644
index 0000000..ae77f86
--- /dev/null
+++ b/RNN.lua
@@ -0,0 +1,393 @@
+local RNN, parent = torch.class('cudnn.RNN', 'nn.Module')
+local ffi = require 'ffi'
+local errcheck = cudnn.errcheck
+
+function RNN:__init(hiddenSize, numLayers)
+ parent.__init(self)
+
+ self.datatype = 0 -- TODO CUDNN_FLOAT, should get the constant from ffi
+ self.hiddenSize = hiddenSize
+ self.inputSize = 0
+ self.seqLength = 0
+ self.numLayers = numLayers
+ self.miniBatch = 0
+ self.bidirectional = 0
+ self.inputMode = 0 -- TODO CUDNN_LINEAR_INPUT, should get the constant from ffi
+ self.mode = 0 -- TODO CUDNN_RNN_RELU, should get the constant from ffi
+ self.dropout = 0
+ self.seed = 0x01234567
+
+ self.gradInput = torch.CudaTensor()
+ self.output = torch.CudaTensor()
+ self.weight = torch.CudaTensor()
+ self.gradParameters = torch.CudaTensor()
+ self.hx = torch.CudaTensor()
+ self.cx = torch.CudaTensor()
+ self.hy = torch.CudaTensor()
+ self.cy = torch.CudaTensor()
+ self.reserve = torch.CudaTensor(1)
+end
+
+local function createDescriptors(count, descs_type, create_func, destroy_func)
+ local ds = ffi.new(descs_type, count)
+ for i = 0, count - 1 do
+ errcheck(create_func, ds + i)
+ end
+ local function destroyDescriptors(ds)
+ for i = 0, count - 1 do
+ errcheck(destroy_func, ds[i])
+ end
+ end
+ ffi.gc(ds, destroyDescriptors)
+ return ds
+end
+
+local function createDropoutDescriptors(count)
+ return createDescriptors(count,
+ 'cudnnDropoutDescriptor_t[?]',
+ 'cudnnCreateDropoutDescriptor',
+ 'cudnnDestroyDropoutDescriptor')
+end
+
+local function createFilterDescriptors(count)
+ return createDescriptors(count,
+ 'cudnnFilterDescriptor_t[?]',
+ 'cudnnCreateFilterDescriptor',
+ 'cudnnDestroyFilterDescriptor')
+end
+
+local function createRNNDescriptors(count)
+ return createDescriptors(count,
+ 'cudnnRNNDescriptor_t[?]',
+ 'cudnnCreateRNNDescriptor',
+ 'cudnnDestroyRNNDescriptor')
+end
+
+local function createTensorDescriptors(count) return createDescriptors(count,
+ 'cudnnTensorDescriptor_t[?]',
+ 'cudnnCreateTensorDescriptor',
+ 'cudnnDestroyTensorDescriptor')
+end
+
+function RNN:resetDropoutDescriptor()
+ if not self.dropoutDesc then
+ self.dropoutDesc = createDropoutDescriptors(1)
+ end
+
+ self.dropoutStatesSize = torch.LongTensor(1)
+ errcheck('cudnnDropoutGetStatesSize',
+ cudnn.getHandle(),
+ self.dropoutStatesSize:data())
+ self.dropoutStates = torch.CudaTensor(self.dropoutStatesSize[1])
+
+ errcheck('cudnnSetDropoutDescriptor',
+ self.dropoutDesc[0],
+ cudnn.getHandle(),
+ self.dropout,
+ self.dropoutStates:data(), self.dropoutStatesSize[1],
+ self.seed)
+end
+
+function RNN:resetRNNDescriptor()
+ if not self.rnnDesc then
+ self.rnnDesc = createRNNDescriptors(1)
+ end
+
+ errcheck('cudnnSetRNNDescriptor',
+ self.rnnDesc[0],
+ self.hiddenSize,
+ self.seqLength,
+ self.numLayers,
+ self.dropoutDesc[0],
+ self.inputMode,
+ self.bidirectional,
+ self.mode,
+ self.datatype)
+end
+
+function RNN:resetWeightDescriptors()
+ if not self.wDesc then
+ self.wDesc = createFilterDescriptors(1)
+ end
+
+ local weightSize = torch.LongTensor(1)
+ errcheck('cudnnGetRNNParamsSize',
+ cudnn.getHandle(),
+ self.rnnDesc[0],
+ self.xDescs,
+ weightSize:data())
+ local dim = torch.IntTensor({weightSize[1] / 4, 1, 1}) -- sizeof(float)
+
+ errcheck('cudnnSetFilterNdDescriptor',
+ self.wDesc[0],
+ self.datatype,
+ 0, -- TODO ffi CUDNN_TENSOR_NCHW
+ 3,
+ dim:data())
+end
+
+function RNN:resetIODescriptors()
+ self.xDescs = createTensorDescriptors(self.seqLength)
+ self.yDescs = createTensorDescriptors(self.seqLength)
+
+ for i = 0, self.seqLength - 1 do
+ local dim = torch.IntTensor({self.inputSize, self.miniBatch, 1})
+ local stride = torch.IntTensor({1, dim[1], dim[1] * dim[2]})
+
+ errcheck('cudnnSetTensorNdDescriptor',
+ self.xDescs[i],
+ self.datatype,
+ 3,
+ dim:data(),
+ stride:data())
+
+ dim[1] = self.hiddenSize * (self.bidirectional > 0 and 2 or 1)
+ stride[2] = dim[1]
+ stride[3] = dim[1] * dim[2]
+
+ errcheck('cudnnSetTensorNdDescriptor',
+ self.yDescs[i],
+ self.datatype,
+ 3,
+ dim:data(),
+ stride:data())
+ end
+end
+
+function RNN:resetHiddenDescriptors()
+ self.hxDesc = cudnn.toDescriptor(self.hx)
+ self.hyDesc = cudnn.toDescriptor(self.hy)
+end
+
+function RNN:resetCellDescriptors()
+ self.cxDesc = cudnn.toDescriptor(self.cx)
+ self.cyDesc = cudnn.toDescriptor(self.cy)
+end
+
+function RNN:makeContiguous(input, gradOutput)
+ if not input:isContiguous() then
+ self._input = self._input or input.new()
+ self._input:typeAs(input):resizeAs(input):copy(input)
+ input = self._input
+ end
+ if gradOutput and not gradOutput:isContiguous() then
+ self._gradOutput = self._gradOutput or gradOutput.new()
+ self._gradOutput:typeAs(gradOutput):resizeAs(gradOutput):copy(gradOutput)
+ gradOutput = self._gradOutput
+ end
+ return input, gradOutput
+end
+
+function RNN:updateOutput(input)
+
+ assert(input:dim() == 3)
+
+ -- Decide which descriptors/tensors need to be updated.
+ local resetRNN = not DropoutDesc or not RNNDesc
+ local resetIO = not xDescs or not yDescs
+ local resetHC = not self.hxDesc or not self.hyDesc or
+ not self.cxDesc or not self.cyDesc
+ local resetWeight = not wDesc
+
+ if input:size(1) ~= self.inputSize then
+ self.inputSize = input:size(1)
+ resetRNN = true
+ resetIO = true
+ resetWeight = true
+ end
+
+ if input:size(2) ~= self.miniBatch then
+ self.miniBatch = input:size(1)
+ resetRNN = true
+ resetIO = true
+ resetHC = true
+ resetWeight = true
+ end
+
+ if input:size(3) ~= self.seqLength then
+ self.seqLength = input:size(1)
+ resetRNN = true
+ resetIO = true
+ end
+
+ -- Update descriptors/tensors
+ if resetRNN then
+ self:resetDropoutDescriptor()
+ self:resetRNNDescriptor()
+ end
+
+ local x = self:makeContiguous(input)
+ local y = self.output
+ if resetIO then
+ self.output:resize(self.hiddenSize, self.miniBatch, self.seqLength)
+ self:resetIODescriptors()
+ end
+
+ -- Hidden/cell output becomes the new hidden/cell input.
+ local hx = self.hy
+ local cx = self.cy
+ local hy = self.hx
+ local cy = self.cx
+ if resetHC then
+ self.hx:resize(self.hiddenSize, self.miniBatch, self.numLayers)
+ self.cx:resize(self.hiddenSize, self.miniBatch, self.numLayers)
+ self.hy:resize(self.hiddenSize, self.miniBatch, self.numLayers)
+ self.cy:resize(self.hiddenSize, self.miniBatch, self.numLayers)
+ self:resetHiddenDescriptors()
+ self:resetCellDescriptors()
+ end
+
+ local w = self.weight
+ if resetWeight then
+ local weightSize = torch.LongTensor(1)
+ errcheck('cudnnGetRNNParamsSize',
+ cudnn.getHandle(),
+ self.rnnDesc[0],
+ self.xDescs,
+ weightSize:data())
+ weightSize[1] = (weightSize[1] + 3) / 4 -- sizeof(float)
+ self.weight:resize(weightSize[1] / 4)
+ self:resetWeightDescriptors()
+ end
+
+ self.workspace = cudnn.getSharedWorkspace()
+ local workspaceSize = torch.LongTensor(1)
+ errcheck('cudnnGetRNNWorkspaceSize',
+ cudnn.getHandle(),
+ self.rnnDesc[0],
+ self.xDescs,
+ workspaceSize:data())
+ workspaceSize[1] = (workspaceSize[1] + 3) / 4 -- sizeof(float)
+ if self.workspace:size(1) < workspaceSize[1] then
+ self.workspace:resize(workspaceSize[1])
+ end
+
+ local reserveSize = torch.LongTensor(1)
+ errcheck('cudnnGetRNNTrainingReserveSize',
+ cudnn.getHandle(),
+ self.rnnDesc[0],
+ self.xDescs,
+ reserveSize:data())
+ reserveSize[1] = (reserveSize[1] + 3) / 4 -- sizeof(float)
+ if self.reserve:size(1) < reserveSize[1] then
+ self.reserve:resize(reserveSize[1])
+ end
+
+ errcheck('cudnnRNNForwardTraining',
+ cudnn.getHandle(),
+ self.rnnDesc[0],
+ self.xDescs, x:data(),
+ self.hxDesc[0], hx:data(),
+ self.cxDesc[0], cx:data(),
+ self.wDesc[0], w:data(),
+ self.yDescs, y:data(),
+ self.hyDesc[0], hy:data(),
+ self.cyDesc[0], cy:data(),
+ self.workspace:data(), self.workspace:size(1) * 4, -- sizeof(float)
+ self.reserve:data(), self.reserve:size(1) * 4) -- sizeof(float)
+end
+
+function RNN:updateGradInput(input, gradOutput)
+
+ assert(input:dim() == 3)
+ assert(input:size(1) == self.inputSize)
+ assert(input:size(2) == self.miniBatch)
+ assert(input:size(3) == self.seqLength)
+
+ assert(gradOutput:dim() == self.output:dim())
+ for i = 1, gradOutput:dim() do
+ assert(gradOutput:size(i) == self.output:size(i))
+ end
+
+ local y = self.output
+ local dy = gradOutput
+ local w = self.weight
+ local hx = self.hx
+ local cx = self.cx
+ local dx = self.gradInput
+
+ if dx:dim() ~= 3 or
+ dx:size(1) ~= input:size(1) or
+ dx:size(2) ~= input:size(2) or
+ dx:size(3) ~= input:size(3) then
+ dx:resizeAs(input)
+ end
+
+ errcheck('cudnnRNNBackwardData',
+ cudnn.getHandle(),
+ self.rnnDesc[0],
+ self.yDescs, y:data(),
+ self.yDescs, dy:data(),
+ self.hyDesc[0], nil, -- TODO should dhy be ignored?
+ self.cyDesc[0], nil, -- TODO should dhy be ignored?
+ self.wDesc[0], w:data(),
+ self.hxDesc[0], hx:data(),
+ self.cxDesc[0], cx:data(),
+ self.xDescs, dx:data(),
+ self.hxDesc[0], nil, -- TODO should dhx be ignored?
+ self.cxDesc[0], nil, -- TODO should dcx be ignored?
+ self.workspace:data(), self.workspace:size(1) * 4, -- sizeof(float)
+ self.reserve:data(), self.reserve:size(1) * 4) -- sizeof(float)
+end
+
+function RNN:accGradParameters(input, gradOutput, scale)
+
+ assert(input:dim() == 3)
+ assert(input:size(1) == self.inputSize)
+ assert(input:size(2) == self.miniBatch)
+ assert(input:size(3) == self.seqLength)
+
+ assert(gradOutput:dim() == self.output:dim())
+ for i = 1, gradOutput:dim() do
+ assert(gradOutput:size(i) == self.output:size(i))
+ end
+
+ local x = input
+ local hx = self.hx
+ local y = self.output
+ local dw = self.gradParameters
+
+ if dw:dim() ~= 3 or
+ dw:size(1) ~= self.weight:size(1) or
+ dw:size(2) ~= self.weight:size(2) or
+ dw:size(3) ~= self.weight:size(3) then
+ dw:resizeAs(self.weight)
+ end
+
+ if scale == 0 then
+ return
+ end
+
+ -- cudnnRNNBackwardWeights doesn't accept a scale parameter so instead
+ -- scale before and after.
+ -- TODO: How much does this impact accuracy?
+ if scale ~= 1 then
+ local scaleTensor = torch.Tensor({1 / scale})
+ errcheck('cudnnScaleTensor',
+ cudnn.getHandle(),
+ self.wDesc[0],
+ self.dw:data(),
+ scaleTensor:data())
+ end
+
+ errcheck('cudnnRNNBackwardWeights',
+ cudnn.getHandle(),
+ self.rnnDesc[0],
+ self.xDescs, x:data(),
+ self.hxDesc[0], hx:data(),
+ self.yDescs, y:data(),
+ self.workspace:data(), self.workspace:size(1) * 4, -- sizeof(float)
+ self.wDesc[0], dw:data(),
+ self.reserve:data(), self.reserve:size(1) * 4) -- sizeof(float)
+
+
+ if scale ~= 1 then
+ local scaleTensor = torch.Tensor({scale})
+ errcheck('cudnnScaleTensor',
+ cudnn.getHandle(),
+ self.wDesc[0],
+ self.dw:data(),
+ scaleTensor:data())
+ end
+end
+
diff --git a/init.lua b/init.lua
index a4755f2..d7d2cf7 100644
--- a/init.lua
+++ b/init.lua
@@ -122,6 +122,7 @@ require('cudnn.SpatialBatchNormalization')
require('cudnn.VolumetricBatchNormalization')
require('cudnn.SpatialCrossEntropyCriterion')
require('cudnn.TemporalConvolution')
+require('cudnn.RNN')
require('cudnn.functional')
require('cudnn.convert')
diff --git a/rnn_exp.lua b/rnn_exp.lua
new file mode 100644
index 0000000..3af7118
--- /dev/null
+++ b/rnn_exp.lua
@@ -0,0 +1,293 @@
+import 'cudnn'
+local ffi = require 'ffi'
+local errcheck = cudnn.errcheck
+
+local datatype = 0 -- TODO CUDNN_FLOAT, should get the constant from ffi
+local hiddenSize = 1 -- TODO This is a layer parameter, correct?
+local inputSize = 1 -- TODO Is this a layer parameter or determined by input?
+local seqLength = 1 -- TODO Is this a layer parameter or determined by input?
+local numLayers = 1 -- TODO
+local miniBatch = 1 -- TODO
+local bidirectional = 0 -- TODO CUDNN_UNIDIRECTIONAL, should get the constant from ffi
+local inputMode = 0 -- TODO CUDNN_LINEAR_INPUT, should get the constant from ffi
+local mode = 0 -- TODO CUDNN_RNN_RELU, should get the constant from ffi
+local dropout = 0 -- TODO
+local seed = 0x01234567 -- TODO
+
+-- Dropout Descriptor
+
+local dropoutStatesSize = torch.LongTensor(1)
+errcheck('cudnnDropoutGetStatesSize',
+ cudnn.getHandle(),
+ dropoutStatesSize:data())
+local dropoutStates = torch.CudaTensor(dropoutStatesSize[1])
+
+local dropoutDesc = ffi.new('cudnnDropoutDescriptor_t[?]', 1)
+errcheck('cudnnCreateDropoutDescriptor', dropoutDesc)
+-- TODO GC was being called early. Ignore cleanup for now.
+-- ffi.gc(dropoutDesc, function(d) errcheck('cudnnDestroyDropoutDescriptor', d[0]) end)
+errcheck('cudnnSetDropoutDescriptor',
+ dropoutDesc[0],
+ cudnn.getHandle(),
+ dropout,
+ -- TODO Using dropoutStates causes an invalid memory access error.
+ dropoutStates:data(), dropoutStatesSize[1],
+ seed)
+
+-- RNN Descriptor
+local rnnDesc = ffi.new('cudnnRNNDescriptor_t[?]', 1)
+errcheck('cudnnCreateRNNDescriptor', rnnDesc)
+-- ffi.gc(rnnDesc, function(d) errcheck('cudnnDestroyRNNDescriptor', d[0]) end)
+errcheck('cudnnSetRNNDescriptor',
+ rnnDesc[0],
+ hiddenSize,
+ seqLength,
+ numLayers,
+ dropoutDesc[0],
+ inputMode,
+ bidirectional,
+ mode,
+ datatype)
+
+-- Input
+local inputDescs = ffi.new('cudnnTensorDescriptor_t[?]', seqLength)
+for i = 0, seqLength - 1 do
+ errcheck('cudnnCreateTensorDescriptor', inputDescs + i)
+end
+-- ffi.gc(inputDescs, function()
+-- for i = 0, seqLength - 1 do
+-- errcheck('cudnnDestroyTensorDescriptor', inputDescs[i])
+-- end
+-- end)
+
+local dims = torch.IntTensor({inputSize, miniBatch, seqLength})
+local stride = torch.IntTensor({1, dims[1], 1})
+
+for i = 0, seqLength - 1 do
+ errcheck('cudnnSetTensorNdDescriptor',
+ inputDescs[i],
+ datatype,
+ 3,
+ dims:data(),
+ stride:data())
+end
+
+local input = torch.CudaTensor(dims[1], dims[2], dims[3])
+
+-- Ouptut
+local outputDescs = ffi.new('cudnnTensorDescriptor_t[?]', seqLength)
+for i = 0, seqLength - 1 do
+ errcheck('cudnnCreateTensorDescriptor', outputDescs + i)
+end
+-- ffi.gc(outputDescs, function()
+-- for i = 0, seqLength - 1 do
+-- errcheck('cudnnDestroyTensorDescriptor', outputDescs[i])
+-- end
+-- end)
+
+local dims = torch.IntTensor({hiddenSize, miniBatch, seqLength})
+local stride = torch.IntTensor({1, dims[1], 1})
+
+for i = 0, seqLength - 1 do
+ errcheck('cudnnSetTensorNdDescriptor',
+ outputDescs[i],
+ datatype,
+ 3,
+ dims:data(),
+ stride:data())
+end
+
+local output = torch.CudaTensor(dims[1], dims[2], dims[3])
+
+-- Hidden
+local hiddenInputDesc = ffi.new('cudnnTensorDescriptor_t[?]', 1)
+local hiddenOutputDesc = ffi.new('cudnnTensorDescriptor_t[?]', 1)
+errcheck('cudnnCreateTensorDescriptor', hiddenInputDesc)
+errcheck('cudnnCreateTensorDescriptor', hiddenOutputDesc)
+-- ffi.gc(hiddenInputDesc, function(d) errcheck('cudnnDestroyTensorDescriptor', d[0]) end)
+-- ffi.gc(hiddenOutputDesc, function(d) errcheck('cudnnDestroyTensorDescriptor', d[0]) end)
+
+local dims = torch.IntTensor({hiddenSize, miniBatch, numLayers})
+local stride = torch.IntTensor({1, dims[1], 1})
+
+errcheck('cudnnSetTensorNdDescriptor',
+ hiddenInputDesc[0],
+ datatype,
+ 3,
+ dims:data(),
+ stride:data())
+errcheck('cudnnSetTensorNdDescriptor',
+ hiddenOutputDesc[0],
+ datatype,
+ 3,
+ dims:data(),
+ stride:data())
+
+local hiddenInput = torch.CudaTensor(dims[1], dims[2], dims[3])
+local hiddenOutput = torch.CudaTensor(dims[1], dims[2], dims[3])
+
+-- Cell
+local cellInputDesc = ffi.new('cudnnTensorDescriptor_t[?]', 1)
+local cellOutputDesc = ffi.new('cudnnTensorDescriptor_t[?]', 1)
+errcheck('cudnnCreateTensorDescriptor', cellInputDesc)
+errcheck('cudnnCreateTensorDescriptor', cellOutputDesc)
+-- ffi.gc(cellInputDesc, function(d) errcheck('cudnnDestroyTensorDescriptor', d[0]) end)
+-- ffi.gc(cellOutputDesc, function(d) errcheck('cudnnDestroyTensorDescriptor', d[0]) end)
+
+local dims = torch.IntTensor({hiddenSize, miniBatch, numLayers})
+local stride = torch.IntTensor({1, dims[1], 1})
+
+errcheck('cudnnSetTensorNdDescriptor',
+ cellInputDesc[0],
+ datatype,
+ 3,
+ dims:data(),
+ stride:data())
+errcheck('cudnnSetTensorNdDescriptor',
+ cellOutputDesc[0],
+ datatype,
+ 3,
+ dims:data(),
+ stride:data())
+
+local cellInput = torch.CudaTensor(dims[1], dims[2], dims[3])
+local cellOutput = torch.CudaTensor(dims[1], dims[2], dims[3])
+
+-- Weight
+local weightDesc = ffi.new('cudnnFilterDescriptor_t[?]', 1)
+errcheck('cudnnCreateFilterDescriptor', weightDesc)
+-- ffi.gc(weightDesc, function(d) errcheck('cudnnDestroyFilterDescriptor', d[0]) end)
+
+local weightSize = torch.LongTensor(1)
+errcheck('cudnnGetRNNParamsSize',
+ cudnn.getHandle(),
+ rnnDesc[0],
+ inputDescs,
+ weightSize:data())
+local dims = torch.IntTensor({weightSize[1] / 4, 1, 1}) -- sizeof(float)
+
+errcheck('cudnnSetFilterNdDescriptor',
+ weightDesc[0],
+ datatype,
+ 0, -- TODO ffi CUDNN_TENSOR_NCHW
+ 3,
+ dims:data())
+local weight = torch.CudaTensor(dims[1], dims[2], dims[3])
+
+-- Workspace
+local workspace = cudnn.getSharedWorkspace()
+local workspaceSize = torch.LongTensor(1)
+errcheck('cudnnGetRNNWorkspaceSize',
+ cudnn.getHandle(),
+ rnnDesc[0],
+ inputDescs,
+ workspaceSize:data())
+workspace:resize(workspaceSize[1] / 4) -- sizeof(float)
+
+-- Print Descriptor data
+print("hiddenSize = " .. hiddenSize)
+print("inputSize = " .. inputSize)
+print("seqLength = " .. seqLength)
+print("numLayers = " .. numLayers)
+print("miniBatch = " .. miniBatch)
+print("bidirectional = " .. bidirectional)
+print("inputMode = " .. inputMode)
+print("mode = " .. mode)
+print("dropout = " .. dropout)
+
+local datatype = torch.IntTensor(1)
+local nbDims = torch.IntTensor(1)
+local dims = torch.IntTensor(3)
+local stride = torch.IntTensor(3)
+
+errcheck('cudnnGetTensorNdDescriptor',
+ inputDescs[0],
+ 3,
+ datatype:data(),
+ nbDims:data(),
+ dims:data(),
+ stride:data())
+print("Input " ..
+ "dim=(" .. dims[1] .. ", " .. dims[2] .. ", " .. dims[3] .. ") " ..
+ "stride=(" .. stride[1] .. ", " .. stride[2] .. ", " .. stride[3] .. ")")
+
+errcheck('cudnnGetTensorNdDescriptor',
+ outputDescs[0],
+ 3,
+ datatype:data(),
+ nbDims:data(),
+ dims:data(),
+ stride:data())
+print("Output " ..
+ "dim=(" .. dims[1] .. ", " .. dims[2] .. ", " .. dims[3] .. ") " ..
+ "stride=(" .. stride[1] .. ", " .. stride[2] .. ", " .. stride[3] .. ")")
+
+errcheck('cudnnGetTensorNdDescriptor',
+ hiddenInputDesc[0],
+ 3,
+ datatype:data(),
+ nbDims:data(),
+ dims:data(),
+ stride:data())
+print("Hidden Input " ..
+ "dim=(" .. dims[1] .. ", " .. dims[2] .. ", " .. dims[3] .. ") " ..
+ "stride=(" .. stride[1] .. ", " .. stride[2] .. ", " .. stride[3] .. ")")
+
+errcheck('cudnnGetTensorNdDescriptor',
+ hiddenOutputDesc[0],
+ 3,
+ datatype:data(),
+ nbDims:data(),
+ dims:data(),
+ stride:data())
+print("Hidden Output " ..
+ "dim=(" .. dims[1] .. ", " .. dims[2] .. ", " .. dims[3] .. ") " ..
+ "stride=(" .. stride[1] .. ", " .. stride[2] .. ", " .. stride[3] .. ")")
+
+errcheck('cudnnGetTensorNdDescriptor',
+ cellInputDesc[0],
+ 3,
+ datatype:data(),
+ nbDims:data(),
+ dims:data(),
+ stride:data())
+print("Cell Input " ..
+ "dim=(" .. dims[1] .. ", " .. dims[2] .. ", " .. dims[3] .. ") " ..
+ "stride=(" .. stride[1] .. ", " .. stride[2] .. ", " .. stride[3] .. ")")
+
+errcheck('cudnnGetTensorNdDescriptor',
+ cellOutputDesc[0],
+ 3,
+ datatype:data(),
+ nbDims:data(),
+ dims:data(),
+ stride:data())
+print("Cell Output " ..
+ "dim=(" .. dims[1] .. ", " .. dims[2] .. ", " .. dims[3] .. ") " ..
+ "stride=(" .. stride[1] .. ", " .. stride[2] .. ", " .. stride[3] .. ")")
+
+local format = ffi.new('cudnnTensorFormat_t[?]', 1)
+errcheck('cudnnGetFilterNdDescriptor',
+ weightDesc[0],
+ 3,
+ datatype:data(),
+ format,
+ nbDims:data(),
+ dims:data())
+
+print("Weight " ..
+ "dim=(" .. dims[1] .. ", " .. dims[2] .. ", " .. dims[3] .. ") ")
+
+-- ForwardInference
+errcheck('cudnnRNNForwardInference',
+ cudnn.getHandle(),
+ rnnDesc[0],
+ inputDescs, input:data(),
+ hiddenInputDesc[0], hiddenInput:data(),
+ cellInputDesc[0], cellInput:data(),
+ weightDesc[0], weight:data(),
+ outputDescs, output:data(),
+ hiddenOutputDesc[0], hiddenOutput:data(),
+ cellOutputDesc[0], cellOutput:data(),
+ workspace:data(), workspace:size(1) * 4) -- sizeof(float)
+
diff --git a/rnn_exp2.lua b/rnn_exp2.lua
new file mode 100644
index 0000000..e2ad093
--- /dev/null
+++ b/rnn_exp2.lua
@@ -0,0 +1,198 @@
+import 'cudnn'
+local ffi = require 'ffi'
+local errcheck = cudnn.errcheck
+
+local datatype = 0 -- TODO CUDNN_DATA_FLOAT=0, should get the constant from ffi
+local hiddenSize = 1 -- TODO This is a layer parameter, correct?
+local inputSize = 1 -- TODO Is this a layer parameter or determined by input?
+local seqLength = 1 -- TODO Is this a layer parameter or determined by input?
+local numLayers = 1 -- TODO
+local miniBatch = 1 -- TODO
+local bidirectional = 0 -- TODO CUDNN_UNIDIRECTIONAL=0, should get the constant from ffi
+local inputMode = 0 -- TODO CUDNN_LINEAR_INPUT=0, should get the constant from ffi
+local mode = 0 -- TODO CUDNN_RNN_RELU=0, CUDNN_LSTM=1, CUDNN_GRU=2 should get the constant from ffi
+local dropout = 0 -- TODO
+local seed = 0x01234567 -- TODO
+
+-- Dropout Descriptor
+
+print()
+print("---------------------------------------------------------------------------------------")
+print()
+local dropoutStatesSize = torch.LongTensor(1)
+errcheck('cudnnDropoutGetStatesSize', cudnn.getHandle(), dropoutStatesSize:data())
+local dropoutStates = torch.CudaTensor(dropoutStatesSize[1])
+
+local dropoutDesc = ffi.new('cudnnDropoutDescriptor_t[?]', 1)
+errcheck('cudnnCreateDropoutDescriptor', dropoutDesc)
+
+-- TODO GC was being called early. Ignore cleanup for now.
+-- ffi.gc(dropoutDesc, function(d) errcheck('cudnnDestroyDropoutDescriptor', d[0]) end)
+
+errcheck('cudnnSetDropoutDescriptor', dropoutDesc[0], cudnn.getHandle(), dropout, dropoutStates:data(), dropoutStatesSize[1], seed)
+
+-- RNN Descriptor
+local rnnDesc = ffi.new('cudnnRNNDescriptor_t[?]', 1)
+errcheck('cudnnCreateRNNDescriptor', rnnDesc)
+-- ffi.gc(rnnDesc, function(d) errcheck('cudnnDestroyRNNDescriptor', d[0]) end)
+errcheck('cudnnSetRNNDescriptor', rnnDesc[0], hiddenSize, seqLength, numLayers, dropoutDesc[0], inputMode, bidirectional, mode, datatype)
+
+-- Input
+local inputDescs = ffi.new('cudnnTensorDescriptor_t[?]', seqLength)
+for i = 0, seqLength - 1 do
+ errcheck('cudnnCreateTensorDescriptor', inputDescs + i)
+end
+-- ffi.gc(inputDescs, function()
+-- for i = 0, seqLength - 1 do
+-- errcheck('cudnnDestroyTensorDescriptor', inputDescs[i])
+-- end
+-- end)
+
+local dims_1 = torch.IntTensor({inputSize, miniBatch, seqLength})
+local stride_1 = torch.IntTensor({1, dims_1[1], 1})
+
+for i = 0, seqLength - 1 do
+ errcheck('cudnnSetTensorNdDescriptor', inputDescs[i], datatype, 3, dims_1:data(), stride_1:data())
+end
+
+local input = torch.CudaTensor(dims_1[1], dims_1[2], dims_1[3])
+
+-- Ouptut
+local outputDescs = ffi.new('cudnnTensorDescriptor_t[?]', seqLength)
+for i = 0, seqLength - 1 do
+ errcheck('cudnnCreateTensorDescriptor', outputDescs + i)
+end
+-- ffi.gc(outputDescs, function()
+-- for i = 0, seqLength - 1 do
+-- errcheck('cudnnDestroyTensorDescriptor', outputDescs[i])
+-- end
+-- end)
+
+local dims_2 = torch.IntTensor({hiddenSize, miniBatch, seqLength})
+local stride_2 = torch.IntTensor({1, dims_2[1], 1})
+
+for i = 0, seqLength - 1 do
+ errcheck('cudnnSetTensorNdDescriptor', outputDescs[i], datatype, 3, dims_2:data(), stride_2:data())
+end
+
+local output = torch.CudaTensor(dims_2[1], dims_2[2], dims_2[3])
+
+-- Hidden
+local hiddenInputDesc = ffi.new('cudnnTensorDescriptor_t[?]', 1)
+local hiddenOutputDesc = ffi.new('cudnnTensorDescriptor_t[?]', 1)
+errcheck('cudnnCreateTensorDescriptor', hiddenInputDesc)
+errcheck('cudnnCreateTensorDescriptor', hiddenOutputDesc)
+-- ffi.gc(hiddenInputDesc, function(d) errcheck('cudnnDestroyTensorDescriptor', d[0]) end)
+-- ffi.gc(hiddenOutputDesc, function(d) errcheck('cudnnDestroyTensorDescriptor', d[0]) end)
+
+local dims_3 = torch.IntTensor({hiddenSize, miniBatch, numLayers})
+local stride_3 = torch.IntTensor({1, dims_3[1], 1})
+
+errcheck('cudnnSetTensorNdDescriptor', hiddenInputDesc[0], datatype, 3, dims_3:data(), stride_3:data())
+errcheck('cudnnSetTensorNdDescriptor', hiddenOutputDesc[0], datatype, 3, dims_3:data(), stride_3:data())
+
+local hiddenInput = torch.CudaTensor(dims_3[1], dims_3[2], dims_3[3])
+local hiddenOutput = torch.CudaTensor(dims_3[1], dims_3[2], dims_3[3])
+
+-- Cell
+local cellInputDesc = ffi.new('cudnnTensorDescriptor_t[?]', 1)
+local cellOutputDesc = ffi.new('cudnnTensorDescriptor_t[?]', 1)
+errcheck('cudnnCreateTensorDescriptor', cellInputDesc)
+errcheck('cudnnCreateTensorDescriptor', cellOutputDesc)
+-- ffi.gc(cellInputDesc, function(d) errcheck('cudnnDestroyTensorDescriptor', d[0]) end)
+-- ffi.gc(cellOutputDesc, function(d) errcheck('cudnnDestroyTensorDescriptor', d[0]) end)
+
+local dims_4 = torch.IntTensor({hiddenSize, miniBatch, numLayers})
+local stride_4 = torch.IntTensor({1, dims_4[1], 1})
+
+errcheck('cudnnSetTensorNdDescriptor', cellInputDesc[0], datatype, 3, dims_4:data(), stride_4:data())
+errcheck('cudnnSetTensorNdDescriptor', cellOutputDesc[0], datatype, 3, dims_4:data(), stride_4:data())
+
+local cellInput = torch.CudaTensor(dims_4[1], dims_4[2], dims_4[3])
+local cellOutput = torch.CudaTensor(dims_4[1], dims_4[2], dims_4[3])
+
+-- Weight
+local weightDesc = ffi.new('cudnnFilterDescriptor_t[?]', 1)
+errcheck('cudnnCreateFilterDescriptor', weightDesc)
+-- ffi.gc(weightDesc, function(d) errcheck('cudnnDestroyFilterDescriptor', d[0]) end)
+
+local weightSize = torch.LongTensor(1)
+errcheck('cudnnGetRNNParamsSize', cudnn.getHandle(), rnnDesc[0], inputDescs, weightSize:data())
+local dims_5 = torch.IntTensor({weightSize[1] / 4, 1, 1}) -- sizeof(float)
+
+-- TODO ffi CUDNN_TENSOR_NCHW
+errcheck('cudnnSetFilterNdDescriptor', weightDesc[0], datatype, 0, 3, dims_5:data())
+
+local weight = torch.CudaTensor(dims_5[1], dims_5[2], dims_5[3])
+
+-- Workspace
+local workspace = cudnn.getSharedWorkspace()
+local workspaceSize = torch.LongTensor(1)
+errcheck('cudnnGetRNNWorkspaceSize', cudnn.getHandle(), rnnDesc[0], inputDescs, workspaceSize:data())
+workspace:resize(workspaceSize[1] * 40000) -- sizeof(float)
+
+-- Print Descriptor data
+print("hiddenSize = " .. hiddenSize)
+print("inputSize = " .. inputSize)
+print("seqLength = " .. seqLength)
+print("numLayers = " .. numLayers)
+print("miniBatch = " .. miniBatch)
+print("bidirectional = " .. bidirectional)
+print("inputMode = " .. inputMode)
+print("mode = " .. mode)
+print("dropout = " .. dropout)
+
+local datatype = torch.IntTensor(1)
+local nbDims = torch.IntTensor(1)
+local dims = torch.IntTensor(3)
+local stride = torch.IntTensor(3)
+
+errcheck('cudnnGetTensorNdDescriptor', inputDescs[0], 3, datatype:data(), nbDims:data(), dims:data(), stride:data())
+print("Input dim=(" .. dims[1] .. ", " .. dims[2] .. ", " .. dims[3] .. ") " .. "stride=(" .. stride[1] .. ", " .. stride[2] .. ", " .. stride[3] .. ")")
+
+errcheck('cudnnGetTensorNdDescriptor', outputDescs[0], 3, datatype:data(), nbDims:data(), dims:data(), stride:data())
+print("Output dim=(" .. dims[1] .. ", " .. dims[2] .. ", " .. dims[3] .. ") " .. "stride=(" .. stride[1] .. ", " .. stride[2] .. ", " .. stride[3] .. ")")
+
+errcheck('cudnnGetTensorNdDescriptor', hiddenInputDesc[0], 3, datatype:data(), nbDims:data(), dims:data(), stride:data())
+print("Hidden Input dim=(" .. dims[1] .. ", " .. dims[2] .. ", " .. dims[3] .. ") " .. "stride=(" .. stride[1] .. ", " .. stride[2] .. ", " .. stride[3] .. ")")
+
+errcheck('cudnnGetTensorNdDescriptor', hiddenOutputDesc[0], 3, datatype:data(), nbDims:data(), dims:data(), stride:data())
+print("Hidden Output dim=(" .. dims[1] .. ", " .. dims[2] .. ", " .. dims[3] .. ") " .. "stride=(" .. stride[1] .. ", " .. stride[2] .. ", " .. stride[3] .. ")")
+
+errcheck('cudnnGetTensorNdDescriptor', cellInputDesc[0], 3, datatype:data(), nbDims:data(), dims:data(), stride:data())
+print("Cell Input dim=(" .. dims[1] .. ", " .. dims[2] .. ", " .. dims[3] .. ") " .. "stride=(" .. stride[1] .. ", " .. stride[2] .. ", " .. stride[3] .. ")")
+
+errcheck('cudnnGetTensorNdDescriptor', cellOutputDesc[0], 3, datatype:data(), nbDims:data(), dims:data(), stride:data())
+print("Cell Output dim=(" .. dims[1] .. ", " .. dims[2] .. ", " .. dims[3] .. ") " .. "stride=(" .. stride[1] .. ", " .. stride[2] .. ", " .. stride[3] .. ")")
+
+local format = ffi.new('cudnnTensorFormat_t[?]', 1)
+errcheck('cudnnGetFilterNdDescriptor', weightDesc[0], 3, datatype:data(), format, nbDims:data(), dims:data())
+
+print("Weight dim=(" .. dims[1] .. ", " .. dims[2] .. ", " .. dims[3] .. ") ")
+
+------ ForwardInference
+--errcheck('cudnnRNNForwardInference',
+-- cudnn.getHandle(),
+-- rnnDesc[0],
+-- inputDescs, input:data(),
+-- hiddenInputDesc[0], nil, -- hiddenInput:data(),
+-- cellInputDesc[0], nil, -- cellInput:data(),
+-- weightDesc[0], weight:data(),
+-- outputDescs, output:data(),
+-- hiddenOutputDesc[0], nil, -- hiddenOutput:data(),
+-- cellOutputDesc[0], nil, -- cellOutput:data(),
+-- workspace:data(), workspace:size(1) * 40000) -- sizeof(float)
+
+---- ForwardInference
+errcheck('cudnnRNNForwardInference',
+ cudnn.getHandle(),
+ rnnDesc[0],
+ inputDescs, input:data(),
+ hiddenInputDesc[0], hiddenInput:data(),
+ cellInputDesc[0], cellInput:data(),
+ weightDesc[0], weight:data(),
+ outputDescs, output:data(),
+ hiddenOutputDesc[0], hiddenOutput:data(),
+ cellOutputDesc[0], cellOutput:data(),
+ workspace:data(), workspace:size(1) * 40000) -- sizeof(float)
+
diff --git a/test/test.lua b/test/test.lua
index 9b5cbde..8fcd1b9 100644
--- a/test/test.lua
+++ b/test/test.lua
@@ -1444,9 +1444,10 @@ math.randomseed(os.time())
mytester = torch.Tester()
mytester:add(cudnntest)
-if torch.random(1,2) == 1 then
+-- if torch.random(1,2) == 1 then
cudnn.benchmark = true -- run manual auto-tuner
-end
+ cudnn.verbose = true
+-- end
for i=1,cutorch.getDeviceCount() do