diff options
author | ZJaume <jzaragoza@prompsit.com> | 2022-07-20 15:15:02 +0300 |
---|---|---|
committer | Jaume Zaragoza <ZJaume@users.noreply.github.com> | 2022-07-27 15:20:55 +0300 |
commit | 3fbb1d3d122b3ee0d5df6f7b7eda89dd0b006b66 (patch) | |
tree | 9f6b01949411bf8919819b3dc4d92b57d3d97199 | |
parent | 54a67245c43832325707c16e3ca521e429bc72e3 (diff) |
Overwrite XLMRConfig class
This makes the classes more compatible with HF API and to be able to
load them later more easily.
-rw-r--r-- | bicleaner_ai/layers.py | 16 | ||||
-rw-r--r-- | bicleaner_ai/models.py | 43 |
2 files changed, 40 insertions, 19 deletions
diff --git a/bicleaner_ai/layers.py b/bicleaner_ai/layers.py index 7d19aa0..95264f0 100644 --- a/bicleaner_ai/layers.py +++ b/bicleaner_ai/layers.py @@ -67,18 +67,22 @@ class TransformerBlock(layers.Layer): ffn_output = self.dropout2(ffn_output, training=training) return self.layernorm2(out1 + ffn_output) -class BCClassificationHead(layers.Layer): - """Head for sentence-level classification tasks.""" +class BicleanerAIClassificationHead(layers.Layer): + """ + Head for Bicleaner sentence classification tasks. + It reads BicleanerAIConfig to be able to change + classifier layer parameters (size, activation and dropout) + """ - def __init__(self, config, hidden_size, dropout, activation, **kwargs): + def __init__(self, config, **kwargs): super().__init__(**kwargs) self.dense = layers.Dense( - hidden_size, + config.head_hidden_size, kernel_initializer=get_initializer(config.initializer_range), - activation=activation, + activation=config.head_activation, name="dense", ) - self.dropout = layers.Dropout(dropout) + self.dropout = layers.Dropout(config.head_dropout) self.out_proj = layers.Dense( config.num_labels, kernel_initializer=get_initializer(config.initializer_range), diff --git a/bicleaner_ai/models.py b/bicleaner_ai/models.py index c5d7ea6..ea00bca 100644 --- a/bicleaner_ai/models.py +++ b/bicleaner_ai/models.py @@ -1,4 +1,8 @@ -from transformers import TFXLMRobertaForSequenceClassification, XLMRobertaTokenizerFast +from transformers import ( + TFXLMRobertaForSequenceClassification, + XLMRobertaConfig, + XLMRobertaTokenizerFast +) from transformers.modeling_tf_outputs import TFSequenceClassifierOutput from transformers.optimization_tf import create_optimizer from tensorflow.keras.optimizers.schedules import InverseTimeDecay @@ -29,7 +33,7 @@ try: from .layers import ( TransformerBlock, TokenAndPositionEmbedding, - BCClassificationHead) + BicleanerAIClassificationHead) except (SystemError, ImportError): import decomposable_attention from metrics import FScore, MatthewsCorrCoef @@ -40,7 +44,7 @@ except (SystemError, ImportError): from layers import ( TransformerBlock, TokenAndPositionEmbedding, - BCClassificationHead) + BicleanerAIClassificationHead) def calibrate_output(y_true, y_pred): ''' Platt calibration @@ -522,7 +526,7 @@ class BCXLMRoberta(BaseModel): def load_model(self, model_file): settings = self.settings - tf_model = TFXLMRBicleaner.from_pretrained( + tf_model = TFXLMRBicleanerAI.from_pretrained( model_file, num_labels=settings["n_classes"], head_hidden_size=settings["n_hidden"], @@ -635,13 +639,26 @@ class BCXLMRoberta(BaseModel): return y_true, y_pred -class TFXLMRBicleaner(TFXLMRobertaForSequenceClassification): - """Model for sentence-level classification tasks.""" +class XLMRBicleanerAIConfig(XLMRobertaConfig): + ''' + Bicleaner AI XLMR configuration class + adds the config parameters for the classification layer + ''' + def __init__(self, head_hidden_size=2048, head_dropout=0.1, + head_activation='relu', **kwargs): + super().__init__(**kwargs) + self.head_hidden_size = head_hidden_size + self.head_dropout = head_dropout + self.head_activation = head_activation + +class TFXLMRBicleanerAI(TFXLMRobertaForSequenceClassification): + ''' + Model for Bicleaner sentence-level classification tasks. + Overwrites XLMRoberta classifiacion layer. + ''' + config_class = XLMRBicleanerAIConfig - def __init__(self, config, head_hidden_size, head_dropout, head_activation): - super().__init__(config) - self.classifier = BCClassificationHead(config, - head_hidden_size, - head_dropout, - head_activation, - name='bc_classification_head') + def __init__(self, config, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + self.classifier = BicleanerAIClassificationHead(config, + name='bicleaner_ai_classification_head') |