diff options
author | Soumith Chintala <soumith@gmail.com> | 2016-04-16 00:52:18 +0300 |
---|---|---|
committer | Soumith Chintala <soumith@gmail.com> | 2016-04-16 00:52:18 +0300 |
commit | 6283ef3d9fed244422788d0329998a485e4f2a75 (patch) | |
tree | 29c3a937f2f5bacbb88467b17fdfb0c23c22ed8a | |
parent | 54a41874947775391d233c8f02719d0da8798f6c (diff) | |
parent | 7e5378f234f918c362a0ba465e48af9cf7b9c7ea (diff) |
Merge pull request #59 from SeanNaren/master
CTCCriterion optimisation
-rw-r--r-- | CTCCriterion.lua | 21 |
1 files changed, 6 insertions, 15 deletions
diff --git a/CTCCriterion.lua b/CTCCriterion.lua index 0aa4ae5..efabf13 100644 --- a/CTCCriterion.lua +++ b/CTCCriterion.lua @@ -21,10 +21,8 @@ function CTCCriterion:updateOutput(output, labels) assert(output:nDimension() == CTCCriterion.dim, "Output must be a tensor of (batch x time x inputdim), recieved " .. output:nDimension() .. " dimensions") local tensorSizes = output:size() local acts = self:createCTCBatch(output, tensorSizes) - local sizes = {} -- For each batch we state the number of time steps. - for x = 1, tensorSizes[1] do - table.insert(sizes, tensorSizes[2]) - end + local sizes = torch.Tensor(tensorSizes[1]):fill(tensorSizes[2]) + sizes = torch.totable(sizes) if (output:type() == 'torch.CudaTensor') then local grads = torch.CudaTensor() self.output = sumCosts(gpu_ctc(acts, grads, labels, sizes)) @@ -39,10 +37,8 @@ end function CTCCriterion:updateGradInput(output, labels) local tensorSizes = output:size() local acts = self:createCTCBatch(output, tensorSizes) - local sizes = {} - for x = 1, tensorSizes[1] do - table.insert(sizes, tensorSizes[2]) - end + local sizes = torch.Tensor(tensorSizes[1]):fill(tensorSizes[2]) + sizes = torch.totable(sizes) local grads = acts:clone():zero() if (output:type() == 'torch.CudaTensor') then gpu_ctc(acts, grads, labels, sizes) @@ -73,13 +69,8 @@ end ]] function CTCCriterion:createCTCBatch(output, sizes) self.acts:resize(sizes[1] * sizes[2], sizes[3]):zero() - local counter = 1 - for i = 1, sizes[2] do - for j = 1, sizes[1] do - self.acts[counter] = output[j][i] - counter = counter + 1 - end - end + local output = output:transpose(1, 2) + self.acts = torch.reshape(output, sizes[1] * sizes[2], sizes[3]) return self.acts end |