diff options
author | John Bauer <horatio@gmail.com> | 2022-10-29 03:33:40 +0300 |
---|---|---|
committer | John Bauer <horatio@gmail.com> | 2022-10-29 03:33:40 +0300 |
commit | ffca4f4054550a977b6b929c13154c1cf652a39e (patch) | |
tree | ba6a52a3dc29bf8771e917f16c37f6ba9de8e4bb | |
parent | 29fb29f7425ea22025fb4986c33b3f2c9fa1f155 (diff) |
Add functionality to turn a tokenized text file into a file of parse trees
-rw-r--r-- | stanza/models/constituency/trainer.py | 53 | ||||
-rw-r--r-- | stanza/models/constituency_parser.py | 5 | ||||
-rw-r--r-- | stanza/utils/training/run_constituency.py | 9 |
3 files changed, 65 insertions, 2 deletions
diff --git a/stanza/models/constituency/trainer.py b/stanza/models/constituency/trainer.py index b3d5cc5e..6f9bd79b 100644 --- a/stanza/models/constituency/trainer.py +++ b/stanza/models/constituency/trainer.py @@ -211,6 +211,59 @@ def verify_transitions(trees, sequences, transition_scheme, unary_limit): if tree != result: raise RuntimeError("Transition sequence did not match for a tree!\nOriginal tree:{}\nTransitions: {}\nResult tree:{}".format(tree, sequence, result)) +def load_model_parse_text(args, model_file, retag_pipeline): + """ + Load a model, then parse text and write it to stdout or args['predict_file'] + """ + foundation_cache = retag_pipeline.foundation_cache if retag_pipeline else FoundationCache() + load_args = { + 'wordvec_pretrain_file': args['wordvec_pretrain_file'], + 'charlm_forward_file': args['charlm_forward_file'], + 'charlm_backward_file': args['charlm_backward_file'], + 'cuda': args['cuda'], + } + trainer = Trainer.load(model_file, args=load_args, foundation_cache=foundation_cache) + model = trainer.model + logger.info("Loaded model from %s", model_file) + + parse_text(args, model, retag_pipeline) + +def parse_text(args, model, retag_pipeline): + """ + Use the given model to parse text and write it + + refactored so it can be used elsewhere, such as Ensemble + """ + if args['tokenized_file']: + with open(args['tokenized_file'], encoding='utf-8') as fin: + lines = fin.readlines() + lines = [x.strip() for x in lines] + # a large chunk of VI wiki data was pretokenized with sentences too long + # or with the em-dash, so let's filter those for now + # TODO: remove later + lines = [x for x in lines if x and len(x) <= 100 and len(x) >= 10 and '—' not in x] + docs = [[word.replace("_", " ") for word in sentence.split()] for sentence in lines] + logger.info("Processing %d lines", len(docs)) + doc = retag_pipeline(docs) + if args['retag_method'] == 'xpos': + words = [[(w.text, w.xpos) for w in s.words] for s in doc.sentences] + else: + words = [[(w.text, w.upos) for w in s.words] for s in doc.sentences] + assert len(words) == len(docs) + treebank = model.parse_sentences_no_grad(iter(tqdm(words)), model.build_batch_from_tagged_words, args['eval_batch_size'], model.predict, keep_scores=False) + if args['predict_file']: + predict_file = args['predict_file'] + if args['predict_dir']: + predict_file = os.path.join(args['predict_dir'], predict_file) + with open(predict_file, "w", encoding="utf-8") as fout: + for result in treebank: + fout.write(args['predict_format'].format(result.predictions[0].tree)) + fout.write("\n") + else: + for result in treebank: + print(args['predict_format'].format(result.predictions[0].tree)) + + def evaluate(args, model_file, retag_pipeline): """ Loads the given model file and tests the eval_file treebank. diff --git a/stanza/models/constituency_parser.py b/stanza/models/constituency_parser.py index 6b991977..bc4a8331 100644 --- a/stanza/models/constituency_parser.py +++ b/stanza/models/constituency_parser.py @@ -178,7 +178,8 @@ def parse_args(args=None): parser.add_argument('--train_file', type=str, default=None, help='Input file for data loader.') parser.add_argument('--eval_file', type=str, default=None, help='Input file for data loader.') - parser.add_argument('--mode', default='train', choices=['train', 'predict', 'remove_optimizer']) + parser.add_argument('--tokenized_file', type=str, default=None, help='Input file of tokenized text for parsing with parse_text.') + parser.add_argument('--mode', default='train', choices=['train', 'parse_text', 'predict', 'remove_optimizer']) parser.add_argument('--num_generate', type=int, default=0, help='When running a dev set, how many sentences to generate beyond the greedy one') parser.add_argument('--predict_dir', type=str, default=".", help='Where to write the predictions during --mode predict. Pred and orig files will be written - the orig file will be retagged if that is requested. Writing the orig file is useful for removing None and retagging') parser.add_argument('--predict_file', type=str, default=None, help='Base name for writing predictions') @@ -519,6 +520,8 @@ def main(args=None): trainer.train(args, model_load_file, model_save_each_file, retag_pipeline) elif args['mode'] == 'predict': trainer.evaluate(args, model_load_file, retag_pipeline) + elif args['mode'] == 'parse_text': + trainer.load_model_parse_text(args, model_load_file, retag_pipeline) elif args['mode'] == 'remove_optimizer': trainer.remove_optimizer(args, args['save_name'], model_load_file) diff --git a/stanza/utils/training/run_constituency.py b/stanza/utils/training/run_constituency.py index 313b150e..97ff1559 100644 --- a/stanza/utils/training/run_constituency.py +++ b/stanza/utils/training/run_constituency.py @@ -30,6 +30,8 @@ def add_constituency_args(parser): parser.add_argument('--charlm', default="default", type=str, help='Which charlm to run on. Will use the default charlm for this language/model if not set. Set to None to turn off charlm for languages with a default charlm') parser.add_argument('--no_charlm', dest='charlm', action="store_const", const=None, help="Don't use a charlm, even if one is used by default for this package") + parser.add_argument('--parse_text', dest='mode', action='store_const', const="parse_text", help='Parse a text file') + def run_treebank(mode, paths, treebank, short_name, temp_output_file, command_args, extra_args): constituency_dir = paths["CONSTITUENCY_DATA_DIR"] @@ -93,7 +95,12 @@ def run_treebank(mode, paths, treebank, short_name, logger.info("Running test step with args: {}".format(test_args)) constituency_parser.main(test_args) - + if mode == "parse_text": + text_args = ['--shorthand', short_name, + '--mode', 'parse_text'] + text_args = text_args + default_args + extra_args + logger.info("Processing text with args: {}".format(text_args)) + constituency_parser.main(text_args) def main(): common.main(run_treebank, "constituency", "constituency", add_constituency_args) |