diff options
Diffstat (limited to 'Module.lua')
-rw-r--r-- | Module.lua | 11 |
1 files changed, 9 insertions, 2 deletions
@@ -205,6 +205,7 @@ function Module:getParameters() parameters[k]:zero() end + local maskParameters= flatParameters:float():clone() local cumSumOfHoles = flatParameters:float():cumsum(1) local nUsedParameters = nParameters - cumSumOfHoles[#cumSumOfHoles] local flatUsedParameters = Tensor(nUsedParameters) @@ -222,12 +223,18 @@ function Module:getParameters() local k, v = unpack(storageAndOffset) flatParameters[{{v+1,v+k:size()}}]:copy(Tensor():set(k)) end + if cumSumOfHoles:sum() == 0 then flatUsedParameters:copy(flatParameters) else - for k = 1,flatUsedParameters:nElement() do - flatUsedParameters[k] = flatParameters[k+cumSumOfHoles[k]] + local counter = 0 + for k = 1,flatParameters:nElement() do + if maskParameters[k] == 0 then + counter = counter + 1 + flatUsedParameters[counter] = flatParameters[counter+cumSumOfHoles[k]] + end end + assert (counter == nUsedParameters) end return flatUsedParameters end |