Welcome to mirror list, hosted at ThFree Co, Russian Federation.

pretrain.py « common « models « stanza - github.com/stanfordnlp/stanza.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: 193cc71df5a2b3d270ba3ef325e03baa74955138 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
"""
Supports for pretrained data.
"""
import os
import re

import lzma
import logging
import numpy as np
import torch

from .vocab import BaseVocab, VOCAB_PREFIX

from stanza.resources.common import DEFAULT_MODEL_DIR

logger = logging.getLogger('stanza')

class PretrainedWordVocab(BaseVocab):
    def build_vocab(self):
        self._id2unit = VOCAB_PREFIX + self.data
        self._unit2id = {w:i for i, w in enumerate(self._id2unit)}

    def normalize_unit(self, unit):
        unit = super().normalize_unit(unit)
        if unit:
            unit = unit.replace(" ","\xa0")
        return unit

class Pretrain:
    """ A loader and saver for pretrained embeddings. """

    def __init__(self, filename=None, vec_filename=None, max_vocab=-1, save_to_file=True):
        self.filename = filename
        self._vec_filename = vec_filename
        self._max_vocab = max_vocab
        self._save_to_file = save_to_file

    @property
    def vocab(self):
        if not hasattr(self, '_vocab'):
            self.load()
        return self._vocab

    @property
    def emb(self):
        if not hasattr(self, '_emb'):
            self.load()
        return self._emb

    def load(self):
        if self.filename is not None and os.path.exists(self.filename):
            try:
                data = torch.load(self.filename, lambda storage, loc: storage)
                logger.debug("Loaded pretrain from {}".format(self.filename))
                self._vocab, self._emb = PretrainedWordVocab.load_state_dict(data['vocab']), data['emb']
                return
            except (KeyboardInterrupt, SystemExit):
                raise
            except BaseException as e:
                logger.warning("Pretrained file exists but cannot be loaded from {}, due to the following exception:\n\t{}".format(self.filename, e))
                vocab, emb = self.read_pretrain()
        else:
            if self.filename is not None:
                logger.info("Pretrained filename %s specified, but file does not exist.  Attempting to load from text file" % self.filename)
            vocab, emb = self.read_pretrain()

        self._vocab = vocab
        self._emb = emb

        if self._save_to_file:
            # save to file
            assert self.filename is not None, "Filename must be provided to save pretrained vector to file."
            self.save(self.filename)

    def save(self, filename):
        # should not infinite loop since the load function sets _vocab and _emb before trying to save
        data = {'vocab': self.vocab.state_dict(), 'emb': self.emb}
        try:
            torch.save(data, filename, _use_new_zipfile_serialization=False, pickle_protocol=3)
            logger.info("Saved pretrained vocab and vectors to {}".format(filename))
        except (KeyboardInterrupt, SystemExit):
            raise
        except BaseException as e:
            logger.warning("Saving pretrained data failed due to the following exception... continuing anyway.\n\t{}".format(e))


    def write_text(self, filename):
        """
        Write the vocab & values to a text file
        """
        with open(filename, "w") as fout:
            for i in range(len(self.vocab)):
                row = self.emb[i]
                fout.write(self.vocab[i])
                fout.write("\t")
                fout.write("\t".join(map(str, row)))
                fout.write("\n")


    def read_pretrain(self):
        # load from pretrained filename
        if self._vec_filename is None:
            raise RuntimeError("Vector file is not provided.")
        logger.info("Reading pretrained vectors from {}...".format(self._vec_filename))

        # first try reading as xz file, if failed retry as text file
        try:
            words, emb, failed = self.read_from_file(self._vec_filename, open_func=lzma.open)
        except lzma.LZMAError as err:
            logger.warning("Cannot decode vector file %s as xz file. Retrying as text file..." % self._vec_filename)
            words, emb, failed = self.read_from_file(self._vec_filename, open_func=open)

        if failed > 0: # recover failure
            emb = emb[:-failed]
        if len(emb) - len(VOCAB_PREFIX) != len(words):
            raise RuntimeError("Loaded number of vectors does not match number of words.")
        
        # Use a fixed vocab size
        if self._max_vocab > len(VOCAB_PREFIX) and self._max_vocab < len(words):
            words = words[:self._max_vocab - len(VOCAB_PREFIX)]
            emb = emb[:self._max_vocab]

        vocab = PretrainedWordVocab(words)
        
        return vocab, emb

    def read_from_file(self, filename, open_func=open):
        """
        Open a vector file using the provided function and read from it.
        """
        # some vector files, such as Google News, use tabs
        tab_space_pattern = re.compile(r"[ \t]+")
        first = True
        words = []
        failed = 0
        with open_func(filename, 'rb') as f:
            for i, line in enumerate(f):
                try:
                    line = line.decode()
                except UnicodeDecodeError:
                    failed += 1
                    continue
                if first:
                    # the first line contains the number of word vectors and the dimensionality
                    first = False
                    line = line.strip().split(' ')
                    rows, cols = [int(x) for x in line]
                    emb = np.zeros((rows + len(VOCAB_PREFIX), cols), dtype=np.float32)
                    continue

                line = tab_space_pattern.split((line.rstrip()))
                emb[i+len(VOCAB_PREFIX)-1-failed, :] = [float(x) for x in line[-cols:]]
                # if there were word pieces separated with spaces, rejoin them with nbsp instead
                # this way, the normalize_unit method in vocab.py can find the word at test time
                words.append('\xa0'.join(line[:-cols]))
        return words, emb, failed


def find_pretrain_file(wordvec_pretrain_file, save_dir, shorthand, lang):
    """
    When training a model, look in a few different places for a .pt file

    If a specific argument was passsed in, prefer that location
    Otherwise, check in a few places:
      saved_models/{model}/{shorthand}.pretrain.pt
      saved_models/{model}/{shorthand}_pretrain.pt
      ~/stanza_resources/{language}/pretrain/{shorthand}_pretrain.pt
    """
    if wordvec_pretrain_file:
        return wordvec_pretrain_file

    default_pretrain_file = os.path.join(save_dir, '{}.pretrain.pt'.format(shorthand))
    if os.path.exists(default_pretrain_file):
        logger.debug("Found existing .pt file in %s" % default_pretrain_file)
        return default_pretrain_file
    else:
        logger.debug("Cannot find pretrained vectors in %s" % default_pretrain_file)

    pretrain_file = os.path.join(save_dir, '{}_pretrain.pt'.format(shorthand))
    if os.path.exists(pretrain_file):
        logger.debug("Found existing .pt file in %s" % pretrain_file)
        return pretrain_file
    else:
        logger.debug("Cannot find pretrained vectors in %s" % pretrain_file)

    if shorthand.find("_") >= 0:
        # try to assemble /home/user/stanza_resources/vi/pretrain/vtb.pt for example
        pretrain_file = os.path.join(DEFAULT_MODEL_DIR, lang, 'pretrain', '{}.pt'.format(shorthand.split('_', 1)[1]))
        if os.path.exists(pretrain_file):
            logger.debug("Found existing .pt file in %s" % pretrain_file)
            return pretrain_file
        else:
            logger.debug("Cannot find pretrained vectors in %s" % pretrain_file)

    # if we can't find it anywhere, just return the first location searched...
    # maybe we'll get lucky and the original .txt file can be found
    return default_pretrain_file


if __name__ == '__main__':
    with open('test.txt', 'w') as fout:
        fout.write('3 2\na 1 1\nb -1 -1\nc 0 0\n')
    # 1st load: save to pt file
    pretrain = Pretrain('test.pt', 'test.txt')
    print(pretrain.emb)
    # verify pt file
    x = torch.load('test.pt')
    print(x)
    # 2nd load: load saved pt file
    pretrain = Pretrain('test.pt', 'test.txt')
    print(pretrain.emb)