Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/bitextor/bicleaner-ai.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorZJaume <jzaragoza@prompsit.com>2022-08-24 16:51:24 +0300
committerZJaume <jzaragoza@prompsit.com>2022-08-24 16:51:24 +0300
commitff0a56b34001a3fe2eaef64c3e20dcaf7b31a175 (patch)
treeff168e4b78a938514d9ebf1f935ab2c70197008b
parent67c96d9f41b61e9924f6417dd3ed299d21e698be (diff)
Restore classifier layer loading for old models
-rw-r--r--bicleaner_ai/models.py19
1 files changed, 18 insertions, 1 deletions
diff --git a/bicleaner_ai/models.py b/bicleaner_ai/models.py
index 15dc19a..3327e47 100644
--- a/bicleaner_ai/models.py
+++ b/bicleaner_ai/models.py
@@ -12,6 +12,7 @@ from tensorflow.keras.losses import SparseCategoricalCrossentropy, BinaryCrossen
from tensorflow.keras.metrics import Precision, Recall
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.models import load_model
+from tensorflow.python.keras.saving import hdf5_format
from tensorflow.keras import layers
from contextlib import redirect_stdout
from glove import Corpus, Glove
@@ -21,7 +22,9 @@ import sentencepiece as sp
import tensorflow as tf
import numpy as np
import logging
+import h5py
import sys
+import os
try:
from . import decomposable_attention
@@ -659,5 +662,19 @@ class TFXLMRBicleanerAI(TFXLMRobertaForSequenceClassification):
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
+
+ # Set the appropiate name for the classifier head layer
+ # to avoid old models not being loaded correctly
+ name = 'bicleaner_ai_classification_head'
+ # If it's not a path, we are loading from hub, use the new name
+ if os.path.isdir(config._name_or_path):
+ # Inspect model file to guess if it has old head name
+ with h5py.File(config._name_or_path + '/tf_model.h5', 'r') as h5:
+ layers = set(hdf5_format.
+ load_attributes_from_hdf5_group(h5, "layer_names"))
+ if 'bc_classification_head' in layers:
+ name = 'bc_classification_head'
+
self.classifier = BicleanerAIClassificationHead(config,
- name='bicleaner_ai_classification_head')
+ name=name)
+