From f450325e81056bc209b2d38037c7d105a5efc813 Mon Sep 17 00:00:00 2001 From: Sean Date: Sat, 2 Apr 2016 16:02:58 +0100 Subject: added base CTCCriterion --- CTCCriterion.lua | 85 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 85 insertions(+) create mode 100644 CTCCriterion.lua diff --git a/CTCCriterion.lua b/CTCCriterion.lua new file mode 100644 index 0000000..1449db0 --- /dev/null +++ b/CTCCriterion.lua @@ -0,0 +1,85 @@ +require 'warp_ctc' + +local CTCCriterion, parent = torch.class('nn.CTCCriterion', 'nn.Criterion') + +function CTCCriterion:__init() + parent.__init(self) + self.acts = torch.Tensor() + self.convertedGradients = torch.Tensor() +end + +function CTCCriterion:updateOutput(output, labels) + local tensorSizes = output:size() + local acts = self:createCTCBatch(output, tensorSizes) + local sizes = {} -- For each batch we state the number of time steps. + for x = 1, tensorSizes[1] do + table.insert(sizes, tensorSizes[2]) + end + if (output:type() == 'torch.CudaTensor') then + local grads = torch.CudaTensor() + self.output = sumCosts(gpu_ctc(acts, grads, labels, sizes)) + else + local grads = torch.Tensor() + self.output = sumCosts(cpu_ctc(acts:float(), grads:float(), labels, sizes)) + end + + return self.output +end + +function CTCCriterion:updateGradInput(output, labels) + local tensorSizes = output:size() + local acts = self:createCTCBatch(output, tensorSizes) + local sizes = {} + for x = 1, tensorSizes[1] do + table.insert(sizes, tensorSizes[2]) + end + local grads = acts:clone():zero() + if (output:type() == 'torch.CudaTensor') then + gpu_ctc(acts, grads, labels, sizes) + else + cpu_ctc(acts:float(), grads:float(), labels, sizes) + end + self.gradInput = self:revertBatching(grads, tensorSizes):typeAs(output) + return self.gradInput +end + +--If batching occurs multiple costs are returned. We sum the costs and return. +function sumCosts(list) + local acc + for k, v in ipairs(list) do + if 1 == k then + acc = v + else + acc = acc + v + end + end + return acc +end + +--[[ +-- Converts the outputs into batch warp-ctc format seen at the end of the README here: +-- https://github.com/baidu-research/warp-ctc/blob/master/torch_binding/TUTORIAL.md + ]] +function CTCCriterion:createCTCBatch(output, sizes) + self.acts:resize(sizes[1] * sizes[2], sizes[3]):zero() + local counter = 1 + for i = 1, sizes[2] do + for j = 1, sizes[1] do + self.acts[counter] = output[j][i] + counter = counter + 1 + end + end + return self.acts +end + +function CTCCriterion:revertBatching(gradients, sizes) + self.convertedGradients:resize(sizes[1], sizes[2], sizes[3]):zero() + local counter = 1 + for i = 1, sizes[2] do + for j = 1, sizes[1] do + self.convertedGradients[j][i] = gradients[counter] + counter = counter + 1 + end + end + return self.convertedGradients +end \ No newline at end of file -- cgit v1.2.3