Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'src/data/corpus_base.cpp')
-rw-r--r--src/data/corpus_base.cpp77
1 files changed, 72 insertions, 5 deletions
diff --git a/src/data/corpus_base.cpp b/src/data/corpus_base.cpp
index 9d95a121..20301103 100644
--- a/src/data/corpus_base.cpp
+++ b/src/data/corpus_base.cpp
@@ -12,7 +12,24 @@ typedef std::vector<float> MaskBatch;
typedef std::pair<WordBatch, MaskBatch> WordMask;
typedef std::vector<WordMask> SentBatch;
-CorpusIterator::CorpusIterator() : pos_(-1), tup_(0) {}
+void SentenceTupleImpl::setWeights(const std::vector<float>& weights) {
+ if(weights.size() != 1) { // this assumes a single sentence-level weight is always fine
+ ABORT_IF(empty(), "Source and target sequences should be added to a tuple before data weights");
+ auto numWeights = weights.size();
+ auto numTrgWords = back().size();
+ // word-level weights may or may not contain a weight for EOS tokens
+ if(numWeights != numTrgWords && numWeights != numTrgWords - 1)
+ LOG(warn,
+ "[warn] "
+ "Number of weights ({}) does not match the number of target words ({}) in line #{}",
+ numWeights,
+ numTrgWords,
+ id_);
+ }
+ weights_ = weights;
+}
+
+CorpusIterator::CorpusIterator() : pos_(-1) {}
CorpusIterator::CorpusIterator(CorpusBase* corpus)
: corpus_(corpus), pos_(0), tup_(corpus_->next()) {}
@@ -23,7 +40,7 @@ void CorpusIterator::increment() {
}
bool CorpusIterator::equal(CorpusIterator const& other) const {
- return this->pos_ == other.pos_ || (this->tup_.empty() && other.tup_.empty());
+ return this->pos_ == other.pos_ || (!this->tup_.valid() && !other.tup_.valid());
}
const SentenceTuple& CorpusIterator::dereference() const {
@@ -390,7 +407,7 @@ CorpusBase::CorpusBase(Ptr<Options> options, bool translate, size_t seed)
void CorpusBase::addWordsToSentenceTuple(const std::string& line,
size_t batchIndex,
- SentenceTuple& tup) const {
+ SentenceTupleImpl& tup) const {
// This turns a string in to a sequence of numerical word ids. Depending
// on the vocabulary type, this can be non-trivial, e.g. when SentencePiece
// is used.
@@ -411,7 +428,7 @@ void CorpusBase::addWordsToSentenceTuple(const std::string& line,
}
void CorpusBase::addAlignmentToSentenceTuple(const std::string& line,
- SentenceTuple& tup) const {
+ SentenceTupleImpl& tup) const {
ABORT_IF(rightLeft_,
"Guided alignment and right-left model cannot be used "
"together at the moment");
@@ -420,7 +437,7 @@ void CorpusBase::addAlignmentToSentenceTuple(const std::string& line,
tup.setAlignment(align);
}
-void CorpusBase::addWeightsToSentenceTuple(const std::string& line, SentenceTuple& tup) const {
+void CorpusBase::addWeightsToSentenceTuple(const std::string& line, SentenceTupleImpl& tup) const {
auto elements = utils::split(line, " ");
if(!elements.empty()) {
@@ -549,6 +566,7 @@ size_t CorpusBase::getNumberOfTSVInputFields(Ptr<Options> options) {
return 0;
}
+<<<<<<< HEAD
void SentenceTuple::setWeights(const std::vector<float>& weights) {
if(weights.size() != 1) { // this assumes a single sentence-level weight is always fine
ABORT_IF(empty(), "Source and target sequences should be added to a tuple before data weights");
@@ -564,6 +582,55 @@ void SentenceTuple::setWeights(const std::vector<float>& weights) {
id_);
}
weights_ = weights;
+=======
+// experimental: hide inline-fix source tokens from cross attention
+std::vector<float> SubBatch::crossMaskWithInlineFixSourceSuppressed() const
+{
+ const auto& srcVocab = *vocab();
+
+ auto factoredVocab = vocab()->tryAs<FactoredVocab>();
+ size_t inlineFixGroupIndex = 0, inlineFixSrc = 0;
+ auto hasInlineFixFactors = factoredVocab && factoredVocab->tryGetFactor(FactoredVocab_INLINE_FIX_WHAT_serialized, /*out*/ inlineFixGroupIndex, /*out*/ inlineFixSrc);
+
+ auto fixSrcId = srcVocab[FactoredVocab_FIX_SRC_ID_TAG];
+ auto fixTgtId = srcVocab[FactoredVocab_FIX_TGT_ID_TAG];
+ auto fixEndId = srcVocab[FactoredVocab_FIX_END_ID_TAG];
+ auto unkId = srcVocab.getUnkId();
+ auto hasInlineFixTags = fixSrcId != unkId && fixTgtId != unkId && fixEndId != unkId;
+
+ auto m = mask(); // default return value, which we will modify in-place below in case we need to
+ if (hasInlineFixFactors || hasInlineFixTags) {
+ LOG_ONCE(info, "[data] Suppressing cross-attention into inline-fix source tokens");
+
+ // example: force French translation of name "frank" to always be "franck"
+ // - hasInlineFixFactors: "frank|is franck|it", "frank|is" cannot be cross-attended to
+ // - hasInlineFixTags: "<IOPEN> frank <IDELIM> franck <ICLOSE>", "frank" and all tags cannot be cross-attended to
+ auto dimBatch = batchSize(); // number of sentences in the batch
+ auto dimWidth = batchWidth(); // number of words in the longest sentence in the batch
+ const auto& d = data();
+ size_t numWords = 0;
+ for (size_t b = 0; b < dimBatch; b++) { // loop over batch entries
+ bool inside = false;
+ for (size_t s = 0; s < dimWidth; s++) { // loop over source positions
+ auto i = locate(/*batchIdx=*/b, /*wordPos=*/s);
+ if (!m[i])
+ break;
+ numWords++;
+ // keep track of entering/exiting the inline-fix source tags
+ auto w = d[i];
+ if (w == fixSrcId)
+ inside = true;
+ else if (w == fixTgtId)
+ inside = false;
+ bool wHasSrcIdFactor = hasInlineFixFactors && factoredVocab->getFactor(w, inlineFixGroupIndex) == inlineFixSrc;
+ if (inside || w == fixSrcId || w == fixTgtId || w == fixEndId || wHasSrcIdFactor)
+ m[i] = 0.0f; // decoder must not look at embedded source, nor the markup tokens
+ }
+ }
+ ABORT_IF(batchWords() != 0/*n/a*/ && numWords != batchWords(), "batchWords() inconsistency??");
+ }
+ return m;
+>>>>>>> master
}
} // namespace data