Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorHieu Hoang <hihoan@microsoft.com>2021-04-29 09:40:00 +0300
committerHieu Hoang <hihoan@microsoft.com>2021-04-29 09:40:00 +0300
commit909df372d10803395684a60d6d6fe0cb7de83637 (patch)
tree8a3453d4cf43ee9917558e265a9353f840d7edef
parent49e379bba5c77c1b80927b7f0db5603e171a1903 (diff)
restart
-rw-r--r--src/CMakeLists.txt1
-rw-r--r--src/data/shortlist.cpp43
-rw-r--r--src/data/shortlist.h29
-rw-r--r--src/layers/generic.cpp1
-rw-r--r--src/layers/lsh.cpp130
-rw-r--r--src/layers/lsh.h31
-rw-r--r--src/layers/output.cpp14
-rw-r--r--src/layers/output.h2
8 files changed, 59 insertions, 192 deletions
diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt
index cf276137..d2fd269f 100644
--- a/src/CMakeLists.txt
+++ b/src/CMakeLists.txt
@@ -72,7 +72,6 @@ set(MARIAN_SOURCES
layers/generic.cpp
layers/loss.cpp
layers/weight.cpp
- layers/lsh.cpp
layers/embedding.cpp
layers/output.cpp
layers/logits.cpp
diff --git a/src/data/shortlist.cpp b/src/data/shortlist.cpp
index 6f551262..67317f4b 100644
--- a/src/data/shortlist.cpp
+++ b/src/data/shortlist.cpp
@@ -1,5 +1,6 @@
#include "data/shortlist.h"
#include "microsoft/shortlist/utils/ParameterTree.h"
+#include "marian.h"
namespace marian {
namespace data {
@@ -12,6 +13,48 @@ const T* get(const void*& current, size_t num = 1) {
return ptr;
}
+//////////////////////////////////////////////////////////////////////////////////////
+Shortlist::Shortlist(const std::vector<WordIndex>& indices)
+ : indices_(indices) {}
+
+const std::vector<WordIndex>& Shortlist::indices() const { return indices_; }
+WordIndex Shortlist::reverseMap(int idx) { return indices_[idx]; }
+
+WordIndex Shortlist::tryForwardMap(WordIndex wIdx) {
+ auto first = std::lower_bound(indices_.begin(), indices_.end(), wIdx);
+ if(first != indices_.end() && *first == wIdx) // check if element not less than wIdx has been found and if equal to wIdx
+ return (int)std::distance(indices_.begin(), first); // return coordinate if found
+ else
+ return npos; // return npos if not found, @TODO: replace with std::optional once we switch to C++17?
+}
+
+void Shortlist::filter(Expr input, Expr weights, bool isLegacyUntransposedW, Expr b, Expr lemmaEt) {
+ int k = indices_.size();
+ int currBeamSize = input->shape()[0];
+ int batchSize = input->shape()[2];
+ std::cerr << "currBeamSize=" << currBeamSize << std::endl;
+ std::cerr << "batchSize=" << batchSize << std::endl;
+
+ Expr indicesExprBC;
+ broadcast(weights, isLegacyUntransposedW, b, lemmaEt, indicesExprBC, k);
+}
+
+
+void Shortlist::broadcast(Expr weights,
+ bool isLegacyUntransposedW,
+ Expr b,
+ Expr lemmaEt,
+ Expr indicesExprBC,
+ int k) {
+ cachedShortWt_ = index_select(weights, isLegacyUntransposedW ? -1 : 0, indices());
+ if (b) {
+ cachedShortb_ = index_select(b, -1, indices());
+ }
+ cachedShortLemmaEt_ = index_select(lemmaEt, -1, indices());
+ return;
+
+}
+//////////////////////////////////////////////////////////////////////////////////////
QuicksandShortlistGenerator::QuicksandShortlistGenerator(Ptr<Options> options,
Ptr<const Vocab> srcVocab,
Ptr<const Vocab> trgVocab,
diff --git a/src/data/shortlist.h b/src/data/shortlist.h
index f0467640..dd7d0589 100644
--- a/src/data/shortlist.h
+++ b/src/data/shortlist.h
@@ -19,26 +19,29 @@ namespace marian {
namespace data {
class Shortlist {
-private:
+protected:
std::vector<WordIndex> indices_; // // [packed shortlist index] -> word index, used to select columns from output embeddings
+ Expr cachedShortWt_; // short-listed version, cached (cleared by clear())
+ Expr cachedShortb_; // these match the current value of shortlist_
+ Expr cachedShortLemmaEt_;
+
+ virtual void broadcast(Expr weights,
+ bool isLegacyUntransposedW,
+ Expr b,
+ Expr lemmaEt,
+ Expr indicesExprBC,
+ int k);
public:
static constexpr WordIndex npos{std::numeric_limits<WordIndex>::max()}; // used to identify invalid shortlist entries similar to std::string::npos
- Shortlist(const std::vector<WordIndex>& indices)
- : indices_(indices) {}
+ Shortlist(const std::vector<WordIndex>& indices);
- const std::vector<WordIndex>& indices() const { return indices_; }
- WordIndex reverseMap(int idx) { return indices_[idx]; }
-
- WordIndex tryForwardMap(WordIndex wIdx) {
- auto first = std::lower_bound(indices_.begin(), indices_.end(), wIdx);
- if(first != indices_.end() && *first == wIdx) // check if element not less than wIdx has been found and if equal to wIdx
- return (int)std::distance(indices_.begin(), first); // return coordinate if found
- else
- return npos; // return npos if not found, @TODO: replace with std::optional once we switch to C++17?
- }
+ const std::vector<WordIndex>& indices() const;
+ WordIndex reverseMap(int idx);
+ WordIndex tryForwardMap(WordIndex wIdx);
+ virtual void filter(Expr input, Expr weights, bool isLegacyUntransposedW, Expr b, Expr lemmaEt);
};
class ShortlistGenerator {
diff --git a/src/layers/generic.cpp b/src/layers/generic.cpp
index 8e2ecfd7..17ef32fc 100644
--- a/src/layers/generic.cpp
+++ b/src/layers/generic.cpp
@@ -4,7 +4,6 @@
#include "layers/constructors.h"
#include "layers/generic.h"
#include "layers/loss.h"
-#include "layers/lsh.h"
#include "models/states.h" // for EncoderState
namespace marian {} // namespace marian
diff --git a/src/layers/lsh.cpp b/src/layers/lsh.cpp
deleted file mode 100644
index a91778ed..00000000
--- a/src/layers/lsh.cpp
+++ /dev/null
@@ -1,130 +0,0 @@
-#include "layers/lsh.h"
-#include "graph/expression_operators.h"
-#include "tensors/cpu/prod_blas.h"
-
-#if BLAS_FOUND
-#include "3rd_party/faiss/IndexLSH.h"
-#endif
-
-namespace marian {
-
-Expr LSH::apply(Expr input, Expr W, Expr b) {
- auto idx = search(input, W);
- return affine(idx, input, W, b);
-}
-
-Expr LSH::search(Expr query, Expr values) {
-#if BLAS_FOUND
- ABORT_IF(query->graph()->getDeviceId().type == DeviceType::gpu,
- "LSH index (--output-approx-knn) currently not implemented for GPU");
-
- auto kShape = query->shape();
- kShape.set(-1, k_);
-
- auto forward = [this](Expr out, const std::vector<Expr>& inputs) {
- auto query = inputs[0];
- auto values = inputs[1];
-
- int dim = values->shape()[-1];
-
- if(!index_ || indexHash_ != values->hash()) {
- LOG(info, "Building LSH index for vector dim {} and with hash size {} bits", dim, nbits_);
- index_.reset(new faiss::IndexLSH(dim, nbits_,
- /*rotate=*/dim != nbits_,
- /*train_thesholds*/false));
- int vRows = values->shape().elements() / dim;
- index_->train(vRows, values->val()->data<float>());
- index_->add( vRows, values->val()->data<float>());
- indexHash_ = values->hash();
- }
-
- int qRows = query->shape().elements() / dim;
- std::vector<float> distances(qRows * k_);
- std::vector<faiss::Index::idx_t> ids(qRows * k_);
-
- index_->search(qRows, query->val()->data<float>(), k_,
- distances.data(), ids.data());
-
- std::vector<IndexType> vOut;
- vOut.reserve(ids.size());
- for(auto id : ids)
- vOut.push_back((IndexType)id);
-
- out->val()->set(vOut);
- };
-
- return lambda({query, values}, kShape, Type::uint32, forward);
-#else
- query; values;
- ABORT("LSH output layer requires a CPU BLAS library");
-#endif
-}
-
-Expr LSH::affine(Expr idx, Expr input, Expr W, Expr b) {
- auto outShape = input->shape();
- int dimVoc = W->shape()[-2];
- outShape.set(-1, dimVoc);
-
- auto forward = [this](Expr out, const std::vector<Expr>& inputs) {
- auto lowest = NumericLimits<float>(out->value_type()).lowest;
- out->val()->set(lowest);
-
- int dimIn = inputs[1]->shape()[-1];
- int dimOut = out->shape()[-1];
- int dimRows = out->shape().elements() / dimOut;
-
- auto outPtr = out->val()->data<float>();
- auto idxPtr = inputs[0]->val()->data<uint32_t>();
- auto queryPtr = inputs[1]->val()->data<float>();
- auto WPtr = inputs[2]->val()->data<float>();
- auto bPtr = inputs.size() > 3 ? inputs[3]->val()->data<float>() : nullptr; // nullptr if no bias given
-
- for(int row = 0; row < dimRows; ++row) {
- auto currIdxPtr = idxPtr + row * k_; // move to next batch of k entries
- auto currQueryPtr = queryPtr + row * dimIn; // move to next input query vector
- auto currOutPtr = outPtr + row * dimOut; // move to next output position vector (of vocabulary size)
- for(int k = 0; k < k_; k++) {
- int relPos = currIdxPtr[k]; // k-th best vocabulay item
- auto currWPtr = WPtr + relPos * dimIn; // offset for k-th best embedding
- currOutPtr[relPos] = bPtr ? bPtr[relPos] : 0; // write bias value to position, init to 0 if no bias given
-
- // proceed one vector product at a time writing to the correct position
- sgemm(false, true, 1, 1, dimIn, 1.0f, currQueryPtr, dimIn, currWPtr, dimIn, 1.0f, &currOutPtr[relPos], 1);
- }
- }
- };
-
- std::vector<Expr> nodes = {idx, input, W};
- if(b) // bias is optional
- nodes.push_back(b);
-
- return lambda(nodes,
- outShape,
- input->value_type(),
- forward);
-}
-
-// @TODO: alternative version which does the same as above with Marian operators, currently missing "scatter".
-// this uses more memory and likely to be slower. Would make sense to have a scatter node that actually creates
-// the node instead of relying on an existing node, e.g. scatter(shape, defaultValue, axis, indices, values);
-#if 0
-Expr LSH::affine(Expr idx, Expr input, Expr W, Expr b) {
- int dim = input->shape()[-1];
- int bch = idx->shape().elements() / k;
-
- auto W = reshape(rows(Wt_, flatten(idx)), {bch, k, dim}); // [rows, k, dim]
- auto b = reshape(cols(b_, flatten(idx)), {bch, 1, k}); // [rows, 1, k]
-
- auto aff = reshape(bdot(reshape(input, {bch, 1, dim}), W, false, true) + b, idx->shape()); // [beam, time, batch, k]
-
- int dimVoc = Wt_->shape()[-2];
- auto oShape = input->shape();
- oShape.set(-1, dimVoc);
- auto lowest = graph_->constant(oShape,
- inits::fromValue(NumericLimits<float>(input->value_type()).lowest),
- input->value_type());
- return scatter(lowest, -1, idx, aff);
-}
-#endif
-
-} // namespace marian \ No newline at end of file
diff --git a/src/layers/lsh.h b/src/layers/lsh.h
deleted file mode 100644
index bf498cc6..00000000
--- a/src/layers/lsh.h
+++ /dev/null
@@ -1,31 +0,0 @@
-#include "graph/expression_graph.h"
-#include <memory>
-
-namespace faiss {
- struct IndexLSH;
-}
-
-namespace marian {
-
-class LSH {
-public:
- LSH(int k, int nbits) : k_{k}, nbits_{nbits} {
-#if !BLAS_FOUND
- ABORT("LSH-based output approximation requires BLAS library");
-#endif
- }
-
- Expr apply(Expr query, Expr values, Expr bias);
-
-private:
- Ptr<faiss::IndexLSH> index_;
- size_t indexHash_{0};
-
- int k_{100};
- int nbits_{1024};
-
- Expr search(Expr query, Expr values);
- Expr affine(Expr idx, Expr query, Expr values, Expr bias);
-};
-
-} \ No newline at end of file
diff --git a/src/layers/output.cpp b/src/layers/output.cpp
index 4c34bdce..e9bffac4 100644
--- a/src/layers/output.cpp
+++ b/src/layers/output.cpp
@@ -2,7 +2,6 @@
#include "common/timer.h"
#include "data/factored_vocab.h"
#include "layers/loss.h"
-#include "layers/lsh.h"
namespace marian {
namespace mlp {
@@ -12,13 +11,6 @@ namespace mlp {
if(Wt_)
return;
- // this option is only set in the decoder
- if(!lsh_ && options_->hasAndNotEmpty("output-approx-knn")) {
- auto k = opt<std::vector<int>>("output-approx-knn")[0];
- auto nbits = opt<std::vector<int>>("output-approx-knn")[1];
- lsh_ = New<LSH>(k, nbits);
- }
-
auto name = options_->get<std::string>("prefix");
auto numOutputClasses = options_->get<int>("dim");
@@ -71,13 +63,7 @@ Logits Output::applyAsLogits(Expr input) /*override final*/ {
};
auto affineOrLSH = [this, affineOrDot](Expr x, Expr W, Expr b, bool transA, bool transB) {
- if(lsh_) {
- ABORT_IF(transA, "Transposed query not supported for LSH");
- ABORT_IF(!transB, "Untransposed indexed matrix not supported for LSH");
- return lsh_->apply(x, W, b); // knows how to deal with undefined bias
- } else {
return affineOrDot(x, W, b, transA, transB);
- }
};
if(shortlist_ && !cachedShortWt_) { // shortlisted versions of parameters are cached within one
diff --git a/src/layers/output.h b/src/layers/output.h
index 2b6f4986..bf8a580a 100644
--- a/src/layers/output.h
+++ b/src/layers/output.h
@@ -7,7 +7,6 @@
#include "marian.h"
namespace marian {
-class LSH;
namespace mlp {
@@ -28,7 +27,6 @@ private:
// optional parameters set/updated after construction
Expr tiedParam_;
Ptr<data::Shortlist> shortlist_;
- Ptr<LSH> lsh_;
void lazyConstruct(int inputDim);