diff options
author | John Bauer <horatio@gmail.com> | 2021-08-03 19:30:08 +0300 |
---|---|---|
committer | John Bauer <horatio@gmail.com> | 2021-08-10 02:11:56 +0300 |
commit | c5bbfab41fa200692bab80ce695f8ff2a6ad8262 (patch) | |
tree | fd6b8c518a6348e9d7d149c2f9f7591f7fc6bf52 | |
parent | 2b1a28646a18effcd4c8e83883aaa4a1e4c729e6 (diff) |
Add a test to see if any tags are in the dev set but not the train set
-rw-r--r-- | stanza/models/common/utils.py | 20 | ||||
-rw-r--r-- | stanza/models/ner_tagger.py | 2 | ||||
-rw-r--r-- | stanza/tests/test_utils.py | 6 |
3 files changed, 28 insertions, 0 deletions
diff --git a/stanza/models/common/utils.py b/stanza/models/common/utils.py index 25491754..842bf6fd 100644 --- a/stanza/models/common/utils.py +++ b/stanza/models/common/utils.py @@ -279,3 +279,23 @@ def set_random_seed(seed, cuda): if cuda: torch.cuda.manual_seed(seed) return seed + +def find_missing_tags(known_tags, test_tags): + if isinstance(known_tags, list) and isinstance(known_tags[0], list): + known_tags = set(x for y in known_tags for x in y) + if isinstance(test_tags, list) and isinstance(test_tags[0], list): + test_tags = sorted(set(x for y in test_tags for x in y)) + missing_tags = sorted(x for x in test_tags if x not in known_tags) + return missing_tags + +def warn_missing_tags(known_tags, test_tags, test_set_name): + """ + Print a warning if any tags present in the second list are not in the first list. + + Can also handle a list of lists. + """ + missing_tags = find_missing_tags(known_tags, test_tags) + if len(missing_tags) > 0: + logger.warning("Found tags in {} missing from the expected tag set: {}".format(test_set_name, missing_tags)) + return True + return False diff --git a/stanza/models/ner_tagger.py b/stanza/models/ner_tagger.py index c34e9036..585065fa 100644 --- a/stanza/models/ner_tagger.py +++ b/stanza/models/ner_tagger.py @@ -166,6 +166,8 @@ def train(args): dev_batch = DataLoader(dev_doc, args['batch_size'], args, pretrain, vocab=vocab, evaluation=True) dev_gold_tags = dev_batch.tags + utils.warn_missing_tags(train_batch.tags, dev_batch.tags, "dev") + # skip training if the language does not have training or dev data if len(train_batch) == 0 or len(dev_batch) == 0: logger.info("Skip training because no data available...") diff --git a/stanza/tests/test_utils.py b/stanza/tests/test_utils.py index 220cd224..4b02ab07 100644 --- a/stanza/tests/test_utils.py +++ b/stanza/tests/test_utils.py @@ -123,3 +123,9 @@ def test_split_into_batches(): # double check that unsort is working as expected assert data == utils.unsort(ordered, orig_idx) + + +def test_find_missing_tags(): + assert utils.find_missing_tags(["O", "PER", "LOC"], ["O", "PER", "LOC"]) == [] + assert utils.find_missing_tags(["O", "PER", "LOC"], ["O", "PER", "LOC", "ORG"]) == ['ORG'] + assert utils.find_missing_tags([["O", "PER"], ["O", "LOC"]], [["O", "PER"], ["LOC", "ORG"]]) == ['ORG'] |