Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/clementfarabet/lua---nnx.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSean <seannaren>2016-04-02 18:02:58 +0300
committerSean <seannaren>2016-04-02 18:02:58 +0300
commitf450325e81056bc209b2d38037c7d105a5efc813 (patch)
treec2caf8e6682336c2adaeaffe480925383e739b39
parent81d533254e39738ec95d27da96b12aa93eaaa725 (diff)
added base CTCCriterion
-rw-r--r--CTCCriterion.lua85
1 files changed, 85 insertions, 0 deletions
diff --git a/CTCCriterion.lua b/CTCCriterion.lua
new file mode 100644
index 0000000..1449db0
--- /dev/null
+++ b/CTCCriterion.lua
@@ -0,0 +1,85 @@
+require 'warp_ctc'
+
+local CTCCriterion, parent = torch.class('nn.CTCCriterion', 'nn.Criterion')
+
+function CTCCriterion:__init()
+ parent.__init(self)
+ self.acts = torch.Tensor()
+ self.convertedGradients = torch.Tensor()
+end
+
+function CTCCriterion:updateOutput(output, labels)
+ local tensorSizes = output:size()
+ local acts = self:createCTCBatch(output, tensorSizes)
+ local sizes = {} -- For each batch we state the number of time steps.
+ for x = 1, tensorSizes[1] do
+ table.insert(sizes, tensorSizes[2])
+ end
+ if (output:type() == 'torch.CudaTensor') then
+ local grads = torch.CudaTensor()
+ self.output = sumCosts(gpu_ctc(acts, grads, labels, sizes))
+ else
+ local grads = torch.Tensor()
+ self.output = sumCosts(cpu_ctc(acts:float(), grads:float(), labels, sizes))
+ end
+
+ return self.output
+end
+
+function CTCCriterion:updateGradInput(output, labels)
+ local tensorSizes = output:size()
+ local acts = self:createCTCBatch(output, tensorSizes)
+ local sizes = {}
+ for x = 1, tensorSizes[1] do
+ table.insert(sizes, tensorSizes[2])
+ end
+ local grads = acts:clone():zero()
+ if (output:type() == 'torch.CudaTensor') then
+ gpu_ctc(acts, grads, labels, sizes)
+ else
+ cpu_ctc(acts:float(), grads:float(), labels, sizes)
+ end
+ self.gradInput = self:revertBatching(grads, tensorSizes):typeAs(output)
+ return self.gradInput
+end
+
+--If batching occurs multiple costs are returned. We sum the costs and return.
+function sumCosts(list)
+ local acc
+ for k, v in ipairs(list) do
+ if 1 == k then
+ acc = v
+ else
+ acc = acc + v
+ end
+ end
+ return acc
+end
+
+--[[
+-- Converts the outputs into batch warp-ctc format seen at the end of the README here:
+-- https://github.com/baidu-research/warp-ctc/blob/master/torch_binding/TUTORIAL.md
+ ]]
+function CTCCriterion:createCTCBatch(output, sizes)
+ self.acts:resize(sizes[1] * sizes[2], sizes[3]):zero()
+ local counter = 1
+ for i = 1, sizes[2] do
+ for j = 1, sizes[1] do
+ self.acts[counter] = output[j][i]
+ counter = counter + 1
+ end
+ end
+ return self.acts
+end
+
+function CTCCriterion:revertBatching(gradients, sizes)
+ self.convertedGradients:resize(sizes[1], sizes[2], sizes[3]):zero()
+ local counter = 1
+ for i = 1, sizes[2] do
+ for j = 1, sizes[1] do
+ self.convertedGradients[j][i] = gradients[counter]
+ counter = counter + 1
+ end
+ end
+ return self.convertedGradients
+end \ No newline at end of file