diff options
author | John Bauer <horatio@gmail.com> | 2021-08-03 20:31:01 +0300 |
---|---|---|
committer | John Bauer <horatio@gmail.com> | 2021-08-10 02:11:56 +0300 |
commit | c457a9309ad15c522e94230f919c25d1e7aebf64 (patch) | |
tree | c3b59fc42cd8dc37d1a0a542f0b35d11ee34b2b5 | |
parent | c5bbfab41fa200692bab80ce695f8ff2a6ad8262 (diff) |
Also check if the test set has tags not present in the tagger or if the train set has tags not presenti in a finetune NER modelv1.2.3
-rw-r--r-- | stanza/models/ner_tagger.py | 7 |
1 files changed, 5 insertions, 2 deletions
diff --git a/stanza/models/ner_tagger.py b/stanza/models/ner_tagger.py index 585065fa..020e2c68 100644 --- a/stanza/models/ner_tagger.py +++ b/stanza/models/ner_tagger.py @@ -166,7 +166,9 @@ def train(args): dev_batch = DataLoader(dev_doc, args['batch_size'], args, pretrain, vocab=vocab, evaluation=True) dev_gold_tags = dev_batch.tags - utils.warn_missing_tags(train_batch.tags, dev_batch.tags, "dev") + if args['finetune']: + utils.warn_missing_tags([i for i in trainer.vocab['tag']], train_batch.tags, "training set") + utils.warn_missing_tags(train_batch.tags, dev_batch.tags, "dev set") # skip training if the language does not have training or dev data if len(train_batch) == 0 or len(dev_batch) == 0: @@ -261,7 +263,8 @@ def evaluate(args): logger.info("Loading data with batch size {}...".format(args['batch_size'])) doc = Document(json.load(open(args['eval_file']))) batch = DataLoader(doc, args['batch_size'], loaded_args, vocab=vocab, evaluation=True) - + utils.warn_missing_tags([i for i in trainer.vocab['tag']], batch.tags, "eval_file") + logger.info("Start evaluation...") preds = [] for i, b in enumerate(batch): |