Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/clementfarabet/lua---nnx.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'BatchTrainer.lua')
-rw-r--r--BatchTrainer.lua170
1 files changed, 170 insertions, 0 deletions
diff --git a/BatchTrainer.lua b/BatchTrainer.lua
new file mode 100644
index 0000000..a5b135d
--- /dev/null
+++ b/BatchTrainer.lua
@@ -0,0 +1,170 @@
+local BatchTrainer, parent = torch.class('nn.BatchTrainer', 'nn.OnlineTrainer')
+
+-- Essentially simialar to the OnlineTrainer but only used the parts
+-- of the code which prepare the data and the tester. train() has been
+-- replaced by nextBatch() which moves the trainer one batch further
+-- in the data. When the first epoch is finished then the batches are
+-- reused. Each call to optimizer.forward() in nextBatch() creates a
+-- closure with the current batch as input.
+
+function BatchTrainer:__init(...)
+ local args = {...}
+ parent.__init(self, args)
+ -- unpack args
+ xlua.unpack_class(
+ self, args,
+ 'BatchTrainer',
+ 'A modified version of the general-purpose online trainer class.\n'
+ .. ' which only preps the input batch and calls optimizer to\n'
+ .. ' create a closure\n',
+ {arg='trainset', type='nn.DataList',
+ help='dataset from which to draw batches', req=true},
+ {arg='module', type='nn.Module', help='a module to train', req=true},
+ {arg='criterion', type='nn.Criterion',
+ help='a criterion to estimate the error'},
+ {arg='preprocessor', type='nn.Module',
+ help='a preprocessor to prime the data before the module'},
+ {arg='optimizer', type='nn.Optimization',
+ help='an optimization method'},
+ {arg='batchSize', type='number',
+ help='[mini] batch size', default=1},
+ {arg='maxEpoch', type='number',
+ help='maximum number of epochs', default=50},
+ {arg='dispProgress', type='boolean',
+ help='display a progress bar during training/testing', default=true},
+ {arg='save', type='string',
+ help='path to save networks and log training'},
+ {arg='timestamp', type='boolean',
+ help='if true, appends a timestamp to each network saved', default=false}
+ )
+ self.epoch = 1
+ self.batch = nil
+ self.trainOffset = nil
+end
+
+-- update the counters
+function BatchTrainer:next()
+ if not self.batch or not self.trainOffset then
+ -- initialize
+ self.batch = 1
+ self.trainOffset = 1
+ else
+ -- hook to run something on the current batch
+ -- (for eg. if you want to run a test on this batch before
+ -- switching to the next)
+ if self.hookTrainBatch then
+ self.hookTrainBatch(self)
+ end
+
+ -- simple batch increment
+ self.batch = self.batch + 1
+ self.trainOffset = self.trainOffset + self.batchSize
+
+ -- test for new epoch
+ if self.trainOffset > self.trainset:size() then
+
+ -- hook to run on current epoch before switching to next
+ if self.hookTrainEpoch then
+ self.hookTrainEpoch(self)
+ end
+
+ if self.save then self:log() end
+
+ self.trainOffset = 1
+ self.epoch = self.epoch + 1
+ self.batch = 1
+ end
+
+ -- on all but the first batch we need to reset the children
+ if optimizer.parallelize > 1 then
+ parallel.children:send('break')
+ end
+
+ end
+ -- disp progress
+ if self.dispProgress then
+ xlua.progress(self.trainOffset, self.trainset:size())
+ end
+
+end
+
+-- this function is called train() in the online trainer. I seems to
+-- make more sense to call it next_batch() here as the training is
+-- done outside of this code.
+
+function BatchTrainer:nextBatch()
+ self:next()
+ local module = self.module
+ local criterion = self.criterion
+ local t = self.trainOffset
+ local ds = self.trainset:size()
+ local bs = self.batchSize
+
+ print('<trainer> on training set:')
+ print("<trainer> online epoch # " .. self.epoch
+ .. ' batch # '..self.batch
+ .. ' [batchSize = ' .. self.batchSize .. ']')
+
+ -- create mini batch
+ self.inputs = self.inputs or {}
+ self.targets = self.targets or {}
+ local inputs = {}
+ local targets = {}
+ if not self.inputs[self.batch] then
+
+ self.inputs[self.batch] = {}
+ inputs = self.inputs[self.batch]
+ self.targets[self.batch] = {}
+ targets = self.targets[self.batch]
+
+ for i = t,math.min(t+bs-1,ds) do
+ -- load new sample
+ local sample = self.trainset[i]
+ local input = sample[1]
+ local target = sample[2]
+
+ -- optional preprocess (no learning is done for that guy)
+ if self.preprocessor then input = self.preprocessor:forward(input) end
+
+ -- store input/target
+ table.insert(inputs, input)
+ table.insert(targets, target)
+ end
+ else
+ -- get batch from cache
+ inputs = self.inputs[self.batch]
+ targets = self.targets[self.batch]
+ end
+
+ -- set up closure batch.evaluate() for optimizer
+ local error = self.optimizer:forward(inputs, targets)
+
+end
+
+-- special test to just get results of current batch
+function BatchTrainer:testBatch()
+ local criterion = self.criterion
+ local module = self.module
+
+ local inputs = self.inputs[self.batch]
+ local targets = self.targets[self.batch]
+
+ self.currentError = 0
+
+ for i = 1,#inputs do
+ local input = inputs[i]
+ local target = targets[i]
+ if criterion then
+ self.currentError = self.currentError +
+ criterion:forward(module:forward(input), target)
+ else
+ local _,error = module:forward(input, target)
+ self.currentError = self.currentError + error
+ end
+ -- user hook
+ if self.hookTestSample then
+ self.hookTestSample(self, {input, target})
+ end
+ end
+end
+