diff options
author | Marcin Junczys-Dowmunt <junczys@amu.edu.pl> | 2017-04-23 18:40:43 +0300 |
---|---|---|
committer | Marcin Junczys-Dowmunt <junczys@amu.edu.pl> | 2017-04-23 18:40:43 +0300 |
commit | c13bc7cec3a18dca1f8cd1a5f3a24617e96c0e3f (patch) | |
tree | addf17f949ab45ee1a1eb5158a2456d1e66f3bdf /src/graph/node_operators_unary.h | |
parent | 38d9204fe565335536b7daa94568367253c14063 (diff) |
better memory handling
Diffstat (limited to 'src/graph/node_operators_unary.h')
-rw-r--r-- | src/graph/node_operators_unary.h | 142 |
1 files changed, 76 insertions, 66 deletions
diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h index 20298e2a..8b78343f 100644 --- a/src/graph/node_operators_unary.h +++ b/src/graph/node_operators_unary.h @@ -29,14 +29,14 @@ struct LogitNodeOp : public UnaryNodeOp { return { NodeOp(Element(_1 = Sigma(_2), val_, - children_[0]->val())) + child(0)->val())) }; } NodeOps backwardOps() { return { NodeOp(Add(_1 * _2 * (1.0f - _2), - children_[0]->grad(), + child(0)->grad(), adj_, val_)) }; } @@ -69,28 +69,28 @@ struct TanhNodeOp : public NaryNodeOp { case 1: return { NodeOp(Element(_1 = Tanh(_2), val_, - children_[0]->val())) }; + child(0)->val())) }; case 2: return { NodeOp(Element(_1 = Tanh(_2 + _3), val_, - children_[0]->val(), - children_[1]->val())) }; + child(0)->val(), + child(1)->val())) }; case 3: return { NodeOp(Element(_1 = Tanh(_2 + _3 + _4), val_, - children_[0]->val(), - children_[1]->val(), - children_[2]->val())) }; + child(0)->val(), + child(1)->val(), + child(2)->val())) }; default: return { NodeOp( Element(_1 = _2 + _3 + _4, val_, - children_[0]->val(), - children_[1]->val(), - children_[2]->val()); + child(0)->val(), + child(1)->val(), + child(2)->val()); for(int i = 3; i < children_.size(); ++i) - Element(_1 += _2, val_, children_[i]->val()); + Element(_1 += _2, val_, child(i)->val()); Element(_1 = Tanh(_1), val_); ) }; @@ -99,10 +99,10 @@ struct TanhNodeOp : public NaryNodeOp { NodeOps backwardOps() { NodeOps ops; - for(auto&& child : children_) { + for(int i = 0; i < children_.size(); i++) { ops.push_back( NodeOp(Add(_1 * (1.0f - (_2 * _2)), - child->grad(), adj_, val_)) + child(i)->grad(), adj_, val_)) ); } return ops; @@ -141,15 +141,15 @@ struct ReLUNodeOp : public UnaryNodeOp { return { NodeOp(Element(_1 = ReLU(_2), val_, - children_[0]->val())) + child(0)->val())) }; } NodeOps backwardOps() { return { NodeOp(Add(_1 * ReLUback(_2), - children_[0]->grad(), - adj_, children_[0]->val())) + child(0)->grad(), + adj_, child(0)->val())) }; } @@ -174,7 +174,7 @@ struct SoftmaxNodeOp : public NaryNodeOp { NodeOps forwardOps() { return { NodeOp(Softmax(val_, - children_[0]->val(), + child(0)->val(), mask_ ? mask_->val() : nullptr)) }; } @@ -203,7 +203,7 @@ struct SoftmaxNodeOp : public NaryNodeOp { // val_ is already masked if there is a mask, so no need to apply here. return { - NodeOp(SoftmaxGrad(children_[0]->grad(), adj_, val_)) + NodeOp(SoftmaxGrad(child(0)->grad(), adj_, val_)) }; } @@ -219,7 +219,7 @@ struct LogSoftmaxNodeOp : public UnaryNodeOp { NodeOps forwardOps() { return { - NodeOp(LogSoftmax(val_, children_[0]->val())) + NodeOp(LogSoftmax(val_, child(0)->val())) }; } @@ -228,7 +228,7 @@ struct LogSoftmaxNodeOp : public UnaryNodeOp { // J * dy = dy - avg*1 // where avg = exp(p)'*dy and p is the softmax output (probabilities). return { - NodeOp(LogSoftmaxGrad(children_[0]->grad(), adj_, val_)) + NodeOp(LogSoftmaxGrad(child(0)->grad(), adj_, val_)) }; } @@ -246,11 +246,11 @@ struct SumNodeOp : public UnaryNodeOp { ax_(keywords::Get(keywords::axis, -1, args...)) { } NodeOps forwardOps() { - return { NodeOp(Reduce(_1, val_, children_[0]->val())) }; + return { NodeOp(Reduce(_1, val_, child(0)->val())) }; } NodeOps backwardOps() { - return { NodeOp(Add(_1, children_[0]->grad(), adj_)) }; + return { NodeOp(Add(_1, child(0)->grad(), adj_)) }; } template <class ...Args> @@ -297,20 +297,20 @@ struct MeanNodeOp : public UnaryNodeOp { ax_(keywords::Get(keywords::axis, -1, args...)) { } NodeOps forwardOps() { - int left = children_[0]->shape().elements() / val_->shape().elements(); + int left = child(0)->shape().elements() / val_->shape().elements(); float scale = 1.f / left; return { - NodeOp(Reduce(_1, val_, children_[0]->val(), scale)) + NodeOp(Reduce(_1, val_, child(0)->val(), scale)) }; } NodeOps backwardOps() { - int left = children_[0]->shape().elements() / val_->shape().elements(); + int left = child(0)->shape().elements() / val_->shape().elements(); float scale = 1.f / left; return { - NodeOp(Add(_1, children_[0]->grad(), adj_, scale)) + NodeOp(Add(_1, child(0)->grad(), adj_, scale)) }; } @@ -358,16 +358,16 @@ struct LogNodeOp : public UnaryNodeOp { return { NodeOp(Element(_1 = Log(_2), val_, - children_[0]->val())) + child(0)->val())) }; } NodeOps backwardOps() { return { NodeOp(Add(_1 * (1.f / _2), - children_[0]->grad(), + child(0)->grad(), adj_, - children_[0]->val())) + child(0)->val())) }; } @@ -385,16 +385,16 @@ struct ExpNodeOp : public UnaryNodeOp { return { NodeOp(Element(_1 = Exp(_2), val_, - children_[0]->val())) + child(0)->val())) }; } NodeOps backwardOps() { return { NodeOp(Add(_1 * Exp(_2), - children_[0]->grad(), + child(0)->grad(), adj_, - children_[0]->val())) + child(0)->val())) }; } @@ -416,14 +416,14 @@ struct SqrtNodeOp : public UnaryNodeOp { return { NodeOp(Element(_1 = Sqrt(_2 + epsilon_), val_, - children_[0]->val())) + child(0)->val())) }; } NodeOps backwardOps() { return { NodeOp(Add(0.5f * (1.f / _1) * _2, - children_[0]->grad(), + child(0)->grad(), val_, adj_)) }; @@ -456,15 +456,15 @@ struct SquareNodeOp : public UnaryNodeOp { return { NodeOp(Element(_1 = _2 * _2, val_, - children_[0]->val())) + child(0)->val())) }; } NodeOps backwardOps() { return { NodeOp(Add(2.f * _1 * _2, - children_[0]->grad(), - children_[0]->val(), + child(0)->grad(), + child(0)->val(), adj_)) }; } @@ -485,14 +485,14 @@ struct NegNodeOp : public UnaryNodeOp { return { NodeOp(Element(_1 = -_2, val_, - children_[0]->val())) + child(0)->val())) }; } NodeOps backwardOps() { return { NodeOp(Add(-_1, - children_[0]->grad(), + child(0)->grad(), adj_)) }; } @@ -514,14 +514,14 @@ struct RowsNodeOp : public UnaryNodeOp { return { NodeOp(CopyRows(val_, - children_[0]->val(), + child(0)->val(), indeces_)) }; } NodeOps backwardOps() { return { - NodeOp(PasteRows(children_[0]->grad(), + NodeOp(PasteRows(child(0)->grad(), adj_, indeces_)) }; @@ -568,14 +568,14 @@ struct ColsNodeOp : public UnaryNodeOp { return { NodeOp(CopyCols(val_, - children_[0]->val(), + child(0)->val(), indeces_)) }; } NodeOps backwardOps() { return { - NodeOp(PasteCols(children_[0]->grad(), + NodeOp(PasteCols(child(0)->grad(), adj_, indeces_)) }; @@ -619,14 +619,14 @@ struct TransposeNodeOp : public UnaryNodeOp { NodeOps forwardOps() { return { NodeOp(Transpose(getCublasHandle(), - val_, children_[0]->val())) + val_, child(0)->val())) }; } NodeOps backwardOps() { return { NodeOp(Transpose(getCublasHandle(), - children_[0]->grad(), adj_)) + child(0)->grad(), adj_)) }; } @@ -648,11 +648,18 @@ struct TransposeNodeOp : public UnaryNodeOp { } }; -struct ReshapeNodeOp : public UnaryNodeOp { +class ReshapeNodeOp : public UnaryNodeOp { +private: + Expr reshapee_; + +public: template <typename ...Args> ReshapeNodeOp(Expr a, Shape shape, Args ...args) - : UnaryNodeOp(a, keywords::shape=shape, args...) { } + : UnaryNodeOp(a, keywords::shape=shape, args...), + reshapee_(a) { } + + size_t allocate() { return 0; } void free() {} @@ -660,21 +667,21 @@ struct ReshapeNodeOp : public UnaryNodeOp { void backward() {} void init_dependent() { - children_[0]->init_dependent(); + reshapee_->init_dependent(); } void set_zero_adjoint() { - children_[0]->set_zero_adjoint(); + reshapee_->set_zero_adjoint(); } Tensor& val() { - auto childVal = children_[0]->val(); + auto childVal = reshapee_->val(); val_.reset(new TensorBase(childVal->data(), shape(), childVal->getDevice())); return val_; }; Tensor& grad() { - auto childGrad = children_[0]->grad(); + auto childGrad = reshapee_->grad(); adj_.reset(new TensorBase(childGrad->data(), shape(), childGrad->getDevice())); return adj_; }; @@ -699,12 +706,15 @@ struct ReshapeNodeOp : public UnaryNodeOp { }; -struct TimestepNodeOp : public UnaryNodeOp { +class TimestepNodeOp : public UnaryNodeOp { +private: + Expr stepNode_; size_t step_; +public: TimestepNodeOp(Expr a, size_t step) : UnaryNodeOp(a, keywords::shape=newShape(a)), - step_(step) + stepNode_(a), step_(step) { } Shape newShape(Expr a) { @@ -721,22 +731,22 @@ struct TimestepNodeOp : public UnaryNodeOp { void backward() {} void init_dependent() { - children_[0]->init_dependent(); + stepNode_->init_dependent(); } void set_zero_adjoint() { - children_[0]->set_zero_adjoint(); + stepNode_->set_zero_adjoint(); } Tensor& val() { - auto childVal = children_[0]->val(); + auto childVal = stepNode_->val(); size_t offset = step_ * shape().elements(); val_.reset(new TensorBase(childVal->data() + offset, shape(), childVal->getDevice())); return val_; }; Tensor& grad() { - auto childGrad = children_[0]->grad(); + auto childGrad = stepNode_->grad(); size_t offset = step_ * shape().elements(); adj_.reset(new TensorBase(childGrad->data() + offset, shape(), childGrad->getDevice())); return adj_; @@ -770,14 +780,14 @@ struct ShiftNodeOp : public UnaryNodeOp { NodeOps forwardOps() { return { NodeOp(Shift(val_, - children_[0]->val(), + child(0)->val(), shift_)) }; } NodeOps backwardOps() { return { - NodeOp(Shift(children_[0]->grad(), + NodeOp(Shift(child(0)->grad(), adj_, shift_, true)) @@ -812,20 +822,20 @@ struct LexicalProbNodeOp : public NaryNodeOp { void forward() { sparse::LfaForward(val_, - children_[0]->val(), - children_[1]->val(), + child(0)->val(), + child(1)->val(), lf_); // val = x + ln(p + eps) Element(_1 = (Log(_1 + eps_) + _2), - val_, children_[0]->val()); + val_, child(0)->val()); } void backward() { - Add(_1, children_[0]->grad(), adj_); + Add(_1, child(0)->grad(), adj_); // adj' = adj / (p + eps) = adj / exp(val - x) Element(_1 = _1 / Exp(_2 - _3), - adj_, val_, children_[0]->val()); - sparse::LfaBackward(children_[1]->grad(), adj_, lf_); + adj_, val_, child(0)->val()); + sparse::LfaBackward(child(1)->grad(), adj_, lf_); } const std::string type() { |