diff options
Diffstat (limited to 'bicleaner_ai/models.py')
-rw-r--r-- | bicleaner_ai/models.py | 7 |
1 files changed, 4 insertions, 3 deletions
diff --git a/bicleaner_ai/models.py b/bicleaner_ai/models.py index 6c3db5d..b76067b 100644 --- a/bicleaner_ai/models.py +++ b/bicleaner_ai/models.py @@ -486,7 +486,7 @@ class BCXLMRoberta(BaseModel): self.tokenizer = None self.settings = { - "model": 'jplu/tf-xlm-roberta-base', + "base_model": 'jplu/tf-xlm-roberta-base', "batch_size": 16, "maxlen": 150, "n_classes": 2, @@ -573,7 +573,7 @@ class BCXLMRoberta(BaseModel): logging.info("Loading training set") self.tokenizer = XLMRobertaTokenizerFast.from_pretrained( - self.settings["model"]) + self.settings["base_model"]) train_generator = self.get_generator(self.settings["batch_size"], shuffle=True) train_generator.load(train_set) @@ -594,12 +594,13 @@ class BCXLMRoberta(BaseModel): strategy = tf.distribute.MirroredStrategy() num_devices = strategy.num_replicas_in_sync with strategy.scope(): - self.model = self.load_model(self.settings["model"]) + self.model = self.load_model(self.settings["base_model"]) self.model.compile(optimizer=self.settings["optimizer"], loss=SparseCategoricalCrossentropy( from_logits=True), metrics=[FScore(argmax=True), MatthewsCorrCoef(argmax=True)]) + self.model.config._name_or_path = self.settings["model_name"] if logging.getLogger().level == logging.DEBUG: self.model.summary() |