diff options
author | Marcin Junczys-Dowmunt <junczys@amu.edu.pl> | 2017-01-23 00:56:50 +0300 |
---|---|---|
committer | Marcin Junczys-Dowmunt <junczys@amu.edu.pl> | 2017-01-23 00:56:50 +0300 |
commit | 622260e2006c9ba67d4f0532954a428278ad2e4b (patch) | |
tree | 0614ca6fa7e641c05b2091f12c2187476b511f46 /src/graph/node_operators_unary.h | |
parent | 79c9e20bb1b84733c612e93161b663c45037853e (diff) |
major refactorting
Diffstat (limited to 'src/graph/node_operators_unary.h')
-rw-r--r-- | src/graph/node_operators_unary.h | 464 |
1 files changed, 217 insertions, 247 deletions
diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h index a654f1b2..18687c0a 100644 --- a/src/graph/node_operators_unary.h +++ b/src/graph/node_operators_unary.h @@ -7,27 +7,16 @@ namespace marian { -struct UnaryNodeOp : public Node { - Expr a_; - +struct UnaryNodeOp : public NaryNodeOp { template <typename ...Args> UnaryNodeOp(Expr a, Args ...args) - : Node(a->graph(), - keywords::shape=a->shape(), - args...), - a_(a) - { - setTrainable(a_->trainable()); - remove_children_from_top_nodes(); - } - - ~UnaryNodeOp() {} + : NaryNodeOp({a}, + keywords::shape=a->shape(), + args...) {} - std::vector<Expr> children() { - return { a_ }; + const std::string color() { + return "yellow"; } - - void remove_children_from_top_nodes(); }; struct LogitNodeOp : public UnaryNodeOp { @@ -35,25 +24,26 @@ struct LogitNodeOp : public UnaryNodeOp { LogitNodeOp(Args ...args) : UnaryNodeOp(args...) { } - void forward() { - Element(_1 = Sigma(_2), - val_, a_->val()); + NodeOps forwardOps() { + return { + NodeOp(Element(_1 = Sigma(_2), + val_, + children_[0]->val())) + }; } - void backward() { - if(a_->trainable()) - Element(_1 += _2 * _3 * (1.0f - _3), - a_->grad(), adj_, val_); + NodeOps backwardOps() { + return { + NodeOp(Element(_1 += _2 * _3 * (1.0f - _3), + children_[0]->grad(), + adj_, + val_)) + }; } - virtual std::string graphviz() { - std::stringstream ss; - ss << "\"" << this << "\" [shape=\"box\", label=" << label("logit") - << ", style=\"filled\", fillcolor=\"yellow\"]" << std::endl; - ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl << std::endl; - return ss.str(); - }; - + const std::string type() { + return "logit"; + } }; struct TanhNodeOp : public UnaryNodeOp { @@ -61,25 +51,26 @@ struct TanhNodeOp : public UnaryNodeOp { TanhNodeOp(Args ...args) : UnaryNodeOp(args...) { } - void forward() { - Element(_1 = Tanh(_2), - val_, a_->val()); + NodeOps forwardOps() { + return { + NodeOp(Element(_1 = Tanh(_2), + val_, + children_[0]->val())) + }; } - void backward() { - if(a_->trainable()) - Element(_1 += _2 * (1.0f - (_3 * _3)), - a_->grad(), adj_, val_); + NodeOps backwardOps() { + return { + NodeOp(Element(_1 += _2 * (1.0f - (_3 * _3)), + children_[0]->grad(), + adj_, + val_)) + }; } - virtual std::string graphviz() { - std::stringstream ss; - ss << "\"" << this << "\" [shape=\"box\", label=" << label("tanh") - << ", style=\"filled\", fillcolor=\"yellow\"]" << std::endl; - ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl << std::endl; - return ss.str(); - }; - + const std::string type() { + return "tanh"; + } }; /** @@ -102,25 +93,24 @@ struct ReLUNodeOp : public UnaryNodeOp { ReLUNodeOp(Args ...args) : UnaryNodeOp(args...) { } - void forward() { - Element(_1 = ReLU(_2), - val_, a_->val()); + NodeOps forwardOps() { + return { + NodeOp(Element(_1 = ReLU(_2), + val_, + children_[0]->val())) + }; } - void backward() { - if(a_->trainable()) - Element(_1 += _2 * ReLUback(_3), - a_->grad(), adj_, a_->val()); + NodeOps backwardOps() { + return { + NodeOp(Element(_1 += _2 * ReLUback(_3), + children_[0]->grad(), adj_, children_[0]->val())) + }; } - virtual std::string graphviz() { - std::stringstream ss; - ss << "\"" << this << "\" [shape=\"box\", label=" << label("ReLU") - << ", style=\"filled\", fillcolor=\"yellow\"]" << std::endl; - ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl << std::endl; - return ss.str(); - }; - + const std::string type() { + return "ReLU"; + } }; /** @@ -142,12 +132,12 @@ struct DropoutNodeOp : public UnaryNodeOp { } void inference() { - Element(_1 = _2, val_, a_->val()); + Element(_1 = _2, val_, children_[0]->val()); } void forward() { if(!allocated_) { - CudnnDropoutPrepare(a_->val(), p_, + CudnnDropoutPrepare(children_[0]->val(), p_, &dropDesc_, &space_, &spaceSize_, &states_, (size_t)this); // seeding with pointer address @@ -155,22 +145,18 @@ struct DropoutNodeOp : public UnaryNodeOp { } CudnnDropoutForward(dropDesc_, space_, spaceSize_, - val_, a_->val()); + val_, children_[0]->val()); } void backward() { - if(a_->trainable()) + if(children_[0]->trainable()) CudnnDropoutBackward(dropDesc_, space_, spaceSize_, - a_->grad(), adj_); + children_[0]->grad(), adj_); } - virtual std::string graphviz() { - std::stringstream ss; - ss << "\"" << this << "\" [shape=\"box\", label=" << label("dropout") - << ", style=\"filled\", fillcolor=\"yellow\"]" << std::endl; - ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl << std::endl; - return ss.str(); - }; + const std::string type() { + return "dropout"; + } private: bool allocated_; @@ -181,22 +167,28 @@ struct DropoutNodeOp : public UnaryNodeOp { cudnnDropoutDescriptor_t dropDesc_; }; -struct SoftmaxNodeOp : public UnaryNodeOp { +struct SoftmaxNodeOp : public NaryNodeOp { template <typename ...Args> - SoftmaxNodeOp(Expr a, Expr mask = nullptr, Args ...args) - : UnaryNodeOp(a, args...), mask_(mask) { - remove_mask_from_top_nodes(); + SoftmaxNodeOp(Expr a, Args ...args) + : NaryNodeOp(a, args...), mask_(nullptr) { } - Expr mask_; + template <typename ...Args> + SoftmaxNodeOp(Expr a, Expr mask, Args ...args) + : NaryNodeOp({a, mask}, args...), mask_(mask) { + } - void remove_mask_from_top_nodes(); + Expr mask_; - void forward() { - Softmax(val_, a_->val(), mask_->val()); + NodeOps forwardOps() { + return { + NodeOp(Softmax(val_, + children_[0]->val(), + mask_ ? mask_->val() : nullptr)) + }; } - void backward() { + NodeOps backwardOps() { // For each row, the Jacobian times vector is given by: // J * dy = p .* (dy - avg*1) // where avg = p'*dy and p is the softmax output (probabilities). @@ -208,19 +200,15 @@ struct SoftmaxNodeOp : public UnaryNodeOp { // http://jmlr.org/proceedings/papers/v48/martins16.pdf // val_ is already masked if there is a mask, so no need to apply here. - if(a_->trainable()) - SoftmaxGrad(a_->grad(), adj_, val_); - } - - virtual std::string graphviz() { - std::stringstream ss; - ss << "\"" << this << "\" [shape=\"box\", label=" << label("softmax") - << ", style=\"filled\", fillcolor=\"yellow\"]" << std::endl; - ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl << std::endl; - if(mask_) - ss << "\"" << mask_ << "\" -> \"" << this << "\"" << std::endl << std::endl; - return ss.str(); - }; + + return { + NodeOp(SoftmaxGrad(children_[0]->grad(), adj_, val_)) + }; + } + + const std::string type() { + return "softmax"; + } }; struct LogSoftmaxNodeOp : public UnaryNodeOp { @@ -228,56 +216,24 @@ struct LogSoftmaxNodeOp : public UnaryNodeOp { LogSoftmaxNodeOp(Args ...args) : UnaryNodeOp(args...) { } - void forward() { - CudnnLogSoftmax(val_, a_->val()); + NodeOps forwardOps() { + return { + NodeOp(LogSoftmax(val_, children_[0]->val())) + }; } - void backward() { + NodeOps backwardOps() { // Based on the description for softmax, we have logsoftmax: // J * dy = dy - avg*1 // where avg = exp(p)'*dy and p is the softmax output (probabilities). - if(a_->trainable()) - LogSoftmaxGrad(a_->grad(), adj_, val_); + return { + NodeOp(LogSoftmaxGrad(children_[0]->grad(), adj_, val_)) + }; } - virtual std::string graphviz() { - std::stringstream ss; - ss << "\"" << this << "\" [shape=\"box\", label=" << label("log-softmax") - << ", style=\"filled\", fillcolor=\"yellow\"]" << std::endl; - ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl << std::endl; - return ss.str(); - }; -}; - - -struct ArgmaxNodeOp : public UnaryNodeOp { - template <typename ...Args> - ArgmaxNodeOp(Expr a, Args ...args) - : UnaryNodeOp(a, keywords::shape=newShape(a), args...) { } - - void forward() { - // B = softmax(A). - //Argmax(&val_, &a_->val()); + const std::string type() { + return "logsoftmax"; } - - void backward() { - } - - Shape newShape(Expr a) { - Shape shape = a->shape(); - shape.set(0, 1); - return shape; - } - - - virtual std::string graphviz() { - std::stringstream ss; - ss << "\"" << this << "\" [shape=\"box\", label=" - << label("argmax") << ", style=\"filled\", fillcolor=\"yellow\"]" << std::endl; - ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl << std::endl; - return ss.str(); - }; - }; struct SumNodeOp : public UnaryNodeOp { @@ -285,13 +241,12 @@ struct SumNodeOp : public UnaryNodeOp { SumNodeOp(Expr a, Args ...args) : UnaryNodeOp(a, keywords::shape=newShape(a, args...), args...) { } - void forward() { - Reduce(_1, val_, a_->val()); + NodeOps forwardOps() { + return { NodeOp(Reduce(_1, val_, children_[0]->val())) }; } - void backward() { - if(a_->trainable()) - Add(_1, a_->grad(), adj_); + NodeOps backwardOps() { + return { NodeOp(Add(_1, children_[0]->grad(), adj_)) }; } template <class ...Args> @@ -310,13 +265,13 @@ struct SumNodeOp : public UnaryNodeOp { return shape; } - virtual std::string graphviz() { - std::stringstream ss; - ss << "\"" << this << "\" [shape=\"box\", label=" - << label("sum") << ", style=\"filled\", fillcolor=\"orange\"]" << std::endl; - ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl << std::endl; - return ss.str(); - }; + const std::string type() { + return "sum"; + } + + const std::string color() { + return "orange"; + } }; @@ -325,17 +280,22 @@ struct MeanNodeOp : public UnaryNodeOp { MeanNodeOp(Expr a, Args ...args) : UnaryNodeOp(a, keywords::shape=newShape(a, args...), args...) { } - void forward() { - int left = a_->shape().elements() / val_->shape().elements(); + NodeOps forwardOps() { + int left = children_[0]->shape().elements() / val_->shape().elements(); float scale = 1.f / left; - Reduce(_1 * scale, val_, a_->val()); + + return { + NodeOp(Reduce(_1 * scale, val_, children_[0]->val())) + }; } - void backward() { - int left = a_->shape().elements() / val_->shape().elements(); + NodeOps backwardOps() { + int left = children_[0]->shape().elements() / val_->shape().elements(); float scale = 1.f / left; - if(a_->trainable()) - Add(_1 * scale, a_->grad(), adj_); + + return { + NodeOp(Add(_1 * scale, children_[0]->grad(), adj_)) + }; } template <class ...Args> @@ -354,12 +314,12 @@ struct MeanNodeOp : public UnaryNodeOp { return shape; } - virtual std::string graphviz() { - std::stringstream ss; - ss << "\"" << this << "\" [shape=\"box\", label=" - << label("mean") << ", style=\"filled\", fillcolor=\"orange\"]" << std::endl; - ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl << std::endl; - return ss.str(); + const std::string type() { + return "mean"; + } + + const std::string color() { + return "orange"; } }; @@ -369,24 +329,26 @@ struct LogNodeOp : public UnaryNodeOp { LogNodeOp(Args ...args) : UnaryNodeOp(args...) {} - void forward() { - Element(_1 = Log(_2), val_, a_->val()); + NodeOps forwardOps() { + return { + NodeOp(Element(_1 = Log(_2), + val_, + children_[0]->val())) + }; } - void backward() { - if(a_->trainable()) - Add(_1 * (1.f / _2), - a_->grad(), adj_, a_->val()); + NodeOps backwardOps() { + return { + NodeOp(Add(_1 * (1.f / _2), + children_[0]->grad(), + adj_, + children_[0]->val())) + }; } - virtual std::string graphviz() { - std::stringstream ss; - ss << "\"" << this << "\" [shape=\"box\", label=" - << label("log") << ", style=\"filled\", fillcolor=\"yellow\"]" << std::endl; - ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl << std::endl; - return ss.str(); - }; - + const std::string type() { + return "log"; + } }; struct ExpNodeOp : public UnaryNodeOp { @@ -394,23 +356,26 @@ struct ExpNodeOp : public UnaryNodeOp { ExpNodeOp(Args ...args) : UnaryNodeOp(args...) { } - void forward() { - Element(_1 = Exp(_2), val_, a_->val()); + NodeOps forwardOps() { + return { + NodeOp(Element(_1 = Exp(_2), + val_, + children_[0]->val())) + }; } - void backward() { - if(a_->trainable()) - Add(_1 * Exp(_2), - a_->grad(), adj_, a_->val()); + NodeOps backwardOps() { + return { + NodeOp(Add(_1 * Exp(_2), + children_[0]->grad(), + adj_, + children_[0]->val())) + }; } - virtual std::string graphviz() { - std::stringstream ss; - ss << "\"" << this << "\" [shape=\"box\", label=" << label("exp") - << ", style=\"filled\", fillcolor=\"yellow\"]" << std::endl; - ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl << std::endl; - return ss.str(); - }; + const std::string type() { + return "exp"; + } }; @@ -419,21 +384,24 @@ struct NegNodeOp : public UnaryNodeOp { NegNodeOp(Args ...args) : UnaryNodeOp(args...) { } - void forward() { - Element(_1 = -_2, val_, a_->val()); + NodeOps forwardOps() { + return { + NodeOp(Element(_1 = -_2, + val_, + children_[0]->val())) + }; } - void backward() { - if(a_->trainable()) - Add(-_1, a_->grad(), adj_); + NodeOps backwardOps() { + return { + NodeOp(Add(-_1, + children_[0]->grad(), + adj_)) + }; } - virtual std::string graphviz() { - std::stringstream ss; - ss << "\"" << this << "\" [shape=\"box\", label=" - << label("-") << ", style=\"filled\", fillcolor=\"yellow\"]" << std::endl; - ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl << std::endl; - return ss.str(); + const std::string type() { + return "-"; } }; @@ -445,13 +413,18 @@ struct RowsNodeOp : public UnaryNodeOp { thrust::copy(indeces.begin(), indeces.end(), indeces_.begin()); } - void forward() { - CopyRows(val_, a_->val(), indeces_); + NodeOps forwardOps() { + return { + NodeOp(CopyRows(val_, children_[0]->val(), indeces_)) + }; } - void backward() { - if(a_->trainable()) - PasteRows(a_->grad(), adj_, indeces_); + NodeOps backwardOps() { + return { + NodeOp(PasteRows(children_[0]->grad(), + adj_, + indeces_)) + }; } template <class ...Args> @@ -461,12 +434,12 @@ struct RowsNodeOp : public UnaryNodeOp { return shape; } - virtual std::string graphviz() { - std::stringstream ss; - ss << "\"" << this << "\" [shape=\"box\", label=" - << label("rows") << ", style=\"filled\", fillcolor=\"orange\"]" << std::endl; - ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl << std::endl; - return ss.str(); + const std::string type() { + return "rows"; + } + + const std::string color() { + return "orange"; } DeviceVector<size_t> indeces_; @@ -477,13 +450,18 @@ struct TransposeNodeOp : public UnaryNodeOp { TransposeNodeOp(Expr a, Args ...args) : UnaryNodeOp(a, keywords::shape=newShape(a), args...) { } - void forward() { - Transpose(val_, a_->val()); + NodeOps forwardOps() { + return { + NodeOp(Transpose(getCublasHandle(), + val_, children_[0]->val())) + }; } - void backward() { - if(a_->trainable()) - Transpose(a_->grad(), adj_); + NodeOps backwardOps() { + return { + NodeOp(Transpose(getCublasHandle(), + children_[0]->grad(), adj_)) + }; } template <class ...Args> @@ -495,12 +473,12 @@ struct TransposeNodeOp : public UnaryNodeOp { return shape; } - virtual std::string graphviz() { - std::stringstream ss; - ss << "\"" << this << "\" [shape=\"box\", label=" - << label("transpose") << ", style=\"filled\", fillcolor=\"orange\"]" << std::endl; - ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl << std::endl; - return ss.str(); + const std::string type() { + return "transpose"; + } + + const std::string color() { + return "orange"; } }; @@ -516,33 +494,29 @@ struct ReshapeNodeOp : public UnaryNodeOp { void backward() {} void init_dependent() { - a_->init_dependent(); + children_[0]->init_dependent(); } void set_zero_adjoint() { - a_->set_zero_adjoint(); + children_[0]->set_zero_adjoint(); } Tensor& val() { - val_.reset(new TensorGPU(a_->val()->data(), shape())); + val_.reset(new TensorGPU(children_[0]->val()->data(), shape())); return val_; }; Tensor& grad() { - adj_.reset(new TensorGPU(a_->grad()->data(), shape())); + adj_.reset(new TensorGPU(children_[0]->grad()->data(), shape())); return adj_; }; - std::vector<Expr> children() { - return a_->children(); + const std::string type() { + return "reshape"; } - virtual std::string graphviz() { - std::stringstream ss; - ss << "\"" << this << "\" [shape=\"box\", label=" - << label("reshape") << ", style=\"filled\", fillcolor=\"yellow\"]" << std::endl; - ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl << std::endl; - return ss.str(); + const std::string color() { + return "grey"; } }; @@ -568,35 +542,31 @@ struct TimestepNodeOp : public UnaryNodeOp { void backward() {} void init_dependent() { - a_->init_dependent(); + children_[0]->init_dependent(); } void set_zero_adjoint() { - a_->set_zero_adjoint(); + children_[0]->set_zero_adjoint(); } Tensor& val() { size_t offset = step_ * shape().elements(); - val_.reset(new TensorGPU(a_->val()->data() + offset, shape())); + val_.reset(new TensorGPU(children_[0]->val()->data() + offset, shape())); return val_; }; Tensor& grad() { size_t offset = step_ * shape().elements(); - adj_.reset(new TensorGPU(a_->grad()->data() + offset, shape())); + adj_.reset(new TensorGPU(children_[0]->grad()->data() + offset, shape())); return adj_; }; - std::vector<Expr> children() { - return a_->children(); + const std::string type() { + return "step"; } - virtual std::string graphviz() { - std::stringstream ss; - ss << "\"" << this << "\" [shape=\"box\", label=" - << label("step") << ", style=\"filled\", fillcolor=\"yellow\"]" << std::endl; - ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl << std::endl; - return ss.str(); + const std::string color() { + return "grey"; } }; |