diff options
-rw-r--r-- | transquest/algo/sentence_level/monotransquest/run_model.py | 12 |
1 files changed, 1 insertions, 11 deletions
diff --git a/transquest/algo/sentence_level/monotransquest/run_model.py b/transquest/algo/sentence_level/monotransquest/run_model.py index e4a9623..6e7fa2c 100644 --- a/transquest/algo/sentence_level/monotransquest/run_model.py +++ b/transquest/algo/sentence_level/monotransquest/run_model.py @@ -1456,17 +1456,7 @@ class MonoTransQuestModel: model_outputs = preds else: model_outputs = preds - if multi_label: - if isinstance(args.threshold, list): - threshold_values = args.threshold - preds = [ - [self._threshold(pred, threshold_values[i]) for i, pred in enumerate(example)] - for example in preds - ] - else: - preds = [[self._threshold(pred, args.threshold) for pred in example] for example in preds] - else: - preds = np.argmax(preds, axis=1) + preds = np.argmax(preds, axis=1) if self.args.labels_map and not self.args.regression: inverse_labels_map = {value: key for key, value in self.args.labels_map.items()} |