diff options
Diffstat (limited to 'bicleaner_ai/datagen.py')
-rw-r--r-- | bicleaner_ai/datagen.py | 22 |
1 files changed, 9 insertions, 13 deletions
diff --git a/bicleaner_ai/datagen.py b/bicleaner_ai/datagen.py index b624ffc..fca8b46 100644 --- a/bicleaner_ai/datagen.py +++ b/bicleaner_ai/datagen.py @@ -117,18 +117,14 @@ class SentenceGenerator(tf.keras.utils.Sequence): # Build array of sample weights # If no parsable float is detected assume that there are the tags if len(data) >= 4 and data[3]: - try: - float(data[3][0]) - except ValueError: - logging.debug("No float detected at 4th field of the data, " - "ignoring data weights." - f" File: {source}") - # Load the tags (4th field) if requested - if not ignore_tags: - logging.debug(f"Loading tags for file {source}") - self.tags = np.array(data[3], dtype=str) - else: + if data[3][0].replace('.', '', 1).isdigit(): + logging.debug("Loading data weights") self.weights = np.array(data[3], dtype=float) + elif not ignore_tags: + logging.debug(f"Loading tags for file {source}") + self.tags = np.array(data[3], dtype=str) + else: + logging.debug("Ignoring fourth column as it is not numeric") # Index samples self.num_samples = len(data[0]) @@ -178,7 +174,7 @@ class ConcatSentenceGenerator(SentenceGenerator): padding="post", truncating="post", maxlen=self.maxlen) - att_mask = None + return input_ids else: # Tokenize with Transformers tokenizer that concatenates internally dataset = self.encoder(text1, text2, @@ -191,4 +187,4 @@ class ConcatSentenceGenerator(SentenceGenerator): input_ids = dataset["input_ids"] att_mask = dataset["attention_mask"] - return input_ids, att_mask + return input_ids, att_mask |