Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/soumith/cudnn.torch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'RNN.lua')
-rw-r--r--RNN.lua7
1 files changed, 5 insertions, 2 deletions
diff --git a/RNN.lua b/RNN.lua
index 45e137e..c61d84d 100644
--- a/RNN.lua
+++ b/RNN.lua
@@ -180,7 +180,7 @@ function RNN:resetOutputDescriptor(output, batchSizes)
local dim = torch.IntTensor({batchSizes[i+1], self.hiddenSize * self.numDirections, 1})
local stride = torch.IntTensor({dim[3] * dim[2], dim[3],1})
errcheck('cudnnSetTensorNdDescriptor',
- self.xDescs[i],
+ self.yDescs[i],
self.datatype,
3,
dim:data(),
@@ -468,7 +468,9 @@ function RNN:updateOutput(input)
-- Make sure input is contiguous
local x = self:makeContiguous(self.inputPacked and input[1] or input)
local oSize = self:deriveOutputSize(x)
- local oStride = torch.LongStorage({oSize[2] * oSize[3], oSize[3], 1})
+ local oStride = self.inputPacked and
+ torch.LongStorage({oSize[2], 1}) or
+ torch.LongStorage({oSize[2] * oSize[3], oSize[3], 1})
self.output:resize(oSize, oStride)
local y = self.output
local w = self.weight
@@ -547,6 +549,7 @@ function RNN:updateOutput(input)
local elemSize = self.reserve:elementSize()
reserveSize = math.floor((reserveSize + elemSize - 1) / elemSize)
self.reserve:resize(reserveSize)
+
errcheck('cudnnRNNForwardTraining',
cudnn.getHandle(),
self.rnnDesc[0],