diff options
author | ZJaume <jzaragoza@prompsit.com> | 2022-09-05 17:32:30 +0300 |
---|---|---|
committer | ZJaume <jzaragoza@prompsit.com> | 2022-09-05 17:32:30 +0300 |
commit | a75201d3f298b2b99057802236c8a4c3d780af1a (patch) | |
tree | a748863fbc1439b46100b4b2582de86482286066 | |
parent | bc7bddfd68244b4b5cd206c4ac8e6f3094c76b3b (diff) | |
parent | dddce29bfe6047cd9f198990e916ac18b156efe1 (diff) |
Merge branch 'master' into noise
-rw-r--r-- | bicleaner_ai/models.py | 22 |
1 files changed, 16 insertions, 6 deletions
diff --git a/bicleaner_ai/models.py b/bicleaner_ai/models.py index fcc04df..cdb403e 100644 --- a/bicleaner_ai/models.py +++ b/bicleaner_ai/models.py @@ -489,8 +489,6 @@ class BCXLMRoberta(BaseModel): self.tokenizer = None self.settings = { - "model_file": "model.tf", - "vocab_file": "vocab", "base_model": 'jplu/tf-xlm-roberta-base', "batch_size": 16, "maxlen": 150, @@ -540,9 +538,19 @@ class BCXLMRoberta(BaseModel): def load(self): ''' Load fine-tuned model ''' - vocab_file = self.dir + '/' + self.settings["vocab_file"] + # If vocab and model files are in a subdirectory, load from there + if "vocab_file" in self.settings: + vocab_file = f'{self.dir}/{self.settings["vocab_file"]}' + else: + vocab_file = self.dir + + if "model_file" in self.settings: + model_file = f'{self.dir}/{self.settings["model_file"]}' + else: + model_file = self.dir + self.tokenizer = XLMRobertaTokenizerFast.from_pretrained(vocab_file) - self.model = self.load_model(f'{self.dir}/{self.settings["model_file"]}') + self.model = self.load_model(model_file) def softmax_pos_prob(self, x): # Compute softmax probability of the second (positive) class @@ -623,8 +631,10 @@ class BCXLMRoberta(BaseModel): batch_size=self.settings["batch_size"], callbacks=[earlystop], verbose=verbose) - self.model.save_pretrained(self.dir) - self.tokenizer.save_pretrained(self.dir) + self.model.save_pretrained(self.dir + '/' + + self.settings["model_file"]) + self.tokenizer.save_pretrained(self.dir + '/' + + self.settings["vocab_file"]) y_true = dev_generator.y with redirect_stdout(sys.stderr): |