1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
|
------------------------------------------------------------------------
--[[ CTCCriterion ]]--
-- CTC Alignment for sequence data where input and labels do not align.
-- Useful for speech recognition on a phoneme/character level basis.
-- Inputs assumed are in the form of batch x time x inputdim.
-- Targets assumed in the form of {{1,2},{3,4}} where {1,2} is for the first
-- element.
------------------------------------------------------------------------
local CTCCriterion, parent = torch.class('nn.CTCCriterion', 'nn.Criterion')
CTCCriterion.dim = 3
function CTCCriterion:__init()
parent.__init(self)
require 'warp_ctc'
self.acts = torch.Tensor()
self.convertedGradients = torch.Tensor()
end
function CTCCriterion:updateOutput(output, labels)
assert(output:nDimension() == CTCCriterion.dim, "Output must be a tensor of (batch x time x inputdim), recieved " .. output:nDimension() .. " dimensions")
local tensorSizes = output:size()
local acts = self:createCTCBatch(output, tensorSizes)
local sizes = torch.Tensor(tensorSizes[1]):fill(tensorSizes[2])
sizes = torch.totable(sizes)
if (output:type() == 'torch.CudaTensor') then
local grads = torch.CudaTensor()
self.output = sumCosts(gpu_ctc(acts, grads, labels, sizes))
else
local grads = torch.Tensor()
self.output = sumCosts(cpu_ctc(acts:float(), grads:float(), labels, sizes))
end
return self.output
end
function CTCCriterion:updateGradInput(output, labels)
local tensorSizes = output:size()
local acts = self:createCTCBatch(output, tensorSizes)
local sizes = torch.Tensor(tensorSizes[1]):fill(tensorSizes[2])
sizes = torch.totable(sizes)
local grads = acts:clone():zero()
if (output:type() == 'torch.CudaTensor') then
gpu_ctc(acts, grads, labels, sizes)
else
grads = grads:float()
cpu_ctc(acts:float(), grads, labels, sizes)
end
self.gradInput = self:revertBatching(grads, tensorSizes):typeAs(output)
return self.gradInput
end
--If batching occurs multiple costs are returned. We sum the costs and return.
function sumCosts(list)
local acc
for k, v in ipairs(list) do
if 1 == k then
acc = v
else
acc = acc + v
end
end
return acc
end
--[[
-- Converts the outputs into batch warp-ctc format seen at the end of the README here:
-- https://github.com/baidu-research/warp-ctc/blob/master/torch_binding/TUTORIAL.md
]]
function CTCCriterion:createCTCBatch(output, sizes)
self.acts:resize(sizes[1] * sizes[2], sizes[3]):zero()
local output = output:transpose(1, 2)
self.acts = torch.reshape(output, sizes[1] * sizes[2], sizes[3])
return self.acts
end
function CTCCriterion:revertBatching(gradients, sizes)
self.convertedGradients:resize(sizes[1], sizes[2], sizes[3]):zero()
local counter = 1
for i = 1, sizes[2] do
for j = 1, sizes[1] do
self.convertedGradients[j][i] = gradients[counter]
counter = counter + 1
end
end
return self.convertedGradients
end
|