Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/stanfordnlp/stanza.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJohn Bauer <horatio@gmail.com>2022-09-12 09:19:47 +0300
committerJohn Bauer <horatio@gmail.com>2022-09-12 09:19:47 +0300
commitc4d785729e42ac90f298e0ef4ab487d14fa35591 (patch)
tree48c833aa1529957c7be3062fd32705b1fa3cbcd4
parent824bb780bba86cc0610346f538cf9eecd065fe33 (diff)
Throw out batches which had gone to NaN. Log the number of times it happens
-rw-r--r--stanza/models/constituency/trainer.py28
1 files changed, 21 insertions, 7 deletions
diff --git a/stanza/models/constituency/trainer.py b/stanza/models/constituency/trainer.py
index 5e2eb8ff..9a6ea4c1 100644
--- a/stanza/models/constituency/trainer.py
+++ b/stanza/models/constituency/trainer.py
@@ -500,14 +500,15 @@ def train(args, model_load_file, model_save_each_file, retag_pipeline):
TrainItem = namedtuple("TrainItem", ['tree', 'gold_sequence', 'preterminals'])
-class EpochStats(namedtuple("EpochStats", ['epoch_loss', 'transitions_correct', 'transitions_incorrect', 'repairs_used', 'fake_transitions_used'])):
+class EpochStats(namedtuple("EpochStats", ['epoch_loss', 'transitions_correct', 'transitions_incorrect', 'repairs_used', 'fake_transitions_used', 'nans'])):
def __add__(self, other):
transitions_correct = self.transitions_correct + other.transitions_correct
transitions_incorrect = self.transitions_incorrect + other.transitions_incorrect
repairs_used = self.repairs_used + other.repairs_used
fake_transitions_used = self.fake_transitions_used + other.fake_transitions_used
epoch_loss = self.epoch_loss + other.epoch_loss
- return EpochStats(epoch_loss, transitions_correct, transitions_incorrect, repairs_used, fake_transitions_used)
+ nans = self.nans + other.nans
+ return EpochStats(epoch_loss, transitions_correct, transitions_incorrect, repairs_used, fake_transitions_used, nans)
def iterate_training(args, trainer, train_trees, train_sequences, transitions, dev_trees, foundation_cache, model_save_each_filename, evaluator):
@@ -585,6 +586,8 @@ def iterate_training(args, trainer, train_trees, train_sequences, transitions, d
trainer.save(args['checkpoint_save_name'], save_optimizer=True)
if model_save_each_filename:
trainer.save(model_save_each_filename % trainer.epochs_trained, save_optimizer=True)
+ if epoch_stats.nans > 0:
+ logger.warning("Had to ignore %d batches with NaN", epoch_stats.nans)
logger.info("Epoch %d finished\n Transitions correct: %s\n Transitions incorrect: %s\n Total loss for epoch: %.5f\n Dev score (%5d): %8f\n Best dev score (%5d): %8f", trainer.epochs_trained, epoch_stats.transitions_correct, epoch_stats.transitions_incorrect, epoch_stats.epoch_loss, trainer.epochs_trained, f1, best_epoch, best_f1)
if args['wandb']:
@@ -642,17 +645,23 @@ def train_model_one_epoch(epoch, trainer, transition_tensors, model_loss_functio
optimizer = trainer.optimizer
scheduler = trainer.scheduler
- epoch_stats = EpochStats(0.0, Counter(), Counter(), Counter(), 0)
+ epoch_stats = EpochStats(0.0, Counter(), Counter(), Counter(), 0, 0)
for batch_idx, interval_start in enumerate(tqdm(interval_starts, postfix="Epoch %d" % epoch)):
batch = epoch_data[interval_start:interval_start+args['train_batch_size']]
batch_stats = train_model_one_batch(epoch, batch_idx, model, batch, transition_tensors, model_loss_function, args)
- optimizer.step()
+ # Early in the training, some trees will be degenerate in a
+ # way that results in layers going up the tree amplifying the
+ # weights until they overflow. Generally that problem
+ # resolves itself in a few iterations, so for now we just
+ # ignore those batches, but report how often it happens
+ if batch_stats.nans == 0:
+ optimizer.step()
optimizer.zero_grad()
-
epoch_stats = epoch_stats + batch_stats
+
old_lr = scheduler.get_last_lr()[0]
scheduler.step()
new_lr = scheduler.get_last_lr()[0]
@@ -775,9 +784,14 @@ def train_model_one_batch(epoch, batch_idx, model, batch, transition_tensors, mo
logger.info(" %s norm: %f grad: %f", n, torch.linalg.norm(p), torch.linalg.norm(p.grad))
if not matched:
logger.info(" (none found!)")
- batch_loss = tree_loss.item()
+ if torch.any(torch.isnan(tree_loss)):
+ batch_loss = 0.0
+ nans = 1
+ else:
+ batch_loss = tree_loss.item()
+ nans = 0
- return EpochStats(batch_loss, transitions_correct, transitions_incorrect, repairs_used, fake_transitions_used)
+ return EpochStats(batch_loss, transitions_correct, transitions_incorrect, repairs_used, fake_transitions_used, nans)
def build_batch_from_trees(batch_size, data_iterator, model):
"""