diff options
Diffstat (limited to 'test.lua')
-rw-r--r-- | test.lua | 145 |
1 files changed, 133 insertions, 12 deletions
@@ -39,6 +39,7 @@ for test_name, component in pairs(tostringTestModules) do end end + function nntest.Add() local inj_vals = {math.random(3,5), 1} -- Also test the inj = 1 spatial case local ini = math.random(3,5) @@ -310,7 +311,7 @@ function nntest.PReLU() for t,err in pairs(jac.testAllUpdate(module, input, 'weight', 'gradWeight')) do mytester:assertlt(err, precision, string.format( - 'error on weight [%s]', t)) + 'error on weight [%s]', t)) end -- 2D @@ -328,7 +329,7 @@ function nntest.PReLU() for t,err in pairs(jac.testAllUpdate(module, input, 'weight', 'gradWeight')) do mytester:assertlt(err, precision, string.format( - 'error on weight [%s]', t)) + 'error on weight [%s]', t)) end -- 4D @@ -347,7 +348,7 @@ function nntest.PReLU() for t,err in pairs(jac.testAllUpdate(module, input, 'weight', 'gradWeight')) do mytester:assertlt(err, precision, string.format( - 'error on weight [%s]', t)) + 'error on weight [%s]', t)) end -- IO @@ -1067,7 +1068,7 @@ function nntest.LogSoftmax() local module = nn.LogSoftMax() local err = jac.testJacobian(module,input) - mytester:assertlt(err,1e-3, 'error on state ') + mytester:assertlt(err, 1e-3, 'error on state ') local ferr,berr = jac.testIO(module,input) mytester:asserteq(ferr, 0, torch.typename(module) .. ' - i/o forward err ') @@ -2999,7 +3000,7 @@ function nntest.AddConstant() -- Test BPROP local err = jac.testJacobian(mod, input) mytester:assertlt(err, precision, 'bprop error ') - + -- inplace comparisons local ini = math.random(3,5) local inj = math.random(3,5) @@ -3024,7 +3025,7 @@ function nntest.AddConstant() local gradInput1 = module1:backward(input1, gradOutput1) local gradInput2 = module2:backward(input2, gradOutput2) - mytester:asserteq(0, (gradInput1-gradInput2):abs():max(), + mytester:asserteq(0, (gradInput1-gradInput2):abs():max(), torch.typename(module1) .. ' - in-place backward err ') local input1 = torch.rand(ink, inj, ini) @@ -3034,7 +3035,7 @@ function nntest.AddConstant() module1:backward(module1.output,torch.rand(input1:size())) local err = (input1-input2):abs():max() - mytester:asserteq(err, 0, torch.typename(module1) .. + mytester:asserteq(err, 0, torch.typename(module1) .. ' - inplace input change err ') end @@ -3056,7 +3057,7 @@ function nntest.MulConstant() -- Test BPROP local err = jac.testJacobian(mod, input) mytester:assertlt(err, precision, 'bprop error ') - + -- inplace comparisons local ini = math.random(3,5) local inj = math.random(3,5) @@ -3075,13 +3076,13 @@ function nntest.MulConstant() local out1 = module1:forward(input1) local out2 = module2:forward(input2) - mytester:asserteq(0, (out1-out2):abs():max(), torch.typename(module1) .. + mytester:asserteq(0, (out1-out2):abs():max(), torch.typename(module1) .. ' - in-place forward err ') local gradInput1 = module1:backward(input1, gradOutput1) local gradInput2 = module2:backward(input2, gradOutput2) - - mytester:asserteq(0, (gradInput1-gradInput2):abs():max(), + + mytester:asserteq(0, (gradInput1-gradInput2):abs():max(), torch.typename(module1) .. ' - in-place backward err ') local input1 = torch.rand(ink, inj, ini) @@ -3091,7 +3092,7 @@ function nntest.MulConstant() module1:backward(module1.output,torch.rand(input1:size())) local err = (input1-input2):abs():max() - mytester:assertalmosteq(err, 0, 1e-15, torch.typename(module1) .. + mytester:assertalmosteq(err, 0, 1e-15, torch.typename(module1) .. ' - inplace input change err ') end @@ -4154,6 +4155,126 @@ function nntest.addSingletonDimension() "invalid dimension not detected") end +function nntest.Typecast() + local function make_network() + local seq = nn.Sequential() + seq:add(nn.Linear(15, 10)) + seq:add(nn.Linear(15, 10)) + seq.modules[1].bias:fill(1) + seq.modules[2].bias:fill(2) + return seq + end + + -- make sure that the typecasts aren't nops + assert(torch.getdefaulttensortype() == 'torch.DoubleTensor') + + -- basic net + local net = make_network() + net.modules[1].empty_tensor = torch.Tensor() + net:float() + assert(net.modules[1].bias:type() == 'torch.FloatTensor', + net.modules[1].bias:type()) + assert(net.modules[1].empty_tensor:type() == 'torch.FloatTensor') + assert(net.modules[1].bias ~= net.modules[2].bias) + net.modules[1].bias:fill(3) + assert(net.modules[1].bias[1] == 3) + assert(net.modules[2].bias[1] == 2) + + -- shared tensors remain shared + local net = make_network() + net.modules[2].bias = net.modules[1].bias + net:float() + assert(net.modules[1].bias:type() == 'torch.FloatTensor') + assert(net.modules[1].bias == net.modules[2].bias) + assert(net.modules[1].bias[1] == 1) + + -- shared storages remain shared + local net = make_network() + net.modules[2].bias:set(net.modules[1].bias) + local net = net:float() + assert(net.modules[1].bias:type() == 'torch.FloatTensor') + assert(net.modules[1].bias ~= net.modules[2].bias) + net.modules[1].bias:fill(3) + assert(net.modules[1].bias[1] == 3) + assert(net.modules[2].bias[1] == 3) + + -- tricky: overlapping views on the same storage are preserved + local net = make_network() + local overlap_storage = torch.Tensor(15):fill(1) + net.modules[1].bias = overlap_storage:narrow(1, 1, 10) + net.modules[2].bias = overlap_storage:narrow(1, 6, 10) + net:float() + assert(net.modules[1].bias:type() == 'torch.FloatTensor') + assert(net.modules[1].bias ~= net.modules[2].bias) + net.modules[1].bias:fill(3) + assert(net.modules[1].bias[1] == 3) + assert(net.modules[2].bias[1] == 3) + assert(net.modules[2].bias[6] == 1) -- only the first 5 elements overlapped + + -- check recursiveType on a table + local net1 = make_network() + local net2 = make_network() + net2.modules[1].bias:set(net1.modules[1].bias) + net1:float() + net2:float() + net1.modules[1].bias:fill(3) + assert(net2.modules[1].bias[1] == 1) + + local net1 = make_network() + local net2 = make_network() + net2.modules[1].bias:set(net1.modules[1].bias) + + local tensorCache = {} + net1:type('torch.FloatTensor', tensorCache) + net2:type('torch.FloatTensor', tensorCache) + net1.modules[1].bias:fill(3) + assert(net2.modules[1].bias[1] == 3) + + local net1 = make_network() + local net2 = make_network() + net2.modules[1].bias:set(net1.modules[1].bias) + + nn.utils.recursiveType({net1, net2}, 'torch.FloatTensor') + net1.modules[1].bias:fill(3) + assert(net2.modules[1].bias[1] == 3) + + -- smoke test some modules with custom type methods + local custom_type_modules = { + nn.MixtureTable(3), + nn.ConcatTable(), + nn.Copy(), + nn.Copy(nil, nil, nil, true), + nn.SpatialContrastiveNormalization(), + nn.DotProduct(), + nn.PairwiseDistance(1), + nn.SpatialDivisiveNormalization(), + nn.SpatialSubtractiveNormalization() + } + for _, module in ipairs(custom_type_modules) do + module:float() + end +end + +function nntest.Module_apply() + local s = nn.Sequential() + s:add(nn.Linear(10,10)) + local s2 = nn.Sequential() + s2:add(nn.Linear(10,5)) + s:add(s2) + s:add(nn.Tanh()) + + local seen = 0 + s:apply(function(module) + if torch.type(module) == 'nn.Linear' then + module.bias:resize(20) + seen = seen + 1 + end + end) + mytester:asserteq(seen, 2) + mytester:asserteq(s.modules[1].bias:size(1), 20) + mytester:asserteq(s2.modules[1].bias:size(1), 20) +end + mytester:add(nntest) if not nn then |