Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/clementfarabet/lua---nnx.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'init.lua')
-rw-r--r--init.lua30
1 files changed, 29 insertions, 1 deletions
diff --git a/init.lua b/init.lua
index 6b8b7e5..3519294 100644
--- a/init.lua
+++ b/init.lua
@@ -93,19 +93,21 @@ torch.include('nnx', 'SpatialColorTransform.lua')
-- criterions:
torch.include('nnx', 'SuperCriterion.lua')
torch.include('nnx', 'SparseCriterion.lua')
+torch.include('nnx', 'DistNLLCriterion.lua')
torch.include('nnx', 'SpatialMSECriterion.lua')
torch.include('nnx', 'SpatialClassNLLCriterion.lua')
torch.include('nnx', 'SpatialSparseCriterion.lua')
-- optimizations:
torch.include('nnx', 'Optimization.lua')
+torch.include('nnx', 'BatchOptimization.lua')
torch.include('nnx', 'SGDOptimization.lua')
torch.include('nnx', 'LBFGSOptimization.lua')
-- trainers:
torch.include('nnx', 'Trainer.lua')
torch.include('nnx', 'OnlineTrainer.lua')
-torch.include('nnx', 'StochasticTrainer.lua')
+torch.include('nnx', 'BatchTrainer.lua')
-- datasets:
torch.include('nnx', 'DataSet.lua')
@@ -185,3 +187,29 @@ function nnx.getGradParameters(...)
-- return all parameters found
return holder
end
+
+function nnx.flattenParameters(parameters)
+ -- compute offsets of each parameter
+ local offsets = {}
+ local dimensions = {}
+ local elements = {}
+ local nParameters = 0
+ for _,param in ipairs(parameters) do
+ table.insert(offsets, nParameters+1)
+ table.insert(dimensions, param:size())
+ table.insert(elements, param:nElement())
+ nParameters = nParameters + param:nElement()
+ end
+ -- create flat vector
+ local flatParameters = torch.Tensor(nParameters)
+ local storage = flatParameters:storage()
+ -- reallocate all parameters in flat vector
+ for i = 1,#parameters do
+ local data = parameters[i]:clone()
+ parameters[i]:set(storage, offsets[i], elements[i]):resize(dimensions[i]):copy(data)
+ end
+ -- cleanup
+ collectgarbage()
+ -- return new flat vector that contains all discrete parameters
+ return flatParameters
+end