Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorMartin Junczys-Dowmunt <Marcin.JunczysDowmunt@microsoft.com>2021-07-09 23:35:09 +0300
committerMartin Junczys-Dowmunt <Marcin.JunczysDowmunt@microsoft.com>2021-07-09 23:35:09 +0300
commit35c822eb4ea29e5445b7c75b665c4872a2cc1adb (patch)
tree06d2fe62c56eda33e428242c01c931604c298064 /src
parentd6c09b24de0c4576b4883b460b99dd108ce8f4a9 (diff)
Merged PR 19685: Marianize LSH as operators for mmapping and use in Quicksand
This PR turns the LSH index and search into a set of operators that live in the expression graph. This makes creation etc. thread-safe (one index per graph) and allows to later implement GPU versions. This allows to mmap the LSH as a Marian parameter since now we only need to turn the index into something that can be saved to disk using the existing tensors. This happens in marian_conv or the equivalent interface function in the Quicksand interface.
Diffstat (limited to 'src')
-rw-r--r--src/3rd_party/faiss/Index.cpp119
-rw-r--r--src/3rd_party/faiss/Index.h177
-rw-r--r--src/3rd_party/faiss/IndexLSH.cpp224
-rw-r--r--src/3rd_party/faiss/IndexLSH.h90
-rw-r--r--src/3rd_party/faiss/utils/hamming-inl.h10
-rw-r--r--src/3rd_party/faiss/utils/hamming.h4
-rw-r--r--src/CMakeLists.txt1
-rw-r--r--src/command/marian_conv.cpp39
-rw-r--r--src/data/shortlist.cpp74
-rw-r--r--src/data/shortlist.h3
-rw-r--r--src/graph/expression_graph.h10
-rw-r--r--src/graph/expression_operators.cpp12
-rw-r--r--src/graph/expression_operators.h11
-rw-r--r--src/graph/node_initializers.cpp30
-rw-r--r--src/graph/node_initializers.h12
-rw-r--r--src/graph/node_operators_binary.h22
-rw-r--r--src/graph/node_operators_unary.h41
-rw-r--r--src/layers/lsh.cpp233
-rw-r--r--src/layers/lsh.h49
-rw-r--r--src/microsoft/quicksand.cpp24
-rw-r--r--src/microsoft/quicksand.h5
-rw-r--r--src/tensors/cpu/expression_graph_packable.h36
-rw-r--r--src/tensors/tensor.h7
-rw-r--r--src/training/training_state.h3
24 files changed, 499 insertions, 737 deletions
diff --git a/src/3rd_party/faiss/Index.cpp b/src/3rd_party/faiss/Index.cpp
deleted file mode 100644
index eac5f3d9..00000000
--- a/src/3rd_party/faiss/Index.cpp
+++ /dev/null
@@ -1,119 +0,0 @@
-/**
- * Copyright (c) Facebook, Inc. and its affiliates.
- *
- * This source code is licensed under the MIT license found in the
- * LICENSE file in the root directory of this source tree.
- */
-
-// -*- c++ -*-
-
-#include "Index.h"
-#include "common/logging.h"
-#include <cstring>
-
-namespace faiss {
-
-Index::~Index ()
-{
-}
-
-
-void Index::train(idx_t /*n*/, const float* /*x*/) {
- // does nothing by default
-}
-
-
-void Index::range_search (idx_t , const float *, float,
- RangeSearchResult *) const
-{
- ABORT ("range search not implemented");
-}
-
-void Index::assign (idx_t n, const float * x, idx_t * labels, idx_t k)
-{
- float * distances = new float[n * k];
- ScopeDeleter<float> del(distances);
- search (n, x, k, distances, labels);
-}
-
-void Index::add_with_ids(
- idx_t /*n*/,
- const float* /*x*/,
- const idx_t* /*xids*/) {
- ABORT ("add_with_ids not implemented for this type of index");
-}
-
-size_t Index::remove_ids(const IDSelector& /*sel*/) {
- ABORT ("remove_ids not implemented for this type of index");
- return -1;
-}
-
-
-void Index::reconstruct (idx_t, float * ) const {
- ABORT ("reconstruct not implemented for this type of index");
-}
-
-
-void Index::reconstruct_n (idx_t i0, idx_t ni, float *recons) const {
- for (idx_t i = 0; i < ni; i++) {
- reconstruct (i0 + i, recons + i * d);
- }
-}
-
-
-void Index::search_and_reconstruct (idx_t n, const float *x, idx_t k,
- float *distances, idx_t *labels,
- float *recons) const {
- search (n, x, k, distances, labels);
- for (idx_t i = 0; i < n; ++i) {
- for (idx_t j = 0; j < k; ++j) {
- idx_t ij = i * k + j;
- idx_t key = labels[ij];
- float* reconstructed = recons + ij * d;
- if (key < 0) {
- // Fill with NaNs
- memset(reconstructed, -1, sizeof(*reconstructed) * d);
- } else {
- reconstruct (key, reconstructed);
- }
- }
- }
-}
-
-void Index::compute_residual (const float * x,
- float * residual, idx_t key) const {
- reconstruct (key, residual);
- for (size_t i = 0; i < d; i++) {
- residual[i] = x[i] - residual[i];
- }
-}
-
-void Index::compute_residual_n (idx_t n, const float* xs,
- float* residuals,
- const idx_t* keys) const {
-//#pragma omp parallel for
- for (idx_t i = 0; i < n; ++i) {
- compute_residual(&xs[i * d], &residuals[i * d], keys[i]);
- }
-}
-
-
-
-size_t Index::sa_code_size () const
-{
- ABORT ("standalone codec not implemented for this type of index");
-}
-
-void Index::sa_encode (idx_t, const float *,
- uint8_t *) const
-{
- ABORT ("standalone codec not implemented for this type of index");
-}
-
-void Index::sa_decode (idx_t, const uint8_t *,
- float *) const
-{
- ABORT ("standalone codec not implemented for this type of index");
-}
-
-}
diff --git a/src/3rd_party/faiss/Index.h b/src/3rd_party/faiss/Index.h
index deaabcaa..24765f7d 100644
--- a/src/3rd_party/faiss/Index.h
+++ b/src/3rd_party/faiss/Index.h
@@ -39,11 +39,6 @@
namespace faiss {
-/// Forward declarations see AuxIndexStructures.h
-struct IDSelector;
-struct RangeSearchResult;
-struct DistanceComputer;
-
/** Abstract structure for an index, supports adding vectors and searching them.
*
* All vectors provided at add or search time are 32-bit float arrays,
@@ -53,178 +48,6 @@ struct Index {
using idx_t = int64_t; ///< all indices are this type
using component_t = float;
using distance_t = float;
-
- int d; ///< vector dimension
- idx_t ntotal; ///< total nb of indexed vectors
- bool verbose; ///< verbosity level
-
- /// set if the Index does not require training, or if training is
- /// done already
- bool is_trained;
-
- /// type of metric this index uses for search
- MetricType metric_type;
- float metric_arg; ///< argument of the metric type
-
- explicit Index (idx_t d = 0, MetricType metric = METRIC_L2):
- d((int)d),
- ntotal(0),
- verbose(false),
- is_trained(true),
- metric_type (metric),
- metric_arg(0) {}
-
- virtual ~Index ();
-
-
- /** Perform training on a representative set of vectors
- *
- * @param n nb of training vectors
- * @param x training vecors, size n * d
- */
- virtual void train(idx_t n, const float* x);
-
- /** Add n vectors of dimension d to the index.
- *
- * Vectors are implicitly assigned labels ntotal .. ntotal + n - 1
- * This function slices the input vectors in chuncks smaller than
- * blocksize_add and calls add_core.
- * @param x input matrix, size n * d
- */
- virtual void add (idx_t n, const float *x) = 0;
-
- /** Same as add, but stores xids instead of sequential ids.
- *
- * The default implementation fails with an assertion, as it is
- * not supported by all indexes.
- *
- * @param xids if non-null, ids to store for the vectors (size n)
- */
- virtual void add_with_ids (idx_t n, const float * x, const idx_t *xids);
-
- /** query n vectors of dimension d to the index.
- *
- * return at most k vectors. If there are not enough results for a
- * query, the result array is padded with -1s.
- *
- * @param x input vectors to search, size n * d
- * @param labels output labels of the NNs, size n*k
- * @param distances output pairwise distances, size n*k
- */
- virtual void search (idx_t n, const float *x, idx_t k,
- float *distances, idx_t *labels) const = 0;
-
- /** query n vectors of dimension d to the index.
- *
- * return all vectors with distance < radius. Note that many
- * indexes do not implement the range_search (only the k-NN search
- * is mandatory).
- *
- * @param x input vectors to search, size n * d
- * @param radius search radius
- * @param result result table
- */
- virtual void range_search (idx_t n, const float *x, float radius,
- RangeSearchResult *result) const;
-
- /** return the indexes of the k vectors closest to the query x.
- *
- * This function is identical as search but only return labels of neighbors.
- * @param x input vectors to search, size n * d
- * @param labels output labels of the NNs, size n*k
- */
- void assign (idx_t n, const float * x, idx_t * labels, idx_t k = 1);
-
- /// removes all elements from the database.
- virtual void reset() = 0;
-
- /** removes IDs from the index. Not supported by all
- * indexes. Returns the number of elements removed.
- */
- virtual size_t remove_ids (const IDSelector & sel);
-
- /** Reconstruct a stored vector (or an approximation if lossy coding)
- *
- * this function may not be defined for some indexes
- * @param key id of the vector to reconstruct
- * @param recons reconstucted vector (size d)
- */
- virtual void reconstruct (idx_t key, float * recons) const;
-
- /** Reconstruct vectors i0 to i0 + ni - 1
- *
- * this function may not be defined for some indexes
- * @param recons reconstucted vector (size ni * d)
- */
- virtual void reconstruct_n (idx_t i0, idx_t ni, float *recons) const;
-
- /** Similar to search, but also reconstructs the stored vectors (or an
- * approximation in the case of lossy coding) for the search results.
- *
- * If there are not enough results for a query, the resulting arrays
- * is padded with -1s.
- *
- * @param recons reconstructed vectors size (n, k, d)
- **/
- virtual void search_and_reconstruct (idx_t n, const float *x, idx_t k,
- float *distances, idx_t *labels,
- float *recons) const;
-
- /** Computes a residual vector after indexing encoding.
- *
- * The residual vector is the difference between a vector and the
- * reconstruction that can be decoded from its representation in
- * the index. The residual can be used for multiple-stage indexing
- * methods, like IndexIVF's methods.
- *
- * @param x input vector, size d
- * @param residual output residual vector, size d
- * @param key encoded index, as returned by search and assign
- */
- virtual void compute_residual (const float * x,
- float * residual, idx_t key) const;
-
- /** Computes a residual vector after indexing encoding (batch form).
- * Equivalent to calling compute_residual for each vector.
- *
- * The residual vector is the difference between a vector and the
- * reconstruction that can be decoded from its representation in
- * the index. The residual can be used for multiple-stage indexing
- * methods, like IndexIVF's methods.
- *
- * @param n number of vectors
- * @param xs input vectors, size (n x d)
- * @param residuals output residual vectors, size (n x d)
- * @param keys encoded index, as returned by search and assign
- */
- virtual void compute_residual_n (idx_t n, const float* xs,
- float* residuals,
- const idx_t* keys) const;
-
- /* The standalone codec interface */
-
- /** size of the produced codes in bytes */
- virtual size_t sa_code_size () const;
-
- /** encode a set of vectors
- *
- * @param n number of vectors
- * @param x input vectors, size n * d
- * @param bytes output encoded vectors, size n * sa_code_size()
- */
- virtual void sa_encode (idx_t n, const float *x,
- uint8_t *bytes) const;
-
- /** encode a set of vectors
- *
- * @param n number of vectors
- * @param bytes input encoded vectors, size n * sa_code_size()
- * @param x output vectors, size n * d
- */
- virtual void sa_decode (idx_t n, const uint8_t *bytes,
- float *x) const;
-
-
};
}
diff --git a/src/3rd_party/faiss/IndexLSH.cpp b/src/3rd_party/faiss/IndexLSH.cpp
deleted file mode 100644
index 6df84331..00000000
--- a/src/3rd_party/faiss/IndexLSH.cpp
+++ /dev/null
@@ -1,224 +0,0 @@
-/**
- * Copyright (c) Facebook, Inc. and its affiliates.
- *
- * This source code is licensed under the MIT license found in the
- * LICENSE file in the root directory of this source tree.
- */
-
-// -*- c++ -*-
-
-#include <faiss/IndexLSH.h>
-
-#include <cstdio>
-#include <cstring>
-
-#include <algorithm>
-
-#include <faiss/utils/hamming.h>
-#include "common/logging.h"
-
-
-namespace faiss {
-
-/***************************************************************
- * IndexLSH
- ***************************************************************/
-
-
-IndexLSH::IndexLSH (idx_t d, int nbits, bool rotate_data, bool train_thresholds):
- Index(d), nbits(nbits), rotate_data(rotate_data),
- train_thresholds (train_thresholds), rrot(d, nbits)
-{
- is_trained = !train_thresholds;
-
- bytes_per_vec = (nbits + 7) / 8;
-
- if (rotate_data) {
- rrot.init(5);
- } else {
- ABORT_UNLESS(d >= nbits, "d >= nbits");
- }
-}
-
-IndexLSH::IndexLSH ():
- nbits (0), bytes_per_vec(0), rotate_data (false), train_thresholds (false)
-{
-}
-
-
-const float * IndexLSH::apply_preprocess (idx_t n, const float *x) const
-{
-
- float *xt = nullptr;
- if (rotate_data) {
- // also applies bias if exists
- xt = rrot.apply (n, x);
- } else if (d != nbits) {
- assert (nbits < d);
- xt = new float [nbits * n];
- float *xp = xt;
- for (idx_t i = 0; i < n; i++) {
- const float *xl = x + i * d;
- for (int j = 0; j < nbits; j++)
- *xp++ = xl [j];
- }
- }
-
- if (train_thresholds) {
-
- if (xt == NULL) {
- xt = new float [nbits * n];
- memcpy (xt, x, sizeof(*x) * n * nbits);
- }
-
- float *xp = xt;
- for (idx_t i = 0; i < n; i++)
- for (int j = 0; j < nbits; j++)
- *xp++ -= thresholds [j];
- }
-
- return xt ? xt : x;
-}
-
-
-
-void IndexLSH::train (idx_t n, const float *x)
-{
- if (train_thresholds) {
- thresholds.resize (nbits);
- train_thresholds = false;
- const float *xt = apply_preprocess (n, x);
- ScopeDeleter<float> del (xt == x ? nullptr : xt);
- train_thresholds = true;
-
- float * transposed_x = new float [n * nbits];
- ScopeDeleter<float> del2 (transposed_x);
-
- for (idx_t i = 0; i < n; i++)
- for (idx_t j = 0; j < nbits; j++)
- transposed_x [j * n + i] = xt [i * nbits + j];
-
- for (idx_t i = 0; i < nbits; i++) {
- float *xi = transposed_x + i * n;
- // std::nth_element
- std::sort (xi, xi + n);
- if (n % 2 == 1)
- thresholds [i] = xi [n / 2];
- else
- thresholds [i] = (xi [n / 2 - 1] + xi [n / 2]) / 2;
-
- }
- }
- is_trained = true;
-}
-
-
-void IndexLSH::add (idx_t n, const float *x)
-{
- ABORT_UNLESS (is_trained, "is_trained");
- codes.resize ((ntotal + n) * bytes_per_vec);
-
- sa_encode (n, x, &codes[ntotal * bytes_per_vec]);
-
- ntotal += n;
-}
-
-
-void IndexLSH::search (
- idx_t n,
- const float *x,
- idx_t k,
- float *distances,
- idx_t *labels) const
-{
- ABORT_UNLESS (is_trained, "is_trained");
- const float *xt = apply_preprocess (n, x);
- ScopeDeleter<float> del (xt == x ? nullptr : xt);
-
- uint8_t * qcodes = new uint8_t [n * bytes_per_vec];
- ScopeDeleter<uint8_t> del2 (qcodes);
-
- fvecs2bitvecs (xt, qcodes, nbits, n);
-
- int * idistances = new int [n * k];
- ScopeDeleter<int> del3 (idistances);
-
- int_maxheap_array_t res = { size_t(n), size_t(k), labels, idistances};
-
- hammings_knn_hc (&res, qcodes, codes.data(),
- ntotal, bytes_per_vec, true);
-
-
- // convert distances to floats
- for (int i = 0; i < k * n; i++)
- distances[i] = idistances[i];
-
-}
-
-
-void IndexLSH::transfer_thresholds (LinearTransform *vt) {
- if (!train_thresholds) return;
- ABORT_UNLESS (nbits == vt->d_out, "nbits == vt->d_out");
- if (!vt->have_bias) {
- vt->b.resize (nbits, 0);
- vt->have_bias = true;
- }
- for (int i = 0; i < nbits; i++)
- vt->b[i] -= thresholds[i];
- train_thresholds = false;
- thresholds.clear();
-}
-
-void IndexLSH::reset() {
- codes.clear();
- ntotal = 0;
-}
-
-
-size_t IndexLSH::sa_code_size () const
-{
- return bytes_per_vec;
-}
-
-void IndexLSH::sa_encode (idx_t n, const float *x,
- uint8_t *bytes) const
-{
- ABORT_UNLESS (is_trained, "is_trained");
- const float *xt = apply_preprocess (n, x);
- ScopeDeleter<float> del (xt == x ? nullptr : xt);
- fvecs2bitvecs (xt, bytes, nbits, n);
-}
-
-void IndexLSH::sa_decode (idx_t n, const uint8_t *bytes,
- float *x) const
-{
- float *xt = x;
- ScopeDeleter<float> del;
- if (rotate_data || nbits != d) {
- xt = new float [n * nbits];
- del.set(xt);
- }
- bitvecs2fvecs (bytes, xt, nbits, n);
-
- if (train_thresholds) {
- float *xp = xt;
- for (idx_t i = 0; i < n; i++) {
- for (int j = 0; j < nbits; j++) {
- *xp++ += thresholds [j];
- }
- }
- }
-
- if (rotate_data) {
- rrot.reverse_transform (n, xt, x);
- } else if (nbits != d) {
- for (idx_t i = 0; i < n; i++) {
- memcpy (x + i * d, xt + i * nbits,
- nbits * sizeof(xt[0]));
- }
- }
-}
-
-
-
-} // namespace faiss
diff --git a/src/3rd_party/faiss/IndexLSH.h b/src/3rd_party/faiss/IndexLSH.h
deleted file mode 100644
index 66435363..00000000
--- a/src/3rd_party/faiss/IndexLSH.h
+++ /dev/null
@@ -1,90 +0,0 @@
-/**
- * Copyright (c) Facebook, Inc. and its affiliates.
- *
- * This source code is licensed under the MIT license found in the
- * LICENSE file in the root directory of this source tree.
- */
-
-// -*- c++ -*-
-
-#ifndef INDEX_LSH_H
-#define INDEX_LSH_H
-
-#include <vector>
-
-#include <faiss/Index.h>
-#include <faiss/VectorTransform.h>
-
-namespace faiss {
-
-
-/** The sign of each vector component is put in a binary signature */
-struct IndexLSH:Index {
- typedef unsigned char uint8_t;
-
- int nbits; ///< nb of bits per vector
- int bytes_per_vec; ///< nb of 8-bits per encoded vector
- bool rotate_data; ///< whether to apply a random rotation to input
- bool train_thresholds; ///< whether we train thresholds or use 0
-
- RandomRotationMatrix rrot; ///< optional random rotation
-
- std::vector <float> thresholds; ///< thresholds to compare with
-
- /// encoded dataset
- std::vector<uint8_t> codes;
-
- IndexLSH (
- idx_t d, int nbits,
- bool rotate_data = true,
- bool train_thresholds = false);
-
- /** Preprocesses and resizes the input to the size required to
- * binarize the data
- *
- * @param x input vectors, size n * d
- * @return output vectors, size n * bits. May be the same pointer
- * as x, otherwise it should be deleted by the caller
- */
- const float *apply_preprocess (idx_t n, const float *x) const;
-
- void train(idx_t n, const float* x) override;
-
- void add(idx_t n, const float* x) override;
-
- void search(
- idx_t n,
- const float* x,
- idx_t k,
- float* distances,
- idx_t* labels) const override;
-
- void reset() override;
-
- /// transfer the thresholds to a pre-processing stage (and unset
- /// train_thresholds)
- void transfer_thresholds (LinearTransform * vt);
-
- ~IndexLSH() override {}
-
- IndexLSH ();
-
- /* standalone codec interface.
- *
- * The vectors are decoded to +/- 1 (not 0, 1) */
-
- size_t sa_code_size () const override;
-
- void sa_encode (idx_t n, const float *x,
- uint8_t *bytes) const override;
-
- void sa_decode (idx_t n, const uint8_t *bytes,
- float *x) const override;
-
-};
-
-
-}
-
-
-#endif
diff --git a/src/3rd_party/faiss/utils/hamming-inl.h b/src/3rd_party/faiss/utils/hamming-inl.h
index d32da758..b164dc88 100644
--- a/src/3rd_party/faiss/utils/hamming-inl.h
+++ b/src/3rd_party/faiss/utils/hamming-inl.h
@@ -10,8 +10,8 @@
namespace faiss {
-#ifdef _MSC_VER
-#define bzero(p,n) (memset((p),0,(n)))
+#ifdef _MSC_VER
+#define bzero(p,n) (memset((p),0,(n)))
#endif
inline BitstringWriter::BitstringWriter(uint8_t *code, int code_size):
code (code), code_size (code_size), i(0)
@@ -29,7 +29,7 @@ inline void BitstringWriter::write(uint64_t x, int nbit) {
i += nbit;
return;
} else {
- int j = i >> 3;
+ size_t j = i >> 3;
code[j++] |= x << (i & 7);
i += nbit;
x >>= na;
@@ -57,7 +57,7 @@ inline uint64_t BitstringReader::read(int nbit) {
return res;
} else {
int ofs = na;
- int j = (i >> 3) + 1;
+ size_t j = (i >> 3) + 1;
i += nbit;
nbit -= na;
while (nbit > 8) {
@@ -160,7 +160,7 @@ struct HammingComputer20 {
void set (const uint8_t *a8, int code_size) {
assert (code_size == 20);
const uint64_t *a = (uint64_t *)a8;
- a0 = a[0]; a1 = a[1]; a2 = a[2];
+ a0 = a[0]; a1 = a[1]; a2 = (uint32_t)a[2];
}
inline int hamming (const uint8_t *b8) const {
diff --git a/src/3rd_party/faiss/utils/hamming.h b/src/3rd_party/faiss/utils/hamming.h
index 762d3773..0c89c4d1 100644
--- a/src/3rd_party/faiss/utils/hamming.h
+++ b/src/3rd_party/faiss/utils/hamming.h
@@ -31,7 +31,7 @@
#ifdef _MSC_VER
#include <intrin.h> // needed for some intrinsics in <memory>
-#define __builtin_popcountl __popcnt64
+#define __builtin_popcountl __popcnt64
#endif
/* The Hamming distance type */
@@ -116,7 +116,7 @@ struct BitstringReader {
extern size_t hamming_batch_size;
static inline int popcount64(uint64_t x) {
- return __builtin_popcountl(x);
+ return (int)__builtin_popcountl(x);
}
diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt
index d2fd269f..1f5db423 100644
--- a/src/CMakeLists.txt
+++ b/src/CMakeLists.txt
@@ -75,6 +75,7 @@ set(MARIAN_SOURCES
layers/embedding.cpp
layers/output.cpp
layers/logits.cpp
+ layers/lsh.cpp
rnn/cells.cpp
rnn/attention.cpp
diff --git a/src/command/marian_conv.cpp b/src/command/marian_conv.cpp
index 26cac858..e0e89d2b 100644
--- a/src/command/marian_conv.cpp
+++ b/src/command/marian_conv.cpp
@@ -2,6 +2,7 @@
#include "common/cli_wrapper.h"
#include "tensors/cpu/expression_graph_packable.h"
#include "onnx/expression_graph_onnx_exporter.h"
+#include "layers/lsh.h"
#include <sstream>
@@ -25,6 +26,9 @@ int main(int argc, char** argv) {
cli->add<std::string>("--gemm-type,-g", "GEMM Type to be used: float32, packed16, packed8avx2, packed8avx512, "
"intgemm8, intgemm8ssse3, intgemm8avx2, intgemm8avx512, intgemm16, intgemm16sse2, intgemm16avx2, intgemm16avx512",
"float32");
+ cli->add<std::vector<std::string>>("--add-lsh",
+ "Encode output matrix and optional rotation matrix into model file. "
+ "arg1: number of bits in LSH encoding, arg2: name of output weights matrix")->implicit_val("1024 Wemb");
cli->add<std::vector<std::string>>("--vocabs,-V", "Vocabulary file, required for ONNX export");
cli->parse(argc, argv);
options->merge(config);
@@ -34,6 +38,16 @@ int main(int argc, char** argv) {
auto exportAs = options->get<std::string>("export-as");
auto vocabPaths = options->get<std::vector<std::string>>("vocabs");// , std::vector<std::string>());
+
+ bool addLsh = options->hasAndNotEmpty("add-lsh");
+ int lshNBits = 1024;
+ std::string lshOutputWeights = "Wemb";
+ if(addLsh) {
+ auto lshParams = options->get<std::vector<std::string>>("add-lsh");
+ lshNBits = std::stoi(lshParams[0]);
+ if(lshParams.size() > 1)
+ lshOutputWeights = lshParams[1];
+ }
// We accept any type here and will later croak during packAndSave if the type cannot be used for conversion
Type saveGemmType = typeFromString(options->get<std::string>("gemm-type", "float32"));
@@ -45,23 +59,36 @@ int main(int argc, char** argv) {
marian::io::getYamlFromModel(config, "special:model.yml", modelFrom);
configStr << config;
- auto load = [&](Ptr<ExpressionGraph> graph) {
+ if (exportAs == "marian-bin") {
+ auto graph = New<ExpressionGraphPackable>();
graph->setDevice(CPU0);
graph->load(modelFrom);
+
+ if(addLsh) {
+ // Add dummy parameters for the LSH before the model gets actually initialized.
+ // This create the parameters with useless values in the tensors, but it gives us the memory we need.
+ graph->setReloaded(false);
+ lsh::addDummyParameters(graph, /*weights=*/lshOutputWeights, /*nBits=*/lshNBits);
+ graph->setReloaded(true);
+ }
+
graph->forward(); // run the initializers
- };
+ if(addLsh) {
+ // After initialization, hijack the paramters for the LSH and force-overwrite with correct values.
+ // Once this is done we can just pack and save as normal.
+ lsh::overwriteDummyParameters(graph, /*weights=*/lshOutputWeights);
+ }
- if (exportAs == "marian-bin") {
- auto graph = New<ExpressionGraphPackable>();
- load(graph);
// added a flag if the weights needs to be packed or not
graph->packAndSave(modelTo, configStr.str(), /* --gemm-type */ saveGemmType, Type::float32);
}
else if (exportAs == "onnx-encode") {
#ifdef USE_ONNX
auto graph = New<ExpressionGraphONNXExporter>();
- load(graph);
+ graph->setDevice(CPU0);
+ graph->load(modelFrom);
+ graph->forward(); // run the initializers
auto modelOptions = New<Options>(config)->with("vocabs", vocabPaths, "inference", true);
graph->exportToONNX(modelTo, modelOptions, vocabPaths);
diff --git a/src/data/shortlist.cpp b/src/data/shortlist.cpp
index f7e229ff..396c6ba4 100644
--- a/src/data/shortlist.cpp
+++ b/src/data/shortlist.cpp
@@ -1,10 +1,7 @@
#include "data/shortlist.h"
#include "microsoft/shortlist/utils/ParameterTree.h"
#include "marian.h"
-
-#if BLAS_FOUND
-#include "3rd_party/faiss/IndexLSH.h"
-#endif
+#include "layers/lsh.h"
namespace marian {
namespace data {
@@ -47,7 +44,6 @@ void Shortlist::filter(Expr input, Expr weights, bool isLegacyUntransposedW, Exp
Shape kShape({k});
indicesExpr_ = lambda({input, weights}, kShape, Type::uint32, forward);
- //std::cerr << "indicesExpr_=" << indicesExpr_->shape() << std::endl;
createCachedTensors(weights, isLegacyUntransposedW, b, lemmaEt, k);
initialized_ = true;
}
@@ -78,12 +74,10 @@ void Shortlist::createCachedTensors(Expr weights,
}
///////////////////////////////////////////////////////////////////////////////////
-Ptr<faiss::IndexLSH> LSHShortlist::index_;
-std::mutex LSHShortlist::mutex_;
LSHShortlist::LSHShortlist(int k, int nbits, size_t lemmaSize)
-: Shortlist(std::vector<WordIndex>())
-, k_(k), nbits_(nbits), lemmaSize_(lemmaSize) {
+: Shortlist(std::vector<WordIndex>()),
+ k_(k), nbits_(nbits), lemmaSize_(lemmaSize) {
}
WordIndex LSHShortlist::reverseMap(int beamIdx, int batchIdx, int idx) const {
@@ -99,67 +93,23 @@ Expr LSHShortlist::getIndicesExpr() const {
}
void LSHShortlist::filter(Expr input, Expr weights, bool isLegacyUntransposedW, Expr b, Expr lemmaEt) {
-#if BLAS_FOUND
+
ABORT_IF(input->graph()->getDeviceId().type == DeviceType::gpu,
"LSH index (--output-approx-knn) currently not implemented for GPU");
- int currBeamSize = input->shape()[0];
- int batchSize = input->shape()[2];
- int numHypos = currBeamSize * batchSize;
-
- auto forward = [this, numHypos](Expr out, const std::vector<Expr>& inputs) {
- auto query = inputs[0];
- auto values = inputs[1];
- int dim = values->shape()[-1];
-
- mutex_.lock();
- if(!index_) {
- LOG(info, "Building LSH index for vector dim {} and with hash size {} bits", dim, nbits_);
- index_.reset(new faiss::IndexLSH(dim, nbits_,
- /*rotate=*/dim != nbits_,
- /*train_thesholds*/false));
- index_->train(lemmaSize_, values->val()->data<float>());
- index_->add( lemmaSize_, values->val()->data<float>());
- }
- mutex_.unlock();
-
- int qRows = query->shape().elements() / dim;
- std::vector<float> distances(qRows * k_);
- std::vector<faiss::Index::idx_t> ids(qRows * k_);
-
- index_->search(qRows, query->val()->data<float>(), k_,
- distances.data(), ids.data());
-
- indices_.clear();
- for(auto iter = ids.begin(); iter != ids.end(); ++iter) {
- faiss::Index::idx_t id = *iter;
- indices_.push_back((WordIndex)id);
- }
-
- for (size_t hypoIdx = 0; hypoIdx < numHypos; ++hypoIdx) {
- size_t startIdx = k_ * hypoIdx;
- size_t endIdx = startIdx + k_;
- std::sort(indices_.begin() + startIdx, indices_.begin() + endIdx);
- }
- out->val()->set(indices_);
- };
-
- Shape kShape({currBeamSize, batchSize, k_});
- indicesExpr_ = lambda({input, weights}, kShape, Type::uint32, forward);
+ indicesExpr_ = callback(lsh::search(input, weights, k_, nbits_, (int)lemmaSize_),
+ [this](Expr node) {
+ node->val()->get(indices_); // set the value of the field indices_ whenever the graph traverses this node
+ });
createCachedTensors(weights, isLegacyUntransposedW, b, lemmaEt, k_);
-
-#else
- input; weights; isLegacyUntransposedW; b; lemmaEt;
- ABORT("LSH output layer requires a CPU BLAS library");
-#endif
}
void LSHShortlist::createCachedTensors(Expr weights,
- bool isLegacyUntransposedW,
- Expr b,
- Expr lemmaEt,
- int k) {
+ bool isLegacyUntransposedW,
+ Expr b,
+ Expr lemmaEt,
+ int k) {
int currBeamSize = indicesExpr_->shape()[0];
int batchSize = indicesExpr_->shape()[1];
ABORT_IF(isLegacyUntransposedW, "Legacy untranspose W not yet tested");
diff --git a/src/data/shortlist.h b/src/data/shortlist.h
index a75d2c4b..d3841b21 100644
--- a/src/data/shortlist.h
+++ b/src/data/shortlist.h
@@ -25,7 +25,8 @@ namespace data {
class Shortlist {
protected:
std::vector<WordIndex> indices_; // // [packed shortlist index] -> word index, used to select columns from output embeddings
- Expr indicesExpr_;
+ Expr indicesExpr_; // cache an expression that contains the short list indices
+
Expr cachedShortWt_; // short-listed version, cached (cleared by clear())
Expr cachedShortb_; // these match the current value of shortlist_
Expr cachedShortLemmaEt_;
diff --git a/src/graph/expression_graph.h b/src/graph/expression_graph.h
index fce7d532..553a5d63 100644
--- a/src/graph/expression_graph.h
+++ b/src/graph/expression_graph.h
@@ -647,6 +647,16 @@ public:
}
/**
+ * Return the Parameters object related to the graph by elementType.
+ * The Parameters object holds the whole set of the parameter nodes of the given type.
+ */
+ Ptr<Parameters>& params(Type elementType) {
+ auto it = paramsByElementType_.find(elementType);
+ ABORT_IF(it == paramsByElementType_.end(), "Parameter object for type {} does not exist", defaultElementType_);
+ return it->second;
+ }
+
+ /**
* Set default element type for the graph.
* The default value is used if some node type is not specified.
*/
diff --git a/src/graph/expression_operators.cpp b/src/graph/expression_operators.cpp
index 24d12eea..560ab4e7 100644
--- a/src/graph/expression_operators.cpp
+++ b/src/graph/expression_operators.cpp
@@ -28,13 +28,17 @@ Expr checkpoint(Expr a) {
}
Expr lambda(const std::vector<Expr>& nodes, Shape shape, Type type,
- LambdaNodeFunctor fwd) {
- return Expression<LambdaNodeOp>(nodes, shape, type, fwd);
+ LambdaNodeFunctor fwd, size_t hash) {
+ return Expression<LambdaNodeOp>(nodes, shape, type, fwd, hash);
}
Expr lambda(const std::vector<Expr>& nodes, Shape shape, Type type,
- LambdaNodeFunctor fwd, LambdaNodeFunctor bwd) {
- return Expression<LambdaNodeOp>(nodes, shape, type, fwd, bwd);
+ LambdaNodeFunctor fwd, LambdaNodeFunctor bwd, size_t hash) {
+ return Expression<LambdaNodeOp>(nodes, shape, type, fwd, bwd, hash);
+}
+
+Expr callback(Expr node, LambdaNodeCallback call) {
+ return Expression<CallbackNodeOp>(node, call);
}
// logistic function. Note: scipy name is expit()
diff --git a/src/graph/expression_operators.h b/src/graph/expression_operators.h
index 6c7e5758..e34ddc8a 100644
--- a/src/graph/expression_operators.h
+++ b/src/graph/expression_operators.h
@@ -26,12 +26,19 @@ typedef std::function<void(Expr out, const std::vector<Expr>& in)> LambdaNodeFun
/**
* Arbitrary node with forward operation only.
*/
-Expr lambda(const std::vector<Expr>& nodes, Shape shape, Type type, LambdaNodeFunctor fwd);
+Expr lambda(const std::vector<Expr>& nodes, Shape shape, Type type, LambdaNodeFunctor fwd, size_t hash = 0);
/**
* Arbitrary node with forward and backward operation.
*/
-Expr lambda(const std::vector<Expr>& nodes, Shape shape, Type type, LambdaNodeFunctor fwd, LambdaNodeFunctor bwd);
+Expr lambda(const std::vector<Expr>& nodes, Shape shape, Type type, LambdaNodeFunctor fwd, LambdaNodeFunctor bwd, size_t hash = 0);
+
+
+/**
+ * Convience typedef for graph @ref lambda expressions.
+ */
+typedef std::function<void(Expr)> LambdaNodeCallback;
+Expr callback(Expr node, LambdaNodeCallback call);
/**
* @addtogroup graph_ops_activation Activation Functions
diff --git a/src/graph/node_initializers.cpp b/src/graph/node_initializers.cpp
index 4e39d1bf..e44b4828 100644
--- a/src/graph/node_initializers.cpp
+++ b/src/graph/node_initializers.cpp
@@ -11,6 +11,15 @@ namespace marian {
namespace inits {
+class DummyInit : public NodeInitializer {
+public:
+ void apply(Tensor tensor) override {
+ tensor;
+ }
+};
+
+Ptr<NodeInitializer> dummy() { return New<DummyInit>(); }
+
class LambdaInit : public NodeInitializer {
private:
std::function<void(Tensor)> lambda_;
@@ -237,24 +246,3 @@ template Ptr<NodeInitializer> range<IndexType>(IndexType begin, IndexType end, I
} // namespace inits
} // namespace marian
-
-#if BLAS_FOUND
-#include "faiss/VectorTransform.h"
-
-namespace marian {
-namespace inits {
-
-Ptr<NodeInitializer> randomRotation(size_t seed) {
- auto rot = [=](Tensor t) {
- int rows = t->shape()[-2];
- int cols = t->shape()[-1];
- faiss::RandomRotationMatrix rrot(cols, rows); // transposed in faiss
- rrot.init((int)seed);
- t->set(rrot.A);
- };
- return fromLambda(rot, Type::float32);
-}
-
-} // namespace inits
-} // namespace marian
-#endif
diff --git a/src/graph/node_initializers.h b/src/graph/node_initializers.h
index 7cdb4183..5e9f8013 100644
--- a/src/graph/node_initializers.h
+++ b/src/graph/node_initializers.h
@@ -36,6 +36,11 @@ public:
};
/**
+ * Dummy do-nothing initializer. Mostly for testing.
+ */
+Ptr<NodeInitializer> dummy();
+
+/**
* Use a lambda function of form [](Tensor t) { do something with t } to initialize tensor.
* @param func functor
*/
@@ -264,13 +269,6 @@ Ptr<NodeInitializer> fromWord2vec(const std::string& file,
Ptr<NodeInitializer> sinusoidalPositionEmbeddings(int start);
/**
- * Computes a random rotation matrix for LSH hashing.
- * This is part of a hash function. The values are orthonormal and computed via
- * QR decomposition. Same seed results in same random rotation.
- */
-Ptr<NodeInitializer> randomRotation(size_t seed = Config::seed);
-
-/**
* Computes the equivalent of Python's range().
* Computes a range from begin to end-1, like Python's range().
* The constant being initialized must have one dimension that matches
diff --git a/src/graph/node_operators_binary.h b/src/graph/node_operators_binary.h
index 169b1420..a180bb5c 100644
--- a/src/graph/node_operators_binary.h
+++ b/src/graph/node_operators_binary.h
@@ -21,20 +21,26 @@ private:
std::unique_ptr<LambdaNodeFunctor> forward_;
std::unique_ptr<LambdaNodeFunctor> backward_;
+ size_t externalHash_;
+
public:
LambdaNodeOp(Inputs inputs, Shape shape, Type type,
- LambdaNodeFunctor forward)
+ LambdaNodeFunctor forward,
+ size_t externalHash = 0)
: NaryNodeOp(inputs, shape, type),
- forward_(new LambdaNodeFunctor(forward)) {
+ forward_(new LambdaNodeFunctor(forward)),
+ externalHash_(externalHash) {
Node::trainable_ = !!backward_;
}
LambdaNodeOp(Inputs inputs, Shape shape, Type type,
LambdaNodeFunctor forward,
- LambdaNodeFunctor backward)
+ LambdaNodeFunctor backward,
+ size_t externalHash = 0)
: NaryNodeOp(inputs, shape, type),
forward_(new LambdaNodeFunctor(forward)),
- backward_(new LambdaNodeFunctor(backward)) {
+ backward_(new LambdaNodeFunctor(backward)),
+ externalHash_(externalHash) {
}
void forward() override {
@@ -50,8 +56,12 @@ public:
virtual size_t hash() override {
size_t seed = NaryNodeOp::hash();
- util::hash_combine(seed, forward_.get());
- util::hash_combine(seed, backward_.get());
+ if(externalHash_ != 0) {
+ util::hash_combine(seed, externalHash_);
+ } else {
+ util::hash_combine(seed, forward_.get());
+ util::hash_combine(seed, backward_.get());
+ }
return seed;
}
diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h
index 82b02a65..448b4c4a 100644
--- a/src/graph/node_operators_unary.h
+++ b/src/graph/node_operators_unary.h
@@ -795,7 +795,7 @@ private:
};
class ReshapeNodeOp : public UnaryNodeOp {
-private:
+protected:
friend class SerializationHelpers;
Expr reshapee_;
@@ -858,6 +858,45 @@ public:
}
};
+// @TODO: add version with access to backward step
+// This allows to attach a lambda function to any node during the execution. It is a non-operation otherwise
+// i.e. doesn't consume any memory or take any time to execute (it's a reshape onto itself) other than the
+// compute in the lambda function. It gets called after the forward step of the argument node.
+class CallbackNodeOp : public ReshapeNodeOp {
+private:
+ typedef std::function<void(Expr)> LambdaNodeCallback;
+ std::unique_ptr<LambdaNodeCallback> callback_;
+
+public:
+ CallbackNodeOp(Expr node, LambdaNodeCallback callback)
+ : ReshapeNodeOp(node, node->shape()),
+ callback_(new LambdaNodeCallback(callback)) {
+ }
+
+ void forward() override {
+ (*callback_)(ReshapeNodeOp::reshapee_);
+ }
+
+ const std::string type() override { return "callback"; }
+
+ virtual size_t hash() override {
+ size_t seed = ReshapeNodeOp::hash();
+ util::hash_combine(seed, callback_.get());
+ return seed;
+ }
+
+ virtual bool equal(Expr node) override {
+ if(!ReshapeNodeOp::equal(node))
+ return false;
+ auto cnode = std::dynamic_pointer_cast<CallbackNodeOp>(node);
+ if(!cnode)
+ return false;
+ if(callback_ != cnode->callback_) // pointer compare on purpose
+ return false;
+ return true;
+ }
+};
+
// @TODO: review if still required as this is an ugly hack anyway.
// Memory less operator that clips gradients during backward step
// Executes this as an additional operation on the gradient.
diff --git a/src/layers/lsh.cpp b/src/layers/lsh.cpp
new file mode 100644
index 00000000..89b482f4
--- /dev/null
+++ b/src/layers/lsh.cpp
@@ -0,0 +1,233 @@
+#include "layers/lsh.h"
+#include "tensors/tensor_operators.h"
+#include "common/utils.h"
+
+#include "3rd_party/faiss/utils/hamming.h"
+#include "3rd_party/faiss/Index.h"
+
+#if BLAS_FOUND
+#include "3rd_party/faiss/VectorTransform.h"
+#endif
+
+
+namespace marian {
+namespace lsh {
+
+int bytesPerVector(int nBits) {
+ return (nBits + 7) / 8;
+}
+
+void fillRandomRotationMatrix(Tensor output, Ptr<Allocator> allocator) {
+#if BLAS_FOUND
+ int nRows = output->shape()[-2];
+ int nBits = output->shape()[-1];
+
+ // @TODO re-implement using Marian code so it uses the correct random generator etc.
+ faiss::RandomRotationMatrix rrot(nRows, nBits);
+ // Then we do not need to use this seed at all
+ rrot.init(5); // currently set to 5 following the default from FAISS, this could be any number really.
+
+ // The faiss random rotation matrix is column major, hence we create a temporary tensor,
+ // copy the rotation matrix into it and transpose to output.
+ Shape tempShape = {nBits, nRows};
+ auto memory = allocator->alloc(requiredBytes(tempShape, output->type()));
+ auto temp = TensorBase::New(memory,
+ tempShape,
+ output->type(),
+ output->getBackend());
+ temp->set(rrot.A);
+ TransposeND(output, temp, {0, 1, 3, 2});
+ allocator->free(memory);
+#else
+ output; allocator;
+ ABORT("LSH with rotation matrix requires Marian to be compiled with a BLAS library");
+#endif
+}
+
+void encode(Tensor output, Tensor input) {
+ int nBits = input->shape()[-1]; // number of bits is equal last dimension of float matrix
+ int nRows = input->shape().elements() / nBits;
+ faiss::fvecs2bitvecs(input->data<float>(), output->data<uint8_t>(), (size_t)nBits, (size_t)nRows);
+}
+
+void encodeWithRotation(Tensor output, Tensor input, Tensor rotation, Ptr<Allocator> allocator) {
+ int nBits = input->shape()[-1]; // number of bits is equal last dimension of float matrix unless we rotate
+ int nRows = input->shape().elements() / nBits;
+
+ Tensor tempInput = input;
+ MemoryPiece::PtrType memory;
+ if(rotation) {
+ int nBitsRot = rotation->shape()[-1];
+ Shape tempShape = {nRows, nBitsRot};
+ memory = allocator->alloc(requiredBytes(tempShape, rotation->type()));
+ tempInput = TensorBase::New(memory, tempShape, rotation->type(), rotation->getBackend());
+ Prod(tempInput, input, rotation, false, false, 0.f, 1.f);
+ }
+ encode(output, tempInput);
+
+ if(memory)
+ allocator->free(memory);
+};
+
+Expr encode(Expr input, Expr rotation) {
+ auto encodeFwd = [](Expr out, const std::vector<Expr>& inputs) {
+ if(inputs.size() == 1) {
+ encode(out->val(), inputs[0]->val());
+ } else if(inputs.size() == 2) {
+ encodeWithRotation(out->val(), inputs[0]->val(), inputs[1]->val(), out->graph()->allocator());
+ } else {
+ ABORT("Too many inputs to encode??");
+ }
+ };
+
+ // Use the address of the first lambda function as an immutable hash. Making it static and const makes sure
+ // that this hash value will not change. Next pass the hash into the lambda functor were it will be used
+ // to identify this unique operation. Marian's ExpressionGraph can automatically memoize and identify nodes
+ // that operate only on immutable nodes (parameters) and have the same hash. This way we make sure that the
+ // codes node won't actually get recomputed throughout ExpressionGraph lifetime. `codes` will be reused
+ // and the body of the lambda will not be called again. This does however build one index per graph.
+ static const size_t encodeHash = (size_t)&encodeFwd;
+
+ Shape encodedShape = input->shape();
+
+ int nBits = rotation ? rotation->shape()[-1] : input->shape()[-1];
+ encodedShape.set(-1, bytesPerVector(nBits));
+ std::vector<Expr> inputs = {input};
+ if(rotation)
+ inputs.push_back(rotation);
+ return lambda(inputs, encodedShape, Type::uint8, encodeFwd, encodeHash);
+}
+
+Expr rotator(Expr weights, int nBits) {
+ auto rotator = [](Expr out, const std::vector<Expr>& inputs) {
+ inputs;
+ fillRandomRotationMatrix(out->val(), out->graph()->allocator());
+ };
+
+ static const size_t rotatorHash = (size_t)&rotator;
+ int dim = weights->shape()[-1];
+ return lambda({weights}, {dim, nBits}, Type::float32, rotator, rotatorHash);
+}
+
+Expr searchEncoded(Expr encodedQuery, Expr encodedWeights, int k, int firstNRows) {
+ ABORT_IF(encodedQuery->shape()[-1] != encodedWeights->shape()[-1],
+ "Query and index bit vectors need to be of same size ({} != {})", encodedQuery->shape()[-1], encodedWeights->shape()[-1]);
+
+ int currBeamSize = encodedQuery->shape()[0];
+ int batchSize = encodedQuery->shape()[2];
+ int numHypos = currBeamSize * batchSize;
+
+ auto search = [=](Expr out, const std::vector<Expr>& inputs) {
+ Expr encodedQuery = inputs[0];
+ Expr encodedWeights = inputs[1];
+
+ int bytesPerVector = encodedWeights->shape()[-1];
+ int wRows = encodedWeights->shape().elements() / bytesPerVector;
+
+ // we use this with Factored Segmenter to skip the factor embeddings at the end
+ if(firstNRows != 0)
+ wRows = firstNRows;
+
+ int qRows = encodedQuery->shape().elements() / bytesPerVector;
+
+ uint8_t* qCodes = encodedQuery->val()->data<uint8_t>();
+ uint8_t* wCodes = encodedWeights->val()->data<uint8_t>();
+
+ // use actual faiss code for performing the hamming search.
+ std::vector<int> distances(qRows * k);
+ std::vector<faiss::Index::idx_t> ids(qRows * k);
+ faiss::int_maxheap_array_t res = {(size_t)qRows, (size_t)k, ids.data(), distances.data()};
+ faiss::hammings_knn_hc(&res, qCodes, wCodes, (size_t)wRows, (size_t)bytesPerVector, 0);
+
+ // Copy int64_t indices to Marian index type and sort by increasing index value per hypothesis.
+ // The sorting is required as we later do a binary search on those values for reverse look-up.
+ uint32_t* outData = out->val()->data<uint32_t>();
+ for (size_t hypoIdx = 0; hypoIdx < numHypos; ++hypoIdx) {
+ size_t startIdx = k * hypoIdx;
+ size_t endIdx = startIdx + k;
+ for(size_t i = startIdx; i < endIdx; ++i)
+ outData[i] = (uint32_t)ids[i];
+ std::sort(outData + startIdx, outData + endIdx);
+ }
+ };
+
+ Shape kShape({currBeamSize, batchSize, k});
+ return lambda({encodedQuery, encodedWeights}, kShape, Type::uint32, search);
+}
+
+Expr search(Expr query, Expr weights, int k, int nBits, int firstNRows) {
+ int dim = weights->shape()[-1];
+
+ Expr rotMat = nullptr;
+ if(dim != nBits) {
+ rotMat = weights->graph()->get("lsh_output_rotation");
+ if(rotMat) {
+ LOG_ONCE(info, "Reusing parameter LSH rotation matrix {} with shape {}", rotMat->name(), rotMat->shape());
+ } else {
+ LOG_ONCE(info, "Creating ad-hoc rotation matrix with shape {}", Shape({dim, nBits}));
+ rotMat = rotator(weights, nBits);
+ }
+ }
+
+ Expr encodedWeights = weights->graph()->get("lsh_output_codes");
+ if(encodedWeights) {
+ LOG_ONCE(info, "Reusing parameter LSH code matrix {} with shape {}", encodedWeights->name(), encodedWeights->shape());
+ } else {
+ LOG_ONCE(info, "Creating ad-hoc code matrix with shape {}", Shape({weights->shape()[-2], lsh::bytesPerVector(nBits)}));
+ encodedWeights = encode(weights, rotMat);
+ }
+
+ return searchEncoded(encode(query, rotMat), encodedWeights, k, firstNRows);
+}
+
+class RandomRotation : public inits::NodeInitializer {
+public:
+ void apply(Tensor tensor) override {
+ auto sharedAllocator = allocator_.lock();
+ ABORT_IF(!sharedAllocator, "Allocator in RandomRotation has not been set or expired");
+ fillRandomRotationMatrix(tensor, sharedAllocator);
+ }
+};
+
+Ptr<inits::NodeInitializer> randomRotation() {
+ return New<RandomRotation>();
+}
+
+void addDummyParameters(Ptr<ExpressionGraph> graph, std::string weightsName, int nBitsRot) {
+ auto weights = graph->get(weightsName);
+
+ ABORT_IF(!weights, "Trying to encode non-existing weights matrix {}??", weightsName);
+
+ int nBits = weights->shape()[-1];
+ int nRows = weights->shape().elements() / nBits;
+
+ Expr rotation;
+ if(nBits != nBitsRot) {
+ LOG(info, "Adding LSH rotation parameter lsh_output_rotation with shape {}", Shape({nBits, nBitsRot}));
+ rotation = graph->param("lsh_output_rotation", {nBits, nBitsRot}, inits::dummy(), Type::float32);
+ nBits = nBitsRot;
+ }
+
+ int bytesPerVector = lsh::bytesPerVector(nBits);
+ LOG(info, "Adding LSH encoded weights lsh_output_codes with shape {}", Shape({nRows, bytesPerVector}));
+ auto codes = graph->param("lsh_output_codes", {nRows, bytesPerVector}, inits::dummy(), Type::uint8);
+}
+
+void overwriteDummyParameters(Ptr<ExpressionGraph> graph, std::string weightsName) {
+ Expr weights = graph->get(weightsName);
+ Expr codes = graph->get("lsh_output_codes");
+ Expr rotation = graph->get("lsh_output_rotation");
+
+ ABORT_IF(!weights, "Trying to encode non-existing weights matrix {}??", weightsName);
+ ABORT_IF(!codes, "Trying to overwrite non-existing LSH parameters lsh_output_codes??");
+
+ if(rotation) {
+ fillRandomRotationMatrix(rotation->val(), weights->graph()->allocator());
+ encodeWithRotation(codes->val(), weights->val(), rotation->val(), weights->graph()->allocator());
+ } else {
+ encode(codes->val(), weights->val());
+ }
+}
+
+}
+} \ No newline at end of file
diff --git a/src/layers/lsh.h b/src/layers/lsh.h
new file mode 100644
index 00000000..60908238
--- /dev/null
+++ b/src/layers/lsh.h
@@ -0,0 +1,49 @@
+#pragma once
+
+#include "graph/expression_operators.h"
+#include "graph/node_initializers.h"
+
+#include <vector>
+
+/**
+ * In this file we bascially take the faiss::IndexLSH and pick it apart so that the individual steps
+ * can be implemented as Marian inference operators. We can encode the inputs and weights into their
+ * bitwise equivalents, apply the hashing rotation (if required), and perform the actual search.
+ *
+ * This also allows to create parameters that get dumped into the model weight file. This is currently
+ * a bit hacky (see marian-conv), but once this is done the model can memory-map the LSH with existing
+ * mechanisms and no additional memory is consumed to build the index or rotation matrix.
+ */
+
+namespace marian {
+namespace lsh {
+
+ // return the number of full bytes required to encoded that many bits
+ int bytesPerVector(int nBits);
+
+ // encodes an input as a bit vector, with optional rotation
+ Expr encode(Expr input, Expr rotator = nullptr);
+
+ // compute the rotation matrix (maps weights->shape()[-1] to nbits floats)
+ Expr rotator(Expr weights, int nbits);
+
+ // perform the LSH search on fully encoded input and weights, return k results (indices) per input row
+ // @TODO: add a top-k like operator that also returns the bitwise computed distances
+ Expr searchEncoded(Expr encodedQuery, Expr encodedWeights, int k, int firstNRows = 0);
+
+ // same as above, but performs encoding on the fly
+ Expr search(Expr query, Expr weights, int k, int nbits, int firstNRows = 0);
+
+ // These are helper functions for encoding the LSH into the binary Marian model, used by marian-conv
+ void addDummyParameters(Ptr<ExpressionGraph> graph, std::string weightsName, int nBits);
+ void overwriteDummyParameters(Ptr<ExpressionGraph> graph, std::string weightsName);
+
+ /**
+ * Computes a random rotation matrix for LSH hashing.
+ * This is part of a hash function. The values are orthonormal and computed via
+ * QR decomposition.
+ */
+ Ptr<inits::NodeInitializer> randomRotation();
+}
+
+} \ No newline at end of file
diff --git a/src/microsoft/quicksand.cpp b/src/microsoft/quicksand.cpp
index 6476df8f..70e657a9 100644
--- a/src/microsoft/quicksand.cpp
+++ b/src/microsoft/quicksand.cpp
@@ -11,6 +11,7 @@
#include "data/alignment.h"
#include "data/vocab_base.h"
#include "tensors/cpu/expression_graph_packable.h"
+#include "layers/lsh.h"
#if USE_FBGEMM
#include "fbgemm/Utils.h"
@@ -248,7 +249,7 @@ DecoderCpuAvxVersion parseCpuAvxVersion(std::string name) {
// This function converts an fp32 model into an FBGEMM based packed model.
// marian defined types are used for external project as well.
// The targetPrec is passed as int32_t for the exported function definition.
-bool convertModel(std::string inputFile, std::string outputFile, int32_t targetPrec) {
+bool convertModel(std::string inputFile, std::string outputFile, int32_t targetPrec, bool addLsh) {
std::cerr << "Converting from: " << inputFile << ", to: " << outputFile << ", precision: " << targetPrec << std::endl;
YAML::Node config;
@@ -260,7 +261,26 @@ bool convertModel(std::string inputFile, std::string outputFile, int32_t targetP
graph->setDevice(CPU0);
graph->load(inputFile);
- graph->forward();
+
+ // MJD: Note, this is a default settings which we might want to change or expose. Use this only with Polonium students.
+ // The LSH will not be used by default even if it exists in the model. That has to be enabled in the decoder config.
+ int lshNBits = 1024;
+ std::string lshOutputWeights = "Wemb";
+ if(addLsh) {
+ // Add dummy parameters for the LSH before the model gets actually initialized.
+ // This create the parameters with useless values in the tensors, but it gives us the memory we need.
+ graph->setReloaded(false);
+ lsh::addDummyParameters(graph, /*weights=*/lshOutputWeights, /*nBits=*/lshNBits);
+ graph->setReloaded(true);
+ }
+
+ graph->forward(); // run the initializers
+
+ if(addLsh) {
+ // After initialization, hijack the paramters for the LSH and force-overwrite with correct values.
+ // Once this is done we can just pack and save as normal.
+ lsh::overwriteDummyParameters(graph, /*weights=*/lshOutputWeights);
+ }
Type targetPrecType = (Type) targetPrec;
if (targetPrecType == Type::packed16
diff --git a/src/microsoft/quicksand.h b/src/microsoft/quicksand.h
index 87de1948..b710e135 100644
--- a/src/microsoft/quicksand.h
+++ b/src/microsoft/quicksand.h
@@ -76,7 +76,10 @@ std::vector<Ptr<IVocabWrapper>> loadVocabs(const std::vector<std::string>& vocab
DecoderCpuAvxVersion getCpuAvxVersion();
DecoderCpuAvxVersion parseCpuAvxVersion(std::string name);
-bool convertModel(std::string inputFile, std::string outputFile, int32_t targetPrec);
+// MJD: added "addLsh" which will now break whatever compilation after update. That's on purpose.
+// The calling code should be adapted, not this interface. If you need to fix things in QS because of this
+// talk to me first!
+bool convertModel(std::string inputFile, std::string outputFile, int32_t targetPrec, bool addLsh);
} // namespace quicksand
} // namespace marian
diff --git a/src/tensors/cpu/expression_graph_packable.h b/src/tensors/cpu/expression_graph_packable.h
index 689aa3b1..f5a9cad9 100644
--- a/src/tensors/cpu/expression_graph_packable.h
+++ b/src/tensors/cpu/expression_graph_packable.h
@@ -27,14 +27,17 @@ public:
virtual ~ExpressionGraphPackable() {}
// Convert model weights into packed format and save to IO items.
- // @TODO: review this
- void packAndSave(const std::string& name, const std::string& meta, Type gemmElementType = Type::float32, Type saveElementType = Type::float32) {
+ std::vector<io::Item> pack(Type gemmElementType = Type::float32, Type saveElementType = Type::float32) {
std::vector<io::Item> ioItems;
+ // handle packable parameters first (a float32 parameter is packable)
+ auto packableParameters = paramsByElementType_[Type::float32];
// sorted by name in std::map
- for (auto p : params()->getMap()) {
+ for (auto p : packableParameters->getMap()) {
std::string pName = p.first;
+ LOG(info, "Processing parameter {} with shape {} and type {}", pName, p.second->shape(), p.second->value_type());
+
if (!namespace_.empty()) {
if (pName.substr(0, namespace_.size() + 2) == namespace_ + "::")
pName = pName.substr(namespace_.size() + 2);
@@ -257,6 +260,33 @@ public:
}
}
+ // Now handle all non-float32 parameters
+ for(auto& iter : paramsByElementType_) {
+ auto type = iter.first;
+ if(type == Type::float32)
+ continue;
+
+ for (auto p : iter.second->getMap()) {
+ std::string pName = p.first;
+ LOG(info, "Processing parameter {} with shape {} and type {}", pName, p.second->shape(), p.second->value_type());
+
+ if (!namespace_.empty()) {
+ if (pName.substr(0, namespace_.size() + 2) == namespace_ + "::")
+ pName = pName.substr(namespace_.size() + 2);
+ }
+
+ Tensor val = p.second->val();
+ io::Item item;
+ val->get(item, pName);
+ ioItems.emplace_back(std::move(item));
+ }
+ }
+
+ return ioItems;
+ }
+
+ void packAndSave(const std::string& name, const std::string& meta, Type gemmElementType = Type::float32, Type saveElementType = Type::float32) {
+ auto ioItems = pack(gemmElementType, saveElementType);
if (!meta.empty())
io::addMetaToItems(meta, "special:model.yml", ioItems);
io::saveItems(name, ioItems);
diff --git a/src/tensors/tensor.h b/src/tensors/tensor.h
index 10c3e7f1..a7071404 100644
--- a/src/tensors/tensor.h
+++ b/src/tensors/tensor.h
@@ -35,7 +35,8 @@ class TensorBase {
ENABLE_INTRUSIVE_PTR(TensorBase)
- // Constructors are private, use TensorBase::New(...)
+protected:
+ // Constructors are protected, use TensorBase::New(...)
TensorBase(MemoryPiece::PtrType memory,
Shape shape,
Type type,
@@ -61,10 +62,10 @@ class TensorBase {
shape_(shape), type_(type), backend_(backend) {}
public:
- // Use this whenever pointing to MemoryPiece
+ // Use this whenever pointing to TensorBase
typedef IPtr<TensorBase> PtrType;
- // Use this whenever creating a pointer to MemoryPiece
+ // Use this whenever creating a pointer to TensorBase
template <class ...Args>
static PtrType New(Args&& ...args) {
return PtrType(new TensorBase(std::forward<Args>(args)...));
diff --git a/src/training/training_state.h b/src/training/training_state.h
index e0c1ba5d..ce0895a2 100644
--- a/src/training/training_state.h
+++ b/src/training/training_state.h
@@ -142,8 +142,9 @@ public:
// for periods.
bool enteredNewPeriodOf(std::string schedulingParam) const {
auto period = SchedulingParameter::parse(schedulingParam);
+ // @TODO: adapt to logical epochs
ABORT_IF(period.unit == SchedulingUnit::epochs,
- "Unit {} is not supported for frequency parameters (the one(s) with value {})",
+ "Unit {} is not supported for frequency parameters",
schedulingParam);
auto previousProgress = getPreviousProgressIn(period.unit);
auto progress = getProgressIn(period.unit);