Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/clementfarabet/lua---nnx.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'init.lua')
-rw-r--r--init.lua26
1 files changed, 26 insertions, 0 deletions
diff --git a/init.lua b/init.lua
index 6b8b7e5..04877d3 100644
--- a/init.lua
+++ b/init.lua
@@ -185,3 +185,29 @@ function nnx.getGradParameters(...)
-- return all parameters found
return holder
end
+
+function nnx.flattenParameters(parameters)
+ -- compute offsets of each parameter
+ local offsets = {}
+ local dimensions = {}
+ local elements = {}
+ local nParameters = 0
+ for _,param in ipairs(parameters) do
+ table.insert(offsets, nParameters+1)
+ table.insert(dimensions, param:size())
+ table.insert(elements, param:nElement())
+ nParameters = nParameters + param:nElement()
+ end
+ -- create flat vector
+ local flatParameters = torch.Tensor(nParameters)
+ local storage = flatParameters:storage()
+ -- reallocate all parameters in flat vector
+ for i = 1,#parameters do
+ local data = parameters[i]:clone()
+ parameters[i]:set(storage, offsets[i], elements[i]):resize(dimensions[i]):copy(data)
+ end
+ -- cleanup
+ collectgarbage()
+ -- return new flat vector that contains all discrete parameters
+ return flatParameters
+end