diff options
Diffstat (limited to 'CTCCriterion.lua')
-rw-r--r-- | CTCCriterion.lua | 3 |
1 files changed, 2 insertions, 1 deletions
diff --git a/CTCCriterion.lua b/CTCCriterion.lua index 22f13af..0aa4ae5 100644 --- a/CTCCriterion.lua +++ b/CTCCriterion.lua @@ -47,7 +47,8 @@ function CTCCriterion:updateGradInput(output, labels) if (output:type() == 'torch.CudaTensor') then gpu_ctc(acts, grads, labels, sizes) else - cpu_ctc(acts:float(), grads:float(), labels, sizes) + grads = grads:float() + cpu_ctc(acts:float(), grads, labels, sizes) end self.gradInput = self:revertBatching(grads, tensorSizes):typeAs(output) return self.gradInput |