diff options
author | Trevor Killeen <killeentm@gmail.com> | 2017-03-31 17:49:06 +0300 |
---|---|---|
committer | Trevor Killeen <killeentm@gmail.com> | 2017-04-04 17:48:24 +0300 |
commit | bd72272c7fade737de84d25b604aa9f8b77e9fe7 (patch) | |
tree | 1e8d282a834d8e4894e85960f8e75aa629de92d5 | |
parent | 51239c7bf128a77d5c2c8cddb6b67ab3ba8664bf (diff) |
implement pack/pad utility functions
-rw-r--r-- | RNN.lua | 98 | ||||
-rw-r--r-- | test/test_rnn.lua | 56 |
2 files changed, 154 insertions, 0 deletions
@@ -247,8 +247,106 @@ function RNN:resetStates() end end +-- input a TxBx* tensor (or BxTx* if batchFirst) where T is the length +-- of the longest sequence, B is the batch size, and * is any number of +-- dimensions. +-- +-- lengths is a table of sequence lengths, which should be sorted in +-- decreasing order. +-- +-- returns a table containing a packed tensor of size (sum of lengths x *) +-- and a list of batch sizes per timestep, i.e. the number of sequences +-- with at least timestep elements. +function RNN:packPaddedSequence(input, lengths, batchFirst) + if batchFirst then + input = input:transpose(1, 2) + end + + local batches = {} + local bszpts = {} + local lengthsIdx = #lengths + local currentLength = lengths[lengthsIdx] + + local steps = input:size(1) + local bsz = input:size(2) + if bsz ~= #lengths then + error("lengths array has incorrect size (expected: " .. bsz .. "but found: " .. #lengths ..")") + end + + for ts = 1, steps do + table.insert(batches, input[ts]:narrow(1, 1, bsz)) + table.insert(bszpts, bsz) + while ts == currentLength do + if lengthsIdx == 0 then + currentLength = nil + break + else + lengthsIdx = lengthsIdx - 1 + bsz = bsz - 1 + local nextLength = lengths[lengthsIdx] + if currentLength ~= nil and nextLength ~= nil and currentLength > nextLength then + error("lengths array has to be sorted in decreasing order") + end + currentLength = lengths[lengthsIdx] + end + end + + if currentLength == nil then + break + end + end + return {torch.cat(batches, 1), bszpts} +end + +-- An inverse operation to packPaddedSequence(...) above. Takes a sequence (i.e. +-- a Tensor, bszpts table with the format as returned by packPaddedSequence and +-- reconverts it into the TxBx* (or BxTx* if batchFirst) tensor and lengths array +function RNN:padPackedSequence(seq, batchFirst) + local data, bszpts = unpack(seq) + local maxBatchSize = bszpts[1] + local outputSize = torch.LongStorage(2 + data[1]:nDimension()) + outputSize[1] = #bszpts + outputSize[2] = maxBatchSize + for i = 1, data[1]:nDimension() do + outputSize[i + 2] = data[1]:size(i) + end + local output = torch.Tensor():typeAs(data):resize(outputSize):zero() + + local lengths = {} + local offset = 1 + local pbsz = bszpts[1] + local bsz = nil + + local i = 1 + while i <= #bszpts do + bsz = bszpts[i] + output[i]:narrow(1, 1, bsz):copy(data:narrow(1, offset, bsz)) + offset = offset + bsz + + local dec = pbsz - bsz + for j = 1, dec do + table.insert(lengths, i - 1) + end + pbsz = bsz + i = i + 1 + end + for j = 1, bsz do + table.insert(lengths, i - 1) + end + + -- reverse lengths list + local reversed = {} + for i = #lengths, 1, -1 do + table.insert(reversed, lengths[i]) + end + + if batchFirst then + output = output:transpose(1, 2) + end + return output, reversed +end function RNN:updateOutput(input) if (self.batchFirst) then diff --git a/test/test_rnn.lua b/test/test_rnn.lua index 0d0b37b..b761507 100644 --- a/test/test_rnn.lua +++ b/test/test_rnn.lua @@ -258,6 +258,62 @@ function getRNNCheckSums(miniBatch, seqLength, hiddenSize, numberOfLayers, numbe return checkSums end +function cudnntest.testPackPadSequences() + -- T is 4, B = 5, vector size = 3 + local input = torch.CudaIntTensor({ + {{101, 102, 103}, + {201, 202, 203}, + {301, 302, 303}, + {401, 402, 403}, + {501, 502, 503}}, + {{104, 105, 106}, + {204, 205, 206}, + {304, 305, 306}, + { 0, 0, 0}, + { 0, 0, 0}}, + {{107, 108, 109}, + {207, 208, 209}, + { 0, 0, 0}, + { 0, 0, 0}, + { 0, 0, 0}}, + {{110, 111, 112}, + { 0, 0, 0}, + { 0, 0, 0}, + { 0, 0, 0}, + { 0, 0, 0}}, + }) + local lengths = {4, 3, 2, 1, 1} + + local expectedPacked = torch.CudaIntTensor({ + {101, 102, 103}, {201, 202, 203}, {301, 302, 303}, {401, 402, 403}, {501, 502, 503}, + {104, 105, 106}, {204, 205, 206}, {304, 305, 306}, + {107, 108, 109}, {207, 208, 209}, + {110, 111, 112} + }) + local expectedBSPT = {5, 3, 2, 1} + + local result = cudnn.RNN:packPaddedSequence(input, lengths) + local actualPacked, actualBSPT = unpack(result) + mytester:assertTensorEq(expectedPacked, actualPacked) + mytester:assertTableEq(expectedBSPT, actualBSPT) + + local actualUnpacked, actualLengths = cudnn.RNN:padPackedSequence(result) + mytester:assertTensorEq(input, actualUnpacked) + mytester:assertTableEq(lengths, actualLengths) + + -- test again with batchFirst + input = input:transpose(1, 2) + + local result = cudnn.RNN:packPaddedSequence(input, lengths, true) + local actualPacked, actualBSPT = unpack(result) + mytester:assertTensorEq(expectedPacked, actualPacked) + mytester:assertTableEq(expectedBSPT, actualBSPT) + + local actualUnpacked, actualLengths = cudnn.RNN:padPackedSequence(result, true) + mytester:assertTensorEq(input, actualUnpacked) + mytester:assertTableEq(lengths, actualLengths) +end + mytester = torch.Tester() mytester:add(cudnntest) mytester:run() |