diff options
author | Ronan Collobert <ronan@collobert.com> | 2012-01-25 17:55:20 +0400 |
---|---|---|
committer | Ronan Collobert <ronan@collobert.com> | 2012-01-25 17:55:20 +0400 |
commit | 4df3893abd1b9f840f1d9a8c1859799ccbf941de (patch) | |
tree | e8a1e1cc1b6ea6e47855347b157eaf419fdb357b /StochasticGradient.lua |
initial revamp of torch7 tree
Diffstat (limited to 'StochasticGradient.lua')
-rw-r--r-- | StochasticGradient.lua | 57 |
1 files changed, 57 insertions, 0 deletions
diff --git a/StochasticGradient.lua b/StochasticGradient.lua new file mode 100644 index 0000000..2d5e810 --- /dev/null +++ b/StochasticGradient.lua @@ -0,0 +1,57 @@ +local StochasticGradient = torch.class('nn.StochasticGradient') + +function StochasticGradient:__init(module, criterion) + self.learningRate = 0.01 + self.learningRateDecay = 0 + self.maxIteration = 25 + self.shuffleIndices = true + self.module = module + self.criterion = criterion +end + +function StochasticGradient:train(dataset) + local iteration = 1 + local currentLearningRate = self.learningRate + local module = self.module + local criterion = self.criterion + + local shuffledIndices = torch.randperm(dataset:size(), 'torch.LongTensor') + if not self.shuffleIndices then + for t = 1,dataset:size() do + shuffledIndices[t] = t + end + end + + print("# StochasticGradient: training") + + while true do + local currentError = 0 + for t = 1,dataset:size() do + local example = dataset[shuffledIndices[t]] + local input = example[1] + local target = example[2] + + currentError = currentError + criterion:forward(module:forward(input), target) + + module:updateGradInput(input, criterion:updateGradInput(module.output, target)) + module:accUpdateGradParameters(input, criterion.gradInput, currentLearningRate) + + if self.hookExample then + self.hookExample(self, example) + end + end + + if self.hookIteration then + self.hookIteration(self, iteration) + end + + currentError = currentError / dataset:size() + print("# current error = " .. currentError) + iteration = iteration + 1 + currentLearningRate = self.learningRate/(1+iteration*self.learningRateDecay) + if self.maxIteration > 0 and iteration > self.maxIteration then + print("# StochasticGradient: you have reached the maximum number of iterations") + break + end + end +end |