diff options
author | John Bauer <horatio@gmail.com> | 2022-11-04 22:18:54 +0300 |
---|---|---|
committer | John Bauer <horatio@gmail.com> | 2022-11-04 22:18:54 +0300 |
commit | 4dbe964a545d785d8151e6e82500b1ce1da1d1a7 (patch) | |
tree | 50d6e913d00a61a33d136bedea158d26fb22def1 | |
parent | 0efa21be676593e6b96893a5dbe60b2994fe69c9 (diff) |
Refactor the tokenization method from tokenize_wiki.py Reuse it to add an option to selftrain_vi_quad which only tokenizes the data without parsing it
-rw-r--r-- | stanza/utils/datasets/constituency/selftrain.py | 59 | ||||
-rw-r--r-- | stanza/utils/datasets/constituency/selftrain_vi_quad.py | 44 | ||||
-rw-r--r-- | stanza/utils/datasets/constituency/tokenize_wiki.py | 56 |
3 files changed, 95 insertions, 64 deletions
diff --git a/stanza/utils/datasets/constituency/selftrain.py b/stanza/utils/datasets/constituency/selftrain.py index 07d2f20d..7334d70e 100644 --- a/stanza/utils/datasets/constituency/selftrain.py +++ b/stanza/utils/datasets/constituency/selftrain.py @@ -47,6 +47,33 @@ def common_args(parser): help='Output trees in PTB brackets (default is a bracket language format)' ) +def add_length_args(parser): + parser.add_argument( + '--min_len', + default=5, + type=int, + help='Minimum length sentence to keep. None = unlimited' + ) + parser.add_argument( + '--no_min_len', + dest='min_len', + action='store_const', + const=None, + help='No minimum length' + ) + parser.add_argument( + '--max_len', + default=100, + type=int, + help='Maximum length sentence to keep. None = unlimited' + ) + parser.add_argument( + '--no_max_len', + dest='max_len', + action='store_const', + const=None, + help='No maximum length' + ) def build_ssplit_pipe(ssplit, lang): if ssplit: @@ -107,6 +134,38 @@ def split_docs(docs, ssplit_pipe, max_len=140, max_word_len=50, chunk_size=2000) logger.info("Sentences filtered for length: %d", filtered_sentences) return new_docs +def tokenize_docs(docs, pipe, min_len, max_len): + """ + Turn the text in docs into a list of whitespace separated sentences + + docs: a list of strings + pipe: a Stanza pipeline for tokenizing + min_len, max_len: can be None to not filter by this attribute + """ + results = [] + docs = [stanza.Document([], text=t) for t in docs] + pipe(docs) + for doc in docs: + for sentence in doc.sentences: + if min_len and len(sentence.words) < min_len: + continue + if max_len and len(sentence.words) > max_len: + continue + text = sentence.text + if (text.find("|") >= 0 or text.find("_") >= 0 or + text.find("<") >= 0 or text.find(">") >= 0 or + text.find("[") >= 0 or text.find("]") >= 0 or + text.find('—') >= 0): # an em dash, seems to be part of lists + continue + # the VI tokenizer in particular doesn't split these well + if any(any(w.text.find(c) >= 0 and len(w.text) > 1 for w in sentence.words) + for c in '"()'): + continue + text = [w.text.replace(" ", "_") for w in sentence.words] + text = " ".join(text) + results.append(text) + return results + def find_matching_trees(docs, num_sentences, accepted_trees, tag_pipe, parser_pipes, shuffle=True, chunk_size=10, max_len=140, min_len=10, output_ptb=False): """ Find trees where all the parsers in parser_pipes agree diff --git a/stanza/utils/datasets/constituency/selftrain_vi_quad.py b/stanza/utils/datasets/constituency/selftrain_vi_quad.py index ac5b287a..3426dcb4 100644 --- a/stanza/utils/datasets/constituency/selftrain_vi_quad.py +++ b/stanza/utils/datasets/constituency/selftrain_vi_quad.py @@ -6,6 +6,7 @@ import argparse import json import logging +import stanza from stanza.utils.datasets.constituency import selftrain logger = logging.getLogger('stanza') @@ -15,11 +16,18 @@ def parse_args(): description="Script that converts vi quad to silver standard trees" ) selftrain.common_args(parser) + selftrain.add_length_args(parser) parser.add_argument( '--input_file', default="extern_data/vietnamese/ViQuAD/train_ViQuAD.json", help='Path to the ViQuAD train file' ) + parser.add_argument( + '--tokenize_only', + default=False, + action='store_true', + help='Tokenize instead of writing trees' + ) args = parser.parse_args() return args @@ -61,22 +69,30 @@ def main(): """ args = parse_args() - tag_pipe = selftrain.build_tag_pipe(ssplit=False, lang=args.lang) - parser_pipes = selftrain.build_parser_pipes(args.lang, args.models) - - # create a blank file. we will append to this file so that partial results can be used - with open(args.output_file, "w") as fout: - pass - - accepted_trees = set() docs = read_quad(args.input_file) logger.info("Read %d lines from %s", len(docs), args.input_file) - new_trees = selftrain.find_matching_trees(docs, args.num_sentences, accepted_trees, tag_pipe, parser_pipes, shuffle=False, chunk_size=100) - new_trees = [tree for tree in new_trees if tree.find("(_SQ") >= 0] - with open(args.output_file, "a") as fout: - for tree in sorted(new_trees): - fout.write(tree) - fout.write("\n") + if args.tokenize_only: + pipe = stanza.Pipeline(args.lang, processors="tokenize") + text = selftrain.tokenize_docs(docs, pipe, args.min_len, args.max_len) + with open(args.output_file, "w", encoding="utf-8") as fout: + for line in text: + fout.write(line) + fout.write("\n") + else: + tag_pipe = selftrain.build_tag_pipe(ssplit=False, lang=args.lang) + parser_pipes = selftrain.build_parser_pipes(args.lang, args.models) + + # create a blank file. we will append to this file so that partial results can be used + with open(args.output_file, "w") as fout: + pass + + accepted_trees = set() + new_trees = selftrain.find_matching_trees(docs, args.num_sentences, accepted_trees, tag_pipe, parser_pipes, shuffle=False, chunk_size=100) + new_trees = [tree for tree in new_trees if tree.find("(_SQ") >= 0] + with open(args.output_file, "a") as fout: + for tree in sorted(new_trees): + fout.write(tree) + fout.write("\n") if __name__ == '__main__': main() diff --git a/stanza/utils/datasets/constituency/tokenize_wiki.py b/stanza/utils/datasets/constituency/tokenize_wiki.py index feeda4b2..6c6f80ac 100644 --- a/stanza/utils/datasets/constituency/tokenize_wiki.py +++ b/stanza/utils/datasets/constituency/tokenize_wiki.py @@ -17,6 +17,7 @@ import argparse import stanza from stanza.models.common.utils import get_tqdm from stanza.utils.datasets.constituency import selftrain_wiki +from stanza.utils.datasets.constituency.selftrain import add_length_args, tokenize_docs tqdm = get_tqdm() @@ -39,36 +40,10 @@ def parse_args(): default='extern_data/vietnamese/wikipedia/text/AA', help='Path to the wikipedia dump after processing by wikiextractor' ) - parser.add_argument( - '--min_len', - default=5, - type=int, - help='Minimum length sentence to keep. None = unlimited' - ) - parser.add_argument( - '--no_min_len', - dest='min_len', - action='store_const', - const=None, - help='No minimum length' - ) - parser.add_argument( - '--max_len', - default=100, - type=int, - help='Maximum length sentence to keep. None = unlimited' - ) - parser.add_argument( - '--no_max_len', - dest='max_len', - action='store_const', - const=None, - help='No maximum length' - ) + add_length_args(parser) args = parser.parse_args() return args - def main(): args = parse_args() files = selftrain_wiki.list_wikipedia_files(args.input_dir) @@ -78,29 +53,10 @@ def main(): with open(args.output_file, "w", encoding="utf-8") as fout: for filename in tqdm(files): docs = selftrain_wiki.read_wiki_file(filename) - docs = [stanza.Document([], text=t) for t in docs] - pipe(docs) - - for doc in docs: - for sentence in doc.sentences: - if args.min_len and len(sentence.words) < args.min_len: - continue - if args.max_len and len(sentence.words) > args.max_len: - continue - text = sentence.text - if (text.find("|") >= 0 or text.find("_") >= 0 or - text.find("<") >= 0 or text.find(">") >= 0 or - text.find("[") >= 0 or text.find("]") >= 0 or - text.find('—') >= 0): # an em dash, seems to be part of lists - continue - # the VI tokenizer in particular doesn't split these well - if any(any(w.text.find(c) >= 0 and len(w.text) > 1 for w in sentence.words) - for c in '"()'): - continue - text = [w.text.replace(" ", "_") for w in sentence.words] - text = " ".join(text) - fout.write(text) - fout.write("\n") + text = tokenize_docs(docs, pipe, args.min_len, args.max_len) + for line in text: + fout.write(line) + fout.write("\n") if __name__ == '__main__': main() |