diff options
author | Soumith Chintala <soumith@gmail.com> | 2016-06-04 21:02:50 +0300 |
---|---|---|
committer | Soumith Chintala <soumith@gmail.com> | 2016-06-04 21:02:50 +0300 |
commit | 97520f7ce3cb2477b6190ab2c559f00cfa612968 (patch) | |
tree | c5d00781abcb7bb20e144ebf37df129a81d84e38 | |
parent | 9a5c2f02c7b0415f021d568caccd65002f6658ee (diff) | |
parent | b01cfb5cb3b376340817a12d1f4bab988935dece (diff) |
Merge pull request #169 from SeanNaren/R5R5
Fixed error with batchFirst assertion
-rw-r--r-- | RNN.lua | 14 |
1 files changed, 6 insertions, 8 deletions
@@ -210,7 +210,7 @@ function RNN:resetCellDescriptors() end function RNN:makeContiguous(input, gradOutput) - if not input:isContiguous() then + if input and not input:isContiguous() then self._input = self._input or input.new() self._input:typeAs(input):resizeAs(input):copy(input) input = self._input @@ -361,17 +361,15 @@ function RNN:updateGradInput(input, gradOutput) if (self.batchFirst) then input = input:transpose(1, 2) gradOutput = gradOutput:transpose(1, 2) - self.output = self.output:transpose(1, 2) end assert(input:dim() == 3, 'input should have 3 dimensions: seqLength, miniBatch, inputSize') assert(input:size(1) == self.seqLength, 'input has incorrect sequence length!') assert(input:size(2) == self.miniBatch, 'input has incorrect minibatch size!') assert(input:size(3) == self.inputSize, 'input has incorrect size!') - - assert(gradOutput:isSameSizeAs(self.output), 'gradOutput has incorrect size!') assert(self.train, 'updateGradInput can only be called when training!') - - local x, dy = self:makeContiguous(input, gradOutput) + local expectedSize = torch.LongStorage {self.seqLength, self.miniBatch, self.hiddenSize * self.numDirections} + assert(gradOutput:isSize(expectedSize), 'gradOutput has incorrect size!') + local x, dy = self:makeContiguous(nil, gradOutput) -- No need to calculate x. local y = self.output local w = self.weight local dx = self.gradInput:resizeAs(input) @@ -451,8 +449,8 @@ function RNN:accGradParameters(input, gradOutput, scale) assert(input:size(1) == self.seqLength, 'input has incorrect sequence length!') assert(input:size(2) == self.miniBatch, 'input has incorrect minibatch size!') assert(input:size(3) == self.inputSize, 'input has incorrect size!') - - assert(gradOutput:isSameSizeAs(self.output), 'gradOutput has incorrect size!') + local expectedSize = torch.LongStorage {self.seqLength, self.miniBatch, self.hiddenSize * self.numDirections} + assert(gradOutput:isSize(expectedSize), 'gradOutput has incorrect size!') assert(self.train, 'accGradParameters can only be called when training!') local x, dy = self:makeContiguous(input, gradOutput) |