Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'test.lua')
-rw-r--r--test.lua120
1 files changed, 92 insertions, 28 deletions
diff --git a/test.lua b/test.lua
index f8437dc..3f37dac 100644
--- a/test.lua
+++ b/test.lua
@@ -2009,38 +2009,44 @@ function nntest.SpatialSubSampling()
end
function nntest.SpatialMaxPooling()
- local from = math.random(1,5)
- local ki = math.random(1,4)
- local kj = math.random(1,4)
- local si = math.random(1,3)
- local sj = math.random(1,3)
- local outi = math.random(4,5)
- local outj = math.random(4,5)
- local ini = (outi-1)*si+ki
- local inj = (outj-1)*sj+kj
+ for _,ceil_mode in pairs({true,false}) do
+ local from = math.random(1,5)
+ local ki = math.random(1,4)
+ local kj = math.random(1,4)
+ local si = math.random(1,3)
+ local sj = math.random(1,3)
+ local outi = math.random(4,5)
+ local outj = math.random(4,5)
+ local padW = math.min(math.random(0,1),math.floor(ki/2))
+ local padH = math.min(math.random(0,1),math.floor(kj/2))
+ local ini = (outi-1)*si+ki-2*padW
+ local inj = (outj-1)*sj+kj-2*padH
+
+ local ceil_string = ceil_mode and 'ceil' or 'floor'
+ local module = nn.SpatialMaxPooling(ki,kj,si,sj,padW,padH)
+ if ceil_mode then module:ceil() else module:floor() end
+ local input = torch.rand(from,inj,ini)
- local module = nn.SpatialMaxPooling(ki,kj,si,sj)
- local input = torch.rand(from,ini,inj)
-
- local err = jac.testJacobian(module, input)
- mytester:assertlt(err, precision, 'error on state ')
-
- local ferr, berr = jac.testIO(module, input)
- mytester:asserteq(ferr, 0, torch.typename(module) .. ' - i/o forward err ')
- mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err ')
+ local err = jac.testJacobian(module, input)
+ mytester:assertlt(err, precision, 'error '..ceil_string..' mode on state ')
- -- batch
- local nbatch = math.random(2,5)
- input = torch.rand(nbatch,from,ini,inj)
- module = nn.SpatialMaxPooling(ki,kj,si,sj)
+ local ferr, berr = jac.testIO(module, input)
+ mytester:asserteq(ferr, 0, torch.typename(module) .. ' - i/o forward err ')
+ mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err ')
- local err = jac.testJacobian(module, input)
- mytester:assertlt(err, precision, 'error on state (Batch) ')
+ -- batch
+ local nbatch = math.random(2,5)
+ input = torch.rand(nbatch,from,inj,ini)
+ module = nn.SpatialMaxPooling(ki,kj,si,sj,padW,padH)
+ if ceil_mode then module:ceil() else module:floor() end
- local ferr, berr = jac.testIO(module, input)
- mytester:asserteq(ferr, 0, torch.typename(module) .. ' - i/o forward err (Batch) ')
- mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err (Batch) ')
+ local err = jac.testJacobian(module, input)
+ mytester:assertlt(err, precision, 'error '..ceil_string..' mode on state (Batch)')
+ local ferr, berr = jac.testIO(module, input)
+ mytester:asserteq(ferr, 0, torch.typename(module) .. ' - i/o forward err (Batch) ')
+ mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err (Batch) ')
+ end
end
function nntest.SpatialAveragePooling()
@@ -3012,6 +3018,13 @@ function nntest.SplitTable()
module = nn.SplitTable(d, 2)
mytester:asserteq(#module:forward(input), input:size(d+1), "dimension " .. d)
end
+
+ -- Negative indices
+ local module = nn.SplitTable(-3)
+ local input = torch.randn(3,4,5)
+ mytester:asserteq(#module:forward(input), 3, "negative index")
+ local input = torch.randn(2,3,4,5)
+ mytester:asserteq(#module:forward(input), 3, "negative index (minibatch)")
end
function nntest.SelectTable()
@@ -3040,8 +3053,20 @@ function nntest.SelectTable()
equal(gradInput[idx], gradOutputs[idx], "gradInput[idx] dimension " .. idx)
equal(gradInput[nonIdx[idx]], zeros[nonIdx[idx]], "gradInput[nonIdx] dimension " .. idx)
end
- module:float()
+
+ -- test negative index
+ local idx = -2
+ module = nn.SelectTable(idx)
+ local output = module:forward(input)
+ equal(output, input[#input+idx+1], "output dimension " .. idx)
+ local gradInput = module:backward(input, gradOutputs[#input+idx+1])
+ equal(gradInput[#input+idx+1], gradOutputs[#input+idx+1], "gradInput[idx] dimension " .. idx)
+ equal(gradInput[nonIdx[#input+idx+1]], zeros[nonIdx[#input+idx+1]], "gradInput[nonIdx] dimension " .. idx)
+
+ -- test typecast
local idx = #input
+ module = nn.SelectTable(idx)
+ module:float()
local output = module:forward(input)
equal(output, input[idx], "type output")
local gradInput = module:backward(input, gradOutputs[idx])
@@ -3180,6 +3205,45 @@ function nntest.MixtureTable()
end
end
+function nntest.NarrowTable()
+ local input = torch.randn(3,10,4)
+ local gradOutput = torch.randn(3,3,4)
+ local nt = nn.NarrowTable(5,3)
+ local seq = nn.Sequential()
+ seq:add(nn.SplitTable(1,2))
+ seq:add(nt)
+ seq:add(nn.JoinTable(1,1))
+ seq:add(nn.Reshape(3,3,4))
+ local seq2 = nn.Narrow(2,5,3)
+ local output = seq:forward(input)
+ local gradInput = seq:backward(input, gradOutput)
+ local output2 = seq2:forward(input)
+ local gradInput2 = seq2:backward(input, gradOutput)
+ mytester:assertTensorEq(output, output2, 0.0000001, "NarrowTable output err")
+ mytester:assertTensorEq(gradInput, gradInput2, 0.00001, "NarrowTable gradInput err")
+
+ -- now try it with a smaller input
+ local input = input:narrow(2, 1, 8)
+ local output = seq:forward(input)
+ local gradInput = seq:backward(input, gradOutput)
+ local output2 = seq2:forward(input)
+ local gradInput2 = seq2:backward(input, gradOutput)
+ mytester:assertTensorEq(output, output2, 0.0000001, "NarrowTable small output err")
+ mytester:assertTensorEq(gradInput, gradInput2, 0.00001, "NarrowTable small gradInput err")
+
+ -- test type-cast
+ local input = input:float()
+ local gradOutput = gradOutput:float()
+ seq:float()
+ seq2:float()
+ local output = seq:forward(input)
+ local gradInput = seq:backward(input, gradOutput)
+ local output2 = seq2:forward(input)
+ local gradInput2 = seq2:backward(input, gradOutput)
+ mytester:assertTensorEq(output, output2, 0.0000001, "NarrowTable output float err")
+ mytester:assertTensorEq(gradInput, gradInput2, 0.00001, "NarrowTable gradInput float err")
+end
+
function nntest.View()
local input = torch.rand(10)
local template = torch.rand(5,2)