diff options
author | Marcin Junczys-Dowmunt <junczys@amu.edu.pl> | 2018-05-12 08:03:52 +0300 |
---|---|---|
committer | Marcin Junczys-Dowmunt <junczys@amu.edu.pl> | 2018-05-12 08:03:52 +0300 |
commit | 6571741b735e1fed9f999476361f0d68de6f1118 (patch) | |
tree | 5c90beca3a04a8fd41aef66681da050f978ae8b9 /src | |
parent | 7d5ea76d7b671cb3282a964e79e05dd1787d4aca (diff) |
clipping gemm
Diffstat (limited to 'src')
-rw-r--r-- | src/common/config_parser.cpp | 6 | ||||
-rw-r--r-- | src/graph/expression_operators.cpp | 44 | ||||
-rw-r--r-- | src/graph/expression_operators.h | 2 | ||||
-rw-r--r-- | src/graph/node_operators_unary.h | 40 | ||||
-rw-r--r-- | src/models/encoder_decoder.cpp | 2 | ||||
-rw-r--r-- | src/tensors/cpu/int16.h | 11 | ||||
-rwxr-xr-x | src/tensors/cpu/sharp/sse_gemm.h | 3 | ||||
-rw-r--r-- | src/tensors/gpu/element.inc | 3 | ||||
-rw-r--r-- | src/tensors/gpu/prod.cu | 33 | ||||
-rw-r--r-- | src/translator/translator.h | 2 |
10 files changed, 93 insertions, 53 deletions
diff --git a/src/common/config_parser.cpp b/src/common/config_parser.cpp index da63d4da..a936f64f 100644 --- a/src/common/config_parser.cpp +++ b/src/common/config_parser.cpp @@ -260,6 +260,8 @@ void ConfigParser::addOptionsCommon(po::options_description& desc) { "Suppress logging for translation") ("seed", po::value<size_t>()->default_value(0), "Seed for all random number generators. 0 means initialize randomly") + ("clip-gemm", po::value<float>()->default_value(0.f), + "If not 0 clip GEMM input values to +/- arg") ("interpolate-env-vars", po::value<bool>()->zero_tokens()->default_value(false), "allow the use of environment variables in paths, of the form ${VAR_NAME}") ("relative-paths", po::value<bool>()->zero_tokens()->default_value(false), @@ -480,8 +482,6 @@ void ConfigParser::addOptionsTraining(po::options_description& desc) { "Number of batches to preload for length-based sorting") ("maxi-batch-sort", po::value<std::string>()->default_value("trg"), "Sorting strategy for maxi-batch: trg (default) src none") - ("clip-gemm", po::value<float>()->default_value(0.f), - "If not 0 clip GEMM input values to +/- arg") ("optimizer,o", po::value<std::string>()->default_value("adam"), "Optimization algorithm (possible values: sgd, adagrad, adam") ("optimizer-params", po::value<std::vector<float>>() @@ -916,7 +916,6 @@ void ConfigParser::parseOptions(int argc, char** argv, bool doValidate) { SET_OPTION("sqlite", std::string); SET_OPTION("sqlite-drop", bool); - SET_OPTION("clip-gemm", float); SET_OPTION("optimizer", std::string); SET_OPTION_NONDEFAULT("optimizer-params", std::vector<float>); SET_OPTION("optimizer-delay", size_t); @@ -1021,6 +1020,7 @@ void ConfigParser::parseOptions(int argc, char** argv, bool doValidate) { SET_OPTION("quiet-translation", bool); SET_OPTION_NONDEFAULT("log", std::string); SET_OPTION("seed", size_t); + SET_OPTION("clip-gemm", float); SET_OPTION("interpolate-env-vars", bool); SET_OPTION("relative-paths", bool); SET_OPTION("devices", std::vector<std::string>); diff --git a/src/graph/expression_operators.cpp b/src/graph/expression_operators.cpp index 6e80cb08..7ff661b7 100644 --- a/src/graph/expression_operators.cpp +++ b/src/graph/expression_operators.cpp @@ -31,6 +31,13 @@ Expr prelu(Expr a, float alpha) { return Expression<PReLUNodeOp>(alpha, a); } +Expr clip(Expr a, float c) { + if(c == 0) + return a; + else + return Expression<ClipNodeOp>(a, c); +} + Expr log(Expr a) { return Expression<LogNodeOp>(a); }; @@ -204,16 +211,19 @@ Expr weighted_average(Expr in, Expr weights, keywords::axis_k ax) { Expr dot(Expr a, Expr b, bool transA, bool transB, float scale) { auto device = a->graph()->getDevice().type; + float clipValue = a->graph()->getBackend()->getClip(); + if(a->graph()->isOptimized() && device == DeviceType::cpu) { // dotInt16 computes A * B.T, hence the transpose for B to get A * B // if transA = false and transB = false. - return cpu::int16::dot(cpu::int16::quantize(transA ? transpose(a) : a), - cpu::int16::quantize(transB ? b : transpose(b)), + return cpu::int16::dot(cpu::int16::quantize(transA ? transpose(a) : a, clipValue), + cpu::int16::quantize(transB ? b : transpose(b), clipValue), scale); } else { - return Expression<DotNodeOp>(a, b, transA, transB, scale); + return Expression<DotNodeOp>(clip(a, clipValue), clip(b, clipValue), + transA, transB, scale); } } @@ -223,9 +233,12 @@ Expr bdot(Expr a, Expr b, bool transA, bool transB, float scale) { Expr affine(Expr a, Expr b, Expr bias, bool transA, bool transB, float scale) { auto device = a->graph()->getDevice().type; + + float clipValue = a->graph()->getBackend()->getClip(); + if(a->graph()->isOptimized() && device == DeviceType::cpu) { - bool autotune = false; + bool autotune = true; if(autotune) { thread_local Ptr<AutoTuner<Expr>> tuner = New<AutoTuner<Expr>>(); @@ -255,8 +268,8 @@ Expr affine(Expr a, Expr b, Expr bias, bool transA, bool transB, float scale) { return e; }; auto alg1 = [=]() { - return rec1(cpu::int16::affine(rec1(cpu::int16::quantize(transA ? rec1(transpose(a)) : a)), - cpu::int16::quantize(transB ? b : transpose(b)), + return rec1(cpu::int16::affine(rec1(cpu::int16::quantize(transA ? rec1(transpose(a)) : a, clipValue)), + cpu::int16::quantize(transB ? b : transpose(b), clipValue), bias, scale), true); @@ -270,8 +283,18 @@ Expr affine(Expr a, Expr b, Expr bias, bool transA, bool transB, float scale) { e->record(tuner, hash2, stop); return e; }; + + auto alg2 = [=]() { - std::vector<Expr> nodes = {a, b, bias}; + auto ac = clip(a, clipValue); + if(ac != a) + ac = rec2(ac); + + auto bc = clip(b, clipValue); + if(bc != b) + bc = rec2(bc); + + std::vector<Expr> nodes = {ac, bc, bias}; return rec2(Expression<AffineNodeOp>(nodes, transA, transB, scale), true); }; @@ -283,16 +306,17 @@ Expr affine(Expr a, Expr b, Expr bias, bool transA, bool transB, float scale) { } else { // cpu int16 version - return cpu::int16::affine(cpu::int16::quantize(transA ? transpose(a) : a), - cpu::int16::quantize(transB ? b : transpose(b)), + return cpu::int16::affine(cpu::int16::quantize(transA ? transpose(a) : a, clipValue), + cpu::int16::quantize(transB ? b : transpose(b), clipValue), bias, scale); } } else { // general version, MKL, CBlas or CUDA - std::vector<Expr> nodes = {a, b, bias}; + std::vector<Expr> nodes = {clip(a, clipValue), clip(b, clipValue), bias}; return Expression<AffineNodeOp>(nodes, transA, transB, scale); + } } diff --git a/src/graph/expression_operators.h b/src/graph/expression_operators.h index 070ee0ba..036445ee 100644 --- a/src/graph/expression_operators.h +++ b/src/graph/expression_operators.h @@ -34,6 +34,8 @@ Expr log(Expr a); Expr exp(Expr a); +Expr clip(Expr a, float c); + Expr operator-(Expr a); /*********************************************************/ diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h index 273adf44..8f1fd2f7 100644 --- a/src/graph/node_operators_unary.h +++ b/src/graph/node_operators_unary.h @@ -99,6 +99,46 @@ public: } }; +struct ClipNodeOp : public UnaryNodeOp { +private: + float clip_{0}; + +public: + ClipNodeOp(Expr a, float clip) : UnaryNodeOp(a), clip_{clip} {} + + NodeOps forwardOps() { + using namespace functional; + return {NodeOp(Element(_1 = clip(_2, clip_), val_, child(0)->val()))}; + } + + NodeOps backwardOps() { + using namespace functional; + // @TODO: is this correct? + return {NodeOp(Add(_1, child(0)->grad(), adj_))}; + } + + const std::string type() { return "clip"; } + + virtual size_t hash() { + if(!hash_) { + hash_ = NaryNodeOp::hash(); + boost::hash_combine(hash_, clip_); + } + return hash_; + } + + virtual bool equal(Expr node) { + if(!NaryNodeOp::equal(node)) + return false; + auto cnode = std::dynamic_pointer_cast<ClipNodeOp>(node); + if(!cnode) + return false; + if(clip_ != cnode->clip_) + return false; + return true; + } +}; + struct LogitNodeOp : public UnaryNodeOp { LogitNodeOp(Expr a) : UnaryNodeOp(a) {} diff --git a/src/models/encoder_decoder.cpp b/src/models/encoder_decoder.cpp index 73691257..c17adb69 100644 --- a/src/models/encoder_decoder.cpp +++ b/src/models/encoder_decoder.cpp @@ -26,7 +26,7 @@ EncoderDecoder::EncoderDecoder(Ptr<Options> options) "special-vocab", "tied-embeddings", "tied-embeddings-src", - "tied-embeddings-all", + "tied-embeddings-all" }; modelFeatures_.insert("transformer-heads"); diff --git a/src/tensors/cpu/int16.h b/src/tensors/cpu/int16.h index aca49e17..621a3e30 100644 --- a/src/tensors/cpu/int16.h +++ b/src/tensors/cpu/int16.h @@ -8,11 +8,14 @@ namespace cpu { namespace int16 { struct QuantizeNodeOp : public UnaryNodeOp { - QuantizeNodeOp(Expr a) : UnaryNodeOp(a, Type::int16) {} + float clipValue_; + + QuantizeNodeOp(Expr a, float clipValue) + : UnaryNodeOp(a, Type::int16), clipValue_{clipValue} {} NodeOps forwardOps() { return { - NodeOp(Quantize(val_, child(0)->val())) + NodeOp(Quantize(val_, child(0)->val(), clipValue_)) }; } @@ -118,8 +121,8 @@ static inline Expr affine(Expr a, Expr b, Expr c, float scalar) { return Expression<cpu::int16::AffineNodeOp>(nodes, scalar); } -static inline Expr quantize(Expr a) { - return Expression<cpu::int16::QuantizeNodeOp>(a); +static inline Expr quantize(Expr a, float clipValue) { + return Expression<cpu::int16::QuantizeNodeOp>(a, clipValue); } diff --git a/src/tensors/cpu/sharp/sse_gemm.h b/src/tensors/cpu/sharp/sse_gemm.h index 542c3f6d..72f1d8d6 100755 --- a/src/tensors/cpu/sharp/sse_gemm.h +++ b/src/tensors/cpu/sharp/sse_gemm.h @@ -71,7 +71,8 @@ namespace int16 { const int BITS = 10; static inline void Quantize(marian::Tensor out, - const marian::Tensor in) { + const marian::Tensor in, + float clipValue) { int num_rows = in->shape().elements() / in->shape()[-1]; int width = in->shape()[-1]; diff --git a/src/tensors/gpu/element.inc b/src/tensors/gpu/element.inc index 27f12d0d..002c668f 100644 --- a/src/tensors/gpu/element.inc +++ b/src/tensors/gpu/element.inc @@ -36,4 +36,5 @@ template void Element<Assign<Var<1>, TernaryFunctor<elem::IfThenElse, BinaryFunc template void Element<Assign<Var<1>, TernaryFunctor<elem::IfThenElse, BinaryFunctor<elem::Gt, UnaryFunctor<elem::Abs, Assignee<2>>, Capture>, Capture, Assignee<2>>>, marian::Tensor>(Assign<Var<1>, TernaryFunctor<elem::IfThenElse, BinaryFunctor<elem::Gt, UnaryFunctor<elem::Abs, Assignee<2>>, Capture>, Capture, Assignee<2>>>, marian::Tensor, marian::Tensor); template void Element<Assign<Var<1>, TernaryFunctor<elem::IfThenElse, BinaryFunctor<elem::Lt, UnaryFunctor<elem::Abs, Assignee<1>>, Capture>, Capture, Assignee<1>>>>(Assign<Var<1>, TernaryFunctor<elem::IfThenElse, BinaryFunctor<elem::Lt, UnaryFunctor<elem::Abs, Assignee<1>>, Capture>, Capture, Assignee<1>>>, marian::Tensor); template void Element<Assign<Var<1>, TernaryFunctor<elem::IfThenElse, BinaryFunctor<elem::Leq, UnaryFunctor<elem::Abs, Assignee<1> >, Capture>, Capture, Assignee<1> > >>(Assign<Var<1>, TernaryFunctor<elem::IfThenElse, BinaryFunctor<elem::Leq, UnaryFunctor<elem::Abs, Assignee<1> >, Capture>, Capture, Assignee<1> > >, marian::Tensor); -template void Element<Assign<Var<1>, TernaryFunctor<elem::IfThenElse, BinaryFunctor<elem::Leq, UnaryFunctor<elem::Abs, Assignee<2> >, Capture>, Capture, Capture> >, marian::Tensor >(Assign<Var<1>, TernaryFunctor<elem::IfThenElse, BinaryFunctor<elem::Leq, UnaryFunctor<elem::Abs, Assignee<2> >, Capture>, Capture, Capture> >, marian::Tensor, marian::Tensor);
\ No newline at end of file +template void Element<Assign<Var<1>, TernaryFunctor<elem::IfThenElse, BinaryFunctor<elem::Leq, UnaryFunctor<elem::Abs, Assignee<2> >, Capture>, Capture, Capture> >, marian::Tensor >(Assign<Var<1>, TernaryFunctor<elem::IfThenElse, BinaryFunctor<elem::Leq, UnaryFunctor<elem::Abs, Assignee<2> >, Capture>, Capture, Capture> >, marian::Tensor, marian::Tensor); +template void Element<Assign<Var<1>, BinaryFunctor<elem::Clip, Assignee<2>, Capture>>, marian::Tensor>(Assign<Var<1>, BinaryFunctor<elem::Clip, Assignee<2>, Capture>>, marian::Tensor, marian::Tensor); diff --git a/src/tensors/gpu/prod.cu b/src/tensors/gpu/prod.cu index e0b526a1..d75a2ceb 100644 --- a/src/tensors/gpu/prod.cu +++ b/src/tensors/gpu/prod.cu @@ -11,26 +11,6 @@ namespace marian { namespace gpu { -__global__ void gClip(float* in, int length, float clip) { - for(int bid = 0; bid < length; bid += blockDim.x * gridDim.x) { - int index = bid + blockDim.x * blockIdx.x + threadIdx.x; - if(index < length) { - if(in[index] < -clip) - in[index] = -clip; - if(in[index] > clip) - in[index] = clip; - } - } -} - -void Clip(marian::Tensor A, float clip) { - int length = A->shape().elements(); - int threads = std::min(MAX_THREADS, length); - int blocks = std::min(MAX_BLOCKS, length / threads + (length % threads != 0)); - - gClip<<<blocks, threads>>>(A->data(), length, clip); -} - void Prod(marian::Tensor C, marian::Tensor A, marian::Tensor B, @@ -41,12 +21,6 @@ void Prod(marian::Tensor C, cudaSetDevice(C->getDevice().no); float alpha = scalar; - float clip = C->getBackend()->getClip(); - if(clip != 0.f) { - Clip(A, clip); - Clip(B, clip); - } - size_t m = A->shape().elements() / A->shape().back(); size_t k = A->shape().back(); if(transA) @@ -116,13 +90,6 @@ void ProdBatched(marian::Tensor C, cudaSetDevice(C->getDevice().no); float alpha = scalar; - float clip = C->getBackend()->getClip(); - if(clip != 0.f) { - Clip(A, clip); - Clip(B, clip); - } - - size_t batchA = A->shape().elements() / (A->shape()[-1] * A->shape()[-2]); size_t batchB = B->shape().elements() / (B->shape()[-1] * B->shape()[-2]); diff --git a/src/translator/translator.h b/src/translator/translator.h index 71e37475..1f3253ae 100644 --- a/src/translator/translator.h +++ b/src/translator/translator.h @@ -55,6 +55,7 @@ public: auto task = [&](DeviceId device, size_t id) { auto graph = New<ExpressionGraph>(true, options_->get<bool>("optimize")); graph->setDevice(device); + graph->getBackend()->setClip(options_->get<float>("clip-gemm")); graph->reserveWorkspaceMB(options_->get<size_t>("workspace")); graphs_[id] = graph; @@ -154,6 +155,7 @@ public: for(auto device : devices_) { auto graph = New<ExpressionGraph>(true, options_->get<bool>("optimize")); graph->setDevice(device); + graph->getBackend()->setClip(options_->get<float>("clip-gemm")); graph->reserveWorkspaceMB(options_->get<size_t>("workspace")); graphs_.push_back(graph); |