diff options
author | John Bauer <horatio@gmail.com> | 2022-11-13 22:42:22 +0300 |
---|---|---|
committer | John Bauer <horatio@gmail.com> | 2022-11-13 22:42:22 +0300 |
commit | 62c42a342ec9813bbbd8825816582ecc2b8bf16b (patch) | |
tree | f9593894b6647271971aef98ccf5857d23d18cab | |
parent | 766341942962e5a5a0aa0cda3dd170ac098ac6f9 (diff) |
Add the ability to ensemble the retag models in --score_dev or ensemble modes
-rw-r--r-- | stanza/models/constituency/ensemble.py | 2 | ||||
-rw-r--r-- | stanza/models/constituency/retagging.py | 37 | ||||
-rw-r--r-- | stanza/models/constituency/trainer.py | 6 | ||||
-rw-r--r-- | stanza/models/constituency/utils.py | 20 |
4 files changed, 44 insertions, 21 deletions
diff --git a/stanza/models/constituency/ensemble.py b/stanza/models/constituency/ensemble.py index f05216dc..1a354c6f 100644 --- a/stanza/models/constituency/ensemble.py +++ b/stanza/models/constituency/ensemble.py @@ -269,7 +269,7 @@ def parse_args(args=None): def main(): args = parse_args() retag_pipeline = retagging.build_retag_pipeline(args) - foundation_cache = retag_pipeline.foundation_cache if retag_pipeline else FoundationCache() + foundation_cache = retag_pipeline[0].foundation_cache if retag_pipeline else FoundationCache() ensemble = Ensemble(args['models'], args, foundation_cache) ensemble.eval() diff --git a/stanza/models/constituency/retagging.py b/stanza/models/constituency/retagging.py index 4bf9e69e..5fa535b2 100644 --- a/stanza/models/constituency/retagging.py +++ b/stanza/models/constituency/retagging.py @@ -8,10 +8,12 @@ so as to avoid unnecessary circular imports (eg, Pipeline imports constituency/trainer which imports this which imports Pipeline) """ +import copy import logging from stanza import Pipeline +from stanza.models.common.foundation_cache import FoundationCache from stanza.models.common.vocab import VOCAB_PREFIX logger = logging.getLogger('stanza') @@ -22,7 +24,7 @@ def add_retag_args(parser): """ parser.add_argument('--retag_package', default="default", help='Which tagger shortname to use when retagging trees. None for no retagging. Retagging is recommended, as gold tags will not be available at pipeline time') parser.add_argument('--retag_method', default='xpos', choices=['xpos', 'upos'], help='Which tags to use when retagging') - parser.add_argument('--retag_model_path', default=None, help='Path to a retag POS model to use. Will use a downloaded Stanza model by default') + parser.add_argument('--retag_model_path', default=None, help='Path to a retag POS model to use. Will use a downloaded Stanza model by default. Can specify multiple taggers with ;') parser.add_argument('--no_retag', dest='retag_package', action="store_const", const=None, help="Don't retag the trees") def postprocess_args(args): @@ -38,10 +40,14 @@ def postprocess_args(args): def build_retag_pipeline(args): """ - Build a retag pipeline based on the arguments + Builds retag pipelines based on the arguments May alter the arguments if the pipeline is incompatible, such as taggers with no xpos + + Will return a list of one or more retag pipelines. + Multiple tagger models can be specified by having them + semi-colon separated in retag_model_path. """ # some argument sets might not use 'mode' if args['retag_package'] is not None and args.get('mode', None) != 'remove_optimizer': @@ -52,17 +58,28 @@ def build_retag_pipeline(args): raise ValueError("Retag package %s does not specify the language, and it is not clear from the arguments" % args['retag_package']) lang = args.get('lang', None) package = args['retag_package'] + foundation_cache = FoundationCache() retag_args = {"lang": lang, "processors": "tokenize, pos", "tokenize_pretokenized": True, "package": {"pos": package}} - if args['retag_model_path'] is not None: - retag_args['pos_model_path'] = args['retag_model_path'] - retag_pipeline = Pipeline(**retag_args) - if args['retag_xpos'] and len(retag_pipeline.processors['pos'].vocab['xpos']) == len(VOCAB_PREFIX): - logger.warning("XPOS for the %s tagger is empty. Switching to UPOS", package) - args['retag_xpos'] = False - args['retag_method'] = 'upos' - return retag_pipeline + + def build(retag_args, path): + retag_args = copy.deepcopy(retag_args) + if path is not None: + retag_args['pos_model_path'] = path + + retag_pipeline = Pipeline(foundation_cache=foundation_cache, **retag_args) + if args['retag_xpos'] and len(retag_pipeline.processors['pos'].vocab['xpos']) == len(VOCAB_PREFIX): + logger.warning("XPOS for the %s tagger is empty. Switching to UPOS", package) + args['retag_xpos'] = False + args['retag_method'] = 'upos' + return retag_pipeline + + if args['retag_model_path'] is None: + return [build(retag_args, None)] + paths = args['retag_model_path'].split(";") + # can be length 1 if only one tagger to work with + return [build(retag_args, path) for path in paths] return None diff --git a/stanza/models/constituency/trainer.py b/stanza/models/constituency/trainer.py index 2ee3ae2f..97a6af32 100644 --- a/stanza/models/constituency/trainer.py +++ b/stanza/models/constituency/trainer.py @@ -215,7 +215,7 @@ def load_model_parse_text(args, model_file, retag_pipeline): """ Load a model, then parse text and write it to stdout or args['predict_file'] """ - foundation_cache = retag_pipeline.foundation_cache if retag_pipeline else FoundationCache() + foundation_cache = retag_pipeline[0].foundation_cache if retag_pipeline else FoundationCache() load_args = { 'wordvec_pretrain_file': args['wordvec_pretrain_file'], 'charlm_forward_file': args['charlm_forward_file'], @@ -288,7 +288,7 @@ def evaluate(args, model_file, retag_pipeline): kbest = None with EvaluateParser(kbest=kbest) as evaluator: - foundation_cache = retag_pipeline.foundation_cache if retag_pipeline else FoundationCache() + foundation_cache = retag_pipeline[0].foundation_cache if retag_pipeline else FoundationCache() load_args = { 'wordvec_pretrain_file': args['wordvec_pretrain_file'], 'charlm_forward_file': args['charlm_forward_file'], @@ -579,7 +579,7 @@ def train(args, model_load_file, model_save_each_file, retag_pipeline): silver_trees = retag_trees(silver_trees, retag_pipeline, args['retag_xpos']) logger.info("Retagging finished") - foundation_cache = retag_pipeline.foundation_cache if retag_pipeline else FoundationCache() + foundation_cache = retag_pipeline[0].foundation_cache if retag_pipeline else FoundationCache() trainer, train_sequences, silver_sequences, train_transitions = build_trainer(args, train_trees, dev_trees, silver_trees, foundation_cache, model_load_file) trainer = iterate_training(args, trainer, train_trees, train_sequences, train_transitions, dev_trees, silver_trees, silver_sequences, foundation_cache, model_save_each_file, evaluator) diff --git a/stanza/models/constituency/utils.py b/stanza/models/constituency/utils.py index be89ab7d..e207bc43 100644 --- a/stanza/models/constituency/utils.py +++ b/stanza/models/constituency/utils.py @@ -2,7 +2,7 @@ Collects a few of the conparser utility methods which don't belong elsewhere """ -from collections import deque +from collections import Counter, deque import copy import logging @@ -75,7 +75,7 @@ def replace_tags(tree, tags): return new_tree -def retag_trees(trees, pipeline, xpos=True): +def retag_trees(trees, pipelines, xpos=True): """ Retag all of the trees using the given processor @@ -99,11 +99,17 @@ def retag_trees(trees, pipeline, xpos=True): raise ValueError("Unable to process tree %d" % (idx + chunk_start)) from e doc = Document(sentences) - doc = pipeline(doc) - if xpos: - tag_lists = [[x.xpos for x in sentence.words] for sentence in doc.sentences] - else: - tag_lists = [[x.upos for x in sentence.words] for sentence in doc.sentences] + tag_lists = [] + for pipeline in pipelines: + doc = pipeline(doc) + tag_lists.append([[x.xpos if xpos else x.upos for x in sentence.words] for sentence in doc.sentences]) + # tag_lists: for N pipeline, S sentences + # we now have N lists of S sentences each + # for sentence in zip(*tag_lists): N lists of |s| tags for this given sentence s + # for tag in zip(*sentence): N predicted tags. + # most common one in the Counter will be chosen + tag_lists = [[Counter(tag).most_common(1)[0][0] for tag in zip(*sentence)] + for sentence in zip(*tag_lists)] for tree_idx, (tree, tags) in enumerate(zip(chunk, tag_lists)): try: |