diff options
author | John Bauer <horatio@gmail.com> | 2022-09-08 01:22:51 +0300 |
---|---|---|
committer | John Bauer <horatio@gmail.com> | 2022-09-08 20:04:00 +0300 |
commit | 7aee87a84e1574043c382b02f6f1f0f4a691e2ce (patch) | |
tree | dc655e88f1f9b5a19a5eb5a63c41268ace06a0fd | |
parent | 1d33718ab625da4251f8b6863f9b74a13c5fa09a (diff) |
Rearrange a bunch of functions from prepare_tokenizer_treebank to a common file
Move the read/write conllu functions to a common folder so they can be used elsewhere
Move the MWT_RE etc as well
Move prepare_treebank_labels to common (and rename it)
Move convert_conllu_to_txt as well
Refactor a tokenizer_conllu_name function
-rw-r--r-- | stanza/tests/tokenization/test_tokenization_lst20.py | 2 | ||||
-rw-r--r-- | stanza/tests/tokenization/test_tokenization_orchid.py | 2 | ||||
-rw-r--r-- | stanza/utils/datasets/common.py | 82 | ||||
-rw-r--r-- | stanza/utils/datasets/corenlp_segmenter_dataset.py | 7 | ||||
-rwxr-xr-x | stanza/utils/datasets/prepare_tokenizer_treebank.py | 89 |
5 files changed, 98 insertions, 84 deletions
diff --git a/stanza/tests/tokenization/test_tokenization_lst20.py b/stanza/tests/tokenization/test_tokenization_lst20.py index a0728123..67928b5b 100644 --- a/stanza/tests/tokenization/test_tokenization_lst20.py +++ b/stanza/tests/tokenization/test_tokenization_lst20.py @@ -6,7 +6,7 @@ import pytest import stanza from stanza.tests import * -from stanza.utils.datasets.prepare_tokenizer_treebank import convert_conllu_to_txt +from stanza.utils.datasets.common import convert_conllu_to_txt from stanza.utils.datasets.tokenization.convert_th_lst20 import read_document from stanza.utils.datasets.tokenization.process_thai_tokenization import write_section diff --git a/stanza/tests/tokenization/test_tokenization_orchid.py b/stanza/tests/tokenization/test_tokenization_orchid.py index 8c0fb9f5..8a186e26 100644 --- a/stanza/tests/tokenization/test_tokenization_orchid.py +++ b/stanza/tests/tokenization/test_tokenization_orchid.py @@ -8,7 +8,7 @@ import xml.etree.ElementTree as ET import stanza from stanza.tests import * -from stanza.utils.datasets.prepare_tokenizer_treebank import convert_conllu_to_txt +from stanza.utils.datasets.common import convert_conllu_to_txt from stanza.utils.datasets.tokenization.convert_th_orchid import parse_xml from stanza.utils.datasets.tokenization.process_thai_tokenization import write_section diff --git a/stanza/utils/datasets/common.py b/stanza/utils/datasets/common.py index efdbb8cf..871ebb80 100644 --- a/stanza/utils/datasets/common.py +++ b/stanza/utils/datasets/common.py @@ -3,13 +3,93 @@ import argparse import glob import logging import os +import re +import subprocess import sys -import stanza.utils.default_paths as default_paths from stanza.models.common.short_name_to_treebank import canonical_treebank_name +import stanza.utils.datasets.prepare_tokenizer_data as prepare_tokenizer_data +import stanza.utils.default_paths as default_paths logger = logging.getLogger('stanza') +# RE to see if the index of a conllu line represents an MWT +MWT_RE = re.compile("^[0-9]+[-][0-9]+") + +# RE to see if the index of a conllu line represents an MWT or copy node +MWT_OR_COPY_RE = re.compile("^[0-9]+[-.][0-9]+") + +# more restrictive than an actual int as we expect certain formats in the conllu files +INT_RE = re.compile("^[0-9]+$") + +CONLLU_TO_TXT_PERL = os.path.join(os.path.split(__file__)[0], "conllu_to_text.pl") + +def convert_conllu_to_txt(tokenizer_dir, short_name, shards=("train", "dev", "test")): + """ + Uses the udtools perl script to convert a conllu file to txt + + TODO: switch to a python version to get rid of some perl dependence + """ + for dataset in shards: + output_conllu = f"{tokenizer_dir}/{short_name}.{dataset}.gold.conllu" + output_txt = f"{tokenizer_dir}/{short_name}.{dataset}.txt" + + if not os.path.exists(output_conllu): + # the perl script doesn't raise an error code for file not found! + raise FileNotFoundError("Cannot convert %s as the file cannot be found" % output_conllu) + # use an external script to produce the txt files + subprocess.check_output(f"perl {CONLLU_TO_TXT_PERL} {output_conllu} > {output_txt}", shell=True) + +def mwt_name(base_dir, short_name, dataset): + return os.path.join(base_dir, f"{short_name}-ud-{dataset}-mwt.json") + +def tokenizer_conllu_name(base_dir, short_name, dataset): + return os.path.join(base_dir, f"{short_name}.{dataset}.gold.conllu") + +def prepare_tokenizer_dataset_labels(input_txt, input_conllu, tokenizer_dir, short_name, dataset): + prepare_tokenizer_data.main([input_txt, + input_conllu, + "-o", f"{tokenizer_dir}/{short_name}-ud-{dataset}.toklabels", + "-m", mwt_name(tokenizer_dir, short_name, dataset)]) + +def prepare_tokenizer_treebank_labels(tokenizer_dir, short_name): + """ + Given the txt and gold.conllu files, prepare mwt and label files for train/dev/test + """ + for dataset in ("train", "dev", "test"): + output_txt = f"{tokenizer_dir}/{short_name}.{dataset}.txt" + output_conllu = f"{tokenizer_dir}/{short_name}.{dataset}.gold.conllu" + try: + prepare_tokenizer_dataset_labels(output_txt, output_conllu, tokenizer_dir, short_name, dataset) + except (KeyboardInterrupt, SystemExit): + raise + except: + print("Failed to convert %s to %s" % (output_txt, output_conllu)) + raise + +def read_sentences_from_conllu(filename): + sents = [] + cache = [] + with open(filename, encoding="utf-8") as infile: + for line in infile: + line = line.strip() + if len(line) == 0: + if len(cache) > 0: + sents.append(cache) + cache = [] + continue + cache.append(line) + if len(cache) > 0: + sents.append(cache) + return sents + +def write_sentences_to_conllu(filename, sents): + with open(filename, 'w', encoding="utf-8") as outfile: + for lines in sents: + for line in lines: + print(line, file=outfile) + print("", file=outfile) + def find_treebank_dataset_file(treebank, udbase_dir, dataset, extension, fail=False): """ For a given treebank, dataset, extension, look for the exact filename to use. diff --git a/stanza/utils/datasets/corenlp_segmenter_dataset.py b/stanza/utils/datasets/corenlp_segmenter_dataset.py index 9ebf7783..b2a275a3 100644 --- a/stanza/utils/datasets/corenlp_segmenter_dataset.py +++ b/stanza/utils/datasets/corenlp_segmenter_dataset.py @@ -12,6 +12,7 @@ import os import sys import tempfile +import stanza.utils.datasets.common as common import stanza.utils.datasets.prepare_tokenizer_treebank as prepare_tokenizer_treebank import stanza.utils.default_paths as default_paths @@ -54,9 +55,9 @@ def process_treebank(treebank, paths, output_dir): dev_file = f"{tokenizer_dir}/{short_name}.dev.gold.conllu" test_file = f"{tokenizer_dir}/{short_name}.test.gold.conllu" - train_set = prepare_tokenizer_treebank.read_sentences_from_conllu(train_file) - dev_set = prepare_tokenizer_treebank.read_sentences_from_conllu(dev_file) - test_set = prepare_tokenizer_treebank.read_sentences_from_conllu(test_file) + train_set = common.read_sentences_from_conllu(train_file) + dev_set = common.read_sentences_from_conllu(dev_file) + test_set = common.read_sentences_from_conllu(test_file) train_out = os.path.join(output_dir, f"{short_name}.train.seg.txt") test_out = os.path.join(output_dir, f"{short_name}.test.seg.txt") diff --git a/stanza/utils/datasets/prepare_tokenizer_treebank.py b/stanza/utils/datasets/prepare_tokenizer_treebank.py index d03b81ac..ca10af3f 100755 --- a/stanza/utils/datasets/prepare_tokenizer_treebank.py +++ b/stanza/utils/datasets/prepare_tokenizer_treebank.py @@ -26,14 +26,13 @@ import glob import os import random import re -import subprocess import tempfile from collections import Counter from stanza.models.common.constant import treebank_to_short_name import stanza.utils.datasets.common as common -import stanza.utils.datasets.prepare_tokenizer_data as prepare_tokenizer_data +from stanza.utils.datasets.common import read_sentences_from_conllu, write_sentences_to_conllu, INT_RE, MWT_RE, MWT_OR_COPY_RE import stanza.utils.datasets.tokenization.convert_my_alt as convert_my_alt import stanza.utils.datasets.tokenization.convert_vi_vlsp as convert_vi_vlsp import stanza.utils.datasets.tokenization.convert_th_best as convert_th_best @@ -82,29 +81,6 @@ def copy_conllu_treebank(treebank, paths, dest_dir, postprocess=None, augment=Tr postprocess(tokenizer_dir, "test.gold", dest_dir, "test.gold", short_name) copy_conllu_file(dest_dir, "test.gold", dest_dir, "test.in", short_name) -def read_sentences_from_conllu(filename): - sents = [] - cache = [] - with open(filename, encoding="utf-8") as infile: - for line in infile: - line = line.strip() - if len(line) == 0: - if len(cache) > 0: - sents.append(cache) - cache = [] - continue - cache.append(line) - if len(cache) > 0: - sents.append(cache) - return sents - -def write_sentences_to_conllu(filename, sents): - with open(filename, 'w', encoding="utf-8") as outfile: - for lines in sents: - for line in lines: - print(line, file=outfile) - print("", file=outfile) - def split_train_file(treebank, train_input_conllu, train_output_conllu, dev_output_conllu): # set the seed for each data file so that the results are the same # regardless of how many treebanks are processed at once @@ -129,49 +105,6 @@ def split_train_file(treebank, train_input_conllu, train_output_conllu, dev_outp return True -def mwt_name(base_dir, short_name, dataset): - return f"{base_dir}/{short_name}-ud-{dataset}-mwt.json" - -def prepare_dataset_labels(input_txt, input_conllu, tokenizer_dir, short_name, dataset): - prepare_tokenizer_data.main([input_txt, - input_conllu, - "-o", f"{tokenizer_dir}/{short_name}-ud-{dataset}.toklabels", - "-m", mwt_name(tokenizer_dir, short_name, dataset)]) - -def prepare_treebank_labels(tokenizer_dir, short_name): - for dataset in ("train", "dev", "test"): - output_txt = f"{tokenizer_dir}/{short_name}.{dataset}.txt" - output_conllu = f"{tokenizer_dir}/{short_name}.{dataset}.gold.conllu" - try: - prepare_dataset_labels(output_txt, output_conllu, tokenizer_dir, short_name, dataset) - except (KeyboardInterrupt, SystemExit): - raise - except: - print("Failed to convert %s to %s" % (output_txt, output_conllu)) - raise - -CONLLU_TO_TXT_PERL = os.path.join(os.path.split(__file__)[0], "conllu_to_text.pl") - -def convert_conllu_to_txt(tokenizer_dir, short_name, shards=("train", "dev", "test")): - for dataset in shards: - output_conllu = f"{tokenizer_dir}/{short_name}.{dataset}.gold.conllu" - output_txt = f"{tokenizer_dir}/{short_name}.{dataset}.txt" - - if not os.path.exists(output_conllu): - # the perl script doesn't raise an error code for file not found! - raise FileNotFoundError("Cannot convert %s as the file cannot be found" % output_conllu) - # use an external script to produce the txt files - subprocess.check_output(f"perl {CONLLU_TO_TXT_PERL} {output_conllu} > {output_txt}", shell=True) - - -# RE to see if the index of a conllu line represents an MWT -MWT_RE = re.compile("^[0-9]+[-][0-9]+") - -# RE to see if the index of a conllu line represents an MWT or copy node -MWT_OR_COPY_RE = re.compile("^[0-9]+[-.][0-9]+") - -# more restrictive than an actual int as we expect certain formats in the conllu files -INT_RE = re.compile("^[0-9]+$") def strip_mwt_from_sentences(sents): """ @@ -801,7 +734,7 @@ def build_combined_korean_dataset(udbase_dir, tokenizer_dir, short_name, dataset def build_combined_korean(udbase_dir, tokenizer_dir, short_name): for dataset in ("train", "dev", "test"): - output_conllu = f"{tokenizer_dir}/{short_name}.{dataset}.gold.conllu" + output_conllu = common.tokenizer_conllu_name(tokenizer_dir, short_name, dataset) build_combined_korean_dataset(udbase_dir, tokenizer_dir, short_name, dataset, output_conllu) def build_combined_italian_dataset(paths, dataset): @@ -1005,7 +938,7 @@ def build_combined_dataset(paths, short_name, augment): build_fn = COMBINED_FNS[short_name] extra_fn = COMBINED_EXTRA_FNS.get(short_name, None) for dataset in ("train", "dev", "test"): - output_conllu = f"{tokenizer_dir}/{short_name}.{dataset}.gold.conllu" + output_conllu = common.tokenizer_conllu_name(tokenizer_dir, short_name, dataset) sents = build_fn(paths, dataset) if dataset == 'train' and augment: sents = augment_punct(sents) @@ -1025,7 +958,7 @@ def build_bio_dataset(paths, udbase_dir, tokenizer_dir, handparsed_dir, short_na name, bio_dataset = short_name.split("_") assert name == 'en' for dataset in ("train", "dev", "test"): - output_conllu = f"{tokenizer_dir}/{short_name}.{dataset}.gold.conllu" + output_conllu = common.tokenizer_conllu_name(tokenizer_dir, short_name, dataset) if dataset == 'train': sents = build_combined_english_dataset(paths, dataset) if dataset == 'train' and augment: @@ -1045,7 +978,7 @@ def build_combined_english_gum_dataset(udbase_dir, tokenizer_dir, short_name, da check_gum_ready(udbase_dir) random.seed(1234) - output_conllu = f"{tokenizer_dir}/{short_name}.{dataset}.gold.conllu" + output_conllu = common.tokenizer_conllu_name(tokenizer_dir, short_name, dataset) treebanks = ["UD_English-GUM", "UD_English-GUMReddit"] sents = [] @@ -1066,7 +999,7 @@ def prepare_ud_dataset(treebank, udbase_dir, tokenizer_dir, short_name, short_la if input_conllu is None: input_conllu = common.find_treebank_dataset_file(treebank, udbase_dir, dataset, "conllu", fail=True) if output_conllu is None: - output_conllu = f"{tokenizer_dir}/{short_name}.{dataset}.gold.conllu" + output_conllu = common.tokenizer_conllu_name(tokenizer_dir, short_name, dataset) print("Reading from %s and writing to %s" % (input_conllu, output_conllu)) if short_name == "te_mtg" and dataset == 'train' and augment: @@ -1116,9 +1049,9 @@ def process_partial_ud_treebank(treebank, udbase_dir, tokenizer_dir, short_name, train_input_conllu = common.find_treebank_dataset_file(treebank, udbase_dir, "train", "conllu") test_input_conllu = common.find_treebank_dataset_file(treebank, udbase_dir, "test", "conllu") - train_output_conllu = f"{tokenizer_dir}/{short_name}.train.gold.conllu" - dev_output_conllu = f"{tokenizer_dir}/{short_name}.dev.gold.conllu" - test_output_conllu = f"{tokenizer_dir}/{short_name}.test.gold.conllu" + train_output_conllu = common.tokenizer_conllu_name(tokenizer_dir, short_name, "train") + dev_output_conllu = common.tokenizer_conllu_name(tokenizer_dir, short_name, "dev") + test_output_conllu = common.tokenizer_conllu_name(tokenizer_dir, short_name, "test") if (common.num_words_in_file(train_input_conllu) <= 1000 and common.num_words_in_file(test_input_conllu) > 5000): @@ -1194,10 +1127,10 @@ def process_treebank(treebank, paths, args): process_ud_treebank(treebank, udbase_dir, tokenizer_dir, short_name, short_language, args.augment) if not short_name in ('th_orchid', 'th_lst20'): - convert_conllu_to_txt(tokenizer_dir, short_name) + common.convert_conllu_to_txt(tokenizer_dir, short_name) if args.prepare_labels: - prepare_treebank_labels(tokenizer_dir, short_name) + common.prepare_tokenizer_treebank_labels(tokenizer_dir, short_name) def main(): |