Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/soumith/cudnn.torch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTrevor Killeen <killeentm@gmail.com>2017-03-31 23:35:53 +0300
committerTrevor Killeen <killeentm@gmail.com>2017-04-04 17:49:25 +0300
commitfb62faf355ad94c4f26c0ccbda8cb45fb645b7e0 (patch)
tree1ff5a03bd06ee1229ceff843e9fb06052bb563d2
parentbd72272c7fade737de84d25b604aa9f8b77e9fe7 (diff)
refactor forward to potentially accept packed tensor + sequence lengths
-rw-r--r--RNN.lua192
1 files changed, 143 insertions, 49 deletions
diff --git a/RNN.lua b/RNN.lua
index 428c81d..45e137e 100644
--- a/RNN.lua
+++ b/RNN.lua
@@ -28,6 +28,7 @@ function RNN:__init(inputSize, hiddenSize, numLayers, batchFirst, dropout, remem
self.batchFirst = batchFirst or false -- Set to true for batch x time x inputdim.
self.rememberStates = rememberStates or false
self.sync = true
+ self.inputPacked = false
self.gradInput = torch.CudaTensor()
self.output = torch.CudaTensor()
self.weight = torch.CudaTensor()
@@ -50,7 +51,8 @@ function RNN:reset(stdv)
self:resetDropoutDescriptor()
self:resetRNNDescriptor()
- self:resetIODescriptors()
+ self:resetInputDescriptor()
+ self:resetOutputDescriptor()
local weightSizePtr = ffi.new("size_t[1]")
errcheck('cudnnGetRNNParamsSize',
@@ -141,28 +143,60 @@ function RNN:resetWeightDescriptor()
)
end
-function RNN:resetIODescriptors()
+function RNN:resetInputDescriptor(input, batchSizes)
self.xDescs = self:createTensorDescriptors(self.seqLength)
- self.yDescs = self:createTensorDescriptors(self.seqLength)
- for i = 0, self.seqLength - 1 do
- local dim = torch.IntTensor({ self.miniBatch,self.inputSize, 1})
- local stride = torch.IntTensor({dim[3] * dim[2], dim[3],1})
- errcheck('cudnnSetTensorNdDescriptor',
- self.xDescs[i],
- self.datatype,
- 3,
- dim:data(),
- stride:data())
-
- local dim = torch.IntTensor({self.miniBatch, self.hiddenSize * self.numDirections, 1})
- local stride = torch.IntTensor({dim[3] * dim[2], dim[3],1})
- errcheck('cudnnSetTensorNdDescriptor',
- self.yDescs[i],
- self.datatype,
- 3,
- dim:data(),
- stride:data())
+ if self.inputPacked and input ~= nil and batchSizes ~= nil then
+ assert(#batchSizes == self.seqLength)
+ for i = 0, self.seqLength - 1 do
+ -- tensor shape is (# of sequences in the batch at the timestep, inputSize, 1 (for cudnn))
+ local dim = torch.IntTensor({batchSizes[i+1], input:size(2), 1})
+ local stride = torch.IntTensor({dim[3] * dim[2], dim[3],1})
+ errcheck('cudnnSetTensorNdDescriptor',
+ self.xDescs[i],
+ self.datatype,
+ 3,
+ dim:data(),
+ stride:data())
+ end
+ else
+ for i = 0, self.seqLength - 1 do
+ local dim = torch.IntTensor({ self.miniBatch,self.inputSize, 1})
+ local stride = torch.IntTensor({dim[3] * dim[2], dim[3],1})
+ errcheck('cudnnSetTensorNdDescriptor',
+ self.xDescs[i],
+ self.datatype,
+ 3,
+ dim:data(),
+ stride:data())
+ end
+ end
+end
+
+function RNN:resetOutputDescriptor(output, batchSizes)
+ self.yDescs = self:createTensorDescriptors(self.seqLength)
+ if self.inputPacked and output ~= nil and batchSizes ~= nil then
+ for i = 0, self.seqLength - 1 do
+ local dim = torch.IntTensor({batchSizes[i+1], self.hiddenSize * self.numDirections, 1})
+ local stride = torch.IntTensor({dim[3] * dim[2], dim[3],1})
+ errcheck('cudnnSetTensorNdDescriptor',
+ self.xDescs[i],
+ self.datatype,
+ 3,
+ dim:data(),
+ stride:data())
+ end
+ else
+ for i = 0, self.seqLength - 1 do
+ local dim = torch.IntTensor({self.miniBatch, self.hiddenSize * self.numDirections, 1})
+ local stride = torch.IntTensor({dim[3] * dim[2], dim[3],1})
+ errcheck('cudnnSetTensorNdDescriptor',
+ self.yDescs[i],
+ self.datatype,
+ 3,
+ dim:data(),
+ stride:data())
+ end
end
end
@@ -220,10 +254,6 @@ function RNN:makeContiguous(input, gradOutput)
return input, gradOutput
end
-function RNN:resizeOutput(tensor)
- return tensor:resize(self.seqLength, self.miniBatch, self.hiddenSize * self.numDirections)
-end
-
function RNN:resizeHidden(tensor)
return tensor:resize(self.numLayers * self.numDirections, self.miniBatch, self.hiddenSize)
end
@@ -348,11 +378,44 @@ function RNN:padPackedSequence(seq, batchFirst)
return output, reversed
end
+-- it feels a little dirty setting this function on the class as opposed
+-- to having it be functional, but because we need to access class state,
+-- here we are...
+function RNN:deriveOutputSize(input)
+ if self.inputPacked then
+ return torch.LongStorage({input:size(1), self.hiddenSize * self.numDirections})
+ else
+ return torch.LongStorage({self.seqLength, self.miniBatch, self.hiddenSize * self.numDirections})
+ end
+end
+
+-- updateOutput takes either of the following as inputs:
+--
+-- 1. A seqLength x miniBatch x inputSize Tensor, where seqLength is the
+-- length of the sequence for every input in the batch, miniBatch is the
+-- number of elements in the batch, and inputSize is the size of the input vectors
+-- at each time step
+--
+-- OR
+--
+-- 2. A table containing a packed tensor and a list of batch sizes per timestep. In this
+-- case we are supporting variable length sequences for the forward pass. This table
+-- is the output from packPaddedSequence(...) above
function RNN:updateOutput(input)
- if (self.batchFirst) then
- input = input:transpose(1, 2)
- end
- assert(input:dim() == 3, 'input must have 3 dimensions: seqLength, miniBatch, inputSize')
+ local inputPacked = (type(input) == 'table')
+ local switched = self.inputPacked ~= inputPacked
+ self.inputPacked = inputPacked
+
+ if self.batchFirst and not self.inputPacked then
+ input = input:transpose(1, 2)
+ end
+
+ if self.inputPacked then
+ assert(input[1]:dim() == 2, 'packed input must have two dimensions: sum(sequence lengths), inputSize')
+ else
+ assert(input:dim() == 3, 'input must have 3 dimensions: seqLength, miniBatch, inputSize')
+ end
+
assert(self.dropout == 0 or cudnn.version >= 5103, 'dropout supported only in cudnn v5.1 and above')
-- Decide which descriptors/tensors need to be updated.
local resetRNN = not self.dropoutDesc or not self.rnnDesc
@@ -360,19 +423,56 @@ function RNN:updateOutput(input)
local resetHC = not self.hxDesc or not self.hyDesc or not self.cxDesc or not self.cyDesc
local resetWeight = not self.wDesc
- if input:size(1) ~= self.seqLength then
- self.seqLength = input:size(1)
- resetIO = true
- end
-
- if input:size(2) ~= self.miniBatch then
- self.miniBatch = input:size(2)
- resetIO = true
- resetHC = true
+ if self.inputPacked then
+ -- Handle resets for packed input
+
+ -- In the case of packed inputs, the sequence length is the length of the bsz per time list.
+ -- We need to reset the IO descriptors if this has changed.
+ if #input[2] ~= self.seqLength then
+ self.seqLength = #input[2]
+ resetIO = true
+ end
+
+ -- Similarly, the miniBatch "size" is the batch size at the first timestep (when all
+ -- sequences are in the batch, regardless of length). If this has changed then we need
+ -- to reset both the IO descriptors and the hidden/cell descriptors
+ if input[2][1] ~= self.miniBatch then
+ self.miniBatch = input[2][1]
+ resetIO = true
+ resetHC = true
+ end
+ assert(input[1]:size(2) == self.inputSize, 'Incorrect input size!')
+ else
+ -- Handle resets for standard (i.e. not packed) input
+
+ -- If the length of the sequences in this input batch differ from the previous batch
+ -- we need to: reset the IO descriptors to describe the new size of the input and
+ -- output Tensors in the seqLength dimension
+ if input:size(1) ~= self.seqLength then
+ self.seqLength = input:size(1)
+ resetIO = true
+ end
+
+ -- If the batch size has changed we need to:
+ -- 1. Update the IO descritprs to describe the new size of the input and output Tensors in the
+ -- batchSize dimension
+ -- 2. Reset the size of the hidden/cell descriptors so they can store batchSize states
+ if input:size(2) ~= self.miniBatch then
+ self.miniBatch = input:size(2)
+ resetIO = true
+ resetHC = true
+ end
+ assert(input:size(3) == self.inputSize, 'Incorrect input size!')
end
- assert(input:size(3) == self.inputSize, 'Incorrect input size!')
-
+ -- Make sure input is contiguous
+ local x = self:makeContiguous(self.inputPacked and input[1] or input)
+ local oSize = self:deriveOutputSize(x)
+ local oStride = torch.LongStorage({oSize[2] * oSize[3], oSize[3], 1})
+ self.output:resize(oSize, oStride)
+ local y = self.output
+ local w = self.weight
+ local bszpts = self.inputPacked and input[2]
-- Update descriptors/tensors
if resetRNN then
@@ -380,7 +480,8 @@ function RNN:updateOutput(input)
self:resetRNNDescriptor()
end
if resetIO then
- self:resetIODescriptors(input)
+ self:resetInputDescriptor(x, bszpts)
+ self:resetOutputDescriptor(y, bszpts)
end
if resetHC then
self:resetHiddenDescriptors()
@@ -390,13 +491,6 @@ function RNN:updateOutput(input)
self:resetWeightDescriptor()
end
- local x = self:makeContiguous(input)
- local oSize = torch.LongStorage({self.seqLength, self.miniBatch, self.hiddenSize * self.numDirections})
- local oStride = torch.LongStorage({self.miniBatch * self.hiddenSize * self.numDirections, self.hiddenSize * self.numDirections, 1})
- self.output:resize(oSize, oStride)
- local y = self.output
- local w = self.weight
-
-- Optionally use hiddenInput/cellInput parameters
if self.rememberStates then
if self.hiddenOutput:nDimension() == 3 and self.hiddenOutput:size(1) == self.numLayers * self.numDirections and
@@ -483,7 +577,7 @@ function RNN:updateOutput(input)
wsSize)
end
if self.sync then cutorch.synchronize() end
- if (self.batchFirst) then
+ if self.batchFirst and not self.inputPacked then
self.output = self.output:transpose(1, 2)
end
return self.output