Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/clementfarabet/lua---nnx.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorClement Farabet <clement.farabet@gmail.com>2011-08-24 18:29:38 +0400
committerClement Farabet <clement.farabet@gmail.com>2011-08-24 18:29:38 +0400
commit52568ca8072eed52ef784262144802dfc62d296a (patch)
tree222af391bfacdec95a4b748e8bad2dfcde9d93ef
parent3675e32fded83807fa5e96604dbfab7d72c04d5b (diff)
Added a mini-batch parameter to OnlineTrainer.
-rw-r--r--OnlineTrainer.lua30
1 files changed, 20 insertions, 10 deletions
diff --git a/OnlineTrainer.lua b/OnlineTrainer.lua
index 8a06aa7..1c47d5e 100644
--- a/OnlineTrainer.lua
+++ b/OnlineTrainer.lua
@@ -20,6 +20,7 @@ function OnlineTrainer:__init(...)
{arg='preprocessor', type='nn.Module', help='a preprocessor to prime the data before the module'},
{arg='optimizer', type='nn.Optimization', help='an optimization method'},
+ {arg='batchSize', type='number', help='[mini] batch size', default=1},
{arg='maxEpoch', type='number', help='maximum number of epochs', default=50},
{arg='dispProgress', type='boolean', help='display a progress bar during training/testing', default=true},
{arg='save', type='string', help='path to save networks and log training'},
@@ -70,29 +71,38 @@ function OnlineTrainer:train(dataset)
self.time = sys.clock()
self.currentError = 0
- for t = 1,dataset:size() do
+ for t = 1,dataset:size(),self.batchSize do
-- disp progress
if self.dispProgress then
xlua.progress(t, dataset:size())
end
- -- load new sample
- local sample = dataset[self.trainOffset + shuffledIndices[t]]
- local input = sample[1]
- local target = sample[2]
-
- -- optional preprocess (no learning is done for that guy)
- if self.preprocessor then input = self.preprocessor:forward(input) end
+ -- create mini batch
+ local inputs = {}
+ local targets = {}
+ for i = t,math.min(t+self.batchSize-1,dataset:size()) do
+ -- load new sample
+ local sample = dataset[self.trainOffset + shuffledIndices[i]]
+ local input = sample[1]
+ local target = sample[2]
+
+ -- optional preprocess (no learning is done for that guy)
+ if self.preprocessor then input = self.preprocessor:forward(input) end
+
+ -- store input/target
+ table.insert(inputs, input)
+ table.insert(targets, target)
+ end
-- optimize the model given current input/target set
- local error = self.optimizer:forward({input}, {target})
+ local error = self.optimizer:forward(inputs, targets)
-- accumulate error
self.currentError = self.currentError + error
-- call user hook, if any
if self.hookTrainSample then
- self.hookTrainSample(self, sample)
+ self.hookTrainSample(self, {inputs[#inputs], targets[#targets]})
end
end