diff options
author | Clement Farabet <clement.farabet@gmail.com> | 2011-08-29 22:19:40 +0400 |
---|---|---|
committer | Clement Farabet <clement.farabet@gmail.com> | 2011-08-29 22:19:40 +0400 |
commit | c6cb639ce76899720b82612eafb214b610a8b9a2 (patch) | |
tree | 9c00c5b7a0bb3e06510233054a433a81c59e0244 /init.lua | |
parent | ce7043e49884a2aeb2bc7489aa6e50d21682aba5 (diff) |
Added a convenient method to re-alloc all params of a module in a flat vector.
Diffstat (limited to 'init.lua')
-rw-r--r-- | init.lua | 26 |
1 files changed, 26 insertions, 0 deletions
@@ -185,3 +185,29 @@ function nnx.getGradParameters(...) -- return all parameters found return holder end + +function nnx.flattenParameters(parameters) + -- compute offsets of each parameter + local offsets = {} + local dimensions = {} + local elements = {} + local nParameters = 0 + for _,param in ipairs(parameters) do + table.insert(offsets, nParameters+1) + table.insert(dimensions, param:size()) + table.insert(elements, param:nElement()) + nParameters = nParameters + param:nElement() + end + -- create flat vector + local flatParameters = torch.Tensor(nParameters) + local storage = flatParameters:storage() + -- reallocate all parameters in flat vector + for i = 1,#parameters do + local data = parameters[i]:clone() + parameters[i]:set(storage, offsets[i], elements[i]):resize(dimensions[i]):copy(data) + end + -- cleanup + collectgarbage() + -- return new flat vector that contains all discrete parameters + return flatParameters +end |