Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMarcin Junczys-Dowmunt <junczys@amu.edu.pl>2017-04-23 18:40:43 +0300
committerMarcin Junczys-Dowmunt <junczys@amu.edu.pl>2017-04-23 18:40:43 +0300
commitc13bc7cec3a18dca1f8cd1a5f3a24617e96c0e3f (patch)
treeaddf17f949ab45ee1a1eb5158a2456d1e66f3bdf /src/graph/node_operators_unary.h
parent38d9204fe565335536b7daa94568367253c14063 (diff)
better memory handling
Diffstat (limited to 'src/graph/node_operators_unary.h')
-rw-r--r--src/graph/node_operators_unary.h142
1 files changed, 76 insertions, 66 deletions
diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h
index 20298e2a..8b78343f 100644
--- a/src/graph/node_operators_unary.h
+++ b/src/graph/node_operators_unary.h
@@ -29,14 +29,14 @@ struct LogitNodeOp : public UnaryNodeOp {
return {
NodeOp(Element(_1 = Sigma(_2),
val_,
- children_[0]->val()))
+ child(0)->val()))
};
}
NodeOps backwardOps() {
return {
NodeOp(Add(_1 * _2 * (1.0f - _2),
- children_[0]->grad(),
+ child(0)->grad(),
adj_, val_))
};
}
@@ -69,28 +69,28 @@ struct TanhNodeOp : public NaryNodeOp {
case 1:
return { NodeOp(Element(_1 = Tanh(_2),
val_,
- children_[0]->val())) };
+ child(0)->val())) };
case 2:
return { NodeOp(Element(_1 = Tanh(_2 + _3),
val_,
- children_[0]->val(),
- children_[1]->val())) };
+ child(0)->val(),
+ child(1)->val())) };
case 3:
return { NodeOp(Element(_1 = Tanh(_2 + _3 + _4),
val_,
- children_[0]->val(),
- children_[1]->val(),
- children_[2]->val())) };
+ child(0)->val(),
+ child(1)->val(),
+ child(2)->val())) };
default:
return {
NodeOp(
Element(_1 = _2 + _3 + _4,
val_,
- children_[0]->val(),
- children_[1]->val(),
- children_[2]->val());
+ child(0)->val(),
+ child(1)->val(),
+ child(2)->val());
for(int i = 3; i < children_.size(); ++i)
- Element(_1 += _2, val_, children_[i]->val());
+ Element(_1 += _2, val_, child(i)->val());
Element(_1 = Tanh(_1), val_);
)
};
@@ -99,10 +99,10 @@ struct TanhNodeOp : public NaryNodeOp {
NodeOps backwardOps() {
NodeOps ops;
- for(auto&& child : children_) {
+ for(int i = 0; i < children_.size(); i++) {
ops.push_back(
NodeOp(Add(_1 * (1.0f - (_2 * _2)),
- child->grad(), adj_, val_))
+ child(i)->grad(), adj_, val_))
);
}
return ops;
@@ -141,15 +141,15 @@ struct ReLUNodeOp : public UnaryNodeOp {
return {
NodeOp(Element(_1 = ReLU(_2),
val_,
- children_[0]->val()))
+ child(0)->val()))
};
}
NodeOps backwardOps() {
return {
NodeOp(Add(_1 * ReLUback(_2),
- children_[0]->grad(),
- adj_, children_[0]->val()))
+ child(0)->grad(),
+ adj_, child(0)->val()))
};
}
@@ -174,7 +174,7 @@ struct SoftmaxNodeOp : public NaryNodeOp {
NodeOps forwardOps() {
return {
NodeOp(Softmax(val_,
- children_[0]->val(),
+ child(0)->val(),
mask_ ? mask_->val() : nullptr))
};
}
@@ -203,7 +203,7 @@ struct SoftmaxNodeOp : public NaryNodeOp {
// val_ is already masked if there is a mask, so no need to apply here.
return {
- NodeOp(SoftmaxGrad(children_[0]->grad(), adj_, val_))
+ NodeOp(SoftmaxGrad(child(0)->grad(), adj_, val_))
};
}
@@ -219,7 +219,7 @@ struct LogSoftmaxNodeOp : public UnaryNodeOp {
NodeOps forwardOps() {
return {
- NodeOp(LogSoftmax(val_, children_[0]->val()))
+ NodeOp(LogSoftmax(val_, child(0)->val()))
};
}
@@ -228,7 +228,7 @@ struct LogSoftmaxNodeOp : public UnaryNodeOp {
// J * dy = dy - avg*1
// where avg = exp(p)'*dy and p is the softmax output (probabilities).
return {
- NodeOp(LogSoftmaxGrad(children_[0]->grad(), adj_, val_))
+ NodeOp(LogSoftmaxGrad(child(0)->grad(), adj_, val_))
};
}
@@ -246,11 +246,11 @@ struct SumNodeOp : public UnaryNodeOp {
ax_(keywords::Get(keywords::axis, -1, args...)) { }
NodeOps forwardOps() {
- return { NodeOp(Reduce(_1, val_, children_[0]->val())) };
+ return { NodeOp(Reduce(_1, val_, child(0)->val())) };
}
NodeOps backwardOps() {
- return { NodeOp(Add(_1, children_[0]->grad(), adj_)) };
+ return { NodeOp(Add(_1, child(0)->grad(), adj_)) };
}
template <class ...Args>
@@ -297,20 +297,20 @@ struct MeanNodeOp : public UnaryNodeOp {
ax_(keywords::Get(keywords::axis, -1, args...)) { }
NodeOps forwardOps() {
- int left = children_[0]->shape().elements() / val_->shape().elements();
+ int left = child(0)->shape().elements() / val_->shape().elements();
float scale = 1.f / left;
return {
- NodeOp(Reduce(_1, val_, children_[0]->val(), scale))
+ NodeOp(Reduce(_1, val_, child(0)->val(), scale))
};
}
NodeOps backwardOps() {
- int left = children_[0]->shape().elements() / val_->shape().elements();
+ int left = child(0)->shape().elements() / val_->shape().elements();
float scale = 1.f / left;
return {
- NodeOp(Add(_1, children_[0]->grad(), adj_, scale))
+ NodeOp(Add(_1, child(0)->grad(), adj_, scale))
};
}
@@ -358,16 +358,16 @@ struct LogNodeOp : public UnaryNodeOp {
return {
NodeOp(Element(_1 = Log(_2),
val_,
- children_[0]->val()))
+ child(0)->val()))
};
}
NodeOps backwardOps() {
return {
NodeOp(Add(_1 * (1.f / _2),
- children_[0]->grad(),
+ child(0)->grad(),
adj_,
- children_[0]->val()))
+ child(0)->val()))
};
}
@@ -385,16 +385,16 @@ struct ExpNodeOp : public UnaryNodeOp {
return {
NodeOp(Element(_1 = Exp(_2),
val_,
- children_[0]->val()))
+ child(0)->val()))
};
}
NodeOps backwardOps() {
return {
NodeOp(Add(_1 * Exp(_2),
- children_[0]->grad(),
+ child(0)->grad(),
adj_,
- children_[0]->val()))
+ child(0)->val()))
};
}
@@ -416,14 +416,14 @@ struct SqrtNodeOp : public UnaryNodeOp {
return {
NodeOp(Element(_1 = Sqrt(_2 + epsilon_),
val_,
- children_[0]->val()))
+ child(0)->val()))
};
}
NodeOps backwardOps() {
return {
NodeOp(Add(0.5f * (1.f / _1) * _2,
- children_[0]->grad(),
+ child(0)->grad(),
val_,
adj_))
};
@@ -456,15 +456,15 @@ struct SquareNodeOp : public UnaryNodeOp {
return {
NodeOp(Element(_1 = _2 * _2,
val_,
- children_[0]->val()))
+ child(0)->val()))
};
}
NodeOps backwardOps() {
return {
NodeOp(Add(2.f * _1 * _2,
- children_[0]->grad(),
- children_[0]->val(),
+ child(0)->grad(),
+ child(0)->val(),
adj_))
};
}
@@ -485,14 +485,14 @@ struct NegNodeOp : public UnaryNodeOp {
return {
NodeOp(Element(_1 = -_2,
val_,
- children_[0]->val()))
+ child(0)->val()))
};
}
NodeOps backwardOps() {
return {
NodeOp(Add(-_1,
- children_[0]->grad(),
+ child(0)->grad(),
adj_))
};
}
@@ -514,14 +514,14 @@ struct RowsNodeOp : public UnaryNodeOp {
return {
NodeOp(CopyRows(val_,
- children_[0]->val(),
+ child(0)->val(),
indeces_))
};
}
NodeOps backwardOps() {
return {
- NodeOp(PasteRows(children_[0]->grad(),
+ NodeOp(PasteRows(child(0)->grad(),
adj_,
indeces_))
};
@@ -568,14 +568,14 @@ struct ColsNodeOp : public UnaryNodeOp {
return {
NodeOp(CopyCols(val_,
- children_[0]->val(),
+ child(0)->val(),
indeces_))
};
}
NodeOps backwardOps() {
return {
- NodeOp(PasteCols(children_[0]->grad(),
+ NodeOp(PasteCols(child(0)->grad(),
adj_,
indeces_))
};
@@ -619,14 +619,14 @@ struct TransposeNodeOp : public UnaryNodeOp {
NodeOps forwardOps() {
return {
NodeOp(Transpose(getCublasHandle(),
- val_, children_[0]->val()))
+ val_, child(0)->val()))
};
}
NodeOps backwardOps() {
return {
NodeOp(Transpose(getCublasHandle(),
- children_[0]->grad(), adj_))
+ child(0)->grad(), adj_))
};
}
@@ -648,11 +648,18 @@ struct TransposeNodeOp : public UnaryNodeOp {
}
};
-struct ReshapeNodeOp : public UnaryNodeOp {
+class ReshapeNodeOp : public UnaryNodeOp {
+private:
+ Expr reshapee_;
+
+public:
template <typename ...Args>
ReshapeNodeOp(Expr a, Shape shape, Args ...args)
- : UnaryNodeOp(a, keywords::shape=shape, args...) { }
+ : UnaryNodeOp(a, keywords::shape=shape, args...),
+ reshapee_(a) { }
+
+
size_t allocate() { return 0; }
void free() {}
@@ -660,21 +667,21 @@ struct ReshapeNodeOp : public UnaryNodeOp {
void backward() {}
void init_dependent() {
- children_[0]->init_dependent();
+ reshapee_->init_dependent();
}
void set_zero_adjoint() {
- children_[0]->set_zero_adjoint();
+ reshapee_->set_zero_adjoint();
}
Tensor& val() {
- auto childVal = children_[0]->val();
+ auto childVal = reshapee_->val();
val_.reset(new TensorBase(childVal->data(), shape(), childVal->getDevice()));
return val_;
};
Tensor& grad() {
- auto childGrad = children_[0]->grad();
+ auto childGrad = reshapee_->grad();
adj_.reset(new TensorBase(childGrad->data(), shape(), childGrad->getDevice()));
return adj_;
};
@@ -699,12 +706,15 @@ struct ReshapeNodeOp : public UnaryNodeOp {
};
-struct TimestepNodeOp : public UnaryNodeOp {
+class TimestepNodeOp : public UnaryNodeOp {
+private:
+ Expr stepNode_;
size_t step_;
+public:
TimestepNodeOp(Expr a, size_t step)
: UnaryNodeOp(a, keywords::shape=newShape(a)),
- step_(step)
+ stepNode_(a), step_(step)
{ }
Shape newShape(Expr a) {
@@ -721,22 +731,22 @@ struct TimestepNodeOp : public UnaryNodeOp {
void backward() {}
void init_dependent() {
- children_[0]->init_dependent();
+ stepNode_->init_dependent();
}
void set_zero_adjoint() {
- children_[0]->set_zero_adjoint();
+ stepNode_->set_zero_adjoint();
}
Tensor& val() {
- auto childVal = children_[0]->val();
+ auto childVal = stepNode_->val();
size_t offset = step_ * shape().elements();
val_.reset(new TensorBase(childVal->data() + offset, shape(), childVal->getDevice()));
return val_;
};
Tensor& grad() {
- auto childGrad = children_[0]->grad();
+ auto childGrad = stepNode_->grad();
size_t offset = step_ * shape().elements();
adj_.reset(new TensorBase(childGrad->data() + offset, shape(), childGrad->getDevice()));
return adj_;
@@ -770,14 +780,14 @@ struct ShiftNodeOp : public UnaryNodeOp {
NodeOps forwardOps() {
return {
NodeOp(Shift(val_,
- children_[0]->val(),
+ child(0)->val(),
shift_))
};
}
NodeOps backwardOps() {
return {
- NodeOp(Shift(children_[0]->grad(),
+ NodeOp(Shift(child(0)->grad(),
adj_,
shift_,
true))
@@ -812,20 +822,20 @@ struct LexicalProbNodeOp : public NaryNodeOp {
void forward() {
sparse::LfaForward(val_,
- children_[0]->val(),
- children_[1]->val(),
+ child(0)->val(),
+ child(1)->val(),
lf_);
// val = x + ln(p + eps)
Element(_1 = (Log(_1 + eps_) + _2),
- val_, children_[0]->val());
+ val_, child(0)->val());
}
void backward() {
- Add(_1, children_[0]->grad(), adj_);
+ Add(_1, child(0)->grad(), adj_);
// adj' = adj / (p + eps) = adj / exp(val - x)
Element(_1 = _1 / Exp(_2 - _3),
- adj_, val_, children_[0]->val());
- sparse::LfaBackward(children_[1]->grad(), adj_, lf_);
+ adj_, val_, child(0)->val());
+ sparse::LfaBackward(child(1)->grad(), adj_, lf_);
}
const std::string type() {