Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorafaji <afaji321@gmail.com>2020-11-05 01:25:40 +0300
committerGitHub <noreply@github.com>2020-11-05 01:25:40 +0300
commite274ac76b2c967c69ac9ab053425bc0e29c6b718 (patch)
tree5114cb80e1a21235e79a716de05ae8fe60f0d71a
parentfabbe203091dec345165dc63528072b86d177d19 (diff)
Quantized model finetuning (#690)
* enable quantized training
-rw-r--r--CHANGELOG.md1
-rw-r--r--src/CMakeLists.txt1
-rwxr-xr-xsrc/common/config_parser.cpp19
-rw-r--r--src/common/config_parser.h1
-rw-r--r--src/common/config_validator.cpp6
-rwxr-xr-xsrc/functional/operators.h28
-rw-r--r--src/optimizers/quantizer.cpp185
-rw-r--r--src/optimizers/quantizer.h48
-rwxr-xr-xsrc/tensors/gpu/add.inc1
-rw-r--r--src/tensors/gpu/add_all.inc3
-rwxr-xr-xsrc/tensors/gpu/element.inc10
-rwxr-xr-xsrc/training/graph_group_sync.cpp18
-rwxr-xr-xsrc/training/graph_group_sync.h4
13 files changed, 322 insertions, 3 deletions
diff --git a/CHANGELOG.md b/CHANGELOG.md
index 420f0445..8dbbec48 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -28,6 +28,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
- Support for reading from TSV files from STDIN and other sources during training
and translation with options --tsv and --tsv-fields n.
- Internal optional parameter in n-best list generation that skips empty hypotheses.
+- Quantized training (fixed point or log-based quantization) with --quantize-bits N command
### Fixed
- Fix bug causing certain reductions into scalars to be 0 on the GPU backend. Removed unnecessary warp shuffle instructions.
diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt
index 5c52fd50..6dcf7fd8 100644
--- a/src/CMakeLists.txt
+++ b/src/CMakeLists.txt
@@ -76,6 +76,7 @@ set(MARIAN_SOURCES
rnn/cells.cpp
rnn/attention.cpp
+ optimizers/quantizer.cpp
optimizers/clippers.cpp
optimizers/optimizers.cpp
diff --git a/src/common/config_parser.cpp b/src/common/config_parser.cpp
index e67c9b2c..68a3416f 100755
--- a/src/common/config_parser.cpp
+++ b/src/common/config_parser.cpp
@@ -540,6 +540,9 @@ void ConfigParser::addOptionsTraining(cli::CLIWrapper& cli) {
"Overlap model computations with MPI communication",
true);
+ // model quantization training
+ addSuboptionsQuantization(cli);
+
// add ULR settings
addSuboptionsULR(cli);
@@ -909,6 +912,22 @@ void ConfigParser::addSuboptionsULR(cli::CLIWrapper& cli) {
// clang-format on
}
+void ConfigParser::addSuboptionsQuantization(cli::CLIWrapper& cli) {
+ // clang-format off
+ // model quantization training
+ cli.add<size_t>("--quantize-bits",
+ "Number of bits to compress model to. Set to 0 to disable",
+ 0);
+ cli.add<size_t>("--quantize-optimization-steps",
+ "Adjust quantization scaling factor for N steps",
+ 0);
+ cli.add<bool>("--quantize-log-based",
+ "Uses log-based quantization");
+ cli.add<bool>("--quantize-biases",
+ "Apply quantization to biases");
+ // clang-format on
+}
+
cli::mode ConfigParser::getMode() const { return mode_; }
diff --git a/src/common/config_parser.h b/src/common/config_parser.h
index 933bbb59..18b6eccb 100644
--- a/src/common/config_parser.h
+++ b/src/common/config_parser.h
@@ -138,6 +138,7 @@ private:
void addSuboptionsInputLength(cli::CLIWrapper&);
void addSuboptionsTSV(cli::CLIWrapper&);
void addSuboptionsULR(cli::CLIWrapper&);
+ void addSuboptionsQuantization(cli::CLIWrapper&);
// Extract paths to all config files found in the config object.
// Look at --config option and model.npz.yml files.
diff --git a/src/common/config_validator.cpp b/src/common/config_validator.cpp
index 44d7eec6..46dcee5e 100644
--- a/src/common/config_validator.cpp
+++ b/src/common/config_validator.cpp
@@ -147,6 +147,12 @@ void ConfigValidator::validateOptionsTraining() const {
|| get<std::string>("ulr-keys-vectors") == "")),
"ULR enablign requires query and keys vectors specified with --ulr-query-vectors and "
"--ulr-keys-vectors option");
+
+ // validate model quantization
+ size_t bits = get<size_t>("quantize-bits");
+ ABORT_IF(bits > 32, "Invalid quantization bits. Must be from 0 to 32 bits");
+
+ ABORT_IF(bits > 0 && !get<bool>("sync-sgd"), "Model quantization only works with synchronous training (--sync-sgd)");
}
void ConfigValidator::validateModelExtension(cli::mode mode) const {
diff --git a/src/functional/operators.h b/src/functional/operators.h
index 25982009..6345bfb6 100755
--- a/src/functional/operators.h
+++ b/src/functional/operators.h
@@ -25,6 +25,10 @@ struct Ops {
static HOST_DEVICE_INLINE T neg(const T&) { ABORT("Unknown type"); }
static HOST_DEVICE_INLINE T sgn(const T&) { ABORT("Unknown type"); }
+ static HOST_DEVICE_INLINE T round(const T&) { ABORT("Unknown type"); }
+ static HOST_DEVICE_INLINE T floor(const T&) { ABORT("Unknown type"); }
+ static HOST_DEVICE_INLINE T ceil(const T&) { ABORT("Unknown type"); }
+
static HOST_DEVICE_INLINE T add(const T&, const T&) { ABORT("Unknown type"); }
static HOST_DEVICE_INLINE T sub(const T&, const T&) { ABORT("Unknown type"); }
static HOST_DEVICE_INLINE T mul(const T&, const T&) { ABORT("Unknown type"); }
@@ -79,6 +83,10 @@ struct Ops<float> {
static HOST_DEVICE_INLINE float neg(const float& x) { return -x; }
static HOST_DEVICE_INLINE float sgn(const float& x) { return (float)((0 < x) - (x < 0)); }
+ static HOST_DEVICE_INLINE float round(const float& x) { return roundf(x); }
+ static HOST_DEVICE_INLINE float floor(const float& x) { return floorf(x); }
+ static HOST_DEVICE_INLINE float ceil(const float& x) { return ceilf(x); }
+
static HOST_DEVICE_INLINE float add(const float& x, const float& y) { return x + y; }
static HOST_DEVICE_INLINE float sub(const float& x, const float& y) { return x - y; }
static HOST_DEVICE_INLINE float mul(const float& x, const float& y) { return x * y; }
@@ -144,6 +152,10 @@ struct Ops<double> {
static HOST_DEVICE_INLINE double neg(const double& x) { return -x; }
static HOST_DEVICE_INLINE double sgn(const double& x) { return (0 < x) - (x < 0); }
+ static HOST_DEVICE_INLINE double round(const double& x) { return std::round(x); }
+ static HOST_DEVICE_INLINE double floor(const double& x) { return std::floor(x); }
+ static HOST_DEVICE_INLINE double ceil(const double& x) { return std::ceil(x); }
+
static HOST_DEVICE_INLINE double add(const double& x, const double& y) { return x + y; }
static HOST_DEVICE_INLINE double sub(const double& x, const double& y) { return x - y; }
static HOST_DEVICE_INLINE double mul(const double& x, const double& y) { return x * y; }
@@ -254,6 +266,10 @@ struct Ops<float32x4> {
// @TODO: get rid of loop4 with proper intrisics
static inline float32x4 sgn(const float32x4& x) { return loop4(Ops<float>::sgn, x); }
+ static inline float32x4 round(const float32x4& x) { return _mm_round_ps(x, _MM_FROUND_TO_NEAREST_INT); }
+ static inline float32x4 floor(const float32x4& x) { return _mm_floor_ps(x); }
+ static inline float32x4 ceil(const float32x4& x) { return _mm_ceil_ps(x); }
+
static inline float32x4 add(const float32x4& x, const float32x4& y) { return _mm_add_ps(x, y); }
static inline float32x4 sub(const float32x4& x, const float32x4& y) { return _mm_sub_ps(x, y); }
static inline float32x4 mul(const float32x4& x, const float32x4& y) { return _mm_mul_ps(x, y); }
@@ -380,6 +396,10 @@ struct Ops<float32x8> {
// @TODO: get rid of loop8 with proper intrisics
static inline float32x8 sgn(const float32x8& x) { return loop8(Ops<float>::sgn, x); }
+ static inline float32x8 round(const float32x8& x) { return _mm256_round_ps(x, _MM_FROUND_TO_NEAREST_INT); }
+ static inline float32x8 floor(const float32x8& x) { return _mm256_floor_ps(x); }
+ static inline float32x8 ceil(const float32x8& x) { return _mm256_ceil_ps(x); }
+
static inline float32x8 add(const float32x8& x, const float32x8& y) { return _mm256_add_ps(x, y); }
static inline float32x8 sub(const float32x8& x, const float32x8& y) { return _mm256_sub_ps(x, y); }
static inline float32x8 mul(const float32x8& x, const float32x8& y) { return _mm256_mul_ps(x, y); }
@@ -473,6 +493,10 @@ struct Ops<half> {
static DEVICE_INLINE half abs(const half& x) { return fabs((float)x); }// @TODO half has this information somewhere in the struct, right?
static DEVICE_INLINE half sgn(const half& x) { half zero = 0.f; return (zero < x) - (x < zero); } // @TODO half has this information somewhere in the struct, right?
+ static DEVICE_INLINE half round(const half& x) { return hrint(x); }
+ static DEVICE_INLINE half floor(const half& x) { return hfloor(x); }
+ static DEVICE_INLINE half ceil(const half& x) { return hceil(x); }
+
static DEVICE_INLINE half add(const half& x, const half& y) { return x + y; }
static DEVICE_INLINE half sub(const half& x, const half& y) { return x - y; }
static DEVICE_INLINE half mul(const half& x, const half& y) { return x * y; }
@@ -578,6 +602,10 @@ UNARY(Sqrt, sqrt, Ops<ElementType>::sqrt(x));
UNARY(Neg, operator-, Ops<ElementType>::neg(x));
UNARY(Sgn, sgn, Ops<ElementType>::sgn(x));
+UNARY(Round, round, Ops<ElementType>::round(x));
+UNARY(Floor, floor, Ops<ElementType>::floor(x));
+UNARY(Ceil, ceil, Ops<ElementType>::ceil(x));
+
BINARY(Plus, operator+, Ops<ElementType>::add(x, y));
BINARY(Minus, operator-, Ops<ElementType>::sub(x, y));
BINARY(Mult, operator*, Ops<ElementType>::mul(x, y));
diff --git a/src/optimizers/quantizer.cpp b/src/optimizers/quantizer.cpp
new file mode 100644
index 00000000..fc175672
--- /dev/null
+++ b/src/optimizers/quantizer.cpp
@@ -0,0 +1,185 @@
+#include <cmath>
+
+#include "optimizers/quantizer.h"
+#include "tensors/tensor_allocator.h"
+#include "tensors/tensor_operators.h"
+
+#include "functional/functional.h"
+
+namespace marian {
+
+/* simulate a fixed quantization for values in data.
+ * For example:
+ * data = [0.96, 0.73, 0.82, 0.84, 0.42, 0.29, 0.65]
+ * res = [1 , 0.6, 0.8 , 0.8 , 0.4, 0.2 , 0.6 ]
+ *
+ * @param data contains the original data
+ * @param res will contain the resulting quantized data. set data = res for in-place operation
+ * @param numCenters the number of quantized centers in absolute. It should be 2^(bit-1)
+ * @param S stores the scaling factor.
+ */
+static void fixedPointQuantization(Tensor data, Tensor res, int numCenters, float S) {
+ using namespace functional;
+ float multiplier = numCenters / S;
+
+ // clip based on the scale
+ Element(_1 = clip(_2, S), res, data);
+
+ // get the quantization bin ID
+ Element(_1 = round(_1 * multiplier), res);
+
+ // revert back to floating point representation
+ Element(_1 /= multiplier, res);
+}
+
+/* simulate a log-based quantization for values in data. The quantized value will be in the form of
+ * S*2^q For example:
+ * data = [0.9, 0.7, 0.5, 0.2 , 1.1]
+ * res = [1, 0.5, 0.5, 0.25, 1 ]
+ *
+ * @param data contains the original data.
+ * @param res will contain the resulting quantized data. set data = res for in-place operation.
+ * @param size the data size.
+ * @param numCenters the number of quantized centers in absolute. It should be 2^(bit-1).
+ * @param S stores the scaling factor.
+ * @param base for log quantized center. Default of 2.
+ */
+static void logQuantization(Tensor data, Tensor res, int numCenters, float S, float base = 2.0f) {
+ using namespace functional;
+
+ // clip based on the scaling factor
+ Element(_1 = clip(_2, S), res, data);
+
+ // multiplier such that the quantization is rounded in normal-space instead of log space.
+ // 4/3 for base = 2. example: 11.8 should be quantized to 8, instead of 16.
+ float mult = (2.0f * base) / (1.0f + base);
+
+ // log-quantization works as the following:
+ // 1. capture the sign:
+ // sign = sgn(v)
+ // 2. get the quantization center:
+ // qc = floor(log2(abs(v/S) * _mult))
+ // 3. clip the center to make sure we have no more than 2^(bit-1) centers.
+ // qc = clip(qc, num_centers)
+ // 4. revert back to floating point space:
+ // q = 2^qc * S * sign
+ //
+ // The above steps are writen in 1 call as below, to avoid reserving extra Tensors:
+
+ Element(
+ _1 = sgn(_1) // revert the sign back
+ * S // revert the scaling function
+ * pow(base, // revert from log space to normal FP represtation
+ clip(floor(log(abs(_1 / S) * mult) / log(base)), // get the quantization center
+ (float) numCenters)), // clip
+ res);
+}
+
+/* Quantize all the parameters (except bias, unless enabled via --quantize-biases).
+ * Quantization only works if we store the quantization error residual.
+ * The stored residual will be added for the next quantization.
+ * @param graph is the model graph to be quantized (in-place).
+ */
+void ModelQuantizer::quantize(Ptr<ExpressionGraph> graph) {
+ // lazily allocate tensor for error feedback mechanism
+ if(!errorResidual_) {
+ LOG(info, "Quantizing the model to {}-bits", bits_);
+
+ int numElements = (int)graph->params()->vals()->size();
+ auto allocator = New<TensorAllocator>(graph->getBackend());
+ allocator->reserveExact(graph->params()->vals()->memory()->size());
+ allocator->allocate(errorResidual_, {1, numElements});
+
+ allocators_.push_back(allocator);
+ isFirstError_ = true;
+ }
+
+ {
+ // apply error feedback mechanism
+ using namespace functional;
+ Element(_1 += _2, graph->params()->vals(), errorResidual_); // add the previous error residual to the current model
+ errorResidual_->copyFrom(graph->params()->vals()); // set the model as the error-residual (will be updated below)
+ }
+
+ for(auto p : *graph->params()) {
+ // quantize weight tensors, biases optional
+ if(quantBias_ || p->val()->shape()[0] > 1)
+ quantizeImpl(p->val());
+ }
+
+ // get new error residual. Skip the first one.
+ if (!isFirstError_) {
+ using namespace functional;
+ Element(_1 -= _2, errorResidual_, graph->params()->vals()); // new error-residual = original model - quantized model
+ }
+ else {
+ errorResidual_->set(0);
+ isFirstError_ = false;
+ }
+}
+
+
+/* Tensor quantization implementation.
+ * @param t is the tensor to be quantized (in-place)
+ */
+void ModelQuantizer::quantizeImpl(Tensor t) {
+ if(!tempVar_) {
+ // init the swap tensor
+ auto allocator = New<TensorAllocator>(t->getBackend());
+ allocator->reserveExact(sizeof(float));
+ allocator->allocate(tempVar_, {1, 1});
+ allocators_.push_back(allocator);
+ }
+
+ // init additional tensor for scaling optimization
+ if(!delta_ && optSteps_ > 0) {
+ int msize = (int) t->size();
+ auto allocator = New<TensorAllocator>(t->getBackend());
+ allocator->reserveExact(msize * sizeof(float));
+ allocator->allocate(delta_, {1, msize});
+ allocators_.push_back(allocator);
+ }
+
+ Tensor q = delta_->subtensor(0, t->size()); // to store the quantized t
+ Tensor tflat = t->subtensor(0, t->size()); // flatten t for reduce
+
+ float S = 0.0f; // scaling factor S
+ // get intial scaling factor (S) based on max element in Tensor
+ {
+ using namespace functional;
+ Reduce(abs(_1), max(_1, _2), 0.0f, tempVar_, tflat);
+ S = tempVar_->get(0);
+ }
+
+ // optimize the scaling factor S
+ for(int i = 0; i < optSteps_; i++) {
+ // let t be the original tensor, and q be the quantized tensor, and q = S*a where S is the
+ // scaling factor. we want to optimize S to minimize MSE(S*a - t) therefore, S =
+ // sum(a*t)/sum(a*a) see https://www.aclweb.org/anthology/2020.ngt-1.4.pdf for more details.
+ if(logQuant_)
+ logQuantization(t, q, (1 << (bits_ - 1)) - 1, S);
+ else
+ fixedPointQuantization(t, q, (1 << (bits_ - 1)) - 1, S);
+
+ // obtain a by applying q/=S
+ using namespace functional;
+ Element(_1 /= S, delta_);
+
+ // get sum(a*t)
+ Reduce(_1 * _2, tempVar_, tflat, q);
+ float deltaNumer = tempVar_->get(0);
+
+ // get sum(a*a)
+ Reduce(_1 * _1, tempVar_, q);
+ float deltaDenom = tempVar_->get(0);
+
+ S = deltaNumer / deltaDenom; // S = sum(a*t)/sum(a*a)
+ }
+
+ // final quantization
+ if(logQuant_) {
+ logQuantization(t, t, (1 << (bits_ - 1)) - 1, S);
+ } else
+ fixedPointQuantization(t, t,(1 << (bits_ - 1)) - 1, S);
+}
+} // namespace marian
diff --git a/src/optimizers/quantizer.h b/src/optimizers/quantizer.h
new file mode 100644
index 00000000..e385ed5e
--- /dev/null
+++ b/src/optimizers/quantizer.h
@@ -0,0 +1,48 @@
+#pragma once
+
+#include "common/options.h"
+#include "functional/functional.h"
+#include "graph/expression_graph.h"
+#include "tensors/backend.h"
+#include "tensors/tensor.h"
+#include "tensors/tensor_allocator.h"
+#include "tensors/tensor_operators.h"
+
+namespace marian {
+
+/* Class to implement quantization of all the parameters in a model graph
+ * This class handles the required error-feedback mechanism internally.
+ * Example:
+ * auto mq = New<ModelQuantizer>(options_);
+ * mq->quantize(graph_);
+ *
+ * Parameters in graph_ will be quantized every time quantize is called.
+ * The internal error-residual is also updated each quantize call,
+ * therefore, use the same ModelQuantizer object to quantize the same graph.
+ */
+class ModelQuantizer {
+public:
+ ModelQuantizer(Ptr<Options> options)
+ : bits_{options->get<size_t>("quantize-bits")},
+ optSteps_{options->get<size_t>("quantize-optimization-steps")},
+ quantBias_{options->get<bool>("quantize-biases")},
+ logQuant_{options->get<bool>("quantize-log-based")} {}
+
+ void quantize(Ptr<ExpressionGraph> graph);
+
+protected:
+ void quantizeImpl(Tensor t);
+
+ size_t bits_;
+ size_t optSteps_;
+ bool quantBias_;
+ bool logQuant_;
+ bool isFirstError_;
+
+ std::vector<Ptr<TensorAllocator>> allocators_;
+
+ Tensor errorResidual_; // Tensor to store the error-residual
+ Tensor delta_; // temporary Tensor for storing q to calculate optimal S
+ Tensor tempVar_; // single element Tensor for Reduce swap variable
+};
+} // namespace marian
diff --git a/src/tensors/gpu/add.inc b/src/tensors/gpu/add.inc
index 64093253..903ee3ba 100755
--- a/src/tensors/gpu/add.inc
+++ b/src/tensors/gpu/add.inc
@@ -35,4 +35,5 @@ template void Add<BinaryFunctor<elem::Mult, BinaryFunctor<elem::Mult, Assignee<1
template void Add<BinaryFunctor<elem::Mult, Assignee<1>, BinaryFunctor<elem::Plus, BinaryFunctor<elem::Mult, Capture, Assignee<3>>, BinaryFunctor<elem::Mult, UnaryFunctor<elem::Sigmoid, BinaryFunctor<elem::Mult, Capture, Assignee<2>>>, BinaryFunctor<elem::Minus, Capture, BinaryFunctor<elem::Mult, Capture, Assignee<3>>>>>>, marian::Tensor, marian::Tensor, marian::Tensor >(BinaryFunctor<elem::Mult, Assignee<1>, BinaryFunctor<elem::Plus, BinaryFunctor<elem::Mult, Capture, Assignee<3>>, BinaryFunctor<elem::Mult, UnaryFunctor<elem::Sigmoid, BinaryFunctor<elem::Mult, Capture, Assignee<2>>>, BinaryFunctor<elem::Minus, Capture, BinaryFunctor<elem::Mult, Capture, Assignee<3>>>>>>, float, marian::Tensor, marian::Tensor, marian::Tensor, marian::Tensor);
template void marian::gpu::Add<marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::UnaryFunctor<marian::functional::elem::Sgn, marian::functional::Assignee<1> >, marian::functional::Assignee<2> >, IntrusivePtr<marian::TensorBase>, IntrusivePtr<marian::TensorBase> >(marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::UnaryFunctor<marian::functional::elem::Sgn, marian::functional::Assignee<1> >, marian::functional::Assignee<2> >, float, IntrusivePtr<marian::TensorBase>, IntrusivePtr<marian::TensorBase>, IntrusivePtr<marian::TensorBase>);
template void marian::gpu::Add<marian::functional::UnaryFunctor<marian::functional::elem::Abs, marian::functional::BinaryFunctor<marian::functional::elem::Minus, marian::functional::Assignee<1>, marian::functional::Assignee<2> > >, IntrusivePtr<marian::TensorBase>, IntrusivePtr<marian::TensorBase> >(marian::functional::UnaryFunctor<marian::functional::elem::Abs, marian::functional::BinaryFunctor<marian::functional::elem::Minus, marian::functional::Assignee<1>, marian::functional::Assignee<2> > >, float, IntrusivePtr<marian::TensorBase>, IntrusivePtr<marian::TensorBase>, IntrusivePtr<marian::TensorBase>);
+template void marian::gpu::Aggregate<marian::functional::UnaryFunctor<marian::functional::elem::Abs, marian::functional::Assignee<1> >, marian::functional::BinaryFunctor<marian::functional::elem::Max, marian::functional::Assignee<1>, marian::functional::Assignee<2> >, IntrusivePtr<marian::TensorBase> >(marian::functional::UnaryFunctor<marian::functional::elem::Abs, marian::functional::Assignee<1> >, float, marian::functional::BinaryFunctor<marian::functional::elem::Max, marian::functional::Assignee<1>, marian::functional::Assignee<2> >, float, IntrusivePtr<marian::TensorBase>, IntrusivePtr<marian::TensorBase>);
template void marian::gpu::Add<marian::functional::BinaryFunctor<marian::functional::elem::Mult,marian::functional::Assignee<1>,marian::functional::UnaryFunctor<marian::functional::elem::Cos,marian::functional::Assignee<2> > >,class IntrusivePtr<class marian::TensorBase>,class IntrusivePtr<class marian::TensorBase> >(marian::functional::BinaryFunctor<marian::functional::elem::Mult,marian::functional::Assignee<1>,marian::functional::UnaryFunctor<marian::functional::elem::Cos,marian::functional::Assignee<2> > >,float,class IntrusivePtr<class marian::TensorBase>,class IntrusivePtr<class marian::TensorBase>,class IntrusivePtr<class marian::TensorBase>);
diff --git a/src/tensors/gpu/add_all.inc b/src/tensors/gpu/add_all.inc
index 2147f260..29a3a5d6 100644
--- a/src/tensors/gpu/add_all.inc
+++ b/src/tensors/gpu/add_all.inc
@@ -36,7 +36,7 @@ template void AggregateAll<float, float, BinaryFunctor<elem::Mult, Assignee<1>,
template void marian::AggregateAll<float, float, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::UnaryFunctor<marian::functional::elem::Sgn, marian::functional::Assignee<1> >, marian::functional::Assignee<2> >, marian::functional::BinaryFunctor<marian::functional::elem::Plus, marian::functional::Assignee<1>, marian::functional::Assignee<2> > >(std::shared_ptr<marian::Allocator>, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::UnaryFunctor<marian::functional::elem::Sgn, marian::functional::Assignee<1> >, marian::functional::Assignee<2> >, float, marian::functional::BinaryFunctor<marian::functional::elem::Plus, marian::functional::Assignee<1>, marian::functional::Assignee<2> >, float, IntrusivePtr<marian::TensorBase>, IntrusivePtr<marian::TensorBase>, IntrusivePtr<marian::TensorBase>);
template void marian::AggregateAll<float, float, marian::functional::UnaryFunctor<marian::functional::elem::Abs, marian::functional::BinaryFunctor<marian::functional::elem::Minus, marian::functional::Assignee<1>, marian::functional::Assignee<2> > >, marian::functional::BinaryFunctor<marian::functional::elem::Plus, marian::functional::Assignee<1>, marian::functional::Assignee<2> > >(std::shared_ptr<marian::Allocator>, marian::functional::UnaryFunctor<marian::functional::elem::Abs, marian::functional::BinaryFunctor<marian::functional::elem::Minus, marian::functional::Assignee<1>, marian::functional::Assignee<2> > >, float, marian::functional::BinaryFunctor<marian::functional::elem::Plus, marian::functional::Assignee<1>, marian::functional::Assignee<2> >, float, IntrusivePtr<marian::TensorBase>, IntrusivePtr<marian::TensorBase>, IntrusivePtr<marian::TensorBase>);
template void marian::AggregateAll<float,float,marian::functional::BinaryFunctor<marian::functional::elem::Mult,marian::functional::Assignee<1>,marian::functional::UnaryFunctor<marian::functional::elem::Cos,marian::functional::Assignee<2> > >,marian::functional::BinaryFunctor<marian::functional::elem::Plus,marian::functional::Assignee<1>,marian::functional::Assignee<2> > >(std::shared_ptr<marian::Allocator>,marian::functional::BinaryFunctor<marian::functional::elem::Mult,marian::functional::Assignee<1>,marian::functional::UnaryFunctor<marian::functional::elem::Cos,marian::functional::Assignee<2> > >,float,marian::functional::BinaryFunctor<marian::functional::elem::Plus,marian::functional::Assignee<1>,marian::functional::Assignee<2> >,float,IntrusivePtr<marian::TensorBase>,IntrusivePtr<marian::TensorBase>,IntrusivePtr<marian::TensorBase>);
-
+template void marian::AggregateAll<float, float, marian::functional::UnaryFunctor<marian::functional::elem::Abs, marian::functional::Assignee<1> >, marian::functional::BinaryFunctor<marian::functional::elem::Max, marian::functional::Assignee<1>, marian::functional::Assignee<2> > >(std::shared_ptr<marian::Allocator>, marian::functional::UnaryFunctor<marian::functional::elem::Abs, marian::functional::Assignee<1> >, float, marian::functional::BinaryFunctor<marian::functional::elem::Max, marian::functional::Assignee<1>, marian::functional::Assignee<2> >, float, IntrusivePtr<marian::TensorBase>, IntrusivePtr<marian::TensorBase>);
#if COMPILE_FP16
template void AggregateAll<__half, float, BinaryFunctor<elem::Mult, BinaryFunctor<elem::Mult, Capture, Assignee<1>>, Assignee<2>>, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>>(std::shared_ptr<Allocator>, BinaryFunctor<elem::Mult, BinaryFunctor<elem::Mult, Capture, Assignee<1>>, Assignee<2>>, float, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor, marian::Tensor);
template void AggregateAll<__half, float, BinaryFunctor<elem::Mult, BinaryFunctor<elem::Mult, Capture, BinaryFunctor<elem::Div, Capture, Assignee<1>>>, Assignee<2>>, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>>(std::shared_ptr<Allocator>, BinaryFunctor<elem::Mult, BinaryFunctor<elem::Mult, Capture, BinaryFunctor<elem::Div, Capture, Assignee<1>>>, Assignee<2>>, float, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor, marian::Tensor);
@@ -74,4 +74,5 @@ template void AggregateAll<__half, float, BinaryFunctor<elem::Mult, Assignee<1>,
template void marian::AggregateAll<__half, float, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::UnaryFunctor<marian::functional::elem::Sgn, marian::functional::Assignee<1> >, marian::functional::Assignee<2> >, marian::functional::BinaryFunctor<marian::functional::elem::Plus, marian::functional::Assignee<1>, marian::functional::Assignee<2> > >(std::shared_ptr<marian::Allocator>, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::UnaryFunctor<marian::functional::elem::Sgn, marian::functional::Assignee<1> >, marian::functional::Assignee<2> >, float, marian::functional::BinaryFunctor<marian::functional::elem::Plus, marian::functional::Assignee<1>, marian::functional::Assignee<2> >, float, IntrusivePtr<marian::TensorBase>, IntrusivePtr<marian::TensorBase>, IntrusivePtr<marian::TensorBase>);
template void marian::AggregateAll<__half, float, marian::functional::UnaryFunctor<marian::functional::elem::Abs, marian::functional::BinaryFunctor<marian::functional::elem::Minus, marian::functional::Assignee<1>, marian::functional::Assignee<2> > >, marian::functional::BinaryFunctor<marian::functional::elem::Plus, marian::functional::Assignee<1>, marian::functional::Assignee<2> > >(std::shared_ptr<marian::Allocator>, marian::functional::UnaryFunctor<marian::functional::elem::Abs, marian::functional::BinaryFunctor<marian::functional::elem::Minus, marian::functional::Assignee<1>, marian::functional::Assignee<2> > >, float, marian::functional::BinaryFunctor<marian::functional::elem::Plus, marian::functional::Assignee<1>, marian::functional::Assignee<2> >, float, IntrusivePtr<marian::TensorBase>, IntrusivePtr<marian::TensorBase>, IntrusivePtr<marian::TensorBase>);
template void marian::AggregateAll<__half,float,marian::functional::BinaryFunctor<marian::functional::elem::Mult,marian::functional::Assignee<1>,marian::functional::UnaryFunctor<marian::functional::elem::Cos,marian::functional::Assignee<2> > >,marian::functional::BinaryFunctor<marian::functional::elem::Plus,marian::functional::Assignee<1>,marian::functional::Assignee<2> > >(std::shared_ptr<marian::Allocator>,marian::functional::BinaryFunctor<marian::functional::elem::Mult,marian::functional::Assignee<1>,marian::functional::UnaryFunctor<marian::functional::elem::Cos,marian::functional::Assignee<2> > >,float,marian::functional::BinaryFunctor<marian::functional::elem::Plus,marian::functional::Assignee<1>,marian::functional::Assignee<2> >,float,IntrusivePtr<marian::TensorBase>,IntrusivePtr<marian::TensorBase>,IntrusivePtr<marian::TensorBase>);
+template void marian::AggregateAll<__half, float, marian::functional::UnaryFunctor<marian::functional::elem::Abs, marian::functional::Assignee<1> >, marian::functional::BinaryFunctor<marian::functional::elem::Max, marian::functional::Assignee<1>, marian::functional::Assignee<2> > >(std::shared_ptr<marian::Allocator>, marian::functional::UnaryFunctor<marian::functional::elem::Abs, marian::functional::Assignee<1> >, float, marian::functional::BinaryFunctor<marian::functional::elem::Max, marian::functional::Assignee<1>, marian::functional::Assignee<2> >, float, IntrusivePtr<marian::TensorBase>, IntrusivePtr<marian::TensorBase>);
#endif
diff --git a/src/tensors/gpu/element.inc b/src/tensors/gpu/element.inc
index e2e74200..0eb75625 100755
--- a/src/tensors/gpu/element.inc
+++ b/src/tensors/gpu/element.inc
@@ -59,9 +59,15 @@ template void marian::gpu::Element<marian::functional::Assign<marian::functional
template void marian::gpu::Element<marian::functional::Assign<marian::functional::Var<1>, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::Assignee<2>, marian::functional::UnaryFunctor<marian::functional::elem::Sigmoid, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::Capture, marian::functional::Assignee<2> > > > >, marian::Tensor >(marian::functional::Assign<marian::functional::Var<1>, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::Assignee<2>, marian::functional::UnaryFunctor<marian::functional::elem::Sigmoid, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::Capture, marian::functional::Assignee<2> > > > >, marian::Tensor, marian::Tensor);
template void marian::gpu::Element<marian::functional::Assign<marian::functional::Var<1>, marian::functional::BinaryFunctor<marian::functional::elem::Plus, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::BinaryFunctor<marian::functional::elem::Lt, marian::functional::Assignee<1>, marian::functional::Capture>, marian::functional::Capture>, marian::functional::Capture> >>(marian::functional::Assign<marian::functional::Var<1>, marian::functional::BinaryFunctor<marian::functional::elem::Plus, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::BinaryFunctor<marian::functional::elem::Lt, marian::functional::Assignee<1>, marian::functional::Capture>, marian::functional::Capture>, marian::functional::Capture> >, marian::Tensor);
template void marian::gpu::Element<marian::functional::Assign<marian::functional::Var<1>, marian::functional::BinaryFunctor<marian::functional::elem::Minus, marian::functional::Assignee<1>, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::Capture, marian::functional::BinaryFunctor<marian::functional::elem::Plus, marian::functional::BinaryFunctor<marian::functional::elem::Div, marian::functional::Assignee<2>, marian::functional::Capture>, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::Capture, marian::functional::Assignee<1> > > > > >, IntrusivePtr<marian::TensorBase> >(marian::functional::Assign<marian::functional::Var<1>, marian::functional::BinaryFunctor<marian::functional::elem::Minus, marian::functional::Assignee<1>, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::Capture, marian::functional::BinaryFunctor<marian::functional::elem::Plus, marian::functional::BinaryFunctor<marian::functional::elem::Div, marian::functional::Assignee<2>, marian::functional::Capture>, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::Capture, marian::functional::Assignee<1> > > > > >, marian::Tensor, marian::Tensor);
+template void marian::gpu::Element<marian::functional::Assign<marian::functional::Var<1>, marian::functional::BinaryFunctor<marian::functional::elem::Minus, marian::functional::Assignee<1>, marian::functional::Assignee<2> > >, IntrusivePtr<marian::TensorBase> >(marian::functional::Assign<marian::functional::Var<1>, marian::functional::BinaryFunctor<marian::functional::elem::Minus, marian::functional::Assignee<1>, marian::functional::Assignee<2> > >, IntrusivePtr<marian::TensorBase>, IntrusivePtr<marian::TensorBase>);
template void marian::gpu::Element<marian::functional::Assign<marian::functional::Var<1>, marian::functional::UnaryFunctor<marian::functional::elem::Abs, marian::functional::Assignee<2> > >, IntrusivePtr<marian::TensorBase> >(marian::functional::Assign<marian::functional::Var<1>, marian::functional::UnaryFunctor<marian::functional::elem::Abs, marian::functional::Assignee<2> > >, IntrusivePtr<marian::TensorBase>, IntrusivePtr<marian::TensorBase>);
-template void marian::gpu::Element<marian::functional::Assign<marian::functional::Var<1>, marian::functional::UnaryFunctor<marian::functional::elem::Sin, marian::functional::Assignee<2> > >, IntrusivePtr<marian::TensorBase> >(marian::functional::Assign<marian::functional::Var<1>, marian::functional::UnaryFunctor<marian::functional::elem::Sin, marian::functional::Assignee<2> > >, IntrusivePtr<marian::TensorBase>, IntrusivePtr<marian::TensorBase>);
-
+template void marian::gpu::Element<marian::functional::Assign<marian::functional::Var<1>, marian::functional::UnaryFunctor<marian::functional::elem::Sin, marian::functional::Assignee<2> > >, IntrusivePtr<marian::TensorBase> >(marian::functional::Assign<marian::functional::Var<1>, marian::functional::UnaryFunctor<marian::functional::elem::Sin, marian::functional::Assignee<2> > >, IntrusivePtr<marian::TensorBase>, IntrusivePtr<marian::TensorBase>);
+template void marian::gpu::Element<marian::functional::Assign<marian::functional::Var<1>, marian::functional::BinaryFunctor<marian::functional::elem::Div, marian::functional::UnaryFunctor<marian::functional::elem::Round, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::Assignee<1>, marian::functional::Capture> >, marian::functional::Capture> >>(marian::functional::Assign<marian::functional::Var<1>, marian::functional::BinaryFunctor<marian::functional::elem::Div, marian::functional::UnaryFunctor<marian::functional::elem::Round, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::Assignee<1>, marian::functional::Capture> >, marian::functional::Capture> >, IntrusivePtr<marian::TensorBase>);
+template void marian::gpu::Element<marian::functional::Assign<marian::functional::Var<1>, marian::functional::UnaryFunctor<marian::functional::elem::Round, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::Assignee<1>, marian::functional::Capture> > >>(marian::functional::Assign<marian::functional::Var<1>, marian::functional::UnaryFunctor<marian::functional::elem::Round, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::Assignee<1>, marian::functional::Capture> > >, IntrusivePtr<marian::TensorBase>);
+template void marian::gpu::Element<marian::functional::Assign<marian::functional::Var<1>, marian::functional::UnaryFunctor<marian::functional::elem::Round, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::Assignee<2>, marian::functional::Capture> > >, IntrusivePtr<marian::TensorBase> >(marian::functional::Assign<marian::functional::Var<1>, marian::functional::UnaryFunctor<marian::functional::elem::Round, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::Assignee<2>, marian::functional::Capture> > >, IntrusivePtr<marian::TensorBase>, IntrusivePtr<marian::TensorBase>);
+template void marian::gpu::Element<marian::functional::Assign<marian::functional::Var<1>, marian::functional::UnaryFunctor<marian::functional::elem::Floor, marian::functional::BinaryFunctor<marian::functional::elem::Div, marian::functional::UnaryFunctor<marian::functional::elem::Log, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::UnaryFunctor<marian::functional::elem::Abs, marian::functional::BinaryFunctor<marian::functional::elem::Div, marian::functional::Assignee<2>, marian::functional::Capture> >, marian::functional::Capture> >, marian::functional::UnaryFunctor<marian::functional::elem::Log, marian::functional::Capture> > > >, IntrusivePtr<marian::TensorBase> >(marian::functional::Assign<marian::functional::Var<1>, marian::functional::UnaryFunctor<marian::functional::elem::Floor, marian::functional::BinaryFunctor<marian::functional::elem::Div, marian::functional::UnaryFunctor<marian::functional::elem::Log, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::UnaryFunctor<marian::functional::elem::Abs, marian::functional::BinaryFunctor<marian::functional::elem::Div, marian::functional::Assignee<2>, marian::functional::Capture> >, marian::functional::Capture> >, marian::functional::UnaryFunctor<marian::functional::elem::Log, marian::functional::Capture> > > >, IntrusivePtr<marian::TensorBase>, IntrusivePtr<marian::TensorBase>);
+template void marian::gpu::Element<marian::functional::Assign<marian::functional::Var<1>, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::BinaryFunctor<marian::functional::elem::Pow, marian::functional::Capture, marian::functional::Assignee<1> >, marian::functional::Capture>, marian::functional::UnaryFunctor<marian::functional::elem::Sgn, marian::functional::Assignee<2> > > >, IntrusivePtr<marian::TensorBase> >(marian::functional::Assign<marian::functional::Var<1>, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::BinaryFunctor<marian::functional::elem::Pow, marian::functional::Capture, marian::functional::Assignee<1> >, marian::functional::Capture>, marian::functional::UnaryFunctor<marian::functional::elem::Sgn, marian::functional::Assignee<2> > > >, IntrusivePtr<marian::TensorBase>, IntrusivePtr<marian::TensorBase>);
+template void marian::gpu::Element<marian::functional::Assign<marian::functional::Var<1>, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::UnaryFunctor<marian::functional::elem::Sgn, marian::functional::Assignee<1> >, marian::functional::Capture>, marian::functional::BinaryFunctor<marian::functional::elem::Pow, marian::functional::Capture, marian::functional::BinaryFunctor<marian::functional::elem::Clip, marian::functional::UnaryFunctor<marian::functional::elem::Floor, marian::functional::BinaryFunctor<marian::functional::elem::Div, marian::functional::UnaryFunctor<marian::functional::elem::Log, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::UnaryFunctor<marian::functional::elem::Abs, marian::functional::BinaryFunctor<marian::functional::elem::Div, marian::functional::Assignee<1>, marian::functional::Capture> >, marian::functional::Capture> >, marian::functional::UnaryFunctor<marian::functional::elem::Log, marian::functional::Capture> > >, marian::functional::Capture> > > >>(marian::functional::Assign<marian::functional::Var<1>, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::UnaryFunctor<marian::functional::elem::Sgn, marian::functional::Assignee<1> >, marian::functional::Capture>, marian::functional::BinaryFunctor<marian::functional::elem::Pow, marian::functional::Capture, marian::functional::BinaryFunctor<marian::functional::elem::Clip, marian::functional::UnaryFunctor<marian::functional::elem::Floor, marian::functional::BinaryFunctor<marian::functional::elem::Div, marian::functional::UnaryFunctor<marian::functional::elem::Log, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::UnaryFunctor<marian::functional::elem::Abs, marian::functional::BinaryFunctor<marian::functional::elem::Div, marian::functional::Assignee<1>, marian::functional::Capture> >, marian::functional::Capture> >, marian::functional::UnaryFunctor<marian::functional::elem::Log, marian::functional::Capture> > >, marian::functional::Capture> > > >, IntrusivePtr<marian::TensorBase>);
// How to add new specializations:
// When you use a new specialization, it will cause a link error of this form (example):
// .../src/tensors/tensor_operators.h:41: undefined reference to `void marian::gpu::Element<marian::functional::Assign< ... > ( ... )'
diff --git a/src/training/graph_group_sync.cpp b/src/training/graph_group_sync.cpp
index 1457faff..eaeefb42 100755
--- a/src/training/graph_group_sync.cpp
+++ b/src/training/graph_group_sync.cpp
@@ -297,6 +297,11 @@ void SyncGraphGroup::update(std::vector<Ptr<data::Batch>> subBatches, size_t num
return nullptr; // null if we reached beyond the end
};
+ // Helper to quantize the model
+ auto quantizeModel = [&](size_t idx, size_t /*begin*/, size_t /*end*/) {
+ quantizers_[idx]->quantize(graphs_[idx]);
+ };
+
// Upon very first execution, reset everything
if(first_) {
LOG(info, "[training] Batches are processed as {} process(es) x {} devices/process",
@@ -304,6 +309,14 @@ void SyncGraphGroup::update(std::vector<Ptr<data::Batch>> subBatches, size_t num
initialize(subBatches.front());
if(mvAvg_ && paramsAvg_.empty())
initializeAvg();
+
+ // initialize model quantization
+ if (options_->get<size_t>("quantize-bits") > 0) {
+ for (int idx = 0; idx < graphs_.size(); idx++)
+ quantizers_.push_back(New<ModelQuantizer>(options_));
+ comm_->foreach(quantizeModel);
+ }
+
first_ = false;
}
@@ -357,6 +370,11 @@ void SyncGraphGroup::update(std::vector<Ptr<data::Batch>> subBatches, size_t num
comm_->scatterReduceAndResetGrads(); // reduce gradients across all devices and MPI nodes into shards
comm_->foreach(update); // per-shard model-update
comm_->allGatherParams(); // distribute param value shards back
+
+ // Re-add the error residual from previous quantization,
+ // then re-quantize the model back and update the error residual
+ if (options_->get<size_t>("quantize-bits") > 0)
+ comm_->foreach(quantizeModel);
}
else
LOG(info, "[training] skipping {}-th update due to loss being {}", scheduler_->numberOfBatches(), localLoss.loss);
diff --git a/src/training/graph_group_sync.h b/src/training/graph_group_sync.h
index 147b172c..5b31a933 100755
--- a/src/training/graph_group_sync.h
+++ b/src/training/graph_group_sync.h
@@ -1,5 +1,6 @@
#pragma once
+#include "optimizers/quantizer.h"
#include "training/graph_group.h"
#include "training/communicator.h"
#include "training/exponential_smoothing.h"
@@ -23,6 +24,9 @@ class SyncGraphGroup : public GraphGroup, public ExponentialSmoothing {
std::vector<Ptr<TensorAllocator>> paramsAllocs_; // [deviceIndex] we must hold a reference to the memory until this class dies
// @TODO: move this nto ExponentialSmoothing, together with paramsAvg_?
+ // model quantizer
+ std::vector<Ptr<ModelQuantizer>> quantizers_;
+
// state for update()
bool first_{ true }; // gets interpreted and cleared by update()
std::vector<Ptr<data::Batch>> pendingBatches_; // in case of dynamic MB-size scaling, we temporarly buffer up batches across update() calls until enough