Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/TharinduDR/TransQuest.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTharinduDR <rhtdranasinghe@gmail.com>2021-03-24 02:10:46 +0300
committerTharinduDR <rhtdranasinghe@gmail.com>2021-03-24 02:10:46 +0300
commitc61b154e9f7677e95b2e8879ce7c06358147b6e1 (patch)
tree88bf4e1ed86d996b70320d4b45adb0f72d4ed4ad
parentf5182190dd197c4cb5a29be4058a15461a0b282b (diff)
056: Code Refactoringv1.0.1
-rw-r--r--transquest/algo/sentence_level/monotransquest/run_model.py12
1 files changed, 1 insertions, 11 deletions
diff --git a/transquest/algo/sentence_level/monotransquest/run_model.py b/transquest/algo/sentence_level/monotransquest/run_model.py
index e4a9623..6e7fa2c 100644
--- a/transquest/algo/sentence_level/monotransquest/run_model.py
+++ b/transquest/algo/sentence_level/monotransquest/run_model.py
@@ -1456,17 +1456,7 @@ class MonoTransQuestModel:
model_outputs = preds
else:
model_outputs = preds
- if multi_label:
- if isinstance(args.threshold, list):
- threshold_values = args.threshold
- preds = [
- [self._threshold(pred, threshold_values[i]) for i, pred in enumerate(example)]
- for example in preds
- ]
- else:
- preds = [[self._threshold(pred, args.threshold) for pred in example] for example in preds]
- else:
- preds = np.argmax(preds, axis=1)
+ preds = np.argmax(preds, axis=1)
if self.args.labels_map and not self.args.regression:
inverse_labels_map = {value: key for key, value in self.args.labels_map.items()}