diff options
author | John Bauer <horatio@gmail.com> | 2022-08-31 08:43:23 +0300 |
---|---|---|
committer | John Bauer <horatio@gmail.com> | 2022-08-31 08:43:23 +0300 |
commit | a5b13345fa18b53a8a54c7fe1117f2808d970e8d (patch) | |
tree | b73d4d546959859a7ac5a431ced7bdcb044cd9f3 | |
parent | 2495c498eaf3325fecf61f2cd3aff7d58d2ef12b (diff) |
Simplify the load mechanism in classifier Trainer so that the load() call loads the pretrain, charlm, etc
-rw-r--r-- | stanza/models/classifier.py | 2 | ||||
-rw-r--r-- | stanza/models/classifiers/trainer.py | 42 | ||||
-rw-r--r-- | stanza/pipeline/sentiment_processor.py | 16 | ||||
-rw-r--r-- | stanza/tests/classifiers/test_classifier.py | 4 |
4 files changed, 25 insertions, 39 deletions
diff --git a/stanza/models/classifier.py b/stanza/models/classifier.py index 17f9d231..4e420a39 100644 --- a/stanza/models/classifier.py +++ b/stanza/models/classifier.py @@ -495,7 +495,7 @@ def main(args=None): train_set = None if args.load_name: - trainer = Trainer.load_model(args, load_optimizer=args.train) + trainer = Trainer.load(args.load_name, args, load_optimizer=args.train) else: trainer = Trainer.build_new_model(args, train_set) diff --git a/stanza/models/classifiers/trainer.py b/stanza/models/classifiers/trainer.py index 2eea9c86..c74e7b0a 100644 --- a/stanza/models/classifiers/trainer.py +++ b/stanza/models/classifiers/trainer.py @@ -11,7 +11,7 @@ import torch.optim as optim import stanza.models.classifiers.data as data import stanza.models.classifiers.cnn_classifier as cnn_classifier -from stanza.models.common.foundation_cache import load_bert, load_charlm +from stanza.models.common.foundation_cache import load_bert, load_charlm, load_pretrain from stanza.models.common.pretrain import Pretrain logger = logging.getLogger('stanza') @@ -52,29 +52,12 @@ class Trainer: logger.info("Model saved to {}".format(filename)) @staticmethod - def load_model(args, load_optimizer=False): - """ - Load both the pretrained embedding and other pieces from the args as well as the model itself - """ - pretrain = Trainer.load_pretrain(args) - elmo_model = utils.load_elmo(args.elmo_model) if args.use_elmo else None - charmodel_forward = load_charlm(args.charlm_forward_file) - charmodel_backward = load_charlm(args.charlm_backward_file) - - if os.path.exists(args.load_name): - load_name = args.load_name - else: - load_name = os.path.join(args.save_dir, args.load_name) - if not os.path.exists(load_name): - raise FileNotFoundError("Could not find model to load in either %s or %s" % (args.load_name, load_name)) - - trainer = Trainer.load(load_name, args, pretrain, charmodel_forward, charmodel_backward, elmo_model, load_optimizer=load_optimizer) - return trainer - - # TODO: load the pretrain and all that stuff here - # in other words, combine load_model and load... - @staticmethod - def load(filename, args, pretrain, charmodel_forward, charmodel_backward, elmo_model, foundation_cache=None, load_optimizer=False): + def load(filename, args, foundation_cache=None, load_optimizer=False): + if not os.path.exists(filename): + if os.path.exists(os.path.join(args.save_dir, filename)): + filename = os.path.join(args.save_dir, filename) + else: + raise FileNotFoundError("Cannot find model in {} or in {}".format(filename, os.path.join(args.save_dir, filename))) try: checkpoint = torch.load(filename, lambda storage, loc: storage) except BaseException: @@ -98,6 +81,11 @@ class Trainer: # TODO: the getattr is not needed when all models have this baked into the config model_type = getattr(checkpoint['config'], 'model_type', 'CNNClassifier') + pretrain = Trainer.load_pretrain(args, foundation_cache) + elmo_model = utils.load_elmo(args.elmo_model) if args.use_elmo else None + charmodel_forward = load_charlm(args.charlm_forward_file, foundation_cache) + charmodel_backward = load_charlm(args.charlm_backward_file, foundation_cache) + bert_model = checkpoint['config'].bert_model bert_model, bert_tokenizer = load_bert(bert_model, foundation_cache) if model_type == 'CNNClassifier': @@ -138,7 +126,7 @@ class Trainer: return trainer - def load_pretrain(args): + def load_pretrain(args, foundation_cache): if args.wordvec_pretrain_file: pretrain_file = args.wordvec_pretrain_file elif args.wordvec_type: @@ -148,7 +136,7 @@ class Trainer: logger.info("Looking for pretrained vectors in {}".format(pretrain_file)) if os.path.exists(pretrain_file): - vec_file = None + return load_pretrain(pretrain_file, foundation_cache) elif args.wordvec_raw_file: vec_file = args.wordvec_raw_file logger.info("Pretrain not found. Looking in {}".format(vec_file)) @@ -168,7 +156,7 @@ class Trainer: if train_set is None: raise ValueError("Must have a train set to build a new model - needed for labels and delta word vectors") - pretrain = Trainer.load_pretrain(args) + pretrain = Trainer.load_pretrain(args, foundation_cache=None) elmo_model = utils.load_elmo(args.elmo_model) if args.use_elmo else None charmodel_forward = load_charlm(args.charlm_forward_file) charmodel_backward = load_charlm(args.charlm_backward_file) diff --git a/stanza/pipeline/sentiment_processor.py b/stanza/pipeline/sentiment_processor.py index 8ff25017..60b2307b 100644 --- a/stanza/pipeline/sentiment_processor.py +++ b/stanza/pipeline/sentiment_processor.py @@ -31,24 +31,22 @@ class SentimentProcessor(UDProcessor): def _set_up_model(self, config, pipeline, use_gpu): # get pretrained word vectors pretrain_path = config.get('pretrain_path', None) - self._pretrain = pipeline.foundation_cache.load_pretrain(pretrain_path) if pretrain_path else None forward_charlm_path = config.get('forward_charlm_path', None) - charmodel_forward = pipeline.foundation_cache.load_charlm(forward_charlm_path) backward_charlm_path = config.get('backward_charlm_path', None) - charmodel_backward = pipeline.foundation_cache.load_charlm(backward_charlm_path) - args = SimpleNamespace(cuda = use_gpu) - # set up model # elmo does not have a convenient way to download intermediate # models the way stanza downloads charlms & pretrains or # transformers downloads bert etc # however, elmo in general is not as good as using a # transformer, so it is unlikely we will ever fix this + args = SimpleNamespace(cuda = use_gpu, + charlm_forward_file = forward_charlm_path, + charlm_backward_file = backward_charlm_path, + wordvec_pretrain_file = pretrain_path, + elmo_model = None, + use_elmo = False) + # set up model trainer = Trainer.load(filename=config['model_path'], args=args, - pretrain=self._pretrain, - charmodel_forward=charmodel_forward, - charmodel_backward=charmodel_backward, - elmo_model=None, foundation_cache=pipeline.foundation_cache) self._model = trainer.model # batch size counted as words diff --git a/stanza/tests/classifiers/test_classifier.py b/stanza/tests/classifiers/test_classifier.py index 3b4eb3e4..72cc185b 100644 --- a/stanza/tests/classifiers/test_classifier.py +++ b/stanza/tests/classifiers/test_classifier.py @@ -148,9 +148,9 @@ def test_save_load(tmp_path, fake_embeddings, train_file, dev_file): trainer.save(save_filename) args.load_name = args.save_name - trainer = Trainer.load_model(args) + trainer = Trainer.load(args.load_name, args) args.load_name = save_filename - trainer = Trainer.load_model(args) + trainer = Trainer.load(args.load_name, args) def test_train_basic(tmp_path, fake_embeddings, train_file, dev_file): run_training(tmp_path, fake_embeddings, train_file, dev_file) |