From e0122ccbbd701aee50183464c9b5acf39481b16e Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Tue, 30 May 2017 13:33:58 +0100 Subject: Simplified batchfirst gradient modification --- CTCCriterion.lua | 20 ++------------------ 1 file changed, 2 insertions(+), 18 deletions(-) (limited to 'CTCCriterion.lua') diff --git a/CTCCriterion.lua b/CTCCriterion.lua index 9c24683..5d2ab4f 100644 --- a/CTCCriterion.lua +++ b/CTCCriterion.lua @@ -50,9 +50,9 @@ function CTCCriterion:updateGradInput(input, target) return self.gradInput end if self.batchFirst then -- batchSize x seqLen x outputDim - self.gradInput = inverseInterleave(self.gradInput, input:size(1)) + self.gradInput = self.gradInput:view(input:size(2), input:size(1), -1):transpose(1, 2) else -- seqLen x batchSize x outputDim - self.gradInput:view(self.gradInput, input:size(1), input:size(2), -1) + self.gradInput:view(self.gradInput, input:size(1), input:size(2), -1) end return self.gradInput end @@ -66,22 +66,6 @@ function CTCCriterion:makeContiguous(input) return input end -function inverseInterleave(tensor, batchSize) - local sizes = torch.LongStorage(3) - sizes[1] = batchSize - sizes[2] = tensor:size(1) / batchSize - sizes[3] = tensor:size(2) - local result = tensor.new():resize(sizes):zero() - local counter = 1 - for i = 1, sizes[2] do - for j = 1, sizes[1] do - result[j][i]:copy(tensor[counter]) - counter = counter + 1 - end - end - return result -end - --If batching occurs multiple costs are returned. We sum the costs and return. function sumCosts(list) local acc -- cgit v1.2.3