Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/bitextor/bicleaner-ai.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorZJaume <jzaragoza@prompsit.com>2022-07-20 15:15:02 +0300
committerJaume Zaragoza <ZJaume@users.noreply.github.com>2022-07-27 15:20:55 +0300
commit3fbb1d3d122b3ee0d5df6f7b7eda89dd0b006b66 (patch)
tree9f6b01949411bf8919819b3dc4d92b57d3d97199
parent54a67245c43832325707c16e3ca521e429bc72e3 (diff)
Overwrite XLMRConfig class
This makes the classes more compatible with HF API and to be able to load them later more easily.
-rw-r--r--bicleaner_ai/layers.py16
-rw-r--r--bicleaner_ai/models.py43
2 files changed, 40 insertions, 19 deletions
diff --git a/bicleaner_ai/layers.py b/bicleaner_ai/layers.py
index 7d19aa0..95264f0 100644
--- a/bicleaner_ai/layers.py
+++ b/bicleaner_ai/layers.py
@@ -67,18 +67,22 @@ class TransformerBlock(layers.Layer):
ffn_output = self.dropout2(ffn_output, training=training)
return self.layernorm2(out1 + ffn_output)
-class BCClassificationHead(layers.Layer):
- """Head for sentence-level classification tasks."""
+class BicleanerAIClassificationHead(layers.Layer):
+ """
+ Head for Bicleaner sentence classification tasks.
+ It reads BicleanerAIConfig to be able to change
+ classifier layer parameters (size, activation and dropout)
+ """
- def __init__(self, config, hidden_size, dropout, activation, **kwargs):
+ def __init__(self, config, **kwargs):
super().__init__(**kwargs)
self.dense = layers.Dense(
- hidden_size,
+ config.head_hidden_size,
kernel_initializer=get_initializer(config.initializer_range),
- activation=activation,
+ activation=config.head_activation,
name="dense",
)
- self.dropout = layers.Dropout(dropout)
+ self.dropout = layers.Dropout(config.head_dropout)
self.out_proj = layers.Dense(
config.num_labels,
kernel_initializer=get_initializer(config.initializer_range),
diff --git a/bicleaner_ai/models.py b/bicleaner_ai/models.py
index c5d7ea6..ea00bca 100644
--- a/bicleaner_ai/models.py
+++ b/bicleaner_ai/models.py
@@ -1,4 +1,8 @@
-from transformers import TFXLMRobertaForSequenceClassification, XLMRobertaTokenizerFast
+from transformers import (
+ TFXLMRobertaForSequenceClassification,
+ XLMRobertaConfig,
+ XLMRobertaTokenizerFast
+)
from transformers.modeling_tf_outputs import TFSequenceClassifierOutput
from transformers.optimization_tf import create_optimizer
from tensorflow.keras.optimizers.schedules import InverseTimeDecay
@@ -29,7 +33,7 @@ try:
from .layers import (
TransformerBlock,
TokenAndPositionEmbedding,
- BCClassificationHead)
+ BicleanerAIClassificationHead)
except (SystemError, ImportError):
import decomposable_attention
from metrics import FScore, MatthewsCorrCoef
@@ -40,7 +44,7 @@ except (SystemError, ImportError):
from layers import (
TransformerBlock,
TokenAndPositionEmbedding,
- BCClassificationHead)
+ BicleanerAIClassificationHead)
def calibrate_output(y_true, y_pred):
''' Platt calibration
@@ -522,7 +526,7 @@ class BCXLMRoberta(BaseModel):
def load_model(self, model_file):
settings = self.settings
- tf_model = TFXLMRBicleaner.from_pretrained(
+ tf_model = TFXLMRBicleanerAI.from_pretrained(
model_file,
num_labels=settings["n_classes"],
head_hidden_size=settings["n_hidden"],
@@ -635,13 +639,26 @@ class BCXLMRoberta(BaseModel):
return y_true, y_pred
-class TFXLMRBicleaner(TFXLMRobertaForSequenceClassification):
- """Model for sentence-level classification tasks."""
+class XLMRBicleanerAIConfig(XLMRobertaConfig):
+ '''
+ Bicleaner AI XLMR configuration class
+ adds the config parameters for the classification layer
+ '''
+ def __init__(self, head_hidden_size=2048, head_dropout=0.1,
+ head_activation='relu', **kwargs):
+ super().__init__(**kwargs)
+ self.head_hidden_size = head_hidden_size
+ self.head_dropout = head_dropout
+ self.head_activation = head_activation
+
+class TFXLMRBicleanerAI(TFXLMRobertaForSequenceClassification):
+ '''
+ Model for Bicleaner sentence-level classification tasks.
+ Overwrites XLMRoberta classifiacion layer.
+ '''
+ config_class = XLMRBicleanerAIConfig
- def __init__(self, config, head_hidden_size, head_dropout, head_activation):
- super().__init__(config)
- self.classifier = BCClassificationHead(config,
- head_hidden_size,
- head_dropout,
- head_activation,
- name='bc_classification_head')
+ def __init__(self, config, *inputs, **kwargs):
+ super().__init__(config, *inputs, **kwargs)
+ self.classifier = BicleanerAIClassificationHead(config,
+ name='bicleaner_ai_classification_head')