Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/bitextor/bicleaner-ai.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorZJaume <jzaragoza@prompsit.com>2022-08-09 17:28:24 +0300
committerZJaume <jzaragoza@prompsit.com>2022-08-09 17:28:24 +0300
commitf589cd62583c261b94510918da29d58bc8777e87 (patch)
treecfdbfd2beed447eca33458a7ad5adcc179df87cb
parent5266d5a5b9378c58f1f9a895e718544b8ebe8071 (diff)
Restore retrocompatibility with older models
-rw-r--r--bicleaner_ai/models.py7
1 files changed, 5 insertions, 2 deletions
diff --git a/bicleaner_ai/models.py b/bicleaner_ai/models.py
index b76067b..15dc19a 100644
--- a/bicleaner_ai/models.py
+++ b/bicleaner_ai/models.py
@@ -486,6 +486,8 @@ class BCXLMRoberta(BaseModel):
self.tokenizer = None
self.settings = {
+ "model_file": "model.tf",
+ "vocab_file": "vocab",
"base_model": 'jplu/tf-xlm-roberta-base',
"batch_size": 16,
"maxlen": 150,
@@ -535,8 +537,9 @@ class BCXLMRoberta(BaseModel):
def load(self):
''' Load fine-tuned model '''
- self.tokenizer = XLMRobertaTokenizerFast.from_pretrained(self.dir)
- self.model = self.load_model(self.dir)
+ vocab_file = self.dir + '/' + self.settings["vocab_file"]
+ self.tokenizer = XLMRobertaTokenizerFast.from_pretrained(vocab_file)
+ self.model = self.load_model(f'{self.dir}/{self.settings["model_file"]}')
def softmax_pos_prob(self, x):
# Compute softmax probability of the second (positive) class