diff options
author | James Kirkpatrick <kirkpatrick@google.com> | 2014-10-01 12:16:27 +0400 |
---|---|---|
committer | Soumith Chintala <soumith@gmail.com> | 2014-10-27 17:05:16 +0300 |
commit | 8581f84f45c2baef72dd32f6336410480b617ff7 (patch) | |
tree | e590661f77f4780dee406b44cc6e11e3fdce09cf /test | |
parent | 8d7d03ebe72ae507ae716f292778123dc34e04b1 (diff) |
Corrected getParamaters for partial views
Module:getParameters was incorrectly overwriting parameters that
were partial views on larger storages.
Diffstat (limited to 'test')
-rw-r--r-- | test/test.lua | 31 |
1 files changed, 31 insertions, 0 deletions
diff --git a/test/test.lua b/test/test.lua index b035928..aa1d062 100644 --- a/test/test.lua +++ b/test/test.lua @@ -1745,6 +1745,37 @@ function nntest.Module_getParameters_7() mytester:asserteq(p:nElement(), 121, 'error: incorrect number of elements in flat vector') end +function nntest.Module_getParameters_8() + local function makeMLP(nin, ns) + local net = nn.Sequential() + + for k,v in ipairs(ns) do + net:add(nn.Linear(nin, v)) + nin = v + end + _,_ = net:getParameters() + return net + end + + local mlp1 = makeMLP(10, {10,10}) + local mlp2 = makeMLP(10, {10,10}) + + local net = nn.Sequential():add(mlp1:get(1)) + :add(mlp2:get(1)) + + -- clone the second MLP to ensure that the weights before calling getParameters are preserved + mlp2 = mlp2:clone() + + local p, gp = net:getParameters() + + mytester:asserteq((p[{ {1,100} }] - net.modules[1].weight):norm(), 0, 'error when using partial realloc') + mytester:asserteq((p[{ {111,210} }] - net.modules[2].weight):norm(), 0, 'error when using partial realloc') + -- check that the weights have the same values as before get Parameters was called + mytester:asserteq((net.modules[1].weight - mlp1.modules[1].weight):norm(), 0, ' error when using partial realloc') + mytester:asserteq((net.modules[2].weight - mlp2.modules[1].weight):norm(), 0, ' error when using partial realloc') + +end + function nntest.PairwiseDistance() -- Note: testJacobian doesn't support table inputs, and rather than re-write -- it so that it does, I'll just use a split table module on the input. |