diff options
author | ZJaume <jzaragoza@prompsit.com> | 2022-08-09 17:28:24 +0300 |
---|---|---|
committer | ZJaume <jzaragoza@prompsit.com> | 2022-08-09 17:28:24 +0300 |
commit | f589cd62583c261b94510918da29d58bc8777e87 (patch) | |
tree | cfdbfd2beed447eca33458a7ad5adcc179df87cb | |
parent | 5266d5a5b9378c58f1f9a895e718544b8ebe8071 (diff) |
Restore retrocompatibility with older models
-rw-r--r-- | bicleaner_ai/models.py | 7 |
1 files changed, 5 insertions, 2 deletions
diff --git a/bicleaner_ai/models.py b/bicleaner_ai/models.py index b76067b..15dc19a 100644 --- a/bicleaner_ai/models.py +++ b/bicleaner_ai/models.py @@ -486,6 +486,8 @@ class BCXLMRoberta(BaseModel): self.tokenizer = None self.settings = { + "model_file": "model.tf", + "vocab_file": "vocab", "base_model": 'jplu/tf-xlm-roberta-base', "batch_size": 16, "maxlen": 150, @@ -535,8 +537,9 @@ class BCXLMRoberta(BaseModel): def load(self): ''' Load fine-tuned model ''' - self.tokenizer = XLMRobertaTokenizerFast.from_pretrained(self.dir) - self.model = self.load_model(self.dir) + vocab_file = self.dir + '/' + self.settings["vocab_file"] + self.tokenizer = XLMRobertaTokenizerFast.from_pretrained(vocab_file) + self.model = self.load_model(f'{self.dir}/{self.settings["model_file"]}') def softmax_pos_prob(self, x): # Compute softmax probability of the second (positive) class |