From 154b0e8e59d3276744ae0c8ea56dc226f777fba8 Mon Sep 17 00:00:00 2001 From: J38 Date: Wed, 21 Apr 2021 07:55:39 -0700 Subject: processor for identifying languages of text class for multilingual pipeline model code for langid system add langid to constants add lang to Document integrate langid and multilingual into stanza adds script for creating UD train data add eval and model training from checkpoint allow new labels in eval set, fine-grained eval option handles fine-grained Norwegian and Chinese in UD data stores tag index with data, add all labels in eval set option for saving with label (e.g. epoch1) adds tweet cleaning adds lang subset option adds benchmark test adds emoji library dependency to stanza --- scripts/treebank_to_shorthand.sh | 4 +- setup.py | 2 +- stanza/__init__.py | 1 + stanza/models/common/constant.py | 13 +- stanza/models/common/doc.py | 11 + stanza/models/lang_identifier.py | 226 ++++++++++++ stanza/models/langid/__init__.py | 0 stanza/models/langid/create_ud_data.py | 205 +++++++++++ stanza/models/langid/data.py | 136 ++++++++ stanza/models/langid/model.py | 120 +++++++ stanza/models/langid/trainer.py | 53 +++ stanza/pipeline/_constants.py | 1 + stanza/pipeline/core.py | 1 + stanza/pipeline/langid_processor.py | 126 +++++++ stanza/pipeline/multilingual.py | 106 ++++++ stanza/resources/prepare_resources.py | 32 +- stanza/tests/test_langid.py | 608 +++++++++++++++++++++++++++++++++ 17 files changed, 1628 insertions(+), 17 deletions(-) create mode 100644 stanza/models/lang_identifier.py create mode 100644 stanza/models/langid/__init__.py create mode 100644 stanza/models/langid/create_ud_data.py create mode 100644 stanza/models/langid/data.py create mode 100644 stanza/models/langid/model.py create mode 100644 stanza/models/langid/trainer.py create mode 100644 stanza/pipeline/langid_processor.py create mode 100644 stanza/pipeline/multilingual.py create mode 100644 stanza/tests/test_langid.py diff --git a/scripts/treebank_to_shorthand.sh b/scripts/treebank_to_shorthand.sh index bb6f1793..f5395ec6 100755 --- a/scripts/treebank_to_shorthand.sh +++ b/scripts/treebank_to_shorthand.sh @@ -19,10 +19,10 @@ lang=`echo $treebank | sed -e 's#-.*$##g' -e 's#^[^_]*_##g'` lcode=${lang2lcode[$lang]} if [ -z "$lcode" ]; then if [ $lang == "Chinese" ]; then - if [ $tbname == "gsdsimp" ]; then + if [ $tbname == "gsdsimp" -o $tbname == "cfl" ]; then # TODO why not zh-hans? lcode=zh - elif [ $tbname == "gsd" -o $tbname == "hk" -o $tbname == "cfl" -o $tbname == "pud" ]; then + elif [ $tbname == "gsd" -o $tbname == "hk" -o $tbname == "pud" ]; then lcode=zh-hant fi elif [ $lang == "Norwegian" ]; then diff --git a/setup.py b/setup.py index 70e20fe3..890088da 100644 --- a/setup.py +++ b/setup.py @@ -76,7 +76,7 @@ setup( # your project is installed. For an analysis of "install_requires" vs pip's # requirements files see: # https://packaging.python.org/en/latest/requirements.html - install_requires=['numpy', 'protobuf', 'requests', 'torch>=1.3.0', 'tqdm'], + install_requires=['emoji', 'numpy', 'protobuf', 'requests', 'torch>=1.3.0', 'tqdm'], # List required Python versions python_requires='>=3.6', diff --git a/stanza/__init__.py b/stanza/__init__.py index 76f04fd9..25c6fd13 100644 --- a/stanza/__init__.py +++ b/stanza/__init__.py @@ -1,4 +1,5 @@ from stanza.pipeline.core import Pipeline +from stanza.pipeline.multilingual import MultilingualPipeline from stanza.models.common.doc import Document from stanza.resources.common import download from stanza.resources.installation import install_corenlp, download_corenlp_models diff --git a/stanza/models/common/constant.py b/stanza/models/common/constant.py index 3ba570ab..8fdabb4b 100644 --- a/stanza/models/common/constant.py +++ b/stanza/models/common/constant.py @@ -134,6 +134,7 @@ langlower2lcode = {lcode2lang[k].lower(): k.lower() for k in lcode2lang} # additional useful code to language mapping # added after dict invert to avoid conflict lcode2lang['nb'] = 'Norwegian' # Norwegian Bokmall mapped to default norwegian +lcode2lang['no'] = 'Norwegian' lcode2lang['zh'] = 'Simplified_Chinese' lang2lcode['Chinese'] = 'zh' @@ -142,12 +143,12 @@ lang2lcode['Chinese'] = 'zh' lang2lcode['Old_Russian'] = 'orv' treebank_special_cases = { - "UD_Chinese-GSDSimp": "zh_gsdsimp", + "UD_Chinese-GSDSimp": "zh-hans_gsdsimp", "UD_Chinese-GSD": "zh-hant_gsd", "UD_Chinese-HK": "zh-hant_hk", - "UD_Chinese-CFL": "zh-hant_cfl", + "UD_Chinese-CFL": "zh-hans_cfl", "UD_Chinese-PUD": "zh-hant_pud", - "UD_Norwegian-Bokmaal": "nb_bokmaal", + "UD_Norwegian-Bokmaal": "no_bokmaal", "UD_Norwegian-Nynorsk": "nn_nynorsk", "UD_Norwegian-NynorskLIA": "nn_nynorsklia", } @@ -174,3 +175,9 @@ def treebank_to_short_name(treebank): short = "{}_{}".format(lcode, corpus.lower()) return short + +def treebank_to_langid(treebank): + """ Convert treebank name to langid """ + short_name = treebank_to_short_name(treebank) + return short_name.split("_")[0] + diff --git a/stanza/models/common/doc.py b/stanza/models/common/doc.py index d24ea966..eb1d9f2f 100644 --- a/stanza/models/common/doc.py +++ b/stanza/models/common/doc.py @@ -72,6 +72,7 @@ class Document(StanzaObject): comments: A list of list of strings to use as comments on the sentences, either None or the same length as sentences """ self._sentences = [] + self._lang = None self._text = None self._num_tokens = 0 self._num_words = 0 @@ -80,6 +81,16 @@ class Document(StanzaObject): self._process_sentences(sentences, comments) self._ents = [] + @property + def lang(self): + """ Access the language of this document """ + return self._lang + + @lang.setter + def lang(self, value): + """ Set the language of this document """ + self._lang = value + @property def text(self): """ Access the raw text for this document. """ diff --git a/stanza/models/lang_identifier.py b/stanza/models/lang_identifier.py new file mode 100644 index 00000000..ca7aa8e2 --- /dev/null +++ b/stanza/models/lang_identifier.py @@ -0,0 +1,226 @@ +""" +Entry point for training and evaluating a Bi-LSTM language identifier +""" + +import argparse +import json +import logging +import os +import random +import torch + +from datetime import datetime +from stanza.models.langid.data import DataLoader +from stanza.models.langid.trainer import Trainer +from tqdm import tqdm + +logger = logging.getLogger('stanza') + +def parse_args(args=None): + parser = argparse.ArgumentParser() + parser.add_argument("--batch-mode", help="custom settings when running in batch mode", action="store_true") + parser.add_argument("--batch-size", help="batch size for training", type=int, default=64) + parser.add_argument("--eval-length", help="length of strings to eval on", type=int, default=None) + parser.add_argument("--eval-set", help="eval on dev or test", default="test") + parser.add_argument("--data-dir", help="directory with train/dev/test data", default=None) + parser.add_argument("--load-model", help="path to load model from", default=None) + parser.add_argument("--mode", help="train or eval", default="train") + parser.add_argument("--num-epochs", help="number of epochs for training", type=int, default=50) + parser.add_argument("--randomize", help="take random substrings of samples", action="store_true") + parser.add_argument("--randomize-lengths-range", help="range of lengths to use when random sampling text", + type=randomize_lengths_range, default="5,20") + parser.add_argument("--merge-labels-for-eval", + help="merge some language labels for eval (e.g. \"zh-hans\" and \"zh-hant\" to \"zh\")", + action="store_true") + parser.add_argument("--save-best-epochs", help="save model for every epoch with new best score", action="store_true") + parser.add_argument("--save-name", help="where to save model", default=None) + parser.add_argument("--use-cpu", help="use cpu", action="store_true") + args = parser.parse_args(args=args) + args.use_gpu = True if torch.cuda.is_available() and not args.use_cpu else False + return args + + +def randomize_lengths_range(range_list): + """ + Range of lengths for random samples + """ + range_boundaries = [int(x) for x in range_list.split(",")] + assert range_boundaries[0] < range_boundaries[1], f"Invalid range: ({range_boundaries[0]}, {range_boundaries[1]})" + return range_boundaries + + +def main(args=None): + args = parse_args(args=args) + torch.manual_seed(0) + if args.mode == "train": + train_model(args) + else: + eval_model(args) + + +def build_indexes(args): + tag_to_idx = {} + char_to_idx = {} + train_files = [f"{args.data_dir}/{x}" for x in os.listdir(args.data_dir) if "train" in x] + for train_file in train_files: + with open(train_file) as curr_file: + lines = curr_file.read().strip().split("\n") + examples = [json.loads(line) for line in lines if line.strip()] + for example in examples: + label = example["label"] + if label not in tag_to_idx: + tag_to_idx[label] = len(tag_to_idx) + sequence = example["text"] + for char in list(sequence): + if char not in char_to_idx: + char_to_idx[char] = len(char_to_idx) + char_to_idx["UNK"] = len(char_to_idx) + char_to_idx[""] = len(char_to_idx) + + return tag_to_idx, char_to_idx + + +def train_model(args): + # set up indexes + tag_to_idx, char_to_idx = build_indexes(args) + # load training data + train_data = DataLoader() + train_files = [f"{args.data_dir}/{x}" for x in os.listdir(args.data_dir) if "train" in x] + train_data.load_data(args.batch_size, train_files, char_to_idx, tag_to_idx, args.randomize) + # load dev data + dev_data = DataLoader() + dev_files = [f"{args.data_dir}/{x}" for x in os.listdir(args.data_dir) if "dev" in x] + dev_data.load_data(args.batch_size, dev_files, char_to_idx, tag_to_idx, randomize=False, + max_length=args.eval_length) + # set up trainer + trainer_config = { + "model_path": args.save_name, + "char_to_idx": char_to_idx, + "tag_to_idx": tag_to_idx, + "batch_size": args.batch_size, + "lang_weights": train_data.lang_weights + } + if args.load_model: + trainer_config["load_model"] = args.load_model + logger.info(f"{datetime.now()}\tLoading model from: {args.load_model}") + trainer = Trainer(trainer_config, load_model=args.load_model, use_gpu=args.use_gpu) + # run training + best_accuracy = 0.0 + for epoch in range(1, args.num_epochs+1): + logger.info(f"{datetime.now()}\tEpoch {epoch}") + logger.info(f"{datetime.now()}\tNum training batches: {len(train_data.batches)}") + for train_batch in tqdm(train_data.batches, disable=args.batch_mode): + inputs = (train_batch["sentences"], train_batch["targets"]) + trainer.update(inputs) + logger.info(f"{datetime.now()}\tEpoch complete. Evaluating on dev data.") + curr_dev_accuracy, curr_confusion_matrix, curr_precisions, curr_recalls, curr_f1s = \ + eval_trainer(trainer, dev_data, batch_mode=args.batch_mode) + logger.info(f"{datetime.now()}\tCurrent dev accuracy: {curr_dev_accuracy}") + if curr_dev_accuracy > best_accuracy: + logger.info(f"{datetime.now()}\tNew best score. Saving model.") + model_label = f"epoch{epoch}" if args.save_best_epochs else None + trainer.save(label=model_label) + with open(score_log_path(args.save_name), "w") as score_log_file: + for score_log in [{"dev_accuracy": curr_dev_accuracy}, curr_confusion_matrix, curr_precisions, + curr_recalls, curr_f1s]: + score_log_file.write(json.dumps(score_log) + "\n") + best_accuracy = curr_dev_accuracy + + # reload training data + logger.info(f"{datetime.now()}\tResampling training data.") + train_data.load_data(args.batch_size, train_files, char_to_idx, tag_to_idx, args.randomize) + + +def score_log_path(file_path): + """ + Helper that will determine corresponding log file (e.g. /path/to/demo.pt to /path/to/demo.json + """ + model_suffix = os.path.splitext(file_path) + if model_suffix: + score_log_path = f"{file_path[:-len(model_suffix)]}.json" + else: + score_log_path = f"{file_path}.json" + return score_log_path + + +def eval_model(args): + # set up trainer + trainer_config = { + "model_path": None, + "load_model": args.load_model, + "batch_size": args.batch_size + } + trainer = Trainer(trainer_config, load_model=True, use_gpu=args.use_gpu) + # load test data + test_data = DataLoader() + test_files = [f"{args.data_dir}/{x}" for x in os.listdir(args.data_dir) if args.eval_set in x] + test_data.load_data(args.batch_size, test_files, trainer.model.char_to_idx, trainer.model.tag_to_idx, + randomize=False, max_length=args.eval_length) + curr_accuracy, curr_confusion_matrix, curr_precisions, curr_recalls, curr_f1s = \ + eval_trainer(trainer, test_data, batch_mode=args.batch_mode, fine_grained=not args.merge_labels_for_eval) + logger.info(f"{datetime.now()}\t{args.eval_set} accuracy: {curr_accuracy}") + eval_save_path = args.save_name if args.save_name else score_log_path(args.load_model) + if not os.path.exists(eval_save_path) or args.save_name: + with open(eval_save_path, "w") as score_log_file: + for score_log in [{"dev_accuracy": curr_accuracy}, curr_confusion_matrix, curr_precisions, + curr_recalls, curr_f1s]: + score_log_file.write(json.dumps(score_log) + "\n") + + + +def eval_trainer(trainer, dev_data, batch_mode=False, fine_grained=True): + """ + Produce dev accuracy and confusion matrix for a trainer + """ + + # set up confusion matrix + tag_to_idx = dev_data.tag_to_idx + idx_to_tag = dev_data.idx_to_tag + confusion_matrix = {} + for row_label in tag_to_idx: + confusion_matrix[row_label] = {} + for col_label in tag_to_idx: + confusion_matrix[row_label][col_label] = 0 + + # process dev batches + for dev_batch in tqdm(dev_data.batches, disable=batch_mode): + inputs = (dev_batch["sentences"], dev_batch["targets"]) + predictions = trainer.predict(inputs) + for target_idx, prediction in zip(dev_batch["targets"], predictions): + prediction_label = idx_to_tag[prediction] if fine_grained else idx_to_tag[prediction].split("-")[0] + confusion_matrix[idx_to_tag[target_idx]][prediction_label] += 1 + + # calculate dev accuracy + total_examples = sum([sum([confusion_matrix[i][j] for j in confusion_matrix[i]]) for i in confusion_matrix]) + total_correct = sum([confusion_matrix[i][i] for i in confusion_matrix]) + dev_accuracy = float(total_correct) / float(total_examples) + + # calculate precision, recall, F1 + precision_scores = {"type": "precision"} + recall_scores = {"type": "recall"} + f1_scores = {"type": "f1"} + for prediction_label in tag_to_idx: + total = sum([confusion_matrix[k][prediction_label] for k in tag_to_idx]) + if total != 0.0: + precision_scores[prediction_label] = float(confusion_matrix[prediction_label][prediction_label])/float(total) + else: + precision_scores[prediction_label] = 0.0 + for target_label in tag_to_idx: + total = sum([confusion_matrix[target_label][k] for k in tag_to_idx]) + if total != 0: + recall_scores[target_label] = float(confusion_matrix[target_label][target_label])/float(total) + else: + recall_scores[target_label] = 0.0 + for label in tag_to_idx: + if precision_scores[label] == 0.0 and recall_scores[label] == 0.0: + f1_scores[label] = 0.0 + else: + f1_scores[label] = \ + 2.0 * (precision_scores[label] * recall_scores[label]) / (precision_scores[label] + recall_scores[label]) + + return dev_accuracy, confusion_matrix, precision_scores, recall_scores, f1_scores + + +if __name__ == "__main__": + main() + diff --git a/stanza/models/langid/__init__.py b/stanza/models/langid/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/stanza/models/langid/create_ud_data.py b/stanza/models/langid/create_ud_data.py new file mode 100644 index 00000000..7e7d5cc5 --- /dev/null +++ b/stanza/models/langid/create_ud_data.py @@ -0,0 +1,205 @@ +""" +Script for producing training/dev/test data from UD data or sentences + +Example output data format (one example per line): + +{"text": "Hello world.", "label": "en"} + +""" + +import argparse +import json +import logging +import os +import re +import sys + +from pathlib import Path +from random import randint, random, shuffle +from string import digits +from tqdm import tqdm + +from stanza.models.common.constant import treebank_to_langid + +logger = logging.getLogger('stanza') + +DEFAULT_LANGUAGES = "af,ar,be,bg,bxr,ca,cop,cs,cu,da,de,el,en,es,et,eu,fa,fi,fr,fro,ga,gd,gl,got,grc,he,hi,hr,hsb,hu,hy,id,it,ja,kk,kmr,ko,la,lt,lv,lzh,mr,mt,nl,nn,no,olo,orv,pl,pt,ro,ru,sk,sl,sme,sr,sv,swl,ta,te,tr,ug,uk,ur,vi,wo,zh-hans,zh-hant".split(",") + +def parse_args(args=None): + parser = argparse.ArgumentParser() + parser.add_argument("--data-format", help="input data format", choices=["ud", "one-per-line"], default="ud") + parser.add_argument("--eval-length", help="length of eval strings", type=int, default=10) + parser.add_argument("--languages", help="list of languages to use, or \"all\"", default=DEFAULT_LANGUAGES) + parser.add_argument("--min-window", help="minimal training example length", type=int, default=10) + parser.add_argument("--max-window", help="maximum training example length", type=int, default=50) + parser.add_argument("--ud-path", help="path to ud data") + parser.add_argument("--save-path", help="path to save data", default=".") + parser.add_argument("--splits", help="size of train/dev/test splits in percentages", type=splits_from_list, + default="0.8,0.1,0.1") + args = parser.parse_args(args=args) + return args + + +def splits_from_list(value_list): + return [float(x) for x in value_list.split(",")] + + +def main(args=None): + args = parse_args(args=args) + if isinstance(args.languages, str): + args.languages = args.languages.split(",") + data_paths = [f"{args.save_path}/{data_split}.jsonl" for data_split in ["train", "dev", "test"]] + lang_to_files = collect_files(args.ud_path, args.languages, data_format=args.data_format) + logger.info(f"Building UD data for languages: {','.join(args.languages)}") + for lang_id in tqdm(lang_to_files): + lang_examples = generate_examples(lang_id, lang_to_files[lang_id], splits=args.splits, + min_window=args.min_window, max_window=args.max_window, + eval_length=args.eval_length, data_format=args.data_format) + for (data_set, save_path) in zip(lang_examples, data_paths): + with open(save_path, "a") as json_file: + for json_entry in data_set: + json.dump(json_entry, json_file, ensure_ascii=False) + json_file.write("\n") + + +def collect_files(ud_path, languages, data_format="ud"): + """ + Given path to UD, collect files + If data_format = "ud", expects files to be of form *.conllu + If data_format = "one-per-line", expects files to be of form "*.sentences.txt" + In all cases, the UD path should be a directory with subdirectories for each language + """ + data_format_to_search_path = {"ud": "*/*.conllu", "one-per-line": "*/*sentences.txt"} + ud_files = Path(ud_path).glob(data_format_to_search_path[data_format]) + lang_to_files = {} + for ud_file in ud_files: + if data_format == "ud": + lang_id = treebank_to_langid(ud_file.parent.name) + else: + lang_id = ud_file.name.split("_")[0] + if lang_id not in languages and "all" not in languages: + continue + if not lang_id in lang_to_files: + lang_to_files[lang_id] = [] + lang_to_files[lang_id].append(ud_file) + return lang_to_files + + +def generate_examples(lang_id, list_of_files, splits=(0.8,0.1,0.1), min_window=10, max_window=50, + eval_length=10, data_format="ud"): + """ + Generate train/dev/test examples for a given language + """ + examples = [] + for ud_file in list_of_files: + sentences = sentences_from_file(ud_file, data_format=data_format) + for sentence in sentences: + sentence = clean_sentence(sentence) + if validate_sentence(sentence, min_window): + examples += sentence_to_windows(sentence, min_window=min_window, max_window=max_window) + shuffle(examples) + train_idx = int(splits[0] * len(examples)) + train_set = [example_json(lang_id, example) for example in examples[:train_idx]] + dev_idx = int(splits[1] * len(examples)) + train_idx + dev_set = [example_json(lang_id, example, eval_length=eval_length) for example in examples[train_idx:dev_idx]] + test_set = [example_json(lang_id, example, eval_length=eval_length) for example in examples[dev_idx:]] + return train_set, dev_set, test_set + + +def sentences_from_file(ud_file_path, data_format="ud"): + """ + Retrieve all sentences from a UD file + """ + if data_format == "ud": + with open(ud_file_path) as ud_file: + ud_file_contents = ud_file.read().strip() + assert "# text = " in ud_file_contents, \ + f"{ud_file_path} does not have expected format, \"# text =\" does not appear" + sentences = [x[9:] for x in ud_file_contents.split("\n") if x.startswith("# text = ")] + elif data_format == "one-per-line": + with open(ud_file_path) as ud_file: + sentences = [x for x in ud_file.read().strip().split("\n") if x] + return sentences + + +def sentence_to_windows(sentence, min_window, max_window): + """ + Create window size chunks from a sentence, always starting with a word + """ + windows = [] + words = sentence.split(" ") + curr_window = "" + for idx, word in enumerate(words): + curr_window += (" " + word) + curr_window = curr_window.lstrip() + next_word_len = len(words[idx+1]) + 1 if idx+1 < len(words) else 0 + if len(curr_window) + next_word_len > max_window: + curr_window = clean_sentence(curr_window) + if validate_sentence(curr_window, min_window): + windows.append(curr_window.strip()) + curr_window = "" + if len(curr_window) >= min_window: + windows.append(curr_window) + return windows + + +def validate_sentence(current_window, min_window): + """ + Sentence validation from: LSTM-LID + GitHub: https://github.com/AU-DIS/LSTM_langid/blob/main/src/dataset_creator.py + """ + if len(current_window) < min_window: + return False + return True + +def find(s, ch): + """ + Helper for clean_sentence from LSTM-LID + GitHub: https://github.com/AU-DIS/LSTM_langid/blob/main/src/dataset_creator.py + """ + return [i for i, ltr in enumerate(s) if ltr == ch] + + +def clean_sentence(line): + """ + Sentence cleaning from LSTM-LID + GitHub: https://github.com/AU-DIS/LSTM_langid/blob/main/src/dataset_creator.py + """ + # We remove some special characters and fix small errors in the data, to improve the quality of the data + line = line.replace("\n", '') #{"text": "- Mor.\n", "label": "da"} + line = line.replace("- ", '') #{"text": "- Mor.", "label": "da"} + line = line.replace("_", '') #{"text": "- Mor.", "label": "da"} + line = line.replace("\\", '') + line = line.replace("\"", '') + line = line.replace(" ", " ") + remove_digits = str.maketrans('', '', digits) + line = line.translate(remove_digits) + words = line.split() + new_words = [] + # Below fixes large I instead of l. Does not catch everything, but should also not really make any mistakes either + for word in words: + clean_word = word + s = clean_word + if clean_word[1:].__contains__("I"): + indices = find(clean_word, "I") + for indx in indices: + if clean_word[indx-1].islower(): + if len(clean_word) > indx + 1: + if clean_word[indx+1].islower(): + s = s[:indx] + "l" + s[indx + 1:] + else: + s = s[:indx] + "l" + s[indx + 1:] + new_words.append(s) + new_line = " ".join(new_words) + return new_line + + +def example_json(lang_id, text, eval_length=None): + if eval_length is not None: + text = text[:eval_length] + return {"text": text.strip(), "label": lang_id} + + +if __name__ == "__main__": + main() + diff --git a/stanza/models/langid/data.py b/stanza/models/langid/data.py new file mode 100644 index 00000000..b5e328cc --- /dev/null +++ b/stanza/models/langid/data.py @@ -0,0 +1,136 @@ +import json +import random +import torch + + +class DataLoader: + """ + Class for loading language id data and providing batches + """ + + def __init__(self, use_gpu=None): + self.batches = None + self.batches_iter = None + self.tag_to_idx = None + self.idx_to_tag = None + self.lang_weights = None + # set self.use_gpu and self.device + if use_gpu is None: + self.use_gpu = torch.cuda.is_available() + else: + self.use_gpu = use_gpu + if self.use_gpu: + self.device = torch.device("cuda") + else: + self.device = None + + def load_data(self, batch_size, data_files, char_index, tag_index, randomize=False, randomize_range=(5,20), + max_length=None): + """ + Load sequence data and labels, calculate weights for weighted cross entropy loss. + Data is stored in a file, 1 example per line + Example: {"text": "Hello world.", "label": "en"} + """ + + # set up examples from data files + examples = [] + for data_file in data_files: + examples += [x for x in open(data_file).read().split("\n") if x.strip()] + random.shuffle(examples) + examples = [json.loads(x) for x in examples] + + # add additional labels in this data set to tag index + tag_index = dict(tag_index) + new_labels = set([x["label"] for x in examples]) - set(tag_index.keys()) + for new_label in new_labels: + tag_index[new_label] = len(tag_index) + self.tag_to_idx = tag_index + self.idx_to_tag = [i[1] for i in sorted([(v,k) for k,v in self.tag_to_idx.items()])] + + # set up lang counts used for weights for cross entropy loss + lang_counts = [0 for _ in tag_index] + + # optionally limit text to max length + if max_length is not None: + examples = [{"text": x["text"][:max_length], "label": x["label"]} for x in examples] + + # randomize data + if randomize: + split_examples = [] + for example in examples: + sequence = example["text"] + label = example["label"] + sequences = DataLoader.randomize_data([sequence], upper_lim=randomize_range[1], + lower_lim=randomize_range[0]) + split_examples += [{"text": seq, "label": label} for seq in sequences] + examples = split_examples + random.shuffle(examples) + + # break into equal length batches + batch_lengths = {} + for example in examples: + sequence = example["text"] + label = example["label"] + if len(sequence) not in batch_lengths: + batch_lengths[len(sequence)] = [] + sequence_as_list = [char_index.get(c, char_index["UNK"]) for c in list(sequence)] + batch_lengths[len(sequence)].append((sequence_as_list, tag_index[label])) + lang_counts[tag_index[label]] += 1 + for length in batch_lengths: + random.shuffle(batch_lengths[length]) + + # create final set of batches + batches = [] + for length in batch_lengths: + for sublist in [batch_lengths[length][i:i + batch_size] for i in + range(0, len(batch_lengths[length]), batch_size)]: + batches.append(sublist) + + self.batches = [self.build_batch_tensors(batch) for batch in batches] + + # set up lang weights + most_frequent = max(lang_counts) + # set to 0.0 if lang_count is 0 or most_frequent/lang_count otherwise + lang_counts = [(most_frequent * x)/(max(1, x) ** 2) for x in lang_counts] + self.lang_weights = torch.tensor(lang_counts, device=self.device, dtype=torch.float) + + # shuffle batches to mix up lengths + random.shuffle(self.batches) + self.batches_iter = iter(self.batches) + + @staticmethod + def randomize_data(sentences, upper_lim=20, lower_lim=5): + """ + Takes the original data and creates random length examples with length between upper limit and lower limit + From LSTM_langid: https://github.com/AU-DIS/LSTM_langid/blob/main/src/language_datasets.py + """ + + new_data = [] + for sentence in sentences: + remaining = sentence + while lower_lim < len(remaining): + lim = random.randint(lower_lim, upper_lim) + m = min(len(remaining), lim) + new_sentence = remaining[:m] + new_data.append(new_sentence) + split = remaining[m:].split(" ", 1) + if len(split) <= 1: + break + remaining = split[1] + random.shuffle(new_data) + return new_data + + def build_batch_tensors(self, batch): + """ + Helper to turn batches into tensors + """ + + batch_tensors = dict() + batch_tensors["sentences"] = torch.tensor([s[0] for s in batch], device=self.device, dtype=torch.long) + batch_tensors["targets"] = torch.tensor([s[1] for s in batch], device=self.device, dtype=torch.long) + + return batch_tensors + + def next(self): + return next(self.batches_iter) + diff --git a/stanza/models/langid/model.py b/stanza/models/langid/model.py new file mode 100644 index 00000000..799030e3 --- /dev/null +++ b/stanza/models/langid/model.py @@ -0,0 +1,120 @@ +import torch +import torch.nn as nn + + +class LangIDBiLSTM(nn.Module): + """ + Multi-layer BiLSTM model for language detecting. A recreation of "A reproduction of Apple's bi-directional LSTM models + for language identification in short strings." (Toftrup et al 2021) + + Arxiv: https://arxiv.org/abs/2102.06282 + GitHub: https://github.com/AU-DIS/LSTM_langid + """ + + def __init__(self, char_to_idx, tag_to_idx, num_layers, embedding_dim, hidden_dim, batch_size=64, weights=None, + dropout=0.0, lang_subset=None): + super(LangIDBiLSTM, self).__init__() + self.num_layers = num_layers + self.embedding_dim = embedding_dim + self.hidden_dim = hidden_dim + self.char_to_idx = char_to_idx + self.vocab_size = len(char_to_idx) + self.tag_to_idx = tag_to_idx + self.idx_to_tag = [i[1] for i in sorted([(v,k) for k,v in self.tag_to_idx.items()])] + self.lang_subset = lang_subset + self.padding_idx = char_to_idx[""] + self.tagset_size = len(tag_to_idx) + self.batch_size = batch_size + self.loss_train = nn.CrossEntropyLoss(weight=weights) + self.dropout_prob = dropout + + # embeddings for chars + self.char_embeds = nn.Embedding( + num_embeddings=self.vocab_size, + embedding_dim=self.embedding_dim, + padding_idx=self.padding_idx + ) + + # the bidirectional LSTM + self.lstm = nn.LSTM( + self.embedding_dim, + self.hidden_dim, + num_layers=self.num_layers, + bidirectional=True, + batch_first=True + ) + + # convert output to tag space + self.hidden_to_tag = nn.Linear( + self.hidden_dim * 2, + self.tagset_size + ) + + # dropout layer + self.dropout = nn.Dropout(p=self.dropout_prob) + + def build_lang_mask(self, use_gpu=None): + """ + Build language mask if a lang subset is specified (e.g. ["en", "fr"]) + """ + device = torch.device("cuda") if use_gpu else None + lang_mask_list = [int(lang in self.lang_subset) for lang in self.idx_to_tag] if self.lang_subset else \ + [1 for lang in self.idx_to_tag] + self.lang_mask = torch.tensor(lang_mask_list, device=device, dtype=torch.float) + + def loss(self, Y_hat, Y): + return self.loss_train(Y_hat, Y) + + def forward(self, x): + # embed input + x = self.char_embeds(x) + + # run through LSTM + x, _ = self.lstm(x) + + # run through linear layer + x = self.hidden_to_tag(x) + + # sum character outputs for each sequence + x = torch.sum(x, dim=1) + + return x + + def prediction_scores(self, x): + prediction_probs = self(x) + if self.lang_subset: + prediction_batch_size = prediction_probs.size()[0] + batch_mask = torch.stack([self.lang_mask for _ in range(prediction_batch_size)]) + prediction_probs = prediction_probs * batch_mask + return torch.argmax(prediction_probs, dim=1) + + def save(self, path): + """ Save a model at path """ + checkpoint = { + "char_to_idx": self.char_to_idx, + "tag_to_idx": self.tag_to_idx, + "num_layers": self.num_layers, + "embedding_dim": self.embedding_dim, + "hidden_dim": self.hidden_dim, + "model_state_dict": self.state_dict() + } + torch.save(checkpoint, path) + + @classmethod + def load(cls, path, use_cuda=False, batch_size=64, lang_subset=None): + """ Load a serialized model located at path """ + if use_cuda: + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + else: + device = torch.device("cpu") + checkpoint = torch.load(path, map_location=torch.device("cpu")) + weights = checkpoint["model_state_dict"]["loss_train.weight"] + model = cls(checkpoint["char_to_idx"], checkpoint["tag_to_idx"], checkpoint["num_layers"], + checkpoint["embedding_dim"], checkpoint["hidden_dim"], batch_size=batch_size, weights=weights, + lang_subset=lang_subset) + model.load_state_dict(checkpoint["model_state_dict"]) + if use_cuda: + model.to(torch.device("cuda")) + model.build_lang_mask(use_gpu=use_cuda) + return model + diff --git a/stanza/models/langid/trainer.py b/stanza/models/langid/trainer.py new file mode 100644 index 00000000..6491508f --- /dev/null +++ b/stanza/models/langid/trainer.py @@ -0,0 +1,53 @@ +import torch +import torch.optim as optim + +from stanza.models.langid.model import LangIDBiLSTM + + +class Trainer: + + DEFAULT_BATCH_SIZE = 64 + DEFAULT_LAYERS = 2 + DEFAULT_EMBEDDING_DIM = 150 + DEFAULT_HIDDEN_DIM = 150 + + def __init__(self, config, load_model=False, use_gpu=None): + self.model_path = config["model_path"] + self.use_gpu = torch.cuda.is_available() if use_gpu is None else use_gpu + self.device = torch.device("cuda") if self.use_gpu else None + self.batch_size = config.get("batch_size", Trainer.DEFAULT_BATCH_SIZE) + if load_model: + self.load(config["load_model"]) + else: + self.model = LangIDBiLSTM(config["char_to_idx"], config["tag_to_idx"], Trainer.DEFAULT_LAYERS, + Trainer.DEFAULT_EMBEDDING_DIM, + Trainer.DEFAULT_HIDDEN_DIM, + batch_size=self.batch_size, + weights=config["lang_weights"]).to(self.device) + self.optimizer = optim.AdamW(self.model.parameters()) + + def update(self, inputs): + self.model.train() + sentences, targets = inputs + self.optimizer.zero_grad() + y_hat = self.model.forward(sentences) + loss = self.model.loss(y_hat, targets) + loss.backward() + self.optimizer.step() + + def predict(self, inputs): + self.model.eval() + sentences, targets = inputs + return torch.argmax(self.model(sentences), dim=1) + + def save(self, label=None): + # save a copy of model with label + if label: + self.model.save(f"{self.model_path[:-3]}-{label}.pt") + self.model.save(self.model_path) + + def load(self, model_path=None): + if not model_path: + model_path = self.model_path + self.model = LangIDBiLSTM.load(model_path, self.use_gpu, self.batch_size) + diff --git a/stanza/pipeline/_constants.py b/stanza/pipeline/_constants.py index 865a185e..db3e1cb5 100644 --- a/stanza/pipeline/_constants.py +++ b/stanza/pipeline/_constants.py @@ -1,6 +1,7 @@ """ Module defining constants """ # string constants for processor names +LANGID = 'langid' TOKENIZE = 'tokenize' MWT = 'mwt' POS = 'pos' diff --git a/stanza/pipeline/core.py b/stanza/pipeline/core.py index 40adef47..09299e05 100644 --- a/stanza/pipeline/core.py +++ b/stanza/pipeline/core.py @@ -15,6 +15,7 @@ from stanza.pipeline._constants import * from stanza.models.common.doc import Document from stanza.pipeline.processor import Processor, ProcessorRequirementsException from stanza.pipeline.registry import NAME_TO_PROCESSOR_CLASS, PIPELINE_NAMES +from stanza.pipeline.langid_processor import LangIDProcessor from stanza.pipeline.tokenize_processor import TokenizeProcessor from stanza.pipeline.mwt_processor import MWTProcessor from stanza.pipeline.pos_processor import POSProcessor diff --git a/stanza/pipeline/langid_processor.py b/stanza/pipeline/langid_processor.py new file mode 100644 index 00000000..a512196e --- /dev/null +++ b/stanza/pipeline/langid_processor.py @@ -0,0 +1,126 @@ +""" +Processor for determining language of text. +""" + +import emoji +import re +import stanza +import torch + +from stanza.models.common.doc import Document +from stanza.models.langid.model import LangIDBiLSTM +from stanza.pipeline._constants import * +from stanza.pipeline.processor import UDProcessor, register_processor + + +@register_processor(name=LANGID) +class LangIDProcessor(UDProcessor): + """ + Class for detecting language of text. + """ + + # set of processor requirements this processor fulfills + PROVIDES_DEFAULT = set([LANGID]) + + # set of processor requirements for this processor + REQUIRES_DEFAULT = set([]) + + # default max sequence length + MAX_SEQ_LENGTH_DEFAULT = 1000 + + def _set_up_model(self, config, use_gpu): + batch_size = config.get("batch_size", 64) + self._model = LangIDBiLSTM.load(path=config["model_path"], use_cuda=use_gpu, + batch_size=batch_size, lang_subset=config.get("lang_subset")) + self._device = torch.device("cuda") if use_gpu else None + self._char_index = self._model.char_to_idx + self._clean_text = config.get("clean_text") + + def _text_to_tensor(self, docs): + """ + Map list of strings to batch tensor. Assumed all docs are same length. + """ + + all_docs = [] + for doc in docs: + doc_chars = [self._char_index.get(c, self._char_index["UNK"]) for c in list(doc)] + all_docs.append(doc_chars) + return torch.tensor(all_docs, device=self._device, dtype=torch.long) + + def _id_langs(self, batch_tensor): + """ + Identify languages for each sequence in a batch tensor + """ + predictions = self._model.prediction_scores(batch_tensor) + prediction_labels = [self._model.idx_to_tag[prediction] for prediction in predictions] + + return prediction_labels + + # regexes for cleaning text + http_regex = re.compile("https?:\/\/t\.co/[a-zA-Z0-9]+") + handle_regex = re.compile("@[a-zA-Z0-9_]+") + hashtag_regex = re.compile("#[a-zA-Z]+") + punctuation_regex = re.compile("[!.]+") + all_regexes = [http_regex, handle_regex, hashtag_regex, punctuation_regex] + + @staticmethod + def clean_text(text): + """ + Process text to improve language id performance. Main emphasis is on tweets, this method removes shortened + urls, hashtags, handles, and punctuation and emoji. + """ + + for regex in LangIDProcessor.all_regexes: + text = regex.sub(" ", text) + + text = emoji.get_emoji_regexp().sub(" ", text) + + if text.strip(): + text = text.strip() + + return text + + def _process_list(self, docs): + """ + Identify language of list of strings or Documents + """ + + if len(docs) == 0: + # TO DO: what standard do we want for bad input, such as empty list? + # TO DO: more handling of bad input + return + + if isinstance(docs[0], str): + docs = [Document([], text) for text in docs] + + docs_by_length = {} + for doc in docs: + text = LangIDProcessor.clean_text(doc.text) if self._clean_text else doc.text + doc_length = len(text) + if doc_length not in docs_by_length: + docs_by_length[doc_length] = [] + docs_by_length[doc_length].append((doc, text)) + + for doc_length in docs_by_length: + inputs = [doc[1] for doc in docs_by_length[doc_length]] + predictions = self._id_langs(self._text_to_tensor(inputs)) + for doc, lang in zip(docs_by_length[doc_length], predictions): + doc[0].lang = lang + + return docs + + def process(self, doc): + """ + Handle single str or Document + """ + + wrapped_doc = [doc] + return self._process_list(wrapped_doc)[0] + + def bulk_process(self, docs): + """ + Handle list of strings or Documents + """ + + return self._process_list(docs) + diff --git a/stanza/pipeline/multilingual.py b/stanza/pipeline/multilingual.py new file mode 100644 index 00000000..a6c958ba --- /dev/null +++ b/stanza/pipeline/multilingual.py @@ -0,0 +1,106 @@ +""" +Class for running multilingual pipelines +""" + +import torch + +from stanza.models.common.doc import Document +from stanza.pipeline.core import Pipeline +from stanza.pipeline._constants import * + + +class MultilingualPipeline: + """ + Pipeline for handling multilingual data. Takes in text, detects language, and routes request to pipeline for that + language. + """ + + def __init__( + self, + lang_id_config: dict = None, + lang_configs: dict = None, + ld_batch_size: int = 64, + max_cache_size: int = 10, + use_gpu: bool = None + ): + # set up configs and cache for various language pipelines + self.lang_id_config = {} if lang_id_config is None else lang_id_config + self.lang_configs = {} if lang_configs is None else lang_configs + self.max_cache_size = max_cache_size + self.pipeline_cache = {} + self.lang_request_history = [] + + # set use_gpu + if use_gpu is None: + self.use_gpu = torch.cuda.is_available() + else: + self.use_gpu = use_gpu + + # build language id pipeline + self.lang_id_pipeline = Pipeline(lang='multilingual', processors="langid", use_gpu=self.use_gpu, + **self.lang_id_config) + + def _update_pipeline_cache(self, lang): + """ + Do any necessary updates to the pipeline cache for this language. This includes building a new + pipeline for the lang, and possibly clearing out a language with the old last access date. + """ + + # update request history + if lang in self.lang_request_history: + self.lang_request_history.remove(lang) + self.lang_request_history.append(lang) + + # update language configs + if lang not in self.lang_configs: + self.lang_configs[lang] = {'lang': lang} + + # update pipeline cache + if lang not in self.pipeline_cache: + # clear least recently used lang from pipeline cache + if len(self.pipeline_cache) == self.max_cache_size: + lru_lang = self.lang_request_history[0] + self.pipeline_cache.remove(lru_lang) + self.lang_request_history.remove(lru_lang) + self.pipeline_cache[lang] = Pipeline(**self.lang_configs[lang]) + + def process(self, doc): + """ + Run language detection on a string, a Document, or a list of either, route to language specific pipeline + """ + + # only return a list if given a list + singleton_input = not isinstance(doc, list) + if singleton_input: + docs = [doc] + else: + docs = doc + + if docs and isinstance(docs[0], str): + docs = [Document([], text=text) for text in docs] + + # run language identification + docs_w_langid = self.lang_id_pipeline.process(docs) + + # create language specific batches, store global idx with each doc + lang_batches = {} + for doc in docs_w_langid: + if doc.lang not in lang_batches: + lang_batches[doc.lang] = [] + lang_batches[doc.lang].append(doc) + + # run through each language, submit a batch to the language specific pipeline + for lang in lang_batches.keys(): + self._update_pipeline_cache(lang) + self.pipeline_cache[lang](lang_batches[lang]) + + # only return a list if given a list + if singleton_input: + return docs_w_langid[0] + else: + return docs_w_langid + + def __call__(self, doc): + doc = self.process(doc) + return doc + diff --git a/stanza/resources/prepare_resources.py b/stanza/resources/prepare_resources.py index 31177863..16f64a7f 100644 --- a/stanza/resources/prepare_resources.py +++ b/stanza/resources/prepare_resources.py @@ -83,10 +83,10 @@ default_treebanks = { "te": "mtg", "orv": "torot", "nn": "nynorsk", - "mr": "ufal" + "mr": "ufal", + "multilingual": "ud" } - # default ner for languages default_ners = { "ar": "aqmar", @@ -104,7 +104,6 @@ default_ners = { "zh-hans": "ontonotes", } - # default charlms for languages default_charlms = { "ar": "ccwiki", @@ -167,7 +166,8 @@ processor_to_ending = { "sentiment": "sentiment", "pretrain": "pretrain", "forward_charlm": "forward_charlm", - "backward_charlm": "backward_charlm" + "backward_charlm": "backward_charlm", + "langid": "langid" } ending_to_processor = {j: i for i, j in processor_to_ending.items()} @@ -269,9 +269,11 @@ def get_md5(path): def parse_args(): parser = argparse.ArgumentParser() - parser.add_argument('--input_dir', type=str, help='Input dir for various models.') - parser.add_argument('--output_dir', type=str, help='Output dir for various models.') + parser.add_argument('--input-dir', type=str, help='Input dir for various models.') + parser.add_argument('--output-dir', type=str, help='Output dir for various models.') args = parser.parse_args() + args.input_dir = os.path.abspath(args.input_dir) + args.output_dir = os.path.abspath(args.output_dir) return args @@ -309,15 +311,15 @@ def process_dirs(args): dirs = sorted(os.listdir(args.input_dir)) resources = {} - for dir in dirs: - print(f"Processing models in {dir}") - models = sorted(os.listdir(os.path.join(args.input_dir, dir))) + for model_dir in dirs: + print(f"Processing models in {model_dir}") + models = sorted(os.listdir(os.path.join(args.input_dir, model_dir))) for model in models: if not model.endswith('.pt'): continue # get processor lang, package, processor = split_model_name(model) # copy file - input_path = os.path.join(args.input_dir, dir, model) + input_path = os.path.join(args.input_dir, model_dir, model) output_path = os.path.join(args.output_dir, lang, processor, package + '.pt') ensure_dir(Path(output_path).parent) shutil.copy(input_path, output_path) @@ -381,17 +383,23 @@ def process_defaults(args): if lang in default_sentiment: processors.append('sentiment') + if lang == 'multilingual': + processors = ['langid'] + default_dependencies = {} + with zipfile.ZipFile('default.zip', 'w', zipfile.ZIP_DEFLATED) as zipf: for processor in processors: if processor == 'ner': package = ner_package elif processor in ['forward_charlm', 'backward_charlm']: package = charlm_package elif processor == 'sentiment': package = sentiment_package + elif processor == 'langid': package = 'ud' else: package = ud_package filename = os.path.join(args.output_dir, lang, processor, package + '.pt') + if os.path.exists(filename): print(" Model {} package {}: file {}".format(processor, package, filename)) - if processor in ['tokenize', 'mwt', 'lemma', 'pos', 'depparse', 'ner', 'sentiment']: + if processor in ['tokenize', 'mwt', 'lemma', 'pos', 'depparse', 'ner', 'sentiment', 'langid']: default_processors[processor] = package zipf.write(processor) zipf.write(os.path.join(processor, package + '.pt')) @@ -420,6 +428,7 @@ def process_defaults(args): def process_lcode(args): resources = json.load(open(os.path.join(args.output_dir, 'resources.json'))) resources_new = {} + resources_new["multilingual"] = resources["multilingual"] for lang in resources: if lang not in lcode2lang: print(lang + ' not found in lcode2lang!') @@ -449,3 +458,4 @@ def main(): if __name__ == '__main__': main() + diff --git a/stanza/tests/test_langid.py b/stanza/tests/test_langid.py new file mode 100644 index 00000000..9d42f9b3 --- /dev/null +++ b/stanza/tests/test_langid.py @@ -0,0 +1,608 @@ +""" +Basic tests of langid module +""" + +from stanza.models.common.doc import Document +from stanza.pipeline.core import Pipeline +from stanza.pipeline.multilingual import MultilingualPipeline + +def test_langid(): + """ + Basic test of language identification + """ + english_text = "This is an English sentence." + french_text = "C'est une phrase française." + docs = [english_text, french_text] + + nlp = Pipeline(lang='multilingual', processors="langid") + docs = [Document([], text=text) for text in docs] + nlp(docs) + predictions = [doc.lang for doc in docs] + assert predictions == ["en", "fr"] + +def test_langid_benchmark(): + """ + Run lang id model on 500 examples, confirm reasonable accuracy. + """ + examples = [ + {"text": "contingentiam in naturalibus causis.", "label": "la"}, + {"text": "I jak opowiadał nieżyjący już pan Czesław", "label": "pl"}, + {"text": "Sonera gilt seit längerem als Übernahmekandidat", "label": "de"}, + {"text": "与银类似,汞也可以与空气中的硫化氢反应。", "label": "zh-hans"}, + {"text": "contradictionem implicat.", "label": "la"}, + {"text": "Bis zu Prozent gingen die Offerten etwa im", "label": "de"}, + {"text": "inneren Sicherheit vorgeschlagene Ausweitung der", "label": "de"}, + {"text": "Multimedia-PDA mit Mini-Tastatur", "label": "de"}, + {"text": "Ponášalo sa to na rovnicu o dvoch neznámych.", "label": "sk"}, + {"text": "이처럼 앞으로 심판의 그 날에 다시 올 메시아가 예수 그리스도이며 , 그는 모든 인류의", "label": "ko"}, + {"text": "Die Arbeitsgruppe bedauert , dass der weit über", "label": "de"}, + {"text": "И только раз довелось поговорить с ним не вполне", "label": "ru"}, + {"text": "de a-l lovi cu piciorul și conștiința că era", "label": "ro"}, + {"text": "relación coas pretensións do demandante e que, nos", "label": "gl"}, + {"text": "med petdeset in sedemdeset", "label": "sl"}, + {"text": "Catalunya; el Consell Comarcal del Vallès Oriental", "label": "ca"}, + {"text": "kunnen worden.", "label": "nl"}, + {"text": "Witkin je ve většině ohledů zcela jiný.", "label": "cs"}, + {"text": "lernen, so zu agieren, dass sie positive oder auch", "label": "de"}, + {"text": "olurmuş...", "label": "tr"}, + {"text": "sarcasmo de Altman, desde as «peruas» que discutem", "label": "pt"}, + {"text": "خلاف فوجداری مقدمہ درج کرے۔", "label": "ur"}, + {"text": "Norddal kommune :", "label": "no"}, + {"text": "dem Windows-.-Zeitalter , soll in diesem Jahr", "label": "de"}, + {"text": "przeklętych ucieleśniają mit poety-cygana,", "label": "pl"}, + {"text": "We do not believe the suspect has ties to this", "label": "en"}, + {"text": "groziņu pīšanu.", "label": "lv"}, + {"text": "Senior Vice-President David M. Thomas möchte", "label": "de"}, + {"text": "neomylně vybral nějakou knihu a začetl se.", "label": "cs"}, + {"text": "Statt dessen darf beispielsweise der Browser des", "label": "de"}, + {"text": "outubro, alcançando R $ bilhões em .", "label": "pt"}, + {"text": "(Porte, ), as it does other disciplines", "label": "en"}, + {"text": "uskupení se mylně domnívaly, že podporu", "label": "cs"}, + {"text": "Übernahme von Next Ende an dem System herum , das", "label": "de"}, + {"text": "No podemos decir a la Hacienda que los alemanes", "label": "es"}, + {"text": "и рѣста еи братья", "label": "orv"}, + {"text": "الذي اتخذ قرارا بتجميد اعلان الدولة الفلسطينية", "label": "ar"}, + {"text": "uurides Rootsi sõjaarhiivist toodud . sajandi", "label": "et"}, + {"text": "selskapets penger til å pusse opp sin enebolig på", "label": "no"}, + {"text": "средней полосе и севернее в Ярославской,", "label": "ru"}, + {"text": "il-massa żejda fil-ġemgħat u superġemgħat ta'", "label": "mt"}, + {"text": "The Global Beauties on internetilehekülg, mida", "label": "et"}, + {"text": "이스라엘 인들은 하나님이 그 큰 팔을 펴 이집트 인들을 치는 것을 보고 하나님을 두려워하며", "label": "ko"}, + {"text": "Snad ještě dodejme jeden ekonomický argument.", "label": "cs"}, + {"text": "Spalio d. vykusiame pirmajame rinkimų ture", "label": "lt"}, + {"text": "und schlechter Journalismus ein gutes Geschäft .", "label": "de"}, + {"text": "Du sodiečiai sėdi ant potvynio apsemtų namų stogo.", "label": "lt"}, + {"text": "цей є автентичним.", "label": "uk"}, + {"text": "Și îndegrabă fu cu îngerul mulțime de șireaguri", "label": "ro"}, + {"text": "sobra personal cualificado.", "label": "es"}, + {"text": "Tako se u Njemačkoj dvije trećine liječnika služe", "label": "hr"}, + {"text": "Dual-Athlon-Chipsatz noch in diesem Jahr", "label": "de"}, + {"text": "यहां तक कि चीन के चीफ ऑफ जनरल स्टाफ भी भारत का", "label": "hi"}, + {"text": "Li forestier du mont avale", "label": "fro"}, + {"text": "Netzwerken für Privatanwender zu bewundern .", "label": "de"}, + {"text": "만해는 승적을 가진 중이 결혼할 수 없다는 불교의 계율을 시대에 맞지 않는 것으로 보았다", "label": "ko"}, + {"text": "balance and weight distribution but not really for", "label": "en"}, + {"text": "og så e # tente vi opp den om morgonen å sfyrte", "label": "nn"}, + {"text": "변화는 의심의 여지가 없는 것이지만 반면에 진화는 논쟁의 씨앗이다 .", "label": "ko"}, + {"text": "puteare fac aceastea.", "label": "ro"}, + {"text": "Waitt seine Führungsmannschaft nicht dem", "label": "de"}, + {"text": "juhtimisega, tulid sealt.", "label": "et"}, + {"text": "Veränderungen .", "label": "de"}, + {"text": "banda en el Bayer Leverkusen de la Bundesliga de", "label": "es"}, + {"text": "В туже зиму посла всеволодъ сн҃а своѥго ст҃ослава", "label": "orv"}, + {"text": "пославъ приведе я мастеры ѿ грекъ", "label": "orv"}, + {"text": "En un nou escenari difícil d'imaginar fa poques", "label": "ca"}, + {"text": "καὶ γὰρ τινὲς αὐτοὺς εὐεργεσίαι εἶχον ἐκ Κροίσου", "label": "grc"}, + {"text": "직접적인 관련이 있다 .", "label": "ko"}, + {"text": "가까운 듯하면서도 멀다 .", "label": "ko"}, + {"text": "Er bietet ein ähnliches Leistungsniveau und", "label": "de"}, + {"text": "民都洛水牛是獨居的,並不會以群族聚居。", "label": "zh-hant"}, + {"text": "την τρομοκρατία.", "label": "el"}, + {"text": "hurbiltzen diren neurrian.", "label": "eu"}, + {"text": "Ah dimenticavo, ma tutta sta caciara per fare un", "label": "it"}, + {"text": "На первом этапе (-) прошла так называемая", "label": "ru"}, + {"text": "of games are on the market.", "label": "en"}, + {"text": "находится Мост дружбы, соединяющий узбекский и", "label": "ru"}, + {"text": "lessié je voldroie que li saint fussent aporté", "label": "fro"}, + {"text": "Дошла очередь и до Гималаев.", "label": "ru"}, + {"text": "vzácným suknem táhly pouští, si jednou chtěl do", "label": "cs"}, + {"text": "E no terceiro tipo sitúa a familias (%), nos que a", "label": "gl"}, + {"text": "وجابت دوريات امريكية وعراقية شوارع المدينة، فيما", "label": "ar"}, + {"text": "Jeg har bodd her i år .", "label": "no"}, + {"text": "Pohrozil, že odbory zostří postoj, pokud se", "label": "cs"}, + {"text": "tinham conseguido.", "label": "pt"}, + {"text": "Nicht-Erkrankten einen Anfangsverdacht für einen", "label": "de"}, + {"text": "permanece em aberto.", "label": "pt"}, + {"text": "questi possono promettere rendimenti fino a un", "label": "it"}, + {"text": "Tema juurutatud kahevedurisüsteemita oleksid", "label": "et"}, + {"text": "Поведение внешне простой игрушки оказалось", "label": "ru"}, + {"text": "Bundesländern war vom Börsenverein des Deutschen", "label": "de"}, + {"text": "acció, 'a mesura que avanci l'estiu, amb l'augment", "label": "ca"}, + {"text": "Dove trovare queste risorse? Jay Naidoo, ministro", "label": "it"}, + {"text": "essas gordurinhas.", "label": "pt"}, + {"text": "Im zweiten Schritt sollen im übernächsten Jahr", "label": "de"}, + {"text": "allveelaeva pole enam vaja, kuna külm sõda on läbi", "label": "et"}, + {"text": "उपद्रवी दुकानों को लूटने के साथ ही उनमें आग लगा", "label": "hi"}, + {"text": "@user nella sfortuna sei fortunata ..", "label": "it"}, + {"text": "математических школ в виде грозовых туч.", "label": "ru"}, + {"text": "No cambiaremos nunca nuestra forma de jugar por un", "label": "es"}, + {"text": "dla tej klasy ani wymogów minimalnych, z wyjątkiem", "label": "pl"}, + {"text": "en todo el mundo, mientras que en España consiguió", "label": "es"}, + {"text": "политики считать надежное обеспечение военной", "label": "ru"}, + {"text": "gogoratzen du, genio alemana delakoaren", "label": "eu"}, + {"text": "Бычий глаз.", "label": "ru"}, + {"text": "Opeření se v pravidelných obdobích obnovuje", "label": "cs"}, + {"text": "I no és només la seva, es tracta d'una resposta", "label": "ca"}, + {"text": "오경을 가르쳤다 .", "label": "ko"}, + {"text": "Nach der so genannten Start-up-Periode vergibt die", "label": "de"}, + {"text": "Saulista huomasi jo lapsena , että hänellä on", "label": "fi"}, + {"text": "Министерство культуры сочло нецелесообразным, и", "label": "ru"}, + {"text": "znepřátelené tábory v Tádžikistánu předseda", "label": "cs"}, + {"text": "καὶ ἦν ὁ λαὸς προσδοκῶν τὸν Ζαχαρίαν καὶ ἐθαύμαζον", "label": "grc"}, + {"text": "Вечером, в продукте, этот же человек говорил о", "label": "ru"}, + {"text": "lugar á formación de xuizos máis complexos.", "label": "gl"}, + {"text": "cheaper, in the end?", "label": "en"}, + {"text": "الوزارة في شأن صفقات بيع الشركات العامة التي تم", "label": "ar"}, + {"text": "tärkeintä elämässäni .", "label": "fi"}, + {"text": "Виконання Мінських угод було заблоковано Росією та", "label": "uk"}, + {"text": "Aby szybko rozpoznać żołnierzy desantu, należy", "label": "pl"}, + {"text": "Bankengeschäfte liegen vorn , sagte Strothmann .", "label": "de"}, + {"text": "продолжение работы.", "label": "ru"}, + {"text": "Metro AG plant Online-Offensive", "label": "de"}, + {"text": "nu vor veni, și să vor osîndi, aceia nu pot porni", "label": "ro"}, + {"text": "Ich denke , es geht in Wirklichkeit darum , NT bei", "label": "de"}, + {"text": "de turism care încasează contravaloarea", "label": "ro"}, + {"text": "Aurkaria itotzea da helburua, baloia lapurtu eta", "label": "eu"}, + {"text": "com a centre de formació en Tecnologies de la", "label": "ca"}, + {"text": "oportet igitur quod omne agens in agendo intendat", "label": "la"}, + {"text": "Jerzego Andrzejewskiego, oparty na chińskich", "label": "pl"}, + {"text": "sau một vài câu chuyện xã giao không dính dáng tới", "label": "vi"}, + {"text": "что экономическому прорыву жесткий авторитарный", "label": "ru"}, + {"text": "DRAM-Preisen scheinen DSPs ein", "label": "de"}, + {"text": "Jos dajan nubbái: Mana!", "label": "sme"}, + {"text": "toți carii ascultară de el să răsipiră.", "label": "ro"}, + {"text": "odpowiedzialności, które w systemie własności", "label": "pl"}, + {"text": "Dvomesečno potovanje do Mollenda v Peruju je", "label": "sl"}, + {"text": "d'entre les agències internacionals.", "label": "ca"}, + {"text": "Fahrzeugzugangssysteme gefertigt und an viele", "label": "de"}, + {"text": "in an answer to the sharers' petition in Cuthbert", "label": "en"}, + {"text": "Europa-Domain per Verordnung zu regeln .", "label": "de"}, + {"text": "#Balotelli. Su ebay prezzi stracciati per Silvio", "label": "it"}, + {"text": "Ne na košickém trávníku, ale už včera v letadle se", "label": "cs"}, + {"text": "zaměstnanosti a investičních strategií.", "label": "cs"}, + {"text": "Tatínku, udělej den", "label": "cs"}, + {"text": "frecuencia con Mary.", "label": "es"}, + {"text": "Свеаборге.", "label": "ru"}, + {"text": "opatření slovenské strany o certifikaci nejvíce", "label": "cs"}, + {"text": "En todas me decían: 'Espera que hagamos un estudio", "label": "es"}, + {"text": "Die Demonstration sollte nach Darstellung der", "label": "de"}, + {"text": "Ci vorrà un assoluto rigore se dietro i disavanzi", "label": "it"}, + {"text": "Tatínku, víš, že Honzovi odešla maminka?", "label": "cs"}, + {"text": "Die Anzahl der Rechner wuchs um % auf und die", "label": "de"}, + {"text": "האמריקאית על אדמת סעודיה עלולה לסבך את ישראל, אין", "label": "he"}, + {"text": "Volán Egyesülés, a Közlekedési Főfelügyelet is.", "label": "hu"}, + {"text": "Schejbala, který stejnou hru s velkým úspěchem", "label": "cs"}, + {"text": "depends on the data type of the field.", "label": "en"}, + {"text": "Umsatzwarnung zu Wochenbeginn zeitweise auf ein", "label": "de"}, + {"text": "niin heti nukun .", "label": "fi"}, + {"text": "Mobilfunkunternehmen gegen die Anwendung der so", "label": "de"}, + {"text": "sapessi le intenzioni del governo Monti e dell'UE", "label": "it"}, + {"text": "Di chi è figlia Martine Aubry?", "label": "it"}, + {"text": "avec le reste du monde.", "label": "fr"}, + {"text": "Այդ մաքոքը ինքնին նոր չէ, աշխարհը արդեն մի քանի", "label": "hy"}, + {"text": "și în cazul destrămării cenaclului.", "label": "ro"}, + {"text": "befriedigen kann , und ohne die auftretenden", "label": "de"}, + {"text": "Κύκνον τ̓ ἐξεναρεῖν καὶ ἀπὸ κλυτὰ τεύχεα δῦσαι.", "label": "grc"}, + {"text": "færdiguddannede.", "label": "da"}, + {"text": "Schmidt war Sohn eines Rittergutsbesitzers.", "label": "de"}, + {"text": "и вдаша попадь ѡпрати", "label": "orv"}, + {"text": "cine nu știe învățătură”.", "label": "ro"}, + {"text": "détacha et cette dernière tenta de tuer le jeune", "label": "fr"}, + {"text": "Der har saka også ei lengre forhistorie.", "label": "nn"}, + {"text": "Pieprz roztłuc w moździerzu, dodać do pasty,", "label": "pl"}, + {"text": "Лежа за гребнем оврага, как за бруствером, Ушаков", "label": "ru"}, + {"text": "gesucht habe, vielen Dank nochmals!", "label": "de"}, + {"text": "инструментальных сталей, повышения", "label": "ru"}, + {"text": "im Halbfinale Patrick Smith und im Finale dann", "label": "de"}, + {"text": "البنوك التريث في منح تسهيلات جديدة لمنتجي حديد", "label": "ar"}, + {"text": "una bolsa ventral, la cual se encuentra debajo de", "label": "es"}, + {"text": "za SETimes.", "label": "sr"}, + {"text": "de Irak, a un piloto italiano que había violado el", "label": "es"}, + {"text": "Er könne sich nicht erklären , wie die Zeitung auf", "label": "de"}, + {"text": "Прохорова.", "label": "ru"}, + {"text": "la democrazia perde sulla tecnocrazia? #", "label": "it"}, + {"text": "entre ambas instituciones, confirmó al medio que", "label": "es"}, + {"text": "Austlandet, vart det funne om lag førti", "label": "nn"}, + {"text": "уровнями власти.", "label": "ru"}, + {"text": "Dá tedy primáři úplatek, a často ne malý.", "label": "cs"}, + {"text": "brillantes del acto, al llevar a cabo en el", "label": "es"}, + {"text": "eee druga zadeva je majhen priročen gre kamorkoli", "label": "sl"}, + {"text": "Das ATX-Board paßt in herkömmliche PC-ATX-Gehäuse", "label": "de"}, + {"text": "Za vodné bylo v prvním pololetí zaplaceno v ČR", "label": "cs"}, + {"text": "Даже на полсантиметра.", "label": "ru"}, + {"text": "com la del primer tinent d'alcalde en funcions,", "label": "ca"}, + {"text": "кількох оповідань в цілості — щось на зразок того", "label": "uk"}, + {"text": "sed ad divitias congregandas, vel superfluum", "label": "la"}, + {"text": "Norma Talmadge, spela mot Valentino i en version", "label": "sv"}, + {"text": "Dlatego chciał się jej oświadczyć w niezwykłym", "label": "pl"}, + {"text": "будут выступать на одинаковых снарядах.", "label": "ru"}, + {"text": "Orang-orang terbunuh di sana.", "label": "id"}, + {"text": "لدى رايت شقيق اسمه أوسكار, وهو يعمل كرسام للكتب", "label": "ar"}, + {"text": "Wirklichkeit verlagerten und kaum noch", "label": "de"}, + {"text": "как перемешивают костяшки перед игрой в домино, и", "label": "ru"}, + {"text": "В средине дня, когда солнце светило в нашу", "label": "ru"}, + {"text": "d'aventure aux rôles de jeune romantique avec une", "label": "fr"}, + {"text": "My teď hledáme organizace, jež by s námi chtěly", "label": "cs"}, + {"text": "Urteilsfähigkeit einbüßen , wenn ich eigene", "label": "de"}, + {"text": "sua appartenenza anche a voci diverse da quella in", "label": "it"}, + {"text": "Aufträge dieses Jahr verdoppeln werden .", "label": "de"}, + {"text": "M.E.: Miała szanse mnie odnaleźć, gdyby naprawdę", "label": "pl"}, + {"text": "secundum contactum virtutis, cum careat dimensiva", "label": "la"}, + {"text": "ezinbestekoa dela esan zuen.", "label": "eu"}, + {"text": "Anek hurbiltzeko eskatzen zion besaulkitik, eta", "label": "eu"}, + {"text": "perfectius alio videat, quamvis uterque videat", "label": "la"}, + {"text": "Die Strecke war anspruchsvoll und führte unter", "label": "de"}, + {"text": "саморазоблачительным уроком, западные СМИ не", "label": "ru"}, + {"text": "han representerer radikal islamisme .", "label": "no"}, + {"text": "Què s'hi respira pel que fa a la reforma del", "label": "ca"}, + {"text": "previsto para também ser desconstruido.", "label": "pt"}, + {"text": "Ὠκεανοῦ βαθυκόλποις ἄνθεά τ̓ αἰνυμένην, ῥόδα καὶ", "label": "grc"}, + {"text": "para jovens de a anos nos Cieps.", "label": "pt"}, + {"text": "संघर्ष को अंजाम तक पहुंचाने का ऐलान किया है ।", "label": "hi"}, + {"text": "objeví i u nás.", "label": "cs"}, + {"text": "kvitteringer.", "label": "da"}, + {"text": "This report is no exception.", "label": "en"}, + {"text": "Разлепват доносниците до избирателните списъци", "label": "bg"}, + {"text": "anderem ihre Bewegungsfreiheit in den USA", "label": "de"}, + {"text": "Ñu tegoon ca kaw gor ña ay njotti bopp yu kenn", "label": "wo"}, + {"text": "Struktur kann beispielsweise der Schwerpunkt mehr", "label": "de"}, + {"text": "% la velocidad permitida, la sanción es muy grave.", "label": "es"}, + {"text": "Teles-Einstieg in ADSL-Markt", "label": "de"}, + {"text": "ettekäändeks liiga suure osamaksu.", "label": "et"}, + {"text": "als Indiz für die geänderte Marktpolitik des", "label": "de"}, + {"text": "quod quidem aperte consequitur ponentes", "label": "la"}, + {"text": "de negociación para el próximo de junio.", "label": "es"}, + {"text": "Tyto důmyslné dekorace doznaly v poslední době", "label": "cs"}, + {"text": "največjega uspeha doslej.", "label": "sl"}, + {"text": "Paul Allen je jedan od suosnivača Interval", "label": "hr"}, + {"text": "Federal (Seac / DF) eo Sindicato das Empresas de", "label": "pt"}, + {"text": "Quartal mit . Mark gegenüber dem gleichen Quartal", "label": "de"}, + {"text": "otros clubes y del Barça B saldrán varios", "label": "es"}, + {"text": "Jaskula (Pol.) -", "label": "cs"}, + {"text": "umožnily říci, že je možné přejít k mnohem", "label": "cs"}, + {"text": "اعلن الجنرال تومي فرانكس قائد القوات الامريكية", "label": "ar"}, + {"text": "Telekom-Chef Ron Sommer und der Vorstandssprecher", "label": "de"}, + {"text": "My, jako průmyslový a finanční holding, můžeme", "label": "cs"}, + {"text": "voorlichting onder andere betrekking kan hebben:", "label": "nl"}, + {"text": "Hinrichtung geistig Behinderter applaudiert oder", "label": "de"}, + {"text": "wie beispielsweise Anzahl erzielte Klicks ,", "label": "de"}, + {"text": "Intel-PC-SDRAM-Spezifikation in der Version . (", "label": "de"}, + {"text": "plângere în termen de zile de la comunicarea", "label": "ro"}, + {"text": "и Испания ще изгубят втория си комисар в ЕК.", "label": "bg"}, + {"text": "इसके चलते इस आदिवासी जनजाति का क्षरण हो रहा है ।", "label": "hi"}, + {"text": "aunque se mostró contrario a establecer un", "label": "es"}, + {"text": "des letzten Jahres von auf Millionen Euro .", "label": "de"}, + {"text": "Ankara se također poziva da u cijelosti ratificira", "label": "hr"}, + {"text": "herunterlädt .", "label": "de"}, + {"text": "стрессовую ситуацию для организма, каковой", "label": "ru"}, + {"text": "Státního shromáždění (parlamentu).", "label": "cs"}, + {"text": "diskutieren , ob und wie dieser Dienst weiterhin", "label": "de"}, + {"text": "Verbindungen zu FPÖ-nahen Polizisten gepflegt und", "label": "de"}, + {"text": "Pražského volebního lídra ovšem nevybírá Miloš", "label": "cs"}, + {"text": "Nach einem Bericht der Washington Post bleibt das", "label": "de"}, + {"text": "للوضع آنذاك، لكني في قرارة نفسي كنت سعيداً لما", "label": "ar"}, + {"text": "не желаят запазването на статуквото.", "label": "bg"}, + {"text": "Offenburg gewesen .", "label": "de"}, + {"text": "ἐὰν ὑμῖν εἴπω οὐ μὴ πιστεύσητε", "label": "grc"}, + {"text": "all'odiato compagno di squadra Prost, il quale", "label": "it"}, + {"text": "historischen Gänselieselbrunnens.", "label": "de"}, + {"text": "למידע מלווייני הריגול האמריקאיים העוקבים אחר", "label": "he"}, + {"text": "οὐδὲν ἄρα διαφέρεις Ἀμάσιος τοῦ Ἠλείου, ὃν", "label": "grc"}, + {"text": "movementos migratorios.", "label": "gl"}, + {"text": "Handy und ein Spracherkennungsprogramm sämtliche", "label": "de"}, + {"text": "Kümne aasta jooksul on Eestisse ohjeldamatult", "label": "et"}, + {"text": "H.G. Bücknera.", "label": "pl"}, + {"text": "protiv krijumčarenja, ili pak traženju ukidanja", "label": "hr"}, + {"text": "Topware-Anteile mehrere Millionen Mark gefordert", "label": "de"}, + {"text": "Maar de mensen die nu over Van Dijk bij FC Twente", "label": "nl"}, + {"text": "poidan experimentar as percepcións do interesado,", "label": "gl"}, + {"text": "Miał przecież w kieszeni nóż.", "label": "pl"}, + {"text": "Avšak žádná z nich nepronikla za hranice přímé", "label": "cs"}, + {"text": "esim. helpottamalla luottoja muiden", "label": "fi"}, + {"text": "Podle předběžných výsledků zvítězila v", "label": "cs"}, + {"text": "Nicht nur das Web-Frontend , auch die", "label": "de"}, + {"text": "Regierungsinstitutionen oder Universitäten bei", "label": "de"}, + {"text": "Խուլեն Լոպետեգիին, պատճառաբանելով, որ վերջինս", "label": "hy"}, + {"text": "Афганистана, где в последние дни идут ожесточенные", "label": "ru"}, + {"text": "лѧхове же не идоша", "label": "orv"}, + {"text": "Mit Hilfe von IBMs Chip-Management-Systemen sollen", "label": "de"}, + {"text": ", als Manager zu Telefonica zu wechseln .", "label": "de"}, + {"text": "którym zajmuje się człowiek, zmienia go i pozwala", "label": "pl"}, + {"text": "činí kyperských liber, to je asi USD.", "label": "cs"}, + {"text": "Studienplätze getauscht werden .", "label": "de"}, + {"text": "учёных, орнитологов признают вид.", "label": "ru"}, + {"text": "acordare a concediilor prevăzute de legislațiile", "label": "ro"}, + {"text": "at større innsats for fornybar, berekraftig energi", "label": "nn"}, + {"text": "Politiet veit ikkje kor mange personar som deltok", "label": "nn"}, + {"text": "offentligheten av unge , sinte menn som har", "label": "no"}, + {"text": "însuși în jurul lapunei, care încet DISPARE în", "label": "ro"}, + {"text": "O motivo da decisão é evitar uma sobrecarga ainda", "label": "pt"}, + {"text": "El Apostolado de la prensa contribuye en modo", "label": "es"}, + {"text": "Teltow ( Kreis Teltow-Fläming ) ist Schmitt einer", "label": "de"}, + {"text": "grozījumus un iesniegt tos Apvienoto Nāciju", "label": "lv"}, + {"text": "Gestalt einer deutschen Nationalmannschaft als", "label": "de"}, + {"text": "D überholt zu haben , konterte am heutigen Montag", "label": "de"}, + {"text": "Softwarehersteller Oracle hat im dritten Quartal", "label": "de"}, + {"text": "Během nich se ekonomické podmínky mohou radikálně", "label": "cs"}, + {"text": "Dziki kot w górach zeskakuje z kamienia.", "label": "pl"}, + {"text": "Ačkoliv ligový nováček prohrál, opět potvrdil, že", "label": "cs"}, + {"text": "des Tages , Portraits internationaler Stars sowie", "label": "de"}, + {"text": "Communicator bekannt wurde .", "label": "de"}, + {"text": "τῷ δ’ ἄρα καὶ αὐτῷ ἡ γυνή ἐπίτεξ ἐοῦσα πᾶσαν", "label": "grc"}, + {"text": "Triadú tenia, mentre redactava 'Dies de memòria',", "label": "ca"}, + {"text": "دسته‌جمعی در درخشندگی ماه سیم‌گون زمزمه ستاینده و", "label": "fa"}, + {"text": "Книгу, наполненную мелочной заботой об одежде,", "label": "ru"}, + {"text": "putares canem leporem persequi.", "label": "la"}, + {"text": "В дальнейшем эта яркость слегка померкла, но в", "label": "ru"}, + {"text": "offizielles Verfahren gegen die Telekom", "label": "de"}, + {"text": "podrían haber sido habitantes de la Península", "label": "es"}, + {"text": "Grundlage für dieses Verfahren sind spezielle", "label": "de"}, + {"text": "Rechtsausschuß vorgelegten Entwurf der Richtlinie", "label": "de"}, + {"text": "Im so genannten Portalgeschäft sei das Unternehmen", "label": "de"}, + {"text": "ⲏ ⲉⲓϣⲁⲛϥⲓ ⲛⲉⲓⲇⲱⲗⲟⲛ ⲉⲧϩⲙⲡⲉⲕⲏⲓ ⲙⲏ ⲉⲓⲛⲁϣϩⲱⲡ ⲟⲛ ⲙⲡⲣⲏ", "label": "cop"}, + {"text": "juego podían matar a cualquier herbívoro, pero", "label": "es"}, + {"text": "Nach Angaben von Axent nutzen Unternehmen aus der", "label": "de"}, + {"text": "hrdiny Havlovy Zahradní slavnosti (premiéra ) se", "label": "cs"}, + {"text": "Een zin van heb ik jou daar", "label": "nl"}, + {"text": "hat sein Hirn an der CeBIT-Kasse vergessen .", "label": "de"}, + {"text": "καὶ τοὺς ἐκπλαγέντας οὐκ ἔχειν ἔτι ἐλεγχομένους", "label": "grc"}, + {"text": "nachgewiesenen langfristigen Kosten , sowie den im", "label": "de"}, + {"text": "jučer nakon četiri dana putovanja u Helsinki.", "label": "hr"}, + {"text": "pašto paslaugos teikėjas gali susitarti su", "label": "lt"}, + {"text": "В результате, эти золотые кадры переходят из одной", "label": "ru"}, + {"text": "द फाइव-ईयर एंगेजमेंट में अभिनय किया जिसमें जैसन", "label": "hi"}, + {"text": "výpis o počtu akcií.", "label": "cs"}, + {"text": "Enfin, elles arrivent à un pavillon chinois", "label": "fr"}, + {"text": "Tentu saja, tren yang berhubungandengan", "label": "id"}, + {"text": "Arbeidarpartiet og SV har sikra seg fleirtal mot", "label": "nn"}, + {"text": "eles: 'Tudo isso está errado' , disse um", "label": "pt"}, + {"text": "The islands are in their own time zone, minutes", "label": "en"}, + {"text": "Auswahl debütierte er am .", "label": "de"}, + {"text": "Bu komisyonlar, arazilerini satın almak için", "label": "tr"}, + {"text": "Geschütze gegen Redmond aufgefahren .", "label": "de"}, + {"text": "Time scything the hours, but at the top, over the", "label": "en"}, + {"text": "Di musim semi , berharap mengadaptasi Tintin untuk", "label": "id"}, + {"text": "крупнейшей геополитической катастрофой XX века.", "label": "ru"}, + {"text": "Rajojen avaaminen ei suju ongelmitta .", "label": "fi"}, + {"text": "непроницаемым, как для СССР.", "label": "ru"}, + {"text": "Ma non mancano le polemiche.", "label": "it"}, + {"text": "Internet als Ort politischer Diskussion und auch", "label": "de"}, + {"text": "incomplets.", "label": "ca"}, + {"text": "Su padre luchó al lado de Luis Moya, primer Jefe", "label": "es"}, + {"text": "informazione.", "label": "it"}, + {"text": "Primacom bietet für Telekom-Kabelnetz", "label": "de"}, + {"text": "Oświadczenie prezydencji w imieniu Unii", "label": "pl"}, + {"text": "foran rattet i familiens gamle Baleno hvis døra på", "label": "no"}, + {"text": "[speaker:laughter]", "label": "sl"}, + {"text": "Dog med langt mindre utstyr med seg.", "label": "nn"}, + {"text": "dass es nicht schon mit der anfänglichen", "label": "de"}, + {"text": "इस पर दोनों पक्षों में नोकझोंक शुरू हो गई ।", "label": "hi"}, + {"text": "کے ترجمان منیش تیواری اور دگ وجئے سنگھ نے بھی یہ", "label": "ur"}, + {"text": "dell'Assemblea Costituente che posseggono i", "label": "it"}, + {"text": "и аште вьси съблазнѧтъ сѧ нъ не азъ", "label": "cu"}, + {"text": "In Irvine hat auch das Logistikunternehmen Atlas", "label": "de"}, + {"text": "законодательных норм, принимаемых существующей", "label": "ru"}, + {"text": "Κροίσῳ προτείνων τὰς χεῖρας ἐπικατασφάξαι μιν", "label": "grc"}, + {"text": "МИНУСЫ: ИНФЛЯЦИЯ И КРИЗИС В ЖИВОТНОВОДСТВЕ.", "label": "ru"}, + {"text": "unterschiedlicher Meinung .", "label": "de"}, + {"text": "Jospa joku ystävällinen sielu auttaisi kassieni", "label": "fi"}, + {"text": "Añadió que, en el futuro se harán otros", "label": "es"}, + {"text": "Sessiz tonlama hem Fince, hem de Kuzey Sami", "label": "tr"}, + {"text": "nicht ihnen gehört und sie nicht alles , was sie", "label": "de"}, + {"text": "Etelästä Kuivajärveen laskee Tammelan Liesjärvestä", "label": "fi"}, + {"text": "ICANNs Vorsitzender Vint Cerf warb mit dem Hinweis", "label": "de"}, + {"text": "Norsk politikk frå til kan dermed, i", "label": "nn"}, + {"text": "Głosowało posłów.", "label": "pl"}, + {"text": "Danny Jones -- smithjones@ev.net", "label": "en"}, + {"text": "sebeuvědomění moderní civilizace sehrála lučavka", "label": "cs"}, + {"text": "относительно спокойный сон: тому гарантия", "label": "ru"}, + {"text": "A halte voiz prist li pedra a crïer", "label": "fro"}, + {"text": "آن‌ها امیدوارند این واکسن به‌زودی در دسترس بیماران", "label": "fa"}, + {"text": "vlastní důstojnou vousatou tváří.", "label": "cs"}, + {"text": "ora aprire la strada a nuove cause e alimentare il", "label": "it"}, + {"text": "Die Zahl der Vielleser nahm von auf Prozent zu ,", "label": "de"}, + {"text": "Finanzvorstand von Hotline-Dienstleister InfoGenie", "label": "de"}, + {"text": "entwickeln .", "label": "de"}, + {"text": "incolumità pubblica.", "label": "it"}, + {"text": "lehtija televisiomainonta", "label": "fi"}, + {"text": "joistakin kohdista eri mieltä.", "label": "fi"}, + {"text": "Hlavně anglická nezávislá scéna, Dead Can Dance,", "label": "cs"}, + {"text": "pásmech od do bodů bodové stupnice.", "label": "cs"}, + {"text": "Zu Beginn des Ersten Weltkrieges zählte das", "label": "de"}, + {"text": "Així van sorgir, damunt els antics cementiris,", "label": "ca"}, + {"text": "In manchem Gedicht der spätern Alten, wie zum", "label": "de"}, + {"text": "gaweihaida jah insandida in þana fairƕu jus qiþiþ", "label": "got"}, + {"text": "Beides sollte gelöscht werden!", "label": "de"}, + {"text": "modifiqués la seva petició inicial de anys de", "label": "ca"}, + {"text": "В день открытия симпозиума состоялась закладка", "label": "ru"}, + {"text": "tõestatud.", "label": "et"}, + {"text": "ἵππῳ πίπτει αὐτοῦ ταύτῃ", "label": "grc"}, + {"text": "bisher nie enttäuscht!", "label": "de"}, + {"text": "De bohte ollu tuollárat ja suttolaččat ja", "label": "sme"}, + {"text": "Klarsignal från röstlängdsläsaren, tre tryck i", "label": "sv"}, + {"text": "Tvůrcem nového termínu je Joseph Fisher.", "label": "cs"}, + {"text": "Nie miałem czasu na reakcję twierdzi Norbert,", "label": "pl"}, + {"text": "potentia Schöpfer.", "label": "de"}, + {"text": "Un poquito caro, pero vale mucho la pena;", "label": "es"}, + {"text": "οὔ τε γὰρ ἴφθιμοι Λύκιοι Δαναῶν ἐδύναντο τεῖχος", "label": "grc"}, + {"text": "vajec, sladového výtažku a některých vitamínových", "label": "cs"}, + {"text": "Настоящие герои, те, чьи истории потом", "label": "ru"}, + {"text": "praesumptio:", "label": "la"}, + {"text": "Olin justkui nende vastutusel.", "label": "et"}, + {"text": "Jokainen keinahdus tuo lähemmäksi hetkeä jolloin", "label": "fi"}, + {"text": "ekonomicky výhodných způsobů odvodnění těžkých,", "label": "cs"}, + {"text": "Poprvé ve své historii dokázala v kvalifikaci pro", "label": "cs"}, + {"text": "zpracovatelského a spotřebního průmyslu bude nutné", "label": "cs"}, + {"text": "Windows CE zu integrieren .", "label": "de"}, + {"text": "Armangué, a través d'un decret, ordenés l'aturada", "label": "ca"}, + {"text": "to, co nás Evropany spojuje, než to, co nás od", "label": "cs"}, + {"text": "ergänzt durch einen gesetzlich verankertes", "label": "de"}, + {"text": "Насчитал, что с начала года всего три дня были", "label": "ru"}, + {"text": "Borisovu tražeći od njega da prihvati njenu", "label": "sr"}, + {"text": "la presenza di ben veleni diversi: . chili di", "label": "it"}, + {"text": "καὶ τῶν ἐκλεκτῶν ἀγγέλων ἵνα ταῦτα φυλάξῃς χωρὶς", "label": "grc"}, + {"text": "pretraživale obližnju bolnicu i stambene zgrade u", "label": "hr"}, + {"text": "An rund Katzen habe Wolf seine Spiele getestet ,", "label": "de"}, + {"text": "investigating since March.", "label": "en"}, + {"text": "Tonböden (Mullböden).", "label": "de"}, + {"text": "Stálý dopisovatel LN v SRN Bedřich Utitz", "label": "cs"}, + {"text": "červnu předložené smlouvy.", "label": "cs"}, + {"text": "πνεύματι ᾧ ἐλάλει", "label": "grc"}, + {"text": ".%의 신장세를 보였다.", "label": "ko"}, + {"text": "Foae verde, foi de nuc, Prin pădure, prin colnic,", "label": "ro"}, + {"text": "διαπέμψας ἄλλους ἄλλῃ τοὺς μὲν ἐς Δελφοὺς ἰέναι", "label": "grc"}, + {"text": "المسلمين أو أي تيار سياسي طالما عمل ذلك التيار في", "label": "ar"}, + {"text": "As informações são da Dow Jones.", "label": "pt"}, + {"text": "Milliarde DM ausgestattet sein .", "label": "de"}, + {"text": "De utgår fortfarande från att kvinnans jämlikhet", "label": "sv"}, + {"text": "Sneeuw maakte in Davos bij de voorbereiding een", "label": "nl"}, + {"text": "De ahí que en este mercado puedan negociarse", "label": "es"}, + {"text": "intenzívnějšímu sbírání a studiu.", "label": "cs"}, + {"text": "और औसकर ४.० पैकेज का प्रयोग किया गया है ।", "label": "hi"}, + {"text": "Adipati Kuningan karena Kuningan menjadi bagian", "label": "id"}, + {"text": "Svako je bar jednom poželeo da mašine prosto umeju", "label": "sr"}, + {"text": "Im vergangenen Jahr haben die Regierungen einen", "label": "de"}, + {"text": "durat motus, aliquid fit et non est;", "label": "la"}, + {"text": "Dominować będą piosenki do tekstów Edwarda", "label": "pl"}, + {"text": "beantwortet .", "label": "de"}, + {"text": "О гуманитариях было кому рассказывать, а вот за", "label": "ru"}, + {"text": "Helsingin kaupunki riitautti vuokrasopimuksen", "label": "fi"}, + {"text": "chợt tan biến.", "label": "vi"}, + {"text": "avtomobil ločuje od drugih.", "label": "sl"}, + {"text": "Congress has proven itself ineffective as a body.", "label": "en"}, + {"text": "मैक्सिको ने इस तरह का शो इस समय आयोजित करने का", "label": "hi"}, + {"text": "No minimum order amount.", "label": "en"}, + {"text": "Convertassa .", "label": "fi"}, + {"text": "Как это можно сделать?", "label": "ru"}, + {"text": "tha mi creidsinn gu robh iad ceart cho saor shuas", "label": "gd"}, + {"text": "실제 일제는 이런 만해의 논리를 묵살하고 한반도를 침략한 다음 , 이어 만주를 침략하고", "label": "ko"}, + {"text": "Da un semplice richiamo all'ordine fino a grandi", "label": "it"}, + {"text": "pozoruhodný nejen po umělecké stránce, jež", "label": "cs"}, + {"text": "La comida y el servicio aprueban.", "label": "es"}, + {"text": "again, connected not with each other but to the", "label": "en"}, + {"text": "Protokol výslovně stanoví, že nikdo nemůže být", "label": "cs"}, + {"text": "ఒక విషయం అడగాలని ఉంది .", "label": "te"}, + {"text": "Безгранично почитая дирекцию, ловя на лету каждое", "label": "ru"}, + {"text": "rovnoběžných růstových vrstev, zůstávají krychlové", "label": "cs"}, + {"text": "प्रवेश और पूर्व प्रधानमंत्री लाल बहादुर शास्त्री", "label": "hi"}, + {"text": "Bronzen medaille in de Europese marathon.", "label": "nl"}, + {"text": "- gadu vecumā viņi to nesaprot.", "label": "lv"}, + {"text": "Realizó sus estudios primarios en la Escuela Julia", "label": "es"}, + {"text": "cuartos de final, su clasificación para la final a", "label": "es"}, + {"text": "Sem si pro něho přiletí americký raketoplán, na", "label": "cs"}, + {"text": "Way to go!", "label": "en"}, + {"text": "gehört der neuen SPD-Führung unter Parteichef", "label": "de"}, + {"text": "Somit simuliert der Player mit einer GByte-Platte", "label": "de"}, + {"text": "Berufung auf kommissionsnahe Kreise , die bereits", "label": "de"}, + {"text": "Dist Clarïen", "label": "fro"}, + {"text": "Schon nach den Gerüchten , die Telekom wolle den", "label": "de"}, + {"text": "Software von NetObjects ist nach Angaben des", "label": "de"}, + {"text": "si enim per legem iustitia ergo Christus gratis", "label": "la"}, + {"text": "ducerent in ipsam magis quam in corpus christi,", "label": "la"}, + {"text": "Neustar-Melbourne-IT-Partnerschaft NeuLevel .", "label": "de"}, + {"text": "forderte dagegen seine drastische Verschärfung.", "label": "de"}, + {"text": "pemmican på hundrede forskellige måder.", "label": "da"}, + {"text": "Lehån, själv matematiklärare, visar hur den nya", "label": "sv"}, + {"text": "I highly recommend his shop.", "label": "en"}, + {"text": "verità, giovani fedeli prostratevi #amen", "label": "it"}, + {"text": "उत्तर प्रदेश के अध्यक्ष पद से हटाए गए विनय कटियार", "label": "hi"}, + {"text": "() روزی مےں کشادگی ہوتی ہے۔", "label": "ur"}, + {"text": "Prozessorgeschäft profitieren kann , stellen", "label": "de"}, + {"text": "školy začalo počítat pytle s moukou a zjistilo, že", "label": "cs"}, + {"text": "प्रभावशाली पर गैर सरकारी लोगों के घरों में भी", "label": "hi"}, + {"text": "geschichtslos , oder eine Farce , wie sich", "label": "de"}, + {"text": "Ústrednými mocnosťami v marci však spôsobilo, že", "label": "sk"}, + {"text": "التسليح بدون مبرر، واستمرار الأضرار الناجمة عن فرض", "label": "ar"}, + {"text": "Například Pedagogická fakulta Univerzity Karlovy", "label": "cs"}, + {"text": "nostris ut eriperet nos de praesenti saeculo", "label": "la"}] + + nlp = Pipeline(lang="multilingual", processors="langid") + docs = [Document([], text=example["text"]) for example in examples] + gold_labels = [example["label"] for example in examples] + nlp(docs) + accuracy = sum([(doc.lang == label) for doc,label in zip(docs,gold_labels)])/len(docs) + assert accuracy >= 0.98 + + +def test_text_cleaning(): + """ + Basic test of cleaning text + """ + docs = ["Bonjour le monde! #thisisfrench #ilovefrance", + "Bonjour le monde! https://t.co/U0Zjp3tusD"] + docs = [Document([], text=text) for text in docs] + + nlp = Pipeline(lang="multilingual", processors="langid") + nlp(docs) + assert [doc.lang for doc in docs] == ["it", "it"] + + nlp = Pipeline(lang="multilingual", processors="langid", langid_clean_text=True) + assert nlp.processors["langid"]._clean_text + nlp(docs) + assert [doc.lang for doc in docs] == ["fr", "fr"] + +def test_lang_subset(): + """ + Basic test of restricting output to subset of languages + """ + docs = ["Bonjour le monde! #thisisfrench #ilovefrance", + "Bonjour le monde! https://t.co/U0Zjp3tusD"] + docs = [Document([], text=text) for text in docs] + + nlp = Pipeline(lang="multilingual", processors="langid") + nlp(docs) + assert [doc.lang for doc in docs] == ["it", "it"] + + nlp = Pipeline(lang="multilingual", processors="langid", langid_lang_subset=["en","fr"]) + assert nlp.processors["langid"]._model.lang_subset == ["en", "fr"] + nlp(docs) + assert [doc.lang for doc in docs] == ["fr", "fr"] + + nlp = Pipeline(lang="multilingual", processors="langid", langid_lang_subset=["en"]) + assert nlp.processors["langid"]._model.lang_subset == ["en"] + nlp(docs) + assert [doc.lang for doc in docs] == ["en", "en"] + +def test_multilingual_pipeline(): + """ + Basic test of multilingual pipeline + """ + english_text = "This is an English sentence." + english_deps_gold = "\n".join(( + "('This', 5, 'nsubj')", + "('is', 5, 'cop')", + "('an', 5, 'det')", + "('English', 5, 'amod')", + "('sentence', 0, 'root')", + "('.', 5, 'punct')" + )) + + french_text = "C'est une phrase française." + french_deps_gold = "\n".join(( + "(\"C'\", 4, 'nsubj')", + "('est', 4, 'cop')", + "('une', 4, 'det')", + "('phrase', 0, 'root')", + "('française', 4, 'amod')", + "('.', 4, 'punct')" + )) + + nlp = MultilingualPipeline() + docs = [english_text, french_text] + docs = nlp(docs) + + assert docs[0].lang == "en" + assert docs[0].sentences[0].dependencies_string() == english_deps_gold + assert docs[1].lang == "fr" + assert docs[1].sentences[0].dependencies_string() == french_deps_gold + -- cgit v1.2.3 From 52352c266eab614c4f8cc65cf799c225774a124b Mon Sep 17 00:00:00 2001 From: John Bauer Date: Fri, 25 Jun 2021 22:44:57 -0700 Subject: The next release will be 1.3.0. In the meantime, we can include new models for the langid processor in a 1.3.0 directory --- stanza/_version.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/stanza/_version.py b/stanza/_version.py index 100ecb6e..87647499 100644 --- a/stanza/_version.py +++ b/stanza/_version.py @@ -1,4 +1,4 @@ """ Single source of truth for version number """ -__version__ = "1.2.2" -__resources_version__ = '1.2.2' +__version__ = "1.3.0" +__resources_version__ = '1.3.0' -- cgit v1.2.3 From 29534158efa2be63c2ed9ed566557c90c966dec4 Mon Sep 17 00:00:00 2001 From: John Bauer Date: Fri, 25 Jun 2021 22:45:39 -0700 Subject: Move the parse_args to the top to be easier to find. Change dashes back to underscores (for consistency with the rest of the codebase) and put in reasonable defaults for input_dir and output_dir --- stanza/resources/prepare_resources.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/stanza/resources/prepare_resources.py b/stanza/resources/prepare_resources.py index 16f64a7f..9421d3a1 100644 --- a/stanza/resources/prepare_resources.py +++ b/stanza/resources/prepare_resources.py @@ -6,6 +6,16 @@ import hashlib import shutil import zipfile +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--input_dir', type=str, default="/u/nlp/software/stanza/current-models", help='Input dir for various models. Defaults to the recommended home on the nlp cluster') + parser.add_argument('--output_dir', type=str, default="/u/nlp/software/stanza/built-models", help='Output dir for various models.') + args = parser.parse_args() + args.input_dir = os.path.abspath(args.input_dir) + args.output_dir = os.path.abspath(args.output_dir) + return args + + # default treebank for languages default_treebanks = { "af": "afribooms", @@ -267,16 +277,6 @@ def get_md5(path): return hashlib.md5(data).hexdigest() -def parse_args(): - parser = argparse.ArgumentParser() - parser.add_argument('--input-dir', type=str, help='Input dir for various models.') - parser.add_argument('--output-dir', type=str, help='Output dir for various models.') - args = parser.parse_args() - args.input_dir = os.path.abspath(args.input_dir) - args.output_dir = os.path.abspath(args.output_dir) - return args - - def split_model_name(model): """ Split model names by _ -- cgit v1.2.3 From 63329da46fdb2eebfb669df34ddea9942f23c7bb Mon Sep 17 00:00:00 2001 From: John Bauer Date: Tue, 8 Jun 2021 15:59:02 -0700 Subject: This small hack is not needed now that a new version of CoreNLP has been released --- stanza/server/ud_enhancer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stanza/server/ud_enhancer.py b/stanza/server/ud_enhancer.py index e3d64afa..92d6b7ff 100644 --- a/stanza/server/ud_enhancer.py +++ b/stanza/server/ud_enhancer.py @@ -72,7 +72,7 @@ def main(): nlp = stanza.Pipeline('en', processors='tokenize,pos,lemma,depparse') - with UniversalEnhancer(language="en", classpath="$CLASSPATH") as enhancer: + with UniversalEnhancer(language="en") as enhancer: doc = nlp("This is the car that I bought") result = enhancer.process(doc) print(result.sentence[0].enhancedDependencies) -- cgit v1.2.3 From 5005a779a959dbbe7a87539dd2ed542df8cecaa5 Mon Sep 17 00:00:00 2001 From: John Bauer Date: Tue, 8 Jun 2021 18:54:04 -0700 Subject: Typos in doc --- stanza/utils/datasets/ner/convert_fire_2013.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/stanza/utils/datasets/ner/convert_fire_2013.py b/stanza/utils/datasets/ner/convert_fire_2013.py index f76aa696..9c0f5cd7 100644 --- a/stanza/utils/datasets/ner/convert_fire_2013.py +++ b/stanza/utils/datasets/ner/convert_fire_2013.py @@ -68,8 +68,8 @@ if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--input_path', type=str, default="/home/john/extern_data/ner/FIRE2013/hindi_train", help="Directory with raw files to read") parser.add_argument('--train_file', type=str, default="/home/john/stanza/data/ner/hi_fire2013.train.csv", help="Where to put the train file") - parser.add_argument('--dev_file', type=str, default="/home/john/stanza/data/ner/hi_fire2013.dev.csv", help="Where to put the train file") - parser.add_argument('--test_file', type=str, default="/home/john/stanza/data/ner/hi_fire2013.test.csv", help="Where to put the train file") + parser.add_argument('--dev_file', type=str, default="/home/john/stanza/data/ner/hi_fire2013.dev.csv", help="Where to put the dev file") + parser.add_argument('--test_file', type=str, default="/home/john/stanza/data/ner/hi_fire2013.test.csv", help="Where to put the test file") args = parser.parse_args() convert_fire_2013(args.input_path, args.train_file, args.dev_file, args.test_file) -- cgit v1.2.3 From f761a82cbca740dc518da6e26c676a0dca74151c Mon Sep 17 00:00:00 2001 From: John Bauer Date: Tue, 8 Jun 2021 21:16:43 -0700 Subject: Add a couple sanity checks on the FIRE 2013 data --- stanza/utils/datasets/ner/convert_fire_2013.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/stanza/utils/datasets/ner/convert_fire_2013.py b/stanza/utils/datasets/ner/convert_fire_2013.py index 9c0f5cd7..dfafdae7 100644 --- a/stanza/utils/datasets/ner/convert_fire_2013.py +++ b/stanza/utils/datasets/ner/convert_fire_2013.py @@ -41,6 +41,10 @@ def convert_fileset(output_csv_file, filenames): for sentence in sentences: for line in sentence: pieces = line.split("\t") + if len(pieces) != 6: + raise ValueError("Found %d pieces instead of the expected 6" % len(pieces)) + if pieces[3] == 'o' and (pieces[4] != 'o' or pieces[5] != 'o'): + raise ValueError("Inner NER labeled but the top layer was O") fout.write("%s\t%s\n" % (pieces[0], normalize(pieces[3]))) fout.write("\n") -- cgit v1.2.3 From 6023bae51d845138f16eedcc57ae1d3fb19f97de Mon Sep 17 00:00:00 2001 From: John Bauer Date: Tue, 8 Jun 2021 22:26:47 -0700 Subject: This snippet has been useful a couple times, so... check in a script to count the number of times a word in an NER dataset can be found in a WV file --- stanza/models/common/count_ner_coverage.py | 38 ++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) create mode 100644 stanza/models/common/count_ner_coverage.py diff --git a/stanza/models/common/count_ner_coverage.py b/stanza/models/common/count_ner_coverage.py new file mode 100644 index 00000000..b5a592c7 --- /dev/null +++ b/stanza/models/common/count_ner_coverage.py @@ -0,0 +1,38 @@ +from stanza.models.common import pretrain +import argparse + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument('ners', type=str, nargs='*', help='Which treebanks to run on') + parser.add_argument('--pretrain', type=str, default="/home/john/stanza_resources/hi/pretrain/hdtb.pt", help='Which pretrain to use') + parser.set_defaults(ners=["/home/john/stanza/data/ner/hi_fire2013.train.csv", + "/home/john/stanza/data/ner/hi_fire2013.dev.csv"]) + args = parser.parse_args() + return args + + +def read_ner(filename): + words = [] + for line in open(filename).readlines(): + line = line.strip() + if not line: + continue + if line.split("\t")[1] == 'O': + continue + words.append(line.split("\t")[0]) + return words + +def count_coverage(pretrain, words): + count = 0 + for w in words: + if w in pretrain.vocab: + count = count + 1 + return count / len(words) + +args = parse_args() +pt = pretrain.Pretrain(args.pretrain) +for dataset in args.ners: + words = read_ner(dataset) + print(dataset) + print(count_coverage(pt, words)) + print() -- cgit v1.2.3 From 6400cfd5b046b53089ada61294d3e9d06504abef Mon Sep 17 00:00:00 2001 From: John Bauer Date: Wed, 9 Jun 2021 08:15:58 -0700 Subject: Separate the test of the scorer from the test of the ner_tagger itself --- stanza/tests/test_models_ner_scorer.py | 27 +++++++++++++++++++++++++++ stanza/tests/test_ner_tagger.py | 16 ---------------- 2 files changed, 27 insertions(+), 16 deletions(-) create mode 100644 stanza/tests/test_models_ner_scorer.py diff --git a/stanza/tests/test_models_ner_scorer.py b/stanza/tests/test_models_ner_scorer.py new file mode 100644 index 00000000..b6993f09 --- /dev/null +++ b/stanza/tests/test_models_ner_scorer.py @@ -0,0 +1,27 @@ +""" +Simple test of the scorer module for NER +""" + +import pytest +import stanza + +from stanza.tests import * +from stanza.models.ner.scorer import score_by_token, score_by_entity + +pytestmark = [pytest.mark.travis, pytest.mark.pipeline] + +def test_ner_scorer(): + pred_sequences = [['O', 'S-LOC', 'O', 'O', 'B-PER', 'E-PER'], + ['O', 'S-MISC', 'O', 'E-ORG', 'O', 'B-PER', 'I-PER', 'E-PER']] + gold_sequences = [['O', 'B-LOC', 'E-LOC', 'O', 'B-PER', 'E-PER'], + ['O', 'S-MISC', 'B-ORG', 'E-ORG', 'O', 'B-PER', 'E-PER', 'S-LOC']] + + token_p, token_r, token_f = score_by_token(pred_sequences, gold_sequences) + assert pytest.approx(token_p, abs=0.00001) == 0.625 + assert pytest.approx(token_r, abs=0.00001) == 0.5 + assert pytest.approx(token_f, abs=0.00001) == 0.55555 + + entity_p, entity_r, entity_f = score_by_entity(pred_sequences, gold_sequences) + assert pytest.approx(entity_p, abs=0.00001) == 0.4 + assert pytest.approx(entity_r, abs=0.00001) == 0.33333 + assert pytest.approx(entity_f, abs=0.00001) == 0.36363 diff --git a/stanza/tests/test_ner_tagger.py b/stanza/tests/test_ner_tagger.py index a6ea6b8a..fde10013 100644 --- a/stanza/tests/test_ner_tagger.py +++ b/stanza/tests/test_ner_tagger.py @@ -23,19 +23,3 @@ def test_ner(): doc = nlp(EN_DOC) assert EN_DOC_GOLD == '\n'.join([ent.pretty_print() for ent in doc.ents]) - -def test_ner_scorer(): - pred_sequences = [['O', 'S-LOC', 'O', 'O', 'B-PER', 'E-PER'], - ['O', 'S-MISC', 'O', 'E-ORG', 'O', 'B-PER', 'I-PER', 'E-PER']] - gold_sequences = [['O', 'B-LOC', 'E-LOC', 'O', 'B-PER', 'E-PER'], - ['O', 'S-MISC', 'B-ORG', 'E-ORG', 'O', 'B-PER', 'E-PER', 'S-LOC']] - - token_p, token_r, token_f = score_by_token(pred_sequences, gold_sequences) - assert pytest.approx(token_p, abs=0.00001) == 0.625 - assert pytest.approx(token_r, abs=0.00001) == 0.5 - assert pytest.approx(token_f, abs=0.00001) == 0.55555 - - entity_p, entity_r, entity_f = score_by_entity(pred_sequences, gold_sequences) - assert pytest.approx(entity_p, abs=0.00001) == 0.4 - assert pytest.approx(entity_r, abs=0.00001) == 0.33333 - assert pytest.approx(entity_f, abs=0.00001) == 0.36363 -- cgit v1.2.3 From a869c4abf27b5fd1983f1dc031fe76aae29c72ea Mon Sep 17 00:00:00 2001 From: John Bauer Date: Wed, 9 Jun 2021 14:06:19 -0700 Subject: Allow fire2013 as well as FIRE2013... cuts down on hassle --- stanza/utils/datasets/ner/prepare_ner_dataset.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/stanza/utils/datasets/ner/prepare_ner_dataset.py b/stanza/utils/datasets/ner/prepare_ner_dataset.py index 54a2c7e3..22befde0 100644 --- a/stanza/utils/datasets/ner/prepare_ner_dataset.py +++ b/stanza/utils/datasets/ner/prepare_ner_dataset.py @@ -76,11 +76,11 @@ import tempfile from stanza.models.common.constant import treebank_to_short_name, lcode2lang import stanza.utils.default_paths as default_paths -from stanza.utils.datasets.ner.convert_fire_2013 import convert_fire_2013 from stanza.utils.datasets.ner.preprocess_wikiner import preprocess_wikiner from stanza.utils.datasets.ner.split_wikiner import split_wikiner import stanza.utils.datasets.ner.convert_bsf_to_beios as convert_bsf_to_beios import stanza.utils.datasets.ner.convert_bsnlp as convert_bsnlp +import stanza.utils.datasets.ner.convert_fire_2013 as convert_fire_2013 import stanza.utils.datasets.ner.convert_ijc as convert_ijc import stanza.utils.datasets.ner.convert_rgai as convert_rgai import stanza.utils.datasets.ner.convert_nytk as convert_nytk @@ -181,7 +181,7 @@ def process_fire_2013(paths, dataset): dev_csv_file = os.path.join(base_output_path, "%s.dev.csv" % short_name) test_csv_file = os.path.join(base_output_path, "%s.test.csv" % short_name) - convert_fire_2013(base_input_path, train_csv_file, dev_csv_file, test_csv_file) + convert_fire_2013.convert_fire_2013(base_input_path, train_csv_file, dev_csv_file, test_csv_file) for csv_file, shard in zip((train_csv_file, dev_csv_file, test_csv_file), SHARDS): output_filename = os.path.join(base_output_path, '%s.%s.json' % (short_name, shard)) @@ -327,7 +327,7 @@ def main(dataset_name): process_languk(paths) elif dataset_name == 'hi_ijc': process_ijc(paths, dataset_name) - elif dataset_name.endswith("FIRE2013"): + elif dataset_name.endswith("FIRE2013") or dataset_name.endswith("fire2013"): process_fire_2013(paths, dataset_name) elif dataset_name.endswith('WikiNER'): process_wikiner(paths, dataset_name) -- cgit v1.2.3 From 606c4f79565f37ab43cc029e24e41fed99e06abe Mon Sep 17 00:00:00 2001 From: John Bauer Date: Mon, 14 Jun 2021 00:12:41 -0700 Subject: Prefer to split by \t first, in case the text has nbsp or something like that --- stanza/utils/datasets/ner/prepare_ner_file.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/stanza/utils/datasets/ner/prepare_ner_file.py b/stanza/utils/datasets/ner/prepare_ner_file.py index e5fe2220..f4894fa2 100644 --- a/stanza/utils/datasets/ner/prepare_ner_file.py +++ b/stanza/utils/datasets/ner/prepare_ner_file.py @@ -47,7 +47,9 @@ def load_conll03(filename, skip_doc_start=True): if skip_doc_start and DOC_START_TOKEN in line: continue if len(line) > 0: - array = line.split() + array = line.split("\t") + if len(array) < MIN_NUM_FIELD: + array = line.split() if len(array) < MIN_NUM_FIELD: continue else: @@ -64,8 +66,10 @@ def process_cache(cached_lines): tokens = [] ner_tags = [] for line in cached_lines: - array = line.split() - assert len(array) >= MIN_NUM_FIELD and len(array) <= MAX_NUM_FIELD + array = line.split("\t") + if len(array) < MIN_NUM_FIELD: + array = line.split() + assert len(array) >= MIN_NUM_FIELD and len(array) <= MAX_NUM_FIELD, "Got unexpected line length: {}".format(array) tokens.append(array[0]) ner_tags.append(array[-1]) return (tokens, ner_tags) -- cgit v1.2.3 From 0b4deab05a1c34b9c3bbe9ce87b026e65b131912 Mon Sep 17 00:00:00 2001 From: John Bauer Date: Thu, 17 Jun 2021 09:34:18 -0700 Subject: Update 2.7 -> 2.8 as the default UDBASE --- stanza/utils/default_paths.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stanza/utils/default_paths.py b/stanza/utils/default_paths.py index ce40efc2..977df186 100644 --- a/stanza/utils/default_paths.py +++ b/stanza/utils/default_paths.py @@ -25,7 +25,7 @@ def get_default_paths(): # TODO: not sure what other people actually have # TODO: also, could make this automatically update to the latest - "UDBASE": "extern_data/ud2/ud-treebanks-v2.7", + "UDBASE": "extern_data/ud2/ud-treebanks-v2.8", "NERBASE": "extern_data/ner", -- cgit v1.2.3 From 65640c46642ddb6c02fe891971568e06a9505a34 Mon Sep 17 00:00:00 2001 From: John Bauer Date: Wed, 30 Jun 2021 22:10:52 -0700 Subject: A combined ES model will be similar to both AnCora and GSD --- stanza/models/pos/xpos_vocab_factory.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stanza/models/pos/xpos_vocab_factory.py b/stanza/models/pos/xpos_vocab_factory.py index 39da44fd..5abbb0ec 100644 --- a/stanza/models/pos/xpos_vocab_factory.py +++ b/stanza/models/pos/xpos_vocab_factory.py @@ -6,7 +6,7 @@ from stanza.models.pos.vocab import WordVocab, XPOSVocab def xpos_vocab_factory(data, shorthand): if shorthand in ["af_afribooms", "ar_padt", "bg_btb", "cs_cac", "cs_cltt", "cs_fictree", "cs_pdt", "en_partut", "fr_partut", "gd_arcosg", "gl_ctg", "gl_treegal", "grc_perseus", "hr_set", "is_icepahc", "is_modern", "it_combined", "it_isdt", "it_partut", "it_postwita", "it_twittiro", "it_vit", "la_perseus", "la_udante", "lt_alksnis", "lv_lvtb", "ro_nonstandard", "ro_rrt", "ro_simonero", "sk_snk", "sl_ssj", "sl_sst", "sr_set", "ta_ttb", "uk_iu"]: return XPOSVocab(data, shorthand, idx=2, sep="") - elif shorthand in ["be_hse", "ca_ancora", "cop_scriptorium", "cu_proiel", "cy_ccg", "da_ddt", "de_gsd", "de_hdt", "el_gdt", "en_combined", "en_ewt", "en_gum", "es_ancora", "es_gsd", "et_edt", "et_ewt", "eu_bdt", "fa_perdt", "fa_seraji", "fi_tdt", "fr_gsd", "fro_srcmf", "fr_sequoia", "fr_spoken", "ga_idt", "got_proiel", "grc_proiel", "he_htb", "hi_hdtb", "hu_szeged", "hy_armtdp", "hyw_armtdp", "id_csui", "ja_gsd", "la_proiel", "lt_hse", "lzh_kyoto", "mr_ufal", "mt_mudt", "nb_bokmaal", "nn_nynorsk", "nn_nynorsklia", "orv_rnc", "orv_torot", "pcm_nsc", "pt_bosque", "pt_gsd", "qtd_sagt", "ru_gsd", "ru_syntagrus", "ru_taiga", "sa_vedic", "sme_giella", "swl_sslc", "te_mtg", "tr_boun", "tr_framenet", "tr_imst", "tr_kenet", "tr_penn", "tr_tourism", "ug_udt", "vi_vtb", "wo_wtb", "zh_gsdsimp", "zh-hant_gsd", "bxr_bdt", "hsb_ufal", "ja_bccwj", "kk_ktb", "kmr_mg", "olo_kkpp"]: + elif shorthand in ["be_hse", "ca_ancora", "cop_scriptorium", "cu_proiel", "cy_ccg", "da_ddt", "de_gsd", "de_hdt", "el_gdt", "en_combined", "en_ewt", "en_gum", "es_ancora", "es_gsd", "es_combined", "et_edt", "et_ewt", "eu_bdt", "fa_perdt", "fa_seraji", "fi_tdt", "fr_gsd", "fro_srcmf", "fr_sequoia", "fr_spoken", "ga_idt", "got_proiel", "grc_proiel", "he_htb", "hi_hdtb", "hu_szeged", "hy_armtdp", "hyw_armtdp", "id_csui", "ja_gsd", "la_proiel", "lt_hse", "lzh_kyoto", "mr_ufal", "mt_mudt", "nb_bokmaal", "nn_nynorsk", "nn_nynorsklia", "orv_rnc", "orv_torot", "pcm_nsc", "pt_bosque", "pt_gsd", "qtd_sagt", "ru_gsd", "ru_syntagrus", "ru_taiga", "sa_vedic", "sme_giella", "swl_sslc", "te_mtg", "tr_boun", "tr_framenet", "tr_imst", "tr_kenet", "tr_penn", "tr_tourism", "ug_udt", "vi_vtb", "wo_wtb", "zh_gsdsimp", "zh-hant_gsd", "bxr_bdt", "hsb_ufal", "ja_bccwj", "kk_ktb", "kmr_mg", "olo_kkpp"]: return WordVocab(data, shorthand, idx=2, ignore=["_"]) elif shorthand in ["en_lines", "fo_farpahc", "sv_lines", "ur_udtb"]: return XPOSVocab(data, shorthand, idx=2, sep="-") -- cgit v1.2.3 From e0ed63b13e1952a663485e74e7c0ad325234417b Mon Sep 17 00:00:00 2001 From: John Bauer Date: Wed, 9 Jun 2021 16:00:20 -0700 Subject: Shuffle files randomly --- stanza/utils/datasets/ner/convert_fire_2013.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/stanza/utils/datasets/ner/convert_fire_2013.py b/stanza/utils/datasets/ner/convert_fire_2013.py index dfafdae7..b95275be 100644 --- a/stanza/utils/datasets/ner/convert_fire_2013.py +++ b/stanza/utils/datasets/ner/convert_fire_2013.py @@ -13,6 +13,7 @@ This script keeps just the word and the ner1. It is quite possible that using t import argparse import glob import os +import random def normalize(entity): if entity == 'o': @@ -53,6 +54,7 @@ def convert_fire_2013(input_path, train_csv_file, dev_csv_file, test_csv_file): # won't be numerically sorted... shouldn't matter filenames = sorted(filenames) + random.shuffle(filenames) train_cutoff = int(0.8 * len(filenames)) dev_cutoff = int(0.9 * len(filenames)) @@ -69,6 +71,8 @@ def convert_fire_2013(input_path, train_csv_file, dev_csv_file, test_csv_file): convert_fileset(test_csv_file, test_files) if __name__ == '__main__': + random.seed(1234) + parser = argparse.ArgumentParser() parser.add_argument('--input_path', type=str, default="/home/john/extern_data/ner/FIRE2013/hindi_train", help="Directory with raw files to read") parser.add_argument('--train_file', type=str, default="/home/john/stanza/data/ner/hi_fire2013.train.csv", help="Where to put the train file") -- cgit v1.2.3 From d26df4e4a8828ccc9f454d78c1ca435f469e0233 Mon Sep 17 00:00:00 2001 From: John Bauer Date: Fri, 2 Jul 2021 13:02:11 -0700 Subject: Add a comment on why no softmax after the prediction --- stanza/models/classifiers/cnn_classifier.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/stanza/models/classifiers/cnn_classifier.py b/stanza/models/classifiers/cnn_classifier.py index fa5160bf..f7bf4c8e 100644 --- a/stanza/models/classifiers/cnn_classifier.py +++ b/stanza/models/classifiers/cnn_classifier.py @@ -367,6 +367,8 @@ class CNNClassifier(nn.Module): for fc in self.fc_layers[:-1]: previous_layer = self.dropout(F.relu(fc(previous_layer))) out = self.fc_layers[-1](previous_layer) + # note that we return the raw logits rather than use a softmax + # https://discuss.pytorch.org/t/multi-class-cross-entropy-loss-and-softmax-in-pytorch/24920/4 return out -- cgit v1.2.3 From 666c8efdd1be14774b5bc470e2b9872d9e166126 Mon Sep 17 00:00:00 2001 From: John Bauer Date: Mon, 5 Jul 2021 00:18:12 -0700 Subject: Need to refactor this from the deparse, pos tagger, and elsewhere --- stanza/models/parser.py | 1 + 1 file changed, 1 insertion(+) diff --git a/stanza/models/parser.py b/stanza/models/parser.py index 4d605dcb..f74c8b6d 100644 --- a/stanza/models/parser.py +++ b/stanza/models/parser.py @@ -115,6 +115,7 @@ def model_file_name(args): return os.path.join(args['save_dir'], save_name) +# TODO: refactor with everywhere def load_pretrain(args): pt = None if args['pretrain']: -- cgit v1.2.3 From 0342af1da4b9c503f1560495c11d76784c26efbb Mon Sep 17 00:00:00 2001 From: John Bauer Date: Thu, 24 Jun 2021 08:15:26 -0700 Subject: This is a bug in the classifier I think --- stanza/models/classifiers/cnn_classifier.py | 1 + 1 file changed, 1 insertion(+) diff --git a/stanza/models/classifiers/cnn_classifier.py b/stanza/models/classifiers/cnn_classifier.py index f7bf4c8e..e912fc3d 100644 --- a/stanza/models/classifiers/cnn_classifier.py +++ b/stanza/models/classifiers/cnn_classifier.py @@ -112,6 +112,7 @@ class CNNClassifier(nn.Module): self.extra_vocab = list(extra_vocab) self.extra_vocab_map = { word: i for i, word in enumerate(self.extra_vocab) } # TODO: possibly add regularization specifically on the extra embedding? + # TODO FIXME: word of idx 0 is being shared with the padding! self.extra_embedding = nn.Embedding(num_embeddings = len(extra_vocab), embedding_dim = self.config.extra_wordvec_dim, max_norm = self.config.extra_wordvec_max_norm, -- cgit v1.2.3 From c782368d1083c4b595b229035657e2aa218c7930 Mon Sep 17 00:00:00 2001 From: John Bauer Date: Mon, 14 Jun 2021 22:26:27 -0700 Subject: Add a path for overall extra datasets --- stanza/utils/default_paths.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/stanza/utils/default_paths.py b/stanza/utils/default_paths.py index 977df186..6326ba10 100644 --- a/stanza/utils/default_paths.py +++ b/stanza/utils/default_paths.py @@ -32,6 +32,9 @@ def get_default_paths(): # there's a stanford github, stanfordnlp/handparsed-treebank, # with some data for different languages "HANDPARSED_DIR": "extern_data/handparsed-treebank", + + # data root for other general input files, such as VI_VLSP + "EXTERN_DIR": "extern_data" } paths = { "DATA_ROOT" : DATA_ROOT } -- cgit v1.2.3 From 0777cd456adc44f1bc28c3cb5757011654cead2c Mon Sep 17 00:00:00 2001 From: John Bauer Date: Tue, 15 Jun 2021 14:14:11 -0700 Subject: Add an error when the conllu can't be found, since the perl script apparently doesn't throw an error --- stanza/utils/datasets/prepare_tokenizer_treebank.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/stanza/utils/datasets/prepare_tokenizer_treebank.py b/stanza/utils/datasets/prepare_tokenizer_treebank.py index 459a6a74..afc3b6f1 100755 --- a/stanza/utils/datasets/prepare_tokenizer_treebank.py +++ b/stanza/utils/datasets/prepare_tokenizer_treebank.py @@ -145,6 +145,9 @@ def convert_conllu_to_txt(tokenizer_dir, short_name): output_conllu = f"{tokenizer_dir}/{short_name}.{dataset}.gold.conllu" output_txt = f"{tokenizer_dir}/{short_name}.{dataset}.txt" + if not os.path.exists(output_conllu): + # the perl script doesn't raise an error code for file not found! + raise FileNotFoundError("Cannot convert %s as the file cannot be found" % output_conllu) # use an external script to produce the txt files subprocess.check_output(f"perl {CONLLU_TO_TXT_PERL} {output_conllu} > {output_txt}", shell=True) -- cgit v1.2.3 From fb16f8df5a1331dc3ed7f09195ad68af03679ef8 Mon Sep 17 00:00:00 2001 From: John Bauer Date: Tue, 15 Jun 2021 16:10:09 -0700 Subject: Add a version of a conversion script which turns VI_VLSP into a dataset suitable for our tokenizer Tries to condense unnecessary spaces out of the VLSP tokenization dataset --- stanza/utils/datasets/common.py | 3 + .../utils/datasets/prepare_tokenizer_treebank.py | 6 +- stanza/utils/datasets/tokenization/__init__.py | 0 .../utils/datasets/tokenization/convert_vi_vlsp.py | 94 ++++++++++++++++++++++ 4 files changed, 101 insertions(+), 2 deletions(-) create mode 100644 stanza/utils/datasets/tokenization/__init__.py create mode 100644 stanza/utils/datasets/tokenization/convert_vi_vlsp.py diff --git a/stanza/utils/datasets/common.py b/stanza/utils/datasets/common.py index 7ea7e617..87fe1490 100644 --- a/stanza/utils/datasets/common.py +++ b/stanza/utils/datasets/common.py @@ -115,6 +115,9 @@ def get_ud_treebanks(udbase_dir, filtered=True): def build_argparse(): parser = argparse.ArgumentParser() parser.add_argument('treebanks', type=str, nargs='+', help='Which treebanks to run on. Use all_ud or ud_all for all UD treebanks') + + # TODO: not sure this is the best place for dataset-specific arguments. + parser.add_argument('--vlsp_include_spaces', action='store_true', default=False, help='When processing vi_vlsp tokenization, include all of the spaces. Otherwise, we try to turn the text back into standard text') return parser diff --git a/stanza/utils/datasets/prepare_tokenizer_treebank.py b/stanza/utils/datasets/prepare_tokenizer_treebank.py index afc3b6f1..eca93013 100755 --- a/stanza/utils/datasets/prepare_tokenizer_treebank.py +++ b/stanza/utils/datasets/prepare_tokenizer_treebank.py @@ -34,7 +34,7 @@ from collections import Counter import stanza.utils.datasets.common as common import stanza.utils.datasets.prepare_tokenizer_data as prepare_tokenizer_data - +import stanza.utils.datasets.tokenization.convert_vi_vlsp as convert_vi_vlsp def copy_conllu_file(tokenizer_dir, tokenizer_file, dest_dir, dest_file, short_name): original = f"{tokenizer_dir}/{short_name}.{tokenizer_file}.conllu" @@ -1033,7 +1033,9 @@ def process_treebank(treebank, paths, args): os.makedirs(tokenizer_dir, exist_ok=True) - if short_name.startswith("ko_combined"): + if short_name == "vi_vlsp": + convert_vi_vlsp.convert_vi_vlsp(paths["EXTERN_DIR"], tokenizer_dir, args) + elif short_name.startswith("ko_combined"): build_combined_korean(udbase_dir, tokenizer_dir, short_name) elif short_name in ("it_combined", "en_combined", "es_combined"): build_combined_dataset(udbase_dir, tokenizer_dir, handparsed_dir, short_name, args.augment) diff --git a/stanza/utils/datasets/tokenization/__init__.py b/stanza/utils/datasets/tokenization/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/stanza/utils/datasets/tokenization/convert_vi_vlsp.py b/stanza/utils/datasets/tokenization/convert_vi_vlsp.py new file mode 100644 index 00000000..b6073e19 --- /dev/null +++ b/stanza/utils/datasets/tokenization/convert_vi_vlsp.py @@ -0,0 +1,94 @@ + +import os + +def find_spaces(sentence): + # TODO: there are some sentences where there is only one quote, + # and some of them should be attached to the previous word instead + # of the next word. Training should work this way, though + odd_quotes = False + + spaces = [] + for word_idx, word in enumerate(sentence): + space = True + if word_idx < len(sentence) - 1: + if sentence[word_idx+1] in (',', '.', '!', '?', ')', ':', ';', '”', '…', '...'): + space = False + if word in ('(', '“'): + space = False + if word == '"': + if odd_quotes: + # already saw one quote. put this one at the end of the PREVIOUS word + # note that we know there must be at least one word already + odd_quotes = False + spaces[word_idx-1] = False + else: + odd_quotes = True + space = False + spaces.append(space) + return spaces + +def write_file(vlsp_include_spaces, output_filename, sentences, shard): + with open(output_filename, "w") as fout: + for sent_idx, sentence in enumerate(sentences): + fout.write("# sent_id = %s.%d\n" % (shard, sent_idx)) + orig_text = " ".join(sentence) + if vlsp_include_spaces: + fout.write("# text = %s\n" % orig_text) + else: + spaces = find_spaces(sentence) + full_text = "" + for word, space in zip(sentence, spaces): + # could be made more efficient, but shouldn't matter + full_text = full_text + word + if space: + full_text = full_text + " " + fout.write("# text = %s\n" % full_text) + fout.write("# orig_text = %s\n" % orig_text) + for word_idx, word in enumerate(sentence): + fake_dep = "root" if word_idx == 0 else "dep" + fout.write("%d\t%s\t%s" % ((word_idx+1), word, word)) + fout.write("\t_\t_\t_") + fout.write("\t%d\t%s" % (word_idx, fake_dep)) + fout.write("\t_\t") + if vlsp_include_spaces or spaces[word_idx]: + fout.write("_") + else: + fout.write("SpaceAfter=No") + fout.write("\n") + fout.write("\n") + +def convert_file(vlsp_include_spaces, input_filename, output_filename, shard, split_filename=None, split_shard=None): + with open(input_filename) as fin: + lines = fin.readlines() + + sentences = [] + for line in lines: + words = line.split() + words = [w.replace("_", " ") for w in words] + sentences.append(words) + + if split_filename is not None: + # even this is a larger dev set than the train set + split_point = int(len(sentences) * 0.95) + write_file(vlsp_include_spaces, output_filename, sentences[:split_point], shard) + write_file(vlsp_include_spaces, split_filename, sentences[split_point:], split_shard) + else: + write_file(vlsp_include_spaces, output_filename, sentences, shard) + +def convert_vi_vlsp(extern_dir, tokenizer_dir, args): + input_path = os.path.join(extern_dir, "vietnamese", "VLSP2013-WS-data") + + input_train_filename = os.path.join(input_path, "VLSP2013_WS_train_gold.txt") + input_test_filename = os.path.join(input_path, "VLSP2013_WS_test_gold.txt") + if not os.path.exists(input_train_filename): + raise FileNotFoundError("Cannot find train set for VLSP at %s" % input_train_filename) + if not os.path.exists(input_test_filename): + raise FileNotFoundError("Cannot find test set for VLSP at %s" % input_test_filename) + + output_train_filename = os.path.join(tokenizer_dir, "vi_vlsp.train.gold.conllu") + output_dev_filename = os.path.join(tokenizer_dir, "vi_vlsp.dev.gold.conllu") + output_test_filename = os.path.join(tokenizer_dir, "vi_vlsp.test.gold.conllu") + + convert_file(args.vlsp_include_spaces, input_train_filename, output_train_filename, "train", output_dev_filename, "dev") + convert_file(args.vlsp_include_spaces, input_test_filename, output_test_filename, "test") + -- cgit v1.2.3 From 780a320706b73f02d20000225a441254c1472b20 Mon Sep 17 00:00:00 2001 From: John Bauer Date: Thu, 15 Jul 2021 14:43:16 -0700 Subject: Fail if the expected data split isn't where it's supposed to be --- stanza/utils/datasets/prepare_tokenizer_treebank.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stanza/utils/datasets/prepare_tokenizer_treebank.py b/stanza/utils/datasets/prepare_tokenizer_treebank.py index 459a6a74..102e86f8 100755 --- a/stanza/utils/datasets/prepare_tokenizer_treebank.py +++ b/stanza/utils/datasets/prepare_tokenizer_treebank.py @@ -944,7 +944,7 @@ def build_combined_english_gum(udbase_dir, tokenizer_dir, short_name, augment): build_combined_english_gum_dataset(udbase_dir, tokenizer_dir, short_name, dataset, augment) def prepare_ud_dataset(treebank, udbase_dir, tokenizer_dir, short_name, short_language, dataset, augment=True): - input_conllu = common.find_treebank_dataset_file(treebank, udbase_dir, dataset, "conllu") + input_conllu = common.find_treebank_dataset_file(treebank, udbase_dir, dataset, "conllu", fail=True) output_conllu = f"{tokenizer_dir}/{short_name}.{dataset}.gold.conllu" if short_name == "te_mtg" and dataset == 'train' and augment: -- cgit v1.2.3 From da949c889927f71ce9d82765212d7a334e2473f2 Mon Sep 17 00:00:00 2001 From: John Bauer Date: Fri, 16 Jul 2021 09:38:50 -0700 Subject: Attach ". to the end of a sentence --- stanza/utils/datasets/tokenization/convert_vi_vlsp.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/stanza/utils/datasets/tokenization/convert_vi_vlsp.py b/stanza/utils/datasets/tokenization/convert_vi_vlsp.py index b6073e19..2c99c49a 100644 --- a/stanza/utils/datasets/tokenization/convert_vi_vlsp.py +++ b/stanza/utils/datasets/tokenization/convert_vi_vlsp.py @@ -10,6 +10,14 @@ def find_spaces(sentence): spaces = [] for word_idx, word in enumerate(sentence): space = True + # Quote period at the end of a sentence needs to be attached + # to the rest of the text. Some sentences have `"... text` + # in the middle, though, so look for that + if word_idx < len(sentence) - 2 and sentence[word_idx+1] == '"': + if sentence[word_idx+2] == '.': + space = False + elif word_idx == len(sentence) - 3 and sentence[word_idx+2] == '...': + space = False if word_idx < len(sentence) - 1: if sentence[word_idx+1] in (',', '.', '!', '?', ')', ':', ';', '”', '…', '...'): space = False -- cgit v1.2.3 From 537329692fc17424a91368ae2b7a88bcb2db0dc0 Mon Sep 17 00:00:00 2001 From: Vy Date: Fri, 16 Jul 2021 12:58:48 -0700 Subject: only convert file --- .../utils/datasets/tokenization/convert_vi_vlsp.py | 27 +++++++++++++++++----- 1 file changed, 21 insertions(+), 6 deletions(-) diff --git a/stanza/utils/datasets/tokenization/convert_vi_vlsp.py b/stanza/utils/datasets/tokenization/convert_vi_vlsp.py index b6073e19..37a22077 100644 --- a/stanza/utils/datasets/tokenization/convert_vi_vlsp.py +++ b/stanza/utils/datasets/tokenization/convert_vi_vlsp.py @@ -1,6 +1,8 @@ import os +punctuation_set = (',', '.', '!', '?', ')', ':', ';', '”', '…', '...') + def find_spaces(sentence): # TODO: there are some sentences where there is only one quote, # and some of them should be attached to the previous word instead @@ -11,9 +13,9 @@ def find_spaces(sentence): for word_idx, word in enumerate(sentence): space = True if word_idx < len(sentence) - 1: - if sentence[word_idx+1] in (',', '.', '!', '?', ')', ':', ';', '”', '…', '...'): + if sentence[word_idx+1] in (',', '.', '!', '?', ')', ':', ';', '”', '…', '...','/', '%'): space = False - if word in ('(', '“'): + if word in ('(', '“', '/'): space = False if word == '"': if odd_quotes: @@ -29,9 +31,17 @@ def find_spaces(sentence): def write_file(vlsp_include_spaces, output_filename, sentences, shard): with open(output_filename, "w") as fout: + check_headlines = False for sent_idx, sentence in enumerate(sentences): fout.write("# sent_id = %s.%d\n" % (shard, sent_idx)) orig_text = " ".join(sentence) + #check if the previous line is a headline (no ending mark at the end) then make this sentence a new par + if check_headlines: + fout.write("# newpar id =%s.%d.1\n" % (shard, sent_idx)) + check_headlines = False + if sentence[len(sentence) - 1] not in punctuation_set: + check_headlines = True + if vlsp_include_spaces: fout.write("# text = %s\n" % orig_text) else: @@ -63,10 +73,15 @@ def convert_file(vlsp_include_spaces, input_filename, output_filename, shard, sp sentences = [] for line in lines: - words = line.split() - words = [w.replace("_", " ") for w in words] - sentences.append(words) - + if len(line.replace("_", " ").split())>1: + words = line.split() + #one syllable lines are eliminated + if len(words) == 1 and len(words[0].split("_")) == 1: + continue + else: + words = [w.replace("_", " ") for w in words] + sentences.append(words) + if split_filename is not None: # even this is a larger dev set than the train set split_point = int(len(sentences) * 0.95) -- cgit v1.2.3 From 72bc6a3c6cac9152e53adc6c645c424a603b1c15 Mon Sep 17 00:00:00 2001 From: John Bauer Date: Fri, 16 Jul 2021 13:34:05 -0700 Subject: Adjust the newpar title --- stanza/utils/datasets/tokenization/convert_vi_vlsp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stanza/utils/datasets/tokenization/convert_vi_vlsp.py b/stanza/utils/datasets/tokenization/convert_vi_vlsp.py index 31e7a985..2c00a51c 100644 --- a/stanza/utils/datasets/tokenization/convert_vi_vlsp.py +++ b/stanza/utils/datasets/tokenization/convert_vi_vlsp.py @@ -45,7 +45,7 @@ def write_file(vlsp_include_spaces, output_filename, sentences, shard): orig_text = " ".join(sentence) #check if the previous line is a headline (no ending mark at the end) then make this sentence a new par if check_headlines: - fout.write("# newpar id =%s.%d.1\n" % (shard, sent_idx)) + fout.write("# newpar_id = %s.%d.1\n" % (shard, sent_idx)) check_headlines = False if sentence[len(sentence) - 1] not in punctuation_set: check_headlines = True -- cgit v1.2.3 From b1de854cb0d94fe38867ac11db90062814d18cd5 Mon Sep 17 00:00:00 2001 From: John Bauer Date: Fri, 16 Jul 2021 14:55:30 -0700 Subject: Move thai orchid & best tokenization to the tokenization specific directory --- stanza/utils/datasets/process_best.py | 174 --------------------- stanza/utils/datasets/process_orchid.py | 153 ------------------ stanza/utils/datasets/process_thai_tokenization.py | 66 -------- stanza/utils/datasets/tokenization/process_best.py | 174 +++++++++++++++++++++ .../utils/datasets/tokenization/process_orchid.py | 153 ++++++++++++++++++ .../tokenization/process_thai_tokenization.py | 66 ++++++++ 6 files changed, 393 insertions(+), 393 deletions(-) delete mode 100644 stanza/utils/datasets/process_best.py delete mode 100644 stanza/utils/datasets/process_orchid.py delete mode 100644 stanza/utils/datasets/process_thai_tokenization.py create mode 100644 stanza/utils/datasets/tokenization/process_best.py create mode 100644 stanza/utils/datasets/tokenization/process_orchid.py create mode 100644 stanza/utils/datasets/tokenization/process_thai_tokenization.py diff --git a/stanza/utils/datasets/process_best.py b/stanza/utils/datasets/process_best.py deleted file mode 100644 index 21125455..00000000 --- a/stanza/utils/datasets/process_best.py +++ /dev/null @@ -1,174 +0,0 @@ -"""Parses the BEST Thai dataset. - -That is to say, the dataset named BEST. We have not yet figured out -which segmentation standard we prefer. - -Note that the version of BEST we used actually had some strange -sentence splits according to a native Thai speaker. Not sure how to -fix that. Options include doing it automatically or finding some -knowledgable annotators to resplit it for us (or just not using BEST) - -This outputs the tokenization results in a conll format similar to -that of the UD treebanks, so we pretend to be a UD treebank for ease -of compatibility with the stanza tools. - -python3 -m stanza.utils.datasets.process_best extern_data/thai/best data/tokenize -./scripts/run_tokenize.sh UD_Thai-best --dropout 0.05 --unit_dropout 0.05 --steps 50000 -""" - -import glob -import os -import random -import re -import sys - -from pythainlp import sent_tokenize - -from stanza.utils.datasets.process_thai_tokenization import write_dataset - -def clean_line(line): - line = line.replace("html>", "html|>") - # news_00089.txt - line = line.replace("", "") - line = line.replace("", "") - # specific error that occurs in encyclopedia_00095.txt - line = line.replace("Penn", "|Penn>") - # news_00058.txt - line = line.replace("จม.เปิดผนึก", "จม.|เปิดผนึก") - # news_00015.txt - line = re.sub("([^|<>]+)([^|<>]+)", "\\1|\\2", line) - # news_00024.txt - line = re.sub("([^|<>]+)", "\\1", line) - # news_00055.txt - line = re.sub("([^|<>]+)([^|<>]+)", "\\1|\\2", line) - line = re.sub("([^|<>]+)([^|<>]+)", "\\1|\\2", line) - line = re.sub("([^|<>]+)([^|<>]+) ([^|<>]+)", "\\1|\\2|\\3", line) - # news_00008.txt and other news articles - line = re.sub("([0-9])", "|\\1", line) - line = line.replace(" ", "|") - line = line.strip() - return line - - -def clean_word(word): - # novel_00078.txt - if word == '': - return 'พี่มน' - if word.startswith("") and word.endswith(""): - return word[4:-5] - if word.startswith("") and word.endswith(""): - return word[4:-5] - if word.startswith("") and word.endswith(""): - return word[6:-7] - if word.startswith(""): - return word[4:] - if word.endswith(""): - return word[:-5] - if word.startswith(""): - return word[6:] - if word.endswith(""): - return word[:-7] - if word == '<': - return word - return word - -def reprocess_lines(processed_lines): - reprocessed_lines = [] - for line in processed_lines: - text = "".join(line) - chunks = sent_tokenize(text) - if sum(len(x) for x in chunks) != len(text): - raise ValueError("Got unexpected text length: \n{}\nvs\n{}".format(text, chunks)) - - chunk_lengths = [len(x) for x in chunks] - - current_length = 0 - new_line = [] - for word in line: - if len(word) + current_length < chunk_lengths[0]: - new_line.append(word) - current_length = current_length + len(word) - elif len(word) + current_length == chunk_lengths[0]: - new_line.append(word) - reprocessed_lines.append(new_line) - new_line = [] - chunk_lengths = chunk_lengths[1:] - current_length = 0 - else: - remaining_len = chunk_lengths[0] - current_length - new_line.append(word[:remaining_len]) - reprocessed_lines.append(new_line) - word = word[remaining_len:] - chunk_lengths = chunk_lengths[1:] - while len(word) > chunk_lengths[0]: - new_line = [word[:chunk_lengths[0]]] - reprocessed_lines.append(new_line) - word = word[chunk_lengths[0]:] - chunk_lengths = chunk_lengths[1:] - new_line = [word] - current_length = len(word) - reprocessed_lines.append(new_line) - return reprocessed_lines - -def read_data(input_dir): - subdirs = [os.path.join(input_dir, 'article'), - os.path.join(input_dir, 'encyclopedia'), - os.path.join(input_dir, 'news'), - os.path.join(input_dir, 'novel')] - files = [] - for subdir in subdirs: - if not os.path.exists(subdir): - raise FileNotFoundError("Expected a directory that did not exist: {}".format(subdir)) - files.extend(glob.glob(os.path.join(subdir, '*.txt'))) - - documents = [] - for filename in files: - with open(filename) as fin: - sentences = [] - processed_lines = [] - for line in fin.readlines(): - line = clean_line(line) - words = line.split("|") - words = [clean_word(x) for x in words] - for word in words: - if len(word) > 1 and word[0] == '<': - raise ValueError("Unexpected word '{}' in document {}".format(word, filename)) - words = [x for x in words if x] - processed_lines.append(words) - - processed_lines = reprocess_lines(processed_lines) - - for words in processed_lines: - # turn the words into a sentence - sentence = [] - for word in words: - word = word.strip() - if not word: - if len(sentence) == 0: - raise ValueError("Unexpected space at start of sentence in document {}".format(filename)) - sentence[-1] = (sentence[-1][0], True) - else: - sentence.append((word, False)) - # blank lines are very rare in best, but why not treat them as a paragraph break - if len(sentence) == 0: - paragraphs = [sentences] - documents.append(paragraphs) - sentences = [] - continue - sentence[-1] = (sentence[-1][0], True) - sentences.append(sentence) - paragraphs = [sentences] - documents.append(paragraphs) - - return documents - -def main(): - random.seed(1000) - input_dir = sys.argv[1] - output_dir = sys.argv[2] - documents = read_data(input_dir) - write_dataset(documents, output_dir, "best") - - -if __name__ == '__main__': - main() diff --git a/stanza/utils/datasets/process_orchid.py b/stanza/utils/datasets/process_orchid.py deleted file mode 100644 index 794c3925..00000000 --- a/stanza/utils/datasets/process_orchid.py +++ /dev/null @@ -1,153 +0,0 @@ -"""Parses the xml conversion of orchid - -https://github.com/korakot/thainlp/blob/master/xmlchid.xml - -For example, if you put the data file in the above link in -extern_data/thai/orchid/xmlchid.xml -you would then run -python3 -m stanza.utils.datasets.process_orchid extern_data/thai/orchid/xmlchid.xml data/tokenize - -Because there is no definitive train/dev/test split that we have found -so far, we randomly shuffle the data on a paragraph level and split it -80/10/10. A random seed is chosen so that the splits are reproducible. - -The datasets produced have a similar format to the UD datasets, so we -give it a fake UD name to make life easier for the downstream tools. - -Training on this dataset seems to work best with low dropout numbers. -For example: - -./scripts/run_tokenize.sh UD_Thai-orchid --dropout 0.05 --unit_dropout 0.05 - -This results in a model with dev set scores: - th_orchid 87.98 70.94 -test set scores: - 91.60 72.43 - -Apparently the random split produced a test set easier than the dev set. -""" - -import random -import sys -import xml.etree.ElementTree as ET - -from stanza.utils.datasets.process_thai_tokenization import write_dataset - -# line "122819" has some error in the tokenization of the musical notation -# line "209380" is also messed up -# others have @ followed by a part of speech, which is clearly wrong - -skipped_lines = { - "122819", - "209380", - "227769", - "245992", - "347163", - "409708", - "431227", -} - -escape_sequences = { - '': '(', - '': ')', - '': '^', - '': '.', - '': '-', - '': '*', - '': '"', - '': '/', - '': ':', - '': '=', - '': ',', - '': ';', - '': '<', - '': '>', - '': '&', - '': '{', - '': '}', - '': "'", - '': '+', - '': '#', - '': '$', - '': '@', - '': '?', - '': '!', - 'app
  • ances': 'appliances', - 'intel
  • gence': 'intelligence', - "'": "/'", - '<100>': '100', -} - -allowed_sequences = { - '', - '', - '', - '', - '', - '
  • ', - '<---vp', - '<---', - '<----', -} - -def read_data(input_filename): - tree = ET.parse(input_filename) - - # we will put each paragraph in a separate block in the output file - # we won't pay any attention to the document boundaries unless we - # later find out it was necessary - # a paragraph will be a list of sentences - # a sentence is a list of words, where each word is a string - documents = [] - - root = tree.getroot() - for document in root: - # these should all be documents - if document.tag != 'document': - raise ValueError("Unexpected orchid xml layout: {}".format(document.tag)) - paragraphs = [] - for paragraph in document: - if paragraph.tag != 'paragraph': - raise ValueError("Unexpected orchid xml layout: {} under {}".format(paragraph.tag, document.tag)) - sentences = [] - for sentence in paragraph: - if sentence.tag != 'sentence': - raise ValueError("Unexpected orchid xml layout: {} under {}".format(sentence.tag, document.tag)) - if sentence.attrib['line_num'] in skipped_lines: - continue - words = [] - for word_idx, word in enumerate(sentence): - if word.tag != 'word': - raise ValueError("Unexpected orchid xml layout: {} under {}".format(word.tag, sentence.tag)) - word = word.attrib['surface'] - word = escape_sequences.get(word, word) - if word == '': - if word_idx == 0: - raise ValueError("Space character was the first token in a sentence: {}".format(sentence.attrib['line_num'])) - else: - words[-1] = (words[-1][0], True) - continue - if len(word) > 1 and word[0] == '<' and word not in allowed_sequences: - raise ValueError("Unknown escape sequence {}".format(word)) - words.append((word, False)) - if len(words) == 0: - continue - sentences.append(words) - paragraphs.append(sentences) - documents.append(paragraphs) - - print("Number of documents: {}".format(len(documents))) - print("Number of paragraphs: {}".format(sum(len(document) for document in documents))) - return documents - - -def main(): - random.seed(1007) - input_filename = sys.argv[1] - output_dir = sys.argv[2] - documents = read_data(input_filename) - write_dataset(documents, output_dir, "orchid") - - -if __name__ == '__main__': - main() diff --git a/stanza/utils/datasets/process_thai_tokenization.py b/stanza/utils/datasets/process_thai_tokenization.py deleted file mode 100644 index 27e347dd..00000000 --- a/stanza/utils/datasets/process_thai_tokenization.py +++ /dev/null @@ -1,66 +0,0 @@ -import os -import random - -def write_section(output_dir, dataset_name, section, documents): - """ - Writes a list of documents for tokenization, including a file in conll format - - The Thai datasets generally have no MWT (apparently not relevant for Thai) - - output_dir: the destination directory for the output files - dataset_name: orchid, BEST, lst20, etc - section: train/dev/test - documents: a nested list of documents, paragraphs, sentences, words - words is a list of (word, space_follows) - """ - with open(os.path.join(output_dir, 'th_%s-ud-%s-mwt.json' % (dataset_name, section)), 'w') as fout: - fout.write("[]\n") - - text_out = open(os.path.join(output_dir, 'th_%s.%s.txt' % (dataset_name, section)), 'w') - label_out = open(os.path.join(output_dir, 'th_%s-ud-%s.toklabels' % (dataset_name, section)), 'w') - for document in documents: - for paragraph in document: - for sentence_idx, sentence in enumerate(paragraph): - for word_idx, word in enumerate(sentence): - # TODO: split with newlines to make it more readable? - text_out.write(word[0]) - for i in range(len(word[0]) - 1): - label_out.write("0") - if word_idx == len(sentence) - 1: - label_out.write("2") - else: - label_out.write("1") - if word[1] and sentence_idx != len(paragraph) - 1: - text_out.write(' ') - label_out.write('0') - - text_out.write("\n\n") - label_out.write("\n\n") - - text_out.close() - label_out.close() - - with open(os.path.join(output_dir, 'th_%s.%s.gold.conllu' % (dataset_name, section)), 'w') as fout: - for document in documents: - for paragraph in document: - for sentence in paragraph: - for word_idx, word in enumerate(sentence): - # SpaceAfter is left blank if there is space after the word - space = '_' if word[1] else 'SpaceAfter=No' - # Note the faked dependency structure: the conll reading code - # needs it even if it isn't being used in any way - fake_dep = 'root' if word_idx == 0 else 'dep' - fout.write('{}\t{}\t_\t_\t_\t_\t{}\t{}\t{}:{}\t{}\n'.format(word_idx+1, word[0], word_idx, fake_dep, word_idx, fake_dep, space)) - fout.write('\n') - -def write_dataset(documents, output_dir, dataset_name): - """ - Shuffle a list of documents, write three sections - """ - random.shuffle(documents) - num_train = int(len(documents) * 0.8) - num_dev = int(len(documents) * 0.1) - os.makedirs(output_dir, exist_ok=True) - write_section(output_dir, dataset_name, 'train', documents[:num_train]) - write_section(output_dir, dataset_name, 'dev', documents[num_train:num_train+num_dev]) - write_section(output_dir, dataset_name, 'test', documents[num_train+num_dev:]) diff --git a/stanza/utils/datasets/tokenization/process_best.py b/stanza/utils/datasets/tokenization/process_best.py new file mode 100644 index 00000000..acfcffa7 --- /dev/null +++ b/stanza/utils/datasets/tokenization/process_best.py @@ -0,0 +1,174 @@ +"""Parses the BEST Thai dataset. + +That is to say, the dataset named BEST. We have not yet figured out +which segmentation standard we prefer. + +Note that the version of BEST we used actually had some strange +sentence splits according to a native Thai speaker. Not sure how to +fix that. Options include doing it automatically or finding some +knowledgable annotators to resplit it for us (or just not using BEST) + +This outputs the tokenization results in a conll format similar to +that of the UD treebanks, so we pretend to be a UD treebank for ease +of compatibility with the stanza tools. + +python3 -m stanza.utils.datasets.tokenization.process_best extern_data/thai/best data/tokenize +./scripts/run_tokenize.sh UD_Thai-best --dropout 0.05 --unit_dropout 0.05 --steps 50000 +""" + +import glob +import os +import random +import re +import sys + +from pythainlp import sent_tokenize + +from stanza.utils.datasets.tokenization.process_thai_tokenization import write_dataset + +def clean_line(line): + line = line.replace("html>", "html|>") + # news_00089.txt + line = line.replace("", "") + line = line.replace("", "") + # specific error that occurs in encyclopedia_00095.txt + line = line.replace("Penn", "|Penn>") + # news_00058.txt + line = line.replace("จม.เปิดผนึก", "จม.|เปิดผนึก") + # news_00015.txt + line = re.sub("([^|<>]+)([^|<>]+)", "\\1|\\2", line) + # news_00024.txt + line = re.sub("([^|<>]+)", "\\1", line) + # news_00055.txt + line = re.sub("([^|<>]+)([^|<>]+)", "\\1|\\2", line) + line = re.sub("([^|<>]+)([^|<>]+)", "\\1|\\2", line) + line = re.sub("([^|<>]+)([^|<>]+) ([^|<>]+)", "\\1|\\2|\\3", line) + # news_00008.txt and other news articles + line = re.sub("([0-9])", "|\\1", line) + line = line.replace(" ", "|") + line = line.strip() + return line + + +def clean_word(word): + # novel_00078.txt + if word == '': + return 'พี่มน' + if word.startswith("") and word.endswith(""): + return word[4:-5] + if word.startswith("") and word.endswith(""): + return word[4:-5] + if word.startswith("") and word.endswith(""): + return word[6:-7] + if word.startswith(""): + return word[4:] + if word.endswith(""): + return word[:-5] + if word.startswith(""): + return word[6:] + if word.endswith(""): + return word[:-7] + if word == '<': + return word + return word + +def reprocess_lines(processed_lines): + reprocessed_lines = [] + for line in processed_lines: + text = "".join(line) + chunks = sent_tokenize(text) + if sum(len(x) for x in chunks) != len(text): + raise ValueError("Got unexpected text length: \n{}\nvs\n{}".format(text, chunks)) + + chunk_lengths = [len(x) for x in chunks] + + current_length = 0 + new_line = [] + for word in line: + if len(word) + current_length < chunk_lengths[0]: + new_line.append(word) + current_length = current_length + len(word) + elif len(word) + current_length == chunk_lengths[0]: + new_line.append(word) + reprocessed_lines.append(new_line) + new_line = [] + chunk_lengths = chunk_lengths[1:] + current_length = 0 + else: + remaining_len = chunk_lengths[0] - current_length + new_line.append(word[:remaining_len]) + reprocessed_lines.append(new_line) + word = word[remaining_len:] + chunk_lengths = chunk_lengths[1:] + while len(word) > chunk_lengths[0]: + new_line = [word[:chunk_lengths[0]]] + reprocessed_lines.append(new_line) + word = word[chunk_lengths[0]:] + chunk_lengths = chunk_lengths[1:] + new_line = [word] + current_length = len(word) + reprocessed_lines.append(new_line) + return reprocessed_lines + +def read_data(input_dir): + subdirs = [os.path.join(input_dir, 'article'), + os.path.join(input_dir, 'encyclopedia'), + os.path.join(input_dir, 'news'), + os.path.join(input_dir, 'novel')] + files = [] + for subdir in subdirs: + if not os.path.exists(subdir): + raise FileNotFoundError("Expected a directory that did not exist: {}".format(subdir)) + files.extend(glob.glob(os.path.join(subdir, '*.txt'))) + + documents = [] + for filename in files: + with open(filename) as fin: + sentences = [] + processed_lines = [] + for line in fin.readlines(): + line = clean_line(line) + words = line.split("|") + words = [clean_word(x) for x in words] + for word in words: + if len(word) > 1 and word[0] == '<': + raise ValueError("Unexpected word '{}' in document {}".format(word, filename)) + words = [x for x in words if x] + processed_lines.append(words) + + processed_lines = reprocess_lines(processed_lines) + + for words in processed_lines: + # turn the words into a sentence + sentence = [] + for word in words: + word = word.strip() + if not word: + if len(sentence) == 0: + raise ValueError("Unexpected space at start of sentence in document {}".format(filename)) + sentence[-1] = (sentence[-1][0], True) + else: + sentence.append((word, False)) + # blank lines are very rare in best, but why not treat them as a paragraph break + if len(sentence) == 0: + paragraphs = [sentences] + documents.append(paragraphs) + sentences = [] + continue + sentence[-1] = (sentence[-1][0], True) + sentences.append(sentence) + paragraphs = [sentences] + documents.append(paragraphs) + + return documents + +def main(): + random.seed(1000) + input_dir = sys.argv[1] + output_dir = sys.argv[2] + documents = read_data(input_dir) + write_dataset(documents, output_dir, "best") + + +if __name__ == '__main__': + main() diff --git a/stanza/utils/datasets/tokenization/process_orchid.py b/stanza/utils/datasets/tokenization/process_orchid.py new file mode 100644 index 00000000..e7064d47 --- /dev/null +++ b/stanza/utils/datasets/tokenization/process_orchid.py @@ -0,0 +1,153 @@ +"""Parses the xml conversion of orchid + +https://github.com/korakot/thainlp/blob/master/xmlchid.xml + +For example, if you put the data file in the above link in +extern_data/thai/orchid/xmlchid.xml +you would then run +python3 -m stanza.utils.datasets.tokenization.process_orchid extern_data/thai/orchid/xmlchid.xml data/tokenize + +Because there is no definitive train/dev/test split that we have found +so far, we randomly shuffle the data on a paragraph level and split it +80/10/10. A random seed is chosen so that the splits are reproducible. + +The datasets produced have a similar format to the UD datasets, so we +give it a fake UD name to make life easier for the downstream tools. + +Training on this dataset seems to work best with low dropout numbers. +For example: + +python3 -m stanza.utils.training.run_tokenizer th_orchid --dropout 0.05 --unit_dropout 0.05 + +This results in a model with dev set scores: + th_orchid 87.98 70.94 +test set scores: + 91.60 72.43 + +Apparently the random split produced a test set easier than the dev set. +""" + +import random +import sys +import xml.etree.ElementTree as ET + +from stanza.utils.datasets.tokenization.process_thai_tokenization import write_dataset + +# line "122819" has some error in the tokenization of the musical notation +# line "209380" is also messed up +# others have @ followed by a part of speech, which is clearly wrong + +skipped_lines = { + "122819", + "209380", + "227769", + "245992", + "347163", + "409708", + "431227", +} + +escape_sequences = { + '': '(', + '': ')', + '': '^', + '': '.', + '': '-', + '': '*', + '': '"', + '': '/', + '': ':', + '': '=', + '': ',', + '': ';', + '': '<', + '': '>', + '': '&', + '': '{', + '': '}', + '': "'", + '': '+', + '': '#', + '': '$', + '': '@', + '': '?', + '': '!', + 'app
  • ances': 'appliances', + 'intel
  • gence': 'intelligence', + "'": "/'", + '<100>': '100', +} + +allowed_sequences = { + '', + '', + '', + '', + '', + '
  • ', + '<---vp', + '<---', + '<----', +} + +def read_data(input_filename): + tree = ET.parse(input_filename) + + # we will put each paragraph in a separate block in the output file + # we won't pay any attention to the document boundaries unless we + # later find out it was necessary + # a paragraph will be a list of sentences + # a sentence is a list of words, where each word is a string + documents = [] + + root = tree.getroot() + for document in root: + # these should all be documents + if document.tag != 'document': + raise ValueError("Unexpected orchid xml layout: {}".format(document.tag)) + paragraphs = [] + for paragraph in document: + if paragraph.tag != 'paragraph': + raise ValueError("Unexpected orchid xml layout: {} under {}".format(paragraph.tag, document.tag)) + sentences = [] + for sentence in paragraph: + if sentence.tag != 'sentence': + raise ValueError("Unexpected orchid xml layout: {} under {}".format(sentence.tag, document.tag)) + if sentence.attrib['line_num'] in skipped_lines: + continue + words = [] + for word_idx, word in enumerate(sentence): + if word.tag != 'word': + raise ValueError("Unexpected orchid xml layout: {} under {}".format(word.tag, sentence.tag)) + word = word.attrib['surface'] + word = escape_sequences.get(word, word) + if word == '': + if word_idx == 0: + raise ValueError("Space character was the first token in a sentence: {}".format(sentence.attrib['line_num'])) + else: + words[-1] = (words[-1][0], True) + continue + if len(word) > 1 and word[0] == '<' and word not in allowed_sequences: + raise ValueError("Unknown escape sequence {}".format(word)) + words.append((word, False)) + if len(words) == 0: + continue + sentences.append(words) + paragraphs.append(sentences) + documents.append(paragraphs) + + print("Number of documents: {}".format(len(documents))) + print("Number of paragraphs: {}".format(sum(len(document) for document in documents))) + return documents + + +def main(): + random.seed(1007) + input_filename = sys.argv[1] + output_dir = sys.argv[2] + documents = read_data(input_filename) + write_dataset(documents, output_dir, "orchid") + + +if __name__ == '__main__': + main() diff --git a/stanza/utils/datasets/tokenization/process_thai_tokenization.py b/stanza/utils/datasets/tokenization/process_thai_tokenization.py new file mode 100644 index 00000000..27e347dd --- /dev/null +++ b/stanza/utils/datasets/tokenization/process_thai_tokenization.py @@ -0,0 +1,66 @@ +import os +import random + +def write_section(output_dir, dataset_name, section, documents): + """ + Writes a list of documents for tokenization, including a file in conll format + + The Thai datasets generally have no MWT (apparently not relevant for Thai) + + output_dir: the destination directory for the output files + dataset_name: orchid, BEST, lst20, etc + section: train/dev/test + documents: a nested list of documents, paragraphs, sentences, words + words is a list of (word, space_follows) + """ + with open(os.path.join(output_dir, 'th_%s-ud-%s-mwt.json' % (dataset_name, section)), 'w') as fout: + fout.write("[]\n") + + text_out = open(os.path.join(output_dir, 'th_%s.%s.txt' % (dataset_name, section)), 'w') + label_out = open(os.path.join(output_dir, 'th_%s-ud-%s.toklabels' % (dataset_name, section)), 'w') + for document in documents: + for paragraph in document: + for sentence_idx, sentence in enumerate(paragraph): + for word_idx, word in enumerate(sentence): + # TODO: split with newlines to make it more readable? + text_out.write(word[0]) + for i in range(len(word[0]) - 1): + label_out.write("0") + if word_idx == len(sentence) - 1: + label_out.write("2") + else: + label_out.write("1") + if word[1] and sentence_idx != len(paragraph) - 1: + text_out.write(' ') + label_out.write('0') + + text_out.write("\n\n") + label_out.write("\n\n") + + text_out.close() + label_out.close() + + with open(os.path.join(output_dir, 'th_%s.%s.gold.conllu' % (dataset_name, section)), 'w') as fout: + for document in documents: + for paragraph in document: + for sentence in paragraph: + for word_idx, word in enumerate(sentence): + # SpaceAfter is left blank if there is space after the word + space = '_' if word[1] else 'SpaceAfter=No' + # Note the faked dependency structure: the conll reading code + # needs it even if it isn't being used in any way + fake_dep = 'root' if word_idx == 0 else 'dep' + fout.write('{}\t{}\t_\t_\t_\t_\t{}\t{}\t{}:{}\t{}\n'.format(word_idx+1, word[0], word_idx, fake_dep, word_idx, fake_dep, space)) + fout.write('\n') + +def write_dataset(documents, output_dir, dataset_name): + """ + Shuffle a list of documents, write three sections + """ + random.shuffle(documents) + num_train = int(len(documents) * 0.8) + num_dev = int(len(documents) * 0.1) + os.makedirs(output_dir, exist_ok=True) + write_section(output_dir, dataset_name, 'train', documents[:num_train]) + write_section(output_dir, dataset_name, 'dev', documents[num_train:num_train+num_dev]) + write_section(output_dir, dataset_name, 'test', documents[num_train+num_dev:]) -- cgit v1.2.3 From 5fb045e02b556ed9db311182999f40075ecadc0b Mon Sep 17 00:00:00 2001 From: John Bauer Date: Fri, 16 Jul 2021 16:17:46 -0700 Subject: Add th_orchid to the prepare_tokenizer_treebank script --- .../utils/datasets/prepare_tokenizer_treebank.py | 8 +- .../datasets/tokenization/convert_th_orchid.py | 159 +++++++++++++++++++++ .../utils/datasets/tokenization/process_orchid.py | 153 -------------------- 3 files changed, 164 insertions(+), 156 deletions(-) create mode 100644 stanza/utils/datasets/tokenization/convert_th_orchid.py delete mode 100644 stanza/utils/datasets/tokenization/process_orchid.py diff --git a/stanza/utils/datasets/prepare_tokenizer_treebank.py b/stanza/utils/datasets/prepare_tokenizer_treebank.py index f7e79dcb..4ebbf7a3 100755 --- a/stanza/utils/datasets/prepare_tokenizer_treebank.py +++ b/stanza/utils/datasets/prepare_tokenizer_treebank.py @@ -35,6 +35,7 @@ from collections import Counter import stanza.utils.datasets.common as common import stanza.utils.datasets.prepare_tokenizer_data as prepare_tokenizer_data import stanza.utils.datasets.tokenization.convert_vi_vlsp as convert_vi_vlsp +import stanza.utils.datasets.tokenization.convert_th_orchid as convert_th_orchid def copy_conllu_file(tokenizer_dir, tokenizer_file, dest_dir, dest_file, short_name): original = f"{tokenizer_dir}/{short_name}.{tokenizer_file}.conllu" @@ -1017,9 +1018,8 @@ def process_treebank(treebank, paths, args): """ Processes a single treebank into train, dev, test parts - TODO - Currently assumes it is always a UD treebank. There are Thai - treebanks which are not included in UD. + Includes processing for a few external tokenization datasets: + vi_vlsp, th_orchid Also, there is no specific mechanism for UD_Arabic-NYUAD or similar treebanks, which need integration with LDC datsets @@ -1035,6 +1035,8 @@ def process_treebank(treebank, paths, args): if short_name == "vi_vlsp": convert_vi_vlsp.convert_vi_vlsp(paths["EXTERN_DIR"], tokenizer_dir, args) + elif short_name == "th_orchid": + convert_th_orchid.main(paths["EXTERN_DIR"], tokenizer_dir) elif short_name.startswith("ko_combined"): build_combined_korean(udbase_dir, tokenizer_dir, short_name) elif short_name in ("it_combined", "en_combined", "es_combined"): diff --git a/stanza/utils/datasets/tokenization/convert_th_orchid.py b/stanza/utils/datasets/tokenization/convert_th_orchid.py new file mode 100644 index 00000000..4cecb491 --- /dev/null +++ b/stanza/utils/datasets/tokenization/convert_th_orchid.py @@ -0,0 +1,159 @@ +"""Parses the xml conversion of orchid + +https://github.com/korakot/thainlp/blob/master/xmlchid.xml + +For example, if you put the data file in the above link in +extern_data/thai/orchid/xmlchid.xml +you would then run +python3 -m stanza.utils.datasets.tokenization.convert_th_orchid extern_data/thai/orchid/xmlchid.xml data/tokenize + +Because there is no definitive train/dev/test split that we have found +so far, we randomly shuffle the data on a paragraph level and split it +80/10/10. A random seed is chosen so that the splits are reproducible. + +The datasets produced have a similar format to the UD datasets, so we +give it a fake UD name to make life easier for the downstream tools. + +Training on this dataset seems to work best with low dropout numbers. +For example: + +python3 -m stanza.utils.training.run_tokenizer th_orchid --dropout 0.05 --unit_dropout 0.05 + +This results in a model with dev set scores: + th_orchid 87.98 70.94 +test set scores: + 91.60 72.43 + +Apparently the random split produced a test set easier than the dev set. +""" + +import os +import random +import sys +import xml.etree.ElementTree as ET + +from stanza.utils.datasets.tokenization.process_thai_tokenization import write_dataset + +# line "122819" has some error in the tokenization of the musical notation +# line "209380" is also messed up +# others have @ followed by a part of speech, which is clearly wrong + +skipped_lines = { + "122819", + "209380", + "227769", + "245992", + "347163", + "409708", + "431227", +} + +escape_sequences = { + '': '(', + '': ')', + '': '^', + '': '.', + '': '-', + '': '*', + '': '"', + '': '/', + '': ':', + '': '=', + '': ',', + '': ';', + '': '<', + '': '>', + '': '&', + '': '{', + '': '}', + '': "'", + '': '+', + '': '#', + '': '$', + '': '@', + '': '?', + '': '!', + 'app
  • ances': 'appliances', + 'intel
  • gence': 'intelligence', + "'": "/'", + '<100>': '100', +} + +allowed_sequences = { + '', + '', + '', + '', + '', + '
  • ', + '<---vp', + '<---', + '<----', +} + +def read_data(input_filename): + print("Reading {}".format(input_filename)) + tree = ET.parse(input_filename) + + # we will put each paragraph in a separate block in the output file + # we won't pay any attention to the document boundaries unless we + # later find out it was necessary + # a paragraph will be a list of sentences + # a sentence is a list of words, where each word is a string + documents = [] + + root = tree.getroot() + for document in root: + # these should all be documents + if document.tag != 'document': + raise ValueError("Unexpected orchid xml layout: {}".format(document.tag)) + paragraphs = [] + for paragraph in document: + if paragraph.tag != 'paragraph': + raise ValueError("Unexpected orchid xml layout: {} under {}".format(paragraph.tag, document.tag)) + sentences = [] + for sentence in paragraph: + if sentence.tag != 'sentence': + raise ValueError("Unexpected orchid xml layout: {} under {}".format(sentence.tag, document.tag)) + if sentence.attrib['line_num'] in skipped_lines: + continue + words = [] + for word_idx, word in enumerate(sentence): + if word.tag != 'word': + raise ValueError("Unexpected orchid xml layout: {} under {}".format(word.tag, sentence.tag)) + word = word.attrib['surface'] + word = escape_sequences.get(word, word) + if word == '': + if word_idx == 0: + raise ValueError("Space character was the first token in a sentence: {}".format(sentence.attrib['line_num'])) + else: + words[-1] = (words[-1][0], True) + continue + if len(word) > 1 and word[0] == '<' and word not in allowed_sequences: + raise ValueError("Unknown escape sequence {}".format(word)) + words.append((word, False)) + if len(words) == 0: + continue + sentences.append(words) + paragraphs.append(sentences) + documents.append(paragraphs) + + print("Number of documents: {}".format(len(documents))) + print("Number of paragraphs: {}".format(sum(len(document) for document in documents))) + return documents + + +def main(*args): + random.seed(1007) + if not args: + args = sys.argv[1:] + input_filename = args[0] + if os.path.isdir(input_filename): + input_filename = os.path.join(input_filename, "thai", "orchid", "xmlchid.xml") + output_dir = args[1] + documents = read_data(input_filename) + write_dataset(documents, output_dir, "orchid") + + +if __name__ == '__main__': + main() diff --git a/stanza/utils/datasets/tokenization/process_orchid.py b/stanza/utils/datasets/tokenization/process_orchid.py deleted file mode 100644 index e7064d47..00000000 --- a/stanza/utils/datasets/tokenization/process_orchid.py +++ /dev/null @@ -1,153 +0,0 @@ -"""Parses the xml conversion of orchid - -https://github.com/korakot/thainlp/blob/master/xmlchid.xml - -For example, if you put the data file in the above link in -extern_data/thai/orchid/xmlchid.xml -you would then run -python3 -m stanza.utils.datasets.tokenization.process_orchid extern_data/thai/orchid/xmlchid.xml data/tokenize - -Because there is no definitive train/dev/test split that we have found -so far, we randomly shuffle the data on a paragraph level and split it -80/10/10. A random seed is chosen so that the splits are reproducible. - -The datasets produced have a similar format to the UD datasets, so we -give it a fake UD name to make life easier for the downstream tools. - -Training on this dataset seems to work best with low dropout numbers. -For example: - -python3 -m stanza.utils.training.run_tokenizer th_orchid --dropout 0.05 --unit_dropout 0.05 - -This results in a model with dev set scores: - th_orchid 87.98 70.94 -test set scores: - 91.60 72.43 - -Apparently the random split produced a test set easier than the dev set. -""" - -import random -import sys -import xml.etree.ElementTree as ET - -from stanza.utils.datasets.tokenization.process_thai_tokenization import write_dataset - -# line "122819" has some error in the tokenization of the musical notation -# line "209380" is also messed up -# others have @ followed by a part of speech, which is clearly wrong - -skipped_lines = { - "122819", - "209380", - "227769", - "245992", - "347163", - "409708", - "431227", -} - -escape_sequences = { - '': '(', - '': ')', - '': '^', - '': '.', - '': '-', - '': '*', - '': '"', - '': '/', - '': ':', - '': '=', - '': ',', - '': ';', - '': '<', - '': '>', - '': '&', - '': '{', - '': '}', - '': "'", - '': '+', - '': '#', - '': '$', - '': '@', - '': '?', - '': '!', - 'app
  • ances': 'appliances', - 'intel
  • gence': 'intelligence', - "'": "/'", - '<100>': '100', -} - -allowed_sequences = { - '', - '', - '', - '', - '', - '
  • ', - '<---vp', - '<---', - '<----', -} - -def read_data(input_filename): - tree = ET.parse(input_filename) - - # we will put each paragraph in a separate block in the output file - # we won't pay any attention to the document boundaries unless we - # later find out it was necessary - # a paragraph will be a list of sentences - # a sentence is a list of words, where each word is a string - documents = [] - - root = tree.getroot() - for document in root: - # these should all be documents - if document.tag != 'document': - raise ValueError("Unexpected orchid xml layout: {}".format(document.tag)) - paragraphs = [] - for paragraph in document: - if paragraph.tag != 'paragraph': - raise ValueError("Unexpected orchid xml layout: {} under {}".format(paragraph.tag, document.tag)) - sentences = [] - for sentence in paragraph: - if sentence.tag != 'sentence': - raise ValueError("Unexpected orchid xml layout: {} under {}".format(sentence.tag, document.tag)) - if sentence.attrib['line_num'] in skipped_lines: - continue - words = [] - for word_idx, word in enumerate(sentence): - if word.tag != 'word': - raise ValueError("Unexpected orchid xml layout: {} under {}".format(word.tag, sentence.tag)) - word = word.attrib['surface'] - word = escape_sequences.get(word, word) - if word == '': - if word_idx == 0: - raise ValueError("Space character was the first token in a sentence: {}".format(sentence.attrib['line_num'])) - else: - words[-1] = (words[-1][0], True) - continue - if len(word) > 1 and word[0] == '<' and word not in allowed_sequences: - raise ValueError("Unknown escape sequence {}".format(word)) - words.append((word, False)) - if len(words) == 0: - continue - sentences.append(words) - paragraphs.append(sentences) - documents.append(paragraphs) - - print("Number of documents: {}".format(len(documents))) - print("Number of paragraphs: {}".format(sum(len(document) for document in documents))) - return documents - - -def main(): - random.seed(1007) - input_filename = sys.argv[1] - output_dir = sys.argv[2] - documents = read_data(input_filename) - write_dataset(documents, output_dir, "orchid") - - -if __name__ == '__main__': - main() -- cgit v1.2.3 From 23b3c227cb2203f20df86ac96c454b6427bde861 Mon Sep 17 00:00:00 2001 From: John Bauer Date: Fri, 16 Jul 2021 16:55:12 -0700 Subject: Connect BEST to the conversion script --- .../utils/datasets/prepare_tokenizer_treebank.py | 5 +- .../utils/datasets/tokenization/convert_th_best.py | 186 +++++++++++++++++++++ stanza/utils/datasets/tokenization/process_best.py | 174 ------------------- 3 files changed, 190 insertions(+), 175 deletions(-) create mode 100644 stanza/utils/datasets/tokenization/convert_th_best.py delete mode 100644 stanza/utils/datasets/tokenization/process_best.py diff --git a/stanza/utils/datasets/prepare_tokenizer_treebank.py b/stanza/utils/datasets/prepare_tokenizer_treebank.py index 4ebbf7a3..5bbb9da5 100755 --- a/stanza/utils/datasets/prepare_tokenizer_treebank.py +++ b/stanza/utils/datasets/prepare_tokenizer_treebank.py @@ -35,6 +35,7 @@ from collections import Counter import stanza.utils.datasets.common as common import stanza.utils.datasets.prepare_tokenizer_data as prepare_tokenizer_data import stanza.utils.datasets.tokenization.convert_vi_vlsp as convert_vi_vlsp +import stanza.utils.datasets.tokenization.convert_th_best as convert_th_best import stanza.utils.datasets.tokenization.convert_th_orchid as convert_th_orchid def copy_conllu_file(tokenizer_dir, tokenizer_file, dest_dir, dest_file, short_name): @@ -1019,7 +1020,7 @@ def process_treebank(treebank, paths, args): Processes a single treebank into train, dev, test parts Includes processing for a few external tokenization datasets: - vi_vlsp, th_orchid + vi_vlsp, th_orchid, th_best Also, there is no specific mechanism for UD_Arabic-NYUAD or similar treebanks, which need integration with LDC datsets @@ -1037,6 +1038,8 @@ def process_treebank(treebank, paths, args): convert_vi_vlsp.convert_vi_vlsp(paths["EXTERN_DIR"], tokenizer_dir, args) elif short_name == "th_orchid": convert_th_orchid.main(paths["EXTERN_DIR"], tokenizer_dir) + elif short_name == "th_best": + convert_th_best.main(paths["EXTERN_DIR"], tokenizer_dir) elif short_name.startswith("ko_combined"): build_combined_korean(udbase_dir, tokenizer_dir, short_name) elif short_name in ("it_combined", "en_combined", "es_combined"): diff --git a/stanza/utils/datasets/tokenization/convert_th_best.py b/stanza/utils/datasets/tokenization/convert_th_best.py new file mode 100644 index 00000000..a642702c --- /dev/null +++ b/stanza/utils/datasets/tokenization/convert_th_best.py @@ -0,0 +1,186 @@ +"""Parses the BEST Thai dataset. + +That is to say, the dataset named BEST. We have not yet figured out +which segmentation standard we prefer. + +Note that the version of BEST we used actually had some strange +sentence splits according to a native Thai speaker. Not sure how to +fix that. Options include doing it automatically or finding some +knowledgable annotators to resplit it for us (or just not using BEST) + +This outputs the tokenization results in a conll format similar to +that of the UD treebanks, so we pretend to be a UD treebank for ease +of compatibility with the stanza tools. + +BEST can be downloaded from here: + +https://aiforthai.in.th/corpus.php + +python3 -m stanza.utils.datasets.tokenization.process_best extern_data/thai/best data/tokenize +./scripts/run_tokenize.sh UD_Thai-best --dropout 0.05 --unit_dropout 0.05 --steps 50000 +""" + +import glob +import os +import random +import re +import sys + +from pythainlp import sent_tokenize + +from stanza.utils.datasets.tokenization.process_thai_tokenization import write_dataset + +def clean_line(line): + line = line.replace("html>", "html|>") + # news_00089.txt + line = line.replace("", "") + line = line.replace("", "") + # specific error that occurs in encyclopedia_00095.txt + line = line.replace("Penn", "|Penn>") + # news_00058.txt + line = line.replace("จม.เปิดผนึก", "จม.|เปิดผนึก") + # news_00015.txt + line = re.sub("([^|<>]+)([^|<>]+)", "\\1|\\2", line) + # news_00024.txt + line = re.sub("([^|<>]+)", "\\1", line) + # news_00055.txt + line = re.sub("([^|<>]+)([^|<>]+)", "\\1|\\2", line) + line = re.sub("([^|<>]+)([^|<>]+)", "\\1|\\2", line) + line = re.sub("([^|<>]+)([^|<>]+) ([^|<>]+)", "\\1|\\2|\\3", line) + # news_00008.txt and other news articles + line = re.sub("([0-9])", "|\\1", line) + line = line.replace(" ", "|") + line = line.strip() + return line + + +def clean_word(word): + # novel_00078.txt + if word == '': + return 'พี่มน' + if word.startswith("") and word.endswith(""): + return word[4:-5] + if word.startswith("") and word.endswith(""): + return word[4:-5] + if word.startswith("") and word.endswith(""): + return word[6:-7] + if word.startswith(""): + return word[4:] + if word.endswith(""): + return word[:-5] + if word.startswith(""): + return word[6:] + if word.endswith(""): + return word[:-7] + if word == '<': + return word + return word + +def reprocess_lines(processed_lines): + reprocessed_lines = [] + for line in processed_lines: + text = "".join(line) + chunks = sent_tokenize(text) + if sum(len(x) for x in chunks) != len(text): + raise ValueError("Got unexpected text length: \n{}\nvs\n{}".format(text, chunks)) + + chunk_lengths = [len(x) for x in chunks] + + current_length = 0 + new_line = [] + for word in line: + if len(word) + current_length < chunk_lengths[0]: + new_line.append(word) + current_length = current_length + len(word) + elif len(word) + current_length == chunk_lengths[0]: + new_line.append(word) + reprocessed_lines.append(new_line) + new_line = [] + chunk_lengths = chunk_lengths[1:] + current_length = 0 + else: + remaining_len = chunk_lengths[0] - current_length + new_line.append(word[:remaining_len]) + reprocessed_lines.append(new_line) + word = word[remaining_len:] + chunk_lengths = chunk_lengths[1:] + while len(word) > chunk_lengths[0]: + new_line = [word[:chunk_lengths[0]]] + reprocessed_lines.append(new_line) + word = word[chunk_lengths[0]:] + chunk_lengths = chunk_lengths[1:] + new_line = [word] + current_length = len(word) + reprocessed_lines.append(new_line) + return reprocessed_lines + +def read_data(input_dir): + subdirs = [os.path.join(input_dir, 'article'), + os.path.join(input_dir, 'encyclopedia'), + os.path.join(input_dir, 'news'), + os.path.join(input_dir, 'novel')] + files = [] + for subdir in subdirs: + if not os.path.exists(subdir): + raise FileNotFoundError("Expected a directory that did not exist: {}".format(subdir)) + files.extend(glob.glob(os.path.join(subdir, '*.txt'))) + + documents = [] + for filename in files: + with open(filename) as fin: + sentences = [] + processed_lines = [] + for line in fin.readlines(): + line = clean_line(line) + words = line.split("|") + words = [clean_word(x) for x in words] + for word in words: + if len(word) > 1 and word[0] == '<': + raise ValueError("Unexpected word '{}' in document {}".format(word, filename)) + words = [x for x in words if x] + processed_lines.append(words) + + processed_lines = reprocess_lines(processed_lines) + + for words in processed_lines: + # turn the words into a sentence + sentence = [] + for word in words: + word = word.strip() + if not word: + if len(sentence) == 0: + raise ValueError("Unexpected space at start of sentence in document {}".format(filename)) + sentence[-1] = (sentence[-1][0], True) + else: + sentence.append((word, False)) + # blank lines are very rare in best, but why not treat them as a paragraph break + if len(sentence) == 0: + paragraphs = [sentences] + documents.append(paragraphs) + sentences = [] + continue + sentence[-1] = (sentence[-1][0], True) + sentences.append(sentence) + paragraphs = [sentences] + documents.append(paragraphs) + + return documents + +def main(*args): + random.seed(1000) + if not args: + args = sys.argv[1:] + + input_dir = args[0] + full_input_dir = os.path.join(input_dir, "thai", "best") + if os.path.exists(full_input_dir): + # otherwise hopefully the user gave us the full path? + input_dir = full_input_dir + + output_dir = args[1] + documents = read_data(input_dir) + write_dataset(documents, output_dir, "best") + + +if __name__ == '__main__': + main() diff --git a/stanza/utils/datasets/tokenization/process_best.py b/stanza/utils/datasets/tokenization/process_best.py deleted file mode 100644 index acfcffa7..00000000 --- a/stanza/utils/datasets/tokenization/process_best.py +++ /dev/null @@ -1,174 +0,0 @@ -"""Parses the BEST Thai dataset. - -That is to say, the dataset named BEST. We have not yet figured out -which segmentation standard we prefer. - -Note that the version of BEST we used actually had some strange -sentence splits according to a native Thai speaker. Not sure how to -fix that. Options include doing it automatically or finding some -knowledgable annotators to resplit it for us (or just not using BEST) - -This outputs the tokenization results in a conll format similar to -that of the UD treebanks, so we pretend to be a UD treebank for ease -of compatibility with the stanza tools. - -python3 -m stanza.utils.datasets.tokenization.process_best extern_data/thai/best data/tokenize -./scripts/run_tokenize.sh UD_Thai-best --dropout 0.05 --unit_dropout 0.05 --steps 50000 -""" - -import glob -import os -import random -import re -import sys - -from pythainlp import sent_tokenize - -from stanza.utils.datasets.tokenization.process_thai_tokenization import write_dataset - -def clean_line(line): - line = line.replace("html>", "html|>") - # news_00089.txt - line = line.replace("", "") - line = line.replace("", "") - # specific error that occurs in encyclopedia_00095.txt - line = line.replace("Penn", "|Penn>") - # news_00058.txt - line = line.replace("จม.เปิดผนึก", "จม.|เปิดผนึก") - # news_00015.txt - line = re.sub("([^|<>]+)([^|<>]+)", "\\1|\\2", line) - # news_00024.txt - line = re.sub("([^|<>]+)", "\\1", line) - # news_00055.txt - line = re.sub("([^|<>]+)([^|<>]+)", "\\1|\\2", line) - line = re.sub("([^|<>]+)([^|<>]+)", "\\1|\\2", line) - line = re.sub("([^|<>]+)([^|<>]+) ([^|<>]+)", "\\1|\\2|\\3", line) - # news_00008.txt and other news articles - line = re.sub("([0-9])", "|\\1", line) - line = line.replace(" ", "|") - line = line.strip() - return line - - -def clean_word(word): - # novel_00078.txt - if word == '': - return 'พี่มน' - if word.startswith("") and word.endswith(""): - return word[4:-5] - if word.startswith("") and word.endswith(""): - return word[4:-5] - if word.startswith("") and word.endswith(""): - return word[6:-7] - if word.startswith(""): - return word[4:] - if word.endswith(""): - return word[:-5] - if word.startswith(""): - return word[6:] - if word.endswith(""): - return word[:-7] - if word == '<': - return word - return word - -def reprocess_lines(processed_lines): - reprocessed_lines = [] - for line in processed_lines: - text = "".join(line) - chunks = sent_tokenize(text) - if sum(len(x) for x in chunks) != len(text): - raise ValueError("Got unexpected text length: \n{}\nvs\n{}".format(text, chunks)) - - chunk_lengths = [len(x) for x in chunks] - - current_length = 0 - new_line = [] - for word in line: - if len(word) + current_length < chunk_lengths[0]: - new_line.append(word) - current_length = current_length + len(word) - elif len(word) + current_length == chunk_lengths[0]: - new_line.append(word) - reprocessed_lines.append(new_line) - new_line = [] - chunk_lengths = chunk_lengths[1:] - current_length = 0 - else: - remaining_len = chunk_lengths[0] - current_length - new_line.append(word[:remaining_len]) - reprocessed_lines.append(new_line) - word = word[remaining_len:] - chunk_lengths = chunk_lengths[1:] - while len(word) > chunk_lengths[0]: - new_line = [word[:chunk_lengths[0]]] - reprocessed_lines.append(new_line) - word = word[chunk_lengths[0]:] - chunk_lengths = chunk_lengths[1:] - new_line = [word] - current_length = len(word) - reprocessed_lines.append(new_line) - return reprocessed_lines - -def read_data(input_dir): - subdirs = [os.path.join(input_dir, 'article'), - os.path.join(input_dir, 'encyclopedia'), - os.path.join(input_dir, 'news'), - os.path.join(input_dir, 'novel')] - files = [] - for subdir in subdirs: - if not os.path.exists(subdir): - raise FileNotFoundError("Expected a directory that did not exist: {}".format(subdir)) - files.extend(glob.glob(os.path.join(subdir, '*.txt'))) - - documents = [] - for filename in files: - with open(filename) as fin: - sentences = [] - processed_lines = [] - for line in fin.readlines(): - line = clean_line(line) - words = line.split("|") - words = [clean_word(x) for x in words] - for word in words: - if len(word) > 1 and word[0] == '<': - raise ValueError("Unexpected word '{}' in document {}".format(word, filename)) - words = [x for x in words if x] - processed_lines.append(words) - - processed_lines = reprocess_lines(processed_lines) - - for words in processed_lines: - # turn the words into a sentence - sentence = [] - for word in words: - word = word.strip() - if not word: - if len(sentence) == 0: - raise ValueError("Unexpected space at start of sentence in document {}".format(filename)) - sentence[-1] = (sentence[-1][0], True) - else: - sentence.append((word, False)) - # blank lines are very rare in best, but why not treat them as a paragraph break - if len(sentence) == 0: - paragraphs = [sentences] - documents.append(paragraphs) - sentences = [] - continue - sentence[-1] = (sentence[-1][0], True) - sentences.append(sentence) - paragraphs = [sentences] - documents.append(paragraphs) - - return documents - -def main(): - random.seed(1000) - input_dir = sys.argv[1] - output_dir = sys.argv[2] - documents = read_data(input_dir) - write_dataset(documents, output_dir, "best") - - -if __name__ == '__main__': - main() -- cgit v1.2.3 From 174bd543e74e256badb1fda6d7c59f1922f031d6 Mon Sep 17 00:00:00 2001 From: John Bauer Date: Wed, 10 Feb 2021 08:06:43 -0800 Subject: Add NewPar to new paragraphs. --- .../utils/datasets/tokenization/process_thai_tokenization.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/stanza/utils/datasets/tokenization/process_thai_tokenization.py b/stanza/utils/datasets/tokenization/process_thai_tokenization.py index 27e347dd..e2950d7c 100644 --- a/stanza/utils/datasets/tokenization/process_thai_tokenization.py +++ b/stanza/utils/datasets/tokenization/process_thai_tokenization.py @@ -43,10 +43,20 @@ def write_section(output_dir, dataset_name, section, documents): with open(os.path.join(output_dir, 'th_%s.%s.gold.conllu' % (dataset_name, section)), 'w') as fout: for document in documents: for paragraph in document: + new_par = True for sentence in paragraph: for word_idx, word in enumerate(sentence): # SpaceAfter is left blank if there is space after the word - space = '_' if word[1] else 'SpaceAfter=No' + if word[1] and new_par: + space = 'NewPar=Yes' + elif word[1]: + space = '_' + elif new_par: + space = 'SpaceAfter=No|NewPar=Yes' + else: + space = 'SpaceAfter=No' + new_par = False + # Note the faked dependency structure: the conll reading code # needs it even if it isn't being used in any way fake_dep = 'root' if word_idx == 0 else 'dep' -- cgit v1.2.3 From 24c7943e3bf163cbd2c2e0066f78ad7e1ae245c7 Mon Sep 17 00:00:00 2001 From: John Bauer Date: Fri, 16 Jul 2021 18:47:45 -0700 Subject: Problem with space separation --- stanza/utils/datasets/tokenization/process_thai_tokenization.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stanza/utils/datasets/tokenization/process_thai_tokenization.py b/stanza/utils/datasets/tokenization/process_thai_tokenization.py index e2950d7c..135bf543 100644 --- a/stanza/utils/datasets/tokenization/process_thai_tokenization.py +++ b/stanza/utils/datasets/tokenization/process_thai_tokenization.py @@ -30,7 +30,7 @@ def write_section(output_dir, dataset_name, section, documents): label_out.write("2") else: label_out.write("1") - if word[1] and sentence_idx != len(paragraph) - 1: + if word[1] and (sentence_idx != len(paragraph) - 1 or word_idx != len(sentence) - 1): text_out.write(' ') label_out.write('0') -- cgit v1.2.3 From 32a227299a9b16f25e4aa9f8d65d81356904ccea Mon Sep 17 00:00:00 2001 From: John Bauer Date: Sun, 18 Jul 2021 19:23:44 -0700 Subject: Test updates based on changes to the underlying data, which changed the results of the model --- stanza/tests/test_english_pipeline.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/stanza/tests/test_english_pipeline.py b/stanza/tests/test_english_pipeline.py index 73569a9a..c8003297 100644 --- a/stanza/tests/test_english_pipeline.py +++ b/stanza/tests/test_english_pipeline.py @@ -27,10 +27,10 @@ EN_DOC_TOKENS_GOLD = """ ]> ]> -]> +]> ]> ]> -]> +]> ]> ]> @@ -50,10 +50,10 @@ EN_DOC_WORDS_GOLD = """ - + - + @@ -96,10 +96,10 @@ EN_DOC_CONLLU_GOLD = """ 1 He he PRON PRP Case=Nom|Gender=Masc|Number=Sing|Person=3|PronType=Prs 3 nsubj:pass _ start_char=34|end_char=36 2 was be AUX VBD Mood=Ind|Number=Sing|Person=3|Tense=Past|VerbForm=Fin 3 aux:pass _ start_char=37|end_char=40 -3 elected elect VERB VBN Tense=Past|VerbForm=Part 0 root _ start_char=41|end_char=48 +3 elected elect VERB VBN Tense=Past|VerbForm=Part|Voice=Pass 0 root _ start_char=41|end_char=48 4 president president NOUN NN Number=Sing 3 xcomp _ start_char=49|end_char=58 5 in in ADP IN _ 6 case _ start_char=59|end_char=61 -6 2008 2008 NUM CD NumType=Card 3 obl _ start_char=62|end_char=66 +6 2008 2008 NUM CD NumForm=Digit|NumType=Card 3 obl _ start_char=62|end_char=66 7 . . PUNCT . _ 3 punct _ start_char=66|end_char=67 1 Obama Obama PROPN NNP Number=Sing 2 nsubj _ start_char=69|end_char=74 @@ -120,10 +120,10 @@ EN_DOC_CONLLU_GOLD_MULTIDOC = """ 1 He he PRON PRP Case=Nom|Gender=Masc|Number=Sing|Person=3|PronType=Prs 3 nsubj:pass _ start_char=0|end_char=2 2 was be AUX VBD Mood=Ind|Number=Sing|Person=3|Tense=Past|VerbForm=Fin 3 aux:pass _ start_char=3|end_char=6 -3 elected elect VERB VBN Tense=Past|VerbForm=Part 0 root _ start_char=7|end_char=14 +3 elected elect VERB VBN Tense=Past|VerbForm=Part|Voice=Pass 0 root _ start_char=7|end_char=14 4 president president NOUN NN Number=Sing 3 xcomp _ start_char=15|end_char=24 5 in in ADP IN _ 6 case _ start_char=25|end_char=27 -6 2008 2008 NUM CD NumType=Card 3 obl _ start_char=28|end_char=32 +6 2008 2008 NUM CD NumForm=Digit|NumType=Card 3 obl _ start_char=28|end_char=32 7 . . PUNCT . _ 3 punct _ start_char=32|end_char=33 1 Obama Obama PROPN NNP Number=Sing 2 nsubj _ start_char=0|end_char=5 -- cgit v1.2.3 From a107c132830140168f7a640fa8dba911d7a5e825 Mon Sep 17 00:00:00 2001 From: John Bauer Date: Sun, 18 Jul 2021 19:42:19 -0700 Subject: Test updates based on changes to the underlying data, which changed the results of the model --- stanza/tests/test_french_pipeline.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/stanza/tests/test_french_pipeline.py b/stanza/tests/test_french_pipeline.py index fe781dc2..d4056aa4 100644 --- a/stanza/tests/test_french_pipeline.py +++ b/stanza/tests/test_french_pipeline.py @@ -95,7 +95,7 @@ EXPECTED_RESULT = """ "upos": "NOUN", "feats": "Gender=Masc|Number=Sing", "head": 3, - "deprel": "obl:arg", + "deprel": "obl:mod", "start_char": 30, "end_char": 36 }, @@ -168,7 +168,7 @@ EXPECTED_RESULT = """ "upos": "NOUN", "feats": "Gender=Masc|Number=Sing", "head": 11, - "deprel": "xcomp:pred", + "deprel": "xcomp", "start_char": 70, "end_char": 78 }, -- cgit v1.2.3 From 5c087af134b7abadc1fba44bccb5ada6997911c1 Mon Sep 17 00:00:00 2001 From: John Bauer Date: Thu, 20 Aug 2020 16:23:56 -0700 Subject: Add a script which converts the LST20 dataset for tokenization --- stanza/utils/datasets/process_lst20.py | 63 ++++++++++++++++++++++++++++++++++ 1 file changed, 63 insertions(+) create mode 100644 stanza/utils/datasets/process_lst20.py diff --git a/stanza/utils/datasets/process_lst20.py b/stanza/utils/datasets/process_lst20.py new file mode 100644 index 00000000..77e48592 --- /dev/null +++ b/stanza/utils/datasets/process_lst20.py @@ -0,0 +1,63 @@ +"""Processes the tokenization section of the LST20 Thai dataset + +The dataset is available here: + +https://aiforthai.in.th/corpus.php + + +python3 -m stanza.utils.datasets.process_lst20 extern_data/thai/LST20_Corpus data/tokenize + +Unlike Orchid and BEST, LST20 has train/eval/test splits, which we relabel train/dev/test. + +./scripts/run_tokenize.sh UD_Thai-lst20 --dropout 0.05 --unit_dropout 0.05 +""" + + +import glob +import os +import sys + +from stanza.utils.datasets.process_thai_tokenization import write_section + +def read_data(input_dir, section): + input_dir = os.path.join(input_dir, section) + filenames = glob.glob(os.path.join(input_dir, "*.txt")) + documents = [] + for filename in filenames: + document = [] + lines = open(filename).readlines() + sentence = [] + for line in lines: + line = line.strip() + if not line: + if sentence: + #sentence[-1] = (sentence[-1][0], True) + document.append(sentence) + sentence = [] + else: + pieces = line.split("\t") + if pieces[0] == '_': + sentence[-1] = (sentence[-1][0], True) + else: + sentence.append((pieces[0], False)) + + if sentence: + #sentence[-1] = (sentence[-1][0], True) + document.append(sentence) + sentence = [] + # TODO: is there any way to divide up a single document into paragraphs? + documents.append([document]) + return documents + +def main(): + input_dir = sys.argv[1] + output_dir = sys.argv[2] + for (in_section, out_section) in (("train", "train"), + ("eval", "dev"), + ("test", "test")): + documents = read_data(input_dir, in_section) + write_section(output_dir, "lst20", out_section, documents) + + +if __name__ == '__main__': + main() -- cgit v1.2.3 From d837c8bef7d37d03e4f729fcc18a2663bbea02fa Mon Sep 17 00:00:00 2001 From: John Bauer Date: Fri, 16 Jul 2021 20:50:42 -0700 Subject: Move process_lst20 to tokenization --- stanza/utils/datasets/process_lst20.py | 63 ---------------------- .../datasets/tokenization/convert_th_lst20.py | 63 ++++++++++++++++++++++ 2 files changed, 63 insertions(+), 63 deletions(-) delete mode 100644 stanza/utils/datasets/process_lst20.py create mode 100644 stanza/utils/datasets/tokenization/convert_th_lst20.py diff --git a/stanza/utils/datasets/process_lst20.py b/stanza/utils/datasets/process_lst20.py deleted file mode 100644 index 77e48592..00000000 --- a/stanza/utils/datasets/process_lst20.py +++ /dev/null @@ -1,63 +0,0 @@ -"""Processes the tokenization section of the LST20 Thai dataset - -The dataset is available here: - -https://aiforthai.in.th/corpus.php - - -python3 -m stanza.utils.datasets.process_lst20 extern_data/thai/LST20_Corpus data/tokenize - -Unlike Orchid and BEST, LST20 has train/eval/test splits, which we relabel train/dev/test. - -./scripts/run_tokenize.sh UD_Thai-lst20 --dropout 0.05 --unit_dropout 0.05 -""" - - -import glob -import os -import sys - -from stanza.utils.datasets.process_thai_tokenization import write_section - -def read_data(input_dir, section): - input_dir = os.path.join(input_dir, section) - filenames = glob.glob(os.path.join(input_dir, "*.txt")) - documents = [] - for filename in filenames: - document = [] - lines = open(filename).readlines() - sentence = [] - for line in lines: - line = line.strip() - if not line: - if sentence: - #sentence[-1] = (sentence[-1][0], True) - document.append(sentence) - sentence = [] - else: - pieces = line.split("\t") - if pieces[0] == '_': - sentence[-1] = (sentence[-1][0], True) - else: - sentence.append((pieces[0], False)) - - if sentence: - #sentence[-1] = (sentence[-1][0], True) - document.append(sentence) - sentence = [] - # TODO: is there any way to divide up a single document into paragraphs? - documents.append([document]) - return documents - -def main(): - input_dir = sys.argv[1] - output_dir = sys.argv[2] - for (in_section, out_section) in (("train", "train"), - ("eval", "dev"), - ("test", "test")): - documents = read_data(input_dir, in_section) - write_section(output_dir, "lst20", out_section, documents) - - -if __name__ == '__main__': - main() diff --git a/stanza/utils/datasets/tokenization/convert_th_lst20.py b/stanza/utils/datasets/tokenization/convert_th_lst20.py new file mode 100644 index 00000000..d256ead5 --- /dev/null +++ b/stanza/utils/datasets/tokenization/convert_th_lst20.py @@ -0,0 +1,63 @@ +"""Processes the tokenization section of the LST20 Thai dataset + +The dataset is available here: + +https://aiforthai.in.th/corpus.php + + +python3 -m stanza.utils.datasets.tokenization.convert_th_lst20 extern_data/thai/LST20_Corpus data/tokenize + +Unlike Orchid and BEST, LST20 has train/eval/test splits, which we relabel train/dev/test. + +./scripts/run_tokenize.sh UD_Thai-lst20 --dropout 0.05 --unit_dropout 0.05 +""" + + +import glob +import os +import sys + +from stanza.utils.datasets.tokenization.process_thai_tokenization import write_section + +def read_data(input_dir, section): + input_dir = os.path.join(input_dir, section) + filenames = glob.glob(os.path.join(input_dir, "*.txt")) + documents = [] + for filename in filenames: + document = [] + lines = open(filename).readlines() + sentence = [] + for line in lines: + line = line.strip() + if not line: + if sentence: + #sentence[-1] = (sentence[-1][0], True) + document.append(sentence) + sentence = [] + else: + pieces = line.split("\t") + if pieces[0] == '_': + sentence[-1] = (sentence[-1][0], True) + else: + sentence.append((pieces[0], False)) + + if sentence: + #sentence[-1] = (sentence[-1][0], True) + document.append(sentence) + sentence = [] + # TODO: is there any way to divide up a single document into paragraphs? + documents.append([document]) + return documents + +def main(): + input_dir = sys.argv[1] + output_dir = sys.argv[2] + for (in_section, out_section) in (("train", "train"), + ("eval", "dev"), + ("test", "test")): + documents = read_data(input_dir, in_section) + write_section(output_dir, "lst20", out_section, documents) + + +if __name__ == '__main__': + main() -- cgit v1.2.3 From 46ca9f3baf1d6c9e355304a5288c5b3779f6648a Mon Sep 17 00:00:00 2001 From: John Bauer Date: Sat, 17 Jul 2021 23:42:28 -0700 Subject: Integrate lst20 into the tokenization script --- stanza/utils/datasets/prepare_tokenizer_treebank.py | 11 ++++++++++- stanza/utils/datasets/tokenization/convert_th_lst20.py | 16 +++++++++++++--- 2 files changed, 23 insertions(+), 4 deletions(-) diff --git a/stanza/utils/datasets/prepare_tokenizer_treebank.py b/stanza/utils/datasets/prepare_tokenizer_treebank.py index 5bbb9da5..8251e267 100755 --- a/stanza/utils/datasets/prepare_tokenizer_treebank.py +++ b/stanza/utils/datasets/prepare_tokenizer_treebank.py @@ -36,6 +36,7 @@ import stanza.utils.datasets.common as common import stanza.utils.datasets.prepare_tokenizer_data as prepare_tokenizer_data import stanza.utils.datasets.tokenization.convert_vi_vlsp as convert_vi_vlsp import stanza.utils.datasets.tokenization.convert_th_best as convert_th_best +import stanza.utils.datasets.tokenization.convert_th_lst20 as convert_th_lst20 import stanza.utils.datasets.tokenization.convert_th_orchid as convert_th_orchid def copy_conllu_file(tokenizer_dir, tokenizer_file, dest_dir, dest_file, short_name): @@ -138,7 +139,13 @@ def prepare_treebank_labels(tokenizer_dir, short_name): for dataset in ("train", "dev", "test"): output_txt = f"{tokenizer_dir}/{short_name}.{dataset}.txt" output_conllu = f"{tokenizer_dir}/{short_name}.{dataset}.gold.conllu" - prepare_dataset_labels(output_txt, output_conllu, tokenizer_dir, short_name, dataset) + try: + prepare_dataset_labels(output_txt, output_conllu, tokenizer_dir, short_name, dataset) + except (KeyboardInterrupt, SystemExit): + raise + except: + print("Failed to convert %s to %s" % (output_txt, output_conllu)) + raise CONLLU_TO_TXT_PERL = os.path.join(os.path.split(__file__)[0], "conllu_to_text.pl") @@ -1038,6 +1045,8 @@ def process_treebank(treebank, paths, args): convert_vi_vlsp.convert_vi_vlsp(paths["EXTERN_DIR"], tokenizer_dir, args) elif short_name == "th_orchid": convert_th_orchid.main(paths["EXTERN_DIR"], tokenizer_dir) + elif short_name == "th_lst20": + convert_th_lst20.main(paths["EXTERN_DIR"], tokenizer_dir) elif short_name == "th_best": convert_th_best.main(paths["EXTERN_DIR"], tokenizer_dir) elif short_name.startswith("ko_combined"): diff --git a/stanza/utils/datasets/tokenization/convert_th_lst20.py b/stanza/utils/datasets/tokenization/convert_th_lst20.py index d256ead5..2f15e3e5 100644 --- a/stanza/utils/datasets/tokenization/convert_th_lst20.py +++ b/stanza/utils/datasets/tokenization/convert_th_lst20.py @@ -36,6 +36,8 @@ def read_data(input_dir, section): sentence = [] else: pieces = line.split("\t") + # there are some nbsp in tokens in lst20, but the downstream tools expect spaces + pieces = [p.replace("\xa0", " ") for p in pieces] if pieces[0] == '_': sentence[-1] = (sentence[-1][0], True) else: @@ -49,13 +51,21 @@ def read_data(input_dir, section): documents.append([document]) return documents -def main(): - input_dir = sys.argv[1] - output_dir = sys.argv[2] +def main(*args): + if not args: + args = sys.argv[1:] + input_dir = args[0] + full_input_dir = os.path.join(input_dir, "thai", "LST20_Corpus") + if os.path.exists(full_input_dir): + # otherwise hopefully the user gave us the full path? + input_dir = full_input_dir + output_dir = args[1] for (in_section, out_section) in (("train", "train"), ("eval", "dev"), ("test", "test")): + print("Processing %s" % out_section) documents = read_data(input_dir, in_section) + print(" Read in %d files" % len(documents)) write_section(output_dir, "lst20", out_section, documents) -- cgit v1.2.3 From 8063055dcca91a70f239d0dc12d8230c206d10b9 Mon Sep 17 00:00:00 2001 From: John Bauer Date: Mon, 19 Jul 2021 09:29:57 -0700 Subject: Don't make new text files for datasets which already produced text files --- stanza/utils/datasets/prepare_tokenizer_treebank.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/stanza/utils/datasets/prepare_tokenizer_treebank.py b/stanza/utils/datasets/prepare_tokenizer_treebank.py index 8251e267..694f2b2e 100755 --- a/stanza/utils/datasets/prepare_tokenizer_treebank.py +++ b/stanza/utils/datasets/prepare_tokenizer_treebank.py @@ -1068,7 +1068,8 @@ def process_treebank(treebank, paths, args): else: process_ud_treebank(treebank, udbase_dir, tokenizer_dir, short_name, short_language, args.augment) - convert_conllu_to_txt(tokenizer_dir, short_name) + if not short_name in ('th_orchid', 'th_lst20'): + convert_conllu_to_txt(tokenizer_dir, short_name) if args.prepare_labels: prepare_treebank_labels(tokenizer_dir, short_name) -- cgit v1.2.3 From 31a9b8c9837c3471930a5deeb37ec50615c86936 Mon Sep 17 00:00:00 2001 From: John Bauer Date: Mon, 19 Jul 2021 09:30:27 -0700 Subject: Refactor some to make it easier to test the lst20 script --- .../utils/datasets/prepare_tokenizer_treebank.py | 4 +- .../datasets/tokenization/convert_th_lst20.py | 53 ++++++++++++---------- 2 files changed, 31 insertions(+), 26 deletions(-) diff --git a/stanza/utils/datasets/prepare_tokenizer_treebank.py b/stanza/utils/datasets/prepare_tokenizer_treebank.py index 694f2b2e..c2188513 100755 --- a/stanza/utils/datasets/prepare_tokenizer_treebank.py +++ b/stanza/utils/datasets/prepare_tokenizer_treebank.py @@ -149,8 +149,8 @@ def prepare_treebank_labels(tokenizer_dir, short_name): CONLLU_TO_TXT_PERL = os.path.join(os.path.split(__file__)[0], "conllu_to_text.pl") -def convert_conllu_to_txt(tokenizer_dir, short_name): - for dataset in ("train", "dev", "test"): +def convert_conllu_to_txt(tokenizer_dir, short_name, shards=("train", "dev", "test")): + for dataset in shards: output_conllu = f"{tokenizer_dir}/{short_name}.{dataset}.gold.conllu" output_txt = f"{tokenizer_dir}/{short_name}.{dataset}.txt" diff --git a/stanza/utils/datasets/tokenization/convert_th_lst20.py b/stanza/utils/datasets/tokenization/convert_th_lst20.py index 2f15e3e5..275ac21b 100644 --- a/stanza/utils/datasets/tokenization/convert_th_lst20.py +++ b/stanza/utils/datasets/tokenization/convert_th_lst20.py @@ -19,35 +19,40 @@ import sys from stanza.utils.datasets.tokenization.process_thai_tokenization import write_section +def read_document(lines): + document = [] + sentence = [] + for line in lines: + line = line.strip() + if not line: + if sentence: + #sentence[-1] = (sentence[-1][0], True) + document.append(sentence) + sentence = [] + else: + pieces = line.split("\t") + # there are some nbsp in tokens in lst20, but the downstream tools expect spaces + pieces = [p.replace("\xa0", " ") for p in pieces] + if pieces[0] == '_': + sentence[-1] = (sentence[-1][0], True) + else: + sentence.append((pieces[0], False)) + + if sentence: + #sentence[-1] = (sentence[-1][0], True) + document.append(sentence) + sentence = [] + # TODO: is there any way to divide up a single document into paragraphs? + return document + def read_data(input_dir, section): input_dir = os.path.join(input_dir, section) filenames = glob.glob(os.path.join(input_dir, "*.txt")) documents = [] for filename in filenames: - document = [] - lines = open(filename).readlines() - sentence = [] - for line in lines: - line = line.strip() - if not line: - if sentence: - #sentence[-1] = (sentence[-1][0], True) - document.append(sentence) - sentence = [] - else: - pieces = line.split("\t") - # there are some nbsp in tokens in lst20, but the downstream tools expect spaces - pieces = [p.replace("\xa0", " ") for p in pieces] - if pieces[0] == '_': - sentence[-1] = (sentence[-1][0], True) - else: - sentence.append((pieces[0], False)) - - if sentence: - #sentence[-1] = (sentence[-1][0], True) - document.append(sentence) - sentence = [] - # TODO: is there any way to divide up a single document into paragraphs? + with open(filename) as fin: + lines = fin.readlines() + document = read_document(lines) documents.append([document]) return documents -- cgit v1.2.3 From 6f0a50a26d08976b01c30c7bba3465397b25aac6 Mon Sep 17 00:00:00 2001 From: John Bauer Date: Mon, 19 Jul 2021 13:05:24 -0700 Subject: Improve prepare_ner_dataset doc --- stanza/utils/datasets/ner/prepare_ner_dataset.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/stanza/utils/datasets/ner/prepare_ner_dataset.py b/stanza/utils/datasets/ner/prepare_ner_dataset.py index 22befde0..4652bdae 100644 --- a/stanza/utils/datasets/ner/prepare_ner_dataset.py +++ b/stanza/utils/datasets/ner/prepare_ner_dataset.py @@ -10,7 +10,7 @@ Also, Finnish Turku dataset, available here: - https://turkunlp.org/fin-ner.html - Download and unzip the corpus, putting the .tsv files into $NERBASE/fi_turku - - prepare_ner_dataset.py hu_nytk fi_turku + - prepare_ner_dataset.py fi_turku IJCNLP 2008 produced a few Indian language NER datasets. description: @@ -18,6 +18,7 @@ IJCNLP 2008 produced a few Indian language NER datasets. download: http://ltrc.iiit.ac.in/ner-ssea-08/index.cgi?topic=5 The models produced from these datasets have extremely low recall, unfortunately. + - prepare_ner_dataset.py hi-fire2013 FIRE 2013 also produced NER datasets for Indian languages. http://au-kbc.org/nlp/NER-FIRE2013/index.html -- cgit v1.2.3 From c04e622e130e786b0c3163cac3a4be76ad3052b4 Mon Sep 17 00:00:00 2001 From: John Bauer Date: Mon, 19 Jul 2021 13:42:05 -0700 Subject: Add some more command lines to the prepare_ner_dataset.py doc --- stanza/utils/datasets/ner/prepare_ner_dataset.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/stanza/utils/datasets/ner/prepare_ner_dataset.py b/stanza/utils/datasets/ner/prepare_ner_dataset.py index 4652bdae..e63f501d 100644 --- a/stanza/utils/datasets/ner/prepare_ner_dataset.py +++ b/stanza/utils/datasets/ner/prepare_ner_dataset.py @@ -40,6 +40,7 @@ There are two Hungarian datasets are available here: You can also build individual pieces with hu_rgai_business or hu_rgai_criminal Create a subdirectory of $NERBASE, $NERBASE/hu_rgai, and download both of the pieces and unzip them in that directory. + - prepare_ner_dataset.py hu_rgai Another Hungarian dataset is here: - https://github.com/nytud/NYTK-NerKor @@ -48,6 +49,7 @@ Another Hungarian dataset is here: The two Hungarian datasets can be combined with hu_combined TODO: verify that there is no overlap in text + - prepare_ner_dataset.py hu_combined BSNLP publishes NER datasets for Eastern European languages. - In 2019 they published BG, CS, PL, RU. @@ -66,6 +68,7 @@ BSNLP publishes NER datasets for Eastern European languages. - we use the code name "bg_bsnlp19". Other languages from bsnlp 2019 can be supported by adding the appropriate functionality in convert_bsnlp.py. + - prepare_ner_dataset.py bg_bsnlp19 """ import glob -- cgit v1.2.3 From fdf9c7d954595449b1fd0ee71e7d3fa7b6121a2f Mon Sep 17 00:00:00 2001 From: John Bauer Date: Mon, 19 Jul 2021 14:50:15 -0700 Subject: Add a few extra cases to treebank_to_short_name so that calling on an already short name should generally work. Add a test as well --- stanza/models/common/constant.py | 8 +++++++- stanza/tests/test_constant.py | 35 +++++++++++++++++++++++++++++++++++ 2 files changed, 42 insertions(+), 1 deletion(-) create mode 100644 stanza/tests/test_constant.py diff --git a/stanza/models/common/constant.py b/stanza/models/common/constant.py index 8fdabb4b..9b39b7c2 100644 --- a/stanza/models/common/constant.py +++ b/stanza/models/common/constant.py @@ -160,7 +160,13 @@ def treebank_to_short_name(treebank): if treebank.startswith('UD_'): treebank = treebank[3:] - splits = treebank.split('-') + # special case starting with zh in case the input is an already-converted ZH treebank + if treebank.startswith("zh-hans") or treebank.startswith("zh-hant"): + splits = (treebank[:len("zh-hans")], treebank[len("zh-hans")+1:]) + else: + splits = treebank.split('-') + if len(splits) == 1: + splits = treebank.split("_", 1) assert len(splits) == 2, "Unable to process %s" % treebank lang, corpus = splits diff --git a/stanza/tests/test_constant.py b/stanza/tests/test_constant.py new file mode 100644 index 00000000..3afcc8d6 --- /dev/null +++ b/stanza/tests/test_constant.py @@ -0,0 +1,35 @@ +""" +Test the conversion to lcodes and splitting of dataset names +""" + +import tempfile + +import pytest + +import stanza +from stanza.models.common.constant import treebank_to_short_name +from stanza.tests import * + +pytestmark = [pytest.mark.travis, pytest.mark.pipeline] + +def test_treebank(): + """ + Test the entire treebank name conversion + """ + # conversion of a UD_ name + assert "hi_hdtb" == treebank_to_short_name("UD_Hindi-HDTB") + # conversion of names without UD + assert "hi_fire2013" == treebank_to_short_name("Hindi-fire2013") + assert "hi_fire2013" == treebank_to_short_name("Hindi-Fire2013") + assert "hi_fire2013" == treebank_to_short_name("Hindi-FIRE2013") + # already short names are generally preserved + assert "hi_fire2013" == treebank_to_short_name("hi-fire2013") + assert "hi_fire2013" == treebank_to_short_name("hi_fire2013") + # a special case + assert "zh-hant_pud" == treebank_to_short_name("UD_Chinese-PUD") + # a special case already converted once + assert "zh-hant_pud" == treebank_to_short_name("zh-hant_pud") + assert "zh-hant_pud" == treebank_to_short_name("zh-hant-pud") + assert "zh-hans_gsdsimp" == treebank_to_short_name("zh-hans_gsdsimp") + + -- cgit v1.2.3 From bf6a6298b2370dde29565785edabe694a9b8cbee Mon Sep 17 00:00:00 2001 From: John Bauer Date: Mon, 19 Jul 2021 14:53:38 -0700 Subject: Standardize the final short_name of the hindi ner dataset regardless of which of the plausible short_names was given --- stanza/utils/datasets/ner/prepare_ner_dataset.py | 1 + 1 file changed, 1 insertion(+) diff --git a/stanza/utils/datasets/ner/prepare_ner_dataset.py b/stanza/utils/datasets/ner/prepare_ner_dataset.py index e63f501d..d936c158 100644 --- a/stanza/utils/datasets/ner/prepare_ner_dataset.py +++ b/stanza/utils/datasets/ner/prepare_ner_dataset.py @@ -173,6 +173,7 @@ def process_fire_2013(paths, dataset): """ short_name = treebank_to_short_name(dataset) langcode, _ = short_name.split("_") + short_name = "%s_fire2013" % langcode if not langcode in ("hi", "en", "ta", "bn", "mal"): raise ValueError("Language %s not one of the FIRE 2013 languages") language = lcode2lang[langcode].lower() -- cgit v1.2.3 From c1aedb867911efd700d12b63a8f94778551c5e02 Mon Sep 17 00:00:00 2001 From: John Bauer Date: Tue, 20 Jul 2021 15:39:20 -0700 Subject: Revert "Adjust the newpar title" This reverts commit 72bc6a3c6cac9152e53adc6c645c424a603b1c15. --- stanza/utils/datasets/tokenization/convert_vi_vlsp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stanza/utils/datasets/tokenization/convert_vi_vlsp.py b/stanza/utils/datasets/tokenization/convert_vi_vlsp.py index 2c00a51c..31e7a985 100644 --- a/stanza/utils/datasets/tokenization/convert_vi_vlsp.py +++ b/stanza/utils/datasets/tokenization/convert_vi_vlsp.py @@ -45,7 +45,7 @@ def write_file(vlsp_include_spaces, output_filename, sentences, shard): orig_text = " ".join(sentence) #check if the previous line is a headline (no ending mark at the end) then make this sentence a new par if check_headlines: - fout.write("# newpar_id = %s.%d.1\n" % (shard, sent_idx)) + fout.write("# newpar id =%s.%d.1\n" % (shard, sent_idx)) check_headlines = False if sentence[len(sentence) - 1] not in punctuation_set: check_headlines = True -- cgit v1.2.3 From 8e30b16c28942ebac0b71e3ff9d3a09da0952f7b Mon Sep 17 00:00:00 2001 From: John Bauer Date: Fri, 21 Aug 2020 10:19:08 -0700 Subject: Refactor some of the processing code which uses pythainlp --- .../utils/datasets/tokenization/convert_th_best.py | 64 +---------------- .../tokenization/process_thai_tokenization.py | 84 ++++++++++++++++++++++ 2 files changed, 87 insertions(+), 61 deletions(-) diff --git a/stanza/utils/datasets/tokenization/convert_th_best.py b/stanza/utils/datasets/tokenization/convert_th_best.py index a642702c..416c84b2 100644 --- a/stanza/utils/datasets/tokenization/convert_th_best.py +++ b/stanza/utils/datasets/tokenization/convert_th_best.py @@ -28,7 +28,7 @@ import sys from pythainlp import sent_tokenize -from stanza.utils.datasets.tokenization.process_thai_tokenization import write_dataset +from stanza.utils.datasets.tokenization.process_thai_tokenization import reprocess_lines, write_dataset, convert_processed_lines def clean_line(line): line = line.replace("html>", "html|>") @@ -76,44 +76,6 @@ def clean_word(word): return word return word -def reprocess_lines(processed_lines): - reprocessed_lines = [] - for line in processed_lines: - text = "".join(line) - chunks = sent_tokenize(text) - if sum(len(x) for x in chunks) != len(text): - raise ValueError("Got unexpected text length: \n{}\nvs\n{}".format(text, chunks)) - - chunk_lengths = [len(x) for x in chunks] - - current_length = 0 - new_line = [] - for word in line: - if len(word) + current_length < chunk_lengths[0]: - new_line.append(word) - current_length = current_length + len(word) - elif len(word) + current_length == chunk_lengths[0]: - new_line.append(word) - reprocessed_lines.append(new_line) - new_line = [] - chunk_lengths = chunk_lengths[1:] - current_length = 0 - else: - remaining_len = chunk_lengths[0] - current_length - new_line.append(word[:remaining_len]) - reprocessed_lines.append(new_line) - word = word[remaining_len:] - chunk_lengths = chunk_lengths[1:] - while len(word) > chunk_lengths[0]: - new_line = [word[:chunk_lengths[0]]] - reprocessed_lines.append(new_line) - word = word[chunk_lengths[0]:] - chunk_lengths = chunk_lengths[1:] - new_line = [word] - current_length = len(word) - reprocessed_lines.append(new_line) - return reprocessed_lines - def read_data(input_dir): subdirs = [os.path.join(input_dir, 'article'), os.path.join(input_dir, 'encyclopedia'), @@ -128,7 +90,6 @@ def read_data(input_dir): documents = [] for filename in files: with open(filename) as fin: - sentences = [] processed_lines = [] for line in fin.readlines(): line = clean_line(line) @@ -141,28 +102,9 @@ def read_data(input_dir): processed_lines.append(words) processed_lines = reprocess_lines(processed_lines) + paragraphs = convert_processed_lines(processed_lines) - for words in processed_lines: - # turn the words into a sentence - sentence = [] - for word in words: - word = word.strip() - if not word: - if len(sentence) == 0: - raise ValueError("Unexpected space at start of sentence in document {}".format(filename)) - sentence[-1] = (sentence[-1][0], True) - else: - sentence.append((word, False)) - # blank lines are very rare in best, but why not treat them as a paragraph break - if len(sentence) == 0: - paragraphs = [sentences] - documents.append(paragraphs) - sentences = [] - continue - sentence[-1] = (sentence[-1][0], True) - sentences.append(sentence) - paragraphs = [sentences] - documents.append(paragraphs) + documents.extend(paragraphs) return documents diff --git a/stanza/utils/datasets/tokenization/process_thai_tokenization.py b/stanza/utils/datasets/tokenization/process_thai_tokenization.py index 135bf543..d863402c 100644 --- a/stanza/utils/datasets/tokenization/process_thai_tokenization.py +++ b/stanza/utils/datasets/tokenization/process_thai_tokenization.py @@ -1,6 +1,8 @@ import os import random +from pythainlp import sent_tokenize + def write_section(output_dir, dataset_name, section, documents): """ Writes a list of documents for tokenization, including a file in conll format @@ -74,3 +76,85 @@ def write_dataset(documents, output_dir, dataset_name): write_section(output_dir, dataset_name, 'train', documents[:num_train]) write_section(output_dir, dataset_name, 'dev', documents[num_train:num_train+num_dev]) write_section(output_dir, dataset_name, 'test', documents[num_train+num_dev:]) + + + +def reprocess_lines(processed_lines): + """ + Reprocesses lines using pythainlp to cut up sentences into shorter sentences. + + Many of the lines in BEST seem to be multiple Thai sentences concatenated, according to native Thai speakers. + + Input: a list of lines, where each line is a list of words. Space characters can be included as words + Output: a new list of lines, resplit using pythainlp + """ + reprocessed_lines = [] + for line in processed_lines: + text = "".join(line) + chunks = sent_tokenize(text) + # Check that the total text back is the same as the text in + if sum(len(x) for x in chunks) != len(text): + raise ValueError("Got unexpected text length: \n{}\nvs\n{}".format(text, chunks)) + + chunk_lengths = [len(x) for x in chunks] + + current_length = 0 + new_line = [] + for word in line: + if len(word) + current_length < chunk_lengths[0]: + new_line.append(word) + current_length = current_length + len(word) + elif len(word) + current_length == chunk_lengths[0]: + new_line.append(word) + reprocessed_lines.append(new_line) + new_line = [] + chunk_lengths = chunk_lengths[1:] + current_length = 0 + else: + remaining_len = chunk_lengths[0] - current_length + new_line.append(word[:remaining_len]) + reprocessed_lines.append(new_line) + word = word[remaining_len:] + chunk_lengths = chunk_lengths[1:] + while len(word) > chunk_lengths[0]: + new_line = [word[:chunk_lengths[0]]] + reprocessed_lines.append(new_line) + word = word[chunk_lengths[0]:] + chunk_lengths = chunk_lengths[1:] + new_line = [word] + current_length = len(word) + reprocessed_lines.append(new_line) + return reprocessed_lines + +def convert_processed_lines(processed_lines): + """ + Convert a list of sentences into documents suitable for the output methods in this module. + + Input: a list of lines, including space words + Output: a list of documents, each document containing a list of sentences + Each sentence is a list of words: (text, space_follows) + Space words will be eliminated. + """ + paragraphs = [] + sentences = [] + for words in processed_lines: + # turn the words into a sentence + sentence = [] + for word in words: + word = word.strip() + if not word: + if len(sentence) == 0: + raise ValueError("Unexpected space at start of sentence in document {}".format(filename)) + sentence[-1] = (sentence[-1][0], True) + else: + sentence.append((word, False)) + # blank lines are very rare in best, but why not treat them as a paragraph break + if len(sentence) == 0: + paragraphs.append([sentences]) + sentences = [] + continue + sentence[-1] = (sentence[-1][0], True) + sentences.append(sentence) + paragraphs.append([sentences]) + return paragraphs + -- cgit v1.2.3 From e933b076e410fe29df4a1ec94c4a5fa808353641 Mon Sep 17 00:00:00 2001 From: John Bauer Date: Fri, 21 Aug 2020 12:03:43 -0700 Subject: Use pythainlp to resplit lst20 sentences as well --- .../datasets/tokenization/convert_th_lst20.py | 31 +++++++++++++++++++--- 1 file changed, 27 insertions(+), 4 deletions(-) diff --git a/stanza/utils/datasets/tokenization/convert_th_lst20.py b/stanza/utils/datasets/tokenization/convert_th_lst20.py index 275ac21b..c067bf73 100644 --- a/stanza/utils/datasets/tokenization/convert_th_lst20.py +++ b/stanza/utils/datasets/tokenization/convert_th_lst20.py @@ -17,7 +17,7 @@ import glob import os import sys -from stanza.utils.datasets.tokenization.process_thai_tokenization import write_section +from stanza.utils.datasets.tokenization.process_thai_tokenization import write_section, convert_processed_lines, reprocess_lines def read_document(lines): document = [] @@ -43,7 +43,30 @@ def read_document(lines): document.append(sentence) sentence = [] # TODO: is there any way to divide up a single document into paragraphs? - return document + return [[document]] + +def retokenize_document(lines): + processed_lines = [] + sentence = [] + for line in lines: + line = line.strip() + if not line: + if sentence: + processed_lines.append(sentence) + sentence = [] + else: + pieces = line.split("\t") + if pieces[0] == '_': + sentence.append(' ') + else: + sentence.append(pieces[0]) + if sentence: + processed_lines.append(sentence) + + processed_lines = reprocess_lines(processed_lines) + paragraphs = convert_processed_lines(processed_lines) + return paragraphs + def read_data(input_dir, section): input_dir = os.path.join(input_dir, section) @@ -52,8 +75,8 @@ def read_data(input_dir, section): for filename in filenames: with open(filename) as fin: lines = fin.readlines() - document = read_document(lines) - documents.append([document]) + document = retokenize_document(lines) + documents.extend(document) return documents def main(*args): -- cgit v1.2.3 From c86b430e6333ba2dcb146da486093f58102928d5 Mon Sep 17 00:00:00 2001 From: John Bauer Date: Mon, 19 Jul 2021 19:54:23 -0700 Subject: Make the retokenization an option for the lst20 dataset --- .../utils/datasets/prepare_tokenizer_treebank.py | 3 +- .../datasets/tokenization/convert_th_lst20.py | 33 ++++++++++++++++------ 2 files changed, 26 insertions(+), 10 deletions(-) diff --git a/stanza/utils/datasets/prepare_tokenizer_treebank.py b/stanza/utils/datasets/prepare_tokenizer_treebank.py index c2188513..d84e36c3 100755 --- a/stanza/utils/datasets/prepare_tokenizer_treebank.py +++ b/stanza/utils/datasets/prepare_tokenizer_treebank.py @@ -1021,6 +1021,7 @@ def add_specific_args(parser): help='Augment the dataset in various ways') parser.add_argument('--no_prepare_labels', action='store_false', dest='prepare_labels', default=True, help='Prepare tokenizer and MWT labels. Expensive, but obviously necessary for training those models.') + convert_th_lst20.add_lst20_args(parser) def process_treebank(treebank, paths, args): """ @@ -1046,7 +1047,7 @@ def process_treebank(treebank, paths, args): elif short_name == "th_orchid": convert_th_orchid.main(paths["EXTERN_DIR"], tokenizer_dir) elif short_name == "th_lst20": - convert_th_lst20.main(paths["EXTERN_DIR"], tokenizer_dir) + convert_th_lst20.convert(paths["EXTERN_DIR"], tokenizer_dir, args) elif short_name == "th_best": convert_th_best.main(paths["EXTERN_DIR"], tokenizer_dir) elif short_name.startswith("ko_combined"): diff --git a/stanza/utils/datasets/tokenization/convert_th_lst20.py b/stanza/utils/datasets/tokenization/convert_th_lst20.py index c067bf73..6cea7d77 100644 --- a/stanza/utils/datasets/tokenization/convert_th_lst20.py +++ b/stanza/utils/datasets/tokenization/convert_th_lst20.py @@ -13,6 +13,7 @@ Unlike Orchid and BEST, LST20 has train/eval/test splits, which we relabel train """ +import argparse import glob import os import sys @@ -68,34 +69,48 @@ def retokenize_document(lines): return paragraphs -def read_data(input_dir, section): +def read_data(input_dir, section, resegment): input_dir = os.path.join(input_dir, section) filenames = glob.glob(os.path.join(input_dir, "*.txt")) documents = [] for filename in filenames: with open(filename) as fin: lines = fin.readlines() - document = retokenize_document(lines) + if resegment: + document = retokenize_document(lines) + else: + document = read_document(lines) documents.extend(document) return documents -def main(*args): - if not args: - args = sys.argv[1:] - input_dir = args[0] +def add_lst20_args(parser): + parser.add_argument('--no_lst20_resegment', action='store_false', dest="lst20_resegment", default=True, help='When processing th_lst20 tokenization, use pythainlp to resegment the text. The other option is to keep the original sentence segmentation. Currently our model is not good at that') + +def parse_lst20_args(): + parser = argparse.ArgumentParser() + parser.add_argument('input_dir', help="Directory to use when processing lst20") + parser.add_argument('output_dir', help="Directory to use when saving lst20") + add_lst20_args(parser) + return parser.parse_args() + + + +def convert(input_dir, output_dir, args): full_input_dir = os.path.join(input_dir, "thai", "LST20_Corpus") if os.path.exists(full_input_dir): # otherwise hopefully the user gave us the full path? input_dir = full_input_dir - output_dir = args[1] for (in_section, out_section) in (("train", "train"), ("eval", "dev"), ("test", "test")): print("Processing %s" % out_section) - documents = read_data(input_dir, in_section) - print(" Read in %d files" % len(documents)) + documents = read_data(input_dir, in_section, args.lst20_resegment) + print(" Read in %d documents" % len(documents)) write_section(output_dir, "lst20", out_section, documents) +def main(): + args = parse_lst20_args() + convert(args.input_dir, args.output_dir, args) if __name__ == '__main__': main() -- cgit v1.2.3 From 2cd732b76cb2139d7a7ba25b95df5bec9425f873 Mon Sep 17 00:00:00 2001 From: John Bauer Date: Tue, 20 Jul 2021 23:48:20 -0700 Subject: Add a tiny test for part of the LST20 preparation --- stanza/tests/test_tokenization_lst20.py | 84 ++++++++++++++++++++++ .../tokenization/process_thai_tokenization.py | 10 ++- 2 files changed, 92 insertions(+), 2 deletions(-) create mode 100644 stanza/tests/test_tokenization_lst20.py diff --git a/stanza/tests/test_tokenization_lst20.py b/stanza/tests/test_tokenization_lst20.py new file mode 100644 index 00000000..5b94c84a --- /dev/null +++ b/stanza/tests/test_tokenization_lst20.py @@ -0,0 +1,84 @@ +import os +import tempfile + +import pytest + +import stanza +from stanza.tests import * + +from stanza.utils.datasets.prepare_tokenizer_treebank import convert_conllu_to_txt +from stanza.utils.datasets.tokenization.convert_th_lst20 import read_document +from stanza.utils.datasets.tokenization.process_thai_tokenization import write_section + +pytestmark = [pytest.mark.travis, pytest.mark.pipeline] + +SMALL_LST_SAMPLE=""" +สุรยุทธ์ NN B_PER B_CLS +ยัน VV O I_CLS +ปฏิเสธ VV O I_CLS +ลงนาม VV O I_CLS +_ PU O I_CLS +MOU NN O I_CLS +_ PU O I_CLS +กับ PS O I_CLS +อียู NN B_ORG I_CLS +ไม่ NG O I_CLS +กระทบ VV O I_CLS +สัมพันธ์ NN O E_CLS + +1 NU B_DTM B_CLS +_ PU I_DTM I_CLS +กันยายน NN I_DTM I_CLS +_ PU I_DTM I_CLS +2550 NU E_DTM I_CLS +_ PU O I_CLS +12:21 NU B_DTM I_CLS +_ PU I_DTM I_CLS +น. CL E_DTM E_CLS +""".strip() + +EXPECTED_CONLLU=""" +1 สุรยุทธ์ _ _ _ _ 0 root 0:root SpaceAfter=No|NewPar=Yes +2 ยัน _ _ _ _ 1 dep 1:dep SpaceAfter=No +3 ปฏิเสธ _ _ _ _ 2 dep 2:dep SpaceAfter=No +4 ลงนาม _ _ _ _ 3 dep 3:dep _ +5 MOU _ _ _ _ 4 dep 4:dep _ +6 กับ _ _ _ _ 5 dep 5:dep SpaceAfter=No +7 อียู _ _ _ _ 6 dep 6:dep SpaceAfter=No +8 ไม่ _ _ _ _ 7 dep 7:dep SpaceAfter=No +9 กระทบ _ _ _ _ 8 dep 8:dep SpaceAfter=No +10 สัมพันธ์ _ _ _ _ 9 dep 9:dep SpaceAfter=No + +1 1 _ _ _ _ 0 root 0:root _ +2 กันยายน _ _ _ _ 1 dep 1:dep _ +3 2550 _ _ _ _ 2 dep 2:dep _ +4 12:21 _ _ _ _ 3 dep 3:dep _ +5 น. _ _ _ _ 4 dep 4:dep SpaceAfter=No +""".strip() + +EXPECTED_TXT="สุรยุทธ์ยันปฏิเสธลงนาม MOU กับอียูไม่กระทบสัมพันธ์1 กันยายน 2550 12:21 น.\n\n" +EXPECTED_LABELS="0000000100100000100001000100010001001000010000000210000000100001000001002\n\n" + +def test_small(): + """ + A small test just to verify that the output is being produced as we want + + Note that there currently are no spaces after the first sentence. + Apparently this is wrong, but weirdly, doing that makes the model even worse. + """ + lines = SMALL_LST_SAMPLE.strip().split("\n") + documents = read_document(lines) + + with tempfile.TemporaryDirectory() as output_dir: + write_section(output_dir, "lst20", "train", documents) + with open(os.path.join(output_dir, "th_lst20.train.gold.conllu")) as fin: + conllu = fin.read().strip() + with open(os.path.join(output_dir, "th_lst20.train.txt")) as fin: + txt = fin.read() + with open(os.path.join(output_dir, "th_lst20-ud-train.toklabels")) as fin: + labels = fin.read() + assert conllu == EXPECTED_CONLLU + assert txt == EXPECTED_TXT + assert labels == EXPECTED_LABELS + + assert len(txt) == len(labels) diff --git a/stanza/utils/datasets/tokenization/process_thai_tokenization.py b/stanza/utils/datasets/tokenization/process_thai_tokenization.py index d863402c..d92ab674 100644 --- a/stanza/utils/datasets/tokenization/process_thai_tokenization.py +++ b/stanza/utils/datasets/tokenization/process_thai_tokenization.py @@ -1,7 +1,10 @@ import os import random -from pythainlp import sent_tokenize +try: + from pythainlp import sent_tokenize +except ImportError: + pass def write_section(output_dir, dataset_name, section, documents): """ @@ -91,7 +94,10 @@ def reprocess_lines(processed_lines): reprocessed_lines = [] for line in processed_lines: text = "".join(line) - chunks = sent_tokenize(text) + try: + chunks = sent_tokenize(text) + except NameError as e: + raise NameError("Sentences cannot be reprocessed without first installing pythainlp") from e # Check that the total text back is the same as the text in if sum(len(x) for x in chunks) != len(text): raise ValueError("Got unexpected text length: \n{}\nvs\n{}".format(text, chunks)) -- cgit v1.2.3 From e4ddfedd48e6b27385a575f7ad3a73096e957d8d Mon Sep 17 00:00:00 2001 From: John Bauer Date: Wed, 21 Jul 2021 00:14:11 -0700 Subject: Attempt to add a helpful error explaining where it looked for LST20 --- stanza/utils/datasets/tokenization/convert_th_lst20.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/stanza/utils/datasets/tokenization/convert_th_lst20.py b/stanza/utils/datasets/tokenization/convert_th_lst20.py index 6cea7d77..4c48aad8 100644 --- a/stanza/utils/datasets/tokenization/convert_th_lst20.py +++ b/stanza/utils/datasets/tokenization/convert_th_lst20.py @@ -4,8 +4,9 @@ The dataset is available here: https://aiforthai.in.th/corpus.php +The data should be installed under ${EXTERN_DATA}/thai/LST20_Corpus -python3 -m stanza.utils.datasets.tokenization.convert_th_lst20 extern_data/thai/LST20_Corpus data/tokenize +python3 -m stanza.utils.datasets.tokenization.convert_th_lst20 extern_data data/tokenize Unlike Orchid and BEST, LST20 has train/eval/test splits, which we relabel train/dev/test. @@ -70,8 +71,11 @@ def retokenize_document(lines): def read_data(input_dir, section, resegment): - input_dir = os.path.join(input_dir, section) - filenames = glob.glob(os.path.join(input_dir, "*.txt")) + glob_path = os.path.join(input_dir, section, "*.txt") + filenames = glob.glob(glob_path) + print(" Found {} files in {}".format(len(filenames), glob_path)) + if len(filenames) == 0: + raise FileNotFoundError("Could not find any files for the {} section. Is LST20 installed in {}?".format(section, input_dir)) documents = [] for filename in filenames: with open(filename) as fin: @@ -96,10 +100,10 @@ def parse_lst20_args(): def convert(input_dir, output_dir, args): - full_input_dir = os.path.join(input_dir, "thai", "LST20_Corpus") - if os.path.exists(full_input_dir): - # otherwise hopefully the user gave us the full path? - input_dir = full_input_dir + input_dir = os.path.join(input_dir, "thai", "LST20_Corpus") + if not os.path.exists(input_dir): + raise FileNotFoundError("Could not find LST20 corpus in {}".format(input_dir)) + for (in_section, out_section) in (("train", "train"), ("eval", "dev"), ("test", "test")): -- cgit v1.2.3 From 720e09ae605210c8b87ea0028c0a8287acfc69b5 Mon Sep 17 00:00:00 2001 From: John Bauer Date: Wed, 21 Jul 2021 16:47:57 -0700 Subject: Add a lot of notes on how the characters are expected to line up in the test --- stanza/tests/test_tokenization_lst20.py | 23 +++++++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/stanza/tests/test_tokenization_lst20.py b/stanza/tests/test_tokenization_lst20.py index 5b94c84a..d232d61c 100644 --- a/stanza/tests/test_tokenization_lst20.py +++ b/stanza/tests/test_tokenization_lst20.py @@ -56,8 +56,27 @@ EXPECTED_CONLLU=""" 5 น. _ _ _ _ 4 dep 4:dep SpaceAfter=No """.strip() -EXPECTED_TXT="สุรยุทธ์ยันปฏิเสธลงนาม MOU กับอียูไม่กระทบสัมพันธ์1 กันยายน 2550 12:21 น.\n\n" -EXPECTED_LABELS="0000000100100000100001000100010001001000010000000210000000100001000001002\n\n" +# Note: these DO NOT line up perfectly (in an emacs window, at least) +# because Thai characters have a length greater than 1. +# The lengths of the words are: +# สุรยุทธ์ 8 +# ยัน 3 +# ปฏิเสธ 6 +# ลงนาม 5 +# MOU 3 +# กับ 3 +# อียู 4 +# ไม่ 3 +# กระทบ 5 +# สัมพันธ์ 8 +# 1 1 +# กันยายน 7 +# 2550 4 +# 12:21 5 +# น. 2 +EXPECTED_TXT = "สุรยุทธ์ยันปฏิเสธลงนาม MOU กับอียูไม่กระทบสัมพันธ์1 กันยายน 2550 12:21 น.\n\n" +EXPECTED_LABELS = "0000000100100000100001000100010001001000010000000210000000100001000001002\n\n" +# counting spaces 1234567812312345612345_123_123123412312345123456781_1234567_1234_12345_12 def test_small(): """ -- cgit v1.2.3 From 4f27e1144f4927a3a5a056c7e66872b885720d3d Mon Sep 17 00:00:00 2001 From: John Bauer Date: Wed, 21 Jul 2021 16:55:16 -0700 Subject: Add an option to add spaces after the sentence ends (which is actually more correct) --- stanza/tests/test_tokenization_lst20.py | 2 +- stanza/utils/datasets/tokenization/convert_th_lst20.py | 15 +++++++++------ 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/stanza/tests/test_tokenization_lst20.py b/stanza/tests/test_tokenization_lst20.py index d232d61c..f93759fe 100644 --- a/stanza/tests/test_tokenization_lst20.py +++ b/stanza/tests/test_tokenization_lst20.py @@ -86,7 +86,7 @@ def test_small(): Apparently this is wrong, but weirdly, doing that makes the model even worse. """ lines = SMALL_LST_SAMPLE.strip().split("\n") - documents = read_document(lines) + documents = read_document(lines, spaces_after=False) with tempfile.TemporaryDirectory() as output_dir: write_section(output_dir, "lst20", "train", documents) diff --git a/stanza/utils/datasets/tokenization/convert_th_lst20.py b/stanza/utils/datasets/tokenization/convert_th_lst20.py index 4c48aad8..deb60a1b 100644 --- a/stanza/utils/datasets/tokenization/convert_th_lst20.py +++ b/stanza/utils/datasets/tokenization/convert_th_lst20.py @@ -21,14 +21,15 @@ import sys from stanza.utils.datasets.tokenization.process_thai_tokenization import write_section, convert_processed_lines, reprocess_lines -def read_document(lines): +def read_document(lines, spaces_after): document = [] sentence = [] for line in lines: line = line.strip() if not line: if sentence: - #sentence[-1] = (sentence[-1][0], True) + if spaces_after: + sentence[-1] = (sentence[-1][0], True) document.append(sentence) sentence = [] else: @@ -41,7 +42,8 @@ def read_document(lines): sentence.append((pieces[0], False)) if sentence: - #sentence[-1] = (sentence[-1][0], True) + if spaces_after: + sentence[-1] = (sentence[-1][0], True) document.append(sentence) sentence = [] # TODO: is there any way to divide up a single document into paragraphs? @@ -70,7 +72,7 @@ def retokenize_document(lines): return paragraphs -def read_data(input_dir, section, resegment): +def read_data(input_dir, section, resegment, spaces_after): glob_path = os.path.join(input_dir, section, "*.txt") filenames = glob.glob(glob_path) print(" Found {} files in {}".format(len(filenames), glob_path)) @@ -83,12 +85,13 @@ def read_data(input_dir, section, resegment): if resegment: document = retokenize_document(lines) else: - document = read_document(lines) + document = read_document(lines, spaces_after) documents.extend(document) return documents def add_lst20_args(parser): parser.add_argument('--no_lst20_resegment', action='store_false', dest="lst20_resegment", default=True, help='When processing th_lst20 tokenization, use pythainlp to resegment the text. The other option is to keep the original sentence segmentation. Currently our model is not good at that') + parser.add_argument('--lst20_spaces_after', action='store_true', dest="lst20_spaces_after", default=False, help='When processing th_lst20 without pythainlp, put spaces after each sentence. This better fits the language but gets lower scores for some reason') def parse_lst20_args(): parser = argparse.ArgumentParser() @@ -108,7 +111,7 @@ def convert(input_dir, output_dir, args): ("eval", "dev"), ("test", "test")): print("Processing %s" % out_section) - documents = read_data(input_dir, in_section, args.lst20_resegment) + documents = read_data(input_dir, in_section, args.lst20_resegment, args.lst20_spaces_after) print(" Read in %d documents" % len(documents)) write_section(output_dir, "lst20", out_section, documents) -- cgit v1.2.3 From 71b9aece2cacb4f371cf61297298f318d2248289 Mon Sep 17 00:00:00 2001 From: John Bauer Date: Wed, 21 Jul 2021 19:55:42 -0700 Subject: Add more notes on how the tokenization boundaries are determined --- stanza/tests/test_tokenization_lst20.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/stanza/tests/test_tokenization_lst20.py b/stanza/tests/test_tokenization_lst20.py index f93759fe..3959592b 100644 --- a/stanza/tests/test_tokenization_lst20.py +++ b/stanza/tests/test_tokenization_lst20.py @@ -78,6 +78,12 @@ EXPECTED_TXT = "สุรยุทธ์ยันปฏิเสธลง EXPECTED_LABELS = "0000000100100000100001000100010001001000010000000210000000100001000001002\n\n" # counting spaces 1234567812312345612345_123_123123412312345123456781_1234567_1234_12345_12 +# note that the word splits go on the final letter of the word in the +# UD conllu datasets, so that is what we mimic here +# for example, from EWT: +# Al-Zaman : American forces killed Shaikh Abdullah +# 0110000101000000001000000100000010000001000000001 + def test_small(): """ A small test just to verify that the output is being produced as we want -- cgit v1.2.3 From 377be71e432eee5451b1627f5af4e0a4612f489c Mon Sep 17 00:00:00 2001 From: John Bauer Date: Wed, 21 Jul 2021 22:52:50 -0700 Subject: Add an option to split clauses into sentences if a space is between clauses --- stanza/tests/test_tokenization_lst20.py | 2 +- stanza/utils/datasets/tokenization/convert_th_lst20.py | 18 +++++++++++++----- 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/stanza/tests/test_tokenization_lst20.py b/stanza/tests/test_tokenization_lst20.py index 3959592b..b4af3cc7 100644 --- a/stanza/tests/test_tokenization_lst20.py +++ b/stanza/tests/test_tokenization_lst20.py @@ -92,7 +92,7 @@ def test_small(): Apparently this is wrong, but weirdly, doing that makes the model even worse. """ lines = SMALL_LST_SAMPLE.strip().split("\n") - documents = read_document(lines, spaces_after=False) + documents = read_document(lines, spaces_after=False, split_clauses=False) with tempfile.TemporaryDirectory() as output_dir: write_section(output_dir, "lst20", "train", documents) diff --git a/stanza/utils/datasets/tokenization/convert_th_lst20.py b/stanza/utils/datasets/tokenization/convert_th_lst20.py index deb60a1b..744c44cd 100644 --- a/stanza/utils/datasets/tokenization/convert_th_lst20.py +++ b/stanza/utils/datasets/tokenization/convert_th_lst20.py @@ -21,7 +21,7 @@ import sys from stanza.utils.datasets.tokenization.process_thai_tokenization import write_section, convert_processed_lines, reprocess_lines -def read_document(lines, spaces_after): +def read_document(lines, spaces_after, split_clauses): document = [] sentence = [] for line in lines: @@ -36,7 +36,14 @@ def read_document(lines, spaces_after): pieces = line.split("\t") # there are some nbsp in tokens in lst20, but the downstream tools expect spaces pieces = [p.replace("\xa0", " ") for p in pieces] - if pieces[0] == '_': + if split_clauses and pieces[0] == '_' and pieces[3] == 'O': + if sentence: + # note that we don't need to check spaces_after + # the "token" is a space anyway + sentence[-1] = (sentence[-1][0], True) + document.append(sentence) + sentence = [] + elif pieces[0] == '_': sentence[-1] = (sentence[-1][0], True) else: sentence.append((pieces[0], False)) @@ -72,7 +79,7 @@ def retokenize_document(lines): return paragraphs -def read_data(input_dir, section, resegment, spaces_after): +def read_data(input_dir, section, resegment, spaces_after, split_clauses): glob_path = os.path.join(input_dir, section, "*.txt") filenames = glob.glob(glob_path) print(" Found {} files in {}".format(len(filenames), glob_path)) @@ -85,13 +92,14 @@ def read_data(input_dir, section, resegment, spaces_after): if resegment: document = retokenize_document(lines) else: - document = read_document(lines, spaces_after) + document = read_document(lines, spaces_after, split_clauses) documents.extend(document) return documents def add_lst20_args(parser): parser.add_argument('--no_lst20_resegment', action='store_false', dest="lst20_resegment", default=True, help='When processing th_lst20 tokenization, use pythainlp to resegment the text. The other option is to keep the original sentence segmentation. Currently our model is not good at that') parser.add_argument('--lst20_spaces_after', action='store_true', dest="lst20_spaces_after", default=False, help='When processing th_lst20 without pythainlp, put spaces after each sentence. This better fits the language but gets lower scores for some reason') + parser.add_argument('--split_clauses', action='store_true', dest="split_clauses", default=False, help='When processing th_lst20 without pythainlp, turn spaces which are labeled as between clauses into sentence splits') def parse_lst20_args(): parser = argparse.ArgumentParser() @@ -111,7 +119,7 @@ def convert(input_dir, output_dir, args): ("eval", "dev"), ("test", "test")): print("Processing %s" % out_section) - documents = read_data(input_dir, in_section, args.lst20_resegment, args.lst20_spaces_after) + documents = read_data(input_dir, in_section, args.lst20_resegment, args.lst20_spaces_after, args.split_clauses) print(" Read in %d documents" % len(documents)) write_section(output_dir, "lst20", out_section, documents) -- cgit v1.2.3 From 204ad0d450bcf05a3bcdc5d31b3796bc5ab948c0 Mon Sep 17 00:00:00 2001 From: John Bauer Date: Thu, 22 Jul 2021 00:04:35 -0700 Subject: Add a longer test for a couple different variations on processing text --- stanza/tests/test_tokenization_lst20.py | 157 +++++++++++++++++++++++++++++--- 1 file changed, 142 insertions(+), 15 deletions(-) diff --git a/stanza/tests/test_tokenization_lst20.py b/stanza/tests/test_tokenization_lst20.py index b4af3cc7..a0728123 100644 --- a/stanza/tests/test_tokenization_lst20.py +++ b/stanza/tests/test_tokenization_lst20.py @@ -35,6 +35,18 @@ _ PU O I_CLS 12:21 NU B_DTM I_CLS _ PU I_DTM I_CLS น. CL E_DTM E_CLS + +ผู้สื่อข่าว NN O B_CLS +รายงาน VV O I_CLS +เพิ่มเติม VV O I_CLS +ว่า CC O E_CLS +_ PU O O +จาก PS O B_CLS +การ FX O I_CLS +ลง VV O I_CLS +พื้นที่ NN O I_CLS +พบ VV O I_CLS +ว่า CC O E_CLS """.strip() EXPECTED_CONLLU=""" @@ -54,6 +66,17 @@ EXPECTED_CONLLU=""" 3 2550 _ _ _ _ 2 dep 2:dep _ 4 12:21 _ _ _ _ 3 dep 3:dep _ 5 น. _ _ _ _ 4 dep 4:dep SpaceAfter=No + +1 ผู้สื่อข่าว _ _ _ _ 0 root 0:root SpaceAfter=No +2 รายงาน _ _ _ _ 1 dep 1:dep SpaceAfter=No +3 เพิ่มเติม _ _ _ _ 2 dep 2:dep SpaceAfter=No +4 ว่า _ _ _ _ 3 dep 3:dep _ +5 จาก _ _ _ _ 4 dep 4:dep SpaceAfter=No +6 การ _ _ _ _ 5 dep 5:dep SpaceAfter=No +7 ลง _ _ _ _ 6 dep 6:dep SpaceAfter=No +8 พื้นที่ _ _ _ _ 7 dep 7:dep SpaceAfter=No +9 พบ _ _ _ _ 8 dep 8:dep SpaceAfter=No +10 ว่า _ _ _ _ 9 dep 9:dep SpaceAfter=No """.strip() # Note: these DO NOT line up perfectly (in an emacs window, at least) @@ -74,9 +97,19 @@ EXPECTED_CONLLU=""" # 2550 4 # 12:21 5 # น. 2 -EXPECTED_TXT = "สุรยุทธ์ยันปฏิเสธลงนาม MOU กับอียูไม่กระทบสัมพันธ์1 กันยายน 2550 12:21 น.\n\n" -EXPECTED_LABELS = "0000000100100000100001000100010001001000010000000210000000100001000001002\n\n" -# counting spaces 1234567812312345612345_123_123123412312345123456781_1234567_1234_12345_12 +# ผู้สื่อข่าว 11 +# รายงาน 6 +# เพิ่มเติม 9 +# ว่า 3 +# จาก 3 +# การ 3 +# ลง 2 +# พื้นที่ 7 +# พบ 2 +# ว่า 3 +EXPECTED_TXT = "สุรยุทธ์ยันปฏิเสธลงนาม MOU กับอียูไม่กระทบสัมพันธ์1 กันยายน 2550 12:21 น.ผู้สื่อข่าวรายงานเพิ่มเติมว่า จากการลงพื้นที่พบว่า\n\n" +EXPECTED_LABELS = "000000010010000010000100010001000100100001000000021000000010000100000100200000000001000001000000001001000100101000000101002\n\n" +# counting spaces 1234567812312345612345_123_123123412312345123456781_1234567_1234_12345_12123456789AB123456123456789123_12312312123456712123 # note that the word splits go on the final letter of the word in the # UD conllu datasets, so that is what we mimic here @@ -84,6 +117,22 @@ EXPECTED_LABELS = "00000001001000001000010001000100010010000100000002100000001 # Al-Zaman : American forces killed Shaikh Abdullah # 0110000101000000001000000100000010000001000000001 +def check_results(documents, expected_conllu, expected_txt, expected_labels): + with tempfile.TemporaryDirectory() as output_dir: + write_section(output_dir, "lst20", "train", documents) + with open(os.path.join(output_dir, "th_lst20.train.gold.conllu")) as fin: + conllu = fin.read().strip() + with open(os.path.join(output_dir, "th_lst20.train.txt")) as fin: + txt = fin.read() + with open(os.path.join(output_dir, "th_lst20-ud-train.toklabels")) as fin: + labels = fin.read() + assert conllu == expected_conllu + assert txt == expected_txt + assert labels == expected_labels + + assert len(txt) == len(labels) + + def test_small(): """ A small test just to verify that the output is being produced as we want @@ -93,17 +142,95 @@ def test_small(): """ lines = SMALL_LST_SAMPLE.strip().split("\n") documents = read_document(lines, spaces_after=False, split_clauses=False) + check_results(documents, EXPECTED_CONLLU, EXPECTED_TXT, EXPECTED_LABELS) - with tempfile.TemporaryDirectory() as output_dir: - write_section(output_dir, "lst20", "train", documents) - with open(os.path.join(output_dir, "th_lst20.train.gold.conllu")) as fin: - conllu = fin.read().strip() - with open(os.path.join(output_dir, "th_lst20.train.txt")) as fin: - txt = fin.read() - with open(os.path.join(output_dir, "th_lst20-ud-train.toklabels")) as fin: - labels = fin.read() - assert conllu == EXPECTED_CONLLU - assert txt == EXPECTED_TXT - assert labels == EXPECTED_LABELS +EXPECTED_SPACE_CONLLU=""" +1 สุรยุทธ์ _ _ _ _ 0 root 0:root SpaceAfter=No|NewPar=Yes +2 ยัน _ _ _ _ 1 dep 1:dep SpaceAfter=No +3 ปฏิเสธ _ _ _ _ 2 dep 2:dep SpaceAfter=No +4 ลงนาม _ _ _ _ 3 dep 3:dep _ +5 MOU _ _ _ _ 4 dep 4:dep _ +6 กับ _ _ _ _ 5 dep 5:dep SpaceAfter=No +7 อียู _ _ _ _ 6 dep 6:dep SpaceAfter=No +8 ไม่ _ _ _ _ 7 dep 7:dep SpaceAfter=No +9 กระทบ _ _ _ _ 8 dep 8:dep SpaceAfter=No +10 สัมพันธ์ _ _ _ _ 9 dep 9:dep _ - assert len(txt) == len(labels) +1 1 _ _ _ _ 0 root 0:root _ +2 กันยายน _ _ _ _ 1 dep 1:dep _ +3 2550 _ _ _ _ 2 dep 2:dep _ +4 12:21 _ _ _ _ 3 dep 3:dep _ +5 น. _ _ _ _ 4 dep 4:dep _ + +1 ผู้สื่อข่าว _ _ _ _ 0 root 0:root SpaceAfter=No +2 รายงาน _ _ _ _ 1 dep 1:dep SpaceAfter=No +3 เพิ่มเติม _ _ _ _ 2 dep 2:dep SpaceAfter=No +4 ว่า _ _ _ _ 3 dep 3:dep _ +5 จาก _ _ _ _ 4 dep 4:dep SpaceAfter=No +6 การ _ _ _ _ 5 dep 5:dep SpaceAfter=No +7 ลง _ _ _ _ 6 dep 6:dep SpaceAfter=No +8 พื้นที่ _ _ _ _ 7 dep 7:dep SpaceAfter=No +9 พบ _ _ _ _ 8 dep 8:dep SpaceAfter=No +10 ว่า _ _ _ _ 9 dep 9:dep _ +""".strip() + +EXPECTED_SPACE_TXT = "สุรยุทธ์ยันปฏิเสธลงนาม MOU กับอียูไม่กระทบสัมพันธ์ 1 กันยายน 2550 12:21 น. ผู้สื่อข่าวรายงานเพิ่มเติมว่า จากการลงพื้นที่พบว่า\n\n" +EXPECTED_SPACE_LABELS = "00000001001000001000010001000100010010000100000002010000000100001000001002000000000001000001000000001001000100101000000101002\n\n" + +def test_space_after(): + """ + This version of the test adds the space after attribute + """ + lines = SMALL_LST_SAMPLE.strip().split("\n") + documents = read_document(lines, spaces_after=True, split_clauses=False) + check_results(documents, EXPECTED_SPACE_CONLLU, EXPECTED_SPACE_TXT, EXPECTED_SPACE_LABELS) + + +EXPECTED_CLAUSE_CONLLU=""" +1 สุรยุทธ์ _ _ _ _ 0 root 0:root SpaceAfter=No|NewPar=Yes +2 ยัน _ _ _ _ 1 dep 1:dep SpaceAfter=No +3 ปฏิเสธ _ _ _ _ 2 dep 2:dep SpaceAfter=No +4 ลงนาม _ _ _ _ 3 dep 3:dep _ +5 MOU _ _ _ _ 4 dep 4:dep _ +6 กับ _ _ _ _ 5 dep 5:dep SpaceAfter=No +7 อียู _ _ _ _ 6 dep 6:dep SpaceAfter=No +8 ไม่ _ _ _ _ 7 dep 7:dep SpaceAfter=No +9 กระทบ _ _ _ _ 8 dep 8:dep SpaceAfter=No +10 สัมพันธ์ _ _ _ _ 9 dep 9:dep _ + +1 1 _ _ _ _ 0 root 0:root _ +2 กันยายน _ _ _ _ 1 dep 1:dep _ +3 2550 _ _ _ _ 2 dep 2:dep _ +4 12:21 _ _ _ _ 3 dep 3:dep _ +5 น. _ _ _ _ 4 dep 4:dep _ + +1 ผู้สื่อข่าว _ _ _ _ 0 root 0:root SpaceAfter=No +2 รายงาน _ _ _ _ 1 dep 1:dep SpaceAfter=No +3 เพิ่มเติม _ _ _ _ 2 dep 2:dep SpaceAfter=No +4 ว่า _ _ _ _ 3 dep 3:dep _ + +1 จาก _ _ _ _ 0 root 0:root SpaceAfter=No +2 การ _ _ _ _ 1 dep 1:dep SpaceAfter=No +3 ลง _ _ _ _ 2 dep 2:dep SpaceAfter=No +4 พื้นที่ _ _ _ _ 3 dep 3:dep SpaceAfter=No +5 พบ _ _ _ _ 4 dep 4:dep SpaceAfter=No +6 ว่า _ _ _ _ 5 dep 5:dep _ +""".strip() + +EXPECTED_CLAUSE_TXT = "สุรยุทธ์ยันปฏิเสธลงนาม MOU กับอียูไม่กระทบสัมพันธ์ 1 กันยายน 2550 12:21 น. ผู้สื่อข่าวรายงานเพิ่มเติมว่า จากการลงพื้นที่พบว่า\n\n" +EXPECTED_CLAUSE_LABELS = "00000001001000001000010001000100010010000100000002010000000100001000001002000000000001000001000000001002000100101000000101002\n\n" + + +def test_split_clause(): + """ + This version of the test also resplits on spaces between clauses + """ + lines = SMALL_LST_SAMPLE.strip().split("\n") + documents = read_document(lines, spaces_after=True, split_clauses=True) + check_results(documents, EXPECTED_CLAUSE_CONLLU, EXPECTED_CLAUSE_TXT, EXPECTED_CLAUSE_LABELS) + +if __name__ == "__main__": + lines = SMALL_LST_SAMPLE.strip().split("\n") + documents = read_document(lines, spaces_after=False, split_clauses=False) + + write_section("foo", "lst20", "train", documents) -- cgit v1.2.3 From ff0bb70dfc983ee028f5ed94f97ae1e659e7c570 Mon Sep 17 00:00:00 2001 From: John Bauer Date: Thu, 22 Jul 2021 00:24:48 -0700 Subject: Add a test which checks that the orchid results are consistent This includes refactoring the xml parsing a bit in the orchid script --- stanza/tests/test_tokenization_orchid.py | 107 +++++++++++++++++++++ .../datasets/tokenization/convert_th_orchid.py | 7 +- 2 files changed, 112 insertions(+), 2 deletions(-) create mode 100644 stanza/tests/test_tokenization_orchid.py diff --git a/stanza/tests/test_tokenization_orchid.py b/stanza/tests/test_tokenization_orchid.py new file mode 100644 index 00000000..eed003f4 --- /dev/null +++ b/stanza/tests/test_tokenization_orchid.py @@ -0,0 +1,107 @@ +import os +import tempfile + +import pytest + +import xml.etree.ElementTree as ET + +import stanza +from stanza.tests import * + +from stanza.utils.datasets.prepare_tokenizer_treebank import convert_conllu_to_txt +from stanza.utils.datasets.tokenization.convert_th_orchid import parse_xml +from stanza.utils.datasets.tokenization.process_thai_tokenization import write_section + +pytestmark = [pytest.mark.travis, pytest.mark.pipeline] + + +SMALL_DOC=""" + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +""" + + +EXPECTED_RESULTS=""" +1 การ _ _ _ _ 0 root 0:root SpaceAfter=No|NewPar=Yes +2 ประชุม _ _ _ _ 1 dep 1:dep SpaceAfter=No +3 ทาง _ _ _ _ 2 dep 2:dep SpaceAfter=No +4 วิชาการ _ _ _ _ 3 dep 3:dep _ +5 ครั้ง _ _ _ _ 4 dep 4:dep SpaceAfter=No +6 ที่ 1 _ _ _ _ 5 dep 5:dep SpaceAfter=No + +1 โครงการวิจัยและพัฒนา _ _ _ _ 0 root 0:root SpaceAfter=No +2 อิเล็กทรอนิกส์ _ _ _ _ 1 dep 1:dep SpaceAfter=No +3 และ _ _ _ _ 2 dep 2:dep SpaceAfter=No +4 คอมพิวเตอร์ _ _ _ _ 3 dep 3:dep SpaceAfter=No + +1 วัน _ _ _ _ 0 root 0:root SpaceAfter=No|NewPar=Yes +2 ที่ 15 _ _ _ _ 1 dep 1:dep SpaceAfter=No +3 - _ _ _ _ 2 dep 2:dep SpaceAfter=No +4 16 _ _ _ _ 3 dep 3:dep _ +5 สิงหาคม _ _ _ _ 4 dep 4:dep _ +6 2532 _ _ _ _ 5 dep 5:dep SpaceAfter=No +""".strip() + +EXPECTED_TEXT="""การประชุมทางวิชาการ ครั้งที่ 1โครงการวิจัยและพัฒนาอิเล็กทรอนิกส์และคอมพิวเตอร์ + +วันที่ 15-16 สิงหาคม 2532 + +""" + +EXPECTED_LABELS="""001000001001000000100000100002000000000000000000010000000000000100100000000002 + +0010000011010000000100002 + +""" + +def check_results(documents, expected_conllu, expected_txt, expected_labels): + with tempfile.TemporaryDirectory() as output_dir: + write_section(output_dir, "orchid", "train", documents) + with open(os.path.join(output_dir, "th_orchid.train.gold.conllu")) as fin: + conllu = fin.read().strip() + with open(os.path.join(output_dir, "th_orchid.train.txt")) as fin: + txt = fin.read() + with open(os.path.join(output_dir, "th_orchid-ud-train.toklabels")) as fin: + labels = fin.read() + assert conllu == expected_conllu + assert txt == expected_txt + assert labels == expected_labels + + assert len(txt) == len(labels) + +def test_orchid(): + tree = ET.ElementTree(ET.fromstring(SMALL_DOC)) + documents = parse_xml(tree) + check_results(documents, EXPECTED_RESULTS, EXPECTED_TEXT, EXPECTED_LABELS) + diff --git a/stanza/utils/datasets/tokenization/convert_th_orchid.py b/stanza/utils/datasets/tokenization/convert_th_orchid.py index 4cecb491..fc60e636 100644 --- a/stanza/utils/datasets/tokenization/convert_th_orchid.py +++ b/stanza/utils/datasets/tokenization/convert_th_orchid.py @@ -94,7 +94,12 @@ allowed_sequences = { def read_data(input_filename): print("Reading {}".format(input_filename)) tree = ET.parse(input_filename) + documents = parse_xml(tree) + print("Number of documents: {}".format(len(documents))) + print("Number of paragraphs: {}".format(sum(len(document) for document in documents))) + return documents +def parse_xml(tree): # we will put each paragraph in a separate block in the output file # we won't pay any attention to the document boundaries unless we # later find out it was necessary @@ -138,8 +143,6 @@ def read_data(input_filename): paragraphs.append(sentences) documents.append(paragraphs) - print("Number of documents: {}".format(len(documents))) - print("Number of paragraphs: {}".format(sum(len(document) for document in documents))) return documents -- cgit v1.2.3 From a2ed479e6669ff0a384e5f2d2f4f7a29227a1dca Mon Sep 17 00:00:00 2001 From: John Bauer Date: Thu, 22 Jul 2021 00:27:11 -0700 Subject: Adjust orchid preparation script to always include spaces after sentences --- stanza/tests/test_tokenization_orchid.py | 10 +++++----- stanza/utils/datasets/tokenization/convert_th_orchid.py | 1 + 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/stanza/tests/test_tokenization_orchid.py b/stanza/tests/test_tokenization_orchid.py index eed003f4..8c0fb9f5 100644 --- a/stanza/tests/test_tokenization_orchid.py +++ b/stanza/tests/test_tokenization_orchid.py @@ -58,28 +58,28 @@ EXPECTED_RESULTS=""" 3 ทาง _ _ _ _ 2 dep 2:dep SpaceAfter=No 4 วิชาการ _ _ _ _ 3 dep 3:dep _ 5 ครั้ง _ _ _ _ 4 dep 4:dep SpaceAfter=No -6 ที่ 1 _ _ _ _ 5 dep 5:dep SpaceAfter=No +6 ที่ 1 _ _ _ _ 5 dep 5:dep _ 1 โครงการวิจัยและพัฒนา _ _ _ _ 0 root 0:root SpaceAfter=No 2 อิเล็กทรอนิกส์ _ _ _ _ 1 dep 1:dep SpaceAfter=No 3 และ _ _ _ _ 2 dep 2:dep SpaceAfter=No -4 คอมพิวเตอร์ _ _ _ _ 3 dep 3:dep SpaceAfter=No +4 คอมพิวเตอร์ _ _ _ _ 3 dep 3:dep _ 1 วัน _ _ _ _ 0 root 0:root SpaceAfter=No|NewPar=Yes 2 ที่ 15 _ _ _ _ 1 dep 1:dep SpaceAfter=No 3 - _ _ _ _ 2 dep 2:dep SpaceAfter=No 4 16 _ _ _ _ 3 dep 3:dep _ 5 สิงหาคม _ _ _ _ 4 dep 4:dep _ -6 2532 _ _ _ _ 5 dep 5:dep SpaceAfter=No +6 2532 _ _ _ _ 5 dep 5:dep _ """.strip() -EXPECTED_TEXT="""การประชุมทางวิชาการ ครั้งที่ 1โครงการวิจัยและพัฒนาอิเล็กทรอนิกส์และคอมพิวเตอร์ +EXPECTED_TEXT="""การประชุมทางวิชาการ ครั้งที่ 1 โครงการวิจัยและพัฒนาอิเล็กทรอนิกส์และคอมพิวเตอร์ วันที่ 15-16 สิงหาคม 2532 """ -EXPECTED_LABELS="""001000001001000000100000100002000000000000000000010000000000000100100000000002 +EXPECTED_LABELS="""0010000010010000001000001000020000000000000000000010000000000000100100000000002 0010000011010000000100002 diff --git a/stanza/utils/datasets/tokenization/convert_th_orchid.py b/stanza/utils/datasets/tokenization/convert_th_orchid.py index fc60e636..871e87d1 100644 --- a/stanza/utils/datasets/tokenization/convert_th_orchid.py +++ b/stanza/utils/datasets/tokenization/convert_th_orchid.py @@ -139,6 +139,7 @@ def parse_xml(tree): words.append((word, False)) if len(words) == 0: continue + words[-1] = (words[-1][0], True) sentences.append(words) paragraphs.append(sentences) documents.append(paragraphs) -- cgit v1.2.3 From 3314d049bf5343b11890242cdaa0d6a83e09d90c Mon Sep 17 00:00:00 2001 From: John Bauer Date: Fri, 23 Jul 2021 00:02:07 -0700 Subject: Process gz files as well as .txt and .txt.xz in the charlm --- stanza/utils/charlm/make_lm_data.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/stanza/utils/charlm/make_lm_data.py b/stanza/utils/charlm/make_lm_data.py index a2a7e3e8..e1a8ca16 100644 --- a/stanza/utils/charlm/make_lm_data.py +++ b/stanza/utils/charlm/make_lm_data.py @@ -86,6 +86,9 @@ def prepare_lm_data(src_dir, tgt_dir, lang, dataset_name): for src_fn in glob.glob(str(src_dir) + '/*.txt.xz'): cmd = f"xzcat {src_fn} >> {tgt_tmp}" subprocess.run(cmd, shell=True) + for src_fn in glob.glob(str(src_dir) + '/*.txt.gz'): + cmd = f"zcat {src_fn} >> {tgt_tmp}" + subprocess.run(cmd, shell=True) tgt_tmp_shuffled = Path(str(tgt_tmp) + ".shuffled") print(f"--> Shuffling files into {tgt_tmp_shuffled}...") -- cgit v1.2.3 From 11f290ed681f180c5df9a64efa0cc6c41e03b91a Mon Sep 17 00:00:00 2001 From: John Bauer Date: Fri, 23 Jul 2021 09:46:58 -0700 Subject: Fix command line for hindi datasets --- stanza/utils/datasets/ner/prepare_ner_dataset.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/stanza/utils/datasets/ner/prepare_ner_dataset.py b/stanza/utils/datasets/ner/prepare_ner_dataset.py index d936c158..80d63cec 100644 --- a/stanza/utils/datasets/ner/prepare_ner_dataset.py +++ b/stanza/utils/datasets/ner/prepare_ner_dataset.py @@ -18,13 +18,14 @@ IJCNLP 2008 produced a few Indian language NER datasets. download: http://ltrc.iiit.ac.in/ner-ssea-08/index.cgi?topic=5 The models produced from these datasets have extremely low recall, unfortunately. - - prepare_ner_dataset.py hi-fire2013 + - prepare_ner_dataset.py hi_ijc FIRE 2013 also produced NER datasets for Indian languages. http://au-kbc.org/nlp/NER-FIRE2013/index.html The datasets are password locked. For Stanford users, contact Chris Manning for license details. For external users, please contact the organizers for more information. + - prepare_ner_dataset.py hi-fire2013 Ukranian NER is provided by lang-uk, available here: https://github.com/lang-uk/ner-uk -- cgit v1.2.3 From e0203f1d214eb12422a8becfe50436f81e29f1c9 Mon Sep 17 00:00:00 2001 From: John Bauer Date: Fri, 23 Jul 2021 10:20:43 -0700 Subject: Add indentation to the json rather than saving it in one large dump --- stanza/utils/datasets/ner/prepare_ner_file.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/stanza/utils/datasets/ner/prepare_ner_file.py b/stanza/utils/datasets/ner/prepare_ner_file.py index f4894fa2..71f2aeee 100644 --- a/stanza/utils/datasets/ner/prepare_ner_file.py +++ b/stanza/utils/datasets/ner/prepare_ner_file.py @@ -34,8 +34,8 @@ def process_dataset(input_filename, output_filename): document += [sent] with open(output_filename, 'w') as outfile: - json.dump(document, outfile) - print("Generated json file {}.".format(output_filename)) + json.dump(document, outfile, indent=1) + print("Generated json file {}".format(output_filename)) # TODO: make skip_doc_start an argument def load_conll03(filename, skip_doc_start=True): -- cgit v1.2.3 From 7b0a9984fb0c87934cb07de8cb5a3b00e527b1a9 Mon Sep 17 00:00:00 2001 From: John Bauer Date: Fri, 23 Jul 2021 10:24:44 -0700 Subject: Add a test of empty text for the pipeline --- stanza/tests/test_english_pipeline.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/stanza/tests/test_english_pipeline.py b/stanza/tests/test_english_pipeline.py index c8003297..f270c1d4 100644 --- a/stanza/tests/test_english_pipeline.py +++ b/stanza/tests/test_english_pipeline.py @@ -163,6 +163,9 @@ def test_dependency_parse(processed_doc): assert "\n\n".join([sent.dependencies_string() for sent in processed_doc.sentences]) == \ EN_DOC_DEPENDENCY_PARSES_GOLD +def test_empty(pipeline): + # make sure that various models handle the degenerate empty case + pipeline("") @pytest.fixture(scope="module") def processed_multidoc(pipeline): -- cgit v1.2.3 From ccc9c1bf8e557e12e9ba7d970d9f30e7c9bc24cb Mon Sep 17 00:00:00 2001 From: John Bauer Date: Tue, 20 Jul 2021 09:35:58 -0700 Subject: If given an empty list, simply return an empty list when sort is called. Fixes issue 769. --- stanza/models/common/utils.py | 2 ++ stanza/tests/test_utils.py | 9 +++++++++ 2 files changed, 11 insertions(+) diff --git a/stanza/models/common/utils.py b/stanza/models/common/utils.py index 32f1b2f8..25491754 100644 --- a/stanza/models/common/utils.py +++ b/stanza/models/common/utils.py @@ -214,6 +214,8 @@ def sort_with_indices(data, key=None, reverse=False): One useful application is to sort by length, which can be done with key=len Returns the data as a sorted list, then the indices of the original list. """ + if not data: + return [], [] if key: ordered = sorted(enumerate(data), key=lambda x: key(x[1]), reverse=reverse) else: diff --git a/stanza/tests/test_utils.py b/stanza/tests/test_utils.py index bc5cf4e4..220cd224 100644 --- a/stanza/tests/test_utils.py +++ b/stanza/tests/test_utils.py @@ -84,6 +84,15 @@ def test_sort_with_indices(): unsorted = utils.unsort(ordered, orig_idx) assert data == unsorted +def test_empty_sort_with_indices(): + ordered, orig_idx = utils.sort_with_indices([]) + assert len(ordered) == 0 + assert len(orig_idx) == 0 + + unsorted = utils.unsort(ordered, orig_idx) + assert [] == unsorted + + def test_split_into_batches(): data = [] for i in range(5): -- cgit v1.2.3 From 8ecf43025ad93d82a0a847ef1af8bb49a2b599ab Mon Sep 17 00:00:00 2001 From: John Bauer Date: Tue, 20 Jul 2021 14:25:59 -0700 Subject: Add a flag for finetuning from a different load name from the save_name --- stanza/models/ner_tagger.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/stanza/models/ner_tagger.py b/stanza/models/ner_tagger.py index b4b0a09f..50b566c8 100644 --- a/stanza/models/ner_tagger.py +++ b/stanza/models/ner_tagger.py @@ -40,6 +40,7 @@ def parse_args(args=None): parser.add_argument('--mode', default='train', choices=['train', 'predict']) parser.add_argument('--finetune', action='store_true', help='Load existing model during `train` mode from `save_dir` path') + parser.add_argument('--finetune_load_name', type=str, default=None, help='Model to load when finetuning') parser.add_argument('--train_classifier_only', action='store_true', help='In case of applying Transfer-learning approach and training only the classifier layer this will freeze gradient propagation for all other layers.') parser.add_argument('--lang', type=str, help='Language') @@ -117,7 +118,10 @@ def train(args): vocab = None trainer = None - if args['finetune'] and os.path.exists(model_file): + if args['finetune'] and args['finetune_load_name']: + logger.warning('Finetune is ON. Using model from "{}"'.format(args['finetune_load_name'])) + _, trainer, vocab = load_model(args, args['finetune_load_name']) + elif args['finetune'] and os.path.exists(model_file): logger.warning('Finetune is ON. Using model from "{}"'.format(model_file)) _, trainer, vocab = load_model(args, model_file) else: -- cgit v1.2.3 From 71cec2e9011fcc51901849fd4927cab212d4bb5c Mon Sep 17 00:00:00 2001 From: John Bauer Date: Wed, 9 Jun 2021 08:43:45 -0700 Subject: Add a confusion matrix over tokens to the output of the ner_tagger --- stanza/models/classifier.py | 34 ++------------------ stanza/models/ner/scorer.py | 6 ++-- stanza/models/ner_tagger.py | 4 +++ stanza/tests/test_models_ner_scorer.py | 2 +- stanza/utils/confusion.py | 58 ++++++++++++++++++++++++++++++++++ 5 files changed, 69 insertions(+), 35 deletions(-) create mode 100644 stanza/utils/confusion.py diff --git a/stanza/models/classifier.py b/stanza/models/classifier.py index 89105178..2bbe81d9 100644 --- a/stanza/models/classifier.py +++ b/stanza/models/classifier.py @@ -22,6 +22,8 @@ import stanza.models.classifiers.classifier_args as classifier_args import stanza.models.classifiers.cnn_classifier as cnn_classifier import stanza.models.classifiers.data as data +from stanza.utils.confusion impmort format_confusion + class Loss(Enum): CROSS = 1 @@ -312,38 +314,6 @@ def confusion_to_macro_f1(confusion): return sum_f1 / len(keys) -def format_confusion(confusion, labels, hide_zeroes=False): - """ - pretty print for confusion matrixes - adapted from https://gist.github.com/zachguo/10296432 - """ - columnwidth = max([len(x) for x in labels] + [5]) # 5 is value length - empty_cell = " " * columnwidth - - fst_empty_cell = (columnwidth-3)//2 * " " + "t/p" + (columnwidth-3)//2 * " " - - if len(fst_empty_cell) < len(empty_cell): - fst_empty_cell = " " * (len(empty_cell) - len(fst_empty_cell)) + fst_empty_cell - # Print header - header = " " + fst_empty_cell + " " - - for label in labels: - header = header + "%{0}s ".format(columnwidth) % label - text = [header] - - # Print rows - for i, label1 in enumerate(labels): - row = " %{0}s ".format(columnwidth) % label1 - for j, label2 in enumerate(labels): - confusion_cell = confusion.get(label1, {}).get(label2, 0) - cell = "%{0}.1f".format(columnwidth) % confusion_cell - if hide_zeroes: - cell = cell if confusion_cell else empty_cell - row = row + cell + " " - text.append(row) - return "\n".join(text) - - def score_dataset(model, dataset, label_map=None, device=None, remap_labels=None, forgive_unmapped_labels=False): """ diff --git a/stanza/models/ner/scorer.py b/stanza/models/ner/scorer.py index 06eb82be..c72ccdcd 100644 --- a/stanza/models/ner/scorer.py +++ b/stanza/models/ner/scorer.py @@ -4,7 +4,7 @@ An NER scorer that calculates F1 score given gold and predicted tags. import sys import os import logging -from collections import Counter +from collections import Counter, defaultdict from stanza.models.ner.utils import decode_from_bioes @@ -82,11 +82,13 @@ def score_by_token(pred_tag_sequences, gold_tag_sequences, verbose=True): correct_by_tag = Counter() guessed_by_tag = Counter() gold_by_tag = Counter() + confusion = defaultdict(lambda: defaultdict(int)) for gold_tags, pred_tags in zip(gold_tag_sequences, pred_tag_sequences): assert(len(gold_tags) == len(pred_tags)), \ "Number of predicted tags does not match gold." for g, p in zip(gold_tags, pred_tags): + confusion[g][p] = confusion[g][p] + 1 if g == 'O' and p == 'O': continue elif g == 'O' and p != 'O': @@ -113,7 +115,7 @@ def score_by_token(pred_tag_sequences, gold_tag_sequences, verbose=True): logger.info("Prec.\tRec.\tF1") logger.info("{:.2f}\t{:.2f}\t{:.2f}".format( \ prec_micro*100, rec_micro*100, f_micro*100)) - return prec_micro, rec_micro, f_micro + return prec_micro, rec_micro, f_micro, confusion def test(): pred_sequences = [['O', 'S-LOC', 'O', 'O', 'B-PER', 'E-PER'], diff --git a/stanza/models/ner_tagger.py b/stanza/models/ner_tagger.py index 50b566c8..86a51caa 100644 --- a/stanza/models/ner_tagger.py +++ b/stanza/models/ner_tagger.py @@ -27,6 +27,8 @@ from stanza.utils.conll import CoNLL from stanza.models.common.doc import * from stanza.models import _training_logging +from stanza.utils.confusion import format_confusion + logger = logging.getLogger('stanza') def parse_args(args=None): @@ -259,9 +261,11 @@ def evaluate(args): gold_tags = batch.tags _, _, score = scorer.score_by_entity(preds, gold_tags) + _, _, _, confusion = scorer.score_by_token(preds, gold_tags) logger.info("NER tagger score:") logger.info("{} {:.2f}".format(args['shorthand'], score*100)) + logger.info("NER token confusion matrix:\n{}".format(format_confusion(confusion))) def load_model(args, model_file): diff --git a/stanza/tests/test_models_ner_scorer.py b/stanza/tests/test_models_ner_scorer.py index b6993f09..040f9297 100644 --- a/stanza/tests/test_models_ner_scorer.py +++ b/stanza/tests/test_models_ner_scorer.py @@ -16,7 +16,7 @@ def test_ner_scorer(): gold_sequences = [['O', 'B-LOC', 'E-LOC', 'O', 'B-PER', 'E-PER'], ['O', 'S-MISC', 'B-ORG', 'E-ORG', 'O', 'B-PER', 'E-PER', 'S-LOC']] - token_p, token_r, token_f = score_by_token(pred_sequences, gold_sequences) + token_p, token_r, token_f, confusion = score_by_token(pred_sequences, gold_sequences) assert pytest.approx(token_p, abs=0.00001) == 0.625 assert pytest.approx(token_r, abs=0.00001) == 0.5 assert pytest.approx(token_f, abs=0.00001) == 0.55555 diff --git a/stanza/utils/confusion.py b/stanza/utils/confusion.py new file mode 100644 index 00000000..0895d43f --- /dev/null +++ b/stanza/utils/confusion.py @@ -0,0 +1,58 @@ +def format_confusion(confusion, labels=None, hide_zeroes=False): + """ + pretty print for confusion matrixes + adapted from https://gist.github.com/zachguo/10296432 + + The matrix should look like this: + confusion[gold][pred] + """ + if labels is None: + labels = set(confusion.keys()) + for key in confusion.keys(): + labels = labels.union(confusion[key].keys()) + if 'O' in labels: + labels.remove('O') + labels = ['O'] + sorted(labels) + else: + labels = labels.sorted() + + columnwidth = max([len(x) for x in labels] + [5]) # 5 is value length + empty_cell = " " * columnwidth + + fst_empty_cell = (columnwidth-3)//2 * " " + "t/p" + (columnwidth-3)//2 * " " + + if len(fst_empty_cell) < len(empty_cell): + fst_empty_cell = " " * (len(empty_cell) - len(fst_empty_cell)) + fst_empty_cell + # Print header + header = " " + fst_empty_cell + " " + + for label in labels: + header = header + "%{0}s ".format(columnwidth) % label + text = [header] + + # Print rows + all_ints = True + for i, label1 in enumerate(labels): + for j, label2 in enumerate(labels): + if not isinstance(confusion.get(label1, {}).get(label2, 0), int): + all_ints = False + break + if not all_ints: + break + + # Print rows + for i, label1 in enumerate(labels): + row = " %{0}s ".format(columnwidth) % label1 + for j, label2 in enumerate(labels): + confusion_cell = confusion.get(label1, {}).get(label2, 0) + if all_ints: + cell = "%{0}d".format(columnwidth) % confusion_cell + else: + cell = "%{0}.1f".format(columnwidth) % confusion_cell + if hide_zeroes: + cell = cell if confusion_cell else empty_cell + row = row + cell + " " + text.append(row) + return "\n".join(text) + + -- cgit v1.2.3 From c1744a5ac9357f66b7c4bd5055e30d86b1a4a777 Mon Sep 17 00:00:00 2001 From: John Bauer Date: Wed, 9 Jun 2021 12:33:15 -0700 Subject: Format ints differently from floats in the confusion matrix --- stanza/utils/confusion.py | 28 ++++++++++++++++++---------- 1 file changed, 18 insertions(+), 10 deletions(-) diff --git a/stanza/utils/confusion.py b/stanza/utils/confusion.py index 0895d43f..56aa5c63 100644 --- a/stanza/utils/confusion.py +++ b/stanza/utils/confusion.py @@ -26,11 +26,7 @@ def format_confusion(confusion, labels=None, hide_zeroes=False): # Print header header = " " + fst_empty_cell + " " - for label in labels: - header = header + "%{0}s ".format(columnwidth) % label - text = [header] - - # Print rows + # If the numbers are all ints, no need to include the .0 at the end of each entry all_ints = True for i, label1 in enumerate(labels): for j, label2 in enumerate(labels): @@ -39,16 +35,28 @@ def format_confusion(confusion, labels=None, hide_zeroes=False): break if not all_ints: break - + + if all_ints: + format_cell = lambda confusion_cell: "%{0}d".format(columnwidth) % confusion_cell + else: + format_cell = lambda confusion_cell: "%{0}.1f".format(columnwidth) % confusion_cell + + # make sure the columnwidth can handle long numbers + for i, label1 in enumerate(labels): + for j, label2 in enumerate(labels): + cell = confusion.get(label1, {}).get(label2, 0) + columnwidth = max(columnwidth, len(format_cell(cell))) + + for label in labels: + header = header + "%{0}s ".format(columnwidth) % label + text = [header] + # Print rows for i, label1 in enumerate(labels): row = " %{0}s ".format(columnwidth) % label1 for j, label2 in enumerate(labels): confusion_cell = confusion.get(label1, {}).get(label2, 0) - if all_ints: - cell = "%{0}d".format(columnwidth) % confusion_cell - else: - cell = "%{0}.1f".format(columnwidth) % confusion_cell + cell = format_cell(confusion_cell) if hide_zeroes: cell = cell if confusion_cell else empty_cell row = row + cell + " " -- cgit v1.2.3 From 7e5e7f6bfe53e01c6a75df03e9bfe3b94191d025 Mon Sep 17 00:00:00 2001 From: John Bauer Date: Wed, 9 Jun 2021 12:57:05 -0700 Subject: Make the matrix more readable when there are a ton of categories --- stanza/utils/confusion.py | 41 ++++++++++++++++++++++++++++++++++------- 1 file changed, 34 insertions(+), 7 deletions(-) diff --git a/stanza/utils/confusion.py b/stanza/utils/confusion.py index 56aa5c63..574aa7a9 100644 --- a/stanza/utils/confusion.py +++ b/stanza/utils/confusion.py @@ -1,3 +1,27 @@ + +from collections import defaultdict + +def condense_ner_labels(confusion, labels): + new_confusion = defaultdict(lambda: defaultdict(int)) + new_labels = [] + for l1 in labels: + if l1.find("-") >= 0: + new_l1 = l1.split("-", 1)[1] + else: + new_l1 = l1 + if new_l1 not in new_labels: + new_labels.append(new_l1) + for l2 in labels: + if l2.find("-") >= 0: + new_l2 = l2.split("-", 1)[1] + else: + new_l2 = l2 + + old_value = confusion.get(l1, {}).get(l2, 0) + new_confusion[new_l1][new_l2] = new_confusion[new_l1][new_l2] + old_value + return new_confusion, new_labels + + def format_confusion(confusion, labels=None, hide_zeroes=False): """ pretty print for confusion matrixes @@ -19,13 +43,6 @@ def format_confusion(confusion, labels=None, hide_zeroes=False): columnwidth = max([len(x) for x in labels] + [5]) # 5 is value length empty_cell = " " * columnwidth - fst_empty_cell = (columnwidth-3)//2 * " " + "t/p" + (columnwidth-3)//2 * " " - - if len(fst_empty_cell) < len(empty_cell): - fst_empty_cell = " " * (len(empty_cell) - len(fst_empty_cell)) + fst_empty_cell - # Print header - header = " " + fst_empty_cell + " " - # If the numbers are all ints, no need to include the .0 at the end of each entry all_ints = True for i, label1 in enumerate(labels): @@ -47,6 +64,16 @@ def format_confusion(confusion, labels=None, hide_zeroes=False): cell = confusion.get(label1, {}).get(label2, 0) columnwidth = max(columnwidth, len(format_cell(cell))) + # if this is an NER confusion matrix (well, if it has - in the labels) + # try to drop a bunch of labels to make the matrix easier to display + if columnwidth * len(labels) > 150: + confusion, labels = condense_ner_labels(confusion, labels) + + # Print header + fst_empty_cell = (columnwidth-3)//2 * " " + "t/p" + (columnwidth-3)//2 * " " + if len(fst_empty_cell) < len(empty_cell): + fst_empty_cell = " " * (len(empty_cell) - len(fst_empty_cell)) + fst_empty_cell + header = " " + fst_empty_cell + " " for label in labels: header = header + "%{0}s ".format(columnwidth) % label text = [header] -- cgit v1.2.3 From 6f00fde81feac718600df88ff9ab2767bf49cce5 Mon Sep 17 00:00:00 2001 From: John Bauer Date: Fri, 11 Jun 2021 23:40:29 -0700 Subject: Add a processing step for NHCLT datasets. Currently Afrikaans is the most useful, as we don't have other tokenizers --- stanza/utils/datasets/ner/prepare_ner_dataset.py | 53 ++++++++++++++++++++++++ stanza/utils/datasets/ner/split_wikiner.py | 14 ++++++- 2 files changed, 66 insertions(+), 1 deletion(-) diff --git a/stanza/utils/datasets/ner/prepare_ner_dataset.py b/stanza/utils/datasets/ner/prepare_ner_dataset.py index 80d63cec..47e3d8b3 100644 --- a/stanza/utils/datasets/ner/prepare_ner_dataset.py +++ b/stanza/utils/datasets/ner/prepare_ner_dataset.py @@ -70,6 +70,23 @@ BSNLP publishes NER datasets for Eastern European languages. bsnlp 2019 can be supported by adding the appropriate functionality in convert_bsnlp.py. - prepare_ner_dataset.py bg_bsnlp19 + +NCHLT produced NER datasets for many African languages. + Unfortunately, it is difficult to make use of many of these, + as there is no corresponding UD data from which to build a + tokenizer or other tools. + - Afrikaans: https://repo.sadilar.org/handle/20.500.12185/299 + - isiNdebele: https://repo.sadilar.org/handle/20.500.12185/306 + - isiXhosa: https://repo.sadilar.org/handle/20.500.12185/312 + - isiZulu: https://repo.sadilar.org/handle/20.500.12185/319 + - Sepedi: https://repo.sadilar.org/handle/20.500.12185/328 + - Sesotho: https://repo.sadilar.org/handle/20.500.12185/334 + - Setswana: https://repo.sadilar.org/handle/20.500.12185/341 + - Siswati: https://repo.sadilar.org/handle/20.500.12185/346 + - Tsivenda: https://repo.sadilar.org/handle/20.500.12185/355 + - Xitsonga: https://repo.sadilar.org/handle/20.500.12185/362 + Agree to the license, download the zip, and unzip it in + $NERBASE/NCHLT """ import glob @@ -322,6 +339,40 @@ def process_bsnlp(paths, short_name): output_filename = os.path.join(base_output_path, '%s.%s.json' % (short_name, shard)) prepare_ner_file.process_dataset(csv_file, output_filename) +NCHLT_LANGUAGE_MAP = { + "af": "NCHLT Afrikaans Named Entity Annotated Corpus", + # none of the following have UD datasets as of 2.8. Until they + # exist, we assume the language codes NCHTL are sufficient + "nr": "NCHLT isiNdebele Named Entity Annotated Corpus", + "nso": "NCHLT Sepedi Named Entity Annotated Corpus", + "ss": "NCHLT Siswati Named Entity Annotated Corpus", + "st": "NCHLT Sesotho Named Entity Annotated Corpus", + "tn": "NCHLT Setswana Named Entity Annotated Corpus", + "ts": "NCHLT Xitsonga Named Entity Annotated Corpus", + "ve": "NCHLT Tshivenda Named Entity Annotated Corpus", + "xh": "NCHLT isiXhosa Named Entity Annotated Corpus", + "zu": "NCHLT isiZulu Named Entity Annotated Corpus", +} + +def process_nchlt(paths, short_name): + language = short_name.split("_")[0] + if not language in NCHLT_LANGUAGE_MAP: + raise ValueError("Language %s not part of NCHLT" % language) + short_name = "%s_nchlt" % language + + base_input_path = os.path.join(paths["NERBASE"], "NCHLT", NCHLT_LANGUAGE_MAP[language], "*Full.txt") + input_files = glob.glob(base_input_path) + if len(input_files) == 0: + raise FileNotFoundError("Cannot find NCHLT dataset in '%s' Did you remember to download the file?" % base_input_path) + + if len(input_files) > 1: + raise ValueError("Unexpected number of files matched '%s' There should only be one" % base_input_path) + + base_output_path = paths["NER_DATA_DIR"] + split_wikiner(base_output_path, input_files[0], prefix=short_name, remap={"OUT": "O"}) + convert_bio_to_json(base_output_path, base_output_path, short_name) + + def main(dataset_name): paths = default_paths.get_default_paths() @@ -345,6 +396,8 @@ def main(dataset_name): process_hu_combined(paths) elif dataset_name.endswith("_bsnlp19"): process_bsnlp(paths, dataset_name) + elif dataset_name.endswith("_nchlt"): + process_nchlt(paths, dataset_name) else: raise ValueError(f"dataset {dataset_name} currently not handled") diff --git a/stanza/utils/datasets/ner/split_wikiner.py b/stanza/utils/datasets/ner/split_wikiner.py index 8c4b3d3d..fea2b02f 100644 --- a/stanza/utils/datasets/ner/split_wikiner.py +++ b/stanza/utils/datasets/ner/split_wikiner.py @@ -50,13 +50,25 @@ def write_sentences_to_file(sents, filename): print(f"{pair[0]}\t{pair[1]}", file=outfile) print("", file=outfile) -def split_wikiner(directory, *in_filenames, encoding="utf-8", prefix=""): +def remap_labels(sents, remap): + new_sentences = [] + for sentence in sents: + new_sent = [] + for word in sentence: + new_sent.append([word[0], remap.get(word[1], word[1])]) + new_sentences.append(new_sent) + return new_sentences + +def split_wikiner(directory, *in_filenames, encoding="utf-8", prefix="", remap=None): sents = [] for filename in in_filenames: new_sents = read_sentences(filename, encoding) print(f"{len(new_sents)} sentences read from {filename}.") sents.extend(new_sents) + if remap: + sents = remap_labels(sents, remap) + # split num = len(sents) train_num = int(num*0.7) -- cgit v1.2.3 From 09056dcd1b5613c5ead9a950ceae3c1cb4fa46ae Mon Sep 17 00:00:00 2001 From: John Bauer Date: Thu, 22 Jul 2021 15:51:47 -0700 Subject: Add the ability for the ner model to upscale basic (no B- or I-) tagging -> BIOES --- stanza/models/ner/data.py | 26 ++------------- stanza/models/ner/utils.py | 74 ++++++++++++++++++++++++++++++++++++++++++ stanza/tests/test_ner_utils.py | 49 ++++++++++++++++++++++++++++ 3 files changed, 125 insertions(+), 24 deletions(-) create mode 100644 stanza/tests/test_ner_utils.py diff --git a/stanza/models/ner/data.py b/stanza/models/ner/data.py index 80d6ac1d..f80dc10f 100644 --- a/stanza/models/ner/data.py +++ b/stanza/models/ner/data.py @@ -7,7 +7,7 @@ from stanza.models.common.vocab import PAD_ID, VOCAB_PREFIX from stanza.models.pos.vocab import CharVocab, WordVocab from stanza.models.ner.vocab import TagVocab, MultiVocab from stanza.models.common.doc import * -from stanza.models.ner.utils import is_bio_scheme, to_bio2, bio2_to_bioes +from stanza.models.ner.utils import process_tags logger = logging.getLogger('stanza') @@ -135,31 +135,9 @@ class DataLoader: def load_doc(self, doc): data = doc.get([TEXT, NER], as_sentences=True, from_token=True) if self.preprocess_tags: # preprocess tags - data = self.process_tags(data) + data = process_tags(data, self.args.get('scheme', 'bio')) return data - def process_tags(self, sentences): - res = [] - # check if tag conversion is needed - convert_to_bioes = False - is_bio = is_bio_scheme([x[1] for sent in sentences for x in sent]) - if is_bio and self.args.get('scheme', 'bio').lower() == 'bioes': - convert_to_bioes = True - logger.debug("BIO tagging scheme found in input; converting into BIOES scheme...") - # process tags - for sent in sentences: - words, tags = zip(*sent) - # NER field sanity checking - if any([x is None or x == '_' for x in tags]): - raise ValueError("NER tag not found for some input data.") - # first ensure BIO2 scheme - tags = to_bio2(tags) - # then convert to BIOES - if convert_to_bioes: - tags = bio2_to_bioes(tags) - res.append([[w,t] for w,t in zip(words, tags)]) - return res - def process_chars(self, sents): start_id, end_id = self.vocab['char'].unit2id('\n'), self.vocab['char'].unit2id(' ') # special token start_offset, end_offset = 1, 1 diff --git a/stanza/models/ner/utils.py b/stanza/models/ner/utils.py index 0b1d4b1e..c107b14e 100644 --- a/stanza/models/ner/utils.py +++ b/stanza/models/ner/utils.py @@ -2,6 +2,26 @@ Utility functions for dealing with NER tagging. """ +import logging + +logger = logging.getLogger('stanza') + +def is_basic_scheme(all_tags): + """ + Check if a basic tagging scheme is used. Return True if so. + + Args: + all_tags: a list of NER tags + + Returns: + True if the tagging scheme does not use B-, I-, etc, otherwise False + """ + for tag in all_tags: + if len(tag) > 2 and tag[:2] in ('B-', 'I-', 'S-', 'E-'): + return False + return True + + def is_bio_scheme(all_tags): """ Check if BIO tagging scheme is used. Return True if so. @@ -45,6 +65,28 @@ def to_bio2(tags): new_tags.append(tag) return new_tags +def basic_to_bio(tags): + """ + Convert a basic tag sequence into a BIO sequence. + You can compose this with bio2_to_bioes to convert to bioes + + Args: + tags: a list of tags in basic (no B-, I-, etc) format + + Returns: + new_tags: a list of tags in BIO format + """ + new_tags = [] + for i, tag in enumerate(tags): + if tag == 'O': + new_tags.append(tag) + elif i == 0 or tags[i-1] == 'O' or tags[i-1] != tag: + new_tags.append('B-' + tag) + else: + new_tags.append('I-' + tag) + return new_tags + + def bio2_to_bioes(tags): """ Convert the BIO2 tag sequence into a BIOES sequence. @@ -77,6 +119,38 @@ def bio2_to_bioes(tags): raise Exception(f"Invalid IOB tag found: {tag}") return new_tags +def process_tags(sentences, scheme): + res = [] + # check if tag conversion is needed + convert_bio_to_bioes = False + convert_basic_to_bioes = False + is_bio = is_bio_scheme([x[1] for sent in sentences for x in sent]) + is_basic = not is_bio and is_basic_scheme([x[1] for sent in sentences for x in sent]) + if is_bio and scheme.lower() == 'bioes': + convert_bio_to_bioes = True + logger.debug("BIO tagging scheme found in input; converting into BIOES scheme...") + elif is_basic and scheme.lower() == 'bioes': + convert_basic_to_bioes = True + logger.debug("Basic tagging scheme found in input; converting into BIOES scheme...") + # process tags + for sent in sentences: + words, tags = zip(*sent) + # NER field sanity checking + if any([x is None or x == '_' for x in tags]): + raise ValueError("NER tag not found for some input data.") + if convert_basic_to_bioes: + # if basic, convert tags -> bio -> bioes + tags = bio2_to_bioes(basic_to_bio(tags)) + else: + # first ensure BIO2 scheme + tags = to_bio2(tags) + # then convert to BIOES + if convert_bio_to_bioes: + tags = bio2_to_bioes(tags) + res.append([(w,t) for w,t in zip(words, tags)]) + return res + + def decode_from_bioes(tags): """ Decode from a sequence of BIOES tags, assuming default tag is 'O'. diff --git a/stanza/tests/test_ner_utils.py b/stanza/tests/test_ner_utils.py new file mode 100644 index 00000000..21d9262b --- /dev/null +++ b/stanza/tests/test_ner_utils.py @@ -0,0 +1,49 @@ +import pytest + +from stanza.tests import * + +from stanza.models.ner import utils + +pytestmark = [pytest.mark.travis, pytest.mark.pipeline] + +WORDS = [["Unban", "Mox", "Opal"], ["Ragavan", "is", "red"], ["Urza", "Lord", "High", "Artificer", "goes", "infinite", "with", "Thopter", "Sword"]] +BIO_TAGS = [["O", "B-ART", "I-ART"], ["B-MONKEY", "O", "B-COLOR"], ["B-PER", "I-PER", "I-PER", "I-PER", "O", "O", "O", "B-WEAPON", "B-WEAPON"]] +BIOES_TAGS = [["O", "B-ART", "E-ART"], ["S-MONKEY", "O", "S-COLOR"], ["B-PER", "I-PER", "I-PER", "E-PER", "O", "O", "O", "S-WEAPON", "S-WEAPON"]] +# note the problem with not using BIO tags - the consecutive tags for thopter/sword get treated as one item +BASIC_TAGS = [["O", "ART", "ART"], ["MONKEY", "O", "COLOR"], [ "PER", "PER", "PER", "PER", "O", "O", "O", "WEAPON", "WEAPON"]] +BASIC_BIOES = [["O", "B-ART", "E-ART"], ["S-MONKEY", "O", "S-COLOR"], ["B-PER", "I-PER", "I-PER", "E-PER", "O", "O", "O", "B-WEAPON", "E-WEAPON"]] + +def check_reprocessed_tags(words, input_tags, expected_tags): + sentences = [list(zip(x, y)) for x, y in zip(words, input_tags)] + retagged = utils.process_tags(sentences=sentences, scheme="bioes") + expected_retagged = [list(zip(x, y)) for x, y in zip(words, expected_tags)] + assert retagged == expected_retagged + +def test_process_tags_bio(): + check_reprocessed_tags(WORDS, BIO_TAGS, BIOES_TAGS) + +def test_process_tags_basic(): + check_reprocessed_tags(WORDS, BASIC_TAGS, BASIC_BIOES) + +def test_process_tags_bioes(): + """ + This one should not change, naturally + """ + check_reprocessed_tags(WORDS, BIOES_TAGS, BIOES_TAGS) + check_reprocessed_tags(WORDS, BASIC_BIOES, BASIC_BIOES) + +def run_flattened(fn, tags): + return fn([x for x in y for y in tags]) + +def test_check_bio(): + assert utils.is_bio_scheme([x for y in BIO_TAGS for x in y]) + assert not utils.is_bio_scheme([x for y in BIOES_TAGS for x in y]) + assert not utils.is_bio_scheme([x for y in BASIC_TAGS for x in y]) + assert not utils.is_bio_scheme([x for y in BASIC_BIOES for x in y]) + +def test_check_basic(): + assert not utils.is_basic_scheme([x for y in BIO_TAGS for x in y]) + assert not utils.is_basic_scheme([x for y in BIOES_TAGS for x in y]) + assert utils.is_basic_scheme([x for y in BASIC_TAGS for x in y]) + assert not utils.is_basic_scheme([x for y in BASIC_BIOES for x in y]) + -- cgit v1.2.3 From 76ffb458baba63a44a6d5db148c3f46ca4a63a82 Mon Sep 17 00:00:00 2001 From: John Bauer Date: Thu, 22 Jul 2021 09:06:27 -0700 Subject: Add processing for it_fbk. Uses the .tsv file they sent us and their recommendation to use 0.8/0.1/0.1 splits --- stanza/utils/datasets/ner/prepare_ner_dataset.py | 21 +++++++++++++++++++-- stanza/utils/datasets/ner/split_wikiner.py | 13 ++++++++----- 2 files changed, 27 insertions(+), 7 deletions(-) diff --git a/stanza/utils/datasets/ner/prepare_ner_dataset.py b/stanza/utils/datasets/ner/prepare_ner_dataset.py index 47e3d8b3..33d2d04a 100644 --- a/stanza/utils/datasets/ner/prepare_ner_dataset.py +++ b/stanza/utils/datasets/ner/prepare_ner_dataset.py @@ -12,6 +12,10 @@ Also, Finnish Turku dataset, available here: $NERBASE/fi_turku - prepare_ner_dataset.py fi_turku +FBK in Italy produced an Italian dataset. + The processing here is for a combined .tsv file they sent us. + - prepare_ner_dataset.py it_fbk + IJCNLP 2008 produced a few Indian language NER datasets. description: http://ltrc.iiit.ac.in/ner-ssea-08/index.cgi?topic=3 @@ -110,7 +114,7 @@ import stanza.utils.datasets.ner.prepare_ner_file as prepare_ner_file SHARDS = ('train', 'dev', 'test') -def convert_bio_to_json(base_input_path, base_output_path, short_name): +def convert_bio_to_json(base_input_path, base_output_path, short_name, suffix="bio"): """ Convert BIO files to json @@ -119,7 +123,7 @@ def convert_bio_to_json(base_input_path, base_output_path, short_name): in same path for both base_input_path and base_output_path. """ for shard in SHARDS: - input_filename = os.path.join(base_input_path, '%s.%s.bio' % (short_name, shard)) + input_filename = os.path.join(base_input_path, '%s.%s.%s' % (short_name, shard, suffix)) if not os.path.exists(input_filename): raise FileNotFoundError('Cannot find %s component of %s in %s' % (shard, short_name, input_filename)) output_filename = os.path.join(base_output_path, '%s.%s.json' % (short_name, shard)) @@ -137,6 +141,17 @@ def process_turku(paths): output_filename = os.path.join(base_output_path, '%s.%s.json' % (short_name, shard)) prepare_ner_file.process_dataset(input_filename, output_filename) +def process_it_fbk(paths): + short_name = "it_fbk" + base_input_path = os.path.join(paths["NERBASE"], short_name) + csv_file = os.path.join(base_input_path, "all-wiki-split.tsv") + if not os.path.exists(csv_file): + raise FileNotFoundError("Cannot find the FBK dataset in its expected location: {}".format(csv_file)) + base_output_path = paths["NER_DATA_DIR"] + split_wikiner(base_output_path, csv_file, prefix=short_name, suffix="io", shuffle=False, train_fraction=0.8, dev_fraction=0.1) + convert_bio_to_json(base_output_path, base_output_path, short_name, suffix="io") + + def process_languk(paths): short_name = 'uk_languk' base_input_path = os.path.join(paths["NERBASE"], 'lang-uk', 'ner-uk', 'data') @@ -380,6 +395,8 @@ def main(dataset_name): if dataset_name == 'fi_turku': process_turku(paths) + elif dataset_name == 'it_fbk': + process_it_fbk(paths) elif dataset_name in ('uk_languk', 'Ukranian_languk', 'Ukranian-languk'): process_languk(paths) elif dataset_name == 'hi_ijc': diff --git a/stanza/utils/datasets/ner/split_wikiner.py b/stanza/utils/datasets/ner/split_wikiner.py index fea2b02f..8c7e8d47 100644 --- a/stanza/utils/datasets/ner/split_wikiner.py +++ b/stanza/utils/datasets/ner/split_wikiner.py @@ -59,7 +59,7 @@ def remap_labels(sents, remap): new_sentences.append(new_sent) return new_sentences -def split_wikiner(directory, *in_filenames, encoding="utf-8", prefix="", remap=None): +def split_wikiner(directory, *in_filenames, encoding="utf-8", prefix="", suffix="bio", remap=None, shuffle=True, train_fraction=0.7, dev_fraction=0.15): sents = [] for filename in in_filenames: new_sents = read_sentences(filename, encoding) @@ -71,16 +71,19 @@ def split_wikiner(directory, *in_filenames, encoding="utf-8", prefix="", remap=N # split num = len(sents) - train_num = int(num*0.7) - dev_num = int(num*0.15) + train_num = int(num*train_fraction) + dev_num = int(num*dev_fraction) + if train_fraction + dev_fraction > 1.0: + raise ValueError("Train and dev fractions added up to more than 1: {} {} {}".format(train_fraction, dev_fraction)) - random.shuffle(sents) + if shuffle: + random.shuffle(sents) train_sents = sents[:train_num] dev_sents = sents[train_num:train_num+dev_num] test_sents = sents[train_num+dev_num:] batches = [train_sents, dev_sents, test_sents] - filenames = ['train.bio', 'dev.bio', 'test.bio'] + filenames = [f'train.{suffix}', f'dev.{suffix}', f'test.{suffix}'] if prefix: filenames = ['%s.%s' % (prefix, f) for f in filenames] for batch, filename in zip(batches, filenames): -- cgit v1.2.3 From 237333c5378c49ba049bd7161b1688e26357dfe0 Mon Sep 17 00:00:00 2001 From: John Bauer Date: Thu, 22 Jul 2021 15:50:37 -0700 Subject: Add some explanation to the logging output for the NER scores --- stanza/models/ner/scorer.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/stanza/models/ner/scorer.py b/stanza/models/ner/scorer.py index c72ccdcd..20c22210 100644 --- a/stanza/models/ner/scorer.py +++ b/stanza/models/ner/scorer.py @@ -59,8 +59,7 @@ def score_by_entity(pred_tag_sequences, gold_tag_sequences, verbose=True): f_micro = 2.0 * prec_micro * rec_micro / (prec_micro + rec_micro) if verbose: - logger.info("Prec.\tRec.\tF1") - logger.info("{:.2f}\t{:.2f}\t{:.2f}".format( \ + logger.info("Score by entity:\nPrec.\tRec.\tF1\n{:.2f}\t{:.2f}\t{:.2f}".format( prec_micro*100, rec_micro*100, f_micro*100)) return prec_micro, rec_micro, f_micro @@ -112,8 +111,7 @@ def score_by_token(pred_tag_sequences, gold_tag_sequences, verbose=True): f_micro = 2.0 * prec_micro * rec_micro / (prec_micro + rec_micro) if verbose: - logger.info("Prec.\tRec.\tF1") - logger.info("{:.2f}\t{:.2f}\t{:.2f}".format( \ + logger.info("Score by token:\nPrec.\tRec.\tF1\n{:.2f}\t{:.2f}\t{:.2f}".format( prec_micro*100, rec_micro*100, f_micro*100)) return prec_micro, rec_micro, f_micro, confusion -- cgit v1.2.3 From 0227367109536eba7e2ac1d0f97a60f8436b15b1 Mon Sep 17 00:00:00 2001 From: vythaihn <68755973+vythaihn@users.noreply.github.com> Date: Tue, 27 Jul 2021 13:47:23 -0700 Subject: Add vlsp pos dataset option for VLSP WS task (#772) Combines the VLSP POS dataset with the VLSP WS dataset for building WS models --- stanza/utils/datasets/common.py | 2 - .../utils/datasets/prepare_tokenizer_treebank.py | 1 + .../utils/datasets/tokenization/convert_vi_vlsp.py | 48 +++++++++++++++++++--- 3 files changed, 43 insertions(+), 8 deletions(-) diff --git a/stanza/utils/datasets/common.py b/stanza/utils/datasets/common.py index 87fe1490..8c1631fb 100644 --- a/stanza/utils/datasets/common.py +++ b/stanza/utils/datasets/common.py @@ -116,8 +116,6 @@ def build_argparse(): parser = argparse.ArgumentParser() parser.add_argument('treebanks', type=str, nargs='+', help='Which treebanks to run on. Use all_ud or ud_all for all UD treebanks') - # TODO: not sure this is the best place for dataset-specific arguments. - parser.add_argument('--vlsp_include_spaces', action='store_true', default=False, help='When processing vi_vlsp tokenization, include all of the spaces. Otherwise, we try to turn the text back into standard text') return parser diff --git a/stanza/utils/datasets/prepare_tokenizer_treebank.py b/stanza/utils/datasets/prepare_tokenizer_treebank.py index d84e36c3..35941cce 100755 --- a/stanza/utils/datasets/prepare_tokenizer_treebank.py +++ b/stanza/utils/datasets/prepare_tokenizer_treebank.py @@ -1023,6 +1023,7 @@ def add_specific_args(parser): help='Prepare tokenizer and MWT labels. Expensive, but obviously necessary for training those models.') convert_th_lst20.add_lst20_args(parser) + convert_vi_vlsp.add_vlsp_args(parser) def process_treebank(treebank, paths, args): """ Processes a single treebank into train, dev, test parts diff --git a/stanza/utils/datasets/tokenization/convert_vi_vlsp.py b/stanza/utils/datasets/tokenization/convert_vi_vlsp.py index 31e7a985..947fe17f 100644 --- a/stanza/utils/datasets/tokenization/convert_vi_vlsp.py +++ b/stanza/utils/datasets/tokenization/convert_vi_vlsp.py @@ -37,6 +37,9 @@ def find_spaces(sentence): spaces.append(space) return spaces +def add_vlsp_args(parser): + parser.add_argument('--include_pos_data', action='store_true', default=False, help='To include or not POS training dataset for tokenization training. The path to POS dataset is expected to be in the same dir with WS path. For example, extern_dir/vietnamese/VLSP2013-POS-data') + parser.add_argument('--vlsp_include_spaces', action='store_true', default=False, help='When processing vi_vlsp tokenization, include all of the spaces. Otherwise, we try to turn the text back into standard text') def write_file(vlsp_include_spaces, output_filename, sentences, shard): with open(output_filename, "w") as fout: check_headlines = False @@ -75,11 +78,30 @@ def write_file(vlsp_include_spaces, output_filename, sentences, shard): fout.write("\n") fout.write("\n") -def convert_file(vlsp_include_spaces, input_filename, output_filename, shard, split_filename=None, split_shard=None): +def convert_pos_dataset(file_path): + """ + This function is to process the pos dataset + """ + + file = open(file_path, "r") + document = file.readlines() + sentences = [] + sent = [] + for line in document: + if line == "\n" and len(sent)>1: + if sent not in sentences: + sentences.append(sent) + sent = [] + elif line != "\n": + sent.append(line.split("\t")[0].replace("_"," ").strip()) + return sentences + +def convert_file(vlsp_include_spaces, input_filename, output_filename, shard, split_filename=None, split_shard=None, pos_data = None): with open(input_filename) as fin: lines = fin.readlines() sentences = [] + set_sentences = set() for line in lines: if len(line.replace("_", " ").split())>1: words = line.split() @@ -88,30 +110,44 @@ def convert_file(vlsp_include_spaces, input_filename, output_filename, shard, sp continue else: words = [w.replace("_", " ") for w in words] - sentences.append(words) - + #only add sentences that hasn't been added before + if words not in sentences: + sentences.append(words) + set_sentences.add(' '.join(words)) + if split_filename is not None: # even this is a larger dev set than the train set split_point = int(len(sentences) * 0.95) - write_file(vlsp_include_spaces, output_filename, sentences[:split_point], shard) + #check pos_data that aren't overlapping with current VLSP WS dataset + sentences_pos = [] if pos_data is None else [sent for sent in pos_data if ' '.join(sent) not in set_sentences] + print("Added ", len(sentences_pos), " sentences from POS dataset.") + write_file(vlsp_include_spaces, output_filename, sentences[:split_point]+sentences_pos, shard) write_file(vlsp_include_spaces, split_filename, sentences[split_point:], split_shard) else: write_file(vlsp_include_spaces, output_filename, sentences, shard) def convert_vi_vlsp(extern_dir, tokenizer_dir, args): input_path = os.path.join(extern_dir, "vietnamese", "VLSP2013-WS-data") - + input_pos_path = os.path.join(extern_dir, "vietnamese", "VLSP2013-POS-data") input_train_filename = os.path.join(input_path, "VLSP2013_WS_train_gold.txt") input_test_filename = os.path.join(input_path, "VLSP2013_WS_test_gold.txt") + + input_pos_filename = os.path.join(input_pos_path, "VLSP2013_POS_train_BI_POS_Column.txt.goldSeg") if not os.path.exists(input_train_filename): raise FileNotFoundError("Cannot find train set for VLSP at %s" % input_train_filename) if not os.path.exists(input_test_filename): raise FileNotFoundError("Cannot find test set for VLSP at %s" % input_test_filename) + pos_data = None + if args.include_pos_data: + if not os.path.exists(input_pos_filename): + raise FileNotFoundError("Cannot find pos dataset for VLSP at %" % input_pos_filename) + else: + pos_data = convert_pos_dataset(input_pos_filename) output_train_filename = os.path.join(tokenizer_dir, "vi_vlsp.train.gold.conllu") output_dev_filename = os.path.join(tokenizer_dir, "vi_vlsp.dev.gold.conllu") output_test_filename = os.path.join(tokenizer_dir, "vi_vlsp.test.gold.conllu") - convert_file(args.vlsp_include_spaces, input_train_filename, output_train_filename, "train", output_dev_filename, "dev") + convert_file(args.vlsp_include_spaces, input_train_filename, output_train_filename, "train", output_dev_filename, "dev", pos_data) convert_file(args.vlsp_include_spaces, input_test_filename, output_test_filename, "test") -- cgit v1.2.3 From ff5be9af94ff9093933d3c29e5e9452af4faf244 Mon Sep 17 00:00:00 2001 From: David Riff Date: Wed, 28 Jul 2021 07:34:20 +0100 Subject: Add mwt if tokenize is passed without MWT (#777) Add mwt if only tokenize is specified in both download and pipeline --- stanza/pipeline/core.py | 34 +++++----------------------------- stanza/resources/common.py | 19 +++++++++++++++++++ stanza/tests/test_installation.py | 6 ++++++ 3 files changed, 30 insertions(+), 29 deletions(-) diff --git a/stanza/pipeline/core.py b/stanza/pipeline/core.py index 09299e05..fda90e2d 100644 --- a/stanza/pipeline/core.py +++ b/stanza/pipeline/core.py @@ -24,7 +24,7 @@ from stanza.pipeline.depparse_processor import DepparseProcessor from stanza.pipeline.sentiment_processor import SentimentProcessor from stanza.pipeline.ner_processor import NERProcessor from stanza.resources.common import DEFAULT_MODEL_DIR, \ - maintain_processor_list, add_dependencies, build_default_config, set_logging_level, process_pipeline_parameters, sort_processors + maintain_processor_list, add_dependencies, add_mwt, build_default_config, set_logging_level, process_pipeline_parameters, sort_processors from stanza.utils.helper_func import make_table logger = logging.getLogger('stanza') @@ -98,7 +98,10 @@ class Pipeline: logger.warning(f'Unsupported language: {lang}.') # Maintain load list - processors = self.maybe_add_mwt(kwargs, resources, lang, processors) + if (not kwargs.get("tokenize_pretokenized") + and TOKENIZE in processors + and MWT not in processors): + add_mwt(processors, resources, lang) self.load_list = maintain_processor_list(resources, lang, package, processors) if lang in resources else [] self.load_list = add_dependencies(resources, lang, self.load_list) if lang in resources else [] self.load_list = self.update_kwargs(kwargs, self.load_list) @@ -174,33 +177,6 @@ class Pipeline: logger.info("Done loading processors!") - @staticmethod - def maybe_add_mwt(kwargs, resources, lang, processors): - """ - A hack to add MWT to languages which need it - - If tokenize is in the list, but mwt is not, and there is a corresponding - tokenize & mwt pair in the resources file, we add mwt - otherwise we'll get another 10 bugs regarding missing mwt errors - """ - # first check to see if tokenize_pretokenized is True. - # if so, then we assume MWT is already present - if kwargs.get("tokenize_pretokenized", None): - return processors - - if TOKENIZE in processors and MWT not in processors: - value = processors[TOKENIZE] - if value == 'default' and MWT in resources[lang]['default_processors']: - logger.warning("Language %s package default expects mwt, which has been added" % lang) - processors[MWT] = 'default' - elif (value in resources[lang][TOKENIZE] and MWT in resources[lang] and - value in resources[lang][MWT]): - logger.warning("Language %s package %s expects mwt, which has been added" % (lang, value)) - processors[MWT] = value - - return processors - - @staticmethod def update_kwargs(kwargs, processor_list): processor_dict = {processor: {'package': package, 'dependencies': dependencies} for (processor, package, dependencies) in processor_list} diff --git a/stanza/resources/common.py b/stanza/resources/common.py index d9837468..a1b7e690 100644 --- a/stanza/resources/common.py +++ b/stanza/resources/common.py @@ -150,11 +150,30 @@ def sort_processors(processor_list): sorted_list.append(item) return sorted_list +def add_mwt(processors, resources, lang): + """Add mwt if tokenize is passed without mwt. + + If tokenize is in the list, but mwt is not, and there is a corresponding + tokenize and mwt pair in the resources file, mwt is added so no missing + mwt errors are raised. + """ + value = processors[TOKENIZE] + if value == "default" and MWT in resources[lang]['default_processors']: + logger.warning("Language %s package default expects mwt, which has been added", lang) + processors[MWT] = 'default' + elif (value in resources[lang][TOKENIZE] + and MWT in resources[lang] + and value in resources[lang][MWT]): + logger.warning("Language %s package %s expects mwt, which has been added", lang, value) + processors[MWT] = value + def maintain_processor_list(resources, lang, package, processors): processor_list = {} # resolve processor models if processors: logger.debug(f'Processing parameter "processors"...') + if TOKENIZE in processors and MWT not in processors: + add_mwt(processors, resources, lang) for key, value in processors.items(): assert(isinstance(key, str) and isinstance(value, str)) if key not in PIPELINE_NAMES: diff --git a/stanza/tests/test_installation.py b/stanza/tests/test_installation.py index 13b8c4f9..3ec287b1 100644 --- a/stanza/tests/test_installation.py +++ b/stanza/tests/test_installation.py @@ -37,3 +37,9 @@ def test_download_corenlp_models(): dest_file = os.path.join(test_dir, f"stanford-corenlp-{version}-models-{model_name}.jar") assert os.path.isfile(dest_file), "Downloaded model file not found." + +def test_download_tokenize_mwt(): + with tempfile.TemporaryDirectory(dir=".") as test_dir: + stanza.download("en", model_dir=test_dir, processors="tokenize", package="ewt", verbose=False) + pipeline = stanza.Pipeline("en", model_dir=test_dir, processors="tokenize", package="ewt") + assert isinstance(pipeline, stanza.Pipeline) -- cgit v1.2.3 From 55433ce8d5d532ed855096deb14a366cf13f14ec Mon Sep 17 00:00:00 2001 From: John Bauer Date: Wed, 28 Jul 2021 00:00:43 -0700 Subject: Double check that the length of the processors list is 2 when adding just tokenize from EWT --- stanza/tests/test_installation.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/stanza/tests/test_installation.py b/stanza/tests/test_installation.py index 3ec287b1..69a7bb0f 100644 --- a/stanza/tests/test_installation.py +++ b/stanza/tests/test_installation.py @@ -43,3 +43,5 @@ def test_download_tokenize_mwt(): stanza.download("en", model_dir=test_dir, processors="tokenize", package="ewt", verbose=False) pipeline = stanza.Pipeline("en", model_dir=test_dir, processors="tokenize", package="ewt") assert isinstance(pipeline, stanza.Pipeline) + # mwt should be added to the list + assert len(pipeline.loaded_processors) == 2 -- cgit v1.2.3 From 595bf7fb11e18069dc5106351ec822cbfc7e4c09 Mon Sep 17 00:00:00 2001 From: John Bauer Date: Wed, 28 Jul 2021 22:33:39 -0700 Subject: Update word embedding to the dimension in the file when creating a new model --- stanza/models/ner_tagger.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/stanza/models/ner_tagger.py b/stanza/models/ner_tagger.py index 86a51caa..e255f575 100644 --- a/stanza/models/ner_tagger.py +++ b/stanza/models/ner_tagger.py @@ -142,6 +142,12 @@ def train(args): # do not save pretrained embeddings individually pretrain = Pretrain(None, vec_file, args['pretrain_max_vocab'], save_to_file=False) + if pretrain is not None: + word_emb_dim = pretrain.emb.shape[1] + if args['word_emb_dim'] and args['word_emb_dim'] != word_emb_dim: + logger.warning("Embedding file has a dimension of {}. Model will be built with that size instead of {}".format(word_emb_dim, args['word_emb_dim'])) + args['word_emb_dim'] = word_emb_dim + if args['charlm']: if args['charlm_shorthand'] is None: raise ValueError("CharLM Shorthand is required for loading pretrained CharLM model...") -- cgit v1.2.3 From 33d7d61f974ee267b0806b5024669ab78513c332 Mon Sep 17 00:00:00 2001 From: J38 Date: Fri, 30 Jul 2021 14:04:47 -0700 Subject: skip langid tests until resources set up --- stanza/tests/test_langid.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/stanza/tests/test_langid.py b/stanza/tests/test_langid.py index 9d42f9b3..3f8cc1a7 100644 --- a/stanza/tests/test_langid.py +++ b/stanza/tests/test_langid.py @@ -2,10 +2,14 @@ Basic tests of langid module """ +import pytest + from stanza.models.common.doc import Document from stanza.pipeline.core import Pipeline from stanza.pipeline.multilingual import MultilingualPipeline +pytestmark = pytest.mark.skip + def test_langid(): """ Basic test of language identification -- cgit v1.2.3 From 50f153f78e9668a24d8403ad291b7cc0622fd044 Mon Sep 17 00:00:00 2001 From: John Bauer Date: Fri, 30 Jul 2021 15:22:12 -0700 Subject: This test was backwards, causing a bunch of stray java processes when a context closes --- stanza/server/java_protobuf_requests.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stanza/server/java_protobuf_requests.py b/stanza/server/java_protobuf_requests.py index fda8e8d5..ee036387 100644 --- a/stanza/server/java_protobuf_requests.py +++ b/stanza/server/java_protobuf_requests.py @@ -77,7 +77,7 @@ class JavaProtobufContext(object): return self def __exit__(self, type, value, traceback): - if self.pipe.poll() is not None: + if self.pipe.poll() is None: self.pipe.stdin.write((0).to_bytes(4, 'big')) self.pipe.stdin.flush() -- cgit v1.2.3 From 682567762cb01315723d736b440ca1782a97bea3 Mon Sep 17 00:00:00 2001 From: John Bauer Date: Mon, 2 Aug 2021 12:02:39 -0700 Subject: Add two new NER models to the resources --- stanza/resources/prepare_resources.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/stanza/resources/prepare_resources.py b/stanza/resources/prepare_resources.py index 9421d3a1..305deb7f 100644 --- a/stanza/resources/prepare_resources.py +++ b/stanza/resources/prepare_resources.py @@ -99,6 +99,7 @@ default_treebanks = { # default ner for languages default_ners = { + "af": "nchlt", "ar": "aqmar", "bg": "bsnlp19", "de": "conll03", @@ -107,6 +108,7 @@ default_ners = { "fi": "turku", "fr": "wikiner", "hu": "combined", + "it": "fbk", "nl": "conll02", "ru": "wikiner", "uk": "languk", @@ -116,6 +118,7 @@ default_ners = { # default charlms for languages default_charlms = { + "af": "oscar", "ar": "ccwiki", "bg": "conll17", "de": "newswiki", @@ -123,6 +126,7 @@ default_charlms = { "es": "newswiki", "fi": "conll17", "fr": "newswiki", + "it": "conll17", "nl": "ccwiki", "ru": "newswiki", "vi": "conll17", -- cgit v1.2.3 From 586e96a0f3297dd1ca5a8f016905855db465a4a4 Mon Sep 17 00:00:00 2001 From: John Bauer Date: Tue, 3 Aug 2021 08:51:53 -0700 Subject: Change file not found to an error --- stanza/models/ner_tagger.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stanza/models/ner_tagger.py b/stanza/models/ner_tagger.py index e255f575..c34e9036 100644 --- a/stanza/models/ner_tagger.py +++ b/stanza/models/ner_tagger.py @@ -128,7 +128,7 @@ def train(args): _, trainer, vocab = load_model(args, model_file) else: if args['finetune']: - logger.warning('Finetune is set to true but model file is not found. Continuing with training from scratch.') + raise FileNotFoundError('Finetune is set to true but model file is not found: {}'.format(model_file)) # load pretrained vectors if args['wordvec_pretrain_file']: -- cgit v1.2.3 From 60f3cf96db625d378d5e428d94036eeee2bf5419 Mon Sep 17 00:00:00 2001 From: John Bauer Date: Tue, 3 Aug 2021 09:30:08 -0700 Subject: Add a test to see if any tags are in the dev set but not the train set --- stanza/models/common/utils.py | 20 ++++++++++++++++++++ stanza/models/ner_tagger.py | 2 ++ stanza/tests/test_utils.py | 6 ++++++ 3 files changed, 28 insertions(+) diff --git a/stanza/models/common/utils.py b/stanza/models/common/utils.py index 25491754..842bf6fd 100644 --- a/stanza/models/common/utils.py +++ b/stanza/models/common/utils.py @@ -279,3 +279,23 @@ def set_random_seed(seed, cuda): if cuda: torch.cuda.manual_seed(seed) return seed + +def find_missing_tags(known_tags, test_tags): + if isinstance(known_tags, list) and isinstance(known_tags[0], list): + known_tags = set(x for y in known_tags for x in y) + if isinstance(test_tags, list) and isinstance(test_tags[0], list): + test_tags = sorted(set(x for y in test_tags for x in y)) + missing_tags = sorted(x for x in test_tags if x not in known_tags) + return missing_tags + +def warn_missing_tags(known_tags, test_tags, test_set_name): + """ + Print a warning if any tags present in the second list are not in the first list. + + Can also handle a list of lists. + """ + missing_tags = find_missing_tags(known_tags, test_tags) + if len(missing_tags) > 0: + logger.warning("Found tags in {} missing from the expected tag set: {}".format(test_set_name, missing_tags)) + return True + return False diff --git a/stanza/models/ner_tagger.py b/stanza/models/ner_tagger.py index c34e9036..585065fa 100644 --- a/stanza/models/ner_tagger.py +++ b/stanza/models/ner_tagger.py @@ -166,6 +166,8 @@ def train(args): dev_batch = DataLoader(dev_doc, args['batch_size'], args, pretrain, vocab=vocab, evaluation=True) dev_gold_tags = dev_batch.tags + utils.warn_missing_tags(train_batch.tags, dev_batch.tags, "dev") + # skip training if the language does not have training or dev data if len(train_batch) == 0 or len(dev_batch) == 0: logger.info("Skip training because no data available...") diff --git a/stanza/tests/test_utils.py b/stanza/tests/test_utils.py index 220cd224..4b02ab07 100644 --- a/stanza/tests/test_utils.py +++ b/stanza/tests/test_utils.py @@ -123,3 +123,9 @@ def test_split_into_batches(): # double check that unsort is working as expected assert data == utils.unsort(ordered, orig_idx) + + +def test_find_missing_tags(): + assert utils.find_missing_tags(["O", "PER", "LOC"], ["O", "PER", "LOC"]) == [] + assert utils.find_missing_tags(["O", "PER", "LOC"], ["O", "PER", "LOC", "ORG"]) == ['ORG'] + assert utils.find_missing_tags([["O", "PER"], ["O", "LOC"]], [["O", "PER"], ["LOC", "ORG"]]) == ['ORG'] -- cgit v1.2.3 From 6b573dce04ad958027f79bc075d7bc086b0ed85c Mon Sep 17 00:00:00 2001 From: John Bauer Date: Tue, 3 Aug 2021 10:31:01 -0700 Subject: Also check if the test set has tags not present in the tagger or if the train set has tags not presenti in a finetune NER model --- stanza/models/ner_tagger.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/stanza/models/ner_tagger.py b/stanza/models/ner_tagger.py index 585065fa..020e2c68 100644 --- a/stanza/models/ner_tagger.py +++ b/stanza/models/ner_tagger.py @@ -166,7 +166,9 @@ def train(args): dev_batch = DataLoader(dev_doc, args['batch_size'], args, pretrain, vocab=vocab, evaluation=True) dev_gold_tags = dev_batch.tags - utils.warn_missing_tags(train_batch.tags, dev_batch.tags, "dev") + if args['finetune']: + utils.warn_missing_tags([i for i in trainer.vocab['tag']], train_batch.tags, "training set") + utils.warn_missing_tags(train_batch.tags, dev_batch.tags, "dev set") # skip training if the language does not have training or dev data if len(train_batch) == 0 or len(dev_batch) == 0: @@ -261,7 +263,8 @@ def evaluate(args): logger.info("Loading data with batch size {}...".format(args['batch_size'])) doc = Document(json.load(open(args['eval_file']))) batch = DataLoader(doc, args['batch_size'], loaded_args, vocab=vocab, evaluation=True) - + utils.warn_missing_tags([i for i in trainer.vocab['tag']], batch.tags, "eval_file") + logger.info("Start evaluation...") preds = [] for i, b in enumerate(batch): -- cgit v1.2.3 From c720a29be47396414c628ce02a81ff01c5efe081 Mon Sep 17 00:00:00 2001 From: Gordon Date: Wed, 4 Aug 2021 14:29:03 +0800 Subject: Create thai_syllable_dict_generator.py Takes in ssg data and converts it to syllable .json dictionary file. --- .../utils/datasets/thai_syllable_dict_generator.py | 53 ++++++++++++++++++++++ 1 file changed, 53 insertions(+) create mode 100644 stanza/utils/datasets/thai_syllable_dict_generator.py diff --git a/stanza/utils/datasets/thai_syllable_dict_generator.py b/stanza/utils/datasets/thai_syllable_dict_generator.py new file mode 100644 index 00000000..ca658e16 --- /dev/null +++ b/stanza/utils/datasets/thai_syllable_dict_generator.py @@ -0,0 +1,53 @@ +import glob +import pathlib +import argparse + + +def create_dictionary(dataset_dir, save_dir): + syllables = set() + + for p in pathlib.Path(dataset_dir).rglob("*.ssg"): # iterate through all files + + with open(p) as f: # for each file + sentences = f.readlines() + + for i in range(len(sentences)): + + sentences[i] = sentences[i].replace("\n", "") + sentences[i] = sentences[i].replace("", "~") + sentences[i] = sentences[i].split("~") # create list of all syllables + + syllables = syllables.union(sentences[i]) + + + print(len(syllables)) + + # Filter out syllables with English words + import re + + a = [] + + for s in syllables: + print("---") + if bool(re.match("^[\u0E00-\u0E7F]*$", s)) and s != "" and " " not in s: + a.append(s) + else: + pass + + a = set(a) + a = dict(zip(list(a), range(len(a)))) + + import json + print(a) + print(len(a)) + with open(save_dir, "w") as fp: + json.dump(a, fp) + +if __name__ == "__main__": + + parser = argparse.ArgumentParser() + parser.add_argument('--dataset_dir', type=str, default="syllable_segmentation_data", help="Directory for syllable dataset") + parser.add_argument('--save_dir', type=str, default="thai-syllable.json", help="Directory for generated file") + args = parser.parse_args() + + create_dictionary(args.dataset_dir, args.save_dir) -- cgit v1.2.3 From 6eddfafd74ce9564358beffa469da70bf6d784cb Mon Sep 17 00:00:00 2001 From: vythaihn <68755973+vythaihn@users.noreply.github.com> Date: Mon, 9 Aug 2021 17:47:33 -0700 Subject: Dictionary redo (#776) This new version of the tokenizer model incorporates the dictionary feature, especially useful for languages that have multi-syllable words such as Vietnamese, Chinese or Thai. In summary, a lexicon contains all unique words found in training dataset, and an external lexicon (if any) is created during training and saved alongside the model after training. Using this lexicon, a dictionary is created which includes "words", "prefixes" and "suffixes" sets. During data preparation, dictionary features are extracted at each character position, to "look ahead" and "look backward" to see if any words formed found in the dictionary. The window size (or the dictionary feature-length) is defined at the 95-percentile among all the existing words in the lexicon, this is to eliminate the less frequent but long words (avoid having a high-dimension feat vector). Prefixes and suffixes are used to stop early during the window-dictionary checking process. --- stanza/models/tokenization/data.py | 63 ++++++++++++++--- stanza/models/tokenization/model.py | 7 +- stanza/models/tokenization/trainer.py | 9 ++- stanza/models/tokenization/utils.py | 126 ++++++++++++++++++++++++++++++++++ stanza/models/tokenizer.py | 42 +++++++++--- 5 files changed, 225 insertions(+), 22 deletions(-) diff --git a/stanza/models/tokenization/data.py b/stanza/models/tokenization/data.py index ce059b6a..9039f818 100644 --- a/stanza/models/tokenization/data.py +++ b/stanza/models/tokenization/data.py @@ -1,14 +1,11 @@ from bisect import bisect_right from copy import copy -import json import numpy as np import random import logging import re import torch - from .vocab import Vocab - logger = logging.getLogger('stanza') def filter_consecutive_whitespaces(para): @@ -26,11 +23,11 @@ NEWLINE_WHITESPACE_RE = re.compile(r'\n\s*\n') NUMERIC_RE = re.compile(r'^([\d]+[,\.]*)+$') WHITESPACE_RE = re.compile(r'\s') - class DataLoader: - def __init__(self, args, input_files={'txt': None, 'label': None}, input_text=None, input_data=None, vocab=None, evaluation=False): + def __init__(self, args, input_files={'txt': None, 'label': None}, input_text=None, input_data=None, vocab=None, evaluation=False, dictionary=None): self.args = args self.eval = evaluation + self.dictionary = dictionary # get input files txt_file = input_files['txt'] @@ -107,8 +104,6 @@ class DataLoader: func = lambda x: 1 if x.startswith(' ') else 0 elif feat_func == 'capitalized': func = lambda x: 1 if x[0].isupper() else 0 - elif feat_func == 'all_caps': - func = lambda x: 1 if x.isupper() else 0 elif feat_func == 'numeric': func = lambda x: 1 if (NUMERIC_RE.match(x) is not None) else 0 else: @@ -119,6 +114,40 @@ class DataLoader: # stacking all featurize functions composite_func = lambda x: [f(x) for f in funcs] + length = len(para) + #This function is to extract dictionary features for each character + def extract_dict_feat(idx): + dict_forward_feats = [0 for i in range(self.args['num_dict_feat'])] + dict_backward_feats = [0 for i in range(self.args['num_dict_feat'])] + forward_word = para[idx][0] + backward_word = para[idx][0] + prefix = True + suffix = True + for window in range(1,self.args['num_dict_feat']+1): + # concatenate each character and check if words found in dict not, stop if prefix not found + #check if idx+t is out of bound and if the prefix is already not found + if (idx + window) <= length-1 and prefix: + forward_word += para[idx+window][0].lower() + #check in json file if the word is present as prefix or word or None. + feat = 1 if forward_word in self.dictionary["words"] else 0 + #if the return value is not 2 or 3 then the checking word is not a valid word in dict. + dict_forward_feats[window-1] = feat + #if the dict return 0 means no prefixes found, thus, stop looking for forward. + if forward_word not in self.dictionary["prefixes"]: + prefix = False + #backward check: similar to forward + if (idx - window) >= 0 and suffix: + backward_word = para[idx-window][0].lower() + backward_word + feat = 1 if backward_word in self.dictionary["words"] else 0 + dict_backward_feats[window-1] = feat + if backward_word not in self.dictionary["suffixes"]: + suffix = False + #if cannot find both prefix and suffix, then exit the loop + if not prefix and not suffix: + break + + return dict_forward_feats + dict_backward_feats + def process_sentence(sent): return [self.vocab.unit2id(y[0]) for y in sent], [y[1] for y in sent], [y[2] for y in sent], [y[0] for y in sent] @@ -135,6 +164,12 @@ class DataLoader: if use_start_of_para: f = 1 if i == 0 else 0 feats.append(f) + + #if dictionary feature is selected + if self.args['use_dictionary']: + dict_feats = extract_dict_feat(i) + feats = feats + dict_feats + current += [(unit, label, feats)] if label1 == 2 or label1 == 4: # end of sentence if len(current) <= self.args['max_seqlen']: @@ -156,7 +191,7 @@ class DataLoader: random.shuffle(para) self.init_sent_ids() - def next(self, eval_offsets=None, unit_dropout=0.0, old_batch=None): + def next(self, eval_offsets=None, unit_dropout=0.0, old_batch=None, feat_unit_dropout=0.0): ''' Get a batch of converted and padded PyTorch data from preprocessed raw text for training/prediction. ''' feat_size = len(self.sentences[0][0][2][0]) unkid = self.vocab.unit2id('') @@ -277,6 +312,18 @@ class DataLoader: if mask[i, j]: raw_units[i][j] = '' + # dropout unit feature vector in addition to only torch.dropout in the model. + # experiments showed that only torch.dropout hurts the model + # we believe it is because the dict feature vector is mostly scarse so it makes + # more sense to drop out the whole vector instead of only single element. + if self.args['use_dictionary'] and feat_unit_dropout > 0 and not self.eval: + mask_feat = np.random.random_sample(units.shape) < feat_unit_dropout + mask_feat[units == padid] = 0 + for i in range(len(raw_units)): + for j in range(len(raw_units[i])): + if mask_feat[i,j]: + features[i,j,:] = 0 + units = torch.from_numpy(units) labels = torch.from_numpy(labels) features = torch.from_numpy(features) diff --git a/stanza/models/tokenization/model.py b/stanza/models/tokenization/model.py index 8c4f3198..1f609871 100644 --- a/stanza/models/tokenization/model.py +++ b/stanza/models/tokenization/model.py @@ -3,7 +3,7 @@ import torch.nn.functional as F import torch.nn as nn class Tokenizer(nn.Module): - def __init__(self, args, nchars, emb_dim, hidden_dim, dropout): + def __init__(self, args, nchars, emb_dim, hidden_dim, dropout, feat_dropout): super().__init__() self.args = args @@ -37,12 +37,15 @@ class Tokenizer(nn.Module): self.mwt_clf2 = nn.Linear(hidden_dim * 2, 1, bias=False) self.dropout = nn.Dropout(dropout) + self.dropout_feat = nn.Dropout(feat_dropout) + self.toknoise = nn.Dropout(self.args['tok_noise']) def forward(self, x, feats): emb = self.embeddings(x) - emb = self.dropout(emb) + feats = self.dropout_feat(feats) + emb = torch.cat([emb, feats], 2) diff --git a/stanza/models/tokenization/trainer.py b/stanza/models/tokenization/trainer.py index bb0deb85..37fe66da 100644 --- a/stanza/models/tokenization/trainer.py +++ b/stanza/models/tokenization/trainer.py @@ -12,7 +12,7 @@ from .vocab import Vocab logger = logging.getLogger('stanza') class Trainer(BaseTrainer): - def __init__(self, args=None, vocab=None, model_file=None, use_cuda=False): + def __init__(self, args=None, vocab=None, lexicon=None, model_file=None, use_cuda=False): self.use_cuda = use_cuda if model_file is not None: # load everything from file @@ -21,7 +21,8 @@ class Trainer(BaseTrainer): # build model from scratch self.args = args self.vocab = vocab - self.model = Tokenizer(self.args, self.args['vocab_size'], self.args['emb_dim'], self.args['hidden_dim'], dropout=self.args['dropout']) + self.lexicon = lexicon + self.model = Tokenizer(self.args, self.args['vocab_size'], self.args['emb_dim'], self.args['hidden_dim'], dropout=self.args['dropout'], feat_dropout=self.args['feat_dropout']) self.criterion = nn.CrossEntropyLoss(ignore_index=-1) if use_cuda: self.model.cuda() @@ -72,6 +73,7 @@ class Trainer(BaseTrainer): params = { 'model': self.model.state_dict() if self.model is not None else None, 'vocab': self.vocab.state_dict(), + 'lexicon': self.lexicon, 'config': self.args } try: @@ -91,6 +93,7 @@ class Trainer(BaseTrainer): # Default to True as many currently saved models # were built with mwt layers self.args['use_mwt'] = True - self.model = Tokenizer(self.args, self.args['vocab_size'], self.args['emb_dim'], self.args['hidden_dim'], dropout=self.args['dropout']) + self.model = Tokenizer(self.args, self.args['vocab_size'], self.args['emb_dim'], self.args['hidden_dim'], dropout=self.args['dropout'], feat_dropout=self.args['feat_dropout']) self.model.load_state_dict(checkpoint['model']) self.vocab = Vocab.load_state_dict(checkpoint['vocab']) + self.lexicon = checkpoint['lexicon'] diff --git a/stanza/models/tokenization/utils.py b/stanza/models/tokenization/utils.py index ea7bda47..da123b58 100644 --- a/stanza/models/tokenization/utils.py +++ b/stanza/models/tokenization/utils.py @@ -4,12 +4,138 @@ import json import numpy as np import re import logging +import os +import stanza.utils.default_paths as default_paths from stanza.models.common.utils import ud_scores, harmonic_mean from stanza.utils.conll import CoNLL from stanza.models.common.doc import * logger = logging.getLogger('stanza') +paths = default_paths.get_default_paths() + +def create_dictionary(lexicon=None): + """ + This function is to create a new dictionary used for improving tokenization model for multi-syllable words languages + such as vi, zh or th. This function takes the lexicon as input and output a dictionary that contains three set: + words, prefixes and suffixes where prefixes set should contains all the prefixes in the lexicon and similar for suffixes. + The point of having prefixes/suffixes sets in the dictionary is just to make it easier to check during data preparation. + + :param shorthand - language and dataset, eg: vi_vlsp, zh_gsdsimp + :param lexicon - set of words used to create dictionary + :return a dictionary object that contains words and their prefixes and suffixes. + """ + + dictionary = {"words":set(), "prefixes":set(), "suffixes":set()} + + def add_word(word): + if word not in dictionary["words"]: + dictionary["words"].add(word) + prefix = "" + suffix = "" + for i in range(0,len(word)-1): + prefix = prefix + word[i] + suffix = word[len(word) - i - 1] + suffix + dictionary["prefixes"].add(prefix) + dictionary["suffixes"].add(suffix) + + for word in lexicon: + if len(word)>1: + add_word(word) + + return dictionary +def create_lexicon(shorthand=None, train_path=None, external_path=None): + """ + This function is to create a lexicon to store all the words from the training set and external dictionary. + This lexicon will be saved with the model and will be used to create dictionary when the model is loaded. + The idea of separating lexicon and dictionary in two different phases is a good tradeoff between time and space. + Note that we eliminate all the long words but less frequently appeared in the lexicon by only taking 95-percentile + list of words. + + :param shorthand - language and dataset, eg: vi_vlsp, zh_gsdsimp + :param train_path - path to conllu train file + :param external_path - path to extenral dict, expected to be inside the training dataset dir with format of: SHORTHAND-externaldict.txt + :return a set lexicon object that contains all distinct words + """ + lexicon = set() + length_freq = [] + #this regex is to check if a character is an actual Thai character as seems .isalpha() python method doesn't pick up Thai accent characters.. + pattern_thai = re.compile(r"(?:[^\d\W]+)|\s") + + def check_valid_word(shorthand, word): + """ + This function is to check if the word are multi-syllable words and not numbers. + For vi, whitespaces are syllabe-separator. + """ + if shorthand.startswith("vi_"): + return True if len(word.split(" ")) > 1 and any(map(str.isalpha, word)) and not any(map(str.isdigit, word)) else False + elif shorthand.startswith("th_"): + return True if len(word) > 1 and any(map(pattern_thai.match, word)) and not any(map(str.isdigit, word)) else False + else: + return True if len(word) > 1 and any(map(str.isalpha, word)) and not any(map(str.isdigit, word)) else False + + #checking for words in the training set to add them to lexicon. + if train_path is not None: + if not os.path.isfile(train_path): + raise FileNotFoundError(f"Cannot open train set at {train_path}") + + doc_conll,_ = CoNLL.conll2dict(input_file=train_path) + + for sent_conll in doc_conll: + for token_conll in sent_conll: + word = token_conll['text'].lower() + if check_valid_word(shorthand, word) and word not in lexicon: + lexicon.add(word) + length_freq.append(len(word)) + count_word = len(lexicon) + logger.info(f"Added {count_word} words from the training data to the lexicon.") + + #checking for external dictionary and add them to lexicon. + if external_path is not None: + if not os.path.isfile(external_path): + raise FileNotFoundError(f"Cannot open external dictionary at {external_path}") + + external_file = open(external_path, "r", encoding="utf-8") + lines = external_file.readlines() + for line in lines: + word = line.lower() + word = word.replace("\n","") + if check_valid_word(shorthand, word) and word not in lexicon: + lexicon.add(word) + length_freq.append(len(word)) + external_file.close() + logger.info(f"Added another {len(lexicon) - count_word} words from the external dict to dictionary.") + + + #automatically calculate the number of dictionary features (window size to look for words) based on the frequency of word length + #take the length at 95-percentile to eliminate all the longest (maybe) compounds words in the lexicon + num_dict_feat = int(np.percentile(length_freq, 95)) + lexicon = {word for word in lexicon if len(word) <= num_dict_feat } + logger.info(f"Final lexicon consists of {len(lexicon)} words after getting rid of long words.") + + return lexicon, num_dict_feat + +def load_lexicon(args): + """ + This function is to create a new dictionary and load it to training. + The external dictionary is expected to be inside the training dataset dir with format of: SHORTHAND-externaldict.txt + For example, vi_vlsp-externaldict.txt + """ + shorthand = args["shorthand"] + tokenize_dir = paths["TOKENIZE_DATA_DIR"] + train_path = f"{tokenize_dir}/{shorthand}.train.gold.conllu" + external_dict_path = f"{tokenize_dir}/{shorthand}-externaldict.txt" + if not os.path.exists(external_dict_path): + logger.info("External dictionary not found! Checking training data...") + external_dict_path = None + if not os.path.exists(train_path): + logger.info(f"Training dataset does not exist, thus cannot create dictionary {shorthand}") + train_path = None + if train_path is None and external_dict_path is None: + raise FileNotFoundError(f"Cannot find training set / external dictionary at {train_path} and {external_dict_path}") + + return create_lexicon(shorthand, train_path, external_dict_path) + def load_mwt_dict(filename): if filename is not None: diff --git a/stanza/models/tokenizer.py b/stanza/models/tokenizer.py index 54bc729f..c51720ee 100644 --- a/stanza/models/tokenizer.py +++ b/stanza/models/tokenizer.py @@ -4,6 +4,15 @@ Entry point for training and evaluating a neural tokenizer. This tokenizer treats tokenization and sentence segmentation as a tagging problem, and uses a combination of recurrent and convolutional architectures. For details please refer to paper: https://nlp.stanford.edu/pubs/qi2018universal.pdf. + +Updated: This new version of tokenizer model incorporates the dictionary feature, especially useful for languages that +have multi-syllable words such as Vietnamese, Chinese or Thai. In summary, a lexicon contains all unique words found in +training dataset and external lexicon (if any) is created during training and saved alongside the model after training. +Using this lexicon, a dictionary is created which includes "words", "prefixes" and "suffixes" sets. During data preparation, +dictionary features are extracted at each character position, to "look ahead" and "look backward" to see if any words formed +found in the dictionary. The window size (or the dictionary feature length) is defined at the 95-percentile among all the existing +words in the lexicon, this is to eliminate the less frequent but long words (avoid having a high-dimension feat vector). Prefixes +and suffixes are used to stop early during the window-dictionary checking process. """ import argparse @@ -13,11 +22,11 @@ import random import numpy as np import os import torch - +import json from stanza.models.common import utils from stanza.models.tokenization.trainer import Trainer from stanza.models.tokenization.data import DataLoader -from stanza.models.tokenization.utils import load_mwt_dict, eval_model, output_predictions +from stanza.models.tokenization.utils import load_mwt_dict, eval_model, output_predictions, load_lexicon, create_dictionary from stanza.models import _training_logging logger = logging.getLogger('stanza') @@ -49,6 +58,7 @@ def parse_args(args=None): parser.add_argument('--input_dropout', action='store_true', help="Dropout input embeddings as well") parser.add_argument('--conv_res', type=str, default=None, help="Convolutional residual layers for the RNN") parser.add_argument('--rnn_layers', type=int, default=1, help="Layers of RNN in the tokenizer") + parser.add_argument('--use_dictionary', action='store_true', help="Use dictionary feature. The lexicon is created using the training data and external dict (if any) expected to be found under the same folder of training dataset, formatted as SHORTHAND-externaldict.txt where each line in this file is a word. For example, data/zh_gsdsimp-externaldict.txt") parser.add_argument('--max_grad_norm', type=float, default=1.0, help="Maximum gradient norm to clip to") parser.add_argument('--anneal', type=float, default=.999, help="Anneal the learning rate by this amount when dev performance deteriorate") @@ -56,6 +66,8 @@ def parse_args(args=None): parser.add_argument('--lr0', type=float, default=2e-3, help="Initial learning rate") parser.add_argument('--dropout', type=float, default=0.33, help="Dropout probability") parser.add_argument('--unit_dropout', type=float, default=0.33, help="Unit dropout probability") + parser.add_argument('--feat_dropout', type=float, default=0.05, help="Features dropout probability for each element in feature vector") + parser.add_argument('--feat_unit_dropout', type=float, default=0.33, help="The whole feature of units dropout probability") parser.add_argument('--tok_noise', type=float, default=0.02, help="Probability to induce noise to the input of the higher RNN") parser.add_argument('--sent_drop_prob', type=float, default=0.2, help="Probability to drop sentences at the end of batches during training uniformly at random. Idea is to fake paragraph endings.") parser.add_argument('--weight_decay', type=float, default=0.0, help="Weight decay") @@ -90,13 +102,19 @@ def main(args=None): args = vars(args) logger.info("Running tokenizer in {} mode".format(args['mode'])) - args['feat_funcs'] = ['space_before', 'capitalized', 'all_caps', 'numeric'] + args['feat_funcs'] = ['space_before', 'capitalized', 'numeric', 'end_of_para', 'start_of_para'] args['feat_dim'] = len(args['feat_funcs']) save_name = args['save_name'] if args['save_name'] else '{}_tokenizer.pt'.format(args['shorthand']) args['save_name'] = os.path.join(args['save_dir'], save_name) utils.ensure_dir(args['save_dir']) if args['mode'] == 'train': + #load lexicon + args['lexicon'], args['num_dict_feat'] = (None, None) if not args["use_dictionary"] else load_lexicon(args) + #create the dictionary + args['dictionary'] = None if not args["use_dictionary"] else create_dictionary(args['lexicon']) + #adjust the feat_dim + args['feat_dim'] += args['num_dict_feat']*2 if args["use_dictionary"] else 0 train(args) else: evaluate(args) @@ -108,21 +126,22 @@ def train(args): 'txt': args['txt_file'], 'label': args['label_file'] } - train_batches = DataLoader(args, input_files=train_input_files) + train_batches = DataLoader(args, input_files=train_input_files, dictionary=args["dictionary"]) vocab = train_batches.vocab + args['vocab_size'] = len(vocab) dev_input_files = { 'txt': args['dev_txt_file'], 'label': args['dev_label_file'] } - dev_batches = DataLoader(args, input_files=dev_input_files, vocab=vocab, evaluation=True) + dev_batches = DataLoader(args, input_files=dev_input_files, vocab=vocab, evaluation=True, dictionary=args["dictionary"]) if args['use_mwt'] is None: args['use_mwt'] = train_batches.has_mwt() logger.info("Found {}mwts in the training data. Setting use_mwt to {}".format(("" if args['use_mwt'] else "no "), args['use_mwt'])) - trainer = Trainer(args=args, vocab=vocab, use_cuda=args['cuda']) + trainer = Trainer(args=args, vocab=vocab, lexicon=args['lexicon'], use_cuda=args['cuda']) if args['load_name'] is not None: load_name = os.path.join(args['save_dir'], args['load_name']) @@ -138,7 +157,7 @@ def train(args): best_dev_step = -1 for step in range(1, steps+1): - batch = train_batches.next(unit_dropout=args['unit_dropout']) + batch = train_batches.next(unit_dropout=args['unit_dropout'], feat_unit_dropout = args['feat_unit_dropout']) loss = trainer.update(batch) if step % args['report_steps'] == 0: @@ -180,16 +199,21 @@ def evaluate(args): use_cuda = args['cuda'] and not args['cpu'] trainer = Trainer(model_file=args['load_name'] or args['save_name'], use_cuda=use_cuda) loaded_args, vocab = trainer.args, trainer.vocab + for k in loaded_args: if not k.endswith('_file') and k not in ['cuda', 'mode', 'save_dir', 'load_name', 'save_name']: args[k] = loaded_args[k] - + + args['lexicon'] = None if not args['use_dictionary'] else trainer.lexicon + args['dictionary'] = None if not args['use_dictionary'] else create_dictionary(lexicon) + eval_input_files = { 'txt': args['txt_file'], 'label': args['label_file'] } - batches = DataLoader(args, input_files=eval_input_files, vocab=vocab, evaluation=True) + + batches = DataLoader(args, input_files=eval_input_files, vocab=vocab, evaluation=True, dictionary=args['dictionary']) oov_count, N, _, _ = output_predictions(args['conll_file'], trainer, batches, vocab, mwt_dict, args['max_seqlen']) -- cgit v1.2.3 From a7b78d69c264f6983df38c36a8aefe900459e212 Mon Sep 17 00:00:00 2001 From: John Bauer Date: Mon, 9 Aug 2021 18:08:31 -0700 Subject: Open/close files in a context to guarantee handles are closed --- stanza/models/tokenization/utils.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/stanza/models/tokenization/utils.py b/stanza/models/tokenization/utils.py index da123b58..28156a44 100644 --- a/stanza/models/tokenization/utils.py +++ b/stanza/models/tokenization/utils.py @@ -95,15 +95,14 @@ def create_lexicon(shorthand=None, train_path=None, external_path=None): if not os.path.isfile(external_path): raise FileNotFoundError(f"Cannot open external dictionary at {external_path}") - external_file = open(external_path, "r", encoding="utf-8") - lines = external_file.readlines() + with open(external_path, "r", encoding="utf-8") as external_file: + lines = external_file.readlines() for line in lines: word = line.lower() word = word.replace("\n","") if check_valid_word(shorthand, word) and word not in lexicon: lexicon.add(word) length_freq.append(len(word)) - external_file.close() logger.info(f"Added another {len(lexicon) - count_word} words from the external dict to dictionary.") -- cgit v1.2.3 From d3ed438b2e1437e9c7d31417a6488b3e862360b8 Mon Sep 17 00:00:00 2001 From: Gordon Date: Tue, 10 Aug 2021 12:54:54 +0800 Subject: Updated BEST to include TEST_100K --- .../utils/datasets/tokenization/convert_th_best.py | 70 +++++++++++++++++----- 1 file changed, 56 insertions(+), 14 deletions(-) diff --git a/stanza/utils/datasets/tokenization/convert_th_best.py b/stanza/utils/datasets/tokenization/convert_th_best.py index 416c84b2..b6e0c6cc 100644 --- a/stanza/utils/datasets/tokenization/convert_th_best.py +++ b/stanza/utils/datasets/tokenization/convert_th_best.py @@ -19,7 +19,6 @@ https://aiforthai.in.th/corpus.php python3 -m stanza.utils.datasets.tokenization.process_best extern_data/thai/best data/tokenize ./scripts/run_tokenize.sh UD_Thai-best --dropout 0.05 --unit_dropout 0.05 --steps 50000 """ - import glob import os import random @@ -28,34 +27,36 @@ import sys from pythainlp import sent_tokenize -from stanza.utils.datasets.tokenization.process_thai_tokenization import reprocess_lines, write_dataset, convert_processed_lines +from stanza.utils.datasets.tokenization.process_thai_tokenization import reprocess_lines, write_dataset, convert_processed_lines, write_dataset_best, write_dataset def clean_line(line): line = line.replace("html>", "html|>") - # news_00089.txt + # news_00089.txt line = line.replace("", "") line = line.replace("", "") - # specific error that occurs in encyclopedia_00095.txt + # specific error that occurs in encyclopedia_00095.txt line = line.replace("Penn", "|Penn>") - # news_00058.txt + # news_00058.txt line = line.replace("จม.เปิดผนึก", "จม.|เปิดผนึก") - # news_00015.txt + # news_00015.txt line = re.sub("([^|<>]+)([^|<>]+)", "\\1|\\2", line) - # news_00024.txt + # news_00024.txt line = re.sub("([^|<>]+)", "\\1", line) - # news_00055.txt + # news_00055.txt line = re.sub("([^|<>]+)([^|<>]+)", "\\1|\\2", line) line = re.sub("([^|<>]+)([^|<>]+)", "\\1|\\2", line) line = re.sub("([^|<>]+)([^|<>]+) ([^|<>]+)", "\\1|\\2|\\3", line) - # news_00008.txt and other news articles + # news_00008.txt and other news articles line = re.sub("([0-9])", "|\\1", line) line = line.replace(" ", "|") + line = line.replace("", "") + line = line.replace("", "") line = line.strip() return line def clean_word(word): - # novel_00078.txt + # novel_00078.txt if word == '': return 'พี่มน' if word.startswith("") and word.endswith(""): @@ -64,6 +65,12 @@ def clean_word(word): return word[4:-5] if word.startswith("") and word.endswith(""): return word[6:-7] + """ + if word.startswith(""): + return word[4:] + if word.endswith(""): + return word[:-5] + """ if word.startswith(""): return word[4:] if word.endswith(""): @@ -77,6 +84,12 @@ def clean_word(word): return word def read_data(input_dir): + + # data for test sets + test_files = [os.path.join(input_dir, 'TEST_100K_ANS.txt')] + print(test_files) + + # data for train and dev sets subdirs = [os.path.join(input_dir, 'article'), os.path.join(input_dir, 'encyclopedia'), os.path.join(input_dir, 'news'), @@ -87,9 +100,32 @@ def read_data(input_dir): raise FileNotFoundError("Expected a directory that did not exist: {}".format(subdir)) files.extend(glob.glob(os.path.join(subdir, '*.txt'))) + test_documents = [] + for filename in test_files: + print("File name:", filename) + with open(filename) as fin: + processed_lines = [] + for line in fin.readlines(): + line = clean_line(line) + words = line.split("|") + words = [clean_word(x) for x in words] + for word in words: + if len(word) > 1 and word[0] == '<': + raise ValueError("Unexpected word '{}' in document {}".format(word, filename)) + words = [x for x in words if x] + processed_lines.append(words) + + processed_lines = reprocess_lines(processed_lines) + paragraphs = convert_processed_lines(processed_lines) + + test_documents.extend(paragraphs) + print("Test document finished.") + documents = [] + for filename in files: with open(filename) as fin: + print("File:", filename) processed_lines = [] for line in fin.readlines(): line = clean_line(line) @@ -106,7 +142,10 @@ def read_data(input_dir): documents.extend(paragraphs) - return documents + print("All documents finished.") + + return documents, test_documents + def main(*args): random.seed(1000) @@ -116,13 +155,16 @@ def main(*args): input_dir = args[0] full_input_dir = os.path.join(input_dir, "thai", "best") if os.path.exists(full_input_dir): - # otherwise hopefully the user gave us the full path? + # otherwise hopefully the user gave us the full path? input_dir = full_input_dir output_dir = args[1] - documents = read_data(input_dir) - write_dataset(documents, output_dir, "best") + documents, test_documents = read_data(input_dir) + print("Finished reading data.") + write_dataset_best(documents, test_documents, output_dir, "best") if __name__ == '__main__': main() + + -- cgit v1.2.3 From 91fb4a97b2942f7823e0b7c7ee60adc2be44396c Mon Sep 17 00:00:00 2001 From: Gordon Date: Tue, 10 Aug 2021 12:57:03 +0800 Subject: includes TEST_100K as test set for BEST eval --- .../tokenization/process_thai_tokenization.py | 40 +++++++++++++++++----- 1 file changed, 31 insertions(+), 9 deletions(-) diff --git a/stanza/utils/datasets/tokenization/process_thai_tokenization.py b/stanza/utils/datasets/tokenization/process_thai_tokenization.py index d92ab674..064b2e10 100644 --- a/stanza/utils/datasets/tokenization/process_thai_tokenization.py +++ b/stanza/utils/datasets/tokenization/process_thai_tokenization.py @@ -80,6 +80,17 @@ def write_dataset(documents, output_dir, dataset_name): write_section(output_dir, dataset_name, 'dev', documents[num_train:num_train+num_dev]) write_section(output_dir, dataset_name, 'test', documents[num_train+num_dev:]) +def write_dataset_best(documents, test_documents, output_dir, dataset_name): + """ + Shuffle a list of documents, write three sections + """ + random.shuffle(documents) + num_train = int(len(documents) * 0.85) + num_dev = int(len(documents) * 0.15) + os.makedirs(output_dir, exist_ok=True) + write_section(output_dir, dataset_name, 'train', documents[:num_train]) + write_section(output_dir, dataset_name, 'dev', documents[num_train:num_train+num_dev]) + write_section(output_dir, dataset_name, 'test', test_documents) def reprocess_lines(processed_lines): @@ -133,28 +144,35 @@ def reprocess_lines(processed_lines): return reprocessed_lines def convert_processed_lines(processed_lines): - """ - Convert a list of sentences into documents suitable for the output methods in this module. - - Input: a list of lines, including space words - Output: a list of documents, each document containing a list of sentences - Each sentence is a list of words: (text, space_follows) - Space words will be eliminated. + """ + Convert a list of sentences into documents suitable for the output methods in this module. + + Input: a list of lines, including space words + Output: a list of documents, each document containing a list of sentences + Each sentence is a list of words: (text, space_follows) + Space words will be eliminated. """ paragraphs = [] sentences = [] for words in processed_lines: - # turn the words into a sentence + # turn the words into a sentence + + if len(words) > 1 and " " == words[0]: + words = words[1:] + elif len(words) == 1 and " " == words[0]: + words = [] + sentence = [] for word in words: word = word.strip() if not word: if len(sentence) == 0: + print(word) raise ValueError("Unexpected space at start of sentence in document {}".format(filename)) sentence[-1] = (sentence[-1][0], True) else: sentence.append((word, False)) - # blank lines are very rare in best, but why not treat them as a paragraph break + # blank lines are very rare in best, but why not treat them as a paragraph break if len(sentence) == 0: paragraphs.append([sentences]) sentences = [] @@ -164,3 +182,7 @@ def convert_processed_lines(processed_lines): paragraphs.append([sentences]) return paragraphs + + + + -- cgit v1.2.3 -- cgit v1.2.3 From c9b6c0d98fb2cea97e97623306c09f1481072aea Mon Sep 17 00:00:00 2001 From: John Bauer Date: Wed, 11 Aug 2021 16:13:49 -0700 Subject: Fix some whitespace --- .../utils/datasets/tokenization/convert_th_best.py | 34 ++++++++++------------ .../tokenization/process_thai_tokenization.py | 23 +++++++-------- 2 files changed, 27 insertions(+), 30 deletions(-) diff --git a/stanza/utils/datasets/tokenization/convert_th_best.py b/stanza/utils/datasets/tokenization/convert_th_best.py index b6e0c6cc..778f2dac 100644 --- a/stanza/utils/datasets/tokenization/convert_th_best.py +++ b/stanza/utils/datasets/tokenization/convert_th_best.py @@ -31,22 +31,22 @@ from stanza.utils.datasets.tokenization.process_thai_tokenization import reproce def clean_line(line): line = line.replace("html>", "html|>") - # news_00089.txt + # news_00089.txt line = line.replace("", "") line = line.replace("", "") - # specific error that occurs in encyclopedia_00095.txt + # specific error that occurs in encyclopedia_00095.txt line = line.replace("Penn", "|Penn>") - # news_00058.txt + # news_00058.txt line = line.replace("จม.เปิดผนึก", "จม.|เปิดผนึก") - # news_00015.txt + # news_00015.txt line = re.sub("([^|<>]+)([^|<>]+)", "\\1|\\2", line) - # news_00024.txt + # news_00024.txt line = re.sub("([^|<>]+)", "\\1", line) - # news_00055.txt + # news_00055.txt line = re.sub("([^|<>]+)([^|<>]+)", "\\1|\\2", line) line = re.sub("([^|<>]+)([^|<>]+)", "\\1|\\2", line) line = re.sub("([^|<>]+)([^|<>]+) ([^|<>]+)", "\\1|\\2|\\3", line) - # news_00008.txt and other news articles + # news_00008.txt and other news articles line = re.sub("([0-9])", "|\\1", line) line = line.replace(" ", "|") line = line.replace("", "") @@ -56,7 +56,7 @@ def clean_line(line): def clean_word(word): - # novel_00078.txt + # novel_00078.txt if word == '': return 'พี่มน' if word.startswith("") and word.endswith(""): @@ -65,11 +65,11 @@ def clean_word(word): return word[4:-5] if word.startswith("") and word.endswith(""): return word[6:-7] - """ - if word.startswith(""): - return word[4:] - if word.endswith(""): - return word[:-5] + """ + if word.startswith(""): + return word[4:] + if word.endswith(""): + return word[:-5] """ if word.startswith(""): return word[4:] @@ -84,12 +84,11 @@ def clean_word(word): return word def read_data(input_dir): - - # data for test sets + # data for test sets test_files = [os.path.join(input_dir, 'TEST_100K_ANS.txt')] print(test_files) - # data for train and dev sets + # data for train and dev sets subdirs = [os.path.join(input_dir, 'article'), os.path.join(input_dir, 'encyclopedia'), os.path.join(input_dir, 'news'), @@ -155,7 +154,7 @@ def main(*args): input_dir = args[0] full_input_dir = os.path.join(input_dir, "thai", "best") if os.path.exists(full_input_dir): - # otherwise hopefully the user gave us the full path? + # otherwise hopefully the user gave us the full path? input_dir = full_input_dir output_dir = args[1] @@ -167,4 +166,3 @@ def main(*args): if __name__ == '__main__': main() - diff --git a/stanza/utils/datasets/tokenization/process_thai_tokenization.py b/stanza/utils/datasets/tokenization/process_thai_tokenization.py index 064b2e10..5ef0e3d5 100644 --- a/stanza/utils/datasets/tokenization/process_thai_tokenization.py +++ b/stanza/utils/datasets/tokenization/process_thai_tokenization.py @@ -81,8 +81,8 @@ def write_dataset(documents, output_dir, dataset_name): write_section(output_dir, dataset_name, 'test', documents[num_train+num_dev:]) def write_dataset_best(documents, test_documents, output_dir, dataset_name): - """ - Shuffle a list of documents, write three sections + """ + Shuffle a list of documents, write three sections """ random.shuffle(documents) num_train = int(len(documents) * 0.85) @@ -144,19 +144,18 @@ def reprocess_lines(processed_lines): return reprocessed_lines def convert_processed_lines(processed_lines): - """ - Convert a list of sentences into documents suitable for the output methods in this module. - - Input: a list of lines, including space words - Output: a list of documents, each document containing a list of sentences - Each sentence is a list of words: (text, space_follows) - Space words will be eliminated. + """ + Convert a list of sentences into documents suitable for the output methods in this module. + + Input: a list of lines, including space words + Output: a list of documents, each document containing a list of sentences + Each sentence is a list of words: (text, space_follows) + Space words will be eliminated. """ paragraphs = [] sentences = [] for words in processed_lines: - # turn the words into a sentence - + # turn the words into a sentence if len(words) > 1 and " " == words[0]: words = words[1:] elif len(words) == 1 and " " == words[0]: @@ -172,7 +171,7 @@ def convert_processed_lines(processed_lines): sentence[-1] = (sentence[-1][0], True) else: sentence.append((word, False)) - # blank lines are very rare in best, but why not treat them as a paragraph break + # blank lines are very rare in best, but why not treat them as a paragraph break if len(sentence) == 0: paragraphs.append([sentences]) sentences = [] -- cgit v1.2.3 From 47b6d7b55e91c07944fb139a8a9e367d5ad10ce4 Mon Sep 17 00:00:00 2001 From: John Bauer Date: Mon, 16 Aug 2021 18:32:17 -0700 Subject: Fix a comment --- stanza/models/classifiers/cnn_classifier.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/stanza/models/classifiers/cnn_classifier.py b/stanza/models/classifiers/cnn_classifier.py index e912fc3d..9cb1bd45 100644 --- a/stanza/models/classifiers/cnn_classifier.py +++ b/stanza/models/classifiers/cnn_classifier.py @@ -112,7 +112,8 @@ class CNNClassifier(nn.Module): self.extra_vocab = list(extra_vocab) self.extra_vocab_map = { word: i for i, word in enumerate(self.extra_vocab) } # TODO: possibly add regularization specifically on the extra embedding? - # TODO FIXME: word of idx 0 is being shared with the padding! + # note: it looks like a bug that this doesn't add UNK or PAD, but actually + # those are expected to already be the first two entries self.extra_embedding = nn.Embedding(num_embeddings = len(extra_vocab), embedding_dim = self.config.extra_wordvec_dim, max_norm = self.config.extra_wordvec_max_norm, -- cgit v1.2.3 From 5f910c027700ca99c449035ac16bbcb503d74cb3 Mon Sep 17 00:00:00 2001 From: Andrew Garkavyi Date: Tue, 17 Aug 2021 04:13:04 +0300 Subject: Ukrainian Ner: use train/test split file instead of random (#787) * feat: Ner-languk: added ability to read predefined split of train/test from file - read split from doc/dev-test-split.txt - old logic to randomly split is still there The prior version of code that created train/dev/test split used a completely random approach with weights by sets. languk community whose data set is used provides files that list recommended split for train and test sets. The reason to use it is that split was verified not to contain documents from the same sources: several documents can be from the same book and to avoid bias in train data, such occurrences were manually removed. So train set does not contain parts from the book that exists in the test set. The dev set is randomly created from train as before. Overall this allows getting more realistic results of the model. But no performance changes in the trained model were observed after introducing this change. --- stanza/utils/datasets/ner/convert_bsf_to_beios.py | 63 +++++++++++++++++++++-- stanza/utils/datasets/ner/prepare_ner_dataset.py | 3 +- 2 files changed, 60 insertions(+), 6 deletions(-) diff --git a/stanza/utils/datasets/ner/convert_bsf_to_beios.py b/stanza/utils/datasets/ner/convert_bsf_to_beios.py index 6309efe2..16b7150d 100644 --- a/stanza/utils/datasets/ner/convert_bsf_to_beios.py +++ b/stanza/utils/datasets/ner/convert_bsf_to_beios.py @@ -4,8 +4,9 @@ import os import glob from collections import namedtuple import re +from typing import Tuple from tqdm import tqdm -from random import choices +from random import choices, shuffle BsfInfo = namedtuple('BsfInfo', 'id, tag, start_idx, end_idx, token') @@ -93,14 +94,16 @@ def parse_bsf(bsf_data: str) -> list: CORPUS_NAME = 'Ukrainian-languk' + def convert_bsf_in_folder(src_dir_path: str, dst_dir_path: str, converter: str = 'beios', - doc_delim: str = '\n') -> None: + doc_delim: str = '\n', train_test_split_file: str = None) -> None: """ :param doc_delim: delimiter to be used between documents :param src_dir_path: path to directory with BSF marked files :param dst_dir_path: where to save output data :param converter: `beios` or `iob` output formats + :param train_test_split_file: path to file cotaining train/test lists of file names :return: """ ann_path = os.path.join(src_dir_path, '*.tok.ann') @@ -127,7 +130,10 @@ def convert_bsf_in_folder(src_dir_path: str, dst_dir_path: str, converter: str = data_sets = [train_set, dev_set, test_set] split_weights = (8, 1, 1) - log.info(f'Found {len(tok_files)} files') + if train_test_split_file is not None: + train_names, dev_names, test_names = read_languk_train_test_split(train_test_split_file) + + log.info(f'Found {len(tok_files)} files in data folder "{src_dir_path}"') for (tok_fname, ann_fname) in tqdm(zip(tok_files, ann_files), total=len(tok_files), unit='file'): if tok_fname[:-3] != ann_fname[:-3]: tqdm.write(f'Token and Annotation file names do not match ann={ann_fname}, tok={tok_fname}') @@ -138,7 +144,16 @@ def convert_bsf_in_folder(src_dir_path: str, dst_dir_path: str, converter: str = ann_data = ann_file.read() out_data = convert_bsf(token_data, ann_data, converter) - target_dataset = choices(data_sets, split_weights)[0] + if train_test_split_file is None: + target_dataset = choices(data_sets, split_weights)[0] + else: + target_dataset = train_set + fkey = os.path.basename(tok_fname)[:-4] + if fkey in dev_names: + target_dataset = dev_set + elif fkey in test_names: + target_dataset = test_set + target_dataset.append(out_data) log.info(f'Data is split as following: train={len(train_set)}, dev={len(dev_set)}, test={len(test_set)}') @@ -155,6 +170,43 @@ def convert_bsf_in_folder(src_dir_path: str, dst_dir_path: str, converter: str = log.info('All done') +def read_languk_train_test_split(file_path: str, dev_split: float = 0.1) -> Tuple: + """ + Read predefined split of train and test files in data set. + Originally located under doc/dev-test-split.txt + :param file_path: path to dev-test-split.txt file (should include file name with extension) + :param dev_split: 0 to 1 float value defining how much to allocate to dev split + :return: tuple of (train, dev, test) each containing list of files to be used for respective data sets + """ + log.info(f'Trying to read train/dev/test split from file "{file_path}". Dev allocation = {dev_split}') + train_files, test_files, dev_files = [], [], [] + container = test_files + with open(file_path, 'r') as f: + for ln in f: + ln = ln.strip() + if ln == 'DEV': + container = train_files + elif ln == 'TEST': + container = test_files + elif ln == '': + pass + else: + container.append(ln) + + # split in file only contains train and test split. + # For Stanza training we need train, dev, test + # We will take part of train as dev set + # This way anyone using test set outside of this code base can be sure that there was no data set polution + shuffle(train_files) + dev_files = train_files[: int(len(train_files) * dev_split)] + train_files = train_files[int(len(train_files) * dev_split):] + + assert len(set(train_files).intersection(set(dev_files))) == 0 + + log.info(f'Files in each set: train={len(train_files)}, dev={len(dev_files)}, test={len(test_files)}') + return train_files, dev_files, test_files + + if __name__ == '__main__': logging.basicConfig() @@ -165,7 +217,8 @@ if __name__ == '__main__': parser.add_argument('--dst', type=str, default='data/ner', help='Where to store the converted dataset') parser.add_argument('-c', type=str, default='beios', help='`beios` or `iob` formats to be used for output') parser.add_argument('--doc_delim', type=str, default='\n', help='Delimiter to be used to separate documents in the output data') + parser.add_argument('--split_file', type=str, help='Name of a file containing Train/Test split (files in train and test set)') parser.print_help() args = parser.parse_args() - convert_bsf_in_folder(args.src_dataset, args.dst, args.c, args.doc_delim) + convert_bsf_in_folder(args.src_dataset, args.dst, args.c, args.doc_delim, train_test_split_file=args.split_file) diff --git a/stanza/utils/datasets/ner/prepare_ner_dataset.py b/stanza/utils/datasets/ner/prepare_ner_dataset.py index 33d2d04a..c72ce1fa 100644 --- a/stanza/utils/datasets/ner/prepare_ner_dataset.py +++ b/stanza/utils/datasets/ner/prepare_ner_dataset.py @@ -156,7 +156,8 @@ def process_languk(paths): short_name = 'uk_languk' base_input_path = os.path.join(paths["NERBASE"], 'lang-uk', 'ner-uk', 'data') base_output_path = paths["NER_DATA_DIR"] - convert_bsf_to_beios.convert_bsf_in_folder(base_input_path, base_output_path) + train_test_split_fname = os.path.join(paths["NERBASE"], 'lang-uk', 'ner-uk', 'doc', 'dev-test-split.txt') + convert_bsf_to_beios.convert_bsf_in_folder(base_input_path, base_output_path, train_test_split_file=train_test_split_fname) for shard in SHARDS: input_filename = os.path.join(base_output_path, convert_bsf_to_beios.CORPUS_NAME, "%s.bio" % shard) if not os.path.exists(input_filename): -- cgit v1.2.3 From 46118e61bb90ae28ee605b8394f6a813c969c4a4 Mon Sep 17 00:00:00 2001 From: John Bauer Date: Sun, 22 Aug 2021 00:13:35 -0700 Subject: Only report max dev score if ran dev set at least once --- stanza/models/ner_tagger.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/stanza/models/ner_tagger.py b/stanza/models/ner_tagger.py index 020e2c68..dd24dbbb 100644 --- a/stanza/models/ner_tagger.py +++ b/stanza/models/ner_tagger.py @@ -249,8 +249,9 @@ def train(args): logger.info("Training ended with {} steps.".format(global_step)) - best_f, best_eval = max(dev_score_history)*100, np.argmax(dev_score_history)+1 - logger.info("Best dev F1 = {:.2f}, at iteration = {}".format(best_f, best_eval * args['eval_interval'])) + if len(dev_score_history) > 0: + best_f, best_eval = max(dev_score_history)*100, np.argmax(dev_score_history)+1 + logger.info("Best dev F1 = {:.2f}, at iteration = {}".format(best_f, best_eval * args['eval_interval'])) def evaluate(args): # file paths -- cgit v1.2.3 From 80474a913363b7095c6bb292b1b9a20bd6c852f4 Mon Sep 17 00:00:00 2001 From: John Bauer Date: Sun, 22 Aug 2021 00:28:34 -0700 Subject: Raise a more descriptive error for missing charlm files --- stanza/models/ner/model.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/stanza/models/ner/model.py b/stanza/models/ner/model.py index efad8d51..dbd93b35 100644 --- a/stanza/models/ner/model.py +++ b/stanza/models/ner/model.py @@ -1,3 +1,4 @@ +import os import numpy as np import torch import torch.nn as nn @@ -35,6 +36,10 @@ class NERTagger(nn.Module): if self.args['char'] and self.args['char_emb_dim'] > 0: if self.args['charlm']: + if args['charlm_forward_file'] is None or not os.path.exists(args['charlm_forward_file']): + raise FileNotFoundError('Could not find forward character model: {} Please specify with --charlm_forward_file'.format(args['charlm_forward_file'])) + if args['charlm_backward_file'] is None or not os.path.exists(args['charlm_backward_file']): + raise FileNotFoundError('Could not find backward character model: {} Please specify with --charlm_backward_file'.format(args['charlm_backward_file'])) add_unsaved_module('charmodel_forward', CharacterLanguageModel.load(args['charlm_forward_file'], finetune=False)) add_unsaved_module('charmodel_backward', CharacterLanguageModel.load(args['charlm_backward_file'], finetune=False)) input_size += self.charmodel_forward.hidden_dim() + self.charmodel_backward.hidden_dim() -- cgit v1.2.3 From 5e3017430c41f09fcf1739b24494294d3175f7b9 Mon Sep 17 00:00:00 2001 From: John Bauer Date: Wed, 25 Aug 2021 23:58:42 -0700 Subject: Fix a typo... technically this bug is in the currently released 1.2.3 --- stanza/models/classifier.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stanza/models/classifier.py b/stanza/models/classifier.py index 2bbe81d9..445b2557 100644 --- a/stanza/models/classifier.py +++ b/stanza/models/classifier.py @@ -22,7 +22,7 @@ import stanza.models.classifiers.classifier_args as classifier_args import stanza.models.classifiers.cnn_classifier as cnn_classifier import stanza.models.classifiers.data as data -from stanza.utils.confusion impmort format_confusion +from stanza.utils.confusion import format_confusion class Loss(Enum): -- cgit v1.2.3 From 3cba41b721f6403c836edda2a6bf9f185aa5b86d Mon Sep 17 00:00:00 2001 From: John Bauer Date: Thu, 26 Aug 2021 11:14:47 -0700 Subject: Add a conversion of the extra train to the prep_sentiment script --- scripts/sentiment/process_sst.sh | 19 +++++++++++++++++++ stanza/models/classifier.py | 4 ++-- 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/scripts/sentiment/process_sst.sh b/scripts/sentiment/process_sst.sh index 7ee7fb67..ac33990f 100755 --- a/scripts/sentiment/process_sst.sh +++ b/scripts/sentiment/process_sst.sh @@ -39,6 +39,12 @@ java edu.stanford.nlp.trees.OutputSubtrees -input $INPUT_DIR/fiveclass/dev.txt > echo $OUTPUT_DIR/fiveclass/test-phrases.txt java edu.stanford.nlp.trees.OutputSubtrees -input $INPUT_DIR/fiveclass/test.txt > $OUTPUT_DIR/fiveclass/test-phrases.txt +echo $OUTPUT_DIR/fiveclass/extra-train-phrases.txt +java edu.stanford.nlp.trees.OutputSubtrees -input $INPUT_DIR/fiveclass/extra-train.txt > $OUTPUT_DIR/fiveclass/extra-train-phrases.txt + +echo $OUTPUT_DIR/fiveclass/checked-extra-train-phrases.txt +java edu.stanford.nlp.trees.OutputSubtrees -input $INPUT_DIR/fiveclass/checked-extra-train.txt > $OUTPUT_DIR/fiveclass/checked-extra-train-phrases.txt + echo $OUTPUT_DIR/fiveclass/train-roots.txt java edu.stanford.nlp.trees.OutputSubtrees -input $INPUT_DIR/fiveclass/train.txt -root_only > $OUTPUT_DIR/fiveclass/train-roots.txt @@ -59,6 +65,12 @@ java edu.stanford.nlp.trees.OutputSubtrees -input $INPUT_DIR/fiveclass/dev.txt - echo $OUTPUT_DIR/binary/test-binary-phrases.txt java edu.stanford.nlp.trees.OutputSubtrees -input $INPUT_DIR/fiveclass/test.txt -ignore_labels 2 -remap_labels "1=0,2=-1,3=1,4=1" > $OUTPUT_DIR/binary/test-binary-phrases.txt +echo $OUTPUT_DIR/binary/extra-train-binary-phrases.txt +java edu.stanford.nlp.trees.OutputSubtrees -input $INPUT_DIR/fiveclass/extra-train.txt -ignore_labels 2 -remap_labels "1=0,2=-1,3=1,4=1" > $OUTPUT_DIR/binary/extra-train-binary-phrases.txt + +echo $OUTPUT_DIR/binary/checked-extra-train-binary-phrases.txt +java edu.stanford.nlp.trees.OutputSubtrees -input $INPUT_DIR/fiveclass/checked-extra-train.txt -ignore_labels 2 -remap_labels "1=0,2=-1,3=1,4=1" > $OUTPUT_DIR/binary/checked-extra-train-binary-phrases.txt + echo $OUTPUT_DIR/binary/dev-binary-roots.txt java edu.stanford.nlp.trees.OutputSubtrees -input $INPUT_DIR/fiveclass/dev.txt -root_only -ignore_labels 2 -remap_labels "1=0,2=-1,3=1,4=1" > $OUTPUT_DIR/binary/dev-binary-roots.txt @@ -76,6 +88,13 @@ java edu.stanford.nlp.trees.OutputSubtrees -input $INPUT_DIR/fiveclass/dev.txt - echo $OUTPUT_DIR/threeclass/test-threeclass-phrases.txt java edu.stanford.nlp.trees.OutputSubtrees -input $INPUT_DIR/fiveclass/test.txt -remap_labels "0=0,1=0,2=1,3=2,4=2" > $OUTPUT_DIR/threeclass/test-threeclass-phrases.txt +echo $OUTPUT_DIR/threeclass/extra-train-threeclass-phrases.txt +java edu.stanford.nlp.trees.OutputSubtrees -input $INPUT_DIR/fiveclass/extra-train.txt -remap_labels "0=0,1=0,2=1,3=2,4=2" > $OUTPUT_DIR/threeclass/extra-train-threeclass-phrases.txt + +echo $OUTPUT_DIR/threeclass/checked-extra-train-threeclass-phrases.txt +java edu.stanford.nlp.trees.OutputSubtrees -input $INPUT_DIR/fiveclass/checked-extra-train.txt -remap_labels "0=0,1=0,2=1,3=2,4=2" > $OUTPUT_DIR/threeclass/checked-extra-train-threeclass-phrases.txt + + echo $OUTPUT_DIR/threeclass/dev-threeclass-roots.txt java edu.stanford.nlp.trees.OutputSubtrees -input $INPUT_DIR/fiveclass/dev.txt -root_only -remap_labels "0=0,1=0,2=1,3=2,4=2" > $OUTPUT_DIR/threeclass/dev-threeclass-roots.txt diff --git a/stanza/models/classifier.py b/stanza/models/classifier.py index 445b2557..ed834843 100644 --- a/stanza/models/classifier.py +++ b/stanza/models/classifier.py @@ -80,7 +80,7 @@ python3 -u -m stanza.models.classifier --wordvec_type google --wordvec_dir exte To train models on combined 3 class datasets: -nohup python3 -u -m stanza.models.classifier --max_epochs 400 --filter_channels 1000 --fc_shapes 400,100 --base_name FC41_3class --extra_wordvec_method CONCAT --extra_wordvec_dim 200 --train_file extern_data/sentiment/sst-processed/threeclass/train-threeclass-phrases.txt,extern_data/sentiment/MELD/train.txt,extern_data/sentiment/slsd/train.txt,extern_data/sentiment/arguana/train.txt,extern_data/sentiment/airline/train.txt,extern_data/sentiment/sst-processed/threeclass/extra-train-threeclass-phrases.txt,extern_data/sentiment/sst-processed/threeclass/checked-extra-threeclass-phrases.txt --dev_file extern_data/sentiment/sst-processed/threeclass/dev-threeclass-roots.txt --test_file extern_data/sentiment/sst-processed/threeclass/test-threeclass-roots.txt > FC41_3class.out 2>&1 & +nohup python3 -u -m stanza.models.classifier --max_epochs 400 --filter_channels 1000 --fc_shapes 400,100 --base_name FC41_3class --extra_wordvec_method CONCAT --extra_wordvec_dim 200 --train_file extern_data/sentiment/sst-processed/threeclass/train-threeclass-phrases.txt,extern_data/sentiment/MELD/train.txt,extern_data/sentiment/slsd/train.txt,extern_data/sentiment/arguana/train.txt,extern_data/sentiment/airline/train.txt,extern_data/sentiment/sst-processed/threeclass/extra-train-threeclass-phrases.txt,extern_data/sentiment/sst-processed/threeclass/checked-extra-train-threeclass-phrases.txt --dev_file extern_data/sentiment/sst-processed/threeclass/dev-threeclass-roots.txt --test_file extern_data/sentiment/sst-processed/threeclass/test-threeclass-roots.txt > FC41_3class.out 2>&1 & This tests that model: @@ -488,7 +488,7 @@ def train_model(model, model_file, args, train_set, dev_set, labels): # Add any leftover loss to the epoch_loss epoch_loss += running_loss - logger.info("Finished epoch %d" % (epoch + 1)) + logger.info("Finished epoch %d Total loss %.3f" % (epoch + 1, epoch_loss)) dev_score = score_dev_set(model, dev_set, args.dev_eval_scoring) if args.save_intermediate_models: checkpoint_file = checkpoint_name(model_file, epoch + 1, args.dev_eval_scoring, dev_score) -- cgit v1.2.3 From 23029fececba6737a6552629ac447732809277ea Mon Sep 17 00:00:00 2001 From: John Bauer Date: Wed, 1 Sep 2021 08:48:08 -0700 Subject: Copy files using shutil.copy2 to preserve metadata. Don't include the directories so they don't have any metadata at all. This makes the zip files have the same md5sum at the end --- stanza/resources/prepare_resources.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/stanza/resources/prepare_resources.py b/stanza/resources/prepare_resources.py index 305deb7f..982be443 100644 --- a/stanza/resources/prepare_resources.py +++ b/stanza/resources/prepare_resources.py @@ -273,7 +273,7 @@ def ensure_dir(dir): def copy_file(src, dst): ensure_dir(Path(dst).parent) - shutil.copy(src, dst) + shutil.copy2(src, dst) def get_md5(path): @@ -325,8 +325,7 @@ def process_dirs(args): # copy file input_path = os.path.join(args.input_dir, model_dir, model) output_path = os.path.join(args.output_dir, lang, processor, package + '.pt') - ensure_dir(Path(output_path).parent) - shutil.copy(input_path, output_path) + copy_file(input_path, output_path) # maintain md5 md5 = get_md5(output_path) # maintain dependencies @@ -405,7 +404,6 @@ def process_defaults(args): print(" Model {} package {}: file {}".format(processor, package, filename)) if processor in ['tokenize', 'mwt', 'lemma', 'pos', 'depparse', 'ner', 'sentiment', 'langid']: default_processors[processor] = package - zipf.write(processor) zipf.write(os.path.join(processor, package + '.pt')) elif lang in allowed_empty_languages: # we don't have a lot of Thai support yet -- cgit v1.2.3 From e8b5b70064f818b7e09d5aad698f3339413eb829 Mon Sep 17 00:00:00 2001 From: John Bauer Date: Sun, 5 Sep 2021 20:55:16 -0700 Subject: Add a mechanism for importing tqdm differently depending on current context, eg writing to file, stdout, etc --- stanza/models/common/utils.py | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/stanza/models/common/utils.py b/stanza/models/common/utils.py index 842bf6fd..69a5ee7d 100644 --- a/stanza/models/common/utils.py +++ b/stanza/models/common/utils.py @@ -6,7 +6,9 @@ import os from collections import Counter import random import json +import sys import unicodedata + import torch import numpy as np @@ -299,3 +301,32 @@ def warn_missing_tags(known_tags, test_tags, test_set_name): logger.warning("Found tags in {} missing from the expected tag set: {}".format(test_set_name, missing_tags)) return True return False + +def get_tqdm(): + """ + Return a tqdm appropriate for the situation + + imports tqdm depending on if we're at a console, redir to a file, notebook, etc + + from @tcrimi at https://github.com/tqdm/tqdm/issues/506 + + This replaces `import tqdm`, so for example, you do this: + tqdm = utils.get_tqdm() + then do this when you want a scroll bar or regular iterator depending on context: + tqdm(list) + """ + try: + ipy_str = str(type(get_ipython())) + if 'zmqshell' in ipy_str: + from tqdm import tqdm_notebook as tqdm + return tqdm + if 'terminal' in ipy_str: + from tqdm import tqdm + return tqdm + except: + if sys.stderr.isatty(): + from tqdm import tqdm + return tqdm + def tqdm(iterable, **kwargs): + return iterable + return tqdm -- cgit v1.2.3 From f7af5049568f81a716106fee5403d339ca246f38 Mon Sep 17 00:00:00 2001 From: John Bauer Date: Sun, 22 Aug 2021 00:00:30 -0700 Subject: HuggingFace integration Download corenlp models from HuggingFace instead of stanford by default Also, download all of corenlp from HuggingFace --- stanza/resources/installation.py | 36 ++++++++++++++++++++++++++---------- stanza/tests/test_installation.py | 2 +- 2 files changed, 27 insertions(+), 11 deletions(-) diff --git a/stanza/resources/installation.py b/stanza/resources/installation.py index 7c5e5b2b..0a942bd8 100644 --- a/stanza/resources/installation.py +++ b/stanza/resources/installation.py @@ -12,19 +12,26 @@ from stanza.resources.common import HOME_DIR, request_file, unzip, \ logger = logging.getLogger('stanza') +DEFAULT_CORENLP_MODEL_URL = os.getenv( + 'CORENLP_MODEL_URL', + 'https://huggingface.co/stanfordnlp/corenlp-{model}/resolve/{tag}/stanford-corenlp-models-{model}.jar' +) +BACKUP_CORENLP_MODEL_URL = "http://nlp.stanford.edu/software/stanford-corenlp-{version}-models-{model}.jar" + DEFAULT_CORENLP_URL = os.getenv( - 'CORENLP_URL', - "http://nlp.stanford.edu/software/" + 'CORENLP_MODEL_URL', + 'https://huggingface.co/stanfordnlp/CoreNLP/resolve/{tag}/stanford-corenlp-latest.zip' ) + DEFAULT_CORENLP_DIR = os.getenv( 'CORENLP_HOME', os.path.join(HOME_DIR, 'stanza_corenlp') ) -AVAILABLE_MODELS = set(['arabic', 'chinese', 'english', 'english-kbp', 'french', 'german', 'spanish']) +AVAILABLE_MODELS = set(['arabic', 'chinese', 'english-extra', 'english-kbp', 'french', 'german', 'spanish']) -def download_corenlp_models(model, version, dir=DEFAULT_CORENLP_DIR, url=DEFAULT_CORENLP_URL, logging_level='INFO', proxies=None): +def download_corenlp_models(model, version, dir=DEFAULT_CORENLP_DIR, url=DEFAULT_CORENLP_MODEL_URL, logging_level='INFO', proxies=None): """ A automatic way to download the CoreNLP models. @@ -34,11 +41,12 @@ def download_corenlp_models(model, version, dir=DEFAULT_CORENLP_DIR, url=DEFAULT version: the version of the model dir: the directory to download CoreNLP model into; alternatively can be set up with environment variable $CORENLP_HOME - url: the link to download CoreNLP models + url: The link to download CoreNLP models. + It will need {model} and either {version} or {tag} to properly format the URL logging_level: logging level to use during installation """ dir = os.path.expanduser(dir) - if model is None or version is None: + if not model or not version: raise ValueError( "Both model and model version should be specified." ) @@ -49,9 +57,13 @@ def download_corenlp_models(model, version, dir=DEFAULT_CORENLP_DIR, url=DEFAULT f'{model} is currently not supported. ' f'Must be one of: {list(AVAILABLE_MODELS)}.' ) + # for example: + # https://huggingface.co/stanfordnlp/CoreNLP/resolve/v4.2.2/stanford-corenlp-models-french.jar + tag = version if version == 'main' else 'v' + version + download_url = url.format(tag=tag, model=model, version=version) try: request_file( - url + f'stanford-corenlp-{version}-models-{model}.jar', + download_url, os.path.join(dir, f'stanford-corenlp-{version}-models-{model}.jar'), proxies ) @@ -64,7 +76,7 @@ def download_corenlp_models(model, version, dir=DEFAULT_CORENLP_DIR, url=DEFAULT ) from e -def install_corenlp(dir=DEFAULT_CORENLP_DIR, url=DEFAULT_CORENLP_URL, logging_level=None, proxies=None): +def install_corenlp(dir=DEFAULT_CORENLP_DIR, url=DEFAULT_CORENLP_URL, logging_level=None, proxies=None, version="main"): """ A fully automatic way to install and setting up the CoreNLP library to use the client functionality. @@ -72,7 +84,8 @@ def install_corenlp(dir=DEFAULT_CORENLP_DIR, url=DEFAULT_CORENLP_URL, logging_le Args: dir: the directory to download CoreNLP model into; alternatively can be set up with environment variable $CORENLP_HOME - url: the link to download CoreNLP models + url: The link to download CoreNLP models + Needs a {version} or {tag} parameter to specify the version logging_level: logging level to use during installation """ dir = os.path.expanduser(dir) @@ -86,8 +99,11 @@ def install_corenlp(dir=DEFAULT_CORENLP_DIR, url=DEFAULT_CORENLP_URL, logging_le logger.info(f"Installing CoreNLP package into {dir}...") # First download the URL package logger.debug(f"Download to destination file: {os.path.join(dir, 'corenlp.zip')}") + tag = version if version == 'main' else 'v' + version + url = url.format(version=version, tag=tag) try: - request_file(url + 'stanford-corenlp-latest.zip', os.path.join(dir, 'corenlp.zip'), proxies) + request_file(url, os.path.join(dir, 'corenlp.zip'), proxies) + except (KeyboardInterrupt, SystemExit): raise except Exception as e: diff --git a/stanza/tests/test_installation.py b/stanza/tests/test_installation.py index 69a7bb0f..03fff24d 100644 --- a/stanza/tests/test_installation.py +++ b/stanza/tests/test_installation.py @@ -18,7 +18,7 @@ def test_install_corenlp(): # the download method doesn't install over existing directories shutil.rmtree(test_dir) - stanza.install_corenlp(dir=test_dir, url='http://nlp.stanford.edu/software/') + stanza.install_corenlp(dir=test_dir) assert os.path.isdir(test_dir), "Installation destination directory not found." jar_files = [f for f in os.listdir(test_dir) \ -- cgit v1.2.3 From 717f6b2fd118848790a36676ba3a30383b955046 Mon Sep 17 00:00:00 2001 From: John Bauer Date: Wed, 8 Sep 2021 10:55:11 -0700 Subject: Move the installation test to its own directory --- stanza/tests/resources/test_installation.py | 47 +++++++++++++++++++++++++++++ stanza/tests/test_installation.py | 47 ----------------------------- 2 files changed, 47 insertions(+), 47 deletions(-) create mode 100644 stanza/tests/resources/test_installation.py delete mode 100644 stanza/tests/test_installation.py diff --git a/stanza/tests/resources/test_installation.py b/stanza/tests/resources/test_installation.py new file mode 100644 index 00000000..03fff24d --- /dev/null +++ b/stanza/tests/resources/test_installation.py @@ -0,0 +1,47 @@ +""" +Test installation functions. +""" + +import os +import pytest +import shutil +import tempfile + +import stanza + +pytestmark = [pytest.mark.travis, pytest.mark.client] + +def test_install_corenlp(): + # we do not reset the CORENLP_HOME variable since this may impact the + # client tests + with tempfile.TemporaryDirectory(dir=".") as test_dir: + + # the download method doesn't install over existing directories + shutil.rmtree(test_dir) + stanza.install_corenlp(dir=test_dir) + + assert os.path.isdir(test_dir), "Installation destination directory not found." + jar_files = [f for f in os.listdir(test_dir) \ + if f.endswith('.jar') and f.startswith('stanford-corenlp')] + assert len(jar_files) > 0, \ + "Cannot find stanford-corenlp jar files in the installation directory." + assert not os.path.exists(os.path.join(test_dir, 'corenlp.zip')), \ + "Downloaded zip file was not removed." + +def test_download_corenlp_models(): + model_name = "arabic" + version = "4.2.2" + + with tempfile.TemporaryDirectory(dir=".") as test_dir: + stanza.download_corenlp_models(model=model_name, version=version, dir=test_dir) + + dest_file = os.path.join(test_dir, f"stanford-corenlp-{version}-models-{model_name}.jar") + assert os.path.isfile(dest_file), "Downloaded model file not found." + +def test_download_tokenize_mwt(): + with tempfile.TemporaryDirectory(dir=".") as test_dir: + stanza.download("en", model_dir=test_dir, processors="tokenize", package="ewt", verbose=False) + pipeline = stanza.Pipeline("en", model_dir=test_dir, processors="tokenize", package="ewt") + assert isinstance(pipeline, stanza.Pipeline) + # mwt should be added to the list + assert len(pipeline.loaded_processors) == 2 diff --git a/stanza/tests/test_installation.py b/stanza/tests/test_installation.py deleted file mode 100644 index 03fff24d..00000000 --- a/stanza/tests/test_installation.py +++ /dev/null @@ -1,47 +0,0 @@ -""" -Test installation functions. -""" - -import os -import pytest -import shutil -import tempfile - -import stanza - -pytestmark = [pytest.mark.travis, pytest.mark.client] - -def test_install_corenlp(): - # we do not reset the CORENLP_HOME variable since this may impact the - # client tests - with tempfile.TemporaryDirectory(dir=".") as test_dir: - - # the download method doesn't install over existing directories - shutil.rmtree(test_dir) - stanza.install_corenlp(dir=test_dir) - - assert os.path.isdir(test_dir), "Installation destination directory not found." - jar_files = [f for f in os.listdir(test_dir) \ - if f.endswith('.jar') and f.startswith('stanford-corenlp')] - assert len(jar_files) > 0, \ - "Cannot find stanford-corenlp jar files in the installation directory." - assert not os.path.exists(os.path.join(test_dir, 'corenlp.zip')), \ - "Downloaded zip file was not removed." - -def test_download_corenlp_models(): - model_name = "arabic" - version = "4.2.2" - - with tempfile.TemporaryDirectory(dir=".") as test_dir: - stanza.download_corenlp_models(model=model_name, version=version, dir=test_dir) - - dest_file = os.path.join(test_dir, f"stanford-corenlp-{version}-models-{model_name}.jar") - assert os.path.isfile(dest_file), "Downloaded model file not found." - -def test_download_tokenize_mwt(): - with tempfile.TemporaryDirectory(dir=".") as test_dir: - stanza.download("en", model_dir=test_dir, processors="tokenize", package="ewt", verbose=False) - pipeline = stanza.Pipeline("en", model_dir=test_dir, processors="tokenize", package="ewt") - assert isinstance(pipeline, stanza.Pipeline) - # mwt should be added to the list - assert len(pipeline.loaded_processors) == 2 -- cgit v1.2.3 From 5337e5a547939514fa5b59694a91790315a845ce Mon Sep 17 00:00:00 2001 From: John Bauer Date: Wed, 8 Sep 2021 11:11:46 -0700 Subject: Integrate the main downloads with huggingface --- stanza/resources/common.py | 8 ++++++-- stanza/resources/prepare_resources.py | 2 +- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/stanza/resources/common.py b/stanza/resources/common.py index a1b7e690..01dca151 100644 --- a/stanza/resources/common.py +++ b/stanza/resources/common.py @@ -424,8 +424,12 @@ def download( logger.info( f'Downloading default packages for language: {lang} ({lang_name})...' ) + # want the URL to become, for example: + # https://huggingface.co/stanfordnlp/stanza-af/resolve/v1.3.0/models/default.zip + # so we hopefully start from + # https://huggingface.co/stanfordnlp/stanza-{lang}/resolve/v{resources_version}/models/{filename} request_file( - f'{url}/{resources_version}/{lang}/default.zip', + url.format(resources_version=resources_version, lang=lang, filename="default.zip"), os.path.join(model_dir, lang, f'default.zip'), proxies, md5=resources[lang]['default_md5'], @@ -448,7 +452,7 @@ def download( for key, value in download_list: try: request_file( - f'{url}/{resources_version}/{lang}/{key}/{value}.pt', + url.format(resources_version=resources_version, lang=lang, filename=f"{key}/{value}.pt"), os.path.join(model_dir, lang, key, f'{value}.pt'), proxies, md5=resources[lang][key][value]['md5'] diff --git a/stanza/resources/prepare_resources.py b/stanza/resources/prepare_resources.py index 982be443..543c14ee 100644 --- a/stanza/resources/prepare_resources.py +++ b/stanza/resources/prepare_resources.py @@ -446,7 +446,7 @@ def process_misc(args): resources = json.load(open(os.path.join(args.output_dir, 'resources.json'))) resources['no'] = {'alias': 'nb'} resources['zh'] = {'alias': 'zh-hans'} - resources['url'] = 'http://nlp.stanford.edu/software/stanza' + resources['url'] = 'https://huggingface.co/stanfordnlp/stanza-{lang}/resolve/v{resources_version}/models/{filename}' json.dump(resources, open(os.path.join(args.output_dir, 'resources.json'), 'w'), indent=2) -- cgit v1.2.3 From 538549c6210da026b1f4254fe6a74df31040805d Mon Sep 17 00:00:00 2001 From: John Bauer Date: Wed, 8 Sep 2021 11:13:11 -0700 Subject: Separate tests of different components into different files --- stanza/tests/resources/test_common.py | 19 +++++++++++++++++++ stanza/tests/resources/test_installation.py | 8 -------- 2 files changed, 19 insertions(+), 8 deletions(-) create mode 100644 stanza/tests/resources/test_common.py diff --git a/stanza/tests/resources/test_common.py b/stanza/tests/resources/test_common.py new file mode 100644 index 00000000..3594cac9 --- /dev/null +++ b/stanza/tests/resources/test_common.py @@ -0,0 +1,19 @@ +""" +Test various resource downloading functions from resources/common.py +""" + +import pytest +import tempfile + +import stanza + +pytestmark = [pytest.mark.travis, pytest.mark.client] + + +def test_download_tokenize_mwt(): + with tempfile.TemporaryDirectory(dir=".") as test_dir: + stanza.download("en", model_dir=test_dir, processors="tokenize", package="ewt", verbose=False) + pipeline = stanza.Pipeline("en", model_dir=test_dir, processors="tokenize", package="ewt") + assert isinstance(pipeline, stanza.Pipeline) + # mwt should be added to the list + assert len(pipeline.loaded_processors) == 2 diff --git a/stanza/tests/resources/test_installation.py b/stanza/tests/resources/test_installation.py index 03fff24d..73e907f6 100644 --- a/stanza/tests/resources/test_installation.py +++ b/stanza/tests/resources/test_installation.py @@ -37,11 +37,3 @@ def test_download_corenlp_models(): dest_file = os.path.join(test_dir, f"stanford-corenlp-{version}-models-{model_name}.jar") assert os.path.isfile(dest_file), "Downloaded model file not found." - -def test_download_tokenize_mwt(): - with tempfile.TemporaryDirectory(dir=".") as test_dir: - stanza.download("en", model_dir=test_dir, processors="tokenize", package="ewt", verbose=False) - pipeline = stanza.Pipeline("en", model_dir=test_dir, processors="tokenize", package="ewt") - assert isinstance(pipeline, stanza.Pipeline) - # mwt should be added to the list - assert len(pipeline.loaded_processors) == 2 -- cgit v1.2.3 From 7f4b3f0c2b60d6cc1038e62722fb0ff9a55fdc04 Mon Sep 17 00:00:00 2001 From: John Bauer Date: Wed, 8 Sep 2021 14:36:37 -0700 Subject: Update results to match a new depparse model --- stanza/tests/test_tokenizer.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/stanza/tests/test_tokenizer.py b/stanza/tests/test_tokenizer.py index fc5c96b8..b7956277 100644 --- a/stanza/tests/test_tokenizer.py +++ b/stanza/tests/test_tokenizer.py @@ -166,12 +166,12 @@ ZH_DOC1_GOLD_TOKENS=""" ]> ]> ]> -]> +]> ]> ]> ]> -]> +]> ]> ]> ]> @@ -180,8 +180,8 @@ ZH_DOC1_GOLD_TOKENS=""" ]> ]> ]> -]> -]> +]> +]> """.strip() ZH_DOC_GOLD_NOSSPLIT_TOKENS = """ -- cgit v1.2.3 From 8a09877f222e11422c2dcebe49a391e1601093c8 Mon Sep 17 00:00:00 2001 From: John Bauer Date: Wed, 8 Sep 2021 14:43:54 -0700 Subject: Update test for new tokenizer args after the dictionary was added --- stanza/tests/test_tokenize_data.py | 1 + 1 file changed, 1 insertion(+) diff --git a/stanza/tests/test_tokenize_data.py b/stanza/tests/test_tokenize_data.py index 1bd2ae39..c37ace4e 100644 --- a/stanza/tests/test_tokenize_data.py +++ b/stanza/tests/test_tokenize_data.py @@ -23,6 +23,7 @@ FAKE_PROPERTIES = { "lang":"de", 'feat_funcs': ("space_before","capitalized"), 'max_seqlen': 300, + 'use_dictionary': False, } def test_has_mwt(): -- cgit v1.2.3 From 6d28575bb48640fc566e467059ed77a9bd58b895 Mon Sep 17 00:00:00 2001 From: John Bauer Date: Sun, 5 Sep 2021 20:43:33 -0700 Subject: Add an option to wrap long pos processes in a tqdm --- stanza/pipeline/pos_processor.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/stanza/pipeline/pos_processor.py b/stanza/pipeline/pos_processor.py index da918fdf..89658ee2 100644 --- a/stanza/pipeline/pos_processor.py +++ b/stanza/pipeline/pos_processor.py @@ -4,12 +4,14 @@ Processor for performing part-of-speech tagging from stanza.models.common import doc from stanza.models.common.pretrain import Pretrain -from stanza.models.common.utils import unsort +from stanza.models.common.utils import get_tqdm, unsort from stanza.models.pos.data import DataLoader from stanza.models.pos.trainer import Trainer from stanza.pipeline._constants import * from stanza.pipeline.processor import UDProcessor, register_processor +tqdm = get_tqdm() + @register_processor(name=POS) class POSProcessor(UDProcessor): @@ -23,14 +25,21 @@ class POSProcessor(UDProcessor): self._pretrain = Pretrain(config['pretrain_path']) if 'pretrain_path' in config else None # set up trainer self._trainer = Trainer(pretrain=self.pretrain, model_file=config['model_path'], use_cuda=use_gpu) + self._tqdm = 'tqdm' in config and config['tqdm'] def process(self, document): batch = DataLoader( document, self.config['batch_size'], self.config, self.pretrain, vocab=self.vocab, evaluation=True, sort_during_eval=True) preds = [] - for i, b in enumerate(batch): - preds += self.trainer.predict(b) + + if self._tqdm: + for i, b in enumerate(tqdm(batch)): + preds += self.trainer.predict(b) + else: + for i, b in enumerate(batch): + preds += self.trainer.predict(b) + preds = unsort(preds, batch.data_orig_idx) batch.doc.set([doc.UPOS, doc.XPOS, doc.FEATS], [y for x in preds for y in x]) return batch.doc -- cgit v1.2.3 From 81a83ebf950248c321cf1e1dc4e8227636792868 Mon Sep 17 00:00:00 2001 From: J38 Date: Fri, 10 Sep 2021 01:35:47 -0700 Subject: add demo workflow --- .github/workflows/demo.yml | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) create mode 100644 .github/workflows/demo.yml diff --git a/.github/workflows/demo.yml b/.github/workflows/demo.yml new file mode 100644 index 00000000..6fdfe5ae --- /dev/null +++ b/.github/workflows/demo.yml @@ -0,0 +1,29 @@ +name: Run Stanza Tests +on: [push] +jobs: + Run-Stanza-Tests: + runs-on: self-hosted + steps: + - run: echo "🎉 The job was automatically triggered by a ${{ github.event_name }} event." + - run: echo "🐧 This job is now running on a ${{ runner.os }} server hosted by GitHub!" + - run: echo "🔎 The name of your branch is ${{ github.ref }} and your repository is ${{ github.repository }}." + - name: Check out repository code + uses: actions/checkout@v2 + - run: echo "💡 The ${{ github.repository }} repository has been cloned to the runner." + - run: echo "🖥️ The workflow is now ready to test your code on the runner." + - name: Run demo + run: | + # set up environment + bash + . /home/stanzabuild/miniconda3/etc/profile.d/conda.sh + # install from stanza repo being evaluated + pwd + pip install -e . + # set up for tests + source stanza/tests/setup_test.sh + # run tests + echo "Running tests..." + pytest stanza/tests/test_depparse.py + echo "This currently works!" + + - run: echo "🍏 This job's status is ${{ job.status }}." -- cgit v1.2.3 From 11021e650da5e64010f6c06d642e6e5af58e202d Mon Sep 17 00:00:00 2001 From: J38 Date: Fri, 10 Sep 2021 01:47:00 -0700 Subject: Update demo.yml --- .github/workflows/demo.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.github/workflows/demo.yml b/.github/workflows/demo.yml index 6fdfe5ae..70038ed6 100644 --- a/.github/workflows/demo.yml +++ b/.github/workflows/demo.yml @@ -20,10 +20,13 @@ jobs: pwd pip install -e . # set up for tests + rm -rf /home/stanzabuild/stanza-github-actions/actions-runner/_work/stanza/stanza/stanza_test source stanza/tests/setup_test.sh # run tests echo "Running tests..." + export CUDA_VISIBLE_DEVICES=2 pytest stanza/tests/test_depparse.py + pytest stanza/tests/test_ner_tagger.py echo "This currently works!" - run: echo "🍏 This job's status is ${{ job.status }}." -- cgit v1.2.3 From 85bca1ce4bbc010879bac64721bfe9de0af2effc Mon Sep 17 00:00:00 2001 From: J38 Date: Fri, 10 Sep 2021 02:20:22 -0700 Subject: specify models dir --- stanza/pipeline/multilingual.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/stanza/pipeline/multilingual.py b/stanza/pipeline/multilingual.py index a6c958ba..55056c77 100644 --- a/stanza/pipeline/multilingual.py +++ b/stanza/pipeline/multilingual.py @@ -7,6 +7,7 @@ import torch from stanza.models.common.doc import Document from stanza.pipeline.core import Pipeline from stanza.pipeline._constants import * +from stanza.resources.common import DEFAULT_MODEL_DIR class MultilingualPipeline: @@ -17,6 +18,7 @@ class MultilingualPipeline: def __init__( self, + model_dir: str = DEFAULT_MODEL_DIR, lang_id_config: dict = None, lang_configs: dict = None, ld_batch_size: int = 64, @@ -24,6 +26,7 @@ class MultilingualPipeline: use_gpu: bool = None ): # set up configs and cache for various language pipelines + self.model_dir = model_dir self.lang_id_config = {} if lang_id_config is None else lang_id_config self.lang_configs = {} if lang_configs is None else lang_configs self.max_cache_size = max_cache_size @@ -37,8 +40,8 @@ class MultilingualPipeline: self.use_gpu = use_gpu # build language id pipeline - self.lang_id_pipeline = Pipeline(lang='multilingual', processors="langid", use_gpu=self.use_gpu, - **self.lang_id_config) + self.lang_id_pipeline = Pipeline(dir=self.model_dir, lang='multilingual', processors="langid", + use_gpu=self.use_gpu, **self.lang_id_config) def _update_pipeline_cache(self, lang): """ @@ -62,7 +65,7 @@ class MultilingualPipeline: lru_lang = self.lang_request_history[0] self.pipeline_cache.remove(lru_lang) self.lang_request_history.remove(lru_lang) - self.pipeline_cache[lang] = Pipeline(**self.lang_configs[lang]) + self.pipeline_cache[lang] = Pipeline(dir=self.model_dir, **self.lang_configs[lang]) def process(self, doc): """ -- cgit v1.2.3 From 7202fdc3fc387adc3cedee2c15488e4b8a9b5992 Mon Sep 17 00:00:00 2001 From: J38 Date: Fri, 10 Sep 2021 02:33:02 -0700 Subject: set model dir path --- stanza/tests/test_langid.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/stanza/tests/test_langid.py b/stanza/tests/test_langid.py index 3f8cc1a7..5dd36125 100644 --- a/stanza/tests/test_langid.py +++ b/stanza/tests/test_langid.py @@ -7,8 +7,9 @@ import pytest from stanza.models.common.doc import Document from stanza.pipeline.core import Pipeline from stanza.pipeline.multilingual import MultilingualPipeline +from stanza.tests import * -pytestmark = pytest.mark.skip +#pytestmark = pytest.mark.skip def test_langid(): """ @@ -18,7 +19,7 @@ def test_langid(): french_text = "C'est une phrase française." docs = [english_text, french_text] - nlp = Pipeline(lang='multilingual', processors="langid") + nlp = Pipeline(dir=TEST_MODELS_DIR, lang='multilingual', processors="langid") docs = [Document([], text=text) for text in docs] nlp(docs) predictions = [doc.lang for doc in docs] @@ -530,7 +531,7 @@ def test_langid_benchmark(): {"text": "Například Pedagogická fakulta Univerzity Karlovy", "label": "cs"}, {"text": "nostris ut eriperet nos de praesenti saeculo", "label": "la"}] - nlp = Pipeline(lang="multilingual", processors="langid") + nlp = Pipeline(dir=TEST_MODELS_DIR, lang="multilingual", processors="langid") docs = [Document([], text=example["text"]) for example in examples] gold_labels = [example["label"] for example in examples] nlp(docs) @@ -546,11 +547,11 @@ def test_text_cleaning(): "Bonjour le monde! https://t.co/U0Zjp3tusD"] docs = [Document([], text=text) for text in docs] - nlp = Pipeline(lang="multilingual", processors="langid") + nlp = Pipeline(dir=TEST_MODELS_DIR, lang="multilingual", processors="langid") nlp(docs) assert [doc.lang for doc in docs] == ["it", "it"] - nlp = Pipeline(lang="multilingual", processors="langid", langid_clean_text=True) + nlp = Pipeline(dir=TEST_MODELS_DIR, lang="multilingual", processors="langid", langid_clean_text=True) assert nlp.processors["langid"]._clean_text nlp(docs) assert [doc.lang for doc in docs] == ["fr", "fr"] @@ -563,16 +564,16 @@ def test_lang_subset(): "Bonjour le monde! https://t.co/U0Zjp3tusD"] docs = [Document([], text=text) for text in docs] - nlp = Pipeline(lang="multilingual", processors="langid") + nlp = Pipeline(dir=TEST_MODELS_DIR, lang="multilingual", processors="langid") nlp(docs) assert [doc.lang for doc in docs] == ["it", "it"] - nlp = Pipeline(lang="multilingual", processors="langid", langid_lang_subset=["en","fr"]) + nlp = Pipeline(dir=TEST_MODELS_DIR, lang="multilingual", processors="langid", langid_lang_subset=["en","fr"]) assert nlp.processors["langid"]._model.lang_subset == ["en", "fr"] nlp(docs) assert [doc.lang for doc in docs] == ["fr", "fr"] - nlp = Pipeline(lang="multilingual", processors="langid", langid_lang_subset=["en"]) + nlp = Pipeline(dir=TEST_MODELS_DIR, lang="multilingual", processors="langid", langid_lang_subset=["en"]) assert nlp.processors["langid"]._model.lang_subset == ["en"] nlp(docs) assert [doc.lang for doc in docs] == ["en", "en"] @@ -601,7 +602,7 @@ def test_multilingual_pipeline(): "('.', 4, 'punct')" )) - nlp = MultilingualPipeline() + nlp = MultilingualPipeline(model_dir=TEST_MODELS_DIR) docs = [english_text, french_text] docs = nlp(docs) -- cgit v1.2.3 From 4105c23f34c7d4efd836585c304304113c5c5ddb Mon Sep 17 00:00:00 2001 From: J38 Date: Fri, 10 Sep 2021 02:37:38 -0700 Subject: Create stanza-tests.yaml --- .github/workflows/stanza-tests.yaml | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) create mode 100644 .github/workflows/stanza-tests.yaml diff --git a/.github/workflows/stanza-tests.yaml b/.github/workflows/stanza-tests.yaml new file mode 100644 index 00000000..604ae915 --- /dev/null +++ b/.github/workflows/stanza-tests.yaml @@ -0,0 +1,33 @@ +name: Run Stanza Tests +on: [push] +jobs: + Run-Stanza-Tests: + runs-on: self-hosted + steps: + - run: echo "🎉 The job was automatically triggered by a ${{ github.event_name }} event." + - run: echo "🐧 This job is now running on a ${{ runner.os }} server hosted by GitHub!" + - run: echo "🔎 The name of your branch is ${{ github.ref }} and your repository is ${{ github.repository }}." + - name: Check out repository code + uses: actions/checkout@v2 + - run: echo "💡 The ${{ github.repository }} repository has been cloned to the runner." + - run: echo "🖥️ The workflow is now ready to test your code on the runner." + - name: Run demo + run: | + # set up environment + bash + . /home/stanzabuild/miniconda3/etc/profile.d/conda.sh + # install from stanza repo being evaluated + pwd + pip install -e . + # set up for tests + rm -rf /home/stanzabuild/stanza-github-actions/actions-runner/_work/stanza/stanza/stanza_test + source stanza/tests/setup_test.sh + # run tests + echo "Running tests..." + export CUDA_VISIBLE_DEVICES=2 + pytest stanza/tests/test_depparse.py + pytest stanza/tests/test_ner_tagger.py + pytest stanza/tests/test_langid.py + echo "This currently works!" + + - run: echo "🍏 This job's status is ${{ job.status }}." -- cgit v1.2.3 From 46253376dfc3ae6ecef0cc84044922d7093ea17c Mon Sep 17 00:00:00 2001 From: J38 Date: Fri, 10 Sep 2021 02:47:14 -0700 Subject: download multilingual during setup --- stanza/tests/setup_test.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/stanza/tests/setup_test.sh b/stanza/tests/setup_test.sh index a9d4bbf2..aec98367 100644 --- a/stanza/tests/setup_test.sh +++ b/stanza/tests/setup_test.sh @@ -23,6 +23,7 @@ mkdir -p $models_dir $PYTHON -c "import stanza; stanza.download(lang='en', model_dir='${models_dir}', logging_level='info')" || echo "failed to download english model" $PYTHON -c "import stanza; stanza.download(lang='fr', model_dir='${models_dir}', logging_level='info')" || echo "failed to download french model" $PYTHON -c "import stanza; stanza.download(lang='zh', model_dir='${models_dir}', logging_level='info')" || echo "failed to download chinese model" +$PYTHON -c "import stanza; stanza.download(lang='multilingual', model_dir='${models_dir}', logging_level='info')" || echo "failed to download chinese model" echo "Models downloaded to ${models_dir}." export STANZA_TEST_HOME=$test_dir -- cgit v1.2.3 From b4f9251bbfc451d059cb6a2674ca0563ca012a4b Mon Sep 17 00:00:00 2001 From: J38 Date: Fri, 10 Sep 2021 02:55:56 -0700 Subject: Delete demo.yml --- .github/workflows/demo.yml | 32 -------------------------------- 1 file changed, 32 deletions(-) delete mode 100644 .github/workflows/demo.yml diff --git a/.github/workflows/demo.yml b/.github/workflows/demo.yml deleted file mode 100644 index 70038ed6..00000000 --- a/.github/workflows/demo.yml +++ /dev/null @@ -1,32 +0,0 @@ -name: Run Stanza Tests -on: [push] -jobs: - Run-Stanza-Tests: - runs-on: self-hosted - steps: - - run: echo "🎉 The job was automatically triggered by a ${{ github.event_name }} event." - - run: echo "🐧 This job is now running on a ${{ runner.os }} server hosted by GitHub!" - - run: echo "🔎 The name of your branch is ${{ github.ref }} and your repository is ${{ github.repository }}." - - name: Check out repository code - uses: actions/checkout@v2 - - run: echo "💡 The ${{ github.repository }} repository has been cloned to the runner." - - run: echo "🖥️ The workflow is now ready to test your code on the runner." - - name: Run demo - run: | - # set up environment - bash - . /home/stanzabuild/miniconda3/etc/profile.d/conda.sh - # install from stanza repo being evaluated - pwd - pip install -e . - # set up for tests - rm -rf /home/stanzabuild/stanza-github-actions/actions-runner/_work/stanza/stanza/stanza_test - source stanza/tests/setup_test.sh - # run tests - echo "Running tests..." - export CUDA_VISIBLE_DEVICES=2 - pytest stanza/tests/test_depparse.py - pytest stanza/tests/test_ner_tagger.py - echo "This currently works!" - - - run: echo "🍏 This job's status is ${{ job.status }}." -- cgit v1.2.3 From 111fb7b429ef3b5a700a3d04354370b642f7c048 Mon Sep 17 00:00:00 2001 From: J38 Date: Fri, 10 Sep 2021 02:57:22 -0700 Subject: Update stanza-tests.yaml --- .github/workflows/stanza-tests.yaml | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/.github/workflows/stanza-tests.yaml b/.github/workflows/stanza-tests.yaml index 604ae915..34953c78 100644 --- a/.github/workflows/stanza-tests.yaml +++ b/.github/workflows/stanza-tests.yaml @@ -16,6 +16,7 @@ jobs: # set up environment bash . /home/stanzabuild/miniconda3/etc/profile.d/conda.sh + export CORENLP_HOME=/home/stanzabuild/stanford-corenlp-4.2.2 # install from stanza repo being evaluated pwd pip install -e . @@ -25,9 +26,6 @@ jobs: # run tests echo "Running tests..." export CUDA_VISIBLE_DEVICES=2 - pytest stanza/tests/test_depparse.py - pytest stanza/tests/test_ner_tagger.py - pytest stanza/tests/test_langid.py - echo "This currently works!" + pytest stanza/tests - run: echo "🍏 This job's status is ${{ job.status }}." -- cgit v1.2.3 From fd73b22803f2443896cc7ed189775b2d1171297d Mon Sep 17 00:00:00 2001 From: J38 Date: Fri, 10 Sep 2021 04:21:43 -0700 Subject: Update stanza-tests.yaml --- .github/workflows/stanza-tests.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/stanza-tests.yaml b/.github/workflows/stanza-tests.yaml index 34953c78..d40eb274 100644 --- a/.github/workflows/stanza-tests.yaml +++ b/.github/workflows/stanza-tests.yaml @@ -17,6 +17,7 @@ jobs: bash . /home/stanzabuild/miniconda3/etc/profile.d/conda.sh export CORENLP_HOME=/home/stanzabuild/stanford-corenlp-4.2.2 + export CLASSPATH=/home/stanzabuild/stanford-corenlp-4.2.2/*: # install from stanza repo being evaluated pwd pip install -e . -- cgit v1.2.3 From c610c4b9690158fbcb3cb23d80e44ab771b850a7 Mon Sep 17 00:00:00 2001 From: John Bauer Date: Fri, 10 Sep 2021 13:08:59 -0700 Subject: Update HighwayLSTM for unsorted PackedSequences --- stanza/models/common/hlstm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/stanza/models/common/hlstm.py b/stanza/models/common/hlstm.py index bfddb3e4..124c6ad5 100644 --- a/stanza/models/common/hlstm.py +++ b/stanza/models/common/hlstm.py @@ -99,14 +99,14 @@ class HighwayLSTM(nn.Module): for l in range(self.num_layers): if l > 0: - input = PackedSequence(self.drop(input.data), input.batch_sizes) + input = PackedSequence(self.drop(input.data), input.batch_sizes, input.sorted_indices, input.unsorted_indices) layer_hx = (hx[0][l * self.num_directions:(l+1)*self.num_directions], hx[1][l * self.num_directions:(l+1)*self.num_directions]) if hx is not None else None h, (ht, ct) = self.lstm[l](input, seqlens, layer_hx) hs.append(ht) cs.append(ct) - input = PackedSequence(h.data + torch.sigmoid(self.gate[l](input.data)) * highway_func(self.highway[l](input.data)), input.batch_sizes) + input = PackedSequence(h.data + torch.sigmoid(self.gate[l](input.data)) * highway_func(self.highway[l](input.data)), input.batch_sizes, input.sorted_indices, input.unsorted_indices) if self.pad: input = pad_packed_sequence(input, batch_first=self.batch_first)[0] -- cgit v1.2.3 From e232f67f3850a32a1b4f3a99e9eb4f5c5580c019 Mon Sep 17 00:00:00 2001 From: John Bauer Date: Wed, 22 Sep 2021 17:23:38 -0700 Subject: Addresses issue #804: don't remove all text when simplifying text in the sentiment processor --- stanza/models/classifiers/data.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/stanza/models/classifiers/data.py b/stanza/models/classifiers/data.py index 4922414a..72c04b9c 100644 --- a/stanza/models/classifiers/data.py +++ b/stanza/models/classifiers/data.py @@ -11,6 +11,10 @@ def update_text(sentence, wordvec_type): # stanford sentiment dataset has a lot of random - and / sentence = sentence.replace("-", " ") sentence = sentence.replace("/", " ") + sentence = sentence.strip() + if sentence == "": + # removed too much + sentence = "-" sentence = sentence.split() # our current word vectors are all entirely lowercased sentence = [word.lower() for word in sentence] -- cgit v1.2.3 From 182dfd17f4812cc306d77c1509348335aa6b2613 Mon Sep 17 00:00:00 2001 From: John Bauer Date: Wed, 22 Sep 2021 17:27:21 -0700 Subject: Add a test for the empty text and -- text for the sentiment processor --- stanza/tests/test_pipeline_sentiment_processor.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/stanza/tests/test_pipeline_sentiment_processor.py b/stanza/tests/test_pipeline_sentiment_processor.py index b46eedf4..d78dbbb6 100644 --- a/stanza/tests/test_pipeline_sentiment_processor.py +++ b/stanza/tests/test_pipeline_sentiment_processor.py @@ -36,3 +36,12 @@ def test_multiple_sentences(pipeline): results = [sentence.sentiment for sentence in doc.sentences] assert EXPECTED == results +def test_empty_text(pipeline): + """ + Test empty text and a text which might get reduced to empty text by removing dashes + """ + doc = pipeline("") + assert len(doc.sentences) == 0 + + doc = pipeline("--") + assert len(doc.sentences) == 1 -- cgit v1.2.3 From 0596044e8d5699b0aded723b46da94422975c7e9 Mon Sep 17 00:00:00 2001 From: John Bauer Date: Fri, 4 Jun 2021 10:20:57 -0700 Subject: Add a data structure to represent a parse tree Also, a tree reader which reads a list of trees from a string or file. Includes an interface to parse multiple trees from the same file. Pieces of the input are read as a TokenIterator so we can keep track of line numbers --- stanza/models/constituency/__init__.py | 0 stanza/models/constituency/parse_tree.py | 81 ++++++++++++++++++ stanza/models/constituency/tree_reader.py | 117 ++++++++++++++++++++++++++ stanza/tests/test_constituency_parse_tree.py | 33 ++++++++ stanza/tests/test_constituency_tree_reader.py | 61 ++++++++++++++ 5 files changed, 292 insertions(+) create mode 100644 stanza/models/constituency/__init__.py create mode 100644 stanza/models/constituency/parse_tree.py create mode 100644 stanza/models/constituency/tree_reader.py create mode 100644 stanza/tests/test_constituency_parse_tree.py create mode 100644 stanza/tests/test_constituency_tree_reader.py diff --git a/stanza/models/constituency/__init__.py b/stanza/models/constituency/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/stanza/models/constituency/parse_tree.py b/stanza/models/constituency/parse_tree.py new file mode 100644 index 00000000..1ca95f8d --- /dev/null +++ b/stanza/models/constituency/parse_tree.py @@ -0,0 +1,81 @@ +""" +Tree datastructure +""" + +from collections import deque +from io import StringIO + +from stanza.models.common.doc import StanzaObject + +class Tree(StanzaObject): + """ + A data structure to represent a parse tree + """ + def __init__(self, label=None, children=None): + if children is None: + self.children = [] + elif isinstance(children, Tree): + self.children = (children,) + else: + self.children = children + + self.label = label + + def is_leaf(self): + return len(self.children) == 0 + + def is_preterminal(self): + return len(self.children) == 1 and len(self.children[0].children) == 0 + + def yield_preterminals(self): + if self.is_leaf(): + pass + elif self.is_preterminal(): + yield self + else: + for child in self.children: + for preterminal in child.yield_preterminals(): + yield preterminal + + def __repr__(self): + """ + Turn the tree into a string representing the tree + + Note that this is not a recursive traversal + Otherwise, a tree too deep might blow up the call stack + """ + with StringIO() as buf: + stack = deque() + stack.append(self) + while len(stack) > 0: + node = stack.pop() + if node == ')' or node == ' ': + buf.write(node) + continue + if not node.children: + buf.write(node.label) + continue + buf.write("(") + buf.write(node.label) + stack.append(')') + for child in reversed(node.children): + stack.append(child) + stack.append(' ') + buf.seek(0) + return buf.read() + + def __eq__(self, other): + if self is other: + return True + if not isinstance(other, Tree): + return False + if self.label != other.label: + return False + if self.children != other.children: + return False + return True + + def depth(self): + if not self.children: + return 0 + return 1 + max(x.depth() for x in self.children) diff --git a/stanza/models/constituency/tree_reader.py b/stanza/models/constituency/tree_reader.py new file mode 100644 index 00000000..adfe2b55 --- /dev/null +++ b/stanza/models/constituency/tree_reader.py @@ -0,0 +1,117 @@ +from stanza.models.common import utils +from stanza.models.constituency.parse_tree import Tree + +tqdm = utils.get_tqdm() + +OPEN_PAREN = "(" +CLOSE_PAREN = ")" + +def recursive_open_tree(token_iterator, at_root): + # TODO: unwind the recursion + text = [] + children = [] + + token = next(token_iterator, None) + while token != None: + if token is OPEN_PAREN: + children.append(recursive_open_tree(token_iterator, at_root=False)) + elif token is CLOSE_PAREN: + if len(text) == 0: + if at_root: + return Tree(label="ROOT", children=children) + raise ValueError("Found a tree with no label on a node! Line number %d" % token_iterator.line_num) + + pieces = " ".join(text).split() + if len(pieces) == 1: + return Tree(label=pieces[0], children=children) + if len(children) > 0: + raise ValueError("Found a tree with both text children and bracketed children! Line number %d" % token_iterator.line_num) + label = pieces[0] + child_label = " ".join(pieces[1:]) + return Tree(label=label, children=Tree(label=child_label)) + else: + text.append(token) + token = next(token_iterator, None) + +def recursive_read_trees(token_iterator): + """ + TODO: some of the error cases we hit can be recovered from + also, just in general it would be good to unwind the recursion + """ + trees = [] + token = next(token_iterator, None) + while token: + if token is OPEN_PAREN: + trees.append(recursive_open_tree(token_iterator, at_root=True)) + token = next(token_iterator, None) + continue + + if token is CLOSE_PAREN: + raise ValueError("Tree document had too many close parens! Line number %d" % token_iterator.line_num) + else: + raise ValueError("Tree document had text between trees! Line number %d" % token_iterator.line_num) + + return trees + +class TokenIterator: + """ + A specific iterator for reading trees from a tree file + + The idea is that this will keep track of which line + we are processing, so that an error can be logged + from the correct line + """ + def __init__(self, text): + self.lines = text.split("\n") + self.line_num = -1 + self.token_iterator = iter([]) + + def __iter__(self): + return self + + def __next__(self): + n = next(self.token_iterator, None) + while n is None: + self.line_num = self.line_num + 1 + if self.line_num >= len(self.lines): + raise StopIteration + + line = self.lines[self.line_num].strip() + if not line: + continue + + pieces = [] + open_pieces = line.split(OPEN_PAREN) + for o_idx, open_piece in enumerate(open_pieces): + if open_piece: + close_pieces = open_piece.split(CLOSE_PAREN) + for c_idx, close_piece in enumerate(close_pieces): + close_piece = close_piece.strip() + if close_piece: + pieces.append(close_piece) + if c_idx != len(close_pieces) - 1: + pieces.append(CLOSE_PAREN) + if o_idx != len(open_pieces) - 1: + pieces.append(OPEN_PAREN) + self.token_iterator = iter(pieces) + n = next(self.token_iterator, None) + + return n + +def read_trees(text): + """ + Reads multiple trees from the text + """ + token_iterator = TokenIterator(text) + trees = recursive_read_trees(token_iterator) + return trees + +def read_tree_file(filename): + with open(filename) as fin: + trees = read_trees(fin.read()) + return trees + +if __name__ == '__main__': + text="( (SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))" + trees = read_trees(text) + print(trees) diff --git a/stanza/tests/test_constituency_parse_tree.py b/stanza/tests/test_constituency_parse_tree.py new file mode 100644 index 00000000..a4f81180 --- /dev/null +++ b/stanza/tests/test_constituency_parse_tree.py @@ -0,0 +1,33 @@ +import pytest + +from stanza.models.constituency.parse_tree import Tree + +from stanza.tests import * + +pytestmark = [pytest.mark.pipeline, pytest.mark.travis] + +def test_leaf_preterminal(): + foo = Tree(label="foo") + assert foo.is_leaf() + assert not foo.is_preterminal() + assert len(foo.children) == 0 + assert str(foo) == 'foo' + + bar = Tree(label="bar", children=foo) + assert not bar.is_leaf() + assert bar.is_preterminal() + assert len(bar.children) == 1 + assert str(bar) == "(bar foo)" + + baz = Tree(label="baz", children=[bar]) + assert not baz.is_leaf() + assert not baz.is_preterminal() + assert len(baz.children) == 1 + assert str(baz) == "(baz (bar foo))" + + +def test_depth(): + text = "(foo) ((S (VP (VB Unban)) (NP (NNP Mox) (NNP Opal))))" + trees = tree_reader.read_trees(text) + assert trees[0].depth() == 0 + assert trees[1].depth() == 4 diff --git a/stanza/tests/test_constituency_tree_reader.py b/stanza/tests/test_constituency_tree_reader.py new file mode 100644 index 00000000..feee74fa --- /dev/null +++ b/stanza/tests/test_constituency_tree_reader.py @@ -0,0 +1,61 @@ +import pytest +from stanza.models.constituency import tree_reader + +from stanza.tests import * + +pytestmark = [pytest.mark.pipeline, pytest.mark.travis] + +def test_simple(): + """ + Tests reading two simple trees from the same text + """ + text = "(VB Unban) (NNP Opal)" + trees = tree_reader.read_trees(text) + assert len(trees) == 2 + assert trees[0].is_preterminal() + assert trees[0].label == 'VB' + assert trees[0].children[0].label == 'Unban' + assert trees[1].is_preterminal() + assert trees[1].label == 'NNP' + assert trees[1].children[0].label == 'Opal' + +def test_newlines(): + """ + The same test should work if there are newlines + """ + text = "(VB Unban)\n\n(NNP Opal)" + trees = tree_reader.read_trees(text) + assert len(trees) == 2 + +def test_complicated(): + """ + A more complicated tree that should successfully read + """ + text="( (SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))" + trees = tree_reader.read_trees(text) + assert len(trees) == 1 + tree = trees[0] + assert not tree.is_leaf() + assert not tree.is_preterminal() + assert tree.label == 'ROOT' + assert len(tree.children) == 1 + assert tree.children[0].label == 'SBARQ' + assert len(tree.children[0].children) == 3 + assert [x.label for x in tree.children[0].children] == ['WHNP', 'SQ', '.'] + # etc etc + +def test_one_word(): + """ + Check that one node trees are correctly read + + probably not super relevant for the parsing use case + """ + text="(FOO) (BAR)" + trees = tree_reader.read_trees(text) + assert len(trees) == 2 + + assert trees[0].is_leaf() + assert trees[0].label == 'FOO' + + assert trees[1].is_leaf() + assert trees[1].label == 'BAR' -- cgit v1.2.3 From 90318023432d584c62986123ef414a1fa93683ca Mon Sep 17 00:00:00 2001 From: John Bauer Date: Fri, 18 Jun 2021 12:37:14 -0700 Subject: Constituency parser based on word embeddings to create Trees out of a sequence of words. This is a squash of what was originally a long list of changes See 2f7db846e14ce73ca95416172b1ba5ba512821f5 for a original sequence Primary methods are either top-down or in-order transition sequences, as per ths paper: In-Order Transition-based Constituent Parsing Jiangming Liu and Yue Zhang Parser eval interface which calls the CoreNLP parser eval Model is based on LSTMs. Includes a treebank evaluation request to CoreNLP via a protobuf Has options to use a variety of small modifications to the models. Constraints on the transitions hopefully prevent the parser from getting stuck. Allow either adadelta or sgd as optimizer Allow choice of relu or tanh for nonlinearity Includes a bunch of tests Move the constituency tests into their own directory Defaults are set to reasonable values for the WSJ PTB Lots of effort put into bulk operations instead of doing single transitions at a time Saves the optimizer state when saving a model. Makes the model much larger, but allows for restarting training from the same optimizer Also, a mode to remove the optimizer from a model (which shrinks it). Uses a mechanism similar to the original implementation to avoid too many "unary" transitions, eg an open immediately followed by a close. However, some training trees have too many unary transitions for the original limit=3 to be sufficient Charlm integration, including batching, although that didn't seem to help Also has some doc on things which didn't help --- doc/CoreNLP.proto | 32 ++ stanza/models/constituency/base_model.py | 205 +++++++ stanza/models/constituency/lstm_model.py | 543 +++++++++++++++++++ stanza/models/constituency/parse_transitions.py | 603 +++++++++++++++++++++ stanza/models/constituency/parse_tree.py | 260 ++++++++- stanza/models/constituency/trainer.py | 570 +++++++++++++++++++ stanza/models/constituency/transition_sequence.py | 112 ++++ stanza/models/constituency/tree_reader.py | 56 +- stanza/models/constituency/tree_stack.py | 52 ++ stanza/models/constituency/utils.py | 58 ++ stanza/models/constituency_parser.py | 290 ++++++++++ stanza/pipeline/_constants.py | 1 + stanza/pipeline/constituency_processor.py | 52 ++ stanza/pipeline/core.py | 1 + stanza/protobuf/CoreNLP_pb2.py | 253 ++++++++- stanza/server/java_protobuf_requests.py | 92 ++++ stanza/server/parser_eval.py | 41 ++ stanza/tests/constituency/test_lstm_model.py | 208 +++++++ .../tests/constituency/test_parse_transitions.py | 412 ++++++++++++++ stanza/tests/constituency/test_parse_tree.py | 196 +++++++ .../tests/constituency/test_transition_sequence.py | 87 +++ stanza/tests/constituency/test_tree_reader.py | 61 +++ stanza/tests/constituency/test_tree_stack.py | 50 ++ stanza/tests/constituency/test_utils.py | 68 +++ stanza/tests/test_constituency_parse_tree.py | 33 -- stanza/tests/test_constituency_tree_reader.py | 61 --- stanza/tests/test_java_protobuf_requests.py | 23 + stanza/tests/test_parser_eval.py | 40 ++ 28 files changed, 4329 insertions(+), 131 deletions(-) create mode 100644 stanza/models/constituency/base_model.py create mode 100644 stanza/models/constituency/lstm_model.py create mode 100644 stanza/models/constituency/parse_transitions.py create mode 100644 stanza/models/constituency/trainer.py create mode 100644 stanza/models/constituency/transition_sequence.py create mode 100644 stanza/models/constituency/tree_stack.py create mode 100644 stanza/models/constituency/utils.py create mode 100644 stanza/models/constituency_parser.py create mode 100644 stanza/pipeline/constituency_processor.py create mode 100644 stanza/server/parser_eval.py create mode 100644 stanza/tests/constituency/test_lstm_model.py create mode 100644 stanza/tests/constituency/test_parse_transitions.py create mode 100644 stanza/tests/constituency/test_parse_tree.py create mode 100644 stanza/tests/constituency/test_transition_sequence.py create mode 100644 stanza/tests/constituency/test_tree_reader.py create mode 100644 stanza/tests/constituency/test_tree_stack.py create mode 100644 stanza/tests/constituency/test_utils.py delete mode 100644 stanza/tests/test_constituency_parse_tree.py delete mode 100644 stanza/tests/test_constituency_tree_reader.py create mode 100644 stanza/tests/test_java_protobuf_requests.py create mode 100644 stanza/tests/test_parser_eval.py diff --git a/doc/CoreNLP.proto b/doc/CoreNLP.proto index 7fbff6dd..18b56ec0 100644 --- a/doc/CoreNLP.proto +++ b/doc/CoreNLP.proto @@ -698,3 +698,35 @@ message DependencyEnhancerRequest { string relativePronouns = 3; } } + +// A version of ParseTree with a flattened structure so that deep trees +// don't exceed the protobuf stack depth +message FlattenedParseTree { + message Node { + oneof contents { + bool openNode = 1; + bool closeNode = 2; + string value = 3; + } + + optional double score = 4; + } + + repeated Node nodes = 1; +} + +// A protobuf for calling the java constituency parser evaluator from elsewhere +message EvaluateParserRequest { + message ParseResult { + required FlattenedParseTree gold = 1; + // repeated so you can send in kbest parses, if your parser handles that + // note that this already includes a score field + repeated FlattenedParseTree predicted = 2; + } + + repeated ParseResult treebank = 1; +} + +message EvaluateParserResponse { + required double f1 = 1; +} diff --git a/stanza/models/constituency/base_model.py b/stanza/models/constituency/base_model.py new file mode 100644 index 00000000..0a4ee102 --- /dev/null +++ b/stanza/models/constituency/base_model.py @@ -0,0 +1,205 @@ +""" +The BaseModel is passed to the transitions so that the transitions +can operate on a parsing state without knowing the exact +representation used in the model. + +For example, a SimpleModel simply looks at the top of the various stacks in the state. + +A model with LSTM representations for the different transitions may +attach the hidden and output states of the LSTM to the word / +constituent / transition stacks. + +Reminder: the parsing state is a list of words to parse, the +transitions used to build a (possibly incomplete) parse, and the +constituent(s) built so far by those transitions. Each of these +components are represented using stacks to improve the efficiency +of operations such as "combine the most recent 4 constituents" +or "turn the next input word into a constituent" +""" + +from abc import ABC, abstractmethod + +from stanza.models.constituency.parse_transitions import TransitionScheme +from stanza.models.constituency.parse_tree import Tree +from stanza.models.constituency.tree_stack import TreeStack + +class BaseModel(ABC): + """ + This base class defines abstract methods for manipulating a State. + + Applying transitions may change important metadata about a State + such as the vectors associated with LSTM hidden states, for example. + """ + @abstractmethod + def initial_word_queues(self, tagged_word_lists): + """ + For each list of tagged words, builds a TreeStack of word nodes + + The word lists should be backwards so that the first word is the last word put on the stack (LIFO) + """ + + @abstractmethod + def initial_transitions(self): + """ + Builds an initial transition stack with whatever values need to go into first position + """ + + @abstractmethod + def initial_constituents(self): + """ + Builds an initial constituent stack with whatever values need to go into first position + """ + + @abstractmethod + def get_word(self, word_node): + """ + Get the word corresponding to this position in the word queue + """ + + @abstractmethod + def transform_word_to_constituent(self, state): + """ + Transform the top node of word_queue to something that can push on the constituent stack + """ + + @abstractmethod + def dummy_constituent(self, dummy): + """ + When using a dummy node as a sentinel, transform it to something usable by this model + """ + + @abstractmethod + def unary_transform(self, constituents, labels): + """ + Transform the top of the constituent stack using a unary transform to the new label + """ + + @abstractmethod + def build_constituents(self, labels, children_lists): + """ + Build multiple constituents at once. This gives the opportunity for batching operations + """ + + @abstractmethod + def push_constituents(self, constituent_stacks, constituents): + """ + Add a multiple constituents to multiple constituent_stacks + + Useful to factor this out in case batching will help + """ + + @abstractmethod + def get_top_constituent(self, constituents): + """ + Get the first constituent from the constituent stack + + For example, a model might want to remove embeddings and LSTM state vectors + """ + + @abstractmethod + def push_transitions(self, transition_stacks, transitions): + """ + Add a multiple transitions to multiple transition_stacks + + Useful to factor this out in case batching will help + """ + + @abstractmethod + def get_top_transition(self, transitions): + """ + Get the first transition from the transition stack + + For example, a model might want to remove transition embeddings before returning the transition + """ + + def get_root_labels(self): + """ + Return ROOT labels for this model. Probably ROOT, TOP, or both + """ + return ("ROOT",) + + @abstractmethod + def transition_scheme(self): + """ + Transition scheme used - see parse_transitions + """ + + @abstractmethod + def has_unary_transitions(self): + """ + Whether or not this model uses unary transitions, based on transition_scheme + """ + + @abstractmethod + def is_top_down(self): + """ + Whether or not this model is TOP_DOWN + """ + +class SimpleModel(BaseModel): + """ + This model allows pushing and popping with no extra data + """ + def __init__(self, transition_scheme=TransitionScheme.TOP_DOWN_UNARY): + self._transition_scheme = transition_scheme + + def initial_word_queues(self, tagged_word_lists): + word_queues = [] + for tagged_words in tagged_word_lists: + word_queue = [tag_node for tag_node in tagged_words] + word_queue.reverse() + word_queue.append(None) + word_queues.append(word_queue) + return word_queues + + def initial_transitions(self): + return TreeStack(value=None, parent=None, length=1) + + def initial_constituents(self): + return TreeStack(value=None, parent=None, length=1) + + def get_word(self, word_node): + return word_node + + def transform_word_to_constituent(self, state): + return state.word_queue[state.word_position] + + def dummy_constituent(self, dummy): + return dummy + + def unary_transform(self, constituents, labels): + top_constituent = constituents.value + for label in reversed(labels): + top_constituent = Tree(label=label, children=[top_constituent]) + return top_constituent + + def build_constituents(self, labels, children_lists): + constituents = [] + for label, children in zip(labels, children_lists): + if isinstance(label, str): + label = (label,) + for value in reversed(label): + children = Tree(label=value, children=children) + constituents.append(children) + return constituents + + def push_constituents(self, constituent_stacks, constituents): + return [stack.push(constituent) for stack, constituent in zip(constituent_stacks, constituents)] + + def get_top_constituent(self, constituents): + return constituents.value + + def push_transitions(self, transition_stacks, transitions): + return [stack.push(transition) for stack, transition in zip(transition_stacks, transitions)] + + def get_top_transition(self, transitions): + return transitions.value + + def transition_scheme(self): + return self._transition_scheme + + def has_unary_transitions(self): + return self._transition_scheme is TransitionScheme.TOP_DOWN_UNARY + + def is_top_down(self): + return self._transition_scheme in (TransitionScheme.TOP_DOWN, TransitionScheme.TOP_DOWN_UNARY, TransitionScheme.TOP_DOWN_COMPOUND) diff --git a/stanza/models/constituency/lstm_model.py b/stanza/models/constituency/lstm_model.py new file mode 100644 index 00000000..2939403e --- /dev/null +++ b/stanza/models/constituency/lstm_model.py @@ -0,0 +1,543 @@ +""" +A version of the BaseModel which uses LSTMs to predict the correct next transition +based on the current known state. + +The primary purpose of this class is to implement the prediction of the next +transition, which is done by concatenating the output of an LSTM operated over +previous transitions, the words, and the partially built constituents. +""" + +from collections import namedtuple +import logging +from operator import itemgetter +import random +import torch +import torch.nn as nn +from torch.nn.utils.rnn import pack_padded_sequence + +from stanza.models.common.data import get_long_tensor +from stanza.models.common.utils import unsort +from stanza.models.common.vocab import PAD_ID, UNK_ID +from stanza.models.constituency.base_model import BaseModel +from stanza.models.constituency.parse_transitions import TransitionScheme +from stanza.models.constituency.parse_tree import Tree +from stanza.models.constituency.tree_stack import TreeStack + +logger = logging.getLogger('stanza') + +WordNode = namedtuple("WordNode", ['value', 'hx']) +TransitionNode = namedtuple("TransitionNode", ['value', 'output', 'hx', 'cx']) + +# Invariant: the output at the top of the constituency stack will have a +# single dimension +# We do this to maintain consistency between the different operations, +# which sometimes result in different shapes +# This will be unsqueezed in order to put into the next layer if needed +# hx & cx are the hidden & cell states of the LSTM going across constituents +ConstituentNode = namedtuple("ConstituentNode", ['value', 'output', 'hx', 'cx']) +Constituent = namedtuple("Constituent", ['value', 'hx']) + + +class LSTMModel(BaseModel, nn.Module): + def __init__(self, pretrain, forward_charlm, backward_charlm, transitions, constituents, tags, words, rare_words, root_labels, open_nodes, args): + """ + pretrain: a Pretrain object + transitions: a list of all possible transitions which will be + used to build trees + constituents: a list of all possible constituents in the treebank + tags: a list of all possible tags in the treebank + words: a list of all known words, used for a delta word embedding. + note that there will be an attempt made to learn UNK words as well, + and tags by themselves may help UNK words + rare_words: a list of rare words, used to occasionally replace with UNK + root_labels: probably ROOT, although apparently some treebanks like TOP + open_nodes: a list of all possible open nodes which will go on the stack + - this might be different from constituents if there are nodes + which represent multiple constituents at once + args: hidden_size, transition_hidden_size, etc as gotten from + constituency_parser.py + + Note that it might look like a hassle to pass all of this in + when it can be collected directly from the trees themselves. + However, that would only work at train time. At eval or + pipeline time we will load the lists from the saved model. + """ + super().__init__() + self.args = args + self.unsaved_modules = [] + + emb_matrix = pretrain.emb + self.add_unsaved_module('embedding', nn.Embedding.from_pretrained(torch.from_numpy(emb_matrix), freeze=True)) + + self.vocab_map = { word: i for i, word in enumerate(pretrain.vocab) } + # precompute tensors for the word indices + # the tensors should be put on the GPU if needed with a call to cuda() + self.register_buffer('vocab_tensors', torch.tensor(range(len(pretrain.vocab)), requires_grad=False)) + self.vocab_size = emb_matrix.shape[0] + self.embedding_dim = emb_matrix.shape[1] + + self.root_labels = sorted(list(root_labels)) + self.constituents = sorted(list(constituents)) + self.constituent_map = { x: i for (i, x) in enumerate(self.constituents) } + # precompute tensors for the constituents + self.register_buffer('constituent_tensors', torch.tensor(range(len(self.constituent_map)), requires_grad=False)) + + self.hidden_size = self.args['hidden_size'] + self.transition_hidden_size = self.args['transition_hidden_size'] + self.tag_embedding_dim = self.args['tag_embedding_dim'] + self.transition_embedding_dim = self.args['transition_embedding_dim'] + self.delta_embedding_dim = self.args['delta_embedding_dim'] + self.word_input_size = self.embedding_dim + self.tag_embedding_dim + self.delta_embedding_dim + + if forward_charlm is not None: + self.add_unsaved_module('forward_charlm', forward_charlm) + self.add_unsaved_module('forward_charlm_vocab', forward_charlm.char_vocab()) + self.word_input_size += self.forward_charlm.hidden_dim() + else: + self.forward_charlm = None + if backward_charlm is not None: + self.add_unsaved_module('backward_charlm', backward_charlm) + self.add_unsaved_module('backward_charlm_vocab', backward_charlm.char_vocab()) + self.word_input_size += self.backward_charlm.hidden_dim() + else: + self.backward_charlm = None + + # TODO: add a max_norm? + self.delta_words = sorted(list(words)) + self.delta_word_map = { word: i+2 for i, word in enumerate(self.delta_words) } + assert PAD_ID == 0 + assert UNK_ID == 1 + self.delta_embedding = nn.Embedding(num_embeddings = len(self.delta_words)+2, + embedding_dim = self.delta_embedding_dim, + padding_idx = 0) + self.register_buffer('delta_tensors', torch.tensor(range(len(self.delta_words) + 2), requires_grad=False)) + + self.rare_words = set(rare_words) + + self.tags = sorted(list(tags)) + if self.tag_embedding_dim > 0: + self.tag_map = { t: i for i, t in enumerate(self.tags) } + self.tag_embedding = nn.Embedding(num_embeddings = len(tags), + embedding_dim = self.tag_embedding_dim) + self.register_buffer('tag_tensors', torch.tensor(range(len(self.tags)), requires_grad=False)) + + self.transitions = sorted(list(transitions)) + self.transition_map = { t: i for i, t in enumerate(self.transitions) } + # precompute tensors for the transitions + self.register_buffer('transition_tensors', torch.tensor(range(len(transitions)), requires_grad=False)) + self.transition_embedding = nn.Embedding(num_embeddings = len(transitions), + embedding_dim = self.transition_embedding_dim) + + self.num_layers = self.args['num_lstm_layers'] + self.lstm_layer_dropout = self.args['lstm_layer_dropout'] + + # also register a buffer of zeros so that we can always get zeros on the appropriate device + self.register_buffer('zeros', torch.zeros(self.hidden_size)) + self.register_buffer('transition_zeros', torch.zeros(self.num_layers, 1, self.transition_hidden_size)) + self.register_buffer('constituent_zeros', torch.zeros(self.num_layers, 1, self.hidden_size)) + + self.word_lstm = nn.LSTM(input_size=self.word_input_size, hidden_size=self.hidden_size, num_layers=self.num_layers, bidirectional=True, dropout=self.lstm_layer_dropout) + + # after putting the word_delta_tag input through the word_lstm, we get back + # hidden_size * 2 output with the front and back lstms concatenated. + # this transforms it into hidden_size with the values mixed together + self.word_to_constituent = nn.Linear(self.hidden_size * 2, self.hidden_size) + + self.transition_lstm = nn.LSTM(input_size=self.transition_embedding_dim, hidden_size=self.transition_hidden_size, num_layers=self.num_layers, dropout=self.lstm_layer_dropout) + # input_size is hidden_size - could introduce a new constituent_size instead if we liked + self.constituent_lstm = nn.LSTM(input_size=self.hidden_size, hidden_size=self.hidden_size, num_layers=self.num_layers, dropout=self.lstm_layer_dropout) + + self._transition_scheme = args['transition_scheme'] + if self._transition_scheme is TransitionScheme.TOP_DOWN_UNARY: + unary_transforms = {} + for constituent in self.constituent_map: + unary_transforms[constituent] = nn.Linear(self.hidden_size, self.hidden_size) + self.unary_transforms = nn.ModuleDict(unary_transforms) + + self.open_nodes = sorted(list(open_nodes)) + # an embedding for the spot on the constituent LSTM taken up by the Open transitions + # the pattern when condensing constituents is embedding - con1 - con2 - con3 - embedding + # TODO: try the two ends have different embeddings? + self.open_node_map = { x: i for (i, x) in enumerate(self.open_nodes) } + self.open_node_embedding = nn.Embedding(num_embeddings = len(self.open_node_map), + embedding_dim = self.hidden_size) + + # TODO: remove this `get` once it's not needed + if args.get('combined_dummy_embedding', False): + self.dummy_embedding = self.open_node_embedding + else: + self.dummy_embedding = nn.Embedding(num_embeddings = len(self.open_node_map), + embedding_dim = self.hidden_size) + self.register_buffer('open_node_tensors', torch.tensor(range(len(open_nodes)), requires_grad=False)) + + # forward and backward pieces for crunching several + # constituents into one, combined into a bi-lstm + # TODO: make the hidden size here an option? + self.constituent_reduce_lstm = nn.LSTM(input_size=self.hidden_size, hidden_size=self.hidden_size, num_layers=self.num_layers, bidirectional=True, dropout=self.lstm_layer_dropout) + # affine transformation from bi-lstm reduce to a new hidden layer + self.reduce_linear = nn.Linear(self.hidden_size * 2, self.hidden_size) + + if self.args['nonlinearity'] == 'tanh': + self.nonlinearity = nn.Tanh() + elif self.args['nonlinearity'] == 'relu': + self.nonlinearity = nn.ReLU() + elif self.args['nonlinearity'] == 'gelu': + self.nonlinearity = nn.GELU() + else: + raise ValueError('Chosen value of nonlinearity, "%s", not handled' % self.args['nonlinearity']) + + self.word_dropout = nn.Dropout(self.args['word_dropout']) + self.predict_dropout = nn.Dropout(self.args['predict_dropout']) + self.lstm_input_dropout = nn.Dropout(self.args['lstm_input_dropout']) + + # matrix for predicting the next transition using word/constituent/transition queues + # word size + constituency size + transition size + middle_layers = self.args['num_output_layers'] - 1 + predict_input_size = [self.hidden_size * 2 + self.transition_hidden_size] + [self.hidden_size] * middle_layers + predict_output_size = [self.hidden_size] * middle_layers + [len(transitions)] + self.output_layers = nn.ModuleList([nn.Linear(input_size, output_size) + for input_size, output_size in zip(predict_input_size, predict_output_size)]) + + self.constituency_lstm = self.args['constituency_lstm'] + + def add_unsaved_module(self, name, module): + """ + Adds a module which will not be saved to disk + + Best used for large models such as pretrained word embeddings + """ + self.unsaved_modules += [name] + setattr(self, name, module) + + def get_root_labels(self): + return self.root_labels + + def build_char_representation(self, all_word_labels, device, forward): + CHARLM_START = "\n" + CHARLM_END = " " + + if forward: + charlm = self.forward_charlm + vocab = self.forward_charlm_vocab + else: + charlm = self.backward_charlm + vocab = self.backward_charlm_vocab + + all_data = [] + for idx, word_labels in enumerate(all_word_labels): + if forward: + word_labels = reversed(word_labels) + else: + word_labels = [x[::-1] for x in word_labels] + + chars = [CHARLM_START] + offsets = [] + for w in word_labels: + chars.extend(w) + chars.append(CHARLM_END) + offsets.append(len(chars) - 1) + if not forward: + offsets.reverse() + chars = vocab.map(chars) + all_data.append((chars, offsets, len(chars), len(all_data))) + + all_data.sort(key=itemgetter(2), reverse=True) + chars, char_offsets, char_lens, orig_idx = tuple(zip(*all_data)) + chars = get_long_tensor(chars, len(all_data), pad_id=vocab.unit2id(' ')).to(device=device) + + # TODO: surely this should be stuffed in the charlm model itself rather than done here + with torch.no_grad(): + output, _, _ = charlm.forward(chars, char_lens) + res = [output[i, offsets] for i, offsets in enumerate(char_offsets)] + res = unsort(res, orig_idx) + + return res + + def initial_word_queues(self, tagged_word_lists): + """ + Produce initial word queues out of the model's LSTMs for use in the tagged word lists. + + Operates in a batched fashion to reduce the runtime for the LSTM operations + """ + device = next(self.parameters()).device + + all_word_inputs = [] + all_word_labels = [] + for sentence_idx, tagged_words in enumerate(tagged_word_lists): + word_idx = torch.stack([self.vocab_tensors[self.vocab_map.get(word.children[0].label, UNK_ID)] for word in tagged_words]) + word_input = self.embedding(word_idx) + + # this occasionally learns UNK at train time + word_labels = [word.children[0].label for word in tagged_words] + if self.training: + delta_labels = [None if word in self.rare_words and random.random() < self.args['rare_word_unknown_frequency'] else word + for word in word_labels] + else: + delta_labels = word_labels + delta_idx = torch.stack([self.delta_tensors[self.delta_word_map.get(word, UNK_ID)] for word in delta_labels]) + + delta_input = self.delta_embedding(delta_idx) + + word_inputs = [word_input, delta_input] + + if self.tag_embedding_dim > 0: + try: + tag_idx = torch.stack([self.tag_tensors[self.tag_map[word.label]] for word in tagged_words]) + tag_input = self.tag_embedding(tag_idx) + word_inputs.append(tag_input) + except KeyError as e: + raise KeyError("Constituency parser not trained with tag {}".format(str(e))) from e + + all_word_labels.append(word_labels) + all_word_inputs.append(word_inputs) + + if self.forward_charlm is not None: + all_forward_chars = self.build_char_representation(all_word_labels, device, forward=True) + for word_inputs, forward_chars in zip(all_word_inputs, all_forward_chars): + word_inputs.append(forward_chars) + if self.backward_charlm is not None: + all_backward_chars = self.build_char_representation(all_word_labels, device, forward=False) + for word_inputs, backward_chars in zip(all_word_inputs, all_backward_chars): + word_inputs.append(backward_chars) + + word_lstm_input = torch.zeros((max(len(x) for x in tagged_word_lists), len(tagged_word_lists), self.word_input_size), device=device) + + for sentence_idx, word_inputs in enumerate(all_word_inputs): + # now of size sentence x input + word_input = torch.cat(word_inputs, dim=1) + word_input = self.word_dropout(word_input) + + word_lstm_input[:word_input.shape[0], sentence_idx, :] = word_input + + packed_word_input = torch.nn.utils.rnn.pack_padded_sequence(word_lstm_input, [len(x) for x in tagged_word_lists], enforce_sorted=False) + word_output, _ = self.word_lstm(packed_word_input) + # would like to do word_to_constituent here, but it seems PackedSequence doesn't support Linear + # word_output will now be sentence x batch x 2*hidden_size + word_output, word_output_lens = torch.nn.utils.rnn.pad_packed_sequence(word_output) + # now sentence x batch x hidden_size + + word_queues = [] + for sentence_idx, tagged_words in enumerate(tagged_word_lists): + sentence_output = word_output[:len(tagged_words), sentence_idx, :] + sentence_output = self.word_to_constituent(sentence_output) + sentence_output = self.nonlinearity(sentence_output) + # TODO: this makes it so constituents downstream are + # build with the outputs of the LSTM, not the word + # embeddings themselves. It is possible we want to + # transform the word_input to hidden_size in some way + # and use that instead + word_queue = [WordNode(tag_node, sentence_output[idx, :]) + for idx, tag_node in enumerate(tagged_words)] + word_queue.reverse() + word_queue.append(WordNode(None, self.zeros)) + + word_queues.append(word_queue) + + return word_queues + + def initial_transitions(self): + """ + Return an initial TreeStack with no transitions + """ + return TreeStack(value=TransitionNode(None, self.transition_zeros[-1, 0, :], self.transition_zeros, self.transition_zeros), parent=None, length=1) + + def initial_constituents(self): + """ + Return an initial TreeStack with no constituents + """ + return TreeStack(value=ConstituentNode(None, self.constituent_zeros[-1, 0, :], self.constituent_zeros, self.constituent_zeros), parent=None, length=1) + + def get_word(self, word_node): + return word_node.value + + def transform_word_to_constituent(self, state): + word_node = state.word_queue[state.word_position] + word = word_node.value + return Constituent(value=word, hx=word_node.hx) + + def dummy_constituent(self, dummy): + label = dummy.label + open_index = self.open_node_tensors[self.open_node_map[label]] + hx = self.dummy_embedding(open_index) + return Constituent(value=dummy, hx=hx) + + def unary_transform(self, constituents, labels): + top_constituent = constituents.value + node = top_constituent.value + hx = top_constituent.output + for label in reversed(labels): + node = Tree(label=label, children=[node]) + hx = self.unary_transforms[label](hx) + # non-linearity after the unary transform + hx = self.nonlinearity(hx) + top_constituent = Constituent(value=node, hx=hx) + return top_constituent + + def build_constituents(self, labels, children_lists): + label_hx = [self.open_node_embedding(self.open_node_tensors[self.open_node_map[label]]) for label in labels] + + max_length = max(len(children) for children in children_lists) + zeros = torch.zeros(self.hidden_size, device=label_hx[0].device) + node_hx = [[child.output for child in children] for children in children_lists] + # weirdly, this is faster than using pack_sequence + unpacked_hx = [[lhx] + nhx + [lhx] + [zeros] * (max_length - len(nhx)) for lhx, nhx in zip(label_hx, node_hx)] + unpacked_hx = [self.lstm_input_dropout(torch.stack(nhx)) for nhx in unpacked_hx] + packed_hx = torch.stack(unpacked_hx, axis=1) + packed_hx = torch.nn.utils.rnn.pack_padded_sequence(packed_hx, [len(x)+2 for x in children_lists], enforce_sorted=False) + lstm_output = self.constituent_reduce_lstm(packed_hx) + # take just the output of the final layer + # result of lstm is ouput, (hx, cx) + # so [1][0] gets hx + # [1][0][-1] is the final output + # will be shape len(children_lists) * 2, hidden_size for bidirectional + # where forward outputs are -2 and backwards are -1 + lstm_output = lstm_output[1][0] + forward_hx = lstm_output[-2, :] + backward_hx = lstm_output[-1, :] + + hx = self.reduce_linear(torch.cat((forward_hx, backward_hx), axis=1)) + hx = self.nonlinearity(hx) + + constituents = [] + for idx, (label, children) in enumerate(zip(labels, children_lists)): + children = [child.value for child in children] + if isinstance(label, str): + node = Tree(label=label, children=children) + else: + for value in reversed(label): + node = Tree(label=value, children=children) + children = node + constituents.append(Constituent(value=node, hx=hx[idx, :])) + return constituents + + def push_constituents(self, constituent_stacks, constituents): + current_nodes = [stack.value for stack in constituent_stacks] + + constituent_input = torch.stack([x.hx for x in constituents]) + constituent_input = constituent_input.unsqueeze(0) + constituent_input = self.lstm_input_dropout(constituent_input) + + hx = torch.cat([current_node.hx for current_node in current_nodes], axis=1) + cx = torch.cat([current_node.cx for current_node in current_nodes], axis=1) + output, (hx, cx) = self.constituent_lstm(constituent_input, (hx, cx)) + if self.constituency_lstm: + new_stacks = [stack.push(ConstituentNode(constituent.value, output[0, i, :], hx[:, i:i+1, :], cx[:, i:i+1, :])) + for i, (stack, constituent) in enumerate(zip(constituent_stacks, constituents))] + else: + new_stacks = [stack.push(ConstituentNode(constituent.value, constituents[i].hx, hx[:, i:i+1, :], cx[:, i:i+1, :])) + for i, (stack, constituent) in enumerate(zip(constituent_stacks, constituents))] + return new_stacks + + def get_top_constituent(self, constituents): + """ + Extract only the top constituent from a state's constituent + sequence, even though it has multiple addition pieces of + information + """ + constituent_node = constituents.value + return constituent_node.value + + def push_transitions(self, transition_stacks, transitions): + transition_idx = torch.stack([self.transition_tensors[self.transition_map[transition]] for transition in transitions]) + transition_input = self.transition_embedding(transition_idx).unsqueeze(0) + transition_input = self.lstm_input_dropout(transition_input) + + hx = torch.cat([t.value.hx for t in transition_stacks], axis=1) + cx = torch.cat([t.value.cx for t in transition_stacks], axis=1) + output, (hx, cx) = self.transition_lstm(transition_input, (hx, cx)) + new_stacks = [stack.push(TransitionNode(transition, output[0, i, :], hx[:, i:i+1, :], cx[:, i:i+1, :])) + for i, (stack, transition) in enumerate(zip(transition_stacks, transitions))] + return new_stacks + + def get_top_transition(self, transitions): + """ + Extract only the top transition from a state's transition + sequence, even though it has multiple addition pieces of + information + """ + transition_node = transitions.value + return transition_node.value + + def transition_scheme(self): + return self._transition_scheme + + def has_unary_transitions(self): + return self._transition_scheme is TransitionScheme.TOP_DOWN_UNARY + + def is_top_down(self): + return self._transition_scheme in (TransitionScheme.TOP_DOWN, TransitionScheme.TOP_DOWN_UNARY, TransitionScheme.TOP_DOWN_COMPOUND) + + def forward(self, states): + """ + Return logits for a prediction of what transition to make next + + We've basically done all the work analyzing the state as + part of applying the transitions, so this method is very simple + """ + word_hx = torch.stack([state.word_queue[state.word_position].hx for state in states]) + transition_hx = torch.stack([state.transitions.value.output for state in states]) + # note that we use hx instead of output from the constituents + # this way, we can, as an option, NOT include the constituents to the left + # when building the current vector for a constituent + # and the vector used for inference will still incorporate the entire LSTM + constituent_hx = torch.stack([state.constituents.value.hx[-1, 0, :] for state in states]) + + hx = torch.cat((word_hx, transition_hx, constituent_hx), axis=1) + for idx, output_layer in enumerate(self.output_layers): + hx = self.predict_dropout(hx) + if idx < len(self.output_layers) - 1: + hx = self.nonlinearity(hx) + hx = output_layer(hx) + return hx + + # TODO: merge this with forward? + def predict(self, states, is_legal=False): + """ + Generate and return predictions, along with the transitions those predictions represent + + If is_legal is set to True, will only return legal transitions. + This means returning None if there are no legal transitions. + Hopefully the constraints prevent that from happening + """ + predictions = self.forward(states) + pred_max = torch.argmax(predictions, axis=1) + + pred_trans = [self.transitions[pred_max[idx]] for idx in range(len(states))] + if is_legal: + for idx, (state, trans) in enumerate(zip(states, pred_trans)): + if not trans.is_legal(state, self): + _, indices = predictions[idx, :].sort(descending=True) + for index in indices: + if self.transitions[index].is_legal(state, self): + pred_trans[idx] = self.transitions[index] + break + else: # yeah, else on a for loop, deal with it + pred_trans[idx] = None + + return predictions, pred_trans + + def get_params(self, skip_modules=True): + """ + Get a dictionary for saving the model + """ + model_state = self.state_dict() + # skip saving modules like pretrained embeddings, because they are large and will be saved in a separate file + if skip_modules: + skipped = [k for k in model_state.keys() if k.split('.')[0] in self.unsaved_modules] + for k in skipped: + del model_state[k] + params = { + 'model': model_state, + 'model_type': "LSTM", + 'config': self.args, + 'transitions': self.transitions, + 'constituents': self.constituents, + 'tags': self.tags, + 'words': self.delta_words, + 'rare_words': self.rare_words, + 'root_labels': self.root_labels, + 'open_nodes': self.open_nodes, + } + + return params + diff --git a/stanza/models/constituency/parse_transitions.py b/stanza/models/constituency/parse_transitions.py new file mode 100644 index 00000000..ec815caa --- /dev/null +++ b/stanza/models/constituency/parse_transitions.py @@ -0,0 +1,603 @@ +""" +Defines a series of transitions (open a constituent, close a constituent, etc + +Also defines a State which holds the various data needed to build +a parse tree out of tagged words. +""" + +from abc import ABC, abstractmethod +from collections import defaultdict, namedtuple +from enum import Enum +import functools +import logging + +from stanza.models.constituency.parse_tree import Tree + +logger = logging.getLogger('stanza') + +class TransitionScheme(Enum): + TOP_DOWN = 1 + TOP_DOWN_COMPOUND = 2 + TOP_DOWN_UNARY = 3 + + IN_ORDER = 4 + +UNARY_LIMIT = 4 + +class State(namedtuple('State', ['word_queue', 'transitions', 'constituents', 'gold_tree', 'gold_sequence', + 'sentence_length', 'num_opens', 'word_position'])): + """ + Represents a partially completed transition parse + + Includes stack/buffers for unused words, already executed transitions, and partially build constituents + At training time, also keeps track of the gold data we are reparsing + + num_opens is useful for tracking + 1) if the parser is in a stuck state where it is making infinite opens + 2) if a close transition is impossible because there are no previous opens + + sentence_length tracks how long the sentence is so we abort if we go infinite + + non-stack information such as sentence_length and num_opens + will be copied from the original_state if possible, with the + exact arguments overriding the values in the original_state + + gold_tree: the original tree, if made from a gold tree. might be None + gold_sequence: the original transition sequence, if available + Note that at runtime, gold values will not be available + + word_position tracks where in the word queue we are. cheaper than + manipulating the list itself. this can be handled differently + from transitions and constituents as it is processed once + at the start of parsing + """ + def empty_word_queue(self): + # the first element of each stack is a sentinel with no value + # and no parent + return self.word_position == self.sentence_length + + def empty_transitions(self): + # the first element of each stack is a sentinel with no value + # and no parent + return self.transitions.parent is None + + def has_one_constituent(self): + # a length of 1 represents no constituents + return len(self.constituents) == 2 + + def num_constituents(self): + return len(self.constituents) - 1 + + def num_transitions(self): + # -1 for the sentinel value + return len(self.transitions) - 1 + + def finished(self, model): + return self.empty_word_queue() and self.has_one_constituent() and model.get_top_constituent(self.constituents).label in model.get_root_labels() + + def get_tree(self, model): + return model.get_top_constituent(self.constituents) + + def all_transitions(self, model): + # TODO: rewrite this to be nicer / faster? or just refactor? + all_transitions = [] + transitions = self.transitions + while transitions.parent is not None: + all_transitions.append(model.get_top_transition(transitions)) + transitions = transitions.parent + return list(reversed(all_transitions)) + + def all_constituents(self, model): + # TODO: rewrite this to be nicer / faster? + all_constituents = [] + constituents = self.constituents + while constituents.parent is not None: + all_constituents.append(model.get_top_constituent(constituents)) + constituents = constituents.parent + return list(reversed(all_constituents)) + + def all_words(self, model): + return [model.get_word(x) for x in self.word_queue] + + def to_string(self, model): + return "State(\n buffer:%s\n transitions:%s\n constituents:%s)" % (str(self.all_words(model)), str(self.all_transitions(model)), str(self.all_constituents(model))) + + def __str__(self): + return "State(\n buffer:%s\n transitions:%s\n constituents:%s)" % (str(self.word_queue), str(self.transitions), str(self.constituents)) + +def initial_state_from_preterminals(preterminal_lists, model, gold_trees): + """ + what is passed in should be a list of list of preterminals + """ + word_queues = model.initial_word_queues(preterminal_lists) + # this is the bottom of the TreeStack and will be the same for each State + transitions=model.initial_transitions() + constituents=model.initial_constituents() + states = [State(sentence_length=len(wq)-1, # -1 because it ends with a sentinel + num_opens=0, + word_queue=wq, + gold_tree=None, + gold_sequence=None, + transitions=transitions, + constituents=constituents, + word_position=0) + for idx, wq in enumerate(word_queues)] + if gold_trees: + states = [state._replace(gold_tree=gold_tree) for gold_tree, state in zip(gold_trees, states)] + return states + +def initial_state_from_words(word_lists, model): + # TODO: stop reversing the words + preterminal_lists = [] + for words in word_lists: + preterminals = [] + for word, tag in reversed(words): + word_node = Tree(label=word) + tag_node = Tree(label=tag, children=[word_node]) + preterminals.append(tag_node) + preterminal_lists.append(preterminals) + return initial_state_from_preterminals(preterminal_lists, model, gold_trees=None) + +def initial_state_from_gold_trees(trees, model): + # reversed so we put the words on the stack backwards + preterminal_lists = [[Tree(label=pt.label, children=Tree(label=pt.children[0].label)) + for pt in tree.yield_reversed_preterminals()] + for tree in trees] + return initial_state_from_preterminals(preterminal_lists, model, gold_trees=trees) + +@functools.total_ordering +class Transition(ABC): + """ + model is passed in as a dependency injection + for example, an LSTM model can update hidden & output vectors when transitioning + """ + @abstractmethod + def update_state(self, state, model): + """ + update the word queue position, possibly remove old pieces from the constituents state, and return the new constituent + + the return value should be a tuple: + updated word_position + updated constituents + new constituent to put on the queue and None + - note that the constituent shouldn't be on the queue yet + that allows putting it on as a batch operation, which + saves a significant amount of time in an LSTM, for example + OR + data used to make a new constituent and the method used + - for example, CloseConstituent can return the children needed + and itself. this allows a batch operation to build + the constituent + """ + pass + + def delta_opens(self): + return 0 + + def apply(self, state, model): + """ + return a new State transformed via this transition + """ + word_position, constituents, new_constituent, callback = self.update_state(state, model) + if callback is not None: + new_constituent = callback.build_constituents(model, [new_constituent])[0] + constituents = model.push_constituents([constituents], [new_constituent])[0] + + return state._replace(num_opens=state.num_opens + self.delta_opens(), + word_position=word_position, + transitions=model.push_transitions([state.transitions], [self])[0], + constituents=constituents) + + @abstractmethod + def is_legal(self, state, model): + """ + assess whether or not this transition is legal in this state + + at parse time, the parser might choose a transition which cannot be made + """ + pass + + def __lt__(self, other): + # put the Shift at the front of a list, and otherwise sort alphabetically + if self == other: + return False + if isinstance(self, Shift): + return True + if isinstance(other, Shift): + return False + return str(self) < str(other) + +class Shift(Transition): + def update_state(self, state, model): + """ + This will handle all aspects of a shift transition + + - push the top element of the word queue onto constituents + - pop the top element of the word queue + """ + new_constituent = model.transform_word_to_constituent(state) + return state.word_position+1, state.constituents, new_constituent, None + + def is_legal(self, state, model): + """ + Disallow shifting when the word queue is empty or there are no opens to eventually eat this word + """ + if state.empty_word_queue(): + return False + if model.is_top_down(): + # top down transition sequences cannot shift if there are currently no + # Open transitions on the stack. in such a case, the new constituent + # will never be reduced + if state.num_opens == 0: + return False + if state.num_opens == 1: + # there must be at least one transition, since there is an open + assert state.transitions.parent is not None + if state.transitions.parent.parent is None: + # only one transition + trans = model.get_top_transition(state.transitions) + # must be an Open, since there is one open and one transitions + # note that an S, FRAG, etc could happen if we're using unary + # and ROOT-S is possible in the case of compound Open + # in both cases, Shift is legal + # Note that the corresponding problem of shifting after the ROOT-S + # has been closed to just ROOT is handled in CloseConstituent + if len(trans.label) == 1 and trans.top_label in model.get_root_labels(): + # don't shift a word at the very start of a parse + # we want there to be an extra layer below ROOT + return False + else: + # in-order k==1 (the only other option currently) + # can shift ONCE, but note that there is no way to consume + # two items in a row if there is no Open on the stack. + # As long as there is one or more open transitions, + # everything can be eaten + if state.num_opens == 0: + if state.num_constituents() > 0: + return False + return True + + def __repr__(self): + return "Shift" + + def __eq__(self, other): + if self is other: + return True + if isinstance(other, Shift): + return True + return False + + def __hash__(self): + return hash(37) + +class CompoundUnary(Transition): + # TODO: run experiments to see if this is actually useful + def __init__(self, labels): + # the FIRST label will be the top of the tree + # so CompoundUnary that results in root will have root as labels[0], for example + if isinstance(labels, str): + self.labels = (labels,) + else: + self.labels = tuple(labels) + + def update_state(self, state, model): + # remove the top constituent + # apply the labels + # put the constituent back on the state + constituents = state.constituents + new_constituent = model.unary_transform(state.constituents, self.labels) + constituents = constituents.pop() + return state.word_position, constituents, new_constituent, None + + def is_legal(self, state, model): + """ + Disallow consecutive CompoundUnary transitions, force final transition to go to ROOT + """ + # can't unary transition nothing + if model.get_top_constituent(state.constituents) is None: + return False + # don't unary transition a dummy, dummy + # and don't stack CompoundUnary transitions + if isinstance(model.get_top_transition(state.transitions), (CompoundUnary, OpenConstituent)): + return False + is_root = self.labels[0] in model.get_root_labels() + if not state.empty_word_queue() or not state.has_one_constituent(): + return not is_root + else: + return is_root + + def __repr__(self): + return "CompoundUnary(%s)" % ",".join(self.labels) + + def __eq__(self, other): + if self is other: + return True + if not isinstance(other, CompoundUnary): + return False + if self.labels == other.labels: + return True + return False + + def __hash__(self): + return hash(self.labels) + +class Dummy(): + """ + Takes a space on the constituent stack to represent where an Open transition occurred + """ + def __init__(self, label): + self.label = label + + def __str__(self): + return "Dummy({})".format(self.label) + + def __eq__(self, other): + if self is other: + return True + if not isinstance(other, Dummy): + return False + if self.label == other.label: + return True + return False + + def __hash__(self): + return hash(self.label) + +def too_many_unary_nodes(tree): + """ + Return True iff there are UNARY_LIMIT unary nodes in a tree in a row + + helps prevent infinite open/close patterns + otherwise, the model can get stuck in essentially an infinite loop + """ + if tree is None: + return False + for _ in range(UNARY_LIMIT + 1): + if len(tree.children) != 1: + return False + tree = tree.children[0] + return True + +class OpenConstituent(Transition): + def __init__(self, *label): + self.label = tuple(label) + self.top_label = self.label[0] + + def delta_opens(self): + return 1 + + def update_state(self, state, model): + # open a new constituent which can later be closed + # puts a DUMMY constituent on the stack to mark where the constituents end + return state.word_position, state.constituents, model.dummy_constituent(Dummy(self.label)), None + + def is_legal(self, state, model): + """ + disallow based on the length of the sentence + """ + if state.num_opens > state.sentence_length + 5: + # fudge a bit so we don't miss root nodes etc in very small trees + return False + if model.is_top_down(): + # If the model is top down, you can't Open if there are + # no word to eventually eat + if state.empty_word_queue(): + return False + # Also, you can only Open a ROOT iff it is at the root position + # The assumption in the unary scheme is there will be no + # root open transitions + if not model.has_unary_transitions(): + # TODO: maybe cache this value if this is an expensive operation + is_root = self.top_label in model.get_root_labels() + if is_root: + return state.empty_transitions() + else: + return not state.empty_transitions() + else: + # in-order nodes can Open as long as there is at least one thing + # on the constituency stack + # since closing the in-order involves removing one more + # item before the open, and it can close at any time + # (a close immediately after the open represents a unary) + if state.num_constituents() == 0: + return False + if isinstance(model.get_top_transition(state.transitions), OpenConstituent): + # consecutive Opens don't make sense in the context of in-order + return False + # one other restriction - we assume all parse trees + # start with (ROOT (first_real_con ...)) + # therefore ROOT can only occur via Open after everything + # else has been pushed and processed + # there are no further restrictions + is_root = self.top_label in model.get_root_labels() + if is_root: + # can't make a root node if it will be in the middle of the parse + # can't make a root node if there's still words to eat + # note that the second assumption wouldn't work, + # except we are assuming there will never be multiple + # nodes under one root + return state.num_opens == 0 and state.empty_word_queue() + else: + if (state.num_opens > 0 or state.empty_word_queue()) and too_many_unary_nodes(model.get_top_constituent(state.constituents)): + # looks like we've been in a loop of lots of unary transitions + # note that we check `num_opens > 0` because otherwise we might wind up stuck + # in a state where the only legal transition is open, such as if the + # constituent stack is otherwise empty, but the open is illegal because + # it causes too many unaries + # in such a case we can forbid the corresponding close instead... + # if empty_word_queue, that means it is trying to make infinitiely many + # non-ROOT Open transitions instead of just transitioning ROOT + return False + return True + return True + + def __repr__(self): + return "OpenConstituent({})".format(self.label) + + def __eq__(self, other): + if self is other: + return True + if not isinstance(other, OpenConstituent): + return False + if self.label == other.label: + return True + return False + + def __hash__(self): + return hash(self.label) + +class CloseConstituent(Transition): + def delta_opens(self): + return -1 + + def update_state(self, state, model): + # pop constituents until we are done + children = [] + constituents = state.constituents + while not isinstance(model.get_top_constituent(constituents), Dummy): + # keep the entire value from the stack - the model may need + # the whole thing to transform the children into a new node + children.append(constituents.value) + constituents = constituents.pop() + # the Dummy has the label on it + label = model.get_top_constituent(constituents).label + # pop past the Dummy as well + constituents = constituents.pop() + if not model.is_top_down(): + # the alternative to TOP_DOWN_... is IN_ORDER + # in which case we want to pop one more constituent + children.append(constituents.value) + constituents = constituents.pop() + # the children are in the opposite order of what we expect + children.reverse() + + return state.word_position, constituents, (label, children), CloseConstituent + + @staticmethod + def build_constituents(model, data): + labels, children_lists = list(map(list, zip(*data))) + new_constituents = model.build_constituents(labels, children_lists) + return new_constituents + + + def is_legal(self, state, model): + """ + Disallow if there is no Open on the stack yet + in TOP_DOWN, if the previous transition was the Open (nothing built yet) + in IN_ORDER, previous transition does not matter, except for one small corner case + """ + if state.num_opens <= 0: + return False + if model.is_top_down(): + if isinstance(model.get_top_transition(state.transitions), OpenConstituent): + return False + if state.num_opens <= 1 and not state.empty_word_queue(): + # don't close the last open until all words have been used + return False + if model.transition_scheme() == TransitionScheme.TOP_DOWN_COMPOUND: + # when doing TOP_DOWN_COMPOUND, we assume all transitions + # at the ROOT level have an S, SQ, FRAG, etc underneath + # this is checked when the model is first trained + if state.num_opens == 1 and not state.empty_word_queue(): + return False + elif not model.has_unary_transitions(): + # in fact, we have to leave the top level constituent + # under the ROOT open if unary transitions are not possible + if state.num_opens == 2 and not state.empty_word_queue(): + return False + else: + if not isinstance(model.get_top_transition(state.transitions), OpenConstituent): + # we're not stuck in a loop of unaries + return True + if state.num_opens > 1 or state.empty_word_queue(): + # in either of these cases, the corresponding Open should be eliminated + # if we're stuck in a loop of unaries + return True + node = model.get_top_constituent(state.constituents.pop()) + if too_many_unary_nodes(node): + # at this point, we are in a situation where + # - multiple unaries have happened in a row + # - there is stuff on the word_queue, so a ROOT open isn't legal + # - there's only one constituent on the stack, so the only legal + # option once there are no opens left will be an open + # this means we'll be stuck having to open again if we do close + # this node, so instead we make the Close illegal + return False + return True + + def __repr__(self): + return "CloseConstituent" + + def __eq__(self, other): + if self is other: + return True + if isinstance(other, CloseConstituent): + return True + return False + + def __hash__(self): + return hash(93) + +def bulk_apply(model, tree_batch, transitions, fail=False, max_transitions=1000): + remove = set() + + word_positions = [] + constituents = [] + new_constituents = [] + callbacks = defaultdict(list) + + for idx, (tree, transition) in enumerate(zip(tree_batch, transitions)): + if not transition: + error = "Got stuck and couldn't find a legal transition on the following gold tree:\n{}\n\nFinal state:\n{}".format(tree.gold_tree, tree.to_string(model)) + if fail: + raise ValueError(error) + else: + logger.error(error) + remove.add(idx) + continue + + if max_transitions and tree.num_transitions() >= max_transitions: + # too many transitions + if tree.gold_tree: + error = "Went infinite on the following gold tree:\n{}\n\nFinal state:\n{}".format(tree.gold_tree, tree.to_string(model)) + else: + error = "Went infinite!:\nFinal state:\n{}".format(tree.to_string(model)) + if fail: + raise ValueError(error) + else: + logger.error(error) + remove.add(idx) + continue + + wq, c, nc, callback = transition.update_state(tree, model) + + word_positions.append(wq) + constituents.append(c) + new_constituents.append(nc) + if callback: + # not `idx` in case something was removed + callbacks[callback].append(len(new_constituents)-1) + + for key, idxs in callbacks.items(): + data = [new_constituents[x] for x in idxs] + callback_constituents = key.build_constituents(model, data) + for idx, constituent in zip(idxs, callback_constituents): + new_constituents[idx] = constituent + + tree_batch = [tree for idx, tree in enumerate(tree_batch) if idx not in remove] + transitions = [trans for idx, trans in enumerate(transitions) if idx not in remove] + + if len(tree_batch) == 0: + return tree_batch + + new_transitions = model.push_transitions([tree.transitions for tree in tree_batch], transitions) + new_constituents = model.push_constituents(constituents, new_constituents) + + tree_batch = [state._replace(num_opens=state.num_opens + transition.delta_opens(), + word_position=word_position, + transitions=transition_stack, + constituents=constituents) + for (state, transition, word_position, transition_stack, constituents) + in zip(tree_batch, transitions, word_positions, new_transitions, new_constituents)] + + return tree_batch diff --git a/stanza/models/constituency/parse_tree.py b/stanza/models/constituency/parse_tree.py index 1ca95f8d..6eb24717 100644 --- a/stanza/models/constituency/parse_tree.py +++ b/stanza/models/constituency/parse_tree.py @@ -2,18 +2,28 @@ Tree datastructure """ -from collections import deque +from collections import deque, Counter from io import StringIO +import re from stanza.models.common.doc import StanzaObject +# useful more for the "is" functionality than the time savings +CLOSE_PAREN = ')' +SPACE_SEPARATOR = ' ' +OPEN_PAREN = '(' + +EMPTY_CHILDREN = () + +CONSTITUENT_SPLIT = re.compile("[-=#]") + class Tree(StanzaObject): """ A data structure to represent a parse tree """ def __init__(self, label=None, children=None): if children is None: - self.children = [] + self.children = EMPTY_CHILDREN elif isinstance(children, Tree): self.children = (children,) else: @@ -27,15 +37,37 @@ class Tree(StanzaObject): def is_preterminal(self): return len(self.children) == 1 and len(self.children[0].children) == 0 - def yield_preterminals(self): - if self.is_leaf(): - pass - elif self.is_preterminal(): - yield self - else: - for child in self.children: - for preterminal in child.yield_preterminals(): - yield preterminal + def yield_reversed_preterminals(self): + """ + Yield the preterminals one at a time in BACKWARDS order + + This is done reversed as it is a frequently used method in the + parser, so this is a tiny optimization + """ + nodes = deque() + nodes.append(self) + while len(nodes) > 0: + node = nodes.pop() + if len(node.children) == 0: + raise ValueError("Got called with an unexpected tree layout: {}".format(self)) + elif node.is_preterminal(): + yield node + else: + nodes.extend(node.children) + + def leaf_labels(self): + """ + Get the labels of the leaves + + Not optimized whatsoever - current not an important part of + the parser + """ + preterminals = reversed([x for x in self.yield_reversed_preterminals()]) + words = [x.children[0].label for x in preterminals] + return words + + def preterminals(self): + return list(reversed(list(self.yield_reversed_preterminals()))) def __repr__(self): """ @@ -49,18 +81,21 @@ class Tree(StanzaObject): stack.append(self) while len(stack) > 0: node = stack.pop() - if node == ')' or node == ' ': + # note that == can recursively call == in some circumstances! + if node is CLOSE_PAREN or node is SPACE_SEPARATOR: buf.write(node) continue - if not node.children: - buf.write(node.label) + if len(node.children) == 0: + if node.label is not None: + buf.write(node.label) continue - buf.write("(") - buf.write(node.label) - stack.append(')') + buf.write(OPEN_PAREN) + if node.label is not None: + buf.write(node.label) + stack.append(CLOSE_PAREN) for child in reversed(node.children): stack.append(child) - stack.append(' ') + stack.append(SPACE_SEPARATOR) buf.seek(0) return buf.read() @@ -71,7 +106,9 @@ class Tree(StanzaObject): return False if self.label != other.label: return False - if self.children != other.children: + if len(self.children) != len(other.children): + return False + if any(c1 != c2 for c1, c2 in zip(self.children, other.children)): return False return True @@ -79,3 +116,188 @@ class Tree(StanzaObject): if not self.children: return 0 return 1 + max(x.depth() for x in self.children) + + def visit_preorder(self, internal=None, preterminal=None, leaf=None): + """ + Visit the tree in a preorder order + + Applies the given functions to each node. + internal: if not None, applies this function to each non-leaf, non-preterminal node + preterminal: if not None, applies this functiion to each preterminal + leaf: if not None, applies this function to each leaf + + The functions should *not* destructively alter the trees. + There is no attempt to interpret the results of calling these functions. + Rather, you can use visit_preorder to collect stats on trees, etc. + """ + if self.is_leaf(): + if leaf: + leaf(self) + elif self.is_preterminal(): + if preterminal: + preterminal(self) + else: + if internal: + internal(self) + for child in self.children: + child.visit_preorder(internal, preterminal, leaf) + + @staticmethod + def get_unique_constituent_labels(trees): + """ + Walks over all of the trees and gets all of the unique constituent names from the trees + """ + if isinstance(trees, Tree): + trees = [trees] + + constituents = set() + for tree in trees: + tree.visit_preorder(internal = lambda x: constituents.add(x.label)) + return sorted(constituents) + + @staticmethod + def get_unique_tags(trees): + """ + Walks over all of the trees and gets all of the unique tags from the trees + """ + if isinstance(trees, Tree): + trees = [trees] + + tags = set() + for tree in trees: + tree.visit_preorder(preterminal = lambda x: tags.add(x.label)) + return sorted(tags) + + @staticmethod + def get_unique_words(trees): + """ + Walks over all of the trees and gets all of the unique words from the trees + """ + if isinstance(trees, Tree): + trees = [trees] + + words = set() + for tree in trees: + tree.visit_preorder(leaf = lambda x: words.add(x.label)) + return sorted(words) + + @staticmethod + def get_rare_words(trees, threshold=0.05): + """ + Walks over all of the trees and gets the least frequently occurring words. + + threshold: choose the bottom X percent + """ + if isinstance(trees, Tree): + trees = [trees] + + words = Counter() + for tree in trees: + tree.visit_preorder(leaf = lambda x: words.update([x.label])) + threshold = max(int(len(words) * threshold), 1) + return sorted(x[0] for x in words.most_common()[:-threshold-1:-1]) + + @staticmethod + def get_root_labels(trees): + return sorted(set(x.label for x in trees)) + + @staticmethod + def get_compound_constituents(trees): + constituents = set() + stack = deque() + for tree in trees: + stack.append(tree) + while len(stack) > 0: + node = stack.pop() + if node.is_leaf() or node.is_preterminal(): + continue + labels = [node.label] + while len(node.children) == 1 and not node.children[0].is_preterminal(): + node = node.children[0] + labels.append(node.label) + constituents.add(tuple(labels)) + for child in node.children: + stack.append(child) + return sorted(constituents) + + # TODO: test different pattern + def simplify_labels(self, pattern=CONSTITUENT_SPLIT): + """ + Return a copy of the tree with the -=# removed + + Leaves the text of the leaves alone. + """ + new_label = self.label + # check len(new_label) just in case it's a tag of - or = + if new_label and not self.is_leaf() and len(new_label) > 1 and new_label not in ('-LRB-', '-RRB-'): + new_label = pattern.split(new_label)[0] + new_children = [child.simplify_labels(pattern) for child in self.children] + return Tree(new_label, new_children) + + def remap_constituent_labels(self, label_map): + """ + Copies the tree with some labels replaced. + + Labels in the map are replaced with the mapped value. + Labels not in the map are unchanged. + """ + if self.is_leaf(): + return Tree(self.label) + if self.is_preterminal(): + return Tree(self.label, Tree(self.children[0].label)) + new_label = label_map.get(self.label, self.label) + return Tree(new_label, [child.remap_constituent_labels(label_map) for child in self.children]) + + def remap_words(self, word_map): + """ + Copies the tree with some labels replaced. + + Labels in the map are replaced with the mapped value. + Labels not in the map are unchanged. + """ + if self.is_leaf(): + new_label = word_map.get(self.label, self.label) + return Tree(new_label) + if self.is_preterminal(): + return Tree(self.label, self.children[0].remap_words(word_map)) + return Tree(self.label, [child.remap_words(word_map) for child in self.children]) + + def replace_words(self, words): + """ + Replace all leaf words with the words in the given list (or iterable) + + Returns a new tree + """ + word_iterator = iter(words) + def recursive_replace_words(subtree): + if subtree.is_leaf(): + word = next(word_iterator, None) + if word is None: + raise ValueError("Not enough words to replace all leaves") + return Tree(word) + return Tree(subtree.label, [recursive_replace_words(x) for x in subtree.children]) + + new_tree = recursive_replace_words(self) + if any(True for _ in word_iterator): + raise ValueError("Too many tags for the given tree") + return new_tree + + + def prune_none(self): + """ + Return a copy of the tree, eliminating all nodes which are in one of two categories: + they are a preterminal -NONE-, such as appears in PTB + they have been pruned to 0 children by the recursive call + """ + if self.is_leaf(): + return Tree(self.label) + if self.is_preterminal(): + if self.label == '-NONE-': + return None + return Tree(self.label, Tree(self.children[0].label)) + # must be internal node + new_children = [child.prune_none() for child in self.children] + new_children = [child for child in new_children if child is not None] + if len(new_children) == 0: + return None + return Tree(self.label, new_children) diff --git a/stanza/models/constituency/trainer.py b/stanza/models/constituency/trainer.py new file mode 100644 index 00000000..e8d49bb5 --- /dev/null +++ b/stanza/models/constituency/trainer.py @@ -0,0 +1,570 @@ +""" +This file includes a variety of methods needed to train new +constituency parsers. It also includes a method to load an +already-trained parser. + +See the `train` method for the code block which starts from + raw treebank and returns a new parser. +`evaluate` reads a treebank and gives a score for those trees. +`parse_tagged_words` is useful at Pipeline time - + it takes words & tags and processes that into trees. +""" + +import logging +import random +import os + +import torch +from torch import nn +from torch import optim + +from stanza.models.common import pretrain +from stanza.models.common import utils +from stanza.models.common.char_model import CharacterLanguageModel +from stanza.models.constituency import base_model +from stanza.models.constituency import parse_transitions +from stanza.models.constituency import parse_tree +from stanza.models.constituency import transition_sequence +from stanza.models.constituency import tree_reader +from stanza.models.constituency.lstm_model import LSTMModel +from stanza.models.constituency.parse_transitions import State, TransitionScheme +from stanza.models.constituency.utils import retag_trees +from stanza.server.parser_eval import EvaluateParser + +tqdm = utils.get_tqdm() + +logger = logging.getLogger('stanza') + + +class Trainer: + """ + Stores a constituency model and its optimizer + + Not inheriting from common/trainer.py because there's no concept of change_lr (yet?) + """ + def __init__(self, model=None, optimizer=None): + self.model = model + self.optimizer = optimizer + + def save(self, filename, save_optimizer=True): + """ + Save the model (and by default the optimizer) to the given path + """ + params = self.model.get_params() + checkpoint = { + 'params': params, + 'model_type': 'LSTM', + } + if save_optimizer and self.optimizer is not None: + checkpoint['optimizer_state_dict'] = self.optimizer.state_dict() + torch.save(checkpoint, filename, _use_new_zipfile_serialization=False) + logger.info("Model saved to %s", filename) + + + @staticmethod + def load(filename, pt, forward_charlm, backward_charlm, use_gpu, args=None, load_optimizer=False): + """ + Load back a model and possibly its optimizer. + + pt: a Pretrain word embedding + """ + if args is None: + args = {} + + try: + checkpoint = torch.load(filename, lambda storage, loc: storage) + except BaseException: + logger.exception("Cannot load model from %s", filename) + raise + logger.debug("Loaded model from %s", filename) + + model_type = checkpoint['model_type'] + params = checkpoint.get('params', checkpoint) + + if model_type == 'LSTM': + model = LSTMModel(pretrain=pt, + forward_charlm=forward_charlm, + backward_charlm=backward_charlm, + transitions=params['transitions'], + constituents=params['constituents'], + tags=params['tags'], + words=params['words'], + rare_words=params['rare_words'], + root_labels=params['root_labels'], + open_nodes=params['open_nodes'], + args=params['config']) + else: + raise ValueError("Unknown model type {}".format(model_type)) + model.load_state_dict(params['model'], strict=False) + + if use_gpu: + model.cuda() + + if load_optimizer: + optimizer_args = dict(params['config']) + optimizer_args.update(args) + optimizer = build_optimizer(optimizer_args, model) + + if checkpoint.get('optimizer_state_dict', None) is not None: + optimizer.load_state_dict(checkpoint['optimizer_state_dict']) + else: + logger.info("Attempted to load optimizer to resume training, but optimizer not saved. Creating new optimizer") + else: + optimizer = None + + logger.debug("-- MODEL CONFIG --") + for k in model.args.keys(): + logger.debug(" --%s: %s", k, model.args[k]) + + return Trainer(model=model, optimizer=optimizer) + + +def build_optimizer(args, model): + """ + Build an optimizer based on the arguments given + """ + if args['optim'].lower() == 'sgd': + optimizer = optim.SGD(model.parameters(), lr=args['learning_rate'], momentum=0.9, weight_decay=args['weight_decay']) + elif args['optim'].lower() == 'adadelta': + optimizer = optim.Adadelta(model.parameters(), lr=args['learning_rate'], weight_decay=args['weight_decay']) + elif args['optim'].lower() == 'adamw': + optimizer = optim.AdamW(model.parameters(), lr=args['learning_rate'], weight_decay=args['weight_decay']) + else: + raise ValueError("Unknown optimizer: %s" % args.optim) + return optimizer + +def load_pretrain(args): + """ + Loads a pretrain based on the paths in the arguments + """ + pretrain_file = pretrain.find_pretrain_file(args['wordvec_pretrain_file'], args['save_dir'], args['shorthand'], args['lang']) + if os.path.exists(pretrain_file): + vec_file = None + else: + vec_file = args['wordvec_file'] if args['wordvec_file'] else utils.get_wordvec_file(args['wordvec_dir'], args['shorthand']) + pt = pretrain.Pretrain(pretrain_file, vec_file, args['pretrain_max_vocab']) + return pt + +def load_charlm(charlm_file): + if charlm_file: + logger.debug("Loading charlm from %s", charlm_file) + return CharacterLanguageModel.load(charlm_file, finetune=False) + return None + +def read_treebank(filename): + """ + Read a treebank and alter the trees to be a simpler format for learning to parse + """ + logger.info("Reading trees from %s", filename) + trees = tree_reader.read_tree_file(filename) + trees = [t.prune_none().simplify_labels() for t in trees] + + illegal_trees = [t for t in trees if len(t.children) > 1] + if len(illegal_trees) > 0: + raise ValueError("Found {} tree(s) which had non-unary transitions at the ROOT. First illegal tree: {}".format(len(illegal_trees), illegal_trees[0])) + + return trees + +def verify_transitions(trees, sequences, transition_scheme): + """ + Given a list of trees and their transition sequences, verify that the sequences rebuild the trees + """ + model = base_model.SimpleModel(transition_scheme) + logger.info("Verifying the transition sequences for %d trees", len(trees)) + for tree, sequence in tqdm(zip(trees, sequences), total=len(trees)): + state = parse_transitions.initial_state_from_gold_trees([tree], model)[0] + for idx, trans in enumerate(sequence): + if not trans.is_legal(state, model): + raise RuntimeError("Transition {}:{} was not legal in a transition sequence:\nOriginal tree: {}\nTransitions: {}".format(idx, trans, tree, sequence)) + state = trans.apply(state, model) + result = model.get_top_constituent(state.constituents) + if tree != result: + raise RuntimeError("Transition sequence did not match for a tree!\nOriginal tree:{}\nTransitions: {}\nResult tree:{}".format(tree, sequence, result)) + +def evaluate(args, model_file, retag_pipeline): + """ + Loads the given model file and tests the eval_file treebank. + + May retag the trees using retag_pipeline + Uses a subprocess to run the Java EvalB code + """ + pt = load_pretrain(args) + forward_charlm = load_charlm(args['charlm_forward_file']) + backward_charlm = load_charlm(args['charlm_backward_file']) + trainer = Trainer.load(model_file, pt, forward_charlm, backward_charlm, args['cuda']) + + treebank = read_treebank(args['eval_file']) + logger.info("Read %d trees for evaluation", len(treebank)) + + if retag_pipeline is not None: + logger.info("Retagging trees using the %s tags from the %s package...", args['retag_method'], args['retag_package']) + treebank = retag_trees(treebank, retag_pipeline, args['retag_xpos']) + logger.info("Retagging finished") + + f1 = run_dev_set(trainer.model, treebank, args) + logger.info("F1 score on %s: %f", args['eval_file'], f1) + +def build_treebank(trees, transition_scheme): + """ + Convert a set of trees into the corresponding treebank based on the args + + Currently only supports top-down transitions, but more may be added in the future, especially bottom up + """ + return transition_sequence.build_treebank(trees, transition_scheme=transition_scheme) + +def get_open_nodes(trees, args): + """ + Return a list of all open nodes in the given dataset. + Depending on the parameters, may be single or compound open transitions. + """ + if args['transition_scheme'] is TransitionScheme.TOP_DOWN_COMPOUND: + return parse_tree.Tree.get_compound_constituents(trees) + else: + return [(x,) for x in parse_tree.Tree.get_unique_constituent_labels(trees)] + +def print_args(args): + """ + For record keeping purposes, print out the arguments when training + """ + keys = sorted(args.keys()) + log_lines = ['%s: %s' % (k, args[k]) for k in keys] + logger.info('ARGS USED AT TRAINING TIME:\n%s\n', '\n'.join(log_lines)) + +def remove_optimizer(args, model_save_file, model_load_file): + """ + A utility method to remove the optimizer from a save file + + Will make the save file a lot smaller + """ + # TODO: kind of overkill to load in the pretrain rather than + # change the load/save to work without it, but probably this + # functionality isn't used that often anyway + pt = load_pretrain(args) + forward_charlm = load_charlm(args['charlm_forward_file']) + backward_charlm = load_charlm(args['charlm_backward_file']) + trainer = Trainer.load(model_load_file, pt, forward_charlm, backward_charlm, use_gpu=False, load_optimizer=False) + trainer.save(model_save_file) + +def train(args, model_save_file, model_load_file, model_save_latest_file, retag_pipeline): + """ + Build a model, train it using the requested train & dev files + """ + print_args(args) + + utils.ensure_dir(args['save_dir']) + + train_trees = read_treebank(args['train_file']) + logger.info("Read %d trees for the training set", len(train_trees)) + + dev_trees = read_treebank(args['eval_file']) + logger.info("Read %d trees for the dev set", len(dev_trees)) + + if retag_pipeline is not None: + logger.info("Retagging trees using the %s tags from the %s package...", args['retag_method'], args['retag_package']) + train_trees = retag_trees(train_trees, retag_pipeline, args['retag_xpos']) + dev_trees = retag_trees(dev_trees, retag_pipeline, args['retag_xpos']) + logger.info("Retagging finished") + + train_constituents = parse_tree.Tree.get_unique_constituent_labels(train_trees) + dev_constituents = parse_tree.Tree.get_unique_constituent_labels(dev_trees) + logger.info("Unique constituents in training set: %s", train_constituents) + for con in dev_constituents: + if con not in train_constituents: + raise RuntimeError("Found label {} in the dev set which don't exist in the train set".format(con)) + + logger.info("Building training transition sequences") + train_sequences = build_treebank(tqdm(train_trees), args['transition_scheme']) + train_transitions = transition_sequence.all_transitions(train_sequences) + + logger.info("Building dev transition sequences") + dev_sequences = build_treebank(tqdm(dev_trees), args['transition_scheme']) + dev_transitions = transition_sequence.all_transitions(dev_sequences) + + logger.info("Total unique transitions in train set: %d", len(train_transitions)) + for trans in dev_transitions: + if trans not in train_transitions: + raise RuntimeError("Found transition {} in the dev set which don't exist in the train set".format(trans)) + + verify_transitions(train_trees, train_sequences, args['transition_scheme']) + verify_transitions(dev_trees, dev_sequences, args['transition_scheme']) + + root_labels = parse_tree.Tree.get_root_labels(train_trees) + for root_state in parse_tree.Tree.get_root_labels(dev_trees): + if root_state not in root_labels: + raise RuntimeError("Found root state {} in the dev set which is not a ROOT state in the train set".format(root_state)) + + tags = parse_tree.Tree.get_unique_tags(train_trees) + logger.info("Unique tags in training set: %s", tags) + for tag in parse_tree.Tree.get_unique_tags(dev_trees): + if tag not in tags: + raise RuntimeError("Found tag {} in the dev set which is not a tag in the train set".format(tag)) + + # we don't check against the words in the dev set as it is + # expected there will be some UNK words + words = parse_tree.Tree.get_unique_words(train_trees) + rare_words = parse_tree.Tree.get_rare_words(train_trees, args['rare_word_threshold']) + # also, it's not actually an error if there is a pattern of + # compound unary or compound open nodes which doesn't exist in the + # train set. it just means we probably won't ever get that right + open_nodes = get_open_nodes(train_trees, args) + + pt = load_pretrain(args) + forward_charlm = load_charlm(args['charlm_forward_file']) + backward_charlm = load_charlm(args['charlm_backward_file']) + + # at this point we have: + # pretrain + # train_trees, dev_trees + # lists of transitions, internal nodes, and root states the parser needs to be aware of + + if args['finetune'] or (args['maybe_finetune'] and os.path.exists(model_load_file)): + logger.info("Loading model to continue training from %s", model_load_file) + trainer = Trainer.load(model_load_file, pt, forward_charlm, backward_charlm, args['cuda'], args, load_optimizer=True) + else: + model = LSTMModel(pt, forward_charlm, backward_charlm, train_transitions, train_constituents, tags, words, rare_words, root_labels, open_nodes, args) + if args['cuda']: + model.cuda() + + optimizer = build_optimizer(args, model) + + trainer = Trainer(model, optimizer) + + iterate_training(trainer, train_trees, train_sequences, train_transitions, dev_trees, args, model_save_file, model_save_latest_file) + +def iterate_training(trainer, train_trees, train_sequences, transitions, dev_trees, args, model_filename, model_latest_filename): + """ + Given an initialized model, a processed dataset, and a secondary dev dataset, train the model + + The training is iterated in the following loop: + extract a batch of trees of the same length from the training set + convert those trees into initial parsing states + repeat until trees are done: + batch predict the model's interpretation of the current states + add the errors to the list of things to backprop + advance the parsing state for each of the trees + + Currently the only method implemented for advancing the parsing state + is to use the gold transition. + + TODO: add a dynamic oracle which can adjust the future expected + parsing decisions after the parser makes an error. This way, + the parser will have "experienced" what the correct decision + to make is when it gets into incorrect states at runtime. + """ + model = trainer.model + optimizer = trainer.optimizer + + loss_function = nn.CrossEntropyLoss(reduction='sum') + if args['cuda']: + loss_function.cuda() + + device = next(model.parameters()).device + transition_tensors = {x: torch.tensor(y, requires_grad=False, device=device).unsqueeze(0) + for (y, x) in enumerate(transitions)} + + model.train() + + train_data = list(zip(train_trees, train_sequences)) + leftover_training_data = [] + best_f1 = 0.0 + best_epoch = 0 + for epoch in range(1, args['epochs']+1): + model.train() + logger.info("Starting epoch %d", epoch) + epoch_data = leftover_training_data + while len(epoch_data) < args['eval_interval']: + random.shuffle(train_data) + epoch_data.extend(train_data) + leftover_training_data = epoch_data[args['eval_interval']:] + epoch_data = epoch_data[:args['eval_interval']] + epoch_data.sort(key=lambda x: len(x[1])) + interval_starts = list(range(0, len(epoch_data), args['train_batch_size'])) + random.shuffle(interval_starts) + + epoch_loss = 0.0 + + transitions_correct = 0 + transitions_incorrect = 0 + + for interval_start in tqdm(interval_starts, postfix="Batch"): + batch = epoch_data[interval_start:interval_start+args['train_batch_size']] + # the batch will be empty when all trees from this epoch are trained + # now we add the state to the trees in the batch + initial_states = parse_transitions.initial_state_from_gold_trees([tree for tree, _ in batch], model) + batch = [state._replace(gold_sequence=sequence) + for (tree, sequence), state in zip(batch, initial_states)] + + all_errors = [] + all_answers = [] + + while len(batch) > 0: + outputs, pred_transitions = model.predict(batch) + gold_transitions = [x.gold_sequence[x.num_transitions()] for x in batch] + trans_tensor = [transition_tensors[gold_transition] for gold_transition in gold_transitions] + all_errors.append(outputs) + all_answers.extend(trans_tensor) + + for pred_transition, gold_transition in zip(pred_transitions, gold_transitions): + if pred_transition != gold_transition: + transitions_incorrect = transitions_incorrect + 1 + else: + transitions_correct = transitions_correct + 1 + + # eliminate finished trees, keeping only the transitions we will use + zipped_batch = [x for x in zip(batch, gold_transitions) if x[0].num_transitions() + 1 < len(x[0].gold_sequence)] + batch = [x[0] for x in zipped_batch] + gold_transitions = [x[1] for x in zipped_batch] + + if len(batch) > 0: + # bulk update states + batch = parse_transitions.bulk_apply(model, batch, gold_transitions, fail=True, max_transitions=None) + + errors = torch.cat(all_errors) + answers = torch.cat(all_answers) + tree_loss = loss_function(errors, answers) + tree_loss.backward() + epoch_loss += tree_loss.item() + + optimizer.step() + optimizer.zero_grad() + + # print statistics + f1 = run_dev_set(model, dev_trees, args) + if f1 > best_f1: + logger.info("New best dev score: %.5f > %.5f", f1, best_f1) + best_f1 = f1 + best_epoch = epoch + trainer.save(model_filename, save_optimizer=True) + if model_latest_filename: + trainer.save(model_latest_filename, save_optimizer=True) + logger.info("Epoch {} finished\nTransitions correct: {} Transitions incorrect: {}\n Total loss for epoch: {}\n Dev score ({:5}): {}\n Best dev score ({:5}): {}".format(epoch, transitions_correct, transitions_incorrect, epoch_loss, epoch, f1, best_epoch, best_f1)) + +def build_batch_from_trees(batch_size, data_iterator, model): + """ + Read from the data_iterator batch_size trees and turn them into new parsing states + """ + tree_batch = [] + for _ in range(batch_size): + gold_tree = next(data_iterator, None) + if gold_tree is None: + break + tree_batch.append(gold_tree) + + if len(tree_batch) > 0: + tree_batch = parse_transitions.initial_state_from_gold_trees(tree_batch, model) + return tree_batch + +def build_batch_from_tagged_words(batch_size, data_iterator, model): + """ + Read from the data_iterator batch_size tagged sentences and turn them into new parsing states + """ + tree_batch = [] + for _ in range(batch_size): + sentence = next(data_iterator, None) + if sentence is None: + break + tree_batch.append(sentence) + + if len(tree_batch) > 0: + tree_batch = parse_transitions.initial_state_from_words(tree_batch, model) + return tree_batch + +def parse_sentences(data_iterator, build_batch_fn, batch_size, model): + """ + Given an iterator over the data and a method for building batches, returns a bunch of parse trees. + + The data_iterator should be anything which returns the data for a parse task via next() + build_batch_fn is a function that turns that data into State objects + This will be called to generate batches of size batch_size until the data is exhausted + + The return is a list of tuples: (gold_tree, [(predicted, score) ...]) + gold_tree will be left blank if the data did not include gold trees + currently score is always 1.0, but the interface may be expanded to get a score from the result of the parsing + """ + treebank = [] + tree_batch = build_batch_fn(batch_size, data_iterator, model) + horizon_iterator = iter([]) + + while len(tree_batch) > 0: + _, transitions = model.predict(tree_batch, is_legal=True) + tree_batch = parse_transitions.bulk_apply(model, tree_batch, transitions) + + remove = set() + for idx, tree in enumerate(tree_batch): + if tree.finished(model): + predicted_tree = tree.get_tree(model) + gold_tree = tree.gold_tree + # TODO: put an actual score here? + treebank.append((gold_tree, [(predicted_tree, 1.0)])) + remove.add(idx) + + tree_batch = [tree for idx, tree in enumerate(tree_batch) if idx not in remove] + + for _ in range(batch_size - len(tree_batch)): + horizon_tree = next(horizon_iterator, None) + if not horizon_tree: + horizon_batch = build_batch_fn(batch_size, data_iterator, model) + if len(horizon_batch) == 0: + break + horizon_iterator = iter(horizon_batch) + horizon_tree = next(horizon_iterator, None) + + tree_batch.append(horizon_tree) + + return treebank + +def parse_tagged_words(model, words, batch_size): + """ + This parses tagged words and returns a list of trees. + + The tagged words should be represented: + one list per sentence + each sentence is a list of (word, tag) + The return value is a list of ParseTree objects + """ + logger.debug("Processing %d sentences", len(words)) + model.eval() + + sentence_iterator = iter(words) + treebank = parse_sentences(sentence_iterator, build_batch_from_tagged_words, batch_size, model) + + results = [t[1][0][0] for t in treebank] + return results + +def run_dev_set(model, dev_trees, args): + """ + This reparses a treebank and executes the CoreNLP Java EvalB code. + + It only works if CoreNLP 4.3.0 or higher is in the classpath. + """ + logger.info("Processing %d trees from %s", len(dev_trees), args['eval_file']) + model.eval() + + tree_iterator = iter(tqdm(dev_trees)) + treebank = parse_sentences(tree_iterator, build_batch_from_trees, args['eval_batch_size'], model) + + if len(treebank) < len(dev_trees): + logger.warning("Only evaluating %d trees instead of %d", len(treebank), len(dev_trees)) + + if args['mode'] == 'predict' and args['predict_file']: + utils.ensure_dir(args['predict_dir'], verbose=False) + pred_file = os.path.join(args['predict_dir'], args['predict_file'] + ".pred.mrg") + orig_file = os.path.join(args['predict_dir'], args['predict_file'] + ".orig.mrg") + if os.path.exists(pred_file): + logger.warning("Cowardly refusing to overwrite {}".format(pred_file)) + elif os.path.exists(orig_file): + logger.warning("Cowardly refusing to overwrite {}".format(orig_file)) + else: + with open(pred_file, 'w') as fout: + for tree in treebank: + fout.write(str(tree[1][0][0])) + fout.write("\n") + + with open(orig_file, 'w') as fout: + for tree in treebank: + fout.write(str(tree[0])) + fout.write("\n") + + with EvaluateParser(classpath="$CLASSPATH") as evaluator: + response = evaluator.process(treebank) + return response.f1 diff --git a/stanza/models/constituency/transition_sequence.py b/stanza/models/constituency/transition_sequence.py new file mode 100644 index 00000000..fe60c527 --- /dev/null +++ b/stanza/models/constituency/transition_sequence.py @@ -0,0 +1,112 @@ +""" +Build a transition sequence from parse trees. + +Supports multiple transition schemes - TOP_DOWN and variants, IN_ORDER +""" + +from stanza.models.constituency.parse_transitions import Shift, CompoundUnary, OpenConstituent, CloseConstituent, TransitionScheme +from stanza.models.constituency.tree_reader import read_trees + +def yield_top_down_sequence(tree, transition_scheme=TransitionScheme.TOP_DOWN_UNARY): + """ + For tree (X A B C D), yield Open(X) A B C D Close + + The details are in how to treat unary transitions + Three possibilities handled by this method: + TOP_DOWN_UNARY: (Y (X ...)) -> Open(X) ... Close Unary(Y) + TOP_DOWN_COMPOUND: (Y (X ...)) -> Open(Y, X) ... Close + TOP_DOWN: (Y (X ...)) -> Open(Y) Open(X) ... Close Close + """ + if tree.is_preterminal(): + yield Shift() + return + + if tree.is_leaf(): + return + + if transition_scheme is TransitionScheme.TOP_DOWN_UNARY: + if len(tree.children) == 1: + labels = [] + while not tree.is_preterminal() and len(tree.children) == 1: + labels.append(tree.label) + tree = tree.children[0] + for transition in yield_top_down_sequence(tree, transition_scheme): + yield transition + yield CompoundUnary(labels) + return + + if transition_scheme is TransitionScheme.TOP_DOWN_COMPOUND: + labels = [tree.label] + while len(tree.children) == 1 and not tree.children[0].is_preterminal(): + tree = tree.children[0] + labels.append(tree.label) + yield OpenConstituent(*labels) + else: + yield OpenConstituent(tree.label) + for child in tree.children: + for transition in yield_top_down_sequence(child, transition_scheme): + yield transition + yield CloseConstituent() + +def yield_in_order_sequence(tree): + """ + For tree (X A B C D), yield A Open(X) B C D Close + """ + if tree.is_preterminal(): + yield Shift() + return + + if tree.is_leaf(): + return + + for transition in yield_in_order_sequence(tree.children[0]): + yield transition + + yield OpenConstituent(tree.label) + + for child in tree.children[1:]: + for transition in yield_in_order_sequence(child): + yield transition + + yield CloseConstituent() + +def build_sequence(tree, transition_scheme=TransitionScheme.TOP_DOWN_UNARY): + """ + Turn a single tree into a list of transitions based on the TransitionScheme + """ + if transition_scheme is TransitionScheme.IN_ORDER: + return list(yield_in_order_sequence(tree)) + else: + return list(yield_top_down_sequence(tree, transition_scheme)) + +def build_treebank(trees, transition_scheme=TransitionScheme.TOP_DOWN_UNARY): + """ + Turn each of the trees in the treebank into a list of transitions based on the TransitionScheme + """ + return [build_sequence(tree, transition_scheme) for tree in trees] + +def all_transitions(transition_lists): + """ + Given a list of transition lists, combine them all into a list of unique transitions. + """ + transitions = set() + for trans_list in transition_lists: + for trans in trans_list: + transitions.add(trans) + return sorted(transitions) + +def main(): + """ + Convert a sample tree and print its transitions + """ + text="( (SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))" + #text = "(WP Who)" + + tree = read_trees(text)[0] + + print(tree) + transitions = build_sequence(tree) + print(transitions) + +if __name__ == '__main__': + main() diff --git a/stanza/models/constituency/tree_reader.py b/stanza/models/constituency/tree_reader.py index adfe2b55..05f5e848 100644 --- a/stanza/models/constituency/tree_reader.py +++ b/stanza/models/constituency/tree_reader.py @@ -1,3 +1,10 @@ +""" +Reads ParseTree objects from a file, string, or similar input + +Works by first splitting the input into (, ), and all other tokens, +then recursively processing those tokens into trees. +""" + from stanza.models.common import utils from stanza.models.constituency.parse_tree import Tree @@ -6,35 +13,50 @@ tqdm = utils.get_tqdm() OPEN_PAREN = "(" CLOSE_PAREN = ")" -def recursive_open_tree(token_iterator, at_root): +def recursive_open_tree(token_iterator, at_root, broken_ok): + """ + Build a tree from the tokens in the token_iterator + """ # TODO: unwind the recursion text = [] children = [] token = next(token_iterator, None) - while token != None: + while token is not None: if token is OPEN_PAREN: - children.append(recursive_open_tree(token_iterator, at_root=False)) + children.append(recursive_open_tree(token_iterator, at_root=False, broken_ok=broken_ok)) elif token is CLOSE_PAREN: if len(text) == 0: if at_root: return Tree(label="ROOT", children=children) - raise ValueError("Found a tree with no label on a node! Line number %d" % token_iterator.line_num) + elif broken_ok: + return Tree(label=None, children=children) + else: + raise ValueError("Found a tree with no label on a node! Line number %d" % token_iterator.line_num) pieces = " ".join(text).split() if len(pieces) == 1: return Tree(label=pieces[0], children=children) - if len(children) > 0: - raise ValueError("Found a tree with both text children and bracketed children! Line number %d" % token_iterator.line_num) + + # the assumption here is that a language such as VI may + # have spaces in the words, but it still represents + # just one child label = pieces[0] child_label = " ".join(pieces[1:]) + if len(children) > 0: + if broken_ok: + return Tree(label=label, children=children + [Tree(label=child_label)]) + else: + raise ValueError("Found a tree with both text children and bracketed children! Line number %d" % token_iterator.line_num) return Tree(label=label, children=Tree(label=child_label)) else: text.append(token) token = next(token_iterator, None) -def recursive_read_trees(token_iterator): +def recursive_read_trees(token_iterator, broken_ok): """ + Read all of the trees from the token_iterator + TODO: some of the error cases we hit can be recovered from also, just in general it would be good to unwind the recursion """ @@ -42,7 +64,7 @@ def recursive_read_trees(token_iterator): token = next(token_iterator, None) while token: if token is OPEN_PAREN: - trees.append(recursive_open_tree(token_iterator, at_root=True)) + trees.append(recursive_open_tree(token_iterator, at_root=True, broken_ok=broken_ok)) token = next(token_iterator, None) continue @@ -63,6 +85,7 @@ class TokenIterator: """ def __init__(self, text): self.lines = text.split("\n") + self.num_lines = len(self.lines) self.line_num = -1 self.token_iterator = iter([]) @@ -98,20 +121,31 @@ class TokenIterator: return n -def read_trees(text): +def read_trees(text, broken_ok=False): """ Reads multiple trees from the text """ token_iterator = TokenIterator(text) - trees = recursive_read_trees(token_iterator) + if token_iterator.num_lines > 1000: + token_iterator = iter(tqdm(token_iterator)) + trees = recursive_read_trees(token_iterator, broken_ok=broken_ok) return trees def read_tree_file(filename): + """ + Read all of the trees in the given file + """ with open(filename) as fin: trees = read_trees(fin.read()) return trees -if __name__ == '__main__': +def main(): + """ + Reads a sample tree + """ text="( (SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))" trees = read_trees(text) print(trees) + +if __name__ == '__main__': + main() diff --git a/stanza/models/constituency/tree_stack.py b/stanza/models/constituency/tree_stack.py new file mode 100644 index 00000000..afab61b7 --- /dev/null +++ b/stanza/models/constituency/tree_stack.py @@ -0,0 +1,52 @@ +""" +A utilitiy class for keeping track of intermediate parse states +""" + +from collections import namedtuple + +class TreeStack(namedtuple('TreeStack', ['value', 'parent', 'length'])): + """ + A stack which can branch in several directions, as long as you + keep track of the branching heads + + An example usage is when K constituents are removed at once + to create a new constituent, and then the LSTM which tracks the + values of the constituents is updated starting from the Kth + output of the LSTM with the new value. + + We don't simply keep track of a single stack object using a deque + because versions of the parser which use a beam will want to be + able to branch in different directions from the same base stack + + Another possible usage is if an oracle is used for training + in a manner where some fraction of steps are non-gold steps, + but we also want to take a gold step from the same state. + Eg, parser gets to state X, wants to make incorrect transition T + instead of gold transition G, and so we continue training both + X+G and X+T. If we only represent the state X with standard + python stacks, it would not be possible to track both of these + states at the same time without copying the entire thing. + + Value can be as transition, a word, or a partially built constituent + + Implemented as a namedtuple to make it a bit more efficient + """ + def pop(self): + return self.parent + + def push(self, value): + # returns a new StackNode which points to this + return TreeStack(value, parent=self, length=self.length+1) + + def __iter__(self): + stack = self + while stack.parent is not None: + yield stack.value + stack = stack.parent + yield stack.value + + def __str__(self): + return "TreeStack(%s)" % ", ".join([str(x) for x in self]) + + def __len__(self): + return self.length diff --git a/stanza/models/constituency/utils.py b/stanza/models/constituency/utils.py new file mode 100644 index 00000000..7fd4648d --- /dev/null +++ b/stanza/models/constituency/utils.py @@ -0,0 +1,58 @@ +""" +Collects a few of the conparser utility methods which don't belong elsewhere +""" + +from collections import deque +import copy + +from stanza.models.common.doc import TEXT, Document + +def replace_tags(tree, tags): + if tree.is_leaf(): + raise ValueError("Must call replace_tags with non-leaf") + + tag_iterator = iter(tags) + + new_tree = copy.deepcopy(tree) + queue = deque() + queue.append(new_tree) + while len(queue) > 0: + next_node = queue.pop() + if next_node.is_preterminal(): + try: + label = next(tag_iterator) + except StopIteration: + raise ValueError("Not enough tags in sentence for given tree") + next_node.label = label + elif next_node.is_leaf(): + raise ValueError("Got a badly structured tree: {}".format(tree)) + else: + queue.extend(reversed(next_node.children)) + + if any(True for _ in tag_iterator): + raise ValueError("Too many tags for the given tree") + + return new_tree + + +def retag_trees(trees, pipeline, xpos=True): + """ + Retag all of the trees using the given processor + + Returns a list of new trees + """ + sentences = [] + for tree in trees: + tokens = [{TEXT: pt.children[0].label} for pt in tree.preterminals()] + sentences.append(tokens) + + doc = Document(sentences) + doc = pipeline(doc) + if xpos: + tag_lists = [[x.xpos for x in sentence.words] for sentence in doc.sentences] + else: + tag_lists = [[x.upos for x in sentence.words] for sentence in doc.sentences] + + new_trees = [replace_tags(tree, tags) for tree, tags in zip(trees, tag_lists)] + return new_trees + diff --git a/stanza/models/constituency_parser.py b/stanza/models/constituency_parser.py new file mode 100644 index 00000000..62421822 --- /dev/null +++ b/stanza/models/constituency_parser.py @@ -0,0 +1,290 @@ +"""A command line interface to a shift reduce constituency parser. + +This follows the work of +Recurrent neural network grammars by Dyer et al +In-Order Transition-based Constituent Parsing by Liu & Zhang + +The general outline is: + + Train a model by taking a list of trees, converting them to + transition sequences, and learning a model which can predict the + next transition given a current state + Then, at inference time, repeatedly predict the next transition until parsing is complete + +The "transitions" are variations on shift/reduce as per an +intro-to-compilers class. The idea is that you can treat all of the +words in a sentence as a buffer of tokens, then either "shift" them to +represent a new constituent, or "reduce" one or more constituents to +form a new constituent. + +In order to make the runtime a more competitive speed, effort is taken +to batch the transitions and apply multiple transitions at once. At +train time, batches are groups together by length, and at inference +time, new trees are added to the batch as previous trees on the batch +finish their inference. + +There are two minor differences in the model: + - The word input is a bi-lstm, not a uni-lstm. + This gave a small increase in accuracy. + - The combination of several constituents into one constituent is done + via a single bi-lstm rather than two separate lstms. This increases + speed without a noticeable effect on accuracy. + +A couple experiments which have been tried with little noticeable impact: + - Combining constituents using the method in the paper (only a trained + vector at the start instead of both ends) did not affect results + and is a little slower + - Using multiple layers of LSTM hidden state for the input to the final + classification layers didn't help + +The code breakdown is as follows: + + this file: main interface for training or evaluating models + constituency/trainer.py: contains the training & evaluation code + + constituency/parse_tree.py: a data structure for representing a parse tree and utility methods + constituency/tree_reader.py: a module which can read trees from a string or input file + + constituency/tree_stack.py: a linked list which can branch in + different directions, which will be useful when implementing beam + search or a dynamic oracle + + constituency/parse_transitions.py: transitions and a State data structure to store them + constituency/transition_sequence.py: turns ParseTree objects into + the transition sequences needed to make them + + constituency/base_model.py: operates on the transitions to turn them in to constituents, + eventually forming one final parse tree composed of all of the constituents + constituency/lstm_model.py: adds LSTM features to the constituents to predict what the + correct transition to make is, allowing for predictions on previously unseen text + + stanza/pipeline/constituency_processor.py: interface between this model and the Pipeline +""" + +import argparse +import logging +import os + +import torch + +from stanza import Pipeline +from stanza.models.common import utils +from stanza.models.constituency import trainer +from stanza.models.constituency.parse_transitions import TransitionScheme + +logger = logging.getLogger('stanza') + +def parse_args(args=None): + """ + Adds the arguments for building the con parser + + For the most part, defaults are set to cross-validated values, at least for WSJ + """ + parser = argparse.ArgumentParser() + + parser.add_argument('--data_dir', type=str, default='data/constituency', help='Directory of constituency data.') + + parser.add_argument('--wordvec_dir', type=str, default='extern_data/wordvec', help='Directory of word vectors') + parser.add_argument('--wordvec_file', type=str, default='', help='File that contains word vectors') + parser.add_argument('--wordvec_pretrain_file', type=str, default=None, help='Exact name of the pretrain file to read') + parser.add_argument('--pretrain_max_vocab', type=int, default=250000) + + # for whatever reason, this feature was not helpful + parser.add_argument('--charlm_forward_file', type=str, default=None, help="Exact path to use for forward charlm") + parser.add_argument('--charlm_backward_file', type=str, default=None, help="Exact path to use for backward charlm") + + parser.add_argument('--tag_embedding_dim', type=int, default=20, help="Embedding size for a tag. 0 turns off the feature") + # Smaller values also seem to work + # For example, after 700 iterations: + # 32: 0.9174 + # 50: 0.9183 + # 72: 0.9176 + # 100: 0.9185 + # not a huge difference regardless + # (these numbers were without retagging) + parser.add_argument('--delta_embedding_dim', type=int, default=100, help="Embedding size for a delta embedding") + + parser.add_argument('--train_file', type=str, default=None, help='Input file for data loader.') + parser.add_argument('--eval_file', type=str, default=None, help='Input file for data loader.') + parser.add_argument('--mode', default='train', choices=['train', 'predict', 'remove_optimizer']) + parser.add_argument('--predict_dir', type=str, default=".", help='Where to write the predictions during --mode predict. Pred and orig files will be written - the orig file will be retagged if that is requested. The orig file is important as the results will be shuffled') + parser.add_argument('--predict_file', type=str, default=None, help='Base name for writing predictions') + + parser.add_argument('--lang', type=str, help='Language') + parser.add_argument('--shorthand', type=str, help="Treebank shorthand") + + parser.add_argument('--transition_embedding_dim', type=int, default=20, help="Embedding size for a transition") + parser.add_argument('--transition_hidden_size', type=int, default=20, help="Embedding size for transition stack") + # larger was more effective, up to a point + parser.add_argument('--hidden_size', type=int, default=128, help="Size of the output layers for constituency stack and word queue") + + parser.add_argument('--epochs', type=int, default=200) + parser.add_argument('--eval_interval', type=int, default=5000) + # 30 is slightly slower than 50, for example, but seems to train a bit better on WSJ + # earlier version of the model (less accurate overall) had the following results with adadelta: + # 30: 0.9085 + # 50: 0.9070 + # 75: 0.9010 + # 150: 0.8985 + # as another data point, running a newer version with better constituency lstm behavior had: + # 30: 0.9111 + # 50: 0.9094 + # checking smaller batch sizes to see how this works, at 135 epochs, the values are + # 10: 0.8919 + # 20: 0.9072 + # 30: 0.9121 + # obviously these experiments aren't the complete story, but it + # looks like 30 trees per batch is the best value for WSJ + # note that these numbers are for adadelta and might not apply + # to other optimizers + # eval batch should generally be faster the bigger the batch, + # up to a point, as it allows for more batching of the LSTM + # operations and the prediction step + parser.add_argument('--train_batch_size', type=int, default=30, help='How many trees to train before taking an optimizer step') + parser.add_argument('--eval_batch_size', type=int, default=50, help='How many trees to batch when running eval') + + parser.add_argument('--save_dir', type=str, default='saved_models/constituency', help='Root dir for saving models.') + parser.add_argument('--save_name', type=str, default=None, help="File name to save the model") + parser.add_argument('--save_latest_name', type=str, default=None, help="Save the latest model here regardless of score. Useful for restarting training") + + parser.add_argument('--seed', type=int, default=1234) + parser.add_argument('--cuda', type=bool, default=torch.cuda.is_available()) + parser.add_argument('--cpu', action='store_true', help='Ignore CUDA.') + + DEFAULT_LEARNING_RATES = { "adamw": 0.001, "adadelta": 1.0, "sgd": 0.001 } + parser.add_argument('--learning_rate', default=None, type=float, help='Learning rate for the optimizer. Reasonable values are 1.0 for adadelta or 0.001 for SGD. None uses a default for the given optimizer: {}'.format(DEFAULT_LEARNING_RATES)) + # When using adadelta, weight_decay of 0.01 to 0.001 had the best results. + # 0.1 was very clearly too high. 0.0001 might have been okay. + parser.add_argument('--weight_decay', default=0.01, type=float, help='Weight decay (eg, l2 reg) to use in the optimizer') + parser.add_argument('--optim', default='Adadelta', help='Optimizer type: SGD, AdamW, or Adadelta') + + # When using word_dropout and predict_dropout in conjunction with relu, one particular experiment produced the following dev scores after 300 iterations: + # 0.0: 0.9085 + # 0.2: 0.9165 + # 0.4: 0.9162 + # 0.5: 0.9123 + # Letting 0.2 and 0.4 run for longer, along with 0.3 as another + # trial, continued to give extremely similar results over time. + # No attempt has been made to test the different dropouts separately... + parser.add_argument('--word_dropout', default=0.2, type=float, help='Dropout on the word embedding') + parser.add_argument('--predict_dropout', default=0.2, type=float, help='Dropout on the final prediction layer') + # lstm_dropout has not been fully tested yet + # one experiment after 200 iterations (after retagging, so scores are lower than some other experiments): + # 0.0: 0.9093 + # 0.1: 0.9094 + # 0.2: 0.9094 + # 0.3: 0.9076 + # 0.4: 0.9077 + parser.add_argument('--lstm_layer_dropout', default=0.0, type=float, help='Dropout in the LSTM layers') + # one not very conclusive experiment (not long enough) came up with these numbers after ~200 iterations + # 0.0 0.9091 + # 0.1 0.9095 + # 0.2 0.9118 + # 0.3 0.9123 + # 0.4 0.9080 + parser.add_argument('--lstm_input_dropout', default=0.2, type=float, help='Dropout on the input to an LSTM') + + parser.add_argument('--transition_scheme', default=TransitionScheme.IN_ORDER, type=lambda x: TransitionScheme[x.upper()], + help='Transition scheme to use. {}'.format(", ".join(x.name for x in TransitionScheme))) + + parser.add_argument('--constituency_lstm', default=False, action='store_true', help="Build constituents using the full LSTM instead of just the nodes below the new constituent. Doesn't match the original papers and might be slightly less effective") + + # combining dummy and open node embeddings might be a slight improvement + # for example, after 550 iterations, one experiment had + # True: 0.9154 + # False: 0.9150 + # another (with a different structure) had 850 iterations + # True: 0.9155 + # False: 0.9149 + parser.add_argument('--combined_dummy_embedding', default=False, action='store_true', help="Use the same embedding for dummy nodes and the vectors used when combining constituents") + parser.add_argument('--no_combined_dummy_embedding', dest='combined_dummy_embedding', action='store_false', help="Don't use the same embedding for dummy nodes and the vectors used when combining constituents") + + # relu gave at least 1 F1 improvement over tanh in various experiments + # relu & gelu seem roughly the same, but relu is clearly faster. + # relu, 496 iterations: 0.9176 + # gelu, 467 iterations: 0.9181 + # after the same clock time on the same hardware. the two had been + # trading places in terms of accuracy over those ~500 iterations. + parser.add_argument('--nonlinearity', default='relu', choices=['tanh', 'relu', 'gelu'], help='Nonlinearity to use in the model. relu is a noticeable improvement') + + parser.add_argument('--rare_word_unknown_frequency', default=0.02, type=float, help='How often to replace a rare word with UNK when training') + parser.add_argument('--rare_word_threshold', default=0.02, type=float, help='How many words to consider as rare words as a fraction of the dataset') + + parser.add_argument('--num_lstm_layers', default=2, type=int, help='How many layers to use in the LSTMs') + parser.add_argument('--num_output_layers', default=3, type=int, help='How many layers to use at the prediction level') + + # TODO: add the ability to keep training in a different direction + # after making an error, eg, add an oracle + parser.add_argument('--train_method', default='gold_entire', choices=['gold_entire'], help='Different training methods to use') + + parser.add_argument('--finetune', action='store_true', help='Load existing model during `train` mode from `load_name` path') + parser.add_argument('--maybe_finetune', action='store_true', help='Load existing model during `train` mode from `load_name` path if it exists. Useful for running in situations where a job is frequently being preempted') + parser.add_argument('--load_name', type=str, default=None, help='Model to load when finetuning, evaluating, or manipulating an existing file') + + parser.add_argument('--retag_package', default=None, help='Which tagger shortname to use when retagging trees. None for no retagging. Retagging is recommended, as gold tags will not be available at pipeline time') + parser.add_argument('--retag_method', default='xpos', choices=['xpos', 'upos'], help='Which tags to use when retagging') + + args = parser.parse_args(args=args) + if not args.lang and args.shorthand and len(args.shorthand.split("_")) == 2: + args.lang = args.shorthand.split("_")[0] + if args.cpu: + args.cuda = False + if args.learning_rate is None: + args.learning_rate = DEFAULT_LEARNING_RATES.get(args.optim.lower(), None) + + args = vars(args) + + if args['retag_method'] == 'xpos': + args['retag_xpos'] = True + elif args['retag_method'] == 'upos': + args['retag_xpos'] = False + else: + raise ValueError("Unknown retag method {}".format(xpos)) + + return args + +def main(args=None): + """ + Main function for building con parser + + Processes args, calls the appropriate function for the chosen --mode + """ + args = parse_args(args=args) + + utils.set_random_seed(args['seed'], args['cuda']) + + logger.info("Running constituency parser in %s mode", args['mode']) + logger.debug("Using GPU: %s", args['cuda']) + + model_save_file = args['save_name'] if args['save_name'] else '{}_constituency.pt'.format(args['shorthand']) + model_save_file = os.path.join(args['save_dir'], model_save_file) + + model_save_latest_file = None + if args['save_latest_name']: + model_save_latest_file = os.path.join(args['save_dir'], args['save_latest_name']) + + model_load_file = model_save_file + if args['load_name']: + model_load_file = os.path.join(args['save_dir'], args['load_name']) + elif args['mode'] == 'train' and args['save_latest_name']: + model_load_file = model_save_latest_file + + if args['retag_package'] is not None: + if '_' in args['retag_package']: + lang, package = args['retag_package'].split('_', 1) + retag_pipeline = Pipeline(lang=lang, processors="tokenize, pos", tokenize_pretokenized=True, pos_package=package, pos_tqdm=True) + else: + lang = args['retag_package'] + retag_pipeline = Pipeline(lang=lang, processors="tokenize, pos", tokenize_pretokenized=True, pos_tqdm=True) + else: + retag_pipeline = None + + if args['mode'] == 'train': + trainer.train(args, model_save_file, model_load_file, model_save_latest_file, retag_pipeline) + elif args['mode'] == 'predict': + trainer.evaluate(args, model_load_file, retag_pipeline) + elif args['mode'] == 'remove_optimizer': + trainer.remove_optimizer(args, model_save_file, model_load_file) + +if __name__ == '__main__': + main() diff --git a/stanza/pipeline/_constants.py b/stanza/pipeline/_constants.py index db3e1cb5..eff758f8 100644 --- a/stanza/pipeline/_constants.py +++ b/stanza/pipeline/_constants.py @@ -9,3 +9,4 @@ LEMMA = 'lemma' DEPPARSE = 'depparse' NER = 'ner' SENTIMENT = 'sentiment' +CONSTITUENCY = 'constituency' diff --git a/stanza/pipeline/constituency_processor.py b/stanza/pipeline/constituency_processor.py new file mode 100644 index 00000000..c9305bf6 --- /dev/null +++ b/stanza/pipeline/constituency_processor.py @@ -0,0 +1,52 @@ +"""Processor that attaches a constituency tree to a sentence + +The model used is a generally a model trained on the Stanford +Sentiment Treebank or some similar dataset. When run, this processor +attaches a score in the form of a string to each sentence in the +document. + +TODO: a possible way to generalize this would be to make it a +ClassifierProcessor and have "sentiment" be an option. +""" + +import stanza.models.constituency.trainer as trainer + +from stanza.models.common import doc +from stanza.models.common.pretrain import Pretrain +from stanza.pipeline._constants import * +from stanza.pipeline.processor import UDProcessor, register_processor + +@register_processor(CONSTITUENCY) +class ConstituencyProcessor(UDProcessor): + # set of processor requirements this processor fulfills + PROVIDES_DEFAULT = set([CONSTITUENCY]) + # set of processor requirements for this processor + REQUIRES_DEFAULT = set([TOKENIZE, POS]) + + # default batch size, measured in sentences + DEFAULT_BATCH_SIZE = 50 + + def _set_up_model(self, config, use_gpu): + # get pretrained word vectors + pretrain_path = config.get('pretrain_path', None) + self._pretrain = Pretrain(pretrain_path) if pretrain_path else None + # set up model + charlm_forward_file = config.get('forward_charlm_path', None) + charlm_backward_file = config.get('backward_charlm_path', None) + self._model = trainer.Trainer.load(filename=config['model_path'], + pt=self._pretrain, + forward_charlm=trainer.load_charlm(charlm_forward_file), + backward_charlm=trainer.load_charlm(charlm_backward_file), + use_gpu=use_gpu) + # batch size counted as sentences + self._batch_size = config.get('batch_size', ConstituencyProcessor.DEFAULT_BATCH_SIZE) + + def process(self, document): + sentences = document.sentences + # TODO: perhaps MWT should be relevant here? + # certainly parsing across an MWT boundary is an error + # TODO: maybe some constituency models are trained on UPOS not XPOS + words = [[(w.text, w.xpos) for w in s.words] for s in sentences] + trees = trainer.parse_tagged_words(self._model.model, words, self._batch_size) + document.set(CONSTITUENCY, trees, to_sentence=True) + return document diff --git a/stanza/pipeline/core.py b/stanza/pipeline/core.py index fda90e2d..fcb63d18 100644 --- a/stanza/pipeline/core.py +++ b/stanza/pipeline/core.py @@ -22,6 +22,7 @@ from stanza.pipeline.pos_processor import POSProcessor from stanza.pipeline.lemma_processor import LemmaProcessor from stanza.pipeline.depparse_processor import DepparseProcessor from stanza.pipeline.sentiment_processor import SentimentProcessor +from stanza.pipeline.constituency_processor import ConstituencyProcessor from stanza.pipeline.ner_processor import NERProcessor from stanza.resources.common import DEFAULT_MODEL_DIR, \ maintain_processor_list, add_dependencies, add_mwt, build_default_config, set_logging_level, process_pipeline_parameters, sort_processors diff --git a/stanza/protobuf/CoreNLP_pb2.py b/stanza/protobuf/CoreNLP_pb2.py index 298ed1b8..f29f4132 100644 --- a/stanza/protobuf/CoreNLP_pb2.py +++ b/stanza/protobuf/CoreNLP_pb2.py @@ -19,7 +19,7 @@ DESCRIPTOR = _descriptor.FileDescriptor( package='edu.stanford.nlp.pipeline', syntax='proto2', serialized_options=b'\n\031edu.stanford.nlp.pipelineB\rCoreNLPProtos', - serialized_pb=b'\n\rCoreNLP.proto\x12\x19\x65\x64u.stanford.nlp.pipeline\"\xe1\x05\n\x08\x44ocument\x12\x0c\n\x04text\x18\x01 \x02(\t\x12\x35\n\x08sentence\x18\x02 \x03(\x0b\x32#.edu.stanford.nlp.pipeline.Sentence\x12\x39\n\ncorefChain\x18\x03 \x03(\x0b\x32%.edu.stanford.nlp.pipeline.CorefChain\x12\r\n\x05\x64ocID\x18\x04 \x01(\t\x12\x0f\n\x07\x64ocDate\x18\x07 \x01(\t\x12\x10\n\x08\x63\x61lendar\x18\x08 \x01(\x04\x12;\n\x11sentencelessToken\x18\x05 \x03(\x0b\x32 .edu.stanford.nlp.pipeline.Token\x12\x33\n\tcharacter\x18\n \x03(\x0b\x32 .edu.stanford.nlp.pipeline.Token\x12/\n\x05quote\x18\x06 \x03(\x0b\x32 .edu.stanford.nlp.pipeline.Quote\x12\x37\n\x08mentions\x18\t \x03(\x0b\x32%.edu.stanford.nlp.pipeline.NERMention\x12#\n\x1bhasEntityMentionsAnnotation\x18\r \x01(\x08\x12\x0e\n\x06xmlDoc\x18\x0b \x01(\x08\x12\x34\n\x08sections\x18\x0c \x03(\x0b\x32\".edu.stanford.nlp.pipeline.Section\x12<\n\x10mentionsForCoref\x18\x0e \x03(\x0b\x32\".edu.stanford.nlp.pipeline.Mention\x12!\n\x19hasCorefMentionAnnotation\x18\x0f \x01(\x08\x12\x1a\n\x12hasCorefAnnotation\x18\x10 \x01(\x08\x12+\n#corefMentionToEntityMentionMappings\x18\x11 \x03(\x05\x12+\n#entityMentionToCorefMentionMappings\x18\x12 \x03(\x05*\x05\x08\x64\x10\x80\x02\"\xf3\x0f\n\x08Sentence\x12/\n\x05token\x18\x01 \x03(\x0b\x32 .edu.stanford.nlp.pipeline.Token\x12\x18\n\x10tokenOffsetBegin\x18\x02 \x02(\r\x12\x16\n\x0etokenOffsetEnd\x18\x03 \x02(\r\x12\x15\n\rsentenceIndex\x18\x04 \x01(\r\x12\x1c\n\x14\x63haracterOffsetBegin\x18\x05 \x01(\r\x12\x1a\n\x12\x63haracterOffsetEnd\x18\x06 \x01(\r\x12\x37\n\tparseTree\x18\x07 \x01(\x0b\x32$.edu.stanford.nlp.pipeline.ParseTree\x12@\n\x12\x62inarizedParseTree\x18\x1f \x01(\x0b\x32$.edu.stanford.nlp.pipeline.ParseTree\x12@\n\x12\x61nnotatedParseTree\x18 \x01(\x0b\x32$.edu.stanford.nlp.pipeline.ParseTree\x12\x11\n\tsentiment\x18! \x01(\t\x12=\n\x0fkBestParseTrees\x18\" \x03(\x0b\x32$.edu.stanford.nlp.pipeline.ParseTree\x12\x45\n\x11\x62\x61sicDependencies\x18\x08 \x01(\x0b\x32*.edu.stanford.nlp.pipeline.DependencyGraph\x12I\n\x15\x63ollapsedDependencies\x18\t \x01(\x0b\x32*.edu.stanford.nlp.pipeline.DependencyGraph\x12T\n collapsedCCProcessedDependencies\x18\n \x01(\x0b\x32*.edu.stanford.nlp.pipeline.DependencyGraph\x12K\n\x17\x61lternativeDependencies\x18\r \x01(\x0b\x32*.edu.stanford.nlp.pipeline.DependencyGraph\x12?\n\x0copenieTriple\x18\x0e \x03(\x0b\x32).edu.stanford.nlp.pipeline.RelationTriple\x12<\n\tkbpTriple\x18\x10 \x03(\x0b\x32).edu.stanford.nlp.pipeline.RelationTriple\x12\x45\n\x10\x65ntailedSentence\x18\x0f \x03(\x0b\x32+.edu.stanford.nlp.pipeline.SentenceFragment\x12\x43\n\x0e\x65ntailedClause\x18# \x03(\x0b\x32+.edu.stanford.nlp.pipeline.SentenceFragment\x12H\n\x14\x65nhancedDependencies\x18\x11 \x01(\x0b\x32*.edu.stanford.nlp.pipeline.DependencyGraph\x12P\n\x1c\x65nhancedPlusPlusDependencies\x18\x12 \x01(\x0b\x32*.edu.stanford.nlp.pipeline.DependencyGraph\x12\x33\n\tcharacter\x18\x13 \x03(\x0b\x32 .edu.stanford.nlp.pipeline.Token\x12\x11\n\tparagraph\x18\x0b \x01(\r\x12\x0c\n\x04text\x18\x0c \x01(\t\x12\x12\n\nlineNumber\x18\x14 \x01(\r\x12\x1e\n\x16hasRelationAnnotations\x18\x33 \x01(\x08\x12\x31\n\x06\x65ntity\x18\x34 \x03(\x0b\x32!.edu.stanford.nlp.pipeline.Entity\x12\x35\n\x08relation\x18\x35 \x03(\x0b\x32#.edu.stanford.nlp.pipeline.Relation\x12$\n\x1chasNumerizedTokensAnnotation\x18\x36 \x01(\x08\x12\x37\n\x08mentions\x18\x37 \x03(\x0b\x32%.edu.stanford.nlp.pipeline.NERMention\x12<\n\x10mentionsForCoref\x18\x38 \x03(\x0b\x32\".edu.stanford.nlp.pipeline.Mention\x12\"\n\x1ahasCorefMentionsAnnotation\x18\x39 \x01(\x08\x12\x12\n\nsentenceID\x18: \x01(\t\x12\x13\n\x0bsectionDate\x18; \x01(\t\x12\x14\n\x0csectionIndex\x18< \x01(\r\x12\x13\n\x0bsectionName\x18= \x01(\t\x12\x15\n\rsectionAuthor\x18> \x01(\t\x12\r\n\x05\x64ocID\x18? \x01(\t\x12\x15\n\rsectionQuoted\x18@ \x01(\x08\x12#\n\x1bhasEntityMentionsAnnotation\x18\x41 \x01(\x08\x12\x1f\n\x17hasKBPTriplesAnnotation\x18\x44 \x01(\x08\x12\"\n\x1ahasOpenieTriplesAnnotation\x18\x45 \x01(\x08\x12\x14\n\x0c\x63hapterIndex\x18\x42 \x01(\r\x12\x16\n\x0eparagraphIndex\x18\x43 \x01(\r\x12=\n\x10\x65nhancedSentence\x18\x46 \x01(\x0b\x32#.edu.stanford.nlp.pipeline.Sentence\x12\x0f\n\x07speaker\x18G \x01(\t\x12\x13\n\x0bspeakerType\x18H \x01(\t*\x05\x08\x64\x10\x80\x02\"\xc2\x0c\n\x05Token\x12\x0c\n\x04word\x18\x01 \x01(\t\x12\x0b\n\x03pos\x18\x02 \x01(\t\x12\r\n\x05value\x18\x03 \x01(\t\x12\x10\n\x08\x63\x61tegory\x18\x04 \x01(\t\x12\x0e\n\x06\x62\x65\x66ore\x18\x05 \x01(\t\x12\r\n\x05\x61\x66ter\x18\x06 \x01(\t\x12\x14\n\x0coriginalText\x18\x07 \x01(\t\x12\x0b\n\x03ner\x18\x08 \x01(\t\x12\x11\n\tcoarseNER\x18> \x01(\t\x12\x16\n\x0e\x66ineGrainedNER\x18? \x01(\t\x12\x15\n\rnerLabelProbs\x18\x42 \x03(\t\x12\x15\n\rnormalizedNER\x18\t \x01(\t\x12\r\n\x05lemma\x18\n \x01(\t\x12\x11\n\tbeginChar\x18\x0b \x01(\r\x12\x0f\n\x07\x65ndChar\x18\x0c \x01(\r\x12\x11\n\tutterance\x18\r \x01(\r\x12\x0f\n\x07speaker\x18\x0e \x01(\t\x12\x13\n\x0bspeakerType\x18M \x01(\t\x12\x12\n\nbeginIndex\x18\x0f \x01(\r\x12\x10\n\x08\x65ndIndex\x18\x10 \x01(\r\x12\x17\n\x0ftokenBeginIndex\x18\x11 \x01(\r\x12\x15\n\rtokenEndIndex\x18\x12 \x01(\r\x12\x34\n\ntimexValue\x18\x13 \x01(\x0b\x32 .edu.stanford.nlp.pipeline.Timex\x12\x15\n\rhasXmlContext\x18\x15 \x01(\x08\x12\x12\n\nxmlContext\x18\x16 \x03(\t\x12\x16\n\x0e\x63orefClusterID\x18\x17 \x01(\r\x12\x0e\n\x06\x61nswer\x18\x18 \x01(\t\x12\x15\n\rheadWordIndex\x18\x1a \x01(\r\x12\x35\n\x08operator\x18\x1b \x01(\x0b\x32#.edu.stanford.nlp.pipeline.Operator\x12\x35\n\x08polarity\x18\x1c \x01(\x0b\x32#.edu.stanford.nlp.pipeline.Polarity\x12\x14\n\x0cpolarity_dir\x18\' \x01(\t\x12-\n\x04span\x18\x1d \x01(\x0b\x32\x1f.edu.stanford.nlp.pipeline.Span\x12\x11\n\tsentiment\x18\x1e \x01(\t\x12\x16\n\x0equotationIndex\x18\x1f \x01(\x05\x12\x42\n\x0e\x63onllUFeatures\x18 \x01(\x0b\x32*.edu.stanford.nlp.pipeline.MapStringString\x12\x11\n\tcoarseTag\x18! \x01(\t\x12\x38\n\x0f\x63onllUTokenSpan\x18\" \x01(\x0b\x32\x1f.edu.stanford.nlp.pipeline.Span\x12\x12\n\nconllUMisc\x18# \x01(\t\x12G\n\x13\x63onllUSecondaryDeps\x18$ \x01(\x0b\x32*.edu.stanford.nlp.pipeline.MapStringString\x12\x17\n\x0fwikipediaEntity\x18% \x01(\t\x12\x11\n\tisNewline\x18& \x01(\x08\x12\x0e\n\x06gender\x18\x33 \x01(\t\x12\x10\n\x08trueCase\x18\x34 \x01(\t\x12\x14\n\x0ctrueCaseText\x18\x35 \x01(\t\x12\x13\n\x0b\x63hineseChar\x18\x36 \x01(\t\x12\x12\n\nchineseSeg\x18\x37 \x01(\t\x12\x16\n\x0e\x63hineseXMLChar\x18< \x01(\t\x12\x11\n\tarabicSeg\x18L \x01(\t\x12\x13\n\x0bsectionName\x18\x38 \x01(\t\x12\x15\n\rsectionAuthor\x18\x39 \x01(\t\x12\x13\n\x0bsectionDate\x18: \x01(\t\x12\x17\n\x0fsectionEndLabel\x18; \x01(\t\x12\x0e\n\x06parent\x18= \x01(\t\x12\x19\n\x11\x63orefMentionIndex\x18@ \x03(\r\x12\x1a\n\x12\x65ntityMentionIndex\x18\x41 \x01(\r\x12\r\n\x05isMWT\x18\x43 \x01(\x08\x12\x12\n\nisFirstMWT\x18\x44 \x01(\x08\x12\x0f\n\x07mwtText\x18\x45 \x01(\t\x12\x14\n\x0cnumericValue\x18\x46 \x01(\x04\x12\x13\n\x0bnumericType\x18G \x01(\t\x12\x1d\n\x15numericCompositeValue\x18H \x01(\x04\x12\x1c\n\x14numericCompositeType\x18I \x01(\t\x12\x1c\n\x14\x63odepointOffsetBegin\x18J \x01(\r\x12\x1a\n\x12\x63odepointOffsetEnd\x18K \x01(\r*\x05\x08\x64\x10\x80\x02\"\xe4\x03\n\x05Quote\x12\x0c\n\x04text\x18\x01 \x01(\t\x12\r\n\x05\x62\x65gin\x18\x02 \x01(\r\x12\x0b\n\x03\x65nd\x18\x03 \x01(\r\x12\x15\n\rsentenceBegin\x18\x05 \x01(\r\x12\x13\n\x0bsentenceEnd\x18\x06 \x01(\r\x12\x12\n\ntokenBegin\x18\x07 \x01(\r\x12\x10\n\x08tokenEnd\x18\x08 \x01(\r\x12\r\n\x05\x64ocid\x18\t \x01(\t\x12\r\n\x05index\x18\n \x01(\r\x12\x0e\n\x06\x61uthor\x18\x0b \x01(\t\x12\x0f\n\x07mention\x18\x0c \x01(\t\x12\x14\n\x0cmentionBegin\x18\r \x01(\r\x12\x12\n\nmentionEnd\x18\x0e \x01(\r\x12\x13\n\x0bmentionType\x18\x0f \x01(\t\x12\x14\n\x0cmentionSieve\x18\x10 \x01(\t\x12\x0f\n\x07speaker\x18\x11 \x01(\t\x12\x14\n\x0cspeakerSieve\x18\x12 \x01(\t\x12\x18\n\x10\x63\x61nonicalMention\x18\x13 \x01(\t\x12\x1d\n\x15\x63\x61nonicalMentionBegin\x18\x14 \x01(\r\x12\x1b\n\x13\x63\x61nonicalMentionEnd\x18\x15 \x01(\r\x12N\n\x1a\x61ttributionDependencyGraph\x18\x16 \x01(\x0b\x32*.edu.stanford.nlp.pipeline.DependencyGraph\"\xc7\x01\n\tParseTree\x12\x33\n\x05\x63hild\x18\x01 \x03(\x0b\x32$.edu.stanford.nlp.pipeline.ParseTree\x12\r\n\x05value\x18\x02 \x01(\t\x12\x17\n\x0fyieldBeginIndex\x18\x03 \x01(\r\x12\x15\n\ryieldEndIndex\x18\x04 \x01(\r\x12\r\n\x05score\x18\x05 \x01(\x01\x12\x37\n\tsentiment\x18\x06 \x01(\x0e\x32$.edu.stanford.nlp.pipeline.Sentiment\"\x96\x03\n\x0f\x44\x65pendencyGraph\x12=\n\x04node\x18\x01 \x03(\x0b\x32/.edu.stanford.nlp.pipeline.DependencyGraph.Node\x12=\n\x04\x65\x64ge\x18\x02 \x03(\x0b\x32/.edu.stanford.nlp.pipeline.DependencyGraph.Edge\x12\x10\n\x04root\x18\x03 \x03(\rB\x02\x10\x01\x1a\x44\n\x04Node\x12\x15\n\rsentenceIndex\x18\x01 \x02(\r\x12\r\n\x05index\x18\x02 \x02(\r\x12\x16\n\x0e\x63opyAnnotation\x18\x03 \x01(\r\x1a\xac\x01\n\x04\x45\x64ge\x12\x0e\n\x06source\x18\x01 \x02(\r\x12\x0e\n\x06target\x18\x02 \x02(\r\x12\x0b\n\x03\x64\x65p\x18\x03 \x01(\t\x12\x0f\n\x07isExtra\x18\x04 \x01(\x08\x12\x12\n\nsourceCopy\x18\x05 \x01(\r\x12\x12\n\ntargetCopy\x18\x06 \x01(\r\x12>\n\x08language\x18\x07 \x01(\x0e\x32#.edu.stanford.nlp.pipeline.Language:\x07Unknown\"\xc6\x02\n\nCorefChain\x12\x0f\n\x07\x63hainID\x18\x01 \x02(\x05\x12\x43\n\x07mention\x18\x02 \x03(\x0b\x32\x32.edu.stanford.nlp.pipeline.CorefChain.CorefMention\x12\x16\n\x0erepresentative\x18\x03 \x02(\r\x1a\xc9\x01\n\x0c\x43orefMention\x12\x11\n\tmentionID\x18\x01 \x01(\x05\x12\x13\n\x0bmentionType\x18\x02 \x01(\t\x12\x0e\n\x06number\x18\x03 \x01(\t\x12\x0e\n\x06gender\x18\x04 \x01(\t\x12\x0f\n\x07\x61nimacy\x18\x05 \x01(\t\x12\x12\n\nbeginIndex\x18\x06 \x01(\r\x12\x10\n\x08\x65ndIndex\x18\x07 \x01(\r\x12\x11\n\theadIndex\x18\t \x01(\r\x12\x15\n\rsentenceIndex\x18\n \x01(\r\x12\x10\n\x08position\x18\x0b \x01(\r\"\xef\x08\n\x07Mention\x12\x11\n\tmentionID\x18\x01 \x01(\x05\x12\x13\n\x0bmentionType\x18\x02 \x01(\t\x12\x0e\n\x06number\x18\x03 \x01(\t\x12\x0e\n\x06gender\x18\x04 \x01(\t\x12\x0f\n\x07\x61nimacy\x18\x05 \x01(\t\x12\x0e\n\x06person\x18\x06 \x01(\t\x12\x12\n\nstartIndex\x18\x07 \x01(\r\x12\x10\n\x08\x65ndIndex\x18\t \x01(\r\x12\x11\n\theadIndex\x18\n \x01(\x05\x12\x12\n\nheadString\x18\x0b \x01(\t\x12\x11\n\tnerString\x18\x0c \x01(\t\x12\x13\n\x0boriginalRef\x18\r \x01(\x05\x12\x1a\n\x12goldCorefClusterID\x18\x0e \x01(\x05\x12\x16\n\x0e\x63orefClusterID\x18\x0f \x01(\x05\x12\x12\n\nmentionNum\x18\x10 \x01(\x05\x12\x0f\n\x07sentNum\x18\x11 \x01(\x05\x12\r\n\x05utter\x18\x12 \x01(\x05\x12\x11\n\tparagraph\x18\x13 \x01(\x05\x12\x11\n\tisSubject\x18\x14 \x01(\x08\x12\x16\n\x0eisDirectObject\x18\x15 \x01(\x08\x12\x18\n\x10isIndirectObject\x18\x16 \x01(\x08\x12\x1b\n\x13isPrepositionObject\x18\x17 \x01(\x08\x12\x0f\n\x07hasTwin\x18\x18 \x01(\x08\x12\x0f\n\x07generic\x18\x19 \x01(\x08\x12\x13\n\x0bisSingleton\x18\x1a \x01(\x08\x12\x1a\n\x12hasBasicDependency\x18\x1b \x01(\x08\x12\x1d\n\x15hasEnhancedDepenedncy\x18\x1c \x01(\x08\x12\x1b\n\x13hasContextParseTree\x18\x1d \x01(\x08\x12?\n\x0fheadIndexedWord\x18\x1e \x01(\x0b\x32&.edu.stanford.nlp.pipeline.IndexedWord\x12=\n\rdependingVerb\x18\x1f \x01(\x0b\x32&.edu.stanford.nlp.pipeline.IndexedWord\x12\x38\n\x08headWord\x18 \x01(\x0b\x32&.edu.stanford.nlp.pipeline.IndexedWord\x12;\n\x0bspeakerInfo\x18! \x01(\x0b\x32&.edu.stanford.nlp.pipeline.SpeakerInfo\x12=\n\rsentenceWords\x18\x32 \x03(\x0b\x32&.edu.stanford.nlp.pipeline.IndexedWord\x12<\n\x0coriginalSpan\x18\x33 \x03(\x0b\x32&.edu.stanford.nlp.pipeline.IndexedWord\x12\x12\n\ndependents\x18\x34 \x03(\t\x12\x19\n\x11preprocessedTerms\x18\x35 \x03(\t\x12\x13\n\x0b\x61ppositions\x18\x36 \x03(\x05\x12\x1c\n\x14predicateNominatives\x18\x37 \x03(\x05\x12\x18\n\x10relativePronouns\x18\x38 \x03(\x05\x12\x13\n\x0blistMembers\x18\x39 \x03(\x05\x12\x15\n\rbelongToLists\x18: \x03(\x05\"X\n\x0bIndexedWord\x12\x13\n\x0bsentenceNum\x18\x01 \x01(\x05\x12\x12\n\ntokenIndex\x18\x02 \x01(\x05\x12\r\n\x05\x64ocID\x18\x03 \x01(\x05\x12\x11\n\tcopyCount\x18\x04 \x01(\r\"4\n\x0bSpeakerInfo\x12\x13\n\x0bspeakerName\x18\x01 \x01(\t\x12\x10\n\x08mentions\x18\x02 \x03(\x05\"\"\n\x04Span\x12\r\n\x05\x62\x65gin\x18\x01 \x02(\r\x12\x0b\n\x03\x65nd\x18\x02 \x02(\r\"w\n\x05Timex\x12\r\n\x05value\x18\x01 \x01(\t\x12\x10\n\x08\x61ltValue\x18\x02 \x01(\t\x12\x0c\n\x04text\x18\x03 \x01(\t\x12\x0c\n\x04type\x18\x04 \x01(\t\x12\x0b\n\x03tid\x18\x05 \x01(\t\x12\x12\n\nbeginPoint\x18\x06 \x01(\r\x12\x10\n\x08\x65ndPoint\x18\x07 \x01(\r\"\xdb\x01\n\x06\x45ntity\x12\x11\n\theadStart\x18\x06 \x01(\r\x12\x0f\n\x07headEnd\x18\x07 \x01(\r\x12\x13\n\x0bmentionType\x18\x08 \x01(\t\x12\x16\n\x0enormalizedName\x18\t \x01(\t\x12\x16\n\x0eheadTokenIndex\x18\n \x01(\r\x12\x0f\n\x07\x63orefID\x18\x0b \x01(\t\x12\x10\n\x08objectID\x18\x01 \x01(\t\x12\x13\n\x0b\x65xtentStart\x18\x02 \x01(\r\x12\x11\n\textentEnd\x18\x03 \x01(\r\x12\x0c\n\x04type\x18\x04 \x01(\t\x12\x0f\n\x07subtype\x18\x05 \x01(\t\"\xb7\x01\n\x08Relation\x12\x0f\n\x07\x61rgName\x18\x06 \x03(\t\x12.\n\x03\x61rg\x18\x07 \x03(\x0b\x32!.edu.stanford.nlp.pipeline.Entity\x12\x11\n\tsignature\x18\x08 \x01(\t\x12\x10\n\x08objectID\x18\x01 \x01(\t\x12\x13\n\x0b\x65xtentStart\x18\x02 \x01(\r\x12\x11\n\textentEnd\x18\x03 \x01(\r\x12\x0c\n\x04type\x18\x04 \x01(\t\x12\x0f\n\x07subtype\x18\x05 \x01(\t\"\xb2\x01\n\x08Operator\x12\x0c\n\x04name\x18\x01 \x02(\t\x12\x1b\n\x13quantifierSpanBegin\x18\x02 \x02(\x05\x12\x19\n\x11quantifierSpanEnd\x18\x03 \x02(\x05\x12\x18\n\x10subjectSpanBegin\x18\x04 \x02(\x05\x12\x16\n\x0esubjectSpanEnd\x18\x05 \x02(\x05\x12\x17\n\x0fobjectSpanBegin\x18\x06 \x02(\x05\x12\x15\n\robjectSpanEnd\x18\x07 \x02(\x05\"\xa9\x04\n\x08Polarity\x12K\n\x12projectEquivalence\x18\x01 \x02(\x0e\x32/.edu.stanford.nlp.pipeline.NaturalLogicRelation\x12Q\n\x18projectForwardEntailment\x18\x02 \x02(\x0e\x32/.edu.stanford.nlp.pipeline.NaturalLogicRelation\x12Q\n\x18projectReverseEntailment\x18\x03 \x02(\x0e\x32/.edu.stanford.nlp.pipeline.NaturalLogicRelation\x12H\n\x0fprojectNegation\x18\x04 \x02(\x0e\x32/.edu.stanford.nlp.pipeline.NaturalLogicRelation\x12K\n\x12projectAlternation\x18\x05 \x02(\x0e\x32/.edu.stanford.nlp.pipeline.NaturalLogicRelation\x12\x45\n\x0cprojectCover\x18\x06 \x02(\x0e\x32/.edu.stanford.nlp.pipeline.NaturalLogicRelation\x12L\n\x13projectIndependence\x18\x07 \x02(\x0e\x32/.edu.stanford.nlp.pipeline.NaturalLogicRelation\"\xdd\x02\n\nNERMention\x12\x15\n\rsentenceIndex\x18\x01 \x01(\r\x12%\n\x1dtokenStartInSentenceInclusive\x18\x02 \x02(\r\x12#\n\x1btokenEndInSentenceExclusive\x18\x03 \x02(\r\x12\x0b\n\x03ner\x18\x04 \x02(\t\x12\x15\n\rnormalizedNER\x18\x05 \x01(\t\x12\x12\n\nentityType\x18\x06 \x01(\t\x12/\n\x05timex\x18\x07 \x01(\x0b\x32 .edu.stanford.nlp.pipeline.Timex\x12\x17\n\x0fwikipediaEntity\x18\x08 \x01(\t\x12\x0e\n\x06gender\x18\t \x01(\t\x12\x1a\n\x12\x65ntityMentionIndex\x18\n \x01(\r\x12#\n\x1b\x63\x61nonicalEntityMentionIndex\x18\x0b \x01(\r\x12\x19\n\x11\x65ntityMentionText\x18\x0c \x01(\t\"Y\n\x10SentenceFragment\x12\x12\n\ntokenIndex\x18\x01 \x03(\r\x12\x0c\n\x04root\x18\x02 \x01(\r\x12\x14\n\x0c\x61ssumedTruth\x18\x03 \x01(\x08\x12\r\n\x05score\x18\x04 \x01(\x01\":\n\rTokenLocation\x12\x15\n\rsentenceIndex\x18\x01 \x01(\r\x12\x12\n\ntokenIndex\x18\x02 \x01(\r\"\x9a\x03\n\x0eRelationTriple\x12\x0f\n\x07subject\x18\x01 \x01(\t\x12\x10\n\x08relation\x18\x02 \x01(\t\x12\x0e\n\x06object\x18\x03 \x01(\t\x12\x12\n\nconfidence\x18\x04 \x01(\x01\x12?\n\rsubjectTokens\x18\r \x03(\x0b\x32(.edu.stanford.nlp.pipeline.TokenLocation\x12@\n\x0erelationTokens\x18\x0e \x03(\x0b\x32(.edu.stanford.nlp.pipeline.TokenLocation\x12>\n\x0cobjectTokens\x18\x0f \x03(\x0b\x32(.edu.stanford.nlp.pipeline.TokenLocation\x12\x38\n\x04tree\x18\x08 \x01(\x0b\x32*.edu.stanford.nlp.pipeline.DependencyGraph\x12\x0e\n\x06istmod\x18\t \x01(\x08\x12\x10\n\x08prefixBe\x18\n \x01(\x08\x12\x10\n\x08suffixBe\x18\x0b \x01(\x08\x12\x10\n\x08suffixOf\x18\x0c \x01(\x08\"-\n\x0fMapStringString\x12\x0b\n\x03key\x18\x01 \x03(\t\x12\r\n\x05value\x18\x02 \x03(\t\"*\n\x0cMapIntString\x12\x0b\n\x03key\x18\x01 \x03(\r\x12\r\n\x05value\x18\x02 \x03(\t\"\xfc\x01\n\x07Section\x12\x11\n\tcharBegin\x18\x01 \x02(\r\x12\x0f\n\x07\x63harEnd\x18\x02 \x02(\r\x12\x0e\n\x06\x61uthor\x18\x03 \x01(\t\x12\x17\n\x0fsentenceIndexes\x18\x04 \x03(\r\x12\x10\n\x08\x64\x61tetime\x18\x05 \x01(\t\x12\x30\n\x06quotes\x18\x06 \x03(\x0b\x32 .edu.stanford.nlp.pipeline.Quote\x12\x17\n\x0f\x61uthorCharBegin\x18\x07 \x01(\r\x12\x15\n\rauthorCharEnd\x18\x08 \x01(\r\x12\x30\n\x06xmlTag\x18\t \x02(\x0b\x32 .edu.stanford.nlp.pipeline.Token\"\xe4\x01\n\x0eSemgrexRequest\x12\x0f\n\x07semgrex\x18\x01 \x03(\t\x12\x45\n\x05query\x18\x02 \x03(\x0b\x32\x36.edu.stanford.nlp.pipeline.SemgrexRequest.Dependencies\x1az\n\x0c\x44\x65pendencies\x12/\n\x05token\x18\x01 \x03(\x0b\x32 .edu.stanford.nlp.pipeline.Token\x12\x39\n\x05graph\x18\x02 \x02(\x0b\x32*.edu.stanford.nlp.pipeline.DependencyGraph\"\x8a\x04\n\x0fSemgrexResponse\x12\x46\n\x06result\x18\x01 \x03(\x0b\x32\x36.edu.stanford.nlp.pipeline.SemgrexResponse.GraphResult\x1a-\n\tNamedNode\x12\x0c\n\x04name\x18\x01 \x02(\t\x12\x12\n\nmatchIndex\x18\x02 \x02(\x05\x1a+\n\rNamedRelation\x12\x0c\n\x04name\x18\x01 \x02(\t\x12\x0c\n\x04reln\x18\x02 \x02(\t\x1a\xa7\x01\n\x05Match\x12\x12\n\nmatchIndex\x18\x01 \x02(\x05\x12\x42\n\x04node\x18\x02 \x03(\x0b\x32\x34.edu.stanford.nlp.pipeline.SemgrexResponse.NamedNode\x12\x46\n\x04reln\x18\x03 \x03(\x0b\x32\x38.edu.stanford.nlp.pipeline.SemgrexResponse.NamedRelation\x1aP\n\rSemgrexResult\x12?\n\x05match\x18\x01 \x03(\x0b\x32\x30.edu.stanford.nlp.pipeline.SemgrexResponse.Match\x1aW\n\x0bGraphResult\x12H\n\x06result\x18\x01 \x03(\x0b\x32\x38.edu.stanford.nlp.pipeline.SemgrexResponse.SemgrexResult\"W\n\x12TokensRegexRequest\x12\x30\n\x03\x64oc\x18\x01 \x02(\x0b\x32#.edu.stanford.nlp.pipeline.Document\x12\x0f\n\x07pattern\x18\x02 \x03(\t\"\xa7\x03\n\x13TokensRegexResponse\x12J\n\x05match\x18\x01 \x03(\x0b\x32;.edu.stanford.nlp.pipeline.TokensRegexResponse.PatternMatch\x1a\x39\n\rMatchLocation\x12\x0c\n\x04text\x18\x01 \x01(\t\x12\r\n\x05\x62\x65gin\x18\x02 \x01(\x05\x12\x0b\n\x03\x65nd\x18\x03 \x01(\x05\x1a\xb3\x01\n\x05Match\x12\x10\n\x08sentence\x18\x01 \x02(\x05\x12K\n\x05match\x18\x02 \x02(\x0b\x32<.edu.stanford.nlp.pipeline.TokensRegexResponse.MatchLocation\x12K\n\x05group\x18\x03 \x03(\x0b\x32<.edu.stanford.nlp.pipeline.TokensRegexResponse.MatchLocation\x1aS\n\x0cPatternMatch\x12\x43\n\x05match\x18\x01 \x03(\x0b\x32\x34.edu.stanford.nlp.pipeline.TokensRegexResponse.Match\"\xae\x01\n\x19\x44\x65pendencyEnhancerRequest\x12\x35\n\x08\x64ocument\x18\x01 \x02(\x0b\x32#.edu.stanford.nlp.pipeline.Document\x12\x37\n\x08language\x18\x02 \x01(\x0e\x32#.edu.stanford.nlp.pipeline.LanguageH\x00\x12\x1a\n\x10relativePronouns\x18\x03 \x01(\tH\x00\x42\x05\n\x03ref*\xa3\x01\n\x08Language\x12\x0b\n\x07Unknown\x10\x00\x12\x07\n\x03\x41ny\x10\x01\x12\n\n\x06\x41rabic\x10\x02\x12\x0b\n\x07\x43hinese\x10\x03\x12\x0b\n\x07\x45nglish\x10\x04\x12\n\n\x06German\x10\x05\x12\n\n\x06\x46rench\x10\x06\x12\n\n\x06Hebrew\x10\x07\x12\x0b\n\x07Spanish\x10\x08\x12\x14\n\x10UniversalEnglish\x10\t\x12\x14\n\x10UniversalChinese\x10\n*h\n\tSentiment\x12\x13\n\x0fSTRONG_NEGATIVE\x10\x00\x12\x11\n\rWEAK_NEGATIVE\x10\x01\x12\x0b\n\x07NEUTRAL\x10\x02\x12\x11\n\rWEAK_POSITIVE\x10\x03\x12\x13\n\x0fSTRONG_POSITIVE\x10\x04*\x93\x01\n\x14NaturalLogicRelation\x12\x0f\n\x0b\x45QUIVALENCE\x10\x00\x12\x16\n\x12\x46ORWARD_ENTAILMENT\x10\x01\x12\x16\n\x12REVERSE_ENTAILMENT\x10\x02\x12\x0c\n\x08NEGATION\x10\x03\x12\x0f\n\x0b\x41LTERNATION\x10\x04\x12\t\n\x05\x43OVER\x10\x05\x12\x10\n\x0cINDEPENDENCE\x10\x06\x42*\n\x19\x65\x64u.stanford.nlp.pipelineB\rCoreNLPProtos' + serialized_pb=b'\n\rCoreNLP.proto\x12\x19\x65\x64u.stanford.nlp.pipeline\"\xe1\x05\n\x08\x44ocument\x12\x0c\n\x04text\x18\x01 \x02(\t\x12\x35\n\x08sentence\x18\x02 \x03(\x0b\x32#.edu.stanford.nlp.pipeline.Sentence\x12\x39\n\ncorefChain\x18\x03 \x03(\x0b\x32%.edu.stanford.nlp.pipeline.CorefChain\x12\r\n\x05\x64ocID\x18\x04 \x01(\t\x12\x0f\n\x07\x64ocDate\x18\x07 \x01(\t\x12\x10\n\x08\x63\x61lendar\x18\x08 \x01(\x04\x12;\n\x11sentencelessToken\x18\x05 \x03(\x0b\x32 .edu.stanford.nlp.pipeline.Token\x12\x33\n\tcharacter\x18\n \x03(\x0b\x32 .edu.stanford.nlp.pipeline.Token\x12/\n\x05quote\x18\x06 \x03(\x0b\x32 .edu.stanford.nlp.pipeline.Quote\x12\x37\n\x08mentions\x18\t \x03(\x0b\x32%.edu.stanford.nlp.pipeline.NERMention\x12#\n\x1bhasEntityMentionsAnnotation\x18\r \x01(\x08\x12\x0e\n\x06xmlDoc\x18\x0b \x01(\x08\x12\x34\n\x08sections\x18\x0c \x03(\x0b\x32\".edu.stanford.nlp.pipeline.Section\x12<\n\x10mentionsForCoref\x18\x0e \x03(\x0b\x32\".edu.stanford.nlp.pipeline.Mention\x12!\n\x19hasCorefMentionAnnotation\x18\x0f \x01(\x08\x12\x1a\n\x12hasCorefAnnotation\x18\x10 \x01(\x08\x12+\n#corefMentionToEntityMentionMappings\x18\x11 \x03(\x05\x12+\n#entityMentionToCorefMentionMappings\x18\x12 \x03(\x05*\x05\x08\x64\x10\x80\x02\"\xf3\x0f\n\x08Sentence\x12/\n\x05token\x18\x01 \x03(\x0b\x32 .edu.stanford.nlp.pipeline.Token\x12\x18\n\x10tokenOffsetBegin\x18\x02 \x02(\r\x12\x16\n\x0etokenOffsetEnd\x18\x03 \x02(\r\x12\x15\n\rsentenceIndex\x18\x04 \x01(\r\x12\x1c\n\x14\x63haracterOffsetBegin\x18\x05 \x01(\r\x12\x1a\n\x12\x63haracterOffsetEnd\x18\x06 \x01(\r\x12\x37\n\tparseTree\x18\x07 \x01(\x0b\x32$.edu.stanford.nlp.pipeline.ParseTree\x12@\n\x12\x62inarizedParseTree\x18\x1f \x01(\x0b\x32$.edu.stanford.nlp.pipeline.ParseTree\x12@\n\x12\x61nnotatedParseTree\x18 \x01(\x0b\x32$.edu.stanford.nlp.pipeline.ParseTree\x12\x11\n\tsentiment\x18! \x01(\t\x12=\n\x0fkBestParseTrees\x18\" \x03(\x0b\x32$.edu.stanford.nlp.pipeline.ParseTree\x12\x45\n\x11\x62\x61sicDependencies\x18\x08 \x01(\x0b\x32*.edu.stanford.nlp.pipeline.DependencyGraph\x12I\n\x15\x63ollapsedDependencies\x18\t \x01(\x0b\x32*.edu.stanford.nlp.pipeline.DependencyGraph\x12T\n collapsedCCProcessedDependencies\x18\n \x01(\x0b\x32*.edu.stanford.nlp.pipeline.DependencyGraph\x12K\n\x17\x61lternativeDependencies\x18\r \x01(\x0b\x32*.edu.stanford.nlp.pipeline.DependencyGraph\x12?\n\x0copenieTriple\x18\x0e \x03(\x0b\x32).edu.stanford.nlp.pipeline.RelationTriple\x12<\n\tkbpTriple\x18\x10 \x03(\x0b\x32).edu.stanford.nlp.pipeline.RelationTriple\x12\x45\n\x10\x65ntailedSentence\x18\x0f \x03(\x0b\x32+.edu.stanford.nlp.pipeline.SentenceFragment\x12\x43\n\x0e\x65ntailedClause\x18# \x03(\x0b\x32+.edu.stanford.nlp.pipeline.SentenceFragment\x12H\n\x14\x65nhancedDependencies\x18\x11 \x01(\x0b\x32*.edu.stanford.nlp.pipeline.DependencyGraph\x12P\n\x1c\x65nhancedPlusPlusDependencies\x18\x12 \x01(\x0b\x32*.edu.stanford.nlp.pipeline.DependencyGraph\x12\x33\n\tcharacter\x18\x13 \x03(\x0b\x32 .edu.stanford.nlp.pipeline.Token\x12\x11\n\tparagraph\x18\x0b \x01(\r\x12\x0c\n\x04text\x18\x0c \x01(\t\x12\x12\n\nlineNumber\x18\x14 \x01(\r\x12\x1e\n\x16hasRelationAnnotations\x18\x33 \x01(\x08\x12\x31\n\x06\x65ntity\x18\x34 \x03(\x0b\x32!.edu.stanford.nlp.pipeline.Entity\x12\x35\n\x08relation\x18\x35 \x03(\x0b\x32#.edu.stanford.nlp.pipeline.Relation\x12$\n\x1chasNumerizedTokensAnnotation\x18\x36 \x01(\x08\x12\x37\n\x08mentions\x18\x37 \x03(\x0b\x32%.edu.stanford.nlp.pipeline.NERMention\x12<\n\x10mentionsForCoref\x18\x38 \x03(\x0b\x32\".edu.stanford.nlp.pipeline.Mention\x12\"\n\x1ahasCorefMentionsAnnotation\x18\x39 \x01(\x08\x12\x12\n\nsentenceID\x18: \x01(\t\x12\x13\n\x0bsectionDate\x18; \x01(\t\x12\x14\n\x0csectionIndex\x18< \x01(\r\x12\x13\n\x0bsectionName\x18= \x01(\t\x12\x15\n\rsectionAuthor\x18> \x01(\t\x12\r\n\x05\x64ocID\x18? \x01(\t\x12\x15\n\rsectionQuoted\x18@ \x01(\x08\x12#\n\x1bhasEntityMentionsAnnotation\x18\x41 \x01(\x08\x12\x1f\n\x17hasKBPTriplesAnnotation\x18\x44 \x01(\x08\x12\"\n\x1ahasOpenieTriplesAnnotation\x18\x45 \x01(\x08\x12\x14\n\x0c\x63hapterIndex\x18\x42 \x01(\r\x12\x16\n\x0eparagraphIndex\x18\x43 \x01(\r\x12=\n\x10\x65nhancedSentence\x18\x46 \x01(\x0b\x32#.edu.stanford.nlp.pipeline.Sentence\x12\x0f\n\x07speaker\x18G \x01(\t\x12\x13\n\x0bspeakerType\x18H \x01(\t*\x05\x08\x64\x10\x80\x02\"\xc2\x0c\n\x05Token\x12\x0c\n\x04word\x18\x01 \x01(\t\x12\x0b\n\x03pos\x18\x02 \x01(\t\x12\r\n\x05value\x18\x03 \x01(\t\x12\x10\n\x08\x63\x61tegory\x18\x04 \x01(\t\x12\x0e\n\x06\x62\x65\x66ore\x18\x05 \x01(\t\x12\r\n\x05\x61\x66ter\x18\x06 \x01(\t\x12\x14\n\x0coriginalText\x18\x07 \x01(\t\x12\x0b\n\x03ner\x18\x08 \x01(\t\x12\x11\n\tcoarseNER\x18> \x01(\t\x12\x16\n\x0e\x66ineGrainedNER\x18? \x01(\t\x12\x15\n\rnerLabelProbs\x18\x42 \x03(\t\x12\x15\n\rnormalizedNER\x18\t \x01(\t\x12\r\n\x05lemma\x18\n \x01(\t\x12\x11\n\tbeginChar\x18\x0b \x01(\r\x12\x0f\n\x07\x65ndChar\x18\x0c \x01(\r\x12\x11\n\tutterance\x18\r \x01(\r\x12\x0f\n\x07speaker\x18\x0e \x01(\t\x12\x13\n\x0bspeakerType\x18M \x01(\t\x12\x12\n\nbeginIndex\x18\x0f \x01(\r\x12\x10\n\x08\x65ndIndex\x18\x10 \x01(\r\x12\x17\n\x0ftokenBeginIndex\x18\x11 \x01(\r\x12\x15\n\rtokenEndIndex\x18\x12 \x01(\r\x12\x34\n\ntimexValue\x18\x13 \x01(\x0b\x32 .edu.stanford.nlp.pipeline.Timex\x12\x15\n\rhasXmlContext\x18\x15 \x01(\x08\x12\x12\n\nxmlContext\x18\x16 \x03(\t\x12\x16\n\x0e\x63orefClusterID\x18\x17 \x01(\r\x12\x0e\n\x06\x61nswer\x18\x18 \x01(\t\x12\x15\n\rheadWordIndex\x18\x1a \x01(\r\x12\x35\n\x08operator\x18\x1b \x01(\x0b\x32#.edu.stanford.nlp.pipeline.Operator\x12\x35\n\x08polarity\x18\x1c \x01(\x0b\x32#.edu.stanford.nlp.pipeline.Polarity\x12\x14\n\x0cpolarity_dir\x18\' \x01(\t\x12-\n\x04span\x18\x1d \x01(\x0b\x32\x1f.edu.stanford.nlp.pipeline.Span\x12\x11\n\tsentiment\x18\x1e \x01(\t\x12\x16\n\x0equotationIndex\x18\x1f \x01(\x05\x12\x42\n\x0e\x63onllUFeatures\x18 \x01(\x0b\x32*.edu.stanford.nlp.pipeline.MapStringString\x12\x11\n\tcoarseTag\x18! \x01(\t\x12\x38\n\x0f\x63onllUTokenSpan\x18\" \x01(\x0b\x32\x1f.edu.stanford.nlp.pipeline.Span\x12\x12\n\nconllUMisc\x18# \x01(\t\x12G\n\x13\x63onllUSecondaryDeps\x18$ \x01(\x0b\x32*.edu.stanford.nlp.pipeline.MapStringString\x12\x17\n\x0fwikipediaEntity\x18% \x01(\t\x12\x11\n\tisNewline\x18& \x01(\x08\x12\x0e\n\x06gender\x18\x33 \x01(\t\x12\x10\n\x08trueCase\x18\x34 \x01(\t\x12\x14\n\x0ctrueCaseText\x18\x35 \x01(\t\x12\x13\n\x0b\x63hineseChar\x18\x36 \x01(\t\x12\x12\n\nchineseSeg\x18\x37 \x01(\t\x12\x16\n\x0e\x63hineseXMLChar\x18< \x01(\t\x12\x11\n\tarabicSeg\x18L \x01(\t\x12\x13\n\x0bsectionName\x18\x38 \x01(\t\x12\x15\n\rsectionAuthor\x18\x39 \x01(\t\x12\x13\n\x0bsectionDate\x18: \x01(\t\x12\x17\n\x0fsectionEndLabel\x18; \x01(\t\x12\x0e\n\x06parent\x18= \x01(\t\x12\x19\n\x11\x63orefMentionIndex\x18@ \x03(\r\x12\x1a\n\x12\x65ntityMentionIndex\x18\x41 \x01(\r\x12\r\n\x05isMWT\x18\x43 \x01(\x08\x12\x12\n\nisFirstMWT\x18\x44 \x01(\x08\x12\x0f\n\x07mwtText\x18\x45 \x01(\t\x12\x14\n\x0cnumericValue\x18\x46 \x01(\x04\x12\x13\n\x0bnumericType\x18G \x01(\t\x12\x1d\n\x15numericCompositeValue\x18H \x01(\x04\x12\x1c\n\x14numericCompositeType\x18I \x01(\t\x12\x1c\n\x14\x63odepointOffsetBegin\x18J \x01(\r\x12\x1a\n\x12\x63odepointOffsetEnd\x18K \x01(\r*\x05\x08\x64\x10\x80\x02\"\xe4\x03\n\x05Quote\x12\x0c\n\x04text\x18\x01 \x01(\t\x12\r\n\x05\x62\x65gin\x18\x02 \x01(\r\x12\x0b\n\x03\x65nd\x18\x03 \x01(\r\x12\x15\n\rsentenceBegin\x18\x05 \x01(\r\x12\x13\n\x0bsentenceEnd\x18\x06 \x01(\r\x12\x12\n\ntokenBegin\x18\x07 \x01(\r\x12\x10\n\x08tokenEnd\x18\x08 \x01(\r\x12\r\n\x05\x64ocid\x18\t \x01(\t\x12\r\n\x05index\x18\n \x01(\r\x12\x0e\n\x06\x61uthor\x18\x0b \x01(\t\x12\x0f\n\x07mention\x18\x0c \x01(\t\x12\x14\n\x0cmentionBegin\x18\r \x01(\r\x12\x12\n\nmentionEnd\x18\x0e \x01(\r\x12\x13\n\x0bmentionType\x18\x0f \x01(\t\x12\x14\n\x0cmentionSieve\x18\x10 \x01(\t\x12\x0f\n\x07speaker\x18\x11 \x01(\t\x12\x14\n\x0cspeakerSieve\x18\x12 \x01(\t\x12\x18\n\x10\x63\x61nonicalMention\x18\x13 \x01(\t\x12\x1d\n\x15\x63\x61nonicalMentionBegin\x18\x14 \x01(\r\x12\x1b\n\x13\x63\x61nonicalMentionEnd\x18\x15 \x01(\r\x12N\n\x1a\x61ttributionDependencyGraph\x18\x16 \x01(\x0b\x32*.edu.stanford.nlp.pipeline.DependencyGraph\"\xc7\x01\n\tParseTree\x12\x33\n\x05\x63hild\x18\x01 \x03(\x0b\x32$.edu.stanford.nlp.pipeline.ParseTree\x12\r\n\x05value\x18\x02 \x01(\t\x12\x17\n\x0fyieldBeginIndex\x18\x03 \x01(\r\x12\x15\n\ryieldEndIndex\x18\x04 \x01(\r\x12\r\n\x05score\x18\x05 \x01(\x01\x12\x37\n\tsentiment\x18\x06 \x01(\x0e\x32$.edu.stanford.nlp.pipeline.Sentiment\"\x96\x03\n\x0f\x44\x65pendencyGraph\x12=\n\x04node\x18\x01 \x03(\x0b\x32/.edu.stanford.nlp.pipeline.DependencyGraph.Node\x12=\n\x04\x65\x64ge\x18\x02 \x03(\x0b\x32/.edu.stanford.nlp.pipeline.DependencyGraph.Edge\x12\x10\n\x04root\x18\x03 \x03(\rB\x02\x10\x01\x1a\x44\n\x04Node\x12\x15\n\rsentenceIndex\x18\x01 \x02(\r\x12\r\n\x05index\x18\x02 \x02(\r\x12\x16\n\x0e\x63opyAnnotation\x18\x03 \x01(\r\x1a\xac\x01\n\x04\x45\x64ge\x12\x0e\n\x06source\x18\x01 \x02(\r\x12\x0e\n\x06target\x18\x02 \x02(\r\x12\x0b\n\x03\x64\x65p\x18\x03 \x01(\t\x12\x0f\n\x07isExtra\x18\x04 \x01(\x08\x12\x12\n\nsourceCopy\x18\x05 \x01(\r\x12\x12\n\ntargetCopy\x18\x06 \x01(\r\x12>\n\x08language\x18\x07 \x01(\x0e\x32#.edu.stanford.nlp.pipeline.Language:\x07Unknown\"\xc6\x02\n\nCorefChain\x12\x0f\n\x07\x63hainID\x18\x01 \x02(\x05\x12\x43\n\x07mention\x18\x02 \x03(\x0b\x32\x32.edu.stanford.nlp.pipeline.CorefChain.CorefMention\x12\x16\n\x0erepresentative\x18\x03 \x02(\r\x1a\xc9\x01\n\x0c\x43orefMention\x12\x11\n\tmentionID\x18\x01 \x01(\x05\x12\x13\n\x0bmentionType\x18\x02 \x01(\t\x12\x0e\n\x06number\x18\x03 \x01(\t\x12\x0e\n\x06gender\x18\x04 \x01(\t\x12\x0f\n\x07\x61nimacy\x18\x05 \x01(\t\x12\x12\n\nbeginIndex\x18\x06 \x01(\r\x12\x10\n\x08\x65ndIndex\x18\x07 \x01(\r\x12\x11\n\theadIndex\x18\t \x01(\r\x12\x15\n\rsentenceIndex\x18\n \x01(\r\x12\x10\n\x08position\x18\x0b \x01(\r\"\xef\x08\n\x07Mention\x12\x11\n\tmentionID\x18\x01 \x01(\x05\x12\x13\n\x0bmentionType\x18\x02 \x01(\t\x12\x0e\n\x06number\x18\x03 \x01(\t\x12\x0e\n\x06gender\x18\x04 \x01(\t\x12\x0f\n\x07\x61nimacy\x18\x05 \x01(\t\x12\x0e\n\x06person\x18\x06 \x01(\t\x12\x12\n\nstartIndex\x18\x07 \x01(\r\x12\x10\n\x08\x65ndIndex\x18\t \x01(\r\x12\x11\n\theadIndex\x18\n \x01(\x05\x12\x12\n\nheadString\x18\x0b \x01(\t\x12\x11\n\tnerString\x18\x0c \x01(\t\x12\x13\n\x0boriginalRef\x18\r \x01(\x05\x12\x1a\n\x12goldCorefClusterID\x18\x0e \x01(\x05\x12\x16\n\x0e\x63orefClusterID\x18\x0f \x01(\x05\x12\x12\n\nmentionNum\x18\x10 \x01(\x05\x12\x0f\n\x07sentNum\x18\x11 \x01(\x05\x12\r\n\x05utter\x18\x12 \x01(\x05\x12\x11\n\tparagraph\x18\x13 \x01(\x05\x12\x11\n\tisSubject\x18\x14 \x01(\x08\x12\x16\n\x0eisDirectObject\x18\x15 \x01(\x08\x12\x18\n\x10isIndirectObject\x18\x16 \x01(\x08\x12\x1b\n\x13isPrepositionObject\x18\x17 \x01(\x08\x12\x0f\n\x07hasTwin\x18\x18 \x01(\x08\x12\x0f\n\x07generic\x18\x19 \x01(\x08\x12\x13\n\x0bisSingleton\x18\x1a \x01(\x08\x12\x1a\n\x12hasBasicDependency\x18\x1b \x01(\x08\x12\x1d\n\x15hasEnhancedDepenedncy\x18\x1c \x01(\x08\x12\x1b\n\x13hasContextParseTree\x18\x1d \x01(\x08\x12?\n\x0fheadIndexedWord\x18\x1e \x01(\x0b\x32&.edu.stanford.nlp.pipeline.IndexedWord\x12=\n\rdependingVerb\x18\x1f \x01(\x0b\x32&.edu.stanford.nlp.pipeline.IndexedWord\x12\x38\n\x08headWord\x18 \x01(\x0b\x32&.edu.stanford.nlp.pipeline.IndexedWord\x12;\n\x0bspeakerInfo\x18! \x01(\x0b\x32&.edu.stanford.nlp.pipeline.SpeakerInfo\x12=\n\rsentenceWords\x18\x32 \x03(\x0b\x32&.edu.stanford.nlp.pipeline.IndexedWord\x12<\n\x0coriginalSpan\x18\x33 \x03(\x0b\x32&.edu.stanford.nlp.pipeline.IndexedWord\x12\x12\n\ndependents\x18\x34 \x03(\t\x12\x19\n\x11preprocessedTerms\x18\x35 \x03(\t\x12\x13\n\x0b\x61ppositions\x18\x36 \x03(\x05\x12\x1c\n\x14predicateNominatives\x18\x37 \x03(\x05\x12\x18\n\x10relativePronouns\x18\x38 \x03(\x05\x12\x13\n\x0blistMembers\x18\x39 \x03(\x05\x12\x15\n\rbelongToLists\x18: \x03(\x05\"X\n\x0bIndexedWord\x12\x13\n\x0bsentenceNum\x18\x01 \x01(\x05\x12\x12\n\ntokenIndex\x18\x02 \x01(\x05\x12\r\n\x05\x64ocID\x18\x03 \x01(\x05\x12\x11\n\tcopyCount\x18\x04 \x01(\r\"4\n\x0bSpeakerInfo\x12\x13\n\x0bspeakerName\x18\x01 \x01(\t\x12\x10\n\x08mentions\x18\x02 \x03(\x05\"\"\n\x04Span\x12\r\n\x05\x62\x65gin\x18\x01 \x02(\r\x12\x0b\n\x03\x65nd\x18\x02 \x02(\r\"w\n\x05Timex\x12\r\n\x05value\x18\x01 \x01(\t\x12\x10\n\x08\x61ltValue\x18\x02 \x01(\t\x12\x0c\n\x04text\x18\x03 \x01(\t\x12\x0c\n\x04type\x18\x04 \x01(\t\x12\x0b\n\x03tid\x18\x05 \x01(\t\x12\x12\n\nbeginPoint\x18\x06 \x01(\r\x12\x10\n\x08\x65ndPoint\x18\x07 \x01(\r\"\xdb\x01\n\x06\x45ntity\x12\x11\n\theadStart\x18\x06 \x01(\r\x12\x0f\n\x07headEnd\x18\x07 \x01(\r\x12\x13\n\x0bmentionType\x18\x08 \x01(\t\x12\x16\n\x0enormalizedName\x18\t \x01(\t\x12\x16\n\x0eheadTokenIndex\x18\n \x01(\r\x12\x0f\n\x07\x63orefID\x18\x0b \x01(\t\x12\x10\n\x08objectID\x18\x01 \x01(\t\x12\x13\n\x0b\x65xtentStart\x18\x02 \x01(\r\x12\x11\n\textentEnd\x18\x03 \x01(\r\x12\x0c\n\x04type\x18\x04 \x01(\t\x12\x0f\n\x07subtype\x18\x05 \x01(\t\"\xb7\x01\n\x08Relation\x12\x0f\n\x07\x61rgName\x18\x06 \x03(\t\x12.\n\x03\x61rg\x18\x07 \x03(\x0b\x32!.edu.stanford.nlp.pipeline.Entity\x12\x11\n\tsignature\x18\x08 \x01(\t\x12\x10\n\x08objectID\x18\x01 \x01(\t\x12\x13\n\x0b\x65xtentStart\x18\x02 \x01(\r\x12\x11\n\textentEnd\x18\x03 \x01(\r\x12\x0c\n\x04type\x18\x04 \x01(\t\x12\x0f\n\x07subtype\x18\x05 \x01(\t\"\xb2\x01\n\x08Operator\x12\x0c\n\x04name\x18\x01 \x02(\t\x12\x1b\n\x13quantifierSpanBegin\x18\x02 \x02(\x05\x12\x19\n\x11quantifierSpanEnd\x18\x03 \x02(\x05\x12\x18\n\x10subjectSpanBegin\x18\x04 \x02(\x05\x12\x16\n\x0esubjectSpanEnd\x18\x05 \x02(\x05\x12\x17\n\x0fobjectSpanBegin\x18\x06 \x02(\x05\x12\x15\n\robjectSpanEnd\x18\x07 \x02(\x05\"\xa9\x04\n\x08Polarity\x12K\n\x12projectEquivalence\x18\x01 \x02(\x0e\x32/.edu.stanford.nlp.pipeline.NaturalLogicRelation\x12Q\n\x18projectForwardEntailment\x18\x02 \x02(\x0e\x32/.edu.stanford.nlp.pipeline.NaturalLogicRelation\x12Q\n\x18projectReverseEntailment\x18\x03 \x02(\x0e\x32/.edu.stanford.nlp.pipeline.NaturalLogicRelation\x12H\n\x0fprojectNegation\x18\x04 \x02(\x0e\x32/.edu.stanford.nlp.pipeline.NaturalLogicRelation\x12K\n\x12projectAlternation\x18\x05 \x02(\x0e\x32/.edu.stanford.nlp.pipeline.NaturalLogicRelation\x12\x45\n\x0cprojectCover\x18\x06 \x02(\x0e\x32/.edu.stanford.nlp.pipeline.NaturalLogicRelation\x12L\n\x13projectIndependence\x18\x07 \x02(\x0e\x32/.edu.stanford.nlp.pipeline.NaturalLogicRelation\"\xdd\x02\n\nNERMention\x12\x15\n\rsentenceIndex\x18\x01 \x01(\r\x12%\n\x1dtokenStartInSentenceInclusive\x18\x02 \x02(\r\x12#\n\x1btokenEndInSentenceExclusive\x18\x03 \x02(\r\x12\x0b\n\x03ner\x18\x04 \x02(\t\x12\x15\n\rnormalizedNER\x18\x05 \x01(\t\x12\x12\n\nentityType\x18\x06 \x01(\t\x12/\n\x05timex\x18\x07 \x01(\x0b\x32 .edu.stanford.nlp.pipeline.Timex\x12\x17\n\x0fwikipediaEntity\x18\x08 \x01(\t\x12\x0e\n\x06gender\x18\t \x01(\t\x12\x1a\n\x12\x65ntityMentionIndex\x18\n \x01(\r\x12#\n\x1b\x63\x61nonicalEntityMentionIndex\x18\x0b \x01(\r\x12\x19\n\x11\x65ntityMentionText\x18\x0c \x01(\t\"Y\n\x10SentenceFragment\x12\x12\n\ntokenIndex\x18\x01 \x03(\r\x12\x0c\n\x04root\x18\x02 \x01(\r\x12\x14\n\x0c\x61ssumedTruth\x18\x03 \x01(\x08\x12\r\n\x05score\x18\x04 \x01(\x01\":\n\rTokenLocation\x12\x15\n\rsentenceIndex\x18\x01 \x01(\r\x12\x12\n\ntokenIndex\x18\x02 \x01(\r\"\x9a\x03\n\x0eRelationTriple\x12\x0f\n\x07subject\x18\x01 \x01(\t\x12\x10\n\x08relation\x18\x02 \x01(\t\x12\x0e\n\x06object\x18\x03 \x01(\t\x12\x12\n\nconfidence\x18\x04 \x01(\x01\x12?\n\rsubjectTokens\x18\r \x03(\x0b\x32(.edu.stanford.nlp.pipeline.TokenLocation\x12@\n\x0erelationTokens\x18\x0e \x03(\x0b\x32(.edu.stanford.nlp.pipeline.TokenLocation\x12>\n\x0cobjectTokens\x18\x0f \x03(\x0b\x32(.edu.stanford.nlp.pipeline.TokenLocation\x12\x38\n\x04tree\x18\x08 \x01(\x0b\x32*.edu.stanford.nlp.pipeline.DependencyGraph\x12\x0e\n\x06istmod\x18\t \x01(\x08\x12\x10\n\x08prefixBe\x18\n \x01(\x08\x12\x10\n\x08suffixBe\x18\x0b \x01(\x08\x12\x10\n\x08suffixOf\x18\x0c \x01(\x08\"-\n\x0fMapStringString\x12\x0b\n\x03key\x18\x01 \x03(\t\x12\r\n\x05value\x18\x02 \x03(\t\"*\n\x0cMapIntString\x12\x0b\n\x03key\x18\x01 \x03(\r\x12\r\n\x05value\x18\x02 \x03(\t\"\xfc\x01\n\x07Section\x12\x11\n\tcharBegin\x18\x01 \x02(\r\x12\x0f\n\x07\x63harEnd\x18\x02 \x02(\r\x12\x0e\n\x06\x61uthor\x18\x03 \x01(\t\x12\x17\n\x0fsentenceIndexes\x18\x04 \x03(\r\x12\x10\n\x08\x64\x61tetime\x18\x05 \x01(\t\x12\x30\n\x06quotes\x18\x06 \x03(\x0b\x32 .edu.stanford.nlp.pipeline.Quote\x12\x17\n\x0f\x61uthorCharBegin\x18\x07 \x01(\r\x12\x15\n\rauthorCharEnd\x18\x08 \x01(\r\x12\x30\n\x06xmlTag\x18\t \x02(\x0b\x32 .edu.stanford.nlp.pipeline.Token\"\xe4\x01\n\x0eSemgrexRequest\x12\x0f\n\x07semgrex\x18\x01 \x03(\t\x12\x45\n\x05query\x18\x02 \x03(\x0b\x32\x36.edu.stanford.nlp.pipeline.SemgrexRequest.Dependencies\x1az\n\x0c\x44\x65pendencies\x12/\n\x05token\x18\x01 \x03(\x0b\x32 .edu.stanford.nlp.pipeline.Token\x12\x39\n\x05graph\x18\x02 \x02(\x0b\x32*.edu.stanford.nlp.pipeline.DependencyGraph\"\x8a\x04\n\x0fSemgrexResponse\x12\x46\n\x06result\x18\x01 \x03(\x0b\x32\x36.edu.stanford.nlp.pipeline.SemgrexResponse.GraphResult\x1a-\n\tNamedNode\x12\x0c\n\x04name\x18\x01 \x02(\t\x12\x12\n\nmatchIndex\x18\x02 \x02(\x05\x1a+\n\rNamedRelation\x12\x0c\n\x04name\x18\x01 \x02(\t\x12\x0c\n\x04reln\x18\x02 \x02(\t\x1a\xa7\x01\n\x05Match\x12\x12\n\nmatchIndex\x18\x01 \x02(\x05\x12\x42\n\x04node\x18\x02 \x03(\x0b\x32\x34.edu.stanford.nlp.pipeline.SemgrexResponse.NamedNode\x12\x46\n\x04reln\x18\x03 \x03(\x0b\x32\x38.edu.stanford.nlp.pipeline.SemgrexResponse.NamedRelation\x1aP\n\rSemgrexResult\x12?\n\x05match\x18\x01 \x03(\x0b\x32\x30.edu.stanford.nlp.pipeline.SemgrexResponse.Match\x1aW\n\x0bGraphResult\x12H\n\x06result\x18\x01 \x03(\x0b\x32\x38.edu.stanford.nlp.pipeline.SemgrexResponse.SemgrexResult\"W\n\x12TokensRegexRequest\x12\x30\n\x03\x64oc\x18\x01 \x02(\x0b\x32#.edu.stanford.nlp.pipeline.Document\x12\x0f\n\x07pattern\x18\x02 \x03(\t\"\xa7\x03\n\x13TokensRegexResponse\x12J\n\x05match\x18\x01 \x03(\x0b\x32;.edu.stanford.nlp.pipeline.TokensRegexResponse.PatternMatch\x1a\x39\n\rMatchLocation\x12\x0c\n\x04text\x18\x01 \x01(\t\x12\r\n\x05\x62\x65gin\x18\x02 \x01(\x05\x12\x0b\n\x03\x65nd\x18\x03 \x01(\x05\x1a\xb3\x01\n\x05Match\x12\x10\n\x08sentence\x18\x01 \x02(\x05\x12K\n\x05match\x18\x02 \x02(\x0b\x32<.edu.stanford.nlp.pipeline.TokensRegexResponse.MatchLocation\x12K\n\x05group\x18\x03 \x03(\x0b\x32<.edu.stanford.nlp.pipeline.TokensRegexResponse.MatchLocation\x1aS\n\x0cPatternMatch\x12\x43\n\x05match\x18\x01 \x03(\x0b\x32\x34.edu.stanford.nlp.pipeline.TokensRegexResponse.Match\"\xae\x01\n\x19\x44\x65pendencyEnhancerRequest\x12\x35\n\x08\x64ocument\x18\x01 \x02(\x0b\x32#.edu.stanford.nlp.pipeline.Document\x12\x37\n\x08language\x18\x02 \x01(\x0e\x32#.edu.stanford.nlp.pipeline.LanguageH\x00\x12\x1a\n\x10relativePronouns\x18\x03 \x01(\tH\x00\x42\x05\n\x03ref\"\xb4\x01\n\x12\x46lattenedParseTree\x12\x41\n\x05nodes\x18\x01 \x03(\x0b\x32\x32.edu.stanford.nlp.pipeline.FlattenedParseTree.Node\x1a[\n\x04Node\x12\x12\n\x08openNode\x18\x01 \x01(\x08H\x00\x12\x13\n\tcloseNode\x18\x02 \x01(\x08H\x00\x12\x0f\n\x05value\x18\x03 \x01(\tH\x00\x12\r\n\x05score\x18\x04 \x01(\x01\x42\n\n\x08\x63ontents\"\xf6\x01\n\x15\x45valuateParserRequest\x12N\n\x08treebank\x18\x01 \x03(\x0b\x32<.edu.stanford.nlp.pipeline.EvaluateParserRequest.ParseResult\x1a\x8c\x01\n\x0bParseResult\x12;\n\x04gold\x18\x01 \x02(\x0b\x32-.edu.stanford.nlp.pipeline.FlattenedParseTree\x12@\n\tpredicted\x18\x02 \x03(\x0b\x32-.edu.stanford.nlp.pipeline.FlattenedParseTree\"$\n\x16\x45valuateParserResponse\x12\n\n\x02\x66\x31\x18\x01 \x02(\x01*\xa3\x01\n\x08Language\x12\x0b\n\x07Unknown\x10\x00\x12\x07\n\x03\x41ny\x10\x01\x12\n\n\x06\x41rabic\x10\x02\x12\x0b\n\x07\x43hinese\x10\x03\x12\x0b\n\x07\x45nglish\x10\x04\x12\n\n\x06German\x10\x05\x12\n\n\x06\x46rench\x10\x06\x12\n\n\x06Hebrew\x10\x07\x12\x0b\n\x07Spanish\x10\x08\x12\x14\n\x10UniversalEnglish\x10\t\x12\x14\n\x10UniversalChinese\x10\n*h\n\tSentiment\x12\x13\n\x0fSTRONG_NEGATIVE\x10\x00\x12\x11\n\rWEAK_NEGATIVE\x10\x01\x12\x0b\n\x07NEUTRAL\x10\x02\x12\x11\n\rWEAK_POSITIVE\x10\x03\x12\x13\n\x0fSTRONG_POSITIVE\x10\x04*\x93\x01\n\x14NaturalLogicRelation\x12\x0f\n\x0b\x45QUIVALENCE\x10\x00\x12\x16\n\x12\x46ORWARD_ENTAILMENT\x10\x01\x12\x16\n\x12REVERSE_ENTAILMENT\x10\x02\x12\x0c\n\x08NEGATION\x10\x03\x12\x0f\n\x0b\x41LTERNATION\x10\x04\x12\t\n\x05\x43OVER\x10\x05\x12\x10\n\x0cINDEPENDENCE\x10\x06\x42*\n\x19\x65\x64u.stanford.nlp.pipelineB\rCoreNLPProtos' ) _LANGUAGE = _descriptor.EnumDescriptor( @@ -75,8 +75,8 @@ _LANGUAGE = _descriptor.EnumDescriptor( ], containing_type=None, serialized_options=None, - serialized_start=11149, - serialized_end=11312, + serialized_start=11619, + serialized_end=11782, ) _sym_db.RegisterEnumDescriptor(_LANGUAGE) @@ -110,8 +110,8 @@ _SENTIMENT = _descriptor.EnumDescriptor( ], containing_type=None, serialized_options=None, - serialized_start=11314, - serialized_end=11418, + serialized_start=11784, + serialized_end=11888, ) _sym_db.RegisterEnumDescriptor(_SENTIMENT) @@ -153,8 +153,8 @@ _NATURALLOGICRELATION = _descriptor.EnumDescriptor( ], containing_type=None, serialized_options=None, - serialized_start=11421, - serialized_end=11568, + serialized_start=11891, + serialized_end=12038, ) _sym_db.RegisterEnumDescriptor(_NATURALLOGICRELATION) @@ -3522,6 +3522,190 @@ _DEPENDENCYENHANCERREQUEST = _descriptor.Descriptor( serialized_end=11146, ) + +_FLATTENEDPARSETREE_NODE = _descriptor.Descriptor( + name='Node', + full_name='edu.stanford.nlp.pipeline.FlattenedParseTree.Node', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='openNode', full_name='edu.stanford.nlp.pipeline.FlattenedParseTree.Node.openNode', index=0, + number=1, type=8, cpp_type=7, label=1, + has_default_value=False, default_value=False, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='closeNode', full_name='edu.stanford.nlp.pipeline.FlattenedParseTree.Node.closeNode', index=1, + number=2, type=8, cpp_type=7, label=1, + has_default_value=False, default_value=False, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='value', full_name='edu.stanford.nlp.pipeline.FlattenedParseTree.Node.value', index=2, + number=3, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='score', full_name='edu.stanford.nlp.pipeline.FlattenedParseTree.Node.score', index=3, + number=4, type=1, cpp_type=5, label=1, + has_default_value=False, default_value=float(0), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto2', + extension_ranges=[], + oneofs=[ + _descriptor.OneofDescriptor( + name='contents', full_name='edu.stanford.nlp.pipeline.FlattenedParseTree.Node.contents', + index=0, containing_type=None, fields=[]), + ], + serialized_start=11238, + serialized_end=11329, +) + +_FLATTENEDPARSETREE = _descriptor.Descriptor( + name='FlattenedParseTree', + full_name='edu.stanford.nlp.pipeline.FlattenedParseTree', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='nodes', full_name='edu.stanford.nlp.pipeline.FlattenedParseTree.nodes', index=0, + number=1, type=11, cpp_type=10, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[_FLATTENEDPARSETREE_NODE, ], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto2', + extension_ranges=[], + oneofs=[ + ], + serialized_start=11149, + serialized_end=11329, +) + + +_EVALUATEPARSERREQUEST_PARSERESULT = _descriptor.Descriptor( + name='ParseResult', + full_name='edu.stanford.nlp.pipeline.EvaluateParserRequest.ParseResult', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='gold', full_name='edu.stanford.nlp.pipeline.EvaluateParserRequest.ParseResult.gold', index=0, + number=1, type=11, cpp_type=10, label=2, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='predicted', full_name='edu.stanford.nlp.pipeline.EvaluateParserRequest.ParseResult.predicted', index=1, + number=2, type=11, cpp_type=10, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto2', + extension_ranges=[], + oneofs=[ + ], + serialized_start=11438, + serialized_end=11578, +) + +_EVALUATEPARSERREQUEST = _descriptor.Descriptor( + name='EvaluateParserRequest', + full_name='edu.stanford.nlp.pipeline.EvaluateParserRequest', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='treebank', full_name='edu.stanford.nlp.pipeline.EvaluateParserRequest.treebank', index=0, + number=1, type=11, cpp_type=10, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[_EVALUATEPARSERREQUEST_PARSERESULT, ], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto2', + extension_ranges=[], + oneofs=[ + ], + serialized_start=11332, + serialized_end=11578, +) + + +_EVALUATEPARSERRESPONSE = _descriptor.Descriptor( + name='EvaluateParserResponse', + full_name='edu.stanford.nlp.pipeline.EvaluateParserResponse', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='f1', full_name='edu.stanford.nlp.pipeline.EvaluateParserResponse.f1', index=0, + number=1, type=1, cpp_type=5, label=2, + has_default_value=False, default_value=float(0), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto2', + extension_ranges=[], + oneofs=[ + ], + serialized_start=11580, + serialized_end=11616, +) + _DOCUMENT.fields_by_name['sentence'].message_type = _SENTENCE _DOCUMENT.fields_by_name['corefChain'].message_type = _COREFCHAIN _DOCUMENT.fields_by_name['sentencelessToken'].message_type = _TOKEN @@ -3619,6 +3803,21 @@ _DEPENDENCYENHANCERREQUEST.fields_by_name['language'].containing_oneof = _DEPEND _DEPENDENCYENHANCERREQUEST.oneofs_by_name['ref'].fields.append( _DEPENDENCYENHANCERREQUEST.fields_by_name['relativePronouns']) _DEPENDENCYENHANCERREQUEST.fields_by_name['relativePronouns'].containing_oneof = _DEPENDENCYENHANCERREQUEST.oneofs_by_name['ref'] +_FLATTENEDPARSETREE_NODE.containing_type = _FLATTENEDPARSETREE +_FLATTENEDPARSETREE_NODE.oneofs_by_name['contents'].fields.append( + _FLATTENEDPARSETREE_NODE.fields_by_name['openNode']) +_FLATTENEDPARSETREE_NODE.fields_by_name['openNode'].containing_oneof = _FLATTENEDPARSETREE_NODE.oneofs_by_name['contents'] +_FLATTENEDPARSETREE_NODE.oneofs_by_name['contents'].fields.append( + _FLATTENEDPARSETREE_NODE.fields_by_name['closeNode']) +_FLATTENEDPARSETREE_NODE.fields_by_name['closeNode'].containing_oneof = _FLATTENEDPARSETREE_NODE.oneofs_by_name['contents'] +_FLATTENEDPARSETREE_NODE.oneofs_by_name['contents'].fields.append( + _FLATTENEDPARSETREE_NODE.fields_by_name['value']) +_FLATTENEDPARSETREE_NODE.fields_by_name['value'].containing_oneof = _FLATTENEDPARSETREE_NODE.oneofs_by_name['contents'] +_FLATTENEDPARSETREE.fields_by_name['nodes'].message_type = _FLATTENEDPARSETREE_NODE +_EVALUATEPARSERREQUEST_PARSERESULT.fields_by_name['gold'].message_type = _FLATTENEDPARSETREE +_EVALUATEPARSERREQUEST_PARSERESULT.fields_by_name['predicted'].message_type = _FLATTENEDPARSETREE +_EVALUATEPARSERREQUEST_PARSERESULT.containing_type = _EVALUATEPARSERREQUEST +_EVALUATEPARSERREQUEST.fields_by_name['treebank'].message_type = _EVALUATEPARSERREQUEST_PARSERESULT DESCRIPTOR.message_types_by_name['Document'] = _DOCUMENT DESCRIPTOR.message_types_by_name['Sentence'] = _SENTENCE DESCRIPTOR.message_types_by_name['Token'] = _TOKEN @@ -3647,6 +3846,9 @@ DESCRIPTOR.message_types_by_name['SemgrexResponse'] = _SEMGREXRESPONSE DESCRIPTOR.message_types_by_name['TokensRegexRequest'] = _TOKENSREGEXREQUEST DESCRIPTOR.message_types_by_name['TokensRegexResponse'] = _TOKENSREGEXRESPONSE DESCRIPTOR.message_types_by_name['DependencyEnhancerRequest'] = _DEPENDENCYENHANCERREQUEST +DESCRIPTOR.message_types_by_name['FlattenedParseTree'] = _FLATTENEDPARSETREE +DESCRIPTOR.message_types_by_name['EvaluateParserRequest'] = _EVALUATEPARSERREQUEST +DESCRIPTOR.message_types_by_name['EvaluateParserResponse'] = _EVALUATEPARSERRESPONSE DESCRIPTOR.enum_types_by_name['Language'] = _LANGUAGE DESCRIPTOR.enum_types_by_name['Sentiment'] = _SENTIMENT DESCRIPTOR.enum_types_by_name['NaturalLogicRelation'] = _NATURALLOGICRELATION @@ -3944,6 +4146,43 @@ DependencyEnhancerRequest = _reflection.GeneratedProtocolMessageType('Dependency }) _sym_db.RegisterMessage(DependencyEnhancerRequest) +FlattenedParseTree = _reflection.GeneratedProtocolMessageType('FlattenedParseTree', (_message.Message,), { + + 'Node' : _reflection.GeneratedProtocolMessageType('Node', (_message.Message,), { + 'DESCRIPTOR' : _FLATTENEDPARSETREE_NODE, + '__module__' : 'CoreNLP_pb2' + # @@protoc_insertion_point(class_scope:edu.stanford.nlp.pipeline.FlattenedParseTree.Node) + }) + , + 'DESCRIPTOR' : _FLATTENEDPARSETREE, + '__module__' : 'CoreNLP_pb2' + # @@protoc_insertion_point(class_scope:edu.stanford.nlp.pipeline.FlattenedParseTree) + }) +_sym_db.RegisterMessage(FlattenedParseTree) +_sym_db.RegisterMessage(FlattenedParseTree.Node) + +EvaluateParserRequest = _reflection.GeneratedProtocolMessageType('EvaluateParserRequest', (_message.Message,), { + + 'ParseResult' : _reflection.GeneratedProtocolMessageType('ParseResult', (_message.Message,), { + 'DESCRIPTOR' : _EVALUATEPARSERREQUEST_PARSERESULT, + '__module__' : 'CoreNLP_pb2' + # @@protoc_insertion_point(class_scope:edu.stanford.nlp.pipeline.EvaluateParserRequest.ParseResult) + }) + , + 'DESCRIPTOR' : _EVALUATEPARSERREQUEST, + '__module__' : 'CoreNLP_pb2' + # @@protoc_insertion_point(class_scope:edu.stanford.nlp.pipeline.EvaluateParserRequest) + }) +_sym_db.RegisterMessage(EvaluateParserRequest) +_sym_db.RegisterMessage(EvaluateParserRequest.ParseResult) + +EvaluateParserResponse = _reflection.GeneratedProtocolMessageType('EvaluateParserResponse', (_message.Message,), { + 'DESCRIPTOR' : _EVALUATEPARSERRESPONSE, + '__module__' : 'CoreNLP_pb2' + # @@protoc_insertion_point(class_scope:edu.stanford.nlp.pipeline.EvaluateParserResponse) + }) +_sym_db.RegisterMessage(EvaluateParserResponse) + DESCRIPTOR._options = None _DEPENDENCYGRAPH.fields_by_name['root']._options = None diff --git a/stanza/server/java_protobuf_requests.py b/stanza/server/java_protobuf_requests.py index ee036387..d3f67e91 100644 --- a/stanza/server/java_protobuf_requests.py +++ b/stanza/server/java_protobuf_requests.py @@ -1,5 +1,8 @@ +from collections import deque import subprocess +from stanza.models.constituency.parse_tree import Tree +from stanza.protobuf import FlattenedParseTree from stanza.server.client import resolve_classpath def send_request(request, response_type, java_main, classpath=None): @@ -16,6 +19,95 @@ def send_request(request, response_type, java_main, classpath=None): response.ParseFromString(pipe.stdout) return response +def add_tree_nodes(proto_tree, tree, score): + # add an open node + node = proto_tree.nodes.add() + node.openNode = True + if score is not None: + node.score = score + + # add the content of this node + node = proto_tree.nodes.add() + node.value = tree.label + + # add all children... + # leaves get just one node + # branches are called recursively + for child in tree.children: + if child.is_leaf(): + node = proto_tree.nodes.add() + node.value = child.label + else: + add_tree_nodes(proto_tree, child, None) + + node = proto_tree.nodes.add() + node.closeNode = True + +def build_tree(tree, score): + """ + Builds a FlattenedParseTree from CoreNLP.proto + + Populates the value field from tree.label and iterates through the + children via tree.children. Should work on any tree structure + which follows that layout + + The score will be added to the top node (if it is not None) + + Operates by recursively calling add_tree_nodes + """ + proto_tree = FlattenedParseTree() + add_tree_nodes(proto_tree, tree, score) + return proto_tree + +def from_tree(proto_tree): + """ + Convert a FlattenedParseTree back into a Tree + + returns Tree, score + (score might be None if it is missing) + """ + score = None + stack = deque() + for node in proto_tree.nodes: + if node.HasField("score") and score is None: + score = node.score + + if node.openNode: + if len(stack) > 0 and isinstance(stack[-1], FlattenedParseTree.Node) and stack[-1].openNode: + raise ValueError("Got a proto with no label on a node: {}".format(proto_tree)) + stack.append(node) + continue + if not node.closeNode: + child = Tree(label=node.value) + # TODO: do something with the score + stack.append(child) + continue + + # must be a close operation... + if len(stack) <= 1: + raise ValueError("Got a proto with too many close operations: {}".format(proto_tree)) + # on a close operation, pop until we hit the open + # then turn everything in that span into a new node + children = [] + nextNode = stack.pop() + while not isinstance(nextNode, FlattenedParseTree.Node): + children.append(nextNode) + nextNode = stack.pop() + if len(children) == 0: + raise ValueError("Got a proto with an open immediately followed by a close: {}".format(proto_tree)) + children.reverse() + label = children[0] + children = children[1:] + subtree = Tree(label=label.label, children=children) + stack.append(subtree) + + if len(stack) > 1: + raise ValueError("Got a proto which does not close all of the nodes: {}".format(proto_tree)) + tree = stack.pop() + if not isinstance(tree, Tree): + raise ValueError("Got a proto which was just one Open operation: {}".format(proto_tree)) + return tree, score + def add_token(token_list, word, token): """ Add a token to a proto request. diff --git a/stanza/server/parser_eval.py b/stanza/server/parser_eval.py new file mode 100644 index 00000000..c5b30f6a --- /dev/null +++ b/stanza/server/parser_eval.py @@ -0,0 +1,41 @@ + + + +import stanza +from stanza.protobuf import EvaluateParserRequest, EvaluateParserResponse +from stanza.server.java_protobuf_requests import send_request, build_tree, JavaProtobufContext + + +EVALUATE_JAVA = "edu.stanford.nlp.parser.metrics.EvaluateExternalParser" + +def build_request(treebank): + """ + treebank should be a list of pairs: [gold, predictions] + each predictions is a list of pairs (prediction, score) + Note that for now, only one tree is measured, but this may be extensible in the future + Trees should be in the form of a Tree from parse_tree.py + """ + request = EvaluateParserRequest() + for gold, predictions in treebank: + parse_result = request.treebank.add() + parse_result.gold.CopyFrom(build_tree(gold, None)) + for prediction, score in predictions: + parse_result.predicted.append(build_tree(prediction, score)) + + return request + + +class EvaluateParser(JavaProtobufContext): + """ + Parser evaluation context window + + This is a context window which keeps a process open. Should allow + for multiple requests without launching new java processes each time. + """ + def __init__(self, classpath=None): + super(EvaluateParser, self).__init__(classpath, EvaluateParserResponse, EVALUATE_JAVA) + + def process(self, treebank): + request = build_request(treebank) + return self.process_request(request) + diff --git a/stanza/tests/constituency/test_lstm_model.py b/stanza/tests/constituency/test_lstm_model.py new file mode 100644 index 00000000..9b5d00f1 --- /dev/null +++ b/stanza/tests/constituency/test_lstm_model.py @@ -0,0 +1,208 @@ +import os +import tempfile + +import pytest + +from stanza.models import constituency_parser +from stanza.models.common import pretrain +from stanza.models.constituency import lstm_model +from stanza.models.constituency import parse_transitions +from stanza.models.constituency import parse_tree +from stanza.models.constituency import trainer +from stanza.models.constituency import transition_sequence +from stanza.models.constituency import tree_reader +from stanza.tests import * +from stanza.tests.constituency import test_parse_transitions + +pytestmark = [pytest.mark.pipeline, pytest.mark.travis] + +TREEBANK = """ +( (S + (VP (VBG Enjoying) + (NP (PRP$ my) (JJ favorite) (NN Friday) (NN tradition))) + (. .))) + +( (NP + (VP (VBG Sitting) + (PP (IN in) + (NP (DT a) (RB stifling) (JJ hot) (NNP South) (NNP Station))) + (VP (VBG waiting) + (PP (IN for) + (NP (PRP$ my) (JJ delayed) (NNP @MBTA) (NN train))))) + (. .))) + +( (S + (NP (PRP I)) + (VP + (ADVP (RB really)) + (VBP hate) + (NP (DT the) (NNP @MBTA))))) + +( (S + (S (VP (VB Seek))) + (CC and) + (S (NP (PRP ye)) + (VP (MD shall) + (VP (VB find)))) + (. .))) +""" + +@pytest.fixture(scope="module") +def pt(): + return pretrain.Pretrain(vec_filename=f'{TEST_WORKING_DIR}/in/tiny_emb.xz', save_to_file=False) + +def build_model(pt, *args): + # TODO: build a fake embedding some other way? + trees = tree_reader.read_trees(TREEBANK) + + args = constituency_parser.parse_args(args) + + transitions = trainer.build_treebank(trees, args['transition_scheme']) + transitions = transition_sequence.all_transitions(transitions) + constituents = parse_tree.Tree.get_unique_constituent_labels(trees) + tags = parse_tree.Tree.get_unique_tags(trees) + words = parse_tree.Tree.get_unique_words(trees) + rare_words = parse_tree.Tree.get_rare_words(trees) + root_labels = parse_tree.Tree.get_root_labels(trees) + open_nodes = trainer.get_open_nodes(trees, args) + + model = lstm_model.LSTMModel(pt, None, None, transitions, constituents, tags, words, rare_words, root_labels, open_nodes, args) + return model + +@pytest.fixture(scope="module") +def unary_model(pt): + return build_model(pt, "--transition_scheme", "TOP_DOWN_UNARY") + +def test_initial_model(unary_model): + """ + does nothing, just tests that the construction went okay + """ + pass + +def test_initial_state(unary_model): + test_parse_transitions.test_initial_state(unary_model) + +def test_shift(pt): + # TODO: might be good to include some tests specifically for shift + # in the context of a model with unaries + model = build_model(pt) + test_parse_transitions.test_shift(model) + +def test_unary(unary_model): + test_parse_transitions.test_unary(unary_model) + +def test_unary_requires_root(unary_model): + test_parse_transitions.test_unary_requires_root(unary_model) + +def test_open(unary_model): + test_parse_transitions.test_open(unary_model) + +def test_compound_open(pt): + model = build_model(pt, '--transition_scheme', "TOP_DOWN_COMPOUND") + test_parse_transitions.test_compound_open(model) + +def test_in_order_open(pt): + model = build_model(pt, '--transition_scheme', "IN_ORDER") + test_parse_transitions.test_in_order_open(model) + +def test_close(unary_model): + test_parse_transitions.test_close(unary_model) + +def run_forward_checks(model): + state = test_parse_transitions.build_initial_state(model)[0] + model((state,)) + + shift = parse_transitions.Shift() + state = shift.apply(state, model) + model((state,)) + + open_transition = parse_transitions.OpenConstituent("NP") + assert open_transition.is_legal(state, model) + state = open_transition.apply(state, model) + assert state.num_opens == 1 + model((state,)) + + state = shift.apply(state, model) + model((state,)) + state = shift.apply(state, model) + model((state,)) + assert state.num_opens == 1 + # now should have "mox", "opal" on the constituents + + close_transition = parse_transitions.CloseConstituent() + assert close_transition.is_legal(state, model) + state = close_transition.apply(state, model) + assert state.num_opens == 0 + + model((state,)) + +def test_unary_forward(pt, unary_model): + """ + Checks that the forward pass doesn't crash when run after various operations + + Doesn't check the forward pass for making reasonable answers + """ + run_forward_checks(unary_model) + +def test_lstm_forward(pt): + model = build_model(pt, '--num_lstm_layers', '1') + run_forward_checks(model) + model = build_model(pt, '--num_lstm_layers', '2') + run_forward_checks(model) + model = build_model(pt, '--num_lstm_layers', '3') + run_forward_checks(model) + +def test_multiple_output_forward(pt): + """ + Test a couple different sizes of output layers + """ + model = build_model(pt, '--num_output_layers', '1', '--num_lstm_layers', '2') + run_forward_checks(model) + + model = build_model(pt, '--num_output_layers', '2', '--num_lstm_layers', '2') + run_forward_checks(model) + +def test_no_tag_embedding_forward(pt): + """ + Test that the model continues to work if the tag embedding is turned on or off + """ + model = build_model(pt, '--tag_embedding_dim', '20') + run_forward_checks(model) + + model = build_model(pt, '--tag_embedding_dim', '0') + run_forward_checks(model) + +def test_forward_con_lstm(pt): + """ + Tests an older version of the model + """ + model = build_model(pt, '--num_lstm_layers', '2', '--constituency_lstm') + run_forward_checks(model) + +def test_forward_combined_dummy(pt): + """ + Tests combined dummy and open node embeddings + """ + model = build_model(pt, '--combined_dummy_embedding') + run_forward_checks(model) + + model = build_model(pt, '--no_combined_dummy_embedding') + run_forward_checks(model) + +def test_save_load_model(pt, unary_model): + """ + Just tests that saving and loading works without crashs. + + Currently no test of the values themselves + """ + with tempfile.TemporaryDirectory() as tmpdirname: + tr = trainer.Trainer(model=unary_model) + + # attempt saving + filename = os.path.join(tmpdirname, "parser.pt") + tr.save(filename) + + assert os.path.exists(filename) + + # load it back in + tr.load(filename, pt, None, None, False) diff --git a/stanza/tests/constituency/test_parse_transitions.py b/stanza/tests/constituency/test_parse_transitions.py new file mode 100644 index 00000000..a28b9b19 --- /dev/null +++ b/stanza/tests/constituency/test_parse_transitions.py @@ -0,0 +1,412 @@ +import pytest + +from stanza.models.constituency import parse_transitions +from stanza.models.constituency.base_model import SimpleModel +from stanza.models.constituency.parse_transitions import TransitionScheme +from stanza.tests import * + +pytestmark = [pytest.mark.pipeline, pytest.mark.travis] + + +def build_initial_state(model): + words = ["Unban", "Mox", "Opal"] + tags = ["VB", "NNP", "NNP"] + + state = parse_transitions.initial_state_from_words([list(zip(words, tags))], model) + assert len(state) == 1 + assert state[0].num_transitions() == 0 + return state + +def test_initial_state(model=None): + if model is None: + model = SimpleModel() + states = build_initial_state(model) + assert len(states) == 1 + state = states[0] + + assert state.sentence_length == 3 + assert state.num_opens == 0 + # each stack has a sentinel value at the end + assert len(state.word_queue) == 4 + assert len(state.constituents) == 1 + assert len(state.transitions) == 1 + assert state.word_position == 0 + +def test_shift(model=None): + if model is None: + model = SimpleModel() + state = build_initial_state(model)[0] + + open_transition = parse_transitions.OpenConstituent("ROOT") + state = open_transition.apply(state, model) + open_transition = parse_transitions.OpenConstituent("S") + state = open_transition.apply(state, model) + shift = parse_transitions.Shift() + assert shift.is_legal(state, model) + assert len(state.word_queue) == 4 + assert state.word_position == 0 + + state = shift.apply(state, model) + assert len(state.word_queue) == 4 + # 4 because of the dummy created by the opens + assert len(state.constituents) == 4 + assert len(state.transitions) == 4 + assert shift.is_legal(state, model) + assert state.word_position == 1 + assert not state.empty_word_queue() + + state = shift.apply(state, model) + assert len(state.word_queue) == 4 + assert len(state.constituents) == 5 + assert len(state.transitions) == 5 + assert shift.is_legal(state, model) + assert state.word_position == 2 + assert not state.empty_word_queue() + + state = shift.apply(state, model) + assert len(state.word_queue) == 4 + assert len(state.constituents) == 6 + assert len(state.transitions) == 6 + assert not shift.is_legal(state, model) + assert state.word_position == 3 + assert state.empty_word_queue() + + constituents = state.constituents + assert model.get_top_constituent(constituents).children[0].label == 'Opal' + constituents = constituents.pop() + assert model.get_top_constituent(constituents).children[0].label == 'Mox' + constituents = constituents.pop() + assert model.get_top_constituent(constituents).children[0].label == 'Unban' + +def test_initial_unary(model=None): + # it doesn't make sense to start with a CompoundUnary + if model is None: + model = SimpleModel() + + state = build_initial_state(model)[0] + unary = parse_transitions.CompoundUnary(['ROOT', 'VP']) + assert not unary.is_legal(state, model) + unary = parse_transitions.CompoundUnary(['VP']) + assert not unary.is_legal(state, model) + + +def test_unary(model=None): + if model is None: + model = SimpleModel() + state = build_initial_state(model)[0] + + shift = parse_transitions.Shift() + state = shift.apply(state, model) + + # this is technically the wrong parse but we're being lazy + unary = parse_transitions.CompoundUnary(['S', 'VP']) + assert unary.is_legal(state, model) + state = unary.apply(state, model) + assert not unary.is_legal(state, model) + + tree = model.get_top_constituent(state.constituents) + assert tree.label == 'S' + assert len(tree.children) == 1 + tree = tree.children[0] + assert tree.label == 'VP' + assert len(tree.children) == 1 + tree = tree.children[0] + assert tree.label == 'VB' + assert tree.is_preterminal() + +def test_unary_requires_root(model=None): + if model is None: + model = SimpleModel() + state = build_initial_state(model)[0] + + open_transition = parse_transitions.OpenConstituent("S") + assert open_transition.is_legal(state, model) + state = open_transition.apply(state, model) + + shift = parse_transitions.Shift() + assert shift.is_legal(state, model) + state = shift.apply(state, model) + assert shift.is_legal(state, model) + state = shift.apply(state, model) + assert shift.is_legal(state, model) + state = shift.apply(state, model) + assert not shift.is_legal(state, model) + + close_transition = parse_transitions.CloseConstituent() + assert close_transition.is_legal(state, model) + state = close_transition.apply(state, model) + assert not open_transition.is_legal(state, model) + assert not close_transition.is_legal(state, model) + + np_unary = parse_transitions.CompoundUnary("NP") + assert not np_unary.is_legal(state, model) + root_unary = parse_transitions.CompoundUnary("ROOT") + assert root_unary.is_legal(state, model) + assert not state.finished(model) + state = root_unary.apply(state, model) + assert not root_unary.is_legal(state, model) + + assert state.finished(model) + +def test_open(model=None): + if model is None: + model = SimpleModel() + state = build_initial_state(model)[0] + + shift = parse_transitions.Shift() + state = shift.apply(state, model) + state = shift.apply(state, model) + assert state.num_opens == 0 + + open_transition = parse_transitions.OpenConstituent("VP") + assert open_transition.is_legal(state, model) + state = open_transition.apply(state, model) + assert open_transition.is_legal(state, model) + assert state.num_opens == 1 + + # check that it is illegal if there are too many opens already + for i in range(20): + state = open_transition.apply(state, model) + assert not open_transition.is_legal(state, model) + assert state.num_opens == 21 + + # check that it is illegal if the state is out of words + state = build_initial_state(model)[0] + state = shift.apply(state, model) + state = shift.apply(state, model) + state = shift.apply(state, model) + assert not open_transition.is_legal(state, model) + +def test_compound_open(model=None): + if model is None: + model = SimpleModel() + state = build_initial_state(model)[0] + + open_transition = parse_transitions.OpenConstituent("ROOT", "S") + assert open_transition.is_legal(state, model) + shift = parse_transitions.Shift() + close_transition = parse_transitions.CloseConstituent() + + state = open_transition.apply(state, model) + state = shift.apply(state, model) + state = shift.apply(state, model) + state = shift.apply(state, model) + state = close_transition.apply(state, model) + + tree = model.get_top_constituent(state.constituents) + assert tree.label == 'ROOT' + assert len(tree.children) == 1 + tree = tree.children[0] + assert tree.label == 'S' + assert len(tree.children) == 3 + assert tree.children[0].children[0].label == 'Unban' + assert tree.children[1].children[0].label == 'Mox' + assert tree.children[2].children[0].label == 'Opal' + +def test_in_order_open(model=None): + if model is None: + model = SimpleModel(TransitionScheme.IN_ORDER) + state = build_initial_state(model)[0] + + shift = parse_transitions.Shift() + assert shift.is_legal(state, model) + state = shift.apply(state, model) + assert not shift.is_legal(state, model) + + open_vp = parse_transitions.OpenConstituent("VP") + assert open_vp.is_legal(state, model) + state = open_vp.apply(state, model) + assert not open_vp.is_legal(state, model) + + close_trans = parse_transitions.CloseConstituent() + assert close_trans.is_legal(state, model) + state = close_trans.apply(state, model) + + open_s = parse_transitions.OpenConstituent("S") + assert open_s.is_legal(state, model) + state = open_s.apply(state, model) + assert not open_vp.is_legal(state, model) + + # check that root transitions won't happen in the middle of a parse + open_root = parse_transitions.OpenConstituent("ROOT") + assert not open_root.is_legal(state, model) + + # build (NP (NNP Mox) (NNP Opal)) + open_np = parse_transitions.OpenConstituent("NP") + assert shift.is_legal(state, model) + state = shift.apply(state, model) + assert open_np.is_legal(state, model) + # make sure root can't happen in places where an arbitrary open is legal + assert not open_root.is_legal(state, model) + state = open_np.apply(state, model) + assert shift.is_legal(state, model) + state = shift.apply(state, model) + assert close_trans.is_legal(state, model) + state = close_trans.apply(state, model) + + assert close_trans.is_legal(state, model) + state = close_trans.apply(state, model) + + assert open_root.is_legal(state, model) + state = open_root.apply(state, model) + +def test_too_many_unaries_close(): + """ + This tests rejecting Close at the start of a sequence after too many unary transitions + + The model should reject doing multiple "unaries" - eg, Open then Close - in an IN_ORDER sequence + """ + model = SimpleModel(TransitionScheme.IN_ORDER) + state = build_initial_state(model)[0] + + shift = parse_transitions.Shift() + assert shift.is_legal(state, model) + state = shift.apply(state, model) + + open_np = parse_transitions.OpenConstituent("NP") + close_trans = parse_transitions.CloseConstituent() + for _ in range(parse_transitions.UNARY_LIMIT): + assert open_np.is_legal(state, model) + state = open_np.apply(state, model) + + assert close_trans.is_legal(state, model) + state = close_trans.apply(state, model) + + assert open_np.is_legal(state, model) + state = open_np.apply(state, model) + assert not close_trans.is_legal(state, model) + +def test_too_many_unaries_open(): + """ + This tests rejecting Open in the middle of a sequence after too many unary transitions + + The model should reject doing multiple "unaries" - eg, Open then Close - in an IN_ORDER sequence + """ + model = SimpleModel(TransitionScheme.IN_ORDER) + state = build_initial_state(model)[0] + + shift = parse_transitions.Shift() + assert shift.is_legal(state, model) + state = shift.apply(state, model) + + open_np = parse_transitions.OpenConstituent("NP") + close_trans = parse_transitions.CloseConstituent() + + assert open_np.is_legal(state, model) + state = open_np.apply(state, model) + assert not open_np.is_legal(state, model) + assert shift.is_legal(state, model) + state = shift.apply(state, model) + + for _ in range(parse_transitions.UNARY_LIMIT): + assert open_np.is_legal(state, model) + state = open_np.apply(state, model) + + assert close_trans.is_legal(state, model) + state = close_trans.apply(state, model) + + assert not open_np.is_legal(state, model) + +def test_close(model=None): + if model is None: + model = SimpleModel() + # this one actually tests an entire subtree building + state = build_initial_state(model)[0] + + shift = parse_transitions.Shift() + state = shift.apply(state, model) + + open_transition = parse_transitions.OpenConstituent("NP") + assert open_transition.is_legal(state, model) + state = open_transition.apply(state, model) + assert state.num_opens == 1 + + state = shift.apply(state, model) + state = shift.apply(state, model) + assert state.num_opens == 1 + # now should have "mox", "opal" on the constituents + + close_transition = parse_transitions.CloseConstituent() + assert close_transition.is_legal(state, model) + state = close_transition.apply(state, model) + assert state.num_opens == 0 + assert not close_transition.is_legal(state, model) + + tree = model.get_top_constituent(state.constituents) + assert tree.label == 'NP' + assert len(tree.children) == 2 + assert tree.children[0].is_preterminal() + assert tree.children[1].is_preterminal() + assert tree.children[0].children[0].label == 'Mox' + assert tree.children[1].children[0].label == 'Opal' + + assert len(state.constituents) == 3 + + assert state.all_transitions(model) == [shift, open_transition, shift, shift, close_transition] + +def test_hashes(): + transitions = set() + + shift = parse_transitions.Shift() + assert shift not in transitions + transitions.add(shift) + assert shift in transitions + shift = parse_transitions.Shift() + assert shift in transitions + + for i in range(5): + transitions.add(shift) + assert len(transitions) == 1 + + unary = parse_transitions.CompoundUnary("asdf") + assert unary not in transitions + transitions.add(unary) + assert unary in transitions + + unary = parse_transitions.CompoundUnary(["asdf", "zzzz"]) + assert unary not in transitions + transitions.add(unary) + transitions.add(unary) + transitions.add(unary) + unary = parse_transitions.CompoundUnary(["asdf", "zzzz"]) + assert unary in transitions + + # check that the str and the list constructors result in the same item + assert len(transitions) == 3 + unary = parse_transitions.CompoundUnary(["asdf"]) + assert unary in transitions + + oc = parse_transitions.OpenConstituent("asdf") + assert oc not in transitions + transitions.add(oc) + assert oc in transitions + transitions.add(oc) + transitions.add(oc) + assert len(transitions) == 4 + assert parse_transitions.OpenConstituent("asdf") in transitions + + cc = parse_transitions.CloseConstituent() + assert cc not in transitions + transitions.add(cc) + transitions.add(cc) + transitions.add(cc) + assert cc in transitions + cc = parse_transitions.CloseConstituent() + assert cc in transitions + assert len(transitions) == 5 + + +def test_sort(): + expected = [] + + expected.append(parse_transitions.Shift()) + expected.append(parse_transitions.CloseConstituent()) + expected.append(parse_transitions.CompoundUnary(["NP"])) + expected.append(parse_transitions.CompoundUnary(["NP", "VP"])) + expected.append(parse_transitions.OpenConstituent("mox")) + expected.append(parse_transitions.OpenConstituent("opal")) + expected.append(parse_transitions.OpenConstituent("unban")) + + transitions = set(expected) + transitions = sorted(transitions) + assert transitions == expected diff --git a/stanza/tests/constituency/test_parse_tree.py b/stanza/tests/constituency/test_parse_tree.py new file mode 100644 index 00000000..959e9936 --- /dev/null +++ b/stanza/tests/constituency/test_parse_tree.py @@ -0,0 +1,196 @@ +import pytest + +from stanza.models.constituency.parse_tree import Tree +from stanza.models.constituency import tree_reader + +from stanza.tests import * + +pytestmark = [pytest.mark.pipeline, pytest.mark.travis] + +def test_leaf_preterminal(): + foo = Tree(label="foo") + assert foo.is_leaf() + assert not foo.is_preterminal() + assert len(foo.children) == 0 + assert str(foo) == 'foo' + + bar = Tree(label="bar", children=foo) + assert not bar.is_leaf() + assert bar.is_preterminal() + assert len(bar.children) == 1 + assert str(bar) == "(bar foo)" + + baz = Tree(label="baz", children=[bar]) + assert not baz.is_leaf() + assert not baz.is_preterminal() + assert len(baz.children) == 1 + assert str(baz) == "(baz (bar foo))" + + +def test_yield_preterminals(): + text = "((S (VP (VB Unban)) (NP (NNP Mox) (NNP Opal))))" + trees = tree_reader.read_trees(text) + + preterminals = trees[0].preterminals() + assert len(preterminals) == 3 + assert str(preterminals) == "[(VB Unban), (NNP Mox), (NNP Opal)]" + +def test_depth(): + text = "(foo) ((S (VP (VB Unban)) (NP (NNP Mox) (NNP Opal))))" + trees = tree_reader.read_trees(text) + assert trees[0].depth() == 0 + assert trees[1].depth() == 4 + +def test_unique_labels(): + """ + Test getting the unique labels from a tree + + Assumes tree_reader works, which should be fine since it is tested elsewhere + """ + text="((SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?))) ((SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))" + + trees = tree_reader.read_trees(text) + + labels = Tree.get_unique_constituent_labels(trees) + expected = ['NP', 'PP', 'ROOT', 'SBARQ', 'SQ', 'VP', 'WHNP'] + assert labels == expected + +def test_unique_tags(): + """ + Test getting the unique tags from a tree + """ + text="((SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))" + + trees = tree_reader.read_trees(text) + + tags = Tree.get_unique_tags(trees) + expected = ['.', 'DT', 'IN', 'NN', 'VBZ', 'WP'] + assert tags == expected + + +def test_unique_words(): + """ + Test getting the unique words from a tree + """ + text="((SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))" + + trees = tree_reader.read_trees(text) + + words = Tree.get_unique_words(trees) + expected = ['?', 'Who', 'in', 'seat', 'sits', 'this'] + assert words == expected + +def test_rare_words(): + """ + Test getting the unique words from a tree + """ + text="((SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?))) ((SBARQ (NP (DT this) (NN seat)) (. ?)))" + + trees = tree_reader.read_trees(text) + + words = Tree.get_rare_words(trees, 0.5) + expected = ['Who', 'in', 'sits'] + assert words == expected + +def test_root_labels(): + text="( (SBARQ-FOO (WHNP-BAR (WP Who)) (SQ#ASDF (VP=1 (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))" + trees = tree_reader.read_trees(text) + assert ["ROOT"] == Tree.get_root_labels(trees) + + text=("( (SBARQ-FOO (WHNP-BAR (WP Who)) (SQ#ASDF (VP=1 (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))" + + "( (SBARQ-FOO (WHNP-BAR (WP Who)) (SQ#ASDF (VP=1 (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))" + + "( (SBARQ-FOO (WHNP-BAR (WP Who)) (SQ#ASDF (VP=1 (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))") + trees = tree_reader.read_trees(text) + assert ["ROOT"] == Tree.get_root_labels(trees) + + text="(FOO) (BAR)" + trees = tree_reader.read_trees(text) + assert ["BAR", "FOO"] == Tree.get_root_labels(trees) + +def test_prune_none(): + text=["((SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (-NONE- in) (NP (DT this) (NN seat))))) (. ?)))", # test one dead node + "((SBARQ (WHNP (-NONE- Who)) (SQ (VP (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))", # test recursive dead nodes + "((SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in) (NP (-NONE- this) (-NONE- seat))))) (. ?)))"] # test all children dead + expected=["(ROOT (SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (NP (DT this) (NN seat))))) (. ?)))", + "(ROOT (SBARQ (SQ (VP (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))", + "(ROOT (SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in)))) (. ?)))"] + + for t, e in zip(text, expected): + trees = tree_reader.read_trees(t) + assert len(trees) == 1 + tree = trees[0].prune_none() + assert e == str(tree) + +def test_simplify_labels(): + text="( (SBARQ-FOO (WHNP-BAR (WP Who)) (SQ#ASDF (VP=1 (VBZ sits) (PP (IN in) (NP (DT this) (- -))))) (. ?)))" + expected = "(ROOT (SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in) (NP (DT this) (- -))))) (. ?)))" + trees = tree_reader.read_trees(text) + trees = [t.simplify_labels() for t in trees] + assert len(trees) == 1 + assert expected == str(trees[0]) + +def test_remap_constituent_labels(): + text="(ROOT (SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in)))) (. ?)))" + expected="(ROOT (FOO (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in)))) (. ?)))" + + label_map = { "SBARQ": "FOO" } + trees = tree_reader.read_trees(text) + trees = [t.remap_constituent_labels(label_map) for t in trees] + assert len(trees) == 1 + assert expected == str(trees[0]) + +def test_remap_constituent_words(): + text="(ROOT (SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in)))) (. ?)))" + expected="(ROOT (SBARQ (WHNP (WP unban)) (SQ (VP (VBZ mox) (PP (IN opal)))) (. ?)))" + + word_map = { "Who": "unban", "sits": "mox", "in": "opal" } + trees = tree_reader.read_trees(text) + trees = [t.remap_words(word_map) for t in trees] + assert len(trees) == 1 + assert expected == str(trees[0]) + +def test_replace_words(): + text="(ROOT (SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in)))) (. ?)))" + expected="(ROOT (SBARQ (WHNP (WP unban)) (SQ (VP (VBZ mox) (PP (IN opal)))) (. ?)))" + new_words = ["unban", "mox", "opal", "?"] + + trees = tree_reader.read_trees(text) + assert len(trees) == 1 + tree = trees[0] + new_tree = tree.replace_words(new_words) + assert expected == str(new_tree) + + +def test_compound_constituents(): + # TODO: add skinny trees like this to the various transition tests + text="((VP (VB Unban)))" + trees = tree_reader.read_trees(text) + assert Tree.get_compound_constituents(trees) == [('ROOT', 'VP')] + + text="(ROOT (SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in)))) (. ?)))" + trees = tree_reader.read_trees(text) + assert Tree.get_compound_constituents(trees) == [('PP',), ('ROOT', 'SBARQ'), ('SQ', 'VP'), ('WHNP',)] + + text="((VP (VB Unban))) (ROOT (SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in)))) (. ?)))" + trees = tree_reader.read_trees(text) + assert Tree.get_compound_constituents(trees) == [('PP',), ('ROOT', 'SBARQ'), ('ROOT', 'VP'), ('SQ', 'VP'), ('WHNP',)] + +def test_equals(): + """ + Check one tree from the actual dataset for == + + when built with compound Open, this didn't work because of a silly bug + """ + text = "(ROOT (S (NP (DT The) (NNP Arizona) (NNPS Corporations) (NNP Commission)) (VP (VBD authorized) (NP (NP (DT an) (ADJP (CD 11.5)) (NN %) (NN rate) (NN increase)) (PP (IN at) (NP (NNP Tucson) (NNP Electric) (NNP Power) (NNP Co.))) (, ,) (UCP (ADJP (ADJP (RB substantially) (JJR lower)) (SBAR (IN than) (S (VP (VBN recommended) (NP (JJ last) (NN month)) (PP (IN by) (NP (DT a) (NN commission) (NN hearing) (NN officer))))))) (CC and) (NP (NP (QP (RB barely) (PDT half)) (DT the) (NN rise)) (VP (VBN sought) (PP (IN by) (NP (DT the) (NN utility)))))))) (. .)))" + + trees = tree_reader.read_trees(text) + assert len(trees) == 1 + tree = trees[0] + + assert tree == tree + + trees2 = tree_reader.read_trees(text) + tree2 = trees2[0] + + assert tree is not tree2 + assert tree == tree2 diff --git a/stanza/tests/constituency/test_transition_sequence.py b/stanza/tests/constituency/test_transition_sequence.py new file mode 100644 index 00000000..6c77db3f --- /dev/null +++ b/stanza/tests/constituency/test_transition_sequence.py @@ -0,0 +1,87 @@ +import pytest +from stanza.models.constituency import parse_transitions +from stanza.models.constituency import transition_sequence +from stanza.models.constituency import tree_reader +from stanza.models.constituency.base_model import SimpleModel +from stanza.models.constituency.parse_transitions import * + +from stanza.tests import * + +pytestmark = [pytest.mark.pipeline, pytest.mark.travis] + +def check_reproduce_tree(transition_scheme): + text="((SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))" + trees = tree_reader.read_trees(text) + + model = SimpleModel(transition_scheme) + transitions = transition_sequence.build_sequence(trees[0], transition_scheme) + states = parse_transitions.initial_state_from_gold_trees(trees, model) + assert(len(states)) == 1 + state = states[0] + assert state.num_transitions() == 0 + + for t in transitions: + assert t.is_legal(state, model) + state = t.apply(state, model) + + # one item for the final tree + # one item for the sentinel at the end + assert len(state.constituents) == 2 + # the transition sequence should put all of the words + # from the buffer onto the tree + # one spot left for the sentinel value + assert len(state.word_queue) == 7 + assert state.sentence_length == 6 + assert state.word_position == state.sentence_length + assert len(state.transitions) == len(transitions) + 1 + + result_tree = state.constituents.value + assert result_tree == trees[0] + +def test_top_down_unary(): + check_reproduce_tree(transition_scheme=TransitionScheme.TOP_DOWN_UNARY) + +def test_top_down_no_unary(): + check_reproduce_tree(transition_scheme=TransitionScheme.TOP_DOWN) + +def test_in_order(): + check_reproduce_tree(transition_scheme=TransitionScheme.IN_ORDER) + +def test_all_transitions(): + text="((SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))" + trees = tree_reader.read_trees(text) + model = SimpleModel() + transitions = transition_sequence.build_treebank(trees) + + expected = [Shift(), CloseConstituent(), CompoundUnary("ROOT"), CompoundUnary("SQ"), CompoundUnary("WHNP"), OpenConstituent("NP"), OpenConstituent("PP"), OpenConstituent("SBARQ"), OpenConstituent("VP")] + assert transition_sequence.all_transitions(transitions) == expected + + +def test_all_transitions_no_unary(): + text="((SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))" + trees = tree_reader.read_trees(text) + model = SimpleModel() + transitions = transition_sequence.build_treebank(trees, transition_scheme=TransitionScheme.TOP_DOWN) + + expected = [Shift(), CloseConstituent(), OpenConstituent("NP"), OpenConstituent("PP"), OpenConstituent("ROOT"), OpenConstituent("SBARQ"), OpenConstituent("SQ"), OpenConstituent("VP"), OpenConstituent("WHNP")] + assert transition_sequence.all_transitions(transitions) == expected + +def test_top_down_compound_unary(): + text = "(ROOT (S (NP (DT The) (NNP Arizona) (NNPS Corporations) (NNP Commission)) (VP (VBD authorized) (NP (NP (DT an) (ADJP (CD 11.5)) (NN %) (NN rate) (NN increase)) (PP (IN at) (NP (NNP Tucson) (NNP Electric) (NNP Power) (NNP Co.))) (, ,) (UCP (ADJP (ADJP (RB substantially) (JJR lower)) (SBAR (IN than) (S (VP (VBN recommended) (NP (JJ last) (NN month)) (PP (IN by) (NP (DT a) (NN commission) (NN hearing) (NN officer))))))) (CC and) (NP (NP (QP (RB barely) (PDT half)) (DT the) (NN rise)) (VP (VBN sought) (PP (IN by) (NP (DT the) (NN utility)))))))) (. .)))" + + trees = tree_reader.read_trees(text) + assert len(trees) == 1 + + model = SimpleModel() + transitions = transition_sequence.build_sequence(trees[0], transition_scheme=TransitionScheme.TOP_DOWN_COMPOUND) + + states = parse_transitions.initial_state_from_gold_trees(trees, model) + assert len(states) == 1 + state = states[0] + + for t in transitions: + assert t.is_legal(state, model) + state = t.apply(state, model) + + result = model.get_top_constituent(state.constituents) + assert trees[0] == result diff --git a/stanza/tests/constituency/test_tree_reader.py b/stanza/tests/constituency/test_tree_reader.py new file mode 100644 index 00000000..feee74fa --- /dev/null +++ b/stanza/tests/constituency/test_tree_reader.py @@ -0,0 +1,61 @@ +import pytest +from stanza.models.constituency import tree_reader + +from stanza.tests import * + +pytestmark = [pytest.mark.pipeline, pytest.mark.travis] + +def test_simple(): + """ + Tests reading two simple trees from the same text + """ + text = "(VB Unban) (NNP Opal)" + trees = tree_reader.read_trees(text) + assert len(trees) == 2 + assert trees[0].is_preterminal() + assert trees[0].label == 'VB' + assert trees[0].children[0].label == 'Unban' + assert trees[1].is_preterminal() + assert trees[1].label == 'NNP' + assert trees[1].children[0].label == 'Opal' + +def test_newlines(): + """ + The same test should work if there are newlines + """ + text = "(VB Unban)\n\n(NNP Opal)" + trees = tree_reader.read_trees(text) + assert len(trees) == 2 + +def test_complicated(): + """ + A more complicated tree that should successfully read + """ + text="( (SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))" + trees = tree_reader.read_trees(text) + assert len(trees) == 1 + tree = trees[0] + assert not tree.is_leaf() + assert not tree.is_preterminal() + assert tree.label == 'ROOT' + assert len(tree.children) == 1 + assert tree.children[0].label == 'SBARQ' + assert len(tree.children[0].children) == 3 + assert [x.label for x in tree.children[0].children] == ['WHNP', 'SQ', '.'] + # etc etc + +def test_one_word(): + """ + Check that one node trees are correctly read + + probably not super relevant for the parsing use case + """ + text="(FOO) (BAR)" + trees = tree_reader.read_trees(text) + assert len(trees) == 2 + + assert trees[0].is_leaf() + assert trees[0].label == 'FOO' + + assert trees[1].is_leaf() + assert trees[1].label == 'BAR' diff --git a/stanza/tests/constituency/test_tree_stack.py b/stanza/tests/constituency/test_tree_stack.py new file mode 100644 index 00000000..e7859a3b --- /dev/null +++ b/stanza/tests/constituency/test_tree_stack.py @@ -0,0 +1,50 @@ +import pytest + +from stanza.models.constituency.tree_stack import TreeStack + +from stanza.tests import * + +pytestmark = [pytest.mark.pipeline, pytest.mark.travis] + +def test_simple(): + stack = TreeStack(value=5, parent=None, length=1) + stack = stack.push(3) + stack = stack.push(1) + + expected_values = [1, 3, 5] + for value in expected_values: + assert stack.value == value + stack = stack.pop() + assert stack is None + +def test_iter(): + stack = TreeStack(value=5, parent=None, length=1) + stack = stack.push(3) + stack = stack.push(1) + + stack_list = list(stack) + assert list(stack) == [1, 3, 5] + +def test_str(): + stack = TreeStack(value=5, parent=None, length=1) + stack = stack.push(3) + stack = stack.push(1) + + assert str(stack) == "TreeStack(1, 3, 5)" + +def test_len(): + stack = TreeStack(value=5, parent=None, length=1) + assert len(stack) == 1 + + stack = stack.push(3) + stack = stack.push(1) + assert len(stack) == 3 + +def test_long_len(): + """ + Original stack had a bug where this took exponential time... + """ + stack = TreeStack(value=0, parent=None, length=1) + for i in range(1, 40): + stack = stack.push(i) + assert len(stack) == 40 diff --git a/stanza/tests/constituency/test_utils.py b/stanza/tests/constituency/test_utils.py new file mode 100644 index 00000000..38a9f37e --- /dev/null +++ b/stanza/tests/constituency/test_utils.py @@ -0,0 +1,68 @@ +import pytest + +from stanza import Pipeline +from stanza.models.constituency import tree_reader +from stanza.models.constituency import utils + +from stanza.tests import * + +pytestmark = [pytest.mark.pipeline, pytest.mark.travis] + + +@pytest.fixture(scope="module") +def pipeline(): + return Pipeline(lang="en", processors="tokenize, pos", tokenize_pretokenized=True) + + + +def test_xpos_retag(pipeline): + """ + Test using the English tagger that trees will be correctly retagged by read_trees using xpos + """ + text = "((S (VP (X Find)) (NP (X Mox) (X Opal)))) ((S (NP (X Ragavan)) (VP (X steals) (NP (X important) (X cards)))))" + expected = "((S (VP (VB Find)) (NP (NNP Mox) (NNP Opal)))) ((S (NP (NNP Ragavan)) (VP (VBZ steals) (NP (JJ important) (NNS cards)))))" + + trees = tree_reader.read_trees(text) + + new_trees = utils.retag_trees(trees, pipeline, xpos=True) + assert new_trees == tree_reader.read_trees(expected) + + + +def test_upos_retag(pipeline): + """ + Test using the English tagger that trees will be correctly retagged by read_trees using upos + """ + text = "((S (VP (X Find)) (NP (X Mox) (X Opal)))) ((S (NP (X Ragavan)) (VP (X steals) (NP (X important) (X cards)))))" + expected = "((S (VP (VERB Find)) (NP (PROPN Mox) (PROPN Opal)))) ((S (NP (PROPN Ragavan)) (VP (VERB steals) (NP (ADJ important) (NOUN cards)))))" + + trees = tree_reader.read_trees(text) + + new_trees = utils.retag_trees(trees, pipeline, xpos=False) + assert new_trees == tree_reader.read_trees(expected) + + +def test_replace_tags(): + """ + Test the underlying replace_tags method + + Also tests that the method throws exceptions when it is supposed to + """ + text = "((S (VP (X Find)) (NP (X Mox) (X Opal))))" + expected = "((S (VP (A Find)) (NP (B Mox) (C Opal))))" + + trees = tree_reader.read_trees(text) + + new_tags = ["A", "B", "C"] + new_tree = utils.replace_tags(trees[0], new_tags) + + assert new_tree == tree_reader.read_trees(expected)[0] + + with pytest.raises(ValueError): + new_tags = ["A", "B"] + new_tree = utils.replace_tags(trees[0], new_tags) + + with pytest.raises(ValueError): + new_tags = ["A", "B", "C", "D"] + new_tree = utils.replace_tags(trees[0], new_tags) + diff --git a/stanza/tests/test_constituency_parse_tree.py b/stanza/tests/test_constituency_parse_tree.py deleted file mode 100644 index a4f81180..00000000 --- a/stanza/tests/test_constituency_parse_tree.py +++ /dev/null @@ -1,33 +0,0 @@ -import pytest - -from stanza.models.constituency.parse_tree import Tree - -from stanza.tests import * - -pytestmark = [pytest.mark.pipeline, pytest.mark.travis] - -def test_leaf_preterminal(): - foo = Tree(label="foo") - assert foo.is_leaf() - assert not foo.is_preterminal() - assert len(foo.children) == 0 - assert str(foo) == 'foo' - - bar = Tree(label="bar", children=foo) - assert not bar.is_leaf() - assert bar.is_preterminal() - assert len(bar.children) == 1 - assert str(bar) == "(bar foo)" - - baz = Tree(label="baz", children=[bar]) - assert not baz.is_leaf() - assert not baz.is_preterminal() - assert len(baz.children) == 1 - assert str(baz) == "(baz (bar foo))" - - -def test_depth(): - text = "(foo) ((S (VP (VB Unban)) (NP (NNP Mox) (NNP Opal))))" - trees = tree_reader.read_trees(text) - assert trees[0].depth() == 0 - assert trees[1].depth() == 4 diff --git a/stanza/tests/test_constituency_tree_reader.py b/stanza/tests/test_constituency_tree_reader.py deleted file mode 100644 index feee74fa..00000000 --- a/stanza/tests/test_constituency_tree_reader.py +++ /dev/null @@ -1,61 +0,0 @@ -import pytest -from stanza.models.constituency import tree_reader - -from stanza.tests import * - -pytestmark = [pytest.mark.pipeline, pytest.mark.travis] - -def test_simple(): - """ - Tests reading two simple trees from the same text - """ - text = "(VB Unban) (NNP Opal)" - trees = tree_reader.read_trees(text) - assert len(trees) == 2 - assert trees[0].is_preterminal() - assert trees[0].label == 'VB' - assert trees[0].children[0].label == 'Unban' - assert trees[1].is_preterminal() - assert trees[1].label == 'NNP' - assert trees[1].children[0].label == 'Opal' - -def test_newlines(): - """ - The same test should work if there are newlines - """ - text = "(VB Unban)\n\n(NNP Opal)" - trees = tree_reader.read_trees(text) - assert len(trees) == 2 - -def test_complicated(): - """ - A more complicated tree that should successfully read - """ - text="( (SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))" - trees = tree_reader.read_trees(text) - assert len(trees) == 1 - tree = trees[0] - assert not tree.is_leaf() - assert not tree.is_preterminal() - assert tree.label == 'ROOT' - assert len(tree.children) == 1 - assert tree.children[0].label == 'SBARQ' - assert len(tree.children[0].children) == 3 - assert [x.label for x in tree.children[0].children] == ['WHNP', 'SQ', '.'] - # etc etc - -def test_one_word(): - """ - Check that one node trees are correctly read - - probably not super relevant for the parsing use case - """ - text="(FOO) (BAR)" - trees = tree_reader.read_trees(text) - assert len(trees) == 2 - - assert trees[0].is_leaf() - assert trees[0].label == 'FOO' - - assert trees[1].is_leaf() - assert trees[1].label == 'BAR' diff --git a/stanza/tests/test_java_protobuf_requests.py b/stanza/tests/test_java_protobuf_requests.py new file mode 100644 index 00000000..0c7ee7d8 --- /dev/null +++ b/stanza/tests/test_java_protobuf_requests.py @@ -0,0 +1,23 @@ +import tempfile + +import pytest + +from stanza.models.constituency import tree_reader +from stanza.server import java_protobuf_requests +from stanza.tests import * + +pytestmark = [pytest.mark.travis, pytest.mark.pipeline] + +def check_tree(proto_tree, py_tree, py_score): + tree, tree_score = java_protobuf_requests.from_tree(proto_tree) + assert tree_score == py_score + assert tree == py_tree + +def test_build_tree(): + text="((S (VP (VB Unban)) (NP (NNP Mox) (NNP Opal))))\n( (SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))" + trees = tree_reader.read_trees(text) + assert len(trees) == 2 + + for tree in trees: + proto_tree = java_protobuf_requests.build_tree(trees[0], 1.0) + check_tree(proto_tree, trees[0], 1.0) diff --git a/stanza/tests/test_parser_eval.py b/stanza/tests/test_parser_eval.py new file mode 100644 index 00000000..d633c529 --- /dev/null +++ b/stanza/tests/test_parser_eval.py @@ -0,0 +1,40 @@ +""" +Test the parser eval interface +""" + +import pytest +import stanza +from stanza.models.constituency import tree_reader +from stanza.protobuf import EvaluateParserRequest, EvaluateParserResponse +from stanza.server.parser_eval import build_request, EvaluateParser +from stanza.tests.test_java_protobuf_requests import check_tree + +from stanza.tests import * + +pytestmark = [pytest.mark.travis, pytest.mark.client] + +def build_one_tree_treebank(): + text = "((S (VP (VB Unban)) (NP (NNP Mox) (NNP Opal))))" + trees = tree_reader.read_trees(text) + assert len(trees) == 1 + gold = trees[0] + prediction = (gold, 1.0) + treebank = [(gold, [prediction])] + return treebank + +def test_build_request_one_tree(): + treebank = build_one_tree_treebank() + request = build_request(treebank) + + assert len(request.treebank) == 1 + check_tree(request.treebank[0].gold, treebank[0][0], None) + assert len(request.treebank[0].predicted) == 1 + check_tree(request.treebank[0].predicted[0], treebank[0][1][0][0], treebank[0][1][0][1]) + + +def test_score_one_tree(): + treebank = build_one_tree_treebank() + + with EvaluateParser(classpath="$CLASSPATH") as ep: + response = ep.process(treebank) + assert response.f1 == pytest.approx(1.0) -- cgit v1.2.3 From 6347e75befbe3a4684d2982a06c7041ef8660058 Mon Sep 17 00:00:00 2001 From: John Bauer Date: Sat, 18 Sep 2021 00:09:57 -0700 Subject: Processing script for the IT Turin treebank Remaps and/or replaces some weird constituency labels Refactor some, get a bunch of the filtering applied to the test set. Turns out there's at least one miswritten test tree as well Prune NONE ... turns out a lot of those are showing up. Also double check that no train trees are in the test set, although that was not a problem Splits train into train & dev Remaps words as well so that [] is back the way it should be Skips a broken tree based on its preterminal Resplits a bunch of the tokens when processing it_turin --- .../datasets/constituency/convert_it_turin.py | 322 +++++++++++++++++++++ 1 file changed, 322 insertions(+) create mode 100644 stanza/utils/datasets/constituency/convert_it_turin.py diff --git a/stanza/utils/datasets/constituency/convert_it_turin.py b/stanza/utils/datasets/constituency/convert_it_turin.py new file mode 100644 index 00000000..018073e3 --- /dev/null +++ b/stanza/utils/datasets/constituency/convert_it_turin.py @@ -0,0 +1,322 @@ +""" +Converts Turin's constituency dataset + +Turin University put out a freely available constituency dataset in 2011. +It is not as large as VIT or ISST, but it is free, which is nice. + +The 2011 parsing task combines trees from several sources: +http://www.di.unito.it/~tutreeb/evalita-parsingtask-11.html + +There is another site for Turin treebanks: +http://www.di.unito.it/~tutreeb/treebanks.html + +Weirdly, the most recent versions of the Evalita trees are not there. +The most relevant parts are the ParTUT downloads. As of Sep. 2021: + +http://www.di.unito.it/~tutreeb/corpora/Par-TUT/tutINpenn/italian/JRCAcquis_It.pen +http://www.di.unito.it/~tutreeb/corpora/Par-TUT/tutINpenn/italian/UDHR_It.pen +http://www.di.unito.it/~tutreeb/corpora/Par-TUT/tutINpenn/italian/CC_It.pen +http://www.di.unito.it/~tutreeb/corpora/Par-TUT/tutINpenn/italian/FB_It.pen +http://www.di.unito.it/~tutreeb/corpora/Par-TUT/tutINpenn/italian/WIT3_It.pen + +We can't simply cat all these files together as there are a bunch of +asterisks as comments and the files may have some duplicates. For +example, the JRCAcquis piece has many duplicates. Also, some don't +pass validation for one reason or another. + +One oddity of these data files is that the MWT are denoted by doubling +the token. The token is not split as would be expected, though. We try +to use stanza's MWT tokenizer for IT to split the tokens, with some +rules added by hand in BIWORD_SPLITS. Two are still unsplit, though... +""" + +import glob +import os +import re +import sys + +import stanza +from stanza.models.constituency import parse_tree +from stanza.models.constituency import tree_reader + +def load_without_asterisks(in_file, encoding='utf-8'): + with open(in_file, encoding=encoding) as fin: + new_lines = [x if x.find("********") < 0 else "\n" for x in fin.readlines()] + if len(new_lines) > 0 and not new_lines[-1].endswith("\n"): + new_lines[-1] = new_lines[-1] + "\n" + return new_lines + +CONSTITUENT_SPLIT = re.compile("[-=#+0-9]") + +# JRCA is almost entirely duplicates +# WIT3 follows a different annotation scheme +FILES_TO_ELIMINATE = ["JRCAcquis_It.pen", "WIT3_It.pen"] + +# assuming this is a typo +REMAP_NODES = { "Sbar" : "SBAR" } + +REMAP_WORDS = { "-LSB-": "[", "-RSB-": "]" } + +# these mostly seem to be mistakes +# maybe Vbar and ADVbar should be converted to something else? +NODES_TO_ELIMINATE = ["C", "PHRASP", "PRDT", "Vbar", "parte", "ADVbar"] + +UNKNOWN_SPLITS = set() + +# a map of splits that the tokenizer or MWT doesn't handle well +BIWORD_SPLITS = { "offertogli": ("offerto", "gli"), + "offertegli": ("offerte", "gli"), + "formatasi": ("formata", "si"), + "formatosi": ("formato", "si"), + "multiplexarlo": ("multiplexar", "lo"), + "esibirsi": ("esibir", "si"), + "pagarne": ("pagar", "ne"), + "recarsi": ("recar", "si"), + "trarne": ("trar", "ne"), + "esserci": ("esser", "ci"), + "aprirne": ("aprir", "ne"), + "farle": ("far", "le"), + "disporne": ("dispor", "ne"), + "andargli": ("andar", "gli"), + "CONSIDERARSI": ("CONSIDERAR", "SI"), + "conferitegli": ("conferite", "gli"), + "formatasi": ("formata", "si"), + "formatosi": ("formato", "si"), + "Formatisi": ("Formati", "si"), + "multiplexarlo": ("multiplexar", "lo"), + "esibirsi": ("esibir", "si"), + "pagarne": ("pagar", "ne"), + "recarsi": ("recar", "si"), + "trarne": ("trar", "ne"), + "temerne": ("temer", "ne"), + "esserci": ("esser", "ci"), + "esservi": ("esser", "vi"), + "restituirne": ("restituir", "ne"), + "col": ("con", "il"), + "cogli": ("con", "gli"), + "dirgli": ("dir", "gli"), + "opporgli": ("oppor", "gli"), + "eccolo": ("ecco", "lo"), + "Eccolo": ("Ecco", "lo"), + "Eccole": ("Ecco", "le"), + "farci": ("far", "ci"), + "farli": ("far", "li"), + "farne": ("far", "ne"), + "farsi": ("far", "si"), + "farvi": ("far", "vi"), + "Connettiti": ("Connetti", "ti"), + "APPLICARSI": ("APPLICAR", "SI"), + # This is not always two words, but if it IS two words, + # it gets split like this + "assicurati": ("assicura", "ti"), + "Fatti": ("Fai", "te"), + "ai": ("a", "i"), + "Ai": ("A", "i"), + "AI": ("A", "I"), + "al": ("a", "il"), + "Al": ("A", "il"), + "AL": ("A", "IL"), + "coi": ("con", "i"), + "colla": ("con", "la"), + "colle": ("con", "le"), + "dal": ("da", "il"), + "Dal": ("Da", "il"), + "DAL": ("DA", "IL"), + "dei": ("di", "i"), + "Dei": ("Di", "i"), + "DEI": ("DI", "I"), + "del": ("di", "il"), + "Del": ("Di", "il"), + "DEL": ("DI", "IL"), + "nei": ("in", "i"), + "NEI": ("IN", "I"), + "nel": ("in", "il"), + "Nel": ("In", "il"), + "NEL": ("IN", "IL"), + "pel": ("per", "il"), + "sui": ("su", "i"), + "Sui": ("Su", "i"), + "sul": ("su", "il"), + "Sul": ("Su", "il"), + ",": (",", ","), + ".": (".", "."), + '"': ('"', '"'), + '-': ('-', '-'), + '-LRB-': ('-LRB-', '-LRB-'), + "garantirne": ("garantir", "ne"), + "aprirvi": ("aprir", "vi"), + "esimersi": ("esimer", "si"), + "opporsi": ("oppor", "si"), +} + +CAP_BIWORD = re.compile("[A-Z]+_[A-Z]+") + +def split_mwe(tree, pipeline): + words = list(tree.leaf_labels()) + found = False + for idx, word in enumerate(words[:-3]): + if word == words[idx+1] and word == words[idx+2] and word == words[idx+3]: + raise ValueError("Oh no, 4 consecutive words") + + for idx, word in enumerate(words[:-2]): + if word == words[idx+1] and word == words[idx+2]: + doc = pipeline(word) + assert len(doc.sentences) == 1 + if len(doc.sentences[0].words) != 3: + raise RuntimeError("Word {} not tokenized into 3 parts... thought all 3 part words were handled!".format(word)) + words[idx] = doc.sentences[0].words[0].text + words[idx+1] = doc.sentences[0].words[1].text + words[idx+2] = doc.sentences[0].words[2].text + found = True + + for idx, word in enumerate(words[:-1]): + if word == words[idx+1]: + if word in BIWORD_SPLITS: + first_word = BIWORD_SPLITS[word][0] + second_word = BIWORD_SPLITS[word][1] + elif CAP_BIWORD.match(word): + first_word, second_word = word.split("_") + else: + doc = pipeline(word) + assert len(doc.sentences) == 1 + if len(doc.sentences[0].words) == 2: + first_word = doc.sentences[0].words[0].text + second_word = doc.sentences[0].words[1].text + else: + if word not in UNKNOWN_SPLITS: + UNKNOWN_SPLITS.add(word) + print("Could not figure out how to split {}\n {}\n {}".format(word, " ".join(words), tree)) + continue + + words[idx] = first_word + words[idx+1] = second_word + found = True + + if found: + tree = tree.replace_words(words) + return tree + + +def load_trees(filename, pipeline): + # some of the files are in latin-1 encoding rather than utf-8 + try: + raw_text = load_without_asterisks(filename, "utf-8") + except UnicodeDecodeError: + raw_text = load_without_asterisks(filename, "latin-1") + + # also, some have messed up validation (it will be logged) + # hence the broken_ok=True argument + trees = tree_reader.read_trees("".join(raw_text), broken_ok=True) + + filtered_trees = [] + for tree in trees: + if tree.children[0].label is None: + print("Skipping a broken tree (missing label) in {}: {}".format(filename, tree)) + continue + + try: + words = tuple(tree.leaf_labels()) + except ValueError: + print("Skipping a broken tree (missing preterminal) in {}: {}".format(filename, tree)) + continue + + if any('www.facebook' in pt.label for pt in tree.preterminals()): + print("Skipping a tree with a weird preterminal label in {}: {}".format(filename, tree)) + continue + + tree = tree.prune_none().simplify_labels(CONSTITUENT_SPLIT) + tree = tree.remap_constituent_labels(REMAP_NODES) + tree = tree.remap_words(REMAP_WORDS) + + tree = split_mwe(tree, pipeline) + if tree is None: + continue + + constituents = set(parse_tree.Tree.get_unique_constituent_labels(tree)) + for weird_label in NODES_TO_ELIMINATE: + if weird_label in constituents: + break + else: + weird_label = None + if weird_label is not None: + print("Skipping a tree with a weird label {} in {}: {}".format(weird_label, filename, tree)) + continue + + filtered_trees.append(tree) + + return filtered_trees + +def save_trees(out_file, trees): + print("Saving {} trees to {}".format(len(trees), out_file)) + with open(out_file, "w", encoding="utf-8") as fout: + for tree in trees: + fout.write(str(tree)) + fout.write("\n") + +def main(): + pipeline = stanza.Pipeline("it", processors="tokenize, mwt", tokenize_no_ssplit=True) + + input_path = sys.argv[1] + output_path = sys.argv[2] + + os.makedirs(output_path, exist_ok=True) + + evalita_dir = os.path.join(input_path, "evalita") + + evalita_test = os.path.join(evalita_dir, "evalita11_TESTgold_CONPARSE.penn") + it_test = os.path.join(output_path, "it_turin_test.mrg") + test_trees = load_trees(evalita_test, pipeline) + save_trees(it_test, test_trees) + + known_text = set() + for tree in test_trees: + words = tuple(tree.leaf_labels()) + assert words not in known_text + known_text.add(words) + + evalita_train = os.path.join(output_path, "it_turin_train.mrg") + evalita_files = glob.glob(os.path.join(evalita_dir, "*2011*penn")) + turin_files = glob.glob(os.path.join(input_path, "turin", "*pen")) + filenames = evalita_files + turin_files + filtered_trees = [] + for filename in filenames: + if os.path.split(filename)[1] in FILES_TO_ELIMINATE: + continue + + trees = load_trees(filename, pipeline) + file_trees = [] + + for tree in trees: + words = tuple(tree.leaf_labels()) + if words in known_text: + print("Skipping a duplicate in {}: {}".format(filename, tree)) + continue + + known_text.add(words) + + file_trees.append(tree) + + filtered_trees.append((filename, file_trees)) + + print("{} contains {} usable trees".format(evalita_test, len(test_trees))) + print(" Unique constituents in {}: {}".format(evalita_test, parse_tree.Tree.get_unique_constituent_labels(test_trees))) + + train_trees = [] + dev_trees = [] + for filename, file_trees in filtered_trees: + print("{} contains {} usable trees".format(filename, len(file_trees))) + print(" Unique constituents in {}: {}".format(filename, parse_tree.Tree.get_unique_constituent_labels(file_trees))) + for tree in file_trees: + if len(train_trees) <= len(dev_trees) * 9: + train_trees.append(tree) + else: + dev_trees.append(tree) + + it_train = os.path.join(output_path, "it_turin_train.mrg") + save_trees(it_train, train_trees) + + it_dev = os.path.join(output_path, "it_turin_dev.mrg") + save_trees(it_dev, dev_trees) + +if __name__ == '__main__': + main() -- cgit v1.2.3 From 84e6d7f203a33bb0225d4d4c699975502782612c Mon Sep 17 00:00:00 2001 From: John Bauer Date: Fri, 24 Sep 2021 19:00:26 -0700 Subject: Include six in the setup --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 890088da..c3f9be88 100644 --- a/setup.py +++ b/setup.py @@ -76,7 +76,7 @@ setup( # your project is installed. For an analysis of "install_requires" vs pip's # requirements files see: # https://packaging.python.org/en/latest/requirements.html - install_requires=['emoji', 'numpy', 'protobuf', 'requests', 'torch>=1.3.0', 'tqdm'], + install_requires=['emoji', 'numpy', 'protobuf', 'requests', 'six', 'torch>=1.3.0', 'tqdm'], # List required Python versions python_requires='>=3.6', -- cgit v1.2.3 From 8ce1cc0118d335fa39d6ae6bf87e5ca3c9efb58d Mon Sep 17 00:00:00 2001 From: John Bauer Date: Fri, 24 Sep 2021 19:01:52 -0700 Subject: No longer need to use classpath by default - demo version of 4.3.0 is now available --- stanza/models/constituency/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stanza/models/constituency/trainer.py b/stanza/models/constituency/trainer.py index e8d49bb5..058506cf 100644 --- a/stanza/models/constituency/trainer.py +++ b/stanza/models/constituency/trainer.py @@ -565,6 +565,6 @@ def run_dev_set(model, dev_trees, args): fout.write(str(tree[0])) fout.write("\n") - with EvaluateParser(classpath="$CLASSPATH") as evaluator: + with EvaluateParser() as evaluator: response = evaluator.process(treebank) return response.f1 -- cgit v1.2.3 From b7570d37ca51dad2bad27c48a4bae248f7737a14 Mon Sep 17 00:00:00 2001 From: John Bauer Date: Mon, 27 Sep 2021 19:13:30 -0700 Subject: Refactor the part of trainer.py which builds a trainer using the information in the train/dev treebank. As part of this, factor out the logging levels so that the tests don't spam tqdm logs. Also, refactor loading the charlm and add a test specifically of charlm Rename zeros -> word_zeros for better readability --- stanza/models/constituency/lstm_model.py | 4 +- stanza/models/constituency/trainer.py | 80 +++++++++++++++++----------- stanza/tests/constituency/test_lstm_model.py | 37 +++++++++---- 3 files changed, 76 insertions(+), 45 deletions(-) diff --git a/stanza/models/constituency/lstm_model.py b/stanza/models/constituency/lstm_model.py index 2939403e..ec9b1a8f 100644 --- a/stanza/models/constituency/lstm_model.py +++ b/stanza/models/constituency/lstm_model.py @@ -132,7 +132,7 @@ class LSTMModel(BaseModel, nn.Module): self.lstm_layer_dropout = self.args['lstm_layer_dropout'] # also register a buffer of zeros so that we can always get zeros on the appropriate device - self.register_buffer('zeros', torch.zeros(self.hidden_size)) + self.register_buffer('word_zeros', torch.zeros(self.hidden_size)) self.register_buffer('transition_zeros', torch.zeros(self.num_layers, 1, self.transition_hidden_size)) self.register_buffer('constituent_zeros', torch.zeros(self.num_layers, 1, self.hidden_size)) @@ -329,7 +329,7 @@ class LSTMModel(BaseModel, nn.Module): word_queue = [WordNode(tag_node, sentence_output[idx, :]) for idx, tag_node in enumerate(tagged_words)] word_queue.reverse() - word_queue.append(WordNode(None, self.zeros)) + word_queue.append(WordNode(None, self.word_zeros)) word_queues.append(word_queue) diff --git a/stanza/models/constituency/trainer.py b/stanza/models/constituency/trainer.py index 058506cf..5437e833 100644 --- a/stanza/models/constituency/trainer.py +++ b/stanza/models/constituency/trainer.py @@ -33,8 +33,7 @@ from stanza.server.parser_eval import EvaluateParser tqdm = utils.get_tqdm() -logger = logging.getLogger('stanza') - +logger = logging.getLogger('stanza.constituency.trainer') class Trainer: """ @@ -171,7 +170,12 @@ def verify_transitions(trees, sequences, transition_scheme): """ model = base_model.SimpleModel(transition_scheme) logger.info("Verifying the transition sequences for %d trees", len(trees)) - for tree, sequence in tqdm(zip(trees, sequences), total=len(trees)): + + data = zip(trees, sequences) + if logger.getEffectiveLevel() <= logging.INFO: + data = tqdm(zip(trees, sequences), total=len(trees)) + + for tree, sequence in data: state = parse_transitions.initial_state_from_gold_trees([tree], model)[0] for idx, trans in enumerate(sequence): if not trans.is_legal(state, model): @@ -245,26 +249,18 @@ def remove_optimizer(args, model_save_file, model_load_file): trainer = Trainer.load(model_load_file, pt, forward_charlm, backward_charlm, use_gpu=False, load_optimizer=False) trainer.save(model_save_file) -def train(args, model_save_file, model_load_file, model_save_latest_file, retag_pipeline): +def convert_trees_to_sequences(trees, tree_type, transition_scheme): + logger.info("Building {} transition sequences".format(tree_type)) + if logger.getEffectiveLevel() <= logging.INFO: + trees = tqdm(trees) + sequences = build_treebank(trees, transition_scheme) + transitions = transition_sequence.all_transitions(sequences) + return sequences, transitions + +def build_trainer(args, train_trees, dev_trees, pt, forward_charlm, backward_charlm): """ - Build a model, train it using the requested train & dev files + Builds a Trainer (with model) and the train_sequences and transitions for the given trees. """ - print_args(args) - - utils.ensure_dir(args['save_dir']) - - train_trees = read_treebank(args['train_file']) - logger.info("Read %d trees for the training set", len(train_trees)) - - dev_trees = read_treebank(args['eval_file']) - logger.info("Read %d trees for the dev set", len(dev_trees)) - - if retag_pipeline is not None: - logger.info("Retagging trees using the %s tags from the %s package...", args['retag_method'], args['retag_package']) - train_trees = retag_trees(train_trees, retag_pipeline, args['retag_xpos']) - dev_trees = retag_trees(dev_trees, retag_pipeline, args['retag_xpos']) - logger.info("Retagging finished") - train_constituents = parse_tree.Tree.get_unique_constituent_labels(train_trees) dev_constituents = parse_tree.Tree.get_unique_constituent_labels(dev_trees) logger.info("Unique constituents in training set: %s", train_constituents) @@ -272,13 +268,8 @@ def train(args, model_save_file, model_load_file, model_save_latest_file, retag_ if con not in train_constituents: raise RuntimeError("Found label {} in the dev set which don't exist in the train set".format(con)) - logger.info("Building training transition sequences") - train_sequences = build_treebank(tqdm(train_trees), args['transition_scheme']) - train_transitions = transition_sequence.all_transitions(train_sequences) - - logger.info("Building dev transition sequences") - dev_sequences = build_treebank(tqdm(dev_trees), args['transition_scheme']) - dev_transitions = transition_sequence.all_transitions(dev_sequences) + train_sequences, train_transitions = convert_trees_to_sequences(train_trees, "training", args['transition_scheme']) + dev_sequences, dev_transitions = convert_trees_to_sequences(dev_trees, "dev", args['transition_scheme']) logger.info("Total unique transitions in train set: %d", len(train_transitions)) for trans in dev_transitions: @@ -308,10 +299,6 @@ def train(args, model_save_file, model_load_file, model_save_latest_file, retag_ # train set. it just means we probably won't ever get that right open_nodes = get_open_nodes(train_trees, args) - pt = load_pretrain(args) - forward_charlm = load_charlm(args['charlm_forward_file']) - backward_charlm = load_charlm(args['charlm_backward_file']) - # at this point we have: # pretrain # train_trees, dev_trees @@ -329,8 +316,37 @@ def train(args, model_save_file, model_load_file, model_save_latest_file, retag_ trainer = Trainer(model, optimizer) + return trainer, train_sequences, train_transitions + +def train(args, model_save_file, model_load_file, model_save_latest_file, retag_pipeline): + """ + Build a model, train it using the requested train & dev files + """ + print_args(args) + + utils.ensure_dir(args['save_dir']) + + train_trees = read_treebank(args['train_file']) + logger.info("Read %d trees for the training set", len(train_trees)) + + dev_trees = read_treebank(args['eval_file']) + logger.info("Read %d trees for the dev set", len(dev_trees)) + + if retag_pipeline is not None: + logger.info("Retagging trees using the %s tags from the %s package...", args['retag_method'], args['retag_package']) + train_trees = retag_trees(train_trees, retag_pipeline, args['retag_xpos']) + dev_trees = retag_trees(dev_trees, retag_pipeline, args['retag_xpos']) + logger.info("Retagging finished") + + pt = load_pretrain(args) + forward_charlm = load_charlm(args['charlm_forward_file']) + backward_charlm = load_charlm(args['charlm_backward_file']) + + trainer, train_sequences, train_transitions = build_trainer(args, train_trees, dev_trees, pt, forward_charlm, backward_charlm) + iterate_training(trainer, train_trees, train_sequences, train_transitions, dev_trees, args, model_save_file, model_save_latest_file) + def iterate_training(trainer, train_trees, train_sequences, transitions, dev_trees, args, model_filename, model_latest_filename): """ Given an initialized model, a processed dataset, and a secondary dev dataset, train the model diff --git a/stanza/tests/constituency/test_lstm_model.py b/stanza/tests/constituency/test_lstm_model.py index 9b5d00f1..b84278c8 100644 --- a/stanza/tests/constituency/test_lstm_model.py +++ b/stanza/tests/constituency/test_lstm_model.py @@ -1,3 +1,4 @@ +import logging import os import tempfile @@ -16,6 +17,9 @@ from stanza.tests.constituency import test_parse_transitions pytestmark = [pytest.mark.pipeline, pytest.mark.travis] +logger = logging.getLogger('stanza.constituency.trainer') +logger.setLevel(logging.WARNING) + TREEBANK = """ ( (S (VP (VBG Enjoying) @@ -53,20 +57,16 @@ def pt(): def build_model(pt, *args): # TODO: build a fake embedding some other way? - trees = tree_reader.read_trees(TREEBANK) + train_trees = tree_reader.read_trees(TREEBANK) + dev_trees = train_trees[-1:] args = constituency_parser.parse_args(args) + forward_charlm = trainer.load_charlm(args['charlm_forward_file']) + backward_charlm = trainer.load_charlm(args['charlm_backward_file']) - transitions = trainer.build_treebank(trees, args['transition_scheme']) - transitions = transition_sequence.all_transitions(transitions) - constituents = parse_tree.Tree.get_unique_constituent_labels(trees) - tags = parse_tree.Tree.get_unique_tags(trees) - words = parse_tree.Tree.get_unique_words(trees) - rare_words = parse_tree.Tree.get_rare_words(trees) - root_labels = parse_tree.Tree.get_root_labels(trees) - open_nodes = trainer.get_open_nodes(trees, args) - - model = lstm_model.LSTMModel(pt, None, None, transitions, constituents, tags, words, rare_words, root_labels, open_nodes, args) + model, _, _ = trainer.build_trainer(args, train_trees, dev_trees, pt, forward_charlm, backward_charlm) + model = model.model + assert isinstance(model, lstm_model.LSTMModel) return model @pytest.fixture(scope="module") @@ -189,6 +189,21 @@ def test_forward_combined_dummy(pt): model = build_model(pt, '--no_combined_dummy_embedding') run_forward_checks(model) +def test_forward_charlm(pt): + """ + Tests loading and running a charlm + + Note that this doesn't test the results of the charlm itself, + just that the model is shaped correctly + """ + forward_charlm_path = os.path.join(TEST_MODELS_DIR, "en", "forward_charlm", "1billion.pt") + backward_charlm_path = os.path.join(TEST_MODELS_DIR, "en", "backward_charlm", "1billion.pt") + assert os.path.exists(forward_charlm_path), "Need to download en test models (or update path to the forward charlm)" + assert os.path.exists(backward_charlm_path), "Need to download en test models (or update path to the backward charlm)" + + model = build_model(pt, '--charlm_forward_file', forward_charlm_path, '--charlm_backward_file', backward_charlm_path) + run_forward_checks(model) + def test_save_load_model(pt, unary_model): """ Just tests that saving and loading works without crashs. -- cgit v1.2.3 From dc57b589241727a6787ffad6b76187531e8dbcf8 Mon Sep 17 00:00:00 2001 From: John Bauer Date: Wed, 29 Sep 2021 17:06:29 -0700 Subject: Move the tqdm inside the iterator so it can have an informed max length --- stanza/models/constituency/tree_reader.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/stanza/models/constituency/tree_reader.py b/stanza/models/constituency/tree_reader.py index 05f5e848..a8d2ed5d 100644 --- a/stanza/models/constituency/tree_reader.py +++ b/stanza/models/constituency/tree_reader.py @@ -87,6 +87,10 @@ class TokenIterator: self.lines = text.split("\n") self.num_lines = len(self.lines) self.line_num = -1 + if self.num_lines > 1000: + self.line_iterator = iter(tqdm(self.lines)) + else: + self.line_iterator = iter(self.lines) self.token_iterator = iter([]) def __iter__(self): @@ -99,7 +103,7 @@ class TokenIterator: if self.line_num >= len(self.lines): raise StopIteration - line = self.lines[self.line_num].strip() + line = next(self.line_iterator, "").strip() if not line: continue @@ -126,8 +130,6 @@ def read_trees(text, broken_ok=False): Reads multiple trees from the text """ token_iterator = TokenIterator(text) - if token_iterator.num_lines > 1000: - token_iterator = iter(tqdm(token_iterator)) trees = recursive_read_trees(token_iterator, broken_ok=broken_ok) return trees -- cgit v1.2.3 From 5c44de09a7ec1a34476ac7f349d8cd9d65555dd1 Mon Sep 17 00:00:00 2001 From: hungbui0411 <86261282+hungbui0411@users.noreply.github.com> Date: Thu, 30 Sep 2021 14:38:27 +0700 Subject: add script for converting trees to vi_treebank (#814) * Add a script that goes through a directory and convert all the trees in the directory's files into the correct tree format for VLSP 09 constituency dataset * Add a script called vtb_split.py that can be used after vtb_convert.py to split the files in the directory into a train/dev/test set with a split of 70/15/15 --- stanza/utils/datasets/constituency/vtb_convert.py | 75 +++++++++++++ stanza/utils/datasets/constituency/vtb_split.py | 130 ++++++++++++++++++++++ 2 files changed, 205 insertions(+) create mode 100644 stanza/utils/datasets/constituency/vtb_convert.py create mode 100644 stanza/utils/datasets/constituency/vtb_split.py diff --git a/stanza/utils/datasets/constituency/vtb_convert.py b/stanza/utils/datasets/constituency/vtb_convert.py new file mode 100644 index 00000000..d9b92e0e --- /dev/null +++ b/stanza/utils/datasets/constituency/vtb_convert.py @@ -0,0 +1,75 @@ +""" +Script for processing the VTB files and turning their trees into the desired tree syntax + +The VTB original trees are stored in the directory: +VietTreebank_VLSP_SP73/Kho ngu lieu 10000 cay cu phap + +The script requires two arguments: +1. Original directory storing the original trees +2. New directory storing the converted trees +""" + + +import os +import argparse + + +def convert_file(org_dir, new_dir): + """ + :param org_dir: original directory storing original trees + :param new_dir: new directory storing formatted constituency trees + + This function writes new trees to the corresponding files in new_dir + """ + with open(org_dir, 'r') as reader, open(new_dir, 'w') as writer: + content = reader.readlines() + for line in content: + line = ' '.join(line.split()) + if line == '': + continue + elif line == '': + writer.write('(ROOT ') + elif line == '': + writer.write(')\n') + else: + writer.write(line) + + +def main(): + """ + Main function for the script + + Process args, loop through each file in the directory and convert + to the desired tree format + """ + parser = argparse.ArgumentParser( + description="Script that converts a VTB Tree into the desired format", + ) + parser.add_argument( + 'org_dir', + help='The location of the original directory storing original trees ' + ) + parser.add_argument( + 'new_dir', + help='The location of new directory storing the new formatted trees' + ) + + args = parser.parse_args() + + org_dir = args.org_dir + new_dir = args.new_dir + + for filename in os.listdir(org_dir): + file_name, file_extension = os.path.splitext(filename) + # Only convert .prd files, skip the .raw files + if file_extension == '.raw': + continue + file_path = os.path.join(org_dir, filename) + new_path = os.path.join(new_dir, file_name) + new_file_path = f'{new_path}.mrg' + # Convert the tree and write to new_file_path + convert_file(file_path, new_file_path) + + +if __name__ == '__main__': + main() diff --git a/stanza/utils/datasets/constituency/vtb_split.py b/stanza/utils/datasets/constituency/vtb_split.py new file mode 100644 index 00000000..27a7161b --- /dev/null +++ b/stanza/utils/datasets/constituency/vtb_split.py @@ -0,0 +1,130 @@ +""" +From a directory of files with VTB Trees, split into train/dev/test set +with a split of 70/15/15 + +The script requires two arguments +1. org_dir: the original directory obtainable from running vtb_convert.py +2. split_dir: the directory where the train/dev/test splits will be stored +""" + +import os +import argparse +import random + + +def create_shuffle_list(org_dir): + """ + This function creates the random order with which we use to loop through the files + :param org_dir: original directory storing the files that store the trees + :return: list of file names randomly shuffled + """ + file_names = [] + for filename in os.listdir(org_dir): + file_names.append(filename) + random.shuffle(file_names) + + return file_names + + +def create_paths(split_dir): + """ + This function creates the necessary paths for the train/dev/test splits + :param split_dir: directory that stores the splits + :return: train path, dev path, test path + """ + train_path = os.path.join(split_dir, 'train.mrg') + dev_path = os.path.join(split_dir, 'dev.mrg') + test_path = os.path.join(split_dir, 'test.mrg') + + return train_path, dev_path, test_path + + +def get_num_samples(org_dir, file_names): + """ + Function for obtaining the number of samples + :param org_dir: original directory storing the tree files + :param file_names: list of file names in the directory + :return: number of samples + """ + count = 0 + # Loop through the files, which then loop through the trees + for filename in file_names: + # Skip files that are not .mrg + if not filename.endswith('.mrg'): + continue + # File is .mrg. Start processing + file_dir = os.path.join(org_dir, filename) + with open(file_dir, 'r') as reader: + content = reader.readlines() + for _ in content: + count += 1 + + return count + + +def main(): + """ + Main function for the script + + Process args, loop through each tree in each file in the directory + and write the trees to the train/dev/test split with a split of + 70/15/15 + """ + parser = argparse.ArgumentParser( + description="Script that splits a list of files of vtb trees into train/dev/test sets", + ) + parser.add_argument( + 'org_dir', + help='The location of the original directory storing correctly formatted vtb trees ' + ) + parser.add_argument( + 'split_dir', + help='The location of new directory storing the train/dev/test set' + ) + + args = parser.parse_args() + + org_dir = args.org_dir + split_dir = args.split_dir + + random.seed(1234) + + # Create a random shuffle list of the file names in the original directory + file_names = create_shuffle_list(org_dir) + + # Create train_path, dev_path, test_path + train_path, dev_path, test_path = create_paths(split_dir) + + # Set up the number of samples for each train/dev/test set + num_samples = get_num_samples(org_dir, file_names) + stop_train = int(num_samples * 0.7) + stop_dev = int(num_samples * 0.85) + + # Write directory and write count + write_dir = train_path + count = 0 + + # Loop through the files, which then loop through the trees and write to write_dir + for filename in file_names: + # Skip files that are not .mrg + if not filename.endswith('.mrg'): + continue + # File is .mrg. Start processing + file_dir = os.path.join(org_dir, filename) + with open(file_dir, 'r') as reader, open(write_dir, 'a') as writer: + content = reader.readlines() + for line in content: + # Write to write_dir + writer.write(line) + # Check current count to switch write_dir + count += 1 + # Switch to writing dev set + if count > stop_train: + write_dir = dev_path + # Switch to writing test set + if count > stop_dev: + write_dir = test_path + + +if __name__ == '__main__': + main() -- cgit v1.2.3 From c3cd0a2bbc7bc20a4b4003b1c96377b9f3b10505 Mon Sep 17 00:00:00 2001 From: John Bauer Date: Thu, 30 Sep 2021 09:14:21 -0700 Subject: Add missing pytest to the test conda --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index c3f9be88..042e4502 100644 --- a/setup.py +++ b/setup.py @@ -87,7 +87,7 @@ setup( # $ pip install -e .[dev,test] extras_require={ 'dev': ['check-manifest'], - 'test': ['coverage'], + 'test': ['coverage', 'pytest'], }, # If there are data files included in your packages that need to be -- cgit v1.2.3 From bbd553775a955b4322927214083c75dd2f690e05 Mon Sep 17 00:00:00 2001 From: John Bauer Date: Wed, 29 Sep 2021 21:48:02 -0700 Subject: Ask for one more line so that the line tqdm can stop --- stanza/models/constituency/tree_reader.py | 1 + 1 file changed, 1 insertion(+) diff --git a/stanza/models/constituency/tree_reader.py b/stanza/models/constituency/tree_reader.py index a8d2ed5d..65c9250c 100644 --- a/stanza/models/constituency/tree_reader.py +++ b/stanza/models/constituency/tree_reader.py @@ -101,6 +101,7 @@ class TokenIterator: while n is None: self.line_num = self.line_num + 1 if self.line_num >= len(self.lines): + next(self.line_iterator, "") raise StopIteration line = next(self.line_iterator, "").strip() -- cgit v1.2.3 From 9c99dee9310f16419ff6152008599c3034c20776 Mon Sep 17 00:00:00 2001 From: John Bauer Date: Wed, 29 Sep 2021 01:13:43 -0700 Subject: Split a couple tests into a separate Trainer test --- stanza/tests/constituency/test_lstm_model.py | 86 +-------------------------- stanza/tests/constituency/test_trainer.py | 89 ++++++++++++++++++++++++++++ 2 files changed, 92 insertions(+), 83 deletions(-) create mode 100644 stanza/tests/constituency/test_trainer.py diff --git a/stanza/tests/constituency/test_lstm_model.py b/stanza/tests/constituency/test_lstm_model.py index b84278c8..b9bb6f80 100644 --- a/stanza/tests/constituency/test_lstm_model.py +++ b/stanza/tests/constituency/test_lstm_model.py @@ -1,84 +1,22 @@ -import logging import os -import tempfile import pytest -from stanza.models import constituency_parser -from stanza.models.common import pretrain -from stanza.models.constituency import lstm_model from stanza.models.constituency import parse_transitions -from stanza.models.constituency import parse_tree -from stanza.models.constituency import trainer -from stanza.models.constituency import transition_sequence -from stanza.models.constituency import tree_reader from stanza.tests import * from stanza.tests.constituency import test_parse_transitions +from stanza.tests.constituency.test_trainer import build_trainer, pt pytestmark = [pytest.mark.pipeline, pytest.mark.travis] -logger = logging.getLogger('stanza.constituency.trainer') -logger.setLevel(logging.WARNING) - -TREEBANK = """ -( (S - (VP (VBG Enjoying) - (NP (PRP$ my) (JJ favorite) (NN Friday) (NN tradition))) - (. .))) - -( (NP - (VP (VBG Sitting) - (PP (IN in) - (NP (DT a) (RB stifling) (JJ hot) (NNP South) (NNP Station))) - (VP (VBG waiting) - (PP (IN for) - (NP (PRP$ my) (JJ delayed) (NNP @MBTA) (NN train))))) - (. .))) - -( (S - (NP (PRP I)) - (VP - (ADVP (RB really)) - (VBP hate) - (NP (DT the) (NNP @MBTA))))) - -( (S - (S (VP (VB Seek))) - (CC and) - (S (NP (PRP ye)) - (VP (MD shall) - (VP (VB find)))) - (. .))) -""" - -@pytest.fixture(scope="module") -def pt(): - return pretrain.Pretrain(vec_filename=f'{TEST_WORKING_DIR}/in/tiny_emb.xz', save_to_file=False) - def build_model(pt, *args): - # TODO: build a fake embedding some other way? - train_trees = tree_reader.read_trees(TREEBANK) - dev_trees = train_trees[-1:] - - args = constituency_parser.parse_args(args) - forward_charlm = trainer.load_charlm(args['charlm_forward_file']) - backward_charlm = trainer.load_charlm(args['charlm_backward_file']) - - model, _, _ = trainer.build_trainer(args, train_trees, dev_trees, pt, forward_charlm, backward_charlm) - model = model.model - assert isinstance(model, lstm_model.LSTMModel) - return model + trainer = build_trainer(pt, *args) + return trainer.model @pytest.fixture(scope="module") def unary_model(pt): return build_model(pt, "--transition_scheme", "TOP_DOWN_UNARY") -def test_initial_model(unary_model): - """ - does nothing, just tests that the construction went okay - """ - pass - def test_initial_state(unary_model): test_parse_transitions.test_initial_state(unary_model) @@ -203,21 +141,3 @@ def test_forward_charlm(pt): model = build_model(pt, '--charlm_forward_file', forward_charlm_path, '--charlm_backward_file', backward_charlm_path) run_forward_checks(model) - -def test_save_load_model(pt, unary_model): - """ - Just tests that saving and loading works without crashs. - - Currently no test of the values themselves - """ - with tempfile.TemporaryDirectory() as tmpdirname: - tr = trainer.Trainer(model=unary_model) - - # attempt saving - filename = os.path.join(tmpdirname, "parser.pt") - tr.save(filename) - - assert os.path.exists(filename) - - # load it back in - tr.load(filename, pt, None, None, False) diff --git a/stanza/tests/constituency/test_trainer.py b/stanza/tests/constituency/test_trainer.py new file mode 100644 index 00000000..d9fd18f6 --- /dev/null +++ b/stanza/tests/constituency/test_trainer.py @@ -0,0 +1,89 @@ +import logging +import tempfile + +import pytest + +from stanza.models import constituency_parser +from stanza.models.common import pretrain +from stanza.models.constituency import lstm_model +from stanza.models.constituency import trainer +from stanza.models.constituency import tree_reader +from stanza.tests import * + +pytestmark = [pytest.mark.pipeline, pytest.mark.travis] + +logger = logging.getLogger('stanza.constituency.trainer') +logger.setLevel(logging.WARNING) + +TREEBANK = """ +( (S + (VP (VBG Enjoying) + (NP (PRP$ my) (JJ favorite) (NN Friday) (NN tradition))) + (. .))) + +( (NP + (VP (VBG Sitting) + (PP (IN in) + (NP (DT a) (RB stifling) (JJ hot) (NNP South) (NNP Station))) + (VP (VBG waiting) + (PP (IN for) + (NP (PRP$ my) (JJ delayed) (NNP @MBTA) (NN train))))) + (. .))) + +( (S + (NP (PRP I)) + (VP + (ADVP (RB really)) + (VBP hate) + (NP (DT the) (NNP @MBTA))))) + +( (S + (S (VP (VB Seek))) + (CC and) + (S (NP (PRP ye)) + (VP (MD shall) + (VP (VB find)))) + (. .))) +""" + +@pytest.fixture(scope="module") +def pt(): + return pretrain.Pretrain(vec_filename=f'{TEST_WORKING_DIR}/in/tiny_emb.xz', save_to_file=False) + +def build_trainer(pt, *args): + # TODO: build a fake embedding some other way? + train_trees = tree_reader.read_trees(TREEBANK) + dev_trees = train_trees[-1:] + + args = constituency_parser.parse_args(args) + forward_charlm = trainer.load_charlm(args['charlm_forward_file']) + backward_charlm = trainer.load_charlm(args['charlm_backward_file']) + + model, _, _ = trainer.build_trainer(args, train_trees, dev_trees, pt, forward_charlm, backward_charlm) + assert isinstance(model.model, lstm_model.LSTMModel) + return model + +def test_initial_model(pt): + """ + does nothing, just tests that the construction went okay + """ + build_trainer(pt) + + +def test_save_load_model(pt): + """ + Just tests that saving and loading works without crashs. + + Currently no test of the values themselves + """ + with tempfile.TemporaryDirectory() as tmpdirname: + tr = build_trainer(pt) + + # attempt saving + filename = os.path.join(tmpdirname, "parser.pt") + tr.save(filename) + + assert os.path.exists(filename) + + # load it back in + tr.load(filename, pt, None, None, False) -- cgit v1.2.3 From be5b0cb853497f7e78c1b12e2e7654993597c8e2 Mon Sep 17 00:00:00 2001 From: J38 Date: Sun, 3 Oct 2021 18:39:55 -0700 Subject: Update stanza-tests.yaml --- .github/workflows/stanza-tests.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/stanza-tests.yaml b/.github/workflows/stanza-tests.yaml index d40eb274..3ce9d00d 100644 --- a/.github/workflows/stanza-tests.yaml +++ b/.github/workflows/stanza-tests.yaml @@ -16,8 +16,8 @@ jobs: # set up environment bash . /home/stanzabuild/miniconda3/etc/profile.d/conda.sh - export CORENLP_HOME=/home/stanzabuild/stanford-corenlp-4.2.2 - export CLASSPATH=/home/stanzabuild/stanford-corenlp-4.2.2/*: + export CORENLP_HOME=/home/stanzabuild/stanford-corenlp-4.3.0 + export CLASSPATH=/home/stanzabuild/stanford-corenlp-4.3.0/*: # install from stanza repo being evaluated pwd pip install -e . -- cgit v1.2.3 From 069cb8ce5d16796dad00c1b775ed74259d587325 Mon Sep 17 00:00:00 2001 From: J38 Date: Sun, 3 Oct 2021 19:47:00 -0700 Subject: try to fix test --- stanza/tests/constituency/test_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stanza/tests/constituency/test_utils.py b/stanza/tests/constituency/test_utils.py index 38a9f37e..7e3b5d9e 100644 --- a/stanza/tests/constituency/test_utils.py +++ b/stanza/tests/constituency/test_utils.py @@ -11,7 +11,7 @@ pytestmark = [pytest.mark.pipeline, pytest.mark.travis] @pytest.fixture(scope="module") def pipeline(): - return Pipeline(lang="en", processors="tokenize, pos", tokenize_pretokenized=True) + return Pipeline(dir=TEST_MODELS_DIR, lang="en", processors="tokenize, pos", tokenize_pretokenized=True) -- cgit v1.2.3 From a10e98510f0ad770cc4c7faa79c700e8ffab7e8f Mon Sep 17 00:00:00 2001 From: John Bauer Date: Sat, 2 Oct 2021 21:53:01 -0700 Subject: Refactor usage of the dictionary so we can include it in the pipeline --- stanza/models/tokenization/trainer.py | 9 ++++++++- stanza/models/tokenizer.py | 31 +++++++++++++++++-------------- stanza/pipeline/tokenize_processor.py | 2 +- 3 files changed, 26 insertions(+), 16 deletions(-) diff --git a/stanza/models/tokenization/trainer.py b/stanza/models/tokenization/trainer.py index 37fe66da..c40b70d2 100644 --- a/stanza/models/tokenization/trainer.py +++ b/stanza/models/tokenization/trainer.py @@ -5,6 +5,7 @@ import torch.nn as nn import torch.optim as optim from stanza.models.common.trainer import Trainer as BaseTrainer +from stanza.models.tokenization.utils import create_dictionary from .model import Tokenizer from .vocab import Vocab @@ -12,7 +13,7 @@ from .vocab import Vocab logger = logging.getLogger('stanza') class Trainer(BaseTrainer): - def __init__(self, args=None, vocab=None, lexicon=None, model_file=None, use_cuda=False): + def __init__(self, args=None, vocab=None, lexicon=None, dictionary=None, model_file=None, use_cuda=False): self.use_cuda = use_cuda if model_file is not None: # load everything from file @@ -22,6 +23,7 @@ class Trainer(BaseTrainer): self.args = args self.vocab = vocab self.lexicon = lexicon + self.dictionary = dictionary self.model = Tokenizer(self.args, self.args['vocab_size'], self.args['emb_dim'], self.args['hidden_dim'], dropout=self.args['dropout'], feat_dropout=self.args['feat_dropout']) self.criterion = nn.CrossEntropyLoss(ignore_index=-1) if use_cuda: @@ -97,3 +99,8 @@ class Trainer(BaseTrainer): self.model.load_state_dict(checkpoint['model']) self.vocab = Vocab.load_state_dict(checkpoint['vocab']) self.lexicon = checkpoint['lexicon'] + + if self.lexicon is not None: + self.dictionary = create_dictionary(self.lexicon) + else: + self.dictionary = None diff --git a/stanza/models/tokenizer.py b/stanza/models/tokenizer.py index c51720ee..e88b746e 100644 --- a/stanza/models/tokenizer.py +++ b/stanza/models/tokenizer.py @@ -58,7 +58,7 @@ def parse_args(args=None): parser.add_argument('--input_dropout', action='store_true', help="Dropout input embeddings as well") parser.add_argument('--conv_res', type=str, default=None, help="Convolutional residual layers for the RNN") parser.add_argument('--rnn_layers', type=int, default=1, help="Layers of RNN in the tokenizer") - parser.add_argument('--use_dictionary', action='store_true', help="Use dictionary feature. The lexicon is created using the training data and external dict (if any) expected to be found under the same folder of training dataset, formatted as SHORTHAND-externaldict.txt where each line in this file is a word. For example, data/zh_gsdsimp-externaldict.txt") + parser.add_argument('--use_dictionary', action='store_true', help="Use dictionary feature. The lexicon is created using the training data and external dict (if any) expected to be found under the same folder of training dataset, formatted as SHORTHAND-externaldict.txt where each line in this file is a word. For example, data/tokenize/zh_gsdsimp-externaldict.txt") parser.add_argument('--max_grad_norm', type=float, default=1.0, help="Maximum gradient norm to clip to") parser.add_argument('--anneal', type=float, default=.999, help="Anneal the learning rate by this amount when dev performance deteriorate") @@ -109,24 +109,30 @@ def main(args=None): utils.ensure_dir(args['save_dir']) if args['mode'] == 'train': - #load lexicon - args['lexicon'], args['num_dict_feat'] = (None, None) if not args["use_dictionary"] else load_lexicon(args) - #create the dictionary - args['dictionary'] = None if not args["use_dictionary"] else create_dictionary(args['lexicon']) - #adjust the feat_dim - args['feat_dim'] += args['num_dict_feat']*2 if args["use_dictionary"] else 0 train(args) else: evaluate(args) def train(args): + if args['use_dictionary']: + #load lexicon + lexicon, args['num_dict_feat'] = load_lexicon(args) + #create the dictionary + dictionary = create_dictionary(lexicon) + #adjust the feat_dim + args['feat_dim'] += args['num_dict_feat']*2 + else: + args['num_dict_feat'] = 0 + lexicon=None + dictionary=None + mwt_dict = load_mwt_dict(args['mwt_json_file']) train_input_files = { 'txt': args['txt_file'], 'label': args['label_file'] } - train_batches = DataLoader(args, input_files=train_input_files, dictionary=args["dictionary"]) + train_batches = DataLoader(args, input_files=train_input_files, dictionary=dictionary) vocab = train_batches.vocab args['vocab_size'] = len(vocab) @@ -135,13 +141,13 @@ def train(args): 'txt': args['dev_txt_file'], 'label': args['dev_label_file'] } - dev_batches = DataLoader(args, input_files=dev_input_files, vocab=vocab, evaluation=True, dictionary=args["dictionary"]) + dev_batches = DataLoader(args, input_files=dev_input_files, vocab=vocab, evaluation=True, dictionary=dictionary) if args['use_mwt'] is None: args['use_mwt'] = train_batches.has_mwt() logger.info("Found {}mwts in the training data. Setting use_mwt to {}".format(("" if args['use_mwt'] else "no "), args['use_mwt'])) - trainer = Trainer(args=args, vocab=vocab, lexicon=args['lexicon'], use_cuda=args['cuda']) + trainer = Trainer(args=args, vocab=vocab, lexicon=lexicon, dictionary=dictionary, use_cuda=args['cuda']) if args['load_name'] is not None: load_name = os.path.join(args['save_dir'], args['load_name']) @@ -204,16 +210,13 @@ def evaluate(args): if not k.endswith('_file') and k not in ['cuda', 'mode', 'save_dir', 'load_name', 'save_name']: args[k] = loaded_args[k] - args['lexicon'] = None if not args['use_dictionary'] else trainer.lexicon - args['dictionary'] = None if not args['use_dictionary'] else create_dictionary(lexicon) - eval_input_files = { 'txt': args['txt_file'], 'label': args['label_file'] } - batches = DataLoader(args, input_files=eval_input_files, vocab=vocab, evaluation=True, dictionary=args['dictionary']) + batches = DataLoader(args, input_files=eval_input_files, vocab=vocab, evaluation=True, dictionary=trainer.dictionary) oov_count, N, _, _ = output_predictions(args['conll_file'], trainer, batches, vocab, mwt_dict, args['max_seqlen']) diff --git a/stanza/pipeline/tokenize_processor.py b/stanza/pipeline/tokenize_processor.py index 79b17a54..3421f90b 100644 --- a/stanza/pipeline/tokenize_processor.py +++ b/stanza/pipeline/tokenize_processor.py @@ -82,7 +82,7 @@ class TokenizeProcessor(UDProcessor): raw_text = '\n\n'.join(document) if isinstance(document, list) else document # set up batches - batches = DataLoader(self.config, input_text=raw_text, vocab=self.vocab, evaluation=True) + batches = DataLoader(self.config, input_text=raw_text, vocab=self.vocab, evaluation=True, dictionary=self.trainer.dictionary) # get dict data _, _, _, document = output_predictions(None, self.trainer, batches, self.vocab, None, self.config.get('max_seqlen', TokenizeProcessor.MAX_SEQ_LENGTH_DEFAULT), -- cgit v1.2.3 From 15ed89c75ca0a67403d36d6aa3bfe6065d978e37 Mon Sep 17 00:00:00 2001 From: John Bauer Date: Mon, 4 Oct 2021 14:54:04 -0700 Subject: Slightly update the BSNLP description --- stanza/utils/datasets/ner/prepare_ner_dataset.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/stanza/utils/datasets/ner/prepare_ner_dataset.py b/stanza/utils/datasets/ner/prepare_ner_dataset.py index c664a105..f248748b 100644 --- a/stanza/utils/datasets/ner/prepare_ner_dataset.py +++ b/stanza/utils/datasets/ner/prepare_ner_dataset.py @@ -62,10 +62,12 @@ The two Hungarian datasets can be combined with hu_combined BSNLP publishes NER datasets for Eastern European languages. - In 2019 they published BG, CS, PL, RU. + - http://bsnlp.cs.helsinki.fi/bsnlp-2019/shared_task.html - In 2021 they added some more data, but the test sets were not publicly available as of April 2021. Therefore, currently the model is made from 2019. - - http://bsnlp.cs.helsinki.fi/bsnlp-2019/shared_task.html + In 2021, the link to the 2021 task is here: + http://bsnlp.cs.helsinki.fi/shared-task.html - The below method processes the 2019 version of the corpus. It has specific adjustments for the BG section, which has quite a few typos or mis-annotations in it. Other languages -- cgit v1.2.3 From 60f1087f68d0dc72fba2752b5bd41975ab57ed1f Mon Sep 17 00:00:00 2001 From: John Bauer Date: Mon, 4 Oct 2021 15:12:36 -0700 Subject: Add constituency parser to the model build script --- stanza/resources/prepare_resources.py | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/stanza/resources/prepare_resources.py b/stanza/resources/prepare_resources.py index 543c14ee..23568527 100644 --- a/stanza/resources/prepare_resources.py +++ b/stanza/resources/prepare_resources.py @@ -164,6 +164,11 @@ default_sentiment = { "zh-hans": "ren", } +# also, a few languages (very few, currently) have constituency parser models +default_constituency = { + "en": "wsj", +} + allowed_empty_languages = [ # we don't have a lot of Thai support yet "th" @@ -178,6 +183,7 @@ processor_to_ending = { "depparse": "parser", "ner": "nertagger", "sentiment": "sentiment", + "constituency": "constituency", "pretrain": "pretrain", "forward_charlm": "forward_charlm", "backward_charlm": "backward_charlm", @@ -338,6 +344,11 @@ def process_dirs(args): # sentiment models use the default pretrain for the language pretrain_package = default_treebanks[lang] dependencies = [{'model': 'pretrain', 'package': pretrain_package}] + elif processor == 'constituency': + # so far, this invariant is true: + # constituency models use the default pretrain for the language + pretrain_package = default_treebanks[lang] + dependencies = [{'model': 'pretrain', 'package': pretrain_package}] else: dependencies = None # maintain resources @@ -369,6 +380,8 @@ def process_defaults(args): charlm_package = default_charlms[lang] if lang in default_sentiment: sentiment_package = default_sentiment[lang] + if lang in default_constituency: + constituency_package = default_constituency[lang] if lang in default_ners and lang in default_charlms: ner_dependencies = get_ner_dependencies(lang, ner_package) @@ -377,6 +390,9 @@ def process_defaults(args): if lang in default_sentiment: # All of the sentiment models created so far have used the default pretrain default_dependencies['sentiment'] = [{'model': 'pretrain', 'package': ud_package}] + if lang in default_constituency: + # All of the constituency models created so far also use the default pretrain + default_dependencies['constituency'] = [{'model': 'pretrain', 'package': ud_package}] processors = ['tokenize', 'mwt', 'lemma', 'pos', 'depparse', 'pretrain'] if lang in default_ners: @@ -385,6 +401,8 @@ def process_defaults(args): processors.extend(['forward_charlm', 'backward_charlm']) if lang in default_sentiment: processors.append('sentiment') + if lang in default_constituency: + processors.append('constituency') if lang == 'multilingual': processors = ['langid'] @@ -395,6 +413,7 @@ def process_defaults(args): if processor == 'ner': package = ner_package elif processor in ['forward_charlm', 'backward_charlm']: package = charlm_package elif processor == 'sentiment': package = sentiment_package + elif processor == 'constituency': package = constituency_package elif processor == 'langid': package = 'ud' else: package = ud_package @@ -402,7 +421,7 @@ def process_defaults(args): if os.path.exists(filename): print(" Model {} package {}: file {}".format(processor, package, filename)) - if processor in ['tokenize', 'mwt', 'lemma', 'pos', 'depparse', 'ner', 'sentiment', 'langid']: + if processor in ['tokenize', 'mwt', 'lemma', 'pos', 'depparse', 'ner', 'sentiment', 'constituency', 'langid']: default_processors[processor] = package zipf.write(os.path.join(processor, package + '.pt')) elif lang in allowed_empty_languages: -- cgit v1.2.3 From 53da9b7585ec7f9f5ba8134ec966a4a1ee6c76f0 Mon Sep 17 00:00:00 2001 From: John Bauer Date: Mon, 4 Oct 2021 15:25:27 -0700 Subject: Update tokens for new ZH tokenizer with dictionary --- stanza/tests/test_tokenizer.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/stanza/tests/test_tokenizer.py b/stanza/tests/test_tokenizer.py index b7956277..9a48eb54 100644 --- a/stanza/tests/test_tokenizer.py +++ b/stanza/tests/test_tokenizer.py @@ -171,17 +171,16 @@ ZH_DOC1_GOLD_TOKENS=""" ]> ]> -]> +]> ]> -]> -]> -]> -]> -]> -]> -]> -]> -]> +]> +]> +]> +]> +]> +]> +]> +]> """.strip() ZH_DOC_GOLD_NOSSPLIT_TOKENS = """ -- cgit v1.2.3 From 2ef65a9bb491d9687f0d51d05d85a0851f7c732f Mon Sep 17 00:00:00 2001 From: John Bauer Date: Mon, 4 Oct 2021 15:42:05 -0700 Subject: Add a simple test that the con parser is working in the pipeline, sneak in a test of -- to make sure that doesn't crash sentiment --- stanza/tests/test_english_pipeline.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/stanza/tests/test_english_pipeline.py b/stanza/tests/test_english_pipeline.py index f270c1d4..3e3729ac 100644 --- a/stanza/tests/test_english_pipeline.py +++ b/stanza/tests/test_english_pipeline.py @@ -166,6 +166,7 @@ def test_dependency_parse(processed_doc): def test_empty(pipeline): # make sure that various models handle the degenerate empty case pipeline("") + pipeline("--") @pytest.fixture(scope="module") def processed_multidoc(pipeline): @@ -200,3 +201,8 @@ def processed_multidoc_variant(): def test_dependency_parse_multidoc_variant(processed_multidoc_variant): assert "\n\n".join([sent.dependencies_string() for processed_doc in processed_multidoc_variant for sent in processed_doc.sentences]) == \ EN_DOC_DEPENDENCY_PARSES_GOLD + +def test_constituency_parser(): + nlp = stanza.Pipeline(dir=TEST_MODELS_DIR, processors="tokenize,pos,constituency") + doc = nlp("This is a test") + assert str(doc.sentences[0].constituency) == '(ROOT (S (NP (DT This)) (VP (VBZ is) (NP (DT a) (NN test)))))' -- cgit v1.2.3