Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRonan Collobert <ronan@collobert.com>2012-01-25 17:55:20 +0400
committerRonan Collobert <ronan@collobert.com>2012-01-25 17:55:20 +0400
commit4df3893abd1b9f840f1d9a8c1859799ccbf941de (patch)
treee8a1e1cc1b6ea6e47855347b157eaf419fdb357b /StochasticGradient.lua
initial revamp of torch7 tree
Diffstat (limited to 'StochasticGradient.lua')
-rw-r--r--StochasticGradient.lua57
1 files changed, 57 insertions, 0 deletions
diff --git a/StochasticGradient.lua b/StochasticGradient.lua
new file mode 100644
index 0000000..2d5e810
--- /dev/null
+++ b/StochasticGradient.lua
@@ -0,0 +1,57 @@
+local StochasticGradient = torch.class('nn.StochasticGradient')
+
+function StochasticGradient:__init(module, criterion)
+ self.learningRate = 0.01
+ self.learningRateDecay = 0
+ self.maxIteration = 25
+ self.shuffleIndices = true
+ self.module = module
+ self.criterion = criterion
+end
+
+function StochasticGradient:train(dataset)
+ local iteration = 1
+ local currentLearningRate = self.learningRate
+ local module = self.module
+ local criterion = self.criterion
+
+ local shuffledIndices = torch.randperm(dataset:size(), 'torch.LongTensor')
+ if not self.shuffleIndices then
+ for t = 1,dataset:size() do
+ shuffledIndices[t] = t
+ end
+ end
+
+ print("# StochasticGradient: training")
+
+ while true do
+ local currentError = 0
+ for t = 1,dataset:size() do
+ local example = dataset[shuffledIndices[t]]
+ local input = example[1]
+ local target = example[2]
+
+ currentError = currentError + criterion:forward(module:forward(input), target)
+
+ module:updateGradInput(input, criterion:updateGradInput(module.output, target))
+ module:accUpdateGradParameters(input, criterion.gradInput, currentLearningRate)
+
+ if self.hookExample then
+ self.hookExample(self, example)
+ end
+ end
+
+ if self.hookIteration then
+ self.hookIteration(self, iteration)
+ end
+
+ currentError = currentError / dataset:size()
+ print("# current error = " .. currentError)
+ iteration = iteration + 1
+ currentLearningRate = self.learningRate/(1+iteration*self.learningRateDecay)
+ if self.maxIteration > 0 and iteration > self.maxIteration then
+ print("# StochasticGradient: you have reached the maximum number of iterations")
+ break
+ end
+ end
+end