Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/stanfordnlp/stanza.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJohn Bauer <horatio@gmail.com>2022-03-02 05:41:19 +0300
committerJohn Bauer <horatio@gmail.com>2022-03-09 12:06:52 +0300
commitd2bc7f5f80ee2316dbc7dfef5f720b291430d8b8 (patch)
tree91eab135c831be188a9e335f9cb39504d9e90b1a
parente8db0c841bde42bc8b0187062cfa666338d6c53c (diff)
Add a simple MHA to the modelcon_simple_transformer
Use simple_attn to replace the inputs
-rw-r--r--stanza/models/constituency/lstm_model.py15
-rw-r--r--stanza/models/constituency/simple_attention.py51
2 files changed, 66 insertions, 0 deletions
diff --git a/stanza/models/constituency/lstm_model.py b/stanza/models/constituency/lstm_model.py
index 76cc7615..a51f6b2a 100644
--- a/stanza/models/constituency/lstm_model.py
+++ b/stanza/models/constituency/lstm_model.py
@@ -45,6 +45,7 @@ from stanza.models.constituency.label_attention import LabelAttentionModule
from stanza.models.constituency.parse_transitions import TransitionScheme
from stanza.models.constituency.parse_tree import Tree
from stanza.models.constituency.partitioned_transformer import PartitionedTransformerModule
+from stanza.models.constituency.simple_attention import SimpleAttentionModule
from stanza.models.constituency.tree_stack import TreeStack
from stanza.models.constituency.utils import build_nonlinearity, initialize_linear, TextTooLongError
@@ -259,6 +260,15 @@ class LSTMModel(BaseModel, nn.Module):
self.bert_tokenizer = None
self.is_phobert = False
+ # TODO: make these options
+ self.simple_attn = SimpleAttentionModule(4,
+ 8,
+ self.word_input_size,
+ self.word_input_size,
+ 128,
+ self.word_input_size // 2)
+ self.word_input_size = self.simple_attn.d_model
+
self.partitioned_transformer_module = None
if self.args['pattn_num_heads'] > 0 and self.args['pattn_num_layers'] > 0:
# Initializations of parameters for the Partitioned Attention
@@ -643,6 +653,11 @@ class LSTMModel(BaseModel, nn.Module):
bert_embeddings = [be[1:-1] for be in bert_embeddings]
all_word_inputs = [torch.cat((x, y), axis=1) for x, y in zip(all_word_inputs, bert_embeddings)]
+ if self.simple_attn is not None:
+ # TODO: if this helps, batch the operations
+ attention_results = [self.simple_attn(x.unsqueeze(0)).squeeze(0) for x in all_word_inputs]
+ all_word_inputs = attention_results
+
# Extract partitioned representation
if self.partitioned_transformer_module is not None:
partitioned_embeddings = self.partitioned_transformer_module(None, all_word_inputs)
diff --git a/stanza/models/constituency/simple_attention.py b/stanza/models/constituency/simple_attention.py
new file mode 100644
index 00000000..02acdcc1
--- /dev/null
+++ b/stanza/models/constituency/simple_attention.py
@@ -0,0 +1,51 @@
+import logging
+
+import torch.nn as nn
+
+from stanza.models.constituency.positional_encoding import ConcatSinusoidalEncoding
+
+logger = logging.getLogger('stanza')
+
+class SimpleAttentionModule(nn.Module):
+ def __init__(self,
+ n_layers,
+ n_heads,
+ d_input,
+ d_model,
+ d_timing,
+ d_feed_forward):
+ super().__init__()
+
+ if d_model <= d_timing:
+ d_model += d_timing
+ logger.warning("d_model <= d_timing. changing d_model to %d", d_model)
+
+ if d_model % n_heads != 0:
+ d_model = d_model + n_heads - d_model % n_heads
+ logger.warning("d_model %% n_heads != 0. changing d_model to %d", d_model)
+
+ self.d_model = d_model
+ self.attn_proj = nn.Linear(d_input, d_model - d_timing)
+ self.attn_timing = ConcatSinusoidalEncoding(d_model=d_timing)
+ self.attn_layers = nn.ModuleList([nn.MultiheadAttention(d_model, n_heads, batch_first=True)
+ for _ in range(n_layers)])
+ self.linear_in = nn.ModuleList([nn.Linear(d_model, d_feed_forward)
+ for _ in range(n_layers)])
+ self.linear_out = nn.ModuleList([nn.Linear(d_feed_forward, d_model)
+ for _ in range(n_layers)])
+ self.nonlinearity = nn.ReLU()
+
+ def forward(self, x):
+ x = self.attn_proj(x)
+ x = self.attn_timing(x)
+
+ for layer, ff_in, ff_out in zip(self.attn_layers, self.linear_in, self.linear_out):
+ # TODO: residual dropout if this is working at all
+ x_attn = layer(x, x, x)[0]
+ x = x + x_attn
+ # TODO: layer norms?
+ x_ff = self.nonlinearity(ff_out(self.nonlinearity(ff_in(x))))
+ x = x + x_ff
+
+ return x
+