Welcome to mirror list, hosted at ThFree Co, Russian Federation.

count_ngrams_main.cc « builder « lm - github.com/kpu/kenlm.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: b20f68f988167c1fd312db58ddee19e630d0ab3b (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
#include "combine_counts.hh"
#include "corpus_count.hh"
#include "../common/compare.hh"
#include "../../util/stream/chain.hh"
#include "../../util/stream/io.hh"
#include "../../util/stream/sort.hh"
#include "../../util/file.hh"
#include "../../util/file_piece.hh"
#include "../../util/usage.hh"

#include <boost/program_options.hpp>

#include <string>

namespace {
class SizeNotify {
  public:
    SizeNotify(std::size_t &out) : behind_(out) {}

    void operator()(const std::string &from) {
      behind_ = util::ParseSize(from);
    }

  private:
    std::size_t &behind_;
};

boost::program_options::typed_value<std::string> *SizeOption(std::size_t &to, const char *default_value) {
  return boost::program_options::value<std::string>()->notifier(SizeNotify(to))->default_value(default_value);
}

} // namespace

int main(int argc, char *argv[]) {
  namespace po = boost::program_options;
  unsigned order;
  std::size_t ram;
  std::string temp_prefix, vocab_table, vocab_list;
  po::options_description options("corpus count");
  options.add_options()
    ("help,h", po::bool_switch(), "Show this help message")
    ("order,o", po::value<unsigned>(&order)->required(), "Order")
    ("temp_prefix,T", po::value<std::string>(&temp_prefix)->default_value(util::DefaultTempDirectory()), "Temporary file prefix")
    ("memory,S", SizeOption(ram, "80%"), "RAM")
    ("read_vocab_table", po::value<std::string>(&vocab_table), "Vocabulary hash table to read.  This should be a probing hash table with size at the beginning.")
    ("write_vocab_list", po::value<std::string>(&vocab_list), "Vocabulary list to write as null-delimited strings.");

  po::variables_map vm;
  po::store(po::parse_command_line(argc, argv, options), vm);
  if (argc == 1 || vm["help"].as<bool>()) {
    std::cerr << "Counts n-grams from standard input.\n" << options << std::endl;
    return 1;
  }
  po::notify(vm);

  if (!(vocab_table.empty() ^ vocab_list.empty())) {
    std::cerr << "Specify one of --read_vocab_table or --write_vocab_list for vocabulary handling." << std::endl;
    return 1;
  }

  util::NormalizeTempPrefix(temp_prefix);

  util::scoped_fd vocab_file(vocab_table.empty() ? util::CreateOrThrow(vocab_list.c_str()) : util::OpenReadOrThrow(vocab_table.c_str()));

  std::size_t blocks = 2;
  std::size_t remaining_size = ram - util::SizeOrThrow(vocab_file.get());

  std::size_t memory_for_chain =
    // This much memory to work with after vocab hash table.
    static_cast<float>(remaining_size) /
    // Solve for block size including the dedupe multiplier for one block.
    (static_cast<float>(blocks) + lm::builder::CorpusCount::DedupeMultiplier(order)) *
    // Chain likes memory expressed in terms of total memory.
    static_cast<float>(blocks);
  std::cerr << "Using " << memory_for_chain << " for chains." << std::endl;
  
  util::stream::Chain chain(util::stream::ChainConfig(lm::NGram<uint64_t>::TotalSize(order), blocks, memory_for_chain));
  util::FilePiece f(0, NULL, &std::cerr);
  uint64_t token_count = 0;
  lm::WordIndex type_count = 0;
  std::vector<bool> empty_prune;
  std::string empty_string;
  lm::builder::CorpusCount counter(f, vocab_file.get(), vocab_table.empty(), token_count, type_count, empty_prune, empty_string, chain.BlockSize() / chain.EntrySize(), lm::THROW_UP);
  chain >> boost::ref(counter);

  util::stream::SortConfig sort_config;
  sort_config.temp_prefix = temp_prefix;
  sort_config.buffer_size = 64 * 1024 * 1024;
  // Intended to run in parallel.
  sort_config.total_memory = remaining_size;
  util::stream::Sort<lm::SuffixOrder, lm::builder::CombineCounts> sorted(chain, sort_config, lm::SuffixOrder(order), lm::builder::CombineCounts());
  chain.Wait(true);
  util::stream::Chain chain2(util::stream::ChainConfig(lm::NGram<uint64_t>::TotalSize(order), blocks, sort_config.buffer_size));
  sorted.Output(chain2);
  // Inefficiently copies if there's only one block.
  chain2 >> util::stream::WriteAndRecycle(1);
  chain2.Wait(true);
  return 0;
}