diff options
author | Hieu Hoang <hihoan@microsoft.com> | 2021-06-18 20:18:31 +0300 |
---|---|---|
committer | Hieu Hoang <hihoan@microsoft.com> | 2021-06-18 20:18:31 +0300 |
commit | cd292d3b32428b6c1cf57e9eb6ad06b1db1e5452 (patch) | |
tree | 6bd696bd0ca90f832022899416a2dda781ad6a3b /src | |
parent | a332e550a5cf236d5ab97fea3a512c3eff5d3947 (diff) |
changes for review
Diffstat (limited to 'src')
-rw-r--r-- | src/common/utils.h | 2 | ||||
-rw-r--r-- | src/data/factored_vocab.cpp | 3 | ||||
-rw-r--r-- | src/data/shortlist.cpp | 22 | ||||
-rw-r--r-- | src/data/shortlist.h | 15 | ||||
-rw-r--r-- | src/data/vocab.cpp | 2 | ||||
-rw-r--r-- | src/data/vocab.h | 2 | ||||
-rw-r--r-- | src/layers/logits.cpp | 2 | ||||
-rw-r--r-- | src/layers/output.cpp | 4 | ||||
-rw-r--r-- | src/translator/beam_search.cpp | 2 |
9 files changed, 19 insertions, 35 deletions
diff --git a/src/common/utils.h b/src/common/utils.h index d8d387a8..13b50c0b 100644 --- a/src/common/utils.h +++ b/src/common/utils.h @@ -63,7 +63,7 @@ std::string findReplace(const std::string& in, const std::string& what, const st double parseDouble(std::string s); double parseNumber(std::string s); - +// prints vector values with a custom label. template<class T> void Debug(const T *arr, size_t size, const std::string &str) { std::cerr << str << ":" << size << ": "; diff --git a/src/data/factored_vocab.cpp b/src/data/factored_vocab.cpp index e26a8479..4c5207dd 100644 --- a/src/data/factored_vocab.cpp +++ b/src/data/factored_vocab.cpp @@ -274,7 +274,10 @@ void FactoredVocab::constructGroupInfoFromFactorVocab() { groupRanges_[g].second = u + 1; groupCounts[g]++; } + + // required by LSH shortlist lemmaSize_ = groupCounts[0]; + for (size_t g = 0; g < numGroups; g++) { // detect non-overlapping groups LOG(info, "[vocab] Factor group '{}' has {} members", groupPrefixes_[g], groupCounts[g]); if (groupCounts[g] == 0) { // factor group is unused --@TODO: once this is not hard-coded, this is an error condition diff --git a/src/data/shortlist.cpp b/src/data/shortlist.cpp index a965f249..b7c03436 100644 --- a/src/data/shortlist.cpp +++ b/src/data/shortlist.cpp @@ -24,9 +24,9 @@ Shortlist::Shortlist(const std::vector<WordIndex>& indices) Shortlist::~Shortlist() {} -WordIndex Shortlist::reverseMap(int , int , int idx) const { return indices_[idx]; } +WordIndex Shortlist::reverseMap(int /*beamIdx*/, int /*batchIdx*/, int idx) const { return indices_[idx]; } -WordIndex Shortlist::tryForwardMap(int , int , WordIndex wIdx) const { +WordIndex Shortlist::tryForwardMap(WordIndex wIdx) const { auto first = std::lower_bound(indices_.begin(), indices_.end(), wIdx); if(first != indices_.end() && *first == wIdx) // check if element not less than wIdx has been found and if equal to wIdx return (int)std::distance(indices_.begin(), first); // return coordinate if found @@ -83,15 +83,8 @@ Ptr<faiss::IndexLSH> LSHShortlist::index_; LSHShortlist::LSHShortlist(int k, int nbits, size_t lemmaSize) : Shortlist(std::vector<WordIndex>()) , k_(k), nbits_(nbits), lemmaSize_(lemmaSize) { - /* - for (int i = 0; i < k_; ++i) { - indices_.push_back(i); - } - */ } -//#define BLAS_FOUND 1 - WordIndex LSHShortlist::reverseMap(int beamIdx, int batchIdx, int idx) const { //int currBeamSize = indicesExpr_->shape()[0]; int currBatchSize = indicesExpr_->shape()[1]; @@ -100,15 +93,6 @@ WordIndex LSHShortlist::reverseMap(int beamIdx, int batchIdx, int idx) const { return indices_[idx]; } -WordIndex LSHShortlist::tryForwardMap(int , int , WordIndex wIdx) const { - auto first = std::lower_bound(indices_.begin(), indices_.end(), wIdx); - bool found = first != indices_.end(); - if(found && *first == wIdx) // check if element not less than wIdx has been found and if equal to wIdx - return (int)std::distance(indices_.begin(), first); // return coordinate if found - else - return npos; // return npos if not found, @TODO: replace with std::optional once we switch to C++17? -} - Expr LSHShortlist::getIndicesExpr() const { return indicesExpr_; } @@ -128,7 +112,6 @@ void LSHShortlist::filter(Expr input, Expr weights, bool isLegacyUntransposedW, int dim = values->shape()[-1]; if(!index_) { - //std::cerr << "build lsh index" << std::endl; LOG(info, "Building LSH index for vector dim {} and with hash size {} bits", dim, nbits_); index_.reset(new faiss::IndexLSH(dim, nbits_, /*rotate=*/dim != nbits_, @@ -199,7 +182,6 @@ void LSHShortlist::createCachedTensors(Expr weights, LSHShortlistGenerator::LSHShortlistGenerator(int k, int nbits, size_t lemmaSize) : k_(k), nbits_(nbits), lemmaSize_(lemmaSize) { - //std::cerr << "LSHShortlistGenerator" << std::endl; } Ptr<Shortlist> LSHShortlistGenerator::generate(Ptr<data::CorpusBatch> batch) const { diff --git a/src/data/shortlist.h b/src/data/shortlist.h index 1d8903e6..cd96e0d7 100644 --- a/src/data/shortlist.h +++ b/src/data/shortlist.h @@ -29,7 +29,7 @@ protected: Expr cachedShortWt_; // short-listed version, cached (cleared by clear()) Expr cachedShortb_; // these match the current value of shortlist_ Expr cachedShortLemmaEt_; - bool done_; + bool done_; // used by batch-level shortlist. Only initialize with 1st call then skip all subsequent calls for same batch void createCachedTensors(Expr weights, bool isLegacyUntransposedW, @@ -43,7 +43,7 @@ public: virtual ~Shortlist(); virtual WordIndex reverseMap(int beamIdx, int batchIdx, int idx) const; - virtual WordIndex tryForwardMap(int batchIdx, int beamIdx, WordIndex wIdx) const; + virtual WordIndex tryForwardMap(WordIndex wIdx) const; virtual void filter(Expr input, Expr weights, bool isLegacyUntransposedW, Expr b, Expr lemmaEt); virtual Expr getIndicesExpr() const; @@ -66,12 +66,14 @@ public: }; /////////////////////////////////////////////////////////////////////////////////// +// implements SLIDE for faster inference. +// https://arxiv.org/pdf/1903.03129.pdf class LSHShortlist: public Shortlist { private: - int k_; - int nbits_; - size_t lemmaSize_; - static Ptr<faiss::IndexLSH> index_; + int k_; // number of candidates returned from each input + int nbits_; // length of hash + size_t lemmaSize_; // vocab size + static Ptr<faiss::IndexLSH> index_; // LSH index to store all possible candidates void createCachedTensors(Expr weights, bool isLegacyUntransposedW, @@ -82,7 +84,6 @@ private: public: LSHShortlist(int k, int nbits, size_t lemmaSize); virtual WordIndex reverseMap(int beamIdx, int batchIdx, int idx) const override; - virtual WordIndex tryForwardMap(int batchIdx, int beamIdx, WordIndex wIdx) const override; virtual void filter(Expr input, Expr weights, bool isLegacyUntransposedW, Expr b, Expr lemmaEt) override; virtual Expr getIndicesExpr() const override; diff --git a/src/data/vocab.cpp b/src/data/vocab.cpp index 38eddd01..82a4b8da 100644 --- a/src/data/vocab.cpp +++ b/src/data/vocab.cpp @@ -133,7 +133,7 @@ size_t Vocab::lemmaSize() const { return vImpl_->lemmaSize(); } -// number of vocabulary items +// type of vocabulary items std::string Vocab::type() const { return vImpl_->type(); } // return EOS symbol id diff --git a/src/data/vocab.h b/src/data/vocab.h index f4a7e0b7..4af82e8e 100644 --- a/src/data/vocab.h +++ b/src/data/vocab.h @@ -61,7 +61,7 @@ public: // number of vocabulary items size_t size() const; - // number of vocabulary items + // number of lemma items. Same as size() except in factored models size_t lemmaSize() const; // number of vocabulary items diff --git a/src/layers/logits.cpp b/src/layers/logits.cpp index 1830741e..0bd8aa91 100644 --- a/src/layers/logits.cpp +++ b/src/layers/logits.cpp @@ -247,8 +247,6 @@ std::vector<float> Logits::getFactorMasks(size_t factorGroup, const std::vector< std::vector<float> Logits::getFactorMasksMultiDim(size_t factorGroup, Expr indicesExpr) const { // [lemmaIndex] -> 1.0 for words that do have this factor; else 0 - //std::cerr << "indicesExpr=" << indicesExpr->shape() << std::endl; - //int batchSize int batchSize = indicesExpr->shape()[0]; int currBeamSize = indicesExpr->shape()[1]; int numHypos = batchSize * currBeamSize; diff --git a/src/layers/output.cpp b/src/layers/output.cpp index 964cb724..21eb3714 100644 --- a/src/layers/output.cpp +++ b/src/layers/output.cpp @@ -56,12 +56,12 @@ Logits Output::applyAsLogits(Expr input) /*override final*/ { lazyConstruct(input->shape()[-1]); auto affineOrDot = [](Expr x, Expr W, Expr b, bool transA, bool transB) { - + /* std::cerr << "affineOrDot.x=" << x->shape() << std::endl; std::cerr << "affineOrDot.W=" << W->shape() << std::endl; std::cerr << "affineOrDot.b=" << b->shape() << std::endl; std::cerr << "affineOrDot.transA=" << transA << " transB=" << transB << std::endl; - + */ if(b) return affine(x, W, b, transA, transB); else diff --git a/src/translator/beam_search.cpp b/src/translator/beam_search.cpp index eda288a4..94de3db0 100644 --- a/src/translator/beam_search.cpp +++ b/src/translator/beam_search.cpp @@ -315,7 +315,7 @@ Histories BeamSearch::search(Ptr<ExpressionGraph> graph, Ptr<data::CorpusBatch> suppressed.erase(std::remove_if(suppressed.begin(), suppressed.end(), [&](WordIndex i) { - return shortlist->tryForwardMap(4545, 3343, i) == data::Shortlist::npos; // TODO beamIdx + return shortlist->tryForwardMap(i) == data::Shortlist::npos; // TODO beamIdx }), suppressed.end()); |