diff options
author | Clement Farabet <clement.farabet@gmail.com> | 2011-08-24 01:33:57 +0400 |
---|---|---|
committer | Clement Farabet <clement.farabet@gmail.com> | 2011-08-24 01:33:57 +0400 |
commit | 27353f00f1806c0657f9194c06f296c4038e8545 (patch) | |
tree | b0d7dc2cf74bc1adca28773dda9c8a4fd01ad872 | |
parent | cd5f492473c47835b8ca2c8e9686c15479354faa (diff) |
Added Onlinetrainer class.
-rw-r--r-- | OnlineTrainer.lua | 229 | ||||
-rw-r--r-- | init.lua | 1 | ||||
-rw-r--r-- | nnx-1.0-1.rockspec | 1 |
3 files changed, 231 insertions, 0 deletions
diff --git a/OnlineTrainer.lua b/OnlineTrainer.lua new file mode 100644 index 0000000..9dd81f2 --- /dev/null +++ b/OnlineTrainer.lua @@ -0,0 +1,229 @@ +local OnlineTrainer, parent = torch.class('nn.OnlineTrainer','nn.Trainer') + +function OnlineTrainer:__init(...) + parent.__init(self) + -- unpack args + xlua.unpack_class(self, {...}, + 'OnlineTrainer', + + 'A general-purpose online trainer class.\n' + .. 'Provides 4 user hooks to perform extra work after each sample, or each epoch:\n' + .. '> trainer = nn.OnlineTrainer(...) \n' + .. '> trainer.hookTrainSample = function(trainer, sample) ... end \n' + .. '> trainer.hookTrainEpoch = function(trainer) ... end \n' + .. '> trainer.hookTestSample = function(trainer, sample) ... end \n' + .. '> trainer.hookTestEpoch = function(trainer) ... end \n' + .. '> ', + + {arg='module', type='nn.Module', help='a module to train', req=true}, + {arg='criterion', type='nn.Criterion', help='a criterion to estimate the error'}, + {arg='preprocessor', type='nn.Module', help='a preprocessor to prime the data before the module'}, + {arg='optimizer', type='nn.Optimization', help='an optimization method'}, + + {arg='maxEpoch', type='number', help='maximum number of epochs', default=50}, + {arg='dispProgress', type='boolean', help='display a progress bar during training/testing', default=true}, + {arg='save', type='string', help='path to save networks and log training'}, + {arg='timestamp', type='boolean', help='if true, appends a timestamp to each network saved', default=false} + ) + -- private params + self.trainOffset = 0 + self.testOffset = 0 +end + +function OnlineTrainer:log() + -- save network + local filename = self.save + os.execute('mkdir -p ' .. sys.dirname(filename)) + if self.timestamp then + -- use a timestamp to store all networks uniquely + filename = filename .. '-' .. os.date("%Y_%m_%d_%X") + else + -- if no timestamp, just store the previous one + if sys.filep(filename) then + os.execute('mv ' .. filename .. ' ' .. filename .. '.old') + end + end + print('<trainer> saving network to '..filename) + local file = torch.DiskFile(filename,'w') + self.module:write(file) + file:close() +end + +function OnlineTrainer:train(dataset) + self.epoch = self.epoch or 1 + local module = self.module + local criterion = self.criterion + self.trainset = dataset + + local shuffledIndices = {} + if not self.shuffleIndices then + for t = 1,dataset:size() do + shuffledIndices[t] = t + end + else + shuffledIndices = lab.randperm(dataset:size()) + end + + local parameters = nnx.getParameters(module) + local gradParameters = nnx.getGradParameters(module) + + while true do + print('<trainer> on training set:') + print("<trainer> stochastic gradient descent epoch # " .. self.epoch) + + module:zeroGradParameters() + + self.time = sys.clock() + self.currentError = 0 + for t = 1,dataset:size() do + -- disp progress + if self.dispProgress then + xlua.progress(t, dataset:size()) + end + + -- load new sample + local sample = dataset[self.trainOffset + shuffledIndices[t]] + local input = sample[1] + local target = sample[2] + local sample_x = sample.x + local sample_y = sample.y + + -- optional preprocess (no learning is done for that guy) + if self.preprocessor then input = self.preprocessor:forward(input) end + + -- forward through model and criterion + -- (if no criterion, it is assumed to be contained in the model) + local modelOut, error + if criterion then + modelOut = module:forward(input) + error = criterion:forward(modelOut, target) + else + modelOut, error = module:forward(input, target, sample_x, sample_y) + end + + -- accumulate error + self.currentError = self.currentError + error + + -- reset gradients + module:zeroGradParameters() + + -- backward through model + -- (if no criterion, it is assumed that derror is internally generated) + if criterion then + local derror = criterion:backward(module.output, target) + module:backward(input, derror) + else + module:backward(input) + end + + -- update parameters in the model + self.optimizer:forward(parameters, gradParameters) + + -- call user hook, if any + if self.hookTrainSample then + self.hookTrainSample(self, sample) + end + end + + self.currentError = self.currentError / dataset:size() + print("<trainer> current error = " .. self.currentError) + + self.time = sys.clock() - self.time + self.time = self.time / dataset:size() + print("<trainer> time to learn 1 sample = " .. (self.time*1000) .. 'ms') + + if self.hookTrainEpoch then + self.hookTrainEpoch(self) + end + + if self.save then self:log() end + + self.epoch = self.epoch + 1 + + if dataset.infiniteSet then + self.trainOffset = self.trainOffset + dataset:size() + end + + if self.maxEpoch > 0 and self.epoch > self.maxEpoch then + print("<trainer> you have reached the maximum number of epochs") + break + end + end +end + + +function OnlineTrainer:test(dataset) + print('<trainer> on testing Set:') + + local module = self.module + local shuffledIndices = {} + local criterion = self.criterion + self.currentError = 0 + self.testset = dataset + + local shuffledIndices = {} + if not self.shuffleIndices then + for t = 1,dataset:size() do + shuffledIndices[t] = t + end + else + shuffledIndices = lab.randperm(dataset:size()) + end + + self.time = sys.clock() + for t = 1,dataset:size() do + -- disp progress + if self.dispProgress then + xlua.progress(t, dataset:size()) + end + + -- get new sample + local sample = dataset[self.testOffset + shuffledIndices[t]] + local input = sample[1] + local target = sample[2] + + -- test sample through current model + if self.preprocessor then input = self.preprocessor:forward(input) end + if criterion then + self.currentError = self.currentError + + criterion:forward(module:forward(input), target) + else + local _,error = module:forward(input, target) + self.currentError = self.currentError + error + end + + -- user hook + if self.hookTestSample then + self.hookTestSample(self, sample) + end + end + + self.currentError = self.currentError / dataset:size() + print("<trainer> test current error = " .. self.currentError) + + self.time = sys.clock() - self.time + self.time = self.time / dataset:size() + print("<trainer> time to test 1 sample = " .. (self.time*1000) .. 'ms') + + if self.hookTestEpoch then + self.hookTestEpoch(self) + end + + if dataset.infiniteSet then + self.testOffset = self.testOffset + dataset:size() + end + + return self.currentError +end + +function OnlineTrainer:write(file) + parent.write(self,file) + file:writeObject(self.module) + file:writeObject(self.criterion) +end + +function OnlineTrainer:read(file) + parent.read(self,file) + self.module = file:readObject() + self.criterion = file:readObject() +end @@ -103,6 +103,7 @@ torch.include('nnx', 'SGDOptimization.lua') -- trainers: torch.include('nnx', 'Trainer.lua') +torch.include('nnx', 'OnlineTrainer.lua') torch.include('nnx', 'StochasticTrainer.lua') -- datasets: diff --git a/nnx-1.0-1.rockspec b/nnx-1.0-1.rockspec index ba2120c..abcb5da 100644 --- a/nnx-1.0-1.rockspec +++ b/nnx-1.0-1.rockspec @@ -77,6 +77,7 @@ build = { install_files(/lua/nnx SuperCriterion.lua) install_files(/lua/nnx SpatialCriterion.lua) install_files(/lua/nnx Trainer.lua) + install_files(/lua/nnx OnlineTrainer.lua) install_files(/lua/nnx StochasticTrainer.lua) install_files(/lua/nnx DataSet.lua) install_files(/lua/nnx DataList.lua) |