diff options
Diffstat (limited to 'RNN.lua')
-rw-r--r-- | RNN.lua | 18 |
1 files changed, 11 insertions, 7 deletions
@@ -687,7 +687,7 @@ function RNN:updateGradInput(input, gradOutput) end function RNN:accGradParameters(input, gradOutput, scale) - if (self.batchFirst) then + if self.batchFirst and not self.inputPacked then input = input:transpose(1, 2) gradOutput = gradOutput:transpose(1, 2) self.output = self.output:transpose(1, 2) @@ -695,15 +695,19 @@ function RNN:accGradParameters(input, gradOutput, scale) scale = scale or 1 if scale == 0 then return end assert(self.dropout == 0 or cudnn.version >= 5103, 'dropout supported only in cudnn 5.1 and above') - assert(input:dim() == 3, 'input should have 3 dimensions: seqLength, miniBatch, inputSize') - assert(input:size(1) == self.seqLength, 'input has incorrect sequence length!') - assert(input:size(2) == self.miniBatch, 'input has incorrect minibatch size!') - assert(input:size(3) == self.inputSize, 'input has incorrect size!') + if self.inputPacked then + assert(input[1]:dim() == 2, 'packed input must have two dimensions: sum(sequence lengths), inputSize') + else + assert(input:dim() == 3, 'input should have 3 dimensions: seqLength, miniBatch, inputSize') + assert(input:size(1) == self.seqLength, 'input has incorrect sequence length!') + assert(input:size(2) == self.miniBatch, 'input has incorrect minibatch size!') + assert(input:size(3) == self.inputSize, 'input has incorrect size!') + end assert(gradOutput:isSameSizeAs(self.output), 'gradOutput has incorrect size!') assert(self.train, 'accGradParameters can only be called when training!') - local x, dy = self:makeContiguous(input, gradOutput) + local x, dy = self:makeContiguous(self.inputPacked and input[1] or input, gradOutput) local hx = self.hiddenInput local y = self.output local dw = self.gradWeight @@ -759,7 +763,7 @@ function RNN:accGradParameters(input, gradOutput, scale) scaleTensor:data()) end - if (self.batchFirst) then + if self.batchFirst and not self.inputPacked then gradOutput = gradOutput:transpose(1, 2) self.output = self.output:transpose(1, 2) end |