Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/kpu/kenlm.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKenneth Heafield <github@kheafield.com>2014-02-06 01:42:57 +0400
committerKenneth Heafield <github@kheafield.com>2014-02-06 01:42:57 +0400
commit4ffaf8034598fb208e01103c6ed04833d0010998 (patch)
tree4a5340b087dc8749b1d8f56cf612483f97fea27f
parent71add7895757e478f09a2751cf9f77e8d6672300 (diff)
Use FilePiece for query program, print perplexity with/without OOVs
-rw-r--r--lm/ngram_query.hh48
-rw-r--r--lm/query_main.cc16
-rw-r--r--util/Jamfile1
-rw-r--r--util/file_piece.hh20
4 files changed, 55 insertions, 30 deletions
diff --git a/lm/ngram_query.hh b/lm/ngram_query.hh
index ec2590f..efb509d 100644
--- a/lm/ngram_query.hh
+++ b/lm/ngram_query.hh
@@ -3,6 +3,7 @@
#include "lm/enumerate_vocab.hh"
#include "lm/model.hh"
+#include "util/file_piece.hh"
#include "util/usage.hh"
#include <cstdlib>
@@ -16,42 +17,41 @@
namespace lm {
namespace ngram {
-template <class Model> void Query(const Model &model, bool sentence_context, std::istream &in_stream, std::ostream &out_stream) {
+template <class Model> void Query(const Model &model, bool sentence_context) {
typename Model::State state, out;
lm::FullScoreReturn ret;
- std::string word;
+ StringPiece word;
+
+ util::FilePiece in(0);
+ std::ostream &out_stream = std::cout;
double corpus_total = 0.0;
+ double corpus_total_oov_only = 0.0;
uint64_t corpus_oov = 0;
uint64_t corpus_tokens = 0;
- while (in_stream) {
+ while (true) {
state = sentence_context ? model.BeginSentenceState() : model.NullContextState();
float total = 0.0;
- bool got = false;
uint64_t oov = 0;
- while (in_stream >> word) {
- got = true;
+
+ while (in.ReadWordSameLine(word)) {
lm::WordIndex vocab = model.GetVocabulary().Index(word);
- if (vocab == 0) ++oov;
ret = model.FullScore(state, vocab, out);
+ if (vocab == model.GetVocabulary().NotFound()) {
+ ++oov;
+ corpus_total_oov_only += ret.prob;
+ }
total += ret.prob;
out_stream << word << '=' << vocab << ' ' << static_cast<unsigned int>(ret.ngram_length) << ' ' << ret.prob << '\t';
++corpus_tokens;
state = out;
- char c;
- while (true) {
- c = in_stream.get();
- if (!in_stream) break;
- if (c == '\n') break;
- if (!isspace(c)) {
- in_stream.unget();
- break;
- }
- }
- if (c == '\n') break;
}
- if (!got && !in_stream) break;
+ // If people don't have a newline after their last query, this won't add a </s>.
+ // Sue me.
+ try {
+ UTIL_THROW_IF('\n' != in.get(), util::Exception, "FilePiece is confused.");
+ } catch (const util::EndOfFileException &e) { break; }
if (sentence_context) {
ret = model.FullScore(state, model.GetVocabulary().EndSentence(), out);
total += ret.prob;
@@ -62,13 +62,17 @@ template <class Model> void Query(const Model &model, bool sentence_context, std
corpus_total += total;
corpus_oov += oov;
}
- out_stream << "Perplexity " << pow(10.0, -(corpus_total / static_cast<double>(corpus_tokens))) << std::endl;
+ out_stream <<
+ "Perplexity including OOVs:\t" << pow(10.0, -(corpus_total / static_cast<double>(corpus_tokens))) << "\n"
+ "Perplexity excluding OOVs:\t" << pow(10.0, -((corpus_total - corpus_total_oov_only) / static_cast<double>(corpus_tokens - corpus_oov))) << "\n"
+ "OOVs:\t" << corpus_oov << "\n"
+ ;
}
-template <class M> void Query(const char *file, bool sentence_context, std::istream &in_stream, std::ostream &out_stream) {
+template <class M> void Query(const char *file, bool sentence_context) {
Config config;
M model(file, config);
- Query(model, sentence_context, in_stream, out_stream);
+ Query(model, sentence_context);
}
} // namespace ngram
diff --git a/lm/query_main.cc b/lm/query_main.cc
index bd4fde6..cd661f7 100644
--- a/lm/query_main.cc
+++ b/lm/query_main.cc
@@ -32,22 +32,22 @@ int main(int argc, char *argv[]) {
if (RecognizeBinary(file, model_type)) {
switch(model_type) {
case PROBING:
- Query<lm::ngram::ProbingModel>(file, sentence_context, std::cin, std::cout);
+ Query<lm::ngram::ProbingModel>(file, sentence_context);
break;
case REST_PROBING:
- Query<lm::ngram::RestProbingModel>(file, sentence_context, std::cin, std::cout);
+ Query<lm::ngram::RestProbingModel>(file, sentence_context);
break;
case TRIE:
- Query<TrieModel>(file, sentence_context, std::cin, std::cout);
+ Query<TrieModel>(file, sentence_context);
break;
case QUANT_TRIE:
- Query<QuantTrieModel>(file, sentence_context, std::cin, std::cout);
+ Query<QuantTrieModel>(file, sentence_context);
break;
case ARRAY_TRIE:
- Query<ArrayTrieModel>(file, sentence_context, std::cin, std::cout);
+ Query<ArrayTrieModel>(file, sentence_context);
break;
case QUANT_ARRAY_TRIE:
- Query<QuantArrayTrieModel>(file, sentence_context, std::cin, std::cout);
+ Query<QuantArrayTrieModel>(file, sentence_context);
break;
default:
std::cerr << "Unrecognized kenlm model type " << model_type << std::endl;
@@ -56,10 +56,10 @@ int main(int argc, char *argv[]) {
#ifdef WITH_NPLM
} else if (lm::np::Model::Recognize(file)) {
lm::np::Model model(file);
- Query(model, sentence_context, std::cin, std::cout);
+ Query(model, sentence_context);
#endif
} else {
- Query<ProbingModel>(file, sentence_context, std::cin, std::cout);
+ Query<ProbingModel>(file, sentence_context);
}
std::cerr << "Total time including destruction:\n";
util::PrintUsage(std::cerr);
diff --git a/util/Jamfile b/util/Jamfile
index 77b5438..afab916 100644
--- a/util/Jamfile
+++ b/util/Jamfile
@@ -22,6 +22,7 @@ obj file_piece_test.o : file_piece_test.cc /top//boost_unit_test_framework : $(c
fakelib kenutil : bit_packing.cc ersatz_progress.cc exception.cc file.cc file_piece.cc mmap.cc murmur_hash.cc pool.cc read_compressed scoped.cc string_piece.cc usage.cc double-conversion//double-conversion : <include>.. <os>LINUX,<threading>single:<source>rt : : <include>.. ;
exe cat_compressed : cat_compressed_main.cc kenutil ;
+exe file_piece_cat : file_piece_main.cc kenutil ;
alias programs : cat_compressed ;
diff --git a/util/file_piece.hh b/util/file_piece.hh
index 1054c18..83bcd4f 100644
--- a/util/file_piece.hh
+++ b/util/file_piece.hh
@@ -56,6 +56,26 @@ class FilePiece {
return Consume(FindDelimiterOrEOF(delim));
}
+ // Read word until the line or file ends.
+ bool ReadWordSameLine(StringPiece &to, const bool *delim = kSpaces) {
+ assert(delim[static_cast<unsigned char>('\n')]);
+ // Skip non-enter spaces.
+ for (; ; ++position_) {
+ if (position_ == position_end_) {
+ try {
+ Shift();
+ } catch (const util::EndOfFileException &e) { return false; }
+ // And break out at end of file.
+ if (position_ == position_end_) return false;
+ }
+ if (!delim[static_cast<unsigned char>(*position_)]) break;
+ if (*position_ == '\n') return false;
+ }
+ // We can't be at the end of file because there's at least one character open.
+ to = Consume(FindDelimiterOrEOF(delim));
+ return true;
+ }
+
// Unlike ReadDelimited, this includes leading spaces and consumes the delimiter.
// It is similar to getline in that way.
StringPiece ReadLine(char delim = '\n');