Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/moses-smt/mosesdecoder.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--lm/CMakeLists.txt183
-rw-r--r--lm/builder/CMakeLists.txt87
-rw-r--r--lm/builder/corpus_count.cc3
-rw-r--r--lm/builder/corpus_count_test.cc5
-rw-r--r--lm/builder/debug_print.hh (renamed from lm/builder/print.hh)61
-rw-r--r--lm/builder/dump_counts_main.cc4
-rw-r--r--lm/builder/header_info.hh4
-rw-r--r--lm/builder/initial_probabilities.cc2
-rw-r--r--lm/builder/initial_probabilities.hh3
-rw-r--r--lm/builder/interpolate.cc11
-rw-r--r--lm/builder/interpolate.hh2
-rw-r--r--lm/builder/lmplz_main.cc35
-rw-r--r--lm/builder/output.cc27
-rw-r--r--lm/builder/output.hh48
-rw-r--r--lm/builder/pipeline.cc35
-rw-r--r--lm/builder/pipeline.hh1
-rw-r--r--lm/builder/print.cc64
-rw-r--r--lm/common/CMakeLists.txt40
-rw-r--r--lm/common/Jamfile2
-rw-r--r--lm/common/joint_order.hh (renamed from lm/builder/joint_order.hh)29
-rw-r--r--lm/common/model_buffer.cc49
-rw-r--r--lm/common/model_buffer.hh46
-rw-r--r--lm/common/ngram.hh5
-rw-r--r--lm/common/ngram_stream.hh41
-rw-r--r--lm/common/print.cc62
-rw-r--r--lm/common/print.hh58
-rw-r--r--lm/common/size_option.cc24
-rw-r--r--lm/common/size_option.hh11
-rw-r--r--lm/common/special.hh (renamed from lm/builder/special.hh)10
-rw-r--r--lm/filter/CMakeLists.txt62
-rw-r--r--util/CMakeLists.txt109
-rw-r--r--util/double-conversion/CMakeLists.txt39
-rw-r--r--util/file.cc13
-rw-r--r--util/file.hh22
-rw-r--r--util/fixed_array.hh35
-rw-r--r--util/float_to_string.hh4
-rw-r--r--util/probing_hash_table.hh3
-rw-r--r--util/probing_hash_table_benchmark_main.cc14
-rw-r--r--util/stream/CMakeLists.txt74
-rw-r--r--util/stream/Jamfile8
-rw-r--r--util/stream/chain.cc2
-rw-r--r--util/stream/chain.hh15
-rw-r--r--util/stream/count_records.cc12
-rw-r--r--util/stream/count_records.hh20
-rw-r--r--util/stream/multi_stream.hh17
-rw-r--r--util/stream/rewindable_stream.cc157
-rw-r--r--util/stream/rewindable_stream.hh46
-rw-r--r--util/stream/rewindable_stream_test.cc2
48 files changed, 1139 insertions, 467 deletions
diff --git a/lm/CMakeLists.txt b/lm/CMakeLists.txt
index 62de6f0b5..195fc730c 100644
--- a/lm/CMakeLists.txt
+++ b/lm/CMakeLists.txt
@@ -1,46 +1,139 @@
+cmake_minimum_required(VERSION 2.8.8)
+#
+# The KenLM cmake files make use of add_library(... OBJECTS ...)
+#
+# This syntax allows grouping of source files when compiling
+# (effectively creating "fake" libraries based on source subdirs).
+#
+# This syntax was only added in cmake version 2.8.8
+#
+# see http://www.cmake.org/Wiki/CMake/Tutorials/Object_Library
+
+
+# This CMake file was created by Lane Schwartz <dowobeha@gmail.com>
+
+
+set(KENLM_MAX_ORDER 6)
+
+add_definitions(-DKENLM_MAX_ORDER=${KENLM_MAX_ORDER})
+
+
+# Explicitly list the source files for this subdirectory
+#
+# If you add any source files to this subdirectory
+# that should be included in the kenlm library,
+# (this excludes any unit test files)
+# you should add them to the following list:
+set(KENLM_SOURCE
+ bhiksha.cc
+ binary_format.cc
+ config.cc
+ lm_exception.cc
+ model.cc
+ quantize.cc
+ read_arpa.cc
+ search_hashed.cc
+ search_trie.cc
+ sizes.cc
+ trie.cc
+ trie_sort.cc
+ value_build.cc
+ virtual_interface.cc
+ vocab.cc
+)
+
+
+# Group these objects together for later use.
+#
+# Given add_library(foo OBJECT ${my_foo_sources}),
+# refer to these objects as $<TARGET_OBJECTS:foo>
+#
+add_library(kenlm OBJECT ${KENLM_SOURCE})
+
+# This directory has children that need to be processed
+add_subdirectory(builder)
+add_subdirectory(common)
+add_subdirectory(filter)
+
+
+
+# Explicitly list the executable files to be compiled
+set(EXE_LIST
+ query
+ fragment
+ build_binary
+)
+
+# Iterate through the executable list
+foreach(exe ${EXE_LIST})
+
+ # Compile the executable, linking against the requisite dependent object files
+ add_executable(${exe} ${exe}_main.cc $<TARGET_OBJECTS:kenlm> $<TARGET_OBJECTS:kenlm_util>)
+
+ # Link the executable against boost
+ target_link_libraries(${exe} ${Boost_LIBRARIES})
+
+ # Group executables together
+ set_target_properties(${exe} PROPERTIES FOLDER executables)
+
+# End for loop
+endforeach(exe)
+
+
+# Install the executable files
+install(TARGETS ${EXE_LIST} DESTINATION bin)
+
+
+
+if(BUILD_TESTING)
+
+ # Explicitly list the Boost test files to be compiled
+ set(KENLM_BOOST_TESTS_LIST
+ left_test
+ model_test
+ partial_test
+ )
+
+ # Iterate through the Boost tests list
+ foreach(test ${KENLM_BOOST_TESTS_LIST})
+
+ # Compile the executable, linking against the requisite dependent object files
+ add_executable(${test} ${test}.cc $<TARGET_OBJECTS:kenlm> $<TARGET_OBJECTS:kenlm_util>)
+
+ # Require the following compile flag
+ set_target_properties(${test} PROPERTIES COMPILE_FLAGS -DBOOST_TEST_DYN_LINK)
+
+ # Link the executable against boost
+ target_link_libraries(${test} ${Boost_LIBRARIES})
+
+ # model_test requires an extra command line parameter
+ if ("${test}" STREQUAL "model_test")
+ set(test_params
+ ${CMAKE_CURRENT_SOURCE_DIR}/test.arpa
+ ${CMAKE_CURRENT_SOURCE_DIR}/test_nounk.arpa
+ )
+ else()
+ set(test_params
+ ${CMAKE_CURRENT_SOURCE_DIR}/test.arpa
+ )
+ endif()
+
+ # Specify command arguments for how to run each unit test
+ #
+ # Assuming that foo was defined via add_executable(foo ...),
+ # the syntax $<TARGET_FILE:foo> gives the full path to the executable.
+ #
+ add_test(NAME ${test}_test
+ COMMAND $<TARGET_FILE:${test}> ${test_params})
+
+ # Group unit tests together
+ set_target_properties(${test} PROPERTIES FOLDER "unit_tests")
+
+ # End for loop
+ endforeach(test)
+
+endif()
+
+
+
-list(APPEND SOURCE_KENLM "${CMAKE_CURRENT_SOURCE_DIR}/bhiksha.cc")
-list(APPEND SOURCE_KENLM "${CMAKE_CURRENT_SOURCE_DIR}/bhiksha.hh")
-list(APPEND SOURCE_KENLM "${CMAKE_CURRENT_SOURCE_DIR}/binary_format.cc")
-list(APPEND SOURCE_KENLM "${CMAKE_CURRENT_SOURCE_DIR}/binary_format.hh")
-list(APPEND SOURCE_KENLM "${CMAKE_CURRENT_SOURCE_DIR}/blank.hh")
-list(APPEND SOURCE_KENLM "${CMAKE_CURRENT_SOURCE_DIR}/config.cc")
-list(APPEND SOURCE_KENLM "${CMAKE_CURRENT_SOURCE_DIR}/config.hh")
-list(APPEND SOURCE_KENLM "${CMAKE_CURRENT_SOURCE_DIR}/enumerate_vocab.hh")
-list(APPEND SOURCE_KENLM "${CMAKE_CURRENT_SOURCE_DIR}/facade.hh")
-list(APPEND SOURCE_KENLM "${CMAKE_CURRENT_SOURCE_DIR}/left.hh")
-list(APPEND SOURCE_KENLM "${CMAKE_CURRENT_SOURCE_DIR}/lm_exception.cc")
-list(APPEND SOURCE_KENLM "${CMAKE_CURRENT_SOURCE_DIR}/lm_exception.hh")
-list(APPEND SOURCE_KENLM "${CMAKE_CURRENT_SOURCE_DIR}/max_order.hh")
-list(APPEND SOURCE_KENLM "${CMAKE_CURRENT_SOURCE_DIR}/model.cc")
-list(APPEND SOURCE_KENLM "${CMAKE_CURRENT_SOURCE_DIR}/model.hh")
-list(APPEND SOURCE_KENLM "${CMAKE_CURRENT_SOURCE_DIR}/model_type.hh")
-list(APPEND SOURCE_KENLM "${CMAKE_CURRENT_SOURCE_DIR}/ngram_query.hh")
-list(APPEND SOURCE_KENLM "${CMAKE_CURRENT_SOURCE_DIR}/partial.hh")
-list(APPEND SOURCE_KENLM "${CMAKE_CURRENT_SOURCE_DIR}/quantize.cc")
-list(APPEND SOURCE_KENLM "${CMAKE_CURRENT_SOURCE_DIR}/quantize.hh")
-list(APPEND SOURCE_KENLM "${CMAKE_CURRENT_SOURCE_DIR}/read_arpa.cc")
-list(APPEND SOURCE_KENLM "${CMAKE_CURRENT_SOURCE_DIR}/read_arpa.hh")
-list(APPEND SOURCE_KENLM "${CMAKE_CURRENT_SOURCE_DIR}/return.hh")
-list(APPEND SOURCE_KENLM "${CMAKE_CURRENT_SOURCE_DIR}/search_hashed.cc")
-list(APPEND SOURCE_KENLM "${CMAKE_CURRENT_SOURCE_DIR}/search_hashed.hh")
-list(APPEND SOURCE_KENLM "${CMAKE_CURRENT_SOURCE_DIR}/search_trie.cc")
-list(APPEND SOURCE_KENLM "${CMAKE_CURRENT_SOURCE_DIR}/search_trie.hh")
-list(APPEND SOURCE_KENLM "${CMAKE_CURRENT_SOURCE_DIR}/sizes.cc")
-list(APPEND SOURCE_KENLM "${CMAKE_CURRENT_SOURCE_DIR}/sizes.hh")
-list(APPEND SOURCE_KENLM "${CMAKE_CURRENT_SOURCE_DIR}/state.hh")
-list(APPEND SOURCE_KENLM "${CMAKE_CURRENT_SOURCE_DIR}/trie.cc")
-list(APPEND SOURCE_KENLM "${CMAKE_CURRENT_SOURCE_DIR}/trie.hh")
-list(APPEND SOURCE_KENLM "${CMAKE_CURRENT_SOURCE_DIR}/trie_sort.cc")
-list(APPEND SOURCE_KENLM "${CMAKE_CURRENT_SOURCE_DIR}/trie_sort.hh")
-list(APPEND SOURCE_KENLM "${CMAKE_CURRENT_SOURCE_DIR}/value.hh")
-list(APPEND SOURCE_KENLM "${CMAKE_CURRENT_SOURCE_DIR}/value_build.cc")
-list(APPEND SOURCE_KENLM "${CMAKE_CURRENT_SOURCE_DIR}/value_build.hh")
-list(APPEND SOURCE_KENLM "${CMAKE_CURRENT_SOURCE_DIR}/virtual_interface.cc")
-list(APPEND SOURCE_KENLM "${CMAKE_CURRENT_SOURCE_DIR}/virtual_interface.hh")
-list(APPEND SOURCE_KENLM "${CMAKE_CURRENT_SOURCE_DIR}/vocab.cc")
-list(APPEND SOURCE_KENLM "${CMAKE_CURRENT_SOURCE_DIR}/vocab.hh")
-list(APPEND SOURCE_KENLM "${CMAKE_CURRENT_SOURCE_DIR}/weights.hh")
-list(APPEND SOURCE_KENLM "${CMAKE_CURRENT_SOURCE_DIR}/word_index.hh")
-
-add_library(kenlm OBJECT ${SOURCE_KENLM}) \ No newline at end of file
diff --git a/lm/builder/CMakeLists.txt b/lm/builder/CMakeLists.txt
new file mode 100644
index 000000000..d84a7f7da
--- /dev/null
+++ b/lm/builder/CMakeLists.txt
@@ -0,0 +1,87 @@
+cmake_minimum_required(VERSION 2.8.8)
+#
+# The KenLM cmake files make use of add_library(... OBJECTS ...)
+#
+# This syntax allows grouping of source files when compiling
+# (effectively creating "fake" libraries based on source subdirs).
+#
+# This syntax was only added in cmake version 2.8.8
+#
+# see http://www.cmake.org/Wiki/CMake/Tutorials/Object_Library
+
+
+# This CMake file was created by Lane Schwartz <dowobeha@gmail.com>
+
+# Explicitly list the source files for this subdirectory
+#
+# If you add any source files to this subdirectory
+# that should be included in the kenlm library,
+# (this excludes any unit test files)
+# you should add them to the following list:
+#
+# In order to set correct paths to these files
+# in case this variable is referenced by CMake files in the parent directory,
+# we prefix all files with ${CMAKE_CURRENT_SOURCE_DIR}.
+#
+set(KENLM_BUILDER_SOURCE
+ ${CMAKE_CURRENT_SOURCE_DIR}/adjust_counts.cc
+ ${CMAKE_CURRENT_SOURCE_DIR}/corpus_count.cc
+ ${CMAKE_CURRENT_SOURCE_DIR}/initial_probabilities.cc
+ ${CMAKE_CURRENT_SOURCE_DIR}/interpolate.cc
+ ${CMAKE_CURRENT_SOURCE_DIR}/output.cc
+ ${CMAKE_CURRENT_SOURCE_DIR}/pipeline.cc
+ )
+
+
+# Group these objects together for later use.
+#
+# Given add_library(foo OBJECT ${my_foo_sources}),
+# refer to these objects as $<TARGET_OBJECTS:foo>
+#
+add_library(kenlm_builder OBJECT ${KENLM_BUILDER_SOURCE})
+
+
+# Compile the executable, linking against the requisite dependent object files
+add_executable(lmplz lmplz_main.cc $<TARGET_OBJECTS:kenlm> $<TARGET_OBJECTS:kenlm_common> $<TARGET_OBJECTS:kenlm_builder> $<TARGET_OBJECTS:kenlm_util>)
+
+# Link the executable against boost
+target_link_libraries(lmplz ${Boost_LIBRARIES})
+
+# Group executables together
+set_target_properties(lmplz PROPERTIES FOLDER executables)
+
+if(BUILD_TESTING)
+
+ # Explicitly list the Boost test files to be compiled
+ set(KENLM_BOOST_TESTS_LIST
+ adjust_counts_test
+ corpus_count_test
+ )
+
+ # Iterate through the Boost tests list
+ foreach(test ${KENLM_BOOST_TESTS_LIST})
+
+ # Compile the executable, linking against the requisite dependent object files
+ add_executable(${test} ${test}.cc $<TARGET_OBJECTS:kenlm> $<TARGET_OBJECTS:kenlm_common> $<TARGET_OBJECTS:kenlm_builder> $<TARGET_OBJECTS:kenlm_util>)
+
+ # Require the following compile flag
+ set_target_properties(${test} PROPERTIES COMPILE_FLAGS "-DBOOST_TEST_DYN_LINK -DBOOST_PROGRAM_OPTIONS_DYN_LINK")
+
+ # Link the executable against boost
+ target_link_libraries(${test} ${Boost_LIBRARIES})
+
+ # Specify command arguments for how to run each unit test
+ #
+ # Assuming that foo was defined via add_executable(foo ...),
+ # the syntax $<TARGET_FILE:foo> gives the full path to the executable.
+ #
+ add_test(NAME ${test}_test
+ COMMAND $<TARGET_FILE:${test}>)
+
+ # Group unit tests together
+ set_target_properties(${test} PROPERTIES FOLDER "unit_tests")
+
+ # End for loop
+ endforeach(test)
+
+endif()
diff --git a/lm/builder/corpus_count.cc b/lm/builder/corpus_count.cc
index 9f23b28a8..04815d805 100644
--- a/lm/builder/corpus_count.cc
+++ b/lm/builder/corpus_count.cc
@@ -15,9 +15,6 @@
#include "util/stream/timer.hh"
#include "util/tokenize_piece.hh"
-#include <boost/unordered_set.hpp>
-#include <boost/unordered_map.hpp>
-
#include <functional>
#include <stdint.h>
diff --git a/lm/builder/corpus_count_test.cc b/lm/builder/corpus_count_test.cc
index 82f859690..88bcf9657 100644
--- a/lm/builder/corpus_count_test.cc
+++ b/lm/builder/corpus_count_test.cc
@@ -43,12 +43,13 @@ BOOST_AUTO_TEST_CASE(Short) {
util::scoped_fd vocab(util::MakeTemp("corpus_count_test_vocab"));
util::stream::Chain chain(config);
- NGramStream<BuildingPayload> stream;
uint64_t token_count;
WordIndex type_count = 10;
std::vector<bool> prune_words;
CorpusCount counter(input_piece, vocab.get(), token_count, type_count, prune_words, "", chain.BlockSize() / chain.EntrySize(), SILENT);
- chain >> boost::ref(counter) >> stream >> util::stream::kRecycle;
+ chain >> boost::ref(counter);
+ NGramStream<BuildingPayload> stream(chain.Add());
+ chain >> util::stream::kRecycle;
const char *v[] = {"<unk>", "<s>", "</s>", "looking", "on", "a", "little", "more", "loin", "foo", "bar"};
diff --git a/lm/builder/print.hh b/lm/builder/debug_print.hh
index 5f293de85..193a6892c 100644
--- a/lm/builder/print.hh
+++ b/lm/builder/debug_print.hh
@@ -1,54 +1,18 @@
-#ifndef LM_BUILDER_PRINT_H
-#define LM_BUILDER_PRINT_H
+#ifndef LM_BUILDER_DEBUG_PRINT_H
+#define LM_BUILDER_DEBUG_PRINT_H
-#include "lm/common/ngram_stream.hh"
-#include "lm/builder/output.hh"
#include "lm/builder/payload.hh"
-#include "lm/common/ngram.hh"
+#include "lm/common/print.hh"
+#include "lm/common/ngram_stream.hh"
#include "util/fake_ofstream.hh"
#include "util/file.hh"
-#include "util/mmap.hh"
-#include "util/string_piece.hh"
#include <boost/lexical_cast.hpp>
-#include <ostream>
-#include <cassert>
-
-// Warning: print routines read all unigrams before all bigrams before all
-// trigrams etc. So if other parts of the chain move jointly, you'll have to
-// buffer.
-
namespace lm { namespace builder {
-
-class VocabReconstitute {
- public:
- // fd must be alive for life of this object; does not take ownership.
- explicit VocabReconstitute(int fd);
-
- const char *Lookup(WordIndex index) const {
- assert(index < map_.size() - 1);
- return map_[index];
- }
-
- StringPiece LookupPiece(WordIndex index) const {
- return StringPiece(map_[index], map_[index + 1] - 1 - map_[index]);
- }
-
- std::size_t Size() const {
- // There's an extra entry to support StringPiece lengths.
- return map_.size() - 1;
- }
-
- private:
- util::scoped_memory memory_;
- std::vector<const char*> map_;
-};
-
// Not defined, only specialized.
template <class T> void PrintPayload(util::FakeOFStream &to, const BuildingPayload &payload);
template <> inline void PrintPayload<uint64_t>(util::FakeOFStream &to, const BuildingPayload &payload) {
- // TODO slow
to << payload.count;
}
template <> inline void PrintPayload<Uninterpolated>(util::FakeOFStream &to, const BuildingPayload &payload) {
@@ -101,19 +65,6 @@ template <class V> class Print {
int to_;
};
-class PrintARPA : public OutputHook {
- public:
- explicit PrintARPA(int fd, bool verbose_header)
- : OutputHook(PROB_SEQUENTIAL_HOOK), out_fd_(fd), verbose_header_(verbose_header) {}
-
- void Sink(util::stream::Chains &chains);
-
- void Run(const util::stream::ChainPositions &positions);
-
- private:
- util::scoped_fd out_fd_;
- bool verbose_header_;
-};
-
}} // namespaces
-#endif // LM_BUILDER_PRINT_H
+
+#endif // LM_BUILDER_DEBUG_PRINT_H
diff --git a/lm/builder/dump_counts_main.cc b/lm/builder/dump_counts_main.cc
index fa0016792..a4c9478b6 100644
--- a/lm/builder/dump_counts_main.cc
+++ b/lm/builder/dump_counts_main.cc
@@ -1,4 +1,4 @@
-#include "lm/builder/print.hh"
+#include "lm/common/print.hh"
#include "lm/word_index.hh"
#include "util/file.hh"
#include "util/read_compressed.hh"
@@ -20,7 +20,7 @@ int main(int argc, char *argv[]) {
}
util::ReadCompressed counts(util::OpenReadOrThrow(argv[1]));
util::scoped_fd vocab_file(util::OpenReadOrThrow(argv[2]));
- lm::builder::VocabReconstitute vocab(vocab_file.get());
+ lm::VocabReconstitute vocab(vocab_file.get());
unsigned int order = boost::lexical_cast<unsigned int>(argv[3]);
std::vector<char> record(sizeof(uint32_t) * order + sizeof(uint64_t));
while (std::size_t got = counts.ReadOrEOF(&*record.begin(), record.size())) {
diff --git a/lm/builder/header_info.hh b/lm/builder/header_info.hh
index 146195233..d01d0496b 100644
--- a/lm/builder/header_info.hh
+++ b/lm/builder/header_info.hh
@@ -5,6 +5,8 @@
#include <vector>
#include <stdint.h>
+namespace lm { namespace builder {
+
// Some configuration info that is used to add
// comments to the beginning of an ARPA file
struct HeaderInfo {
@@ -21,4 +23,6 @@ struct HeaderInfo {
// TODO: More info if multiple models were interpolated
};
+}} // namespaces
+
#endif
diff --git a/lm/builder/initial_probabilities.cc b/lm/builder/initial_probabilities.cc
index ef8a8ecfd..5b8d86d33 100644
--- a/lm/builder/initial_probabilities.cc
+++ b/lm/builder/initial_probabilities.cc
@@ -1,9 +1,9 @@
#include "lm/builder/initial_probabilities.hh"
#include "lm/builder/discount.hh"
-#include "lm/builder/special.hh"
#include "lm/builder/hash_gamma.hh"
#include "lm/builder/payload.hh"
+#include "lm/common/special.hh"
#include "lm/common/ngram_stream.hh"
#include "util/murmur_hash.hh"
#include "util/file.hh"
diff --git a/lm/builder/initial_probabilities.hh b/lm/builder/initial_probabilities.hh
index dddbbb913..caeea58c5 100644
--- a/lm/builder/initial_probabilities.hh
+++ b/lm/builder/initial_probabilities.hh
@@ -10,9 +10,8 @@
namespace util { namespace stream { class Chains; } }
namespace lm {
-namespace builder {
-
class SpecialVocab;
+namespace builder {
struct InitialProbabilitiesConfig {
// These should be small buffers to keep the adder from getting too far ahead
diff --git a/lm/builder/interpolate.cc b/lm/builder/interpolate.cc
index 5f0a339bc..6374bcf04 100644
--- a/lm/builder/interpolate.cc
+++ b/lm/builder/interpolate.cc
@@ -1,16 +1,16 @@
#include "lm/builder/interpolate.hh"
#include "lm/builder/hash_gamma.hh"
-#include "lm/builder/joint_order.hh"
-#include "lm/common/ngram_stream.hh"
+#include "lm/builder/payload.hh"
#include "lm/common/compare.hh"
+#include "lm/common/joint_order.hh"
+#include "lm/common/ngram_stream.hh"
#include "lm/lm_exception.hh"
#include "util/fixed_array.hh"
#include "util/murmur_hash.hh"
#include <cassert>
#include <cmath>
-#include <iostream>
namespace lm { namespace builder {
namespace {
@@ -91,7 +91,8 @@ template <class Output> class Callback {
}
}
- void Enter(unsigned order_minus_1, NGram<BuildingPayload> &gram) {
+ void Enter(unsigned order_minus_1, void *data) {
+ NGram<BuildingPayload> gram(data, order_minus_1 + 1);
BuildingPayload &pay = gram.Value();
pay.complete.prob = pay.uninterp.prob + pay.uninterp.gamma * probs_[order_minus_1];
probs_[order_minus_1 + 1] = pay.complete.prob;
@@ -125,7 +126,7 @@ template <class Output> class Callback {
output_.Gram(order_minus_1, out_backoff, pay.complete);
}
- void Exit(unsigned, const NGram<BuildingPayload> &) const {}
+ void Exit(unsigned, void *) const {}
private:
util::FixedArray<util::stream::Stream> backoffs_;
diff --git a/lm/builder/interpolate.hh b/lm/builder/interpolate.hh
index dcee75adb..d20cd545c 100644
--- a/lm/builder/interpolate.hh
+++ b/lm/builder/interpolate.hh
@@ -1,7 +1,7 @@
#ifndef LM_BUILDER_INTERPOLATE_H
#define LM_BUILDER_INTERPOLATE_H
-#include "lm/builder/special.hh"
+#include "lm/common/special.hh"
#include "lm/word_index.hh"
#include "util/stream/multi_stream.hh"
diff --git a/lm/builder/lmplz_main.cc b/lm/builder/lmplz_main.cc
index c27490665..cc3f381ca 100644
--- a/lm/builder/lmplz_main.cc
+++ b/lm/builder/lmplz_main.cc
@@ -1,6 +1,6 @@
#include "lm/builder/output.hh"
#include "lm/builder/pipeline.hh"
-#include "lm/builder/print.hh"
+#include "lm/common/size_option.hh"
#include "lm/lm_exception.hh"
#include "util/file.hh"
#include "util/file_piece.hh"
@@ -13,21 +13,6 @@
#include <vector>
namespace {
-class SizeNotify {
- public:
- SizeNotify(std::size_t &out) : behind_(out) {}
-
- void operator()(const std::string &from) {
- behind_ = util::ParseSize(from);
- }
-
- private:
- std::size_t &behind_;
-};
-
-boost::program_options::typed_value<std::string> *SizeOption(std::size_t &to, const char *default_value) {
- return boost::program_options::value<std::string>()->notifier(SizeNotify(to))->default_value(default_value);
-}
// Parse and validate pruning thresholds then return vector of threshold counts
// for each n-grams order.
@@ -106,17 +91,16 @@ int main(int argc, char *argv[]) {
("interpolate_unigrams", po::value<bool>(&pipeline.initial_probs.interpolate_unigrams)->default_value(true)->implicit_value(true), "Interpolate the unigrams (default) as opposed to giving lots of mass to <unk> like SRI. If you want SRI's behavior with a large <unk> and the old lmplz default, use --interpolate_unigrams 0.")
("skip_symbols", po::bool_switch(), "Treat <s>, </s>, and <unk> as whitespace instead of throwing an exception")
("temp_prefix,T", po::value<std::string>(&pipeline.sort.temp_prefix)->default_value("/tmp/lm"), "Temporary file prefix")
- ("memory,S", SizeOption(pipeline.sort.total_memory, util::GuessPhysicalMemory() ? "80%" : "1G"), "Sorting memory")
- ("minimum_block", SizeOption(pipeline.minimum_block, "8K"), "Minimum block size to allow")
- ("sort_block", SizeOption(pipeline.sort.buffer_size, "64M"), "Size of IO operations for sort (determines arity)")
+ ("memory,S", lm:: SizeOption(pipeline.sort.total_memory, util::GuessPhysicalMemory() ? "80%" : "1G"), "Sorting memory")
+ ("minimum_block", lm::SizeOption(pipeline.minimum_block, "8K"), "Minimum block size to allow")
+ ("sort_block", lm::SizeOption(pipeline.sort.buffer_size, "64M"), "Size of IO operations for sort (determines arity)")
("block_count", po::value<std::size_t>(&pipeline.block_count)->default_value(2), "Block count (per order)")
("vocab_estimate", po::value<lm::WordIndex>(&pipeline.vocab_estimate)->default_value(1000000), "Assume this vocabulary size for purposes of calculating memory in step 1 (corpus count) and pre-sizing the hash table")
- ("vocab_file", po::value<std::string>(&pipeline.vocab_file)->default_value(""), "Location to write a file containing the unique vocabulary strings delimited by null bytes")
("vocab_pad", po::value<uint64_t>(&pipeline.vocab_size_for_unk)->default_value(0), "If the vocabulary is smaller than this value, pad with <unk> to reach this size. Requires --interpolate_unigrams")
("verbose_header", po::bool_switch(&verbose_header), "Add a verbose header to the ARPA file that includes information such as token count, smoothing type, etc.")
("text", po::value<std::string>(&text), "Read text from a file instead of stdin")
("arpa", po::value<std::string>(&arpa), "Write ARPA to a file instead of stdout")
- ("intermediate", po::value<std::string>(&intermediate), "Write ngrams to an intermediate file. Turns off ARPA output (which can be reactivated by --arpa file). Forces --renumber on. Implicitly makes --vocab_file be the provided name + .vocab.")
+ ("intermediate", po::value<std::string>(&intermediate), "Write ngrams to intermediate files. Turns off ARPA output (which can be reactivated by --arpa file). Forces --renumber on.")
("renumber", po::bool_switch(&pipeline.renumber_vocabulary), "Rrenumber the vocabulary identifiers so that they are monotone with the hash of each string. This is consistent with the ordering used by the trie data structure.")
("collapse_values", po::bool_switch(&pipeline.output_q), "Collapse probability and backoff into a single value, q that yields the same sentence-level probabilities. See http://kheafield.com/professional/edinburgh/rest_paper.pdf for more details, including a proof.")
("prune", po::value<std::vector<std::string> >(&pruning)->multitoken(), "Prune n-grams with count less than or equal to the given threshold. Specify one value for each order i.e. 0 0 1 to prune singleton trigrams and above. The sequence of values must be non-decreasing and the last value applies to any remaining orders. Default is to not prune, which is equivalent to --prune 0.")
@@ -217,15 +201,10 @@ int main(int argc, char *argv[]) {
bool writing_intermediate = vm.count("intermediate");
if (writing_intermediate) {
pipeline.renumber_vocabulary = true;
- if (!pipeline.vocab_file.empty()) {
- std::cerr << "--intermediate and --vocab_file are incompatible because --intermediate already makes a vocab file." << std::endl;
- return 1;
- }
- pipeline.vocab_file = intermediate + ".vocab";
}
- lm::builder::Output output(writing_intermediate ? intermediate : pipeline.sort.temp_prefix, writing_intermediate);
+ lm::builder::Output output(writing_intermediate ? intermediate : pipeline.sort.temp_prefix, writing_intermediate, pipeline.output_q);
if (!writing_intermediate || vm.count("arpa")) {
- output.Add(new lm::builder::PrintARPA(out.release(), verbose_header));
+ output.Add(new lm::builder::PrintHook(out.release(), verbose_header));
}
lm::builder::Pipeline(pipeline, in.release(), output);
} catch (const util::MallocException &e) {
diff --git a/lm/builder/output.cc b/lm/builder/output.cc
index 76478ad06..c92283ac6 100644
--- a/lm/builder/output.cc
+++ b/lm/builder/output.cc
@@ -1,6 +1,8 @@
#include "lm/builder/output.hh"
#include "lm/common/model_buffer.hh"
+#include "lm/common/print.hh"
+#include "util/fake_ofstream.hh"
#include "util/stream/multi_stream.hh"
#include <iostream>
@@ -9,23 +11,22 @@ namespace lm { namespace builder {
OutputHook::~OutputHook() {}
-Output::Output(StringPiece file_base, bool keep_buffer)
- : file_base_(file_base.data(), file_base.size()), keep_buffer_(keep_buffer) {}
+Output::Output(StringPiece file_base, bool keep_buffer, bool output_q)
+ : buffer_(file_base, keep_buffer, output_q) {}
-void Output::SinkProbs(util::stream::Chains &chains, bool output_q) {
+void Output::SinkProbs(util::stream::Chains &chains) {
Apply(PROB_PARALLEL_HOOK, chains);
- if (!keep_buffer_ && !Have(PROB_SEQUENTIAL_HOOK)) {
+ if (!buffer_.Keep() && !Have(PROB_SEQUENTIAL_HOOK)) {
chains >> util::stream::kRecycle;
chains.Wait(true);
return;
}
- lm::common::ModelBuffer buf(file_base_, keep_buffer_, output_q);
- buf.Sink(chains);
+ buffer_.Sink(chains, header_.counts_pruned);
chains >> util::stream::kRecycle;
chains.Wait(false);
if (Have(PROB_SEQUENTIAL_HOOK)) {
std::cerr << "=== 5/5 Writing ARPA model ===" << std::endl;
- buf.Source(chains);
+ buffer_.Source(chains);
Apply(PROB_SEQUENTIAL_HOOK, chains);
chains >> util::stream::kRecycle;
chains.Wait(true);
@@ -34,8 +35,18 @@ void Output::SinkProbs(util::stream::Chains &chains, bool output_q) {
void Output::Apply(HookType hook_type, util::stream::Chains &chains) {
for (boost::ptr_vector<OutputHook>::iterator entry = outputs_[hook_type].begin(); entry != outputs_[hook_type].end(); ++entry) {
- entry->Sink(chains);
+ entry->Sink(header_, VocabFile(), chains);
}
}
+void PrintHook::Sink(const HeaderInfo &info, int vocab_file, util::stream::Chains &chains) {
+ if (verbose_header_) {
+ util::FakeOFStream out(file_.get(), 50);
+ out << "# Input file: " << info.input_file << '\n';
+ out << "# Token count: " << info.token_count << '\n';
+ out << "# Smoothing: Modified Kneser-Ney" << '\n';
+ }
+ chains >> PrintARPA(vocab_file, file_.get(), info.counts_pruned);
+}
+
}} // namespaces
diff --git a/lm/builder/output.hh b/lm/builder/output.hh
index c1e0d1469..69d6c6dac 100644
--- a/lm/builder/output.hh
+++ b/lm/builder/output.hh
@@ -2,6 +2,7 @@
#define LM_BUILDER_OUTPUT_H
#include "lm/builder/header_info.hh"
+#include "lm/common/model_buffer.hh"
#include "util/file.hh"
#include <boost/ptr_container/ptr_vector.hpp>
@@ -20,69 +21,64 @@ enum HookType {
NUMBER_OF_HOOKS // Keep this last so we know how many values there are.
};
-class Output;
-
class OutputHook {
public:
- explicit OutputHook(HookType hook_type) : type_(hook_type), master_(NULL) {}
+ explicit OutputHook(HookType hook_type) : type_(hook_type) {}
virtual ~OutputHook();
- virtual void Sink(util::stream::Chains &chains) = 0;
+ virtual void Sink(const HeaderInfo &info, int vocab_file, util::stream::Chains &chains) = 0;
- protected:
- const HeaderInfo &GetHeader() const;
- int GetVocabFD() const;
+ HookType Type() const { return type_; }
private:
- friend class Output;
- const HookType type_;
- const Output *master_;
+ HookType type_;
};
class Output : boost::noncopyable {
public:
- Output(StringPiece file_base, bool keep_buffer);
+ Output(StringPiece file_base, bool keep_buffer, bool output_q);
// Takes ownership.
void Add(OutputHook *hook) {
- hook->master_ = this;
- outputs_[hook->type_].push_back(hook);
+ outputs_[hook->Type()].push_back(hook);
}
bool Have(HookType hook_type) const {
return !outputs_[hook_type].empty();
}
- void SetVocabFD(int to) { vocab_fd_ = to; }
- int GetVocabFD() const { return vocab_fd_; }
+ int VocabFile() const { return buffer_.VocabFile(); }
void SetHeader(const HeaderInfo &header) { header_ = header; }
const HeaderInfo &GetHeader() const { return header_; }
// This is called by the pipeline.
- void SinkProbs(util::stream::Chains &chains, bool output_q);
+ void SinkProbs(util::stream::Chains &chains);
unsigned int Steps() const { return Have(PROB_SEQUENTIAL_HOOK); }
private:
void Apply(HookType hook_type, util::stream::Chains &chains);
+ ModelBuffer buffer_;
+
boost::ptr_vector<OutputHook> outputs_[NUMBER_OF_HOOKS];
- int vocab_fd_;
HeaderInfo header_;
-
- std::string file_base_;
- bool keep_buffer_;
};
-inline const HeaderInfo &OutputHook::GetHeader() const {
- return master_->GetHeader();
-}
+class PrintHook : public OutputHook {
+ public:
+ // Takes ownership
+ PrintHook(int write_fd, bool verbose_header)
+ : OutputHook(PROB_SEQUENTIAL_HOOK), file_(write_fd), verbose_header_(verbose_header) {}
-inline int OutputHook::GetVocabFD() const {
- return master_->GetVocabFD();
-}
+ void Sink(const HeaderInfo &info, int vocab_file, util::stream::Chains &chains);
+
+ private:
+ util::scoped_fd file_;
+ bool verbose_header_;
+};
}} // namespaces
diff --git a/lm/builder/pipeline.cc b/lm/builder/pipeline.cc
index d588beedf..69972e278 100644
--- a/lm/builder/pipeline.cc
+++ b/lm/builder/pipeline.cc
@@ -277,27 +277,27 @@ void InterpolateProbabilities(const std::vector<uint64_t> &counts, Master &maste
}
master >> Interpolate(std::max(master.Config().vocab_size_for_unk, counts[0] - 1 /* <s> is not included */), util::stream::ChainPositions(gamma_chains), config.prune_thresholds, config.prune_vocab, config.output_q, specials);
gamma_chains >> util::stream::kRecycle;
- output.SinkProbs(master.MutableChains(), config.output_q);
+ output.SinkProbs(master.MutableChains());
}
class VocabNumbering {
public:
- VocabNumbering(StringPiece vocab_file, StringPiece temp_prefix, bool renumber)
- : vocab_file_(vocab_file.data(), vocab_file.size()),
- temp_prefix_(temp_prefix.data(), temp_prefix.size()),
+ VocabNumbering(int final_vocab, StringPiece temp_prefix, bool renumber)
+ : final_vocab_(final_vocab),
renumber_(renumber),
specials_(kBOS, kEOS) {
- InitFile(renumber || vocab_file.empty());
+ if (renumber) {
+ temporary_.reset(util::MakeTemp(temp_prefix));
+ }
}
- int File() const { return null_delimited_.get(); }
+ int WriteOnTheFly() const { return renumber_ ? temporary_.get() : final_vocab_; }
// Compute the vocabulary mapping and return the memory used.
std::size_t ComputeMapping(WordIndex type_count) {
if (!renumber_) return 0;
- util::scoped_fd previous(null_delimited_.release());
- InitFile(vocab_file_.empty());
- ngram::SortedVocabulary::ComputeRenumbering(type_count, previous.get(), null_delimited_.get(), vocab_mapping_);
+ ngram::SortedVocabulary::ComputeRenumbering(type_count, temporary_.get(), final_vocab_, vocab_mapping_);
+ temporary_.reset();
return sizeof(WordIndex) * vocab_mapping_.size();
}
@@ -312,15 +312,9 @@ class VocabNumbering {
const SpecialVocab &Specials() const { return specials_; }
private:
- void InitFile(bool temp) {
- null_delimited_.reset(temp ?
- util::MakeTemp(temp_prefix_) :
- util::CreateOrThrow(vocab_file_.c_str()));
- }
-
- std::string vocab_file_, temp_prefix_;
-
- util::scoped_fd null_delimited_;
+ int final_vocab_;
+ // Out of order vocab file created on the fly.
+ util::scoped_fd temporary_;
bool renumber_;
@@ -349,18 +343,17 @@ void Pipeline(PipelineConfig &config, int text_file, Output &output) {
// master's destructor will wait for chains. But they might be deadlocked if
// this thread dies because e.g. it ran out of memory.
try {
- VocabNumbering numbering(config.vocab_file, config.TempPrefix(), config.renumber_vocabulary);
+ VocabNumbering numbering(output.VocabFile(), config.TempPrefix(), config.renumber_vocabulary);
uint64_t token_count;
WordIndex type_count;
std::string text_file_name;
std::vector<bool> prune_words;
util::scoped_ptr<util::stream::Sort<SuffixOrder, CombineCounts> > sorted_counts(
- CountText(text_file, numbering.File(), master, token_count, type_count, text_file_name, prune_words));
+ CountText(text_file, numbering.WriteOnTheFly(), master, token_count, type_count, text_file_name, prune_words));
std::cerr << "Unigram tokens " << token_count << " types " << type_count << std::endl;
// Create vocab mapping, which uses temporary memory, while nothing else is happening.
std::size_t subtract_for_numbering = numbering.ComputeMapping(type_count);
- output.SetVocabFD(numbering.File());
std::cerr << "=== 2/" << master.Steps() << " Calculating and sorting adjusted counts ===" << std::endl;
master.InitForAdjust(*sorted_counts, type_count, subtract_for_numbering);
diff --git a/lm/builder/pipeline.hh b/lm/builder/pipeline.hh
index 695ecf7bd..66f1fd9a8 100644
--- a/lm/builder/pipeline.hh
+++ b/lm/builder/pipeline.hh
@@ -18,7 +18,6 @@ class Output;
struct PipelineConfig {
std::size_t order;
- std::string vocab_file;
util::stream::SortConfig sort;
InitialProbabilitiesConfig initial_probs;
util::stream::ChainConfig read_backoffs;
diff --git a/lm/builder/print.cc b/lm/builder/print.cc
deleted file mode 100644
index 178e54a21..000000000
--- a/lm/builder/print.cc
+++ /dev/null
@@ -1,64 +0,0 @@
-#include "lm/builder/print.hh"
-
-#include "util/fake_ofstream.hh"
-#include "util/file.hh"
-#include "util/mmap.hh"
-#include "util/scoped.hh"
-#include "util/stream/timer.hh"
-
-#include <sstream>
-#include <cstring>
-
-namespace lm { namespace builder {
-
-VocabReconstitute::VocabReconstitute(int fd) {
- uint64_t size = util::SizeOrThrow(fd);
- util::MapRead(util::POPULATE_OR_READ, fd, 0, size, memory_);
- const char *const start = static_cast<const char*>(memory_.get());
- const char *i;
- for (i = start; i != start + size; i += strlen(i) + 1) {
- map_.push_back(i);
- }
- // Last one for LookupPiece.
- map_.push_back(i);
-}
-
-void PrintARPA::Sink(util::stream::Chains &chains) {
- chains >> boost::ref(*this);
-}
-
-void PrintARPA::Run(const util::stream::ChainPositions &positions) {
- VocabReconstitute vocab(GetVocabFD());
- util::FakeOFStream out(out_fd_.get());
-
- // Write header.
- if (verbose_header_) {
- out << "# Input file: " << GetHeader().input_file << '\n';
- out << "# Token count: " << GetHeader().token_count << '\n';
- out << "# Smoothing: Modified Kneser-Ney" << '\n';
- }
- out << "\\data\\\n";
- for (size_t i = 0; i < positions.size(); ++i) {
- out << "ngram " << (i+1) << '=' << GetHeader().counts_pruned[i] << '\n';
- }
- out << '\n';
-
- for (unsigned order = 1; order <= positions.size(); ++order) {
- out << "\\" << order << "-grams:" << '\n';
- for (NGramStream<BuildingPayload> stream(positions[order - 1]); stream; ++stream) {
- // Correcting for numerical precision issues. Take that IRST.
- out << stream->Value().complete.prob << '\t' << vocab.Lookup(*stream->begin());
- for (const WordIndex *i = stream->begin() + 1; i != stream->end(); ++i) {
- out << ' ' << vocab.Lookup(*i);
- }
- if (order != positions.size())
- out << '\t' << stream->Value().complete.backoff;
- out << '\n';
-
- }
- out << '\n';
- }
- out << "\\end\\\n";
-}
-
-}} // namespaces
diff --git a/lm/common/CMakeLists.txt b/lm/common/CMakeLists.txt
new file mode 100644
index 000000000..942e24bdc
--- /dev/null
+++ b/lm/common/CMakeLists.txt
@@ -0,0 +1,40 @@
+cmake_minimum_required(VERSION 2.8.8)
+#
+# The KenLM cmake files make use of add_library(... OBJECTS ...)
+#
+# This syntax allows grouping of source files when compiling
+# (effectively creating "fake" libraries based on source subdirs).
+#
+# This syntax was only added in cmake version 2.8.8
+#
+# see http://www.cmake.org/Wiki/CMake/Tutorials/Object_Library
+
+
+# This CMake file was created by Lane Schwartz <dowobeha@gmail.com>
+
+# Explicitly list the source files for this subdirectory
+#
+# If you add any source files to this subdirectory
+# that should be included in the kenlm library,
+# (this excludes any unit test files)
+# you should add them to the following list:
+#
+# In order to set correct paths to these files
+# in case this variable is referenced by CMake files in the parent directory,
+# we prefix all files with ${CMAKE_CURRENT_SOURCE_DIR}.
+#
+set(KENLM_COMMON_SOURCE
+ ${CMAKE_CURRENT_SOURCE_DIR}/model_buffer.cc
+ ${CMAKE_CURRENT_SOURCE_DIR}/print.cc
+ ${CMAKE_CURRENT_SOURCE_DIR}/renumber.cc
+ ${CMAKE_CURRENT_SOURCE_DIR}/size_option.cc
+ )
+
+
+# Group these objects together for later use.
+#
+# Given add_library(foo OBJECT ${my_foo_sources}),
+# refer to these objects as $<TARGET_OBJECTS:foo>
+#
+add_library(kenlm_common OBJECT ${KENLM_COMMON_SOURCE})
+
diff --git a/lm/common/Jamfile b/lm/common/Jamfile
index 1c9c37210..c9bdfd0df 100644
--- a/lm/common/Jamfile
+++ b/lm/common/Jamfile
@@ -1,2 +1,2 @@
fakelib common : [ glob *.cc : *test.cc *main.cc ]
- ../../util//kenutil ../../util/stream//stream ../../util/double-conversion//double-conversion ..//kenlm ;
+ ../../util//kenutil ../../util/stream//stream ../../util/double-conversion//double-conversion ..//kenlm /top//boost_program_options ;
diff --git a/lm/builder/joint_order.hh b/lm/common/joint_order.hh
index 5f62a4578..6113bb8f1 100644
--- a/lm/builder/joint_order.hh
+++ b/lm/common/joint_order.hh
@@ -1,8 +1,7 @@
-#ifndef LM_BUILDER_JOINT_ORDER_H
-#define LM_BUILDER_JOINT_ORDER_H
+#ifndef LM_COMMON_JOINT_ORDER_H
+#define LM_COMMON_JOINT_ORDER_H
#include "lm/common/ngram_stream.hh"
-#include "lm/builder/payload.hh"
#include "lm/lm_exception.hh"
#ifdef DEBUG
@@ -12,15 +11,19 @@
#include <cstring>
-namespace lm { namespace builder {
+namespace lm {
template <class Callback, class Compare> void JointOrder(const util::stream::ChainPositions &positions, Callback &callback) {
// Allow matching to reference streams[-1].
- NGramStreams<BuildingPayload> streams_with_dummy;
- streams_with_dummy.InitWithDummy(positions);
- NGramStream<BuildingPayload> *streams = streams_with_dummy.begin() + 1;
+ util::FixedArray<ProxyStream<NGramHeader> > streams_with_dummy(positions.size() + 1);
+ // A bogus stream for [-1].
+ streams_with_dummy.push_back();
+ for (std::size_t i = 0; i < positions.size(); ++i) {
+ streams_with_dummy.push_back(positions[i], NGramHeader(NULL, i + 1));
+ }
+ ProxyStream<NGramHeader> *streams = streams_with_dummy.begin() + 1;
- unsigned int order;
+ std::size_t order;
for (order = 0; order < positions.size() && streams[order]; ++order) {}
assert(order); // should always have <unk>.
@@ -31,11 +34,11 @@ template <class Callback, class Compare> void JointOrder(const util::stream::Cha
less_compare.push_back(i + 1);
#endif // DEBUG
- unsigned int current = 0;
+ std::size_t current = 0;
while (true) {
// Does the context match the lower one?
if (!memcmp(streams[static_cast<int>(current) - 1]->begin(), streams[current]->begin() + Compare::kMatchOffset, sizeof(WordIndex) * current)) {
- callback.Enter(current, *streams[current]);
+ callback.Enter(current, streams[current].Get());
// Transition to looking for extensions.
if (++current < order) continue;
}
@@ -51,7 +54,7 @@ template <class Callback, class Compare> void JointOrder(const util::stream::Cha
while(true) {
assert(current > 0);
--current;
- callback.Exit(current, *streams[current]);
+ callback.Exit(current, streams[current].Get());
if (++streams[current]) break;
@@ -63,6 +66,6 @@ template <class Callback, class Compare> void JointOrder(const util::stream::Cha
}
}
-}} // namespaces
+} // namespaces
-#endif // LM_BUILDER_JOINT_ORDER_H
+#endif // LM_COMMON_JOINT_ORDER_H
diff --git a/lm/common/model_buffer.cc b/lm/common/model_buffer.cc
index d4635da51..431d4ae4c 100644
--- a/lm/common/model_buffer.cc
+++ b/lm/common/model_buffer.cc
@@ -8,25 +8,30 @@
#include <boost/lexical_cast.hpp>
-namespace lm { namespace common {
+namespace lm {
namespace {
const char kMetadataHeader[] = "KenLM intermediate binary file";
} // namespace
-ModelBuffer::ModelBuffer(const std::string &file_base, bool keep_buffer, bool output_q)
- : file_base_(file_base), keep_buffer_(keep_buffer), output_q_(output_q) {}
-
-ModelBuffer::ModelBuffer(const std::string &file_base)
- : file_base_(file_base), keep_buffer_(false) {
+ModelBuffer::ModelBuffer(StringPiece file_base, bool keep_buffer, bool output_q)
+ : file_base_(file_base.data(), file_base.size()), keep_buffer_(keep_buffer), output_q_(output_q),
+ vocab_file_(keep_buffer ? util::CreateOrThrow((file_base_ + ".vocab").c_str()) : util::MakeTemp(file_base_)) {}
+
+ModelBuffer::ModelBuffer(StringPiece file_base)
+ : file_base_(file_base.data(), file_base.size()), keep_buffer_(false) {
const std::string full_name = file_base_ + ".kenlm_intermediate";
util::FilePiece in(full_name.c_str());
StringPiece token = in.ReadLine();
UTIL_THROW_IF2(token != kMetadataHeader, "File " << full_name << " begins with \"" << token << "\" not " << kMetadataHeader);
token = in.ReadDelimited();
- UTIL_THROW_IF2(token != "Order", "Expected Order, got \"" << token << "\" in " << full_name);
- unsigned long order = in.ReadULong();
+ UTIL_THROW_IF2(token != "Counts", "Expected Counts, got \"" << token << "\" in " << full_name);
+ char got;
+ while ((got = in.get()) == ' ') {
+ counts_.push_back(in.ReadULong());
+ }
+ UTIL_THROW_IF2(got != '\n', "Expected newline at end of counts.");
token = in.ReadDelimited();
UTIL_THROW_IF2(token != "Payload", "Expected Payload, got \"" << token << "\" in " << full_name);
@@ -39,16 +44,16 @@ ModelBuffer::ModelBuffer(const std::string &file_base)
UTIL_THROW(util::Exception, "Unknown payload " << token);
}
- files_.Init(order);
- for (unsigned long i = 0; i < order; ++i) {
+ vocab_file_.reset(util::OpenReadOrThrow((file_base_ + ".vocab").c_str()));
+
+ files_.Init(counts_.size());
+ for (unsigned long i = 0; i < counts_.size(); ++i) {
files_.push_back(util::OpenReadOrThrow((file_base_ + '.' + boost::lexical_cast<std::string>(i + 1)).c_str()));
}
}
-// virtual destructor
-ModelBuffer::~ModelBuffer() {}
-
-void ModelBuffer::Sink(util::stream::Chains &chains) {
+void ModelBuffer::Sink(util::stream::Chains &chains, const std::vector<uint64_t> &counts) {
+ counts_ = counts;
// Open files.
files_.Init(chains.size());
for (std::size_t i = 0; i < chains.size(); ++i) {
@@ -64,19 +69,23 @@ void ModelBuffer::Sink(util::stream::Chains &chains) {
if (keep_buffer_) {
util::scoped_fd metadata(util::CreateOrThrow((file_base_ + ".kenlm_intermediate").c_str()));
util::FakeOFStream meta(metadata.get(), 200);
- meta << kMetadataHeader << "\nOrder " << chains.size() << "\nPayload " << (output_q_ ? "q" : "pb") << '\n';
+ meta << kMetadataHeader << "\nCounts";
+ for (std::vector<uint64_t>::const_iterator i = counts_.begin(); i != counts_.end(); ++i) {
+ meta << ' ' << *i;
+ }
+ meta << "\nPayload " << (output_q_ ? "q" : "pb") << '\n';
}
}
void ModelBuffer::Source(util::stream::Chains &chains) {
- assert(chains.size() == files_.size());
- for (unsigned int i = 0; i < files_.size(); ++i) {
+ assert(chains.size() <= files_.size());
+ for (unsigned int i = 0; i < chains.size(); ++i) {
chains[i] >> util::stream::PRead(files_[i].get());
}
}
-std::size_t ModelBuffer::Order() const {
- return files_.size();
+void ModelBuffer::Source(std::size_t order_minus_1, util::stream::Chain &chain) {
+ chain >> util::stream::PRead(files_[order_minus_1].get());
}
-}} // namespaces
+} // namespace
diff --git a/lm/common/model_buffer.hh b/lm/common/model_buffer.hh
index 6a5c7bf49..92662bbf8 100644
--- a/lm/common/model_buffer.hh
+++ b/lm/common/model_buffer.hh
@@ -1,5 +1,5 @@
-#ifndef LM_BUILDER_MODEL_BUFFER_H
-#define LM_BUILDER_MODEL_BUFFER_H
+#ifndef LM_COMMON_MODEL_BUFFER_H
+#define LM_COMMON_MODEL_BUFFER_H
/* Format with separate files in suffix order. Each file contains
* n-grams of the same order.
@@ -9,37 +9,55 @@
#include "util/fixed_array.hh"
#include <string>
+#include <vector>
-namespace util { namespace stream { class Chains; } }
+namespace util { namespace stream {
+class Chains;
+class Chain;
+}} // namespaces
-namespace lm { namespace common {
+namespace lm {
class ModelBuffer {
public:
- // Construct for writing.
- ModelBuffer(const std::string &file_base, bool keep_buffer, bool output_q);
+ // Construct for writing. Must call VocabFile() and fill it with null-delimited vocab words.
+ ModelBuffer(StringPiece file_base, bool keep_buffer, bool output_q);
// Load from file.
- explicit ModelBuffer(const std::string &file_base);
-
- // explicit for virtual destructor.
- ~ModelBuffer();
+ explicit ModelBuffer(StringPiece file_base);
- void Sink(util::stream::Chains &chains);
+ // Must call VocabFile and populate before calling this function.
+ void Sink(util::stream::Chains &chains, const std::vector<uint64_t> &counts);
+ // Read files and write to the given chains. If fewer chains are provided,
+ // only do the lower orders.
void Source(util::stream::Chains &chains);
+ void Source(std::size_t order_minus_1, util::stream::Chain &chain);
+
// The order of the n-gram model that is associated with the model buffer.
- std::size_t Order() const;
+ std::size_t Order() const { return counts_.size(); }
+ // Requires Sink or load from file.
+ const std::vector<uint64_t> &Counts() const {
+ assert(!counts_.empty());
+ return counts_;
+ }
+
+ int VocabFile() const { return vocab_file_.get(); }
+ int StealVocabFile() { return vocab_file_.release(); }
+
+ bool Keep() const { return keep_buffer_; }
private:
const std::string file_base_;
const bool keep_buffer_;
bool output_q_;
+ std::vector<uint64_t> counts_;
+ util::scoped_fd vocab_file_;
util::FixedArray<util::scoped_fd> files_;
};
-}} // namespaces
+} // namespace lm
-#endif // LM_BUILDER_MODEL_BUFFER_H
+#endif // LM_COMMON_MODEL_BUFFER_H
diff --git a/lm/common/ngram.hh b/lm/common/ngram.hh
index 813017640..7a6d1c358 100644
--- a/lm/common/ngram.hh
+++ b/lm/common/ngram.hh
@@ -16,6 +16,8 @@ class NGramHeader {
NGramHeader(void *begin, std::size_t order)
: begin_(static_cast<WordIndex*>(begin)), end_(begin_ + order) {}
+ NGramHeader() : begin_(NULL), end_(NULL) {}
+
const uint8_t *Base() const { return reinterpret_cast<const uint8_t*>(begin_); }
uint8_t *Base() { return reinterpret_cast<uint8_t*>(begin_); }
@@ -32,6 +34,7 @@ class NGramHeader {
const WordIndex *end() const { return end_; }
WordIndex *end() { return end_; }
+ std::size_t size() const { return end_ - begin_; }
std::size_t Order() const { return end_ - begin_; }
private:
@@ -42,6 +45,8 @@ template <class PayloadT> class NGram : public NGramHeader {
public:
typedef PayloadT Payload;
+ NGram() : NGramHeader(NULL, 0) {}
+
NGram(void *begin, std::size_t order) : NGramHeader(begin, order) {}
// Would do operator++ but that can get confusing for a stream.
diff --git a/lm/common/ngram_stream.hh b/lm/common/ngram_stream.hh
index 53c4ffcb8..8bdf36e3c 100644
--- a/lm/common/ngram_stream.hh
+++ b/lm/common/ngram_stream.hh
@@ -10,24 +10,21 @@
namespace lm {
-template <class Payload> class NGramStream {
+template <class Proxy> class ProxyStream {
public:
- NGramStream() : gram_(NULL, 0) {}
+ // Make an invalid stream.
+ ProxyStream() {}
- NGramStream(const util::stream::ChainPosition &position) : gram_(NULL, 0) {
- Init(position);
+ explicit ProxyStream(const util::stream::ChainPosition &position, const Proxy &proxy = Proxy())
+ : proxy_(proxy), stream_(position) {
+ proxy_.ReBase(stream_.Get());
}
- void Init(const util::stream::ChainPosition &position) {
- stream_.Init(position);
- gram_ = NGram<Payload>(stream_.Get(), NGram<Payload>::OrderFromSize(position.GetChain().EntrySize()));
- }
-
- NGram<Payload> &operator*() { return gram_; }
- const NGram<Payload> &operator*() const { return gram_; }
+ Proxy &operator*() { return proxy_; }
+ const Proxy &operator*() const { return proxy_; }
- NGram<Payload> *operator->() { return &gram_; }
- const NGram<Payload> *operator->() const { return &gram_; }
+ Proxy *operator->() { return &proxy_; }
+ const Proxy *operator->() const { return &proxy_; }
void *Get() { return stream_.Get(); }
const void *Get() const { return stream_.Get(); }
@@ -36,21 +33,25 @@ template <class Payload> class NGramStream {
bool operator!() const { return !stream_; }
void Poison() { stream_.Poison(); }
- NGramStream &operator++() {
+ ProxyStream<Proxy> &operator++() {
++stream_;
- gram_.ReBase(stream_.Get());
+ proxy_.ReBase(stream_.Get());
return *this;
}
private:
- NGram<Payload> gram_;
+ Proxy proxy_;
util::stream::Stream stream_;
};
-template <class Payload> inline util::stream::Chain &operator>>(util::stream::Chain &chain, NGramStream<Payload> &str) {
- str.Init(chain.Add());
- return chain;
-}
+template <class Payload> class NGramStream : public ProxyStream<NGram<Payload> > {
+ public:
+ // Make an invalid stream.
+ NGramStream() {}
+
+ explicit NGramStream(const util::stream::ChainPosition &position) :
+ ProxyStream<NGram<Payload> >(position, NGram<Payload>(NULL, NGram<Payload>::OrderFromSize(position.GetChain().EntrySize()))) {}
+};
template <class Payload> class NGramStreams : public util::stream::GenericStreams<NGramStream<Payload> > {
private:
diff --git a/lm/common/print.cc b/lm/common/print.cc
new file mode 100644
index 000000000..cd2a80260
--- /dev/null
+++ b/lm/common/print.cc
@@ -0,0 +1,62 @@
+#include "lm/common/print.hh"
+
+#include "lm/common/ngram_stream.hh"
+#include "util/fake_ofstream.hh"
+#include "util/file.hh"
+#include "util/mmap.hh"
+#include "util/scoped.hh"
+
+#include <sstream>
+#include <cstring>
+
+namespace lm {
+
+VocabReconstitute::VocabReconstitute(int fd) {
+ uint64_t size = util::SizeOrThrow(fd);
+ util::MapRead(util::POPULATE_OR_READ, fd, 0, size, memory_);
+ const char *const start = static_cast<const char*>(memory_.get());
+ const char *i;
+ for (i = start; i != start + size; i += strlen(i) + 1) {
+ map_.push_back(i);
+ }
+ // Last one for LookupPiece.
+ map_.push_back(i);
+}
+
+namespace {
+template <class Payload> void PrintLead(const VocabReconstitute &vocab, ProxyStream<Payload> &stream, util::FakeOFStream &out) {
+ out << stream->Value().prob << '\t' << vocab.Lookup(*stream->begin());
+ for (const WordIndex *i = stream->begin() + 1; i != stream->end(); ++i) {
+ out << ' ' << vocab.Lookup(*i);
+ }
+}
+} // namespace
+
+void PrintARPA::Run(const util::stream::ChainPositions &positions) {
+ VocabReconstitute vocab(vocab_fd_);
+ util::FakeOFStream out(out_fd_);
+ out << "\\data\\\n";
+ for (size_t i = 0; i < positions.size(); ++i) {
+ out << "ngram " << (i+1) << '=' << counts_[i] << '\n';
+ }
+ out << '\n';
+
+ for (unsigned order = 1; order < positions.size(); ++order) {
+ out << "\\" << order << "-grams:" << '\n';
+ for (ProxyStream<NGram<ProbBackoff> > stream(positions[order - 1], NGram<ProbBackoff>(NULL, order)); stream; ++stream) {
+ PrintLead(vocab, stream, out);
+ out << '\t' << stream->Value().backoff << '\n';
+ }
+ out << '\n';
+ }
+
+ out << "\\" << positions.size() << "-grams:" << '\n';
+ for (ProxyStream<NGram<Prob> > stream(positions.back(), NGram<Prob>(NULL, positions.size())); stream; ++stream) {
+ PrintLead(vocab, stream, out);
+ out << '\n';
+ }
+ out << '\n';
+ out << "\\end\\\n";
+}
+
+} // namespace lm
diff --git a/lm/common/print.hh b/lm/common/print.hh
new file mode 100644
index 000000000..6aa08b32a
--- /dev/null
+++ b/lm/common/print.hh
@@ -0,0 +1,58 @@
+#ifndef LM_COMMON_PRINT_H
+#define LM_COMMON_PRINT_H
+
+#include "lm/word_index.hh"
+#include "util/mmap.hh"
+#include "util/string_piece.hh"
+
+#include <cassert>
+#include <vector>
+
+namespace util { namespace stream { class ChainPositions; }}
+
+// Warning: PrintARPA routines read all unigrams before all bigrams before all
+// trigrams etc. So if other parts of the chain move jointly, you'll have to
+// buffer.
+
+namespace lm {
+
+class VocabReconstitute {
+ public:
+ // fd must be alive for life of this object; does not take ownership.
+ explicit VocabReconstitute(int fd);
+
+ const char *Lookup(WordIndex index) const {
+ assert(index < map_.size() - 1);
+ return map_[index];
+ }
+
+ StringPiece LookupPiece(WordIndex index) const {
+ return StringPiece(map_[index], map_[index + 1] - 1 - map_[index]);
+ }
+
+ std::size_t Size() const {
+ // There's an extra entry to support StringPiece lengths.
+ return map_.size() - 1;
+ }
+
+ private:
+ util::scoped_memory memory_;
+ std::vector<const char*> map_;
+};
+
+class PrintARPA {
+ public:
+ // Does not take ownership of vocab_fd or out_fd.
+ explicit PrintARPA(int vocab_fd, int out_fd, const std::vector<uint64_t> &counts)
+ : vocab_fd_(vocab_fd), out_fd_(out_fd), counts_(counts) {}
+
+ void Run(const util::stream::ChainPositions &positions);
+
+ private:
+ int vocab_fd_;
+ int out_fd_;
+ std::vector<uint64_t> counts_;
+};
+
+} // namespace lm
+#endif // LM_COMMON_PRINT_H
diff --git a/lm/common/size_option.cc b/lm/common/size_option.cc
new file mode 100644
index 000000000..46a920e69
--- /dev/null
+++ b/lm/common/size_option.cc
@@ -0,0 +1,24 @@
+#include <boost/program_options.hpp>
+#include "util/usage.hh"
+
+namespace lm {
+
+namespace {
+class SizeNotify {
+ public:
+ explicit SizeNotify(std::size_t &out) : behind_(out) {}
+
+ void operator()(const std::string &from) {
+ behind_ = util::ParseSize(from);
+ }
+
+ private:
+ std::size_t &behind_;
+};
+}
+
+boost::program_options::typed_value<std::string> *SizeOption(std::size_t &to, const char *default_value) {
+ return boost::program_options::value<std::string>()->notifier(SizeNotify(to))->default_value(default_value);
+}
+
+} // namespace lm
diff --git a/lm/common/size_option.hh b/lm/common/size_option.hh
new file mode 100644
index 000000000..d3b8e33cb
--- /dev/null
+++ b/lm/common/size_option.hh
@@ -0,0 +1,11 @@
+#include <boost/program_options.hpp>
+
+#include <cstddef>
+#include <string>
+
+namespace lm {
+
+// Create a boost program option for data sizes. This parses sizes like 1T and 10k.
+boost::program_options::typed_value<std::string> *SizeOption(std::size_t &to, const char *default_value);
+
+} // namespace lm
diff --git a/lm/builder/special.hh b/lm/common/special.hh
index c70865ce1..0677cd71b 100644
--- a/lm/builder/special.hh
+++ b/lm/common/special.hh
@@ -1,9 +1,9 @@
-#ifndef LM_BUILDER_SPECIAL_H
-#define LM_BUILDER_SPECIAL_H
+#ifndef LM_COMMON_SPECIAL_H
+#define LM_COMMON_SPECIAL_H
#include "lm/word_index.hh"
-namespace lm { namespace builder {
+namespace lm {
class SpecialVocab {
public:
@@ -22,6 +22,6 @@ class SpecialVocab {
WordIndex eos_;
};
-}} // namespaces
+} // namespace lm
-#endif // LM_BUILDER_SPECIAL_H
+#endif // LM_COMMON_SPECIAL_H
diff --git a/lm/filter/CMakeLists.txt b/lm/filter/CMakeLists.txt
new file mode 100644
index 000000000..4e791cef8
--- /dev/null
+++ b/lm/filter/CMakeLists.txt
@@ -0,0 +1,62 @@
+cmake_minimum_required(VERSION 2.8.8)
+#
+# The KenLM cmake files make use of add_library(... OBJECTS ...)
+#
+# This syntax allows grouping of source files when compiling
+# (effectively creating "fake" libraries based on source subdirs).
+#
+# This syntax was only added in cmake version 2.8.8
+#
+# see http://www.cmake.org/Wiki/CMake/Tutorials/Object_Library
+
+
+# This CMake file was created by Lane Schwartz <dowobeha@gmail.com>
+
+# Explicitly list the source files for this subdirectory
+#
+# If you add any source files to this subdirectory
+# that should be included in the kenlm library,
+# (this excludes any unit test files)
+# you should add them to the following list:
+#
+# In order to set correct paths to these files
+# in case this variable is referenced by CMake files in the parent directory,
+# we prefix all files with ${CMAKE_CURRENT_SOURCE_DIR}.
+#
+set(KENLM_FILTER_SOURCE
+ ${CMAKE_CURRENT_SOURCE_DIR}/arpa_io.cc
+ ${CMAKE_CURRENT_SOURCE_DIR}/phrase.cc
+ ${CMAKE_CURRENT_SOURCE_DIR}/vocab.cc
+ )
+
+
+# Group these objects together for later use.
+#
+# Given add_library(foo OBJECT ${my_foo_sources}),
+# refer to these objects as $<TARGET_OBJECTS:foo>
+#
+add_library(kenlm_filter OBJECT ${KENLM_FILTER_SOURCE})
+
+
+# Explicitly list the executable files to be compiled
+set(EXE_LIST
+ filter
+ phrase_table_vocab
+)
+
+
+# Iterate through the executable list
+foreach(exe ${EXE_LIST})
+
+ # Compile the executable, linking against the requisite dependent object files
+ add_executable(${exe} ${exe}_main.cc $<TARGET_OBJECTS:kenlm> $<TARGET_OBJECTS:kenlm_filter> $<TARGET_OBJECTS:kenlm_util>)
+
+ # Link the executable against boost
+ target_link_libraries(${exe} ${Boost_LIBRARIES})
+
+ # Group executables together
+ set_target_properties(${exe} PROPERTIES FOLDER executables)
+
+# End for loop
+endforeach(exe)
+
diff --git a/util/CMakeLists.txt b/util/CMakeLists.txt
new file mode 100644
index 000000000..c52cdbc06
--- /dev/null
+++ b/util/CMakeLists.txt
@@ -0,0 +1,109 @@
+cmake_minimum_required(VERSION 2.8.8)
+#
+# The KenLM cmake files make use of add_library(... OBJECTS ...)
+#
+# This syntax allows grouping of source files when compiling
+# (effectively creating "fake" libraries based on source subdirs).
+#
+# This syntax was only added in cmake version 2.8.8
+#
+# see http://www.cmake.org/Wiki/CMake/Tutorials/Object_Library
+
+
+# This CMake file was created by Lane Schwartz <dowobeha@gmail.com>
+
+
+# Explicitly list the source files for this subdirectory
+#
+# If you add any source files to this subdirectory
+# that should be included in the kenlm library,
+# (this excludes any unit test files)
+# you should add them to the following list:
+#
+# Because we do not set PARENT_SCOPE in the following definition,
+# CMake files in the parent directory won't be able to access this variable.
+#
+set(KENLM_UTIL_SOURCE
+ bit_packing.cc
+ ersatz_progress.cc
+ exception.cc
+ file.cc
+ file_piece.cc
+ float_to_string.cc
+ integer_to_string.cc
+ mmap.cc
+ murmur_hash.cc
+ parallel_read.cc
+ pool.cc
+ read_compressed.cc
+ scoped.cc
+ string_piece.cc
+ usage.cc
+ )
+
+# This directory has children that need to be processed
+add_subdirectory(double-conversion)
+add_subdirectory(stream)
+
+
+# Group these objects together for later use.
+#
+# Given add_library(foo OBJECT ${my_foo_sources}),
+# refer to these objects as $<TARGET_OBJECTS:foo>
+#
+add_library(kenlm_util OBJECT ${KENLM_UTIL_DOUBLECONVERSION_SOURCE} ${KENLM_UTIL_STREAM_SOURCE} ${KENLM_UTIL_SOURCE})
+
+
+
+# Only compile and run unit tests if tests should be run
+if(BUILD_TESTING)
+
+ # Explicitly list the Boost test files to be compiled
+ set(KENLM_BOOST_TESTS_LIST
+ bit_packing_test
+ file_piece_test
+ joint_sort_test
+ multi_intersection_test
+ probing_hash_table_test
+ read_compressed_test
+ sorted_uniform_test
+ tokenize_piece_test
+ )
+
+ # Iterate through the Boost tests list
+ foreach(test ${KENLM_BOOST_TESTS_LIST})
+
+ # Compile the executable, linking against the requisite dependent object files
+ add_executable(${test} ${test}.cc $<TARGET_OBJECTS:kenlm_util>)
+
+ # Require the following compile flag
+ set_target_properties(${test} PROPERTIES COMPILE_FLAGS -DBOOST_TEST_DYN_LINK)
+
+ # Link the executable against boost
+ target_link_libraries(${test} ${Boost_LIBRARIES})
+
+ # file_piece_test requires an extra command line parameter
+ if ("${test}" STREQUAL "file_piece_test")
+ set(test_params
+ ${CMAKE_CURRENT_SOURCE_DIR}/file_piece.cc
+ )
+ else()
+ set(test_params
+ )
+ endif()
+
+ # Specify command arguments for how to run each unit test
+ #
+ # Assuming that foo was defined via add_executable(foo ...),
+ # the syntax $<TARGET_FILE:foo> gives the full path to the executable.
+ #
+ add_test(NAME ${test}_test
+ COMMAND $<TARGET_FILE:${test}> ${test_params})
+
+ # Group unit tests together
+ set_target_properties(${test} PROPERTIES FOLDER "unit_tests")
+
+ # End for loop
+ endforeach(test)
+
+endif()
diff --git a/util/double-conversion/CMakeLists.txt b/util/double-conversion/CMakeLists.txt
new file mode 100644
index 000000000..e2cf02aa6
--- /dev/null
+++ b/util/double-conversion/CMakeLists.txt
@@ -0,0 +1,39 @@
+cmake_minimum_required(VERSION 2.8.8)
+#
+# The KenLM cmake files make use of add_library(... OBJECTS ...)
+#
+# This syntax allows grouping of source files when compiling
+# (effectively creating "fake" libraries based on source subdirs).
+#
+# This syntax was only added in cmake version 2.8.8
+#
+# see http://www.cmake.org/Wiki/CMake/Tutorials/Object_Library
+
+
+# This CMake file was created by Lane Schwartz <dowobeha@gmail.com>
+
+# Explicitly list the source files for this subdirectory
+#
+# If you add any source files to this subdirectory
+# that should be included in the kenlm library,
+# (this excludes any unit test files)
+# you should add them to the following list:
+#
+# In order to allow CMake files in the parent directory
+# to see this variable definition, we set PARENT_SCOPE.
+#
+# In order to set correct paths to these files
+# when this variable is referenced by CMake files in the parent directory,
+# we prefix all files with ${CMAKE_CURRENT_SOURCE_DIR}.
+#
+set(KENLM_UTIL_DOUBLECONVERSION_SOURCE
+ ${CMAKE_CURRENT_SOURCE_DIR}/bignum-dtoa.cc
+ ${CMAKE_CURRENT_SOURCE_DIR}/bignum.cc
+ ${CMAKE_CURRENT_SOURCE_DIR}/cached-powers.cc
+ ${CMAKE_CURRENT_SOURCE_DIR}/diy-fp.cc
+ ${CMAKE_CURRENT_SOURCE_DIR}/double-conversion.cc
+ ${CMAKE_CURRENT_SOURCE_DIR}/fast-dtoa.cc
+ ${CMAKE_CURRENT_SOURCE_DIR}/fixed-dtoa.cc
+ ${CMAKE_CURRENT_SOURCE_DIR}/strtod.cc
+ PARENT_SCOPE)
+
diff --git a/util/file.cc b/util/file.cc
index 046b9ff90..be272f9bc 100644
--- a/util/file.cc
+++ b/util/file.cc
@@ -60,6 +60,14 @@ EndOfFileException::EndOfFileException() throw() {
}
EndOfFileException::~EndOfFileException() throw() {}
+bool InputFileIsStdin(StringPiece path) {
+ return path == "-" || path == "/dev/stdin";
+}
+
+bool OutputFileIsStdout(StringPiece path) {
+ return path == "-" || path == "/dev/stdout";
+}
+
int OpenReadOrThrow(const char *name) {
int ret;
#if defined(_WIN32) || defined(_WIN64)
@@ -111,7 +119,10 @@ uint64_t SizeOrThrow(int fd) {
}
void ResizeOrThrow(int fd, uint64_t to) {
-#if defined(_WIN32) || defined(_WIN64)
+#if defined __MINGW32__
+ // Does this handle 64-bit?
+ int ret = ftruncate
+#elif defined(_WIN32) || defined(_WIN64)
errno_t ret = _chsize_s
#elif defined(OS_ANDROID)
int ret = ftruncate64
diff --git a/util/file.hh b/util/file.hh
index bd5873cbc..f7cb4d688 100644
--- a/util/file.hh
+++ b/util/file.hh
@@ -78,6 +78,28 @@ int OpenReadOrThrow(const char *name);
// Create file if it doesn't exist, truncate if it does. Opened for write.
int CreateOrThrow(const char *name);
+/** Does the given input file path denote standard input?
+ *
+ * Returns true if, and only if, path is either "-" or "/dev/stdin".
+ *
+ * Opening standard input as a file may need some special treatment for
+ * portability. There's a convention that a dash ("-") in place of an input
+ * file path denotes standard input, but opening "/dev/stdin" may need to be
+ * special as well.
+ */
+bool InputPathIsStdin(StringPiece path);
+
+/** Does the given output file path denote standard output?
+ *
+ * Returns true if, and only if, path is either "-" or "/dev/stdout".
+ *
+ * Opening standard output as a file may need some special treatment for
+ * portability. There's a convention that a dash ("-") in place of an output
+ * file path denotes standard output, but opening "/dev/stdout" may need to be
+ * special as well.
+ */
+bool OutputPathIsStdout(StringPiece path);
+
// Return value for SizeFile when it can't size properly.
const uint64_t kBadSize = (uint64_t)-1;
uint64_t SizeFile(int fd);
diff --git a/util/fixed_array.hh b/util/fixed_array.hh
index 610cbdf12..9083d39ee 100644
--- a/util/fixed_array.hh
+++ b/util/fixed_array.hh
@@ -100,33 +100,56 @@ template <class T> class FixedArray {
*
* @param i Index of the object to reference
*/
- T &operator[](std::size_t i) { return begin()[i]; }
+ T &operator[](std::size_t i) {
+ assert(i < size());
+ return begin()[i];
+ }
/**
* Gets a const reference to the object with index i currently stored in this data structure.
*
* @param i Index of the object to reference
*/
- const T &operator[](std::size_t i) const { return begin()[i]; }
+ const T &operator[](std::size_t i) const {
+ assert(i < size());
+ return begin()[i];
+ }
/**
* Constructs a new object using the provided parameter,
* and stores it in this data structure.
*
* The memory backing the constructed object is managed by this data structure.
+ * I miss C++11 variadic templates.
*/
+ void push_back() {
+ new (end()) T();
+ Constructed();
+ }
template <class C> void push_back(const C &c) {
- new (end()) T(c); // use "placement new" syntax to initalize T in an already-allocated memory location
+ new (end()) T(c);
+ Constructed();
+ }
+ template <class C> void push_back(C &c) {
+ new (end()) T(c);
+ Constructed();
+ }
+ template <class C, class D> void push_back(const C &c, const D &d) {
+ new (end()) T(c, d);
Constructed();
}
+ void pop_back() {
+ back().~T();
+ --newed_end_;
+ }
+
/**
* Removes all elements from this array.
*/
void clear() {
- for (T *i = begin(); i != end(); ++i)
- i->~T();
- newed_end_ = begin();
+ while (newed_end_ != begin())
+ pop_back();
}
protected:
diff --git a/util/float_to_string.hh b/util/float_to_string.hh
index d1104e790..930532734 100644
--- a/util/float_to_string.hh
+++ b/util/float_to_string.hh
@@ -8,13 +8,13 @@ namespace util {
template <> struct ToStringBuf<double> {
// DoubleToStringConverter::kBase10MaximalLength + 1 for null paranoia.
- static const unsigned kBytes = 18;
+ static const unsigned kBytes = 19;
};
// Single wasn't documented in double conversion, so be conservative and
// say the same as double.
template <> struct ToStringBuf<float> {
- static const unsigned kBytes = 18;
+ static const unsigned kBytes = 19;
};
char *ToString(double value, char *to);
diff --git a/util/probing_hash_table.hh b/util/probing_hash_table.hh
index f4192577b..f32b64ea3 100644
--- a/util/probing_hash_table.hh
+++ b/util/probing_hash_table.hh
@@ -92,8 +92,7 @@ template <class EntryT, class HashT, class EqualT = std::equal_to<typename Entry
Key got(i->GetKey());
if (equal_(got, t.GetKey())) { out = i; return true; }
if (equal_(got, invalid_)) {
- UTIL_THROW_IF(++entries_ >= buckets_, ProbingSizeException,
- "Hash table with " << buckets_ << " buckets is full.");
+ UTIL_THROW_IF(++entries_ >= buckets_, ProbingSizeException, "Hash table with " << buckets_ << " buckets is full.");
*i = t;
out = i;
return false;
diff --git a/util/probing_hash_table_benchmark_main.cc b/util/probing_hash_table_benchmark_main.cc
index 9deeaac2d..3e12290cf 100644
--- a/util/probing_hash_table_benchmark_main.cc
+++ b/util/probing_hash_table_benchmark_main.cc
@@ -1,18 +1,9 @@
-
#include "util/probing_hash_table.hh"
#include "util/scoped.hh"
#include "util/usage.hh"
#include <boost/random/mersenne_twister.hpp>
-#include <boost/version.hpp>
-#if BOOST_VERSION / 100000 < 1
-#error BOOST_LIB_VERSION is to old. Time to upgrade.
-#elif BOOST_VERSION / 100000 > 1 || BOOST_VERSION / 100 % 1000 > 46
#include <boost/random/uniform_int_distribution.hpp>
-#define have_uniform_int_distribution
-#else
-#include <boost/random/uniform_int.hpp>
-#endif
#include <iostream>
@@ -31,13 +22,8 @@ void Test(uint64_t entries, uint64_t lookups, float multiplier = 1.5) {
std::size_t size = Table::Size(entries, multiplier);
scoped_malloc backing(util::CallocOrThrow(size));
Table table(backing.get(), size);
-#ifdef have_uniform_int_distribution
boost::random::mt19937 gen;
boost::random::uniform_int_distribution<> dist(std::numeric_limits<uint64_t>::min(), std::numeric_limits<uint64_t>::max());
-#else
- boost::mt19937 gen;
- boost::uniform_int<> dist(std::numeric_limits<uint64_t>::min(), std::numeric_limits<uint64_t>::max());
-#endif
double start = UserTime();
for (uint64_t i = 0; i < entries; ++i) {
Entry entry;
diff --git a/util/stream/CMakeLists.txt b/util/stream/CMakeLists.txt
new file mode 100644
index 000000000..d3eddbe5c
--- /dev/null
+++ b/util/stream/CMakeLists.txt
@@ -0,0 +1,74 @@
+cmake_minimum_required(VERSION 2.8.8)
+#
+# The KenLM cmake files make use of add_library(... OBJECTS ...)
+#
+# This syntax allows grouping of source files when compiling
+# (effectively creating "fake" libraries based on source subdirs).
+#
+# This syntax was only added in cmake version 2.8.8
+#
+# see http://www.cmake.org/Wiki/CMake/Tutorials/Object_Library
+
+
+# This CMake file was created by Lane Schwartz <dowobeha@gmail.com>
+
+# Explicitly list the source files for this subdirectory
+#
+# If you add any source files to this subdirectory
+# that should be included in the kenlm library,
+# (this excludes any unit test files)
+# you should add them to the following list:
+#
+# In order to allow CMake files in the parent directory
+# to see this variable definition, we set PARENT_SCOPE.
+#
+# In order to set correct paths to these files
+# when this variable is referenced by CMake files in the parent directory,
+# we prefix all files with ${CMAKE_CURRENT_SOURCE_DIR}.
+#
+set(KENLM_UTIL_STREAM_SOURCE
+ ${CMAKE_CURRENT_SOURCE_DIR}/chain.cc
+ ${CMAKE_CURRENT_SOURCE_DIR}/io.cc
+ ${CMAKE_CURRENT_SOURCE_DIR}/line_input.cc
+ ${CMAKE_CURRENT_SOURCE_DIR}/multi_progress.cc
+ ${CMAKE_CURRENT_SOURCE_DIR}/rewindable_stream.cc
+ PARENT_SCOPE)
+
+
+
+if(BUILD_TESTING)
+
+ # Explicitly list the Boost test files to be compiled
+ set(KENLM_BOOST_TESTS_LIST
+ io_test
+ sort_test
+ stream_test
+ )
+
+ # Iterate through the Boost tests list
+ foreach(test ${KENLM_BOOST_TESTS_LIST})
+
+ # Compile the executable, linking against the requisite dependent object files
+ add_executable(${test} ${test}.cc $<TARGET_OBJECTS:kenlm_util>)
+
+ # Require the following compile flag
+ set_target_properties(${test} PROPERTIES COMPILE_FLAGS -DBOOST_TEST_DYN_LINK)
+
+ # Link the executable against boost
+ target_link_libraries(${test} ${Boost_LIBRARIES})
+
+ # Specify command arguments for how to run each unit test
+ #
+ # Assuming that foo was defined via add_executable(foo ...),
+ # the syntax $<TARGET_FILE:foo> gives the full path to the executable.
+ #
+ add_test(NAME ${test}_test
+ COMMAND $<TARGET_FILE:${test}>)
+
+ # Group unit tests together
+ set_target_properties(${test} PROPERTIES FOLDER "unit_tests")
+
+ # End for loop
+ endforeach(test)
+
+endif()
diff --git a/util/stream/Jamfile b/util/stream/Jamfile
index cde0247e7..de9d41c5f 100644
--- a/util/stream/Jamfile
+++ b/util/stream/Jamfile
@@ -1,10 +1,4 @@
-#if $(BOOST-VERSION) >= 104800 {
-# timer-link = <library>/top//boost_timer ;
-#} else {
-# timer-link = ;
-#}
-
-fakelib stream : chain.cc rewindable_stream.cc io.cc line_input.cc multi_progress.cc ..//kenutil /top//boost_thread : : : <library>/top//boost_thread ;
+fakelib stream : [ glob *.cc : *_test.cc ] ..//kenutil /top//boost_thread : : : <library>/top//boost_thread ;
import testing ;
unit-test io_test : io_test.cc stream /top//boost_unit_test_framework ;
diff --git a/util/stream/chain.cc b/util/stream/chain.cc
index 39f2f3fbb..6bc000522 100644
--- a/util/stream/chain.cc
+++ b/util/stream/chain.cc
@@ -126,7 +126,7 @@ Link::~Link() {
if (current_) {
// Probably an exception unwinding.
std::cerr << "Last input should have been poison." << std::endl;
- // abort();
+ abort();
} else {
if (!poisoned_) {
// Poison is a block whose memory pointer is NULL.
diff --git a/util/stream/chain.hh b/util/stream/chain.hh
index 8caa1afcb..296982260 100644
--- a/util/stream/chain.hh
+++ b/util/stream/chain.hh
@@ -74,11 +74,11 @@ class Thread {
* This method is called automatically by this class's @ref Thread() "constructor".
*/
template <class Position, class Worker> void operator()(const Position &position, Worker &worker) {
- try {
+// try {
worker.Run(position);
- } catch (const std::exception &e) {
- UnhandledException(e);
- }
+// } catch (const std::exception &e) {
+// UnhandledException(e);
+// }
}
private:
@@ -158,6 +158,13 @@ class Chain {
return block_size_;
}
+ /**
+ * Number of blocks going through the Chain.
+ */
+ std::size_t BlockCount() const {
+ return config_.block_count;
+ }
+
/** Two ways to add to the chain: Add() or operator>>. */
ChainPosition Add();
diff --git a/util/stream/count_records.cc b/util/stream/count_records.cc
new file mode 100644
index 000000000..bdadad713
--- /dev/null
+++ b/util/stream/count_records.cc
@@ -0,0 +1,12 @@
+#include "util/stream/count_records.hh"
+#include "util/stream/chain.hh"
+
+namespace util { namespace stream {
+
+void CountRecords::Run(const ChainPosition &position) {
+ for (Link link(position); link; ++link) {
+ *count_ += link->ValidSize() / position.GetChain().EntrySize();
+ }
+}
+
+}} // namespaces
diff --git a/util/stream/count_records.hh b/util/stream/count_records.hh
new file mode 100644
index 000000000..e3f7c94af
--- /dev/null
+++ b/util/stream/count_records.hh
@@ -0,0 +1,20 @@
+#include <stdint.h>
+
+namespace util { namespace stream {
+
+class ChainPosition;
+
+class CountRecords {
+ public:
+ explicit CountRecords(uint64_t *out)
+ : count_(out) {
+ *count_ = 0;
+ }
+
+ void Run(const ChainPosition &position);
+
+ private:
+ uint64_t *count_;
+};
+
+}} // namespaces
diff --git a/util/stream/multi_stream.hh b/util/stream/multi_stream.hh
index b1461f964..6381fc2ed 100644
--- a/util/stream/multi_stream.hh
+++ b/util/stream/multi_stream.hh
@@ -20,6 +20,9 @@ class ChainPositions : public util::FixedArray<util::stream::ChainPosition> {
public:
ChainPositions() {}
+ explicit ChainPositions(std::size_t bound) :
+ util::FixedArray<util::stream::ChainPosition>(bound) {}
+
void Init(Chains &chains);
explicit ChainPositions(Chains &chains) {
@@ -88,16 +91,6 @@ template <class T> class GenericStreams : public util::FixedArray<T> {
public:
GenericStreams() {}
- // This puts a dummy T at the beginning (useful to algorithms that need to reference something at the beginning).
- void InitWithDummy(const ChainPositions &positions) {
- P::Init(positions.size() + 1);
- new (P::end()) T(); // use "placement new" syntax to initalize T in an already-allocated memory location
- P::Constructed();
- for (const util::stream::ChainPosition *i = positions.begin(); i != positions.end(); ++i) {
- P::push_back(*i);
- }
- }
-
// Limit restricts to positions[0,limit)
void Init(const ChainPositions &positions, std::size_t limit) {
P::Init(limit);
@@ -112,6 +105,10 @@ template <class T> class GenericStreams : public util::FixedArray<T> {
GenericStreams(const ChainPositions &positions) {
Init(positions);
}
+
+ void Init(std::size_t amount) {
+ P::Init(amount);
+ }
};
template <class T> inline Chains &operator>>(Chains &chains, GenericStreams<T> &streams) {
diff --git a/util/stream/rewindable_stream.cc b/util/stream/rewindable_stream.cc
index c7e39231b..2867bf8ab 100644
--- a/util/stream/rewindable_stream.cc
+++ b/util/stream/rewindable_stream.cc
@@ -13,105 +13,120 @@ void RewindableStream::Init(const ChainPosition &position) {
UTIL_THROW_IF2(in_, "RewindableStream::Init twice");
in_ = position.in_;
out_ = position.out_;
+ hit_poison_ = false;
poisoned_ = false;
progress_ = position.progress_;
entry_size_ = position.GetChain().EntrySize();
block_size_ = position.GetChain().BlockSize();
- FetchBlock();
- current_bl_ = &second_bl_;
- current_ = static_cast<uint8_t*>(current_bl_->Get());
- end_ = current_ + current_bl_->ValidSize();
-}
-
-const void *RewindableStream::Get() const {
- return current_;
-}
-
-void *RewindableStream::Get() {
- return current_;
+ block_count_ = position.GetChain().BlockCount();
+ blocks_it_ = 0;
+ marked_ = NULL;
+ UTIL_THROW_IF2(block_count_ < 2, "RewindableStream needs block_count at least two");
+ AppendBlock();
}
RewindableStream &RewindableStream::operator++() {
assert(*this);
- assert(current_ < end_);
+ assert(current_ < block_end_);
+ assert(current_);
+ assert(blocks_it_ < blocks_.size());
current_ += entry_size_;
- if (current_ == end_) {
- // two cases: either we need to fetch the next block, or we've already
- // fetched it before. We can check this by looking at the current_bl_
- // pointer: if it's at the second_bl_, we need to flush and fetch a new
- // block. Otherwise, we can just move over to the second block.
- if (current_bl_ == &second_bl_) {
- if (first_bl_) {
- out_->Produce(first_bl_);
- progress_ += first_bl_.ValidSize();
+ if (UTIL_UNLIKELY(current_ == block_end_)) {
+ // Fetch another block if necessary.
+ if (++blocks_it_ == blocks_.size()) {
+ if (!marked_) {
+ Flush(blocks_.begin() + blocks_it_);
+ blocks_it_ = 0;
}
- first_bl_ = second_bl_;
- FetchBlock();
+ AppendBlock();
+ assert(poisoned_ || (blocks_it_ == blocks_.size() - 1));
+ if (poisoned_) return *this;
}
- current_bl_ = &second_bl_;
- current_ = static_cast<uint8_t *>(second_bl_.Get());
- end_ = current_ + second_bl_.ValidSize();
- }
-
- if (!*current_bl_)
- {
- if (current_bl_ == &second_bl_ && first_bl_)
- {
- out_->Produce(first_bl_);
- progress_ += first_bl_.ValidSize();
- }
- out_->Produce(*current_bl_);
- poisoned_ = true;
+ Block &cur_block = blocks_[blocks_it_];
+ current_ = static_cast<uint8_t*>(cur_block.Get());
+ block_end_ = current_ + cur_block.ValidSize();
}
-
+ assert(current_);
+ assert(current_ >= static_cast<uint8_t*>(blocks_[blocks_it_].Get()));
+ assert(current_ < block_end_);
+ assert(block_end_ == blocks_[blocks_it_].ValidEnd());
return *this;
}
-void RewindableStream::FetchBlock() {
- // The loop is needed since it is *feasible* that we're given 0 sized but
- // valid blocks
- do {
- in_->Consume(second_bl_);
- } while (second_bl_ && second_bl_.ValidSize() == 0);
-}
-
void RewindableStream::Mark() {
marked_ = current_;
+ Flush(blocks_.begin() + blocks_it_);
+ blocks_it_ = 0;
}
void RewindableStream::Rewind() {
- if (marked_ >= first_bl_.Get() && marked_ < first_bl_.ValidEnd()) {
- current_bl_ = &first_bl_;
- current_ = marked_;
- } else if (marked_ >= second_bl_.Get() && marked_ < second_bl_.ValidEnd()) {
- current_bl_ = &second_bl_;
- current_ = marked_;
- } else { UTIL_THROW2("RewindableStream rewound too far"); }
+ if (current_ != marked_) {
+ poisoned_ = false;
+ }
+ blocks_it_ = 0;
+ current_ = marked_;
+ block_end_ = static_cast<const uint8_t*>(blocks_[blocks_it_].ValidEnd());
+
+ assert(current_);
+ assert(current_ >= static_cast<uint8_t*>(blocks_[blocks_it_].Get()));
+ assert(current_ < block_end_);
+ assert(block_end_ == blocks_[blocks_it_].ValidEnd());
}
void RewindableStream::Poison() {
- assert(!poisoned_);
+ if (blocks_.empty()) return;
+ assert(*this);
+ assert(blocks_it_ == blocks_.size() - 1);
- // Three things: if we have a buffered first block, we need to produce it
- // first. Then, produce the partial "current" block, and then send the
- // poison down the chain
+ // Produce all buffered blocks.
+ blocks_.back().SetValidSize(current_ - static_cast<uint8_t*>(blocks_.back().Get()));
+ Flush(blocks_.end());
+ blocks_it_ = 0;
- // if we still have a buffered first block, produce it first
- if (current_bl_ == &second_bl_ && first_bl_) {
- out_->Produce(first_bl_);
- progress_ += first_bl_.ValidSize();
+ Block poison;
+ if (!hit_poison_) {
+ in_->Consume(poison);
}
+ poison.SetToPoison();
+ out_->Produce(poison);
+ hit_poison_ = true;
+ poisoned_ = true;
+}
- // send our partial block
- current_bl_->SetValidSize(current_
- - static_cast<uint8_t *>(current_bl_->Get()));
- out_->Produce(*current_bl_);
- progress_ += current_bl_->ValidSize();
+void RewindableStream::AppendBlock() {
+ if (UTIL_UNLIKELY(blocks_.size() >= block_count_)) {
+ std::cerr << "RewindableStream trying to use more blocks than available" << std::endl;
+ abort();
+ }
+ if (UTIL_UNLIKELY(hit_poison_)) {
+ poisoned_ = true;
+ return;
+ }
+ Block get;
+ // The loop is needed since it is *feasible* that we're given 0 sized but
+ // valid blocks
+ do {
+ in_->Consume(get);
+ if (UTIL_LIKELY(get)) {
+ blocks_.push_back(get);
+ } else {
+ hit_poison_ = true;
+ poisoned_ = true;
+ return;
+ }
+ } while (UTIL_UNLIKELY(get.ValidSize() == 0));
+ current_ = static_cast<uint8_t*>(blocks_.back().Get());
+ block_end_ = static_cast<const uint8_t*>(blocks_.back().ValidEnd());
+ blocks_it_ = blocks_.size() - 1;
+}
- // send down the poison
- current_bl_->SetToPoison();
- out_->Produce(*current_bl_);
- poisoned_ = true;
+void RewindableStream::Flush(std::deque<Block>::iterator to) {
+ for (std::deque<Block>::iterator i = blocks_.begin(); i != to; ++i) {
+ out_->Produce(*i);
+ progress_ += i->ValidSize();
+ }
+ blocks_.erase(blocks_.begin(), to);
}
+
}
}
diff --git a/util/stream/rewindable_stream.hh b/util/stream/rewindable_stream.hh
index 9ee637c99..560825cde 100644
--- a/util/stream/rewindable_stream.hh
+++ b/util/stream/rewindable_stream.hh
@@ -5,6 +5,8 @@
#include <boost/noncopyable.hpp>
+#include <deque>
+
namespace util {
namespace stream {
@@ -23,6 +25,10 @@ class RewindableStream : boost::noncopyable {
*/
RewindableStream();
+ ~RewindableStream() {
+ Poison();
+ }
+
/**
* Initializes an existing RewindableStream at a specific position in
* a Chain.
@@ -38,21 +44,32 @@ class RewindableStream : boost::noncopyable {
*
* Equivalent to RewindableStream a(); a.Init(....);
*/
- explicit RewindableStream(const ChainPosition &position);
+ explicit RewindableStream(const ChainPosition &position)
+ : in_(NULL) {
+ Init(position);
+ }
/**
* Gets the record at the current stream position. Const version.
*/
- const void *Get() const;
+ const void *Get() const {
+ assert(!poisoned_);
+ assert(current_);
+ return current_;
+ }
/**
* Gets the record at the current stream position.
*/
- void *Get();
+ void *Get() {
+ assert(!poisoned_);
+ assert(current_);
+ return current_;
+ }
- operator bool() const { return current_; }
+ operator bool() const { return !poisoned_; }
- bool operator!() const { return !(*this); }
+ bool operator!() const { return poisoned_; }
/**
* Marks the current position in the stream to be rewound to later.
@@ -80,19 +97,26 @@ class RewindableStream : boost::noncopyable {
void Poison();
private:
- void FetchBlock();
+ void AppendBlock();
+
+ void Flush(std::deque<Block>::iterator to);
+
+ std::deque<Block> blocks_;
+ // current_ is in blocks_[blocks_it_] unless poisoned_.
+ std::size_t blocks_it_;
std::size_t entry_size_;
std::size_t block_size_;
+ std::size_t block_count_;
- uint8_t *marked_, *current_, *end_;
-
- Block first_bl_;
- Block second_bl_;
- Block* current_bl_;
+ uint8_t *marked_, *current_;
+ const uint8_t *block_end_;
PCQueue<Block> *in_, *out_;
+ // Have we hit poison at the end of the stream, even if rewinding?
+ bool hit_poison_;
+ // Is the curren position poison?
bool poisoned_;
WorkerProgress progress_;
diff --git a/util/stream/rewindable_stream_test.cc b/util/stream/rewindable_stream_test.cc
index 3ed87f372..f8924c3c7 100644
--- a/util/stream/rewindable_stream_test.cc
+++ b/util/stream/rewindable_stream_test.cc
@@ -22,8 +22,8 @@ BOOST_AUTO_TEST_CASE(RewindableStreamTest) {
config.total_memory = 100;
config.block_count = 6;
- RewindableStream s;
Chain chain(config);
+ RewindableStream s;
chain >> Read(in.get()) >> s >> kRecycle;
uint64_t i = 0;
for (; s; ++s, ++i) {