Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/clementfarabet/lua---nnx.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSean <seannaren>2016-04-02 18:57:17 +0300
committerSean <seannaren>2016-04-02 18:57:17 +0300
commit16377ab8b93d2db883b8e44f1f7bc9d78467a656 (patch)
tree58f398716d2b88cffe283adf4cd99e2d27391a36
parent5c17882b179119d67bb547ea0926eb0913442c7d (diff)
added base tests for CTCCriterion
-rw-r--r--test/test-all.lua69
1 files changed, 46 insertions, 23 deletions
diff --git a/test/test-all.lua b/test/test-all.lua
index c1ca354..2e77db4 100644
--- a/test/test-all.lua
+++ b/test/test-all.lua
@@ -102,10 +102,10 @@ local function template_SpatialReSamplingEx(up, mode)
local module = nn.SpatialReSamplingEx({owidth=owidth_, oheight=oheight_,
xDim=xdim, yDim = ydim, mode=mode})
local input = torch.rand(dims)
-
+
local err = nn.Jacobian.testJacobian(module, input)
mytester:assertlt(err, precision, 'error on state ')
-
+
local ferr, berr = nn.Jacobian.testIO(module, input)
mytester:asserteq(ferr, 0, torch.typename(module) .. ' - i/o forward err ')
mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err ')
@@ -172,21 +172,21 @@ function nnxtest.SpatialReSampling_1()
local batchSize = math.random(4,8)
local input2 = torch.rand(batchSize,fanin,sizey,sizex)
input2[2]:copy(input)
-
+
local output = module:forward(input):clone()
local output2 = module:forward(input2)
mytester:assertTensorEq(output, output2[2], 0.00001, 'SpatialResampling batch forward err')
-
+
local gradInput = module:backward(input, output):clone()
local gradInput2 = module:backward(input2, output2)
mytester:assertTensorEq(gradInput, gradInput2[2], 0.00001, 'SpatialResampling batch backward err')
-
+
-- test rwidth/rheight
local input = torch.randn(3,8,10)
local module = nn.SpatialReSampling{rwidth=0.5,rheight=0.5}
local output = module:forward(input)
mytester:assertTableEq(output:size():totable(), {3, 4, 5}, 0.00000001, 'SpatialResampling batch rwidth/rheight err')
-
+
local input = torch.randn(2,3,8,10)
local module = nn.SpatialReSampling{rwidth=0.5,rheight=0.5}
local output = module:forward(input)
@@ -408,7 +408,7 @@ local function template_SpatialMatching(channels, iwidth, iheight, maxw, maxh, f
local input = torch.rand(2, channels, iheight, iwidth)
local err = nn.Jacobian.testJacobian(module, input)
mytester:assertlt(err, precision, 'error on state ')
-
+
local ferr, berr = nn.Jacobian.testIO(module, input)
mytester:asserteq(ferr, 0, torch.typename(module) .. ' - i/o forward err ')
mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err ')
@@ -426,7 +426,7 @@ function nnxtest.SoftMaxTree()
local grad = torch.randn(5)
local root_id = 29
local hierarchy={
- [29]=torch.IntTensor{30,1,2}, [1]=torch.IntTensor{3,4,5},
+ [29]=torch.IntTensor{30,1,2}, [1]=torch.IntTensor{3,4,5},
[2]=torch.IntTensor{6,7,8}, [3]=torch.IntTensor{9,10,11},
[4]=torch.IntTensor{12,13,14}, [5]=torch.IntTensor{15,16,17},
[6]=torch.IntTensor{18,19,20}, [7]=torch.IntTensor{21,22,23},
@@ -439,7 +439,7 @@ function nnxtest.SoftMaxTree()
local indices = {3,3,4}
local parentIds = {29,2,8}
local linears = {}
-
+
for i,parentId in ipairs(parentIds) do
local s = nn.Sequential()
local linear = nn.Linear(100,hierarchy[parentId]:size(1))
@@ -512,7 +512,7 @@ end
function nnxtest.TreeNLLCriterion()
local input = torch.randn(5,10)
local target = torch.ones(5) --all targets are 1
- local c = nn.TreeNLLCriterion()
+ local c = nn.TreeNLLCriterion()
-- the targets are actually ignored (SoftMaxTree uses them before TreeNLLCriterion)
local err = c:forward(input, target)
gradInput = c:backward(input, target)
@@ -524,6 +524,29 @@ function nnxtest.TreeNLLCriterion()
mytester:assertTensorEq(gradInput2:narrow(2,1,1), gradInput, 0.00001)
end
+function nnxtest.CTCCriterion()
+ local criterion = nn.CTCCriterion()
+ local acts = torch.Tensor({{{0,0,0,0,0}}})
+ local targets = {{1}}
+ mytester:eq(criterion:updateOutput(acts,targets), 1.6094379425049, 0, "CTCCriterion.smallTest")
+ local acts =
+ torch.Tensor({{{1,2,3,4,5}, {6,7,8,9,10}, {11,12,13,14,15}}})
+ local targets = {{3,3}}
+ mytester:eq(criterion:updateOutput(acts,targets), 7.355742931366, 0, "CTCCriterion.mediumTest")
+ local acts = torch.Tensor({{{-5,-4,-3,-2,-1}, {-10,-9,-8,-7,-6}, {-15,-14,-13,-12,-11}}})
+ local targets = {{2,3}}
+ mytester:eq(criterion:updateOutput(acts,targets), 4.938850402832, 0, "CTCCriterion.mediumNegativeTest")
+ local acts =
+ torch.Tensor({
+ {{0,0,0,0,0},{0,0,0,0,0},{0,0,0,0,0}},
+ {{1,2,3,4,5},{6,7,8,9,10},{11,12,13,14,15}},
+ {{-5,-4,-3,-2,-1},{-10,-9,-8,-7,-6},{-15,-14,-13,-12,-11}}
+ })
+ local targets = {{1},{3,3},{2,3}}
+ mytester:eq(criterion:updateOutput(acts,targets), 15.331147670746, 0, "CTCCriterion.batchTest")
+end
+
+
local function blur(mean, stdv, size)
local range = torch.range(1,size):float()
local a = 1/(stdv*math.sqrt(2*math.pi))
@@ -532,10 +555,10 @@ local function blur(mean, stdv, size)
end
function nnxtest.Balance()
- local inputSize = 7
+ local inputSize = 7
local batchSize = 3
local nBatch = 1
-
+
local input = torch.randn(batchSize, inputSize):mul(0.1):float()
for i=1,batchSize do
input[i]:add(blur(3, 1, inputSize):float())
@@ -546,34 +569,34 @@ function nnxtest.Balance()
local gradOutput = torch.randn(batchSize, inputSize):float()
local bl = nn.Balance(nBatch)
bl:float()
-
+
local output = bl:forward(input)
local p_y = output:sum(1):div(output:sum())
mytester:assert(p_y:std() < 0.02)
mytester:assert(math.abs(p_y:sum() - 1) < 0.000001)
-
+
local gradInput = bl:backward(input, gradOutput)
end
function nnxtest.MultiSoftMax()
- local inputSize = 7
+ local inputSize = 7
local nSoftmax = 5
local batchSize = 3
-
+
local input = torch.randn(batchSize, nSoftmax, inputSize)
local gradOutput = torch.randn(batchSize, nSoftmax, inputSize)
local msm = nn.MultiSoftMax()
-
+
local output = msm:forward(input)
local gradInput = msm:backward(input, gradOutput)
mytester:assert(output:isSameSizeAs(input))
mytester:assert(gradOutput:isSameSizeAs(gradInput))
-
+
local sm = nn.SoftMax()
local input2 = input:view(batchSize*nSoftmax, inputSize)
local output2 = sm:forward(input2)
local gradInput2 = sm:backward(input2, gradOutput:view(batchSize*nSoftmax, inputSize))
-
+
mytester:assertTensorEq(output, output2, 0.000001)
mytester:assertTensorEq(gradInput, gradInput2, 0.000001)
end
@@ -585,14 +608,14 @@ function nnxtest.PushPullTable()
local gradOutput = torch.randn(5)
local root_id = 29
local hierarchy={
- [29]=torch.IntTensor{30,1,2}, [1]=torch.IntTensor{3,4,5},
+ [29]=torch.IntTensor{30,1,2}, [1]=torch.IntTensor{3,4,5},
[2]=torch.IntTensor{6,7,8}, [3]=torch.IntTensor{9,10,11},
[4]=torch.IntTensor{12,13,14}, [5]=torch.IntTensor{15,16,17},
[6]=torch.IntTensor{18,19,20}, [7]=torch.IntTensor{21,22,23},
[8]=torch.IntTensor{24,25,26,27,28}
}
local smt = nn.SoftMaxTree(100, hierarchy, root_id)
- -- create a network where inputs are fed through softmaxtree
+ -- create a network where inputs are fed through softmaxtree
-- and targets are teleported (pushed then pulled) to softmaxtree
local mlp = nn.Sequential()
local linear = nn.Linear(50,100)
@@ -618,7 +641,7 @@ function nnxtest.PushPullTable()
mytester:assertTensorEq(output, output2, 0.00001, "push/pull forward error")
mytester:assertTensorEq(gradInput[1], gradInput[1], 0.00001, "push/pull backward error")
mytester:assertTensorEq(gradInput[2], gradInput[2], 0.00001, "push/pull backward error")
-
+
-- test multi-pull case
local mlp = nn.Sequential()
local push = nn.PushTable(2)
@@ -635,7 +658,7 @@ function nnxtest.PushPullTable()
mytester:assertTensorEq(output[4], inputTable[2], 0.00001, "push/pull multi-forward error")
local gradOutput = {inputTable[2]:clone(), inputTable[1]:clone(), inputTable[2]:clone(), inputTable[2]:clone()}
local gradInput = mlp:backward(inputTable, gradOutput)
- local gradInput2 = inputTable[2]:clone():mul(3)
+ local gradInput2 = inputTable[2]:clone():mul(3)
mytester:assertTensorEq(gradInput[1], gradInput[1], 0.00001, "push/pull multi-backward error")
mytester:assertTensorEq(gradInput[2], gradInput[2], 0.00001, "push/pull multi-backward error")
end