diff options
author | Hieu Hoang <hihoan@microsoft.com> | 2021-04-29 09:40:00 +0300 |
---|---|---|
committer | Hieu Hoang <hihoan@microsoft.com> | 2021-04-29 09:40:00 +0300 |
commit | 909df372d10803395684a60d6d6fe0cb7de83637 (patch) | |
tree | 8a3453d4cf43ee9917558e265a9353f840d7edef | |
parent | 49e379bba5c77c1b80927b7f0db5603e171a1903 (diff) |
restart
-rw-r--r-- | src/CMakeLists.txt | 1 | ||||
-rw-r--r-- | src/data/shortlist.cpp | 43 | ||||
-rw-r--r-- | src/data/shortlist.h | 29 | ||||
-rw-r--r-- | src/layers/generic.cpp | 1 | ||||
-rw-r--r-- | src/layers/lsh.cpp | 130 | ||||
-rw-r--r-- | src/layers/lsh.h | 31 | ||||
-rw-r--r-- | src/layers/output.cpp | 14 | ||||
-rw-r--r-- | src/layers/output.h | 2 |
8 files changed, 59 insertions, 192 deletions
diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index cf276137..d2fd269f 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -72,7 +72,6 @@ set(MARIAN_SOURCES layers/generic.cpp layers/loss.cpp layers/weight.cpp - layers/lsh.cpp layers/embedding.cpp layers/output.cpp layers/logits.cpp diff --git a/src/data/shortlist.cpp b/src/data/shortlist.cpp index 6f551262..67317f4b 100644 --- a/src/data/shortlist.cpp +++ b/src/data/shortlist.cpp @@ -1,5 +1,6 @@ #include "data/shortlist.h" #include "microsoft/shortlist/utils/ParameterTree.h" +#include "marian.h" namespace marian { namespace data { @@ -12,6 +13,48 @@ const T* get(const void*& current, size_t num = 1) { return ptr; } +////////////////////////////////////////////////////////////////////////////////////// +Shortlist::Shortlist(const std::vector<WordIndex>& indices) + : indices_(indices) {} + +const std::vector<WordIndex>& Shortlist::indices() const { return indices_; } +WordIndex Shortlist::reverseMap(int idx) { return indices_[idx]; } + +WordIndex Shortlist::tryForwardMap(WordIndex wIdx) { + auto first = std::lower_bound(indices_.begin(), indices_.end(), wIdx); + if(first != indices_.end() && *first == wIdx) // check if element not less than wIdx has been found and if equal to wIdx + return (int)std::distance(indices_.begin(), first); // return coordinate if found + else + return npos; // return npos if not found, @TODO: replace with std::optional once we switch to C++17? +} + +void Shortlist::filter(Expr input, Expr weights, bool isLegacyUntransposedW, Expr b, Expr lemmaEt) { + int k = indices_.size(); + int currBeamSize = input->shape()[0]; + int batchSize = input->shape()[2]; + std::cerr << "currBeamSize=" << currBeamSize << std::endl; + std::cerr << "batchSize=" << batchSize << std::endl; + + Expr indicesExprBC; + broadcast(weights, isLegacyUntransposedW, b, lemmaEt, indicesExprBC, k); +} + + +void Shortlist::broadcast(Expr weights, + bool isLegacyUntransposedW, + Expr b, + Expr lemmaEt, + Expr indicesExprBC, + int k) { + cachedShortWt_ = index_select(weights, isLegacyUntransposedW ? -1 : 0, indices()); + if (b) { + cachedShortb_ = index_select(b, -1, indices()); + } + cachedShortLemmaEt_ = index_select(lemmaEt, -1, indices()); + return; + +} +////////////////////////////////////////////////////////////////////////////////////// QuicksandShortlistGenerator::QuicksandShortlistGenerator(Ptr<Options> options, Ptr<const Vocab> srcVocab, Ptr<const Vocab> trgVocab, diff --git a/src/data/shortlist.h b/src/data/shortlist.h index f0467640..dd7d0589 100644 --- a/src/data/shortlist.h +++ b/src/data/shortlist.h @@ -19,26 +19,29 @@ namespace marian { namespace data { class Shortlist { -private: +protected: std::vector<WordIndex> indices_; // // [packed shortlist index] -> word index, used to select columns from output embeddings + Expr cachedShortWt_; // short-listed version, cached (cleared by clear()) + Expr cachedShortb_; // these match the current value of shortlist_ + Expr cachedShortLemmaEt_; + + virtual void broadcast(Expr weights, + bool isLegacyUntransposedW, + Expr b, + Expr lemmaEt, + Expr indicesExprBC, + int k); public: static constexpr WordIndex npos{std::numeric_limits<WordIndex>::max()}; // used to identify invalid shortlist entries similar to std::string::npos - Shortlist(const std::vector<WordIndex>& indices) - : indices_(indices) {} + Shortlist(const std::vector<WordIndex>& indices); - const std::vector<WordIndex>& indices() const { return indices_; } - WordIndex reverseMap(int idx) { return indices_[idx]; } - - WordIndex tryForwardMap(WordIndex wIdx) { - auto first = std::lower_bound(indices_.begin(), indices_.end(), wIdx); - if(first != indices_.end() && *first == wIdx) // check if element not less than wIdx has been found and if equal to wIdx - return (int)std::distance(indices_.begin(), first); // return coordinate if found - else - return npos; // return npos if not found, @TODO: replace with std::optional once we switch to C++17? - } + const std::vector<WordIndex>& indices() const; + WordIndex reverseMap(int idx); + WordIndex tryForwardMap(WordIndex wIdx); + virtual void filter(Expr input, Expr weights, bool isLegacyUntransposedW, Expr b, Expr lemmaEt); }; class ShortlistGenerator { diff --git a/src/layers/generic.cpp b/src/layers/generic.cpp index 8e2ecfd7..17ef32fc 100644 --- a/src/layers/generic.cpp +++ b/src/layers/generic.cpp @@ -4,7 +4,6 @@ #include "layers/constructors.h" #include "layers/generic.h" #include "layers/loss.h" -#include "layers/lsh.h" #include "models/states.h" // for EncoderState namespace marian {} // namespace marian diff --git a/src/layers/lsh.cpp b/src/layers/lsh.cpp deleted file mode 100644 index a91778ed..00000000 --- a/src/layers/lsh.cpp +++ /dev/null @@ -1,130 +0,0 @@ -#include "layers/lsh.h" -#include "graph/expression_operators.h" -#include "tensors/cpu/prod_blas.h" - -#if BLAS_FOUND -#include "3rd_party/faiss/IndexLSH.h" -#endif - -namespace marian { - -Expr LSH::apply(Expr input, Expr W, Expr b) { - auto idx = search(input, W); - return affine(idx, input, W, b); -} - -Expr LSH::search(Expr query, Expr values) { -#if BLAS_FOUND - ABORT_IF(query->graph()->getDeviceId().type == DeviceType::gpu, - "LSH index (--output-approx-knn) currently not implemented for GPU"); - - auto kShape = query->shape(); - kShape.set(-1, k_); - - auto forward = [this](Expr out, const std::vector<Expr>& inputs) { - auto query = inputs[0]; - auto values = inputs[1]; - - int dim = values->shape()[-1]; - - if(!index_ || indexHash_ != values->hash()) { - LOG(info, "Building LSH index for vector dim {} and with hash size {} bits", dim, nbits_); - index_.reset(new faiss::IndexLSH(dim, nbits_, - /*rotate=*/dim != nbits_, - /*train_thesholds*/false)); - int vRows = values->shape().elements() / dim; - index_->train(vRows, values->val()->data<float>()); - index_->add( vRows, values->val()->data<float>()); - indexHash_ = values->hash(); - } - - int qRows = query->shape().elements() / dim; - std::vector<float> distances(qRows * k_); - std::vector<faiss::Index::idx_t> ids(qRows * k_); - - index_->search(qRows, query->val()->data<float>(), k_, - distances.data(), ids.data()); - - std::vector<IndexType> vOut; - vOut.reserve(ids.size()); - for(auto id : ids) - vOut.push_back((IndexType)id); - - out->val()->set(vOut); - }; - - return lambda({query, values}, kShape, Type::uint32, forward); -#else - query; values; - ABORT("LSH output layer requires a CPU BLAS library"); -#endif -} - -Expr LSH::affine(Expr idx, Expr input, Expr W, Expr b) { - auto outShape = input->shape(); - int dimVoc = W->shape()[-2]; - outShape.set(-1, dimVoc); - - auto forward = [this](Expr out, const std::vector<Expr>& inputs) { - auto lowest = NumericLimits<float>(out->value_type()).lowest; - out->val()->set(lowest); - - int dimIn = inputs[1]->shape()[-1]; - int dimOut = out->shape()[-1]; - int dimRows = out->shape().elements() / dimOut; - - auto outPtr = out->val()->data<float>(); - auto idxPtr = inputs[0]->val()->data<uint32_t>(); - auto queryPtr = inputs[1]->val()->data<float>(); - auto WPtr = inputs[2]->val()->data<float>(); - auto bPtr = inputs.size() > 3 ? inputs[3]->val()->data<float>() : nullptr; // nullptr if no bias given - - for(int row = 0; row < dimRows; ++row) { - auto currIdxPtr = idxPtr + row * k_; // move to next batch of k entries - auto currQueryPtr = queryPtr + row * dimIn; // move to next input query vector - auto currOutPtr = outPtr + row * dimOut; // move to next output position vector (of vocabulary size) - for(int k = 0; k < k_; k++) { - int relPos = currIdxPtr[k]; // k-th best vocabulay item - auto currWPtr = WPtr + relPos * dimIn; // offset for k-th best embedding - currOutPtr[relPos] = bPtr ? bPtr[relPos] : 0; // write bias value to position, init to 0 if no bias given - - // proceed one vector product at a time writing to the correct position - sgemm(false, true, 1, 1, dimIn, 1.0f, currQueryPtr, dimIn, currWPtr, dimIn, 1.0f, &currOutPtr[relPos], 1); - } - } - }; - - std::vector<Expr> nodes = {idx, input, W}; - if(b) // bias is optional - nodes.push_back(b); - - return lambda(nodes, - outShape, - input->value_type(), - forward); -} - -// @TODO: alternative version which does the same as above with Marian operators, currently missing "scatter". -// this uses more memory and likely to be slower. Would make sense to have a scatter node that actually creates -// the node instead of relying on an existing node, e.g. scatter(shape, defaultValue, axis, indices, values); -#if 0 -Expr LSH::affine(Expr idx, Expr input, Expr W, Expr b) { - int dim = input->shape()[-1]; - int bch = idx->shape().elements() / k; - - auto W = reshape(rows(Wt_, flatten(idx)), {bch, k, dim}); // [rows, k, dim] - auto b = reshape(cols(b_, flatten(idx)), {bch, 1, k}); // [rows, 1, k] - - auto aff = reshape(bdot(reshape(input, {bch, 1, dim}), W, false, true) + b, idx->shape()); // [beam, time, batch, k] - - int dimVoc = Wt_->shape()[-2]; - auto oShape = input->shape(); - oShape.set(-1, dimVoc); - auto lowest = graph_->constant(oShape, - inits::fromValue(NumericLimits<float>(input->value_type()).lowest), - input->value_type()); - return scatter(lowest, -1, idx, aff); -} -#endif - -} // namespace marian
\ No newline at end of file diff --git a/src/layers/lsh.h b/src/layers/lsh.h deleted file mode 100644 index bf498cc6..00000000 --- a/src/layers/lsh.h +++ /dev/null @@ -1,31 +0,0 @@ -#include "graph/expression_graph.h" -#include <memory> - -namespace faiss { - struct IndexLSH; -} - -namespace marian { - -class LSH { -public: - LSH(int k, int nbits) : k_{k}, nbits_{nbits} { -#if !BLAS_FOUND - ABORT("LSH-based output approximation requires BLAS library"); -#endif - } - - Expr apply(Expr query, Expr values, Expr bias); - -private: - Ptr<faiss::IndexLSH> index_; - size_t indexHash_{0}; - - int k_{100}; - int nbits_{1024}; - - Expr search(Expr query, Expr values); - Expr affine(Expr idx, Expr query, Expr values, Expr bias); -}; - -}
\ No newline at end of file diff --git a/src/layers/output.cpp b/src/layers/output.cpp index 4c34bdce..e9bffac4 100644 --- a/src/layers/output.cpp +++ b/src/layers/output.cpp @@ -2,7 +2,6 @@ #include "common/timer.h" #include "data/factored_vocab.h" #include "layers/loss.h" -#include "layers/lsh.h" namespace marian { namespace mlp { @@ -12,13 +11,6 @@ namespace mlp { if(Wt_) return; - // this option is only set in the decoder - if(!lsh_ && options_->hasAndNotEmpty("output-approx-knn")) { - auto k = opt<std::vector<int>>("output-approx-knn")[0]; - auto nbits = opt<std::vector<int>>("output-approx-knn")[1]; - lsh_ = New<LSH>(k, nbits); - } - auto name = options_->get<std::string>("prefix"); auto numOutputClasses = options_->get<int>("dim"); @@ -71,13 +63,7 @@ Logits Output::applyAsLogits(Expr input) /*override final*/ { }; auto affineOrLSH = [this, affineOrDot](Expr x, Expr W, Expr b, bool transA, bool transB) { - if(lsh_) { - ABORT_IF(transA, "Transposed query not supported for LSH"); - ABORT_IF(!transB, "Untransposed indexed matrix not supported for LSH"); - return lsh_->apply(x, W, b); // knows how to deal with undefined bias - } else { return affineOrDot(x, W, b, transA, transB); - } }; if(shortlist_ && !cachedShortWt_) { // shortlisted versions of parameters are cached within one diff --git a/src/layers/output.h b/src/layers/output.h index 2b6f4986..bf8a580a 100644 --- a/src/layers/output.h +++ b/src/layers/output.h @@ -7,7 +7,6 @@ #include "marian.h" namespace marian { -class LSH; namespace mlp { @@ -28,7 +27,6 @@ private: // optional parameters set/updated after construction Expr tiedParam_; Ptr<data::Shortlist> shortlist_; - Ptr<LSH> lsh_; void lazyConstruct(int inputDim); |