Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/soumith/cudnn.torch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNatalia Gimelshein <ngimelshein@nvidia.com>2016-06-30 21:47:31 +0300
committerNatalia Gimelshein <ngimelshein@nvidia.com>2016-06-30 21:47:31 +0300
commit69f2eb6122be4a54d49137a59c23595fa24e6d1a (patch)
treebff7b7497d2b03d5ffab030d4054f9a090c9b747
parentfdef3826f243192e98a462208f7df06d07267290 (diff)
re-enable dropout for 5.1, minor changes to batchFirst
-rw-r--r--RNN.lua13
1 files changed, 5 insertions, 8 deletions
diff --git a/RNN.lua b/RNN.lua
index c5e3efd..17faa3a 100644
--- a/RNN.lua
+++ b/RNN.lua
@@ -238,7 +238,7 @@ function RNN:updateOutput(input)
input = input:transpose(1, 2)
end
assert(input:dim() == 3, 'input must have 3 dimensions: seqLength, miniBatch, inputSize')
- assert(self.dropout == 0, 'dropout currently not supported')
+ assert(self.dropout == 0 or cudnn.version >= 5103, 'dropout supported only in cudnn v5.1 and above')
-- Decide which descriptors/tensors need to be updated.
local resetRNN = not self.dropoutDesc or not self.rnnDesc
local resetIO = not self.xDescs or not self.yDescs
@@ -277,11 +277,8 @@ function RNN:updateOutput(input)
local x = self:makeContiguous(input)
local oSize = torch.LongStorage({self.seqLength, self.miniBatch, self.hiddenSize * self.numDirections})
- if not self.output:isContiguous() then
- self.output = self.output:transpose(1,2)
- assert(self.output:isContiguous())
- end
- self.output:resize(oSize)
+ local oStride = torch.LongStorage({self.miniBatch * self.hiddenSize * self.numDirections, self.hiddenSize * self.numDirections, 1})
+ self.output:resize(oSize, oStride)
local y = self.output
local w = self.weight
local hy = self:resizeHidden(self.hiddenOutput):zero()
@@ -372,7 +369,7 @@ function RNN:updateGradInput(input, gradOutput)
gradOutput = gradOutput:transpose(1, 2)
self.output = self.output:transpose(1, 2)
end
- assert(self.dropout == 0, 'dropout currently not supported')
+ assert(self.dropout == 0 or cudnn.version >= 5103, 'dropout supported only in cudnn v 5.1 and above')
assert(input:dim() == 3, 'input should have 3 dimensions: seqLength, miniBatch, inputSize')
assert(input:size(1) == self.seqLength, 'input has incorrect sequence length!')
assert(input:size(2) == self.miniBatch, 'input has incorrect minibatch size!')
@@ -458,7 +455,7 @@ function RNN:accGradParameters(input, gradOutput, scale)
end
scale = scale or 1
if scale == 0 then return end
- assert(self.dropout == 0, 'dropout currently not supported')
+ assert(self.dropout == 0 or cudnn.version >= 5103, 'dropout supported only in cudnn 5.1 and above')
assert(input:dim() == 3, 'input should have 3 dimensions: seqLength, miniBatch, inputSize')
assert(input:size(1) == self.seqLength, 'input has incorrect sequence length!')
assert(input:size(2) == self.miniBatch, 'input has incorrect minibatch size!')