Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMarcin Junczys-Dowmunt <Marcin.JunczysDowmunt@microsoft.com>2021-11-25 05:33:49 +0300
committerMarcin Junczys-Dowmunt <Marcin.JunczysDowmunt@microsoft.com>2021-11-25 05:33:49 +0300
commit8b8d1b11e28a421b348703d702c9c5206061df9d (patch)
tree01f5313e706bdba5c3317d5a713f8ab6224ec2da
parentc85d0608483789d446361ea28d95f7d7c9545f2d (diff)
Merged PR 21553: Parallelize data reading for training
This parallelizes data reading. On very fast GPUs and with small models training speed can be starved by too slow batch creation. Use --data-threads 8 or more, by default currently set to 1 for backcompat.
-rw-r--r--src/common/config_parser.cpp7
-rw-r--r--src/common/utils.cpp8
-rw-r--r--src/data/batch_generator.h35
-rw-r--r--src/data/corpus.cpp152
-rw-r--r--src/data/corpus.h3
-rw-r--r--src/data/corpus_base.cpp44
-rw-r--r--src/data/corpus_base.h105
-rw-r--r--src/data/corpus_nbest.cpp7
-rw-r--r--src/data/corpus_sqlite.cpp6
-rw-r--r--src/data/text_input.cpp6
10 files changed, 251 insertions, 122 deletions
diff --git a/src/common/config_parser.cpp b/src/common/config_parser.cpp
index 59b328e9..3d79f8af 100644
--- a/src/common/config_parser.cpp
+++ b/src/common/config_parser.cpp
@@ -883,6 +883,10 @@ void ConfigParser::addSuboptionsBatching(cli::CLIWrapper& cli) {
if(mode_ == cli::mode::training) {
cli.add<bool>("--shuffle-in-ram",
"Keep shuffled corpus in RAM, do not write to temp file");
+
+ cli.add<size_t>("--data-threads",
+ "Number of concurrent threads to use during data reading and processing", 1);
+
// @TODO: Consider making the next two options options of the vocab instead, to make it more local in scope.
cli.add<size_t>("--all-caps-every",
"When forming minibatches, preprocess every Nth line on the fly to all-caps. Assumes UTF-8");
@@ -901,6 +905,9 @@ void ConfigParser::addSuboptionsBatching(cli::CLIWrapper& cli) {
cli.add<bool>("--mini-batch-round-up",
"Round up batch size to next power of 2 for more efficient training, but this can make batch size less stable. Disable with --mini-batch-round-up=false",
true);
+ } else {
+ cli.add<size_t>("--data-threads",
+ "Number of concurrent threads to use during data reading and processing", 1);
}
// clang-format on
}
diff --git a/src/common/utils.cpp b/src/common/utils.cpp
index 72624041..99fc790a 100644
--- a/src/common/utils.cpp
+++ b/src/common/utils.cpp
@@ -70,22 +70,20 @@ void split(const std::string& line,
// the function guarantees that the output has as many elements as requested
void splitTsv(const std::string& line, std::vector<std::string>& fields, size_t numFields) {
fields.clear();
+ fields.resize(numFields); // make sure there is as many elements as requested
size_t begin = 0;
size_t pos = 0;
for(size_t i = 0; i < numFields; ++i) {
pos = line.find('\t', begin);
if(pos == std::string::npos) {
- fields.push_back(line.substr(begin));
+ fields[i] = line.substr(begin);
break;
}
- fields.push_back(line.substr(begin, pos - begin));
+ fields[i] = line.substr(begin, pos - begin);
begin = pos + 1;
}
- if(fields.size() < numFields) // make sure there is as many elements as requested
- fields.resize(numFields);
-
ABORT_IF(pos != std::string::npos, "Excessive field(s) in the tab-separated line: '{}'", line);
}
diff --git a/src/data/batch_generator.h b/src/data/batch_generator.h
index a248db23..ea977468 100644
--- a/src/data/batch_generator.h
+++ b/src/data/batch_generator.h
@@ -2,6 +2,7 @@
#include "common/options.h"
#include "common/signal_handling.h"
+#include "common/timer.h"
#include "data/batch_stats.h"
#include "data/rng_engine.h"
#include "training/training_state.h"
@@ -92,6 +93,8 @@ private:
// this runs on a bg thread; sequencing is handled by caller, but locking is done in here
std::deque<BatchPtr> fetchBatches() {
+ timer::Timer total;
+
typedef typename Sample::value_type Item;
auto itemCmp = [](const Item& sa, const Item& sb) { return sa.size() < sb.size(); }; // sort by element length, not content
@@ -135,19 +138,29 @@ private:
if(current_ != data_->end())
++current_;
}
- size_t sets = 0;
- while(current_ != data_->end() && maxiBatch->size() < maxSize) { // loop over data
+
+ Samples maxiBatchTemp;
+ while(current_ != data_->end() && maxiBatchTemp.size() < maxSize) { // loop over data
if (saveAndExitRequested()) // stop generating batches
return std::deque<BatchPtr>();
- maxiBatch->push(*current_);
- sets = current_->size();
+
+ maxiBatchTemp.push_back(*current_);
+
// do not consume more than required for the maxi batch as this causes
// that line-by-line translation is delayed by one sentence
- bool last = maxiBatch->size() == maxSize;
+ bool last = maxiBatchTemp.size() == maxSize;
if(!last)
++current_; // this actually reads the next line and pre-processes it
}
- size_t numSentencesRead = maxiBatch->size();
+ size_t numSentencesRead = maxiBatchTemp.size();
+
+ size_t sets = 0;
+ for(auto&& s : maxiBatchTemp) {
+ if(!s.empty()) {
+ sets = s.size();
+ maxiBatch->push(s);
+ }
+ }
// construct the actual batches and place them in the queue
Samples batchVector;
@@ -163,6 +176,7 @@ private:
BatchStats::const_iterator cachedStatsIter;
if (stats_)
cachedStatsIter = stats_->begin();
+
while(!maxiBatch->empty()) { // while there are sentences in the queue
if (saveAndExitRequested()) // stop generating batches
return std::deque<BatchPtr>();
@@ -178,12 +192,7 @@ private:
lengths[i] = batchVector.back()[i].size(); // record max lengths so far
maxBatchSize = stats_->findBatchSize(lengths, cachedStatsIter);
- // this optimization makes no difference indeed
-#if 0 // sanity check: would we find the same entry if searching from the start?
- auto it = stats_->lower_bound(lengths);
- auto maxBatchSize1 = stats_->findBatchSize(lengths, it);
- ABORT_IF(maxBatchSize != maxBatchSize1, "findBatchSize iter caching logic is borked");
-#endif
+
makeBatch = batchVector.size() >= maxBatchSize;
// if last added sentence caused a bump then we likely have bad padding, so rather move it into the next batch
if(batchVector.size() > maxBatchSize) {
@@ -231,6 +240,8 @@ private:
LOG(debug, "[data] fetched {} batches with {} sentences. Per batch: {} sentences, {} labels.",
tempBatches.size(), numSentencesRead,
(double)totalSent / (double)totalDenom, (double)totalLabels / (double)totalDenom);
+ LOG(debug, "[data] fetching batches took {:.2f} seconds, {:.2f} sents/s", total.elapsed(), (double)numSentencesRead / total.elapsed());
+
return tempBatches;
}
diff --git a/src/data/corpus.cpp b/src/data/corpus.cpp
index d8a364b2..643a7de9 100644
--- a/src/data/corpus.cpp
+++ b/src/data/corpus.cpp
@@ -14,18 +14,30 @@ namespace data {
Corpus::Corpus(Ptr<Options> options, bool translate /*= false*/, size_t seed /*= Config:seed*/)
: CorpusBase(options, translate, seed),
- shuffleInRAM_(options_->get<bool>("shuffle-in-ram", false)),
- allCapsEvery_(options_->get<size_t>("all-caps-every", 0)),
- titleCaseEvery_(options_->get<size_t>("english-title-case-every", 0)) {}
+ shuffleInRAM_(options_->get<bool>("shuffle-in-ram", false)),
+ allCapsEvery_(options_->get<size_t>("all-caps-every", 0)),
+ titleCaseEvery_(options_->get<size_t>("english-title-case-every", 0)) {
+
+ auto numThreads = options_->get<size_t>("data-threads", 1);
+ if(numThreads > 1)
+ threadPool_.reset(new ThreadPool(numThreads));
+
+}
Corpus::Corpus(std::vector<std::string> paths,
std::vector<Ptr<Vocab>> vocabs,
Ptr<Options> options,
size_t seed /*= Config:seed*/)
: CorpusBase(paths, vocabs, options, seed),
- shuffleInRAM_(options_->get<bool>("shuffle-in-ram", false)),
- allCapsEvery_(options_->get<size_t>("all-caps-every", 0)),
- titleCaseEvery_(options_->get<size_t>("english-title-case-every", 0)) {}
+ shuffleInRAM_(options_->get<bool>("shuffle-in-ram", false)),
+ allCapsEvery_(options_->get<size_t>("all-caps-every", 0)),
+ titleCaseEvery_(options_->get<size_t>("english-title-case-every", 0)) {
+
+ auto numThreads = options_->get<size_t>("data-threads", 1);
+ if(numThreads > 1)
+ threadPool_.reset(new ThreadPool(numThreads));
+
+}
void Corpus::preprocessLine(std::string& line, size_t streamId, bool& altered) {
bool isFactoredVocab = vocabs_.back()->tryAs<FactoredVocab>() != nullptr;
@@ -52,16 +64,10 @@ void Corpus::preprocessLine(std::string& line, size_t streamId, bool& altered) {
}
SentenceTuple Corpus::next() {
- // Used for handling TSV inputs
- // Determine the total number of fields including alignments or weights
- auto tsvNumAllFields = tsvNumInputFields_;
- if(alignFileIdx_ > -1)
- ++tsvNumAllFields;
- if(weightFileIdx_ > -1)
- ++tsvNumAllFields;
- std::vector<std::string> fields(tsvNumAllFields);
-
- for(;;) { // (this is a retry loop for skipping invalid sentences)
+ size_t numStreams = corpusInRAM_.empty() ? files_.size() : corpusInRAM_.size();
+ std::vector<std::string> fields(numStreams);
+
+ while(true) { // retry loop
// get index of the current sentence
size_t curId = pos_; // note: at end, pos_ == total size
// if corpus has been shuffled, ids_ contains sentence indexes
@@ -69,83 +75,91 @@ SentenceTuple Corpus::next() {
curId = ids_[pos_];
pos_++;
- // fill up the sentence tuple with sentences from all input files
- SentenceTuple tup(curId);
size_t eofsHit = 0;
- size_t numStreams = corpusInRAM_.empty() ? files_.size() : corpusInRAM_.size();
- for(size_t i = 0; i < numStreams; ++i) {
- std::string line;
-
+ for(size_t i = 0; i < numStreams; ++i) { // looping of all streams
// fetch line, from cached copy in RAM or actual file
if (!corpusInRAM_.empty()) {
if (curId < corpusInRAM_[i].size())
- line = corpusInRAM_[i][curId];
+ fields[i] = corpusInRAM_[i][curId];
else {
eofsHit++;
continue;
}
}
else {
- bool gotLine = io::getline(*files_[i], line).good();
+ bool gotLine = io::getline(*files_[i], fields[i]).good();
if(!gotLine) {
eofsHit++;
continue;
}
}
+ }
- if(i > 0 && i == alignFileIdx_) {
- addAlignmentToSentenceTuple(line, tup);
- } else if(i > 0 && i == weightFileIdx_) {
- addWeightsToSentenceTuple(line, tup);
- } else {
- if(tsv_) { // split TSV input and add each field into the sentence tuple
- utils::splitTsv(line, fields, tsvNumAllFields);
- size_t shift = 0;
- for(size_t j = 0; j < tsvNumAllFields; ++j) {
- // index j needs to be shifted to get the proper vocab index if guided-alignment or
- // data-weighting are preceding source or target sequences in TSV input
- if(j == alignFileIdx_ || j == weightFileIdx_) {
- ++shift;
- } else {
- size_t vocabId = j - shift;
- bool altered;
- preprocessLine(fields[j], vocabId, /*out=*/altered);
- if (altered)
- tup.markAltered();
- addWordsToSentenceTuple(fields[j], vocabId, tup);
- }
- }
-
- // weights are added last to the sentence tuple, because this runs a validation that needs
- // length of the target sequence
- if(alignFileIdx_ > -1)
- addAlignmentToSentenceTuple(fields[alignFileIdx_], tup);
- if(weightFileIdx_ > -1)
- addWeightsToSentenceTuple(fields[weightFileIdx_], tup);
+ if(eofsHit == numStreams)
+ return SentenceTuple(); // unintialized SentenceTuple which will be invalid when tested
+ ABORT_IF(eofsHit != 0, "not all input files have the same number of lines");
+
+ auto makeSentenceTuple = [this](size_t curId, std::vector<std::string> fields) {
+ if(tsv_) {
+ // with tsv inputs data, there is only one input stream, hence we only have one field
+ // which needs to be tokenized into tab-separated fields
+ ABORT_IF(fields.size() != 1, "Reading TSV file, but we have don't have exactly one stream??");
+ size_t numAllFields = tsvNumInputFields_;
+ if(alignFileIdx_ > -1)
+ ++numAllFields;
+ if(weightFileIdx_ > -1)
+ ++numAllFields;
+ // replace single-element fields array with extracted tsv fields
+ std::vector<std::string> tmpFields;
+ utils::splitTsv(fields[0], tmpFields, numAllFields); // this verifies the number of fields
+ fields.swap(tmpFields);
+ }
+
+ // fill up the sentence tuple with sentences from all input files
+ SentenceTupleImpl tup(curId);
+ size_t shift = 0;
+ for(size_t i = 0; i < fields.size(); ++i) {
+ // index j needs to be shifted to get the proper vocab index if guided-alignment or
+ // data-weighting are preceding source or target sequences in TSV input
+ if(i == alignFileIdx_ || i == weightFileIdx_) {
+ ++shift;
} else {
+ size_t vocabId = i - shift;
bool altered;
- preprocessLine(line, i, /*out=*/altered);
+ preprocessLine(fields[i], vocabId, /*out=*/altered);
if (altered)
tup.markAltered();
- addWordsToSentenceTuple(line, i, tup);
+ addWordsToSentenceTuple(fields[i], vocabId, tup);
}
+
+ // weights are added last to the sentence tuple, because this runs a validation that needs
+ // length of the target sequence
+ if(alignFileIdx_ > -1)
+ addAlignmentToSentenceTuple(fields[alignFileIdx_], tup);
+ if(weightFileIdx_ > -1)
+ addWeightsToSentenceTuple(fields[weightFileIdx_], tup);
}
- }
-
- if (eofsHit == numStreams)
- return SentenceTuple(0);
- ABORT_IF(eofsHit != 0, "not all input files have the same number of lines");
- // check if all streams are valid, that is, non-empty and no longer than maximum allowed length
- if(std::all_of(tup.begin(), tup.end(), [=](const Words& words) {
- return words.size() > 0 && words.size() <= maxLength_;
- }))
- return tup;
+ // check if all streams are valid, that is, non-empty and no longer than maximum allowed length
+ if(std::all_of(tup.begin(), tup.end(), [=](const Words& words) {
+ return words.size() > 0 && words.size() <= maxLength_;
+ })) {
+ return tup;
+ } else {
+ return SentenceTupleImpl(); // return an empty tuple if above test does not pass
+ }
+ };
+
+ if(threadPool_) { // use thread pool if available
+ return SentenceTuple(threadPool_->enqueue(makeSentenceTuple, curId, fields));
+ } else { // otherwise launch here and just pass the result into the wrapper
+ auto tup = makeSentenceTuple(curId, fields);
+ if(!tup.empty())
+ return SentenceTuple(tup);
+ }
- // otherwise skip this sentence and try the next one
- // @TODO: tail recursion?
- }
+ } // end of retry loop
}
// reset and initialize shuffled reading
@@ -167,6 +181,8 @@ void Corpus::reset() {
pos_ = 0;
for (size_t i = 0; i < paths_.size(); ++i) {
if(paths_[i] == "stdin" || paths_[i] == "-") {
+ std::cin.tie(0);
+ std::ios_base::sync_with_stdio(false);
files_[i].reset(new std::istream(std::cin.rdbuf()));
// Probably not necessary, unless there are some buffers
// that we want flushed.
diff --git a/src/data/corpus.h b/src/data/corpus.h
index e8e9a9fd..281d43a2 100644
--- a/src/data/corpus.h
+++ b/src/data/corpus.h
@@ -4,6 +4,7 @@
#include <iostream>
#include <random>
+#include "3rd_party/threadpool.h"
#include "common/definitions.h"
#include "common/file_stream.h"
#include "common/options.h"
@@ -20,6 +21,8 @@ class Corpus : public CorpusBase {
private:
std::vector<UPtr<io::TemporaryFile>> tempFiles_;
std::vector<size_t> ids_;
+
+ UPtr<ThreadPool> threadPool_; // thread pool for parallelized data reading
// for shuffle-in-ram
bool shuffleInRAM_{false};
diff --git a/src/data/corpus_base.cpp b/src/data/corpus_base.cpp
index 5f9a9ee3..bfce31bf 100644
--- a/src/data/corpus_base.cpp
+++ b/src/data/corpus_base.cpp
@@ -12,7 +12,24 @@ typedef std::vector<float> MaskBatch;
typedef std::pair<WordBatch, MaskBatch> WordMask;
typedef std::vector<WordMask> SentBatch;
-CorpusIterator::CorpusIterator() : pos_(-1), tup_(0) {}
+void SentenceTupleImpl::setWeights(const std::vector<float>& weights) {
+ if(weights.size() != 1) { // this assumes a single sentence-level weight is always fine
+ ABORT_IF(empty(), "Source and target sequences should be added to a tuple before data weights");
+ auto numWeights = weights.size();
+ auto numTrgWords = back().size();
+ // word-level weights may or may not contain a weight for EOS tokens
+ if(numWeights != numTrgWords && numWeights != numTrgWords - 1)
+ LOG(warn,
+ "[warn] "
+ "Number of weights ({}) does not match the number of target words ({}) in line #{}",
+ numWeights,
+ numTrgWords,
+ id_);
+ }
+ weights_ = weights;
+}
+
+CorpusIterator::CorpusIterator() : pos_(-1) {}
CorpusIterator::CorpusIterator(CorpusBase* corpus)
: corpus_(corpus), pos_(0), tup_(corpus_->next()) {}
@@ -23,7 +40,7 @@ void CorpusIterator::increment() {
}
bool CorpusIterator::equal(CorpusIterator const& other) const {
- return this->pos_ == other.pos_ || (this->tup_.empty() && other.tup_.empty());
+ return this->pos_ == other.pos_ || (!this->tup_.valid() && !other.tup_.valid());
}
const SentenceTuple& CorpusIterator::dereference() const {
@@ -390,7 +407,7 @@ CorpusBase::CorpusBase(Ptr<Options> options, bool translate, size_t seed)
void CorpusBase::addWordsToSentenceTuple(const std::string& line,
size_t batchIndex,
- SentenceTuple& tup) const {
+ SentenceTupleImpl& tup) const {
// This turns a string in to a sequence of numerical word ids. Depending
// on the vocabulary type, this can be non-trivial, e.g. when SentencePiece
// is used.
@@ -411,7 +428,7 @@ void CorpusBase::addWordsToSentenceTuple(const std::string& line,
}
void CorpusBase::addAlignmentToSentenceTuple(const std::string& line,
- SentenceTuple& tup) const {
+ SentenceTupleImpl& tup) const {
ABORT_IF(rightLeft_,
"Guided alignment and right-left model cannot be used "
"together at the moment");
@@ -420,7 +437,7 @@ void CorpusBase::addAlignmentToSentenceTuple(const std::string& line,
tup.setAlignment(align);
}
-void CorpusBase::addWeightsToSentenceTuple(const std::string& line, SentenceTuple& tup) const {
+void CorpusBase::addWeightsToSentenceTuple(const std::string& line, SentenceTupleImpl& tup) const {
auto elements = utils::split(line, " ");
if(!elements.empty()) {
@@ -549,23 +566,6 @@ size_t CorpusBase::getNumberOfTSVInputFields(Ptr<Options> options) {
return 0;
}
-void SentenceTuple::setWeights(const std::vector<float>& weights) {
- if(weights.size() != 1) { // this assumes a single sentence-level weight is always fine
- ABORT_IF(empty(), "Source and target sequences should be added to a tuple before data weights");
- auto numWeights = weights.size();
- auto numTrgWords = back().size();
- // word-level weights may or may not contain a weight for EOS tokens
- if(numWeights != numTrgWords && numWeights != numTrgWords - 1)
- LOG(warn,
- "[warn] "
- "Number of weights ({}) does not match the number of target words ({}) in line #{}",
- numWeights,
- numTrgWords,
- id_);
- }
- weights_ = weights;
-}
-
// experimental: hide inline-fix source tokens from cross attention
std::vector<float> SubBatch::crossMaskWithInlineFixSourceSuppressed() const
{
diff --git a/src/data/corpus_base.h b/src/data/corpus_base.h
index 251df5bc..82a01286 100644
--- a/src/data/corpus_base.h
+++ b/src/data/corpus_base.h
@@ -11,6 +11,8 @@
#include "data/rng_engine.h"
#include "data/vocab.h"
+#include <future>
+
namespace marian {
namespace data {
@@ -22,7 +24,7 @@ namespace data {
* construction of marian::data::CorpusBatch objects. They are not a part of
* marian::data::CorpusBatch.
*/
-class SentenceTuple {
+class SentenceTupleImpl {
private:
size_t id_;
std::vector<Words> tuple_; // [stream index][step index]
@@ -34,11 +36,16 @@ public:
typedef Words value_type;
/**
+ * @brief Creates an empty tuple with 0 id (default constructor).
+ */
+ SentenceTupleImpl() : id_(0) {}
+
+ /**
* @brief Creates an empty tuple with the given Id.
*/
- SentenceTuple(size_t id) : id_(id) {}
+ SentenceTupleImpl(size_t id) : id_(id) {}
- ~SentenceTuple() { tuple_.clear(); }
+ ~SentenceTupleImpl() {}
/**
* @brief Returns the sentence's ID.
@@ -114,6 +121,92 @@ public:
void setAlignment(const WordAlignment& alignment) { alignment_ = alignment; }
};
+class SentenceTuple {
+private:
+ std::shared_ptr<std::future<SentenceTupleImpl>> fImpl_;
+ mutable std::shared_ptr<SentenceTupleImpl> impl_;
+
+public:
+ typedef Words value_type;
+
+ /**
+ * @brief Creates an empty tuple with no associated future.
+ */
+ SentenceTuple() {}
+
+ SentenceTuple(const SentenceTupleImpl& tupImpl)
+ : impl_(std::make_shared<SentenceTupleImpl>(tupImpl)) {}
+
+ SentenceTuple(std::future<SentenceTupleImpl>&& fImpl)
+ : fImpl_(new std::future<SentenceTupleImpl>(std::move(fImpl))) {}
+
+ SentenceTupleImpl& get() const {
+ if(!impl_) {
+ ABORT_IF(!fImpl_ || !fImpl_->valid(), "No future tuple associated with SentenceTuple");
+ impl_ = std::make_shared<SentenceTupleImpl>(fImpl_->get());
+ }
+ return *impl_;
+ }
+
+ /**
+ * @brief Returns the sentence's ID.
+ */
+ size_t getId() const { return get().getId(); }
+
+ /**
+ * @brief Returns whether this Tuple was altered or augmented from what
+ * was provided to Marian in input.
+ */
+ bool isAltered() const { return get().isAltered(); }
+
+ /**
+ * @brief The size of the tuple, e.g. two for parallel data with a source and
+ * target sentences.
+ */
+ size_t size() const { return get().size(); }
+
+ /**
+ * @brief confirms that the tuple has been populated with data
+ */
+ bool valid() const {
+ return fImpl_ || impl_;
+ }
+
+ /**
+ * @brief The i-th tuple sentence.
+ *
+ * @param i Tuple's index.
+ */
+ Words& operator[](size_t i) { return get()[i]; }
+ const Words& operator[](size_t i) const { return get()[i]; }
+
+ /**
+ * @brief The last tuple sentence, i.e. the target sentence.
+ */
+ Words& back() { return get().back(); }
+ const Words& back() const { return get().back(); }
+
+ /**
+ * @brief Checks whether the tuple is empty.
+ */
+ bool empty() const { return get().empty(); }
+
+ auto begin() const -> decltype(get().begin()) { return get().begin(); }
+ auto end() const -> decltype(get().end()) { return get().end(); }
+
+ auto rbegin() const -> decltype(get().rbegin()) { return get().rbegin(); }
+ auto rend() const -> decltype(get().rend()) { return get().rend(); }
+
+ /**
+ * @brief Get sentence weights.
+ *
+ * For sentence-level weights the vector contains only one element.
+ */
+ const std::vector<float>& getWeights() const { return get().getWeights(); }
+
+ const WordAlignment& getAlignment() const { return get().getAlignment(); }
+};
+
/**
* @brief Batch of sentences represented as word indices with masking.
*/
@@ -586,17 +679,17 @@ protected:
* @brief Helper function converting a line of text into words using the i-th
* vocabulary and adding them to the sentence tuple.
*/
- void addWordsToSentenceTuple(const std::string& line, size_t batchIndex, SentenceTuple& tup) const;
+ void addWordsToSentenceTuple(const std::string& line, size_t batchIndex, SentenceTupleImpl& tup) const;
/**
* @brief Helper function parsing a line with word alignments and adding them
* to the sentence tuple.
*/
- void addAlignmentToSentenceTuple(const std::string& line, SentenceTuple& tup) const;
+ void addAlignmentToSentenceTuple(const std::string& line, SentenceTupleImpl& tup) const;
/**
* @brief Helper function parsing a line of weights and adding them to the
* sentence tuple.
*/
- void addWeightsToSentenceTuple(const std::string& line, SentenceTuple& tup) const;
+ void addWeightsToSentenceTuple(const std::string& line, SentenceTupleImpl& tup) const;
void addAlignmentsToBatch(Ptr<CorpusBatch> batch, const std::vector<Sample>& batchVector);
diff --git a/src/data/corpus_nbest.cpp b/src/data/corpus_nbest.cpp
index d5a48d8d..8029d351 100644
--- a/src/data/corpus_nbest.cpp
+++ b/src/data/corpus_nbest.cpp
@@ -43,7 +43,7 @@ SentenceTuple CorpusNBest::next() {
pos_++;
// fill up the sentence tuple with sentences from all input files
- SentenceTuple tup(curId);
+ SentenceTupleImpl tup(curId);
std::string line;
lastLines_.resize(files_.size() - 1);
@@ -74,9 +74,10 @@ SentenceTuple CorpusNBest::next() {
if(cont && std::all_of(tup.begin(), tup.end(), [=](const Words& words) {
return words.size() > 0 && words.size() <= maxLength_;
}))
- return tup;
+ return SentenceTuple(tup);
}
- return SentenceTuple(0);
+
+ return SentenceTuple();
}
void CorpusNBest::reset() {
diff --git a/src/data/corpus_sqlite.cpp b/src/data/corpus_sqlite.cpp
index 297847c0..f7c577f2 100644
--- a/src/data/corpus_sqlite.cpp
+++ b/src/data/corpus_sqlite.cpp
@@ -109,7 +109,7 @@ SentenceTuple CorpusSQLite::next() {
while(select_->executeStep()) {
// fill up the sentence tuple with sentences from all input files
size_t curId = select_->getColumn(0).getInt();
- SentenceTuple tup(curId);
+ SentenceTupleImpl tup(curId);
for(size_t i = 0; i < files_.size(); ++i) {
auto line = select_->getColumn((int)(i + 1));
@@ -126,9 +126,9 @@ SentenceTuple CorpusSQLite::next() {
if(std::all_of(tup.begin(), tup.end(), [=](const Words& words) {
return words.size() > 0 && words.size() <= maxLength_;
}))
- return tup;
+ return SentenceTuple(tup);
}
- return SentenceTuple(0);
+ return SentenceTuple();
}
void CorpusSQLite::shuffle() {
diff --git a/src/data/text_input.cpp b/src/data/text_input.cpp
index 958190fc..b1f4cdd4 100644
--- a/src/data/text_input.cpp
+++ b/src/data/text_input.cpp
@@ -40,7 +40,7 @@ SentenceTuple TextInput::next() {
size_t curId = pos_++;
// fill up the sentence tuple with source and/or target sentences
- SentenceTuple tup(curId);
+ SentenceTupleImpl tup(curId);
for(size_t i = 0; i < files_.size(); ++i) {
std::string line;
if(io::getline(*files_[i], line)) {
@@ -57,9 +57,9 @@ SentenceTuple TextInput::next() {
}
if(tup.size() == files_.size()) // check if each input file provided an example
- return tup;
+ return SentenceTuple(tup);
else if(tup.size() == 0) // if no file provided examples we are done
- return SentenceTuple(0);
+ return SentenceTuple();
else // neither all nor none => we have at least on missing entry
ABORT("There are missing entries in the text tuples.");
}