Welcome to mirror list, hosted at ThFree Co, Russian Federation.

kenlm_benchmark_main.cc « lm - github.com/kpu/kenlm.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: 93196ece2188ff169c45dc138b83a5fcdc93c48a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
#include "model.hh"
#include "../util/file_stream.hh"
#include "../util/file.hh"
#include "../util/file_piece.hh"
#include "../util/usage.hh"
#include "../util/thread_pool.hh"

#include <boost/range/iterator_range.hpp>
#include <boost/program_options.hpp>

#include <iostream>

#include <stdint.h>

namespace {

template <class Model, class Width> void ConvertToBytes(const Model &model, int fd_in) {
  util::FilePiece in(fd_in);
  util::FileStream out(1);
  Width width;
  StringPiece word;
  const Width end_sentence = (Width)model.GetVocabulary().EndSentence();
  while (true) {
    while (in.ReadWordSameLine(word)) {
      width = (Width)model.GetVocabulary().Index(word);
      out.write(&width, sizeof(Width));
    }
    if (!in.ReadLineOrEOF(word)) break;
    out.write(&end_sentence, sizeof(Width));
  }
}

template <class Model, class Width> class Worker {
  public:
    explicit Worker(const Model &model, double &add_total) : model_(model), total_(0.0), add_total_(add_total) {}

    // Destructors happen in the main thread, so there's no race for add_total_.
    ~Worker() { add_total_ += total_; }

    typedef boost::iterator_range<Width *> Request;

    void operator()(Request request) {
      const lm::ngram::State *const begin_state = &model_.BeginSentenceState();
      const lm::ngram::State *next_state = begin_state;
      const Width kEOS = model_.GetVocabulary().EndSentence();
      float sum = 0.0;
      // Do even stuff first.
      const Width *even_end = request.begin() + (request.size() & ~1);
      // Alternating states
      const Width *i;
      for (i = request.begin(); i != even_end;) {
        sum += model_.FullScore(*next_state, *i, state_[1]).prob;
        next_state = (*i++ == kEOS) ? begin_state : &state_[1];
        sum += model_.FullScore(*next_state, *i, state_[0]).prob;
        next_state = (*i++ == kEOS) ? begin_state : &state_[0];
      }
      // Odd corner case.
      if (request.size() & 1) {
        sum += model_.FullScore(*next_state, *i, state_[2]).prob;
        next_state = (*i++ == kEOS) ? begin_state : &state_[2];
      }
      total_ += sum;
    }

  private:
    const Model &model_;
    double total_;
    double &add_total_;

    lm::ngram::State state_[3];
};

struct Config {
  int fd_in;
  std::size_t threads;
  std::size_t buf_per_thread;
  bool query;
};

template <class Model, class Width> void QueryFromBytes(const Model &model, const Config &config) {
  util::FileStream out(1);
  out << "Threads: " << config.threads << '\n';
  const Width kEOS = model.GetVocabulary().EndSentence();
  double total = 0.0;
  // Number of items to have in queue in addition to everything in flight.
  const std::size_t kInQueue = 3;
  std::size_t total_queue = config.threads + kInQueue;
  std::vector<Width> backing(config.buf_per_thread * total_queue);
  double loaded_cpu;
  double loaded_wall;
  uint64_t queries = 0;
  {
    util::RecyclingThreadPool<Worker<Model, Width> > pool(total_queue, config.threads, Worker<Model, Width>(model, total), boost::iterator_range<Width *>((Width*)0, (Width*)0));

    for (std::size_t i = 0; i < total_queue; ++i) {
      pool.PopulateRecycling(boost::iterator_range<Width *>(&backing[i * config.buf_per_thread], &backing[i * config.buf_per_thread]));
    }

    loaded_cpu = util::CPUTime();
    loaded_wall = util::WallTime();
    out << "To Load, CPU: " << loaded_cpu << " Wall: " << loaded_wall << '\n';
    boost::iterator_range<Width *> overhang((Width*)0, (Width*)0);
    while (true) {
      boost::iterator_range<Width *> buf = pool.Consume();
      std::memmove(buf.begin(), overhang.begin(), overhang.size() * sizeof(Width));
      std::size_t got = util::ReadOrEOF(config.fd_in, buf.begin() + overhang.size(), (config.buf_per_thread - overhang.size()) * sizeof(Width));
      if (!got && overhang.empty()) break;
      UTIL_THROW_IF2(got % sizeof(Width), "File size not a multiple of vocab id size " << sizeof(Width));
      Width *read_end = buf.begin() + overhang.size() + got / sizeof(Width);
      Width *last_eos;
      for (last_eos = read_end - 1; ; --last_eos) {
        UTIL_THROW_IF2(last_eos <= buf.begin(), "Encountered a sentence longer than the buffer size of " << config.buf_per_thread << " words.  Rerun with increased buffer size. TODO: adaptable buffer");
        if (*last_eos == kEOS) break;
      }
      buf = boost::iterator_range<Width*>(buf.begin(), last_eos + 1);
      overhang = boost::iterator_range<Width*>(last_eos + 1, read_end);
      queries += buf.size();
      pool.Produce(buf);
    }
  } // Drain pool.

  double after_cpu = util::CPUTime();
  double after_wall = util::WallTime();
  util::FileStream(2, 70) << "Probability sum: " << total << '\n';
  out << "Queries: " << queries << '\n';
  out << "Excluding load, CPU: " << (after_cpu - loaded_cpu) << " Wall: " << (after_wall - loaded_wall) << '\n';
  double cpu_per_entry = ((after_cpu - loaded_cpu) / static_cast<double>(queries));
  double wall_per_entry = ((after_wall - loaded_wall) / static_cast<double>(queries));
  out << "Seconds per query excluding load, CPU: " << cpu_per_entry << " Wall: " << wall_per_entry << '\n';
  out << "Queries per second excluding load, CPU: " << (1.0/cpu_per_entry) << " Wall: " << (1.0/wall_per_entry) << '\n';
  out << "RSSMax: " << util::RSSMax() << '\n';
}

template <class Model, class Width> void DispatchFunction(const Model &model, const Config &config) {
  if (config.query) {
    QueryFromBytes<Model, Width>(model, config);
  } else {
    ConvertToBytes<Model, Width>(model, config.fd_in);
  }
}

template <class Model> void DispatchWidth(const char *file, const Config &config) {
  lm::ngram::Config model_config;
  model_config.load_method = util::READ;
  Model model(file, model_config);
  uint64_t bound = model.GetVocabulary().Bound();
  if (bound <= 256) {
    DispatchFunction<Model, uint8_t>(model, config);
  } else if (bound <= 65536) {
    DispatchFunction<Model, uint16_t>(model, config);
  } else if (bound <= (1ULL << 32)) {
    DispatchFunction<Model, uint32_t>(model, config);
  } else {
    DispatchFunction<Model, uint64_t>(model, config);
  }
}

void Dispatch(const char *file, const Config &config) {
  using namespace lm::ngram;
  lm::ngram::ModelType model_type;
  if (lm::ngram::RecognizeBinary(file, model_type)) {
    switch(model_type) {
      case PROBING:
        DispatchWidth<lm::ngram::ProbingModel>(file, config);
        break;
      case REST_PROBING:
        DispatchWidth<lm::ngram::RestProbingModel>(file, config);
        break;
      case TRIE:
        DispatchWidth<lm::ngram::TrieModel>(file, config);
        break;
      case QUANT_TRIE:
        DispatchWidth<lm::ngram::QuantTrieModel>(file, config);
        break;
      case ARRAY_TRIE:
        DispatchWidth<lm::ngram::ArrayTrieModel>(file, config);
        break;
      case QUANT_ARRAY_TRIE:
        DispatchWidth<lm::ngram::QuantArrayTrieModel>(file, config);
        break;
      default:
        UTIL_THROW(util::Exception, "Unrecognized kenlm model type " << model_type);
    }
  } else {
    UTIL_THROW(util::Exception, "Binarize before running benchmarks.");
  }
}

} // namespace

int main(int argc, char *argv[]) {
  try {
    Config config;
    config.fd_in = 0;
    std::string model;
    namespace po = boost::program_options;
    po::options_description options("Benchmark options");
    options.add_options()
      ("help,h", po::bool_switch(), "Show help message")
      ("model,m", po::value<std::string>(&model)->required(), "Model to query or convert vocab ids")
      ("threads,t", po::value<std::size_t>(&config.threads)->default_value(boost::thread::hardware_concurrency()), "Threads to use (querying only; TODO vocab conversion)")
      ("buffer,b", po::value<std::size_t>(&config.buf_per_thread)->default_value(4096), "Number of words to buffer per task.")
      ("vocab,v", po::bool_switch(), "Convert strings to vocab ids")
      ("query,q", po::bool_switch(), "Query from vocab ids");
    po::variables_map vm;
    po::store(po::parse_command_line(argc, argv, options), vm);
    if (argc == 1 || vm["help"].as<bool>()) {
      std::cerr << "Benchmark program for KenLM.  Intended usage:\n"
        << "#Convert text to vocabulary ids offline.  These ids are tied to a model.\n"
        << argv[0] << " -v -m $model <$text >$text.vocab\n"
        << "#Ensure files are in RAM.\n"
        << "cat $text.vocab $model >/dev/null\n"
        << "#Timed query against the model.\n"
        << argv[0] << " -q -m $model <$text.vocab\n";
      return 0;
    }
    po::notify(vm);
    if (!(vm["vocab"].as<bool>() ^ vm["query"].as<bool>())) {
      std::cerr << "Specify exactly one of -v (vocab conversion) or -q (query)." << std::endl;
      return 0;
    }
    config.query = vm["query"].as<bool>();
    if (!config.threads) {
      std::cerr << "Specify a non-zero number of threads with -t." << std::endl;
    }
    Dispatch(model.c_str(), config);
  } catch (const std::exception &e) {
    std::cerr << e.what() << std::endl;
    return 1;
  }
  return 0;
}