diff options
Diffstat (limited to 'RNN.lua')
-rw-r--r-- | RNN.lua | 7 |
1 files changed, 5 insertions, 2 deletions
@@ -180,7 +180,7 @@ function RNN:resetOutputDescriptor(output, batchSizes) local dim = torch.IntTensor({batchSizes[i+1], self.hiddenSize * self.numDirections, 1}) local stride = torch.IntTensor({dim[3] * dim[2], dim[3],1}) errcheck('cudnnSetTensorNdDescriptor', - self.xDescs[i], + self.yDescs[i], self.datatype, 3, dim:data(), @@ -468,7 +468,9 @@ function RNN:updateOutput(input) -- Make sure input is contiguous local x = self:makeContiguous(self.inputPacked and input[1] or input) local oSize = self:deriveOutputSize(x) - local oStride = torch.LongStorage({oSize[2] * oSize[3], oSize[3], 1}) + local oStride = self.inputPacked and + torch.LongStorage({oSize[2], 1}) or + torch.LongStorage({oSize[2] * oSize[3], oSize[3], 1}) self.output:resize(oSize, oStride) local y = self.output local w = self.weight @@ -547,6 +549,7 @@ function RNN:updateOutput(input) local elemSize = self.reserve:elementSize() reserveSize = math.floor((reserveSize + elemSize - 1) / elemSize) self.reserve:resize(reserveSize) + errcheck('cudnnRNNForwardTraining', cudnn.getHandle(), self.rnnDesc[0], |