Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/soumith/cudnn.torch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTrevor Killeen <killeentm@gmail.com>2017-03-31 17:49:06 +0300
committerTrevor Killeen <killeentm@gmail.com>2017-04-04 17:48:24 +0300
commitbd72272c7fade737de84d25b604aa9f8b77e9fe7 (patch)
tree1e8d282a834d8e4894e85960f8e75aa629de92d5
parent51239c7bf128a77d5c2c8cddb6b67ab3ba8664bf (diff)
implement pack/pad utility functions
-rw-r--r--RNN.lua98
-rw-r--r--test/test_rnn.lua56
2 files changed, 154 insertions, 0 deletions
diff --git a/RNN.lua b/RNN.lua
index f22849a..428c81d 100644
--- a/RNN.lua
+++ b/RNN.lua
@@ -247,8 +247,106 @@ function RNN:resetStates()
end
end
+-- input a TxBx* tensor (or BxTx* if batchFirst) where T is the length
+-- of the longest sequence, B is the batch size, and * is any number of
+-- dimensions.
+--
+-- lengths is a table of sequence lengths, which should be sorted in
+-- decreasing order.
+--
+-- returns a table containing a packed tensor of size (sum of lengths x *)
+-- and a list of batch sizes per timestep, i.e. the number of sequences
+-- with at least timestep elements.
+function RNN:packPaddedSequence(input, lengths, batchFirst)
+ if batchFirst then
+ input = input:transpose(1, 2)
+ end
+
+ local batches = {}
+ local bszpts = {}
+ local lengthsIdx = #lengths
+ local currentLength = lengths[lengthsIdx]
+
+ local steps = input:size(1)
+ local bsz = input:size(2)
+ if bsz ~= #lengths then
+ error("lengths array has incorrect size (expected: " .. bsz .. "but found: " .. #lengths ..")")
+ end
+
+ for ts = 1, steps do
+ table.insert(batches, input[ts]:narrow(1, 1, bsz))
+ table.insert(bszpts, bsz)
+ while ts == currentLength do
+ if lengthsIdx == 0 then
+ currentLength = nil
+ break
+ else
+ lengthsIdx = lengthsIdx - 1
+ bsz = bsz - 1
+ local nextLength = lengths[lengthsIdx]
+ if currentLength ~= nil and nextLength ~= nil and currentLength > nextLength then
+ error("lengths array has to be sorted in decreasing order")
+ end
+ currentLength = lengths[lengthsIdx]
+ end
+ end
+
+ if currentLength == nil then
+ break
+ end
+ end
+ return {torch.cat(batches, 1), bszpts}
+end
+
+-- An inverse operation to packPaddedSequence(...) above. Takes a sequence (i.e.
+-- a Tensor, bszpts table with the format as returned by packPaddedSequence and
+-- reconverts it into the TxBx* (or BxTx* if batchFirst) tensor and lengths array
+function RNN:padPackedSequence(seq, batchFirst)
+ local data, bszpts = unpack(seq)
+ local maxBatchSize = bszpts[1]
+ local outputSize = torch.LongStorage(2 + data[1]:nDimension())
+ outputSize[1] = #bszpts
+ outputSize[2] = maxBatchSize
+ for i = 1, data[1]:nDimension() do
+ outputSize[i + 2] = data[1]:size(i)
+ end
+ local output = torch.Tensor():typeAs(data):resize(outputSize):zero()
+
+ local lengths = {}
+ local offset = 1
+ local pbsz = bszpts[1]
+ local bsz = nil
+
+ local i = 1
+ while i <= #bszpts do
+ bsz = bszpts[i]
+ output[i]:narrow(1, 1, bsz):copy(data:narrow(1, offset, bsz))
+ offset = offset + bsz
+
+ local dec = pbsz - bsz
+ for j = 1, dec do
+ table.insert(lengths, i - 1)
+ end
+ pbsz = bsz
+ i = i + 1
+ end
+ for j = 1, bsz do
+ table.insert(lengths, i - 1)
+ end
+
+ -- reverse lengths list
+ local reversed = {}
+ for i = #lengths, 1, -1 do
+ table.insert(reversed, lengths[i])
+ end
+
+ if batchFirst then
+ output = output:transpose(1, 2)
+ end
+ return output, reversed
+end
function RNN:updateOutput(input)
if (self.batchFirst) then
diff --git a/test/test_rnn.lua b/test/test_rnn.lua
index 0d0b37b..b761507 100644
--- a/test/test_rnn.lua
+++ b/test/test_rnn.lua
@@ -258,6 +258,62 @@ function getRNNCheckSums(miniBatch, seqLength, hiddenSize, numberOfLayers, numbe
return checkSums
end
+function cudnntest.testPackPadSequences()
+ -- T is 4, B = 5, vector size = 3
+ local input = torch.CudaIntTensor({
+ {{101, 102, 103},
+ {201, 202, 203},
+ {301, 302, 303},
+ {401, 402, 403},
+ {501, 502, 503}},
+ {{104, 105, 106},
+ {204, 205, 206},
+ {304, 305, 306},
+ { 0, 0, 0},
+ { 0, 0, 0}},
+ {{107, 108, 109},
+ {207, 208, 209},
+ { 0, 0, 0},
+ { 0, 0, 0},
+ { 0, 0, 0}},
+ {{110, 111, 112},
+ { 0, 0, 0},
+ { 0, 0, 0},
+ { 0, 0, 0},
+ { 0, 0, 0}},
+ })
+ local lengths = {4, 3, 2, 1, 1}
+
+ local expectedPacked = torch.CudaIntTensor({
+ {101, 102, 103}, {201, 202, 203}, {301, 302, 303}, {401, 402, 403}, {501, 502, 503},
+ {104, 105, 106}, {204, 205, 206}, {304, 305, 306},
+ {107, 108, 109}, {207, 208, 209},
+ {110, 111, 112}
+ })
+ local expectedBSPT = {5, 3, 2, 1}
+
+ local result = cudnn.RNN:packPaddedSequence(input, lengths)
+ local actualPacked, actualBSPT = unpack(result)
+ mytester:assertTensorEq(expectedPacked, actualPacked)
+ mytester:assertTableEq(expectedBSPT, actualBSPT)
+
+ local actualUnpacked, actualLengths = cudnn.RNN:padPackedSequence(result)
+ mytester:assertTensorEq(input, actualUnpacked)
+ mytester:assertTableEq(lengths, actualLengths)
+
+ -- test again with batchFirst
+ input = input:transpose(1, 2)
+
+ local result = cudnn.RNN:packPaddedSequence(input, lengths, true)
+ local actualPacked, actualBSPT = unpack(result)
+ mytester:assertTensorEq(expectedPacked, actualPacked)
+ mytester:assertTableEq(expectedBSPT, actualBSPT)
+
+ local actualUnpacked, actualLengths = cudnn.RNN:padPackedSequence(result, true)
+ mytester:assertTensorEq(input, actualUnpacked)
+ mytester:assertTableEq(lengths, actualLengths)
+end
+
mytester = torch.Tester()
mytester:add(cudnntest)
mytester:run()