diff options
author | SeanNaren <taz838@hotmail.co.uk> | 2016-05-13 21:51:00 +0300 |
---|---|---|
committer | SeanNaren <taz838@hotmail.co.uk> | 2016-05-13 21:51:00 +0300 |
commit | c4006a5f32fae62f8e7bd261d93f99d47987b5f3 (patch) | |
tree | 618021f6a5f8c8c3f3b7feb96f2416627be6cb3e /CTCCriterion.lua | |
parent | f1517225248670c947209ee9fffd731914c96bd2 (diff) |
Updated CTCCriterion with variable lengths and batchFirst
Diffstat (limited to 'CTCCriterion.lua')
-rw-r--r-- | CTCCriterion.lua | 119 |
1 files changed, 64 insertions, 55 deletions
diff --git a/CTCCriterion.lua b/CTCCriterion.lua index efabf13..6d4a2ec 100644 --- a/CTCCriterion.lua +++ b/CTCCriterion.lua @@ -1,55 +1,87 @@ ------------------------------------------------------------------------ ---[[ CTCCriterion ]]-- +--[[ CTCCriterion ]] -- -- CTC Alignment for sequence data where input and labels do not align. -- Useful for speech recognition on a phoneme/character level basis. --- Inputs assumed are in the form of batch x time x inputdim. +-- Inputs assumed are in the form of seqLength x batch x inputDim. +-- If batchFirst = true then input in the form of batch x seqLength x inputDim. -- Targets assumed in the form of {{1,2},{3,4}} where {1,2} is for the first --- element. +-- element and so forth. ------------------------------------------------------------------------ local CTCCriterion, parent = torch.class('nn.CTCCriterion', 'nn.Criterion') -CTCCriterion.dim = 3 - -function CTCCriterion:__init() - parent.__init(self) +function CTCCriterion:__init(batchFirst) require 'warp_ctc' + parent.__init(self) self.acts = torch.Tensor() - self.convertedGradients = torch.Tensor() + self.batchFirst = batchFirst or false +end + +function CTCCriterion:forward(input, target, sizes) + return self:updateOutput(input, target, sizes) end -function CTCCriterion:updateOutput(output, labels) - assert(output:nDimension() == CTCCriterion.dim, "Output must be a tensor of (batch x time x inputdim), recieved " .. output:nDimension() .. " dimensions") - local tensorSizes = output:size() - local acts = self:createCTCBatch(output, tensorSizes) - local sizes = torch.Tensor(tensorSizes[1]):fill(tensorSizes[2]) - sizes = torch.totable(sizes) - if (output:type() == 'torch.CudaTensor') then - local grads = torch.CudaTensor() - self.output = sumCosts(gpu_ctc(acts, grads, labels, sizes)) +function CTCCriterion:updateOutput(input, target, sizes) + assert(sizes, + "You must pass the size of each sequence in the batch as a tensor") + local acts = self.acts + if input:dim() == 3 then + acts:resizeAs(input):copy(input) + if self.batchFirst then + acts = acts:transpose(1, 2) + acts = self:makeContiguous(acts) + end + acts:view(acts, acts:size(1) * acts:size(2), -1) + end + assert(acts:nDimension() == 2) + self.sizes = torch.totable(sizes) + self.gradInput = acts.new():resizeAs(acts):zero() + if input:type() == 'torch.CudaTensor' then + self.output = sumCosts(gpu_ctc(acts, self.gradInput, target, self.sizes)) else - local grads = torch.Tensor() - self.output = sumCosts(cpu_ctc(acts:float(), grads:float(), labels, sizes)) + acts = acts:float() + self.gradInput = self.gradInput:float() + self.output = sumCosts(cpu_ctc(acts, self.gradInput, target, self.sizes)) end - return self.output end -function CTCCriterion:updateGradInput(output, labels) - local tensorSizes = output:size() - local acts = self:createCTCBatch(output, tensorSizes) - local sizes = torch.Tensor(tensorSizes[1]):fill(tensorSizes[2]) - sizes = torch.totable(sizes) - local grads = acts:clone():zero() - if (output:type() == 'torch.CudaTensor') then - gpu_ctc(acts, grads, labels, sizes) - else - grads = grads:float() - cpu_ctc(acts:float(), grads, labels, sizes) +function CTCCriterion:updateGradInput(input, target) + if input:dim() == 2 then -- (seqLen * batchSize) x outputDim + return self.gradInput + end + if self.batchFirst then -- batchSize x seqLen x outputDim + self.gradInput = inverseInterleave(self.gradInput, input:size(1)) + else -- seqLen x batchSize x outputDim + self.gradInput:view(self.gradInput, input:size(1), input:size(2), -1) end - self.gradInput = self:revertBatching(grads, tensorSizes):typeAs(output) return self.gradInput end +function CTCCriterion:makeContiguous(input) + if not input:isContiguous() then + self._input = self._input or input.new() + self._input:typeAs(input):resizeAs(input):copy(input) + input = self._input + end + return input +end + +function inverseInterleave(tensor, batchSize) + local sizes = torch.LongStorage(3) + sizes[1] = batchSize + sizes[2] = tensor:size(1) / batchSize + sizes[3] = tensor:size(2) + local result = tensor.new():resize(sizes):zero() + local counter = 1 + for i = 1, sizes[2] do + for j = 1, sizes[1] do + result[j][i]:copy(tensor[counter]) + counter = counter + 1 + end + end + return result +end + --If batching occurs multiple costs are returned. We sum the costs and return. function sumCosts(list) local acc @@ -61,27 +93,4 @@ function sumCosts(list) end end return acc -end - ---[[ --- Converts the outputs into batch warp-ctc format seen at the end of the README here: --- https://github.com/baidu-research/warp-ctc/blob/master/torch_binding/TUTORIAL.md - ]] -function CTCCriterion:createCTCBatch(output, sizes) - self.acts:resize(sizes[1] * sizes[2], sizes[3]):zero() - local output = output:transpose(1, 2) - self.acts = torch.reshape(output, sizes[1] * sizes[2], sizes[3]) - return self.acts -end - -function CTCCriterion:revertBatching(gradients, sizes) - self.convertedGradients:resize(sizes[1], sizes[2], sizes[3]):zero() - local counter = 1 - for i = 1, sizes[2] do - for j = 1, sizes[1] do - self.convertedGradients[j][i] = gradients[counter] - counter = counter + 1 - end - end - return self.convertedGradients end
\ No newline at end of file |