Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/stanfordnlp/stanza.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJohn Bauer <horatio@gmail.com>2022-08-31 08:43:23 +0300
committerJohn Bauer <horatio@gmail.com>2022-08-31 08:43:23 +0300
commita5b13345fa18b53a8a54c7fe1117f2808d970e8d (patch)
treeb73d4d546959859a7ac5a431ced7bdcb044cd9f3
parent2495c498eaf3325fecf61f2cd3aff7d58d2ef12b (diff)
Simplify the load mechanism in classifier Trainer so that the load() call loads the pretrain, charlm, etc
-rw-r--r--stanza/models/classifier.py2
-rw-r--r--stanza/models/classifiers/trainer.py42
-rw-r--r--stanza/pipeline/sentiment_processor.py16
-rw-r--r--stanza/tests/classifiers/test_classifier.py4
4 files changed, 25 insertions, 39 deletions
diff --git a/stanza/models/classifier.py b/stanza/models/classifier.py
index 17f9d231..4e420a39 100644
--- a/stanza/models/classifier.py
+++ b/stanza/models/classifier.py
@@ -495,7 +495,7 @@ def main(args=None):
train_set = None
if args.load_name:
- trainer = Trainer.load_model(args, load_optimizer=args.train)
+ trainer = Trainer.load(args.load_name, args, load_optimizer=args.train)
else:
trainer = Trainer.build_new_model(args, train_set)
diff --git a/stanza/models/classifiers/trainer.py b/stanza/models/classifiers/trainer.py
index 2eea9c86..c74e7b0a 100644
--- a/stanza/models/classifiers/trainer.py
+++ b/stanza/models/classifiers/trainer.py
@@ -11,7 +11,7 @@ import torch.optim as optim
import stanza.models.classifiers.data as data
import stanza.models.classifiers.cnn_classifier as cnn_classifier
-from stanza.models.common.foundation_cache import load_bert, load_charlm
+from stanza.models.common.foundation_cache import load_bert, load_charlm, load_pretrain
from stanza.models.common.pretrain import Pretrain
logger = logging.getLogger('stanza')
@@ -52,29 +52,12 @@ class Trainer:
logger.info("Model saved to {}".format(filename))
@staticmethod
- def load_model(args, load_optimizer=False):
- """
- Load both the pretrained embedding and other pieces from the args as well as the model itself
- """
- pretrain = Trainer.load_pretrain(args)
- elmo_model = utils.load_elmo(args.elmo_model) if args.use_elmo else None
- charmodel_forward = load_charlm(args.charlm_forward_file)
- charmodel_backward = load_charlm(args.charlm_backward_file)
-
- if os.path.exists(args.load_name):
- load_name = args.load_name
- else:
- load_name = os.path.join(args.save_dir, args.load_name)
- if not os.path.exists(load_name):
- raise FileNotFoundError("Could not find model to load in either %s or %s" % (args.load_name, load_name))
-
- trainer = Trainer.load(load_name, args, pretrain, charmodel_forward, charmodel_backward, elmo_model, load_optimizer=load_optimizer)
- return trainer
-
- # TODO: load the pretrain and all that stuff here
- # in other words, combine load_model and load...
- @staticmethod
- def load(filename, args, pretrain, charmodel_forward, charmodel_backward, elmo_model, foundation_cache=None, load_optimizer=False):
+ def load(filename, args, foundation_cache=None, load_optimizer=False):
+ if not os.path.exists(filename):
+ if os.path.exists(os.path.join(args.save_dir, filename)):
+ filename = os.path.join(args.save_dir, filename)
+ else:
+ raise FileNotFoundError("Cannot find model in {} or in {}".format(filename, os.path.join(args.save_dir, filename)))
try:
checkpoint = torch.load(filename, lambda storage, loc: storage)
except BaseException:
@@ -98,6 +81,11 @@ class Trainer:
# TODO: the getattr is not needed when all models have this baked into the config
model_type = getattr(checkpoint['config'], 'model_type', 'CNNClassifier')
+ pretrain = Trainer.load_pretrain(args, foundation_cache)
+ elmo_model = utils.load_elmo(args.elmo_model) if args.use_elmo else None
+ charmodel_forward = load_charlm(args.charlm_forward_file, foundation_cache)
+ charmodel_backward = load_charlm(args.charlm_backward_file, foundation_cache)
+
bert_model = checkpoint['config'].bert_model
bert_model, bert_tokenizer = load_bert(bert_model, foundation_cache)
if model_type == 'CNNClassifier':
@@ -138,7 +126,7 @@ class Trainer:
return trainer
- def load_pretrain(args):
+ def load_pretrain(args, foundation_cache):
if args.wordvec_pretrain_file:
pretrain_file = args.wordvec_pretrain_file
elif args.wordvec_type:
@@ -148,7 +136,7 @@ class Trainer:
logger.info("Looking for pretrained vectors in {}".format(pretrain_file))
if os.path.exists(pretrain_file):
- vec_file = None
+ return load_pretrain(pretrain_file, foundation_cache)
elif args.wordvec_raw_file:
vec_file = args.wordvec_raw_file
logger.info("Pretrain not found. Looking in {}".format(vec_file))
@@ -168,7 +156,7 @@ class Trainer:
if train_set is None:
raise ValueError("Must have a train set to build a new model - needed for labels and delta word vectors")
- pretrain = Trainer.load_pretrain(args)
+ pretrain = Trainer.load_pretrain(args, foundation_cache=None)
elmo_model = utils.load_elmo(args.elmo_model) if args.use_elmo else None
charmodel_forward = load_charlm(args.charlm_forward_file)
charmodel_backward = load_charlm(args.charlm_backward_file)
diff --git a/stanza/pipeline/sentiment_processor.py b/stanza/pipeline/sentiment_processor.py
index 8ff25017..60b2307b 100644
--- a/stanza/pipeline/sentiment_processor.py
+++ b/stanza/pipeline/sentiment_processor.py
@@ -31,24 +31,22 @@ class SentimentProcessor(UDProcessor):
def _set_up_model(self, config, pipeline, use_gpu):
# get pretrained word vectors
pretrain_path = config.get('pretrain_path', None)
- self._pretrain = pipeline.foundation_cache.load_pretrain(pretrain_path) if pretrain_path else None
forward_charlm_path = config.get('forward_charlm_path', None)
- charmodel_forward = pipeline.foundation_cache.load_charlm(forward_charlm_path)
backward_charlm_path = config.get('backward_charlm_path', None)
- charmodel_backward = pipeline.foundation_cache.load_charlm(backward_charlm_path)
- args = SimpleNamespace(cuda = use_gpu)
- # set up model
# elmo does not have a convenient way to download intermediate
# models the way stanza downloads charlms & pretrains or
# transformers downloads bert etc
# however, elmo in general is not as good as using a
# transformer, so it is unlikely we will ever fix this
+ args = SimpleNamespace(cuda = use_gpu,
+ charlm_forward_file = forward_charlm_path,
+ charlm_backward_file = backward_charlm_path,
+ wordvec_pretrain_file = pretrain_path,
+ elmo_model = None,
+ use_elmo = False)
+ # set up model
trainer = Trainer.load(filename=config['model_path'],
args=args,
- pretrain=self._pretrain,
- charmodel_forward=charmodel_forward,
- charmodel_backward=charmodel_backward,
- elmo_model=None,
foundation_cache=pipeline.foundation_cache)
self._model = trainer.model
# batch size counted as words
diff --git a/stanza/tests/classifiers/test_classifier.py b/stanza/tests/classifiers/test_classifier.py
index 3b4eb3e4..72cc185b 100644
--- a/stanza/tests/classifiers/test_classifier.py
+++ b/stanza/tests/classifiers/test_classifier.py
@@ -148,9 +148,9 @@ def test_save_load(tmp_path, fake_embeddings, train_file, dev_file):
trainer.save(save_filename)
args.load_name = args.save_name
- trainer = Trainer.load_model(args)
+ trainer = Trainer.load(args.load_name, args)
args.load_name = save_filename
- trainer = Trainer.load_model(args)
+ trainer = Trainer.load(args.load_name, args)
def test_train_basic(tmp_path, fake_embeddings, train_file, dev_file):
run_training(tmp_path, fake_embeddings, train_file, dev_file)