Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMartin Junczys-Dowmunt <Marcin.JunczysDowmunt@microsoft.com>2021-07-16 23:04:16 +0300
committerMartin Junczys-Dowmunt <Marcin.JunczysDowmunt@microsoft.com>2021-07-16 23:04:16 +0300
commit8e88071ae8caa89e3926c0d7281d8a59897e222c (patch)
treebb4f9e505ff2a5faa82e28797e7196927c83fc41 /src/microsoft
parent42f0b8b74bba16fed646c8af7b2f75e02af7a85c (diff)
Merged PR 19842: Adapt LSH to work with Leaf
Small changes to make the LSH work with Leaf server and QuickSand.
Diffstat (limited to 'src/microsoft')
-rw-r--r--src/microsoft/quicksand.cpp27
-rw-r--r--src/microsoft/quicksand.h2
-rw-r--r--src/microsoft/shortlist/utils/Converter.cpp2
-rw-r--r--src/microsoft/shortlist/utils/Converter.h2
-rw-r--r--src/microsoft/shortlist/utils/ParameterTree.cpp3
-rw-r--r--src/microsoft/shortlist/utils/ParameterTree.h2
-rw-r--r--src/microsoft/shortlist/utils/StringUtils.cpp2
-rw-r--r--src/microsoft/shortlist/utils/StringUtils.h2
8 files changed, 35 insertions, 7 deletions
diff --git a/src/microsoft/quicksand.cpp b/src/microsoft/quicksand.cpp
index 70e657a9..099ce180 100644
--- a/src/microsoft/quicksand.cpp
+++ b/src/microsoft/quicksand.cpp
@@ -78,7 +78,7 @@ public:
graph_->setDevice(deviceId, device_);
#if MKL_FOUND
- mkl_set_num_threads(options->get<int>("mkl-threads", 1));
+ mkl_set_num_threads(options_->get<int>("mkl-threads", 1));
#endif
std::vector<std::string> models
@@ -114,6 +114,9 @@ public:
for(auto scorer : scorers_) {
scorer->init(graph_);
}
+
+ // run parameter init once, this is required for graph_->get("parameter name") to work correctly
+ graph_->forward();
}
void setWorkspace(uint8_t* data, size_t size) override { device_->set(data, size); }
@@ -121,8 +124,21 @@ public:
QSNBestBatch decode(const QSBatch& qsBatch,
size_t maxLength,
const std::unordered_set<WordIndex>& shortlist) override {
- if(shortlist.size() > 0) {
- auto shortListGen = New<data::FakeShortlistGenerator>(shortlist);
+
+ std::vector<int> lshOpts = options_->get<std::vector<int>>("output-approx-knn", {});
+ ABORT_IF(lshOpts.size() != 0 && lshOpts.size() != 2, "--output-approx-knn takes 2 parameters");
+ ABORT_IF(lshOpts.size() == 2 && shortlist.size() > 0, "LSH and shortlist cannot be used at the same time");
+
+ if(lshOpts.size() == 2 || shortlist.size() > 0) {
+ Ptr<data::ShortlistGenerator> shortListGen;
+ // both ShortListGenerators are thin wrappers, hence no problem with calling this per query
+ if(lshOpts.size() == 2) {
+ // Setting abortIfDynamic to true disallows memory allocation for LSH parameters, this is specifically for use in Quicksand.
+ // If we want to use the LSH in Quicksand we need to create a binary model that contains the LSH parameters via conversion.
+ shortListGen = New<data::LSHShortlistGenerator>(lshOpts[0], lshOpts[1], vocabs_[1]->lemmaSize(), /*abortIfDynamic=*/true);
+ } else {
+ shortListGen = New<data::FakeShortlistGenerator>(shortlist);
+ }
for(auto scorer : scorers_)
scorer->setShortlistGenerator(shortListGen);
}
@@ -249,7 +265,7 @@ DecoderCpuAvxVersion parseCpuAvxVersion(std::string name) {
// This function converts an fp32 model into an FBGEMM based packed model.
// marian defined types are used for external project as well.
// The targetPrec is passed as int32_t for the exported function definition.
-bool convertModel(std::string inputFile, std::string outputFile, int32_t targetPrec, bool addLsh) {
+bool convertModel(std::string inputFile, std::string outputFile, int32_t targetPrec, int32_t lshNBits) {
std::cerr << "Converting from: " << inputFile << ", to: " << outputFile << ", precision: " << targetPrec << std::endl;
YAML::Node config;
@@ -264,9 +280,10 @@ bool convertModel(std::string inputFile, std::string outputFile, int32_t targetP
// MJD: Note, this is a default settings which we might want to change or expose. Use this only with Polonium students.
// The LSH will not be used by default even if it exists in the model. That has to be enabled in the decoder config.
- int lshNBits = 1024;
std::string lshOutputWeights = "Wemb";
+ bool addLsh = lshNBits > 0;
if(addLsh) {
+ std::cerr << "Adding LSH to model with hash size " << lshNBits << std::endl;
// Add dummy parameters for the LSH before the model gets actually initialized.
// This create the parameters with useless values in the tensors, but it gives us the memory we need.
graph->setReloaded(false);
diff --git a/src/microsoft/quicksand.h b/src/microsoft/quicksand.h
index b710e135..cddcfd22 100644
--- a/src/microsoft/quicksand.h
+++ b/src/microsoft/quicksand.h
@@ -79,7 +79,7 @@ DecoderCpuAvxVersion parseCpuAvxVersion(std::string name);
// MJD: added "addLsh" which will now break whatever compilation after update. That's on purpose.
// The calling code should be adapted, not this interface. If you need to fix things in QS because of this
// talk to me first!
-bool convertModel(std::string inputFile, std::string outputFile, int32_t targetPrec, bool addLsh);
+bool convertModel(std::string inputFile, std::string outputFile, int32_t targetPrec, int32_t lshNBits);
} // namespace quicksand
} // namespace marian
diff --git a/src/microsoft/shortlist/utils/Converter.cpp b/src/microsoft/shortlist/utils/Converter.cpp
index c28178cd..df44b338 100644
--- a/src/microsoft/shortlist/utils/Converter.cpp
+++ b/src/microsoft/shortlist/utils/Converter.cpp
@@ -1,5 +1,6 @@
#include "microsoft/shortlist/utils/Converter.h"
+namespace marian {
namespace quicksand {
#include "microsoft/shortlist/logging/LoggerMacros.h"
@@ -57,3 +58,4 @@ void Converter::HandleConversionError(const std::string& str, const char * type_
}
} // namespace quicksand
+} // namespace marian \ No newline at end of file
diff --git a/src/microsoft/shortlist/utils/Converter.h b/src/microsoft/shortlist/utils/Converter.h
index 9d9dd96d..ecbb5457 100644
--- a/src/microsoft/shortlist/utils/Converter.h
+++ b/src/microsoft/shortlist/utils/Converter.h
@@ -5,6 +5,7 @@
#include <vector>
#include <sstream>
+namespace marian {
namespace quicksand {
class Converter {
@@ -81,3 +82,4 @@ std::vector<T> Converter::ConvertVectorInternal(I begin, I end, const char * typ
}
} // namespace quicksand
+} // namespace marian \ No newline at end of file
diff --git a/src/microsoft/shortlist/utils/ParameterTree.cpp b/src/microsoft/shortlist/utils/ParameterTree.cpp
index 465d2e0d..b7396b5e 100644
--- a/src/microsoft/shortlist/utils/ParameterTree.cpp
+++ b/src/microsoft/shortlist/utils/ParameterTree.cpp
@@ -5,6 +5,7 @@
#include "microsoft/shortlist/utils/StringUtils.h"
#include "microsoft/shortlist/utils/Converter.h"
+namespace marian {
namespace quicksand {
#include "microsoft/shortlist/logging/LoggerMacros.h"
@@ -414,4 +415,4 @@ void ParameterTree::ReplaceVariablesInternal(
}
} // namespace quicksand
-
+} // namespace marian \ No newline at end of file
diff --git a/src/microsoft/shortlist/utils/ParameterTree.h b/src/microsoft/shortlist/utils/ParameterTree.h
index 1474ff64..e9052f2e 100644
--- a/src/microsoft/shortlist/utils/ParameterTree.h
+++ b/src/microsoft/shortlist/utils/ParameterTree.h
@@ -8,6 +8,7 @@
#include "microsoft/shortlist/utils/StringUtils.h"
+namespace marian {
namespace quicksand {
class ParameterTree {
@@ -183,3 +184,4 @@ void ParameterTree::SetParam(const std::string& name, const T& obj) {
}
} // namespace quicksand
+} // namespace marian \ No newline at end of file
diff --git a/src/microsoft/shortlist/utils/StringUtils.cpp b/src/microsoft/shortlist/utils/StringUtils.cpp
index 7870b542..e4fb8815 100644
--- a/src/microsoft/shortlist/utils/StringUtils.cpp
+++ b/src/microsoft/shortlist/utils/StringUtils.cpp
@@ -4,6 +4,7 @@
#include <algorithm>
#include <string>
+namespace marian {
namespace quicksand {
#include "microsoft/shortlist/logging/LoggerMacros.h"
@@ -336,3 +337,4 @@ std::string StringUtils::ToLower(const std::string& str) {
}
} // namespace quicksand
+} // namespace marian \ No newline at end of file
diff --git a/src/microsoft/shortlist/utils/StringUtils.h b/src/microsoft/shortlist/utils/StringUtils.h
index 31bb1fcc..be9d1e54 100644
--- a/src/microsoft/shortlist/utils/StringUtils.h
+++ b/src/microsoft/shortlist/utils/StringUtils.h
@@ -8,6 +8,7 @@
#include "microsoft/shortlist/utils/PrintTypes.h"
+namespace marian {
namespace quicksand {
class StringUtils {
@@ -96,3 +97,4 @@ std::string StringUtils::ToString(const T& obj) {
}
} // namespace quicksand
+} // namespace marian \ No newline at end of file