Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/moses-smt/mosesdecoder.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/kenlm
diff options
context:
space:
mode:
authorheafield <heafield@1f5c12ca-751b-0410-a591-d2e778427230>2011-08-16 16:57:21 +0400
committerheafield <heafield@1f5c12ca-751b-0410-a591-d2e778427230>2011-08-16 16:57:21 +0400
commit6dae77c3eb7568f095bbe712b7d121172b12e1bd (patch)
tree6c75baf05a929d82f4524370871ea2c5c4cc475b /kenlm
parent436a285f18e42b19e581d1e673bf55766934f10b (diff)
Fix segfault withe trie on models without <unk>. Problem was that trie writes correct counts to
the binary file header, including <unk>. But the vocabulary was sized based on the ARPA file count (excluding <unk>). Then when the binary file was loaded, the vocabulary size was based on the count including <unk>. Fix this by pre-padding vocabulary to the count including <unk>. Also, some minor cleanups: remove a debug message and change some always-true returns to void. git-svn-id: https://mosesdecoder.svn.sourceforge.net/svnroot/mosesdecoder/trunk@4145 1f5c12ca-751b-0410-a591-d2e778427230
Diffstat (limited to 'kenlm')
-rw-r--r--kenlm/lm/bhiksha.cc5
-rw-r--r--kenlm/lm/binary_format.cc11
-rw-r--r--kenlm/lm/binary_format.hh2
-rw-r--r--kenlm/lm/model.cc59
-rw-r--r--kenlm/lm/model_test.cc53
-rw-r--r--kenlm/lm/search_hashed.cc2
-rw-r--r--kenlm/lm/search_hashed.hh3
-rw-r--r--kenlm/lm/search_trie.cc10
-rw-r--r--kenlm/lm/search_trie.hh8
-rw-r--r--kenlm/lm/test_nounk.arpa120
-rw-r--r--kenlm/lm/trie.hh3
-rw-r--r--kenlm/lm/vocab.hh4
12 files changed, 221 insertions, 59 deletions
diff --git a/kenlm/lm/bhiksha.cc b/kenlm/lm/bhiksha.cc
index 3d8afbe42..bf86fd4be 100644
--- a/kenlm/lm/bhiksha.cc
+++ b/kenlm/lm/bhiksha.cc
@@ -1,7 +1,6 @@
#include "lm/bhiksha.hh"
#include "lm/config.hh"
-#include <iostream>
#include <limits>
namespace lm {
@@ -50,9 +49,7 @@ std::size_t ArrayCount(uint64_t max_offset, uint64_t max_next, const Config &con
} // namespace
std::size_t ArrayBhiksha::Size(uint64_t max_offset, uint64_t max_next, const Config &config) {
- std::size_t ret = sizeof(uint64_t) * (1 /* header */ + ArrayCount(max_offset, max_next, config)) + 7 /* 8-byte alignment */;
- std::cout << ret << std::endl;
- return ret;
+ return sizeof(uint64_t) * (1 /* header */ + ArrayCount(max_offset, max_next, config)) + 7 /* 8-byte alignment */;
}
uint8_t ArrayBhiksha::InlineBits(uint64_t max_offset, uint64_t max_next, const Config &config) {
diff --git a/kenlm/lm/binary_format.cc b/kenlm/lm/binary_format.cc
index ea42b30b3..e02e621a1 100644
--- a/kenlm/lm/binary_format.cc
+++ b/kenlm/lm/binary_format.cc
@@ -100,16 +100,17 @@ uint8_t *SetupJustVocab(const Config &config, uint8_t order, std::size_t memory_
}
}
-uint8_t *GrowForSearch(const Config &config, std::size_t memory_size, Backing &backing) {
+uint8_t *GrowForSearch(const Config &config, std::size_t vocab_pad, std::size_t memory_size, Backing &backing) {
+ std::size_t adjusted_vocab = backing.vocab.size() + vocab_pad;
if (config.write_mmap) {
// Grow the file to accomodate the search, using zeros.
- if (-1 == ftruncate(backing.file.get(), backing.vocab.size() + memory_size))
- UTIL_THROW(util::ErrnoException, "ftruncate on " << config.write_mmap << " to " << (backing.vocab.size() + memory_size) << " failed");
+ if (-1 == ftruncate(backing.file.get(), adjusted_vocab + memory_size))
+ UTIL_THROW(util::ErrnoException, "ftruncate on " << config.write_mmap << " to " << (adjusted_vocab + memory_size) << " failed");
// We're skipping over the header and vocab for the search space mmap. mmap likes page aligned offsets, so some arithmetic to round the offset down.
off_t page_size = sysconf(_SC_PAGE_SIZE);
- off_t alignment_cruft = backing.vocab.size() % page_size;
- backing.search.reset(util::MapOrThrow(alignment_cruft + memory_size, true, util::kFileFlags, false, backing.file.get(), backing.vocab.size() - alignment_cruft), alignment_cruft + memory_size, util::scoped_memory::MMAP_ALLOCATED);
+ off_t alignment_cruft = adjusted_vocab % page_size;
+ backing.search.reset(util::MapOrThrow(alignment_cruft + memory_size, true, util::kFileFlags, false, backing.file.get(), adjusted_vocab - alignment_cruft), alignment_cruft + memory_size, util::scoped_memory::MMAP_ALLOCATED);
return reinterpret_cast<uint8_t*>(backing.search.get()) + alignment_cruft;
} else {
diff --git a/kenlm/lm/binary_format.hh b/kenlm/lm/binary_format.hh
index dbb8d3606..d28cb6c52 100644
--- a/kenlm/lm/binary_format.hh
+++ b/kenlm/lm/binary_format.hh
@@ -60,7 +60,7 @@ void AdvanceOrThrow(int fd, off_t off);
// Create just enough of a binary file to write vocabulary to it.
uint8_t *SetupJustVocab(const Config &config, uint8_t order, std::size_t memory_size, Backing &backing);
// Grow the binary file for the search data structure and set backing.search, returning the memory address where the search data structure should begin.
-uint8_t *GrowForSearch(const Config &config, std::size_t memory_size, Backing &backing);
+uint8_t *GrowForSearch(const Config &config, std::size_t vocab_pad, std::size_t memory_size, Backing &backing);
// Write header to binary file. This is done last to prevent incomplete files
// from loading.
diff --git a/kenlm/lm/model.cc b/kenlm/lm/model.cc
index fa20cb771..27e24b1c3 100644
--- a/kenlm/lm/model.cc
+++ b/kenlm/lm/model.cc
@@ -58,35 +58,40 @@ template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT
template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::InitializeFromARPA(const char *file, const Config &config) {
// Backing file is the ARPA. Steal it so we can make the backing file the mmap output if any.
util::FilePiece f(backing_.file.release(), file, config.messages);
- std::vector<uint64_t> counts;
- // File counts do not include pruned trigrams that extend to quadgrams etc. These will be fixed by search_.
- ReadARPACounts(f, counts);
-
- if (counts.size() > kMaxOrder) UTIL_THROW(FormatLoadException, "This model has order " << counts.size() << ". Edit lm/max_order.hh, set kMaxOrder to at least this value, and recompile.");
- if (counts.size() < 2) UTIL_THROW(FormatLoadException, "This ngram implementation assumes at least a bigram model.");
- if (config.probing_multiplier <= 1.0) UTIL_THROW(ConfigException, "probing multiplier must be > 1.0");
-
- std::size_t vocab_size = VocabularyT::Size(counts[0], config);
- // Setup the binary file for writing the vocab lookup table. The search_ is responsible for growing the binary file to its needs.
- vocab_.SetupMemory(SetupJustVocab(config, counts.size(), vocab_size, backing_), vocab_size, counts[0], config);
-
- if (config.write_mmap) {
- WriteWordsWrapper wrap(config.enumerate_vocab);
- vocab_.ConfigureEnumerate(&wrap, counts[0]);
- search_.InitializeFromARPA(file, f, counts, config, vocab_, backing_);
- wrap.Write(backing_.file.get());
- } else {
- vocab_.ConfigureEnumerate(config.enumerate_vocab, counts[0]);
- search_.InitializeFromARPA(file, f, counts, config, vocab_, backing_);
- }
+ try {
+ std::vector<uint64_t> counts;
+ // File counts do not include pruned trigrams that extend to quadgrams etc. These will be fixed by search_.
+ ReadARPACounts(f, counts);
+
+ if (counts.size() > kMaxOrder) UTIL_THROW(FormatLoadException, "This model has order " << counts.size() << ". Edit lm/max_order.hh, set kMaxOrder to at least this value, and recompile.");
+ if (counts.size() < 2) UTIL_THROW(FormatLoadException, "This ngram implementation assumes at least a bigram model.");
+ if (config.probing_multiplier <= 1.0) UTIL_THROW(ConfigException, "probing multiplier must be > 1.0");
+
+ std::size_t vocab_size = VocabularyT::Size(counts[0], config);
+ // Setup the binary file for writing the vocab lookup table. The search_ is responsible for growing the binary file to its needs.
+ vocab_.SetupMemory(SetupJustVocab(config, counts.size(), vocab_size, backing_), vocab_size, counts[0], config);
+
+ if (config.write_mmap) {
+ WriteWordsWrapper wrap(config.enumerate_vocab);
+ vocab_.ConfigureEnumerate(&wrap, counts[0]);
+ search_.InitializeFromARPA(file, f, counts, config, vocab_, backing_);
+ wrap.Write(backing_.file.get());
+ } else {
+ vocab_.ConfigureEnumerate(config.enumerate_vocab, counts[0]);
+ search_.InitializeFromARPA(file, f, counts, config, vocab_, backing_);
+ }
- if (!vocab_.SawUnk()) {
- assert(config.unknown_missing != THROW_UP);
- // Default probabilities for unknown.
- search_.unigram.Unknown().backoff = 0.0;
- search_.unigram.Unknown().prob = config.unknown_missing_logprob;
+ if (!vocab_.SawUnk()) {
+ assert(config.unknown_missing != THROW_UP);
+ // Default probabilities for unknown.
+ search_.unigram.Unknown().backoff = 0.0;
+ search_.unigram.Unknown().prob = config.unknown_missing_logprob;
+ }
+ FinishFile(config, kModelType, counts, backing_);
+ } catch (util::Exception &e) {
+ e << " Byte: " << f.Offset();
+ throw;
}
- FinishFile(config, kModelType, counts, backing_);
}
template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search, VocabularyT>::FullScore(const State &in_state, const WordIndex new_word, State &out_state) const {
diff --git a/kenlm/lm/model_test.cc b/kenlm/lm/model_test.cc
index 2f1066f1a..598605772 100644
--- a/kenlm/lm/model_test.cc
+++ b/kenlm/lm/model_test.cc
@@ -193,6 +193,14 @@ template <class M> void Stateless(const M &model) {
BOOST_CHECK_EQUAL(static_cast<WordIndex>(0), state.history_[0]);
}
+template <class M> void NoUnkCheck(const M &model) {
+ WordIndex unk_index = 0;
+ State state;
+
+ FullScoreReturn ret = model.FullScoreForgotState(&unk_index, &unk_index + 1, unk_index, state);
+ BOOST_CHECK_CLOSE(-100.0, ret.prob, 0.001);
+}
+
template <class M> void Everything(const M &m) {
Starters(m);
Continuation(m);
@@ -231,12 +239,21 @@ template <class ModelT> void LoadingTest() {
Config config;
config.arpa_complain = Config::NONE;
config.messages = NULL;
- ExpectEnumerateVocab enumerate;
- config.enumerate_vocab = &enumerate;
config.probing_multiplier = 2.0;
- ModelT m("test.arpa", config);
- enumerate.Check(m.GetVocabulary());
- Everything(m);
+ {
+ ExpectEnumerateVocab enumerate;
+ config.enumerate_vocab = &enumerate;
+ ModelT m("test.arpa", config);
+ enumerate.Check(m.GetVocabulary());
+ Everything(m);
+ }
+ {
+ ExpectEnumerateVocab enumerate;
+ config.enumerate_vocab = &enumerate;
+ ModelT m("test_nounk.arpa", config);
+ enumerate.Check(m.GetVocabulary());
+ NoUnkCheck(m);
+ }
}
BOOST_AUTO_TEST_CASE(probing) {
@@ -275,9 +292,29 @@ template <class ModelT> void BinaryTest() {
BOOST_REQUIRE(RecognizeBinary("test.binary", type));
BOOST_CHECK_EQUAL(ModelT::kModelType, type);
- ModelT binary("test.binary", config);
- enumerate.Check(binary.GetVocabulary());
- Everything(binary);
+ {
+ ModelT binary("test.binary", config);
+ enumerate.Check(binary.GetVocabulary());
+ Everything(binary);
+ }
+ unlink("test.binary");
+
+ // Now test without <unk>.
+ config.write_mmap = "test_nounk.binary";
+ config.messages = NULL;
+ enumerate.Clear();
+ {
+ ModelT copy_model("test_nounk.arpa", config);
+ enumerate.Check(copy_model.GetVocabulary());
+ enumerate.Clear();
+ NoUnkCheck(copy_model);
+ }
+ config.write_mmap = NULL;
+ {
+ ModelT binary("test_nounk.binary", config);
+ enumerate.Check(binary.GetVocabulary());
+ NoUnkCheck(binary);
+ }
unlink("test.binary");
}
diff --git a/kenlm/lm/search_hashed.cc b/kenlm/lm/search_hashed.cc
index c56ba7b85..82c53ec87 100644
--- a/kenlm/lm/search_hashed.cc
+++ b/kenlm/lm/search_hashed.cc
@@ -98,7 +98,7 @@ template <class MiddleT, class LongestT> uint8_t *TemplateHashedSearch<MiddleT,
template <class MiddleT, class LongestT> template <class Voc> void TemplateHashedSearch<MiddleT, LongestT>::InitializeFromARPA(const char * /*file*/, util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, Voc &vocab, Backing &backing) {
// TODO: fix sorted.
- SetupMemory(GrowForSearch(config, Size(counts, config), backing), counts, config);
+ SetupMemory(GrowForSearch(config, 0, Size(counts, config), backing), counts, config);
PositiveProbWarn warn(config.positive_log_probability);
diff --git a/kenlm/lm/search_hashed.hh b/kenlm/lm/search_hashed.hh
index f3acdefc8..c62985e4b 100644
--- a/kenlm/lm/search_hashed.hh
+++ b/kenlm/lm/search_hashed.hh
@@ -52,12 +52,11 @@ struct HashedSearch {
Unigram unigram;
- bool LookupUnigram(WordIndex word, float &prob, float &backoff, Node &next) const {
+ void LookupUnigram(WordIndex word, float &prob, float &backoff, Node &next) const {
const ProbBackoff &entry = unigram.Lookup(word);
prob = entry.prob;
backoff = entry.backoff;
next = static_cast<Node>(word);
- return true;
}
};
diff --git a/kenlm/lm/search_trie.cc b/kenlm/lm/search_trie.cc
index 422421d83..05059ffbe 100644
--- a/kenlm/lm/search_trie.cc
+++ b/kenlm/lm/search_trie.cc
@@ -544,8 +544,8 @@ void ARPAToSortedFiles(const Config &config, util::FilePiece &f, std::vector<uin
std::string unigram_name = file_prefix + "unigrams";
util::scoped_fd unigram_file;
// In case <unk> appears.
- size_t extra_count = counts[0] + 1;
- util::scoped_mmap unigram_mmap(util::MapZeroedWrite(unigram_name.c_str(), extra_count * sizeof(ProbBackoff), unigram_file), extra_count * sizeof(ProbBackoff));
+ size_t file_out = (counts[0] + 1) * sizeof(ProbBackoff);
+ util::scoped_mmap unigram_mmap(util::MapZeroedWrite(unigram_name.c_str(), file_out, unigram_file), file_out);
Read1Grams(f, counts[0], vocab, reinterpret_cast<ProbBackoff*>(unigram_mmap.get()), warn);
CheckSpecials(config, vocab);
if (!vocab.SawUnk()) ++counts[0];
@@ -822,7 +822,7 @@ template <class Quant> void TrainProbQuantizer(uint8_t order, uint64_t count, So
} // namespace
-template <class Quant, class Bhiksha> void BuildTrie(const std::string &file_prefix, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant, Bhiksha> &out, Quant &quant, Backing &backing) {
+template <class Quant, class Bhiksha> void BuildTrie(const std::string &file_prefix, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant, Bhiksha> &out, Quant &quant, const SortedVocabulary &vocab, Backing &backing) {
std::vector<SortedFileReader> inputs(counts.size() - 1);
std::vector<ContextReader> contexts(counts.size() - 1);
@@ -847,7 +847,7 @@ template <class Quant, class Bhiksha> void BuildTrie(const std::string &file_pre
SanityCheckCounts(counts, fixed_counts);
counts = fixed_counts;
- out.SetupMemory(GrowForSearch(config, TrieSearch<Quant, Bhiksha>::Size(fixed_counts, config), backing), fixed_counts, config);
+ out.SetupMemory(GrowForSearch(config, vocab.UnkCountChangePadding(), TrieSearch<Quant, Bhiksha>::Size(fixed_counts, config), backing), fixed_counts, config);
if (Quant::kTrain) {
util::ErsatzProgress progress(config.messages, "Quantizing", std::accumulate(counts.begin() + 1, counts.end(), 0));
@@ -969,7 +969,7 @@ template <class Quant, class Bhiksha> void TrieSearch<Quant, Bhiksha>::Initializ
// At least 1MB sorting memory.
ARPAToSortedFiles(config, f, counts, std::max<size_t>(config.building_memory, 1048576), temporary_directory.c_str(), vocab);
- BuildTrie(temporary_directory, counts, config, *this, quant_, backing);
+ BuildTrie(temporary_directory, counts, config, *this, quant_, vocab, backing);
if (rmdir(temporary_directory.c_str()) && config.messages) {
*config.messages << "Failed to delete " << temporary_directory << std::endl;
}
diff --git a/kenlm/lm/search_trie.hh b/kenlm/lm/search_trie.hh
index b308f9b46..2f39c09ff 100644
--- a/kenlm/lm/search_trie.hh
+++ b/kenlm/lm/search_trie.hh
@@ -14,7 +14,7 @@ class SortedVocabulary;
namespace trie {
template <class Quant, class Bhiksha> class TrieSearch;
-template <class Quant, class Bhiksha> void BuildTrie(const std::string &file_prefix, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant, Bhiksha> &out, Quant &quant, Backing &backing);
+template <class Quant, class Bhiksha> void BuildTrie(const std::string &file_prefix, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant, Bhiksha> &out, Quant &quant, const SortedVocabulary &vocab, Backing &backing);
template <class Quant, class Bhiksha> class TrieSearch {
public:
@@ -57,8 +57,8 @@ template <class Quant, class Bhiksha> class TrieSearch {
void InitializeFromARPA(const char *file, util::FilePiece &f, std::vector<uint64_t> &counts, const Config &config, SortedVocabulary &vocab, Backing &backing);
- bool LookupUnigram(WordIndex word, float &prob, float &backoff, Node &node) const {
- return unigram.Find(word, prob, backoff, node);
+ void LookupUnigram(WordIndex word, float &prob, float &backoff, Node &node) const {
+ unigram.Find(word, prob, backoff, node);
}
bool LookupMiddle(const Middle &mid, WordIndex word, float &prob, float &backoff, Node &node) const {
@@ -85,7 +85,7 @@ template <class Quant, class Bhiksha> class TrieSearch {
}
private:
- friend void BuildTrie<Quant, Bhiksha>(const std::string &file_prefix, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant, Bhiksha> &out, Quant &quant, Backing &backing);
+ friend void BuildTrie<Quant, Bhiksha>(const std::string &file_prefix, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant, Bhiksha> &out, Quant &quant, const SortedVocabulary &vocab, Backing &backing);
// Middles are managed manually so we can delay construction and they don't have to be copyable.
void FreeMiddles() {
diff --git a/kenlm/lm/test_nounk.arpa b/kenlm/lm/test_nounk.arpa
new file mode 100644
index 000000000..060733d98
--- /dev/null
+++ b/kenlm/lm/test_nounk.arpa
@@ -0,0 +1,120 @@
+
+\data\
+ngram 1=36
+ngram 2=45
+ngram 3=10
+ngram 4=6
+ngram 5=4
+
+\1-grams:
+-1.383514 , -0.30103
+-1.139057 . -0.845098
+-1.029493 </s>
+-99 <s> -0.4149733
+-1.285941 a -0.69897
+-1.687872 also -0.30103
+-1.687872 beyond -0.30103
+-1.687872 biarritz -0.30103
+-1.687872 call -0.30103
+-1.687872 concerns -0.30103
+-1.687872 consider -0.30103
+-1.687872 considering -0.30103
+-1.687872 for -0.30103
+-1.509559 higher -0.30103
+-1.687872 however -0.30103
+-1.687872 i -0.30103
+-1.687872 immediate -0.30103
+-1.687872 in -0.30103
+-1.687872 is -0.30103
+-1.285941 little -0.69897
+-1.383514 loin -0.30103
+-1.687872 look -0.30103
+-1.285941 looking -0.4771212
+-1.206319 more -0.544068
+-1.509559 on -0.4771212
+-1.509559 screening -0.4771212
+-1.687872 small -0.30103
+-1.687872 the -0.30103
+-1.687872 to -0.30103
+-1.687872 watch -0.30103
+-1.687872 watching -0.30103
+-1.687872 what -0.30103
+-1.687872 would -0.30103
+-3.141592 foo
+-2.718281 bar 3.0
+-6.535897 baz -0.0
+
+\2-grams:
+-0.6925742 , .
+-0.7522095 , however
+-0.7522095 , is
+-0.0602359 . </s>
+-0.4846522 <s> looking -0.4771214
+-1.051485 <s> screening
+-1.07153 <s> the
+-1.07153 <s> watching
+-1.07153 <s> what
+-0.09132547 a little -0.69897
+-0.2922095 also call
+-0.2922095 beyond immediate
+-0.2705918 biarritz .
+-0.2922095 call for
+-0.2922095 concerns in
+-0.2922095 consider watch
+-0.2922095 considering consider
+-0.2834328 for ,
+-0.5511513 higher more
+-0.5845945 higher small
+-0.2834328 however ,
+-0.2922095 i would
+-0.2922095 immediate concerns
+-0.2922095 in biarritz
+-0.2922095 is to
+-0.09021038 little more -0.1998621
+-0.7273645 loin ,
+-0.6925742 loin .
+-0.6708385 loin </s>
+-0.2922095 look beyond
+-0.4638903 looking higher
+-0.4638903 looking on -0.4771212
+-0.5136299 more . -0.4771212
+-0.3561665 more loin
+-0.1649931 on a -0.4771213
+-0.1649931 screening a -0.4771213
+-0.2705918 small .
+-0.287799 the screening
+-0.2922095 to look
+-0.2622373 watch </s>
+-0.2922095 watching considering
+-0.2922095 what i
+-0.2922095 would also
+-2 also would -6
+-6 foo bar
+
+\3-grams:
+-0.01916512 more . </s>
+-0.0283603 on a little -0.4771212
+-0.0283603 screening a little -0.4771212
+-0.01660496 a little more -0.09409451
+-0.3488368 <s> looking higher
+-0.3488368 <s> looking on -0.4771212
+-0.1892331 little more loin
+-0.04835128 looking on a -0.4771212
+-3 also would consider -7
+-7 to look good
+
+\4-grams:
+-0.009249173 looking on a little -0.4771212
+-0.005464747 on a little more -0.4771212
+-0.005464747 screening a little more
+-0.1453306 a little more loin
+-0.01552657 <s> looking on a -0.4771212
+-4 also would consider higher -8
+
+\5-grams:
+-0.003061223 <s> looking on a little
+-0.001813953 looking on a little more
+-0.0432557 on a little more loin
+-5 also would consider higher looking
+
+\end\
diff --git a/kenlm/lm/trie.hh b/kenlm/lm/trie.hh
index 3ea9b8f4f..536120648 100644
--- a/kenlm/lm/trie.hh
+++ b/kenlm/lm/trie.hh
@@ -47,13 +47,12 @@ class Unigram {
void LoadedBinary() {}
- bool Find(WordIndex word, float &prob, float &backoff, NodeRange &next) const {
+ void Find(WordIndex word, float &prob, float &backoff, NodeRange &next) const {
UnigramValue *val = unigram_ + word;
prob = val->weights.prob;
backoff = val->weights.backoff;
next.begin = val->next;
next.end = (val+1)->next;
- return true;
}
private:
diff --git a/kenlm/lm/vocab.hh b/kenlm/lm/vocab.hh
index c92518e48..9d218fff0 100644
--- a/kenlm/lm/vocab.hh
+++ b/kenlm/lm/vocab.hh
@@ -61,6 +61,7 @@ class SortedVocabulary : public base::Vocabulary {
}
}
+ // Size for purposes of file writing
static size_t Size(std::size_t entries, const Config &config);
// Vocab words are [0, Bound()) Only valid after FinishedLoading/LoadedBinary.
@@ -77,6 +78,9 @@ class SortedVocabulary : public base::Vocabulary {
// Reorders reorder_vocab so that the IDs are sorted.
void FinishedLoading(ProbBackoff *reorder_vocab);
+ // Trie stores the correct counts including <unk> in the header. If this was previously sized based on a count exluding <unk>, padding with 8 bytes will make it the correct size based on a count including <unk>.
+ std::size_t UnkCountChangePadding() const { return SawUnk() ? 0 : sizeof(uint64_t); }
+
bool SawUnk() const { return saw_unk_; }
void LoadedBinary(int fd, EnumerateVocab *to);