Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/clementfarabet/lua---nnx.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--StochasticTrainer.lua11
1 files changed, 10 insertions, 1 deletions
diff --git a/StochasticTrainer.lua b/StochasticTrainer.lua
index 006de5c..65f9fe3 100644
--- a/StochasticTrainer.lua
+++ b/StochasticTrainer.lua
@@ -69,6 +69,7 @@ function StochasticTrainer:train(dataset)
module:zeroGradParameters()
+ self.time = sys.clock()
self.currentError = 0
for t = 1,dataset:size() do
-- disp progress
@@ -142,6 +143,10 @@ function StochasticTrainer:train(dataset)
self.currentError = self.currentError / dataset:size()
print("<trainer> current error = " .. self.currentError)
+ self.time = sys.clock() - self.time
+ self.time = self.time / dataset:size()
+ print("<trainer> time to learn 1 sample = " .. (self.time*1000) .. 'ms')
+
if self.hookTrainEpoch then
self.hookTrainEpoch(self)
end
@@ -172,7 +177,6 @@ function StochasticTrainer:test(dataset)
self.currentError = 0
self.testset = dataset
-
local shuffledIndices = {}
if not self.shuffleIndices then
for t = 1,dataset:size() do
@@ -182,6 +186,7 @@ function StochasticTrainer:test(dataset)
shuffledIndices = lab.randperm(dataset:size())
end
+ self.time = sys.clock()
for t = 1,dataset:size() do
-- disp progress
if self.dispProgress then
@@ -219,6 +224,10 @@ function StochasticTrainer:test(dataset)
self.currentError = self.currentError / dataset:size()
print("<trainer> test current error = " .. self.currentError)
+ self.time = sys.clock() - self.time
+ self.time = self.time / dataset:size()
+ print("<trainer> time to test 1 sample = " .. (self.time*1000) .. 'ms')
+
if self.hookTestEpoch then
self.hookTestEpoch(self)
end