Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/stanfordnlp/stanza.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJohn Bauer <horatio@gmail.com>2021-08-03 20:31:01 +0300
committerJohn Bauer <horatio@gmail.com>2021-08-10 02:11:56 +0300
commitc457a9309ad15c522e94230f919c25d1e7aebf64 (patch)
treec3b59fc42cd8dc37d1a0a542f0b35d11ee34b2b5
parentc5bbfab41fa200692bab80ce695f8ff2a6ad8262 (diff)
Also check if the test set has tags not present in the tagger or if the train set has tags not presenti in a finetune NER modelv1.2.3
-rw-r--r--stanza/models/ner_tagger.py7
1 files changed, 5 insertions, 2 deletions
diff --git a/stanza/models/ner_tagger.py b/stanza/models/ner_tagger.py
index 585065fa..020e2c68 100644
--- a/stanza/models/ner_tagger.py
+++ b/stanza/models/ner_tagger.py
@@ -166,7 +166,9 @@ def train(args):
dev_batch = DataLoader(dev_doc, args['batch_size'], args, pretrain, vocab=vocab, evaluation=True)
dev_gold_tags = dev_batch.tags
- utils.warn_missing_tags(train_batch.tags, dev_batch.tags, "dev")
+ if args['finetune']:
+ utils.warn_missing_tags([i for i in trainer.vocab['tag']], train_batch.tags, "training set")
+ utils.warn_missing_tags(train_batch.tags, dev_batch.tags, "dev set")
# skip training if the language does not have training or dev data
if len(train_batch) == 0 or len(dev_batch) == 0:
@@ -261,7 +263,8 @@ def evaluate(args):
logger.info("Loading data with batch size {}...".format(args['batch_size']))
doc = Document(json.load(open(args['eval_file'])))
batch = DataLoader(doc, args['batch_size'], loaded_args, vocab=vocab, evaluation=True)
-
+ utils.warn_missing_tags([i for i in trainer.vocab['tag']], batch.tags, "eval_file")
+
logger.info("Start evaluation...")
preds = []
for i, b in enumerate(batch):