diff options
author | Clement Farabet <clement.farabet@gmail.com> | 2011-08-31 03:16:32 +0400 |
---|---|---|
committer | Clement Farabet <clement.farabet@gmail.com> | 2011-08-31 03:16:32 +0400 |
commit | 987894db868ed9b5ad0cd746a8c3569985acd71d (patch) | |
tree | d7cc16937910feaf80d8d25f0f4526d1a99887cd | |
parent | dab3bb7517155399bc6f9e377e9fc15c16063aa1 (diff) |
God rid of shuffle flags in Trainers.
-rw-r--r-- | OnlineTrainer.lua | 34 | ||||
-rw-r--r-- | Trainer.lua | 7 |
2 files changed, 2 insertions, 39 deletions
diff --git a/OnlineTrainer.lua b/OnlineTrainer.lua index db662e5..dc6e860 100644 --- a/OnlineTrainer.lua +++ b/OnlineTrainer.lua @@ -33,9 +33,6 @@ function OnlineTrainer:__init(...) {arg='timestamp', type='boolean', help='if true, appends a timestamp to each network saved', default=false} ) - -- private params - self.trainOffset = 0 - self.testOffset = 0 end function OnlineTrainer:log() @@ -63,15 +60,6 @@ function OnlineTrainer:train(dataset) local criterion = self.criterion self.trainset = dataset - local shuffledIndices = {} - if not self.shuffleIndices then - for t = 1,dataset:size() do - shuffledIndices[t] = t - end - else - shuffledIndices = lab.randperm(dataset:size()) - end - while true do print('<trainer> on training set:') print("<trainer> online epoch # " .. self.epoch .. ' [batchSize = ' .. self.batchSize .. ']') @@ -89,7 +77,7 @@ function OnlineTrainer:train(dataset) local targets = {} for i = t,math.min(t+self.batchSize-1,dataset:size()) do -- load new sample - local sample = dataset[self.trainOffset + shuffledIndices[i]] + local sample = dataset[i] local input = sample[1] local target = sample[2] @@ -128,10 +116,6 @@ function OnlineTrainer:train(dataset) self.epoch = self.epoch + 1 - if dataset.infiniteSet then - self.trainOffset = self.trainOffset + dataset:size() - end - if self.maxEpoch > 0 and self.epoch > self.maxEpoch then print("<trainer> you have reached the maximum number of epochs") break @@ -144,20 +128,10 @@ function OnlineTrainer:test(dataset) print('<trainer> on testing Set:') local module = self.module - local shuffledIndices = {} local criterion = self.criterion self.currentError = 0 self.testset = dataset - local shuffledIndices = {} - if not self.shuffleIndices then - for t = 1,dataset:size() do - shuffledIndices[t] = t - end - else - shuffledIndices = lab.randperm(dataset:size()) - end - self.time = sys.clock() for t = 1,dataset:size() do -- disp progress @@ -166,7 +140,7 @@ function OnlineTrainer:test(dataset) end -- get new sample - local sample = dataset[self.testOffset + shuffledIndices[t]] + local sample = dataset[t] local input = sample[1] local target = sample[2] @@ -197,10 +171,6 @@ function OnlineTrainer:test(dataset) self.hookTestEpoch(self) end - if dataset.infiniteSet then - self.testOffset = self.testOffset + dataset:size() - end - return self.currentError end diff --git a/Trainer.lua b/Trainer.lua index 3388ef7..b7da770 100644 --- a/Trainer.lua +++ b/Trainer.lua @@ -4,7 +4,6 @@ function Trainer:__init() self.learningRate = 0.01 self.learningRateDecay = 0 self.maxIteration = 25 - self.shuffleIndices = true end function Trainer:train(dataset) @@ -14,14 +13,12 @@ function Trainer:write(file) file:writeDouble(self.learningRate) file:writeDouble(self.learningRateDecay) file:writeInt(self.maxIteration) - file:writeBool(self.shuffleIndices) end function Trainer:read(file) self.learningRate = file:readDouble() self.learningRateDecay = file:readDouble() self.maxIteration = file:readInt() - self.shuffleIndices = file:readBool() end function Trainer:share(mlp, ...) @@ -30,10 +27,6 @@ function Trainer:share(mlp, ...) end end -function Trainer:setShuffle(bool) - self.shuffleIndices = bool -end - function Trainer:clone(...) local f = torch.MemoryFile("rw"):binary() f:writeObject(self) |