diff options
author | Ivo Danihelka <ivo@danihelka.net> | 2012-11-29 14:55:23 +0400 |
---|---|---|
committer | Ivo Danihelka <ivo@danihelka.net> | 2012-11-29 14:55:23 +0400 |
commit | 34da23919a23b02a1cbddbc62f40ae502a0be946 (patch) | |
tree | 6dbfffaab69f22666c3aebaf65062e4ae7d45177 /Module.lua | |
parent | 6160010cbc434e02ab3b8f2c414afd96a5fb2c30 (diff) |
Improved Module:getParameters() speed when using many storages.
Diffstat (limited to 'Module.lua')
-rw-r--r-- | Module.lua | 21 |
1 files changed, 12 insertions, 9 deletions
@@ -139,12 +139,13 @@ function Module:getParameters() -- get parameters local parameters,gradParameters = self:parameters() - local function storageInSet(set, storage) --this is waste of time (need correct hash) - for key, val in pairs(set) do - if key == storage then - return val - end + local function storageInSet(set, storage) + local storageAndOffset = set[torch.pointer(storage)] + if storageAndOffset == nil then + return nil end + local storage, offset = unpack(storageAndOffset) + return offset end -- this function flattens arbitrary lists of parameters, @@ -155,9 +156,10 @@ function Module:getParameters() local storages = {} local nParameters = 0 for k = 1,#parameters do - if not storageInSet(storages, parameters[k]:storage()) then - storages[parameters[k]:storage()] = nParameters - nParameters = nParameters + parameters[k]:storage():size() + local storage = parameters[k]:storage() + if not storageInSet(storages, storage) then + storages[torch.pointer(storage)] = {storage, nParameters} + nParameters = nParameters + storage:size() end end @@ -186,7 +188,8 @@ function Module:getParameters() parameters[k]:stride()) end - for k, v in pairs(storages) do + for _, storageAndOffset in pairs(storages) do + local k, v = unpack(storageAndOffset) flatParameters[{{v+1,v+k:size()}}]:copy(Tensor():set(k)) end for k = 1,flatUsedParameters:nElement() do |