Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/clementfarabet/lua---nnx.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorClement Farabet <clement.farabet@gmail.com>2011-08-31 03:16:32 +0400
committerClement Farabet <clement.farabet@gmail.com>2011-08-31 03:16:32 +0400
commit987894db868ed9b5ad0cd746a8c3569985acd71d (patch)
treed7cc16937910feaf80d8d25f0f4526d1a99887cd
parentdab3bb7517155399bc6f9e377e9fc15c16063aa1 (diff)
God rid of shuffle flags in Trainers.
-rw-r--r--OnlineTrainer.lua34
-rw-r--r--Trainer.lua7
2 files changed, 2 insertions, 39 deletions
diff --git a/OnlineTrainer.lua b/OnlineTrainer.lua
index db662e5..dc6e860 100644
--- a/OnlineTrainer.lua
+++ b/OnlineTrainer.lua
@@ -33,9 +33,6 @@ function OnlineTrainer:__init(...)
{arg='timestamp', type='boolean',
help='if true, appends a timestamp to each network saved', default=false}
)
- -- private params
- self.trainOffset = 0
- self.testOffset = 0
end
function OnlineTrainer:log()
@@ -63,15 +60,6 @@ function OnlineTrainer:train(dataset)
local criterion = self.criterion
self.trainset = dataset
- local shuffledIndices = {}
- if not self.shuffleIndices then
- for t = 1,dataset:size() do
- shuffledIndices[t] = t
- end
- else
- shuffledIndices = lab.randperm(dataset:size())
- end
-
while true do
print('<trainer> on training set:')
print("<trainer> online epoch # " .. self.epoch .. ' [batchSize = ' .. self.batchSize .. ']')
@@ -89,7 +77,7 @@ function OnlineTrainer:train(dataset)
local targets = {}
for i = t,math.min(t+self.batchSize-1,dataset:size()) do
-- load new sample
- local sample = dataset[self.trainOffset + shuffledIndices[i]]
+ local sample = dataset[i]
local input = sample[1]
local target = sample[2]
@@ -128,10 +116,6 @@ function OnlineTrainer:train(dataset)
self.epoch = self.epoch + 1
- if dataset.infiniteSet then
- self.trainOffset = self.trainOffset + dataset:size()
- end
-
if self.maxEpoch > 0 and self.epoch > self.maxEpoch then
print("<trainer> you have reached the maximum number of epochs")
break
@@ -144,20 +128,10 @@ function OnlineTrainer:test(dataset)
print('<trainer> on testing Set:')
local module = self.module
- local shuffledIndices = {}
local criterion = self.criterion
self.currentError = 0
self.testset = dataset
- local shuffledIndices = {}
- if not self.shuffleIndices then
- for t = 1,dataset:size() do
- shuffledIndices[t] = t
- end
- else
- shuffledIndices = lab.randperm(dataset:size())
- end
-
self.time = sys.clock()
for t = 1,dataset:size() do
-- disp progress
@@ -166,7 +140,7 @@ function OnlineTrainer:test(dataset)
end
-- get new sample
- local sample = dataset[self.testOffset + shuffledIndices[t]]
+ local sample = dataset[t]
local input = sample[1]
local target = sample[2]
@@ -197,10 +171,6 @@ function OnlineTrainer:test(dataset)
self.hookTestEpoch(self)
end
- if dataset.infiniteSet then
- self.testOffset = self.testOffset + dataset:size()
- end
-
return self.currentError
end
diff --git a/Trainer.lua b/Trainer.lua
index 3388ef7..b7da770 100644
--- a/Trainer.lua
+++ b/Trainer.lua
@@ -4,7 +4,6 @@ function Trainer:__init()
self.learningRate = 0.01
self.learningRateDecay = 0
self.maxIteration = 25
- self.shuffleIndices = true
end
function Trainer:train(dataset)
@@ -14,14 +13,12 @@ function Trainer:write(file)
file:writeDouble(self.learningRate)
file:writeDouble(self.learningRateDecay)
file:writeInt(self.maxIteration)
- file:writeBool(self.shuffleIndices)
end
function Trainer:read(file)
self.learningRate = file:readDouble()
self.learningRateDecay = file:readDouble()
self.maxIteration = file:readInt()
- self.shuffleIndices = file:readBool()
end
function Trainer:share(mlp, ...)
@@ -30,10 +27,6 @@ function Trainer:share(mlp, ...)
end
end
-function Trainer:setShuffle(bool)
- self.shuffleIndices = bool
-end
-
function Trainer:clone(...)
local f = torch.MemoryFile("rw"):binary()
f:writeObject(self)