Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/stanfordnlp/stanza.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJohn Bauer <horatio@gmail.com>2022-10-29 05:21:42 +0300
committerJohn Bauer <horatio@gmail.com>2022-10-29 05:21:42 +0300
commitb595e437999761e17f1eade0e565a11558d51207 (patch)
tree648d4bfac4559dbb13fc67f3af762725e95ec39a
parentad48b49dfcde0d72c9499e522e1b7bdc7523c91e (diff)
Connect model ensembles to the predict_text functionality
-rw-r--r--stanza/models/constituency/ensemble.py63
1 files changed, 44 insertions, 19 deletions
diff --git a/stanza/models/constituency/ensemble.py b/stanza/models/constituency/ensemble.py
index fcf9b9b1..0b8cedc5 100644
--- a/stanza/models/constituency/ensemble.py
+++ b/stanza/models/constituency/ensemble.py
@@ -21,8 +21,8 @@ from stanza.models.common.foundation_cache import FoundationCache
from stanza.models.constituency import parse_transitions
from stanza.models.constituency import retagging
from stanza.models.constituency import tree_reader
-from stanza.models.constituency.trainer import Trainer, run_dev_set
-from stanza.models.constituency.utils import retag_trees
+from stanza.models.constituency.trainer import Trainer, run_dev_set, parse_text
+from stanza.models.constituency.utils import add_predict_output_args, retag_trees
from stanza.resources.common import DEFAULT_MODEL_DIR
from stanza.server.parser_eval import EvaluateParser, ParseResult, ScoredTree
from stanza.utils.default_paths import get_default_paths
@@ -59,6 +59,24 @@ class Ensemble:
for model in self.models:
model.eval()
+ def build_batch_from_tagged_words(self, batch_size, data_iterator):
+ """
+ Read from the data_iterator batch_size tagged sentences and turn them into new parsing states
+
+ Expects a list of list of (word, tag)
+ """
+ state_batch = []
+ for _ in range(batch_size):
+ sentence = next(data_iterator, None)
+ if sentence is None:
+ break
+ state_batch.append(sentence)
+
+ if len(state_batch) > 0:
+ state_batch = [model.initial_state_from_words(state_batch) for model in self.models]
+ state_batch = list(zip(*state_batch))
+ return state_batch
+
def build_batch_from_trees(self, batch_size, data_iterator):
"""
Read from the data_iterator batch_size trees and turn them into N lists of parsing states
@@ -203,6 +221,8 @@ def parse_args(args=None):
parser = argparse.ArgumentParser()
parser.add_argument('--eval_file', type=str, default=None, help='Input file for data loader.')
+ parser.add_argument('--tokenized_file', type=str, default=None, help='Input file of tokenized text for parsing with parse_text.')
+
parser.add_argument('--charlm_forward_file', type=str, default=None, help="Exact path to use for forward charlm")
parser.add_argument('--charlm_backward_file', type=str, default=None, help="Exact path to use for backward charlm")
parser.add_argument('--wordvec_pretrain_file', type=str, default=None, help='Exact name of the pretrain file to read')
@@ -215,7 +235,8 @@ def parse_args(args=None):
parser.add_argument('--eval_batch_size', type=int, default=50, help='How many trees to batch when running eval')
parser.add_argument('models', type=str, nargs='+', default=None, help="Which model(s) to load")
- parser.add_argument('--mode', default='predict', choices=['predict'])
+ parser.add_argument('--mode', default='predict', choices=['parse_text', 'predict'])
+ add_predict_output_args(parser)
retagging.add_retag_args(parser)
# TODO: get default method & package from run_constituency.py
@@ -225,8 +246,6 @@ def parse_args(args=None):
retagging.postprocess_args(args)
args['num_generate'] = 0
- args['predict_file'] = None
- args['predict_dir'] = None
if not args['eval_file'] and args['lang'] in DEFAULT_EVAL:
paths = get_default_paths()
@@ -247,20 +266,26 @@ def main():
foundation_cache = retag_pipeline.foundation_cache if retag_pipeline else FoundationCache()
ensemble = Ensemble(args['models'], args, foundation_cache)
-
- with EvaluateParser() as evaluator:
- treebank = tree_reader.read_treebank(args['eval_file'])
- logger.info("Read %d trees for evaluation", len(treebank))
-
- if retag_pipeline is not None:
- logger.info("Retagging trees using the %s tags from the %s package...", args['retag_method'], args['retag_package'])
- treebank = retag_trees(treebank, retag_pipeline, args['retag_xpos'])
- logger.info("Retagging finished")
-
- f1, kbestF1 = run_dev_set(ensemble, treebank, args, evaluator)
- logger.info("F1 score on %s: %f", args['eval_file'], f1)
- if kbestF1 is not None:
- logger.info("KBest F1 score on %s: %f", args['eval_file'], kbestF1)
+ ensemble.eval()
+
+ if args['mode'] == 'predict':
+ with EvaluateParser() as evaluator:
+ treebank = tree_reader.read_treebank(args['eval_file'])
+ logger.info("Read %d trees for evaluation", len(treebank))
+
+ if retag_pipeline is not None:
+ logger.info("Retagging trees using the %s tags from the %s package...", args['retag_method'], args['retag_package'])
+ treebank = retag_trees(treebank, retag_pipeline, args['retag_xpos'])
+ logger.info("Retagging finished")
+
+ f1, kbestF1 = run_dev_set(ensemble, treebank, args, evaluator)
+ logger.info("F1 score on %s: %f", args['eval_file'], f1)
+ if kbestF1 is not None:
+ logger.info("KBest F1 score on %s: %f", args['eval_file'], kbestF1)
+ elif args['mode'] == 'parse_text':
+ parse_text(args, ensemble, retag_pipeline)
+ else:
+ raise ValueError("Unhandled mode %s" % args['mode'])
if __name__ == "__main__":