diff options
author | Hieu Hoang <hihoan@microsoft.com> | 2021-03-23 04:19:16 +0300 |
---|---|---|
committer | Hieu Hoang <hihoan@microsoft.com> | 2021-03-23 04:19:16 +0300 |
commit | 415769fb2f953c5abf88f2d3498f2a46ea3607d7 (patch) | |
tree | 4a5fc40ec0623094d075d77a7afc20ceb64dce09 | |
parent | b36d0bbbab1fa4d435e517c974c3fde96e3145fe (diff) |
start lsh shortlist
-rw-r--r-- | src/data/shortlist.cpp | 16 | ||||
-rw-r--r-- | src/data/shortlist.h | 8 | ||||
-rw-r--r-- | src/translator/translator.h | 5 |
3 files changed, 26 insertions, 3 deletions
diff --git a/src/data/shortlist.cpp b/src/data/shortlist.cpp index 6f551262..2d4a5edc 100644 --- a/src/data/shortlist.cpp +++ b/src/data/shortlist.cpp @@ -133,16 +133,30 @@ Ptr<Shortlist> QuicksandShortlistGenerator::generate(Ptr<data::CorpusBatch> batc return New<Shortlist>(indices); } +LSHlistGenerator::LSHlistGenerator(int k, int nbits) { + +} + +Ptr<Shortlist> LSHlistGenerator::generate(Ptr<data::CorpusBatch> batch) const { + +} + + Ptr<ShortlistGenerator> createShortlistGenerator(Ptr<Options> options, Ptr<const Vocab> srcVocab, Ptr<const Vocab> trgVocab, size_t srcIdx, size_t trgIdx, + const std::vector<int> &lshOpts, bool shared) { + std::cerr << "lshOpts=" << lshOpts.size() << std::endl; std::vector<std::string> vals = options->get<std::vector<std::string>>("shortlist"); ABORT_IF(vals.empty(), "No path to shortlist given"); std::string fname = vals[0]; - if(filesystem::Path(fname).extension().string() == ".bin") { + if (lshOpts.size() == 2) { + return New<LSHlistGenerator>(lshOpts[0], lshOpts[1]); + } + else if(filesystem::Path(fname).extension().string() == ".bin") { return New<QuicksandShortlistGenerator>(options, srcVocab, trgVocab, srcIdx, trgIdx, shared); } else { return New<LexicalShortlistGenerator>(options, srcVocab, trgVocab, srcIdx, trgIdx, shared); diff --git a/src/data/shortlist.h b/src/data/shortlist.h index ab6a087b..be04e518 100644 --- a/src/data/shortlist.h +++ b/src/data/shortlist.h @@ -328,6 +328,13 @@ public: virtual Ptr<Shortlist> generate(Ptr<data::CorpusBatch> batch) const override; }; +class LSHlistGenerator : public ShortlistGenerator { +private: + +public: + LSHlistGenerator(int k, int nbits); +}; + /* Shortlist factory to create correct type of shortlist. Currently assumes everything is a text shortlist unless the extension is *.bin for which the Microsoft legacy binary shortlist is used. @@ -337,6 +344,7 @@ Ptr<ShortlistGenerator> createShortlistGenerator(Ptr<Options> options, Ptr<const Vocab> trgVocab, size_t srcIdx = 0, size_t trgIdx = 1, + const std::vector<int> &lshOpts, bool shared = false); } // namespace data diff --git a/src/translator/translator.h b/src/translator/translator.h index fe01065b..edc4a4fa 100644 --- a/src/translator/translator.h +++ b/src/translator/translator.h @@ -62,8 +62,9 @@ public: trgVocab_->load(vocabs.back()); auto srcVocab = corpus_->getVocabs()[0]; - if(options_->hasAndNotEmpty("shortlist")) - shortlistGenerator_ = data::createShortlistGenerator(options_, srcVocab, trgVocab_, 0, 1, vocabs.front() == vocabs.back()); + std::vector<int> lshOpts = options_->get<std::vector<int>>("output-approx-knn"); + if(lshOpts.size() == 2 || options_->hasAndNotEmpty("shortlist")) + shortlistGenerator_ = data::createShortlistGenerator(options_, srcVocab, trgVocab_, 0, 1, lshOpts, vocabs.front() == vocabs.back()); auto devices = Config::getDevices(options_); numDevices_ = devices.size(); |