Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/stanfordnlp/stanza.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJohn Bauer <horatio@gmail.com>2021-10-05 05:15:34 +0300
committerGitHub <noreply@github.com>2021-10-05 05:15:34 +0300
commitf91ca215e175d4f7b202259fe789374db7829395 (patch)
treeb6fdbd6b418cada65511fd9fe8b74c0a5400351e
parentc457a9309ad15c522e94230f919c25d1e7aebf64 (diff)
parent3ef977a6bb34f07acd7237225f3d8d5a7d8b5eea (diff)
Merge pull request #817 from stanfordnlp/devv1.3.0
Update CoreNLP version in tests
-rw-r--r--.github/workflows/stanza-tests.yaml32
-rw-r--r--doc/CoreNLP.proto32
-rwxr-xr-xscripts/sentiment/process_sst.sh19
-rwxr-xr-xscripts/treebank_to_shorthand.sh4
-rw-r--r--setup.py4
-rw-r--r--stanza/__init__.py1
-rw-r--r--stanza/_version.py4
-rw-r--r--stanza/models/classifier.py6
-rw-r--r--stanza/models/classifiers/cnn_classifier.py4
-rw-r--r--stanza/models/classifiers/data.py4
-rw-r--r--stanza/models/common/constant.py21
-rw-r--r--stanza/models/common/count_ner_coverage.py38
-rw-r--r--stanza/models/common/doc.py11
-rw-r--r--stanza/models/common/hlstm.py4
-rw-r--r--stanza/models/common/utils.py31
-rw-r--r--stanza/models/constituency/__init__.py0
-rw-r--r--stanza/models/constituency/base_model.py205
-rw-r--r--stanza/models/constituency/lstm_model.py543
-rw-r--r--stanza/models/constituency/parse_transitions.py603
-rw-r--r--stanza/models/constituency/parse_tree.py303
-rw-r--r--stanza/models/constituency/trainer.py586
-rw-r--r--stanza/models/constituency/transition_sequence.py112
-rw-r--r--stanza/models/constituency/tree_reader.py154
-rw-r--r--stanza/models/constituency/tree_stack.py52
-rw-r--r--stanza/models/constituency/utils.py58
-rw-r--r--stanza/models/constituency_parser.py290
-rw-r--r--stanza/models/lang_identifier.py226
-rw-r--r--stanza/models/langid/__init__.py0
-rw-r--r--stanza/models/langid/create_ud_data.py205
-rw-r--r--stanza/models/langid/data.py136
-rw-r--r--stanza/models/langid/model.py120
-rw-r--r--stanza/models/langid/trainer.py53
-rw-r--r--stanza/models/ner/model.py5
-rw-r--r--stanza/models/ner_tagger.py5
-rw-r--r--stanza/models/parser.py1
-rw-r--r--stanza/models/pos/xpos_vocab_factory.py2
-rw-r--r--stanza/models/tokenization/data.py63
-rw-r--r--stanza/models/tokenization/model.py7
-rw-r--r--stanza/models/tokenization/trainer.py16
-rw-r--r--stanza/models/tokenization/utils.py125
-rw-r--r--stanza/models/tokenizer.py45
-rw-r--r--stanza/pipeline/_constants.py2
-rw-r--r--stanza/pipeline/constituency_processor.py52
-rw-r--r--stanza/pipeline/core.py2
-rw-r--r--stanza/pipeline/langid_processor.py126
-rw-r--r--stanza/pipeline/multilingual.py109
-rw-r--r--stanza/pipeline/pos_processor.py15
-rw-r--r--stanza/pipeline/tokenize_processor.py2
-rw-r--r--stanza/protobuf/CoreNLP_pb2.py253
-rw-r--r--stanza/resources/common.py8
-rw-r--r--stanza/resources/installation.py36
-rw-r--r--stanza/resources/prepare_resources.py71
-rw-r--r--stanza/server/java_protobuf_requests.py92
-rw-r--r--stanza/server/parser_eval.py41
-rw-r--r--stanza/server/ud_enhancer.py2
-rw-r--r--stanza/tests/constituency/test_lstm_model.py143
-rw-r--r--stanza/tests/constituency/test_parse_transitions.py412
-rw-r--r--stanza/tests/constituency/test_parse_tree.py196
-rw-r--r--stanza/tests/constituency/test_trainer.py89
-rw-r--r--stanza/tests/constituency/test_transition_sequence.py87
-rw-r--r--stanza/tests/constituency/test_tree_reader.py61
-rw-r--r--stanza/tests/constituency/test_tree_stack.py50
-rw-r--r--stanza/tests/constituency/test_utils.py68
-rw-r--r--stanza/tests/resources/test_common.py19
-rw-r--r--stanza/tests/resources/test_installation.py (renamed from stanza/tests/test_installation.py)2
-rw-r--r--stanza/tests/setup_test.sh1
-rw-r--r--stanza/tests/test_constant.py35
-rw-r--r--stanza/tests/test_english_pipeline.py6
-rw-r--r--stanza/tests/test_java_protobuf_requests.py23
-rw-r--r--stanza/tests/test_langid.py613
-rw-r--r--stanza/tests/test_parser_eval.py40
-rw-r--r--stanza/tests/test_pipeline_sentiment_processor.py9
-rw-r--r--stanza/tests/test_tokenization_lst20.py236
-rw-r--r--stanza/tests/test_tokenization_orchid.py107
-rw-r--r--stanza/tests/test_tokenize_data.py1
-rw-r--r--stanza/tests/test_tokenizer.py21
-rw-r--r--stanza/utils/charlm/make_lm_data.py3
-rw-r--r--stanza/utils/datasets/common.py1
-rw-r--r--stanza/utils/datasets/constituency/convert_it_turin.py322
-rw-r--r--stanza/utils/datasets/constituency/vtb_convert.py75
-rw-r--r--stanza/utils/datasets/constituency/vtb_split.py130
-rw-r--r--stanza/utils/datasets/ner/convert_bsf_to_beios.py63
-rw-r--r--stanza/utils/datasets/ner/convert_fire_2013.py12
-rw-r--r--stanza/utils/datasets/ner/prepare_ner_dataset.py22
-rw-r--r--stanza/utils/datasets/ner/prepare_ner_file.py14
-rwxr-xr-xstanza/utils/datasets/prepare_tokenizer_treebank.py42
-rw-r--r--stanza/utils/datasets/process_thai_tokenization.py66
-rw-r--r--stanza/utils/datasets/thai_syllable_dict_generator.py53
-rw-r--r--stanza/utils/datasets/tokenization/__init__.py0
-rw-r--r--stanza/utils/datasets/tokenization/convert_th_best.py (renamed from stanza/utils/datasets/process_best.py)136
-rw-r--r--stanza/utils/datasets/tokenization/convert_th_lst20.py131
-rw-r--r--stanza/utils/datasets/tokenization/convert_th_orchid.py (renamed from stanza/utils/datasets/process_orchid.py)26
-rw-r--r--stanza/utils/datasets/tokenization/convert_vi_vlsp.py153
-rw-r--r--stanza/utils/datasets/tokenization/process_thai_tokenization.py187
-rw-r--r--stanza/utils/default_paths.py5
95 files changed, 8329 insertions, 272 deletions
diff --git a/.github/workflows/stanza-tests.yaml b/.github/workflows/stanza-tests.yaml
new file mode 100644
index 00000000..3ce9d00d
--- /dev/null
+++ b/.github/workflows/stanza-tests.yaml
@@ -0,0 +1,32 @@
+name: Run Stanza Tests
+on: [push]
+jobs:
+ Run-Stanza-Tests:
+ runs-on: self-hosted
+ steps:
+ - run: echo "🎉 The job was automatically triggered by a ${{ github.event_name }} event."
+ - run: echo "🐧 This job is now running on a ${{ runner.os }} server hosted by GitHub!"
+ - run: echo "🔎 The name of your branch is ${{ github.ref }} and your repository is ${{ github.repository }}."
+ - name: Check out repository code
+ uses: actions/checkout@v2
+ - run: echo "💡 The ${{ github.repository }} repository has been cloned to the runner."
+ - run: echo "🖥️ The workflow is now ready to test your code on the runner."
+ - name: Run demo
+ run: |
+ # set up environment
+ bash
+ . /home/stanzabuild/miniconda3/etc/profile.d/conda.sh
+ export CORENLP_HOME=/home/stanzabuild/stanford-corenlp-4.3.0
+ export CLASSPATH=/home/stanzabuild/stanford-corenlp-4.3.0/*:
+ # install from stanza repo being evaluated
+ pwd
+ pip install -e .
+ # set up for tests
+ rm -rf /home/stanzabuild/stanza-github-actions/actions-runner/_work/stanza/stanza/stanza_test
+ source stanza/tests/setup_test.sh
+ # run tests
+ echo "Running tests..."
+ export CUDA_VISIBLE_DEVICES=2
+ pytest stanza/tests
+
+ - run: echo "🍏 This job's status is ${{ job.status }}."
diff --git a/doc/CoreNLP.proto b/doc/CoreNLP.proto
index 7fbff6dd..18b56ec0 100644
--- a/doc/CoreNLP.proto
+++ b/doc/CoreNLP.proto
@@ -698,3 +698,35 @@ message DependencyEnhancerRequest {
string relativePronouns = 3;
}
}
+
+// A version of ParseTree with a flattened structure so that deep trees
+// don't exceed the protobuf stack depth
+message FlattenedParseTree {
+ message Node {
+ oneof contents {
+ bool openNode = 1;
+ bool closeNode = 2;
+ string value = 3;
+ }
+
+ optional double score = 4;
+ }
+
+ repeated Node nodes = 1;
+}
+
+// A protobuf for calling the java constituency parser evaluator from elsewhere
+message EvaluateParserRequest {
+ message ParseResult {
+ required FlattenedParseTree gold = 1;
+ // repeated so you can send in kbest parses, if your parser handles that
+ // note that this already includes a score field
+ repeated FlattenedParseTree predicted = 2;
+ }
+
+ repeated ParseResult treebank = 1;
+}
+
+message EvaluateParserResponse {
+ required double f1 = 1;
+}
diff --git a/scripts/sentiment/process_sst.sh b/scripts/sentiment/process_sst.sh
index 7ee7fb67..ac33990f 100755
--- a/scripts/sentiment/process_sst.sh
+++ b/scripts/sentiment/process_sst.sh
@@ -39,6 +39,12 @@ java edu.stanford.nlp.trees.OutputSubtrees -input $INPUT_DIR/fiveclass/dev.txt >
echo $OUTPUT_DIR/fiveclass/test-phrases.txt
java edu.stanford.nlp.trees.OutputSubtrees -input $INPUT_DIR/fiveclass/test.txt > $OUTPUT_DIR/fiveclass/test-phrases.txt
+echo $OUTPUT_DIR/fiveclass/extra-train-phrases.txt
+java edu.stanford.nlp.trees.OutputSubtrees -input $INPUT_DIR/fiveclass/extra-train.txt > $OUTPUT_DIR/fiveclass/extra-train-phrases.txt
+
+echo $OUTPUT_DIR/fiveclass/checked-extra-train-phrases.txt
+java edu.stanford.nlp.trees.OutputSubtrees -input $INPUT_DIR/fiveclass/checked-extra-train.txt > $OUTPUT_DIR/fiveclass/checked-extra-train-phrases.txt
+
echo $OUTPUT_DIR/fiveclass/train-roots.txt
java edu.stanford.nlp.trees.OutputSubtrees -input $INPUT_DIR/fiveclass/train.txt -root_only > $OUTPUT_DIR/fiveclass/train-roots.txt
@@ -59,6 +65,12 @@ java edu.stanford.nlp.trees.OutputSubtrees -input $INPUT_DIR/fiveclass/dev.txt -
echo $OUTPUT_DIR/binary/test-binary-phrases.txt
java edu.stanford.nlp.trees.OutputSubtrees -input $INPUT_DIR/fiveclass/test.txt -ignore_labels 2 -remap_labels "1=0,2=-1,3=1,4=1" > $OUTPUT_DIR/binary/test-binary-phrases.txt
+echo $OUTPUT_DIR/binary/extra-train-binary-phrases.txt
+java edu.stanford.nlp.trees.OutputSubtrees -input $INPUT_DIR/fiveclass/extra-train.txt -ignore_labels 2 -remap_labels "1=0,2=-1,3=1,4=1" > $OUTPUT_DIR/binary/extra-train-binary-phrases.txt
+
+echo $OUTPUT_DIR/binary/checked-extra-train-binary-phrases.txt
+java edu.stanford.nlp.trees.OutputSubtrees -input $INPUT_DIR/fiveclass/checked-extra-train.txt -ignore_labels 2 -remap_labels "1=0,2=-1,3=1,4=1" > $OUTPUT_DIR/binary/checked-extra-train-binary-phrases.txt
+
echo $OUTPUT_DIR/binary/dev-binary-roots.txt
java edu.stanford.nlp.trees.OutputSubtrees -input $INPUT_DIR/fiveclass/dev.txt -root_only -ignore_labels 2 -remap_labels "1=0,2=-1,3=1,4=1" > $OUTPUT_DIR/binary/dev-binary-roots.txt
@@ -76,6 +88,13 @@ java edu.stanford.nlp.trees.OutputSubtrees -input $INPUT_DIR/fiveclass/dev.txt -
echo $OUTPUT_DIR/threeclass/test-threeclass-phrases.txt
java edu.stanford.nlp.trees.OutputSubtrees -input $INPUT_DIR/fiveclass/test.txt -remap_labels "0=0,1=0,2=1,3=2,4=2" > $OUTPUT_DIR/threeclass/test-threeclass-phrases.txt
+echo $OUTPUT_DIR/threeclass/extra-train-threeclass-phrases.txt
+java edu.stanford.nlp.trees.OutputSubtrees -input $INPUT_DIR/fiveclass/extra-train.txt -remap_labels "0=0,1=0,2=1,3=2,4=2" > $OUTPUT_DIR/threeclass/extra-train-threeclass-phrases.txt
+
+echo $OUTPUT_DIR/threeclass/checked-extra-train-threeclass-phrases.txt
+java edu.stanford.nlp.trees.OutputSubtrees -input $INPUT_DIR/fiveclass/checked-extra-train.txt -remap_labels "0=0,1=0,2=1,3=2,4=2" > $OUTPUT_DIR/threeclass/checked-extra-train-threeclass-phrases.txt
+
+
echo $OUTPUT_DIR/threeclass/dev-threeclass-roots.txt
java edu.stanford.nlp.trees.OutputSubtrees -input $INPUT_DIR/fiveclass/dev.txt -root_only -remap_labels "0=0,1=0,2=1,3=2,4=2" > $OUTPUT_DIR/threeclass/dev-threeclass-roots.txt
diff --git a/scripts/treebank_to_shorthand.sh b/scripts/treebank_to_shorthand.sh
index bb6f1793..f5395ec6 100755
--- a/scripts/treebank_to_shorthand.sh
+++ b/scripts/treebank_to_shorthand.sh
@@ -19,10 +19,10 @@ lang=`echo $treebank | sed -e 's#-.*$##g' -e 's#^[^_]*_##g'`
lcode=${lang2lcode[$lang]}
if [ -z "$lcode" ]; then
if [ $lang == "Chinese" ]; then
- if [ $tbname == "gsdsimp" ]; then
+ if [ $tbname == "gsdsimp" -o $tbname == "cfl" ]; then
# TODO why not zh-hans?
lcode=zh
- elif [ $tbname == "gsd" -o $tbname == "hk" -o $tbname == "cfl" -o $tbname == "pud" ]; then
+ elif [ $tbname == "gsd" -o $tbname == "hk" -o $tbname == "pud" ]; then
lcode=zh-hant
fi
elif [ $lang == "Norwegian" ]; then
diff --git a/setup.py b/setup.py
index 70e20fe3..042e4502 100644
--- a/setup.py
+++ b/setup.py
@@ -76,7 +76,7 @@ setup(
# your project is installed. For an analysis of "install_requires" vs pip's
# requirements files see:
# https://packaging.python.org/en/latest/requirements.html
- install_requires=['numpy', 'protobuf', 'requests', 'torch>=1.3.0', 'tqdm'],
+ install_requires=['emoji', 'numpy', 'protobuf', 'requests', 'six', 'torch>=1.3.0', 'tqdm'],
# List required Python versions
python_requires='>=3.6',
@@ -87,7 +87,7 @@ setup(
# $ pip install -e .[dev,test]
extras_require={
'dev': ['check-manifest'],
- 'test': ['coverage'],
+ 'test': ['coverage', 'pytest'],
},
# If there are data files included in your packages that need to be
diff --git a/stanza/__init__.py b/stanza/__init__.py
index 76f04fd9..25c6fd13 100644
--- a/stanza/__init__.py
+++ b/stanza/__init__.py
@@ -1,4 +1,5 @@
from stanza.pipeline.core import Pipeline
+from stanza.pipeline.multilingual import MultilingualPipeline
from stanza.models.common.doc import Document
from stanza.resources.common import download
from stanza.resources.installation import install_corenlp, download_corenlp_models
diff --git a/stanza/_version.py b/stanza/_version.py
index 082bbc48..87647499 100644
--- a/stanza/_version.py
+++ b/stanza/_version.py
@@ -1,4 +1,4 @@
""" Single source of truth for version number """
-__version__ = "1.2.3"
-__resources_version__ = '1.2.2'
+__version__ = "1.3.0"
+__resources_version__ = '1.3.0'
diff --git a/stanza/models/classifier.py b/stanza/models/classifier.py
index 2bbe81d9..ed834843 100644
--- a/stanza/models/classifier.py
+++ b/stanza/models/classifier.py
@@ -22,7 +22,7 @@ import stanza.models.classifiers.classifier_args as classifier_args
import stanza.models.classifiers.cnn_classifier as cnn_classifier
import stanza.models.classifiers.data as data
-from stanza.utils.confusion impmort format_confusion
+from stanza.utils.confusion import format_confusion
class Loss(Enum):
@@ -80,7 +80,7 @@ python3 -u -m stanza.models.classifier --wordvec_type google --wordvec_dir exte
To train models on combined 3 class datasets:
-nohup python3 -u -m stanza.models.classifier --max_epochs 400 --filter_channels 1000 --fc_shapes 400,100 --base_name FC41_3class --extra_wordvec_method CONCAT --extra_wordvec_dim 200 --train_file extern_data/sentiment/sst-processed/threeclass/train-threeclass-phrases.txt,extern_data/sentiment/MELD/train.txt,extern_data/sentiment/slsd/train.txt,extern_data/sentiment/arguana/train.txt,extern_data/sentiment/airline/train.txt,extern_data/sentiment/sst-processed/threeclass/extra-train-threeclass-phrases.txt,extern_data/sentiment/sst-processed/threeclass/checked-extra-threeclass-phrases.txt --dev_file extern_data/sentiment/sst-processed/threeclass/dev-threeclass-roots.txt --test_file extern_data/sentiment/sst-processed/threeclass/test-threeclass-roots.txt > FC41_3class.out 2>&1 &
+nohup python3 -u -m stanza.models.classifier --max_epochs 400 --filter_channels 1000 --fc_shapes 400,100 --base_name FC41_3class --extra_wordvec_method CONCAT --extra_wordvec_dim 200 --train_file extern_data/sentiment/sst-processed/threeclass/train-threeclass-phrases.txt,extern_data/sentiment/MELD/train.txt,extern_data/sentiment/slsd/train.txt,extern_data/sentiment/arguana/train.txt,extern_data/sentiment/airline/train.txt,extern_data/sentiment/sst-processed/threeclass/extra-train-threeclass-phrases.txt,extern_data/sentiment/sst-processed/threeclass/checked-extra-train-threeclass-phrases.txt --dev_file extern_data/sentiment/sst-processed/threeclass/dev-threeclass-roots.txt --test_file extern_data/sentiment/sst-processed/threeclass/test-threeclass-roots.txt > FC41_3class.out 2>&1 &
This tests that model:
@@ -488,7 +488,7 @@ def train_model(model, model_file, args, train_set, dev_set, labels):
# Add any leftover loss to the epoch_loss
epoch_loss += running_loss
- logger.info("Finished epoch %d" % (epoch + 1))
+ logger.info("Finished epoch %d Total loss %.3f" % (epoch + 1, epoch_loss))
dev_score = score_dev_set(model, dev_set, args.dev_eval_scoring)
if args.save_intermediate_models:
checkpoint_file = checkpoint_name(model_file, epoch + 1, args.dev_eval_scoring, dev_score)
diff --git a/stanza/models/classifiers/cnn_classifier.py b/stanza/models/classifiers/cnn_classifier.py
index fa5160bf..9cb1bd45 100644
--- a/stanza/models/classifiers/cnn_classifier.py
+++ b/stanza/models/classifiers/cnn_classifier.py
@@ -112,6 +112,8 @@ class CNNClassifier(nn.Module):
self.extra_vocab = list(extra_vocab)
self.extra_vocab_map = { word: i for i, word in enumerate(self.extra_vocab) }
# TODO: possibly add regularization specifically on the extra embedding?
+ # note: it looks like a bug that this doesn't add UNK or PAD, but actually
+ # those are expected to already be the first two entries
self.extra_embedding = nn.Embedding(num_embeddings = len(extra_vocab),
embedding_dim = self.config.extra_wordvec_dim,
max_norm = self.config.extra_wordvec_max_norm,
@@ -367,6 +369,8 @@ class CNNClassifier(nn.Module):
for fc in self.fc_layers[:-1]:
previous_layer = self.dropout(F.relu(fc(previous_layer)))
out = self.fc_layers[-1](previous_layer)
+ # note that we return the raw logits rather than use a softmax
+ # https://discuss.pytorch.org/t/multi-class-cross-entropy-loss-and-softmax-in-pytorch/24920/4
return out
diff --git a/stanza/models/classifiers/data.py b/stanza/models/classifiers/data.py
index 4922414a..72c04b9c 100644
--- a/stanza/models/classifiers/data.py
+++ b/stanza/models/classifiers/data.py
@@ -11,6 +11,10 @@ def update_text(sentence, wordvec_type):
# stanford sentiment dataset has a lot of random - and /
sentence = sentence.replace("-", " ")
sentence = sentence.replace("/", " ")
+ sentence = sentence.strip()
+ if sentence == "":
+ # removed too much
+ sentence = "-"
sentence = sentence.split()
# our current word vectors are all entirely lowercased
sentence = [word.lower() for word in sentence]
diff --git a/stanza/models/common/constant.py b/stanza/models/common/constant.py
index 3ba570ab..9b39b7c2 100644
--- a/stanza/models/common/constant.py
+++ b/stanza/models/common/constant.py
@@ -134,6 +134,7 @@ langlower2lcode = {lcode2lang[k].lower(): k.lower() for k in lcode2lang}
# additional useful code to language mapping
# added after dict invert to avoid conflict
lcode2lang['nb'] = 'Norwegian' # Norwegian Bokmall mapped to default norwegian
+lcode2lang['no'] = 'Norwegian'
lcode2lang['zh'] = 'Simplified_Chinese'
lang2lcode['Chinese'] = 'zh'
@@ -142,12 +143,12 @@ lang2lcode['Chinese'] = 'zh'
lang2lcode['Old_Russian'] = 'orv'
treebank_special_cases = {
- "UD_Chinese-GSDSimp": "zh_gsdsimp",
+ "UD_Chinese-GSDSimp": "zh-hans_gsdsimp",
"UD_Chinese-GSD": "zh-hant_gsd",
"UD_Chinese-HK": "zh-hant_hk",
- "UD_Chinese-CFL": "zh-hant_cfl",
+ "UD_Chinese-CFL": "zh-hans_cfl",
"UD_Chinese-PUD": "zh-hant_pud",
- "UD_Norwegian-Bokmaal": "nb_bokmaal",
+ "UD_Norwegian-Bokmaal": "no_bokmaal",
"UD_Norwegian-Nynorsk": "nn_nynorsk",
"UD_Norwegian-NynorskLIA": "nn_nynorsklia",
}
@@ -159,7 +160,13 @@ def treebank_to_short_name(treebank):
if treebank.startswith('UD_'):
treebank = treebank[3:]
- splits = treebank.split('-')
+ # special case starting with zh in case the input is an already-converted ZH treebank
+ if treebank.startswith("zh-hans") or treebank.startswith("zh-hant"):
+ splits = (treebank[:len("zh-hans")], treebank[len("zh-hans")+1:])
+ else:
+ splits = treebank.split('-')
+ if len(splits) == 1:
+ splits = treebank.split("_", 1)
assert len(splits) == 2, "Unable to process %s" % treebank
lang, corpus = splits
@@ -174,3 +181,9 @@ def treebank_to_short_name(treebank):
short = "{}_{}".format(lcode, corpus.lower())
return short
+
+def treebank_to_langid(treebank):
+ """ Convert treebank name to langid """
+ short_name = treebank_to_short_name(treebank)
+ return short_name.split("_")[0]
+
diff --git a/stanza/models/common/count_ner_coverage.py b/stanza/models/common/count_ner_coverage.py
new file mode 100644
index 00000000..b5a592c7
--- /dev/null
+++ b/stanza/models/common/count_ner_coverage.py
@@ -0,0 +1,38 @@
+from stanza.models.common import pretrain
+import argparse
+
+def parse_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument('ners', type=str, nargs='*', help='Which treebanks to run on')
+ parser.add_argument('--pretrain', type=str, default="/home/john/stanza_resources/hi/pretrain/hdtb.pt", help='Which pretrain to use')
+ parser.set_defaults(ners=["/home/john/stanza/data/ner/hi_fire2013.train.csv",
+ "/home/john/stanza/data/ner/hi_fire2013.dev.csv"])
+ args = parser.parse_args()
+ return args
+
+
+def read_ner(filename):
+ words = []
+ for line in open(filename).readlines():
+ line = line.strip()
+ if not line:
+ continue
+ if line.split("\t")[1] == 'O':
+ continue
+ words.append(line.split("\t")[0])
+ return words
+
+def count_coverage(pretrain, words):
+ count = 0
+ for w in words:
+ if w in pretrain.vocab:
+ count = count + 1
+ return count / len(words)
+
+args = parse_args()
+pt = pretrain.Pretrain(args.pretrain)
+for dataset in args.ners:
+ words = read_ner(dataset)
+ print(dataset)
+ print(count_coverage(pt, words))
+ print()
diff --git a/stanza/models/common/doc.py b/stanza/models/common/doc.py
index d24ea966..eb1d9f2f 100644
--- a/stanza/models/common/doc.py
+++ b/stanza/models/common/doc.py
@@ -72,6 +72,7 @@ class Document(StanzaObject):
comments: A list of list of strings to use as comments on the sentences, either None or the same length as sentences
"""
self._sentences = []
+ self._lang = None
self._text = None
self._num_tokens = 0
self._num_words = 0
@@ -81,6 +82,16 @@ class Document(StanzaObject):
self._ents = []
@property
+ def lang(self):
+ """ Access the language of this document """
+ return self._lang
+
+ @lang.setter
+ def lang(self, value):
+ """ Set the language of this document """
+ self._lang = value
+
+ @property
def text(self):
""" Access the raw text for this document. """
return self._text
diff --git a/stanza/models/common/hlstm.py b/stanza/models/common/hlstm.py
index bfddb3e4..124c6ad5 100644
--- a/stanza/models/common/hlstm.py
+++ b/stanza/models/common/hlstm.py
@@ -99,14 +99,14 @@ class HighwayLSTM(nn.Module):
for l in range(self.num_layers):
if l > 0:
- input = PackedSequence(self.drop(input.data), input.batch_sizes)
+ input = PackedSequence(self.drop(input.data), input.batch_sizes, input.sorted_indices, input.unsorted_indices)
layer_hx = (hx[0][l * self.num_directions:(l+1)*self.num_directions], hx[1][l * self.num_directions:(l+1)*self.num_directions]) if hx is not None else None
h, (ht, ct) = self.lstm[l](input, seqlens, layer_hx)
hs.append(ht)
cs.append(ct)
- input = PackedSequence(h.data + torch.sigmoid(self.gate[l](input.data)) * highway_func(self.highway[l](input.data)), input.batch_sizes)
+ input = PackedSequence(h.data + torch.sigmoid(self.gate[l](input.data)) * highway_func(self.highway[l](input.data)), input.batch_sizes, input.sorted_indices, input.unsorted_indices)
if self.pad:
input = pad_packed_sequence(input, batch_first=self.batch_first)[0]
diff --git a/stanza/models/common/utils.py b/stanza/models/common/utils.py
index 842bf6fd..69a5ee7d 100644
--- a/stanza/models/common/utils.py
+++ b/stanza/models/common/utils.py
@@ -6,7 +6,9 @@ import os
from collections import Counter
import random
import json
+import sys
import unicodedata
+
import torch
import numpy as np
@@ -299,3 +301,32 @@ def warn_missing_tags(known_tags, test_tags, test_set_name):
logger.warning("Found tags in {} missing from the expected tag set: {}".format(test_set_name, missing_tags))
return True
return False
+
+def get_tqdm():
+ """
+ Return a tqdm appropriate for the situation
+
+ imports tqdm depending on if we're at a console, redir to a file, notebook, etc
+
+ from @tcrimi at https://github.com/tqdm/tqdm/issues/506
+
+ This replaces `import tqdm`, so for example, you do this:
+ tqdm = utils.get_tqdm()
+ then do this when you want a scroll bar or regular iterator depending on context:
+ tqdm(list)
+ """
+ try:
+ ipy_str = str(type(get_ipython()))
+ if 'zmqshell' in ipy_str:
+ from tqdm import tqdm_notebook as tqdm
+ return tqdm
+ if 'terminal' in ipy_str:
+ from tqdm import tqdm
+ return tqdm
+ except:
+ if sys.stderr.isatty():
+ from tqdm import tqdm
+ return tqdm
+ def tqdm(iterable, **kwargs):
+ return iterable
+ return tqdm
diff --git a/stanza/models/constituency/__init__.py b/stanza/models/constituency/__init__.py
new file mode 100644
index 00000000..e69de29b
--- /dev/null
+++ b/stanza/models/constituency/__init__.py
diff --git a/stanza/models/constituency/base_model.py b/stanza/models/constituency/base_model.py
new file mode 100644
index 00000000..0a4ee102
--- /dev/null
+++ b/stanza/models/constituency/base_model.py
@@ -0,0 +1,205 @@
+"""
+The BaseModel is passed to the transitions so that the transitions
+can operate on a parsing state without knowing the exact
+representation used in the model.
+
+For example, a SimpleModel simply looks at the top of the various stacks in the state.
+
+A model with LSTM representations for the different transitions may
+attach the hidden and output states of the LSTM to the word /
+constituent / transition stacks.
+
+Reminder: the parsing state is a list of words to parse, the
+transitions used to build a (possibly incomplete) parse, and the
+constituent(s) built so far by those transitions. Each of these
+components are represented using stacks to improve the efficiency
+of operations such as "combine the most recent 4 constituents"
+or "turn the next input word into a constituent"
+"""
+
+from abc import ABC, abstractmethod
+
+from stanza.models.constituency.parse_transitions import TransitionScheme
+from stanza.models.constituency.parse_tree import Tree
+from stanza.models.constituency.tree_stack import TreeStack
+
+class BaseModel(ABC):
+ """
+ This base class defines abstract methods for manipulating a State.
+
+ Applying transitions may change important metadata about a State
+ such as the vectors associated with LSTM hidden states, for example.
+ """
+ @abstractmethod
+ def initial_word_queues(self, tagged_word_lists):
+ """
+ For each list of tagged words, builds a TreeStack of word nodes
+
+ The word lists should be backwards so that the first word is the last word put on the stack (LIFO)
+ """
+
+ @abstractmethod
+ def initial_transitions(self):
+ """
+ Builds an initial transition stack with whatever values need to go into first position
+ """
+
+ @abstractmethod
+ def initial_constituents(self):
+ """
+ Builds an initial constituent stack with whatever values need to go into first position
+ """
+
+ @abstractmethod
+ def get_word(self, word_node):
+ """
+ Get the word corresponding to this position in the word queue
+ """
+
+ @abstractmethod
+ def transform_word_to_constituent(self, state):
+ """
+ Transform the top node of word_queue to something that can push on the constituent stack
+ """
+
+ @abstractmethod
+ def dummy_constituent(self, dummy):
+ """
+ When using a dummy node as a sentinel, transform it to something usable by this model
+ """
+
+ @abstractmethod
+ def unary_transform(self, constituents, labels):
+ """
+ Transform the top of the constituent stack using a unary transform to the new label
+ """
+
+ @abstractmethod
+ def build_constituents(self, labels, children_lists):
+ """
+ Build multiple constituents at once. This gives the opportunity for batching operations
+ """
+
+ @abstractmethod
+ def push_constituents(self, constituent_stacks, constituents):
+ """
+ Add a multiple constituents to multiple constituent_stacks
+
+ Useful to factor this out in case batching will help
+ """
+
+ @abstractmethod
+ def get_top_constituent(self, constituents):
+ """
+ Get the first constituent from the constituent stack
+
+ For example, a model might want to remove embeddings and LSTM state vectors
+ """
+
+ @abstractmethod
+ def push_transitions(self, transition_stacks, transitions):
+ """
+ Add a multiple transitions to multiple transition_stacks
+
+ Useful to factor this out in case batching will help
+ """
+
+ @abstractmethod
+ def get_top_transition(self, transitions):
+ """
+ Get the first transition from the transition stack
+
+ For example, a model might want to remove transition embeddings before returning the transition
+ """
+
+ def get_root_labels(self):
+ """
+ Return ROOT labels for this model. Probably ROOT, TOP, or both
+ """
+ return ("ROOT",)
+
+ @abstractmethod
+ def transition_scheme(self):
+ """
+ Transition scheme used - see parse_transitions
+ """
+
+ @abstractmethod
+ def has_unary_transitions(self):
+ """
+ Whether or not this model uses unary transitions, based on transition_scheme
+ """
+
+ @abstractmethod
+ def is_top_down(self):
+ """
+ Whether or not this model is TOP_DOWN
+ """
+
+class SimpleModel(BaseModel):
+ """
+ This model allows pushing and popping with no extra data
+ """
+ def __init__(self, transition_scheme=TransitionScheme.TOP_DOWN_UNARY):
+ self._transition_scheme = transition_scheme
+
+ def initial_word_queues(self, tagged_word_lists):
+ word_queues = []
+ for tagged_words in tagged_word_lists:
+ word_queue = [tag_node for tag_node in tagged_words]
+ word_queue.reverse()
+ word_queue.append(None)
+ word_queues.append(word_queue)
+ return word_queues
+
+ def initial_transitions(self):
+ return TreeStack(value=None, parent=None, length=1)
+
+ def initial_constituents(self):
+ return TreeStack(value=None, parent=None, length=1)
+
+ def get_word(self, word_node):
+ return word_node
+
+ def transform_word_to_constituent(self, state):
+ return state.word_queue[state.word_position]
+
+ def dummy_constituent(self, dummy):
+ return dummy
+
+ def unary_transform(self, constituents, labels):
+ top_constituent = constituents.value
+ for label in reversed(labels):
+ top_constituent = Tree(label=label, children=[top_constituent])
+ return top_constituent
+
+ def build_constituents(self, labels, children_lists):
+ constituents = []
+ for label, children in zip(labels, children_lists):
+ if isinstance(label, str):
+ label = (label,)
+ for value in reversed(label):
+ children = Tree(label=value, children=children)
+ constituents.append(children)
+ return constituents
+
+ def push_constituents(self, constituent_stacks, constituents):
+ return [stack.push(constituent) for stack, constituent in zip(constituent_stacks, constituents)]
+
+ def get_top_constituent(self, constituents):
+ return constituents.value
+
+ def push_transitions(self, transition_stacks, transitions):
+ return [stack.push(transition) for stack, transition in zip(transition_stacks, transitions)]
+
+ def get_top_transition(self, transitions):
+ return transitions.value
+
+ def transition_scheme(self):
+ return self._transition_scheme
+
+ def has_unary_transitions(self):
+ return self._transition_scheme is TransitionScheme.TOP_DOWN_UNARY
+
+ def is_top_down(self):
+ return self._transition_scheme in (TransitionScheme.TOP_DOWN, TransitionScheme.TOP_DOWN_UNARY, TransitionScheme.TOP_DOWN_COMPOUND)
diff --git a/stanza/models/constituency/lstm_model.py b/stanza/models/constituency/lstm_model.py
new file mode 100644
index 00000000..ec9b1a8f
--- /dev/null
+++ b/stanza/models/constituency/lstm_model.py
@@ -0,0 +1,543 @@
+"""
+A version of the BaseModel which uses LSTMs to predict the correct next transition
+based on the current known state.
+
+The primary purpose of this class is to implement the prediction of the next
+transition, which is done by concatenating the output of an LSTM operated over
+previous transitions, the words, and the partially built constituents.
+"""
+
+from collections import namedtuple
+import logging
+from operator import itemgetter
+import random
+import torch
+import torch.nn as nn
+from torch.nn.utils.rnn import pack_padded_sequence
+
+from stanza.models.common.data import get_long_tensor
+from stanza.models.common.utils import unsort
+from stanza.models.common.vocab import PAD_ID, UNK_ID
+from stanza.models.constituency.base_model import BaseModel
+from stanza.models.constituency.parse_transitions import TransitionScheme
+from stanza.models.constituency.parse_tree import Tree
+from stanza.models.constituency.tree_stack import TreeStack
+
+logger = logging.getLogger('stanza')
+
+WordNode = namedtuple("WordNode", ['value', 'hx'])
+TransitionNode = namedtuple("TransitionNode", ['value', 'output', 'hx', 'cx'])
+
+# Invariant: the output at the top of the constituency stack will have a
+# single dimension
+# We do this to maintain consistency between the different operations,
+# which sometimes result in different shapes
+# This will be unsqueezed in order to put into the next layer if needed
+# hx & cx are the hidden & cell states of the LSTM going across constituents
+ConstituentNode = namedtuple("ConstituentNode", ['value', 'output', 'hx', 'cx'])
+Constituent = namedtuple("Constituent", ['value', 'hx'])
+
+
+class LSTMModel(BaseModel, nn.Module):
+ def __init__(self, pretrain, forward_charlm, backward_charlm, transitions, constituents, tags, words, rare_words, root_labels, open_nodes, args):
+ """
+ pretrain: a Pretrain object
+ transitions: a list of all possible transitions which will be
+ used to build trees
+ constituents: a list of all possible constituents in the treebank
+ tags: a list of all possible tags in the treebank
+ words: a list of all known words, used for a delta word embedding.
+ note that there will be an attempt made to learn UNK words as well,
+ and tags by themselves may help UNK words
+ rare_words: a list of rare words, used to occasionally replace with UNK
+ root_labels: probably ROOT, although apparently some treebanks like TOP
+ open_nodes: a list of all possible open nodes which will go on the stack
+ - this might be different from constituents if there are nodes
+ which represent multiple constituents at once
+ args: hidden_size, transition_hidden_size, etc as gotten from
+ constituency_parser.py
+
+ Note that it might look like a hassle to pass all of this in
+ when it can be collected directly from the trees themselves.
+ However, that would only work at train time. At eval or
+ pipeline time we will load the lists from the saved model.
+ """
+ super().__init__()
+ self.args = args
+ self.unsaved_modules = []
+
+ emb_matrix = pretrain.emb
+ self.add_unsaved_module('embedding', nn.Embedding.from_pretrained(torch.from_numpy(emb_matrix), freeze=True))
+
+ self.vocab_map = { word: i for i, word in enumerate(pretrain.vocab) }
+ # precompute tensors for the word indices
+ # the tensors should be put on the GPU if needed with a call to cuda()
+ self.register_buffer('vocab_tensors', torch.tensor(range(len(pretrain.vocab)), requires_grad=False))
+ self.vocab_size = emb_matrix.shape[0]
+ self.embedding_dim = emb_matrix.shape[1]
+
+ self.root_labels = sorted(list(root_labels))
+ self.constituents = sorted(list(constituents))
+ self.constituent_map = { x: i for (i, x) in enumerate(self.constituents) }
+ # precompute tensors for the constituents
+ self.register_buffer('constituent_tensors', torch.tensor(range(len(self.constituent_map)), requires_grad=False))
+
+ self.hidden_size = self.args['hidden_size']
+ self.transition_hidden_size = self.args['transition_hidden_size']
+ self.tag_embedding_dim = self.args['tag_embedding_dim']
+ self.transition_embedding_dim = self.args['transition_embedding_dim']
+ self.delta_embedding_dim = self.args['delta_embedding_dim']
+ self.word_input_size = self.embedding_dim + self.tag_embedding_dim + self.delta_embedding_dim
+
+ if forward_charlm is not None:
+ self.add_unsaved_module('forward_charlm', forward_charlm)
+ self.add_unsaved_module('forward_charlm_vocab', forward_charlm.char_vocab())
+ self.word_input_size += self.forward_charlm.hidden_dim()
+ else:
+ self.forward_charlm = None
+ if backward_charlm is not None:
+ self.add_unsaved_module('backward_charlm', backward_charlm)
+ self.add_unsaved_module('backward_charlm_vocab', backward_charlm.char_vocab())
+ self.word_input_size += self.backward_charlm.hidden_dim()
+ else:
+ self.backward_charlm = None
+
+ # TODO: add a max_norm?
+ self.delta_words = sorted(list(words))
+ self.delta_word_map = { word: i+2 for i, word in enumerate(self.delta_words) }
+ assert PAD_ID == 0
+ assert UNK_ID == 1
+ self.delta_embedding = nn.Embedding(num_embeddings = len(self.delta_words)+2,
+ embedding_dim = self.delta_embedding_dim,
+ padding_idx = 0)
+ self.register_buffer('delta_tensors', torch.tensor(range(len(self.delta_words) + 2), requires_grad=False))
+
+ self.rare_words = set(rare_words)
+
+ self.tags = sorted(list(tags))
+ if self.tag_embedding_dim > 0:
+ self.tag_map = { t: i for i, t in enumerate(self.tags) }
+ self.tag_embedding = nn.Embedding(num_embeddings = len(tags),
+ embedding_dim = self.tag_embedding_dim)
+ self.register_buffer('tag_tensors', torch.tensor(range(len(self.tags)), requires_grad=False))
+
+ self.transitions = sorted(list(transitions))
+ self.transition_map = { t: i for i, t in enumerate(self.transitions) }
+ # precompute tensors for the transitions
+ self.register_buffer('transition_tensors', torch.tensor(range(len(transitions)), requires_grad=False))
+ self.transition_embedding = nn.Embedding(num_embeddings = len(transitions),
+ embedding_dim = self.transition_embedding_dim)
+
+ self.num_layers = self.args['num_lstm_layers']
+ self.lstm_layer_dropout = self.args['lstm_layer_dropout']
+
+ # also register a buffer of zeros so that we can always get zeros on the appropriate device
+ self.register_buffer('word_zeros', torch.zeros(self.hidden_size))
+ self.register_buffer('transition_zeros', torch.zeros(self.num_layers, 1, self.transition_hidden_size))
+ self.register_buffer('constituent_zeros', torch.zeros(self.num_layers, 1, self.hidden_size))
+
+ self.word_lstm = nn.LSTM(input_size=self.word_input_size, hidden_size=self.hidden_size, num_layers=self.num_layers, bidirectional=True, dropout=self.lstm_layer_dropout)
+
+ # after putting the word_delta_tag input through the word_lstm, we get back
+ # hidden_size * 2 output with the front and back lstms concatenated.
+ # this transforms it into hidden_size with the values mixed together
+ self.word_to_constituent = nn.Linear(self.hidden_size * 2, self.hidden_size)
+
+ self.transition_lstm = nn.LSTM(input_size=self.transition_embedding_dim, hidden_size=self.transition_hidden_size, num_layers=self.num_layers, dropout=self.lstm_layer_dropout)
+ # input_size is hidden_size - could introduce a new constituent_size instead if we liked
+ self.constituent_lstm = nn.LSTM(input_size=self.hidden_size, hidden_size=self.hidden_size, num_layers=self.num_layers, dropout=self.lstm_layer_dropout)
+
+ self._transition_scheme = args['transition_scheme']
+ if self._transition_scheme is TransitionScheme.TOP_DOWN_UNARY:
+ unary_transforms = {}
+ for constituent in self.constituent_map:
+ unary_transforms[constituent] = nn.Linear(self.hidden_size, self.hidden_size)
+ self.unary_transforms = nn.ModuleDict(unary_transforms)
+
+ self.open_nodes = sorted(list(open_nodes))
+ # an embedding for the spot on the constituent LSTM taken up by the Open transitions
+ # the pattern when condensing constituents is embedding - con1 - con2 - con3 - embedding
+ # TODO: try the two ends have different embeddings?
+ self.open_node_map = { x: i for (i, x) in enumerate(self.open_nodes) }
+ self.open_node_embedding = nn.Embedding(num_embeddings = len(self.open_node_map),
+ embedding_dim = self.hidden_size)
+
+ # TODO: remove this `get` once it's not needed
+ if args.get('combined_dummy_embedding', False):
+ self.dummy_embedding = self.open_node_embedding
+ else:
+ self.dummy_embedding = nn.Embedding(num_embeddings = len(self.open_node_map),
+ embedding_dim = self.hidden_size)
+ self.register_buffer('open_node_tensors', torch.tensor(range(len(open_nodes)), requires_grad=False))
+
+ # forward and backward pieces for crunching several
+ # constituents into one, combined into a bi-lstm
+ # TODO: make the hidden size here an option?
+ self.constituent_reduce_lstm = nn.LSTM(input_size=self.hidden_size, hidden_size=self.hidden_size, num_layers=self.num_layers, bidirectional=True, dropout=self.lstm_layer_dropout)
+ # affine transformation from bi-lstm reduce to a new hidden layer
+ self.reduce_linear = nn.Linear(self.hidden_size * 2, self.hidden_size)
+
+ if self.args['nonlinearity'] == 'tanh':
+ self.nonlinearity = nn.Tanh()
+ elif self.args['nonlinearity'] == 'relu':
+ self.nonlinearity = nn.ReLU()
+ elif self.args['nonlinearity'] == 'gelu':
+ self.nonlinearity = nn.GELU()
+ else:
+ raise ValueError('Chosen value of nonlinearity, "%s", not handled' % self.args['nonlinearity'])
+
+ self.word_dropout = nn.Dropout(self.args['word_dropout'])
+ self.predict_dropout = nn.Dropout(self.args['predict_dropout'])
+ self.lstm_input_dropout = nn.Dropout(self.args['lstm_input_dropout'])
+
+ # matrix for predicting the next transition using word/constituent/transition queues
+ # word size + constituency size + transition size
+ middle_layers = self.args['num_output_layers'] - 1
+ predict_input_size = [self.hidden_size * 2 + self.transition_hidden_size] + [self.hidden_size] * middle_layers
+ predict_output_size = [self.hidden_size] * middle_layers + [len(transitions)]
+ self.output_layers = nn.ModuleList([nn.Linear(input_size, output_size)
+ for input_size, output_size in zip(predict_input_size, predict_output_size)])
+
+ self.constituency_lstm = self.args['constituency_lstm']
+
+ def add_unsaved_module(self, name, module):
+ """
+ Adds a module which will not be saved to disk
+
+ Best used for large models such as pretrained word embeddings
+ """
+ self.unsaved_modules += [name]
+ setattr(self, name, module)
+
+ def get_root_labels(self):
+ return self.root_labels
+
+ def build_char_representation(self, all_word_labels, device, forward):
+ CHARLM_START = "\n"
+ CHARLM_END = " "
+
+ if forward:
+ charlm = self.forward_charlm
+ vocab = self.forward_charlm_vocab
+ else:
+ charlm = self.backward_charlm
+ vocab = self.backward_charlm_vocab
+
+ all_data = []
+ for idx, word_labels in enumerate(all_word_labels):
+ if forward:
+ word_labels = reversed(word_labels)
+ else:
+ word_labels = [x[::-1] for x in word_labels]
+
+ chars = [CHARLM_START]
+ offsets = []
+ for w in word_labels:
+ chars.extend(w)
+ chars.append(CHARLM_END)
+ offsets.append(len(chars) - 1)
+ if not forward:
+ offsets.reverse()
+ chars = vocab.map(chars)
+ all_data.append((chars, offsets, len(chars), len(all_data)))
+
+ all_data.sort(key=itemgetter(2), reverse=True)
+ chars, char_offsets, char_lens, orig_idx = tuple(zip(*all_data))
+ chars = get_long_tensor(chars, len(all_data), pad_id=vocab.unit2id(' ')).to(device=device)
+
+ # TODO: surely this should be stuffed in the charlm model itself rather than done here
+ with torch.no_grad():
+ output, _, _ = charlm.forward(chars, char_lens)
+ res = [output[i, offsets] for i, offsets in enumerate(char_offsets)]
+ res = unsort(res, orig_idx)
+
+ return res
+
+ def initial_word_queues(self, tagged_word_lists):
+ """
+ Produce initial word queues out of the model's LSTMs for use in the tagged word lists.
+
+ Operates in a batched fashion to reduce the runtime for the LSTM operations
+ """
+ device = next(self.parameters()).device
+
+ all_word_inputs = []
+ all_word_labels = []
+ for sentence_idx, tagged_words in enumerate(tagged_word_lists):
+ word_idx = torch.stack([self.vocab_tensors[self.vocab_map.get(word.children[0].label, UNK_ID)] for word in tagged_words])
+ word_input = self.embedding(word_idx)
+
+ # this occasionally learns UNK at train time
+ word_labels = [word.children[0].label for word in tagged_words]
+ if self.training:
+ delta_labels = [None if word in self.rare_words and random.random() < self.args['rare_word_unknown_frequency'] else word
+ for word in word_labels]
+ else:
+ delta_labels = word_labels
+ delta_idx = torch.stack([self.delta_tensors[self.delta_word_map.get(word, UNK_ID)] for word in delta_labels])
+
+ delta_input = self.delta_embedding(delta_idx)
+
+ word_inputs = [word_input, delta_input]
+
+ if self.tag_embedding_dim > 0:
+ try:
+ tag_idx = torch.stack([self.tag_tensors[self.tag_map[word.label]] for word in tagged_words])
+ tag_input = self.tag_embedding(tag_idx)
+ word_inputs.append(tag_input)
+ except KeyError as e:
+ raise KeyError("Constituency parser not trained with tag {}".format(str(e))) from e
+
+ all_word_labels.append(word_labels)
+ all_word_inputs.append(word_inputs)
+
+ if self.forward_charlm is not None:
+ all_forward_chars = self.build_char_representation(all_word_labels, device, forward=True)
+ for word_inputs, forward_chars in zip(all_word_inputs, all_forward_chars):
+ word_inputs.append(forward_chars)
+ if self.backward_charlm is not None:
+ all_backward_chars = self.build_char_representation(all_word_labels, device, forward=False)
+ for word_inputs, backward_chars in zip(all_word_inputs, all_backward_chars):
+ word_inputs.append(backward_chars)
+
+ word_lstm_input = torch.zeros((max(len(x) for x in tagged_word_lists), len(tagged_word_lists), self.word_input_size), device=device)
+
+ for sentence_idx, word_inputs in enumerate(all_word_inputs):
+ # now of size sentence x input
+ word_input = torch.cat(word_inputs, dim=1)
+ word_input = self.word_dropout(word_input)
+
+ word_lstm_input[:word_input.shape[0], sentence_idx, :] = word_input
+
+ packed_word_input = torch.nn.utils.rnn.pack_padded_sequence(word_lstm_input, [len(x) for x in tagged_word_lists], enforce_sorted=False)
+ word_output, _ = self.word_lstm(packed_word_input)
+ # would like to do word_to_constituent here, but it seems PackedSequence doesn't support Linear
+ # word_output will now be sentence x batch x 2*hidden_size
+ word_output, word_output_lens = torch.nn.utils.rnn.pad_packed_sequence(word_output)
+ # now sentence x batch x hidden_size
+
+ word_queues = []
+ for sentence_idx, tagged_words in enumerate(tagged_word_lists):
+ sentence_output = word_output[:len(tagged_words), sentence_idx, :]
+ sentence_output = self.word_to_constituent(sentence_output)
+ sentence_output = self.nonlinearity(sentence_output)
+ # TODO: this makes it so constituents downstream are
+ # build with the outputs of the LSTM, not the word
+ # embeddings themselves. It is possible we want to
+ # transform the word_input to hidden_size in some way
+ # and use that instead
+ word_queue = [WordNode(tag_node, sentence_output[idx, :])
+ for idx, tag_node in enumerate(tagged_words)]
+ word_queue.reverse()
+ word_queue.append(WordNode(None, self.word_zeros))
+
+ word_queues.append(word_queue)
+
+ return word_queues
+
+ def initial_transitions(self):
+ """
+ Return an initial TreeStack with no transitions
+ """
+ return TreeStack(value=TransitionNode(None, self.transition_zeros[-1, 0, :], self.transition_zeros, self.transition_zeros), parent=None, length=1)
+
+ def initial_constituents(self):
+ """
+ Return an initial TreeStack with no constituents
+ """
+ return TreeStack(value=ConstituentNode(None, self.constituent_zeros[-1, 0, :], self.constituent_zeros, self.constituent_zeros), parent=None, length=1)
+
+ def get_word(self, word_node):
+ return word_node.value
+
+ def transform_word_to_constituent(self, state):
+ word_node = state.word_queue[state.word_position]
+ word = word_node.value
+ return Constituent(value=word, hx=word_node.hx)
+
+ def dummy_constituent(self, dummy):
+ label = dummy.label
+ open_index = self.open_node_tensors[self.open_node_map[label]]
+ hx = self.dummy_embedding(open_index)
+ return Constituent(value=dummy, hx=hx)
+
+ def unary_transform(self, constituents, labels):
+ top_constituent = constituents.value
+ node = top_constituent.value
+ hx = top_constituent.output
+ for label in reversed(labels):
+ node = Tree(label=label, children=[node])
+ hx = self.unary_transforms[label](hx)
+ # non-linearity after the unary transform
+ hx = self.nonlinearity(hx)
+ top_constituent = Constituent(value=node, hx=hx)
+ return top_constituent
+
+ def build_constituents(self, labels, children_lists):
+ label_hx = [self.open_node_embedding(self.open_node_tensors[self.open_node_map[label]]) for label in labels]
+
+ max_length = max(len(children) for children in children_lists)
+ zeros = torch.zeros(self.hidden_size, device=label_hx[0].device)
+ node_hx = [[child.output for child in children] for children in children_lists]
+ # weirdly, this is faster than using pack_sequence
+ unpacked_hx = [[lhx] + nhx + [lhx] + [zeros] * (max_length - len(nhx)) for lhx, nhx in zip(label_hx, node_hx)]
+ unpacked_hx = [self.lstm_input_dropout(torch.stack(nhx)) for nhx in unpacked_hx]
+ packed_hx = torch.stack(unpacked_hx, axis=1)
+ packed_hx = torch.nn.utils.rnn.pack_padded_sequence(packed_hx, [len(x)+2 for x in children_lists], enforce_sorted=False)
+ lstm_output = self.constituent_reduce_lstm(packed_hx)
+ # take just the output of the final layer
+ # result of lstm is ouput, (hx, cx)
+ # so [1][0] gets hx
+ # [1][0][-1] is the final output
+ # will be shape len(children_lists) * 2, hidden_size for bidirectional
+ # where forward outputs are -2 and backwards are -1
+ lstm_output = lstm_output[1][0]
+ forward_hx = lstm_output[-2, :]
+ backward_hx = lstm_output[-1, :]
+
+ hx = self.reduce_linear(torch.cat((forward_hx, backward_hx), axis=1))
+ hx = self.nonlinearity(hx)
+
+ constituents = []
+ for idx, (label, children) in enumerate(zip(labels, children_lists)):
+ children = [child.value for child in children]
+ if isinstance(label, str):
+ node = Tree(label=label, children=children)
+ else:
+ for value in reversed(label):
+ node = Tree(label=value, children=children)
+ children = node
+ constituents.append(Constituent(value=node, hx=hx[idx, :]))
+ return constituents
+
+ def push_constituents(self, constituent_stacks, constituents):
+ current_nodes = [stack.value for stack in constituent_stacks]
+
+ constituent_input = torch.stack([x.hx for x in constituents])
+ constituent_input = constituent_input.unsqueeze(0)
+ constituent_input = self.lstm_input_dropout(constituent_input)
+
+ hx = torch.cat([current_node.hx for current_node in current_nodes], axis=1)
+ cx = torch.cat([current_node.cx for current_node in current_nodes], axis=1)
+ output, (hx, cx) = self.constituent_lstm(constituent_input, (hx, cx))
+ if self.constituency_lstm:
+ new_stacks = [stack.push(ConstituentNode(constituent.value, output[0, i, :], hx[:, i:i+1, :], cx[:, i:i+1, :]))
+ for i, (stack, constituent) in enumerate(zip(constituent_stacks, constituents))]
+ else:
+ new_stacks = [stack.push(ConstituentNode(constituent.value, constituents[i].hx, hx[:, i:i+1, :], cx[:, i:i+1, :]))
+ for i, (stack, constituent) in enumerate(zip(constituent_stacks, constituents))]
+ return new_stacks
+
+ def get_top_constituent(self, constituents):
+ """
+ Extract only the top constituent from a state's constituent
+ sequence, even though it has multiple addition pieces of
+ information
+ """
+ constituent_node = constituents.value
+ return constituent_node.value
+
+ def push_transitions(self, transition_stacks, transitions):
+ transition_idx = torch.stack([self.transition_tensors[self.transition_map[transition]] for transition in transitions])
+ transition_input = self.transition_embedding(transition_idx).unsqueeze(0)
+ transition_input = self.lstm_input_dropout(transition_input)
+
+ hx = torch.cat([t.value.hx for t in transition_stacks], axis=1)
+ cx = torch.cat([t.value.cx for t in transition_stacks], axis=1)
+ output, (hx, cx) = self.transition_lstm(transition_input, (hx, cx))
+ new_stacks = [stack.push(TransitionNode(transition, output[0, i, :], hx[:, i:i+1, :], cx[:, i:i+1, :]))
+ for i, (stack, transition) in enumerate(zip(transition_stacks, transitions))]
+ return new_stacks
+
+ def get_top_transition(self, transitions):
+ """
+ Extract only the top transition from a state's transition
+ sequence, even though it has multiple addition pieces of
+ information
+ """
+ transition_node = transitions.value
+ return transition_node.value
+
+ def transition_scheme(self):
+ return self._transition_scheme
+
+ def has_unary_transitions(self):
+ return self._transition_scheme is TransitionScheme.TOP_DOWN_UNARY
+
+ def is_top_down(self):
+ return self._transition_scheme in (TransitionScheme.TOP_DOWN, TransitionScheme.TOP_DOWN_UNARY, TransitionScheme.TOP_DOWN_COMPOUND)
+
+ def forward(self, states):
+ """
+ Return logits for a prediction of what transition to make next
+
+ We've basically done all the work analyzing the state as
+ part of applying the transitions, so this method is very simple
+ """
+ word_hx = torch.stack([state.word_queue[state.word_position].hx for state in states])
+ transition_hx = torch.stack([state.transitions.value.output for state in states])
+ # note that we use hx instead of output from the constituents
+ # this way, we can, as an option, NOT include the constituents to the left
+ # when building the current vector for a constituent
+ # and the vector used for inference will still incorporate the entire LSTM
+ constituent_hx = torch.stack([state.constituents.value.hx[-1, 0, :] for state in states])
+
+ hx = torch.cat((word_hx, transition_hx, constituent_hx), axis=1)
+ for idx, output_layer in enumerate(self.output_layers):
+ hx = self.predict_dropout(hx)
+ if idx < len(self.output_layers) - 1:
+ hx = self.nonlinearity(hx)
+ hx = output_layer(hx)
+ return hx
+
+ # TODO: merge this with forward?
+ def predict(self, states, is_legal=False):
+ """
+ Generate and return predictions, along with the transitions those predictions represent
+
+ If is_legal is set to True, will only return legal transitions.
+ This means returning None if there are no legal transitions.
+ Hopefully the constraints prevent that from happening
+ """
+ predictions = self.forward(states)
+ pred_max = torch.argmax(predictions, axis=1)
+
+ pred_trans = [self.transitions[pred_max[idx]] for idx in range(len(states))]
+ if is_legal:
+ for idx, (state, trans) in enumerate(zip(states, pred_trans)):
+ if not trans.is_legal(state, self):
+ _, indices = predictions[idx, :].sort(descending=True)
+ for index in indices:
+ if self.transitions[index].is_legal(state, self):
+ pred_trans[idx] = self.transitions[index]
+ break
+ else: # yeah, else on a for loop, deal with it
+ pred_trans[idx] = None
+
+ return predictions, pred_trans
+
+ def get_params(self, skip_modules=True):
+ """
+ Get a dictionary for saving the model
+ """
+ model_state = self.state_dict()
+ # skip saving modules like pretrained embeddings, because they are large and will be saved in a separate file
+ if skip_modules:
+ skipped = [k for k in model_state.keys() if k.split('.')[0] in self.unsaved_modules]
+ for k in skipped:
+ del model_state[k]
+ params = {
+ 'model': model_state,
+ 'model_type': "LSTM",
+ 'config': self.args,
+ 'transitions': self.transitions,
+ 'constituents': self.constituents,
+ 'tags': self.tags,
+ 'words': self.delta_words,
+ 'rare_words': self.rare_words,
+ 'root_labels': self.root_labels,
+ 'open_nodes': self.open_nodes,
+ }
+
+ return params
+
diff --git a/stanza/models/constituency/parse_transitions.py b/stanza/models/constituency/parse_transitions.py
new file mode 100644
index 00000000..ec815caa
--- /dev/null
+++ b/stanza/models/constituency/parse_transitions.py
@@ -0,0 +1,603 @@
+"""
+Defines a series of transitions (open a constituent, close a constituent, etc
+
+Also defines a State which holds the various data needed to build
+a parse tree out of tagged words.
+"""
+
+from abc import ABC, abstractmethod
+from collections import defaultdict, namedtuple
+from enum import Enum
+import functools
+import logging
+
+from stanza.models.constituency.parse_tree import Tree
+
+logger = logging.getLogger('stanza')
+
+class TransitionScheme(Enum):
+ TOP_DOWN = 1
+ TOP_DOWN_COMPOUND = 2
+ TOP_DOWN_UNARY = 3
+
+ IN_ORDER = 4
+
+UNARY_LIMIT = 4
+
+class State(namedtuple('State', ['word_queue', 'transitions', 'constituents', 'gold_tree', 'gold_sequence',
+ 'sentence_length', 'num_opens', 'word_position'])):
+ """
+ Represents a partially completed transition parse
+
+ Includes stack/buffers for unused words, already executed transitions, and partially build constituents
+ At training time, also keeps track of the gold data we are reparsing
+
+ num_opens is useful for tracking
+ 1) if the parser is in a stuck state where it is making infinite opens
+ 2) if a close transition is impossible because there are no previous opens
+
+ sentence_length tracks how long the sentence is so we abort if we go infinite
+
+ non-stack information such as sentence_length and num_opens
+ will be copied from the original_state if possible, with the
+ exact arguments overriding the values in the original_state
+
+ gold_tree: the original tree, if made from a gold tree. might be None
+ gold_sequence: the original transition sequence, if available
+ Note that at runtime, gold values will not be available
+
+ word_position tracks where in the word queue we are. cheaper than
+ manipulating the list itself. this can be handled differently
+ from transitions and constituents as it is processed once
+ at the start of parsing
+ """
+ def empty_word_queue(self):
+ # the first element of each stack is a sentinel with no value
+ # and no parent
+ return self.word_position == self.sentence_length
+
+ def empty_transitions(self):
+ # the first element of each stack is a sentinel with no value
+ # and no parent
+ return self.transitions.parent is None
+
+ def has_one_constituent(self):
+ # a length of 1 represents no constituents
+ return len(self.constituents) == 2
+
+ def num_constituents(self):
+ return len(self.constituents) - 1
+
+ def num_transitions(self):
+ # -1 for the sentinel value
+ return len(self.transitions) - 1
+
+ def finished(self, model):
+ return self.empty_word_queue() and self.has_one_constituent() and model.get_top_constituent(self.constituents).label in model.get_root_labels()
+
+ def get_tree(self, model):
+ return model.get_top_constituent(self.constituents)
+
+ def all_transitions(self, model):
+ # TODO: rewrite this to be nicer / faster? or just refactor?
+ all_transitions = []
+ transitions = self.transitions
+ while transitions.parent is not None:
+ all_transitions.append(model.get_top_transition(transitions))
+ transitions = transitions.parent
+ return list(reversed(all_transitions))
+
+ def all_constituents(self, model):
+ # TODO: rewrite this to be nicer / faster?
+ all_constituents = []
+ constituents = self.constituents
+ while constituents.parent is not None:
+ all_constituents.append(model.get_top_constituent(constituents))
+ constituents = constituents.parent
+ return list(reversed(all_constituents))
+
+ def all_words(self, model):
+ return [model.get_word(x) for x in self.word_queue]
+
+ def to_string(self, model):
+ return "State(\n buffer:%s\n transitions:%s\n constituents:%s)" % (str(self.all_words(model)), str(self.all_transitions(model)), str(self.all_constituents(model)))
+
+ def __str__(self):
+ return "State(\n buffer:%s\n transitions:%s\n constituents:%s)" % (str(self.word_queue), str(self.transitions), str(self.constituents))
+
+def initial_state_from_preterminals(preterminal_lists, model, gold_trees):
+ """
+ what is passed in should be a list of list of preterminals
+ """
+ word_queues = model.initial_word_queues(preterminal_lists)
+ # this is the bottom of the TreeStack and will be the same for each State
+ transitions=model.initial_transitions()
+ constituents=model.initial_constituents()
+ states = [State(sentence_length=len(wq)-1, # -1 because it ends with a sentinel
+ num_opens=0,
+ word_queue=wq,
+ gold_tree=None,
+ gold_sequence=None,
+ transitions=transitions,
+ constituents=constituents,
+ word_position=0)
+ for idx, wq in enumerate(word_queues)]
+ if gold_trees:
+ states = [state._replace(gold_tree=gold_tree) for gold_tree, state in zip(gold_trees, states)]
+ return states
+
+def initial_state_from_words(word_lists, model):
+ # TODO: stop reversing the words
+ preterminal_lists = []
+ for words in word_lists:
+ preterminals = []
+ for word, tag in reversed(words):
+ word_node = Tree(label=word)
+ tag_node = Tree(label=tag, children=[word_node])
+ preterminals.append(tag_node)
+ preterminal_lists.append(preterminals)
+ return initial_state_from_preterminals(preterminal_lists, model, gold_trees=None)
+
+def initial_state_from_gold_trees(trees, model):
+ # reversed so we put the words on the stack backwards
+ preterminal_lists = [[Tree(label=pt.label, children=Tree(label=pt.children[0].label))
+ for pt in tree.yield_reversed_preterminals()]
+ for tree in trees]
+ return initial_state_from_preterminals(preterminal_lists, model, gold_trees=trees)
+
+@functools.total_ordering
+class Transition(ABC):
+ """
+ model is passed in as a dependency injection
+ for example, an LSTM model can update hidden & output vectors when transitioning
+ """
+ @abstractmethod
+ def update_state(self, state, model):
+ """
+ update the word queue position, possibly remove old pieces from the constituents state, and return the new constituent
+
+ the return value should be a tuple:
+ updated word_position
+ updated constituents
+ new constituent to put on the queue and None
+ - note that the constituent shouldn't be on the queue yet
+ that allows putting it on as a batch operation, which
+ saves a significant amount of time in an LSTM, for example
+ OR
+ data used to make a new constituent and the method used
+ - for example, CloseConstituent can return the children needed
+ and itself. this allows a batch operation to build
+ the constituent
+ """
+ pass
+
+ def delta_opens(self):
+ return 0
+
+ def apply(self, state, model):
+ """
+ return a new State transformed via this transition
+ """
+ word_position, constituents, new_constituent, callback = self.update_state(state, model)
+ if callback is not None:
+ new_constituent = callback.build_constituents(model, [new_constituent])[0]
+ constituents = model.push_constituents([constituents], [new_constituent])[0]
+
+ return state._replace(num_opens=state.num_opens + self.delta_opens(),
+ word_position=word_position,
+ transitions=model.push_transitions([state.transitions], [self])[0],
+ constituents=constituents)
+
+ @abstractmethod
+ def is_legal(self, state, model):
+ """
+ assess whether or not this transition is legal in this state
+
+ at parse time, the parser might choose a transition which cannot be made
+ """
+ pass
+
+ def __lt__(self, other):
+ # put the Shift at the front of a list, and otherwise sort alphabetically
+ if self == other:
+ return False
+ if isinstance(self, Shift):
+ return True
+ if isinstance(other, Shift):
+ return False
+ return str(self) < str(other)
+
+class Shift(Transition):
+ def update_state(self, state, model):
+ """
+ This will handle all aspects of a shift transition
+
+ - push the top element of the word queue onto constituents
+ - pop the top element of the word queue
+ """
+ new_constituent = model.transform_word_to_constituent(state)
+ return state.word_position+1, state.constituents, new_constituent, None
+
+ def is_legal(self, state, model):
+ """
+ Disallow shifting when the word queue is empty or there are no opens to eventually eat this word
+ """
+ if state.empty_word_queue():
+ return False
+ if model.is_top_down():
+ # top down transition sequences cannot shift if there are currently no
+ # Open transitions on the stack. in such a case, the new constituent
+ # will never be reduced
+ if state.num_opens == 0:
+ return False
+ if state.num_opens == 1:
+ # there must be at least one transition, since there is an open
+ assert state.transitions.parent is not None
+ if state.transitions.parent.parent is None:
+ # only one transition
+ trans = model.get_top_transition(state.transitions)
+ # must be an Open, since there is one open and one transitions
+ # note that an S, FRAG, etc could happen if we're using unary
+ # and ROOT-S is possible in the case of compound Open
+ # in both cases, Shift is legal
+ # Note that the corresponding problem of shifting after the ROOT-S
+ # has been closed to just ROOT is handled in CloseConstituent
+ if len(trans.label) == 1 and trans.top_label in model.get_root_labels():
+ # don't shift a word at the very start of a parse
+ # we want there to be an extra layer below ROOT
+ return False
+ else:
+ # in-order k==1 (the only other option currently)
+ # can shift ONCE, but note that there is no way to consume
+ # two items in a row if there is no Open on the stack.
+ # As long as there is one or more open transitions,
+ # everything can be eaten
+ if state.num_opens == 0:
+ if state.num_constituents() > 0:
+ return False
+ return True
+
+ def __repr__(self):
+ return "Shift"
+
+ def __eq__(self, other):
+ if self is other:
+ return True
+ if isinstance(other, Shift):
+ return True
+ return False
+
+ def __hash__(self):
+ return hash(37)
+
+class CompoundUnary(Transition):
+ # TODO: run experiments to see if this is actually useful
+ def __init__(self, labels):
+ # the FIRST label will be the top of the tree
+ # so CompoundUnary that results in root will have root as labels[0], for example
+ if isinstance(labels, str):
+ self.labels = (labels,)
+ else:
+ self.labels = tuple(labels)
+
+ def update_state(self, state, model):
+ # remove the top constituent
+ # apply the labels
+ # put the constituent back on the state
+ constituents = state.constituents
+ new_constituent = model.unary_transform(state.constituents, self.labels)
+ constituents = constituents.pop()
+ return state.word_position, constituents, new_constituent, None
+
+ def is_legal(self, state, model):
+ """
+ Disallow consecutive CompoundUnary transitions, force final transition to go to ROOT
+ """
+ # can't unary transition nothing
+ if model.get_top_constituent(state.constituents) is None:
+ return False
+ # don't unary transition a dummy, dummy
+ # and don't stack CompoundUnary transitions
+ if isinstance(model.get_top_transition(state.transitions), (CompoundUnary, OpenConstituent)):
+ return False
+ is_root = self.labels[0] in model.get_root_labels()
+ if not state.empty_word_queue() or not state.has_one_constituent():
+ return not is_root
+ else:
+ return is_root
+
+ def __repr__(self):
+ return "CompoundUnary(%s)" % ",".join(self.labels)
+
+ def __eq__(self, other):
+ if self is other:
+ return True
+ if not isinstance(other, CompoundUnary):
+ return False
+ if self.labels == other.labels:
+ return True
+ return False
+
+ def __hash__(self):
+ return hash(self.labels)
+
+class Dummy():
+ """
+ Takes a space on the constituent stack to represent where an Open transition occurred
+ """
+ def __init__(self, label):
+ self.label = label
+
+ def __str__(self):
+ return "Dummy({})".format(self.label)
+
+ def __eq__(self, other):
+ if self is other:
+ return True
+ if not isinstance(other, Dummy):
+ return False
+ if self.label == other.label:
+ return True
+ return False
+
+ def __hash__(self):
+ return hash(self.label)
+
+def too_many_unary_nodes(tree):
+ """
+ Return True iff there are UNARY_LIMIT unary nodes in a tree in a row
+
+ helps prevent infinite open/close patterns
+ otherwise, the model can get stuck in essentially an infinite loop
+ """
+ if tree is None:
+ return False
+ for _ in range(UNARY_LIMIT + 1):
+ if len(tree.children) != 1:
+ return False
+ tree = tree.children[0]
+ return True
+
+class OpenConstituent(Transition):
+ def __init__(self, *label):
+ self.label = tuple(label)
+ self.top_label = self.label[0]
+
+ def delta_opens(self):
+ return 1
+
+ def update_state(self, state, model):
+ # open a new constituent which can later be closed
+ # puts a DUMMY constituent on the stack to mark where the constituents end
+ return state.word_position, state.constituents, model.dummy_constituent(Dummy(self.label)), None
+
+ def is_legal(self, state, model):
+ """
+ disallow based on the length of the sentence
+ """
+ if state.num_opens > state.sentence_length + 5:
+ # fudge a bit so we don't miss root nodes etc in very small trees
+ return False
+ if model.is_top_down():
+ # If the model is top down, you can't Open if there are
+ # no word to eventually eat
+ if state.empty_word_queue():
+ return False
+ # Also, you can only Open a ROOT iff it is at the root position
+ # The assumption in the unary scheme is there will be no
+ # root open transitions
+ if not model.has_unary_transitions():
+ # TODO: maybe cache this value if this is an expensive operation
+ is_root = self.top_label in model.get_root_labels()
+ if is_root:
+ return state.empty_transitions()
+ else:
+ return not state.empty_transitions()
+ else:
+ # in-order nodes can Open as long as there is at least one thing
+ # on the constituency stack
+ # since closing the in-order involves removing one more
+ # item before the open, and it can close at any time
+ # (a close immediately after the open represents a unary)
+ if state.num_constituents() == 0:
+ return False
+ if isinstance(model.get_top_transition(state.transitions), OpenConstituent):
+ # consecutive Opens don't make sense in the context of in-order
+ return False
+ # one other restriction - we assume all parse trees
+ # start with (ROOT (first_real_con ...))
+ # therefore ROOT can only occur via Open after everything
+ # else has been pushed and processed
+ # there are no further restrictions
+ is_root = self.top_label in model.get_root_labels()
+ if is_root:
+ # can't make a root node if it will be in the middle of the parse
+ # can't make a root node if there's still words to eat
+ # note that the second assumption wouldn't work,
+ # except we are assuming there will never be multiple
+ # nodes under one root
+ return state.num_opens == 0 and state.empty_word_queue()
+ else:
+ if (state.num_opens > 0 or state.empty_word_queue()) and too_many_unary_nodes(model.get_top_constituent(state.constituents)):
+ # looks like we've been in a loop of lots of unary transitions
+ # note that we check `num_opens > 0` because otherwise we might wind up stuck
+ # in a state where the only legal transition is open, such as if the
+ # constituent stack is otherwise empty, but the open is illegal because
+ # it causes too many unaries
+ # in such a case we can forbid the corresponding close instead...
+ # if empty_word_queue, that means it is trying to make infinitiely many
+ # non-ROOT Open transitions instead of just transitioning ROOT
+ return False
+ return True
+ return True
+
+ def __repr__(self):
+ return "OpenConstituent({})".format(self.label)
+
+ def __eq__(self, other):
+ if self is other:
+ return True
+ if not isinstance(other, OpenConstituent):
+ return False
+ if self.label == other.label:
+ return True
+ return False
+
+ def __hash__(self):
+ return hash(self.label)
+
+class CloseConstituent(Transition):
+ def delta_opens(self):
+ return -1
+
+ def update_state(self, state, model):
+ # pop constituents until we are done
+ children = []
+ constituents = state.constituents
+ while not isinstance(model.get_top_constituent(constituents), Dummy):
+ # keep the entire value from the stack - the model may need
+ # the whole thing to transform the children into a new node
+ children.append(constituents.value)
+ constituents = constituents.pop()
+ # the Dummy has the label on it
+ label = model.get_top_constituent(constituents).label
+ # pop past the Dummy as well
+ constituents = constituents.pop()
+ if not model.is_top_down():
+ # the alternative to TOP_DOWN_... is IN_ORDER
+ # in which case we want to pop one more constituent
+ children.append(constituents.value)
+ constituents = constituents.pop()
+ # the children are in the opposite order of what we expect
+ children.reverse()
+
+ return state.word_position, constituents, (label, children), CloseConstituent
+
+ @staticmethod
+ def build_constituents(model, data):
+ labels, children_lists = list(map(list, zip(*data)))
+ new_constituents = model.build_constituents(labels, children_lists)
+ return new_constituents
+
+
+ def is_legal(self, state, model):
+ """
+ Disallow if there is no Open on the stack yet
+ in TOP_DOWN, if the previous transition was the Open (nothing built yet)
+ in IN_ORDER, previous transition does not matter, except for one small corner case
+ """
+ if state.num_opens <= 0:
+ return False
+ if model.is_top_down():
+ if isinstance(model.get_top_transition(state.transitions), OpenConstituent):
+ return False
+ if state.num_opens <= 1 and not state.empty_word_queue():
+ # don't close the last open until all words have been used
+ return False
+ if model.transition_scheme() == TransitionScheme.TOP_DOWN_COMPOUND:
+ # when doing TOP_DOWN_COMPOUND, we assume all transitions
+ # at the ROOT level have an S, SQ, FRAG, etc underneath
+ # this is checked when the model is first trained
+ if state.num_opens == 1 and not state.empty_word_queue():
+ return False
+ elif not model.has_unary_transitions():
+ # in fact, we have to leave the top level constituent
+ # under the ROOT open if unary transitions are not possible
+ if state.num_opens == 2 and not state.empty_word_queue():
+ return False
+ else:
+ if not isinstance(model.get_top_transition(state.transitions), OpenConstituent):
+ # we're not stuck in a loop of unaries
+ return True
+ if state.num_opens > 1 or state.empty_word_queue():
+ # in either of these cases, the corresponding Open should be eliminated
+ # if we're stuck in a loop of unaries
+ return True
+ node = model.get_top_constituent(state.constituents.pop())
+ if too_many_unary_nodes(node):
+ # at this point, we are in a situation where
+ # - multiple unaries have happened in a row
+ # - there is stuff on the word_queue, so a ROOT open isn't legal
+ # - there's only one constituent on the stack, so the only legal
+ # option once there are no opens left will be an open
+ # this means we'll be stuck having to open again if we do close
+ # this node, so instead we make the Close illegal
+ return False
+ return True
+
+ def __repr__(self):
+ return "CloseConstituent"
+
+ def __eq__(self, other):
+ if self is other:
+ return True
+ if isinstance(other, CloseConstituent):
+ return True
+ return False
+
+ def __hash__(self):
+ return hash(93)
+
+def bulk_apply(model, tree_batch, transitions, fail=False, max_transitions=1000):
+ remove = set()
+
+ word_positions = []
+ constituents = []
+ new_constituents = []
+ callbacks = defaultdict(list)
+
+ for idx, (tree, transition) in enumerate(zip(tree_batch, transitions)):
+ if not transition:
+ error = "Got stuck and couldn't find a legal transition on the following gold tree:\n{}\n\nFinal state:\n{}".format(tree.gold_tree, tree.to_string(model))
+ if fail:
+ raise ValueError(error)
+ else:
+ logger.error(error)
+ remove.add(idx)
+ continue
+
+ if max_transitions and tree.num_transitions() >= max_transitions:
+ # too many transitions
+ if tree.gold_tree:
+ error = "Went infinite on the following gold tree:\n{}\n\nFinal state:\n{}".format(tree.gold_tree, tree.to_string(model))
+ else:
+ error = "Went infinite!:\nFinal state:\n{}".format(tree.to_string(model))
+ if fail:
+ raise ValueError(error)
+ else:
+ logger.error(error)
+ remove.add(idx)
+ continue
+
+ wq, c, nc, callback = transition.update_state(tree, model)
+
+ word_positions.append(wq)
+ constituents.append(c)
+ new_constituents.append(nc)
+ if callback:
+ # not `idx` in case something was removed
+ callbacks[callback].append(len(new_constituents)-1)
+
+ for key, idxs in callbacks.items():
+ data = [new_constituents[x] for x in idxs]
+ callback_constituents = key.build_constituents(model, data)
+ for idx, constituent in zip(idxs, callback_constituents):
+ new_constituents[idx] = constituent
+
+ tree_batch = [tree for idx, tree in enumerate(tree_batch) if idx not in remove]
+ transitions = [trans for idx, trans in enumerate(transitions) if idx not in remove]
+
+ if len(tree_batch) == 0:
+ return tree_batch
+
+ new_transitions = model.push_transitions([tree.transitions for tree in tree_batch], transitions)
+ new_constituents = model.push_constituents(constituents, new_constituents)
+
+ tree_batch = [state._replace(num_opens=state.num_opens + transition.delta_opens(),
+ word_position=word_position,
+ transitions=transition_stack,
+ constituents=constituents)
+ for (state, transition, word_position, transition_stack, constituents)
+ in zip(tree_batch, transitions, word_positions, new_transitions, new_constituents)]
+
+ return tree_batch
diff --git a/stanza/models/constituency/parse_tree.py b/stanza/models/constituency/parse_tree.py
new file mode 100644
index 00000000..6eb24717
--- /dev/null
+++ b/stanza/models/constituency/parse_tree.py
@@ -0,0 +1,303 @@
+"""
+Tree datastructure
+"""
+
+from collections import deque, Counter
+from io import StringIO
+import re
+
+from stanza.models.common.doc import StanzaObject
+
+# useful more for the "is" functionality than the time savings
+CLOSE_PAREN = ')'
+SPACE_SEPARATOR = ' '
+OPEN_PAREN = '('
+
+EMPTY_CHILDREN = ()
+
+CONSTITUENT_SPLIT = re.compile("[-=#]")
+
+class Tree(StanzaObject):
+ """
+ A data structure to represent a parse tree
+ """
+ def __init__(self, label=None, children=None):
+ if children is None:
+ self.children = EMPTY_CHILDREN
+ elif isinstance(children, Tree):
+ self.children = (children,)
+ else:
+ self.children = children
+
+ self.label = label
+
+ def is_leaf(self):
+ return len(self.children) == 0
+
+ def is_preterminal(self):
+ return len(self.children) == 1 and len(self.children[0].children) == 0
+
+ def yield_reversed_preterminals(self):
+ """
+ Yield the preterminals one at a time in BACKWARDS order
+
+ This is done reversed as it is a frequently used method in the
+ parser, so this is a tiny optimization
+ """
+ nodes = deque()
+ nodes.append(self)
+ while len(nodes) > 0:
+ node = nodes.pop()
+ if len(node.children) == 0:
+ raise ValueError("Got called with an unexpected tree layout: {}".format(self))
+ elif node.is_preterminal():
+ yield node
+ else:
+ nodes.extend(node.children)
+
+ def leaf_labels(self):
+ """
+ Get the labels of the leaves
+
+ Not optimized whatsoever - current not an important part of
+ the parser
+ """
+ preterminals = reversed([x for x in self.yield_reversed_preterminals()])
+ words = [x.children[0].label for x in preterminals]
+ return words
+
+ def preterminals(self):
+ return list(reversed(list(self.yield_reversed_preterminals())))
+
+ def __repr__(self):
+ """
+ Turn the tree into a string representing the tree
+
+ Note that this is not a recursive traversal
+ Otherwise, a tree too deep might blow up the call stack
+ """
+ with StringIO() as buf:
+ stack = deque()
+ stack.append(self)
+ while len(stack) > 0:
+ node = stack.pop()
+ # note that == can recursively call == in some circumstances!
+ if node is CLOSE_PAREN or node is SPACE_SEPARATOR:
+ buf.write(node)
+ continue
+ if len(node.children) == 0:
+ if node.label is not None:
+ buf.write(node.label)
+ continue
+ buf.write(OPEN_PAREN)
+ if node.label is not None:
+ buf.write(node.label)
+ stack.append(CLOSE_PAREN)
+ for child in reversed(node.children):
+ stack.append(child)
+ stack.append(SPACE_SEPARATOR)
+ buf.seek(0)
+ return buf.read()
+
+ def __eq__(self, other):
+ if self is other:
+ return True
+ if not isinstance(other, Tree):
+ return False
+ if self.label != other.label:
+ return False
+ if len(self.children) != len(other.children):
+ return False
+ if any(c1 != c2 for c1, c2 in zip(self.children, other.children)):
+ return False
+ return True
+
+ def depth(self):
+ if not self.children:
+ return 0
+ return 1 + max(x.depth() for x in self.children)
+
+ def visit_preorder(self, internal=None, preterminal=None, leaf=None):
+ """
+ Visit the tree in a preorder order
+
+ Applies the given functions to each node.
+ internal: if not None, applies this function to each non-leaf, non-preterminal node
+ preterminal: if not None, applies this functiion to each preterminal
+ leaf: if not None, applies this function to each leaf
+
+ The functions should *not* destructively alter the trees.
+ There is no attempt to interpret the results of calling these functions.
+ Rather, you can use visit_preorder to collect stats on trees, etc.
+ """
+ if self.is_leaf():
+ if leaf:
+ leaf(self)
+ elif self.is_preterminal():
+ if preterminal:
+ preterminal(self)
+ else:
+ if internal:
+ internal(self)
+ for child in self.children:
+ child.visit_preorder(internal, preterminal, leaf)
+
+ @staticmethod
+ def get_unique_constituent_labels(trees):
+ """
+ Walks over all of the trees and gets all of the unique constituent names from the trees
+ """
+ if isinstance(trees, Tree):
+ trees = [trees]
+
+ constituents = set()
+ for tree in trees:
+ tree.visit_preorder(internal = lambda x: constituents.add(x.label))
+ return sorted(constituents)
+
+ @staticmethod
+ def get_unique_tags(trees):
+ """
+ Walks over all of the trees and gets all of the unique tags from the trees
+ """
+ if isinstance(trees, Tree):
+ trees = [trees]
+
+ tags = set()
+ for tree in trees:
+ tree.visit_preorder(preterminal = lambda x: tags.add(x.label))
+ return sorted(tags)
+
+ @staticmethod
+ def get_unique_words(trees):
+ """
+ Walks over all of the trees and gets all of the unique words from the trees
+ """
+ if isinstance(trees, Tree):
+ trees = [trees]
+
+ words = set()
+ for tree in trees:
+ tree.visit_preorder(leaf = lambda x: words.add(x.label))
+ return sorted(words)
+
+ @staticmethod
+ def get_rare_words(trees, threshold=0.05):
+ """
+ Walks over all of the trees and gets the least frequently occurring words.
+
+ threshold: choose the bottom X percent
+ """
+ if isinstance(trees, Tree):
+ trees = [trees]
+
+ words = Counter()
+ for tree in trees:
+ tree.visit_preorder(leaf = lambda x: words.update([x.label]))
+ threshold = max(int(len(words) * threshold), 1)
+ return sorted(x[0] for x in words.most_common()[:-threshold-1:-1])
+
+ @staticmethod
+ def get_root_labels(trees):
+ return sorted(set(x.label for x in trees))
+
+ @staticmethod
+ def get_compound_constituents(trees):
+ constituents = set()
+ stack = deque()
+ for tree in trees:
+ stack.append(tree)
+ while len(stack) > 0:
+ node = stack.pop()
+ if node.is_leaf() or node.is_preterminal():
+ continue
+ labels = [node.label]
+ while len(node.children) == 1 and not node.children[0].is_preterminal():
+ node = node.children[0]
+ labels.append(node.label)
+ constituents.add(tuple(labels))
+ for child in node.children:
+ stack.append(child)
+ return sorted(constituents)
+
+ # TODO: test different pattern
+ def simplify_labels(self, pattern=CONSTITUENT_SPLIT):
+ """
+ Return a copy of the tree with the -=# removed
+
+ Leaves the text of the leaves alone.
+ """
+ new_label = self.label
+ # check len(new_label) just in case it's a tag of - or =
+ if new_label and not self.is_leaf() and len(new_label) > 1 and new_label not in ('-LRB-', '-RRB-'):
+ new_label = pattern.split(new_label)[0]
+ new_children = [child.simplify_labels(pattern) for child in self.children]
+ return Tree(new_label, new_children)
+
+ def remap_constituent_labels(self, label_map):
+ """
+ Copies the tree with some labels replaced.
+
+ Labels in the map are replaced with the mapped value.
+ Labels not in the map are unchanged.
+ """
+ if self.is_leaf():
+ return Tree(self.label)
+ if self.is_preterminal():
+ return Tree(self.label, Tree(self.children[0].label))
+ new_label = label_map.get(self.label, self.label)
+ return Tree(new_label, [child.remap_constituent_labels(label_map) for child in self.children])
+
+ def remap_words(self, word_map):
+ """
+ Copies the tree with some labels replaced.
+
+ Labels in the map are replaced with the mapped value.
+ Labels not in the map are unchanged.
+ """
+ if self.is_leaf():
+ new_label = word_map.get(self.label, self.label)
+ return Tree(new_label)
+ if self.is_preterminal():
+ return Tree(self.label, self.children[0].remap_words(word_map))
+ return Tree(self.label, [child.remap_words(word_map) for child in self.children])
+
+ def replace_words(self, words):
+ """
+ Replace all leaf words with the words in the given list (or iterable)
+
+ Returns a new tree
+ """
+ word_iterator = iter(words)
+ def recursive_replace_words(subtree):
+ if subtree.is_leaf():
+ word = next(word_iterator, None)
+ if word is None:
+ raise ValueError("Not enough words to replace all leaves")
+ return Tree(word)
+ return Tree(subtree.label, [recursive_replace_words(x) for x in subtree.children])
+
+ new_tree = recursive_replace_words(self)
+ if any(True for _ in word_iterator):
+ raise ValueError("Too many tags for the given tree")
+ return new_tree
+
+
+ def prune_none(self):
+ """
+ Return a copy of the tree, eliminating all nodes which are in one of two categories:
+ they are a preterminal -NONE-, such as appears in PTB
+ they have been pruned to 0 children by the recursive call
+ """
+ if self.is_leaf():
+ return Tree(self.label)
+ if self.is_preterminal():
+ if self.label == '-NONE-':
+ return None
+ return Tree(self.label, Tree(self.children[0].label))
+ # must be internal node
+ new_children = [child.prune_none() for child in self.children]
+ new_children = [child for child in new_children if child is not None]
+ if len(new_children) == 0:
+ return None
+ return Tree(self.label, new_children)
diff --git a/stanza/models/constituency/trainer.py b/stanza/models/constituency/trainer.py
new file mode 100644
index 00000000..5437e833
--- /dev/null
+++ b/stanza/models/constituency/trainer.py
@@ -0,0 +1,586 @@
+"""
+This file includes a variety of methods needed to train new
+constituency parsers. It also includes a method to load an
+already-trained parser.
+
+See the `train` method for the code block which starts from
+ raw treebank and returns a new parser.
+`evaluate` reads a treebank and gives a score for those trees.
+`parse_tagged_words` is useful at Pipeline time -
+ it takes words & tags and processes that into trees.
+"""
+
+import logging
+import random
+import os
+
+import torch
+from torch import nn
+from torch import optim
+
+from stanza.models.common import pretrain
+from stanza.models.common import utils
+from stanza.models.common.char_model import CharacterLanguageModel
+from stanza.models.constituency import base_model
+from stanza.models.constituency import parse_transitions
+from stanza.models.constituency import parse_tree
+from stanza.models.constituency import transition_sequence
+from stanza.models.constituency import tree_reader
+from stanza.models.constituency.lstm_model import LSTMModel
+from stanza.models.constituency.parse_transitions import State, TransitionScheme
+from stanza.models.constituency.utils import retag_trees
+from stanza.server.parser_eval import EvaluateParser
+
+tqdm = utils.get_tqdm()
+
+logger = logging.getLogger('stanza.constituency.trainer')
+
+class Trainer:
+ """
+ Stores a constituency model and its optimizer
+
+ Not inheriting from common/trainer.py because there's no concept of change_lr (yet?)
+ """
+ def __init__(self, model=None, optimizer=None):
+ self.model = model
+ self.optimizer = optimizer
+
+ def save(self, filename, save_optimizer=True):
+ """
+ Save the model (and by default the optimizer) to the given path
+ """
+ params = self.model.get_params()
+ checkpoint = {
+ 'params': params,
+ 'model_type': 'LSTM',
+ }
+ if save_optimizer and self.optimizer is not None:
+ checkpoint['optimizer_state_dict'] = self.optimizer.state_dict()
+ torch.save(checkpoint, filename, _use_new_zipfile_serialization=False)
+ logger.info("Model saved to %s", filename)
+
+
+ @staticmethod
+ def load(filename, pt, forward_charlm, backward_charlm, use_gpu, args=None, load_optimizer=False):
+ """
+ Load back a model and possibly its optimizer.
+
+ pt: a Pretrain word embedding
+ """
+ if args is None:
+ args = {}
+
+ try:
+ checkpoint = torch.load(filename, lambda storage, loc: storage)
+ except BaseException:
+ logger.exception("Cannot load model from %s", filename)
+ raise
+ logger.debug("Loaded model from %s", filename)
+
+ model_type = checkpoint['model_type']
+ params = checkpoint.get('params', checkpoint)
+
+ if model_type == 'LSTM':
+ model = LSTMModel(pretrain=pt,
+ forward_charlm=forward_charlm,
+ backward_charlm=backward_charlm,
+ transitions=params['transitions'],
+ constituents=params['constituents'],
+ tags=params['tags'],
+ words=params['words'],
+ rare_words=params['rare_words'],
+ root_labels=params['root_labels'],
+ open_nodes=params['open_nodes'],
+ args=params['config'])
+ else:
+ raise ValueError("Unknown model type {}".format(model_type))
+ model.load_state_dict(params['model'], strict=False)
+
+ if use_gpu:
+ model.cuda()
+
+ if load_optimizer:
+ optimizer_args = dict(params['config'])
+ optimizer_args.update(args)
+ optimizer = build_optimizer(optimizer_args, model)
+
+ if checkpoint.get('optimizer_state_dict', None) is not None:
+ optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
+ else:
+ logger.info("Attempted to load optimizer to resume training, but optimizer not saved. Creating new optimizer")
+ else:
+ optimizer = None
+
+ logger.debug("-- MODEL CONFIG --")
+ for k in model.args.keys():
+ logger.debug(" --%s: %s", k, model.args[k])
+
+ return Trainer(model=model, optimizer=optimizer)
+
+
+def build_optimizer(args, model):
+ """
+ Build an optimizer based on the arguments given
+ """
+ if args['optim'].lower() == 'sgd':
+ optimizer = optim.SGD(model.parameters(), lr=args['learning_rate'], momentum=0.9, weight_decay=args['weight_decay'])
+ elif args['optim'].lower() == 'adadelta':
+ optimizer = optim.Adadelta(model.parameters(), lr=args['learning_rate'], weight_decay=args['weight_decay'])
+ elif args['optim'].lower() == 'adamw':
+ optimizer = optim.AdamW(model.parameters(), lr=args['learning_rate'], weight_decay=args['weight_decay'])
+ else:
+ raise ValueError("Unknown optimizer: %s" % args.optim)
+ return optimizer
+
+def load_pretrain(args):
+ """
+ Loads a pretrain based on the paths in the arguments
+ """
+ pretrain_file = pretrain.find_pretrain_file(args['wordvec_pretrain_file'], args['save_dir'], args['shorthand'], args['lang'])
+ if os.path.exists(pretrain_file):
+ vec_file = None
+ else:
+ vec_file = args['wordvec_file'] if args['wordvec_file'] else utils.get_wordvec_file(args['wordvec_dir'], args['shorthand'])
+ pt = pretrain.Pretrain(pretrain_file, vec_file, args['pretrain_max_vocab'])
+ return pt
+
+def load_charlm(charlm_file):
+ if charlm_file:
+ logger.debug("Loading charlm from %s", charlm_file)
+ return CharacterLanguageModel.load(charlm_file, finetune=False)
+ return None
+
+def read_treebank(filename):
+ """
+ Read a treebank and alter the trees to be a simpler format for learning to parse
+ """
+ logger.info("Reading trees from %s", filename)
+ trees = tree_reader.read_tree_file(filename)
+ trees = [t.prune_none().simplify_labels() for t in trees]
+
+ illegal_trees = [t for t in trees if len(t.children) > 1]
+ if len(illegal_trees) > 0:
+ raise ValueError("Found {} tree(s) which had non-unary transitions at the ROOT. First illegal tree: {}".format(len(illegal_trees), illegal_trees[0]))
+
+ return trees
+
+def verify_transitions(trees, sequences, transition_scheme):
+ """
+ Given a list of trees and their transition sequences, verify that the sequences rebuild the trees
+ """
+ model = base_model.SimpleModel(transition_scheme)
+ logger.info("Verifying the transition sequences for %d trees", len(trees))
+
+ data = zip(trees, sequences)
+ if logger.getEffectiveLevel() <= logging.INFO:
+ data = tqdm(zip(trees, sequences), total=len(trees))
+
+ for tree, sequence in data:
+ state = parse_transitions.initial_state_from_gold_trees([tree], model)[0]
+ for idx, trans in enumerate(sequence):
+ if not trans.is_legal(state, model):
+ raise RuntimeError("Transition {}:{} was not legal in a transition sequence:\nOriginal tree: {}\nTransitions: {}".format(idx, trans, tree, sequence))
+ state = trans.apply(state, model)
+ result = model.get_top_constituent(state.constituents)
+ if tree != result:
+ raise RuntimeError("Transition sequence did not match for a tree!\nOriginal tree:{}\nTransitions: {}\nResult tree:{}".format(tree, sequence, result))
+
+def evaluate(args, model_file, retag_pipeline):
+ """
+ Loads the given model file and tests the eval_file treebank.
+
+ May retag the trees using retag_pipeline
+ Uses a subprocess to run the Java EvalB code
+ """
+ pt = load_pretrain(args)
+ forward_charlm = load_charlm(args['charlm_forward_file'])
+ backward_charlm = load_charlm(args['charlm_backward_file'])
+ trainer = Trainer.load(model_file, pt, forward_charlm, backward_charlm, args['cuda'])
+
+ treebank = read_treebank(args['eval_file'])
+ logger.info("Read %d trees for evaluation", len(treebank))
+
+ if retag_pipeline is not None:
+ logger.info("Retagging trees using the %s tags from the %s package...", args['retag_method'], args['retag_package'])
+ treebank = retag_trees(treebank, retag_pipeline, args['retag_xpos'])
+ logger.info("Retagging finished")
+
+ f1 = run_dev_set(trainer.model, treebank, args)
+ logger.info("F1 score on %s: %f", args['eval_file'], f1)
+
+def build_treebank(trees, transition_scheme):
+ """
+ Convert a set of trees into the corresponding treebank based on the args
+
+ Currently only supports top-down transitions, but more may be added in the future, especially bottom up
+ """
+ return transition_sequence.build_treebank(trees, transition_scheme=transition_scheme)
+
+def get_open_nodes(trees, args):
+ """
+ Return a list of all open nodes in the given dataset.
+ Depending on the parameters, may be single or compound open transitions.
+ """
+ if args['transition_scheme'] is TransitionScheme.TOP_DOWN_COMPOUND:
+ return parse_tree.Tree.get_compound_constituents(trees)
+ else:
+ return [(x,) for x in parse_tree.Tree.get_unique_constituent_labels(trees)]
+
+def print_args(args):
+ """
+ For record keeping purposes, print out the arguments when training
+ """
+ keys = sorted(args.keys())
+ log_lines = ['%s: %s' % (k, args[k]) for k in keys]
+ logger.info('ARGS USED AT TRAINING TIME:\n%s\n', '\n'.join(log_lines))
+
+def remove_optimizer(args, model_save_file, model_load_file):
+ """
+ A utility method to remove the optimizer from a save file
+
+ Will make the save file a lot smaller
+ """
+ # TODO: kind of overkill to load in the pretrain rather than
+ # change the load/save to work without it, but probably this
+ # functionality isn't used that often anyway
+ pt = load_pretrain(args)
+ forward_charlm = load_charlm(args['charlm_forward_file'])
+ backward_charlm = load_charlm(args['charlm_backward_file'])
+ trainer = Trainer.load(model_load_file, pt, forward_charlm, backward_charlm, use_gpu=False, load_optimizer=False)
+ trainer.save(model_save_file)
+
+def convert_trees_to_sequences(trees, tree_type, transition_scheme):
+ logger.info("Building {} transition sequences".format(tree_type))
+ if logger.getEffectiveLevel() <= logging.INFO:
+ trees = tqdm(trees)
+ sequences = build_treebank(trees, transition_scheme)
+ transitions = transition_sequence.all_transitions(sequences)
+ return sequences, transitions
+
+def build_trainer(args, train_trees, dev_trees, pt, forward_charlm, backward_charlm):
+ """
+ Builds a Trainer (with model) and the train_sequences and transitions for the given trees.
+ """
+ train_constituents = parse_tree.Tree.get_unique_constituent_labels(train_trees)
+ dev_constituents = parse_tree.Tree.get_unique_constituent_labels(dev_trees)
+ logger.info("Unique constituents in training set: %s", train_constituents)
+ for con in dev_constituents:
+ if con not in train_constituents:
+ raise RuntimeError("Found label {} in the dev set which don't exist in the train set".format(con))
+
+ train_sequences, train_transitions = convert_trees_to_sequences(train_trees, "training", args['transition_scheme'])
+ dev_sequences, dev_transitions = convert_trees_to_sequences(dev_trees, "dev", args['transition_scheme'])
+
+ logger.info("Total unique transitions in train set: %d", len(train_transitions))
+ for trans in dev_transitions:
+ if trans not in train_transitions:
+ raise RuntimeError("Found transition {} in the dev set which don't exist in the train set".format(trans))
+
+ verify_transitions(train_trees, train_sequences, args['transition_scheme'])
+ verify_transitions(dev_trees, dev_sequences, args['transition_scheme'])
+
+ root_labels = parse_tree.Tree.get_root_labels(train_trees)
+ for root_state in parse_tree.Tree.get_root_labels(dev_trees):
+ if root_state not in root_labels:
+ raise RuntimeError("Found root state {} in the dev set which is not a ROOT state in the train set".format(root_state))
+
+ tags = parse_tree.Tree.get_unique_tags(train_trees)
+ logger.info("Unique tags in training set: %s", tags)
+ for tag in parse_tree.Tree.get_unique_tags(dev_trees):
+ if tag not in tags:
+ raise RuntimeError("Found tag {} in the dev set which is not a tag in the train set".format(tag))
+
+ # we don't check against the words in the dev set as it is
+ # expected there will be some UNK words
+ words = parse_tree.Tree.get_unique_words(train_trees)
+ rare_words = parse_tree.Tree.get_rare_words(train_trees, args['rare_word_threshold'])
+ # also, it's not actually an error if there is a pattern of
+ # compound unary or compound open nodes which doesn't exist in the
+ # train set. it just means we probably won't ever get that right
+ open_nodes = get_open_nodes(train_trees, args)
+
+ # at this point we have:
+ # pretrain
+ # train_trees, dev_trees
+ # lists of transitions, internal nodes, and root states the parser needs to be aware of
+
+ if args['finetune'] or (args['maybe_finetune'] and os.path.exists(model_load_file)):
+ logger.info("Loading model to continue training from %s", model_load_file)
+ trainer = Trainer.load(model_load_file, pt, forward_charlm, backward_charlm, args['cuda'], args, load_optimizer=True)
+ else:
+ model = LSTMModel(pt, forward_charlm, backward_charlm, train_transitions, train_constituents, tags, words, rare_words, root_labels, open_nodes, args)
+ if args['cuda']:
+ model.cuda()
+
+ optimizer = build_optimizer(args, model)
+
+ trainer = Trainer(model, optimizer)
+
+ return trainer, train_sequences, train_transitions
+
+def train(args, model_save_file, model_load_file, model_save_latest_file, retag_pipeline):
+ """
+ Build a model, train it using the requested train & dev files
+ """
+ print_args(args)
+
+ utils.ensure_dir(args['save_dir'])
+
+ train_trees = read_treebank(args['train_file'])
+ logger.info("Read %d trees for the training set", len(train_trees))
+
+ dev_trees = read_treebank(args['eval_file'])
+ logger.info("Read %d trees for the dev set", len(dev_trees))
+
+ if retag_pipeline is not None:
+ logger.info("Retagging trees using the %s tags from the %s package...", args['retag_method'], args['retag_package'])
+ train_trees = retag_trees(train_trees, retag_pipeline, args['retag_xpos'])
+ dev_trees = retag_trees(dev_trees, retag_pipeline, args['retag_xpos'])
+ logger.info("Retagging finished")
+
+ pt = load_pretrain(args)
+ forward_charlm = load_charlm(args['charlm_forward_file'])
+ backward_charlm = load_charlm(args['charlm_backward_file'])
+
+ trainer, train_sequences, train_transitions = build_trainer(args, train_trees, dev_trees, pt, forward_charlm, backward_charlm)
+
+ iterate_training(trainer, train_trees, train_sequences, train_transitions, dev_trees, args, model_save_file, model_save_latest_file)
+
+
+def iterate_training(trainer, train_trees, train_sequences, transitions, dev_trees, args, model_filename, model_latest_filename):
+ """
+ Given an initialized model, a processed dataset, and a secondary dev dataset, train the model
+
+ The training is iterated in the following loop:
+ extract a batch of trees of the same length from the training set
+ convert those trees into initial parsing states
+ repeat until trees are done:
+ batch predict the model's interpretation of the current states
+ add the errors to the list of things to backprop
+ advance the parsing state for each of the trees
+
+ Currently the only method implemented for advancing the parsing state
+ is to use the gold transition.
+
+ TODO: add a dynamic oracle which can adjust the future expected
+ parsing decisions after the parser makes an error. This way,
+ the parser will have "experienced" what the correct decision
+ to make is when it gets into incorrect states at runtime.
+ """
+ model = trainer.model
+ optimizer = trainer.optimizer
+
+ loss_function = nn.CrossEntropyLoss(reduction='sum')
+ if args['cuda']:
+ loss_function.cuda()
+
+ device = next(model.parameters()).device
+ transition_tensors = {x: torch.tensor(y, requires_grad=False, device=device).unsqueeze(0)
+ for (y, x) in enumerate(transitions)}
+
+ model.train()
+
+ train_data = list(zip(train_trees, train_sequences))
+ leftover_training_data = []
+ best_f1 = 0.0
+ best_epoch = 0
+ for epoch in range(1, args['epochs']+1):
+ model.train()
+ logger.info("Starting epoch %d", epoch)
+ epoch_data = leftover_training_data
+ while len(epoch_data) < args['eval_interval']:
+ random.shuffle(train_data)
+ epoch_data.extend(train_data)
+ leftover_training_data = epoch_data[args['eval_interval']:]
+ epoch_data = epoch_data[:args['eval_interval']]
+ epoch_data.sort(key=lambda x: len(x[1]))
+ interval_starts = list(range(0, len(epoch_data), args['train_batch_size']))
+ random.shuffle(interval_starts)
+
+ epoch_loss = 0.0
+
+ transitions_correct = 0
+ transitions_incorrect = 0
+
+ for interval_start in tqdm(interval_starts, postfix="Batch"):
+ batch = epoch_data[interval_start:interval_start+args['train_batch_size']]
+ # the batch will be empty when all trees from this epoch are trained
+ # now we add the state to the trees in the batch
+ initial_states = parse_transitions.initial_state_from_gold_trees([tree for tree, _ in batch], model)
+ batch = [state._replace(gold_sequence=sequence)
+ for (tree, sequence), state in zip(batch, initial_states)]
+
+ all_errors = []
+ all_answers = []
+
+ while len(batch) > 0:
+ outputs, pred_transitions = model.predict(batch)
+ gold_transitions = [x.gold_sequence[x.num_transitions()] for x in batch]
+ trans_tensor = [transition_tensors[gold_transition] for gold_transition in gold_transitions]
+ all_errors.append(outputs)
+ all_answers.extend(trans_tensor)
+
+ for pred_transition, gold_transition in zip(pred_transitions, gold_transitions):
+ if pred_transition != gold_transition:
+ transitions_incorrect = transitions_incorrect + 1
+ else:
+ transitions_correct = transitions_correct + 1
+
+ # eliminate finished trees, keeping only the transitions we will use
+ zipped_batch = [x for x in zip(batch, gold_transitions) if x[0].num_transitions() + 1 < len(x[0].gold_sequence)]
+ batch = [x[0] for x in zipped_batch]
+ gold_transitions = [x[1] for x in zipped_batch]
+
+ if len(batch) > 0:
+ # bulk update states
+ batch = parse_transitions.bulk_apply(model, batch, gold_transitions, fail=True, max_transitions=None)
+
+ errors = torch.cat(all_errors)
+ answers = torch.cat(all_answers)
+ tree_loss = loss_function(errors, answers)
+ tree_loss.backward()
+ epoch_loss += tree_loss.item()
+
+ optimizer.step()
+ optimizer.zero_grad()
+
+ # print statistics
+ f1 = run_dev_set(model, dev_trees, args)
+ if f1 > best_f1:
+ logger.info("New best dev score: %.5f > %.5f", f1, best_f1)
+ best_f1 = f1
+ best_epoch = epoch
+ trainer.save(model_filename, save_optimizer=True)
+ if model_latest_filename:
+ trainer.save(model_latest_filename, save_optimizer=True)
+ logger.info("Epoch {} finished\nTransitions correct: {} Transitions incorrect: {}\n Total loss for epoch: {}\n Dev score ({:5}): {}\n Best dev score ({:5}): {}".format(epoch, transitions_correct, transitions_incorrect, epoch_loss, epoch, f1, best_epoch, best_f1))
+
+def build_batch_from_trees(batch_size, data_iterator, model):
+ """
+ Read from the data_iterator batch_size trees and turn them into new parsing states
+ """
+ tree_batch = []
+ for _ in range(batch_size):
+ gold_tree = next(data_iterator, None)
+ if gold_tree is None:
+ break
+ tree_batch.append(gold_tree)
+
+ if len(tree_batch) > 0:
+ tree_batch = parse_transitions.initial_state_from_gold_trees(tree_batch, model)
+ return tree_batch
+
+def build_batch_from_tagged_words(batch_size, data_iterator, model):
+ """
+ Read from the data_iterator batch_size tagged sentences and turn them into new parsing states
+ """
+ tree_batch = []
+ for _ in range(batch_size):
+ sentence = next(data_iterator, None)
+ if sentence is None:
+ break
+ tree_batch.append(sentence)
+
+ if len(tree_batch) > 0:
+ tree_batch = parse_transitions.initial_state_from_words(tree_batch, model)
+ return tree_batch
+
+def parse_sentences(data_iterator, build_batch_fn, batch_size, model):
+ """
+ Given an iterator over the data and a method for building batches, returns a bunch of parse trees.
+
+ The data_iterator should be anything which returns the data for a parse task via next()
+ build_batch_fn is a function that turns that data into State objects
+ This will be called to generate batches of size batch_size until the data is exhausted
+
+ The return is a list of tuples: (gold_tree, [(predicted, score) ...])
+ gold_tree will be left blank if the data did not include gold trees
+ currently score is always 1.0, but the interface may be expanded to get a score from the result of the parsing
+ """
+ treebank = []
+ tree_batch = build_batch_fn(batch_size, data_iterator, model)
+ horizon_iterator = iter([])
+
+ while len(tree_batch) > 0:
+ _, transitions = model.predict(tree_batch, is_legal=True)
+ tree_batch = parse_transitions.bulk_apply(model, tree_batch, transitions)
+
+ remove = set()
+ for idx, tree in enumerate(tree_batch):
+ if tree.finished(model):
+ predicted_tree = tree.get_tree(model)
+ gold_tree = tree.gold_tree
+ # TODO: put an actual score here?
+ treebank.append((gold_tree, [(predicted_tree, 1.0)]))
+ remove.add(idx)
+
+ tree_batch = [tree for idx, tree in enumerate(tree_batch) if idx not in remove]
+
+ for _ in range(batch_size - len(tree_batch)):
+ horizon_tree = next(horizon_iterator, None)
+ if not horizon_tree:
+ horizon_batch = build_batch_fn(batch_size, data_iterator, model)
+ if len(horizon_batch) == 0:
+ break
+ horizon_iterator = iter(horizon_batch)
+ horizon_tree = next(horizon_iterator, None)
+
+ tree_batch.append(horizon_tree)
+
+ return treebank
+
+def parse_tagged_words(model, words, batch_size):
+ """
+ This parses tagged words and returns a list of trees.
+
+ The tagged words should be represented:
+ one list per sentence
+ each sentence is a list of (word, tag)
+ The return value is a list of ParseTree objects
+ """
+ logger.debug("Processing %d sentences", len(words))
+ model.eval()
+
+ sentence_iterator = iter(words)
+ treebank = parse_sentences(sentence_iterator, build_batch_from_tagged_words, batch_size, model)
+
+ results = [t[1][0][0] for t in treebank]
+ return results
+
+def run_dev_set(model, dev_trees, args):
+ """
+ This reparses a treebank and executes the CoreNLP Java EvalB code.
+
+ It only works if CoreNLP 4.3.0 or higher is in the classpath.
+ """
+ logger.info("Processing %d trees from %s", len(dev_trees), args['eval_file'])
+ model.eval()
+
+ tree_iterator = iter(tqdm(dev_trees))
+ treebank = parse_sentences(tree_iterator, build_batch_from_trees, args['eval_batch_size'], model)
+
+ if len(treebank) < len(dev_trees):
+ logger.warning("Only evaluating %d trees instead of %d", len(treebank), len(dev_trees))
+
+ if args['mode'] == 'predict' and args['predict_file']:
+ utils.ensure_dir(args['predict_dir'], verbose=False)
+ pred_file = os.path.join(args['predict_dir'], args['predict_file'] + ".pred.mrg")
+ orig_file = os.path.join(args['predict_dir'], args['predict_file'] + ".orig.mrg")
+ if os.path.exists(pred_file):
+ logger.warning("Cowardly refusing to overwrite {}".format(pred_file))
+ elif os.path.exists(orig_file):
+ logger.warning("Cowardly refusing to overwrite {}".format(orig_file))
+ else:
+ with open(pred_file, 'w') as fout:
+ for tree in treebank:
+ fout.write(str(tree[1][0][0]))
+ fout.write("\n")
+
+ with open(orig_file, 'w') as fout:
+ for tree in treebank:
+ fout.write(str(tree[0]))
+ fout.write("\n")
+
+ with EvaluateParser() as evaluator:
+ response = evaluator.process(treebank)
+ return response.f1
diff --git a/stanza/models/constituency/transition_sequence.py b/stanza/models/constituency/transition_sequence.py
new file mode 100644
index 00000000..fe60c527
--- /dev/null
+++ b/stanza/models/constituency/transition_sequence.py
@@ -0,0 +1,112 @@
+"""
+Build a transition sequence from parse trees.
+
+Supports multiple transition schemes - TOP_DOWN and variants, IN_ORDER
+"""
+
+from stanza.models.constituency.parse_transitions import Shift, CompoundUnary, OpenConstituent, CloseConstituent, TransitionScheme
+from stanza.models.constituency.tree_reader import read_trees
+
+def yield_top_down_sequence(tree, transition_scheme=TransitionScheme.TOP_DOWN_UNARY):
+ """
+ For tree (X A B C D), yield Open(X) A B C D Close
+
+ The details are in how to treat unary transitions
+ Three possibilities handled by this method:
+ TOP_DOWN_UNARY: (Y (X ...)) -> Open(X) ... Close Unary(Y)
+ TOP_DOWN_COMPOUND: (Y (X ...)) -> Open(Y, X) ... Close
+ TOP_DOWN: (Y (X ...)) -> Open(Y) Open(X) ... Close Close
+ """
+ if tree.is_preterminal():
+ yield Shift()
+ return
+
+ if tree.is_leaf():
+ return
+
+ if transition_scheme is TransitionScheme.TOP_DOWN_UNARY:
+ if len(tree.children) == 1:
+ labels = []
+ while not tree.is_preterminal() and len(tree.children) == 1:
+ labels.append(tree.label)
+ tree = tree.children[0]
+ for transition in yield_top_down_sequence(tree, transition_scheme):
+ yield transition
+ yield CompoundUnary(labels)
+ return
+
+ if transition_scheme is TransitionScheme.TOP_DOWN_COMPOUND:
+ labels = [tree.label]
+ while len(tree.children) == 1 and not tree.children[0].is_preterminal():
+ tree = tree.children[0]
+ labels.append(tree.label)
+ yield OpenConstituent(*labels)
+ else:
+ yield OpenConstituent(tree.label)
+ for child in tree.children:
+ for transition in yield_top_down_sequence(child, transition_scheme):
+ yield transition
+ yield CloseConstituent()
+
+def yield_in_order_sequence(tree):
+ """
+ For tree (X A B C D), yield A Open(X) B C D Close
+ """
+ if tree.is_preterminal():
+ yield Shift()
+ return
+
+ if tree.is_leaf():
+ return
+
+ for transition in yield_in_order_sequence(tree.children[0]):
+ yield transition
+
+ yield OpenConstituent(tree.label)
+
+ for child in tree.children[1:]:
+ for transition in yield_in_order_sequence(child):
+ yield transition
+
+ yield CloseConstituent()
+
+def build_sequence(tree, transition_scheme=TransitionScheme.TOP_DOWN_UNARY):
+ """
+ Turn a single tree into a list of transitions based on the TransitionScheme
+ """
+ if transition_scheme is TransitionScheme.IN_ORDER:
+ return list(yield_in_order_sequence(tree))
+ else:
+ return list(yield_top_down_sequence(tree, transition_scheme))
+
+def build_treebank(trees, transition_scheme=TransitionScheme.TOP_DOWN_UNARY):
+ """
+ Turn each of the trees in the treebank into a list of transitions based on the TransitionScheme
+ """
+ return [build_sequence(tree, transition_scheme) for tree in trees]
+
+def all_transitions(transition_lists):
+ """
+ Given a list of transition lists, combine them all into a list of unique transitions.
+ """
+ transitions = set()
+ for trans_list in transition_lists:
+ for trans in trans_list:
+ transitions.add(trans)
+ return sorted(transitions)
+
+def main():
+ """
+ Convert a sample tree and print its transitions
+ """
+ text="( (SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))"
+ #text = "(WP Who)"
+
+ tree = read_trees(text)[0]
+
+ print(tree)
+ transitions = build_sequence(tree)
+ print(transitions)
+
+if __name__ == '__main__':
+ main()
diff --git a/stanza/models/constituency/tree_reader.py b/stanza/models/constituency/tree_reader.py
new file mode 100644
index 00000000..65c9250c
--- /dev/null
+++ b/stanza/models/constituency/tree_reader.py
@@ -0,0 +1,154 @@
+"""
+Reads ParseTree objects from a file, string, or similar input
+
+Works by first splitting the input into (, ), and all other tokens,
+then recursively processing those tokens into trees.
+"""
+
+from stanza.models.common import utils
+from stanza.models.constituency.parse_tree import Tree
+
+tqdm = utils.get_tqdm()
+
+OPEN_PAREN = "("
+CLOSE_PAREN = ")"
+
+def recursive_open_tree(token_iterator, at_root, broken_ok):
+ """
+ Build a tree from the tokens in the token_iterator
+ """
+ # TODO: unwind the recursion
+ text = []
+ children = []
+
+ token = next(token_iterator, None)
+ while token is not None:
+ if token is OPEN_PAREN:
+ children.append(recursive_open_tree(token_iterator, at_root=False, broken_ok=broken_ok))
+ elif token is CLOSE_PAREN:
+ if len(text) == 0:
+ if at_root:
+ return Tree(label="ROOT", children=children)
+ elif broken_ok:
+ return Tree(label=None, children=children)
+ else:
+ raise ValueError("Found a tree with no label on a node! Line number %d" % token_iterator.line_num)
+
+ pieces = " ".join(text).split()
+ if len(pieces) == 1:
+ return Tree(label=pieces[0], children=children)
+
+ # the assumption here is that a language such as VI may
+ # have spaces in the words, but it still represents
+ # just one child
+ label = pieces[0]
+ child_label = " ".join(pieces[1:])
+ if len(children) > 0:
+ if broken_ok:
+ return Tree(label=label, children=children + [Tree(label=child_label)])
+ else:
+ raise ValueError("Found a tree with both text children and bracketed children! Line number %d" % token_iterator.line_num)
+ return Tree(label=label, children=Tree(label=child_label))
+ else:
+ text.append(token)
+ token = next(token_iterator, None)
+
+def recursive_read_trees(token_iterator, broken_ok):
+ """
+ Read all of the trees from the token_iterator
+
+ TODO: some of the error cases we hit can be recovered from
+ also, just in general it would be good to unwind the recursion
+ """
+ trees = []
+ token = next(token_iterator, None)
+ while token:
+ if token is OPEN_PAREN:
+ trees.append(recursive_open_tree(token_iterator, at_root=True, broken_ok=broken_ok))
+ token = next(token_iterator, None)
+ continue
+
+ if token is CLOSE_PAREN:
+ raise ValueError("Tree document had too many close parens! Line number %d" % token_iterator.line_num)
+ else:
+ raise ValueError("Tree document had text between trees! Line number %d" % token_iterator.line_num)
+
+ return trees
+
+class TokenIterator:
+ """
+ A specific iterator for reading trees from a tree file
+
+ The idea is that this will keep track of which line
+ we are processing, so that an error can be logged
+ from the correct line
+ """
+ def __init__(self, text):
+ self.lines = text.split("\n")
+ self.num_lines = len(self.lines)
+ self.line_num = -1
+ if self.num_lines > 1000:
+ self.line_iterator = iter(tqdm(self.lines))
+ else:
+ self.line_iterator = iter(self.lines)
+ self.token_iterator = iter([])
+
+ def __iter__(self):
+ return self
+
+ def __next__(self):
+ n = next(self.token_iterator, None)
+ while n is None:
+ self.line_num = self.line_num + 1
+ if self.line_num >= len(self.lines):
+ next(self.line_iterator, "")
+ raise StopIteration
+
+ line = next(self.line_iterator, "").strip()
+ if not line:
+ continue
+
+ pieces = []
+ open_pieces = line.split(OPEN_PAREN)
+ for o_idx, open_piece in enumerate(open_pieces):
+ if open_piece:
+ close_pieces = open_piece.split(CLOSE_PAREN)
+ for c_idx, close_piece in enumerate(close_pieces):
+ close_piece = close_piece.strip()
+ if close_piece:
+ pieces.append(close_piece)
+ if c_idx != len(close_pieces) - 1:
+ pieces.append(CLOSE_PAREN)
+ if o_idx != len(open_pieces) - 1:
+ pieces.append(OPEN_PAREN)
+ self.token_iterator = iter(pieces)
+ n = next(self.token_iterator, None)
+
+ return n
+
+def read_trees(text, broken_ok=False):
+ """
+ Reads multiple trees from the text
+ """
+ token_iterator = TokenIterator(text)
+ trees = recursive_read_trees(token_iterator, broken_ok=broken_ok)
+ return trees
+
+def read_tree_file(filename):
+ """
+ Read all of the trees in the given file
+ """
+ with open(filename) as fin:
+ trees = read_trees(fin.read())
+ return trees
+
+def main():
+ """
+ Reads a sample tree
+ """
+ text="( (SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))"
+ trees = read_trees(text)
+ print(trees)
+
+if __name__ == '__main__':
+ main()
diff --git a/stanza/models/constituency/tree_stack.py b/stanza/models/constituency/tree_stack.py
new file mode 100644
index 00000000..afab61b7
--- /dev/null
+++ b/stanza/models/constituency/tree_stack.py
@@ -0,0 +1,52 @@
+"""
+A utilitiy class for keeping track of intermediate parse states
+"""
+
+from collections import namedtuple
+
+class TreeStack(namedtuple('TreeStack', ['value', 'parent', 'length'])):
+ """
+ A stack which can branch in several directions, as long as you
+ keep track of the branching heads
+
+ An example usage is when K constituents are removed at once
+ to create a new constituent, and then the LSTM which tracks the
+ values of the constituents is updated starting from the Kth
+ output of the LSTM with the new value.
+
+ We don't simply keep track of a single stack object using a deque
+ because versions of the parser which use a beam will want to be
+ able to branch in different directions from the same base stack
+
+ Another possible usage is if an oracle is used for training
+ in a manner where some fraction of steps are non-gold steps,
+ but we also want to take a gold step from the same state.
+ Eg, parser gets to state X, wants to make incorrect transition T
+ instead of gold transition G, and so we continue training both
+ X+G and X+T. If we only represent the state X with standard
+ python stacks, it would not be possible to track both of these
+ states at the same time without copying the entire thing.
+
+ Value can be as transition, a word, or a partially built constituent
+
+ Implemented as a namedtuple to make it a bit more efficient
+ """
+ def pop(self):
+ return self.parent
+
+ def push(self, value):
+ # returns a new StackNode which points to this
+ return TreeStack(value, parent=self, length=self.length+1)
+
+ def __iter__(self):
+ stack = self
+ while stack.parent is not None:
+ yield stack.value
+ stack = stack.parent
+ yield stack.value
+
+ def __str__(self):
+ return "TreeStack(%s)" % ", ".join([str(x) for x in self])
+
+ def __len__(self):
+ return self.length
diff --git a/stanza/models/constituency/utils.py b/stanza/models/constituency/utils.py
new file mode 100644
index 00000000..7fd4648d
--- /dev/null
+++ b/stanza/models/constituency/utils.py
@@ -0,0 +1,58 @@
+"""
+Collects a few of the conparser utility methods which don't belong elsewhere
+"""
+
+from collections import deque
+import copy
+
+from stanza.models.common.doc import TEXT, Document
+
+def replace_tags(tree, tags):
+ if tree.is_leaf():
+ raise ValueError("Must call replace_tags with non-leaf")
+
+ tag_iterator = iter(tags)
+
+ new_tree = copy.deepcopy(tree)
+ queue = deque()
+ queue.append(new_tree)
+ while len(queue) > 0:
+ next_node = queue.pop()
+ if next_node.is_preterminal():
+ try:
+ label = next(tag_iterator)
+ except StopIteration:
+ raise ValueError("Not enough tags in sentence for given tree")
+ next_node.label = label
+ elif next_node.is_leaf():
+ raise ValueError("Got a badly structured tree: {}".format(tree))
+ else:
+ queue.extend(reversed(next_node.children))
+
+ if any(True for _ in tag_iterator):
+ raise ValueError("Too many tags for the given tree")
+
+ return new_tree
+
+
+def retag_trees(trees, pipeline, xpos=True):
+ """
+ Retag all of the trees using the given processor
+
+ Returns a list of new trees
+ """
+ sentences = []
+ for tree in trees:
+ tokens = [{TEXT: pt.children[0].label} for pt in tree.preterminals()]
+ sentences.append(tokens)
+
+ doc = Document(sentences)
+ doc = pipeline(doc)
+ if xpos:
+ tag_lists = [[x.xpos for x in sentence.words] for sentence in doc.sentences]
+ else:
+ tag_lists = [[x.upos for x in sentence.words] for sentence in doc.sentences]
+
+ new_trees = [replace_tags(tree, tags) for tree, tags in zip(trees, tag_lists)]
+ return new_trees
+
diff --git a/stanza/models/constituency_parser.py b/stanza/models/constituency_parser.py
new file mode 100644
index 00000000..62421822
--- /dev/null
+++ b/stanza/models/constituency_parser.py
@@ -0,0 +1,290 @@
+"""A command line interface to a shift reduce constituency parser.
+
+This follows the work of
+Recurrent neural network grammars by Dyer et al
+In-Order Transition-based Constituent Parsing by Liu & Zhang
+
+The general outline is:
+
+ Train a model by taking a list of trees, converting them to
+ transition sequences, and learning a model which can predict the
+ next transition given a current state
+ Then, at inference time, repeatedly predict the next transition until parsing is complete
+
+The "transitions" are variations on shift/reduce as per an
+intro-to-compilers class. The idea is that you can treat all of the
+words in a sentence as a buffer of tokens, then either "shift" them to
+represent a new constituent, or "reduce" one or more constituents to
+form a new constituent.
+
+In order to make the runtime a more competitive speed, effort is taken
+to batch the transitions and apply multiple transitions at once. At
+train time, batches are groups together by length, and at inference
+time, new trees are added to the batch as previous trees on the batch
+finish their inference.
+
+There are two minor differences in the model:
+ - The word input is a bi-lstm, not a uni-lstm.
+ This gave a small increase in accuracy.
+ - The combination of several constituents into one constituent is done
+ via a single bi-lstm rather than two separate lstms. This increases
+ speed without a noticeable effect on accuracy.
+
+A couple experiments which have been tried with little noticeable impact:
+ - Combining constituents using the method in the paper (only a trained
+ vector at the start instead of both ends) did not affect results
+ and is a little slower
+ - Using multiple layers of LSTM hidden state for the input to the final
+ classification layers didn't help
+
+The code breakdown is as follows:
+
+ this file: main interface for training or evaluating models
+ constituency/trainer.py: contains the training & evaluation code
+
+ constituency/parse_tree.py: a data structure for representing a parse tree and utility methods
+ constituency/tree_reader.py: a module which can read trees from a string or input file
+
+ constituency/tree_stack.py: a linked list which can branch in
+ different directions, which will be useful when implementing beam
+ search or a dynamic oracle
+
+ constituency/parse_transitions.py: transitions and a State data structure to store them
+ constituency/transition_sequence.py: turns ParseTree objects into
+ the transition sequences needed to make them
+
+ constituency/base_model.py: operates on the transitions to turn them in to constituents,
+ eventually forming one final parse tree composed of all of the constituents
+ constituency/lstm_model.py: adds LSTM features to the constituents to predict what the
+ correct transition to make is, allowing for predictions on previously unseen text
+
+ stanza/pipeline/constituency_processor.py: interface between this model and the Pipeline
+"""
+
+import argparse
+import logging
+import os
+
+import torch
+
+from stanza import Pipeline
+from stanza.models.common import utils
+from stanza.models.constituency import trainer
+from stanza.models.constituency.parse_transitions import TransitionScheme
+
+logger = logging.getLogger('stanza')
+
+def parse_args(args=None):
+ """
+ Adds the arguments for building the con parser
+
+ For the most part, defaults are set to cross-validated values, at least for WSJ
+ """
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument('--data_dir', type=str, default='data/constituency', help='Directory of constituency data.')
+
+ parser.add_argument('--wordvec_dir', type=str, default='extern_data/wordvec', help='Directory of word vectors')
+ parser.add_argument('--wordvec_file', type=str, default='', help='File that contains word vectors')
+ parser.add_argument('--wordvec_pretrain_file', type=str, default=None, help='Exact name of the pretrain file to read')
+ parser.add_argument('--pretrain_max_vocab', type=int, default=250000)
+
+ # for whatever reason, this feature was not helpful
+ parser.add_argument('--charlm_forward_file', type=str, default=None, help="Exact path to use for forward charlm")
+ parser.add_argument('--charlm_backward_file', type=str, default=None, help="Exact path to use for backward charlm")
+
+ parser.add_argument('--tag_embedding_dim', type=int, default=20, help="Embedding size for a tag. 0 turns off the feature")
+ # Smaller values also seem to work
+ # For example, after 700 iterations:
+ # 32: 0.9174
+ # 50: 0.9183
+ # 72: 0.9176
+ # 100: 0.9185
+ # not a huge difference regardless
+ # (these numbers were without retagging)
+ parser.add_argument('--delta_embedding_dim', type=int, default=100, help="Embedding size for a delta embedding")
+
+ parser.add_argument('--train_file', type=str, default=None, help='Input file for data loader.')
+ parser.add_argument('--eval_file', type=str, default=None, help='Input file for data loader.')
+ parser.add_argument('--mode', default='train', choices=['train', 'predict', 'remove_optimizer'])
+ parser.add_argument('--predict_dir', type=str, default=".", help='Where to write the predictions during --mode predict. Pred and orig files will be written - the orig file will be retagged if that is requested. The orig file is important as the results will be shuffled')
+ parser.add_argument('--predict_file', type=str, default=None, help='Base name for writing predictions')
+
+ parser.add_argument('--lang', type=str, help='Language')
+ parser.add_argument('--shorthand', type=str, help="Treebank shorthand")
+
+ parser.add_argument('--transition_embedding_dim', type=int, default=20, help="Embedding size for a transition")
+ parser.add_argument('--transition_hidden_size', type=int, default=20, help="Embedding size for transition stack")
+ # larger was more effective, up to a point
+ parser.add_argument('--hidden_size', type=int, default=128, help="Size of the output layers for constituency stack and word queue")
+
+ parser.add_argument('--epochs', type=int, default=200)
+ parser.add_argument('--eval_interval', type=int, default=5000)
+ # 30 is slightly slower than 50, for example, but seems to train a bit better on WSJ
+ # earlier version of the model (less accurate overall) had the following results with adadelta:
+ # 30: 0.9085
+ # 50: 0.9070
+ # 75: 0.9010
+ # 150: 0.8985
+ # as another data point, running a newer version with better constituency lstm behavior had:
+ # 30: 0.9111
+ # 50: 0.9094
+ # checking smaller batch sizes to see how this works, at 135 epochs, the values are
+ # 10: 0.8919
+ # 20: 0.9072
+ # 30: 0.9121
+ # obviously these experiments aren't the complete story, but it
+ # looks like 30 trees per batch is the best value for WSJ
+ # note that these numbers are for adadelta and might not apply
+ # to other optimizers
+ # eval batch should generally be faster the bigger the batch,
+ # up to a point, as it allows for more batching of the LSTM
+ # operations and the prediction step
+ parser.add_argument('--train_batch_size', type=int, default=30, help='How many trees to train before taking an optimizer step')
+ parser.add_argument('--eval_batch_size', type=int, default=50, help='How many trees to batch when running eval')
+
+ parser.add_argument('--save_dir', type=str, default='saved_models/constituency', help='Root dir for saving models.')
+ parser.add_argument('--save_name', type=str, default=None, help="File name to save the model")
+ parser.add_argument('--save_latest_name', type=str, default=None, help="Save the latest model here regardless of score. Useful for restarting training")
+
+ parser.add_argument('--seed', type=int, default=1234)
+ parser.add_argument('--cuda', type=bool, default=torch.cuda.is_available())
+ parser.add_argument('--cpu', action='store_true', help='Ignore CUDA.')
+
+ DEFAULT_LEARNING_RATES = { "adamw": 0.001, "adadelta": 1.0, "sgd": 0.001 }
+ parser.add_argument('--learning_rate', default=None, type=float, help='Learning rate for the optimizer. Reasonable values are 1.0 for adadelta or 0.001 for SGD. None uses a default for the given optimizer: {}'.format(DEFAULT_LEARNING_RATES))
+ # When using adadelta, weight_decay of 0.01 to 0.001 had the best results.
+ # 0.1 was very clearly too high. 0.0001 might have been okay.
+ parser.add_argument('--weight_decay', default=0.01, type=float, help='Weight decay (eg, l2 reg) to use in the optimizer')
+ parser.add_argument('--optim', default='Adadelta', help='Optimizer type: SGD, AdamW, or Adadelta')
+
+ # When using word_dropout and predict_dropout in conjunction with relu, one particular experiment produced the following dev scores after 300 iterations:
+ # 0.0: 0.9085
+ # 0.2: 0.9165
+ # 0.4: 0.9162
+ # 0.5: 0.9123
+ # Letting 0.2 and 0.4 run for longer, along with 0.3 as another
+ # trial, continued to give extremely similar results over time.
+ # No attempt has been made to test the different dropouts separately...
+ parser.add_argument('--word_dropout', default=0.2, type=float, help='Dropout on the word embedding')
+ parser.add_argument('--predict_dropout', default=0.2, type=float, help='Dropout on the final prediction layer')
+ # lstm_dropout has not been fully tested yet
+ # one experiment after 200 iterations (after retagging, so scores are lower than some other experiments):
+ # 0.0: 0.9093
+ # 0.1: 0.9094
+ # 0.2: 0.9094
+ # 0.3: 0.9076
+ # 0.4: 0.9077
+ parser.add_argument('--lstm_layer_dropout', default=0.0, type=float, help='Dropout in the LSTM layers')
+ # one not very conclusive experiment (not long enough) came up with these numbers after ~200 iterations
+ # 0.0 0.9091
+ # 0.1 0.9095
+ # 0.2 0.9118
+ # 0.3 0.9123
+ # 0.4 0.9080
+ parser.add_argument('--lstm_input_dropout', default=0.2, type=float, help='Dropout on the input to an LSTM')
+
+ parser.add_argument('--transition_scheme', default=TransitionScheme.IN_ORDER, type=lambda x: TransitionScheme[x.upper()],
+ help='Transition scheme to use. {}'.format(", ".join(x.name for x in TransitionScheme)))
+
+ parser.add_argument('--constituency_lstm', default=False, action='store_true', help="Build constituents using the full LSTM instead of just the nodes below the new constituent. Doesn't match the original papers and might be slightly less effective")
+
+ # combining dummy and open node embeddings might be a slight improvement
+ # for example, after 550 iterations, one experiment had
+ # True: 0.9154
+ # False: 0.9150
+ # another (with a different structure) had 850 iterations
+ # True: 0.9155
+ # False: 0.9149
+ parser.add_argument('--combined_dummy_embedding', default=False, action='store_true', help="Use the same embedding for dummy nodes and the vectors used when combining constituents")
+ parser.add_argument('--no_combined_dummy_embedding', dest='combined_dummy_embedding', action='store_false', help="Don't use the same embedding for dummy nodes and the vectors used when combining constituents")
+
+ # relu gave at least 1 F1 improvement over tanh in various experiments
+ # relu & gelu seem roughly the same, but relu is clearly faster.
+ # relu, 496 iterations: 0.9176
+ # gelu, 467 iterations: 0.9181
+ # after the same clock time on the same hardware. the two had been
+ # trading places in terms of accuracy over those ~500 iterations.
+ parser.add_argument('--nonlinearity', default='relu', choices=['tanh', 'relu', 'gelu'], help='Nonlinearity to use in the model. relu is a noticeable improvement')
+
+ parser.add_argument('--rare_word_unknown_frequency', default=0.02, type=float, help='How often to replace a rare word with UNK when training')
+ parser.add_argument('--rare_word_threshold', default=0.02, type=float, help='How many words to consider as rare words as a fraction of the dataset')
+
+ parser.add_argument('--num_lstm_layers', default=2, type=int, help='How many layers to use in the LSTMs')
+ parser.add_argument('--num_output_layers', default=3, type=int, help='How many layers to use at the prediction level')
+
+ # TODO: add the ability to keep training in a different direction
+ # after making an error, eg, add an oracle
+ parser.add_argument('--train_method', default='gold_entire', choices=['gold_entire'], help='Different training methods to use')
+
+ parser.add_argument('--finetune', action='store_true', help='Load existing model during `train` mode from `load_name` path')
+ parser.add_argument('--maybe_finetune', action='store_true', help='Load existing model during `train` mode from `load_name` path if it exists. Useful for running in situations where a job is frequently being preempted')
+ parser.add_argument('--load_name', type=str, default=None, help='Model to load when finetuning, evaluating, or manipulating an existing file')
+
+ parser.add_argument('--retag_package', default=None, help='Which tagger shortname to use when retagging trees. None for no retagging. Retagging is recommended, as gold tags will not be available at pipeline time')
+ parser.add_argument('--retag_method', default='xpos', choices=['xpos', 'upos'], help='Which tags to use when retagging')
+
+ args = parser.parse_args(args=args)
+ if not args.lang and args.shorthand and len(args.shorthand.split("_")) == 2:
+ args.lang = args.shorthand.split("_")[0]
+ if args.cpu:
+ args.cuda = False
+ if args.learning_rate is None:
+ args.learning_rate = DEFAULT_LEARNING_RATES.get(args.optim.lower(), None)
+
+ args = vars(args)
+
+ if args['retag_method'] == 'xpos':
+ args['retag_xpos'] = True
+ elif args['retag_method'] == 'upos':
+ args['retag_xpos'] = False
+ else:
+ raise ValueError("Unknown retag method {}".format(xpos))
+
+ return args
+
+def main(args=None):
+ """
+ Main function for building con parser
+
+ Processes args, calls the appropriate function for the chosen --mode
+ """
+ args = parse_args(args=args)
+
+ utils.set_random_seed(args['seed'], args['cuda'])
+
+ logger.info("Running constituency parser in %s mode", args['mode'])
+ logger.debug("Using GPU: %s", args['cuda'])
+
+ model_save_file = args['save_name'] if args['save_name'] else '{}_constituency.pt'.format(args['shorthand'])
+ model_save_file = os.path.join(args['save_dir'], model_save_file)
+
+ model_save_latest_file = None
+ if args['save_latest_name']:
+ model_save_latest_file = os.path.join(args['save_dir'], args['save_latest_name'])
+
+ model_load_file = model_save_file
+ if args['load_name']:
+ model_load_file = os.path.join(args['save_dir'], args['load_name'])
+ elif args['mode'] == 'train' and args['save_latest_name']:
+ model_load_file = model_save_latest_file
+
+ if args['retag_package'] is not None:
+ if '_' in args['retag_package']:
+ lang, package = args['retag_package'].split('_', 1)
+ retag_pipeline = Pipeline(lang=lang, processors="tokenize, pos", tokenize_pretokenized=True, pos_package=package, pos_tqdm=True)
+ else:
+ lang = args['retag_package']
+ retag_pipeline = Pipeline(lang=lang, processors="tokenize, pos", tokenize_pretokenized=True, pos_tqdm=True)
+ else:
+ retag_pipeline = None
+
+ if args['mode'] == 'train':
+ trainer.train(args, model_save_file, model_load_file, model_save_latest_file, retag_pipeline)
+ elif args['mode'] == 'predict':
+ trainer.evaluate(args, model_load_file, retag_pipeline)
+ elif args['mode'] == 'remove_optimizer':
+ trainer.remove_optimizer(args, model_save_file, model_load_file)
+
+if __name__ == '__main__':
+ main()
diff --git a/stanza/models/lang_identifier.py b/stanza/models/lang_identifier.py
new file mode 100644
index 00000000..ca7aa8e2
--- /dev/null
+++ b/stanza/models/lang_identifier.py
@@ -0,0 +1,226 @@
+"""
+Entry point for training and evaluating a Bi-LSTM language identifier
+"""
+
+import argparse
+import json
+import logging
+import os
+import random
+import torch
+
+from datetime import datetime
+from stanza.models.langid.data import DataLoader
+from stanza.models.langid.trainer import Trainer
+from tqdm import tqdm
+
+logger = logging.getLogger('stanza')
+
+def parse_args(args=None):
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--batch-mode", help="custom settings when running in batch mode", action="store_true")
+ parser.add_argument("--batch-size", help="batch size for training", type=int, default=64)
+ parser.add_argument("--eval-length", help="length of strings to eval on", type=int, default=None)
+ parser.add_argument("--eval-set", help="eval on dev or test", default="test")
+ parser.add_argument("--data-dir", help="directory with train/dev/test data", default=None)
+ parser.add_argument("--load-model", help="path to load model from", default=None)
+ parser.add_argument("--mode", help="train or eval", default="train")
+ parser.add_argument("--num-epochs", help="number of epochs for training", type=int, default=50)
+ parser.add_argument("--randomize", help="take random substrings of samples", action="store_true")
+ parser.add_argument("--randomize-lengths-range", help="range of lengths to use when random sampling text",
+ type=randomize_lengths_range, default="5,20")
+ parser.add_argument("--merge-labels-for-eval",
+ help="merge some language labels for eval (e.g. \"zh-hans\" and \"zh-hant\" to \"zh\")",
+ action="store_true")
+ parser.add_argument("--save-best-epochs", help="save model for every epoch with new best score", action="store_true")
+ parser.add_argument("--save-name", help="where to save model", default=None)
+ parser.add_argument("--use-cpu", help="use cpu", action="store_true")
+ args = parser.parse_args(args=args)
+ args.use_gpu = True if torch.cuda.is_available() and not args.use_cpu else False
+ return args
+
+
+def randomize_lengths_range(range_list):
+ """
+ Range of lengths for random samples
+ """
+ range_boundaries = [int(x) for x in range_list.split(",")]
+ assert range_boundaries[0] < range_boundaries[1], f"Invalid range: ({range_boundaries[0]}, {range_boundaries[1]})"
+ return range_boundaries
+
+
+def main(args=None):
+ args = parse_args(args=args)
+ torch.manual_seed(0)
+ if args.mode == "train":
+ train_model(args)
+ else:
+ eval_model(args)
+
+
+def build_indexes(args):
+ tag_to_idx = {}
+ char_to_idx = {}
+ train_files = [f"{args.data_dir}/{x}" for x in os.listdir(args.data_dir) if "train" in x]
+ for train_file in train_files:
+ with open(train_file) as curr_file:
+ lines = curr_file.read().strip().split("\n")
+ examples = [json.loads(line) for line in lines if line.strip()]
+ for example in examples:
+ label = example["label"]
+ if label not in tag_to_idx:
+ tag_to_idx[label] = len(tag_to_idx)
+ sequence = example["text"]
+ for char in list(sequence):
+ if char not in char_to_idx:
+ char_to_idx[char] = len(char_to_idx)
+ char_to_idx["UNK"] = len(char_to_idx)
+ char_to_idx["<PAD>"] = len(char_to_idx)
+
+ return tag_to_idx, char_to_idx
+
+
+def train_model(args):
+ # set up indexes
+ tag_to_idx, char_to_idx = build_indexes(args)
+ # load training data
+ train_data = DataLoader()
+ train_files = [f"{args.data_dir}/{x}" for x in os.listdir(args.data_dir) if "train" in x]
+ train_data.load_data(args.batch_size, train_files, char_to_idx, tag_to_idx, args.randomize)
+ # load dev data
+ dev_data = DataLoader()
+ dev_files = [f"{args.data_dir}/{x}" for x in os.listdir(args.data_dir) if "dev" in x]
+ dev_data.load_data(args.batch_size, dev_files, char_to_idx, tag_to_idx, randomize=False,
+ max_length=args.eval_length)
+ # set up trainer
+ trainer_config = {
+ "model_path": args.save_name,
+ "char_to_idx": char_to_idx,
+ "tag_to_idx": tag_to_idx,
+ "batch_size": args.batch_size,
+ "lang_weights": train_data.lang_weights
+ }
+ if args.load_model:
+ trainer_config["load_model"] = args.load_model
+ logger.info(f"{datetime.now()}\tLoading model from: {args.load_model}")
+ trainer = Trainer(trainer_config, load_model=args.load_model, use_gpu=args.use_gpu)
+ # run training
+ best_accuracy = 0.0
+ for epoch in range(1, args.num_epochs+1):
+ logger.info(f"{datetime.now()}\tEpoch {epoch}")
+ logger.info(f"{datetime.now()}\tNum training batches: {len(train_data.batches)}")
+ for train_batch in tqdm(train_data.batches, disable=args.batch_mode):
+ inputs = (train_batch["sentences"], train_batch["targets"])
+ trainer.update(inputs)
+ logger.info(f"{datetime.now()}\tEpoch complete. Evaluating on dev data.")
+ curr_dev_accuracy, curr_confusion_matrix, curr_precisions, curr_recalls, curr_f1s = \
+ eval_trainer(trainer, dev_data, batch_mode=args.batch_mode)
+ logger.info(f"{datetime.now()}\tCurrent dev accuracy: {curr_dev_accuracy}")
+ if curr_dev_accuracy > best_accuracy:
+ logger.info(f"{datetime.now()}\tNew best score. Saving model.")
+ model_label = f"epoch{epoch}" if args.save_best_epochs else None
+ trainer.save(label=model_label)
+ with open(score_log_path(args.save_name), "w") as score_log_file:
+ for score_log in [{"dev_accuracy": curr_dev_accuracy}, curr_confusion_matrix, curr_precisions,
+ curr_recalls, curr_f1s]:
+ score_log_file.write(json.dumps(score_log) + "\n")
+ best_accuracy = curr_dev_accuracy
+
+ # reload training data
+ logger.info(f"{datetime.now()}\tResampling training data.")
+ train_data.load_data(args.batch_size, train_files, char_to_idx, tag_to_idx, args.randomize)
+
+
+def score_log_path(file_path):
+ """
+ Helper that will determine corresponding log file (e.g. /path/to/demo.pt to /path/to/demo.json
+ """
+ model_suffix = os.path.splitext(file_path)
+ if model_suffix:
+ score_log_path = f"{file_path[:-len(model_suffix)]}.json"
+ else:
+ score_log_path = f"{file_path}.json"
+ return score_log_path
+
+
+def eval_model(args):
+ # set up trainer
+ trainer_config = {
+ "model_path": None,
+ "load_model": args.load_model,
+ "batch_size": args.batch_size
+ }
+ trainer = Trainer(trainer_config, load_model=True, use_gpu=args.use_gpu)
+ # load test data
+ test_data = DataLoader()
+ test_files = [f"{args.data_dir}/{x}" for x in os.listdir(args.data_dir) if args.eval_set in x]
+ test_data.load_data(args.batch_size, test_files, trainer.model.char_to_idx, trainer.model.tag_to_idx,
+ randomize=False, max_length=args.eval_length)
+ curr_accuracy, curr_confusion_matrix, curr_precisions, curr_recalls, curr_f1s = \
+ eval_trainer(trainer, test_data, batch_mode=args.batch_mode, fine_grained=not args.merge_labels_for_eval)
+ logger.info(f"{datetime.now()}\t{args.eval_set} accuracy: {curr_accuracy}")
+ eval_save_path = args.save_name if args.save_name else score_log_path(args.load_model)
+ if not os.path.exists(eval_save_path) or args.save_name:
+ with open(eval_save_path, "w") as score_log_file:
+ for score_log in [{"dev_accuracy": curr_accuracy}, curr_confusion_matrix, curr_precisions,
+ curr_recalls, curr_f1s]:
+ score_log_file.write(json.dumps(score_log) + "\n")
+
+
+
+def eval_trainer(trainer, dev_data, batch_mode=False, fine_grained=True):
+ """
+ Produce dev accuracy and confusion matrix for a trainer
+ """
+
+ # set up confusion matrix
+ tag_to_idx = dev_data.tag_to_idx
+ idx_to_tag = dev_data.idx_to_tag
+ confusion_matrix = {}
+ for row_label in tag_to_idx:
+ confusion_matrix[row_label] = {}
+ for col_label in tag_to_idx:
+ confusion_matrix[row_label][col_label] = 0
+
+ # process dev batches
+ for dev_batch in tqdm(dev_data.batches, disable=batch_mode):
+ inputs = (dev_batch["sentences"], dev_batch["targets"])
+ predictions = trainer.predict(inputs)
+ for target_idx, prediction in zip(dev_batch["targets"], predictions):
+ prediction_label = idx_to_tag[prediction] if fine_grained else idx_to_tag[prediction].split("-")[0]
+ confusion_matrix[idx_to_tag[target_idx]][prediction_label] += 1
+
+ # calculate dev accuracy
+ total_examples = sum([sum([confusion_matrix[i][j] for j in confusion_matrix[i]]) for i in confusion_matrix])
+ total_correct = sum([confusion_matrix[i][i] for i in confusion_matrix])
+ dev_accuracy = float(total_correct) / float(total_examples)
+
+ # calculate precision, recall, F1
+ precision_scores = {"type": "precision"}
+ recall_scores = {"type": "recall"}
+ f1_scores = {"type": "f1"}
+ for prediction_label in tag_to_idx:
+ total = sum([confusion_matrix[k][prediction_label] for k in tag_to_idx])
+ if total != 0.0:
+ precision_scores[prediction_label] = float(confusion_matrix[prediction_label][prediction_label])/float(total)
+ else:
+ precision_scores[prediction_label] = 0.0
+ for target_label in tag_to_idx:
+ total = sum([confusion_matrix[target_label][k] for k in tag_to_idx])
+ if total != 0:
+ recall_scores[target_label] = float(confusion_matrix[target_label][target_label])/float(total)
+ else:
+ recall_scores[target_label] = 0.0
+ for label in tag_to_idx:
+ if precision_scores[label] == 0.0 and recall_scores[label] == 0.0:
+ f1_scores[label] = 0.0
+ else:
+ f1_scores[label] = \
+ 2.0 * (precision_scores[label] * recall_scores[label]) / (precision_scores[label] + recall_scores[label])
+
+ return dev_accuracy, confusion_matrix, precision_scores, recall_scores, f1_scores
+
+
+if __name__ == "__main__":
+ main()
+
diff --git a/stanza/models/langid/__init__.py b/stanza/models/langid/__init__.py
new file mode 100644
index 00000000..e69de29b
--- /dev/null
+++ b/stanza/models/langid/__init__.py
diff --git a/stanza/models/langid/create_ud_data.py b/stanza/models/langid/create_ud_data.py
new file mode 100644
index 00000000..7e7d5cc5
--- /dev/null
+++ b/stanza/models/langid/create_ud_data.py
@@ -0,0 +1,205 @@
+"""
+Script for producing training/dev/test data from UD data or sentences
+
+Example output data format (one example per line):
+
+{"text": "Hello world.", "label": "en"}
+
+"""
+
+import argparse
+import json
+import logging
+import os
+import re
+import sys
+
+from pathlib import Path
+from random import randint, random, shuffle
+from string import digits
+from tqdm import tqdm
+
+from stanza.models.common.constant import treebank_to_langid
+
+logger = logging.getLogger('stanza')
+
+DEFAULT_LANGUAGES = "af,ar,be,bg,bxr,ca,cop,cs,cu,da,de,el,en,es,et,eu,fa,fi,fr,fro,ga,gd,gl,got,grc,he,hi,hr,hsb,hu,hy,id,it,ja,kk,kmr,ko,la,lt,lv,lzh,mr,mt,nl,nn,no,olo,orv,pl,pt,ro,ru,sk,sl,sme,sr,sv,swl,ta,te,tr,ug,uk,ur,vi,wo,zh-hans,zh-hant".split(",")
+
+def parse_args(args=None):
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--data-format", help="input data format", choices=["ud", "one-per-line"], default="ud")
+ parser.add_argument("--eval-length", help="length of eval strings", type=int, default=10)
+ parser.add_argument("--languages", help="list of languages to use, or \"all\"", default=DEFAULT_LANGUAGES)
+ parser.add_argument("--min-window", help="minimal training example length", type=int, default=10)
+ parser.add_argument("--max-window", help="maximum training example length", type=int, default=50)
+ parser.add_argument("--ud-path", help="path to ud data")
+ parser.add_argument("--save-path", help="path to save data", default=".")
+ parser.add_argument("--splits", help="size of train/dev/test splits in percentages", type=splits_from_list,
+ default="0.8,0.1,0.1")
+ args = parser.parse_args(args=args)
+ return args
+
+
+def splits_from_list(value_list):
+ return [float(x) for x in value_list.split(",")]
+
+
+def main(args=None):
+ args = parse_args(args=args)
+ if isinstance(args.languages, str):
+ args.languages = args.languages.split(",")
+ data_paths = [f"{args.save_path}/{data_split}.jsonl" for data_split in ["train", "dev", "test"]]
+ lang_to_files = collect_files(args.ud_path, args.languages, data_format=args.data_format)
+ logger.info(f"Building UD data for languages: {','.join(args.languages)}")
+ for lang_id in tqdm(lang_to_files):
+ lang_examples = generate_examples(lang_id, lang_to_files[lang_id], splits=args.splits,
+ min_window=args.min_window, max_window=args.max_window,
+ eval_length=args.eval_length, data_format=args.data_format)
+ for (data_set, save_path) in zip(lang_examples, data_paths):
+ with open(save_path, "a") as json_file:
+ for json_entry in data_set:
+ json.dump(json_entry, json_file, ensure_ascii=False)
+ json_file.write("\n")
+
+
+def collect_files(ud_path, languages, data_format="ud"):
+ """
+ Given path to UD, collect files
+ If data_format = "ud", expects files to be of form *.conllu
+ If data_format = "one-per-line", expects files to be of form "*.sentences.txt"
+ In all cases, the UD path should be a directory with subdirectories for each language
+ """
+ data_format_to_search_path = {"ud": "*/*.conllu", "one-per-line": "*/*sentences.txt"}
+ ud_files = Path(ud_path).glob(data_format_to_search_path[data_format])
+ lang_to_files = {}
+ for ud_file in ud_files:
+ if data_format == "ud":
+ lang_id = treebank_to_langid(ud_file.parent.name)
+ else:
+ lang_id = ud_file.name.split("_")[0]
+ if lang_id not in languages and "all" not in languages:
+ continue
+ if not lang_id in lang_to_files:
+ lang_to_files[lang_id] = []
+ lang_to_files[lang_id].append(ud_file)
+ return lang_to_files
+
+
+def generate_examples(lang_id, list_of_files, splits=(0.8,0.1,0.1), min_window=10, max_window=50,
+ eval_length=10, data_format="ud"):
+ """
+ Generate train/dev/test examples for a given language
+ """
+ examples = []
+ for ud_file in list_of_files:
+ sentences = sentences_from_file(ud_file, data_format=data_format)
+ for sentence in sentences:
+ sentence = clean_sentence(sentence)
+ if validate_sentence(sentence, min_window):
+ examples += sentence_to_windows(sentence, min_window=min_window, max_window=max_window)
+ shuffle(examples)
+ train_idx = int(splits[0] * len(examples))
+ train_set = [example_json(lang_id, example) for example in examples[:train_idx]]
+ dev_idx = int(splits[1] * len(examples)) + train_idx
+ dev_set = [example_json(lang_id, example, eval_length=eval_length) for example in examples[train_idx:dev_idx]]
+ test_set = [example_json(lang_id, example, eval_length=eval_length) for example in examples[dev_idx:]]
+ return train_set, dev_set, test_set
+
+
+def sentences_from_file(ud_file_path, data_format="ud"):
+ """
+ Retrieve all sentences from a UD file
+ """
+ if data_format == "ud":
+ with open(ud_file_path) as ud_file:
+ ud_file_contents = ud_file.read().strip()
+ assert "# text = " in ud_file_contents, \
+ f"{ud_file_path} does not have expected format, \"# text =\" does not appear"
+ sentences = [x[9:] for x in ud_file_contents.split("\n") if x.startswith("# text = ")]
+ elif data_format == "one-per-line":
+ with open(ud_file_path) as ud_file:
+ sentences = [x for x in ud_file.read().strip().split("\n") if x]
+ return sentences
+
+
+def sentence_to_windows(sentence, min_window, max_window):
+ """
+ Create window size chunks from a sentence, always starting with a word
+ """
+ windows = []
+ words = sentence.split(" ")
+ curr_window = ""
+ for idx, word in enumerate(words):
+ curr_window += (" " + word)
+ curr_window = curr_window.lstrip()
+ next_word_len = len(words[idx+1]) + 1 if idx+1 < len(words) else 0
+ if len(curr_window) + next_word_len > max_window:
+ curr_window = clean_sentence(curr_window)
+ if validate_sentence(curr_window, min_window):
+ windows.append(curr_window.strip())
+ curr_window = ""
+ if len(curr_window) >= min_window:
+ windows.append(curr_window)
+ return windows
+
+
+def validate_sentence(current_window, min_window):
+ """
+ Sentence validation from: LSTM-LID
+ GitHub: https://github.com/AU-DIS/LSTM_langid/blob/main/src/dataset_creator.py
+ """
+ if len(current_window) < min_window:
+ return False
+ return True
+
+def find(s, ch):
+ """
+ Helper for clean_sentence from LSTM-LID
+ GitHub: https://github.com/AU-DIS/LSTM_langid/blob/main/src/dataset_creator.py
+ """
+ return [i for i, ltr in enumerate(s) if ltr == ch]
+
+
+def clean_sentence(line):
+ """
+ Sentence cleaning from LSTM-LID
+ GitHub: https://github.com/AU-DIS/LSTM_langid/blob/main/src/dataset_creator.py
+ """
+ # We remove some special characters and fix small errors in the data, to improve the quality of the data
+ line = line.replace("\n", '') #{"text": "- Mor.\n", "label": "da"}
+ line = line.replace("- ", '') #{"text": "- Mor.", "label": "da"}
+ line = line.replace("_", '') #{"text": "- Mor.", "label": "da"}
+ line = line.replace("\\", '')
+ line = line.replace("\"", '')
+ line = line.replace(" ", " ")
+ remove_digits = str.maketrans('', '', digits)
+ line = line.translate(remove_digits)
+ words = line.split()
+ new_words = []
+ # Below fixes large I instead of l. Does not catch everything, but should also not really make any mistakes either
+ for word in words:
+ clean_word = word
+ s = clean_word
+ if clean_word[1:].__contains__("I"):
+ indices = find(clean_word, "I")
+ for indx in indices:
+ if clean_word[indx-1].islower():
+ if len(clean_word) > indx + 1:
+ if clean_word[indx+1].islower():
+ s = s[:indx] + "l" + s[indx + 1:]
+ else:
+ s = s[:indx] + "l" + s[indx + 1:]
+ new_words.append(s)
+ new_line = " ".join(new_words)
+ return new_line
+
+
+def example_json(lang_id, text, eval_length=None):
+ if eval_length is not None:
+ text = text[:eval_length]
+ return {"text": text.strip(), "label": lang_id}
+
+
+if __name__ == "__main__":
+ main()
+
diff --git a/stanza/models/langid/data.py b/stanza/models/langid/data.py
new file mode 100644
index 00000000..b5e328cc
--- /dev/null
+++ b/stanza/models/langid/data.py
@@ -0,0 +1,136 @@
+import json
+import random
+import torch
+
+
+class DataLoader:
+ """
+ Class for loading language id data and providing batches
+ """
+
+ def __init__(self, use_gpu=None):
+ self.batches = None
+ self.batches_iter = None
+ self.tag_to_idx = None
+ self.idx_to_tag = None
+ self.lang_weights = None
+ # set self.use_gpu and self.device
+ if use_gpu is None:
+ self.use_gpu = torch.cuda.is_available()
+ else:
+ self.use_gpu = use_gpu
+ if self.use_gpu:
+ self.device = torch.device("cuda")
+ else:
+ self.device = None
+
+ def load_data(self, batch_size, data_files, char_index, tag_index, randomize=False, randomize_range=(5,20),
+ max_length=None):
+ """
+ Load sequence data and labels, calculate weights for weighted cross entropy loss.
+ Data is stored in a file, 1 example per line
+ Example: {"text": "Hello world.", "label": "en"}
+ """
+
+ # set up examples from data files
+ examples = []
+ for data_file in data_files:
+ examples += [x for x in open(data_file).read().split("\n") if x.strip()]
+ random.shuffle(examples)
+ examples = [json.loads(x) for x in examples]
+
+ # add additional labels in this data set to tag index
+ tag_index = dict(tag_index)
+ new_labels = set([x["label"] for x in examples]) - set(tag_index.keys())
+ for new_label in new_labels:
+ tag_index[new_label] = len(tag_index)
+ self.tag_to_idx = tag_index
+ self.idx_to_tag = [i[1] for i in sorted([(v,k) for k,v in self.tag_to_idx.items()])]
+
+ # set up lang counts used for weights for cross entropy loss
+ lang_counts = [0 for _ in tag_index]
+
+ # optionally limit text to max length
+ if max_length is not None:
+ examples = [{"text": x["text"][:max_length], "label": x["label"]} for x in examples]
+
+ # randomize data
+ if randomize:
+ split_examples = []
+ for example in examples:
+ sequence = example["text"]
+ label = example["label"]
+ sequences = DataLoader.randomize_data([sequence], upper_lim=randomize_range[1],
+ lower_lim=randomize_range[0])
+ split_examples += [{"text": seq, "label": label} for seq in sequences]
+ examples = split_examples
+ random.shuffle(examples)
+
+ # break into equal length batches
+ batch_lengths = {}
+ for example in examples:
+ sequence = example["text"]
+ label = example["label"]
+ if len(sequence) not in batch_lengths:
+ batch_lengths[len(sequence)] = []
+ sequence_as_list = [char_index.get(c, char_index["UNK"]) for c in list(sequence)]
+ batch_lengths[len(sequence)].append((sequence_as_list, tag_index[label]))
+ lang_counts[tag_index[label]] += 1
+ for length in batch_lengths:
+ random.shuffle(batch_lengths[length])
+
+ # create final set of batches
+ batches = []
+ for length in batch_lengths:
+ for sublist in [batch_lengths[length][i:i + batch_size] for i in
+ range(0, len(batch_lengths[length]), batch_size)]:
+ batches.append(sublist)
+
+ self.batches = [self.build_batch_tensors(batch) for batch in batches]
+
+ # set up lang weights
+ most_frequent = max(lang_counts)
+ # set to 0.0 if lang_count is 0 or most_frequent/lang_count otherwise
+ lang_counts = [(most_frequent * x)/(max(1, x) ** 2) for x in lang_counts]
+ self.lang_weights = torch.tensor(lang_counts, device=self.device, dtype=torch.float)
+
+ # shuffle batches to mix up lengths
+ random.shuffle(self.batches)
+ self.batches_iter = iter(self.batches)
+
+ @staticmethod
+ def randomize_data(sentences, upper_lim=20, lower_lim=5):
+ """
+ Takes the original data and creates random length examples with length between upper limit and lower limit
+ From LSTM_langid: https://github.com/AU-DIS/LSTM_langid/blob/main/src/language_datasets.py
+ """
+
+ new_data = []
+ for sentence in sentences:
+ remaining = sentence
+ while lower_lim < len(remaining):
+ lim = random.randint(lower_lim, upper_lim)
+ m = min(len(remaining), lim)
+ new_sentence = remaining[:m]
+ new_data.append(new_sentence)
+ split = remaining[m:].split(" ", 1)
+ if len(split) <= 1:
+ break
+ remaining = split[1]
+ random.shuffle(new_data)
+ return new_data
+
+ def build_batch_tensors(self, batch):
+ """
+ Helper to turn batches into tensors
+ """
+
+ batch_tensors = dict()
+ batch_tensors["sentences"] = torch.tensor([s[0] for s in batch], device=self.device, dtype=torch.long)
+ batch_tensors["targets"] = torch.tensor([s[1] for s in batch], device=self.device, dtype=torch.long)
+
+ return batch_tensors
+
+ def next(self):
+ return next(self.batches_iter)
+
diff --git a/stanza/models/langid/model.py b/stanza/models/langid/model.py
new file mode 100644
index 00000000..799030e3
--- /dev/null
+++ b/stanza/models/langid/model.py
@@ -0,0 +1,120 @@
+import torch
+import torch.nn as nn
+
+
+class LangIDBiLSTM(nn.Module):
+ """
+ Multi-layer BiLSTM model for language detecting. A recreation of "A reproduction of Apple's bi-directional LSTM models
+ for language identification in short strings." (Toftrup et al 2021)
+
+ Arxiv: https://arxiv.org/abs/2102.06282
+ GitHub: https://github.com/AU-DIS/LSTM_langid
+ """
+
+ def __init__(self, char_to_idx, tag_to_idx, num_layers, embedding_dim, hidden_dim, batch_size=64, weights=None,
+ dropout=0.0, lang_subset=None):
+ super(LangIDBiLSTM, self).__init__()
+ self.num_layers = num_layers
+ self.embedding_dim = embedding_dim
+ self.hidden_dim = hidden_dim
+ self.char_to_idx = char_to_idx
+ self.vocab_size = len(char_to_idx)
+ self.tag_to_idx = tag_to_idx
+ self.idx_to_tag = [i[1] for i in sorted([(v,k) for k,v in self.tag_to_idx.items()])]
+ self.lang_subset = lang_subset
+ self.padding_idx = char_to_idx["<PAD>"]
+ self.tagset_size = len(tag_to_idx)
+ self.batch_size = batch_size
+ self.loss_train = nn.CrossEntropyLoss(weight=weights)
+ self.dropout_prob = dropout
+
+ # embeddings for chars
+ self.char_embeds = nn.Embedding(
+ num_embeddings=self.vocab_size,
+ embedding_dim=self.embedding_dim,
+ padding_idx=self.padding_idx
+ )
+
+ # the bidirectional LSTM
+ self.lstm = nn.LSTM(
+ self.embedding_dim,
+ self.hidden_dim,
+ num_layers=self.num_layers,
+ bidirectional=True,
+ batch_first=True
+ )
+
+ # convert output to tag space
+ self.hidden_to_tag = nn.Linear(
+ self.hidden_dim * 2,
+ self.tagset_size
+ )
+
+ # dropout layer
+ self.dropout = nn.Dropout(p=self.dropout_prob)
+
+ def build_lang_mask(self, use_gpu=None):
+ """
+ Build language mask if a lang subset is specified (e.g. ["en", "fr"])
+ """
+ device = torch.device("cuda") if use_gpu else None
+ lang_mask_list = [int(lang in self.lang_subset) for lang in self.idx_to_tag] if self.lang_subset else \
+ [1 for lang in self.idx_to_tag]
+ self.lang_mask = torch.tensor(lang_mask_list, device=device, dtype=torch.float)
+
+ def loss(self, Y_hat, Y):
+ return self.loss_train(Y_hat, Y)
+
+ def forward(self, x):
+ # embed input
+ x = self.char_embeds(x)
+
+ # run through LSTM
+ x, _ = self.lstm(x)
+
+ # run through linear layer
+ x = self.hidden_to_tag(x)
+
+ # sum character outputs for each sequence
+ x = torch.sum(x, dim=1)
+
+ return x
+
+ def prediction_scores(self, x):
+ prediction_probs = self(x)
+ if self.lang_subset:
+ prediction_batch_size = prediction_probs.size()[0]
+ batch_mask = torch.stack([self.lang_mask for _ in range(prediction_batch_size)])
+ prediction_probs = prediction_probs * batch_mask
+ return torch.argmax(prediction_probs, dim=1)
+
+ def save(self, path):
+ """ Save a model at path """
+ checkpoint = {
+ "char_to_idx": self.char_to_idx,
+ "tag_to_idx": self.tag_to_idx,
+ "num_layers": self.num_layers,
+ "embedding_dim": self.embedding_dim,
+ "hidden_dim": self.hidden_dim,
+ "model_state_dict": self.state_dict()
+ }
+ torch.save(checkpoint, path)
+
+ @classmethod
+ def load(cls, path, use_cuda=False, batch_size=64, lang_subset=None):
+ """ Load a serialized model located at path """
+ if use_cuda:
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
+ else:
+ device = torch.device("cpu")
+ checkpoint = torch.load(path, map_location=torch.device("cpu"))
+ weights = checkpoint["model_state_dict"]["loss_train.weight"]
+ model = cls(checkpoint["char_to_idx"], checkpoint["tag_to_idx"], checkpoint["num_layers"],
+ checkpoint["embedding_dim"], checkpoint["hidden_dim"], batch_size=batch_size, weights=weights,
+ lang_subset=lang_subset)
+ model.load_state_dict(checkpoint["model_state_dict"])
+ if use_cuda:
+ model.to(torch.device("cuda"))
+ model.build_lang_mask(use_gpu=use_cuda)
+ return model
+
diff --git a/stanza/models/langid/trainer.py b/stanza/models/langid/trainer.py
new file mode 100644
index 00000000..6491508f
--- /dev/null
+++ b/stanza/models/langid/trainer.py
@@ -0,0 +1,53 @@
+import torch
+import torch.optim as optim
+
+from stanza.models.langid.model import LangIDBiLSTM
+
+
+class Trainer:
+
+ DEFAULT_BATCH_SIZE = 64
+ DEFAULT_LAYERS = 2
+ DEFAULT_EMBEDDING_DIM = 150
+ DEFAULT_HIDDEN_DIM = 150
+
+ def __init__(self, config, load_model=False, use_gpu=None):
+ self.model_path = config["model_path"]
+ self.use_gpu = torch.cuda.is_available() if use_gpu is None else use_gpu
+ self.device = torch.device("cuda") if self.use_gpu else None
+ self.batch_size = config.get("batch_size", Trainer.DEFAULT_BATCH_SIZE)
+ if load_model:
+ self.load(config["load_model"])
+ else:
+ self.model = LangIDBiLSTM(config["char_to_idx"], config["tag_to_idx"], Trainer.DEFAULT_LAYERS,
+ Trainer.DEFAULT_EMBEDDING_DIM,
+ Trainer.DEFAULT_HIDDEN_DIM,
+ batch_size=self.batch_size,
+ weights=config["lang_weights"]).to(self.device)
+ self.optimizer = optim.AdamW(self.model.parameters())
+
+ def update(self, inputs):
+ self.model.train()
+ sentences, targets = inputs
+ self.optimizer.zero_grad()
+ y_hat = self.model.forward(sentences)
+ loss = self.model.loss(y_hat, targets)
+ loss.backward()
+ self.optimizer.step()
+
+ def predict(self, inputs):
+ self.model.eval()
+ sentences, targets = inputs
+ return torch.argmax(self.model(sentences), dim=1)
+
+ def save(self, label=None):
+ # save a copy of model with label
+ if label:
+ self.model.save(f"{self.model_path[:-3]}-{label}.pt")
+ self.model.save(self.model_path)
+
+ def load(self, model_path=None):
+ if not model_path:
+ model_path = self.model_path
+ self.model = LangIDBiLSTM.load(model_path, self.use_gpu, self.batch_size)
+
diff --git a/stanza/models/ner/model.py b/stanza/models/ner/model.py
index efad8d51..dbd93b35 100644
--- a/stanza/models/ner/model.py
+++ b/stanza/models/ner/model.py
@@ -1,3 +1,4 @@
+import os
import numpy as np
import torch
import torch.nn as nn
@@ -35,6 +36,10 @@ class NERTagger(nn.Module):
if self.args['char'] and self.args['char_emb_dim'] > 0:
if self.args['charlm']:
+ if args['charlm_forward_file'] is None or not os.path.exists(args['charlm_forward_file']):
+ raise FileNotFoundError('Could not find forward character model: {} Please specify with --charlm_forward_file'.format(args['charlm_forward_file']))
+ if args['charlm_backward_file'] is None or not os.path.exists(args['charlm_backward_file']):
+ raise FileNotFoundError('Could not find backward character model: {} Please specify with --charlm_backward_file'.format(args['charlm_backward_file']))
add_unsaved_module('charmodel_forward', CharacterLanguageModel.load(args['charlm_forward_file'], finetune=False))
add_unsaved_module('charmodel_backward', CharacterLanguageModel.load(args['charlm_backward_file'], finetune=False))
input_size += self.charmodel_forward.hidden_dim() + self.charmodel_backward.hidden_dim()
diff --git a/stanza/models/ner_tagger.py b/stanza/models/ner_tagger.py
index 020e2c68..dd24dbbb 100644
--- a/stanza/models/ner_tagger.py
+++ b/stanza/models/ner_tagger.py
@@ -249,8 +249,9 @@ def train(args):
logger.info("Training ended with {} steps.".format(global_step))
- best_f, best_eval = max(dev_score_history)*100, np.argmax(dev_score_history)+1
- logger.info("Best dev F1 = {:.2f}, at iteration = {}".format(best_f, best_eval * args['eval_interval']))
+ if len(dev_score_history) > 0:
+ best_f, best_eval = max(dev_score_history)*100, np.argmax(dev_score_history)+1
+ logger.info("Best dev F1 = {:.2f}, at iteration = {}".format(best_f, best_eval * args['eval_interval']))
def evaluate(args):
# file paths
diff --git a/stanza/models/parser.py b/stanza/models/parser.py
index 4d605dcb..f74c8b6d 100644
--- a/stanza/models/parser.py
+++ b/stanza/models/parser.py
@@ -115,6 +115,7 @@ def model_file_name(args):
return os.path.join(args['save_dir'], save_name)
+# TODO: refactor with everywhere
def load_pretrain(args):
pt = None
if args['pretrain']:
diff --git a/stanza/models/pos/xpos_vocab_factory.py b/stanza/models/pos/xpos_vocab_factory.py
index 39da44fd..5abbb0ec 100644
--- a/stanza/models/pos/xpos_vocab_factory.py
+++ b/stanza/models/pos/xpos_vocab_factory.py
@@ -6,7 +6,7 @@ from stanza.models.pos.vocab import WordVocab, XPOSVocab
def xpos_vocab_factory(data, shorthand):
if shorthand in ["af_afribooms", "ar_padt", "bg_btb", "cs_cac", "cs_cltt", "cs_fictree", "cs_pdt", "en_partut", "fr_partut", "gd_arcosg", "gl_ctg", "gl_treegal", "grc_perseus", "hr_set", "is_icepahc", "is_modern", "it_combined", "it_isdt", "it_partut", "it_postwita", "it_twittiro", "it_vit", "la_perseus", "la_udante", "lt_alksnis", "lv_lvtb", "ro_nonstandard", "ro_rrt", "ro_simonero", "sk_snk", "sl_ssj", "sl_sst", "sr_set", "ta_ttb", "uk_iu"]:
return XPOSVocab(data, shorthand, idx=2, sep="")
- elif shorthand in ["be_hse", "ca_ancora", "cop_scriptorium", "cu_proiel", "cy_ccg", "da_ddt", "de_gsd", "de_hdt", "el_gdt", "en_combined", "en_ewt", "en_gum", "es_ancora", "es_gsd", "et_edt", "et_ewt", "eu_bdt", "fa_perdt", "fa_seraji", "fi_tdt", "fr_gsd", "fro_srcmf", "fr_sequoia", "fr_spoken", "ga_idt", "got_proiel", "grc_proiel", "he_htb", "hi_hdtb", "hu_szeged", "hy_armtdp", "hyw_armtdp", "id_csui", "ja_gsd", "la_proiel", "lt_hse", "lzh_kyoto", "mr_ufal", "mt_mudt", "nb_bokmaal", "nn_nynorsk", "nn_nynorsklia", "orv_rnc", "orv_torot", "pcm_nsc", "pt_bosque", "pt_gsd", "qtd_sagt", "ru_gsd", "ru_syntagrus", "ru_taiga", "sa_vedic", "sme_giella", "swl_sslc", "te_mtg", "tr_boun", "tr_framenet", "tr_imst", "tr_kenet", "tr_penn", "tr_tourism", "ug_udt", "vi_vtb", "wo_wtb", "zh_gsdsimp", "zh-hant_gsd", "bxr_bdt", "hsb_ufal", "ja_bccwj", "kk_ktb", "kmr_mg", "olo_kkpp"]:
+ elif shorthand in ["be_hse", "ca_ancora", "cop_scriptorium", "cu_proiel", "cy_ccg", "da_ddt", "de_gsd", "de_hdt", "el_gdt", "en_combined", "en_ewt", "en_gum", "es_ancora", "es_gsd", "es_combined", "et_edt", "et_ewt", "eu_bdt", "fa_perdt", "fa_seraji", "fi_tdt", "fr_gsd", "fro_srcmf", "fr_sequoia", "fr_spoken", "ga_idt", "got_proiel", "grc_proiel", "he_htb", "hi_hdtb", "hu_szeged", "hy_armtdp", "hyw_armtdp", "id_csui", "ja_gsd", "la_proiel", "lt_hse", "lzh_kyoto", "mr_ufal", "mt_mudt", "nb_bokmaal", "nn_nynorsk", "nn_nynorsklia", "orv_rnc", "orv_torot", "pcm_nsc", "pt_bosque", "pt_gsd", "qtd_sagt", "ru_gsd", "ru_syntagrus", "ru_taiga", "sa_vedic", "sme_giella", "swl_sslc", "te_mtg", "tr_boun", "tr_framenet", "tr_imst", "tr_kenet", "tr_penn", "tr_tourism", "ug_udt", "vi_vtb", "wo_wtb", "zh_gsdsimp", "zh-hant_gsd", "bxr_bdt", "hsb_ufal", "ja_bccwj", "kk_ktb", "kmr_mg", "olo_kkpp"]:
return WordVocab(data, shorthand, idx=2, ignore=["_"])
elif shorthand in ["en_lines", "fo_farpahc", "sv_lines", "ur_udtb"]:
return XPOSVocab(data, shorthand, idx=2, sep="-")
diff --git a/stanza/models/tokenization/data.py b/stanza/models/tokenization/data.py
index ce059b6a..9039f818 100644
--- a/stanza/models/tokenization/data.py
+++ b/stanza/models/tokenization/data.py
@@ -1,14 +1,11 @@
from bisect import bisect_right
from copy import copy
-import json
import numpy as np
import random
import logging
import re
import torch
-
from .vocab import Vocab
-
logger = logging.getLogger('stanza')
def filter_consecutive_whitespaces(para):
@@ -26,11 +23,11 @@ NEWLINE_WHITESPACE_RE = re.compile(r'\n\s*\n')
NUMERIC_RE = re.compile(r'^([\d]+[,\.]*)+$')
WHITESPACE_RE = re.compile(r'\s')
-
class DataLoader:
- def __init__(self, args, input_files={'txt': None, 'label': None}, input_text=None, input_data=None, vocab=None, evaluation=False):
+ def __init__(self, args, input_files={'txt': None, 'label': None}, input_text=None, input_data=None, vocab=None, evaluation=False, dictionary=None):
self.args = args
self.eval = evaluation
+ self.dictionary = dictionary
# get input files
txt_file = input_files['txt']
@@ -107,8 +104,6 @@ class DataLoader:
func = lambda x: 1 if x.startswith(' ') else 0
elif feat_func == 'capitalized':
func = lambda x: 1 if x[0].isupper() else 0
- elif feat_func == 'all_caps':
- func = lambda x: 1 if x.isupper() else 0
elif feat_func == 'numeric':
func = lambda x: 1 if (NUMERIC_RE.match(x) is not None) else 0
else:
@@ -119,6 +114,40 @@ class DataLoader:
# stacking all featurize functions
composite_func = lambda x: [f(x) for f in funcs]
+ length = len(para)
+ #This function is to extract dictionary features for each character
+ def extract_dict_feat(idx):
+ dict_forward_feats = [0 for i in range(self.args['num_dict_feat'])]
+ dict_backward_feats = [0 for i in range(self.args['num_dict_feat'])]
+ forward_word = para[idx][0]
+ backward_word = para[idx][0]
+ prefix = True
+ suffix = True
+ for window in range(1,self.args['num_dict_feat']+1):
+ # concatenate each character and check if words found in dict not, stop if prefix not found
+ #check if idx+t is out of bound and if the prefix is already not found
+ if (idx + window) <= length-1 and prefix:
+ forward_word += para[idx+window][0].lower()
+ #check in json file if the word is present as prefix or word or None.
+ feat = 1 if forward_word in self.dictionary["words"] else 0
+ #if the return value is not 2 or 3 then the checking word is not a valid word in dict.
+ dict_forward_feats[window-1] = feat
+ #if the dict return 0 means no prefixes found, thus, stop looking for forward.
+ if forward_word not in self.dictionary["prefixes"]:
+ prefix = False
+ #backward check: similar to forward
+ if (idx - window) >= 0 and suffix:
+ backward_word = para[idx-window][0].lower() + backward_word
+ feat = 1 if backward_word in self.dictionary["words"] else 0
+ dict_backward_feats[window-1] = feat
+ if backward_word not in self.dictionary["suffixes"]:
+ suffix = False
+ #if cannot find both prefix and suffix, then exit the loop
+ if not prefix and not suffix:
+ break
+
+ return dict_forward_feats + dict_backward_feats
+
def process_sentence(sent):
return [self.vocab.unit2id(y[0]) for y in sent], [y[1] for y in sent], [y[2] for y in sent], [y[0] for y in sent]
@@ -135,6 +164,12 @@ class DataLoader:
if use_start_of_para:
f = 1 if i == 0 else 0
feats.append(f)
+
+ #if dictionary feature is selected
+ if self.args['use_dictionary']:
+ dict_feats = extract_dict_feat(i)
+ feats = feats + dict_feats
+
current += [(unit, label, feats)]
if label1 == 2 or label1 == 4: # end of sentence
if len(current) <= self.args['max_seqlen']:
@@ -156,7 +191,7 @@ class DataLoader:
random.shuffle(para)
self.init_sent_ids()
- def next(self, eval_offsets=None, unit_dropout=0.0, old_batch=None):
+ def next(self, eval_offsets=None, unit_dropout=0.0, old_batch=None, feat_unit_dropout=0.0):
''' Get a batch of converted and padded PyTorch data from preprocessed raw text for training/prediction. '''
feat_size = len(self.sentences[0][0][2][0])
unkid = self.vocab.unit2id('<UNK>')
@@ -277,6 +312,18 @@ class DataLoader:
if mask[i, j]:
raw_units[i][j] = '<UNK>'
+ # dropout unit feature vector in addition to only torch.dropout in the model.
+ # experiments showed that only torch.dropout hurts the model
+ # we believe it is because the dict feature vector is mostly scarse so it makes
+ # more sense to drop out the whole vector instead of only single element.
+ if self.args['use_dictionary'] and feat_unit_dropout > 0 and not self.eval:
+ mask_feat = np.random.random_sample(units.shape) < feat_unit_dropout
+ mask_feat[units == padid] = 0
+ for i in range(len(raw_units)):
+ for j in range(len(raw_units[i])):
+ if mask_feat[i,j]:
+ features[i,j,:] = 0
+
units = torch.from_numpy(units)
labels = torch.from_numpy(labels)
features = torch.from_numpy(features)
diff --git a/stanza/models/tokenization/model.py b/stanza/models/tokenization/model.py
index 8c4f3198..1f609871 100644
--- a/stanza/models/tokenization/model.py
+++ b/stanza/models/tokenization/model.py
@@ -3,7 +3,7 @@ import torch.nn.functional as F
import torch.nn as nn
class Tokenizer(nn.Module):
- def __init__(self, args, nchars, emb_dim, hidden_dim, dropout):
+ def __init__(self, args, nchars, emb_dim, hidden_dim, dropout, feat_dropout):
super().__init__()
self.args = args
@@ -37,12 +37,15 @@ class Tokenizer(nn.Module):
self.mwt_clf2 = nn.Linear(hidden_dim * 2, 1, bias=False)
self.dropout = nn.Dropout(dropout)
+ self.dropout_feat = nn.Dropout(feat_dropout)
+
self.toknoise = nn.Dropout(self.args['tok_noise'])
def forward(self, x, feats):
emb = self.embeddings(x)
-
emb = self.dropout(emb)
+ feats = self.dropout_feat(feats)
+
emb = torch.cat([emb, feats], 2)
diff --git a/stanza/models/tokenization/trainer.py b/stanza/models/tokenization/trainer.py
index bb0deb85..c40b70d2 100644
--- a/stanza/models/tokenization/trainer.py
+++ b/stanza/models/tokenization/trainer.py
@@ -5,6 +5,7 @@ import torch.nn as nn
import torch.optim as optim
from stanza.models.common.trainer import Trainer as BaseTrainer
+from stanza.models.tokenization.utils import create_dictionary
from .model import Tokenizer
from .vocab import Vocab
@@ -12,7 +13,7 @@ from .vocab import Vocab
logger = logging.getLogger('stanza')
class Trainer(BaseTrainer):
- def __init__(self, args=None, vocab=None, model_file=None, use_cuda=False):
+ def __init__(self, args=None, vocab=None, lexicon=None, dictionary=None, model_file=None, use_cuda=False):
self.use_cuda = use_cuda
if model_file is not None:
# load everything from file
@@ -21,7 +22,9 @@ class Trainer(BaseTrainer):
# build model from scratch
self.args = args
self.vocab = vocab
- self.model = Tokenizer(self.args, self.args['vocab_size'], self.args['emb_dim'], self.args['hidden_dim'], dropout=self.args['dropout'])
+ self.lexicon = lexicon
+ self.dictionary = dictionary
+ self.model = Tokenizer(self.args, self.args['vocab_size'], self.args['emb_dim'], self.args['hidden_dim'], dropout=self.args['dropout'], feat_dropout=self.args['feat_dropout'])
self.criterion = nn.CrossEntropyLoss(ignore_index=-1)
if use_cuda:
self.model.cuda()
@@ -72,6 +75,7 @@ class Trainer(BaseTrainer):
params = {
'model': self.model.state_dict() if self.model is not None else None,
'vocab': self.vocab.state_dict(),
+ 'lexicon': self.lexicon,
'config': self.args
}
try:
@@ -91,6 +95,12 @@ class Trainer(BaseTrainer):
# Default to True as many currently saved models
# were built with mwt layers
self.args['use_mwt'] = True
- self.model = Tokenizer(self.args, self.args['vocab_size'], self.args['emb_dim'], self.args['hidden_dim'], dropout=self.args['dropout'])
+ self.model = Tokenizer(self.args, self.args['vocab_size'], self.args['emb_dim'], self.args['hidden_dim'], dropout=self.args['dropout'], feat_dropout=self.args['feat_dropout'])
self.model.load_state_dict(checkpoint['model'])
self.vocab = Vocab.load_state_dict(checkpoint['vocab'])
+ self.lexicon = checkpoint['lexicon']
+
+ if self.lexicon is not None:
+ self.dictionary = create_dictionary(self.lexicon)
+ else:
+ self.dictionary = None
diff --git a/stanza/models/tokenization/utils.py b/stanza/models/tokenization/utils.py
index ea7bda47..28156a44 100644
--- a/stanza/models/tokenization/utils.py
+++ b/stanza/models/tokenization/utils.py
@@ -4,12 +4,137 @@ import json
import numpy as np
import re
import logging
+import os
+import stanza.utils.default_paths as default_paths
from stanza.models.common.utils import ud_scores, harmonic_mean
from stanza.utils.conll import CoNLL
from stanza.models.common.doc import *
logger = logging.getLogger('stanza')
+paths = default_paths.get_default_paths()
+
+def create_dictionary(lexicon=None):
+ """
+ This function is to create a new dictionary used for improving tokenization model for multi-syllable words languages
+ such as vi, zh or th. This function takes the lexicon as input and output a dictionary that contains three set:
+ words, prefixes and suffixes where prefixes set should contains all the prefixes in the lexicon and similar for suffixes.
+ The point of having prefixes/suffixes sets in the dictionary is just to make it easier to check during data preparation.
+
+ :param shorthand - language and dataset, eg: vi_vlsp, zh_gsdsimp
+ :param lexicon - set of words used to create dictionary
+ :return a dictionary object that contains words and their prefixes and suffixes.
+ """
+
+ dictionary = {"words":set(), "prefixes":set(), "suffixes":set()}
+
+ def add_word(word):
+ if word not in dictionary["words"]:
+ dictionary["words"].add(word)
+ prefix = ""
+ suffix = ""
+ for i in range(0,len(word)-1):
+ prefix = prefix + word[i]
+ suffix = word[len(word) - i - 1] + suffix
+ dictionary["prefixes"].add(prefix)
+ dictionary["suffixes"].add(suffix)
+
+ for word in lexicon:
+ if len(word)>1:
+ add_word(word)
+
+ return dictionary
+def create_lexicon(shorthand=None, train_path=None, external_path=None):
+ """
+ This function is to create a lexicon to store all the words from the training set and external dictionary.
+ This lexicon will be saved with the model and will be used to create dictionary when the model is loaded.
+ The idea of separating lexicon and dictionary in two different phases is a good tradeoff between time and space.
+ Note that we eliminate all the long words but less frequently appeared in the lexicon by only taking 95-percentile
+ list of words.
+
+ :param shorthand - language and dataset, eg: vi_vlsp, zh_gsdsimp
+ :param train_path - path to conllu train file
+ :param external_path - path to extenral dict, expected to be inside the training dataset dir with format of: SHORTHAND-externaldict.txt
+ :return a set lexicon object that contains all distinct words
+ """
+ lexicon = set()
+ length_freq = []
+ #this regex is to check if a character is an actual Thai character as seems .isalpha() python method doesn't pick up Thai accent characters..
+ pattern_thai = re.compile(r"(?:[^\d\W]+)|\s")
+
+ def check_valid_word(shorthand, word):
+ """
+ This function is to check if the word are multi-syllable words and not numbers.
+ For vi, whitespaces are syllabe-separator.
+ """
+ if shorthand.startswith("vi_"):
+ return True if len(word.split(" ")) > 1 and any(map(str.isalpha, word)) and not any(map(str.isdigit, word)) else False
+ elif shorthand.startswith("th_"):
+ return True if len(word) > 1 and any(map(pattern_thai.match, word)) and not any(map(str.isdigit, word)) else False
+ else:
+ return True if len(word) > 1 and any(map(str.isalpha, word)) and not any(map(str.isdigit, word)) else False
+
+ #checking for words in the training set to add them to lexicon.
+ if train_path is not None:
+ if not os.path.isfile(train_path):
+ raise FileNotFoundError(f"Cannot open train set at {train_path}")
+
+ doc_conll,_ = CoNLL.conll2dict(input_file=train_path)
+
+ for sent_conll in doc_conll:
+ for token_conll in sent_conll:
+ word = token_conll['text'].lower()
+ if check_valid_word(shorthand, word) and word not in lexicon:
+ lexicon.add(word)
+ length_freq.append(len(word))
+ count_word = len(lexicon)
+ logger.info(f"Added {count_word} words from the training data to the lexicon.")
+
+ #checking for external dictionary and add them to lexicon.
+ if external_path is not None:
+ if not os.path.isfile(external_path):
+ raise FileNotFoundError(f"Cannot open external dictionary at {external_path}")
+
+ with open(external_path, "r", encoding="utf-8") as external_file:
+ lines = external_file.readlines()
+ for line in lines:
+ word = line.lower()
+ word = word.replace("\n","")
+ if check_valid_word(shorthand, word) and word not in lexicon:
+ lexicon.add(word)
+ length_freq.append(len(word))
+ logger.info(f"Added another {len(lexicon) - count_word} words from the external dict to dictionary.")
+
+
+ #automatically calculate the number of dictionary features (window size to look for words) based on the frequency of word length
+ #take the length at 95-percentile to eliminate all the longest (maybe) compounds words in the lexicon
+ num_dict_feat = int(np.percentile(length_freq, 95))
+ lexicon = {word for word in lexicon if len(word) <= num_dict_feat }
+ logger.info(f"Final lexicon consists of {len(lexicon)} words after getting rid of long words.")
+
+ return lexicon, num_dict_feat
+
+def load_lexicon(args):
+ """
+ This function is to create a new dictionary and load it to training.
+ The external dictionary is expected to be inside the training dataset dir with format of: SHORTHAND-externaldict.txt
+ For example, vi_vlsp-externaldict.txt
+ """
+ shorthand = args["shorthand"]
+ tokenize_dir = paths["TOKENIZE_DATA_DIR"]
+ train_path = f"{tokenize_dir}/{shorthand}.train.gold.conllu"
+ external_dict_path = f"{tokenize_dir}/{shorthand}-externaldict.txt"
+ if not os.path.exists(external_dict_path):
+ logger.info("External dictionary not found! Checking training data...")
+ external_dict_path = None
+ if not os.path.exists(train_path):
+ logger.info(f"Training dataset does not exist, thus cannot create dictionary {shorthand}")
+ train_path = None
+ if train_path is None and external_dict_path is None:
+ raise FileNotFoundError(f"Cannot find training set / external dictionary at {train_path} and {external_dict_path}")
+
+ return create_lexicon(shorthand, train_path, external_dict_path)
+
def load_mwt_dict(filename):
if filename is not None:
diff --git a/stanza/models/tokenizer.py b/stanza/models/tokenizer.py
index 54bc729f..e88b746e 100644
--- a/stanza/models/tokenizer.py
+++ b/stanza/models/tokenizer.py
@@ -4,6 +4,15 @@ Entry point for training and evaluating a neural tokenizer.
This tokenizer treats tokenization and sentence segmentation as a tagging problem, and uses a combination of
recurrent and convolutional architectures.
For details please refer to paper: https://nlp.stanford.edu/pubs/qi2018universal.pdf.
+
+Updated: This new version of tokenizer model incorporates the dictionary feature, especially useful for languages that
+have multi-syllable words such as Vietnamese, Chinese or Thai. In summary, a lexicon contains all unique words found in
+training dataset and external lexicon (if any) is created during training and saved alongside the model after training.
+Using this lexicon, a dictionary is created which includes "words", "prefixes" and "suffixes" sets. During data preparation,
+dictionary features are extracted at each character position, to "look ahead" and "look backward" to see if any words formed
+found in the dictionary. The window size (or the dictionary feature length) is defined at the 95-percentile among all the existing
+words in the lexicon, this is to eliminate the less frequent but long words (avoid having a high-dimension feat vector). Prefixes
+and suffixes are used to stop early during the window-dictionary checking process.
"""
import argparse
@@ -13,11 +22,11 @@ import random
import numpy as np
import os
import torch
-
+import json
from stanza.models.common import utils
from stanza.models.tokenization.trainer import Trainer
from stanza.models.tokenization.data import DataLoader
-from stanza.models.tokenization.utils import load_mwt_dict, eval_model, output_predictions
+from stanza.models.tokenization.utils import load_mwt_dict, eval_model, output_predictions, load_lexicon, create_dictionary
from stanza.models import _training_logging
logger = logging.getLogger('stanza')
@@ -49,6 +58,7 @@ def parse_args(args=None):
parser.add_argument('--input_dropout', action='store_true', help="Dropout input embeddings as well")
parser.add_argument('--conv_res', type=str, default=None, help="Convolutional residual layers for the RNN")
parser.add_argument('--rnn_layers', type=int, default=1, help="Layers of RNN in the tokenizer")
+ parser.add_argument('--use_dictionary', action='store_true', help="Use dictionary feature. The lexicon is created using the training data and external dict (if any) expected to be found under the same folder of training dataset, formatted as SHORTHAND-externaldict.txt where each line in this file is a word. For example, data/tokenize/zh_gsdsimp-externaldict.txt")
parser.add_argument('--max_grad_norm', type=float, default=1.0, help="Maximum gradient norm to clip to")
parser.add_argument('--anneal', type=float, default=.999, help="Anneal the learning rate by this amount when dev performance deteriorate")
@@ -56,6 +66,8 @@ def parse_args(args=None):
parser.add_argument('--lr0', type=float, default=2e-3, help="Initial learning rate")
parser.add_argument('--dropout', type=float, default=0.33, help="Dropout probability")
parser.add_argument('--unit_dropout', type=float, default=0.33, help="Unit dropout probability")
+ parser.add_argument('--feat_dropout', type=float, default=0.05, help="Features dropout probability for each element in feature vector")
+ parser.add_argument('--feat_unit_dropout', type=float, default=0.33, help="The whole feature of units dropout probability")
parser.add_argument('--tok_noise', type=float, default=0.02, help="Probability to induce noise to the input of the higher RNN")
parser.add_argument('--sent_drop_prob', type=float, default=0.2, help="Probability to drop sentences at the end of batches during training uniformly at random. Idea is to fake paragraph endings.")
parser.add_argument('--weight_decay', type=float, default=0.0, help="Weight decay")
@@ -90,7 +102,7 @@ def main(args=None):
args = vars(args)
logger.info("Running tokenizer in {} mode".format(args['mode']))
- args['feat_funcs'] = ['space_before', 'capitalized', 'all_caps', 'numeric']
+ args['feat_funcs'] = ['space_before', 'capitalized', 'numeric', 'end_of_para', 'start_of_para']
args['feat_dim'] = len(args['feat_funcs'])
save_name = args['save_name'] if args['save_name'] else '{}_tokenizer.pt'.format(args['shorthand'])
args['save_name'] = os.path.join(args['save_dir'], save_name)
@@ -102,27 +114,40 @@ def main(args=None):
evaluate(args)
def train(args):
+ if args['use_dictionary']:
+ #load lexicon
+ lexicon, args['num_dict_feat'] = load_lexicon(args)
+ #create the dictionary
+ dictionary = create_dictionary(lexicon)
+ #adjust the feat_dim
+ args['feat_dim'] += args['num_dict_feat']*2
+ else:
+ args['num_dict_feat'] = 0
+ lexicon=None
+ dictionary=None
+
mwt_dict = load_mwt_dict(args['mwt_json_file'])
train_input_files = {
'txt': args['txt_file'],
'label': args['label_file']
}
- train_batches = DataLoader(args, input_files=train_input_files)
+ train_batches = DataLoader(args, input_files=train_input_files, dictionary=dictionary)
vocab = train_batches.vocab
+
args['vocab_size'] = len(vocab)
dev_input_files = {
'txt': args['dev_txt_file'],
'label': args['dev_label_file']
}
- dev_batches = DataLoader(args, input_files=dev_input_files, vocab=vocab, evaluation=True)
+ dev_batches = DataLoader(args, input_files=dev_input_files, vocab=vocab, evaluation=True, dictionary=dictionary)
if args['use_mwt'] is None:
args['use_mwt'] = train_batches.has_mwt()
logger.info("Found {}mwts in the training data. Setting use_mwt to {}".format(("" if args['use_mwt'] else "no "), args['use_mwt']))
- trainer = Trainer(args=args, vocab=vocab, use_cuda=args['cuda'])
+ trainer = Trainer(args=args, vocab=vocab, lexicon=lexicon, dictionary=dictionary, use_cuda=args['cuda'])
if args['load_name'] is not None:
load_name = os.path.join(args['save_dir'], args['load_name'])
@@ -138,7 +163,7 @@ def train(args):
best_dev_step = -1
for step in range(1, steps+1):
- batch = train_batches.next(unit_dropout=args['unit_dropout'])
+ batch = train_batches.next(unit_dropout=args['unit_dropout'], feat_unit_dropout = args['feat_unit_dropout'])
loss = trainer.update(batch)
if step % args['report_steps'] == 0:
@@ -180,16 +205,18 @@ def evaluate(args):
use_cuda = args['cuda'] and not args['cpu']
trainer = Trainer(model_file=args['load_name'] or args['save_name'], use_cuda=use_cuda)
loaded_args, vocab = trainer.args, trainer.vocab
+
for k in loaded_args:
if not k.endswith('_file') and k not in ['cuda', 'mode', 'save_dir', 'load_name', 'save_name']:
args[k] = loaded_args[k]
-
+
eval_input_files = {
'txt': args['txt_file'],
'label': args['label_file']
}
- batches = DataLoader(args, input_files=eval_input_files, vocab=vocab, evaluation=True)
+
+ batches = DataLoader(args, input_files=eval_input_files, vocab=vocab, evaluation=True, dictionary=trainer.dictionary)
oov_count, N, _, _ = output_predictions(args['conll_file'], trainer, batches, vocab, mwt_dict, args['max_seqlen'])
diff --git a/stanza/pipeline/_constants.py b/stanza/pipeline/_constants.py
index 865a185e..eff758f8 100644
--- a/stanza/pipeline/_constants.py
+++ b/stanza/pipeline/_constants.py
@@ -1,6 +1,7 @@
""" Module defining constants """
# string constants for processor names
+LANGID = 'langid'
TOKENIZE = 'tokenize'
MWT = 'mwt'
POS = 'pos'
@@ -8,3 +9,4 @@ LEMMA = 'lemma'
DEPPARSE = 'depparse'
NER = 'ner'
SENTIMENT = 'sentiment'
+CONSTITUENCY = 'constituency'
diff --git a/stanza/pipeline/constituency_processor.py b/stanza/pipeline/constituency_processor.py
new file mode 100644
index 00000000..c9305bf6
--- /dev/null
+++ b/stanza/pipeline/constituency_processor.py
@@ -0,0 +1,52 @@
+"""Processor that attaches a constituency tree to a sentence
+
+The model used is a generally a model trained on the Stanford
+Sentiment Treebank or some similar dataset. When run, this processor
+attaches a score in the form of a string to each sentence in the
+document.
+
+TODO: a possible way to generalize this would be to make it a
+ClassifierProcessor and have "sentiment" be an option.
+"""
+
+import stanza.models.constituency.trainer as trainer
+
+from stanza.models.common import doc
+from stanza.models.common.pretrain import Pretrain
+from stanza.pipeline._constants import *
+from stanza.pipeline.processor import UDProcessor, register_processor
+
+@register_processor(CONSTITUENCY)
+class ConstituencyProcessor(UDProcessor):
+ # set of processor requirements this processor fulfills
+ PROVIDES_DEFAULT = set([CONSTITUENCY])
+ # set of processor requirements for this processor
+ REQUIRES_DEFAULT = set([TOKENIZE, POS])
+
+ # default batch size, measured in sentences
+ DEFAULT_BATCH_SIZE = 50
+
+ def _set_up_model(self, config, use_gpu):
+ # get pretrained word vectors
+ pretrain_path = config.get('pretrain_path', None)
+ self._pretrain = Pretrain(pretrain_path) if pretrain_path else None
+ # set up model
+ charlm_forward_file = config.get('forward_charlm_path', None)
+ charlm_backward_file = config.get('backward_charlm_path', None)
+ self._model = trainer.Trainer.load(filename=config['model_path'],
+ pt=self._pretrain,
+ forward_charlm=trainer.load_charlm(charlm_forward_file),
+ backward_charlm=trainer.load_charlm(charlm_backward_file),
+ use_gpu=use_gpu)
+ # batch size counted as sentences
+ self._batch_size = config.get('batch_size', ConstituencyProcessor.DEFAULT_BATCH_SIZE)
+
+ def process(self, document):
+ sentences = document.sentences
+ # TODO: perhaps MWT should be relevant here?
+ # certainly parsing across an MWT boundary is an error
+ # TODO: maybe some constituency models are trained on UPOS not XPOS
+ words = [[(w.text, w.xpos) for w in s.words] for s in sentences]
+ trees = trainer.parse_tagged_words(self._model.model, words, self._batch_size)
+ document.set(CONSTITUENCY, trees, to_sentence=True)
+ return document
diff --git a/stanza/pipeline/core.py b/stanza/pipeline/core.py
index acdbfce0..fcb63d18 100644
--- a/stanza/pipeline/core.py
+++ b/stanza/pipeline/core.py
@@ -15,12 +15,14 @@ from stanza.pipeline._constants import *
from stanza.models.common.doc import Document
from stanza.pipeline.processor import Processor, ProcessorRequirementsException
from stanza.pipeline.registry import NAME_TO_PROCESSOR_CLASS, PIPELINE_NAMES
+from stanza.pipeline.langid_processor import LangIDProcessor
from stanza.pipeline.tokenize_processor import TokenizeProcessor
from stanza.pipeline.mwt_processor import MWTProcessor
from stanza.pipeline.pos_processor import POSProcessor
from stanza.pipeline.lemma_processor import LemmaProcessor
from stanza.pipeline.depparse_processor import DepparseProcessor
from stanza.pipeline.sentiment_processor import SentimentProcessor
+from stanza.pipeline.constituency_processor import ConstituencyProcessor
from stanza.pipeline.ner_processor import NERProcessor
from stanza.resources.common import DEFAULT_MODEL_DIR, \
maintain_processor_list, add_dependencies, add_mwt, build_default_config, set_logging_level, process_pipeline_parameters, sort_processors
diff --git a/stanza/pipeline/langid_processor.py b/stanza/pipeline/langid_processor.py
new file mode 100644
index 00000000..a512196e
--- /dev/null
+++ b/stanza/pipeline/langid_processor.py
@@ -0,0 +1,126 @@
+"""
+Processor for determining language of text.
+"""
+
+import emoji
+import re
+import stanza
+import torch
+
+from stanza.models.common.doc import Document
+from stanza.models.langid.model import LangIDBiLSTM
+from stanza.pipeline._constants import *
+from stanza.pipeline.processor import UDProcessor, register_processor
+
+
+@register_processor(name=LANGID)
+class LangIDProcessor(UDProcessor):
+ """
+ Class for detecting language of text.
+ """
+
+ # set of processor requirements this processor fulfills
+ PROVIDES_DEFAULT = set([LANGID])
+
+ # set of processor requirements for this processor
+ REQUIRES_DEFAULT = set([])
+
+ # default max sequence length
+ MAX_SEQ_LENGTH_DEFAULT = 1000
+
+ def _set_up_model(self, config, use_gpu):
+ batch_size = config.get("batch_size", 64)
+ self._model = LangIDBiLSTM.load(path=config["model_path"], use_cuda=use_gpu,
+ batch_size=batch_size, lang_subset=config.get("lang_subset"))
+ self._device = torch.device("cuda") if use_gpu else None
+ self._char_index = self._model.char_to_idx
+ self._clean_text = config.get("clean_text")
+
+ def _text_to_tensor(self, docs):
+ """
+ Map list of strings to batch tensor. Assumed all docs are same length.
+ """
+
+ all_docs = []
+ for doc in docs:
+ doc_chars = [self._char_index.get(c, self._char_index["UNK"]) for c in list(doc)]
+ all_docs.append(doc_chars)
+ return torch.tensor(all_docs, device=self._device, dtype=torch.long)
+
+ def _id_langs(self, batch_tensor):
+ """
+ Identify languages for each sequence in a batch tensor
+ """
+ predictions = self._model.prediction_scores(batch_tensor)
+ prediction_labels = [self._model.idx_to_tag[prediction] for prediction in predictions]
+
+ return prediction_labels
+
+ # regexes for cleaning text
+ http_regex = re.compile("https?:\/\/t\.co/[a-zA-Z0-9]+")
+ handle_regex = re.compile("@[a-zA-Z0-9_]+")
+ hashtag_regex = re.compile("#[a-zA-Z]+")
+ punctuation_regex = re.compile("[!.]+")
+ all_regexes = [http_regex, handle_regex, hashtag_regex, punctuation_regex]
+
+ @staticmethod
+ def clean_text(text):
+ """
+ Process text to improve language id performance. Main emphasis is on tweets, this method removes shortened
+ urls, hashtags, handles, and punctuation and emoji.
+ """
+
+ for regex in LangIDProcessor.all_regexes:
+ text = regex.sub(" ", text)
+
+ text = emoji.get_emoji_regexp().sub(" ", text)
+
+ if text.strip():
+ text = text.strip()
+
+ return text
+
+ def _process_list(self, docs):
+ """
+ Identify language of list of strings or Documents
+ """
+
+ if len(docs) == 0:
+ # TO DO: what standard do we want for bad input, such as empty list?
+ # TO DO: more handling of bad input
+ return
+
+ if isinstance(docs[0], str):
+ docs = [Document([], text) for text in docs]
+
+ docs_by_length = {}
+ for doc in docs:
+ text = LangIDProcessor.clean_text(doc.text) if self._clean_text else doc.text
+ doc_length = len(text)
+ if doc_length not in docs_by_length:
+ docs_by_length[doc_length] = []
+ docs_by_length[doc_length].append((doc, text))
+
+ for doc_length in docs_by_length:
+ inputs = [doc[1] for doc in docs_by_length[doc_length]]
+ predictions = self._id_langs(self._text_to_tensor(inputs))
+ for doc, lang in zip(docs_by_length[doc_length], predictions):
+ doc[0].lang = lang
+
+ return docs
+
+ def process(self, doc):
+ """
+ Handle single str or Document
+ """
+
+ wrapped_doc = [doc]
+ return self._process_list(wrapped_doc)[0]
+
+ def bulk_process(self, docs):
+ """
+ Handle list of strings or Documents
+ """
+
+ return self._process_list(docs)
+
diff --git a/stanza/pipeline/multilingual.py b/stanza/pipeline/multilingual.py
new file mode 100644
index 00000000..55056c77
--- /dev/null
+++ b/stanza/pipeline/multilingual.py
@@ -0,0 +1,109 @@
+"""
+Class for running multilingual pipelines
+"""
+
+import torch
+
+from stanza.models.common.doc import Document
+from stanza.pipeline.core import Pipeline
+from stanza.pipeline._constants import *
+from stanza.resources.common import DEFAULT_MODEL_DIR
+
+
+class MultilingualPipeline:
+ """
+ Pipeline for handling multilingual data. Takes in text, detects language, and routes request to pipeline for that
+ language.
+ """
+
+ def __init__(
+ self,
+ model_dir: str = DEFAULT_MODEL_DIR,
+ lang_id_config: dict = None,
+ lang_configs: dict = None,
+ ld_batch_size: int = 64,
+ max_cache_size: int = 10,
+ use_gpu: bool = None
+ ):
+ # set up configs and cache for various language pipelines
+ self.model_dir = model_dir
+ self.lang_id_config = {} if lang_id_config is None else lang_id_config
+ self.lang_configs = {} if lang_configs is None else lang_configs
+ self.max_cache_size = max_cache_size
+ self.pipeline_cache = {}
+ self.lang_request_history = []
+
+ # set use_gpu
+ if use_gpu is None:
+ self.use_gpu = torch.cuda.is_available()
+ else:
+ self.use_gpu = use_gpu
+
+ # build language id pipeline
+ self.lang_id_pipeline = Pipeline(dir=self.model_dir, lang='multilingual', processors="langid",
+ use_gpu=self.use_gpu, **self.lang_id_config)
+
+ def _update_pipeline_cache(self, lang):
+ """
+ Do any necessary updates to the pipeline cache for this language. This includes building a new
+ pipeline for the lang, and possibly clearing out a language with the old last access date.
+ """
+
+ # update request history
+ if lang in self.lang_request_history:
+ self.lang_request_history.remove(lang)
+ self.lang_request_history.append(lang)
+
+ # update language configs
+ if lang not in self.lang_configs:
+ self.lang_configs[lang] = {'lang': lang}
+
+ # update pipeline cache
+ if lang not in self.pipeline_cache:
+ # clear least recently used lang from pipeline cache
+ if len(self.pipeline_cache) == self.max_cache_size:
+ lru_lang = self.lang_request_history[0]
+ self.pipeline_cache.remove(lru_lang)
+ self.lang_request_history.remove(lru_lang)
+ self.pipeline_cache[lang] = Pipeline(dir=self.model_dir, **self.lang_configs[lang])
+
+ def process(self, doc):
+ """
+ Run language detection on a string, a Document, or a list of either, route to language specific pipeline
+ """
+
+ # only return a list if given a list
+ singleton_input = not isinstance(doc, list)
+ if singleton_input:
+ docs = [doc]
+ else:
+ docs = doc
+
+ if docs and isinstance(docs[0], str):
+ docs = [Document([], text=text) for text in docs]
+
+ # run language identification
+ docs_w_langid = self.lang_id_pipeline.process(docs)
+
+ # create language specific batches, store global idx with each doc
+ lang_batches = {}
+ for doc in docs_w_langid:
+ if doc.lang not in lang_batches:
+ lang_batches[doc.lang] = []
+ lang_batches[doc.lang].append(doc)
+
+ # run through each language, submit a batch to the language specific pipeline
+ for lang in lang_batches.keys():
+ self._update_pipeline_cache(lang)
+ self.pipeline_cache[lang](lang_batches[lang])
+
+ # only return a list if given a list
+ if singleton_input:
+ return docs_w_langid[0]
+ else:
+ return docs_w_langid
+
+ def __call__(self, doc):
+ doc = self.process(doc)
+ return doc
+
diff --git a/stanza/pipeline/pos_processor.py b/stanza/pipeline/pos_processor.py
index da918fdf..89658ee2 100644
--- a/stanza/pipeline/pos_processor.py
+++ b/stanza/pipeline/pos_processor.py
@@ -4,12 +4,14 @@ Processor for performing part-of-speech tagging
from stanza.models.common import doc
from stanza.models.common.pretrain import Pretrain
-from stanza.models.common.utils import unsort
+from stanza.models.common.utils import get_tqdm, unsort
from stanza.models.pos.data import DataLoader
from stanza.models.pos.trainer import Trainer
from stanza.pipeline._constants import *
from stanza.pipeline.processor import UDProcessor, register_processor
+tqdm = get_tqdm()
+
@register_processor(name=POS)
class POSProcessor(UDProcessor):
@@ -23,14 +25,21 @@ class POSProcessor(UDProcessor):
self._pretrain = Pretrain(config['pretrain_path']) if 'pretrain_path' in config else None
# set up trainer
self._trainer = Trainer(pretrain=self.pretrain, model_file=config['model_path'], use_cuda=use_gpu)
+ self._tqdm = 'tqdm' in config and config['tqdm']
def process(self, document):
batch = DataLoader(
document, self.config['batch_size'], self.config, self.pretrain, vocab=self.vocab, evaluation=True,
sort_during_eval=True)
preds = []
- for i, b in enumerate(batch):
- preds += self.trainer.predict(b)
+
+ if self._tqdm:
+ for i, b in enumerate(tqdm(batch)):
+ preds += self.trainer.predict(b)
+ else:
+ for i, b in enumerate(batch):
+ preds += self.trainer.predict(b)
+
preds = unsort(preds, batch.data_orig_idx)
batch.doc.set([doc.UPOS, doc.XPOS, doc.FEATS], [y for x in preds for y in x])
return batch.doc
diff --git a/stanza/pipeline/tokenize_processor.py b/stanza/pipeline/tokenize_processor.py
index 79b17a54..3421f90b 100644
--- a/stanza/pipeline/tokenize_processor.py
+++ b/stanza/pipeline/tokenize_processor.py
@@ -82,7 +82,7 @@ class TokenizeProcessor(UDProcessor):
raw_text = '\n\n'.join(document) if isinstance(document, list) else document
# set up batches
- batches = DataLoader(self.config, input_text=raw_text, vocab=self.vocab, evaluation=True)
+ batches = DataLoader(self.config, input_text=raw_text, vocab=self.vocab, evaluation=True, dictionary=self.trainer.dictionary)
# get dict data
_, _, _, document = output_predictions(None, self.trainer, batches, self.vocab, None,
self.config.get('max_seqlen', TokenizeProcessor.MAX_SEQ_LENGTH_DEFAULT),
diff --git a/stanza/protobuf/CoreNLP_pb2.py b/stanza/protobuf/CoreNLP_pb2.py
index 298ed1b8..f29f4132 100644
--- a/stanza/protobuf/CoreNLP_pb2.py
+++ b/stanza/protobuf/CoreNLP_pb2.py
@@ -19,7 +19,7 @@ DESCRIPTOR = _descriptor.FileDescriptor(
package='edu.stanford.nlp.pipeline',
syntax='proto2',
serialized_options=b'\n\031edu.stanford.nlp.pipelineB\rCoreNLPProtos',
- serialized_pb=b'\n\rCoreNLP.proto\x12\x19\x65\x64u.stanford.nlp.pipeline\"\xe1\x05\n\x08\x44ocument\x12\x0c\n\x04text\x18\x01 \x02(\t\x12\x35\n\x08sentence\x18\x02 \x03(\x0b\x32#.edu.stanford.nlp.pipeline.Sentence\x12\x39\n\ncorefChain\x18\x03 \x03(\x0b\x32%.edu.stanford.nlp.pipeline.CorefChain\x12\r\n\x05\x64ocID\x18\x04 \x01(\t\x12\x0f\n\x07\x64ocDate\x18\x07 \x01(\t\x12\x10\n\x08\x63\x61lendar\x18\x08 \x01(\x04\x12;\n\x11sentencelessToken\x18\x05 \x03(\x0b\x32 .edu.stanford.nlp.pipeline.Token\x12\x33\n\tcharacter\x18\n \x03(\x0b\x32 .edu.stanford.nlp.pipeline.Token\x12/\n\x05quote\x18\x06 \x03(\x0b\x32 .edu.stanford.nlp.pipeline.Quote\x12\x37\n\x08mentions\x18\t \x03(\x0b\x32%.edu.stanford.nlp.pipeline.NERMention\x12#\n\x1bhasEntityMentionsAnnotation\x18\r \x01(\x08\x12\x0e\n\x06xmlDoc\x18\x0b \x01(\x08\x12\x34\n\x08sections\x18\x0c \x03(\x0b\x32\".edu.stanford.nlp.pipeline.Section\x12<\n\x10mentionsForCoref\x18\x0e \x03(\x0b\x32\".edu.stanford.nlp.pipeline.Mention\x12!\n\x19hasCorefMentionAnnotation\x18\x0f \x01(\x08\x12\x1a\n\x12hasCorefAnnotation\x18\x10 \x01(\x08\x12+\n#corefMentionToEntityMentionMappings\x18\x11 \x03(\x05\x12+\n#entityMentionToCorefMentionMappings\x18\x12 \x03(\x05*\x05\x08\x64\x10\x80\x02\"\xf3\x0f\n\x08Sentence\x12/\n\x05token\x18\x01 \x03(\x0b\x32 .edu.stanford.nlp.pipeline.Token\x12\x18\n\x10tokenOffsetBegin\x18\x02 \x02(\r\x12\x16\n\x0etokenOffsetEnd\x18\x03 \x02(\r\x12\x15\n\rsentenceIndex\x18\x04 \x01(\r\x12\x1c\n\x14\x63haracterOffsetBegin\x18\x05 \x01(\r\x12\x1a\n\x12\x63haracterOffsetEnd\x18\x06 \x01(\r\x12\x37\n\tparseTree\x18\x07 \x01(\x0b\x32$.edu.stanford.nlp.pipeline.ParseTree\x12@\n\x12\x62inarizedParseTree\x18\x1f \x01(\x0b\x32$.edu.stanford.nlp.pipeline.ParseTree\x12@\n\x12\x61nnotatedParseTree\x18 \x01(\x0b\x32$.edu.stanford.nlp.pipeline.ParseTree\x12\x11\n\tsentiment\x18! \x01(\t\x12=\n\x0fkBestParseTrees\x18\" \x03(\x0b\x32$.edu.stanford.nlp.pipeline.ParseTree\x12\x45\n\x11\x62\x61sicDependencies\x18\x08 \x01(\x0b\x32*.edu.stanford.nlp.pipeline.DependencyGraph\x12I\n\x15\x63ollapsedDependencies\x18\t \x01(\x0b\x32*.edu.stanford.nlp.pipeline.DependencyGraph\x12T\n collapsedCCProcessedDependencies\x18\n \x01(\x0b\x32*.edu.stanford.nlp.pipeline.DependencyGraph\x12K\n\x17\x61lternativeDependencies\x18\r \x01(\x0b\x32*.edu.stanford.nlp.pipeline.DependencyGraph\x12?\n\x0copenieTriple\x18\x0e \x03(\x0b\x32).edu.stanford.nlp.pipeline.RelationTriple\x12<\n\tkbpTriple\x18\x10 \x03(\x0b\x32).edu.stanford.nlp.pipeline.RelationTriple\x12\x45\n\x10\x65ntailedSentence\x18\x0f \x03(\x0b\x32+.edu.stanford.nlp.pipeline.SentenceFragment\x12\x43\n\x0e\x65ntailedClause\x18# \x03(\x0b\x32+.edu.stanford.nlp.pipeline.SentenceFragment\x12H\n\x14\x65nhancedDependencies\x18\x11 \x01(\x0b\x32*.edu.stanford.nlp.pipeline.DependencyGraph\x12P\n\x1c\x65nhancedPlusPlusDependencies\x18\x12 \x01(\x0b\x32*.edu.stanford.nlp.pipeline.DependencyGraph\x12\x33\n\tcharacter\x18\x13 \x03(\x0b\x32 .edu.stanford.nlp.pipeline.Token\x12\x11\n\tparagraph\x18\x0b \x01(\r\x12\x0c\n\x04text\x18\x0c \x01(\t\x12\x12\n\nlineNumber\x18\x14 \x01(\r\x12\x1e\n\x16hasRelationAnnotations\x18\x33 \x01(\x08\x12\x31\n\x06\x65ntity\x18\x34 \x03(\x0b\x32!.edu.stanford.nlp.pipeline.Entity\x12\x35\n\x08relation\x18\x35 \x03(\x0b\x32#.edu.stanford.nlp.pipeline.Relation\x12$\n\x1chasNumerizedTokensAnnotation\x18\x36 \x01(\x08\x12\x37\n\x08mentions\x18\x37 \x03(\x0b\x32%.edu.stanford.nlp.pipeline.NERMention\x12<\n\x10mentionsForCoref\x18\x38 \x03(\x0b\x32\".edu.stanford.nlp.pipeline.Mention\x12\"\n\x1ahasCorefMentionsAnnotation\x18\x39 \x01(\x08\x12\x12\n\nsentenceID\x18: \x01(\t\x12\x13\n\x0bsectionDate\x18; \x01(\t\x12\x14\n\x0csectionIndex\x18< \x01(\r\x12\x13\n\x0bsectionName\x18= \x01(\t\x12\x15\n\rsectionAuthor\x18> \x01(\t\x12\r\n\x05\x64ocID\x18? \x01(\t\x12\x15\n\rsectionQuoted\x18@ \x01(\x08\x12#\n\x1bhasEntityMentionsAnnotation\x18\x41 \x01(\x08\x12\x1f\n\x17hasKBPTriplesAnnotation\x18\x44 \x01(\x08\x12\"\n\x1ahasOpenieTriplesAnnotation\x18\x45 \x01(\x08\x12\x14\n\x0c\x63hapterIndex\x18\x42 \x01(\r\x12\x16\n\x0eparagraphIndex\x18\x43 \x01(\r\x12=\n\x10\x65nhancedSentence\x18\x46 \x01(\x0b\x32#.edu.stanford.nlp.pipeline.Sentence\x12\x0f\n\x07speaker\x18G \x01(\t\x12\x13\n\x0bspeakerType\x18H \x01(\t*\x05\x08\x64\x10\x80\x02\"\xc2\x0c\n\x05Token\x12\x0c\n\x04word\x18\x01 \x01(\t\x12\x0b\n\x03pos\x18\x02 \x01(\t\x12\r\n\x05value\x18\x03 \x01(\t\x12\x10\n\x08\x63\x61tegory\x18\x04 \x01(\t\x12\x0e\n\x06\x62\x65\x66ore\x18\x05 \x01(\t\x12\r\n\x05\x61\x66ter\x18\x06 \x01(\t\x12\x14\n\x0coriginalText\x18\x07 \x01(\t\x12\x0b\n\x03ner\x18\x08 \x01(\t\x12\x11\n\tcoarseNER\x18> \x01(\t\x12\x16\n\x0e\x66ineGrainedNER\x18? \x01(\t\x12\x15\n\rnerLabelProbs\x18\x42 \x03(\t\x12\x15\n\rnormalizedNER\x18\t \x01(\t\x12\r\n\x05lemma\x18\n \x01(\t\x12\x11\n\tbeginChar\x18\x0b \x01(\r\x12\x0f\n\x07\x65ndChar\x18\x0c \x01(\r\x12\x11\n\tutterance\x18\r \x01(\r\x12\x0f\n\x07speaker\x18\x0e \x01(\t\x12\x13\n\x0bspeakerType\x18M \x01(\t\x12\x12\n\nbeginIndex\x18\x0f \x01(\r\x12\x10\n\x08\x65ndIndex\x18\x10 \x01(\r\x12\x17\n\x0ftokenBeginIndex\x18\x11 \x01(\r\x12\x15\n\rtokenEndIndex\x18\x12 \x01(\r\x12\x34\n\ntimexValue\x18\x13 \x01(\x0b\x32 .edu.stanford.nlp.pipeline.Timex\x12\x15\n\rhasXmlContext\x18\x15 \x01(\x08\x12\x12\n\nxmlContext\x18\x16 \x03(\t\x12\x16\n\x0e\x63orefClusterID\x18\x17 \x01(\r\x12\x0e\n\x06\x61nswer\x18\x18 \x01(\t\x12\x15\n\rheadWordIndex\x18\x1a \x01(\r\x12\x35\n\x08operator\x18\x1b \x01(\x0b\x32#.edu.stanford.nlp.pipeline.Operator\x12\x35\n\x08polarity\x18\x1c \x01(\x0b\x32#.edu.stanford.nlp.pipeline.Polarity\x12\x14\n\x0cpolarity_dir\x18\' \x01(\t\x12-\n\x04span\x18\x1d \x01(\x0b\x32\x1f.edu.stanford.nlp.pipeline.Span\x12\x11\n\tsentiment\x18\x1e \x01(\t\x12\x16\n\x0equotationIndex\x18\x1f \x01(\x05\x12\x42\n\x0e\x63onllUFeatures\x18 \x01(\x0b\x32*.edu.stanford.nlp.pipeline.MapStringString\x12\x11\n\tcoarseTag\x18! \x01(\t\x12\x38\n\x0f\x63onllUTokenSpan\x18\" \x01(\x0b\x32\x1f.edu.stanford.nlp.pipeline.Span\x12\x12\n\nconllUMisc\x18# \x01(\t\x12G\n\x13\x63onllUSecondaryDeps\x18$ \x01(\x0b\x32*.edu.stanford.nlp.pipeline.MapStringString\x12\x17\n\x0fwikipediaEntity\x18% \x01(\t\x12\x11\n\tisNewline\x18& \x01(\x08\x12\x0e\n\x06gender\x18\x33 \x01(\t\x12\x10\n\x08trueCase\x18\x34 \x01(\t\x12\x14\n\x0ctrueCaseText\x18\x35 \x01(\t\x12\x13\n\x0b\x63hineseChar\x18\x36 \x01(\t\x12\x12\n\nchineseSeg\x18\x37 \x01(\t\x12\x16\n\x0e\x63hineseXMLChar\x18< \x01(\t\x12\x11\n\tarabicSeg\x18L \x01(\t\x12\x13\n\x0bsectionName\x18\x38 \x01(\t\x12\x15\n\rsectionAuthor\x18\x39 \x01(\t\x12\x13\n\x0bsectionDate\x18: \x01(\t\x12\x17\n\x0fsectionEndLabel\x18; \x01(\t\x12\x0e\n\x06parent\x18= \x01(\t\x12\x19\n\x11\x63orefMentionIndex\x18@ \x03(\r\x12\x1a\n\x12\x65ntityMentionIndex\x18\x41 \x01(\r\x12\r\n\x05isMWT\x18\x43 \x01(\x08\x12\x12\n\nisFirstMWT\x18\x44 \x01(\x08\x12\x0f\n\x07mwtText\x18\x45 \x01(\t\x12\x14\n\x0cnumericValue\x18\x46 \x01(\x04\x12\x13\n\x0bnumericType\x18G \x01(\t\x12\x1d\n\x15numericCompositeValue\x18H \x01(\x04\x12\x1c\n\x14numericCompositeType\x18I \x01(\t\x12\x1c\n\x14\x63odepointOffsetBegin\x18J \x01(\r\x12\x1a\n\x12\x63odepointOffsetEnd\x18K \x01(\r*\x05\x08\x64\x10\x80\x02\"\xe4\x03\n\x05Quote\x12\x0c\n\x04text\x18\x01 \x01(\t\x12\r\n\x05\x62\x65gin\x18\x02 \x01(\r\x12\x0b\n\x03\x65nd\x18\x03 \x01(\r\x12\x15\n\rsentenceBegin\x18\x05 \x01(\r\x12\x13\n\x0bsentenceEnd\x18\x06 \x01(\r\x12\x12\n\ntokenBegin\x18\x07 \x01(\r\x12\x10\n\x08tokenEnd\x18\x08 \x01(\r\x12\r\n\x05\x64ocid\x18\t \x01(\t\x12\r\n\x05index\x18\n \x01(\r\x12\x0e\n\x06\x61uthor\x18\x0b \x01(\t\x12\x0f\n\x07mention\x18\x0c \x01(\t\x12\x14\n\x0cmentionBegin\x18\r \x01(\r\x12\x12\n\nmentionEnd\x18\x0e \x01(\r\x12\x13\n\x0bmentionType\x18\x0f \x01(\t\x12\x14\n\x0cmentionSieve\x18\x10 \x01(\t\x12\x0f\n\x07speaker\x18\x11 \x01(\t\x12\x14\n\x0cspeakerSieve\x18\x12 \x01(\t\x12\x18\n\x10\x63\x61nonicalMention\x18\x13 \x01(\t\x12\x1d\n\x15\x63\x61nonicalMentionBegin\x18\x14 \x01(\r\x12\x1b\n\x13\x63\x61nonicalMentionEnd\x18\x15 \x01(\r\x12N\n\x1a\x61ttributionDependencyGraph\x18\x16 \x01(\x0b\x32*.edu.stanford.nlp.pipeline.DependencyGraph\"\xc7\x01\n\tParseTree\x12\x33\n\x05\x63hild\x18\x01 \x03(\x0b\x32$.edu.stanford.nlp.pipeline.ParseTree\x12\r\n\x05value\x18\x02 \x01(\t\x12\x17\n\x0fyieldBeginIndex\x18\x03 \x01(\r\x12\x15\n\ryieldEndIndex\x18\x04 \x01(\r\x12\r\n\x05score\x18\x05 \x01(\x01\x12\x37\n\tsentiment\x18\x06 \x01(\x0e\x32$.edu.stanford.nlp.pipeline.Sentiment\"\x96\x03\n\x0f\x44\x65pendencyGraph\x12=\n\x04node\x18\x01 \x03(\x0b\x32/.edu.stanford.nlp.pipeline.DependencyGraph.Node\x12=\n\x04\x65\x64ge\x18\x02 \x03(\x0b\x32/.edu.stanford.nlp.pipeline.DependencyGraph.Edge\x12\x10\n\x04root\x18\x03 \x03(\rB\x02\x10\x01\x1a\x44\n\x04Node\x12\x15\n\rsentenceIndex\x18\x01 \x02(\r\x12\r\n\x05index\x18\x02 \x02(\r\x12\x16\n\x0e\x63opyAnnotation\x18\x03 \x01(\r\x1a\xac\x01\n\x04\x45\x64ge\x12\x0e\n\x06source\x18\x01 \x02(\r\x12\x0e\n\x06target\x18\x02 \x02(\r\x12\x0b\n\x03\x64\x65p\x18\x03 \x01(\t\x12\x0f\n\x07isExtra\x18\x04 \x01(\x08\x12\x12\n\nsourceCopy\x18\x05 \x01(\r\x12\x12\n\ntargetCopy\x18\x06 \x01(\r\x12>\n\x08language\x18\x07 \x01(\x0e\x32#.edu.stanford.nlp.pipeline.Language:\x07Unknown\"\xc6\x02\n\nCorefChain\x12\x0f\n\x07\x63hainID\x18\x01 \x02(\x05\x12\x43\n\x07mention\x18\x02 \x03(\x0b\x32\x32.edu.stanford.nlp.pipeline.CorefChain.CorefMention\x12\x16\n\x0erepresentative\x18\x03 \x02(\r\x1a\xc9\x01\n\x0c\x43orefMention\x12\x11\n\tmentionID\x18\x01 \x01(\x05\x12\x13\n\x0bmentionType\x18\x02 \x01(\t\x12\x0e\n\x06number\x18\x03 \x01(\t\x12\x0e\n\x06gender\x18\x04 \x01(\t\x12\x0f\n\x07\x61nimacy\x18\x05 \x01(\t\x12\x12\n\nbeginIndex\x18\x06 \x01(\r\x12\x10\n\x08\x65ndIndex\x18\x07 \x01(\r\x12\x11\n\theadIndex\x18\t \x01(\r\x12\x15\n\rsentenceIndex\x18\n \x01(\r\x12\x10\n\x08position\x18\x0b \x01(\r\"\xef\x08\n\x07Mention\x12\x11\n\tmentionID\x18\x01 \x01(\x05\x12\x13\n\x0bmentionType\x18\x02 \x01(\t\x12\x0e\n\x06number\x18\x03 \x01(\t\x12\x0e\n\x06gender\x18\x04 \x01(\t\x12\x0f\n\x07\x61nimacy\x18\x05 \x01(\t\x12\x0e\n\x06person\x18\x06 \x01(\t\x12\x12\n\nstartIndex\x18\x07 \x01(\r\x12\x10\n\x08\x65ndIndex\x18\t \x01(\r\x12\x11\n\theadIndex\x18\n \x01(\x05\x12\x12\n\nheadString\x18\x0b \x01(\t\x12\x11\n\tnerString\x18\x0c \x01(\t\x12\x13\n\x0boriginalRef\x18\r \x01(\x05\x12\x1a\n\x12goldCorefClusterID\x18\x0e \x01(\x05\x12\x16\n\x0e\x63orefClusterID\x18\x0f \x01(\x05\x12\x12\n\nmentionNum\x18\x10 \x01(\x05\x12\x0f\n\x07sentNum\x18\x11 \x01(\x05\x12\r\n\x05utter\x18\x12 \x01(\x05\x12\x11\n\tparagraph\x18\x13 \x01(\x05\x12\x11\n\tisSubject\x18\x14 \x01(\x08\x12\x16\n\x0eisDirectObject\x18\x15 \x01(\x08\x12\x18\n\x10isIndirectObject\x18\x16 \x01(\x08\x12\x1b\n\x13isPrepositionObject\x18\x17 \x01(\x08\x12\x0f\n\x07hasTwin\x18\x18 \x01(\x08\x12\x0f\n\x07generic\x18\x19 \x01(\x08\x12\x13\n\x0bisSingleton\x18\x1a \x01(\x08\x12\x1a\n\x12hasBasicDependency\x18\x1b \x01(\x08\x12\x1d\n\x15hasEnhancedDepenedncy\x18\x1c \x01(\x08\x12\x1b\n\x13hasContextParseTree\x18\x1d \x01(\x08\x12?\n\x0fheadIndexedWord\x18\x1e \x01(\x0b\x32&.edu.stanford.nlp.pipeline.IndexedWord\x12=\n\rdependingVerb\x18\x1f \x01(\x0b\x32&.edu.stanford.nlp.pipeline.IndexedWord\x12\x38\n\x08headWord\x18 \x01(\x0b\x32&.edu.stanford.nlp.pipeline.IndexedWord\x12;\n\x0bspeakerInfo\x18! \x01(\x0b\x32&.edu.stanford.nlp.pipeline.SpeakerInfo\x12=\n\rsentenceWords\x18\x32 \x03(\x0b\x32&.edu.stanford.nlp.pipeline.IndexedWord\x12<\n\x0coriginalSpan\x18\x33 \x03(\x0b\x32&.edu.stanford.nlp.pipeline.IndexedWord\x12\x12\n\ndependents\x18\x34 \x03(\t\x12\x19\n\x11preprocessedTerms\x18\x35 \x03(\t\x12\x13\n\x0b\x61ppositions\x18\x36 \x03(\x05\x12\x1c\n\x14predicateNominatives\x18\x37 \x03(\x05\x12\x18\n\x10relativePronouns\x18\x38 \x03(\x05\x12\x13\n\x0blistMembers\x18\x39 \x03(\x05\x12\x15\n\rbelongToLists\x18: \x03(\x05\"X\n\x0bIndexedWord\x12\x13\n\x0bsentenceNum\x18\x01 \x01(\x05\x12\x12\n\ntokenIndex\x18\x02 \x01(\x05\x12\r\n\x05\x64ocID\x18\x03 \x01(\x05\x12\x11\n\tcopyCount\x18\x04 \x01(\r\"4\n\x0bSpeakerInfo\x12\x13\n\x0bspeakerName\x18\x01 \x01(\t\x12\x10\n\x08mentions\x18\x02 \x03(\x05\"\"\n\x04Span\x12\r\n\x05\x62\x65gin\x18\x01 \x02(\r\x12\x0b\n\x03\x65nd\x18\x02 \x02(\r\"w\n\x05Timex\x12\r\n\x05value\x18\x01 \x01(\t\x12\x10\n\x08\x61ltValue\x18\x02 \x01(\t\x12\x0c\n\x04text\x18\x03 \x01(\t\x12\x0c\n\x04type\x18\x04 \x01(\t\x12\x0b\n\x03tid\x18\x05 \x01(\t\x12\x12\n\nbeginPoint\x18\x06 \x01(\r\x12\x10\n\x08\x65ndPoint\x18\x07 \x01(\r\"\xdb\x01\n\x06\x45ntity\x12\x11\n\theadStart\x18\x06 \x01(\r\x12\x0f\n\x07headEnd\x18\x07 \x01(\r\x12\x13\n\x0bmentionType\x18\x08 \x01(\t\x12\x16\n\x0enormalizedName\x18\t \x01(\t\x12\x16\n\x0eheadTokenIndex\x18\n \x01(\r\x12\x0f\n\x07\x63orefID\x18\x0b \x01(\t\x12\x10\n\x08objectID\x18\x01 \x01(\t\x12\x13\n\x0b\x65xtentStart\x18\x02 \x01(\r\x12\x11\n\textentEnd\x18\x03 \x01(\r\x12\x0c\n\x04type\x18\x04 \x01(\t\x12\x0f\n\x07subtype\x18\x05 \x01(\t\"\xb7\x01\n\x08Relation\x12\x0f\n\x07\x61rgName\x18\x06 \x03(\t\x12.\n\x03\x61rg\x18\x07 \x03(\x0b\x32!.edu.stanford.nlp.pipeline.Entity\x12\x11\n\tsignature\x18\x08 \x01(\t\x12\x10\n\x08objectID\x18\x01 \x01(\t\x12\x13\n\x0b\x65xtentStart\x18\x02 \x01(\r\x12\x11\n\textentEnd\x18\x03 \x01(\r\x12\x0c\n\x04type\x18\x04 \x01(\t\x12\x0f\n\x07subtype\x18\x05 \x01(\t\"\xb2\x01\n\x08Operator\x12\x0c\n\x04name\x18\x01 \x02(\t\x12\x1b\n\x13quantifierSpanBegin\x18\x02 \x02(\x05\x12\x19\n\x11quantifierSpanEnd\x18\x03 \x02(\x05\x12\x18\n\x10subjectSpanBegin\x18\x04 \x02(\x05\x12\x16\n\x0esubjectSpanEnd\x18\x05 \x02(\x05\x12\x17\n\x0fobjectSpanBegin\x18\x06 \x02(\x05\x12\x15\n\robjectSpanEnd\x18\x07 \x02(\x05\"\xa9\x04\n\x08Polarity\x12K\n\x12projectEquivalence\x18\x01 \x02(\x0e\x32/.edu.stanford.nlp.pipeline.NaturalLogicRelation\x12Q\n\x18projectForwardEntailment\x18\x02 \x02(\x0e\x32/.edu.stanford.nlp.pipeline.NaturalLogicRelation\x12Q\n\x18projectReverseEntailment\x18\x03 \x02(\x0e\x32/.edu.stanford.nlp.pipeline.NaturalLogicRelation\x12H\n\x0fprojectNegation\x18\x04 \x02(\x0e\x32/.edu.stanford.nlp.pipeline.NaturalLogicRelation\x12K\n\x12projectAlternation\x18\x05 \x02(\x0e\x32/.edu.stanford.nlp.pipeline.NaturalLogicRelation\x12\x45\n\x0cprojectCover\x18\x06 \x02(\x0e\x32/.edu.stanford.nlp.pipeline.NaturalLogicRelation\x12L\n\x13projectIndependence\x18\x07 \x02(\x0e\x32/.edu.stanford.nlp.pipeline.NaturalLogicRelation\"\xdd\x02\n\nNERMention\x12\x15\n\rsentenceIndex\x18\x01 \x01(\r\x12%\n\x1dtokenStartInSentenceInclusive\x18\x02 \x02(\r\x12#\n\x1btokenEndInSentenceExclusive\x18\x03 \x02(\r\x12\x0b\n\x03ner\x18\x04 \x02(\t\x12\x15\n\rnormalizedNER\x18\x05 \x01(\t\x12\x12\n\nentityType\x18\x06 \x01(\t\x12/\n\x05timex\x18\x07 \x01(\x0b\x32 .edu.stanford.nlp.pipeline.Timex\x12\x17\n\x0fwikipediaEntity\x18\x08 \x01(\t\x12\x0e\n\x06gender\x18\t \x01(\t\x12\x1a\n\x12\x65ntityMentionIndex\x18\n \x01(\r\x12#\n\x1b\x63\x61nonicalEntityMentionIndex\x18\x0b \x01(\r\x12\x19\n\x11\x65ntityMentionText\x18\x0c \x01(\t\"Y\n\x10SentenceFragment\x12\x12\n\ntokenIndex\x18\x01 \x03(\r\x12\x0c\n\x04root\x18\x02 \x01(\r\x12\x14\n\x0c\x61ssumedTruth\x18\x03 \x01(\x08\x12\r\n\x05score\x18\x04 \x01(\x01\":\n\rTokenLocation\x12\x15\n\rsentenceIndex\x18\x01 \x01(\r\x12\x12\n\ntokenIndex\x18\x02 \x01(\r\"\x9a\x03\n\x0eRelationTriple\x12\x0f\n\x07subject\x18\x01 \x01(\t\x12\x10\n\x08relation\x18\x02 \x01(\t\x12\x0e\n\x06object\x18\x03 \x01(\t\x12\x12\n\nconfidence\x18\x04 \x01(\x01\x12?\n\rsubjectTokens\x18\r \x03(\x0b\x32(.edu.stanford.nlp.pipeline.TokenLocation\x12@\n\x0erelationTokens\x18\x0e \x03(\x0b\x32(.edu.stanford.nlp.pipeline.TokenLocation\x12>\n\x0cobjectTokens\x18\x0f \x03(\x0b\x32(.edu.stanford.nlp.pipeline.TokenLocation\x12\x38\n\x04tree\x18\x08 \x01(\x0b\x32*.edu.stanford.nlp.pipeline.DependencyGraph\x12\x0e\n\x06istmod\x18\t \x01(\x08\x12\x10\n\x08prefixBe\x18\n \x01(\x08\x12\x10\n\x08suffixBe\x18\x0b \x01(\x08\x12\x10\n\x08suffixOf\x18\x0c \x01(\x08\"-\n\x0fMapStringString\x12\x0b\n\x03key\x18\x01 \x03(\t\x12\r\n\x05value\x18\x02 \x03(\t\"*\n\x0cMapIntString\x12\x0b\n\x03key\x18\x01 \x03(\r\x12\r\n\x05value\x18\x02 \x03(\t\"\xfc\x01\n\x07Section\x12\x11\n\tcharBegin\x18\x01 \x02(\r\x12\x0f\n\x07\x63harEnd\x18\x02 \x02(\r\x12\x0e\n\x06\x61uthor\x18\x03 \x01(\t\x12\x17\n\x0fsentenceIndexes\x18\x04 \x03(\r\x12\x10\n\x08\x64\x61tetime\x18\x05 \x01(\t\x12\x30\n\x06quotes\x18\x06 \x03(\x0b\x32 .edu.stanford.nlp.pipeline.Quote\x12\x17\n\x0f\x61uthorCharBegin\x18\x07 \x01(\r\x12\x15\n\rauthorCharEnd\x18\x08 \x01(\r\x12\x30\n\x06xmlTag\x18\t \x02(\x0b\x32 .edu.stanford.nlp.pipeline.Token\"\xe4\x01\n\x0eSemgrexRequest\x12\x0f\n\x07semgrex\x18\x01 \x03(\t\x12\x45\n\x05query\x18\x02 \x03(\x0b\x32\x36.edu.stanford.nlp.pipeline.SemgrexRequest.Dependencies\x1az\n\x0c\x44\x65pendencies\x12/\n\x05token\x18\x01 \x03(\x0b\x32 .edu.stanford.nlp.pipeline.Token\x12\x39\n\x05graph\x18\x02 \x02(\x0b\x32*.edu.stanford.nlp.pipeline.DependencyGraph\"\x8a\x04\n\x0fSemgrexResponse\x12\x46\n\x06result\x18\x01 \x03(\x0b\x32\x36.edu.stanford.nlp.pipeline.SemgrexResponse.GraphResult\x1a-\n\tNamedNode\x12\x0c\n\x04name\x18\x01 \x02(\t\x12\x12\n\nmatchIndex\x18\x02 \x02(\x05\x1a+\n\rNamedRelation\x12\x0c\n\x04name\x18\x01 \x02(\t\x12\x0c\n\x04reln\x18\x02 \x02(\t\x1a\xa7\x01\n\x05Match\x12\x12\n\nmatchIndex\x18\x01 \x02(\x05\x12\x42\n\x04node\x18\x02 \x03(\x0b\x32\x34.edu.stanford.nlp.pipeline.SemgrexResponse.NamedNode\x12\x46\n\x04reln\x18\x03 \x03(\x0b\x32\x38.edu.stanford.nlp.pipeline.SemgrexResponse.NamedRelation\x1aP\n\rSemgrexResult\x12?\n\x05match\x18\x01 \x03(\x0b\x32\x30.edu.stanford.nlp.pipeline.SemgrexResponse.Match\x1aW\n\x0bGraphResult\x12H\n\x06result\x18\x01 \x03(\x0b\x32\x38.edu.stanford.nlp.pipeline.SemgrexResponse.SemgrexResult\"W\n\x12TokensRegexRequest\x12\x30\n\x03\x64oc\x18\x01 \x02(\x0b\x32#.edu.stanford.nlp.pipeline.Document\x12\x0f\n\x07pattern\x18\x02 \x03(\t\"\xa7\x03\n\x13TokensRegexResponse\x12J\n\x05match\x18\x01 \x03(\x0b\x32;.edu.stanford.nlp.pipeline.TokensRegexResponse.PatternMatch\x1a\x39\n\rMatchLocation\x12\x0c\n\x04text\x18\x01 \x01(\t\x12\r\n\x05\x62\x65gin\x18\x02 \x01(\x05\x12\x0b\n\x03\x65nd\x18\x03 \x01(\x05\x1a\xb3\x01\n\x05Match\x12\x10\n\x08sentence\x18\x01 \x02(\x05\x12K\n\x05match\x18\x02 \x02(\x0b\x32<.edu.stanford.nlp.pipeline.TokensRegexResponse.MatchLocation\x12K\n\x05group\x18\x03 \x03(\x0b\x32<.edu.stanford.nlp.pipeline.TokensRegexResponse.MatchLocation\x1aS\n\x0cPatternMatch\x12\x43\n\x05match\x18\x01 \x03(\x0b\x32\x34.edu.stanford.nlp.pipeline.TokensRegexResponse.Match\"\xae\x01\n\x19\x44\x65pendencyEnhancerRequest\x12\x35\n\x08\x64ocument\x18\x01 \x02(\x0b\x32#.edu.stanford.nlp.pipeline.Document\x12\x37\n\x08language\x18\x02 \x01(\x0e\x32#.edu.stanford.nlp.pipeline.LanguageH\x00\x12\x1a\n\x10relativePronouns\x18\x03 \x01(\tH\x00\x42\x05\n\x03ref*\xa3\x01\n\x08Language\x12\x0b\n\x07Unknown\x10\x00\x12\x07\n\x03\x41ny\x10\x01\x12\n\n\x06\x41rabic\x10\x02\x12\x0b\n\x07\x43hinese\x10\x03\x12\x0b\n\x07\x45nglish\x10\x04\x12\n\n\x06German\x10\x05\x12\n\n\x06\x46rench\x10\x06\x12\n\n\x06Hebrew\x10\x07\x12\x0b\n\x07Spanish\x10\x08\x12\x14\n\x10UniversalEnglish\x10\t\x12\x14\n\x10UniversalChinese\x10\n*h\n\tSentiment\x12\x13\n\x0fSTRONG_NEGATIVE\x10\x00\x12\x11\n\rWEAK_NEGATIVE\x10\x01\x12\x0b\n\x07NEUTRAL\x10\x02\x12\x11\n\rWEAK_POSITIVE\x10\x03\x12\x13\n\x0fSTRONG_POSITIVE\x10\x04*\x93\x01\n\x14NaturalLogicRelation\x12\x0f\n\x0b\x45QUIVALENCE\x10\x00\x12\x16\n\x12\x46ORWARD_ENTAILMENT\x10\x01\x12\x16\n\x12REVERSE_ENTAILMENT\x10\x02\x12\x0c\n\x08NEGATION\x10\x03\x12\x0f\n\x0b\x41LTERNATION\x10\x04\x12\t\n\x05\x43OVER\x10\x05\x12\x10\n\x0cINDEPENDENCE\x10\x06\x42*\n\x19\x65\x64u.stanford.nlp.pipelineB\rCoreNLPProtos'
+ serialized_pb=b'\n\rCoreNLP.proto\x12\x19\x65\x64u.stanford.nlp.pipeline\"\xe1\x05\n\x08\x44ocument\x12\x0c\n\x04text\x18\x01 \x02(\t\x12\x35\n\x08sentence\x18\x02 \x03(\x0b\x32#.edu.stanford.nlp.pipeline.Sentence\x12\x39\n\ncorefChain\x18\x03 \x03(\x0b\x32%.edu.stanford.nlp.pipeline.CorefChain\x12\r\n\x05\x64ocID\x18\x04 \x01(\t\x12\x0f\n\x07\x64ocDate\x18\x07 \x01(\t\x12\x10\n\x08\x63\x61lendar\x18\x08 \x01(\x04\x12;\n\x11sentencelessToken\x18\x05 \x03(\x0b\x32 .edu.stanford.nlp.pipeline.Token\x12\x33\n\tcharacter\x18\n \x03(\x0b\x32 .edu.stanford.nlp.pipeline.Token\x12/\n\x05quote\x18\x06 \x03(\x0b\x32 .edu.stanford.nlp.pipeline.Quote\x12\x37\n\x08mentions\x18\t \x03(\x0b\x32%.edu.stanford.nlp.pipeline.NERMention\x12#\n\x1bhasEntityMentionsAnnotation\x18\r \x01(\x08\x12\x0e\n\x06xmlDoc\x18\x0b \x01(\x08\x12\x34\n\x08sections\x18\x0c \x03(\x0b\x32\".edu.stanford.nlp.pipeline.Section\x12<\n\x10mentionsForCoref\x18\x0e \x03(\x0b\x32\".edu.stanford.nlp.pipeline.Mention\x12!\n\x19hasCorefMentionAnnotation\x18\x0f \x01(\x08\x12\x1a\n\x12hasCorefAnnotation\x18\x10 \x01(\x08\x12+\n#corefMentionToEntityMentionMappings\x18\x11 \x03(\x05\x12+\n#entityMentionToCorefMentionMappings\x18\x12 \x03(\x05*\x05\x08\x64\x10\x80\x02\"\xf3\x0f\n\x08Sentence\x12/\n\x05token\x18\x01 \x03(\x0b\x32 .edu.stanford.nlp.pipeline.Token\x12\x18\n\x10tokenOffsetBegin\x18\x02 \x02(\r\x12\x16\n\x0etokenOffsetEnd\x18\x03 \x02(\r\x12\x15\n\rsentenceIndex\x18\x04 \x01(\r\x12\x1c\n\x14\x63haracterOffsetBegin\x18\x05 \x01(\r\x12\x1a\n\x12\x63haracterOffsetEnd\x18\x06 \x01(\r\x12\x37\n\tparseTree\x18\x07 \x01(\x0b\x32$.edu.stanford.nlp.pipeline.ParseTree\x12@\n\x12\x62inarizedParseTree\x18\x1f \x01(\x0b\x32$.edu.stanford.nlp.pipeline.ParseTree\x12@\n\x12\x61nnotatedParseTree\x18 \x01(\x0b\x32$.edu.stanford.nlp.pipeline.ParseTree\x12\x11\n\tsentiment\x18! \x01(\t\x12=\n\x0fkBestParseTrees\x18\" \x03(\x0b\x32$.edu.stanford.nlp.pipeline.ParseTree\x12\x45\n\x11\x62\x61sicDependencies\x18\x08 \x01(\x0b\x32*.edu.stanford.nlp.pipeline.DependencyGraph\x12I\n\x15\x63ollapsedDependencies\x18\t \x01(\x0b\x32*.edu.stanford.nlp.pipeline.DependencyGraph\x12T\n collapsedCCProcessedDependencies\x18\n \x01(\x0b\x32*.edu.stanford.nlp.pipeline.DependencyGraph\x12K\n\x17\x61lternativeDependencies\x18\r \x01(\x0b\x32*.edu.stanford.nlp.pipeline.DependencyGraph\x12?\n\x0copenieTriple\x18\x0e \x03(\x0b\x32).edu.stanford.nlp.pipeline.RelationTriple\x12<\n\tkbpTriple\x18\x10 \x03(\x0b\x32).edu.stanford.nlp.pipeline.RelationTriple\x12\x45\n\x10\x65ntailedSentence\x18\x0f \x03(\x0b\x32+.edu.stanford.nlp.pipeline.SentenceFragment\x12\x43\n\x0e\x65ntailedClause\x18# \x03(\x0b\x32+.edu.stanford.nlp.pipeline.SentenceFragment\x12H\n\x14\x65nhancedDependencies\x18\x11 \x01(\x0b\x32*.edu.stanford.nlp.pipeline.DependencyGraph\x12P\n\x1c\x65nhancedPlusPlusDependencies\x18\x12 \x01(\x0b\x32*.edu.stanford.nlp.pipeline.DependencyGraph\x12\x33\n\tcharacter\x18\x13 \x03(\x0b\x32 .edu.stanford.nlp.pipeline.Token\x12\x11\n\tparagraph\x18\x0b \x01(\r\x12\x0c\n\x04text\x18\x0c \x01(\t\x12\x12\n\nlineNumber\x18\x14 \x01(\r\x12\x1e\n\x16hasRelationAnnotations\x18\x33 \x01(\x08\x12\x31\n\x06\x65ntity\x18\x34 \x03(\x0b\x32!.edu.stanford.nlp.pipeline.Entity\x12\x35\n\x08relation\x18\x35 \x03(\x0b\x32#.edu.stanford.nlp.pipeline.Relation\x12$\n\x1chasNumerizedTokensAnnotation\x18\x36 \x01(\x08\x12\x37\n\x08mentions\x18\x37 \x03(\x0b\x32%.edu.stanford.nlp.pipeline.NERMention\x12<\n\x10mentionsForCoref\x18\x38 \x03(\x0b\x32\".edu.stanford.nlp.pipeline.Mention\x12\"\n\x1ahasCorefMentionsAnnotation\x18\x39 \x01(\x08\x12\x12\n\nsentenceID\x18: \x01(\t\x12\x13\n\x0bsectionDate\x18; \x01(\t\x12\x14\n\x0csectionIndex\x18< \x01(\r\x12\x13\n\x0bsectionName\x18= \x01(\t\x12\x15\n\rsectionAuthor\x18> \x01(\t\x12\r\n\x05\x64ocID\x18? \x01(\t\x12\x15\n\rsectionQuoted\x18@ \x01(\x08\x12#\n\x1bhasEntityMentionsAnnotation\x18\x41 \x01(\x08\x12\x1f\n\x17hasKBPTriplesAnnotation\x18\x44 \x01(\x08\x12\"\n\x1ahasOpenieTriplesAnnotation\x18\x45 \x01(\x08\x12\x14\n\x0c\x63hapterIndex\x18\x42 \x01(\r\x12\x16\n\x0eparagraphIndex\x18\x43 \x01(\r\x12=\n\x10\x65nhancedSentence\x18\x46 \x01(\x0b\x32#.edu.stanford.nlp.pipeline.Sentence\x12\x0f\n\x07speaker\x18G \x01(\t\x12\x13\n\x0bspeakerType\x18H \x01(\t*\x05\x08\x64\x10\x80\x02\"\xc2\x0c\n\x05Token\x12\x0c\n\x04word\x18\x01 \x01(\t\x12\x0b\n\x03pos\x18\x02 \x01(\t\x12\r\n\x05value\x18\x03 \x01(\t\x12\x10\n\x08\x63\x61tegory\x18\x04 \x01(\t\x12\x0e\n\x06\x62\x65\x66ore\x18\x05 \x01(\t\x12\r\n\x05\x61\x66ter\x18\x06 \x01(\t\x12\x14\n\x0coriginalText\x18\x07 \x01(\t\x12\x0b\n\x03ner\x18\x08 \x01(\t\x12\x11\n\tcoarseNER\x18> \x01(\t\x12\x16\n\x0e\x66ineGrainedNER\x18? \x01(\t\x12\x15\n\rnerLabelProbs\x18\x42 \x03(\t\x12\x15\n\rnormalizedNER\x18\t \x01(\t\x12\r\n\x05lemma\x18\n \x01(\t\x12\x11\n\tbeginChar\x18\x0b \x01(\r\x12\x0f\n\x07\x65ndChar\x18\x0c \x01(\r\x12\x11\n\tutterance\x18\r \x01(\r\x12\x0f\n\x07speaker\x18\x0e \x01(\t\x12\x13\n\x0bspeakerType\x18M \x01(\t\x12\x12\n\nbeginIndex\x18\x0f \x01(\r\x12\x10\n\x08\x65ndIndex\x18\x10 \x01(\r\x12\x17\n\x0ftokenBeginIndex\x18\x11 \x01(\r\x12\x15\n\rtokenEndIndex\x18\x12 \x01(\r\x12\x34\n\ntimexValue\x18\x13 \x01(\x0b\x32 .edu.stanford.nlp.pipeline.Timex\x12\x15\n\rhasXmlContext\x18\x15 \x01(\x08\x12\x12\n\nxmlContext\x18\x16 \x03(\t\x12\x16\n\x0e\x63orefClusterID\x18\x17 \x01(\r\x12\x0e\n\x06\x61nswer\x18\x18 \x01(\t\x12\x15\n\rheadWordIndex\x18\x1a \x01(\r\x12\x35\n\x08operator\x18\x1b \x01(\x0b\x32#.edu.stanford.nlp.pipeline.Operator\x12\x35\n\x08polarity\x18\x1c \x01(\x0b\x32#.edu.stanford.nlp.pipeline.Polarity\x12\x14\n\x0cpolarity_dir\x18\' \x01(\t\x12-\n\x04span\x18\x1d \x01(\x0b\x32\x1f.edu.stanford.nlp.pipeline.Span\x12\x11\n\tsentiment\x18\x1e \x01(\t\x12\x16\n\x0equotationIndex\x18\x1f \x01(\x05\x12\x42\n\x0e\x63onllUFeatures\x18 \x01(\x0b\x32*.edu.stanford.nlp.pipeline.MapStringString\x12\x11\n\tcoarseTag\x18! \x01(\t\x12\x38\n\x0f\x63onllUTokenSpan\x18\" \x01(\x0b\x32\x1f.edu.stanford.nlp.pipeline.Span\x12\x12\n\nconllUMisc\x18# \x01(\t\x12G\n\x13\x63onllUSecondaryDeps\x18$ \x01(\x0b\x32*.edu.stanford.nlp.pipeline.MapStringString\x12\x17\n\x0fwikipediaEntity\x18% \x01(\t\x12\x11\n\tisNewline\x18& \x01(\x08\x12\x0e\n\x06gender\x18\x33 \x01(\t\x12\x10\n\x08trueCase\x18\x34 \x01(\t\x12\x14\n\x0ctrueCaseText\x18\x35 \x01(\t\x12\x13\n\x0b\x63hineseChar\x18\x36 \x01(\t\x12\x12\n\nchineseSeg\x18\x37 \x01(\t\x12\x16\n\x0e\x63hineseXMLChar\x18< \x01(\t\x12\x11\n\tarabicSeg\x18L \x01(\t\x12\x13\n\x0bsectionName\x18\x38 \x01(\t\x12\x15\n\rsectionAuthor\x18\x39 \x01(\t\x12\x13\n\x0bsectionDate\x18: \x01(\t\x12\x17\n\x0fsectionEndLabel\x18; \x01(\t\x12\x0e\n\x06parent\x18= \x01(\t\x12\x19\n\x11\x63orefMentionIndex\x18@ \x03(\r\x12\x1a\n\x12\x65ntityMentionIndex\x18\x41 \x01(\r\x12\r\n\x05isMWT\x18\x43 \x01(\x08\x12\x12\n\nisFirstMWT\x18\x44 \x01(\x08\x12\x0f\n\x07mwtText\x18\x45 \x01(\t\x12\x14\n\x0cnumericValue\x18\x46 \x01(\x04\x12\x13\n\x0bnumericType\x18G \x01(\t\x12\x1d\n\x15numericCompositeValue\x18H \x01(\x04\x12\x1c\n\x14numericCompositeType\x18I \x01(\t\x12\x1c\n\x14\x63odepointOffsetBegin\x18J \x01(\r\x12\x1a\n\x12\x63odepointOffsetEnd\x18K \x01(\r*\x05\x08\x64\x10\x80\x02\"\xe4\x03\n\x05Quote\x12\x0c\n\x04text\x18\x01 \x01(\t\x12\r\n\x05\x62\x65gin\x18\x02 \x01(\r\x12\x0b\n\x03\x65nd\x18\x03 \x01(\r\x12\x15\n\rsentenceBegin\x18\x05 \x01(\r\x12\x13\n\x0bsentenceEnd\x18\x06 \x01(\r\x12\x12\n\ntokenBegin\x18\x07 \x01(\r\x12\x10\n\x08tokenEnd\x18\x08 \x01(\r\x12\r\n\x05\x64ocid\x18\t \x01(\t\x12\r\n\x05index\x18\n \x01(\r\x12\x0e\n\x06\x61uthor\x18\x0b \x01(\t\x12\x0f\n\x07mention\x18\x0c \x01(\t\x12\x14\n\x0cmentionBegin\x18\r \x01(\r\x12\x12\n\nmentionEnd\x18\x0e \x01(\r\x12\x13\n\x0bmentionType\x18\x0f \x01(\t\x12\x14\n\x0cmentionSieve\x18\x10 \x01(\t\x12\x0f\n\x07speaker\x18\x11 \x01(\t\x12\x14\n\x0cspeakerSieve\x18\x12 \x01(\t\x12\x18\n\x10\x63\x61nonicalMention\x18\x13 \x01(\t\x12\x1d\n\x15\x63\x61nonicalMentionBegin\x18\x14 \x01(\r\x12\x1b\n\x13\x63\x61nonicalMentionEnd\x18\x15 \x01(\r\x12N\n\x1a\x61ttributionDependencyGraph\x18\x16 \x01(\x0b\x32*.edu.stanford.nlp.pipeline.DependencyGraph\"\xc7\x01\n\tParseTree\x12\x33\n\x05\x63hild\x18\x01 \x03(\x0b\x32$.edu.stanford.nlp.pipeline.ParseTree\x12\r\n\x05value\x18\x02 \x01(\t\x12\x17\n\x0fyieldBeginIndex\x18\x03 \x01(\r\x12\x15\n\ryieldEndIndex\x18\x04 \x01(\r\x12\r\n\x05score\x18\x05 \x01(\x01\x12\x37\n\tsentiment\x18\x06 \x01(\x0e\x32$.edu.stanford.nlp.pipeline.Sentiment\"\x96\x03\n\x0f\x44\x65pendencyGraph\x12=\n\x04node\x18\x01 \x03(\x0b\x32/.edu.stanford.nlp.pipeline.DependencyGraph.Node\x12=\n\x04\x65\x64ge\x18\x02 \x03(\x0b\x32/.edu.stanford.nlp.pipeline.DependencyGraph.Edge\x12\x10\n\x04root\x18\x03 \x03(\rB\x02\x10\x01\x1a\x44\n\x04Node\x12\x15\n\rsentenceIndex\x18\x01 \x02(\r\x12\r\n\x05index\x18\x02 \x02(\r\x12\x16\n\x0e\x63opyAnnotation\x18\x03 \x01(\r\x1a\xac\x01\n\x04\x45\x64ge\x12\x0e\n\x06source\x18\x01 \x02(\r\x12\x0e\n\x06target\x18\x02 \x02(\r\x12\x0b\n\x03\x64\x65p\x18\x03 \x01(\t\x12\x0f\n\x07isExtra\x18\x04 \x01(\x08\x12\x12\n\nsourceCopy\x18\x05 \x01(\r\x12\x12\n\ntargetCopy\x18\x06 \x01(\r\x12>\n\x08language\x18\x07 \x01(\x0e\x32#.edu.stanford.nlp.pipeline.Language:\x07Unknown\"\xc6\x02\n\nCorefChain\x12\x0f\n\x07\x63hainID\x18\x01 \x02(\x05\x12\x43\n\x07mention\x18\x02 \x03(\x0b\x32\x32.edu.stanford.nlp.pipeline.CorefChain.CorefMention\x12\x16\n\x0erepresentative\x18\x03 \x02(\r\x1a\xc9\x01\n\x0c\x43orefMention\x12\x11\n\tmentionID\x18\x01 \x01(\x05\x12\x13\n\x0bmentionType\x18\x02 \x01(\t\x12\x0e\n\x06number\x18\x03 \x01(\t\x12\x0e\n\x06gender\x18\x04 \x01(\t\x12\x0f\n\x07\x61nimacy\x18\x05 \x01(\t\x12\x12\n\nbeginIndex\x18\x06 \x01(\r\x12\x10\n\x08\x65ndIndex\x18\x07 \x01(\r\x12\x11\n\theadIndex\x18\t \x01(\r\x12\x15\n\rsentenceIndex\x18\n \x01(\r\x12\x10\n\x08position\x18\x0b \x01(\r\"\xef\x08\n\x07Mention\x12\x11\n\tmentionID\x18\x01 \x01(\x05\x12\x13\n\x0bmentionType\x18\x02 \x01(\t\x12\x0e\n\x06number\x18\x03 \x01(\t\x12\x0e\n\x06gender\x18\x04 \x01(\t\x12\x0f\n\x07\x61nimacy\x18\x05 \x01(\t\x12\x0e\n\x06person\x18\x06 \x01(\t\x12\x12\n\nstartIndex\x18\x07 \x01(\r\x12\x10\n\x08\x65ndIndex\x18\t \x01(\r\x12\x11\n\theadIndex\x18\n \x01(\x05\x12\x12\n\nheadString\x18\x0b \x01(\t\x12\x11\n\tnerString\x18\x0c \x01(\t\x12\x13\n\x0boriginalRef\x18\r \x01(\x05\x12\x1a\n\x12goldCorefClusterID\x18\x0e \x01(\x05\x12\x16\n\x0e\x63orefClusterID\x18\x0f \x01(\x05\x12\x12\n\nmentionNum\x18\x10 \x01(\x05\x12\x0f\n\x07sentNum\x18\x11 \x01(\x05\x12\r\n\x05utter\x18\x12 \x01(\x05\x12\x11\n\tparagraph\x18\x13 \x01(\x05\x12\x11\n\tisSubject\x18\x14 \x01(\x08\x12\x16\n\x0eisDirectObject\x18\x15 \x01(\x08\x12\x18\n\x10isIndirectObject\x18\x16 \x01(\x08\x12\x1b\n\x13isPrepositionObject\x18\x17 \x01(\x08\x12\x0f\n\x07hasTwin\x18\x18 \x01(\x08\x12\x0f\n\x07generic\x18\x19 \x01(\x08\x12\x13\n\x0bisSingleton\x18\x1a \x01(\x08\x12\x1a\n\x12hasBasicDependency\x18\x1b \x01(\x08\x12\x1d\n\x15hasEnhancedDepenedncy\x18\x1c \x01(\x08\x12\x1b\n\x13hasContextParseTree\x18\x1d \x01(\x08\x12?\n\x0fheadIndexedWord\x18\x1e \x01(\x0b\x32&.edu.stanford.nlp.pipeline.IndexedWord\x12=\n\rdependingVerb\x18\x1f \x01(\x0b\x32&.edu.stanford.nlp.pipeline.IndexedWord\x12\x38\n\x08headWord\x18 \x01(\x0b\x32&.edu.stanford.nlp.pipeline.IndexedWord\x12;\n\x0bspeakerInfo\x18! \x01(\x0b\x32&.edu.stanford.nlp.pipeline.SpeakerInfo\x12=\n\rsentenceWords\x18\x32 \x03(\x0b\x32&.edu.stanford.nlp.pipeline.IndexedWord\x12<\n\x0coriginalSpan\x18\x33 \x03(\x0b\x32&.edu.stanford.nlp.pipeline.IndexedWord\x12\x12\n\ndependents\x18\x34 \x03(\t\x12\x19\n\x11preprocessedTerms\x18\x35 \x03(\t\x12\x13\n\x0b\x61ppositions\x18\x36 \x03(\x05\x12\x1c\n\x14predicateNominatives\x18\x37 \x03(\x05\x12\x18\n\x10relativePronouns\x18\x38 \x03(\x05\x12\x13\n\x0blistMembers\x18\x39 \x03(\x05\x12\x15\n\rbelongToLists\x18: \x03(\x05\"X\n\x0bIndexedWord\x12\x13\n\x0bsentenceNum\x18\x01 \x01(\x05\x12\x12\n\ntokenIndex\x18\x02 \x01(\x05\x12\r\n\x05\x64ocID\x18\x03 \x01(\x05\x12\x11\n\tcopyCount\x18\x04 \x01(\r\"4\n\x0bSpeakerInfo\x12\x13\n\x0bspeakerName\x18\x01 \x01(\t\x12\x10\n\x08mentions\x18\x02 \x03(\x05\"\"\n\x04Span\x12\r\n\x05\x62\x65gin\x18\x01 \x02(\r\x12\x0b\n\x03\x65nd\x18\x02 \x02(\r\"w\n\x05Timex\x12\r\n\x05value\x18\x01 \x01(\t\x12\x10\n\x08\x61ltValue\x18\x02 \x01(\t\x12\x0c\n\x04text\x18\x03 \x01(\t\x12\x0c\n\x04type\x18\x04 \x01(\t\x12\x0b\n\x03tid\x18\x05 \x01(\t\x12\x12\n\nbeginPoint\x18\x06 \x01(\r\x12\x10\n\x08\x65ndPoint\x18\x07 \x01(\r\"\xdb\x01\n\x06\x45ntity\x12\x11\n\theadStart\x18\x06 \x01(\r\x12\x0f\n\x07headEnd\x18\x07 \x01(\r\x12\x13\n\x0bmentionType\x18\x08 \x01(\t\x12\x16\n\x0enormalizedName\x18\t \x01(\t\x12\x16\n\x0eheadTokenIndex\x18\n \x01(\r\x12\x0f\n\x07\x63orefID\x18\x0b \x01(\t\x12\x10\n\x08objectID\x18\x01 \x01(\t\x12\x13\n\x0b\x65xtentStart\x18\x02 \x01(\r\x12\x11\n\textentEnd\x18\x03 \x01(\r\x12\x0c\n\x04type\x18\x04 \x01(\t\x12\x0f\n\x07subtype\x18\x05 \x01(\t\"\xb7\x01\n\x08Relation\x12\x0f\n\x07\x61rgName\x18\x06 \x03(\t\x12.\n\x03\x61rg\x18\x07 \x03(\x0b\x32!.edu.stanford.nlp.pipeline.Entity\x12\x11\n\tsignature\x18\x08 \x01(\t\x12\x10\n\x08objectID\x18\x01 \x01(\t\x12\x13\n\x0b\x65xtentStart\x18\x02 \x01(\r\x12\x11\n\textentEnd\x18\x03 \x01(\r\x12\x0c\n\x04type\x18\x04 \x01(\t\x12\x0f\n\x07subtype\x18\x05 \x01(\t\"\xb2\x01\n\x08Operator\x12\x0c\n\x04name\x18\x01 \x02(\t\x12\x1b\n\x13quantifierSpanBegin\x18\x02 \x02(\x05\x12\x19\n\x11quantifierSpanEnd\x18\x03 \x02(\x05\x12\x18\n\x10subjectSpanBegin\x18\x04 \x02(\x05\x12\x16\n\x0esubjectSpanEnd\x18\x05 \x02(\x05\x12\x17\n\x0fobjectSpanBegin\x18\x06 \x02(\x05\x12\x15\n\robjectSpanEnd\x18\x07 \x02(\x05\"\xa9\x04\n\x08Polarity\x12K\n\x12projectEquivalence\x18\x01 \x02(\x0e\x32/.edu.stanford.nlp.pipeline.NaturalLogicRelation\x12Q\n\x18projectForwardEntailment\x18\x02 \x02(\x0e\x32/.edu.stanford.nlp.pipeline.NaturalLogicRelation\x12Q\n\x18projectReverseEntailment\x18\x03 \x02(\x0e\x32/.edu.stanford.nlp.pipeline.NaturalLogicRelation\x12H\n\x0fprojectNegation\x18\x04 \x02(\x0e\x32/.edu.stanford.nlp.pipeline.NaturalLogicRelation\x12K\n\x12projectAlternation\x18\x05 \x02(\x0e\x32/.edu.stanford.nlp.pipeline.NaturalLogicRelation\x12\x45\n\x0cprojectCover\x18\x06 \x02(\x0e\x32/.edu.stanford.nlp.pipeline.NaturalLogicRelation\x12L\n\x13projectIndependence\x18\x07 \x02(\x0e\x32/.edu.stanford.nlp.pipeline.NaturalLogicRelation\"\xdd\x02\n\nNERMention\x12\x15\n\rsentenceIndex\x18\x01 \x01(\r\x12%\n\x1dtokenStartInSentenceInclusive\x18\x02 \x02(\r\x12#\n\x1btokenEndInSentenceExclusive\x18\x03 \x02(\r\x12\x0b\n\x03ner\x18\x04 \x02(\t\x12\x15\n\rnormalizedNER\x18\x05 \x01(\t\x12\x12\n\nentityType\x18\x06 \x01(\t\x12/\n\x05timex\x18\x07 \x01(\x0b\x32 .edu.stanford.nlp.pipeline.Timex\x12\x17\n\x0fwikipediaEntity\x18\x08 \x01(\t\x12\x0e\n\x06gender\x18\t \x01(\t\x12\x1a\n\x12\x65ntityMentionIndex\x18\n \x01(\r\x12#\n\x1b\x63\x61nonicalEntityMentionIndex\x18\x0b \x01(\r\x12\x19\n\x11\x65ntityMentionText\x18\x0c \x01(\t\"Y\n\x10SentenceFragment\x12\x12\n\ntokenIndex\x18\x01 \x03(\r\x12\x0c\n\x04root\x18\x02 \x01(\r\x12\x14\n\x0c\x61ssumedTruth\x18\x03 \x01(\x08\x12\r\n\x05score\x18\x04 \x01(\x01\":\n\rTokenLocation\x12\x15\n\rsentenceIndex\x18\x01 \x01(\r\x12\x12\n\ntokenIndex\x18\x02 \x01(\r\"\x9a\x03\n\x0eRelationTriple\x12\x0f\n\x07subject\x18\x01 \x01(\t\x12\x10\n\x08relation\x18\x02 \x01(\t\x12\x0e\n\x06object\x18\x03 \x01(\t\x12\x12\n\nconfidence\x18\x04 \x01(\x01\x12?\n\rsubjectTokens\x18\r \x03(\x0b\x32(.edu.stanford.nlp.pipeline.TokenLocation\x12@\n\x0erelationTokens\x18\x0e \x03(\x0b\x32(.edu.stanford.nlp.pipeline.TokenLocation\x12>\n\x0cobjectTokens\x18\x0f \x03(\x0b\x32(.edu.stanford.nlp.pipeline.TokenLocation\x12\x38\n\x04tree\x18\x08 \x01(\x0b\x32*.edu.stanford.nlp.pipeline.DependencyGraph\x12\x0e\n\x06istmod\x18\t \x01(\x08\x12\x10\n\x08prefixBe\x18\n \x01(\x08\x12\x10\n\x08suffixBe\x18\x0b \x01(\x08\x12\x10\n\x08suffixOf\x18\x0c \x01(\x08\"-\n\x0fMapStringString\x12\x0b\n\x03key\x18\x01 \x03(\t\x12\r\n\x05value\x18\x02 \x03(\t\"*\n\x0cMapIntString\x12\x0b\n\x03key\x18\x01 \x03(\r\x12\r\n\x05value\x18\x02 \x03(\t\"\xfc\x01\n\x07Section\x12\x11\n\tcharBegin\x18\x01 \x02(\r\x12\x0f\n\x07\x63harEnd\x18\x02 \x02(\r\x12\x0e\n\x06\x61uthor\x18\x03 \x01(\t\x12\x17\n\x0fsentenceIndexes\x18\x04 \x03(\r\x12\x10\n\x08\x64\x61tetime\x18\x05 \x01(\t\x12\x30\n\x06quotes\x18\x06 \x03(\x0b\x32 .edu.stanford.nlp.pipeline.Quote\x12\x17\n\x0f\x61uthorCharBegin\x18\x07 \x01(\r\x12\x15\n\rauthorCharEnd\x18\x08 \x01(\r\x12\x30\n\x06xmlTag\x18\t \x02(\x0b\x32 .edu.stanford.nlp.pipeline.Token\"\xe4\x01\n\x0eSemgrexRequest\x12\x0f\n\x07semgrex\x18\x01 \x03(\t\x12\x45\n\x05query\x18\x02 \x03(\x0b\x32\x36.edu.stanford.nlp.pipeline.SemgrexRequest.Dependencies\x1az\n\x0c\x44\x65pendencies\x12/\n\x05token\x18\x01 \x03(\x0b\x32 .edu.stanford.nlp.pipeline.Token\x12\x39\n\x05graph\x18\x02 \x02(\x0b\x32*.edu.stanford.nlp.pipeline.DependencyGraph\"\x8a\x04\n\x0fSemgrexResponse\x12\x46\n\x06result\x18\x01 \x03(\x0b\x32\x36.edu.stanford.nlp.pipeline.SemgrexResponse.GraphResult\x1a-\n\tNamedNode\x12\x0c\n\x04name\x18\x01 \x02(\t\x12\x12\n\nmatchIndex\x18\x02 \x02(\x05\x1a+\n\rNamedRelation\x12\x0c\n\x04name\x18\x01 \x02(\t\x12\x0c\n\x04reln\x18\x02 \x02(\t\x1a\xa7\x01\n\x05Match\x12\x12\n\nmatchIndex\x18\x01 \x02(\x05\x12\x42\n\x04node\x18\x02 \x03(\x0b\x32\x34.edu.stanford.nlp.pipeline.SemgrexResponse.NamedNode\x12\x46\n\x04reln\x18\x03 \x03(\x0b\x32\x38.edu.stanford.nlp.pipeline.SemgrexResponse.NamedRelation\x1aP\n\rSemgrexResult\x12?\n\x05match\x18\x01 \x03(\x0b\x32\x30.edu.stanford.nlp.pipeline.SemgrexResponse.Match\x1aW\n\x0bGraphResult\x12H\n\x06result\x18\x01 \x03(\x0b\x32\x38.edu.stanford.nlp.pipeline.SemgrexResponse.SemgrexResult\"W\n\x12TokensRegexRequest\x12\x30\n\x03\x64oc\x18\x01 \x02(\x0b\x32#.edu.stanford.nlp.pipeline.Document\x12\x0f\n\x07pattern\x18\x02 \x03(\t\"\xa7\x03\n\x13TokensRegexResponse\x12J\n\x05match\x18\x01 \x03(\x0b\x32;.edu.stanford.nlp.pipeline.TokensRegexResponse.PatternMatch\x1a\x39\n\rMatchLocation\x12\x0c\n\x04text\x18\x01 \x01(\t\x12\r\n\x05\x62\x65gin\x18\x02 \x01(\x05\x12\x0b\n\x03\x65nd\x18\x03 \x01(\x05\x1a\xb3\x01\n\x05Match\x12\x10\n\x08sentence\x18\x01 \x02(\x05\x12K\n\x05match\x18\x02 \x02(\x0b\x32<.edu.stanford.nlp.pipeline.TokensRegexResponse.MatchLocation\x12K\n\x05group\x18\x03 \x03(\x0b\x32<.edu.stanford.nlp.pipeline.TokensRegexResponse.MatchLocation\x1aS\n\x0cPatternMatch\x12\x43\n\x05match\x18\x01 \x03(\x0b\x32\x34.edu.stanford.nlp.pipeline.TokensRegexResponse.Match\"\xae\x01\n\x19\x44\x65pendencyEnhancerRequest\x12\x35\n\x08\x64ocument\x18\x01 \x02(\x0b\x32#.edu.stanford.nlp.pipeline.Document\x12\x37\n\x08language\x18\x02 \x01(\x0e\x32#.edu.stanford.nlp.pipeline.LanguageH\x00\x12\x1a\n\x10relativePronouns\x18\x03 \x01(\tH\x00\x42\x05\n\x03ref\"\xb4\x01\n\x12\x46lattenedParseTree\x12\x41\n\x05nodes\x18\x01 \x03(\x0b\x32\x32.edu.stanford.nlp.pipeline.FlattenedParseTree.Node\x1a[\n\x04Node\x12\x12\n\x08openNode\x18\x01 \x01(\x08H\x00\x12\x13\n\tcloseNode\x18\x02 \x01(\x08H\x00\x12\x0f\n\x05value\x18\x03 \x01(\tH\x00\x12\r\n\x05score\x18\x04 \x01(\x01\x42\n\n\x08\x63ontents\"\xf6\x01\n\x15\x45valuateParserRequest\x12N\n\x08treebank\x18\x01 \x03(\x0b\x32<.edu.stanford.nlp.pipeline.EvaluateParserRequest.ParseResult\x1a\x8c\x01\n\x0bParseResult\x12;\n\x04gold\x18\x01 \x02(\x0b\x32-.edu.stanford.nlp.pipeline.FlattenedParseTree\x12@\n\tpredicted\x18\x02 \x03(\x0b\x32-.edu.stanford.nlp.pipeline.FlattenedParseTree\"$\n\x16\x45valuateParserResponse\x12\n\n\x02\x66\x31\x18\x01 \x02(\x01*\xa3\x01\n\x08Language\x12\x0b\n\x07Unknown\x10\x00\x12\x07\n\x03\x41ny\x10\x01\x12\n\n\x06\x41rabic\x10\x02\x12\x0b\n\x07\x43hinese\x10\x03\x12\x0b\n\x07\x45nglish\x10\x04\x12\n\n\x06German\x10\x05\x12\n\n\x06\x46rench\x10\x06\x12\n\n\x06Hebrew\x10\x07\x12\x0b\n\x07Spanish\x10\x08\x12\x14\n\x10UniversalEnglish\x10\t\x12\x14\n\x10UniversalChinese\x10\n*h\n\tSentiment\x12\x13\n\x0fSTRONG_NEGATIVE\x10\x00\x12\x11\n\rWEAK_NEGATIVE\x10\x01\x12\x0b\n\x07NEUTRAL\x10\x02\x12\x11\n\rWEAK_POSITIVE\x10\x03\x12\x13\n\x0fSTRONG_POSITIVE\x10\x04*\x93\x01\n\x14NaturalLogicRelation\x12\x0f\n\x0b\x45QUIVALENCE\x10\x00\x12\x16\n\x12\x46ORWARD_ENTAILMENT\x10\x01\x12\x16\n\x12REVERSE_ENTAILMENT\x10\x02\x12\x0c\n\x08NEGATION\x10\x03\x12\x0f\n\x0b\x41LTERNATION\x10\x04\x12\t\n\x05\x43OVER\x10\x05\x12\x10\n\x0cINDEPENDENCE\x10\x06\x42*\n\x19\x65\x64u.stanford.nlp.pipelineB\rCoreNLPProtos'
)
_LANGUAGE = _descriptor.EnumDescriptor(
@@ -75,8 +75,8 @@ _LANGUAGE = _descriptor.EnumDescriptor(
],
containing_type=None,
serialized_options=None,
- serialized_start=11149,
- serialized_end=11312,
+ serialized_start=11619,
+ serialized_end=11782,
)
_sym_db.RegisterEnumDescriptor(_LANGUAGE)
@@ -110,8 +110,8 @@ _SENTIMENT = _descriptor.EnumDescriptor(
],
containing_type=None,
serialized_options=None,
- serialized_start=11314,
- serialized_end=11418,
+ serialized_start=11784,
+ serialized_end=11888,
)
_sym_db.RegisterEnumDescriptor(_SENTIMENT)
@@ -153,8 +153,8 @@ _NATURALLOGICRELATION = _descriptor.EnumDescriptor(
],
containing_type=None,
serialized_options=None,
- serialized_start=11421,
- serialized_end=11568,
+ serialized_start=11891,
+ serialized_end=12038,
)
_sym_db.RegisterEnumDescriptor(_NATURALLOGICRELATION)
@@ -3522,6 +3522,190 @@ _DEPENDENCYENHANCERREQUEST = _descriptor.Descriptor(
serialized_end=11146,
)
+
+_FLATTENEDPARSETREE_NODE = _descriptor.Descriptor(
+ name='Node',
+ full_name='edu.stanford.nlp.pipeline.FlattenedParseTree.Node',
+ filename=None,
+ file=DESCRIPTOR,
+ containing_type=None,
+ fields=[
+ _descriptor.FieldDescriptor(
+ name='openNode', full_name='edu.stanford.nlp.pipeline.FlattenedParseTree.Node.openNode', index=0,
+ number=1, type=8, cpp_type=7, label=1,
+ has_default_value=False, default_value=False,
+ message_type=None, enum_type=None, containing_type=None,
+ is_extension=False, extension_scope=None,
+ serialized_options=None, file=DESCRIPTOR),
+ _descriptor.FieldDescriptor(
+ name='closeNode', full_name='edu.stanford.nlp.pipeline.FlattenedParseTree.Node.closeNode', index=1,
+ number=2, type=8, cpp_type=7, label=1,
+ has_default_value=False, default_value=False,
+ message_type=None, enum_type=None, containing_type=None,
+ is_extension=False, extension_scope=None,
+ serialized_options=None, file=DESCRIPTOR),
+ _descriptor.FieldDescriptor(
+ name='value', full_name='edu.stanford.nlp.pipeline.FlattenedParseTree.Node.value', index=2,
+ number=3, type=9, cpp_type=9, label=1,
+ has_default_value=False, default_value=b"".decode('utf-8'),
+ message_type=None, enum_type=None, containing_type=None,
+ is_extension=False, extension_scope=None,
+ serialized_options=None, file=DESCRIPTOR),
+ _descriptor.FieldDescriptor(
+ name='score', full_name='edu.stanford.nlp.pipeline.FlattenedParseTree.Node.score', index=3,
+ number=4, type=1, cpp_type=5, label=1,
+ has_default_value=False, default_value=float(0),
+ message_type=None, enum_type=None, containing_type=None,
+ is_extension=False, extension_scope=None,
+ serialized_options=None, file=DESCRIPTOR),
+ ],
+ extensions=[
+ ],
+ nested_types=[],
+ enum_types=[
+ ],
+ serialized_options=None,
+ is_extendable=False,
+ syntax='proto2',
+ extension_ranges=[],
+ oneofs=[
+ _descriptor.OneofDescriptor(
+ name='contents', full_name='edu.stanford.nlp.pipeline.FlattenedParseTree.Node.contents',
+ index=0, containing_type=None, fields=[]),
+ ],
+ serialized_start=11238,
+ serialized_end=11329,
+)
+
+_FLATTENEDPARSETREE = _descriptor.Descriptor(
+ name='FlattenedParseTree',
+ full_name='edu.stanford.nlp.pipeline.FlattenedParseTree',
+ filename=None,
+ file=DESCRIPTOR,
+ containing_type=None,
+ fields=[
+ _descriptor.FieldDescriptor(
+ name='nodes', full_name='edu.stanford.nlp.pipeline.FlattenedParseTree.nodes', index=0,
+ number=1, type=11, cpp_type=10, label=3,
+ has_default_value=False, default_value=[],
+ message_type=None, enum_type=None, containing_type=None,
+ is_extension=False, extension_scope=None,
+ serialized_options=None, file=DESCRIPTOR),
+ ],
+ extensions=[
+ ],
+ nested_types=[_FLATTENEDPARSETREE_NODE, ],
+ enum_types=[
+ ],
+ serialized_options=None,
+ is_extendable=False,
+ syntax='proto2',
+ extension_ranges=[],
+ oneofs=[
+ ],
+ serialized_start=11149,
+ serialized_end=11329,
+)
+
+
+_EVALUATEPARSERREQUEST_PARSERESULT = _descriptor.Descriptor(
+ name='ParseResult',
+ full_name='edu.stanford.nlp.pipeline.EvaluateParserRequest.ParseResult',
+ filename=None,
+ file=DESCRIPTOR,
+ containing_type=None,
+ fields=[
+ _descriptor.FieldDescriptor(
+ name='gold', full_name='edu.stanford.nlp.pipeline.EvaluateParserRequest.ParseResult.gold', index=0,
+ number=1, type=11, cpp_type=10, label=2,
+ has_default_value=False, default_value=None,
+ message_type=None, enum_type=None, containing_type=None,
+ is_extension=False, extension_scope=None,
+ serialized_options=None, file=DESCRIPTOR),
+ _descriptor.FieldDescriptor(
+ name='predicted', full_name='edu.stanford.nlp.pipeline.EvaluateParserRequest.ParseResult.predicted', index=1,
+ number=2, type=11, cpp_type=10, label=3,
+ has_default_value=False, default_value=[],
+ message_type=None, enum_type=None, containing_type=None,
+ is_extension=False, extension_scope=None,
+ serialized_options=None, file=DESCRIPTOR),
+ ],
+ extensions=[
+ ],
+ nested_types=[],
+ enum_types=[
+ ],
+ serialized_options=None,
+ is_extendable=False,
+ syntax='proto2',
+ extension_ranges=[],
+ oneofs=[
+ ],
+ serialized_start=11438,
+ serialized_end=11578,
+)
+
+_EVALUATEPARSERREQUEST = _descriptor.Descriptor(
+ name='EvaluateParserRequest',
+ full_name='edu.stanford.nlp.pipeline.EvaluateParserRequest',
+ filename=None,
+ file=DESCRIPTOR,
+ containing_type=None,
+ fields=[
+ _descriptor.FieldDescriptor(
+ name='treebank', full_name='edu.stanford.nlp.pipeline.EvaluateParserRequest.treebank', index=0,
+ number=1, type=11, cpp_type=10, label=3,
+ has_default_value=False, default_value=[],
+ message_type=None, enum_type=None, containing_type=None,
+ is_extension=False, extension_scope=None,
+ serialized_options=None, file=DESCRIPTOR),
+ ],
+ extensions=[
+ ],
+ nested_types=[_EVALUATEPARSERREQUEST_PARSERESULT, ],
+ enum_types=[
+ ],
+ serialized_options=None,
+ is_extendable=False,
+ syntax='proto2',
+ extension_ranges=[],
+ oneofs=[
+ ],
+ serialized_start=11332,
+ serialized_end=11578,
+)
+
+
+_EVALUATEPARSERRESPONSE = _descriptor.Descriptor(
+ name='EvaluateParserResponse',
+ full_name='edu.stanford.nlp.pipeline.EvaluateParserResponse',
+ filename=None,
+ file=DESCRIPTOR,
+ containing_type=None,
+ fields=[
+ _descriptor.FieldDescriptor(
+ name='f1', full_name='edu.stanford.nlp.pipeline.EvaluateParserResponse.f1', index=0,
+ number=1, type=1, cpp_type=5, label=2,
+ has_default_value=False, default_value=float(0),
+ message_type=None, enum_type=None, containing_type=None,
+ is_extension=False, extension_scope=None,
+ serialized_options=None, file=DESCRIPTOR),
+ ],
+ extensions=[
+ ],
+ nested_types=[],
+ enum_types=[
+ ],
+ serialized_options=None,
+ is_extendable=False,
+ syntax='proto2',
+ extension_ranges=[],
+ oneofs=[
+ ],
+ serialized_start=11580,
+ serialized_end=11616,
+)
+
_DOCUMENT.fields_by_name['sentence'].message_type = _SENTENCE
_DOCUMENT.fields_by_name['corefChain'].message_type = _COREFCHAIN
_DOCUMENT.fields_by_name['sentencelessToken'].message_type = _TOKEN
@@ -3619,6 +3803,21 @@ _DEPENDENCYENHANCERREQUEST.fields_by_name['language'].containing_oneof = _DEPEND
_DEPENDENCYENHANCERREQUEST.oneofs_by_name['ref'].fields.append(
_DEPENDENCYENHANCERREQUEST.fields_by_name['relativePronouns'])
_DEPENDENCYENHANCERREQUEST.fields_by_name['relativePronouns'].containing_oneof = _DEPENDENCYENHANCERREQUEST.oneofs_by_name['ref']
+_FLATTENEDPARSETREE_NODE.containing_type = _FLATTENEDPARSETREE
+_FLATTENEDPARSETREE_NODE.oneofs_by_name['contents'].fields.append(
+ _FLATTENEDPARSETREE_NODE.fields_by_name['openNode'])
+_FLATTENEDPARSETREE_NODE.fields_by_name['openNode'].containing_oneof = _FLATTENEDPARSETREE_NODE.oneofs_by_name['contents']
+_FLATTENEDPARSETREE_NODE.oneofs_by_name['contents'].fields.append(
+ _FLATTENEDPARSETREE_NODE.fields_by_name['closeNode'])
+_FLATTENEDPARSETREE_NODE.fields_by_name['closeNode'].containing_oneof = _FLATTENEDPARSETREE_NODE.oneofs_by_name['contents']
+_FLATTENEDPARSETREE_NODE.oneofs_by_name['contents'].fields.append(
+ _FLATTENEDPARSETREE_NODE.fields_by_name['value'])
+_FLATTENEDPARSETREE_NODE.fields_by_name['value'].containing_oneof = _FLATTENEDPARSETREE_NODE.oneofs_by_name['contents']
+_FLATTENEDPARSETREE.fields_by_name['nodes'].message_type = _FLATTENEDPARSETREE_NODE
+_EVALUATEPARSERREQUEST_PARSERESULT.fields_by_name['gold'].message_type = _FLATTENEDPARSETREE
+_EVALUATEPARSERREQUEST_PARSERESULT.fields_by_name['predicted'].message_type = _FLATTENEDPARSETREE
+_EVALUATEPARSERREQUEST_PARSERESULT.containing_type = _EVALUATEPARSERREQUEST
+_EVALUATEPARSERREQUEST.fields_by_name['treebank'].message_type = _EVALUATEPARSERREQUEST_PARSERESULT
DESCRIPTOR.message_types_by_name['Document'] = _DOCUMENT
DESCRIPTOR.message_types_by_name['Sentence'] = _SENTENCE
DESCRIPTOR.message_types_by_name['Token'] = _TOKEN
@@ -3647,6 +3846,9 @@ DESCRIPTOR.message_types_by_name['SemgrexResponse'] = _SEMGREXRESPONSE
DESCRIPTOR.message_types_by_name['TokensRegexRequest'] = _TOKENSREGEXREQUEST
DESCRIPTOR.message_types_by_name['TokensRegexResponse'] = _TOKENSREGEXRESPONSE
DESCRIPTOR.message_types_by_name['DependencyEnhancerRequest'] = _DEPENDENCYENHANCERREQUEST
+DESCRIPTOR.message_types_by_name['FlattenedParseTree'] = _FLATTENEDPARSETREE
+DESCRIPTOR.message_types_by_name['EvaluateParserRequest'] = _EVALUATEPARSERREQUEST
+DESCRIPTOR.message_types_by_name['EvaluateParserResponse'] = _EVALUATEPARSERRESPONSE
DESCRIPTOR.enum_types_by_name['Language'] = _LANGUAGE
DESCRIPTOR.enum_types_by_name['Sentiment'] = _SENTIMENT
DESCRIPTOR.enum_types_by_name['NaturalLogicRelation'] = _NATURALLOGICRELATION
@@ -3944,6 +4146,43 @@ DependencyEnhancerRequest = _reflection.GeneratedProtocolMessageType('Dependency
})
_sym_db.RegisterMessage(DependencyEnhancerRequest)
+FlattenedParseTree = _reflection.GeneratedProtocolMessageType('FlattenedParseTree', (_message.Message,), {
+
+ 'Node' : _reflection.GeneratedProtocolMessageType('Node', (_message.Message,), {
+ 'DESCRIPTOR' : _FLATTENEDPARSETREE_NODE,
+ '__module__' : 'CoreNLP_pb2'
+ # @@protoc_insertion_point(class_scope:edu.stanford.nlp.pipeline.FlattenedParseTree.Node)
+ })
+ ,
+ 'DESCRIPTOR' : _FLATTENEDPARSETREE,
+ '__module__' : 'CoreNLP_pb2'
+ # @@protoc_insertion_point(class_scope:edu.stanford.nlp.pipeline.FlattenedParseTree)
+ })
+_sym_db.RegisterMessage(FlattenedParseTree)
+_sym_db.RegisterMessage(FlattenedParseTree.Node)
+
+EvaluateParserRequest = _reflection.GeneratedProtocolMessageType('EvaluateParserRequest', (_message.Message,), {
+
+ 'ParseResult' : _reflection.GeneratedProtocolMessageType('ParseResult', (_message.Message,), {
+ 'DESCRIPTOR' : _EVALUATEPARSERREQUEST_PARSERESULT,
+ '__module__' : 'CoreNLP_pb2'
+ # @@protoc_insertion_point(class_scope:edu.stanford.nlp.pipeline.EvaluateParserRequest.ParseResult)
+ })
+ ,
+ 'DESCRIPTOR' : _EVALUATEPARSERREQUEST,
+ '__module__' : 'CoreNLP_pb2'
+ # @@protoc_insertion_point(class_scope:edu.stanford.nlp.pipeline.EvaluateParserRequest)
+ })
+_sym_db.RegisterMessage(EvaluateParserRequest)
+_sym_db.RegisterMessage(EvaluateParserRequest.ParseResult)
+
+EvaluateParserResponse = _reflection.GeneratedProtocolMessageType('EvaluateParserResponse', (_message.Message,), {
+ 'DESCRIPTOR' : _EVALUATEPARSERRESPONSE,
+ '__module__' : 'CoreNLP_pb2'
+ # @@protoc_insertion_point(class_scope:edu.stanford.nlp.pipeline.EvaluateParserResponse)
+ })
+_sym_db.RegisterMessage(EvaluateParserResponse)
+
DESCRIPTOR._options = None
_DEPENDENCYGRAPH.fields_by_name['root']._options = None
diff --git a/stanza/resources/common.py b/stanza/resources/common.py
index a1b7e690..01dca151 100644
--- a/stanza/resources/common.py
+++ b/stanza/resources/common.py
@@ -424,8 +424,12 @@ def download(
logger.info(
f'Downloading default packages for language: {lang} ({lang_name})...'
)
+ # want the URL to become, for example:
+ # https://huggingface.co/stanfordnlp/stanza-af/resolve/v1.3.0/models/default.zip
+ # so we hopefully start from
+ # https://huggingface.co/stanfordnlp/stanza-{lang}/resolve/v{resources_version}/models/{filename}
request_file(
- f'{url}/{resources_version}/{lang}/default.zip',
+ url.format(resources_version=resources_version, lang=lang, filename="default.zip"),
os.path.join(model_dir, lang, f'default.zip'),
proxies,
md5=resources[lang]['default_md5'],
@@ -448,7 +452,7 @@ def download(
for key, value in download_list:
try:
request_file(
- f'{url}/{resources_version}/{lang}/{key}/{value}.pt',
+ url.format(resources_version=resources_version, lang=lang, filename=f"{key}/{value}.pt"),
os.path.join(model_dir, lang, key, f'{value}.pt'),
proxies,
md5=resources[lang][key][value]['md5']
diff --git a/stanza/resources/installation.py b/stanza/resources/installation.py
index 7c5e5b2b..0a942bd8 100644
--- a/stanza/resources/installation.py
+++ b/stanza/resources/installation.py
@@ -12,19 +12,26 @@ from stanza.resources.common import HOME_DIR, request_file, unzip, \
logger = logging.getLogger('stanza')
+DEFAULT_CORENLP_MODEL_URL = os.getenv(
+ 'CORENLP_MODEL_URL',
+ 'https://huggingface.co/stanfordnlp/corenlp-{model}/resolve/{tag}/stanford-corenlp-models-{model}.jar'
+)
+BACKUP_CORENLP_MODEL_URL = "http://nlp.stanford.edu/software/stanford-corenlp-{version}-models-{model}.jar"
+
DEFAULT_CORENLP_URL = os.getenv(
- 'CORENLP_URL',
- "http://nlp.stanford.edu/software/"
+ 'CORENLP_MODEL_URL',
+ 'https://huggingface.co/stanfordnlp/CoreNLP/resolve/{tag}/stanford-corenlp-latest.zip'
)
+
DEFAULT_CORENLP_DIR = os.getenv(
'CORENLP_HOME',
os.path.join(HOME_DIR, 'stanza_corenlp')
)
-AVAILABLE_MODELS = set(['arabic', 'chinese', 'english', 'english-kbp', 'french', 'german', 'spanish'])
+AVAILABLE_MODELS = set(['arabic', 'chinese', 'english-extra', 'english-kbp', 'french', 'german', 'spanish'])
-def download_corenlp_models(model, version, dir=DEFAULT_CORENLP_DIR, url=DEFAULT_CORENLP_URL, logging_level='INFO', proxies=None):
+def download_corenlp_models(model, version, dir=DEFAULT_CORENLP_DIR, url=DEFAULT_CORENLP_MODEL_URL, logging_level='INFO', proxies=None):
"""
A automatic way to download the CoreNLP models.
@@ -34,11 +41,12 @@ def download_corenlp_models(model, version, dir=DEFAULT_CORENLP_DIR, url=DEFAULT
version: the version of the model
dir: the directory to download CoreNLP model into; alternatively can be
set up with environment variable $CORENLP_HOME
- url: the link to download CoreNLP models
+ url: The link to download CoreNLP models.
+ It will need {model} and either {version} or {tag} to properly format the URL
logging_level: logging level to use during installation
"""
dir = os.path.expanduser(dir)
- if model is None or version is None:
+ if not model or not version:
raise ValueError(
"Both model and model version should be specified."
)
@@ -49,9 +57,13 @@ def download_corenlp_models(model, version, dir=DEFAULT_CORENLP_DIR, url=DEFAULT
f'{model} is currently not supported. '
f'Must be one of: {list(AVAILABLE_MODELS)}.'
)
+ # for example:
+ # https://huggingface.co/stanfordnlp/CoreNLP/resolve/v4.2.2/stanford-corenlp-models-french.jar
+ tag = version if version == 'main' else 'v' + version
+ download_url = url.format(tag=tag, model=model, version=version)
try:
request_file(
- url + f'stanford-corenlp-{version}-models-{model}.jar',
+ download_url,
os.path.join(dir, f'stanford-corenlp-{version}-models-{model}.jar'),
proxies
)
@@ -64,7 +76,7 @@ def download_corenlp_models(model, version, dir=DEFAULT_CORENLP_DIR, url=DEFAULT
) from e
-def install_corenlp(dir=DEFAULT_CORENLP_DIR, url=DEFAULT_CORENLP_URL, logging_level=None, proxies=None):
+def install_corenlp(dir=DEFAULT_CORENLP_DIR, url=DEFAULT_CORENLP_URL, logging_level=None, proxies=None, version="main"):
"""
A fully automatic way to install and setting up the CoreNLP library
to use the client functionality.
@@ -72,7 +84,8 @@ def install_corenlp(dir=DEFAULT_CORENLP_DIR, url=DEFAULT_CORENLP_URL, logging_le
Args:
dir: the directory to download CoreNLP model into; alternatively can be
set up with environment variable $CORENLP_HOME
- url: the link to download CoreNLP models
+ url: The link to download CoreNLP models
+ Needs a {version} or {tag} parameter to specify the version
logging_level: logging level to use during installation
"""
dir = os.path.expanduser(dir)
@@ -86,8 +99,11 @@ def install_corenlp(dir=DEFAULT_CORENLP_DIR, url=DEFAULT_CORENLP_URL, logging_le
logger.info(f"Installing CoreNLP package into {dir}...")
# First download the URL package
logger.debug(f"Download to destination file: {os.path.join(dir, 'corenlp.zip')}")
+ tag = version if version == 'main' else 'v' + version
+ url = url.format(version=version, tag=tag)
try:
- request_file(url + 'stanford-corenlp-latest.zip', os.path.join(dir, 'corenlp.zip'), proxies)
+ request_file(url, os.path.join(dir, 'corenlp.zip'), proxies)
+
except (KeyboardInterrupt, SystemExit):
raise
except Exception as e:
diff --git a/stanza/resources/prepare_resources.py b/stanza/resources/prepare_resources.py
index adcea413..23568527 100644
--- a/stanza/resources/prepare_resources.py
+++ b/stanza/resources/prepare_resources.py
@@ -6,6 +6,16 @@ import hashlib
import shutil
import zipfile
+def parse_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--input_dir', type=str, default="/u/nlp/software/stanza/current-models", help='Input dir for various models. Defaults to the recommended home on the nlp cluster')
+ parser.add_argument('--output_dir', type=str, default="/u/nlp/software/stanza/built-models", help='Output dir for various models.')
+ args = parser.parse_args()
+ args.input_dir = os.path.abspath(args.input_dir)
+ args.output_dir = os.path.abspath(args.output_dir)
+ return args
+
+
# default treebank for languages
default_treebanks = {
"af": "afribooms",
@@ -83,10 +93,10 @@ default_treebanks = {
"te": "mtg",
"orv": "torot",
"nn": "nynorsk",
- "mr": "ufal"
+ "mr": "ufal",
+ "multilingual": "ud"
}
-
# default ner for languages
default_ners = {
"af": "nchlt",
@@ -106,7 +116,6 @@ default_ners = {
"zh-hans": "ontonotes",
}
-
# default charlms for languages
default_charlms = {
"af": "oscar",
@@ -155,6 +164,11 @@ default_sentiment = {
"zh-hans": "ren",
}
+# also, a few languages (very few, currently) have constituency parser models
+default_constituency = {
+ "en": "wsj",
+}
+
allowed_empty_languages = [
# we don't have a lot of Thai support yet
"th"
@@ -169,9 +183,11 @@ processor_to_ending = {
"depparse": "parser",
"ner": "nertagger",
"sentiment": "sentiment",
+ "constituency": "constituency",
"pretrain": "pretrain",
"forward_charlm": "forward_charlm",
- "backward_charlm": "backward_charlm"
+ "backward_charlm": "backward_charlm",
+ "langid": "langid"
}
ending_to_processor = {j: i for i, j in processor_to_ending.items()}
@@ -263,7 +279,7 @@ def ensure_dir(dir):
def copy_file(src, dst):
ensure_dir(Path(dst).parent)
- shutil.copy(src, dst)
+ shutil.copy2(src, dst)
def get_md5(path):
@@ -271,14 +287,6 @@ def get_md5(path):
return hashlib.md5(data).hexdigest()
-def parse_args():
- parser = argparse.ArgumentParser()
- parser.add_argument('--input_dir', type=str, help='Input dir for various models.')
- parser.add_argument('--output_dir', type=str, help='Output dir for various models.')
- args = parser.parse_args()
- return args
-
-
def split_model_name(model):
"""
Split model names by _
@@ -313,18 +321,17 @@ def process_dirs(args):
dirs = sorted(os.listdir(args.input_dir))
resources = {}
- for dir in dirs:
- print(f"Processing models in {dir}")
- models = sorted(os.listdir(os.path.join(args.input_dir, dir)))
+ for model_dir in dirs:
+ print(f"Processing models in {model_dir}")
+ models = sorted(os.listdir(os.path.join(args.input_dir, model_dir)))
for model in models:
if not model.endswith('.pt'): continue
# get processor
lang, package, processor = split_model_name(model)
# copy file
- input_path = os.path.join(args.input_dir, dir, model)
+ input_path = os.path.join(args.input_dir, model_dir, model)
output_path = os.path.join(args.output_dir, lang, processor, package + '.pt')
- ensure_dir(Path(output_path).parent)
- shutil.copy(input_path, output_path)
+ copy_file(input_path, output_path)
# maintain md5
md5 = get_md5(output_path)
# maintain dependencies
@@ -337,6 +344,11 @@ def process_dirs(args):
# sentiment models use the default pretrain for the language
pretrain_package = default_treebanks[lang]
dependencies = [{'model': 'pretrain', 'package': pretrain_package}]
+ elif processor == 'constituency':
+ # so far, this invariant is true:
+ # constituency models use the default pretrain for the language
+ pretrain_package = default_treebanks[lang]
+ dependencies = [{'model': 'pretrain', 'package': pretrain_package}]
else:
dependencies = None
# maintain resources
@@ -368,6 +380,8 @@ def process_defaults(args):
charlm_package = default_charlms[lang]
if lang in default_sentiment:
sentiment_package = default_sentiment[lang]
+ if lang in default_constituency:
+ constituency_package = default_constituency[lang]
if lang in default_ners and lang in default_charlms:
ner_dependencies = get_ner_dependencies(lang, ner_package)
@@ -376,6 +390,9 @@ def process_defaults(args):
if lang in default_sentiment:
# All of the sentiment models created so far have used the default pretrain
default_dependencies['sentiment'] = [{'model': 'pretrain', 'package': ud_package}]
+ if lang in default_constituency:
+ # All of the constituency models created so far also use the default pretrain
+ default_dependencies['constituency'] = [{'model': 'pretrain', 'package': ud_package}]
processors = ['tokenize', 'mwt', 'lemma', 'pos', 'depparse', 'pretrain']
if lang in default_ners:
@@ -384,20 +401,28 @@ def process_defaults(args):
processors.extend(['forward_charlm', 'backward_charlm'])
if lang in default_sentiment:
processors.append('sentiment')
+ if lang in default_constituency:
+ processors.append('constituency')
+
+ if lang == 'multilingual':
+ processors = ['langid']
+ default_dependencies = {}
with zipfile.ZipFile('default.zip', 'w', zipfile.ZIP_DEFLATED) as zipf:
for processor in processors:
if processor == 'ner': package = ner_package
elif processor in ['forward_charlm', 'backward_charlm']: package = charlm_package
elif processor == 'sentiment': package = sentiment_package
+ elif processor == 'constituency': package = constituency_package
+ elif processor == 'langid': package = 'ud'
else: package = ud_package
filename = os.path.join(args.output_dir, lang, processor, package + '.pt')
+
if os.path.exists(filename):
print(" Model {} package {}: file {}".format(processor, package, filename))
- if processor in ['tokenize', 'mwt', 'lemma', 'pos', 'depparse', 'ner', 'sentiment']:
+ if processor in ['tokenize', 'mwt', 'lemma', 'pos', 'depparse', 'ner', 'sentiment', 'constituency', 'langid']:
default_processors[processor] = package
- zipf.write(processor)
zipf.write(os.path.join(processor, package + '.pt'))
elif lang in allowed_empty_languages:
# we don't have a lot of Thai support yet
@@ -424,6 +449,7 @@ def process_defaults(args):
def process_lcode(args):
resources = json.load(open(os.path.join(args.output_dir, 'resources.json')))
resources_new = {}
+ resources_new["multilingual"] = resources["multilingual"]
for lang in resources:
if lang not in lcode2lang:
print(lang + ' not found in lcode2lang!')
@@ -439,7 +465,7 @@ def process_misc(args):
resources = json.load(open(os.path.join(args.output_dir, 'resources.json')))
resources['no'] = {'alias': 'nb'}
resources['zh'] = {'alias': 'zh-hans'}
- resources['url'] = 'http://nlp.stanford.edu/software/stanza'
+ resources['url'] = 'https://huggingface.co/stanfordnlp/stanza-{lang}/resolve/v{resources_version}/models/{filename}'
json.dump(resources, open(os.path.join(args.output_dir, 'resources.json'), 'w'), indent=2)
@@ -453,3 +479,4 @@ def main():
if __name__ == '__main__':
main()
+
diff --git a/stanza/server/java_protobuf_requests.py b/stanza/server/java_protobuf_requests.py
index ee036387..d3f67e91 100644
--- a/stanza/server/java_protobuf_requests.py
+++ b/stanza/server/java_protobuf_requests.py
@@ -1,5 +1,8 @@
+from collections import deque
import subprocess
+from stanza.models.constituency.parse_tree import Tree
+from stanza.protobuf import FlattenedParseTree
from stanza.server.client import resolve_classpath
def send_request(request, response_type, java_main, classpath=None):
@@ -16,6 +19,95 @@ def send_request(request, response_type, java_main, classpath=None):
response.ParseFromString(pipe.stdout)
return response
+def add_tree_nodes(proto_tree, tree, score):
+ # add an open node
+ node = proto_tree.nodes.add()
+ node.openNode = True
+ if score is not None:
+ node.score = score
+
+ # add the content of this node
+ node = proto_tree.nodes.add()
+ node.value = tree.label
+
+ # add all children...
+ # leaves get just one node
+ # branches are called recursively
+ for child in tree.children:
+ if child.is_leaf():
+ node = proto_tree.nodes.add()
+ node.value = child.label
+ else:
+ add_tree_nodes(proto_tree, child, None)
+
+ node = proto_tree.nodes.add()
+ node.closeNode = True
+
+def build_tree(tree, score):
+ """
+ Builds a FlattenedParseTree from CoreNLP.proto
+
+ Populates the value field from tree.label and iterates through the
+ children via tree.children. Should work on any tree structure
+ which follows that layout
+
+ The score will be added to the top node (if it is not None)
+
+ Operates by recursively calling add_tree_nodes
+ """
+ proto_tree = FlattenedParseTree()
+ add_tree_nodes(proto_tree, tree, score)
+ return proto_tree
+
+def from_tree(proto_tree):
+ """
+ Convert a FlattenedParseTree back into a Tree
+
+ returns Tree, score
+ (score might be None if it is missing)
+ """
+ score = None
+ stack = deque()
+ for node in proto_tree.nodes:
+ if node.HasField("score") and score is None:
+ score = node.score
+
+ if node.openNode:
+ if len(stack) > 0 and isinstance(stack[-1], FlattenedParseTree.Node) and stack[-1].openNode:
+ raise ValueError("Got a proto with no label on a node: {}".format(proto_tree))
+ stack.append(node)
+ continue
+ if not node.closeNode:
+ child = Tree(label=node.value)
+ # TODO: do something with the score
+ stack.append(child)
+ continue
+
+ # must be a close operation...
+ if len(stack) <= 1:
+ raise ValueError("Got a proto with too many close operations: {}".format(proto_tree))
+ # on a close operation, pop until we hit the open
+ # then turn everything in that span into a new node
+ children = []
+ nextNode = stack.pop()
+ while not isinstance(nextNode, FlattenedParseTree.Node):
+ children.append(nextNode)
+ nextNode = stack.pop()
+ if len(children) == 0:
+ raise ValueError("Got a proto with an open immediately followed by a close: {}".format(proto_tree))
+ children.reverse()
+ label = children[0]
+ children = children[1:]
+ subtree = Tree(label=label.label, children=children)
+ stack.append(subtree)
+
+ if len(stack) > 1:
+ raise ValueError("Got a proto which does not close all of the nodes: {}".format(proto_tree))
+ tree = stack.pop()
+ if not isinstance(tree, Tree):
+ raise ValueError("Got a proto which was just one Open operation: {}".format(proto_tree))
+ return tree, score
+
def add_token(token_list, word, token):
"""
Add a token to a proto request.
diff --git a/stanza/server/parser_eval.py b/stanza/server/parser_eval.py
new file mode 100644
index 00000000..c5b30f6a
--- /dev/null
+++ b/stanza/server/parser_eval.py
@@ -0,0 +1,41 @@
+
+
+
+import stanza
+from stanza.protobuf import EvaluateParserRequest, EvaluateParserResponse
+from stanza.server.java_protobuf_requests import send_request, build_tree, JavaProtobufContext
+
+
+EVALUATE_JAVA = "edu.stanford.nlp.parser.metrics.EvaluateExternalParser"
+
+def build_request(treebank):
+ """
+ treebank should be a list of pairs: [gold, predictions]
+ each predictions is a list of pairs (prediction, score)
+ Note that for now, only one tree is measured, but this may be extensible in the future
+ Trees should be in the form of a Tree from parse_tree.py
+ """
+ request = EvaluateParserRequest()
+ for gold, predictions in treebank:
+ parse_result = request.treebank.add()
+ parse_result.gold.CopyFrom(build_tree(gold, None))
+ for prediction, score in predictions:
+ parse_result.predicted.append(build_tree(prediction, score))
+
+ return request
+
+
+class EvaluateParser(JavaProtobufContext):
+ """
+ Parser evaluation context window
+
+ This is a context window which keeps a process open. Should allow
+ for multiple requests without launching new java processes each time.
+ """
+ def __init__(self, classpath=None):
+ super(EvaluateParser, self).__init__(classpath, EvaluateParserResponse, EVALUATE_JAVA)
+
+ def process(self, treebank):
+ request = build_request(treebank)
+ return self.process_request(request)
+
diff --git a/stanza/server/ud_enhancer.py b/stanza/server/ud_enhancer.py
index e3d64afa..92d6b7ff 100644
--- a/stanza/server/ud_enhancer.py
+++ b/stanza/server/ud_enhancer.py
@@ -72,7 +72,7 @@ def main():
nlp = stanza.Pipeline('en',
processors='tokenize,pos,lemma,depparse')
- with UniversalEnhancer(language="en", classpath="$CLASSPATH") as enhancer:
+ with UniversalEnhancer(language="en") as enhancer:
doc = nlp("This is the car that I bought")
result = enhancer.process(doc)
print(result.sentence[0].enhancedDependencies)
diff --git a/stanza/tests/constituency/test_lstm_model.py b/stanza/tests/constituency/test_lstm_model.py
new file mode 100644
index 00000000..b9bb6f80
--- /dev/null
+++ b/stanza/tests/constituency/test_lstm_model.py
@@ -0,0 +1,143 @@
+import os
+
+import pytest
+
+from stanza.models.constituency import parse_transitions
+from stanza.tests import *
+from stanza.tests.constituency import test_parse_transitions
+from stanza.tests.constituency.test_trainer import build_trainer, pt
+
+pytestmark = [pytest.mark.pipeline, pytest.mark.travis]
+
+def build_model(pt, *args):
+ trainer = build_trainer(pt, *args)
+ return trainer.model
+
+@pytest.fixture(scope="module")
+def unary_model(pt):
+ return build_model(pt, "--transition_scheme", "TOP_DOWN_UNARY")
+
+def test_initial_state(unary_model):
+ test_parse_transitions.test_initial_state(unary_model)
+
+def test_shift(pt):
+ # TODO: might be good to include some tests specifically for shift
+ # in the context of a model with unaries
+ model = build_model(pt)
+ test_parse_transitions.test_shift(model)
+
+def test_unary(unary_model):
+ test_parse_transitions.test_unary(unary_model)
+
+def test_unary_requires_root(unary_model):
+ test_parse_transitions.test_unary_requires_root(unary_model)
+
+def test_open(unary_model):
+ test_parse_transitions.test_open(unary_model)
+
+def test_compound_open(pt):
+ model = build_model(pt, '--transition_scheme', "TOP_DOWN_COMPOUND")
+ test_parse_transitions.test_compound_open(model)
+
+def test_in_order_open(pt):
+ model = build_model(pt, '--transition_scheme', "IN_ORDER")
+ test_parse_transitions.test_in_order_open(model)
+
+def test_close(unary_model):
+ test_parse_transitions.test_close(unary_model)
+
+def run_forward_checks(model):
+ state = test_parse_transitions.build_initial_state(model)[0]
+ model((state,))
+
+ shift = parse_transitions.Shift()
+ state = shift.apply(state, model)
+ model((state,))
+
+ open_transition = parse_transitions.OpenConstituent("NP")
+ assert open_transition.is_legal(state, model)
+ state = open_transition.apply(state, model)
+ assert state.num_opens == 1
+ model((state,))
+
+ state = shift.apply(state, model)
+ model((state,))
+ state = shift.apply(state, model)
+ model((state,))
+ assert state.num_opens == 1
+ # now should have "mox", "opal" on the constituents
+
+ close_transition = parse_transitions.CloseConstituent()
+ assert close_transition.is_legal(state, model)
+ state = close_transition.apply(state, model)
+ assert state.num_opens == 0
+
+ model((state,))
+
+def test_unary_forward(pt, unary_model):
+ """
+ Checks that the forward pass doesn't crash when run after various operations
+
+ Doesn't check the forward pass for making reasonable answers
+ """
+ run_forward_checks(unary_model)
+
+def test_lstm_forward(pt):
+ model = build_model(pt, '--num_lstm_layers', '1')
+ run_forward_checks(model)
+ model = build_model(pt, '--num_lstm_layers', '2')
+ run_forward_checks(model)
+ model = build_model(pt, '--num_lstm_layers', '3')
+ run_forward_checks(model)
+
+def test_multiple_output_forward(pt):
+ """
+ Test a couple different sizes of output layers
+ """
+ model = build_model(pt, '--num_output_layers', '1', '--num_lstm_layers', '2')
+ run_forward_checks(model)
+
+ model = build_model(pt, '--num_output_layers', '2', '--num_lstm_layers', '2')
+ run_forward_checks(model)
+
+def test_no_tag_embedding_forward(pt):
+ """
+ Test that the model continues to work if the tag embedding is turned on or off
+ """
+ model = build_model(pt, '--tag_embedding_dim', '20')
+ run_forward_checks(model)
+
+ model = build_model(pt, '--tag_embedding_dim', '0')
+ run_forward_checks(model)
+
+def test_forward_con_lstm(pt):
+ """
+ Tests an older version of the model
+ """
+ model = build_model(pt, '--num_lstm_layers', '2', '--constituency_lstm')
+ run_forward_checks(model)
+
+def test_forward_combined_dummy(pt):
+ """
+ Tests combined dummy and open node embeddings
+ """
+ model = build_model(pt, '--combined_dummy_embedding')
+ run_forward_checks(model)
+
+ model = build_model(pt, '--no_combined_dummy_embedding')
+ run_forward_checks(model)
+
+def test_forward_charlm(pt):
+ """
+ Tests loading and running a charlm
+
+ Note that this doesn't test the results of the charlm itself,
+ just that the model is shaped correctly
+ """
+ forward_charlm_path = os.path.join(TEST_MODELS_DIR, "en", "forward_charlm", "1billion.pt")
+ backward_charlm_path = os.path.join(TEST_MODELS_DIR, "en", "backward_charlm", "1billion.pt")
+ assert os.path.exists(forward_charlm_path), "Need to download en test models (or update path to the forward charlm)"
+ assert os.path.exists(backward_charlm_path), "Need to download en test models (or update path to the backward charlm)"
+
+ model = build_model(pt, '--charlm_forward_file', forward_charlm_path, '--charlm_backward_file', backward_charlm_path)
+ run_forward_checks(model)
diff --git a/stanza/tests/constituency/test_parse_transitions.py b/stanza/tests/constituency/test_parse_transitions.py
new file mode 100644
index 00000000..a28b9b19
--- /dev/null
+++ b/stanza/tests/constituency/test_parse_transitions.py
@@ -0,0 +1,412 @@
+import pytest
+
+from stanza.models.constituency import parse_transitions
+from stanza.models.constituency.base_model import SimpleModel
+from stanza.models.constituency.parse_transitions import TransitionScheme
+from stanza.tests import *
+
+pytestmark = [pytest.mark.pipeline, pytest.mark.travis]
+
+
+def build_initial_state(model):
+ words = ["Unban", "Mox", "Opal"]
+ tags = ["VB", "NNP", "NNP"]
+
+ state = parse_transitions.initial_state_from_words([list(zip(words, tags))], model)
+ assert len(state) == 1
+ assert state[0].num_transitions() == 0
+ return state
+
+def test_initial_state(model=None):
+ if model is None:
+ model = SimpleModel()
+ states = build_initial_state(model)
+ assert len(states) == 1
+ state = states[0]
+
+ assert state.sentence_length == 3
+ assert state.num_opens == 0
+ # each stack has a sentinel value at the end
+ assert len(state.word_queue) == 4
+ assert len(state.constituents) == 1
+ assert len(state.transitions) == 1
+ assert state.word_position == 0
+
+def test_shift(model=None):
+ if model is None:
+ model = SimpleModel()
+ state = build_initial_state(model)[0]
+
+ open_transition = parse_transitions.OpenConstituent("ROOT")
+ state = open_transition.apply(state, model)
+ open_transition = parse_transitions.OpenConstituent("S")
+ state = open_transition.apply(state, model)
+ shift = parse_transitions.Shift()
+ assert shift.is_legal(state, model)
+ assert len(state.word_queue) == 4
+ assert state.word_position == 0
+
+ state = shift.apply(state, model)
+ assert len(state.word_queue) == 4
+ # 4 because of the dummy created by the opens
+ assert len(state.constituents) == 4
+ assert len(state.transitions) == 4
+ assert shift.is_legal(state, model)
+ assert state.word_position == 1
+ assert not state.empty_word_queue()
+
+ state = shift.apply(state, model)
+ assert len(state.word_queue) == 4
+ assert len(state.constituents) == 5
+ assert len(state.transitions) == 5
+ assert shift.is_legal(state, model)
+ assert state.word_position == 2
+ assert not state.empty_word_queue()
+
+ state = shift.apply(state, model)
+ assert len(state.word_queue) == 4
+ assert len(state.constituents) == 6
+ assert len(state.transitions) == 6
+ assert not shift.is_legal(state, model)
+ assert state.word_position == 3
+ assert state.empty_word_queue()
+
+ constituents = state.constituents
+ assert model.get_top_constituent(constituents).children[0].label == 'Opal'
+ constituents = constituents.pop()
+ assert model.get_top_constituent(constituents).children[0].label == 'Mox'
+ constituents = constituents.pop()
+ assert model.get_top_constituent(constituents).children[0].label == 'Unban'
+
+def test_initial_unary(model=None):
+ # it doesn't make sense to start with a CompoundUnary
+ if model is None:
+ model = SimpleModel()
+
+ state = build_initial_state(model)[0]
+ unary = parse_transitions.CompoundUnary(['ROOT', 'VP'])
+ assert not unary.is_legal(state, model)
+ unary = parse_transitions.CompoundUnary(['VP'])
+ assert not unary.is_legal(state, model)
+
+
+def test_unary(model=None):
+ if model is None:
+ model = SimpleModel()
+ state = build_initial_state(model)[0]
+
+ shift = parse_transitions.Shift()
+ state = shift.apply(state, model)
+
+ # this is technically the wrong parse but we're being lazy
+ unary = parse_transitions.CompoundUnary(['S', 'VP'])
+ assert unary.is_legal(state, model)
+ state = unary.apply(state, model)
+ assert not unary.is_legal(state, model)
+
+ tree = model.get_top_constituent(state.constituents)
+ assert tree.label == 'S'
+ assert len(tree.children) == 1
+ tree = tree.children[0]
+ assert tree.label == 'VP'
+ assert len(tree.children) == 1
+ tree = tree.children[0]
+ assert tree.label == 'VB'
+ assert tree.is_preterminal()
+
+def test_unary_requires_root(model=None):
+ if model is None:
+ model = SimpleModel()
+ state = build_initial_state(model)[0]
+
+ open_transition = parse_transitions.OpenConstituent("S")
+ assert open_transition.is_legal(state, model)
+ state = open_transition.apply(state, model)
+
+ shift = parse_transitions.Shift()
+ assert shift.is_legal(state, model)
+ state = shift.apply(state, model)
+ assert shift.is_legal(state, model)
+ state = shift.apply(state, model)
+ assert shift.is_legal(state, model)
+ state = shift.apply(state, model)
+ assert not shift.is_legal(state, model)
+
+ close_transition = parse_transitions.CloseConstituent()
+ assert close_transition.is_legal(state, model)
+ state = close_transition.apply(state, model)
+ assert not open_transition.is_legal(state, model)
+ assert not close_transition.is_legal(state, model)
+
+ np_unary = parse_transitions.CompoundUnary("NP")
+ assert not np_unary.is_legal(state, model)
+ root_unary = parse_transitions.CompoundUnary("ROOT")
+ assert root_unary.is_legal(state, model)
+ assert not state.finished(model)
+ state = root_unary.apply(state, model)
+ assert not root_unary.is_legal(state, model)
+
+ assert state.finished(model)
+
+def test_open(model=None):
+ if model is None:
+ model = SimpleModel()
+ state = build_initial_state(model)[0]
+
+ shift = parse_transitions.Shift()
+ state = shift.apply(state, model)
+ state = shift.apply(state, model)
+ assert state.num_opens == 0
+
+ open_transition = parse_transitions.OpenConstituent("VP")
+ assert open_transition.is_legal(state, model)
+ state = open_transition.apply(state, model)
+ assert open_transition.is_legal(state, model)
+ assert state.num_opens == 1
+
+ # check that it is illegal if there are too many opens already
+ for i in range(20):
+ state = open_transition.apply(state, model)
+ assert not open_transition.is_legal(state, model)
+ assert state.num_opens == 21
+
+ # check that it is illegal if the state is out of words
+ state = build_initial_state(model)[0]
+ state = shift.apply(state, model)
+ state = shift.apply(state, model)
+ state = shift.apply(state, model)
+ assert not open_transition.is_legal(state, model)
+
+def test_compound_open(model=None):
+ if model is None:
+ model = SimpleModel()
+ state = build_initial_state(model)[0]
+
+ open_transition = parse_transitions.OpenConstituent("ROOT", "S")
+ assert open_transition.is_legal(state, model)
+ shift = parse_transitions.Shift()
+ close_transition = parse_transitions.CloseConstituent()
+
+ state = open_transition.apply(state, model)
+ state = shift.apply(state, model)
+ state = shift.apply(state, model)
+ state = shift.apply(state, model)
+ state = close_transition.apply(state, model)
+
+ tree = model.get_top_constituent(state.constituents)
+ assert tree.label == 'ROOT'
+ assert len(tree.children) == 1
+ tree = tree.children[0]
+ assert tree.label == 'S'
+ assert len(tree.children) == 3
+ assert tree.children[0].children[0].label == 'Unban'
+ assert tree.children[1].children[0].label == 'Mox'
+ assert tree.children[2].children[0].label == 'Opal'
+
+def test_in_order_open(model=None):
+ if model is None:
+ model = SimpleModel(TransitionScheme.IN_ORDER)
+ state = build_initial_state(model)[0]
+
+ shift = parse_transitions.Shift()
+ assert shift.is_legal(state, model)
+ state = shift.apply(state, model)
+ assert not shift.is_legal(state, model)
+
+ open_vp = parse_transitions.OpenConstituent("VP")
+ assert open_vp.is_legal(state, model)
+ state = open_vp.apply(state, model)
+ assert not open_vp.is_legal(state, model)
+
+ close_trans = parse_transitions.CloseConstituent()
+ assert close_trans.is_legal(state, model)
+ state = close_trans.apply(state, model)
+
+ open_s = parse_transitions.OpenConstituent("S")
+ assert open_s.is_legal(state, model)
+ state = open_s.apply(state, model)
+ assert not open_vp.is_legal(state, model)
+
+ # check that root transitions won't happen in the middle of a parse
+ open_root = parse_transitions.OpenConstituent("ROOT")
+ assert not open_root.is_legal(state, model)
+
+ # build (NP (NNP Mox) (NNP Opal))
+ open_np = parse_transitions.OpenConstituent("NP")
+ assert shift.is_legal(state, model)
+ state = shift.apply(state, model)
+ assert open_np.is_legal(state, model)
+ # make sure root can't happen in places where an arbitrary open is legal
+ assert not open_root.is_legal(state, model)
+ state = open_np.apply(state, model)
+ assert shift.is_legal(state, model)
+ state = shift.apply(state, model)
+ assert close_trans.is_legal(state, model)
+ state = close_trans.apply(state, model)
+
+ assert close_trans.is_legal(state, model)
+ state = close_trans.apply(state, model)
+
+ assert open_root.is_legal(state, model)
+ state = open_root.apply(state, model)
+
+def test_too_many_unaries_close():
+ """
+ This tests rejecting Close at the start of a sequence after too many unary transitions
+
+ The model should reject doing multiple "unaries" - eg, Open then Close - in an IN_ORDER sequence
+ """
+ model = SimpleModel(TransitionScheme.IN_ORDER)
+ state = build_initial_state(model)[0]
+
+ shift = parse_transitions.Shift()
+ assert shift.is_legal(state, model)
+ state = shift.apply(state, model)
+
+ open_np = parse_transitions.OpenConstituent("NP")
+ close_trans = parse_transitions.CloseConstituent()
+ for _ in range(parse_transitions.UNARY_LIMIT):
+ assert open_np.is_legal(state, model)
+ state = open_np.apply(state, model)
+
+ assert close_trans.is_legal(state, model)
+ state = close_trans.apply(state, model)
+
+ assert open_np.is_legal(state, model)
+ state = open_np.apply(state, model)
+ assert not close_trans.is_legal(state, model)
+
+def test_too_many_unaries_open():
+ """
+ This tests rejecting Open in the middle of a sequence after too many unary transitions
+
+ The model should reject doing multiple "unaries" - eg, Open then Close - in an IN_ORDER sequence
+ """
+ model = SimpleModel(TransitionScheme.IN_ORDER)
+ state = build_initial_state(model)[0]
+
+ shift = parse_transitions.Shift()
+ assert shift.is_legal(state, model)
+ state = shift.apply(state, model)
+
+ open_np = parse_transitions.OpenConstituent("NP")
+ close_trans = parse_transitions.CloseConstituent()
+
+ assert open_np.is_legal(state, model)
+ state = open_np.apply(state, model)
+ assert not open_np.is_legal(state, model)
+ assert shift.is_legal(state, model)
+ state = shift.apply(state, model)
+
+ for _ in range(parse_transitions.UNARY_LIMIT):
+ assert open_np.is_legal(state, model)
+ state = open_np.apply(state, model)
+
+ assert close_trans.is_legal(state, model)
+ state = close_trans.apply(state, model)
+
+ assert not open_np.is_legal(state, model)
+
+def test_close(model=None):
+ if model is None:
+ model = SimpleModel()
+ # this one actually tests an entire subtree building
+ state = build_initial_state(model)[0]
+
+ shift = parse_transitions.Shift()
+ state = shift.apply(state, model)
+
+ open_transition = parse_transitions.OpenConstituent("NP")
+ assert open_transition.is_legal(state, model)
+ state = open_transition.apply(state, model)
+ assert state.num_opens == 1
+
+ state = shift.apply(state, model)
+ state = shift.apply(state, model)
+ assert state.num_opens == 1
+ # now should have "mox", "opal" on the constituents
+
+ close_transition = parse_transitions.CloseConstituent()
+ assert close_transition.is_legal(state, model)
+ state = close_transition.apply(state, model)
+ assert state.num_opens == 0
+ assert not close_transition.is_legal(state, model)
+
+ tree = model.get_top_constituent(state.constituents)
+ assert tree.label == 'NP'
+ assert len(tree.children) == 2
+ assert tree.children[0].is_preterminal()
+ assert tree.children[1].is_preterminal()
+ assert tree.children[0].children[0].label == 'Mox'
+ assert tree.children[1].children[0].label == 'Opal'
+
+ assert len(state.constituents) == 3
+
+ assert state.all_transitions(model) == [shift, open_transition, shift, shift, close_transition]
+
+def test_hashes():
+ transitions = set()
+
+ shift = parse_transitions.Shift()
+ assert shift not in transitions
+ transitions.add(shift)
+ assert shift in transitions
+ shift = parse_transitions.Shift()
+ assert shift in transitions
+
+ for i in range(5):
+ transitions.add(shift)
+ assert len(transitions) == 1
+
+ unary = parse_transitions.CompoundUnary("asdf")
+ assert unary not in transitions
+ transitions.add(unary)
+ assert unary in transitions
+
+ unary = parse_transitions.CompoundUnary(["asdf", "zzzz"])
+ assert unary not in transitions
+ transitions.add(unary)
+ transitions.add(unary)
+ transitions.add(unary)
+ unary = parse_transitions.CompoundUnary(["asdf", "zzzz"])
+ assert unary in transitions
+
+ # check that the str and the list constructors result in the same item
+ assert len(transitions) == 3
+ unary = parse_transitions.CompoundUnary(["asdf"])
+ assert unary in transitions
+
+ oc = parse_transitions.OpenConstituent("asdf")
+ assert oc not in transitions
+ transitions.add(oc)
+ assert oc in transitions
+ transitions.add(oc)
+ transitions.add(oc)
+ assert len(transitions) == 4
+ assert parse_transitions.OpenConstituent("asdf") in transitions
+
+ cc = parse_transitions.CloseConstituent()
+ assert cc not in transitions
+ transitions.add(cc)
+ transitions.add(cc)
+ transitions.add(cc)
+ assert cc in transitions
+ cc = parse_transitions.CloseConstituent()
+ assert cc in transitions
+ assert len(transitions) == 5
+
+
+def test_sort():
+ expected = []
+
+ expected.append(parse_transitions.Shift())
+ expected.append(parse_transitions.CloseConstituent())
+ expected.append(parse_transitions.CompoundUnary(["NP"]))
+ expected.append(parse_transitions.CompoundUnary(["NP", "VP"]))
+ expected.append(parse_transitions.OpenConstituent("mox"))
+ expected.append(parse_transitions.OpenConstituent("opal"))
+ expected.append(parse_transitions.OpenConstituent("unban"))
+
+ transitions = set(expected)
+ transitions = sorted(transitions)
+ assert transitions == expected
diff --git a/stanza/tests/constituency/test_parse_tree.py b/stanza/tests/constituency/test_parse_tree.py
new file mode 100644
index 00000000..959e9936
--- /dev/null
+++ b/stanza/tests/constituency/test_parse_tree.py
@@ -0,0 +1,196 @@
+import pytest
+
+from stanza.models.constituency.parse_tree import Tree
+from stanza.models.constituency import tree_reader
+
+from stanza.tests import *
+
+pytestmark = [pytest.mark.pipeline, pytest.mark.travis]
+
+def test_leaf_preterminal():
+ foo = Tree(label="foo")
+ assert foo.is_leaf()
+ assert not foo.is_preterminal()
+ assert len(foo.children) == 0
+ assert str(foo) == 'foo'
+
+ bar = Tree(label="bar", children=foo)
+ assert not bar.is_leaf()
+ assert bar.is_preterminal()
+ assert len(bar.children) == 1
+ assert str(bar) == "(bar foo)"
+
+ baz = Tree(label="baz", children=[bar])
+ assert not baz.is_leaf()
+ assert not baz.is_preterminal()
+ assert len(baz.children) == 1
+ assert str(baz) == "(baz (bar foo))"
+
+
+def test_yield_preterminals():
+ text = "((S (VP (VB Unban)) (NP (NNP Mox) (NNP Opal))))"
+ trees = tree_reader.read_trees(text)
+
+ preterminals = trees[0].preterminals()
+ assert len(preterminals) == 3
+ assert str(preterminals) == "[(VB Unban), (NNP Mox), (NNP Opal)]"
+
+def test_depth():
+ text = "(foo) ((S (VP (VB Unban)) (NP (NNP Mox) (NNP Opal))))"
+ trees = tree_reader.read_trees(text)
+ assert trees[0].depth() == 0
+ assert trees[1].depth() == 4
+
+def test_unique_labels():
+ """
+ Test getting the unique labels from a tree
+
+ Assumes tree_reader works, which should be fine since it is tested elsewhere
+ """
+ text="((SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?))) ((SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))"
+
+ trees = tree_reader.read_trees(text)
+
+ labels = Tree.get_unique_constituent_labels(trees)
+ expected = ['NP', 'PP', 'ROOT', 'SBARQ', 'SQ', 'VP', 'WHNP']
+ assert labels == expected
+
+def test_unique_tags():
+ """
+ Test getting the unique tags from a tree
+ """
+ text="((SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))"
+
+ trees = tree_reader.read_trees(text)
+
+ tags = Tree.get_unique_tags(trees)
+ expected = ['.', 'DT', 'IN', 'NN', 'VBZ', 'WP']
+ assert tags == expected
+
+
+def test_unique_words():
+ """
+ Test getting the unique words from a tree
+ """
+ text="((SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))"
+
+ trees = tree_reader.read_trees(text)
+
+ words = Tree.get_unique_words(trees)
+ expected = ['?', 'Who', 'in', 'seat', 'sits', 'this']
+ assert words == expected
+
+def test_rare_words():
+ """
+ Test getting the unique words from a tree
+ """
+ text="((SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?))) ((SBARQ (NP (DT this) (NN seat)) (. ?)))"
+
+ trees = tree_reader.read_trees(text)
+
+ words = Tree.get_rare_words(trees, 0.5)
+ expected = ['Who', 'in', 'sits']
+ assert words == expected
+
+def test_root_labels():
+ text="( (SBARQ-FOO (WHNP-BAR (WP Who)) (SQ#ASDF (VP=1 (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))"
+ trees = tree_reader.read_trees(text)
+ assert ["ROOT"] == Tree.get_root_labels(trees)
+
+ text=("( (SBARQ-FOO (WHNP-BAR (WP Who)) (SQ#ASDF (VP=1 (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))" +
+ "( (SBARQ-FOO (WHNP-BAR (WP Who)) (SQ#ASDF (VP=1 (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))" +
+ "( (SBARQ-FOO (WHNP-BAR (WP Who)) (SQ#ASDF (VP=1 (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))")
+ trees = tree_reader.read_trees(text)
+ assert ["ROOT"] == Tree.get_root_labels(trees)
+
+ text="(FOO) (BAR)"
+ trees = tree_reader.read_trees(text)
+ assert ["BAR", "FOO"] == Tree.get_root_labels(trees)
+
+def test_prune_none():
+ text=["((SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (-NONE- in) (NP (DT this) (NN seat))))) (. ?)))", # test one dead node
+ "((SBARQ (WHNP (-NONE- Who)) (SQ (VP (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))", # test recursive dead nodes
+ "((SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in) (NP (-NONE- this) (-NONE- seat))))) (. ?)))"] # test all children dead
+ expected=["(ROOT (SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (NP (DT this) (NN seat))))) (. ?)))",
+ "(ROOT (SBARQ (SQ (VP (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))",
+ "(ROOT (SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in)))) (. ?)))"]
+
+ for t, e in zip(text, expected):
+ trees = tree_reader.read_trees(t)
+ assert len(trees) == 1
+ tree = trees[0].prune_none()
+ assert e == str(tree)
+
+def test_simplify_labels():
+ text="( (SBARQ-FOO (WHNP-BAR (WP Who)) (SQ#ASDF (VP=1 (VBZ sits) (PP (IN in) (NP (DT this) (- -))))) (. ?)))"
+ expected = "(ROOT (SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in) (NP (DT this) (- -))))) (. ?)))"
+ trees = tree_reader.read_trees(text)
+ trees = [t.simplify_labels() for t in trees]
+ assert len(trees) == 1
+ assert expected == str(trees[0])
+
+def test_remap_constituent_labels():
+ text="(ROOT (SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in)))) (. ?)))"
+ expected="(ROOT (FOO (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in)))) (. ?)))"
+
+ label_map = { "SBARQ": "FOO" }
+ trees = tree_reader.read_trees(text)
+ trees = [t.remap_constituent_labels(label_map) for t in trees]
+ assert len(trees) == 1
+ assert expected == str(trees[0])
+
+def test_remap_constituent_words():
+ text="(ROOT (SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in)))) (. ?)))"
+ expected="(ROOT (SBARQ (WHNP (WP unban)) (SQ (VP (VBZ mox) (PP (IN opal)))) (. ?)))"
+
+ word_map = { "Who": "unban", "sits": "mox", "in": "opal" }
+ trees = tree_reader.read_trees(text)
+ trees = [t.remap_words(word_map) for t in trees]
+ assert len(trees) == 1
+ assert expected == str(trees[0])
+
+def test_replace_words():
+ text="(ROOT (SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in)))) (. ?)))"
+ expected="(ROOT (SBARQ (WHNP (WP unban)) (SQ (VP (VBZ mox) (PP (IN opal)))) (. ?)))"
+ new_words = ["unban", "mox", "opal", "?"]
+
+ trees = tree_reader.read_trees(text)
+ assert len(trees) == 1
+ tree = trees[0]
+ new_tree = tree.replace_words(new_words)
+ assert expected == str(new_tree)
+
+
+def test_compound_constituents():
+ # TODO: add skinny trees like this to the various transition tests
+ text="((VP (VB Unban)))"
+ trees = tree_reader.read_trees(text)
+ assert Tree.get_compound_constituents(trees) == [('ROOT', 'VP')]
+
+ text="(ROOT (SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in)))) (. ?)))"
+ trees = tree_reader.read_trees(text)
+ assert Tree.get_compound_constituents(trees) == [('PP',), ('ROOT', 'SBARQ'), ('SQ', 'VP'), ('WHNP',)]
+
+ text="((VP (VB Unban))) (ROOT (SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in)))) (. ?)))"
+ trees = tree_reader.read_trees(text)
+ assert Tree.get_compound_constituents(trees) == [('PP',), ('ROOT', 'SBARQ'), ('ROOT', 'VP'), ('SQ', 'VP'), ('WHNP',)]
+
+def test_equals():
+ """
+ Check one tree from the actual dataset for ==
+
+ when built with compound Open, this didn't work because of a silly bug
+ """
+ text = "(ROOT (S (NP (DT The) (NNP Arizona) (NNPS Corporations) (NNP Commission)) (VP (VBD authorized) (NP (NP (DT an) (ADJP (CD 11.5)) (NN %) (NN rate) (NN increase)) (PP (IN at) (NP (NNP Tucson) (NNP Electric) (NNP Power) (NNP Co.))) (, ,) (UCP (ADJP (ADJP (RB substantially) (JJR lower)) (SBAR (IN than) (S (VP (VBN recommended) (NP (JJ last) (NN month)) (PP (IN by) (NP (DT a) (NN commission) (NN hearing) (NN officer))))))) (CC and) (NP (NP (QP (RB barely) (PDT half)) (DT the) (NN rise)) (VP (VBN sought) (PP (IN by) (NP (DT the) (NN utility)))))))) (. .)))"
+
+ trees = tree_reader.read_trees(text)
+ assert len(trees) == 1
+ tree = trees[0]
+
+ assert tree == tree
+
+ trees2 = tree_reader.read_trees(text)
+ tree2 = trees2[0]
+
+ assert tree is not tree2
+ assert tree == tree2
diff --git a/stanza/tests/constituency/test_trainer.py b/stanza/tests/constituency/test_trainer.py
new file mode 100644
index 00000000..d9fd18f6
--- /dev/null
+++ b/stanza/tests/constituency/test_trainer.py
@@ -0,0 +1,89 @@
+import logging
+import tempfile
+
+import pytest
+
+from stanza.models import constituency_parser
+from stanza.models.common import pretrain
+from stanza.models.constituency import lstm_model
+from stanza.models.constituency import trainer
+from stanza.models.constituency import tree_reader
+from stanza.tests import *
+
+pytestmark = [pytest.mark.pipeline, pytest.mark.travis]
+
+logger = logging.getLogger('stanza.constituency.trainer')
+logger.setLevel(logging.WARNING)
+
+TREEBANK = """
+( (S
+ (VP (VBG Enjoying)
+ (NP (PRP$ my) (JJ favorite) (NN Friday) (NN tradition)))
+ (. .)))
+
+( (NP
+ (VP (VBG Sitting)
+ (PP (IN in)
+ (NP (DT a) (RB stifling) (JJ hot) (NNP South) (NNP Station)))
+ (VP (VBG waiting)
+ (PP (IN for)
+ (NP (PRP$ my) (JJ delayed) (NNP @MBTA) (NN train)))))
+ (. .)))
+
+( (S
+ (NP (PRP I))
+ (VP
+ (ADVP (RB really))
+ (VBP hate)
+ (NP (DT the) (NNP @MBTA)))))
+
+( (S
+ (S (VP (VB Seek)))
+ (CC and)
+ (S (NP (PRP ye))
+ (VP (MD shall)
+ (VP (VB find))))
+ (. .)))
+"""
+
+@pytest.fixture(scope="module")
+def pt():
+ return pretrain.Pretrain(vec_filename=f'{TEST_WORKING_DIR}/in/tiny_emb.xz', save_to_file=False)
+
+def build_trainer(pt, *args):
+ # TODO: build a fake embedding some other way?
+ train_trees = tree_reader.read_trees(TREEBANK)
+ dev_trees = train_trees[-1:]
+
+ args = constituency_parser.parse_args(args)
+ forward_charlm = trainer.load_charlm(args['charlm_forward_file'])
+ backward_charlm = trainer.load_charlm(args['charlm_backward_file'])
+
+ model, _, _ = trainer.build_trainer(args, train_trees, dev_trees, pt, forward_charlm, backward_charlm)
+ assert isinstance(model.model, lstm_model.LSTMModel)
+ return model
+
+def test_initial_model(pt):
+ """
+ does nothing, just tests that the construction went okay
+ """
+ build_trainer(pt)
+
+
+def test_save_load_model(pt):
+ """
+ Just tests that saving and loading works without crashs.
+
+ Currently no test of the values themselves
+ """
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ tr = build_trainer(pt)
+
+ # attempt saving
+ filename = os.path.join(tmpdirname, "parser.pt")
+ tr.save(filename)
+
+ assert os.path.exists(filename)
+
+ # load it back in
+ tr.load(filename, pt, None, None, False)
diff --git a/stanza/tests/constituency/test_transition_sequence.py b/stanza/tests/constituency/test_transition_sequence.py
new file mode 100644
index 00000000..6c77db3f
--- /dev/null
+++ b/stanza/tests/constituency/test_transition_sequence.py
@@ -0,0 +1,87 @@
+import pytest
+from stanza.models.constituency import parse_transitions
+from stanza.models.constituency import transition_sequence
+from stanza.models.constituency import tree_reader
+from stanza.models.constituency.base_model import SimpleModel
+from stanza.models.constituency.parse_transitions import *
+
+from stanza.tests import *
+
+pytestmark = [pytest.mark.pipeline, pytest.mark.travis]
+
+def check_reproduce_tree(transition_scheme):
+ text="((SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))"
+ trees = tree_reader.read_trees(text)
+
+ model = SimpleModel(transition_scheme)
+ transitions = transition_sequence.build_sequence(trees[0], transition_scheme)
+ states = parse_transitions.initial_state_from_gold_trees(trees, model)
+ assert(len(states)) == 1
+ state = states[0]
+ assert state.num_transitions() == 0
+
+ for t in transitions:
+ assert t.is_legal(state, model)
+ state = t.apply(state, model)
+
+ # one item for the final tree
+ # one item for the sentinel at the end
+ assert len(state.constituents) == 2
+ # the transition sequence should put all of the words
+ # from the buffer onto the tree
+ # one spot left for the sentinel value
+ assert len(state.word_queue) == 7
+ assert state.sentence_length == 6
+ assert state.word_position == state.sentence_length
+ assert len(state.transitions) == len(transitions) + 1
+
+ result_tree = state.constituents.value
+ assert result_tree == trees[0]
+
+def test_top_down_unary():
+ check_reproduce_tree(transition_scheme=TransitionScheme.TOP_DOWN_UNARY)
+
+def test_top_down_no_unary():
+ check_reproduce_tree(transition_scheme=TransitionScheme.TOP_DOWN)
+
+def test_in_order():
+ check_reproduce_tree(transition_scheme=TransitionScheme.IN_ORDER)
+
+def test_all_transitions():
+ text="((SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))"
+ trees = tree_reader.read_trees(text)
+ model = SimpleModel()
+ transitions = transition_sequence.build_treebank(trees)
+
+ expected = [Shift(), CloseConstituent(), CompoundUnary("ROOT"), CompoundUnary("SQ"), CompoundUnary("WHNP"), OpenConstituent("NP"), OpenConstituent("PP"), OpenConstituent("SBARQ"), OpenConstituent("VP")]
+ assert transition_sequence.all_transitions(transitions) == expected
+
+
+def test_all_transitions_no_unary():
+ text="((SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))"
+ trees = tree_reader.read_trees(text)
+ model = SimpleModel()
+ transitions = transition_sequence.build_treebank(trees, transition_scheme=TransitionScheme.TOP_DOWN)
+
+ expected = [Shift(), CloseConstituent(), OpenConstituent("NP"), OpenConstituent("PP"), OpenConstituent("ROOT"), OpenConstituent("SBARQ"), OpenConstituent("SQ"), OpenConstituent("VP"), OpenConstituent("WHNP")]
+ assert transition_sequence.all_transitions(transitions) == expected
+
+def test_top_down_compound_unary():
+ text = "(ROOT (S (NP (DT The) (NNP Arizona) (NNPS Corporations) (NNP Commission)) (VP (VBD authorized) (NP (NP (DT an) (ADJP (CD 11.5)) (NN %) (NN rate) (NN increase)) (PP (IN at) (NP (NNP Tucson) (NNP Electric) (NNP Power) (NNP Co.))) (, ,) (UCP (ADJP (ADJP (RB substantially) (JJR lower)) (SBAR (IN than) (S (VP (VBN recommended) (NP (JJ last) (NN month)) (PP (IN by) (NP (DT a) (NN commission) (NN hearing) (NN officer))))))) (CC and) (NP (NP (QP (RB barely) (PDT half)) (DT the) (NN rise)) (VP (VBN sought) (PP (IN by) (NP (DT the) (NN utility)))))))) (. .)))"
+
+ trees = tree_reader.read_trees(text)
+ assert len(trees) == 1
+
+ model = SimpleModel()
+ transitions = transition_sequence.build_sequence(trees[0], transition_scheme=TransitionScheme.TOP_DOWN_COMPOUND)
+
+ states = parse_transitions.initial_state_from_gold_trees(trees, model)
+ assert len(states) == 1
+ state = states[0]
+
+ for t in transitions:
+ assert t.is_legal(state, model)
+ state = t.apply(state, model)
+
+ result = model.get_top_constituent(state.constituents)
+ assert trees[0] == result
diff --git a/stanza/tests/constituency/test_tree_reader.py b/stanza/tests/constituency/test_tree_reader.py
new file mode 100644
index 00000000..feee74fa
--- /dev/null
+++ b/stanza/tests/constituency/test_tree_reader.py
@@ -0,0 +1,61 @@
+import pytest
+from stanza.models.constituency import tree_reader
+
+from stanza.tests import *
+
+pytestmark = [pytest.mark.pipeline, pytest.mark.travis]
+
+def test_simple():
+ """
+ Tests reading two simple trees from the same text
+ """
+ text = "(VB Unban) (NNP Opal)"
+ trees = tree_reader.read_trees(text)
+ assert len(trees) == 2
+ assert trees[0].is_preterminal()
+ assert trees[0].label == 'VB'
+ assert trees[0].children[0].label == 'Unban'
+ assert trees[1].is_preterminal()
+ assert trees[1].label == 'NNP'
+ assert trees[1].children[0].label == 'Opal'
+
+def test_newlines():
+ """
+ The same test should work if there are newlines
+ """
+ text = "(VB Unban)\n\n(NNP Opal)"
+ trees = tree_reader.read_trees(text)
+ assert len(trees) == 2
+
+def test_complicated():
+ """
+ A more complicated tree that should successfully read
+ """
+ text="( (SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))"
+ trees = tree_reader.read_trees(text)
+ assert len(trees) == 1
+ tree = trees[0]
+ assert not tree.is_leaf()
+ assert not tree.is_preterminal()
+ assert tree.label == 'ROOT'
+ assert len(tree.children) == 1
+ assert tree.children[0].label == 'SBARQ'
+ assert len(tree.children[0].children) == 3
+ assert [x.label for x in tree.children[0].children] == ['WHNP', 'SQ', '.']
+ # etc etc
+
+def test_one_word():
+ """
+ Check that one node trees are correctly read
+
+ probably not super relevant for the parsing use case
+ """
+ text="(FOO) (BAR)"
+ trees = tree_reader.read_trees(text)
+ assert len(trees) == 2
+
+ assert trees[0].is_leaf()
+ assert trees[0].label == 'FOO'
+
+ assert trees[1].is_leaf()
+ assert trees[1].label == 'BAR'
diff --git a/stanza/tests/constituency/test_tree_stack.py b/stanza/tests/constituency/test_tree_stack.py
new file mode 100644
index 00000000..e7859a3b
--- /dev/null
+++ b/stanza/tests/constituency/test_tree_stack.py
@@ -0,0 +1,50 @@
+import pytest
+
+from stanza.models.constituency.tree_stack import TreeStack
+
+from stanza.tests import *
+
+pytestmark = [pytest.mark.pipeline, pytest.mark.travis]
+
+def test_simple():
+ stack = TreeStack(value=5, parent=None, length=1)
+ stack = stack.push(3)
+ stack = stack.push(1)
+
+ expected_values = [1, 3, 5]
+ for value in expected_values:
+ assert stack.value == value
+ stack = stack.pop()
+ assert stack is None
+
+def test_iter():
+ stack = TreeStack(value=5, parent=None, length=1)
+ stack = stack.push(3)
+ stack = stack.push(1)
+
+ stack_list = list(stack)
+ assert list(stack) == [1, 3, 5]
+
+def test_str():
+ stack = TreeStack(value=5, parent=None, length=1)
+ stack = stack.push(3)
+ stack = stack.push(1)
+
+ assert str(stack) == "TreeStack(1, 3, 5)"
+
+def test_len():
+ stack = TreeStack(value=5, parent=None, length=1)
+ assert len(stack) == 1
+
+ stack = stack.push(3)
+ stack = stack.push(1)
+ assert len(stack) == 3
+
+def test_long_len():
+ """
+ Original stack had a bug where this took exponential time...
+ """
+ stack = TreeStack(value=0, parent=None, length=1)
+ for i in range(1, 40):
+ stack = stack.push(i)
+ assert len(stack) == 40
diff --git a/stanza/tests/constituency/test_utils.py b/stanza/tests/constituency/test_utils.py
new file mode 100644
index 00000000..7e3b5d9e
--- /dev/null
+++ b/stanza/tests/constituency/test_utils.py
@@ -0,0 +1,68 @@
+import pytest
+
+from stanza import Pipeline
+from stanza.models.constituency import tree_reader
+from stanza.models.constituency import utils
+
+from stanza.tests import *
+
+pytestmark = [pytest.mark.pipeline, pytest.mark.travis]
+
+
+@pytest.fixture(scope="module")
+def pipeline():
+ return Pipeline(dir=TEST_MODELS_DIR, lang="en", processors="tokenize, pos", tokenize_pretokenized=True)
+
+
+
+def test_xpos_retag(pipeline):
+ """
+ Test using the English tagger that trees will be correctly retagged by read_trees using xpos
+ """
+ text = "((S (VP (X Find)) (NP (X Mox) (X Opal)))) ((S (NP (X Ragavan)) (VP (X steals) (NP (X important) (X cards)))))"
+ expected = "((S (VP (VB Find)) (NP (NNP Mox) (NNP Opal)))) ((S (NP (NNP Ragavan)) (VP (VBZ steals) (NP (JJ important) (NNS cards)))))"
+
+ trees = tree_reader.read_trees(text)
+
+ new_trees = utils.retag_trees(trees, pipeline, xpos=True)
+ assert new_trees == tree_reader.read_trees(expected)
+
+
+
+def test_upos_retag(pipeline):
+ """
+ Test using the English tagger that trees will be correctly retagged by read_trees using upos
+ """
+ text = "((S (VP (X Find)) (NP (X Mox) (X Opal)))) ((S (NP (X Ragavan)) (VP (X steals) (NP (X important) (X cards)))))"
+ expected = "((S (VP (VERB Find)) (NP (PROPN Mox) (PROPN Opal)))) ((S (NP (PROPN Ragavan)) (VP (VERB steals) (NP (ADJ important) (NOUN cards)))))"
+
+ trees = tree_reader.read_trees(text)
+
+ new_trees = utils.retag_trees(trees, pipeline, xpos=False)
+ assert new_trees == tree_reader.read_trees(expected)
+
+
+def test_replace_tags():
+ """
+ Test the underlying replace_tags method
+
+ Also tests that the method throws exceptions when it is supposed to
+ """
+ text = "((S (VP (X Find)) (NP (X Mox) (X Opal))))"
+ expected = "((S (VP (A Find)) (NP (B Mox) (C Opal))))"
+
+ trees = tree_reader.read_trees(text)
+
+ new_tags = ["A", "B", "C"]
+ new_tree = utils.replace_tags(trees[0], new_tags)
+
+ assert new_tree == tree_reader.read_trees(expected)[0]
+
+ with pytest.raises(ValueError):
+ new_tags = ["A", "B"]
+ new_tree = utils.replace_tags(trees[0], new_tags)
+
+ with pytest.raises(ValueError):
+ new_tags = ["A", "B", "C", "D"]
+ new_tree = utils.replace_tags(trees[0], new_tags)
+
diff --git a/stanza/tests/resources/test_common.py b/stanza/tests/resources/test_common.py
new file mode 100644
index 00000000..3594cac9
--- /dev/null
+++ b/stanza/tests/resources/test_common.py
@@ -0,0 +1,19 @@
+"""
+Test various resource downloading functions from resources/common.py
+"""
+
+import pytest
+import tempfile
+
+import stanza
+
+pytestmark = [pytest.mark.travis, pytest.mark.client]
+
+
+def test_download_tokenize_mwt():
+ with tempfile.TemporaryDirectory(dir=".") as test_dir:
+ stanza.download("en", model_dir=test_dir, processors="tokenize", package="ewt", verbose=False)
+ pipeline = stanza.Pipeline("en", model_dir=test_dir, processors="tokenize", package="ewt")
+ assert isinstance(pipeline, stanza.Pipeline)
+ # mwt should be added to the list
+ assert len(pipeline.loaded_processors) == 2
diff --git a/stanza/tests/test_installation.py b/stanza/tests/resources/test_installation.py
index 69a7bb0f..03fff24d 100644
--- a/stanza/tests/test_installation.py
+++ b/stanza/tests/resources/test_installation.py
@@ -18,7 +18,7 @@ def test_install_corenlp():
# the download method doesn't install over existing directories
shutil.rmtree(test_dir)
- stanza.install_corenlp(dir=test_dir, url='http://nlp.stanford.edu/software/')
+ stanza.install_corenlp(dir=test_dir)
assert os.path.isdir(test_dir), "Installation destination directory not found."
jar_files = [f for f in os.listdir(test_dir) \
diff --git a/stanza/tests/setup_test.sh b/stanza/tests/setup_test.sh
index a9d4bbf2..aec98367 100644
--- a/stanza/tests/setup_test.sh
+++ b/stanza/tests/setup_test.sh
@@ -23,6 +23,7 @@ mkdir -p $models_dir
$PYTHON -c "import stanza; stanza.download(lang='en', model_dir='${models_dir}', logging_level='info')" || echo "failed to download english model"
$PYTHON -c "import stanza; stanza.download(lang='fr', model_dir='${models_dir}', logging_level='info')" || echo "failed to download french model"
$PYTHON -c "import stanza; stanza.download(lang='zh', model_dir='${models_dir}', logging_level='info')" || echo "failed to download chinese model"
+$PYTHON -c "import stanza; stanza.download(lang='multilingual', model_dir='${models_dir}', logging_level='info')" || echo "failed to download chinese model"
echo "Models downloaded to ${models_dir}."
export STANZA_TEST_HOME=$test_dir
diff --git a/stanza/tests/test_constant.py b/stanza/tests/test_constant.py
new file mode 100644
index 00000000..3afcc8d6
--- /dev/null
+++ b/stanza/tests/test_constant.py
@@ -0,0 +1,35 @@
+"""
+Test the conversion to lcodes and splitting of dataset names
+"""
+
+import tempfile
+
+import pytest
+
+import stanza
+from stanza.models.common.constant import treebank_to_short_name
+from stanza.tests import *
+
+pytestmark = [pytest.mark.travis, pytest.mark.pipeline]
+
+def test_treebank():
+ """
+ Test the entire treebank name conversion
+ """
+ # conversion of a UD_ name
+ assert "hi_hdtb" == treebank_to_short_name("UD_Hindi-HDTB")
+ # conversion of names without UD
+ assert "hi_fire2013" == treebank_to_short_name("Hindi-fire2013")
+ assert "hi_fire2013" == treebank_to_short_name("Hindi-Fire2013")
+ assert "hi_fire2013" == treebank_to_short_name("Hindi-FIRE2013")
+ # already short names are generally preserved
+ assert "hi_fire2013" == treebank_to_short_name("hi-fire2013")
+ assert "hi_fire2013" == treebank_to_short_name("hi_fire2013")
+ # a special case
+ assert "zh-hant_pud" == treebank_to_short_name("UD_Chinese-PUD")
+ # a special case already converted once
+ assert "zh-hant_pud" == treebank_to_short_name("zh-hant_pud")
+ assert "zh-hant_pud" == treebank_to_short_name("zh-hant-pud")
+ assert "zh-hans_gsdsimp" == treebank_to_short_name("zh-hans_gsdsimp")
+
+
diff --git a/stanza/tests/test_english_pipeline.py b/stanza/tests/test_english_pipeline.py
index f270c1d4..3e3729ac 100644
--- a/stanza/tests/test_english_pipeline.py
+++ b/stanza/tests/test_english_pipeline.py
@@ -166,6 +166,7 @@ def test_dependency_parse(processed_doc):
def test_empty(pipeline):
# make sure that various models handle the degenerate empty case
pipeline("")
+ pipeline("--")
@pytest.fixture(scope="module")
def processed_multidoc(pipeline):
@@ -200,3 +201,8 @@ def processed_multidoc_variant():
def test_dependency_parse_multidoc_variant(processed_multidoc_variant):
assert "\n\n".join([sent.dependencies_string() for processed_doc in processed_multidoc_variant for sent in processed_doc.sentences]) == \
EN_DOC_DEPENDENCY_PARSES_GOLD
+
+def test_constituency_parser():
+ nlp = stanza.Pipeline(dir=TEST_MODELS_DIR, processors="tokenize,pos,constituency")
+ doc = nlp("This is a test")
+ assert str(doc.sentences[0].constituency) == '(ROOT (S (NP (DT This)) (VP (VBZ is) (NP (DT a) (NN test)))))'
diff --git a/stanza/tests/test_java_protobuf_requests.py b/stanza/tests/test_java_protobuf_requests.py
new file mode 100644
index 00000000..0c7ee7d8
--- /dev/null
+++ b/stanza/tests/test_java_protobuf_requests.py
@@ -0,0 +1,23 @@
+import tempfile
+
+import pytest
+
+from stanza.models.constituency import tree_reader
+from stanza.server import java_protobuf_requests
+from stanza.tests import *
+
+pytestmark = [pytest.mark.travis, pytest.mark.pipeline]
+
+def check_tree(proto_tree, py_tree, py_score):
+ tree, tree_score = java_protobuf_requests.from_tree(proto_tree)
+ assert tree_score == py_score
+ assert tree == py_tree
+
+def test_build_tree():
+ text="((S (VP (VB Unban)) (NP (NNP Mox) (NNP Opal))))\n( (SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))"
+ trees = tree_reader.read_trees(text)
+ assert len(trees) == 2
+
+ for tree in trees:
+ proto_tree = java_protobuf_requests.build_tree(trees[0], 1.0)
+ check_tree(proto_tree, trees[0], 1.0)
diff --git a/stanza/tests/test_langid.py b/stanza/tests/test_langid.py
new file mode 100644
index 00000000..5dd36125
--- /dev/null
+++ b/stanza/tests/test_langid.py
@@ -0,0 +1,613 @@
+"""
+Basic tests of langid module
+"""
+
+import pytest
+
+from stanza.models.common.doc import Document
+from stanza.pipeline.core import Pipeline
+from stanza.pipeline.multilingual import MultilingualPipeline
+from stanza.tests import *
+
+#pytestmark = pytest.mark.skip
+
+def test_langid():
+ """
+ Basic test of language identification
+ """
+ english_text = "This is an English sentence."
+ french_text = "C'est une phrase française."
+ docs = [english_text, french_text]
+
+ nlp = Pipeline(dir=TEST_MODELS_DIR, lang='multilingual', processors="langid")
+ docs = [Document([], text=text) for text in docs]
+ nlp(docs)
+ predictions = [doc.lang for doc in docs]
+ assert predictions == ["en", "fr"]
+
+def test_langid_benchmark():
+ """
+ Run lang id model on 500 examples, confirm reasonable accuracy.
+ """
+ examples = [
+ {"text": "contingentiam in naturalibus causis.", "label": "la"},
+ {"text": "I jak opowiadał nieżyjący już pan Czesław", "label": "pl"},
+ {"text": "Sonera gilt seit längerem als Übernahmekandidat", "label": "de"},
+ {"text": "与银类似,汞也可以与空气中的硫化氢反应。", "label": "zh-hans"},
+ {"text": "contradictionem implicat.", "label": "la"},
+ {"text": "Bis zu Prozent gingen die Offerten etwa im", "label": "de"},
+ {"text": "inneren Sicherheit vorgeschlagene Ausweitung der", "label": "de"},
+ {"text": "Multimedia-PDA mit Mini-Tastatur", "label": "de"},
+ {"text": "Ponášalo sa to na rovnicu o dvoch neznámych.", "label": "sk"},
+ {"text": "이처럼 앞으로 심판의 그 날에 다시 올 메시아가 예수 그리스도이며 , 그는 모든 인류의", "label": "ko"},
+ {"text": "Die Arbeitsgruppe bedauert , dass der weit über", "label": "de"},
+ {"text": "И только раз довелось поговорить с ним не вполне", "label": "ru"},
+ {"text": "de a-l lovi cu piciorul și conștiința că era", "label": "ro"},
+ {"text": "relación coas pretensións do demandante e que, nos", "label": "gl"},
+ {"text": "med petdeset in sedemdeset", "label": "sl"},
+ {"text": "Catalunya; el Consell Comarcal del Vallès Oriental", "label": "ca"},
+ {"text": "kunnen worden.", "label": "nl"},
+ {"text": "Witkin je ve většině ohledů zcela jiný.", "label": "cs"},
+ {"text": "lernen, so zu agieren, dass sie positive oder auch", "label": "de"},
+ {"text": "olurmuş...", "label": "tr"},
+ {"text": "sarcasmo de Altman, desde as «peruas» que discutem", "label": "pt"},
+ {"text": "خلاف فوجداری مقدمہ درج کرے۔", "label": "ur"},
+ {"text": "Norddal kommune :", "label": "no"},
+ {"text": "dem Windows-.-Zeitalter , soll in diesem Jahr", "label": "de"},
+ {"text": "przeklętych ucieleśniają mit poety-cygana,", "label": "pl"},
+ {"text": "We do not believe the suspect has ties to this", "label": "en"},
+ {"text": "groziņu pīšanu.", "label": "lv"},
+ {"text": "Senior Vice-President David M. Thomas möchte", "label": "de"},
+ {"text": "neomylně vybral nějakou knihu a začetl se.", "label": "cs"},
+ {"text": "Statt dessen darf beispielsweise der Browser des", "label": "de"},
+ {"text": "outubro, alcançando R $ bilhões em .", "label": "pt"},
+ {"text": "(Porte, ), as it does other disciplines", "label": "en"},
+ {"text": "uskupení se mylně domnívaly, že podporu", "label": "cs"},
+ {"text": "Übernahme von Next Ende an dem System herum , das", "label": "de"},
+ {"text": "No podemos decir a la Hacienda que los alemanes", "label": "es"},
+ {"text": "и рѣста еи братья", "label": "orv"},
+ {"text": "الذي اتخذ قرارا بتجميد اعلان الدولة الفلسطينية", "label": "ar"},
+ {"text": "uurides Rootsi sõjaarhiivist toodud . sajandi", "label": "et"},
+ {"text": "selskapets penger til å pusse opp sin enebolig på", "label": "no"},
+ {"text": "средней полосе и севернее в Ярославской,", "label": "ru"},
+ {"text": "il-massa żejda fil-ġemgħat u superġemgħat ta'", "label": "mt"},
+ {"text": "The Global Beauties on internetilehekülg, mida", "label": "et"},
+ {"text": "이스라엘 인들은 하나님이 그 큰 팔을 펴 이집트 인들을 치는 것을 보고 하나님을 두려워하며", "label": "ko"},
+ {"text": "Snad ještě dodejme jeden ekonomický argument.", "label": "cs"},
+ {"text": "Spalio d. vykusiame pirmajame rinkimų ture", "label": "lt"},
+ {"text": "und schlechter Journalismus ein gutes Geschäft .", "label": "de"},
+ {"text": "Du sodiečiai sėdi ant potvynio apsemtų namų stogo.", "label": "lt"},
+ {"text": "цей є автентичним.", "label": "uk"},
+ {"text": "Și îndegrabă fu cu îngerul mulțime de șireaguri", "label": "ro"},
+ {"text": "sobra personal cualificado.", "label": "es"},
+ {"text": "Tako se u Njemačkoj dvije trećine liječnika služe", "label": "hr"},
+ {"text": "Dual-Athlon-Chipsatz noch in diesem Jahr", "label": "de"},
+ {"text": "यहां तक कि चीन के चीफ ऑफ जनरल स्टाफ भी भारत का", "label": "hi"},
+ {"text": "Li forestier du mont avale", "label": "fro"},
+ {"text": "Netzwerken für Privatanwender zu bewundern .", "label": "de"},
+ {"text": "만해는 승적을 가진 중이 결혼할 수 없다는 불교의 계율을 시대에 맞지 않는 것으로 보았다", "label": "ko"},
+ {"text": "balance and weight distribution but not really for", "label": "en"},
+ {"text": "og så e # tente vi opp den om morgonen å sfyrte", "label": "nn"},
+ {"text": "변화는 의심의 여지가 없는 것이지만 반면에 진화는 논쟁의 씨앗이다 .", "label": "ko"},
+ {"text": "puteare fac aceastea.", "label": "ro"},
+ {"text": "Waitt seine Führungsmannschaft nicht dem", "label": "de"},
+ {"text": "juhtimisega, tulid sealt.", "label": "et"},
+ {"text": "Veränderungen .", "label": "de"},
+ {"text": "banda en el Bayer Leverkusen de la Bundesliga de", "label": "es"},
+ {"text": "В туже зиму посла всеволодъ сн҃а своѥго ст҃ослава", "label": "orv"},
+ {"text": "пославъ приведе я мастеры ѿ грекъ", "label": "orv"},
+ {"text": "En un nou escenari difícil d'imaginar fa poques", "label": "ca"},
+ {"text": "καὶ γὰρ τινὲς αὐτοὺς εὐεργεσίαι εἶχον ἐκ Κροίσου", "label": "grc"},
+ {"text": "직접적인 관련이 있다 .", "label": "ko"},
+ {"text": "가까운 듯하면서도 멀다 .", "label": "ko"},
+ {"text": "Er bietet ein ähnliches Leistungsniveau und", "label": "de"},
+ {"text": "民都洛水牛是獨居的,並不會以群族聚居。", "label": "zh-hant"},
+ {"text": "την τρομοκρατία.", "label": "el"},
+ {"text": "hurbiltzen diren neurrian.", "label": "eu"},
+ {"text": "Ah dimenticavo, ma tutta sta caciara per fare un", "label": "it"},
+ {"text": "На первом этапе (-) прошла так называемая", "label": "ru"},
+ {"text": "of games are on the market.", "label": "en"},
+ {"text": "находится Мост дружбы, соединяющий узбекский и", "label": "ru"},
+ {"text": "lessié je voldroie que li saint fussent aporté", "label": "fro"},
+ {"text": "Дошла очередь и до Гималаев.", "label": "ru"},
+ {"text": "vzácným suknem táhly pouští, si jednou chtěl do", "label": "cs"},
+ {"text": "E no terceiro tipo sitúa a familias (%), nos que a", "label": "gl"},
+ {"text": "وجابت دوريات امريكية وعراقية شوارع المدينة، فيما", "label": "ar"},
+ {"text": "Jeg har bodd her i år .", "label": "no"},
+ {"text": "Pohrozil, že odbory zostří postoj, pokud se", "label": "cs"},
+ {"text": "tinham conseguido.", "label": "pt"},
+ {"text": "Nicht-Erkrankten einen Anfangsverdacht für einen", "label": "de"},
+ {"text": "permanece em aberto.", "label": "pt"},
+ {"text": "questi possono promettere rendimenti fino a un", "label": "it"},
+ {"text": "Tema juurutatud kahevedurisüsteemita oleksid", "label": "et"},
+ {"text": "Поведение внешне простой игрушки оказалось", "label": "ru"},
+ {"text": "Bundesländern war vom Börsenverein des Deutschen", "label": "de"},
+ {"text": "acció, 'a mesura que avanci l'estiu, amb l'augment", "label": "ca"},
+ {"text": "Dove trovare queste risorse? Jay Naidoo, ministro", "label": "it"},
+ {"text": "essas gordurinhas.", "label": "pt"},
+ {"text": "Im zweiten Schritt sollen im übernächsten Jahr", "label": "de"},
+ {"text": "allveelaeva pole enam vaja, kuna külm sõda on läbi", "label": "et"},
+ {"text": "उपद्रवी दुकानों को लूटने के साथ ही उनमें आग लगा", "label": "hi"},
+ {"text": "@user nella sfortuna sei fortunata ..", "label": "it"},
+ {"text": "математических школ в виде грозовых туч.", "label": "ru"},
+ {"text": "No cambiaremos nunca nuestra forma de jugar por un", "label": "es"},
+ {"text": "dla tej klasy ani wymogów minimalnych, z wyjątkiem", "label": "pl"},
+ {"text": "en todo el mundo, mientras que en España consiguió", "label": "es"},
+ {"text": "политики считать надежное обеспечение военной", "label": "ru"},
+ {"text": "gogoratzen du, genio alemana delakoaren", "label": "eu"},
+ {"text": "Бычий глаз.", "label": "ru"},
+ {"text": "Opeření se v pravidelných obdobích obnovuje", "label": "cs"},
+ {"text": "I no és només la seva, es tracta d'una resposta", "label": "ca"},
+ {"text": "오경을 가르쳤다 .", "label": "ko"},
+ {"text": "Nach der so genannten Start-up-Periode vergibt die", "label": "de"},
+ {"text": "Saulista huomasi jo lapsena , että hänellä on", "label": "fi"},
+ {"text": "Министерство культуры сочло нецелесообразным, и", "label": "ru"},
+ {"text": "znepřátelené tábory v Tádžikistánu předseda", "label": "cs"},
+ {"text": "καὶ ἦν ὁ λαὸς προσδοκῶν τὸν Ζαχαρίαν καὶ ἐθαύμαζον", "label": "grc"},
+ {"text": "Вечером, в продукте, этот же человек говорил о", "label": "ru"},
+ {"text": "lugar á formación de xuizos máis complexos.", "label": "gl"},
+ {"text": "cheaper, in the end?", "label": "en"},
+ {"text": "الوزارة في شأن صفقات بيع الشركات العامة التي تم", "label": "ar"},
+ {"text": "tärkeintä elämässäni .", "label": "fi"},
+ {"text": "Виконання Мінських угод було заблоковано Росією та", "label": "uk"},
+ {"text": "Aby szybko rozpoznać żołnierzy desantu, należy", "label": "pl"},
+ {"text": "Bankengeschäfte liegen vorn , sagte Strothmann .", "label": "de"},
+ {"text": "продолжение работы.", "label": "ru"},
+ {"text": "Metro AG plant Online-Offensive", "label": "de"},
+ {"text": "nu vor veni, și să vor osîndi, aceia nu pot porni", "label": "ro"},
+ {"text": "Ich denke , es geht in Wirklichkeit darum , NT bei", "label": "de"},
+ {"text": "de turism care încasează contravaloarea", "label": "ro"},
+ {"text": "Aurkaria itotzea da helburua, baloia lapurtu eta", "label": "eu"},
+ {"text": "com a centre de formació en Tecnologies de la", "label": "ca"},
+ {"text": "oportet igitur quod omne agens in agendo intendat", "label": "la"},
+ {"text": "Jerzego Andrzejewskiego, oparty na chińskich", "label": "pl"},
+ {"text": "sau một vài câu chuyện xã giao không dính dáng tới", "label": "vi"},
+ {"text": "что экономическому прорыву жесткий авторитарный", "label": "ru"},
+ {"text": "DRAM-Preisen scheinen DSPs ein", "label": "de"},
+ {"text": "Jos dajan nubbái: Mana!", "label": "sme"},
+ {"text": "toți carii ascultară de el să răsipiră.", "label": "ro"},
+ {"text": "odpowiedzialności, które w systemie własności", "label": "pl"},
+ {"text": "Dvomesečno potovanje do Mollenda v Peruju je", "label": "sl"},
+ {"text": "d'entre les agències internacionals.", "label": "ca"},
+ {"text": "Fahrzeugzugangssysteme gefertigt und an viele", "label": "de"},
+ {"text": "in an answer to the sharers' petition in Cuthbert", "label": "en"},
+ {"text": "Europa-Domain per Verordnung zu regeln .", "label": "de"},
+ {"text": "#Balotelli. Su ebay prezzi stracciati per Silvio", "label": "it"},
+ {"text": "Ne na košickém trávníku, ale už včera v letadle se", "label": "cs"},
+ {"text": "zaměstnanosti a investičních strategií.", "label": "cs"},
+ {"text": "Tatínku, udělej den", "label": "cs"},
+ {"text": "frecuencia con Mary.", "label": "es"},
+ {"text": "Свеаборге.", "label": "ru"},
+ {"text": "opatření slovenské strany o certifikaci nejvíce", "label": "cs"},
+ {"text": "En todas me decían: 'Espera que hagamos un estudio", "label": "es"},
+ {"text": "Die Demonstration sollte nach Darstellung der", "label": "de"},
+ {"text": "Ci vorrà un assoluto rigore se dietro i disavanzi", "label": "it"},
+ {"text": "Tatínku, víš, že Honzovi odešla maminka?", "label": "cs"},
+ {"text": "Die Anzahl der Rechner wuchs um % auf und die", "label": "de"},
+ {"text": "האמריקאית על אדמת סעודיה עלולה לסבך את ישראל, אין", "label": "he"},
+ {"text": "Volán Egyesülés, a Közlekedési Főfelügyelet is.", "label": "hu"},
+ {"text": "Schejbala, který stejnou hru s velkým úspěchem", "label": "cs"},
+ {"text": "depends on the data type of the field.", "label": "en"},
+ {"text": "Umsatzwarnung zu Wochenbeginn zeitweise auf ein", "label": "de"},
+ {"text": "niin heti nukun .", "label": "fi"},
+ {"text": "Mobilfunkunternehmen gegen die Anwendung der so", "label": "de"},
+ {"text": "sapessi le intenzioni del governo Monti e dell'UE", "label": "it"},
+ {"text": "Di chi è figlia Martine Aubry?", "label": "it"},
+ {"text": "avec le reste du monde.", "label": "fr"},
+ {"text": "Այդ մաքոքը ինքնին նոր չէ, աշխարհը արդեն մի քանի", "label": "hy"},
+ {"text": "și în cazul destrămării cenaclului.", "label": "ro"},
+ {"text": "befriedigen kann , und ohne die auftretenden", "label": "de"},
+ {"text": "Κύκνον τ̓ ἐξεναρεῖν καὶ ἀπὸ κλυτὰ τεύχεα δῦσαι.", "label": "grc"},
+ {"text": "færdiguddannede.", "label": "da"},
+ {"text": "Schmidt war Sohn eines Rittergutsbesitzers.", "label": "de"},
+ {"text": "и вдаша попадь ѡпрати", "label": "orv"},
+ {"text": "cine nu știe învățătură”.", "label": "ro"},
+ {"text": "détacha et cette dernière tenta de tuer le jeune", "label": "fr"},
+ {"text": "Der har saka også ei lengre forhistorie.", "label": "nn"},
+ {"text": "Pieprz roztłuc w moździerzu, dodać do pasty,", "label": "pl"},
+ {"text": "Лежа за гребнем оврага, как за бруствером, Ушаков", "label": "ru"},
+ {"text": "gesucht habe, vielen Dank nochmals!", "label": "de"},
+ {"text": "инструментальных сталей, повышения", "label": "ru"},
+ {"text": "im Halbfinale Patrick Smith und im Finale dann", "label": "de"},
+ {"text": "البنوك التريث في منح تسهيلات جديدة لمنتجي حديد", "label": "ar"},
+ {"text": "una bolsa ventral, la cual se encuentra debajo de", "label": "es"},
+ {"text": "za SETimes.", "label": "sr"},
+ {"text": "de Irak, a un piloto italiano que había violado el", "label": "es"},
+ {"text": "Er könne sich nicht erklären , wie die Zeitung auf", "label": "de"},
+ {"text": "Прохорова.", "label": "ru"},
+ {"text": "la democrazia perde sulla tecnocrazia? #", "label": "it"},
+ {"text": "entre ambas instituciones, confirmó al medio que", "label": "es"},
+ {"text": "Austlandet, vart det funne om lag førti", "label": "nn"},
+ {"text": "уровнями власти.", "label": "ru"},
+ {"text": "Dá tedy primáři úplatek, a často ne malý.", "label": "cs"},
+ {"text": "brillantes del acto, al llevar a cabo en el", "label": "es"},
+ {"text": "eee druga zadeva je majhen priročen gre kamorkoli", "label": "sl"},
+ {"text": "Das ATX-Board paßt in herkömmliche PC-ATX-Gehäuse", "label": "de"},
+ {"text": "Za vodné bylo v prvním pololetí zaplaceno v ČR", "label": "cs"},
+ {"text": "Даже на полсантиметра.", "label": "ru"},
+ {"text": "com la del primer tinent d'alcalde en funcions,", "label": "ca"},
+ {"text": "кількох оповідань в цілості — щось на зразок того", "label": "uk"},
+ {"text": "sed ad divitias congregandas, vel superfluum", "label": "la"},
+ {"text": "Norma Talmadge, spela mot Valentino i en version", "label": "sv"},
+ {"text": "Dlatego chciał się jej oświadczyć w niezwykłym", "label": "pl"},
+ {"text": "будут выступать на одинаковых снарядах.", "label": "ru"},
+ {"text": "Orang-orang terbunuh di sana.", "label": "id"},
+ {"text": "لدى رايت شقيق اسمه أوسكار, وهو يعمل كرسام للكتب", "label": "ar"},
+ {"text": "Wirklichkeit verlagerten und kaum noch", "label": "de"},
+ {"text": "как перемешивают костяшки перед игрой в домино, и", "label": "ru"},
+ {"text": "В средине дня, когда солнце светило в нашу", "label": "ru"},
+ {"text": "d'aventure aux rôles de jeune romantique avec une", "label": "fr"},
+ {"text": "My teď hledáme organizace, jež by s námi chtěly", "label": "cs"},
+ {"text": "Urteilsfähigkeit einbüßen , wenn ich eigene", "label": "de"},
+ {"text": "sua appartenenza anche a voci diverse da quella in", "label": "it"},
+ {"text": "Aufträge dieses Jahr verdoppeln werden .", "label": "de"},
+ {"text": "M.E.: Miała szanse mnie odnaleźć, gdyby naprawdę", "label": "pl"},
+ {"text": "secundum contactum virtutis, cum careat dimensiva", "label": "la"},
+ {"text": "ezinbestekoa dela esan zuen.", "label": "eu"},
+ {"text": "Anek hurbiltzeko eskatzen zion besaulkitik, eta", "label": "eu"},
+ {"text": "perfectius alio videat, quamvis uterque videat", "label": "la"},
+ {"text": "Die Strecke war anspruchsvoll und führte unter", "label": "de"},
+ {"text": "саморазоблачительным уроком, западные СМИ не", "label": "ru"},
+ {"text": "han representerer radikal islamisme .", "label": "no"},
+ {"text": "Què s'hi respira pel que fa a la reforma del", "label": "ca"},
+ {"text": "previsto para também ser desconstruido.", "label": "pt"},
+ {"text": "Ὠκεανοῦ βαθυκόλποις ἄνθεά τ̓ αἰνυμένην, ῥόδα καὶ", "label": "grc"},
+ {"text": "para jovens de a anos nos Cieps.", "label": "pt"},
+ {"text": "संघर्ष को अंजाम तक पहुंचाने का ऐलान किया है ।", "label": "hi"},
+ {"text": "objeví i u nás.", "label": "cs"},
+ {"text": "kvitteringer.", "label": "da"},
+ {"text": "This report is no exception.", "label": "en"},
+ {"text": "Разлепват доносниците до избирателните списъци", "label": "bg"},
+ {"text": "anderem ihre Bewegungsfreiheit in den USA", "label": "de"},
+ {"text": "Ñu tegoon ca kaw gor ña ay njotti bopp yu kenn", "label": "wo"},
+ {"text": "Struktur kann beispielsweise der Schwerpunkt mehr", "label": "de"},
+ {"text": "% la velocidad permitida, la sanción es muy grave.", "label": "es"},
+ {"text": "Teles-Einstieg in ADSL-Markt", "label": "de"},
+ {"text": "ettekäändeks liiga suure osamaksu.", "label": "et"},
+ {"text": "als Indiz für die geänderte Marktpolitik des", "label": "de"},
+ {"text": "quod quidem aperte consequitur ponentes", "label": "la"},
+ {"text": "de negociación para el próximo de junio.", "label": "es"},
+ {"text": "Tyto důmyslné dekorace doznaly v poslední době", "label": "cs"},
+ {"text": "največjega uspeha doslej.", "label": "sl"},
+ {"text": "Paul Allen je jedan od suosnivača Interval", "label": "hr"},
+ {"text": "Federal (Seac / DF) eo Sindicato das Empresas de", "label": "pt"},
+ {"text": "Quartal mit . Mark gegenüber dem gleichen Quartal", "label": "de"},
+ {"text": "otros clubes y del Barça B saldrán varios", "label": "es"},
+ {"text": "Jaskula (Pol.) -", "label": "cs"},
+ {"text": "umožnily říci, že je možné přejít k mnohem", "label": "cs"},
+ {"text": "اعلن الجنرال تومي فرانكس قائد القوات الامريكية", "label": "ar"},
+ {"text": "Telekom-Chef Ron Sommer und der Vorstandssprecher", "label": "de"},
+ {"text": "My, jako průmyslový a finanční holding, můžeme", "label": "cs"},
+ {"text": "voorlichting onder andere betrekking kan hebben:", "label": "nl"},
+ {"text": "Hinrichtung geistig Behinderter applaudiert oder", "label": "de"},
+ {"text": "wie beispielsweise Anzahl erzielte Klicks ,", "label": "de"},
+ {"text": "Intel-PC-SDRAM-Spezifikation in der Version . (", "label": "de"},
+ {"text": "plângere în termen de zile de la comunicarea", "label": "ro"},
+ {"text": "и Испания ще изгубят втория си комисар в ЕК.", "label": "bg"},
+ {"text": "इसके चलते इस आदिवासी जनजाति का क्षरण हो रहा है ।", "label": "hi"},
+ {"text": "aunque se mostró contrario a establecer un", "label": "es"},
+ {"text": "des letzten Jahres von auf Millionen Euro .", "label": "de"},
+ {"text": "Ankara se također poziva da u cijelosti ratificira", "label": "hr"},
+ {"text": "herunterlädt .", "label": "de"},
+ {"text": "стрессовую ситуацию для организма, каковой", "label": "ru"},
+ {"text": "Státního shromáždění (parlamentu).", "label": "cs"},
+ {"text": "diskutieren , ob und wie dieser Dienst weiterhin", "label": "de"},
+ {"text": "Verbindungen zu FPÖ-nahen Polizisten gepflegt und", "label": "de"},
+ {"text": "Pražského volebního lídra ovšem nevybírá Miloš", "label": "cs"},
+ {"text": "Nach einem Bericht der Washington Post bleibt das", "label": "de"},
+ {"text": "للوضع آنذاك، لكني في قرارة نفسي كنت سعيداً لما", "label": "ar"},
+ {"text": "не желаят запазването на статуквото.", "label": "bg"},
+ {"text": "Offenburg gewesen .", "label": "de"},
+ {"text": "ἐὰν ὑμῖν εἴπω οὐ μὴ πιστεύσητε", "label": "grc"},
+ {"text": "all'odiato compagno di squadra Prost, il quale", "label": "it"},
+ {"text": "historischen Gänselieselbrunnens.", "label": "de"},
+ {"text": "למידע מלווייני הריגול האמריקאיים העוקבים אחר", "label": "he"},
+ {"text": "οὐδὲν ἄρα διαφέρεις Ἀμάσιος τοῦ Ἠλείου, ὃν", "label": "grc"},
+ {"text": "movementos migratorios.", "label": "gl"},
+ {"text": "Handy und ein Spracherkennungsprogramm sämtliche", "label": "de"},
+ {"text": "Kümne aasta jooksul on Eestisse ohjeldamatult", "label": "et"},
+ {"text": "H.G. Bücknera.", "label": "pl"},
+ {"text": "protiv krijumčarenja, ili pak traženju ukidanja", "label": "hr"},
+ {"text": "Topware-Anteile mehrere Millionen Mark gefordert", "label": "de"},
+ {"text": "Maar de mensen die nu over Van Dijk bij FC Twente", "label": "nl"},
+ {"text": "poidan experimentar as percepcións do interesado,", "label": "gl"},
+ {"text": "Miał przecież w kieszeni nóż.", "label": "pl"},
+ {"text": "Avšak žádná z nich nepronikla za hranice přímé", "label": "cs"},
+ {"text": "esim. helpottamalla luottoja muiden", "label": "fi"},
+ {"text": "Podle předběžných výsledků zvítězila v", "label": "cs"},
+ {"text": "Nicht nur das Web-Frontend , auch die", "label": "de"},
+ {"text": "Regierungsinstitutionen oder Universitäten bei", "label": "de"},
+ {"text": "Խուլեն Լոպետեգիին, պատճառաբանելով, որ վերջինս", "label": "hy"},
+ {"text": "Афганистана, где в последние дни идут ожесточенные", "label": "ru"},
+ {"text": "лѧхове же не идоша", "label": "orv"},
+ {"text": "Mit Hilfe von IBMs Chip-Management-Systemen sollen", "label": "de"},
+ {"text": ", als Manager zu Telefonica zu wechseln .", "label": "de"},
+ {"text": "którym zajmuje się człowiek, zmienia go i pozwala", "label": "pl"},
+ {"text": "činí kyperských liber, to je asi USD.", "label": "cs"},
+ {"text": "Studienplätze getauscht werden .", "label": "de"},
+ {"text": "учёных, орнитологов признают вид.", "label": "ru"},
+ {"text": "acordare a concediilor prevăzute de legislațiile", "label": "ro"},
+ {"text": "at større innsats for fornybar, berekraftig energi", "label": "nn"},
+ {"text": "Politiet veit ikkje kor mange personar som deltok", "label": "nn"},
+ {"text": "offentligheten av unge , sinte menn som har", "label": "no"},
+ {"text": "însuși în jurul lapunei, care încet DISPARE în", "label": "ro"},
+ {"text": "O motivo da decisão é evitar uma sobrecarga ainda", "label": "pt"},
+ {"text": "El Apostolado de la prensa contribuye en modo", "label": "es"},
+ {"text": "Teltow ( Kreis Teltow-Fläming ) ist Schmitt einer", "label": "de"},
+ {"text": "grozījumus un iesniegt tos Apvienoto Nāciju", "label": "lv"},
+ {"text": "Gestalt einer deutschen Nationalmannschaft als", "label": "de"},
+ {"text": "D überholt zu haben , konterte am heutigen Montag", "label": "de"},
+ {"text": "Softwarehersteller Oracle hat im dritten Quartal", "label": "de"},
+ {"text": "Během nich se ekonomické podmínky mohou radikálně", "label": "cs"},
+ {"text": "Dziki kot w górach zeskakuje z kamienia.", "label": "pl"},
+ {"text": "Ačkoliv ligový nováček prohrál, opět potvrdil, že", "label": "cs"},
+ {"text": "des Tages , Portraits internationaler Stars sowie", "label": "de"},
+ {"text": "Communicator bekannt wurde .", "label": "de"},
+ {"text": "τῷ δ’ ἄρα καὶ αὐτῷ ἡ γυνή ἐπίτεξ ἐοῦσα πᾶσαν", "label": "grc"},
+ {"text": "Triadú tenia, mentre redactava 'Dies de memòria',", "label": "ca"},
+ {"text": "دسته‌جمعی در درخشندگی ماه سیم‌گون زمزمه ستاینده و", "label": "fa"},
+ {"text": "Книгу, наполненную мелочной заботой об одежде,", "label": "ru"},
+ {"text": "putares canem leporem persequi.", "label": "la"},
+ {"text": "В дальнейшем эта яркость слегка померкла, но в", "label": "ru"},
+ {"text": "offizielles Verfahren gegen die Telekom", "label": "de"},
+ {"text": "podrían haber sido habitantes de la Península", "label": "es"},
+ {"text": "Grundlage für dieses Verfahren sind spezielle", "label": "de"},
+ {"text": "Rechtsausschuß vorgelegten Entwurf der Richtlinie", "label": "de"},
+ {"text": "Im so genannten Portalgeschäft sei das Unternehmen", "label": "de"},
+ {"text": "ⲏ ⲉⲓϣⲁⲛϥⲓ ⲛⲉⲓⲇⲱⲗⲟⲛ ⲉⲧϩⲙⲡⲉⲕⲏⲓ ⲙⲏ ⲉⲓⲛⲁϣϩⲱⲡ ⲟⲛ ⲙⲡⲣⲏ", "label": "cop"},
+ {"text": "juego podían matar a cualquier herbívoro, pero", "label": "es"},
+ {"text": "Nach Angaben von Axent nutzen Unternehmen aus der", "label": "de"},
+ {"text": "hrdiny Havlovy Zahradní slavnosti (premiéra ) se", "label": "cs"},
+ {"text": "Een zin van heb ik jou daar", "label": "nl"},
+ {"text": "hat sein Hirn an der CeBIT-Kasse vergessen .", "label": "de"},
+ {"text": "καὶ τοὺς ἐκπλαγέντας οὐκ ἔχειν ἔτι ἐλεγχομένους", "label": "grc"},
+ {"text": "nachgewiesenen langfristigen Kosten , sowie den im", "label": "de"},
+ {"text": "jučer nakon četiri dana putovanja u Helsinki.", "label": "hr"},
+ {"text": "pašto paslaugos teikėjas gali susitarti su", "label": "lt"},
+ {"text": "В результате, эти золотые кадры переходят из одной", "label": "ru"},
+ {"text": "द फाइव-ईयर एंगेजमेंट में अभिनय किया जिसमें जैसन", "label": "hi"},
+ {"text": "výpis o počtu akcií.", "label": "cs"},
+ {"text": "Enfin, elles arrivent à un pavillon chinois", "label": "fr"},
+ {"text": "Tentu saja, tren yang berhubungandengan", "label": "id"},
+ {"text": "Arbeidarpartiet og SV har sikra seg fleirtal mot", "label": "nn"},
+ {"text": "eles: 'Tudo isso está errado' , disse um", "label": "pt"},
+ {"text": "The islands are in their own time zone, minutes", "label": "en"},
+ {"text": "Auswahl debütierte er am .", "label": "de"},
+ {"text": "Bu komisyonlar, arazilerini satın almak için", "label": "tr"},
+ {"text": "Geschütze gegen Redmond aufgefahren .", "label": "de"},
+ {"text": "Time scything the hours, but at the top, over the", "label": "en"},
+ {"text": "Di musim semi , berharap mengadaptasi Tintin untuk", "label": "id"},
+ {"text": "крупнейшей геополитической катастрофой XX века.", "label": "ru"},
+ {"text": "Rajojen avaaminen ei suju ongelmitta .", "label": "fi"},
+ {"text": "непроницаемым, как для СССР.", "label": "ru"},
+ {"text": "Ma non mancano le polemiche.", "label": "it"},
+ {"text": "Internet als Ort politischer Diskussion und auch", "label": "de"},
+ {"text": "incomplets.", "label": "ca"},
+ {"text": "Su padre luchó al lado de Luis Moya, primer Jefe", "label": "es"},
+ {"text": "informazione.", "label": "it"},
+ {"text": "Primacom bietet für Telekom-Kabelnetz", "label": "de"},
+ {"text": "Oświadczenie prezydencji w imieniu Unii", "label": "pl"},
+ {"text": "foran rattet i familiens gamle Baleno hvis døra på", "label": "no"},
+ {"text": "[speaker:laughter]", "label": "sl"},
+ {"text": "Dog med langt mindre utstyr med seg.", "label": "nn"},
+ {"text": "dass es nicht schon mit der anfänglichen", "label": "de"},
+ {"text": "इस पर दोनों पक्षों में नोकझोंक शुरू हो गई ।", "label": "hi"},
+ {"text": "کے ترجمان منیش تیواری اور دگ وجئے سنگھ نے بھی یہ", "label": "ur"},
+ {"text": "dell'Assemblea Costituente che posseggono i", "label": "it"},
+ {"text": "и аште вьси съблазнѧтъ сѧ нъ не азъ", "label": "cu"},
+ {"text": "In Irvine hat auch das Logistikunternehmen Atlas", "label": "de"},
+ {"text": "законодательных норм, принимаемых существующей", "label": "ru"},
+ {"text": "Κροίσῳ προτείνων τὰς χεῖρας ἐπικατασφάξαι μιν", "label": "grc"},
+ {"text": "МИНУСЫ: ИНФЛЯЦИЯ И КРИЗИС В ЖИВОТНОВОДСТВЕ.", "label": "ru"},
+ {"text": "unterschiedlicher Meinung .", "label": "de"},
+ {"text": "Jospa joku ystävällinen sielu auttaisi kassieni", "label": "fi"},
+ {"text": "Añadió que, en el futuro se harán otros", "label": "es"},
+ {"text": "Sessiz tonlama hem Fince, hem de Kuzey Sami", "label": "tr"},
+ {"text": "nicht ihnen gehört und sie nicht alles , was sie", "label": "de"},
+ {"text": "Etelästä Kuivajärveen laskee Tammelan Liesjärvestä", "label": "fi"},
+ {"text": "ICANNs Vorsitzender Vint Cerf warb mit dem Hinweis", "label": "de"},
+ {"text": "Norsk politikk frå til kan dermed, i", "label": "nn"},
+ {"text": "Głosowało posłów.", "label": "pl"},
+ {"text": "Danny Jones -- smithjones@ev.net", "label": "en"},
+ {"text": "sebeuvědomění moderní civilizace sehrála lučavka", "label": "cs"},
+ {"text": "относительно спокойный сон: тому гарантия", "label": "ru"},
+ {"text": "A halte voiz prist li pedra a crïer", "label": "fro"},
+ {"text": "آن‌ها امیدوارند این واکسن به‌زودی در دسترس بیماران", "label": "fa"},
+ {"text": "vlastní důstojnou vousatou tváří.", "label": "cs"},
+ {"text": "ora aprire la strada a nuove cause e alimentare il", "label": "it"},
+ {"text": "Die Zahl der Vielleser nahm von auf Prozent zu ,", "label": "de"},
+ {"text": "Finanzvorstand von Hotline-Dienstleister InfoGenie", "label": "de"},
+ {"text": "entwickeln .", "label": "de"},
+ {"text": "incolumità pubblica.", "label": "it"},
+ {"text": "lehtija televisiomainonta", "label": "fi"},
+ {"text": "joistakin kohdista eri mieltä.", "label": "fi"},
+ {"text": "Hlavně anglická nezávislá scéna, Dead Can Dance,", "label": "cs"},
+ {"text": "pásmech od do bodů bodové stupnice.", "label": "cs"},
+ {"text": "Zu Beginn des Ersten Weltkrieges zählte das", "label": "de"},
+ {"text": "Així van sorgir, damunt els antics cementiris,", "label": "ca"},
+ {"text": "In manchem Gedicht der spätern Alten, wie zum", "label": "de"},
+ {"text": "gaweihaida jah insandida in þana fairƕu jus qiþiþ", "label": "got"},
+ {"text": "Beides sollte gelöscht werden!", "label": "de"},
+ {"text": "modifiqués la seva petició inicial de anys de", "label": "ca"},
+ {"text": "В день открытия симпозиума состоялась закладка", "label": "ru"},
+ {"text": "tõestatud.", "label": "et"},
+ {"text": "ἵππῳ πίπτει αὐτοῦ ταύτῃ", "label": "grc"},
+ {"text": "bisher nie enttäuscht!", "label": "de"},
+ {"text": "De bohte ollu tuollárat ja suttolaččat ja", "label": "sme"},
+ {"text": "Klarsignal från röstlängdsläsaren, tre tryck i", "label": "sv"},
+ {"text": "Tvůrcem nového termínu je Joseph Fisher.", "label": "cs"},
+ {"text": "Nie miałem czasu na reakcję twierdzi Norbert,", "label": "pl"},
+ {"text": "potentia Schöpfer.", "label": "de"},
+ {"text": "Un poquito caro, pero vale mucho la pena;", "label": "es"},
+ {"text": "οὔ τε γὰρ ἴφθιμοι Λύκιοι Δαναῶν ἐδύναντο τεῖχος", "label": "grc"},
+ {"text": "vajec, sladového výtažku a některých vitamínových", "label": "cs"},
+ {"text": "Настоящие герои, те, чьи истории потом", "label": "ru"},
+ {"text": "praesumptio:", "label": "la"},
+ {"text": "Olin justkui nende vastutusel.", "label": "et"},
+ {"text": "Jokainen keinahdus tuo lähemmäksi hetkeä jolloin", "label": "fi"},
+ {"text": "ekonomicky výhodných způsobů odvodnění těžkých,", "label": "cs"},
+ {"text": "Poprvé ve své historii dokázala v kvalifikaci pro", "label": "cs"},
+ {"text": "zpracovatelského a spotřebního průmyslu bude nutné", "label": "cs"},
+ {"text": "Windows CE zu integrieren .", "label": "de"},
+ {"text": "Armangué, a través d'un decret, ordenés l'aturada", "label": "ca"},
+ {"text": "to, co nás Evropany spojuje, než to, co nás od", "label": "cs"},
+ {"text": "ergänzt durch einen gesetzlich verankertes", "label": "de"},
+ {"text": "Насчитал, что с начала года всего три дня были", "label": "ru"},
+ {"text": "Borisovu tražeći od njega da prihvati njenu", "label": "sr"},
+ {"text": "la presenza di ben veleni diversi: . chili di", "label": "it"},
+ {"text": "καὶ τῶν ἐκλεκτῶν ἀγγέλων ἵνα ταῦτα φυλάξῃς χωρὶς", "label": "grc"},
+ {"text": "pretraživale obližnju bolnicu i stambene zgrade u", "label": "hr"},
+ {"text": "An rund Katzen habe Wolf seine Spiele getestet ,", "label": "de"},
+ {"text": "investigating since March.", "label": "en"},
+ {"text": "Tonböden (Mullböden).", "label": "de"},
+ {"text": "Stálý dopisovatel LN v SRN Bedřich Utitz", "label": "cs"},
+ {"text": "červnu předložené smlouvy.", "label": "cs"},
+ {"text": "πνεύματι ᾧ ἐλάλει", "label": "grc"},
+ {"text": ".%의 신장세를 보였다.", "label": "ko"},
+ {"text": "Foae verde, foi de nuc, Prin pădure, prin colnic,", "label": "ro"},
+ {"text": "διαπέμψας ἄλλους ἄλλῃ τοὺς μὲν ἐς Δελφοὺς ἰέναι", "label": "grc"},
+ {"text": "المسلمين أو أي تيار سياسي طالما عمل ذلك التيار في", "label": "ar"},
+ {"text": "As informações são da Dow Jones.", "label": "pt"},
+ {"text": "Milliarde DM ausgestattet sein .", "label": "de"},
+ {"text": "De utgår fortfarande från att kvinnans jämlikhet", "label": "sv"},
+ {"text": "Sneeuw maakte in Davos bij de voorbereiding een", "label": "nl"},
+ {"text": "De ahí que en este mercado puedan negociarse", "label": "es"},
+ {"text": "intenzívnějšímu sbírání a studiu.", "label": "cs"},
+ {"text": "और औसकर ४.० पैकेज का प्रयोग किया गया है ।", "label": "hi"},
+ {"text": "Adipati Kuningan karena Kuningan menjadi bagian", "label": "id"},
+ {"text": "Svako je bar jednom poželeo da mašine prosto umeju", "label": "sr"},
+ {"text": "Im vergangenen Jahr haben die Regierungen einen", "label": "de"},
+ {"text": "durat motus, aliquid fit et non est;", "label": "la"},
+ {"text": "Dominować będą piosenki do tekstów Edwarda", "label": "pl"},
+ {"text": "beantwortet .", "label": "de"},
+ {"text": "О гуманитариях было кому рассказывать, а вот за", "label": "ru"},
+ {"text": "Helsingin kaupunki riitautti vuokrasopimuksen", "label": "fi"},
+ {"text": "chợt tan biến.", "label": "vi"},
+ {"text": "avtomobil ločuje od drugih.", "label": "sl"},
+ {"text": "Congress has proven itself ineffective as a body.", "label": "en"},
+ {"text": "मैक्सिको ने इस तरह का शो इस समय आयोजित करने का", "label": "hi"},
+ {"text": "No minimum order amount.", "label": "en"},
+ {"text": "Convertassa .", "label": "fi"},
+ {"text": "Как это можно сделать?", "label": "ru"},
+ {"text": "tha mi creidsinn gu robh iad ceart cho saor shuas", "label": "gd"},
+ {"text": "실제 일제는 이런 만해의 논리를 묵살하고 한반도를 침략한 다음 , 이어 만주를 침략하고", "label": "ko"},
+ {"text": "Da un semplice richiamo all'ordine fino a grandi", "label": "it"},
+ {"text": "pozoruhodný nejen po umělecké stránce, jež", "label": "cs"},
+ {"text": "La comida y el servicio aprueban.", "label": "es"},
+ {"text": "again, connected not with each other but to the", "label": "en"},
+ {"text": "Protokol výslovně stanoví, že nikdo nemůže být", "label": "cs"},
+ {"text": "ఒక విషయం అడగాలని ఉంది .", "label": "te"},
+ {"text": "Безгранично почитая дирекцию, ловя на лету каждое", "label": "ru"},
+ {"text": "rovnoběžných růstových vrstev, zůstávají krychlové", "label": "cs"},
+ {"text": "प्रवेश और पूर्व प्रधानमंत्री लाल बहादुर शास्त्री", "label": "hi"},
+ {"text": "Bronzen medaille in de Europese marathon.", "label": "nl"},
+ {"text": "- gadu vecumā viņi to nesaprot.", "label": "lv"},
+ {"text": "Realizó sus estudios primarios en la Escuela Julia", "label": "es"},
+ {"text": "cuartos de final, su clasificación para la final a", "label": "es"},
+ {"text": "Sem si pro něho přiletí americký raketoplán, na", "label": "cs"},
+ {"text": "Way to go!", "label": "en"},
+ {"text": "gehört der neuen SPD-Führung unter Parteichef", "label": "de"},
+ {"text": "Somit simuliert der Player mit einer GByte-Platte", "label": "de"},
+ {"text": "Berufung auf kommissionsnahe Kreise , die bereits", "label": "de"},
+ {"text": "Dist Clarïen", "label": "fro"},
+ {"text": "Schon nach den Gerüchten , die Telekom wolle den", "label": "de"},
+ {"text": "Software von NetObjects ist nach Angaben des", "label": "de"},
+ {"text": "si enim per legem iustitia ergo Christus gratis", "label": "la"},
+ {"text": "ducerent in ipsam magis quam in corpus christi,", "label": "la"},
+ {"text": "Neustar-Melbourne-IT-Partnerschaft NeuLevel .", "label": "de"},
+ {"text": "forderte dagegen seine drastische Verschärfung.", "label": "de"},
+ {"text": "pemmican på hundrede forskellige måder.", "label": "da"},
+ {"text": "Lehån, själv matematiklärare, visar hur den nya", "label": "sv"},
+ {"text": "I highly recommend his shop.", "label": "en"},
+ {"text": "verità, giovani fedeli prostratevi #amen", "label": "it"},
+ {"text": "उत्तर प्रदेश के अध्यक्ष पद से हटाए गए विनय कटियार", "label": "hi"},
+ {"text": "() روزی مےں کشادگی ہوتی ہے۔", "label": "ur"},
+ {"text": "Prozessorgeschäft profitieren kann , stellen", "label": "de"},
+ {"text": "školy začalo počítat pytle s moukou a zjistilo, že", "label": "cs"},
+ {"text": "प्रभावशाली पर गैर सरकारी लोगों के घरों में भी", "label": "hi"},
+ {"text": "geschichtslos , oder eine Farce , wie sich", "label": "de"},
+ {"text": "Ústrednými mocnosťami v marci však spôsobilo, že", "label": "sk"},
+ {"text": "التسليح بدون مبرر، واستمرار الأضرار الناجمة عن فرض", "label": "ar"},
+ {"text": "Například Pedagogická fakulta Univerzity Karlovy", "label": "cs"},
+ {"text": "nostris ut eriperet nos de praesenti saeculo", "label": "la"}]
+
+ nlp = Pipeline(dir=TEST_MODELS_DIR, lang="multilingual", processors="langid")
+ docs = [Document([], text=example["text"]) for example in examples]
+ gold_labels = [example["label"] for example in examples]
+ nlp(docs)
+ accuracy = sum([(doc.lang == label) for doc,label in zip(docs,gold_labels)])/len(docs)
+ assert accuracy >= 0.98
+
+
+def test_text_cleaning():
+ """
+ Basic test of cleaning text
+ """
+ docs = ["Bonjour le monde! #thisisfrench #ilovefrance",
+ "Bonjour le monde! https://t.co/U0Zjp3tusD"]
+ docs = [Document([], text=text) for text in docs]
+
+ nlp = Pipeline(dir=TEST_MODELS_DIR, lang="multilingual", processors="langid")
+ nlp(docs)
+ assert [doc.lang for doc in docs] == ["it", "it"]
+
+ nlp = Pipeline(dir=TEST_MODELS_DIR, lang="multilingual", processors="langid", langid_clean_text=True)
+ assert nlp.processors["langid"]._clean_text
+ nlp(docs)
+ assert [doc.lang for doc in docs] == ["fr", "fr"]
+
+def test_lang_subset():
+ """
+ Basic test of restricting output to subset of languages
+ """
+ docs = ["Bonjour le monde! #thisisfrench #ilovefrance",
+ "Bonjour le monde! https://t.co/U0Zjp3tusD"]
+ docs = [Document([], text=text) for text in docs]
+
+ nlp = Pipeline(dir=TEST_MODELS_DIR, lang="multilingual", processors="langid")
+ nlp(docs)
+ assert [doc.lang for doc in docs] == ["it", "it"]
+
+ nlp = Pipeline(dir=TEST_MODELS_DIR, lang="multilingual", processors="langid", langid_lang_subset=["en","fr"])
+ assert nlp.processors["langid"]._model.lang_subset == ["en", "fr"]
+ nlp(docs)
+ assert [doc.lang for doc in docs] == ["fr", "fr"]
+
+ nlp = Pipeline(dir=TEST_MODELS_DIR, lang="multilingual", processors="langid", langid_lang_subset=["en"])
+ assert nlp.processors["langid"]._model.lang_subset == ["en"]
+ nlp(docs)
+ assert [doc.lang for doc in docs] == ["en", "en"]
+
+def test_multilingual_pipeline():
+ """
+ Basic test of multilingual pipeline
+ """
+ english_text = "This is an English sentence."
+ english_deps_gold = "\n".join((
+ "('This', 5, 'nsubj')",
+ "('is', 5, 'cop')",
+ "('an', 5, 'det')",
+ "('English', 5, 'amod')",
+ "('sentence', 0, 'root')",
+ "('.', 5, 'punct')"
+ ))
+
+ french_text = "C'est une phrase française."
+ french_deps_gold = "\n".join((
+ "(\"C'\", 4, 'nsubj')",
+ "('est', 4, 'cop')",
+ "('une', 4, 'det')",
+ "('phrase', 0, 'root')",
+ "('française', 4, 'amod')",
+ "('.', 4, 'punct')"
+ ))
+
+ nlp = MultilingualPipeline(model_dir=TEST_MODELS_DIR)
+ docs = [english_text, french_text]
+ docs = nlp(docs)
+
+ assert docs[0].lang == "en"
+ assert docs[0].sentences[0].dependencies_string() == english_deps_gold
+ assert docs[1].lang == "fr"
+ assert docs[1].sentences[0].dependencies_string() == french_deps_gold
+
diff --git a/stanza/tests/test_parser_eval.py b/stanza/tests/test_parser_eval.py
new file mode 100644
index 00000000..d633c529
--- /dev/null
+++ b/stanza/tests/test_parser_eval.py
@@ -0,0 +1,40 @@
+"""
+Test the parser eval interface
+"""
+
+import pytest
+import stanza
+from stanza.models.constituency import tree_reader
+from stanza.protobuf import EvaluateParserRequest, EvaluateParserResponse
+from stanza.server.parser_eval import build_request, EvaluateParser
+from stanza.tests.test_java_protobuf_requests import check_tree
+
+from stanza.tests import *
+
+pytestmark = [pytest.mark.travis, pytest.mark.client]
+
+def build_one_tree_treebank():
+ text = "((S (VP (VB Unban)) (NP (NNP Mox) (NNP Opal))))"
+ trees = tree_reader.read_trees(text)
+ assert len(trees) == 1
+ gold = trees[0]
+ prediction = (gold, 1.0)
+ treebank = [(gold, [prediction])]
+ return treebank
+
+def test_build_request_one_tree():
+ treebank = build_one_tree_treebank()
+ request = build_request(treebank)
+
+ assert len(request.treebank) == 1
+ check_tree(request.treebank[0].gold, treebank[0][0], None)
+ assert len(request.treebank[0].predicted) == 1
+ check_tree(request.treebank[0].predicted[0], treebank[0][1][0][0], treebank[0][1][0][1])
+
+
+def test_score_one_tree():
+ treebank = build_one_tree_treebank()
+
+ with EvaluateParser(classpath="$CLASSPATH") as ep:
+ response = ep.process(treebank)
+ assert response.f1 == pytest.approx(1.0)
diff --git a/stanza/tests/test_pipeline_sentiment_processor.py b/stanza/tests/test_pipeline_sentiment_processor.py
index b46eedf4..d78dbbb6 100644
--- a/stanza/tests/test_pipeline_sentiment_processor.py
+++ b/stanza/tests/test_pipeline_sentiment_processor.py
@@ -36,3 +36,12 @@ def test_multiple_sentences(pipeline):
results = [sentence.sentiment for sentence in doc.sentences]
assert EXPECTED == results
+def test_empty_text(pipeline):
+ """
+ Test empty text and a text which might get reduced to empty text by removing dashes
+ """
+ doc = pipeline("")
+ assert len(doc.sentences) == 0
+
+ doc = pipeline("--")
+ assert len(doc.sentences) == 1
diff --git a/stanza/tests/test_tokenization_lst20.py b/stanza/tests/test_tokenization_lst20.py
new file mode 100644
index 00000000..a0728123
--- /dev/null
+++ b/stanza/tests/test_tokenization_lst20.py
@@ -0,0 +1,236 @@
+import os
+import tempfile
+
+import pytest
+
+import stanza
+from stanza.tests import *
+
+from stanza.utils.datasets.prepare_tokenizer_treebank import convert_conllu_to_txt
+from stanza.utils.datasets.tokenization.convert_th_lst20 import read_document
+from stanza.utils.datasets.tokenization.process_thai_tokenization import write_section
+
+pytestmark = [pytest.mark.travis, pytest.mark.pipeline]
+
+SMALL_LST_SAMPLE="""
+สุรยุทธ์ NN B_PER B_CLS
+ยัน VV O I_CLS
+ปฏิเสธ VV O I_CLS
+ลงนาม VV O I_CLS
+_ PU O I_CLS
+MOU NN O I_CLS
+_ PU O I_CLS
+กับ PS O I_CLS
+อียู NN B_ORG I_CLS
+ไม่ NG O I_CLS
+กระทบ VV O I_CLS
+สัมพันธ์ NN O E_CLS
+
+1 NU B_DTM B_CLS
+_ PU I_DTM I_CLS
+กันยายน NN I_DTM I_CLS
+_ PU I_DTM I_CLS
+2550 NU E_DTM I_CLS
+_ PU O I_CLS
+12:21 NU B_DTM I_CLS
+_ PU I_DTM I_CLS
+น. CL E_DTM E_CLS
+
+ผู้สื่อข่าว NN O B_CLS
+รายงาน VV O I_CLS
+เพิ่มเติม VV O I_CLS
+ว่า CC O E_CLS
+_ PU O O
+จาก PS O B_CLS
+การ FX O I_CLS
+ลง VV O I_CLS
+พื้นที่ NN O I_CLS
+พบ VV O I_CLS
+ว่า CC O E_CLS
+""".strip()
+
+EXPECTED_CONLLU="""
+1 สุรยุทธ์ _ _ _ _ 0 root 0:root SpaceAfter=No|NewPar=Yes
+2 ยัน _ _ _ _ 1 dep 1:dep SpaceAfter=No
+3 ปฏิเสธ _ _ _ _ 2 dep 2:dep SpaceAfter=No
+4 ลงนาม _ _ _ _ 3 dep 3:dep _
+5 MOU _ _ _ _ 4 dep 4:dep _
+6 กับ _ _ _ _ 5 dep 5:dep SpaceAfter=No
+7 อียู _ _ _ _ 6 dep 6:dep SpaceAfter=No
+8 ไม่ _ _ _ _ 7 dep 7:dep SpaceAfter=No
+9 กระทบ _ _ _ _ 8 dep 8:dep SpaceAfter=No
+10 สัมพันธ์ _ _ _ _ 9 dep 9:dep SpaceAfter=No
+
+1 1 _ _ _ _ 0 root 0:root _
+2 กันยายน _ _ _ _ 1 dep 1:dep _
+3 2550 _ _ _ _ 2 dep 2:dep _
+4 12:21 _ _ _ _ 3 dep 3:dep _
+5 น. _ _ _ _ 4 dep 4:dep SpaceAfter=No
+
+1 ผู้สื่อข่าว _ _ _ _ 0 root 0:root SpaceAfter=No
+2 รายงาน _ _ _ _ 1 dep 1:dep SpaceAfter=No
+3 เพิ่มเติม _ _ _ _ 2 dep 2:dep SpaceAfter=No
+4 ว่า _ _ _ _ 3 dep 3:dep _
+5 จาก _ _ _ _ 4 dep 4:dep SpaceAfter=No
+6 การ _ _ _ _ 5 dep 5:dep SpaceAfter=No
+7 ลง _ _ _ _ 6 dep 6:dep SpaceAfter=No
+8 พื้นที่ _ _ _ _ 7 dep 7:dep SpaceAfter=No
+9 พบ _ _ _ _ 8 dep 8:dep SpaceAfter=No
+10 ว่า _ _ _ _ 9 dep 9:dep SpaceAfter=No
+""".strip()
+
+# Note: these DO NOT line up perfectly (in an emacs window, at least)
+# because Thai characters have a length greater than 1.
+# The lengths of the words are:
+# สุรยุทธ์ 8
+# ยัน 3
+# ปฏิเสธ 6
+# ลงนาม 5
+# MOU 3
+# กับ 3
+# อียู 4
+# ไม่ 3
+# กระทบ 5
+# สัมพันธ์ 8
+# 1 1
+# กันยายน 7
+# 2550 4
+# 12:21 5
+# น. 2
+# ผู้สื่อข่าว 11
+# รายงาน 6
+# เพิ่มเติม 9
+# ว่า 3
+# จาก 3
+# การ 3
+# ลง 2
+# พื้นที่ 7
+# พบ 2
+# ว่า 3
+EXPECTED_TXT = "สุรยุทธ์ยันปฏิเสธลงนาม MOU กับอียูไม่กระทบสัมพันธ์1 กันยายน 2550 12:21 น.ผู้สื่อข่าวรายงานเพิ่มเติมว่า จากการลงพื้นที่พบว่า\n\n"
+EXPECTED_LABELS = "000000010010000010000100010001000100100001000000021000000010000100000100200000000001000001000000001001000100101000000101002\n\n"
+# counting spaces 1234567812312345612345_123_123123412312345123456781_1234567_1234_12345_12123456789AB123456123456789123_12312312123456712123
+
+# note that the word splits go on the final letter of the word in the
+# UD conllu datasets, so that is what we mimic here
+# for example, from EWT:
+# Al-Zaman : American forces killed Shaikh Abdullah
+# 0110000101000000001000000100000010000001000000001
+
+def check_results(documents, expected_conllu, expected_txt, expected_labels):
+ with tempfile.TemporaryDirectory() as output_dir:
+ write_section(output_dir, "lst20", "train", documents)
+ with open(os.path.join(output_dir, "th_lst20.train.gold.conllu")) as fin:
+ conllu = fin.read().strip()
+ with open(os.path.join(output_dir, "th_lst20.train.txt")) as fin:
+ txt = fin.read()
+ with open(os.path.join(output_dir, "th_lst20-ud-train.toklabels")) as fin:
+ labels = fin.read()
+ assert conllu == expected_conllu
+ assert txt == expected_txt
+ assert labels == expected_labels
+
+ assert len(txt) == len(labels)
+
+
+def test_small():
+ """
+ A small test just to verify that the output is being produced as we want
+
+ Note that there currently are no spaces after the first sentence.
+ Apparently this is wrong, but weirdly, doing that makes the model even worse.
+ """
+ lines = SMALL_LST_SAMPLE.strip().split("\n")
+ documents = read_document(lines, spaces_after=False, split_clauses=False)
+ check_results(documents, EXPECTED_CONLLU, EXPECTED_TXT, EXPECTED_LABELS)
+
+EXPECTED_SPACE_CONLLU="""
+1 สุรยุทธ์ _ _ _ _ 0 root 0:root SpaceAfter=No|NewPar=Yes
+2 ยัน _ _ _ _ 1 dep 1:dep SpaceAfter=No
+3 ปฏิเสธ _ _ _ _ 2 dep 2:dep SpaceAfter=No
+4 ลงนาม _ _ _ _ 3 dep 3:dep _
+5 MOU _ _ _ _ 4 dep 4:dep _
+6 กับ _ _ _ _ 5 dep 5:dep SpaceAfter=No
+7 อียู _ _ _ _ 6 dep 6:dep SpaceAfter=No
+8 ไม่ _ _ _ _ 7 dep 7:dep SpaceAfter=No
+9 กระทบ _ _ _ _ 8 dep 8:dep SpaceAfter=No
+10 สัมพันธ์ _ _ _ _ 9 dep 9:dep _
+
+1 1 _ _ _ _ 0 root 0:root _
+2 กันยายน _ _ _ _ 1 dep 1:dep _
+3 2550 _ _ _ _ 2 dep 2:dep _
+4 12:21 _ _ _ _ 3 dep 3:dep _
+5 น. _ _ _ _ 4 dep 4:dep _
+
+1 ผู้สื่อข่าว _ _ _ _ 0 root 0:root SpaceAfter=No
+2 รายงาน _ _ _ _ 1 dep 1:dep SpaceAfter=No
+3 เพิ่มเติม _ _ _ _ 2 dep 2:dep SpaceAfter=No
+4 ว่า _ _ _ _ 3 dep 3:dep _
+5 จาก _ _ _ _ 4 dep 4:dep SpaceAfter=No
+6 การ _ _ _ _ 5 dep 5:dep SpaceAfter=No
+7 ลง _ _ _ _ 6 dep 6:dep SpaceAfter=No
+8 พื้นที่ _ _ _ _ 7 dep 7:dep SpaceAfter=No
+9 พบ _ _ _ _ 8 dep 8:dep SpaceAfter=No
+10 ว่า _ _ _ _ 9 dep 9:dep _
+""".strip()
+
+EXPECTED_SPACE_TXT = "สุรยุทธ์ยันปฏิเสธลงนาม MOU กับอียูไม่กระทบสัมพันธ์ 1 กันยายน 2550 12:21 น. ผู้สื่อข่าวรายงานเพิ่มเติมว่า จากการลงพื้นที่พบว่า\n\n"
+EXPECTED_SPACE_LABELS = "00000001001000001000010001000100010010000100000002010000000100001000001002000000000001000001000000001001000100101000000101002\n\n"
+
+def test_space_after():
+ """
+ This version of the test adds the space after attribute
+ """
+ lines = SMALL_LST_SAMPLE.strip().split("\n")
+ documents = read_document(lines, spaces_after=True, split_clauses=False)
+ check_results(documents, EXPECTED_SPACE_CONLLU, EXPECTED_SPACE_TXT, EXPECTED_SPACE_LABELS)
+
+
+EXPECTED_CLAUSE_CONLLU="""
+1 สุรยุทธ์ _ _ _ _ 0 root 0:root SpaceAfter=No|NewPar=Yes
+2 ยัน _ _ _ _ 1 dep 1:dep SpaceAfter=No
+3 ปฏิเสธ _ _ _ _ 2 dep 2:dep SpaceAfter=No
+4 ลงนาม _ _ _ _ 3 dep 3:dep _
+5 MOU _ _ _ _ 4 dep 4:dep _
+6 กับ _ _ _ _ 5 dep 5:dep SpaceAfter=No
+7 อียู _ _ _ _ 6 dep 6:dep SpaceAfter=No
+8 ไม่ _ _ _ _ 7 dep 7:dep SpaceAfter=No
+9 กระทบ _ _ _ _ 8 dep 8:dep SpaceAfter=No
+10 สัมพันธ์ _ _ _ _ 9 dep 9:dep _
+
+1 1 _ _ _ _ 0 root 0:root _
+2 กันยายน _ _ _ _ 1 dep 1:dep _
+3 2550 _ _ _ _ 2 dep 2:dep _
+4 12:21 _ _ _ _ 3 dep 3:dep _
+5 น. _ _ _ _ 4 dep 4:dep _
+
+1 ผู้สื่อข่าว _ _ _ _ 0 root 0:root SpaceAfter=No
+2 รายงาน _ _ _ _ 1 dep 1:dep SpaceAfter=No
+3 เพิ่มเติม _ _ _ _ 2 dep 2:dep SpaceAfter=No
+4 ว่า _ _ _ _ 3 dep 3:dep _
+
+1 จาก _ _ _ _ 0 root 0:root SpaceAfter=No
+2 การ _ _ _ _ 1 dep 1:dep SpaceAfter=No
+3 ลง _ _ _ _ 2 dep 2:dep SpaceAfter=No
+4 พื้นที่ _ _ _ _ 3 dep 3:dep SpaceAfter=No
+5 พบ _ _ _ _ 4 dep 4:dep SpaceAfter=No
+6 ว่า _ _ _ _ 5 dep 5:dep _
+""".strip()
+
+EXPECTED_CLAUSE_TXT = "สุรยุทธ์ยันปฏิเสธลงนาม MOU กับอียูไม่กระทบสัมพันธ์ 1 กันยายน 2550 12:21 น. ผู้สื่อข่าวรายงานเพิ่มเติมว่า จากการลงพื้นที่พบว่า\n\n"
+EXPECTED_CLAUSE_LABELS = "00000001001000001000010001000100010010000100000002010000000100001000001002000000000001000001000000001002000100101000000101002\n\n"
+
+
+def test_split_clause():
+ """
+ This version of the test also resplits on spaces between clauses
+ """
+ lines = SMALL_LST_SAMPLE.strip().split("\n")
+ documents = read_document(lines, spaces_after=True, split_clauses=True)
+ check_results(documents, EXPECTED_CLAUSE_CONLLU, EXPECTED_CLAUSE_TXT, EXPECTED_CLAUSE_LABELS)
+
+if __name__ == "__main__":
+ lines = SMALL_LST_SAMPLE.strip().split("\n")
+ documents = read_document(lines, spaces_after=False, split_clauses=False)
+
+ write_section("foo", "lst20", "train", documents)
diff --git a/stanza/tests/test_tokenization_orchid.py b/stanza/tests/test_tokenization_orchid.py
new file mode 100644
index 00000000..8c0fb9f5
--- /dev/null
+++ b/stanza/tests/test_tokenization_orchid.py
@@ -0,0 +1,107 @@
+import os
+import tempfile
+
+import pytest
+
+import xml.etree.ElementTree as ET
+
+import stanza
+from stanza.tests import *
+
+from stanza.utils.datasets.prepare_tokenizer_treebank import convert_conllu_to_txt
+from stanza.utils.datasets.tokenization.convert_th_orchid import parse_xml
+from stanza.utils.datasets.tokenization.process_thai_tokenization import write_section
+
+pytestmark = [pytest.mark.travis, pytest.mark.pipeline]
+
+
+SMALL_DOC="""
+<corpus>
+<document TPublisher="ศูนย์เทคโนโลยีอิเล็กทรอนิกส์และคอมพิวเตอร์แห่งชาติ, กระทรวงวิทยาศาสตร์ เทคโนโลยีและการพลังงาน" EPublisher="National Electronics and Computer Technology Center, Ministry of Science, Technology and Energy" TInbook="การประชุมทางวิชาการ ครั้งที่ 1, โครงการวิจัยและพัฒนาอิเล็กทรอนิกส์และคอมพิวเตอร์, ปีงบประมาณ 2531, เล่ม 1" TTitle="การประชุมทางวิชาการ ครั้งที่ 1" Year="1989" EInbook="The 1st Annual Conference, Electronics and Computer Research and Development Project, Fiscal Year 1988, Book 1" ETitle="[1st Annual Conference]">
+<paragraph id="1" line_num="12">
+<sentence id="1" line_num = "13" raw_txt = "การประชุมทางวิชาการ ครั้งที่ 1">
+<word surface="การ" pos="FIXN"/>
+<word surface="ประชุม" pos="VACT"/>
+<word surface="ทาง" pos="NCMN"/>
+<word surface="วิชาการ" pos="NCMN"/>
+<word surface="&lt;space&gt;" pos="PUNC"/>
+<word surface="ครั้ง" pos="CFQC"/>
+<word surface="ที่ 1" pos="DONM"/>
+</sentence>
+<sentence id="2" line_num = "23" raw_txt = "โครงการวิจัยและพัฒนาอิเล็กทรอนิกส์และคอมพิวเตอร์">
+<word surface="โครงการวิจัยและพัฒนา" pos="NCMN"/>
+<word surface="อิเล็กทรอนิกส์" pos="NCMN"/>
+<word surface="และ" pos="JCRG"/>
+<word surface="คอมพิวเตอร์" pos="NCMN"/>
+</sentence>
+</paragraph>
+<paragraph id="3" line_num="51">
+<sentence id="1" line_num = "52" raw_txt = "วันที่ 15-16 สิงหาคม 2532">
+<word surface="วัน" pos="NCMN"/>
+<word surface="ที่ 15" pos="DONM"/>
+<word surface="&lt;minus&gt;" pos="PUNC"/>
+<word surface="16" pos="DONM"/>
+<word surface="&lt;space&gt;" pos="PUNC"/>
+<word surface="สิงหาคม" pos="NCMN"/>
+<word surface="&lt;space&gt;" pos="PUNC"/>
+<word surface="2532" pos="NCNM"/>
+</sentence>
+</paragraph>
+</document>
+</corpus>
+"""
+
+
+EXPECTED_RESULTS="""
+1 การ _ _ _ _ 0 root 0:root SpaceAfter=No|NewPar=Yes
+2 ประชุม _ _ _ _ 1 dep 1:dep SpaceAfter=No
+3 ทาง _ _ _ _ 2 dep 2:dep SpaceAfter=No
+4 วิชาการ _ _ _ _ 3 dep 3:dep _
+5 ครั้ง _ _ _ _ 4 dep 4:dep SpaceAfter=No
+6 ที่ 1 _ _ _ _ 5 dep 5:dep _
+
+1 โครงการวิจัยและพัฒนา _ _ _ _ 0 root 0:root SpaceAfter=No
+2 อิเล็กทรอนิกส์ _ _ _ _ 1 dep 1:dep SpaceAfter=No
+3 และ _ _ _ _ 2 dep 2:dep SpaceAfter=No
+4 คอมพิวเตอร์ _ _ _ _ 3 dep 3:dep _
+
+1 วัน _ _ _ _ 0 root 0:root SpaceAfter=No|NewPar=Yes
+2 ที่ 15 _ _ _ _ 1 dep 1:dep SpaceAfter=No
+3 - _ _ _ _ 2 dep 2:dep SpaceAfter=No
+4 16 _ _ _ _ 3 dep 3:dep _
+5 สิงหาคม _ _ _ _ 4 dep 4:dep _
+6 2532 _ _ _ _ 5 dep 5:dep _
+""".strip()
+
+EXPECTED_TEXT="""การประชุมทางวิชาการ ครั้งที่ 1 โครงการวิจัยและพัฒนาอิเล็กทรอนิกส์และคอมพิวเตอร์
+
+วันที่ 15-16 สิงหาคม 2532
+
+"""
+
+EXPECTED_LABELS="""0010000010010000001000001000020000000000000000000010000000000000100100000000002
+
+0010000011010000000100002
+
+"""
+
+def check_results(documents, expected_conllu, expected_txt, expected_labels):
+ with tempfile.TemporaryDirectory() as output_dir:
+ write_section(output_dir, "orchid", "train", documents)
+ with open(os.path.join(output_dir, "th_orchid.train.gold.conllu")) as fin:
+ conllu = fin.read().strip()
+ with open(os.path.join(output_dir, "th_orchid.train.txt")) as fin:
+ txt = fin.read()
+ with open(os.path.join(output_dir, "th_orchid-ud-train.toklabels")) as fin:
+ labels = fin.read()
+ assert conllu == expected_conllu
+ assert txt == expected_txt
+ assert labels == expected_labels
+
+ assert len(txt) == len(labels)
+
+def test_orchid():
+ tree = ET.ElementTree(ET.fromstring(SMALL_DOC))
+ documents = parse_xml(tree)
+ check_results(documents, EXPECTED_RESULTS, EXPECTED_TEXT, EXPECTED_LABELS)
+
diff --git a/stanza/tests/test_tokenize_data.py b/stanza/tests/test_tokenize_data.py
index 1bd2ae39..c37ace4e 100644
--- a/stanza/tests/test_tokenize_data.py
+++ b/stanza/tests/test_tokenize_data.py
@@ -23,6 +23,7 @@ FAKE_PROPERTIES = {
"lang":"de",
'feat_funcs': ("space_before","capitalized"),
'max_seqlen': 300,
+ 'use_dictionary': False,
}
def test_has_mwt():
diff --git a/stanza/tests/test_tokenizer.py b/stanza/tests/test_tokenizer.py
index fc5c96b8..9a48eb54 100644
--- a/stanza/tests/test_tokenizer.py
+++ b/stanza/tests/test_tokenizer.py
@@ -166,22 +166,21 @@ ZH_DOC1_GOLD_TOKENS="""
<Token id=1;words=[<Word id=1;text=北京;lemma=北京;upos=PROPN;xpos=NNP;head=5;deprel=nsubj>]>
<Token id=2;words=[<Word id=2;text=是;lemma=是;upos=AUX;xpos=VC;head=5;deprel=cop>]>
<Token id=3;words=[<Word id=3;text=中国;lemma=中国;upos=PROPN;xpos=NNP;head=5;deprel=nmod>]>
-<Token id=4;words=[<Word id=4;text=的;lemma=的;upos=PART;xpos=DEC;feats=Case=Gen;head=3;deprel=case:dec>]>
+<Token id=4;words=[<Word id=4;text=的;lemma=的;upos=PART;xpos=DEC;feats=Case=Gen;head=3;deprel=case>]>
<Token id=5;words=[<Word id=5;text=首都;lemma=首都;upos=NOUN;xpos=NN;head=0;deprel=root>]>
<Token id=6;words=[<Word id=6;text=。;lemma=。;upos=PUNCT;xpos=.;head=5;deprel=punct>]>
<Token id=1;words=[<Word id=1;text=北京;lemma=北京;upos=PROPN;xpos=NNP;head=2;deprel=nsubj>]>
-<Token id=2;words=[<Word id=2;text=有;lemma=有;upos=VERB;xpos=VV;head=11;deprel=acl>]>
+<Token id=2;words=[<Word id=2;text=有;lemma=有;upos=VERB;xpos=VV;head=10;deprel=acl>]>
<Token id=3;words=[<Word id=3;text=2100万;lemma=2100万;upos=NUM;xpos=CD;feats=NumType=Card;head=4;deprel=nummod>]>
-<Token id=4;words=[<Word id=4;text=人;lemma=人;upos=NOUN;xpos=NN;head=5;deprel=compound>]>
-<Token id=5;words=[<Word id=5;text=口;lemma=口;upos=PART;xpos=SFN;head=2;deprel=obj>]>
-<Token id=6;words=[<Word id=6;text=,;lemma=,;upos=PUNCT;xpos=,;head=11;deprel=punct>]>
-<Token id=7;words=[<Word id=7;text=是;lemma=是;upos=AUX;xpos=VC;head=11;deprel=cop>]>
-<Token id=8;words=[<Word id=8;text=一;lemma=一;upos=NUM;xpos=CD;feats=NumType=Card;head=9;deprel=nummod>]>
-<Token id=9;words=[<Word id=9;text=个;lemma=个;upos=NOUN;xpos=NNB;head=11;deprel=nmod>]>
-<Token id=10;words=[<Word id=10;text=直辖;lemma=直辖;upos=VERB;xpos=VV;head=11;deprel=compound>]>
-<Token id=11;words=[<Word id=11;text=市;lemma=市;upos=PART;xpos=SFN;head=0;deprel=root>]>
-<Token id=12;words=[<Word id=12;text=。;lemma=。;upos=PUNCT;xpos=.;head=11;deprel=punct>]>
+<Token id=4;words=[<Word id=4;text=人口;lemma=人口;upos=NOUN;xpos=NN;head=2;deprel=obj>]>
+<Token id=5;words=[<Word id=5;text=,;lemma=,;upos=PUNCT;xpos=,;head=10;deprel=punct>]>
+<Token id=6;words=[<Word id=6;text=是;lemma=是;upos=AUX;xpos=VC;head=10;deprel=cop>]>
+<Token id=7;words=[<Word id=7;text=一;lemma=一;upos=NUM;xpos=CD;feats=NumType=Card;head=8;deprel=nummod>]>
+<Token id=8;words=[<Word id=8;text=个;lemma=个;upos=NOUN;xpos=NNB;head=10;deprel=nmod>]>
+<Token id=9;words=[<Word id=9;text=直辖;lemma=直辖;upos=VERB;xpos=VV;head=10;deprel=compound>]>
+<Token id=10;words=[<Word id=10;text=市;lemma=市;upos=PART;xpos=SFN;head=0;deprel=root>]>
+<Token id=11;words=[<Word id=11;text=。;lemma=。;upos=PUNCT;xpos=.;head=10;deprel=punct>]>
""".strip()
ZH_DOC_GOLD_NOSSPLIT_TOKENS = """
diff --git a/stanza/utils/charlm/make_lm_data.py b/stanza/utils/charlm/make_lm_data.py
index a2a7e3e8..e1a8ca16 100644
--- a/stanza/utils/charlm/make_lm_data.py
+++ b/stanza/utils/charlm/make_lm_data.py
@@ -86,6 +86,9 @@ def prepare_lm_data(src_dir, tgt_dir, lang, dataset_name):
for src_fn in glob.glob(str(src_dir) + '/*.txt.xz'):
cmd = f"xzcat {src_fn} >> {tgt_tmp}"
subprocess.run(cmd, shell=True)
+ for src_fn in glob.glob(str(src_dir) + '/*.txt.gz'):
+ cmd = f"zcat {src_fn} >> {tgt_tmp}"
+ subprocess.run(cmd, shell=True)
tgt_tmp_shuffled = Path(str(tgt_tmp) + ".shuffled")
print(f"--> Shuffling files into {tgt_tmp_shuffled}...")
diff --git a/stanza/utils/datasets/common.py b/stanza/utils/datasets/common.py
index 7ea7e617..8c1631fb 100644
--- a/stanza/utils/datasets/common.py
+++ b/stanza/utils/datasets/common.py
@@ -115,6 +115,7 @@ def get_ud_treebanks(udbase_dir, filtered=True):
def build_argparse():
parser = argparse.ArgumentParser()
parser.add_argument('treebanks', type=str, nargs='+', help='Which treebanks to run on. Use all_ud or ud_all for all UD treebanks')
+
return parser
diff --git a/stanza/utils/datasets/constituency/convert_it_turin.py b/stanza/utils/datasets/constituency/convert_it_turin.py
new file mode 100644
index 00000000..018073e3
--- /dev/null
+++ b/stanza/utils/datasets/constituency/convert_it_turin.py
@@ -0,0 +1,322 @@
+"""
+Converts Turin's constituency dataset
+
+Turin University put out a freely available constituency dataset in 2011.
+It is not as large as VIT or ISST, but it is free, which is nice.
+
+The 2011 parsing task combines trees from several sources:
+http://www.di.unito.it/~tutreeb/evalita-parsingtask-11.html
+
+There is another site for Turin treebanks:
+http://www.di.unito.it/~tutreeb/treebanks.html
+
+Weirdly, the most recent versions of the Evalita trees are not there.
+The most relevant parts are the ParTUT downloads. As of Sep. 2021:
+
+http://www.di.unito.it/~tutreeb/corpora/Par-TUT/tutINpenn/italian/JRCAcquis_It.pen
+http://www.di.unito.it/~tutreeb/corpora/Par-TUT/tutINpenn/italian/UDHR_It.pen
+http://www.di.unito.it/~tutreeb/corpora/Par-TUT/tutINpenn/italian/CC_It.pen
+http://www.di.unito.it/~tutreeb/corpora/Par-TUT/tutINpenn/italian/FB_It.pen
+http://www.di.unito.it/~tutreeb/corpora/Par-TUT/tutINpenn/italian/WIT3_It.pen
+
+We can't simply cat all these files together as there are a bunch of
+asterisks as comments and the files may have some duplicates. For
+example, the JRCAcquis piece has many duplicates. Also, some don't
+pass validation for one reason or another.
+
+One oddity of these data files is that the MWT are denoted by doubling
+the token. The token is not split as would be expected, though. We try
+to use stanza's MWT tokenizer for IT to split the tokens, with some
+rules added by hand in BIWORD_SPLITS. Two are still unsplit, though...
+"""
+
+import glob
+import os
+import re
+import sys
+
+import stanza
+from stanza.models.constituency import parse_tree
+from stanza.models.constituency import tree_reader
+
+def load_without_asterisks(in_file, encoding='utf-8'):
+ with open(in_file, encoding=encoding) as fin:
+ new_lines = [x if x.find("********") < 0 else "\n" for x in fin.readlines()]
+ if len(new_lines) > 0 and not new_lines[-1].endswith("\n"):
+ new_lines[-1] = new_lines[-1] + "\n"
+ return new_lines
+
+CONSTITUENT_SPLIT = re.compile("[-=#+0-9]")
+
+# JRCA is almost entirely duplicates
+# WIT3 follows a different annotation scheme
+FILES_TO_ELIMINATE = ["JRCAcquis_It.pen", "WIT3_It.pen"]
+
+# assuming this is a typo
+REMAP_NODES = { "Sbar" : "SBAR" }
+
+REMAP_WORDS = { "-LSB-": "[", "-RSB-": "]" }
+
+# these mostly seem to be mistakes
+# maybe Vbar and ADVbar should be converted to something else?
+NODES_TO_ELIMINATE = ["C", "PHRASP", "PRDT", "Vbar", "parte", "ADVbar"]
+
+UNKNOWN_SPLITS = set()
+
+# a map of splits that the tokenizer or MWT doesn't handle well
+BIWORD_SPLITS = { "offertogli": ("offerto", "gli"),
+ "offertegli": ("offerte", "gli"),
+ "formatasi": ("formata", "si"),
+ "formatosi": ("formato", "si"),
+ "multiplexarlo": ("multiplexar", "lo"),
+ "esibirsi": ("esibir", "si"),
+ "pagarne": ("pagar", "ne"),
+ "recarsi": ("recar", "si"),
+ "trarne": ("trar", "ne"),
+ "esserci": ("esser", "ci"),
+ "aprirne": ("aprir", "ne"),
+ "farle": ("far", "le"),
+ "disporne": ("dispor", "ne"),
+ "andargli": ("andar", "gli"),
+ "CONSIDERARSI": ("CONSIDERAR", "SI"),
+ "conferitegli": ("conferite", "gli"),
+ "formatasi": ("formata", "si"),
+ "formatosi": ("formato", "si"),
+ "Formatisi": ("Formati", "si"),
+ "multiplexarlo": ("multiplexar", "lo"),
+ "esibirsi": ("esibir", "si"),
+ "pagarne": ("pagar", "ne"),
+ "recarsi": ("recar", "si"),
+ "trarne": ("trar", "ne"),
+ "temerne": ("temer", "ne"),
+ "esserci": ("esser", "ci"),
+ "esservi": ("esser", "vi"),
+ "restituirne": ("restituir", "ne"),
+ "col": ("con", "il"),
+ "cogli": ("con", "gli"),
+ "dirgli": ("dir", "gli"),
+ "opporgli": ("oppor", "gli"),
+ "eccolo": ("ecco", "lo"),
+ "Eccolo": ("Ecco", "lo"),
+ "Eccole": ("Ecco", "le"),
+ "farci": ("far", "ci"),
+ "farli": ("far", "li"),
+ "farne": ("far", "ne"),
+ "farsi": ("far", "si"),
+ "farvi": ("far", "vi"),
+ "Connettiti": ("Connetti", "ti"),
+ "APPLICARSI": ("APPLICAR", "SI"),
+ # This is not always two words, but if it IS two words,
+ # it gets split like this
+ "assicurati": ("assicura", "ti"),
+ "Fatti": ("Fai", "te"),
+ "ai": ("a", "i"),
+ "Ai": ("A", "i"),
+ "AI": ("A", "I"),
+ "al": ("a", "il"),
+ "Al": ("A", "il"),
+ "AL": ("A", "IL"),
+ "coi": ("con", "i"),
+ "colla": ("con", "la"),
+ "colle": ("con", "le"),
+ "dal": ("da", "il"),
+ "Dal": ("Da", "il"),
+ "DAL": ("DA", "IL"),
+ "dei": ("di", "i"),
+ "Dei": ("Di", "i"),
+ "DEI": ("DI", "I"),
+ "del": ("di", "il"),
+ "Del": ("Di", "il"),
+ "DEL": ("DI", "IL"),
+ "nei": ("in", "i"),
+ "NEI": ("IN", "I"),
+ "nel": ("in", "il"),
+ "Nel": ("In", "il"),
+ "NEL": ("IN", "IL"),
+ "pel": ("per", "il"),
+ "sui": ("su", "i"),
+ "Sui": ("Su", "i"),
+ "sul": ("su", "il"),
+ "Sul": ("Su", "il"),
+ ",": (",", ","),
+ ".": (".", "."),
+ '"': ('"', '"'),
+ '-': ('-', '-'),
+ '-LRB-': ('-LRB-', '-LRB-'),
+ "garantirne": ("garantir", "ne"),
+ "aprirvi": ("aprir", "vi"),
+ "esimersi": ("esimer", "si"),
+ "opporsi": ("oppor", "si"),
+}
+
+CAP_BIWORD = re.compile("[A-Z]+_[A-Z]+")
+
+def split_mwe(tree, pipeline):
+ words = list(tree.leaf_labels())
+ found = False
+ for idx, word in enumerate(words[:-3]):
+ if word == words[idx+1] and word == words[idx+2] and word == words[idx+3]:
+ raise ValueError("Oh no, 4 consecutive words")
+
+ for idx, word in enumerate(words[:-2]):
+ if word == words[idx+1] and word == words[idx+2]:
+ doc = pipeline(word)
+ assert len(doc.sentences) == 1
+ if len(doc.sentences[0].words) != 3:
+ raise RuntimeError("Word {} not tokenized into 3 parts... thought all 3 part words were handled!".format(word))
+ words[idx] = doc.sentences[0].words[0].text
+ words[idx+1] = doc.sentences[0].words[1].text
+ words[idx+2] = doc.sentences[0].words[2].text
+ found = True
+
+ for idx, word in enumerate(words[:-1]):
+ if word == words[idx+1]:
+ if word in BIWORD_SPLITS:
+ first_word = BIWORD_SPLITS[word][0]
+ second_word = BIWORD_SPLITS[word][1]
+ elif CAP_BIWORD.match(word):
+ first_word, second_word = word.split("_")
+ else:
+ doc = pipeline(word)
+ assert len(doc.sentences) == 1
+ if len(doc.sentences[0].words) == 2:
+ first_word = doc.sentences[0].words[0].text
+ second_word = doc.sentences[0].words[1].text
+ else:
+ if word not in UNKNOWN_SPLITS:
+ UNKNOWN_SPLITS.add(word)
+ print("Could not figure out how to split {}\n {}\n {}".format(word, " ".join(words), tree))
+ continue
+
+ words[idx] = first_word
+ words[idx+1] = second_word
+ found = True
+
+ if found:
+ tree = tree.replace_words(words)
+ return tree
+
+
+def load_trees(filename, pipeline):
+ # some of the files are in latin-1 encoding rather than utf-8
+ try:
+ raw_text = load_without_asterisks(filename, "utf-8")
+ except UnicodeDecodeError:
+ raw_text = load_without_asterisks(filename, "latin-1")
+
+ # also, some have messed up validation (it will be logged)
+ # hence the broken_ok=True argument
+ trees = tree_reader.read_trees("".join(raw_text), broken_ok=True)
+
+ filtered_trees = []
+ for tree in trees:
+ if tree.children[0].label is None:
+ print("Skipping a broken tree (missing label) in {}: {}".format(filename, tree))
+ continue
+
+ try:
+ words = tuple(tree.leaf_labels())
+ except ValueError:
+ print("Skipping a broken tree (missing preterminal) in {}: {}".format(filename, tree))
+ continue
+
+ if any('www.facebook' in pt.label for pt in tree.preterminals()):
+ print("Skipping a tree with a weird preterminal label in {}: {}".format(filename, tree))
+ continue
+
+ tree = tree.prune_none().simplify_labels(CONSTITUENT_SPLIT)
+ tree = tree.remap_constituent_labels(REMAP_NODES)
+ tree = tree.remap_words(REMAP_WORDS)
+
+ tree = split_mwe(tree, pipeline)
+ if tree is None:
+ continue
+
+ constituents = set(parse_tree.Tree.get_unique_constituent_labels(tree))
+ for weird_label in NODES_TO_ELIMINATE:
+ if weird_label in constituents:
+ break
+ else:
+ weird_label = None
+ if weird_label is not None:
+ print("Skipping a tree with a weird label {} in {}: {}".format(weird_label, filename, tree))
+ continue
+
+ filtered_trees.append(tree)
+
+ return filtered_trees
+
+def save_trees(out_file, trees):
+ print("Saving {} trees to {}".format(len(trees), out_file))
+ with open(out_file, "w", encoding="utf-8") as fout:
+ for tree in trees:
+ fout.write(str(tree))
+ fout.write("\n")
+
+def main():
+ pipeline = stanza.Pipeline("it", processors="tokenize, mwt", tokenize_no_ssplit=True)
+
+ input_path = sys.argv[1]
+ output_path = sys.argv[2]
+
+ os.makedirs(output_path, exist_ok=True)
+
+ evalita_dir = os.path.join(input_path, "evalita")
+
+ evalita_test = os.path.join(evalita_dir, "evalita11_TESTgold_CONPARSE.penn")
+ it_test = os.path.join(output_path, "it_turin_test.mrg")
+ test_trees = load_trees(evalita_test, pipeline)
+ save_trees(it_test, test_trees)
+
+ known_text = set()
+ for tree in test_trees:
+ words = tuple(tree.leaf_labels())
+ assert words not in known_text
+ known_text.add(words)
+
+ evalita_train = os.path.join(output_path, "it_turin_train.mrg")
+ evalita_files = glob.glob(os.path.join(evalita_dir, "*2011*penn"))
+ turin_files = glob.glob(os.path.join(input_path, "turin", "*pen"))
+ filenames = evalita_files + turin_files
+ filtered_trees = []
+ for filename in filenames:
+ if os.path.split(filename)[1] in FILES_TO_ELIMINATE:
+ continue
+
+ trees = load_trees(filename, pipeline)
+ file_trees = []
+
+ for tree in trees:
+ words = tuple(tree.leaf_labels())
+ if words in known_text:
+ print("Skipping a duplicate in {}: {}".format(filename, tree))
+ continue
+
+ known_text.add(words)
+
+ file_trees.append(tree)
+
+ filtered_trees.append((filename, file_trees))
+
+ print("{} contains {} usable trees".format(evalita_test, len(test_trees)))
+ print(" Unique constituents in {}: {}".format(evalita_test, parse_tree.Tree.get_unique_constituent_labels(test_trees)))
+
+ train_trees = []
+ dev_trees = []
+ for filename, file_trees in filtered_trees:
+ print("{} contains {} usable trees".format(filename, len(file_trees)))
+ print(" Unique constituents in {}: {}".format(filename, parse_tree.Tree.get_unique_constituent_labels(file_trees)))
+ for tree in file_trees:
+ if len(train_trees) <= len(dev_trees) * 9:
+ train_trees.append(tree)
+ else:
+ dev_trees.append(tree)
+
+ it_train = os.path.join(output_path, "it_turin_train.mrg")
+ save_trees(it_train, train_trees)
+
+ it_dev = os.path.join(output_path, "it_turin_dev.mrg")
+ save_trees(it_dev, dev_trees)
+
+if __name__ == '__main__':
+ main()
diff --git a/stanza/utils/datasets/constituency/vtb_convert.py b/stanza/utils/datasets/constituency/vtb_convert.py
new file mode 100644
index 00000000..d9b92e0e
--- /dev/null
+++ b/stanza/utils/datasets/constituency/vtb_convert.py
@@ -0,0 +1,75 @@
+"""
+Script for processing the VTB files and turning their trees into the desired tree syntax
+
+The VTB original trees are stored in the directory:
+VietTreebank_VLSP_SP73/Kho ngu lieu 10000 cay cu phap
+
+The script requires two arguments:
+1. Original directory storing the original trees
+2. New directory storing the converted trees
+"""
+
+
+import os
+import argparse
+
+
+def convert_file(org_dir, new_dir):
+ """
+ :param org_dir: original directory storing original trees
+ :param new_dir: new directory storing formatted constituency trees
+
+ This function writes new trees to the corresponding files in new_dir
+ """
+ with open(org_dir, 'r') as reader, open(new_dir, 'w') as writer:
+ content = reader.readlines()
+ for line in content:
+ line = ' '.join(line.split())
+ if line == '':
+ continue
+ elif line == '<s>':
+ writer.write('(ROOT ')
+ elif line == '</s>':
+ writer.write(')\n')
+ else:
+ writer.write(line)
+
+
+def main():
+ """
+ Main function for the script
+
+ Process args, loop through each file in the directory and convert
+ to the desired tree format
+ """
+ parser = argparse.ArgumentParser(
+ description="Script that converts a VTB Tree into the desired format",
+ )
+ parser.add_argument(
+ 'org_dir',
+ help='The location of the original directory storing original trees '
+ )
+ parser.add_argument(
+ 'new_dir',
+ help='The location of new directory storing the new formatted trees'
+ )
+
+ args = parser.parse_args()
+
+ org_dir = args.org_dir
+ new_dir = args.new_dir
+
+ for filename in os.listdir(org_dir):
+ file_name, file_extension = os.path.splitext(filename)
+ # Only convert .prd files, skip the .raw files
+ if file_extension == '.raw':
+ continue
+ file_path = os.path.join(org_dir, filename)
+ new_path = os.path.join(new_dir, file_name)
+ new_file_path = f'{new_path}.mrg'
+ # Convert the tree and write to new_file_path
+ convert_file(file_path, new_file_path)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/stanza/utils/datasets/constituency/vtb_split.py b/stanza/utils/datasets/constituency/vtb_split.py
new file mode 100644
index 00000000..27a7161b
--- /dev/null
+++ b/stanza/utils/datasets/constituency/vtb_split.py
@@ -0,0 +1,130 @@
+"""
+From a directory of files with VTB Trees, split into train/dev/test set
+with a split of 70/15/15
+
+The script requires two arguments
+1. org_dir: the original directory obtainable from running vtb_convert.py
+2. split_dir: the directory where the train/dev/test splits will be stored
+"""
+
+import os
+import argparse
+import random
+
+
+def create_shuffle_list(org_dir):
+ """
+ This function creates the random order with which we use to loop through the files
+ :param org_dir: original directory storing the files that store the trees
+ :return: list of file names randomly shuffled
+ """
+ file_names = []
+ for filename in os.listdir(org_dir):
+ file_names.append(filename)
+ random.shuffle(file_names)
+
+ return file_names
+
+
+def create_paths(split_dir):
+ """
+ This function creates the necessary paths for the train/dev/test splits
+ :param split_dir: directory that stores the splits
+ :return: train path, dev path, test path
+ """
+ train_path = os.path.join(split_dir, 'train.mrg')
+ dev_path = os.path.join(split_dir, 'dev.mrg')
+ test_path = os.path.join(split_dir, 'test.mrg')
+
+ return train_path, dev_path, test_path
+
+
+def get_num_samples(org_dir, file_names):
+ """
+ Function for obtaining the number of samples
+ :param org_dir: original directory storing the tree files
+ :param file_names: list of file names in the directory
+ :return: number of samples
+ """
+ count = 0
+ # Loop through the files, which then loop through the trees
+ for filename in file_names:
+ # Skip files that are not .mrg
+ if not filename.endswith('.mrg'):
+ continue
+ # File is .mrg. Start processing
+ file_dir = os.path.join(org_dir, filename)
+ with open(file_dir, 'r') as reader:
+ content = reader.readlines()
+ for _ in content:
+ count += 1
+
+ return count
+
+
+def main():
+ """
+ Main function for the script
+
+ Process args, loop through each tree in each file in the directory
+ and write the trees to the train/dev/test split with a split of
+ 70/15/15
+ """
+ parser = argparse.ArgumentParser(
+ description="Script that splits a list of files of vtb trees into train/dev/test sets",
+ )
+ parser.add_argument(
+ 'org_dir',
+ help='The location of the original directory storing correctly formatted vtb trees '
+ )
+ parser.add_argument(
+ 'split_dir',
+ help='The location of new directory storing the train/dev/test set'
+ )
+
+ args = parser.parse_args()
+
+ org_dir = args.org_dir
+ split_dir = args.split_dir
+
+ random.seed(1234)
+
+ # Create a random shuffle list of the file names in the original directory
+ file_names = create_shuffle_list(org_dir)
+
+ # Create train_path, dev_path, test_path
+ train_path, dev_path, test_path = create_paths(split_dir)
+
+ # Set up the number of samples for each train/dev/test set
+ num_samples = get_num_samples(org_dir, file_names)
+ stop_train = int(num_samples * 0.7)
+ stop_dev = int(num_samples * 0.85)
+
+ # Write directory and write count
+ write_dir = train_path
+ count = 0
+
+ # Loop through the files, which then loop through the trees and write to write_dir
+ for filename in file_names:
+ # Skip files that are not .mrg
+ if not filename.endswith('.mrg'):
+ continue
+ # File is .mrg. Start processing
+ file_dir = os.path.join(org_dir, filename)
+ with open(file_dir, 'r') as reader, open(write_dir, 'a') as writer:
+ content = reader.readlines()
+ for line in content:
+ # Write to write_dir
+ writer.write(line)
+ # Check current count to switch write_dir
+ count += 1
+ # Switch to writing dev set
+ if count > stop_train:
+ write_dir = dev_path
+ # Switch to writing test set
+ if count > stop_dev:
+ write_dir = test_path
+
+
+if __name__ == '__main__':
+ main()
diff --git a/stanza/utils/datasets/ner/convert_bsf_to_beios.py b/stanza/utils/datasets/ner/convert_bsf_to_beios.py
index 6309efe2..16b7150d 100644
--- a/stanza/utils/datasets/ner/convert_bsf_to_beios.py
+++ b/stanza/utils/datasets/ner/convert_bsf_to_beios.py
@@ -4,8 +4,9 @@ import os
import glob
from collections import namedtuple
import re
+from typing import Tuple
from tqdm import tqdm
-from random import choices
+from random import choices, shuffle
BsfInfo = namedtuple('BsfInfo', 'id, tag, start_idx, end_idx, token')
@@ -93,14 +94,16 @@ def parse_bsf(bsf_data: str) -> list:
CORPUS_NAME = 'Ukrainian-languk'
+
def convert_bsf_in_folder(src_dir_path: str, dst_dir_path: str, converter: str = 'beios',
- doc_delim: str = '\n') -> None:
+ doc_delim: str = '\n', train_test_split_file: str = None) -> None:
"""
:param doc_delim: delimiter to be used between documents
:param src_dir_path: path to directory with BSF marked files
:param dst_dir_path: where to save output data
:param converter: `beios` or `iob` output formats
+ :param train_test_split_file: path to file cotaining train/test lists of file names
:return:
"""
ann_path = os.path.join(src_dir_path, '*.tok.ann')
@@ -127,7 +130,10 @@ def convert_bsf_in_folder(src_dir_path: str, dst_dir_path: str, converter: str =
data_sets = [train_set, dev_set, test_set]
split_weights = (8, 1, 1)
- log.info(f'Found {len(tok_files)} files')
+ if train_test_split_file is not None:
+ train_names, dev_names, test_names = read_languk_train_test_split(train_test_split_file)
+
+ log.info(f'Found {len(tok_files)} files in data folder "{src_dir_path}"')
for (tok_fname, ann_fname) in tqdm(zip(tok_files, ann_files), total=len(tok_files), unit='file'):
if tok_fname[:-3] != ann_fname[:-3]:
tqdm.write(f'Token and Annotation file names do not match ann={ann_fname}, tok={tok_fname}')
@@ -138,7 +144,16 @@ def convert_bsf_in_folder(src_dir_path: str, dst_dir_path: str, converter: str =
ann_data = ann_file.read()
out_data = convert_bsf(token_data, ann_data, converter)
- target_dataset = choices(data_sets, split_weights)[0]
+ if train_test_split_file is None:
+ target_dataset = choices(data_sets, split_weights)[0]
+ else:
+ target_dataset = train_set
+ fkey = os.path.basename(tok_fname)[:-4]
+ if fkey in dev_names:
+ target_dataset = dev_set
+ elif fkey in test_names:
+ target_dataset = test_set
+
target_dataset.append(out_data)
log.info(f'Data is split as following: train={len(train_set)}, dev={len(dev_set)}, test={len(test_set)}')
@@ -155,6 +170,43 @@ def convert_bsf_in_folder(src_dir_path: str, dst_dir_path: str, converter: str =
log.info('All done')
+def read_languk_train_test_split(file_path: str, dev_split: float = 0.1) -> Tuple:
+ """
+ Read predefined split of train and test files in data set.
+ Originally located under doc/dev-test-split.txt
+ :param file_path: path to dev-test-split.txt file (should include file name with extension)
+ :param dev_split: 0 to 1 float value defining how much to allocate to dev split
+ :return: tuple of (train, dev, test) each containing list of files to be used for respective data sets
+ """
+ log.info(f'Trying to read train/dev/test split from file "{file_path}". Dev allocation = {dev_split}')
+ train_files, test_files, dev_files = [], [], []
+ container = test_files
+ with open(file_path, 'r') as f:
+ for ln in f:
+ ln = ln.strip()
+ if ln == 'DEV':
+ container = train_files
+ elif ln == 'TEST':
+ container = test_files
+ elif ln == '':
+ pass
+ else:
+ container.append(ln)
+
+ # split in file only contains train and test split.
+ # For Stanza training we need train, dev, test
+ # We will take part of train as dev set
+ # This way anyone using test set outside of this code base can be sure that there was no data set polution
+ shuffle(train_files)
+ dev_files = train_files[: int(len(train_files) * dev_split)]
+ train_files = train_files[int(len(train_files) * dev_split):]
+
+ assert len(set(train_files).intersection(set(dev_files))) == 0
+
+ log.info(f'Files in each set: train={len(train_files)}, dev={len(dev_files)}, test={len(test_files)}')
+ return train_files, dev_files, test_files
+
+
if __name__ == '__main__':
logging.basicConfig()
@@ -165,7 +217,8 @@ if __name__ == '__main__':
parser.add_argument('--dst', type=str, default='data/ner', help='Where to store the converted dataset')
parser.add_argument('-c', type=str, default='beios', help='`beios` or `iob` formats to be used for output')
parser.add_argument('--doc_delim', type=str, default='\n', help='Delimiter to be used to separate documents in the output data')
+ parser.add_argument('--split_file', type=str, help='Name of a file containing Train/Test split (files in train and test set)')
parser.print_help()
args = parser.parse_args()
- convert_bsf_in_folder(args.src_dataset, args.dst, args.c, args.doc_delim)
+ convert_bsf_in_folder(args.src_dataset, args.dst, args.c, args.doc_delim, train_test_split_file=args.split_file)
diff --git a/stanza/utils/datasets/ner/convert_fire_2013.py b/stanza/utils/datasets/ner/convert_fire_2013.py
index f76aa696..b95275be 100644
--- a/stanza/utils/datasets/ner/convert_fire_2013.py
+++ b/stanza/utils/datasets/ner/convert_fire_2013.py
@@ -13,6 +13,7 @@ This script keeps just the word and the ner1. It is quite possible that using t
import argparse
import glob
import os
+import random
def normalize(entity):
if entity == 'o':
@@ -41,6 +42,10 @@ def convert_fileset(output_csv_file, filenames):
for sentence in sentences:
for line in sentence:
pieces = line.split("\t")
+ if len(pieces) != 6:
+ raise ValueError("Found %d pieces instead of the expected 6" % len(pieces))
+ if pieces[3] == 'o' and (pieces[4] != 'o' or pieces[5] != 'o'):
+ raise ValueError("Inner NER labeled but the top layer was O")
fout.write("%s\t%s\n" % (pieces[0], normalize(pieces[3])))
fout.write("\n")
@@ -49,6 +54,7 @@ def convert_fire_2013(input_path, train_csv_file, dev_csv_file, test_csv_file):
# won't be numerically sorted... shouldn't matter
filenames = sorted(filenames)
+ random.shuffle(filenames)
train_cutoff = int(0.8 * len(filenames))
dev_cutoff = int(0.9 * len(filenames))
@@ -65,11 +71,13 @@ def convert_fire_2013(input_path, train_csv_file, dev_csv_file, test_csv_file):
convert_fileset(test_csv_file, test_files)
if __name__ == '__main__':
+ random.seed(1234)
+
parser = argparse.ArgumentParser()
parser.add_argument('--input_path', type=str, default="/home/john/extern_data/ner/FIRE2013/hindi_train", help="Directory with raw files to read")
parser.add_argument('--train_file', type=str, default="/home/john/stanza/data/ner/hi_fire2013.train.csv", help="Where to put the train file")
- parser.add_argument('--dev_file', type=str, default="/home/john/stanza/data/ner/hi_fire2013.dev.csv", help="Where to put the train file")
- parser.add_argument('--test_file', type=str, default="/home/john/stanza/data/ner/hi_fire2013.test.csv", help="Where to put the train file")
+ parser.add_argument('--dev_file', type=str, default="/home/john/stanza/data/ner/hi_fire2013.dev.csv", help="Where to put the dev file")
+ parser.add_argument('--test_file', type=str, default="/home/john/stanza/data/ner/hi_fire2013.test.csv", help="Where to put the test file")
args = parser.parse_args()
convert_fire_2013(args.input_path, args.train_file, args.dev_file, args.test_file)
diff --git a/stanza/utils/datasets/ner/prepare_ner_dataset.py b/stanza/utils/datasets/ner/prepare_ner_dataset.py
index 36a96da4..f248748b 100644
--- a/stanza/utils/datasets/ner/prepare_ner_dataset.py
+++ b/stanza/utils/datasets/ner/prepare_ner_dataset.py
@@ -10,7 +10,11 @@ Also, Finnish Turku dataset, available here:
- https://turkunlp.org/fin-ner.html
- Download and unzip the corpus, putting the .tsv files into
$NERBASE/fi_turku
- - prepare_ner_dataset.py hu_nytk fi_turku
+ - prepare_ner_dataset.py fi_turku
+
+FBK in Italy produced an Italian dataset.
+ The processing here is for a combined .tsv file they sent us.
+ - prepare_ner_dataset.py it_fbk
FBK in Italy produced an Italian dataset.
The processing here is for a combined .tsv file they sent us.
@@ -22,12 +26,14 @@ IJCNLP 2008 produced a few Indian language NER datasets.
download:
http://ltrc.iiit.ac.in/ner-ssea-08/index.cgi?topic=5
The models produced from these datasets have extremely low recall, unfortunately.
+ - prepare_ner_dataset.py hi_ijc
FIRE 2013 also produced NER datasets for Indian languages.
http://au-kbc.org/nlp/NER-FIRE2013/index.html
The datasets are password locked.
For Stanford users, contact Chris Manning for license details.
For external users, please contact the organizers for more information.
+ - prepare_ner_dataset.py hi-fire2013
Ukranian NER is provided by lang-uk, available here:
https://github.com/lang-uk/ner-uk
@@ -56,10 +62,12 @@ The two Hungarian datasets can be combined with hu_combined
BSNLP publishes NER datasets for Eastern European languages.
- In 2019 they published BG, CS, PL, RU.
+ - http://bsnlp.cs.helsinki.fi/bsnlp-2019/shared_task.html
- In 2021 they added some more data, but the test sets
were not publicly available as of April 2021.
Therefore, currently the model is made from 2019.
- - http://bsnlp.cs.helsinki.fi/bsnlp-2019/shared_task.html
+ In 2021, the link to the 2021 task is here:
+ http://bsnlp.cs.helsinki.fi/shared-task.html
- The below method processes the 2019 version of the corpus.
It has specific adjustments for the BG section, which has
quite a few typos or mis-annotations in it. Other languages
@@ -100,11 +108,11 @@ import tempfile
from stanza.models.common.constant import treebank_to_short_name, lcode2lang
import stanza.utils.default_paths as default_paths
-from stanza.utils.datasets.ner.convert_fire_2013 import convert_fire_2013
from stanza.utils.datasets.ner.preprocess_wikiner import preprocess_wikiner
from stanza.utils.datasets.ner.split_wikiner import split_wikiner
import stanza.utils.datasets.ner.convert_bsf_to_beios as convert_bsf_to_beios
import stanza.utils.datasets.ner.convert_bsnlp as convert_bsnlp
+import stanza.utils.datasets.ner.convert_fire_2013 as convert_fire_2013
import stanza.utils.datasets.ner.convert_ijc as convert_ijc
import stanza.utils.datasets.ner.convert_rgai as convert_rgai
import stanza.utils.datasets.ner.convert_nytk as convert_nytk
@@ -154,7 +162,8 @@ def process_languk(paths):
short_name = 'uk_languk'
base_input_path = os.path.join(paths["NERBASE"], 'lang-uk', 'ner-uk', 'data')
base_output_path = paths["NER_DATA_DIR"]
- convert_bsf_to_beios.convert_bsf_in_folder(base_input_path, base_output_path)
+ train_test_split_fname = os.path.join(paths["NERBASE"], 'lang-uk', 'ner-uk', 'doc', 'dev-test-split.txt')
+ convert_bsf_to_beios.convert_bsf_in_folder(base_input_path, base_output_path, train_test_split_file=train_test_split_fname)
for shard in SHARDS:
input_filename = os.path.join(base_output_path, convert_bsf_to_beios.CORPUS_NAME, "%s.bio" % shard)
if not os.path.exists(input_filename):
@@ -204,6 +213,7 @@ def process_fire_2013(paths, dataset):
"""
short_name = treebank_to_short_name(dataset)
langcode, _ = short_name.split("_")
+ short_name = "%s_fire2013" % langcode
if not langcode in ("hi", "en", "ta", "bn", "mal"):
raise ValueError("Language %s not one of the FIRE 2013 languages")
language = lcode2lang[langcode].lower()
@@ -216,7 +226,7 @@ def process_fire_2013(paths, dataset):
dev_csv_file = os.path.join(base_output_path, "%s.dev.csv" % short_name)
test_csv_file = os.path.join(base_output_path, "%s.test.csv" % short_name)
- convert_fire_2013(base_input_path, train_csv_file, dev_csv_file, test_csv_file)
+ convert_fire_2013.convert_fire_2013(base_input_path, train_csv_file, dev_csv_file, test_csv_file)
for csv_file, shard in zip((train_csv_file, dev_csv_file, test_csv_file), SHARDS):
output_filename = os.path.join(base_output_path, '%s.%s.json' % (short_name, shard))
@@ -398,7 +408,7 @@ def main(dataset_name):
process_languk(paths)
elif dataset_name == 'hi_ijc':
process_ijc(paths, dataset_name)
- elif dataset_name.endswith("FIRE2013"):
+ elif dataset_name.endswith("FIRE2013") or dataset_name.endswith("fire2013"):
process_fire_2013(paths, dataset_name)
elif dataset_name.endswith('WikiNER'):
process_wikiner(paths, dataset_name)
diff --git a/stanza/utils/datasets/ner/prepare_ner_file.py b/stanza/utils/datasets/ner/prepare_ner_file.py
index e5fe2220..71f2aeee 100644
--- a/stanza/utils/datasets/ner/prepare_ner_file.py
+++ b/stanza/utils/datasets/ner/prepare_ner_file.py
@@ -34,8 +34,8 @@ def process_dataset(input_filename, output_filename):
document += [sent]
with open(output_filename, 'w') as outfile:
- json.dump(document, outfile)
- print("Generated json file {}.".format(output_filename))
+ json.dump(document, outfile, indent=1)
+ print("Generated json file {}".format(output_filename))
# TODO: make skip_doc_start an argument
def load_conll03(filename, skip_doc_start=True):
@@ -47,7 +47,9 @@ def load_conll03(filename, skip_doc_start=True):
if skip_doc_start and DOC_START_TOKEN in line:
continue
if len(line) > 0:
- array = line.split()
+ array = line.split("\t")
+ if len(array) < MIN_NUM_FIELD:
+ array = line.split()
if len(array) < MIN_NUM_FIELD:
continue
else:
@@ -64,8 +66,10 @@ def process_cache(cached_lines):
tokens = []
ner_tags = []
for line in cached_lines:
- array = line.split()
- assert len(array) >= MIN_NUM_FIELD and len(array) <= MAX_NUM_FIELD
+ array = line.split("\t")
+ if len(array) < MIN_NUM_FIELD:
+ array = line.split()
+ assert len(array) >= MIN_NUM_FIELD and len(array) <= MAX_NUM_FIELD, "Got unexpected line length: {}".format(array)
tokens.append(array[0])
ner_tags.append(array[-1])
return (tokens, ner_tags)
diff --git a/stanza/utils/datasets/prepare_tokenizer_treebank.py b/stanza/utils/datasets/prepare_tokenizer_treebank.py
index 459a6a74..35941cce 100755
--- a/stanza/utils/datasets/prepare_tokenizer_treebank.py
+++ b/stanza/utils/datasets/prepare_tokenizer_treebank.py
@@ -34,7 +34,10 @@ from collections import Counter
import stanza.utils.datasets.common as common
import stanza.utils.datasets.prepare_tokenizer_data as prepare_tokenizer_data
-
+import stanza.utils.datasets.tokenization.convert_vi_vlsp as convert_vi_vlsp
+import stanza.utils.datasets.tokenization.convert_th_best as convert_th_best
+import stanza.utils.datasets.tokenization.convert_th_lst20 as convert_th_lst20
+import stanza.utils.datasets.tokenization.convert_th_orchid as convert_th_orchid
def copy_conllu_file(tokenizer_dir, tokenizer_file, dest_dir, dest_file, short_name):
original = f"{tokenizer_dir}/{short_name}.{tokenizer_file}.conllu"
@@ -136,15 +139,24 @@ def prepare_treebank_labels(tokenizer_dir, short_name):
for dataset in ("train", "dev", "test"):
output_txt = f"{tokenizer_dir}/{short_name}.{dataset}.txt"
output_conllu = f"{tokenizer_dir}/{short_name}.{dataset}.gold.conllu"
- prepare_dataset_labels(output_txt, output_conllu, tokenizer_dir, short_name, dataset)
+ try:
+ prepare_dataset_labels(output_txt, output_conllu, tokenizer_dir, short_name, dataset)
+ except (KeyboardInterrupt, SystemExit):
+ raise
+ except:
+ print("Failed to convert %s to %s" % (output_txt, output_conllu))
+ raise
CONLLU_TO_TXT_PERL = os.path.join(os.path.split(__file__)[0], "conllu_to_text.pl")
-def convert_conllu_to_txt(tokenizer_dir, short_name):
- for dataset in ("train", "dev", "test"):
+def convert_conllu_to_txt(tokenizer_dir, short_name, shards=("train", "dev", "test")):
+ for dataset in shards:
output_conllu = f"{tokenizer_dir}/{short_name}.{dataset}.gold.conllu"
output_txt = f"{tokenizer_dir}/{short_name}.{dataset}.txt"
+ if not os.path.exists(output_conllu):
+ # the perl script doesn't raise an error code for file not found!
+ raise FileNotFoundError("Cannot convert %s as the file cannot be found" % output_conllu)
# use an external script to produce the txt files
subprocess.check_output(f"perl {CONLLU_TO_TXT_PERL} {output_conllu} > {output_txt}", shell=True)
@@ -944,7 +956,7 @@ def build_combined_english_gum(udbase_dir, tokenizer_dir, short_name, augment):
build_combined_english_gum_dataset(udbase_dir, tokenizer_dir, short_name, dataset, augment)
def prepare_ud_dataset(treebank, udbase_dir, tokenizer_dir, short_name, short_language, dataset, augment=True):
- input_conllu = common.find_treebank_dataset_file(treebank, udbase_dir, dataset, "conllu")
+ input_conllu = common.find_treebank_dataset_file(treebank, udbase_dir, dataset, "conllu", fail=True)
output_conllu = f"{tokenizer_dir}/{short_name}.{dataset}.gold.conllu"
if short_name == "te_mtg" and dataset == 'train' and augment:
@@ -1009,14 +1021,15 @@ def add_specific_args(parser):
help='Augment the dataset in various ways')
parser.add_argument('--no_prepare_labels', action='store_false', dest='prepare_labels', default=True,
help='Prepare tokenizer and MWT labels. Expensive, but obviously necessary for training those models.')
+ convert_th_lst20.add_lst20_args(parser)
+ convert_vi_vlsp.add_vlsp_args(parser)
def process_treebank(treebank, paths, args):
"""
Processes a single treebank into train, dev, test parts
- TODO
- Currently assumes it is always a UD treebank. There are Thai
- treebanks which are not included in UD.
+ Includes processing for a few external tokenization datasets:
+ vi_vlsp, th_orchid, th_best
Also, there is no specific mechanism for UD_Arabic-NYUAD or
similar treebanks, which need integration with LDC datsets
@@ -1030,7 +1043,15 @@ def process_treebank(treebank, paths, args):
os.makedirs(tokenizer_dir, exist_ok=True)
- if short_name.startswith("ko_combined"):
+ if short_name == "vi_vlsp":
+ convert_vi_vlsp.convert_vi_vlsp(paths["EXTERN_DIR"], tokenizer_dir, args)
+ elif short_name == "th_orchid":
+ convert_th_orchid.main(paths["EXTERN_DIR"], tokenizer_dir)
+ elif short_name == "th_lst20":
+ convert_th_lst20.convert(paths["EXTERN_DIR"], tokenizer_dir, args)
+ elif short_name == "th_best":
+ convert_th_best.main(paths["EXTERN_DIR"], tokenizer_dir)
+ elif short_name.startswith("ko_combined"):
build_combined_korean(udbase_dir, tokenizer_dir, short_name)
elif short_name in ("it_combined", "en_combined", "es_combined"):
build_combined_dataset(udbase_dir, tokenizer_dir, handparsed_dir, short_name, args.augment)
@@ -1049,7 +1070,8 @@ def process_treebank(treebank, paths, args):
else:
process_ud_treebank(treebank, udbase_dir, tokenizer_dir, short_name, short_language, args.augment)
- convert_conllu_to_txt(tokenizer_dir, short_name)
+ if not short_name in ('th_orchid', 'th_lst20'):
+ convert_conllu_to_txt(tokenizer_dir, short_name)
if args.prepare_labels:
prepare_treebank_labels(tokenizer_dir, short_name)
diff --git a/stanza/utils/datasets/process_thai_tokenization.py b/stanza/utils/datasets/process_thai_tokenization.py
deleted file mode 100644
index 27e347dd..00000000
--- a/stanza/utils/datasets/process_thai_tokenization.py
+++ /dev/null
@@ -1,66 +0,0 @@
-import os
-import random
-
-def write_section(output_dir, dataset_name, section, documents):
- """
- Writes a list of documents for tokenization, including a file in conll format
-
- The Thai datasets generally have no MWT (apparently not relevant for Thai)
-
- output_dir: the destination directory for the output files
- dataset_name: orchid, BEST, lst20, etc
- section: train/dev/test
- documents: a nested list of documents, paragraphs, sentences, words
- words is a list of (word, space_follows)
- """
- with open(os.path.join(output_dir, 'th_%s-ud-%s-mwt.json' % (dataset_name, section)), 'w') as fout:
- fout.write("[]\n")
-
- text_out = open(os.path.join(output_dir, 'th_%s.%s.txt' % (dataset_name, section)), 'w')
- label_out = open(os.path.join(output_dir, 'th_%s-ud-%s.toklabels' % (dataset_name, section)), 'w')
- for document in documents:
- for paragraph in document:
- for sentence_idx, sentence in enumerate(paragraph):
- for word_idx, word in enumerate(sentence):
- # TODO: split with newlines to make it more readable?
- text_out.write(word[0])
- for i in range(len(word[0]) - 1):
- label_out.write("0")
- if word_idx == len(sentence) - 1:
- label_out.write("2")
- else:
- label_out.write("1")
- if word[1] and sentence_idx != len(paragraph) - 1:
- text_out.write(' ')
- label_out.write('0')
-
- text_out.write("\n\n")
- label_out.write("\n\n")
-
- text_out.close()
- label_out.close()
-
- with open(os.path.join(output_dir, 'th_%s.%s.gold.conllu' % (dataset_name, section)), 'w') as fout:
- for document in documents:
- for paragraph in document:
- for sentence in paragraph:
- for word_idx, word in enumerate(sentence):
- # SpaceAfter is left blank if there is space after the word
- space = '_' if word[1] else 'SpaceAfter=No'
- # Note the faked dependency structure: the conll reading code
- # needs it even if it isn't being used in any way
- fake_dep = 'root' if word_idx == 0 else 'dep'
- fout.write('{}\t{}\t_\t_\t_\t_\t{}\t{}\t{}:{}\t{}\n'.format(word_idx+1, word[0], word_idx, fake_dep, word_idx, fake_dep, space))
- fout.write('\n')
-
-def write_dataset(documents, output_dir, dataset_name):
- """
- Shuffle a list of documents, write three sections
- """
- random.shuffle(documents)
- num_train = int(len(documents) * 0.8)
- num_dev = int(len(documents) * 0.1)
- os.makedirs(output_dir, exist_ok=True)
- write_section(output_dir, dataset_name, 'train', documents[:num_train])
- write_section(output_dir, dataset_name, 'dev', documents[num_train:num_train+num_dev])
- write_section(output_dir, dataset_name, 'test', documents[num_train+num_dev:])
diff --git a/stanza/utils/datasets/thai_syllable_dict_generator.py b/stanza/utils/datasets/thai_syllable_dict_generator.py
new file mode 100644
index 00000000..ca658e16
--- /dev/null
+++ b/stanza/utils/datasets/thai_syllable_dict_generator.py
@@ -0,0 +1,53 @@
+import glob
+import pathlib
+import argparse
+
+
+def create_dictionary(dataset_dir, save_dir):
+ syllables = set()
+
+ for p in pathlib.Path(dataset_dir).rglob("*.ssg"): # iterate through all files
+
+ with open(p) as f: # for each file
+ sentences = f.readlines()
+
+ for i in range(len(sentences)):
+
+ sentences[i] = sentences[i].replace("\n", "")
+ sentences[i] = sentences[i].replace("<s/>", "~")
+ sentences[i] = sentences[i].split("~") # create list of all syllables
+
+ syllables = syllables.union(sentences[i])
+
+
+ print(len(syllables))
+
+ # Filter out syllables with English words
+ import re
+
+ a = []
+
+ for s in syllables:
+ print("---")
+ if bool(re.match("^[\u0E00-\u0E7F]*$", s)) and s != "" and " " not in s:
+ a.append(s)
+ else:
+ pass
+
+ a = set(a)
+ a = dict(zip(list(a), range(len(a))))
+
+ import json
+ print(a)
+ print(len(a))
+ with open(save_dir, "w") as fp:
+ json.dump(a, fp)
+
+if __name__ == "__main__":
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--dataset_dir', type=str, default="syllable_segmentation_data", help="Directory for syllable dataset")
+ parser.add_argument('--save_dir', type=str, default="thai-syllable.json", help="Directory for generated file")
+ args = parser.parse_args()
+
+ create_dictionary(args.dataset_dir, args.save_dir)
diff --git a/stanza/utils/datasets/tokenization/__init__.py b/stanza/utils/datasets/tokenization/__init__.py
new file mode 100644
index 00000000..e69de29b
--- /dev/null
+++ b/stanza/utils/datasets/tokenization/__init__.py
diff --git a/stanza/utils/datasets/process_best.py b/stanza/utils/datasets/tokenization/convert_th_best.py
index 21125455..778f2dac 100644
--- a/stanza/utils/datasets/process_best.py
+++ b/stanza/utils/datasets/tokenization/convert_th_best.py
@@ -12,10 +12,13 @@ This outputs the tokenization results in a conll format similar to
that of the UD treebanks, so we pretend to be a UD treebank for ease
of compatibility with the stanza tools.
-python3 -m stanza.utils.datasets.process_best extern_data/thai/best data/tokenize
+BEST can be downloaded from here:
+
+https://aiforthai.in.th/corpus.php
+
+python3 -m stanza.utils.datasets.tokenization.process_best extern_data/thai/best data/tokenize
./scripts/run_tokenize.sh UD_Thai-best --dropout 0.05 --unit_dropout 0.05 --steps 50000
"""
-
import glob
import os
import random
@@ -24,7 +27,7 @@ import sys
from pythainlp import sent_tokenize
-from stanza.utils.datasets.process_thai_tokenization import write_dataset
+from stanza.utils.datasets.tokenization.process_thai_tokenization import reprocess_lines, write_dataset, convert_processed_lines, write_dataset_best, write_dataset
def clean_line(line):
line = line.replace("html>", "html|>")
@@ -46,6 +49,8 @@ def clean_line(line):
# news_00008.txt and other news articles
line = re.sub("</AB>([0-9])", "</AB>|\\1", line)
line = line.replace("</AB> ", "</AB>|")
+ line = line.replace("<EM>", "<POEM>")
+ line = line.replace("</EM>", "</POEM>")
line = line.strip()
return line
@@ -60,6 +65,12 @@ def clean_word(word):
return word[4:-5]
if word.startswith("<POEM>") and word.endswith("</POEM>"):
return word[6:-7]
+ """
+ if word.startswith("<EM>"):
+ return word[4:]
+ if word.endswith("</EM>"):
+ return word[:-5]
+ """
if word.startswith("<NE>"):
return word[4:]
if word.endswith("</NE>"):
@@ -72,45 +83,12 @@ def clean_word(word):
return word
return word
-def reprocess_lines(processed_lines):
- reprocessed_lines = []
- for line in processed_lines:
- text = "".join(line)
- chunks = sent_tokenize(text)
- if sum(len(x) for x in chunks) != len(text):
- raise ValueError("Got unexpected text length: \n{}\nvs\n{}".format(text, chunks))
-
- chunk_lengths = [len(x) for x in chunks]
-
- current_length = 0
- new_line = []
- for word in line:
- if len(word) + current_length < chunk_lengths[0]:
- new_line.append(word)
- current_length = current_length + len(word)
- elif len(word) + current_length == chunk_lengths[0]:
- new_line.append(word)
- reprocessed_lines.append(new_line)
- new_line = []
- chunk_lengths = chunk_lengths[1:]
- current_length = 0
- else:
- remaining_len = chunk_lengths[0] - current_length
- new_line.append(word[:remaining_len])
- reprocessed_lines.append(new_line)
- word = word[remaining_len:]
- chunk_lengths = chunk_lengths[1:]
- while len(word) > chunk_lengths[0]:
- new_line = [word[:chunk_lengths[0]]]
- reprocessed_lines.append(new_line)
- word = word[chunk_lengths[0]:]
- chunk_lengths = chunk_lengths[1:]
- new_line = [word]
- current_length = len(word)
- reprocessed_lines.append(new_line)
- return reprocessed_lines
-
def read_data(input_dir):
+ # data for test sets
+ test_files = [os.path.join(input_dir, 'TEST_100K_ANS.txt')]
+ print(test_files)
+
+ # data for train and dev sets
subdirs = [os.path.join(input_dir, 'article'),
os.path.join(input_dir, 'encyclopedia'),
os.path.join(input_dir, 'news'),
@@ -121,10 +99,32 @@ def read_data(input_dir):
raise FileNotFoundError("Expected a directory that did not exist: {}".format(subdir))
files.extend(glob.glob(os.path.join(subdir, '*.txt')))
+ test_documents = []
+ for filename in test_files:
+ print("File name:", filename)
+ with open(filename) as fin:
+ processed_lines = []
+ for line in fin.readlines():
+ line = clean_line(line)
+ words = line.split("|")
+ words = [clean_word(x) for x in words]
+ for word in words:
+ if len(word) > 1 and word[0] == '<':
+ raise ValueError("Unexpected word '{}' in document {}".format(word, filename))
+ words = [x for x in words if x]
+ processed_lines.append(words)
+
+ processed_lines = reprocess_lines(processed_lines)
+ paragraphs = convert_processed_lines(processed_lines)
+
+ test_documents.extend(paragraphs)
+ print("Test document finished.")
+
documents = []
+
for filename in files:
with open(filename) as fin:
- sentences = []
+ print("File:", filename)
processed_lines = []
for line in fin.readlines():
line = clean_line(line)
@@ -137,38 +137,32 @@ def read_data(input_dir):
processed_lines.append(words)
processed_lines = reprocess_lines(processed_lines)
+ paragraphs = convert_processed_lines(processed_lines)
- for words in processed_lines:
- # turn the words into a sentence
- sentence = []
- for word in words:
- word = word.strip()
- if not word:
- if len(sentence) == 0:
- raise ValueError("Unexpected space at start of sentence in document {}".format(filename))
- sentence[-1] = (sentence[-1][0], True)
- else:
- sentence.append((word, False))
- # blank lines are very rare in best, but why not treat them as a paragraph break
- if len(sentence) == 0:
- paragraphs = [sentences]
- documents.append(paragraphs)
- sentences = []
- continue
- sentence[-1] = (sentence[-1][0], True)
- sentences.append(sentence)
- paragraphs = [sentences]
- documents.append(paragraphs)
-
- return documents
-
-def main():
+ documents.extend(paragraphs)
+
+ print("All documents finished.")
+
+ return documents, test_documents
+
+
+def main(*args):
random.seed(1000)
- input_dir = sys.argv[1]
- output_dir = sys.argv[2]
- documents = read_data(input_dir)
- write_dataset(documents, output_dir, "best")
+ if not args:
+ args = sys.argv[1:]
+
+ input_dir = args[0]
+ full_input_dir = os.path.join(input_dir, "thai", "best")
+ if os.path.exists(full_input_dir):
+ # otherwise hopefully the user gave us the full path?
+ input_dir = full_input_dir
+
+ output_dir = args[1]
+ documents, test_documents = read_data(input_dir)
+ print("Finished reading data.")
+ write_dataset_best(documents, test_documents, output_dir, "best")
if __name__ == '__main__':
main()
+
diff --git a/stanza/utils/datasets/tokenization/convert_th_lst20.py b/stanza/utils/datasets/tokenization/convert_th_lst20.py
new file mode 100644
index 00000000..744c44cd
--- /dev/null
+++ b/stanza/utils/datasets/tokenization/convert_th_lst20.py
@@ -0,0 +1,131 @@
+"""Processes the tokenization section of the LST20 Thai dataset
+
+The dataset is available here:
+
+https://aiforthai.in.th/corpus.php
+
+The data should be installed under ${EXTERN_DATA}/thai/LST20_Corpus
+
+python3 -m stanza.utils.datasets.tokenization.convert_th_lst20 extern_data data/tokenize
+
+Unlike Orchid and BEST, LST20 has train/eval/test splits, which we relabel train/dev/test.
+
+./scripts/run_tokenize.sh UD_Thai-lst20 --dropout 0.05 --unit_dropout 0.05
+"""
+
+
+import argparse
+import glob
+import os
+import sys
+
+from stanza.utils.datasets.tokenization.process_thai_tokenization import write_section, convert_processed_lines, reprocess_lines
+
+def read_document(lines, spaces_after, split_clauses):
+ document = []
+ sentence = []
+ for line in lines:
+ line = line.strip()
+ if not line:
+ if sentence:
+ if spaces_after:
+ sentence[-1] = (sentence[-1][0], True)
+ document.append(sentence)
+ sentence = []
+ else:
+ pieces = line.split("\t")
+ # there are some nbsp in tokens in lst20, but the downstream tools expect spaces
+ pieces = [p.replace("\xa0", " ") for p in pieces]
+ if split_clauses and pieces[0] == '_' and pieces[3] == 'O':
+ if sentence:
+ # note that we don't need to check spaces_after
+ # the "token" is a space anyway
+ sentence[-1] = (sentence[-1][0], True)
+ document.append(sentence)
+ sentence = []
+ elif pieces[0] == '_':
+ sentence[-1] = (sentence[-1][0], True)
+ else:
+ sentence.append((pieces[0], False))
+
+ if sentence:
+ if spaces_after:
+ sentence[-1] = (sentence[-1][0], True)
+ document.append(sentence)
+ sentence = []
+ # TODO: is there any way to divide up a single document into paragraphs?
+ return [[document]]
+
+def retokenize_document(lines):
+ processed_lines = []
+ sentence = []
+ for line in lines:
+ line = line.strip()
+ if not line:
+ if sentence:
+ processed_lines.append(sentence)
+ sentence = []
+ else:
+ pieces = line.split("\t")
+ if pieces[0] == '_':
+ sentence.append(' ')
+ else:
+ sentence.append(pieces[0])
+ if sentence:
+ processed_lines.append(sentence)
+
+ processed_lines = reprocess_lines(processed_lines)
+ paragraphs = convert_processed_lines(processed_lines)
+ return paragraphs
+
+
+def read_data(input_dir, section, resegment, spaces_after, split_clauses):
+ glob_path = os.path.join(input_dir, section, "*.txt")
+ filenames = glob.glob(glob_path)
+ print(" Found {} files in {}".format(len(filenames), glob_path))
+ if len(filenames) == 0:
+ raise FileNotFoundError("Could not find any files for the {} section. Is LST20 installed in {}?".format(section, input_dir))
+ documents = []
+ for filename in filenames:
+ with open(filename) as fin:
+ lines = fin.readlines()
+ if resegment:
+ document = retokenize_document(lines)
+ else:
+ document = read_document(lines, spaces_after, split_clauses)
+ documents.extend(document)
+ return documents
+
+def add_lst20_args(parser):
+ parser.add_argument('--no_lst20_resegment', action='store_false', dest="lst20_resegment", default=True, help='When processing th_lst20 tokenization, use pythainlp to resegment the text. The other option is to keep the original sentence segmentation. Currently our model is not good at that')
+ parser.add_argument('--lst20_spaces_after', action='store_true', dest="lst20_spaces_after", default=False, help='When processing th_lst20 without pythainlp, put spaces after each sentence. This better fits the language but gets lower scores for some reason')
+ parser.add_argument('--split_clauses', action='store_true', dest="split_clauses", default=False, help='When processing th_lst20 without pythainlp, turn spaces which are labeled as between clauses into sentence splits')
+
+def parse_lst20_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument('input_dir', help="Directory to use when processing lst20")
+ parser.add_argument('output_dir', help="Directory to use when saving lst20")
+ add_lst20_args(parser)
+ return parser.parse_args()
+
+
+
+def convert(input_dir, output_dir, args):
+ input_dir = os.path.join(input_dir, "thai", "LST20_Corpus")
+ if not os.path.exists(input_dir):
+ raise FileNotFoundError("Could not find LST20 corpus in {}".format(input_dir))
+
+ for (in_section, out_section) in (("train", "train"),
+ ("eval", "dev"),
+ ("test", "test")):
+ print("Processing %s" % out_section)
+ documents = read_data(input_dir, in_section, args.lst20_resegment, args.lst20_spaces_after, args.split_clauses)
+ print(" Read in %d documents" % len(documents))
+ write_section(output_dir, "lst20", out_section, documents)
+
+def main():
+ args = parse_lst20_args()
+ convert(args.input_dir, args.output_dir, args)
+
+if __name__ == '__main__':
+ main()
diff --git a/stanza/utils/datasets/process_orchid.py b/stanza/utils/datasets/tokenization/convert_th_orchid.py
index 794c3925..871e87d1 100644
--- a/stanza/utils/datasets/process_orchid.py
+++ b/stanza/utils/datasets/tokenization/convert_th_orchid.py
@@ -5,7 +5,7 @@ https://github.com/korakot/thainlp/blob/master/xmlchid.xml
For example, if you put the data file in the above link in
extern_data/thai/orchid/xmlchid.xml
you would then run
-python3 -m stanza.utils.datasets.process_orchid extern_data/thai/orchid/xmlchid.xml data/tokenize
+python3 -m stanza.utils.datasets.tokenization.convert_th_orchid extern_data/thai/orchid/xmlchid.xml data/tokenize
Because there is no definitive train/dev/test split that we have found
so far, we randomly shuffle the data on a paragraph level and split it
@@ -17,7 +17,7 @@ give it a fake UD name to make life easier for the downstream tools.
Training on this dataset seems to work best with low dropout numbers.
For example:
-./scripts/run_tokenize.sh UD_Thai-orchid --dropout 0.05 --unit_dropout 0.05
+python3 -m stanza.utils.training.run_tokenizer th_orchid --dropout 0.05 --unit_dropout 0.05
This results in a model with dev set scores:
th_orchid 87.98 70.94
@@ -27,11 +27,12 @@ test set scores:
Apparently the random split produced a test set easier than the dev set.
"""
+import os
import random
import sys
import xml.etree.ElementTree as ET
-from stanza.utils.datasets.process_thai_tokenization import write_dataset
+from stanza.utils.datasets.tokenization.process_thai_tokenization import write_dataset
# line "122819" has some error in the tokenization of the musical notation
# line "209380" is also messed up
@@ -91,8 +92,14 @@ allowed_sequences = {
}
def read_data(input_filename):
+ print("Reading {}".format(input_filename))
tree = ET.parse(input_filename)
+ documents = parse_xml(tree)
+ print("Number of documents: {}".format(len(documents)))
+ print("Number of paragraphs: {}".format(sum(len(document) for document in documents)))
+ return documents
+def parse_xml(tree):
# we will put each paragraph in a separate block in the output file
# we won't pay any attention to the document boundaries unless we
# later find out it was necessary
@@ -132,19 +139,22 @@ def read_data(input_filename):
words.append((word, False))
if len(words) == 0:
continue
+ words[-1] = (words[-1][0], True)
sentences.append(words)
paragraphs.append(sentences)
documents.append(paragraphs)
- print("Number of documents: {}".format(len(documents)))
- print("Number of paragraphs: {}".format(sum(len(document) for document in documents)))
return documents
-def main():
+def main(*args):
random.seed(1007)
- input_filename = sys.argv[1]
- output_dir = sys.argv[2]
+ if not args:
+ args = sys.argv[1:]
+ input_filename = args[0]
+ if os.path.isdir(input_filename):
+ input_filename = os.path.join(input_filename, "thai", "orchid", "xmlchid.xml")
+ output_dir = args[1]
documents = read_data(input_filename)
write_dataset(documents, output_dir, "orchid")
diff --git a/stanza/utils/datasets/tokenization/convert_vi_vlsp.py b/stanza/utils/datasets/tokenization/convert_vi_vlsp.py
new file mode 100644
index 00000000..947fe17f
--- /dev/null
+++ b/stanza/utils/datasets/tokenization/convert_vi_vlsp.py
@@ -0,0 +1,153 @@
+
+import os
+
+punctuation_set = (',', '.', '!', '?', ')', ':', ';', '”', '…', '...')
+
+def find_spaces(sentence):
+ # TODO: there are some sentences where there is only one quote,
+ # and some of them should be attached to the previous word instead
+ # of the next word. Training should work this way, though
+ odd_quotes = False
+
+ spaces = []
+ for word_idx, word in enumerate(sentence):
+ space = True
+ # Quote period at the end of a sentence needs to be attached
+ # to the rest of the text. Some sentences have `"... text`
+ # in the middle, though, so look for that
+ if word_idx < len(sentence) - 2 and sentence[word_idx+1] == '"':
+ if sentence[word_idx+2] == '.':
+ space = False
+ elif word_idx == len(sentence) - 3 and sentence[word_idx+2] == '...':
+ space = False
+ if word_idx < len(sentence) - 1:
+ if sentence[word_idx+1] in (',', '.', '!', '?', ')', ':', ';', '”', '…', '...','/', '%'):
+ space = False
+ if word in ('(', '“', '/'):
+ space = False
+ if word == '"':
+ if odd_quotes:
+ # already saw one quote. put this one at the end of the PREVIOUS word
+ # note that we know there must be at least one word already
+ odd_quotes = False
+ spaces[word_idx-1] = False
+ else:
+ odd_quotes = True
+ space = False
+ spaces.append(space)
+ return spaces
+
+def add_vlsp_args(parser):
+ parser.add_argument('--include_pos_data', action='store_true', default=False, help='To include or not POS training dataset for tokenization training. The path to POS dataset is expected to be in the same dir with WS path. For example, extern_dir/vietnamese/VLSP2013-POS-data')
+ parser.add_argument('--vlsp_include_spaces', action='store_true', default=False, help='When processing vi_vlsp tokenization, include all of the spaces. Otherwise, we try to turn the text back into standard text')
+def write_file(vlsp_include_spaces, output_filename, sentences, shard):
+ with open(output_filename, "w") as fout:
+ check_headlines = False
+ for sent_idx, sentence in enumerate(sentences):
+ fout.write("# sent_id = %s.%d\n" % (shard, sent_idx))
+ orig_text = " ".join(sentence)
+ #check if the previous line is a headline (no ending mark at the end) then make this sentence a new par
+ if check_headlines:
+ fout.write("# newpar id =%s.%d.1\n" % (shard, sent_idx))
+ check_headlines = False
+ if sentence[len(sentence) - 1] not in punctuation_set:
+ check_headlines = True
+
+ if vlsp_include_spaces:
+ fout.write("# text = %s\n" % orig_text)
+ else:
+ spaces = find_spaces(sentence)
+ full_text = ""
+ for word, space in zip(sentence, spaces):
+ # could be made more efficient, but shouldn't matter
+ full_text = full_text + word
+ if space:
+ full_text = full_text + " "
+ fout.write("# text = %s\n" % full_text)
+ fout.write("# orig_text = %s\n" % orig_text)
+ for word_idx, word in enumerate(sentence):
+ fake_dep = "root" if word_idx == 0 else "dep"
+ fout.write("%d\t%s\t%s" % ((word_idx+1), word, word))
+ fout.write("\t_\t_\t_")
+ fout.write("\t%d\t%s" % (word_idx, fake_dep))
+ fout.write("\t_\t")
+ if vlsp_include_spaces or spaces[word_idx]:
+ fout.write("_")
+ else:
+ fout.write("SpaceAfter=No")
+ fout.write("\n")
+ fout.write("\n")
+
+def convert_pos_dataset(file_path):
+ """
+ This function is to process the pos dataset
+ """
+
+ file = open(file_path, "r")
+ document = file.readlines()
+ sentences = []
+ sent = []
+ for line in document:
+ if line == "\n" and len(sent)>1:
+ if sent not in sentences:
+ sentences.append(sent)
+ sent = []
+ elif line != "\n":
+ sent.append(line.split("\t")[0].replace("_"," ").strip())
+ return sentences
+
+def convert_file(vlsp_include_spaces, input_filename, output_filename, shard, split_filename=None, split_shard=None, pos_data = None):
+ with open(input_filename) as fin:
+ lines = fin.readlines()
+
+ sentences = []
+ set_sentences = set()
+ for line in lines:
+ if len(line.replace("_", " ").split())>1:
+ words = line.split()
+ #one syllable lines are eliminated
+ if len(words) == 1 and len(words[0].split("_")) == 1:
+ continue
+ else:
+ words = [w.replace("_", " ") for w in words]
+ #only add sentences that hasn't been added before
+ if words not in sentences:
+ sentences.append(words)
+ set_sentences.add(' '.join(words))
+
+ if split_filename is not None:
+ # even this is a larger dev set than the train set
+ split_point = int(len(sentences) * 0.95)
+ #check pos_data that aren't overlapping with current VLSP WS dataset
+ sentences_pos = [] if pos_data is None else [sent for sent in pos_data if ' '.join(sent) not in set_sentences]
+ print("Added ", len(sentences_pos), " sentences from POS dataset.")
+ write_file(vlsp_include_spaces, output_filename, sentences[:split_point]+sentences_pos, shard)
+ write_file(vlsp_include_spaces, split_filename, sentences[split_point:], split_shard)
+ else:
+ write_file(vlsp_include_spaces, output_filename, sentences, shard)
+
+def convert_vi_vlsp(extern_dir, tokenizer_dir, args):
+ input_path = os.path.join(extern_dir, "vietnamese", "VLSP2013-WS-data")
+ input_pos_path = os.path.join(extern_dir, "vietnamese", "VLSP2013-POS-data")
+ input_train_filename = os.path.join(input_path, "VLSP2013_WS_train_gold.txt")
+ input_test_filename = os.path.join(input_path, "VLSP2013_WS_test_gold.txt")
+
+ input_pos_filename = os.path.join(input_pos_path, "VLSP2013_POS_train_BI_POS_Column.txt.goldSeg")
+ if not os.path.exists(input_train_filename):
+ raise FileNotFoundError("Cannot find train set for VLSP at %s" % input_train_filename)
+ if not os.path.exists(input_test_filename):
+ raise FileNotFoundError("Cannot find test set for VLSP at %s" % input_test_filename)
+ pos_data = None
+ if args.include_pos_data:
+ if not os.path.exists(input_pos_filename):
+ raise FileNotFoundError("Cannot find pos dataset for VLSP at %" % input_pos_filename)
+ else:
+ pos_data = convert_pos_dataset(input_pos_filename)
+
+ output_train_filename = os.path.join(tokenizer_dir, "vi_vlsp.train.gold.conllu")
+ output_dev_filename = os.path.join(tokenizer_dir, "vi_vlsp.dev.gold.conllu")
+ output_test_filename = os.path.join(tokenizer_dir, "vi_vlsp.test.gold.conllu")
+
+ convert_file(args.vlsp_include_spaces, input_train_filename, output_train_filename, "train", output_dev_filename, "dev", pos_data)
+ convert_file(args.vlsp_include_spaces, input_test_filename, output_test_filename, "test")
+
diff --git a/stanza/utils/datasets/tokenization/process_thai_tokenization.py b/stanza/utils/datasets/tokenization/process_thai_tokenization.py
new file mode 100644
index 00000000..5ef0e3d5
--- /dev/null
+++ b/stanza/utils/datasets/tokenization/process_thai_tokenization.py
@@ -0,0 +1,187 @@
+import os
+import random
+
+try:
+ from pythainlp import sent_tokenize
+except ImportError:
+ pass
+
+def write_section(output_dir, dataset_name, section, documents):
+ """
+ Writes a list of documents for tokenization, including a file in conll format
+
+ The Thai datasets generally have no MWT (apparently not relevant for Thai)
+
+ output_dir: the destination directory for the output files
+ dataset_name: orchid, BEST, lst20, etc
+ section: train/dev/test
+ documents: a nested list of documents, paragraphs, sentences, words
+ words is a list of (word, space_follows)
+ """
+ with open(os.path.join(output_dir, 'th_%s-ud-%s-mwt.json' % (dataset_name, section)), 'w') as fout:
+ fout.write("[]\n")
+
+ text_out = open(os.path.join(output_dir, 'th_%s.%s.txt' % (dataset_name, section)), 'w')
+ label_out = open(os.path.join(output_dir, 'th_%s-ud-%s.toklabels' % (dataset_name, section)), 'w')
+ for document in documents:
+ for paragraph in document:
+ for sentence_idx, sentence in enumerate(paragraph):
+ for word_idx, word in enumerate(sentence):
+ # TODO: split with newlines to make it more readable?
+ text_out.write(word[0])
+ for i in range(len(word[0]) - 1):
+ label_out.write("0")
+ if word_idx == len(sentence) - 1:
+ label_out.write("2")
+ else:
+ label_out.write("1")
+ if word[1] and (sentence_idx != len(paragraph) - 1 or word_idx != len(sentence) - 1):
+ text_out.write(' ')
+ label_out.write('0')
+
+ text_out.write("\n\n")
+ label_out.write("\n\n")
+
+ text_out.close()
+ label_out.close()
+
+ with open(os.path.join(output_dir, 'th_%s.%s.gold.conllu' % (dataset_name, section)), 'w') as fout:
+ for document in documents:
+ for paragraph in document:
+ new_par = True
+ for sentence in paragraph:
+ for word_idx, word in enumerate(sentence):
+ # SpaceAfter is left blank if there is space after the word
+ if word[1] and new_par:
+ space = 'NewPar=Yes'
+ elif word[1]:
+ space = '_'
+ elif new_par:
+ space = 'SpaceAfter=No|NewPar=Yes'
+ else:
+ space = 'SpaceAfter=No'
+ new_par = False
+
+ # Note the faked dependency structure: the conll reading code
+ # needs it even if it isn't being used in any way
+ fake_dep = 'root' if word_idx == 0 else 'dep'
+ fout.write('{}\t{}\t_\t_\t_\t_\t{}\t{}\t{}:{}\t{}\n'.format(word_idx+1, word[0], word_idx, fake_dep, word_idx, fake_dep, space))
+ fout.write('\n')
+
+def write_dataset(documents, output_dir, dataset_name):
+ """
+ Shuffle a list of documents, write three sections
+ """
+ random.shuffle(documents)
+ num_train = int(len(documents) * 0.8)
+ num_dev = int(len(documents) * 0.1)
+ os.makedirs(output_dir, exist_ok=True)
+ write_section(output_dir, dataset_name, 'train', documents[:num_train])
+ write_section(output_dir, dataset_name, 'dev', documents[num_train:num_train+num_dev])
+ write_section(output_dir, dataset_name, 'test', documents[num_train+num_dev:])
+
+def write_dataset_best(documents, test_documents, output_dir, dataset_name):
+ """
+ Shuffle a list of documents, write three sections
+ """
+ random.shuffle(documents)
+ num_train = int(len(documents) * 0.85)
+ num_dev = int(len(documents) * 0.15)
+ os.makedirs(output_dir, exist_ok=True)
+ write_section(output_dir, dataset_name, 'train', documents[:num_train])
+ write_section(output_dir, dataset_name, 'dev', documents[num_train:num_train+num_dev])
+ write_section(output_dir, dataset_name, 'test', test_documents)
+
+
+def reprocess_lines(processed_lines):
+ """
+ Reprocesses lines using pythainlp to cut up sentences into shorter sentences.
+
+ Many of the lines in BEST seem to be multiple Thai sentences concatenated, according to native Thai speakers.
+
+ Input: a list of lines, where each line is a list of words. Space characters can be included as words
+ Output: a new list of lines, resplit using pythainlp
+ """
+ reprocessed_lines = []
+ for line in processed_lines:
+ text = "".join(line)
+ try:
+ chunks = sent_tokenize(text)
+ except NameError as e:
+ raise NameError("Sentences cannot be reprocessed without first installing pythainlp") from e
+ # Check that the total text back is the same as the text in
+ if sum(len(x) for x in chunks) != len(text):
+ raise ValueError("Got unexpected text length: \n{}\nvs\n{}".format(text, chunks))
+
+ chunk_lengths = [len(x) for x in chunks]
+
+ current_length = 0
+ new_line = []
+ for word in line:
+ if len(word) + current_length < chunk_lengths[0]:
+ new_line.append(word)
+ current_length = current_length + len(word)
+ elif len(word) + current_length == chunk_lengths[0]:
+ new_line.append(word)
+ reprocessed_lines.append(new_line)
+ new_line = []
+ chunk_lengths = chunk_lengths[1:]
+ current_length = 0
+ else:
+ remaining_len = chunk_lengths[0] - current_length
+ new_line.append(word[:remaining_len])
+ reprocessed_lines.append(new_line)
+ word = word[remaining_len:]
+ chunk_lengths = chunk_lengths[1:]
+ while len(word) > chunk_lengths[0]:
+ new_line = [word[:chunk_lengths[0]]]
+ reprocessed_lines.append(new_line)
+ word = word[chunk_lengths[0]:]
+ chunk_lengths = chunk_lengths[1:]
+ new_line = [word]
+ current_length = len(word)
+ reprocessed_lines.append(new_line)
+ return reprocessed_lines
+
+def convert_processed_lines(processed_lines):
+ """
+ Convert a list of sentences into documents suitable for the output methods in this module.
+
+ Input: a list of lines, including space words
+ Output: a list of documents, each document containing a list of sentences
+ Each sentence is a list of words: (text, space_follows)
+ Space words will be eliminated.
+ """
+ paragraphs = []
+ sentences = []
+ for words in processed_lines:
+ # turn the words into a sentence
+ if len(words) > 1 and " " == words[0]:
+ words = words[1:]
+ elif len(words) == 1 and " " == words[0]:
+ words = []
+
+ sentence = []
+ for word in words:
+ word = word.strip()
+ if not word:
+ if len(sentence) == 0:
+ print(word)
+ raise ValueError("Unexpected space at start of sentence in document {}".format(filename))
+ sentence[-1] = (sentence[-1][0], True)
+ else:
+ sentence.append((word, False))
+ # blank lines are very rare in best, but why not treat them as a paragraph break
+ if len(sentence) == 0:
+ paragraphs.append([sentences])
+ sentences = []
+ continue
+ sentence[-1] = (sentence[-1][0], True)
+ sentences.append(sentence)
+ paragraphs.append([sentences])
+ return paragraphs
+
+
+
+
+
diff --git a/stanza/utils/default_paths.py b/stanza/utils/default_paths.py
index ce40efc2..6326ba10 100644
--- a/stanza/utils/default_paths.py
+++ b/stanza/utils/default_paths.py
@@ -25,13 +25,16 @@ def get_default_paths():
# TODO: not sure what other people actually have
# TODO: also, could make this automatically update to the latest
- "UDBASE": "extern_data/ud2/ud-treebanks-v2.7",
+ "UDBASE": "extern_data/ud2/ud-treebanks-v2.8",
"NERBASE": "extern_data/ner",
# there's a stanford github, stanfordnlp/handparsed-treebank,
# with some data for different languages
"HANDPARSED_DIR": "extern_data/handparsed-treebank",
+
+ # data root for other general input files, such as VI_VLSP
+ "EXTERN_DIR": "extern_data"
}
paths = { "DATA_ROOT" : DATA_ROOT }