diff options
author | John Bauer <horatio@gmail.com> | 2022-10-29 11:09:40 +0300 |
---|---|---|
committer | John Bauer <horatio@gmail.com> | 2022-10-30 09:20:37 +0300 |
commit | d881203e04be81c97b96c376d6f350c4e5181886 (patch) | |
tree | d7d4e3685e4265e02a6d2c6076f42c2520255cac | |
parent | 0a527352cd0d61d6385ed54e6d454c14b4593e5b (diff) |
Rough draft of using silver trees.
Mostly untested. Includes an unfinished test of the silver data
-rw-r--r-- | stanza/models/constituency/trainer.py | 102 | ||||
-rw-r--r-- | stanza/models/constituency/utils.py | 3 | ||||
-rw-r--r-- | stanza/models/constituency_parser.py | 1 | ||||
-rw-r--r-- | stanza/tests/constituency/test_trainer.py | 17 |
4 files changed, 94 insertions, 29 deletions
diff --git a/stanza/models/constituency/trainer.py b/stanza/models/constituency/trainer.py index 000af875..6f1e4818 100644 --- a/stanza/models/constituency/trainer.py +++ b/stanza/models/constituency/trainer.py @@ -349,16 +349,39 @@ def add_grad_clipping(trainer, grad_clipping): if p.requires_grad: p.register_hook(lambda grad: torch.clamp(grad, -grad_clipping, grad_clipping)) -def build_trainer(args, train_trees, dev_trees, foundation_cache, model_load_file): +def check_constituents(train_constituents, trees, treebank_name): + """ + Check that all the constituents in the other dataset are known in the train set + """ + constituents = parse_tree.Tree.get_unique_constituent_labels(trees) + for con in constituents: + if con not in train_constituents: + raise RuntimeError("Found label {} in the {} set which don't exist in the train set".format(con, treebank_name)) + +def check_transitions(train_transitions, other_transitions, treebank_name): + """ + Check that all the transitions in the other dataset are known in the train set + """ + for trans in other_transitions: + if trans not in train_transitions: + raise RuntimeError("Found transition {} in the {} set which don't exist in the train set".format(trans, treebank_name)) + +def check_root_labels(root_labels, other_trees, treebank_name): + """ + Check that all the root states in the other dataset are known in the train set + """ + for root_state in parse_tree.Tree.get_root_labels(other_trees): + if root_state not in root_labels: + raise RuntimeError("Found root state {} in the {} set which is not a ROOT state in the train set".format(root_state, treebank_name)) + +def build_trainer(args, train_trees, dev_trees, silver_trees, foundation_cache, model_load_file): """ Builds a Trainer (with model) and the train_sequences and transitions for the given trees. """ train_constituents = parse_tree.Tree.get_unique_constituent_labels(train_trees) - dev_constituents = parse_tree.Tree.get_unique_constituent_labels(dev_trees) logger.info("Unique constituents in training set: %s", train_constituents) - for con in dev_constituents: - if con not in train_constituents: - raise RuntimeError("Found label {} in the dev set which don't exist in the train set".format(con)) + check_constituents(train_constituents, dev_trees, "dev") + check_constituents(train_constituents, silver_trees, "silver") constituent_counts = parse_tree.Tree.get_constituent_counts(train_trees) logger.info("Constituent node counts: %s", constituent_counts) @@ -374,22 +397,24 @@ def build_trainer(args, train_trees, dev_trees, foundation_cache, model_load_fil unary_limit = max(max(t.count_unary_depth() for t in train_trees), max(t.count_unary_depth() for t in dev_trees)) + 1 + if silver_trees: + unary_limit = max(unary_limit, max(t.count_unary_depth() for t in silver_trees)) train_sequences, train_transitions = transition_sequence.convert_trees_to_sequences(train_trees, "training", args['transition_scheme']) dev_sequences, dev_transitions = transition_sequence.convert_trees_to_sequences(dev_trees, "dev", args['transition_scheme']) + silver_sequences, silver_transitions = transition_sequence.convert_trees_to_sequences(silver_trees, "silver", args['transition_scheme']) logger.info("Total unique transitions in train set: %d", len(train_transitions)) logger.info("Unique transitions in training set: %s", train_transitions) - for trans in dev_transitions: - if trans not in train_transitions: - raise RuntimeError("Found transition {} in the dev set which don't exist in the train set".format(trans)) + check_transitions(train_transitions, dev_transitions, "dev") + # theoretically could just train based on the items in the silver dataset + check_transitions(train_transitions, silver_transitions, "silver") verify_transitions(train_trees, train_sequences, args['transition_scheme'], unary_limit) verify_transitions(dev_trees, dev_sequences, args['transition_scheme'], unary_limit) root_labels = parse_tree.Tree.get_root_labels(train_trees) - for root_state in parse_tree.Tree.get_root_labels(dev_trees): - if root_state not in root_labels: - raise RuntimeError("Found root state {} in the dev set which is not a ROOT state in the train set".format(root_state)) + check_root_labels(root_labels, dev_trees, "dev") + check_root_labels(root_labels, silver_trees, "silver") # we don't check against the words in the dev set as it is # expected there will be some UNK words @@ -418,12 +443,12 @@ def build_trainer(args, train_trees, dev_trees, foundation_cache, model_load_fil # grad clipping is not saved with the rest of the model add_grad_clipping(trainer, args['grad_clipping']) - # TODO: turn finetune, relearn_structure, multistage into an enum + # TODO: turn finetune, relearn_structure, multistage into an enum? # finetune just means continue learning, so checkpoint is sufficient # relearn_structure is essentially a one stage multistage # multistage with a checkpoint will have the proper optimizer for that epoch # and no special learning mode means we are training a new model and should continue - return trainer, train_sequences, train_transitions + return trainer, train_sequences, silver_sequences, train_transitions if args['finetune']: logger.info("Loading model to finetune: %s", model_load_file) @@ -477,7 +502,7 @@ def build_trainer(args, train_trees, dev_trees, foundation_cache, model_load_fil add_grad_clipping(trainer, args['grad_clipping']) - return trainer, train_sequences, train_transitions + return trainer, train_sequences, silver_sequences, train_transitions def remove_duplicates(trees, dataset): """ @@ -542,16 +567,23 @@ def train(args, model_load_file, model_save_each_file, retag_pipeline): logger.info("Read %d trees for the dev set", len(dev_trees)) dev_trees = remove_duplicates(dev_trees, "dev") + silver_trees = [] + if args['silver_file']: + silver_trees = tree_reader.read_treebank(args['silver_file']) + logger.info("Read %d trees for the silver training set", len(silver_trees)) + silver_trees = remove_duplicates(silver_trees, "silver") + if retag_pipeline is not None: logger.info("Retagging trees using the %s tags from the %s package...", args['retag_method'], args['retag_package']) train_trees = retag_trees(train_trees, retag_pipeline, args['retag_xpos']) dev_trees = retag_trees(dev_trees, retag_pipeline, args['retag_xpos']) + silver_trees = retag_trees(silver_trees, retag_pipeline, args['retag_xpos']) logger.info("Retagging finished") foundation_cache = retag_pipeline.foundation_cache if retag_pipeline else FoundationCache() - trainer, train_sequences, train_transitions = build_trainer(args, train_trees, dev_trees, foundation_cache, model_load_file) + trainer, train_sequences, silver_sequences, train_transitions = build_trainer(args, train_trees, dev_trees, silver_trees, foundation_cache, model_load_file) - trainer = iterate_training(args, trainer, train_trees, train_sequences, train_transitions, dev_trees, foundation_cache, model_save_each_file, evaluator) + trainer = iterate_training(args, trainer, train_trees, train_sequences, train_transitions, dev_trees, silver_trees, silver_sequences, foundation_cache, model_save_each_file, evaluator) if args['wandb']: wandb.finish() @@ -571,7 +603,27 @@ class EpochStats(namedtuple("EpochStats", ['epoch_loss', 'transitions_correct', return EpochStats(epoch_loss, transitions_correct, transitions_incorrect, repairs_used, fake_transitions_used, nans) -def iterate_training(args, trainer, train_trees, train_sequences, transitions, dev_trees, foundation_cache, model_save_each_filename, evaluator): +def compose_train_data(trees, sequences): + preterminal_lists = [[Tree(label=preterminal.label, children=Tree(label=preterminal.children[0].label)) + for preterminal in tree.yield_preterminals()] + for tree in trees] + data = [TrainItem(*x) for x in zip(trees, sequences, preterminal_lists)] + return data + +def next_epoch_data(leftover_training_data, train_data, epoch_size): + if not train_data: + return [], [] + + epoch_data = leftover_training_data + while len(epoch_data) < epoch_size: + random.shuffle(train_data) + epoch_data.extend(train_data) + leftover_training_data = epoch_data[epoch_size:] + epoch_data = epoch_data[:epoch_size] + + return leftover_training_data, epoch_data + +def iterate_training(args, trainer, train_trees, train_sequences, transitions, dev_trees, silver_trees, silver_sequences, foundation_cache, model_save_each_filename, evaluator): """ Given an initialized model, a processed dataset, and a secondary dev dataset, train the model @@ -598,10 +650,8 @@ def iterate_training(args, trainer, train_trees, train_sequences, transitions, d for (y, x) in enumerate(model.transitions)} model.train() - preterminal_lists = [[Tree(label=preterminal.label, children=Tree(label=preterminal.children[0].label)) - for preterminal in tree.yield_preterminals()] - for tree in train_trees] - train_data = [TrainItem(*x) for x in zip(train_trees, train_sequences, preterminal_lists)] + train_data = compose_train_data(train_trees, train_sequences) + silver_data = compose_train_data(silver_trees, silver_sequences) if not args['epoch_size']: args['epoch_size'] = len(train_data) @@ -614,6 +664,7 @@ def iterate_training(args, trainer, train_trees, train_sequences, transitions, d multistage_splits[args['epochs'] * 3 // 4] = (args['pattn_num_layers'], True) leftover_training_data = [] + leftover_silver_data = [] if trainer.best_epoch > 0: logger.info("Restarting trainer with a model trained for %d epochs. Best epoch %d, f1 %f", trainer.epochs_trained, trainer.best_epoch, trainer.best_f1) # trainer.epochs_trained+1 so that if the trainer gets saved after 1 epoch, the epochs_trained is 1 @@ -622,12 +673,9 @@ def iterate_training(args, trainer, train_trees, train_sequences, transitions, d logger.info("Starting epoch %d", trainer.epochs_trained) if args['log_norms']: model.log_norms() - epoch_data = leftover_training_data - while len(epoch_data) < args['epoch_size']: - random.shuffle(train_data) - epoch_data.extend(train_data) - leftover_training_data = epoch_data[args['epoch_size']:] - epoch_data = epoch_data[:args['epoch_size']] + leftover_training_data, epoch_data = next_epoch_data(leftover_training_data, train_data, args['epoch_size']) + leftover_silver_data, epoch_silver_data = next_epoch_data(leftover_silver_data, silver_data, args['epoch_size']) + epoch_data = epoch_data + epoch_silver_data epoch_data.sort(key=lambda x: len(x[1])) epoch_stats = train_model_one_epoch(trainer.epochs_trained, trainer, transition_tensors, model_loss_function, epoch_data, args) diff --git a/stanza/models/constituency/utils.py b/stanza/models/constituency/utils.py index 01dc3522..e3bfa525 100644 --- a/stanza/models/constituency/utils.py +++ b/stanza/models/constituency/utils.py @@ -78,6 +78,9 @@ def retag_trees(trees, pipeline, xpos=True): Returns a list of new trees """ + if len(trees) == 0: + return trees + sentences = [] try: for idx, tree in enumerate(trees): diff --git a/stanza/models/constituency_parser.py b/stanza/models/constituency_parser.py index 33e3cd05..7a1349e3 100644 --- a/stanza/models/constituency_parser.py +++ b/stanza/models/constituency_parser.py @@ -177,6 +177,7 @@ def parse_args(args=None): parser.add_argument('--delta_embedding_dim', type=int, default=100, help="Embedding size for a delta embedding") parser.add_argument('--train_file', type=str, default=None, help='Input file for data loader.') + parser.add_argument('--silver_file', type=str, default=None, help='Secondary training file.') parser.add_argument('--eval_file', type=str, default=None, help='Input file for data loader.') parser.add_argument('--tokenized_file', type=str, default=None, help='Input file of tokenized text for parsing with parse_text.') parser.add_argument('--mode', default='train', choices=['train', 'parse_text', 'predict', 'remove_optimizer']) diff --git a/stanza/tests/constituency/test_trainer.py b/stanza/tests/constituency/test_trainer.py index 71b537e1..a9a37f4f 100644 --- a/stanza/tests/constituency/test_trainer.py +++ b/stanza/tests/constituency/test_trainer.py @@ -57,6 +57,7 @@ def build_trainer(wordvec_pretrain_file, *args, treebank=TREEBANK): # TODO: build a fake embedding some other way? train_trees = tree_reader.read_trees(treebank) dev_trees = train_trees[-1:] + silver_trees = [] args = ['--wordvec_pretrain_file', wordvec_pretrain_file] + list(args) args = constituency_parser.parse_args(args) @@ -65,7 +66,7 @@ def build_trainer(wordvec_pretrain_file, *args, treebank=TREEBANK): # might be None, unless we're testing loading an existing model model_load_name = args['load_name'] - model, _, _ = trainer.build_trainer(args, train_trees, dev_trees, foundation_cache, model_load_name) + model, _, _, _ = trainer.build_trainer(args, train_trees, dev_trees, silver_trees, foundation_cache, model_load_name) assert isinstance(model.model, lstm_model.LSTMModel) return model @@ -148,7 +149,7 @@ class TestTrainer: args['wandb'] = None return args - def run_train_test(self, wordvec_pretrain_file, tmpdirname, num_epochs=5, extra_args=None): + def run_train_test(self, wordvec_pretrain_file, tmpdirname, num_epochs=5, extra_args=None, use_silver=False): """ Runs a test of the trainer for a few iterations. @@ -160,6 +161,8 @@ class TestTrainer: extra_args += ['--epochs', '%d' % num_epochs] train_treebank_file, eval_treebank_file = self.write_treebanks(tmpdirname) + if use_silver: + extra_args += ['--silver_file', str(eval_treebank_file)] args = self.training_args(wordvec_pretrain_file, tmpdirname, train_treebank_file, eval_treebank_file, *extra_args) each_name = os.path.join(args['save_dir'], 'each_%02d.pt') @@ -204,6 +207,16 @@ class TestTrainer: with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname: self.run_train_test(wordvec_pretrain_file, tmpdirname) + def test_train_silver(self, wordvec_pretrain_file): + """ + Test the whole thing for a few iterations on the fake data + + This tests that it works if you give it a silver file, but + doesn't actually test that the silver file is being used + """ + with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname: + self.run_train_test(wordvec_pretrain_file, tmpdirname, use_silver=True) + def run_multistage_tests(self, wordvec_pretrain_file, tmpdirname, use_lattn, extra_args=None): train_treebank_file, eval_treebank_file = self.write_treebanks(tmpdirname) args = ['--multistage', '--pattn_num_layers', '1'] |