diff options
author | Martin Junczys-Dowmunt <Marcin.JunczysDowmunt@microsoft.com> | 2021-07-16 23:04:16 +0300 |
---|---|---|
committer | Martin Junczys-Dowmunt <Marcin.JunczysDowmunt@microsoft.com> | 2021-07-16 23:04:16 +0300 |
commit | 8e88071ae8caa89e3926c0d7281d8a59897e222c (patch) | |
tree | bb4f9e505ff2a5faa82e28797e7196927c83fc41 /src/microsoft | |
parent | 42f0b8b74bba16fed646c8af7b2f75e02af7a85c (diff) |
Merged PR 19842: Adapt LSH to work with Leaf
Small changes to make the LSH work with Leaf server and QuickSand.
Diffstat (limited to 'src/microsoft')
-rw-r--r-- | src/microsoft/quicksand.cpp | 27 | ||||
-rw-r--r-- | src/microsoft/quicksand.h | 2 | ||||
-rw-r--r-- | src/microsoft/shortlist/utils/Converter.cpp | 2 | ||||
-rw-r--r-- | src/microsoft/shortlist/utils/Converter.h | 2 | ||||
-rw-r--r-- | src/microsoft/shortlist/utils/ParameterTree.cpp | 3 | ||||
-rw-r--r-- | src/microsoft/shortlist/utils/ParameterTree.h | 2 | ||||
-rw-r--r-- | src/microsoft/shortlist/utils/StringUtils.cpp | 2 | ||||
-rw-r--r-- | src/microsoft/shortlist/utils/StringUtils.h | 2 |
8 files changed, 35 insertions, 7 deletions
diff --git a/src/microsoft/quicksand.cpp b/src/microsoft/quicksand.cpp index 70e657a9..099ce180 100644 --- a/src/microsoft/quicksand.cpp +++ b/src/microsoft/quicksand.cpp @@ -78,7 +78,7 @@ public: graph_->setDevice(deviceId, device_); #if MKL_FOUND - mkl_set_num_threads(options->get<int>("mkl-threads", 1)); + mkl_set_num_threads(options_->get<int>("mkl-threads", 1)); #endif std::vector<std::string> models @@ -114,6 +114,9 @@ public: for(auto scorer : scorers_) { scorer->init(graph_); } + + // run parameter init once, this is required for graph_->get("parameter name") to work correctly + graph_->forward(); } void setWorkspace(uint8_t* data, size_t size) override { device_->set(data, size); } @@ -121,8 +124,21 @@ public: QSNBestBatch decode(const QSBatch& qsBatch, size_t maxLength, const std::unordered_set<WordIndex>& shortlist) override { - if(shortlist.size() > 0) { - auto shortListGen = New<data::FakeShortlistGenerator>(shortlist); + + std::vector<int> lshOpts = options_->get<std::vector<int>>("output-approx-knn", {}); + ABORT_IF(lshOpts.size() != 0 && lshOpts.size() != 2, "--output-approx-knn takes 2 parameters"); + ABORT_IF(lshOpts.size() == 2 && shortlist.size() > 0, "LSH and shortlist cannot be used at the same time"); + + if(lshOpts.size() == 2 || shortlist.size() > 0) { + Ptr<data::ShortlistGenerator> shortListGen; + // both ShortListGenerators are thin wrappers, hence no problem with calling this per query + if(lshOpts.size() == 2) { + // Setting abortIfDynamic to true disallows memory allocation for LSH parameters, this is specifically for use in Quicksand. + // If we want to use the LSH in Quicksand we need to create a binary model that contains the LSH parameters via conversion. + shortListGen = New<data::LSHShortlistGenerator>(lshOpts[0], lshOpts[1], vocabs_[1]->lemmaSize(), /*abortIfDynamic=*/true); + } else { + shortListGen = New<data::FakeShortlistGenerator>(shortlist); + } for(auto scorer : scorers_) scorer->setShortlistGenerator(shortListGen); } @@ -249,7 +265,7 @@ DecoderCpuAvxVersion parseCpuAvxVersion(std::string name) { // This function converts an fp32 model into an FBGEMM based packed model. // marian defined types are used for external project as well. // The targetPrec is passed as int32_t for the exported function definition. -bool convertModel(std::string inputFile, std::string outputFile, int32_t targetPrec, bool addLsh) { +bool convertModel(std::string inputFile, std::string outputFile, int32_t targetPrec, int32_t lshNBits) { std::cerr << "Converting from: " << inputFile << ", to: " << outputFile << ", precision: " << targetPrec << std::endl; YAML::Node config; @@ -264,9 +280,10 @@ bool convertModel(std::string inputFile, std::string outputFile, int32_t targetP // MJD: Note, this is a default settings which we might want to change or expose. Use this only with Polonium students. // The LSH will not be used by default even if it exists in the model. That has to be enabled in the decoder config. - int lshNBits = 1024; std::string lshOutputWeights = "Wemb"; + bool addLsh = lshNBits > 0; if(addLsh) { + std::cerr << "Adding LSH to model with hash size " << lshNBits << std::endl; // Add dummy parameters for the LSH before the model gets actually initialized. // This create the parameters with useless values in the tensors, but it gives us the memory we need. graph->setReloaded(false); diff --git a/src/microsoft/quicksand.h b/src/microsoft/quicksand.h index b710e135..cddcfd22 100644 --- a/src/microsoft/quicksand.h +++ b/src/microsoft/quicksand.h @@ -79,7 +79,7 @@ DecoderCpuAvxVersion parseCpuAvxVersion(std::string name); // MJD: added "addLsh" which will now break whatever compilation after update. That's on purpose. // The calling code should be adapted, not this interface. If you need to fix things in QS because of this // talk to me first! -bool convertModel(std::string inputFile, std::string outputFile, int32_t targetPrec, bool addLsh); +bool convertModel(std::string inputFile, std::string outputFile, int32_t targetPrec, int32_t lshNBits); } // namespace quicksand } // namespace marian diff --git a/src/microsoft/shortlist/utils/Converter.cpp b/src/microsoft/shortlist/utils/Converter.cpp index c28178cd..df44b338 100644 --- a/src/microsoft/shortlist/utils/Converter.cpp +++ b/src/microsoft/shortlist/utils/Converter.cpp @@ -1,5 +1,6 @@ #include "microsoft/shortlist/utils/Converter.h" +namespace marian { namespace quicksand { #include "microsoft/shortlist/logging/LoggerMacros.h" @@ -57,3 +58,4 @@ void Converter::HandleConversionError(const std::string& str, const char * type_ } } // namespace quicksand +} // namespace marian
\ No newline at end of file diff --git a/src/microsoft/shortlist/utils/Converter.h b/src/microsoft/shortlist/utils/Converter.h index 9d9dd96d..ecbb5457 100644 --- a/src/microsoft/shortlist/utils/Converter.h +++ b/src/microsoft/shortlist/utils/Converter.h @@ -5,6 +5,7 @@ #include <vector> #include <sstream> +namespace marian { namespace quicksand { class Converter { @@ -81,3 +82,4 @@ std::vector<T> Converter::ConvertVectorInternal(I begin, I end, const char * typ } } // namespace quicksand +} // namespace marian
\ No newline at end of file diff --git a/src/microsoft/shortlist/utils/ParameterTree.cpp b/src/microsoft/shortlist/utils/ParameterTree.cpp index 465d2e0d..b7396b5e 100644 --- a/src/microsoft/shortlist/utils/ParameterTree.cpp +++ b/src/microsoft/shortlist/utils/ParameterTree.cpp @@ -5,6 +5,7 @@ #include "microsoft/shortlist/utils/StringUtils.h" #include "microsoft/shortlist/utils/Converter.h" +namespace marian { namespace quicksand { #include "microsoft/shortlist/logging/LoggerMacros.h" @@ -414,4 +415,4 @@ void ParameterTree::ReplaceVariablesInternal( } } // namespace quicksand - +} // namespace marian
\ No newline at end of file diff --git a/src/microsoft/shortlist/utils/ParameterTree.h b/src/microsoft/shortlist/utils/ParameterTree.h index 1474ff64..e9052f2e 100644 --- a/src/microsoft/shortlist/utils/ParameterTree.h +++ b/src/microsoft/shortlist/utils/ParameterTree.h @@ -8,6 +8,7 @@ #include "microsoft/shortlist/utils/StringUtils.h" +namespace marian { namespace quicksand { class ParameterTree { @@ -183,3 +184,4 @@ void ParameterTree::SetParam(const std::string& name, const T& obj) { } } // namespace quicksand +} // namespace marian
\ No newline at end of file diff --git a/src/microsoft/shortlist/utils/StringUtils.cpp b/src/microsoft/shortlist/utils/StringUtils.cpp index 7870b542..e4fb8815 100644 --- a/src/microsoft/shortlist/utils/StringUtils.cpp +++ b/src/microsoft/shortlist/utils/StringUtils.cpp @@ -4,6 +4,7 @@ #include <algorithm> #include <string> +namespace marian { namespace quicksand { #include "microsoft/shortlist/logging/LoggerMacros.h" @@ -336,3 +337,4 @@ std::string StringUtils::ToLower(const std::string& str) { } } // namespace quicksand +} // namespace marian
\ No newline at end of file diff --git a/src/microsoft/shortlist/utils/StringUtils.h b/src/microsoft/shortlist/utils/StringUtils.h index 31bb1fcc..be9d1e54 100644 --- a/src/microsoft/shortlist/utils/StringUtils.h +++ b/src/microsoft/shortlist/utils/StringUtils.h @@ -8,6 +8,7 @@ #include "microsoft/shortlist/utils/PrintTypes.h" +namespace marian { namespace quicksand { class StringUtils { @@ -96,3 +97,4 @@ std::string StringUtils::ToString(const T& obj) { } } // namespace quicksand +} // namespace marian
\ No newline at end of file |