Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorHieu Hoang <hihoan@microsoft.com>2021-03-23 04:19:16 +0300
committerHieu Hoang <hihoan@microsoft.com>2021-03-23 04:19:16 +0300
commit415769fb2f953c5abf88f2d3498f2a46ea3607d7 (patch)
tree4a5fc40ec0623094d075d77a7afc20ceb64dce09
parentb36d0bbbab1fa4d435e517c974c3fde96e3145fe (diff)
start lsh shortlist
-rw-r--r--src/data/shortlist.cpp16
-rw-r--r--src/data/shortlist.h8
-rw-r--r--src/translator/translator.h5
3 files changed, 26 insertions, 3 deletions
diff --git a/src/data/shortlist.cpp b/src/data/shortlist.cpp
index 6f551262..2d4a5edc 100644
--- a/src/data/shortlist.cpp
+++ b/src/data/shortlist.cpp
@@ -133,16 +133,30 @@ Ptr<Shortlist> QuicksandShortlistGenerator::generate(Ptr<data::CorpusBatch> batc
return New<Shortlist>(indices);
}
+LSHlistGenerator::LSHlistGenerator(int k, int nbits) {
+
+}
+
+Ptr<Shortlist> LSHlistGenerator::generate(Ptr<data::CorpusBatch> batch) const {
+
+}
+
+
Ptr<ShortlistGenerator> createShortlistGenerator(Ptr<Options> options,
Ptr<const Vocab> srcVocab,
Ptr<const Vocab> trgVocab,
size_t srcIdx,
size_t trgIdx,
+ const std::vector<int> &lshOpts,
bool shared) {
+ std::cerr << "lshOpts=" << lshOpts.size() << std::endl;
std::vector<std::string> vals = options->get<std::vector<std::string>>("shortlist");
ABORT_IF(vals.empty(), "No path to shortlist given");
std::string fname = vals[0];
- if(filesystem::Path(fname).extension().string() == ".bin") {
+ if (lshOpts.size() == 2) {
+ return New<LSHlistGenerator>(lshOpts[0], lshOpts[1]);
+ }
+ else if(filesystem::Path(fname).extension().string() == ".bin") {
return New<QuicksandShortlistGenerator>(options, srcVocab, trgVocab, srcIdx, trgIdx, shared);
} else {
return New<LexicalShortlistGenerator>(options, srcVocab, trgVocab, srcIdx, trgIdx, shared);
diff --git a/src/data/shortlist.h b/src/data/shortlist.h
index ab6a087b..be04e518 100644
--- a/src/data/shortlist.h
+++ b/src/data/shortlist.h
@@ -328,6 +328,13 @@ public:
virtual Ptr<Shortlist> generate(Ptr<data::CorpusBatch> batch) const override;
};
+class LSHlistGenerator : public ShortlistGenerator {
+private:
+
+public:
+ LSHlistGenerator(int k, int nbits);
+};
+
/*
Shortlist factory to create correct type of shortlist. Currently assumes everything is a text shortlist
unless the extension is *.bin for which the Microsoft legacy binary shortlist is used.
@@ -337,6 +344,7 @@ Ptr<ShortlistGenerator> createShortlistGenerator(Ptr<Options> options,
Ptr<const Vocab> trgVocab,
size_t srcIdx = 0,
size_t trgIdx = 1,
+ const std::vector<int> &lshOpts,
bool shared = false);
} // namespace data
diff --git a/src/translator/translator.h b/src/translator/translator.h
index fe01065b..edc4a4fa 100644
--- a/src/translator/translator.h
+++ b/src/translator/translator.h
@@ -62,8 +62,9 @@ public:
trgVocab_->load(vocabs.back());
auto srcVocab = corpus_->getVocabs()[0];
- if(options_->hasAndNotEmpty("shortlist"))
- shortlistGenerator_ = data::createShortlistGenerator(options_, srcVocab, trgVocab_, 0, 1, vocabs.front() == vocabs.back());
+ std::vector<int> lshOpts = options_->get<std::vector<int>>("output-approx-knn");
+ if(lshOpts.size() == 2 || options_->hasAndNotEmpty("shortlist"))
+ shortlistGenerator_ = data::createShortlistGenerator(options_, srcVocab, trgVocab_, 0, 1, lshOpts, vocabs.front() == vocabs.back());
auto devices = Config::getDevices(options_);
numDevices_ = devices.size();