diff options
author | ZJaume <jzaragoza@prompsit.com> | 2021-11-04 13:28:07 +0300 |
---|---|---|
committer | ZJaume <jzaragoza@prompsit.com> | 2021-11-04 13:29:51 +0300 |
commit | 6ed2240286ce7c297e6b087ad83c7b1b61a8d181 (patch) | |
tree | 78d5aa500a415d2e3c105416647f54af761c3d9c | |
parent | 6b8210efdebe8ef1e7c8d324c2610f4c4124e2ca (diff) |
Add MCC as validation metric in XLMR, use default names for metrics
-rw-r--r-- | bicleaner_ai/metrics.py | 4 | ||||
-rw-r--r-- | bicleaner_ai/models.py | 8 |
2 files changed, 6 insertions, 6 deletions
diff --git a/bicleaner_ai/metrics.py b/bicleaner_ai/metrics.py index d995889..08838bd 100644 --- a/bicleaner_ai/metrics.py +++ b/bicleaner_ai/metrics.py @@ -13,7 +13,7 @@ class FScore(Metric): thresholds=None, top_k=None, class_id=None, - name=None, + name='f1', dtype=None, argmax=False): super(FScore, self).__init__(name=name, dtype=dtype) @@ -85,7 +85,7 @@ class MatthewsCorrCoef(Metric): thresholds=None, top_k=None, class_id=None, - name=None, + name='mcc', dtype=None, argmax=False): super(MatthewsCorrCoef, self).__init__(name=name, dtype=dtype) diff --git a/bicleaner_ai/models.py b/bicleaner_ai/models.py index f6fa6ff..ac08139 100644 --- a/bicleaner_ai/models.py +++ b/bicleaner_ai/models.py @@ -193,8 +193,8 @@ class BaseModel(ModelInterface): #TODO create argmax precision and recall or use categorical acc #Precision(name='p'), #Recall(name='r'), - FScore(name='f1', argmax=self.settings["distilled"]), - MatthewsCorrCoef(name='mcc', argmax=self.settings["distilled"]), + FScore(argmax=self.settings["distilled"]), + MatthewsCorrCoef(argmax=self.settings["distilled"]), ] def get_generator(self, batch_size, shuffle): @@ -574,8 +574,8 @@ class BCXLMRoberta(BaseModel): self.model.compile(optimizer=self.settings["optimizer"], loss=SparseCategoricalCrossentropy( from_logits=True), - metrics=[FScore(name='f1', - argmax=True)]) + metrics=[FScore(argmax=True), + MatthewsCorrCoef(argmax=True)]) if logging.getLogger().level == logging.DEBUG: self.model.summary() self.model.fit(train_generator, |