diff options
author | Marcin Junczys-Dowmunt <Marcin.JunczysDowmunt@microsoft.com> | 2021-11-25 05:33:49 +0300 |
---|---|---|
committer | Marcin Junczys-Dowmunt <Marcin.JunczysDowmunt@microsoft.com> | 2021-11-25 05:33:49 +0300 |
commit | 8b8d1b11e28a421b348703d702c9c5206061df9d (patch) | |
tree | 01f5313e706bdba5c3317d5a713f8ab6224ec2da | |
parent | c85d0608483789d446361ea28d95f7d7c9545f2d (diff) |
Merged PR 21553: Parallelize data reading for training
This parallelizes data reading. On very fast GPUs and with small models training speed can be starved by too slow batch creation. Use --data-threads 8 or more, by default currently set to 1 for backcompat.
-rw-r--r-- | src/common/config_parser.cpp | 7 | ||||
-rw-r--r-- | src/common/utils.cpp | 8 | ||||
-rw-r--r-- | src/data/batch_generator.h | 35 | ||||
-rw-r--r-- | src/data/corpus.cpp | 152 | ||||
-rw-r--r-- | src/data/corpus.h | 3 | ||||
-rw-r--r-- | src/data/corpus_base.cpp | 44 | ||||
-rw-r--r-- | src/data/corpus_base.h | 105 | ||||
-rw-r--r-- | src/data/corpus_nbest.cpp | 7 | ||||
-rw-r--r-- | src/data/corpus_sqlite.cpp | 6 | ||||
-rw-r--r-- | src/data/text_input.cpp | 6 |
10 files changed, 251 insertions, 122 deletions
diff --git a/src/common/config_parser.cpp b/src/common/config_parser.cpp index 59b328e9..3d79f8af 100644 --- a/src/common/config_parser.cpp +++ b/src/common/config_parser.cpp @@ -883,6 +883,10 @@ void ConfigParser::addSuboptionsBatching(cli::CLIWrapper& cli) { if(mode_ == cli::mode::training) { cli.add<bool>("--shuffle-in-ram", "Keep shuffled corpus in RAM, do not write to temp file"); + + cli.add<size_t>("--data-threads", + "Number of concurrent threads to use during data reading and processing", 1); + // @TODO: Consider making the next two options options of the vocab instead, to make it more local in scope. cli.add<size_t>("--all-caps-every", "When forming minibatches, preprocess every Nth line on the fly to all-caps. Assumes UTF-8"); @@ -901,6 +905,9 @@ void ConfigParser::addSuboptionsBatching(cli::CLIWrapper& cli) { cli.add<bool>("--mini-batch-round-up", "Round up batch size to next power of 2 for more efficient training, but this can make batch size less stable. Disable with --mini-batch-round-up=false", true); + } else { + cli.add<size_t>("--data-threads", + "Number of concurrent threads to use during data reading and processing", 1); } // clang-format on } diff --git a/src/common/utils.cpp b/src/common/utils.cpp index 72624041..99fc790a 100644 --- a/src/common/utils.cpp +++ b/src/common/utils.cpp @@ -70,22 +70,20 @@ void split(const std::string& line, // the function guarantees that the output has as many elements as requested void splitTsv(const std::string& line, std::vector<std::string>& fields, size_t numFields) { fields.clear(); + fields.resize(numFields); // make sure there is as many elements as requested size_t begin = 0; size_t pos = 0; for(size_t i = 0; i < numFields; ++i) { pos = line.find('\t', begin); if(pos == std::string::npos) { - fields.push_back(line.substr(begin)); + fields[i] = line.substr(begin); break; } - fields.push_back(line.substr(begin, pos - begin)); + fields[i] = line.substr(begin, pos - begin); begin = pos + 1; } - if(fields.size() < numFields) // make sure there is as many elements as requested - fields.resize(numFields); - ABORT_IF(pos != std::string::npos, "Excessive field(s) in the tab-separated line: '{}'", line); } diff --git a/src/data/batch_generator.h b/src/data/batch_generator.h index a248db23..ea977468 100644 --- a/src/data/batch_generator.h +++ b/src/data/batch_generator.h @@ -2,6 +2,7 @@ #include "common/options.h" #include "common/signal_handling.h" +#include "common/timer.h" #include "data/batch_stats.h" #include "data/rng_engine.h" #include "training/training_state.h" @@ -92,6 +93,8 @@ private: // this runs on a bg thread; sequencing is handled by caller, but locking is done in here std::deque<BatchPtr> fetchBatches() { + timer::Timer total; + typedef typename Sample::value_type Item; auto itemCmp = [](const Item& sa, const Item& sb) { return sa.size() < sb.size(); }; // sort by element length, not content @@ -135,19 +138,29 @@ private: if(current_ != data_->end()) ++current_; } - size_t sets = 0; - while(current_ != data_->end() && maxiBatch->size() < maxSize) { // loop over data + + Samples maxiBatchTemp; + while(current_ != data_->end() && maxiBatchTemp.size() < maxSize) { // loop over data if (saveAndExitRequested()) // stop generating batches return std::deque<BatchPtr>(); - maxiBatch->push(*current_); - sets = current_->size(); + + maxiBatchTemp.push_back(*current_); + // do not consume more than required for the maxi batch as this causes // that line-by-line translation is delayed by one sentence - bool last = maxiBatch->size() == maxSize; + bool last = maxiBatchTemp.size() == maxSize; if(!last) ++current_; // this actually reads the next line and pre-processes it } - size_t numSentencesRead = maxiBatch->size(); + size_t numSentencesRead = maxiBatchTemp.size(); + + size_t sets = 0; + for(auto&& s : maxiBatchTemp) { + if(!s.empty()) { + sets = s.size(); + maxiBatch->push(s); + } + } // construct the actual batches and place them in the queue Samples batchVector; @@ -163,6 +176,7 @@ private: BatchStats::const_iterator cachedStatsIter; if (stats_) cachedStatsIter = stats_->begin(); + while(!maxiBatch->empty()) { // while there are sentences in the queue if (saveAndExitRequested()) // stop generating batches return std::deque<BatchPtr>(); @@ -178,12 +192,7 @@ private: lengths[i] = batchVector.back()[i].size(); // record max lengths so far maxBatchSize = stats_->findBatchSize(lengths, cachedStatsIter); - // this optimization makes no difference indeed -#if 0 // sanity check: would we find the same entry if searching from the start? - auto it = stats_->lower_bound(lengths); - auto maxBatchSize1 = stats_->findBatchSize(lengths, it); - ABORT_IF(maxBatchSize != maxBatchSize1, "findBatchSize iter caching logic is borked"); -#endif + makeBatch = batchVector.size() >= maxBatchSize; // if last added sentence caused a bump then we likely have bad padding, so rather move it into the next batch if(batchVector.size() > maxBatchSize) { @@ -231,6 +240,8 @@ private: LOG(debug, "[data] fetched {} batches with {} sentences. Per batch: {} sentences, {} labels.", tempBatches.size(), numSentencesRead, (double)totalSent / (double)totalDenom, (double)totalLabels / (double)totalDenom); + LOG(debug, "[data] fetching batches took {:.2f} seconds, {:.2f} sents/s", total.elapsed(), (double)numSentencesRead / total.elapsed()); + return tempBatches; } diff --git a/src/data/corpus.cpp b/src/data/corpus.cpp index d8a364b2..643a7de9 100644 --- a/src/data/corpus.cpp +++ b/src/data/corpus.cpp @@ -14,18 +14,30 @@ namespace data { Corpus::Corpus(Ptr<Options> options, bool translate /*= false*/, size_t seed /*= Config:seed*/) : CorpusBase(options, translate, seed), - shuffleInRAM_(options_->get<bool>("shuffle-in-ram", false)), - allCapsEvery_(options_->get<size_t>("all-caps-every", 0)), - titleCaseEvery_(options_->get<size_t>("english-title-case-every", 0)) {} + shuffleInRAM_(options_->get<bool>("shuffle-in-ram", false)), + allCapsEvery_(options_->get<size_t>("all-caps-every", 0)), + titleCaseEvery_(options_->get<size_t>("english-title-case-every", 0)) { + + auto numThreads = options_->get<size_t>("data-threads", 1); + if(numThreads > 1) + threadPool_.reset(new ThreadPool(numThreads)); + +} Corpus::Corpus(std::vector<std::string> paths, std::vector<Ptr<Vocab>> vocabs, Ptr<Options> options, size_t seed /*= Config:seed*/) : CorpusBase(paths, vocabs, options, seed), - shuffleInRAM_(options_->get<bool>("shuffle-in-ram", false)), - allCapsEvery_(options_->get<size_t>("all-caps-every", 0)), - titleCaseEvery_(options_->get<size_t>("english-title-case-every", 0)) {} + shuffleInRAM_(options_->get<bool>("shuffle-in-ram", false)), + allCapsEvery_(options_->get<size_t>("all-caps-every", 0)), + titleCaseEvery_(options_->get<size_t>("english-title-case-every", 0)) { + + auto numThreads = options_->get<size_t>("data-threads", 1); + if(numThreads > 1) + threadPool_.reset(new ThreadPool(numThreads)); + +} void Corpus::preprocessLine(std::string& line, size_t streamId, bool& altered) { bool isFactoredVocab = vocabs_.back()->tryAs<FactoredVocab>() != nullptr; @@ -52,16 +64,10 @@ void Corpus::preprocessLine(std::string& line, size_t streamId, bool& altered) { } SentenceTuple Corpus::next() { - // Used for handling TSV inputs - // Determine the total number of fields including alignments or weights - auto tsvNumAllFields = tsvNumInputFields_; - if(alignFileIdx_ > -1) - ++tsvNumAllFields; - if(weightFileIdx_ > -1) - ++tsvNumAllFields; - std::vector<std::string> fields(tsvNumAllFields); - - for(;;) { // (this is a retry loop for skipping invalid sentences) + size_t numStreams = corpusInRAM_.empty() ? files_.size() : corpusInRAM_.size(); + std::vector<std::string> fields(numStreams); + + while(true) { // retry loop // get index of the current sentence size_t curId = pos_; // note: at end, pos_ == total size // if corpus has been shuffled, ids_ contains sentence indexes @@ -69,83 +75,91 @@ SentenceTuple Corpus::next() { curId = ids_[pos_]; pos_++; - // fill up the sentence tuple with sentences from all input files - SentenceTuple tup(curId); size_t eofsHit = 0; - size_t numStreams = corpusInRAM_.empty() ? files_.size() : corpusInRAM_.size(); - for(size_t i = 0; i < numStreams; ++i) { - std::string line; - + for(size_t i = 0; i < numStreams; ++i) { // looping of all streams // fetch line, from cached copy in RAM or actual file if (!corpusInRAM_.empty()) { if (curId < corpusInRAM_[i].size()) - line = corpusInRAM_[i][curId]; + fields[i] = corpusInRAM_[i][curId]; else { eofsHit++; continue; } } else { - bool gotLine = io::getline(*files_[i], line).good(); + bool gotLine = io::getline(*files_[i], fields[i]).good(); if(!gotLine) { eofsHit++; continue; } } + } - if(i > 0 && i == alignFileIdx_) { - addAlignmentToSentenceTuple(line, tup); - } else if(i > 0 && i == weightFileIdx_) { - addWeightsToSentenceTuple(line, tup); - } else { - if(tsv_) { // split TSV input and add each field into the sentence tuple - utils::splitTsv(line, fields, tsvNumAllFields); - size_t shift = 0; - for(size_t j = 0; j < tsvNumAllFields; ++j) { - // index j needs to be shifted to get the proper vocab index if guided-alignment or - // data-weighting are preceding source or target sequences in TSV input - if(j == alignFileIdx_ || j == weightFileIdx_) { - ++shift; - } else { - size_t vocabId = j - shift; - bool altered; - preprocessLine(fields[j], vocabId, /*out=*/altered); - if (altered) - tup.markAltered(); - addWordsToSentenceTuple(fields[j], vocabId, tup); - } - } - - // weights are added last to the sentence tuple, because this runs a validation that needs - // length of the target sequence - if(alignFileIdx_ > -1) - addAlignmentToSentenceTuple(fields[alignFileIdx_], tup); - if(weightFileIdx_ > -1) - addWeightsToSentenceTuple(fields[weightFileIdx_], tup); + if(eofsHit == numStreams) + return SentenceTuple(); // unintialized SentenceTuple which will be invalid when tested + ABORT_IF(eofsHit != 0, "not all input files have the same number of lines"); + + auto makeSentenceTuple = [this](size_t curId, std::vector<std::string> fields) { + if(tsv_) { + // with tsv inputs data, there is only one input stream, hence we only have one field + // which needs to be tokenized into tab-separated fields + ABORT_IF(fields.size() != 1, "Reading TSV file, but we have don't have exactly one stream??"); + size_t numAllFields = tsvNumInputFields_; + if(alignFileIdx_ > -1) + ++numAllFields; + if(weightFileIdx_ > -1) + ++numAllFields; + // replace single-element fields array with extracted tsv fields + std::vector<std::string> tmpFields; + utils::splitTsv(fields[0], tmpFields, numAllFields); // this verifies the number of fields + fields.swap(tmpFields); + } + + // fill up the sentence tuple with sentences from all input files + SentenceTupleImpl tup(curId); + size_t shift = 0; + for(size_t i = 0; i < fields.size(); ++i) { + // index j needs to be shifted to get the proper vocab index if guided-alignment or + // data-weighting are preceding source or target sequences in TSV input + if(i == alignFileIdx_ || i == weightFileIdx_) { + ++shift; } else { + size_t vocabId = i - shift; bool altered; - preprocessLine(line, i, /*out=*/altered); + preprocessLine(fields[i], vocabId, /*out=*/altered); if (altered) tup.markAltered(); - addWordsToSentenceTuple(line, i, tup); + addWordsToSentenceTuple(fields[i], vocabId, tup); } + + // weights are added last to the sentence tuple, because this runs a validation that needs + // length of the target sequence + if(alignFileIdx_ > -1) + addAlignmentToSentenceTuple(fields[alignFileIdx_], tup); + if(weightFileIdx_ > -1) + addWeightsToSentenceTuple(fields[weightFileIdx_], tup); } - } - - if (eofsHit == numStreams) - return SentenceTuple(0); - ABORT_IF(eofsHit != 0, "not all input files have the same number of lines"); - // check if all streams are valid, that is, non-empty and no longer than maximum allowed length - if(std::all_of(tup.begin(), tup.end(), [=](const Words& words) { - return words.size() > 0 && words.size() <= maxLength_; - })) - return tup; + // check if all streams are valid, that is, non-empty and no longer than maximum allowed length + if(std::all_of(tup.begin(), tup.end(), [=](const Words& words) { + return words.size() > 0 && words.size() <= maxLength_; + })) { + return tup; + } else { + return SentenceTupleImpl(); // return an empty tuple if above test does not pass + } + }; + + if(threadPool_) { // use thread pool if available + return SentenceTuple(threadPool_->enqueue(makeSentenceTuple, curId, fields)); + } else { // otherwise launch here and just pass the result into the wrapper + auto tup = makeSentenceTuple(curId, fields); + if(!tup.empty()) + return SentenceTuple(tup); + } - // otherwise skip this sentence and try the next one - // @TODO: tail recursion? - } + } // end of retry loop } // reset and initialize shuffled reading @@ -167,6 +181,8 @@ void Corpus::reset() { pos_ = 0; for (size_t i = 0; i < paths_.size(); ++i) { if(paths_[i] == "stdin" || paths_[i] == "-") { + std::cin.tie(0); + std::ios_base::sync_with_stdio(false); files_[i].reset(new std::istream(std::cin.rdbuf())); // Probably not necessary, unless there are some buffers // that we want flushed. diff --git a/src/data/corpus.h b/src/data/corpus.h index e8e9a9fd..281d43a2 100644 --- a/src/data/corpus.h +++ b/src/data/corpus.h @@ -4,6 +4,7 @@ #include <iostream> #include <random> +#include "3rd_party/threadpool.h" #include "common/definitions.h" #include "common/file_stream.h" #include "common/options.h" @@ -20,6 +21,8 @@ class Corpus : public CorpusBase { private: std::vector<UPtr<io::TemporaryFile>> tempFiles_; std::vector<size_t> ids_; + + UPtr<ThreadPool> threadPool_; // thread pool for parallelized data reading // for shuffle-in-ram bool shuffleInRAM_{false}; diff --git a/src/data/corpus_base.cpp b/src/data/corpus_base.cpp index 5f9a9ee3..bfce31bf 100644 --- a/src/data/corpus_base.cpp +++ b/src/data/corpus_base.cpp @@ -12,7 +12,24 @@ typedef std::vector<float> MaskBatch; typedef std::pair<WordBatch, MaskBatch> WordMask; typedef std::vector<WordMask> SentBatch; -CorpusIterator::CorpusIterator() : pos_(-1), tup_(0) {} +void SentenceTupleImpl::setWeights(const std::vector<float>& weights) { + if(weights.size() != 1) { // this assumes a single sentence-level weight is always fine + ABORT_IF(empty(), "Source and target sequences should be added to a tuple before data weights"); + auto numWeights = weights.size(); + auto numTrgWords = back().size(); + // word-level weights may or may not contain a weight for EOS tokens + if(numWeights != numTrgWords && numWeights != numTrgWords - 1) + LOG(warn, + "[warn] " + "Number of weights ({}) does not match the number of target words ({}) in line #{}", + numWeights, + numTrgWords, + id_); + } + weights_ = weights; +} + +CorpusIterator::CorpusIterator() : pos_(-1) {} CorpusIterator::CorpusIterator(CorpusBase* corpus) : corpus_(corpus), pos_(0), tup_(corpus_->next()) {} @@ -23,7 +40,7 @@ void CorpusIterator::increment() { } bool CorpusIterator::equal(CorpusIterator const& other) const { - return this->pos_ == other.pos_ || (this->tup_.empty() && other.tup_.empty()); + return this->pos_ == other.pos_ || (!this->tup_.valid() && !other.tup_.valid()); } const SentenceTuple& CorpusIterator::dereference() const { @@ -390,7 +407,7 @@ CorpusBase::CorpusBase(Ptr<Options> options, bool translate, size_t seed) void CorpusBase::addWordsToSentenceTuple(const std::string& line, size_t batchIndex, - SentenceTuple& tup) const { + SentenceTupleImpl& tup) const { // This turns a string in to a sequence of numerical word ids. Depending // on the vocabulary type, this can be non-trivial, e.g. when SentencePiece // is used. @@ -411,7 +428,7 @@ void CorpusBase::addWordsToSentenceTuple(const std::string& line, } void CorpusBase::addAlignmentToSentenceTuple(const std::string& line, - SentenceTuple& tup) const { + SentenceTupleImpl& tup) const { ABORT_IF(rightLeft_, "Guided alignment and right-left model cannot be used " "together at the moment"); @@ -420,7 +437,7 @@ void CorpusBase::addAlignmentToSentenceTuple(const std::string& line, tup.setAlignment(align); } -void CorpusBase::addWeightsToSentenceTuple(const std::string& line, SentenceTuple& tup) const { +void CorpusBase::addWeightsToSentenceTuple(const std::string& line, SentenceTupleImpl& tup) const { auto elements = utils::split(line, " "); if(!elements.empty()) { @@ -549,23 +566,6 @@ size_t CorpusBase::getNumberOfTSVInputFields(Ptr<Options> options) { return 0; } -void SentenceTuple::setWeights(const std::vector<float>& weights) { - if(weights.size() != 1) { // this assumes a single sentence-level weight is always fine - ABORT_IF(empty(), "Source and target sequences should be added to a tuple before data weights"); - auto numWeights = weights.size(); - auto numTrgWords = back().size(); - // word-level weights may or may not contain a weight for EOS tokens - if(numWeights != numTrgWords && numWeights != numTrgWords - 1) - LOG(warn, - "[warn] " - "Number of weights ({}) does not match the number of target words ({}) in line #{}", - numWeights, - numTrgWords, - id_); - } - weights_ = weights; -} - // experimental: hide inline-fix source tokens from cross attention std::vector<float> SubBatch::crossMaskWithInlineFixSourceSuppressed() const { diff --git a/src/data/corpus_base.h b/src/data/corpus_base.h index 251df5bc..82a01286 100644 --- a/src/data/corpus_base.h +++ b/src/data/corpus_base.h @@ -11,6 +11,8 @@ #include "data/rng_engine.h" #include "data/vocab.h" +#include <future> + namespace marian { namespace data { @@ -22,7 +24,7 @@ namespace data { * construction of marian::data::CorpusBatch objects. They are not a part of * marian::data::CorpusBatch. */ -class SentenceTuple { +class SentenceTupleImpl { private: size_t id_; std::vector<Words> tuple_; // [stream index][step index] @@ -34,11 +36,16 @@ public: typedef Words value_type; /** + * @brief Creates an empty tuple with 0 id (default constructor). + */ + SentenceTupleImpl() : id_(0) {} + + /** * @brief Creates an empty tuple with the given Id. */ - SentenceTuple(size_t id) : id_(id) {} + SentenceTupleImpl(size_t id) : id_(id) {} - ~SentenceTuple() { tuple_.clear(); } + ~SentenceTupleImpl() {} /** * @brief Returns the sentence's ID. @@ -114,6 +121,92 @@ public: void setAlignment(const WordAlignment& alignment) { alignment_ = alignment; } }; +class SentenceTuple { +private: + std::shared_ptr<std::future<SentenceTupleImpl>> fImpl_; + mutable std::shared_ptr<SentenceTupleImpl> impl_; + +public: + typedef Words value_type; + + /** + * @brief Creates an empty tuple with no associated future. + */ + SentenceTuple() {} + + SentenceTuple(const SentenceTupleImpl& tupImpl) + : impl_(std::make_shared<SentenceTupleImpl>(tupImpl)) {} + + SentenceTuple(std::future<SentenceTupleImpl>&& fImpl) + : fImpl_(new std::future<SentenceTupleImpl>(std::move(fImpl))) {} + + SentenceTupleImpl& get() const { + if(!impl_) { + ABORT_IF(!fImpl_ || !fImpl_->valid(), "No future tuple associated with SentenceTuple"); + impl_ = std::make_shared<SentenceTupleImpl>(fImpl_->get()); + } + return *impl_; + } + + /** + * @brief Returns the sentence's ID. + */ + size_t getId() const { return get().getId(); } + + /** + * @brief Returns whether this Tuple was altered or augmented from what + * was provided to Marian in input. + */ + bool isAltered() const { return get().isAltered(); } + + /** + * @brief The size of the tuple, e.g. two for parallel data with a source and + * target sentences. + */ + size_t size() const { return get().size(); } + + /** + * @brief confirms that the tuple has been populated with data + */ + bool valid() const { + return fImpl_ || impl_; + } + + /** + * @brief The i-th tuple sentence. + * + * @param i Tuple's index. + */ + Words& operator[](size_t i) { return get()[i]; } + const Words& operator[](size_t i) const { return get()[i]; } + + /** + * @brief The last tuple sentence, i.e. the target sentence. + */ + Words& back() { return get().back(); } + const Words& back() const { return get().back(); } + + /** + * @brief Checks whether the tuple is empty. + */ + bool empty() const { return get().empty(); } + + auto begin() const -> decltype(get().begin()) { return get().begin(); } + auto end() const -> decltype(get().end()) { return get().end(); } + + auto rbegin() const -> decltype(get().rbegin()) { return get().rbegin(); } + auto rend() const -> decltype(get().rend()) { return get().rend(); } + + /** + * @brief Get sentence weights. + * + * For sentence-level weights the vector contains only one element. + */ + const std::vector<float>& getWeights() const { return get().getWeights(); } + + const WordAlignment& getAlignment() const { return get().getAlignment(); } +}; + /** * @brief Batch of sentences represented as word indices with masking. */ @@ -586,17 +679,17 @@ protected: * @brief Helper function converting a line of text into words using the i-th * vocabulary and adding them to the sentence tuple. */ - void addWordsToSentenceTuple(const std::string& line, size_t batchIndex, SentenceTuple& tup) const; + void addWordsToSentenceTuple(const std::string& line, size_t batchIndex, SentenceTupleImpl& tup) const; /** * @brief Helper function parsing a line with word alignments and adding them * to the sentence tuple. */ - void addAlignmentToSentenceTuple(const std::string& line, SentenceTuple& tup) const; + void addAlignmentToSentenceTuple(const std::string& line, SentenceTupleImpl& tup) const; /** * @brief Helper function parsing a line of weights and adding them to the * sentence tuple. */ - void addWeightsToSentenceTuple(const std::string& line, SentenceTuple& tup) const; + void addWeightsToSentenceTuple(const std::string& line, SentenceTupleImpl& tup) const; void addAlignmentsToBatch(Ptr<CorpusBatch> batch, const std::vector<Sample>& batchVector); diff --git a/src/data/corpus_nbest.cpp b/src/data/corpus_nbest.cpp index d5a48d8d..8029d351 100644 --- a/src/data/corpus_nbest.cpp +++ b/src/data/corpus_nbest.cpp @@ -43,7 +43,7 @@ SentenceTuple CorpusNBest::next() { pos_++; // fill up the sentence tuple with sentences from all input files - SentenceTuple tup(curId); + SentenceTupleImpl tup(curId); std::string line; lastLines_.resize(files_.size() - 1); @@ -74,9 +74,10 @@ SentenceTuple CorpusNBest::next() { if(cont && std::all_of(tup.begin(), tup.end(), [=](const Words& words) { return words.size() > 0 && words.size() <= maxLength_; })) - return tup; + return SentenceTuple(tup); } - return SentenceTuple(0); + + return SentenceTuple(); } void CorpusNBest::reset() { diff --git a/src/data/corpus_sqlite.cpp b/src/data/corpus_sqlite.cpp index 297847c0..f7c577f2 100644 --- a/src/data/corpus_sqlite.cpp +++ b/src/data/corpus_sqlite.cpp @@ -109,7 +109,7 @@ SentenceTuple CorpusSQLite::next() { while(select_->executeStep()) { // fill up the sentence tuple with sentences from all input files size_t curId = select_->getColumn(0).getInt(); - SentenceTuple tup(curId); + SentenceTupleImpl tup(curId); for(size_t i = 0; i < files_.size(); ++i) { auto line = select_->getColumn((int)(i + 1)); @@ -126,9 +126,9 @@ SentenceTuple CorpusSQLite::next() { if(std::all_of(tup.begin(), tup.end(), [=](const Words& words) { return words.size() > 0 && words.size() <= maxLength_; })) - return tup; + return SentenceTuple(tup); } - return SentenceTuple(0); + return SentenceTuple(); } void CorpusSQLite::shuffle() { diff --git a/src/data/text_input.cpp b/src/data/text_input.cpp index 958190fc..b1f4cdd4 100644 --- a/src/data/text_input.cpp +++ b/src/data/text_input.cpp @@ -40,7 +40,7 @@ SentenceTuple TextInput::next() { size_t curId = pos_++; // fill up the sentence tuple with source and/or target sentences - SentenceTuple tup(curId); + SentenceTupleImpl tup(curId); for(size_t i = 0; i < files_.size(); ++i) { std::string line; if(io::getline(*files_[i], line)) { @@ -57,9 +57,9 @@ SentenceTuple TextInput::next() { } if(tup.size() == files_.size()) // check if each input file provided an example - return tup; + return SentenceTuple(tup); else if(tup.size() == 0) // if no file provided examples we are done - return SentenceTuple(0); + return SentenceTuple(); else // neither all nor none => we have at least on missing entry ABORT("There are missing entries in the text tuples."); } |