Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'test.lua')
-rw-r--r--test.lua145
1 files changed, 133 insertions, 12 deletions
diff --git a/test.lua b/test.lua
index 3fff151..25be158 100644
--- a/test.lua
+++ b/test.lua
@@ -39,6 +39,7 @@ for test_name, component in pairs(tostringTestModules) do
end
end
+
function nntest.Add()
local inj_vals = {math.random(3,5), 1} -- Also test the inj = 1 spatial case
local ini = math.random(3,5)
@@ -310,7 +311,7 @@ function nntest.PReLU()
for t,err in pairs(jac.testAllUpdate(module, input, 'weight', 'gradWeight')) do
mytester:assertlt(err, precision, string.format(
- 'error on weight [%s]', t))
+ 'error on weight [%s]', t))
end
-- 2D
@@ -328,7 +329,7 @@ function nntest.PReLU()
for t,err in pairs(jac.testAllUpdate(module, input, 'weight', 'gradWeight')) do
mytester:assertlt(err, precision, string.format(
- 'error on weight [%s]', t))
+ 'error on weight [%s]', t))
end
-- 4D
@@ -347,7 +348,7 @@ function nntest.PReLU()
for t,err in pairs(jac.testAllUpdate(module, input, 'weight', 'gradWeight')) do
mytester:assertlt(err, precision, string.format(
- 'error on weight [%s]', t))
+ 'error on weight [%s]', t))
end
-- IO
@@ -1067,7 +1068,7 @@ function nntest.LogSoftmax()
local module = nn.LogSoftMax()
local err = jac.testJacobian(module,input)
- mytester:assertlt(err,1e-3, 'error on state ')
+ mytester:assertlt(err, 1e-3, 'error on state ')
local ferr,berr = jac.testIO(module,input)
mytester:asserteq(ferr, 0, torch.typename(module) .. ' - i/o forward err ')
@@ -2999,7 +3000,7 @@ function nntest.AddConstant()
-- Test BPROP
local err = jac.testJacobian(mod, input)
mytester:assertlt(err, precision, 'bprop error ')
-
+
-- inplace comparisons
local ini = math.random(3,5)
local inj = math.random(3,5)
@@ -3024,7 +3025,7 @@ function nntest.AddConstant()
local gradInput1 = module1:backward(input1, gradOutput1)
local gradInput2 = module2:backward(input2, gradOutput2)
- mytester:asserteq(0, (gradInput1-gradInput2):abs():max(),
+ mytester:asserteq(0, (gradInput1-gradInput2):abs():max(),
torch.typename(module1) .. ' - in-place backward err ')
local input1 = torch.rand(ink, inj, ini)
@@ -3034,7 +3035,7 @@ function nntest.AddConstant()
module1:backward(module1.output,torch.rand(input1:size()))
local err = (input1-input2):abs():max()
- mytester:asserteq(err, 0, torch.typename(module1) ..
+ mytester:asserteq(err, 0, torch.typename(module1) ..
' - inplace input change err ')
end
@@ -3056,7 +3057,7 @@ function nntest.MulConstant()
-- Test BPROP
local err = jac.testJacobian(mod, input)
mytester:assertlt(err, precision, 'bprop error ')
-
+
-- inplace comparisons
local ini = math.random(3,5)
local inj = math.random(3,5)
@@ -3075,13 +3076,13 @@ function nntest.MulConstant()
local out1 = module1:forward(input1)
local out2 = module2:forward(input2)
- mytester:asserteq(0, (out1-out2):abs():max(), torch.typename(module1) ..
+ mytester:asserteq(0, (out1-out2):abs():max(), torch.typename(module1) ..
' - in-place forward err ')
local gradInput1 = module1:backward(input1, gradOutput1)
local gradInput2 = module2:backward(input2, gradOutput2)
-
- mytester:asserteq(0, (gradInput1-gradInput2):abs():max(),
+
+ mytester:asserteq(0, (gradInput1-gradInput2):abs():max(),
torch.typename(module1) .. ' - in-place backward err ')
local input1 = torch.rand(ink, inj, ini)
@@ -3091,7 +3092,7 @@ function nntest.MulConstant()
module1:backward(module1.output,torch.rand(input1:size()))
local err = (input1-input2):abs():max()
- mytester:assertalmosteq(err, 0, 1e-15, torch.typename(module1) ..
+ mytester:assertalmosteq(err, 0, 1e-15, torch.typename(module1) ..
' - inplace input change err ')
end
@@ -4154,6 +4155,126 @@ function nntest.addSingletonDimension()
"invalid dimension not detected")
end
+function nntest.Typecast()
+ local function make_network()
+ local seq = nn.Sequential()
+ seq:add(nn.Linear(15, 10))
+ seq:add(nn.Linear(15, 10))
+ seq.modules[1].bias:fill(1)
+ seq.modules[2].bias:fill(2)
+ return seq
+ end
+
+ -- make sure that the typecasts aren't nops
+ assert(torch.getdefaulttensortype() == 'torch.DoubleTensor')
+
+ -- basic net
+ local net = make_network()
+ net.modules[1].empty_tensor = torch.Tensor()
+ net:float()
+ assert(net.modules[1].bias:type() == 'torch.FloatTensor',
+ net.modules[1].bias:type())
+ assert(net.modules[1].empty_tensor:type() == 'torch.FloatTensor')
+ assert(net.modules[1].bias ~= net.modules[2].bias)
+ net.modules[1].bias:fill(3)
+ assert(net.modules[1].bias[1] == 3)
+ assert(net.modules[2].bias[1] == 2)
+
+ -- shared tensors remain shared
+ local net = make_network()
+ net.modules[2].bias = net.modules[1].bias
+ net:float()
+ assert(net.modules[1].bias:type() == 'torch.FloatTensor')
+ assert(net.modules[1].bias == net.modules[2].bias)
+ assert(net.modules[1].bias[1] == 1)
+
+ -- shared storages remain shared
+ local net = make_network()
+ net.modules[2].bias:set(net.modules[1].bias)
+ local net = net:float()
+ assert(net.modules[1].bias:type() == 'torch.FloatTensor')
+ assert(net.modules[1].bias ~= net.modules[2].bias)
+ net.modules[1].bias:fill(3)
+ assert(net.modules[1].bias[1] == 3)
+ assert(net.modules[2].bias[1] == 3)
+
+ -- tricky: overlapping views on the same storage are preserved
+ local net = make_network()
+ local overlap_storage = torch.Tensor(15):fill(1)
+ net.modules[1].bias = overlap_storage:narrow(1, 1, 10)
+ net.modules[2].bias = overlap_storage:narrow(1, 6, 10)
+ net:float()
+ assert(net.modules[1].bias:type() == 'torch.FloatTensor')
+ assert(net.modules[1].bias ~= net.modules[2].bias)
+ net.modules[1].bias:fill(3)
+ assert(net.modules[1].bias[1] == 3)
+ assert(net.modules[2].bias[1] == 3)
+ assert(net.modules[2].bias[6] == 1) -- only the first 5 elements overlapped
+
+ -- check recursiveType on a table
+ local net1 = make_network()
+ local net2 = make_network()
+ net2.modules[1].bias:set(net1.modules[1].bias)
+ net1:float()
+ net2:float()
+ net1.modules[1].bias:fill(3)
+ assert(net2.modules[1].bias[1] == 1)
+
+ local net1 = make_network()
+ local net2 = make_network()
+ net2.modules[1].bias:set(net1.modules[1].bias)
+
+ local tensorCache = {}
+ net1:type('torch.FloatTensor', tensorCache)
+ net2:type('torch.FloatTensor', tensorCache)
+ net1.modules[1].bias:fill(3)
+ assert(net2.modules[1].bias[1] == 3)
+
+ local net1 = make_network()
+ local net2 = make_network()
+ net2.modules[1].bias:set(net1.modules[1].bias)
+
+ nn.utils.recursiveType({net1, net2}, 'torch.FloatTensor')
+ net1.modules[1].bias:fill(3)
+ assert(net2.modules[1].bias[1] == 3)
+
+ -- smoke test some modules with custom type methods
+ local custom_type_modules = {
+ nn.MixtureTable(3),
+ nn.ConcatTable(),
+ nn.Copy(),
+ nn.Copy(nil, nil, nil, true),
+ nn.SpatialContrastiveNormalization(),
+ nn.DotProduct(),
+ nn.PairwiseDistance(1),
+ nn.SpatialDivisiveNormalization(),
+ nn.SpatialSubtractiveNormalization()
+ }
+ for _, module in ipairs(custom_type_modules) do
+ module:float()
+ end
+end
+
+function nntest.Module_apply()
+ local s = nn.Sequential()
+ s:add(nn.Linear(10,10))
+ local s2 = nn.Sequential()
+ s2:add(nn.Linear(10,5))
+ s:add(s2)
+ s:add(nn.Tanh())
+
+ local seen = 0
+ s:apply(function(module)
+ if torch.type(module) == 'nn.Linear' then
+ module.bias:resize(20)
+ seen = seen + 1
+ end
+ end)
+ mytester:asserteq(seen, 2)
+ mytester:asserteq(s.modules[1].bias:size(1), 20)
+ mytester:asserteq(s2.modules[1].bias:size(1), 20)
+end
+
mytester:add(nntest)
if not nn then