Welcome to mirror list, hosted at ThFree Co, Russian Federation.

CTCCriterion.lua - github.com/clementfarabet/lua---nnx.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: 1449db037f7f7600cdb2df303db9f0db539b8dc5 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
require 'warp_ctc'

local CTCCriterion, parent = torch.class('nn.CTCCriterion', 'nn.Criterion')

function CTCCriterion:__init()
    parent.__init(self)
    self.acts = torch.Tensor()
    self.convertedGradients = torch.Tensor()
end

function CTCCriterion:updateOutput(output, labels)
    local tensorSizes = output:size()
    local acts = self:createCTCBatch(output, tensorSizes)
    local sizes = {} -- For each batch we state the number of time steps.
    for x = 1, tensorSizes[1] do
        table.insert(sizes, tensorSizes[2])
    end
    if (output:type() == 'torch.CudaTensor') then
        local grads = torch.CudaTensor()
        self.output = sumCosts(gpu_ctc(acts, grads, labels, sizes))
    else
        local grads = torch.Tensor()
        self.output = sumCosts(cpu_ctc(acts:float(), grads:float(), labels, sizes))
    end

    return self.output
end

function CTCCriterion:updateGradInput(output, labels)
    local tensorSizes = output:size()
    local acts = self:createCTCBatch(output, tensorSizes)
    local sizes = {}
    for x = 1, tensorSizes[1] do
        table.insert(sizes, tensorSizes[2])
    end
    local grads = acts:clone():zero()
    if (output:type() == 'torch.CudaTensor') then
        gpu_ctc(acts, grads, labels, sizes)
    else
        cpu_ctc(acts:float(), grads:float(), labels, sizes)
    end
    self.gradInput = self:revertBatching(grads, tensorSizes):typeAs(output)
    return self.gradInput
end

--If batching occurs multiple costs are returned. We sum the costs and return.
function sumCosts(list)
    local acc
    for k, v in ipairs(list) do
        if 1 == k then
            acc = v
        else
            acc = acc + v
        end
    end
    return acc
end

--[[
-- Converts the outputs into batch warp-ctc format seen at the end of the README here:
-- https://github.com/baidu-research/warp-ctc/blob/master/torch_binding/TUTORIAL.md
 ]]
function CTCCriterion:createCTCBatch(output, sizes)
    self.acts:resize(sizes[1] * sizes[2], sizes[3]):zero()
    local counter = 1
    for i = 1, sizes[2] do
        for j = 1, sizes[1] do
            self.acts[counter] = output[j][i]
            counter = counter + 1
        end
    end
    return self.acts
end

function CTCCriterion:revertBatching(gradients, sizes)
    self.convertedGradients:resize(sizes[1], sizes[2], sizes[3]):zero()
    local counter = 1
    for i = 1, sizes[2] do
        for j = 1, sizes[1] do
            self.convertedGradients[j][i] = gradients[counter]
            counter = counter + 1
        end
    end
    return self.convertedGradients
end