diff options
-rw-r--r-- | CTCCriterion.lua | 12 |
1 files changed, 11 insertions, 1 deletions
diff --git a/CTCCriterion.lua b/CTCCriterion.lua index 1449db0..7f628de 100644 --- a/CTCCriterion.lua +++ b/CTCCriterion.lua @@ -1,7 +1,16 @@ require 'warp_ctc' - +------------------------------------------------------------------------ +--[[ CTCCriterion ]]-- +-- CTC Alignment for sequence data where input and labels do not align. +-- Useful for speech recognition on a phoneme/character level basis. +-- Inputs assumed are in the form of batch x time x inputdim. +-- Targets assumed in the form of {{1,2},{3,4}} where {1,2} is for the first +-- element. +------------------------------------------------------------------------ local CTCCriterion, parent = torch.class('nn.CTCCriterion', 'nn.Criterion') +CTCCriterion.dim = 3 + function CTCCriterion:__init() parent.__init(self) self.acts = torch.Tensor() @@ -9,6 +18,7 @@ function CTCCriterion:__init() end function CTCCriterion:updateOutput(output, labels) + assert(output:nDimension() == CTCCriterion.dim, "Output must be a tensor of (batch x time x inputdim), recieved " .. output:nDimension() .. " dimensions") local tensorSizes = output:size() local acts = self:createCTCBatch(output, tensorSizes) local sizes = {} -- For each batch we state the number of time steps. |