Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/clementfarabet/lua---nnx.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'test/test-all.lua')
-rw-r--r--test/test-all.lua12
1 files changed, 10 insertions, 2 deletions
diff --git a/test/test-all.lua b/test/test-all.lua
index edc69aa..80ed910 100644
--- a/test/test-all.lua
+++ b/test/test-all.lua
@@ -547,7 +547,7 @@ function nnxtest.CTCCriterion()
}):transpose(1, 2):contiguous()
local targets = {{1},{3,3},{2,3}}
local sizes = torch.Tensor({1,3,3})
- mytester:eq(criterion:updateOutput(acts, targets, sizes), 13.904030799866, precision, "CTCCriterion.batchTest")
+ mytester:eq(criterion:updateOutput(acts, targets, sizes), 13.904030799866 / 3, precision, "CTCCriterion.batchTest")
local gradOutputNorm = criterion:updateGradInput(acts, targets, sizes)
criterion = nn.CTCCriterion(true) -- batchFirst true, input is batch x seqLength x inputDim
local batchFirstActs =
@@ -556,9 +556,17 @@ function nnxtest.CTCCriterion()
{{1,2,3,4,5},{6,7,8,9,10},{11,12,13,14,15}},
{{-5,-4,-3,-2,-1},{-10,-9,-8,-7,-6},{-15,-14,-13,-12,-11}}
})
- mytester:eq(criterion:updateOutput(batchFirstActs, targets, sizes), 13.904030799866, precision, "CTCCriterion.batchFirstTest")
+ mytester:eq(criterion:updateOutput(batchFirstActs, targets, sizes), 13.904030799866 / 3, precision, "CTCCriterion.batchFirstTest")
local gradOutputBatchFirst = criterion:updateGradInput(acts, targets, sizes)
mytester:assertTensorEq(gradOutputBatchFirst:transpose(1, 2), gradOutputNorm, precision, "CTCCriterion.gradCheckTest")
+ torch.Tensor({
+ {0,0,0,0,0},{1,2,3,4,5},{-5,-4,-3,-2,-1},
+ {0,0,0,0,0},{6,7,8,9,10},{-10,-9,-8,-7,-6},
+ {0,0,0,0,0},{11,12,13,14,15},{-15,-14,-13,-12,-11},
+ })
+ mytester:eq(criterion:updateOutput(batchFirstActs, targets, sizes), 13.904030799866 / 3, precision, "CTCCriterion.batchFirstTest")
+ local gradOutputBatchFirst = criterion:updateGradInput(acts, targets, sizes)
+ mytester:assertTensorEq(gradOutputBatchFirst:transpose(1, 2), gradOutputNorm, precision, "CTCCriterion.2DTensorTest")
end
local function blur(mean, stdv, size)