diff options
author | John Bauer <horatio@gmail.com> | 2022-09-12 09:19:47 +0300 |
---|---|---|
committer | John Bauer <horatio@gmail.com> | 2022-09-12 09:19:47 +0300 |
commit | c4d785729e42ac90f298e0ef4ab487d14fa35591 (patch) | |
tree | 48c833aa1529957c7be3062fd32705b1fa3cbcd4 | |
parent | 824bb780bba86cc0610346f538cf9eecd065fe33 (diff) |
Throw out batches which had gone to NaN. Log the number of times it happens
-rw-r--r-- | stanza/models/constituency/trainer.py | 28 |
1 files changed, 21 insertions, 7 deletions
diff --git a/stanza/models/constituency/trainer.py b/stanza/models/constituency/trainer.py index 5e2eb8ff..9a6ea4c1 100644 --- a/stanza/models/constituency/trainer.py +++ b/stanza/models/constituency/trainer.py @@ -500,14 +500,15 @@ def train(args, model_load_file, model_save_each_file, retag_pipeline): TrainItem = namedtuple("TrainItem", ['tree', 'gold_sequence', 'preterminals']) -class EpochStats(namedtuple("EpochStats", ['epoch_loss', 'transitions_correct', 'transitions_incorrect', 'repairs_used', 'fake_transitions_used'])): +class EpochStats(namedtuple("EpochStats", ['epoch_loss', 'transitions_correct', 'transitions_incorrect', 'repairs_used', 'fake_transitions_used', 'nans'])): def __add__(self, other): transitions_correct = self.transitions_correct + other.transitions_correct transitions_incorrect = self.transitions_incorrect + other.transitions_incorrect repairs_used = self.repairs_used + other.repairs_used fake_transitions_used = self.fake_transitions_used + other.fake_transitions_used epoch_loss = self.epoch_loss + other.epoch_loss - return EpochStats(epoch_loss, transitions_correct, transitions_incorrect, repairs_used, fake_transitions_used) + nans = self.nans + other.nans + return EpochStats(epoch_loss, transitions_correct, transitions_incorrect, repairs_used, fake_transitions_used, nans) def iterate_training(args, trainer, train_trees, train_sequences, transitions, dev_trees, foundation_cache, model_save_each_filename, evaluator): @@ -585,6 +586,8 @@ def iterate_training(args, trainer, train_trees, train_sequences, transitions, d trainer.save(args['checkpoint_save_name'], save_optimizer=True) if model_save_each_filename: trainer.save(model_save_each_filename % trainer.epochs_trained, save_optimizer=True) + if epoch_stats.nans > 0: + logger.warning("Had to ignore %d batches with NaN", epoch_stats.nans) logger.info("Epoch %d finished\n Transitions correct: %s\n Transitions incorrect: %s\n Total loss for epoch: %.5f\n Dev score (%5d): %8f\n Best dev score (%5d): %8f", trainer.epochs_trained, epoch_stats.transitions_correct, epoch_stats.transitions_incorrect, epoch_stats.epoch_loss, trainer.epochs_trained, f1, best_epoch, best_f1) if args['wandb']: @@ -642,17 +645,23 @@ def train_model_one_epoch(epoch, trainer, transition_tensors, model_loss_functio optimizer = trainer.optimizer scheduler = trainer.scheduler - epoch_stats = EpochStats(0.0, Counter(), Counter(), Counter(), 0) + epoch_stats = EpochStats(0.0, Counter(), Counter(), Counter(), 0, 0) for batch_idx, interval_start in enumerate(tqdm(interval_starts, postfix="Epoch %d" % epoch)): batch = epoch_data[interval_start:interval_start+args['train_batch_size']] batch_stats = train_model_one_batch(epoch, batch_idx, model, batch, transition_tensors, model_loss_function, args) - optimizer.step() + # Early in the training, some trees will be degenerate in a + # way that results in layers going up the tree amplifying the + # weights until they overflow. Generally that problem + # resolves itself in a few iterations, so for now we just + # ignore those batches, but report how often it happens + if batch_stats.nans == 0: + optimizer.step() optimizer.zero_grad() - epoch_stats = epoch_stats + batch_stats + old_lr = scheduler.get_last_lr()[0] scheduler.step() new_lr = scheduler.get_last_lr()[0] @@ -775,9 +784,14 @@ def train_model_one_batch(epoch, batch_idx, model, batch, transition_tensors, mo logger.info(" %s norm: %f grad: %f", n, torch.linalg.norm(p), torch.linalg.norm(p.grad)) if not matched: logger.info(" (none found!)") - batch_loss = tree_loss.item() + if torch.any(torch.isnan(tree_loss)): + batch_loss = 0.0 + nans = 1 + else: + batch_loss = tree_loss.item() + nans = 0 - return EpochStats(batch_loss, transitions_correct, transitions_incorrect, repairs_used, fake_transitions_used) + return EpochStats(batch_loss, transitions_correct, transitions_incorrect, repairs_used, fake_transitions_used, nans) def build_batch_from_trees(batch_size, data_iterator, model): """ |