1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
|
#include "encoder.h"
#include "common/sentences.h"
using namespace std;
namespace amunmt {
namespace GPU {
Encoder::Encoder(const Weights& model, const YAML::Node& config)
: embeddings_(model.encEmbeddings_),
forwardRnn_(InitForwardCell(model, config)),
backwardRnn_(InitBackwardCell(model, config))
{}
std::unique_ptr<Cell> Encoder::InitForwardCell(const Weights& model, const YAML::Node& config){
std::string celltype = config["enc-cell"] ? config["enc-cell"].as<std::string>() : "gru";
if (celltype == "lstm") {
return unique_ptr<Cell>(new LSTM<Weights::EncForwardLSTM>(*(model.encForwardLSTM_)));
} else if (celltype == "mlstm") {
return unique_ptr<Cell>(new Multiplicative<LSTM, Weights::EncForwardLSTM>(*model.encForwardMLSTM_));
} else if (celltype == "gru") {
return unique_ptr<Cell>(new GRU<Weights::EncForwardGRU>(*(model.encForwardGRU_)));
}
assert(false);
return unique_ptr<Cell>(nullptr);
}
std::unique_ptr<Cell> Encoder::InitBackwardCell(const Weights& model, const YAML::Node& config){
std::string enccell = config["enc-cell"] ? config["enc-cell"].as<std::string>() : "gru";
std::string celltype = config["enc-cell-r"] ? config["enc-cell-r"].as<std::string>() : enccell;
if (celltype == "lstm") {
return unique_ptr<Cell>(new LSTM<Weights::EncBackwardLSTM>(*(model.encBackwardLSTM_)));
} else if (celltype == "mlstm") {
return unique_ptr<Cell>(new Multiplicative<LSTM, Weights::EncBackwardLSTM>(*model.encBackwardMLSTM_));
} else if (celltype == "gru") {
return unique_ptr<Cell>(new GRU<Weights::EncBackwardGRU>(*(model.encBackwardGRU_)));
}
assert(false);
return unique_ptr<Cell>(nullptr);
}
size_t GetMaxLength(const Sentences& source, size_t tab) {
size_t maxLength = source.at(0)->GetWords(tab).size();
for (size_t i = 0; i < source.size(); ++i) {
const Sentence &sentence = *source.at(i);
maxLength = std::max(maxLength, sentence.GetWords(tab).size());
}
return maxLength;
}
std::vector<std::vector<size_t>> GetBatchInput(const Sentences& source, size_t tab, size_t maxLen) {
std::vector<std::vector<size_t>> matrix(maxLen, std::vector<size_t>(source.size(), 0));
for (size_t j = 0; j < source.size(); ++j) {
for (size_t i = 0; i < source.at(j)->GetWords(tab).size(); ++i) {
matrix[i][j] = source.at(j)->GetWords(tab)[i];
}
}
return matrix;
}
void Encoder::Encode(const Sentences& source, size_t tab, mblas::Matrix& context,
mblas::IMatrix &sentencesMask)
{
size_t maxSentenceLength = GetMaxLength(source, tab);
//cerr << "1dMapping=" << mblas::Debug(dMapping, 2) << endl;
HostVector<uint> hMapping(maxSentenceLength * source.size(), 0);
for (size_t i = 0; i < source.size(); ++i) {
for (size_t j = 0; j < source.at(i)->GetWords(tab).size(); ++j) {
hMapping[i * maxSentenceLength + j] = 1;
}
}
sentencesMask.NewSize(maxSentenceLength, source.size(), 1, 1);
mblas::copy(thrust::raw_pointer_cast(hMapping.data()),
hMapping.size(),
sentencesMask.data(),
cudaMemcpyHostToDevice);
//cerr << "GetContext1=" << context.Debug(1) << endl;
context.NewSize(maxSentenceLength,
forwardRnn_.GetStateLength().output + backwardRnn_.GetStateLength().output,
1,
source.size());
//cerr << "GetContext2=" << context.Debug(1) << endl;
auto input = GetBatchInput(source, tab, maxSentenceLength);
for (size_t i = 0; i < input.size(); ++i) {
if (i >= embeddedWords_.size()) {
embeddedWords_.emplace_back();
}
embeddings_.Lookup(embeddedWords_[i], input[i]);
//cerr << "embeddedWords_=" << embeddedWords_.back().Debug(true) << endl;
}
//cerr << "GetContext3=" << context.Debug(1) << endl;
forwardRnn_.Encode(embeddedWords_.cbegin(),
embeddedWords_.cbegin() + maxSentenceLength,
context, source.size(), false);
//cerr << "GetContext4=" << context.Debug(1) << endl;
backwardRnn_.Encode(embeddedWords_.crend() - maxSentenceLength,
embeddedWords_.crend() ,
context, source.size(), true, &sentencesMask);
//cerr << "GetContext5=" << context.Debug(1) << endl;
}
}
}
|