Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/stanfordnlp/stanza.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJohn Bauer <horatio@gmail.com>2022-10-28 00:51:50 +0300
committerJohn Bauer <horatio@gmail.com>2022-10-28 04:49:18 +0300
commitb8f3f380a50e1d178283a8567a42378ac2523036 (patch)
treee8c495a2fedba11cfc37696a628ce0bf5112c809
parent94fd9c8379873834f3f26109d3d24e3256e9bf87 (diff)
Refactor the retagging args & pipeline creation into a separate modeule
-rw-r--r--stanza/models/constituency/retagging.py65
-rw-r--r--stanza/models/constituency_parser.py36
2 files changed, 70 insertions, 31 deletions
diff --git a/stanza/models/constituency/retagging.py b/stanza/models/constituency/retagging.py
new file mode 100644
index 00000000..e74a47fc
--- /dev/null
+++ b/stanza/models/constituency/retagging.py
@@ -0,0 +1,65 @@
+"""
+Refactor a few functions specifically for retagging trees
+
+Retagging is important because the gold tags will not be available at runtime
+
+Note that the method which does the actual retagging is in utils.py
+so as to avoid unnecessary circular imports
+(eg, Pipeline imports constituency/trainer which imports this which imports Pipeline)
+"""
+
+from stanza import Pipeline
+
+from stanza.models.common.vocab import VOCAB_PREFIX
+
+def add_retag_args(parser):
+ """
+ Arguments specifically for retagging treebanks
+ """
+ parser.add_argument('--retag_package', default="default", help='Which tagger shortname to use when retagging trees. None for no retagging. Retagging is recommended, as gold tags will not be available at pipeline time')
+ parser.add_argument('--retag_method', default='xpos', choices=['xpos', 'upos'], help='Which tags to use when retagging')
+ parser.add_argument('--retag_model_path', default=None, help='Path to a retag POS model to use. Will use a downloaded Stanza model by default')
+ parser.add_argument('--no_retag', dest='retag_package', action="store_const", const=None, help="Don't retag the trees")
+
+def postprocess_args(args):
+ """
+ After parsing args, unify some settings
+ """
+ if args['retag_method'] == 'xpos':
+ args['retag_xpos'] = True
+ elif args['retag_method'] == 'upos':
+ args['retag_xpos'] = False
+ else:
+ raise ValueError("Unknown retag method {}".format(xpos))
+
+def build_retag_pipeline(args):
+ """
+ Build a retag pipeline based on the arguments
+
+ May alter the arguments if the pipeline is incompatible, such as
+ taggers with no xpos
+ """
+ # some argument sets might not use 'mode'
+ if args['retag_package'] is not None and args.get('mode', None) != 'remove_optimizer':
+ if '_' in args['retag_package']:
+ lang, package = args['retag_package'].split('_', 1)
+ else:
+ if 'lang' not in args:
+ raise ValueError("Retag package %s does not specify the language, and it is not clear from the arguments" % args['retag_package'])
+ lang = args.get('lang', None)
+ package = args['retag_package']
+ retag_args = {"lang": lang,
+ "processors": "tokenize, pos",
+ "tokenize_pretokenized": True,
+ "package": {"pos": package},
+ "pos_tqdm": True}
+ if args['retag_model_path'] is not None:
+ retag_args['pos_model_path'] = args['retag_model_path']
+ retag_pipeline = Pipeline(**retag_args)
+ if args['retag_xpos'] and len(retag_pipeline.processors['pos'].vocab['xpos']) == len(VOCAB_PREFIX):
+ logger.warning("XPOS for the %s tagger is empty. Switching to UPOS", package)
+ args['retag_xpos'] = False
+ args['retag_method'] = 'upos'
+ return retag_pipeline
+
+ return None
diff --git a/stanza/models/constituency_parser.py b/stanza/models/constituency_parser.py
index 3a139170..6b991977 100644
--- a/stanza/models/constituency_parser.py
+++ b/stanza/models/constituency_parser.py
@@ -104,6 +104,7 @@ The code breakdown is as follows:
constituency/lstm_model.py: adds LSTM features to the constituents to predict what the
correct transition to make is, allowing for predictions on previously unseen text
+ constituency/retagging.py: a couple utility methods specifically for retagging
constituency/utils.py: a couple utility methods
constituency/dyanmic_oracle.py: a dynamic oracle which currently
@@ -132,7 +133,7 @@ import torch
from stanza import Pipeline
from stanza.models.common import utils
-from stanza.models.common.vocab import VOCAB_PREFIX
+from stanza.models.constituency import retagging
from stanza.models.constituency import trainer
from stanza.models.constituency.lstm_model import ConstituencyComposition, SentenceBoundary, StackHistory
from stanza.models.constituency.parse_transitions import TransitionScheme
@@ -393,10 +394,7 @@ def parse_args(args=None):
parser.add_argument('--no_checkpoint', dest='checkpoint', action='store_false', help="Don't save checkpoints")
parser.add_argument('--load_name', type=str, default=None, help='Model to load when finetuning, evaluating, or manipulating an existing file')
- parser.add_argument('--retag_package', default="default", help='Which tagger shortname to use when retagging trees. None for no retagging. Retagging is recommended, as gold tags will not be available at pipeline time')
- parser.add_argument('--retag_method', default='xpos', choices=['xpos', 'upos'], help='Which tags to use when retagging')
- parser.add_argument('--retag_model_path', default=None, help='Path to a retag POS model to use. Will use a downloaded Stanza model by default')
- parser.add_argument('--no_retag', dest='retag_package', action="store_const", const=None, help="Don't retag the trees")
+ retagging.add_retag_args(parser)
# Partitioned Attention
parser.add_argument('--pattn_d_model', default=1024, type=int, help='Partitioned attention model dimensionality')
@@ -471,12 +469,7 @@ def parse_args(args=None):
args = vars(args)
- if args['retag_method'] == 'xpos':
- args['retag_xpos'] = True
- elif args['retag_method'] == 'upos':
- args['retag_xpos'] = False
- else:
- raise ValueError("Unknown retag method {}".format(xpos))
+ retagging.postprocess_args(args)
model_save_file = args['save_name'] if args['save_name'] else '{}_constituency.pt'.format(args['shorthand'])
@@ -520,26 +513,7 @@ def main(args=None):
else:
model_load_file = os.path.join(args['save_dir'], args['load_name'])
- if args['retag_package'] is not None and args['mode'] != 'remove_optimizer':
- if '_' in args['retag_package']:
- lang, package = args['retag_package'].split('_', 1)
- else:
- lang = args['lang']
- package = args['retag_package']
- retag_args = {"lang": lang,
- "processors": "tokenize, pos",
- "tokenize_pretokenized": True,
- "package": {"pos": package},
- "pos_tqdm": True}
- if args['retag_model_path'] is not None:
- retag_args['pos_model_path'] = args['retag_model_path']
- retag_pipeline = Pipeline(**retag_args)
- if args['retag_xpos'] and len(retag_pipeline.processors['pos'].vocab['xpos']) == len(VOCAB_PREFIX):
- logger.warning("XPOS for the %s tagger is empty. Switching to UPOS", package)
- args['retag_xpos'] = False
- args['retag_method'] = 'upos'
- else:
- retag_pipeline = None
+ retag_pipeline = retagging.build_retag_pipeline(args)
if args['mode'] == 'train':
trainer.train(args, model_load_file, model_save_each_file, retag_pipeline)