Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/stanfordnlp/stanza.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJohn Bauer <horatio@gmail.com>2021-08-03 19:30:08 +0300
committerJohn Bauer <horatio@gmail.com>2021-08-10 02:11:56 +0300
commitc5bbfab41fa200692bab80ce695f8ff2a6ad8262 (patch)
treefd6b8c518a6348e9d7d149c2f9f7591f7fc6bf52
parent2b1a28646a18effcd4c8e83883aaa4a1e4c729e6 (diff)
Add a test to see if any tags are in the dev set but not the train set
-rw-r--r--stanza/models/common/utils.py20
-rw-r--r--stanza/models/ner_tagger.py2
-rw-r--r--stanza/tests/test_utils.py6
3 files changed, 28 insertions, 0 deletions
diff --git a/stanza/models/common/utils.py b/stanza/models/common/utils.py
index 25491754..842bf6fd 100644
--- a/stanza/models/common/utils.py
+++ b/stanza/models/common/utils.py
@@ -279,3 +279,23 @@ def set_random_seed(seed, cuda):
if cuda:
torch.cuda.manual_seed(seed)
return seed
+
+def find_missing_tags(known_tags, test_tags):
+ if isinstance(known_tags, list) and isinstance(known_tags[0], list):
+ known_tags = set(x for y in known_tags for x in y)
+ if isinstance(test_tags, list) and isinstance(test_tags[0], list):
+ test_tags = sorted(set(x for y in test_tags for x in y))
+ missing_tags = sorted(x for x in test_tags if x not in known_tags)
+ return missing_tags
+
+def warn_missing_tags(known_tags, test_tags, test_set_name):
+ """
+ Print a warning if any tags present in the second list are not in the first list.
+
+ Can also handle a list of lists.
+ """
+ missing_tags = find_missing_tags(known_tags, test_tags)
+ if len(missing_tags) > 0:
+ logger.warning("Found tags in {} missing from the expected tag set: {}".format(test_set_name, missing_tags))
+ return True
+ return False
diff --git a/stanza/models/ner_tagger.py b/stanza/models/ner_tagger.py
index c34e9036..585065fa 100644
--- a/stanza/models/ner_tagger.py
+++ b/stanza/models/ner_tagger.py
@@ -166,6 +166,8 @@ def train(args):
dev_batch = DataLoader(dev_doc, args['batch_size'], args, pretrain, vocab=vocab, evaluation=True)
dev_gold_tags = dev_batch.tags
+ utils.warn_missing_tags(train_batch.tags, dev_batch.tags, "dev")
+
# skip training if the language does not have training or dev data
if len(train_batch) == 0 or len(dev_batch) == 0:
logger.info("Skip training because no data available...")
diff --git a/stanza/tests/test_utils.py b/stanza/tests/test_utils.py
index 220cd224..4b02ab07 100644
--- a/stanza/tests/test_utils.py
+++ b/stanza/tests/test_utils.py
@@ -123,3 +123,9 @@ def test_split_into_batches():
# double check that unsort is working as expected
assert data == utils.unsort(ordered, orig_idx)
+
+
+def test_find_missing_tags():
+ assert utils.find_missing_tags(["O", "PER", "LOC"], ["O", "PER", "LOC"]) == []
+ assert utils.find_missing_tags(["O", "PER", "LOC"], ["O", "PER", "LOC", "ORG"]) == ['ORG']
+ assert utils.find_missing_tags([["O", "PER"], ["O", "LOC"]], [["O", "PER"], ["LOC", "ORG"]]) == ['ORG']