diff options
Diffstat (limited to 'OnlineTrainer.lua')
-rw-r--r-- | OnlineTrainer.lua | 61 |
1 files changed, 19 insertions, 42 deletions
diff --git a/OnlineTrainer.lua b/OnlineTrainer.lua index 2b7f2b5..dc6e860 100644 --- a/OnlineTrainer.lua +++ b/OnlineTrainer.lua @@ -16,19 +16,23 @@ function OnlineTrainer:__init(...) .. '> ', {arg='module', type='nn.Module', help='a module to train', req=true}, - {arg='criterion', type='nn.Criterion', help='a criterion to estimate the error'}, - {arg='preprocessor', type='nn.Module', help='a preprocessor to prime the data before the module'}, - {arg='optimizer', type='nn.Optimization', help='an optimization method'}, - - {arg='batchSize', type='number', help='[mini] batch size', default=1}, - {arg='maxEpoch', type='number', help='maximum number of epochs', default=50}, - {arg='dispProgress', type='boolean', help='display a progress bar during training/testing', default=true}, - {arg='save', type='string', help='path to save networks and log training'}, - {arg='timestamp', type='boolean', help='if true, appends a timestamp to each network saved', default=false} + {arg='criterion', type='nn.Criterion', + help='a criterion to estimate the error'}, + {arg='preprocessor', type='nn.Module', + help='a preprocessor to prime the data before the module'}, + {arg='optimizer', type='nn.Optimization', + help='an optimization method'}, + {arg='batchSize', type='number', + help='[mini] batch size', default=1}, + {arg='maxEpoch', type='number', + help='maximum number of epochs', default=50}, + {arg='dispProgress', type='boolean', + help='display a progress bar during training/testing', default=true}, + {arg='save', type='string', + help='path to save networks and log training'}, + {arg='timestamp', type='boolean', + help='if true, appends a timestamp to each network saved', default=false} ) - -- private params - self.trainOffset = 0 - self.testOffset = 0 end function OnlineTrainer:log() @@ -56,18 +60,9 @@ function OnlineTrainer:train(dataset) local criterion = self.criterion self.trainset = dataset - local shuffledIndices = {} - if not self.shuffleIndices then - for t = 1,dataset:size() do - shuffledIndices[t] = t - end - else - shuffledIndices = lab.randperm(dataset:size()) - end - while true do print('<trainer> on training set:') - print("<trainer> online epoch # " .. self.epoch .. '[batchSize = ' .. self.batchSize .. ']') + print("<trainer> online epoch # " .. self.epoch .. ' [batchSize = ' .. self.batchSize .. ']') self.time = sys.clock() self.currentError = 0 @@ -82,7 +77,7 @@ function OnlineTrainer:train(dataset) local targets = {} for i = t,math.min(t+self.batchSize-1,dataset:size()) do -- load new sample - local sample = dataset[self.trainOffset + shuffledIndices[i]] + local sample = dataset[i] local input = sample[1] local target = sample[2] @@ -121,10 +116,6 @@ function OnlineTrainer:train(dataset) self.epoch = self.epoch + 1 - if dataset.infiniteSet then - self.trainOffset = self.trainOffset + dataset:size() - end - if self.maxEpoch > 0 and self.epoch > self.maxEpoch then print("<trainer> you have reached the maximum number of epochs") break @@ -137,20 +128,10 @@ function OnlineTrainer:test(dataset) print('<trainer> on testing Set:') local module = self.module - local shuffledIndices = {} local criterion = self.criterion self.currentError = 0 self.testset = dataset - local shuffledIndices = {} - if not self.shuffleIndices then - for t = 1,dataset:size() do - shuffledIndices[t] = t - end - else - shuffledIndices = lab.randperm(dataset:size()) - end - self.time = sys.clock() for t = 1,dataset:size() do -- disp progress @@ -159,7 +140,7 @@ function OnlineTrainer:test(dataset) end -- get new sample - local sample = dataset[self.testOffset + shuffledIndices[t]] + local sample = dataset[t] local input = sample[1] local target = sample[2] @@ -190,10 +171,6 @@ function OnlineTrainer:test(dataset) self.hookTestEpoch(self) end - if dataset.infiniteSet then - self.testOffset = self.testOffset + dataset:size() - end - return self.currentError end |