Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorClement Farabet <clement.farabet@gmail.com>2012-09-07 20:44:49 +0400
committerClement Farabet <clement.farabet@gmail.com>2012-09-07 20:44:49 +0400
commitcde19db1b991380bec5700f11ff837677dcbf444 (patch)
tree86050f677af1b13a52c5b63a5458bd0d7768e36a /test
parentbf724fdabdca054b2cfde2fee1908c3b216e65a6 (diff)
New tester.
Diffstat (limited to 'test')
-rw-r--r--test/test.lua99
1 files changed, 97 insertions, 2 deletions
diff --git a/test/test.lua b/test/test.lua
index 4d4383d..ff8b0d9 100644
--- a/test/test.lua
+++ b/test/test.lua
@@ -1078,10 +1078,105 @@ function nntest.VolumetricConvolution()
mytester:asserteq(0, berr, torch.typename(module) .. ' - i/o backward err ')
end
+function nntest.Module_getParameters_1()
+ local n = nn.Sequential()
+ n:add( nn.Linear(10,10) )
+ local p = n:getParameters()
+
+ mytester:asserteq((p[{ {1,100} }] - n.modules[1].weight):norm(), 0, 'getParameters(): weights wrong')
+ mytester:asserteq((p[{ {101,110} }] - n.modules[1].bias):norm(), 0, 'getParameters(): bias wrong')
+end
+
+function nntest.Module_getParameters_2()
+ local n = nn.Sequential()
+ n:add( nn.Linear(10,10) )
+ local p = n:getParameters()
+
+ n:add( nn.Linear(10,10) )
+ p = n:getParameters()
+
+ mytester:asserteq((p[{ {111,210} }] - n.modules[2].weight):norm(), 0, 'error when appending new module')
+ mytester:asserteq((p[{ {211,220} }] - n.modules[2].bias):norm(), 0, 'error when appending new module')
+end
+
+function nntest.Module_getParameters_3()
+ local n = nn.Sequential()
+ n:add( nn.Linear(10,10) )
+ n:add( n.modules[1]:clone() )
+ local p = n:getParameters()
+
+ mytester:asserteq((p[{ {1,100} }] - n.modules[1].weight):norm(), 0, 'error when using cloning')
+ mytester:asserteq((p[{ {101,110} }] - n.modules[1].bias):norm(), 0, 'error when using cloning')
+
+ mytester:asserteq((p[{ {111,210} }] - n.modules[2].weight):norm(), 0, 'error when using cloning')
+ mytester:asserteq((p[{ {211,220} }] - n.modules[2].bias):norm(), 0, 'error when using cloning')
+
+ mytester:asserteq((p[{ {111,210} }] - n.modules[1].weight):norm(), 0, 'error when using cloning')
+ mytester:asserteq((p[{ {211,220} }] - n.modules[1].bias):norm(), 0, 'error when using cloning')
+
+ n:reset()
+
+ mytester:assertgt((p[{ {111,210} }] - n.modules[1].weight):norm(), 0, 'error when using cloning')
+ mytester:assertgt((p[{ {211,220} }] - n.modules[1].bias):norm(), 0, 'error when using cloning')
+end
+
+function nntest.Module_getParameters_4()
+ local n = nn.Sequential()
+ n:add( nn.Linear(10,10) )
+ n:add( n.modules[1]:clone() )
+ local p = n:getParameters()
+
+ n:add(nn.Linear(10,10))
+ p = n:getParameters()
+
+ mytester:asserteq((p[{ {1,100} }] - n.modules[1].weight):norm(), 0, 'error when using cloning')
+ mytester:asserteq((p[{ {101,110} }] - n.modules[1].bias):norm(), 0, 'error when using cloning')
+
+ mytester:asserteq((p[{ {111,210} }] - n.modules[2].weight):norm(), 0, 'error when using cloning')
+ mytester:asserteq((p[{ {211,220} }] - n.modules[2].bias):norm(), 0, 'error when using cloning')
+
+ mytester:asserteq((p[{ {221,320} }] - n.modules[3].weight):norm(), 0, 'error when using cloning')
+ mytester:asserteq((p[{ {321,330} }] - n.modules[3].bias):norm(), 0, 'error when using cloning')
+end
+
+function nntest.Module_getParameters_5()
+ local n = nn.Sequential()
+ n:add( nn.Linear(10,10) )
+ n:add( n.modules[1]:clone('weight','bias') )
+ local p = n:getParameters()
+
+ mytester:asserteq((p[{ {1,100} }] - n.modules[1].weight):norm(), 0, 'error when using cloning+sharing')
+ mytester:asserteq((p[{ {101,110} }] - n.modules[1].bias):norm(), 0, 'error when using cloning+sharing')
+
+ mytester:asserteq((p[{ {1,100} }] - n.modules[2].weight):norm(), 0, 'error when using cloning+sharing')
+ mytester:asserteq((p[{ {101,110} }] - n.modules[2].bias):norm(), 0, 'error when using cloning+sharing')
+
+ n:reset()
+
+ mytester:asserteq((p[{ {1,100} }] - n.modules[2].weight):norm(), 0, 'error when using cloning+sharing')
+ mytester:asserteq((p[{ {101,110} }] - n.modules[2].bias):norm(), 0, 'error when using cloning+sharing')
+end
+
+function nntest.Module_getParameters_6()
+ local n = nn.Sequential()
+ n:add( nn.Linear(10,10) )
+ n:add( n.modules[1]:clone('weight','bias') )
+ local p = n:getParameters()
+
+ n:add(nn.Linear(10,10))
+ p = n:getParameters()
+
+ mytester:asserteq((p[{ {1,100} }] - n.modules[1].weight):norm(), 0, 'error when using cloning+sharing')
+ mytester:asserteq((p[{ {101,110} }] - n.modules[1].bias):norm(), 0, 'error when using cloning+sharing')
+
+ mytester:asserteq((p[{ {1,100} }] - n.modules[2].weight):norm(), 0, 'error when using cloning+sharing')
+ mytester:asserteq((p[{ {101,110} }] - n.modules[2].bias):norm(), 0, 'error when using cloning+sharing')
+
+ mytester:asserteq((p[{ {111,210} }] - n.modules[3].weight):norm(), 0, 'error when using cloning+sharing')
+ mytester:asserteq((p[{ {211,220} }] - n.modules[3].bias):norm(), 0, 'error when using cloning+sharing')
+end
mytester:add(nntest)
---mytester:add(test_SpatialConvolution)
---mytester:add(test_AbsCriterion)
if not nn then
require 'nn'