Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/clementfarabet/lua---nnx.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSeanNaren <taz838@hotmail.co.uk>2016-04-16 00:49:23 +0300
committerSeanNaren <taz838@hotmail.co.uk>2016-04-16 00:49:23 +0300
commit7e5378f234f918c362a0ba465e48af9cf7b9c7ea (patch)
tree29c3a937f2f5bacbb88467b17fdfb0c23c22ed8a
parent54a41874947775391d233c8f02719d0da8798f6c (diff)
Change creation of sizes and forward pass
-rw-r--r--CTCCriterion.lua21
1 files changed, 6 insertions, 15 deletions
diff --git a/CTCCriterion.lua b/CTCCriterion.lua
index 0aa4ae5..efabf13 100644
--- a/CTCCriterion.lua
+++ b/CTCCriterion.lua
@@ -21,10 +21,8 @@ function CTCCriterion:updateOutput(output, labels)
assert(output:nDimension() == CTCCriterion.dim, "Output must be a tensor of (batch x time x inputdim), recieved " .. output:nDimension() .. " dimensions")
local tensorSizes = output:size()
local acts = self:createCTCBatch(output, tensorSizes)
- local sizes = {} -- For each batch we state the number of time steps.
- for x = 1, tensorSizes[1] do
- table.insert(sizes, tensorSizes[2])
- end
+ local sizes = torch.Tensor(tensorSizes[1]):fill(tensorSizes[2])
+ sizes = torch.totable(sizes)
if (output:type() == 'torch.CudaTensor') then
local grads = torch.CudaTensor()
self.output = sumCosts(gpu_ctc(acts, grads, labels, sizes))
@@ -39,10 +37,8 @@ end
function CTCCriterion:updateGradInput(output, labels)
local tensorSizes = output:size()
local acts = self:createCTCBatch(output, tensorSizes)
- local sizes = {}
- for x = 1, tensorSizes[1] do
- table.insert(sizes, tensorSizes[2])
- end
+ local sizes = torch.Tensor(tensorSizes[1]):fill(tensorSizes[2])
+ sizes = torch.totable(sizes)
local grads = acts:clone():zero()
if (output:type() == 'torch.CudaTensor') then
gpu_ctc(acts, grads, labels, sizes)
@@ -73,13 +69,8 @@ end
]]
function CTCCriterion:createCTCBatch(output, sizes)
self.acts:resize(sizes[1] * sizes[2], sizes[3]):zero()
- local counter = 1
- for i = 1, sizes[2] do
- for j = 1, sizes[1] do
- self.acts[counter] = output[j][i]
- counter = counter + 1
- end
- end
+ local output = output:transpose(1, 2)
+ self.acts = torch.reshape(output, sizes[1] * sizes[2], sizes[3])
return self.acts
end