Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/clementfarabet/lua---nnx.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSoumith Chintala <soumith@gmail.com>2016-05-17 04:44:48 +0300
committerSoumith Chintala <soumith@gmail.com>2016-05-17 04:44:48 +0300
commit73a49c80b5c090230dd1605e20c683cbf04e684e (patch)
treeb99170716e816f10e732ebb1add9999b8c79cce8
parentf1517225248670c947209ee9fffd731914c96bd2 (diff)
parenta0306ea4172331a007c51d250b9e5bbd58074284 (diff)
Merge pull request #61 from SeanNaren/master
CTCCriterion variable length and batchFirst support
-rw-r--r--CTCCriterion.lua119
-rw-r--r--test/test-all.lua53
2 files changed, 98 insertions, 74 deletions
diff --git a/CTCCriterion.lua b/CTCCriterion.lua
index efabf13..6d4a2ec 100644
--- a/CTCCriterion.lua
+++ b/CTCCriterion.lua
@@ -1,55 +1,87 @@
------------------------------------------------------------------------
---[[ CTCCriterion ]]--
+--[[ CTCCriterion ]] --
-- CTC Alignment for sequence data where input and labels do not align.
-- Useful for speech recognition on a phoneme/character level basis.
--- Inputs assumed are in the form of batch x time x inputdim.
+-- Inputs assumed are in the form of seqLength x batch x inputDim.
+-- If batchFirst = true then input in the form of batch x seqLength x inputDim.
-- Targets assumed in the form of {{1,2},{3,4}} where {1,2} is for the first
--- element.
+-- element and so forth.
------------------------------------------------------------------------
local CTCCriterion, parent = torch.class('nn.CTCCriterion', 'nn.Criterion')
-CTCCriterion.dim = 3
-
-function CTCCriterion:__init()
- parent.__init(self)
+function CTCCriterion:__init(batchFirst)
require 'warp_ctc'
+ parent.__init(self)
self.acts = torch.Tensor()
- self.convertedGradients = torch.Tensor()
+ self.batchFirst = batchFirst or false
+end
+
+function CTCCriterion:forward(input, target, sizes)
+ return self:updateOutput(input, target, sizes)
end
-function CTCCriterion:updateOutput(output, labels)
- assert(output:nDimension() == CTCCriterion.dim, "Output must be a tensor of (batch x time x inputdim), recieved " .. output:nDimension() .. " dimensions")
- local tensorSizes = output:size()
- local acts = self:createCTCBatch(output, tensorSizes)
- local sizes = torch.Tensor(tensorSizes[1]):fill(tensorSizes[2])
- sizes = torch.totable(sizes)
- if (output:type() == 'torch.CudaTensor') then
- local grads = torch.CudaTensor()
- self.output = sumCosts(gpu_ctc(acts, grads, labels, sizes))
+function CTCCriterion:updateOutput(input, target, sizes)
+ assert(sizes,
+ "You must pass the size of each sequence in the batch as a tensor")
+ local acts = self.acts
+ if input:dim() == 3 then
+ acts:resizeAs(input):copy(input)
+ if self.batchFirst then
+ acts = acts:transpose(1, 2)
+ acts = self:makeContiguous(acts)
+ end
+ acts:view(acts, acts:size(1) * acts:size(2), -1)
+ end
+ assert(acts:nDimension() == 2)
+ self.sizes = torch.totable(sizes)
+ self.gradInput = acts.new():resizeAs(acts):zero()
+ if input:type() == 'torch.CudaTensor' then
+ self.output = sumCosts(gpu_ctc(acts, self.gradInput, target, self.sizes))
else
- local grads = torch.Tensor()
- self.output = sumCosts(cpu_ctc(acts:float(), grads:float(), labels, sizes))
+ acts = acts:float()
+ self.gradInput = self.gradInput:float()
+ self.output = sumCosts(cpu_ctc(acts, self.gradInput, target, self.sizes))
end
-
return self.output
end
-function CTCCriterion:updateGradInput(output, labels)
- local tensorSizes = output:size()
- local acts = self:createCTCBatch(output, tensorSizes)
- local sizes = torch.Tensor(tensorSizes[1]):fill(tensorSizes[2])
- sizes = torch.totable(sizes)
- local grads = acts:clone():zero()
- if (output:type() == 'torch.CudaTensor') then
- gpu_ctc(acts, grads, labels, sizes)
- else
- grads = grads:float()
- cpu_ctc(acts:float(), grads, labels, sizes)
+function CTCCriterion:updateGradInput(input, target)
+ if input:dim() == 2 then -- (seqLen * batchSize) x outputDim
+ return self.gradInput
+ end
+ if self.batchFirst then -- batchSize x seqLen x outputDim
+ self.gradInput = inverseInterleave(self.gradInput, input:size(1))
+ else -- seqLen x batchSize x outputDim
+ self.gradInput:view(self.gradInput, input:size(1), input:size(2), -1)
end
- self.gradInput = self:revertBatching(grads, tensorSizes):typeAs(output)
return self.gradInput
end
+function CTCCriterion:makeContiguous(input)
+ if not input:isContiguous() then
+ self._input = self._input or input.new()
+ self._input:typeAs(input):resizeAs(input):copy(input)
+ input = self._input
+ end
+ return input
+end
+
+function inverseInterleave(tensor, batchSize)
+ local sizes = torch.LongStorage(3)
+ sizes[1] = batchSize
+ sizes[2] = tensor:size(1) / batchSize
+ sizes[3] = tensor:size(2)
+ local result = tensor.new():resize(sizes):zero()
+ local counter = 1
+ for i = 1, sizes[2] do
+ for j = 1, sizes[1] do
+ result[j][i]:copy(tensor[counter])
+ counter = counter + 1
+ end
+ end
+ return result
+end
+
--If batching occurs multiple costs are returned. We sum the costs and return.
function sumCosts(list)
local acc
@@ -61,27 +93,4 @@ function sumCosts(list)
end
end
return acc
-end
-
---[[
--- Converts the outputs into batch warp-ctc format seen at the end of the README here:
--- https://github.com/baidu-research/warp-ctc/blob/master/torch_binding/TUTORIAL.md
- ]]
-function CTCCriterion:createCTCBatch(output, sizes)
- self.acts:resize(sizes[1] * sizes[2], sizes[3]):zero()
- local output = output:transpose(1, 2)
- self.acts = torch.reshape(output, sizes[1] * sizes[2], sizes[3])
- return self.acts
-end
-
-function CTCCriterion:revertBatching(gradients, sizes)
- self.convertedGradients:resize(sizes[1], sizes[2], sizes[3]):zero()
- local counter = 1
- for i = 1, sizes[2] do
- for j = 1, sizes[1] do
- self.convertedGradients[j][i] = gradients[counter]
- counter = counter + 1
- end
- end
- return self.convertedGradients
end \ No newline at end of file
diff --git a/test/test-all.lua b/test/test-all.lua
index 7486344..edc69aa 100644
--- a/test/test-all.lua
+++ b/test/test-all.lua
@@ -525,25 +525,40 @@ function nnxtest.TreeNLLCriterion()
end
function nnxtest.CTCCriterion()
- local criterion = nn.CTCCriterion()
- local acts = torch.Tensor({{{0,0,0,0,0}}})
- local targets = {{1}}
- mytester:eq(criterion:updateOutput(acts,targets), 1.6094379425049, 0, "CTCCriterion.smallTest")
- local acts =
- torch.Tensor({{{1,2,3,4,5}, {6,7,8,9,10}, {11,12,13,14,15}}})
- local targets = {{3,3}}
- mytester:eq(criterion:updateOutput(acts,targets), 7.355742931366, 0, "CTCCriterion.mediumTest")
- local acts = torch.Tensor({{{-5,-4,-3,-2,-1}, {-10,-9,-8,-7,-6}, {-15,-14,-13,-12,-11}}})
- local targets = {{2,3}}
- mytester:eq(criterion:updateOutput(acts,targets), 4.938850402832, 0, "CTCCriterion.mediumNegativeTest")
- local acts =
- torch.Tensor({
- {{0,0,0,0,0},{0,0,0,0,0},{0,0,0,0,0}},
- {{1,2,3,4,5},{6,7,8,9,10},{11,12,13,14,15}},
- {{-5,-4,-3,-2,-1},{-10,-9,-8,-7,-6},{-15,-14,-13,-12,-11}}
- })
- local targets = {{1},{3,3},{2,3}}
- mytester:eq(criterion:updateOutput(acts,targets), 15.331147670746, 0, "CTCCriterion.batchTest")
+ local criterion = nn.CTCCriterion()
+ local acts = torch.Tensor({{{0,0,0,0,0}}}):transpose(1, 2):contiguous() -- input is seqLength x batch x inputDim
+ local targets = {{1}}
+ local sizes = torch.Tensor({1})
+ mytester:eq(criterion:updateOutput(acts, targets, sizes), 1.6094379425049, precision, "CTCCriterion.smallTest")
+ local acts =
+ torch.Tensor({{{1,2,3,4,5}, {6,7,8,9,10}, {11,12,13,14,15}}})
+ local targets = {{3,3}}
+ local sizes = torch.Tensor({3})
+ mytester:eq(criterion:updateOutput(acts, targets, sizes), 7.355742931366, precision, "CTCCriterion.mediumTest")
+ local acts = torch.Tensor({{{-5,-4,-3,-2,-1}, {-10,-9,-8,-7,-6}, {-15,-14,-13,-12,-11}}}):transpose(1, 2):contiguous()
+ local targets = {{2,3}}
+ local sizes = torch.Tensor({3})
+ mytester:eq(criterion:updateOutput(acts, targets, sizes), 4.9388499259949, precision, "CTCCriterion.mediumNegativeTest")
+ local acts =
+ torch.Tensor({
+ {{0,0,0,0,0},{0,0,0,0,0},{0,0,0,0,0}},
+ {{1,2,3,4,5},{6,7,8,9,10},{11,12,13,14,15}},
+ {{-5,-4,-3,-2,-1},{-10,-9,-8,-7,-6},{-15,-14,-13,-12,-11}}
+ }):transpose(1, 2):contiguous()
+ local targets = {{1},{3,3},{2,3}}
+ local sizes = torch.Tensor({1,3,3})
+ mytester:eq(criterion:updateOutput(acts, targets, sizes), 13.904030799866, precision, "CTCCriterion.batchTest")
+ local gradOutputNorm = criterion:updateGradInput(acts, targets, sizes)
+ criterion = nn.CTCCriterion(true) -- batchFirst true, input is batch x seqLength x inputDim
+ local batchFirstActs =
+ torch.Tensor({
+ {{0,0,0,0,0},{0,0,0,0,0},{0,0,0,0,0}},
+ {{1,2,3,4,5},{6,7,8,9,10},{11,12,13,14,15}},
+ {{-5,-4,-3,-2,-1},{-10,-9,-8,-7,-6},{-15,-14,-13,-12,-11}}
+ })
+ mytester:eq(criterion:updateOutput(batchFirstActs, targets, sizes), 13.904030799866, precision, "CTCCriterion.batchFirstTest")
+ local gradOutputBatchFirst = criterion:updateGradInput(acts, targets, sizes)
+ mytester:assertTensorEq(gradOutputBatchFirst:transpose(1, 2), gradOutputNorm, precision, "CTCCriterion.gradCheckTest")
end
local function blur(mean, stdv, size)