Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorHieu Hoang <hieuhoang@gmail.com>2017-06-16 03:21:33 +0300
committerHieu Hoang <hieuhoang@gmail.com>2017-06-16 03:21:33 +0300
commit347b8110caaa71db67c47cf4356ca2546c4a58a2 (patch)
treec99d82dc874b7dbb456250fa2280cc779aaf77a4 /src/amun/cpu/decoder/best_hyps.h
parent9eb238f6c3b2f711d3a94354d5545982572bd539 (diff)
parentb72076fefdecf7fee8e008091de3a08fdf691eaf (diff)
merge
Diffstat (limited to 'src/amun/cpu/decoder/best_hyps.h')
-rw-r--r--src/amun/cpu/decoder/best_hyps.h208
1 files changed, 101 insertions, 107 deletions
diff --git a/src/amun/cpu/decoder/best_hyps.h b/src/amun/cpu/decoder/best_hyps.h
index 8ebc9ce6..82d0f3e1 100644
--- a/src/amun/cpu/decoder/best_hyps.h
+++ b/src/amun/cpu/decoder/best_hyps.h
@@ -5,7 +5,6 @@
#include "common/scorer.h"
#include "common/god.h"
-#include "common/utils.h"
#include "common/exception.h"
#include "cpu/mblas/matrix.h"
@@ -24,134 +23,129 @@ struct ProbCompare {
class BestHyps : public BestHypsBase
{
-public:
- void CalcBeam(
- const God &god,
- const Beam& prevHyps,
- const std::vector<ScorerPtr>& scorers,
- const Words& filterIndices,
- bool returnAlignment,
- std::vector<Beam>& beams,
- std::vector<uint>& beamSizes
- )
- {
- using namespace mblas;
-
- auto& weights = god.GetScorerWeights();
-
- mblas::ArrayMatrix& Probs = static_cast<mblas::ArrayMatrix&>(scorers[0]->GetProbs());
-
- mblas::ArrayMatrix Costs(Probs.rows(), 1);
- for (size_t i = 0; i < prevHyps.size(); ++i) {
- Costs.data()[i] = prevHyps[i]->GetCost();
- }
-
- Probs *= weights.at(scorers[0]->GetName());
- AddBiasVector<byColumn>(Probs, Costs);
-
- for (size_t i = 1; i < scorers.size(); ++i) {
- mblas::ArrayMatrix &currProb = static_cast<mblas::ArrayMatrix&>(scorers[i]->GetProbs());
-
- Probs += weights.at(scorers[i]->GetName()) * currProb;
- }
-
- size_t size = Probs.rows() * Probs.columns(); // Probs.size();
- std::vector<size_t> keys(size);
- for (size_t i = 0; i < keys.size(); ++i) {
- keys[i] = i;
- }
+ public:
+ BestHyps(const God &god)
+ : BestHypsBase(
+ !god.Get<bool>("allow-unk"),
+ god.Get<bool>("n-best"),
+ god.Get<std::vector<std::string>>("softmax-filter").size(),
+ god.Get<bool>("return-alignment") || god.Get<bool>("return-soft-alignment"),
+ god.GetScorerWeights())
+ {}
+
+ void CalcBeam(
+ const Beam& prevHyps,
+ const std::vector<ScorerPtr>& scorers,
+ const Words& filterIndices,
+ std::vector<Beam>& beams,
+ std::vector<uint>& beamSizes)
+ {
+ using namespace mblas;
+
+ mblas::ArrayMatrix& Probs = static_cast<mblas::ArrayMatrix&>(scorers[0]->GetProbs());
+
+ mblas::ArrayMatrix Costs(Probs.rows(), 1);
+ for (size_t i = 0; i < prevHyps.size(); ++i) {
+ Costs.data()[i] = prevHyps[i]->GetCost();
+ }
- size_t beamSize = beamSizes[0];
+ Probs *= weights_.at(scorers[0]->GetName());
+ AddBiasVector<byColumn>(Probs, Costs);
- std::vector<size_t> bestKeys(beamSize);
- std::vector<float> bestCosts(beamSize);
+ for (size_t i = 1; i < scorers.size(); ++i) {
+ mblas::ArrayMatrix &currProb = static_cast<mblas::ArrayMatrix&>(scorers[i]->GetProbs());
- if (!god.Get<bool>("allow-unk")) {
- blaze::column(Probs, UNK_ID) = std::numeric_limits<float>::lowest();
- }
+ Probs += weights_.at(scorers[i]->GetName()) * currProb;
+ }
- //std::cerr << "2Probs=" << Probs.Debug(1) << std::endl;
- //std::cerr << "beamSizes=" << amunmt::Debug(beamSizes, 2) << " " << std::endl;
+ size_t size = Probs.rows() * Probs.columns(); // Probs.size();
+ std::vector<size_t> keys(size);
+ for (size_t i = 0; i < keys.size(); ++i) {
+ keys[i] = i;
+ }
- std::nth_element(keys.begin(), keys.begin() + beamSize, keys.end(),
- ProbCompare(Probs.data()));
+ size_t beamSize = beamSizes[0];
- for (size_t i = 0; i < beamSize; ++i) {
- bestKeys[i] = keys[i];
- bestCosts[i] = Probs.data()[keys[i]];
- }
+ std::vector<size_t> bestKeys(beamSize);
+ std::vector<float> bestCosts(beamSize);
- //std::cerr << "bestCosts=" << amunmt::Debug(bestCosts, 2) << " " << std::endl;
- //std::cerr << "bestKeys=" << amunmt::Debug(bestKeys, 2) << std::endl;
+ if (forbidUNK_) {
+ blaze::column(Probs, UNK_ID) = std::numeric_limits<float>::lowest();
+ }
- std::vector<std::vector<float>> breakDowns;
- bool doBreakdown = god.Get<bool>("n-best");
- if (doBreakdown) {
- breakDowns.push_back(bestCosts);
- for (auto& scorer : scorers) {
- std::vector<float> modelCosts(beamSize);
- mblas::ArrayMatrix &currProb = static_cast<mblas::ArrayMatrix&>(scorer->GetProbs());
+ std::nth_element(keys.begin(), keys.begin() + beamSize, keys.end(),
+ ProbCompare(Probs.data()));
- auto it = boost::make_permutation_iterator(currProb.begin(), keys.begin());
- std::copy(it, it + beamSize, modelCosts.begin());
- breakDowns.push_back(modelCosts);
+ for (size_t i = 0; i < beamSize; ++i) {
+ bestKeys[i] = keys[i];
+ bestCosts[i] = Probs.data()[keys[i]];
}
- }
-
- bool filter = god.Get<std::vector<std::string>>("softmax-filter").size();
- for (size_t i = 0; i < beamSize; i++) {
- size_t wordIndex = bestKeys[i] % Probs.columns();
+ std::vector<std::vector<float>> breakDowns;
+ if (returnNBestList_) {
+ breakDowns.push_back(bestCosts);
+ for (auto& scorer : scorers) {
+ std::vector<float> modelCosts(beamSize);
+ mblas::ArrayMatrix &currProb = static_cast<mblas::ArrayMatrix&>(scorer->GetProbs());
- if (filter) {
- wordIndex = filterIndices[wordIndex];
+ auto it = boost::make_permutation_iterator(currProb.begin(), keys.begin());
+ std::copy(it, it + beamSize, modelCosts.begin());
+ breakDowns.push_back(modelCosts);
+ }
}
- size_t hypIndex = bestKeys[i] / Probs.columns();
- float cost = bestCosts[i];
+ for (size_t i = 0; i < beamSize; i++) {
+ size_t wordIndex = bestKeys[i] % Probs.columns();
- HypothesisPtr hyp;
- if (returnAlignment) {
- std::vector<SoftAlignmentPtr> alignments;
- for (auto& scorer : scorers) {
- if (CPU::EncoderDecoder* encdec = dynamic_cast<CPU::EncoderDecoder*>(scorer.get())) {
- auto& attention = encdec->GetAttention();
- alignments.emplace_back(new SoftAlignment(attention.begin(hypIndex),
- attention.end(hypIndex)));
- } else {
- amunmt_UTIL_THROW2("Return Alignment is allowed only with Nematus scorer.");
- }
+ if (isInputFiltered_) {
+ wordIndex = filterIndices[wordIndex];
}
- hyp.reset(new Hypothesis(prevHyps[hypIndex], wordIndex, hypIndex, cost, alignments));
- } else {
- hyp.reset(new Hypothesis(prevHyps[hypIndex], wordIndex, hypIndex, cost));
- }
+ size_t hypIndex = bestKeys[i] / Probs.columns();
+ float cost = bestCosts[i];
+
+ HypothesisPtr hyp;
+ if (returnAttentionWeights_) {
+ std::vector<SoftAlignmentPtr> alignments;
+ for (auto& scorer : scorers) {
+ if (CPU::EncoderDecoder* encdec = dynamic_cast<CPU::EncoderDecoder*>(scorer.get())) {
+ auto& attention = encdec->GetAttention();
+ alignments.emplace_back(new SoftAlignment(attention.begin(hypIndex),
+ attention.end(hypIndex)));
+ } else {
+ amunmt_UTIL_THROW2("Return Alignment is allowed only with Nematus scorer.");
+ }
+ }
- if (doBreakdown) {
- hyp->GetCostBreakdown().resize(scorers.size());
- float sum = 0;
- for(size_t j = 0; j < scorers.size(); ++j) {
- if (j == 0) {
- hyp->GetCostBreakdown()[0] = breakDowns[0][i];
- } else {
- float cost = 0;
- if (j < scorers.size()) {
- if (prevHyps[hypIndex]->GetCostBreakdown().size() < scorers.size())
- const_cast<HypothesisPtr&>(prevHyps[hypIndex])->GetCostBreakdown().resize(scorers.size(), 0.0);
- cost = breakDowns[j][i] + const_cast<HypothesisPtr&>(prevHyps[hypIndex])->GetCostBreakdown()[j];
+ hyp.reset(new Hypothesis(prevHyps[hypIndex], wordIndex, hypIndex, cost, alignments));
+ } else {
+ hyp.reset(new Hypothesis(prevHyps[hypIndex], wordIndex, hypIndex, cost));
+ }
+
+ if (returnNBestList_) {
+ hyp->GetCostBreakdown().resize(scorers.size());
+ float sum = 0;
+ for(size_t j = 0; j < scorers.size(); ++j) {
+ if (j == 0) {
+ hyp->GetCostBreakdown()[0] = breakDowns[0][i];
+ } else {
+ float cost = 0;
+ if (j < scorers.size()) {
+ if (prevHyps[hypIndex]->GetCostBreakdown().size() < scorers.size())
+ const_cast<HypothesisPtr&>(prevHyps[hypIndex])->GetCostBreakdown().resize(scorers.size(), 0.0);
+ cost = breakDowns[j][i] + const_cast<HypothesisPtr&>(prevHyps[hypIndex])->GetCostBreakdown()[j];
+ }
+ sum += weights_.at(scorers[j]->GetName()) * cost;
+ hyp->GetCostBreakdown()[j] = cost;
}
- sum += weights.at(scorers[j]->GetName()) * cost;
- hyp->GetCostBreakdown()[j] = cost;
}
+ hyp->GetCostBreakdown()[0] -= sum;
+ hyp->GetCostBreakdown()[0] /= weights_.at(scorers[0]->GetName());
}
- hyp->GetCostBreakdown()[0] -= sum;
- hyp->GetCostBreakdown()[0] /= weights.at(scorers[0]->GetName());
+ beams[0].push_back(hyp);
}
- beams[0].push_back(hyp);
}
- }
};
-} // namespace
-}
+} // namespace CPU
+} // namespace amunmt