1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
|
#include "cpu/dl4mt/encoder_decoder.h"
#include <vector>
#include <yaml-cpp/yaml.h>
#include "common/sentences.h"
#include "cpu/dl4mt/encoder.h"
#include "cpu/dl4mt/decoder.h"
namespace amunmt {
namespace CPU {
namespace dl4mt {
using EDState = EncoderDecoderState;
EncoderDecoder::EncoderDecoder(const God &god,
const std::string& name,
const YAML::Node& config,
unsigned tab,
const dl4mt::Weights& model)
: CPUEncoderDecoderBase(god, name, config, tab),
model_(model),
encoder_(new dl4mt::Encoder(model_)),
decoder_(new dl4mt::Decoder(model_))
{}
void EncoderDecoder::Decode(const State& in, State& out, const std::vector<unsigned>&) {
const EDState& edIn = in.get<EDState>();
EDState& edOut = out.get<EDState>();
decoder_->Decode(edOut.GetStates(), edIn.GetStates(),
edIn.GetEmbeddings(), SourceContext_);
}
void EncoderDecoder::BeginSentenceState(State& state, unsigned batchSize) {
EDState& edState = state.get<EDState>();
decoder_->EmptyState(edState.GetStates(), SourceContext_, batchSize);
decoder_->EmptyEmbedding(edState.GetEmbeddings(), batchSize);
}
void EncoderDecoder::Encode(const Sentences& sources) {
encoder_->Encode(sources.Get(0).GetWords(tab_), SourceContext_);
}
void EncoderDecoder::AssembleBeamState(const State& in,
const Beam& beam,
State& out) {
std::vector<unsigned> beamWords;
std::vector<unsigned> beamStateIds;
for(auto h : beam) {
beamWords.push_back(h->GetWord());
beamStateIds.push_back(h->GetPrevStateIndex());
}
const EDState& edIn = in.get<EDState>();
EDState& edOut = out.get<EDState>();
edOut.GetStates() = mblas::Assemble<mblas::byRow, mblas::Tensor>(edIn.GetStates(), beamStateIds);
decoder_->Lookup(edOut.GetEmbeddings(), beamWords);
}
void EncoderDecoder::GetAttention(mblas::Tensor& Attention) {
decoder_->GetAttention(Attention);
}
mblas::Tensor& EncoderDecoder::GetAttention() {
return decoder_->GetAttention();
}
unsigned EncoderDecoder::GetVocabSize() const {
return decoder_->GetVocabSize();
}
void EncoderDecoder::Filter(const std::vector<unsigned>& filterIds) {
decoder_->Filter(filterIds);
}
BaseTensor& EncoderDecoder::GetProbs() {
return decoder_->GetProbs();
}
}
}
}
|