Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/clementfarabet/lua---nnx.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSean <seannaren>2016-04-02 18:15:10 +0300
committerSean <seannaren>2016-04-02 18:15:10 +0300
commit8f9bafe8b09132436f006d701c7edc1f976dce8c (patch)
treede02739c8c71f007833b4dabc69fcfa7c4233a7d
parent734c810eb49027922e719c01426bb881c031baa6 (diff)
added more documentation and assertion check on input dim
-rw-r--r--CTCCriterion.lua12
1 files changed, 11 insertions, 1 deletions
diff --git a/CTCCriterion.lua b/CTCCriterion.lua
index 1449db0..7f628de 100644
--- a/CTCCriterion.lua
+++ b/CTCCriterion.lua
@@ -1,7 +1,16 @@
require 'warp_ctc'
-
+------------------------------------------------------------------------
+--[[ CTCCriterion ]]--
+-- CTC Alignment for sequence data where input and labels do not align.
+-- Useful for speech recognition on a phoneme/character level basis.
+-- Inputs assumed are in the form of batch x time x inputdim.
+-- Targets assumed in the form of {{1,2},{3,4}} where {1,2} is for the first
+-- element.
+------------------------------------------------------------------------
local CTCCriterion, parent = torch.class('nn.CTCCriterion', 'nn.Criterion')
+CTCCriterion.dim = 3
+
function CTCCriterion:__init()
parent.__init(self)
self.acts = torch.Tensor()
@@ -9,6 +18,7 @@ function CTCCriterion:__init()
end
function CTCCriterion:updateOutput(output, labels)
+ assert(output:nDimension() == CTCCriterion.dim, "Output must be a tensor of (batch x time x inputdim), recieved " .. output:nDimension() .. " dimensions")
local tensorSizes = output:size()
local acts = self:createCTCBatch(output, tensorSizes)
local sizes = {} -- For each batch we state the number of time steps.