Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/clementfarabet/lua---nnx.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSeanNaren <seannaren@hotmail.com>2017-05-30 15:33:58 +0300
committerSeanNaren <seannaren@hotmail.com>2017-05-30 15:33:58 +0300
commite0122ccbbd701aee50183464c9b5acf39481b16e (patch)
tree58510b3c2fff04334d292005a6f733925ec4177e
parentc5c4d6adcdfe0d36404ca241dad66cb4a460f41f (diff)
Simplified batchfirst gradient modification
-rw-r--r--CTCCriterion.lua20
1 files changed, 2 insertions, 18 deletions
diff --git a/CTCCriterion.lua b/CTCCriterion.lua
index 9c24683..5d2ab4f 100644
--- a/CTCCriterion.lua
+++ b/CTCCriterion.lua
@@ -50,9 +50,9 @@ function CTCCriterion:updateGradInput(input, target)
return self.gradInput
end
if self.batchFirst then -- batchSize x seqLen x outputDim
- self.gradInput = inverseInterleave(self.gradInput, input:size(1))
+ self.gradInput = self.gradInput:view(input:size(2), input:size(1), -1):transpose(1, 2)
else -- seqLen x batchSize x outputDim
- self.gradInput:view(self.gradInput, input:size(1), input:size(2), -1)
+ self.gradInput:view(self.gradInput, input:size(1), input:size(2), -1)
end
return self.gradInput
end
@@ -66,22 +66,6 @@ function CTCCriterion:makeContiguous(input)
return input
end
-function inverseInterleave(tensor, batchSize)
- local sizes = torch.LongStorage(3)
- sizes[1] = batchSize
- sizes[2] = tensor:size(1) / batchSize
- sizes[3] = tensor:size(2)
- local result = tensor.new():resize(sizes):zero()
- local counter = 1
- for i = 1, sizes[2] do
- for j = 1, sizes[1] do
- result[j][i]:copy(tensor[counter])
- counter = counter + 1
- end
- end
- return result
-end
-
--If batching occurs multiple costs are returned. We sum the costs and return.
function sumCosts(list)
local acc