1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
|
from collections import defaultdict
import os
import re
import sys
from stanza.models.common.vocab import VOCAB_PREFIX
from stanza.models.common.constant import treebank_to_short_name
from stanza.models.pos.vocab import XPOSVocab, WordVocab
from stanza.models.common.doc import *
from stanza.utils.conll import CoNLL
SHORTNAME_RE = re.compile("[a-z-]+_[a-z0-9]+")
def filter_data(data, idx):
data_filtered = []
for sentence in data:
flag = True
for token in sentence:
if token[idx] is None:
flag = False
if flag: data_filtered.append(sentence)
return data_filtered
def get_factory(sh, fn):
print('Resolving vocab option for {}...'.format(sh))
train_file = 'data/pos/{}.train.in.conllu'.format(sh)
if not os.path.exists(train_file):
raise UserWarning('Training data for {} not found in the data directory, falling back to using WordVocab. To generate the '
'XPOS vocabulary for this treebank properly, please run the following command first:\n'
'\tstanza/utils/datasets/prepare_pos_treebank.py {}'.format(fn, fn))
# without the training file, there's not much we can do
key = 'WordVocab(data, shorthand, idx=2)'
return key
doc = CoNLL.conll2doc(input_file=train_file)
data = doc.get([TEXT, UPOS, XPOS, FEATS], as_sentences=True)
print(f'Original length = {len(data)}')
data = filter_data(data, idx=2)
print(f'Filtered length = {len(data)}')
vocab = WordVocab(data, sh, idx=2, ignore=["_"])
key = 'WordVocab(data, shorthand, idx=2, ignore=["_"])'
best_size = len(vocab) - len(VOCAB_PREFIX)
if best_size > 20:
for sep in ['', '-', '+', '|', ',', ':']: # separators
vocab = XPOSVocab(data, sh, idx=2, sep=sep)
length = sum(len(x) - len(VOCAB_PREFIX) for x in vocab._id2unit.values())
if length < best_size:
key = 'XPOSVocab(data, shorthand, idx=2, sep="{}")'.format(sep)
best_size = length
return key
def main():
if len(sys.argv) != 3:
print('Usage: {} list_of_tb_file output_factory_file'.format(sys.argv[0]))
sys.exit(0)
# Read list of all treebanks of concern
list_of_tb_file, output_file = sys.argv[1:]
shorthands = []
fullnames = []
with open(list_of_tb_file) as f:
for line in f:
treebank = line.strip()
fullnames.append(treebank)
if SHORTNAME_RE.match(treebank):
shorthands.append(treebank)
else:
shorthands.append(treebank_to_short_name(treebank))
# For each treebank, we would like to find the XPOS Vocab configuration that minimizes
# the number of total classes needed to predict by all tagger classifiers. This is
# achieved by enumerating different options of separators that different treebanks might
# use, and comparing that to treating the XPOS tags as separate categories (using a
# WordVocab).
mapping = defaultdict(list)
for sh, fn in zip(shorthands, fullnames):
factory = get_factory(sh, fn)
mapping[factory].append(sh)
# Generate code. This takes the XPOS vocabulary classes selected above, and generates the
# actual factory class as seen in models.pos.xpos_vocab_factory.
first = True
with open(output_file, 'w') as f:
print('''# This is the XPOS factory method generated automatically from stanza.models.pos.build_xpos_vocab_factory.
# Please don't edit it!
from stanza.models.pos.vocab import WordVocab, XPOSVocab
def xpos_vocab_factory(data, shorthand):''', file=f)
for key in mapping:
print(" {} shorthand in [{}]:".format('if' if first else 'elif', ', '.join(['"{}"'.format(x) for x in mapping[key]])), file=f)
print(" return {}".format(key), file=f)
first = False
print(''' else:
raise NotImplementedError('Language shorthand "{}" not found!'.format(shorthand))''', file=f)
print('Done!')
if __name__ == "__main__":
main()
|