Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/src/graph
diff options
context:
space:
mode:
authorMarcin Junczys-Dowmunt <junczys@amu.edu.pl>2018-06-08 00:34:39 +0300
committerMarcin Junczys-Dowmunt <junczys@amu.edu.pl>2018-06-08 00:34:39 +0300
commit7bb558ecfcdfef5c629f5a9d85ea2b4680bb60aa (patch)
tree0664fc883448c68ad81b91e4b76bef319d45b638 /src/graph
parent68d61a662294cb3f26b3935da95a8ce1c404c293 (diff)
parentac21830517e75e31a0bca3b071292acff0d9610d (diff)
Merge branch 'master' into jonathac/windows_build
Diffstat (limited to 'src/graph')
-rw-r--r--src/graph/expression_operators.cpp48
-rw-r--r--src/graph/expression_operators.h2
-rw-r--r--src/graph/node_operators.cpp1
-rw-r--r--src/graph/node_operators_unary.h39
4 files changed, 80 insertions, 10 deletions
diff --git a/src/graph/expression_operators.cpp b/src/graph/expression_operators.cpp
index 287f771d..fd985d4a 100644
--- a/src/graph/expression_operators.cpp
+++ b/src/graph/expression_operators.cpp
@@ -31,6 +31,13 @@ Expr prelu(Expr a, float alpha) {
return Expression<PReLUNodeOp>(alpha, a);
}
+Expr clip(Expr a, float c) {
+ if(c == 0)
+ return a;
+ else
+ return Expression<ClipNodeOp>(a, c);
+}
+
Expr log(Expr a) {
return Expression<LogNodeOp>(a);
};
@@ -204,16 +211,21 @@ Expr weighted_average(Expr in, Expr weights, keywords::axis_k ax) {
Expr dot(Expr a, Expr b, bool transA, bool transB, float scale) {
auto device = a->graph()->getDevice().type;
+ float clipValue = a->graph()->getBackend()->getClip();
+
+ // Currently only true when command line options
+ // --optimize --cpu-thread=N with N > 0 are set.
if(a->graph()->isOptimized() && device == DeviceType::cpu) {
// dotInt16 computes A * B.T, hence the transpose for B to get A * B
// if transA = false and transB = false.
- return cpu::int16::dot(cpu::int16::quantize(transA ? transpose(a) : a),
- cpu::int16::quantize(transB ? b : transpose(b)),
+ return cpu::int16::dot(cpu::int16::quantize(transA ? transpose(a) : a, clipValue),
+ cpu::int16::quantize(transB ? b : transpose(b), clipValue),
scale);
}
else {
- return Expression<DotNodeOp>(a, b, transA, transB, scale);
+ return Expression<DotNodeOp>(clip(a, clipValue), clip(b, clipValue),
+ transA, transB, scale);
}
}
@@ -223,6 +235,9 @@ Expr bdot(Expr a, Expr b, bool transA, bool transB, float scale) {
Expr affine(Expr a, Expr b, Expr bias, bool transA, bool transB, float scale) {
auto device = a->graph()->getDevice().type;
+
+ float clipValue = a->graph()->getBackend()->getClip();
+
if(a->graph()->isOptimized() && device == DeviceType::cpu) {
bool autotune = true;
@@ -255,8 +270,8 @@ Expr affine(Expr a, Expr b, Expr bias, bool transA, bool transB, float scale) {
return e;
};
auto alg1 = [=]() {
- return rec1(cpu::int16::affine(rec1(cpu::int16::quantize(transA ? rec1(transpose(a)) : a)),
- cpu::int16::quantize(transB ? b : transpose(b)),
+ return rec1(cpu::int16::affine(rec1(cpu::int16::quantize(transA ? rec1(transpose(a)) : a, clipValue)),
+ cpu::int16::quantize(transB ? b : transpose(b), clipValue),
bias,
scale),
true);
@@ -270,8 +285,18 @@ Expr affine(Expr a, Expr b, Expr bias, bool transA, bool transB, float scale) {
e->record(tuner, hash2, stop);
return e;
};
+
+
auto alg2 = [=]() {
- std::vector<Expr> nodes = {a, b, bias};
+ auto ac = clip(a, clipValue);
+ if(ac != a)
+ ac = rec2(ac);
+
+ auto bc = clip(b, clipValue);
+ if(bc != b)
+ bc = rec2(bc);
+
+ std::vector<Expr> nodes = {ac, bc, bias};
return rec2(Expression<AffineNodeOp>(nodes, transA, transB, scale),
true);
};
@@ -283,16 +308,21 @@ Expr affine(Expr a, Expr b, Expr bias, bool transA, bool transB, float scale) {
}
else {
// cpu int16 version
- return cpu::int16::affine(cpu::int16::quantize(transA ? transpose(a) : a),
- cpu::int16::quantize(transB ? b : transpose(b)),
+ return cpu::int16::affine(cpu::int16::quantize(transA ? transpose(a) : a, clipValue),
+ cpu::int16::quantize(transB ? b : transpose(b), clipValue),
bias,
scale);
}
}
else {
// general version, MKL, CBlas or CUDA
- std::vector<Expr> nodes = {a, b, bias};
+ // if clipValue > 0, the inputs will be clipped to range [-clipValue, clipValue]
+ // This is meant to keep values at the same range as used during training when
+ // optimizing for 8-bit integer products. Likely to be removed in the future
+ // when we explore better ways to handle this.
+ std::vector<Expr> nodes = {clip(a, clipValue), clip(b, clipValue), bias};
return Expression<AffineNodeOp>(nodes, transA, transB, scale);
+
}
}
diff --git a/src/graph/expression_operators.h b/src/graph/expression_operators.h
index 070ee0ba..036445ee 100644
--- a/src/graph/expression_operators.h
+++ b/src/graph/expression_operators.h
@@ -34,6 +34,8 @@ Expr log(Expr a);
Expr exp(Expr a);
+Expr clip(Expr a, float c);
+
Expr operator-(Expr a);
/*********************************************************/
diff --git a/src/graph/node_operators.cpp b/src/graph/node_operators.cpp
index 788b3fc6..315cb5eb 100644
--- a/src/graph/node_operators.cpp
+++ b/src/graph/node_operators.cpp
@@ -2,7 +2,6 @@
#include "expression_graph.h"
#include "tensors/tensor_operators.h"
-#include "tensors/cpu/sharp/sse_gemm.h"
namespace marian {
diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h
index e14f6546..259e6072 100644
--- a/src/graph/node_operators_unary.h
+++ b/src/graph/node_operators_unary.h
@@ -99,6 +99,45 @@ public:
}
};
+struct ClipNodeOp : public UnaryNodeOp {
+private:
+ float clip_{0};
+
+public:
+ ClipNodeOp(Expr a, float clip) : UnaryNodeOp(a), clip_{clip} {}
+
+ NodeOps forwardOps() {
+ using namespace functional;
+ return {NodeOp(Element(_1 = clip(_2, clip_), val_, child(0)->val()))};
+ }
+
+ NodeOps backwardOps() {
+ using namespace functional;
+ return {NodeOp(Add(bump(_1, clip_) * _2, child(0)->grad(), child(0)->val(), adj_))};
+ }
+
+ const std::string type() { return "clip"; }
+
+ virtual size_t hash() {
+ if(!hash_) {
+ hash_ = NaryNodeOp::hash();
+ boost::hash_combine(hash_, clip_);
+ }
+ return hash_;
+ }
+
+ virtual bool equal(Expr node) {
+ if(!NaryNodeOp::equal(node))
+ return false;
+ auto cnode = std::dynamic_pointer_cast<ClipNodeOp>(node);
+ if(!cnode)
+ return false;
+ if(clip_ != cnode->clip_)
+ return false;
+ return true;
+ }
+};
+
struct LogitNodeOp : public UnaryNodeOp {
LogitNodeOp(Expr a) : UnaryNodeOp(a) {}