From fb62faf355ad94c4f26c0ccbda8cb45fb645b7e0 Mon Sep 17 00:00:00 2001 From: Trevor Killeen Date: Fri, 31 Mar 2017 13:35:53 -0700 Subject: refactor forward to potentially accept packed tensor + sequence lengths --- RNN.lua | 192 +++++++++++++++++++++++++++++++++++++++++++++++----------------- 1 file changed, 143 insertions(+), 49 deletions(-) diff --git a/RNN.lua b/RNN.lua index 428c81d..45e137e 100644 --- a/RNN.lua +++ b/RNN.lua @@ -28,6 +28,7 @@ function RNN:__init(inputSize, hiddenSize, numLayers, batchFirst, dropout, remem self.batchFirst = batchFirst or false -- Set to true for batch x time x inputdim. self.rememberStates = rememberStates or false self.sync = true + self.inputPacked = false self.gradInput = torch.CudaTensor() self.output = torch.CudaTensor() self.weight = torch.CudaTensor() @@ -50,7 +51,8 @@ function RNN:reset(stdv) self:resetDropoutDescriptor() self:resetRNNDescriptor() - self:resetIODescriptors() + self:resetInputDescriptor() + self:resetOutputDescriptor() local weightSizePtr = ffi.new("size_t[1]") errcheck('cudnnGetRNNParamsSize', @@ -141,28 +143,60 @@ function RNN:resetWeightDescriptor() ) end -function RNN:resetIODescriptors() +function RNN:resetInputDescriptor(input, batchSizes) self.xDescs = self:createTensorDescriptors(self.seqLength) - self.yDescs = self:createTensorDescriptors(self.seqLength) - for i = 0, self.seqLength - 1 do - local dim = torch.IntTensor({ self.miniBatch,self.inputSize, 1}) - local stride = torch.IntTensor({dim[3] * dim[2], dim[3],1}) - errcheck('cudnnSetTensorNdDescriptor', - self.xDescs[i], - self.datatype, - 3, - dim:data(), - stride:data()) - - local dim = torch.IntTensor({self.miniBatch, self.hiddenSize * self.numDirections, 1}) - local stride = torch.IntTensor({dim[3] * dim[2], dim[3],1}) - errcheck('cudnnSetTensorNdDescriptor', - self.yDescs[i], - self.datatype, - 3, - dim:data(), - stride:data()) + if self.inputPacked and input ~= nil and batchSizes ~= nil then + assert(#batchSizes == self.seqLength) + for i = 0, self.seqLength - 1 do + -- tensor shape is (# of sequences in the batch at the timestep, inputSize, 1 (for cudnn)) + local dim = torch.IntTensor({batchSizes[i+1], input:size(2), 1}) + local stride = torch.IntTensor({dim[3] * dim[2], dim[3],1}) + errcheck('cudnnSetTensorNdDescriptor', + self.xDescs[i], + self.datatype, + 3, + dim:data(), + stride:data()) + end + else + for i = 0, self.seqLength - 1 do + local dim = torch.IntTensor({ self.miniBatch,self.inputSize, 1}) + local stride = torch.IntTensor({dim[3] * dim[2], dim[3],1}) + errcheck('cudnnSetTensorNdDescriptor', + self.xDescs[i], + self.datatype, + 3, + dim:data(), + stride:data()) + end + end +end + +function RNN:resetOutputDescriptor(output, batchSizes) + self.yDescs = self:createTensorDescriptors(self.seqLength) + if self.inputPacked and output ~= nil and batchSizes ~= nil then + for i = 0, self.seqLength - 1 do + local dim = torch.IntTensor({batchSizes[i+1], self.hiddenSize * self.numDirections, 1}) + local stride = torch.IntTensor({dim[3] * dim[2], dim[3],1}) + errcheck('cudnnSetTensorNdDescriptor', + self.xDescs[i], + self.datatype, + 3, + dim:data(), + stride:data()) + end + else + for i = 0, self.seqLength - 1 do + local dim = torch.IntTensor({self.miniBatch, self.hiddenSize * self.numDirections, 1}) + local stride = torch.IntTensor({dim[3] * dim[2], dim[3],1}) + errcheck('cudnnSetTensorNdDescriptor', + self.yDescs[i], + self.datatype, + 3, + dim:data(), + stride:data()) + end end end @@ -220,10 +254,6 @@ function RNN:makeContiguous(input, gradOutput) return input, gradOutput end -function RNN:resizeOutput(tensor) - return tensor:resize(self.seqLength, self.miniBatch, self.hiddenSize * self.numDirections) -end - function RNN:resizeHidden(tensor) return tensor:resize(self.numLayers * self.numDirections, self.miniBatch, self.hiddenSize) end @@ -348,11 +378,44 @@ function RNN:padPackedSequence(seq, batchFirst) return output, reversed end +-- it feels a little dirty setting this function on the class as opposed +-- to having it be functional, but because we need to access class state, +-- here we are... +function RNN:deriveOutputSize(input) + if self.inputPacked then + return torch.LongStorage({input:size(1), self.hiddenSize * self.numDirections}) + else + return torch.LongStorage({self.seqLength, self.miniBatch, self.hiddenSize * self.numDirections}) + end +end + +-- updateOutput takes either of the following as inputs: +-- +-- 1. A seqLength x miniBatch x inputSize Tensor, where seqLength is the +-- length of the sequence for every input in the batch, miniBatch is the +-- number of elements in the batch, and inputSize is the size of the input vectors +-- at each time step +-- +-- OR +-- +-- 2. A table containing a packed tensor and a list of batch sizes per timestep. In this +-- case we are supporting variable length sequences for the forward pass. This table +-- is the output from packPaddedSequence(...) above function RNN:updateOutput(input) - if (self.batchFirst) then - input = input:transpose(1, 2) - end - assert(input:dim() == 3, 'input must have 3 dimensions: seqLength, miniBatch, inputSize') + local inputPacked = (type(input) == 'table') + local switched = self.inputPacked ~= inputPacked + self.inputPacked = inputPacked + + if self.batchFirst and not self.inputPacked then + input = input:transpose(1, 2) + end + + if self.inputPacked then + assert(input[1]:dim() == 2, 'packed input must have two dimensions: sum(sequence lengths), inputSize') + else + assert(input:dim() == 3, 'input must have 3 dimensions: seqLength, miniBatch, inputSize') + end + assert(self.dropout == 0 or cudnn.version >= 5103, 'dropout supported only in cudnn v5.1 and above') -- Decide which descriptors/tensors need to be updated. local resetRNN = not self.dropoutDesc or not self.rnnDesc @@ -360,19 +423,56 @@ function RNN:updateOutput(input) local resetHC = not self.hxDesc or not self.hyDesc or not self.cxDesc or not self.cyDesc local resetWeight = not self.wDesc - if input:size(1) ~= self.seqLength then - self.seqLength = input:size(1) - resetIO = true - end - - if input:size(2) ~= self.miniBatch then - self.miniBatch = input:size(2) - resetIO = true - resetHC = true + if self.inputPacked then + -- Handle resets for packed input + + -- In the case of packed inputs, the sequence length is the length of the bsz per time list. + -- We need to reset the IO descriptors if this has changed. + if #input[2] ~= self.seqLength then + self.seqLength = #input[2] + resetIO = true + end + + -- Similarly, the miniBatch "size" is the batch size at the first timestep (when all + -- sequences are in the batch, regardless of length). If this has changed then we need + -- to reset both the IO descriptors and the hidden/cell descriptors + if input[2][1] ~= self.miniBatch then + self.miniBatch = input[2][1] + resetIO = true + resetHC = true + end + assert(input[1]:size(2) == self.inputSize, 'Incorrect input size!') + else + -- Handle resets for standard (i.e. not packed) input + + -- If the length of the sequences in this input batch differ from the previous batch + -- we need to: reset the IO descriptors to describe the new size of the input and + -- output Tensors in the seqLength dimension + if input:size(1) ~= self.seqLength then + self.seqLength = input:size(1) + resetIO = true + end + + -- If the batch size has changed we need to: + -- 1. Update the IO descritprs to describe the new size of the input and output Tensors in the + -- batchSize dimension + -- 2. Reset the size of the hidden/cell descriptors so they can store batchSize states + if input:size(2) ~= self.miniBatch then + self.miniBatch = input:size(2) + resetIO = true + resetHC = true + end + assert(input:size(3) == self.inputSize, 'Incorrect input size!') end - assert(input:size(3) == self.inputSize, 'Incorrect input size!') - + -- Make sure input is contiguous + local x = self:makeContiguous(self.inputPacked and input[1] or input) + local oSize = self:deriveOutputSize(x) + local oStride = torch.LongStorage({oSize[2] * oSize[3], oSize[3], 1}) + self.output:resize(oSize, oStride) + local y = self.output + local w = self.weight + local bszpts = self.inputPacked and input[2] -- Update descriptors/tensors if resetRNN then @@ -380,7 +480,8 @@ function RNN:updateOutput(input) self:resetRNNDescriptor() end if resetIO then - self:resetIODescriptors(input) + self:resetInputDescriptor(x, bszpts) + self:resetOutputDescriptor(y, bszpts) end if resetHC then self:resetHiddenDescriptors() @@ -390,13 +491,6 @@ function RNN:updateOutput(input) self:resetWeightDescriptor() end - local x = self:makeContiguous(input) - local oSize = torch.LongStorage({self.seqLength, self.miniBatch, self.hiddenSize * self.numDirections}) - local oStride = torch.LongStorage({self.miniBatch * self.hiddenSize * self.numDirections, self.hiddenSize * self.numDirections, 1}) - self.output:resize(oSize, oStride) - local y = self.output - local w = self.weight - -- Optionally use hiddenInput/cellInput parameters if self.rememberStates then if self.hiddenOutput:nDimension() == 3 and self.hiddenOutput:size(1) == self.numLayers * self.numDirections and @@ -483,7 +577,7 @@ function RNN:updateOutput(input) wsSize) end if self.sync then cutorch.synchronize() end - if (self.batchFirst) then + if self.batchFirst and not self.inputPacked then self.output = self.output:transpose(1, 2) end return self.output -- cgit v1.2.3