diff options
author | John Bauer <horatio@gmail.com> | 2022-10-14 10:03:48 +0300 |
---|---|---|
committer | John Bauer <horatio@gmail.com> | 2022-11-07 23:56:04 +0300 |
commit | f163c4c7eea19ad4a93afaf93d30183cb3c7b93c (patch) | |
tree | ab2fd11bffcd86b495b2a5139d687424346af92a | |
parent | a07cf172e2ba3c8c82b4195dec9d72364ad4c5e0 (diff) |
Add a variant of multihead attention where there's one key or one key per label
Seems competitive with MAX but not a big improvement of any kind
Adds an optional position encoding to the KEY / UNTIED_KEY constituency compositions
(using reduce_position as a parameter for the size, 0 -> no position)
reduce_position needs to be an unsaved module so that it doesn't barf if it gets reloaded later with a different size
Incudes comments on a couple variants that didn't work - linear after the attention or a double position vector
-rw-r--r-- | stanza/models/constituency/lstm_model.py | 72 | ||||
-rw-r--r-- | stanza/models/constituency_parser.py | 4 | ||||
-rw-r--r-- | stanza/tests/constituency/test_lstm_model.py | 23 |
3 files changed, 98 insertions, 1 deletions
diff --git a/stanza/models/constituency/lstm_model.py b/stanza/models/constituency/lstm_model.py index 9f08535a..f0ec86e9 100644 --- a/stanza/models/constituency/lstm_model.py +++ b/stanza/models/constituency/lstm_model.py @@ -45,6 +45,7 @@ from stanza.models.constituency.lstm_tree_stack import LSTMTreeStack from stanza.models.constituency.parse_transitions import TransitionScheme from stanza.models.constituency.parse_tree import Tree from stanza.models.constituency.partitioned_transformer import PartitionedTransformerModule +from stanza.models.constituency.positional_encoding import ConcatSinusoidalEncoding from stanza.models.constituency.transformer_tree_stack import TransformerTreeStack from stanza.models.constituency.tree_stack import TreeStack from stanza.models.constituency.utils import build_nonlinearity, initialize_linear, TextTooLongError @@ -170,6 +171,31 @@ class StackHistory(Enum): # UNTIED_MAX 0.9592 # Furthermore, starting from a finished MAX model and restarting # by splitting the MAX layer into multiple pieces did not improve. +# +# KEY has a single Key which is used for a facsimile of ATTN +# each incoming subtree has its values weighted by a Query +# then the Key is used to calculate a softmax +# finally, a Value is used to scale the subtrees +# reduce_heads is used to determine the number of heads +# There is an option to use or not use position information +# using a sinusoidal position embedding +# UNTIED_KEY is the same, but has a different key +# for each possible constituent +# On a VI dataset: +# MAX 0.82064 +# KEY (pos, 8) 0.81739 +# UNTIED_KEY (pos, 8) 0.82046 +# UNTIED_KEY (pos, 4) 0.81742 +# Attempted to add a linear to mix the attn heads together, +# but that was awful: 0.81567 +# Adding two position vectors, one in each direction, did not help: +# UNTIED_KEY (2x pos, 8) 0.8188 +# To redo that experiment, double the width of reduce_query and +# reduce_value, then call reduce_position on nhx, flip it, +# and call reduce_position again +# Evidently the experiments to try should be: +# no pos at all +# more heads class ConstituencyComposition(Enum): BILSTM = 1 MAX = 2 @@ -179,6 +205,8 @@ class ConstituencyComposition(Enum): ATTN = 6 TREE_LSTM_CX = 7 UNTIED_MAX = 8 + KEY = 9 + UNTIED_KEY = 10 class LSTMModel(BaseModel, nn.Module): def __init__(self, pretrain, forward_charlm, backward_charlm, bert_model, bert_tokenizer, transitions, constituents, tags, words, rare_words, root_labels, constituent_opens, unary_limit, args): @@ -225,7 +253,7 @@ class LSTMModel(BaseModel, nn.Module): self.hidden_size = self.args['hidden_size'] self.constituency_composition = self.args.get("constituency_composition", ConstituencyComposition.BILSTM) - if self.constituency_composition == ConstituencyComposition.ATTN: + if self.constituency_composition in (ConstituencyComposition.ATTN, ConstituencyComposition.KEY, ConstituencyComposition.UNTIED_KEY): self.reduce_heads = self.args['reduce_heads'] if self.hidden_size % self.reduce_heads != 0: self.hidden_size = self.hidden_size + self.reduce_heads - (self.hidden_size % self.reduce_heads) @@ -494,6 +522,24 @@ class LSTMModel(BaseModel, nn.Module): initialize_linear(self.reduce_bigram, self.args['nonlinearity'], self.hidden_size) elif self.constituency_composition == ConstituencyComposition.ATTN: self.reduce_attn = nn.MultiheadAttention(self.hidden_size, self.reduce_heads) + elif self.constituency_composition == ConstituencyComposition.KEY or self.constituency_composition == ConstituencyComposition.UNTIED_KEY: + if self.args['reduce_position']: + # unsaved module so that if it grows, we don't save + # the larger version unnecessarily + # under any normal circumstances, the growth will + # happen early in training when the model is not + # behaving well, then will not be needed once the + # model learns not to make super degenerate + # constituents + self.add_unsaved_module("reduce_position", ConcatSinusoidalEncoding(self.args['reduce_position'], 50)) + else: + self.add_unsaved_module("reduce_position", nn.Identity()) + self.reduce_query = nn.Linear(self.hidden_size + self.args['reduce_position'], self.hidden_size, bias=False) + self.reduce_value = nn.Linear(self.hidden_size + self.args['reduce_position'], self.hidden_size) + if self.constituency_composition == ConstituencyComposition.KEY: + self.register_parameter('reduce_key', torch.nn.Parameter(torch.randn(self.reduce_heads, self.hidden_size // self.reduce_heads, 1, requires_grad=True))) + else: + self.register_parameter('reduce_key', torch.nn.Parameter(torch.randn(len(constituent_opens), self.reduce_heads, self.hidden_size // self.reduce_heads, 1, requires_grad=True))) elif self.constituency_composition == ConstituencyComposition.TREE_LSTM: self.constituent_reduce_lstm = nn.LSTM(input_size=self.hidden_size, hidden_size=self.hidden_size, num_layers=self.num_tree_lstm_layers, dropout=self.lstm_layer_dropout) elif self.constituency_composition == ConstituencyComposition.TREE_LSTM_CX: @@ -781,6 +827,8 @@ class LSTMModel(BaseModel, nn.Module): children_lists is a list of children that go under each of the new nodes lists of each are used so that we can stack operations """ + # at the end of each of these operations, we expect lstm_hx.shape + # is (L, N, hidden_size) for N lists of children if (self.constituency_composition == ConstituencyComposition.BILSTM or self.constituency_composition == ConstituencyComposition.BILSTM_MAX): node_hx = [[child.value.tree_hx.squeeze(0) for child in children] for children in children_lists] @@ -864,6 +912,28 @@ class LSTMModel(BaseModel, nn.Module): hx = torch.stack(unpacked_hx, axis=0) lstm_hx = self.nonlinearity(hx).unsqueeze(0) lstm_cx = None + elif self.constituency_composition == ConstituencyComposition.KEY or self.constituency_composition == ConstituencyComposition.UNTIED_KEY: + node_hx = [torch.stack([child.value.tree_hx for child in children]) for children in children_lists] + # add a position vector to each node_hx + node_hx = [self.reduce_position(x.reshape(x.shape[0], -1)) for x in node_hx] + query_hx = [self.reduce_query(nhx) for nhx in node_hx] + # reshape query for MHA + query_hx = [nhx.reshape(nhx.shape[0], self.reduce_heads, -1).transpose(0, 1) for nhx in query_hx] + if self.constituency_composition == ConstituencyComposition.KEY: + queries = [torch.matmul(nhx, self.reduce_key) for nhx in query_hx] + else: + label_indices = [self.constituent_open_map[label] for label in labels] + queries = [torch.matmul(nhx, self.reduce_key[label_idx]) for nhx, label_idx in zip(query_hx, label_indices)] + # softmax each head + weights = [torch.nn.functional.softmax(nhx, dim=1).transpose(1, 2) for nhx in queries] + value_hx = [self.reduce_value(nhx) for nhx in node_hx] + value_hx = [nhx.reshape(nhx.shape[0], self.reduce_heads, -1).transpose(0, 1) for nhx in value_hx] + # use the softmaxes to add up the heads + unpacked_hx = [torch.matmul(weight, nhx).squeeze(1) for weight, nhx in zip(weights, value_hx)] + unpacked_hx = [nhx.reshape(-1) for nhx in unpacked_hx] + hx = torch.stack(unpacked_hx, axis=0).unsqueeze(0) + lstm_hx = self.nonlinearity(hx) + lstm_cx = None elif self.constituency_composition in (ConstituencyComposition.TREE_LSTM, ConstituencyComposition.TREE_LSTM_CX): label_hx = [self.lstm_input_dropout(self.constituent_open_embedding(self.constituent_open_tensors[self.constituent_open_map[label]])) for label in labels] label_hx = torch.stack(label_hx).unsqueeze(0) diff --git a/stanza/models/constituency_parser.py b/stanza/models/constituency_parser.py index 83174ee3..98a166ad 100644 --- a/stanza/models/constituency_parser.py +++ b/stanza/models/constituency_parser.py @@ -390,6 +390,7 @@ def parse_args(args=None): parser.add_argument('--constituency_composition', default=ConstituencyComposition.MAX, type=lambda x: ConstituencyComposition[x.upper()], help='How to build a new composition from its children. {}'.format(", ".join(x.name for x in ConstituencyComposition))) parser.add_argument('--reduce_heads', default=8, type=int, help='Number of attn heads to use when reducing children into a parent tree (constituency_composition == attn)') + parser.add_argument('--reduce_position', default=None, type=int, help="Dimension of position vector to use when reducing children. None means 1/4 hidden_size, 0 means don't use (constituency_composition == key | untied_key)") parser.add_argument('--relearn_structure', action='store_true', help='Starting from an existing checkpoint, add or remove pattn / lattn. One thing that works well is to train an initial model using adadelta with no pattn, then add pattn with adamw') parser.add_argument('--finetune', action='store_true', help='Load existing model during `train` mode from `load_name` path') @@ -462,6 +463,9 @@ def parse_args(args=None): if args.stage1_learning_rate is None: args.stage1_learning_rate = DEFAULT_LEARNING_RATES["adadelta"] + if args.reduce_position is None: + args.reduce_position = args.hidden_size // 4 + if args.num_tree_lstm_layers is None: if args.constituency_composition in (ConstituencyComposition.TREE_LSTM, ConstituencyComposition.TREE_LSTM_CX): args.num_tree_lstm_layers = 2 diff --git a/stanza/tests/constituency/test_lstm_model.py b/stanza/tests/constituency/test_lstm_model.py index bf3be952..6276c19f 100644 --- a/stanza/tests/constituency/test_lstm_model.py +++ b/stanza/tests/constituency/test_lstm_model.py @@ -200,6 +200,12 @@ def test_forward_constituency_composition(pretrain_file): model = build_model(pretrain_file, '--constituency_composition', 'max') run_forward_checks(model, num_states=2) + model = build_model(pretrain_file, '--constituency_composition', 'key') + run_forward_checks(model, num_states=2) + + model = build_model(pretrain_file, '--constituency_composition', 'untied_key') + run_forward_checks(model, num_states=2) + model = build_model(pretrain_file, '--constituency_composition', 'untied_max') run_forward_checks(model, num_states=2) @@ -218,6 +224,23 @@ def test_forward_constituency_composition(pretrain_file): model = build_model(pretrain_file, '--constituency_composition', 'attn') run_forward_checks(model, num_states=2) +def test_forward_key_position(pretrain_file): + """ + Test KEY and UNTIED_KEY either with or without reduce_position + """ + model = build_model(pretrain_file, '--constituency_composition', 'untied_key', '--reduce_position', '0') + run_forward_checks(model, num_states=2) + + model = build_model(pretrain_file, '--constituency_composition', 'untied_key', '--reduce_position', '32') + run_forward_checks(model, num_states=2) + + model = build_model(pretrain_file, '--constituency_composition', 'key', '--reduce_position', '0') + run_forward_checks(model, num_states=2) + + model = build_model(pretrain_file, '--constituency_composition', 'key', '--reduce_position', '32') + run_forward_checks(model, num_states=2) + + def test_forward_attn_hidden_size(pretrain_file): """ Test that when attn is used with hidden sizes not evenly divisible by reduce_heads, the model reconfigures the hidden_size |