Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/soumith/cudnn.torch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTrevor Killeen <killeentm@gmail.com>2017-04-04 17:11:52 +0300
committerTrevor Killeen <killeentm@gmail.com>2017-04-04 17:49:25 +0300
commit63f10c2513d01da1f233e81cb89d55ec1c1b0c25 (patch)
treea0a60a3183b881ae0654754dadbb4cefa5b3e6d0
parent16c2267bf6c0e5032eddb6e2a194200725337137 (diff)
parity for updateGradInput
-rw-r--r--RNN.lua29
-rw-r--r--test/test_rnn.lua27
2 files changed, 44 insertions, 12 deletions
diff --git a/RNN.lua b/RNN.lua
index c61d84d..ca0f9ce 100644
--- a/RNN.lua
+++ b/RNN.lua
@@ -587,24 +587,29 @@ function RNN:updateOutput(input)
end
function RNN:updateGradInput(input, gradOutput)
- if (self.batchFirst) then
- input = input:transpose(1, 2)
- gradOutput = gradOutput:transpose(1, 2)
- self.output = self.output:transpose(1, 2)
- end
+ if self.batchFirst and not self.inputPacked then
+ input = input:transpose(1, 2)
+ gradOutput = gradOutput:transpose(1, 2)
+ self.output = self.output:transpose(1, 2)
+ end
assert(self.dropout == 0 or cudnn.version >= 5103, 'dropout supported only in cudnn v 5.1 and above')
- assert(input:dim() == 3, 'input should have 3 dimensions: seqLength, miniBatch, inputSize')
- assert(input:size(1) == self.seqLength, 'input has incorrect sequence length!')
- assert(input:size(2) == self.miniBatch, 'input has incorrect minibatch size!')
- assert(input:size(3) == self.inputSize, 'input has incorrect size!')
+
+ if self.inputPacked then
+ assert(input[1]:dim() == 2, 'packed input must have two dimensions: sum(sequence lengths), inputSize')
+ else
+ assert(input:dim() == 3, 'input should have 3 dimensions: seqLength, miniBatch, inputSize')
+ assert(input:size(1) == self.seqLength, 'input has incorrect sequence length!')
+ assert(input:size(2) == self.miniBatch, 'input has incorrect minibatch size!')
+ assert(input:size(3) == self.inputSize, 'input has incorrect size!')
+ end
assert(gradOutput:isSameSizeAs(self.output), 'gradOutput has incorrect size!')
assert(self.train, 'updateGradInput can only be called when training!')
- local x, dy = self:makeContiguous(input, gradOutput)
+ local x, dy = self:makeContiguous(self.inputPacked and input[1] or input, gradOutput)
local y = self.output
local w = self.weight
- local dx = self.gradInput:resizeAs(input)
+ local dx = self.gradInput:resizeAs(self.inputPacked and input[1] or input)
local hx = self.hiddenInput
local cx = self.cellInput
local dhy = self.gradHiddenOutput
@@ -674,7 +679,7 @@ function RNN:updateGradInput(input, gradOutput)
wsPtr, wsSize,
self.reserve:data(), self.reserve:size(1) * self.reserve:elementSize())
if self.sync then cutorch.synchronize() end
- if (self.batchFirst) then
+ if self.batchFirst and not self.inputPacked then
self.gradInput = self.gradInput:transpose(1, 2)
self.output = self.output:transpose(1, 2)
end
diff --git a/test/test_rnn.lua b/test/test_rnn.lua
index d2e0518..8ad982a 100644
--- a/test/test_rnn.lua
+++ b/test/test_rnn.lua
@@ -367,6 +367,16 @@ function rnntest.testVariableLengthSequences()
local lengths = {4, 3, 3, 1}
local maxLength = 4
+ -- Generate gradOutput based on input sizes
+ local gradOutput = torch.CudaTensor(11, 1, 10):uniform()
+ local indivGradOutputs = {
+ torch.cat({gradOutput:narrow(1, 1, 1), gradOutput:narrow(1, 5, 1), gradOutput:narrow(1, 8, 1), gradOutput:narrow(1, 11, 1)}, 1):clone(),
+ torch.cat({gradOutput:narrow(1, 2, 1), gradOutput:narrow(1, 6, 1), gradOutput:narrow(1, 9, 1)}, 1):clone(),
+ torch.cat({gradOutput:narrow(1, 3, 1), gradOutput:narrow(1, 7, 1), gradOutput:narrow(1, 10, 1)}, 1):clone(),
+ gradOutput:narrow(1, 4, 1):clone()
+ }
+ gradOutput = gradOutput:squeeze()
+
local inputSize = 4
local hiddenSize = 10
local numLayers = 1
@@ -402,6 +412,7 @@ function rnntest.testVariableLengthSequences()
local separate = {}
local hids = {}
+ local indivGradInputs = {}
for i, length in ipairs(lengths) do
local inp = indivInputs[i]
@@ -409,6 +420,11 @@ function rnntest.testVariableLengthSequences()
table.insert(separate, output)
local hid = lstm2.hiddenOutput:clone()
table.insert(hids, hid)
+
+ -- need to do backwards pass here too
+ local gradOutput = indivGradOutputs[i]
+ local gradInp = lstm2:updateGradInput(inp, gradOutput):clone()
+ table.insert(indivGradInputs, gradInp)
end
separate = torch.cat(separate, 1):squeeze()
hids = torch.cat(hids, 1):squeeze()
@@ -440,6 +456,17 @@ function rnntest.testVariableLengthSequences()
local hdiff = torch.csub(packedHiddenOutput, hids):abs():sum()
mytester:assert(hdiff < 1e7)
+
+ -- Step 2: update grad input as batch and individually
+
+ local packedGradInput = lstm:updateGradInput(packed, gradOutput)
+ local igiTestable = torch.cat(indivGradInputs, 1):squeeze(2)
+
+ for _, pair in ipairs(corresponding) do
+ sep, batched = unpack(pair)
+ local diff = torch.csub(igiTestable[sep], packedGradInput[batched]):abs():sum()
+ mytester:assert(diff < 1e-7)
+ end
end
mytester = torch.Tester()