diff options
Diffstat (limited to 'init.lua')
-rw-r--r-- | init.lua | 30 |
1 files changed, 29 insertions, 1 deletions
@@ -93,19 +93,21 @@ torch.include('nnx', 'SpatialColorTransform.lua') -- criterions: torch.include('nnx', 'SuperCriterion.lua') torch.include('nnx', 'SparseCriterion.lua') +torch.include('nnx', 'DistNLLCriterion.lua') torch.include('nnx', 'SpatialMSECriterion.lua') torch.include('nnx', 'SpatialClassNLLCriterion.lua') torch.include('nnx', 'SpatialSparseCriterion.lua') -- optimizations: torch.include('nnx', 'Optimization.lua') +torch.include('nnx', 'BatchOptimization.lua') torch.include('nnx', 'SGDOptimization.lua') torch.include('nnx', 'LBFGSOptimization.lua') -- trainers: torch.include('nnx', 'Trainer.lua') torch.include('nnx', 'OnlineTrainer.lua') -torch.include('nnx', 'StochasticTrainer.lua') +torch.include('nnx', 'BatchTrainer.lua') -- datasets: torch.include('nnx', 'DataSet.lua') @@ -185,3 +187,29 @@ function nnx.getGradParameters(...) -- return all parameters found return holder end + +function nnx.flattenParameters(parameters) + -- compute offsets of each parameter + local offsets = {} + local dimensions = {} + local elements = {} + local nParameters = 0 + for _,param in ipairs(parameters) do + table.insert(offsets, nParameters+1) + table.insert(dimensions, param:size()) + table.insert(elements, param:nElement()) + nParameters = nParameters + param:nElement() + end + -- create flat vector + local flatParameters = torch.Tensor(nParameters) + local storage = flatParameters:storage() + -- reallocate all parameters in flat vector + for i = 1,#parameters do + local data = parameters[i]:clone() + parameters[i]:set(storage, offsets[i], elements[i]):resize(dimensions[i]):copy(data) + end + -- cleanup + collectgarbage() + -- return new flat vector that contains all discrete parameters + return flatParameters +end |