Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/stanfordnlp/stanza.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJohn Bauer <horatio@gmail.com>2022-10-29 11:09:40 +0300
committerJohn Bauer <horatio@gmail.com>2022-10-30 09:20:37 +0300
commitd881203e04be81c97b96c376d6f350c4e5181886 (patch)
treed7d4e3685e4265e02a6d2c6076f42c2520255cac
parent0a527352cd0d61d6385ed54e6d454c14b4593e5b (diff)
Rough draft of using silver trees.
Mostly untested. Includes an unfinished test of the silver data
-rw-r--r--stanza/models/constituency/trainer.py102
-rw-r--r--stanza/models/constituency/utils.py3
-rw-r--r--stanza/models/constituency_parser.py1
-rw-r--r--stanza/tests/constituency/test_trainer.py17
4 files changed, 94 insertions, 29 deletions
diff --git a/stanza/models/constituency/trainer.py b/stanza/models/constituency/trainer.py
index 000af875..6f1e4818 100644
--- a/stanza/models/constituency/trainer.py
+++ b/stanza/models/constituency/trainer.py
@@ -349,16 +349,39 @@ def add_grad_clipping(trainer, grad_clipping):
if p.requires_grad:
p.register_hook(lambda grad: torch.clamp(grad, -grad_clipping, grad_clipping))
-def build_trainer(args, train_trees, dev_trees, foundation_cache, model_load_file):
+def check_constituents(train_constituents, trees, treebank_name):
+ """
+ Check that all the constituents in the other dataset are known in the train set
+ """
+ constituents = parse_tree.Tree.get_unique_constituent_labels(trees)
+ for con in constituents:
+ if con not in train_constituents:
+ raise RuntimeError("Found label {} in the {} set which don't exist in the train set".format(con, treebank_name))
+
+def check_transitions(train_transitions, other_transitions, treebank_name):
+ """
+ Check that all the transitions in the other dataset are known in the train set
+ """
+ for trans in other_transitions:
+ if trans not in train_transitions:
+ raise RuntimeError("Found transition {} in the {} set which don't exist in the train set".format(trans, treebank_name))
+
+def check_root_labels(root_labels, other_trees, treebank_name):
+ """
+ Check that all the root states in the other dataset are known in the train set
+ """
+ for root_state in parse_tree.Tree.get_root_labels(other_trees):
+ if root_state not in root_labels:
+ raise RuntimeError("Found root state {} in the {} set which is not a ROOT state in the train set".format(root_state, treebank_name))
+
+def build_trainer(args, train_trees, dev_trees, silver_trees, foundation_cache, model_load_file):
"""
Builds a Trainer (with model) and the train_sequences and transitions for the given trees.
"""
train_constituents = parse_tree.Tree.get_unique_constituent_labels(train_trees)
- dev_constituents = parse_tree.Tree.get_unique_constituent_labels(dev_trees)
logger.info("Unique constituents in training set: %s", train_constituents)
- for con in dev_constituents:
- if con not in train_constituents:
- raise RuntimeError("Found label {} in the dev set which don't exist in the train set".format(con))
+ check_constituents(train_constituents, dev_trees, "dev")
+ check_constituents(train_constituents, silver_trees, "silver")
constituent_counts = parse_tree.Tree.get_constituent_counts(train_trees)
logger.info("Constituent node counts: %s", constituent_counts)
@@ -374,22 +397,24 @@ def build_trainer(args, train_trees, dev_trees, foundation_cache, model_load_fil
unary_limit = max(max(t.count_unary_depth() for t in train_trees),
max(t.count_unary_depth() for t in dev_trees)) + 1
+ if silver_trees:
+ unary_limit = max(unary_limit, max(t.count_unary_depth() for t in silver_trees))
train_sequences, train_transitions = transition_sequence.convert_trees_to_sequences(train_trees, "training", args['transition_scheme'])
dev_sequences, dev_transitions = transition_sequence.convert_trees_to_sequences(dev_trees, "dev", args['transition_scheme'])
+ silver_sequences, silver_transitions = transition_sequence.convert_trees_to_sequences(silver_trees, "silver", args['transition_scheme'])
logger.info("Total unique transitions in train set: %d", len(train_transitions))
logger.info("Unique transitions in training set: %s", train_transitions)
- for trans in dev_transitions:
- if trans not in train_transitions:
- raise RuntimeError("Found transition {} in the dev set which don't exist in the train set".format(trans))
+ check_transitions(train_transitions, dev_transitions, "dev")
+ # theoretically could just train based on the items in the silver dataset
+ check_transitions(train_transitions, silver_transitions, "silver")
verify_transitions(train_trees, train_sequences, args['transition_scheme'], unary_limit)
verify_transitions(dev_trees, dev_sequences, args['transition_scheme'], unary_limit)
root_labels = parse_tree.Tree.get_root_labels(train_trees)
- for root_state in parse_tree.Tree.get_root_labels(dev_trees):
- if root_state not in root_labels:
- raise RuntimeError("Found root state {} in the dev set which is not a ROOT state in the train set".format(root_state))
+ check_root_labels(root_labels, dev_trees, "dev")
+ check_root_labels(root_labels, silver_trees, "silver")
# we don't check against the words in the dev set as it is
# expected there will be some UNK words
@@ -418,12 +443,12 @@ def build_trainer(args, train_trees, dev_trees, foundation_cache, model_load_fil
# grad clipping is not saved with the rest of the model
add_grad_clipping(trainer, args['grad_clipping'])
- # TODO: turn finetune, relearn_structure, multistage into an enum
+ # TODO: turn finetune, relearn_structure, multistage into an enum?
# finetune just means continue learning, so checkpoint is sufficient
# relearn_structure is essentially a one stage multistage
# multistage with a checkpoint will have the proper optimizer for that epoch
# and no special learning mode means we are training a new model and should continue
- return trainer, train_sequences, train_transitions
+ return trainer, train_sequences, silver_sequences, train_transitions
if args['finetune']:
logger.info("Loading model to finetune: %s", model_load_file)
@@ -477,7 +502,7 @@ def build_trainer(args, train_trees, dev_trees, foundation_cache, model_load_fil
add_grad_clipping(trainer, args['grad_clipping'])
- return trainer, train_sequences, train_transitions
+ return trainer, train_sequences, silver_sequences, train_transitions
def remove_duplicates(trees, dataset):
"""
@@ -542,16 +567,23 @@ def train(args, model_load_file, model_save_each_file, retag_pipeline):
logger.info("Read %d trees for the dev set", len(dev_trees))
dev_trees = remove_duplicates(dev_trees, "dev")
+ silver_trees = []
+ if args['silver_file']:
+ silver_trees = tree_reader.read_treebank(args['silver_file'])
+ logger.info("Read %d trees for the silver training set", len(silver_trees))
+ silver_trees = remove_duplicates(silver_trees, "silver")
+
if retag_pipeline is not None:
logger.info("Retagging trees using the %s tags from the %s package...", args['retag_method'], args['retag_package'])
train_trees = retag_trees(train_trees, retag_pipeline, args['retag_xpos'])
dev_trees = retag_trees(dev_trees, retag_pipeline, args['retag_xpos'])
+ silver_trees = retag_trees(silver_trees, retag_pipeline, args['retag_xpos'])
logger.info("Retagging finished")
foundation_cache = retag_pipeline.foundation_cache if retag_pipeline else FoundationCache()
- trainer, train_sequences, train_transitions = build_trainer(args, train_trees, dev_trees, foundation_cache, model_load_file)
+ trainer, train_sequences, silver_sequences, train_transitions = build_trainer(args, train_trees, dev_trees, silver_trees, foundation_cache, model_load_file)
- trainer = iterate_training(args, trainer, train_trees, train_sequences, train_transitions, dev_trees, foundation_cache, model_save_each_file, evaluator)
+ trainer = iterate_training(args, trainer, train_trees, train_sequences, train_transitions, dev_trees, silver_trees, silver_sequences, foundation_cache, model_save_each_file, evaluator)
if args['wandb']:
wandb.finish()
@@ -571,7 +603,27 @@ class EpochStats(namedtuple("EpochStats", ['epoch_loss', 'transitions_correct',
return EpochStats(epoch_loss, transitions_correct, transitions_incorrect, repairs_used, fake_transitions_used, nans)
-def iterate_training(args, trainer, train_trees, train_sequences, transitions, dev_trees, foundation_cache, model_save_each_filename, evaluator):
+def compose_train_data(trees, sequences):
+ preterminal_lists = [[Tree(label=preterminal.label, children=Tree(label=preterminal.children[0].label))
+ for preterminal in tree.yield_preterminals()]
+ for tree in trees]
+ data = [TrainItem(*x) for x in zip(trees, sequences, preterminal_lists)]
+ return data
+
+def next_epoch_data(leftover_training_data, train_data, epoch_size):
+ if not train_data:
+ return [], []
+
+ epoch_data = leftover_training_data
+ while len(epoch_data) < epoch_size:
+ random.shuffle(train_data)
+ epoch_data.extend(train_data)
+ leftover_training_data = epoch_data[epoch_size:]
+ epoch_data = epoch_data[:epoch_size]
+
+ return leftover_training_data, epoch_data
+
+def iterate_training(args, trainer, train_trees, train_sequences, transitions, dev_trees, silver_trees, silver_sequences, foundation_cache, model_save_each_filename, evaluator):
"""
Given an initialized model, a processed dataset, and a secondary dev dataset, train the model
@@ -598,10 +650,8 @@ def iterate_training(args, trainer, train_trees, train_sequences, transitions, d
for (y, x) in enumerate(model.transitions)}
model.train()
- preterminal_lists = [[Tree(label=preterminal.label, children=Tree(label=preterminal.children[0].label))
- for preterminal in tree.yield_preterminals()]
- for tree in train_trees]
- train_data = [TrainItem(*x) for x in zip(train_trees, train_sequences, preterminal_lists)]
+ train_data = compose_train_data(train_trees, train_sequences)
+ silver_data = compose_train_data(silver_trees, silver_sequences)
if not args['epoch_size']:
args['epoch_size'] = len(train_data)
@@ -614,6 +664,7 @@ def iterate_training(args, trainer, train_trees, train_sequences, transitions, d
multistage_splits[args['epochs'] * 3 // 4] = (args['pattn_num_layers'], True)
leftover_training_data = []
+ leftover_silver_data = []
if trainer.best_epoch > 0:
logger.info("Restarting trainer with a model trained for %d epochs. Best epoch %d, f1 %f", trainer.epochs_trained, trainer.best_epoch, trainer.best_f1)
# trainer.epochs_trained+1 so that if the trainer gets saved after 1 epoch, the epochs_trained is 1
@@ -622,12 +673,9 @@ def iterate_training(args, trainer, train_trees, train_sequences, transitions, d
logger.info("Starting epoch %d", trainer.epochs_trained)
if args['log_norms']:
model.log_norms()
- epoch_data = leftover_training_data
- while len(epoch_data) < args['epoch_size']:
- random.shuffle(train_data)
- epoch_data.extend(train_data)
- leftover_training_data = epoch_data[args['epoch_size']:]
- epoch_data = epoch_data[:args['epoch_size']]
+ leftover_training_data, epoch_data = next_epoch_data(leftover_training_data, train_data, args['epoch_size'])
+ leftover_silver_data, epoch_silver_data = next_epoch_data(leftover_silver_data, silver_data, args['epoch_size'])
+ epoch_data = epoch_data + epoch_silver_data
epoch_data.sort(key=lambda x: len(x[1]))
epoch_stats = train_model_one_epoch(trainer.epochs_trained, trainer, transition_tensors, model_loss_function, epoch_data, args)
diff --git a/stanza/models/constituency/utils.py b/stanza/models/constituency/utils.py
index 01dc3522..e3bfa525 100644
--- a/stanza/models/constituency/utils.py
+++ b/stanza/models/constituency/utils.py
@@ -78,6 +78,9 @@ def retag_trees(trees, pipeline, xpos=True):
Returns a list of new trees
"""
+ if len(trees) == 0:
+ return trees
+
sentences = []
try:
for idx, tree in enumerate(trees):
diff --git a/stanza/models/constituency_parser.py b/stanza/models/constituency_parser.py
index 33e3cd05..7a1349e3 100644
--- a/stanza/models/constituency_parser.py
+++ b/stanza/models/constituency_parser.py
@@ -177,6 +177,7 @@ def parse_args(args=None):
parser.add_argument('--delta_embedding_dim', type=int, default=100, help="Embedding size for a delta embedding")
parser.add_argument('--train_file', type=str, default=None, help='Input file for data loader.')
+ parser.add_argument('--silver_file', type=str, default=None, help='Secondary training file.')
parser.add_argument('--eval_file', type=str, default=None, help='Input file for data loader.')
parser.add_argument('--tokenized_file', type=str, default=None, help='Input file of tokenized text for parsing with parse_text.')
parser.add_argument('--mode', default='train', choices=['train', 'parse_text', 'predict', 'remove_optimizer'])
diff --git a/stanza/tests/constituency/test_trainer.py b/stanza/tests/constituency/test_trainer.py
index 71b537e1..a9a37f4f 100644
--- a/stanza/tests/constituency/test_trainer.py
+++ b/stanza/tests/constituency/test_trainer.py
@@ -57,6 +57,7 @@ def build_trainer(wordvec_pretrain_file, *args, treebank=TREEBANK):
# TODO: build a fake embedding some other way?
train_trees = tree_reader.read_trees(treebank)
dev_trees = train_trees[-1:]
+ silver_trees = []
args = ['--wordvec_pretrain_file', wordvec_pretrain_file] + list(args)
args = constituency_parser.parse_args(args)
@@ -65,7 +66,7 @@ def build_trainer(wordvec_pretrain_file, *args, treebank=TREEBANK):
# might be None, unless we're testing loading an existing model
model_load_name = args['load_name']
- model, _, _ = trainer.build_trainer(args, train_trees, dev_trees, foundation_cache, model_load_name)
+ model, _, _, _ = trainer.build_trainer(args, train_trees, dev_trees, silver_trees, foundation_cache, model_load_name)
assert isinstance(model.model, lstm_model.LSTMModel)
return model
@@ -148,7 +149,7 @@ class TestTrainer:
args['wandb'] = None
return args
- def run_train_test(self, wordvec_pretrain_file, tmpdirname, num_epochs=5, extra_args=None):
+ def run_train_test(self, wordvec_pretrain_file, tmpdirname, num_epochs=5, extra_args=None, use_silver=False):
"""
Runs a test of the trainer for a few iterations.
@@ -160,6 +161,8 @@ class TestTrainer:
extra_args += ['--epochs', '%d' % num_epochs]
train_treebank_file, eval_treebank_file = self.write_treebanks(tmpdirname)
+ if use_silver:
+ extra_args += ['--silver_file', str(eval_treebank_file)]
args = self.training_args(wordvec_pretrain_file, tmpdirname, train_treebank_file, eval_treebank_file, *extra_args)
each_name = os.path.join(args['save_dir'], 'each_%02d.pt')
@@ -204,6 +207,16 @@ class TestTrainer:
with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname:
self.run_train_test(wordvec_pretrain_file, tmpdirname)
+ def test_train_silver(self, wordvec_pretrain_file):
+ """
+ Test the whole thing for a few iterations on the fake data
+
+ This tests that it works if you give it a silver file, but
+ doesn't actually test that the silver file is being used
+ """
+ with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname:
+ self.run_train_test(wordvec_pretrain_file, tmpdirname, use_silver=True)
+
def run_multistage_tests(self, wordvec_pretrain_file, tmpdirname, use_lattn, extra_args=None):
train_treebank_file, eval_treebank_file = self.write_treebanks(tmpdirname)
args = ['--multistage', '--pattn_num_layers', '1']