diff options
author | Marcin Junczys-Dowmunt <junczys@amu.edu.pl> | 2018-06-08 00:34:39 +0300 |
---|---|---|
committer | Marcin Junczys-Dowmunt <junczys@amu.edu.pl> | 2018-06-08 00:34:39 +0300 |
commit | 7bb558ecfcdfef5c629f5a9d85ea2b4680bb60aa (patch) | |
tree | 0664fc883448c68ad81b91e4b76bef319d45b638 /src/graph | |
parent | 68d61a662294cb3f26b3935da95a8ce1c404c293 (diff) | |
parent | ac21830517e75e31a0bca3b071292acff0d9610d (diff) |
Merge branch 'master' into jonathac/windows_build
Diffstat (limited to 'src/graph')
-rw-r--r-- | src/graph/expression_operators.cpp | 48 | ||||
-rw-r--r-- | src/graph/expression_operators.h | 2 | ||||
-rw-r--r-- | src/graph/node_operators.cpp | 1 | ||||
-rw-r--r-- | src/graph/node_operators_unary.h | 39 |
4 files changed, 80 insertions, 10 deletions
diff --git a/src/graph/expression_operators.cpp b/src/graph/expression_operators.cpp index 287f771d..fd985d4a 100644 --- a/src/graph/expression_operators.cpp +++ b/src/graph/expression_operators.cpp @@ -31,6 +31,13 @@ Expr prelu(Expr a, float alpha) { return Expression<PReLUNodeOp>(alpha, a); } +Expr clip(Expr a, float c) { + if(c == 0) + return a; + else + return Expression<ClipNodeOp>(a, c); +} + Expr log(Expr a) { return Expression<LogNodeOp>(a); }; @@ -204,16 +211,21 @@ Expr weighted_average(Expr in, Expr weights, keywords::axis_k ax) { Expr dot(Expr a, Expr b, bool transA, bool transB, float scale) { auto device = a->graph()->getDevice().type; + float clipValue = a->graph()->getBackend()->getClip(); + + // Currently only true when command line options + // --optimize --cpu-thread=N with N > 0 are set. if(a->graph()->isOptimized() && device == DeviceType::cpu) { // dotInt16 computes A * B.T, hence the transpose for B to get A * B // if transA = false and transB = false. - return cpu::int16::dot(cpu::int16::quantize(transA ? transpose(a) : a), - cpu::int16::quantize(transB ? b : transpose(b)), + return cpu::int16::dot(cpu::int16::quantize(transA ? transpose(a) : a, clipValue), + cpu::int16::quantize(transB ? b : transpose(b), clipValue), scale); } else { - return Expression<DotNodeOp>(a, b, transA, transB, scale); + return Expression<DotNodeOp>(clip(a, clipValue), clip(b, clipValue), + transA, transB, scale); } } @@ -223,6 +235,9 @@ Expr bdot(Expr a, Expr b, bool transA, bool transB, float scale) { Expr affine(Expr a, Expr b, Expr bias, bool transA, bool transB, float scale) { auto device = a->graph()->getDevice().type; + + float clipValue = a->graph()->getBackend()->getClip(); + if(a->graph()->isOptimized() && device == DeviceType::cpu) { bool autotune = true; @@ -255,8 +270,8 @@ Expr affine(Expr a, Expr b, Expr bias, bool transA, bool transB, float scale) { return e; }; auto alg1 = [=]() { - return rec1(cpu::int16::affine(rec1(cpu::int16::quantize(transA ? rec1(transpose(a)) : a)), - cpu::int16::quantize(transB ? b : transpose(b)), + return rec1(cpu::int16::affine(rec1(cpu::int16::quantize(transA ? rec1(transpose(a)) : a, clipValue)), + cpu::int16::quantize(transB ? b : transpose(b), clipValue), bias, scale), true); @@ -270,8 +285,18 @@ Expr affine(Expr a, Expr b, Expr bias, bool transA, bool transB, float scale) { e->record(tuner, hash2, stop); return e; }; + + auto alg2 = [=]() { - std::vector<Expr> nodes = {a, b, bias}; + auto ac = clip(a, clipValue); + if(ac != a) + ac = rec2(ac); + + auto bc = clip(b, clipValue); + if(bc != b) + bc = rec2(bc); + + std::vector<Expr> nodes = {ac, bc, bias}; return rec2(Expression<AffineNodeOp>(nodes, transA, transB, scale), true); }; @@ -283,16 +308,21 @@ Expr affine(Expr a, Expr b, Expr bias, bool transA, bool transB, float scale) { } else { // cpu int16 version - return cpu::int16::affine(cpu::int16::quantize(transA ? transpose(a) : a), - cpu::int16::quantize(transB ? b : transpose(b)), + return cpu::int16::affine(cpu::int16::quantize(transA ? transpose(a) : a, clipValue), + cpu::int16::quantize(transB ? b : transpose(b), clipValue), bias, scale); } } else { // general version, MKL, CBlas or CUDA - std::vector<Expr> nodes = {a, b, bias}; + // if clipValue > 0, the inputs will be clipped to range [-clipValue, clipValue] + // This is meant to keep values at the same range as used during training when + // optimizing for 8-bit integer products. Likely to be removed in the future + // when we explore better ways to handle this. + std::vector<Expr> nodes = {clip(a, clipValue), clip(b, clipValue), bias}; return Expression<AffineNodeOp>(nodes, transA, transB, scale); + } } diff --git a/src/graph/expression_operators.h b/src/graph/expression_operators.h index 070ee0ba..036445ee 100644 --- a/src/graph/expression_operators.h +++ b/src/graph/expression_operators.h @@ -34,6 +34,8 @@ Expr log(Expr a); Expr exp(Expr a); +Expr clip(Expr a, float c); + Expr operator-(Expr a); /*********************************************************/ diff --git a/src/graph/node_operators.cpp b/src/graph/node_operators.cpp index 788b3fc6..315cb5eb 100644 --- a/src/graph/node_operators.cpp +++ b/src/graph/node_operators.cpp @@ -2,7 +2,6 @@ #include "expression_graph.h" #include "tensors/tensor_operators.h" -#include "tensors/cpu/sharp/sse_gemm.h" namespace marian { diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h index e14f6546..259e6072 100644 --- a/src/graph/node_operators_unary.h +++ b/src/graph/node_operators_unary.h @@ -99,6 +99,45 @@ public: } }; +struct ClipNodeOp : public UnaryNodeOp { +private: + float clip_{0}; + +public: + ClipNodeOp(Expr a, float clip) : UnaryNodeOp(a), clip_{clip} {} + + NodeOps forwardOps() { + using namespace functional; + return {NodeOp(Element(_1 = clip(_2, clip_), val_, child(0)->val()))}; + } + + NodeOps backwardOps() { + using namespace functional; + return {NodeOp(Add(bump(_1, clip_) * _2, child(0)->grad(), child(0)->val(), adj_))}; + } + + const std::string type() { return "clip"; } + + virtual size_t hash() { + if(!hash_) { + hash_ = NaryNodeOp::hash(); + boost::hash_combine(hash_, clip_); + } + return hash_; + } + + virtual bool equal(Expr node) { + if(!NaryNodeOp::equal(node)) + return false; + auto cnode = std::dynamic_pointer_cast<ClipNodeOp>(node); + if(!cnode) + return false; + if(clip_ != cnode->clip_) + return false; + return true; + } +}; + struct LogitNodeOp : public UnaryNodeOp { LogitNodeOp(Expr a) : UnaryNodeOp(a) {} |