Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/stanfordnlp/stanza.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJohn Bauer <horatio@gmail.com>2022-10-23 06:30:00 +0300
committerJohn Bauer <horatio@gmail.com>2022-11-11 09:41:22 +0300
commit786449a49c8f5e40ea004f24ddcc7126cf482948 (patch)
tree7bb1f2fed6e6fdcf8d598aa9249820d9376b265d
parent080b714426d8be12a0bcf39d5e6020be23dfb33d (diff)
Add a margin_loss termmargin_penalty
-rw-r--r--stanza/models/constituency/base_model.py10
-rw-r--r--stanza/models/constituency/trainer.py48
-rw-r--r--stanza/models/constituency_parser.py2
3 files changed, 46 insertions, 14 deletions
diff --git a/stanza/models/constituency/base_model.py b/stanza/models/constituency/base_model.py
index 23781bce..de73554c 100644
--- a/stanza/models/constituency/base_model.py
+++ b/stanza/models/constituency/base_model.py
@@ -204,6 +204,16 @@ class BaseModel(ABC):
for tree in trees]
return self.initial_state_from_preterminals(preterminal_lists, gold_trees=trees)
+ def build_batch_from_states(self, batch_size, data_iterator):
+ state_batch = []
+ for _ in range(batch_size):
+ state = next(data_iterator, None)
+ if state is None:
+ break
+ state_batch.append(state)
+
+ return state_batch
+
def build_batch_from_trees(self, batch_size, data_iterator):
"""
Read from the data_iterator batch_size trees and turn them into new parsing states
diff --git a/stanza/models/constituency/trainer.py b/stanza/models/constituency/trainer.py
index 09ef534c..792c1af9 100644
--- a/stanza/models/constituency/trainer.py
+++ b/stanza/models/constituency/trainer.py
@@ -591,15 +591,16 @@ def train(args, model_load_file, model_save_each_file, retag_pipeline):
TrainItem = namedtuple("TrainItem", ['tree', 'gold_sequence', 'preterminals'])
-class EpochStats(namedtuple("EpochStats", ['epoch_loss', 'transitions_correct', 'transitions_incorrect', 'repairs_used', 'fake_transitions_used', 'nans'])):
+class EpochStats(namedtuple("EpochStats", ['model_loss', 'margin_loss', 'transitions_correct', 'transitions_incorrect', 'repairs_used', 'fake_transitions_used', 'nans'])):
def __add__(self, other):
transitions_correct = self.transitions_correct + other.transitions_correct
transitions_incorrect = self.transitions_incorrect + other.transitions_incorrect
repairs_used = self.repairs_used + other.repairs_used
fake_transitions_used = self.fake_transitions_used + other.fake_transitions_used
- epoch_loss = self.epoch_loss + other.epoch_loss
+ model_loss = self.model_loss + other.model_loss
+ margin_loss = self.margin_loss + other.margin_loss
nans = self.nans + other.nans
- return EpochStats(epoch_loss, transitions_correct, transitions_incorrect, repairs_used, fake_transitions_used, nans)
+ return EpochStats(model_loss, margin_loss, transitions_correct, transitions_incorrect, repairs_used, fake_transitions_used, nans)
def compose_train_data(trees, sequences):
@@ -641,8 +642,10 @@ def iterate_training(args, trainer, train_trees, train_sequences, transitions, d
# datasets when using 'mean' instead of 'sum' for reduction
# (Remember to adjust the weight decay when rerunning that experiment)
model_loss_function = nn.CrossEntropyLoss(reduction='sum')
+ margin_loss_function = nn.MarginRankingLoss(reduction='sum')
if args['cuda']:
model_loss_function.cuda()
+ margin_loss_function.cuda()
device = next(model.parameters()).device
transition_tensors = {x: torch.tensor(y, requires_grad=False, device=device).unsqueeze(0)
@@ -679,7 +682,7 @@ def iterate_training(args, trainer, train_trees, train_sequences, transitions, d
epoch_data = epoch_data + epoch_silver_data
epoch_data.sort(key=lambda x: len(x[1]))
- epoch_stats = train_model_one_epoch(trainer.epochs_trained, trainer, transition_tensors, model_loss_function, epoch_data, args)
+ epoch_stats = train_model_one_epoch(trainer.epochs_trained, trainer, transition_tensors, model_loss_function, margin_loss_function, epoch_data, args)
# print statistics
f1, _ = run_dev_set(model, dev_trees, args, evaluator)
@@ -695,10 +698,10 @@ def iterate_training(args, trainer, train_trees, train_sequences, transitions, d
trainer.save(model_save_each_filename % trainer.epochs_trained, save_optimizer=True)
if epoch_stats.nans > 0:
logger.warning("Had to ignore %d batches with NaN", epoch_stats.nans)
- logger.info("Epoch %d finished\n Transitions correct: %s\n Transitions incorrect: %s\n Total loss for epoch: %.5f\n Dev score (%5d): %8f\n Best dev score (%5d): %8f", trainer.epochs_trained, epoch_stats.transitions_correct, epoch_stats.transitions_incorrect, epoch_stats.epoch_loss, trainer.epochs_trained, f1, trainer.best_epoch, trainer.best_f1)
+ logger.info("Epoch %d finished\n Transitions correct: %s\n Transitions incorrect: %s\n Total model loss for epoch: %.5f\n Total margin loss for epoch: %.5f\n Dev score (%5d): %8f\n Best dev score (%5d): %8f", trainer.epochs_trained, epoch_stats.transitions_correct, epoch_stats.transitions_incorrect, epoch_stats.model_loss, epoch_stats.margin_loss, trainer.epochs_trained, f1, trainer.best_epoch, trainer.best_f1)
if args['wandb']:
- wandb.log({'epoch_loss': epoch_stats.epoch_loss, 'dev_score': f1}, step=trainer.epochs_trained)
+ wandb.log({'model_loss': epoch_stats.model_loss, 'margin_loss': epoch_stats.margin_loss, 'dev_score': f1}, step=trainer.epochs_trained)
if args['wandb_norm_regex']:
watch_regex = re.compile(args['wandb_norm_regex'])
for n, p in model.named_parameters():
@@ -751,7 +754,7 @@ def iterate_training(args, trainer, train_trees, train_sequences, transitions, d
return trainer
-def train_model_one_epoch(epoch, trainer, transition_tensors, model_loss_function, epoch_data, args):
+def train_model_one_epoch(epoch, trainer, transition_tensors, model_loss_function, margin_loss_function, epoch_data, args):
interval_starts = list(range(0, len(epoch_data), args['train_batch_size']))
random.shuffle(interval_starts)
@@ -759,11 +762,11 @@ def train_model_one_epoch(epoch, trainer, transition_tensors, model_loss_functio
optimizer = trainer.optimizer
scheduler = trainer.scheduler
- epoch_stats = EpochStats(0.0, Counter(), Counter(), Counter(), 0, 0)
+ epoch_stats = EpochStats(0.0, 0.0, Counter(), Counter(), Counter(), 0, 0)
for batch_idx, interval_start in enumerate(tqdm(interval_starts, postfix="Epoch %d" % epoch)):
batch = epoch_data[interval_start:interval_start+args['train_batch_size']]
- batch_stats = train_model_one_batch(epoch, batch_idx, model, batch, transition_tensors, model_loss_function, args)
+ batch_stats = train_model_one_batch(epoch, batch_idx, model, batch, transition_tensors, model_loss_function, margin_loss_function, args)
trainer.batches_trained += 1
# Early in the training, some trees will be degenerate in a
@@ -795,7 +798,7 @@ def train_model_one_epoch(epoch, trainer, transition_tensors, model_loss_functio
return epoch_stats
-def train_model_one_batch(epoch, batch_idx, model, training_batch, transition_tensors, model_loss_function, args):
+def train_model_one_batch(epoch, batch_idx, model, training_batch, transition_tensors, model_loss_function, margin_loss_function, args):
"""
Train the model for one batch
@@ -810,6 +813,7 @@ def train_model_one_batch(epoch, batch_idx, model, training_batch, transition_te
initial_states = model.initial_state_from_preterminals([x.preterminals for x in training_batch], [x.tree for x in training_batch])
initial_states = [state._replace(gold_sequence=sequence)
for (tree, sequence, _), state in zip(training_batch, initial_states)]
+ # save the untouched initial_states for later so we can reparse the trees
current_batch = initial_states
transitions_correct = Counter()
@@ -888,8 +892,21 @@ def train_model_one_batch(epoch, batch_idx, model, training_batch, transition_te
errors = torch.cat(all_errors)
answers = torch.cat(all_answers)
- tree_loss = model_loss_function(errors, answers)
+ model_loss = model_loss_function(errors, answers)
+ tree_loss = model_loss
+
+ if epoch >= args['margin_loss_initial_epoch']:
+ gold_parsed = model.parse_sentences(iter(initial_states), model.build_batch_from_states, len(initial_states), model.predict_gold, keep_scores=True)
+ gold_scores = torch.stack([x.predictions[0].score for x in gold_parsed])
+ model_parsed = model.parse_sentences(iter(initial_states), model.build_batch_from_states, len(initial_states), model.predict, keep_scores=True)
+ model_scores = torch.stack([x.predictions[0].score for x in model_parsed])
+ margin_loss = margin_loss_function(gold_scores, model_scores, torch.ones_like(gold_scores))
+ tree_loss += margin_loss
+ else:
+ margin_loss = 0.0
+
tree_loss.backward()
+
if args['watch_regex']:
matched = False
logger.info("Watching %s ... epoch %d batch %d", args['watch_regex'], epoch, batch_idx)
@@ -906,13 +923,16 @@ def train_model_one_batch(epoch, batch_idx, model, training_batch, transition_te
if not matched:
logger.info(" (none found!)")
if torch.any(torch.isnan(tree_loss)):
- batch_loss = 0.0
+ model_loss = 0.0
+ margin_loss = 0.0
nans = 1
else:
- batch_loss = tree_loss.item()
+ model_loss = model_loss.item()
+ if margin_loss != 0.0:
+ margin_loss = margin_loss.item()
nans = 0
- return EpochStats(batch_loss, transitions_correct, transitions_incorrect, repairs_used, fake_transitions_used, nans)
+ return EpochStats(model_loss, margin_loss, transitions_correct, transitions_incorrect, repairs_used, fake_transitions_used, nans)
def run_dev_set(model, dev_trees, args, evaluator=None):
"""
diff --git a/stanza/models/constituency_parser.py b/stanza/models/constituency_parser.py
index 98a166ad..a1296ee6 100644
--- a/stanza/models/constituency_parser.py
+++ b/stanza/models/constituency_parser.py
@@ -230,6 +230,8 @@ def parse_args(args=None):
parser.add_argument('--oracle_frequency', type=float, default=0.8, help="How often to use the oracle vs how often to force the correct transition")
parser.add_argument('--oracle_forced_errors', type=float, default=0.001, help="Occasionally have the model randomly walk through the state space to try to learn how to recover")
+ parser.add_argument('--margin_loss_initial_epoch', type=int, default=10, help="Initial epochs focus on learning the model itself")
+
# 30 is slightly slower than 50, for example, but seems to train a bit better on WSJ
# earlier version of the model (less accurate overall) had the following results with adadelta:
# 30: 0.9085