Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/src/graph
diff options
context:
space:
mode:
authorMarcin Junczys-Dowmunt <junczys@amu.edu.pl>2018-04-12 02:50:45 +0300
committerMarcin Junczys-Dowmunt <junczys@amu.edu.pl>2018-04-12 02:50:45 +0300
commit78a99473749ee038778f8b9ec37b16d0a62b86b7 (patch)
tree79d4664fc1d847262e060f1b74040e9ce66b88a1 /src/graph
parentbd06e1919ea908662ce6fe12d468be69c4f8c6f4 (diff)
working memoization
Diffstat (limited to 'src/graph')
-rw-r--r--src/graph/chainable.h4
-rw-r--r--src/graph/expression_graph.cpp2
-rw-r--r--src/graph/expression_graph.h160
-rw-r--r--src/graph/node.cpp6
-rw-r--r--src/graph/node.h10
-rw-r--r--src/graph/node_operators.cpp19
-rw-r--r--src/graph/node_operators.h12
-rw-r--r--src/graph/node_operators_unary.h9
8 files changed, 170 insertions, 52 deletions
diff --git a/src/graph/chainable.h b/src/graph/chainable.h
index 281d3653..b1e4b46e 100644
--- a/src/graph/chainable.h
+++ b/src/graph/chainable.h
@@ -64,9 +64,13 @@ struct Chainable {
virtual void init() = 0;
virtual void init_dependent() {}
virtual void set_zero_adjoint() {}
+
virtual bool trainable() = 0;
virtual void setTrainable(bool) = 0;
+ virtual bool memoize() = 0;
+ virtual void setMemoize(bool) = 0;
+
virtual void setId(size_t) = 0;
virtual size_t getId() = 0;
diff --git a/src/graph/expression_graph.cpp b/src/graph/expression_graph.cpp
index e89f45f3..6dd31968 100644
--- a/src/graph/expression_graph.cpp
+++ b/src/graph/expression_graph.cpp
@@ -13,7 +13,7 @@ void ExpressionGraph::setDevice(DeviceId deviceId) {
backend_ = BackendByDevice(deviceId, Config::seed);
params_ = New<Parameters>();
params_->init(backend_);
- tensors_ = New<TensorAllocator>(backend_);
+ tensors_ = New<Tensors>(backend_);
}
}
diff --git a/src/graph/expression_graph.h b/src/graph/expression_graph.h
index c9f2b245..039ca3a5 100644
--- a/src/graph/expression_graph.h
+++ b/src/graph/expression_graph.h
@@ -22,6 +22,99 @@ namespace marian {
template <class T, typename... Args>
Expr Expression(Args&&... args);
+class Tensors {
+private:
+ Ptr<TensorAllocator> tensors_;
+ Ptr<TensorAllocator> cache_;
+
+ typedef std::unordered_map<size_t, std::vector<WExpr>> WeakMemory;
+ typedef std::unordered_map<size_t, std::vector<Expr>> Memory;
+
+ Ptr<WeakMemory> shortterm_;
+ Ptr<Memory> longterm_;
+
+public:
+ Tensors(Ptr<Backend> backend)
+ : tensors_(New<TensorAllocator>(backend)),
+ cache_(New<TensorAllocator>(backend)),
+ shortterm_(New<WeakMemory>()),
+ longterm_(New<Memory>()) {}
+
+ void reserve(size_t bytes) {
+ tensors_->reserve(bytes);
+ }
+
+ void throwAtReallocation(bool throwAtRealloc) {
+ tensors_->throwAtReallocation(throwAtRealloc);
+ }
+
+ void allocateForward(Expr node) {
+ if(!node->val()) {
+ if(node->memoize())
+ cache_->allocate(node->val(), node->shape(), node->value_type());
+ else
+ tensors_->allocate(node->val(), node->shape(), node->value_type());
+ }
+ }
+
+ void allocateBackward(Expr node) {
+ if(!node->grad())
+ tensors_->allocate(node->grad(), node->shape(), node->value_type());
+ }
+
+ void free(Tensor& tensor) {
+ tensors_->free(tensor);
+ }
+
+ // @TODO: get rid of this, not really used or can be done better
+ Ptr<Allocator> allocator() {
+ return tensors_->allocator();
+ }
+
+ Expr findOrRemember(Expr node) {
+ size_t hash = node->hash();
+ if(node->memoize()) {
+ auto it = longterm_->find(hash);
+ if(it != longterm_->end()) {
+ for(auto found : it->second) {
+ return found;
+ //if(node->equal(found)) {
+ //std::cerr << "found memoized" << std::endl;
+ //return found;
+ //}
+ }
+ }
+ (*longterm_)[hash].push_back(node);
+ }
+
+ auto it = shortterm_->find(hash);
+ if(it != shortterm_->end()) {
+ for(auto foundWeak : it->second) {
+ auto found = foundWeak.lock();
+ if(node->equal(found)) {
+ return found;
+ }
+ }
+ }
+ (*shortterm_)[hash].push_back(node);
+ return nullptr;
+ }
+
+ void clear() {
+ tensors_->clear();
+ shortterm_->clear();
+ }
+
+ void clearShorttermMemory() {
+ shortterm_->clear();
+ }
+
+ void clearLongtermMemory() {
+ longterm_->clear();
+ }
+
+};
+
class ExpressionGraph : public std::enable_shared_from_this<ExpressionGraph> {
private:
size_t count_{0};
@@ -31,11 +124,11 @@ private:
std::unordered_set<Expr> topNodes_;
Ptr<Parameters> params_;
- Ptr<TensorAllocator> tensors_;
+ Ptr<Tensors> tensors_;
Ptr<Backend> backend_;
- std::unordered_map<size_t, std::vector<WExpr>> hashMap_;
+ std::unordered_map<size_t, std::vector<Expr>> memoized_;
bool inferenceOnly_{false};
bool optimized_{false};
@@ -127,7 +220,7 @@ public:
void forwardNext() {
// @TODO: check if allocation works properly
- hashMap_.clear();
+ tensors_->clearShorttermMemory();
while(!nodesForward_.empty()) {
auto v = nodesForward_.front();
@@ -160,7 +253,8 @@ public:
// named_.clear();
topNodes_.clear();
- hashMap_.clear();
+
+ tensors_->clearShorttermMemory();
while(!nodesBackward_.empty()) {
auto v = nodesBackward_.back();
@@ -273,54 +367,60 @@ public:
return Expr();
}
- Ptr<Parameters>& params() { return params_; }
+ Ptr<Parameters>& params() {
+ return params_;
+ }
Expr add(Expr node) {
- size_t hash = node->hash();
- auto it = hashMap_.find(hash);
- if(it != hashMap_.end()) {
- for(auto foundWeak : it->second) {
- auto found = foundWeak.lock();
- if(node->equal(found))
- return found;
- }
- }
- hashMap_[hash].push_back(node);
+ auto found = tensors_->findOrRemember(node);
+ if(found) {
+ return found;
+ } else {
+ node->setId(count_++);
- node->setId(count_++);
+ nodesForward_.push_back(node);
+ if(!inferenceOnly_ && node->trainable()) {
+ nodesBackward_.push_back(node);
+ topNodes_.insert(node);
+ }
- nodesForward_.push_back(node);
- if(!inferenceOnly_ && node->trainable()) {
- nodesBackward_.push_back(node);
- topNodes_.insert(node);
+ return node;
}
+ }
- return node;
+ void remove_top_node(Expr node) {
+ topNodes_.erase(node);
}
- void remove_top_node(Expr node) { topNodes_.erase(node); }
+ void allocateForward(Expr node) {
+ if(tensors_)
+ tensors_->allocateForward(node);
+ }
- template <class... Args>
- void tensor(Tensor& t, Args&&... args) {
- tensors_->allocate(t, args...);
+ void allocateBackward(Expr node) {
+ if(tensors_)
+ tensors_->allocateBackward(node);
}
- void free(Tensor& t) {
+ void free(Tensor& tensor) {
if(tensors_)
- tensors_->free(t);
+ tensors_->free(tensor);
}
- Ptr<Allocator> allocator() { return tensors_->allocator(); }
+ // @TODO: get rid of this, not really used or can be done better
+ Ptr<Allocator> allocator() {
+ return tensors_->allocator();
+ }
void clear() {
- // clear everything apart from parameters
+ // clear everything apart from parameters and memoized nodes
count_ = 0;
nodesForward_.clear();
nodesBackward_.clear();
topNodes_.clear();
- hashMap_.clear();
+
tensors_->clear();
}
diff --git a/src/graph/node.cpp b/src/graph/node.cpp
index 66c32d39..5160e95c 100644
--- a/src/graph/node.cpp
+++ b/src/graph/node.cpp
@@ -7,7 +7,7 @@ namespace marian {
size_t Node::allocate() {
size_t elements = 0;
if(!val_) {
- graph()->tensor(val_, shape_, value_type_);
+ graph()->allocateForward(shared_from_this());
elements = val_->shape().elements();
}
return elements;
@@ -24,14 +24,14 @@ void Node::free() {
void Node::init_dependent() {
if(!adj_) {
- graph()->tensor(adj_, shape_, value_type_);
+ graph()->allocateBackward(shared_from_this());
adj_->set(1.f);
}
}
void Node::set_zero_adjoint() {
if(!adj_) {
- graph()->tensor(adj_, shape_, value_type_);
+ graph()->allocateBackward(shared_from_this());
adj_->set(0.f);
}
}
diff --git a/src/graph/node.h b/src/graph/node.h
index 6cc1595d..43dcadb1 100644
--- a/src/graph/node.h
+++ b/src/graph/node.h
@@ -19,6 +19,8 @@ protected:
size_t edges_{0};
bool trainable_{true};
bool destroy_{true};
+ bool memoize_{false};
+
std::vector<Expr> children_;
Weak<ExpressionGraph> graph_;
@@ -68,6 +70,9 @@ public:
virtual void setTrainable(bool trainable) { trainable_ = trainable; }
+ virtual bool memoize() { return memoize_; };
+ virtual void setMemoize(bool memoize) { memoize_ = memoize; };
+
virtual void setId(size_t id) { id_ = id; }
virtual size_t getId() { return id_; }
@@ -150,6 +155,11 @@ struct NaryNodeOp : public Node {
setTrainable(std::any_of(
nodes.begin(), nodes.end(), [](Expr a) { return a->trainable(); }));
+
+ // Node is to be memoized if all children are to be memoized.
+ setMemoize(std::all_of(
+ nodes.begin(), nodes.end(), [](Expr a) { return a->memoize(); }));
+
remove_children_from_top_nodes();
}
diff --git a/src/graph/node_operators.cpp b/src/graph/node_operators.cpp
index 4125d4d1..788b3fc6 100644
--- a/src/graph/node_operators.cpp
+++ b/src/graph/node_operators.cpp
@@ -9,7 +9,7 @@ namespace marian {
size_t ConstantNode::allocate() {
size_t elements = 0;
if(!val_) {
- graph()->tensor(val_, shape_);
+ graph()->allocateForward(shared_from_this());
elements = val_->shape().elements();
}
return elements;
@@ -23,15 +23,18 @@ void ConstantNode::init() {
init_.reset();
}
-size_t ParamNode::allocate() {
- size_t elements = 0;
- if(!val_) {
- graph()->tensor(val_, shape_);
- elements = val_->shape().elements();
- }
- return elements;
+ParamNode::ParamNode(Ptr<ExpressionGraph> graph,
+ const Shape& shape,
+ const NodeInitializer& init,
+ bool fixed)
+ : Node(graph, shape), // TODO: add value_type
+ init_(new NodeInitializer(init)),
+ initialized_(false) {
+ setTrainable(!fixed);
+ setMemoize(graph->isInference());
}
+
void ParamNode::init() {
if(!initialized_) {
(*init_)(val_);
diff --git a/src/graph/node_operators.h b/src/graph/node_operators.h
index 4a554de7..cd490ffc 100644
--- a/src/graph/node_operators.h
+++ b/src/graph/node_operators.h
@@ -45,16 +45,14 @@ struct ParamNode : public Node {
ParamNode(Ptr<ExpressionGraph> graph,
const Shape& shape,
const NodeInitializer& init,
- bool fixed = false)
- : Node(graph, shape), // TODO: add value_type
- init_(new NodeInitializer(init)),
- initialized_(false) {
- setTrainable(!fixed);
- }
+ bool fixed = false);
~ParamNode() {}
- virtual size_t allocate();
+ virtual size_t allocate() {
+ ABORT_IF(!val_, "Parameters should be allocated by their graph");
+ return 0;
+ }
virtual void init();
diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h
index dc4015b2..273adf44 100644
--- a/src/graph/node_operators_unary.h
+++ b/src/graph/node_operators_unary.h
@@ -77,7 +77,7 @@ public:
return {NodeOp(Add(scalar_ * _1, child(0)->grad(), adj_))};
}
- const std::string type() { return "scalar_add"; }
+ const std::string type() { return "scalar_mult"; }
virtual size_t hash() {
if(!hash_) {
@@ -605,8 +605,11 @@ struct NegNodeOp : public UnaryNodeOp {
};
struct RowsNodeOp : public UnaryNodeOp {
- RowsNodeOp(Expr a, const std::vector<size_t>& indeces)
- : UnaryNodeOp(a, newShape(a, indeces)), indices_(indeces) {}
+ RowsNodeOp(Expr a, const std::vector<size_t>& indices)
+ : UnaryNodeOp(a, newShape(a, indices)), indices_(indices) {
+ // @TODO: fix this by using int32 tensor for indices
+ setMemoize(false);
+ }
NodeOps forwardOps() {
// @TODO: solve this with a tensor!