Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorHieu Hoang <hieuhoang@gmail.com>2018-06-08 18:18:07 +0300
committerHieu Hoang <hieuhoang@gmail.com>2018-06-08 18:18:07 +0300
commitf429d4a801465b1862b7da724624d16be31533ef (patch)
treef82429d11f17d035d95d1201e35f62695344a42d
parent7c417028b64fdfc09e2d80ba751188c08cebd4af (diff)
max-length per sentence in batch1.5.0
-rw-r--r--src/amun/common/history.h3
-rw-r--r--src/amun/common/search.cpp12
-rw-r--r--src/amun/common/search.h3
3 files changed, 14 insertions, 4 deletions
diff --git a/src/amun/common/history.h b/src/amun/common/history.h
index 62d44ae0..2878cd42 100644
--- a/src/amun/common/history.h
+++ b/src/amun/common/history.h
@@ -46,6 +46,9 @@ class History {
unsigned GetLineNum() const
{ return lineNo_; }
+ unsigned GetMaxLength() const
+ { return maxLength_; }
+
void SetActive(bool active);
bool GetActive() const;
diff --git a/src/amun/common/search.cpp b/src/amun/common/search.cpp
index 606b7911..e7583c35 100644
--- a/src/amun/common/search.cpp
+++ b/src/amun/common/search.cpp
@@ -100,7 +100,7 @@ std::shared_ptr<Histories> Search::Translate(const Sentences& sentences) {
}
//cerr << "beamSizes=" << Debug(beamSizes, 1) << endl;
- bool hasSurvivors = CalcBeam(histories, beamSizes, prevHyps, states, nextStates);
+ bool hasSurvivors = CalcBeam(histories, beamSizes, prevHyps, states, nextStates, decoderStep);
if (!hasSurvivors) {
break;
}
@@ -134,18 +134,24 @@ bool Search::CalcBeam(
std::vector<unsigned>& beamSizes,
Beam& prevHyps,
States& states,
- States& nextStates)
+ States& nextStates,
+ unsigned decoderStep)
{
unsigned batchSize = beamSizes.size();
Beams beams(batchSize);
bestHyps_->CalcBeam(prevHyps, scorers_, filterIndices_, beams, beamSizes);
histories->Add(beams);
+ //cerr << "batchSize=" << batchSize << endl;
histories->SetActive(false);
Beam survivors;
for (unsigned batchId = 0; batchId < batchSize; ++batchId) {
+ const History &hist = *histories->at(batchId);
+ unsigned maxLength = hist.GetMaxLength();
+
+ //cerr << "beamSizes[batchId]=" << batchId << " " << beamSizes[batchId] << " " << maxLength << endl;
for (auto& h : beams[batchId]) {
- if (h->GetWord() != EOS_ID) {
+ if (decoderStep < maxLength && h->GetWord() != EOS_ID) {
survivors.push_back(h);
histories->SetActive(batchId, true);
diff --git a/src/amun/common/search.h b/src/amun/common/search.h
index 159ede63..81c3d807 100644
--- a/src/amun/common/search.h
+++ b/src/amun/common/search.h
@@ -30,7 +30,8 @@ class Search {
std::vector<unsigned>& beamSizes,
Beam& prevHyps,
States& states,
- States& nextStates);
+ States& nextStates,
+ unsigned decoderStep);
Search(const Search&) = delete;