Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTomasz Dwojak <t.dwojak@amu.edu.pl>2017-05-31 16:27:47 +0300
committerTomasz Dwojak <t.dwojak@amu.edu.pl>2017-06-05 17:10:33 +0300
commite7bc69b077e214f7fc47b246d565a13d6836e6c5 (patch)
tree9590bf5099ef8ddd20d3682ac68434e8934bd6f7 /src/amun/cpu/decoder/best_hyps.h
parent4574e3bc997f478db63f93e3aaf395e0aa30a055 (diff)
Refactoring
* Removed redundant God dependencies * Made Search class more readable * Add constructor to BaseBestHyps * Cleaned up Sentences and Histories classes
Diffstat (limited to 'src/amun/cpu/decoder/best_hyps.h')
-rw-r--r--src/amun/cpu/decoder/best_hyps.h201
1 files changed, 101 insertions, 100 deletions
diff --git a/src/amun/cpu/decoder/best_hyps.h b/src/amun/cpu/decoder/best_hyps.h
index 3cf37acd..d7858953 100644
--- a/src/amun/cpu/decoder/best_hyps.h
+++ b/src/amun/cpu/decoder/best_hyps.h
@@ -23,128 +23,129 @@ struct ProbCompare {
class BestHyps : public BestHypsBase
{
-public:
- void CalcBeam(
- const God &god,
- const Beam& prevHyps,
- const std::vector<ScorerPtr>& scorers,
- const Words& filterIndices,
- bool returnAlignment,
- std::vector<Beam>& beams,
- std::vector<size_t>& beamSizes
- )
- {
- using namespace mblas;
-
- auto& weights = god.GetScorerWeights();
-
- mblas::ArrayMatrix& Probs = static_cast<mblas::ArrayMatrix&>(scorers[0]->GetProbs());
-
- mblas::ArrayMatrix Costs(Probs.rows(), 1);
- for (size_t i = 0; i < prevHyps.size(); ++i) {
- Costs.data()[i] = prevHyps[i]->GetCost();
- }
-
- Probs *= weights.at(scorers[0]->GetName());
- AddBiasVector<byColumn>(Probs, Costs);
+ public:
+ BestHyps(const God &god)
+ : BestHypsBase(
+ !god.Get<bool>("allow-unk"),
+ god.Get<bool>("n-best"),
+ god.Get<std::vector<std::string>>("softmax-filter").size(),
+ god.Get<bool>("return-alignment") || god.Get<bool>("return-soft-alignment"),
+ god.GetScorerWeights())
+ {}
+
+ void CalcBeam(
+ const Beam& prevHyps,
+ const std::vector<ScorerPtr>& scorers,
+ const Words& filterIndices,
+ std::vector<Beam>& beams,
+ std::vector<size_t>& beamSizes)
+ {
+ using namespace mblas;
+
+ mblas::ArrayMatrix& Probs = static_cast<mblas::ArrayMatrix&>(scorers[0]->GetProbs());
+
+ mblas::ArrayMatrix Costs(Probs.rows(), 1);
+ for (size_t i = 0; i < prevHyps.size(); ++i) {
+ Costs.data()[i] = prevHyps[i]->GetCost();
+ }
- for (size_t i = 1; i < scorers.size(); ++i) {
- mblas::ArrayMatrix &currProb = static_cast<mblas::ArrayMatrix&>(scorers[i]->GetProbs());
+ Probs *= weights_.at(scorers[0]->GetName());
+ AddBiasVector<byColumn>(Probs, Costs);
- Probs += weights.at(scorers[i]->GetName()) * currProb;
- }
+ for (size_t i = 1; i < scorers.size(); ++i) {
+ mblas::ArrayMatrix &currProb = static_cast<mblas::ArrayMatrix&>(scorers[i]->GetProbs());
- size_t size = Probs.rows() * Probs.columns(); // Probs.size();
- std::vector<size_t> keys(size);
- for (size_t i = 0; i < keys.size(); ++i) {
- keys[i] = i;
- }
+ Probs += weights_.at(scorers[i]->GetName()) * currProb;
+ }
- size_t beamSize = beamSizes[0];
+ size_t size = Probs.rows() * Probs.columns(); // Probs.size();
+ std::vector<size_t> keys(size);
+ for (size_t i = 0; i < keys.size(); ++i) {
+ keys[i] = i;
+ }
- std::vector<size_t> bestKeys(beamSize);
- std::vector<float> bestCosts(beamSize);
+ size_t beamSize = beamSizes[0];
- if (!god.Get<bool>("allow-unk")) {
- blaze::column(Probs, UNK_ID) = std::numeric_limits<float>::lowest();
- }
+ std::vector<size_t> bestKeys(beamSize);
+ std::vector<float> bestCosts(beamSize);
- std::nth_element(keys.begin(), keys.begin() + beamSize, keys.end(),
- ProbCompare(Probs.data()));
+ if (forbidUNK_) {
+ blaze::column(Probs, UNK_ID) = std::numeric_limits<float>::lowest();
+ }
- for (size_t i = 0; i < beamSize; ++i) {
- bestKeys[i] = keys[i];
- bestCosts[i] = Probs.data()[keys[i]];
- }
+ std::nth_element(keys.begin(), keys.begin() + beamSize, keys.end(),
+ ProbCompare(Probs.data()));
- std::vector<std::vector<float>> breakDowns;
- bool doBreakdown = god.Get<bool>("n-best");
- if (doBreakdown) {
- breakDowns.push_back(bestCosts);
- for (auto& scorer : scorers) {
- std::vector<float> modelCosts(beamSize);
- mblas::ArrayMatrix &currProb = static_cast<mblas::ArrayMatrix&>(scorer->GetProbs());
-
- auto it = boost::make_permutation_iterator(currProb.begin(), keys.begin());
- std::copy(it, it + beamSize, modelCosts.begin());
- breakDowns.push_back(modelCosts);
+ for (size_t i = 0; i < beamSize; ++i) {
+ bestKeys[i] = keys[i];
+ bestCosts[i] = Probs.data()[keys[i]];
}
- }
- bool filter = god.Get<std::vector<std::string>>("softmax-filter").size();
-
- for (size_t i = 0; i < beamSize; i++) {
- size_t wordIndex = bestKeys[i] % Probs.columns();
+ std::vector<std::vector<float>> breakDowns;
+ if (returnNBestList_) {
+ breakDowns.push_back(bestCosts);
+ for (auto& scorer : scorers) {
+ std::vector<float> modelCosts(beamSize);
+ mblas::ArrayMatrix &currProb = static_cast<mblas::ArrayMatrix&>(scorer->GetProbs());
- if (filter) {
- wordIndex = filterIndices[wordIndex];
+ auto it = boost::make_permutation_iterator(currProb.begin(), keys.begin());
+ std::copy(it, it + beamSize, modelCosts.begin());
+ breakDowns.push_back(modelCosts);
+ }
}
- size_t hypIndex = bestKeys[i] / Probs.columns();
- float cost = bestCosts[i];
+ for (size_t i = 0; i < beamSize; i++) {
+ size_t wordIndex = bestKeys[i] % Probs.columns();
- HypothesisPtr hyp;
- if (returnAlignment) {
- std::vector<SoftAlignmentPtr> alignments;
- for (auto& scorer : scorers) {
- if (CPU::EncoderDecoder* encdec = dynamic_cast<CPU::EncoderDecoder*>(scorer.get())) {
- auto& attention = encdec->GetAttention();
- alignments.emplace_back(new SoftAlignment(attention.begin(hypIndex),
- attention.end(hypIndex)));
- } else {
- amunmt_UTIL_THROW2("Return Alignment is allowed only with Nematus scorer.");
- }
+ if (isInputFiltered_) {
+ wordIndex = filterIndices[wordIndex];
}
- hyp.reset(new Hypothesis(prevHyps[hypIndex], wordIndex, hypIndex, cost, alignments));
- } else {
- hyp.reset(new Hypothesis(prevHyps[hypIndex], wordIndex, hypIndex, cost));
- }
+ size_t hypIndex = bestKeys[i] / Probs.columns();
+ float cost = bestCosts[i];
+
+ HypothesisPtr hyp;
+ if (returnAttentionWeights_) {
+ std::vector<SoftAlignmentPtr> alignments;
+ for (auto& scorer : scorers) {
+ if (CPU::EncoderDecoder* encdec = dynamic_cast<CPU::EncoderDecoder*>(scorer.get())) {
+ auto& attention = encdec->GetAttention();
+ alignments.emplace_back(new SoftAlignment(attention.begin(hypIndex),
+ attention.end(hypIndex)));
+ } else {
+ amunmt_UTIL_THROW2("Return Alignment is allowed only with Nematus scorer.");
+ }
+ }
- if (doBreakdown) {
- hyp->GetCostBreakdown().resize(scorers.size());
- float sum = 0;
- for(size_t j = 0; j < scorers.size(); ++j) {
- if (j == 0) {
- hyp->GetCostBreakdown()[0] = breakDowns[0][i];
- } else {
- float cost = 0;
- if (j < scorers.size()) {
- if (prevHyps[hypIndex]->GetCostBreakdown().size() < scorers.size())
- const_cast<HypothesisPtr&>(prevHyps[hypIndex])->GetCostBreakdown().resize(scorers.size(), 0.0);
- cost = breakDowns[j][i] + const_cast<HypothesisPtr&>(prevHyps[hypIndex])->GetCostBreakdown()[j];
+ hyp.reset(new Hypothesis(prevHyps[hypIndex], wordIndex, hypIndex, cost, alignments));
+ } else {
+ hyp.reset(new Hypothesis(prevHyps[hypIndex], wordIndex, hypIndex, cost));
+ }
+
+ if (returnNBestList_) {
+ hyp->GetCostBreakdown().resize(scorers.size());
+ float sum = 0;
+ for(size_t j = 0; j < scorers.size(); ++j) {
+ if (j == 0) {
+ hyp->GetCostBreakdown()[0] = breakDowns[0][i];
+ } else {
+ float cost = 0;
+ if (j < scorers.size()) {
+ if (prevHyps[hypIndex]->GetCostBreakdown().size() < scorers.size())
+ const_cast<HypothesisPtr&>(prevHyps[hypIndex])->GetCostBreakdown().resize(scorers.size(), 0.0);
+ cost = breakDowns[j][i] + const_cast<HypothesisPtr&>(prevHyps[hypIndex])->GetCostBreakdown()[j];
+ }
+ sum += weights_.at(scorers[j]->GetName()) * cost;
+ hyp->GetCostBreakdown()[j] = cost;
}
- sum += weights.at(scorers[j]->GetName()) * cost;
- hyp->GetCostBreakdown()[j] = cost;
}
+ hyp->GetCostBreakdown()[0] -= sum;
+ hyp->GetCostBreakdown()[0] /= weights_.at(scorers[0]->GetName());
}
- hyp->GetCostBreakdown()[0] -= sum;
- hyp->GetCostBreakdown()[0] /= weights.at(scorers[0]->GetName());
+ beams[0].push_back(hyp);
}
- beams[0].push_back(hyp);
}
- }
};
-} // namespace
-}
+} // namespace CPU
+} // namespace amunmt