Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorFrank Seide <fseide@microsoft.com>2019-07-23 00:53:24 +0300
committerMartin Junczys-Dowmunt <Marcin.JunczysDowmunt@microsoft.com>2019-07-23 00:53:25 +0300
commitddcbe825f0af3f548dd33daedfe340f00c9a1c0c (patch)
tree49ca99d7ee9828ac760887a69d4e71ec3c6459bf /src
parentc341a6c78b905fab1f8a059826458d7392af7d45 (diff)
parent236d153ce4de973a374b827cbc3c4191fec4124a (diff)
Merged PR 8650: fix for decoding with factors for untied embeddings
This fixes Zhongkai's issue when decoding with non-tied embeddings with factors. Related work items: #98842
Diffstat (limited to 'src')
-rwxr-xr-xsrc/common/config_parser.cpp2
-rwxr-xr-xsrc/data/factored_vocab.cpp11
-rwxr-xr-x[-rw-r--r--]src/graph/node_operators_unary.h4
-rwxr-xr-xsrc/layers/generic.cpp22
-rwxr-xr-x[-rw-r--r--]src/microsoft/quicksand.cpp13
-rwxr-xr-x[-rw-r--r--]src/microsoft/quicksand.h8
-rwxr-xr-xsrc/training/validator.h6
-rwxr-xr-x[-rw-r--r--]src/translator/beam_search.h29
-rwxr-xr-x[-rw-r--r--]src/translator/translator.h4
9 files changed, 57 insertions, 42 deletions
diff --git a/src/common/config_parser.cpp b/src/common/config_parser.cpp
index bd6650c6..e3441848 100755
--- a/src/common/config_parser.cpp
+++ b/src/common/config_parser.cpp
@@ -671,7 +671,7 @@ void ConfigParser::addSuboptionsInputLength(cli::CLIWrapper& cli) {
"Maximum length of a sentence in a training sentence pair",
defaultMaxLength);
cli.add<bool>("--max-length-crop",
- "Crop a sentence to max-length instead of ommitting it if longer than max-length");
+ "Crop a sentence to max-length instead of omitting it if longer than max-length");
// clang-format on
}
diff --git a/src/data/factored_vocab.cpp b/src/data/factored_vocab.cpp
index 37fbf4c7..1f536160 100755
--- a/src/data/factored_vocab.cpp
+++ b/src/data/factored_vocab.cpp
@@ -27,6 +27,7 @@ namespace marian {
//LOG(info, "[vocab] Attempting to load model a second time; skipping (assuming shared vocab)");
return size();
}
+ LOG(info, "[vocab] Loading vocab spec file {}", modelPath);
// load factor-vocab file and parse it
std::vector<std::vector<std::string>> factorMapTokenized;
@@ -571,11 +572,11 @@ void FactoredVocab::constructNormalizationInfoForVocab() {
static void unescapeHexEscapes(std::string& utf8Lemma) {
if (utf8Lemma.find('\\') == std::string::npos)
return; // nothing to do
- auto lemma = utils::utf8ToUtf16String(utf8Lemma); // \u.... implies we must operate on UTF-16 level
+ auto lemma = utils::utf8ToUtf16String(utf8Lemma); // \u.... implies we must operate on UTF-16 level (not UCS-4)
auto pos = lemma.find('\\');
while (pos != std::string::npos) {
ABORT_IF(pos + 1 >= lemma.size() || (lemma[pos+1] != 'x' && lemma[pos + 1] != 'u'), "Malformed escape in factored encoding: {}", utf8Lemma);
- int numDigits = 2 + (lemma[pos + 1] == 'u');
+ int numDigits = 2 + 2 * (lemma[pos + 1] == 'u'); // 2 for \x, 4 for \u
ABORT_IF(pos + 2 + numDigits > lemma.size(), "Malformed escape in factored encoding: {}", utf8Lemma);
auto digits = utils::utf8FromUtf16String(lemma.substr(pos + 2, numDigits));
auto c = std::strtoul(digits.c_str(), nullptr, 16);
@@ -607,13 +608,15 @@ std::string FactoredVocab::surfaceForm(const Words& sentence) const /*override f
auto has = [&](const char* factor) { return tokenSet.find(factor) != tokenSet.end(); };
// spacing
bool hasGlueRight = has("gr+") || has("wen") || has("cen");
- bool hasGlueLeft = has("gl+") || has("wbn") || has("cbn");
+ bool hasGlueLeft = has("gl+") || has("wbn") || has("cbn") || has("wi");
bool insertSpaceBefore = !prevHadGlueRight && !hasGlueLeft;
if (insertSpaceBefore)
res.push_back(' ');
prevHadGlueRight = hasGlueRight;
// capitalization
unescapeHexEscapes(lemma); // unescape \x.. and \u....
+ if (utils::beginsWith(lemma, "\xE2\x96\x81")) // remove leading _ (\u2581, for DistinguishInitialAndInternalPieces mode)
+ lemma = lemma.substr(3);
if (has("ci")) lemma = utils::utf8Capitalized(lemma);
else if (has("ca")) lemma = utils::utf8ToUpper (lemma);
else if (has("cn")) lemma = utils::utf8ToLower (lemma);
@@ -730,7 +733,7 @@ Ptr<IVocab> createFactoredVocab(const std::string& vocabPath) {
static std::map<std::string, Ptr<IVocab>> s_cache;
auto iter = s_cache.find(vocabPath);
if (iter != s_cache.end()) {
- LOG(info, "[vocab] Reusing existing vocabulary object in memory");
+ LOG(info, "[vocab] Reusing existing vocabulary object in memory (vocab size {})", iter->second->size());
return iter->second;
}
auto vocab = New<FactoredVocab>();
diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h
index 36273d6f..94200435 100644..100755
--- a/src/graph/node_operators_unary.h
+++ b/src/graph/node_operators_unary.h
@@ -477,6 +477,10 @@ struct ReduceNodeOp : public UnaryNodeOp {
NodeOps backwardOps() override {
using namespace functional;
+#if 1 // @BUGBUG: This is a workaround for not correctly propagating non-trainable information. @TODO: Do this the right and general way.
+ if (adj_ == nullptr)
+ return {};
+#endif
switch (opCode_) {
case ReduceNodeOpCode::sum:
return {NodeOp(Add(_1, child(0)->grad(), adj_))};
diff --git a/src/layers/generic.cpp b/src/layers/generic.cpp
index 30060432..9fedaf25 100755
--- a/src/layers/generic.cpp
+++ b/src/layers/generic.cpp
@@ -246,8 +246,12 @@ namespace marian {
b_ = graph_->param(name + "_b", {1, numOutputClasses}, inits::zeros);
- const int lemmaDimEmb = options_->get<int>("lemma-dim-emb", 0);
+ /*const*/ int lemmaDimEmb = options_->get<int>("lemma-dim-emb", 0);
if (lemmaDimEmb > 0) {
+#define HARDMAX_HACK
+#ifdef HARDMAX_HACK
+ lemmaDimEmb = lemmaDimEmb & 0xfffffffe;
+#endif
auto range = factoredVocab_->getGroupRange(0);
auto lemmaVocabDim = (int)(range.second - range.first);
lemmaEt_ = graph_->param(name + "_lemmaEt", {lemmaDimEmb, lemmaVocabDim}, inits::glorot_uniform); // [L x U] L=lemmaDimEmb; transposed for speed
@@ -299,7 +303,7 @@ namespace marian {
allLogits[g] = New<RationalLoss>(factorLogits, nullptr);
// optionally add a soft embedding of lemma back to create some lemma dependency
// @TODO: if this works, move it into lazyConstruct
- const int lemmaDimEmb = options_->get<int>("lemma-dim-emb", 0);
+ /*const*/ int lemmaDimEmb = options_->get<int>("lemma-dim-emb", 0);
if (lemmaDimEmb < 0 && g == 0) {
ABORT_IF(shortlist_ && lemmaDimEmb != 0, "Lemma-dependent bias with short list is not yet implemented");
LOG_ONCE(info, "[embedding] using lemma-dependent bias");
@@ -310,11 +314,21 @@ namespace marian {
if (lemmaDimEmb > 0 && g == 0) {
LOG_ONCE(info, "[embedding] enabled re-embedding of lemma, at dim {}", lemmaDimEmb);
// compute softmax. We compute logsoftmax() separately because this way, computation will be reused later via CSE
- factorLogits = logsoftmax(factorLogits);
+ auto factorLogSoftmax = logsoftmax(factorLogits);
+ auto factorSoftmax = exp(factorLogSoftmax);
+#ifdef HARDMAX_HACK
+ bool hardmax = (lemmaDimEmb & 1) != 0; // odd value triggers hardmax for now (for quick experimentation)
+ if (hardmax) {
+ lemmaDimEmb = lemmaDimEmb & 0xfffffffe;
+ LOG_ONCE(info, "[embedding] HARDMAX_HACK enabled. Actual dim is {}", lemmaDimEmb);
+ auto maxVal = max(factorSoftmax, -1);
+ factorSoftmax = eq(factorSoftmax, maxVal);
+ }
+#endif
// re-embedding lookup, soft-indexed by softmax
if (shortlist_ && !cachedShortLemmaEt_) // short-listed version of re-embedding matrix
cachedShortLemmaEt_ = index_select(lemmaEt_, -1, shortlist_->indices());
- auto e = dot(exp(factorLogits), cachedShortLemmaEt_ ? cachedShortLemmaEt_ : lemmaEt_, false, true); // [B... x L]
+ auto e = dot(factorSoftmax, cachedShortLemmaEt_ ? cachedShortLemmaEt_ : lemmaEt_, false, true); // [B... x L]
// project it back to regular hidden dim
int inputDim = input1->shape()[-1];
auto name = options_->get<std::string>("prefix");
diff --git a/src/microsoft/quicksand.cpp b/src/microsoft/quicksand.cpp
index 2dc2efc4..4a9a032e 100644..100755
--- a/src/microsoft/quicksand.cpp
+++ b/src/microsoft/quicksand.cpp
@@ -55,9 +55,8 @@ private:
public:
BeamSearchDecoder(Ptr<Options> options,
const std::vector<const void*>& ptrs,
- const std::vector<Ptr<IVocabWrapper>>& vocabs,
- WordIndex eos)
- : IBeamSearchDecoder(options, ptrs, eos) {
+ const std::vector<Ptr<IVocabWrapper>>& vocabs)
+ : IBeamSearchDecoder(options, ptrs) {
// copy the vocabs
for (auto vi : vocabs)
@@ -145,7 +144,7 @@ public:
batch->setSentenceIds(sentIds);
// decode
- auto search = New<BeamSearch>(options_, scorers_, marian::Word::fromWordIndex(eos_));
+ auto search = New<BeamSearch>(options_, scorers_, vocabs_[1]);
Histories histories = search->search(graph_, batch);
// convert to QuickSAND format
@@ -189,8 +188,10 @@ public:
Ptr<IBeamSearchDecoder> newDecoder(Ptr<Options> options,
const std::vector<const void*>& ptrs,
const std::vector<Ptr<IVocabWrapper>>& vocabs,
- WordIndex eos) {
- return New<BeamSearchDecoder>(options, ptrs, vocabs, eos);
+ WordIndex eosDummy) { // @TODO: remove this parameter
+ ABORT_IF(marian::Word::fromWordIndex(eosDummy) != std::dynamic_pointer_cast<VocabWrapper>(vocabs[1])->getVocab()->getEosId(), "Inconsistent eos vs. vocabs_[1]");
+
+ return New<BeamSearchDecoder>(options, ptrs, vocabs/*, eos*/);
}
std::vector<Ptr<IVocabWrapper>> loadVocabs(const std::vector<std::string>& vocabPaths) {
diff --git a/src/microsoft/quicksand.h b/src/microsoft/quicksand.h
index c4f539ab..ffab6c0d 100644..100755
--- a/src/microsoft/quicksand.h
+++ b/src/microsoft/quicksand.h
@@ -42,13 +42,11 @@ class IBeamSearchDecoder {
protected:
Ptr<Options> options_;
std::vector<const void*> ptrs_;
- WordIndex eos_;
public:
IBeamSearchDecoder(Ptr<Options> options,
- const std::vector<const void*>& ptrs,
- WordIndex eos)
- : options_(options), ptrs_(ptrs), eos_(eos) {}
+ const std::vector<const void*>& ptrs)
+ : options_(options), ptrs_(ptrs) {}
virtual QSNBestBatch decode(const QSBatch& qsBatch,
size_t maxLength,
@@ -61,7 +59,7 @@ public:
Ptr<IBeamSearchDecoder> newDecoder(Ptr<Options> options,
const std::vector<const void*>& ptrs,
const std::vector<Ptr<IVocabWrapper>>& vocabs,
- WordIndex eos);
+ WordIndex eos/*dummy --@TODO: remove*/);
// load src and tgt vocabs
std::vector<Ptr<IVocabWrapper>> loadVocabs(const std::vector<std::string>& vocabPaths);
diff --git a/src/training/validator.h b/src/training/validator.h
index 7736318b..330b889e 100755
--- a/src/training/validator.h
+++ b/src/training/validator.h
@@ -534,8 +534,7 @@ public:
auto search = New<BeamSearch>(options_,
std::vector<Ptr<Scorer>>{scorer},
- vocabs_.back()->getEosId(),
- vocabs_.back()->getUnkId());
+ vocabs_.back());
auto histories = search->search(graph, batch);
for(auto history : histories) {
@@ -673,8 +672,7 @@ public:
auto search = New<BeamSearch>(options_,
std::vector<Ptr<Scorer>>{scorer},
- vocabs_.back()->getEosId(),
- vocabs_.back()->getUnkId());
+ vocabs_.back());
auto histories = search->search(graph, batch);
size_t no = 0;
diff --git a/src/translator/beam_search.h b/src/translator/beam_search.h
index 393ff6df..d0a2cb11 100644..100755
--- a/src/translator/beam_search.h
+++ b/src/translator/beam_search.h
@@ -16,23 +16,20 @@ private:
Ptr<Options> options_;
std::vector<Ptr<Scorer>> scorers_;
size_t beamSize_;
- Word trgEosId_{Word::NONE};
- Word trgUnkId_{Word::NONE};
+ Ptr<Vocab> trgVocab_;
static constexpr auto INVALID_PATH_SCORE = -9999; // (@TODO: change to -9999.0 once C++ allows that)
public:
BeamSearch(Ptr<Options> options,
const std::vector<Ptr<Scorer>>& scorers,
- Word trgEosId,
- Word trgUnkId = Word::NONE)
+ Ptr<Vocab> trgVocab)
: options_(options),
scorers_(scorers),
beamSize_(options_->has("beam-size")
? options_->get<size_t>("beam-size")
: 3),
- trgEosId_(trgEosId),
- trgUnkId_(trgUnkId) {}
+ trgVocab_(trgVocab) {}
// combine new expandedPathScores and previous beams into new set of beams
Beams toHyps(const std::vector<unsigned int>& nBestKeys, // [dimBatch, beamSize] flattened -> ((batchIdx, beamHypIdx) flattened, word idx) flattened
@@ -197,11 +194,12 @@ public:
// remove all beam entries that have reached EOS
Beams purgeBeams(const Beams& beams) {
+ const auto trgEosId = trgVocab_->getEosId();
Beams newBeams;
for(auto beam : beams) {
Beam newBeam;
for(auto hyp : beam) {
- if(hyp->getWord() != trgEosId_) {
+ if(hyp->getWord() != trgEosId) {
newBeam.push_back(hyp);
}
}
@@ -213,10 +211,7 @@ public:
//**********************************************************************
// main decoding function
Histories search(Ptr<ExpressionGraph> graph, Ptr<data::CorpusBatch> batch) {
- ABORT_IF(batch->back()->vocab() && batch->back()->vocab()->getEosId() != trgEosId_,
- "Batch uses different EOS token than was passed to BeamSearch originally");
-
- auto factoredVocab = batch->back()->vocab()->tryAs<FactoredVocab>();
+ auto factoredVocab = trgVocab_->tryAs<FactoredVocab>();
#if 0 // use '1' here to disable factored decoding, e.g. for comparisons
factoredVocab.reset();
#endif
@@ -225,6 +220,8 @@ public:
factoredVocab.reset();
const int dimBatch = (int)batch->size();
+ const auto trgEosId = trgVocab_->getEosId();
+ const auto trgUnkId = trgVocab_->getUnkId();
auto getNBestList = createGetNBestListFn(beamSize_, dimBatch, graph->getDeviceId());
@@ -249,7 +246,7 @@ public:
Beams beams(dimBatch, Beam(beamSize_, New<Hypothesis>())); // array [dimBatch] of array [localBeamSize] of Hypothesis
for(int i = 0; i < dimBatch; ++i)
- histories[i]->add(beams[i], trgEosId_);
+ histories[i]->add(beams[i], trgEosId);
// the decoding process updates the following state information in each output time step:
// - beams: array [dimBatch] of array [localBeamSize] of Hypothesis
@@ -307,7 +304,7 @@ public:
prevScores.push_back(canExpand ? hyp->getPathScore() : INVALID_PATH_SCORE);
} else { // pad to localBeamSize (dummy hypothesis)
hypIndices.push_back(0);
- prevWords.push_back(trgEosId_); // (unused, but must be valid)
+ prevWords.push_back(trgEosId); // (unused, but must be valid)
prevScores.push_back((float)INVALID_PATH_SCORE);
}
}
@@ -371,8 +368,8 @@ public:
//**********************************************************************
// suppress specific symbols if not at right positions
- if(trgUnkId_ != Word::NONE && options_->has("allow-unk") && !options_->get<bool>("allow-unk") && factorGroup == 0)
- suppressWord(expandedPathScores, factoredVocab ? factoredVocab->getUnkIndex() : trgUnkId_.toWordIndex());
+ if(trgUnkId != Word::NONE && options_->has("allow-unk") && !options_->get<bool>("allow-unk") && factorGroup == 0)
+ suppressWord(expandedPathScores, factoredVocab ? factoredVocab->getUnkIndex() : trgUnkId.toWordIndex());
for(auto state : states)
state->blacklist(expandedPathScores, batch);
@@ -411,7 +408,7 @@ public:
if(!beams[i].empty()) {
if (histories[i]->size() >= options_->get<float>("max-length-factor") * batch->front()->batchWidth())
maxLengthReached = true;
- histories[i]->add(beams[i], trgEosId_, purgedNewBeams[i].empty() || maxLengthReached);
+ histories[i]->add(beams[i], trgEosId, purgedNewBeams[i].empty() || maxLengthReached);
}
}
if (maxLengthReached) // early exit if max length limit was reached
diff --git a/src/translator/translator.h b/src/translator/translator.h
index d2518f5d..43bb5ac3 100644..100755
--- a/src/translator/translator.h
+++ b/src/translator/translator.h
@@ -103,7 +103,7 @@ public:
scorers = scorers_[id % numDevices_];
}
- auto search = New<Search>(options_, scorers, trgVocab_->getEosId(), trgVocab_->getUnkId());
+ auto search = New<Search>(options_, scorers, trgVocab_);
auto histories = search->search(graph, batch);
for(auto history : histories) {
@@ -212,7 +212,7 @@ public:
scorers = scorers_[id % numDevices_];
}
- auto search = New<Search>(options_, scorers, trgVocab_->getEosId(), trgVocab_->getUnkId());
+ auto search = New<Search>(options_, scorers, trgVocab_);
auto histories = search->search(graph, batch);
for(auto history : histories) {