Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/soumith/cudnn.torch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTrevor Killeen <killeentm@gmail.com>2017-04-04 17:44:48 +0300
committerTrevor Killeen <killeentm@gmail.com>2017-04-04 17:49:25 +0300
commitabd3fe3f822fb07f055473bfda99a9a2ac2cf76d (patch)
treedf3ed3c816dd459cadefe13a07cbf38638a47685
parent63f10c2513d01da1f233e81cb89d55ec1c1b0c25 (diff)
implement for accGradParameters
-rw-r--r--RNN.lua18
-rw-r--r--test/test_rnn.lua5
2 files changed, 16 insertions, 7 deletions
diff --git a/RNN.lua b/RNN.lua
index ca0f9ce..0db5ec4 100644
--- a/RNN.lua
+++ b/RNN.lua
@@ -687,7 +687,7 @@ function RNN:updateGradInput(input, gradOutput)
end
function RNN:accGradParameters(input, gradOutput, scale)
- if (self.batchFirst) then
+ if self.batchFirst and not self.inputPacked then
input = input:transpose(1, 2)
gradOutput = gradOutput:transpose(1, 2)
self.output = self.output:transpose(1, 2)
@@ -695,15 +695,19 @@ function RNN:accGradParameters(input, gradOutput, scale)
scale = scale or 1
if scale == 0 then return end
assert(self.dropout == 0 or cudnn.version >= 5103, 'dropout supported only in cudnn 5.1 and above')
- assert(input:dim() == 3, 'input should have 3 dimensions: seqLength, miniBatch, inputSize')
- assert(input:size(1) == self.seqLength, 'input has incorrect sequence length!')
- assert(input:size(2) == self.miniBatch, 'input has incorrect minibatch size!')
- assert(input:size(3) == self.inputSize, 'input has incorrect size!')
+ if self.inputPacked then
+ assert(input[1]:dim() == 2, 'packed input must have two dimensions: sum(sequence lengths), inputSize')
+ else
+ assert(input:dim() == 3, 'input should have 3 dimensions: seqLength, miniBatch, inputSize')
+ assert(input:size(1) == self.seqLength, 'input has incorrect sequence length!')
+ assert(input:size(2) == self.miniBatch, 'input has incorrect minibatch size!')
+ assert(input:size(3) == self.inputSize, 'input has incorrect size!')
+ end
assert(gradOutput:isSameSizeAs(self.output), 'gradOutput has incorrect size!')
assert(self.train, 'accGradParameters can only be called when training!')
- local x, dy = self:makeContiguous(input, gradOutput)
+ local x, dy = self:makeContiguous(self.inputPacked and input[1] or input, gradOutput)
local hx = self.hiddenInput
local y = self.output
local dw = self.gradWeight
@@ -759,7 +763,7 @@ function RNN:accGradParameters(input, gradOutput, scale)
scaleTensor:data())
end
- if (self.batchFirst) then
+ if self.batchFirst and not self.inputPacked then
gradOutput = gradOutput:transpose(1, 2)
self.output = self.output:transpose(1, 2)
end
diff --git a/test/test_rnn.lua b/test/test_rnn.lua
index 8ad982a..cabac77 100644
--- a/test/test_rnn.lua
+++ b/test/test_rnn.lua
@@ -409,6 +409,8 @@ function rnntest.testVariableLengthSequences()
local packed = cudnn.RNN:packPaddedSequence(input, lengths)
local packedOutput = lstm:updateOutput(packed)
local packedHiddenOutput = lstm.hiddenOutput:clone()
+ -- could use padPackedSequence here, but for testing simplicity, we'll just
+ -- operate on the returned results
local separate = {}
local hids = {}
@@ -467,6 +469,9 @@ function rnntest.testVariableLengthSequences()
local diff = torch.csub(igiTestable[sep], packedGradInput[batched]):abs():sum()
mytester:assert(diff < 1e-7)
end
+
+ -- Step 3: Basically verify that accGradParameters works for batch
+ lstm:accGradParameters(packed, gradOutput)
end
mytester = torch.Tester()