diff options
author | ZJaume <jzaragoza@prompsit.com> | 2022-06-16 15:30:16 +0300 |
---|---|---|
committer | ZJaume <jzaragoza@prompsit.com> | 2022-06-16 15:34:45 +0300 |
commit | 0f0268142d4d7327edc93e0ee957782839e37092 (patch) | |
tree | 67686249eb7b2146039ac251e24b0ad2eec2d2c6 /bicleaner_ai | |
parent | 11a034fec9be69ec038e3dd3067d3214521d17bb (diff) |
Force no verbosity in predict by default
Diffstat (limited to 'bicleaner_ai')
-rw-r--r-- | bicleaner_ai/models.py | 8 |
1 files changed, 5 insertions, 3 deletions
diff --git a/bicleaner_ai/models.py b/bicleaner_ai/models.py index e06c162..df7ebda 100644 --- a/bicleaner_ai/models.py +++ b/bicleaner_ai/models.py @@ -115,7 +115,8 @@ class ModelInterface(ABC): pass @abstractmethod - def predict(self, x1, x2, batch_size=None, calibrated=False): + def predict(self, x1, x2, batch_size=None, calibrated=False, + raw=False, verbose=0): pass @abstractmethod @@ -206,14 +207,15 @@ class BaseModel(ModelInterface): '''Returns a compiled Keras model instance''' raise NotImplementedError("Subclass must implement its model architecture") - def predict(self, x1, x2, batch_size=None, calibrated=False, raw=False): + def predict(self, x1, x2, batch_size=None, calibrated=False, + raw=False, verbose=0): '''Predicts from sequence generator''' if batch_size is None: batch_size = self.settings["batch_size"] generator = self.get_generator(batch_size, shuffle=False) generator.load((x1, x2, None)) - y_pred = self.model.predict(generator) + y_pred = self.model.predict(generator, verbose=verbose) # Obtain logits if model returns HF output if isinstance(y_pred, TFSequenceClassifierOutput): y_pred = y_pred.logits |