Welcome to mirror list, hosted at ThFree Co, Russian Federation.

scorers.cpp « translator « src - github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: 1d84423c8be8d00c8248b91cf68daaea2ac08eb5 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
#include "translator/scorers.h"

namespace marian {

Ptr<Scorer> scorerByType(std::string fname,
                         float weight,
                         std::string model,
                         Ptr<Config> config) {
  Ptr<Options> options = New<Options>();
  options->merge(config);
  options->set("inference", true);

  std::string type = options->get<std::string>("type");

  // @TODO: solve this better
  if(type == "lm" && config->has("input")) {
    size_t index = config->get<std::vector<std::string>>("input").size();
    options->set("index", index);
  }

  auto encdec = models::from_options(options, models::usage::translation);

  LOG(info, "Loading scorer of type {} as feature {}", type, fname);

  return New<ScorerWrapper>(encdec, fname, weight, model);
}

std::vector<Ptr<Scorer>> createScorers(Ptr<Config> options) {
  std::vector<Ptr<Scorer>> scorers;

  auto models = options->get<std::vector<std::string>>("models");
  int dimVocab = options->get<std::vector<int>>("dim-vocabs").back();

  std::vector<float> weights(models.size(), 1.f);
  if(options->has("weights"))
    weights = options->get<std::vector<float>>("weights");

  int i = 0;
  for(auto model : models) {
    std::string fname = "F" + std::to_string(i);
    auto modelOptions = New<Config>(*options);

    try {
      if(!options->get<bool>("ignore-model-config"))
        modelOptions->loadModelParameters(model);
    } catch(std::runtime_error& e) {
      LOG(warn, "No model settings found in model file");
    }

    scorers.push_back(scorerByType(fname, weights[i], model, modelOptions));
    i++;
  }

  return scorers;
}
}