Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/OpenNMT/OpenNMT-py.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorVincent Nguyen <vince62s@yahoo.com>2022-10-26 13:28:12 +0300
committerGitHub <noreply@github.com>2022-10-26 13:28:12 +0300
commit2edb50d29b2ccd6927136e28e7aefd95b827b6b1 (patch)
treeca36078a1a7371107ca355c86822373afda19184
parenta0761136982ac7c30eba58a31cec1327db9dcc37 (diff)
parent6de7962aaed0cbe99235c60b61de937b79bb0416 (diff)
Merge pull request #2228 from vince62s/stats
add stats
-rw-r--r--onmt/inputters/text_corpus.py1
-rw-r--r--onmt/trainer.py5
-rw-r--r--onmt/utils/loss.py15
-rw-r--r--onmt/utils/report_manager.py5
-rw-r--r--onmt/utils/statistics.py17
5 files changed, 33 insertions, 10 deletions
diff --git a/onmt/inputters/text_corpus.py b/onmt/inputters/text_corpus.py
index b5b7c021..7aaf77b2 100644
--- a/onmt/inputters/text_corpus.py
+++ b/onmt/inputters/text_corpus.py
@@ -219,7 +219,6 @@ def build_corpora_iters(corpora, transforms, corpora_info,
transforms[name] for name in transform_names if name in transforms
]
transform_pipe = TransformPipe.build_from(corpus_transform)
- logger.info(f"{c_id}'s transforms: {str(transform_pipe)}")
corpus_iter = ParallelCorpusIterator(
corpus, transform_pipe,
skip_empty_level=skip_empty_level, stride=stride, offset=offset)
diff --git a/onmt/trainer.py b/onmt/trainer.py
index 26671819..d401d673 100644
--- a/onmt/trainer.py
+++ b/onmt/trainer.py
@@ -263,7 +263,9 @@ class Trainer(object):
valid_stats = self.validate(
valid_iter, moving_average=self.moving_average)
self._report_step(self.optim.learning_rate(),
- step, valid_stats=valid_stats)
+ step, train_stats=total_stats,
+ valid_stats=valid_stats)
+
# Run patience mechanism
if self.earlystopper is not None:
self.earlystopper(valid_stats, step)
@@ -374,6 +376,7 @@ class Trainer(object):
src_lengths = batch['srclen']
if src_lengths is not None:
report_stats.n_src_words += src_lengths.sum().item()
+ total_stats.n_src_words += src_lengths.sum().item()
tgt_outer = batch['tgt']
diff --git a/onmt/utils/loss.py b/onmt/utils/loss.py
index 5c3e4367..7409bcff 100644
--- a/onmt/utils/loss.py
+++ b/onmt/utils/loss.py
@@ -159,7 +159,8 @@ class LossCompute(nn.Module):
target_data[correct_mask] += offset_align
# Compute sum of perplexities for stats
- stats = self._stats(loss.item(), scores_data, target_data)
+ stats = self._stats(len(batch['srclen']), loss.item(),
+ scores_data, target_data)
return loss, stats
@@ -231,7 +232,8 @@ class LossCompute(nn.Module):
align_head=align_head, ref_align=ref_align)
loss += align_loss
- stats = self._stats(loss.item(), scores, target)
+ stats = self._stats(len(batch['srclen']), loss.item(),
+ scores, target)
if self.lambda_coverage != 0.0:
coverage_loss = self._compute_coverage_loss(
@@ -247,7 +249,7 @@ class LossCompute(nn.Module):
return loss / float(normfactor), stats
- def _stats(self, loss, scores, target):
+ def _stats(self, bsz, loss, scores, target):
"""
Args:
loss (int): the loss computed by the loss criterion.
@@ -263,8 +265,11 @@ class LossCompute(nn.Module):
num_non_padding = non_padding.sum().item()
# in the case criterion reduction is None then we need
# to sum the loss of each sentence in the batch
- return onmt.utils.Statistics(loss,
- num_non_padding, num_correct)
+ return onmt.utils.Statistics(loss=loss,
+ n_batchs=1,
+ n_sents=bsz,
+ n_words=num_non_padding,
+ n_correct=num_correct)
class LabelSmoothingLoss(nn.Module):
diff --git a/onmt/utils/report_manager.py b/onmt/utils/report_manager.py
index 6642ab0d..0a4bc43a 100644
--- a/onmt/utils/report_manager.py
+++ b/onmt/utils/report_manager.py
@@ -147,6 +147,11 @@ class ReportMgr(ReportMgrBase):
if train_stats is not None:
self.log('Train perplexity: %g' % train_stats.ppl())
self.log('Train accuracy: %g' % train_stats.accuracy())
+ self.log('Sentences processed: %g' % train_stats.n_sents)
+ self.log('Average bsz: %4.0f/%4.0f/%2.0f' %
+ (train_stats.n_src_words / train_stats.n_batchs,
+ train_stats.n_words / train_stats.n_batchs,
+ train_stats.n_sents / train_stats.n_batchs))
self.maybe_log_tensorboard(train_stats,
"train",
diff --git a/onmt/utils/statistics.py b/onmt/utils/statistics.py
index e98e60c4..1b60de6f 100644
--- a/onmt/utils/statistics.py
+++ b/onmt/utils/statistics.py
@@ -16,8 +16,11 @@ class Statistics(object):
* elapsed time
"""
- def __init__(self, loss=0, n_words=0, n_correct=0, computed_metrics={}):
+ def __init__(self, loss=0, n_batchs=0, n_sents=0,
+ n_words=0, n_correct=0, computed_metrics={}):
self.loss = loss
+ self.n_batchs = n_batchs
+ self.n_sents = n_sents
self.n_words = n_words
self.n_correct = n_correct
self.n_src_words = 0
@@ -79,11 +82,14 @@ class Statistics(object):
"""
self.loss += stat.loss
+ self.n_batchs += stat.n_batchs
+ self.n_sents += stat.n_sents
self.n_words += stat.n_words
self.n_correct += stat.n_correct
self.computed_metrics = stat.computed_metrics
if update_n_src_words:
+ print("updating n_src_word")
self.n_src_words += stat.n_src_words
def accuracy(self):
@@ -115,13 +121,18 @@ class Statistics(object):
if num_steps > 0:
step_fmt = "%s/%5d" % (step_fmt, num_steps)
logger.info(
- ("Step %s; acc: %6.2f; ppl: %5.2f; xent: %4.2f; " +
- "lr: %7.5f; %3.0f/%3.0f tok/s; %6.0f sec;")
+ ("Step %s; acc: %2.1f; ppl: %5.1f; xent: %2.1f; " +
+ "lr: %7.5f; sents: %7.0f; bsz: %4.0f/%4.0f/%2.0f; " +
+ "%3.0f/%3.0f tok/s; %6.0f sec;")
% (step_fmt,
self.accuracy(),
self.ppl(),
self.xent(),
learning_rate,
+ self.n_sents,
+ self.n_src_words / self.n_batchs,
+ self.n_words / self.n_batchs,
+ self.n_sents / self.n_batchs,
self.n_src_words / (t + 1e-5),
self.n_words / (t + 1e-5),
time.time() - start) +