Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorMarcin Junczys-Dowmunt <junczys@amu.edu.pl>2018-04-12 02:50:45 +0300
committerMarcin Junczys-Dowmunt <junczys@amu.edu.pl>2018-04-12 02:50:45 +0300
commit78a99473749ee038778f8b9ec37b16d0a62b86b7 (patch)
tree79d4664fc1d847262e060f1b74040e9ce66b88a1 /src
parentbd06e1919ea908662ce6fe12d468be69c4f8c6f4 (diff)
working memoization
Diffstat (limited to 'src')
-rw-r--r--src/graph/chainable.h4
-rw-r--r--src/graph/expression_graph.cpp2
-rw-r--r--src/graph/expression_graph.h160
-rw-r--r--src/graph/node.cpp6
-rw-r--r--src/graph/node.h10
-rw-r--r--src/graph/node_operators.cpp19
-rw-r--r--src/graph/node_operators.h12
-rw-r--r--src/graph/node_operators_unary.h9
-rwxr-xr-xsrc/tensors/cpu/sharp/sse_gemm.h49
9 files changed, 182 insertions, 89 deletions
diff --git a/src/graph/chainable.h b/src/graph/chainable.h
index 281d3653..b1e4b46e 100644
--- a/src/graph/chainable.h
+++ b/src/graph/chainable.h
@@ -64,9 +64,13 @@ struct Chainable {
virtual void init() = 0;
virtual void init_dependent() {}
virtual void set_zero_adjoint() {}
+
virtual bool trainable() = 0;
virtual void setTrainable(bool) = 0;
+ virtual bool memoize() = 0;
+ virtual void setMemoize(bool) = 0;
+
virtual void setId(size_t) = 0;
virtual size_t getId() = 0;
diff --git a/src/graph/expression_graph.cpp b/src/graph/expression_graph.cpp
index e89f45f3..6dd31968 100644
--- a/src/graph/expression_graph.cpp
+++ b/src/graph/expression_graph.cpp
@@ -13,7 +13,7 @@ void ExpressionGraph::setDevice(DeviceId deviceId) {
backend_ = BackendByDevice(deviceId, Config::seed);
params_ = New<Parameters>();
params_->init(backend_);
- tensors_ = New<TensorAllocator>(backend_);
+ tensors_ = New<Tensors>(backend_);
}
}
diff --git a/src/graph/expression_graph.h b/src/graph/expression_graph.h
index c9f2b245..039ca3a5 100644
--- a/src/graph/expression_graph.h
+++ b/src/graph/expression_graph.h
@@ -22,6 +22,99 @@ namespace marian {
template <class T, typename... Args>
Expr Expression(Args&&... args);
+class Tensors {
+private:
+ Ptr<TensorAllocator> tensors_;
+ Ptr<TensorAllocator> cache_;
+
+ typedef std::unordered_map<size_t, std::vector<WExpr>> WeakMemory;
+ typedef std::unordered_map<size_t, std::vector<Expr>> Memory;
+
+ Ptr<WeakMemory> shortterm_;
+ Ptr<Memory> longterm_;
+
+public:
+ Tensors(Ptr<Backend> backend)
+ : tensors_(New<TensorAllocator>(backend)),
+ cache_(New<TensorAllocator>(backend)),
+ shortterm_(New<WeakMemory>()),
+ longterm_(New<Memory>()) {}
+
+ void reserve(size_t bytes) {
+ tensors_->reserve(bytes);
+ }
+
+ void throwAtReallocation(bool throwAtRealloc) {
+ tensors_->throwAtReallocation(throwAtRealloc);
+ }
+
+ void allocateForward(Expr node) {
+ if(!node->val()) {
+ if(node->memoize())
+ cache_->allocate(node->val(), node->shape(), node->value_type());
+ else
+ tensors_->allocate(node->val(), node->shape(), node->value_type());
+ }
+ }
+
+ void allocateBackward(Expr node) {
+ if(!node->grad())
+ tensors_->allocate(node->grad(), node->shape(), node->value_type());
+ }
+
+ void free(Tensor& tensor) {
+ tensors_->free(tensor);
+ }
+
+ // @TODO: get rid of this, not really used or can be done better
+ Ptr<Allocator> allocator() {
+ return tensors_->allocator();
+ }
+
+ Expr findOrRemember(Expr node) {
+ size_t hash = node->hash();
+ if(node->memoize()) {
+ auto it = longterm_->find(hash);
+ if(it != longterm_->end()) {
+ for(auto found : it->second) {
+ return found;
+ //if(node->equal(found)) {
+ //std::cerr << "found memoized" << std::endl;
+ //return found;
+ //}
+ }
+ }
+ (*longterm_)[hash].push_back(node);
+ }
+
+ auto it = shortterm_->find(hash);
+ if(it != shortterm_->end()) {
+ for(auto foundWeak : it->second) {
+ auto found = foundWeak.lock();
+ if(node->equal(found)) {
+ return found;
+ }
+ }
+ }
+ (*shortterm_)[hash].push_back(node);
+ return nullptr;
+ }
+
+ void clear() {
+ tensors_->clear();
+ shortterm_->clear();
+ }
+
+ void clearShorttermMemory() {
+ shortterm_->clear();
+ }
+
+ void clearLongtermMemory() {
+ longterm_->clear();
+ }
+
+};
+
class ExpressionGraph : public std::enable_shared_from_this<ExpressionGraph> {
private:
size_t count_{0};
@@ -31,11 +124,11 @@ private:
std::unordered_set<Expr> topNodes_;
Ptr<Parameters> params_;
- Ptr<TensorAllocator> tensors_;
+ Ptr<Tensors> tensors_;
Ptr<Backend> backend_;
- std::unordered_map<size_t, std::vector<WExpr>> hashMap_;
+ std::unordered_map<size_t, std::vector<Expr>> memoized_;
bool inferenceOnly_{false};
bool optimized_{false};
@@ -127,7 +220,7 @@ public:
void forwardNext() {
// @TODO: check if allocation works properly
- hashMap_.clear();
+ tensors_->clearShorttermMemory();
while(!nodesForward_.empty()) {
auto v = nodesForward_.front();
@@ -160,7 +253,8 @@ public:
// named_.clear();
topNodes_.clear();
- hashMap_.clear();
+
+ tensors_->clearShorttermMemory();
while(!nodesBackward_.empty()) {
auto v = nodesBackward_.back();
@@ -273,54 +367,60 @@ public:
return Expr();
}
- Ptr<Parameters>& params() { return params_; }
+ Ptr<Parameters>& params() {
+ return params_;
+ }
Expr add(Expr node) {
- size_t hash = node->hash();
- auto it = hashMap_.find(hash);
- if(it != hashMap_.end()) {
- for(auto foundWeak : it->second) {
- auto found = foundWeak.lock();
- if(node->equal(found))
- return found;
- }
- }
- hashMap_[hash].push_back(node);
+ auto found = tensors_->findOrRemember(node);
+ if(found) {
+ return found;
+ } else {
+ node->setId(count_++);
- node->setId(count_++);
+ nodesForward_.push_back(node);
+ if(!inferenceOnly_ && node->trainable()) {
+ nodesBackward_.push_back(node);
+ topNodes_.insert(node);
+ }
- nodesForward_.push_back(node);
- if(!inferenceOnly_ && node->trainable()) {
- nodesBackward_.push_back(node);
- topNodes_.insert(node);
+ return node;
}
+ }
- return node;
+ void remove_top_node(Expr node) {
+ topNodes_.erase(node);
}
- void remove_top_node(Expr node) { topNodes_.erase(node); }
+ void allocateForward(Expr node) {
+ if(tensors_)
+ tensors_->allocateForward(node);
+ }
- template <class... Args>
- void tensor(Tensor& t, Args&&... args) {
- tensors_->allocate(t, args...);
+ void allocateBackward(Expr node) {
+ if(tensors_)
+ tensors_->allocateBackward(node);
}
- void free(Tensor& t) {
+ void free(Tensor& tensor) {
if(tensors_)
- tensors_->free(t);
+ tensors_->free(tensor);
}
- Ptr<Allocator> allocator() { return tensors_->allocator(); }
+ // @TODO: get rid of this, not really used or can be done better
+ Ptr<Allocator> allocator() {
+ return tensors_->allocator();
+ }
void clear() {
- // clear everything apart from parameters
+ // clear everything apart from parameters and memoized nodes
count_ = 0;
nodesForward_.clear();
nodesBackward_.clear();
topNodes_.clear();
- hashMap_.clear();
+
tensors_->clear();
}
diff --git a/src/graph/node.cpp b/src/graph/node.cpp
index 66c32d39..5160e95c 100644
--- a/src/graph/node.cpp
+++ b/src/graph/node.cpp
@@ -7,7 +7,7 @@ namespace marian {
size_t Node::allocate() {
size_t elements = 0;
if(!val_) {
- graph()->tensor(val_, shape_, value_type_);
+ graph()->allocateForward(shared_from_this());
elements = val_->shape().elements();
}
return elements;
@@ -24,14 +24,14 @@ void Node::free() {
void Node::init_dependent() {
if(!adj_) {
- graph()->tensor(adj_, shape_, value_type_);
+ graph()->allocateBackward(shared_from_this());
adj_->set(1.f);
}
}
void Node::set_zero_adjoint() {
if(!adj_) {
- graph()->tensor(adj_, shape_, value_type_);
+ graph()->allocateBackward(shared_from_this());
adj_->set(0.f);
}
}
diff --git a/src/graph/node.h b/src/graph/node.h
index 6cc1595d..43dcadb1 100644
--- a/src/graph/node.h
+++ b/src/graph/node.h
@@ -19,6 +19,8 @@ protected:
size_t edges_{0};
bool trainable_{true};
bool destroy_{true};
+ bool memoize_{false};
+
std::vector<Expr> children_;
Weak<ExpressionGraph> graph_;
@@ -68,6 +70,9 @@ public:
virtual void setTrainable(bool trainable) { trainable_ = trainable; }
+ virtual bool memoize() { return memoize_; };
+ virtual void setMemoize(bool memoize) { memoize_ = memoize; };
+
virtual void setId(size_t id) { id_ = id; }
virtual size_t getId() { return id_; }
@@ -150,6 +155,11 @@ struct NaryNodeOp : public Node {
setTrainable(std::any_of(
nodes.begin(), nodes.end(), [](Expr a) { return a->trainable(); }));
+
+ // Node is to be memoized if all children are to be memoized.
+ setMemoize(std::all_of(
+ nodes.begin(), nodes.end(), [](Expr a) { return a->memoize(); }));
+
remove_children_from_top_nodes();
}
diff --git a/src/graph/node_operators.cpp b/src/graph/node_operators.cpp
index 4125d4d1..788b3fc6 100644
--- a/src/graph/node_operators.cpp
+++ b/src/graph/node_operators.cpp
@@ -9,7 +9,7 @@ namespace marian {
size_t ConstantNode::allocate() {
size_t elements = 0;
if(!val_) {
- graph()->tensor(val_, shape_);
+ graph()->allocateForward(shared_from_this());
elements = val_->shape().elements();
}
return elements;
@@ -23,15 +23,18 @@ void ConstantNode::init() {
init_.reset();
}
-size_t ParamNode::allocate() {
- size_t elements = 0;
- if(!val_) {
- graph()->tensor(val_, shape_);
- elements = val_->shape().elements();
- }
- return elements;
+ParamNode::ParamNode(Ptr<ExpressionGraph> graph,
+ const Shape& shape,
+ const NodeInitializer& init,
+ bool fixed)
+ : Node(graph, shape), // TODO: add value_type
+ init_(new NodeInitializer(init)),
+ initialized_(false) {
+ setTrainable(!fixed);
+ setMemoize(graph->isInference());
}
+
void ParamNode::init() {
if(!initialized_) {
(*init_)(val_);
diff --git a/src/graph/node_operators.h b/src/graph/node_operators.h
index 4a554de7..cd490ffc 100644
--- a/src/graph/node_operators.h
+++ b/src/graph/node_operators.h
@@ -45,16 +45,14 @@ struct ParamNode : public Node {
ParamNode(Ptr<ExpressionGraph> graph,
const Shape& shape,
const NodeInitializer& init,
- bool fixed = false)
- : Node(graph, shape), // TODO: add value_type
- init_(new NodeInitializer(init)),
- initialized_(false) {
- setTrainable(!fixed);
- }
+ bool fixed = false);
~ParamNode() {}
- virtual size_t allocate();
+ virtual size_t allocate() {
+ ABORT_IF(!val_, "Parameters should be allocated by their graph");
+ return 0;
+ }
virtual void init();
diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h
index dc4015b2..273adf44 100644
--- a/src/graph/node_operators_unary.h
+++ b/src/graph/node_operators_unary.h
@@ -77,7 +77,7 @@ public:
return {NodeOp(Add(scalar_ * _1, child(0)->grad(), adj_))};
}
- const std::string type() { return "scalar_add"; }
+ const std::string type() { return "scalar_mult"; }
virtual size_t hash() {
if(!hash_) {
@@ -605,8 +605,11 @@ struct NegNodeOp : public UnaryNodeOp {
};
struct RowsNodeOp : public UnaryNodeOp {
- RowsNodeOp(Expr a, const std::vector<size_t>& indeces)
- : UnaryNodeOp(a, newShape(a, indeces)), indices_(indeces) {}
+ RowsNodeOp(Expr a, const std::vector<size_t>& indices)
+ : UnaryNodeOp(a, newShape(a, indices)), indices_(indices) {
+ // @TODO: fix this by using int32 tensor for indices
+ setMemoize(false);
+ }
NodeOps forwardOps() {
// @TODO: solve this with a tensor!
diff --git a/src/tensors/cpu/sharp/sse_gemm.h b/src/tensors/cpu/sharp/sse_gemm.h
index 5b8f5aba..542c3f6d 100755
--- a/src/tensors/cpu/sharp/sse_gemm.h
+++ b/src/tensors/cpu/sharp/sse_gemm.h
@@ -65,33 +65,6 @@
namespace marian {
-//void ParamNode::transposeAndQuantize() {
-// Tensor temp;
-// graph()->tensor(temp, shape_, Type::float32);
-// (*init_)(temp);
-//
-// if(transpose_) {
-// Tensor temp2;
-// graph()->tensor(temp2, Shape{shape_[-1], shape_[-2]}, Type::float32);
-// TransposeND(temp2, temp, {1, 0});
-// graph()->free(temp);
-// temp = temp2;
-// }
-//
-// int num_rows = temp->shape()[-2];
-// int width = temp->shape()[-1];
-// double quant_mult = pow(2.0, 10.0);
-// assert(width % 8 == 0);
-//
-// Quantize(temp->data(),
-// val_->data<__m128i>(),
-// (float)quant_mult,
-// num_rows,
-// width);
-//
-// graph()->free(temp);
-//}
-
namespace cpu {
namespace int16 {
@@ -247,16 +220,16 @@ static inline void SSE_MatrixMult(marian::Tensor C,
// We can't have consecutive accesses of qA, qB, *and* C. But we access qA and qB a lot more so it makes
// sense to do it this way.
_mm_store_ss(C1, _mm_cvtepi32_ps(sum1));
- *(C1) *= unquant_mult * scale;
+ *(C1) *= unquant_mult;
_mm_store_ss(C2, _mm_cvtepi32_ps(sum2));
- *(C2) *= unquant_mult * scale;
+ *(C2) *= unquant_mult;
_mm_store_ss(C3, _mm_cvtepi32_ps(sum3));
- *(C3) *= unquant_mult * scale;
+ *(C3) *= unquant_mult;
_mm_store_ss(C4, _mm_cvtepi32_ps(sum4));
- *(C4) *= unquant_mult * scale;
+ *(C4) *= unquant_mult;
}
}
if(rest == 1) {
@@ -281,7 +254,7 @@ static inline void SSE_MatrixMult(marian::Tensor C,
float * C1 = fC + (i + 0) * num_B_rows + j;
_mm_store_ss(C1, _mm_cvtepi32_ps(sum1));
- *(C1) *= unquant_mult * scale;
+ *(C1) *= unquant_mult;
}
}
else if(rest == 2) {
@@ -313,10 +286,10 @@ static inline void SSE_MatrixMult(marian::Tensor C,
float * C2 = fC + (i+1)*num_B_rows + j;
_mm_store_ss(C1, _mm_cvtepi32_ps(sum1));
- *(C1) *= unquant_mult * scale;
+ *(C1) *= unquant_mult;
_mm_store_ss(C2, _mm_cvtepi32_ps(sum2));
- *(C2) *= unquant_mult * scale;
+ *(C2) *= unquant_mult;
}
}
else if(rest == 3) {
@@ -355,13 +328,13 @@ static inline void SSE_MatrixMult(marian::Tensor C,
float * C3 = fC + (i+2)*num_B_rows + j;
_mm_store_ss(C1, _mm_cvtepi32_ps(sum1));
- *(C1) *= unquant_mult * scale;
+ *(C1) *= unquant_mult;
_mm_store_ss(C2, _mm_cvtepi32_ps(sum2));
- *(C2) *= unquant_mult * scale;
+ *(C2) *= unquant_mult;
_mm_store_ss(C3, _mm_cvtepi32_ps(sum3));
- *(C3) *= unquant_mult * scale;
+ *(C3) *= unquant_mult;
}
}
}
@@ -393,6 +366,8 @@ static void ProdInt(marian::Tensor C,
const marian::Tensor B,
float scale) {
+ ABORT_IF(scale != 1, "Scale other than 1 not supported");
+
// @TODO: make this a parameter
float quant_mult = pow(2.0, (float)BITS);