Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/soumith/cudnn.torch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSoumith Chintala <soumith@gmail.com>2016-06-04 21:02:50 +0300
committerSoumith Chintala <soumith@gmail.com>2016-06-04 21:02:50 +0300
commit97520f7ce3cb2477b6190ab2c559f00cfa612968 (patch)
treec5d00781abcb7bb20e144ebf37df129a81d84e38
parent9a5c2f02c7b0415f021d568caccd65002f6658ee (diff)
parentb01cfb5cb3b376340817a12d1f4bab988935dece (diff)
Merge pull request #169 from SeanNaren/R5R5
Fixed error with batchFirst assertion
-rw-r--r--RNN.lua14
1 files changed, 6 insertions, 8 deletions
diff --git a/RNN.lua b/RNN.lua
index 8d651e4..1c644a9 100644
--- a/RNN.lua
+++ b/RNN.lua
@@ -210,7 +210,7 @@ function RNN:resetCellDescriptors()
end
function RNN:makeContiguous(input, gradOutput)
- if not input:isContiguous() then
+ if input and not input:isContiguous() then
self._input = self._input or input.new()
self._input:typeAs(input):resizeAs(input):copy(input)
input = self._input
@@ -361,17 +361,15 @@ function RNN:updateGradInput(input, gradOutput)
if (self.batchFirst) then
input = input:transpose(1, 2)
gradOutput = gradOutput:transpose(1, 2)
- self.output = self.output:transpose(1, 2)
end
assert(input:dim() == 3, 'input should have 3 dimensions: seqLength, miniBatch, inputSize')
assert(input:size(1) == self.seqLength, 'input has incorrect sequence length!')
assert(input:size(2) == self.miniBatch, 'input has incorrect minibatch size!')
assert(input:size(3) == self.inputSize, 'input has incorrect size!')
-
- assert(gradOutput:isSameSizeAs(self.output), 'gradOutput has incorrect size!')
assert(self.train, 'updateGradInput can only be called when training!')
-
- local x, dy = self:makeContiguous(input, gradOutput)
+ local expectedSize = torch.LongStorage {self.seqLength, self.miniBatch, self.hiddenSize * self.numDirections}
+ assert(gradOutput:isSize(expectedSize), 'gradOutput has incorrect size!')
+ local x, dy = self:makeContiguous(nil, gradOutput) -- No need to calculate x.
local y = self.output
local w = self.weight
local dx = self.gradInput:resizeAs(input)
@@ -451,8 +449,8 @@ function RNN:accGradParameters(input, gradOutput, scale)
assert(input:size(1) == self.seqLength, 'input has incorrect sequence length!')
assert(input:size(2) == self.miniBatch, 'input has incorrect minibatch size!')
assert(input:size(3) == self.inputSize, 'input has incorrect size!')
-
- assert(gradOutput:isSameSizeAs(self.output), 'gradOutput has incorrect size!')
+ local expectedSize = torch.LongStorage {self.seqLength, self.miniBatch, self.hiddenSize * self.numDirections}
+ assert(gradOutput:isSize(expectedSize), 'gradOutput has incorrect size!')
assert(self.train, 'accGradParameters can only be called when training!')
local x, dy = self:makeContiguous(input, gradOutput)