Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorHieu Hoang <hihoan@microsoft.com>2021-07-02 22:06:03 +0300
committerHieu Hoang <hihoan@microsoft.com>2021-07-02 22:06:03 +0300
commitbd1f1ee9cb0cd316a62cbcce6653406979be9a00 (patch)
tree9b79484978bc8337c27ad7e06c4a1a6de41481ba
parentff8af52624682180dc415fbfbab1d9b40fc87eea (diff)
marcin's review changes
-rw-r--r--src/data/shortlist.cpp3
-rw-r--r--src/data/shortlist.h2
-rw-r--r--src/layers/logits.cpp6
-rw-r--r--src/layers/logits.h2
-rw-r--r--src/layers/output.cpp5
-rw-r--r--src/translator/beam_search.cpp7
-rw-r--r--src/translator/translator.h2
7 files changed, 10 insertions, 17 deletions
diff --git a/src/data/shortlist.cpp b/src/data/shortlist.cpp
index b9a48b39..9f4a4ebd 100644
--- a/src/data/shortlist.cpp
+++ b/src/data/shortlist.cpp
@@ -320,7 +320,8 @@ Ptr<ShortlistGenerator> createShortlistGenerator(Ptr<Options> options,
size_t srcIdx,
size_t trgIdx,
bool shared) {
- if (lshOpts.size() == 2) {
+ if (lshOpts.size()) {
+ assert(lshOpts.size() == 2);
size_t lemmaSize = trgVocab->lemmaSize();
return New<LSHShortlistGenerator>(lshOpts[0], lshOpts[1], lemmaSize);
}
diff --git a/src/data/shortlist.h b/src/data/shortlist.h
index 7fc48ec2..519b6b5f 100644
--- a/src/data/shortlist.h
+++ b/src/data/shortlist.h
@@ -66,8 +66,6 @@ public:
};
///////////////////////////////////////////////////////////////////////////////////
-// implements SLIDE for faster inference.
-// https://arxiv.org/pdf/1903.03129.pdf
class LSHShortlist: public Shortlist {
private:
int k_; // number of candidates returned from each input
diff --git a/src/layers/logits.cpp b/src/layers/logits.cpp
index 0bd8aa91..794323d0 100644
--- a/src/layers/logits.cpp
+++ b/src/layers/logits.cpp
@@ -62,7 +62,7 @@ Expr Logits::applyLossFunction(
auto factorIndices = indices(maskedFactoredLabels.indices); // [B... flattened] factor-label indices, or 0 if factor does not apply
auto factorMask = constant(maskedFactoredLabels.masks); // [B... flattened] loss values get multiplied with 0 for labels that don't have this factor
auto factorLogits = logits_[g]; // [B... * Ug] label-wise loss values (not aggregated yet)
- std::cerr << "g=" << g << " factorLogits->loss()=" << factorLogits->loss()->shape() << std::endl;
+ //std::cerr << "g=" << g << " factorLogits->loss()=" << factorLogits->loss()->shape() << std::endl;
// For each location in [B...] select [indices[B...]]. If not using factor, select [0] and mask it out next.
auto factorLoss = lossFn(factorLogits->loss(), factorIndices); // [B... x 1]
// clang-format on
@@ -113,7 +113,7 @@ Expr Logits::getFactoredLogits(size_t groupIndex,
else {
auto forward = [this, g](Expr out, const std::vector<Expr>& inputs) {
Expr lastIndices = inputs[0];
- std::vector<float> masks = getFactorMasksMultiDim(g, lastIndices);
+ std::vector<float> masks = getFactorMasks(g, lastIndices);
out->val()->set(masks);
};
@@ -245,7 +245,7 @@ std::vector<float> Logits::getFactorMasks(size_t factorGroup, const std::vector<
return res;
}
-std::vector<float> Logits::getFactorMasksMultiDim(size_t factorGroup, Expr indicesExpr)
+std::vector<float> Logits::getFactorMasks(size_t factorGroup, Expr indicesExpr)
const { // [lemmaIndex] -> 1.0 for words that do have this factor; else 0
int batchSize = indicesExpr->shape()[0];
int currBeamSize = indicesExpr->shape()[1];
diff --git a/src/layers/logits.h b/src/layers/logits.h
index 1a57657d..a92a01c3 100644
--- a/src/layers/logits.h
+++ b/src/layers/logits.h
@@ -77,7 +77,7 @@ private:
} // actually the same as constant(data) for this data type
std::vector<float> getFactorMasks(size_t factorGroup,
const std::vector<WordIndex>& indices) const;
- std::vector<float> getFactorMasksMultiDim(size_t factorGroup, Expr indicesExpr) const;
+ std::vector<float> getFactorMasks(size_t factorGroup, Expr indicesExpr) const; // same as above but separate indices for each batch and beam
private:
// members
diff --git a/src/layers/output.cpp b/src/layers/output.cpp
index 21eb3714..d7ba4490 100644
--- a/src/layers/output.cpp
+++ b/src/layers/output.cpp
@@ -75,7 +75,7 @@ Logits Output::applyAsLogits(Expr input) /*override final*/ {
std::cerr << "affineShortlist.b=" << b->shape() << std::endl;
std::cerr << "affineShortlist.transA=" << transA << " transB=" << transB << std::endl;
*/
- ABORT_IF(!(!transA && transB), "Must be transA==0 and transB==1");
+ ABORT_IF(!(!transA && transB), "affineShortlist. Must be transA==0 and transB==1");
ABORT_IF(b, "affineShortlist not tested with bias");
Expr ret = bdot(x, W, transA, transB);
//std::cerr << "ret=" << ret->shape() << std::endl;
@@ -83,8 +83,7 @@ Logits Output::applyAsLogits(Expr input) /*override final*/ {
return ret;
};
- if(shortlist_) { // shortlisted versions of parameters are cached within one
- // batch, then clear()ed
+ if(shortlist_) {
shortlist_->filter(input, Wt_, isLegacyUntransposedW, b_, lemmaEt_);
}
diff --git a/src/translator/beam_search.cpp b/src/translator/beam_search.cpp
index 94de3db0..da529980 100644
--- a/src/translator/beam_search.cpp
+++ b/src/translator/beam_search.cpp
@@ -20,7 +20,6 @@ Beams BeamSearch::toHyps(const std::vector<unsigned int>& nBestKeys, // [current
const std::vector<bool>& dropBatchEntries, // [origDimBatch] - empty source batch entries are marked with true, should be cleared after first use.
const std::vector<IndexType>& batchIdxMap) const { // [origBatchIdx -> currentBatchIdx]
std::vector<float> align; // collects alignment information from the last executed time step
- //utils::Debug(batchIdxMap, "batchIdxMap");
if(options_->hasAndNotEmpty("alignment") && factorGroup == 0)
align = scorers_[0]->getAlignment(); // [beam depth * max src length * current batch size] -> P(s|t); use alignments from the first scorer, even if ensemble,
@@ -86,12 +85,6 @@ Beams BeamSearch::toHyps(const std::vector<unsigned int>& nBestKeys, // [current
// map wordIdx to word
auto prevBeamHypIdx = beamHypIdx; // back pointer
- /*std::cerr << "currentBatchIdx=" << currentBatchIdx
- << " origBatchIdx=" << origBatchIdx
- << " beamHypIdx=" << beamHypIdx
- << " prevBeamHypIdx=" << prevBeamHypIdx
- << std::endl;*/
-
auto prevHyp = beam[prevBeamHypIdx];
Word word;
// If short list has been set, then wordIdx is an index into the short-listed word set,
diff --git a/src/translator/translator.h b/src/translator/translator.h
index f4f9ec4c..f1acd5a1 100644
--- a/src/translator/translator.h
+++ b/src/translator/translator.h
@@ -63,6 +63,8 @@ public:
auto srcVocab = corpus_->getVocabs()[0];
std::vector<int> lshOpts = options_->get<std::vector<int>>("output-approx-knn");
+ ABORT_IF(lshOpts.size() != 0 && lshOpts.size() != 2, "--output-approx-knn takes 2 parameters");
+
if (lshOpts.size() == 2 || options_->hasAndNotEmpty("shortlist")) {
shortlistGenerator_ = data::createShortlistGenerator(options_, srcVocab, trgVocab_, lshOpts, 0, 1, vocabs.front() == vocabs.back());
}