diff options
author | John Bauer <horatio@gmail.com> | 2022-11-13 07:38:54 +0300 |
---|---|---|
committer | John Bauer <horatio@gmail.com> | 2022-11-13 07:50:01 +0300 |
commit | 218388766e983c8ee6540ff20ea3b09670181b7a (patch) | |
tree | dd12b00975d08523bf02cdfde4f7a63e29e15888 | |
parent | c357d5a269aea55e65a730a42ea91247b6d96691 (diff) |
Potentially use a MHA layer in the taggertagger_mha
-rw-r--r-- | stanza/models/pos/model.py | 49 | ||||
-rw-r--r-- | stanza/models/tagger.py | 6 |
2 files changed, 54 insertions, 1 deletions
diff --git a/stanza/models/pos/model.py b/stanza/models/pos/model.py index 43ba9da4..b425783e 100644 --- a/stanza/models/pos/model.py +++ b/stanza/models/pos/model.py @@ -14,6 +14,10 @@ from stanza.models.common.hlstm import HighwayLSTM from stanza.models.common.dropout import WordDropout from stanza.models.common.vocab import CompositeVocab from stanza.models.common.char_model import CharacterModel +from stanza.models.constituency.positional_encoding import AddSinusoidalEncoding +# this is from a pip installable package, if we want to use something +# more stable than our own position encoding +# from positional_encodings import PositionalEncoding1D logger = logging.getLogger('stanza') @@ -83,7 +87,18 @@ class Tagger(nn.Module): add_unsaved_module('pretrained_emb', nn.Embedding.from_pretrained(torch.from_numpy(emb_matrix), freeze=True)) self.trans_pretrained = nn.Linear(emb_matrix.shape[1], self.args['transformed_dim'], bias=False) input_size += self.args['transformed_dim'] - + + self.mha = None + if self.args.get('attention', False): + #add_unsaved_module('position_encoding', PositionalEncoding1D(input_size)) + add_unsaved_module('position_encoding', AddSinusoidalEncoding(input_size)) + self.layer_norm = nn.LayerNorm(input_size) + self.mha = nn.MultiheadAttention(input_size, self.args['attention_heads'], batch_first=True) + # alternatives: + #self.encoder_layer = nn.TransformerEncoderLayer(input_size, 1, input_size, batch_first=True) + #self.transformer = nn.TransformerEncoder(encoder_layer, 1) + # also, ideally we could put pattn or lattn here + # recurrent layers self.taggerlstm = HighwayLSTM(input_size, self.args['hidden_dim'], self.args['num_layers'], batch_first=True, bidirectional=True, dropout=self.args['dropout'], rec_dropout=self.args['rec_dropout'], highway_func=torch.tanh) self.drop_replacement = nn.Parameter(torch.randn(input_size) / np.sqrt(input_size)) @@ -187,6 +202,38 @@ class Tagger(nn.Module): lstm_inputs = self.drop(lstm_inputs) lstm_inputs = PackedSequence(lstm_inputs, inputs[0].batch_sizes) + if self.mha is not None: + orig_inputs, lstm_sizes = pad_packed_sequence(lstm_inputs, batch_first=True) + # these print statements are super noisy, but if the + # weights are miscalibrated, they show the model blowing + # up almost immediately + #print("orig input: %.4f" % torch.linalg.norm(orig_inputs), end=" ") + # this would work if using PositionalEncoding1D from the positional_encodings package + # pos_inputs = self.position_encoding(orig_inputs) + # lstm_inputs = orig_inputs + pos_inputs * 0.1 + # higher POS factors than this just fail horribly + lstm_inputs = self.position_encoding(orig_inputs, 0.1) + #print("pos input: %.4f total input: %.4f" % (torch.linalg.norm(lstm_inputs - orig_inputs), torch.linalg.norm(lstm_inputs)), end=" ") + + # build an attention mask as the batch may have differing lengths + attn_mask = torch.zeros(lstm_inputs.shape[0], lstm_inputs.shape[1], dtype=torch.bool, device=lstm_inputs.device) + for lstm_idx, lstm_size in enumerate(lstm_sizes): + attn_mask[lstm_idx, lstm_size:] = True # True should mean don't attend + + attn_outputs, attn_weights = self.mha(lstm_inputs, lstm_inputs, lstm_inputs, key_padding_mask=attn_mask) + # multiplying by 0.02 is very unsatisfactory, though + # it means very little derivative flows from the attention + # to the layers below + # not multiplying makes it blow up almost immediately + # weirdly, initializing made no difference. the problem occurs + # almost immediately as the first updates cause the out_proj + # to jump to a much larger value + lstm_inputs = attn_outputs * 0.02 + # could do a residual link here + #lstm_inputs = orig_inputs + lstm_inputs + #print(" attn output: %.4f" % (torch.linalg.norm(lstm_inputs))) + lstm_inputs = pack(lstm_inputs) + lstm_outputs, _ = self.taggerlstm(lstm_inputs, sentlens, hx=(self.taggerlstm_h_init.expand(2 * self.args['num_layers'], word.size(0), self.args['hidden_dim']).contiguous(), self.taggerlstm_c_init.expand(2 * self.args['num_layers'], word.size(0), self.args['hidden_dim']).contiguous())) lstm_outputs = lstm_outputs.data diff --git a/stanza/models/tagger.py b/stanza/models/tagger.py index f0af9ee0..07cb4dbd 100644 --- a/stanza/models/tagger.py +++ b/stanza/models/tagger.py @@ -63,6 +63,9 @@ def parse_args(args=None): parser.add_argument('--rec_dropout', type=float, default=0, help="Recurrent dropout") parser.add_argument('--char_rec_dropout', type=float, default=0, help="Recurrent dropout") + parser.add_argument('--attention', default=False, action='store_true', help='Use an MHA on the inputs before the LSTM') + parser.add_argument('--attention_heads', type=int, default=8, help='Number of attention heads to use') + # TODO: refactor charlm arguments for models which use it? parser.add_argument('--no_char', dest='char', action='store_false', help="Turn off character model.") parser.add_argument('--char_bidirectional', dest='char_bidirectional', action='store_true', help="Use a bidirectional version of the non-pretrained charlm. Doesn't help much, makes the models larger") @@ -206,6 +209,9 @@ def train(args): foundation_cache = FoundationCache() trainer = Trainer(args=args, vocab=vocab, pretrain=pretrain, use_cuda=args['cuda'], foundation_cache=foundation_cache) + if args['log_norms']: + trainer.model.log_norms() + global_step = 0 max_steps = args['max_steps'] dev_score_history = [] |