diff options
author | Frank Seide <fseide@microsoft.com> | 2019-07-23 00:53:24 +0300 |
---|---|---|
committer | Martin Junczys-Dowmunt <Marcin.JunczysDowmunt@microsoft.com> | 2019-07-23 00:53:25 +0300 |
commit | ddcbe825f0af3f548dd33daedfe340f00c9a1c0c (patch) | |
tree | 49ca99d7ee9828ac760887a69d4e71ec3c6459bf /src | |
parent | c341a6c78b905fab1f8a059826458d7392af7d45 (diff) | |
parent | 236d153ce4de973a374b827cbc3c4191fec4124a (diff) |
Merged PR 8650: fix for decoding with factors for untied embeddings
This fixes Zhongkai's issue when decoding with non-tied embeddings with factors.
Related work items: #98842
Diffstat (limited to 'src')
-rwxr-xr-x | src/common/config_parser.cpp | 2 | ||||
-rwxr-xr-x | src/data/factored_vocab.cpp | 11 | ||||
-rwxr-xr-x[-rw-r--r--] | src/graph/node_operators_unary.h | 4 | ||||
-rwxr-xr-x | src/layers/generic.cpp | 22 | ||||
-rwxr-xr-x[-rw-r--r--] | src/microsoft/quicksand.cpp | 13 | ||||
-rwxr-xr-x[-rw-r--r--] | src/microsoft/quicksand.h | 8 | ||||
-rwxr-xr-x | src/training/validator.h | 6 | ||||
-rwxr-xr-x[-rw-r--r--] | src/translator/beam_search.h | 29 | ||||
-rwxr-xr-x[-rw-r--r--] | src/translator/translator.h | 4 |
9 files changed, 57 insertions, 42 deletions
diff --git a/src/common/config_parser.cpp b/src/common/config_parser.cpp index bd6650c6..e3441848 100755 --- a/src/common/config_parser.cpp +++ b/src/common/config_parser.cpp @@ -671,7 +671,7 @@ void ConfigParser::addSuboptionsInputLength(cli::CLIWrapper& cli) { "Maximum length of a sentence in a training sentence pair", defaultMaxLength); cli.add<bool>("--max-length-crop", - "Crop a sentence to max-length instead of ommitting it if longer than max-length"); + "Crop a sentence to max-length instead of omitting it if longer than max-length"); // clang-format on } diff --git a/src/data/factored_vocab.cpp b/src/data/factored_vocab.cpp index 37fbf4c7..1f536160 100755 --- a/src/data/factored_vocab.cpp +++ b/src/data/factored_vocab.cpp @@ -27,6 +27,7 @@ namespace marian { //LOG(info, "[vocab] Attempting to load model a second time; skipping (assuming shared vocab)"); return size(); } + LOG(info, "[vocab] Loading vocab spec file {}", modelPath); // load factor-vocab file and parse it std::vector<std::vector<std::string>> factorMapTokenized; @@ -571,11 +572,11 @@ void FactoredVocab::constructNormalizationInfoForVocab() { static void unescapeHexEscapes(std::string& utf8Lemma) { if (utf8Lemma.find('\\') == std::string::npos) return; // nothing to do - auto lemma = utils::utf8ToUtf16String(utf8Lemma); // \u.... implies we must operate on UTF-16 level + auto lemma = utils::utf8ToUtf16String(utf8Lemma); // \u.... implies we must operate on UTF-16 level (not UCS-4) auto pos = lemma.find('\\'); while (pos != std::string::npos) { ABORT_IF(pos + 1 >= lemma.size() || (lemma[pos+1] != 'x' && lemma[pos + 1] != 'u'), "Malformed escape in factored encoding: {}", utf8Lemma); - int numDigits = 2 + (lemma[pos + 1] == 'u'); + int numDigits = 2 + 2 * (lemma[pos + 1] == 'u'); // 2 for \x, 4 for \u ABORT_IF(pos + 2 + numDigits > lemma.size(), "Malformed escape in factored encoding: {}", utf8Lemma); auto digits = utils::utf8FromUtf16String(lemma.substr(pos + 2, numDigits)); auto c = std::strtoul(digits.c_str(), nullptr, 16); @@ -607,13 +608,15 @@ std::string FactoredVocab::surfaceForm(const Words& sentence) const /*override f auto has = [&](const char* factor) { return tokenSet.find(factor) != tokenSet.end(); }; // spacing bool hasGlueRight = has("gr+") || has("wen") || has("cen"); - bool hasGlueLeft = has("gl+") || has("wbn") || has("cbn"); + bool hasGlueLeft = has("gl+") || has("wbn") || has("cbn") || has("wi"); bool insertSpaceBefore = !prevHadGlueRight && !hasGlueLeft; if (insertSpaceBefore) res.push_back(' '); prevHadGlueRight = hasGlueRight; // capitalization unescapeHexEscapes(lemma); // unescape \x.. and \u.... + if (utils::beginsWith(lemma, "\xE2\x96\x81")) // remove leading _ (\u2581, for DistinguishInitialAndInternalPieces mode) + lemma = lemma.substr(3); if (has("ci")) lemma = utils::utf8Capitalized(lemma); else if (has("ca")) lemma = utils::utf8ToUpper (lemma); else if (has("cn")) lemma = utils::utf8ToLower (lemma); @@ -730,7 +733,7 @@ Ptr<IVocab> createFactoredVocab(const std::string& vocabPath) { static std::map<std::string, Ptr<IVocab>> s_cache; auto iter = s_cache.find(vocabPath); if (iter != s_cache.end()) { - LOG(info, "[vocab] Reusing existing vocabulary object in memory"); + LOG(info, "[vocab] Reusing existing vocabulary object in memory (vocab size {})", iter->second->size()); return iter->second; } auto vocab = New<FactoredVocab>(); diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h index 36273d6f..94200435 100644..100755 --- a/src/graph/node_operators_unary.h +++ b/src/graph/node_operators_unary.h @@ -477,6 +477,10 @@ struct ReduceNodeOp : public UnaryNodeOp { NodeOps backwardOps() override { using namespace functional; +#if 1 // @BUGBUG: This is a workaround for not correctly propagating non-trainable information. @TODO: Do this the right and general way. + if (adj_ == nullptr) + return {}; +#endif switch (opCode_) { case ReduceNodeOpCode::sum: return {NodeOp(Add(_1, child(0)->grad(), adj_))}; diff --git a/src/layers/generic.cpp b/src/layers/generic.cpp index 30060432..9fedaf25 100755 --- a/src/layers/generic.cpp +++ b/src/layers/generic.cpp @@ -246,8 +246,12 @@ namespace marian { b_ = graph_->param(name + "_b", {1, numOutputClasses}, inits::zeros); - const int lemmaDimEmb = options_->get<int>("lemma-dim-emb", 0); + /*const*/ int lemmaDimEmb = options_->get<int>("lemma-dim-emb", 0); if (lemmaDimEmb > 0) { +#define HARDMAX_HACK +#ifdef HARDMAX_HACK + lemmaDimEmb = lemmaDimEmb & 0xfffffffe; +#endif auto range = factoredVocab_->getGroupRange(0); auto lemmaVocabDim = (int)(range.second - range.first); lemmaEt_ = graph_->param(name + "_lemmaEt", {lemmaDimEmb, lemmaVocabDim}, inits::glorot_uniform); // [L x U] L=lemmaDimEmb; transposed for speed @@ -299,7 +303,7 @@ namespace marian { allLogits[g] = New<RationalLoss>(factorLogits, nullptr); // optionally add a soft embedding of lemma back to create some lemma dependency // @TODO: if this works, move it into lazyConstruct - const int lemmaDimEmb = options_->get<int>("lemma-dim-emb", 0); + /*const*/ int lemmaDimEmb = options_->get<int>("lemma-dim-emb", 0); if (lemmaDimEmb < 0 && g == 0) { ABORT_IF(shortlist_ && lemmaDimEmb != 0, "Lemma-dependent bias with short list is not yet implemented"); LOG_ONCE(info, "[embedding] using lemma-dependent bias"); @@ -310,11 +314,21 @@ namespace marian { if (lemmaDimEmb > 0 && g == 0) { LOG_ONCE(info, "[embedding] enabled re-embedding of lemma, at dim {}", lemmaDimEmb); // compute softmax. We compute logsoftmax() separately because this way, computation will be reused later via CSE - factorLogits = logsoftmax(factorLogits); + auto factorLogSoftmax = logsoftmax(factorLogits); + auto factorSoftmax = exp(factorLogSoftmax); +#ifdef HARDMAX_HACK + bool hardmax = (lemmaDimEmb & 1) != 0; // odd value triggers hardmax for now (for quick experimentation) + if (hardmax) { + lemmaDimEmb = lemmaDimEmb & 0xfffffffe; + LOG_ONCE(info, "[embedding] HARDMAX_HACK enabled. Actual dim is {}", lemmaDimEmb); + auto maxVal = max(factorSoftmax, -1); + factorSoftmax = eq(factorSoftmax, maxVal); + } +#endif // re-embedding lookup, soft-indexed by softmax if (shortlist_ && !cachedShortLemmaEt_) // short-listed version of re-embedding matrix cachedShortLemmaEt_ = index_select(lemmaEt_, -1, shortlist_->indices()); - auto e = dot(exp(factorLogits), cachedShortLemmaEt_ ? cachedShortLemmaEt_ : lemmaEt_, false, true); // [B... x L] + auto e = dot(factorSoftmax, cachedShortLemmaEt_ ? cachedShortLemmaEt_ : lemmaEt_, false, true); // [B... x L] // project it back to regular hidden dim int inputDim = input1->shape()[-1]; auto name = options_->get<std::string>("prefix"); diff --git a/src/microsoft/quicksand.cpp b/src/microsoft/quicksand.cpp index 2dc2efc4..4a9a032e 100644..100755 --- a/src/microsoft/quicksand.cpp +++ b/src/microsoft/quicksand.cpp @@ -55,9 +55,8 @@ private: public: BeamSearchDecoder(Ptr<Options> options, const std::vector<const void*>& ptrs, - const std::vector<Ptr<IVocabWrapper>>& vocabs, - WordIndex eos) - : IBeamSearchDecoder(options, ptrs, eos) { + const std::vector<Ptr<IVocabWrapper>>& vocabs) + : IBeamSearchDecoder(options, ptrs) { // copy the vocabs for (auto vi : vocabs) @@ -145,7 +144,7 @@ public: batch->setSentenceIds(sentIds); // decode - auto search = New<BeamSearch>(options_, scorers_, marian::Word::fromWordIndex(eos_)); + auto search = New<BeamSearch>(options_, scorers_, vocabs_[1]); Histories histories = search->search(graph_, batch); // convert to QuickSAND format @@ -189,8 +188,10 @@ public: Ptr<IBeamSearchDecoder> newDecoder(Ptr<Options> options, const std::vector<const void*>& ptrs, const std::vector<Ptr<IVocabWrapper>>& vocabs, - WordIndex eos) { - return New<BeamSearchDecoder>(options, ptrs, vocabs, eos); + WordIndex eosDummy) { // @TODO: remove this parameter + ABORT_IF(marian::Word::fromWordIndex(eosDummy) != std::dynamic_pointer_cast<VocabWrapper>(vocabs[1])->getVocab()->getEosId(), "Inconsistent eos vs. vocabs_[1]"); + + return New<BeamSearchDecoder>(options, ptrs, vocabs/*, eos*/); } std::vector<Ptr<IVocabWrapper>> loadVocabs(const std::vector<std::string>& vocabPaths) { diff --git a/src/microsoft/quicksand.h b/src/microsoft/quicksand.h index c4f539ab..ffab6c0d 100644..100755 --- a/src/microsoft/quicksand.h +++ b/src/microsoft/quicksand.h @@ -42,13 +42,11 @@ class IBeamSearchDecoder { protected: Ptr<Options> options_; std::vector<const void*> ptrs_; - WordIndex eos_; public: IBeamSearchDecoder(Ptr<Options> options, - const std::vector<const void*>& ptrs, - WordIndex eos) - : options_(options), ptrs_(ptrs), eos_(eos) {} + const std::vector<const void*>& ptrs) + : options_(options), ptrs_(ptrs) {} virtual QSNBestBatch decode(const QSBatch& qsBatch, size_t maxLength, @@ -61,7 +59,7 @@ public: Ptr<IBeamSearchDecoder> newDecoder(Ptr<Options> options, const std::vector<const void*>& ptrs, const std::vector<Ptr<IVocabWrapper>>& vocabs, - WordIndex eos); + WordIndex eos/*dummy --@TODO: remove*/); // load src and tgt vocabs std::vector<Ptr<IVocabWrapper>> loadVocabs(const std::vector<std::string>& vocabPaths); diff --git a/src/training/validator.h b/src/training/validator.h index 7736318b..330b889e 100755 --- a/src/training/validator.h +++ b/src/training/validator.h @@ -534,8 +534,7 @@ public: auto search = New<BeamSearch>(options_, std::vector<Ptr<Scorer>>{scorer}, - vocabs_.back()->getEosId(), - vocabs_.back()->getUnkId()); + vocabs_.back()); auto histories = search->search(graph, batch); for(auto history : histories) { @@ -673,8 +672,7 @@ public: auto search = New<BeamSearch>(options_, std::vector<Ptr<Scorer>>{scorer}, - vocabs_.back()->getEosId(), - vocabs_.back()->getUnkId()); + vocabs_.back()); auto histories = search->search(graph, batch); size_t no = 0; diff --git a/src/translator/beam_search.h b/src/translator/beam_search.h index 393ff6df..d0a2cb11 100644..100755 --- a/src/translator/beam_search.h +++ b/src/translator/beam_search.h @@ -16,23 +16,20 @@ private: Ptr<Options> options_; std::vector<Ptr<Scorer>> scorers_; size_t beamSize_; - Word trgEosId_{Word::NONE}; - Word trgUnkId_{Word::NONE}; + Ptr<Vocab> trgVocab_; static constexpr auto INVALID_PATH_SCORE = -9999; // (@TODO: change to -9999.0 once C++ allows that) public: BeamSearch(Ptr<Options> options, const std::vector<Ptr<Scorer>>& scorers, - Word trgEosId, - Word trgUnkId = Word::NONE) + Ptr<Vocab> trgVocab) : options_(options), scorers_(scorers), beamSize_(options_->has("beam-size") ? options_->get<size_t>("beam-size") : 3), - trgEosId_(trgEosId), - trgUnkId_(trgUnkId) {} + trgVocab_(trgVocab) {} // combine new expandedPathScores and previous beams into new set of beams Beams toHyps(const std::vector<unsigned int>& nBestKeys, // [dimBatch, beamSize] flattened -> ((batchIdx, beamHypIdx) flattened, word idx) flattened @@ -197,11 +194,12 @@ public: // remove all beam entries that have reached EOS Beams purgeBeams(const Beams& beams) { + const auto trgEosId = trgVocab_->getEosId(); Beams newBeams; for(auto beam : beams) { Beam newBeam; for(auto hyp : beam) { - if(hyp->getWord() != trgEosId_) { + if(hyp->getWord() != trgEosId) { newBeam.push_back(hyp); } } @@ -213,10 +211,7 @@ public: //********************************************************************** // main decoding function Histories search(Ptr<ExpressionGraph> graph, Ptr<data::CorpusBatch> batch) { - ABORT_IF(batch->back()->vocab() && batch->back()->vocab()->getEosId() != trgEosId_, - "Batch uses different EOS token than was passed to BeamSearch originally"); - - auto factoredVocab = batch->back()->vocab()->tryAs<FactoredVocab>(); + auto factoredVocab = trgVocab_->tryAs<FactoredVocab>(); #if 0 // use '1' here to disable factored decoding, e.g. for comparisons factoredVocab.reset(); #endif @@ -225,6 +220,8 @@ public: factoredVocab.reset(); const int dimBatch = (int)batch->size(); + const auto trgEosId = trgVocab_->getEosId(); + const auto trgUnkId = trgVocab_->getUnkId(); auto getNBestList = createGetNBestListFn(beamSize_, dimBatch, graph->getDeviceId()); @@ -249,7 +246,7 @@ public: Beams beams(dimBatch, Beam(beamSize_, New<Hypothesis>())); // array [dimBatch] of array [localBeamSize] of Hypothesis for(int i = 0; i < dimBatch; ++i) - histories[i]->add(beams[i], trgEosId_); + histories[i]->add(beams[i], trgEosId); // the decoding process updates the following state information in each output time step: // - beams: array [dimBatch] of array [localBeamSize] of Hypothesis @@ -307,7 +304,7 @@ public: prevScores.push_back(canExpand ? hyp->getPathScore() : INVALID_PATH_SCORE); } else { // pad to localBeamSize (dummy hypothesis) hypIndices.push_back(0); - prevWords.push_back(trgEosId_); // (unused, but must be valid) + prevWords.push_back(trgEosId); // (unused, but must be valid) prevScores.push_back((float)INVALID_PATH_SCORE); } } @@ -371,8 +368,8 @@ public: //********************************************************************** // suppress specific symbols if not at right positions - if(trgUnkId_ != Word::NONE && options_->has("allow-unk") && !options_->get<bool>("allow-unk") && factorGroup == 0) - suppressWord(expandedPathScores, factoredVocab ? factoredVocab->getUnkIndex() : trgUnkId_.toWordIndex()); + if(trgUnkId != Word::NONE && options_->has("allow-unk") && !options_->get<bool>("allow-unk") && factorGroup == 0) + suppressWord(expandedPathScores, factoredVocab ? factoredVocab->getUnkIndex() : trgUnkId.toWordIndex()); for(auto state : states) state->blacklist(expandedPathScores, batch); @@ -411,7 +408,7 @@ public: if(!beams[i].empty()) { if (histories[i]->size() >= options_->get<float>("max-length-factor") * batch->front()->batchWidth()) maxLengthReached = true; - histories[i]->add(beams[i], trgEosId_, purgedNewBeams[i].empty() || maxLengthReached); + histories[i]->add(beams[i], trgEosId, purgedNewBeams[i].empty() || maxLengthReached); } } if (maxLengthReached) // early exit if max length limit was reached diff --git a/src/translator/translator.h b/src/translator/translator.h index d2518f5d..43bb5ac3 100644..100755 --- a/src/translator/translator.h +++ b/src/translator/translator.h @@ -103,7 +103,7 @@ public: scorers = scorers_[id % numDevices_]; } - auto search = New<Search>(options_, scorers, trgVocab_->getEosId(), trgVocab_->getUnkId()); + auto search = New<Search>(options_, scorers, trgVocab_); auto histories = search->search(graph, batch); for(auto history : histories) { @@ -212,7 +212,7 @@ public: scorers = scorers_[id % numDevices_]; } - auto search = New<Search>(options_, scorers, trgVocab_->getEosId(), trgVocab_->getUnkId()); + auto search = New<Search>(options_, scorers, trgVocab_); auto histories = search->search(graph, batch); for(auto history : histories) { |