#include "SearchNormalBatch.h" #include "LM/Base.h" #include "Manager.h" #include "Hypothesis.h" #include "util/exception.hh" //#include using namespace std; namespace Moses { SearchNormalBatch::SearchNormalBatch(Manager& manager, const InputType &source, const TranslationOptionCollection &transOptColl) :SearchNormal(manager, source, transOptColl) ,m_batch_size(10000) { m_max_stack_size = StaticData::Instance().GetMaxHypoStackSize(); // Split the feature functions into sets of stateless, stateful // distributed lm, and stateful non-distributed. const vector& ffs = StatefulFeatureFunction::GetStatefulFeatureFunctions(); for (unsigned i = 0; i < ffs.size(); ++i) { if (ffs[i]->GetScoreProducerDescription() == "DLM_5gram") { // TODO WFT m_dlm_ffs[i] = const_cast(static_cast(ffs[i])); m_dlm_ffs[i]->SetFFStateIdx(i); } else { m_stateful_ffs[i] = const_cast(ffs[i]); } } m_stateless_ffs = StatelessFeatureFunction::GetStatelessFeatureFunctions(); } SearchNormalBatch::~SearchNormalBatch() { } /** * Main decoder loop that translates a sentence by expanding * hypotheses stack by stack, until the end of the sentence. */ void SearchNormalBatch::Decode() { const StaticData &staticData = StaticData::Instance(); SentenceStats &stats = m_manager.GetSentenceStats(); // initial seed hypothesis: nothing translated, no words produced Hypothesis *hypo = Hypothesis::Create(m_manager,m_source, m_initialTransOpt); m_hypoStackColl[0]->AddPrune(hypo); // go through each stack std::vector < HypothesisStack* >::iterator iterStack; for (iterStack = m_hypoStackColl.begin() ; iterStack != m_hypoStackColl.end() ; ++iterStack) { // check if decoding ran out of time double _elapsed_time = GetUserTime(); if (_elapsed_time > staticData.GetTimeoutThreshold()) { VERBOSE(1,"Decoding is out of time (" << _elapsed_time << "," << staticData.GetTimeoutThreshold() << ")" << std::endl); interrupted_flag = 1; return; } HypothesisStackNormal &sourceHypoColl = *static_cast(*iterStack); // the stack is pruned before processing (lazy pruning): VERBOSE(3,"processing hypothesis from next stack"); IFVERBOSE(2) { stats.StartTimeStack(); } sourceHypoColl.PruneToSize(staticData.GetMaxHypoStackSize()); VERBOSE(3,std::endl); sourceHypoColl.CleanupArcList(); IFVERBOSE(2) { stats.StopTimeStack(); } // go through each hypothesis on the stack and try to expand it HypothesisStackNormal::const_iterator iterHypo; for (iterHypo = sourceHypoColl.begin() ; iterHypo != sourceHypoColl.end() ; ++iterHypo) { Hypothesis &hypothesis = **iterHypo; ProcessOneHypothesis(hypothesis); // expand the hypothesis } EvalAndMergePartialHypos(); // some logging IFVERBOSE(2) { OutputHypoStackSize(); } // this stack is fully expanded; actual_hypoStack = &sourceHypoColl; } EvalAndMergePartialHypos(); } /** * Expand one hypothesis with a translation option. * this involves initial creation, scoring and adding it to the proper stack * \param hypothesis hypothesis to be expanded upon * \param transOpt translation option (phrase translation) * that is applied to create the new hypothesis * \param expectedScore base score for early discarding * (base hypothesis score plus future score estimation) */ void SearchNormalBatch:: ExpandHypothesis(const Hypothesis &hypothesis, const TranslationOption &transOpt, float expectedScore) { // Check if the number of partial hypotheses exceeds the batch size. if (m_partial_hypos.size() >= m_batch_size) { EvalAndMergePartialHypos(); } const StaticData &staticData = StaticData::Instance(); SentenceStats &stats = m_manager.GetSentenceStats(); Hypothesis *newHypo; if (! staticData.UseEarlyDiscarding()) { // simple build, no questions asked IFVERBOSE(2) { stats.StartTimeBuildHyp(); } newHypo = hypothesis.CreateNext(transOpt); IFVERBOSE(2) { stats.StopTimeBuildHyp(); } if (newHypo==NULL) return; //newHypo->Evaluate(m_transOptColl.GetFutureScore()); // Issue DLM requests for new hypothesis and put into the list of // partial hypotheses. std::map::iterator dlm_iter; for (dlm_iter = m_dlm_ffs.begin(); dlm_iter != m_dlm_ffs.end(); ++dlm_iter) { const FFState* input_state = newHypo->GetPrevHypo() ? newHypo->GetPrevHypo()->GetFFState((*dlm_iter).first) : NULL; (*dlm_iter).second->IssueRequestsFor(*newHypo, input_state); } m_partial_hypos.push_back(newHypo); } else { UTIL_THROW2("can't use early discarding with batch decoding!"); } } void SearchNormalBatch::EvalAndMergePartialHypos() { std::vector::iterator partial_hypo_iter; for (partial_hypo_iter = m_partial_hypos.begin(); partial_hypo_iter != m_partial_hypos.end(); ++partial_hypo_iter) { Hypothesis* hypo = *partial_hypo_iter; // Evaluate with other ffs. std::map::iterator sfff_iter; for (sfff_iter = m_stateful_ffs.begin(); sfff_iter != m_stateful_ffs.end(); ++sfff_iter) { const StatefulFeatureFunction &ff = *(sfff_iter->second); int state_idx = sfff_iter->first; hypo->EvaluateWhenApplied(ff, state_idx); } std::vector::iterator slff_iter; for (slff_iter = m_stateless_ffs.begin(); slff_iter != m_stateless_ffs.end(); ++slff_iter) { hypo->EvaluateWhenApplied(**slff_iter); } } // Wait for all requests from the distributed LM to come back. std::map::iterator dlm_iter; for (dlm_iter = m_dlm_ffs.begin(); dlm_iter != m_dlm_ffs.end(); ++dlm_iter) { (*dlm_iter).second->sync(); } // Incorporate the DLM scores into all hypotheses and put into their // stacks. for (partial_hypo_iter = m_partial_hypos.begin(); partial_hypo_iter != m_partial_hypos.end(); ++partial_hypo_iter) { Hypothesis* hypo = *partial_hypo_iter; // Calculate DLM scores. std::map::iterator dlm_iter; for (dlm_iter = m_dlm_ffs.begin(); dlm_iter != m_dlm_ffs.end(); ++dlm_iter) { LanguageModel &lm = *(dlm_iter->second); hypo->EvaluateWhenApplied(lm, (*dlm_iter).first); } // Put completed hypothesis onto its stack. size_t wordsTranslated = hypo->GetWordsBitmap().GetNumWordsCovered(); m_hypoStackColl[wordsTranslated]->AddPrune(hypo); } m_partial_hypos.clear(); std::vector < HypothesisStack* >::iterator stack_iter; HypothesisStackNormal* stack; for (stack_iter = m_hypoStackColl.begin(); stack_iter != m_hypoStackColl.end(); ++stack_iter) { stack = static_cast(*stack_iter); stack->PruneToSize(m_max_stack_size); } } }