Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/bitextor/bicleaner-ai.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorZJaume <jzaragoza@prompsit.com>2022-09-05 17:21:33 +0300
committerZJaume <jzaragoza@prompsit.com>2022-09-05 17:21:33 +0300
commitdddce29bfe6047cd9f198990e916ac18b156efe1 (patch)
treeb51c1514a52505699ce79988a4b385e3763123bd
parentc820c7a407cbc8da8ec2365d43759212b737dcdd (diff)
Look for model and vocab subdirectories when loading
-rw-r--r--bicleaner_ai/models.py22
1 files changed, 16 insertions, 6 deletions
diff --git a/bicleaner_ai/models.py b/bicleaner_ai/models.py
index 3327e47..d5b21e8 100644
--- a/bicleaner_ai/models.py
+++ b/bicleaner_ai/models.py
@@ -489,8 +489,6 @@ class BCXLMRoberta(BaseModel):
self.tokenizer = None
self.settings = {
- "model_file": "model.tf",
- "vocab_file": "vocab",
"base_model": 'jplu/tf-xlm-roberta-base',
"batch_size": 16,
"maxlen": 150,
@@ -540,9 +538,19 @@ class BCXLMRoberta(BaseModel):
def load(self):
''' Load fine-tuned model '''
- vocab_file = self.dir + '/' + self.settings["vocab_file"]
+ # If vocab and model files are in a subdirectory, load from there
+ if "vocab_file" in self.settings:
+ vocab_file = f'{self.dir}/{self.settings["vocab_file"]}'
+ else:
+ vocab_file = self.dir
+
+ if "model_file" in self.settings:
+ model_file = f'{self.dir}/{self.settings["model_file"]}'
+ else:
+ model_file = self.dir
+
self.tokenizer = XLMRobertaTokenizerFast.from_pretrained(vocab_file)
- self.model = self.load_model(f'{self.dir}/{self.settings["model_file"]}')
+ self.model = self.load_model(model_file)
def softmax_pos_prob(self, x):
# Compute softmax probability of the second (positive) class
@@ -623,8 +631,10 @@ class BCXLMRoberta(BaseModel):
batch_size=self.settings["batch_size"],
callbacks=[earlystop],
verbose=verbose)
- self.model.save_pretrained(self.dir)
- self.tokenizer.save_pretrained(self.dir)
+ self.model.save_pretrained(self.dir + '/'
+ + self.settings["model_file"])
+ self.tokenizer.save_pretrained(self.dir + '/'
+ + self.settings["vocab_file"])
y_true = dev_generator.y
with redirect_stdout(sys.stderr):