diff options
author | John Bauer <horatio@gmail.com> | 2022-10-23 06:30:00 +0300 |
---|---|---|
committer | John Bauer <horatio@gmail.com> | 2022-11-11 09:41:22 +0300 |
commit | 786449a49c8f5e40ea004f24ddcc7126cf482948 (patch) | |
tree | 7bb1f2fed6e6fdcf8d598aa9249820d9376b265d | |
parent | 080b714426d8be12a0bcf39d5e6020be23dfb33d (diff) |
Add a margin_loss termmargin_penalty
-rw-r--r-- | stanza/models/constituency/base_model.py | 10 | ||||
-rw-r--r-- | stanza/models/constituency/trainer.py | 48 | ||||
-rw-r--r-- | stanza/models/constituency_parser.py | 2 |
3 files changed, 46 insertions, 14 deletions
diff --git a/stanza/models/constituency/base_model.py b/stanza/models/constituency/base_model.py index 23781bce..de73554c 100644 --- a/stanza/models/constituency/base_model.py +++ b/stanza/models/constituency/base_model.py @@ -204,6 +204,16 @@ class BaseModel(ABC): for tree in trees] return self.initial_state_from_preterminals(preterminal_lists, gold_trees=trees) + def build_batch_from_states(self, batch_size, data_iterator): + state_batch = [] + for _ in range(batch_size): + state = next(data_iterator, None) + if state is None: + break + state_batch.append(state) + + return state_batch + def build_batch_from_trees(self, batch_size, data_iterator): """ Read from the data_iterator batch_size trees and turn them into new parsing states diff --git a/stanza/models/constituency/trainer.py b/stanza/models/constituency/trainer.py index 09ef534c..792c1af9 100644 --- a/stanza/models/constituency/trainer.py +++ b/stanza/models/constituency/trainer.py @@ -591,15 +591,16 @@ def train(args, model_load_file, model_save_each_file, retag_pipeline): TrainItem = namedtuple("TrainItem", ['tree', 'gold_sequence', 'preterminals']) -class EpochStats(namedtuple("EpochStats", ['epoch_loss', 'transitions_correct', 'transitions_incorrect', 'repairs_used', 'fake_transitions_used', 'nans'])): +class EpochStats(namedtuple("EpochStats", ['model_loss', 'margin_loss', 'transitions_correct', 'transitions_incorrect', 'repairs_used', 'fake_transitions_used', 'nans'])): def __add__(self, other): transitions_correct = self.transitions_correct + other.transitions_correct transitions_incorrect = self.transitions_incorrect + other.transitions_incorrect repairs_used = self.repairs_used + other.repairs_used fake_transitions_used = self.fake_transitions_used + other.fake_transitions_used - epoch_loss = self.epoch_loss + other.epoch_loss + model_loss = self.model_loss + other.model_loss + margin_loss = self.margin_loss + other.margin_loss nans = self.nans + other.nans - return EpochStats(epoch_loss, transitions_correct, transitions_incorrect, repairs_used, fake_transitions_used, nans) + return EpochStats(model_loss, margin_loss, transitions_correct, transitions_incorrect, repairs_used, fake_transitions_used, nans) def compose_train_data(trees, sequences): @@ -641,8 +642,10 @@ def iterate_training(args, trainer, train_trees, train_sequences, transitions, d # datasets when using 'mean' instead of 'sum' for reduction # (Remember to adjust the weight decay when rerunning that experiment) model_loss_function = nn.CrossEntropyLoss(reduction='sum') + margin_loss_function = nn.MarginRankingLoss(reduction='sum') if args['cuda']: model_loss_function.cuda() + margin_loss_function.cuda() device = next(model.parameters()).device transition_tensors = {x: torch.tensor(y, requires_grad=False, device=device).unsqueeze(0) @@ -679,7 +682,7 @@ def iterate_training(args, trainer, train_trees, train_sequences, transitions, d epoch_data = epoch_data + epoch_silver_data epoch_data.sort(key=lambda x: len(x[1])) - epoch_stats = train_model_one_epoch(trainer.epochs_trained, trainer, transition_tensors, model_loss_function, epoch_data, args) + epoch_stats = train_model_one_epoch(trainer.epochs_trained, trainer, transition_tensors, model_loss_function, margin_loss_function, epoch_data, args) # print statistics f1, _ = run_dev_set(model, dev_trees, args, evaluator) @@ -695,10 +698,10 @@ def iterate_training(args, trainer, train_trees, train_sequences, transitions, d trainer.save(model_save_each_filename % trainer.epochs_trained, save_optimizer=True) if epoch_stats.nans > 0: logger.warning("Had to ignore %d batches with NaN", epoch_stats.nans) - logger.info("Epoch %d finished\n Transitions correct: %s\n Transitions incorrect: %s\n Total loss for epoch: %.5f\n Dev score (%5d): %8f\n Best dev score (%5d): %8f", trainer.epochs_trained, epoch_stats.transitions_correct, epoch_stats.transitions_incorrect, epoch_stats.epoch_loss, trainer.epochs_trained, f1, trainer.best_epoch, trainer.best_f1) + logger.info("Epoch %d finished\n Transitions correct: %s\n Transitions incorrect: %s\n Total model loss for epoch: %.5f\n Total margin loss for epoch: %.5f\n Dev score (%5d): %8f\n Best dev score (%5d): %8f", trainer.epochs_trained, epoch_stats.transitions_correct, epoch_stats.transitions_incorrect, epoch_stats.model_loss, epoch_stats.margin_loss, trainer.epochs_trained, f1, trainer.best_epoch, trainer.best_f1) if args['wandb']: - wandb.log({'epoch_loss': epoch_stats.epoch_loss, 'dev_score': f1}, step=trainer.epochs_trained) + wandb.log({'model_loss': epoch_stats.model_loss, 'margin_loss': epoch_stats.margin_loss, 'dev_score': f1}, step=trainer.epochs_trained) if args['wandb_norm_regex']: watch_regex = re.compile(args['wandb_norm_regex']) for n, p in model.named_parameters(): @@ -751,7 +754,7 @@ def iterate_training(args, trainer, train_trees, train_sequences, transitions, d return trainer -def train_model_one_epoch(epoch, trainer, transition_tensors, model_loss_function, epoch_data, args): +def train_model_one_epoch(epoch, trainer, transition_tensors, model_loss_function, margin_loss_function, epoch_data, args): interval_starts = list(range(0, len(epoch_data), args['train_batch_size'])) random.shuffle(interval_starts) @@ -759,11 +762,11 @@ def train_model_one_epoch(epoch, trainer, transition_tensors, model_loss_functio optimizer = trainer.optimizer scheduler = trainer.scheduler - epoch_stats = EpochStats(0.0, Counter(), Counter(), Counter(), 0, 0) + epoch_stats = EpochStats(0.0, 0.0, Counter(), Counter(), Counter(), 0, 0) for batch_idx, interval_start in enumerate(tqdm(interval_starts, postfix="Epoch %d" % epoch)): batch = epoch_data[interval_start:interval_start+args['train_batch_size']] - batch_stats = train_model_one_batch(epoch, batch_idx, model, batch, transition_tensors, model_loss_function, args) + batch_stats = train_model_one_batch(epoch, batch_idx, model, batch, transition_tensors, model_loss_function, margin_loss_function, args) trainer.batches_trained += 1 # Early in the training, some trees will be degenerate in a @@ -795,7 +798,7 @@ def train_model_one_epoch(epoch, trainer, transition_tensors, model_loss_functio return epoch_stats -def train_model_one_batch(epoch, batch_idx, model, training_batch, transition_tensors, model_loss_function, args): +def train_model_one_batch(epoch, batch_idx, model, training_batch, transition_tensors, model_loss_function, margin_loss_function, args): """ Train the model for one batch @@ -810,6 +813,7 @@ def train_model_one_batch(epoch, batch_idx, model, training_batch, transition_te initial_states = model.initial_state_from_preterminals([x.preterminals for x in training_batch], [x.tree for x in training_batch]) initial_states = [state._replace(gold_sequence=sequence) for (tree, sequence, _), state in zip(training_batch, initial_states)] + # save the untouched initial_states for later so we can reparse the trees current_batch = initial_states transitions_correct = Counter() @@ -888,8 +892,21 @@ def train_model_one_batch(epoch, batch_idx, model, training_batch, transition_te errors = torch.cat(all_errors) answers = torch.cat(all_answers) - tree_loss = model_loss_function(errors, answers) + model_loss = model_loss_function(errors, answers) + tree_loss = model_loss + + if epoch >= args['margin_loss_initial_epoch']: + gold_parsed = model.parse_sentences(iter(initial_states), model.build_batch_from_states, len(initial_states), model.predict_gold, keep_scores=True) + gold_scores = torch.stack([x.predictions[0].score for x in gold_parsed]) + model_parsed = model.parse_sentences(iter(initial_states), model.build_batch_from_states, len(initial_states), model.predict, keep_scores=True) + model_scores = torch.stack([x.predictions[0].score for x in model_parsed]) + margin_loss = margin_loss_function(gold_scores, model_scores, torch.ones_like(gold_scores)) + tree_loss += margin_loss + else: + margin_loss = 0.0 + tree_loss.backward() + if args['watch_regex']: matched = False logger.info("Watching %s ... epoch %d batch %d", args['watch_regex'], epoch, batch_idx) @@ -906,13 +923,16 @@ def train_model_one_batch(epoch, batch_idx, model, training_batch, transition_te if not matched: logger.info(" (none found!)") if torch.any(torch.isnan(tree_loss)): - batch_loss = 0.0 + model_loss = 0.0 + margin_loss = 0.0 nans = 1 else: - batch_loss = tree_loss.item() + model_loss = model_loss.item() + if margin_loss != 0.0: + margin_loss = margin_loss.item() nans = 0 - return EpochStats(batch_loss, transitions_correct, transitions_incorrect, repairs_used, fake_transitions_used, nans) + return EpochStats(model_loss, margin_loss, transitions_correct, transitions_incorrect, repairs_used, fake_transitions_used, nans) def run_dev_set(model, dev_trees, args, evaluator=None): """ diff --git a/stanza/models/constituency_parser.py b/stanza/models/constituency_parser.py index 98a166ad..a1296ee6 100644 --- a/stanza/models/constituency_parser.py +++ b/stanza/models/constituency_parser.py @@ -230,6 +230,8 @@ def parse_args(args=None): parser.add_argument('--oracle_frequency', type=float, default=0.8, help="How often to use the oracle vs how often to force the correct transition") parser.add_argument('--oracle_forced_errors', type=float, default=0.001, help="Occasionally have the model randomly walk through the state space to try to learn how to recover") + parser.add_argument('--margin_loss_initial_epoch', type=int, default=10, help="Initial epochs focus on learning the model itself") + # 30 is slightly slower than 50, for example, but seems to train a bit better on WSJ # earlier version of the model (less accurate overall) had the following results with adadelta: # 30: 0.9085 |