Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/clementfarabet/lua---nnx.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorClement Farabet <clement.farabet@gmail.com>2011-09-20 06:33:41 +0400
committerClement Farabet <clement.farabet@gmail.com>2011-09-20 06:33:41 +0400
commitb928d4e9867d16be1f6a94368656eb6fea9ca015 (patch)
treef96decfbb013c619f69b22f8b14b9b80a963fe7a /init.lua
parent123a517433f1a4774dfc696f73d7a2f3a17186b6 (diff)
Fixed (?) flattenParameters.
Diffstat (limited to 'init.lua')
-rw-r--r--init.lua30
1 files changed, 25 insertions, 5 deletions
diff --git a/init.lua b/init.lua
index 054b7ec..d34abba 100644
--- a/init.lua
+++ b/init.lua
@@ -190,14 +190,32 @@ end
function nnx.flattenParameters(parameters)
-- compute offsets of each parameter
local offsets = {}
- local dimensions = {}
+ local sizes = {}
+ local strides = {}
local elements = {}
+ local storageOffsets = {}
+ local params = {}
local nParameters = 0
- for _,param in ipairs(parameters) do
+ for k,param in ipairs(parameters) do
table.insert(offsets, nParameters+1)
- table.insert(dimensions, param:size())
+ table.insert(sizes, param:size())
+ table.insert(strides, param:stride())
table.insert(elements, param:nElement())
- nParameters = nParameters + param:nElement()
+ table.insert(storageOffsets, param:storageOffset())
+ local isView = false
+ for i = 1,k-1 do
+ if param:storage() == parameters[i]:storage() then
+ offsets[k] = offsets[i]
+ if storageOffsets[k] ~= storageOffsets[i] or elements[k] ~= elements[i] then
+ error('<nnx.flattenParameters> canot flatten shared weights with different structures')
+ end
+ isView = true
+ break
+ end
+ end
+ if not isView then
+ nParameters = nParameters + param:nElement()
+ end
end
-- create flat vector
local flatParameters = torch.Tensor(nParameters)
@@ -205,7 +223,9 @@ function nnx.flattenParameters(parameters)
-- reallocate all parameters in flat vector
for i = 1,#parameters do
local data = parameters[i]:clone()
- parameters[i]:set(storage, offsets[i], elements[i]):resize(dimensions[i]):copy(data)
+ parameters[i]:set(storage, offsets[i], elements[i]):resize(sizes[i]):copy(data)
+ data = nil
+ collectgarbage()
end
-- cleanup
collectgarbage()