diff options
author | Sean <seannaren> | 2016-04-02 18:15:10 +0300 |
---|---|---|
committer | Sean <seannaren> | 2016-04-02 18:15:10 +0300 |
commit | 8f9bafe8b09132436f006d701c7edc1f976dce8c (patch) | |
tree | de02739c8c71f007833b4dabc69fcfa7c4233a7d | |
parent | 734c810eb49027922e719c01426bb881c031baa6 (diff) |
added more documentation and assertion check on input dim
-rw-r--r-- | CTCCriterion.lua | 12 |
1 files changed, 11 insertions, 1 deletions
diff --git a/CTCCriterion.lua b/CTCCriterion.lua index 1449db0..7f628de 100644 --- a/CTCCriterion.lua +++ b/CTCCriterion.lua @@ -1,7 +1,16 @@ require 'warp_ctc' - +------------------------------------------------------------------------ +--[[ CTCCriterion ]]-- +-- CTC Alignment for sequence data where input and labels do not align. +-- Useful for speech recognition on a phoneme/character level basis. +-- Inputs assumed are in the form of batch x time x inputdim. +-- Targets assumed in the form of {{1,2},{3,4}} where {1,2} is for the first +-- element. +------------------------------------------------------------------------ local CTCCriterion, parent = torch.class('nn.CTCCriterion', 'nn.Criterion') +CTCCriterion.dim = 3 + function CTCCriterion:__init() parent.__init(self) self.acts = torch.Tensor() @@ -9,6 +18,7 @@ function CTCCriterion:__init() end function CTCCriterion:updateOutput(output, labels) + assert(output:nDimension() == CTCCriterion.dim, "Output must be a tensor of (batch x time x inputdim), recieved " .. output:nDimension() .. " dimensions") local tensorSizes = output:size() local acts = self:createCTCBatch(output, tensorSizes) local sizes = {} -- For each batch we state the number of time steps. |