From 63f10c2513d01da1f233e81cb89d55ec1c1b0c25 Mon Sep 17 00:00:00 2001 From: Trevor Killeen Date: Tue, 4 Apr 2017 07:11:52 -0700 Subject: parity for updateGradInput --- RNN.lua | 29 +++++++++++++++++------------ test/test_rnn.lua | 27 +++++++++++++++++++++++++++ 2 files changed, 44 insertions(+), 12 deletions(-) diff --git a/RNN.lua b/RNN.lua index c61d84d..ca0f9ce 100644 --- a/RNN.lua +++ b/RNN.lua @@ -587,24 +587,29 @@ function RNN:updateOutput(input) end function RNN:updateGradInput(input, gradOutput) - if (self.batchFirst) then - input = input:transpose(1, 2) - gradOutput = gradOutput:transpose(1, 2) - self.output = self.output:transpose(1, 2) - end + if self.batchFirst and not self.inputPacked then + input = input:transpose(1, 2) + gradOutput = gradOutput:transpose(1, 2) + self.output = self.output:transpose(1, 2) + end assert(self.dropout == 0 or cudnn.version >= 5103, 'dropout supported only in cudnn v 5.1 and above') - assert(input:dim() == 3, 'input should have 3 dimensions: seqLength, miniBatch, inputSize') - assert(input:size(1) == self.seqLength, 'input has incorrect sequence length!') - assert(input:size(2) == self.miniBatch, 'input has incorrect minibatch size!') - assert(input:size(3) == self.inputSize, 'input has incorrect size!') + + if self.inputPacked then + assert(input[1]:dim() == 2, 'packed input must have two dimensions: sum(sequence lengths), inputSize') + else + assert(input:dim() == 3, 'input should have 3 dimensions: seqLength, miniBatch, inputSize') + assert(input:size(1) == self.seqLength, 'input has incorrect sequence length!') + assert(input:size(2) == self.miniBatch, 'input has incorrect minibatch size!') + assert(input:size(3) == self.inputSize, 'input has incorrect size!') + end assert(gradOutput:isSameSizeAs(self.output), 'gradOutput has incorrect size!') assert(self.train, 'updateGradInput can only be called when training!') - local x, dy = self:makeContiguous(input, gradOutput) + local x, dy = self:makeContiguous(self.inputPacked and input[1] or input, gradOutput) local y = self.output local w = self.weight - local dx = self.gradInput:resizeAs(input) + local dx = self.gradInput:resizeAs(self.inputPacked and input[1] or input) local hx = self.hiddenInput local cx = self.cellInput local dhy = self.gradHiddenOutput @@ -674,7 +679,7 @@ function RNN:updateGradInput(input, gradOutput) wsPtr, wsSize, self.reserve:data(), self.reserve:size(1) * self.reserve:elementSize()) if self.sync then cutorch.synchronize() end - if (self.batchFirst) then + if self.batchFirst and not self.inputPacked then self.gradInput = self.gradInput:transpose(1, 2) self.output = self.output:transpose(1, 2) end diff --git a/test/test_rnn.lua b/test/test_rnn.lua index d2e0518..8ad982a 100644 --- a/test/test_rnn.lua +++ b/test/test_rnn.lua @@ -367,6 +367,16 @@ function rnntest.testVariableLengthSequences() local lengths = {4, 3, 3, 1} local maxLength = 4 + -- Generate gradOutput based on input sizes + local gradOutput = torch.CudaTensor(11, 1, 10):uniform() + local indivGradOutputs = { + torch.cat({gradOutput:narrow(1, 1, 1), gradOutput:narrow(1, 5, 1), gradOutput:narrow(1, 8, 1), gradOutput:narrow(1, 11, 1)}, 1):clone(), + torch.cat({gradOutput:narrow(1, 2, 1), gradOutput:narrow(1, 6, 1), gradOutput:narrow(1, 9, 1)}, 1):clone(), + torch.cat({gradOutput:narrow(1, 3, 1), gradOutput:narrow(1, 7, 1), gradOutput:narrow(1, 10, 1)}, 1):clone(), + gradOutput:narrow(1, 4, 1):clone() + } + gradOutput = gradOutput:squeeze() + local inputSize = 4 local hiddenSize = 10 local numLayers = 1 @@ -402,6 +412,7 @@ function rnntest.testVariableLengthSequences() local separate = {} local hids = {} + local indivGradInputs = {} for i, length in ipairs(lengths) do local inp = indivInputs[i] @@ -409,6 +420,11 @@ function rnntest.testVariableLengthSequences() table.insert(separate, output) local hid = lstm2.hiddenOutput:clone() table.insert(hids, hid) + + -- need to do backwards pass here too + local gradOutput = indivGradOutputs[i] + local gradInp = lstm2:updateGradInput(inp, gradOutput):clone() + table.insert(indivGradInputs, gradInp) end separate = torch.cat(separate, 1):squeeze() hids = torch.cat(hids, 1):squeeze() @@ -440,6 +456,17 @@ function rnntest.testVariableLengthSequences() local hdiff = torch.csub(packedHiddenOutput, hids):abs():sum() mytester:assert(hdiff < 1e7) + + -- Step 2: update grad input as batch and individually + + local packedGradInput = lstm:updateGradInput(packed, gradOutput) + local igiTestable = torch.cat(indivGradInputs, 1):squeeze(2) + + for _, pair in ipairs(corresponding) do + sep, batched = unpack(pair) + local diff = torch.csub(igiTestable[sep], packedGradInput[batched]):abs():sum() + mytester:assert(diff < 1e-7) + end end mytester = torch.Tester() -- cgit v1.2.3