diff options
author | Clement Farabet <clement.farabet@gmail.com> | 2012-07-14 10:36:04 +0400 |
---|---|---|
committer | Clement Farabet <clement.farabet@gmail.com> | 2012-07-14 10:36:04 +0400 |
commit | 0511421f2df3d2a015986de292b4e39af65ea298 (patch) | |
tree | 124a6356d9ca14fb949f34c65f735e64cd206a39 | |
parent | f61a5ccdcd0f8b8f87df63f3d18a445380c18b22 (diff) |
Attempt to fix getParameters.
-rw-r--r-- | Module.lua | 89 |
1 files changed, 33 insertions, 56 deletions
@@ -139,70 +139,47 @@ function Module:getParameters() -- get parameters local parameters,gradParameters = self:parameters() + local function storageInSet(set, storage) --this is waste of time (need correct hash) + for key, val in pairs(set) do + if key == storage then + return val + end + end + end + -- this function flattens arbitrary lists of parameters, -- even complex shared ones local function flatten(parameters) - -- already flat ? - local flat = true - for k = 2,#parameters do - if parameters[k]:storage() ~= parameters[k-1]:storage() then - flat = false - break + local storages = {} + local nParameters = 0 + for k = 1,#parameters do + if not storageInSet(storages, parameters[k]:storage()) then + storages[parameters[k]:storage()] = nParameters + nParameters = nParameters + parameters[k]:storage():size() end end - if flat then - local nParameters = 0 - for k,param in ipairs(parameters) do - nParameters = nParameters + param:nElement() - end - local flatParameters = parameters[1].new(parameters[1]:storage()) - if nParameters ~= flatParameters:nElement() then - error('flattenParameters(): weird parameters') - end - return flatParameters + + local flatParameters = torch.Tensor(nParameters):fill(1) + local flatStorage = flatParameters:storage() + + for k = 1,#parameters do + local storageOffset = storageInSet(storages, parameters[k]:storage()) + parameters[k]:set(flatStorage, + storageOffset + parameters[k]:storageOffset(), + parameters[k]:size(), + parameters[k]:stride()) + parameters[k]:zero() end - -- compute offsets of each parameter - local offsets = {} - local sizes = {} - local strides = {} - local elements = {} - local storageOffsets = {} - local params = {} - local nParameters = 0 - for k,param in ipairs(parameters) do - table.insert(offsets, nParameters+1) - table.insert(sizes, param:size()) - table.insert(strides, param:stride()) - table.insert(elements, param:nElement()) - table.insert(storageOffsets, param:storageOffset()) - local isView = false - for i = 1,k-1 do - if param:storage() == parameters[i]:storage() then - offsets[k] = offsets[i] - if storageOffsets[k] ~= storageOffsets[i] or elements[k] ~= elements[i] then - error('flattenParameters(): cannot flatten shared weights with different structures') - end - isView = true - break - end - end - if not isView then - nParameters = nParameters + param:nElement() - end + if (flatParameters:sum() ~= 0) then + print("<getParameters()> WARNING: found " + .. flatParameters:sum() .. " holes in the parameters vector (i.e. " + .. flatParameters:sum() .. " storage elements that are unused, this " + .. "might be an issue for your optimization procedure)") end - -- create flat vector - local flatParameters = parameters[1].new(nParameters) - local storage = flatParameters:storage() - -- reallocate all parameters in flat vector - for i = 1,#parameters do - local data = parameters[i]:clone() - parameters[i]:set(storage, offsets[i], elements[i]):resize(sizes[i],strides[i]):copy(data) - data = nil - collectgarbage() + + for k, v in pairs(storages) do + flatParameters[{{v+1,v+k:size()}}]:copy(torch.Tensor():set(k)) end - -- cleanup - collectgarbage() - -- return flat param return flatParameters end |