diff options
author | Hieu Hoang <hieuhoang@gmail.com> | 2018-02-27 02:33:34 +0300 |
---|---|---|
committer | Hieu Hoang <hieuhoang@gmail.com> | 2018-02-27 02:33:34 +0300 |
commit | 43aeb5539ec0f0df9f16d0773561350ce9219fb4 (patch) | |
tree | f9eb0ea7c17d8430ced13ed41dcccd914ebddc94 /src/amun/cpu/nematus/encoder_decoder.cpp | |
parent | 1d98637ad98eb85f67d3734ab5c5169667a78aea (diff) |
Matrix -> Tensor
Diffstat (limited to 'src/amun/cpu/nematus/encoder_decoder.cpp')
-rw-r--r-- | src/amun/cpu/nematus/encoder_decoder.cpp | 6 |
1 files changed, 3 insertions, 3 deletions
diff --git a/src/amun/cpu/nematus/encoder_decoder.cpp b/src/amun/cpu/nematus/encoder_decoder.cpp index 81a7e25a..b5a7cac9 100644 --- a/src/amun/cpu/nematus/encoder_decoder.cpp +++ b/src/amun/cpu/nematus/encoder_decoder.cpp @@ -64,17 +64,17 @@ void EncoderDecoder::AssembleBeamState(const State& in, const EDState& edIn = in.get<EDState>(); EDState& edOut = out.get<EDState>(); - edOut.GetStates() = mblas::Assemble<mblas::byRow, mblas::Matrix>(edIn.GetStates(), beamStateIds); + edOut.GetStates() = mblas::Assemble<mblas::byRow, mblas::Tensor>(edIn.GetStates(), beamStateIds); decoder_->Lookup(edOut.GetEmbeddings(), beamWords); } -void EncoderDecoder::GetAttention(mblas::Matrix& Attention) { +void EncoderDecoder::GetAttention(mblas::Tensor& Attention) { decoder_->GetAttention(Attention); } -mblas::Matrix& EncoderDecoder::GetAttention() { +mblas::Tensor& EncoderDecoder::GetAttention() { return decoder_->GetAttention(); } |