Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/soumith/cudnn.torch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSeanNaren <taz838@hotmail.co.uk>2016-04-20 00:21:04 +0300
committerSeanNaren <taz838@hotmail.co.uk>2016-04-20 00:21:04 +0300
commit20adf78d4e235b979bfd4fd1b38db31713ae2e21 (patch)
tree090975b4b1b65b60c3f66f397c4eb59eb91a6c73
parent796ea1fc9741c9031d60c2511ef0ecfd671956c5 (diff)
Updated gradOutput assertion
-rw-r--r--RNN.lua10
1 files changed, 2 insertions, 8 deletions
diff --git a/RNN.lua b/RNN.lua
index 1b945fc..af1272d 100644
--- a/RNN.lua
+++ b/RNN.lua
@@ -365,15 +365,9 @@ function RNN:updateGradInput(input, gradOutput)
assert(input:size(1) == self.seqLength, 'input has incorrect sequence length!')
assert(input:size(2) == self.miniBatch, 'input has incorrect minibatch size!')
assert(input:size(3) == self.inputSize, 'input has incorrect size!')
-
- if (self.batchFirst) then
- assert(gradOutput:isSameSizeAs(self.output:transpose(1, 2)), 'gradOutput has incorrect size!')
-
- else
- assert(gradOutput:isSameSizeAs(self.output), 'gradOutput has incorrect size!')
- end
assert(self.train, 'updateGradInput can only be called when training!')
-
+ local expectedSize = torch.LongStorage {self.seqLength, self.miniBatch, self.hiddenSize * self.numDirections}
+ assert(gradOutput:isSize(expectedSize), 'gradOutput has incorrect size!')
local x, dy = self:makeContiguous(input, gradOutput)
local y = self.output
local w = self.weight