From 35c822eb4ea29e5445b7c75b665c4872a2cc1adb Mon Sep 17 00:00:00 2001 From: Martin Junczys-Dowmunt Date: Fri, 9 Jul 2021 20:35:09 +0000 Subject: Merged PR 19685: Marianize LSH as operators for mmapping and use in Quicksand This PR turns the LSH index and search into a set of operators that live in the expression graph. This makes creation etc. thread-safe (one index per graph) and allows to later implement GPU versions. This allows to mmap the LSH as a Marian parameter since now we only need to turn the index into something that can be saved to disk using the existing tensors. This happens in marian_conv or the equivalent interface function in the Quicksand interface. --- src/3rd_party/faiss/Index.cpp | 119 -------------- src/3rd_party/faiss/Index.h | 177 --------------------- src/3rd_party/faiss/IndexLSH.cpp | 224 -------------------------- src/3rd_party/faiss/IndexLSH.h | 90 ----------- src/3rd_party/faiss/utils/hamming-inl.h | 10 +- src/3rd_party/faiss/utils/hamming.h | 4 +- src/CMakeLists.txt | 1 + src/command/marian_conv.cpp | 39 ++++- src/data/shortlist.cpp | 74 ++------- src/data/shortlist.h | 3 +- src/graph/expression_graph.h | 10 ++ src/graph/expression_operators.cpp | 12 +- src/graph/expression_operators.h | 11 +- src/graph/node_initializers.cpp | 30 ++-- src/graph/node_initializers.h | 12 +- src/graph/node_operators_binary.h | 22 ++- src/graph/node_operators_unary.h | 41 ++++- src/layers/lsh.cpp | 233 ++++++++++++++++++++++++++++ src/layers/lsh.h | 49 ++++++ src/microsoft/quicksand.cpp | 24 ++- src/microsoft/quicksand.h | 5 +- src/tensors/cpu/expression_graph_packable.h | 36 ++++- src/tensors/tensor.h | 7 +- src/training/training_state.h | 3 +- 24 files changed, 499 insertions(+), 737 deletions(-) delete mode 100644 src/3rd_party/faiss/Index.cpp delete mode 100644 src/3rd_party/faiss/IndexLSH.cpp delete mode 100644 src/3rd_party/faiss/IndexLSH.h create mode 100644 src/layers/lsh.cpp create mode 100644 src/layers/lsh.h (limited to 'src') diff --git a/src/3rd_party/faiss/Index.cpp b/src/3rd_party/faiss/Index.cpp deleted file mode 100644 index eac5f3d9..00000000 --- a/src/3rd_party/faiss/Index.cpp +++ /dev/null @@ -1,119 +0,0 @@ -/** - * Copyright (c) Facebook, Inc. and its affiliates. - * - * This source code is licensed under the MIT license found in the - * LICENSE file in the root directory of this source tree. - */ - -// -*- c++ -*- - -#include "Index.h" -#include "common/logging.h" -#include - -namespace faiss { - -Index::~Index () -{ -} - - -void Index::train(idx_t /*n*/, const float* /*x*/) { - // does nothing by default -} - - -void Index::range_search (idx_t , const float *, float, - RangeSearchResult *) const -{ - ABORT ("range search not implemented"); -} - -void Index::assign (idx_t n, const float * x, idx_t * labels, idx_t k) -{ - float * distances = new float[n * k]; - ScopeDeleter del(distances); - search (n, x, k, distances, labels); -} - -void Index::add_with_ids( - idx_t /*n*/, - const float* /*x*/, - const idx_t* /*xids*/) { - ABORT ("add_with_ids not implemented for this type of index"); -} - -size_t Index::remove_ids(const IDSelector& /*sel*/) { - ABORT ("remove_ids not implemented for this type of index"); - return -1; -} - - -void Index::reconstruct (idx_t, float * ) const { - ABORT ("reconstruct not implemented for this type of index"); -} - - -void Index::reconstruct_n (idx_t i0, idx_t ni, float *recons) const { - for (idx_t i = 0; i < ni; i++) { - reconstruct (i0 + i, recons + i * d); - } -} - - -void Index::search_and_reconstruct (idx_t n, const float *x, idx_t k, - float *distances, idx_t *labels, - float *recons) const { - search (n, x, k, distances, labels); - for (idx_t i = 0; i < n; ++i) { - for (idx_t j = 0; j < k; ++j) { - idx_t ij = i * k + j; - idx_t key = labels[ij]; - float* reconstructed = recons + ij * d; - if (key < 0) { - // Fill with NaNs - memset(reconstructed, -1, sizeof(*reconstructed) * d); - } else { - reconstruct (key, reconstructed); - } - } - } -} - -void Index::compute_residual (const float * x, - float * residual, idx_t key) const { - reconstruct (key, residual); - for (size_t i = 0; i < d; i++) { - residual[i] = x[i] - residual[i]; - } -} - -void Index::compute_residual_n (idx_t n, const float* xs, - float* residuals, - const idx_t* keys) const { -//#pragma omp parallel for - for (idx_t i = 0; i < n; ++i) { - compute_residual(&xs[i * d], &residuals[i * d], keys[i]); - } -} - - - -size_t Index::sa_code_size () const -{ - ABORT ("standalone codec not implemented for this type of index"); -} - -void Index::sa_encode (idx_t, const float *, - uint8_t *) const -{ - ABORT ("standalone codec not implemented for this type of index"); -} - -void Index::sa_decode (idx_t, const uint8_t *, - float *) const -{ - ABORT ("standalone codec not implemented for this type of index"); -} - -} diff --git a/src/3rd_party/faiss/Index.h b/src/3rd_party/faiss/Index.h index deaabcaa..24765f7d 100644 --- a/src/3rd_party/faiss/Index.h +++ b/src/3rd_party/faiss/Index.h @@ -39,11 +39,6 @@ namespace faiss { -/// Forward declarations see AuxIndexStructures.h -struct IDSelector; -struct RangeSearchResult; -struct DistanceComputer; - /** Abstract structure for an index, supports adding vectors and searching them. * * All vectors provided at add or search time are 32-bit float arrays, @@ -53,178 +48,6 @@ struct Index { using idx_t = int64_t; ///< all indices are this type using component_t = float; using distance_t = float; - - int d; ///< vector dimension - idx_t ntotal; ///< total nb of indexed vectors - bool verbose; ///< verbosity level - - /// set if the Index does not require training, or if training is - /// done already - bool is_trained; - - /// type of metric this index uses for search - MetricType metric_type; - float metric_arg; ///< argument of the metric type - - explicit Index (idx_t d = 0, MetricType metric = METRIC_L2): - d((int)d), - ntotal(0), - verbose(false), - is_trained(true), - metric_type (metric), - metric_arg(0) {} - - virtual ~Index (); - - - /** Perform training on a representative set of vectors - * - * @param n nb of training vectors - * @param x training vecors, size n * d - */ - virtual void train(idx_t n, const float* x); - - /** Add n vectors of dimension d to the index. - * - * Vectors are implicitly assigned labels ntotal .. ntotal + n - 1 - * This function slices the input vectors in chuncks smaller than - * blocksize_add and calls add_core. - * @param x input matrix, size n * d - */ - virtual void add (idx_t n, const float *x) = 0; - - /** Same as add, but stores xids instead of sequential ids. - * - * The default implementation fails with an assertion, as it is - * not supported by all indexes. - * - * @param xids if non-null, ids to store for the vectors (size n) - */ - virtual void add_with_ids (idx_t n, const float * x, const idx_t *xids); - - /** query n vectors of dimension d to the index. - * - * return at most k vectors. If there are not enough results for a - * query, the result array is padded with -1s. - * - * @param x input vectors to search, size n * d - * @param labels output labels of the NNs, size n*k - * @param distances output pairwise distances, size n*k - */ - virtual void search (idx_t n, const float *x, idx_t k, - float *distances, idx_t *labels) const = 0; - - /** query n vectors of dimension d to the index. - * - * return all vectors with distance < radius. Note that many - * indexes do not implement the range_search (only the k-NN search - * is mandatory). - * - * @param x input vectors to search, size n * d - * @param radius search radius - * @param result result table - */ - virtual void range_search (idx_t n, const float *x, float radius, - RangeSearchResult *result) const; - - /** return the indexes of the k vectors closest to the query x. - * - * This function is identical as search but only return labels of neighbors. - * @param x input vectors to search, size n * d - * @param labels output labels of the NNs, size n*k - */ - void assign (idx_t n, const float * x, idx_t * labels, idx_t k = 1); - - /// removes all elements from the database. - virtual void reset() = 0; - - /** removes IDs from the index. Not supported by all - * indexes. Returns the number of elements removed. - */ - virtual size_t remove_ids (const IDSelector & sel); - - /** Reconstruct a stored vector (or an approximation if lossy coding) - * - * this function may not be defined for some indexes - * @param key id of the vector to reconstruct - * @param recons reconstucted vector (size d) - */ - virtual void reconstruct (idx_t key, float * recons) const; - - /** Reconstruct vectors i0 to i0 + ni - 1 - * - * this function may not be defined for some indexes - * @param recons reconstucted vector (size ni * d) - */ - virtual void reconstruct_n (idx_t i0, idx_t ni, float *recons) const; - - /** Similar to search, but also reconstructs the stored vectors (or an - * approximation in the case of lossy coding) for the search results. - * - * If there are not enough results for a query, the resulting arrays - * is padded with -1s. - * - * @param recons reconstructed vectors size (n, k, d) - **/ - virtual void search_and_reconstruct (idx_t n, const float *x, idx_t k, - float *distances, idx_t *labels, - float *recons) const; - - /** Computes a residual vector after indexing encoding. - * - * The residual vector is the difference between a vector and the - * reconstruction that can be decoded from its representation in - * the index. The residual can be used for multiple-stage indexing - * methods, like IndexIVF's methods. - * - * @param x input vector, size d - * @param residual output residual vector, size d - * @param key encoded index, as returned by search and assign - */ - virtual void compute_residual (const float * x, - float * residual, idx_t key) const; - - /** Computes a residual vector after indexing encoding (batch form). - * Equivalent to calling compute_residual for each vector. - * - * The residual vector is the difference between a vector and the - * reconstruction that can be decoded from its representation in - * the index. The residual can be used for multiple-stage indexing - * methods, like IndexIVF's methods. - * - * @param n number of vectors - * @param xs input vectors, size (n x d) - * @param residuals output residual vectors, size (n x d) - * @param keys encoded index, as returned by search and assign - */ - virtual void compute_residual_n (idx_t n, const float* xs, - float* residuals, - const idx_t* keys) const; - - /* The standalone codec interface */ - - /** size of the produced codes in bytes */ - virtual size_t sa_code_size () const; - - /** encode a set of vectors - * - * @param n number of vectors - * @param x input vectors, size n * d - * @param bytes output encoded vectors, size n * sa_code_size() - */ - virtual void sa_encode (idx_t n, const float *x, - uint8_t *bytes) const; - - /** encode a set of vectors - * - * @param n number of vectors - * @param bytes input encoded vectors, size n * sa_code_size() - * @param x output vectors, size n * d - */ - virtual void sa_decode (idx_t n, const uint8_t *bytes, - float *x) const; - - }; } diff --git a/src/3rd_party/faiss/IndexLSH.cpp b/src/3rd_party/faiss/IndexLSH.cpp deleted file mode 100644 index 6df84331..00000000 --- a/src/3rd_party/faiss/IndexLSH.cpp +++ /dev/null @@ -1,224 +0,0 @@ -/** - * Copyright (c) Facebook, Inc. and its affiliates. - * - * This source code is licensed under the MIT license found in the - * LICENSE file in the root directory of this source tree. - */ - -// -*- c++ -*- - -#include - -#include -#include - -#include - -#include -#include "common/logging.h" - - -namespace faiss { - -/*************************************************************** - * IndexLSH - ***************************************************************/ - - -IndexLSH::IndexLSH (idx_t d, int nbits, bool rotate_data, bool train_thresholds): - Index(d), nbits(nbits), rotate_data(rotate_data), - train_thresholds (train_thresholds), rrot(d, nbits) -{ - is_trained = !train_thresholds; - - bytes_per_vec = (nbits + 7) / 8; - - if (rotate_data) { - rrot.init(5); - } else { - ABORT_UNLESS(d >= nbits, "d >= nbits"); - } -} - -IndexLSH::IndexLSH (): - nbits (0), bytes_per_vec(0), rotate_data (false), train_thresholds (false) -{ -} - - -const float * IndexLSH::apply_preprocess (idx_t n, const float *x) const -{ - - float *xt = nullptr; - if (rotate_data) { - // also applies bias if exists - xt = rrot.apply (n, x); - } else if (d != nbits) { - assert (nbits < d); - xt = new float [nbits * n]; - float *xp = xt; - for (idx_t i = 0; i < n; i++) { - const float *xl = x + i * d; - for (int j = 0; j < nbits; j++) - *xp++ = xl [j]; - } - } - - if (train_thresholds) { - - if (xt == NULL) { - xt = new float [nbits * n]; - memcpy (xt, x, sizeof(*x) * n * nbits); - } - - float *xp = xt; - for (idx_t i = 0; i < n; i++) - for (int j = 0; j < nbits; j++) - *xp++ -= thresholds [j]; - } - - return xt ? xt : x; -} - - - -void IndexLSH::train (idx_t n, const float *x) -{ - if (train_thresholds) { - thresholds.resize (nbits); - train_thresholds = false; - const float *xt = apply_preprocess (n, x); - ScopeDeleter del (xt == x ? nullptr : xt); - train_thresholds = true; - - float * transposed_x = new float [n * nbits]; - ScopeDeleter del2 (transposed_x); - - for (idx_t i = 0; i < n; i++) - for (idx_t j = 0; j < nbits; j++) - transposed_x [j * n + i] = xt [i * nbits + j]; - - for (idx_t i = 0; i < nbits; i++) { - float *xi = transposed_x + i * n; - // std::nth_element - std::sort (xi, xi + n); - if (n % 2 == 1) - thresholds [i] = xi [n / 2]; - else - thresholds [i] = (xi [n / 2 - 1] + xi [n / 2]) / 2; - - } - } - is_trained = true; -} - - -void IndexLSH::add (idx_t n, const float *x) -{ - ABORT_UNLESS (is_trained, "is_trained"); - codes.resize ((ntotal + n) * bytes_per_vec); - - sa_encode (n, x, &codes[ntotal * bytes_per_vec]); - - ntotal += n; -} - - -void IndexLSH::search ( - idx_t n, - const float *x, - idx_t k, - float *distances, - idx_t *labels) const -{ - ABORT_UNLESS (is_trained, "is_trained"); - const float *xt = apply_preprocess (n, x); - ScopeDeleter del (xt == x ? nullptr : xt); - - uint8_t * qcodes = new uint8_t [n * bytes_per_vec]; - ScopeDeleter del2 (qcodes); - - fvecs2bitvecs (xt, qcodes, nbits, n); - - int * idistances = new int [n * k]; - ScopeDeleter del3 (idistances); - - int_maxheap_array_t res = { size_t(n), size_t(k), labels, idistances}; - - hammings_knn_hc (&res, qcodes, codes.data(), - ntotal, bytes_per_vec, true); - - - // convert distances to floats - for (int i = 0; i < k * n; i++) - distances[i] = idistances[i]; - -} - - -void IndexLSH::transfer_thresholds (LinearTransform *vt) { - if (!train_thresholds) return; - ABORT_UNLESS (nbits == vt->d_out, "nbits == vt->d_out"); - if (!vt->have_bias) { - vt->b.resize (nbits, 0); - vt->have_bias = true; - } - for (int i = 0; i < nbits; i++) - vt->b[i] -= thresholds[i]; - train_thresholds = false; - thresholds.clear(); -} - -void IndexLSH::reset() { - codes.clear(); - ntotal = 0; -} - - -size_t IndexLSH::sa_code_size () const -{ - return bytes_per_vec; -} - -void IndexLSH::sa_encode (idx_t n, const float *x, - uint8_t *bytes) const -{ - ABORT_UNLESS (is_trained, "is_trained"); - const float *xt = apply_preprocess (n, x); - ScopeDeleter del (xt == x ? nullptr : xt); - fvecs2bitvecs (xt, bytes, nbits, n); -} - -void IndexLSH::sa_decode (idx_t n, const uint8_t *bytes, - float *x) const -{ - float *xt = x; - ScopeDeleter del; - if (rotate_data || nbits != d) { - xt = new float [n * nbits]; - del.set(xt); - } - bitvecs2fvecs (bytes, xt, nbits, n); - - if (train_thresholds) { - float *xp = xt; - for (idx_t i = 0; i < n; i++) { - for (int j = 0; j < nbits; j++) { - *xp++ += thresholds [j]; - } - } - } - - if (rotate_data) { - rrot.reverse_transform (n, xt, x); - } else if (nbits != d) { - for (idx_t i = 0; i < n; i++) { - memcpy (x + i * d, xt + i * nbits, - nbits * sizeof(xt[0])); - } - } -} - - - -} // namespace faiss diff --git a/src/3rd_party/faiss/IndexLSH.h b/src/3rd_party/faiss/IndexLSH.h deleted file mode 100644 index 66435363..00000000 --- a/src/3rd_party/faiss/IndexLSH.h +++ /dev/null @@ -1,90 +0,0 @@ -/** - * Copyright (c) Facebook, Inc. and its affiliates. - * - * This source code is licensed under the MIT license found in the - * LICENSE file in the root directory of this source tree. - */ - -// -*- c++ -*- - -#ifndef INDEX_LSH_H -#define INDEX_LSH_H - -#include - -#include -#include - -namespace faiss { - - -/** The sign of each vector component is put in a binary signature */ -struct IndexLSH:Index { - typedef unsigned char uint8_t; - - int nbits; ///< nb of bits per vector - int bytes_per_vec; ///< nb of 8-bits per encoded vector - bool rotate_data; ///< whether to apply a random rotation to input - bool train_thresholds; ///< whether we train thresholds or use 0 - - RandomRotationMatrix rrot; ///< optional random rotation - - std::vector thresholds; ///< thresholds to compare with - - /// encoded dataset - std::vector codes; - - IndexLSH ( - idx_t d, int nbits, - bool rotate_data = true, - bool train_thresholds = false); - - /** Preprocesses and resizes the input to the size required to - * binarize the data - * - * @param x input vectors, size n * d - * @return output vectors, size n * bits. May be the same pointer - * as x, otherwise it should be deleted by the caller - */ - const float *apply_preprocess (idx_t n, const float *x) const; - - void train(idx_t n, const float* x) override; - - void add(idx_t n, const float* x) override; - - void search( - idx_t n, - const float* x, - idx_t k, - float* distances, - idx_t* labels) const override; - - void reset() override; - - /// transfer the thresholds to a pre-processing stage (and unset - /// train_thresholds) - void transfer_thresholds (LinearTransform * vt); - - ~IndexLSH() override {} - - IndexLSH (); - - /* standalone codec interface. - * - * The vectors are decoded to +/- 1 (not 0, 1) */ - - size_t sa_code_size () const override; - - void sa_encode (idx_t n, const float *x, - uint8_t *bytes) const override; - - void sa_decode (idx_t n, const uint8_t *bytes, - float *x) const override; - -}; - - -} - - -#endif diff --git a/src/3rd_party/faiss/utils/hamming-inl.h b/src/3rd_party/faiss/utils/hamming-inl.h index d32da758..b164dc88 100644 --- a/src/3rd_party/faiss/utils/hamming-inl.h +++ b/src/3rd_party/faiss/utils/hamming-inl.h @@ -10,8 +10,8 @@ namespace faiss { -#ifdef _MSC_VER -#define bzero(p,n) (memset((p),0,(n))) +#ifdef _MSC_VER +#define bzero(p,n) (memset((p),0,(n))) #endif inline BitstringWriter::BitstringWriter(uint8_t *code, int code_size): code (code), code_size (code_size), i(0) @@ -29,7 +29,7 @@ inline void BitstringWriter::write(uint64_t x, int nbit) { i += nbit; return; } else { - int j = i >> 3; + size_t j = i >> 3; code[j++] |= x << (i & 7); i += nbit; x >>= na; @@ -57,7 +57,7 @@ inline uint64_t BitstringReader::read(int nbit) { return res; } else { int ofs = na; - int j = (i >> 3) + 1; + size_t j = (i >> 3) + 1; i += nbit; nbit -= na; while (nbit > 8) { @@ -160,7 +160,7 @@ struct HammingComputer20 { void set (const uint8_t *a8, int code_size) { assert (code_size == 20); const uint64_t *a = (uint64_t *)a8; - a0 = a[0]; a1 = a[1]; a2 = a[2]; + a0 = a[0]; a1 = a[1]; a2 = (uint32_t)a[2]; } inline int hamming (const uint8_t *b8) const { diff --git a/src/3rd_party/faiss/utils/hamming.h b/src/3rd_party/faiss/utils/hamming.h index 762d3773..0c89c4d1 100644 --- a/src/3rd_party/faiss/utils/hamming.h +++ b/src/3rd_party/faiss/utils/hamming.h @@ -31,7 +31,7 @@ #ifdef _MSC_VER #include // needed for some intrinsics in -#define __builtin_popcountl __popcnt64 +#define __builtin_popcountl __popcnt64 #endif /* The Hamming distance type */ @@ -116,7 +116,7 @@ struct BitstringReader { extern size_t hamming_batch_size; static inline int popcount64(uint64_t x) { - return __builtin_popcountl(x); + return (int)__builtin_popcountl(x); } diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index d2fd269f..1f5db423 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -75,6 +75,7 @@ set(MARIAN_SOURCES layers/embedding.cpp layers/output.cpp layers/logits.cpp + layers/lsh.cpp rnn/cells.cpp rnn/attention.cpp diff --git a/src/command/marian_conv.cpp b/src/command/marian_conv.cpp index 26cac858..e0e89d2b 100644 --- a/src/command/marian_conv.cpp +++ b/src/command/marian_conv.cpp @@ -2,6 +2,7 @@ #include "common/cli_wrapper.h" #include "tensors/cpu/expression_graph_packable.h" #include "onnx/expression_graph_onnx_exporter.h" +#include "layers/lsh.h" #include @@ -25,6 +26,9 @@ int main(int argc, char** argv) { cli->add("--gemm-type,-g", "GEMM Type to be used: float32, packed16, packed8avx2, packed8avx512, " "intgemm8, intgemm8ssse3, intgemm8avx2, intgemm8avx512, intgemm16, intgemm16sse2, intgemm16avx2, intgemm16avx512", "float32"); + cli->add>("--add-lsh", + "Encode output matrix and optional rotation matrix into model file. " + "arg1: number of bits in LSH encoding, arg2: name of output weights matrix")->implicit_val("1024 Wemb"); cli->add>("--vocabs,-V", "Vocabulary file, required for ONNX export"); cli->parse(argc, argv); options->merge(config); @@ -34,6 +38,16 @@ int main(int argc, char** argv) { auto exportAs = options->get("export-as"); auto vocabPaths = options->get>("vocabs");// , std::vector()); + + bool addLsh = options->hasAndNotEmpty("add-lsh"); + int lshNBits = 1024; + std::string lshOutputWeights = "Wemb"; + if(addLsh) { + auto lshParams = options->get>("add-lsh"); + lshNBits = std::stoi(lshParams[0]); + if(lshParams.size() > 1) + lshOutputWeights = lshParams[1]; + } // We accept any type here and will later croak during packAndSave if the type cannot be used for conversion Type saveGemmType = typeFromString(options->get("gemm-type", "float32")); @@ -45,23 +59,36 @@ int main(int argc, char** argv) { marian::io::getYamlFromModel(config, "special:model.yml", modelFrom); configStr << config; - auto load = [&](Ptr graph) { + if (exportAs == "marian-bin") { + auto graph = New(); graph->setDevice(CPU0); graph->load(modelFrom); + + if(addLsh) { + // Add dummy parameters for the LSH before the model gets actually initialized. + // This create the parameters with useless values in the tensors, but it gives us the memory we need. + graph->setReloaded(false); + lsh::addDummyParameters(graph, /*weights=*/lshOutputWeights, /*nBits=*/lshNBits); + graph->setReloaded(true); + } + graph->forward(); // run the initializers - }; + if(addLsh) { + // After initialization, hijack the paramters for the LSH and force-overwrite with correct values. + // Once this is done we can just pack and save as normal. + lsh::overwriteDummyParameters(graph, /*weights=*/lshOutputWeights); + } - if (exportAs == "marian-bin") { - auto graph = New(); - load(graph); // added a flag if the weights needs to be packed or not graph->packAndSave(modelTo, configStr.str(), /* --gemm-type */ saveGemmType, Type::float32); } else if (exportAs == "onnx-encode") { #ifdef USE_ONNX auto graph = New(); - load(graph); + graph->setDevice(CPU0); + graph->load(modelFrom); + graph->forward(); // run the initializers auto modelOptions = New(config)->with("vocabs", vocabPaths, "inference", true); graph->exportToONNX(modelTo, modelOptions, vocabPaths); diff --git a/src/data/shortlist.cpp b/src/data/shortlist.cpp index f7e229ff..396c6ba4 100644 --- a/src/data/shortlist.cpp +++ b/src/data/shortlist.cpp @@ -1,10 +1,7 @@ #include "data/shortlist.h" #include "microsoft/shortlist/utils/ParameterTree.h" #include "marian.h" - -#if BLAS_FOUND -#include "3rd_party/faiss/IndexLSH.h" -#endif +#include "layers/lsh.h" namespace marian { namespace data { @@ -47,7 +44,6 @@ void Shortlist::filter(Expr input, Expr weights, bool isLegacyUntransposedW, Exp Shape kShape({k}); indicesExpr_ = lambda({input, weights}, kShape, Type::uint32, forward); - //std::cerr << "indicesExpr_=" << indicesExpr_->shape() << std::endl; createCachedTensors(weights, isLegacyUntransposedW, b, lemmaEt, k); initialized_ = true; } @@ -78,12 +74,10 @@ void Shortlist::createCachedTensors(Expr weights, } /////////////////////////////////////////////////////////////////////////////////// -Ptr LSHShortlist::index_; -std::mutex LSHShortlist::mutex_; LSHShortlist::LSHShortlist(int k, int nbits, size_t lemmaSize) -: Shortlist(std::vector()) -, k_(k), nbits_(nbits), lemmaSize_(lemmaSize) { +: Shortlist(std::vector()), + k_(k), nbits_(nbits), lemmaSize_(lemmaSize) { } WordIndex LSHShortlist::reverseMap(int beamIdx, int batchIdx, int idx) const { @@ -99,67 +93,23 @@ Expr LSHShortlist::getIndicesExpr() const { } void LSHShortlist::filter(Expr input, Expr weights, bool isLegacyUntransposedW, Expr b, Expr lemmaEt) { -#if BLAS_FOUND + ABORT_IF(input->graph()->getDeviceId().type == DeviceType::gpu, "LSH index (--output-approx-knn) currently not implemented for GPU"); - int currBeamSize = input->shape()[0]; - int batchSize = input->shape()[2]; - int numHypos = currBeamSize * batchSize; - - auto forward = [this, numHypos](Expr out, const std::vector& inputs) { - auto query = inputs[0]; - auto values = inputs[1]; - int dim = values->shape()[-1]; - - mutex_.lock(); - if(!index_) { - LOG(info, "Building LSH index for vector dim {} and with hash size {} bits", dim, nbits_); - index_.reset(new faiss::IndexLSH(dim, nbits_, - /*rotate=*/dim != nbits_, - /*train_thesholds*/false)); - index_->train(lemmaSize_, values->val()->data()); - index_->add( lemmaSize_, values->val()->data()); - } - mutex_.unlock(); - - int qRows = query->shape().elements() / dim; - std::vector distances(qRows * k_); - std::vector ids(qRows * k_); - - index_->search(qRows, query->val()->data(), k_, - distances.data(), ids.data()); - - indices_.clear(); - for(auto iter = ids.begin(); iter != ids.end(); ++iter) { - faiss::Index::idx_t id = *iter; - indices_.push_back((WordIndex)id); - } - - for (size_t hypoIdx = 0; hypoIdx < numHypos; ++hypoIdx) { - size_t startIdx = k_ * hypoIdx; - size_t endIdx = startIdx + k_; - std::sort(indices_.begin() + startIdx, indices_.begin() + endIdx); - } - out->val()->set(indices_); - }; - - Shape kShape({currBeamSize, batchSize, k_}); - indicesExpr_ = lambda({input, weights}, kShape, Type::uint32, forward); + indicesExpr_ = callback(lsh::search(input, weights, k_, nbits_, (int)lemmaSize_), + [this](Expr node) { + node->val()->get(indices_); // set the value of the field indices_ whenever the graph traverses this node + }); createCachedTensors(weights, isLegacyUntransposedW, b, lemmaEt, k_); - -#else - input; weights; isLegacyUntransposedW; b; lemmaEt; - ABORT("LSH output layer requires a CPU BLAS library"); -#endif } void LSHShortlist::createCachedTensors(Expr weights, - bool isLegacyUntransposedW, - Expr b, - Expr lemmaEt, - int k) { + bool isLegacyUntransposedW, + Expr b, + Expr lemmaEt, + int k) { int currBeamSize = indicesExpr_->shape()[0]; int batchSize = indicesExpr_->shape()[1]; ABORT_IF(isLegacyUntransposedW, "Legacy untranspose W not yet tested"); diff --git a/src/data/shortlist.h b/src/data/shortlist.h index a75d2c4b..d3841b21 100644 --- a/src/data/shortlist.h +++ b/src/data/shortlist.h @@ -25,7 +25,8 @@ namespace data { class Shortlist { protected: std::vector indices_; // // [packed shortlist index] -> word index, used to select columns from output embeddings - Expr indicesExpr_; + Expr indicesExpr_; // cache an expression that contains the short list indices + Expr cachedShortWt_; // short-listed version, cached (cleared by clear()) Expr cachedShortb_; // these match the current value of shortlist_ Expr cachedShortLemmaEt_; diff --git a/src/graph/expression_graph.h b/src/graph/expression_graph.h index fce7d532..553a5d63 100644 --- a/src/graph/expression_graph.h +++ b/src/graph/expression_graph.h @@ -646,6 +646,16 @@ public: return it->second; } + /** + * Return the Parameters object related to the graph by elementType. + * The Parameters object holds the whole set of the parameter nodes of the given type. + */ + Ptr& params(Type elementType) { + auto it = paramsByElementType_.find(elementType); + ABORT_IF(it == paramsByElementType_.end(), "Parameter object for type {} does not exist", defaultElementType_); + return it->second; + } + /** * Set default element type for the graph. * The default value is used if some node type is not specified. diff --git a/src/graph/expression_operators.cpp b/src/graph/expression_operators.cpp index 24d12eea..560ab4e7 100644 --- a/src/graph/expression_operators.cpp +++ b/src/graph/expression_operators.cpp @@ -28,13 +28,17 @@ Expr checkpoint(Expr a) { } Expr lambda(const std::vector& nodes, Shape shape, Type type, - LambdaNodeFunctor fwd) { - return Expression(nodes, shape, type, fwd); + LambdaNodeFunctor fwd, size_t hash) { + return Expression(nodes, shape, type, fwd, hash); } Expr lambda(const std::vector& nodes, Shape shape, Type type, - LambdaNodeFunctor fwd, LambdaNodeFunctor bwd) { - return Expression(nodes, shape, type, fwd, bwd); + LambdaNodeFunctor fwd, LambdaNodeFunctor bwd, size_t hash) { + return Expression(nodes, shape, type, fwd, bwd, hash); +} + +Expr callback(Expr node, LambdaNodeCallback call) { + return Expression(node, call); } // logistic function. Note: scipy name is expit() diff --git a/src/graph/expression_operators.h b/src/graph/expression_operators.h index 6c7e5758..e34ddc8a 100644 --- a/src/graph/expression_operators.h +++ b/src/graph/expression_operators.h @@ -26,12 +26,19 @@ typedef std::function& in)> LambdaNodeFun /** * Arbitrary node with forward operation only. */ -Expr lambda(const std::vector& nodes, Shape shape, Type type, LambdaNodeFunctor fwd); +Expr lambda(const std::vector& nodes, Shape shape, Type type, LambdaNodeFunctor fwd, size_t hash = 0); /** * Arbitrary node with forward and backward operation. */ -Expr lambda(const std::vector& nodes, Shape shape, Type type, LambdaNodeFunctor fwd, LambdaNodeFunctor bwd); +Expr lambda(const std::vector& nodes, Shape shape, Type type, LambdaNodeFunctor fwd, LambdaNodeFunctor bwd, size_t hash = 0); + + +/** + * Convience typedef for graph @ref lambda expressions. + */ +typedef std::function LambdaNodeCallback; +Expr callback(Expr node, LambdaNodeCallback call); /** * @addtogroup graph_ops_activation Activation Functions diff --git a/src/graph/node_initializers.cpp b/src/graph/node_initializers.cpp index 4e39d1bf..e44b4828 100644 --- a/src/graph/node_initializers.cpp +++ b/src/graph/node_initializers.cpp @@ -11,6 +11,15 @@ namespace marian { namespace inits { +class DummyInit : public NodeInitializer { +public: + void apply(Tensor tensor) override { + tensor; + } +}; + +Ptr dummy() { return New(); } + class LambdaInit : public NodeInitializer { private: std::function lambda_; @@ -237,24 +246,3 @@ template Ptr range(IndexType begin, IndexType end, I } // namespace inits } // namespace marian - -#if BLAS_FOUND -#include "faiss/VectorTransform.h" - -namespace marian { -namespace inits { - -Ptr randomRotation(size_t seed) { - auto rot = [=](Tensor t) { - int rows = t->shape()[-2]; - int cols = t->shape()[-1]; - faiss::RandomRotationMatrix rrot(cols, rows); // transposed in faiss - rrot.init((int)seed); - t->set(rrot.A); - }; - return fromLambda(rot, Type::float32); -} - -} // namespace inits -} // namespace marian -#endif diff --git a/src/graph/node_initializers.h b/src/graph/node_initializers.h index 7cdb4183..5e9f8013 100644 --- a/src/graph/node_initializers.h +++ b/src/graph/node_initializers.h @@ -35,6 +35,11 @@ public: virtual ~NodeInitializer() {} }; +/** + * Dummy do-nothing initializer. Mostly for testing. + */ +Ptr dummy(); + /** * Use a lambda function of form [](Tensor t) { do something with t } to initialize tensor. * @param func functor @@ -263,13 +268,6 @@ Ptr fromWord2vec(const std::string& file, */ Ptr sinusoidalPositionEmbeddings(int start); -/** - * Computes a random rotation matrix for LSH hashing. - * This is part of a hash function. The values are orthonormal and computed via - * QR decomposition. Same seed results in same random rotation. - */ -Ptr randomRotation(size_t seed = Config::seed); - /** * Computes the equivalent of Python's range(). * Computes a range from begin to end-1, like Python's range(). diff --git a/src/graph/node_operators_binary.h b/src/graph/node_operators_binary.h index 169b1420..a180bb5c 100644 --- a/src/graph/node_operators_binary.h +++ b/src/graph/node_operators_binary.h @@ -21,20 +21,26 @@ private: std::unique_ptr forward_; std::unique_ptr backward_; + size_t externalHash_; + public: LambdaNodeOp(Inputs inputs, Shape shape, Type type, - LambdaNodeFunctor forward) + LambdaNodeFunctor forward, + size_t externalHash = 0) : NaryNodeOp(inputs, shape, type), - forward_(new LambdaNodeFunctor(forward)) { + forward_(new LambdaNodeFunctor(forward)), + externalHash_(externalHash) { Node::trainable_ = !!backward_; } LambdaNodeOp(Inputs inputs, Shape shape, Type type, LambdaNodeFunctor forward, - LambdaNodeFunctor backward) + LambdaNodeFunctor backward, + size_t externalHash = 0) : NaryNodeOp(inputs, shape, type), forward_(new LambdaNodeFunctor(forward)), - backward_(new LambdaNodeFunctor(backward)) { + backward_(new LambdaNodeFunctor(backward)), + externalHash_(externalHash) { } void forward() override { @@ -50,8 +56,12 @@ public: virtual size_t hash() override { size_t seed = NaryNodeOp::hash(); - util::hash_combine(seed, forward_.get()); - util::hash_combine(seed, backward_.get()); + if(externalHash_ != 0) { + util::hash_combine(seed, externalHash_); + } else { + util::hash_combine(seed, forward_.get()); + util::hash_combine(seed, backward_.get()); + } return seed; } diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h index 82b02a65..448b4c4a 100644 --- a/src/graph/node_operators_unary.h +++ b/src/graph/node_operators_unary.h @@ -795,7 +795,7 @@ private: }; class ReshapeNodeOp : public UnaryNodeOp { -private: +protected: friend class SerializationHelpers; Expr reshapee_; @@ -858,6 +858,45 @@ public: } }; +// @TODO: add version with access to backward step +// This allows to attach a lambda function to any node during the execution. It is a non-operation otherwise +// i.e. doesn't consume any memory or take any time to execute (it's a reshape onto itself) other than the +// compute in the lambda function. It gets called after the forward step of the argument node. +class CallbackNodeOp : public ReshapeNodeOp { +private: + typedef std::function LambdaNodeCallback; + std::unique_ptr callback_; + +public: + CallbackNodeOp(Expr node, LambdaNodeCallback callback) + : ReshapeNodeOp(node, node->shape()), + callback_(new LambdaNodeCallback(callback)) { + } + + void forward() override { + (*callback_)(ReshapeNodeOp::reshapee_); + } + + const std::string type() override { return "callback"; } + + virtual size_t hash() override { + size_t seed = ReshapeNodeOp::hash(); + util::hash_combine(seed, callback_.get()); + return seed; + } + + virtual bool equal(Expr node) override { + if(!ReshapeNodeOp::equal(node)) + return false; + auto cnode = std::dynamic_pointer_cast(node); + if(!cnode) + return false; + if(callback_ != cnode->callback_) // pointer compare on purpose + return false; + return true; + } +}; + // @TODO: review if still required as this is an ugly hack anyway. // Memory less operator that clips gradients during backward step // Executes this as an additional operation on the gradient. diff --git a/src/layers/lsh.cpp b/src/layers/lsh.cpp new file mode 100644 index 00000000..89b482f4 --- /dev/null +++ b/src/layers/lsh.cpp @@ -0,0 +1,233 @@ +#include "layers/lsh.h" +#include "tensors/tensor_operators.h" +#include "common/utils.h" + +#include "3rd_party/faiss/utils/hamming.h" +#include "3rd_party/faiss/Index.h" + +#if BLAS_FOUND +#include "3rd_party/faiss/VectorTransform.h" +#endif + + +namespace marian { +namespace lsh { + +int bytesPerVector(int nBits) { + return (nBits + 7) / 8; +} + +void fillRandomRotationMatrix(Tensor output, Ptr allocator) { +#if BLAS_FOUND + int nRows = output->shape()[-2]; + int nBits = output->shape()[-1]; + + // @TODO re-implement using Marian code so it uses the correct random generator etc. + faiss::RandomRotationMatrix rrot(nRows, nBits); + // Then we do not need to use this seed at all + rrot.init(5); // currently set to 5 following the default from FAISS, this could be any number really. + + // The faiss random rotation matrix is column major, hence we create a temporary tensor, + // copy the rotation matrix into it and transpose to output. + Shape tempShape = {nBits, nRows}; + auto memory = allocator->alloc(requiredBytes(tempShape, output->type())); + auto temp = TensorBase::New(memory, + tempShape, + output->type(), + output->getBackend()); + temp->set(rrot.A); + TransposeND(output, temp, {0, 1, 3, 2}); + allocator->free(memory); +#else + output; allocator; + ABORT("LSH with rotation matrix requires Marian to be compiled with a BLAS library"); +#endif +} + +void encode(Tensor output, Tensor input) { + int nBits = input->shape()[-1]; // number of bits is equal last dimension of float matrix + int nRows = input->shape().elements() / nBits; + faiss::fvecs2bitvecs(input->data(), output->data(), (size_t)nBits, (size_t)nRows); +} + +void encodeWithRotation(Tensor output, Tensor input, Tensor rotation, Ptr allocator) { + int nBits = input->shape()[-1]; // number of bits is equal last dimension of float matrix unless we rotate + int nRows = input->shape().elements() / nBits; + + Tensor tempInput = input; + MemoryPiece::PtrType memory; + if(rotation) { + int nBitsRot = rotation->shape()[-1]; + Shape tempShape = {nRows, nBitsRot}; + memory = allocator->alloc(requiredBytes(tempShape, rotation->type())); + tempInput = TensorBase::New(memory, tempShape, rotation->type(), rotation->getBackend()); + Prod(tempInput, input, rotation, false, false, 0.f, 1.f); + } + encode(output, tempInput); + + if(memory) + allocator->free(memory); +}; + +Expr encode(Expr input, Expr rotation) { + auto encodeFwd = [](Expr out, const std::vector& inputs) { + if(inputs.size() == 1) { + encode(out->val(), inputs[0]->val()); + } else if(inputs.size() == 2) { + encodeWithRotation(out->val(), inputs[0]->val(), inputs[1]->val(), out->graph()->allocator()); + } else { + ABORT("Too many inputs to encode??"); + } + }; + + // Use the address of the first lambda function as an immutable hash. Making it static and const makes sure + // that this hash value will not change. Next pass the hash into the lambda functor were it will be used + // to identify this unique operation. Marian's ExpressionGraph can automatically memoize and identify nodes + // that operate only on immutable nodes (parameters) and have the same hash. This way we make sure that the + // codes node won't actually get recomputed throughout ExpressionGraph lifetime. `codes` will be reused + // and the body of the lambda will not be called again. This does however build one index per graph. + static const size_t encodeHash = (size_t)&encodeFwd; + + Shape encodedShape = input->shape(); + + int nBits = rotation ? rotation->shape()[-1] : input->shape()[-1]; + encodedShape.set(-1, bytesPerVector(nBits)); + std::vector inputs = {input}; + if(rotation) + inputs.push_back(rotation); + return lambda(inputs, encodedShape, Type::uint8, encodeFwd, encodeHash); +} + +Expr rotator(Expr weights, int nBits) { + auto rotator = [](Expr out, const std::vector& inputs) { + inputs; + fillRandomRotationMatrix(out->val(), out->graph()->allocator()); + }; + + static const size_t rotatorHash = (size_t)&rotator; + int dim = weights->shape()[-1]; + return lambda({weights}, {dim, nBits}, Type::float32, rotator, rotatorHash); +} + +Expr searchEncoded(Expr encodedQuery, Expr encodedWeights, int k, int firstNRows) { + ABORT_IF(encodedQuery->shape()[-1] != encodedWeights->shape()[-1], + "Query and index bit vectors need to be of same size ({} != {})", encodedQuery->shape()[-1], encodedWeights->shape()[-1]); + + int currBeamSize = encodedQuery->shape()[0]; + int batchSize = encodedQuery->shape()[2]; + int numHypos = currBeamSize * batchSize; + + auto search = [=](Expr out, const std::vector& inputs) { + Expr encodedQuery = inputs[0]; + Expr encodedWeights = inputs[1]; + + int bytesPerVector = encodedWeights->shape()[-1]; + int wRows = encodedWeights->shape().elements() / bytesPerVector; + + // we use this with Factored Segmenter to skip the factor embeddings at the end + if(firstNRows != 0) + wRows = firstNRows; + + int qRows = encodedQuery->shape().elements() / bytesPerVector; + + uint8_t* qCodes = encodedQuery->val()->data(); + uint8_t* wCodes = encodedWeights->val()->data(); + + // use actual faiss code for performing the hamming search. + std::vector distances(qRows * k); + std::vector ids(qRows * k); + faiss::int_maxheap_array_t res = {(size_t)qRows, (size_t)k, ids.data(), distances.data()}; + faiss::hammings_knn_hc(&res, qCodes, wCodes, (size_t)wRows, (size_t)bytesPerVector, 0); + + // Copy int64_t indices to Marian index type and sort by increasing index value per hypothesis. + // The sorting is required as we later do a binary search on those values for reverse look-up. + uint32_t* outData = out->val()->data(); + for (size_t hypoIdx = 0; hypoIdx < numHypos; ++hypoIdx) { + size_t startIdx = k * hypoIdx; + size_t endIdx = startIdx + k; + for(size_t i = startIdx; i < endIdx; ++i) + outData[i] = (uint32_t)ids[i]; + std::sort(outData + startIdx, outData + endIdx); + } + }; + + Shape kShape({currBeamSize, batchSize, k}); + return lambda({encodedQuery, encodedWeights}, kShape, Type::uint32, search); +} + +Expr search(Expr query, Expr weights, int k, int nBits, int firstNRows) { + int dim = weights->shape()[-1]; + + Expr rotMat = nullptr; + if(dim != nBits) { + rotMat = weights->graph()->get("lsh_output_rotation"); + if(rotMat) { + LOG_ONCE(info, "Reusing parameter LSH rotation matrix {} with shape {}", rotMat->name(), rotMat->shape()); + } else { + LOG_ONCE(info, "Creating ad-hoc rotation matrix with shape {}", Shape({dim, nBits})); + rotMat = rotator(weights, nBits); + } + } + + Expr encodedWeights = weights->graph()->get("lsh_output_codes"); + if(encodedWeights) { + LOG_ONCE(info, "Reusing parameter LSH code matrix {} with shape {}", encodedWeights->name(), encodedWeights->shape()); + } else { + LOG_ONCE(info, "Creating ad-hoc code matrix with shape {}", Shape({weights->shape()[-2], lsh::bytesPerVector(nBits)})); + encodedWeights = encode(weights, rotMat); + } + + return searchEncoded(encode(query, rotMat), encodedWeights, k, firstNRows); +} + +class RandomRotation : public inits::NodeInitializer { +public: + void apply(Tensor tensor) override { + auto sharedAllocator = allocator_.lock(); + ABORT_IF(!sharedAllocator, "Allocator in RandomRotation has not been set or expired"); + fillRandomRotationMatrix(tensor, sharedAllocator); + } +}; + +Ptr randomRotation() { + return New(); +} + +void addDummyParameters(Ptr graph, std::string weightsName, int nBitsRot) { + auto weights = graph->get(weightsName); + + ABORT_IF(!weights, "Trying to encode non-existing weights matrix {}??", weightsName); + + int nBits = weights->shape()[-1]; + int nRows = weights->shape().elements() / nBits; + + Expr rotation; + if(nBits != nBitsRot) { + LOG(info, "Adding LSH rotation parameter lsh_output_rotation with shape {}", Shape({nBits, nBitsRot})); + rotation = graph->param("lsh_output_rotation", {nBits, nBitsRot}, inits::dummy(), Type::float32); + nBits = nBitsRot; + } + + int bytesPerVector = lsh::bytesPerVector(nBits); + LOG(info, "Adding LSH encoded weights lsh_output_codes with shape {}", Shape({nRows, bytesPerVector})); + auto codes = graph->param("lsh_output_codes", {nRows, bytesPerVector}, inits::dummy(), Type::uint8); +} + +void overwriteDummyParameters(Ptr graph, std::string weightsName) { + Expr weights = graph->get(weightsName); + Expr codes = graph->get("lsh_output_codes"); + Expr rotation = graph->get("lsh_output_rotation"); + + ABORT_IF(!weights, "Trying to encode non-existing weights matrix {}??", weightsName); + ABORT_IF(!codes, "Trying to overwrite non-existing LSH parameters lsh_output_codes??"); + + if(rotation) { + fillRandomRotationMatrix(rotation->val(), weights->graph()->allocator()); + encodeWithRotation(codes->val(), weights->val(), rotation->val(), weights->graph()->allocator()); + } else { + encode(codes->val(), weights->val()); + } +} + +} +} \ No newline at end of file diff --git a/src/layers/lsh.h b/src/layers/lsh.h new file mode 100644 index 00000000..60908238 --- /dev/null +++ b/src/layers/lsh.h @@ -0,0 +1,49 @@ +#pragma once + +#include "graph/expression_operators.h" +#include "graph/node_initializers.h" + +#include + +/** + * In this file we bascially take the faiss::IndexLSH and pick it apart so that the individual steps + * can be implemented as Marian inference operators. We can encode the inputs and weights into their + * bitwise equivalents, apply the hashing rotation (if required), and perform the actual search. + * + * This also allows to create parameters that get dumped into the model weight file. This is currently + * a bit hacky (see marian-conv), but once this is done the model can memory-map the LSH with existing + * mechanisms and no additional memory is consumed to build the index or rotation matrix. + */ + +namespace marian { +namespace lsh { + + // return the number of full bytes required to encoded that many bits + int bytesPerVector(int nBits); + + // encodes an input as a bit vector, with optional rotation + Expr encode(Expr input, Expr rotator = nullptr); + + // compute the rotation matrix (maps weights->shape()[-1] to nbits floats) + Expr rotator(Expr weights, int nbits); + + // perform the LSH search on fully encoded input and weights, return k results (indices) per input row + // @TODO: add a top-k like operator that also returns the bitwise computed distances + Expr searchEncoded(Expr encodedQuery, Expr encodedWeights, int k, int firstNRows = 0); + + // same as above, but performs encoding on the fly + Expr search(Expr query, Expr weights, int k, int nbits, int firstNRows = 0); + + // These are helper functions for encoding the LSH into the binary Marian model, used by marian-conv + void addDummyParameters(Ptr graph, std::string weightsName, int nBits); + void overwriteDummyParameters(Ptr graph, std::string weightsName); + + /** + * Computes a random rotation matrix for LSH hashing. + * This is part of a hash function. The values are orthonormal and computed via + * QR decomposition. + */ + Ptr randomRotation(); +} + +} \ No newline at end of file diff --git a/src/microsoft/quicksand.cpp b/src/microsoft/quicksand.cpp index 6476df8f..70e657a9 100644 --- a/src/microsoft/quicksand.cpp +++ b/src/microsoft/quicksand.cpp @@ -11,6 +11,7 @@ #include "data/alignment.h" #include "data/vocab_base.h" #include "tensors/cpu/expression_graph_packable.h" +#include "layers/lsh.h" #if USE_FBGEMM #include "fbgemm/Utils.h" @@ -248,7 +249,7 @@ DecoderCpuAvxVersion parseCpuAvxVersion(std::string name) { // This function converts an fp32 model into an FBGEMM based packed model. // marian defined types are used for external project as well. // The targetPrec is passed as int32_t for the exported function definition. -bool convertModel(std::string inputFile, std::string outputFile, int32_t targetPrec) { +bool convertModel(std::string inputFile, std::string outputFile, int32_t targetPrec, bool addLsh) { std::cerr << "Converting from: " << inputFile << ", to: " << outputFile << ", precision: " << targetPrec << std::endl; YAML::Node config; @@ -260,7 +261,26 @@ bool convertModel(std::string inputFile, std::string outputFile, int32_t targetP graph->setDevice(CPU0); graph->load(inputFile); - graph->forward(); + + // MJD: Note, this is a default settings which we might want to change or expose. Use this only with Polonium students. + // The LSH will not be used by default even if it exists in the model. That has to be enabled in the decoder config. + int lshNBits = 1024; + std::string lshOutputWeights = "Wemb"; + if(addLsh) { + // Add dummy parameters for the LSH before the model gets actually initialized. + // This create the parameters with useless values in the tensors, but it gives us the memory we need. + graph->setReloaded(false); + lsh::addDummyParameters(graph, /*weights=*/lshOutputWeights, /*nBits=*/lshNBits); + graph->setReloaded(true); + } + + graph->forward(); // run the initializers + + if(addLsh) { + // After initialization, hijack the paramters for the LSH and force-overwrite with correct values. + // Once this is done we can just pack and save as normal. + lsh::overwriteDummyParameters(graph, /*weights=*/lshOutputWeights); + } Type targetPrecType = (Type) targetPrec; if (targetPrecType == Type::packed16 diff --git a/src/microsoft/quicksand.h b/src/microsoft/quicksand.h index 87de1948..b710e135 100644 --- a/src/microsoft/quicksand.h +++ b/src/microsoft/quicksand.h @@ -76,7 +76,10 @@ std::vector> loadVocabs(const std::vector& vocab DecoderCpuAvxVersion getCpuAvxVersion(); DecoderCpuAvxVersion parseCpuAvxVersion(std::string name); -bool convertModel(std::string inputFile, std::string outputFile, int32_t targetPrec); +// MJD: added "addLsh" which will now break whatever compilation after update. That's on purpose. +// The calling code should be adapted, not this interface. If you need to fix things in QS because of this +// talk to me first! +bool convertModel(std::string inputFile, std::string outputFile, int32_t targetPrec, bool addLsh); } // namespace quicksand } // namespace marian diff --git a/src/tensors/cpu/expression_graph_packable.h b/src/tensors/cpu/expression_graph_packable.h index 689aa3b1..f5a9cad9 100644 --- a/src/tensors/cpu/expression_graph_packable.h +++ b/src/tensors/cpu/expression_graph_packable.h @@ -27,14 +27,17 @@ public: virtual ~ExpressionGraphPackable() {} // Convert model weights into packed format and save to IO items. - // @TODO: review this - void packAndSave(const std::string& name, const std::string& meta, Type gemmElementType = Type::float32, Type saveElementType = Type::float32) { + std::vector pack(Type gemmElementType = Type::float32, Type saveElementType = Type::float32) { std::vector ioItems; + // handle packable parameters first (a float32 parameter is packable) + auto packableParameters = paramsByElementType_[Type::float32]; // sorted by name in std::map - for (auto p : params()->getMap()) { + for (auto p : packableParameters->getMap()) { std::string pName = p.first; + LOG(info, "Processing parameter {} with shape {} and type {}", pName, p.second->shape(), p.second->value_type()); + if (!namespace_.empty()) { if (pName.substr(0, namespace_.size() + 2) == namespace_ + "::") pName = pName.substr(namespace_.size() + 2); @@ -257,6 +260,33 @@ public: } } + // Now handle all non-float32 parameters + for(auto& iter : paramsByElementType_) { + auto type = iter.first; + if(type == Type::float32) + continue; + + for (auto p : iter.second->getMap()) { + std::string pName = p.first; + LOG(info, "Processing parameter {} with shape {} and type {}", pName, p.second->shape(), p.second->value_type()); + + if (!namespace_.empty()) { + if (pName.substr(0, namespace_.size() + 2) == namespace_ + "::") + pName = pName.substr(namespace_.size() + 2); + } + + Tensor val = p.second->val(); + io::Item item; + val->get(item, pName); + ioItems.emplace_back(std::move(item)); + } + } + + return ioItems; + } + + void packAndSave(const std::string& name, const std::string& meta, Type gemmElementType = Type::float32, Type saveElementType = Type::float32) { + auto ioItems = pack(gemmElementType, saveElementType); if (!meta.empty()) io::addMetaToItems(meta, "special:model.yml", ioItems); io::saveItems(name, ioItems); diff --git a/src/tensors/tensor.h b/src/tensors/tensor.h index 10c3e7f1..a7071404 100644 --- a/src/tensors/tensor.h +++ b/src/tensors/tensor.h @@ -35,7 +35,8 @@ class TensorBase { ENABLE_INTRUSIVE_PTR(TensorBase) - // Constructors are private, use TensorBase::New(...) +protected: + // Constructors are protected, use TensorBase::New(...) TensorBase(MemoryPiece::PtrType memory, Shape shape, Type type, @@ -61,10 +62,10 @@ class TensorBase { shape_(shape), type_(type), backend_(backend) {} public: - // Use this whenever pointing to MemoryPiece + // Use this whenever pointing to TensorBase typedef IPtr PtrType; - // Use this whenever creating a pointer to MemoryPiece + // Use this whenever creating a pointer to TensorBase template static PtrType New(Args&& ...args) { return PtrType(new TensorBase(std::forward(args)...)); diff --git a/src/training/training_state.h b/src/training/training_state.h index e0c1ba5d..ce0895a2 100644 --- a/src/training/training_state.h +++ b/src/training/training_state.h @@ -142,8 +142,9 @@ public: // for periods. bool enteredNewPeriodOf(std::string schedulingParam) const { auto period = SchedulingParameter::parse(schedulingParam); + // @TODO: adapt to logical epochs ABORT_IF(period.unit == SchedulingUnit::epochs, - "Unit {} is not supported for frequency parameters (the one(s) with value {})", + "Unit {} is not supported for frequency parameters", schedulingParam); auto previousProgress = getPreviousProgressIn(period.unit); auto progress = getProgressIn(period.unit); -- cgit v1.2.3