Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/bitextor/bicleaner-ai.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorZJaume <jzaragoza@prompsit.com>2022-06-16 15:30:16 +0300
committerZJaume <jzaragoza@prompsit.com>2022-06-16 15:34:45 +0300
commit0f0268142d4d7327edc93e0ee957782839e37092 (patch)
tree67686249eb7b2146039ac251e24b0ad2eec2d2c6 /bicleaner_ai
parent11a034fec9be69ec038e3dd3067d3214521d17bb (diff)
Force no verbosity in predict by default
Diffstat (limited to 'bicleaner_ai')
-rw-r--r--bicleaner_ai/models.py8
1 files changed, 5 insertions, 3 deletions
diff --git a/bicleaner_ai/models.py b/bicleaner_ai/models.py
index e06c162..df7ebda 100644
--- a/bicleaner_ai/models.py
+++ b/bicleaner_ai/models.py
@@ -115,7 +115,8 @@ class ModelInterface(ABC):
pass
@abstractmethod
- def predict(self, x1, x2, batch_size=None, calibrated=False):
+ def predict(self, x1, x2, batch_size=None, calibrated=False,
+ raw=False, verbose=0):
pass
@abstractmethod
@@ -206,14 +207,15 @@ class BaseModel(ModelInterface):
'''Returns a compiled Keras model instance'''
raise NotImplementedError("Subclass must implement its model architecture")
- def predict(self, x1, x2, batch_size=None, calibrated=False, raw=False):
+ def predict(self, x1, x2, batch_size=None, calibrated=False,
+ raw=False, verbose=0):
'''Predicts from sequence generator'''
if batch_size is None:
batch_size = self.settings["batch_size"]
generator = self.get_generator(batch_size, shuffle=False)
generator.load((x1, x2, None))
- y_pred = self.model.predict(generator)
+ y_pred = self.model.predict(generator, verbose=verbose)
# Obtain logits if model returns HF output
if isinstance(y_pred, TFSequenceClassifierOutput):
y_pred = y_pred.logits