Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/soumith/cudnn.torch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJoan Puigcerver <joapuipe@gmail.com>2016-07-01 18:29:40 +0300
committerJoan Puigcerver <joapuipe@gmail.com>2016-07-01 18:29:40 +0300
commit977a25499624e055e76fca9ff255c2a1c2b0be19 (patch)
treeeeca1a8e80a020e9158682a363a470ded7490330
parent27cc2ac4b28724ebd261f548836c7360c72a05a5 (diff)
parent66be8c95088608420c39bf9d367670eed3de5e4b (diff)
Merge branch 'master' of github.com:soumith/cudnn.torch
Conflicts: RNN.lua
-rw-r--r--CMakeLists.txt2
-rw-r--r--RNN.lua43
-rw-r--r--ffi.lua11
3 files changed, 41 insertions, 15 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt
index fc8224f..637af3c 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -6,7 +6,7 @@ IF(LUAROCKS_PREFIX)
MESSAGE(STATUS "Prefix inferred from Luarocks: ${CMAKE_INSTALL_PREFIX}")
ENDIF()
FIND_PACKAGE(Torch REQUIRED)
-FIND_PACKAGE(CUDA 7.5 REQUIRED)
+FIND_PACKAGE(CUDA 7.0 REQUIRED)
FILE(GLOB luasrc *.lua)
SET(src "")
diff --git a/RNN.lua b/RNN.lua
index f22c9ef..5970388 100644
--- a/RNN.lua
+++ b/RNN.lua
@@ -2,6 +2,8 @@ local RNN, parent = torch.class('cudnn.RNN', 'nn.Module')
local ffi = require 'ffi'
local errcheck = cudnn.errcheck
+local DESCS = {'rnnDesc', 'dropoutDesc', 'wDesc', 'xDescs', 'yDescs', 'hxDesc', 'hyDesc', 'cxDesc', 'cyDesc'}
+
function RNN:__init(inputSize, hiddenSize, numLayers, batchFirst)
parent.__init(self)
@@ -236,7 +238,7 @@ function RNN:updateOutput(input)
input = input:transpose(1, 2)
end
assert(input:dim() == 3, 'input must have 3 dimensions: seqLength, miniBatch, inputSize')
- -- assert(self.dropout == 0, 'dropout currently not supported')
+ assert(self.dropout == 0 or cudnn.version >= 5103, 'dropout supported only in cudnn v5.1 and above')
-- Decide which descriptors/tensors need to be updated.
local resetRNN = not self.dropoutDesc or not self.rnnDesc
local resetIO = not self.xDescs or not self.yDescs
@@ -274,7 +276,10 @@ function RNN:updateOutput(input)
end
local x = self:makeContiguous(input)
- local y = self:resizeOutput(self.output)
+ local oSize = torch.LongStorage({self.seqLength, self.miniBatch, self.hiddenSize * self.numDirections})
+ local oStride = torch.LongStorage({self.miniBatch * self.hiddenSize * self.numDirections, self.hiddenSize * self.numDirections, 1})
+ self.output:resize(oSize, oStride)
+ local y = self.output
local w = self.weight
local hy = self:resizeHidden(self.hiddenOutput):zero()
local cy = self:resizeHidden(self.cellOutput):zero()
@@ -364,7 +369,7 @@ function RNN:updateGradInput(input, gradOutput)
gradOutput = gradOutput:transpose(1, 2)
self.output = self.output:transpose(1, 2)
end
- -- assert(self.dropout == 0, 'dropout currently not supported')
+ assert(self.dropout == 0 or cudnn.version >= 5103, 'dropout supported only in cudnn v 5.1 and above')
assert(input:dim() == 3, 'input should have 3 dimensions: seqLength, miniBatch, inputSize')
assert(input:size(1) == self.seqLength, 'input has incorrect sequence length!')
assert(input:size(2) == self.miniBatch, 'input has incorrect minibatch size!')
@@ -437,6 +442,7 @@ function RNN:updateGradInput(input, gradOutput)
self.reserve:data(), self.reserve:size(1) * 4) -- sizeof(float)
if (self.batchFirst) then
self.gradInput = self.gradInput:transpose(1, 2)
+ self.output = self.output:transpose(1, 2)
end
return self.gradInput
end
@@ -445,10 +451,11 @@ function RNN:accGradParameters(input, gradOutput, scale)
if (self.batchFirst) then
input = input:transpose(1, 2)
gradOutput = gradOutput:transpose(1, 2)
+ self.output = self.output:transpose(1, 2)
end
scale = scale or 1
if scale == 0 then return end
- -- assert(self.dropout == 0, 'dropout currently not supported')
+ assert(self.dropout == 0 or cudnn.version >= 5103, 'dropout supported only in cudnn 5.1 and above')
assert(input:dim() == 3, 'input should have 3 dimensions: seqLength, miniBatch, inputSize')
assert(input:size(1) == self.seqLength, 'input has incorrect sequence length!')
assert(input:size(2) == self.miniBatch, 'input has incorrect minibatch size!')
@@ -502,28 +509,36 @@ function RNN:accGradParameters(input, gradOutput, scale)
self.dw:data(),
scaleTensor:data())
end
+
+ if (self.batchFirst) then
+ gradOutput = gradOutput:transpose(1, 2)
+ self.output = self.output:transpose(1, 2)
+ end
end
function RNN:clearDesc()
- self.dropoutDesc = nil
- self.rnnDesc = nil
- self.dropoutDesc = nil
- self.wDesc = nil
- self.xDescs = nil
- self.yDescs = nil
- self.hxDesc = nil
- self.hyDesc = nil
- self.cxDesc = nil
- self.cyDesc = nil
+ for _, desc in pairs(DESCS) do
+ self[desc] = nil
+ end
end
function RNN:write(f)
+ local pushDescs = {}
+ for _, desc in pairs(DESCS) do
+ pushDescs[desc] = self[desc]
+ end
+
self:clearDesc()
+
local var = {}
for k,v in pairs(self) do
var[k] = v
end
f:writeObject(var)
+
+ for desc, v in pairs(pushDescs) do
+ self[desc] = v
+ end
end
function RNN:clearState()
diff --git a/ffi.lua b/ffi.lua
index 9e2cd5e..2a589a8 100644
--- a/ffi.lua
+++ b/ffi.lua
@@ -1,3 +1,4 @@
+require 'cutorch'
local ffi = require 'ffi'
ffi.cdef[[
@@ -1602,9 +1603,19 @@ Then make sure files named as libcudnn.so.5 or libcudnn.5.dylib are placed in yo
]])
end
+-- check cuDNN version
cudnn.version = tonumber(cudnn.C.cudnnGetVersion())
if cudnn.version < 5005 then
error('These bindings are for version 5005 or above, '
.. 'while the loaded CuDNN is version: ' .. cudnn.version
.. ' \nAre you using an older version of CuDNN?')
end
+
+-- cechk GPU driver version
+local props = cutorch.getDeviceProperties(cutorch.getDevice())
+if cutorch.driverVersion and -- for backward compatiblity
+ not(cutorch.driverVersion >= 7050 -- desktop GPUs
+ or (props.major == 5 and props.minor == 3 and cutorch.driverVersion >= 7000) ) -- Tegra X1
+then
+ error('Insufficient GPU driver version.')
+end