diff options
Diffstat (limited to 'src/data/batch_generator.h')
-rw-r--r-- | src/data/batch_generator.h | 35 |
1 files changed, 23 insertions, 12 deletions
diff --git a/src/data/batch_generator.h b/src/data/batch_generator.h index a248db23..ea977468 100644 --- a/src/data/batch_generator.h +++ b/src/data/batch_generator.h @@ -2,6 +2,7 @@ #include "common/options.h" #include "common/signal_handling.h" +#include "common/timer.h" #include "data/batch_stats.h" #include "data/rng_engine.h" #include "training/training_state.h" @@ -92,6 +93,8 @@ private: // this runs on a bg thread; sequencing is handled by caller, but locking is done in here std::deque<BatchPtr> fetchBatches() { + timer::Timer total; + typedef typename Sample::value_type Item; auto itemCmp = [](const Item& sa, const Item& sb) { return sa.size() < sb.size(); }; // sort by element length, not content @@ -135,19 +138,29 @@ private: if(current_ != data_->end()) ++current_; } - size_t sets = 0; - while(current_ != data_->end() && maxiBatch->size() < maxSize) { // loop over data + + Samples maxiBatchTemp; + while(current_ != data_->end() && maxiBatchTemp.size() < maxSize) { // loop over data if (saveAndExitRequested()) // stop generating batches return std::deque<BatchPtr>(); - maxiBatch->push(*current_); - sets = current_->size(); + + maxiBatchTemp.push_back(*current_); + // do not consume more than required for the maxi batch as this causes // that line-by-line translation is delayed by one sentence - bool last = maxiBatch->size() == maxSize; + bool last = maxiBatchTemp.size() == maxSize; if(!last) ++current_; // this actually reads the next line and pre-processes it } - size_t numSentencesRead = maxiBatch->size(); + size_t numSentencesRead = maxiBatchTemp.size(); + + size_t sets = 0; + for(auto&& s : maxiBatchTemp) { + if(!s.empty()) { + sets = s.size(); + maxiBatch->push(s); + } + } // construct the actual batches and place them in the queue Samples batchVector; @@ -163,6 +176,7 @@ private: BatchStats::const_iterator cachedStatsIter; if (stats_) cachedStatsIter = stats_->begin(); + while(!maxiBatch->empty()) { // while there are sentences in the queue if (saveAndExitRequested()) // stop generating batches return std::deque<BatchPtr>(); @@ -178,12 +192,7 @@ private: lengths[i] = batchVector.back()[i].size(); // record max lengths so far maxBatchSize = stats_->findBatchSize(lengths, cachedStatsIter); - // this optimization makes no difference indeed -#if 0 // sanity check: would we find the same entry if searching from the start? - auto it = stats_->lower_bound(lengths); - auto maxBatchSize1 = stats_->findBatchSize(lengths, it); - ABORT_IF(maxBatchSize != maxBatchSize1, "findBatchSize iter caching logic is borked"); -#endif + makeBatch = batchVector.size() >= maxBatchSize; // if last added sentence caused a bump then we likely have bad padding, so rather move it into the next batch if(batchVector.size() > maxBatchSize) { @@ -231,6 +240,8 @@ private: LOG(debug, "[data] fetched {} batches with {} sentences. Per batch: {} sentences, {} labels.", tempBatches.size(), numSentencesRead, (double)totalSent / (double)totalDenom, (double)totalLabels / (double)totalDenom); + LOG(debug, "[data] fetching batches took {:.2f} seconds, {:.2f} sents/s", total.elapsed(), (double)numSentencesRead / total.elapsed()); + return tempBatches; } |