diff options
author | TharinduDR <rhtdranasinghe@gmail.com> | 2021-03-24 02:10:46 +0300 |
---|---|---|
committer | TharinduDR <rhtdranasinghe@gmail.com> | 2021-03-24 02:10:46 +0300 |
commit | c61b154e9f7677e95b2e8879ce7c06358147b6e1 (patch) | |
tree | 88bf4e1ed86d996b70320d4b45adb0f72d4ed4ad | |
parent | f5182190dd197c4cb5a29be4058a15461a0b282b (diff) |
056: Code Refactoringv1.0.1
-rw-r--r-- | transquest/algo/sentence_level/monotransquest/run_model.py | 12 |
1 files changed, 1 insertions, 11 deletions
diff --git a/transquest/algo/sentence_level/monotransquest/run_model.py b/transquest/algo/sentence_level/monotransquest/run_model.py index e4a9623..6e7fa2c 100644 --- a/transquest/algo/sentence_level/monotransquest/run_model.py +++ b/transquest/algo/sentence_level/monotransquest/run_model.py @@ -1456,17 +1456,7 @@ class MonoTransQuestModel: model_outputs = preds else: model_outputs = preds - if multi_label: - if isinstance(args.threshold, list): - threshold_values = args.threshold - preds = [ - [self._threshold(pred, threshold_values[i]) for i, pred in enumerate(example)] - for example in preds - ] - else: - preds = [[self._threshold(pred, args.threshold) for pred in example] for example in preds] - else: - preds = np.argmax(preds, axis=1) + preds = np.argmax(preds, axis=1) if self.args.labels_map and not self.args.regression: inverse_labels_map = {value: key for key, value in self.args.labels_map.items()} |