Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorHieu Hoang <hihoan@microsoft.com>2021-06-18 20:18:31 +0300
committerHieu Hoang <hihoan@microsoft.com>2021-06-18 20:18:31 +0300
commitcd292d3b32428b6c1cf57e9eb6ad06b1db1e5452 (patch)
tree6bd696bd0ca90f832022899416a2dda781ad6a3b /src
parenta332e550a5cf236d5ab97fea3a512c3eff5d3947 (diff)
changes for review
Diffstat (limited to 'src')
-rw-r--r--src/common/utils.h2
-rw-r--r--src/data/factored_vocab.cpp3
-rw-r--r--src/data/shortlist.cpp22
-rw-r--r--src/data/shortlist.h15
-rw-r--r--src/data/vocab.cpp2
-rw-r--r--src/data/vocab.h2
-rw-r--r--src/layers/logits.cpp2
-rw-r--r--src/layers/output.cpp4
-rw-r--r--src/translator/beam_search.cpp2
9 files changed, 19 insertions, 35 deletions
diff --git a/src/common/utils.h b/src/common/utils.h
index d8d387a8..13b50c0b 100644
--- a/src/common/utils.h
+++ b/src/common/utils.h
@@ -63,7 +63,7 @@ std::string findReplace(const std::string& in, const std::string& what, const st
double parseDouble(std::string s);
double parseNumber(std::string s);
-
+// prints vector values with a custom label.
template<class T>
void Debug(const T *arr, size_t size, const std::string &str) {
std::cerr << str << ":" << size << ": ";
diff --git a/src/data/factored_vocab.cpp b/src/data/factored_vocab.cpp
index e26a8479..4c5207dd 100644
--- a/src/data/factored_vocab.cpp
+++ b/src/data/factored_vocab.cpp
@@ -274,7 +274,10 @@ void FactoredVocab::constructGroupInfoFromFactorVocab() {
groupRanges_[g].second = u + 1;
groupCounts[g]++;
}
+
+ // required by LSH shortlist
lemmaSize_ = groupCounts[0];
+
for (size_t g = 0; g < numGroups; g++) { // detect non-overlapping groups
LOG(info, "[vocab] Factor group '{}' has {} members", groupPrefixes_[g], groupCounts[g]);
if (groupCounts[g] == 0) { // factor group is unused --@TODO: once this is not hard-coded, this is an error condition
diff --git a/src/data/shortlist.cpp b/src/data/shortlist.cpp
index a965f249..b7c03436 100644
--- a/src/data/shortlist.cpp
+++ b/src/data/shortlist.cpp
@@ -24,9 +24,9 @@ Shortlist::Shortlist(const std::vector<WordIndex>& indices)
Shortlist::~Shortlist() {}
-WordIndex Shortlist::reverseMap(int , int , int idx) const { return indices_[idx]; }
+WordIndex Shortlist::reverseMap(int /*beamIdx*/, int /*batchIdx*/, int idx) const { return indices_[idx]; }
-WordIndex Shortlist::tryForwardMap(int , int , WordIndex wIdx) const {
+WordIndex Shortlist::tryForwardMap(WordIndex wIdx) const {
auto first = std::lower_bound(indices_.begin(), indices_.end(), wIdx);
if(first != indices_.end() && *first == wIdx) // check if element not less than wIdx has been found and if equal to wIdx
return (int)std::distance(indices_.begin(), first); // return coordinate if found
@@ -83,15 +83,8 @@ Ptr<faiss::IndexLSH> LSHShortlist::index_;
LSHShortlist::LSHShortlist(int k, int nbits, size_t lemmaSize)
: Shortlist(std::vector<WordIndex>())
, k_(k), nbits_(nbits), lemmaSize_(lemmaSize) {
- /*
- for (int i = 0; i < k_; ++i) {
- indices_.push_back(i);
- }
- */
}
-//#define BLAS_FOUND 1
-
WordIndex LSHShortlist::reverseMap(int beamIdx, int batchIdx, int idx) const {
//int currBeamSize = indicesExpr_->shape()[0];
int currBatchSize = indicesExpr_->shape()[1];
@@ -100,15 +93,6 @@ WordIndex LSHShortlist::reverseMap(int beamIdx, int batchIdx, int idx) const {
return indices_[idx];
}
-WordIndex LSHShortlist::tryForwardMap(int , int , WordIndex wIdx) const {
- auto first = std::lower_bound(indices_.begin(), indices_.end(), wIdx);
- bool found = first != indices_.end();
- if(found && *first == wIdx) // check if element not less than wIdx has been found and if equal to wIdx
- return (int)std::distance(indices_.begin(), first); // return coordinate if found
- else
- return npos; // return npos if not found, @TODO: replace with std::optional once we switch to C++17?
-}
-
Expr LSHShortlist::getIndicesExpr() const {
return indicesExpr_;
}
@@ -128,7 +112,6 @@ void LSHShortlist::filter(Expr input, Expr weights, bool isLegacyUntransposedW,
int dim = values->shape()[-1];
if(!index_) {
- //std::cerr << "build lsh index" << std::endl;
LOG(info, "Building LSH index for vector dim {} and with hash size {} bits", dim, nbits_);
index_.reset(new faiss::IndexLSH(dim, nbits_,
/*rotate=*/dim != nbits_,
@@ -199,7 +182,6 @@ void LSHShortlist::createCachedTensors(Expr weights,
LSHShortlistGenerator::LSHShortlistGenerator(int k, int nbits, size_t lemmaSize)
: k_(k), nbits_(nbits), lemmaSize_(lemmaSize) {
- //std::cerr << "LSHShortlistGenerator" << std::endl;
}
Ptr<Shortlist> LSHShortlistGenerator::generate(Ptr<data::CorpusBatch> batch) const {
diff --git a/src/data/shortlist.h b/src/data/shortlist.h
index 1d8903e6..cd96e0d7 100644
--- a/src/data/shortlist.h
+++ b/src/data/shortlist.h
@@ -29,7 +29,7 @@ protected:
Expr cachedShortWt_; // short-listed version, cached (cleared by clear())
Expr cachedShortb_; // these match the current value of shortlist_
Expr cachedShortLemmaEt_;
- bool done_;
+ bool done_; // used by batch-level shortlist. Only initialize with 1st call then skip all subsequent calls for same batch
void createCachedTensors(Expr weights,
bool isLegacyUntransposedW,
@@ -43,7 +43,7 @@ public:
virtual ~Shortlist();
virtual WordIndex reverseMap(int beamIdx, int batchIdx, int idx) const;
- virtual WordIndex tryForwardMap(int batchIdx, int beamIdx, WordIndex wIdx) const;
+ virtual WordIndex tryForwardMap(WordIndex wIdx) const;
virtual void filter(Expr input, Expr weights, bool isLegacyUntransposedW, Expr b, Expr lemmaEt);
virtual Expr getIndicesExpr() const;
@@ -66,12 +66,14 @@ public:
};
///////////////////////////////////////////////////////////////////////////////////
+// implements SLIDE for faster inference.
+// https://arxiv.org/pdf/1903.03129.pdf
class LSHShortlist: public Shortlist {
private:
- int k_;
- int nbits_;
- size_t lemmaSize_;
- static Ptr<faiss::IndexLSH> index_;
+ int k_; // number of candidates returned from each input
+ int nbits_; // length of hash
+ size_t lemmaSize_; // vocab size
+ static Ptr<faiss::IndexLSH> index_; // LSH index to store all possible candidates
void createCachedTensors(Expr weights,
bool isLegacyUntransposedW,
@@ -82,7 +84,6 @@ private:
public:
LSHShortlist(int k, int nbits, size_t lemmaSize);
virtual WordIndex reverseMap(int beamIdx, int batchIdx, int idx) const override;
- virtual WordIndex tryForwardMap(int batchIdx, int beamIdx, WordIndex wIdx) const override;
virtual void filter(Expr input, Expr weights, bool isLegacyUntransposedW, Expr b, Expr lemmaEt) override;
virtual Expr getIndicesExpr() const override;
diff --git a/src/data/vocab.cpp b/src/data/vocab.cpp
index 38eddd01..82a4b8da 100644
--- a/src/data/vocab.cpp
+++ b/src/data/vocab.cpp
@@ -133,7 +133,7 @@ size_t Vocab::lemmaSize() const {
return vImpl_->lemmaSize();
}
-// number of vocabulary items
+// type of vocabulary items
std::string Vocab::type() const { return vImpl_->type(); }
// return EOS symbol id
diff --git a/src/data/vocab.h b/src/data/vocab.h
index f4a7e0b7..4af82e8e 100644
--- a/src/data/vocab.h
+++ b/src/data/vocab.h
@@ -61,7 +61,7 @@ public:
// number of vocabulary items
size_t size() const;
- // number of vocabulary items
+ // number of lemma items. Same as size() except in factored models
size_t lemmaSize() const;
// number of vocabulary items
diff --git a/src/layers/logits.cpp b/src/layers/logits.cpp
index 1830741e..0bd8aa91 100644
--- a/src/layers/logits.cpp
+++ b/src/layers/logits.cpp
@@ -247,8 +247,6 @@ std::vector<float> Logits::getFactorMasks(size_t factorGroup, const std::vector<
std::vector<float> Logits::getFactorMasksMultiDim(size_t factorGroup, Expr indicesExpr)
const { // [lemmaIndex] -> 1.0 for words that do have this factor; else 0
- //std::cerr << "indicesExpr=" << indicesExpr->shape() << std::endl;
- //int batchSize
int batchSize = indicesExpr->shape()[0];
int currBeamSize = indicesExpr->shape()[1];
int numHypos = batchSize * currBeamSize;
diff --git a/src/layers/output.cpp b/src/layers/output.cpp
index 964cb724..21eb3714 100644
--- a/src/layers/output.cpp
+++ b/src/layers/output.cpp
@@ -56,12 +56,12 @@ Logits Output::applyAsLogits(Expr input) /*override final*/ {
lazyConstruct(input->shape()[-1]);
auto affineOrDot = [](Expr x, Expr W, Expr b, bool transA, bool transB) {
-
+ /*
std::cerr << "affineOrDot.x=" << x->shape() << std::endl;
std::cerr << "affineOrDot.W=" << W->shape() << std::endl;
std::cerr << "affineOrDot.b=" << b->shape() << std::endl;
std::cerr << "affineOrDot.transA=" << transA << " transB=" << transB << std::endl;
-
+ */
if(b)
return affine(x, W, b, transA, transB);
else
diff --git a/src/translator/beam_search.cpp b/src/translator/beam_search.cpp
index eda288a4..94de3db0 100644
--- a/src/translator/beam_search.cpp
+++ b/src/translator/beam_search.cpp
@@ -315,7 +315,7 @@ Histories BeamSearch::search(Ptr<ExpressionGraph> graph, Ptr<data::CorpusBatch>
suppressed.erase(std::remove_if(suppressed.begin(),
suppressed.end(),
[&](WordIndex i) {
- return shortlist->tryForwardMap(4545, 3343, i) == data::Shortlist::npos; // TODO beamIdx
+ return shortlist->tryForwardMap(i) == data::Shortlist::npos; // TODO beamIdx
}),
suppressed.end());