diff options
author | James Kirkpatrick <kirkpatrick@google.com> | 2014-10-01 12:16:27 +0400 |
---|---|---|
committer | Soumith Chintala <soumith@gmail.com> | 2014-10-27 17:05:16 +0300 |
commit | 8581f84f45c2baef72dd32f6336410480b617ff7 (patch) | |
tree | e590661f77f4780dee406b44cc6e11e3fdce09cf /Module.lua | |
parent | 8d7d03ebe72ae507ae716f292778123dc34e04b1 (diff) |
Corrected getParamaters for partial views
Module:getParameters was incorrectly overwriting parameters that
were partial views on larger storages.
Diffstat (limited to 'Module.lua')
-rw-r--r-- | Module.lua | 11 |
1 files changed, 9 insertions, 2 deletions
@@ -205,6 +205,7 @@ function Module:getParameters() parameters[k]:zero() end + local maskParameters= flatParameters:float():clone() local cumSumOfHoles = flatParameters:float():cumsum(1) local nUsedParameters = nParameters - cumSumOfHoles[#cumSumOfHoles] local flatUsedParameters = Tensor(nUsedParameters) @@ -222,12 +223,18 @@ function Module:getParameters() local k, v = unpack(storageAndOffset) flatParameters[{{v+1,v+k:size()}}]:copy(Tensor():set(k)) end + if cumSumOfHoles:sum() == 0 then flatUsedParameters:copy(flatParameters) else - for k = 1,flatUsedParameters:nElement() do - flatUsedParameters[k] = flatParameters[k+cumSumOfHoles[k]] + local counter = 0 + for k = 1,flatParameters:nElement() do + if maskParameters[k] == 0 then + counter = counter + 1 + flatUsedParameters[counter] = flatParameters[counter+cumSumOfHoles[k]] + end end + assert (counter == nUsedParameters) end return flatUsedParameters end |