Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/OpenNMT/OpenNMT-py.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorValentin Berkes <16121857+funboarder13920@users.noreply.github.com>2021-01-08 18:03:17 +0300
committerGitHub <noreply@github.com>2021-01-08 18:03:17 +0300
commit12ec5704c0084e683fca5752502a0351731467ea (patch)
tree53a8a57157da2771e85b7da59fc0b940cf108fba
parent7a2174455ef7ccc22c99f06da07ffb46e85d022f (diff)
Fix gpt2 architecture and inference (#1971)
-rw-r--r--data/data_lm/gen-beam-sol.txt12
-rw-r--r--data/data_lm/gen-sampling-sol.txt14
-rw-r--r--onmt/bin/build_vocab.py2
-rw-r--r--onmt/decoders/transformer.py310
-rw-r--r--onmt/encoders/transformer.py21
-rw-r--r--onmt/modules/average_attn.py10
-rw-r--r--onmt/modules/embeddings.py16
-rw-r--r--onmt/modules/position_ffn.py20
-rw-r--r--onmt/opts.py9
-rwxr-xr-xonmt/tests/rebuild_test_models.sh2
-rw-r--r--onmt/tests/test_beam_search.py19
-rw-r--r--onmt/tests/test_model_lm.ptbin28410923 -> 28404135 bytes
12 files changed, 230 insertions, 205 deletions
diff --git a/data/data_lm/gen-beam-sol.txt b/data/data_lm/gen-beam-sol.txt
index 227f6df4..e8d50d9c 100644
--- a/data/data_lm/gen-beam-sol.txt
+++ b/data/data_lm/gen-beam-sol.txt
@@ -1,7 +1,7 @@
-you !
-ignored .
-elections .
+<unk> refined presents ...
+Mathias Hlubek , Chief Financial Officer of Deutsche Börse .
+integrated in Israel .
.
-<unk> works .
-codec to be available soon .
-forty years ago .
+<unk> of <unk> of <unk> of <unk> of <unk> of <unk> of <unk> of <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> of <unk> of <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> of <unk> <unk> of <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> of <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> / <unk> <unk> <unk> <unk> <unk> <unk> / <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk>
+are looking for any information .
+<unk> tests at the &quot; tools for every month .
diff --git a/data/data_lm/gen-sampling-sol.txt b/data/data_lm/gen-sampling-sol.txt
index 0c2bf3f5..dda9a0d8 100644
--- a/data/data_lm/gen-sampling-sol.txt
+++ b/data/data_lm/gen-sampling-sol.txt
@@ -1,7 +1,7 @@
-you !
-that the Irish problem in the Commission should not only a topical problem .
-elections , Israel will be developed to start of the crisis , at the crisis , at the crisis , at the crisis , during the crisis , at the crisis .
-and <unk>
-the July 2003 , has been developed to win <unk> public - thus been developed countries .
-might have been <unk> .
-<unk> , I think we are going to be able to make it .
+<unk> chocolate <unk> refined presents <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> ensures an <unk> ensures an <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> ensures no longer to <unk> ensures an air <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk>
+hello while passing on the hotel , as well as it is good hotel , as it is a good hotel , as it is good hotel , as it is a good hotel , as it is a good hotel .
+<unk> in Israel was made of <unk> in Israel .
+and <unk> butter and 0 to 6 languages .
+<unk> of <unk> of <unk> of <unk> of <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> of <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> of <unk> <unk> <unk> of <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk>
+are looking for the information .
+the <unk> tests at the &quot; tools for you can get to obtain intermediate diploma / successive dishes of your BMW <unk> for young title symbols over the most of your experience for young <unk> .
diff --git a/onmt/bin/build_vocab.py b/onmt/bin/build_vocab.py
index 462469be..e106d921 100644
--- a/onmt/bin/build_vocab.py
+++ b/onmt/bin/build_vocab.py
@@ -40,7 +40,7 @@ def build_vocab_main(opts):
def save_counter(counter, save_path):
check_path(save_path, exist_ok=opts.overwrite, log=logger.warning)
- with open(save_path, "w",encoding="utf8") as fo:
+ with open(save_path, "w", encoding="utf8") as fo:
for tok, count in counter.most_common():
fo.write(tok + "\t" + str(count) + "\n")
diff --git a/onmt/decoders/transformer.py b/onmt/decoders/transformer.py
index 48f1fff3..a50e4a8e 100644
--- a/onmt/decoders/transformer.py
+++ b/onmt/decoders/transformer.py
@@ -9,13 +9,74 @@ import torch.nn as nn
from onmt.decoders.decoder import DecoderBase
from onmt.modules import MultiHeadedAttention, AverageAttention
from onmt.modules.position_ffn import PositionwiseFeedForward
+from onmt.modules.position_ffn import ActivationFunction
from onmt.utils.misc import sequence_mask
class TransformerDecoderLayerBase(nn.Module):
- def __init__(self):
+ def __init__(
+ self,
+ d_model,
+ heads,
+ d_ff,
+ dropout,
+ attention_dropout,
+ self_attn_type="scaled-dot",
+ max_relative_positions=0,
+ aan_useffn=False,
+ full_context_alignment=False,
+ alignment_heads=0,
+ pos_ffn_activation_fn=ActivationFunction.relu,
+ ):
+ """
+ Args:
+ d_model (int): the dimension of keys/values/queries in
+ :class:`MultiHeadedAttention`, also the input size of
+ the first-layer of the :class:`PositionwiseFeedForward`.
+ heads (int): the number of heads for MultiHeadedAttention.
+ d_ff (int): the second-layer of the
+ :class:`PositionwiseFeedForward`.
+ dropout (float): dropout in residual, self-attn(dot) and
+ feed-forward
+ attention_dropout (float): dropout in context_attn (and
+ self-attn(avg))
+ self_attn_type (string): type of self-attention scaled-dot,
+ average
+ max_relative_positions (int):
+ Max distance between inputs in relative positions
+ representations
+ aan_useffn (bool): Turn on the FFN layer in the AAN decoder
+ full_context_alignment (bool):
+ whether enable an extra full context decoder forward for
+ alignment
+ alignment_heads (int):
+ N. of cross attention heads to use for alignment guiding
+ pos_ffn_activation_fn (ActivationFunction):
+ activation function choice for PositionwiseFeedForward layer
+
+ """
super(TransformerDecoderLayerBase, self).__init__()
+ if self_attn_type == "scaled-dot":
+ self.self_attn = MultiHeadedAttention(
+ heads,
+ d_model,
+ dropout=attention_dropout,
+ max_relative_positions=max_relative_positions,
+ )
+ elif self_attn_type == "average":
+ self.self_attn = AverageAttention(
+ d_model, dropout=attention_dropout, aan_useffn=aan_useffn
+ )
+
+ self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout,
+ pos_ffn_activation_fn
+ )
+ self.layer_norm_1 = nn.LayerNorm(d_model, eps=1e-6)
+ self.drop = nn.Dropout(dropout)
+ self.full_context_alignment = full_context_alignment
+ self.alignment_heads = alignment_heads
+
def forward(self, *args, **kwargs):
"""Extend `_forward` for (possibly) multiple decoder pass:
Always a default (future masked) decoder forward pass,
@@ -51,11 +112,51 @@ class TransformerDecoderLayerBase(nn.Module):
attn_align = attns.mean(dim=1)
return output, top_attn, attn_align
+ def update_dropout(self, dropout, attention_dropout):
+ self.self_attn.update_dropout(attention_dropout)
+ self.feed_forward.update_dropout(dropout)
+ self.drop.p = dropout
+
def _forward(self, *args, **kwargs):
raise NotImplementedError
- def update_dropout(self, dropout, attention_dropout):
- raise NotImplementedError
+ def _compute_dec_mask(self, tgt_pad_mask, future):
+ tgt_len = tgt_pad_mask.size(-1)
+ if not future: # apply future_mask, result mask in (B, T, T)
+ future_mask = torch.ones(
+ [tgt_len, tgt_len],
+ device=tgt_pad_mask.device,
+ dtype=torch.uint8,
+ )
+ future_mask = future_mask.triu_(1).view(1, tgt_len, tgt_len)
+ # BoolTensor was introduced in pytorch 1.2
+ try:
+ future_mask = future_mask.bool()
+ except AttributeError:
+ pass
+ dec_mask = torch.gt(tgt_pad_mask + future_mask, 0)
+ else: # only mask padding, result mask in (B, 1, T)
+ dec_mask = tgt_pad_mask
+ return dec_mask
+
+ def _forward_self_attn(self, inputs_norm, dec_mask, layer_cache, step):
+ if isinstance(self.self_attn, MultiHeadedAttention):
+ return self.self_attn(
+ inputs_norm,
+ inputs_norm,
+ inputs_norm,
+ mask=dec_mask,
+ layer_cache=layer_cache,
+ attn_type="self",
+ )
+ elif isinstance(self.self_attn, AverageAttention):
+ return self.self_attn(
+ inputs_norm, mask=dec_mask, layer_cache=layer_cache, step=step
+ )
+ else:
+ raise ValueError(
+ f"self attention {type(self.self_attn)} not supported"
+ )
class TransformerDecoderLayer(TransformerDecoderLayerBase):
@@ -76,23 +177,6 @@ class TransformerDecoderLayer(TransformerDecoderLayerBase):
A --> E
E --> F(out)
-
- Args:
- d_model (int): the dimension of keys/values/queries in
- :class:`MultiHeadedAttention`, also the input size of
- the first-layer of the :class:`PositionwiseFeedForward`.
- heads (int): the number of heads for MultiHeadedAttention.
- d_ff (int): the second-layer of the :class:`PositionwiseFeedForward`.
- dropout (float): dropout in residual, self-attn(dot) and feed-forward
- attention_dropout (float): dropout in context_attn (and self-attn(avg))
- self_attn_type (string): type of self-attention scaled-dot, average
- max_relative_positions (int):
- Max distance between inputs in relative positions representations
- aan_useffn (bool): Turn on the FFN layer in the AAN decoder
- full_context_alignment (bool):
- whether enable an extra full context decoder forward for alignment
- alignment_heads (int):
- N. of cross attention heads to use for alignment guiding
"""
def __init__(
@@ -107,36 +191,35 @@ class TransformerDecoderLayer(TransformerDecoderLayerBase):
aan_useffn=False,
full_context_alignment=False,
alignment_heads=0,
+ pos_ffn_activation_fn=ActivationFunction.relu,
):
- super(TransformerDecoderLayer, self).__init__()
-
- if self_attn_type == "scaled-dot":
- self.self_attn = MultiHeadedAttention(
- heads,
- d_model,
- dropout=attention_dropout,
- max_relative_positions=max_relative_positions,
- )
- elif self_attn_type == "average":
- self.self_attn = AverageAttention(
- d_model, dropout=attention_dropout, aan_useffn=aan_useffn
- )
-
+ """
+ Args:
+ See TransformerDecoderLayerBase
+ """
+ super(TransformerDecoderLayer, self).__init__(
+ d_model,
+ heads,
+ d_ff,
+ dropout,
+ attention_dropout,
+ self_attn_type,
+ max_relative_positions,
+ aan_useffn,
+ full_context_alignment,
+ alignment_heads,
+ pos_ffn_activation_fn=pos_ffn_activation_fn,
+ )
self.context_attn = MultiHeadedAttention(
heads, d_model, dropout=attention_dropout
)
- self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout)
- self.layer_norm_1 = nn.LayerNorm(d_model, eps=1e-6)
self.layer_norm_2 = nn.LayerNorm(d_model, eps=1e-6)
- self.drop = nn.Dropout(dropout)
- self.full_context_alignment = full_context_alignment
- self.alignment_heads = alignment_heads
def update_dropout(self, dropout, attention_dropout):
- self.self_attn.update_dropout(attention_dropout)
+ super(TransformerDecoderLayer, self).update_dropout(
+ dropout, attention_dropout
+ )
self.context_attn.update_dropout(attention_dropout)
- self.feed_forward.update_dropout(dropout)
- self.drop.p = dropout
def _forward(
self,
@@ -170,39 +253,15 @@ class TransformerDecoderLayer(TransformerDecoderLayerBase):
"""
dec_mask = None
- if step is None:
- tgt_len = tgt_pad_mask.size(-1)
- if not future: # apply future_mask, result mask in (B, T, T)
- future_mask = torch.ones(
- [tgt_len, tgt_len],
- device=tgt_pad_mask.device,
- dtype=torch.uint8,
- )
- future_mask = future_mask.triu_(1).view(1, tgt_len, tgt_len)
- # BoolTensor was introduced in pytorch 1.2
- try:
- future_mask = future_mask.bool()
- except AttributeError:
- pass
- dec_mask = torch.gt(tgt_pad_mask + future_mask, 0)
- else: # only mask padding, result mask in (B, 1, T)
- dec_mask = tgt_pad_mask
-
- input_norm = self.layer_norm_1(inputs)
+ if inputs.size(1) > 1:
+ # masking is necessary when sequence length is greater than one
+ dec_mask = self._compute_dec_mask(tgt_pad_mask, future)
- if isinstance(self.self_attn, MultiHeadedAttention):
- query, _ = self.self_attn(
- input_norm,
- input_norm,
- input_norm,
- mask=dec_mask,
- layer_cache=layer_cache,
- attn_type="self",
- )
- elif isinstance(self.self_attn, AverageAttention):
- query, _ = self.self_attn(
- input_norm, mask=dec_mask, layer_cache=layer_cache, step=step
- )
+ inputs_norm = self.layer_norm_1(inputs)
+
+ query, _ = self._forward_self_attn(
+ inputs_norm, dec_mask, layer_cache, step
+ )
query = self.drop(query) + inputs
@@ -257,6 +316,7 @@ class TransformerDecoderBase(DecoderBase):
opt.full_context_alignment,
opt.alignment_layer,
alignment_heads=opt.alignment_heads,
+ pos_ffn_activation_fn=opt.pos_ffn_activation_fn,
)
def init_state(self, src, memory_bank, enc_hidden):
@@ -345,6 +405,7 @@ class TransformerDecoder(TransformerDecoderBase):
full_context_alignment,
alignment_layer,
alignment_heads,
+ pos_ffn_activation_fn=ActivationFunction.relu,
):
super(TransformerDecoder, self).__init__(
d_model, copy_attn, embeddings, alignment_layer
@@ -363,6 +424,7 @@ class TransformerDecoder(TransformerDecoderBase):
aan_useffn=aan_useffn,
full_context_alignment=full_context_alignment,
alignment_heads=alignment_heads,
+ pos_ffn_activation_fn=pos_ffn_activation_fn,
)
for i in range(num_layers)
]
@@ -460,57 +522,9 @@ class TransformerLMDecoderLayer(TransformerDecoderLayerBase):
Args:
- d_model (int): the dimension of keys/values/queries in
- :class:`MultiHeadedAttention`, also the input size of
- the first-layer of the :class:`PositionwiseFeedForward`.
- heads (int): the number of heads for MultiHeadedAttention.
- d_ff (int): the second-layer of the :class:`PositionwiseFeedForward`.
- dropout (float): dropout in residual, self-attn(dot) and feed-forward
- attention_dropout (float): dropout in context_attn (and self-attn(avg))
- self_attn_type (string): type of self-attention scaled-dot, average
- max_relative_positions (int):
- Max distance between inputs in relative positions representations
- aan_useffn (bool): Turn on the FFN layer in the AAN decoder
- full_context_alignment (bool):
- whether enable an extra full context decoder forward for alignment
- alignment_heads (int):
- N. of cross attention heads to use for alignment guiding
+ See TransformerDecoderLayerBase
"""
- def __init__(
- self,
- d_model,
- heads,
- d_ff,
- dropout,
- attention_dropout,
- self_attn_type="scaled-dot",
- max_relative_positions=0,
- aan_useffn=False,
- full_context_alignment=False,
- alignment_heads=0,
- ):
- super(TransformerLMDecoderLayer, self).__init__()
-
- if self_attn_type == "scaled-dot":
- self.self_attn = MultiHeadedAttention(
- heads,
- d_model,
- dropout=attention_dropout,
- max_relative_positions=max_relative_positions,
- )
- elif self_attn_type == "average":
- self.self_attn = AverageAttention(
- d_model, dropout=attention_dropout, aan_useffn=aan_useffn
- )
-
- self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout)
- self.layer_norm_1 = nn.LayerNorm(d_model, eps=1e-6)
- self.layer_norm_2 = nn.LayerNorm(d_model, eps=1e-6)
- self.drop = nn.Dropout(dropout)
- self.full_context_alignment = full_context_alignment
- self.alignment_heads = alignment_heads
-
def _forward(
self, inputs, tgt_pad_mask, layer_cache=None, step=None, future=False
):
@@ -534,51 +548,21 @@ class TransformerLMDecoderLayer(TransformerDecoderLayerBase):
"""
dec_mask = None
- if step is None:
- tgt_len = tgt_pad_mask.size(-1)
- if not future: # apply future_mask, result mask in (B, T, T)
- future_mask = torch.ones(
- [tgt_len, tgt_len],
- device=tgt_pad_mask.device,
- dtype=torch.uint8,
- )
- future_mask = future_mask.triu_(1).view(1, tgt_len, tgt_len)
- # BoolTensor was introduced in pytorch 1.2
- try:
- future_mask = future_mask.bool()
- except AttributeError:
- pass
- dec_mask = torch.gt(tgt_pad_mask + future_mask, 0)
- else: # only mask padding, result mask in (B, 1, T)
- dec_mask = tgt_pad_mask
+ if inputs.size(1) > 1:
+ # masking is necessary when sequence length is greater than one
+ dec_mask = self._compute_dec_mask(tgt_pad_mask, future)
inputs_norm = self.layer_norm_1(inputs)
- if isinstance(self.self_attn, MultiHeadedAttention):
- query, attns = self.self_attn(
- inputs_norm,
- inputs_norm,
- inputs_norm,
- mask=dec_mask,
- layer_cache=layer_cache,
- attn_type="self",
- )
- elif isinstance(self.self_attn, AverageAttention):
- query, attns = self.self_attn(
- inputs_norm, mask=dec_mask, layer_cache=layer_cache, step=step
- )
- output = self.drop(query) + inputs
-
- output_feedforward = self.feed_forward(self.layer_norm_2(output))
+ query, attns = self._forward_self_attn(
+ inputs_norm, dec_mask, layer_cache, step
+ )
- output_norm = self.drop(output_feedforward) + output
+ output = self.drop(query) + inputs
- return output_norm, attns
+ output_feedforward = self.feed_forward(output)
- def update_dropout(self, dropout, attention_dropout):
- self.self_attn.update_dropout(attention_dropout)
- self.feed_forward.update_dropout(dropout)
- self.drop.p = dropout
+ return output_feedforward, attns
class TransformerLMDecoder(TransformerDecoderBase):
@@ -628,6 +612,7 @@ class TransformerLMDecoder(TransformerDecoderBase):
full_context_alignment=None,
alignment_layer=None,
alignment_heads=None,
+ pos_ffn_activation_fn=ActivationFunction.relu,
):
super(TransformerLMDecoder, self).__init__(
d_model, copy_attn, embeddings, None
@@ -645,6 +630,7 @@ class TransformerLMDecoder(TransformerDecoderBase):
aan_useffn=aan_useffn,
full_context_alignment=None,
alignment_heads=None,
+ pos_ffn_activation_fn=pos_ffn_activation_fn,
)
for i in range(num_layers)
]
diff --git a/onmt/encoders/transformer.py b/onmt/encoders/transformer.py
index 5c29d8de..afb796b3 100644
--- a/onmt/encoders/transformer.py
+++ b/onmt/encoders/transformer.py
@@ -7,6 +7,7 @@ import torch.nn as nn
from onmt.encoders.encoder import EncoderBase
from onmt.modules import MultiHeadedAttention
from onmt.modules.position_ffn import PositionwiseFeedForward
+from onmt.modules.position_ffn import ActivationFunction
from onmt.utils.misc import sequence_mask
@@ -21,16 +22,20 @@ class TransformerEncoderLayer(nn.Module):
heads (int): the number of head for MultiHeadedAttention.
d_ff (int): the second-layer of the PositionwiseFeedForward.
dropout (float): dropout probability(0-1.0).
+ pos_ffn_activation_fn (ActivationFunction):
+ activation function choice for PositionwiseFeedForward layer
"""
def __init__(self, d_model, heads, d_ff, dropout, attention_dropout,
- max_relative_positions=0):
+ max_relative_positions=0,
+ pos_ffn_activation_fn=ActivationFunction.relu):
super(TransformerEncoderLayer, self).__init__()
self.self_attn = MultiHeadedAttention(
heads, d_model, dropout=attention_dropout,
max_relative_positions=max_relative_positions)
- self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout)
+ self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout,
+ pos_ffn_activation_fn)
self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
self.dropout = nn.Dropout(dropout)
@@ -80,6 +85,8 @@ class TransformerEncoder(EncoderBase):
dropout (float): dropout parameters
embeddings (onmt.modules.Embeddings):
embeddings to use, should have positional encodings
+ pos_ffn_activation_fn (ActivationFunction):
+ activation function choice for PositionwiseFeedForward layer
Returns:
(torch.FloatTensor, torch.FloatTensor):
@@ -89,14 +96,16 @@ class TransformerEncoder(EncoderBase):
"""
def __init__(self, num_layers, d_model, heads, d_ff, dropout,
- attention_dropout, embeddings, max_relative_positions):
+ attention_dropout, embeddings, max_relative_positions,
+ pos_ffn_activation_fn=ActivationFunction.relu):
super(TransformerEncoder, self).__init__()
self.embeddings = embeddings
self.transformer = nn.ModuleList(
[TransformerEncoderLayer(
d_model, heads, d_ff, dropout, attention_dropout,
- max_relative_positions=max_relative_positions)
+ max_relative_positions=max_relative_positions,
+ pos_ffn_activation_fn=pos_ffn_activation_fn)
for i in range(num_layers)])
self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
@@ -112,7 +121,9 @@ class TransformerEncoder(EncoderBase):
opt.attention_dropout[0] if type(opt.attention_dropout)
is list else opt.attention_dropout,
embeddings,
- opt.max_relative_positions)
+ opt.max_relative_positions,
+ pos_ffn_activation_fn=opt.pos_ffn_activation_fn,
+ )
def forward(self, src, lengths=None):
"""See :func:`EncoderBase.forward()`"""
diff --git a/onmt/modules/average_attn.py b/onmt/modules/average_attn.py
index 29c4259e..7d8dda5d 100644
--- a/onmt/modules/average_attn.py
+++ b/onmt/modules/average_attn.py
@@ -5,6 +5,7 @@ import torch
import torch.nn as nn
from onmt.modules.position_ffn import PositionwiseFeedForward
+from onmt.modules.position_ffn import ActivationFunction
class AverageAttention(nn.Module):
@@ -17,15 +18,20 @@ class AverageAttention(nn.Module):
model_dim (int): the dimension of keys/values/queries,
must be divisible by head_count
dropout (float): dropout parameter
+ pos_ffn_activation_fn (ActivationFunction):
+ activation function choice for PositionwiseFeedForward layer
"""
- def __init__(self, model_dim, dropout=0.1, aan_useffn=False):
+ def __init__(self, model_dim, dropout=0.1, aan_useffn=False,
+ pos_ffn_activation_fn=ActivationFunction.relu):
self.model_dim = model_dim
self.aan_useffn = aan_useffn
super(AverageAttention, self).__init__()
if aan_useffn:
self.average_layer = PositionwiseFeedForward(model_dim, model_dim,
- dropout)
+ dropout,
+ pos_ffn_activation_fn
+ )
self.gating_layer = nn.Linear(model_dim * 2, model_dim * 2)
def cumulative_average_mask(self, batch_size, inputs_len, device):
diff --git a/onmt/modules/embeddings.py b/onmt/modules/embeddings.py
index 8a97a77c..9561a66e 100644
--- a/onmt/modules/embeddings.py
+++ b/onmt/modules/embeddings.py
@@ -52,15 +52,13 @@ class PositionalEncoding(nn.Module):
"""
emb = emb * math.sqrt(self.dim)
- if step is None:
- if self.pe.size(0) < emb.size(0):
- raise SequenceTooLongError(
- f"Sequence is {emb.size(0)} but PositionalEncoding is"
- f" limited to {self.pe.size(0)}. See max_len argument."
- )
- emb = emb + self.pe[:emb.size(0)]
- else:
- emb = emb + self.pe[step]
+ step = step or 0
+ if self.pe.size(0) < step + emb.size(0):
+ raise SequenceTooLongError(
+ f"Sequence is {emb.size(0) + step} but PositionalEncoding is"
+ f" limited to {self.pe.size(0)}. See max_len argument."
+ )
+ emb = emb + self.pe[step:emb.size(0)+step]
emb = self.dropout(emb)
return emb
diff --git a/onmt/modules/position_ffn.py b/onmt/modules/position_ffn.py
index fb8df80a..6757742d 100644
--- a/onmt/modules/position_ffn.py
+++ b/onmt/modules/position_ffn.py
@@ -1,6 +1,18 @@
"""Position feed-forward network from "Attention is All You Need"."""
import torch.nn as nn
+import torch.nn.functional as F
+
+
+class ActivationFunction(object):
+ relu = "relu"
+ gelu = "gelu"
+
+
+ACTIVATION_FUNCTIONS = {
+ ActivationFunction.relu: F.relu,
+ ActivationFunction.gelu: F.gelu,
+}
class PositionwiseFeedForward(nn.Module):
@@ -11,15 +23,17 @@ class PositionwiseFeedForward(nn.Module):
d_ff (int): the hidden layer size of the second-layer
of the FNN.
dropout (float): dropout probability in :math:`[0, 1)`.
+ activation_fn (ActivationFunction): activation function used.
"""
- def __init__(self, d_model, d_ff, dropout=0.1):
+ def __init__(self, d_model, d_ff, dropout=0.1,
+ activation_fn=ActivationFunction.relu):
super(PositionwiseFeedForward, self).__init__()
self.w_1 = nn.Linear(d_model, d_ff)
self.w_2 = nn.Linear(d_ff, d_model)
self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
self.dropout_1 = nn.Dropout(dropout)
- self.relu = nn.ReLU()
+ self.activation = ACTIVATION_FUNCTIONS[activation_fn]
self.dropout_2 = nn.Dropout(dropout)
def forward(self, x):
@@ -32,7 +46,7 @@ class PositionwiseFeedForward(nn.Module):
(FloatTensor): Output ``(batch_size, input_len, model_dim)``.
"""
- inter = self.dropout_1(self.relu(self.w_1(self.layer_norm(x))))
+ inter = self.dropout_1(self.activation(self.w_1(self.layer_norm(x))))
output = self.dropout_2(self.w_2(inter))
return output + x
diff --git a/onmt/opts.py b/onmt/opts.py
index 65be0e93..0ec07177 100644
--- a/onmt/opts.py
+++ b/onmt/opts.py
@@ -6,6 +6,8 @@ import configargparse
from onmt.models.sru import CheckSRU
from onmt.transforms import AVAILABLE_TRANSFORMS
from onmt.constants import ModelTask
+from onmt.modules.position_ffn import ACTIVATION_FUNCTIONS
+from onmt.modules.position_ffn import ActivationFunction
def config_opts(parser):
@@ -291,6 +293,13 @@ def model_opts(parser):
help="Size of windows in the cnn, the kernel_size is "
"(cnn_kernel_width, 1) in conv layer")
+ group.add('--pos_ffn_activation_fn', '-pos_ffn_activation_fn',
+ type=str, default=ActivationFunction.relu,
+ choices=ACTIVATION_FUNCTIONS.keys(), help='The activation'
+ ' function to use in PositionwiseFeedForward layer. Choices are'
+ f' {ACTIVATION_FUNCTIONS.keys()}. Default to'
+ f' {ActivationFunction.relu}.')
+
group.add('--input_feed', '-input_feed', type=int, default=1,
help="Feed the context vector at each time step as "
"additional input (via concatenation with the word "
diff --git a/onmt/tests/rebuild_test_models.sh b/onmt/tests/rebuild_test_models.sh
index e6225a46..e4b2aa51 100755
--- a/onmt/tests/rebuild_test_models.sh
+++ b/onmt/tests/rebuild_test_models.sh
@@ -114,7 +114,7 @@ $my_python train.py -config data/lm_data.yaml -save_model /tmp/tmp \
-attention_dropout 0.1 -heads 2 -position_encoding -param_init 0 -warmup_steps 100 \
-param_init_glorot -adam_beta2 0.998 -src_vocab data/data_lm/data.vocab.src
#
-mv /tmp/tmp*2000.pt onmt/tests/test_model_lm_2.pt
+mv /tmp/tmp*2000.pt onmt/tests/test_model_lm.pt
rm /tmp/tmp*.pt
fi
#
diff --git a/onmt/tests/test_beam_search.py b/onmt/tests/test_beam_search.py
index bbd24f4b..08d072ae 100644
--- a/onmt/tests/test_beam_search.py
+++ b/onmt/tests/test_beam_search.py
@@ -574,14 +574,14 @@ class TestBeamSearchLM(TestBeamSearchAgainstReferenceCase):
scores_finish = torch.log_softmax(torch.tensor(
[[0, 0, 10000, 0, 5000, .51, .2, 0], # beam 0 shouldn't cont
[100000, 100001, 0, 0, 0, 0, 0, 0],
- [0,100000, 0, 0, 0, 5000, 0, 0],
+ [0, 100000, 0, 0, 0, 5000, 0, 0],
[0, 0, 0, .2, .2, .2, .2, .2],
[0, 0, 0, 0, .2, .2, .2, .2]] # beam 4 -> beam 1 should die
), dim=1)
scores_finish = scores_finish.repeat(self.BATCH_SZ, 1)
scores_finish[:self.BEAM_SZ, beam.eos] = 0
- beam.advance( scores_finish, None)
-
+ beam.advance(scores_finish, None)
+
any_finished = beam.is_finished.any()
if any_finished:
beam.update_finished()
@@ -601,9 +601,9 @@ class TestBeamSearchLM(TestBeamSearchAgainstReferenceCase):
self.third_step(beam, expected_beam_scores, 1)
n_steps = beam.alive_seq.shape[-1] - 1
- self.assertTrue(beam.memory_lengths.equal(n_steps+fn_map_state(src_lengths, dim=0)))
-
-
+ self.assertTrue(beam.memory_lengths.equal(n_steps+fn_map_state(
+ src_lengths, dim=0)))
+
def test_beam_lm_update_memory_length_when_finished(self):
beam = BeamSearchLM(
self.BEAM_SZ, self.BATCH_SZ, 0, 1, 2, self.N_BEST,
@@ -613,8 +613,9 @@ class TestBeamSearchLM(TestBeamSearchAgainstReferenceCase):
device_init = torch.zeros(1, 1)
src_lengths = torch.randint(0, 30, (self.BATCH_SZ,))
fn_map_state, _, _, _ = beam.initialize(device_init, src_lengths)
- expected_beam_scores = self.init_step(beam, 1)
+ _ = self.init_step(beam, 1)
self.finish_first_beam_step(beam)
-
+
n_steps = beam.alive_seq.shape[-1] - 1
- self.assertTrue(beam.memory_lengths.equal(n_steps+fn_map_state(src_lengths[1:], dim=0)))
+ self.assertTrue(beam.memory_lengths.equal(n_steps+fn_map_state(
+ src_lengths[1:], dim=0)))
diff --git a/onmt/tests/test_model_lm.pt b/onmt/tests/test_model_lm.pt
index ef7af433..84a5729f 100644
--- a/onmt/tests/test_model_lm.pt
+++ b/onmt/tests/test_model_lm.pt
Binary files differ