diff options
Diffstat (limited to 'init.lua')
-rw-r--r-- | init.lua | 19 |
1 files changed, 19 insertions, 0 deletions
@@ -195,6 +195,25 @@ function nnx.getGradParameters(...) end function nnx.flattenParameters(parameters) + -- already flat ? + local flat = true + for k = 2,#parameters do + if parameters[k]:storage() ~= parameters[k-1]:storage() then + flat = false + break + end + end + if flat then + local nParameters = 0 + for k,param in ipairs(parameters) do + nParameters = nParameters + param:nElement() + end + flatParameters = torch.Tensor(parameters[1]:storage()) + if nParameters ~= flatParameters:nElement() then + error('<nnx.flattenParameters> weird parameters') + end + return flatParameters + end -- compute offsets of each parameter local offsets = {} local sizes = {} |