Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/stanfordnlp/stanza.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJohn Bauer <horatio@gmail.com>2022-11-01 19:03:20 +0300
committerJohn Bauer <horatio@gmail.com>2022-11-01 19:03:20 +0300
commit22d6a887a32d2b26de094a9954e4a12afad80bee (patch)
treeb48a8d6fa9c646ecf85ca8d1639edce889472d4f
parent6a60761de75191498e7c12fbc5d9ddc2a8f0c12b (diff)
Add an argument for partitioning / not partitioning lattn
-rw-r--r--stanza/models/constituency_parser.py1
-rw-r--r--stanza/tests/constituency/test_lstm_model.py8
2 files changed, 9 insertions, 0 deletions
diff --git a/stanza/models/constituency_parser.py b/stanza/models/constituency_parser.py
index 43a402bf..83174ee3 100644
--- a/stanza/models/constituency_parser.py
+++ b/stanza/models/constituency_parser.py
@@ -422,6 +422,7 @@ def parse_args(args=None):
parser.add_argument('--lattn_pwff', default=True, action='store_true', help='Whether or not to use a Position-wise Feed-forward Layer')
parser.add_argument('--lattn_q_as_matrix', default=False, action='store_true', help='Whether or not Label Attention uses learned query vectors. False means it does')
parser.add_argument('--lattn_partitioned', default=True, action='store_true', help='Whether or not it is partitioned')
+ parser.add_argument('--no_lattn_partitioned', default=True, action='store_false', dest='lattn_partitioned', help='Whether or not it is partitioned')
parser.add_argument('--lattn_combine_as_self', default=False, action='store_true', help='Whether or not the layer uses concatenation. False means it does')
# currently unused - always assume 1/2 of pattn
#parser.add_argument('--lattn_d_positional', default=512, type=int, help='Dimension for the positional embedding')
diff --git a/stanza/tests/constituency/test_lstm_model.py b/stanza/tests/constituency/test_lstm_model.py
index 0ef7c9c0..10d56ddb 100644
--- a/stanza/tests/constituency/test_lstm_model.py
+++ b/stanza/tests/constituency/test_lstm_model.py
@@ -254,6 +254,14 @@ def test_forward_labeled_attention(pretrain_file):
model = build_model(pretrain_file, '--lattn_d_proj', '64', '--lattn_d_l', '16', '--lattn_combined_input')
run_forward_checks(model)
+def test_lattn_partitioned(pretrain_file):
+ model = build_model(pretrain_file, '--lattn_d_proj', '64', '--lattn_d_l', '16', '--lattn_partitioned')
+ run_forward_checks(model)
+
+ model = build_model(pretrain_file, '--lattn_d_proj', '64', '--lattn_d_l', '16', '--no_lattn_partitioned')
+ run_forward_checks(model)
+
+
def test_lattn_projection(pretrain_file):
"""
Test with & without labeled attention layers