diff options
Diffstat (limited to 'src/translator/translator.h')
-rw-r--r-- | src/translator/translator.h | 77 |
1 files changed, 44 insertions, 33 deletions
diff --git a/src/translator/translator.h b/src/translator/translator.h index db1f3d03..4084ced9 100644 --- a/src/translator/translator.h +++ b/src/translator/translator.h @@ -20,12 +20,7 @@ #include "translator/scorers.h" // currently for diagnostics only, will try to mmap files ending in *.bin suffix when enabled. -// @TODO: add this as an actual feature. -#define MMAP 0 - -#if MMAP #include "3rd_party/mio/mio.hpp" -#endif namespace marian { @@ -42,9 +37,8 @@ private: size_t numDevices_; -#if MMAP - std::vector<mio::mmap_source> mmaps_; -#endif + std::vector<mio::mmap_source> model_mmaps_; // map + std::vector<std::vector<io::Item>> model_items_; // non-mmap public: Translate(Ptr<Options> options) @@ -76,15 +70,21 @@ public: scorers_.resize(numDevices_); graphs_.resize(numDevices_); -#if MMAP auto models = options->get<std::vector<std::string>>("models"); - for(auto model : models) { - marian::filesystem::Path modelPath(model); - ABORT_IF(modelPath.extension() != marian::filesystem::Path(".bin"), - "Non-binarized models cannot be mmapped"); - mmaps_.push_back(std::move(mio::mmap_source(model))); + if(options_->get<bool>("model-mmap", false)) { + for(auto model : models) { + ABORT_IF(!io::isBin(model), "Non-binarized models cannot be mmapped"); + LOG(info, "Loading model from {}", model); + model_mmaps_.push_back(mio::mmap_source(model)); + } + } + else { + for(auto model : models) { + LOG(info, "Loading model from {}", model); + auto items = io::loadItems(model); + model_items_.push_back(std::move(items)); + } } -#endif size_t id = 0; for(auto device : devices) { @@ -101,11 +101,14 @@ public: graph->reserveWorkspaceMB(options_->get<size_t>("workspace")); graphs_[id] = graph; -#if MMAP - auto scorers = createScorers(options_, mmaps_); -#else - auto scorers = createScorers(options_); -#endif + std::vector<Ptr<Scorer>> scorers; + if(options_->get<bool>("model-mmap", false)) { + scorers = createScorers(options_, model_mmaps_); + } + else { + scorers = createScorers(options_, model_items_); + } + for(auto scorer : scorers) { scorer->init(graph); if(shortlistGenerator_) @@ -146,11 +149,11 @@ public: std::mutex syncCounts; // timer and counters for total elapsed time and statistics - std::unique_ptr<timer::Timer> totTimer(new timer::Timer()); + std::unique_ptr<timer::Timer> totTimer(new timer::Timer()); size_t totBatches = 0; size_t totLines = 0; size_t totSourceTokens = 0; - + // timer and counters for elapsed time and statistics between updates std::unique_ptr<timer::Timer> curTimer(new timer::Timer()); size_t curBatches = 0; @@ -176,7 +179,7 @@ public: bg.prepare(); for(auto batch : bg) { auto task = [=, &syncCounts, - &totBatches, &totLines, &totSourceTokens, &totTimer, + &totBatches, &totLines, &totSourceTokens, &totTimer, &curBatches, &curLines, &curSourceTokens, &curTimer](size_t id) { thread_local Ptr<ExpressionGraph> graph; thread_local std::vector<Ptr<Scorer>> scorers; @@ -200,12 +203,12 @@ public: } // if we asked for speed information display this - if(statFreq.n > 0) { + if(statFreq.n > 0) { std::lock_guard<std::mutex> lock(syncCounts); - totBatches++; + totBatches++; totLines += batch->size(); totSourceTokens += batch->front()->batchWords(); - + curBatches++; curLines += batch->size(); curSourceTokens += batch->front()->batchWords(); @@ -214,10 +217,10 @@ public: double totTime = totTimer->elapsed(); double curTime = curTimer->elapsed(); - LOG(info, - "Processed {} batches, {} lines, {} source tokens in {:.2f}s - Speed (since last): {:.2f} batches/s - {:.2f} lines/s - {:.2f} tokens/s", + LOG(info, + "Processed {} batches, {} lines, {} source tokens in {:.2f}s - Speed (since last): {:.2f} batches/s - {:.2f} lines/s - {:.2f} tokens/s", totBatches, totLines, totSourceTokens, totTime, curBatches / curTime, curLines / curTime, curSourceTokens / curTime); - + // reset stats between updates curBatches = curLines = curSourceTokens = 0; curTimer.reset(new timer::Timer()); @@ -230,12 +233,12 @@ public: // make sure threads are joined before other local variables get de-allocated threadPool.join_all(); - + // display final speed numbers over total translation if intermediate displays were requested if(statFreq.n > 0) { double totTime = totTimer->elapsed(); - LOG(info, - "Processed {} batches, {} lines, {} source tokens in {:.2f}s - Speed (total): {:.2f} batches/s - {:.2f} lines/s - {:.2f} tokens/s", + LOG(info, + "Processed {} batches, {} lines, {} source tokens in {:.2f}s - Speed (total): {:.2f} batches/s - {:.2f} lines/s - {:.2f} tokens/s", totBatches, totLines, totSourceTokens, totTime, totBatches / totTime, totLines / totTime, totSourceTokens / totTime); } } @@ -288,6 +291,14 @@ public: auto devices = Config::getDevices(options_); numDevices_ = devices.size(); + // preload models + std::vector<std::vector<io::Item>> model_items_; + auto models = options->get<std::vector<std::string>>("models"); + for(auto model : models) { + auto items = io::loadItems(model); + model_items_.push_back(std::move(items)); + } + // initialize scorers for(auto device : devices) { auto graph = New<ExpressionGraph>(true); @@ -303,7 +314,7 @@ public: graph->reserveWorkspaceMB(options_->get<size_t>("workspace")); graphs_.push_back(graph); - auto scorers = createScorers(options_); + auto scorers = createScorers(options_, model_items_); for(auto scorer : scorers) { scorer->init(graph); if(shortlistGenerator_) |