diff options
author | Vincent Nguyen <vince62s@yahoo.com> | 2022-10-26 13:28:12 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-10-26 13:28:12 +0300 |
commit | 2edb50d29b2ccd6927136e28e7aefd95b827b6b1 (patch) | |
tree | ca36078a1a7371107ca355c86822373afda19184 | |
parent | a0761136982ac7c30eba58a31cec1327db9dcc37 (diff) | |
parent | 6de7962aaed0cbe99235c60b61de937b79bb0416 (diff) |
Merge pull request #2228 from vince62s/stats
add stats
-rw-r--r-- | onmt/inputters/text_corpus.py | 1 | ||||
-rw-r--r-- | onmt/trainer.py | 5 | ||||
-rw-r--r-- | onmt/utils/loss.py | 15 | ||||
-rw-r--r-- | onmt/utils/report_manager.py | 5 | ||||
-rw-r--r-- | onmt/utils/statistics.py | 17 |
5 files changed, 33 insertions, 10 deletions
diff --git a/onmt/inputters/text_corpus.py b/onmt/inputters/text_corpus.py index b5b7c021..7aaf77b2 100644 --- a/onmt/inputters/text_corpus.py +++ b/onmt/inputters/text_corpus.py @@ -219,7 +219,6 @@ def build_corpora_iters(corpora, transforms, corpora_info, transforms[name] for name in transform_names if name in transforms ] transform_pipe = TransformPipe.build_from(corpus_transform) - logger.info(f"{c_id}'s transforms: {str(transform_pipe)}") corpus_iter = ParallelCorpusIterator( corpus, transform_pipe, skip_empty_level=skip_empty_level, stride=stride, offset=offset) diff --git a/onmt/trainer.py b/onmt/trainer.py index 26671819..d401d673 100644 --- a/onmt/trainer.py +++ b/onmt/trainer.py @@ -263,7 +263,9 @@ class Trainer(object): valid_stats = self.validate( valid_iter, moving_average=self.moving_average) self._report_step(self.optim.learning_rate(), - step, valid_stats=valid_stats) + step, train_stats=total_stats, + valid_stats=valid_stats) + # Run patience mechanism if self.earlystopper is not None: self.earlystopper(valid_stats, step) @@ -374,6 +376,7 @@ class Trainer(object): src_lengths = batch['srclen'] if src_lengths is not None: report_stats.n_src_words += src_lengths.sum().item() + total_stats.n_src_words += src_lengths.sum().item() tgt_outer = batch['tgt'] diff --git a/onmt/utils/loss.py b/onmt/utils/loss.py index 5c3e4367..7409bcff 100644 --- a/onmt/utils/loss.py +++ b/onmt/utils/loss.py @@ -159,7 +159,8 @@ class LossCompute(nn.Module): target_data[correct_mask] += offset_align # Compute sum of perplexities for stats - stats = self._stats(loss.item(), scores_data, target_data) + stats = self._stats(len(batch['srclen']), loss.item(), + scores_data, target_data) return loss, stats @@ -231,7 +232,8 @@ class LossCompute(nn.Module): align_head=align_head, ref_align=ref_align) loss += align_loss - stats = self._stats(loss.item(), scores, target) + stats = self._stats(len(batch['srclen']), loss.item(), + scores, target) if self.lambda_coverage != 0.0: coverage_loss = self._compute_coverage_loss( @@ -247,7 +249,7 @@ class LossCompute(nn.Module): return loss / float(normfactor), stats - def _stats(self, loss, scores, target): + def _stats(self, bsz, loss, scores, target): """ Args: loss (int): the loss computed by the loss criterion. @@ -263,8 +265,11 @@ class LossCompute(nn.Module): num_non_padding = non_padding.sum().item() # in the case criterion reduction is None then we need # to sum the loss of each sentence in the batch - return onmt.utils.Statistics(loss, - num_non_padding, num_correct) + return onmt.utils.Statistics(loss=loss, + n_batchs=1, + n_sents=bsz, + n_words=num_non_padding, + n_correct=num_correct) class LabelSmoothingLoss(nn.Module): diff --git a/onmt/utils/report_manager.py b/onmt/utils/report_manager.py index 6642ab0d..0a4bc43a 100644 --- a/onmt/utils/report_manager.py +++ b/onmt/utils/report_manager.py @@ -147,6 +147,11 @@ class ReportMgr(ReportMgrBase): if train_stats is not None: self.log('Train perplexity: %g' % train_stats.ppl()) self.log('Train accuracy: %g' % train_stats.accuracy()) + self.log('Sentences processed: %g' % train_stats.n_sents) + self.log('Average bsz: %4.0f/%4.0f/%2.0f' % + (train_stats.n_src_words / train_stats.n_batchs, + train_stats.n_words / train_stats.n_batchs, + train_stats.n_sents / train_stats.n_batchs)) self.maybe_log_tensorboard(train_stats, "train", diff --git a/onmt/utils/statistics.py b/onmt/utils/statistics.py index e98e60c4..1b60de6f 100644 --- a/onmt/utils/statistics.py +++ b/onmt/utils/statistics.py @@ -16,8 +16,11 @@ class Statistics(object): * elapsed time """ - def __init__(self, loss=0, n_words=0, n_correct=0, computed_metrics={}): + def __init__(self, loss=0, n_batchs=0, n_sents=0, + n_words=0, n_correct=0, computed_metrics={}): self.loss = loss + self.n_batchs = n_batchs + self.n_sents = n_sents self.n_words = n_words self.n_correct = n_correct self.n_src_words = 0 @@ -79,11 +82,14 @@ class Statistics(object): """ self.loss += stat.loss + self.n_batchs += stat.n_batchs + self.n_sents += stat.n_sents self.n_words += stat.n_words self.n_correct += stat.n_correct self.computed_metrics = stat.computed_metrics if update_n_src_words: + print("updating n_src_word") self.n_src_words += stat.n_src_words def accuracy(self): @@ -115,13 +121,18 @@ class Statistics(object): if num_steps > 0: step_fmt = "%s/%5d" % (step_fmt, num_steps) logger.info( - ("Step %s; acc: %6.2f; ppl: %5.2f; xent: %4.2f; " + - "lr: %7.5f; %3.0f/%3.0f tok/s; %6.0f sec;") + ("Step %s; acc: %2.1f; ppl: %5.1f; xent: %2.1f; " + + "lr: %7.5f; sents: %7.0f; bsz: %4.0f/%4.0f/%2.0f; " + + "%3.0f/%3.0f tok/s; %6.0f sec;") % (step_fmt, self.accuracy(), self.ppl(), self.xent(), learning_rate, + self.n_sents, + self.n_src_words / self.n_batchs, + self.n_words / self.n_batchs, + self.n_sents / self.n_batchs, self.n_src_words / (t + 1e-5), self.n_words / (t + 1e-5), time.time() - start) + |