Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/clementfarabet/lua---nnx.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'OnlineTrainer.lua')
-rw-r--r--OnlineTrainer.lua61
1 files changed, 19 insertions, 42 deletions
diff --git a/OnlineTrainer.lua b/OnlineTrainer.lua
index 2b7f2b5..dc6e860 100644
--- a/OnlineTrainer.lua
+++ b/OnlineTrainer.lua
@@ -16,19 +16,23 @@ function OnlineTrainer:__init(...)
.. '> ',
{arg='module', type='nn.Module', help='a module to train', req=true},
- {arg='criterion', type='nn.Criterion', help='a criterion to estimate the error'},
- {arg='preprocessor', type='nn.Module', help='a preprocessor to prime the data before the module'},
- {arg='optimizer', type='nn.Optimization', help='an optimization method'},
-
- {arg='batchSize', type='number', help='[mini] batch size', default=1},
- {arg='maxEpoch', type='number', help='maximum number of epochs', default=50},
- {arg='dispProgress', type='boolean', help='display a progress bar during training/testing', default=true},
- {arg='save', type='string', help='path to save networks and log training'},
- {arg='timestamp', type='boolean', help='if true, appends a timestamp to each network saved', default=false}
+ {arg='criterion', type='nn.Criterion',
+ help='a criterion to estimate the error'},
+ {arg='preprocessor', type='nn.Module',
+ help='a preprocessor to prime the data before the module'},
+ {arg='optimizer', type='nn.Optimization',
+ help='an optimization method'},
+ {arg='batchSize', type='number',
+ help='[mini] batch size', default=1},
+ {arg='maxEpoch', type='number',
+ help='maximum number of epochs', default=50},
+ {arg='dispProgress', type='boolean',
+ help='display a progress bar during training/testing', default=true},
+ {arg='save', type='string',
+ help='path to save networks and log training'},
+ {arg='timestamp', type='boolean',
+ help='if true, appends a timestamp to each network saved', default=false}
)
- -- private params
- self.trainOffset = 0
- self.testOffset = 0
end
function OnlineTrainer:log()
@@ -56,18 +60,9 @@ function OnlineTrainer:train(dataset)
local criterion = self.criterion
self.trainset = dataset
- local shuffledIndices = {}
- if not self.shuffleIndices then
- for t = 1,dataset:size() do
- shuffledIndices[t] = t
- end
- else
- shuffledIndices = lab.randperm(dataset:size())
- end
-
while true do
print('<trainer> on training set:')
- print("<trainer> online epoch # " .. self.epoch .. '[batchSize = ' .. self.batchSize .. ']')
+ print("<trainer> online epoch # " .. self.epoch .. ' [batchSize = ' .. self.batchSize .. ']')
self.time = sys.clock()
self.currentError = 0
@@ -82,7 +77,7 @@ function OnlineTrainer:train(dataset)
local targets = {}
for i = t,math.min(t+self.batchSize-1,dataset:size()) do
-- load new sample
- local sample = dataset[self.trainOffset + shuffledIndices[i]]
+ local sample = dataset[i]
local input = sample[1]
local target = sample[2]
@@ -121,10 +116,6 @@ function OnlineTrainer:train(dataset)
self.epoch = self.epoch + 1
- if dataset.infiniteSet then
- self.trainOffset = self.trainOffset + dataset:size()
- end
-
if self.maxEpoch > 0 and self.epoch > self.maxEpoch then
print("<trainer> you have reached the maximum number of epochs")
break
@@ -137,20 +128,10 @@ function OnlineTrainer:test(dataset)
print('<trainer> on testing Set:')
local module = self.module
- local shuffledIndices = {}
local criterion = self.criterion
self.currentError = 0
self.testset = dataset
- local shuffledIndices = {}
- if not self.shuffleIndices then
- for t = 1,dataset:size() do
- shuffledIndices[t] = t
- end
- else
- shuffledIndices = lab.randperm(dataset:size())
- end
-
self.time = sys.clock()
for t = 1,dataset:size() do
-- disp progress
@@ -159,7 +140,7 @@ function OnlineTrainer:test(dataset)
end
-- get new sample
- local sample = dataset[self.testOffset + shuffledIndices[t]]
+ local sample = dataset[t]
local input = sample[1]
local target = sample[2]
@@ -190,10 +171,6 @@ function OnlineTrainer:test(dataset)
self.hookTestEpoch(self)
end
- if dataset.infiniteSet then
- self.testOffset = self.testOffset + dataset:size()
- end
-
return self.currentError
end