diff options
author | John Bauer <horatio@gmail.com> | 2022-03-02 05:41:19 +0300 |
---|---|---|
committer | John Bauer <horatio@gmail.com> | 2022-03-09 12:06:52 +0300 |
commit | d2bc7f5f80ee2316dbc7dfef5f720b291430d8b8 (patch) | |
tree | 91eab135c831be188a9e335f9cb39504d9e90b1a | |
parent | e8db0c841bde42bc8b0187062cfa666338d6c53c (diff) |
Add a simple MHA to the modelcon_simple_transformer
Use simple_attn to replace the inputs
-rw-r--r-- | stanza/models/constituency/lstm_model.py | 15 | ||||
-rw-r--r-- | stanza/models/constituency/simple_attention.py | 51 |
2 files changed, 66 insertions, 0 deletions
diff --git a/stanza/models/constituency/lstm_model.py b/stanza/models/constituency/lstm_model.py index 76cc7615..a51f6b2a 100644 --- a/stanza/models/constituency/lstm_model.py +++ b/stanza/models/constituency/lstm_model.py @@ -45,6 +45,7 @@ from stanza.models.constituency.label_attention import LabelAttentionModule from stanza.models.constituency.parse_transitions import TransitionScheme from stanza.models.constituency.parse_tree import Tree from stanza.models.constituency.partitioned_transformer import PartitionedTransformerModule +from stanza.models.constituency.simple_attention import SimpleAttentionModule from stanza.models.constituency.tree_stack import TreeStack from stanza.models.constituency.utils import build_nonlinearity, initialize_linear, TextTooLongError @@ -259,6 +260,15 @@ class LSTMModel(BaseModel, nn.Module): self.bert_tokenizer = None self.is_phobert = False + # TODO: make these options + self.simple_attn = SimpleAttentionModule(4, + 8, + self.word_input_size, + self.word_input_size, + 128, + self.word_input_size // 2) + self.word_input_size = self.simple_attn.d_model + self.partitioned_transformer_module = None if self.args['pattn_num_heads'] > 0 and self.args['pattn_num_layers'] > 0: # Initializations of parameters for the Partitioned Attention @@ -643,6 +653,11 @@ class LSTMModel(BaseModel, nn.Module): bert_embeddings = [be[1:-1] for be in bert_embeddings] all_word_inputs = [torch.cat((x, y), axis=1) for x, y in zip(all_word_inputs, bert_embeddings)] + if self.simple_attn is not None: + # TODO: if this helps, batch the operations + attention_results = [self.simple_attn(x.unsqueeze(0)).squeeze(0) for x in all_word_inputs] + all_word_inputs = attention_results + # Extract partitioned representation if self.partitioned_transformer_module is not None: partitioned_embeddings = self.partitioned_transformer_module(None, all_word_inputs) diff --git a/stanza/models/constituency/simple_attention.py b/stanza/models/constituency/simple_attention.py new file mode 100644 index 00000000..02acdcc1 --- /dev/null +++ b/stanza/models/constituency/simple_attention.py @@ -0,0 +1,51 @@ +import logging + +import torch.nn as nn + +from stanza.models.constituency.positional_encoding import ConcatSinusoidalEncoding + +logger = logging.getLogger('stanza') + +class SimpleAttentionModule(nn.Module): + def __init__(self, + n_layers, + n_heads, + d_input, + d_model, + d_timing, + d_feed_forward): + super().__init__() + + if d_model <= d_timing: + d_model += d_timing + logger.warning("d_model <= d_timing. changing d_model to %d", d_model) + + if d_model % n_heads != 0: + d_model = d_model + n_heads - d_model % n_heads + logger.warning("d_model %% n_heads != 0. changing d_model to %d", d_model) + + self.d_model = d_model + self.attn_proj = nn.Linear(d_input, d_model - d_timing) + self.attn_timing = ConcatSinusoidalEncoding(d_model=d_timing) + self.attn_layers = nn.ModuleList([nn.MultiheadAttention(d_model, n_heads, batch_first=True) + for _ in range(n_layers)]) + self.linear_in = nn.ModuleList([nn.Linear(d_model, d_feed_forward) + for _ in range(n_layers)]) + self.linear_out = nn.ModuleList([nn.Linear(d_feed_forward, d_model) + for _ in range(n_layers)]) + self.nonlinearity = nn.ReLU() + + def forward(self, x): + x = self.attn_proj(x) + x = self.attn_timing(x) + + for layer, ff_in, ff_out in zip(self.attn_layers, self.linear_in, self.linear_out): + # TODO: residual dropout if this is working at all + x_attn = layer(x, x, x)[0] + x = x + x_attn + # TODO: layer norms? + x_ff = self.nonlinearity(ff_out(self.nonlinearity(ff_in(x)))) + x = x + x_ff + + return x + |