diff options
author | Valentin Berkes <16121857+funboarder13920@users.noreply.github.com> | 2021-01-08 18:03:17 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-01-08 18:03:17 +0300 |
commit | 12ec5704c0084e683fca5752502a0351731467ea (patch) | |
tree | 53a8a57157da2771e85b7da59fc0b940cf108fba | |
parent | 7a2174455ef7ccc22c99f06da07ffb46e85d022f (diff) |
Fix gpt2 architecture and inference (#1971)
-rw-r--r-- | data/data_lm/gen-beam-sol.txt | 12 | ||||
-rw-r--r-- | data/data_lm/gen-sampling-sol.txt | 14 | ||||
-rw-r--r-- | onmt/bin/build_vocab.py | 2 | ||||
-rw-r--r-- | onmt/decoders/transformer.py | 310 | ||||
-rw-r--r-- | onmt/encoders/transformer.py | 21 | ||||
-rw-r--r-- | onmt/modules/average_attn.py | 10 | ||||
-rw-r--r-- | onmt/modules/embeddings.py | 16 | ||||
-rw-r--r-- | onmt/modules/position_ffn.py | 20 | ||||
-rw-r--r-- | onmt/opts.py | 9 | ||||
-rwxr-xr-x | onmt/tests/rebuild_test_models.sh | 2 | ||||
-rw-r--r-- | onmt/tests/test_beam_search.py | 19 | ||||
-rw-r--r-- | onmt/tests/test_model_lm.pt | bin | 28410923 -> 28404135 bytes |
12 files changed, 230 insertions, 205 deletions
diff --git a/data/data_lm/gen-beam-sol.txt b/data/data_lm/gen-beam-sol.txt index 227f6df4..e8d50d9c 100644 --- a/data/data_lm/gen-beam-sol.txt +++ b/data/data_lm/gen-beam-sol.txt @@ -1,7 +1,7 @@ -you ! -ignored . -elections . +<unk> refined presents ... +Mathias Hlubek , Chief Financial Officer of Deutsche Börse . +integrated in Israel . . -<unk> works . -codec to be available soon . -forty years ago . +<unk> of <unk> of <unk> of <unk> of <unk> of <unk> of <unk> of <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> of <unk> of <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> of <unk> <unk> of <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> of <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> / <unk> <unk> <unk> <unk> <unk> <unk> / <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> +are looking for any information . +<unk> tests at the " tools for every month . diff --git a/data/data_lm/gen-sampling-sol.txt b/data/data_lm/gen-sampling-sol.txt index 0c2bf3f5..dda9a0d8 100644 --- a/data/data_lm/gen-sampling-sol.txt +++ b/data/data_lm/gen-sampling-sol.txt @@ -1,7 +1,7 @@ -you ! -that the Irish problem in the Commission should not only a topical problem . -elections , Israel will be developed to start of the crisis , at the crisis , at the crisis , at the crisis , during the crisis , at the crisis . -and <unk> -the July 2003 , has been developed to win <unk> public - thus been developed countries . -might have been <unk> . -<unk> , I think we are going to be able to make it . +<unk> chocolate <unk> refined presents <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> ensures an <unk> ensures an <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> ensures no longer to <unk> ensures an air <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> +hello while passing on the hotel , as well as it is good hotel , as it is a good hotel , as it is good hotel , as it is a good hotel , as it is a good hotel . +<unk> in Israel was made of <unk> in Israel . +and <unk> butter and 0 to 6 languages . +<unk> of <unk> of <unk> of <unk> of <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> of <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> of <unk> <unk> <unk> of <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> +are looking for the information . +the <unk> tests at the " tools for you can get to obtain intermediate diploma / successive dishes of your BMW <unk> for young title symbols over the most of your experience for young <unk> . diff --git a/onmt/bin/build_vocab.py b/onmt/bin/build_vocab.py index 462469be..e106d921 100644 --- a/onmt/bin/build_vocab.py +++ b/onmt/bin/build_vocab.py @@ -40,7 +40,7 @@ def build_vocab_main(opts): def save_counter(counter, save_path): check_path(save_path, exist_ok=opts.overwrite, log=logger.warning) - with open(save_path, "w",encoding="utf8") as fo: + with open(save_path, "w", encoding="utf8") as fo: for tok, count in counter.most_common(): fo.write(tok + "\t" + str(count) + "\n") diff --git a/onmt/decoders/transformer.py b/onmt/decoders/transformer.py index 48f1fff3..a50e4a8e 100644 --- a/onmt/decoders/transformer.py +++ b/onmt/decoders/transformer.py @@ -9,13 +9,74 @@ import torch.nn as nn from onmt.decoders.decoder import DecoderBase from onmt.modules import MultiHeadedAttention, AverageAttention from onmt.modules.position_ffn import PositionwiseFeedForward +from onmt.modules.position_ffn import ActivationFunction from onmt.utils.misc import sequence_mask class TransformerDecoderLayerBase(nn.Module): - def __init__(self): + def __init__( + self, + d_model, + heads, + d_ff, + dropout, + attention_dropout, + self_attn_type="scaled-dot", + max_relative_positions=0, + aan_useffn=False, + full_context_alignment=False, + alignment_heads=0, + pos_ffn_activation_fn=ActivationFunction.relu, + ): + """ + Args: + d_model (int): the dimension of keys/values/queries in + :class:`MultiHeadedAttention`, also the input size of + the first-layer of the :class:`PositionwiseFeedForward`. + heads (int): the number of heads for MultiHeadedAttention. + d_ff (int): the second-layer of the + :class:`PositionwiseFeedForward`. + dropout (float): dropout in residual, self-attn(dot) and + feed-forward + attention_dropout (float): dropout in context_attn (and + self-attn(avg)) + self_attn_type (string): type of self-attention scaled-dot, + average + max_relative_positions (int): + Max distance between inputs in relative positions + representations + aan_useffn (bool): Turn on the FFN layer in the AAN decoder + full_context_alignment (bool): + whether enable an extra full context decoder forward for + alignment + alignment_heads (int): + N. of cross attention heads to use for alignment guiding + pos_ffn_activation_fn (ActivationFunction): + activation function choice for PositionwiseFeedForward layer + + """ super(TransformerDecoderLayerBase, self).__init__() + if self_attn_type == "scaled-dot": + self.self_attn = MultiHeadedAttention( + heads, + d_model, + dropout=attention_dropout, + max_relative_positions=max_relative_positions, + ) + elif self_attn_type == "average": + self.self_attn = AverageAttention( + d_model, dropout=attention_dropout, aan_useffn=aan_useffn + ) + + self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout, + pos_ffn_activation_fn + ) + self.layer_norm_1 = nn.LayerNorm(d_model, eps=1e-6) + self.drop = nn.Dropout(dropout) + self.full_context_alignment = full_context_alignment + self.alignment_heads = alignment_heads + def forward(self, *args, **kwargs): """Extend `_forward` for (possibly) multiple decoder pass: Always a default (future masked) decoder forward pass, @@ -51,11 +112,51 @@ class TransformerDecoderLayerBase(nn.Module): attn_align = attns.mean(dim=1) return output, top_attn, attn_align + def update_dropout(self, dropout, attention_dropout): + self.self_attn.update_dropout(attention_dropout) + self.feed_forward.update_dropout(dropout) + self.drop.p = dropout + def _forward(self, *args, **kwargs): raise NotImplementedError - def update_dropout(self, dropout, attention_dropout): - raise NotImplementedError + def _compute_dec_mask(self, tgt_pad_mask, future): + tgt_len = tgt_pad_mask.size(-1) + if not future: # apply future_mask, result mask in (B, T, T) + future_mask = torch.ones( + [tgt_len, tgt_len], + device=tgt_pad_mask.device, + dtype=torch.uint8, + ) + future_mask = future_mask.triu_(1).view(1, tgt_len, tgt_len) + # BoolTensor was introduced in pytorch 1.2 + try: + future_mask = future_mask.bool() + except AttributeError: + pass + dec_mask = torch.gt(tgt_pad_mask + future_mask, 0) + else: # only mask padding, result mask in (B, 1, T) + dec_mask = tgt_pad_mask + return dec_mask + + def _forward_self_attn(self, inputs_norm, dec_mask, layer_cache, step): + if isinstance(self.self_attn, MultiHeadedAttention): + return self.self_attn( + inputs_norm, + inputs_norm, + inputs_norm, + mask=dec_mask, + layer_cache=layer_cache, + attn_type="self", + ) + elif isinstance(self.self_attn, AverageAttention): + return self.self_attn( + inputs_norm, mask=dec_mask, layer_cache=layer_cache, step=step + ) + else: + raise ValueError( + f"self attention {type(self.self_attn)} not supported" + ) class TransformerDecoderLayer(TransformerDecoderLayerBase): @@ -76,23 +177,6 @@ class TransformerDecoderLayer(TransformerDecoderLayerBase): A --> E E --> F(out) - - Args: - d_model (int): the dimension of keys/values/queries in - :class:`MultiHeadedAttention`, also the input size of - the first-layer of the :class:`PositionwiseFeedForward`. - heads (int): the number of heads for MultiHeadedAttention. - d_ff (int): the second-layer of the :class:`PositionwiseFeedForward`. - dropout (float): dropout in residual, self-attn(dot) and feed-forward - attention_dropout (float): dropout in context_attn (and self-attn(avg)) - self_attn_type (string): type of self-attention scaled-dot, average - max_relative_positions (int): - Max distance between inputs in relative positions representations - aan_useffn (bool): Turn on the FFN layer in the AAN decoder - full_context_alignment (bool): - whether enable an extra full context decoder forward for alignment - alignment_heads (int): - N. of cross attention heads to use for alignment guiding """ def __init__( @@ -107,36 +191,35 @@ class TransformerDecoderLayer(TransformerDecoderLayerBase): aan_useffn=False, full_context_alignment=False, alignment_heads=0, + pos_ffn_activation_fn=ActivationFunction.relu, ): - super(TransformerDecoderLayer, self).__init__() - - if self_attn_type == "scaled-dot": - self.self_attn = MultiHeadedAttention( - heads, - d_model, - dropout=attention_dropout, - max_relative_positions=max_relative_positions, - ) - elif self_attn_type == "average": - self.self_attn = AverageAttention( - d_model, dropout=attention_dropout, aan_useffn=aan_useffn - ) - + """ + Args: + See TransformerDecoderLayerBase + """ + super(TransformerDecoderLayer, self).__init__( + d_model, + heads, + d_ff, + dropout, + attention_dropout, + self_attn_type, + max_relative_positions, + aan_useffn, + full_context_alignment, + alignment_heads, + pos_ffn_activation_fn=pos_ffn_activation_fn, + ) self.context_attn = MultiHeadedAttention( heads, d_model, dropout=attention_dropout ) - self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout) - self.layer_norm_1 = nn.LayerNorm(d_model, eps=1e-6) self.layer_norm_2 = nn.LayerNorm(d_model, eps=1e-6) - self.drop = nn.Dropout(dropout) - self.full_context_alignment = full_context_alignment - self.alignment_heads = alignment_heads def update_dropout(self, dropout, attention_dropout): - self.self_attn.update_dropout(attention_dropout) + super(TransformerDecoderLayer, self).update_dropout( + dropout, attention_dropout + ) self.context_attn.update_dropout(attention_dropout) - self.feed_forward.update_dropout(dropout) - self.drop.p = dropout def _forward( self, @@ -170,39 +253,15 @@ class TransformerDecoderLayer(TransformerDecoderLayerBase): """ dec_mask = None - if step is None: - tgt_len = tgt_pad_mask.size(-1) - if not future: # apply future_mask, result mask in (B, T, T) - future_mask = torch.ones( - [tgt_len, tgt_len], - device=tgt_pad_mask.device, - dtype=torch.uint8, - ) - future_mask = future_mask.triu_(1).view(1, tgt_len, tgt_len) - # BoolTensor was introduced in pytorch 1.2 - try: - future_mask = future_mask.bool() - except AttributeError: - pass - dec_mask = torch.gt(tgt_pad_mask + future_mask, 0) - else: # only mask padding, result mask in (B, 1, T) - dec_mask = tgt_pad_mask - - input_norm = self.layer_norm_1(inputs) + if inputs.size(1) > 1: + # masking is necessary when sequence length is greater than one + dec_mask = self._compute_dec_mask(tgt_pad_mask, future) - if isinstance(self.self_attn, MultiHeadedAttention): - query, _ = self.self_attn( - input_norm, - input_norm, - input_norm, - mask=dec_mask, - layer_cache=layer_cache, - attn_type="self", - ) - elif isinstance(self.self_attn, AverageAttention): - query, _ = self.self_attn( - input_norm, mask=dec_mask, layer_cache=layer_cache, step=step - ) + inputs_norm = self.layer_norm_1(inputs) + + query, _ = self._forward_self_attn( + inputs_norm, dec_mask, layer_cache, step + ) query = self.drop(query) + inputs @@ -257,6 +316,7 @@ class TransformerDecoderBase(DecoderBase): opt.full_context_alignment, opt.alignment_layer, alignment_heads=opt.alignment_heads, + pos_ffn_activation_fn=opt.pos_ffn_activation_fn, ) def init_state(self, src, memory_bank, enc_hidden): @@ -345,6 +405,7 @@ class TransformerDecoder(TransformerDecoderBase): full_context_alignment, alignment_layer, alignment_heads, + pos_ffn_activation_fn=ActivationFunction.relu, ): super(TransformerDecoder, self).__init__( d_model, copy_attn, embeddings, alignment_layer @@ -363,6 +424,7 @@ class TransformerDecoder(TransformerDecoderBase): aan_useffn=aan_useffn, full_context_alignment=full_context_alignment, alignment_heads=alignment_heads, + pos_ffn_activation_fn=pos_ffn_activation_fn, ) for i in range(num_layers) ] @@ -460,57 +522,9 @@ class TransformerLMDecoderLayer(TransformerDecoderLayerBase): Args: - d_model (int): the dimension of keys/values/queries in - :class:`MultiHeadedAttention`, also the input size of - the first-layer of the :class:`PositionwiseFeedForward`. - heads (int): the number of heads for MultiHeadedAttention. - d_ff (int): the second-layer of the :class:`PositionwiseFeedForward`. - dropout (float): dropout in residual, self-attn(dot) and feed-forward - attention_dropout (float): dropout in context_attn (and self-attn(avg)) - self_attn_type (string): type of self-attention scaled-dot, average - max_relative_positions (int): - Max distance between inputs in relative positions representations - aan_useffn (bool): Turn on the FFN layer in the AAN decoder - full_context_alignment (bool): - whether enable an extra full context decoder forward for alignment - alignment_heads (int): - N. of cross attention heads to use for alignment guiding + See TransformerDecoderLayerBase """ - def __init__( - self, - d_model, - heads, - d_ff, - dropout, - attention_dropout, - self_attn_type="scaled-dot", - max_relative_positions=0, - aan_useffn=False, - full_context_alignment=False, - alignment_heads=0, - ): - super(TransformerLMDecoderLayer, self).__init__() - - if self_attn_type == "scaled-dot": - self.self_attn = MultiHeadedAttention( - heads, - d_model, - dropout=attention_dropout, - max_relative_positions=max_relative_positions, - ) - elif self_attn_type == "average": - self.self_attn = AverageAttention( - d_model, dropout=attention_dropout, aan_useffn=aan_useffn - ) - - self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout) - self.layer_norm_1 = nn.LayerNorm(d_model, eps=1e-6) - self.layer_norm_2 = nn.LayerNorm(d_model, eps=1e-6) - self.drop = nn.Dropout(dropout) - self.full_context_alignment = full_context_alignment - self.alignment_heads = alignment_heads - def _forward( self, inputs, tgt_pad_mask, layer_cache=None, step=None, future=False ): @@ -534,51 +548,21 @@ class TransformerLMDecoderLayer(TransformerDecoderLayerBase): """ dec_mask = None - if step is None: - tgt_len = tgt_pad_mask.size(-1) - if not future: # apply future_mask, result mask in (B, T, T) - future_mask = torch.ones( - [tgt_len, tgt_len], - device=tgt_pad_mask.device, - dtype=torch.uint8, - ) - future_mask = future_mask.triu_(1).view(1, tgt_len, tgt_len) - # BoolTensor was introduced in pytorch 1.2 - try: - future_mask = future_mask.bool() - except AttributeError: - pass - dec_mask = torch.gt(tgt_pad_mask + future_mask, 0) - else: # only mask padding, result mask in (B, 1, T) - dec_mask = tgt_pad_mask + if inputs.size(1) > 1: + # masking is necessary when sequence length is greater than one + dec_mask = self._compute_dec_mask(tgt_pad_mask, future) inputs_norm = self.layer_norm_1(inputs) - if isinstance(self.self_attn, MultiHeadedAttention): - query, attns = self.self_attn( - inputs_norm, - inputs_norm, - inputs_norm, - mask=dec_mask, - layer_cache=layer_cache, - attn_type="self", - ) - elif isinstance(self.self_attn, AverageAttention): - query, attns = self.self_attn( - inputs_norm, mask=dec_mask, layer_cache=layer_cache, step=step - ) - output = self.drop(query) + inputs - - output_feedforward = self.feed_forward(self.layer_norm_2(output)) + query, attns = self._forward_self_attn( + inputs_norm, dec_mask, layer_cache, step + ) - output_norm = self.drop(output_feedforward) + output + output = self.drop(query) + inputs - return output_norm, attns + output_feedforward = self.feed_forward(output) - def update_dropout(self, dropout, attention_dropout): - self.self_attn.update_dropout(attention_dropout) - self.feed_forward.update_dropout(dropout) - self.drop.p = dropout + return output_feedforward, attns class TransformerLMDecoder(TransformerDecoderBase): @@ -628,6 +612,7 @@ class TransformerLMDecoder(TransformerDecoderBase): full_context_alignment=None, alignment_layer=None, alignment_heads=None, + pos_ffn_activation_fn=ActivationFunction.relu, ): super(TransformerLMDecoder, self).__init__( d_model, copy_attn, embeddings, None @@ -645,6 +630,7 @@ class TransformerLMDecoder(TransformerDecoderBase): aan_useffn=aan_useffn, full_context_alignment=None, alignment_heads=None, + pos_ffn_activation_fn=pos_ffn_activation_fn, ) for i in range(num_layers) ] diff --git a/onmt/encoders/transformer.py b/onmt/encoders/transformer.py index 5c29d8de..afb796b3 100644 --- a/onmt/encoders/transformer.py +++ b/onmt/encoders/transformer.py @@ -7,6 +7,7 @@ import torch.nn as nn from onmt.encoders.encoder import EncoderBase from onmt.modules import MultiHeadedAttention from onmt.modules.position_ffn import PositionwiseFeedForward +from onmt.modules.position_ffn import ActivationFunction from onmt.utils.misc import sequence_mask @@ -21,16 +22,20 @@ class TransformerEncoderLayer(nn.Module): heads (int): the number of head for MultiHeadedAttention. d_ff (int): the second-layer of the PositionwiseFeedForward. dropout (float): dropout probability(0-1.0). + pos_ffn_activation_fn (ActivationFunction): + activation function choice for PositionwiseFeedForward layer """ def __init__(self, d_model, heads, d_ff, dropout, attention_dropout, - max_relative_positions=0): + max_relative_positions=0, + pos_ffn_activation_fn=ActivationFunction.relu): super(TransformerEncoderLayer, self).__init__() self.self_attn = MultiHeadedAttention( heads, d_model, dropout=attention_dropout, max_relative_positions=max_relative_positions) - self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout) + self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout, + pos_ffn_activation_fn) self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) self.dropout = nn.Dropout(dropout) @@ -80,6 +85,8 @@ class TransformerEncoder(EncoderBase): dropout (float): dropout parameters embeddings (onmt.modules.Embeddings): embeddings to use, should have positional encodings + pos_ffn_activation_fn (ActivationFunction): + activation function choice for PositionwiseFeedForward layer Returns: (torch.FloatTensor, torch.FloatTensor): @@ -89,14 +96,16 @@ class TransformerEncoder(EncoderBase): """ def __init__(self, num_layers, d_model, heads, d_ff, dropout, - attention_dropout, embeddings, max_relative_positions): + attention_dropout, embeddings, max_relative_positions, + pos_ffn_activation_fn=ActivationFunction.relu): super(TransformerEncoder, self).__init__() self.embeddings = embeddings self.transformer = nn.ModuleList( [TransformerEncoderLayer( d_model, heads, d_ff, dropout, attention_dropout, - max_relative_positions=max_relative_positions) + max_relative_positions=max_relative_positions, + pos_ffn_activation_fn=pos_ffn_activation_fn) for i in range(num_layers)]) self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) @@ -112,7 +121,9 @@ class TransformerEncoder(EncoderBase): opt.attention_dropout[0] if type(opt.attention_dropout) is list else opt.attention_dropout, embeddings, - opt.max_relative_positions) + opt.max_relative_positions, + pos_ffn_activation_fn=opt.pos_ffn_activation_fn, + ) def forward(self, src, lengths=None): """See :func:`EncoderBase.forward()`""" diff --git a/onmt/modules/average_attn.py b/onmt/modules/average_attn.py index 29c4259e..7d8dda5d 100644 --- a/onmt/modules/average_attn.py +++ b/onmt/modules/average_attn.py @@ -5,6 +5,7 @@ import torch import torch.nn as nn from onmt.modules.position_ffn import PositionwiseFeedForward +from onmt.modules.position_ffn import ActivationFunction class AverageAttention(nn.Module): @@ -17,15 +18,20 @@ class AverageAttention(nn.Module): model_dim (int): the dimension of keys/values/queries, must be divisible by head_count dropout (float): dropout parameter + pos_ffn_activation_fn (ActivationFunction): + activation function choice for PositionwiseFeedForward layer """ - def __init__(self, model_dim, dropout=0.1, aan_useffn=False): + def __init__(self, model_dim, dropout=0.1, aan_useffn=False, + pos_ffn_activation_fn=ActivationFunction.relu): self.model_dim = model_dim self.aan_useffn = aan_useffn super(AverageAttention, self).__init__() if aan_useffn: self.average_layer = PositionwiseFeedForward(model_dim, model_dim, - dropout) + dropout, + pos_ffn_activation_fn + ) self.gating_layer = nn.Linear(model_dim * 2, model_dim * 2) def cumulative_average_mask(self, batch_size, inputs_len, device): diff --git a/onmt/modules/embeddings.py b/onmt/modules/embeddings.py index 8a97a77c..9561a66e 100644 --- a/onmt/modules/embeddings.py +++ b/onmt/modules/embeddings.py @@ -52,15 +52,13 @@ class PositionalEncoding(nn.Module): """ emb = emb * math.sqrt(self.dim) - if step is None: - if self.pe.size(0) < emb.size(0): - raise SequenceTooLongError( - f"Sequence is {emb.size(0)} but PositionalEncoding is" - f" limited to {self.pe.size(0)}. See max_len argument." - ) - emb = emb + self.pe[:emb.size(0)] - else: - emb = emb + self.pe[step] + step = step or 0 + if self.pe.size(0) < step + emb.size(0): + raise SequenceTooLongError( + f"Sequence is {emb.size(0) + step} but PositionalEncoding is" + f" limited to {self.pe.size(0)}. See max_len argument." + ) + emb = emb + self.pe[step:emb.size(0)+step] emb = self.dropout(emb) return emb diff --git a/onmt/modules/position_ffn.py b/onmt/modules/position_ffn.py index fb8df80a..6757742d 100644 --- a/onmt/modules/position_ffn.py +++ b/onmt/modules/position_ffn.py @@ -1,6 +1,18 @@ """Position feed-forward network from "Attention is All You Need".""" import torch.nn as nn +import torch.nn.functional as F + + +class ActivationFunction(object): + relu = "relu" + gelu = "gelu" + + +ACTIVATION_FUNCTIONS = { + ActivationFunction.relu: F.relu, + ActivationFunction.gelu: F.gelu, +} class PositionwiseFeedForward(nn.Module): @@ -11,15 +23,17 @@ class PositionwiseFeedForward(nn.Module): d_ff (int): the hidden layer size of the second-layer of the FNN. dropout (float): dropout probability in :math:`[0, 1)`. + activation_fn (ActivationFunction): activation function used. """ - def __init__(self, d_model, d_ff, dropout=0.1): + def __init__(self, d_model, d_ff, dropout=0.1, + activation_fn=ActivationFunction.relu): super(PositionwiseFeedForward, self).__init__() self.w_1 = nn.Linear(d_model, d_ff) self.w_2 = nn.Linear(d_ff, d_model) self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) self.dropout_1 = nn.Dropout(dropout) - self.relu = nn.ReLU() + self.activation = ACTIVATION_FUNCTIONS[activation_fn] self.dropout_2 = nn.Dropout(dropout) def forward(self, x): @@ -32,7 +46,7 @@ class PositionwiseFeedForward(nn.Module): (FloatTensor): Output ``(batch_size, input_len, model_dim)``. """ - inter = self.dropout_1(self.relu(self.w_1(self.layer_norm(x)))) + inter = self.dropout_1(self.activation(self.w_1(self.layer_norm(x)))) output = self.dropout_2(self.w_2(inter)) return output + x diff --git a/onmt/opts.py b/onmt/opts.py index 65be0e93..0ec07177 100644 --- a/onmt/opts.py +++ b/onmt/opts.py @@ -6,6 +6,8 @@ import configargparse from onmt.models.sru import CheckSRU from onmt.transforms import AVAILABLE_TRANSFORMS from onmt.constants import ModelTask +from onmt.modules.position_ffn import ACTIVATION_FUNCTIONS +from onmt.modules.position_ffn import ActivationFunction def config_opts(parser): @@ -291,6 +293,13 @@ def model_opts(parser): help="Size of windows in the cnn, the kernel_size is " "(cnn_kernel_width, 1) in conv layer") + group.add('--pos_ffn_activation_fn', '-pos_ffn_activation_fn', + type=str, default=ActivationFunction.relu, + choices=ACTIVATION_FUNCTIONS.keys(), help='The activation' + ' function to use in PositionwiseFeedForward layer. Choices are' + f' {ACTIVATION_FUNCTIONS.keys()}. Default to' + f' {ActivationFunction.relu}.') + group.add('--input_feed', '-input_feed', type=int, default=1, help="Feed the context vector at each time step as " "additional input (via concatenation with the word " diff --git a/onmt/tests/rebuild_test_models.sh b/onmt/tests/rebuild_test_models.sh index e6225a46..e4b2aa51 100755 --- a/onmt/tests/rebuild_test_models.sh +++ b/onmt/tests/rebuild_test_models.sh @@ -114,7 +114,7 @@ $my_python train.py -config data/lm_data.yaml -save_model /tmp/tmp \ -attention_dropout 0.1 -heads 2 -position_encoding -param_init 0 -warmup_steps 100 \ -param_init_glorot -adam_beta2 0.998 -src_vocab data/data_lm/data.vocab.src # -mv /tmp/tmp*2000.pt onmt/tests/test_model_lm_2.pt +mv /tmp/tmp*2000.pt onmt/tests/test_model_lm.pt rm /tmp/tmp*.pt fi # diff --git a/onmt/tests/test_beam_search.py b/onmt/tests/test_beam_search.py index bbd24f4b..08d072ae 100644 --- a/onmt/tests/test_beam_search.py +++ b/onmt/tests/test_beam_search.py @@ -574,14 +574,14 @@ class TestBeamSearchLM(TestBeamSearchAgainstReferenceCase): scores_finish = torch.log_softmax(torch.tensor( [[0, 0, 10000, 0, 5000, .51, .2, 0], # beam 0 shouldn't cont [100000, 100001, 0, 0, 0, 0, 0, 0], - [0,100000, 0, 0, 0, 5000, 0, 0], + [0, 100000, 0, 0, 0, 5000, 0, 0], [0, 0, 0, .2, .2, .2, .2, .2], [0, 0, 0, 0, .2, .2, .2, .2]] # beam 4 -> beam 1 should die ), dim=1) scores_finish = scores_finish.repeat(self.BATCH_SZ, 1) scores_finish[:self.BEAM_SZ, beam.eos] = 0 - beam.advance( scores_finish, None) - + beam.advance(scores_finish, None) + any_finished = beam.is_finished.any() if any_finished: beam.update_finished() @@ -601,9 +601,9 @@ class TestBeamSearchLM(TestBeamSearchAgainstReferenceCase): self.third_step(beam, expected_beam_scores, 1) n_steps = beam.alive_seq.shape[-1] - 1 - self.assertTrue(beam.memory_lengths.equal(n_steps+fn_map_state(src_lengths, dim=0))) - - + self.assertTrue(beam.memory_lengths.equal(n_steps+fn_map_state( + src_lengths, dim=0))) + def test_beam_lm_update_memory_length_when_finished(self): beam = BeamSearchLM( self.BEAM_SZ, self.BATCH_SZ, 0, 1, 2, self.N_BEST, @@ -613,8 +613,9 @@ class TestBeamSearchLM(TestBeamSearchAgainstReferenceCase): device_init = torch.zeros(1, 1) src_lengths = torch.randint(0, 30, (self.BATCH_SZ,)) fn_map_state, _, _, _ = beam.initialize(device_init, src_lengths) - expected_beam_scores = self.init_step(beam, 1) + _ = self.init_step(beam, 1) self.finish_first_beam_step(beam) - + n_steps = beam.alive_seq.shape[-1] - 1 - self.assertTrue(beam.memory_lengths.equal(n_steps+fn_map_state(src_lengths[1:], dim=0))) + self.assertTrue(beam.memory_lengths.equal(n_steps+fn_map_state( + src_lengths[1:], dim=0))) diff --git a/onmt/tests/test_model_lm.pt b/onmt/tests/test_model_lm.pt Binary files differindex ef7af433..84a5729f 100644 --- a/onmt/tests/test_model_lm.pt +++ b/onmt/tests/test_model_lm.pt |