diff options
author | James Kirkpatrick <kirkpatrick@google.com> | 2014-10-01 12:16:27 +0400 |
---|---|---|
committer | Soumith Chintala <soumith@gmail.com> | 2014-10-27 17:05:16 +0300 |
commit | 8581f84f45c2baef72dd32f6336410480b617ff7 (patch) | |
tree | e590661f77f4780dee406b44cc6e11e3fdce09cf | |
parent | 8d7d03ebe72ae507ae716f292778123dc34e04b1 (diff) |
Corrected getParamaters for partial views
Module:getParameters was incorrectly overwriting parameters that
were partial views on larger storages.
-rw-r--r-- | Module.lua | 11 | ||||
-rw-r--r-- | test/test.lua | 31 |
2 files changed, 40 insertions, 2 deletions
@@ -205,6 +205,7 @@ function Module:getParameters() parameters[k]:zero() end + local maskParameters= flatParameters:float():clone() local cumSumOfHoles = flatParameters:float():cumsum(1) local nUsedParameters = nParameters - cumSumOfHoles[#cumSumOfHoles] local flatUsedParameters = Tensor(nUsedParameters) @@ -222,12 +223,18 @@ function Module:getParameters() local k, v = unpack(storageAndOffset) flatParameters[{{v+1,v+k:size()}}]:copy(Tensor():set(k)) end + if cumSumOfHoles:sum() == 0 then flatUsedParameters:copy(flatParameters) else - for k = 1,flatUsedParameters:nElement() do - flatUsedParameters[k] = flatParameters[k+cumSumOfHoles[k]] + local counter = 0 + for k = 1,flatParameters:nElement() do + if maskParameters[k] == 0 then + counter = counter + 1 + flatUsedParameters[counter] = flatParameters[counter+cumSumOfHoles[k]] + end end + assert (counter == nUsedParameters) end return flatUsedParameters end diff --git a/test/test.lua b/test/test.lua index b035928..aa1d062 100644 --- a/test/test.lua +++ b/test/test.lua @@ -1745,6 +1745,37 @@ function nntest.Module_getParameters_7() mytester:asserteq(p:nElement(), 121, 'error: incorrect number of elements in flat vector') end +function nntest.Module_getParameters_8() + local function makeMLP(nin, ns) + local net = nn.Sequential() + + for k,v in ipairs(ns) do + net:add(nn.Linear(nin, v)) + nin = v + end + _,_ = net:getParameters() + return net + end + + local mlp1 = makeMLP(10, {10,10}) + local mlp2 = makeMLP(10, {10,10}) + + local net = nn.Sequential():add(mlp1:get(1)) + :add(mlp2:get(1)) + + -- clone the second MLP to ensure that the weights before calling getParameters are preserved + mlp2 = mlp2:clone() + + local p, gp = net:getParameters() + + mytester:asserteq((p[{ {1,100} }] - net.modules[1].weight):norm(), 0, 'error when using partial realloc') + mytester:asserteq((p[{ {111,210} }] - net.modules[2].weight):norm(), 0, 'error when using partial realloc') + -- check that the weights have the same values as before get Parameters was called + mytester:asserteq((net.modules[1].weight - mlp1.modules[1].weight):norm(), 0, ' error when using partial realloc') + mytester:asserteq((net.modules[2].weight - mlp2.modules[1].weight):norm(), 0, ' error when using partial realloc') + +end + function nntest.PairwiseDistance() -- Note: testJacobian doesn't support table inputs, and rather than re-write -- it so that it does, I'll just use a split table module on the input. |