From ff0a56b34001a3fe2eaef64c3e20dcaf7b31a175 Mon Sep 17 00:00:00 2001 From: ZJaume Date: Wed, 24 Aug 2022 13:51:24 +0000 Subject: Restore classifier layer loading for old models --- bicleaner_ai/models.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/bicleaner_ai/models.py b/bicleaner_ai/models.py index 15dc19a..3327e47 100644 --- a/bicleaner_ai/models.py +++ b/bicleaner_ai/models.py @@ -12,6 +12,7 @@ from tensorflow.keras.losses import SparseCategoricalCrossentropy, BinaryCrossen from tensorflow.keras.metrics import Precision, Recall from tensorflow.keras.optimizers import Adam from tensorflow.keras.models import load_model +from tensorflow.python.keras.saving import hdf5_format from tensorflow.keras import layers from contextlib import redirect_stdout from glove import Corpus, Glove @@ -21,7 +22,9 @@ import sentencepiece as sp import tensorflow as tf import numpy as np import logging +import h5py import sys +import os try: from . import decomposable_attention @@ -659,5 +662,19 @@ class TFXLMRBicleanerAI(TFXLMRobertaForSequenceClassification): def __init__(self, config, *inputs, **kwargs): super().__init__(config, *inputs, **kwargs) + + # Set the appropiate name for the classifier head layer + # to avoid old models not being loaded correctly + name = 'bicleaner_ai_classification_head' + # If it's not a path, we are loading from hub, use the new name + if os.path.isdir(config._name_or_path): + # Inspect model file to guess if it has old head name + with h5py.File(config._name_or_path + '/tf_model.h5', 'r') as h5: + layers = set(hdf5_format. + load_attributes_from_hdf5_group(h5, "layer_names")) + if 'bc_classification_head' in layers: + name = 'bc_classification_head' + self.classifier = BicleanerAIClassificationHead(config, - name='bicleaner_ai_classification_head') + name=name) + -- cgit v1.2.3