Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/clementfarabet/lua---nnx.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'CTCCriterion.lua')
-rw-r--r--CTCCriterion.lua3
1 files changed, 2 insertions, 1 deletions
diff --git a/CTCCriterion.lua b/CTCCriterion.lua
index 22f13af..0aa4ae5 100644
--- a/CTCCriterion.lua
+++ b/CTCCriterion.lua
@@ -47,7 +47,8 @@ function CTCCriterion:updateGradInput(output, labels)
if (output:type() == 'torch.CudaTensor') then
gpu_ctc(acts, grads, labels, sizes)
else
- cpu_ctc(acts:float(), grads:float(), labels, sizes)
+ grads = grads:float()
+ cpu_ctc(acts:float(), grads, labels, sizes)
end
self.gradInput = self:revertBatching(grads, tensorSizes):typeAs(output)
return self.gradInput