diff options
author | John Bauer <horatio@gmail.com> | 2022-04-09 09:52:07 +0300 |
---|---|---|
committer | John Bauer <horatio@gmail.com> | 2022-11-03 20:00:19 +0300 |
commit | e35f0b853043b1b3514c131aa2b6af58ae1e327f (patch) | |
tree | f0c509b1fb86f28731f7a0df5e6c04be7ed387d7 | |
parent | cfbf9d0616b4032541f191c8bee90f0456774eee (diff) |
Add a flag to control how many tags to use when labeling shift transitionscon_shift_tags2
Keeps existing models viable by fixing common_tags when loading existing models
-rw-r--r-- | stanza/models/constituency/lstm_model.py | 4 | ||||
-rw-r--r-- | stanza/models/constituency/parse_tree.py | 17 | ||||
-rw-r--r-- | stanza/models/constituency/trainer.py | 24 | ||||
-rw-r--r-- | stanza/models/constituency_parser.py | 2 | ||||
-rw-r--r-- | stanza/tests/constituency/test_parse_tree.py | 12 |
5 files changed, 50 insertions, 9 deletions
diff --git a/stanza/models/constituency/lstm_model.py b/stanza/models/constituency/lstm_model.py index 264f7314..e894fed5 100644 --- a/stanza/models/constituency/lstm_model.py +++ b/stanza/models/constituency/lstm_model.py @@ -181,7 +181,7 @@ class ConstituencyComposition(Enum): UNTIED_MAX = 8 class LSTMModel(BaseModel, nn.Module): - def __init__(self, pretrain, forward_charlm, backward_charlm, bert_model, bert_tokenizer, transitions, constituents, tags, words, rare_words, root_labels, constituent_opens, unary_limit, args): + def __init__(self, pretrain, forward_charlm, backward_charlm, bert_model, bert_tokenizer, transitions, constituents, tags, common_tags, words, rare_words, root_labels, constituent_opens, unary_limit, args): """ pretrain: a Pretrain object transitions: a list of all possible transitions which will be @@ -287,6 +287,7 @@ class LSTMModel(BaseModel, nn.Module): self.rare_words = set(rare_words) + self.common_tags = set(common_tags) self.tags = sorted(list(tags)) if self.tag_embedding_dim > 0: self.tag_map = { t: i+2 for i, t in enumerate(self.tags) } @@ -1060,6 +1061,7 @@ class LSTMModel(BaseModel, nn.Module): 'transitions': self.transitions, 'constituents': self.constituents, 'tags': self.tags, + 'common_tags': self.common_tags, 'words': self.delta_words, 'rare_words': self.rare_words, 'root_labels': self.root_labels, diff --git a/stanza/models/constituency/parse_tree.py b/stanza/models/constituency/parse_tree.py index 7db70caf..bca64f44 100644 --- a/stanza/models/constituency/parse_tree.py +++ b/stanza/models/constituency/parse_tree.py @@ -327,6 +327,23 @@ class Tree(StanzaObject): tree.visit_preorder(preterminal = lambda x: tags.add(x.label)) return sorted(tags) + + @staticmethod + def get_common_tags(trees, num_tags=5): + """ + Walks over all of the trees and gets the most frequently occurring tags from the trees + """ + if num_tags == 0: + return set() + + if isinstance(trees, Tree): + trees = [trees] + + tags = Counter() + for tree in trees: + tree.visit_preorder(preterminal = lambda x: tags.update([x.label])) + return sorted(x[0] for x in tags.most_common()[:num_tags]) + @staticmethod def get_unique_words(trees): """ diff --git a/stanza/models/constituency/trainer.py b/stanza/models/constituency/trainer.py index dd6f3421..516b5843 100644 --- a/stanza/models/constituency/trainer.py +++ b/stanza/models/constituency/trainer.py @@ -110,6 +110,8 @@ class Trainer: saved_args.update(update_args) model_type = params['model_type'] + # TODO: when all models have this baked in, no need to 'get' this parameter + common_tags = params.get('common_tags', set()) if model_type == 'LSTM': pt = load_pretrain(saved_args.get('wordvec_pretrain_file', None), foundation_cache) bert_model, bert_tokenizer = load_bert(saved_args.get('bert_model', None), foundation_cache) @@ -124,6 +126,7 @@ class Trainer: transitions=Trainer.fix_shift_transitions(params['transitions']), constituents=params['constituents'], tags=params['tags'], + common_tags=common_tags, words=params['words'], rare_words=params['rare_words'], root_labels=params['root_labels'], @@ -411,20 +414,25 @@ def build_trainer(args, train_trees, dev_trees, silver_trees, foundation_cache, if tag not in tags: logger.info("Found tag in dev set which does not exist in train set: %s Continuing...", tag) + num_common_tags = args['num_tag_shifts'] + if num_common_tags < 0: + num_common_tags = len(tags) + common_tags = parse_tree.Tree.get_common_tags(train_trees, num_common_tags) unary_limit = max(max(t.count_unary_depth() for t in train_trees), max(t.count_unary_depth() for t in dev_trees)) + 1 + if silver_trees: unary_limit = max(unary_limit, max(t.count_unary_depth() for t in silver_trees)) - train_sequences = transition_sequence.convert_trees_to_sequences(train_trees, "training", args['transition_scheme'], tags) + train_sequences = transition_sequence.convert_trees_to_sequences(train_trees, "training", args['transition_scheme'], common_tags) # the training transitions will all be labeled with the tags # currently we are just checking correctness # we add an unlabeled Shift so that the model can represent previously unseen tags # at train time we will redo some tags as <UNK> to train the unlabeled Shift # (this also will essentially be a form of dropout) train_transitions = transition_sequence.all_transitions(train_sequences + [[parse_transitions.Shift()]]) - dev_sequences = transition_sequence.convert_trees_to_sequences(dev_trees, "dev", args['transition_scheme'], tags) + dev_sequences = transition_sequence.convert_trees_to_sequences(dev_trees, "dev", args['transition_scheme'], common_tags) dev_transitions = transition_sequence.all_transitions(dev_sequences) - silver_sequences = transition_sequence.convert_trees_to_sequences(silver_trees, "silver", args['transition_scheme'], tags) + silver_sequences = transition_sequence.convert_trees_to_sequences(silver_trees, "silver", args['transition_scheme'], common_tags) silver_transitions = transition_sequence.all_transitions(silver_sequences) logger.info("Total unique transitions in train set: %d", len(train_transitions)) @@ -496,7 +504,7 @@ def build_trainer(args, train_trees, dev_trees, silver_trees, foundation_cache, # using the model's current values works for if the new # dataset is the same or smaller # TODO: handle a larger dataset as well - model = LSTMModel(pt, forward_charlm, backward_charlm, bert_model, bert_tokenizer, trainer.model.transitions, trainer.model.constituents, trainer.model.tags, trainer.model.delta_words, trainer.model.rare_words, trainer.model.root_labels, trainer.model.constituent_opens, trainer.model.unary_limit(), args) + model = LSTMModel(pt, forward_charlm, backward_charlm, bert_model, bert_tokenizer, trainer.model.transitions, trainer.model.constituents, trainer.model.tags, trainer.model.common_tags, trainer.model.delta_words, trainer.model.rare_words, trainer.model.root_labels, trainer.model.constituent_opens, trainer.model.unary_limit(), args) if args['cuda']: model.cuda() model.copy_with_new_structure(trainer.model) @@ -513,14 +521,14 @@ def build_trainer(args, train_trees, dev_trees, silver_trees, foundation_cache, temp_args['pattn_num_layers'] = 0 temp_args['lattn_d_proj'] = 0 - temp_model = LSTMModel(pt, forward_charlm, backward_charlm, bert_model, bert_tokenizer, train_transitions, train_constituents, tags, words, rare_words, root_labels, open_nodes, unary_limit, temp_args) + temp_model = LSTMModel(pt, forward_charlm, backward_charlm, bert_model, bert_tokenizer, train_transitions, train_constituents, tags, common_tags, words, rare_words, root_labels, open_nodes, unary_limit, temp_args) if args['cuda']: temp_model.cuda() temp_optim = build_optimizer(temp_args, temp_model, True) scheduler = build_scheduler(temp_args, temp_optim) trainer = Trainer(temp_model, temp_optim, scheduler) else: - model = LSTMModel(pt, forward_charlm, backward_charlm, bert_model, bert_tokenizer, train_transitions, train_constituents, tags, words, rare_words, root_labels, open_nodes, unary_limit, args) + model = LSTMModel(pt, forward_charlm, backward_charlm, bert_model, bert_tokenizer, train_transitions, train_constituents, tags, common_tags, words, rare_words, root_labels, open_nodes, unary_limit, args) if args['cuda']: model.cuda() @@ -763,7 +771,7 @@ def iterate_training(args, trainer, train_trees, train_sequences, transitions, d forward_charlm = foundation_cache.load_charlm(args['charlm_forward_file']) backward_charlm = foundation_cache.load_charlm(args['charlm_backward_file']) bert_model, bert_tokenizer = foundation_cache.load_bert(args['bert_model']) - new_model = LSTMModel(pt, forward_charlm, backward_charlm, bert_model, bert_tokenizer, model.transitions, model.constituents, model.tags, model.delta_words, model.rare_words, model.root_labels, model.constituent_opens, model.unary_limit(), temp_args) + new_model = LSTMModel(pt, forward_charlm, backward_charlm, bert_model, bert_tokenizer, model.transitions, model.constituents, model.tags, model.common_tags, model.delta_words, model.rare_words, model.root_labels, model.constituent_opens, model.unary_limit(), temp_args) if args['cuda']: new_model.cuda() new_model.copy_with_new_structure(model) @@ -840,7 +848,7 @@ def train_model_one_batch(epoch, batch_idx, model, training_batch, transition_te # the state is build as a bulk operation gold_trees = [x.tree.dropout_tags(args['tag_dropout']) for x in training_batch] preterminals = [list(x.yield_preterminals()) for x in gold_trees] - train_sequences = transition_sequence.build_treebank(gold_trees, args['transition_scheme'], model.tags) + train_sequences = transition_sequence.build_treebank(gold_trees, args['transition_scheme'], model.common_tags) initial_states = model.initial_state_from_preterminals(preterminals, gold_trees) current_batch = [state._replace(gold_sequence=sequence) for sequence, state in zip(train_sequences, initial_states)] diff --git a/stanza/models/constituency_parser.py b/stanza/models/constituency_parser.py index c88f8a01..291e8b20 100644 --- a/stanza/models/constituency_parser.py +++ b/stanza/models/constituency_parser.py @@ -371,6 +371,8 @@ def parse_args(args=None): # 0.4: 0.951272 parser.add_argument('--tag_dropout', default=0.1, type=float, help='Fraction of tags to replace with <UNK> at training time') + parser.add_argument('--num_tag_shifts', default=5, type=int, + help='How many tags to use when labeling shifts. -1 means all of them') # combining dummy and open node embeddings might be a slight improvement # for example, after 550 iterations, one experiment had diff --git a/stanza/tests/constituency/test_parse_tree.py b/stanza/tests/constituency/test_parse_tree.py index 3f531219..26cd732f 100644 --- a/stanza/tests/constituency/test_parse_tree.py +++ b/stanza/tests/constituency/test_parse_tree.py @@ -92,6 +92,18 @@ def test_rare_words(): expected = ['Who', 'in', 'sits'] assert words == expected +def test_common_tags(): + """ + Test getting the unique words from a tree + """ + text="((SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?))) ((SBARQ (NP (DT this) (NN seat)) (. ?)))" + + trees = tree_reader.read_trees(text) + + words = Tree.get_common_tags(trees, 3) + expected = ['.', 'DT', 'NN'] + assert words == expected + def test_common_words(): """ Test getting the unique words from a tree |