diff options
Diffstat (limited to 'src/data/corpus_base.cpp')
-rw-r--r-- | src/data/corpus_base.cpp | 77 |
1 files changed, 72 insertions, 5 deletions
diff --git a/src/data/corpus_base.cpp b/src/data/corpus_base.cpp index 9d95a121..20301103 100644 --- a/src/data/corpus_base.cpp +++ b/src/data/corpus_base.cpp @@ -12,7 +12,24 @@ typedef std::vector<float> MaskBatch; typedef std::pair<WordBatch, MaskBatch> WordMask; typedef std::vector<WordMask> SentBatch; -CorpusIterator::CorpusIterator() : pos_(-1), tup_(0) {} +void SentenceTupleImpl::setWeights(const std::vector<float>& weights) { + if(weights.size() != 1) { // this assumes a single sentence-level weight is always fine + ABORT_IF(empty(), "Source and target sequences should be added to a tuple before data weights"); + auto numWeights = weights.size(); + auto numTrgWords = back().size(); + // word-level weights may or may not contain a weight for EOS tokens + if(numWeights != numTrgWords && numWeights != numTrgWords - 1) + LOG(warn, + "[warn] " + "Number of weights ({}) does not match the number of target words ({}) in line #{}", + numWeights, + numTrgWords, + id_); + } + weights_ = weights; +} + +CorpusIterator::CorpusIterator() : pos_(-1) {} CorpusIterator::CorpusIterator(CorpusBase* corpus) : corpus_(corpus), pos_(0), tup_(corpus_->next()) {} @@ -23,7 +40,7 @@ void CorpusIterator::increment() { } bool CorpusIterator::equal(CorpusIterator const& other) const { - return this->pos_ == other.pos_ || (this->tup_.empty() && other.tup_.empty()); + return this->pos_ == other.pos_ || (!this->tup_.valid() && !other.tup_.valid()); } const SentenceTuple& CorpusIterator::dereference() const { @@ -390,7 +407,7 @@ CorpusBase::CorpusBase(Ptr<Options> options, bool translate, size_t seed) void CorpusBase::addWordsToSentenceTuple(const std::string& line, size_t batchIndex, - SentenceTuple& tup) const { + SentenceTupleImpl& tup) const { // This turns a string in to a sequence of numerical word ids. Depending // on the vocabulary type, this can be non-trivial, e.g. when SentencePiece // is used. @@ -411,7 +428,7 @@ void CorpusBase::addWordsToSentenceTuple(const std::string& line, } void CorpusBase::addAlignmentToSentenceTuple(const std::string& line, - SentenceTuple& tup) const { + SentenceTupleImpl& tup) const { ABORT_IF(rightLeft_, "Guided alignment and right-left model cannot be used " "together at the moment"); @@ -420,7 +437,7 @@ void CorpusBase::addAlignmentToSentenceTuple(const std::string& line, tup.setAlignment(align); } -void CorpusBase::addWeightsToSentenceTuple(const std::string& line, SentenceTuple& tup) const { +void CorpusBase::addWeightsToSentenceTuple(const std::string& line, SentenceTupleImpl& tup) const { auto elements = utils::split(line, " "); if(!elements.empty()) { @@ -549,6 +566,7 @@ size_t CorpusBase::getNumberOfTSVInputFields(Ptr<Options> options) { return 0; } +<<<<<<< HEAD void SentenceTuple::setWeights(const std::vector<float>& weights) { if(weights.size() != 1) { // this assumes a single sentence-level weight is always fine ABORT_IF(empty(), "Source and target sequences should be added to a tuple before data weights"); @@ -564,6 +582,55 @@ void SentenceTuple::setWeights(const std::vector<float>& weights) { id_); } weights_ = weights; +======= +// experimental: hide inline-fix source tokens from cross attention +std::vector<float> SubBatch::crossMaskWithInlineFixSourceSuppressed() const +{ + const auto& srcVocab = *vocab(); + + auto factoredVocab = vocab()->tryAs<FactoredVocab>(); + size_t inlineFixGroupIndex = 0, inlineFixSrc = 0; + auto hasInlineFixFactors = factoredVocab && factoredVocab->tryGetFactor(FactoredVocab_INLINE_FIX_WHAT_serialized, /*out*/ inlineFixGroupIndex, /*out*/ inlineFixSrc); + + auto fixSrcId = srcVocab[FactoredVocab_FIX_SRC_ID_TAG]; + auto fixTgtId = srcVocab[FactoredVocab_FIX_TGT_ID_TAG]; + auto fixEndId = srcVocab[FactoredVocab_FIX_END_ID_TAG]; + auto unkId = srcVocab.getUnkId(); + auto hasInlineFixTags = fixSrcId != unkId && fixTgtId != unkId && fixEndId != unkId; + + auto m = mask(); // default return value, which we will modify in-place below in case we need to + if (hasInlineFixFactors || hasInlineFixTags) { + LOG_ONCE(info, "[data] Suppressing cross-attention into inline-fix source tokens"); + + // example: force French translation of name "frank" to always be "franck" + // - hasInlineFixFactors: "frank|is franck|it", "frank|is" cannot be cross-attended to + // - hasInlineFixTags: "<IOPEN> frank <IDELIM> franck <ICLOSE>", "frank" and all tags cannot be cross-attended to + auto dimBatch = batchSize(); // number of sentences in the batch + auto dimWidth = batchWidth(); // number of words in the longest sentence in the batch + const auto& d = data(); + size_t numWords = 0; + for (size_t b = 0; b < dimBatch; b++) { // loop over batch entries + bool inside = false; + for (size_t s = 0; s < dimWidth; s++) { // loop over source positions + auto i = locate(/*batchIdx=*/b, /*wordPos=*/s); + if (!m[i]) + break; + numWords++; + // keep track of entering/exiting the inline-fix source tags + auto w = d[i]; + if (w == fixSrcId) + inside = true; + else if (w == fixTgtId) + inside = false; + bool wHasSrcIdFactor = hasInlineFixFactors && factoredVocab->getFactor(w, inlineFixGroupIndex) == inlineFixSrc; + if (inside || w == fixSrcId || w == fixTgtId || w == fixEndId || wHasSrcIdFactor) + m[i] = 0.0f; // decoder must not look at embedded source, nor the markup tokens + } + } + ABORT_IF(batchWords() != 0/*n/a*/ && numWords != batchWords(), "batchWords() inconsistency??"); + } + return m; +>>>>>>> master } } // namespace data |