Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'src/translator/scorers.h')
-rw-r--r--src/translator/scorers.h21
1 files changed, 20 insertions, 1 deletions
diff --git a/src/translator/scorers.h b/src/translator/scorers.h
index a5a0be2c..72ebff5d 100644
--- a/src/translator/scorers.h
+++ b/src/translator/scorers.h
@@ -73,12 +73,22 @@ class ScorerWrapper : public Scorer {
private:
Ptr<IEncoderDecoder> encdec_;
std::string fname_;
+ std::vector<io::Item> items_;
const void* ptr_;
public:
ScorerWrapper(Ptr<models::IModel> encdec,
const std::string& name,
float weight,
+ std::vector<io::Item>& items)
+ : Scorer(name, weight),
+ encdec_(std::static_pointer_cast<IEncoderDecoder>(encdec)),
+ items_(items),
+ ptr_{0} {}
+
+ ScorerWrapper(Ptr<models::IModel> encdec,
+ const std::string& name,
+ float weight,
const std::string& fname)
: Scorer(name, weight),
encdec_(std::static_pointer_cast<IEncoderDecoder>(encdec)),
@@ -97,7 +107,9 @@ public:
virtual void init(Ptr<ExpressionGraph> graph) override {
graph->switchParams(getName());
- if(ptr_)
+ if(!items_.empty())
+ encdec_->load(graph, items_);
+ else if(ptr_)
encdec_->mmap(graph, ptr_);
else
encdec_->load(graph, fname_);
@@ -143,11 +155,18 @@ public:
};
Ptr<Scorer> scorerByType(const std::string& fname,
+ float weight,
+ std::vector<io::Item> items,
+ Ptr<Options> options);
+
+Ptr<Scorer> scorerByType(const std::string& fname,
float weight,
const std::string& model,
Ptr<Options> config);
+
std::vector<Ptr<Scorer>> createScorers(Ptr<Options> options);
+std::vector<Ptr<Scorer>> createScorers(Ptr<Options> options, const std::vector<std::vector<io::Item>> models);
Ptr<Scorer> scorerByType(const std::string& fname,
float weight,