diff options
author | Clement Farabet <clement.farabet@gmail.com> | 2012-09-07 20:44:49 +0400 |
---|---|---|
committer | Clement Farabet <clement.farabet@gmail.com> | 2012-09-07 20:44:49 +0400 |
commit | cde19db1b991380bec5700f11ff837677dcbf444 (patch) | |
tree | 86050f677af1b13a52c5b63a5458bd0d7768e36a /test | |
parent | bf724fdabdca054b2cfde2fee1908c3b216e65a6 (diff) |
New tester.
Diffstat (limited to 'test')
-rw-r--r-- | test/test.lua | 99 |
1 files changed, 97 insertions, 2 deletions
diff --git a/test/test.lua b/test/test.lua index 4d4383d..ff8b0d9 100644 --- a/test/test.lua +++ b/test/test.lua @@ -1078,10 +1078,105 @@ function nntest.VolumetricConvolution() mytester:asserteq(0, berr, torch.typename(module) .. ' - i/o backward err ') end +function nntest.Module_getParameters_1() + local n = nn.Sequential() + n:add( nn.Linear(10,10) ) + local p = n:getParameters() + + mytester:asserteq((p[{ {1,100} }] - n.modules[1].weight):norm(), 0, 'getParameters(): weights wrong') + mytester:asserteq((p[{ {101,110} }] - n.modules[1].bias):norm(), 0, 'getParameters(): bias wrong') +end + +function nntest.Module_getParameters_2() + local n = nn.Sequential() + n:add( nn.Linear(10,10) ) + local p = n:getParameters() + + n:add( nn.Linear(10,10) ) + p = n:getParameters() + + mytester:asserteq((p[{ {111,210} }] - n.modules[2].weight):norm(), 0, 'error when appending new module') + mytester:asserteq((p[{ {211,220} }] - n.modules[2].bias):norm(), 0, 'error when appending new module') +end + +function nntest.Module_getParameters_3() + local n = nn.Sequential() + n:add( nn.Linear(10,10) ) + n:add( n.modules[1]:clone() ) + local p = n:getParameters() + + mytester:asserteq((p[{ {1,100} }] - n.modules[1].weight):norm(), 0, 'error when using cloning') + mytester:asserteq((p[{ {101,110} }] - n.modules[1].bias):norm(), 0, 'error when using cloning') + + mytester:asserteq((p[{ {111,210} }] - n.modules[2].weight):norm(), 0, 'error when using cloning') + mytester:asserteq((p[{ {211,220} }] - n.modules[2].bias):norm(), 0, 'error when using cloning') + + mytester:asserteq((p[{ {111,210} }] - n.modules[1].weight):norm(), 0, 'error when using cloning') + mytester:asserteq((p[{ {211,220} }] - n.modules[1].bias):norm(), 0, 'error when using cloning') + + n:reset() + + mytester:assertgt((p[{ {111,210} }] - n.modules[1].weight):norm(), 0, 'error when using cloning') + mytester:assertgt((p[{ {211,220} }] - n.modules[1].bias):norm(), 0, 'error when using cloning') +end + +function nntest.Module_getParameters_4() + local n = nn.Sequential() + n:add( nn.Linear(10,10) ) + n:add( n.modules[1]:clone() ) + local p = n:getParameters() + + n:add(nn.Linear(10,10)) + p = n:getParameters() + + mytester:asserteq((p[{ {1,100} }] - n.modules[1].weight):norm(), 0, 'error when using cloning') + mytester:asserteq((p[{ {101,110} }] - n.modules[1].bias):norm(), 0, 'error when using cloning') + + mytester:asserteq((p[{ {111,210} }] - n.modules[2].weight):norm(), 0, 'error when using cloning') + mytester:asserteq((p[{ {211,220} }] - n.modules[2].bias):norm(), 0, 'error when using cloning') + + mytester:asserteq((p[{ {221,320} }] - n.modules[3].weight):norm(), 0, 'error when using cloning') + mytester:asserteq((p[{ {321,330} }] - n.modules[3].bias):norm(), 0, 'error when using cloning') +end + +function nntest.Module_getParameters_5() + local n = nn.Sequential() + n:add( nn.Linear(10,10) ) + n:add( n.modules[1]:clone('weight','bias') ) + local p = n:getParameters() + + mytester:asserteq((p[{ {1,100} }] - n.modules[1].weight):norm(), 0, 'error when using cloning+sharing') + mytester:asserteq((p[{ {101,110} }] - n.modules[1].bias):norm(), 0, 'error when using cloning+sharing') + + mytester:asserteq((p[{ {1,100} }] - n.modules[2].weight):norm(), 0, 'error when using cloning+sharing') + mytester:asserteq((p[{ {101,110} }] - n.modules[2].bias):norm(), 0, 'error when using cloning+sharing') + + n:reset() + + mytester:asserteq((p[{ {1,100} }] - n.modules[2].weight):norm(), 0, 'error when using cloning+sharing') + mytester:asserteq((p[{ {101,110} }] - n.modules[2].bias):norm(), 0, 'error when using cloning+sharing') +end + +function nntest.Module_getParameters_6() + local n = nn.Sequential() + n:add( nn.Linear(10,10) ) + n:add( n.modules[1]:clone('weight','bias') ) + local p = n:getParameters() + + n:add(nn.Linear(10,10)) + p = n:getParameters() + + mytester:asserteq((p[{ {1,100} }] - n.modules[1].weight):norm(), 0, 'error when using cloning+sharing') + mytester:asserteq((p[{ {101,110} }] - n.modules[1].bias):norm(), 0, 'error when using cloning+sharing') + + mytester:asserteq((p[{ {1,100} }] - n.modules[2].weight):norm(), 0, 'error when using cloning+sharing') + mytester:asserteq((p[{ {101,110} }] - n.modules[2].bias):norm(), 0, 'error when using cloning+sharing') + + mytester:asserteq((p[{ {111,210} }] - n.modules[3].weight):norm(), 0, 'error when using cloning+sharing') + mytester:asserteq((p[{ {211,220} }] - n.modules[3].bias):norm(), 0, 'error when using cloning+sharing') +end mytester:add(nntest) ---mytester:add(test_SpatialConvolution) ---mytester:add(test_AbsCriterion) if not nn then require 'nn' |