Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/stanfordnlp/stanza.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJohn Bauer <horatio@gmail.com>2022-10-14 10:03:48 +0300
committerJohn Bauer <horatio@gmail.com>2022-11-07 23:56:04 +0300
commitf163c4c7eea19ad4a93afaf93d30183cb3c7b93c (patch)
treeab2fd11bffcd86b495b2a5139d687424346af92a
parenta07cf172e2ba3c8c82b4195dec9d72364ad4c5e0 (diff)
Add a variant of multihead attention where there's one key or one key per label
Seems competitive with MAX but not a big improvement of any kind Adds an optional position encoding to the KEY / UNTIED_KEY constituency compositions (using reduce_position as a parameter for the size, 0 -> no position) reduce_position needs to be an unsaved module so that it doesn't barf if it gets reloaded later with a different size Incudes comments on a couple variants that didn't work - linear after the attention or a double position vector
-rw-r--r--stanza/models/constituency/lstm_model.py72
-rw-r--r--stanza/models/constituency_parser.py4
-rw-r--r--stanza/tests/constituency/test_lstm_model.py23
3 files changed, 98 insertions, 1 deletions
diff --git a/stanza/models/constituency/lstm_model.py b/stanza/models/constituency/lstm_model.py
index 9f08535a..f0ec86e9 100644
--- a/stanza/models/constituency/lstm_model.py
+++ b/stanza/models/constituency/lstm_model.py
@@ -45,6 +45,7 @@ from stanza.models.constituency.lstm_tree_stack import LSTMTreeStack
from stanza.models.constituency.parse_transitions import TransitionScheme
from stanza.models.constituency.parse_tree import Tree
from stanza.models.constituency.partitioned_transformer import PartitionedTransformerModule
+from stanza.models.constituency.positional_encoding import ConcatSinusoidalEncoding
from stanza.models.constituency.transformer_tree_stack import TransformerTreeStack
from stanza.models.constituency.tree_stack import TreeStack
from stanza.models.constituency.utils import build_nonlinearity, initialize_linear, TextTooLongError
@@ -170,6 +171,31 @@ class StackHistory(Enum):
# UNTIED_MAX 0.9592
# Furthermore, starting from a finished MAX model and restarting
# by splitting the MAX layer into multiple pieces did not improve.
+#
+# KEY has a single Key which is used for a facsimile of ATTN
+# each incoming subtree has its values weighted by a Query
+# then the Key is used to calculate a softmax
+# finally, a Value is used to scale the subtrees
+# reduce_heads is used to determine the number of heads
+# There is an option to use or not use position information
+# using a sinusoidal position embedding
+# UNTIED_KEY is the same, but has a different key
+# for each possible constituent
+# On a VI dataset:
+# MAX 0.82064
+# KEY (pos, 8) 0.81739
+# UNTIED_KEY (pos, 8) 0.82046
+# UNTIED_KEY (pos, 4) 0.81742
+# Attempted to add a linear to mix the attn heads together,
+# but that was awful: 0.81567
+# Adding two position vectors, one in each direction, did not help:
+# UNTIED_KEY (2x pos, 8) 0.8188
+# To redo that experiment, double the width of reduce_query and
+# reduce_value, then call reduce_position on nhx, flip it,
+# and call reduce_position again
+# Evidently the experiments to try should be:
+# no pos at all
+# more heads
class ConstituencyComposition(Enum):
BILSTM = 1
MAX = 2
@@ -179,6 +205,8 @@ class ConstituencyComposition(Enum):
ATTN = 6
TREE_LSTM_CX = 7
UNTIED_MAX = 8
+ KEY = 9
+ UNTIED_KEY = 10
class LSTMModel(BaseModel, nn.Module):
def __init__(self, pretrain, forward_charlm, backward_charlm, bert_model, bert_tokenizer, transitions, constituents, tags, words, rare_words, root_labels, constituent_opens, unary_limit, args):
@@ -225,7 +253,7 @@ class LSTMModel(BaseModel, nn.Module):
self.hidden_size = self.args['hidden_size']
self.constituency_composition = self.args.get("constituency_composition", ConstituencyComposition.BILSTM)
- if self.constituency_composition == ConstituencyComposition.ATTN:
+ if self.constituency_composition in (ConstituencyComposition.ATTN, ConstituencyComposition.KEY, ConstituencyComposition.UNTIED_KEY):
self.reduce_heads = self.args['reduce_heads']
if self.hidden_size % self.reduce_heads != 0:
self.hidden_size = self.hidden_size + self.reduce_heads - (self.hidden_size % self.reduce_heads)
@@ -494,6 +522,24 @@ class LSTMModel(BaseModel, nn.Module):
initialize_linear(self.reduce_bigram, self.args['nonlinearity'], self.hidden_size)
elif self.constituency_composition == ConstituencyComposition.ATTN:
self.reduce_attn = nn.MultiheadAttention(self.hidden_size, self.reduce_heads)
+ elif self.constituency_composition == ConstituencyComposition.KEY or self.constituency_composition == ConstituencyComposition.UNTIED_KEY:
+ if self.args['reduce_position']:
+ # unsaved module so that if it grows, we don't save
+ # the larger version unnecessarily
+ # under any normal circumstances, the growth will
+ # happen early in training when the model is not
+ # behaving well, then will not be needed once the
+ # model learns not to make super degenerate
+ # constituents
+ self.add_unsaved_module("reduce_position", ConcatSinusoidalEncoding(self.args['reduce_position'], 50))
+ else:
+ self.add_unsaved_module("reduce_position", nn.Identity())
+ self.reduce_query = nn.Linear(self.hidden_size + self.args['reduce_position'], self.hidden_size, bias=False)
+ self.reduce_value = nn.Linear(self.hidden_size + self.args['reduce_position'], self.hidden_size)
+ if self.constituency_composition == ConstituencyComposition.KEY:
+ self.register_parameter('reduce_key', torch.nn.Parameter(torch.randn(self.reduce_heads, self.hidden_size // self.reduce_heads, 1, requires_grad=True)))
+ else:
+ self.register_parameter('reduce_key', torch.nn.Parameter(torch.randn(len(constituent_opens), self.reduce_heads, self.hidden_size // self.reduce_heads, 1, requires_grad=True)))
elif self.constituency_composition == ConstituencyComposition.TREE_LSTM:
self.constituent_reduce_lstm = nn.LSTM(input_size=self.hidden_size, hidden_size=self.hidden_size, num_layers=self.num_tree_lstm_layers, dropout=self.lstm_layer_dropout)
elif self.constituency_composition == ConstituencyComposition.TREE_LSTM_CX:
@@ -781,6 +827,8 @@ class LSTMModel(BaseModel, nn.Module):
children_lists is a list of children that go under each of the new nodes
lists of each are used so that we can stack operations
"""
+ # at the end of each of these operations, we expect lstm_hx.shape
+ # is (L, N, hidden_size) for N lists of children
if (self.constituency_composition == ConstituencyComposition.BILSTM or
self.constituency_composition == ConstituencyComposition.BILSTM_MAX):
node_hx = [[child.value.tree_hx.squeeze(0) for child in children] for children in children_lists]
@@ -864,6 +912,28 @@ class LSTMModel(BaseModel, nn.Module):
hx = torch.stack(unpacked_hx, axis=0)
lstm_hx = self.nonlinearity(hx).unsqueeze(0)
lstm_cx = None
+ elif self.constituency_composition == ConstituencyComposition.KEY or self.constituency_composition == ConstituencyComposition.UNTIED_KEY:
+ node_hx = [torch.stack([child.value.tree_hx for child in children]) for children in children_lists]
+ # add a position vector to each node_hx
+ node_hx = [self.reduce_position(x.reshape(x.shape[0], -1)) for x in node_hx]
+ query_hx = [self.reduce_query(nhx) for nhx in node_hx]
+ # reshape query for MHA
+ query_hx = [nhx.reshape(nhx.shape[0], self.reduce_heads, -1).transpose(0, 1) for nhx in query_hx]
+ if self.constituency_composition == ConstituencyComposition.KEY:
+ queries = [torch.matmul(nhx, self.reduce_key) for nhx in query_hx]
+ else:
+ label_indices = [self.constituent_open_map[label] for label in labels]
+ queries = [torch.matmul(nhx, self.reduce_key[label_idx]) for nhx, label_idx in zip(query_hx, label_indices)]
+ # softmax each head
+ weights = [torch.nn.functional.softmax(nhx, dim=1).transpose(1, 2) for nhx in queries]
+ value_hx = [self.reduce_value(nhx) for nhx in node_hx]
+ value_hx = [nhx.reshape(nhx.shape[0], self.reduce_heads, -1).transpose(0, 1) for nhx in value_hx]
+ # use the softmaxes to add up the heads
+ unpacked_hx = [torch.matmul(weight, nhx).squeeze(1) for weight, nhx in zip(weights, value_hx)]
+ unpacked_hx = [nhx.reshape(-1) for nhx in unpacked_hx]
+ hx = torch.stack(unpacked_hx, axis=0).unsqueeze(0)
+ lstm_hx = self.nonlinearity(hx)
+ lstm_cx = None
elif self.constituency_composition in (ConstituencyComposition.TREE_LSTM, ConstituencyComposition.TREE_LSTM_CX):
label_hx = [self.lstm_input_dropout(self.constituent_open_embedding(self.constituent_open_tensors[self.constituent_open_map[label]])) for label in labels]
label_hx = torch.stack(label_hx).unsqueeze(0)
diff --git a/stanza/models/constituency_parser.py b/stanza/models/constituency_parser.py
index 83174ee3..98a166ad 100644
--- a/stanza/models/constituency_parser.py
+++ b/stanza/models/constituency_parser.py
@@ -390,6 +390,7 @@ def parse_args(args=None):
parser.add_argument('--constituency_composition', default=ConstituencyComposition.MAX, type=lambda x: ConstituencyComposition[x.upper()],
help='How to build a new composition from its children. {}'.format(", ".join(x.name for x in ConstituencyComposition)))
parser.add_argument('--reduce_heads', default=8, type=int, help='Number of attn heads to use when reducing children into a parent tree (constituency_composition == attn)')
+ parser.add_argument('--reduce_position', default=None, type=int, help="Dimension of position vector to use when reducing children. None means 1/4 hidden_size, 0 means don't use (constituency_composition == key | untied_key)")
parser.add_argument('--relearn_structure', action='store_true', help='Starting from an existing checkpoint, add or remove pattn / lattn. One thing that works well is to train an initial model using adadelta with no pattn, then add pattn with adamw')
parser.add_argument('--finetune', action='store_true', help='Load existing model during `train` mode from `load_name` path')
@@ -462,6 +463,9 @@ def parse_args(args=None):
if args.stage1_learning_rate is None:
args.stage1_learning_rate = DEFAULT_LEARNING_RATES["adadelta"]
+ if args.reduce_position is None:
+ args.reduce_position = args.hidden_size // 4
+
if args.num_tree_lstm_layers is None:
if args.constituency_composition in (ConstituencyComposition.TREE_LSTM, ConstituencyComposition.TREE_LSTM_CX):
args.num_tree_lstm_layers = 2
diff --git a/stanza/tests/constituency/test_lstm_model.py b/stanza/tests/constituency/test_lstm_model.py
index bf3be952..6276c19f 100644
--- a/stanza/tests/constituency/test_lstm_model.py
+++ b/stanza/tests/constituency/test_lstm_model.py
@@ -200,6 +200,12 @@ def test_forward_constituency_composition(pretrain_file):
model = build_model(pretrain_file, '--constituency_composition', 'max')
run_forward_checks(model, num_states=2)
+ model = build_model(pretrain_file, '--constituency_composition', 'key')
+ run_forward_checks(model, num_states=2)
+
+ model = build_model(pretrain_file, '--constituency_composition', 'untied_key')
+ run_forward_checks(model, num_states=2)
+
model = build_model(pretrain_file, '--constituency_composition', 'untied_max')
run_forward_checks(model, num_states=2)
@@ -218,6 +224,23 @@ def test_forward_constituency_composition(pretrain_file):
model = build_model(pretrain_file, '--constituency_composition', 'attn')
run_forward_checks(model, num_states=2)
+def test_forward_key_position(pretrain_file):
+ """
+ Test KEY and UNTIED_KEY either with or without reduce_position
+ """
+ model = build_model(pretrain_file, '--constituency_composition', 'untied_key', '--reduce_position', '0')
+ run_forward_checks(model, num_states=2)
+
+ model = build_model(pretrain_file, '--constituency_composition', 'untied_key', '--reduce_position', '32')
+ run_forward_checks(model, num_states=2)
+
+ model = build_model(pretrain_file, '--constituency_composition', 'key', '--reduce_position', '0')
+ run_forward_checks(model, num_states=2)
+
+ model = build_model(pretrain_file, '--constituency_composition', 'key', '--reduce_position', '32')
+ run_forward_checks(model, num_states=2)
+
+
def test_forward_attn_hidden_size(pretrain_file):
"""
Test that when attn is used with hidden sizes not evenly divisible by reduce_heads, the model reconfigures the hidden_size