diff options
Diffstat (limited to 'onmt/modules/conv_multi_step_attention.py')
-rw-r--r-- | onmt/modules/conv_multi_step_attention.py | 15 |
1 files changed, 2 insertions, 13 deletions
diff --git a/onmt/modules/conv_multi_step_attention.py b/onmt/modules/conv_multi_step_attention.py index 545df1c9..917ffabc 100644 --- a/onmt/modules/conv_multi_step_attention.py +++ b/onmt/modules/conv_multi_step_attention.py @@ -2,7 +2,6 @@ import torch import torch.nn as nn import torch.nn.functional as F -from onmt.utils.misc import aeq SCALE_WEIGHT = 0.5 ** 0.5 @@ -37,8 +36,8 @@ class ConvMultiStepAttention(nn.Module): encoder_out_combine): """ Args: - base_target_emb: target emb tensor - input_from_dec: output of decode conv + base_target_emb: target emb tensor (batch, channel, height, width) + input_from_dec: output of dec conv (batch, channel, height, width) encoder_out_top: the key matrix for calculation of attetion weight, which is the top output of encode conv encoder_out_combine: @@ -46,22 +45,12 @@ class ConvMultiStepAttention(nn.Module): which is the combination of base emb and top output of encode """ - # checks - # batch, channel, height, width = base_target_emb.size() batch, _, height, _ = base_target_emb.size() - # batch_, channel_, height_, width_ = input_from_dec.size() batch_, _, height_, _ = input_from_dec.size() - aeq(batch, batch_) - aeq(height, height_) - # enc_batch, enc_channel, enc_height = encoder_out_top.size() enc_batch, _, enc_height = encoder_out_top.size() - # enc_batch_, enc_channel_, enc_height_ = encoder_out_combine.size() enc_batch_, _, enc_height_ = encoder_out_combine.size() - aeq(enc_batch, enc_batch_) - aeq(enc_height, enc_height_) - preatt = seq_linear(self.linear_in, input_from_dec) target = (base_target_emb + preatt) * SCALE_WEIGHT target = torch.squeeze(target, 3) |