diff options
author | John Bauer <horatio@gmail.com> | 2022-10-29 05:21:42 +0300 |
---|---|---|
committer | John Bauer <horatio@gmail.com> | 2022-10-29 05:21:42 +0300 |
commit | b595e437999761e17f1eade0e565a11558d51207 (patch) | |
tree | 648d4bfac4559dbb13fc67f3af762725e95ec39a | |
parent | ad48b49dfcde0d72c9499e522e1b7bdc7523c91e (diff) |
Connect model ensembles to the predict_text functionality
-rw-r--r-- | stanza/models/constituency/ensemble.py | 63 |
1 files changed, 44 insertions, 19 deletions
diff --git a/stanza/models/constituency/ensemble.py b/stanza/models/constituency/ensemble.py index fcf9b9b1..0b8cedc5 100644 --- a/stanza/models/constituency/ensemble.py +++ b/stanza/models/constituency/ensemble.py @@ -21,8 +21,8 @@ from stanza.models.common.foundation_cache import FoundationCache from stanza.models.constituency import parse_transitions from stanza.models.constituency import retagging from stanza.models.constituency import tree_reader -from stanza.models.constituency.trainer import Trainer, run_dev_set -from stanza.models.constituency.utils import retag_trees +from stanza.models.constituency.trainer import Trainer, run_dev_set, parse_text +from stanza.models.constituency.utils import add_predict_output_args, retag_trees from stanza.resources.common import DEFAULT_MODEL_DIR from stanza.server.parser_eval import EvaluateParser, ParseResult, ScoredTree from stanza.utils.default_paths import get_default_paths @@ -59,6 +59,24 @@ class Ensemble: for model in self.models: model.eval() + def build_batch_from_tagged_words(self, batch_size, data_iterator): + """ + Read from the data_iterator batch_size tagged sentences and turn them into new parsing states + + Expects a list of list of (word, tag) + """ + state_batch = [] + for _ in range(batch_size): + sentence = next(data_iterator, None) + if sentence is None: + break + state_batch.append(sentence) + + if len(state_batch) > 0: + state_batch = [model.initial_state_from_words(state_batch) for model in self.models] + state_batch = list(zip(*state_batch)) + return state_batch + def build_batch_from_trees(self, batch_size, data_iterator): """ Read from the data_iterator batch_size trees and turn them into N lists of parsing states @@ -203,6 +221,8 @@ def parse_args(args=None): parser = argparse.ArgumentParser() parser.add_argument('--eval_file', type=str, default=None, help='Input file for data loader.') + parser.add_argument('--tokenized_file', type=str, default=None, help='Input file of tokenized text for parsing with parse_text.') + parser.add_argument('--charlm_forward_file', type=str, default=None, help="Exact path to use for forward charlm") parser.add_argument('--charlm_backward_file', type=str, default=None, help="Exact path to use for backward charlm") parser.add_argument('--wordvec_pretrain_file', type=str, default=None, help='Exact name of the pretrain file to read') @@ -215,7 +235,8 @@ def parse_args(args=None): parser.add_argument('--eval_batch_size', type=int, default=50, help='How many trees to batch when running eval') parser.add_argument('models', type=str, nargs='+', default=None, help="Which model(s) to load") - parser.add_argument('--mode', default='predict', choices=['predict']) + parser.add_argument('--mode', default='predict', choices=['parse_text', 'predict']) + add_predict_output_args(parser) retagging.add_retag_args(parser) # TODO: get default method & package from run_constituency.py @@ -225,8 +246,6 @@ def parse_args(args=None): retagging.postprocess_args(args) args['num_generate'] = 0 - args['predict_file'] = None - args['predict_dir'] = None if not args['eval_file'] and args['lang'] in DEFAULT_EVAL: paths = get_default_paths() @@ -247,20 +266,26 @@ def main(): foundation_cache = retag_pipeline.foundation_cache if retag_pipeline else FoundationCache() ensemble = Ensemble(args['models'], args, foundation_cache) - - with EvaluateParser() as evaluator: - treebank = tree_reader.read_treebank(args['eval_file']) - logger.info("Read %d trees for evaluation", len(treebank)) - - if retag_pipeline is not None: - logger.info("Retagging trees using the %s tags from the %s package...", args['retag_method'], args['retag_package']) - treebank = retag_trees(treebank, retag_pipeline, args['retag_xpos']) - logger.info("Retagging finished") - - f1, kbestF1 = run_dev_set(ensemble, treebank, args, evaluator) - logger.info("F1 score on %s: %f", args['eval_file'], f1) - if kbestF1 is not None: - logger.info("KBest F1 score on %s: %f", args['eval_file'], kbestF1) + ensemble.eval() + + if args['mode'] == 'predict': + with EvaluateParser() as evaluator: + treebank = tree_reader.read_treebank(args['eval_file']) + logger.info("Read %d trees for evaluation", len(treebank)) + + if retag_pipeline is not None: + logger.info("Retagging trees using the %s tags from the %s package...", args['retag_method'], args['retag_package']) + treebank = retag_trees(treebank, retag_pipeline, args['retag_xpos']) + logger.info("Retagging finished") + + f1, kbestF1 = run_dev_set(ensemble, treebank, args, evaluator) + logger.info("F1 score on %s: %f", args['eval_file'], f1) + if kbestF1 is not None: + logger.info("KBest F1 score on %s: %f", args['eval_file'], kbestF1) + elif args['mode'] == 'parse_text': + parse_text(args, ensemble, retag_pipeline) + else: + raise ValueError("Unhandled mode %s" % args['mode']) if __name__ == "__main__": |