diff options
Diffstat (limited to 'stanza/models/common/pretrain.py')
-rw-r--r-- | stanza/models/common/pretrain.py | 66 |
1 files changed, 63 insertions, 3 deletions
diff --git a/stanza/models/common/pretrain.py b/stanza/models/common/pretrain.py index 6dd50d8c..e18accbf 100644 --- a/stanza/models/common/pretrain.py +++ b/stanza/models/common/pretrain.py @@ -11,6 +11,8 @@ import torch from .vocab import BaseVocab, VOCAB_PREFIX +from stanza.resources.common import DEFAULT_MODEL_DIR + logger = logging.getLogger('stanza') class PretrainedWordVocab(BaseVocab): @@ -52,6 +54,8 @@ class Pretrain: logger.warning("Pretrained file exists but cannot be loaded from {}, due to the following exception:\n\t{}".format(self.filename, e)) vocab, emb = self.read_pretrain() else: + if self.filename is not None: + logger.info("Pretrained filename %s specified, but file does not exist. Attempting to load from text file" % self.filename) vocab, emb = self.read_pretrain() self._vocab = vocab @@ -74,10 +78,23 @@ class Pretrain: logger.warning("Saving pretrained data failed due to the following exception... continuing anyway.\n\t{}".format(e)) + def write_text(self, filename): + """ + Write the vocab & values to a text file + """ + with open(filename, "w") as fout: + for i in range(len(self.vocab)): + row = self.emb[i] + fout.write(self.vocab[i]) + fout.write("\t") + fout.write("\t".join(map(str, row))) + fout.write("\n") + + def read_pretrain(self): # load from pretrained filename if self._vec_filename is None: - raise Exception("Vector file is not provided.") + raise RuntimeError("Vector file is not provided.") logger.info("Reading pretrained vectors from {}...".format(self._vec_filename)) # first try reading as xz file, if failed retry as text file @@ -90,7 +107,7 @@ class Pretrain: if failed > 0: # recover failure emb = emb[:-failed] if len(emb) - len(VOCAB_PREFIX) != len(words): - raise Exception("Loaded number of vectors does not match number of words.") + raise RuntimeError("Loaded number of vectors does not match number of words.") # Use a fixed vocab size if self._max_vocab > len(VOCAB_PREFIX) and self._max_vocab < len(words): @@ -127,10 +144,53 @@ class Pretrain: line = tab_space_pattern.split((line.rstrip())) emb[i+len(VOCAB_PREFIX)-1-failed, :] = [float(x) for x in line[-cols:]] - words.append(' '.join(line[:-cols])) + # if there were word pieces separated with spaces, rejoin them with nbsp instead + # this way, the normalize_unit method in vocab.py can find the word at test time + words.append('\xa0'.join(line[:-cols])) return words, emb, failed +def find_pretrain_file(wordvec_pretrain_file, save_dir, shorthand, lang): + """ + When training a model, look in a few different places for a .pt file + + If a specific argument was passsed in, prefer that location + Otherwise, check in a few places: + saved_models/{model}/{shorthand}.pretrain.pt + saved_models/{model}/{shorthand}_pretrain.pt + ~/stanza_resources/{language}/pretrain/{shorthand}_pretrain.pt + """ + if wordvec_pretrain_file: + return wordvec_pretrain_file + + default_pretrain_file = os.path.join(save_dir, '{}.pretrain.pt'.format(shorthand)) + if os.path.exists(default_pretrain_file): + logger.debug("Found existing .pt file in %s" % default_pretrain_file) + return default_pretrain_file + else: + logger.debug("Cannot find pretrained vectors in %s" % default_pretrain_file) + + pretrain_file = os.path.join(save_dir, '{}_pretrain.pt'.format(shorthand)) + if os.path.exists(pretrain_file): + logger.debug("Found existing .pt file in %s" % pretrain_file) + return pretrain_file + else: + logger.debug("Cannot find pretrained vectors in %s" % pretrain_file) + + if shorthand.find("_") >= 0: + # try to assemble /home/user/stanza_resources/vi/pretrain/vtb.pt for example + pretrain_file = os.path.join(DEFAULT_MODEL_DIR, lang, 'pretrain', '{}.pt'.format(shorthand.split('_', 1)[1])) + if os.path.exists(pretrain_file): + logger.debug("Found existing .pt file in %s" % pretrain_file) + return pretrain_file + else: + logger.debug("Cannot find pretrained vectors in %s" % pretrain_file) + + # if we can't find it anywhere, just return the first location searched... + # maybe we'll get lucky and the original .txt file can be found + return default_pretrain_file + + if __name__ == '__main__': with open('test.txt', 'w') as fout: fout.write('3 2\na 1 1\nb -1 -1\nc 0 0\n') |