Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/stanfordnlp/stanza.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJohn Bauer <horatio@gmail.com>2022-11-13 22:42:22 +0300
committerJohn Bauer <horatio@gmail.com>2022-11-13 22:42:22 +0300
commit62c42a342ec9813bbbd8825816582ecc2b8bf16b (patch)
treef9593894b6647271971aef98ccf5857d23d18cab
parent766341942962e5a5a0aa0cda3dd170ac098ac6f9 (diff)
Add the ability to ensemble the retag models in --score_dev or ensemble modes
-rw-r--r--stanza/models/constituency/ensemble.py2
-rw-r--r--stanza/models/constituency/retagging.py37
-rw-r--r--stanza/models/constituency/trainer.py6
-rw-r--r--stanza/models/constituency/utils.py20
4 files changed, 44 insertions, 21 deletions
diff --git a/stanza/models/constituency/ensemble.py b/stanza/models/constituency/ensemble.py
index f05216dc..1a354c6f 100644
--- a/stanza/models/constituency/ensemble.py
+++ b/stanza/models/constituency/ensemble.py
@@ -269,7 +269,7 @@ def parse_args(args=None):
def main():
args = parse_args()
retag_pipeline = retagging.build_retag_pipeline(args)
- foundation_cache = retag_pipeline.foundation_cache if retag_pipeline else FoundationCache()
+ foundation_cache = retag_pipeline[0].foundation_cache if retag_pipeline else FoundationCache()
ensemble = Ensemble(args['models'], args, foundation_cache)
ensemble.eval()
diff --git a/stanza/models/constituency/retagging.py b/stanza/models/constituency/retagging.py
index 4bf9e69e..5fa535b2 100644
--- a/stanza/models/constituency/retagging.py
+++ b/stanza/models/constituency/retagging.py
@@ -8,10 +8,12 @@ so as to avoid unnecessary circular imports
(eg, Pipeline imports constituency/trainer which imports this which imports Pipeline)
"""
+import copy
import logging
from stanza import Pipeline
+from stanza.models.common.foundation_cache import FoundationCache
from stanza.models.common.vocab import VOCAB_PREFIX
logger = logging.getLogger('stanza')
@@ -22,7 +24,7 @@ def add_retag_args(parser):
"""
parser.add_argument('--retag_package', default="default", help='Which tagger shortname to use when retagging trees. None for no retagging. Retagging is recommended, as gold tags will not be available at pipeline time')
parser.add_argument('--retag_method', default='xpos', choices=['xpos', 'upos'], help='Which tags to use when retagging')
- parser.add_argument('--retag_model_path', default=None, help='Path to a retag POS model to use. Will use a downloaded Stanza model by default')
+ parser.add_argument('--retag_model_path', default=None, help='Path to a retag POS model to use. Will use a downloaded Stanza model by default. Can specify multiple taggers with ;')
parser.add_argument('--no_retag', dest='retag_package', action="store_const", const=None, help="Don't retag the trees")
def postprocess_args(args):
@@ -38,10 +40,14 @@ def postprocess_args(args):
def build_retag_pipeline(args):
"""
- Build a retag pipeline based on the arguments
+ Builds retag pipelines based on the arguments
May alter the arguments if the pipeline is incompatible, such as
taggers with no xpos
+
+ Will return a list of one or more retag pipelines.
+ Multiple tagger models can be specified by having them
+ semi-colon separated in retag_model_path.
"""
# some argument sets might not use 'mode'
if args['retag_package'] is not None and args.get('mode', None) != 'remove_optimizer':
@@ -52,17 +58,28 @@ def build_retag_pipeline(args):
raise ValueError("Retag package %s does not specify the language, and it is not clear from the arguments" % args['retag_package'])
lang = args.get('lang', None)
package = args['retag_package']
+ foundation_cache = FoundationCache()
retag_args = {"lang": lang,
"processors": "tokenize, pos",
"tokenize_pretokenized": True,
"package": {"pos": package}}
- if args['retag_model_path'] is not None:
- retag_args['pos_model_path'] = args['retag_model_path']
- retag_pipeline = Pipeline(**retag_args)
- if args['retag_xpos'] and len(retag_pipeline.processors['pos'].vocab['xpos']) == len(VOCAB_PREFIX):
- logger.warning("XPOS for the %s tagger is empty. Switching to UPOS", package)
- args['retag_xpos'] = False
- args['retag_method'] = 'upos'
- return retag_pipeline
+
+ def build(retag_args, path):
+ retag_args = copy.deepcopy(retag_args)
+ if path is not None:
+ retag_args['pos_model_path'] = path
+
+ retag_pipeline = Pipeline(foundation_cache=foundation_cache, **retag_args)
+ if args['retag_xpos'] and len(retag_pipeline.processors['pos'].vocab['xpos']) == len(VOCAB_PREFIX):
+ logger.warning("XPOS for the %s tagger is empty. Switching to UPOS", package)
+ args['retag_xpos'] = False
+ args['retag_method'] = 'upos'
+ return retag_pipeline
+
+ if args['retag_model_path'] is None:
+ return [build(retag_args, None)]
+ paths = args['retag_model_path'].split(";")
+ # can be length 1 if only one tagger to work with
+ return [build(retag_args, path) for path in paths]
return None
diff --git a/stanza/models/constituency/trainer.py b/stanza/models/constituency/trainer.py
index 2ee3ae2f..97a6af32 100644
--- a/stanza/models/constituency/trainer.py
+++ b/stanza/models/constituency/trainer.py
@@ -215,7 +215,7 @@ def load_model_parse_text(args, model_file, retag_pipeline):
"""
Load a model, then parse text and write it to stdout or args['predict_file']
"""
- foundation_cache = retag_pipeline.foundation_cache if retag_pipeline else FoundationCache()
+ foundation_cache = retag_pipeline[0].foundation_cache if retag_pipeline else FoundationCache()
load_args = {
'wordvec_pretrain_file': args['wordvec_pretrain_file'],
'charlm_forward_file': args['charlm_forward_file'],
@@ -288,7 +288,7 @@ def evaluate(args, model_file, retag_pipeline):
kbest = None
with EvaluateParser(kbest=kbest) as evaluator:
- foundation_cache = retag_pipeline.foundation_cache if retag_pipeline else FoundationCache()
+ foundation_cache = retag_pipeline[0].foundation_cache if retag_pipeline else FoundationCache()
load_args = {
'wordvec_pretrain_file': args['wordvec_pretrain_file'],
'charlm_forward_file': args['charlm_forward_file'],
@@ -579,7 +579,7 @@ def train(args, model_load_file, model_save_each_file, retag_pipeline):
silver_trees = retag_trees(silver_trees, retag_pipeline, args['retag_xpos'])
logger.info("Retagging finished")
- foundation_cache = retag_pipeline.foundation_cache if retag_pipeline else FoundationCache()
+ foundation_cache = retag_pipeline[0].foundation_cache if retag_pipeline else FoundationCache()
trainer, train_sequences, silver_sequences, train_transitions = build_trainer(args, train_trees, dev_trees, silver_trees, foundation_cache, model_load_file)
trainer = iterate_training(args, trainer, train_trees, train_sequences, train_transitions, dev_trees, silver_trees, silver_sequences, foundation_cache, model_save_each_file, evaluator)
diff --git a/stanza/models/constituency/utils.py b/stanza/models/constituency/utils.py
index be89ab7d..e207bc43 100644
--- a/stanza/models/constituency/utils.py
+++ b/stanza/models/constituency/utils.py
@@ -2,7 +2,7 @@
Collects a few of the conparser utility methods which don't belong elsewhere
"""
-from collections import deque
+from collections import Counter, deque
import copy
import logging
@@ -75,7 +75,7 @@ def replace_tags(tree, tags):
return new_tree
-def retag_trees(trees, pipeline, xpos=True):
+def retag_trees(trees, pipelines, xpos=True):
"""
Retag all of the trees using the given processor
@@ -99,11 +99,17 @@ def retag_trees(trees, pipeline, xpos=True):
raise ValueError("Unable to process tree %d" % (idx + chunk_start)) from e
doc = Document(sentences)
- doc = pipeline(doc)
- if xpos:
- tag_lists = [[x.xpos for x in sentence.words] for sentence in doc.sentences]
- else:
- tag_lists = [[x.upos for x in sentence.words] for sentence in doc.sentences]
+ tag_lists = []
+ for pipeline in pipelines:
+ doc = pipeline(doc)
+ tag_lists.append([[x.xpos if xpos else x.upos for x in sentence.words] for sentence in doc.sentences])
+ # tag_lists: for N pipeline, S sentences
+ # we now have N lists of S sentences each
+ # for sentence in zip(*tag_lists): N lists of |s| tags for this given sentence s
+ # for tag in zip(*sentence): N predicted tags.
+ # most common one in the Counter will be chosen
+ tag_lists = [[Counter(tag).most_common(1)[0][0] for tag in zip(*sentence)]
+ for sentence in zip(*tag_lists)]
for tree_idx, (tree, tags) in enumerate(zip(chunk, tag_lists)):
try: