Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/stanfordnlp/stanza.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJohn Bauer <horatio@gmail.com>2022-05-01 05:59:35 +0300
committerJohn Bauer <horatio@gmail.com>2022-05-01 05:59:35 +0300
commit88934d2971b3968ed54ea17a23d494df1fb555c2 (patch)
treee51451683b3994dbbe51f9a025bb16aee34de8cd
parentcf89ca21c70263eaf80c98cef5a084515ca6cc89 (diff)
simplify a bitsentiment_charlm
-rw-r--r--stanza/models/classifiers/cnn_classifier.py6
1 files changed, 2 insertions, 4 deletions
diff --git a/stanza/models/classifiers/cnn_classifier.py b/stanza/models/classifiers/cnn_classifier.py
index 425d8278..7353a58e 100644
--- a/stanza/models/classifiers/cnn_classifier.py
+++ b/stanza/models/classifiers/cnn_classifier.py
@@ -297,10 +297,8 @@ class CNNClassifier(nn.Module):
char_reps_backward = self.build_char_reps(inputs, max_phrase_len, self.backward_charlm, self.charmodel_backward_projection, begin_paddings, device)
all_inputs.append(char_reps_backward)
- if len(all_inputs) > 1:
- input_vectors = torch.cat(all_inputs, dim=2)
- else:
- input_vectors = all_inputs[0]
+ # still works even if there's just one item
+ input_vectors = torch.cat(all_inputs, dim=2)
# reshape to fit the input tensors
x = input_vectors.unsqueeze(1)