diff options
-rw-r--r-- | StochasticTrainer.lua | 11 |
1 files changed, 10 insertions, 1 deletions
diff --git a/StochasticTrainer.lua b/StochasticTrainer.lua index 006de5c..65f9fe3 100644 --- a/StochasticTrainer.lua +++ b/StochasticTrainer.lua @@ -69,6 +69,7 @@ function StochasticTrainer:train(dataset) module:zeroGradParameters() + self.time = sys.clock() self.currentError = 0 for t = 1,dataset:size() do -- disp progress @@ -142,6 +143,10 @@ function StochasticTrainer:train(dataset) self.currentError = self.currentError / dataset:size() print("<trainer> current error = " .. self.currentError) + self.time = sys.clock() - self.time + self.time = self.time / dataset:size() + print("<trainer> time to learn 1 sample = " .. (self.time*1000) .. 'ms') + if self.hookTrainEpoch then self.hookTrainEpoch(self) end @@ -172,7 +177,6 @@ function StochasticTrainer:test(dataset) self.currentError = 0 self.testset = dataset - local shuffledIndices = {} if not self.shuffleIndices then for t = 1,dataset:size() do @@ -182,6 +186,7 @@ function StochasticTrainer:test(dataset) shuffledIndices = lab.randperm(dataset:size()) end + self.time = sys.clock() for t = 1,dataset:size() do -- disp progress if self.dispProgress then @@ -219,6 +224,10 @@ function StochasticTrainer:test(dataset) self.currentError = self.currentError / dataset:size() print("<trainer> test current error = " .. self.currentError) + self.time = sys.clock() - self.time + self.time = self.time / dataset:size() + print("<trainer> time to test 1 sample = " .. (self.time*1000) .. 'ms') + if self.hookTestEpoch then self.hookTestEpoch(self) end |