diff options
author | afaji <afaji321@gmail.com> | 2020-11-05 01:25:40 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-11-05 01:25:40 +0300 |
commit | e274ac76b2c967c69ac9ab053425bc0e29c6b718 (patch) | |
tree | 5114cb80e1a21235e79a716de05ae8fe60f0d71a | |
parent | fabbe203091dec345165dc63528072b86d177d19 (diff) |
Quantized model finetuning (#690)
* enable quantized training
-rw-r--r-- | CHANGELOG.md | 1 | ||||
-rw-r--r-- | src/CMakeLists.txt | 1 | ||||
-rwxr-xr-x | src/common/config_parser.cpp | 19 | ||||
-rw-r--r-- | src/common/config_parser.h | 1 | ||||
-rw-r--r-- | src/common/config_validator.cpp | 6 | ||||
-rwxr-xr-x | src/functional/operators.h | 28 | ||||
-rw-r--r-- | src/optimizers/quantizer.cpp | 185 | ||||
-rw-r--r-- | src/optimizers/quantizer.h | 48 | ||||
-rwxr-xr-x | src/tensors/gpu/add.inc | 1 | ||||
-rw-r--r-- | src/tensors/gpu/add_all.inc | 3 | ||||
-rwxr-xr-x | src/tensors/gpu/element.inc | 10 | ||||
-rwxr-xr-x | src/training/graph_group_sync.cpp | 18 | ||||
-rwxr-xr-x | src/training/graph_group_sync.h | 4 |
13 files changed, 322 insertions, 3 deletions
diff --git a/CHANGELOG.md b/CHANGELOG.md index 420f0445..8dbbec48 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -28,6 +28,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. - Support for reading from TSV files from STDIN and other sources during training and translation with options --tsv and --tsv-fields n. - Internal optional parameter in n-best list generation that skips empty hypotheses. +- Quantized training (fixed point or log-based quantization) with --quantize-bits N command ### Fixed - Fix bug causing certain reductions into scalars to be 0 on the GPU backend. Removed unnecessary warp shuffle instructions. diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 5c52fd50..6dcf7fd8 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -76,6 +76,7 @@ set(MARIAN_SOURCES rnn/cells.cpp rnn/attention.cpp + optimizers/quantizer.cpp optimizers/clippers.cpp optimizers/optimizers.cpp diff --git a/src/common/config_parser.cpp b/src/common/config_parser.cpp index e67c9b2c..68a3416f 100755 --- a/src/common/config_parser.cpp +++ b/src/common/config_parser.cpp @@ -540,6 +540,9 @@ void ConfigParser::addOptionsTraining(cli::CLIWrapper& cli) { "Overlap model computations with MPI communication", true); + // model quantization training + addSuboptionsQuantization(cli); + // add ULR settings addSuboptionsULR(cli); @@ -909,6 +912,22 @@ void ConfigParser::addSuboptionsULR(cli::CLIWrapper& cli) { // clang-format on } +void ConfigParser::addSuboptionsQuantization(cli::CLIWrapper& cli) { + // clang-format off + // model quantization training + cli.add<size_t>("--quantize-bits", + "Number of bits to compress model to. Set to 0 to disable", + 0); + cli.add<size_t>("--quantize-optimization-steps", + "Adjust quantization scaling factor for N steps", + 0); + cli.add<bool>("--quantize-log-based", + "Uses log-based quantization"); + cli.add<bool>("--quantize-biases", + "Apply quantization to biases"); + // clang-format on +} + cli::mode ConfigParser::getMode() const { return mode_; } diff --git a/src/common/config_parser.h b/src/common/config_parser.h index 933bbb59..18b6eccb 100644 --- a/src/common/config_parser.h +++ b/src/common/config_parser.h @@ -138,6 +138,7 @@ private: void addSuboptionsInputLength(cli::CLIWrapper&); void addSuboptionsTSV(cli::CLIWrapper&); void addSuboptionsULR(cli::CLIWrapper&); + void addSuboptionsQuantization(cli::CLIWrapper&); // Extract paths to all config files found in the config object. // Look at --config option and model.npz.yml files. diff --git a/src/common/config_validator.cpp b/src/common/config_validator.cpp index 44d7eec6..46dcee5e 100644 --- a/src/common/config_validator.cpp +++ b/src/common/config_validator.cpp @@ -147,6 +147,12 @@ void ConfigValidator::validateOptionsTraining() const { || get<std::string>("ulr-keys-vectors") == "")), "ULR enablign requires query and keys vectors specified with --ulr-query-vectors and " "--ulr-keys-vectors option"); + + // validate model quantization + size_t bits = get<size_t>("quantize-bits"); + ABORT_IF(bits > 32, "Invalid quantization bits. Must be from 0 to 32 bits"); + + ABORT_IF(bits > 0 && !get<bool>("sync-sgd"), "Model quantization only works with synchronous training (--sync-sgd)"); } void ConfigValidator::validateModelExtension(cli::mode mode) const { diff --git a/src/functional/operators.h b/src/functional/operators.h index 25982009..6345bfb6 100755 --- a/src/functional/operators.h +++ b/src/functional/operators.h @@ -25,6 +25,10 @@ struct Ops { static HOST_DEVICE_INLINE T neg(const T&) { ABORT("Unknown type"); } static HOST_DEVICE_INLINE T sgn(const T&) { ABORT("Unknown type"); } + static HOST_DEVICE_INLINE T round(const T&) { ABORT("Unknown type"); } + static HOST_DEVICE_INLINE T floor(const T&) { ABORT("Unknown type"); } + static HOST_DEVICE_INLINE T ceil(const T&) { ABORT("Unknown type"); } + static HOST_DEVICE_INLINE T add(const T&, const T&) { ABORT("Unknown type"); } static HOST_DEVICE_INLINE T sub(const T&, const T&) { ABORT("Unknown type"); } static HOST_DEVICE_INLINE T mul(const T&, const T&) { ABORT("Unknown type"); } @@ -79,6 +83,10 @@ struct Ops<float> { static HOST_DEVICE_INLINE float neg(const float& x) { return -x; } static HOST_DEVICE_INLINE float sgn(const float& x) { return (float)((0 < x) - (x < 0)); } + static HOST_DEVICE_INLINE float round(const float& x) { return roundf(x); } + static HOST_DEVICE_INLINE float floor(const float& x) { return floorf(x); } + static HOST_DEVICE_INLINE float ceil(const float& x) { return ceilf(x); } + static HOST_DEVICE_INLINE float add(const float& x, const float& y) { return x + y; } static HOST_DEVICE_INLINE float sub(const float& x, const float& y) { return x - y; } static HOST_DEVICE_INLINE float mul(const float& x, const float& y) { return x * y; } @@ -144,6 +152,10 @@ struct Ops<double> { static HOST_DEVICE_INLINE double neg(const double& x) { return -x; } static HOST_DEVICE_INLINE double sgn(const double& x) { return (0 < x) - (x < 0); } + static HOST_DEVICE_INLINE double round(const double& x) { return std::round(x); } + static HOST_DEVICE_INLINE double floor(const double& x) { return std::floor(x); } + static HOST_DEVICE_INLINE double ceil(const double& x) { return std::ceil(x); } + static HOST_DEVICE_INLINE double add(const double& x, const double& y) { return x + y; } static HOST_DEVICE_INLINE double sub(const double& x, const double& y) { return x - y; } static HOST_DEVICE_INLINE double mul(const double& x, const double& y) { return x * y; } @@ -254,6 +266,10 @@ struct Ops<float32x4> { // @TODO: get rid of loop4 with proper intrisics static inline float32x4 sgn(const float32x4& x) { return loop4(Ops<float>::sgn, x); } + static inline float32x4 round(const float32x4& x) { return _mm_round_ps(x, _MM_FROUND_TO_NEAREST_INT); } + static inline float32x4 floor(const float32x4& x) { return _mm_floor_ps(x); } + static inline float32x4 ceil(const float32x4& x) { return _mm_ceil_ps(x); } + static inline float32x4 add(const float32x4& x, const float32x4& y) { return _mm_add_ps(x, y); } static inline float32x4 sub(const float32x4& x, const float32x4& y) { return _mm_sub_ps(x, y); } static inline float32x4 mul(const float32x4& x, const float32x4& y) { return _mm_mul_ps(x, y); } @@ -380,6 +396,10 @@ struct Ops<float32x8> { // @TODO: get rid of loop8 with proper intrisics static inline float32x8 sgn(const float32x8& x) { return loop8(Ops<float>::sgn, x); } + static inline float32x8 round(const float32x8& x) { return _mm256_round_ps(x, _MM_FROUND_TO_NEAREST_INT); } + static inline float32x8 floor(const float32x8& x) { return _mm256_floor_ps(x); } + static inline float32x8 ceil(const float32x8& x) { return _mm256_ceil_ps(x); } + static inline float32x8 add(const float32x8& x, const float32x8& y) { return _mm256_add_ps(x, y); } static inline float32x8 sub(const float32x8& x, const float32x8& y) { return _mm256_sub_ps(x, y); } static inline float32x8 mul(const float32x8& x, const float32x8& y) { return _mm256_mul_ps(x, y); } @@ -473,6 +493,10 @@ struct Ops<half> { static DEVICE_INLINE half abs(const half& x) { return fabs((float)x); }// @TODO half has this information somewhere in the struct, right? static DEVICE_INLINE half sgn(const half& x) { half zero = 0.f; return (zero < x) - (x < zero); } // @TODO half has this information somewhere in the struct, right? + static DEVICE_INLINE half round(const half& x) { return hrint(x); } + static DEVICE_INLINE half floor(const half& x) { return hfloor(x); } + static DEVICE_INLINE half ceil(const half& x) { return hceil(x); } + static DEVICE_INLINE half add(const half& x, const half& y) { return x + y; } static DEVICE_INLINE half sub(const half& x, const half& y) { return x - y; } static DEVICE_INLINE half mul(const half& x, const half& y) { return x * y; } @@ -578,6 +602,10 @@ UNARY(Sqrt, sqrt, Ops<ElementType>::sqrt(x)); UNARY(Neg, operator-, Ops<ElementType>::neg(x)); UNARY(Sgn, sgn, Ops<ElementType>::sgn(x)); +UNARY(Round, round, Ops<ElementType>::round(x)); +UNARY(Floor, floor, Ops<ElementType>::floor(x)); +UNARY(Ceil, ceil, Ops<ElementType>::ceil(x)); + BINARY(Plus, operator+, Ops<ElementType>::add(x, y)); BINARY(Minus, operator-, Ops<ElementType>::sub(x, y)); BINARY(Mult, operator*, Ops<ElementType>::mul(x, y)); diff --git a/src/optimizers/quantizer.cpp b/src/optimizers/quantizer.cpp new file mode 100644 index 00000000..fc175672 --- /dev/null +++ b/src/optimizers/quantizer.cpp @@ -0,0 +1,185 @@ +#include <cmath> + +#include "optimizers/quantizer.h" +#include "tensors/tensor_allocator.h" +#include "tensors/tensor_operators.h" + +#include "functional/functional.h" + +namespace marian { + +/* simulate a fixed quantization for values in data. + * For example: + * data = [0.96, 0.73, 0.82, 0.84, 0.42, 0.29, 0.65] + * res = [1 , 0.6, 0.8 , 0.8 , 0.4, 0.2 , 0.6 ] + * + * @param data contains the original data + * @param res will contain the resulting quantized data. set data = res for in-place operation + * @param numCenters the number of quantized centers in absolute. It should be 2^(bit-1) + * @param S stores the scaling factor. + */ +static void fixedPointQuantization(Tensor data, Tensor res, int numCenters, float S) { + using namespace functional; + float multiplier = numCenters / S; + + // clip based on the scale + Element(_1 = clip(_2, S), res, data); + + // get the quantization bin ID + Element(_1 = round(_1 * multiplier), res); + + // revert back to floating point representation + Element(_1 /= multiplier, res); +} + +/* simulate a log-based quantization for values in data. The quantized value will be in the form of + * S*2^q For example: + * data = [0.9, 0.7, 0.5, 0.2 , 1.1] + * res = [1, 0.5, 0.5, 0.25, 1 ] + * + * @param data contains the original data. + * @param res will contain the resulting quantized data. set data = res for in-place operation. + * @param size the data size. + * @param numCenters the number of quantized centers in absolute. It should be 2^(bit-1). + * @param S stores the scaling factor. + * @param base for log quantized center. Default of 2. + */ +static void logQuantization(Tensor data, Tensor res, int numCenters, float S, float base = 2.0f) { + using namespace functional; + + // clip based on the scaling factor + Element(_1 = clip(_2, S), res, data); + + // multiplier such that the quantization is rounded in normal-space instead of log space. + // 4/3 for base = 2. example: 11.8 should be quantized to 8, instead of 16. + float mult = (2.0f * base) / (1.0f + base); + + // log-quantization works as the following: + // 1. capture the sign: + // sign = sgn(v) + // 2. get the quantization center: + // qc = floor(log2(abs(v/S) * _mult)) + // 3. clip the center to make sure we have no more than 2^(bit-1) centers. + // qc = clip(qc, num_centers) + // 4. revert back to floating point space: + // q = 2^qc * S * sign + // + // The above steps are writen in 1 call as below, to avoid reserving extra Tensors: + + Element( + _1 = sgn(_1) // revert the sign back + * S // revert the scaling function + * pow(base, // revert from log space to normal FP represtation + clip(floor(log(abs(_1 / S) * mult) / log(base)), // get the quantization center + (float) numCenters)), // clip + res); +} + +/* Quantize all the parameters (except bias, unless enabled via --quantize-biases). + * Quantization only works if we store the quantization error residual. + * The stored residual will be added for the next quantization. + * @param graph is the model graph to be quantized (in-place). + */ +void ModelQuantizer::quantize(Ptr<ExpressionGraph> graph) { + // lazily allocate tensor for error feedback mechanism + if(!errorResidual_) { + LOG(info, "Quantizing the model to {}-bits", bits_); + + int numElements = (int)graph->params()->vals()->size(); + auto allocator = New<TensorAllocator>(graph->getBackend()); + allocator->reserveExact(graph->params()->vals()->memory()->size()); + allocator->allocate(errorResidual_, {1, numElements}); + + allocators_.push_back(allocator); + isFirstError_ = true; + } + + { + // apply error feedback mechanism + using namespace functional; + Element(_1 += _2, graph->params()->vals(), errorResidual_); // add the previous error residual to the current model + errorResidual_->copyFrom(graph->params()->vals()); // set the model as the error-residual (will be updated below) + } + + for(auto p : *graph->params()) { + // quantize weight tensors, biases optional + if(quantBias_ || p->val()->shape()[0] > 1) + quantizeImpl(p->val()); + } + + // get new error residual. Skip the first one. + if (!isFirstError_) { + using namespace functional; + Element(_1 -= _2, errorResidual_, graph->params()->vals()); // new error-residual = original model - quantized model + } + else { + errorResidual_->set(0); + isFirstError_ = false; + } +} + + +/* Tensor quantization implementation. + * @param t is the tensor to be quantized (in-place) + */ +void ModelQuantizer::quantizeImpl(Tensor t) { + if(!tempVar_) { + // init the swap tensor + auto allocator = New<TensorAllocator>(t->getBackend()); + allocator->reserveExact(sizeof(float)); + allocator->allocate(tempVar_, {1, 1}); + allocators_.push_back(allocator); + } + + // init additional tensor for scaling optimization + if(!delta_ && optSteps_ > 0) { + int msize = (int) t->size(); + auto allocator = New<TensorAllocator>(t->getBackend()); + allocator->reserveExact(msize * sizeof(float)); + allocator->allocate(delta_, {1, msize}); + allocators_.push_back(allocator); + } + + Tensor q = delta_->subtensor(0, t->size()); // to store the quantized t + Tensor tflat = t->subtensor(0, t->size()); // flatten t for reduce + + float S = 0.0f; // scaling factor S + // get intial scaling factor (S) based on max element in Tensor + { + using namespace functional; + Reduce(abs(_1), max(_1, _2), 0.0f, tempVar_, tflat); + S = tempVar_->get(0); + } + + // optimize the scaling factor S + for(int i = 0; i < optSteps_; i++) { + // let t be the original tensor, and q be the quantized tensor, and q = S*a where S is the + // scaling factor. we want to optimize S to minimize MSE(S*a - t) therefore, S = + // sum(a*t)/sum(a*a) see https://www.aclweb.org/anthology/2020.ngt-1.4.pdf for more details. + if(logQuant_) + logQuantization(t, q, (1 << (bits_ - 1)) - 1, S); + else + fixedPointQuantization(t, q, (1 << (bits_ - 1)) - 1, S); + + // obtain a by applying q/=S + using namespace functional; + Element(_1 /= S, delta_); + + // get sum(a*t) + Reduce(_1 * _2, tempVar_, tflat, q); + float deltaNumer = tempVar_->get(0); + + // get sum(a*a) + Reduce(_1 * _1, tempVar_, q); + float deltaDenom = tempVar_->get(0); + + S = deltaNumer / deltaDenom; // S = sum(a*t)/sum(a*a) + } + + // final quantization + if(logQuant_) { + logQuantization(t, t, (1 << (bits_ - 1)) - 1, S); + } else + fixedPointQuantization(t, t,(1 << (bits_ - 1)) - 1, S); +} +} // namespace marian diff --git a/src/optimizers/quantizer.h b/src/optimizers/quantizer.h new file mode 100644 index 00000000..e385ed5e --- /dev/null +++ b/src/optimizers/quantizer.h @@ -0,0 +1,48 @@ +#pragma once + +#include "common/options.h" +#include "functional/functional.h" +#include "graph/expression_graph.h" +#include "tensors/backend.h" +#include "tensors/tensor.h" +#include "tensors/tensor_allocator.h" +#include "tensors/tensor_operators.h" + +namespace marian { + +/* Class to implement quantization of all the parameters in a model graph + * This class handles the required error-feedback mechanism internally. + * Example: + * auto mq = New<ModelQuantizer>(options_); + * mq->quantize(graph_); + * + * Parameters in graph_ will be quantized every time quantize is called. + * The internal error-residual is also updated each quantize call, + * therefore, use the same ModelQuantizer object to quantize the same graph. + */ +class ModelQuantizer { +public: + ModelQuantizer(Ptr<Options> options) + : bits_{options->get<size_t>("quantize-bits")}, + optSteps_{options->get<size_t>("quantize-optimization-steps")}, + quantBias_{options->get<bool>("quantize-biases")}, + logQuant_{options->get<bool>("quantize-log-based")} {} + + void quantize(Ptr<ExpressionGraph> graph); + +protected: + void quantizeImpl(Tensor t); + + size_t bits_; + size_t optSteps_; + bool quantBias_; + bool logQuant_; + bool isFirstError_; + + std::vector<Ptr<TensorAllocator>> allocators_; + + Tensor errorResidual_; // Tensor to store the error-residual + Tensor delta_; // temporary Tensor for storing q to calculate optimal S + Tensor tempVar_; // single element Tensor for Reduce swap variable +}; +} // namespace marian diff --git a/src/tensors/gpu/add.inc b/src/tensors/gpu/add.inc index 64093253..903ee3ba 100755 --- a/src/tensors/gpu/add.inc +++ b/src/tensors/gpu/add.inc @@ -35,4 +35,5 @@ template void Add<BinaryFunctor<elem::Mult, BinaryFunctor<elem::Mult, Assignee<1 template void Add<BinaryFunctor<elem::Mult, Assignee<1>, BinaryFunctor<elem::Plus, BinaryFunctor<elem::Mult, Capture, Assignee<3>>, BinaryFunctor<elem::Mult, UnaryFunctor<elem::Sigmoid, BinaryFunctor<elem::Mult, Capture, Assignee<2>>>, BinaryFunctor<elem::Minus, Capture, BinaryFunctor<elem::Mult, Capture, Assignee<3>>>>>>, marian::Tensor, marian::Tensor, marian::Tensor >(BinaryFunctor<elem::Mult, Assignee<1>, BinaryFunctor<elem::Plus, BinaryFunctor<elem::Mult, Capture, Assignee<3>>, BinaryFunctor<elem::Mult, UnaryFunctor<elem::Sigmoid, BinaryFunctor<elem::Mult, Capture, Assignee<2>>>, BinaryFunctor<elem::Minus, Capture, BinaryFunctor<elem::Mult, Capture, Assignee<3>>>>>>, float, marian::Tensor, marian::Tensor, marian::Tensor, marian::Tensor); template void marian::gpu::Add<marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::UnaryFunctor<marian::functional::elem::Sgn, marian::functional::Assignee<1> >, marian::functional::Assignee<2> >, IntrusivePtr<marian::TensorBase>, IntrusivePtr<marian::TensorBase> >(marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::UnaryFunctor<marian::functional::elem::Sgn, marian::functional::Assignee<1> >, marian::functional::Assignee<2> >, float, IntrusivePtr<marian::TensorBase>, IntrusivePtr<marian::TensorBase>, IntrusivePtr<marian::TensorBase>); template void marian::gpu::Add<marian::functional::UnaryFunctor<marian::functional::elem::Abs, marian::functional::BinaryFunctor<marian::functional::elem::Minus, marian::functional::Assignee<1>, marian::functional::Assignee<2> > >, IntrusivePtr<marian::TensorBase>, IntrusivePtr<marian::TensorBase> >(marian::functional::UnaryFunctor<marian::functional::elem::Abs, marian::functional::BinaryFunctor<marian::functional::elem::Minus, marian::functional::Assignee<1>, marian::functional::Assignee<2> > >, float, IntrusivePtr<marian::TensorBase>, IntrusivePtr<marian::TensorBase>, IntrusivePtr<marian::TensorBase>); +template void marian::gpu::Aggregate<marian::functional::UnaryFunctor<marian::functional::elem::Abs, marian::functional::Assignee<1> >, marian::functional::BinaryFunctor<marian::functional::elem::Max, marian::functional::Assignee<1>, marian::functional::Assignee<2> >, IntrusivePtr<marian::TensorBase> >(marian::functional::UnaryFunctor<marian::functional::elem::Abs, marian::functional::Assignee<1> >, float, marian::functional::BinaryFunctor<marian::functional::elem::Max, marian::functional::Assignee<1>, marian::functional::Assignee<2> >, float, IntrusivePtr<marian::TensorBase>, IntrusivePtr<marian::TensorBase>); template void marian::gpu::Add<marian::functional::BinaryFunctor<marian::functional::elem::Mult,marian::functional::Assignee<1>,marian::functional::UnaryFunctor<marian::functional::elem::Cos,marian::functional::Assignee<2> > >,class IntrusivePtr<class marian::TensorBase>,class IntrusivePtr<class marian::TensorBase> >(marian::functional::BinaryFunctor<marian::functional::elem::Mult,marian::functional::Assignee<1>,marian::functional::UnaryFunctor<marian::functional::elem::Cos,marian::functional::Assignee<2> > >,float,class IntrusivePtr<class marian::TensorBase>,class IntrusivePtr<class marian::TensorBase>,class IntrusivePtr<class marian::TensorBase>);
diff --git a/src/tensors/gpu/add_all.inc b/src/tensors/gpu/add_all.inc index 2147f260..29a3a5d6 100644 --- a/src/tensors/gpu/add_all.inc +++ b/src/tensors/gpu/add_all.inc @@ -36,7 +36,7 @@ template void AggregateAll<float, float, BinaryFunctor<elem::Mult, Assignee<1>, template void marian::AggregateAll<float, float, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::UnaryFunctor<marian::functional::elem::Sgn, marian::functional::Assignee<1> >, marian::functional::Assignee<2> >, marian::functional::BinaryFunctor<marian::functional::elem::Plus, marian::functional::Assignee<1>, marian::functional::Assignee<2> > >(std::shared_ptr<marian::Allocator>, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::UnaryFunctor<marian::functional::elem::Sgn, marian::functional::Assignee<1> >, marian::functional::Assignee<2> >, float, marian::functional::BinaryFunctor<marian::functional::elem::Plus, marian::functional::Assignee<1>, marian::functional::Assignee<2> >, float, IntrusivePtr<marian::TensorBase>, IntrusivePtr<marian::TensorBase>, IntrusivePtr<marian::TensorBase>); template void marian::AggregateAll<float, float, marian::functional::UnaryFunctor<marian::functional::elem::Abs, marian::functional::BinaryFunctor<marian::functional::elem::Minus, marian::functional::Assignee<1>, marian::functional::Assignee<2> > >, marian::functional::BinaryFunctor<marian::functional::elem::Plus, marian::functional::Assignee<1>, marian::functional::Assignee<2> > >(std::shared_ptr<marian::Allocator>, marian::functional::UnaryFunctor<marian::functional::elem::Abs, marian::functional::BinaryFunctor<marian::functional::elem::Minus, marian::functional::Assignee<1>, marian::functional::Assignee<2> > >, float, marian::functional::BinaryFunctor<marian::functional::elem::Plus, marian::functional::Assignee<1>, marian::functional::Assignee<2> >, float, IntrusivePtr<marian::TensorBase>, IntrusivePtr<marian::TensorBase>, IntrusivePtr<marian::TensorBase>); template void marian::AggregateAll<float,float,marian::functional::BinaryFunctor<marian::functional::elem::Mult,marian::functional::Assignee<1>,marian::functional::UnaryFunctor<marian::functional::elem::Cos,marian::functional::Assignee<2> > >,marian::functional::BinaryFunctor<marian::functional::elem::Plus,marian::functional::Assignee<1>,marian::functional::Assignee<2> > >(std::shared_ptr<marian::Allocator>,marian::functional::BinaryFunctor<marian::functional::elem::Mult,marian::functional::Assignee<1>,marian::functional::UnaryFunctor<marian::functional::elem::Cos,marian::functional::Assignee<2> > >,float,marian::functional::BinaryFunctor<marian::functional::elem::Plus,marian::functional::Assignee<1>,marian::functional::Assignee<2> >,float,IntrusivePtr<marian::TensorBase>,IntrusivePtr<marian::TensorBase>,IntrusivePtr<marian::TensorBase>);
- +template void marian::AggregateAll<float, float, marian::functional::UnaryFunctor<marian::functional::elem::Abs, marian::functional::Assignee<1> >, marian::functional::BinaryFunctor<marian::functional::elem::Max, marian::functional::Assignee<1>, marian::functional::Assignee<2> > >(std::shared_ptr<marian::Allocator>, marian::functional::UnaryFunctor<marian::functional::elem::Abs, marian::functional::Assignee<1> >, float, marian::functional::BinaryFunctor<marian::functional::elem::Max, marian::functional::Assignee<1>, marian::functional::Assignee<2> >, float, IntrusivePtr<marian::TensorBase>, IntrusivePtr<marian::TensorBase>); #if COMPILE_FP16 template void AggregateAll<__half, float, BinaryFunctor<elem::Mult, BinaryFunctor<elem::Mult, Capture, Assignee<1>>, Assignee<2>>, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>>(std::shared_ptr<Allocator>, BinaryFunctor<elem::Mult, BinaryFunctor<elem::Mult, Capture, Assignee<1>>, Assignee<2>>, float, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor, marian::Tensor); template void AggregateAll<__half, float, BinaryFunctor<elem::Mult, BinaryFunctor<elem::Mult, Capture, BinaryFunctor<elem::Div, Capture, Assignee<1>>>, Assignee<2>>, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>>(std::shared_ptr<Allocator>, BinaryFunctor<elem::Mult, BinaryFunctor<elem::Mult, Capture, BinaryFunctor<elem::Div, Capture, Assignee<1>>>, Assignee<2>>, float, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor, marian::Tensor); @@ -74,4 +74,5 @@ template void AggregateAll<__half, float, BinaryFunctor<elem::Mult, Assignee<1>, template void marian::AggregateAll<__half, float, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::UnaryFunctor<marian::functional::elem::Sgn, marian::functional::Assignee<1> >, marian::functional::Assignee<2> >, marian::functional::BinaryFunctor<marian::functional::elem::Plus, marian::functional::Assignee<1>, marian::functional::Assignee<2> > >(std::shared_ptr<marian::Allocator>, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::UnaryFunctor<marian::functional::elem::Sgn, marian::functional::Assignee<1> >, marian::functional::Assignee<2> >, float, marian::functional::BinaryFunctor<marian::functional::elem::Plus, marian::functional::Assignee<1>, marian::functional::Assignee<2> >, float, IntrusivePtr<marian::TensorBase>, IntrusivePtr<marian::TensorBase>, IntrusivePtr<marian::TensorBase>); template void marian::AggregateAll<__half, float, marian::functional::UnaryFunctor<marian::functional::elem::Abs, marian::functional::BinaryFunctor<marian::functional::elem::Minus, marian::functional::Assignee<1>, marian::functional::Assignee<2> > >, marian::functional::BinaryFunctor<marian::functional::elem::Plus, marian::functional::Assignee<1>, marian::functional::Assignee<2> > >(std::shared_ptr<marian::Allocator>, marian::functional::UnaryFunctor<marian::functional::elem::Abs, marian::functional::BinaryFunctor<marian::functional::elem::Minus, marian::functional::Assignee<1>, marian::functional::Assignee<2> > >, float, marian::functional::BinaryFunctor<marian::functional::elem::Plus, marian::functional::Assignee<1>, marian::functional::Assignee<2> >, float, IntrusivePtr<marian::TensorBase>, IntrusivePtr<marian::TensorBase>, IntrusivePtr<marian::TensorBase>); template void marian::AggregateAll<__half,float,marian::functional::BinaryFunctor<marian::functional::elem::Mult,marian::functional::Assignee<1>,marian::functional::UnaryFunctor<marian::functional::elem::Cos,marian::functional::Assignee<2> > >,marian::functional::BinaryFunctor<marian::functional::elem::Plus,marian::functional::Assignee<1>,marian::functional::Assignee<2> > >(std::shared_ptr<marian::Allocator>,marian::functional::BinaryFunctor<marian::functional::elem::Mult,marian::functional::Assignee<1>,marian::functional::UnaryFunctor<marian::functional::elem::Cos,marian::functional::Assignee<2> > >,float,marian::functional::BinaryFunctor<marian::functional::elem::Plus,marian::functional::Assignee<1>,marian::functional::Assignee<2> >,float,IntrusivePtr<marian::TensorBase>,IntrusivePtr<marian::TensorBase>,IntrusivePtr<marian::TensorBase>);
+template void marian::AggregateAll<__half, float, marian::functional::UnaryFunctor<marian::functional::elem::Abs, marian::functional::Assignee<1> >, marian::functional::BinaryFunctor<marian::functional::elem::Max, marian::functional::Assignee<1>, marian::functional::Assignee<2> > >(std::shared_ptr<marian::Allocator>, marian::functional::UnaryFunctor<marian::functional::elem::Abs, marian::functional::Assignee<1> >, float, marian::functional::BinaryFunctor<marian::functional::elem::Max, marian::functional::Assignee<1>, marian::functional::Assignee<2> >, float, IntrusivePtr<marian::TensorBase>, IntrusivePtr<marian::TensorBase>); #endif diff --git a/src/tensors/gpu/element.inc b/src/tensors/gpu/element.inc index e2e74200..0eb75625 100755 --- a/src/tensors/gpu/element.inc +++ b/src/tensors/gpu/element.inc @@ -59,9 +59,15 @@ template void marian::gpu::Element<marian::functional::Assign<marian::functional template void marian::gpu::Element<marian::functional::Assign<marian::functional::Var<1>, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::Assignee<2>, marian::functional::UnaryFunctor<marian::functional::elem::Sigmoid, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::Capture, marian::functional::Assignee<2> > > > >, marian::Tensor >(marian::functional::Assign<marian::functional::Var<1>, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::Assignee<2>, marian::functional::UnaryFunctor<marian::functional::elem::Sigmoid, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::Capture, marian::functional::Assignee<2> > > > >, marian::Tensor, marian::Tensor); template void marian::gpu::Element<marian::functional::Assign<marian::functional::Var<1>, marian::functional::BinaryFunctor<marian::functional::elem::Plus, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::BinaryFunctor<marian::functional::elem::Lt, marian::functional::Assignee<1>, marian::functional::Capture>, marian::functional::Capture>, marian::functional::Capture> >>(marian::functional::Assign<marian::functional::Var<1>, marian::functional::BinaryFunctor<marian::functional::elem::Plus, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::BinaryFunctor<marian::functional::elem::Lt, marian::functional::Assignee<1>, marian::functional::Capture>, marian::functional::Capture>, marian::functional::Capture> >, marian::Tensor); template void marian::gpu::Element<marian::functional::Assign<marian::functional::Var<1>, marian::functional::BinaryFunctor<marian::functional::elem::Minus, marian::functional::Assignee<1>, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::Capture, marian::functional::BinaryFunctor<marian::functional::elem::Plus, marian::functional::BinaryFunctor<marian::functional::elem::Div, marian::functional::Assignee<2>, marian::functional::Capture>, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::Capture, marian::functional::Assignee<1> > > > > >, IntrusivePtr<marian::TensorBase> >(marian::functional::Assign<marian::functional::Var<1>, marian::functional::BinaryFunctor<marian::functional::elem::Minus, marian::functional::Assignee<1>, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::Capture, marian::functional::BinaryFunctor<marian::functional::elem::Plus, marian::functional::BinaryFunctor<marian::functional::elem::Div, marian::functional::Assignee<2>, marian::functional::Capture>, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::Capture, marian::functional::Assignee<1> > > > > >, marian::Tensor, marian::Tensor); +template void marian::gpu::Element<marian::functional::Assign<marian::functional::Var<1>, marian::functional::BinaryFunctor<marian::functional::elem::Minus, marian::functional::Assignee<1>, marian::functional::Assignee<2> > >, IntrusivePtr<marian::TensorBase> >(marian::functional::Assign<marian::functional::Var<1>, marian::functional::BinaryFunctor<marian::functional::elem::Minus, marian::functional::Assignee<1>, marian::functional::Assignee<2> > >, IntrusivePtr<marian::TensorBase>, IntrusivePtr<marian::TensorBase>); template void marian::gpu::Element<marian::functional::Assign<marian::functional::Var<1>, marian::functional::UnaryFunctor<marian::functional::elem::Abs, marian::functional::Assignee<2> > >, IntrusivePtr<marian::TensorBase> >(marian::functional::Assign<marian::functional::Var<1>, marian::functional::UnaryFunctor<marian::functional::elem::Abs, marian::functional::Assignee<2> > >, IntrusivePtr<marian::TensorBase>, IntrusivePtr<marian::TensorBase>); -template void marian::gpu::Element<marian::functional::Assign<marian::functional::Var<1>, marian::functional::UnaryFunctor<marian::functional::elem::Sin, marian::functional::Assignee<2> > >, IntrusivePtr<marian::TensorBase> >(marian::functional::Assign<marian::functional::Var<1>, marian::functional::UnaryFunctor<marian::functional::elem::Sin, marian::functional::Assignee<2> > >, IntrusivePtr<marian::TensorBase>, IntrusivePtr<marian::TensorBase>);
- +template void marian::gpu::Element<marian::functional::Assign<marian::functional::Var<1>, marian::functional::UnaryFunctor<marian::functional::elem::Sin, marian::functional::Assignee<2> > >, IntrusivePtr<marian::TensorBase> >(marian::functional::Assign<marian::functional::Var<1>, marian::functional::UnaryFunctor<marian::functional::elem::Sin, marian::functional::Assignee<2> > >, IntrusivePtr<marian::TensorBase>, IntrusivePtr<marian::TensorBase>); +template void marian::gpu::Element<marian::functional::Assign<marian::functional::Var<1>, marian::functional::BinaryFunctor<marian::functional::elem::Div, marian::functional::UnaryFunctor<marian::functional::elem::Round, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::Assignee<1>, marian::functional::Capture> >, marian::functional::Capture> >>(marian::functional::Assign<marian::functional::Var<1>, marian::functional::BinaryFunctor<marian::functional::elem::Div, marian::functional::UnaryFunctor<marian::functional::elem::Round, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::Assignee<1>, marian::functional::Capture> >, marian::functional::Capture> >, IntrusivePtr<marian::TensorBase>); +template void marian::gpu::Element<marian::functional::Assign<marian::functional::Var<1>, marian::functional::UnaryFunctor<marian::functional::elem::Round, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::Assignee<1>, marian::functional::Capture> > >>(marian::functional::Assign<marian::functional::Var<1>, marian::functional::UnaryFunctor<marian::functional::elem::Round, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::Assignee<1>, marian::functional::Capture> > >, IntrusivePtr<marian::TensorBase>); +template void marian::gpu::Element<marian::functional::Assign<marian::functional::Var<1>, marian::functional::UnaryFunctor<marian::functional::elem::Round, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::Assignee<2>, marian::functional::Capture> > >, IntrusivePtr<marian::TensorBase> >(marian::functional::Assign<marian::functional::Var<1>, marian::functional::UnaryFunctor<marian::functional::elem::Round, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::Assignee<2>, marian::functional::Capture> > >, IntrusivePtr<marian::TensorBase>, IntrusivePtr<marian::TensorBase>); +template void marian::gpu::Element<marian::functional::Assign<marian::functional::Var<1>, marian::functional::UnaryFunctor<marian::functional::elem::Floor, marian::functional::BinaryFunctor<marian::functional::elem::Div, marian::functional::UnaryFunctor<marian::functional::elem::Log, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::UnaryFunctor<marian::functional::elem::Abs, marian::functional::BinaryFunctor<marian::functional::elem::Div, marian::functional::Assignee<2>, marian::functional::Capture> >, marian::functional::Capture> >, marian::functional::UnaryFunctor<marian::functional::elem::Log, marian::functional::Capture> > > >, IntrusivePtr<marian::TensorBase> >(marian::functional::Assign<marian::functional::Var<1>, marian::functional::UnaryFunctor<marian::functional::elem::Floor, marian::functional::BinaryFunctor<marian::functional::elem::Div, marian::functional::UnaryFunctor<marian::functional::elem::Log, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::UnaryFunctor<marian::functional::elem::Abs, marian::functional::BinaryFunctor<marian::functional::elem::Div, marian::functional::Assignee<2>, marian::functional::Capture> >, marian::functional::Capture> >, marian::functional::UnaryFunctor<marian::functional::elem::Log, marian::functional::Capture> > > >, IntrusivePtr<marian::TensorBase>, IntrusivePtr<marian::TensorBase>); +template void marian::gpu::Element<marian::functional::Assign<marian::functional::Var<1>, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::BinaryFunctor<marian::functional::elem::Pow, marian::functional::Capture, marian::functional::Assignee<1> >, marian::functional::Capture>, marian::functional::UnaryFunctor<marian::functional::elem::Sgn, marian::functional::Assignee<2> > > >, IntrusivePtr<marian::TensorBase> >(marian::functional::Assign<marian::functional::Var<1>, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::BinaryFunctor<marian::functional::elem::Pow, marian::functional::Capture, marian::functional::Assignee<1> >, marian::functional::Capture>, marian::functional::UnaryFunctor<marian::functional::elem::Sgn, marian::functional::Assignee<2> > > >, IntrusivePtr<marian::TensorBase>, IntrusivePtr<marian::TensorBase>); +template void marian::gpu::Element<marian::functional::Assign<marian::functional::Var<1>, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::UnaryFunctor<marian::functional::elem::Sgn, marian::functional::Assignee<1> >, marian::functional::Capture>, marian::functional::BinaryFunctor<marian::functional::elem::Pow, marian::functional::Capture, marian::functional::BinaryFunctor<marian::functional::elem::Clip, marian::functional::UnaryFunctor<marian::functional::elem::Floor, marian::functional::BinaryFunctor<marian::functional::elem::Div, marian::functional::UnaryFunctor<marian::functional::elem::Log, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::UnaryFunctor<marian::functional::elem::Abs, marian::functional::BinaryFunctor<marian::functional::elem::Div, marian::functional::Assignee<1>, marian::functional::Capture> >, marian::functional::Capture> >, marian::functional::UnaryFunctor<marian::functional::elem::Log, marian::functional::Capture> > >, marian::functional::Capture> > > >>(marian::functional::Assign<marian::functional::Var<1>, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::UnaryFunctor<marian::functional::elem::Sgn, marian::functional::Assignee<1> >, marian::functional::Capture>, marian::functional::BinaryFunctor<marian::functional::elem::Pow, marian::functional::Capture, marian::functional::BinaryFunctor<marian::functional::elem::Clip, marian::functional::UnaryFunctor<marian::functional::elem::Floor, marian::functional::BinaryFunctor<marian::functional::elem::Div, marian::functional::UnaryFunctor<marian::functional::elem::Log, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::UnaryFunctor<marian::functional::elem::Abs, marian::functional::BinaryFunctor<marian::functional::elem::Div, marian::functional::Assignee<1>, marian::functional::Capture> >, marian::functional::Capture> >, marian::functional::UnaryFunctor<marian::functional::elem::Log, marian::functional::Capture> > >, marian::functional::Capture> > > >, IntrusivePtr<marian::TensorBase>); // How to add new specializations: // When you use a new specialization, it will cause a link error of this form (example): // .../src/tensors/tensor_operators.h:41: undefined reference to `void marian::gpu::Element<marian::functional::Assign< ... > ( ... )' diff --git a/src/training/graph_group_sync.cpp b/src/training/graph_group_sync.cpp index 1457faff..eaeefb42 100755 --- a/src/training/graph_group_sync.cpp +++ b/src/training/graph_group_sync.cpp @@ -297,6 +297,11 @@ void SyncGraphGroup::update(std::vector<Ptr<data::Batch>> subBatches, size_t num return nullptr; // null if we reached beyond the end }; + // Helper to quantize the model + auto quantizeModel = [&](size_t idx, size_t /*begin*/, size_t /*end*/) { + quantizers_[idx]->quantize(graphs_[idx]); + }; + // Upon very first execution, reset everything if(first_) { LOG(info, "[training] Batches are processed as {} process(es) x {} devices/process", @@ -304,6 +309,14 @@ void SyncGraphGroup::update(std::vector<Ptr<data::Batch>> subBatches, size_t num initialize(subBatches.front()); if(mvAvg_ && paramsAvg_.empty()) initializeAvg(); + + // initialize model quantization + if (options_->get<size_t>("quantize-bits") > 0) { + for (int idx = 0; idx < graphs_.size(); idx++) + quantizers_.push_back(New<ModelQuantizer>(options_)); + comm_->foreach(quantizeModel); + } + first_ = false; } @@ -357,6 +370,11 @@ void SyncGraphGroup::update(std::vector<Ptr<data::Batch>> subBatches, size_t num comm_->scatterReduceAndResetGrads(); // reduce gradients across all devices and MPI nodes into shards comm_->foreach(update); // per-shard model-update comm_->allGatherParams(); // distribute param value shards back + + // Re-add the error residual from previous quantization, + // then re-quantize the model back and update the error residual + if (options_->get<size_t>("quantize-bits") > 0) + comm_->foreach(quantizeModel); } else LOG(info, "[training] skipping {}-th update due to loss being {}", scheduler_->numberOfBatches(), localLoss.loss); diff --git a/src/training/graph_group_sync.h b/src/training/graph_group_sync.h index 147b172c..5b31a933 100755 --- a/src/training/graph_group_sync.h +++ b/src/training/graph_group_sync.h @@ -1,5 +1,6 @@ #pragma once +#include "optimizers/quantizer.h" #include "training/graph_group.h" #include "training/communicator.h" #include "training/exponential_smoothing.h" @@ -23,6 +24,9 @@ class SyncGraphGroup : public GraphGroup, public ExponentialSmoothing { std::vector<Ptr<TensorAllocator>> paramsAllocs_; // [deviceIndex] we must hold a reference to the memory until this class dies // @TODO: move this nto ExponentialSmoothing, together with paramsAvg_? + // model quantizer + std::vector<Ptr<ModelQuantizer>> quantizers_; + // state for update() bool first_{ true }; // gets interpreted and cleared by update() std::vector<Ptr<data::Batch>> pendingBatches_; // in case of dynamic MB-size scaling, we temporarly buffer up batches across update() calls until enough |