diff options
author | Clement Farabet <clement.farabet@gmail.com> | 2011-08-24 18:29:38 +0400 |
---|---|---|
committer | Clement Farabet <clement.farabet@gmail.com> | 2011-08-24 18:29:38 +0400 |
commit | 52568ca8072eed52ef784262144802dfc62d296a (patch) | |
tree | 222af391bfacdec95a4b748e8bad2dfcde9d93ef | |
parent | 3675e32fded83807fa5e96604dbfab7d72c04d5b (diff) |
Added a mini-batch parameter to OnlineTrainer.
-rw-r--r-- | OnlineTrainer.lua | 30 |
1 files changed, 20 insertions, 10 deletions
diff --git a/OnlineTrainer.lua b/OnlineTrainer.lua index 8a06aa7..1c47d5e 100644 --- a/OnlineTrainer.lua +++ b/OnlineTrainer.lua @@ -20,6 +20,7 @@ function OnlineTrainer:__init(...) {arg='preprocessor', type='nn.Module', help='a preprocessor to prime the data before the module'}, {arg='optimizer', type='nn.Optimization', help='an optimization method'}, + {arg='batchSize', type='number', help='[mini] batch size', default=1}, {arg='maxEpoch', type='number', help='maximum number of epochs', default=50}, {arg='dispProgress', type='boolean', help='display a progress bar during training/testing', default=true}, {arg='save', type='string', help='path to save networks and log training'}, @@ -70,29 +71,38 @@ function OnlineTrainer:train(dataset) self.time = sys.clock() self.currentError = 0 - for t = 1,dataset:size() do + for t = 1,dataset:size(),self.batchSize do -- disp progress if self.dispProgress then xlua.progress(t, dataset:size()) end - -- load new sample - local sample = dataset[self.trainOffset + shuffledIndices[t]] - local input = sample[1] - local target = sample[2] - - -- optional preprocess (no learning is done for that guy) - if self.preprocessor then input = self.preprocessor:forward(input) end + -- create mini batch + local inputs = {} + local targets = {} + for i = t,math.min(t+self.batchSize-1,dataset:size()) do + -- load new sample + local sample = dataset[self.trainOffset + shuffledIndices[i]] + local input = sample[1] + local target = sample[2] + + -- optional preprocess (no learning is done for that guy) + if self.preprocessor then input = self.preprocessor:forward(input) end + + -- store input/target + table.insert(inputs, input) + table.insert(targets, target) + end -- optimize the model given current input/target set - local error = self.optimizer:forward({input}, {target}) + local error = self.optimizer:forward(inputs, targets) -- accumulate error self.currentError = self.currentError + error -- call user hook, if any if self.hookTrainSample then - self.hookTrainSample(self, sample) + self.hookTrainSample(self, {inputs[#inputs], targets[#targets]}) end end |