Welcome to mirror list, hosted at ThFree Co, Russian Federation.

count_ner_coverage.py « common « models « stanza - github.com/stanfordnlp/stanza.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: b5a592c7ca9758b2ff49f1a33c36dce0a948f497 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
from stanza.models.common import pretrain
import argparse

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('ners', type=str, nargs='*', help='Which treebanks to run on')
    parser.add_argument('--pretrain', type=str, default="/home/john/stanza_resources/hi/pretrain/hdtb.pt", help='Which pretrain to use')
    parser.set_defaults(ners=["/home/john/stanza/data/ner/hi_fire2013.train.csv",
                              "/home/john/stanza/data/ner/hi_fire2013.dev.csv"])
    args = parser.parse_args()
    return args


def read_ner(filename):
    words = []
    for line in open(filename).readlines():
        line = line.strip()
        if not line:
            continue
        if line.split("\t")[1] == 'O':
            continue
        words.append(line.split("\t")[0])
    return words

def count_coverage(pretrain, words):
    count = 0
    for w in words:
        if w in pretrain.vocab:
            count = count + 1
    return count / len(words)

args = parse_args()
pt = pretrain.Pretrain(args.pretrain)
for dataset in args.ners:
    words = read_ner(dataset)
    print(dataset)
    print(count_coverage(pt, words))
    print()