Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/stanfordnlp/stanza.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJohn Bauer <horatio@gmail.com>2022-11-01 23:28:50 +0300
committerJohn Bauer <horatio@gmail.com>2022-11-01 23:46:29 +0300
commit5fd6ff3667e5dc931436f6c8dca1c2d722a82bf3 (patch)
tree8fc75d89c605991aed9415e34f62c2ff0186a9f2
parent22d6a887a32d2b26de094a9954e4a12afad80bee (diff)
lattn_partitioned == False should affect the input proj dimension as well
-rw-r--r--stanza/models/constituency/label_attention.py20
-rw-r--r--stanza/tests/constituency/test_lstm_model.py3
2 files changed, 16 insertions, 7 deletions
diff --git a/stanza/models/constituency/label_attention.py b/stanza/models/constituency/label_attention.py
index 0304ae92..2eac168b 100644
--- a/stanza/models/constituency/label_attention.py
+++ b/stanza/models/constituency/label_attention.py
@@ -642,12 +642,15 @@ class LabelAttentionModule(nn.Module):
super().__init__()
self.ff_dim = d_proj * d_l
- self.d_positional = d_positional if d_positional else 0
+ if not lattn_partitioned:
+ self.d_positional = 0
+ else:
+ self.d_positional = d_positional if d_positional else 0
if d_input_proj:
- if d_input_proj <= d_positional:
- raise ValueError("Illegal argument for d_input_proj: d_input_proj %d is smaller than d_positional %d" % (d_input_proj, d_positional))
- self.input_projection = nn.Linear(d_model - d_positional, d_input_proj - d_positional, bias=False)
+ if d_input_proj <= self.d_positional:
+ raise ValueError("Illegal argument for d_input_proj: d_input_proj %d is smaller than d_positional %d" % (d_input_proj, self.d_positional))
+ self.input_projection = nn.Linear(d_model - self.d_positional, d_input_proj - self.d_positional, bias=False)
d_input = d_input_proj
else:
self.input_projection = None
@@ -679,9 +682,12 @@ class LabelAttentionModule(nn.Module):
def forward(self, word_embeddings, tagged_word_lists):
if self.input_projection:
- word_embeddings = [torch.cat((self.input_projection(sentence[:, :-self.d_positional]),
- sentence[:, -self.d_positional:]), dim=1)
- for sentence in word_embeddings]
+ if self.d_positional > 0:
+ word_embeddings = [torch.cat((self.input_projection(sentence[:, :-self.d_positional]),
+ sentence[:, -self.d_positional:]), dim=1)
+ for sentence in word_embeddings]
+ else:
+ word_embeddings = [self.input_projection(sentence) for sentence in word_embeddings]
# Extract Labeled Representation
packed_len = sum(sentence.shape[0] for sentence in word_embeddings)
batch_idxs = np.zeros(packed_len, dtype=int)
diff --git a/stanza/tests/constituency/test_lstm_model.py b/stanza/tests/constituency/test_lstm_model.py
index 10d56ddb..868ecc86 100644
--- a/stanza/tests/constituency/test_lstm_model.py
+++ b/stanza/tests/constituency/test_lstm_model.py
@@ -271,6 +271,9 @@ def test_lattn_projection(pretrain_file):
model = build_model(pretrain_file, '--pattn_d_model', '1024', '--lattn_d_proj', '64', '--lattn_d_l', '16', '--lattn_d_input_proj', '256')
run_forward_checks(model)
+ model = build_model(pretrain_file, '--pattn_d_model', '1024', '--lattn_d_proj', '64', '--lattn_d_l', '16', '--no_lattn_partitioned', '--lattn_d_input_proj', '256')
+ run_forward_checks(model)
+
model = build_model(pretrain_file, '--lattn_d_proj', '64', '--lattn_d_l', '16', '--lattn_d_input_proj', '768')
run_forward_checks(model)