Welcome to mirror list, hosted at ThFree Co, Russian Federation.

encoder_decoder.cpp « dl4mt « cpu « amun « src - github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: 0644182c1275af1618ac01fe16a3ead81228f7f2 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
#include "cpu/dl4mt/encoder_decoder.h"

#include <vector>
#include <yaml-cpp/yaml.h>

#include "common/sentences.h"
#include "cpu/dl4mt/encoder.h"
#include "cpu/dl4mt/decoder.h"


namespace amunmt {
namespace CPU {
namespace dl4mt {

using EDState = EncoderDecoderState;

EncoderDecoder::EncoderDecoder(const God &god,
							   const std::string& name,
                               const YAML::Node& config,
                               unsigned tab,
                               const dl4mt::Weights& model)
  : CPUEncoderDecoderBase(god, name, config, tab),
    model_(model),
    encoder_(new dl4mt::Encoder(model_)),
    decoder_(new dl4mt::Decoder(model_))
{}


void EncoderDecoder::Decode(const State& in, State& out, const std::vector<unsigned>&) {
  const EDState& edIn = in.get<EDState>();
  EDState& edOut = out.get<EDState>();

  decoder_->Decode(edOut.GetStates(), edIn.GetStates(),
                   edIn.GetEmbeddings(), SourceContext_);
}


void EncoderDecoder::BeginSentenceState(State& state, unsigned batchSize) {
  EDState& edState = state.get<EDState>();
  decoder_->EmptyState(edState.GetStates(), SourceContext_, batchSize);
  decoder_->EmptyEmbedding(edState.GetEmbeddings(), batchSize);
}


void EncoderDecoder::Encode(const Sentences& sources) {
  encoder_->Encode(sources.Get(0).GetWords(tab_), SourceContext_);
}


void EncoderDecoder::AssembleBeamState(const State& in,
                                       const Beam& beam,
                                       State& out) {
  std::vector<unsigned> beamWords;
  std::vector<unsigned> beamStateIds;
  for(auto h : beam) {
      beamWords.push_back(h->GetWord());
      beamStateIds.push_back(h->GetPrevStateIndex());
  }

  const EDState& edIn = in.get<EDState>();
  EDState& edOut = out.get<EDState>();

  edOut.GetStates() = mblas::Assemble<mblas::byRow, mblas::Tensor>(edIn.GetStates(), beamStateIds);
  decoder_->Lookup(edOut.GetEmbeddings(), beamWords);
}


void EncoderDecoder::GetAttention(mblas::Tensor& Attention) {
  decoder_->GetAttention(Attention);
}


mblas::Tensor& EncoderDecoder::GetAttention() {
  return decoder_->GetAttention();
}


unsigned EncoderDecoder::GetVocabSize() const {
  return decoder_->GetVocabSize();
}


void EncoderDecoder::Filter(const std::vector<unsigned>& filterIds) {
  decoder_->Filter(filterIds);
}


BaseTensor& EncoderDecoder::GetProbs() {
  return decoder_->GetProbs();
}

}
}
}