diff options
author | SeanNaren <taz838@hotmail.co.uk> | 2016-04-20 00:21:04 +0300 |
---|---|---|
committer | SeanNaren <taz838@hotmail.co.uk> | 2016-04-20 00:21:04 +0300 |
commit | 20adf78d4e235b979bfd4fd1b38db31713ae2e21 (patch) | |
tree | 090975b4b1b65b60c3f66f397c4eb59eb91a6c73 | |
parent | 796ea1fc9741c9031d60c2511ef0ecfd671956c5 (diff) |
Updated gradOutput assertion
-rw-r--r-- | RNN.lua | 10 |
1 files changed, 2 insertions, 8 deletions
@@ -365,15 +365,9 @@ function RNN:updateGradInput(input, gradOutput) assert(input:size(1) == self.seqLength, 'input has incorrect sequence length!') assert(input:size(2) == self.miniBatch, 'input has incorrect minibatch size!') assert(input:size(3) == self.inputSize, 'input has incorrect size!') - - if (self.batchFirst) then - assert(gradOutput:isSameSizeAs(self.output:transpose(1, 2)), 'gradOutput has incorrect size!') - - else - assert(gradOutput:isSameSizeAs(self.output), 'gradOutput has incorrect size!') - end assert(self.train, 'updateGradInput can only be called when training!') - + local expectedSize = torch.LongStorage {self.seqLength, self.miniBatch, self.hiddenSize * self.numDirections} + assert(gradOutput:isSize(expectedSize), 'gradOutput has incorrect size!') local x, dy = self:makeContiguous(input, gradOutput) local y = self.output local w = self.weight |