diff options
Diffstat (limited to 'Module.lua')
-rw-r--r-- | Module.lua | 129 |
1 files changed, 64 insertions, 65 deletions
@@ -137,92 +137,91 @@ end function Module:reset() end -function Module:getParameters() - -- get parameters - local parameters,gradParameters = self:parameters() - +-- this function flattens arbitrary lists of parameters, +-- even complex shared ones +function Module.flatten(parameters) local function storageInSet(set, storage) local storageAndOffset = set[torch.pointer(storage)] if storageAndOffset == nil then - return nil + return nil end local _, offset = table.unpack(storageAndOffset) return offset end - -- this function flattens arbitrary lists of parameters, - -- even complex shared ones - local function flatten(parameters) - if not parameters or #parameters == 0 then - return torch.Tensor() + if not parameters or #parameters == 0 then + return torch.Tensor() + end + local Tensor = parameters[1].new + local dtype = parameters[1]:type() + + local storages = {} + local nParameters = 0 + for k = 1,#parameters do + if parameters[k]:type() ~= dtype then + error("Inconsistent parameter types. " .. parameters[k]:type() .. + " ~= " .. dtype) end - local Tensor = parameters[1].new - local dtype = parameters[1]:type() - - local storages = {} - local nParameters = 0 - for k = 1,#parameters do - if parameters[k]:type() ~= dtype then - error("Inconsistent parameter types. " .. parameters[k]:type() .. - " ~= " .. dtype) - end - local storage = parameters[k]:storage() - if not storageInSet(storages, storage) then - storages[torch.pointer(storage)] = {storage, nParameters} - nParameters = nParameters + storage:size() - end + local storage = parameters[k]:storage() + if not storageInSet(storages, storage) then + storages[torch.pointer(storage)] = {storage, nParameters} + nParameters = nParameters + storage:size() end + end - local flatParameters = Tensor(nParameters):fill(1) - local flatStorage = flatParameters:storage() + local flatParameters = Tensor(nParameters):fill(1) + local flatStorage = flatParameters:storage() - for k = 1,#parameters do - local storageOffset = storageInSet(storages, parameters[k]:storage()) - parameters[k]:set(flatStorage, - storageOffset + parameters[k]:storageOffset(), - parameters[k]:size(), - parameters[k]:stride()) - parameters[k]:zero() - end + for k = 1,#parameters do + local storageOffset = storageInSet(storages, parameters[k]:storage()) + parameters[k]:set(flatStorage, + storageOffset + parameters[k]:storageOffset(), + parameters[k]:size(), + parameters[k]:stride()) + parameters[k]:zero() + end - local maskParameters = flatParameters:float():clone() - local cumSumOfHoles = flatParameters:float():cumsum(1) - local nUsedParameters = nParameters - cumSumOfHoles[#cumSumOfHoles] - local flatUsedParameters = Tensor(nUsedParameters) - local flatUsedStorage = flatUsedParameters:storage() - - for k = 1,#parameters do - local offset = cumSumOfHoles[parameters[k]:storageOffset()] - parameters[k]:set(flatUsedStorage, - parameters[k]:storageOffset() - offset, - parameters[k]:size(), - parameters[k]:stride()) - end + local maskParameters = flatParameters:float():clone() + local cumSumOfHoles = flatParameters:float():cumsum(1) + local nUsedParameters = nParameters - cumSumOfHoles[#cumSumOfHoles] + local flatUsedParameters = Tensor(nUsedParameters) + local flatUsedStorage = flatUsedParameters:storage() + + for k = 1,#parameters do + local offset = cumSumOfHoles[parameters[k]:storageOffset()] + parameters[k]:set(flatUsedStorage, + parameters[k]:storageOffset() - offset, + parameters[k]:size(), + parameters[k]:stride()) + end - for _, storageAndOffset in pairs(storages) do - local k, v = table.unpack(storageAndOffset) - flatParameters[{{v+1,v+k:size()}}]:copy(Tensor():set(k)) - end + for _, storageAndOffset in pairs(storages) do + local k, v = table.unpack(storageAndOffset) + flatParameters[{{v+1,v+k:size()}}]:copy(Tensor():set(k)) + end - if cumSumOfHoles:sum() == 0 then - flatUsedParameters:copy(flatParameters) - else - local counter = 0 - for k = 1,flatParameters:nElement() do - if maskParameters[k] == 0 then - counter = counter + 1 - flatUsedParameters[counter] = flatParameters[counter+cumSumOfHoles[k]] - end + if cumSumOfHoles:sum() == 0 then + flatUsedParameters:copy(flatParameters) + else + local counter = 0 + for k = 1,flatParameters:nElement() do + if maskParameters[k] == 0 then + counter = counter + 1 + flatUsedParameters[counter] = flatParameters[counter+cumSumOfHoles[k]] end - assert (counter == nUsedParameters) end - return flatUsedParameters + assert (counter == nUsedParameters) end + return flatUsedParameters +end +function Module:getParameters() + -- get parameters + local parameters,gradParameters = self:parameters() -- flatten parameters and gradients - local flatParameters = flatten(parameters) + local flatParameters = Module.flatten(parameters) collectgarbage() - local flatGradParameters = flatten(gradParameters) + local flatGradParameters = Module.flatten(gradParameters) collectgarbage() -- return new flat vector that contains all discrete parameters |