Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/moses-smt/mosesdecoder.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/lm
diff options
context:
space:
mode:
authorKenneth Heafield <github@kheafield.com>2011-11-17 23:12:19 +0400
committerKenneth Heafield <github@kheafield.com>2011-11-17 23:12:19 +0400
commit974a708dddab2b4c6836a176d95f8455d0ed5f51 (patch)
tree0916139ae18b032250dc84a6703dabbedc7dc8ac /lm
parent17cec851dfcf5712c66e6304308273b62dc532e8 (diff)
Updated kenlm 96ef3f2c11.
Invalidates old gcc and 32-bit formats, replacing these with one consistent format: 64-bit new gcc. Backwards compatible with these files.
Diffstat (limited to 'lm')
-rw-r--r--lm/binary_format.cc43
-rw-r--r--lm/build_binary.cc3
-rw-r--r--lm/left_test.cc11
-rw-r--r--lm/model_test.cc24
-rw-r--r--lm/ngram_query.cc59
-rw-r--r--lm/search_hashed.cc16
-rw-r--r--lm/search_hashed.hh57
-rw-r--r--lm/search_trie.cc2
-rw-r--r--lm/vocab.cc28
-rw-r--r--lm/vocab.hh30
10 files changed, 194 insertions, 79 deletions
diff --git a/lm/binary_format.cc b/lm/binary_format.cc
index 5aa274216..05a0dff03 100644
--- a/lm/binary_format.cc
+++ b/lm/binary_format.cc
@@ -20,19 +20,39 @@ const char kMagicBytes[] = "mmap lm http://kheafield.com/code format version 5\n
const char kMagicIncomplete[] = "mmap lm http://kheafield.com/code incomplete\n";
const long int kMagicVersion = 5;
-// Test values.
-struct Sanity {
+// Old binary files built on 32-bit machines have this header.
+// TODO: eliminate with next binary release.
+struct OldSanity {
char magic[sizeof(kMagicBytes)];
float zero_f, one_f, minus_half_f;
WordIndex one_word_index, max_word_index;
uint64_t one_uint64;
void SetToReference() {
+ std::memset(this, 0, sizeof(OldSanity));
+ std::memcpy(magic, kMagicBytes, sizeof(magic));
+ zero_f = 0.0; one_f = 1.0; minus_half_f = -0.5;
+ one_word_index = 1;
+ max_word_index = std::numeric_limits<WordIndex>::max();
+ one_uint64 = 1;
+ }
+};
+
+
+// Test values aligned to 8 bytes.
+struct Sanity {
+ char magic[ALIGN8(sizeof(kMagicBytes))];
+ float zero_f, one_f, minus_half_f;
+ WordIndex one_word_index, max_word_index, padding_to_8;
+ uint64_t one_uint64;
+
+ void SetToReference() {
std::memset(this, 0, sizeof(Sanity));
std::memcpy(magic, kMagicBytes, sizeof(magic));
zero_f = 0.0; one_f = 1.0; minus_half_f = -0.5;
one_word_index = 1;
max_word_index = std::numeric_limits<WordIndex>::max();
+ padding_to_8 = 0;
one_uint64 = 1;
}
};
@@ -76,8 +96,12 @@ uint8_t *GrowForSearch(const Config &config, std::size_t vocab_pad, std::size_t
std::size_t adjusted_vocab = backing.vocab.size() + vocab_pad;
if (config.write_mmap) {
// Grow the file to accomodate the search, using zeros.
- if (-1 == ftruncate(backing.file.get(), adjusted_vocab + memory_size))
- UTIL_THROW(util::ErrnoException, "ftruncate on " << config.write_mmap << " to " << (adjusted_vocab + memory_size) << " failed");
+ try {
+ util::ResizeOrThrow(backing.file.get(), adjusted_vocab + memory_size);
+ } catch (util::ErrnoException &e) {
+ e << " for file " << config.write_mmap;
+ throw e;
+ }
// We're skipping over the header and vocab for the search space mmap. mmap likes page aligned offsets, so some arithmetic to round the offset down.
std::size_t page_size = util::SizePage();
@@ -96,7 +120,7 @@ void FinishFile(const Config &config, ModelType model_type, unsigned int search_
util::SyncOrThrow(backing.search.get(), backing.search.size());
util::SyncOrThrow(backing.vocab.get(), backing.vocab.size());
// header and vocab share the same mmap. The header is written here because we know the counts.
- Parameters params;
+ Parameters params = Parameters();
params.counts = counts;
params.fixed.order = counts.size();
params.fixed.probing_multiplier = config.probing_multiplier;
@@ -132,6 +156,10 @@ bool IsBinaryFormat(int fd) {
if ((end_ptr != begin_version) && version != kMagicVersion) {
UTIL_THROW(FormatLoadException, "Binary file has version " << version << " but this implementation expects version " << kMagicVersion << " so you'll have to use the ARPA to rebuild your binary");
}
+
+ OldSanity old_sanity = OldSanity();
+ old_sanity.SetToReference();
+ UTIL_THROW_IF(!memcmp(memory.get(), &old_sanity, sizeof(OldSanity)), FormatLoadException, "Looks like this is an old 32-bit format. The old 32-bit format has been removed so that 64-bit and 32-bit files are exchangeable.");
UTIL_THROW(FormatLoadException, "File looks like it should be loaded with mmap, but the test values don't match. Try rebuilding the binary format LM using the same code revision, compiler, and architecture");
}
return false;
@@ -172,9 +200,8 @@ uint8_t *SetupBinary(const Config &config, const Parameters &params, std::size_t
if (config.enumerate_vocab && !params.fixed.has_vocabulary)
UTIL_THROW(FormatLoadException, "The decoder requested all the vocabulary strings, but this binary file does not have them. You may need to rebuild the binary file with an updated version of build_binary.");
- if (config.enumerate_vocab) {
- util::SeekOrThrow(backing.file.get(), total_map);
- }
+ // Seek to vocabulary words
+ util::SeekOrThrow(backing.file.get(), total_map);
return reinterpret_cast<uint8_t*>(backing.search.get()) + TotalHeaderSize(params.counts.size());
}
diff --git a/lm/build_binary.cc b/lm/build_binary.cc
index f313002fe..400746dfc 100644
--- a/lm/build_binary.cc
+++ b/lm/build_binary.cc
@@ -1,6 +1,5 @@
#include "lm/model.hh"
#include "util/file_piece.hh"
-#include "util/portability.hh"
#include <cstdlib>
#include <exception>
@@ -91,7 +90,7 @@ void ShowSizes(const char *file, const lm::ngram::Config &config) {
prefix = 'G';
divide = 1 << 30;
}
- long int length = std::max<long int>(2, lrint(ceil(log10((double) max_length / divide))));
+ long int length = std::max<long int>(2, static_cast<long int>(ceil(log10((double) max_length / divide))));
std::cout << "Memory estimate:\ntype ";
// right align bytes.
for (long int i = 0; i < length - 2; ++i) std::cout << ' ';
diff --git a/lm/left_test.cc b/lm/left_test.cc
index 8bb91cb37..c85e5efa8 100644
--- a/lm/left_test.cc
+++ b/lm/left_test.cc
@@ -142,7 +142,7 @@ template <class M> float TreeMiddle(const M &m, const std::vector<WordIndex> &wo
template <class M> void LookupVocab(const M &m, const StringPiece &str, std::vector<WordIndex> &out) {
out.clear();
- for (util::PieceIterator<' '> i(str); i; ++i) {
+ for (util::TokenIter<util::SingleCharacter, true> i(str, ' '); i; ++i) {
out.push_back(m.GetVocabulary().Index(*i));
}
}
@@ -326,10 +326,17 @@ template <class M> void FullGrow(const M &m) {
}
}
+const char *FileLocation() {
+ if (boost::unit_test::framework::master_test_suite().argc < 2) {
+ return "test.arpa";
+ }
+ return boost::unit_test::framework::master_test_suite().argv[1];
+}
+
template <class M> void Everything() {
Config config;
config.messages = NULL;
- M m("test.arpa", config);
+ M m(FileLocation(), config);
Short(m);
Charge(m);
diff --git a/lm/model_test.cc b/lm/model_test.cc
index 2654071f8..461704d43 100644
--- a/lm/model_test.cc
+++ b/lm/model_test.cc
@@ -19,6 +19,20 @@ std::ostream &operator<<(std::ostream &o, const State &state) {
namespace {
+const char *TestLocation() {
+ if (boost::unit_test::framework::master_test_suite().argc < 2) {
+ return "test.arpa";
+ }
+ return boost::unit_test::framework::master_test_suite().argv[1];
+}
+const char *TestNoUnkLocation() {
+ if (boost::unit_test::framework::master_test_suite().argc < 3) {
+ return "test_nounk.arpa";
+ }
+ return boost::unit_test::framework::master_test_suite().argv[2];
+
+}
+
#define StartTest(word, ngram, score, indep_left) \
ret = model.FullScore( \
state, \
@@ -307,7 +321,7 @@ template <class ModelT> void LoadingTest() {
{
ExpectEnumerateVocab enumerate;
config.enumerate_vocab = &enumerate;
- ModelT m("test.arpa", config);
+ ModelT m(TestLocation(), config);
enumerate.Check(m.GetVocabulary());
BOOST_CHECK_EQUAL((WordIndex)37, m.GetVocabulary().Bound());
Everything(m);
@@ -315,7 +329,7 @@ template <class ModelT> void LoadingTest() {
{
ExpectEnumerateVocab enumerate;
config.enumerate_vocab = &enumerate;
- ModelT m("test_nounk.arpa", config);
+ ModelT m(TestNoUnkLocation(), config);
enumerate.Check(m.GetVocabulary());
BOOST_CHECK_EQUAL((WordIndex)37, m.GetVocabulary().Bound());
NoUnkCheck(m);
@@ -346,7 +360,7 @@ template <class ModelT> void BinaryTest() {
config.enumerate_vocab = &enumerate;
{
- ModelT copy_model("test.arpa", config);
+ ModelT copy_model(TestLocation(), config);
enumerate.Check(copy_model.GetVocabulary());
enumerate.Clear();
Everything(copy_model);
@@ -370,14 +384,14 @@ template <class ModelT> void BinaryTest() {
config.messages = NULL;
enumerate.Clear();
{
- ModelT copy_model("test_nounk.arpa", config);
+ ModelT copy_model(TestNoUnkLocation(), config);
enumerate.Check(copy_model.GetVocabulary());
enumerate.Clear();
NoUnkCheck(copy_model);
}
config.write_mmap = NULL;
{
- ModelT binary("test_nounk.binary", config);
+ ModelT binary(TestNoUnkLocation(), config);
enumerate.Check(binary.GetVocabulary());
NoUnkCheck(binary);
}
diff --git a/lm/ngram_query.cc b/lm/ngram_query.cc
index 6e9874673..1b2cd5db3 100644
--- a/lm/ngram_query.cc
+++ b/lm/ngram_query.cc
@@ -94,34 +94,39 @@ int main(int argc, char *argv[]) {
std::cerr << "Input is wrapped in <s> and </s> unless null is passed." << std::endl;
return 1;
}
- bool sentence_context = (argc == 2);
- lm::ngram::ModelType model_type;
- if (lm::ngram::RecognizeBinary(argv[1], model_type)) {
- switch(model_type) {
- case lm::ngram::HASH_PROBING:
- Query<lm::ngram::ProbingModel>(argv[1], sentence_context);
- break;
- case lm::ngram::TRIE_SORTED:
- Query<lm::ngram::TrieModel>(argv[1], sentence_context);
- break;
- case lm::ngram::QUANT_TRIE_SORTED:
- Query<lm::ngram::QuantTrieModel>(argv[1], sentence_context);
- break;
- case lm::ngram::ARRAY_TRIE_SORTED:
- Query<lm::ngram::ArrayTrieModel>(argv[1], sentence_context);
- break;
- case lm::ngram::QUANT_ARRAY_TRIE_SORTED:
- Query<lm::ngram::QuantArrayTrieModel>(argv[1], sentence_context);
- break;
- case lm::ngram::HASH_SORTED:
- default:
- std::cerr << "Unrecognized kenlm model type " << model_type << std::endl;
- abort();
+ try {
+ bool sentence_context = (argc == 2);
+ lm::ngram::ModelType model_type;
+ if (lm::ngram::RecognizeBinary(argv[1], model_type)) {
+ switch(model_type) {
+ case lm::ngram::HASH_PROBING:
+ Query<lm::ngram::ProbingModel>(argv[1], sentence_context);
+ break;
+ case lm::ngram::TRIE_SORTED:
+ Query<lm::ngram::TrieModel>(argv[1], sentence_context);
+ break;
+ case lm::ngram::QUANT_TRIE_SORTED:
+ Query<lm::ngram::QuantTrieModel>(argv[1], sentence_context);
+ break;
+ case lm::ngram::ARRAY_TRIE_SORTED:
+ Query<lm::ngram::ArrayTrieModel>(argv[1], sentence_context);
+ break;
+ case lm::ngram::QUANT_ARRAY_TRIE_SORTED:
+ Query<lm::ngram::QuantArrayTrieModel>(argv[1], sentence_context);
+ break;
+ case lm::ngram::HASH_SORTED:
+ default:
+ std::cerr << "Unrecognized kenlm model type " << model_type << std::endl;
+ abort();
+ }
+ } else {
+ Query<lm::ngram::ProbingModel>(argv[1], sentence_context);
}
- } else {
- Query<lm::ngram::ProbingModel>(argv[1], sentence_context);
- }
- PrintUsage("Total time including destruction:\n");
+ PrintUsage("Total time including destruction:\n");
+ } catch (const std::exception &e) {
+ std::cerr << e.what() << std::endl;
+ return 1;
+ }
return 0;
}
diff --git a/lm/search_hashed.cc b/lm/search_hashed.cc
index 247832b0a..f803b632e 100644
--- a/lm/search_hashed.cc
+++ b/lm/search_hashed.cc
@@ -30,7 +30,7 @@ template <class Middle> class ActivateLowerMiddle {
// TODO: somehow get text of n-gram for this error message.
if (!modify_.UnsafeMutableFind(hash, i))
UTIL_THROW(FormatLoadException, "The context of every " << n << "-gram should appear as a " << (n-1) << "-gram");
- SetExtension(i->MutableValue().backoff);
+ SetExtension(i->value.backoff);
}
private:
@@ -65,7 +65,7 @@ template <class Middle> void FixSRI(int lower, float negative_lower_prob, unsign
blank.prob -= unigrams[vocab_ids[1]].backoff;
SetExtension(unigrams[vocab_ids[1]].backoff);
// Bigram including a unigram's backoff
- middle[0].Insert(Middle::Packing::Make(keys[0], blank));
+ middle[0].Insert(detail::ProbBackoffEntry::Make(keys[0], blank));
fix = 1;
} else {
for (unsigned int i = 3; i < fix + 2; ++i) backoff_hash = detail::CombineWordHash(backoff_hash, vocab_ids[i]);
@@ -74,11 +74,11 @@ template <class Middle> void FixSRI(int lower, float negative_lower_prob, unsign
for (; fix <= n - 3; ++fix) {
typename Middle::MutableIterator gotit;
if (middle[fix - 1].UnsafeMutableFind(backoff_hash, gotit)) {
- float &backoff = gotit->MutableValue().backoff;
+ float &backoff = gotit->value.backoff;
SetExtension(backoff);
blank.prob -= backoff;
}
- middle[fix].Insert(Middle::Packing::Make(keys[fix], blank));
+ middle[fix].Insert(detail::ProbBackoffEntry::Make(keys[fix], blank));
backoff_hash = detail::CombineWordHash(backoff_hash, vocab_ids[fix + 2]);
}
}
@@ -89,7 +89,7 @@ template <class Voc, class Store, class Middle, class Activate> void ReadNGrams(
// vocab ids of words in reverse order
std::vector<WordIndex> vocab_ids(n);
std::vector<uint64_t> keys(n-1);
- typename Store::Packing::Value value;
+ typename Store::Entry::Value value;
typename Middle::MutableIterator found;
for (size_t i = 0; i < count; ++i) {
ReadNGram(f, n, vocab, &*vocab_ids.begin(), value, warn);
@@ -100,7 +100,7 @@ template <class Voc, class Store, class Middle, class Activate> void ReadNGrams(
}
// Initially the sign bit is on, indicating it does not extend left. Most already have this but there might +0.0.
util::SetSign(value.prob);
- store.Insert(Store::Packing::Make(keys[n-2], value));
+ store.Insert(Store::Entry::Make(keys[n-2], value));
// Go back and find the longest right-aligned entry, informing it that it extends left. Normally this will match immediately, but sometimes SRI is dumb.
int lower;
util::FloatEnc fix_prob;
@@ -113,9 +113,9 @@ template <class Voc, class Store, class Middle, class Activate> void ReadNGrams(
}
if (middle[lower].UnsafeMutableFind(keys[lower], found)) {
// Turn off sign bit to indicate that it extends left.
- fix_prob.f = found->MutableValue().prob;
+ fix_prob.f = found->value.prob;
fix_prob.i &= ~util::kSignBit;
- found->MutableValue().prob = fix_prob.f;
+ found->value.prob = fix_prob.f;
// We don't need to recurse further down because this entry already set the bits for lower entries.
break;
}
diff --git a/lm/search_hashed.hh b/lm/search_hashed.hh
index e289fd114..96b03013e 100644
--- a/lm/search_hashed.hh
+++ b/lm/search_hashed.hh
@@ -8,7 +8,6 @@
#include "lm/weights.hh"
#include "util/bit_packing.hh"
-#include "util/key_value_packing.hh"
#include "util/probing_hash_table.hh"
#include <algorithm>
@@ -105,7 +104,7 @@ template <class MiddleT, class LongestT> class TemplateHashedSearch : public Has
std::cerr << "Extend pointer " << extend_pointer << " should have been found for length " << (unsigned) extend_length << std::endl;
abort();
}
- val.f = found->GetValue().prob;
+ val.f = found->value.prob;
}
val.i |= util::kSignBit;
prob = val.f;
@@ -117,12 +116,12 @@ template <class MiddleT, class LongestT> class TemplateHashedSearch : public Has
typename Middle::ConstIterator found;
if (!middle.Find(node, found)) return false;
util::FloatEnc enc;
- enc.f = found->GetValue().prob;
+ enc.f = found->value.prob;
ret.independent_left = (enc.i & util::kSignBit);
ret.extend_left = node;
enc.i |= util::kSignBit;
ret.prob = enc.f;
- backoff = found->GetValue().backoff;
+ backoff = found->value.backoff;
return true;
}
@@ -132,7 +131,7 @@ template <class MiddleT, class LongestT> class TemplateHashedSearch : public Has
node = CombineWordHash(node, word);
typename Middle::ConstIterator found;
if (!middle.Find(node, found)) return false;
- backoff = found->GetValue().backoff;
+ backoff = found->value.backoff;
return true;
}
@@ -141,7 +140,7 @@ template <class MiddleT, class LongestT> class TemplateHashedSearch : public Has
node = CombineWordHash(node, word);
typename Longest::ConstIterator found;
if (!longest.Find(node, found)) return false;
- prob = found->GetValue().prob;
+ prob = found->value.prob;
return true;
}
@@ -160,14 +159,50 @@ template <class MiddleT, class LongestT> class TemplateHashedSearch : public Has
std::vector<Middle> middle_;
};
-// std::identity is an SGI extension :-(
-struct IdentityHash : public std::unary_function<uint64_t, size_t> {
- size_t operator()(uint64_t arg) const { return static_cast<size_t>(arg); }
+/* These look like perfect candidates for a template, right? Ancient gcc (4.1
+ * on RedHat stale linux) doesn't pack templates correctly. ProbBackoffEntry
+ * is a multiple of 8 bytes anyway. ProbEntry is 12 bytes so it's set to pack.
+ */
+struct ProbBackoffEntry {
+ uint64_t key;
+ ProbBackoff value;
+ typedef uint64_t Key;
+ typedef ProbBackoff Value;
+ uint64_t GetKey() const {
+ return key;
+ }
+ static ProbBackoffEntry Make(uint64_t key, ProbBackoff value) {
+ ProbBackoffEntry ret;
+ ret.key = key;
+ ret.value = value;
+ return ret;
+ }
+};
+
+#pragma pack(push)
+#pragma pack(4)
+struct ProbEntry {
+ uint64_t key;
+ Prob value;
+ typedef uint64_t Key;
+ typedef Prob Value;
+ uint64_t GetKey() const {
+ return key;
+ }
+ static ProbEntry Make(uint64_t key, Prob value) {
+ ProbEntry ret;
+ ret.key = key;
+ ret.value = value;
+ return ret;
+ }
};
+#pragma pack(pop)
+
+
struct ProbingHashedSearch : public TemplateHashedSearch<
- util::ProbingHashTable<util::ByteAlignedPacking<uint64_t, ProbBackoff>, IdentityHash>,
- util::ProbingHashTable<util::ByteAlignedPacking<uint64_t, Prob>, IdentityHash> > {
+ util::ProbingHashTable<ProbBackoffEntry, util::IdentityHash>,
+ util::ProbingHashTable<ProbEntry, util::IdentityHash> > {
static const ModelType kModelType = HASH_PROBING;
};
diff --git a/lm/search_trie.cc b/lm/search_trie.cc
index 8cb6984b0..f36d9c53c 100644
--- a/lm/search_trie.cc
+++ b/lm/search_trie.cc
@@ -377,7 +377,7 @@ template <class Doing> class BlankManager {
template <class Doing> void RecursiveInsert(const unsigned char total_order, const WordIndex unigram_count, RecordReader *input, std::ostream *progress_out, const char *message, Doing &doing) {
util::ErsatzProgress progress(progress_out, message, unigram_count + 1);
- unsigned int unigram = 0;
+ WordIndex unigram = 0;
std::priority_queue<Gram> grams;
grams.push(Gram(&unigram, 1));
for (unsigned char i = 2; i <= total_order; ++i) {
diff --git a/lm/vocab.cc b/lm/vocab.cc
index 5ac828178..3fefe6b13 100644
--- a/lm/vocab.cc
+++ b/lm/vocab.cc
@@ -13,6 +13,8 @@
#include <string>
+#include <string.h>
+
namespace lm {
namespace ngram {
@@ -30,16 +32,26 @@ const uint64_t kUnknownHash = detail::HashForVocab("<unk>", 5);
// Sadly some LMs have <UNK>.
const uint64_t kUnknownCapHash = detail::HashForVocab("<UNK>", 5);
-WordIndex ReadWords(int fd, EnumerateVocab *enumerate) {
- if (!enumerate) return std::numeric_limits<WordIndex>::max();
+void ReadWords(int fd, EnumerateVocab *enumerate, WordIndex expected_count) {
+ // Check that we're at the right place by reading <unk> which is always first.
+ char check_unk[6];
+ util::ReadOrThrow(fd, check_unk, 6);
+ UTIL_THROW_IF(
+ memcmp(check_unk, "<unk>", 6),
+ FormatLoadException,
+ "Vocabulary words are in the wrong place. This could be because the binary file was built with stale gcc and old kenlm. Stale gcc, including the gcc distributed with RedHat and OS X, has a bug that ignores pragma pack for template-dependent types. New kenlm works around this, so you'll save memory but have to rebuild any binary files using the probing data structure.");
+ if (!enumerate) return;
+ enumerate->Add(0, "<unk>");
+
+ // Read all the words after unk.
const std::size_t kInitialRead = 16384;
std::string buf;
buf.reserve(kInitialRead + 100);
buf.resize(kInitialRead);
- WordIndex index = 0;
+ WordIndex index = 1; // Read <unk> already.
while (true) {
std::size_t got = util::ReadOrEOF(fd, &buf[0], kInitialRead);
- if (got == 0) return index;
+ if (got == 0) break;
buf.resize(got);
while (buf[buf.size() - 1]) {
char next_char;
@@ -53,6 +65,8 @@ WordIndex ReadWords(int fd, EnumerateVocab *enumerate) {
i += length + 1 /* null byte */;
}
}
+
+ UTIL_THROW_IF(expected_count != index, FormatLoadException, "The binary file has the wrong number of words at the end. This could be caused by a truncated binary file.");
}
} // namespace
@@ -130,9 +144,9 @@ void SortedVocabulary::FinishedLoading(ProbBackoff *reorder_vocab) {
void SortedVocabulary::LoadedBinary(int fd, EnumerateVocab *to) {
end_ = begin_ + *(reinterpret_cast<const uint64_t*>(begin_) - 1);
- ReadWords(fd, to);
SetSpecial(Index("<s>"), Index("</s>"), 0);
bound_ = end_ - begin_ + 1;
+ ReadWords(fd, to, bound_);
}
namespace {
@@ -175,7 +189,7 @@ WordIndex ProbingVocabulary::Insert(const StringPiece &str) {
return 0;
} else {
if (enumerate_) enumerate_->Add(bound_, str);
- lookup_.Insert(Lookup::Packing::Make(hashed, bound_));
+ lookup_.Insert(ProbingVocabuaryEntry::Make(hashed, bound_));
return bound_++;
}
}
@@ -190,9 +204,9 @@ void ProbingVocabulary::FinishedLoading(ProbBackoff * /*reorder_vocab*/) {
void ProbingVocabulary::LoadedBinary(int fd, EnumerateVocab *to) {
UTIL_THROW_IF(header_->version != kProbingVocabularyVersion, FormatLoadException, "The binary file has probing version " << header_->version << " but the code expects version " << kProbingVocabularyVersion << ". Please rerun build_binary using the same version of the code.");
lookup_.LoadedBinary();
- ReadWords(fd, to);
bound_ = header_->bound;
SetSpecial(Index("<s>"), Index("</s>"), 0);
+ ReadWords(fd, to, bound_);
}
void MissingUnknown(const Config &config) throw(SpecialWordMissingException) {
diff --git a/lm/vocab.hh b/lm/vocab.hh
index 3c3414fb9..48db3d627 100644
--- a/lm/vocab.hh
+++ b/lm/vocab.hh
@@ -4,7 +4,6 @@
#include "lm/enumerate_vocab.hh"
#include "lm/lm_exception.hh"
#include "lm/virtual_interface.hh"
-#include "util/key_value_packing.hh"
#include "util/probing_hash_table.hh"
#include "util/sorted_uniform.hh"
#include "util/string_piece.hh"
@@ -100,6 +99,26 @@ class SortedVocabulary : public base::Vocabulary {
std::vector<std::string> strings_to_enumerate_;
};
+#pragma pack(push)
+#pragma pack(4)
+struct ProbingVocabuaryEntry {
+ uint64_t key;
+ WordIndex value;
+
+ typedef uint64_t Key;
+ uint64_t GetKey() const {
+ return key;
+ }
+
+ static ProbingVocabuaryEntry Make(uint64_t key, WordIndex value) {
+ ProbingVocabuaryEntry ret;
+ ret.key = key;
+ ret.value = value;
+ return ret;
+ }
+};
+#pragma pack(pop)
+
// Vocabulary storing a map from uint64_t to WordIndex.
class ProbingVocabulary : public base::Vocabulary {
public:
@@ -107,7 +126,7 @@ class ProbingVocabulary : public base::Vocabulary {
WordIndex Index(const StringPiece &str) const {
Lookup::ConstIterator i;
- return lookup_.Find(detail::HashForVocab(str), i) ? i->GetValue() : 0;
+ return lookup_.Find(detail::HashForVocab(str), i) ? i->value : 0;
}
static size_t Size(std::size_t entries, const Config &config);
@@ -129,12 +148,7 @@ class ProbingVocabulary : public base::Vocabulary {
void LoadedBinary(int fd, EnumerateVocab *to);
private:
- // std::identity is an SGI extension :-(
- struct IdentityHash : public std::unary_function<uint64_t, std::size_t> {
- std::size_t operator()(uint64_t arg) const { return static_cast<std::size_t>(arg); }
- };
-
- typedef util::ProbingHashTable<util::ByteAlignedPacking<uint64_t, WordIndex>, IdentityHash> Lookup;
+ typedef util::ProbingHashTable<ProbingVocabuaryEntry, util::IdentityHash> Lookup;
Lookup lookup_;