Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGraeme Nail <graemenail.work@gmail.com>2022-01-18 15:58:52 +0300
committerGitHub <noreply@github.com>2022-01-18 15:58:52 +0300
commitb29cc07a95f49df7825f3a92e860bd642db0e812 (patch)
tree222dacb68ca3b41692dc6b4f90bb5b8d2108c877
parentc84599d08ad69059279abd5a7417a8053db8b631 (diff)
Scorer model loading (#860)
* Add MMAP as an option * Use io::isBin * Allow getYamlFromModel from an Item vector * ScorerWrapper can now load on to a graph from Item vector The interface IEncoderDecoder can now call graph loads directly from an Item Vector. * Translator loads model before creating scorers Scorers are created from an Item vector * Replace model-config try-catch with check using IsNull * Prefer empty vs size * load by items should be pure virtual * Stepwise forward load to encdec * nematus can load from items * amun can load from items * loadItems in TranslateService * Remove logging * Remove by filename scorer functions * Replace by filename createScorer * Explicitly provide default value for get model-mmap * CLI option for model-mmap only for translation and CPU compile * Ensure model-mmap option is CPU only * Remove move on temporary object * Reinstate log messages for model loading in Amun / Nematus * Add log messages for model loading in scorers Co-authored-by: Roman Grundkiewicz <rgrundkiewicz@gmail.com>
-rw-r--r--CHANGELOG.md1
-rw-r--r--src/common/config_parser.cpp7
-rw-r--r--src/common/config_validator.cpp3
-rw-r--r--src/common/io.cpp12
-rw-r--r--src/common/io.h1
-rw-r--r--src/graph/expression_graph.h2
-rw-r--r--src/models/amun.h14
-rw-r--r--src/models/costs.h6
-rw-r--r--src/models/encoder_decoder.cpp6
-rw-r--r--src/models/encoder_decoder.h9
-rw-r--r--src/models/nematus.h16
-rw-r--r--src/translator/scorers.cpp51
-rw-r--r--src/translator/scorers.h21
-rw-r--r--src/translator/translator.h77
14 files changed, 162 insertions, 64 deletions
diff --git a/CHANGELOG.md b/CHANGELOG.md
index a5dd305f..d42c652e 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -27,6 +27,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
- Dynamic gradient-scaling with `--dynamic-gradient-scaling`.
- Add unit tests for binary files.
- Fix compilation with OMP
+- Added `--model-mmap` option to enable mmap loading for CPU-based translation
- Compute aligned memory sizes using exact sizing
- Support for loading lexical shortlist from a binary blob
- Integrate a shortlist converter (which can convert a text lexical shortlist to a binary shortlist) into marian-conv with --shortlist option
diff --git a/src/common/config_parser.cpp b/src/common/config_parser.cpp
index 8da9520c..9705d5b7 100644
--- a/src/common/config_parser.cpp
+++ b/src/common/config_parser.cpp
@@ -183,7 +183,12 @@ void ConfigParser::addOptionsModel(cli::CLIWrapper& cli) {
"Path prefix for pre-trained model to initialize model weights");
}
}
-
+#ifdef COMPILE_CPU
+ if(mode_ == cli::mode::translation) {
+ cli.add<bool>("--model-mmap",
+ "Use memory-mapping when loading model (CPU only)");
+ }
+#endif
cli.add<bool>("--ignore-model-config",
"Ignore the model configuration saved in npz file");
cli.add<std::string>("--type",
diff --git a/src/common/config_validator.cpp b/src/common/config_validator.cpp
index fea7578f..b0230da9 100644
--- a/src/common/config_validator.cpp
+++ b/src/common/config_validator.cpp
@@ -54,6 +54,9 @@ void ConfigValidator::validateOptionsTranslation() const {
ABORT_IF(models.empty() && configs.empty(),
"You need to provide at least one model file or a config file");
+ ABORT_IF(get<bool>("model-mmap") && get<size_t>("cpu-threads") == 0,
+ "Model MMAP is CPU-only, please use --cpu-threads");
+
for(const auto& modelFile : models) {
filesystem::Path modelPath(modelFile);
ABORT_IF(!filesystem::exists(modelPath), "Model file does not exist: " + modelFile);
diff --git a/src/common/io.cpp b/src/common/io.cpp
index a9984b5d..e0b3f39a 100644
--- a/src/common/io.cpp
+++ b/src/common/io.cpp
@@ -56,6 +56,18 @@ void getYamlFromModel(YAML::Node& yaml,
yaml = YAML::Load(item.data());
}
+// Load YAML from item
+void getYamlFromModel(YAML::Node& yaml,
+ const std::string& varName,
+ const std::vector<Item>& items) {
+ for(auto& item : items) {
+ if(item.name == varName) {
+ yaml = YAML::Load(item.data());
+ return;
+ }
+ }
+}
+
void addMetaToItems(const std::string& meta,
const std::string& varName,
std::vector<io::Item>& items) {
diff --git a/src/common/io.h b/src/common/io.h
index 2d18d66e..3f340ed2 100644
--- a/src/common/io.h
+++ b/src/common/io.h
@@ -21,6 +21,7 @@ bool isBin(const std::string& fileName);
void getYamlFromModel(YAML::Node& yaml, const std::string& varName, const std::string& fileName);
void getYamlFromModel(YAML::Node& yaml, const std::string& varName, const void* ptr);
+void getYamlFromModel(YAML::Node& yaml, const std::string& varName, const std::vector<Item>& items);
void addMetaToItems(const std::string& meta,
const std::string& varName,
diff --git a/src/graph/expression_graph.h b/src/graph/expression_graph.h
index 553a5d63..c532abff 100644
--- a/src/graph/expression_graph.h
+++ b/src/graph/expression_graph.h
@@ -739,7 +739,7 @@ public:
public:
/** Load model (mainly parameter objects) from array of io::Items */
- void load(std::vector<io::Item>& ioItems, bool markReloaded = true) {
+ void load(const std::vector<io::Item>& ioItems, bool markReloaded = true) {
setReloaded(false);
for(auto& item : ioItems) {
std::string pName = item.name;
diff --git a/src/models/amun.h b/src/models/amun.h
index 1bfda269..135ce359 100644
--- a/src/models/amun.h
+++ b/src/models/amun.h
@@ -36,7 +36,7 @@ public:
}
void load(Ptr<ExpressionGraph> graph,
- const std::string& name,
+ const std::vector<io::Item>& items,
bool /*markedReloaded*/ = true) override {
std::map<std::string, std::string> nameMap
= {{"decoder_U", "decoder_cell1_U"},
@@ -89,9 +89,7 @@ public:
if(opt<bool>("tied-embeddings-src") || opt<bool>("tied-embeddings-all"))
nameMap["Wemb"] = "Wemb";
- LOG(info, "Loading model from {}", name);
- // load items from .npz file
- auto ioItems = io::loadItems(name);
+ auto ioItems = items;
// map names and remove a dummy matrices
for(auto it = ioItems.begin(); it != ioItems.end();) {
// for backwards compatibility, turn one-dimensional vector into two dimensional matrix with first dimension being 1 and second dimension of the original size
@@ -120,6 +118,14 @@ public:
graph->load(ioItems);
}
+ void load(Ptr<ExpressionGraph> graph,
+ const std::string& name,
+ bool /*markReloaded*/ = true) override {
+ LOG(info, "Loading model from {}", name);
+ auto ioItems = io::loadItems(name);
+ load(graph, ioItems);
+ }
+
void save(Ptr<ExpressionGraph> graph,
const std::string& name,
bool saveTranslatorConfig = false) override {
diff --git a/src/models/costs.h b/src/models/costs.h
index e5463bfd..982a13c5 100644
--- a/src/models/costs.h
+++ b/src/models/costs.h
@@ -326,6 +326,12 @@ public:
Stepwise(Ptr<IEncoderDecoder> encdec, Ptr<ILogProbStep> cost) : encdec_(encdec), cost_(cost) {}
virtual void load(Ptr<ExpressionGraph> graph,
+ const std::vector<io::Item>& items,
+ bool markedReloaded = true) override {
+ encdec_->load(graph, items, markedReloaded);
+ }
+
+ virtual void load(Ptr<ExpressionGraph> graph,
const std::string& name,
bool markedReloaded = true) override {
encdec_->load(graph, name, markedReloaded);
diff --git a/src/models/encoder_decoder.cpp b/src/models/encoder_decoder.cpp
index 66ff16ce..bb938ee5 100644
--- a/src/models/encoder_decoder.cpp
+++ b/src/models/encoder_decoder.cpp
@@ -145,6 +145,12 @@ std::string EncoderDecoder::getModelParametersAsString() {
}
void EncoderDecoder::load(Ptr<ExpressionGraph> graph,
+ const std::vector<io::Item>& items,
+ bool markedReloaded) {
+ graph->load(items, markedReloaded && !opt<bool>("ignore-model-config", false));
+}
+
+void EncoderDecoder::load(Ptr<ExpressionGraph> graph,
const std::string& name,
bool markedReloaded) {
graph->load(name, markedReloaded && !opt<bool>("ignore-model-config", false));
diff --git a/src/models/encoder_decoder.h b/src/models/encoder_decoder.h
index 92c1647f..0fbf3faf 100644
--- a/src/models/encoder_decoder.h
+++ b/src/models/encoder_decoder.h
@@ -12,6 +12,11 @@ namespace marian {
class IEncoderDecoder : public models::IModel {
public:
virtual ~IEncoderDecoder() {}
+
+ virtual void load(Ptr<ExpressionGraph> graph,
+ const std::vector<io::Item>& items,
+ bool markedReloaded = true) = 0;
+
virtual void load(Ptr<ExpressionGraph> graph,
const std::string& name,
bool markedReloaded = true) override
@@ -92,6 +97,10 @@ public:
void push_back(Ptr<DecoderBase> decoder);
virtual void load(Ptr<ExpressionGraph> graph,
+ const std::vector<io::Item>& items,
+ bool markedReloaded = true) override;
+
+ virtual void load(Ptr<ExpressionGraph> graph,
const std::string& name,
bool markedReloaded = true) override;
diff --git a/src/models/nematus.h b/src/models/nematus.h
index 730418e5..aee8e3b0 100644
--- a/src/models/nematus.h
+++ b/src/models/nematus.h
@@ -26,11 +26,9 @@ public:
}
void load(Ptr<ExpressionGraph> graph,
- const std::string& name,
+ const std::vector<io::Item>& items,
bool /*markReloaded*/ = true) override {
- LOG(info, "Loading model from {}", name);
- // load items from .npz file
- auto ioItems = io::loadItems(name);
+ auto ioItems = items;
// map names and remove a dummy matrix 'decoder_c_tt' from items to avoid creating isolated node
for(auto it = ioItems.begin(); it != ioItems.end();) {
// for backwards compatibility, turn one-dimensional vector into two dimensional matrix with first dimension being 1 and second dimension of the original size
@@ -41,7 +39,7 @@ public:
it->shape.set(0, 1);
it->shape.set(1, dim);
}
-
+
if(it->name == "decoder_c_tt") {
it = ioItems.erase(it);
} else if(it->name == "uidx") {
@@ -59,6 +57,14 @@ public:
graph->load(ioItems);
}
+ void load(Ptr<ExpressionGraph> graph,
+ const std::string& name,
+ bool /*markReloaded*/ = true) override {
+ LOG(info, "Loading model from {}", name);
+ auto ioItems = io::loadItems(name);
+ load(graph, ioItems);
+ }
+
void save(Ptr<ExpressionGraph> graph,
const std::string& name,
bool saveTranslatorConfig = false) override {
diff --git a/src/translator/scorers.cpp b/src/translator/scorers.cpp
index d1c8b160..60ec03dd 100644
--- a/src/translator/scorers.cpp
+++ b/src/translator/scorers.cpp
@@ -5,7 +5,7 @@ namespace marian {
Ptr<Scorer> scorerByType(const std::string& fname,
float weight,
- const std::string& model,
+ std::vector<io::Item> items,
Ptr<Options> options) {
options->set("inference", true);
std::string type = options->get<std::string>("type");
@@ -22,7 +22,7 @@ Ptr<Scorer> scorerByType(const std::string& fname,
LOG(info, "Loading scorer of type {} as feature {}", type, fname);
- return New<ScorerWrapper>(encdec, fname, weight, model);
+ return New<ScorerWrapper>(encdec, fname, weight, items);
}
Ptr<Scorer> scorerByType(const std::string& fname,
@@ -47,30 +47,30 @@ Ptr<Scorer> scorerByType(const std::string& fname,
return New<ScorerWrapper>(encdec, fname, weight, ptr);
}
-std::vector<Ptr<Scorer>> createScorers(Ptr<Options> options) {
+std::vector<Ptr<Scorer>> createScorers(Ptr<Options> options, const std::vector<std::vector<io::Item>> models) {
std::vector<Ptr<Scorer>> scorers;
- auto models = options->get<std::vector<std::string>>("models");
-
std::vector<float> weights(models.size(), 1.f);
if(options->hasAndNotEmpty("weights"))
weights = options->get<std::vector<float>>("weights");
bool isPrevRightLeft = false; // if the previous model was a right-to-left model
size_t i = 0;
- for(auto model : models) {
+ for(auto items : models) {
std::string fname = "F" + std::to_string(i);
// load options specific for the scorer
auto modelOptions = New<Options>(options->clone());
- try {
- if(!options->get<bool>("ignore-model-config")) {
- YAML::Node modelYaml;
- io::getYamlFromModel(modelYaml, "special:model.yml", model);
+ if(!options->get<bool>("ignore-model-config")) {
+ YAML::Node modelYaml;
+ io::getYamlFromModel(modelYaml, "special:model.yml", items);
+ if(!modelYaml.IsNull()) {
+ LOG(info, "Loaded model config");
modelOptions->merge(modelYaml, true);
}
- } catch(std::runtime_error&) {
- LOG(warn, "No model settings found in model file");
+ else {
+ LOG(warn, "No model settings found in model file");
+ }
}
// l2r and r2l cannot be used in the same ensemble
@@ -85,13 +85,24 @@ std::vector<Ptr<Scorer>> createScorers(Ptr<Options> options) {
}
}
- scorers.push_back(scorerByType(fname, weights[i], model, modelOptions));
+ scorers.push_back(scorerByType(fname, weights[i], items, modelOptions));
i++;
}
return scorers;
}
+std::vector<Ptr<Scorer>> createScorers(Ptr<Options> options) {
+ std::vector<std::vector<io::Item>> model_items;
+ auto models = options->get<std::vector<std::string>>("models");
+ for(auto model : models) {
+ auto items = io::loadItems(model);
+ model_items.push_back(std::move(items));
+ }
+
+ return createScorers(options, model_items);
+}
+
std::vector<Ptr<Scorer>> createScorers(Ptr<Options> options, const std::vector<const void*>& ptrs) {
std::vector<Ptr<Scorer>> scorers;
@@ -105,14 +116,16 @@ std::vector<Ptr<Scorer>> createScorers(Ptr<Options> options, const std::vector<c
// load options specific for the scorer
auto modelOptions = New<Options>(options->clone());
- try {
- if(!options->get<bool>("ignore-model-config")) {
- YAML::Node modelYaml;
- io::getYamlFromModel(modelYaml, "special:model.yml", ptr);
+ if(!options->get<bool>("ignore-model-config")) {
+ YAML::Node modelYaml;
+ io::getYamlFromModel(modelYaml, "special:model.yml", ptr);
+ if(!modelYaml.IsNull()) {
+ LOG(info, "Loaded model config");
modelOptions->merge(modelYaml, true);
}
- } catch(std::runtime_error&) {
- LOG(warn, "No model settings found in model file");
+ else {
+ LOG(warn, "No model settings found in model file");
+ }
}
scorers.push_back(scorerByType(fname, weights[i], ptr, modelOptions));
diff --git a/src/translator/scorers.h b/src/translator/scorers.h
index a5a0be2c..72ebff5d 100644
--- a/src/translator/scorers.h
+++ b/src/translator/scorers.h
@@ -73,12 +73,22 @@ class ScorerWrapper : public Scorer {
private:
Ptr<IEncoderDecoder> encdec_;
std::string fname_;
+ std::vector<io::Item> items_;
const void* ptr_;
public:
ScorerWrapper(Ptr<models::IModel> encdec,
const std::string& name,
float weight,
+ std::vector<io::Item>& items)
+ : Scorer(name, weight),
+ encdec_(std::static_pointer_cast<IEncoderDecoder>(encdec)),
+ items_(items),
+ ptr_{0} {}
+
+ ScorerWrapper(Ptr<models::IModel> encdec,
+ const std::string& name,
+ float weight,
const std::string& fname)
: Scorer(name, weight),
encdec_(std::static_pointer_cast<IEncoderDecoder>(encdec)),
@@ -97,7 +107,9 @@ public:
virtual void init(Ptr<ExpressionGraph> graph) override {
graph->switchParams(getName());
- if(ptr_)
+ if(!items_.empty())
+ encdec_->load(graph, items_);
+ else if(ptr_)
encdec_->mmap(graph, ptr_);
else
encdec_->load(graph, fname_);
@@ -143,11 +155,18 @@ public:
};
Ptr<Scorer> scorerByType(const std::string& fname,
+ float weight,
+ std::vector<io::Item> items,
+ Ptr<Options> options);
+
+Ptr<Scorer> scorerByType(const std::string& fname,
float weight,
const std::string& model,
Ptr<Options> config);
+
std::vector<Ptr<Scorer>> createScorers(Ptr<Options> options);
+std::vector<Ptr<Scorer>> createScorers(Ptr<Options> options, const std::vector<std::vector<io::Item>> models);
Ptr<Scorer> scorerByType(const std::string& fname,
float weight,
diff --git a/src/translator/translator.h b/src/translator/translator.h
index db1f3d03..4084ced9 100644
--- a/src/translator/translator.h
+++ b/src/translator/translator.h
@@ -20,12 +20,7 @@
#include "translator/scorers.h"
// currently for diagnostics only, will try to mmap files ending in *.bin suffix when enabled.
-// @TODO: add this as an actual feature.
-#define MMAP 0
-
-#if MMAP
#include "3rd_party/mio/mio.hpp"
-#endif
namespace marian {
@@ -42,9 +37,8 @@ private:
size_t numDevices_;
-#if MMAP
- std::vector<mio::mmap_source> mmaps_;
-#endif
+ std::vector<mio::mmap_source> model_mmaps_; // map
+ std::vector<std::vector<io::Item>> model_items_; // non-mmap
public:
Translate(Ptr<Options> options)
@@ -76,15 +70,21 @@ public:
scorers_.resize(numDevices_);
graphs_.resize(numDevices_);
-#if MMAP
auto models = options->get<std::vector<std::string>>("models");
- for(auto model : models) {
- marian::filesystem::Path modelPath(model);
- ABORT_IF(modelPath.extension() != marian::filesystem::Path(".bin"),
- "Non-binarized models cannot be mmapped");
- mmaps_.push_back(std::move(mio::mmap_source(model)));
+ if(options_->get<bool>("model-mmap", false)) {
+ for(auto model : models) {
+ ABORT_IF(!io::isBin(model), "Non-binarized models cannot be mmapped");
+ LOG(info, "Loading model from {}", model);
+ model_mmaps_.push_back(mio::mmap_source(model));
+ }
+ }
+ else {
+ for(auto model : models) {
+ LOG(info, "Loading model from {}", model);
+ auto items = io::loadItems(model);
+ model_items_.push_back(std::move(items));
+ }
}
-#endif
size_t id = 0;
for(auto device : devices) {
@@ -101,11 +101,14 @@ public:
graph->reserveWorkspaceMB(options_->get<size_t>("workspace"));
graphs_[id] = graph;
-#if MMAP
- auto scorers = createScorers(options_, mmaps_);
-#else
- auto scorers = createScorers(options_);
-#endif
+ std::vector<Ptr<Scorer>> scorers;
+ if(options_->get<bool>("model-mmap", false)) {
+ scorers = createScorers(options_, model_mmaps_);
+ }
+ else {
+ scorers = createScorers(options_, model_items_);
+ }
+
for(auto scorer : scorers) {
scorer->init(graph);
if(shortlistGenerator_)
@@ -146,11 +149,11 @@ public:
std::mutex syncCounts;
// timer and counters for total elapsed time and statistics
- std::unique_ptr<timer::Timer> totTimer(new timer::Timer());
+ std::unique_ptr<timer::Timer> totTimer(new timer::Timer());
size_t totBatches = 0;
size_t totLines = 0;
size_t totSourceTokens = 0;
-
+
// timer and counters for elapsed time and statistics between updates
std::unique_ptr<timer::Timer> curTimer(new timer::Timer());
size_t curBatches = 0;
@@ -176,7 +179,7 @@ public:
bg.prepare();
for(auto batch : bg) {
auto task = [=, &syncCounts,
- &totBatches, &totLines, &totSourceTokens, &totTimer,
+ &totBatches, &totLines, &totSourceTokens, &totTimer,
&curBatches, &curLines, &curSourceTokens, &curTimer](size_t id) {
thread_local Ptr<ExpressionGraph> graph;
thread_local std::vector<Ptr<Scorer>> scorers;
@@ -200,12 +203,12 @@ public:
}
// if we asked for speed information display this
- if(statFreq.n > 0) {
+ if(statFreq.n > 0) {
std::lock_guard<std::mutex> lock(syncCounts);
- totBatches++;
+ totBatches++;
totLines += batch->size();
totSourceTokens += batch->front()->batchWords();
-
+
curBatches++;
curLines += batch->size();
curSourceTokens += batch->front()->batchWords();
@@ -214,10 +217,10 @@ public:
double totTime = totTimer->elapsed();
double curTime = curTimer->elapsed();
- LOG(info,
- "Processed {} batches, {} lines, {} source tokens in {:.2f}s - Speed (since last): {:.2f} batches/s - {:.2f} lines/s - {:.2f} tokens/s",
+ LOG(info,
+ "Processed {} batches, {} lines, {} source tokens in {:.2f}s - Speed (since last): {:.2f} batches/s - {:.2f} lines/s - {:.2f} tokens/s",
totBatches, totLines, totSourceTokens, totTime, curBatches / curTime, curLines / curTime, curSourceTokens / curTime);
-
+
// reset stats between updates
curBatches = curLines = curSourceTokens = 0;
curTimer.reset(new timer::Timer());
@@ -230,12 +233,12 @@ public:
// make sure threads are joined before other local variables get de-allocated
threadPool.join_all();
-
+
// display final speed numbers over total translation if intermediate displays were requested
if(statFreq.n > 0) {
double totTime = totTimer->elapsed();
- LOG(info,
- "Processed {} batches, {} lines, {} source tokens in {:.2f}s - Speed (total): {:.2f} batches/s - {:.2f} lines/s - {:.2f} tokens/s",
+ LOG(info,
+ "Processed {} batches, {} lines, {} source tokens in {:.2f}s - Speed (total): {:.2f} batches/s - {:.2f} lines/s - {:.2f} tokens/s",
totBatches, totLines, totSourceTokens, totTime, totBatches / totTime, totLines / totTime, totSourceTokens / totTime);
}
}
@@ -288,6 +291,14 @@ public:
auto devices = Config::getDevices(options_);
numDevices_ = devices.size();
+ // preload models
+ std::vector<std::vector<io::Item>> model_items_;
+ auto models = options->get<std::vector<std::string>>("models");
+ for(auto model : models) {
+ auto items = io::loadItems(model);
+ model_items_.push_back(std::move(items));
+ }
+
// initialize scorers
for(auto device : devices) {
auto graph = New<ExpressionGraph>(true);
@@ -303,7 +314,7 @@ public:
graph->reserveWorkspaceMB(options_->get<size_t>("workspace"));
graphs_.push_back(graph);
- auto scorers = createScorers(options_);
+ auto scorers = createScorers(options_, model_items_);
for(auto scorer : scorers) {
scorer->init(graph);
if(shortlistGenerator_)