Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/src/graph
diff options
context:
space:
mode:
authorMarcin Junczys-Dowmunt <junczys@amu.edu.pl>2018-04-11 06:46:29 +0300
committerMarcin Junczys-Dowmunt <junczys@amu.edu.pl>2018-04-11 06:46:29 +0300
commitafc3bde59710c2d5b8f4b0b4b731752dc2cf9a41 (patch)
tree4b32ed16d45c98516b190ab2c2e0d51ea77a99dd /src/graph
parentb0622430caab902efd9681ac9582d0f274168764 (diff)
add int16 operators, attempt at memoization
Diffstat (limited to 'src/graph')
-rw-r--r--src/graph/chainable.h2
-rw-r--r--src/graph/expression_graph.cpp4
-rw-r--r--src/graph/expression_graph.h10
-rw-r--r--src/graph/expression_operators.cpp32
-rw-r--r--src/graph/expression_operators.h1
-rw-r--r--src/graph/node.cpp10
-rw-r--r--src/graph/node.h11
-rw-r--r--src/graph/node_operators.cpp2
-rw-r--r--src/graph/node_operators.h6
-rw-r--r--src/graph/node_operators_unary.h6
10 files changed, 59 insertions, 25 deletions
diff --git a/src/graph/chainable.h b/src/graph/chainable.h
index 9e5d8fb2..281d3653 100644
--- a/src/graph/chainable.h
+++ b/src/graph/chainable.h
@@ -72,7 +72,9 @@ struct Chainable {
// virtual const std::string& type() = 0;
virtual Ptr<ExpressionGraph> graph() = 0;
+
virtual const Shape& shape() = 0;
+ virtual const Type& value_type() = 0;
virtual std::vector<Expr>& children() = 0;
virtual Expr child(size_t) = 0;
diff --git a/src/graph/expression_graph.cpp b/src/graph/expression_graph.cpp
index 4a0edb34..e89f45f3 100644
--- a/src/graph/expression_graph.cpp
+++ b/src/graph/expression_graph.cpp
@@ -5,8 +5,8 @@
namespace marian {
-ExpressionGraph::ExpressionGraph(bool inference)
- : inferenceOnly_(inference), backend_(nullptr) {}
+ExpressionGraph::ExpressionGraph(bool inference, bool optimized)
+ : inferenceOnly_(inference), optimized_(optimized), backend_(nullptr) {}
void ExpressionGraph::setDevice(DeviceId deviceId) {
if(!backend_) {
diff --git a/src/graph/expression_graph.h b/src/graph/expression_graph.h
index c5f12335..c9f2b245 100644
--- a/src/graph/expression_graph.h
+++ b/src/graph/expression_graph.h
@@ -38,6 +38,8 @@ private:
std::unordered_map<size_t, std::vector<WExpr>> hashMap_;
bool inferenceOnly_{false};
+ bool optimized_{false};
+
bool reloaded_{false};
std::string namespace_;
@@ -53,9 +55,10 @@ public:
*
* Constructor should be used as New<ExpressionGraph>()
*/
- ExpressionGraph(bool inference = false);
+ ExpressionGraph(bool inference = false, bool optimized = false);
void setInference(bool inference) { inferenceOnly_ = inference; }
+ bool isInference() { return inferenceOnly_; }
~ExpressionGraph() {
clear();
@@ -63,9 +66,14 @@ public:
}
void setDevice(DeviceId deviceId = {0, DeviceType::gpu});
+
DeviceId getDevice() { return backend_->getDevice(); }
+
Ptr<Backend> getBackend() { return backend_; }
+ void setOptimized(bool optimized) { optimized_ = optimized; }
+ bool isOptimized() { return (optimized_ && inferenceOnly_); }
+
void switchParams(const std::string& newNamespace) {
namespace_ = newNamespace;
}
diff --git a/src/graph/expression_operators.cpp b/src/graph/expression_operators.cpp
index a4a8b079..5f649318 100644
--- a/src/graph/expression_operators.cpp
+++ b/src/graph/expression_operators.cpp
@@ -5,6 +5,8 @@
#include "graph/node_operators_binary.h"
#include "graph/node_operators_unary.h"
+#include "tensors/cpu/int16.h"
+
namespace marian {
Expr debug(Expr a, const std::string& message) {
@@ -200,13 +202,36 @@ Expr weighted_average(Expr in, Expr weights, keywords::axis_k ax) {
}
Expr dot(Expr a, Expr b, bool transA, bool transB, float scalar) {
- return Expression<DotNodeOp>(a, b, transA, transB, scalar);
+ auto device = a->graph()->getDevice().type;
+ if(a->graph()->isOptimized() && device == DeviceType::cpu) {
+ // dotInt16 computes A * B.T, hence the transpose for B to get A * B
+ // if transA = false and transB = false.
+ return cpu::int16::dot(cpu::int16::quantize(transA ? transpose(a) : a),
+ cpu::int16::quantize(transB ? b : transpose(b)),
+ scalar);
+ }
+ else {
+ return Expression<DotNodeOp>(a, b, transA, transB, scalar);
+ }
}
Expr bdot(Expr a, Expr b, bool transA, bool transB, float scalar) {
return Expression<DotBatchedNodeOp>(a, b, transA, transB, scalar);
}
+Expr affine(Expr a, Expr b, Expr bias, bool transA, bool transB, float scalar) {
+ auto device = a->graph()->getDevice().type;
+ if(a->graph()->isOptimized() && device == DeviceType::cpu) {
+ return cpu::int16::affine(cpu::int16::quantize(transA ? transpose(a) : a),
+ cpu::int16::quantize(transB ? b : transpose(b)),
+ bias, scalar);
+ }
+ else {
+ std::vector<Expr> nodes = {a, b, bias};
+ return Expression<AffineNodeOp>(nodes, transA, transB, scalar);
+ }
+}
+
Expr transpose(Expr a) {
std::vector<int> axes(a->shape().size());
for(int i = 0; i < axes.size(); ++i) {
@@ -237,11 +262,6 @@ Expr cross_entropy(Expr a, Expr b) {
return Expression<CrossEntropyNodeOp>(a, b);
}
-Expr affine(Expr a, Expr b, Expr c, bool transA, bool transB, float scalar) {
- std::vector<Expr> nodes = {a, b, c};
- return Expression<AffineNodeOp>(nodes, transA, transB, scalar);
-}
-
Expr plus(const std::vector<Expr>&) {
ABORT("Not implemented");
}
diff --git a/src/graph/expression_operators.h b/src/graph/expression_operators.h
index c637105f..070ee0ba 100644
--- a/src/graph/expression_operators.h
+++ b/src/graph/expression_operators.h
@@ -63,6 +63,7 @@ Expr dot(Expr a,
bool transA = false,
bool transB = false,
float scalar = 1.f);
+
Expr bdot(Expr a,
Expr b,
bool transA = false,
diff --git a/src/graph/node.cpp b/src/graph/node.cpp
index 12399f53..66c32d39 100644
--- a/src/graph/node.cpp
+++ b/src/graph/node.cpp
@@ -7,7 +7,7 @@ namespace marian {
size_t Node::allocate() {
size_t elements = 0;
if(!val_) {
- graph()->tensor(val_, shape_);
+ graph()->tensor(val_, shape_, value_type_);
elements = val_->shape().elements();
}
return elements;
@@ -24,15 +24,15 @@ void Node::free() {
void Node::init_dependent() {
if(!adj_) {
- graph()->tensor(adj_, shape_);
- adj_->set(1);
+ graph()->tensor(adj_, shape_, value_type_);
+ adj_->set(1.f);
}
}
void Node::set_zero_adjoint() {
if(!adj_) {
- graph()->tensor(adj_, shape_);
- adj_->set(0);
+ graph()->tensor(adj_, shape_, value_type_);
+ adj_->set(0.f);
}
}
diff --git a/src/graph/node.h b/src/graph/node.h
index 15f223aa..6cc1595d 100644
--- a/src/graph/node.h
+++ b/src/graph/node.h
@@ -23,6 +23,8 @@ protected:
Weak<ExpressionGraph> graph_;
Shape shape_{1, 1, 1, 1};
+ Type value_type_{Type::float32};
+
std::string name_{"none"};
Tensor val_{nullptr};
@@ -32,8 +34,8 @@ protected:
std::string debugMessage_;
public:
- Node(Ptr<ExpressionGraph> graph, Shape shape)
- : graph_(graph), shape_(shape) {}
+ Node(Ptr<ExpressionGraph> graph, Shape shape, Type value_type = Type::float32)
+ : graph_(graph), shape_(shape), value_type_(value_type) {}
virtual ~Node() {
if(destroy_) {
@@ -99,6 +101,7 @@ public:
virtual Tensor& grad() { return adj_; };
virtual const Shape& shape() { return shape_; }
+ virtual const Type& value_type() { return value_type_; }
void set_name(const std::string& name) { name_ = name; }
@@ -139,8 +142,8 @@ public:
struct NaryNodeOp : public Node {
size_t hash_{0};
- NaryNodeOp(const std::vector<Expr>& nodes, Shape shape)
- : Node(nodes.front()->graph(), shape) {
+ NaryNodeOp(const std::vector<Expr>& nodes, Shape shape, Type value_type = Type::float32)
+ : Node(nodes.front()->graph(), shape, value_type) {
children_.resize(nodes.size());
for(int i = 0; i < nodes.size(); ++i)
children_[i] = nodes[i];
diff --git a/src/graph/node_operators.cpp b/src/graph/node_operators.cpp
index 146da4b1..4125d4d1 100644
--- a/src/graph/node_operators.cpp
+++ b/src/graph/node_operators.cpp
@@ -7,7 +7,6 @@
namespace marian {
size_t ConstantNode::allocate() {
- // @TODO params
size_t elements = 0;
if(!val_) {
graph()->tensor(val_, shape_);
@@ -25,7 +24,6 @@ void ConstantNode::init() {
}
size_t ParamNode::allocate() {
- // @TODO params
size_t elements = 0;
if(!val_) {
graph()->tensor(val_, shape_);
diff --git a/src/graph/node_operators.h b/src/graph/node_operators.h
index 7626073d..c53cfe1b 100644
--- a/src/graph/node_operators.h
+++ b/src/graph/node_operators.h
@@ -10,7 +10,7 @@ struct ConstantNode : public Node {
ConstantNode(Ptr<ExpressionGraph> graph,
const Shape& shape,
const NodeInitializer& init)
- : Node(graph, shape),
+ : Node(graph, shape), // TODO: add value_type
init_(new NodeInitializer(init)),
initialized_(false) {
setTrainable(false);
@@ -28,7 +28,7 @@ struct ConstantNode : public Node {
const std::string color() { return "white"; }
virtual size_t hash() {
- std::size_t seed = boost::hash<std::string>()(name());
+ std::size_t seed = boost::hash<std::string>()(name()); // TODO: add value_type
boost::hash_combine(seed, type());
boost::hash_combine(seed, this);
return seed;
@@ -49,7 +49,7 @@ struct ParamNode : public Node {
const Shape& shape,
const NodeInitializer& init,
bool fixed = false)
- : Node(graph, shape),
+ : Node(graph, shape), // TODO: add value_type
init_(new NodeInitializer(init)),
initialized_(false) {
setTrainable(!fixed);
diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h
index 8ab249a8..dc4015b2 100644
--- a/src/graph/node_operators_unary.h
+++ b/src/graph/node_operators_unary.h
@@ -12,9 +12,11 @@
namespace marian {
struct UnaryNodeOp : public NaryNodeOp {
- UnaryNodeOp(Expr a, Shape shape) : NaryNodeOp({a}, shape) {}
+ UnaryNodeOp(Expr a, Shape shape, Type value_type = Type::float32)
+ : NaryNodeOp({a}, shape, value_type) {}
- UnaryNodeOp(Expr a) : NaryNodeOp({a}, a->shape()) {}
+ UnaryNodeOp(Expr a, Type value_type = Type::float32)
+ : NaryNodeOp({a}, a->shape(), value_type) {}
const std::string color() { return "yellow"; }
};