Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/clementfarabet/lua---nnx.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSoumith Chintala <soumith@gmail.com>2016-04-02 23:20:06 +0300
committerSoumith Chintala <soumith@gmail.com>2016-04-02 23:20:06 +0300
commit66d9f3e607e28a51a1a7da3c28e9a28680984e68 (patch)
treebd507a35bd2d20eb012f41982c252a26c8b5360a
parent81d533254e39738ec95d27da96b12aa93eaaa725 (diff)
parent3209358d90eda9cda19677059c2d7f875fe6766e (diff)
Merge pull request #55 from SeanNaren/master
CTC Based Criterion
-rw-r--r--CTCCriterion.lua95
-rw-r--r--README.md35
-rw-r--r--init.lua1
-rw-r--r--test/test-all.lua22
4 files changed, 153 insertions, 0 deletions
diff --git a/CTCCriterion.lua b/CTCCriterion.lua
new file mode 100644
index 0000000..22f13af
--- /dev/null
+++ b/CTCCriterion.lua
@@ -0,0 +1,95 @@
+------------------------------------------------------------------------
+--[[ CTCCriterion ]]--
+-- CTC Alignment for sequence data where input and labels do not align.
+-- Useful for speech recognition on a phoneme/character level basis.
+-- Inputs assumed are in the form of batch x time x inputdim.
+-- Targets assumed in the form of {{1,2},{3,4}} where {1,2} is for the first
+-- element.
+------------------------------------------------------------------------
+local CTCCriterion, parent = torch.class('nn.CTCCriterion', 'nn.Criterion')
+
+CTCCriterion.dim = 3
+
+function CTCCriterion:__init()
+ parent.__init(self)
+ require 'warp_ctc'
+ self.acts = torch.Tensor()
+ self.convertedGradients = torch.Tensor()
+end
+
+function CTCCriterion:updateOutput(output, labels)
+ assert(output:nDimension() == CTCCriterion.dim, "Output must be a tensor of (batch x time x inputdim), recieved " .. output:nDimension() .. " dimensions")
+ local tensorSizes = output:size()
+ local acts = self:createCTCBatch(output, tensorSizes)
+ local sizes = {} -- For each batch we state the number of time steps.
+ for x = 1, tensorSizes[1] do
+ table.insert(sizes, tensorSizes[2])
+ end
+ if (output:type() == 'torch.CudaTensor') then
+ local grads = torch.CudaTensor()
+ self.output = sumCosts(gpu_ctc(acts, grads, labels, sizes))
+ else
+ local grads = torch.Tensor()
+ self.output = sumCosts(cpu_ctc(acts:float(), grads:float(), labels, sizes))
+ end
+
+ return self.output
+end
+
+function CTCCriterion:updateGradInput(output, labels)
+ local tensorSizes = output:size()
+ local acts = self:createCTCBatch(output, tensorSizes)
+ local sizes = {}
+ for x = 1, tensorSizes[1] do
+ table.insert(sizes, tensorSizes[2])
+ end
+ local grads = acts:clone():zero()
+ if (output:type() == 'torch.CudaTensor') then
+ gpu_ctc(acts, grads, labels, sizes)
+ else
+ cpu_ctc(acts:float(), grads:float(), labels, sizes)
+ end
+ self.gradInput = self:revertBatching(grads, tensorSizes):typeAs(output)
+ return self.gradInput
+end
+
+--If batching occurs multiple costs are returned. We sum the costs and return.
+function sumCosts(list)
+ local acc
+ for k, v in ipairs(list) do
+ if 1 == k then
+ acc = v
+ else
+ acc = acc + v
+ end
+ end
+ return acc
+end
+
+--[[
+-- Converts the outputs into batch warp-ctc format seen at the end of the README here:
+-- https://github.com/baidu-research/warp-ctc/blob/master/torch_binding/TUTORIAL.md
+ ]]
+function CTCCriterion:createCTCBatch(output, sizes)
+ self.acts:resize(sizes[1] * sizes[2], sizes[3]):zero()
+ local counter = 1
+ for i = 1, sizes[2] do
+ for j = 1, sizes[1] do
+ self.acts[counter] = output[j][i]
+ counter = counter + 1
+ end
+ end
+ return self.acts
+end
+
+function CTCCriterion:revertBatching(gradients, sizes)
+ self.convertedGradients:resize(sizes[1], sizes[2], sizes[3]):zero()
+ local counter = 1
+ for i = 1, sizes[2] do
+ for j = 1, sizes[1] do
+ self.convertedGradients[j][i] = gradients[counter]
+ counter = counter + 1
+ end
+ end
+ return self.convertedGradients
+end \ No newline at end of file
diff --git a/README.md b/README.md
index c2d5777..d28d8b4 100644
--- a/README.md
+++ b/README.md
@@ -10,6 +10,7 @@ This section includes documentation for the following objects:
* [SoftMaxTree](#nnx.SoftMaxTree) : a hierarchical log-softmax Module;
* [TreeNLLCriterion](#nnx.TreeNLLCriterion) : a negative log-likelihood Criterion for the SoftMaxTree;
+ * [CTCCriterion](#nnx.CTCCriterion) : a Connectionist Temporal Classification Criterion based on [warp-ctc](https://github.com/baidu-research/warp-ctc);
* [PushTable (and PullTable)](#nnx.PushTable) : extracts a table element and inserts it later in the network;
* [MultiSoftMax](#nnx.MultiSoftMax) : performs a softmax over the last dimension of a 2D or 3D input;
* [SpatialReSampling](#nnx.SpatialReSampling) : performs bilinear resampling of a 3D or 4D input image;
@@ -144,6 +145,40 @@ In some cases, this can simplify the digraph of Modules. Note that
a PushTable can be associated to many PullTables, but each PullTable
is associated to only one PushTable.
+<a name='nnx.CTCCriterion'/>
+### CTCCriterion ###
+```
+criterion = nn.CTCCriterion()
+```
+Creates a Criterion based on Baidus' [warp-ctc](https://github.com/baidu-research/warp-ctc) implementation.
+This Module measures the loss between a 3D output of (batch x time x inputdim) and a target without needing alignment of inputs and labels.
+Must have installed warp-ctc which can be installed via luarocks:
+```
+luarocks install http://raw.githubusercontent.com/baidu-research/warp-ctc/master/torch_binding/rocks/warp-ctc-scm-1.rockspec
+```
+Supports cuda via:
+```
+criterion = nn.CTCCriterion():cuda()
+```
+Example:
+```
+output = torch.Tensor({{{1,2,3,4,5},{6,7,8,9,10}}}) -- Tensor of size 1x1x5 (batch x time x inputdim).
+label = {{1,3}}
+ctcCriterion = nn.CTCCriterion()
+
+print(ctcCriterion:forward(output,label))
+
+ctcCriterion = ctcCriterion:cuda() -- Switch to cuda implementation.
+output = output:cuda()
+
+print(ctcCriterion:forward(output,label))
+```
+
+gives the output:
+```
+4.9038286209106
+4.9038290977478
+```
<a name='nnx.MultiSoftMax'/>
### MultiSoftMax ###
This Module takes 2D or 3D input and performs a softmax over the last dimension.
diff --git a/init.lua b/init.lua
index 719ad11..b1e874b 100644
--- a/init.lua
+++ b/init.lua
@@ -80,6 +80,7 @@ require('nnx.SuperCriterion')
require('nnx.DistNLLCriterion')
require('nnx.DistMarginCriterion')
require('nnx.TreeNLLCriterion')
+require('nnx.CTCCriterion')
-- datasets:
require('nnx.DataSet')
diff --git a/test/test-all.lua b/test/test-all.lua
index c1ca354..7486344 100644
--- a/test/test-all.lua
+++ b/test/test-all.lua
@@ -524,6 +524,28 @@ function nnxtest.TreeNLLCriterion()
mytester:assertTensorEq(gradInput2:narrow(2,1,1), gradInput, 0.00001)
end
+function nnxtest.CTCCriterion()
+ local criterion = nn.CTCCriterion()
+ local acts = torch.Tensor({{{0,0,0,0,0}}})
+ local targets = {{1}}
+ mytester:eq(criterion:updateOutput(acts,targets), 1.6094379425049, 0, "CTCCriterion.smallTest")
+ local acts =
+ torch.Tensor({{{1,2,3,4,5}, {6,7,8,9,10}, {11,12,13,14,15}}})
+ local targets = {{3,3}}
+ mytester:eq(criterion:updateOutput(acts,targets), 7.355742931366, 0, "CTCCriterion.mediumTest")
+ local acts = torch.Tensor({{{-5,-4,-3,-2,-1}, {-10,-9,-8,-7,-6}, {-15,-14,-13,-12,-11}}})
+ local targets = {{2,3}}
+ mytester:eq(criterion:updateOutput(acts,targets), 4.938850402832, 0, "CTCCriterion.mediumNegativeTest")
+ local acts =
+ torch.Tensor({
+ {{0,0,0,0,0},{0,0,0,0,0},{0,0,0,0,0}},
+ {{1,2,3,4,5},{6,7,8,9,10},{11,12,13,14,15}},
+ {{-5,-4,-3,-2,-1},{-10,-9,-8,-7,-6},{-15,-14,-13,-12,-11}}
+ })
+ local targets = {{1},{3,3},{2,3}}
+ mytester:eq(criterion:updateOutput(acts,targets), 15.331147670746, 0, "CTCCriterion.batchTest")
+end
+
local function blur(mean, stdv, size)
local range = torch.range(1,size):float()
local a = 1/(stdv*math.sqrt(2*math.pi))