Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'Module.lua')
-rw-r--r--Module.lua21
1 files changed, 12 insertions, 9 deletions
diff --git a/Module.lua b/Module.lua
index 4423aea..e9659e4 100644
--- a/Module.lua
+++ b/Module.lua
@@ -139,12 +139,13 @@ function Module:getParameters()
-- get parameters
local parameters,gradParameters = self:parameters()
- local function storageInSet(set, storage) --this is waste of time (need correct hash)
- for key, val in pairs(set) do
- if key == storage then
- return val
- end
+ local function storageInSet(set, storage)
+ local storageAndOffset = set[torch.pointer(storage)]
+ if storageAndOffset == nil then
+ return nil
end
+ local storage, offset = unpack(storageAndOffset)
+ return offset
end
-- this function flattens arbitrary lists of parameters,
@@ -155,9 +156,10 @@ function Module:getParameters()
local storages = {}
local nParameters = 0
for k = 1,#parameters do
- if not storageInSet(storages, parameters[k]:storage()) then
- storages[parameters[k]:storage()] = nParameters
- nParameters = nParameters + parameters[k]:storage():size()
+ local storage = parameters[k]:storage()
+ if not storageInSet(storages, storage) then
+ storages[torch.pointer(storage)] = {storage, nParameters}
+ nParameters = nParameters + storage:size()
end
end
@@ -186,7 +188,8 @@ function Module:getParameters()
parameters[k]:stride())
end
- for k, v in pairs(storages) do
+ for _, storageAndOffset in pairs(storages) do
+ local k, v = unpack(storageAndOffset)
flatParameters[{{v+1,v+k:size()}}]:copy(Tensor():set(k))
end
for k = 1,flatUsedParameters:nElement() do