Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/stanfordnlp/stanza.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJohn Bauer <horatio@gmail.com>2022-04-09 09:52:07 +0300
committerJohn Bauer <horatio@gmail.com>2022-11-03 20:00:19 +0300
commite35f0b853043b1b3514c131aa2b6af58ae1e327f (patch)
treef0c509b1fb86f28731f7a0df5e6c04be7ed387d7
parentcfbf9d0616b4032541f191c8bee90f0456774eee (diff)
Add a flag to control how many tags to use when labeling shift transitionscon_shift_tags2
Keeps existing models viable by fixing common_tags when loading existing models
-rw-r--r--stanza/models/constituency/lstm_model.py4
-rw-r--r--stanza/models/constituency/parse_tree.py17
-rw-r--r--stanza/models/constituency/trainer.py24
-rw-r--r--stanza/models/constituency_parser.py2
-rw-r--r--stanza/tests/constituency/test_parse_tree.py12
5 files changed, 50 insertions, 9 deletions
diff --git a/stanza/models/constituency/lstm_model.py b/stanza/models/constituency/lstm_model.py
index 264f7314..e894fed5 100644
--- a/stanza/models/constituency/lstm_model.py
+++ b/stanza/models/constituency/lstm_model.py
@@ -181,7 +181,7 @@ class ConstituencyComposition(Enum):
UNTIED_MAX = 8
class LSTMModel(BaseModel, nn.Module):
- def __init__(self, pretrain, forward_charlm, backward_charlm, bert_model, bert_tokenizer, transitions, constituents, tags, words, rare_words, root_labels, constituent_opens, unary_limit, args):
+ def __init__(self, pretrain, forward_charlm, backward_charlm, bert_model, bert_tokenizer, transitions, constituents, tags, common_tags, words, rare_words, root_labels, constituent_opens, unary_limit, args):
"""
pretrain: a Pretrain object
transitions: a list of all possible transitions which will be
@@ -287,6 +287,7 @@ class LSTMModel(BaseModel, nn.Module):
self.rare_words = set(rare_words)
+ self.common_tags = set(common_tags)
self.tags = sorted(list(tags))
if self.tag_embedding_dim > 0:
self.tag_map = { t: i+2 for i, t in enumerate(self.tags) }
@@ -1060,6 +1061,7 @@ class LSTMModel(BaseModel, nn.Module):
'transitions': self.transitions,
'constituents': self.constituents,
'tags': self.tags,
+ 'common_tags': self.common_tags,
'words': self.delta_words,
'rare_words': self.rare_words,
'root_labels': self.root_labels,
diff --git a/stanza/models/constituency/parse_tree.py b/stanza/models/constituency/parse_tree.py
index 7db70caf..bca64f44 100644
--- a/stanza/models/constituency/parse_tree.py
+++ b/stanza/models/constituency/parse_tree.py
@@ -327,6 +327,23 @@ class Tree(StanzaObject):
tree.visit_preorder(preterminal = lambda x: tags.add(x.label))
return sorted(tags)
+
+ @staticmethod
+ def get_common_tags(trees, num_tags=5):
+ """
+ Walks over all of the trees and gets the most frequently occurring tags from the trees
+ """
+ if num_tags == 0:
+ return set()
+
+ if isinstance(trees, Tree):
+ trees = [trees]
+
+ tags = Counter()
+ for tree in trees:
+ tree.visit_preorder(preterminal = lambda x: tags.update([x.label]))
+ return sorted(x[0] for x in tags.most_common()[:num_tags])
+
@staticmethod
def get_unique_words(trees):
"""
diff --git a/stanza/models/constituency/trainer.py b/stanza/models/constituency/trainer.py
index dd6f3421..516b5843 100644
--- a/stanza/models/constituency/trainer.py
+++ b/stanza/models/constituency/trainer.py
@@ -110,6 +110,8 @@ class Trainer:
saved_args.update(update_args)
model_type = params['model_type']
+ # TODO: when all models have this baked in, no need to 'get' this parameter
+ common_tags = params.get('common_tags', set())
if model_type == 'LSTM':
pt = load_pretrain(saved_args.get('wordvec_pretrain_file', None), foundation_cache)
bert_model, bert_tokenizer = load_bert(saved_args.get('bert_model', None), foundation_cache)
@@ -124,6 +126,7 @@ class Trainer:
transitions=Trainer.fix_shift_transitions(params['transitions']),
constituents=params['constituents'],
tags=params['tags'],
+ common_tags=common_tags,
words=params['words'],
rare_words=params['rare_words'],
root_labels=params['root_labels'],
@@ -411,20 +414,25 @@ def build_trainer(args, train_trees, dev_trees, silver_trees, foundation_cache,
if tag not in tags:
logger.info("Found tag in dev set which does not exist in train set: %s Continuing...", tag)
+ num_common_tags = args['num_tag_shifts']
+ if num_common_tags < 0:
+ num_common_tags = len(tags)
+ common_tags = parse_tree.Tree.get_common_tags(train_trees, num_common_tags)
unary_limit = max(max(t.count_unary_depth() for t in train_trees),
max(t.count_unary_depth() for t in dev_trees)) + 1
+
if silver_trees:
unary_limit = max(unary_limit, max(t.count_unary_depth() for t in silver_trees))
- train_sequences = transition_sequence.convert_trees_to_sequences(train_trees, "training", args['transition_scheme'], tags)
+ train_sequences = transition_sequence.convert_trees_to_sequences(train_trees, "training", args['transition_scheme'], common_tags)
# the training transitions will all be labeled with the tags
# currently we are just checking correctness
# we add an unlabeled Shift so that the model can represent previously unseen tags
# at train time we will redo some tags as <UNK> to train the unlabeled Shift
# (this also will essentially be a form of dropout)
train_transitions = transition_sequence.all_transitions(train_sequences + [[parse_transitions.Shift()]])
- dev_sequences = transition_sequence.convert_trees_to_sequences(dev_trees, "dev", args['transition_scheme'], tags)
+ dev_sequences = transition_sequence.convert_trees_to_sequences(dev_trees, "dev", args['transition_scheme'], common_tags)
dev_transitions = transition_sequence.all_transitions(dev_sequences)
- silver_sequences = transition_sequence.convert_trees_to_sequences(silver_trees, "silver", args['transition_scheme'], tags)
+ silver_sequences = transition_sequence.convert_trees_to_sequences(silver_trees, "silver", args['transition_scheme'], common_tags)
silver_transitions = transition_sequence.all_transitions(silver_sequences)
logger.info("Total unique transitions in train set: %d", len(train_transitions))
@@ -496,7 +504,7 @@ def build_trainer(args, train_trees, dev_trees, silver_trees, foundation_cache,
# using the model's current values works for if the new
# dataset is the same or smaller
# TODO: handle a larger dataset as well
- model = LSTMModel(pt, forward_charlm, backward_charlm, bert_model, bert_tokenizer, trainer.model.transitions, trainer.model.constituents, trainer.model.tags, trainer.model.delta_words, trainer.model.rare_words, trainer.model.root_labels, trainer.model.constituent_opens, trainer.model.unary_limit(), args)
+ model = LSTMModel(pt, forward_charlm, backward_charlm, bert_model, bert_tokenizer, trainer.model.transitions, trainer.model.constituents, trainer.model.tags, trainer.model.common_tags, trainer.model.delta_words, trainer.model.rare_words, trainer.model.root_labels, trainer.model.constituent_opens, trainer.model.unary_limit(), args)
if args['cuda']:
model.cuda()
model.copy_with_new_structure(trainer.model)
@@ -513,14 +521,14 @@ def build_trainer(args, train_trees, dev_trees, silver_trees, foundation_cache,
temp_args['pattn_num_layers'] = 0
temp_args['lattn_d_proj'] = 0
- temp_model = LSTMModel(pt, forward_charlm, backward_charlm, bert_model, bert_tokenizer, train_transitions, train_constituents, tags, words, rare_words, root_labels, open_nodes, unary_limit, temp_args)
+ temp_model = LSTMModel(pt, forward_charlm, backward_charlm, bert_model, bert_tokenizer, train_transitions, train_constituents, tags, common_tags, words, rare_words, root_labels, open_nodes, unary_limit, temp_args)
if args['cuda']:
temp_model.cuda()
temp_optim = build_optimizer(temp_args, temp_model, True)
scheduler = build_scheduler(temp_args, temp_optim)
trainer = Trainer(temp_model, temp_optim, scheduler)
else:
- model = LSTMModel(pt, forward_charlm, backward_charlm, bert_model, bert_tokenizer, train_transitions, train_constituents, tags, words, rare_words, root_labels, open_nodes, unary_limit, args)
+ model = LSTMModel(pt, forward_charlm, backward_charlm, bert_model, bert_tokenizer, train_transitions, train_constituents, tags, common_tags, words, rare_words, root_labels, open_nodes, unary_limit, args)
if args['cuda']:
model.cuda()
@@ -763,7 +771,7 @@ def iterate_training(args, trainer, train_trees, train_sequences, transitions, d
forward_charlm = foundation_cache.load_charlm(args['charlm_forward_file'])
backward_charlm = foundation_cache.load_charlm(args['charlm_backward_file'])
bert_model, bert_tokenizer = foundation_cache.load_bert(args['bert_model'])
- new_model = LSTMModel(pt, forward_charlm, backward_charlm, bert_model, bert_tokenizer, model.transitions, model.constituents, model.tags, model.delta_words, model.rare_words, model.root_labels, model.constituent_opens, model.unary_limit(), temp_args)
+ new_model = LSTMModel(pt, forward_charlm, backward_charlm, bert_model, bert_tokenizer, model.transitions, model.constituents, model.tags, model.common_tags, model.delta_words, model.rare_words, model.root_labels, model.constituent_opens, model.unary_limit(), temp_args)
if args['cuda']:
new_model.cuda()
new_model.copy_with_new_structure(model)
@@ -840,7 +848,7 @@ def train_model_one_batch(epoch, batch_idx, model, training_batch, transition_te
# the state is build as a bulk operation
gold_trees = [x.tree.dropout_tags(args['tag_dropout']) for x in training_batch]
preterminals = [list(x.yield_preterminals()) for x in gold_trees]
- train_sequences = transition_sequence.build_treebank(gold_trees, args['transition_scheme'], model.tags)
+ train_sequences = transition_sequence.build_treebank(gold_trees, args['transition_scheme'], model.common_tags)
initial_states = model.initial_state_from_preterminals(preterminals, gold_trees)
current_batch = [state._replace(gold_sequence=sequence)
for sequence, state in zip(train_sequences, initial_states)]
diff --git a/stanza/models/constituency_parser.py b/stanza/models/constituency_parser.py
index c88f8a01..291e8b20 100644
--- a/stanza/models/constituency_parser.py
+++ b/stanza/models/constituency_parser.py
@@ -371,6 +371,8 @@ def parse_args(args=None):
# 0.4: 0.951272
parser.add_argument('--tag_dropout', default=0.1, type=float,
help='Fraction of tags to replace with <UNK> at training time')
+ parser.add_argument('--num_tag_shifts', default=5, type=int,
+ help='How many tags to use when labeling shifts. -1 means all of them')
# combining dummy and open node embeddings might be a slight improvement
# for example, after 550 iterations, one experiment had
diff --git a/stanza/tests/constituency/test_parse_tree.py b/stanza/tests/constituency/test_parse_tree.py
index 3f531219..26cd732f 100644
--- a/stanza/tests/constituency/test_parse_tree.py
+++ b/stanza/tests/constituency/test_parse_tree.py
@@ -92,6 +92,18 @@ def test_rare_words():
expected = ['Who', 'in', 'sits']
assert words == expected
+def test_common_tags():
+ """
+ Test getting the unique words from a tree
+ """
+ text="((SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?))) ((SBARQ (NP (DT this) (NN seat)) (. ?)))"
+
+ trees = tree_reader.read_trees(text)
+
+ words = Tree.get_common_tags(trees, 3)
+ expected = ['.', 'DT', 'NN']
+ assert words == expected
+
def test_common_words():
"""
Test getting the unique words from a tree