Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/bitextor/bicleaner-ai.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'bicleaner_ai/datagen.py')
-rw-r--r--bicleaner_ai/datagen.py22
1 files changed, 9 insertions, 13 deletions
diff --git a/bicleaner_ai/datagen.py b/bicleaner_ai/datagen.py
index b624ffc..fca8b46 100644
--- a/bicleaner_ai/datagen.py
+++ b/bicleaner_ai/datagen.py
@@ -117,18 +117,14 @@ class SentenceGenerator(tf.keras.utils.Sequence):
# Build array of sample weights
# If no parsable float is detected assume that there are the tags
if len(data) >= 4 and data[3]:
- try:
- float(data[3][0])
- except ValueError:
- logging.debug("No float detected at 4th field of the data, "
- "ignoring data weights."
- f" File: {source}")
- # Load the tags (4th field) if requested
- if not ignore_tags:
- logging.debug(f"Loading tags for file {source}")
- self.tags = np.array(data[3], dtype=str)
- else:
+ if data[3][0].replace('.', '', 1).isdigit():
+ logging.debug("Loading data weights")
self.weights = np.array(data[3], dtype=float)
+ elif not ignore_tags:
+ logging.debug(f"Loading tags for file {source}")
+ self.tags = np.array(data[3], dtype=str)
+ else:
+ logging.debug("Ignoring fourth column as it is not numeric")
# Index samples
self.num_samples = len(data[0])
@@ -178,7 +174,7 @@ class ConcatSentenceGenerator(SentenceGenerator):
padding="post",
truncating="post",
maxlen=self.maxlen)
- att_mask = None
+ return input_ids
else:
# Tokenize with Transformers tokenizer that concatenates internally
dataset = self.encoder(text1, text2,
@@ -191,4 +187,4 @@ class ConcatSentenceGenerator(SentenceGenerator):
input_ids = dataset["input_ids"]
att_mask = dataset["attention_mask"]
- return input_ids, att_mask
+ return input_ids, att_mask