Welcome to mirror list, hosted at ThFree Co, Russian Federation.

encoder.cpp « nematus « cpu « amun « src - github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: 240d760181182455fa33e8784ee58a76f47bc996 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
#include "encoder.h"

using namespace std;

namespace amunmt {
namespace CPU {
namespace Nematus {

void Encoder::GetContext(const std::vector<unsigned>& words, mblas::Tensor& context) {
  std::vector<mblas::Tensor> embeddedWords;

  context.resize(words.size(),
                 forwardRnn_.GetStateLength() + backwardRnn_.GetStateLength());

  for (auto& w : words) {
    embeddedWords.emplace_back();
    mblas::Tensor &embed = embeddedWords.back();
    embeddings_.Lookup(embed, w);
  }

  forwardRnn_.GetContext(embeddedWords.cbegin(),
						 embeddedWords.cend(),
						 context, false);
  backwardRnn_.GetContext(embeddedWords.crbegin(),
						  embeddedWords.crend(),
						  context, true);
}

}  // namespace Nematus
}  // namespace CPU
}  // namespace amunmt