Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/kpu/kenlm.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMarcin Junczys-Dowmunt <junczys@amu.edu.pl>2014-11-13 01:18:52 +0300
committerMarcin Junczys-Dowmunt <junczys@amu.edu.pl>2014-11-13 01:18:52 +0300
commit1598bf4402ee0677aa58e5f1091ff74531f88809 (patch)
treee61eb45722ec8d55ea744537149159a16b5a248f
parent1468d734ff9a58771fa3c3b711141258e867dd65 (diff)
parent8bff98a4a0ec0804f11c4fd4574bd28ae59bd82d (diff)
some refactoringunigram_pruning
-rw-r--r--lm/builder/adjust_counts.cc57
-rw-r--r--lm/builder/corpus_count.cc13
2 files changed, 36 insertions, 34 deletions
diff --git a/lm/builder/adjust_counts.cc b/lm/builder/adjust_counts.cc
index bbe6e37..9eb18d2 100644
--- a/lm/builder/adjust_counts.cc
+++ b/lm/builder/adjust_counts.cc
@@ -246,28 +246,29 @@ void AdjustCounts::Run(const util::stream::ChainPositions &positions) {
// Initialization: <unk> has count 0 and so does <s>.
NGramStream *lower_valid = streams.begin();
+ const NGramStream *const streams_begin = streams.begin();
streams[0]->Count() = 0;
*streams[0]->begin() = kUNK;
stats.Add(0, 0);
(++streams[0])->Count() = 0;
*streams[0]->begin() = kBOS;
- // not in stats because it will get put in later.
+ // <s> is not in stats yet because it will get put in later.
+ // This keeps track of actual counts for lower orders. It is not output
+ // (only adjusted counts are), but used to determine pruning.
std::vector<uint64_t> lower_counts(positions.size(), 0);
- // iterate over full (the stream of the highest order ngrams)
- for (; full; ++full) {
+ // Iterate over full (the stream of the highest order ngrams)
+ for (; full; ++full) {
const WordIndex *different = FindDifference(*full, **lower_valid);
std::size_t same = full->end() - 1 - different;
- // Increment the adjusted count.
- if (same) ++streams[same - 1]->Count();
- // Output all the valid ones that changed.
+ // STEP 1: Output all the n-grams that changed.
for (; lower_valid >= &streams[same]; --lower_valid) {
-
- uint64_t lower_order = (*lower_valid)->Order();
- uint64_t lower_count = lower_counts[lower_order - 1];
- if(lower_count <= prune_thresholds_[lower_order - 1] && (lower_order > 1 || (lower_order == 1 && *(*lower_valid)->begin() > 2)))
+
+ uint64_t lower_order_minus_1 = lower_valid - streams_begin;
+ if(lower_counts[lower_order_minus_1] <= prune_thresholds_[lower_order_minus_1]
+ && (lower_order_minus_1 || (lower_order_minus_1 == 0 && *(*lower_valid)->begin() > 2)))
(*lower_valid)->Mark();
if(!prune_words_.empty()) {
@@ -278,28 +279,30 @@ void AdjustCounts::Run(const util::stream::ChainPositions &positions) {
}
}
}
-
- stats.Add(lower_valid - streams.begin(), (*lower_valid)->UnmarkedCount(), (*lower_valid)->IsMarked());
+
+ stats.Add(lower_order_minus_1, (*lower_valid)->UnmarkedCount(), (*lower_valid)->IsMarked());
++*lower_valid;
}
-
- // Count the true occurrences of lower-order n-grams
- for (std::size_t i = 0; i < lower_counts.size(); ++i) {
- if (i >= same) {
- lower_counts[i] = 0;
- }
- lower_counts[i] += full->UnmarkedCount();
+
+ // STEP 2: Update n-grams that still match.
+ // n-grams that match get count from the full entry.
+ for (std::size_t i = 0; i < same; ++i) {
+ lower_counts[i] += full->UnmarkedCount();
}
+ // Increment the number of unique extensions for the longest match.
+ if (same) ++streams[same - 1]->Count();
+ // STEP 3: Initialize new n-grams.
// This is here because bos is also const WordIndex *, so copy gets
// consistent argument types.
const WordIndex *full_end = full->end();
// Initialize and mark as valid up to bos.
const WordIndex *bos;
for (bos = different; (bos > full->begin()) && (*bos != kBOS); --bos) {
- ++lower_valid;
- std::copy(bos, full_end, (*lower_valid)->begin());
- (*lower_valid)->Count() = 1;
+ NGramStream &to = *++lower_valid;
+ std::copy(bos, full_end, to->begin());
+ to->Count() = 1;
+ lower_counts[lower_valid - streams_begin] = full->UnmarkedCount();
}
// Now bos indicates where <s> is or is the 0th word of full.
if (bos != full->begin()) {
@@ -307,15 +310,17 @@ void AdjustCounts::Run(const util::stream::ChainPositions &positions) {
NGramStream &to = *++lower_valid;
std::copy(bos, full_end, to->begin());
- to->Count() = full->UnmarkedCount();
+ // Anything that begins with <s> has full non adjusted count.
+ to->Count() = full->UnmarkedCount();
+ lower_counts[lower_valid - streams_begin] = full->UnmarkedCount();
} else {
- stats.AddFull(full->UnmarkedCount(), full->IsMarked());
+ stats.AddFull(full->UnmarkedCount(), full->IsMarked());
}
assert(lower_valid >= &streams[0]);
}
- // mjd: what is this actually doing?
- // Output everything valid.
+ // The above loop outputs n-grams when it observes changes. This outputs
+ // the last n-grams.
for (NGramStream *s = streams.begin(); s <= lower_valid; ++s) {
uint64_t lower_count = lower_counts[(*s)->Order() - 1];
if(lower_count <= prune_thresholds_[(*s)->Order() - 1])
diff --git a/lm/builder/corpus_count.cc b/lm/builder/corpus_count.cc
index dc03860..7f3dafa 100644
--- a/lm/builder/corpus_count.cc
+++ b/lm/builder/corpus_count.cc
@@ -234,18 +234,15 @@ void CorpusCount::Run(const util::stream::ChainPosition &position) {
try {
while (true) {
StringPiece line(prune_vocab_file.ReadLine());
- for (util::TokenIter<util::BoolCharacter, true> w(line, delimiters); w; ++w) {
- WordIndex i = vocab.Index(*w);
- if (i > 2)
- prune_words_[i] = false;
- }
+ for (util::TokenIter<util::BoolCharacter, true> w(line, delimiters); w; ++w)
+ prune_words_[vocab.Index(*w)] = false;
}
} catch (const util::EndOfFileException &e) {}
// Never prune <unk>, <s>, </s>
- prune_words_[0] = false;
- prune_words_[1] = false;
- prune_words_[2] = false;
+ prune_words_[kUNK] = false;
+ prune_words_[kBOS] = false;
+ prune_words_[kEOS] = false;
} catch (const util::Exception &e) {
std::cerr << e.what() << std::endl;