diff options
Diffstat (limited to 'test/test.lua')
-rw-r--r-- | test/test.lua | 99 |
1 files changed, 97 insertions, 2 deletions
diff --git a/test/test.lua b/test/test.lua index 4d4383d..ff8b0d9 100644 --- a/test/test.lua +++ b/test/test.lua @@ -1078,10 +1078,105 @@ function nntest.VolumetricConvolution() mytester:asserteq(0, berr, torch.typename(module) .. ' - i/o backward err ') end +function nntest.Module_getParameters_1() + local n = nn.Sequential() + n:add( nn.Linear(10,10) ) + local p = n:getParameters() + + mytester:asserteq((p[{ {1,100} }] - n.modules[1].weight):norm(), 0, 'getParameters(): weights wrong') + mytester:asserteq((p[{ {101,110} }] - n.modules[1].bias):norm(), 0, 'getParameters(): bias wrong') +end + +function nntest.Module_getParameters_2() + local n = nn.Sequential() + n:add( nn.Linear(10,10) ) + local p = n:getParameters() + + n:add( nn.Linear(10,10) ) + p = n:getParameters() + + mytester:asserteq((p[{ {111,210} }] - n.modules[2].weight):norm(), 0, 'error when appending new module') + mytester:asserteq((p[{ {211,220} }] - n.modules[2].bias):norm(), 0, 'error when appending new module') +end + +function nntest.Module_getParameters_3() + local n = nn.Sequential() + n:add( nn.Linear(10,10) ) + n:add( n.modules[1]:clone() ) + local p = n:getParameters() + + mytester:asserteq((p[{ {1,100} }] - n.modules[1].weight):norm(), 0, 'error when using cloning') + mytester:asserteq((p[{ {101,110} }] - n.modules[1].bias):norm(), 0, 'error when using cloning') + + mytester:asserteq((p[{ {111,210} }] - n.modules[2].weight):norm(), 0, 'error when using cloning') + mytester:asserteq((p[{ {211,220} }] - n.modules[2].bias):norm(), 0, 'error when using cloning') + + mytester:asserteq((p[{ {111,210} }] - n.modules[1].weight):norm(), 0, 'error when using cloning') + mytester:asserteq((p[{ {211,220} }] - n.modules[1].bias):norm(), 0, 'error when using cloning') + + n:reset() + + mytester:assertgt((p[{ {111,210} }] - n.modules[1].weight):norm(), 0, 'error when using cloning') + mytester:assertgt((p[{ {211,220} }] - n.modules[1].bias):norm(), 0, 'error when using cloning') +end + +function nntest.Module_getParameters_4() + local n = nn.Sequential() + n:add( nn.Linear(10,10) ) + n:add( n.modules[1]:clone() ) + local p = n:getParameters() + + n:add(nn.Linear(10,10)) + p = n:getParameters() + + mytester:asserteq((p[{ {1,100} }] - n.modules[1].weight):norm(), 0, 'error when using cloning') + mytester:asserteq((p[{ {101,110} }] - n.modules[1].bias):norm(), 0, 'error when using cloning') + + mytester:asserteq((p[{ {111,210} }] - n.modules[2].weight):norm(), 0, 'error when using cloning') + mytester:asserteq((p[{ {211,220} }] - n.modules[2].bias):norm(), 0, 'error when using cloning') + + mytester:asserteq((p[{ {221,320} }] - n.modules[3].weight):norm(), 0, 'error when using cloning') + mytester:asserteq((p[{ {321,330} }] - n.modules[3].bias):norm(), 0, 'error when using cloning') +end + +function nntest.Module_getParameters_5() + local n = nn.Sequential() + n:add( nn.Linear(10,10) ) + n:add( n.modules[1]:clone('weight','bias') ) + local p = n:getParameters() + + mytester:asserteq((p[{ {1,100} }] - n.modules[1].weight):norm(), 0, 'error when using cloning+sharing') + mytester:asserteq((p[{ {101,110} }] - n.modules[1].bias):norm(), 0, 'error when using cloning+sharing') + + mytester:asserteq((p[{ {1,100} }] - n.modules[2].weight):norm(), 0, 'error when using cloning+sharing') + mytester:asserteq((p[{ {101,110} }] - n.modules[2].bias):norm(), 0, 'error when using cloning+sharing') + + n:reset() + + mytester:asserteq((p[{ {1,100} }] - n.modules[2].weight):norm(), 0, 'error when using cloning+sharing') + mytester:asserteq((p[{ {101,110} }] - n.modules[2].bias):norm(), 0, 'error when using cloning+sharing') +end + +function nntest.Module_getParameters_6() + local n = nn.Sequential() + n:add( nn.Linear(10,10) ) + n:add( n.modules[1]:clone('weight','bias') ) + local p = n:getParameters() + + n:add(nn.Linear(10,10)) + p = n:getParameters() + + mytester:asserteq((p[{ {1,100} }] - n.modules[1].weight):norm(), 0, 'error when using cloning+sharing') + mytester:asserteq((p[{ {101,110} }] - n.modules[1].bias):norm(), 0, 'error when using cloning+sharing') + + mytester:asserteq((p[{ {1,100} }] - n.modules[2].weight):norm(), 0, 'error when using cloning+sharing') + mytester:asserteq((p[{ {101,110} }] - n.modules[2].bias):norm(), 0, 'error when using cloning+sharing') + + mytester:asserteq((p[{ {111,210} }] - n.modules[3].weight):norm(), 0, 'error when using cloning+sharing') + mytester:asserteq((p[{ {211,220} }] - n.modules[3].bias):norm(), 0, 'error when using cloning+sharing') +end mytester:add(nntest) ---mytester:add(test_SpatialConvolution) ---mytester:add(test_AbsCriterion) if not nn then require 'nn' |