Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/stanfordnlp/stanza.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'stanza/models/common/pretrain.py')
-rw-r--r--stanza/models/common/pretrain.py66
1 files changed, 63 insertions, 3 deletions
diff --git a/stanza/models/common/pretrain.py b/stanza/models/common/pretrain.py
index 6dd50d8c..e18accbf 100644
--- a/stanza/models/common/pretrain.py
+++ b/stanza/models/common/pretrain.py
@@ -11,6 +11,8 @@ import torch
from .vocab import BaseVocab, VOCAB_PREFIX
+from stanza.resources.common import DEFAULT_MODEL_DIR
+
logger = logging.getLogger('stanza')
class PretrainedWordVocab(BaseVocab):
@@ -52,6 +54,8 @@ class Pretrain:
logger.warning("Pretrained file exists but cannot be loaded from {}, due to the following exception:\n\t{}".format(self.filename, e))
vocab, emb = self.read_pretrain()
else:
+ if self.filename is not None:
+ logger.info("Pretrained filename %s specified, but file does not exist. Attempting to load from text file" % self.filename)
vocab, emb = self.read_pretrain()
self._vocab = vocab
@@ -74,10 +78,23 @@ class Pretrain:
logger.warning("Saving pretrained data failed due to the following exception... continuing anyway.\n\t{}".format(e))
+ def write_text(self, filename):
+ """
+ Write the vocab & values to a text file
+ """
+ with open(filename, "w") as fout:
+ for i in range(len(self.vocab)):
+ row = self.emb[i]
+ fout.write(self.vocab[i])
+ fout.write("\t")
+ fout.write("\t".join(map(str, row)))
+ fout.write("\n")
+
+
def read_pretrain(self):
# load from pretrained filename
if self._vec_filename is None:
- raise Exception("Vector file is not provided.")
+ raise RuntimeError("Vector file is not provided.")
logger.info("Reading pretrained vectors from {}...".format(self._vec_filename))
# first try reading as xz file, if failed retry as text file
@@ -90,7 +107,7 @@ class Pretrain:
if failed > 0: # recover failure
emb = emb[:-failed]
if len(emb) - len(VOCAB_PREFIX) != len(words):
- raise Exception("Loaded number of vectors does not match number of words.")
+ raise RuntimeError("Loaded number of vectors does not match number of words.")
# Use a fixed vocab size
if self._max_vocab > len(VOCAB_PREFIX) and self._max_vocab < len(words):
@@ -127,10 +144,53 @@ class Pretrain:
line = tab_space_pattern.split((line.rstrip()))
emb[i+len(VOCAB_PREFIX)-1-failed, :] = [float(x) for x in line[-cols:]]
- words.append(' '.join(line[:-cols]))
+ # if there were word pieces separated with spaces, rejoin them with nbsp instead
+ # this way, the normalize_unit method in vocab.py can find the word at test time
+ words.append('\xa0'.join(line[:-cols]))
return words, emb, failed
+def find_pretrain_file(wordvec_pretrain_file, save_dir, shorthand, lang):
+ """
+ When training a model, look in a few different places for a .pt file
+
+ If a specific argument was passsed in, prefer that location
+ Otherwise, check in a few places:
+ saved_models/{model}/{shorthand}.pretrain.pt
+ saved_models/{model}/{shorthand}_pretrain.pt
+ ~/stanza_resources/{language}/pretrain/{shorthand}_pretrain.pt
+ """
+ if wordvec_pretrain_file:
+ return wordvec_pretrain_file
+
+ default_pretrain_file = os.path.join(save_dir, '{}.pretrain.pt'.format(shorthand))
+ if os.path.exists(default_pretrain_file):
+ logger.debug("Found existing .pt file in %s" % default_pretrain_file)
+ return default_pretrain_file
+ else:
+ logger.debug("Cannot find pretrained vectors in %s" % default_pretrain_file)
+
+ pretrain_file = os.path.join(save_dir, '{}_pretrain.pt'.format(shorthand))
+ if os.path.exists(pretrain_file):
+ logger.debug("Found existing .pt file in %s" % pretrain_file)
+ return pretrain_file
+ else:
+ logger.debug("Cannot find pretrained vectors in %s" % pretrain_file)
+
+ if shorthand.find("_") >= 0:
+ # try to assemble /home/user/stanza_resources/vi/pretrain/vtb.pt for example
+ pretrain_file = os.path.join(DEFAULT_MODEL_DIR, lang, 'pretrain', '{}.pt'.format(shorthand.split('_', 1)[1]))
+ if os.path.exists(pretrain_file):
+ logger.debug("Found existing .pt file in %s" % pretrain_file)
+ return pretrain_file
+ else:
+ logger.debug("Cannot find pretrained vectors in %s" % pretrain_file)
+
+ # if we can't find it anywhere, just return the first location searched...
+ # maybe we'll get lucky and the original .txt file can be found
+ return default_pretrain_file
+
+
if __name__ == '__main__':
with open('test.txt', 'w') as fout:
fout.write('3 2\na 1 1\nb -1 -1\nc 0 0\n')