Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/stanfordnlp/stanza.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJohn Bauer <horatio@gmail.com>2022-11-01 10:35:33 +0300
committerJohn Bauer <horatio@gmail.com>2022-11-02 01:05:15 +0300
commitdded88d3c155f61fc76a95717fa7f3114c75ccc6 (patch)
tree828b42dcbec31ef8ab9c9f853cd779c9df6177e7
parent8bda5dfb52fdbc84a86f6937e06c22c6f5206bed (diff)
Move the pattn & lattn after the word lstm. The position information should be implicit in the LSTM itselfword_lstm_pattn
-rw-r--r--stanza/models/constituency/lstm_model.py101
-rw-r--r--stanza/models/constituency/partitioned_transformer.py18
-rw-r--r--stanza/models/constituency_parser.py4
-rw-r--r--stanza/tests/constituency/test_trainer.py22
4 files changed, 68 insertions, 77 deletions
diff --git a/stanza/models/constituency/lstm_model.py b/stanza/models/constituency/lstm_model.py
index 6c10c472..16ff3b0d 100644
--- a/stanza/models/constituency/lstm_model.py
+++ b/stanza/models/constituency/lstm_model.py
@@ -339,6 +339,9 @@ class LSTMModel(BaseModel, nn.Module):
self.bert_layer_mix = None
self.word_input_size = self.word_input_size + self.bert_dim
+ self.word_lstm = nn.LSTM(input_size=self.word_input_size, hidden_size=self.hidden_size, num_layers=self.num_lstm_layers, bidirectional=True, dropout=self.lstm_layer_dropout)
+
+ self.word_transform_size = self.hidden_size * 2
self.partitioned_transformer_module = None
self.pattn_d_model = 0
if LSTMModel.uses_pattn(self.args):
@@ -357,49 +360,45 @@ class LSTMModel(BaseModel, nn.Module):
ff_dropout=self.args['pattn_relu_dropout'],
residual_dropout=self.args['pattn_residual_dropout'],
attention_dropout=self.args['pattn_attention_dropout'],
- word_input_size=self.word_input_size,
+ word_input_size=self.hidden_size * 2,
bias=self.args['pattn_bias'],
morpho_emb_dropout=self.args['pattn_morpho_emb_dropout'],
timing=self.args['pattn_timing'],
encoder_max_len=self.args['pattn_encoder_max_len']
)
- self.word_input_size += self.pattn_d_model
+
+ self.word_transform_size += self.pattn_d_model
self.label_attention_module = None
if LSTMModel.uses_lattn(self.args):
- if self.partitioned_transformer_module is None:
- logger.error("Not using Labeled Attention, as the Partitioned Attention module is not used")
+ if self.partitioned_transformer_module is None and not self.args['lattn_combined_input']:
+ logger.warning("Switching to lattn_combined_input=True, as the partitioned transformer is not even active")
+ self.args['lattn_combined_input'] = True
+ if self.args['lattn_combined_input']:
+ self.lattn_d_input = self.word_transform_size
else:
- # TODO: think of a couple ways to use alternate inputs
- # for example, could pass in the word inputs with a positional embedding
- # that would also allow it to work in the case of no partitioned module
- if self.args['lattn_combined_input']:
- self.lattn_d_input = self.word_input_size
- else:
- self.lattn_d_input = self.pattn_d_model
- self.label_attention_module = LabelAttentionModule(self.lattn_d_input,
- self.args['lattn_d_input_proj'],
- self.args['lattn_d_kv'],
- self.args['lattn_d_kv'],
- self.args['lattn_d_l'],
- self.args['lattn_d_proj'],
- self.args['lattn_combine_as_self'],
- self.args['lattn_resdrop'],
- self.args['lattn_q_as_matrix'],
- self.args['lattn_residual_dropout'],
- self.args['lattn_attention_dropout'],
- self.pattn_d_model // 2,
- self.args['lattn_d_ff'],
- self.args['lattn_relu_dropout'],
- self.args['lattn_partitioned'])
- self.word_input_size = self.word_input_size + self.args['lattn_d_proj']*self.args['lattn_d_l']
-
- self.word_lstm = nn.LSTM(input_size=self.word_input_size, hidden_size=self.hidden_size, num_layers=self.num_lstm_layers, bidirectional=True, dropout=self.lstm_layer_dropout)
+ self.lattn_d_input = self.pattn_d_model
+ self.label_attention_module = LabelAttentionModule(self.lattn_d_input,
+ self.args['lattn_d_input_proj'],
+ self.args['lattn_d_kv'],
+ self.args['lattn_d_kv'],
+ self.args['lattn_d_l'],
+ self.args['lattn_d_proj'],
+ self.args['lattn_combine_as_self'],
+ self.args['lattn_resdrop'],
+ self.args['lattn_q_as_matrix'],
+ self.args['lattn_residual_dropout'],
+ self.args['lattn_attention_dropout'],
+ self.pattn_d_model // 2,
+ self.args['lattn_d_ff'],
+ self.args['lattn_relu_dropout'],
+ self.args['lattn_partitioned'])
+ self.word_transform_size = self.word_transform_size + self.args['lattn_d_proj']*self.args['lattn_d_l']
# after putting the word_delta_tag input through the word_lstm, we get back
# hidden_size * 2 output with the front and back lstms concatenated.
# this transforms it into hidden_size with the values mixed together
- self.word_to_constituent = nn.Linear(self.hidden_size * 2, self.hidden_size * self.num_tree_lstm_layers)
+ self.word_to_constituent = nn.Linear(self.word_transform_size, self.hidden_size * self.num_tree_lstm_layers)
initialize_linear(self.word_to_constituent, self.args['nonlinearity'], self.hidden_size * 2)
self.transitions = sorted(list(transitions))
@@ -539,8 +538,9 @@ class LSTMModel(BaseModel, nn.Module):
raise ValueError("Unexpected other parameter name {}".format(name))
for idx in range(len(self.constituent_opens)):
my_parameter[idx].data.copy_(other_parameter.data)
- elif name.startswith('word_lstm.weight_ih_l0'):
- # bottom layer shape may have changed from adding a new pattn / lattn block
+ elif name.startswith('word_to_constituent'):
+ # transformation from word_lstm to constituent
+ # might have changed from adding a new pattn / lattn block
my_parameter = self.get_parameter(name)
# -1 so that it can be converted easier to a different parameter
copy_size = min(other_parameter.data.shape[-1], my_parameter.data.shape[-1])
@@ -686,33 +686,36 @@ class LSTMModel(BaseModel, nn.Module):
all_word_inputs = [torch.cat((x, y), axis=1) for x, y in zip(all_word_inputs, bert_embeddings)]
+ all_word_inputs = [self.word_dropout(word_inputs) for word_inputs in all_word_inputs]
+ packed_word_input = torch.nn.utils.rnn.pack_sequence(all_word_inputs, enforce_sorted=False)
+ word_output, _ = self.word_lstm(packed_word_input)
+ # would like to do word_to_constituent here, but it seems PackedSequence doesn't support Linear
+ # word_output will now be sentence x batch x 2*hidden_size
+ word_output, word_output_lens = torch.nn.utils.rnn.pad_packed_sequence(word_output)
+ # now sentence x batch x hidden_size
+
# Extract partitioned representation
+ if self.sentence_boundary_vectors is not SentenceBoundary.NONE:
+ sentence_outputs = [word_output[:len(tagged_words)+2, sentence_idx, :]
+ for sentence_idx, tagged_words in enumerate(tagged_word_lists)]
+ else:
+ sentence_outputs = [word_output[:len(tagged_words), sentence_idx, :]
+ for sentence_idx, tagged_words in enumerate(tagged_word_lists)]
+
if self.partitioned_transformer_module is not None:
- partitioned_embeddings = self.partitioned_transformer_module(None, all_word_inputs)
- all_word_inputs = [torch.cat((x, y[:x.shape[0], :]), axis=1) for x, y in zip(all_word_inputs, partitioned_embeddings)]
+ partitioned_embeddings = self.partitioned_transformer_module(None, sentence_outputs)
+ sentence_outputs = [torch.cat((x, y[:x.shape[0], :]), axis=1) for x, y in zip(sentence_outputs, partitioned_embeddings)]
# Extract Labeled Representation
if self.label_attention_module is not None:
if self.args['lattn_combined_input']:
- labeled_representations = self.label_attention_module(all_word_inputs, tagged_word_lists)
+ labeled_representations = self.label_attention_module(sentence_outputs, tagged_word_lists)
else:
labeled_representations = self.label_attention_module(partitioned_embeddings, tagged_word_lists)
- all_word_inputs = [torch.cat((x, y[:x.shape[0], :]), axis=1) for x, y in zip(all_word_inputs, labeled_representations)]
-
- all_word_inputs = [self.word_dropout(word_inputs) for word_inputs in all_word_inputs]
- packed_word_input = torch.nn.utils.rnn.pack_sequence(all_word_inputs, enforce_sorted=False)
- word_output, _ = self.word_lstm(packed_word_input)
- # would like to do word_to_constituent here, but it seems PackedSequence doesn't support Linear
- # word_output will now be sentence x batch x 2*hidden_size
- word_output, word_output_lens = torch.nn.utils.rnn.pad_packed_sequence(word_output)
- # now sentence x batch x hidden_size
+ sentence_outputs = [torch.cat((x, y[:x.shape[0], :]), axis=1) for x, y in zip(sentence_outputs, labeled_representations)]
word_queues = []
- for sentence_idx, tagged_words in enumerate(tagged_word_lists):
- if self.sentence_boundary_vectors is not SentenceBoundary.NONE:
- sentence_output = word_output[:len(tagged_words)+2, sentence_idx, :]
- else:
- sentence_output = word_output[:len(tagged_words), sentence_idx, :]
+ for sentence_output, tagged_words in zip(sentence_outputs, tagged_word_lists):
sentence_output = self.word_to_constituent(sentence_output)
sentence_output = self.nonlinearity(sentence_output)
# TODO: this makes it so constituents downstream are
diff --git a/stanza/models/constituency/partitioned_transformer.py b/stanza/models/constituency/partitioned_transformer.py
index b46b4338..63380faf 100644
--- a/stanza/models/constituency/partitioned_transformer.py
+++ b/stanza/models/constituency/partitioned_transformer.py
@@ -245,18 +245,8 @@ class PartitionedTransformerModule(nn.Module):
activation=PartitionedReLU()
):
super().__init__()
- self.project_pretrained = nn.Linear(
- word_input_size, d_model // 2, bias=bias
- )
+ self.project_in = PartitionedLinear(word_input_size, d_model, bias=bias)
- self.pattention_morpho_emb_dropout = FeatureDropout(morpho_emb_dropout)
- if timing == 'sin':
- self.add_timing = ConcatSinusoidalEncoding(d_model=d_model // 2, max_len=encoder_max_len)
- elif timing == 'learned':
- self.add_timing = ConcatPositionalEncoding(d_model=d_model // 2, max_len=encoder_max_len)
- else:
- raise ValueError("Unhandled timing type: %s" % timing)
- self.transformer_input_norm = nn.LayerNorm(d_model)
self.pattn_encoder = PartitionedTransformerEncoder(
n_layers,
d_model=d_model,
@@ -296,11 +286,9 @@ class PartitionedTransformerModule(nn.Module):
)
# Project the pretrained embedding onto the desired dimension
- extra_content_annotations = self.project_pretrained(padded_embeddings)
+ encoder_in = self.project_in(padded_embeddings)
+ encoder_in = torch.cat(encoder_in, -1)
- # Add positional information through the table
- encoder_in = self.add_timing(self.pattention_morpho_emb_dropout(extra_content_annotations))
- encoder_in = self.transformer_input_norm(encoder_in)
# Put the partitioned input through the partitioned attention
annotations = self.pattn_encoder(encoder_in, valid_token_mask)
diff --git a/stanza/models/constituency_parser.py b/stanza/models/constituency_parser.py
index 83174ee3..1b916e21 100644
--- a/stanza/models/constituency_parser.py
+++ b/stanza/models/constituency_parser.py
@@ -421,8 +421,8 @@ def parse_args(args=None):
parser.add_argument('--lattn_resdrop', default=True, action='store_true', help='Whether or not to use Residual Dropout')
parser.add_argument('--lattn_pwff', default=True, action='store_true', help='Whether or not to use a Position-wise Feed-forward Layer')
parser.add_argument('--lattn_q_as_matrix', default=False, action='store_true', help='Whether or not Label Attention uses learned query vectors. False means it does')
- parser.add_argument('--lattn_partitioned', default=True, action='store_true', help='Whether or not it is partitioned')
- parser.add_argument('--no_lattn_partitioned', default=True, action='store_false', dest='lattn_partitioned', help='Whether or not it is partitioned')
+ parser.add_argument('--lattn_partitioned', default=False, action='store_true', help='Whether or not it is partitioned')
+ parser.add_argument('--no_lattn_partitioned', default=False, action='store_false', dest='lattn_partitioned', help='Whether or not it is partitioned')
parser.add_argument('--lattn_combine_as_self', default=False, action='store_true', help='Whether or not the layer uses concatenation. False means it does')
# currently unused - always assume 1/2 of pattn
#parser.add_argument('--lattn_d_positional', default=512, type=int, help='Dimension for the positional embedding')
diff --git a/stanza/tests/constituency/test_trainer.py b/stanza/tests/constituency/test_trainer.py
index f6874725..350a11b6 100644
--- a/stanza/tests/constituency/test_trainer.py
+++ b/stanza/tests/constituency/test_trainer.py
@@ -258,26 +258,26 @@ class TestTrainer:
args = self.run_train_test(wordvec_pretrain_file, tmpdirname, num_epochs=8, extra_args=args)
each_name = os.path.join(args['save_dir'], 'each_%02d.pt')
- word_input_sizes = defaultdict(list)
+ word_transform_sizes = defaultdict(list)
for i in range(1, 9):
model_name = each_name % i
assert os.path.exists(model_name)
tr = trainer.Trainer.load(model_name, load_optimizer=True)
assert tr.epochs_trained == i
- word_input_sizes[tr.model.word_input_size].append(i)
+ word_transform_sizes[tr.model.word_transform_size].append(i)
if use_lattn:
# there should be three stages: no attn, pattn, pattn+lattn
- assert len(word_input_sizes) == 3
- word_input_keys = sorted(word_input_sizes.keys())
- assert word_input_sizes[word_input_keys[0]] == [1, 2, 3, 4]
- assert word_input_sizes[word_input_keys[1]] == [5, 6]
- assert word_input_sizes[word_input_keys[2]] == [7, 8]
+ assert len(word_transform_sizes) == 3
+ word_input_keys = sorted(word_transform_sizes.keys())
+ assert word_transform_sizes[word_input_keys[0]] == [1, 2, 3, 4]
+ assert word_transform_sizes[word_input_keys[1]] == [5, 6]
+ assert word_transform_sizes[word_input_keys[2]] == [7, 8]
else:
# with no lattn, there are two stages: no attn, pattn
- assert len(word_input_sizes) == 2
- word_input_keys = sorted(word_input_sizes.keys())
- assert word_input_sizes[word_input_keys[0]] == [1, 2, 3, 4]
- assert word_input_sizes[word_input_keys[1]] == [5, 6, 7, 8]
+ assert len(word_transform_sizes) == 2
+ word_input_keys = sorted(word_transform_sizes.keys())
+ assert word_transform_sizes[word_input_keys[0]] == [1, 2, 3, 4]
+ assert word_transform_sizes[word_input_keys[1]] == [5, 6, 7, 8]
def test_multistage_lattn(self, wordvec_pretrain_file):
"""