Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'src/data/batch_generator.h')
-rw-r--r--src/data/batch_generator.h35
1 files changed, 23 insertions, 12 deletions
diff --git a/src/data/batch_generator.h b/src/data/batch_generator.h
index a248db23..ea977468 100644
--- a/src/data/batch_generator.h
+++ b/src/data/batch_generator.h
@@ -2,6 +2,7 @@
#include "common/options.h"
#include "common/signal_handling.h"
+#include "common/timer.h"
#include "data/batch_stats.h"
#include "data/rng_engine.h"
#include "training/training_state.h"
@@ -92,6 +93,8 @@ private:
// this runs on a bg thread; sequencing is handled by caller, but locking is done in here
std::deque<BatchPtr> fetchBatches() {
+ timer::Timer total;
+
typedef typename Sample::value_type Item;
auto itemCmp = [](const Item& sa, const Item& sb) { return sa.size() < sb.size(); }; // sort by element length, not content
@@ -135,19 +138,29 @@ private:
if(current_ != data_->end())
++current_;
}
- size_t sets = 0;
- while(current_ != data_->end() && maxiBatch->size() < maxSize) { // loop over data
+
+ Samples maxiBatchTemp;
+ while(current_ != data_->end() && maxiBatchTemp.size() < maxSize) { // loop over data
if (saveAndExitRequested()) // stop generating batches
return std::deque<BatchPtr>();
- maxiBatch->push(*current_);
- sets = current_->size();
+
+ maxiBatchTemp.push_back(*current_);
+
// do not consume more than required for the maxi batch as this causes
// that line-by-line translation is delayed by one sentence
- bool last = maxiBatch->size() == maxSize;
+ bool last = maxiBatchTemp.size() == maxSize;
if(!last)
++current_; // this actually reads the next line and pre-processes it
}
- size_t numSentencesRead = maxiBatch->size();
+ size_t numSentencesRead = maxiBatchTemp.size();
+
+ size_t sets = 0;
+ for(auto&& s : maxiBatchTemp) {
+ if(!s.empty()) {
+ sets = s.size();
+ maxiBatch->push(s);
+ }
+ }
// construct the actual batches and place them in the queue
Samples batchVector;
@@ -163,6 +176,7 @@ private:
BatchStats::const_iterator cachedStatsIter;
if (stats_)
cachedStatsIter = stats_->begin();
+
while(!maxiBatch->empty()) { // while there are sentences in the queue
if (saveAndExitRequested()) // stop generating batches
return std::deque<BatchPtr>();
@@ -178,12 +192,7 @@ private:
lengths[i] = batchVector.back()[i].size(); // record max lengths so far
maxBatchSize = stats_->findBatchSize(lengths, cachedStatsIter);
- // this optimization makes no difference indeed
-#if 0 // sanity check: would we find the same entry if searching from the start?
- auto it = stats_->lower_bound(lengths);
- auto maxBatchSize1 = stats_->findBatchSize(lengths, it);
- ABORT_IF(maxBatchSize != maxBatchSize1, "findBatchSize iter caching logic is borked");
-#endif
+
makeBatch = batchVector.size() >= maxBatchSize;
// if last added sentence caused a bump then we likely have bad padding, so rather move it into the next batch
if(batchVector.size() > maxBatchSize) {
@@ -231,6 +240,8 @@ private:
LOG(debug, "[data] fetched {} batches with {} sentences. Per batch: {} sentences, {} labels.",
tempBatches.size(), numSentencesRead,
(double)totalSent / (double)totalDenom, (double)totalLabels / (double)totalDenom);
+ LOG(debug, "[data] fetching batches took {:.2f} seconds, {:.2f} sents/s", total.elapsed(), (double)numSentencesRead / total.elapsed());
+
return tempBatches;
}