Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/stanfordnlp/stanza.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJohn Bauer <horatio@gmail.com>2022-10-29 03:33:40 +0300
committerJohn Bauer <horatio@gmail.com>2022-10-29 03:33:40 +0300
commitffca4f4054550a977b6b929c13154c1cf652a39e (patch)
treeba6a52a3dc29bf8771e917f16c37f6ba9de8e4bb
parent29fb29f7425ea22025fb4986c33b3f2c9fa1f155 (diff)
Add functionality to turn a tokenized text file into a file of parse trees
-rw-r--r--stanza/models/constituency/trainer.py53
-rw-r--r--stanza/models/constituency_parser.py5
-rw-r--r--stanza/utils/training/run_constituency.py9
3 files changed, 65 insertions, 2 deletions
diff --git a/stanza/models/constituency/trainer.py b/stanza/models/constituency/trainer.py
index b3d5cc5e..6f9bd79b 100644
--- a/stanza/models/constituency/trainer.py
+++ b/stanza/models/constituency/trainer.py
@@ -211,6 +211,59 @@ def verify_transitions(trees, sequences, transition_scheme, unary_limit):
if tree != result:
raise RuntimeError("Transition sequence did not match for a tree!\nOriginal tree:{}\nTransitions: {}\nResult tree:{}".format(tree, sequence, result))
+def load_model_parse_text(args, model_file, retag_pipeline):
+ """
+ Load a model, then parse text and write it to stdout or args['predict_file']
+ """
+ foundation_cache = retag_pipeline.foundation_cache if retag_pipeline else FoundationCache()
+ load_args = {
+ 'wordvec_pretrain_file': args['wordvec_pretrain_file'],
+ 'charlm_forward_file': args['charlm_forward_file'],
+ 'charlm_backward_file': args['charlm_backward_file'],
+ 'cuda': args['cuda'],
+ }
+ trainer = Trainer.load(model_file, args=load_args, foundation_cache=foundation_cache)
+ model = trainer.model
+ logger.info("Loaded model from %s", model_file)
+
+ parse_text(args, model, retag_pipeline)
+
+def parse_text(args, model, retag_pipeline):
+ """
+ Use the given model to parse text and write it
+
+ refactored so it can be used elsewhere, such as Ensemble
+ """
+ if args['tokenized_file']:
+ with open(args['tokenized_file'], encoding='utf-8') as fin:
+ lines = fin.readlines()
+ lines = [x.strip() for x in lines]
+ # a large chunk of VI wiki data was pretokenized with sentences too long
+ # or with the em-dash, so let's filter those for now
+ # TODO: remove later
+ lines = [x for x in lines if x and len(x) <= 100 and len(x) >= 10 and '—' not in x]
+ docs = [[word.replace("_", " ") for word in sentence.split()] for sentence in lines]
+ logger.info("Processing %d lines", len(docs))
+ doc = retag_pipeline(docs)
+ if args['retag_method'] == 'xpos':
+ words = [[(w.text, w.xpos) for w in s.words] for s in doc.sentences]
+ else:
+ words = [[(w.text, w.upos) for w in s.words] for s in doc.sentences]
+ assert len(words) == len(docs)
+ treebank = model.parse_sentences_no_grad(iter(tqdm(words)), model.build_batch_from_tagged_words, args['eval_batch_size'], model.predict, keep_scores=False)
+ if args['predict_file']:
+ predict_file = args['predict_file']
+ if args['predict_dir']:
+ predict_file = os.path.join(args['predict_dir'], predict_file)
+ with open(predict_file, "w", encoding="utf-8") as fout:
+ for result in treebank:
+ fout.write(args['predict_format'].format(result.predictions[0].tree))
+ fout.write("\n")
+ else:
+ for result in treebank:
+ print(args['predict_format'].format(result.predictions[0].tree))
+
+
def evaluate(args, model_file, retag_pipeline):
"""
Loads the given model file and tests the eval_file treebank.
diff --git a/stanza/models/constituency_parser.py b/stanza/models/constituency_parser.py
index 6b991977..bc4a8331 100644
--- a/stanza/models/constituency_parser.py
+++ b/stanza/models/constituency_parser.py
@@ -178,7 +178,8 @@ def parse_args(args=None):
parser.add_argument('--train_file', type=str, default=None, help='Input file for data loader.')
parser.add_argument('--eval_file', type=str, default=None, help='Input file for data loader.')
- parser.add_argument('--mode', default='train', choices=['train', 'predict', 'remove_optimizer'])
+ parser.add_argument('--tokenized_file', type=str, default=None, help='Input file of tokenized text for parsing with parse_text.')
+ parser.add_argument('--mode', default='train', choices=['train', 'parse_text', 'predict', 'remove_optimizer'])
parser.add_argument('--num_generate', type=int, default=0, help='When running a dev set, how many sentences to generate beyond the greedy one')
parser.add_argument('--predict_dir', type=str, default=".", help='Where to write the predictions during --mode predict. Pred and orig files will be written - the orig file will be retagged if that is requested. Writing the orig file is useful for removing None and retagging')
parser.add_argument('--predict_file', type=str, default=None, help='Base name for writing predictions')
@@ -519,6 +520,8 @@ def main(args=None):
trainer.train(args, model_load_file, model_save_each_file, retag_pipeline)
elif args['mode'] == 'predict':
trainer.evaluate(args, model_load_file, retag_pipeline)
+ elif args['mode'] == 'parse_text':
+ trainer.load_model_parse_text(args, model_load_file, retag_pipeline)
elif args['mode'] == 'remove_optimizer':
trainer.remove_optimizer(args, args['save_name'], model_load_file)
diff --git a/stanza/utils/training/run_constituency.py b/stanza/utils/training/run_constituency.py
index 313b150e..97ff1559 100644
--- a/stanza/utils/training/run_constituency.py
+++ b/stanza/utils/training/run_constituency.py
@@ -30,6 +30,8 @@ def add_constituency_args(parser):
parser.add_argument('--charlm', default="default", type=str, help='Which charlm to run on. Will use the default charlm for this language/model if not set. Set to None to turn off charlm for languages with a default charlm')
parser.add_argument('--no_charlm', dest='charlm', action="store_const", const=None, help="Don't use a charlm, even if one is used by default for this package")
+ parser.add_argument('--parse_text', dest='mode', action='store_const', const="parse_text", help='Parse a text file')
+
def run_treebank(mode, paths, treebank, short_name,
temp_output_file, command_args, extra_args):
constituency_dir = paths["CONSTITUENCY_DATA_DIR"]
@@ -93,7 +95,12 @@ def run_treebank(mode, paths, treebank, short_name,
logger.info("Running test step with args: {}".format(test_args))
constituency_parser.main(test_args)
-
+ if mode == "parse_text":
+ text_args = ['--shorthand', short_name,
+ '--mode', 'parse_text']
+ text_args = text_args + default_args + extra_args
+ logger.info("Processing text with args: {}".format(text_args))
+ constituency_parser.main(text_args)
def main():
common.main(run_treebank, "constituency", "constituency", add_constituency_args)