Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/moses-smt/mosesdecoder.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKenneth Heafield <github@kheafield.com>2013-01-18 19:58:54 +0400
committerKenneth Heafield <github@kheafield.com>2013-01-18 19:59:51 +0400
commitfc5868d0fff647c3879668af4bfe6e1bab9e83ab (patch)
treed9b5026f32cc5ae603e6f42b1fc7128aeee4ba55
parent5f7b91e702f809577d82a7570778d254d985ba93 (diff)
KenLM df5be22 lmplz for estimation
-rw-r--r--jam-files/sanity.jam11
-rw-r--r--lm/Jamfile4
-rw-r--r--lm/build_binary.cc77
-rw-r--r--lm/builder/Jamfile9
-rw-r--r--lm/builder/README.md47
-rw-r--r--lm/builder/TODO5
-rw-r--r--lm/builder/adjust_counts.cc216
-rw-r--r--lm/builder/adjust_counts.hh44
-rw-r--r--lm/builder/adjust_counts_test.cc106
-rw-r--r--lm/builder/corpus_count.cc223
-rw-r--r--lm/builder/corpus_count.hh42
-rw-r--r--lm/builder/corpus_count_test.cc76
-rw-r--r--lm/builder/discount.hh26
-rw-r--r--lm/builder/header_info.hh20
-rw-r--r--lm/builder/initial_probabilities.cc136
-rw-r--r--lm/builder/initial_probabilities.hh34
-rw-r--r--lm/builder/interpolate.cc65
-rw-r--r--lm/builder/interpolate.hh27
-rw-r--r--lm/builder/joint_order.hh43
-rw-r--r--lm/builder/main.cc94
-rw-r--r--lm/builder/multi_stream.hh180
-rw-r--r--lm/builder/ngram.hh84
-rw-r--r--lm/builder/ngram_stream.hh55
-rw-r--r--lm/builder/pipeline.cc320
-rw-r--r--lm/builder/pipeline.hh40
-rw-r--r--lm/builder/print.cc135
-rw-r--r--lm/builder/print.hh102
-rw-r--r--lm/builder/sort.hh103
-rw-r--r--lm/filter/Jamfile5
-rw-r--r--lm/filter/arpa_io.cc122
-rw-r--r--lm/filter/arpa_io.hh122
-rw-r--r--lm/filter/count_io.hh91
-rw-r--r--lm/filter/format.hh250
-rw-r--r--lm/filter/main.cc249
-rw-r--r--lm/filter/phrase.cc281
-rw-r--r--lm/filter/phrase.hh153
-rw-r--r--lm/filter/thread.hh167
-rw-r--r--lm/filter/vocab.cc54
-rw-r--r--lm/filter/vocab.hh132
-rw-r--r--lm/filter/wrapper.hh58
-rw-r--r--lm/read_arpa.cc10
-rw-r--r--lm/sizes.cc63
-rw-r--r--lm/sizes.hh17
-rw-r--r--util/Jamfile3
-rw-r--r--util/double-conversion/Jamfile2
-rw-r--r--util/double-conversion/double-conversion.cc2
-rw-r--r--util/double-conversion/double-conversion.h6
-rw-r--r--util/ersatz_progress.cc4
-rw-r--r--util/ersatz_progress.hh3
-rw-r--r--util/file.cc13
-rw-r--r--util/file.hh2
-rw-r--r--util/file_piece.cc53
-rw-r--r--util/multi_intersection.hh80
-rw-r--r--util/multi_intersection_test.cc63
-rw-r--r--util/pcqueue.hh105
-rw-r--r--util/pool.cc5
-rw-r--r--util/probing_hash_table.hh5
-rw-r--r--util/scoped.cc29
-rw-r--r--util/scoped.hh17
-rw-r--r--util/stream/Jamfile12
-rw-r--r--util/stream/block.hh43
-rw-r--r--util/stream/chain.cc155
-rw-r--r--util/stream/chain.hh198
-rw-r--r--util/stream/config.hh32
-rw-r--r--util/stream/io.cc64
-rw-r--r--util/stream/io.hh76
-rw-r--r--util/stream/io_test.cc38
-rw-r--r--util/stream/line_input.cc52
-rw-r--r--util/stream/line_input.hh22
-rw-r--r--util/stream/multi_progress.cc86
-rw-r--r--util/stream/multi_progress.hh90
-rw-r--r--util/stream/sort.hh542
-rw-r--r--util/stream/sort_test.cc62
-rw-r--r--util/stream/stream.hh74
-rw-r--r--util/stream/stream_test.cc35
-rw-r--r--util/stream/timer.hh14
-rw-r--r--util/thread_pool.hh95
-rw-r--r--util/usage.cc60
-rw-r--r--util/usage.hh10
79 files changed, 6155 insertions, 95 deletions
diff --git a/jam-files/sanity.jam b/jam-files/sanity.jam
index a35042f79..530ecc420 100644
--- a/jam-files/sanity.jam
+++ b/jam-files/sanity.jam
@@ -138,9 +138,9 @@ rule boost ( min-version ) {
echo Failed to run "$(cmd)" ;
exit Boost does not seem to be installed or g++ is confused. : 1 ;
}
- boost-version = [ MATCH "#define BOOST_VERSION ([0-9]*)" : $(boost-shell[1]) ] ;
- if $(boost-version) < $(min-version) && $(CLEANING) = no {
- exit You have Boost $(boost-version). This package requires Boost at least $(min-version) (and preferably newer). : 1 ;
+ constant BOOST-VERSION : [ MATCH "#define BOOST_VERSION ([0-9]*)" : $(boost-shell[1]) ] ;
+ if $(BOOST-VERSION) < $(min-version) && $(CLEANING) = no {
+ exit You have Boost $(BOOST-VERSION). This package requires Boost at least $(min-version) (and preferably newer). : 1 ;
}
# If matching version tags exist, use them.
boost-lib-version = [ MATCH "#define BOOST_LIB_VERSION \"([^\"]*)\"" : $(boost-shell[1]) ] ;
@@ -159,6 +159,11 @@ rule boost ( min-version ) {
boost-lib program_options PROGRAM_OPTIONS_DYN_LINK ;
boost-lib unit_test_framework TEST_DYN_LINK ;
boost-lib iostreams IOSTREAMS_DYN_LINK ;
+ boost-lib filesystem FILE_SYSTEM_DYN_LINK ;
+ if $(BOOST-VERSION) >= 104800 {
+ boost-lib chrono CHRONO_DYN_LINK ;
+ boost-lib timer TIMER_DYN_LINK : boost_chrono ;
+ }
}
#Link normally to a library, but sometimes static isn't installed so fall back to dynamic.
diff --git a/lm/Jamfile b/lm/Jamfile
index a4ec75002..0faea30f4 100644
--- a/lm/Jamfile
+++ b/lm/Jamfile
@@ -13,7 +13,7 @@ update-if-changed $(ORDER-LOG) $(max-order) ;
max-order += <dependency>$(ORDER-LOG) ;
-fakelib kenlm : bhiksha.cc binary_format.cc config.cc lm_exception.cc model.cc quantize.cc read_arpa.cc search_hashed.cc search_trie.cc trie.cc trie_sort.cc value_build.cc virtual_interface.cc vocab.cc ../util//kenutil : <include>.. $(max-order) : : <include>.. $(max-order) ;
+fakelib kenlm : bhiksha.cc binary_format.cc config.cc lm_exception.cc model.cc quantize.cc read_arpa.cc search_hashed.cc search_trie.cc sizes.cc trie.cc trie_sort.cc value_build.cc virtual_interface.cc vocab.cc ../util//kenutil : <include>.. $(max-order) : : <include>.. $(max-order) ;
import testing ;
@@ -26,4 +26,4 @@ exe build_binary : build_binary.cc kenlm ../util//kenutil ;
exe kenlm_max_order : max_order.cc : <include>.. $(max-order) ;
exe fragment : fragment.cc kenlm ;
-alias programs : query build_binary kenlm_max_order fragment ;
+alias programs : query build_binary kenlm_max_order fragment filter//filter : <threading>multi:<source>builder//lmplz ;
diff --git a/lm/build_binary.cc b/lm/build_binary.cc
index 2b8c9d5b2..ab2c0c320 100644
--- a/lm/build_binary.cc
+++ b/lm/build_binary.cc
@@ -1,10 +1,14 @@
#include "lm/model.hh"
+#include "lm/sizes.hh"
#include "util/file_piece.hh"
+#include "util/usage.hh"
+#include <algorithm>
#include <cstdlib>
#include <exception>
#include <iostream>
#include <iomanip>
+#include <limits>
#include <math.h>
#include <stdlib.h>
@@ -19,8 +23,8 @@ namespace lm {
namespace ngram {
namespace {
-void Usage(const char *name) {
- std::cerr << "Usage: " << name << " [-u log10_unknown_probability] [-s] [-i] [-w mmap|after] [-p probing_multiplier] [-t trie_temporary] [-m trie_building_megabytes] [-q bits] [-b bits] [-a bits] [type] input.arpa [output.mmap]\n\n"
+void Usage(const char *name, const char *default_mem) {
+ std::cerr << "Usage: " << name << " [-u log10_unknown_probability] [-s] [-i] [-w mmap|after] [-p probing_multiplier] [-T trie_temporary] [-S trie_building_mem] [-q bits] [-b bits] [-a bits] [type] input.arpa [output.mmap]\n\n"
"-u sets the log10 probability for <unk> if the ARPA file does not have one.\n"
" Default is -100. The ARPA file will always take precedence.\n"
"-s allows models to be built even if they do not have <s> and </s>.\n"
@@ -38,8 +42,11 @@ void Usage(const char *name) {
"trie is a straightforward trie with bit-level packing. It uses the least\n"
"memory and is still faster than SRI or IRST. Building the trie format uses an\n"
"on-disk sort to save memory.\n"
-"-t is the temporary directory prefix. Default is the output file name.\n"
-"-m limits memory use for sorting. Measured in MB. Default is 1024MB.\n"
+"-T is the temporary directory prefix. Default is the output file name.\n"
+"-S determines memory use for sorting. Default is " << default_mem << ". This is compatible\n"
+" with GNU sort. The number is followed by a unit: \% for percent of physical\n"
+" memory, b for bytes, K for Kilobytes, M for megabytes, then G,T,P,E,Z,Y. \n"
+" Default unit is K for Kilobytes.\n"
"-q turns quantization on and sets the number of bits (e.g. -q 8).\n"
"-b sets backoff quantization bits. Requires -q and defaults to that value.\n"
"-a compresses pointers using an array of offsets. The parameter is the\n"
@@ -83,47 +90,6 @@ void ParseFileList(const char *from, std::vector<std::string> &to) {
}
}
-void ShowSizes(const char *file, const lm::ngram::Config &config) {
- std::vector<uint64_t> counts;
- util::FilePiece f(file);
- lm::ReadARPACounts(f, counts);
- uint64_t sizes[6];
- sizes[0] = ProbingModel::Size(counts, config);
- sizes[1] = RestProbingModel::Size(counts, config);
- sizes[2] = TrieModel::Size(counts, config);
- sizes[3] = QuantTrieModel::Size(counts, config);
- sizes[4] = ArrayTrieModel::Size(counts, config);
- sizes[5] = QuantArrayTrieModel::Size(counts, config);
- uint64_t max_length = *std::max_element(sizes, sizes + sizeof(sizes) / sizeof(uint64_t));
- uint64_t min_length = *std::min_element(sizes, sizes + sizeof(sizes) / sizeof(uint64_t));
- uint64_t divide;
- char prefix;
- if (min_length < (1 << 10) * 10) {
- prefix = ' ';
- divide = 1;
- } else if (min_length < (1 << 20) * 10) {
- prefix = 'k';
- divide = 1 << 10;
- } else if (min_length < (1ULL << 30) * 10) {
- prefix = 'M';
- divide = 1 << 20;
- } else {
- prefix = 'G';
- divide = 1 << 30;
- }
- long int length = std::max<long int>(2, static_cast<long int>(ceil(log10((double) max_length / divide))));
- std::cout << "Memory estimate:\ntype ";
- // right align bytes.
- for (long int i = 0; i < length - 2; ++i) std::cout << ' ';
- std::cout << prefix << "B\n"
- "probing " << std::setw(length) << (sizes[0] / divide) << " assuming -p " << config.probing_multiplier << "\n"
- "probing " << std::setw(length) << (sizes[1] / divide) << " assuming -r models -p " << config.probing_multiplier << "\n"
- "trie " << std::setw(length) << (sizes[2] / divide) << " without quantization\n"
- "trie " << std::setw(length) << (sizes[3] / divide) << " assuming -q " << (unsigned)config.prob_bits << " -b " << (unsigned)config.backoff_bits << " quantization \n"
- "trie " << std::setw(length) << (sizes[4] / divide) << " assuming -a " << (unsigned)config.pointer_bhiksha_bits << " array pointer compression\n"
- "trie " << std::setw(length) << (sizes[5] / divide) << " assuming -a " << (unsigned)config.pointer_bhiksha_bits << " -q " << (unsigned)config.prob_bits << " -b " << (unsigned)config.backoff_bits<< " array pointer compression and quantization\n";
-}
-
void ProbingQuantizationUnsupported() {
std::cerr << "Quantization is only implemented in the trie data structure." << std::endl;
exit(1);
@@ -136,11 +102,14 @@ void ProbingQuantizationUnsupported() {
int main(int argc, char *argv[]) {
using namespace lm::ngram;
+ const char *default_mem = util::GuessPhysicalMemory() ? "80%" : "1G";
+
try {
bool quantize = false, set_backoff_bits = false, bhiksha = false, set_write_method = false, rest = false;
lm::ngram::Config config;
+ config.building_memory = util::ParseSize(default_mem);
int opt;
- while ((opt = getopt(argc, argv, "q:b:a:u:p:t:m:w:sir:")) != -1) {
+ while ((opt = getopt(argc, argv, "q:b:a:u:p:t:T:m:S:w:sir:")) != -1) {
switch(opt) {
case 'q':
config.prob_bits = ParseBitCount(optarg);
@@ -161,12 +130,16 @@ int main(int argc, char *argv[]) {
case 'p':
config.probing_multiplier = ParseFloat(optarg);
break;
- case 't':
+ case 't': // legacy
+ case 'T':
config.temporary_directory_prefix = optarg;
break;
- case 'm':
+ case 'm': // legacy
config.building_memory = ParseUInt(optarg) * 1048576;
break;
+ case 'S':
+ config.building_memory = std::min(static_cast<uint64_t>(std::numeric_limits<std::size_t>::max()), util::ParseSize(optarg));
+ break;
case 'w':
set_write_method = true;
if (!strcmp(optarg, "mmap")) {
@@ -174,7 +147,7 @@ int main(int argc, char *argv[]) {
} else if (!strcmp(optarg, "after")) {
config.write_method = Config::WRITE_AFTER;
} else {
- Usage(argv[0]);
+ Usage(argv[0], default_mem);
}
break;
case 's':
@@ -189,7 +162,7 @@ int main(int argc, char *argv[]) {
config.rest_function = Config::REST_LOWER;
break;
default:
- Usage(argv[0]);
+ Usage(argv[0], default_mem);
}
}
if (!quantize && set_backoff_bits) {
@@ -212,7 +185,7 @@ int main(int argc, char *argv[]) {
from_file = argv[optind + 1];
config.write_mmap = argv[optind + 2];
} else {
- Usage(argv[0]);
+ Usage(argv[0], default_mem);
}
if (!strcmp(model_type, "probing")) {
if (!set_write_method) config.write_method = Config::WRITE_AFTER;
@@ -242,7 +215,7 @@ int main(int argc, char *argv[]) {
}
}
} else {
- Usage(argv[0]);
+ Usage(argv[0], default_mem);
}
}
catch (const std::exception &e) {
diff --git a/lm/builder/Jamfile b/lm/builder/Jamfile
new file mode 100644
index 000000000..9b54911a0
--- /dev/null
+++ b/lm/builder/Jamfile
@@ -0,0 +1,9 @@
+fakelib builder : corpus_count.cc adjust_counts.cc initial_probabilities.cc interpolate.cc pipeline.cc print.cc
+ ../../util//kenutil ../../util/stream//stream ../../util/double-conversion//double-conversion ..//kenlm
+ : : : <library>/top//boost_thread $(timer-link) ;
+
+exe lmplz : main.cc builder /top//boost_program_options ;
+
+import testing ;
+unit-test corpus_count_test : corpus_count_test.cc builder /top//boost_unit_test_framework ;
+unit-test adjust_counts_test : adjust_counts_test.cc builder /top//boost_unit_test_framework ;
diff --git a/lm/builder/README.md b/lm/builder/README.md
new file mode 100644
index 000000000..be0d35e26
--- /dev/null
+++ b/lm/builder/README.md
@@ -0,0 +1,47 @@
+Dependencies
+============
+
+Boost >= 1.42.0 is required.
+
+For Ubuntu,
+```bash
+sudo apt-get install libboost1.48-all-dev
+```
+
+Alternatively, you can download, compile, and install it yourself:
+
+```bash
+wget http://sourceforge.net/projects/boost/files/boost/1.52.0/boost_1_52_0.tar.gz/download -O boost_1_52_0.tar.gz
+tar -xvzf boost_1_52_0.tar.gz
+cd boost_1_52_0
+./bootstrap.sh
+./b2
+sudo ./b2 install
+```
+
+Local install options (in a user-space prefix directory) are also possible. See http://www.boost.org/doc/libs/1_52_0/doc/html/bbv2/installation.html.
+
+
+Building
+========
+
+```bash
+bjam
+```
+Your distribution might package bjam and boost-build separately from Boost. Both are required.
+
+Usage
+=====
+
+Run
+```bash
+$ bin/lmplz
+```
+to see command line arguments
+
+Running
+=======
+
+```bash
+bin/lmplz -o 5 <text >text.arpa
+```
diff --git a/lm/builder/TODO b/lm/builder/TODO
new file mode 100644
index 000000000..cb5aef3a0
--- /dev/null
+++ b/lm/builder/TODO
@@ -0,0 +1,5 @@
+More tests!
+Sharding.
+Some way to manage all the crazy config options.
+Option to build the binary file directly.
+Interpolation of different orders.
diff --git a/lm/builder/adjust_counts.cc b/lm/builder/adjust_counts.cc
new file mode 100644
index 000000000..a6f48011f
--- /dev/null
+++ b/lm/builder/adjust_counts.cc
@@ -0,0 +1,216 @@
+#include "lm/builder/adjust_counts.hh"
+#include "lm/builder/multi_stream.hh"
+#include "util/stream/timer.hh"
+
+#include <algorithm>
+
+namespace lm { namespace builder {
+
+BadDiscountException::BadDiscountException() throw() {}
+BadDiscountException::~BadDiscountException() throw() {}
+
+namespace {
+// Return last word in full that is different.
+const WordIndex* FindDifference(const NGram &full, const NGram &lower_last) {
+ const WordIndex *cur_word = full.end() - 1;
+ const WordIndex *pre_word = lower_last.end() - 1;
+ // Find last difference.
+ for (; pre_word >= lower_last.begin() && *pre_word == *cur_word; --cur_word, --pre_word) {}
+ return cur_word;
+}
+
+class StatCollector {
+ public:
+ StatCollector(std::size_t order, std::vector<uint64_t> &counts, std::vector<Discount> &discounts)
+ : orders_(order), full_(orders_.back()), counts_(counts), discounts_(discounts) {
+ memset(&orders_[0], 0, sizeof(OrderStat) * order);
+ }
+
+ ~StatCollector() {}
+
+ void CalculateDiscounts() {
+ counts_.resize(orders_.size());
+ discounts_.resize(orders_.size());
+ for (std::size_t i = 0; i < orders_.size(); ++i) {
+ const OrderStat &s = orders_[i];
+ counts_[i] = s.count;
+
+ for (unsigned j = 1; j < 4; ++j) {
+ // TODO: Specialize error message for j == 3, meaning 3+
+ UTIL_THROW_IF(s.n[j] == 0, BadDiscountException, "Could not calculate Kneser-Ney discounts for "
+ << (i+1) << "-grams with adjusted count " << (j+1) << " because we didn't observe any "
+ << (i+1) << "-grams with adjusted count " << j << "; Is this small or artificial data?");
+ }
+
+ // See equation (26) in Chen and Goodman.
+ discounts_[i].amount[0] = 0.0;
+ float y = static_cast<float>(s.n[1]) / static_cast<float>(s.n[1] + 2.0 * s.n[2]);
+ for (unsigned j = 1; j < 4; ++j) {
+ discounts_[i].amount[j] = static_cast<float>(j) - static_cast<float>(j + 1) * y * static_cast<float>(s.n[j+1]) / static_cast<float>(s.n[j]);
+ UTIL_THROW_IF(discounts_[i].amount[j] < 0.0 || discounts_[i].amount[j] > j, BadDiscountException, "ERROR: " << (i+1) << "-gram discount out of range for adjusted count " << j << ": " << discounts_[i].amount[j]);
+ }
+ }
+ }
+
+ void Add(std::size_t order_minus_1, uint64_t count) {
+ OrderStat &stat = orders_[order_minus_1];
+ ++stat.count;
+ if (count < 5) ++stat.n[count];
+ }
+
+ void AddFull(uint64_t count) {
+ ++full_.count;
+ if (count < 5) ++full_.n[count];
+ }
+
+ private:
+ struct OrderStat {
+ // n_1 in equation 26 of Chen and Goodman etc
+ uint64_t n[5];
+ uint64_t count;
+ };
+
+ std::vector<OrderStat> orders_;
+ OrderStat &full_;
+
+ std::vector<uint64_t> &counts_;
+ std::vector<Discount> &discounts_;
+};
+
+// Reads all entries in order like NGramStream does.
+// But deletes any entries that have <s> in the 1st (not 0th) position on the
+// way out by putting other entries in their place. This disrupts the sort
+// order but we don't care because the data is going to be sorted again.
+class CollapseStream {
+ public:
+ CollapseStream(const util::stream::ChainPosition &position) :
+ current_(NULL, NGram::OrderFromSize(position.GetChain().EntrySize())),
+ block_(position) {
+ StartBlock();
+ }
+
+ const NGram &operator*() const { return current_; }
+ const NGram *operator->() const { return &current_; }
+
+ operator bool() const { return block_; }
+
+ CollapseStream &operator++() {
+ assert(block_);
+ if (current_.begin()[1] == kBOS && current_.Base() < copy_from_) {
+ memcpy(current_.Base(), copy_from_, current_.TotalSize());
+ UpdateCopyFrom();
+ }
+ current_.NextInMemory();
+ uint8_t *block_base = static_cast<uint8_t*>(block_->Get());
+ if (current_.Base() == block_base + block_->ValidSize()) {
+ block_->SetValidSize(copy_from_ + current_.TotalSize() - block_base);
+ ++block_;
+ StartBlock();
+ }
+ return *this;
+ }
+
+ private:
+ void StartBlock() {
+ for (; ; ++block_) {
+ if (!block_) return;
+ if (block_->ValidSize()) break;
+ }
+ current_.ReBase(block_->Get());
+ copy_from_ = static_cast<uint8_t*>(block_->Get()) + block_->ValidSize();
+ UpdateCopyFrom();
+ }
+
+ // Find last without bos.
+ void UpdateCopyFrom() {
+ for (copy_from_ -= current_.TotalSize(); copy_from_ >= current_.Base(); copy_from_ -= current_.TotalSize()) {
+ if (NGram(copy_from_, current_.Order()).begin()[1] != kBOS) break;
+ }
+ }
+
+ NGram current_;
+
+ // Goes backwards in the block
+ uint8_t *copy_from_;
+
+ util::stream::Link block_;
+};
+
+} // namespace
+
+void AdjustCounts::Run(const ChainPositions &positions) {
+ UTIL_TIMER("(%w s) Adjusted counts\n");
+
+ const std::size_t order = positions.size();
+ StatCollector stats(order, counts_, discounts_);
+ if (order == 1) {
+ // Only unigrams. Just collect stats.
+ for (NGramStream full(positions[0]); full; ++full)
+ stats.AddFull(full->Count());
+ stats.CalculateDiscounts();
+ return;
+ }
+
+ NGramStreams streams;
+ streams.Init(positions, positions.size() - 1);
+ CollapseStream full(positions[positions.size() - 1]);
+
+ // Initialization: <unk> has count 0 and so does <s>.
+ NGramStream *lower_valid = streams.begin();
+ streams[0]->Count() = 0;
+ *streams[0]->begin() = kUNK;
+ stats.Add(0, 0);
+ (++streams[0])->Count() = 0;
+ *streams[0]->begin() = kBOS;
+ // not in stats because it will get put in later.
+
+ // iterate over full (the stream of the highest order ngrams)
+ for (; full; ++full) {
+ const WordIndex *different = FindDifference(*full, **lower_valid);
+ std::size_t same = full->end() - 1 - different;
+ // Increment the adjusted count.
+ if (same) ++streams[same - 1]->Count();
+
+ // Output all the valid ones that changed.
+ for (; lower_valid >= &streams[same]; --lower_valid) {
+ stats.Add(lower_valid - streams.begin(), (*lower_valid)->Count());
+ ++*lower_valid;
+ }
+
+ // This is here because bos is also const WordIndex *, so copy gets
+ // consistent argument types.
+ const WordIndex *full_end = full->end();
+ // Initialize and mark as valid up to bos.
+ const WordIndex *bos;
+ for (bos = different; (bos > full->begin()) && (*bos != kBOS); --bos) {
+ ++lower_valid;
+ std::copy(bos, full_end, (*lower_valid)->begin());
+ (*lower_valid)->Count() = 1;
+ }
+ // Now bos indicates where <s> is or is the 0th word of full.
+ if (bos != full->begin()) {
+ // There is an <s> beyond the 0th word.
+ NGramStream &to = *++lower_valid;
+ std::copy(bos, full_end, to->begin());
+ to->Count() = full->Count();
+ } else {
+ stats.AddFull(full->Count());
+ }
+ assert(lower_valid >= &streams[0]);
+ }
+
+ // Output everything valid.
+ for (NGramStream *s = streams.begin(); s <= lower_valid; ++s) {
+ stats.Add(s - streams.begin(), (*s)->Count());
+ ++*s;
+ }
+ // Poison everyone! Except the N-grams which were already poisoned by the input.
+ for (NGramStream *s = streams.begin(); s != streams.end(); ++s)
+ s->Poison();
+
+ stats.CalculateDiscounts();
+
+ // NOTE: See special early-return case for unigrams near the top of this function
+}
+
+}} // namespaces
diff --git a/lm/builder/adjust_counts.hh b/lm/builder/adjust_counts.hh
new file mode 100644
index 000000000..f38ff79db
--- /dev/null
+++ b/lm/builder/adjust_counts.hh
@@ -0,0 +1,44 @@
+#ifndef LM_BUILDER_ADJUST_COUNTS__
+#define LM_BUILDER_ADJUST_COUNTS__
+
+#include "lm/builder/discount.hh"
+#include "util/exception.hh"
+
+#include <vector>
+
+#include <stdint.h>
+
+namespace lm {
+namespace builder {
+
+class ChainPositions;
+
+class BadDiscountException : public util::Exception {
+ public:
+ BadDiscountException() throw();
+ ~BadDiscountException() throw();
+};
+
+/* Compute adjusted counts.
+ * Input: unique suffix sorted N-grams (and just the N-grams) with raw counts.
+ * Output: [1,N]-grams with adjusted counts.
+ * [1,N)-grams are in suffix order
+ * N-grams are in undefined order (they're going to be sorted anyway).
+ */
+class AdjustCounts {
+ public:
+ AdjustCounts(std::vector<uint64_t> &counts, std::vector<Discount> &discounts)
+ : counts_(counts), discounts_(discounts) {}
+
+ void Run(const ChainPositions &positions);
+
+ private:
+ std::vector<uint64_t> &counts_;
+ std::vector<Discount> &discounts_;
+};
+
+} // namespace builder
+} // namespace lm
+
+#endif // LM_BUILDER_ADJUST_COUNTS__
+
diff --git a/lm/builder/adjust_counts_test.cc b/lm/builder/adjust_counts_test.cc
new file mode 100644
index 000000000..68b5f33ea
--- /dev/null
+++ b/lm/builder/adjust_counts_test.cc
@@ -0,0 +1,106 @@
+#include "lm/builder/adjust_counts.hh"
+
+#include "lm/builder/multi_stream.hh"
+#include "util/scoped.hh"
+
+#include <boost/thread/thread.hpp>
+#define BOOST_TEST_MODULE AdjustCounts
+#include <boost/test/unit_test.hpp>
+
+namespace lm { namespace builder { namespace {
+
+class KeepCopy {
+ public:
+ KeepCopy() : size_(0) {}
+
+ void Run(const util::stream::ChainPosition &position) {
+ for (util::stream::Link link(position); link; ++link) {
+ mem_.call_realloc(size_ + link->ValidSize());
+ memcpy(static_cast<uint8_t*>(mem_.get()) + size_, link->Get(), link->ValidSize());
+ size_ += link->ValidSize();
+ }
+ }
+
+ uint8_t *Get() { return static_cast<uint8_t*>(mem_.get()); }
+ std::size_t Size() const { return size_; }
+
+ private:
+ util::scoped_malloc mem_;
+ std::size_t size_;
+};
+
+struct Gram4 {
+ WordIndex ids[4];
+ uint64_t count;
+};
+
+class WriteInput {
+ public:
+ void Run(const util::stream::ChainPosition &position) {
+ NGramStream input(position);
+ Gram4 grams[] = {
+ {{0,0,0,0},10},
+ {{0,0,3,0},3},
+ // bos
+ {{1,1,1,2},5},
+ {{0,0,3,2},5},
+ };
+ for (size_t i = 0; i < sizeof(grams) / sizeof(Gram4); ++i, ++input) {
+ memcpy(input->begin(), grams[i].ids, sizeof(WordIndex) * 4);
+ input->Count() = grams[i].count;
+ }
+ input.Poison();
+ }
+};
+
+BOOST_AUTO_TEST_CASE(Simple) {
+ KeepCopy outputs[4];
+ std::vector<uint64_t> counts;
+ std::vector<Discount> discount;
+ {
+ util::stream::ChainConfig config;
+ config.total_memory = 100;
+ config.block_count = 1;
+ Chains chains(4);
+ for (unsigned i = 0; i < 4; ++i) {
+ config.entry_size = NGram::TotalSize(i + 1);
+ chains.push_back(config);
+ }
+
+ chains[3] >> WriteInput();
+ ChainPositions for_adjust(chains);
+ for (unsigned i = 0; i < 4; ++i) {
+ chains[i] >> boost::ref(outputs[i]);
+ }
+ chains >> util::stream::kRecycle;
+ BOOST_CHECK_THROW(AdjustCounts(counts, discount).Run(for_adjust), BadDiscountException);
+ }
+ BOOST_REQUIRE_EQUAL(4UL, counts.size());
+ BOOST_CHECK_EQUAL(4UL, counts[0]);
+ // These are no longer set because the discounts are bad.
+/* BOOST_CHECK_EQUAL(4UL, counts[1]);
+ BOOST_CHECK_EQUAL(3UL, counts[2]);
+ BOOST_CHECK_EQUAL(3UL, counts[3]);*/
+ BOOST_REQUIRE_EQUAL(NGram::TotalSize(1) * 4, outputs[0].Size());
+ NGram uni(outputs[0].Get(), 1);
+ BOOST_CHECK_EQUAL(kUNK, *uni.begin());
+ BOOST_CHECK_EQUAL(0ULL, uni.Count());
+ uni.NextInMemory();
+ BOOST_CHECK_EQUAL(kBOS, *uni.begin());
+ BOOST_CHECK_EQUAL(0ULL, uni.Count());
+ uni.NextInMemory();
+ BOOST_CHECK_EQUAL(0UL, *uni.begin());
+ BOOST_CHECK_EQUAL(2ULL, uni.Count());
+ uni.NextInMemory();
+ BOOST_CHECK_EQUAL(2ULL, uni.Count());
+ BOOST_CHECK_EQUAL(2UL, *uni.begin());
+
+ BOOST_REQUIRE_EQUAL(NGram::TotalSize(2) * 4, outputs[1].Size());
+ NGram bi(outputs[1].Get(), 2);
+ BOOST_CHECK_EQUAL(0UL, *bi.begin());
+ BOOST_CHECK_EQUAL(0UL, *(bi.begin() + 1));
+ BOOST_CHECK_EQUAL(1ULL, bi.Count());
+ bi.NextInMemory();
+}
+
+}}} // namespaces
diff --git a/lm/builder/corpus_count.cc b/lm/builder/corpus_count.cc
new file mode 100644
index 000000000..8c3de57dd
--- /dev/null
+++ b/lm/builder/corpus_count.cc
@@ -0,0 +1,223 @@
+#include "lm/builder/corpus_count.hh"
+
+#include "lm/builder/ngram.hh"
+#include "lm/lm_exception.hh"
+#include "lm/word_index.hh"
+#include "util/file.hh"
+#include "util/file_piece.hh"
+#include "util/murmur_hash.hh"
+#include "util/probing_hash_table.hh"
+#include "util/scoped.hh"
+#include "util/stream/chain.hh"
+#include "util/stream/timer.hh"
+#include "util/tokenize_piece.hh"
+
+#include <boost/unordered_set.hpp>
+#include <boost/unordered_map.hpp>
+
+#include <functional>
+
+#include <stdint.h>
+
+namespace lm {
+namespace builder {
+namespace {
+
+class VocabHandout {
+ public:
+ explicit VocabHandout(int fd) {
+ util::scoped_fd duped(util::DupOrThrow(fd));
+ word_list_.reset(util::FDOpenOrThrow(duped));
+
+ Lookup("<unk>"); // Force 0
+ Lookup("<s>"); // Force 1
+ Lookup("</s>"); // Force 2
+ }
+
+ WordIndex Lookup(const StringPiece &word) {
+ uint64_t hashed = util::MurmurHashNative(word.data(), word.size());
+ std::pair<Seen::iterator, bool> ret(seen_.insert(std::pair<uint64_t, lm::WordIndex>(hashed, seen_.size())));
+ if (ret.second) {
+ char null_delimit = 0;
+ util::WriteOrThrow(word_list_.get(), word.data(), word.size());
+ util::WriteOrThrow(word_list_.get(), &null_delimit, 1);
+ UTIL_THROW_IF(seen_.size() >= std::numeric_limits<lm::WordIndex>::max(), VocabLoadException, "Too many vocabulary words. Change WordIndex to uint64_t in lm/word_index.hh.");
+ }
+ return ret.first->second;
+ }
+
+ WordIndex Size() const {
+ return seen_.size();
+ }
+
+ private:
+ typedef boost::unordered_map<uint64_t, lm::WordIndex> Seen;
+
+ Seen seen_;
+
+ util::scoped_FILE word_list_;
+};
+
+class DedupeHash : public std::unary_function<const WordIndex *, bool> {
+ public:
+ explicit DedupeHash(std::size_t order) : size_(order * sizeof(WordIndex)) {}
+
+ std::size_t operator()(const WordIndex *start) const {
+ return util::MurmurHashNative(start, size_);
+ }
+
+ private:
+ const std::size_t size_;
+};
+
+class DedupeEquals : public std::binary_function<const WordIndex *, const WordIndex *, bool> {
+ public:
+ explicit DedupeEquals(std::size_t order) : size_(order * sizeof(WordIndex)) {}
+
+ bool operator()(const WordIndex *first, const WordIndex *second) const {
+ return !memcmp(first, second, size_);
+ }
+
+ private:
+ const std::size_t size_;
+};
+
+struct DedupeEntry {
+ typedef WordIndex *Key;
+ Key GetKey() const { return key; }
+ Key key;
+ static DedupeEntry Construct(WordIndex *at) {
+ DedupeEntry ret;
+ ret.key = at;
+ return ret;
+ }
+};
+
+typedef util::ProbingHashTable<DedupeEntry, DedupeHash, DedupeEquals> Dedupe;
+
+const float kProbingMultiplier = 1.5;
+
+class Writer {
+ public:
+ Writer(std::size_t order, const util::stream::ChainPosition &position, void *dedupe_mem, std::size_t dedupe_mem_size)
+ : block_(position), gram_(block_->Get(), order),
+ dedupe_invalid_(order, std::numeric_limits<WordIndex>::max()),
+ dedupe_(dedupe_mem, dedupe_mem_size, &dedupe_invalid_[0], DedupeHash(order), DedupeEquals(order)),
+ buffer_(new WordIndex[order - 1]),
+ block_size_(position.GetChain().BlockSize()) {
+ dedupe_.Clear(DedupeEntry::Construct(&dedupe_invalid_[0]));
+ assert(Dedupe::Size(position.GetChain().BlockSize() / position.GetChain().EntrySize(), kProbingMultiplier) == dedupe_mem_size);
+ if (order == 1) {
+ // Add special words. AdjustCounts is responsible if order != 1.
+ AddUnigramWord(kUNK);
+ AddUnigramWord(kBOS);
+ }
+ }
+
+ ~Writer() {
+ block_->SetValidSize(reinterpret_cast<const uint8_t*>(gram_.begin()) - static_cast<const uint8_t*>(block_->Get()));
+ (++block_).Poison();
+ }
+
+ // Write context with a bunch of <s>
+ void StartSentence() {
+ for (WordIndex *i = gram_.begin(); i != gram_.end() - 1; ++i) {
+ *i = kBOS;
+ }
+ }
+
+ void Append(WordIndex word) {
+ *(gram_.end() - 1) = word;
+ Dedupe::MutableIterator at;
+ bool found = dedupe_.FindOrInsert(DedupeEntry::Construct(gram_.begin()), at);
+ if (found) {
+ // Already present.
+ NGram already(at->key, gram_.Order());
+ ++(already.Count());
+ // Shift left by one.
+ memmove(gram_.begin(), gram_.begin() + 1, sizeof(WordIndex) * (gram_.Order() - 1));
+ return;
+ }
+ // Complete the write.
+ gram_.Count() = 1;
+ // Prepare the next n-gram.
+ if (reinterpret_cast<uint8_t*>(gram_.begin()) + gram_.TotalSize() != static_cast<uint8_t*>(block_->Get()) + block_size_) {
+ NGram last(gram_);
+ gram_.NextInMemory();
+ std::copy(last.begin() + 1, last.end(), gram_.begin());
+ return;
+ }
+ // Block end. Need to store the context in a temporary buffer.
+ std::copy(gram_.begin() + 1, gram_.end(), buffer_.get());
+ dedupe_.Clear(DedupeEntry::Construct(&dedupe_invalid_[0]));
+ block_->SetValidSize(block_size_);
+ gram_.ReBase((++block_)->Get());
+ std::copy(buffer_.get(), buffer_.get() + gram_.Order() - 1, gram_.begin());
+ }
+
+ private:
+ void AddUnigramWord(WordIndex index) {
+ *gram_.begin() = index;
+ gram_.Count() = 0;
+ gram_.NextInMemory();
+ if (gram_.Base() == static_cast<uint8_t*>(block_->Get()) + block_size_) {
+ block_->SetValidSize(block_size_);
+ gram_.ReBase((++block_)->Get());
+ }
+ }
+
+ util::stream::Link block_;
+
+ NGram gram_;
+
+ // This is the memory behind the invalid value in dedupe_.
+ std::vector<WordIndex> dedupe_invalid_;
+ // Hash table combiner implementation.
+ Dedupe dedupe_;
+
+ // Small buffer to hold existing ngrams when shifting across a block boundary.
+ boost::scoped_array<WordIndex> buffer_;
+
+ const std::size_t block_size_;
+};
+
+} // namespace
+
+float CorpusCount::DedupeMultiplier(std::size_t order) {
+ return kProbingMultiplier * static_cast<float>(sizeof(DedupeEntry)) / static_cast<float>(NGram::TotalSize(order));
+}
+
+CorpusCount::CorpusCount(util::FilePiece &from, int vocab_write, uint64_t &token_count, WordIndex &type_count, std::size_t entries_per_block)
+ : from_(from), vocab_write_(vocab_write), token_count_(token_count), type_count_(type_count),
+ dedupe_mem_size_(Dedupe::Size(entries_per_block, kProbingMultiplier)),
+ dedupe_mem_(util::MallocOrThrow(dedupe_mem_size_)) {
+ token_count_ = 0;
+ type_count_ = 0;
+}
+
+void CorpusCount::Run(const util::stream::ChainPosition &position) {
+ UTIL_TIMER("(%w s) Counted n-grams\n");
+
+ VocabHandout vocab(vocab_write_);
+ const WordIndex end_sentence = vocab.Lookup("</s>");
+ Writer writer(NGram::OrderFromSize(position.GetChain().EntrySize()), position, dedupe_mem_.get(), dedupe_mem_size_);
+ uint64_t count = 0;
+ try {
+ while(true) {
+ StringPiece line(from_.ReadLine());
+ writer.StartSentence();
+ for (util::TokenIter<util::AnyCharacter, true> w(line, " \t"); w; ++w) {
+ WordIndex word = vocab.Lookup(*w);
+ UTIL_THROW_IF(word <= 2, FormatLoadException, "Special word " << *w << " is not allowed in the corpus. I plan to support models containing <unk> in the future.");
+ writer.Append(word);
+ ++count;
+ }
+ writer.Append(end_sentence);
+ }
+ } catch (const util::EndOfFileException &e) {}
+ token_count_ = count;
+ type_count_ = vocab.Size();
+}
+
+} // namespace builder
+} // namespace lm
diff --git a/lm/builder/corpus_count.hh b/lm/builder/corpus_count.hh
new file mode 100644
index 000000000..e255bad13
--- /dev/null
+++ b/lm/builder/corpus_count.hh
@@ -0,0 +1,42 @@
+#ifndef LM_BUILDER_CORPUS_COUNT__
+#define LM_BUILDER_CORPUS_COUNT__
+
+#include "lm/word_index.hh"
+#include "util/scoped.hh"
+
+#include <cstddef>
+#include <string>
+#include <stdint.h>
+
+namespace util {
+class FilePiece;
+namespace stream {
+class ChainPosition;
+} // namespace stream
+} // namespace util
+
+namespace lm {
+namespace builder {
+
+class CorpusCount {
+ public:
+ // Memory usage will be DedupeMultipler(order) * block_size + total_chain_size + unknown vocab_hash_size
+ static float DedupeMultiplier(std::size_t order);
+
+ CorpusCount(util::FilePiece &from, int vocab_write, uint64_t &token_count, WordIndex &type_count, std::size_t entries_per_block);
+
+ void Run(const util::stream::ChainPosition &position);
+
+ private:
+ util::FilePiece &from_;
+ int vocab_write_;
+ uint64_t &token_count_;
+ WordIndex &type_count_;
+
+ std::size_t dedupe_mem_size_;
+ util::scoped_malloc dedupe_mem_;
+};
+
+} // namespace builder
+} // namespace lm
+#endif // LM_BUILDER_CORPUS_COUNT__
diff --git a/lm/builder/corpus_count_test.cc b/lm/builder/corpus_count_test.cc
new file mode 100644
index 000000000..8d53ca9d1
--- /dev/null
+++ b/lm/builder/corpus_count_test.cc
@@ -0,0 +1,76 @@
+#include "lm/builder/corpus_count.hh"
+
+#include "lm/builder/ngram.hh"
+#include "lm/builder/ngram_stream.hh"
+
+#include "util/file.hh"
+#include "util/file_piece.hh"
+#include "util/tokenize_piece.hh"
+#include "util/stream/chain.hh"
+#include "util/stream/stream.hh"
+
+#define BOOST_TEST_MODULE CorpusCountTest
+#include <boost/test/unit_test.hpp>
+
+namespace lm { namespace builder { namespace {
+
+#define Check(str, count) { \
+ BOOST_REQUIRE(stream); \
+ w = stream->begin(); \
+ for (util::TokenIter<util::AnyCharacter, true> t(str, " "); t; ++t, ++w) { \
+ BOOST_CHECK_EQUAL(*t, v[*w]); \
+ } \
+ BOOST_CHECK_EQUAL((uint64_t)count, stream->Count()); \
+ ++stream; \
+}
+
+BOOST_AUTO_TEST_CASE(Short) {
+ util::scoped_fd input_file(util::MakeTemp("corpus_count_test_temp"));
+ const char input[] = "looking on a little more loin\non a little more loin\non foo little more loin\nbar\n\n";
+ // Blocks of 10 are
+ // looking on a little more loin </s> on a little[duplicate] more[duplicate] loin[duplicate] </s>[duplicate] on[duplicate] foo
+ // little more loin </s> bar </s> </s>
+
+ util::WriteOrThrow(input_file.get(), input, sizeof(input) - 1);
+ util::FilePiece input_piece(input_file.release(), "temp file");
+
+ util::stream::ChainConfig config;
+ config.entry_size = NGram::TotalSize(3);
+ config.total_memory = config.entry_size * 20;
+ config.block_count = 2;
+
+ util::scoped_fd vocab(util::MakeTemp("corpus_count_test_vocab"));
+
+ util::stream::Chain chain(config);
+ NGramStream stream;
+ uint64_t token_count;
+ WordIndex type_count;
+ CorpusCount counter(input_piece, vocab.get(), token_count, type_count, chain.BlockSize() / chain.EntrySize());
+ chain >> boost::ref(counter) >> stream >> util::stream::kRecycle;
+
+ const char *v[] = {"<unk>", "<s>", "</s>", "looking", "on", "a", "little", "more", "loin", "foo", "bar"};
+
+ WordIndex *w;
+
+ Check("<s> <s> looking", 1);
+ Check("<s> looking on", 1);
+ Check("looking on a", 1);
+ Check("on a little", 2);
+ Check("a little more", 2);
+ Check("little more loin", 2);
+ Check("more loin </s>", 2);
+ Check("<s> <s> on", 2);
+ Check("<s> on a", 1);
+ Check("<s> on foo", 1);
+ Check("on foo little", 1);
+ Check("foo little more", 1);
+ Check("little more loin", 1);
+ Check("more loin </s>", 1);
+ Check("<s> <s> bar", 1);
+ Check("<s> bar </s>", 1);
+ Check("<s> <s> </s>", 1);
+ BOOST_CHECK(!stream);
+ BOOST_CHECK_EQUAL(sizeof(v) / sizeof(const char*), type_count);
+}
+
+}}} // namespaces
diff --git a/lm/builder/discount.hh b/lm/builder/discount.hh
new file mode 100644
index 000000000..754fb20db
--- /dev/null
+++ b/lm/builder/discount.hh
@@ -0,0 +1,26 @@
+#ifndef BUILDER_DISCOUNT__
+#define BUILDER_DISCOUNT__
+
+#include <algorithm>
+
+#include <inttypes.h>
+
+namespace lm {
+namespace builder {
+
+struct Discount {
+ float amount[4];
+
+ float Get(uint64_t count) const {
+ return amount[std::min<uint64_t>(count, 3)];
+ }
+
+ float Apply(uint64_t count) const {
+ return static_cast<float>(count) - Get(count);
+ }
+};
+
+} // namespace builder
+} // namespace lm
+
+#endif // BUILDER_DISCOUNT__
diff --git a/lm/builder/header_info.hh b/lm/builder/header_info.hh
new file mode 100644
index 000000000..ccca14566
--- /dev/null
+++ b/lm/builder/header_info.hh
@@ -0,0 +1,20 @@
+#ifndef LM_BUILDER_HEADER_INFO__
+#define LM_BUILDER_HEADER_INFO__
+
+#include <string>
+#include <stdint.h>
+
+// Some configuration info that is used to add
+// comments to the beginning of an ARPA file
+struct HeaderInfo {
+ const std::string input_file;
+ const uint64_t token_count;
+
+ HeaderInfo(const std::string& input_file_in, uint64_t token_count_in)
+ : input_file(input_file_in), token_count(token_count_in) {}
+
+ // TODO: Add smoothing type
+ // TODO: More info if multiple models were interpolated
+};
+
+#endif
diff --git a/lm/builder/initial_probabilities.cc b/lm/builder/initial_probabilities.cc
new file mode 100644
index 000000000..58b42a20c
--- /dev/null
+++ b/lm/builder/initial_probabilities.cc
@@ -0,0 +1,136 @@
+#include "lm/builder/initial_probabilities.hh"
+
+#include "lm/builder/discount.hh"
+#include "lm/builder/ngram_stream.hh"
+#include "lm/builder/sort.hh"
+#include "util/file.hh"
+#include "util/stream/chain.hh"
+#include "util/stream/io.hh"
+#include "util/stream/stream.hh"
+
+#include <vector>
+
+namespace lm { namespace builder {
+
+namespace {
+struct BufferEntry {
+ // Gamma from page 20 of Chen and Goodman.
+ float gamma;
+ // \sum_w a(c w) for all w.
+ float denominator;
+};
+
+// Extract an array of gamma from an array of BufferEntry.
+class OnlyGamma {
+ public:
+ void Run(const util::stream::ChainPosition &position) {
+ for (util::stream::Link block_it(position); block_it; ++block_it) {
+ float *out = static_cast<float*>(block_it->Get());
+ const float *in = out;
+ const float *end = static_cast<const float*>(block_it->ValidEnd());
+ for (out += 1, in += 2; in < end; out += 1, in += 2) {
+ *out = *in;
+ }
+ block_it->SetValidSize(block_it->ValidSize() / 2);
+ }
+ }
+};
+
+class AddRight {
+ public:
+ AddRight(const Discount &discount, const util::stream::ChainPosition &input)
+ : discount_(discount), input_(input) {}
+
+ void Run(const util::stream::ChainPosition &output) {
+ NGramStream in(input_);
+ util::stream::Stream out(output);
+
+ std::vector<WordIndex> previous(in->Order() - 1);
+ const std::size_t size = sizeof(WordIndex) * previous.size();
+ for(; in; ++out) {
+ memcpy(&previous[0], in->begin(), size);
+ uint64_t denominator = 0;
+ uint64_t counts[4];
+ memset(counts, 0, sizeof(counts));
+ do {
+ denominator += in->Count();
+ ++counts[std::min(in->Count(), static_cast<uint64_t>(3))];
+ } while (++in && !memcmp(&previous[0], in->begin(), size));
+ BufferEntry &entry = *reinterpret_cast<BufferEntry*>(out.Get());
+ entry.denominator = static_cast<float>(denominator);
+ entry.gamma = 0.0;
+ for (unsigned i = 1; i <= 3; ++i) {
+ entry.gamma += discount_.Get(i) * static_cast<float>(counts[i]);
+ }
+ entry.gamma /= entry.denominator;
+ }
+ out.Poison();
+ }
+
+ private:
+ const Discount &discount_;
+ const util::stream::ChainPosition input_;
+};
+
+class MergeRight {
+ public:
+ MergeRight(bool interpolate_unigrams, const util::stream::ChainPosition &from_adder, const Discount &discount)
+ : interpolate_unigrams_(interpolate_unigrams), from_adder_(from_adder), discount_(discount) {}
+
+ // calculate the initial probability of each n-gram (before order-interpolation)
+ // Run() gets invoked once for each order
+ void Run(const util::stream::ChainPosition &primary) {
+ util::stream::Stream summed(from_adder_);
+
+ NGramStream grams(primary);
+
+ // Without interpolation, the interpolation weight goes to <unk>.
+ if (grams->Order() == 1 && !interpolate_unigrams_) {
+ BufferEntry sums(*static_cast<const BufferEntry*>(summed.Get()));
+ assert(*grams->begin() == kUNK);
+ grams->Value().uninterp.prob = sums.gamma;
+ grams->Value().uninterp.gamma = 0.0;
+ while (++grams) {
+ grams->Value().uninterp.prob = discount_.Apply(grams->Count()) / sums.denominator;
+ grams->Value().uninterp.gamma = 0.0;
+ }
+ ++summed;
+ return;
+ }
+
+ std::vector<WordIndex> previous(grams->Order() - 1);
+ const std::size_t size = sizeof(WordIndex) * previous.size();
+ for (; grams; ++summed) {
+ memcpy(&previous[0], grams->begin(), size);
+ const BufferEntry &sums = *static_cast<const BufferEntry*>(summed.Get());
+ do {
+ Payload &pay = grams->Value();
+ pay.uninterp.prob = discount_.Apply(pay.count) / sums.denominator;
+ pay.uninterp.gamma = sums.gamma;
+ } while (++grams && !memcmp(&previous[0], grams->begin(), size));
+ }
+ }
+
+ private:
+ bool interpolate_unigrams_;
+ util::stream::ChainPosition from_adder_;
+ Discount discount_;
+};
+
+} // namespace
+
+void InitialProbabilities(const InitialProbabilitiesConfig &config, const std::vector<Discount> &discounts, Chains &primary, Chains &second_in, Chains &gamma_out) {
+ util::stream::ChainConfig gamma_config = config.adder_out;
+ gamma_config.entry_size = sizeof(BufferEntry);
+ for (size_t i = 0; i < primary.size(); ++i) {
+ util::stream::ChainPosition second(second_in[i].Add());
+ second_in[i] >> util::stream::kRecycle;
+ gamma_out.push_back(gamma_config);
+ gamma_out[i] >> AddRight(discounts[i], second);
+ primary[i] >> MergeRight(config.interpolate_unigrams, gamma_out[i].Add(), discounts[i]);
+ // Don't bother with the OnlyGamma thread for something to discard.
+ if (i) gamma_out[i] >> OnlyGamma();
+ }
+}
+
+}} // namespaces
diff --git a/lm/builder/initial_probabilities.hh b/lm/builder/initial_probabilities.hh
new file mode 100644
index 000000000..626388eb0
--- /dev/null
+++ b/lm/builder/initial_probabilities.hh
@@ -0,0 +1,34 @@
+#ifndef LM_BUILDER_INITIAL_PROBABILITIES__
+#define LM_BUILDER_INITIAL_PROBABILITIES__
+
+#include "lm/builder/discount.hh"
+#include "util/stream/config.hh"
+
+#include <vector>
+
+namespace lm {
+namespace builder {
+class Chains;
+
+struct InitialProbabilitiesConfig {
+ // These should be small buffers to keep the adder from getting too far ahead
+ util::stream::ChainConfig adder_in;
+ util::stream::ChainConfig adder_out;
+ // SRILM doesn't normally interpolate unigrams.
+ bool interpolate_unigrams;
+};
+
+/* Compute initial (uninterpolated) probabilities
+ * primary: the normal chain of n-grams. Incoming is context sorted adjusted
+ * counts. Outgoing has uninterpolated probabilities for use by Interpolate.
+ * second_in: a second copy of the primary input. Discard the output.
+ * gamma_out: Computed gamma values are output on these chains in suffix order.
+ * The values are bare floats and should be buffered for interpolation to
+ * use.
+ */
+void InitialProbabilities(const InitialProbabilitiesConfig &config, const std::vector<Discount> &discounts, Chains &primary, Chains &second_in, Chains &gamma_out);
+
+} // namespace builder
+} // namespace lm
+
+#endif // LM_BUILDER_INITIAL_PROBABILITIES__
diff --git a/lm/builder/interpolate.cc b/lm/builder/interpolate.cc
new file mode 100644
index 000000000..500268069
--- /dev/null
+++ b/lm/builder/interpolate.cc
@@ -0,0 +1,65 @@
+#include "lm/builder/interpolate.hh"
+
+#include "lm/builder/joint_order.hh"
+#include "lm/builder/multi_stream.hh"
+#include "lm/builder/sort.hh"
+#include "lm/lm_exception.hh"
+
+#include <assert.h>
+
+namespace lm { namespace builder {
+namespace {
+
+class Callback {
+ public:
+ Callback(float uniform_prob, const ChainPositions &backoffs) : backoffs_(backoffs.size()), probs_(backoffs.size() + 2) {
+ probs_[0] = uniform_prob;
+ for (std::size_t i = 0; i < backoffs.size(); ++i) {
+ backoffs_.push_back(backoffs[i]);
+ }
+ }
+
+ ~Callback() {
+ for (std::size_t i = 0; i < backoffs_.size(); ++i) {
+ if (backoffs_[i]) {
+ std::cerr << "Backoffs do not match for order " << (i + 1) << std::endl;
+ abort();
+ }
+ }
+ }
+
+ void Enter(unsigned order_minus_1, NGram &gram) {
+ Payload &pay = gram.Value();
+ pay.complete.prob = pay.uninterp.prob + pay.uninterp.gamma * probs_[order_minus_1];
+ probs_[order_minus_1 + 1] = pay.complete.prob;
+ pay.complete.prob = log10(pay.complete.prob);
+ // TODO: this is a hack to skip n-grams that don't appear as context. Pruning will require some different handling.
+ if (order_minus_1 < backoffs_.size() && *(gram.end() - 1) != kUNK && *(gram.end() - 1) != kEOS) {
+ pay.complete.backoff = log10(*static_cast<const float*>(backoffs_[order_minus_1].Get()));
+ ++backoffs_[order_minus_1];
+ } else {
+ // Not a context.
+ pay.complete.backoff = 0.0;
+ }
+ }
+
+ void Exit(unsigned, const NGram &) const {}
+
+ private:
+ FixedArray<util::stream::Stream> backoffs_;
+
+ std::vector<float> probs_;
+};
+} // namespace
+
+Interpolate::Interpolate(uint64_t unigram_count, const ChainPositions &backoffs)
+ : uniform_prob_(1.0 / static_cast<float>(unigram_count - 1)), backoffs_(backoffs) {}
+
+// perform order-wise interpolation
+void Interpolate::Run(const ChainPositions &positions) {
+ assert(positions.size() == backoffs_.size() + 1);
+ Callback callback(uniform_prob_, backoffs_);
+ JointOrder<Callback, SuffixOrder>(positions, callback);
+}
+
+}} // namespaces
diff --git a/lm/builder/interpolate.hh b/lm/builder/interpolate.hh
new file mode 100644
index 000000000..9268d404c
--- /dev/null
+++ b/lm/builder/interpolate.hh
@@ -0,0 +1,27 @@
+#ifndef LM_BUILDER_INTERPOLATE__
+#define LM_BUILDER_INTERPOLATE__
+
+#include <stdint.h>
+
+#include "lm/builder/multi_stream.hh"
+
+namespace lm { namespace builder {
+
+/* Interpolate step.
+ * Input: suffix sorted n-grams with (p_uninterpolated, gamma) from
+ * InitialProbabilities.
+ * Output: suffix sorted n-grams with complete probability
+ */
+class Interpolate {
+ public:
+ explicit Interpolate(uint64_t unigram_count, const ChainPositions &backoffs);
+
+ void Run(const ChainPositions &positions);
+
+ private:
+ float uniform_prob_;
+ ChainPositions backoffs_;
+};
+
+}} // namespaces
+#endif // LM_BUILDER_INTERPOLATE__
diff --git a/lm/builder/joint_order.hh b/lm/builder/joint_order.hh
new file mode 100644
index 000000000..b56201447
--- /dev/null
+++ b/lm/builder/joint_order.hh
@@ -0,0 +1,43 @@
+#ifndef LM_BUILDER_JOINT_ORDER__
+#define LM_BUILDER_JOINT_ORDER__
+
+#include "lm/builder/multi_stream.hh"
+#include "lm/lm_exception.hh"
+
+#include <string.h>
+
+namespace lm { namespace builder {
+
+template <class Callback, class Compare> void JointOrder(const ChainPositions &positions, Callback &callback) {
+ // Allow matching to reference streams[-1].
+ NGramStreams streams_with_dummy;
+ streams_with_dummy.InitWithDummy(positions);
+ NGramStream *streams = streams_with_dummy.begin() + 1;
+
+ unsigned int order;
+ for (order = 0; order < positions.size() && streams[order]; ++order) {}
+ assert(order); // should always have <unk>.
+ unsigned int current = 0;
+ while (true) {
+ // Does the context match the lower one?
+ if (!memcmp(streams[static_cast<int>(current) - 1]->begin(), streams[current]->begin() + Compare::kMatchOffset, sizeof(WordIndex) * current)) {
+ callback.Enter(current, *streams[current]);
+ // Transition to looking for extensions.
+ if (++current < order) continue;
+ }
+ // No extension left.
+ while(true) {
+ assert(current > 0);
+ --current;
+ callback.Exit(current, *streams[current]);
+ if (++streams[current]) break;
+ UTIL_THROW_IF(order != current + 1, FormatLoadException, "Detected n-gram without matching suffix");
+ order = current;
+ if (!order) return;
+ }
+ }
+}
+
+}} // namespaces
+
+#endif // LM_BUILDER_JOINT_ORDER__
diff --git a/lm/builder/main.cc b/lm/builder/main.cc
new file mode 100644
index 000000000..90b9dca2d
--- /dev/null
+++ b/lm/builder/main.cc
@@ -0,0 +1,94 @@
+#include "lm/builder/pipeline.hh"
+#include "util/file.hh"
+#include "util/file_piece.hh"
+#include "util/usage.hh"
+
+#include <iostream>
+
+#include <boost/program_options.hpp>
+
+namespace {
+class SizeNotify {
+ public:
+ SizeNotify(std::size_t &out) : behind_(out) {}
+
+ void operator()(const std::string &from) {
+ behind_ = util::ParseSize(from);
+ }
+
+ private:
+ std::size_t &behind_;
+};
+
+boost::program_options::typed_value<std::string> *SizeOption(std::size_t &to, const char *default_value) {
+ return boost::program_options::value<std::string>()->notifier(SizeNotify(to))->default_value(default_value);
+}
+
+} // namespace
+
+int main(int argc, char *argv[]) {
+ try {
+ namespace po = boost::program_options;
+ po::options_description options("Language model building options");
+ lm::builder::PipelineConfig pipeline;
+
+ options.add_options()
+ ("order,o", po::value<std::size_t>(&pipeline.order)->required(), "Order of the model")
+ ("interpolate_unigrams", po::bool_switch(&pipeline.initial_probs.interpolate_unigrams), "Interpolate the unigrams (default: emulate SRILM by not interpolating)")
+ ("temp_prefix,T", po::value<std::string>(&pipeline.sort.temp_prefix)->default_value("/tmp/lm"), "Temporary file prefix")
+ ("memory,S", SizeOption(pipeline.sort.total_memory, util::GuessPhysicalMemory() ? "80%" : "1G"), "Sorting memory")
+ ("vocab_memory", SizeOption(pipeline.assume_vocab_hash_size, "50M"), "Assume that the vocabulary hash table will use this much memory for purposes of calculating total memory in the count step")
+ ("minimum_block", SizeOption(pipeline.minimum_block, "8K"), "Minimum block size to allow")
+ ("sort_block", SizeOption(pipeline.sort.buffer_size, "64M"), "Size of IO operations for sort (determines arity)")
+ ("block_count", po::value<std::size_t>(&pipeline.block_count)->default_value(2), "Block count (per order)")
+ ("vocab_file", po::value<std::string>(&pipeline.vocab_file)->default_value(""), "Location to write vocabulary file")
+ ("verbose_header", po::bool_switch(&pipeline.verbose_header), "Add a verbose header to the ARPA file that includes information such as token count, smoothing type, etc.");
+ if (argc == 1) {
+ std::cerr <<
+ "Builds unpruned language models with modified Kneser-Ney smoothing.\n\n"
+ "Please cite:\n"
+ "@inproceedings{kenlm,\n"
+ "author = {Kenneth Heafield},\n"
+ "title = {{KenLM}: Faster and Smaller Language Model Queries},\n"
+ "booktitle = {Proceedings of the Sixth Workshop on Statistical Machine Translation},\n"
+ "month = {July}, year={2011},\n"
+ "address = {Edinburgh, UK},\n"
+ "publisher = {Association for Computational Linguistics},\n"
+ "}\n\n"
+ "Provide the corpus on stdin. The ARPA file will be written to stdout. Order of\n"
+ "the model (-o) is the only mandatory option. As this is an on-disk program,\n"
+ "setting the temporary file location (-T) and sorting memory (-S) is recommended.\n\n"
+ "Memory sizes are specified like GNU sort: a number followed by a unit character.\n"
+ "Valid units are \% for percentage of memory (supported platforms only) and (in\n"
+ "increasing powers of 1024): b, K, M, G, T, P, E, Z, Y. Default is K (*1024).\n\n";
+ std::cerr << options << std::endl;
+ return 1;
+ }
+ po::variables_map vm;
+ po::store(po::parse_command_line(argc, argv, options), vm);
+ po::notify(vm);
+
+ util::NormalizeTempPrefix(pipeline.sort.temp_prefix);
+
+ lm::builder::InitialProbabilitiesConfig &initial = pipeline.initial_probs;
+ // TODO: evaluate options for these.
+ initial.adder_in.total_memory = 32768;
+ initial.adder_in.block_count = 2;
+ initial.adder_out.total_memory = 32768;
+ initial.adder_out.block_count = 2;
+ pipeline.read_backoffs = initial.adder_out;
+
+ // Read from stdin
+ try {
+ lm::builder::Pipeline(pipeline, 0, 1);
+ } catch (const util::MallocException &e) {
+ std::cerr << e.what() << std::endl;
+ std::cerr << "Try rerunning with a more conservative -S setting than " << vm["memory"].as<std::string>() << std::endl;
+ return 1;
+ }
+ util::PrintUsage(std::cerr);
+ } catch (const std::exception &e) {
+ std::cerr << e.what() << std::endl;
+ return 1;
+ }
+}
diff --git a/lm/builder/multi_stream.hh b/lm/builder/multi_stream.hh
new file mode 100644
index 000000000..707a98c7f
--- /dev/null
+++ b/lm/builder/multi_stream.hh
@@ -0,0 +1,180 @@
+#ifndef LM_BUILDER_MULTI_STREAM__
+#define LM_BUILDER_MULTI_STREAM__
+
+#include "lm/builder/ngram_stream.hh"
+#include "util/scoped.hh"
+#include "util/stream/chain.hh"
+
+#include <cstddef>
+#include <new>
+
+#include <assert.h>
+#include <stdlib.h>
+
+namespace lm { namespace builder {
+
+template <class T> class FixedArray {
+ public:
+ explicit FixedArray(std::size_t count) {
+ Init(count);
+ }
+
+ FixedArray() : newed_end_(NULL) {}
+
+ void Init(std::size_t count) {
+ assert(!block_.get());
+ block_.reset(malloc(sizeof(T) * count));
+ if (!block_.get()) throw std::bad_alloc();
+ newed_end_ = begin();
+ }
+
+ FixedArray(const FixedArray &from) {
+ std::size_t size = from.newed_end_ - static_cast<const T*>(from.block_.get());
+ Init(size);
+ for (std::size_t i = 0; i < size; ++i) {
+ new(end()) T(from[i]);
+ Constructed();
+ }
+ }
+
+ ~FixedArray() { clear(); }
+
+ T *begin() { return static_cast<T*>(block_.get()); }
+ const T *begin() const { return static_cast<const T*>(block_.get()); }
+ // Always call Constructed after successful completion of new.
+ T *end() { return newed_end_; }
+ const T *end() const { return newed_end_; }
+
+ T &back() { return *(end() - 1); }
+ const T &back() const { return *(end() - 1); }
+
+ std::size_t size() const { return end() - begin(); }
+ bool empty() const { return begin() == end(); }
+
+ T &operator[](std::size_t i) { return begin()[i]; }
+ const T &operator[](std::size_t i) const { return begin()[i]; }
+
+ template <class C> void push_back(const C &c) {
+ new (end()) T(c);
+ Constructed();
+ }
+
+ void clear() {
+ for (T *i = begin(); i != end(); ++i)
+ i->~T();
+ newed_end_ = begin();
+ }
+
+ protected:
+ void Constructed() {
+ ++newed_end_;
+ }
+
+ private:
+ util::scoped_malloc block_;
+
+ T *newed_end_;
+};
+
+class Chains;
+
+class ChainPositions : public FixedArray<util::stream::ChainPosition> {
+ public:
+ ChainPositions() {}
+
+ void Init(Chains &chains);
+
+ explicit ChainPositions(Chains &chains) {
+ Init(chains);
+ }
+};
+
+class Chains : public FixedArray<util::stream::Chain> {
+ private:
+ template <class T, void (T::*ptr)(const ChainPositions &) = &T::Run> struct CheckForRun {
+ typedef Chains type;
+ };
+
+ public:
+ explicit Chains(std::size_t limit) : FixedArray<util::stream::Chain>(limit) {}
+
+ template <class Worker> typename CheckForRun<Worker>::type &operator>>(const Worker &worker) {
+ threads_.push_back(new util::stream::Thread(ChainPositions(*this), worker));
+ return *this;
+ }
+
+ template <class Worker> typename CheckForRun<Worker>::type &operator>>(const boost::reference_wrapper<Worker> &worker) {
+ threads_.push_back(new util::stream::Thread(ChainPositions(*this), worker));
+ return *this;
+ }
+
+ Chains &operator>>(const util::stream::Recycler &recycler) {
+ for (util::stream::Chain *i = begin(); i != end(); ++i)
+ *i >> recycler;
+ return *this;
+ }
+
+ void Wait(bool release_memory = true) {
+ threads_.clear();
+ for (util::stream::Chain *i = begin(); i != end(); ++i) {
+ i->Wait(release_memory);
+ }
+ }
+
+ private:
+ boost::ptr_vector<util::stream::Thread> threads_;
+
+ Chains(const Chains &);
+ void operator=(const Chains &);
+};
+
+inline void ChainPositions::Init(Chains &chains) {
+ FixedArray<util::stream::ChainPosition>::Init(chains.size());
+ for (util::stream::Chain *i = chains.begin(); i != chains.end(); ++i) {
+ new (end()) util::stream::ChainPosition(i->Add()); Constructed();
+ }
+}
+
+inline Chains &operator>>(Chains &chains, ChainPositions &positions) {
+ positions.Init(chains);
+ return chains;
+}
+
+class NGramStreams : public FixedArray<NGramStream> {
+ public:
+ NGramStreams() {}
+
+ // This puts a dummy NGramStream at the beginning (useful to algorithms that need to reference something at the beginning).
+ void InitWithDummy(const ChainPositions &positions) {
+ FixedArray<NGramStream>::Init(positions.size() + 1);
+ new (end()) NGramStream(); Constructed();
+ for (const util::stream::ChainPosition *i = positions.begin(); i != positions.end(); ++i) {
+ push_back(*i);
+ }
+ }
+
+ // Limit restricts to positions[0,limit)
+ void Init(const ChainPositions &positions, std::size_t limit) {
+ FixedArray<NGramStream>::Init(limit);
+ for (const util::stream::ChainPosition *i = positions.begin(); i != positions.begin() + limit; ++i) {
+ push_back(*i);
+ }
+ }
+ void Init(const ChainPositions &positions) {
+ Init(positions, positions.size());
+ }
+
+ NGramStreams(const ChainPositions &positions) {
+ Init(positions);
+ }
+};
+
+inline Chains &operator>>(Chains &chains, NGramStreams &streams) {
+ ChainPositions positions;
+ chains >> positions;
+ streams.Init(positions);
+ return chains;
+}
+
+}} // namespaces
+#endif // LM_BUILDER_MULTI_STREAM__
diff --git a/lm/builder/ngram.hh b/lm/builder/ngram.hh
new file mode 100644
index 000000000..2984ed0b6
--- /dev/null
+++ b/lm/builder/ngram.hh
@@ -0,0 +1,84 @@
+#ifndef LM_BUILDER_NGRAM__
+#define LM_BUILDER_NGRAM__
+
+#include "lm/weights.hh"
+#include "lm/word_index.hh"
+
+#include <cstddef>
+
+#include <assert.h>
+#include <stdint.h>
+#include <string.h>
+
+namespace lm {
+namespace builder {
+
+struct Uninterpolated {
+ float prob; // Uninterpolated probability.
+ float gamma; // Interpolation weight for lower order.
+};
+
+union Payload {
+ uint64_t count;
+ Uninterpolated uninterp;
+ ProbBackoff complete;
+};
+
+class NGram {
+ public:
+ NGram(void *begin, std::size_t order)
+ : begin_(static_cast<WordIndex*>(begin)), end_(begin_ + order) {}
+
+ const uint8_t *Base() const { return reinterpret_cast<const uint8_t*>(begin_); }
+ uint8_t *Base() { return reinterpret_cast<uint8_t*>(begin_); }
+
+ void ReBase(void *to) {
+ std::size_t difference = end_ - begin_;
+ begin_ = reinterpret_cast<WordIndex*>(to);
+ end_ = begin_ + difference;
+ }
+
+ // Would do operator++ but that can get confusing for a stream.
+ void NextInMemory() {
+ ReBase(&Value() + 1);
+ }
+
+ // Lower-case in deference to STL.
+ const WordIndex *begin() const { return begin_; }
+ WordIndex *begin() { return begin_; }
+ const WordIndex *end() const { return end_; }
+ WordIndex *end() { return end_; }
+
+ const Payload &Value() const { return *reinterpret_cast<const Payload *>(end_); }
+ Payload &Value() { return *reinterpret_cast<Payload *>(end_); }
+
+ uint64_t &Count() { return Value().count; }
+ const uint64_t Count() const { return Value().count; }
+
+ std::size_t Order() const { return end_ - begin_; }
+
+ static std::size_t TotalSize(std::size_t order) {
+ return order * sizeof(WordIndex) + sizeof(Payload);
+ }
+ std::size_t TotalSize() const {
+ // Compiler should optimize this.
+ return TotalSize(Order());
+ }
+ static std::size_t OrderFromSize(std::size_t size) {
+ std::size_t ret = (size - sizeof(Payload)) / sizeof(WordIndex);
+ assert(size == TotalSize(ret));
+ return ret;
+ }
+
+ private:
+ WordIndex *begin_, *end_;
+};
+
+const WordIndex kUNK = 0;
+const WordIndex kBOS = 1;
+const WordIndex kEOS = 2;
+
+} // namespace builder
+} // namespace lm
+
+#endif // LM_BUILDER_NGRAM__
diff --git a/lm/builder/ngram_stream.hh b/lm/builder/ngram_stream.hh
new file mode 100644
index 000000000..3c9946643
--- /dev/null
+++ b/lm/builder/ngram_stream.hh
@@ -0,0 +1,55 @@
+#ifndef LM_BUILDER_NGRAM_STREAM__
+#define LM_BUILDER_NGRAM_STREAM__
+
+#include "lm/builder/ngram.hh"
+#include "util/stream/chain.hh"
+#include "util/stream/stream.hh"
+
+#include <cstddef>
+
+namespace lm { namespace builder {
+
+class NGramStream {
+ public:
+ NGramStream() : gram_(NULL, 0) {}
+
+ NGramStream(const util::stream::ChainPosition &position) : gram_(NULL, 0) {
+ Init(position);
+ }
+
+ void Init(const util::stream::ChainPosition &position) {
+ stream_.Init(position);
+ gram_ = NGram(stream_.Get(), NGram::OrderFromSize(position.GetChain().EntrySize()));
+ }
+
+ NGram &operator*() { return gram_; }
+ const NGram &operator*() const { return gram_; }
+
+ NGram *operator->() { return &gram_; }
+ const NGram *operator->() const { return &gram_; }
+
+ void *Get() { return stream_.Get(); }
+ const void *Get() const { return stream_.Get(); }
+
+ operator bool() const { return stream_; }
+ bool operator!() const { return !stream_; }
+ void Poison() { stream_.Poison(); }
+
+ NGramStream &operator++() {
+ ++stream_;
+ gram_.ReBase(stream_.Get());
+ return *this;
+ }
+
+ private:
+ NGram gram_;
+ util::stream::Stream stream_;
+};
+
+inline util::stream::Chain &operator>>(util::stream::Chain &chain, NGramStream &str) {
+ str.Init(chain.Add());
+ return chain;
+}
+
+}} // namespaces
+#endif // LM_BUILDER_NGRAM_STREAM__
diff --git a/lm/builder/pipeline.cc b/lm/builder/pipeline.cc
new file mode 100644
index 000000000..14a1f7218
--- /dev/null
+++ b/lm/builder/pipeline.cc
@@ -0,0 +1,320 @@
+#include "lm/builder/pipeline.hh"
+
+#include "lm/builder/adjust_counts.hh"
+#include "lm/builder/corpus_count.hh"
+#include "lm/builder/initial_probabilities.hh"
+#include "lm/builder/interpolate.hh"
+#include "lm/builder/print.hh"
+#include "lm/builder/sort.hh"
+
+#include "lm/sizes.hh"
+
+#include "util/exception.hh"
+#include "util/file.hh"
+#include "util/stream/io.hh"
+
+#include <algorithm>
+#include <iostream>
+#include <vector>
+
+namespace lm { namespace builder {
+
+namespace {
+void PrintStatistics(const std::vector<uint64_t> &counts, const std::vector<Discount> &discounts) {
+ std::cerr << "Statistics:\n";
+ for (size_t i = 0; i < counts.size(); ++i) {
+ std::cerr << (i + 1) << ' ' << counts[i];
+ for (size_t d = 1; d <= 3; ++d)
+ std::cerr << " D" << d << (d == 3 ? "+=" : "=") << discounts[i].amount[d];
+ std::cerr << '\n';
+ }
+}
+
+class Master {
+ public:
+ explicit Master(const PipelineConfig &config)
+ : config_(config), chains_(config.order), files_(config.order) {
+ config_.minimum_block = std::max(NGram::TotalSize(config_.order), config_.minimum_block);
+ }
+
+ const PipelineConfig &Config() const { return config_; }
+
+ Chains &MutableChains() { return chains_; }
+
+ template <class T> Master &operator>>(const T &worker) {
+ chains_ >> worker;
+ return *this;
+ }
+
+ // This takes the (partially) sorted ngrams and sets up for adjusted counts.
+ void InitForAdjust(util::stream::Sort<SuffixOrder, AddCombiner> &ngrams, WordIndex types) {
+ const std::size_t each_order_min = config_.minimum_block * config_.block_count;
+ // We know how many unigrams there are. Don't allocate more than needed to them.
+ const std::size_t min_chains = (config_.order - 1) * each_order_min +
+ std::min(types * NGram::TotalSize(1), each_order_min);
+ // Do merge sort with calculated laziness.
+ const std::size_t merge_using = ngrams.Merge(std::min(config_.TotalMemory() - min_chains, ngrams.DefaultLazy()));
+
+ std::vector<uint64_t> count_bounds(1, types);
+ CreateChains(config_.TotalMemory() - merge_using, count_bounds);
+ ngrams.Output(chains_.back(), merge_using);
+
+ // Setup unigram file.
+ files_.push_back(util::MakeTemp(config_.TempPrefix()));
+ }
+
+ // For initial probabilities, but this is generic.
+ void SortAndReadTwice(const std::vector<uint64_t> &counts, Sorts<ContextOrder> &sorts, Chains &second, util::stream::ChainConfig second_config) {
+ // Do merge first before allocating chain memory.
+ for (std::size_t i = 1; i < config_.order; ++i) {
+ sorts[i - 1].Merge(0);
+ }
+ // There's no lazy merge, so just divide memory amongst the chains.
+ CreateChains(config_.TotalMemory(), counts);
+ chains_.back().ActivateProgress();
+ chains_[0] >> files_[0].Source();
+ second_config.entry_size = NGram::TotalSize(1);
+ second.push_back(second_config);
+ second.back() >> files_[0].Source();
+ for (std::size_t i = 1; i < config_.order; ++i) {
+ util::scoped_fd fd(sorts[i - 1].StealCompleted());
+ chains_[i].SetProgressTarget(util::SizeOrThrow(fd.get()));
+ chains_[i] >> util::stream::PRead(util::DupOrThrow(fd.get()), true);
+ second_config.entry_size = NGram::TotalSize(i + 1);
+ second.push_back(second_config);
+ second.back() >> util::stream::PRead(fd.release(), true);
+ }
+ }
+
+ // There is no sort after this, so go for broke on lazy merging.
+ template <class Compare> void MaximumLazyInput(const std::vector<uint64_t> &counts, Sorts<Compare> &sorts) {
+ // Determine the minimum we can use for all the chains.
+ std::size_t min_chains = 0;
+ for (std::size_t i = 0; i < config_.order; ++i) {
+ min_chains += std::min(counts[i] * NGram::TotalSize(i + 1), static_cast<uint64_t>(config_.minimum_block));
+ }
+ std::size_t for_merge = min_chains > config_.TotalMemory() ? 0 : (config_.TotalMemory() - min_chains);
+ std::vector<std::size_t> laziness;
+ // Prioritize longer n-grams.
+ for (util::stream::Sort<SuffixOrder> *i = sorts.end() - 1; i >= sorts.begin(); --i) {
+ laziness.push_back(i->Merge(for_merge));
+ assert(for_merge >= laziness.back());
+ for_merge -= laziness.back();
+ }
+ std::reverse(laziness.begin(), laziness.end());
+
+ CreateChains(for_merge + min_chains, counts);
+ chains_.back().ActivateProgress();
+ chains_[0] >> files_[0].Source();
+ for (std::size_t i = 1; i < config_.order; ++i) {
+ sorts[i - 1].Output(chains_[i], laziness[i - 1]);
+ }
+ }
+
+ void BufferFinal(const std::vector<uint64_t> &counts) {
+ chains_[0] >> files_[0].Sink();
+ for (std::size_t i = 1; i < config_.order; ++i) {
+ files_.push_back(util::MakeTemp(config_.TempPrefix()));
+ chains_[i] >> files_[i].Sink();
+ }
+ chains_.Wait(true);
+ // Use less memory. Because we can.
+ CreateChains(std::min(config_.sort.buffer_size * config_.order, config_.TotalMemory()), counts);
+ for (std::size_t i = 0; i < config_.order; ++i) {
+ chains_[i] >> files_[i].Source();
+ }
+ }
+
+ template <class Compare> void SetupSorts(Sorts<Compare> &sorts) {
+ sorts.Init(config_.order - 1);
+ // Unigrams don't get sorted because their order is always the same.
+ chains_[0] >> files_[0].Sink();
+ for (std::size_t i = 1; i < config_.order; ++i) {
+ sorts.push_back(chains_[i], config_.sort, Compare(i + 1));
+ }
+ chains_.Wait(true);
+ }
+
+ private:
+ // Create chains, allocating memory to them. Totally heuristic. Count
+ // bounds are upper bounds on the counts or not present.
+ void CreateChains(std::size_t remaining_mem, const std::vector<uint64_t> &count_bounds) {
+ std::vector<std::size_t> assignments;
+ assignments.reserve(config_.order);
+ // Start by assigning maximum memory usage (to be refined later).
+ for (std::size_t i = 0; i < count_bounds.size(); ++i) {
+ assignments.push_back(static_cast<std::size_t>(std::min(
+ static_cast<uint64_t>(remaining_mem),
+ count_bounds[i] * static_cast<uint64_t>(NGram::TotalSize(i + 1)))));
+ }
+ assignments.resize(config_.order, remaining_mem);
+
+ // Now we know how much memory everybody wants. How much will they get?
+ // Proportional to this.
+ std::vector<float> portions;
+ // Indices of orders that have yet to be assigned.
+ std::vector<std::size_t> unassigned;
+ for (std::size_t i = 0; i < config_.order; ++i) {
+ portions.push_back(static_cast<float>((i+1) * NGram::TotalSize(i+1)));
+ unassigned.push_back(i);
+ }
+ /*If somebody doesn't eat their full dinner, give it to the rest of the
+ * family. Then somebody else might not eat their full dinner etc. Ends
+ * when everybody unassigned is hungry.
+ */
+ float sum;
+ bool found_more;
+ std::vector<std::size_t> block_count(config_.order);
+ do {
+ sum = 0.0;
+ for (std::size_t i = 0; i < unassigned.size(); ++i) {
+ sum += portions[unassigned[i]];
+ }
+ found_more = false;
+ // If the proportional assignment is more than needed, give it just what it needs.
+ for (std::vector<std::size_t>::iterator i = unassigned.begin(); i != unassigned.end();) {
+ if (assignments[*i] <= remaining_mem * (portions[*i] / sum)) {
+ remaining_mem -= assignments[*i];
+ block_count[*i] = 1;
+ i = unassigned.erase(i);
+ found_more = true;
+ } else {
+ ++i;
+ }
+ }
+ } while (found_more);
+ for (std::vector<std::size_t>::iterator i = unassigned.begin(); i != unassigned.end(); ++i) {
+ assignments[*i] = remaining_mem * (portions[*i] / sum);
+ block_count[*i] = config_.block_count;
+ }
+ chains_.clear();
+ std::cerr << "Chain sizes:";
+ for (std::size_t i = 0; i < config_.order; ++i) {
+ std::cerr << ' ' << (i+1) << ":" << assignments[i];
+ chains_.push_back(util::stream::ChainConfig(NGram::TotalSize(i + 1), block_count[i], assignments[i]));
+ }
+ std::cerr << std::endl;
+ }
+
+ PipelineConfig config_;
+
+ Chains chains_;
+ // Often only unigrams, but sometimes all orders.
+ FixedArray<util::stream::FileBuffer> files_;
+};
+
+void CountText(int text_file /* input */, int vocab_file /* output */, Master &master, uint64_t &token_count, std::string &text_file_name) {
+ const PipelineConfig &config = master.Config();
+ std::cerr << "=== 1/5 Counting and sorting n-grams ===" << std::endl;
+
+ UTIL_THROW_IF(config.TotalMemory() < config.assume_vocab_hash_size, util::Exception, "Vocab hash size estimate " << config.assume_vocab_hash_size << " exceeds total memory " << config.TotalMemory());
+ std::size_t memory_for_chain =
+ // This much memory to work with after vocab hash table.
+ static_cast<float>(config.TotalMemory() - config.assume_vocab_hash_size) /
+ // Solve for block size including the dedupe multiplier for one block.
+ (static_cast<float>(config.block_count) + CorpusCount::DedupeMultiplier(config.order)) *
+ // Chain likes memory expressed in terms of total memory.
+ static_cast<float>(config.block_count);
+ util::stream::Chain chain(util::stream::ChainConfig(NGram::TotalSize(config.order), config.block_count, memory_for_chain));
+
+ WordIndex type_count;
+ util::FilePiece text(text_file, NULL, &std::cerr);
+ text_file_name = text.FileName();
+ CorpusCount counter(text, vocab_file, token_count, type_count, chain.BlockSize() / chain.EntrySize());
+ chain >> boost::ref(counter);
+
+ util::stream::Sort<SuffixOrder, AddCombiner> sorter(chain, config.sort, SuffixOrder(config.order), AddCombiner());
+ chain.Wait(true);
+ std::cerr << "=== 2/5 Calculating and sorting adjusted counts ===" << std::endl;
+ master.InitForAdjust(sorter, type_count);
+}
+
+void InitialProbabilities(const std::vector<uint64_t> &counts, const std::vector<Discount> &discounts, Master &master, Sorts<SuffixOrder> &primary, FixedArray<util::stream::FileBuffer> &gammas) {
+ const PipelineConfig &config = master.Config();
+ Chains second(config.order);
+
+ {
+ Sorts<ContextOrder> sorts;
+ master.SetupSorts(sorts);
+ PrintStatistics(counts, discounts);
+ lm::ngram::ShowSizes(counts);
+ std::cerr << "=== 3/5 Calculating and sorting initial probabilities ===" << std::endl;
+ master.SortAndReadTwice(counts, sorts, second, config.initial_probs.adder_in);
+ }
+
+ Chains gamma_chains(config.order);
+ InitialProbabilities(config.initial_probs, discounts, master.MutableChains(), second, gamma_chains);
+ // Don't care about gamma for 0.
+ gamma_chains[0] >> util::stream::kRecycle;
+ gammas.Init(config.order - 1);
+ for (std::size_t i = 1; i < config.order; ++i) {
+ gammas.push_back(util::MakeTemp(config.TempPrefix()));
+ gamma_chains[i] >> gammas[i - 1].Sink();
+ }
+ // Has to be done here due to gamma_chains scope.
+ master.SetupSorts(primary);
+}
+
+void InterpolateProbabilities(const std::vector<uint64_t> &counts, Master &master, Sorts<SuffixOrder> &primary, FixedArray<util::stream::FileBuffer> &gammas) {
+ std::cerr << "=== 4/5 Calculating and writing order-interpolated probabilities ===" << std::endl;
+ const PipelineConfig &config = master.Config();
+ master.MaximumLazyInput(counts, primary);
+
+ Chains gamma_chains(config.order - 1);
+ util::stream::ChainConfig read_backoffs(config.read_backoffs);
+ read_backoffs.entry_size = sizeof(float);
+ for (std::size_t i = 0; i < config.order - 1; ++i) {
+ gamma_chains.push_back(read_backoffs);
+ gamma_chains.back() >> gammas[i].Source();
+ }
+ master >> Interpolate(counts[0], ChainPositions(gamma_chains));
+ gamma_chains >> util::stream::kRecycle;
+ master.BufferFinal(counts);
+}
+
+} // namespace
+
+void Pipeline(PipelineConfig config, int text_file, int out_arpa) {
+ // Some fail-fast sanity checks.
+ if (config.sort.buffer_size * 4 > config.TotalMemory()) {
+ config.sort.buffer_size = config.TotalMemory() / 4;
+ std::cerr << "Warning: changing sort block size to " << config.sort.buffer_size << " bytes due to low total memory." << std::endl;
+ }
+ if (config.minimum_block < NGram::TotalSize(config.order)) {
+ config.minimum_block = NGram::TotalSize(config.order);
+ std::cerr << "Warning: raising minimum block to " << config.minimum_block << " to fit an ngram in every block." << std::endl;
+ }
+ UTIL_THROW_IF(config.sort.buffer_size < config.minimum_block, util::Exception, "Sort block size " << config.sort.buffer_size << " is below the minimum block size " << config.minimum_block << ".");
+ UTIL_THROW_IF(config.TotalMemory() < config.minimum_block * config.order * config.block_count, util::Exception,
+ "Not enough memory to fit " << (config.order * config.block_count) << " blocks with minimum size " << config.minimum_block << ". Increase memory to " << (config.minimum_block * config.order * config.block_count) << " bytes or decrease the minimum block size.");
+
+ UTIL_TIMER("(%w s) Total wall time elapsed\n");
+ Master master(config);
+
+ util::scoped_fd vocab_file(config.vocab_file.empty() ?
+ util::MakeTemp(config.TempPrefix()) :
+ util::CreateOrThrow(config.vocab_file.c_str()));
+ uint64_t token_count;
+ std::string text_file_name;
+ CountText(text_file, vocab_file.get(), master, token_count, text_file_name);
+
+ std::vector<uint64_t> counts;
+ std::vector<Discount> discounts;
+ master >> AdjustCounts(counts, discounts);
+
+ {
+ FixedArray<util::stream::FileBuffer> gammas;
+ Sorts<SuffixOrder> primary;
+ InitialProbabilities(counts, discounts, master, primary, gammas);
+ InterpolateProbabilities(counts, master, primary, gammas);
+ }
+
+ std::cerr << "=== 5/5 Writing ARPA model ===" << std::endl;
+ VocabReconstitute vocab(vocab_file.get());
+ UTIL_THROW_IF(vocab.Size() != counts[0], util::Exception, "Vocab words don't match up. Is there a null byte in the input?");
+ HeaderInfo header_info(text_file_name, token_count);
+ master >> PrintARPA(vocab, counts, (config.verbose_header ? &header_info : NULL), out_arpa) >> util::stream::kRecycle;
+ master.MutableChains().Wait(true);
+}
+
+}} // namespaces
diff --git a/lm/builder/pipeline.hh b/lm/builder/pipeline.hh
new file mode 100644
index 000000000..f1d6c5f61
--- /dev/null
+++ b/lm/builder/pipeline.hh
@@ -0,0 +1,40 @@
+#ifndef LM_BUILDER_PIPELINE__
+#define LM_BUILDER_PIPELINE__
+
+#include "lm/builder/initial_probabilities.hh"
+#include "lm/builder/header_info.hh"
+#include "util/stream/config.hh"
+#include "util/file_piece.hh"
+
+#include <string>
+#include <cstddef>
+
+namespace lm { namespace builder {
+
+struct PipelineConfig {
+ std::size_t order;
+ std::string vocab_file;
+ util::stream::SortConfig sort;
+ InitialProbabilitiesConfig initial_probs;
+ util::stream::ChainConfig read_backoffs;
+ bool verbose_header;
+
+ // Amount of memory to assume that the vocabulary hash table will use. This
+ // is subtracted from total memory for CorpusCount.
+ std::size_t assume_vocab_hash_size;
+
+ // Minimum block size to tolerate.
+ std::size_t minimum_block;
+
+ // Number of blocks to use. This will be overridden to 1 if everything fits.
+ std::size_t block_count;
+
+ const std::string &TempPrefix() const { return sort.temp_prefix; }
+ std::size_t TotalMemory() const { return sort.total_memory; }
+};
+
+// Takes ownership of text_file.
+void Pipeline(PipelineConfig config, int text_file, int out_arpa);
+
+}} // namespaces
+#endif // LM_BUILDER_PIPELINE__
diff --git a/lm/builder/print.cc b/lm/builder/print.cc
new file mode 100644
index 000000000..b0323221a
--- /dev/null
+++ b/lm/builder/print.cc
@@ -0,0 +1,135 @@
+#include "lm/builder/print.hh"
+
+#include "util/double-conversion/double-conversion.h"
+#include "util/double-conversion/utils.h"
+#include "util/file.hh"
+#include "util/mmap.hh"
+#include "util/scoped.hh"
+#include "util/stream/timer.hh"
+
+#define BOOST_LEXICAL_CAST_ASSUME_C_LOCALE
+#include <boost/lexical_cast.hpp>
+
+#include <sstream>
+
+#include <string.h>
+
+namespace lm { namespace builder {
+
+VocabReconstitute::VocabReconstitute(int fd) {
+ uint64_t size = util::SizeOrThrow(fd);
+ util::MapRead(util::POPULATE_OR_READ, fd, 0, size, memory_);
+ const char *const start = static_cast<const char*>(memory_.get());
+ const char *i;
+ for (i = start; i != start + size; i += strlen(i) + 1) {
+ map_.push_back(i);
+ }
+ // Last one for LookupPiece.
+ map_.push_back(i);
+}
+
+namespace {
+class OutputManager {
+ public:
+ static const std::size_t kOutBuf = 1048576;
+
+ // Does not take ownership of out.
+ explicit OutputManager(int out)
+ : buf_(util::MallocOrThrow(kOutBuf)),
+ builder_(static_cast<char*>(buf_.get()), kOutBuf),
+ // Mostly the default but with inf instead. And no flags.
+ convert_(double_conversion::DoubleToStringConverter::NO_FLAGS, "inf", "NaN", 'e', -6, 21, 6, 0),
+ fd_(out) {}
+
+ ~OutputManager() {
+ Flush();
+ }
+
+ OutputManager &operator<<(float value) {
+ // Odd, but this is the largest number found in the comments.
+ EnsureRemaining(double_conversion::DoubleToStringConverter::kMaxPrecisionDigits + 8);
+ convert_.ToShortestSingle(value, &builder_);
+ return *this;
+ }
+
+ OutputManager &operator<<(StringPiece str) {
+ if (str.size() > kOutBuf) {
+ Flush();
+ util::WriteOrThrow(fd_, str.data(), str.size());
+ } else {
+ EnsureRemaining(str.size());
+ builder_.AddSubstring(str.data(), str.size());
+ }
+ return *this;
+ }
+
+ // Inefficient!
+ OutputManager &operator<<(unsigned val) {
+ return *this << boost::lexical_cast<std::string>(val);
+ }
+
+ OutputManager &operator<<(char c) {
+ EnsureRemaining(1);
+ builder_.AddCharacter(c);
+ return *this;
+ }
+
+ void Flush() {
+ util::WriteOrThrow(fd_, buf_.get(), builder_.position());
+ builder_.Reset();
+ }
+
+ private:
+ void EnsureRemaining(std::size_t amount) {
+ if (static_cast<std::size_t>(builder_.size() - builder_.position()) < amount) {
+ Flush();
+ }
+ }
+
+ util::scoped_malloc buf_;
+ double_conversion::StringBuilder builder_;
+ double_conversion::DoubleToStringConverter convert_;
+ int fd_;
+};
+} // namespace
+
+PrintARPA::PrintARPA(const VocabReconstitute &vocab, const std::vector<uint64_t> &counts, const HeaderInfo* header_info, int out_fd)
+ : vocab_(vocab), out_fd_(out_fd) {
+ std::stringstream stream;
+
+ if (header_info) {
+ stream << "# Input file: " << header_info->input_file << '\n';
+ stream << "# Token count: " << header_info->token_count << '\n';
+ stream << "# Smoothing: Modified Kneser-Ney" << '\n';
+ }
+ stream << "\\data\\\n";
+ for (size_t i = 0; i < counts.size(); ++i) {
+ stream << "ngram " << (i+1) << '=' << counts[i] << '\n';
+ }
+ stream << '\n';
+ std::string as_string(stream.str());
+ util::WriteOrThrow(out_fd, as_string.data(), as_string.size());
+}
+
+void PrintARPA::Run(const ChainPositions &positions) {
+ UTIL_TIMER("(%w s) Wrote ARPA file\n");
+ OutputManager out(out_fd_);
+ for (unsigned order = 1; order <= positions.size(); ++order) {
+ out << "\\" << order << "-grams:" << '\n';
+ for (NGramStream stream(positions[order - 1]); stream; ++stream) {
+ // Correcting for numerical precision issues. Take that IRST.
+ out << std::min(0.0f, stream->Value().complete.prob) << '\t' << vocab_.Lookup(*stream->begin());
+ for (const WordIndex *i = stream->begin() + 1; i != stream->end(); ++i) {
+ out << ' ' << vocab_.Lookup(*i);
+ }
+ float backoff = stream->Value().complete.backoff;
+ if (backoff != 0.0)
+ out << '\t' << backoff;
+ out << '\n';
+ }
+ out << '\n';
+ }
+ out << "\\end\\\n";
+}
+
+}} // namespaces
diff --git a/lm/builder/print.hh b/lm/builder/print.hh
new file mode 100644
index 000000000..aa932e757
--- /dev/null
+++ b/lm/builder/print.hh
@@ -0,0 +1,102 @@
+#ifndef LM_BUILDER_PRINT__
+#define LM_BUILDER_PRINT__
+
+#include "lm/builder/ngram.hh"
+#include "lm/builder/multi_stream.hh"
+#include "lm/builder/header_info.hh"
+#include "util/file.hh"
+#include "util/mmap.hh"
+#include "util/string_piece.hh"
+
+#include <ostream>
+
+#include <assert.h>
+
+// Warning: print routines read all unigrams before all bigrams before all
+// trigrams etc. So if other parts of the chain move jointly, you'll have to
+// buffer.
+
+namespace lm { namespace builder {
+
+class VocabReconstitute {
+ public:
+ // fd must be alive for life of this object; does not take ownership.
+ explicit VocabReconstitute(int fd);
+
+ const char *Lookup(WordIndex index) const {
+ assert(index < map_.size() - 1);
+ return map_[index];
+ }
+
+ StringPiece LookupPiece(WordIndex index) const {
+ return StringPiece(map_[index], map_[index + 1] - 1 - map_[index]);
+ }
+
+ std::size_t Size() const {
+ // There's an extra entry to support StringPiece lengths.
+ return map_.size() - 1;
+ }
+
+ private:
+ util::scoped_memory memory_;
+ std::vector<const char*> map_;
+};
+
+// Not defined, only specialized.
+template <class T> void PrintPayload(std::ostream &to, const Payload &payload);
+template <> inline void PrintPayload<uint64_t>(std::ostream &to, const Payload &payload) {
+ to << payload.count;
+}
+template <> inline void PrintPayload<Uninterpolated>(std::ostream &to, const Payload &payload) {
+ to << log10(payload.uninterp.prob) << ' ' << log10(payload.uninterp.gamma);
+}
+template <> inline void PrintPayload<ProbBackoff>(std::ostream &to, const Payload &payload) {
+ to << payload.complete.prob << ' ' << payload.complete.backoff;
+}
+
+// template parameter is the type stored.
+template <class V> class Print {
+ public:
+ explicit Print(const VocabReconstitute &vocab, std::ostream &to) : vocab_(vocab), to_(to) {}
+
+ void Run(const ChainPositions &chains) {
+ NGramStreams streams(chains);
+ for (NGramStream *s = streams.begin(); s != streams.end(); ++s) {
+ DumpStream(*s);
+ }
+ }
+
+ void Run(const util::stream::ChainPosition &position) {
+ NGramStream stream(position);
+ DumpStream(stream);
+ }
+
+ private:
+ void DumpStream(NGramStream &stream) {
+ for (; stream; ++stream) {
+ PrintPayload<V>(to_, stream->Value());
+ for (const WordIndex *w = stream->begin(); w != stream->end(); ++w) {
+ to_ << ' ' << vocab_.Lookup(*w) << '=' << *w;
+ }
+ to_ << '\n';
+ }
+ }
+
+ const VocabReconstitute &vocab_;
+ std::ostream &to_;
+};
+
+class PrintARPA {
+ public:
+ // header_info may be NULL to disable the header
+ explicit PrintARPA(const VocabReconstitute &vocab, const std::vector<uint64_t> &counts, const HeaderInfo* header_info, int out_fd);
+
+ void Run(const ChainPositions &positions);
+
+ private:
+ const VocabReconstitute &vocab_;
+ int out_fd_;
+};
+
+}} // namespaces
+#endif // LM_BUILDER_PRINT__
diff --git a/lm/builder/sort.hh b/lm/builder/sort.hh
new file mode 100644
index 000000000..9989389b6
--- /dev/null
+++ b/lm/builder/sort.hh
@@ -0,0 +1,103 @@
+#ifndef LM_BUILDER_SORT__
+#define LM_BUILDER_SORT__
+
+#include "lm/builder/multi_stream.hh"
+#include "lm/builder/ngram.hh"
+#include "lm/word_index.hh"
+#include "util/stream/sort.hh"
+
+#include "util/stream/timer.hh"
+
+#include <functional>
+#include <string>
+
+namespace lm {
+namespace builder {
+
+template <class Child> class Comparator : public std::binary_function<const void *, const void *, bool> {
+ public:
+ explicit Comparator(std::size_t order) : order_(order) {}
+
+ inline bool operator()(const void *lhs, const void *rhs) const {
+ return static_cast<const Child*>(this)->Compare(static_cast<const WordIndex*>(lhs), static_cast<const WordIndex*>(rhs));
+ }
+
+ std::size_t Order() const { return order_; }
+
+ protected:
+ std::size_t order_;
+};
+
+class SuffixOrder : public Comparator<SuffixOrder> {
+ public:
+ explicit SuffixOrder(std::size_t order) : Comparator<SuffixOrder>(order) {}
+
+ inline bool Compare(const WordIndex *lhs, const WordIndex *rhs) const {
+ for (std::size_t i = order_ - 1; i != 0; --i) {
+ if (lhs[i] != rhs[i])
+ return lhs[i] < rhs[i];
+ }
+ return lhs[0] < rhs[0];
+ }
+
+ static const unsigned kMatchOffset = 1;
+};
+
+class ContextOrder : public Comparator<ContextOrder> {
+ public:
+ explicit ContextOrder(std::size_t order) : Comparator<ContextOrder>(order) {}
+
+ inline bool Compare(const WordIndex *lhs, const WordIndex *rhs) const {
+ for (int i = order_ - 2; i >= 0; --i) {
+ if (lhs[i] != rhs[i])
+ return lhs[i] < rhs[i];
+ }
+ return lhs[order_ - 1] < rhs[order_ - 1];
+ }
+};
+
+class PrefixOrder : public Comparator<PrefixOrder> {
+ public:
+ explicit PrefixOrder(std::size_t order) : Comparator<PrefixOrder>(order) {}
+
+ inline bool Compare(const WordIndex *lhs, const WordIndex *rhs) const {
+ for (std::size_t i = 0; i < order_; ++i) {
+ if (lhs[i] != rhs[i])
+ return lhs[i] < rhs[i];
+ }
+ return false;
+ }
+
+ static const unsigned kMatchOffset = 0;
+};
+
+// Sum counts for the same n-gram.
+struct AddCombiner {
+ bool operator()(void *first_void, const void *second_void, const SuffixOrder &compare) const {
+ NGram first(first_void, compare.Order());
+ // There isn't a const version of NGram.
+ NGram second(const_cast<void*>(second_void), compare.Order());
+ if (memcmp(first.begin(), second.begin(), sizeof(WordIndex) * compare.Order())) return false;
+ first.Count() += second.Count();
+ return true;
+ }
+};
+
+// The combiner is only used on a single chain, so I didn't bother to allow
+// that template.
+template <class Compare> class Sorts : public FixedArray<util::stream::Sort<Compare> > {
+ private:
+ typedef util::stream::Sort<Compare> S;
+ typedef FixedArray<S> P;
+
+ public:
+ void push_back(util::stream::Chain &chain, const util::stream::SortConfig &config, const Compare &compare) {
+ new (P::end()) S(chain, config, compare);
+ P::Constructed();
+ }
+};
+
+} // namespace builder
+} // namespace lm
+
+#endif // LM_BUILDER_SORT__
diff --git a/lm/filter/Jamfile b/lm/filter/Jamfile
new file mode 100644
index 000000000..16e417261
--- /dev/null
+++ b/lm/filter/Jamfile
@@ -0,0 +1,5 @@
+fakelib lm_filter : phrase.cc vocab.cc arpa_io.cc ../../util//kenutil : <threading>multi:<library>/top//boost_thread ;
+
+obj main : main.cc : <threading>single:<define>NTHREAD <include>../.. ;
+
+exe filter : main lm_filter ../../util//kenutil ..//kenlm : <threading>multi:<library>/top//boost_thread ;
diff --git a/lm/filter/arpa_io.cc b/lm/filter/arpa_io.cc
new file mode 100644
index 000000000..caf8df953
--- /dev/null
+++ b/lm/filter/arpa_io.cc
@@ -0,0 +1,122 @@
+#include "lm/filter/arpa_io.hh"
+#include "util/file_piece.hh"
+
+#include <iostream>
+#include <ostream>
+#include <string>
+#include <vector>
+
+#include <ctype.h>
+#include <errno.h>
+#include <string.h>
+
+namespace lm {
+
+ARPAInputException::ARPAInputException(const StringPiece &message) throw() : what_("Error: ") {
+ what_.append(message.data(), message.size());
+}
+
+ARPAInputException::ARPAInputException(const StringPiece &message, const StringPiece &line) throw() {
+ what_ = "Error: ";
+ what_.append(message.data(), message.size());
+ what_ += " in line '";
+ what_.append(line.data(), line.size());
+ what_ += "'.";
+}
+
+ARPAOutputException::ARPAOutputException(const char *message, const std::string &file_name) throw()
+ : what_(std::string(message) + " file " + file_name), file_name_(file_name) {
+ if (errno) {
+ char buf[1024];
+ buf[0] = 0;
+#if (_POSIX_C_SOURCE >= 200112L || _XOPEN_SOURCE >= 600) && ! _GNU_SOURCE
+ const char *add = buf;
+ if (!strerror_r(errno, buf, 1024)) {
+#else
+ const char *add = strerror_r(errno, buf, 1024);
+ if (add) {
+#endif
+ what_ += " :";
+ what_ += add;
+ }
+ }
+}
+
+// Seeking is the responsibility of the caller.
+void WriteCounts(std::ostream &out, const std::vector<size_t> &number) {
+ out << "\n\\data\\\n";
+ for (unsigned int i = 0; i < number.size(); ++i) {
+ out << "ngram " << i+1 << "=" << number[i] << '\n';
+ }
+ out << '\n';
+}
+
+size_t SizeNeededForCounts(const std::vector<size_t> &number) {
+ std::ostringstream buf;
+ WriteCounts(buf, number);
+ return buf.tellp();
+}
+
+bool IsEntirelyWhiteSpace(const StringPiece &line) {
+ for (size_t i = 0; i < static_cast<size_t>(line.size()); ++i) {
+ if (!isspace(line.data()[i])) return false;
+ }
+ return true;
+}
+
+ARPAOutput::ARPAOutput(const char *name, size_t buffer_size) : file_name_(name), buffer_(new char[buffer_size]) {
+ try {
+ file_.exceptions(std::ostream::eofbit | std::ostream::failbit | std::ostream::badbit);
+ if (!file_.rdbuf()->pubsetbuf(buffer_.get(), buffer_size)) {
+ std::cerr << "Warning: could not enlarge buffer for " << name << std::endl;
+ buffer_.reset();
+ }
+ file_.open(name, std::ios::out | std::ios::binary);
+ } catch (const std::ios_base::failure &f) {
+ throw ARPAOutputException("Opening", file_name_);
+ }
+}
+
+void ARPAOutput::ReserveForCounts(std::streampos reserve) {
+ try {
+ for (std::streampos i = 0; i < reserve; i += std::streampos(1)) {
+ file_ << '\n';
+ }
+ } catch (const std::ios_base::failure &f) {
+ throw ARPAOutputException("Writing blanks to reserve space for counts to ", file_name_);
+ }
+}
+
+void ARPAOutput::BeginLength(unsigned int length) {
+ fast_counter_ = 0;
+ try {
+ file_ << '\\' << length << "-grams:" << '\n';
+ } catch (const std::ios_base::failure &f) {
+ throw ARPAOutputException("Writing n-gram header to ", file_name_);
+ }
+}
+
+void ARPAOutput::EndLength(unsigned int length) {
+ try {
+ file_ << '\n';
+ } catch (const std::ios_base::failure &f) {
+ throw ARPAOutputException("Writing blank at end of count list to ", file_name_);
+ }
+ if (length > counts_.size()) {
+ counts_.resize(length);
+ }
+ counts_[length - 1] = fast_counter_;
+}
+
+void ARPAOutput::Finish() {
+ try {
+ file_ << "\\end\\\n";
+ file_.seekp(0);
+ WriteCounts(file_, counts_);
+ file_ << std::flush;
+ } catch (const std::ios_base::failure &f) {
+ throw ARPAOutputException("Finishing including writing counts at beginning to ", file_name_);
+ }
+}
+
+} // namespace lm
diff --git a/lm/filter/arpa_io.hh b/lm/filter/arpa_io.hh
new file mode 100644
index 000000000..90f484473
--- /dev/null
+++ b/lm/filter/arpa_io.hh
@@ -0,0 +1,122 @@
+#ifndef LM_FILTER_ARPA_IO__
+#define LM_FILTER_ARPA_IO__
+/* Input and output for ARPA format language model files.
+ */
+#include "lm/read_arpa.hh"
+#include "util/exception.hh"
+#include "util/string_piece.hh"
+#include "util/tokenize_piece.hh"
+
+#include <boost/noncopyable.hpp>
+#include <boost/scoped_array.hpp>
+
+#include <fstream>
+#include <string>
+#include <vector>
+
+#include <err.h>
+#include <string.h>
+
+namespace util { class FilePiece; }
+
+namespace lm {
+
+class ARPAInputException : public util::Exception {
+ public:
+ explicit ARPAInputException(const StringPiece &message) throw();
+ explicit ARPAInputException(const StringPiece &message, const StringPiece &line) throw();
+ virtual ~ARPAInputException() throw() {}
+
+ const char *what() const throw() { return what_.c_str(); }
+
+ private:
+ std::string what_;
+};
+
+class ARPAOutputException : public std::exception {
+ public:
+ ARPAOutputException(const char *prefix, const std::string &file_name) throw();
+ virtual ~ARPAOutputException() throw() {}
+
+ const char *what() const throw() { return what_.c_str(); }
+
+ const std::string &File() const throw() { return file_name_; }
+
+ private:
+ std::string what_;
+ const std::string file_name_;
+};
+
+// Handling for the counts of n-grams at the beginning of ARPA files.
+size_t SizeNeededForCounts(const std::vector<size_t> &number);
+
+/* Writes an ARPA file. This has to be seekable so the counts can be written
+ * at the end. Hence, I just have it own a std::fstream instead of accepting
+ * a separately held std::ostream.
+ */
+class ARPAOutput : boost::noncopyable {
+ public:
+ explicit ARPAOutput(const char *name, size_t buffer_size = 65536);
+
+ void ReserveForCounts(std::streampos reserve);
+
+ void BeginLength(unsigned int length);
+
+ void AddNGram(const StringPiece &line) {
+ try {
+ file_ << line << '\n';
+ } catch (const std::ios_base::failure &f) {
+ throw ARPAOutputException("Writing an n-gram", file_name_);
+ }
+ ++fast_counter_;
+ }
+
+ void AddNGram(const StringPiece &ngram, const StringPiece &line) {
+ AddNGram(line);
+ }
+
+ template <class Iterator> void AddNGram(const Iterator &begin, const Iterator &end, const StringPiece &line) {
+ AddNGram(line);
+ }
+
+ void EndLength(unsigned int length);
+
+ void Finish();
+
+ private:
+ const std::string file_name_;
+ boost::scoped_array<char> buffer_;
+ std::fstream file_;
+ size_t fast_counter_;
+ std::vector<size_t> counts_;
+};
+
+
+template <class Output> void ReadNGrams(util::FilePiece &in, unsigned int length, size_t number, Output &out) {
+ ReadNGramHeader(in, length);
+ out.BeginLength(length);
+ for (size_t i = 0; i < number; ++i) {
+ StringPiece line = in.ReadLine();
+ util::TokenIter<util::SingleCharacter> tabber(line, '\t');
+ if (!tabber) throw ARPAInputException("blank line", line);
+ if (!++tabber) throw ARPAInputException("no tab", line);
+
+ out.AddNGram(*tabber, line);
+ }
+ out.EndLength(length);
+}
+
+template <class Output> void ReadARPA(util::FilePiece &in_lm, Output &out) {
+ std::vector<size_t> number;
+ ReadARPACounts(in_lm, number);
+ out.ReserveForCounts(SizeNeededForCounts(number));
+ for (unsigned int i = 0; i < number.size(); ++i) {
+ ReadNGrams(in_lm, i + 1, number[i], out);
+ }
+ ReadEnd(in_lm);
+ out.Finish();
+}
+
+} // namespace lm
+
+#endif // LM_FILTER_ARPA_IO__
diff --git a/lm/filter/count_io.hh b/lm/filter/count_io.hh
new file mode 100644
index 000000000..97c0fa25e
--- /dev/null
+++ b/lm/filter/count_io.hh
@@ -0,0 +1,91 @@
+#ifndef LM_FILTER_COUNT_IO__
+#define LM_FILTER_COUNT_IO__
+
+#include <fstream>
+#include <iostream>
+#include <string>
+
+#include <err.h>
+
+#include "util/file_piece.hh"
+
+namespace lm {
+
+class CountOutput : boost::noncopyable {
+ public:
+ explicit CountOutput(const char *name) : file_(name, std::ios::out) {}
+
+ void AddNGram(const StringPiece &line) {
+ if (!(file_ << line << '\n')) {
+ err(3, "Writing counts file failed");
+ }
+ }
+
+ template <class Iterator> void AddNGram(const Iterator &begin, const Iterator &end, const StringPiece &line) {
+ AddNGram(line);
+ }
+
+ void AddNGram(const StringPiece &ngram, const StringPiece &line) {
+ AddNGram(line);
+ }
+
+ private:
+ std::fstream file_;
+};
+
+class CountBatch {
+ public:
+ explicit CountBatch(std::streamsize initial_read)
+ : initial_read_(initial_read) {
+ buffer_.reserve(initial_read);
+ }
+
+ void Read(std::istream &in) {
+ buffer_.resize(initial_read_);
+ in.read(&*buffer_.begin(), initial_read_);
+ buffer_.resize(in.gcount());
+ char got;
+ while (in.get(got) && got != '\n')
+ buffer_.push_back(got);
+ }
+
+ template <class Output> void Send(Output &out) {
+ for (util::TokenIter<util::SingleCharacter> line(StringPiece(&*buffer_.begin(), buffer_.size()), '\n'); line; ++line) {
+ util::TokenIter<util::SingleCharacter> tabber(*line, '\t');
+ if (!tabber) {
+ std::cerr << "Warning: empty n-gram count line being removed\n";
+ continue;
+ }
+ util::TokenIter<util::SingleCharacter, true> words(*tabber, ' ');
+ if (!words) {
+ std::cerr << "Line has a tab but no words.\n";
+ continue;
+ }
+ out.AddNGram(words, util::TokenIter<util::SingleCharacter, true>::end(), *line);
+ }
+ }
+
+ private:
+ std::streamsize initial_read_;
+
+ // This could have been a std::string but that's less happy with raw writes.
+ std::vector<char> buffer_;
+};
+
+template <class Output> void ReadCount(util::FilePiece &in_file, Output &out) {
+ try {
+ while (true) {
+ StringPiece line = in_file.ReadLine();
+ util::TokenIter<util::SingleCharacter> tabber(line, '\t');
+ if (!tabber) {
+ std::cerr << "Warning: empty n-gram count line being removed\n";
+ continue;
+ }
+ out.AddNGram(*tabber, line);
+ }
+ } catch (const util::EndOfFileException &e) {}
+}
+
+} // namespace lm
+
+#endif // LM_FILTER_COUNT_IO__
diff --git a/lm/filter/format.hh b/lm/filter/format.hh
new file mode 100644
index 000000000..7f945b0d6
--- /dev/null
+++ b/lm/filter/format.hh
@@ -0,0 +1,250 @@
+#ifndef LM_FILTER_FORMAT_H__
+#define LM_FITLER_FORMAT_H__
+
+#include "lm/filter/arpa_io.hh"
+#include "lm/filter/count_io.hh"
+
+#include <boost/lexical_cast.hpp>
+#include <boost/ptr_container/ptr_vector.hpp>
+
+#include <iosfwd>
+
+namespace lm {
+
+template <class Single> class MultipleOutput {
+ private:
+ typedef boost::ptr_vector<Single> Singles;
+ typedef typename Singles::iterator SinglesIterator;
+
+ public:
+ MultipleOutput(const char *prefix, size_t number) {
+ files_.reserve(number);
+ std::string tmp;
+ for (unsigned int i = 0; i < number; ++i) {
+ tmp = prefix;
+ tmp += boost::lexical_cast<std::string>(i);
+ files_.push_back(new Single(tmp.c_str()));
+ }
+ }
+
+ void AddNGram(const StringPiece &line) {
+ for (SinglesIterator i = files_.begin(); i != files_.end(); ++i)
+ i->AddNGram(line);
+ }
+
+ template <class Iterator> void AddNGram(const Iterator &begin, const Iterator &end, const StringPiece &line) {
+ for (SinglesIterator i = files_.begin(); i != files_.end(); ++i)
+ i->AddNGram(begin, end, line);
+ }
+
+ void SingleAddNGram(size_t offset, const StringPiece &line) {
+ files_[offset].AddNGram(line);
+ }
+
+ template <class Iterator> void SingleAddNGram(size_t offset, const Iterator &begin, const Iterator &end, const StringPiece &line) {
+ files_[offset].AddNGram(begin, end, line);
+ }
+
+ protected:
+ Singles files_;
+};
+
+class MultipleARPAOutput : public MultipleOutput<ARPAOutput> {
+ public:
+ MultipleARPAOutput(const char *prefix, size_t number) : MultipleOutput<ARPAOutput>(prefix, number) {}
+
+ void ReserveForCounts(std::streampos reserve) {
+ for (boost::ptr_vector<ARPAOutput>::iterator i = files_.begin(); i != files_.end(); ++i)
+ i->ReserveForCounts(reserve);
+ }
+
+ void BeginLength(unsigned int length) {
+ for (boost::ptr_vector<ARPAOutput>::iterator i = files_.begin(); i != files_.end(); ++i)
+ i->BeginLength(length);
+ }
+
+ void EndLength(unsigned int length) {
+ for (boost::ptr_vector<ARPAOutput>::iterator i = files_.begin(); i != files_.end(); ++i)
+ i->EndLength(length);
+ }
+
+ void Finish() {
+ for (boost::ptr_vector<ARPAOutput>::iterator i = files_.begin(); i != files_.end(); ++i)
+ i->Finish();
+ }
+};
+
+template <class Filter, class Output> class DispatchInput {
+ public:
+ DispatchInput(Filter &filter, Output &output) : filter_(filter), output_(output) {}
+
+/* template <class Iterator> void AddNGram(const Iterator &begin, const Iterator &end, const StringPiece &line) {
+ filter_.AddNGram(begin, end, line, output_);
+ }*/
+
+ void AddNGram(const StringPiece &ngram, const StringPiece &line) {
+ filter_.AddNGram(ngram, line, output_);
+ }
+
+ protected:
+ Filter &filter_;
+ Output &output_;
+};
+
+template <class Filter, class Output> class DispatchARPAInput : public DispatchInput<Filter, Output> {
+ private:
+ typedef DispatchInput<Filter, Output> B;
+
+ public:
+ DispatchARPAInput(Filter &filter, Output &output) : B(filter, output) {}
+
+ void ReserveForCounts(std::streampos reserve) { B::output_.ReserveForCounts(reserve); }
+ void BeginLength(unsigned int length) { B::output_.BeginLength(length); }
+
+ void EndLength(unsigned int length) {
+ B::filter_.Flush();
+ B::output_.EndLength(length);
+ }
+ void Finish() { B::output_.Finish(); }
+};
+
+struct ARPAFormat {
+ typedef ARPAOutput Output;
+ typedef MultipleARPAOutput Multiple;
+ static void Copy(util::FilePiece &in, Output &out) {
+ ReadARPA(in, out);
+ }
+ template <class Filter, class Out> static void RunFilter(util::FilePiece &in, Filter &filter, Out &output) {
+ DispatchARPAInput<Filter, Out> dispatcher(filter, output);
+ ReadARPA(in, dispatcher);
+ }
+};
+
+struct CountFormat {
+ typedef CountOutput Output;
+ typedef MultipleOutput<Output> Multiple;
+ static void Copy(util::FilePiece &in, Output &out) {
+ ReadCount(in, out);
+ }
+ template <class Filter, class Out> static void RunFilter(util::FilePiece &in, Filter &filter, Out &output) {
+ DispatchInput<Filter, Out> dispatcher(filter, output);
+ ReadCount(in, dispatcher);
+ }
+};
+
+/* For multithreading, the buffer classes hold batches of filter inputs and
+ * outputs in memory. The strings get reused a lot, so keep them around
+ * instead of clearing each time.
+ */
+class InputBuffer {
+ public:
+ InputBuffer() : actual_(0) {}
+
+ void Reserve(size_t size) { lines_.reserve(size); }
+
+ template <class Output> void AddNGram(const StringPiece &ngram, const StringPiece &line, Output &output) {
+ if (lines_.size() == actual_) lines_.resize(lines_.size() + 1);
+ // TODO avoid this copy.
+ std::string &copied = lines_[actual_].line;
+ copied.assign(line.data(), line.size());
+ lines_[actual_].ngram.set(copied.data() + (ngram.data() - line.data()), ngram.size());
+ ++actual_;
+ }
+
+ template <class Filter, class Output> void CallFilter(Filter &filter, Output &output) const {
+ for (std::vector<Line>::const_iterator i = lines_.begin(); i != lines_.begin() + actual_; ++i) {
+ filter.AddNGram(i->ngram, i->line, output);
+ }
+ }
+
+ void Clear() { actual_ = 0; }
+ bool Empty() { return actual_ == 0; }
+ size_t Size() { return actual_; }
+
+ private:
+ struct Line {
+ std::string line;
+ StringPiece ngram;
+ };
+
+ size_t actual_;
+
+ std::vector<Line> lines_;
+};
+
+class BinaryOutputBuffer {
+ public:
+ BinaryOutputBuffer() {}
+
+ void Reserve(size_t size) {
+ lines_.reserve(size);
+ }
+
+ void AddNGram(const StringPiece &line) {
+ lines_.push_back(line);
+ }
+
+ template <class Output> void Flush(Output &output) {
+ for (std::vector<StringPiece>::const_iterator i = lines_.begin(); i != lines_.end(); ++i) {
+ output.AddNGram(*i);
+ }
+ lines_.clear();
+ }
+
+ private:
+ std::vector<StringPiece> lines_;
+};
+
+class MultipleOutputBuffer {
+ public:
+ MultipleOutputBuffer() : last_(NULL) {}
+
+ void Reserve(size_t size) {
+ annotated_.reserve(size);
+ }
+
+ void AddNGram(const StringPiece &line) {
+ annotated_.resize(annotated_.size() + 1);
+ annotated_.back().line = line;
+ }
+
+ void SingleAddNGram(size_t offset, const StringPiece &line) {
+ if ((line.data() == last_.data()) && (line.length() == last_.length())) {
+ annotated_.back().systems.push_back(offset);
+ } else {
+ annotated_.resize(annotated_.size() + 1);
+ annotated_.back().systems.push_back(offset);
+ annotated_.back().line = line;
+ last_ = line;
+ }
+ }
+
+ template <class Output> void Flush(Output &output) {
+ for (std::vector<Annotated>::const_iterator i = annotated_.begin(); i != annotated_.end(); ++i) {
+ if (i->systems.empty()) {
+ output.AddNGram(i->line);
+ } else {
+ for (std::vector<size_t>::const_iterator j = i->systems.begin(); j != i->systems.end(); ++j) {
+ output.SingleAddNGram(*j, i->line);
+ }
+ }
+ }
+ annotated_.clear();
+ }
+
+ private:
+ struct Annotated {
+ // If this is empty, send to all systems.
+ // A filter should never send to all systems and send to a single one.
+ std::vector<size_t> systems;
+ StringPiece line;
+ };
+
+ StringPiece last_;
+
+ std::vector<Annotated> annotated_;
+};
+
+} // namespace lm
+
+#endif // LM_FILTER_FORMAT_H__
diff --git a/lm/filter/main.cc b/lm/filter/main.cc
new file mode 100644
index 000000000..c42243e2b
--- /dev/null
+++ b/lm/filter/main.cc
@@ -0,0 +1,249 @@
+#include "lm/filter/arpa_io.hh"
+#include "lm/filter/format.hh"
+#include "lm/filter/phrase.hh"
+#ifndef NTHREAD
+#include "lm/filter/thread.hh"
+#endif
+#include "lm/filter/vocab.hh"
+#include "lm/filter/wrapper.hh"
+#include "util/file_piece.hh"
+
+#include <boost/ptr_container/ptr_vector.hpp>
+
+#include <cstring>
+#include <fstream>
+#include <iostream>
+#include <memory>
+
+namespace lm {
+namespace {
+
+void DisplayHelp(const char *name) {
+ std::cerr
+ << "Usage: " << name << " mode [context] [phrase] [raw|arpa] [threads:m] [batch_size:m] (vocab|model):input_file output_file\n\n"
+ "copy mode just copies, but makes the format nicer for e.g. irstlm's broken\n"
+ " parser.\n"
+ "single mode treats the entire input as a single sentence.\n"
+ "multiple mode filters to multiple sentences in parallel. Each sentence is on\n"
+ " a separate line. A separate file is created for each file by appending the\n"
+ " 0-indexed line number to the output file name.\n"
+ "union mode produces one filtered model that is the union of models created by\n"
+ " multiple mode.\n\n"
+ "context means only the context (all but last word) has to pass the filter, but\n"
+ " the entire n-gram is output.\n\n"
+ "phrase means that the vocabulary is actually tab-delimited phrases and that the\n"
+ " phrases can generate the n-gram when assembled in arbitrary order and\n"
+ " clipped. Currently works with multiple or union mode.\n\n"
+ "The file format is set by [raw|arpa] with default arpa:\n"
+ "raw means space-separated tokens, optionally followed by a tab and arbitrary\n"
+ " text. This is useful for ngram count files.\n"
+ "arpa means the ARPA file format for n-gram language models.\n\n"
+#ifndef NTHREAD
+ "threads:m sets m threads (default: conccurrency detected by boost)\n"
+ "batch_size:m sets the batch size for threading. Expect memory usage from this\n"
+ " of 2*threads*batch_size n-grams.\n\n"
+#else
+ "This binary was compiled with -DNTHREAD, disabling threading. If you wanted\n"
+ " threading, compile without this flag against Boost >=1.42.0.\n\n"
+#endif
+ "There are two inputs: vocabulary and model. Either may be given as a file\n"
+ " while the other is on stdin. Specify the type given as a file using\n"
+ " vocab: or model: before the file name. \n\n"
+ "For ARPA format, the output must be seekable. For raw format, it can be a\n"
+ " stream i.e. /dev/stdout\n";
+}
+
+typedef enum {MODE_COPY, MODE_SINGLE, MODE_MULTIPLE, MODE_UNION} FilterMode;
+typedef enum {FORMAT_ARPA, FORMAT_COUNT} Format;
+
+struct Config {
+ Config() :
+#ifndef NTHREAD
+ batch_size(25000),
+ threads(boost::thread::hardware_concurrency()),
+#endif
+ phrase(false),
+ context(false),
+ format(FORMAT_ARPA)
+ {
+#ifndef NTHREAD
+ if (!threads) threads = 1;
+#endif
+ }
+
+#ifndef NTHREAD
+ size_t batch_size;
+ size_t threads;
+#endif
+ bool phrase;
+ bool context;
+ FilterMode mode;
+ Format format;
+};
+
+template <class Format, class Filter, class OutputBuffer, class Output> void RunThreadedFilter(const Config &config, util::FilePiece &in_lm, Filter &filter, Output &output) {
+#ifndef NTHREAD
+ if (config.threads == 1) {
+#endif
+ Format::RunFilter(in_lm, filter, output);
+#ifndef NTHREAD
+ } else {
+ typedef Controller<Filter, OutputBuffer, Output> Threaded;
+ Threaded threading(config.batch_size, config.threads * 2, config.threads, filter, output);
+ Format::RunFilter(in_lm, threading, output);
+ }
+#endif
+}
+
+template <class Format, class Filter, class OutputBuffer, class Output> void RunContextFilter(const Config &config, util::FilePiece &in_lm, Filter filter, Output &output) {
+ if (config.context) {
+ ContextFilter<Filter> context_filter(filter);
+ RunThreadedFilter<Format, ContextFilter<Filter>, OutputBuffer, Output>(config, in_lm, context_filter, output);
+ } else {
+ RunThreadedFilter<Format, Filter, OutputBuffer, Output>(config, in_lm, filter, output);
+ }
+}
+
+template <class Format, class Binary> void DispatchBinaryFilter(const Config &config, util::FilePiece &in_lm, const Binary &binary, typename Format::Output &out) {
+ typedef BinaryFilter<Binary> Filter;
+ RunContextFilter<Format, Filter, BinaryOutputBuffer, typename Format::Output>(config, in_lm, Filter(binary), out);
+}
+
+template <class Format> void DispatchFilterModes(const Config &config, std::istream &in_vocab, util::FilePiece &in_lm, const char *out_name) {
+ if (config.mode == MODE_MULTIPLE) {
+ if (config.phrase) {
+ typedef phrase::Multiple Filter;
+ phrase::Substrings substrings;
+ typename Format::Multiple out(out_name, phrase::ReadMultiple(in_vocab, substrings));
+ RunContextFilter<Format, Filter, MultipleOutputBuffer, typename Format::Multiple>(config, in_lm, Filter(substrings), out);
+ } else {
+ typedef vocab::Multiple Filter;
+ boost::unordered_map<std::string, std::vector<unsigned int> > words;
+ typename Format::Multiple out(out_name, vocab::ReadMultiple(in_vocab, words));
+ RunContextFilter<Format, Filter, MultipleOutputBuffer, typename Format::Multiple>(config, in_lm, Filter(words), out);
+ }
+ return;
+ }
+
+ typename Format::Output out(out_name);
+
+ if (config.mode == MODE_COPY) {
+ Format::Copy(in_lm, out);
+ return;
+ }
+
+ if (config.mode == MODE_SINGLE) {
+ vocab::Single::Words words;
+ vocab::ReadSingle(in_vocab, words);
+ DispatchBinaryFilter<Format, vocab::Single>(config, in_lm, vocab::Single(words), out);
+ return;
+ }
+
+ if (config.mode == MODE_UNION) {
+ if (config.phrase) {
+ phrase::Substrings substrings;
+ phrase::ReadMultiple(in_vocab, substrings);
+ DispatchBinaryFilter<Format, phrase::Union>(config, in_lm, phrase::Union(substrings), out);
+ } else {
+ vocab::Union::Words words;
+ vocab::ReadMultiple(in_vocab, words);
+ DispatchBinaryFilter<Format, vocab::Union>(config, in_lm, vocab::Union(words), out);
+ }
+ return;
+ }
+}
+
+} // namespace
+} // namespace lm
+
+int main(int argc, char *argv[]) {
+ if (argc < 4) {
+ lm::DisplayHelp(argv[0]);
+ return 1;
+ }
+
+ // I used to have boost::program_options, but some users didn't want to compile boost.
+ lm::Config config;
+ boost::optional<lm::FilterMode> mode;
+ for (int i = 1; i < argc - 2; ++i) {
+ const char *str = argv[i];
+ if (!std::strcmp(str, "copy")) {
+ mode = lm::MODE_COPY;
+ } else if (!std::strcmp(str, "single")) {
+ mode = lm::MODE_SINGLE;
+ } else if (!std::strcmp(str, "multiple")) {
+ mode = lm::MODE_MULTIPLE;
+ } else if (!std::strcmp(str, "union")) {
+ mode = lm::MODE_UNION;
+ } else if (!std::strcmp(str, "phrase")) {
+ config.phrase = true;
+ } else if (!std::strcmp(str, "context")) {
+ config.context = true;
+ } else if (!std::strcmp(str, "arpa")) {
+ config.format = lm::FORMAT_ARPA;
+ } else if (!std::strcmp(str, "raw")) {
+ config.format = lm::FORMAT_COUNT;
+#ifndef NTHREAD
+ } else if (!std::strncmp(str, "threads:", 8)) {
+ config.threads = boost::lexical_cast<size_t>(str + 8);
+ if (!config.threads) {
+ std::cerr << "Specify at least one thread." << std::endl;
+ return 1;
+ }
+ } else if (!std::strncmp(str, "batch_size:", 11)) {
+ config.batch_size = boost::lexical_cast<size_t>(str + 11);
+ if (config.batch_size < 5000) {
+ std::cerr << "Batch size must be at least one and should probably be >= 5000" << std::endl;
+ if (!config.batch_size) return 1;
+ }
+#endif
+ } else {
+ lm::DisplayHelp(argv[0]);
+ return 1;
+ }
+ }
+
+ if (!mode) {
+ lm::DisplayHelp(argv[0]);
+ return 1;
+ }
+ config.mode = *mode;
+
+ if (config.phrase && config.mode != lm::MODE_UNION && mode != lm::MODE_MULTIPLE) {
+ std::cerr << "Phrase constraint currently only works in multiple or union mode. If you really need it for single, put everything on one line and use union." << std::endl;
+ return 1;
+ }
+
+ bool cmd_is_model = true;
+ const char *cmd_input = argv[argc - 2];
+ if (!strncmp(cmd_input, "vocab:", 6)) {
+ cmd_is_model = false;
+ cmd_input += 6;
+ } else if (!strncmp(cmd_input, "model:", 6)) {
+ cmd_input += 6;
+ } else if (strchr(cmd_input, ':')) {
+ errx(1, "Specify vocab: or model: before the input file name, not \"%s\"", cmd_input);
+ } else {
+ std::cerr << "Assuming that " << cmd_input << " is a model file" << std::endl;
+ }
+ std::ifstream cmd_file;
+ std::istream *vocab;
+ if (cmd_is_model) {
+ vocab = &std::cin;
+ } else {
+ cmd_file.open(cmd_input, std::ios::in);
+ if (!cmd_file) {
+ err(2, "Could not open input file %s", cmd_input);
+ }
+ vocab = &cmd_file;
+ }
+
+ util::FilePiece model(cmd_is_model ? util::OpenReadOrThrow(cmd_input) : 0, cmd_is_model ? cmd_input : NULL, &std::cerr);
+
+ if (config.format == lm::FORMAT_ARPA) {
+ lm::DispatchFilterModes<lm::ARPAFormat>(config, *vocab, model, argv[argc - 1]);
+ } else if (config.format == lm::FORMAT_COUNT) {
+ lm::DispatchFilterModes<lm::CountFormat>(config, *vocab, model, argv[argc - 1]);
+ }
+ return 0;
+}
diff --git a/lm/filter/phrase.cc b/lm/filter/phrase.cc
new file mode 100644
index 000000000..1bef2a3f3
--- /dev/null
+++ b/lm/filter/phrase.cc
@@ -0,0 +1,281 @@
+#include "lm/filter/phrase.hh"
+
+#include "lm/filter/format.hh"
+
+#include <algorithm>
+#include <functional>
+#include <iostream>
+#include <queue>
+#include <string>
+#include <vector>
+
+#include <ctype.h>
+
+namespace lm {
+namespace phrase {
+
+unsigned int ReadMultiple(std::istream &in, Substrings &out) {
+ bool sentence_content = false;
+ unsigned int sentence_id = 0;
+ std::vector<Hash> phrase;
+ std::string word;
+ while (in) {
+ char c;
+ // Gather a word.
+ while (!isspace(c = in.get()) && in) word += c;
+ // Treat EOF like a newline.
+ if (!in) c = '\n';
+ // Add the word to the phrase.
+ if (!word.empty()) {
+ phrase.push_back(util::MurmurHashNative(word.data(), word.size()));
+ word.clear();
+ }
+ if (c == ' ') continue;
+ // It's more than just a space. Close out the phrase.
+ if (!phrase.empty()) {
+ sentence_content = true;
+ out.AddPhrase(sentence_id, phrase.begin(), phrase.end());
+ phrase.clear();
+ }
+ if (c == '\t' || c == '\v') continue;
+ // It's more than a space or tab: a newline.
+ if (sentence_content) {
+ ++sentence_id;
+ sentence_content = false;
+ }
+ }
+ if (!in.eof()) in.exceptions(std::istream::failbit | std::istream::badbit);
+ return sentence_id + sentence_content;
+}
+
+namespace detail { const StringPiece kEndSentence("</s>"); }
+
+namespace {
+
+typedef unsigned int Sentence;
+typedef std::vector<Sentence> Sentences;
+
+class Vertex;
+
+class Arc {
+ public:
+ Arc() {}
+
+ // For arcs from one vertex to another.
+ void SetPhrase(Vertex &from, Vertex &to, const Sentences &intersect) {
+ Set(to, intersect);
+ from_ = &from;
+ }
+
+ /* For arcs from before the n-gram begins to somewhere in the n-gram (right
+ * aligned). These have no from_ vertex; it implictly matches every
+ * sentence. This also handles when the n-gram is a substring of a phrase.
+ */
+ void SetRight(Vertex &to, const Sentences &complete) {
+ Set(to, complete);
+ from_ = NULL;
+ }
+
+ Sentence Current() const {
+ return *current_;
+ }
+
+ bool Empty() const {
+ return current_ == last_;
+ }
+
+ /* When this function returns:
+ * If Empty() then there's nothing left from this intersection.
+ *
+ * If Current() == to then to is part of the intersection.
+ *
+ * Otherwise, Current() > to. In this case, to is not part of the
+ * intersection and neither is anything < Current(). To determine if
+ * any value >= Current() is in the intersection, call LowerBound again
+ * with the value.
+ */
+ void LowerBound(const Sentence to);
+
+ private:
+ void Set(Vertex &to, const Sentences &sentences);
+
+ const Sentence *current_;
+ const Sentence *last_;
+ Vertex *from_;
+};
+
+struct ArcGreater : public std::binary_function<const Arc *, const Arc *, bool> {
+ bool operator()(const Arc *first, const Arc *second) const {
+ return first->Current() > second->Current();
+ }
+};
+
+class Vertex {
+ public:
+ Vertex() : current_(0) {}
+
+ Sentence Current() const {
+ return current_;
+ }
+
+ bool Empty() const {
+ return incoming_.empty();
+ }
+
+ void LowerBound(const Sentence to);
+
+ private:
+ friend class Arc;
+
+ void AddIncoming(Arc *arc) {
+ if (!arc->Empty()) incoming_.push(arc);
+ }
+
+ unsigned int current_;
+ std::priority_queue<Arc*, std::vector<Arc*>, ArcGreater> incoming_;
+};
+
+void Arc::LowerBound(const Sentence to) {
+ current_ = std::lower_bound(current_, last_, to);
+ // If *current_ > to, don't advance from_. The intervening values of
+ // from_ may be useful for another one of its outgoing arcs.
+ if (!from_ || Empty() || (Current() > to)) return;
+ assert(Current() == to);
+ from_->LowerBound(to);
+ if (from_->Empty()) {
+ current_ = last_;
+ return;
+ }
+ assert(from_->Current() >= to);
+ if (from_->Current() > to) {
+ current_ = std::lower_bound(current_ + 1, last_, from_->Current());
+ }
+}
+
+void Arc::Set(Vertex &to, const Sentences &sentences) {
+ current_ = &*sentences.begin();
+ last_ = &*sentences.end();
+ to.AddIncoming(this);
+}
+
+void Vertex::LowerBound(const Sentence to) {
+ if (Empty()) return;
+ // Union lower bound.
+ while (true) {
+ Arc *top = incoming_.top();
+ if (top->Current() > to) {
+ current_ = top->Current();
+ return;
+ }
+ // If top->Current() == to, we still need to verify that's an actual
+ // element and not just a bound.
+ incoming_.pop();
+ top->LowerBound(to);
+ if (!top->Empty()) {
+ incoming_.push(top);
+ if (top->Current() == to) {
+ current_ = to;
+ return;
+ }
+ } else if (Empty()) {
+ return;
+ }
+ }
+}
+
+void BuildGraph(const Substrings &phrase, const std::vector<Hash> &hashes, Vertex *const vertices, Arc *free_arc) {
+ assert(!hashes.empty());
+
+ const Hash *const first_word = &*hashes.begin();
+ const Hash *const last_word = &*hashes.end() - 1;
+
+ Hash hash = 0;
+ const Sentences *found;
+ // Phrases starting at or before the first word in the n-gram.
+ {
+ Vertex *vertex = vertices;
+ for (const Hash *word = first_word; ; ++word, ++vertex) {
+ hash = util::MurmurHashNative(&hash, sizeof(uint64_t), *word);
+ // Now hash is [hashes.begin(), word].
+ if (word == last_word) {
+ if (phrase.FindSubstring(hash, found))
+ (free_arc++)->SetRight(*vertex, *found);
+ break;
+ }
+ if (!phrase.FindRight(hash, found)) break;
+ (free_arc++)->SetRight(*vertex, *found);
+ }
+ }
+
+ // Phrases starting at the second or later word in the n-gram.
+ Vertex *vertex_from = vertices;
+ for (const Hash *word_from = first_word + 1; word_from != &*hashes.end(); ++word_from, ++vertex_from) {
+ hash = 0;
+ Vertex *vertex_to = vertex_from + 1;
+ for (const Hash *word_to = word_from; ; ++word_to, ++vertex_to) {
+ // Notice that word_to and vertex_to have the same index.
+ hash = util::MurmurHashNative(&hash, sizeof(uint64_t), *word_to);
+ // Now hash covers [word_from, word_to].
+ if (word_to == last_word) {
+ if (phrase.FindLeft(hash, found))
+ (free_arc++)->SetPhrase(*vertex_from, *vertex_to, *found);
+ break;
+ }
+ if (!phrase.FindPhrase(hash, found)) break;
+ (free_arc++)->SetPhrase(*vertex_from, *vertex_to, *found);
+ }
+ }
+}
+
+} // namespace
+
+namespace detail {
+
+} // namespace detail
+
+bool Union::Evaluate() {
+ assert(!hashes_.empty());
+ // Usually there are at most 6 words in an n-gram, so stack allocation is reasonable.
+ Vertex vertices[hashes_.size()];
+ // One for every substring.
+ Arc arcs[((hashes_.size() + 1) * hashes_.size()) / 2];
+ BuildGraph(substrings_, hashes_, vertices, arcs);
+ Vertex &last_vertex = vertices[hashes_.size() - 1];
+
+ unsigned int lower = 0;
+ while (true) {
+ last_vertex.LowerBound(lower);
+ if (last_vertex.Empty()) return false;
+ if (last_vertex.Current() == lower) return true;
+ lower = last_vertex.Current();
+ }
+}
+
+template <class Output> void Multiple::Evaluate(const StringPiece &line, Output &output) {
+ assert(!hashes_.empty());
+ // Usually there are at most 6 words in an n-gram, so stack allocation is reasonable.
+ Vertex vertices[hashes_.size()];
+ // One for every substring.
+ Arc arcs[((hashes_.size() + 1) * hashes_.size()) / 2];
+ BuildGraph(substrings_, hashes_, vertices, arcs);
+ Vertex &last_vertex = vertices[hashes_.size() - 1];
+
+ unsigned int lower = 0;
+ while (true) {
+ last_vertex.LowerBound(lower);
+ if (last_vertex.Empty()) return;
+ if (last_vertex.Current() == lower) {
+ output.SingleAddNGram(lower, line);
+ ++lower;
+ } else {
+ lower = last_vertex.Current();
+ }
+ }
+}
+
+template void Multiple::Evaluate<CountFormat::Multiple>(const StringPiece &line, CountFormat::Multiple &output);
+template void Multiple::Evaluate<ARPAFormat::Multiple>(const StringPiece &line, ARPAFormat::Multiple &output);
+template void Multiple::Evaluate<MultipleOutputBuffer>(const StringPiece &line, MultipleOutputBuffer &output);
+
+} // namespace phrase
+} // namespace lm
diff --git a/lm/filter/phrase.hh b/lm/filter/phrase.hh
new file mode 100644
index 000000000..07479dea6
--- /dev/null
+++ b/lm/filter/phrase.hh
@@ -0,0 +1,153 @@
+#ifndef LM_FILTER_PHRASE_H__
+#define LM_FILTER_PHRASE_H__
+
+#include "util/murmur_hash.hh"
+#include "util/string_piece.hh"
+#include "util/tokenize_piece.hh"
+
+#include <boost/unordered_map.hpp>
+
+#include <iosfwd>
+#include <vector>
+
+#define LM_FILTER_PHRASE_METHOD(caps, lower) \
+bool Find##caps(Hash key, const std::vector<unsigned int> *&out) const {\
+ Table::const_iterator i(table_.find(key));\
+ if (i==table_.end()) return false; \
+ out = &i->second.lower; \
+ return true; \
+}
+
+namespace lm {
+namespace phrase {
+
+typedef uint64_t Hash;
+
+class Substrings {
+ private:
+ /* This is the value in a hash table where the key is a string. It indicates
+ * four sets of sentences:
+ * substring is sentences with a phrase containing the key as a substring.
+ * left is sentencess with a phrase that begins with the key (left aligned).
+ * right is sentences with a phrase that ends with the key (right aligned).
+ * phrase is sentences where the key is a phrase.
+ * Each set is encoded as a vector of sentence ids in increasing order.
+ */
+ struct SentenceRelation {
+ std::vector<unsigned int> substring, left, right, phrase;
+ };
+ /* Most of the CPU is hash table lookups, so let's not complicate it with
+ * vector equality comparisons. If a collision happens, the SentenceRelation
+ * structure will contain the union of sentence ids over the colliding strings.
+ * In that case, the filter will be slightly more permissive.
+ * The key here is the same as boost's hash of std::vector<std::string>.
+ */
+ typedef boost::unordered_map<Hash, SentenceRelation> Table;
+
+ public:
+ Substrings() {}
+
+ /* If the string isn't a substring of any phrase, return NULL. Otherwise,
+ * return a pointer to std::vector<unsigned int> listing sentences with
+ * matching phrases. This set may be empty for Left, Right, or Phrase.
+ * Example: const std::vector<unsigned int> *FindSubstring(Hash key)
+ */
+ LM_FILTER_PHRASE_METHOD(Substring, substring)
+ LM_FILTER_PHRASE_METHOD(Left, left)
+ LM_FILTER_PHRASE_METHOD(Right, right)
+ LM_FILTER_PHRASE_METHOD(Phrase, phrase)
+
+ // sentence_id must be non-decreasing. Iterators are over words in the phrase.
+ template <class Iterator> void AddPhrase(unsigned int sentence_id, const Iterator &begin, const Iterator &end) {
+ // Iterate over all substrings.
+ for (Iterator start = begin; start != end; ++start) {
+ Hash hash = 0;
+ SentenceRelation *relation;
+ for (Iterator finish = start; finish != end; ++finish) {
+ hash = util::MurmurHashNative(&hash, sizeof(uint64_t), *finish);
+ // Now hash is of [start, finish].
+ relation = &table_[hash];
+ AppendSentence(relation->substring, sentence_id);
+ if (start == begin) AppendSentence(relation->left, sentence_id);
+ }
+ AppendSentence(relation->right, sentence_id);
+ if (start == begin) AppendSentence(relation->phrase, sentence_id);
+ }
+ }
+
+ private:
+ void AppendSentence(std::vector<unsigned int> &vec, unsigned int sentence_id) {
+ if (vec.empty() || vec.back() != sentence_id) vec.push_back(sentence_id);
+ }
+
+ Table table_;
+};
+
+// Read a file with one sentence per line containing tab-delimited phrases of
+// space-separated words.
+unsigned int ReadMultiple(std::istream &in, Substrings &out);
+
+namespace detail {
+extern const StringPiece kEndSentence;
+
+template <class Iterator> void MakeHashes(Iterator i, const Iterator &end, std::vector<Hash> &hashes) {
+ hashes.clear();
+ if (i == end) return;
+ // TODO: check strict phrase boundaries after <s> and before </s>. For now, just skip tags.
+ if ((i->data()[0] == '<') && (i->data()[i->size() - 1] == '>')) {
+ ++i;
+ }
+ for (; i != end && (*i != kEndSentence); ++i) {
+ hashes.push_back(util::MurmurHashNative(i->data(), i->size()));
+ }
+}
+
+} // namespace detail
+
+class Union {
+ public:
+ explicit Union(const Substrings &substrings) : substrings_(substrings) {}
+
+ template <class Iterator> bool PassNGram(const Iterator &begin, const Iterator &end) {
+ detail::MakeHashes(begin, end, hashes_);
+ return hashes_.empty() || Evaluate();
+ }
+
+ private:
+ bool Evaluate();
+
+ std::vector<Hash> hashes_;
+
+ const Substrings &substrings_;
+};
+
+class Multiple {
+ public:
+ explicit Multiple(const Substrings &substrings) : substrings_(substrings) {}
+
+ template <class Iterator, class Output> void AddNGram(const Iterator &begin, const Iterator &end, const StringPiece &line, Output &output) {
+ detail::MakeHashes(begin, end, hashes_);
+ if (hashes_.empty()) {
+ output.AddNGram(line);
+ return;
+ }
+ Evaluate(line, output);
+ }
+
+ template <class Output> void AddNGram(const StringPiece &ngram, const StringPiece &line, Output &output) {
+ AddNGram(util::TokenIter<util::SingleCharacter, true>(ngram, ' '), util::TokenIter<util::SingleCharacter, true>::end(), line, output);
+ }
+
+ void Flush() const {}
+
+ private:
+ template <class Output> void Evaluate(const StringPiece &line, Output &output);
+
+ std::vector<Hash> hashes_;
+
+ const Substrings &substrings_;
+};
+
+} // namespace phrase
+} // namespace lm
+#endif // LM_FILTER_PHRASE_H__
diff --git a/lm/filter/thread.hh b/lm/filter/thread.hh
new file mode 100644
index 000000000..e785b2637
--- /dev/null
+++ b/lm/filter/thread.hh
@@ -0,0 +1,167 @@
+#ifndef LM_FILTER_THREAD_H__
+#define LM_FILTER_THREAD_H__
+
+#include "util/thread_pool.hh"
+
+#include <boost/utility/in_place_factory.hpp>
+
+#include <deque>
+#include <stack>
+
+namespace lm {
+
+template <class OutputBuffer> class ThreadBatch {
+ public:
+ ThreadBatch() {}
+
+ void Reserve(size_t size) {
+ input_.Reserve(size);
+ output_.Reserve(size);
+ }
+
+ // File reading thread.
+ InputBuffer &Fill(uint64_t sequence) {
+ sequence_ = sequence;
+ // Why wait until now to clear instead of after output? free in the same
+ // thread as allocated.
+ input_.Clear();
+ return input_;
+ }
+
+ // Filter worker thread.
+ template <class Filter> void CallFilter(Filter &filter) {
+ input_.CallFilter(filter, output_);
+ }
+
+ uint64_t Sequence() const { return sequence_; }
+
+ // File writing thread.
+ template <class RealOutput> void Flush(RealOutput &output) {
+ output_.Flush(output);
+ }
+
+ private:
+ InputBuffer input_;
+ OutputBuffer output_;
+
+ uint64_t sequence_;
+};
+
+template <class Batch, class Filter> class FilterWorker {
+ public:
+ typedef Batch *Request;
+
+ FilterWorker(const Filter &filter, util::PCQueue<Request> &done) : filter_(filter), done_(done) {}
+
+ void operator()(Request request) {
+ request->CallFilter(filter_);
+ done_.Produce(request);
+ }
+
+ private:
+ Filter filter_;
+
+ util::PCQueue<Request> &done_;
+};
+
+// There should only be one OutputWorker.
+template <class Batch, class Output> class OutputWorker {
+ public:
+ typedef Batch *Request;
+
+ OutputWorker(Output &output, util::PCQueue<Request> &done) : output_(output), done_(done), base_sequence_(0) {}
+
+ void operator()(Request request) {
+ assert(request->Sequence() >= base_sequence_);
+ // Assemble the output in order.
+ uint64_t pos = request->Sequence() - base_sequence_;
+ if (pos >= ordering_.size()) {
+ ordering_.resize(pos + 1, NULL);
+ }
+ ordering_[pos] = request;
+ while (!ordering_.empty() && ordering_.front()) {
+ ordering_.front()->Flush(output_);
+ done_.Produce(ordering_.front());
+ ordering_.pop_front();
+ ++base_sequence_;
+ }
+ }
+
+ private:
+ Output &output_;
+
+ util::PCQueue<Request> &done_;
+
+ std::deque<Request> ordering_;
+
+ uint64_t base_sequence_;
+};
+
+template <class Filter, class OutputBuffer, class RealOutput> class Controller : boost::noncopyable {
+ private:
+ typedef ThreadBatch<OutputBuffer> Batch;
+
+ public:
+ Controller(size_t batch_size, size_t queue, size_t workers, const Filter &filter, RealOutput &output)
+ : batch_size_(batch_size), queue_size_(queue),
+ batches_(queue),
+ to_read_(queue),
+ output_(queue, 1, boost::in_place(boost::ref(output), boost::ref(to_read_)), NULL),
+ filter_(queue, workers, boost::in_place(boost::ref(filter), boost::ref(output_.In())), NULL),
+ sequence_(0) {
+ for (size_t i = 0; i < queue; ++i) {
+ batches_[i].Reserve(batch_size);
+ local_read_.push(&batches_[i]);
+ }
+ NewInput();
+ }
+
+ void AddNGram(const StringPiece &ngram, const StringPiece &line, RealOutput &output) {
+ input_->AddNGram(ngram, line, output);
+ if (input_->Size() == batch_size_) {
+ FlushInput();
+ NewInput();
+ }
+ }
+
+ void Flush() {
+ FlushInput();
+ while (local_read_.size() < queue_size_) {
+ MoveRead();
+ }
+ NewInput();
+ }
+
+ private:
+ void FlushInput() {
+ if (input_->Empty()) return;
+ filter_.Produce(local_read_.top());
+ local_read_.pop();
+ if (local_read_.empty()) MoveRead();
+ }
+
+ void NewInput() {
+ input_ = &local_read_.top()->Fill(sequence_++);
+ }
+
+ void MoveRead() {
+ local_read_.push(to_read_.Consume());
+ }
+
+ const size_t batch_size_;
+ const size_t queue_size_;
+
+ std::vector<Batch> batches_;
+
+ util::PCQueue<Batch*> to_read_;
+ std::stack<Batch*> local_read_;
+ util::ThreadPool<OutputWorker<Batch, RealOutput> > output_;
+ util::ThreadPool<FilterWorker<Batch, Filter> > filter_;
+
+ uint64_t sequence_;
+ InputBuffer *input_;
+};
+
+} // namespace lm
+
+#endif // LM_FILTER_THREAD_H__
diff --git a/lm/filter/vocab.cc b/lm/filter/vocab.cc
new file mode 100644
index 000000000..7ee4e84ba
--- /dev/null
+++ b/lm/filter/vocab.cc
@@ -0,0 +1,54 @@
+#include "lm/filter/vocab.hh"
+
+#include <istream>
+#include <iostream>
+
+#include <ctype.h>
+#include <err.h>
+
+namespace lm {
+namespace vocab {
+
+void ReadSingle(std::istream &in, boost::unordered_set<std::string> &out) {
+ in.exceptions(std::istream::badbit);
+ std::string word;
+ while (in >> word) {
+ out.insert(word);
+ }
+}
+
+namespace {
+bool IsLineEnd(std::istream &in) {
+ int got;
+ do {
+ got = in.get();
+ if (!in) return true;
+ if (got == '\n') return true;
+ } while (isspace(got));
+ in.unget();
+ return false;
+}
+}// namespace
+
+// Read space separated words in enter separated lines. These lines can be
+// very long, so don't read an entire line at a time.
+unsigned int ReadMultiple(std::istream &in, boost::unordered_map<std::string, std::vector<unsigned int> > &out) {
+ in.exceptions(std::istream::badbit);
+ unsigned int sentence = 0;
+ bool used_id = false;
+ std::string word;
+ while (in >> word) {
+ used_id = true;
+ std::vector<unsigned int> &posting = out[word];
+ if (posting.empty() || (posting.back() != sentence))
+ posting.push_back(sentence);
+ if (IsLineEnd(in)) {
+ ++sentence;
+ used_id = false;
+ }
+ }
+ return sentence + used_id;
+}
+
+} // namespace vocab
+} // namespace lm
diff --git a/lm/filter/vocab.hh b/lm/filter/vocab.hh
new file mode 100644
index 000000000..e2b6adffa
--- /dev/null
+++ b/lm/filter/vocab.hh
@@ -0,0 +1,132 @@
+#ifndef LM_FILTER_VOCAB_H__
+#define LM_FILTER_VOCAB_H__
+
+// Vocabulary-based filters for language models.
+
+#include "util/multi_intersection.hh"
+#include "util/string_piece.hh"
+#include "util/tokenize_piece.hh"
+
+#include <boost/noncopyable.hpp>
+#include <boost/range/iterator_range.hpp>
+#include <boost/unordered/unordered_map.hpp>
+#include <boost/unordered/unordered_set.hpp>
+
+#include <string>
+#include <vector>
+
+namespace lm {
+namespace vocab {
+
+void ReadSingle(std::istream &in, boost::unordered_set<std::string> &out);
+
+// Read one sentence vocabulary per line. Return the number of sentences.
+unsigned int ReadMultiple(std::istream &in, boost::unordered_map<std::string, std::vector<unsigned int> > &out);
+
+/* Is this a special tag like <s> or <UNK>? This actually includes anything
+ * surrounded with < and >, which most tokenizers separate for real words, so
+ * this should not catch real words as it looks at a single token.
+ */
+inline bool IsTag(const StringPiece &value) {
+ // The parser should never give an empty string.
+ assert(!value.empty());
+ return (value.data()[0] == '<' && value.data()[value.size() - 1] == '>');
+}
+
+class Single {
+ public:
+ typedef boost::unordered_set<std::string> Words;
+
+ explicit Single(const Words &vocab) : vocab_(vocab) {}
+
+ template <class Iterator> bool PassNGram(const Iterator &begin, const Iterator &end) {
+ for (Iterator i = begin; i != end; ++i) {
+ if (IsTag(*i)) continue;
+ if (FindStringPiece(vocab_, *i) == vocab_.end()) return false;
+ }
+ return true;
+ }
+
+ private:
+ const Words &vocab_;
+};
+
+class Union {
+ public:
+ typedef boost::unordered_map<std::string, std::vector<unsigned int> > Words;
+
+ explicit Union(const Words &vocabs) : vocabs_(vocabs) {}
+
+ template <class Iterator> bool PassNGram(const Iterator &begin, const Iterator &end) {
+ sets_.clear();
+
+ for (Iterator i(begin); i != end; ++i) {
+ if (IsTag(*i)) continue;
+ Words::const_iterator found(FindStringPiece(vocabs_, *i));
+ if (vocabs_.end() == found) return false;
+ sets_.push_back(boost::iterator_range<const unsigned int*>(&*found->second.begin(), &*found->second.end()));
+ }
+ return (sets_.empty() || util::FirstIntersection(sets_));
+ }
+
+ private:
+ const Words &vocabs_;
+
+ std::vector<boost::iterator_range<const unsigned int*> > sets_;
+};
+
+class Multiple {
+ public:
+ typedef boost::unordered_map<std::string, std::vector<unsigned int> > Words;
+
+ Multiple(const Words &vocabs) : vocabs_(vocabs) {}
+
+ private:
+ // Callback from AllIntersection that does AddNGram.
+ template <class Output> class Callback {
+ public:
+ Callback(Output &out, const StringPiece &line) : out_(out), line_(line) {}
+
+ void operator()(unsigned int index) {
+ out_.SingleAddNGram(index, line_);
+ }
+
+ private:
+ Output &out_;
+ const StringPiece &line_;
+ };
+
+ public:
+ template <class Iterator, class Output> void AddNGram(const Iterator &begin, const Iterator &end, const StringPiece &line, Output &output) {
+ sets_.clear();
+ for (Iterator i(begin); i != end; ++i) {
+ if (IsTag(*i)) continue;
+ Words::const_iterator found(FindStringPiece(vocabs_, *i));
+ if (vocabs_.end() == found) return;
+ sets_.push_back(boost::iterator_range<const unsigned int*>(&*found->second.begin(), &*found->second.end()));
+ }
+ if (sets_.empty()) {
+ output.AddNGram(line);
+ return;
+ }
+
+ Callback<Output> cb(output, line);
+ util::AllIntersection(sets_, cb);
+ }
+
+ template <class Output> void AddNGram(const StringPiece &ngram, const StringPiece &line, Output &output) {
+ AddNGram(util::TokenIter<util::SingleCharacter, true>(ngram, ' '), util::TokenIter<util::SingleCharacter, true>::end(), line, output);
+ }
+
+ void Flush() const {}
+
+ private:
+ const Words &vocabs_;
+
+ std::vector<boost::iterator_range<const unsigned int*> > sets_;
+};
+
+} // namespace vocab
+} // namespace lm
+
+#endif // LM_FILTER_VOCAB_H__
diff --git a/lm/filter/wrapper.hh b/lm/filter/wrapper.hh
new file mode 100644
index 000000000..90b07a08f
--- /dev/null
+++ b/lm/filter/wrapper.hh
@@ -0,0 +1,58 @@
+#ifndef LM_FILTER_WRAPPER_H__
+#define LM_FILTER_WRAPPER_H__
+
+#include "util/string_piece.hh"
+
+#include <algorithm>
+#include <string>
+#include <vector>
+
+namespace lm {
+
+// Provide a single-output filter with the same interface as a
+// multiple-output filter so clients code against one interface.
+template <class Binary> class BinaryFilter {
+ public:
+ // Binary modes are just references (and a set) and it makes the API cleaner to copy them.
+ explicit BinaryFilter(Binary binary) : binary_(binary) {}
+
+ template <class Iterator, class Output> void AddNGram(const Iterator &begin, const Iterator &end, const StringPiece &line, Output &output) {
+ if (binary_.PassNGram(begin, end))
+ output.AddNGram(line);
+ }
+
+ template <class Output> void AddNGram(const StringPiece &ngram, const StringPiece &line, Output &output) {
+ AddNGram(util::TokenIter<util::SingleCharacter, true>(ngram, ' '), util::TokenIter<util::SingleCharacter, true>::end(), line, output);
+ }
+
+ void Flush() const {}
+
+ private:
+ Binary binary_;
+};
+
+// Wrap another filter to pay attention only to context words
+template <class FilterT> class ContextFilter {
+ public:
+ typedef FilterT Filter;
+
+ explicit ContextFilter(Filter &backend) : backend_(backend) {}
+
+ template <class Output> void AddNGram(const StringPiece &ngram, const StringPiece &line, Output &output) {
+ pieces_.clear();
+ // TODO: this copy could be avoided by a lookahead iterator.
+ std::copy(util::TokenIter<util::SingleCharacter, true>(ngram, ' '), util::TokenIter<util::SingleCharacter, true>::end(), std::back_insert_iterator<std::vector<StringPiece> >(pieces_));
+ backend_.AddNGram(pieces_.begin(), pieces_.end() - !pieces_.empty(), line, output);
+ }
+
+ void Flush() const {}
+
+ private:
+ std::vector<StringPiece> pieces_;
+
+ Filter backend_;
+};
+
+} // namespace lm
+
+#endif // LM_FILTER_WRAPPER_H__
diff --git a/lm/read_arpa.cc b/lm/read_arpa.cc
index 4723ab3a2..9ea087985 100644
--- a/lm/read_arpa.cc
+++ b/lm/read_arpa.cc
@@ -46,8 +46,14 @@ uint64_t ReadCount(const std::string &from) {
void ReadARPACounts(util::FilePiece &in, std::vector<uint64_t> &number) {
number.clear();
- StringPiece line;
- while (IsEntirelyWhiteSpace(line = in.ReadLine())) {}
+ StringPiece line = in.ReadLine();
+ // In general, ARPA files can have arbitrary text before "\data\"
+ // But in KenLM, we require such lines to start with "#", so that
+ // we can do stricter error checking
+ while (IsEntirelyWhiteSpace(line) || line.starts_with("#")) {
+ line = in.ReadLine();
+ }
+
if (line != "\\data\\") {
if ((line.size() >= 2) && (line.data()[0] == 0x1f) && (static_cast<unsigned char>(line.data()[1]) == 0x8b)) {
UTIL_THROW(FormatLoadException, "Looks like a gzip file. If this is an ARPA file, pipe " << in.FileName() << " through zcat. If this already in binary format, you need to decompress it because mmap doesn't work on top of gzip.");
diff --git a/lm/sizes.cc b/lm/sizes.cc
new file mode 100644
index 000000000..55ad586c4
--- /dev/null
+++ b/lm/sizes.cc
@@ -0,0 +1,63 @@
+#include "lm/sizes.hh"
+#include "lm/model.hh"
+#include "util/file_piece.hh"
+
+#include <vector>
+#include <iomanip>
+
+namespace lm {
+namespace ngram {
+
+void ShowSizes(const std::vector<uint64_t> &counts, const lm::ngram::Config &config) {
+ uint64_t sizes[6];
+ sizes[0] = ProbingModel::Size(counts, config);
+ sizes[1] = RestProbingModel::Size(counts, config);
+ sizes[2] = TrieModel::Size(counts, config);
+ sizes[3] = QuantTrieModel::Size(counts, config);
+ sizes[4] = ArrayTrieModel::Size(counts, config);
+ sizes[5] = QuantArrayTrieModel::Size(counts, config);
+ uint64_t max_length = *std::max_element(sizes, sizes + sizeof(sizes) / sizeof(uint64_t));
+ uint64_t min_length = *std::min_element(sizes, sizes + sizeof(sizes) / sizeof(uint64_t));
+ uint64_t divide;
+ char prefix;
+ if (min_length < (1 << 10) * 10) {
+ prefix = ' ';
+ divide = 1;
+ } else if (min_length < (1 << 20) * 10) {
+ prefix = 'k';
+ divide = 1 << 10;
+ } else if (min_length < (1ULL << 30) * 10) {
+ prefix = 'M';
+ divide = 1 << 20;
+ } else {
+ prefix = 'G';
+ divide = 1 << 30;
+ }
+ long int length = std::max<long int>(2, static_cast<long int>(ceil(log10((double) max_length / divide))));
+ std::cerr << "Memory estimate for binary LM:\ntype ";
+
+ // right align bytes.
+ for (long int i = 0; i < length - 2; ++i) std::cerr << ' ';
+
+ std::cerr << prefix << "B\n"
+ "probing " << std::setw(length) << (sizes[0] / divide) << " assuming -p " << config.probing_multiplier << "\n"
+ "probing " << std::setw(length) << (sizes[1] / divide) << " assuming -r models -p " << config.probing_multiplier << "\n"
+ "trie " << std::setw(length) << (sizes[2] / divide) << " without quantization\n"
+ "trie " << std::setw(length) << (sizes[3] / divide) << " assuming -q " << (unsigned)config.prob_bits << " -b " << (unsigned)config.backoff_bits << " quantization \n"
+ "trie " << std::setw(length) << (sizes[4] / divide) << " assuming -a " << (unsigned)config.pointer_bhiksha_bits << " array pointer compression\n"
+ "trie " << std::setw(length) << (sizes[5] / divide) << " assuming -a " << (unsigned)config.pointer_bhiksha_bits << " -q " << (unsigned)config.prob_bits << " -b " << (unsigned)config.backoff_bits<< " array pointer compression and quantization\n";
+}
+
+void ShowSizes(const std::vector<uint64_t> &counts) {
+ lm::ngram::Config config;
+ ShowSizes(counts, config);
+}
+
+void ShowSizes(const char *file, const lm::ngram::Config &config) {
+ std::vector<uint64_t> counts;
+ util::FilePiece f(file);
+ lm::ReadARPACounts(f, counts);
+ ShowSizes(counts, config);
+}
+
+}} //namespaces
diff --git a/lm/sizes.hh b/lm/sizes.hh
new file mode 100644
index 000000000..85abade77
--- /dev/null
+++ b/lm/sizes.hh
@@ -0,0 +1,17 @@
+#ifndef LM_SIZES__
+#define LM_SIZES__
+
+#include <vector>
+
+#include <stdint.h>
+
+namespace lm { namespace ngram {
+
+struct Config;
+
+void ShowSizes(const std::vector<uint64_t> &counts, const lm::ngram::Config &config);
+void ShowSizes(const std::vector<uint64_t> &counts);
+void ShowSizes(const char *file, const lm::ngram::Config &config);
+
+}} // namespaces
+#endif // LM_SIZES__
diff --git a/util/Jamfile b/util/Jamfile
index 35cc06d04..816b6bcd9 100644
--- a/util/Jamfile
+++ b/util/Jamfile
@@ -16,7 +16,7 @@ alias read_compressed : read_compressed.o $(compressed_deps) ;
obj read_compressed_test.o : read_compressed_test.cc /top//boost_unit_test_framework : $(compressed_flags) ;
obj file_piece_test.o : file_piece_test.cc /top//boost_unit_test_framework : $(compressed_flags) ;
-fakelib kenutil : bit_packing.cc ersatz_progress.cc exception.cc file.cc file_piece.cc mmap.cc murmur_hash.cc pool.cc read_compressed string_piece.cc usage.cc : <include>.. : : <include>.. ;
+fakelib kenutil : bit_packing.cc ersatz_progress.cc exception.cc file.cc file_piece.cc mmap.cc murmur_hash.cc pool.cc read_compressed scoped.cc string_piece.cc usage.cc double-conversion//double-conversion : <include>.. : : <include>.. ;
import testing ;
@@ -27,3 +27,4 @@ unit-test joint_sort_test : joint_sort_test.cc kenutil /top//boost_unit_test_fra
unit-test probing_hash_table_test : probing_hash_table_test.cc kenutil /top//boost_unit_test_framework ;
unit-test sorted_uniform_test : sorted_uniform_test.cc kenutil /top//boost_unit_test_framework ;
unit-test tokenize_piece_test : tokenize_piece_test.cc kenutil /top//boost_unit_test_framework ;
+unit-test multi_intersection_test : multi_intersection_test.cc kenutil /top//boost_unit_test_framework ;
diff --git a/util/double-conversion/Jamfile b/util/double-conversion/Jamfile
index 8c422f88c..7a92afaa4 100644
--- a/util/double-conversion/Jamfile
+++ b/util/double-conversion/Jamfile
@@ -1 +1 @@
-alias double-conversion : [ glob *.cc ] : : : <include>. ;
+fakelib double-conversion : [ glob *.cc ] : : : <include>. ;
diff --git a/util/double-conversion/double-conversion.cc b/util/double-conversion/double-conversion.cc
index a79fe92d2..febba6cd7 100644
--- a/util/double-conversion/double-conversion.cc
+++ b/util/double-conversion/double-conversion.cc
@@ -577,7 +577,7 @@ double StringToDoubleConverter::StringToIeee(
const char* input,
int length,
int* processed_characters_count,
- bool read_as_double) {
+ bool read_as_double) const {
const char* current = input;
const char* end = input + length;
diff --git a/util/double-conversion/double-conversion.h b/util/double-conversion/double-conversion.h
index f98edae75..1c3387d4f 100644
--- a/util/double-conversion/double-conversion.h
+++ b/util/double-conversion/double-conversion.h
@@ -502,7 +502,7 @@ class StringToDoubleConverter {
// in the 'processed_characters_count'. Trailing junk is never included.
double StringToDouble(const char* buffer,
int length,
- int* processed_characters_count) {
+ int* processed_characters_count) const {
return StringToIeee(buffer, length, processed_characters_count, true);
}
@@ -511,7 +511,7 @@ class StringToDoubleConverter {
// due to potential double-rounding.
float StringToFloat(const char* buffer,
int length,
- int* processed_characters_count) {
+ int* processed_characters_count) const {
return static_cast<float>(StringToIeee(buffer, length,
processed_characters_count, false));
}
@@ -526,7 +526,7 @@ class StringToDoubleConverter {
double StringToIeee(const char* buffer,
int length,
int* processed_characters_count,
- bool read_as_double);
+ bool read_as_double) const;
DISALLOW_IMPLICIT_CONSTRUCTORS(StringToDoubleConverter);
};
diff --git a/util/ersatz_progress.cc b/util/ersatz_progress.cc
index cb9338190..498ab5c58 100644
--- a/util/ersatz_progress.cc
+++ b/util/ersatz_progress.cc
@@ -9,6 +9,8 @@ namespace util {
namespace { const unsigned char kWidth = 100; }
+const char kProgressBanner[] = "----5---10---15---20---25---30---35---40---45---50---55---60---65---70---75---80---85---90---95--100\n";
+
ErsatzProgress::ErsatzProgress() : current_(0), next_(std::numeric_limits<uint64_t>::max()), complete_(next_), out_(NULL) {}
ErsatzProgress::~ErsatzProgress() {
@@ -22,7 +24,7 @@ ErsatzProgress::ErsatzProgress(uint64_t complete, std::ostream *to, const std::s
return;
}
if (!message.empty()) *out_ << message << '\n';
- *out_ << "----5---10---15---20---25---30---35---40---45---50---55---60---65---70---75---80---85---90---95--100\n";
+ *out_ << kProgressBanner;
}
void ErsatzProgress::Milestone() {
diff --git a/util/ersatz_progress.hh b/util/ersatz_progress.hh
index f81ee21ae..b94399a80 100644
--- a/util/ersatz_progress.hh
+++ b/util/ersatz_progress.hh
@@ -10,6 +10,9 @@
// boost. Also adds option to print nothing.
namespace util {
+
+extern const char kProgressBanner[];
+
class ErsatzProgress {
public:
// No output.
diff --git a/util/file.cc b/util/file.cc
index 904e2dc22..9a6d2e647 100644
--- a/util/file.cc
+++ b/util/file.cc
@@ -199,7 +199,7 @@ void WriteOrThrow(FILE *to, const void *data, std::size_t size) {
void FSyncOrThrow(int fd) {
// Apparently windows doesn't have fsync?
#if !defined(_WIN32) && !defined(_WIN64)
- UTIL_THROW_IF_ARG(-1 == fsync(fd), FDException, (fd), "Syncing");
+ UTIL_THROW_IF_ARG(-1 == fsync(fd), FDException, (fd), "while syncing");
#endif
}
@@ -378,6 +378,17 @@ mkstemp_and_unlink(char *tmpl) {
}
#endif
+// If it's a directory, add a /. This lets users say -T /tmp without creating
+// /tmpAAAAAA
+void NormalizeTempPrefix(std::string &base) {
+ if (base.empty()) return;
+ if (base[base.size() - 1] == '/') return;
+ struct stat sb;
+ // It's fine for it to not exist.
+ if (-1 == stat(base.c_str(), &sb)) return;
+ if (S_ISDIR(sb.st_mode)) base += '/';
+}
+
int MakeTemp(const std::string &base) {
std::string name(base);
name += "XXXXXX";
diff --git a/util/file.hh b/util/file.hh
index 471198b1c..be88431dd 100644
--- a/util/file.hh
+++ b/util/file.hh
@@ -123,6 +123,8 @@ std::FILE *FDOpenOrThrow(scoped_fd &file);
std::FILE *FDOpenReadOrThrow(scoped_fd &file);
// Temporary files
+// Append a / if base is a directory.
+void NormalizeTempPrefix(std::string &base);
int MakeTemp(const std::string &prefix);
std::FILE *FMakeTemp(const std::string &prefix);
diff --git a/util/file_piece.cc b/util/file_piece.cc
index 5783c5fd0..fbfa0e0e8 100644
--- a/util/file_piece.cc
+++ b/util/file_piece.cc
@@ -1,13 +1,15 @@
#include "util/file_piece.hh"
+#include "util/double-conversion/double-conversion.h"
#include "util/exception.hh"
#include "util/file.hh"
#include "util/mmap.hh"
-#ifdef WIN32
+
+#if defined(_WIN32) || defined(_WIN64)
#include <io.h>
#else
#include <unistd.h>
-#endif // WIN32
+#endif
#include <iostream>
#include <string>
@@ -110,21 +112,33 @@ void FilePiece::Initialize(const char *name, std::ostream *show_progress, std::s
}
namespace {
-void ParseNumber(const char *begin, char *&end, float &out) {
-#if defined(sun) || defined(WIN32)
- out = static_cast<float>(strtod(begin, &end));
-#else
- out = strtof(begin, &end);
-#endif
+
+static const double_conversion::StringToDoubleConverter kConverter(
+ double_conversion::StringToDoubleConverter::ALLOW_TRAILING_JUNK | double_conversion::StringToDoubleConverter::ALLOW_LEADING_SPACES,
+ std::numeric_limits<double>::quiet_NaN(),
+ std::numeric_limits<double>::quiet_NaN(),
+ "inf",
+ "NaN");
+
+void ParseNumber(const char *begin, const char *&end, float &out) {
+ int count;
+ out = kConverter.StringToFloat(begin, end - begin, &count);
+ end = begin + count;
}
-void ParseNumber(const char *begin, char *&end, double &out) {
- out = strtod(begin, &end);
+void ParseNumber(const char *begin, const char *&end, double &out) {
+ int count;
+ out = kConverter.StringToDouble(begin, end - begin, &count);
+ end = begin + count;
}
-void ParseNumber(const char *begin, char *&end, long int &out) {
- out = strtol(begin, &end, 10);
+void ParseNumber(const char *begin, const char *&end, long int &out) {
+ char *silly_end;
+ out = strtol(begin, &silly_end, 10);
+ end = silly_end;
}
-void ParseNumber(const char *begin, char *&end, unsigned long int &out) {
- out = strtoul(begin, &end, 10);
+void ParseNumber(const char *begin, const char *&end, unsigned long int &out) {
+ char *silly_end;
+ out = strtoul(begin, &silly_end, 10);
+ end = silly_end;
}
} // namespace
@@ -134,16 +148,17 @@ template <class T> T FilePiece::ReadNumber() {
if (at_end_) {
// Hallucinate a null off the end of the file.
std::string buffer(position_, position_end_);
- char *end;
+ const char *buf = buffer.c_str();
+ const char *end = buf + buffer.size();
T ret;
- ParseNumber(buffer.c_str(), end, ret);
- if (buffer.c_str() == end) throw ParseNumberException(buffer);
- position_ += end - buffer.c_str();
+ ParseNumber(buf, end, ret);
+ if (buf == end) throw ParseNumberException(buffer);
+ position_ += end - buf;
return ret;
}
Shift();
}
- char *end;
+ const char *end = last_space_;
T ret;
ParseNumber(position_, end, ret);
if (end == position_) throw ParseNumberException(ReadDelimited());
diff --git a/util/multi_intersection.hh b/util/multi_intersection.hh
new file mode 100644
index 000000000..8334d39df
--- /dev/null
+++ b/util/multi_intersection.hh
@@ -0,0 +1,80 @@
+#ifndef UTIL_MULTI_INTERSECTION__
+#define UTIL_MULTI_INTERSECTION__
+
+#include <boost/optional.hpp>
+#include <boost/range/iterator_range.hpp>
+
+#include <algorithm>
+#include <functional>
+#include <vector>
+
+namespace util {
+
+namespace detail {
+template <class Range> struct RangeLessBySize : public std::binary_function<const Range &, const Range &, bool> {
+ bool operator()(const Range &left, const Range &right) const {
+ return left.size() < right.size();
+ }
+};
+
+/* Takes sets specified by their iterators and a boost::optional containing
+ * the lowest intersection if any. Each set must be sorted in increasing
+ * order. sets is changed to truncate the beginning of each sequence to the
+ * location of the match or an empty set. Precondition: sets is not empty
+ * since the intersection over null is the universe and this function does not
+ * know the universe.
+ */
+template <class Iterator, class Less> boost::optional<typename std::iterator_traits<Iterator>::value_type> FirstIntersectionSorted(std::vector<boost::iterator_range<Iterator> > &sets, const Less &less = std::less<typename std::iterator_traits<Iterator>::value_type>()) {
+ typedef std::vector<boost::iterator_range<Iterator> > Sets;
+ typedef typename std::iterator_traits<Iterator>::value_type Value;
+
+ assert(!sets.empty());
+
+ if (sets.front().empty()) return boost::optional<Value>();
+ // Possibly suboptimal to copy for general Value; makes unsigned int go slightly faster.
+ Value highest(sets.front().front());
+ for (typename Sets::iterator i(sets.begin()); i != sets.end(); ) {
+ i->advance_begin(std::lower_bound(i->begin(), i->end(), highest, less) - i->begin());
+ if (i->empty()) return boost::optional<Value>();
+ if (less(highest, i->front())) {
+ highest = i->front();
+ // start over
+ i = sets.begin();
+ } else {
+ ++i;
+ }
+ }
+ return boost::optional<Value>(highest);
+}
+
+} // namespace detail
+
+template <class Iterator, class Less> boost::optional<typename std::iterator_traits<Iterator>::value_type> FirstIntersection(std::vector<boost::iterator_range<Iterator> > &sets, const Less less) {
+ assert(!sets.empty());
+
+ std::sort(sets.begin(), sets.end(), detail::RangeLessBySize<boost::iterator_range<Iterator> >());
+ return detail::FirstIntersectionSorted(sets, less);
+}
+
+template <class Iterator> boost::optional<typename std::iterator_traits<Iterator>::value_type> FirstIntersection(std::vector<boost::iterator_range<Iterator> > &sets) {
+ return FirstIntersection(sets, std::less<typename std::iterator_traits<Iterator>::value_type>());
+}
+
+template <class Iterator, class Output, class Less> void AllIntersection(std::vector<boost::iterator_range<Iterator> > &sets, Output &out, const Less less) {
+ typedef typename std::iterator_traits<Iterator>::value_type Value;
+ assert(!sets.empty());
+
+ std::sort(sets.begin(), sets.end(), detail::RangeLessBySize<boost::iterator_range<Iterator> >());
+ boost::optional<Value> ret;
+ for (boost::optional<Value> ret; ret = detail::FirstIntersectionSorted(sets, less); sets.front().advance_begin(1)) {
+ out(*ret);
+ }
+}
+
+template <class Iterator, class Output> void AllIntersection(std::vector<boost::iterator_range<Iterator> > &sets, Output &out) {
+ AllIntersection(sets, out, std::less<typename std::iterator_traits<Iterator>::value_type>());
+}
+
+} // namespace util
+
+#endif // UTIL_MULTI_INTERSECTION__
diff --git a/util/multi_intersection_test.cc b/util/multi_intersection_test.cc
new file mode 100644
index 000000000..970afc171
--- /dev/null
+++ b/util/multi_intersection_test.cc
@@ -0,0 +1,63 @@
+#include "util/multi_intersection.hh"
+
+#define BOOST_TEST_MODULE MultiIntersectionTest
+#include <boost/test/unit_test.hpp>
+
+namespace util {
+namespace {
+
+BOOST_AUTO_TEST_CASE(Empty) {
+ std::vector<boost::iterator_range<const unsigned int*> > sets;
+
+ sets.push_back(boost::iterator_range<const unsigned int*>(static_cast<const unsigned int*>(NULL), static_cast<const unsigned int*>(NULL)));
+ BOOST_CHECK(!FirstIntersection(sets));
+}
+
+BOOST_AUTO_TEST_CASE(Single) {
+ std::vector<unsigned int> nums;
+ nums.push_back(1);
+ nums.push_back(4);
+ nums.push_back(100);
+ std::vector<boost::iterator_range<std::vector<unsigned int>::const_iterator> > sets;
+ sets.push_back(nums);
+
+ boost::optional<unsigned int> ret(FirstIntersection(sets));
+
+ BOOST_REQUIRE(ret);
+ BOOST_CHECK_EQUAL(static_cast<unsigned int>(1), *ret);
+}
+
+template <class T, unsigned int len> boost::iterator_range<const T*> RangeFromArray(const T (&arr)[len]) {
+ return boost::iterator_range<const T*>(arr, arr + len);
+}
+
+BOOST_AUTO_TEST_CASE(MultiNone) {
+ unsigned int nums0[] = {1, 3, 4, 22};
+ unsigned int nums1[] = {2, 5, 12};
+ unsigned int nums2[] = {4, 17};
+
+ std::vector<boost::iterator_range<const unsigned int*> > sets;
+ sets.push_back(RangeFromArray(nums0));
+ sets.push_back(RangeFromArray(nums1));
+ sets.push_back(RangeFromArray(nums2));
+
+ BOOST_CHECK(!FirstIntersection(sets));
+}
+
+BOOST_AUTO_TEST_CASE(MultiOne) {
+ unsigned int nums0[] = {1, 3, 4, 17, 22};
+ unsigned int nums1[] = {2, 5, 12, 17};
+ unsigned int nums2[] = {4, 17};
+
+ std::vector<boost::iterator_range<const unsigned int*> > sets;
+ sets.push_back(RangeFromArray(nums0));
+ sets.push_back(RangeFromArray(nums1));
+ sets.push_back(RangeFromArray(nums2));
+
+ boost::optional<unsigned int> ret(FirstIntersection(sets));
+ BOOST_REQUIRE(ret);
+ BOOST_CHECK_EQUAL(static_cast<unsigned int>(17), *ret);
+}
+
+} // namespace
+} // namespace util
diff --git a/util/pcqueue.hh b/util/pcqueue.hh
new file mode 100644
index 000000000..3df8749b1
--- /dev/null
+++ b/util/pcqueue.hh
@@ -0,0 +1,105 @@
+#ifndef UTIL_PCQUEUE__
+#define UTIL_PCQUEUE__
+
+#include <boost/interprocess/sync/interprocess_semaphore.hpp>
+#include <boost/scoped_array.hpp>
+#include <boost/thread/mutex.hpp>
+#include <boost/utility.hpp>
+
+#include <errno.h>
+
+namespace util {
+
+inline void WaitSemaphore (boost::interprocess::interprocess_semaphore &on) {
+ while (1) {
+ try {
+ on.wait();
+ break;
+ }
+ catch (boost::interprocess::interprocess_exception &e) {
+ if (e.get_native_error() != EINTR) throw;
+ }
+ }
+}
+
+/* Producer consumer queue safe for multiple producers and multiple consumers.
+ * T must be default constructable and have operator=.
+ * The value is copied twice for Consume(T &out) or three times for Consume(),
+ * so larger objects should be passed via pointer.
+ * Strong exception guarantee if operator= throws. Undefined if semaphores throw.
+ */
+template <class T> class PCQueue : boost::noncopyable {
+ public:
+ explicit PCQueue(size_t size)
+ : empty_(size), used_(0),
+ storage_(new T[size]),
+ end_(storage_.get() + size),
+ produce_at_(storage_.get()),
+ consume_at_(storage_.get()) {}
+
+ // Add a value to the queue.
+ void Produce(const T &val) {
+ WaitSemaphore(empty_);
+ {
+ boost::unique_lock<boost::mutex> produce_lock(produce_at_mutex_);
+ try {
+ *produce_at_ = val;
+ }
+ catch (...) {
+ empty_.post();
+ throw;
+ }
+ if (++produce_at_ == end_) produce_at_ = storage_.get();
+ }
+ used_.post();
+ }
+
+ // Consume a value, assigning it to out.
+ T& Consume(T &out) {
+ WaitSemaphore(used_);
+ {
+ boost::unique_lock<boost::mutex> consume_lock(consume_at_mutex_);
+ try {
+ out = *consume_at_;
+ }
+ catch (...) {
+ used_.post();
+ throw;
+ }
+ if (++consume_at_ == end_) consume_at_ = storage_.get();
+ }
+ empty_.post();
+ return out;
+ }
+
+ // Convenience version of Consume that copies the value to return.
+ // The other version is faster.
+ T Consume() {
+ T ret;
+ Consume(ret);
+ return ret;
+ }
+
+ private:
+ // Number of empty spaces in storage_.
+ boost::interprocess::interprocess_semaphore empty_;
+ // Number of occupied spaces in storage_.
+ boost::interprocess::interprocess_semaphore used_;
+
+ boost::scoped_array<T> storage_;
+
+ T *const end_;
+
+ // Index for next write in storage_.
+ T *produce_at_;
+ boost::mutex produce_at_mutex_;
+
+ // Index for next read from storage_.
+ T *consume_at_;
+ boost::mutex consume_at_mutex_;
+
+};
+
+} // namespace util
+
+#endif // UTIL_PCQUEUE__
diff --git a/util/pool.cc b/util/pool.cc
index 2dffd06f6..429ba1585 100644
--- a/util/pool.cc
+++ b/util/pool.cc
@@ -1,5 +1,7 @@
#include "util/pool.hh"
+#include "util/scoped.hh"
+
#include <stdlib.h>
namespace util {
@@ -24,8 +26,7 @@ void Pool::FreeAll() {
void *Pool::More(std::size_t size) {
std::size_t amount = std::max(static_cast<size_t>(32) << free_list_.size(), size);
- uint8_t *ret = static_cast<uint8_t*>(malloc(amount));
- if (!ret) throw std::bad_alloc();
+ uint8_t *ret = static_cast<uint8_t*>(MallocOrThrow(amount));
free_list_.push_back(ret);
current_ = ret + size;
current_end_ = ret + amount;
diff --git a/util/probing_hash_table.hh b/util/probing_hash_table.hh
index 4a8aff353..6780489df 100644
--- a/util/probing_hash_table.hh
+++ b/util/probing_hash_table.hh
@@ -126,6 +126,11 @@ template <class EntryT, class HashT, class EqualT = std::equal_to<typename Entry
}
}
+ void Clear(Entry invalid) {
+ std::fill(begin_, end_, invalid);
+ entries_ = 0;
+ }
+
private:
MutableIterator begin_;
std::size_t buckets_;
diff --git a/util/scoped.cc b/util/scoped.cc
new file mode 100644
index 000000000..e7066ee47
--- /dev/null
+++ b/util/scoped.cc
@@ -0,0 +1,29 @@
+#include "util/scoped.hh"
+
+#include <cstdlib>
+
+namespace util {
+
+MallocException::MallocException(std::size_t requested) throw() {
+ *this << "for " << requested << " bytes ";
+}
+
+MallocException::~MallocException() throw() {}
+
+void *MallocOrThrow(std::size_t requested) {
+ void *ret;
+ UTIL_THROW_IF_ARG(!(ret = std::malloc(requested)), MallocException, (requested), "in malloc");
+ return ret;
+}
+
+scoped_malloc::~scoped_malloc() {
+ std::free(p_);
+}
+
+void scoped_malloc::call_realloc(std::size_t to) {
+ void *ret;
+ UTIL_THROW_IF_ARG(!(ret = std::realloc(p_, to)) && to, MallocException, (to), "in realloc");
+ p_ = ret;
+}
+
+} // namespace util
diff --git a/util/scoped.hh b/util/scoped.hh
index 37bc4744f..d0a5aabd0 100644
--- a/util/scoped.hh
+++ b/util/scoped.hh
@@ -4,28 +4,31 @@
#include "util/exception.hh"
#include <cstddef>
-#include <cstdlib>
namespace util {
+class MallocException : public ErrnoException {
+ public:
+ explicit MallocException(std::size_t requested) throw();
+ ~MallocException() throw();
+};
+
+void *MallocOrThrow(std::size_t requested);
+
class scoped_malloc {
public:
scoped_malloc() : p_(NULL) {}
scoped_malloc(void *p) : p_(p) {}
- ~scoped_malloc() { std::free(p_); }
+ ~scoped_malloc();
void reset(void *p = NULL) {
scoped_malloc other(p_);
p_ = p;
}
- void call_realloc(std::size_t to) {
- void *ret;
- UTIL_THROW_IF(!(ret = std::realloc(p_, to)) && to, ErrnoException, "realloc to " << to << " bytes failed.");
- p_ = ret;
- }
+ void call_realloc(std::size_t to);
void *get() { return p_; }
const void *get() const { return p_; }
diff --git a/util/stream/Jamfile b/util/stream/Jamfile
new file mode 100644
index 000000000..469f58762
--- /dev/null
+++ b/util/stream/Jamfile
@@ -0,0 +1,12 @@
+if $(BOOST-VERSION) >= 104800 {
+ timer-link = <library>/top//boost_timer ;
+} else {
+ timer-link = ;
+}
+
+fakelib stream : chain.cc io.cc line_input.cc multi_progress.cc ..//kenutil /top//boost_thread : : : <library>/top//boost_thread $(timer-link) ;
+
+import testing ;
+unit-test io_test : io_test.cc stream /top//boost_unit_test_framework ;
+unit-test stream_test : stream_test.cc stream /top//boost_unit_test_framework ;
+unit-test sort_test : sort_test.cc stream /top//boost_unit_test_framework ;
diff --git a/util/stream/block.hh b/util/stream/block.hh
new file mode 100644
index 000000000..11aa991e4
--- /dev/null
+++ b/util/stream/block.hh
@@ -0,0 +1,43 @@
+#ifndef UTIL_STREAM_BLOCK__
+#define UTIL_STREAM_BLOCK__
+
+#include <cstddef>
+#include <stdint.h>
+
+namespace util {
+namespace stream {
+
+class Block {
+ public:
+ Block() : mem_(NULL), valid_size_(0) {}
+
+ Block(void *mem, std::size_t size) : mem_(mem), valid_size_(size) {}
+
+ void SetValidSize(std::size_t to) { valid_size_ = to; }
+ // Read might fill in less than Allocated at EOF.
+ std::size_t ValidSize() const { return valid_size_; }
+
+ void *Get() { return mem_; }
+ const void *Get() const { return mem_; }
+
+ const void *ValidEnd() const {
+ return reinterpret_cast<const uint8_t*>(mem_) + valid_size_;
+ }
+
+ operator bool() const { return mem_ != NULL; }
+ bool operator!() const { return mem_ == NULL; }
+
+ private:
+ friend class Link;
+ void SetToPoison() {
+ mem_ = NULL;
+ }
+
+ void *mem_;
+ std::size_t valid_size_;
+};
+
+} // namespace stream
+} // namespace util
+
+#endif // UTIL_STREAM_BLOCK__
diff --git a/util/stream/chain.cc b/util/stream/chain.cc
new file mode 100644
index 000000000..46708c601
--- /dev/null
+++ b/util/stream/chain.cc
@@ -0,0 +1,155 @@
+#include "util/stream/chain.hh"
+
+#include "util/stream/io.hh"
+
+#include "util/exception.hh"
+#include "util/pcqueue.hh"
+
+#include <cstdlib>
+#include <new>
+#include <iostream>
+
+#include <stdint.h>
+#include <stdlib.h>
+
+namespace util {
+namespace stream {
+
+ChainConfigException::ChainConfigException() throw() { *this << "Chain configured with "; }
+ChainConfigException::~ChainConfigException() throw() {}
+
+Thread::~Thread() {
+ thread_.join();
+}
+
+void Thread::UnhandledException(const std::exception &e) {
+ std::cerr << e.what() << std::endl;
+ abort();
+}
+
+void Recycler::Run(const ChainPosition &position) {
+ for (Link l(position); l; ++l) {
+ l->SetValidSize(position.GetChain().BlockSize());
+ }
+}
+
+const Recycler kRecycle = Recycler();
+
+Chain::Chain(const ChainConfig &config) : config_(config), complete_called_(false) {
+ UTIL_THROW_IF(!config.entry_size, ChainConfigException, "zero-size entries.");
+ UTIL_THROW_IF(!config.block_count, ChainConfigException, "block count zero");
+ UTIL_THROW_IF(config.total_memory < config.entry_size * config.block_count, ChainConfigException, config.total_memory << " total memory, too small for " << config.block_count << " blocks of containing entries of size " << config.entry_size);
+ // Round down block size to a multiple of entry size.
+ block_size_ = config.total_memory / (config.block_count * config.entry_size) * config.entry_size;
+}
+
+Chain::~Chain() {
+ Wait();
+}
+
+ChainPosition Chain::Add() {
+ if (!Running()) Start();
+ PCQueue<Block> &in = queues_.back();
+ queues_.push_back(new PCQueue<Block>(config_.block_count));
+ return ChainPosition(in, queues_.back(), this, progress_);
+}
+
+Chain &Chain::operator>>(const WriteAndRecycle &writer) {
+ threads_.push_back(new Thread(Complete(), writer));
+ return *this;
+}
+
+void Chain::Wait(bool release_memory) {
+ if (queues_.empty()) {
+ assert(threads_.empty());
+ return; // Nothing to wait for.
+ }
+ if (!complete_called_) CompleteLoop();
+ threads_.clear();
+ for (std::size_t i = 0; queues_.front().Consume(); ++i) {
+ if (i == config_.block_count) {
+ std::cerr << "Chain ending without poison." << std::endl;
+ abort();
+ }
+ }
+ queues_.clear();
+ progress_.Finished();
+ complete_called_ = false;
+ if (release_memory) memory_.reset();
+}
+
+void Chain::Start() {
+ Wait(false);
+ if (!memory_.get()) {
+ // Allocate memory.
+ assert(threads_.empty());
+ assert(queues_.empty());
+ std::size_t malloc_size = block_size_ * config_.block_count;
+ memory_.reset(MallocOrThrow(malloc_size));
+ }
+ // This queue can accomodate all blocks.
+ queues_.push_back(new PCQueue<Block>(config_.block_count));
+ // Populate the lead queue with blocks.
+ uint8_t *base = static_cast<uint8_t*>(memory_.get());
+ for (std::size_t i = 0; i < config_.block_count; ++i) {
+ queues_.front().Produce(Block(base, block_size_));
+ base += block_size_;
+ }
+}
+
+ChainPosition Chain::Complete() {
+ assert(Running());
+ UTIL_THROW_IF(complete_called_, util::Exception, "CompleteLoop() called twice");
+ complete_called_ = true;
+ return ChainPosition(queues_.back(), queues_.front(), this, progress_);
+}
+
+Link::Link() : in_(NULL), out_(NULL), poisoned_(true) {}
+
+void Link::Init(const ChainPosition &position) {
+ UTIL_THROW_IF(in_, util::Exception, "Link::Init twice");
+ in_ = position.in_;
+ out_ = position.out_;
+ poisoned_ = false;
+ progress_ = position.progress_;
+ in_->Consume(current_);
+}
+
+Link::Link(const ChainPosition &position) : in_(NULL) {
+ Init(position);
+}
+
+Link::~Link() {
+ if (current_) {
+ // Probably an exception unwinding.
+ std::cerr << "Last input should have been poison." << std::endl;
+// abort();
+ } else {
+ if (!poisoned_) {
+ // Pass the poison!
+ out_->Produce(current_);
+ }
+ }
+}
+
+Link &Link::operator++() {
+ assert(current_);
+ progress_ += current_.ValidSize();
+ out_->Produce(current_);
+ in_->Consume(current_);
+ if (!current_) {
+ poisoned_ = true;
+ out_->Produce(current_);
+ }
+ return *this;
+}
+
+void Link::Poison() {
+ assert(!poisoned_);
+ current_.SetToPoison();
+ out_->Produce(current_);
+ poisoned_ = true;
+}
+
+} // namespace stream
+} // namespace util
diff --git a/util/stream/chain.hh b/util/stream/chain.hh
new file mode 100644
index 000000000..154b9b334
--- /dev/null
+++ b/util/stream/chain.hh
@@ -0,0 +1,198 @@
+#ifndef UTIL_STREAM_CHAIN__
+#define UTIL_STREAM_CHAIN__
+
+#include "util/stream/block.hh"
+#include "util/stream/config.hh"
+#include "util/stream/multi_progress.hh"
+#include "util/scoped.hh"
+
+#include <boost/ptr_container/ptr_vector.hpp>
+#include <boost/thread/thread.hpp>
+
+#include <cstddef>
+
+#include <assert.h>
+
+namespace util {
+template <class T> class PCQueue;
+namespace stream {
+
+class ChainConfigException : public Exception {
+ public:
+ ChainConfigException() throw();
+ ~ChainConfigException() throw();
+};
+
+class Chain;
+// Specifies position in chain for Link constructor.
+class ChainPosition {
+ public:
+ const Chain &GetChain() const { return *chain_; }
+ private:
+ friend class Chain;
+ friend class Link;
+ ChainPosition(PCQueue<Block> &in, PCQueue<Block> &out, Chain *chain, MultiProgress &progress)
+ : in_(&in), out_(&out), chain_(chain), progress_(progress.Add()) {}
+
+ PCQueue<Block> *in_, *out_;
+
+ Chain *chain_;
+
+ WorkerProgress progress_;
+};
+
+// Position is usually ChainPosition but if there are multiple streams involved, this can be ChainPositions.
+class Thread {
+ public:
+ template <class Position, class Worker> Thread(const Position &position, const Worker &worker)
+ : thread_(boost::ref(*this), position, worker) {}
+
+ ~Thread();
+
+ template <class Position, class Worker> void operator()(const Position &position, Worker &worker) {
+ try {
+ worker.Run(position);
+ } catch (const std::exception &e) {
+ UnhandledException(e);
+ }
+ }
+
+ private:
+ void UnhandledException(const std::exception &e);
+
+ boost::thread thread_;
+};
+
+class Recycler {
+ public:
+ void Run(const ChainPosition &position);
+};
+
+extern const Recycler kRecycle;
+class WriteAndRecycle;
+
+class Chain {
+ private:
+ template <class T, void (T::*ptr)(const ChainPosition &) = &T::Run> struct CheckForRun {
+ typedef Chain type;
+ };
+
+ public:
+ explicit Chain(const ChainConfig &config);
+
+ ~Chain();
+
+ void ActivateProgress() {
+ assert(!Running());
+ progress_.Activate();
+ }
+
+ void SetProgressTarget(uint64_t target) {
+ progress_.SetTarget(target);
+ }
+
+ std::size_t EntrySize() const {
+ return config_.entry_size;
+ }
+ std::size_t BlockSize() const {
+ return block_size_;
+ }
+
+ // Two ways to add to the chain: Add() or operator>>.
+ ChainPosition Add();
+
+ // This is for adding threaded workers with a Run method.
+ template <class Worker> typename CheckForRun<Worker>::type &operator>>(const Worker &worker) {
+ assert(!complete_called_);
+ threads_.push_back(new Thread(Add(), worker));
+ return *this;
+ }
+
+ // Avoid copying the worker.
+ template <class Worker> typename CheckForRun<Worker>::type &operator>>(const boost::reference_wrapper<Worker> &worker) {
+ assert(!complete_called_);
+ threads_.push_back(new Thread(Add(), worker));
+ return *this;
+ }
+
+ // Note that Link and Stream also define operator>> outside this class.
+
+ // To complete the loop, call CompleteLoop(), >> kRecycle, or the destructor.
+ void CompleteLoop() {
+ threads_.push_back(new Thread(Complete(), kRecycle));
+ }
+
+ Chain &operator>>(const Recycler &recycle) {
+ CompleteLoop();
+ return *this;
+ }
+
+ Chain &operator>>(const WriteAndRecycle &writer);
+
+ // Chains are reusable. Call Wait to wait for everything to finish and free memory.
+ void Wait(bool release_memory = true);
+
+ // Waits for the current chain to complete (if any) then starts again.
+ void Start();
+
+ bool Running() const { return !queues_.empty(); }
+
+ private:
+ ChainPosition Complete();
+
+ ChainConfig config_;
+
+ std::size_t block_size_;
+
+ scoped_malloc memory_;
+
+ boost::ptr_vector<PCQueue<Block> > queues_;
+
+ bool complete_called_;
+
+ boost::ptr_vector<Thread> threads_;
+
+ MultiProgress progress_;
+};
+
+// Create the link in the worker thread using the position token.
+class Link {
+ public:
+ // Either default construct and Init or just construct all at once.
+ Link();
+ void Init(const ChainPosition &position);
+
+ explicit Link(const ChainPosition &position);
+
+ ~Link();
+
+ Block &operator*() { return current_; }
+ const Block &operator*() const { return current_; }
+
+ Block *operator->() { return &current_; }
+ const Block *operator->() const { return &current_; }
+
+ Link &operator++();
+
+ operator bool() const { return current_; }
+
+ void Poison();
+
+ private:
+ Block current_;
+ PCQueue<Block> *in_, *out_;
+
+ bool poisoned_;
+
+ WorkerProgress progress_;
+};
+
+inline Chain &operator>>(Chain &chain, Link &link) {
+ link.Init(chain.Add());
+ return chain;
+}
+
+} // namespace stream
+} // namespace util
+
+#endif // UTIL_STREAM_CHAIN__
diff --git a/util/stream/config.hh b/util/stream/config.hh
new file mode 100644
index 000000000..1eeb3a8a1
--- /dev/null
+++ b/util/stream/config.hh
@@ -0,0 +1,32 @@
+#ifndef UTIL_STREAM_CONFIG__
+#define UTIL_STREAM_CONFIG__
+
+#include <cstddef>
+#include <string>
+
+namespace util { namespace stream {
+
+struct ChainConfig {
+ ChainConfig() {}
+
+ ChainConfig(std::size_t in_entry_size, std::size_t in_block_count, std::size_t in_total_memory)
+ : entry_size(in_entry_size), block_count(in_block_count), total_memory(in_total_memory) {}
+
+ std::size_t entry_size;
+ std::size_t block_count;
+ // Chain's constructor will make this a multiple of entry_size.
+ std::size_t total_memory;
+};
+
+struct SortConfig {
+ std::string temp_prefix;
+
+ // Size of each input/output buffer.
+ std::size_t buffer_size;
+
+ // Total memory to use when running alone.
+ std::size_t total_memory;
+};
+
+}} // namespaces
+#endif // UTIL_STREAM_CONFIG__
diff --git a/util/stream/io.cc b/util/stream/io.cc
new file mode 100644
index 000000000..c7ad29804
--- /dev/null
+++ b/util/stream/io.cc
@@ -0,0 +1,64 @@
+#include "util/stream/io.hh"
+
+#include "util/file.hh"
+#include "util/stream/chain.hh"
+
+#include <cstddef>
+
+namespace util {
+namespace stream {
+
+ReadSizeException::ReadSizeException() throw() {}
+ReadSizeException::~ReadSizeException() throw() {}
+
+void Read::Run(const ChainPosition &position) {
+ const std::size_t block_size = position.GetChain().BlockSize();
+ const std::size_t entry_size = position.GetChain().EntrySize();
+ for (Link link(position); link; ++link) {
+ std::size_t got = util::ReadOrEOF(file_, link->Get(), block_size);
+ UTIL_THROW_IF(got % entry_size, ReadSizeException, "File ended with " << got << " bytes, not a multiple of " << entry_size << ".");
+ if (got == 0) {
+ link.Poison();
+ return;
+ } else {
+ link->SetValidSize(got);
+ }
+ }
+}
+
+void PRead::Run(const ChainPosition &position) {
+ scoped_fd owner;
+ if (own_) owner.reset(file_);
+ uint64_t size = SizeOrThrow(file_);
+ UTIL_THROW_IF(size % static_cast<uint64_t>(position.GetChain().EntrySize()), ReadSizeException, "File size " << file_ << " size is " << size << " not a multiple of " << position.GetChain().EntrySize());
+ std::size_t block_size = position.GetChain().BlockSize();
+ Link link(position);
+ uint64_t offset = 0;
+ for (; offset + block_size < size; offset += block_size, ++link) {
+ PReadOrThrow(file_, link->Get(), block_size, offset);
+ link->SetValidSize(block_size);
+ }
+ if (size - offset) {
+ PReadOrThrow(file_, link->Get(), size - offset, offset);
+ link->SetValidSize(size - offset);
+ ++link;
+ }
+ link.Poison();
+}
+
+void Write::Run(const ChainPosition &position) {
+ for (Link link(position); link; ++link) {
+ WriteOrThrow(file_, link->Get(), link->ValidSize());
+ }
+}
+
+void WriteAndRecycle::Run(const ChainPosition &position) {
+ const std::size_t block_size = position.GetChain().BlockSize();
+ for (Link link(position); link; ++link) {
+ WriteOrThrow(file_, link->Get(), link->ValidSize());
+ link->SetValidSize(block_size);
+ }
+}
+
+} // namespace stream
+} // namespace util
diff --git a/util/stream/io.hh b/util/stream/io.hh
new file mode 100644
index 000000000..934b6b3fe
--- /dev/null
+++ b/util/stream/io.hh
@@ -0,0 +1,76 @@
+#ifndef UTIL_STREAM_IO__
+#define UTIL_STREAM_IO__
+
+#include "util/exception.hh"
+#include "util/file.hh"
+
+namespace util {
+namespace stream {
+
+class ChainPosition;
+
+class ReadSizeException : public util::Exception {
+ public:
+ ReadSizeException() throw();
+ ~ReadSizeException() throw();
+};
+
+class Read {
+ public:
+ explicit Read(int fd) : file_(fd) {}
+ void Run(const ChainPosition &position);
+ private:
+ int file_;
+};
+
+// Like read but uses pread so that the file can be accessed from multiple threads.
+class PRead {
+ public:
+ explicit PRead(int fd, bool take_own = false) : file_(fd), own_(take_own) {}
+ void Run(const ChainPosition &position);
+ private:
+ int file_;
+ bool own_;
+};
+
+class Write {
+ public:
+ explicit Write(int fd) : file_(fd) {}
+ void Run(const ChainPosition &position);
+ private:
+ int file_;
+};
+
+class WriteAndRecycle {
+ public:
+ explicit WriteAndRecycle(int fd) : file_(fd) {}
+ void Run(const ChainPosition &position);
+ private:
+ int file_;
+};
+
+// Reuse the same file over and over again to buffer output.
+class FileBuffer {
+ public:
+ explicit FileBuffer(int fd) : file_(fd) {}
+
+ WriteAndRecycle Sink() const {
+ util::SeekOrThrow(file_.get(), 0);
+ return WriteAndRecycle(file_.get());
+ }
+
+ PRead Source() const {
+ return PRead(file_.get());
+ }
+
+ uint64_t Size() const {
+ return SizeOrThrow(file_.get());
+ }
+
+ private:
+ scoped_fd file_;
+};
+
+} // namespace stream
+} // namespace util
+#endif // UTIL_STREAM_IO__
diff --git a/util/stream/io_test.cc b/util/stream/io_test.cc
new file mode 100644
index 000000000..82108335a
--- /dev/null
+++ b/util/stream/io_test.cc
@@ -0,0 +1,38 @@
+#include "util/stream/io.hh"
+
+#include "util/stream/chain.hh"
+#include "util/file.hh"
+
+#define BOOST_TEST_MODULE IOTest
+#include <boost/test/unit_test.hpp>
+
+#include <unistd.h>
+
+namespace util { namespace stream { namespace {
+
+BOOST_AUTO_TEST_CASE(CopyFile) {
+ std::string temps("io_test_temp");
+
+ scoped_fd in(MakeTemp(temps));
+ for (uint64_t i = 0; i < 100000; ++i) {
+ WriteOrThrow(in.get(), &i, sizeof(uint64_t));
+ }
+ SeekOrThrow(in.get(), 0);
+ scoped_fd out(MakeTemp(temps));
+
+ ChainConfig config;
+ config.entry_size = 8;
+ config.total_memory = 1024;
+ config.block_count = 10;
+
+ Chain(config) >> PRead(in.get()) >> Write(out.get());
+
+ SeekOrThrow(out.get(), 0);
+ for (uint64_t i = 0; i < 100000; ++i) {
+ uint64_t got;
+ ReadOrThrow(out.get(), &got, sizeof(uint64_t));
+ BOOST_CHECK_EQUAL(i, got);
+ }
+}
+
+}}} // namespaces
diff --git a/util/stream/line_input.cc b/util/stream/line_input.cc
new file mode 100644
index 000000000..dafa50207
--- /dev/null
+++ b/util/stream/line_input.cc
@@ -0,0 +1,52 @@
+#include "util/stream/line_input.hh"
+
+#include "util/exception.hh"
+#include "util/file.hh"
+#include "util/read_compressed.hh"
+#include "util/stream/chain.hh"
+
+#include <algorithm>
+#include <vector>
+
+namespace util { namespace stream {
+
+void LineInput::Run(const ChainPosition &position) {
+ ReadCompressed reader(fd_);
+ // Holding area for beginning of line to be placed in next block.
+ std::vector<char> carry;
+
+ for (Link block(position); ; ++block) {
+ char *to = static_cast<char*>(block->Get());
+ char *begin = to;
+ char *end = to + position.GetChain().BlockSize();
+ std::copy(carry.begin(), carry.end(), to);
+ to += carry.size();
+ while (to != end) {
+ std::size_t got = reader.Read(to, end - to);
+ if (!got) {
+ // EOF
+ block->SetValidSize(to - begin);
+ ++block;
+ block.Poison();
+ return;
+ }
+ to += got;
+ }
+
+ // Find the last newline.
+ char *newline;
+ for (newline = to - 1; ; --newline) {
+ UTIL_THROW_IF(newline < begin, Exception, "Did not find a newline in " << position.GetChain().BlockSize() << " bytes of input of " << NameFromFD(fd_) << ". Is this a text file?");
+ if (*newline == '\n') break;
+ }
+
+ // Copy everything after the last newline to the carry.
+ carry.clear();
+ carry.resize(to - (newline + 1));
+ std::copy(newline + 1, to, &*carry.begin());
+
+ block->SetValidSize(newline + 1 - begin);
+ }
+}
+
+}} // namespaces
diff --git a/util/stream/line_input.hh b/util/stream/line_input.hh
new file mode 100644
index 000000000..86db1dd06
--- /dev/null
+++ b/util/stream/line_input.hh
@@ -0,0 +1,22 @@
+#ifndef UTIL_STREAM_LINE_INPUT__
+#define UTIL_STREAM_LINE_INPUT__
+namespace util {namespace stream {
+
+class ChainPosition;
+
+/* Worker that reads input into blocks, ensuring that blocks contain whole
+ * lines. Assumes that the maximum size of a line is less than the block size
+ */
+class LineInput {
+ public:
+ // Takes ownership upon thread execution.
+ explicit LineInput(int fd);
+
+ void Run(const ChainPosition &position);
+
+ private:
+ int fd_;
+};
+
+}} // namespaces
+#endif // UTIL_STREAM_LINE_INPUT__
diff --git a/util/stream/multi_progress.cc b/util/stream/multi_progress.cc
new file mode 100644
index 000000000..8ba10386b
--- /dev/null
+++ b/util/stream/multi_progress.cc
@@ -0,0 +1,86 @@
+#include "util/stream/multi_progress.hh"
+
+// TODO: merge some functionality with the simple progress bar?
+#include "util/ersatz_progress.hh"
+
+#include <iostream>
+#include <limits>
+
+#include <string.h>
+
+#if !defined(_WIN32) && !defined(_WIN64)
+#include <unistd.h>
+#endif
+
+namespace util { namespace stream {
+
+namespace {
+const char kDisplayCharacters[] = "-+*#0123456789";
+
+uint64_t Next(unsigned char stone, uint64_t complete) {
+ return (static_cast<uint64_t>(stone + 1) * complete + MultiProgress::kWidth - 1) / MultiProgress::kWidth;
+}
+
+} // namespace
+
+MultiProgress::MultiProgress() : active_(false), complete_(std::numeric_limits<uint64_t>::max()), character_handout_(0) {}
+
+MultiProgress::~MultiProgress() {
+ if (active_ && complete_ != std::numeric_limits<uint64_t>::max())
+ std::cerr << '\n';
+}
+
+void MultiProgress::Activate() {
+ active_ =
+#if !defined(_WIN32) && !defined(_WIN64)
+ // Is stderr a terminal?
+ (isatty(2) == 1)
+#else
+ true
+#endif
+ ;
+}
+
+void MultiProgress::SetTarget(uint64_t complete) {
+ if (!active_) return;
+ complete_ = complete;
+ if (!complete) complete_ = 1;
+ memset(display_, 0, sizeof(display_));
+ character_handout_ = 0;
+ std::cerr << kProgressBanner;
+}
+
+WorkerProgress MultiProgress::Add() {
+ if (!active_)
+ return WorkerProgress(std::numeric_limits<uint64_t>::max(), *this, '\0');
+ std::size_t character_index;
+ {
+ boost::unique_lock<boost::mutex> lock(mutex_);
+ character_index = character_handout_++;
+ if (character_handout_ == sizeof(kDisplayCharacters) - 1)
+ character_handout_ = 0;
+ }
+ return WorkerProgress(Next(0, complete_), *this, kDisplayCharacters[character_index]);
+}
+
+void MultiProgress::Finished() {
+ if (!active_ || complete_ == std::numeric_limits<uint64_t>::max()) return;
+ std::cerr << '\n';
+ complete_ = std::numeric_limits<uint64_t>::max();
+}
+
+void MultiProgress::Milestone(WorkerProgress &worker) {
+ if (!active_ || complete_ == std::numeric_limits<uint64_t>::max()) return;
+ unsigned char stone = std::min(static_cast<uint64_t>(kWidth), worker.current_ * kWidth / complete_);
+ for (char *i = &display_[worker.stone_]; i < &display_[stone]; ++i) {
+ *i = worker.character_;
+ }
+ worker.next_ = Next(stone, complete_);
+ worker.stone_ = stone;
+ {
+ boost::unique_lock<boost::mutex> lock(mutex_);
+ std::cerr << '\r' << display_ << std::flush;
+ }
+}
+
+}} // namespaces
diff --git a/util/stream/multi_progress.hh b/util/stream/multi_progress.hh
new file mode 100644
index 000000000..c4dd45a9b
--- /dev/null
+++ b/util/stream/multi_progress.hh
@@ -0,0 +1,90 @@
+/* Progress bar suitable for chains of workers */
+#ifndef UTIL_MULTI_PROGRESS__
+#define UTIL_MULTI_PROGRESS__
+
+#include <boost/thread/mutex.hpp>
+
+#include <cstddef>
+
+#include <stdint.h>
+
+namespace util { namespace stream {
+
+class WorkerProgress;
+
+class MultiProgress {
+ public:
+ static const unsigned char kWidth = 100;
+
+ MultiProgress();
+
+ ~MultiProgress();
+
+ // Turns on showing (requires SetTarget too).
+ void Activate();
+
+ void SetTarget(uint64_t complete);
+
+ WorkerProgress Add();
+
+ void Finished();
+
+ private:
+ friend class WorkerProgress;
+ void Milestone(WorkerProgress &worker);
+
+ bool active_;
+
+ uint64_t complete_;
+
+ boost::mutex mutex_;
+
+ // \0 at the end.
+ char display_[kWidth + 1];
+
+ std::size_t character_handout_;
+
+ MultiProgress(const MultiProgress &);
+ MultiProgress &operator=(const MultiProgress &);
+};
+
+class WorkerProgress {
+ public:
+ // Default contrutor must be initialized with operator= later.
+ WorkerProgress() : parent_(NULL) {}
+
+ // Not threadsafe for the same worker by default.
+ WorkerProgress &operator++() {
+ if (++current_ >= next_) {
+ parent_->Milestone(*this);
+ }
+ return *this;
+ }
+
+ WorkerProgress &operator+=(uint64_t amount) {
+ current_ += amount;
+ if (current_ >= next_) {
+ parent_->Milestone(*this);
+ }
+ return *this;
+ }
+
+ private:
+ friend class MultiProgress;
+ WorkerProgress(uint64_t next, MultiProgress &parent, char character)
+ : current_(0), next_(next), parent_(&parent), stone_(0), character_(character) {}
+
+ uint64_t current_, next_;
+
+ MultiProgress *parent_;
+
+ // Previous milestone reached.
+ unsigned char stone_;
+
+ // Character to display in bar.
+ char character_;
+};
+
+}} // namespaces
+
+#endif // UTIL_MULTI_PROGRESS__
diff --git a/util/stream/sort.hh b/util/stream/sort.hh
new file mode 100644
index 000000000..be6c11eaf
--- /dev/null
+++ b/util/stream/sort.hh
@@ -0,0 +1,542 @@
+/* Usage:
+ * Sort<Compare> sorter(temp, compare);
+ * Chain(config) >> Read(file) >> sorter.Unsorted();
+ * Stream stream;
+ * Chain chain(config) >> sorter.Sorted(internal_config, lazy_config) >> stream;
+ *
+ * Note that sorter must outlive any threads that use Unsorted or Sorted.
+ *
+ * Combiners take the form:
+ * bool operator()(void *into, const void *option, const Compare &compare) const
+ * which returns true iff a combination happened. The sorting algorithm
+ * guarantees compare(into, option). But it does not guarantee
+ * compare(option, into).
+ * Currently, combining is only done in merge steps, not during on-the-fly
+ * sort. Use a hash table for that.
+ */
+
+#ifndef UTIL_STREAM_SORT__
+#define UTIL_STREAM_SORT__
+
+#include "util/stream/chain.hh"
+#include "util/stream/config.hh"
+#include "util/stream/io.hh"
+#include "util/stream/stream.hh"
+#include "util/stream/timer.hh"
+
+#include "util/file.hh"
+#include "util/scoped.hh"
+#include "util/sized_iterator.hh"
+
+#include <algorithm>
+#include <iostream>
+#include <queue>
+#include <string>
+
+namespace util {
+namespace stream {
+
+struct NeverCombine {
+ template <class Compare> bool operator()(const void *, const void *, const Compare &) const {
+ return false;
+ }
+};
+
+// Manage the offsets of sorted blocks in a file.
+class Offsets {
+ public:
+ explicit Offsets(int fd) : log_(fd) {
+ Reset();
+ }
+
+ int File() const { return log_; }
+
+ void Append(uint64_t length) {
+ if (!length) return;
+ ++block_count_;
+ if (length == cur_.length) {
+ ++cur_.run;
+ return;
+ }
+ WriteOrThrow(log_, &cur_, sizeof(Entry));
+ cur_.length = length;
+ cur_.run = 1;
+ }
+
+ void FinishedAppending() {
+ WriteOrThrow(log_, &cur_, sizeof(Entry));
+ SeekOrThrow(log_, sizeof(Entry)); // Skip 0,0 at beginning.
+ cur_.run = 0;
+ if (block_count_) {
+ ReadOrThrow(log_, &cur_, sizeof(Entry));
+ assert(cur_.length);
+ assert(cur_.run);
+ }
+ }
+
+ uint64_t RemainingBlocks() const { return block_count_; }
+
+ uint64_t TotalOffset() const { return output_sum_; }
+
+ uint64_t PeekSize() const {
+ return cur_.length;
+ }
+
+ uint64_t NextSize() {
+ assert(block_count_);
+ uint64_t ret = cur_.length;
+ output_sum_ += ret;
+
+ --cur_.run;
+ --block_count_;
+ if (!cur_.run && block_count_) {
+ ReadOrThrow(log_, &cur_, sizeof(Entry));
+ assert(cur_.length);
+ assert(cur_.run);
+ }
+ return ret;
+ }
+
+ void Reset() {
+ SeekOrThrow(log_, 0);
+ ResizeOrThrow(log_, 0);
+ cur_.length = 0;
+ cur_.run = 0;
+ block_count_ = 0;
+ output_sum_ = 0;
+ }
+
+ private:
+ int log_;
+
+ struct Entry {
+ uint64_t length;
+ uint64_t run;
+ };
+ Entry cur_;
+
+ uint64_t block_count_;
+
+ uint64_t output_sum_;
+};
+
+// A priority queue of entries backed by file buffers
+template <class Compare> class MergeQueue {
+ public:
+ MergeQueue(int fd, std::size_t buffer_size, std::size_t entry_size, const Compare &compare)
+ : queue_(Greater(compare)), in_(fd), buffer_size_(buffer_size), entry_size_(entry_size) {}
+
+ void Push(void *base, uint64_t offset, uint64_t amount) {
+ queue_.push(Entry(base, in_, offset, amount, buffer_size_));
+ }
+
+ const void *Top() const {
+ return queue_.top().Current();
+ }
+
+ void Pop() {
+ Entry top(queue_.top());
+ queue_.pop();
+ if (top.Increment(in_, buffer_size_, entry_size_))
+ queue_.push(top);
+ }
+
+ std::size_t Size() const {
+ return queue_.size();
+ }
+
+ bool Empty() const {
+ return queue_.empty();
+ }
+
+ private:
+ // Priority queue contains these entries.
+ class Entry {
+ public:
+ Entry() {}
+
+ Entry(void *base, int fd, uint64_t offset, uint64_t amount, std::size_t buf_size) {
+ offset_ = offset;
+ remaining_ = amount;
+ buffer_end_ = static_cast<uint8_t*>(base) + buf_size;
+ Read(fd, buf_size);
+ }
+
+ bool Increment(int fd, std::size_t buf_size, std::size_t entry_size) {
+ current_ += entry_size;
+ if (current_ != buffer_end_) return true;
+ return Read(fd, buf_size);
+ }
+
+ const void *Current() const { return current_; }
+
+ private:
+ bool Read(int fd, std::size_t buf_size) {
+ current_ = buffer_end_ - buf_size;
+ std::size_t amount;
+ if (static_cast<uint64_t>(buf_size) < remaining_) {
+ amount = buf_size;
+ } else if (!remaining_) {
+ return false;
+ } else {
+ amount = remaining_;
+ buffer_end_ = current_ + remaining_;
+ }
+ PReadOrThrow(fd, current_, amount, offset_);
+ offset_ += amount;
+ assert(current_ <= buffer_end_);
+ remaining_ -= amount;
+ return true;
+ }
+
+ // Buffer
+ uint8_t *current_, *buffer_end_;
+ // File
+ uint64_t remaining_, offset_;
+ };
+
+ // Wrapper comparison function for queue entries.
+ class Greater : public std::binary_function<const Entry &, const Entry &, bool> {
+ public:
+ explicit Greater(const Compare &compare) : compare_(compare) {}
+
+ bool operator()(const Entry &first, const Entry &second) const {
+ return compare_(second.Current(), first.Current());
+ }
+
+ private:
+ const Compare compare_;
+ };
+
+ typedef std::priority_queue<Entry, std::vector<Entry>, Greater> Queue;
+ Queue queue_;
+
+ const int in_;
+ const std::size_t buffer_size_;
+ const std::size_t entry_size_;
+};
+
+/* A worker object that merges. If the number of pieces to merge exceeds the
+ * arity, it outputs multiple sorted blocks, recording to out_offsets.
+ * However, users will only every see a single sorted block out output because
+ * Sort::Sorted insures the arity is higher than the number of pieces before
+ * returning this.
+ */
+template <class Compare, class Combine> class MergingReader {
+ public:
+ MergingReader(int in, Offsets *in_offsets, Offsets *out_offsets, std::size_t buffer_size, std::size_t total_memory, const Compare &compare, const Combine &combine) :
+ compare_(compare), combine_(combine),
+ in_(in),
+ in_offsets_(in_offsets), out_offsets_(out_offsets),
+ buffer_size_(buffer_size), total_memory_(total_memory) {}
+
+ void Run(const ChainPosition &position) {
+ Run(position, false);
+ }
+
+ void Run(const ChainPosition &position, bool assert_one) {
+ // Special case: nothing to read.
+ if (!in_offsets_->RemainingBlocks()) {
+ Link l(position);
+ l.Poison();
+ return;
+ }
+ // If there's just one entry, just read.
+ if (in_offsets_->RemainingBlocks() == 1) {
+ // Sequencing is important.
+ uint64_t offset = in_offsets_->TotalOffset();
+ uint64_t amount = in_offsets_->NextSize();
+ ReadSingle(offset, amount, position);
+ if (out_offsets_) out_offsets_->Append(amount);
+ return;
+ }
+
+ Stream str(position);
+ scoped_malloc buffer(MallocOrThrow(total_memory_));
+ uint8_t *const buffer_end = static_cast<uint8_t*>(buffer.get()) + total_memory_;
+
+ const std::size_t entry_size = position.GetChain().EntrySize();
+
+ while (in_offsets_->RemainingBlocks()) {
+ // Use bigger buffers if there's less remaining.
+ uint64_t per_buffer = std::max(buffer_size_, total_memory_ / in_offsets_->RemainingBlocks());
+ per_buffer -= per_buffer % entry_size;
+ assert(per_buffer);
+
+ // Populate queue.
+ MergeQueue<Compare> queue(in_, per_buffer, entry_size, compare_);
+ for (uint8_t *buf = static_cast<uint8_t*>(buffer.get());
+ in_offsets_->RemainingBlocks() && (buf + std::min(per_buffer, in_offsets_->PeekSize()) <= buffer_end);) {
+ uint64_t offset = in_offsets_->TotalOffset();
+ uint64_t size = in_offsets_->NextSize();
+ queue.Push(buf, offset, size);
+ buf += static_cast<std::size_t>(std::min<uint64_t>(size, per_buffer));
+ }
+ // This shouldn't happen but it's probably better to die than loop indefinitely.
+ if (queue.Size() < 2 && in_offsets_->RemainingBlocks()) {
+ std::cerr << "Bug in sort implementation: not merging at least two stripes." << std::endl;
+ abort();
+ }
+ if (assert_one && in_offsets_->RemainingBlocks()) {
+ std::cerr << "Bug in sort implementation: should only be one merge group for lazy sort" << std::endl;
+ abort();
+ }
+
+ uint64_t written = 0;
+ // Merge including combiner support.
+ memcpy(str.Get(), queue.Top(), entry_size);
+ for (queue.Pop(); !queue.Empty(); queue.Pop()) {
+ if (!combine_(str.Get(), queue.Top(), compare_)) {
+ ++written; ++str;
+ memcpy(str.Get(), queue.Top(), entry_size);
+ }
+ }
+ ++written; ++str;
+ if (out_offsets_)
+ out_offsets_->Append(written * entry_size);
+ }
+ str.Poison();
+ }
+
+ private:
+ void ReadSingle(uint64_t offset, const uint64_t size, const ChainPosition &position) {
+ // Special case: only one to read.
+ const uint64_t end = offset + size;
+ const uint64_t block_size = position.GetChain().BlockSize();
+ Link l(position);
+ for (; offset + block_size < end; ++l, offset += block_size) {
+ PReadOrThrow(in_, l->Get(), block_size, offset);
+ l->SetValidSize(block_size);
+ }
+ PReadOrThrow(in_, l->Get(), end - offset, offset);
+ l->SetValidSize(end - offset);
+ (++l).Poison();
+ return;
+ }
+
+ Compare compare_;
+ Combine combine_;
+
+ int in_;
+
+ protected:
+ Offsets *in_offsets_;
+
+ private:
+ Offsets *out_offsets_;
+
+ std::size_t buffer_size_;
+ std::size_t total_memory_;
+};
+
+// The lazy step owns the remaining files. This keeps track of them.
+template <class Compare, class Combine> class OwningMergingReader : public MergingReader<Compare, Combine> {
+ private:
+ typedef MergingReader<Compare, Combine> P;
+ public:
+ OwningMergingReader(int data, const Offsets &offsets, std::size_t buffer, std::size_t lazy, const Compare &compare, const Combine &combine)
+ : P(data, NULL, NULL, buffer, lazy, compare, combine),
+ data_(data),
+ offsets_(offsets) {}
+
+ void Run(const ChainPosition &position) {
+ P::in_offsets_ = &offsets_;
+ scoped_fd data(data_);
+ scoped_fd offsets_file(offsets_.File());
+ P::Run(position, true);
+ }
+
+ private:
+ int data_;
+ Offsets offsets_;
+};
+
+// Don't use this directly. Worker that sorts blocks.
+template <class Compare> class BlockSorter {
+ public:
+ BlockSorter(Offsets &offsets, const Compare &compare) :
+ offsets_(&offsets), compare_(compare) {}
+
+ void Run(const ChainPosition &position) {
+ const std::size_t entry_size = position.GetChain().EntrySize();
+ for (Link link(position); link; ++link) {
+ // Record the size of each block in a separate file.
+ offsets_->Append(link->ValidSize());
+ void *end = static_cast<uint8_t*>(link->Get()) + link->ValidSize();
+ std::sort(
+ SizedIt(link->Get(), entry_size),
+ SizedIt(end, entry_size),
+ compare_);
+ }
+ offsets_->FinishedAppending();
+ }
+
+ private:
+ Offsets *offsets_;
+ SizedCompare<Compare> compare_;
+};
+
+class BadSortConfig : public Exception {
+ public:
+ BadSortConfig() throw() {}
+ ~BadSortConfig() throw() {}
+};
+
+template <class Compare, class Combine = NeverCombine> class Sort {
+ public:
+ Sort(Chain &in, const SortConfig &config, const Compare &compare = Compare(), const Combine &combine = Combine())
+ : config_(config),
+ data_(MakeTemp(config.temp_prefix)),
+ offsets_file_(MakeTemp(config.temp_prefix)), offsets_(offsets_file_.get()),
+ compare_(compare), combine_(combine),
+ entry_size_(in.EntrySize()) {
+ UTIL_THROW_IF(!entry_size_, BadSortConfig, "Sorting entries of size 0");
+ // Make buffer_size a multiple of the entry_size.
+ config_.buffer_size -= config_.buffer_size % entry_size_;
+ UTIL_THROW_IF(!config_.buffer_size, BadSortConfig, "Sort buffer too small");
+ UTIL_THROW_IF(config_.total_memory < config_.buffer_size * 4, BadSortConfig, "Sorting memory " << config_.total_memory << " is too small for four buffers (two read and two write).");
+ in >> BlockSorter<Compare>(offsets_, compare_) >> WriteAndRecycle(data_.get());
+ }
+
+ uint64_t Size() const {
+ return SizeOrThrow(data_.get());
+ }
+
+ // Do merge sort, terminating when lazy merge could be done with the
+ // specified memory. Return the minimum memory necessary to do lazy merge.
+ std::size_t Merge(std::size_t lazy_memory) {
+ if (offsets_.RemainingBlocks() <= 1) return 0;
+ const uint64_t lazy_arity = std::max<uint64_t>(1, lazy_memory / config_.buffer_size);
+ uint64_t size = Size();
+ /* No overflow because
+ * offsets_.RemainingBlocks() * config_.buffer_size <= lazy_memory ||
+ * size < lazy_memory
+ */
+ if (offsets_.RemainingBlocks() <= lazy_arity || size <= static_cast<uint64_t>(lazy_memory))
+ return std::min<std::size_t>(size, offsets_.RemainingBlocks() * config_.buffer_size);
+
+ scoped_fd data2(MakeTemp(config_.temp_prefix));
+ int fd_in = data_.get(), fd_out = data2.get();
+ scoped_fd offsets2_file(MakeTemp(config_.temp_prefix));
+ Offsets offsets2(offsets2_file.get());
+ Offsets *offsets_in = &offsets_, *offsets_out = &offsets2;
+
+ // Double buffered writing.
+ ChainConfig chain_config;
+ chain_config.entry_size = entry_size_;
+ chain_config.block_count = 2;
+ chain_config.total_memory = config_.buffer_size * 2;
+ Chain chain(chain_config);
+
+ while (offsets_in->RemainingBlocks() > lazy_arity) {
+ if (size <= static_cast<uint64_t>(lazy_memory)) break;
+ std::size_t reading_memory = config_.total_memory - 2 * config_.buffer_size;
+ if (size < static_cast<uint64_t>(reading_memory)) {
+ reading_memory = static_cast<std::size_t>(size);
+ }
+ SeekOrThrow(fd_in, 0);
+ chain >>
+ MergingReader<Compare, Combine>(
+ fd_in,
+ offsets_in, offsets_out,
+ config_.buffer_size,
+ reading_memory,
+ compare_, combine_) >>
+ WriteAndRecycle(fd_out);
+ chain.Wait();
+ offsets_out->FinishedAppending();
+ ResizeOrThrow(fd_in, 0);
+ offsets_in->Reset();
+ std::swap(fd_in, fd_out);
+ std::swap(offsets_in, offsets_out);
+ size = SizeOrThrow(fd_in);
+ }
+
+ SeekOrThrow(fd_in, 0);
+ if (fd_in == data2.get()) {
+ data_.reset(data2.release());
+ offsets_file_.reset(offsets2_file.release());
+ offsets_ = offsets2;
+ }
+ if (offsets_.RemainingBlocks() <= 1) return 0;
+ // No overflow because the while loop exited.
+ return std::min(size, offsets_.RemainingBlocks() * static_cast<uint64_t>(config_.buffer_size));
+ }
+
+ // Output to chain, using this amount of memory, maximum, for lazy merge
+ // sort.
+ void Output(Chain &out, std::size_t lazy_memory) {
+ Merge(lazy_memory);
+ out.SetProgressTarget(Size());
+ out >> OwningMergingReader<Compare, Combine>(data_.get(), offsets_, config_.buffer_size, lazy_memory, compare_, combine_);
+ data_.release();
+ offsets_file_.release();
+ }
+
+ /* If a pipeline step is reading sorted input and writing to a different
+ * sort order, then there's a trade-off between using RAM to read lazily
+ * (avoiding copying the file) and using RAM to increase block size and,
+ * therefore, decrease the number of merge sort passes in the next
+ * iteration.
+ *
+ * Merge sort takes log_{arity}(pieces) passes. Thus, each time the chain
+ * block size is multiplied by arity, the number of output passes decreases
+ * by one. Up to a constant, then, log_{arity}(chain) is the number of
+ * passes saved. Chain simply divides the memory evenly over all blocks.
+ *
+ * Lazy sort saves this many passes (up to a constant)
+ * log_{arity}((memory-lazy)/block_count) + 1
+ * Non-lazy sort saves this many passes (up to the same constant):
+ * log_{arity}(memory/block_count)
+ * Add log_{arity}(block_count) to both:
+ * log_{arity}(memory-lazy) + 1 versus log_{arity}(memory)
+ * Take arity to the power of both sizes (arity > 1)
+ * (memory - lazy)*arity versus memory
+ * Solve for lazy
+ * lazy = memory * (arity - 1) / arity
+ */
+ std::size_t DefaultLazy() {
+ float arity = static_cast<float>(config_.total_memory / config_.buffer_size);
+ return static_cast<std::size_t>(static_cast<float>(config_.total_memory) * (arity - 1.0) / arity);
+ }
+
+ // Same as Output with default lazy memory setting.
+ void Output(Chain &out) {
+ Output(out, DefaultLazy());
+ }
+
+ // Completely merge sort and transfer ownership to the caller.
+ int StealCompleted() {
+ // Merge all the way.
+ Merge(0);
+ SeekOrThrow(data_.get(), 0);
+ offsets_file_.reset();
+ return data_.release();
+ }
+
+ private:
+ SortConfig config_;
+
+ scoped_fd data_;
+
+ scoped_fd offsets_file_;
+ Offsets offsets_;
+
+ const Compare compare_;
+ const Combine combine_;
+ const std::size_t entry_size_;
+};
+
+// returns bytes to be read on demand.
+template <class Compare, class Combine> uint64_t BlockingSort(Chain &chain, const SortConfig &config, const Compare &compare = Compare(), const Combine &combine = NeverCombine()) {
+ Sort<Compare, Combine> sorter(chain, config, compare, combine);
+ chain.Wait(true);
+ uint64_t size = sorter.Size();
+ sorter.Output(chain);
+ return size;
+}
+
+} // namespace stream
+} // namespace util
+
+#endif // UTIL_STREAM_SORT__
diff --git a/util/stream/sort_test.cc b/util/stream/sort_test.cc
new file mode 100644
index 000000000..fd7705cd9
--- /dev/null
+++ b/util/stream/sort_test.cc
@@ -0,0 +1,62 @@
+#include "util/stream/sort.hh"
+
+#define BOOST_TEST_MODULE SortTest
+#include <boost/test/unit_test.hpp>
+
+#include <algorithm>
+
+#include <unistd.h>
+
+namespace util { namespace stream { namespace {
+
+struct CompareUInt64 : public std::binary_function<const void *, const void *, bool> {
+ bool operator()(const void *first, const void *second) const {
+ return *static_cast<const uint64_t*>(first) < *reinterpret_cast<const uint64_t*>(second);
+ }
+};
+
+const uint64_t kSize = 100000;
+
+struct Putter {
+ Putter(std::vector<uint64_t> &shuffled) : shuffled_(shuffled) {}
+
+ void Run(const ChainPosition &position) {
+ Stream put_shuffled(position);
+ for (uint64_t i = 0; i < shuffled_.size(); ++i, ++put_shuffled) {
+ *static_cast<uint64_t*>(put_shuffled.Get()) = shuffled_[i];
+ }
+ put_shuffled.Poison();
+ }
+ std::vector<uint64_t> &shuffled_;
+};
+
+BOOST_AUTO_TEST_CASE(FromShuffled) {
+ std::vector<uint64_t> shuffled;
+ shuffled.reserve(kSize);
+ for (uint64_t i = 0; i < kSize; ++i) {
+ shuffled.push_back(i);
+ }
+ std::random_shuffle(shuffled.begin(), shuffled.end());
+
+ ChainConfig config;
+ config.entry_size = 8;
+ config.total_memory = 800;
+ config.block_count = 3;
+
+ SortConfig merge_config;
+ merge_config.temp_prefix = "sort_test_temp";
+ merge_config.buffer_size = 800;
+ merge_config.total_memory = 3300;
+
+ Chain chain(config);
+ chain >> Putter(shuffled);
+ BlockingSort(chain, merge_config, CompareUInt64(), NeverCombine());
+ Stream sorted;
+ chain >> sorted >> kRecycle;
+ for (uint64_t i = 0; i < kSize; ++i, ++sorted) {
+ BOOST_CHECK_EQUAL(i, *static_cast<const uint64_t*>(sorted.Get()));
+ }
+ BOOST_CHECK(!sorted);
+}
+
+}}} // namespaces
diff --git a/util/stream/stream.hh b/util/stream/stream.hh
new file mode 100644
index 000000000..6ff45b820
--- /dev/null
+++ b/util/stream/stream.hh
@@ -0,0 +1,74 @@
+#ifndef UTIL_STREAM_STREAM__
+#define UTIL_STREAM_STREAM__
+
+#include "util/stream/chain.hh"
+
+#include <boost/noncopyable.hpp>
+
+#include <assert.h>
+#include <stdint.h>
+
+namespace util {
+namespace stream {
+
+class Stream : boost::noncopyable {
+ public:
+ Stream() : current_(NULL), end_(NULL) {}
+
+ void Init(const ChainPosition &position) {
+ entry_size_ = position.GetChain().EntrySize();
+ block_size_ = position.GetChain().BlockSize();
+ block_it_.Init(position);
+ StartBlock();
+ }
+
+ explicit Stream(const ChainPosition &position) {
+ Init(position);
+ }
+
+ operator bool() const { return current_ != NULL; }
+ bool operator!() const { return current_ == NULL; }
+
+ const void *Get() const { return current_; }
+ void *Get() { return current_; }
+
+ void Poison() {
+ block_it_->SetValidSize(current_ - static_cast<uint8_t*>(block_it_->Get()));
+ ++block_it_;
+ block_it_.Poison();
+ }
+
+ Stream &operator++() {
+ assert(*this);
+ assert(current_ < end_);
+ current_ += entry_size_;
+ if (current_ == end_) {
+ ++block_it_;
+ StartBlock();
+ }
+ return *this;
+ }
+
+ private:
+ void StartBlock() {
+ for (; block_it_ && !block_it_->ValidSize(); ++block_it_) {}
+ current_ = static_cast<uint8_t*>(block_it_->Get());
+ end_ = current_ + block_it_->ValidSize();
+ }
+
+ uint8_t *current_, *end_;
+
+ std::size_t entry_size_;
+ std::size_t block_size_;
+
+ Link block_it_;
+};
+
+inline Chain &operator>>(Chain &chain, Stream &stream) {
+ stream.Init(chain.Add());
+ return chain;
+}
+
+} // namespace stream
+} // namespace util
+#endif // UTIL_STREAM_STREAM__
diff --git a/util/stream/stream_test.cc b/util/stream/stream_test.cc
new file mode 100644
index 000000000..6575d50d0
--- /dev/null
+++ b/util/stream/stream_test.cc
@@ -0,0 +1,35 @@
+#include "util/stream/io.hh"
+
+#include "util/stream/stream.hh"
+#include "util/file.hh"
+
+#define BOOST_TEST_MODULE StreamTest
+#include <boost/test/unit_test.hpp>
+
+#include <unistd.h>
+
+namespace util { namespace stream { namespace {
+
+BOOST_AUTO_TEST_CASE(StreamTest) {
+ scoped_fd in(MakeTemp("io_test_temp"));
+ for (uint64_t i = 0; i < 100000; ++i) {
+ WriteOrThrow(in.get(), &i, sizeof(uint64_t));
+ }
+ SeekOrThrow(in.get(), 0);
+
+ ChainConfig config;
+ config.entry_size = 8;
+ config.total_memory = 100;
+ config.block_count = 12;
+
+ Stream s;
+ Chain chain(config);
+ chain >> Read(in.get()) >> s >> kRecycle;
+ uint64_t i = 0;
+ for (; s; ++s, ++i) {
+ BOOST_CHECK_EQUAL(i, *static_cast<const uint64_t*>(s.Get()));
+ }
+ BOOST_CHECK_EQUAL(100000ULL, i);
+}
+
+}}} // namespaces
diff --git a/util/stream/timer.hh b/util/stream/timer.hh
new file mode 100644
index 000000000..50e94fe8a
--- /dev/null
+++ b/util/stream/timer.hh
@@ -0,0 +1,14 @@
+#ifndef UTIL_STREAM_TIMER__
+#define UTIL_STREAM_TIMER__
+
+#include <boost/version.hpp>
+
+#if BOOST_VERSION >= 104800
+#include <boost/timer/timer.hpp>
+#define UTIL_TIMER(str) boost::timer::auto_cpu_timer timer(std::cerr, 1, (str))
+#else
+//#warning Using Boost older than 1.48. Timing information will not be available.
+#define UTIL_TIMER(str)
+#endif
+
+#endif // UTIL_STREAM_TIMER__
diff --git a/util/thread_pool.hh b/util/thread_pool.hh
new file mode 100644
index 000000000..84e257ea6
--- /dev/null
+++ b/util/thread_pool.hh
@@ -0,0 +1,95 @@
+#ifndef UTIL_THREAD_POOL__
+#define UTIL_THREAD_POOL__
+
+#include "util/pcqueue.hh"
+
+#include <boost/ptr_container/ptr_vector.hpp>
+#include <boost/optional.hpp>
+#include <boost/thread.hpp>
+
+#include <iostream>
+
+#include <stdlib.h>
+
+namespace util {
+
+template <class HandlerT> class Worker : boost::noncopyable {
+ public:
+ typedef HandlerT Handler;
+ typedef typename Handler::Request Request;
+
+ template <class Construct> Worker(PCQueue<Request> &in, Construct &construct, Request &poison)
+ : in_(in), handler_(construct), thread_(boost::ref(*this)), poison_(poison) {}
+
+ // Only call from thread.
+ void operator()() {
+ Request request;
+ while (1) {
+ in_.Consume(request);
+ if (request == poison_) return;
+ try {
+ (*handler_)(request);
+ }
+ catch(std::exception &e) {
+ std::cerr << "Handler threw " << e.what() << std::endl;
+ abort();
+ }
+ catch(...) {
+ std::cerr << "Handler threw an exception, dropping request" << std::endl;
+ abort();
+ }
+ }
+ }
+
+ void Join() {
+ thread_.join();
+ }
+
+ private:
+ PCQueue<Request> &in_;
+
+ boost::optional<Handler> handler_;
+
+ boost::thread thread_;
+
+ Request poison_;
+};
+
+template <class HandlerT> class ThreadPool : boost::noncopyable {
+ public:
+ typedef HandlerT Handler;
+ typedef typename Handler::Request Request;
+
+ template <class Construct> ThreadPool(size_t queue_length, size_t workers, Construct handler_construct, Request poison) : in_(queue_length), poison_(poison) {
+ for (size_t i = 0; i < workers; ++i) {
+ workers_.push_back(new Worker<Handler>(in_, handler_construct, poison));
+ }
+ }
+
+ ~ThreadPool() {
+ for (size_t i = 0; i < workers_.size(); ++i) {
+ Produce(poison_);
+ }
+ for (typename boost::ptr_vector<Worker<Handler> >::iterator i = workers_.begin(); i != workers_.end(); ++i) {
+ i->Join();
+ }
+ }
+
+ void Produce(const Request &request) {
+ in_.Produce(request);
+ }
+
+ // For adding to the queue.
+ PCQueue<Request> &In() { return in_; }
+
+ private:
+ PCQueue<Request> in_;
+
+ boost::ptr_vector<Worker<Handler> > workers_;
+
+ Request poison_;
+};
+
+} // namespace util
+
+#endif // UTIL_THREAD_POOL__
diff --git a/util/usage.cc b/util/usage.cc
index e5cf76f05..16a004bb4 100644
--- a/util/usage.cc
+++ b/util/usage.cc
@@ -1,13 +1,17 @@
#include "util/usage.hh"
+#include "util/exception.hh"
+
#include <fstream>
#include <ostream>
+#include <sstream>
#include <string.h>
#include <ctype.h>
#if !defined(_WIN32) && !defined(_WIN64)
#include <sys/resource.h>
#include <sys/time.h>
+#include <unistd.h>
#endif
namespace util {
@@ -43,4 +47,60 @@ void PrintUsage(std::ostream &out) {
#endif
}
+uint64_t GuessPhysicalMemory() {
+#if defined(_WIN32) || defined(_WIN64)
+ return 0;
+#elif defined(_SC_PHYS_PAGES) && defined(_SC_PAGESIZE)
+ long pages = sysconf(_SC_PHYS_PAGES);
+ if (pages == -1) return 0;
+ long page_size = sysconf(_SC_PAGESIZE);
+ if (page_size == -1) return 0;
+ return static_cast<uint64_t>(pages) * static_cast<uint64_t>(page_size);
+#else
+ return 0;
+#endif
+}
+
+namespace {
+class SizeParseError : public Exception {
+ public:
+ explicit SizeParseError(const std::string &str) throw() {
+ *this << "Failed to parse " << str << " into a memory size ";
+ }
+};
+
+template <class Num> uint64_t ParseNum(const std::string &arg) {
+ std::stringstream stream(arg);
+ Num value;
+ stream >> value;
+ UTIL_THROW_IF_ARG(!stream, SizeParseError, (arg), "for the leading number.");
+ std::string after;
+ stream >> after;
+ UTIL_THROW_IF_ARG(after.size() > 1, SizeParseError, (arg), "because there are more than two characters after the number.");
+ std::string throwaway;
+ UTIL_THROW_IF_ARG(stream >> throwaway, SizeParseError, (arg), "because there was more cruft " << throwaway << " after the number.");
+
+ // Silly sort, using kilobytes as your default unit.
+ if (after.empty()) after == "K";
+ if (after == "%") {
+ uint64_t mem = GuessPhysicalMemory();
+ UTIL_THROW_IF_ARG(!mem, SizeParseError, (arg), "because % was specified but the physical memory size could not be determined.");
+ return static_cast<double>(value) * static_cast<double>(mem) / 100.0;
+ }
+
+ std::string units("bKMGTPEZY");
+ std::string::size_type index = units.find(after[0]);
+ UTIL_THROW_IF_ARG(index == std::string::npos, SizeParseError, (arg), "the allowed suffixes are " << units << "%.");
+ for (std::string::size_type i = 0; i < index; ++i) {
+ value *= 1024;
+ }
+ return value;
+}
+
+} // namespace
+
+uint64_t ParseSize(const std::string &arg) {
+ return arg.find('.') == std::string::npos ? ParseNum<double>(arg) : ParseNum<uint64_t>(arg);
+}
+
} // namespace util
diff --git a/util/usage.hh b/util/usage.hh
index d331ff74c..e19eda7bf 100644
--- a/util/usage.hh
+++ b/util/usage.hh
@@ -1,8 +1,18 @@
#ifndef UTIL_USAGE__
#define UTIL_USAGE__
+#include <cstddef>
#include <iosfwd>
+#include <string>
+
+#include <stdint.h>
namespace util {
void PrintUsage(std::ostream &to);
+
+// Determine how much physical memory there is. Return 0 on failure.
+uint64_t GuessPhysicalMemory();
+
+// Parse a size like unix sort. Sadly, this means the default multiplier is K.
+uint64_t ParseSize(const std::string &arg);
} // namespace util
#endif // UTIL_USAGE__