diff options
Diffstat (limited to 'onmt/utils/statistics.py')
-rw-r--r-- | onmt/utils/statistics.py | 17 |
1 files changed, 14 insertions, 3 deletions
diff --git a/onmt/utils/statistics.py b/onmt/utils/statistics.py index e98e60c4..1b60de6f 100644 --- a/onmt/utils/statistics.py +++ b/onmt/utils/statistics.py @@ -16,8 +16,11 @@ class Statistics(object): * elapsed time """ - def __init__(self, loss=0, n_words=0, n_correct=0, computed_metrics={}): + def __init__(self, loss=0, n_batchs=0, n_sents=0, + n_words=0, n_correct=0, computed_metrics={}): self.loss = loss + self.n_batchs = n_batchs + self.n_sents = n_sents self.n_words = n_words self.n_correct = n_correct self.n_src_words = 0 @@ -79,11 +82,14 @@ class Statistics(object): """ self.loss += stat.loss + self.n_batchs += stat.n_batchs + self.n_sents += stat.n_sents self.n_words += stat.n_words self.n_correct += stat.n_correct self.computed_metrics = stat.computed_metrics if update_n_src_words: + print("updating n_src_word") self.n_src_words += stat.n_src_words def accuracy(self): @@ -115,13 +121,18 @@ class Statistics(object): if num_steps > 0: step_fmt = "%s/%5d" % (step_fmt, num_steps) logger.info( - ("Step %s; acc: %6.2f; ppl: %5.2f; xent: %4.2f; " + - "lr: %7.5f; %3.0f/%3.0f tok/s; %6.0f sec;") + ("Step %s; acc: %2.1f; ppl: %5.1f; xent: %2.1f; " + + "lr: %7.5f; sents: %7.0f; bsz: %4.0f/%4.0f/%2.0f; " + + "%3.0f/%3.0f tok/s; %6.0f sec;") % (step_fmt, self.accuracy(), self.ppl(), self.xent(), learning_rate, + self.n_sents, + self.n_src_words / self.n_batchs, + self.n_words / self.n_batchs, + self.n_sents / self.n_batchs, self.n_src_words / (t + 1e-5), self.n_words / (t + 1e-5), time.time() - start) + |