diff options
Diffstat (limited to 'Module.lua')
-rw-r--r-- | Module.lua | 26 |
1 files changed, 19 insertions, 7 deletions
@@ -170,17 +170,29 @@ function Module:getParameters() parameters[k]:stride()) parameters[k]:zero() end - if (flatParameters:sum() ~= 0) then - print("<getParameters()> WARNING: found " - .. flatParameters:sum() .. " holes in the parameters vector (i.e. " - .. flatParameters:sum() .. " storage elements that are unused, this " - .. "might be an issue for your optimization procedure)") + + local cumSumOfHoles = flatParameters:cumsum(1) + local nUsedParameters = nParameters - cumSumOfHoles[#cumSumOfHoles] + local flatUsedParameters = torch.Tensor(nUsedParameters) + local flatUsedStorage = flatUsedParameters:storage() + + for k = 1,#parameters do + local offset = cumSumOfHoles[parameters[k]:storageOffset()] + parameters[k]:set(flatUsedStorage, + parameters[k]:storageOffset() - offset, + parameters[k]:size(), + parameters[k]:stride()) end - for k, v in pairs(storages) do + for k, v in pairs(storages) do -- we could remove this loop flatParameters[{{v+1,v+k:size()}}]:copy(torch.Tensor():set(k)) end - return flatParameters + print('crap') + for k = 1,flatUsedParameters:nElement() do + flatUsedParameters[k] = flatParameters[k+cumSumOfHoles[k] ] + end + print('0') + return flatUsedParameters end -- flatten parameters and gradients |