From b928d4e9867d16be1f6a94368656eb6fea9ca015 Mon Sep 17 00:00:00 2001 From: Clement Farabet Date: Mon, 19 Sep 2011 22:33:41 -0400 Subject: Fixed (?) flattenParameters. --- init.lua | 30 +++++++++++++++++++++++++----- 1 file changed, 25 insertions(+), 5 deletions(-) (limited to 'init.lua') diff --git a/init.lua b/init.lua index 054b7ec..d34abba 100644 --- a/init.lua +++ b/init.lua @@ -190,14 +190,32 @@ end function nnx.flattenParameters(parameters) -- compute offsets of each parameter local offsets = {} - local dimensions = {} + local sizes = {} + local strides = {} local elements = {} + local storageOffsets = {} + local params = {} local nParameters = 0 - for _,param in ipairs(parameters) do + for k,param in ipairs(parameters) do table.insert(offsets, nParameters+1) - table.insert(dimensions, param:size()) + table.insert(sizes, param:size()) + table.insert(strides, param:stride()) table.insert(elements, param:nElement()) - nParameters = nParameters + param:nElement() + table.insert(storageOffsets, param:storageOffset()) + local isView = false + for i = 1,k-1 do + if param:storage() == parameters[i]:storage() then + offsets[k] = offsets[i] + if storageOffsets[k] ~= storageOffsets[i] or elements[k] ~= elements[i] then + error(' canot flatten shared weights with different structures') + end + isView = true + break + end + end + if not isView then + nParameters = nParameters + param:nElement() + end end -- create flat vector local flatParameters = torch.Tensor(nParameters) @@ -205,7 +223,9 @@ function nnx.flattenParameters(parameters) -- reallocate all parameters in flat vector for i = 1,#parameters do local data = parameters[i]:clone() - parameters[i]:set(storage, offsets[i], elements[i]):resize(dimensions[i]):copy(data) + parameters[i]:set(storage, offsets[i], elements[i]):resize(sizes[i]):copy(data) + data = nil + collectgarbage() end -- cleanup collectgarbage() -- cgit v1.2.3