Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/clementfarabet/lua---nnx.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorClement Farabet <clement.farabet@gmail.com>2011-08-30 00:32:06 +0400
committerClement Farabet <clement.farabet@gmail.com>2011-08-30 00:32:06 +0400
commit68efd2b303b962f97e6cd3fb4e864ae192bac4ad (patch)
treead22aebcec6a6f41bb3093ca2d8cc1ef0da21850
parent80edcf411716bcfb0ac1cdcdaa967922d1bf0fea (diff)
Got rid of StochasticTrainer.
-rw-r--r--OnlineTrainer.lua2
-rw-r--r--StochasticTrainer.lua265
-rw-r--r--init.lua1
-rw-r--r--nnx-1.0-1.rockspec1
4 files changed, 1 insertions, 268 deletions
diff --git a/OnlineTrainer.lua b/OnlineTrainer.lua
index bb120e8..db662e5 100644
--- a/OnlineTrainer.lua
+++ b/OnlineTrainer.lua
@@ -74,7 +74,7 @@ function OnlineTrainer:train(dataset)
while true do
print('<trainer> on training set:')
- print("<trainer> online epoch # " .. self.epoch .. '[batchSize = ' .. self.batchSize .. ']')
+ print("<trainer> online epoch # " .. self.epoch .. ' [batchSize = ' .. self.batchSize .. ']')
self.time = sys.clock()
self.currentError = 0
diff --git a/StochasticTrainer.lua b/StochasticTrainer.lua
deleted file mode 100644
index 62fb670..0000000
--- a/StochasticTrainer.lua
+++ /dev/null
@@ -1,265 +0,0 @@
-local StochasticTrainer, parent = torch.class('nn.StochasticTrainer','nn.Trainer')
-
-function StochasticTrainer:__init(...)
- parent.__init(self)
- -- unpack args
- xlua.unpack_class(self, {...},
- 'StochasticTrainer',
-
- 'A general-purpose stochastic trainer class.\n'
- .. 'Provides 4 user hooks to perform extra work after each sample, or each epoch:\n'
- .. '> trainer = nn.StochasticTrainer(...) \n'
- .. '> trainer.hookTrainSample = function(trainer, sample) ... end \n'
- .. '> trainer.hookTrainEpoch = function(trainer) ... end \n'
- .. '> trainer.hookTestSample = function(trainer, sample) ... end \n'
- .. '> trainer.hookTestEpoch = function(trainer) ... end \n'
- .. '> ',
-
- {arg='module', type='nn.Module', help='a module to train', req=true},
- {arg='criterion', type='nn.Module', help='a criterion to estimate the error'},
- {arg='preprocessor', type='nn.Module', help='a preprocessor to prime the data before the module'},
-
- {arg='learningRate', type='number', help='learning rate (W = W - rate*dE/dW)', default=1e-2},
- {arg='learningRateDecay', type='number', help='learning rate decay (rate = rate * (1-decay), at each epoch)', default=0},
- {arg='weightDecay', type='number', help='amount of weight decay (W = W - decay*W)', default=0},
- {arg='momentum', type='number', help='amount of momentum on weights (dE/W = dE/dW + momentum*prev(dE/dW))', default=0},
- {arg='maxEpoch', type='number', help='maximum number of epochs', default=50},
-
- {arg='maxTarget', type='boolean', help='replaces an CxHxW target map by a HxN target of max values (for NLL criterions)', default=false},
- {arg='dispProgress', type='boolean', help='display a progress bar during training/testing', default=true},
- {arg='skipUniformTargets', type='boolean', help='skip uniform (flat) targets during training', default=false},
-
- {arg='save', type='string', help='path to save networks and log training'},
- {arg='timestamp', type='boolean', help='if true, appends a timestamp to each network saved', default=false}
- )
- -- instantiate SGD optimization module
- self.optimizer = nn.SGDOptimization(self.learningRate, self.weightDecay, self.momentum)
- -- private params
- self.errorArray = self.skipUniformTargets
- self.trainOffset = 0
- self.testOffset = 0
-end
-
-function StochasticTrainer:log()
- -- save network
- local filename = self.save
- os.execute('mkdir -p ' .. sys.dirname(filename))
- if self.timestamp then
- -- use a timestamp to store all networks uniquely
- filename = filename .. '-' .. os.date("%Y_%m_%d_%X")
- else
- -- if no timestamp, just store the previous one
- if sys.filep(filename) then
- os.execute('mv ' .. filename .. ' ' .. filename .. '.old')
- end
- end
- print('<trainer> saving network to '..filename)
- local file = torch.DiskFile(filename,'w')
- self.module:write(file)
- file:close()
-end
-
-function StochasticTrainer:train(dataset)
- self.epoch = self.epoch or 1
- local currentLearningRate = self.learningRate
- local module = self.module
- local criterion = self.criterion
- self.trainset = dataset
-
- local shuffledIndices = {}
- if not self.shuffleIndices then
- for t = 1,dataset:size() do
- shuffledIndices[t] = t
- end
- else
- shuffledIndices = lab.randperm(dataset:size())
- end
-
- local parameters = nnx.getParameters(module)
- local gradParameters = nnx.getGradParameters(module)
-
- while true do
- print('<trainer> on training set:')
- print("<trainer> stochastic gradient descent epoch # " .. self.epoch)
-
- module:zeroGradParameters()
-
- self.time = sys.clock()
- self.currentError = 0
- for t = 1,dataset:size() do
- -- disp progress
- if self.dispProgress then
- xlua.progress(t, dataset:size())
- end
-
- -- load new sample
- local sample = dataset[self.trainOffset + shuffledIndices[t]]
- local input = sample[1]
- local target = sample[2]
- local sample_x = sample.x
- local sample_y = sample.y
-
- -- get max of target ?
- if self.maxTarget then
- target = torch.Tensor(target:nElement()):copy(target)
- _,target = lab.max(target)
- target = target[1]
- end
-
- -- is target uniform ?
- local isUniform = false
- if self.errorArray and target:min() == target:max() then
- isUniform = true
- end
-
- -- perform SGD step
- if not (self.skipUniformTargets and isUniform) then
- -- optional preprocess
- if self.preprocessor then input = self.preprocessor:forward(input) end
-
- -- forward through model and criterion
- -- (if no criterion, it is assumed to be contained in the model)
- local modelOut, error
- if criterion then
- modelOut = module:forward(input)
- error = criterion:forward(modelOut, target)
- else
- modelOut, error = module:forward(input, target, sample_x, sample_y)
- end
-
- -- accumulate error
- self.currentError = self.currentError + error
-
- -- reset gradients
- module:zeroGradParameters()
-
- -- backward through model
- -- (if no criterion, it is assumed that derror is internally generated)
- if criterion then
- local derror = criterion:backward(module.output, target)
- module:backward(input, derror)
- else
- module:backward(input)
- end
-
- -- update parameters in the model
- self.optimizer:forward(parameters, gradParameters)
- end
-
- -- call user hook, if any
- if self.hookTrainSample then
- self.hookTrainSample(self, sample)
- end
- end
-
- self.currentError = self.currentError / dataset:size()
- print("<trainer> current error = " .. self.currentError)
-
- self.time = sys.clock() - self.time
- self.time = self.time / dataset:size()
- print("<trainer> time to learn 1 sample = " .. (self.time*1000) .. 'ms')
-
- if self.hookTrainEpoch then
- self.hookTrainEpoch(self)
- end
-
- if self.save then self:log() end
-
- self.epoch = self.epoch + 1
- currentLearningRate = self.learningRate/(1+self.epoch*self.learningRateDecay)
- self.optimizer.learningRate = currentLearningRate
-
- if dataset.infiniteSet then
- self.trainOffset = self.trainOffset + dataset:size()
- end
-
- if self.maxEpoch > 0 and self.epoch > self.maxEpoch then
- print("<trainer> you have reached the maximum number of epochs")
- break
- end
- end
-end
-
-
-function StochasticTrainer:test(dataset)
- print('<trainer> on testing Set:')
-
- local module = self.module
- local shuffledIndices = {}
- local criterion = self.criterion
- self.currentError = 0
- self.testset = dataset
-
- local shuffledIndices = {}
- if not self.shuffleIndices then
- for t = 1,dataset:size() do
- shuffledIndices[t] = t
- end
- else
- shuffledIndices = lab.randperm(dataset:size())
- end
-
- self.time = sys.clock()
- for t = 1,dataset:size() do
- -- disp progress
- if self.dispProgress then
- xlua.progress(t, dataset:size())
- end
-
- -- get new sample
- local sample = dataset[self.testOffset + shuffledIndices[t]]
- local input = sample[1]
- local target = sample[2]
-
- -- max target ?
- if self.maxTarget then
- target = torch.Tensor(target:nElement()):copy(target)
- _,target = lab.max(target)
- target = target[1]
- end
-
- -- test sample through current model
- if self.preprocessor then input = self.preprocessor:forward(input) end
- if criterion then
- self.currentError = self.currentError +
- criterion:forward(module:forward(input), target)
- else
- local _,error = module:forward(input, target)
- self.currentError = self.currentError + error
- end
-
- -- user hook
- if self.hookTestSample then
- self.hookTestSample(self, sample)
- end
- end
-
- self.currentError = self.currentError / dataset:size()
- print("<trainer> test current error = " .. self.currentError)
-
- self.time = sys.clock() - self.time
- self.time = self.time / dataset:size()
- print("<trainer> time to test 1 sample = " .. (self.time*1000) .. 'ms')
-
- if self.hookTestEpoch then
- self.hookTestEpoch(self)
- end
-
- if dataset.infiniteSet then
- self.testOffset = self.testOffset + dataset:size()
- end
-
- return self.currentError
-end
-
-function StochasticTrainer:write(file)
- parent.write(self,file)
- file:writeObject(self.module)
- file:writeObject(self.criterion)
-end
-
-function StochasticTrainer:read(file)
- parent.read(self,file)
- self.module = file:readObject()
- self.criterion = file:readObject()
-end
diff --git a/init.lua b/init.lua
index 04877d3..20246bc 100644
--- a/init.lua
+++ b/init.lua
@@ -105,7 +105,6 @@ torch.include('nnx', 'LBFGSOptimization.lua')
-- trainers:
torch.include('nnx', 'Trainer.lua')
torch.include('nnx', 'OnlineTrainer.lua')
-torch.include('nnx', 'StochasticTrainer.lua')
-- datasets:
torch.include('nnx', 'DataSet.lua')
diff --git a/nnx-1.0-1.rockspec b/nnx-1.0-1.rockspec
index 4529d24..cfbc571 100644
--- a/nnx-1.0-1.rockspec
+++ b/nnx-1.0-1.rockspec
@@ -83,7 +83,6 @@ build = {
install_files(/lua/nnx SpatialCriterion.lua)
install_files(/lua/nnx Trainer.lua)
install_files(/lua/nnx OnlineTrainer.lua)
- install_files(/lua/nnx StochasticTrainer.lua)
install_files(/lua/nnx DataSet.lua)
install_files(/lua/nnx DataList.lua)
install_files(/lua/nnx DataSetLabelMe.lua)