1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
|
require 'warp_ctc'
local CTCCriterion, parent = torch.class('nn.CTCCriterion', 'nn.Criterion')
function CTCCriterion:__init()
parent.__init(self)
self.acts = torch.Tensor()
self.convertedGradients = torch.Tensor()
end
function CTCCriterion:updateOutput(output, labels)
local tensorSizes = output:size()
local acts = self:createCTCBatch(output, tensorSizes)
local sizes = {} -- For each batch we state the number of time steps.
for x = 1, tensorSizes[1] do
table.insert(sizes, tensorSizes[2])
end
if (output:type() == 'torch.CudaTensor') then
local grads = torch.CudaTensor()
self.output = sumCosts(gpu_ctc(acts, grads, labels, sizes))
else
local grads = torch.Tensor()
self.output = sumCosts(cpu_ctc(acts:float(), grads:float(), labels, sizes))
end
return self.output
end
function CTCCriterion:updateGradInput(output, labels)
local tensorSizes = output:size()
local acts = self:createCTCBatch(output, tensorSizes)
local sizes = {}
for x = 1, tensorSizes[1] do
table.insert(sizes, tensorSizes[2])
end
local grads = acts:clone():zero()
if (output:type() == 'torch.CudaTensor') then
gpu_ctc(acts, grads, labels, sizes)
else
cpu_ctc(acts:float(), grads:float(), labels, sizes)
end
self.gradInput = self:revertBatching(grads, tensorSizes):typeAs(output)
return self.gradInput
end
--If batching occurs multiple costs are returned. We sum the costs and return.
function sumCosts(list)
local acc
for k, v in ipairs(list) do
if 1 == k then
acc = v
else
acc = acc + v
end
end
return acc
end
--[[
-- Converts the outputs into batch warp-ctc format seen at the end of the README here:
-- https://github.com/baidu-research/warp-ctc/blob/master/torch_binding/TUTORIAL.md
]]
function CTCCriterion:createCTCBatch(output, sizes)
self.acts:resize(sizes[1] * sizes[2], sizes[3]):zero()
local counter = 1
for i = 1, sizes[2] do
for j = 1, sizes[1] do
self.acts[counter] = output[j][i]
counter = counter + 1
end
end
return self.acts
end
function CTCCriterion:revertBatching(gradients, sizes)
self.convertedGradients:resize(sizes[1], sizes[2], sizes[3]):zero()
local counter = 1
for i = 1, sizes[2] do
for j = 1, sizes[1] do
self.convertedGradients[j][i] = gradients[counter]
counter = counter + 1
end
end
return self.convertedGradients
end
|