diff options
author | Marcin Junczys-Dowmunt <junczys@amu.edu.pl> | 2014-11-13 01:18:52 +0300 |
---|---|---|
committer | Marcin Junczys-Dowmunt <junczys@amu.edu.pl> | 2014-11-13 01:18:52 +0300 |
commit | 1598bf4402ee0677aa58e5f1091ff74531f88809 (patch) | |
tree | e61eb45722ec8d55ea744537149159a16b5a248f | |
parent | 1468d734ff9a58771fa3c3b711141258e867dd65 (diff) | |
parent | 8bff98a4a0ec0804f11c4fd4574bd28ae59bd82d (diff) |
some refactoringunigram_pruning
-rw-r--r-- | lm/builder/adjust_counts.cc | 57 | ||||
-rw-r--r-- | lm/builder/corpus_count.cc | 13 |
2 files changed, 36 insertions, 34 deletions
diff --git a/lm/builder/adjust_counts.cc b/lm/builder/adjust_counts.cc index bbe6e37..9eb18d2 100644 --- a/lm/builder/adjust_counts.cc +++ b/lm/builder/adjust_counts.cc @@ -246,28 +246,29 @@ void AdjustCounts::Run(const util::stream::ChainPositions &positions) { // Initialization: <unk> has count 0 and so does <s>. NGramStream *lower_valid = streams.begin(); + const NGramStream *const streams_begin = streams.begin(); streams[0]->Count() = 0; *streams[0]->begin() = kUNK; stats.Add(0, 0); (++streams[0])->Count() = 0; *streams[0]->begin() = kBOS; - // not in stats because it will get put in later. + // <s> is not in stats yet because it will get put in later. + // This keeps track of actual counts for lower orders. It is not output + // (only adjusted counts are), but used to determine pruning. std::vector<uint64_t> lower_counts(positions.size(), 0); - // iterate over full (the stream of the highest order ngrams) - for (; full; ++full) { + // Iterate over full (the stream of the highest order ngrams) + for (; full; ++full) { const WordIndex *different = FindDifference(*full, **lower_valid); std::size_t same = full->end() - 1 - different; - // Increment the adjusted count. - if (same) ++streams[same - 1]->Count(); - // Output all the valid ones that changed. + // STEP 1: Output all the n-grams that changed. for (; lower_valid >= &streams[same]; --lower_valid) { - - uint64_t lower_order = (*lower_valid)->Order(); - uint64_t lower_count = lower_counts[lower_order - 1]; - if(lower_count <= prune_thresholds_[lower_order - 1] && (lower_order > 1 || (lower_order == 1 && *(*lower_valid)->begin() > 2))) + + uint64_t lower_order_minus_1 = lower_valid - streams_begin; + if(lower_counts[lower_order_minus_1] <= prune_thresholds_[lower_order_minus_1] + && (lower_order_minus_1 || (lower_order_minus_1 == 0 && *(*lower_valid)->begin() > 2))) (*lower_valid)->Mark(); if(!prune_words_.empty()) { @@ -278,28 +279,30 @@ void AdjustCounts::Run(const util::stream::ChainPositions &positions) { } } } - - stats.Add(lower_valid - streams.begin(), (*lower_valid)->UnmarkedCount(), (*lower_valid)->IsMarked()); + + stats.Add(lower_order_minus_1, (*lower_valid)->UnmarkedCount(), (*lower_valid)->IsMarked()); ++*lower_valid; } - - // Count the true occurrences of lower-order n-grams - for (std::size_t i = 0; i < lower_counts.size(); ++i) { - if (i >= same) { - lower_counts[i] = 0; - } - lower_counts[i] += full->UnmarkedCount(); + + // STEP 2: Update n-grams that still match. + // n-grams that match get count from the full entry. + for (std::size_t i = 0; i < same; ++i) { + lower_counts[i] += full->UnmarkedCount(); } + // Increment the number of unique extensions for the longest match. + if (same) ++streams[same - 1]->Count(); + // STEP 3: Initialize new n-grams. // This is here because bos is also const WordIndex *, so copy gets // consistent argument types. const WordIndex *full_end = full->end(); // Initialize and mark as valid up to bos. const WordIndex *bos; for (bos = different; (bos > full->begin()) && (*bos != kBOS); --bos) { - ++lower_valid; - std::copy(bos, full_end, (*lower_valid)->begin()); - (*lower_valid)->Count() = 1; + NGramStream &to = *++lower_valid; + std::copy(bos, full_end, to->begin()); + to->Count() = 1; + lower_counts[lower_valid - streams_begin] = full->UnmarkedCount(); } // Now bos indicates where <s> is or is the 0th word of full. if (bos != full->begin()) { @@ -307,15 +310,17 @@ void AdjustCounts::Run(const util::stream::ChainPositions &positions) { NGramStream &to = *++lower_valid; std::copy(bos, full_end, to->begin()); - to->Count() = full->UnmarkedCount(); + // Anything that begins with <s> has full non adjusted count. + to->Count() = full->UnmarkedCount(); + lower_counts[lower_valid - streams_begin] = full->UnmarkedCount(); } else { - stats.AddFull(full->UnmarkedCount(), full->IsMarked()); + stats.AddFull(full->UnmarkedCount(), full->IsMarked()); } assert(lower_valid >= &streams[0]); } - // mjd: what is this actually doing? - // Output everything valid. + // The above loop outputs n-grams when it observes changes. This outputs + // the last n-grams. for (NGramStream *s = streams.begin(); s <= lower_valid; ++s) { uint64_t lower_count = lower_counts[(*s)->Order() - 1]; if(lower_count <= prune_thresholds_[(*s)->Order() - 1]) diff --git a/lm/builder/corpus_count.cc b/lm/builder/corpus_count.cc index dc03860..7f3dafa 100644 --- a/lm/builder/corpus_count.cc +++ b/lm/builder/corpus_count.cc @@ -234,18 +234,15 @@ void CorpusCount::Run(const util::stream::ChainPosition &position) { try { while (true) { StringPiece line(prune_vocab_file.ReadLine()); - for (util::TokenIter<util::BoolCharacter, true> w(line, delimiters); w; ++w) { - WordIndex i = vocab.Index(*w); - if (i > 2) - prune_words_[i] = false; - } + for (util::TokenIter<util::BoolCharacter, true> w(line, delimiters); w; ++w) + prune_words_[vocab.Index(*w)] = false; } } catch (const util::EndOfFileException &e) {} // Never prune <unk>, <s>, </s> - prune_words_[0] = false; - prune_words_[1] = false; - prune_words_[2] = false; + prune_words_[kUNK] = false; + prune_words_[kBOS] = false; + prune_words_[kEOS] = false; } catch (const util::Exception &e) { std::cerr << e.what() << std::endl; |