diff options
author | ZJaume <jzaragoza@prompsit.com> | 2022-09-05 17:21:33 +0300 |
---|---|---|
committer | ZJaume <jzaragoza@prompsit.com> | 2022-09-05 17:21:33 +0300 |
commit | dddce29bfe6047cd9f198990e916ac18b156efe1 (patch) | |
tree | b51c1514a52505699ce79988a4b385e3763123bd | |
parent | c820c7a407cbc8da8ec2365d43759212b737dcdd (diff) |
Look for model and vocab subdirectories when loading
-rw-r--r-- | bicleaner_ai/models.py | 22 |
1 files changed, 16 insertions, 6 deletions
diff --git a/bicleaner_ai/models.py b/bicleaner_ai/models.py index 3327e47..d5b21e8 100644 --- a/bicleaner_ai/models.py +++ b/bicleaner_ai/models.py @@ -489,8 +489,6 @@ class BCXLMRoberta(BaseModel): self.tokenizer = None self.settings = { - "model_file": "model.tf", - "vocab_file": "vocab", "base_model": 'jplu/tf-xlm-roberta-base', "batch_size": 16, "maxlen": 150, @@ -540,9 +538,19 @@ class BCXLMRoberta(BaseModel): def load(self): ''' Load fine-tuned model ''' - vocab_file = self.dir + '/' + self.settings["vocab_file"] + # If vocab and model files are in a subdirectory, load from there + if "vocab_file" in self.settings: + vocab_file = f'{self.dir}/{self.settings["vocab_file"]}' + else: + vocab_file = self.dir + + if "model_file" in self.settings: + model_file = f'{self.dir}/{self.settings["model_file"]}' + else: + model_file = self.dir + self.tokenizer = XLMRobertaTokenizerFast.from_pretrained(vocab_file) - self.model = self.load_model(f'{self.dir}/{self.settings["model_file"]}') + self.model = self.load_model(model_file) def softmax_pos_prob(self, x): # Compute softmax probability of the second (positive) class @@ -623,8 +631,10 @@ class BCXLMRoberta(BaseModel): batch_size=self.settings["batch_size"], callbacks=[earlystop], verbose=verbose) - self.model.save_pretrained(self.dir) - self.tokenizer.save_pretrained(self.dir) + self.model.save_pretrained(self.dir + '/' + + self.settings["model_file"]) + self.tokenizer.save_pretrained(self.dir + '/' + + self.settings["vocab_file"]) y_true = dev_generator.y with redirect_stdout(sys.stderr): |