diff options
Diffstat (limited to 'src/translator/scorers.h')
-rw-r--r-- | src/translator/scorers.h | 21 |
1 files changed, 20 insertions, 1 deletions
diff --git a/src/translator/scorers.h b/src/translator/scorers.h index a5a0be2c..72ebff5d 100644 --- a/src/translator/scorers.h +++ b/src/translator/scorers.h @@ -73,12 +73,22 @@ class ScorerWrapper : public Scorer { private: Ptr<IEncoderDecoder> encdec_; std::string fname_; + std::vector<io::Item> items_; const void* ptr_; public: ScorerWrapper(Ptr<models::IModel> encdec, const std::string& name, float weight, + std::vector<io::Item>& items) + : Scorer(name, weight), + encdec_(std::static_pointer_cast<IEncoderDecoder>(encdec)), + items_(items), + ptr_{0} {} + + ScorerWrapper(Ptr<models::IModel> encdec, + const std::string& name, + float weight, const std::string& fname) : Scorer(name, weight), encdec_(std::static_pointer_cast<IEncoderDecoder>(encdec)), @@ -97,7 +107,9 @@ public: virtual void init(Ptr<ExpressionGraph> graph) override { graph->switchParams(getName()); - if(ptr_) + if(!items_.empty()) + encdec_->load(graph, items_); + else if(ptr_) encdec_->mmap(graph, ptr_); else encdec_->load(graph, fname_); @@ -143,11 +155,18 @@ public: }; Ptr<Scorer> scorerByType(const std::string& fname, + float weight, + std::vector<io::Item> items, + Ptr<Options> options); + +Ptr<Scorer> scorerByType(const std::string& fname, float weight, const std::string& model, Ptr<Options> config); + std::vector<Ptr<Scorer>> createScorers(Ptr<Options> options); +std::vector<Ptr<Scorer>> createScorers(Ptr<Options> options, const std::vector<std::vector<io::Item>> models); Ptr<Scorer> scorerByType(const std::string& fname, float weight, |