diff options
author | Kenneth Heafield <github@kheafield.com> | 2014-02-06 01:42:57 +0400 |
---|---|---|
committer | Kenneth Heafield <github@kheafield.com> | 2014-02-06 01:42:57 +0400 |
commit | 4ffaf8034598fb208e01103c6ed04833d0010998 (patch) | |
tree | 4a5340b087dc8749b1d8f56cf612483f97fea27f | |
parent | 71add7895757e478f09a2751cf9f77e8d6672300 (diff) |
Use FilePiece for query program, print perplexity with/without OOVs
-rw-r--r-- | lm/ngram_query.hh | 48 | ||||
-rw-r--r-- | lm/query_main.cc | 16 | ||||
-rw-r--r-- | util/Jamfile | 1 | ||||
-rw-r--r-- | util/file_piece.hh | 20 |
4 files changed, 55 insertions, 30 deletions
diff --git a/lm/ngram_query.hh b/lm/ngram_query.hh index ec2590f..efb509d 100644 --- a/lm/ngram_query.hh +++ b/lm/ngram_query.hh @@ -3,6 +3,7 @@ #include "lm/enumerate_vocab.hh" #include "lm/model.hh" +#include "util/file_piece.hh" #include "util/usage.hh" #include <cstdlib> @@ -16,42 +17,41 @@ namespace lm { namespace ngram { -template <class Model> void Query(const Model &model, bool sentence_context, std::istream &in_stream, std::ostream &out_stream) { +template <class Model> void Query(const Model &model, bool sentence_context) { typename Model::State state, out; lm::FullScoreReturn ret; - std::string word; + StringPiece word; + + util::FilePiece in(0); + std::ostream &out_stream = std::cout; double corpus_total = 0.0; + double corpus_total_oov_only = 0.0; uint64_t corpus_oov = 0; uint64_t corpus_tokens = 0; - while (in_stream) { + while (true) { state = sentence_context ? model.BeginSentenceState() : model.NullContextState(); float total = 0.0; - bool got = false; uint64_t oov = 0; - while (in_stream >> word) { - got = true; + + while (in.ReadWordSameLine(word)) { lm::WordIndex vocab = model.GetVocabulary().Index(word); - if (vocab == 0) ++oov; ret = model.FullScore(state, vocab, out); + if (vocab == model.GetVocabulary().NotFound()) { + ++oov; + corpus_total_oov_only += ret.prob; + } total += ret.prob; out_stream << word << '=' << vocab << ' ' << static_cast<unsigned int>(ret.ngram_length) << ' ' << ret.prob << '\t'; ++corpus_tokens; state = out; - char c; - while (true) { - c = in_stream.get(); - if (!in_stream) break; - if (c == '\n') break; - if (!isspace(c)) { - in_stream.unget(); - break; - } - } - if (c == '\n') break; } - if (!got && !in_stream) break; + // If people don't have a newline after their last query, this won't add a </s>. + // Sue me. + try { + UTIL_THROW_IF('\n' != in.get(), util::Exception, "FilePiece is confused."); + } catch (const util::EndOfFileException &e) { break; } if (sentence_context) { ret = model.FullScore(state, model.GetVocabulary().EndSentence(), out); total += ret.prob; @@ -62,13 +62,17 @@ template <class Model> void Query(const Model &model, bool sentence_context, std corpus_total += total; corpus_oov += oov; } - out_stream << "Perplexity " << pow(10.0, -(corpus_total / static_cast<double>(corpus_tokens))) << std::endl; + out_stream << + "Perplexity including OOVs:\t" << pow(10.0, -(corpus_total / static_cast<double>(corpus_tokens))) << "\n" + "Perplexity excluding OOVs:\t" << pow(10.0, -((corpus_total - corpus_total_oov_only) / static_cast<double>(corpus_tokens - corpus_oov))) << "\n" + "OOVs:\t" << corpus_oov << "\n" + ; } -template <class M> void Query(const char *file, bool sentence_context, std::istream &in_stream, std::ostream &out_stream) { +template <class M> void Query(const char *file, bool sentence_context) { Config config; M model(file, config); - Query(model, sentence_context, in_stream, out_stream); + Query(model, sentence_context); } } // namespace ngram diff --git a/lm/query_main.cc b/lm/query_main.cc index bd4fde6..cd661f7 100644 --- a/lm/query_main.cc +++ b/lm/query_main.cc @@ -32,22 +32,22 @@ int main(int argc, char *argv[]) { if (RecognizeBinary(file, model_type)) { switch(model_type) { case PROBING: - Query<lm::ngram::ProbingModel>(file, sentence_context, std::cin, std::cout); + Query<lm::ngram::ProbingModel>(file, sentence_context); break; case REST_PROBING: - Query<lm::ngram::RestProbingModel>(file, sentence_context, std::cin, std::cout); + Query<lm::ngram::RestProbingModel>(file, sentence_context); break; case TRIE: - Query<TrieModel>(file, sentence_context, std::cin, std::cout); + Query<TrieModel>(file, sentence_context); break; case QUANT_TRIE: - Query<QuantTrieModel>(file, sentence_context, std::cin, std::cout); + Query<QuantTrieModel>(file, sentence_context); break; case ARRAY_TRIE: - Query<ArrayTrieModel>(file, sentence_context, std::cin, std::cout); + Query<ArrayTrieModel>(file, sentence_context); break; case QUANT_ARRAY_TRIE: - Query<QuantArrayTrieModel>(file, sentence_context, std::cin, std::cout); + Query<QuantArrayTrieModel>(file, sentence_context); break; default: std::cerr << "Unrecognized kenlm model type " << model_type << std::endl; @@ -56,10 +56,10 @@ int main(int argc, char *argv[]) { #ifdef WITH_NPLM } else if (lm::np::Model::Recognize(file)) { lm::np::Model model(file); - Query(model, sentence_context, std::cin, std::cout); + Query(model, sentence_context); #endif } else { - Query<ProbingModel>(file, sentence_context, std::cin, std::cout); + Query<ProbingModel>(file, sentence_context); } std::cerr << "Total time including destruction:\n"; util::PrintUsage(std::cerr); diff --git a/util/Jamfile b/util/Jamfile index 77b5438..afab916 100644 --- a/util/Jamfile +++ b/util/Jamfile @@ -22,6 +22,7 @@ obj file_piece_test.o : file_piece_test.cc /top//boost_unit_test_framework : $(c fakelib kenutil : bit_packing.cc ersatz_progress.cc exception.cc file.cc file_piece.cc mmap.cc murmur_hash.cc pool.cc read_compressed scoped.cc string_piece.cc usage.cc double-conversion//double-conversion : <include>.. <os>LINUX,<threading>single:<source>rt : : <include>.. ; exe cat_compressed : cat_compressed_main.cc kenutil ; +exe file_piece_cat : file_piece_main.cc kenutil ; alias programs : cat_compressed ; diff --git a/util/file_piece.hh b/util/file_piece.hh index 1054c18..83bcd4f 100644 --- a/util/file_piece.hh +++ b/util/file_piece.hh @@ -56,6 +56,26 @@ class FilePiece { return Consume(FindDelimiterOrEOF(delim)); } + // Read word until the line or file ends. + bool ReadWordSameLine(StringPiece &to, const bool *delim = kSpaces) { + assert(delim[static_cast<unsigned char>('\n')]); + // Skip non-enter spaces. + for (; ; ++position_) { + if (position_ == position_end_) { + try { + Shift(); + } catch (const util::EndOfFileException &e) { return false; } + // And break out at end of file. + if (position_ == position_end_) return false; + } + if (!delim[static_cast<unsigned char>(*position_)]) break; + if (*position_ == '\n') return false; + } + // We can't be at the end of file because there's at least one character open. + to = Consume(FindDelimiterOrEOF(delim)); + return true; + } + // Unlike ReadDelimited, this includes leading spaces and consumes the delimiter. // It is similar to getline in that way. StringPiece ReadLine(char delim = '\n'); |