Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/OpenNMT/OpenNMT-py.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'onmt/utils/statistics.py')
-rw-r--r--onmt/utils/statistics.py17
1 files changed, 14 insertions, 3 deletions
diff --git a/onmt/utils/statistics.py b/onmt/utils/statistics.py
index e98e60c4..1b60de6f 100644
--- a/onmt/utils/statistics.py
+++ b/onmt/utils/statistics.py
@@ -16,8 +16,11 @@ class Statistics(object):
* elapsed time
"""
- def __init__(self, loss=0, n_words=0, n_correct=0, computed_metrics={}):
+ def __init__(self, loss=0, n_batchs=0, n_sents=0,
+ n_words=0, n_correct=0, computed_metrics={}):
self.loss = loss
+ self.n_batchs = n_batchs
+ self.n_sents = n_sents
self.n_words = n_words
self.n_correct = n_correct
self.n_src_words = 0
@@ -79,11 +82,14 @@ class Statistics(object):
"""
self.loss += stat.loss
+ self.n_batchs += stat.n_batchs
+ self.n_sents += stat.n_sents
self.n_words += stat.n_words
self.n_correct += stat.n_correct
self.computed_metrics = stat.computed_metrics
if update_n_src_words:
+ print("updating n_src_word")
self.n_src_words += stat.n_src_words
def accuracy(self):
@@ -115,13 +121,18 @@ class Statistics(object):
if num_steps > 0:
step_fmt = "%s/%5d" % (step_fmt, num_steps)
logger.info(
- ("Step %s; acc: %6.2f; ppl: %5.2f; xent: %4.2f; " +
- "lr: %7.5f; %3.0f/%3.0f tok/s; %6.0f sec;")
+ ("Step %s; acc: %2.1f; ppl: %5.1f; xent: %2.1f; " +
+ "lr: %7.5f; sents: %7.0f; bsz: %4.0f/%4.0f/%2.0f; " +
+ "%3.0f/%3.0f tok/s; %6.0f sec;")
% (step_fmt,
self.accuracy(),
self.ppl(),
self.xent(),
learning_rate,
+ self.n_sents,
+ self.n_src_words / self.n_batchs,
+ self.n_words / self.n_batchs,
+ self.n_sents / self.n_batchs,
self.n_src_words / (t + 1e-5),
self.n_words / (t + 1e-5),
time.time() - start) +