Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/moses-smt/mosesdecoder.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/lm
diff options
context:
space:
mode:
authorKenneth Heafield <github@kheafield.com>2013-06-24 19:05:47 +0400
committerKenneth Heafield <github@kheafield.com>2013-06-24 19:05:47 +0400
commit794867c555e82edf3bd12ffb3faa35fb24d6e0a1 (patch)
tree8fa6a1051d3f878ff5981a66d2e951148743fded /lm
parentf3cd72537c48ad584ee375195d27871932d3805c (diff)
KenLM 6b4a1c7940a36026de1d96693ccb6ec0f16de8dc
Diffstat (limited to 'lm')
-rw-r--r--lm/builder/lmplz_main.cc31
-rw-r--r--lm/builder/ngram.hh2
-rw-r--r--lm/model.cc21
-rw-r--r--lm/model.hh5
-rw-r--r--lm/search_hashed.cc29
-rw-r--r--lm/search_hashed.hh19
-rw-r--r--lm/virtual_interface.hh3
7 files changed, 73 insertions, 37 deletions
diff --git a/lm/builder/lmplz_main.cc b/lm/builder/lmplz_main.cc
index 1e086dcce..2e3002d12 100644
--- a/lm/builder/lmplz_main.cc
+++ b/lm/builder/lmplz_main.cc
@@ -33,6 +33,8 @@ int main(int argc, char *argv[]) {
po::options_description options("Language model building options");
lm::builder::PipelineConfig pipeline;
+ std::string text, arpa;
+
options.add_options()
("order,o", po::value<std::size_t>(&pipeline.order)
#if BOOST_VERSION >= 104200
@@ -47,18 +49,21 @@ int main(int argc, char *argv[]) {
("vocab_estimate", po::value<lm::WordIndex>(&pipeline.vocab_estimate)->default_value(1000000), "Assume this vocabulary size for purposes of calculating memory in step 1 (corpus count) and pre-sizing the hash table")
("block_count", po::value<std::size_t>(&pipeline.block_count)->default_value(2), "Block count (per order)")
("vocab_file", po::value<std::string>(&pipeline.vocab_file)->default_value(""), "Location to write vocabulary file")
- ("verbose_header", po::bool_switch(&pipeline.verbose_header), "Add a verbose header to the ARPA file that includes information such as token count, smoothing type, etc.");
+ ("verbose_header", po::bool_switch(&pipeline.verbose_header), "Add a verbose header to the ARPA file that includes information such as token count, smoothing type, etc.")
+ ("text", po::value<std::string>(&text), "Read text from a file instead of stdin")
+ ("arpa", po::value<std::string>(&arpa), "Write ARPA to a file instead of stdout");
if (argc == 1) {
std::cerr <<
"Builds unpruned language models with modified Kneser-Ney smoothing.\n\n"
"Please cite:\n"
- "@inproceedings{kenlm,\n"
- "author = {Kenneth Heafield},\n"
- "title = {{KenLM}: Faster and Smaller Language Model Queries},\n"
- "booktitle = {Proceedings of the Sixth Workshop on Statistical Machine Translation},\n"
- "month = {July}, year={2011},\n"
- "address = {Edinburgh, UK},\n"
- "publisher = {Association for Computational Linguistics},\n"
+ "@inproceedings{Heafield-estimate,\n"
+ " author = {Kenneth Heafield and Ivan Pouzyrevsky and Jonathan H. Clark and Philipp Koehn},\n"
+ " title = {Scalable Modified {Kneser-Ney} Language Model Estimation},\n"
+ " year = {2013},\n"
+ " month = {8},\n"
+ " booktitle = {Proceedings of the 51st Annual Meeting of the Association for Computational Linguistics},\n"
+ " address = {Sofia, Bulgaria},\n"
+ " url = {http://kheafield.com/professional/edinburgh/estimate\\_paper.pdf},\n"
"}\n\n"
"Provide the corpus on stdin. The ARPA file will be written to stdout. Order of\n"
"the model (-o) is the only mandatory option. As this is an on-disk program,\n"
@@ -91,9 +96,17 @@ int main(int argc, char *argv[]) {
initial.adder_out.block_count = 2;
pipeline.read_backoffs = initial.adder_out;
+ util::scoped_fd in(0), out(1);
+ if (vm.count("text")) {
+ in.reset(util::OpenReadOrThrow(text.c_str()));
+ }
+ if (vm.count("arpa")) {
+ out.reset(util::CreateOrThrow(arpa.c_str()));
+ }
+
// Read from stdin
try {
- lm::builder::Pipeline(pipeline, 0, 1);
+ lm::builder::Pipeline(pipeline, in.release(), out.release());
} catch (const util::MallocException &e) {
std::cerr << e.what() << std::endl;
std::cerr << "Try rerunning with a more conservative -S setting than " << vm["memory"].as<std::string>() << std::endl;
diff --git a/lm/builder/ngram.hh b/lm/builder/ngram.hh
index 2984ed0b6..f5681516a 100644
--- a/lm/builder/ngram.hh
+++ b/lm/builder/ngram.hh
@@ -53,7 +53,7 @@ class NGram {
Payload &Value() { return *reinterpret_cast<Payload *>(end_); }
uint64_t &Count() { return Value().count; }
- const uint64_t Count() const { return Value().count; }
+ uint64_t Count() const { return Value().count; }
std::size_t Order() const { return end_ - begin_; }
diff --git a/lm/model.cc b/lm/model.cc
index a40fd2fb0..a26654a6f 100644
--- a/lm/model.cc
+++ b/lm/model.cc
@@ -304,5 +304,26 @@ template class GenericModel<trie::TrieSearch<SeparatelyQuantize, trie::DontBhiks
template class GenericModel<trie::TrieSearch<SeparatelyQuantize, trie::ArrayBhiksha>, SortedVocabulary>;
} // namespace detail
+
+base::Model *LoadVirtual(const char *file_name, const Config &config, ModelType model_type) {
+ RecognizeBinary(file_name, model_type);
+ switch (model_type) {
+ case PROBING:
+ return new ProbingModel(file_name, config);
+ case REST_PROBING:
+ return new RestProbingModel(file_name, config);
+ case TRIE:
+ return new TrieModel(file_name, config);
+ case QUANT_TRIE:
+ return new QuantTrieModel(file_name, config);
+ case ARRAY_TRIE:
+ return new ArrayTrieModel(file_name, config);
+ case QUANT_ARRAY_TRIE:
+ return new QuantArrayTrieModel(file_name, config);
+ default:
+ UTIL_THROW(FormatLoadException, "Confused by model type " << model_type);
+ }
+}
+
} // namespace ngram
} // namespace lm
diff --git a/lm/model.hh b/lm/model.hh
index 13ff864e1..60f55110b 100644
--- a/lm/model.hh
+++ b/lm/model.hh
@@ -153,6 +153,11 @@ LM_NAME_MODEL(QuantArrayTrieModel, detail::GenericModel<trie::TrieSearch<Separat
typedef ::lm::ngram::ProbingVocabulary Vocabulary;
typedef ProbingModel Model;
+/* Autorecognize the file type, load, and return the virtual base class. Don't
+ * use the virtual base class if you can avoid it. Instead, use the above
+ * classes as template arguments to your own virtual feature function.*/
+base::Model *LoadVirtual(const char *file_name, const Config &config = Config(), ModelType if_arpa = PROBING);
+
} // namespace ngram
} // namespace lm
diff --git a/lm/search_hashed.cc b/lm/search_hashed.cc
index 2d6f15b23..62275d277 100644
--- a/lm/search_hashed.cc
+++ b/lm/search_hashed.cc
@@ -54,7 +54,7 @@ template <class Weights> class ActivateUnigram {
Weights *modify_;
};
-// Find the lower order entry, inserting blanks along the way as necessary.
+// Find the lower order entry, inserting blanks along the way as necessary.
template <class Value> void FindLower(
const std::vector<uint64_t> &keys,
typename Value::Weights &unigram,
@@ -64,7 +64,7 @@ template <class Value> void FindLower(
typename Value::ProbingEntry entry;
// Backoff will always be 0.0. We'll get the probability and rest in another pass.
entry.value.backoff = kNoExtensionBackoff;
- // Go back and find the longest right-aligned entry, informing it that it extends left. Normally this will match immediately, but sometimes SRI is dumb.
+ // Go back and find the longest right-aligned entry, informing it that it extends left. Normally this will match immediately, but sometimes SRI is dumb.
for (int lower = keys.size() - 2; ; --lower) {
if (lower == -1) {
between.push_back(&unigram);
@@ -77,11 +77,11 @@ template <class Value> void FindLower(
}
}
-// Between usually has single entry, the value to adjust. But sometimes SRI stupidly pruned entries so it has unitialized blank values to be set here.
+// Between usually has single entry, the value to adjust. But sometimes SRI stupidly pruned entries so it has unitialized blank values to be set here.
template <class Added, class Build> void AdjustLower(
const Added &added,
const Build &build,
- std::vector<typename Build::Value::Weights *> &between,
+ std::vector<typename Build::Value::Weights *> &between,
const unsigned int n,
const std::vector<WordIndex> &vocab_ids,
typename Build::Value::Weights *unigrams,
@@ -93,14 +93,14 @@ template <class Added, class Build> void AdjustLower(
}
typedef util::ProbingHashTable<typename Value::ProbingEntry, util::IdentityHash> Middle;
float prob = -fabs(between.back()->prob);
- // Order of the n-gram on which probabilities are based.
+ // Order of the n-gram on which probabilities are based.
unsigned char basis = n - between.size();
assert(basis != 0);
typename Build::Value::Weights **change = &between.back();
// Skip the basis.
--change;
if (basis == 1) {
- // Hallucinate a bigram based on a unigram's backoff and a unigram probability.
+ // Hallucinate a bigram based on a unigram's backoff and a unigram probability.
float &backoff = unigrams[vocab_ids[1]].backoff;
SetExtension(backoff);
prob += backoff;
@@ -128,14 +128,14 @@ template <class Added, class Build> void AdjustLower(
typename std::vector<typename Value::Weights *>::const_iterator i(between.begin());
build.MarkExtends(**i, added);
const typename Value::Weights *longer = *i;
- // Everything has probability but is not marked as extending.
+ // Everything has probability but is not marked as extending.
for (++i; i != between.end(); ++i) {
build.MarkExtends(**i, *longer);
longer = *i;
}
}
-// Continue marking lower entries even they know that they extend left. This is used for upper/lower bounds.
+// Continue marking lower entries even they know that they extend left. This is used for upper/lower bounds.
template <class Build> void MarkLower(
const std::vector<uint64_t> &keys,
const Build &build,
@@ -144,15 +144,15 @@ template <class Build> void MarkLower(
int start_order,
const typename Build::Value::Weights &longer) {
if (start_order == 0) return;
- typename util::ProbingHashTable<typename Build::Value::ProbingEntry, util::IdentityHash>::MutableIterator iter;
- // Hopefully the compiler will realize that if MarkExtends always returns false, it can simplify this code.
+ // Hopefully the compiler will realize that if MarkExtends always returns false, it can simplify this code.
for (int even_lower = start_order - 2 /* index in middle */; ; --even_lower) {
if (even_lower == -1) {
build.MarkExtends(unigram, longer);
return;
}
- middle[even_lower].UnsafeMutableFind(keys[even_lower], iter);
- if (!build.MarkExtends(iter->value, longer)) return;
+ if (!build.MarkExtends(
+ middle[even_lower].UnsafeMutableMustFind(keys[even_lower])->value,
+ longer)) return;
}
}
@@ -168,7 +168,6 @@ template <class Build, class Activate, class Store> void ReadNGrams(
Store &store,
PositiveProbWarn &warn) {
typedef typename Build::Value Value;
- typedef util::ProbingHashTable<typename Value::ProbingEntry, util::IdentityHash> Middle;
assert(n >= 2);
ReadNGramHeader(f, n);
@@ -186,7 +185,7 @@ template <class Build, class Activate, class Store> void ReadNGrams(
for (unsigned int h = 1; h < n - 1; ++h) {
keys[h] = detail::CombineWordHash(keys[h-1], vocab_ids[h+1]);
}
- // Initially the sign bit is on, indicating it does not extend left. Most already have this but there might +0.0.
+ // Initially the sign bit is on, indicating it does not extend left. Most already have this but there might +0.0.
util::SetSign(entry.value.prob);
entry.key = keys[n-2];
@@ -203,7 +202,7 @@ template <class Build, class Activate, class Store> void ReadNGrams(
} // namespace
namespace detail {
-
+
template <class Value> uint8_t *HashedSearch<Value>::SetupMemory(uint8_t *start, const std::vector<uint64_t> &counts, const Config &config) {
std::size_t allocated = Unigram::Size(counts[0]);
unigram_ = Unigram(start, counts[0], allocated);
diff --git a/lm/search_hashed.hh b/lm/search_hashed.hh
index 005957967..9d067bc2e 100644
--- a/lm/search_hashed.hh
+++ b/lm/search_hashed.hh
@@ -71,7 +71,7 @@ template <class Value> class HashedSearch {
static const bool kDifferentRest = Value::kDifferentRest;
static const unsigned int kVersion = 0;
- // TODO: move probing_multiplier here with next binary file format update.
+ // TODO: move probing_multiplier here with next binary file format update.
static void UpdateConfigFromBinary(int, const std::vector<uint64_t> &, Config &) {}
static uint64_t Size(const std::vector<uint64_t> &counts, const Config &config) {
@@ -102,14 +102,9 @@ template <class Value> class HashedSearch {
return ret;
}
-#pragma GCC diagnostic ignored "-Wuninitialized"
MiddlePointer Unpack(uint64_t extend_pointer, unsigned char extend_length, Node &node) const {
node = extend_pointer;
- typename Middle::ConstIterator found;
- bool got = middle_[extend_length - 2].Find(extend_pointer, found);
- assert(got);
- (void)got;
- return MiddlePointer(found->value);
+ return MiddlePointer(middle_[extend_length - 2].MustFind(extend_pointer)->value);
}
MiddlePointer LookupMiddle(unsigned char order_minus_2, WordIndex word, Node &node, bool &independent_left, uint64_t &extend_pointer) const {
@@ -126,14 +121,14 @@ template <class Value> class HashedSearch {
}
LongestPointer LookupLongest(WordIndex word, const Node &node) const {
- // Sign bit is always on because longest n-grams do not extend left.
+ // Sign bit is always on because longest n-grams do not extend left.
typename Longest::ConstIterator found;
if (!longest_.Find(CombineWordHash(node, word), found)) return LongestPointer();
return LongestPointer(found->value.prob);
}
- // Generate a node without necessarily checking that it actually exists.
- // Optionally return false if it's know to not exist.
+ // Generate a node without necessarily checking that it actually exists.
+ // Optionally return false if it's know to not exist.
bool FastMakeNode(const WordIndex *begin, const WordIndex *end, Node &node) const {
assert(begin != end);
node = static_cast<Node>(*begin);
@@ -144,7 +139,7 @@ template <class Value> class HashedSearch {
}
private:
- // Interpret config's rest cost build policy and pass the right template argument to ApplyBuild.
+ // Interpret config's rest cost build policy and pass the right template argument to ApplyBuild.
void DispatchBuild(util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, const ProbingVocabulary &vocab, PositiveProbWarn &warn);
template <class Build> void ApplyBuild(util::FilePiece &f, const std::vector<uint64_t> &counts, const ProbingVocabulary &vocab, PositiveProbWarn &warn, const Build &build);
@@ -153,7 +148,7 @@ template <class Value> class HashedSearch {
public:
Unigram() {}
- Unigram(void *start, uint64_t count, std::size_t /*allocated*/) :
+ Unigram(void *start, uint64_t count, std::size_t /*allocated*/) :
unigram_(static_cast<typename Value::Weights*>(start))
#ifdef DEBUG
, count_(count)
diff --git a/lm/virtual_interface.hh b/lm/virtual_interface.hh
index 6a5a0196f..17f064b2c 100644
--- a/lm/virtual_interface.hh
+++ b/lm/virtual_interface.hh
@@ -6,6 +6,7 @@
#include "util/string_piece.hh"
#include <string>
+#include <string.h>
namespace lm {
namespace base {
@@ -119,7 +120,9 @@ class Model {
size_t StateSize() const { return state_size_; }
const void *BeginSentenceMemory() const { return begin_sentence_memory_; }
+ void BeginSentenceWrite(void *to) const { memcpy(to, begin_sentence_memory_, StateSize()); }
const void *NullContextMemory() const { return null_context_memory_; }
+ void NullContextWrite(void *to) const { memcpy(to, null_context_memory_, StateSize()); }
// Requires in_state != out_state
virtual float Score(const void *in_state, const WordIndex new_word, void *out_state) const = 0;