diff options
author | John Bauer <horatio@gmail.com> | 2022-11-01 19:03:20 +0300 |
---|---|---|
committer | John Bauer <horatio@gmail.com> | 2022-11-01 19:03:20 +0300 |
commit | 22d6a887a32d2b26de094a9954e4a12afad80bee (patch) | |
tree | b48a8d6fa9c646ecf85ca8d1639edce889472d4f | |
parent | 6a60761de75191498e7c12fbc5d9ddc2a8f0c12b (diff) |
Add an argument for partitioning / not partitioning lattn
-rw-r--r-- | stanza/models/constituency_parser.py | 1 | ||||
-rw-r--r-- | stanza/tests/constituency/test_lstm_model.py | 8 |
2 files changed, 9 insertions, 0 deletions
diff --git a/stanza/models/constituency_parser.py b/stanza/models/constituency_parser.py index 43a402bf..83174ee3 100644 --- a/stanza/models/constituency_parser.py +++ b/stanza/models/constituency_parser.py @@ -422,6 +422,7 @@ def parse_args(args=None): parser.add_argument('--lattn_pwff', default=True, action='store_true', help='Whether or not to use a Position-wise Feed-forward Layer') parser.add_argument('--lattn_q_as_matrix', default=False, action='store_true', help='Whether or not Label Attention uses learned query vectors. False means it does') parser.add_argument('--lattn_partitioned', default=True, action='store_true', help='Whether or not it is partitioned') + parser.add_argument('--no_lattn_partitioned', default=True, action='store_false', dest='lattn_partitioned', help='Whether or not it is partitioned') parser.add_argument('--lattn_combine_as_self', default=False, action='store_true', help='Whether or not the layer uses concatenation. False means it does') # currently unused - always assume 1/2 of pattn #parser.add_argument('--lattn_d_positional', default=512, type=int, help='Dimension for the positional embedding') diff --git a/stanza/tests/constituency/test_lstm_model.py b/stanza/tests/constituency/test_lstm_model.py index 0ef7c9c0..10d56ddb 100644 --- a/stanza/tests/constituency/test_lstm_model.py +++ b/stanza/tests/constituency/test_lstm_model.py @@ -254,6 +254,14 @@ def test_forward_labeled_attention(pretrain_file): model = build_model(pretrain_file, '--lattn_d_proj', '64', '--lattn_d_l', '16', '--lattn_combined_input') run_forward_checks(model) +def test_lattn_partitioned(pretrain_file): + model = build_model(pretrain_file, '--lattn_d_proj', '64', '--lattn_d_l', '16', '--lattn_partitioned') + run_forward_checks(model) + + model = build_model(pretrain_file, '--lattn_d_proj', '64', '--lattn_d_l', '16', '--no_lattn_partitioned') + run_forward_checks(model) + + def test_lattn_projection(pretrain_file): """ Test with & without labeled attention layers |