Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMarcin Junczys-Dowmunt <marcinjd@microsoft.com>2021-07-03 22:01:22 +0300
committerMarcin Junczys-Dowmunt <marcinjd@microsoft.com>2021-07-03 22:01:22 +0300
commit8bfa6a44e307a58b80784040d741058855f5b519 (patch)
tree783b9cab2ac01319af59f261339d8f768adbdae2
parent64e787afcea2137eb54dab0650f5f93cdc27b857 (diff)
parent4ace42f35aa9fe6d51884ccd5486ba3e8cf626b4 (diff)
Merge branch 'hihoan/lsh7' of vs-ssh.visualstudio.com:v3/machinetranslation/Marian/marian-dev into hihoan/lsh7
-rw-r--r--src/data/shortlist.cpp6
-rw-r--r--src/data/shortlist.h5
-rw-r--r--src/layers/logits.cpp6
-rw-r--r--src/layers/logits.h2
-rw-r--r--src/layers/output.cpp5
-rw-r--r--src/translator/beam_search.cpp9
-rw-r--r--src/translator/translator.h2
7 files changed, 17 insertions, 18 deletions
diff --git a/src/data/shortlist.cpp b/src/data/shortlist.cpp
index ad2525dc..9f4a4ebd 100644
--- a/src/data/shortlist.cpp
+++ b/src/data/shortlist.cpp
@@ -79,6 +79,7 @@ void Shortlist::createCachedTensors(Expr weights,
///////////////////////////////////////////////////////////////////////////////////
Ptr<faiss::IndexLSH> LSHShortlist::index_;
+std::mutex LSHShortlist::mutex_;
LSHShortlist::LSHShortlist(int k, int nbits, size_t lemmaSize)
: Shortlist(std::vector<WordIndex>())
@@ -111,6 +112,7 @@ void LSHShortlist::filter(Expr input, Expr weights, bool isLegacyUntransposedW,
auto values = inputs[1];
int dim = values->shape()[-1];
+ mutex_.lock();
if(!index_) {
LOG(info, "Building LSH index for vector dim {} and with hash size {} bits", dim, nbits_);
index_.reset(new faiss::IndexLSH(dim, nbits_,
@@ -119,6 +121,7 @@ void LSHShortlist::filter(Expr input, Expr weights, bool isLegacyUntransposedW,
index_->train(lemmaSize_, values->val()->data<float>());
index_->add( lemmaSize_, values->val()->data<float>());
}
+ mutex_.unlock();
int qRows = query->shape().elements() / dim;
std::vector<float> distances(qRows * k_);
@@ -317,7 +320,8 @@ Ptr<ShortlistGenerator> createShortlistGenerator(Ptr<Options> options,
size_t srcIdx,
size_t trgIdx,
bool shared) {
- if (lshOpts.size() == 2) {
+ if (lshOpts.size()) {
+ assert(lshOpts.size() == 2);
size_t lemmaSize = trgVocab->lemmaSize();
return New<LSHShortlistGenerator>(lshOpts[0], lshOpts[1], lemmaSize);
}
diff --git a/src/data/shortlist.h b/src/data/shortlist.h
index cd96e0d7..1ce8fbf4 100644
--- a/src/data/shortlist.h
+++ b/src/data/shortlist.h
@@ -66,14 +66,15 @@ public:
};
///////////////////////////////////////////////////////////////////////////////////
-// implements SLIDE for faster inference.
-// https://arxiv.org/pdf/1903.03129.pdf
+// faster inference inspired by these 2 papers
+// https://arxiv.org/pdf/1903.03129.pdf https://arxiv.org/pdf/1806.00588.pdf
class LSHShortlist: public Shortlist {
private:
int k_; // number of candidates returned from each input
int nbits_; // length of hash
size_t lemmaSize_; // vocab size
static Ptr<faiss::IndexLSH> index_; // LSH index to store all possible candidates
+ static std::mutex mutex_;
void createCachedTensors(Expr weights,
bool isLegacyUntransposedW,
diff --git a/src/layers/logits.cpp b/src/layers/logits.cpp
index 0bd8aa91..794323d0 100644
--- a/src/layers/logits.cpp
+++ b/src/layers/logits.cpp
@@ -62,7 +62,7 @@ Expr Logits::applyLossFunction(
auto factorIndices = indices(maskedFactoredLabels.indices); // [B... flattened] factor-label indices, or 0 if factor does not apply
auto factorMask = constant(maskedFactoredLabels.masks); // [B... flattened] loss values get multiplied with 0 for labels that don't have this factor
auto factorLogits = logits_[g]; // [B... * Ug] label-wise loss values (not aggregated yet)
- std::cerr << "g=" << g << " factorLogits->loss()=" << factorLogits->loss()->shape() << std::endl;
+ //std::cerr << "g=" << g << " factorLogits->loss()=" << factorLogits->loss()->shape() << std::endl;
// For each location in [B...] select [indices[B...]]. If not using factor, select [0] and mask it out next.
auto factorLoss = lossFn(factorLogits->loss(), factorIndices); // [B... x 1]
// clang-format on
@@ -113,7 +113,7 @@ Expr Logits::getFactoredLogits(size_t groupIndex,
else {
auto forward = [this, g](Expr out, const std::vector<Expr>& inputs) {
Expr lastIndices = inputs[0];
- std::vector<float> masks = getFactorMasksMultiDim(g, lastIndices);
+ std::vector<float> masks = getFactorMasks(g, lastIndices);
out->val()->set(masks);
};
@@ -245,7 +245,7 @@ std::vector<float> Logits::getFactorMasks(size_t factorGroup, const std::vector<
return res;
}
-std::vector<float> Logits::getFactorMasksMultiDim(size_t factorGroup, Expr indicesExpr)
+std::vector<float> Logits::getFactorMasks(size_t factorGroup, Expr indicesExpr)
const { // [lemmaIndex] -> 1.0 for words that do have this factor; else 0
int batchSize = indicesExpr->shape()[0];
int currBeamSize = indicesExpr->shape()[1];
diff --git a/src/layers/logits.h b/src/layers/logits.h
index 1a57657d..a92a01c3 100644
--- a/src/layers/logits.h
+++ b/src/layers/logits.h
@@ -77,7 +77,7 @@ private:
} // actually the same as constant(data) for this data type
std::vector<float> getFactorMasks(size_t factorGroup,
const std::vector<WordIndex>& indices) const;
- std::vector<float> getFactorMasksMultiDim(size_t factorGroup, Expr indicesExpr) const;
+ std::vector<float> getFactorMasks(size_t factorGroup, Expr indicesExpr) const; // same as above but separate indices for each batch and beam
private:
// members
diff --git a/src/layers/output.cpp b/src/layers/output.cpp
index 21eb3714..d7ba4490 100644
--- a/src/layers/output.cpp
+++ b/src/layers/output.cpp
@@ -75,7 +75,7 @@ Logits Output::applyAsLogits(Expr input) /*override final*/ {
std::cerr << "affineShortlist.b=" << b->shape() << std::endl;
std::cerr << "affineShortlist.transA=" << transA << " transB=" << transB << std::endl;
*/
- ABORT_IF(!(!transA && transB), "Must be transA==0 and transB==1");
+ ABORT_IF(!(!transA && transB), "affineShortlist. Must be transA==0 and transB==1");
ABORT_IF(b, "affineShortlist not tested with bias");
Expr ret = bdot(x, W, transA, transB);
//std::cerr << "ret=" << ret->shape() << std::endl;
@@ -83,8 +83,7 @@ Logits Output::applyAsLogits(Expr input) /*override final*/ {
return ret;
};
- if(shortlist_) { // shortlisted versions of parameters are cached within one
- // batch, then clear()ed
+ if(shortlist_) {
shortlist_->filter(input, Wt_, isLegacyUntransposedW, b_, lemmaEt_);
}
diff --git a/src/translator/beam_search.cpp b/src/translator/beam_search.cpp
index 94de3db0..2a0d3947 100644
--- a/src/translator/beam_search.cpp
+++ b/src/translator/beam_search.cpp
@@ -20,7 +20,6 @@ Beams BeamSearch::toHyps(const std::vector<unsigned int>& nBestKeys, // [current
const std::vector<bool>& dropBatchEntries, // [origDimBatch] - empty source batch entries are marked with true, should be cleared after first use.
const std::vector<IndexType>& batchIdxMap) const { // [origBatchIdx -> currentBatchIdx]
std::vector<float> align; // collects alignment information from the last executed time step
- //utils::Debug(batchIdxMap, "batchIdxMap");
if(options_->hasAndNotEmpty("alignment") && factorGroup == 0)
align = scorers_[0]->getAlignment(); // [beam depth * max src length * current batch size] -> P(s|t); use alignments from the first scorer, even if ensemble,
@@ -86,12 +85,6 @@ Beams BeamSearch::toHyps(const std::vector<unsigned int>& nBestKeys, // [current
// map wordIdx to word
auto prevBeamHypIdx = beamHypIdx; // back pointer
- /*std::cerr << "currentBatchIdx=" << currentBatchIdx
- << " origBatchIdx=" << origBatchIdx
- << " beamHypIdx=" << beamHypIdx
- << " prevBeamHypIdx=" << prevBeamHypIdx
- << std::endl;*/
-
auto prevHyp = beam[prevBeamHypIdx];
Word word;
// If short list has been set, then wordIdx is an index into the short-listed word set,
@@ -315,7 +308,7 @@ Histories BeamSearch::search(Ptr<ExpressionGraph> graph, Ptr<data::CorpusBatch>
suppressed.erase(std::remove_if(suppressed.begin(),
suppressed.end(),
[&](WordIndex i) {
- return shortlist->tryForwardMap(i) == data::Shortlist::npos; // TODO beamIdx
+ return shortlist->tryForwardMap(i) == data::Shortlist::npos;
}),
suppressed.end());
diff --git a/src/translator/translator.h b/src/translator/translator.h
index b6be3242..8cc301b4 100644
--- a/src/translator/translator.h
+++ b/src/translator/translator.h
@@ -63,6 +63,8 @@ public:
auto srcVocab = corpus_->getVocabs()[0];
std::vector<int> lshOpts = options_->get<std::vector<int>>("output-approx-knn");
+ ABORT_IF(lshOpts.size() != 0 && lshOpts.size() != 2, "--output-approx-knn takes 2 parameters");
+
if (lshOpts.size() == 2 || options_->hasAndNotEmpty("shortlist")) {
shortlistGenerator_ = data::createShortlistGenerator(options_, srcVocab, trgVocab_, lshOpts, 0, 1, vocabs.front() == vocabs.back());
}