diff options
author | Hieu Hoang <hieuhoang@gmail.com> | 2017-06-16 03:21:33 +0300 |
---|---|---|
committer | Hieu Hoang <hieuhoang@gmail.com> | 2017-06-16 03:21:33 +0300 |
commit | 347b8110caaa71db67c47cf4356ca2546c4a58a2 (patch) | |
tree | c99d82dc874b7dbb456250fa2280cc779aaf77a4 /src/amun/cpu/decoder/best_hyps.h | |
parent | 9eb238f6c3b2f711d3a94354d5545982572bd539 (diff) | |
parent | b72076fefdecf7fee8e008091de3a08fdf691eaf (diff) |
merge
Diffstat (limited to 'src/amun/cpu/decoder/best_hyps.h')
-rw-r--r-- | src/amun/cpu/decoder/best_hyps.h | 208 |
1 files changed, 101 insertions, 107 deletions
diff --git a/src/amun/cpu/decoder/best_hyps.h b/src/amun/cpu/decoder/best_hyps.h index 8ebc9ce6..82d0f3e1 100644 --- a/src/amun/cpu/decoder/best_hyps.h +++ b/src/amun/cpu/decoder/best_hyps.h @@ -5,7 +5,6 @@ #include "common/scorer.h" #include "common/god.h" -#include "common/utils.h" #include "common/exception.h" #include "cpu/mblas/matrix.h" @@ -24,134 +23,129 @@ struct ProbCompare { class BestHyps : public BestHypsBase { -public: - void CalcBeam( - const God &god, - const Beam& prevHyps, - const std::vector<ScorerPtr>& scorers, - const Words& filterIndices, - bool returnAlignment, - std::vector<Beam>& beams, - std::vector<uint>& beamSizes - ) - { - using namespace mblas; - - auto& weights = god.GetScorerWeights(); - - mblas::ArrayMatrix& Probs = static_cast<mblas::ArrayMatrix&>(scorers[0]->GetProbs()); - - mblas::ArrayMatrix Costs(Probs.rows(), 1); - for (size_t i = 0; i < prevHyps.size(); ++i) { - Costs.data()[i] = prevHyps[i]->GetCost(); - } - - Probs *= weights.at(scorers[0]->GetName()); - AddBiasVector<byColumn>(Probs, Costs); - - for (size_t i = 1; i < scorers.size(); ++i) { - mblas::ArrayMatrix &currProb = static_cast<mblas::ArrayMatrix&>(scorers[i]->GetProbs()); - - Probs += weights.at(scorers[i]->GetName()) * currProb; - } - - size_t size = Probs.rows() * Probs.columns(); // Probs.size(); - std::vector<size_t> keys(size); - for (size_t i = 0; i < keys.size(); ++i) { - keys[i] = i; - } + public: + BestHyps(const God &god) + : BestHypsBase( + !god.Get<bool>("allow-unk"), + god.Get<bool>("n-best"), + god.Get<std::vector<std::string>>("softmax-filter").size(), + god.Get<bool>("return-alignment") || god.Get<bool>("return-soft-alignment"), + god.GetScorerWeights()) + {} + + void CalcBeam( + const Beam& prevHyps, + const std::vector<ScorerPtr>& scorers, + const Words& filterIndices, + std::vector<Beam>& beams, + std::vector<uint>& beamSizes) + { + using namespace mblas; + + mblas::ArrayMatrix& Probs = static_cast<mblas::ArrayMatrix&>(scorers[0]->GetProbs()); + + mblas::ArrayMatrix Costs(Probs.rows(), 1); + for (size_t i = 0; i < prevHyps.size(); ++i) { + Costs.data()[i] = prevHyps[i]->GetCost(); + } - size_t beamSize = beamSizes[0]; + Probs *= weights_.at(scorers[0]->GetName()); + AddBiasVector<byColumn>(Probs, Costs); - std::vector<size_t> bestKeys(beamSize); - std::vector<float> bestCosts(beamSize); + for (size_t i = 1; i < scorers.size(); ++i) { + mblas::ArrayMatrix &currProb = static_cast<mblas::ArrayMatrix&>(scorers[i]->GetProbs()); - if (!god.Get<bool>("allow-unk")) { - blaze::column(Probs, UNK_ID) = std::numeric_limits<float>::lowest(); - } + Probs += weights_.at(scorers[i]->GetName()) * currProb; + } - //std::cerr << "2Probs=" << Probs.Debug(1) << std::endl; - //std::cerr << "beamSizes=" << amunmt::Debug(beamSizes, 2) << " " << std::endl; + size_t size = Probs.rows() * Probs.columns(); // Probs.size(); + std::vector<size_t> keys(size); + for (size_t i = 0; i < keys.size(); ++i) { + keys[i] = i; + } - std::nth_element(keys.begin(), keys.begin() + beamSize, keys.end(), - ProbCompare(Probs.data())); + size_t beamSize = beamSizes[0]; - for (size_t i = 0; i < beamSize; ++i) { - bestKeys[i] = keys[i]; - bestCosts[i] = Probs.data()[keys[i]]; - } + std::vector<size_t> bestKeys(beamSize); + std::vector<float> bestCosts(beamSize); - //std::cerr << "bestCosts=" << amunmt::Debug(bestCosts, 2) << " " << std::endl; - //std::cerr << "bestKeys=" << amunmt::Debug(bestKeys, 2) << std::endl; + if (forbidUNK_) { + blaze::column(Probs, UNK_ID) = std::numeric_limits<float>::lowest(); + } - std::vector<std::vector<float>> breakDowns; - bool doBreakdown = god.Get<bool>("n-best"); - if (doBreakdown) { - breakDowns.push_back(bestCosts); - for (auto& scorer : scorers) { - std::vector<float> modelCosts(beamSize); - mblas::ArrayMatrix &currProb = static_cast<mblas::ArrayMatrix&>(scorer->GetProbs()); + std::nth_element(keys.begin(), keys.begin() + beamSize, keys.end(), + ProbCompare(Probs.data())); - auto it = boost::make_permutation_iterator(currProb.begin(), keys.begin()); - std::copy(it, it + beamSize, modelCosts.begin()); - breakDowns.push_back(modelCosts); + for (size_t i = 0; i < beamSize; ++i) { + bestKeys[i] = keys[i]; + bestCosts[i] = Probs.data()[keys[i]]; } - } - - bool filter = god.Get<std::vector<std::string>>("softmax-filter").size(); - for (size_t i = 0; i < beamSize; i++) { - size_t wordIndex = bestKeys[i] % Probs.columns(); + std::vector<std::vector<float>> breakDowns; + if (returnNBestList_) { + breakDowns.push_back(bestCosts); + for (auto& scorer : scorers) { + std::vector<float> modelCosts(beamSize); + mblas::ArrayMatrix &currProb = static_cast<mblas::ArrayMatrix&>(scorer->GetProbs()); - if (filter) { - wordIndex = filterIndices[wordIndex]; + auto it = boost::make_permutation_iterator(currProb.begin(), keys.begin()); + std::copy(it, it + beamSize, modelCosts.begin()); + breakDowns.push_back(modelCosts); + } } - size_t hypIndex = bestKeys[i] / Probs.columns(); - float cost = bestCosts[i]; + for (size_t i = 0; i < beamSize; i++) { + size_t wordIndex = bestKeys[i] % Probs.columns(); - HypothesisPtr hyp; - if (returnAlignment) { - std::vector<SoftAlignmentPtr> alignments; - for (auto& scorer : scorers) { - if (CPU::EncoderDecoder* encdec = dynamic_cast<CPU::EncoderDecoder*>(scorer.get())) { - auto& attention = encdec->GetAttention(); - alignments.emplace_back(new SoftAlignment(attention.begin(hypIndex), - attention.end(hypIndex))); - } else { - amunmt_UTIL_THROW2("Return Alignment is allowed only with Nematus scorer."); - } + if (isInputFiltered_) { + wordIndex = filterIndices[wordIndex]; } - hyp.reset(new Hypothesis(prevHyps[hypIndex], wordIndex, hypIndex, cost, alignments)); - } else { - hyp.reset(new Hypothesis(prevHyps[hypIndex], wordIndex, hypIndex, cost)); - } + size_t hypIndex = bestKeys[i] / Probs.columns(); + float cost = bestCosts[i]; + + HypothesisPtr hyp; + if (returnAttentionWeights_) { + std::vector<SoftAlignmentPtr> alignments; + for (auto& scorer : scorers) { + if (CPU::EncoderDecoder* encdec = dynamic_cast<CPU::EncoderDecoder*>(scorer.get())) { + auto& attention = encdec->GetAttention(); + alignments.emplace_back(new SoftAlignment(attention.begin(hypIndex), + attention.end(hypIndex))); + } else { + amunmt_UTIL_THROW2("Return Alignment is allowed only with Nematus scorer."); + } + } - if (doBreakdown) { - hyp->GetCostBreakdown().resize(scorers.size()); - float sum = 0; - for(size_t j = 0; j < scorers.size(); ++j) { - if (j == 0) { - hyp->GetCostBreakdown()[0] = breakDowns[0][i]; - } else { - float cost = 0; - if (j < scorers.size()) { - if (prevHyps[hypIndex]->GetCostBreakdown().size() < scorers.size()) - const_cast<HypothesisPtr&>(prevHyps[hypIndex])->GetCostBreakdown().resize(scorers.size(), 0.0); - cost = breakDowns[j][i] + const_cast<HypothesisPtr&>(prevHyps[hypIndex])->GetCostBreakdown()[j]; + hyp.reset(new Hypothesis(prevHyps[hypIndex], wordIndex, hypIndex, cost, alignments)); + } else { + hyp.reset(new Hypothesis(prevHyps[hypIndex], wordIndex, hypIndex, cost)); + } + + if (returnNBestList_) { + hyp->GetCostBreakdown().resize(scorers.size()); + float sum = 0; + for(size_t j = 0; j < scorers.size(); ++j) { + if (j == 0) { + hyp->GetCostBreakdown()[0] = breakDowns[0][i]; + } else { + float cost = 0; + if (j < scorers.size()) { + if (prevHyps[hypIndex]->GetCostBreakdown().size() < scorers.size()) + const_cast<HypothesisPtr&>(prevHyps[hypIndex])->GetCostBreakdown().resize(scorers.size(), 0.0); + cost = breakDowns[j][i] + const_cast<HypothesisPtr&>(prevHyps[hypIndex])->GetCostBreakdown()[j]; + } + sum += weights_.at(scorers[j]->GetName()) * cost; + hyp->GetCostBreakdown()[j] = cost; } - sum += weights.at(scorers[j]->GetName()) * cost; - hyp->GetCostBreakdown()[j] = cost; } + hyp->GetCostBreakdown()[0] -= sum; + hyp->GetCostBreakdown()[0] /= weights_.at(scorers[0]->GetName()); } - hyp->GetCostBreakdown()[0] -= sum; - hyp->GetCostBreakdown()[0] /= weights.at(scorers[0]->GetName()); + beams[0].push_back(hyp); } - beams[0].push_back(hyp); } - } }; -} // namespace -} +} // namespace CPU +} // namespace amunmt |