diff options
author | Ulrich Germann <Ulrich.Germann@gmail.com> | 2015-12-07 19:07:11 +0300 |
---|---|---|
committer | Ulrich Germann <Ulrich.Germann@gmail.com> | 2015-12-07 19:07:11 +0300 |
commit | c4e45fb128e096f255a624b57b7826febdf06f2e (patch) | |
tree | 74455d64b0e45877c91dc2488838cfe01732b224 /moses/Manager.cpp | |
parent | 2be2481feb2d68d6e4ba366d06fcfa51f7ff664e (diff) |
Code cleanup.
Diffstat (limited to 'moses/Manager.cpp')
-rw-r--r-- | moses/Manager.cpp | 101 |
1 files changed, 32 insertions, 69 deletions
diff --git a/moses/Manager.cpp b/moses/Manager.cpp index 61813479e..b3607d190 100644 --- a/moses/Manager.cpp +++ b/moses/Manager.cpp @@ -47,7 +47,8 @@ Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA #include "moses/HypergraphOutput.h" #include "moses/mbr.h" #include "moses/LatticeMBR.h" - +#include "moses/SearchNormal.h" +#include "moses/SearchCubePruning.h" #include <boost/foreach.hpp> #ifdef HAVE_PROTOBUF @@ -72,10 +73,16 @@ Manager::Manager(ttasksptr const& ttask) boost::shared_ptr<InputType> source = ttask->GetSource(); m_transOptColl = source->CreateTranslationOptionCollection(ttask); - const StaticData &staticData = StaticData::Instance(); - SearchAlgorithm searchAlgorithm = options().search.algo; - m_search = Search::CreateSearch(*this, *source, searchAlgorithm, - *m_transOptColl); + switch(options().search.algo) { + case Normal: + m_search = new SearchNormal(*this, *m_transOptColl); + break; + case CubePruning: + m_search = new SearchCubePruning(*this, *m_transOptColl); + break; + default: + UTIL_THROW2("ERROR: search. Aborting\n"); + } StaticData::Instance().InitializeForInput(ttask); } @@ -814,50 +821,24 @@ size_t Manager::OutputFeatureWeightsForSLF(size_t index, const FeatureFunction* } } -size_t Manager::OutputFeatureValuesForSLF(size_t index, bool zeros, const Hypothesis* hypo, const FeatureFunction* ff, std::ostream &outputSearchGraphStream) const +size_t +Manager:: +OutputFeatureValuesForSLF(size_t index, bool zeros, const Hypothesis* hypo, + const FeatureFunction* ff, std::ostream &out) const { - - // { const FeatureFunction* sp = ff; - // const FVector& m_scores = scoreCollection.GetScoresVector(); - // FVector& scores = const_cast<FVector&>(m_scores); - // std::string prefix = sp->GetScoreProducerDescription() + FName::SEP; - // // std::cout << "prefix==" << prefix << endl; - // // cout << "m_scores==" << m_scores << endl; - // // cout << "m_scores.size()==" << m_scores.size() << endl; - // // cout << "m_scores.coreSize()==" << m_scores.coreSize() << endl; - // // cout << "m_scores.cbegin() ?= m_scores.cend()\t" << (m_scores.cbegin() == m_scores.cend()) << endl; - - - // // for(FVector::FNVmap::const_iterator i = m_scores.cbegin(); i != m_scores.cend(); i++) { - // // std::cout<<prefix << "\t" << (i->first) << "\t" << (i->second) << std::endl; - // // } - // for(int i=0, n=v.size(); i<n; i+=1) { - // // outputSearchGraphStream << prefix << i << "==" << v[i] << std::endl; - - // } - // } - - // FVector featureValues = scoreCollection.GetVectorForProducer(ff); - // outputSearchGraphStream << featureValues << endl; const ScoreComponentCollection& scoreCollection = hypo->GetScoreBreakdown(); - vector<float> featureValues = scoreCollection.GetScoresForProducer(ff); - size_t numScoreComps = featureValues.size();//featureValues.coreSize(); - // if (numScoreComps != ScoreProducer::unlimited) { - // vector<float> values = StaticData::Instance().GetAllWeights().GetScoresForProducer(ff); + size_t numScoreComps = featureValues.size(); for (size_t i = 0; i < numScoreComps; ++i) { - outputSearchGraphStream << "x" << (index+i) << "=" << ((zeros) ? 0.0 : featureValues[i]) << " "; - } - return index+numScoreComps; - // } else { - // cerr << "Sparse features are not supported when outputting HTK standard lattice format" << endl; - // assert(false); - // return 0; - // } + out << "x" << (index+i) << "=" << ((zeros) ? 0.0 : featureValues[i]) << " "; + } + return index + numScoreComps; } /**! Output search graph in hypergraph format of Kenneth Heafield's lazy hypergraph decoder */ -void Manager::OutputSearchGraphAsHypergraph(std::ostream &outputSearchGraphStream) const +void +Manager:: +OutputSearchGraphAsHypergraph(std::ostream &outputSearchGraphStream) const { VERBOSE(2,"Getting search graph to output as hypergraph for sentence " << m_source.GetTranslationId() << std::endl) @@ -1108,7 +1089,7 @@ OutputSearchNode(AllOptions const& opts, long translationId, std::ostream &out, SearchGraphNode const& searchNode) { - const vector<FactorType> &outputFactorOrder = StaticData::Instance().GetOutputFactorOrder(); + const vector<FactorType> &outputFactorOrder = opts.output.factor_order; bool extendedFormat = opts.output.SearchGraphExtended.size(); out << translationId; @@ -1172,7 +1153,8 @@ void Manager::GetConnectedGraph( std::vector< const Hypothesis *>& connectedList = *pConnectedList; // start with the ones in the final stack - const std::vector < HypothesisStack* > &hypoStackColl = m_search->GetHypothesisStacks(); + const std::vector < HypothesisStack* > &hypoStackColl + = m_search->GetHypothesisStacks(); const HypothesisStack &finalStack = *hypoStackColl.back(); HypothesisStack::const_iterator iterHypo; for (iterHypo = finalStack.begin() ; iterHypo != finalStack.end() ; ++iterHypo) { @@ -1520,11 +1502,7 @@ void Manager::OutputBest(OutputCollector *collector) const if (options().output.ReportSegmentation == 2) { GetOutputLanguageModelOrder(out, bestHypo); } - OutputSurface(out,*bestHypo,true); - // bestHypo->OutputBestSurface( - // out, - // staticData.GetOutputFactorOrder(), - // options().output); + OutputSurface(out,*bestHypo, true); if (options().output.PrintAlignmentInfo) { out << "||| "; bestHypo->OutputAlignment(out, options().output.WA_SortOrder); @@ -1557,7 +1535,7 @@ void Manager::OutputBest(OutputCollector *collector) const // lattice MBR if (options().lmbr.enabled) { - if (staticData.options().nbest.enabled) { + if (options().nbest.enabled) { //lattice mbr nbest vector<LatticeMBRSolution> solutions; size_t n = min(nBestSize, options().nbest.nbest_size); @@ -1585,7 +1563,7 @@ void Manager::OutputBest(OutputCollector *collector) const // n-best MBR decoding else { - const TrellisPath &mbrBestHypo = doMBR(nBestList); + const TrellisPath &mbrBestHypo = doMBR(nBestList, options()); OutputBestHypo(mbrBestHypo, out); OutputAlignment(m_alignmentOut, mbrBestHypo); IFVERBOSE(2) { @@ -1840,11 +1818,10 @@ std::map<size_t, const Factor*> Manager::GetPlaceholders(const Hypothesis &hypo, void Manager::OutputLatticeSamples(OutputCollector *collector) const { - const StaticData &staticData = StaticData::Instance(); if (collector) { TrellisPathList latticeSamples; ostringstream out; - CalcLatticeSamples(staticData.GetLatticeSamplesSize(), latticeSamples); + CalcLatticeSamples(options().output.lattice_sample_size, latticeSamples); OutputNBest(out,latticeSamples); collector->Write(m_source.GetTranslationId(), out.str()); } @@ -1951,7 +1928,7 @@ OutputSearchGraph(OutputCollector *collector) const void Manager::OutputSearchGraphSLF() const { - const StaticData &staticData = StaticData::Instance(); + // const StaticData &staticData = StaticData::Instance(); long translationId = m_source.GetTranslationId(); // Output search graph in HTK standard lattice format (SLF) @@ -1975,20 +1952,6 @@ void Manager::OutputSearchGraphSLF() const } -// void Manager::OutputSearchGraphHypergraph() const -// { -// const StaticData &staticData = StaticData::Instance(); -// if (!staticData.GetOutputSearchGraphHypergraph()) return; - -// static char const* key = "output-search-graph-hypergraph"; -// PARAM_VEC const* p = staticData.GetParameter().GetParam(key); -// ScoreComponentCollection const& weights = staticData.GetAllWeights(); -// string const& nBestFile = staticData.GetNBestFilePath(); -// HypergraphOutput<Manager> hypergraphOutput(PRECISION, p, nBestFile, weights); -// hypergraphOutput.Write(*this); - -// } - void Manager::OutputLatticeMBRNBest(std::ostream& out, const vector<LatticeMBRSolution>& solutions,long translationId) const { for (vector<LatticeMBRSolution>::const_iterator si = solutions.begin(); si != solutions.end(); ++si) { @@ -1996,7 +1959,7 @@ void Manager::OutputLatticeMBRNBest(std::ostream& out, const vector<LatticeMBRSo out << " |||"; const vector<Word> mbrHypo = si->GetWords(); for (size_t i = 0 ; i < mbrHypo.size() ; i++) { - const Factor *factor = mbrHypo[i].GetFactor(StaticData::Instance().GetOutputFactorOrder()[0]); + const Factor *factor = mbrHypo[i].GetFactor(options().output.factor_order[0]); if (i>0) out << " " << *factor; else out << *factor; } |