diff options
-rw-r--r-- | CHANGELOG.md | 1 | ||||
-rw-r--r-- | src/data/corpus.cpp | 17 | ||||
-rw-r--r-- | src/data/corpus.h | 2 | ||||
-rw-r--r-- | src/data/corpus_base.cpp | 12 | ||||
-rw-r--r-- | src/data/corpus_base.h | 12 |
5 files changed, 37 insertions, 7 deletions
diff --git a/CHANGELOG.md b/CHANGELOG.md index 1a0b9927..05658fe1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -32,6 +32,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. - Integrate a shortlist converter (which can convert a text lexical shortlist to a binary shortlist) into marian-conv with --shortlist option ### Fixed +- Do not set guided alignments for case augmented data if vocab is not factored - Various fixes to enable LSH in Quicksand - Added support to MPIWrappest::bcast (and similar) for count of type size_t - Adding new validation metrics when training is restarted and --reset-valid-stalled is used diff --git a/src/data/corpus.cpp b/src/data/corpus.cpp index e8ce850b..d8a364b2 100644 --- a/src/data/corpus.cpp +++ b/src/data/corpus.cpp @@ -7,6 +7,7 @@ #include "common/filesystem.h" #include "data/corpus.h" +#include "data/factored_vocab.h" namespace marian { namespace data { @@ -26,13 +27,16 @@ Corpus::Corpus(std::vector<std::string> paths, allCapsEvery_(options_->get<size_t>("all-caps-every", 0)), titleCaseEvery_(options_->get<size_t>("english-title-case-every", 0)) {} -void Corpus::preprocessLine(std::string& line, size_t streamId) { +void Corpus::preprocessLine(std::string& line, size_t streamId, bool& altered) { + bool isFactoredVocab = vocabs_.back()->tryAs<FactoredVocab>() != nullptr; + altered = false; if (allCapsEvery_ != 0 && pos_ % allCapsEvery_ == 0 && !inference_) { line = vocabs_[streamId]->toUpper(line); if (streamId == 0) LOG_ONCE(info, "[data] Source all-caps'ed line to: {}", line); else LOG_ONCE(info, "[data] Target all-caps'ed line to: {}", line); + altered = isFactoredVocab ? false : true; // FS vocab does not really "alter" the token lemma for all caps } else if (titleCaseEvery_ != 0 && pos_ % titleCaseEvery_ == 1 && !inference_ && streamId == 0) { // Only applied to stream 0 (source) since this feature is aimed at robustness against @@ -43,6 +47,7 @@ void Corpus::preprocessLine(std::string& line, size_t streamId) { LOG_ONCE(info, "[data] Source English-title-case'd line to: {}", line); else LOG_ONCE(info, "[data] Target English-title-case'd line to: {}", line); + altered = isFactoredVocab ? false : true; // FS vocab does not really "alter" the token lemma for title casing } } @@ -103,7 +108,10 @@ SentenceTuple Corpus::next() { ++shift; } else { size_t vocabId = j - shift; - preprocessLine(fields[j], vocabId); + bool altered; + preprocessLine(fields[j], vocabId, /*out=*/altered); + if (altered) + tup.markAltered(); addWordsToSentenceTuple(fields[j], vocabId, tup); } } @@ -116,7 +124,10 @@ SentenceTuple Corpus::next() { addWeightsToSentenceTuple(fields[weightFileIdx_], tup); } else { - preprocessLine(line, i); + bool altered; + preprocessLine(line, i, /*out=*/altered); + if (altered) + tup.markAltered(); addWordsToSentenceTuple(line, i, tup); } } diff --git a/src/data/corpus.h b/src/data/corpus.h index 70e7cdfb..e8e9a9fd 100644 --- a/src/data/corpus.h +++ b/src/data/corpus.h @@ -30,7 +30,7 @@ private: // for pre-processing size_t allCapsEvery_{0}; // if set, convert every N-th input sentence (after randomization) to all-caps (source and target) size_t titleCaseEvery_{0}; // ditto for title case (source only) - void preprocessLine(std::string& line, size_t streamId); + void preprocessLine(std::string& line, size_t streamId, bool& altered); // altered => whether the segmentation was altered in marian public: // @TODO: check if translate can be replaced by an option in options diff --git a/src/data/corpus_base.cpp b/src/data/corpus_base.cpp index 5be4298b..5f9a9ee3 100644 --- a/src/data/corpus_base.cpp +++ b/src/data/corpus_base.cpp @@ -447,9 +447,15 @@ void CorpusBase::addAlignmentsToBatch(Ptr<CorpusBatch> batch, std::vector<float> aligns(srcWords * dimBatch * trgWords, 0.f); for(int b = 0; b < dimBatch; ++b) { - for(auto p : batchVector[b].getAlignment()) { - size_t idx = p.srcPos * dimBatch * trgWords + b * trgWords + p.tgtPos; - aligns[idx] = 1.f; + + // If the batch vector is altered within marian by, for example, case augmentation, + // the guided alignments we received for this tuple cease to be valid. + // Hence skip setting alignments for that sentence tuple.. + if (!batchVector[b].isAltered()) { + for(auto p : batchVector[b].getAlignment()) { + size_t idx = p.srcPos * dimBatch * trgWords + b * trgWords + p.tgtPos; + aligns[idx] = 1.f; + } } } batch->setGuidedAlignment(std::move(aligns)); diff --git a/src/data/corpus_base.h b/src/data/corpus_base.h index 8e5e1334..251df5bc 100644 --- a/src/data/corpus_base.h +++ b/src/data/corpus_base.h @@ -28,6 +28,7 @@ private: std::vector<Words> tuple_; // [stream index][step index] std::vector<float> weights_; // [stream index] WordAlignment alignment_; + bool altered_ = false; public: typedef Words value_type; @@ -45,6 +46,17 @@ public: size_t getId() const { return id_; } /** + * @brief Returns whether this Tuple was altered or augmented from what + * was provided to Marian in input. + */ + bool isAltered() const { return altered_; } + + /** + * @brief Mark that this Tuple was internally altered or augmented by Marian + */ + void markAltered() { altered_ = true; } + + /** * @brief Adds a new sentence at the end of the tuple. * * @param words A vector of word indices. |