Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/OpenNMT/OpenNMT-py.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'onmt/modules/conv_multi_step_attention.py')
-rw-r--r--onmt/modules/conv_multi_step_attention.py15
1 files changed, 2 insertions, 13 deletions
diff --git a/onmt/modules/conv_multi_step_attention.py b/onmt/modules/conv_multi_step_attention.py
index 545df1c9..917ffabc 100644
--- a/onmt/modules/conv_multi_step_attention.py
+++ b/onmt/modules/conv_multi_step_attention.py
@@ -2,7 +2,6 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
-from onmt.utils.misc import aeq
SCALE_WEIGHT = 0.5 ** 0.5
@@ -37,8 +36,8 @@ class ConvMultiStepAttention(nn.Module):
encoder_out_combine):
"""
Args:
- base_target_emb: target emb tensor
- input_from_dec: output of decode conv
+ base_target_emb: target emb tensor (batch, channel, height, width)
+ input_from_dec: output of dec conv (batch, channel, height, width)
encoder_out_top: the key matrix for calculation of attetion weight,
which is the top output of encode conv
encoder_out_combine:
@@ -46,22 +45,12 @@ class ConvMultiStepAttention(nn.Module):
which is the combination of base emb and top output of encode
"""
- # checks
- # batch, channel, height, width = base_target_emb.size()
batch, _, height, _ = base_target_emb.size()
- # batch_, channel_, height_, width_ = input_from_dec.size()
batch_, _, height_, _ = input_from_dec.size()
- aeq(batch, batch_)
- aeq(height, height_)
- # enc_batch, enc_channel, enc_height = encoder_out_top.size()
enc_batch, _, enc_height = encoder_out_top.size()
- # enc_batch_, enc_channel_, enc_height_ = encoder_out_combine.size()
enc_batch_, _, enc_height_ = encoder_out_combine.size()
- aeq(enc_batch, enc_batch_)
- aeq(enc_height, enc_height_)
-
preatt = seq_linear(self.linear_in, input_from_dec)
target = (base_target_emb + preatt) * SCALE_WEIGHT
target = torch.squeeze(target, 3)