Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/stanfordnlp/stanza.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJohn Bauer <horatio@gmail.com>2021-07-13 07:49:50 +0300
committerGitHub <noreply@github.com>2021-07-13 07:49:50 +0300
commitde44be871282e05f79f23f5f5e284aceb672726b (patch)
treea933dd38dab2e5f8aebddb3b23ba40c6337e9705
parent0a61b80d44d2868a7044ad389917287fea754614 (diff)
parentd193a5a34a2e055f16ac69e96568759e059006e7 (diff)
Merge pull request #749 from stanfordnlp/1.2.2v1.2.2
1.2.2
-rwxr-xr-xscripts/run_ner.sh33
-rw-r--r--stanza/_version.py4
-rw-r--r--stanza/models/classifiers/cnn_classifier.py18
-rw-r--r--stanza/models/common/pretrain.py6
-rw-r--r--stanza/models/common/utils.py47
-rw-r--r--stanza/models/common/vocab.py3
-rw-r--r--stanza/models/ner/model.py7
-rw-r--r--stanza/models/ner_tagger.py18
-rw-r--r--stanza/models/pos/model.py2
-rw-r--r--stanza/pipeline/ner_processor.py9
-rw-r--r--stanza/pipeline/sentiment_processor.py6
-rw-r--r--stanza/resources/common.py8
-rw-r--r--stanza/resources/prepare_resources.py2
-rw-r--r--stanza/tests/test_english_pipeline.py15
-rw-r--r--stanza/tests/test_pipeline_ner_processor.py81
-rw-r--r--stanza/tests/test_pipeline_sentiment_processor.py38
-rw-r--r--stanza/tests/test_utils.py39
-rw-r--r--stanza/utils/datasets/ner/prepare_ner_dataset.py5
-rw-r--r--stanza/utils/training/common.py6
-rw-r--r--stanza/utils/training/run_ner.py167
20 files changed, 451 insertions, 63 deletions
diff --git a/scripts/run_ner.sh b/scripts/run_ner.sh
deleted file mode 100755
index 4edaf931..00000000
--- a/scripts/run_ner.sh
+++ /dev/null
@@ -1,33 +0,0 @@
-#!/bin/bash
-#
-# Train and evaluate NER tagger. Run as:
-# ./run_ner.sh CORPUS OTHER_ARGS
-# where CORPUS is the full corpus name (e.g., English-CoNLL03) and OTHER_ARGS are additional training arguments (see tagger code) or empty.
-# This script assumes UDBASE and NER_DATA_DIR are correctly set in config.sh.
-
-source scripts/config.sh
-
-corpus=$1; shift
-args=$@
-
-lang=`echo $corpus | sed -e 's#-.*$##g'`
-lcode=`python scripts/lang2code.py $lang`
-corpus_name=`echo $corpus | sed -e 's#^.*-##g' | tr '[:upper:]' '[:lower:]'`
-short=${lcode}_${corpus_name}
-
-train_file=${NER_DATA_DIR}/${short}.train.json
-dev_file=${NER_DATA_DIR}/${short}.dev.json
-test_file=${NER_DATA_DIR}/${short}.test.json
-
-if [ ! -e $train_file ]; then
- bash scripts/prep_ner_data.sh $corpus
-fi
-
-echo "Running ner with $args..."
-python -m stanza.models.ner_tagger --wordvec_dir $WORDVEC_DIR --train_file $train_file --eval_file $dev_file \
- --lang $lang --shorthand $short --mode train $args
-python -m stanza.models.ner_tagger --wordvec_dir $WORDVEC_DIR --eval_file $dev_file \
- --lang $lang --shorthand $short --mode predict $args
-python -m stanza.models.ner_tagger --wordvec_dir $WORDVEC_DIR --eval_file $test_file \
- --lang $lang --shorthand $short --mode predict $args
-
diff --git a/stanza/_version.py b/stanza/_version.py
index 4bce72e3..100ecb6e 100644
--- a/stanza/_version.py
+++ b/stanza/_version.py
@@ -1,4 +1,4 @@
""" Single source of truth for version number """
-__version__ = "1.2.1"
-__resources_version__ = '1.2.1'
+__version__ = "1.2.2"
+__resources_version__ = '1.2.2'
diff --git a/stanza/models/classifiers/cnn_classifier.py b/stanza/models/classifiers/cnn_classifier.py
index 3a329db8..fa5160bf 100644
--- a/stanza/models/classifiers/cnn_classifier.py
+++ b/stanza/models/classifiers/cnn_classifier.py
@@ -12,6 +12,7 @@ import stanza.models.classifiers.classifier_args as classifier_args
import stanza.models.classifiers.data as data
from stanza.models.common.vocab import PAD_ID, UNK_ID
from stanza.models.common.data import get_long_tensor, sort_all
+from stanza.models.common.utils import split_into_batches, sort_with_indices, unsort
# TODO: move CharVocab to common
from stanza.models.pos.vocab import CharVocab
@@ -73,10 +74,7 @@ class CNNClassifier(nn.Module):
charlm_projection = args.charlm_projection,
model_type = 'CNNClassifier')
- if args.char_lowercase:
- self.char_case = lambda x: x.lower()
- else:
- self.char_case = lambda x: x
+ self.char_lowercase = args.char_lowercase
self.unsaved_modules = []
@@ -169,7 +167,6 @@ class CNNClassifier(nn.Module):
self.dropout = nn.Dropout(self.config.dropout)
-
def add_unsaved_module(self, name, module):
self.unsaved_modules += [name]
setattr(self, name, module)
@@ -200,6 +197,8 @@ class CNNClassifier(nn.Module):
return char_reps
+ def char_case(self, x: str) -> str:
+ return x.lower() if self.char_lowercase else x
def forward(self, inputs, device=None):
if not device:
@@ -445,9 +444,10 @@ def label_text(model, text, batch_size=None, reverse_label_map=None, device=None
if batch_size is None:
intervals = [(0, len(text))]
+ orig_idx = None
else:
- # TODO: results would be better if we sort by length and then unsort
- intervals = [(i, min(i+batch_size, len(text))) for i in range(0, len(text), batch_size)]
+ text, orig_idx = sort_with_indices(text, key=len, reverse=True)
+ intervals = split_into_batches(text, batch_size)
labels = []
for interval in intervals:
if interval[1] - interval[0] == 0:
@@ -457,6 +457,10 @@ def label_text(model, text, batch_size=None, reverse_label_map=None, device=None
predicted = torch.argmax(output, dim=1)
labels.extend(predicted.tolist())
+ if orig_idx:
+ text = unsort(text, orig_idx)
+ labels = unsort(labels, orig_idx)
+
logger.debug("Found labels")
for (label, sentence) in zip(labels, text):
logger.debug((label, sentence))
diff --git a/stanza/models/common/pretrain.py b/stanza/models/common/pretrain.py
index e18accbf..193cc71d 100644
--- a/stanza/models/common/pretrain.py
+++ b/stanza/models/common/pretrain.py
@@ -20,6 +20,12 @@ class PretrainedWordVocab(BaseVocab):
self._id2unit = VOCAB_PREFIX + self.data
self._unit2id = {w:i for i, w in enumerate(self._id2unit)}
+ def normalize_unit(self, unit):
+ unit = super().normalize_unit(unit)
+ if unit:
+ unit = unit.replace(" ","\xa0")
+ return unit
+
class Pretrain:
""" A loader and saver for pretrained embeddings. """
diff --git a/stanza/models/common/utils.py b/stanza/models/common/utils.py
index a739d366..32f1b2f8 100644
--- a/stanza/models/common/utils.py
+++ b/stanza/models/common/utils.py
@@ -207,6 +207,53 @@ def unsort(sorted_list, oidx):
_, unsorted = [list(t) for t in zip(*sorted(zip(oidx, sorted_list)))]
return unsorted
+def sort_with_indices(data, key=None, reverse=False):
+ """
+ Sort data and return both the data and the original indices.
+
+ One useful application is to sort by length, which can be done with key=len
+ Returns the data as a sorted list, then the indices of the original list.
+ """
+ if key:
+ ordered = sorted(enumerate(data), key=lambda x: key(x[1]), reverse=reverse)
+ else:
+ ordered = sorted(enumerate(data), key=lambda x: x[1], reverse=reverse)
+
+ result = tuple(zip(*ordered))
+ return result[1], result[0]
+
+def split_into_batches(data, batch_size):
+ """
+ Returns a list of intervals so that each interval is either <= batch_size or one element long.
+
+ Long elements are not dropped from the intervals.
+ data is a list of lists
+ batch_size is how long to make each batch
+ return value is a list of pairs, start_idx end_idx
+ """
+ intervals = []
+ interval_start = 0
+ interval_size = 0
+ for idx, line in enumerate(data):
+ if len(line) > batch_size:
+ # guess we'll just hope the model can handle a batch of this size after all
+ if interval_size > 0:
+ intervals.append((interval_start, idx))
+ intervals.append((idx, idx+1))
+ interval_start = idx+1
+ interval_size = 0
+ elif len(line) + interval_size > batch_size:
+ # this line puts us over batch_size
+ intervals.append((interval_start, idx))
+ interval_start = idx
+ interval_size = len(line)
+ else:
+ interval_size = interval_size + len(line)
+ if interval_size > 0:
+ # there's some leftover
+ intervals.append((interval_start, len(data)))
+ return intervals
+
def tensor_unsort(sorted_tensor, oidx):
"""
Unsort a sorted tensor on its 0-th dimension, based on the original idx.
diff --git a/stanza/models/common/vocab.py b/stanza/models/common/vocab.py
index e3e2c300..cade67c3 100644
--- a/stanza/models/common/vocab.py
+++ b/stanza/models/common/vocab.py
@@ -47,9 +47,10 @@ class BaseVocab:
return new
def normalize_unit(self, unit):
+ # be sure to look in subclasses for other normalization being done
+ # especially PretrainWordVocab
if unit is None:
return unit
- unit = unit.replace(" ","\xa0")
if self.lower:
return unit.lower()
return unit
diff --git a/stanza/models/ner/model.py b/stanza/models/ner/model.py
index bf8e25b1..efad8d51 100644
--- a/stanza/models/ner/model.py
+++ b/stanza/models/ner/model.py
@@ -37,10 +37,11 @@ class NERTagger(nn.Module):
if self.args['charlm']:
add_unsaved_module('charmodel_forward', CharacterLanguageModel.load(args['charlm_forward_file'], finetune=False))
add_unsaved_module('charmodel_backward', CharacterLanguageModel.load(args['charlm_backward_file'], finetune=False))
+ input_size += self.charmodel_forward.hidden_dim() + self.charmodel_backward.hidden_dim()
else:
self.charmodel = CharacterModel(args, vocab, bidirectional=True, attention=False)
- input_size += self.args['char_hidden_dim'] * 2
-
+ input_size += self.args['char_hidden_dim'] * 2
+
# optionally add a input transformation layer
if self.args.get('input_transform', False):
self.input_transform = nn.Linear(input_size, input_size)
@@ -73,7 +74,7 @@ class NERTagger(nn.Module):
vocab_size = len(self.vocab['word'])
dim = self.args['word_emb_dim']
assert emb_matrix.size() == (vocab_size, dim), \
- "Input embedding matrix must match size: {} x {}".format(vocab_size, dim)
+ "Input embedding matrix must match size: {} x {}, found {}".format(vocab_size, dim, emb_matrix.size())
self.word_emb.weight.data.copy_(emb_matrix)
def forward(self, word, word_mask, wordchars, wordchars_mask, tags, word_orig_idx, sentlens, wordlens, chars, charoffsets, charlens, char_orig_idx):
diff --git a/stanza/models/ner_tagger.py b/stanza/models/ner_tagger.py
index f1525ca7..b4b0a09f 100644
--- a/stanza/models/ner_tagger.py
+++ b/stanza/models/ner_tagger.py
@@ -31,7 +31,7 @@ logger = logging.getLogger('stanza')
def parse_args(args=None):
parser = argparse.ArgumentParser()
- parser.add_argument('--data_dir', type=str, default='data/ner', help='Root dir for saving models.')
+ parser.add_argument('--data_dir', type=str, default='data/ner', help='Directory of NER data.')
parser.add_argument('--wordvec_dir', type=str, default='extern_data/word2vec', help='Directory of word vectors')
parser.add_argument('--wordvec_file', type=str, default='', help='File that contains word vectors')
parser.add_argument('--wordvec_pretrain_file', type=str, default=None, help='Exact name of the pretrain file to read')
@@ -62,6 +62,8 @@ def parse_args(args=None):
parser.add_argument('--charlm', action='store_true', help="Turn on contextualized char embedding using pretrained character-level language model.")
parser.add_argument('--charlm_save_dir', type=str, default='saved_models/charlm', help="Root dir for pretrained character-level language model.")
parser.add_argument('--charlm_shorthand', type=str, default=None, help="Shorthand for character-level language model training corpus.")
+ parser.add_argument('--charlm_forward_file', type=str, default=None, help="Exact path to use for forward charlm")
+ parser.add_argument('--charlm_backward_file', type=str, default=None, help="Exact path to use for backward charlm")
parser.add_argument('--char_lowercase', dest='char_lowercase', action='store_true', help="Use lowercased characters in character model.")
parser.add_argument('--no_lowercase', dest='lowercase', action='store_false', help="Use cased word vectors.")
parser.add_argument('--no_emb_finetune', dest='emb_finetune', action='store_false', help="Turn off finetuning of the embedding matrix.")
@@ -114,6 +116,7 @@ def train(args):
pretrain = None
vocab = None
trainer = None
+
if args['finetune'] and os.path.exists(model_file):
logger.warning('Finetune is ON. Using model from "{}"'.format(model_file))
_, trainer, vocab = load_model(args, model_file)
@@ -137,8 +140,10 @@ def train(args):
if args['charlm_shorthand'] is None:
raise ValueError("CharLM Shorthand is required for loading pretrained CharLM model...")
logger.info('Using pretrained contextualized char embedding')
- args['charlm_forward_file'] = '{}/{}_forward_charlm.pt'.format(args['charlm_save_dir'], args['charlm_shorthand'])
- args['charlm_backward_file'] = '{}/{}_backward_charlm.pt'.format(args['charlm_save_dir'], args['charlm_shorthand'])
+ if not args['charlm_forward_file']:
+ args['charlm_forward_file'] = '{}/{}_forward_charlm.pt'.format(args['charlm_save_dir'], args['charlm_shorthand'])
+ if not args['charlm_backward_file']:
+ args['charlm_backward_file'] = '{}/{}_backward_charlm.pt'.format(args['charlm_save_dir'], args['charlm_shorthand'])
# load data
logger.info("Loading data with batch size {}...".format(args['batch_size']))
@@ -258,7 +263,12 @@ def evaluate(args):
def load_model(args, model_file):
# load model
use_cuda = args['cuda'] and not args['cpu']
- trainer = Trainer(model_file=model_file, use_cuda=use_cuda, train_classifier_only=args['train_classifier_only'])
+ charlm_args = {}
+ if 'charlm_forward_file' in args:
+ charlm_args['charlm_forward_file'] = args['charlm_forward_file']
+ if 'charlm_backward_file' in args:
+ charlm_args['charlm_backward_file'] = args['charlm_backward_file']
+ trainer = Trainer(args=charlm_args, model_file=model_file, use_cuda=use_cuda, train_classifier_only=args['train_classifier_only'])
loaded_args, vocab = trainer.args, trainer.vocab
# load config
diff --git a/stanza/models/pos/model.py b/stanza/models/pos/model.py
index f7af1ffc..452f7dda 100644
--- a/stanza/models/pos/model.py
+++ b/stanza/models/pos/model.py
@@ -39,7 +39,7 @@ class Tagger(nn.Module):
self.trans_char = nn.Linear(self.args['char_hidden_dim'], self.args['transformed_dim'], bias=False)
input_size += self.args['transformed_dim']
- if self.args['pretrain']:
+ if self.args['pretrain']:
# pretrained embeddings, by default this won't be saved into model file
add_unsaved_module('pretrained_emb', nn.Embedding.from_pretrained(torch.from_numpy(emb_matrix), freeze=True))
self.trans_pretrained = nn.Linear(emb_matrix.shape[1], self.args['transformed_dim'], bias=False)
diff --git a/stanza/pipeline/ner_processor.py b/stanza/pipeline/ner_processor.py
index eab66b62..52003961 100644
--- a/stanza/pipeline/ner_processor.py
+++ b/stanza/pipeline/ner_processor.py
@@ -38,3 +38,12 @@ class NERProcessor(UDProcessor):
total = len(batch.doc.build_ents())
logger.debug(f'{total} entities found in document.')
return batch.doc
+
+ def bulk_process(self, docs):
+ """
+ NER processor has a collation step after running inference
+ """
+ docs = super().bulk_process(docs)
+ for doc in docs:
+ doc.build_ents()
+ return docs
diff --git a/stanza/pipeline/sentiment_processor.py b/stanza/pipeline/sentiment_processor.py
index a96c80a0..48117142 100644
--- a/stanza/pipeline/sentiment_processor.py
+++ b/stanza/pipeline/sentiment_processor.py
@@ -24,6 +24,9 @@ class SentimentProcessor(UDProcessor):
# set of processor requirements for this processor
REQUIRES_DEFAULT = set([TOKENIZE])
+ # default batch size, measured in words per batch
+ DEFAULT_BATCH_SIZE = 5000
+
def _set_up_model(self, config, use_gpu):
# get pretrained word vectors
pretrain_path = config.get('pretrain_path', None)
@@ -37,7 +40,8 @@ class SentimentProcessor(UDProcessor):
pretrain=self._pretrain,
charmodel_forward=charmodel_forward,
charmodel_backward=charmodel_backward)
- self._batch_size = config.get('batch_size', None)
+ # batch size counted as words
+ self._batch_size = config.get('batch_size', SentimentProcessor.DEFAULT_BATCH_SIZE)
# TODO: move this call to load()
if use_gpu:
diff --git a/stanza/resources/common.py b/stanza/resources/common.py
index 8e70e861..d9837468 100644
--- a/stanza/resources/common.py
+++ b/stanza/resources/common.py
@@ -103,6 +103,12 @@ def file_exists(path, md5):
"""
return os.path.exists(path) and get_md5(path) == md5
+def assert_file_exists(path, md5=None):
+ assert os.path.exists(path), "Could not find file at %s" % path
+ if md5:
+ file_md5 = get_md5(path)
+ assert file_md5 == md5, "md5 for %s is %s, expected %s" % (path, file_md5, md5)
+
def download_file(url, path, proxies, raise_for_status=False):
"""
Download a URL into a file as specified by `path`.
@@ -134,7 +140,7 @@ def request_file(url, path, proxies=None, md5=None, raise_for_status=False):
logger.info(f'File exists: {path}.')
return
download_file(url, path, proxies, raise_for_status)
- assert(not md5 or file_exists(path, md5))
+ assert_file_exists(path, md5)
def sort_processors(processor_list):
sorted_list = []
diff --git a/stanza/resources/prepare_resources.py b/stanza/resources/prepare_resources.py
index 31791e21..31177863 100644
--- a/stanza/resources/prepare_resources.py
+++ b/stanza/resources/prepare_resources.py
@@ -100,6 +100,7 @@ default_ners = {
"nl": "conll02",
"ru": "wikiner",
"uk": "languk",
+ "vi": "vlsp",
"zh-hans": "ontonotes",
}
@@ -115,6 +116,7 @@ default_charlms = {
"fr": "newswiki",
"nl": "ccwiki",
"ru": "newswiki",
+ "vi": "conll17",
"zh-hans": "gigaword"
}
diff --git a/stanza/tests/test_english_pipeline.py b/stanza/tests/test_english_pipeline.py
index 8c89774b..73569a9a 100644
--- a/stanza/tests/test_english_pipeline.py
+++ b/stanza/tests/test_english_pipeline.py
@@ -135,11 +135,13 @@ EN_DOC_CONLLU_GOLD_MULTIDOC = """
@pytest.fixture(scope="module")
-def processed_doc():
- """ Document created by running full English pipeline on a few sentences """
- nlp = stanza.Pipeline(dir=TEST_MODELS_DIR)
- return nlp(EN_DOC)
+def pipeline():
+ return stanza.Pipeline(dir=TEST_MODELS_DIR)
+@pytest.fixture(scope="module")
+def processed_doc(pipeline):
+ """ Document created by running full English pipeline on a few sentences """
+ return pipeline(EN_DOC)
def test_text(processed_doc):
assert processed_doc.text == EN_DOC
@@ -163,11 +165,10 @@ def test_dependency_parse(processed_doc):
@pytest.fixture(scope="module")
-def processed_multidoc():
+def processed_multidoc(pipeline):
""" Document created by running full English pipeline on a few sentences """
docs = [Document([], text=t) for t in EN_DOCS]
- nlp = stanza.Pipeline(dir=TEST_MODELS_DIR)
- return nlp(docs)
+ return pipeline(docs)
def test_conllu_multidoc(processed_multidoc):
diff --git a/stanza/tests/test_pipeline_ner_processor.py b/stanza/tests/test_pipeline_ner_processor.py
new file mode 100644
index 00000000..3f88a8d0
--- /dev/null
+++ b/stanza/tests/test_pipeline_ner_processor.py
@@ -0,0 +1,81 @@
+
+import pytest
+import stanza
+from stanza.utils.conll import CoNLL
+from stanza.models.common.doc import Document
+
+from stanza.tests import *
+
+pytestmark = [pytest.mark.pipeline, pytest.mark.travis]
+
+# data for testing
+EN_DOCS = ["Barack Obama was born in Hawaii.", "He was elected president in 2008.", "Obama attended Harvard."]
+
+EXPECTED_ENTS = [[{
+ "text": "Barack Obama",
+ "type": "PERSON",
+ "start_char": 0,
+ "end_char": 12
+}, {
+ "text": "Hawaii",
+ "type": "GPE",
+ "start_char": 25,
+ "end_char": 31
+}],
+[{
+ "text": "2008",
+ "type": "DATE",
+ "start_char": 28,
+ "end_char": 32
+}],
+[{
+ "text": "Obama",
+ "type": "PERSON",
+ "start_char": 0,
+ "end_char": 5
+}, {
+ "text": "Harvard",
+ "type": "ORG",
+ "start_char": 15,
+ "end_char": 22
+}]]
+
+
+@pytest.fixture(scope="module")
+def pipeline():
+ """
+ A reusable pipeline with the NER module
+ """
+ return stanza.Pipeline(dir=TEST_MODELS_DIR, processors="tokenize,ner")
+
+
+@pytest.fixture(scope="module")
+def processed_doc(pipeline):
+ """ Document created by running full English pipeline on a few sentences """
+ return [pipeline(text) for text in EN_DOCS]
+
+
+@pytest.fixture(scope="module")
+def processed_bulk(pipeline):
+ """ Document created by running full English pipeline on a few sentences """
+ docs = [Document([], text=t) for t in EN_DOCS]
+ return pipeline(docs)
+
+def check_entities_equal(doc, expected):
+ """
+ Checks that the entities of a doc are equal to the given list of maps
+ """
+ assert len(doc.ents) == len(expected)
+ for doc_entity, expected_entity in zip(doc.ents, expected):
+ for k in expected_entity:
+ assert getattr(doc_entity, k) == expected_entity[k]
+
+def test_bulk_ents(processed_bulk):
+ assert len(processed_bulk) == len(EXPECTED_ENTS)
+ for doc, expected in zip(processed_bulk, EXPECTED_ENTS):
+ check_entities_equal(doc, expected)
+
+def test_ents(processed_doc):
+ assert len(processed_doc) == len(EXPECTED_ENTS)
+ for doc, expected in zip(processed_doc, EXPECTED_ENTS):
+ check_entities_equal(doc, expected)
diff --git a/stanza/tests/test_pipeline_sentiment_processor.py b/stanza/tests/test_pipeline_sentiment_processor.py
new file mode 100644
index 00000000..b46eedf4
--- /dev/null
+++ b/stanza/tests/test_pipeline_sentiment_processor.py
@@ -0,0 +1,38 @@
+
+import pytest
+import stanza
+from stanza.utils.conll import CoNLL
+from stanza.models.common.doc import Document
+
+from stanza.tests import *
+
+pytestmark = [pytest.mark.pipeline, pytest.mark.travis]
+
+# data for testing
+EN_DOCS = ["Ragavan is terrible and should go away.", "Today is okay.", "Urza's Saga is great."]
+
+EN_DOC = " ".join(EN_DOCS)
+
+EXPECTED = [0, 1, 2]
+
+@pytest.fixture(scope="module")
+def pipeline():
+ """
+ A reusable pipeline with the NER module
+ """
+ return stanza.Pipeline(dir=TEST_MODELS_DIR, processors="tokenize,sentiment")
+
+def test_simple(pipeline):
+ results = []
+ for text in EN_DOCS:
+ doc = pipeline(text)
+ assert len(doc.sentences) == 1
+ results.append(doc.sentences[0].sentiment)
+ assert EXPECTED == results
+
+def test_multiple_sentences(pipeline):
+ doc = pipeline(EN_DOC)
+ assert len(doc.sentences) == 3
+ results = [sentence.sentiment for sentence in doc.sentences]
+ assert EXPECTED == results
+
diff --git a/stanza/tests/test_utils.py b/stanza/tests/test_utils.py
index 7b654492..bc5cf4e4 100644
--- a/stanza/tests/test_utils.py
+++ b/stanza/tests/test_utils.py
@@ -75,3 +75,42 @@ def test_wordvec_type():
with pytest.raises(FileNotFoundError):
utils.get_wordvec_file(wordvec_dir=temp_dir, shorthand='en_foo')
+def test_sort_with_indices():
+ data = [[1, 2, 3], [4, 5], [6]]
+ ordered, orig_idx = utils.sort_with_indices(data, key=len)
+ assert ordered == ([6], [4, 5], [1, 2, 3])
+ assert orig_idx == (2, 1, 0)
+
+ unsorted = utils.unsort(ordered, orig_idx)
+ assert data == unsorted
+
+def test_split_into_batches():
+ data = []
+ for i in range(5):
+ data.append(["Unban", "mox", "opal", str(i)])
+
+ data.append(["Do", "n't", "ban", "Urza", "'s", "Saga", "that", "card", "is", "great"])
+ data.append(["Ban", "Ragavan"])
+
+ # small batches will put one element in each interval
+ batches = utils.split_into_batches(data, 5)
+ assert batches == [(0, 1), (1, 2), (2, 3), (3, 4), (4, 5), (5, 6), (6, 7)]
+
+ # this one has a batch interrupted in the middle by a large element
+ batches = utils.split_into_batches(data, 8)
+ assert batches == [(0, 2), (2, 4), (4, 5), (5, 6), (6, 7)]
+
+ # this one has the large element at the start of its own batch
+ batches = utils.split_into_batches(data[1:], 8)
+ assert batches == [(0, 2), (2, 4), (4, 5), (5, 6)]
+
+ # overloading the test! assert that the key & reverse is working
+ ordered, orig_idx = utils.sort_with_indices(data, key=len, reverse=True)
+ assert [len(x) for x in ordered] == [10, 4, 4, 4, 4, 4, 2]
+
+ # this has the large element at the start
+ batches = utils.split_into_batches(ordered, 8)
+ assert batches == [(0, 1), (1, 3), (3, 5), (5, 7)]
+
+ # double check that unsort is working as expected
+ assert data == utils.unsort(ordered, orig_idx)
diff --git a/stanza/utils/datasets/ner/prepare_ner_dataset.py b/stanza/utils/datasets/ner/prepare_ner_dataset.py
index 9d7e089a..54a2c7e3 100644
--- a/stanza/utils/datasets/ner/prepare_ner_dataset.py
+++ b/stanza/utils/datasets/ner/prepare_ner_dataset.py
@@ -316,10 +316,9 @@ def process_bsnlp(paths, short_name):
output_filename = os.path.join(base_output_path, '%s.%s.json' % (short_name, shard))
prepare_ner_file.process_dataset(csv_file, output_filename)
-def main():
+def main(dataset_name):
paths = default_paths.get_default_paths()
- dataset_name = sys.argv[1]
random.seed(1234)
if dataset_name == 'fi_turku':
@@ -344,4 +343,4 @@ def main():
raise ValueError(f"dataset {dataset_name} currently not handled")
if __name__ == '__main__':
- main()
+ main(sys.argv[1])
diff --git a/stanza/utils/training/common.py b/stanza/utils/training/common.py
index b414bf56..c3635bbb 100644
--- a/stanza/utils/training/common.py
+++ b/stanza/utils/training/common.py
@@ -40,6 +40,12 @@ def build_argparse():
SHORTNAME_RE = re.compile("[a-z-]+_[a-z0-9]+")
def main(run_treebank, model_dir, model_name, add_specific_args=None):
+ """
+ A main program for each of the run_xyz scripts
+
+ It collects the arguments and runs the main method for each dataset provided.
+ It also tries to look for an existing model and not overwrite it unless --force is provided
+ """
logger.info("Training program called with:\n" + " ".join(sys.argv))
paths = default_paths.get_default_paths()
diff --git a/stanza/utils/training/run_ner.py b/stanza/utils/training/run_ner.py
new file mode 100644
index 00000000..1ee9979f
--- /dev/null
+++ b/stanza/utils/training/run_ner.py
@@ -0,0 +1,167 @@
+"""
+Trains or scores an NER model.
+
+Will attempt to guess the appropriate word vector file if none is
+specified, and will use the charlms specified in the resources
+for a given dataset or language if possible.
+
+Example command line:
+ python3 -m stanza.utils.training.run_ner.py hu_combined
+
+This script expects the prepared data to be in
+ data/ner/{lang}_{dataset}.train.json, {lang}_{dataset}.dev.json, {lang}_{dataset}.test.json
+
+If those files don't exist, it will make an attempt to rebuild them
+using the prepare_ner_dataset script. However, this will fail if the
+data is not already downloaded. More information on where to find
+most of the datasets online is in that script. Some of the datasets
+have licenses which must be agreed to, so no attempt is made to
+automatically download the data.
+"""
+
+import glob
+import logging
+import os
+
+from stanza.models import ner_tagger
+from stanza.utils.datasets.ner import prepare_ner_dataset
+from stanza.utils.training import common
+from stanza.utils.training.common import Mode
+
+from stanza.resources.prepare_resources import default_charlms, ner_charlms
+from stanza.resources.common import DEFAULT_MODEL_DIR
+
+# extra arguments specific to a particular dataset
+DATASET_EXTRA_ARGS = {
+ "vi_vlsp": [ "--dropout", "0.6",
+ "--word_dropout", "0.1",
+ "--locked_dropout", "0.1",
+ "--char_dropout", "0.1" ],
+}
+
+logger = logging.getLogger('stanza')
+
+def add_ner_args(parser):
+ parser.add_argument('--charlm', default=None, type=str, help='Which charlm to run on. Will use the default charlm for this language/model if not set. Set to None to turn off charlm for languages with a default charlm')
+
+def find_charlm(direction, language, charlm):
+ saved_path = 'saved_models/charlm/{}_{}_{}_charlm.pt'.format(language, charlm, direction)
+ if os.path.exists(saved_path):
+ logger.info(f'Using model {saved_path} for {direction} charlm')
+ return saved_path
+
+ resource_path = '{}/{}/{}_charlm/{}.pt'.format(DEFAULT_MODEL_DIR, language, direction, charlm)
+ if os.path.exists(resource_path):
+ logger.info(f'Using model {resource_path} for {direction} charlm')
+ return resource_path
+
+ raise FileNotFoundError(f"Cannot find {direction} charlm in either {saved_path} or {resource_path}")
+
+def find_wordvec_pretrain(language):
+ # TODO: try to extract/remember the specific pretrain for the given model
+ # That would be a good way to archive which pretrains are used for which NER models, anyway
+ pretrain_path = '{}/{}/pretrain/*.pt'.format(DEFAULT_MODEL_DIR, language)
+ pretrains = glob.glob(pretrain_path)
+ if len(pretrains) == 0:
+ raise FileNotFoundError(f"Cannot find any pretrains in {pretrain_path} Try 'stanza.download(\"{language}\")' to get a default pretrain or use --wordvec_pretrain_path to specify a .pt file to use")
+ if len(pretrains) > 1:
+ raise FileNotFoundError(f"Too many pretrains to choose from in {pretrain_path} Must specify an exact path to a --wordvec_pretrain_file")
+ pretrain = pretrains[0]
+ logger.info(f"Using pretrain found in {pretrain} To use a different pretrain, specify --wordvec_pretrain_file")
+ return pretrain
+
+# Technically NER datasets are not necessarily treebanks
+# (usually not, in fact)
+# However, to keep the naming consistent, we leave the
+# method which does the training as run_treebank
+# TODO: rename treebank -> dataset everywhere
+def run_treebank(mode, paths, treebank, short_name,
+ temp_output_file, command_args, extra_args):
+ ner_dir = paths["NER_DATA_DIR"]
+ language, dataset = short_name.split("_")
+
+ train_file = os.path.join(ner_dir, f"{short_name}.train.json")
+ dev_file = os.path.join(ner_dir, f"{short_name}.dev.json")
+ test_file = os.path.join(ner_dir, f"{short_name}.test.json")
+
+ if not os.path.exists(train_file) or not os.path.exists(dev_file) or not os.path.exists(test_file):
+ logger.warning(f"The data for {short_name} is missing or incomplete. Attempting to rebuild...")
+ try:
+ prepare_ner_dataset.main(short_name)
+ except:
+ logger.error(f"Unable to build the data. Please correctly build the files in {train_file}, {dev_file}, {test_file} and then try again.")
+ raise
+
+ default_charlm = default_charlms.get(language, None)
+ specific_charlm = ner_charlms.get(language, {}).get(dataset, None)
+ if command_args.charlm:
+ charlm = command_args.charlm
+ if charlm == 'None':
+ charlm = None
+ elif specific_charlm:
+ charlm = specific_charlm
+ elif default_charlm:
+ charlm = default_charlm
+ else:
+ charlm = None
+
+ if charlm:
+ forward = find_charlm('forward', language, charlm)
+ backward = find_charlm('backward', language, charlm)
+ charlm_args = ['--charlm',
+ '--charlm_shorthand', f'{language}_{charlm}',
+ '--charlm_forward_file', forward,
+ '--charlm_backward_file', backward]
+ else:
+ charlm_args = []
+
+ if mode == Mode.TRAIN:
+ # VI example arguments:
+ # --wordvec_pretrain_file ~/stanza_resources/vi/pretrain/vtb.pt
+ # --train_file data/ner/vi_vlsp.train.json
+ # --eval_file data/ner/vi_vlsp.dev.json
+ # --lang vi
+ # --shorthand vi_vlsp
+ # --mode train
+ # --charlm --charlm_shorthand vi_conll17
+ # --dropout 0.6 --word_dropout 0.1 --locked_dropout 0.1 --char_dropout 0.1
+ dataset_args = DATASET_EXTRA_ARGS.get(short_name, [])
+
+ train_args = ['--train_file', train_file,
+ '--eval_file', dev_file,
+ '--lang', language,
+ '--shorthand', short_name,
+ '--mode', 'train']
+ train_args = train_args + charlm_args + dataset_args + extra_args
+ if '--wordvec_pretrain_file' not in train_args:
+ # will throw an error if the pretrain can't be found
+ wordvec_pretrain = find_wordvec_pretrain(language)
+ train_args = train_args + ['--wordvec_pretrain_file', wordvec_pretrain]
+ logger.info("Running train step with args: {}".format(train_args))
+ ner_tagger.main(train_args)
+
+ if mode == Mode.SCORE_DEV or mode == Mode.TRAIN:
+ dev_args = ['--eval_file', dev_file,
+ '--lang', language,
+ '--shorthand', short_name,
+ '--mode', 'predict']
+ dev_args = dev_args + charlm_args + extra_args
+ logger.info("Running dev step with args: {}".format(dev_args))
+ ner_tagger.main(dev_args)
+
+ if mode == Mode.SCORE_TEST or mode == Mode.TRAIN:
+ test_args = ['--eval_file', test_file,
+ '--lang', language,
+ '--shorthand', short_name,
+ '--mode', 'predict']
+ test_args = test_args + charlm_args + extra_args
+ logger.info("Running test step with args: {}".format(test_args))
+ ner_tagger.main(test_args)
+
+
+def main():
+ common.main(run_treebank, "ner", "nertagger", add_ner_args)
+
+if __name__ == "__main__":
+ main()
+