diff options
author | Clement Farabet <clement.farabet@gmail.com> | 2012-09-14 18:23:19 +0400 |
---|---|---|
committer | Clement Farabet <clement.farabet@gmail.com> | 2012-09-14 18:23:19 +0400 |
commit | 2f1b5b074ef88c70c2ebb9b4aeddd5554c029452 (patch) | |
tree | 99681176623b24325860bc29d5c250c1767c74e2 /hessian.lua | |
parent | 66421de171a303366dd6e87bc9b6acfd99d03dcd (diff) |
Fixed flatten function when hessian is activated.
Diffstat (limited to 'hessian.lua')
-rw-r--r-- | hessian.lua | 91 |
1 files changed, 34 insertions, 57 deletions
diff --git a/hessian.lua b/hessian.lua index 4ecc90a..422e1d9 100644 --- a/hessian.lua +++ b/hessian.lua @@ -265,70 +265,47 @@ function nn.hessian.enable() -- get parameters local parameters,gradParameters,hessianParameters = self:parameters() + local function storageInSet(set, storage) --this is waste of time (need correct hash) + for key, val in pairs(set) do + if key == storage then + return val + end + end + end + -- this function flattens arbitrary lists of parameters, -- even complex shared ones local function flatten(parameters) - -- already flat ? - local flat = true - for k = 2,#parameters do - if parameters[k]:storage() ~= parameters[k-1]:storage() then - flat = false - break + local storages = {} + local nParameters = 0 + for k = 1,#parameters do + if not storageInSet(storages, parameters[k]:storage()) then + storages[parameters[k]:storage()] = nParameters + nParameters = nParameters + parameters[k]:storage():size() end end - if flat then - local nParameters = 0 - for k,param in ipairs(parameters) do - nParameters = nParameters + param:nElement() - end - local flatParameters = parameters[1].new(parameters[1]:storage()) - if nParameters ~= flatParameters:nElement() then - error('flattenParameters(): weird parameters') - end - return flatParameters + + local flatParameters = torch.Tensor(nParameters):fill(1) + local flatStorage = flatParameters:storage() + + for k = 1,#parameters do + local storageOffset = storageInSet(storages, parameters[k]:storage()) + parameters[k]:set(flatStorage, + storageOffset + parameters[k]:storageOffset(), + parameters[k]:size(), + parameters[k]:stride()) + parameters[k]:zero() end - -- compute offsets of each parameter - local offsets = {} - local sizes = {} - local strides = {} - local elements = {} - local storageOffsets = {} - local params = {} - local nParameters = 0 - for k,param in ipairs(parameters) do - table.insert(offsets, nParameters+1) - table.insert(sizes, param:size()) - table.insert(strides, param:stride()) - table.insert(elements, param:nElement()) - table.insert(storageOffsets, param:storageOffset()) - local isView = false - for i = 1,k-1 do - if param:storage() == parameters[i]:storage() then - offsets[k] = offsets[i] - if storageOffsets[k] ~= storageOffsets[i] or elements[k] ~= elements[i] then - error('flattenParameters(): cannot flatten shared weights with different structures') - end - isView = true - break - end - end - if not isView then - nParameters = nParameters + param:nElement() - end + if (flatParameters:sum() ~= 0) then + print("<getParameters()> WARNING: found " + .. flatParameters:sum() .. " holes in the parameters vector (i.e. " + .. flatParameters:sum() .. " storage elements that are unused, this " + .. "might be an issue for your optimization procedure)") end - -- create flat vector - local flatParameters = parameters[1].new(nParameters) - local storage = flatParameters:storage() - -- reallocate all parameters in flat vector - for i = 1,#parameters do - local data = parameters[i]:clone() - parameters[i]:set(storage, offsets[i], elements[i]):resize(sizes[i],strides[i]):copy(data) - data = nil - collectgarbage() + + for k, v in pairs(storages) do + flatParameters[{{v+1,v+k:size()}}]:copy(torch.Tensor():set(k)) end - -- cleanup - collectgarbage() - -- return flat param return flatParameters end @@ -336,7 +313,7 @@ function nn.hessian.enable() local flatParameters = flatten(parameters) local flatGradParameters = flatten(gradParameters) local flatHessianParameters - if hessianParameters[1] then + if hessianParameters and hessianParameters[1] then flatHessianParameters = flatten(hessianParameters) end |