diff options
author | Michael Denkowski <mdenkows@amazon.com> | 2015-10-26 18:24:53 +0300 |
---|---|---|
committer | Michael Denkowski <mdenkows@amazon.com> | 2015-10-26 19:42:42 +0300 |
commit | d3f3389f20bb49830b3837370ca96543e4093660 (patch) | |
tree | f75bac6d6952b756a115e34cde051a17f8fc4473 /moses/BitmapContainer.h | |
parent | 6a37dfd2ce279e8493e151b505215eb9f21865f9 (diff) |
More deterministic tie-breaking for cube pruning (--cbds)
Doesn't slow down regular non-deterministic cube pruning
Diffstat (limited to 'moses/BitmapContainer.h')
-rw-r--r-- | moses/BitmapContainer.h | 34 |
1 files changed, 30 insertions, 4 deletions
diff --git a/moses/BitmapContainer.h b/moses/BitmapContainer.h index 2840e62d9..5f301a1e8 100644 --- a/moses/BitmapContainer.h +++ b/moses/BitmapContainer.h @@ -118,6 +118,11 @@ public: return false; } else { // Equal scores: break ties by comparing target phrases (if they exist) + // *Important*: these are pointers to copies of the target phrases from the + // hypotheses. This class is used to keep priority queues ordered in the + // background, so comparisons made as those data structures are cleaned up + // may occur *after* the target phrases in hypotheses have been cleaned up, + // leading to segfaults if relying on hypotheses to provide target phrases. boost::shared_ptr<TargetPhrase> phrA = itemA->GetTargetPhrase(); boost::shared_ptr<TargetPhrase> phrB = itemB->GetTargetPhrase(); if (!phrA || !phrB) { @@ -137,12 +142,30 @@ public: class HypothesisScoreOrderer { +private: + bool m_deterministic; + public: + HypothesisScoreOrderer(const bool deterministic = false) + : m_deterministic(deterministic) {} + bool operator()(const Hypothesis* hypoA, const Hypothesis* hypoB) const { + float scoreA = hypoA->GetTotalScore(); float scoreB = hypoB->GetTotalScore(); - return (scoreA > scoreB); + if (scoreA > scoreB) { + return true; + } else if (scoreA < scoreB) { + return false; + } else { + if (m_deterministic) { + // Equal scores: break ties by comparing target phrases + return (hypoA->GetCurrTargetPhrase().Compare(hypoB->GetCurrTargetPhrase()) < 0); + } + // Fallback: scoreA > scoreB == false, non-deterministic sort + return false; + } } }; @@ -164,6 +187,8 @@ private: const SquareMatrix &m_futureScores; float m_futureScore; + bool m_deterministic; + std::vector< const Hypothesis* > m_hypotheses; boost::unordered_set< int > m_seenPosition; @@ -181,8 +206,9 @@ public: BackwardsEdge(const BitmapContainer &prevBitmapContainer , BitmapContainer &parent , const TranslationOptionList &translations - , const SquareMatrix &futureScores, - const InputType& source); + , const SquareMatrix &futureScores + , const InputType& source + , const bool deterministic = false); ~BackwardsEdge(); bool GetInitialized(); @@ -216,7 +242,7 @@ private: public: BitmapContainer(const Bitmap &bitmap , HypothesisStackCubePruning &stack - , bool deterministic_sort = false); + , bool deterministic = false); // The destructor will also delete all the edges that are // connected to this BitmapContainer. |