Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorClement Farabet <clement.farabet@gmail.com>2012-09-14 18:23:19 +0400
committerClement Farabet <clement.farabet@gmail.com>2012-09-14 18:23:19 +0400
commit2f1b5b074ef88c70c2ebb9b4aeddd5554c029452 (patch)
tree99681176623b24325860bc29d5c250c1767c74e2 /hessian.lua
parent66421de171a303366dd6e87bc9b6acfd99d03dcd (diff)
Fixed flatten function when hessian is activated.
Diffstat (limited to 'hessian.lua')
-rw-r--r--hessian.lua91
1 files changed, 34 insertions, 57 deletions
diff --git a/hessian.lua b/hessian.lua
index 4ecc90a..422e1d9 100644
--- a/hessian.lua
+++ b/hessian.lua
@@ -265,70 +265,47 @@ function nn.hessian.enable()
-- get parameters
local parameters,gradParameters,hessianParameters = self:parameters()
+ local function storageInSet(set, storage) --this is waste of time (need correct hash)
+ for key, val in pairs(set) do
+ if key == storage then
+ return val
+ end
+ end
+ end
+
-- this function flattens arbitrary lists of parameters,
-- even complex shared ones
local function flatten(parameters)
- -- already flat ?
- local flat = true
- for k = 2,#parameters do
- if parameters[k]:storage() ~= parameters[k-1]:storage() then
- flat = false
- break
+ local storages = {}
+ local nParameters = 0
+ for k = 1,#parameters do
+ if not storageInSet(storages, parameters[k]:storage()) then
+ storages[parameters[k]:storage()] = nParameters
+ nParameters = nParameters + parameters[k]:storage():size()
end
end
- if flat then
- local nParameters = 0
- for k,param in ipairs(parameters) do
- nParameters = nParameters + param:nElement()
- end
- local flatParameters = parameters[1].new(parameters[1]:storage())
- if nParameters ~= flatParameters:nElement() then
- error('flattenParameters(): weird parameters')
- end
- return flatParameters
+
+ local flatParameters = torch.Tensor(nParameters):fill(1)
+ local flatStorage = flatParameters:storage()
+
+ for k = 1,#parameters do
+ local storageOffset = storageInSet(storages, parameters[k]:storage())
+ parameters[k]:set(flatStorage,
+ storageOffset + parameters[k]:storageOffset(),
+ parameters[k]:size(),
+ parameters[k]:stride())
+ parameters[k]:zero()
end
- -- compute offsets of each parameter
- local offsets = {}
- local sizes = {}
- local strides = {}
- local elements = {}
- local storageOffsets = {}
- local params = {}
- local nParameters = 0
- for k,param in ipairs(parameters) do
- table.insert(offsets, nParameters+1)
- table.insert(sizes, param:size())
- table.insert(strides, param:stride())
- table.insert(elements, param:nElement())
- table.insert(storageOffsets, param:storageOffset())
- local isView = false
- for i = 1,k-1 do
- if param:storage() == parameters[i]:storage() then
- offsets[k] = offsets[i]
- if storageOffsets[k] ~= storageOffsets[i] or elements[k] ~= elements[i] then
- error('flattenParameters(): cannot flatten shared weights with different structures')
- end
- isView = true
- break
- end
- end
- if not isView then
- nParameters = nParameters + param:nElement()
- end
+ if (flatParameters:sum() ~= 0) then
+ print("<getParameters()> WARNING: found "
+ .. flatParameters:sum() .. " holes in the parameters vector (i.e. "
+ .. flatParameters:sum() .. " storage elements that are unused, this "
+ .. "might be an issue for your optimization procedure)")
end
- -- create flat vector
- local flatParameters = parameters[1].new(nParameters)
- local storage = flatParameters:storage()
- -- reallocate all parameters in flat vector
- for i = 1,#parameters do
- local data = parameters[i]:clone()
- parameters[i]:set(storage, offsets[i], elements[i]):resize(sizes[i],strides[i]):copy(data)
- data = nil
- collectgarbage()
+
+ for k, v in pairs(storages) do
+ flatParameters[{{v+1,v+k:size()}}]:copy(torch.Tensor():set(k))
end
- -- cleanup
- collectgarbage()
- -- return flat param
return flatParameters
end
@@ -336,7 +313,7 @@ function nn.hessian.enable()
local flatParameters = flatten(parameters)
local flatGradParameters = flatten(gradParameters)
local flatHessianParameters
- if hessianParameters[1] then
+ if hessianParameters and hessianParameters[1] then
flatHessianParameters = flatten(hessianParameters)
end