#include "lm/model.hh" #include "util/file_stream.hh" #include "util/file.hh" #include "util/file_piece.hh" #include "util/usage.hh" #include namespace { template void ConvertToBytes(const Model &model, int fd_in) { util::FilePiece in(fd_in); util::FileStream out(1); Width width; StringPiece word; const Width end_sentence = (Width)model.GetVocabulary().EndSentence(); while (true) { while (in.ReadWordSameLine(word)) { width = (Width)model.GetVocabulary().Index(word); out.write(&width, sizeof(Width)); } if (!in.ReadLineOrEOF(word)) break; out.write(&end_sentence, sizeof(Width)); } } template void QueryFromBytes(const Model &model, int fd_in) { lm::ngram::State state[3]; const lm::ngram::State *const begin_state = &model.BeginSentenceState(); const lm::ngram::State *next_state = begin_state; Width kEOS = model.GetVocabulary().EndSentence(); Width buf[4096]; uint64_t completed = 0; double loaded = util::CPUTime(); std::cout << "CPU_to_load: " << loaded << std::endl; // Numerical precision: batch sums. double total = 0.0; while (std::size_t got = util::ReadOrEOF(fd_in, buf, sizeof(buf))) { float sum = 0.0; UTIL_THROW_IF2(got % sizeof(Width), "File size not a multiple of vocab id size " << sizeof(Width)); got /= sizeof(Width); completed += got; // Do even stuff first. const Width *even_end = buf + (got & ~1); // Alternating states const Width *i; for (i = buf; i != even_end;) { sum += model.FullScore(*next_state, *i, state[1]).prob; next_state = (*i++ == kEOS) ? begin_state : &state[1]; sum += model.FullScore(*next_state, *i, state[0]).prob; next_state = (*i++ == kEOS) ? begin_state : &state[0]; } // Odd corner case. if (got & 1) { sum += model.FullScore(*next_state, *i, state[2]).prob; next_state = (*i++ == kEOS) ? begin_state : &state[2]; } total += sum; } double after = util::CPUTime(); std::cerr << "Probability sum is " << total << std::endl; std::cout << "Queries: " << completed << std::endl; std::cout << "CPU_excluding_load: " << (after - loaded) << "\nCPU_per_query: " << ((after - loaded) / static_cast(completed)) << std::endl; std::cout << "RSSMax: " << util::RSSMax() << std::endl; } template void DispatchFunction(const Model &model, bool query) { if (query) { QueryFromBytes(model, 0); } else { ConvertToBytes(model, 0); } } template void DispatchWidth(const char *file, bool query) { lm::ngram::Config config; config.load_method = util::READ; std::cerr << "Using load_method = READ." << std::endl; Model model(file, config); lm::WordIndex bound = model.GetVocabulary().Bound(); if (bound <= 256) { DispatchFunction(model, query); } else if (bound <= 65536) { DispatchFunction(model, query); } else if (bound <= (1ULL << 32)) { DispatchFunction(model, query); } else { DispatchFunction(model, query); } } void Dispatch(const char *file, bool query) { using namespace lm::ngram; lm::ngram::ModelType model_type; if (lm::ngram::RecognizeBinary(file, model_type)) { switch(model_type) { case PROBING: DispatchWidth(file, query); break; case REST_PROBING: DispatchWidth(file, query); break; case TRIE: DispatchWidth(file, query); break; case QUANT_TRIE: DispatchWidth(file, query); break; case ARRAY_TRIE: DispatchWidth(file, query); break; case QUANT_ARRAY_TRIE: DispatchWidth(file, query); break; default: UTIL_THROW(util::Exception, "Unrecognized kenlm model type " << model_type); } } else { UTIL_THROW(util::Exception, "Binarize before running benchmarks."); } } } // namespace int main(int argc, char *argv[]) { if (argc != 3 || (strcmp(argv[1], "vocab") && strcmp(argv[1], "query"))) { std::cerr << "Benchmark program for KenLM. Intended usage:\n" << "#Convert text to vocabulary ids offline. These ids are tied to a model.\n" << argv[0] << " vocab $model <$text >$text.vocab\n" << "#Ensure files are in RAM.\n" << "cat $text.vocab $model >/dev/null\n" << "#Timed query against the model.\n" << argv[0] << " query $model <$text.vocab\n"; return 1; } Dispatch(argv[2], !strcmp(argv[1], "query")); return 0; }