diff options
author | John Bauer <horatio@gmail.com> | 2022-11-01 23:28:50 +0300 |
---|---|---|
committer | John Bauer <horatio@gmail.com> | 2022-11-01 23:46:29 +0300 |
commit | 5fd6ff3667e5dc931436f6c8dca1c2d722a82bf3 (patch) | |
tree | 8fc75d89c605991aed9415e34f62c2ff0186a9f2 | |
parent | 22d6a887a32d2b26de094a9954e4a12afad80bee (diff) |
lattn_partitioned == False should affect the input proj dimension as well
-rw-r--r-- | stanza/models/constituency/label_attention.py | 20 | ||||
-rw-r--r-- | stanza/tests/constituency/test_lstm_model.py | 3 |
2 files changed, 16 insertions, 7 deletions
diff --git a/stanza/models/constituency/label_attention.py b/stanza/models/constituency/label_attention.py index 0304ae92..2eac168b 100644 --- a/stanza/models/constituency/label_attention.py +++ b/stanza/models/constituency/label_attention.py @@ -642,12 +642,15 @@ class LabelAttentionModule(nn.Module): super().__init__() self.ff_dim = d_proj * d_l - self.d_positional = d_positional if d_positional else 0 + if not lattn_partitioned: + self.d_positional = 0 + else: + self.d_positional = d_positional if d_positional else 0 if d_input_proj: - if d_input_proj <= d_positional: - raise ValueError("Illegal argument for d_input_proj: d_input_proj %d is smaller than d_positional %d" % (d_input_proj, d_positional)) - self.input_projection = nn.Linear(d_model - d_positional, d_input_proj - d_positional, bias=False) + if d_input_proj <= self.d_positional: + raise ValueError("Illegal argument for d_input_proj: d_input_proj %d is smaller than d_positional %d" % (d_input_proj, self.d_positional)) + self.input_projection = nn.Linear(d_model - self.d_positional, d_input_proj - self.d_positional, bias=False) d_input = d_input_proj else: self.input_projection = None @@ -679,9 +682,12 @@ class LabelAttentionModule(nn.Module): def forward(self, word_embeddings, tagged_word_lists): if self.input_projection: - word_embeddings = [torch.cat((self.input_projection(sentence[:, :-self.d_positional]), - sentence[:, -self.d_positional:]), dim=1) - for sentence in word_embeddings] + if self.d_positional > 0: + word_embeddings = [torch.cat((self.input_projection(sentence[:, :-self.d_positional]), + sentence[:, -self.d_positional:]), dim=1) + for sentence in word_embeddings] + else: + word_embeddings = [self.input_projection(sentence) for sentence in word_embeddings] # Extract Labeled Representation packed_len = sum(sentence.shape[0] for sentence in word_embeddings) batch_idxs = np.zeros(packed_len, dtype=int) diff --git a/stanza/tests/constituency/test_lstm_model.py b/stanza/tests/constituency/test_lstm_model.py index 10d56ddb..868ecc86 100644 --- a/stanza/tests/constituency/test_lstm_model.py +++ b/stanza/tests/constituency/test_lstm_model.py @@ -271,6 +271,9 @@ def test_lattn_projection(pretrain_file): model = build_model(pretrain_file, '--pattn_d_model', '1024', '--lattn_d_proj', '64', '--lattn_d_l', '16', '--lattn_d_input_proj', '256') run_forward_checks(model) + model = build_model(pretrain_file, '--pattn_d_model', '1024', '--lattn_d_proj', '64', '--lattn_d_l', '16', '--no_lattn_partitioned', '--lattn_d_input_proj', '256') + run_forward_checks(model) + model = build_model(pretrain_file, '--lattn_d_proj', '64', '--lattn_d_l', '16', '--lattn_d_input_proj', '768') run_forward_checks(model) |