Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'src/translator/translator.h')
-rw-r--r--src/translator/translator.h77
1 files changed, 44 insertions, 33 deletions
diff --git a/src/translator/translator.h b/src/translator/translator.h
index db1f3d03..4084ced9 100644
--- a/src/translator/translator.h
+++ b/src/translator/translator.h
@@ -20,12 +20,7 @@
#include "translator/scorers.h"
// currently for diagnostics only, will try to mmap files ending in *.bin suffix when enabled.
-// @TODO: add this as an actual feature.
-#define MMAP 0
-
-#if MMAP
#include "3rd_party/mio/mio.hpp"
-#endif
namespace marian {
@@ -42,9 +37,8 @@ private:
size_t numDevices_;
-#if MMAP
- std::vector<mio::mmap_source> mmaps_;
-#endif
+ std::vector<mio::mmap_source> model_mmaps_; // map
+ std::vector<std::vector<io::Item>> model_items_; // non-mmap
public:
Translate(Ptr<Options> options)
@@ -76,15 +70,21 @@ public:
scorers_.resize(numDevices_);
graphs_.resize(numDevices_);
-#if MMAP
auto models = options->get<std::vector<std::string>>("models");
- for(auto model : models) {
- marian::filesystem::Path modelPath(model);
- ABORT_IF(modelPath.extension() != marian::filesystem::Path(".bin"),
- "Non-binarized models cannot be mmapped");
- mmaps_.push_back(std::move(mio::mmap_source(model)));
+ if(options_->get<bool>("model-mmap", false)) {
+ for(auto model : models) {
+ ABORT_IF(!io::isBin(model), "Non-binarized models cannot be mmapped");
+ LOG(info, "Loading model from {}", model);
+ model_mmaps_.push_back(mio::mmap_source(model));
+ }
+ }
+ else {
+ for(auto model : models) {
+ LOG(info, "Loading model from {}", model);
+ auto items = io::loadItems(model);
+ model_items_.push_back(std::move(items));
+ }
}
-#endif
size_t id = 0;
for(auto device : devices) {
@@ -101,11 +101,14 @@ public:
graph->reserveWorkspaceMB(options_->get<size_t>("workspace"));
graphs_[id] = graph;
-#if MMAP
- auto scorers = createScorers(options_, mmaps_);
-#else
- auto scorers = createScorers(options_);
-#endif
+ std::vector<Ptr<Scorer>> scorers;
+ if(options_->get<bool>("model-mmap", false)) {
+ scorers = createScorers(options_, model_mmaps_);
+ }
+ else {
+ scorers = createScorers(options_, model_items_);
+ }
+
for(auto scorer : scorers) {
scorer->init(graph);
if(shortlistGenerator_)
@@ -146,11 +149,11 @@ public:
std::mutex syncCounts;
// timer and counters for total elapsed time and statistics
- std::unique_ptr<timer::Timer> totTimer(new timer::Timer());
+ std::unique_ptr<timer::Timer> totTimer(new timer::Timer());
size_t totBatches = 0;
size_t totLines = 0;
size_t totSourceTokens = 0;
-
+
// timer and counters for elapsed time and statistics between updates
std::unique_ptr<timer::Timer> curTimer(new timer::Timer());
size_t curBatches = 0;
@@ -176,7 +179,7 @@ public:
bg.prepare();
for(auto batch : bg) {
auto task = [=, &syncCounts,
- &totBatches, &totLines, &totSourceTokens, &totTimer,
+ &totBatches, &totLines, &totSourceTokens, &totTimer,
&curBatches, &curLines, &curSourceTokens, &curTimer](size_t id) {
thread_local Ptr<ExpressionGraph> graph;
thread_local std::vector<Ptr<Scorer>> scorers;
@@ -200,12 +203,12 @@ public:
}
// if we asked for speed information display this
- if(statFreq.n > 0) {
+ if(statFreq.n > 0) {
std::lock_guard<std::mutex> lock(syncCounts);
- totBatches++;
+ totBatches++;
totLines += batch->size();
totSourceTokens += batch->front()->batchWords();
-
+
curBatches++;
curLines += batch->size();
curSourceTokens += batch->front()->batchWords();
@@ -214,10 +217,10 @@ public:
double totTime = totTimer->elapsed();
double curTime = curTimer->elapsed();
- LOG(info,
- "Processed {} batches, {} lines, {} source tokens in {:.2f}s - Speed (since last): {:.2f} batches/s - {:.2f} lines/s - {:.2f} tokens/s",
+ LOG(info,
+ "Processed {} batches, {} lines, {} source tokens in {:.2f}s - Speed (since last): {:.2f} batches/s - {:.2f} lines/s - {:.2f} tokens/s",
totBatches, totLines, totSourceTokens, totTime, curBatches / curTime, curLines / curTime, curSourceTokens / curTime);
-
+
// reset stats between updates
curBatches = curLines = curSourceTokens = 0;
curTimer.reset(new timer::Timer());
@@ -230,12 +233,12 @@ public:
// make sure threads are joined before other local variables get de-allocated
threadPool.join_all();
-
+
// display final speed numbers over total translation if intermediate displays were requested
if(statFreq.n > 0) {
double totTime = totTimer->elapsed();
- LOG(info,
- "Processed {} batches, {} lines, {} source tokens in {:.2f}s - Speed (total): {:.2f} batches/s - {:.2f} lines/s - {:.2f} tokens/s",
+ LOG(info,
+ "Processed {} batches, {} lines, {} source tokens in {:.2f}s - Speed (total): {:.2f} batches/s - {:.2f} lines/s - {:.2f} tokens/s",
totBatches, totLines, totSourceTokens, totTime, totBatches / totTime, totLines / totTime, totSourceTokens / totTime);
}
}
@@ -288,6 +291,14 @@ public:
auto devices = Config::getDevices(options_);
numDevices_ = devices.size();
+ // preload models
+ std::vector<std::vector<io::Item>> model_items_;
+ auto models = options->get<std::vector<std::string>>("models");
+ for(auto model : models) {
+ auto items = io::loadItems(model);
+ model_items_.push_back(std::move(items));
+ }
+
// initialize scorers
for(auto device : devices) {
auto graph = New<ExpressionGraph>(true);
@@ -303,7 +314,7 @@ public:
graph->reserveWorkspaceMB(options_->get<size_t>("workspace"));
graphs_.push_back(graph);
- auto scorers = createScorers(options_);
+ auto scorers = createScorers(options_, model_items_);
for(auto scorer : scorers) {
scorer->init(graph);
if(shortlistGenerator_)