Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/moses-smt/mosesdecoder.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/lm
diff options
context:
space:
mode:
authorKenneth Heafield <github@kheafield.com>2012-02-28 22:58:00 +0400
committerKenneth Heafield <github@kheafield.com>2012-02-28 22:58:00 +0400
commite48de47c2381547f78f4dbd89f4fa3e76ba0c6bf (patch)
treecdcbb888209bee7dd9c02a7d678cce4262c35416 /lm
parent7927979298644923cf02ad6c757c3d7c209e365a (diff)
KenLM 98814b2 including faster malloc-backed building and portability improvements
Diffstat (limited to 'lm')
-rw-r--r--lm/binary_format.cc46
-rw-r--r--lm/binary_format.hh2
-rw-r--r--lm/build_binary.cc78
-rw-r--r--lm/config.cc1
-rw-r--r--lm/config.hh8
-rw-r--r--lm/model.cc12
-rw-r--r--lm/quantize.cc12
-rw-r--r--lm/search_hashed.cc6
-rw-r--r--lm/search_hashed.hh6
-rw-r--r--lm/search_trie.cc19
-rw-r--r--lm/search_trie.hh2
-rw-r--r--lm/trie_sort.cc21
-rw-r--r--lm/vocab.cc16
-rw-r--r--lm/vocab.hh6
14 files changed, 149 insertions, 86 deletions
diff --git a/lm/binary_format.cc b/lm/binary_format.cc
index ab0166a65..4796f6d1b 100644
--- a/lm/binary_format.cc
+++ b/lm/binary_format.cc
@@ -87,7 +87,7 @@ uint8_t *SetupJustVocab(const Config &config, uint8_t order, std::size_t memory_
strncpy(reinterpret_cast<char*>(backing.vocab.get()), kMagicIncomplete, TotalHeaderSize(order));
return reinterpret_cast<uint8_t*>(backing.vocab.get()) + TotalHeaderSize(order);
} else {
- backing.vocab.reset(util::MapAnonymous(memory_size), memory_size, util::scoped_memory::MMAP_ALLOCATED);
+ util::MapAnonymous(memory_size, backing.vocab);
return reinterpret_cast<uint8_t*>(backing.vocab.get());
}
}
@@ -103,32 +103,44 @@ uint8_t *GrowForSearch(const Config &config, std::size_t vocab_pad, std::size_t
throw e;
}
+ if (config.write_method == Config::WRITE_AFTER) {
+ util::MapAnonymous(memory_size, backing.search);
+ return reinterpret_cast<uint8_t*>(backing.search.get());
+ }
+ // mmap it now.
// We're skipping over the header and vocab for the search space mmap. mmap likes page aligned offsets, so some arithmetic to round the offset down.
std::size_t page_size = util::SizePage();
std::size_t alignment_cruft = adjusted_vocab % page_size;
backing.search.reset(util::MapOrThrow(alignment_cruft + memory_size, true, util::kFileFlags, false, backing.file.get(), adjusted_vocab - alignment_cruft), alignment_cruft + memory_size, util::scoped_memory::MMAP_ALLOCATED);
-
return reinterpret_cast<uint8_t*>(backing.search.get()) + alignment_cruft;
} else {
- backing.search.reset(util::MapAnonymous(memory_size), memory_size, util::scoped_memory::MMAP_ALLOCATED);
+ util::MapAnonymous(memory_size, backing.search);
return reinterpret_cast<uint8_t*>(backing.search.get());
}
}
-void FinishFile(const Config &config, ModelType model_type, unsigned int search_version, const std::vector<uint64_t> &counts, Backing &backing) {
- if (config.write_mmap) {
- util::SyncOrThrow(backing.search.get(), backing.search.size());
- util::SyncOrThrow(backing.vocab.get(), backing.vocab.size());
- // header and vocab share the same mmap. The header is written here because we know the counts.
- Parameters params = Parameters();
- params.counts = counts;
- params.fixed.order = counts.size();
- params.fixed.probing_multiplier = config.probing_multiplier;
- params.fixed.model_type = model_type;
- params.fixed.has_vocabulary = config.include_vocab;
- params.fixed.search_version = search_version;
- WriteHeader(backing.vocab.get(), params);
+void FinishFile(const Config &config, ModelType model_type, unsigned int search_version, const std::vector<uint64_t> &counts, std::size_t vocab_pad, Backing &backing) {
+ if (!config.write_mmap) return;
+ util::SyncOrThrow(backing.vocab.get(), backing.vocab.size());
+ switch (config.write_method) {
+ case Config::WRITE_MMAP:
+ util::SyncOrThrow(backing.search.get(), backing.search.size());
+ break;
+ case Config::WRITE_AFTER:
+ util::SeekOrThrow(backing.file.get(), backing.vocab.size() + vocab_pad);
+ util::WriteOrThrow(backing.file.get(), backing.search.get(), backing.search.size());
+ util::FSyncOrThrow(backing.file.get());
+ break;
}
+ // header and vocab share the same mmap. The header is written here because we know the counts.
+ Parameters params = Parameters();
+ params.counts = counts;
+ params.fixed.order = counts.size();
+ params.fixed.probing_multiplier = config.probing_multiplier;
+ params.fixed.model_type = model_type;
+ params.fixed.has_vocabulary = config.include_vocab;
+ params.fixed.search_version = search_version;
+ WriteHeader(backing.vocab.get(), params);
}
namespace detail {
@@ -172,7 +184,7 @@ void ReadHeader(int fd, Parameters &out) {
UTIL_THROW(FormatLoadException, "Binary format claims to have a probing multiplier of " << out.fixed.probing_multiplier << " which is < 1.0.");
out.counts.resize(static_cast<std::size_t>(out.fixed.order));
- util::ReadOrThrow(fd, &*out.counts.begin(), sizeof(uint64_t) * out.fixed.order);
+ if (out.fixed.order) util::ReadOrThrow(fd, &*out.counts.begin(), sizeof(uint64_t) * out.fixed.order);
}
void MatchCheck(ModelType model_type, unsigned int search_version, const Parameters &params) {
diff --git a/lm/binary_format.hh b/lm/binary_format.hh
index 71209b2a6..dd795f620 100644
--- a/lm/binary_format.hh
+++ b/lm/binary_format.hh
@@ -58,7 +58,7 @@ uint8_t *GrowForSearch(const Config &config, std::size_t vocab_pad, std::size_t
// Write header to binary file. This is done last to prevent incomplete files
// from loading.
-void FinishFile(const Config &config, ModelType model_type, unsigned int search_version, const std::vector<uint64_t> &counts, Backing &backing);
+void FinishFile(const Config &config, ModelType model_type, unsigned int search_version, const std::vector<uint64_t> &counts, std::size_t vocab_pad, Backing &backing);
namespace detail {
diff --git a/lm/build_binary.cc b/lm/build_binary.cc
index e235cc5a3..8cbb69d0a 100644
--- a/lm/build_binary.cc
+++ b/lm/build_binary.cc
@@ -18,11 +18,14 @@ namespace ngram {
namespace {
void Usage(const char *name) {
- std::cerr << "Usage: " << name << " [-u log10_unknown_probability] [-s] [-i] [-p probing_multiplier] [-t trie_temporary] [-m trie_building_megabytes] [-q bits] [-b bits] [-a bits] [type] input.arpa [output.mmap]\n\n"
+ std::cerr << "Usage: " << name << " [-u log10_unknown_probability] [-s] [-i] [-w mmap|after] [-p probing_multiplier] [-t trie_temporary] [-m trie_building_megabytes] [-q bits] [-b bits] [-a bits] [type] input.arpa [output.mmap]\n\n"
"-u sets the log10 probability for <unk> if the ARPA file does not have one.\n"
" Default is -100. The ARPA file will always take precedence.\n"
"-s allows models to be built even if they do not have <s> and </s>.\n"
-"-i allows buggy models from IRSTLM by mapping positive log probability to 0.\n\n"
+"-i allows buggy models from IRSTLM by mapping positive log probability to 0.\n"
+"-w mmap|after determines how writing is done.\n"
+" mmap maps the binary file and writes to it. Default for trie.\n"
+" after allocates anonymous memory, builds, and writes. Default for probing.\n\n"
"type is either probing or trie. Default is probing.\n\n"
"probing uses a probing hash table. It is the fastest but uses the most memory.\n"
"-p sets the space multiplier and must be >1.0. The default is 1.5.\n\n"
@@ -58,7 +61,7 @@ uint8_t ParseBitCount(const char *from) {
unsigned long val = ParseUInt(from);
if (val > 25) {
util::ParseNumberException e(from);
- e << " bit counts are limited to 256.";
+ e << " bit counts are limited to 25.";
}
return val;
}
@@ -115,10 +118,10 @@ int main(int argc, char *argv[]) {
using namespace lm::ngram;
try {
- bool quantize = false, set_backoff_bits = false, bhiksha = false;
+ bool quantize = false, set_backoff_bits = false, bhiksha = false, set_write_method = false;
lm::ngram::Config config;
int opt;
- while ((opt = getopt(argc, argv, "siu:p:t:m:q:b:a:")) != -1) {
+ while ((opt = getopt(argc, argv, "q:b:a:u:p:t:m:w:si")) != -1) {
switch(opt) {
case 'q':
config.prob_bits = ParseBitCount(optarg);
@@ -132,6 +135,7 @@ int main(int argc, char *argv[]) {
case 'a':
config.pointer_bhiksha_bits = ParseBitCount(optarg);
bhiksha = true;
+ break;
case 'u':
config.unknown_missing_logprob = ParseFloat(optarg);
break;
@@ -144,6 +148,16 @@ int main(int argc, char *argv[]) {
case 'm':
config.building_memory = ParseUInt(optarg) * 1048576;
break;
+ case 'w':
+ set_write_method = true;
+ if (!strcmp(optarg, "mmap")) {
+ config.write_method = Config::WRITE_MMAP;
+ } else if (!strcmp(optarg, "after")) {
+ config.write_method = Config::WRITE_AFTER;
+ } else {
+ Usage(argv[0]);
+ }
+ break;
case 's':
config.sentence_marker_missing = lm::SILENT;
break;
@@ -160,45 +174,45 @@ int main(int argc, char *argv[]) {
}
if (optind + 1 == argc) {
ShowSizes(argv[optind], config);
- return 0;
- }
- const char *model_type, *from_file;
- if (optind + 2 == argc) {
- model_type = "probing";
- from_file = argv[optind];
+ } else if (optind + 2 == argc) {
config.write_mmap = argv[optind + 1];
+ if (quantize || set_backoff_bits) ProbingQuantizationUnsupported();
+ ProbingModel(argv[optind], config);
} else if (optind + 3 == argc) {
- model_type = argv[optind];
- from_file = argv[optind + 1];
+ const char *model_type = argv[optind];
+ const char *from_file = argv[optind + 1];
config.write_mmap = argv[optind + 2];
- } else {
- Usage(argv[0]);
- }
- if (!strcmp(model_type, "probing")) {
- if (quantize || set_backoff_bits) ProbingQuantizationUnsupported();
- ProbingModel(from_file, config);
- } else if (!strcmp(model_type, "trie")) {
- if (quantize) {
- if (bhiksha) {
- QuantArrayTrieModel(from_file, config);
+ if (!strcmp(model_type, "probing")) {
+ if (!set_write_method) config.write_method = Config::WRITE_AFTER;
+ if (quantize || set_backoff_bits) ProbingQuantizationUnsupported();
+ ProbingModel(from_file, config);
+ } else if (!strcmp(model_type, "trie")) {
+ if (!set_write_method) config.write_method = Config::WRITE_MMAP;
+ if (quantize) {
+ if (bhiksha) {
+ QuantArrayTrieModel(from_file, config);
+ } else {
+ QuantTrieModel(from_file, config);
+ }
} else {
- QuantTrieModel(from_file, config);
+ if (bhiksha) {
+ ArrayTrieModel(from_file, config);
+ } else {
+ TrieModel(from_file, config);
+ }
}
} else {
- if (bhiksha) {
- ArrayTrieModel(from_file, config);
- } else {
- TrieModel(from_file, config);
- }
+ Usage(argv[0]);
}
} else {
Usage(argv[0]);
}
- std::cerr << "Built " << config.write_mmap << " successfully." << std::endl;
- } catch (const std::exception &e) {
+ }
+ catch (const std::exception &e) {
std::cerr << e.what() << std::endl;
+ std::cerr << "ERROR" << std::endl;
return 1;
}
-
+ std::cerr << "SUCCESS" << std::endl;
return 0;
}
diff --git a/lm/config.cc b/lm/config.cc
index 297589a47..dbe762b32 100644
--- a/lm/config.cc
+++ b/lm/config.cc
@@ -17,6 +17,7 @@ Config::Config() :
temporary_directory_prefix(NULL),
arpa_complain(ALL),
write_mmap(NULL),
+ write_method(WRITE_AFTER),
include_vocab(true),
prob_bits(8),
backoff_bits(8),
diff --git a/lm/config.hh b/lm/config.hh
index 8564661bf..01b756322 100644
--- a/lm/config.hh
+++ b/lm/config.hh
@@ -70,9 +70,17 @@ struct Config {
// to NULL to disable.
const char *write_mmap;
+ typedef enum {
+ WRITE_MMAP, // Map the file directly.
+ WRITE_AFTER // Write after we're done.
+ } WriteMethod;
+ WriteMethod write_method;
+
// Include the vocab in the binary file? Only effective if write_mmap != NULL.
bool include_vocab;
+
+
// Quantization options. Only effective for QuantTrieModel. One value is
// reserved for each of prob and backoff, so 2^bits - 1 buckets will be used
// to quantize (and one of the remaining backoffs will be 0).
diff --git a/lm/model.cc b/lm/model.cc
index 042955efd..478ebed1b 100644
--- a/lm/model.cc
+++ b/lm/model.cc
@@ -46,7 +46,7 @@ template <class Search, class VocabularyT> GenericModel<Search, VocabularyT>::Ge
template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::InitializeFromBinary(void *start, const Parameters &params, const Config &config, int fd) {
SetupMemory(start, params.counts, config);
- vocab_.LoadedBinary(fd, config.enumerate_vocab);
+ vocab_.LoadedBinary(params.fixed.has_vocabulary, fd, config.enumerate_vocab);
search_.LoadedBinary();
}
@@ -82,7 +82,7 @@ template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT
search_.unigram.Unknown().backoff = 0.0;
search_.unigram.Unknown().prob = config.unknown_missing_logprob;
}
- FinishFile(config, kModelType, kVersion, counts, backing_);
+ FinishFile(config, kModelType, kVersion, counts, vocab_.UnkCountChangePadding(), backing_);
} catch (util::Exception &e) {
e << " Byte: " << f.Offset();
throw;
@@ -119,7 +119,7 @@ template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search,
}
float backoff;
// i is the order of the backoff we're looking for.
- const Middle *mid_iter = search_.MiddleBegin() + start - 2;
+ typename Search::MiddleIter mid_iter = search_.MiddleBegin() + start - 2;
for (const WordIndex *i = context_rbegin + start - 1; i < context_rend; ++i, ++mid_iter) {
if (!search_.LookupMiddleNoProb(*mid_iter, *i, backoff, node)) break;
ret.prob += backoff;
@@ -139,7 +139,7 @@ template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT
search_.LookupUnigram(*context_rbegin, out_state.backoff[0], node, ignored);
out_state.length = HasExtension(out_state.backoff[0]) ? 1 : 0;
float *backoff_out = out_state.backoff + 1;
- const typename Search::Middle *mid = search_.MiddleBegin();
+ typename Search::MiddleIter mid(search_.MiddleBegin());
for (const WordIndex *i = context_rbegin + 1; i < context_rend; ++i, ++backoff_out, ++mid) {
if (!search_.LookupMiddleNoProb(*mid, *i, *backoff_out, node)) {
std::copy(context_rbegin, context_rbegin + out_state.length, out_state.words);
@@ -166,7 +166,7 @@ template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search,
// If this function is called, then it does depend on left words.
ret.independent_left = false;
ret.extend_left = extend_pointer;
- const typename Search::Middle *mid_iter = search_.MiddleBegin() + extend_length - 1;
+ typename Search::MiddleIter mid_iter(search_.MiddleBegin() + extend_length - 1);
const WordIndex *i = add_rbegin;
for (; ; ++i, ++backoff_out, ++mid_iter) {
if (i == add_rend) {
@@ -235,7 +235,7 @@ template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search,
// Ok start by looking up the bigram.
const WordIndex *hist_iter = context_rbegin;
- const typename Search::Middle *mid_iter = search_.MiddleBegin();
+ typename Search::MiddleIter mid_iter(search_.MiddleBegin());
for (; ; ++mid_iter, ++hist_iter, ++backoff_out) {
if (hist_iter == context_rend) {
// Ran out of history. Typically no backoff, but this could be a blank.
diff --git a/lm/quantize.cc b/lm/quantize.cc
index 8de37e827..a8e0cb21c 100644
--- a/lm/quantize.cc
+++ b/lm/quantize.cc
@@ -20,11 +20,11 @@ namespace ngram {
namespace {
-void MakeBins(float *values, float *values_end, float *centers, uint32_t bins) {
- std::sort(values, values_end);
- const float *start = values, *finish;
+void MakeBins(std::vector<float> &values, float *centers, uint32_t bins) {
+ std::sort(values.begin(), values.end());
+ std::vector<float>::const_iterator start = values.begin(), finish;
for (uint32_t i = 0; i < bins; ++i, ++centers, start = finish) {
- finish = values + (((values_end - values) * static_cast<uint64_t>(i + 1)) / bins);
+ finish = values.begin() + ((values.size() * static_cast<uint64_t>(i + 1)) / bins);
if (finish == start) {
// zero length bucket.
*centers = i ? *(centers - 1) : -std::numeric_limits<float>::infinity();
@@ -66,12 +66,12 @@ void SeparatelyQuantize::Train(uint8_t order, std::vector<float> &prob, std::vec
float *centers = start_ + TableStart(order) + ProbTableLength();
*(centers++) = kNoExtensionBackoff;
*(centers++) = kExtensionBackoff;
- MakeBins(&*backoff.begin(), &*backoff.end(), centers, (1ULL << backoff_bits_) - 2);
+ MakeBins(backoff, centers, (1ULL << backoff_bits_) - 2);
}
void SeparatelyQuantize::TrainProb(uint8_t order, std::vector<float> &prob) {
float *centers = start_ + TableStart(order);
- MakeBins(&*prob.begin(), &*prob.end(), centers, (1ULL << prob_bits_));
+ MakeBins(prob, centers, (1ULL << prob_bits_));
}
void SeparatelyQuantize::FinishedLoading(const Config &config) {
diff --git a/lm/search_hashed.cc b/lm/search_hashed.cc
index f803b632e..1d6fb5be7 100644
--- a/lm/search_hashed.cc
+++ b/lm/search_hashed.cc
@@ -84,9 +84,11 @@ template <class Middle> void FixSRI(int lower, float negative_lower_prob, unsign
}
template <class Voc, class Store, class Middle, class Activate> void ReadNGrams(util::FilePiece &f, const unsigned int n, const size_t count, const Voc &vocab, ProbBackoff *unigrams, std::vector<Middle> &middle, Activate activate, Store &store, PositiveProbWarn &warn) {
+ assert(n >= 2);
ReadNGramHeader(f, n);
- // vocab ids of words in reverse order
+ // Both vocab_ids and keys are non-empty because n >= 2.
+ // vocab ids of words in reverse order.
std::vector<WordIndex> vocab_ids(n);
std::vector<uint64_t> keys(n-1);
typename Store::Entry::Value value;
@@ -147,7 +149,7 @@ template <class MiddleT, class LongestT> uint8_t *TemplateHashedSearch<MiddleT,
template <class MiddleT, class LongestT> template <class Voc> void TemplateHashedSearch<MiddleT, LongestT>::InitializeFromARPA(const char * /*file*/, util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, Voc &vocab, Backing &backing) {
// TODO: fix sorted.
- SetupMemory(GrowForSearch(config, 0, Size(counts, config), backing), counts, config);
+ SetupMemory(GrowForSearch(config, vocab.UnkCountChangePadding(), Size(counts, config), backing), counts, config);
PositiveProbWarn warn(config.positive_log_probability);
diff --git a/lm/search_hashed.hh b/lm/search_hashed.hh
index 96b03013e..4352c72dd 100644
--- a/lm/search_hashed.hh
+++ b/lm/search_hashed.hh
@@ -91,8 +91,10 @@ template <class MiddleT, class LongestT> class TemplateHashedSearch : public Has
template <class Voc> void InitializeFromARPA(const char *file, util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, Voc &vocab, Backing &backing);
- const Middle *MiddleBegin() const { return &*middle_.begin(); }
- const Middle *MiddleEnd() const { return &*middle_.end(); }
+ typedef typename std::vector<Middle>::const_iterator MiddleIter;
+
+ MiddleIter MiddleBegin() const { return middle_.begin(); }
+ MiddleIter MiddleEnd() const { return middle_.end(); }
Node Unpack(uint64_t extend_pointer, unsigned char extend_length, float &prob) const {
util::FloatEnc val;
diff --git a/lm/search_trie.cc b/lm/search_trie.cc
index f36d9c53c..ffadfa944 100644
--- a/lm/search_trie.cc
+++ b/lm/search_trie.cc
@@ -197,7 +197,7 @@ class SRISucks {
void ObtainBackoffs(unsigned char total_order, FILE *unigram_file, RecordReader *reader) {
for (unsigned char i = 0; i < kMaxOrder - 1; ++i) {
- it_[i] = &*values_[i].begin();
+ it_[i] = values_[i].empty() ? NULL : &*values_[i].begin();
}
messages_[0].Apply(it_, unigram_file);
BackoffMessages *messages = messages_ + 1;
@@ -229,8 +229,8 @@ class SRISucks {
class FindBlanks {
public:
- FindBlanks(uint64_t *counts, unsigned char order, const ProbBackoff *unigrams, SRISucks &messages)
- : counts_(counts), longest_counts_(counts + order - 1), unigrams_(unigrams), sri_(messages) {}
+ FindBlanks(unsigned char order, const ProbBackoff *unigrams, SRISucks &messages)
+ : counts_(order), unigrams_(unigrams), sri_(messages) {}
float UnigramProb(WordIndex index) const {
return unigrams_[index].prob;
@@ -250,7 +250,7 @@ class FindBlanks {
}
void Longest(const void * /*data*/) {
- ++*longest_counts_;
+ ++counts_.back();
}
// Unigrams wrote one past.
@@ -258,8 +258,12 @@ class FindBlanks {
--counts_[0];
}
+ const std::vector<uint64_t> &Counts() const {
+ return counts_;
+ }
+
private:
- uint64_t *const counts_, *const longest_counts_;
+ std::vector<uint64_t> counts_;
const ProbBackoff *unigrams_;
@@ -473,14 +477,15 @@ template <class Quant, class Bhiksha> void BuildTrie(SortedFiles &files, std::ve
}
SRISucks sri;
- std::vector<uint64_t> fixed_counts(counts.size());
+ std::vector<uint64_t> fixed_counts;
util::scoped_FILE unigram_file;
util::scoped_fd unigram_fd(files.StealUnigram());
{
util::scoped_memory unigrams;
MapRead(util::POPULATE_OR_READ, unigram_fd.get(), 0, counts[0] * sizeof(ProbBackoff), unigrams);
- FindBlanks finder(&*fixed_counts.begin(), counts.size(), reinterpret_cast<const ProbBackoff*>(unigrams.get()), sri);
+ FindBlanks finder(counts.size(), reinterpret_cast<const ProbBackoff*>(unigrams.get()), sri);
RecursiveInsert(counts.size(), counts[0], inputs, config.messages, "Identifying n-grams omitted by SRI", finder);
+ fixed_counts = finder.Counts();
}
unigram_file.reset(util::FDOpenOrThrow(unigram_fd));
for (const RecordReader *i = inputs; i != inputs + counts.size() - 2; ++i) {
diff --git a/lm/search_trie.hh b/lm/search_trie.hh
index caa7a05e2..5155ca020 100644
--- a/lm/search_trie.hh
+++ b/lm/search_trie.hh
@@ -62,6 +62,8 @@ template <class Quant, class Bhiksha> class TrieSearch {
void LoadedBinary();
+ typedef const Middle *MiddleIter;
+
const Middle *MiddleBegin() const { return middle_begin_; }
const Middle *MiddleEnd() const { return middle_end_; }
diff --git a/lm/trie_sort.cc b/lm/trie_sort.cc
index 9d1d5f27f..b80fed02e 100644
--- a/lm/trie_sort.cc
+++ b/lm/trie_sort.cc
@@ -83,7 +83,12 @@ FILE *WriteContextFile(uint8_t *begin, uint8_t *end, const util::TempMaker &make
PartialIter context_begin(PartialViewProxy(begin + sizeof(WordIndex), entry_size, context_size));
PartialIter context_end(PartialViewProxy(end + sizeof(WordIndex), entry_size, context_size));
- std::sort(context_begin, context_end, util::SizedCompare<EntryCompare, PartialViewProxy>(EntryCompare(order - 1)));
+#if defined(_WIN32) || defined(_WIN64)
+ std::stable_sort
+#else
+ std::sort
+#endif
+ (context_begin, context_end, util::SizedCompare<EntryCompare, PartialViewProxy>(EntryCompare(order - 1)));
util::scoped_FILE out(maker.MakeFile());
@@ -157,7 +162,10 @@ void RecordReader::Overwrite(const void *start, std::size_t amount) {
UTIL_THROW_IF(fseek(file_, internal - entry_size_, SEEK_CUR), util::ErrnoException, "Couldn't seek backwards for revision");
WriteOrThrow(file_, start, amount);
long forward = entry_size_ - internal - amount;
- if (forward) UTIL_THROW_IF(fseek(file_, forward, SEEK_CUR), util::ErrnoException, "Couldn't seek forwards past revision");
+#if !defined(_WIN32) && !defined(_WIN64)
+ if (forward)
+#endif
+ UTIL_THROW_IF(fseek(file_, forward, SEEK_CUR), util::ErrnoException, "Couldn't seek forwards past revision");
}
void RecordReader::Rewind() {
@@ -244,8 +252,13 @@ void SortedFiles::ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vo
}
// Sort full records by full n-gram.
util::SizedProxy proxy_begin(begin, entry_size), proxy_end(out_end, entry_size);
- // parallel_sort uses too much RAM
- std::sort(NGramIter(proxy_begin), NGramIter(proxy_end), util::SizedCompare<EntryCompare>(EntryCompare(order)));
+ // parallel_sort uses too much RAM. TODO: figure out why windows sort doesn't like my proxies.
+#if defined(_WIN32) || defined(_WIN64)
+ std::stable_sort
+#else
+ std::sort
+#endif
+ (NGramIter(proxy_begin), NGramIter(proxy_end), util::SizedCompare<EntryCompare>(EntryCompare(order)));
files.push_back(DiskFlush(begin, out_end, maker));
contexts.push_back(WriteContextFile(begin, out_end, maker, entry_size, order));
diff --git a/lm/vocab.cc b/lm/vocab.cc
index c10743ceb..9fd698bbf 100644
--- a/lm/vocab.cc
+++ b/lm/vocab.cc
@@ -125,8 +125,10 @@ WordIndex SortedVocabulary::Insert(const StringPiece &str) {
void SortedVocabulary::FinishedLoading(ProbBackoff *reorder_vocab) {
if (enumerate_) {
- util::PairedIterator<ProbBackoff*, std::string*> values(reorder_vocab + 1, &*strings_to_enumerate_.begin());
- util::JointSort(begin_, end_, values);
+ if (!strings_to_enumerate_.empty()) {
+ util::PairedIterator<ProbBackoff*, std::string*> values(reorder_vocab + 1, &*strings_to_enumerate_.begin());
+ util::JointSort(begin_, end_, values);
+ }
for (WordIndex i = 0; i < static_cast<WordIndex>(end_ - begin_); ++i) {
// <unk> strikes again: +1 here.
enumerate_->Add(i + 1, strings_to_enumerate_[i]);
@@ -142,11 +144,11 @@ void SortedVocabulary::FinishedLoading(ProbBackoff *reorder_vocab) {
bound_ = end_ - begin_ + 1;
}
-void SortedVocabulary::LoadedBinary(int fd, EnumerateVocab *to) {
+void SortedVocabulary::LoadedBinary(bool have_words, int fd, EnumerateVocab *to) {
end_ = begin_ + *(reinterpret_cast<const uint64_t*>(begin_) - 1);
SetSpecial(Index("<s>"), Index("</s>"), 0);
bound_ = end_ - begin_ + 1;
- ReadWords(fd, to, bound_);
+ if (have_words) ReadWords(fd, to, bound_);
}
namespace {
@@ -201,12 +203,12 @@ void ProbingVocabulary::FinishedLoading(ProbBackoff * /*reorder_vocab*/) {
SetSpecial(Index("<s>"), Index("</s>"), 0);
}
-void ProbingVocabulary::LoadedBinary(int fd, EnumerateVocab *to) {
+void ProbingVocabulary::LoadedBinary(bool have_words, int fd, EnumerateVocab *to) {
UTIL_THROW_IF(header_->version != kProbingVocabularyVersion, FormatLoadException, "The binary file has probing version " << header_->version << " but the code expects version " << kProbingVocabularyVersion << ". Please rerun build_binary using the same version of the code.");
lookup_.LoadedBinary();
bound_ = header_->bound;
SetSpecial(Index("<s>"), Index("</s>"), 0);
- ReadWords(fd, to, bound_);
+ if (have_words) ReadWords(fd, to, bound_);
}
void MissingUnknown(const Config &config) throw(SpecialWordMissingException) {
@@ -229,7 +231,7 @@ void MissingSentenceMarker(const Config &config, const char *str) throw(SpecialW
if (config.messages) *config.messages << "Missing special word " << str << "; will treat it as <unk>.";
break;
case THROW_UP:
- UTIL_THROW(SpecialWordMissingException, "The ARPA file is missing " << str << " and the model is configured to reject these models. If you built your APRA with IRSTLM and forgot to run add-start-end.sh, complain to <bertoldi at fbk.eu> stating that you think build-lm.sh should do this by default, then go back and retrain your model from the start. To bypass this check and treat " << str << " as an OOV, pass -s. The resulting model will not work with e.g. Moses.");
+ UTIL_THROW(SpecialWordMissingException, "The ARPA file is missing " << str << " and the model is configured to reject these models. Run build_binary -s to disable this check.");
}
}
diff --git a/lm/vocab.hh b/lm/vocab.hh
index 48db3d627..06fdefe49 100644
--- a/lm/vocab.hh
+++ b/lm/vocab.hh
@@ -82,7 +82,7 @@ class SortedVocabulary : public base::Vocabulary {
bool SawUnk() const { return saw_unk_; }
- void LoadedBinary(int fd, EnumerateVocab *to);
+ void LoadedBinary(bool have_words, int fd, EnumerateVocab *to);
private:
uint64_t *begin_, *end_;
@@ -143,9 +143,11 @@ class ProbingVocabulary : public base::Vocabulary {
void FinishedLoading(ProbBackoff *reorder_vocab);
+ std::size_t UnkCountChangePadding() const { return 0; }
+
bool SawUnk() const { return saw_unk_; }
- void LoadedBinary(int fd, EnumerateVocab *to);
+ void LoadedBinary(bool have_words, int fd, EnumerateVocab *to);
private:
typedef util::ProbingHashTable<ProbingVocabuaryEntry, util::IdentityHash> Lookup;