Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/stanfordnlp/stanza.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJohn Bauer <horatio@gmail.com>2022-09-08 01:22:51 +0300
committerJohn Bauer <horatio@gmail.com>2022-09-08 20:04:00 +0300
commit7aee87a84e1574043c382b02f6f1f0f4a691e2ce (patch)
treedc655e88f1f9b5a19a5eb5a63c41268ace06a0fd
parent1d33718ab625da4251f8b6863f9b74a13c5fa09a (diff)
Rearrange a bunch of functions from prepare_tokenizer_treebank to a common file
Move the read/write conllu functions to a common folder so they can be used elsewhere Move the MWT_RE etc as well Move prepare_treebank_labels to common (and rename it) Move convert_conllu_to_txt as well Refactor a tokenizer_conllu_name function
-rw-r--r--stanza/tests/tokenization/test_tokenization_lst20.py2
-rw-r--r--stanza/tests/tokenization/test_tokenization_orchid.py2
-rw-r--r--stanza/utils/datasets/common.py82
-rw-r--r--stanza/utils/datasets/corenlp_segmenter_dataset.py7
-rwxr-xr-xstanza/utils/datasets/prepare_tokenizer_treebank.py89
5 files changed, 98 insertions, 84 deletions
diff --git a/stanza/tests/tokenization/test_tokenization_lst20.py b/stanza/tests/tokenization/test_tokenization_lst20.py
index a0728123..67928b5b 100644
--- a/stanza/tests/tokenization/test_tokenization_lst20.py
+++ b/stanza/tests/tokenization/test_tokenization_lst20.py
@@ -6,7 +6,7 @@ import pytest
import stanza
from stanza.tests import *
-from stanza.utils.datasets.prepare_tokenizer_treebank import convert_conllu_to_txt
+from stanza.utils.datasets.common import convert_conllu_to_txt
from stanza.utils.datasets.tokenization.convert_th_lst20 import read_document
from stanza.utils.datasets.tokenization.process_thai_tokenization import write_section
diff --git a/stanza/tests/tokenization/test_tokenization_orchid.py b/stanza/tests/tokenization/test_tokenization_orchid.py
index 8c0fb9f5..8a186e26 100644
--- a/stanza/tests/tokenization/test_tokenization_orchid.py
+++ b/stanza/tests/tokenization/test_tokenization_orchid.py
@@ -8,7 +8,7 @@ import xml.etree.ElementTree as ET
import stanza
from stanza.tests import *
-from stanza.utils.datasets.prepare_tokenizer_treebank import convert_conllu_to_txt
+from stanza.utils.datasets.common import convert_conllu_to_txt
from stanza.utils.datasets.tokenization.convert_th_orchid import parse_xml
from stanza.utils.datasets.tokenization.process_thai_tokenization import write_section
diff --git a/stanza/utils/datasets/common.py b/stanza/utils/datasets/common.py
index efdbb8cf..871ebb80 100644
--- a/stanza/utils/datasets/common.py
+++ b/stanza/utils/datasets/common.py
@@ -3,13 +3,93 @@ import argparse
import glob
import logging
import os
+import re
+import subprocess
import sys
-import stanza.utils.default_paths as default_paths
from stanza.models.common.short_name_to_treebank import canonical_treebank_name
+import stanza.utils.datasets.prepare_tokenizer_data as prepare_tokenizer_data
+import stanza.utils.default_paths as default_paths
logger = logging.getLogger('stanza')
+# RE to see if the index of a conllu line represents an MWT
+MWT_RE = re.compile("^[0-9]+[-][0-9]+")
+
+# RE to see if the index of a conllu line represents an MWT or copy node
+MWT_OR_COPY_RE = re.compile("^[0-9]+[-.][0-9]+")
+
+# more restrictive than an actual int as we expect certain formats in the conllu files
+INT_RE = re.compile("^[0-9]+$")
+
+CONLLU_TO_TXT_PERL = os.path.join(os.path.split(__file__)[0], "conllu_to_text.pl")
+
+def convert_conllu_to_txt(tokenizer_dir, short_name, shards=("train", "dev", "test")):
+ """
+ Uses the udtools perl script to convert a conllu file to txt
+
+ TODO: switch to a python version to get rid of some perl dependence
+ """
+ for dataset in shards:
+ output_conllu = f"{tokenizer_dir}/{short_name}.{dataset}.gold.conllu"
+ output_txt = f"{tokenizer_dir}/{short_name}.{dataset}.txt"
+
+ if not os.path.exists(output_conllu):
+ # the perl script doesn't raise an error code for file not found!
+ raise FileNotFoundError("Cannot convert %s as the file cannot be found" % output_conllu)
+ # use an external script to produce the txt files
+ subprocess.check_output(f"perl {CONLLU_TO_TXT_PERL} {output_conllu} > {output_txt}", shell=True)
+
+def mwt_name(base_dir, short_name, dataset):
+ return os.path.join(base_dir, f"{short_name}-ud-{dataset}-mwt.json")
+
+def tokenizer_conllu_name(base_dir, short_name, dataset):
+ return os.path.join(base_dir, f"{short_name}.{dataset}.gold.conllu")
+
+def prepare_tokenizer_dataset_labels(input_txt, input_conllu, tokenizer_dir, short_name, dataset):
+ prepare_tokenizer_data.main([input_txt,
+ input_conllu,
+ "-o", f"{tokenizer_dir}/{short_name}-ud-{dataset}.toklabels",
+ "-m", mwt_name(tokenizer_dir, short_name, dataset)])
+
+def prepare_tokenizer_treebank_labels(tokenizer_dir, short_name):
+ """
+ Given the txt and gold.conllu files, prepare mwt and label files for train/dev/test
+ """
+ for dataset in ("train", "dev", "test"):
+ output_txt = f"{tokenizer_dir}/{short_name}.{dataset}.txt"
+ output_conllu = f"{tokenizer_dir}/{short_name}.{dataset}.gold.conllu"
+ try:
+ prepare_tokenizer_dataset_labels(output_txt, output_conllu, tokenizer_dir, short_name, dataset)
+ except (KeyboardInterrupt, SystemExit):
+ raise
+ except:
+ print("Failed to convert %s to %s" % (output_txt, output_conllu))
+ raise
+
+def read_sentences_from_conllu(filename):
+ sents = []
+ cache = []
+ with open(filename, encoding="utf-8") as infile:
+ for line in infile:
+ line = line.strip()
+ if len(line) == 0:
+ if len(cache) > 0:
+ sents.append(cache)
+ cache = []
+ continue
+ cache.append(line)
+ if len(cache) > 0:
+ sents.append(cache)
+ return sents
+
+def write_sentences_to_conllu(filename, sents):
+ with open(filename, 'w', encoding="utf-8") as outfile:
+ for lines in sents:
+ for line in lines:
+ print(line, file=outfile)
+ print("", file=outfile)
+
def find_treebank_dataset_file(treebank, udbase_dir, dataset, extension, fail=False):
"""
For a given treebank, dataset, extension, look for the exact filename to use.
diff --git a/stanza/utils/datasets/corenlp_segmenter_dataset.py b/stanza/utils/datasets/corenlp_segmenter_dataset.py
index 9ebf7783..b2a275a3 100644
--- a/stanza/utils/datasets/corenlp_segmenter_dataset.py
+++ b/stanza/utils/datasets/corenlp_segmenter_dataset.py
@@ -12,6 +12,7 @@ import os
import sys
import tempfile
+import stanza.utils.datasets.common as common
import stanza.utils.datasets.prepare_tokenizer_treebank as prepare_tokenizer_treebank
import stanza.utils.default_paths as default_paths
@@ -54,9 +55,9 @@ def process_treebank(treebank, paths, output_dir):
dev_file = f"{tokenizer_dir}/{short_name}.dev.gold.conllu"
test_file = f"{tokenizer_dir}/{short_name}.test.gold.conllu"
- train_set = prepare_tokenizer_treebank.read_sentences_from_conllu(train_file)
- dev_set = prepare_tokenizer_treebank.read_sentences_from_conllu(dev_file)
- test_set = prepare_tokenizer_treebank.read_sentences_from_conllu(test_file)
+ train_set = common.read_sentences_from_conllu(train_file)
+ dev_set = common.read_sentences_from_conllu(dev_file)
+ test_set = common.read_sentences_from_conllu(test_file)
train_out = os.path.join(output_dir, f"{short_name}.train.seg.txt")
test_out = os.path.join(output_dir, f"{short_name}.test.seg.txt")
diff --git a/stanza/utils/datasets/prepare_tokenizer_treebank.py b/stanza/utils/datasets/prepare_tokenizer_treebank.py
index d03b81ac..ca10af3f 100755
--- a/stanza/utils/datasets/prepare_tokenizer_treebank.py
+++ b/stanza/utils/datasets/prepare_tokenizer_treebank.py
@@ -26,14 +26,13 @@ import glob
import os
import random
import re
-import subprocess
import tempfile
from collections import Counter
from stanza.models.common.constant import treebank_to_short_name
import stanza.utils.datasets.common as common
-import stanza.utils.datasets.prepare_tokenizer_data as prepare_tokenizer_data
+from stanza.utils.datasets.common import read_sentences_from_conllu, write_sentences_to_conllu, INT_RE, MWT_RE, MWT_OR_COPY_RE
import stanza.utils.datasets.tokenization.convert_my_alt as convert_my_alt
import stanza.utils.datasets.tokenization.convert_vi_vlsp as convert_vi_vlsp
import stanza.utils.datasets.tokenization.convert_th_best as convert_th_best
@@ -82,29 +81,6 @@ def copy_conllu_treebank(treebank, paths, dest_dir, postprocess=None, augment=Tr
postprocess(tokenizer_dir, "test.gold", dest_dir, "test.gold", short_name)
copy_conllu_file(dest_dir, "test.gold", dest_dir, "test.in", short_name)
-def read_sentences_from_conllu(filename):
- sents = []
- cache = []
- with open(filename, encoding="utf-8") as infile:
- for line in infile:
- line = line.strip()
- if len(line) == 0:
- if len(cache) > 0:
- sents.append(cache)
- cache = []
- continue
- cache.append(line)
- if len(cache) > 0:
- sents.append(cache)
- return sents
-
-def write_sentences_to_conllu(filename, sents):
- with open(filename, 'w', encoding="utf-8") as outfile:
- for lines in sents:
- for line in lines:
- print(line, file=outfile)
- print("", file=outfile)
-
def split_train_file(treebank, train_input_conllu, train_output_conllu, dev_output_conllu):
# set the seed for each data file so that the results are the same
# regardless of how many treebanks are processed at once
@@ -129,49 +105,6 @@ def split_train_file(treebank, train_input_conllu, train_output_conllu, dev_outp
return True
-def mwt_name(base_dir, short_name, dataset):
- return f"{base_dir}/{short_name}-ud-{dataset}-mwt.json"
-
-def prepare_dataset_labels(input_txt, input_conllu, tokenizer_dir, short_name, dataset):
- prepare_tokenizer_data.main([input_txt,
- input_conllu,
- "-o", f"{tokenizer_dir}/{short_name}-ud-{dataset}.toklabels",
- "-m", mwt_name(tokenizer_dir, short_name, dataset)])
-
-def prepare_treebank_labels(tokenizer_dir, short_name):
- for dataset in ("train", "dev", "test"):
- output_txt = f"{tokenizer_dir}/{short_name}.{dataset}.txt"
- output_conllu = f"{tokenizer_dir}/{short_name}.{dataset}.gold.conllu"
- try:
- prepare_dataset_labels(output_txt, output_conllu, tokenizer_dir, short_name, dataset)
- except (KeyboardInterrupt, SystemExit):
- raise
- except:
- print("Failed to convert %s to %s" % (output_txt, output_conllu))
- raise
-
-CONLLU_TO_TXT_PERL = os.path.join(os.path.split(__file__)[0], "conllu_to_text.pl")
-
-def convert_conllu_to_txt(tokenizer_dir, short_name, shards=("train", "dev", "test")):
- for dataset in shards:
- output_conllu = f"{tokenizer_dir}/{short_name}.{dataset}.gold.conllu"
- output_txt = f"{tokenizer_dir}/{short_name}.{dataset}.txt"
-
- if not os.path.exists(output_conllu):
- # the perl script doesn't raise an error code for file not found!
- raise FileNotFoundError("Cannot convert %s as the file cannot be found" % output_conllu)
- # use an external script to produce the txt files
- subprocess.check_output(f"perl {CONLLU_TO_TXT_PERL} {output_conllu} > {output_txt}", shell=True)
-
-
-# RE to see if the index of a conllu line represents an MWT
-MWT_RE = re.compile("^[0-9]+[-][0-9]+")
-
-# RE to see if the index of a conllu line represents an MWT or copy node
-MWT_OR_COPY_RE = re.compile("^[0-9]+[-.][0-9]+")
-
-# more restrictive than an actual int as we expect certain formats in the conllu files
-INT_RE = re.compile("^[0-9]+$")
def strip_mwt_from_sentences(sents):
"""
@@ -801,7 +734,7 @@ def build_combined_korean_dataset(udbase_dir, tokenizer_dir, short_name, dataset
def build_combined_korean(udbase_dir, tokenizer_dir, short_name):
for dataset in ("train", "dev", "test"):
- output_conllu = f"{tokenizer_dir}/{short_name}.{dataset}.gold.conllu"
+ output_conllu = common.tokenizer_conllu_name(tokenizer_dir, short_name, dataset)
build_combined_korean_dataset(udbase_dir, tokenizer_dir, short_name, dataset, output_conllu)
def build_combined_italian_dataset(paths, dataset):
@@ -1005,7 +938,7 @@ def build_combined_dataset(paths, short_name, augment):
build_fn = COMBINED_FNS[short_name]
extra_fn = COMBINED_EXTRA_FNS.get(short_name, None)
for dataset in ("train", "dev", "test"):
- output_conllu = f"{tokenizer_dir}/{short_name}.{dataset}.gold.conllu"
+ output_conllu = common.tokenizer_conllu_name(tokenizer_dir, short_name, dataset)
sents = build_fn(paths, dataset)
if dataset == 'train' and augment:
sents = augment_punct(sents)
@@ -1025,7 +958,7 @@ def build_bio_dataset(paths, udbase_dir, tokenizer_dir, handparsed_dir, short_na
name, bio_dataset = short_name.split("_")
assert name == 'en'
for dataset in ("train", "dev", "test"):
- output_conllu = f"{tokenizer_dir}/{short_name}.{dataset}.gold.conllu"
+ output_conllu = common.tokenizer_conllu_name(tokenizer_dir, short_name, dataset)
if dataset == 'train':
sents = build_combined_english_dataset(paths, dataset)
if dataset == 'train' and augment:
@@ -1045,7 +978,7 @@ def build_combined_english_gum_dataset(udbase_dir, tokenizer_dir, short_name, da
check_gum_ready(udbase_dir)
random.seed(1234)
- output_conllu = f"{tokenizer_dir}/{short_name}.{dataset}.gold.conllu"
+ output_conllu = common.tokenizer_conllu_name(tokenizer_dir, short_name, dataset)
treebanks = ["UD_English-GUM", "UD_English-GUMReddit"]
sents = []
@@ -1066,7 +999,7 @@ def prepare_ud_dataset(treebank, udbase_dir, tokenizer_dir, short_name, short_la
if input_conllu is None:
input_conllu = common.find_treebank_dataset_file(treebank, udbase_dir, dataset, "conllu", fail=True)
if output_conllu is None:
- output_conllu = f"{tokenizer_dir}/{short_name}.{dataset}.gold.conllu"
+ output_conllu = common.tokenizer_conllu_name(tokenizer_dir, short_name, dataset)
print("Reading from %s and writing to %s" % (input_conllu, output_conllu))
if short_name == "te_mtg" and dataset == 'train' and augment:
@@ -1116,9 +1049,9 @@ def process_partial_ud_treebank(treebank, udbase_dir, tokenizer_dir, short_name,
train_input_conllu = common.find_treebank_dataset_file(treebank, udbase_dir, "train", "conllu")
test_input_conllu = common.find_treebank_dataset_file(treebank, udbase_dir, "test", "conllu")
- train_output_conllu = f"{tokenizer_dir}/{short_name}.train.gold.conllu"
- dev_output_conllu = f"{tokenizer_dir}/{short_name}.dev.gold.conllu"
- test_output_conllu = f"{tokenizer_dir}/{short_name}.test.gold.conllu"
+ train_output_conllu = common.tokenizer_conllu_name(tokenizer_dir, short_name, "train")
+ dev_output_conllu = common.tokenizer_conllu_name(tokenizer_dir, short_name, "dev")
+ test_output_conllu = common.tokenizer_conllu_name(tokenizer_dir, short_name, "test")
if (common.num_words_in_file(train_input_conllu) <= 1000 and
common.num_words_in_file(test_input_conllu) > 5000):
@@ -1194,10 +1127,10 @@ def process_treebank(treebank, paths, args):
process_ud_treebank(treebank, udbase_dir, tokenizer_dir, short_name, short_language, args.augment)
if not short_name in ('th_orchid', 'th_lst20'):
- convert_conllu_to_txt(tokenizer_dir, short_name)
+ common.convert_conllu_to_txt(tokenizer_dir, short_name)
if args.prepare_labels:
- prepare_treebank_labels(tokenizer_dir, short_name)
+ common.prepare_tokenizer_treebank_labels(tokenizer_dir, short_name)
def main():