diff options
-rw-r--r-- | CTCCriterion.lua | 95 | ||||
-rw-r--r-- | README.md | 35 | ||||
-rw-r--r-- | init.lua | 1 | ||||
-rw-r--r-- | test/test-all.lua | 22 |
4 files changed, 153 insertions, 0 deletions
diff --git a/CTCCriterion.lua b/CTCCriterion.lua new file mode 100644 index 0000000..22f13af --- /dev/null +++ b/CTCCriterion.lua @@ -0,0 +1,95 @@ +------------------------------------------------------------------------ +--[[ CTCCriterion ]]-- +-- CTC Alignment for sequence data where input and labels do not align. +-- Useful for speech recognition on a phoneme/character level basis. +-- Inputs assumed are in the form of batch x time x inputdim. +-- Targets assumed in the form of {{1,2},{3,4}} where {1,2} is for the first +-- element. +------------------------------------------------------------------------ +local CTCCriterion, parent = torch.class('nn.CTCCriterion', 'nn.Criterion') + +CTCCriterion.dim = 3 + +function CTCCriterion:__init() + parent.__init(self) + require 'warp_ctc' + self.acts = torch.Tensor() + self.convertedGradients = torch.Tensor() +end + +function CTCCriterion:updateOutput(output, labels) + assert(output:nDimension() == CTCCriterion.dim, "Output must be a tensor of (batch x time x inputdim), recieved " .. output:nDimension() .. " dimensions") + local tensorSizes = output:size() + local acts = self:createCTCBatch(output, tensorSizes) + local sizes = {} -- For each batch we state the number of time steps. + for x = 1, tensorSizes[1] do + table.insert(sizes, tensorSizes[2]) + end + if (output:type() == 'torch.CudaTensor') then + local grads = torch.CudaTensor() + self.output = sumCosts(gpu_ctc(acts, grads, labels, sizes)) + else + local grads = torch.Tensor() + self.output = sumCosts(cpu_ctc(acts:float(), grads:float(), labels, sizes)) + end + + return self.output +end + +function CTCCriterion:updateGradInput(output, labels) + local tensorSizes = output:size() + local acts = self:createCTCBatch(output, tensorSizes) + local sizes = {} + for x = 1, tensorSizes[1] do + table.insert(sizes, tensorSizes[2]) + end + local grads = acts:clone():zero() + if (output:type() == 'torch.CudaTensor') then + gpu_ctc(acts, grads, labels, sizes) + else + cpu_ctc(acts:float(), grads:float(), labels, sizes) + end + self.gradInput = self:revertBatching(grads, tensorSizes):typeAs(output) + return self.gradInput +end + +--If batching occurs multiple costs are returned. We sum the costs and return. +function sumCosts(list) + local acc + for k, v in ipairs(list) do + if 1 == k then + acc = v + else + acc = acc + v + end + end + return acc +end + +--[[ +-- Converts the outputs into batch warp-ctc format seen at the end of the README here: +-- https://github.com/baidu-research/warp-ctc/blob/master/torch_binding/TUTORIAL.md + ]] +function CTCCriterion:createCTCBatch(output, sizes) + self.acts:resize(sizes[1] * sizes[2], sizes[3]):zero() + local counter = 1 + for i = 1, sizes[2] do + for j = 1, sizes[1] do + self.acts[counter] = output[j][i] + counter = counter + 1 + end + end + return self.acts +end + +function CTCCriterion:revertBatching(gradients, sizes) + self.convertedGradients:resize(sizes[1], sizes[2], sizes[3]):zero() + local counter = 1 + for i = 1, sizes[2] do + for j = 1, sizes[1] do + self.convertedGradients[j][i] = gradients[counter] + counter = counter + 1 + end + end + return self.convertedGradients +end
\ No newline at end of file @@ -10,6 +10,7 @@ This section includes documentation for the following objects: * [SoftMaxTree](#nnx.SoftMaxTree) : a hierarchical log-softmax Module; * [TreeNLLCriterion](#nnx.TreeNLLCriterion) : a negative log-likelihood Criterion for the SoftMaxTree; + * [CTCCriterion](#nnx.CTCCriterion) : a Connectionist Temporal Classification Criterion based on [warp-ctc](https://github.com/baidu-research/warp-ctc); * [PushTable (and PullTable)](#nnx.PushTable) : extracts a table element and inserts it later in the network; * [MultiSoftMax](#nnx.MultiSoftMax) : performs a softmax over the last dimension of a 2D or 3D input; * [SpatialReSampling](#nnx.SpatialReSampling) : performs bilinear resampling of a 3D or 4D input image; @@ -144,6 +145,40 @@ In some cases, this can simplify the digraph of Modules. Note that a PushTable can be associated to many PullTables, but each PullTable is associated to only one PushTable. +<a name='nnx.CTCCriterion'/> +### CTCCriterion ### +``` +criterion = nn.CTCCriterion() +``` +Creates a Criterion based on Baidus' [warp-ctc](https://github.com/baidu-research/warp-ctc) implementation. +This Module measures the loss between a 3D output of (batch x time x inputdim) and a target without needing alignment of inputs and labels. +Must have installed warp-ctc which can be installed via luarocks: +``` +luarocks install http://raw.githubusercontent.com/baidu-research/warp-ctc/master/torch_binding/rocks/warp-ctc-scm-1.rockspec +``` +Supports cuda via: +``` +criterion = nn.CTCCriterion():cuda() +``` +Example: +``` +output = torch.Tensor({{{1,2,3,4,5},{6,7,8,9,10}}}) -- Tensor of size 1x1x5 (batch x time x inputdim). +label = {{1,3}} +ctcCriterion = nn.CTCCriterion() + +print(ctcCriterion:forward(output,label)) + +ctcCriterion = ctcCriterion:cuda() -- Switch to cuda implementation. +output = output:cuda() + +print(ctcCriterion:forward(output,label)) +``` + +gives the output: +``` +4.9038286209106 +4.9038290977478 +``` <a name='nnx.MultiSoftMax'/> ### MultiSoftMax ### This Module takes 2D or 3D input and performs a softmax over the last dimension. @@ -80,6 +80,7 @@ require('nnx.SuperCriterion') require('nnx.DistNLLCriterion') require('nnx.DistMarginCriterion') require('nnx.TreeNLLCriterion') +require('nnx.CTCCriterion') -- datasets: require('nnx.DataSet') diff --git a/test/test-all.lua b/test/test-all.lua index c1ca354..7486344 100644 --- a/test/test-all.lua +++ b/test/test-all.lua @@ -524,6 +524,28 @@ function nnxtest.TreeNLLCriterion() mytester:assertTensorEq(gradInput2:narrow(2,1,1), gradInput, 0.00001) end +function nnxtest.CTCCriterion() + local criterion = nn.CTCCriterion() + local acts = torch.Tensor({{{0,0,0,0,0}}}) + local targets = {{1}} + mytester:eq(criterion:updateOutput(acts,targets), 1.6094379425049, 0, "CTCCriterion.smallTest") + local acts = + torch.Tensor({{{1,2,3,4,5}, {6,7,8,9,10}, {11,12,13,14,15}}}) + local targets = {{3,3}} + mytester:eq(criterion:updateOutput(acts,targets), 7.355742931366, 0, "CTCCriterion.mediumTest") + local acts = torch.Tensor({{{-5,-4,-3,-2,-1}, {-10,-9,-8,-7,-6}, {-15,-14,-13,-12,-11}}}) + local targets = {{2,3}} + mytester:eq(criterion:updateOutput(acts,targets), 4.938850402832, 0, "CTCCriterion.mediumNegativeTest") + local acts = + torch.Tensor({ + {{0,0,0,0,0},{0,0,0,0,0},{0,0,0,0,0}}, + {{1,2,3,4,5},{6,7,8,9,10},{11,12,13,14,15}}, + {{-5,-4,-3,-2,-1},{-10,-9,-8,-7,-6},{-15,-14,-13,-12,-11}} + }) + local targets = {{1},{3,3},{2,3}} + mytester:eq(criterion:updateOutput(acts,targets), 15.331147670746, 0, "CTCCriterion.batchTest") +end + local function blur(mean, stdv, size) local range = torch.range(1,size):float() local a = 1/(stdv*math.sqrt(2*math.pi)) |