Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/moses-smt/mosesdecoder.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authormarcinj <junczys@amu.edu.pl>2012-11-16 20:37:38 +0400
committermarcinj <junczys@amu.edu.pl>2012-11-16 20:37:38 +0400
commit0edd6b4a61d3fec9601c2870cd5dfcc69b5367dc (patch)
tree55b0a1d1524a609d99de3a54b0bb9cfa2bea2ece
parentf858f919b3c457fab93929a20638934518dfbbc0 (diff)
parent4f3afbf41ff52fa1398b74ba62a5f15ef3fea8f2 (diff)
Merge branch 'master' of github.com:moses-smt/mosesdecoder
-rw-r--r--moses-chart-cmd/IOWrapper.cpp259
-rw-r--r--moses-chart-cmd/IOWrapper.h9
-rw-r--r--moses-chart-cmd/Main.cpp14
-rw-r--r--moses/ChartCellLabel.h4
-rw-r--r--moses/Incremental.cpp296
-rw-r--r--moses/Incremental.h60
-rw-r--r--moses/Incremental/Fill.cpp143
-rw-r--r--moses/Incremental/Fill.h54
-rw-r--r--moses/Incremental/Manager.cpp122
-rw-r--r--moses/Incremental/Manager.h35
-rw-r--r--moses/Jamfile1
-rw-r--r--moses/LM/Ken.cpp2
-rw-r--r--search/Jamfile6
-rw-r--r--search/applied.hh86
-rw-r--r--search/config.hh25
-rw-r--r--search/context.hh28
-rw-r--r--search/edge_generator.cc3
-rw-r--r--search/edge_generator.hh1
-rw-r--r--search/final.hh36
-rw-r--r--search/header.hh9
-rw-r--r--search/nbest.cc106
-rw-r--r--search/nbest.hh81
-rw-r--r--search/note.hh12
-rw-r--r--search/rule.cc52
-rw-r--r--search/rule.hh11
-rw-r--r--search/types.hh17
-rw-r--r--search/vertex.cc27
-rw-r--r--search/vertex.hh37
-rw-r--r--search/vertex_generator.cc50
-rw-r--r--search/vertex_generator.hh72
-rw-r--r--search/weights.cc71
-rw-r--r--search/weights.hh52
-rw-r--r--search/weights_test.cc38
33 files changed, 996 insertions, 823 deletions
diff --git a/moses-chart-cmd/IOWrapper.cpp b/moses-chart-cmd/IOWrapper.cpp
index 401636fa5..571f7e32e 100644
--- a/moses-chart-cmd/IOWrapper.cpp
+++ b/moses-chart-cmd/IOWrapper.cpp
@@ -40,6 +40,7 @@ POSSIBILITY OF SUCH DAMAGE.
#include "moses/StaticData.h"
#include "moses/DummyScoreProducers.h"
#include "moses/InputFileStream.h"
+#include "moses/Incremental.h"
#include "moses/PhraseDictionary.h"
#include "moses/ChartTrellisPathList.h"
#include "moses/ChartTrellisPath.h"
@@ -377,8 +378,136 @@ void IOWrapper::OutputBestHypo(const ChartHypothesis *hypo, long translationId)
m_singleBestOutputCollector->Write(translationId, out.str());
}
-void IOWrapper::OutputNBestList(const ChartTrellisPathList &nBestList, const ChartHypothesis *bestHypo, const TranslationSystem* system, long translationId)
-{
+void IOWrapper::OutputBestHypo(search::Applied applied, long translationId) {
+ if (!m_singleBestOutputCollector) return;
+ std::ostringstream out;
+ IOWrapper::FixPrecision(out);
+ if (StaticData::Instance().GetOutputHypoScore()) {
+ out << applied.GetScore() << ' ';
+ }
+ Phrase outPhrase;
+ Incremental::ToPhrase(applied, outPhrase);
+ // delete 1st & last
+ CHECK(outPhrase.GetSize() >= 2);
+ outPhrase.RemoveWord(0);
+ outPhrase.RemoveWord(outPhrase.GetSize() - 1);
+ out << outPhrase.GetStringRep(StaticData::Instance().GetOutputFactorOrder());
+ out << '\n';
+ m_singleBestOutputCollector->Write(translationId, out.str());
+}
+
+void IOWrapper::OutputBestNone(long translationId) {
+ if (!m_singleBestOutputCollector) return;
+ if (StaticData::Instance().GetOutputHypoScore()) {
+ m_singleBestOutputCollector->Write(translationId, "0 \n");
+ } else {
+ m_singleBestOutputCollector->Write(translationId, "\n");
+ }
+}
+
+namespace {
+
+void OutputSparseFeatureScores(std::ostream& out, const ScoreComponentCollection &features, const FeatureFunction *ff, std::string &lastName) {
+ const StaticData &staticData = StaticData::Instance();
+ bool labeledOutput = staticData.IsLabeledNBestList();
+ const FVector scores = features.GetVectorForProducer( ff );
+
+ // report weighted aggregate
+ if (! ff->GetSparseFeatureReporting()) {
+ const FVector &weights = staticData.GetAllWeights().GetScoresVector();
+ if (labeledOutput && !boost::contains(ff->GetScoreProducerDescription(), ":"))
+ out << " " << ff->GetScoreProducerWeightShortName() << ":";
+ out << " " << scores.inner_product(weights);
+ }
+
+ // report each feature
+ else {
+ for(FVector::FNVmap::const_iterator i = scores.cbegin(); i != scores.cend(); i++) {
+ if (i->second != 0) { // do not report zero-valued features
+ if (labeledOutput)
+ out << " " << i->first << ":";
+ out << " " << i->second;
+ }
+ }
+ }
+}
+
+void WriteFeatures(const TranslationSystem &system, const ScoreComponentCollection &features, std::ostream &out) {
+ bool labeledOutput = StaticData::Instance().IsLabeledNBestList();
+ // lm
+ const LMList& lml = system.GetLanguageModels();
+ if (lml.size() > 0) {
+ if (labeledOutput)
+ out << "lm:";
+ LMList::const_iterator lmi = lml.begin();
+ for (; lmi != lml.end(); ++lmi) {
+ out << " " << features.GetScoreForProducer(*lmi);
+ }
+ }
+
+ std::string lastName = "";
+
+ // output stateful sparse features
+ const vector<const StatefulFeatureFunction*>& sff = system.GetStatefulFeatureFunctions();
+ for( size_t i=0; i<sff.size(); i++ )
+ if (sff[i]->GetNumScoreComponents() == ScoreProducer::unlimited)
+ OutputSparseFeatureScores(out, features, sff[i], lastName);
+
+ // translation components
+ const vector<PhraseDictionaryFeature*>& pds = system.GetPhraseDictionaries();
+ if (pds.size() > 0) {
+ for( size_t i=0; i<pds.size(); i++ ) {
+ size_t pd_numinputscore = pds[i]->GetNumInputScores();
+ vector<float> scores = features.GetScoresForProducer( pds[i] );
+ for (size_t j = 0; j<scores.size(); ++j){
+ if (labeledOutput && (i == 0) ){
+ if ((j == 0) || (j == pd_numinputscore)){
+ lastName = pds[i]->GetScoreProducerWeightShortName(j);
+ out << " " << lastName << ":";
+ }
+ }
+ out << " " << scores[j];
+ }
+ }
+ }
+
+ // word penalty
+ if (labeledOutput)
+ out << " w:";
+ out << " " << features.GetScoreForProducer(system.GetWordPenaltyProducer());
+
+ // generation
+ const vector<GenerationDictionary*>& gds = system.GetGenerationDictionaries();
+ if (gds.size() > 0) {
+ for( size_t i=0; i<gds.size(); i++ ) {
+ size_t pd_numinputscore = gds[i]->GetNumInputScores();
+ vector<float> scores = features.GetScoresForProducer( gds[i] );
+ for (size_t j = 0; j<scores.size(); ++j){
+ if (labeledOutput && (i == 0) ){
+ if ((j == 0) || (j == pd_numinputscore)){
+ lastName = gds[i]->GetScoreProducerWeightShortName(j);
+ out << " " << lastName << ":";
+ }
+ }
+ out << " " << scores[j];
+ }
+ }
+ }
+
+ // output stateless sparse features
+ lastName = "";
+
+ const vector<const StatelessFeatureFunction*>& slf = system.GetStatelessFeatureFunctions();
+ for( size_t i=0; i<slf.size(); i++ ) {
+ if (slf[i]->GetNumScoreComponents() == ScoreProducer::unlimited) {
+ OutputSparseFeatureScores(out, features, slf[i], lastName);
+ }
+ }
+}
+
+} // namespace
+
+void IOWrapper::OutputNBestList(const ChartTrellisPathList &nBestList, const TranslationSystem* system, long translationId) {
std::ostringstream out;
// Check if we're writing to std::cout.
@@ -387,17 +516,10 @@ void IOWrapper::OutputNBestList(const ChartTrellisPathList &nBestList, const Cha
// preserve existing behaviour, but should probably be done either way.
IOWrapper::FixPrecision(out);
- // The output from -output-hypo-score is always written to std::cout.
- if (StaticData::Instance().GetOutputHypoScore()) {
- if (bestHypo != NULL) {
- out << bestHypo->GetTotalScore() << " ";
- } else {
- out << "0 ";
- }
- }
+ // Used to check StaticData's GetOutputHypoScore(), but it makes no sense with nbest output.
}
- bool labeledOutput = StaticData::Instance().IsLabeledNBestList();
+ //bool includeAlignment = StaticData::Instance().NBestIncludesAlignment();
bool includeWordAlignment = StaticData::Instance().PrintAlignmentInfoInNbest();
ChartTrellisPathList::const_iterator iter;
@@ -421,75 +543,7 @@ void IOWrapper::OutputNBestList(const ChartTrellisPathList &nBestList, const Cha
// before each model type, the corresponding command-line-like name must be emitted
// MERT script relies on this
- // lm
- const LMList& lml = system->GetLanguageModels();
- if (lml.size() > 0) {
- if (labeledOutput)
- out << "lm:";
- LMList::const_iterator lmi = lml.begin();
- for (; lmi != lml.end(); ++lmi) {
- out << " " << path.GetScoreBreakdown().GetScoreForProducer(*lmi);
- }
- }
-
- std::string lastName = "";
-
- // output stateful sparse features
- const vector<const StatefulFeatureFunction*>& sff = system->GetStatefulFeatureFunctions();
- for( size_t i=0; i<sff.size(); i++ )
- if (sff[i]->GetNumScoreComponents() == ScoreProducer::unlimited)
- OutputSparseFeatureScores( out, path, sff[i], lastName );
-
- // translation components
- const vector<PhraseDictionaryFeature*>& pds = system->GetPhraseDictionaries();
- if (pds.size() > 0) {
- for( size_t i=0; i<pds.size(); i++ ) {
- size_t pd_numinputscore = pds[i]->GetNumInputScores();
- vector<float> scores = path.GetScoreBreakdown().GetScoresForProducer( pds[i] );
- for (size_t j = 0; j<scores.size(); ++j){
- if (labeledOutput && (i == 0) ){
- if ((j == 0) || (j == pd_numinputscore)){
- lastName = pds[i]->GetScoreProducerWeightShortName(j);
- out << " " << lastName << ":";
- }
- }
- out << " " << scores[j];
- }
- }
- }
-
- // word penalty
- if (labeledOutput)
- out << " w:";
- out << " " << path.GetScoreBreakdown().GetScoreForProducer(system->GetWordPenaltyProducer());
-
- // generation
- const vector<GenerationDictionary*>& gds = system->GetGenerationDictionaries();
- if (gds.size() > 0) {
- for( size_t i=0; i<gds.size(); i++ ) {
- size_t pd_numinputscore = gds[i]->GetNumInputScores();
- vector<float> scores = path.GetScoreBreakdown().GetScoresForProducer( gds[i] );
- for (size_t j = 0; j<scores.size(); ++j){
- if (labeledOutput && (i == 0) ){
- if ((j == 0) || (j == pd_numinputscore)){
- lastName = gds[i]->GetScoreProducerWeightShortName(j);
- out << " " << lastName << ":";
- }
- }
- out << " " << scores[j];
- }
- }
- }
-
- // output stateless sparse features
- lastName = "";
-
- const vector<const StatelessFeatureFunction*>& slf = system->GetStatelessFeatureFunctions();
- for( size_t i=0; i<slf.size(); i++ ) {
- if (slf[i]->GetNumScoreComponents() == ScoreProducer::unlimited) {
- OutputSparseFeatureScores( out, path, slf[i], lastName );
- }
- }
+ WriteFeatures(*system, path.GetScoreBreakdown(), out);
// total
out << " ||| " << path.GetTotalScore();
@@ -524,34 +578,33 @@ void IOWrapper::OutputNBestList(const ChartTrellisPathList &nBestList, const Cha
out <<std::flush;
- CHECK(m_nBestOutputCollector);
+ assert(m_nBestOutputCollector);
m_nBestOutputCollector->Write(translationId, out.str());
}
-void IOWrapper::OutputSparseFeatureScores( std::ostream& out, const ChartTrellisPath &path, const FeatureFunction *ff, std::string &lastName )
-{
- const StaticData &staticData = StaticData::Instance();
- bool labeledOutput = staticData.IsLabeledNBestList();
- const FVector scores = path.GetScoreBreakdown().GetVectorForProducer( ff );
-
- // report weighted aggregate
- if (! ff->GetSparseFeatureReporting()) {
- const FVector &weights = staticData.GetAllWeights().GetScoresVector();
- if (labeledOutput && !boost::contains(ff->GetScoreProducerDescription(), ":"))
- out << " " << ff->GetScoreProducerWeightShortName() << ":";
- out << " " << scores.inner_product(weights);
+void IOWrapper::OutputNBestList(const std::vector<search::Applied> &nbest, const TranslationSystem &system, long translationId) {
+ std::ostringstream out;
+ // wtf? copied from the original OutputNBestList
+ if (m_nBestOutputCollector->OutputIsCout()) {
+ IOWrapper::FixPrecision(out);
}
-
- // report each feature
- else {
- for(FVector::FNVmap::const_iterator i = scores.cbegin(); i != scores.cend(); i++) {
- if (i->second != 0) { // do not report zero-valued features
- if (labeledOutput)
- out << " " << i->first << ":";
- out << " " << i->second;
- }
- }
+ Phrase outputPhrase;
+ ScoreComponentCollection features;
+ for (std::vector<search::Applied>::const_iterator i = nbest.begin(); i != nbest.end(); ++i) {
+ Incremental::PhraseAndFeatures(system, *i, outputPhrase, features);
+ // <s> and </s>
+ CHECK(outputPhrase.GetSize() >= 2);
+ outputPhrase.RemoveWord(0);
+ outputPhrase.RemoveWord(outputPhrase.GetSize() - 1);
+ out << translationId << " ||| ";
+ OutputSurface(out, outputPhrase, m_outputFactorOrder, false);
+ out << " ||| ";
+ WriteFeatures(system, features, out);
+ out << " ||| " << i->GetScore() << '\n';
}
+ out << std::flush;
+ assert(m_nBestOutputCollector);
+ m_nBestOutputCollector->Write(translationId, out.str());
}
void IOWrapper::FixPrecision(std::ostream &stream, size_t size)
diff --git a/moses-chart-cmd/IOWrapper.h b/moses-chart-cmd/IOWrapper.h
index dea7355d0..5686d5728 100644
--- a/moses-chart-cmd/IOWrapper.h
+++ b/moses-chart-cmd/IOWrapper.h
@@ -44,6 +44,7 @@ POSSIBILITY OF SUCH DAMAGE.
#include "moses/OutputCollector.h"
#include "moses/ChartHypothesis.h"
#include "moses/ChartTrellisPath.h"
+#include "search/applied.hh"
namespace Moses
{
@@ -92,14 +93,14 @@ public:
Moses::InputType* GetInput(Moses::InputType *inputType);
void OutputBestHypo(const Moses::ChartHypothesis *hypo, long translationId);
+ void OutputBestHypo(search::Applied applied, long translationId);
void OutputBestHypo(const std::vector<const Moses::Factor*>& mbrBestHypo, long translationId);
- void OutputNBestList(const Moses::ChartTrellisPathList &nBestList, const Moses::ChartHypothesis *bestHypo, const Moses::TranslationSystem* system, long translationId);
- void OutputSparseFeatureScores(std::ostream& out, const Moses::ChartTrellisPath &path, const Moses::FeatureFunction *ff, std::string &lastName);
+ void OutputBestNone(long translationId);
+ void OutputNBestList(const Moses::ChartTrellisPathList &nBestList, const Moses::TranslationSystem* system, long translationId);
+ void OutputNBestList(const std::vector<search::Applied> &nbest, const Moses::TranslationSystem &system, long translationId);
void OutputDetailedTranslationReport(const Moses::ChartHypothesis *hypo, const Moses::Sentence &sentence, long translationId);
void Backtrack(const Moses::ChartHypothesis *hypo);
- Moses::OutputCollector *ExposeSingleBest() { return m_singleBestOutputCollector; }
-
void ResetTranslationId();
Moses::OutputCollector *GetSearchGraphOutputCollector() {
diff --git a/moses-chart-cmd/Main.cpp b/moses-chart-cmd/Main.cpp
index ee8099e3f..278783926 100644
--- a/moses-chart-cmd/Main.cpp
+++ b/moses-chart-cmd/Main.cpp
@@ -59,7 +59,7 @@ POSSIBILITY OF SUCH DAMAGE.
#include "moses/ChartHypothesis.h"
#include "moses/ChartTrellisPath.h"
#include "moses/ChartTrellisPathList.h"
-#include "moses/Incremental/Manager.h"
+#include "moses/Incremental.h"
#include "util/usage.hh"
@@ -91,10 +91,14 @@ public:
if (staticData.GetSearchAlgorithm() == ChartIncremental) {
Incremental::Manager manager(*m_source, system);
- manager.ProcessSentence();
- if (m_ioWrapper.ExposeSingleBest()) {
- m_ioWrapper.ExposeSingleBest()->Write(translationId, manager.String() + '\n');
+ const std::vector<search::Applied> &nbest = manager.ProcessSentence();
+ if (!nbest.empty()) {
+ m_ioWrapper.OutputBestHypo(nbest[0], translationId);
+ } else {
+ m_ioWrapper.OutputBestNone(translationId);
}
+ if (staticData.GetNBestSize() > 0)
+ m_ioWrapper.OutputNBestList(nbest, system, translationId);
return;
}
@@ -125,7 +129,7 @@ public:
VERBOSE(2,"WRITING " << nBestSize << " TRANSLATION ALTERNATIVES TO " << staticData.GetNBestFilePath() << endl);
ChartTrellisPathList nBestList;
manager.CalcNBest(nBestSize, nBestList,staticData.GetDistinctNBest());
- m_ioWrapper.OutputNBestList(nBestList, bestHypo, &system, translationId);
+ m_ioWrapper.OutputNBestList(nBestList, &system, translationId);
IFVERBOSE(2) {
PrintUserTime("N-Best Hypotheses Generation Time:");
}
diff --git a/moses/ChartCellLabel.h b/moses/ChartCellLabel.h
index c44462fcc..9fccf71e9 100644
--- a/moses/ChartCellLabel.h
+++ b/moses/ChartCellLabel.h
@@ -23,7 +23,7 @@
#include "Word.h"
#include "WordsRange.h"
-namespace search { class Vertex; class VertexGenerator; }
+namespace search { class Vertex; }
namespace Moses
{
@@ -41,7 +41,7 @@ class ChartCellLabel
union Stack {
const HypoList *cube; // cube pruning
const search::Vertex *incr; // incremental search after filling.
- search::VertexGenerator *incr_generator; // incremental search during filling.
+ void *incr_generator; // incremental search during filling.
};
diff --git a/moses/Incremental.cpp b/moses/Incremental.cpp
new file mode 100644
index 000000000..770b0d67e
--- /dev/null
+++ b/moses/Incremental.cpp
@@ -0,0 +1,296 @@
+#include "moses/Incremental.h"
+
+#include "moses/ChartCell.h"
+#include "moses/ChartParserCallback.h"
+#include "moses/FeatureVector.h"
+#include "moses/StaticData.h"
+#include "moses/TranslationSystem.h"
+#include "moses/Util.h"
+
+#include "lm/model.hh"
+#include "search/applied.hh"
+#include "search/config.hh"
+#include "search/context.hh"
+#include "search/edge_generator.hh"
+#include "search/rule.hh"
+#include "search/vertex_generator.hh"
+
+#include <boost/lexical_cast.hpp>
+
+namespace Moses {
+namespace Incremental {
+namespace {
+
+// This is called by EdgeGenerator. Route hypotheses to separate vertices for
+// each left hand side label, populating ChartCellLabelSet out.
+template <class Best> class HypothesisCallback {
+ private:
+ typedef search::VertexGenerator<Best> Gen;
+ public:
+ HypothesisCallback(search::ContextBase &context, Best &best, ChartCellLabelSet &out, boost::object_pool<search::Vertex> &vertex_pool)
+ : context_(context), best_(best), out_(out), vertex_pool_(vertex_pool) {}
+
+ void NewHypothesis(search::PartialEdge partial) {
+ // Get the LHS, look it up in the output ChartCellLabel, and upcast it.
+ // It's not part of the union because it would have been ugly to expose template types in ChartCellLabel.
+ ChartCellLabel::Stack &stack = out_.FindOrInsert(static_cast<const TargetPhrase *>(partial.GetNote().vp)->GetTargetLHS());
+ Gen *entry = static_cast<Gen*>(stack.incr_generator);
+ if (!entry) {
+ entry = generator_pool_.construct(context_, *vertex_pool_.construct(), best_);
+ stack.incr_generator = entry;
+ }
+ entry->NewHypothesis(partial);
+ }
+
+ void FinishedSearch() {
+ for (ChartCellLabelSet::iterator i(out_.mutable_begin()); i != out_.mutable_end(); ++i) {
+ ChartCellLabel::Stack &stack = i->second.MutableStack();
+ Gen *gen = static_cast<Gen*>(stack.incr_generator);
+ gen->FinishedSearch();
+ stack.incr = &gen->Generating();
+ }
+ }
+
+ private:
+ search::ContextBase &context_;
+
+ Best &best_;
+
+ ChartCellLabelSet &out_;
+
+ boost::object_pool<search::Vertex> &vertex_pool_;
+ boost::object_pool<Gen> generator_pool_;
+};
+
+// This is called by the moses parser to collect hypotheses. It converts to my
+// edges (search::PartialEdge).
+template <class Model> class Fill : public ChartParserCallback {
+ public:
+ Fill(search::Context<Model> &context, const std::vector<lm::WordIndex> &vocab_mapping, search::Score oov_weight)
+ : context_(context), vocab_mapping_(vocab_mapping), oov_weight_(oov_weight) {}
+
+ void Add(const TargetPhraseCollection &targets, const StackVec &nts, const WordsRange &ignored);
+
+ void AddPhraseOOV(TargetPhrase &phrase, std::list<TargetPhraseCollection*> &waste_memory, const WordsRange &range);
+
+ bool Empty() const { return edges_.Empty(); }
+
+ template <class Best> void Search(Best &best, ChartCellLabelSet &out, boost::object_pool<search::Vertex> &vertex_pool) {
+ HypothesisCallback<Best> callback(context_, best, out, vertex_pool);
+ edges_.Search(context_, callback);
+ }
+
+ // Root: everything into one vertex.
+ template <class Best> search::History RootSearch(Best &best) {
+ search::Vertex vertex;
+ search::RootVertexGenerator<Best> gen(vertex, best);
+ edges_.Search(context_, gen);
+ return vertex.BestChild();
+ }
+
+ private:
+ lm::WordIndex Convert(const Word &word) const;
+
+ search::Context<Model> &context_;
+
+ const std::vector<lm::WordIndex> &vocab_mapping_;
+
+ search::EdgeGenerator edges_;
+
+ const search::Score oov_weight_;
+};
+
+template <class Model> void Fill<Model>::Add(const TargetPhraseCollection &targets, const StackVec &nts, const WordsRange &) {
+ std::vector<search::PartialVertex> vertices;
+ vertices.reserve(nts.size());
+ float below_score = 0.0;
+ for (StackVec::const_iterator i(nts.begin()); i != nts.end(); ++i) {
+ vertices.push_back((*i)->GetStack().incr->RootPartial());
+ if (vertices.back().Empty()) return;
+ below_score += vertices.back().Bound();
+ }
+
+ std::vector<lm::WordIndex> words;
+ for (TargetPhraseCollection::const_iterator p(targets.begin()); p != targets.end(); ++p) {
+ words.clear();
+ const TargetPhrase &phrase = **p;
+ const AlignmentInfo::NonTermIndexMap &align = phrase.GetAlignNonTerm().GetNonTermIndexMap();
+ search::PartialEdge edge(edges_.AllocateEdge(nts.size()));
+
+ search::PartialVertex *nt = edge.NT();
+ for (size_t i = 0; i < phrase.GetSize(); ++i) {
+ const Word &word = phrase.GetWord(i);
+ if (word.IsNonTerminal()) {
+ *(nt++) = vertices[align[i]];
+ words.push_back(search::kNonTerminal);
+ } else {
+ words.push_back(Convert(word));
+ }
+ }
+
+ edge.SetScore(phrase.GetFutureScore() + below_score);
+ // prob and oov were already accounted for.
+ search::ScoreRule(context_.LanguageModel(), words, edge.Between());
+
+ search::Note note;
+ note.vp = &phrase;
+ edge.SetNote(note);
+
+ edges_.AddEdge(edge);
+ }
+}
+
+template <class Model> void Fill<Model>::AddPhraseOOV(TargetPhrase &phrase, std::list<TargetPhraseCollection*> &, const WordsRange &) {
+ std::vector<lm::WordIndex> words;
+ CHECK(phrase.GetSize() <= 1);
+ if (phrase.GetSize())
+ words.push_back(Convert(phrase.GetWord(0)));
+
+ search::PartialEdge edge(edges_.AllocateEdge(0));
+ // Appears to be a bug that FutureScore does not already include language model.
+ search::ScoreRuleRet scored(search::ScoreRule(context_.LanguageModel(), words, edge.Between()));
+ edge.SetScore(phrase.GetFutureScore() + scored.prob * context_.LMWeight() + static_cast<search::Score>(scored.oov) * oov_weight_);
+
+ search::Note note;
+ note.vp = &phrase;
+ edge.SetNote(note);
+
+ edges_.AddEdge(edge);
+}
+
+// TODO: factors (but chart doesn't seem to support factors anyway).
+template <class Model> lm::WordIndex Fill<Model>::Convert(const Word &word) const {
+ std::size_t factor = word.GetFactor(0)->GetId();
+ return (factor >= vocab_mapping_.size() ? 0 : vocab_mapping_[factor]);
+}
+
+struct ChartCellBaseFactory {
+ ChartCellBase *operator()(size_t startPos, size_t endPos) const {
+ return new ChartCellBase(startPos, endPos);
+ }
+};
+
+} // namespace
+
+Manager::Manager(const InputType &source, const TranslationSystem &system) :
+ source_(source),
+ system_(system),
+ cells_(source, ChartCellBaseFactory()),
+ parser_(source, system, cells_),
+ n_best_(search::NBestConfig(StaticData::Instance().GetNBestSize())) {}
+
+Manager::~Manager() {
+ system_.CleanUpAfterSentenceProcessing(source_);
+}
+
+template <class Model, class Best> search::History Manager::PopulateBest(const Model &model, const std::vector<lm::WordIndex> &words, Best &out) {
+ const LanguageModel &abstract = **system_.GetLanguageModels().begin();
+ const float oov_weight = abstract.OOVFeatureEnabled() ? abstract.GetOOVWeight() : 0.0;
+ const StaticData &data = StaticData::Instance();
+ search::Config config(abstract.GetWeight(), data.GetCubePruningPopLimit(), search::NBestConfig(data.GetNBestSize()));
+ search::Context<Model> context(config, model);
+
+ size_t size = source_.GetSize();
+ boost::object_pool<search::Vertex> vertex_pool(std::max<size_t>(size * size / 2, 32));
+
+ for (size_t width = 1; width < size; ++width) {
+ for (size_t startPos = 0; startPos <= size-width; ++startPos) {
+ WordsRange range(startPos, startPos + width - 1);
+ Fill<Model> filler(context, words, oov_weight);
+ parser_.Create(range, filler);
+ filler.Search(out, cells_.MutableBase(range).MutableTargetLabelSet(), vertex_pool);
+ }
+ }
+
+ WordsRange range(0, size - 1);
+ Fill<Model> filler(context, words, oov_weight);
+ parser_.Create(range, filler);
+ return filler.RootSearch(out);
+}
+
+template <class Model> void Manager::LMCallback(const Model &model, const std::vector<lm::WordIndex> &words) {
+ std::size_t nbest = StaticData::Instance().GetNBestSize();
+ if (nbest <= 1) {
+ search::History ret = PopulateBest(model, words, single_best_);
+ if (ret) {
+ backing_for_single_.resize(1);
+ backing_for_single_[0] = search::Applied(ret);
+ } else {
+ backing_for_single_.clear();
+ }
+ completed_nbest_ = &backing_for_single_;
+ } else {
+ search::History ret = PopulateBest(model, words, n_best_);
+ if (ret) {
+ completed_nbest_ = &n_best_.Extract(ret);
+ } else {
+ backing_for_single_.clear();
+ completed_nbest_ = &backing_for_single_;
+ }
+ }
+}
+
+template void Manager::LMCallback<lm::ngram::ProbingModel>(const lm::ngram::ProbingModel &model, const std::vector<lm::WordIndex> &words);
+template void Manager::LMCallback<lm::ngram::RestProbingModel>(const lm::ngram::RestProbingModel &model, const std::vector<lm::WordIndex> &words);
+template void Manager::LMCallback<lm::ngram::TrieModel>(const lm::ngram::TrieModel &model, const std::vector<lm::WordIndex> &words);
+template void Manager::LMCallback<lm::ngram::QuantTrieModel>(const lm::ngram::QuantTrieModel &model, const std::vector<lm::WordIndex> &words);
+template void Manager::LMCallback<lm::ngram::ArrayTrieModel>(const lm::ngram::ArrayTrieModel &model, const std::vector<lm::WordIndex> &words);
+template void Manager::LMCallback<lm::ngram::QuantArrayTrieModel>(const lm::ngram::QuantArrayTrieModel &model, const std::vector<lm::WordIndex> &words);
+
+const std::vector<search::Applied> &Manager::ProcessSentence() {
+ const LMList &lms = system_.GetLanguageModels();
+ UTIL_THROW_IF(lms.size() != 1, util::Exception, "Incremental search only supports one language model.");
+ (*lms.begin())->IncrementalCallback(*this);
+ return *completed_nbest_;
+}
+
+namespace {
+
+struct NoOp {
+ void operator()(const TargetPhrase &) const {}
+};
+struct AccumScore {
+ AccumScore(ScoreComponentCollection &out) : out_(&out) {}
+ void operator()(const TargetPhrase &phrase) {
+ out_->PlusEquals(phrase.GetScoreBreakdown());
+ }
+ ScoreComponentCollection *out_;
+};
+template <class Action> void AppendToPhrase(const search::Applied final, Phrase &out, Action action) {
+ assert(final.Valid());
+ const TargetPhrase &phrase = *static_cast<const TargetPhrase*>(final.GetNote().vp);
+ action(phrase);
+ const search::Applied *child = final.Children();
+ for (std::size_t i = 0; i < phrase.GetSize(); ++i) {
+ const Word &word = phrase.GetWord(i);
+ if (word.IsNonTerminal()) {
+ AppendToPhrase(*child++, out, action);
+ } else {
+ out.AddWord(word);
+ }
+ }
+}
+
+} // namespace
+
+void ToPhrase(const search::Applied final, Phrase &out) {
+ out.Clear();
+ AppendToPhrase(final, out, NoOp());
+}
+
+void PhraseAndFeatures(const TranslationSystem &system, const search::Applied final, Phrase &phrase, ScoreComponentCollection &features) {
+ phrase.Clear();
+ features.ZeroAll();
+ AppendToPhrase(final, phrase, AccumScore(features));
+
+ // If we made it this far, there is only one language model.
+ float full, ignored_ngram;
+ std::size_t ignored_oov;
+ const LanguageModel &model = **system.GetLanguageModels().begin();
+ model.CalcScore(phrase, full, ignored_ngram, ignored_oov);
+ // CalcScore transforms, but EvaluateChart doesn't.
+ features.Assign(&model, UntransformLMScore(full));
+}
+
+} // namespace Incremental
+} // namespace Moses
diff --git a/moses/Incremental.h b/moses/Incremental.h
new file mode 100644
index 000000000..4bfc2dae3
--- /dev/null
+++ b/moses/Incremental.h
@@ -0,0 +1,60 @@
+#pragma once
+
+#include "lm/word_index.hh"
+#include "search/applied.hh"
+#include "search/nbest.hh"
+
+#include "moses/ChartCellCollection.h"
+#include "moses/ChartParser.h"
+
+#include <vector>
+#include <string>
+
+namespace Moses {
+class ScoreComponentCollection;
+class InputType;
+class TranslationSystem;
+namespace Incremental {
+
+class Manager {
+ public:
+ Manager(const InputType &source, const TranslationSystem &system);
+
+ ~Manager();
+
+ template <class Model> void LMCallback(const Model &model, const std::vector<lm::WordIndex> &words);
+
+ const std::vector<search::Applied> &ProcessSentence();
+
+ // Call to get the same value as ProcessSentence returned.
+ const std::vector<search::Applied> &Completed() const {
+ return *completed_nbest_;
+ }
+
+ private:
+ template <class Model, class Best> search::History PopulateBest(const Model &model, const std::vector<lm::WordIndex> &words, Best &out);
+
+ const InputType &source_;
+ const TranslationSystem &system_;
+ ChartCellCollectionBase cells_;
+ ChartParser parser_;
+
+ // Only one of single_best_ or n_best_ will be used, but it was easier to do this than a template.
+ search::SingleBest single_best_;
+ // ProcessSentence returns a reference to a vector. ProcessSentence
+ // doesn't have one, so this is populated and returned.
+ std::vector<search::Applied> backing_for_single_;
+
+ search::NBest n_best_;
+
+ const std::vector<search::Applied> *completed_nbest_;
+};
+
+// Just get the phrase.
+void ToPhrase(const search::Applied final, Phrase &out);
+// Get the phrase and the features.
+void PhraseAndFeatures(const TranslationSystem &system, const search::Applied final, Phrase &phrase, ScoreComponentCollection &features);
+
+} // namespace Incremental
+} // namespace Moses
+
diff --git a/moses/Incremental/Fill.cpp b/moses/Incremental/Fill.cpp
deleted file mode 100644
index 6f0baba92..000000000
--- a/moses/Incremental/Fill.cpp
+++ /dev/null
@@ -1,143 +0,0 @@
-#include "Fill.h"
-
-#include "moses/ChartCellLabel.h"
-#include "moses/ChartCellLabelSet.h"
-#include "moses/TargetPhraseCollection.h"
-#include "moses/TargetPhrase.h"
-#include "moses/Word.h"
-
-#include "lm/model.hh"
-#include "search/context.hh"
-#include "search/note.hh"
-#include "search/rule.hh"
-#include "search/vertex.hh"
-#include "search/vertex_generator.hh"
-
-#include <math.h>
-
-namespace Moses {
-namespace Incremental {
-
-template <class Model> Fill<Model>::Fill(search::Context<Model> &context, const std::vector<lm::WordIndex> &vocab_mapping)
- : context_(context), vocab_mapping_(vocab_mapping) {}
-
-template <class Model> void Fill<Model>::Add(const TargetPhraseCollection &targets, const StackVec &nts, const WordsRange &) {
- std::vector<search::PartialVertex> vertices;
- vertices.reserve(nts.size());
- float below_score = 0.0;
- for (StackVec::const_iterator i(nts.begin()); i != nts.end(); ++i) {
- vertices.push_back((*i)->GetStack().incr->RootPartial());
- if (vertices.back().Empty()) return;
- below_score += vertices.back().Bound();
- }
-
- std::vector<lm::WordIndex> words;
- for (TargetPhraseCollection::const_iterator p(targets.begin()); p != targets.end(); ++p) {
- words.clear();
- const TargetPhrase &phrase = **p;
- const AlignmentInfo::NonTermIndexMap &align = phrase.GetAlignNonTerm().GetNonTermIndexMap();
- search::PartialEdge edge(edges_.AllocateEdge(nts.size()));
-
- size_t i = 0;
- bool bos = false;
- search::PartialVertex *nt = edge.NT();
- if (phrase.GetSize() && !phrase.GetWord(0).IsNonTerminal()) {
- lm::WordIndex index = Convert(phrase.GetWord(0));
- if (context_.LanguageModel().GetVocabulary().BeginSentence() == index) {
- bos = true;
- } else {
- words.push_back(index);
- }
- i = 1;
- }
- for (; i < phrase.GetSize(); ++i) {
- const Word &word = phrase.GetWord(i);
- if (word.IsNonTerminal()) {
- *(nt++) = vertices[align[i]];
- words.push_back(search::kNonTerminal);
- } else {
- words.push_back(Convert(word));
- }
- }
-
- edge.SetScore(phrase.GetFutureScore() + below_score);
- search::ScoreRule(context_, words, bos, edge.Between());
-
- search::Note note;
- note.vp = &phrase;
- edge.SetNote(note);
-
- edges_.AddEdge(edge);
- }
-}
-
-template <class Model> void Fill<Model>::AddPhraseOOV(TargetPhrase &phrase, std::list<TargetPhraseCollection*> &, const WordsRange &) {
- std::vector<lm::WordIndex> words;
- CHECK(phrase.GetSize() <= 1);
- if (phrase.GetSize())
- words.push_back(Convert(phrase.GetWord(0)));
-
- search::PartialEdge edge(edges_.AllocateEdge(0));
- // Appears to be a bug that FutureScore does not already include language model.
- edge.SetScore(phrase.GetFutureScore() + search::ScoreRule(context_, words, false, edge.Between()));
-
- search::Note note;
- note.vp = &phrase;
- edge.SetNote(note);
-
- edges_.AddEdge(edge);
-}
-
-namespace {
-// Route hypotheses to separate vertices for each left hand side label, populating ChartCellLabelSet out.
-class HypothesisCallback {
- public:
- HypothesisCallback(search::ContextBase &context, ChartCellLabelSet &out, boost::object_pool<search::Vertex> &vertex_pool)
- : context_(context), out_(out), vertex_pool_(vertex_pool) {}
-
- void NewHypothesis(search::PartialEdge partial) {
- search::VertexGenerator *&entry = out_.FindOrInsert(static_cast<const TargetPhrase *>(partial.GetNote().vp)->GetTargetLHS()).incr_generator;
- if (!entry) {
- entry = generator_pool_.construct(context_, *vertex_pool_.construct());
- }
- entry->NewHypothesis(partial);
- }
-
- void FinishedSearch() {
- for (ChartCellLabelSet::iterator i(out_.mutable_begin()); i != out_.mutable_end(); ++i) {
- ChartCellLabel::Stack &stack = i->second.MutableStack();
- stack.incr_generator->FinishedSearch();
- stack.incr = &stack.incr_generator->Generating();
- }
- }
-
- private:
- search::ContextBase &context_;
-
- ChartCellLabelSet &out_;
-
- boost::object_pool<search::Vertex> &vertex_pool_;
- boost::object_pool<search::VertexGenerator> generator_pool_;
-};
-} // namespace
-
-template <class Model> void Fill<Model>::Search(ChartCellLabelSet &out, boost::object_pool<search::Vertex> &vertex_pool) {
- HypothesisCallback callback(context_, out, vertex_pool);
- edges_.Search(context_, callback);
-}
-
-// TODO: factors (but chart doesn't seem to support factors anyway).
-template <class Model> lm::WordIndex Fill<Model>::Convert(const Word &word) const {
- std::size_t factor = word.GetFactor(0)->GetId();
- return (factor >= vocab_mapping_.size() ? 0 : vocab_mapping_[factor]);
-}
-
-template class Fill<lm::ngram::ProbingModel>;
-template class Fill<lm::ngram::RestProbingModel>;
-template class Fill<lm::ngram::TrieModel>;
-template class Fill<lm::ngram::QuantTrieModel>;
-template class Fill<lm::ngram::ArrayTrieModel>;
-template class Fill<lm::ngram::QuantArrayTrieModel>;
-
-} // namespace Incremental
-} // namespace Moses
diff --git a/moses/Incremental/Fill.h b/moses/Incremental/Fill.h
deleted file mode 100644
index 0f4059d09..000000000
--- a/moses/Incremental/Fill.h
+++ /dev/null
@@ -1,54 +0,0 @@
-#pragma once
-
-#include "moses/ChartParserCallback.h"
-#include "moses/StackVec.h"
-
-#include "lm/word_index.hh"
-#include "search/edge_generator.hh"
-
-#include <boost/pool/object_pool.hpp>
-
-#include <list>
-#include <vector>
-
-namespace search {
-template <class Model> class Context;
-class Vertex;
-} // namespace search
-
-namespace Moses {
-class Word;
-class WordsRange;
-class TargetPhraseCollection;
-class WordsRange;
-class ChartCellLabelSet;
-class TargetPhrase;
-
-namespace Incremental {
-
-// Replacement for ChartTranslationOptionList
-// TODO: implement count and score thresholding.
-template <class Model> class Fill : public ChartParserCallback {
- public:
- Fill(search::Context<Model> &context, const std::vector<lm::WordIndex> &vocab_mapping);
-
- void Add(const TargetPhraseCollection &targets, const StackVec &nts, const WordsRange &ignored);
-
- void AddPhraseOOV(TargetPhrase &phrase, std::list<TargetPhraseCollection*> &waste_memory, const WordsRange &range);
-
- bool Empty() const { return edges_.Empty(); }
-
- void Search(ChartCellLabelSet &out, boost::object_pool<search::Vertex> &vertex_pool);
-
- private:
- lm::WordIndex Convert(const Word &word) const ;
-
- search::Context<Model> &context_;
-
- const std::vector<lm::WordIndex> &vocab_mapping_;
-
- search::EdgeGenerator edges_;
-};
-
-} // namespace Incremental
-} // namespace Moses
diff --git a/moses/Incremental/Manager.cpp b/moses/Incremental/Manager.cpp
deleted file mode 100644
index 7d684540c..000000000
--- a/moses/Incremental/Manager.cpp
+++ /dev/null
@@ -1,122 +0,0 @@
-#include "Manager.h"
-
-#include "Fill.h"
-
-#include "moses/ChartCell.h"
-#include "moses/TranslationSystem.h"
-#include "moses/StaticData.h"
-
-#include "search/context.hh"
-#include "search/config.hh"
-#include "search/weights.hh"
-
-#include <boost/lexical_cast.hpp>
-
-namespace Moses {
-namespace Incremental {
-
-namespace {
-struct ChartCellBaseFactory {
- ChartCellBase *operator()(size_t startPos, size_t endPos) const {
- return new ChartCellBase(startPos, endPos);
- }
-};
-} // namespace
-
-Manager::Manager(const InputType &source, const TranslationSystem &system) :
- source_(source),
- system_(system),
- cells_(source, ChartCellBaseFactory()),
- parser_(source, system, cells_) {
-
-}
-
-Manager::~Manager() {
- system_.CleanUpAfterSentenceProcessing(source_);
-}
-
-namespace {
-
-void ConstructString(const search::Final final, std::ostringstream &stream) {
- assert(final.Valid());
- const TargetPhrase &phrase = *static_cast<const TargetPhrase*>(final.GetNote().vp);
- size_t child = 0;
- for (std::size_t i = 0; i < phrase.GetSize(); ++i) {
- const Word &word = phrase.GetWord(i);
- if (word.IsNonTerminal()) {
- assert(child < final.GetArity());
- ConstructString(final.Children()[child++], stream);
- } else {
- stream << word[0]->GetString() << ' ';
- }
- }
-}
-
-void BestString(const ChartCellLabelSet &labels, std::string &out) {
- search::Final best;
- for (ChartCellLabelSet::const_iterator i = labels.begin(); i != labels.end(); ++i) {
- const search::Final child(i->second.GetStack().incr->BestChild());
- if (child.Valid() && (!best.Valid() || (child.GetScore() > best.GetScore()))) {
- best = child;
- }
- }
- if (!best.Valid()) {
- out.clear();
- return;
- }
- std::ostringstream stream;
- ConstructString(best, stream);
- out = stream.str();
- CHECK(out.size() > 9);
- // <s>
- out.erase(0, 4);
- // </s>
- out.erase(out.size() - 5);
- // Hack: include model score
- out += " ||| ";
- out += boost::lexical_cast<std::string>(best.GetScore());
-}
-
-} // namespace
-
-
-template <class Model> void Manager::LMCallback(const Model &model, const std::vector<lm::WordIndex> &words) {
- const LanguageModel &abstract = **system_.GetLanguageModels().begin();
- search::Weights weights(
- abstract.GetWeight(),
- abstract.OOVFeatureEnabled() ? abstract.GetOOVWeight() : 0.0,
- system_.GetWeightWordPenalty());
- search::Config config(weights, StaticData::Instance().GetCubePruningPopLimit());
- search::Context<Model> context(config, model);
-
- size_t size = source_.GetSize();
-
- boost::object_pool<search::Vertex> vertex_pool(std::max<size_t>(size * size / 2, 32));
-
- for (size_t width = 1; width <= size; ++width) {
- for (size_t startPos = 0; startPos <= size-width; ++startPos) {
- size_t endPos = startPos + width - 1;
- WordsRange range(startPos, endPos);
- Fill<Model> filler(context, words);
- parser_.Create(range, filler);
- filler.Search(cells_.MutableBase(range).MutableTargetLabelSet(), vertex_pool);
- }
- }
- BestString(cells_.GetBase(WordsRange(0, source_.GetSize() - 1)).GetTargetLabelSet(), output_);
-}
-
-template void Manager::LMCallback<lm::ngram::ProbingModel>(const lm::ngram::ProbingModel &model, const std::vector<lm::WordIndex> &words);
-template void Manager::LMCallback<lm::ngram::RestProbingModel>(const lm::ngram::RestProbingModel &model, const std::vector<lm::WordIndex> &words);
-template void Manager::LMCallback<lm::ngram::TrieModel>(const lm::ngram::TrieModel &model, const std::vector<lm::WordIndex> &words);
-template void Manager::LMCallback<lm::ngram::QuantTrieModel>(const lm::ngram::QuantTrieModel &model, const std::vector<lm::WordIndex> &words);
-template void Manager::LMCallback<lm::ngram::ArrayTrieModel>(const lm::ngram::ArrayTrieModel &model, const std::vector<lm::WordIndex> &words);
-template void Manager::LMCallback<lm::ngram::QuantArrayTrieModel>(const lm::ngram::QuantArrayTrieModel &model, const std::vector<lm::WordIndex> &words);
-
-void Manager::ProcessSentence() {
- const LMList &lms = system_.GetLanguageModels();
- UTIL_THROW_IF(lms.size() != 1, util::Exception, "Incremental search only supports one language model.");
- (*lms.begin())->IncrementalCallback(*this);
-}
-
-} // namespace Incremental
-} // namespace Moses
diff --git a/moses/Incremental/Manager.h b/moses/Incremental/Manager.h
deleted file mode 100644
index ac8d76a81..000000000
--- a/moses/Incremental/Manager.h
+++ /dev/null
@@ -1,35 +0,0 @@
-#pragma once
-
-#include "lm/word_index.hh"
-
-#include "moses/ChartCellCollection.h"
-#include "moses/ChartParser.h"
-
-namespace Moses {
-class InputType;
-class TranslationSystem;
-namespace Incremental {
-
-class Manager {
- public:
- Manager(const InputType &source, const TranslationSystem &system);
-
- ~Manager();
-
- template <class Model> void LMCallback(const Model &model, const std::vector<lm::WordIndex> &words);
-
- void ProcessSentence();
-
- const std::string &String() const { return output_; }
-
- private:
- const InputType &source_;
- const TranslationSystem &system_;
- ChartCellCollectionBase cells_;
- ChartParser parser_;
-
- std::string output_;
-};
-} // namespace Incremental
-} // namespace Moses
-
diff --git a/moses/Jamfile b/moses/Jamfile
index c05a9c6ab..9caa4e788 100644
--- a/moses/Jamfile
+++ b/moses/Jamfile
@@ -32,7 +32,6 @@ lib moses :
CYKPlusParser/*.cpp
RuleTable/*.cpp
fuzzy-match/*.cpp
- Incremental/*.cpp
: #exceptions
ThreadPool.cpp
SyntacticLanguageModel.cpp
diff --git a/moses/LM/Ken.cpp b/moses/LM/Ken.cpp
index 25e5a00d3..42e517f17 100644
--- a/moses/LM/Ken.cpp
+++ b/moses/LM/Ken.cpp
@@ -38,7 +38,7 @@ Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
#include "moses/InputFileStream.h"
#include "moses/StaticData.h"
#include "moses/ChartHypothesis.h"
-#include "moses/Incremental/Manager.h"
+#include "moses/Incremental.h"
#include <boost/shared_ptr.hpp>
diff --git a/search/Jamfile b/search/Jamfile
index c00d23828..f6433e0e3 100644
--- a/search/Jamfile
+++ b/search/Jamfile
@@ -1,5 +1 @@
-fakelib search : weights.cc vertex.cc vertex_generator.cc edge_generator.cc rule.cc ../lm//kenlm ../util//kenutil /top//boost_system : : : <include>.. ;
-
-import testing ;
-
-unit-test weights_test : weights_test.cc search /top//boost_unit_test_framework ;
+fakelib search : edge_generator.cc nbest.cc rule.cc vertex.cc vertex_generator.cc ../lm//kenlm ../util//kenutil /top//boost_system : : : <include>.. ;
diff --git a/search/applied.hh b/search/applied.hh
new file mode 100644
index 000000000..bd659e5c0
--- /dev/null
+++ b/search/applied.hh
@@ -0,0 +1,86 @@
+#ifndef SEARCH_APPLIED__
+#define SEARCH_APPLIED__
+
+#include "search/edge.hh"
+#include "search/header.hh"
+#include "util/pool.hh"
+
+#include <math.h>
+
+namespace search {
+
+// A full hypothesis: a score, arity of the rule, a pointer to the decoder's rule (Note), and pointers to non-terminals that were substituted.
+template <class Below> class GenericApplied : public Header {
+ public:
+ GenericApplied() {}
+
+ GenericApplied(void *location, PartialEdge partial)
+ : Header(location) {
+ memcpy(Base(), partial.Base(), kHeaderSize);
+ Below *child_out = Children();
+ const PartialVertex *part = partial.NT();
+ const PartialVertex *const part_end_loop = part + partial.GetArity();
+ for (; part != part_end_loop; ++part, ++child_out)
+ *child_out = Below(part->End());
+ }
+
+ GenericApplied(void *location, Score score, Arity arity, Note note) : Header(location, arity) {
+ SetScore(score);
+ SetNote(note);
+ }
+
+ explicit GenericApplied(History from) : Header(from) {}
+
+
+ // These are arrays of length GetArity().
+ Below *Children() {
+ return reinterpret_cast<Below*>(After());
+ }
+ const Below *Children() const {
+ return reinterpret_cast<const Below*>(After());
+ }
+
+ static std::size_t Size(Arity arity) {
+ return kHeaderSize + arity * sizeof(const Below);
+ }
+};
+
+// Applied rule that references itself.
+class Applied : public GenericApplied<Applied> {
+ private:
+ typedef GenericApplied<Applied> P;
+
+ public:
+ Applied() {}
+ Applied(void *location, PartialEdge partial) : P(location, partial) {}
+ Applied(History from) : P(from) {}
+};
+
+// How to build single-best hypotheses.
+class SingleBest {
+ public:
+ typedef PartialEdge Combine;
+
+ void Add(PartialEdge &existing, PartialEdge add) const {
+ if (!existing.Valid() || existing.GetScore() < add.GetScore())
+ existing = add;
+ }
+
+ NBestComplete Complete(PartialEdge partial) {
+ if (!partial.Valid())
+ return NBestComplete(NULL, lm::ngram::ChartState(), -INFINITY);
+ void *place_final = pool_.Allocate(Applied::Size(partial.GetArity()));
+ Applied(place_final, partial);
+ return NBestComplete(
+ place_final,
+ partial.CompletedState(),
+ partial.GetScore());
+ }
+
+ private:
+ util::Pool pool_;
+};
+
+} // namespace search
+
+#endif // SEARCH_APPLIED__
diff --git a/search/config.hh b/search/config.hh
index ef8e2354a..ba18c09e9 100644
--- a/search/config.hh
+++ b/search/config.hh
@@ -1,23 +1,36 @@
#ifndef SEARCH_CONFIG__
#define SEARCH_CONFIG__
-#include "search/weights.hh"
-#include "util/string_piece.hh"
+#include "search/types.hh"
namespace search {
+struct NBestConfig {
+ explicit NBestConfig(unsigned int in_size) {
+ keep = in_size;
+ size = in_size;
+ }
+
+ unsigned int keep, size;
+};
+
class Config {
public:
- Config(const Weights &weights, unsigned int pop_limit) :
- weights_(weights), pop_limit_(pop_limit) {}
+ Config(Score lm_weight, unsigned int pop_limit, const NBestConfig &nbest) :
+ lm_weight_(lm_weight), pop_limit_(pop_limit), nbest_(nbest) {}
- const Weights &GetWeights() const { return weights_; }
+ Score LMWeight() const { return lm_weight_; }
unsigned int PopLimit() const { return pop_limit_; }
+ const NBestConfig &GetNBest() const { return nbest_; }
+
private:
- Weights weights_;
+ Score lm_weight_;
+
unsigned int pop_limit_;
+
+ NBestConfig nbest_;
};
} // namespace search
diff --git a/search/context.hh b/search/context.hh
index 62163144f..08f21bbf0 100644
--- a/search/context.hh
+++ b/search/context.hh
@@ -1,30 +1,16 @@
#ifndef SEARCH_CONTEXT__
#define SEARCH_CONTEXT__
-#include "lm/model.hh"
#include "search/config.hh"
-#include "search/final.hh"
-#include "search/types.hh"
#include "search/vertex.hh"
-#include "util/exception.hh"
-#include "util/pool.hh"
#include <boost/pool/object_pool.hpp>
-#include <boost/ptr_container/ptr_vector.hpp>
-
-#include <vector>
namespace search {
-class Weights;
-
class ContextBase {
public:
- explicit ContextBase(const Config &config) : pop_limit_(config.PopLimit()), weights_(config.GetWeights()) {}
-
- util::Pool &FinalPool() {
- return final_pool_;
- }
+ explicit ContextBase(const Config &config) : config_(config) {}
VertexNode *NewVertexNode() {
VertexNode *ret = vertex_node_pool_.construct();
@@ -36,18 +22,16 @@ class ContextBase {
vertex_node_pool_.destroy(node);
}
- unsigned int PopLimit() const { return pop_limit_; }
+ unsigned int PopLimit() const { return config_.PopLimit(); }
- const Weights &GetWeights() const { return weights_; }
+ Score LMWeight() const { return config_.LMWeight(); }
- private:
- util::Pool final_pool_;
+ const Config &GetConfig() const { return config_; }
+ private:
boost::object_pool<VertexNode> vertex_node_pool_;
- unsigned int pop_limit_;
-
- const Weights &weights_;
+ Config config_;
};
template <class Model> class Context : public ContextBase {
diff --git a/search/edge_generator.cc b/search/edge_generator.cc
index 260159b1f..eacf5de5c 100644
--- a/search/edge_generator.cc
+++ b/search/edge_generator.cc
@@ -1,6 +1,7 @@
#include "search/edge_generator.hh"
#include "lm/left.hh"
+#include "lm/model.hh"
#include "lm/partial.hh"
#include "search/context.hh"
#include "search/vertex.hh"
@@ -38,7 +39,7 @@ template <class Model> void FastScore(const Context<Model> &context, Arity victi
*cover = *(cover + 1);
}
}
- update.SetScore(update.GetScore() + adjustment * context.GetWeights().LM());
+ update.SetScore(update.GetScore() + adjustment * context.LMWeight());
}
} // namespace
diff --git a/search/edge_generator.hh b/search/edge_generator.hh
index 582c78b7b..203942c6f 100644
--- a/search/edge_generator.hh
+++ b/search/edge_generator.hh
@@ -2,7 +2,6 @@
#define SEARCH_EDGE_GENERATOR__
#include "search/edge.hh"
-#include "search/note.hh"
#include "search/types.hh"
#include <queue>
diff --git a/search/final.hh b/search/final.hh
deleted file mode 100644
index 50e62cf2e..000000000
--- a/search/final.hh
+++ /dev/null
@@ -1,36 +0,0 @@
-#ifndef SEARCH_FINAL__
-#define SEARCH_FINAL__
-
-#include "search/header.hh"
-#include "util/pool.hh"
-
-namespace search {
-
-// A full hypothesis with pointers to children.
-class Final : public Header {
- public:
- Final() {}
-
- Final(util::Pool &pool, Score score, Arity arity, Note note)
- : Header(pool.Allocate(Size(arity)), arity) {
- SetScore(score);
- SetNote(note);
- }
-
- // These are arrays of length GetArity().
- Final *Children() {
- return reinterpret_cast<Final*>(After());
- }
- const Final *Children() const {
- return reinterpret_cast<const Final*>(After());
- }
-
- private:
- static std::size_t Size(Arity arity) {
- return kHeaderSize + arity * sizeof(const Final);
- }
-};
-
-} // namespace search
-
-#endif // SEARCH_FINAL__
diff --git a/search/header.hh b/search/header.hh
index 25550dbed..69f0eed04 100644
--- a/search/header.hh
+++ b/search/header.hh
@@ -3,7 +3,6 @@
// Header consisting of Score, Arity, and Note
-#include "search/note.hh"
#include "search/types.hh"
#include <stdint.h>
@@ -24,6 +23,9 @@ class Header {
bool operator<(const Header &other) const {
return GetScore() < other.GetScore();
}
+ bool operator>(const Header &other) const {
+ return GetScore() > other.GetScore();
+ }
Arity GetArity() const {
return *reinterpret_cast<const Arity*>(base_ + sizeof(Score));
@@ -36,9 +38,14 @@ class Header {
*reinterpret_cast<Note*>(base_ + sizeof(Score) + sizeof(Arity)) = to;
}
+ uint8_t *Base() { return base_; }
+ const uint8_t *Base() const { return base_; }
+
protected:
Header() : base_(NULL) {}
+ explicit Header(void *base) : base_(static_cast<uint8_t*>(base)) {}
+
Header(void *base, Arity arity) : base_(static_cast<uint8_t*>(base)) {
*reinterpret_cast<Arity*>(base_ + sizeof(Score)) = arity;
}
diff --git a/search/nbest.cc b/search/nbest.cc
new file mode 100644
index 000000000..ec3322c97
--- /dev/null
+++ b/search/nbest.cc
@@ -0,0 +1,106 @@
+#include "search/nbest.hh"
+
+#include "util/pool.hh"
+
+#include <algorithm>
+#include <functional>
+#include <queue>
+
+#include <assert.h>
+#include <math.h>
+
+namespace search {
+
+NBestList::NBestList(std::vector<PartialEdge> &partials, util::Pool &entry_pool, std::size_t keep) {
+ assert(!partials.empty());
+ std::vector<PartialEdge>::iterator end;
+ if (partials.size() > keep) {
+ end = partials.begin() + keep;
+ std::nth_element(partials.begin(), end, partials.end(), std::greater<PartialEdge>());
+ } else {
+ end = partials.end();
+ }
+ for (std::vector<PartialEdge>::const_iterator i(partials.begin()); i != end; ++i) {
+ queue_.push(QueueEntry(entry_pool.Allocate(QueueEntry::Size(i->GetArity())), *i));
+ }
+}
+
+Score NBestList::TopAfterConstructor() const {
+ assert(revealed_.empty());
+ return queue_.top().GetScore();
+}
+
+const std::vector<Applied> &NBestList::Extract(util::Pool &pool, std::size_t n) {
+ while (revealed_.size() < n && !queue_.empty()) {
+ MoveTop(pool);
+ }
+ return revealed_;
+}
+
+Score NBestList::Visit(util::Pool &pool, std::size_t index) {
+ if (index + 1 < revealed_.size())
+ return revealed_[index + 1].GetScore() - revealed_[index].GetScore();
+ if (queue_.empty())
+ return -INFINITY;
+ if (index + 1 == revealed_.size())
+ return queue_.top().GetScore() - revealed_[index].GetScore();
+ assert(index == revealed_.size());
+
+ MoveTop(pool);
+
+ if (queue_.empty()) return -INFINITY;
+ return queue_.top().GetScore() - revealed_[index].GetScore();
+}
+
+Applied NBestList::Get(util::Pool &pool, std::size_t index) {
+ assert(index <= revealed_.size());
+ if (index == revealed_.size()) MoveTop(pool);
+ return revealed_[index];
+}
+
+void NBestList::MoveTop(util::Pool &pool) {
+ assert(!queue_.empty());
+ QueueEntry entry(queue_.top());
+ queue_.pop();
+ RevealedRef *const children_begin = entry.Children();
+ RevealedRef *const children_end = children_begin + entry.GetArity();
+ Score basis = entry.GetScore();
+ for (RevealedRef *child = children_begin; child != children_end; ++child) {
+ Score change = child->in_->Visit(pool, child->index_);
+ if (change != -INFINITY) {
+ assert(change < 0.001);
+ QueueEntry new_entry(pool.Allocate(QueueEntry::Size(entry.GetArity())), basis + change, entry.GetArity(), entry.GetNote());
+ std::copy(children_begin, child, new_entry.Children());
+ RevealedRef *update = new_entry.Children() + (child - children_begin);
+ update->in_ = child->in_;
+ update->index_ = child->index_ + 1;
+ std::copy(child + 1, children_end, update + 1);
+ queue_.push(new_entry);
+ }
+ // Gesmundo, A. and Henderson, J. Faster Cube Pruning, IWSLT 2010.
+ if (child->index_) break;
+ }
+
+ // Convert QueueEntry to Applied. This leaves some unused memory.
+ void *overwrite = entry.Children();
+ for (unsigned int i = 0; i < entry.GetArity(); ++i) {
+ RevealedRef from(*(static_cast<const RevealedRef*>(overwrite) + i));
+ *(static_cast<Applied*>(overwrite) + i) = from.in_->Get(pool, from.index_);
+ }
+ revealed_.push_back(Applied(entry.Base()));
+}
+
+NBestComplete NBest::Complete(std::vector<PartialEdge> &partials) {
+ assert(!partials.empty());
+ NBestList *list = list_pool_.construct(partials, entry_pool_, config_.keep);
+ return NBestComplete(
+ list,
+ partials.front().CompletedState(), // All partials have the same state
+ list->TopAfterConstructor());
+}
+
+const std::vector<Applied> &NBest::Extract(History history) {
+ return static_cast<NBestList*>(history)->Extract(entry_pool_, config_.size);
+}
+
+} // namespace search
diff --git a/search/nbest.hh b/search/nbest.hh
new file mode 100644
index 000000000..cb7651bc2
--- /dev/null
+++ b/search/nbest.hh
@@ -0,0 +1,81 @@
+#ifndef SEARCH_NBEST__
+#define SEARCH_NBEST__
+
+#include "search/applied.hh"
+#include "search/config.hh"
+#include "search/edge.hh"
+
+#include <boost/pool/object_pool.hpp>
+
+#include <cstddef>
+#include <queue>
+#include <vector>
+
+#include <assert.h>
+
+namespace search {
+
+class NBestList;
+
+class NBestList {
+ private:
+ class RevealedRef {
+ public:
+ explicit RevealedRef(History history)
+ : in_(static_cast<NBestList*>(history)), index_(0) {}
+
+ private:
+ friend class NBestList;
+
+ NBestList *in_;
+ std::size_t index_;
+ };
+
+ typedef GenericApplied<RevealedRef> QueueEntry;
+
+ public:
+ NBestList(std::vector<PartialEdge> &existing, util::Pool &entry_pool, std::size_t keep);
+
+ Score TopAfterConstructor() const;
+
+ const std::vector<Applied> &Extract(util::Pool &pool, std::size_t n);
+
+ private:
+ Score Visit(util::Pool &pool, std::size_t index);
+
+ Applied Get(util::Pool &pool, std::size_t index);
+
+ void MoveTop(util::Pool &pool);
+
+ typedef std::vector<Applied> Revealed;
+ Revealed revealed_;
+
+ typedef std::priority_queue<QueueEntry> Queue;
+ Queue queue_;
+};
+
+class NBest {
+ public:
+ typedef std::vector<PartialEdge> Combine;
+
+ explicit NBest(const NBestConfig &config) : config_(config) {}
+
+ void Add(std::vector<PartialEdge> &existing, PartialEdge addition) const {
+ existing.push_back(addition);
+ }
+
+ NBestComplete Complete(std::vector<PartialEdge> &partials);
+
+ const std::vector<Applied> &Extract(History root);
+
+ private:
+ const NBestConfig config_;
+
+ boost::object_pool<NBestList> list_pool_;
+
+ util::Pool entry_pool_;
+};
+
+} // namespace search
+
+#endif // SEARCH_NBEST__
diff --git a/search/note.hh b/search/note.hh
deleted file mode 100644
index 50bed06ec..000000000
--- a/search/note.hh
+++ /dev/null
@@ -1,12 +0,0 @@
-#ifndef SEARCH_NOTE__
-#define SEARCH_NOTE__
-
-namespace search {
-
-union Note {
- const void *vp;
-};
-
-} // namespace search
-
-#endif // SEARCH_NOTE__
diff --git a/search/rule.cc b/search/rule.cc
index 5b00207ef..0244a09f7 100644
--- a/search/rule.cc
+++ b/search/rule.cc
@@ -1,7 +1,7 @@
#include "search/rule.hh"
+#include "lm/model.hh"
#include "search/context.hh"
-#include "search/final.hh"
#include <ostream>
@@ -9,35 +9,35 @@
namespace search {
-template <class Model> float ScoreRule(const Context<Model> &context, const std::vector<lm::WordIndex> &words, bool prepend_bos, lm::ngram::ChartState *writing) {
- unsigned int oov_count = 0;
- float prob = 0.0;
- const Model &model = context.LanguageModel();
- const lm::WordIndex oov = model.GetVocabulary().NotFound();
- for (std::vector<lm::WordIndex>::const_iterator word = words.begin(); ; ++word) {
- lm::ngram::RuleScore<Model> scorer(model, *(writing++));
- // TODO: optimize
- if (prepend_bos && (word == words.begin())) {
- scorer.BeginSentence();
- }
- for (; ; ++word) {
- if (word == words.end()) {
- prob += scorer.Finish();
- return static_cast<float>(oov_count) * context.GetWeights().OOV() + prob * context.GetWeights().LM();
- }
- if (*word == kNonTerminal) break;
- if (*word == oov) ++oov_count;
+template <class Model> ScoreRuleRet ScoreRule(const Model &model, const std::vector<lm::WordIndex> &words, lm::ngram::ChartState *writing) {
+ ScoreRuleRet ret;
+ ret.prob = 0.0;
+ ret.oov = 0;
+ const lm::WordIndex oov = model.GetVocabulary().NotFound(), bos = model.GetVocabulary().BeginSentence();
+ lm::ngram::RuleScore<Model> scorer(model, *(writing++));
+ std::vector<lm::WordIndex>::const_iterator word = words.begin();
+ if (word != words.end() && *word == bos) {
+ scorer.BeginSentence();
+ ++word;
+ }
+ for (; word != words.end(); ++word) {
+ if (*word == kNonTerminal) {
+ ret.prob += scorer.Finish();
+ scorer.Reset(*(writing++));
+ } else {
+ if (*word == oov) ++ret.oov;
scorer.Terminal(*word);
}
- prob += scorer.Finish();
}
+ ret.prob += scorer.Finish();
+ return ret;
}
-template float ScoreRule(const Context<lm::ngram::RestProbingModel> &model, const std::vector<lm::WordIndex> &words, bool prepend_bos, lm::ngram::ChartState *writing);
-template float ScoreRule(const Context<lm::ngram::ProbingModel> &model, const std::vector<lm::WordIndex> &words, bool prepend_bos, lm::ngram::ChartState *writing);
-template float ScoreRule(const Context<lm::ngram::TrieModel> &model, const std::vector<lm::WordIndex> &words, bool prepend_bos, lm::ngram::ChartState *writing);
-template float ScoreRule(const Context<lm::ngram::QuantTrieModel> &model, const std::vector<lm::WordIndex> &words, bool prepend_bos, lm::ngram::ChartState *writing);
-template float ScoreRule(const Context<lm::ngram::ArrayTrieModel> &model, const std::vector<lm::WordIndex> &words, bool prepend_bos, lm::ngram::ChartState *writing);
-template float ScoreRule(const Context<lm::ngram::QuantArrayTrieModel> &model, const std::vector<lm::WordIndex> &words, bool prepend_bos, lm::ngram::ChartState *writing);
+template ScoreRuleRet ScoreRule(const lm::ngram::RestProbingModel &model, const std::vector<lm::WordIndex> &words, lm::ngram::ChartState *writing);
+template ScoreRuleRet ScoreRule(const lm::ngram::ProbingModel &model, const std::vector<lm::WordIndex> &words, lm::ngram::ChartState *writing);
+template ScoreRuleRet ScoreRule(const lm::ngram::TrieModel &model, const std::vector<lm::WordIndex> &words, lm::ngram::ChartState *writing);
+template ScoreRuleRet ScoreRule(const lm::ngram::QuantTrieModel &model, const std::vector<lm::WordIndex> &words, lm::ngram::ChartState *writing);
+template ScoreRuleRet ScoreRule(const lm::ngram::ArrayTrieModel &model, const std::vector<lm::WordIndex> &words, lm::ngram::ChartState *writing);
+template ScoreRuleRet ScoreRule(const lm::ngram::QuantArrayTrieModel &model, const std::vector<lm::WordIndex> &words, lm::ngram::ChartState *writing);
} // namespace search
diff --git a/search/rule.hh b/search/rule.hh
index 0ce2794db..43ca61625 100644
--- a/search/rule.hh
+++ b/search/rule.hh
@@ -9,11 +9,16 @@
namespace search {
-template <class Model> class Context;
-
const lm::WordIndex kNonTerminal = lm::kMaxWordIndex;
-template <class Model> float ScoreRule(const Context<Model> &context, const std::vector<lm::WordIndex> &words, bool prepend_bos, lm::ngram::ChartState *state_out);
+struct ScoreRuleRet {
+ Score prob;
+ unsigned int oov;
+};
+
+// Pass <s> and </s> normally.
+// Indicate non-terminals with kNonTerminal.
+template <class Model> ScoreRuleRet ScoreRule(const Model &model, const std::vector<lm::WordIndex> &words, lm::ngram::ChartState *state_out);
} // namespace search
diff --git a/search/types.hh b/search/types.hh
index 06eb5bfa2..f9c849b3f 100644
--- a/search/types.hh
+++ b/search/types.hh
@@ -3,12 +3,29 @@
#include <stdint.h>
+namespace lm { namespace ngram { class ChartState; } }
+
namespace search {
typedef float Score;
typedef uint32_t Arity;
+union Note {
+ const void *vp;
+};
+
+typedef void *History;
+
+struct NBestComplete {
+ NBestComplete(History in_history, const lm::ngram::ChartState &in_state, Score in_score)
+ : history(in_history), state(&in_state), score(in_score) {}
+
+ History history;
+ const lm::ngram::ChartState *state;
+ Score score;
+};
+
} // namespace search
#endif // SEARCH_TYPES__
diff --git a/search/vertex.cc b/search/vertex.cc
index 11f4631fa..45842982c 100644
--- a/search/vertex.cc
+++ b/search/vertex.cc
@@ -19,21 +19,34 @@ struct GreaterByBound : public std::binary_function<const VertexNode *, const Ve
} // namespace
-void VertexNode::SortAndSet(ContextBase &context, VertexNode **parent_ptr) {
+void VertexNode::RecursiveSortAndSet(ContextBase &context, VertexNode *&parent_ptr) {
if (Complete()) {
- assert(end_.Valid());
+ assert(end_);
assert(extend_.empty());
- bound_ = end_.GetScore();
return;
}
- if (extend_.size() == 1 && parent_ptr) {
- *parent_ptr = extend_[0];
- extend_[0]->SortAndSet(context, parent_ptr);
+ if (extend_.size() == 1) {
+ parent_ptr = extend_[0];
+ extend_[0]->RecursiveSortAndSet(context, parent_ptr);
context.DeleteVertexNode(this);
return;
}
for (std::vector<VertexNode*>::iterator i = extend_.begin(); i != extend_.end(); ++i) {
- (*i)->SortAndSet(context, &*i);
+ (*i)->RecursiveSortAndSet(context, *i);
+ }
+ std::sort(extend_.begin(), extend_.end(), GreaterByBound());
+ bound_ = extend_.front()->Bound();
+}
+
+void VertexNode::SortAndSet(ContextBase &context) {
+ // This is the root. The root might be empty.
+ if (extend_.empty()) {
+ bound_ = -INFINITY;
+ return;
+ }
+ // The root cannot be replaced. There's always one transition.
+ for (std::vector<VertexNode*>::iterator i = extend_.begin(); i != extend_.end(); ++i) {
+ (*i)->RecursiveSortAndSet(context, *i);
}
std::sort(extend_.begin(), extend_.end(), GreaterByBound());
bound_ = extend_.front()->Bound();
diff --git a/search/vertex.hh b/search/vertex.hh
index 52bc1dfe7..10b3339b9 100644
--- a/search/vertex.hh
+++ b/search/vertex.hh
@@ -2,7 +2,6 @@
#define SEARCH_VERTEX__
#include "lm/left.hh"
-#include "search/final.hh"
#include "search/types.hh"
#include <boost/unordered_set.hpp>
@@ -10,6 +9,7 @@
#include <queue>
#include <vector>
+#include <math.h>
#include <stdint.h>
namespace search {
@@ -18,7 +18,7 @@ class ContextBase;
class VertexNode {
public:
- VertexNode() {}
+ VertexNode() : end_() {}
void InitRoot() {
extend_.clear();
@@ -26,7 +26,7 @@ class VertexNode {
state_.left.length = 0;
state_.right.length = 0;
right_full_ = false;
- end_ = Final();
+ end_ = History();
}
lm::ngram::ChartState &MutableState() { return state_; }
@@ -36,20 +36,21 @@ class VertexNode {
extend_.push_back(next);
}
- void SetEnd(Final end) {
- assert(!end_.Valid());
+ void SetEnd(History end, Score score) {
+ assert(!end_);
end_ = end;
+ bound_ = score;
}
- void SortAndSet(ContextBase &context, VertexNode **parent_pointer);
+ void SortAndSet(ContextBase &context);
// Should only happen to a root node when the entire vertex is empty.
bool Empty() const {
- return !end_.Valid() && extend_.empty();
+ return !end_ && extend_.empty();
}
bool Complete() const {
- return end_.Valid();
+ return end_;
}
const lm::ngram::ChartState &State() const { return state_; }
@@ -64,7 +65,7 @@ class VertexNode {
}
// Will be invalid unless this is a leaf.
- const Final End() const { return end_; }
+ const History End() const { return end_; }
const VertexNode &operator[](size_t index) const {
return *extend_[index];
@@ -75,13 +76,15 @@ class VertexNode {
}
private:
+ void RecursiveSortAndSet(ContextBase &context, VertexNode *&parent);
+
std::vector<VertexNode*> extend_;
lm::ngram::ChartState state_;
bool right_full_;
Score bound_;
- Final end_;
+ History end_;
};
class PartialVertex {
@@ -97,7 +100,7 @@ class PartialVertex {
const lm::ngram::ChartState &State() const { return back_->State(); }
bool RightFull() const { return back_->RightFull(); }
- Score Bound() const { return Complete() ? back_->End().GetScore() : (*back_)[index_].Bound(); }
+ Score Bound() const { return Complete() ? back_->Bound() : (*back_)[index_].Bound(); }
unsigned char Length() const { return back_->Length(); }
@@ -121,7 +124,7 @@ class PartialVertex {
return ret;
}
- const Final End() const {
+ const History End() const {
return back_->End();
}
@@ -130,16 +133,18 @@ class PartialVertex {
unsigned int index_;
};
+template <class Output> class VertexGenerator;
+
class Vertex {
public:
Vertex() {}
PartialVertex RootPartial() const { return PartialVertex(root_); }
- const Final BestChild() const {
+ const History BestChild() const {
PartialVertex top(RootPartial());
if (top.Empty()) {
- return Final();
+ return History();
} else {
PartialVertex continuation;
while (!top.Complete()) {
@@ -150,8 +155,8 @@ class Vertex {
}
private:
- friend class VertexGenerator;
-
+ template <class Output> friend class VertexGenerator;
+ template <class Output> friend class RootVertexGenerator;
VertexNode root_;
};
diff --git a/search/vertex_generator.cc b/search/vertex_generator.cc
index e18010c38..73139ffc5 100644
--- a/search/vertex_generator.cc
+++ b/search/vertex_generator.cc
@@ -11,23 +11,11 @@
namespace search {
-VertexGenerator::VertexGenerator(ContextBase &context, Vertex &gen) : context_(context), gen_(gen) {
- gen.root_.InitRoot();
-}
-
#if BOOST_VERSION > 104200
namespace {
const uint64_t kCompleteAdd = static_cast<uint64_t>(-1);
-// Parallel structure to VertexNode.
-struct Trie {
- Trie() : under(NULL) {}
-
- VertexNode *under;
- boost::unordered_map<uint64_t, Trie> extend;
-};
-
Trie &FindOrInsert(ContextBase &context, Trie &node, uint64_t added, const lm::ngram::ChartState &state, unsigned char left, bool left_full, unsigned char right, bool right_full) {
Trie &next = node.extend[added];
if (!next.under) {
@@ -43,19 +31,10 @@ Trie &FindOrInsert(ContextBase &context, Trie &node, uint64_t added, const lm::n
return next;
}
-void CompleteTransition(ContextBase &context, Trie &starter, PartialEdge partial) {
- Final final(context.FinalPool(), partial.GetScore(), partial.GetArity(), partial.GetNote());
- Final *child_out = final.Children();
- const PartialVertex *part = partial.NT();
- const PartialVertex *const part_end_loop = part + partial.GetArity();
- for (; part != part_end_loop; ++part, ++child_out)
- *child_out = part->End();
-
- starter.under->SetEnd(final);
-}
+} // namespace
-void AddHypothesis(ContextBase &context, Trie &root, PartialEdge partial) {
- const lm::ngram::ChartState &state = partial.CompletedState();
+void AddHypothesis(ContextBase &context, Trie &root, const NBestComplete &end) {
+ const lm::ngram::ChartState &state = *end.state;
unsigned char left = 0, right = 0;
Trie *node = &root;
@@ -81,30 +60,9 @@ void AddHypothesis(ContextBase &context, Trie &root, PartialEdge partial) {
}
node = &FindOrInsert(context, *node, kCompleteAdd - state.left.full, state, state.left.length, true, state.right.length, true);
- CompleteTransition(context, *node, partial);
-}
-
-} // namespace
-
-#else // BOOST_VERSION
-
-struct Trie {
- VertexNode *under;
-};
-
-void AddHypothesis(ContextBase &context, Trie &root, PartialEdge partial) {
- UTIL_THROW(util::Exception, "Upgrade Boost to >= 1.42.0 to use incremental search.");
+ node->under->SetEnd(end.history, end.score);
}
#endif // BOOST_VERSION
-void VertexGenerator::FinishedSearch() {
- Trie root;
- root.under = &gen_.root_;
- for (Existing::const_iterator i(existing_.begin()); i != existing_.end(); ++i) {
- AddHypothesis(context_, root, i->second);
- }
- root.under->SortAndSet(context_, NULL);
-}
-
} // namespace search
diff --git a/search/vertex_generator.hh b/search/vertex_generator.hh
index 60e86112a..da563c2df 100644
--- a/search/vertex_generator.hh
+++ b/search/vertex_generator.hh
@@ -2,9 +2,11 @@
#define SEARCH_VERTEX_GENERATOR__
#include "search/edge.hh"
+#include "search/types.hh"
#include "search/vertex.hh"
#include <boost/unordered_map.hpp>
+#include <boost/version.hpp>
namespace lm {
namespace ngram {
@@ -15,21 +17,44 @@ class ChartState;
namespace search {
class ContextBase;
-class Final;
-class VertexGenerator {
+#if BOOST_VERSION > 104200
+// Parallel structure to VertexNode.
+struct Trie {
+ Trie() : under(NULL) {}
+
+ VertexNode *under;
+ boost::unordered_map<uint64_t, Trie> extend;
+};
+
+void AddHypothesis(ContextBase &context, Trie &root, const NBestComplete &end);
+
+#endif // BOOST_VERSION
+
+// Output makes the single-best or n-best list.
+template <class Output> class VertexGenerator {
public:
- VertexGenerator(ContextBase &context, Vertex &gen);
+ VertexGenerator(ContextBase &context, Vertex &gen, Output &nbest) : context_(context), gen_(gen), nbest_(nbest) {
+ gen.root_.InitRoot();
+ }
void NewHypothesis(PartialEdge partial) {
- const lm::ngram::ChartState &state = partial.CompletedState();
- std::pair<Existing::iterator, bool> ret(existing_.insert(std::make_pair(hash_value(state), partial)));
- if (!ret.second && ret.first->second < partial) {
- ret.first->second = partial;
- }
+ nbest_.Add(existing_[hash_value(partial.CompletedState())], partial);
}
- void FinishedSearch();
+ void FinishedSearch() {
+#if BOOST_VERSION > 104200
+ Trie root;
+ root.under = &gen_.root_;
+ for (typename Existing::iterator i(existing_.begin()); i != existing_.end(); ++i) {
+ AddHypothesis(context_, root, nbest_.Complete(i->second));
+ }
+ existing_.clear();
+ root.under->SortAndSet(context_);
+#else
+ UTIL_THROW(util::Exception, "Upgrade Boost to >= 1.42.0 to use incremental search.");
+#endif
+ }
const Vertex &Generating() const { return gen_; }
@@ -38,8 +63,35 @@ class VertexGenerator {
Vertex &gen_;
- typedef boost::unordered_map<uint64_t, PartialEdge> Existing;
+ typedef boost::unordered_map<uint64_t, typename Output::Combine> Existing;
Existing existing_;
+
+ Output &nbest_;
+};
+
+// Special case for root vertex: everything should come together into the root
+// node. In theory, this should happen naturally due to state collapsing with
+// <s> and </s>. If that's the case, VertexGenerator is fine, though it will
+// make one connection.
+template <class Output> class RootVertexGenerator {
+ public:
+ RootVertexGenerator(Vertex &gen, Output &out) : gen_(gen), out_(out) {}
+
+ void NewHypothesis(PartialEdge partial) {
+ out_.Add(combine_, partial);
+ }
+
+ void FinishedSearch() {
+ gen_.root_.InitRoot();
+ NBestComplete completed(out_.Complete(combine_));
+ gen_.root_.SetEnd(completed.history, completed.score);
+ }
+
+ private:
+ Vertex &gen_;
+
+ typename Output::Combine combine_;
+ Output &out_;
};
} // namespace search
diff --git a/search/weights.cc b/search/weights.cc
deleted file mode 100644
index d65471ad7..000000000
--- a/search/weights.cc
+++ /dev/null
@@ -1,71 +0,0 @@
-#include "search/weights.hh"
-#include "util/tokenize_piece.hh"
-
-#include <cstdlib>
-
-namespace search {
-
-namespace {
-struct Insert {
- void operator()(boost::unordered_map<std::string, search::Score> &map, StringPiece name, search::Score score) const {
- std::string copy(name.data(), name.size());
- map[copy] = score;
- }
-};
-
-struct DotProduct {
- search::Score total;
- DotProduct() : total(0.0) {}
-
- void operator()(const boost::unordered_map<std::string, search::Score> &map, StringPiece name, search::Score score) {
- boost::unordered_map<std::string, search::Score>::const_iterator i(FindStringPiece(map, name));
- if (i != map.end())
- total += score * i->second;
- }
-};
-
-template <class Map, class Op> void Parse(StringPiece text, Map &map, Op &op) {
- for (util::TokenIter<util::SingleCharacter, true> spaces(text, ' '); spaces; ++spaces) {
- util::TokenIter<util::SingleCharacter> equals(*spaces, '=');
- UTIL_THROW_IF(!equals, WeightParseException, "Bad weight token " << *spaces);
- StringPiece name(*equals);
- UTIL_THROW_IF(!++equals, WeightParseException, "Bad weight token " << *spaces);
- char *end;
- // Assumes proper termination.
- double value = std::strtod(equals->data(), &end);
- UTIL_THROW_IF(end != equals->data() + equals->size(), WeightParseException, "Failed to parse weight" << *equals);
- UTIL_THROW_IF(++equals, WeightParseException, "Too many equals in " << *spaces);
- op(map, name, value);
- }
-}
-
-} // namespace
-
-Weights::Weights(StringPiece text) {
- Insert op;
- Parse<Map, Insert>(text, map_, op);
- lm_ = Steal("LanguageModel");
- oov_ = Steal("OOV");
- word_penalty_ = Steal("WordPenalty");
-}
-
-Weights::Weights(Score lm, Score oov, Score word_penalty) : lm_(lm), oov_(oov), word_penalty_(word_penalty) {}
-
-search::Score Weights::DotNoLM(StringPiece text) const {
- DotProduct dot;
- Parse<const Map, DotProduct>(text, map_, dot);
- return dot.total;
-}
-
-float Weights::Steal(const std::string &str) {
- Map::iterator i(map_.find(str));
- if (i == map_.end()) {
- return 0.0;
- } else {
- float ret = i->second;
- map_.erase(i);
- return ret;
- }
-}
-
-} // namespace search
diff --git a/search/weights.hh b/search/weights.hh
deleted file mode 100644
index df1c419f0..000000000
--- a/search/weights.hh
+++ /dev/null
@@ -1,52 +0,0 @@
-// For now, the individual features are not kept.
-#ifndef SEARCH_WEIGHTS__
-#define SEARCH_WEIGHTS__
-
-#include "search/types.hh"
-#include "util/exception.hh"
-#include "util/string_piece.hh"
-
-#include <boost/unordered_map.hpp>
-
-#include <string>
-
-namespace search {
-
-class WeightParseException : public util::Exception {
- public:
- WeightParseException() {}
- ~WeightParseException() throw() {}
-};
-
-class Weights {
- public:
- // Parses weights, sets lm_weight_, removes it from map_.
- explicit Weights(StringPiece text);
-
- // Just the three scores we care about adding.
- Weights(Score lm, Score oov, Score word_penalty);
-
- Score DotNoLM(StringPiece text) const;
-
- Score LM() const { return lm_; }
-
- Score OOV() const { return oov_; }
-
- Score WordPenalty() const { return word_penalty_; }
-
- // Mostly for testing.
- const boost::unordered_map<std::string, Score> &GetMap() const { return map_; }
-
- private:
- float Steal(const std::string &str);
-
- typedef boost::unordered_map<std::string, Score> Map;
-
- Map map_;
-
- Score lm_, oov_, word_penalty_;
-};
-
-} // namespace search
-
-#endif // SEARCH_WEIGHTS__
diff --git a/search/weights_test.cc b/search/weights_test.cc
deleted file mode 100644
index 4811ff060..000000000
--- a/search/weights_test.cc
+++ /dev/null
@@ -1,38 +0,0 @@
-#include "search/weights.hh"
-
-#define BOOST_TEST_MODULE WeightTest
-#include <boost/test/unit_test.hpp>
-#include <boost/test/floating_point_comparison.hpp>
-
-namespace search {
-namespace {
-
-#define CHECK_WEIGHT(value, string) \
- i = parsed.find(string); \
- BOOST_REQUIRE(i != parsed.end()); \
- BOOST_CHECK_CLOSE((value), i->second, 0.001);
-
-BOOST_AUTO_TEST_CASE(parse) {
- // These are not real feature weights.
- Weights w("rarity=0 phrase-SGT=0 phrase-TGS=9.45117 lhsGrhs=0 lexical-SGT=2.33833 lexical-TGS=-28.3317 abstract?=0 LanguageModel=3 lexical?=1 glue?=5");
- const boost::unordered_map<std::string, search::Score> &parsed = w.GetMap();
- boost::unordered_map<std::string, search::Score>::const_iterator i;
- CHECK_WEIGHT(0.0, "rarity");
- CHECK_WEIGHT(0.0, "phrase-SGT");
- CHECK_WEIGHT(9.45117, "phrase-TGS");
- CHECK_WEIGHT(2.33833, "lexical-SGT");
- BOOST_CHECK(parsed.end() == parsed.find("lm"));
- BOOST_CHECK_CLOSE(3.0, w.LM(), 0.001);
- CHECK_WEIGHT(-28.3317, "lexical-TGS");
- CHECK_WEIGHT(5.0, "glue?");
-}
-
-BOOST_AUTO_TEST_CASE(dot) {
- Weights w("rarity=0 phrase-SGT=0 phrase-TGS=9.45117 lhsGrhs=0 lexical-SGT=2.33833 lexical-TGS=-28.3317 abstract?=0 LanguageModel=3 lexical?=1 glue?=5");
- BOOST_CHECK_CLOSE(9.45117 * 3.0, w.DotNoLM("phrase-TGS=3.0"), 0.001);
- BOOST_CHECK_CLOSE(9.45117 * 3.0, w.DotNoLM("phrase-TGS=3.0 LanguageModel=10"), 0.001);
- BOOST_CHECK_CLOSE(9.45117 * 3.0 + 28.3317 * 17.4, w.DotNoLM("rarity=5 phrase-TGS=3.0 LanguageModel=10 lexical-TGS=-17.4"), 0.001);
-}
-
-} // namespace
-} // namespace search