diff options
author | John Bauer <horatio@gmail.com> | 2022-10-28 00:51:50 +0300 |
---|---|---|
committer | John Bauer <horatio@gmail.com> | 2022-10-28 04:49:18 +0300 |
commit | b8f3f380a50e1d178283a8567a42378ac2523036 (patch) | |
tree | e8c495a2fedba11cfc37696a628ce0bf5112c809 | |
parent | 94fd9c8379873834f3f26109d3d24e3256e9bf87 (diff) |
Refactor the retagging args & pipeline creation into a separate modeule
-rw-r--r-- | stanza/models/constituency/retagging.py | 65 | ||||
-rw-r--r-- | stanza/models/constituency_parser.py | 36 |
2 files changed, 70 insertions, 31 deletions
diff --git a/stanza/models/constituency/retagging.py b/stanza/models/constituency/retagging.py new file mode 100644 index 00000000..e74a47fc --- /dev/null +++ b/stanza/models/constituency/retagging.py @@ -0,0 +1,65 @@ +""" +Refactor a few functions specifically for retagging trees + +Retagging is important because the gold tags will not be available at runtime + +Note that the method which does the actual retagging is in utils.py +so as to avoid unnecessary circular imports +(eg, Pipeline imports constituency/trainer which imports this which imports Pipeline) +""" + +from stanza import Pipeline + +from stanza.models.common.vocab import VOCAB_PREFIX + +def add_retag_args(parser): + """ + Arguments specifically for retagging treebanks + """ + parser.add_argument('--retag_package', default="default", help='Which tagger shortname to use when retagging trees. None for no retagging. Retagging is recommended, as gold tags will not be available at pipeline time') + parser.add_argument('--retag_method', default='xpos', choices=['xpos', 'upos'], help='Which tags to use when retagging') + parser.add_argument('--retag_model_path', default=None, help='Path to a retag POS model to use. Will use a downloaded Stanza model by default') + parser.add_argument('--no_retag', dest='retag_package', action="store_const", const=None, help="Don't retag the trees") + +def postprocess_args(args): + """ + After parsing args, unify some settings + """ + if args['retag_method'] == 'xpos': + args['retag_xpos'] = True + elif args['retag_method'] == 'upos': + args['retag_xpos'] = False + else: + raise ValueError("Unknown retag method {}".format(xpos)) + +def build_retag_pipeline(args): + """ + Build a retag pipeline based on the arguments + + May alter the arguments if the pipeline is incompatible, such as + taggers with no xpos + """ + # some argument sets might not use 'mode' + if args['retag_package'] is not None and args.get('mode', None) != 'remove_optimizer': + if '_' in args['retag_package']: + lang, package = args['retag_package'].split('_', 1) + else: + if 'lang' not in args: + raise ValueError("Retag package %s does not specify the language, and it is not clear from the arguments" % args['retag_package']) + lang = args.get('lang', None) + package = args['retag_package'] + retag_args = {"lang": lang, + "processors": "tokenize, pos", + "tokenize_pretokenized": True, + "package": {"pos": package}, + "pos_tqdm": True} + if args['retag_model_path'] is not None: + retag_args['pos_model_path'] = args['retag_model_path'] + retag_pipeline = Pipeline(**retag_args) + if args['retag_xpos'] and len(retag_pipeline.processors['pos'].vocab['xpos']) == len(VOCAB_PREFIX): + logger.warning("XPOS for the %s tagger is empty. Switching to UPOS", package) + args['retag_xpos'] = False + args['retag_method'] = 'upos' + return retag_pipeline + + return None diff --git a/stanza/models/constituency_parser.py b/stanza/models/constituency_parser.py index 3a139170..6b991977 100644 --- a/stanza/models/constituency_parser.py +++ b/stanza/models/constituency_parser.py @@ -104,6 +104,7 @@ The code breakdown is as follows: constituency/lstm_model.py: adds LSTM features to the constituents to predict what the correct transition to make is, allowing for predictions on previously unseen text + constituency/retagging.py: a couple utility methods specifically for retagging constituency/utils.py: a couple utility methods constituency/dyanmic_oracle.py: a dynamic oracle which currently @@ -132,7 +133,7 @@ import torch from stanza import Pipeline from stanza.models.common import utils -from stanza.models.common.vocab import VOCAB_PREFIX +from stanza.models.constituency import retagging from stanza.models.constituency import trainer from stanza.models.constituency.lstm_model import ConstituencyComposition, SentenceBoundary, StackHistory from stanza.models.constituency.parse_transitions import TransitionScheme @@ -393,10 +394,7 @@ def parse_args(args=None): parser.add_argument('--no_checkpoint', dest='checkpoint', action='store_false', help="Don't save checkpoints") parser.add_argument('--load_name', type=str, default=None, help='Model to load when finetuning, evaluating, or manipulating an existing file') - parser.add_argument('--retag_package', default="default", help='Which tagger shortname to use when retagging trees. None for no retagging. Retagging is recommended, as gold tags will not be available at pipeline time') - parser.add_argument('--retag_method', default='xpos', choices=['xpos', 'upos'], help='Which tags to use when retagging') - parser.add_argument('--retag_model_path', default=None, help='Path to a retag POS model to use. Will use a downloaded Stanza model by default') - parser.add_argument('--no_retag', dest='retag_package', action="store_const", const=None, help="Don't retag the trees") + retagging.add_retag_args(parser) # Partitioned Attention parser.add_argument('--pattn_d_model', default=1024, type=int, help='Partitioned attention model dimensionality') @@ -471,12 +469,7 @@ def parse_args(args=None): args = vars(args) - if args['retag_method'] == 'xpos': - args['retag_xpos'] = True - elif args['retag_method'] == 'upos': - args['retag_xpos'] = False - else: - raise ValueError("Unknown retag method {}".format(xpos)) + retagging.postprocess_args(args) model_save_file = args['save_name'] if args['save_name'] else '{}_constituency.pt'.format(args['shorthand']) @@ -520,26 +513,7 @@ def main(args=None): else: model_load_file = os.path.join(args['save_dir'], args['load_name']) - if args['retag_package'] is not None and args['mode'] != 'remove_optimizer': - if '_' in args['retag_package']: - lang, package = args['retag_package'].split('_', 1) - else: - lang = args['lang'] - package = args['retag_package'] - retag_args = {"lang": lang, - "processors": "tokenize, pos", - "tokenize_pretokenized": True, - "package": {"pos": package}, - "pos_tqdm": True} - if args['retag_model_path'] is not None: - retag_args['pos_model_path'] = args['retag_model_path'] - retag_pipeline = Pipeline(**retag_args) - if args['retag_xpos'] and len(retag_pipeline.processors['pos'].vocab['xpos']) == len(VOCAB_PREFIX): - logger.warning("XPOS for the %s tagger is empty. Switching to UPOS", package) - args['retag_xpos'] = False - args['retag_method'] = 'upos' - else: - retag_pipeline = None + retag_pipeline = retagging.build_retag_pipeline(args) if args['mode'] == 'train': trainer.train(args, model_load_file, model_save_each_file, retag_pipeline) |