From dddce29bfe6047cd9f198990e916ac18b156efe1 Mon Sep 17 00:00:00 2001 From: ZJaume Date: Mon, 5 Sep 2022 14:21:33 +0000 Subject: Look for model and vocab subdirectories when loading --- bicleaner_ai/models.py | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/bicleaner_ai/models.py b/bicleaner_ai/models.py index 3327e47..d5b21e8 100644 --- a/bicleaner_ai/models.py +++ b/bicleaner_ai/models.py @@ -489,8 +489,6 @@ class BCXLMRoberta(BaseModel): self.tokenizer = None self.settings = { - "model_file": "model.tf", - "vocab_file": "vocab", "base_model": 'jplu/tf-xlm-roberta-base', "batch_size": 16, "maxlen": 150, @@ -540,9 +538,19 @@ class BCXLMRoberta(BaseModel): def load(self): ''' Load fine-tuned model ''' - vocab_file = self.dir + '/' + self.settings["vocab_file"] + # If vocab and model files are in a subdirectory, load from there + if "vocab_file" in self.settings: + vocab_file = f'{self.dir}/{self.settings["vocab_file"]}' + else: + vocab_file = self.dir + + if "model_file" in self.settings: + model_file = f'{self.dir}/{self.settings["model_file"]}' + else: + model_file = self.dir + self.tokenizer = XLMRobertaTokenizerFast.from_pretrained(vocab_file) - self.model = self.load_model(f'{self.dir}/{self.settings["model_file"]}') + self.model = self.load_model(model_file) def softmax_pos_prob(self, x): # Compute softmax probability of the second (positive) class @@ -623,8 +631,10 @@ class BCXLMRoberta(BaseModel): batch_size=self.settings["batch_size"], callbacks=[earlystop], verbose=verbose) - self.model.save_pretrained(self.dir) - self.tokenizer.save_pretrained(self.dir) + self.model.save_pretrained(self.dir + '/' + + self.settings["model_file"]) + self.tokenizer.save_pretrained(self.dir + '/' + + self.settings["vocab_file"]) y_true = dev_generator.y with redirect_stdout(sys.stderr): -- cgit v1.2.3