Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorIvo Danihelka <ivo@danihelka.net>2012-11-29 14:55:23 +0400
committerIvo Danihelka <ivo@danihelka.net>2012-11-29 14:55:23 +0400
commit34da23919a23b02a1cbddbc62f40ae502a0be946 (patch)
tree6dbfffaab69f22666c3aebaf65062e4ae7d45177 /Module.lua
parent6160010cbc434e02ab3b8f2c414afd96a5fb2c30 (diff)
Improved Module:getParameters() speed when using many storages.
Diffstat (limited to 'Module.lua')
-rw-r--r--Module.lua21
1 files changed, 12 insertions, 9 deletions
diff --git a/Module.lua b/Module.lua
index 4423aea..e9659e4 100644
--- a/Module.lua
+++ b/Module.lua
@@ -139,12 +139,13 @@ function Module:getParameters()
-- get parameters
local parameters,gradParameters = self:parameters()
- local function storageInSet(set, storage) --this is waste of time (need correct hash)
- for key, val in pairs(set) do
- if key == storage then
- return val
- end
+ local function storageInSet(set, storage)
+ local storageAndOffset = set[torch.pointer(storage)]
+ if storageAndOffset == nil then
+ return nil
end
+ local storage, offset = unpack(storageAndOffset)
+ return offset
end
-- this function flattens arbitrary lists of parameters,
@@ -155,9 +156,10 @@ function Module:getParameters()
local storages = {}
local nParameters = 0
for k = 1,#parameters do
- if not storageInSet(storages, parameters[k]:storage()) then
- storages[parameters[k]:storage()] = nParameters
- nParameters = nParameters + parameters[k]:storage():size()
+ local storage = parameters[k]:storage()
+ if not storageInSet(storages, storage) then
+ storages[torch.pointer(storage)] = {storage, nParameters}
+ nParameters = nParameters + storage:size()
end
end
@@ -186,7 +188,8 @@ function Module:getParameters()
parameters[k]:stride())
end
- for k, v in pairs(storages) do
+ for _, storageAndOffset in pairs(storages) do
+ local k, v = unpack(storageAndOffset)
flatParameters[{{v+1,v+k:size()}}]:copy(Tensor():set(k))
end
for k = 1,flatUsedParameters:nElement() do