Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMarcin Junczys-Dowmunt <junczys@amu.edu.pl>2017-01-23 00:56:50 +0300
committerMarcin Junczys-Dowmunt <junczys@amu.edu.pl>2017-01-23 00:56:50 +0300
commit622260e2006c9ba67d4f0532954a428278ad2e4b (patch)
tree0614ca6fa7e641c05b2091f12c2187476b511f46 /src/graph/node_operators_unary.h
parent79c9e20bb1b84733c612e93161b663c45037853e (diff)
major refactorting
Diffstat (limited to 'src/graph/node_operators_unary.h')
-rw-r--r--src/graph/node_operators_unary.h464
1 files changed, 217 insertions, 247 deletions
diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h
index a654f1b2..18687c0a 100644
--- a/src/graph/node_operators_unary.h
+++ b/src/graph/node_operators_unary.h
@@ -7,27 +7,16 @@
namespace marian {
-struct UnaryNodeOp : public Node {
- Expr a_;
-
+struct UnaryNodeOp : public NaryNodeOp {
template <typename ...Args>
UnaryNodeOp(Expr a, Args ...args)
- : Node(a->graph(),
- keywords::shape=a->shape(),
- args...),
- a_(a)
- {
- setTrainable(a_->trainable());
- remove_children_from_top_nodes();
- }
-
- ~UnaryNodeOp() {}
+ : NaryNodeOp({a},
+ keywords::shape=a->shape(),
+ args...) {}
- std::vector<Expr> children() {
- return { a_ };
+ const std::string color() {
+ return "yellow";
}
-
- void remove_children_from_top_nodes();
};
struct LogitNodeOp : public UnaryNodeOp {
@@ -35,25 +24,26 @@ struct LogitNodeOp : public UnaryNodeOp {
LogitNodeOp(Args ...args)
: UnaryNodeOp(args...) { }
- void forward() {
- Element(_1 = Sigma(_2),
- val_, a_->val());
+ NodeOps forwardOps() {
+ return {
+ NodeOp(Element(_1 = Sigma(_2),
+ val_,
+ children_[0]->val()))
+ };
}
- void backward() {
- if(a_->trainable())
- Element(_1 += _2 * _3 * (1.0f - _3),
- a_->grad(), adj_, val_);
+ NodeOps backwardOps() {
+ return {
+ NodeOp(Element(_1 += _2 * _3 * (1.0f - _3),
+ children_[0]->grad(),
+ adj_,
+ val_))
+ };
}
- virtual std::string graphviz() {
- std::stringstream ss;
- ss << "\"" << this << "\" [shape=\"box\", label=" << label("logit")
- << ", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
- ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
- return ss.str();
- };
-
+ const std::string type() {
+ return "logit";
+ }
};
struct TanhNodeOp : public UnaryNodeOp {
@@ -61,25 +51,26 @@ struct TanhNodeOp : public UnaryNodeOp {
TanhNodeOp(Args ...args)
: UnaryNodeOp(args...) { }
- void forward() {
- Element(_1 = Tanh(_2),
- val_, a_->val());
+ NodeOps forwardOps() {
+ return {
+ NodeOp(Element(_1 = Tanh(_2),
+ val_,
+ children_[0]->val()))
+ };
}
- void backward() {
- if(a_->trainable())
- Element(_1 += _2 * (1.0f - (_3 * _3)),
- a_->grad(), adj_, val_);
+ NodeOps backwardOps() {
+ return {
+ NodeOp(Element(_1 += _2 * (1.0f - (_3 * _3)),
+ children_[0]->grad(),
+ adj_,
+ val_))
+ };
}
- virtual std::string graphviz() {
- std::stringstream ss;
- ss << "\"" << this << "\" [shape=\"box\", label=" << label("tanh")
- << ", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
- ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
- return ss.str();
- };
-
+ const std::string type() {
+ return "tanh";
+ }
};
/**
@@ -102,25 +93,24 @@ struct ReLUNodeOp : public UnaryNodeOp {
ReLUNodeOp(Args ...args)
: UnaryNodeOp(args...) { }
- void forward() {
- Element(_1 = ReLU(_2),
- val_, a_->val());
+ NodeOps forwardOps() {
+ return {
+ NodeOp(Element(_1 = ReLU(_2),
+ val_,
+ children_[0]->val()))
+ };
}
- void backward() {
- if(a_->trainable())
- Element(_1 += _2 * ReLUback(_3),
- a_->grad(), adj_, a_->val());
+ NodeOps backwardOps() {
+ return {
+ NodeOp(Element(_1 += _2 * ReLUback(_3),
+ children_[0]->grad(), adj_, children_[0]->val()))
+ };
}
- virtual std::string graphviz() {
- std::stringstream ss;
- ss << "\"" << this << "\" [shape=\"box\", label=" << label("ReLU")
- << ", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
- ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
- return ss.str();
- };
-
+ const std::string type() {
+ return "ReLU";
+ }
};
/**
@@ -142,12 +132,12 @@ struct DropoutNodeOp : public UnaryNodeOp {
}
void inference() {
- Element(_1 = _2, val_, a_->val());
+ Element(_1 = _2, val_, children_[0]->val());
}
void forward() {
if(!allocated_) {
- CudnnDropoutPrepare(a_->val(), p_,
+ CudnnDropoutPrepare(children_[0]->val(), p_,
&dropDesc_,
&space_, &spaceSize_,
&states_, (size_t)this); // seeding with pointer address
@@ -155,22 +145,18 @@ struct DropoutNodeOp : public UnaryNodeOp {
}
CudnnDropoutForward(dropDesc_, space_, spaceSize_,
- val_, a_->val());
+ val_, children_[0]->val());
}
void backward() {
- if(a_->trainable())
+ if(children_[0]->trainable())
CudnnDropoutBackward(dropDesc_, space_, spaceSize_,
- a_->grad(), adj_);
+ children_[0]->grad(), adj_);
}
- virtual std::string graphviz() {
- std::stringstream ss;
- ss << "\"" << this << "\" [shape=\"box\", label=" << label("dropout")
- << ", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
- ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
- return ss.str();
- };
+ const std::string type() {
+ return "dropout";
+ }
private:
bool allocated_;
@@ -181,22 +167,28 @@ struct DropoutNodeOp : public UnaryNodeOp {
cudnnDropoutDescriptor_t dropDesc_;
};
-struct SoftmaxNodeOp : public UnaryNodeOp {
+struct SoftmaxNodeOp : public NaryNodeOp {
template <typename ...Args>
- SoftmaxNodeOp(Expr a, Expr mask = nullptr, Args ...args)
- : UnaryNodeOp(a, args...), mask_(mask) {
- remove_mask_from_top_nodes();
+ SoftmaxNodeOp(Expr a, Args ...args)
+ : NaryNodeOp(a, args...), mask_(nullptr) {
}
- Expr mask_;
+ template <typename ...Args>
+ SoftmaxNodeOp(Expr a, Expr mask, Args ...args)
+ : NaryNodeOp({a, mask}, args...), mask_(mask) {
+ }
- void remove_mask_from_top_nodes();
+ Expr mask_;
- void forward() {
- Softmax(val_, a_->val(), mask_->val());
+ NodeOps forwardOps() {
+ return {
+ NodeOp(Softmax(val_,
+ children_[0]->val(),
+ mask_ ? mask_->val() : nullptr))
+ };
}
- void backward() {
+ NodeOps backwardOps() {
// For each row, the Jacobian times vector is given by:
// J * dy = p .* (dy - avg*1)
// where avg = p'*dy and p is the softmax output (probabilities).
@@ -208,19 +200,15 @@ struct SoftmaxNodeOp : public UnaryNodeOp {
// http://jmlr.org/proceedings/papers/v48/martins16.pdf
// val_ is already masked if there is a mask, so no need to apply here.
- if(a_->trainable())
- SoftmaxGrad(a_->grad(), adj_, val_);
- }
-
- virtual std::string graphviz() {
- std::stringstream ss;
- ss << "\"" << this << "\" [shape=\"box\", label=" << label("softmax")
- << ", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
- ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
- if(mask_)
- ss << "\"" << mask_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
- return ss.str();
- };
+
+ return {
+ NodeOp(SoftmaxGrad(children_[0]->grad(), adj_, val_))
+ };
+ }
+
+ const std::string type() {
+ return "softmax";
+ }
};
struct LogSoftmaxNodeOp : public UnaryNodeOp {
@@ -228,56 +216,24 @@ struct LogSoftmaxNodeOp : public UnaryNodeOp {
LogSoftmaxNodeOp(Args ...args)
: UnaryNodeOp(args...) { }
- void forward() {
- CudnnLogSoftmax(val_, a_->val());
+ NodeOps forwardOps() {
+ return {
+ NodeOp(LogSoftmax(val_, children_[0]->val()))
+ };
}
- void backward() {
+ NodeOps backwardOps() {
// Based on the description for softmax, we have logsoftmax:
// J * dy = dy - avg*1
// where avg = exp(p)'*dy and p is the softmax output (probabilities).
- if(a_->trainable())
- LogSoftmaxGrad(a_->grad(), adj_, val_);
+ return {
+ NodeOp(LogSoftmaxGrad(children_[0]->grad(), adj_, val_))
+ };
}
- virtual std::string graphviz() {
- std::stringstream ss;
- ss << "\"" << this << "\" [shape=\"box\", label=" << label("log-softmax")
- << ", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
- ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
- return ss.str();
- };
-};
-
-
-struct ArgmaxNodeOp : public UnaryNodeOp {
- template <typename ...Args>
- ArgmaxNodeOp(Expr a, Args ...args)
- : UnaryNodeOp(a, keywords::shape=newShape(a), args...) { }
-
- void forward() {
- // B = softmax(A).
- //Argmax(&val_, &a_->val());
+ const std::string type() {
+ return "logsoftmax";
}
-
- void backward() {
- }
-
- Shape newShape(Expr a) {
- Shape shape = a->shape();
- shape.set(0, 1);
- return shape;
- }
-
-
- virtual std::string graphviz() {
- std::stringstream ss;
- ss << "\"" << this << "\" [shape=\"box\", label="
- << label("argmax") << ", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
- ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
- return ss.str();
- };
-
};
struct SumNodeOp : public UnaryNodeOp {
@@ -285,13 +241,12 @@ struct SumNodeOp : public UnaryNodeOp {
SumNodeOp(Expr a, Args ...args)
: UnaryNodeOp(a, keywords::shape=newShape(a, args...), args...) { }
- void forward() {
- Reduce(_1, val_, a_->val());
+ NodeOps forwardOps() {
+ return { NodeOp(Reduce(_1, val_, children_[0]->val())) };
}
- void backward() {
- if(a_->trainable())
- Add(_1, a_->grad(), adj_);
+ NodeOps backwardOps() {
+ return { NodeOp(Add(_1, children_[0]->grad(), adj_)) };
}
template <class ...Args>
@@ -310,13 +265,13 @@ struct SumNodeOp : public UnaryNodeOp {
return shape;
}
- virtual std::string graphviz() {
- std::stringstream ss;
- ss << "\"" << this << "\" [shape=\"box\", label="
- << label("sum") << ", style=\"filled\", fillcolor=\"orange\"]" << std::endl;
- ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
- return ss.str();
- };
+ const std::string type() {
+ return "sum";
+ }
+
+ const std::string color() {
+ return "orange";
+ }
};
@@ -325,17 +280,22 @@ struct MeanNodeOp : public UnaryNodeOp {
MeanNodeOp(Expr a, Args ...args)
: UnaryNodeOp(a, keywords::shape=newShape(a, args...), args...) { }
- void forward() {
- int left = a_->shape().elements() / val_->shape().elements();
+ NodeOps forwardOps() {
+ int left = children_[0]->shape().elements() / val_->shape().elements();
float scale = 1.f / left;
- Reduce(_1 * scale, val_, a_->val());
+
+ return {
+ NodeOp(Reduce(_1 * scale, val_, children_[0]->val()))
+ };
}
- void backward() {
- int left = a_->shape().elements() / val_->shape().elements();
+ NodeOps backwardOps() {
+ int left = children_[0]->shape().elements() / val_->shape().elements();
float scale = 1.f / left;
- if(a_->trainable())
- Add(_1 * scale, a_->grad(), adj_);
+
+ return {
+ NodeOp(Add(_1 * scale, children_[0]->grad(), adj_))
+ };
}
template <class ...Args>
@@ -354,12 +314,12 @@ struct MeanNodeOp : public UnaryNodeOp {
return shape;
}
- virtual std::string graphviz() {
- std::stringstream ss;
- ss << "\"" << this << "\" [shape=\"box\", label="
- << label("mean") << ", style=\"filled\", fillcolor=\"orange\"]" << std::endl;
- ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
- return ss.str();
+ const std::string type() {
+ return "mean";
+ }
+
+ const std::string color() {
+ return "orange";
}
};
@@ -369,24 +329,26 @@ struct LogNodeOp : public UnaryNodeOp {
LogNodeOp(Args ...args)
: UnaryNodeOp(args...) {}
- void forward() {
- Element(_1 = Log(_2), val_, a_->val());
+ NodeOps forwardOps() {
+ return {
+ NodeOp(Element(_1 = Log(_2),
+ val_,
+ children_[0]->val()))
+ };
}
- void backward() {
- if(a_->trainable())
- Add(_1 * (1.f / _2),
- a_->grad(), adj_, a_->val());
+ NodeOps backwardOps() {
+ return {
+ NodeOp(Add(_1 * (1.f / _2),
+ children_[0]->grad(),
+ adj_,
+ children_[0]->val()))
+ };
}
- virtual std::string graphviz() {
- std::stringstream ss;
- ss << "\"" << this << "\" [shape=\"box\", label="
- << label("log") << ", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
- ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
- return ss.str();
- };
-
+ const std::string type() {
+ return "log";
+ }
};
struct ExpNodeOp : public UnaryNodeOp {
@@ -394,23 +356,26 @@ struct ExpNodeOp : public UnaryNodeOp {
ExpNodeOp(Args ...args)
: UnaryNodeOp(args...) { }
- void forward() {
- Element(_1 = Exp(_2), val_, a_->val());
+ NodeOps forwardOps() {
+ return {
+ NodeOp(Element(_1 = Exp(_2),
+ val_,
+ children_[0]->val()))
+ };
}
- void backward() {
- if(a_->trainable())
- Add(_1 * Exp(_2),
- a_->grad(), adj_, a_->val());
+ NodeOps backwardOps() {
+ return {
+ NodeOp(Add(_1 * Exp(_2),
+ children_[0]->grad(),
+ adj_,
+ children_[0]->val()))
+ };
}
- virtual std::string graphviz() {
- std::stringstream ss;
- ss << "\"" << this << "\" [shape=\"box\", label=" << label("exp")
- << ", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
- ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
- return ss.str();
- };
+ const std::string type() {
+ return "exp";
+ }
};
@@ -419,21 +384,24 @@ struct NegNodeOp : public UnaryNodeOp {
NegNodeOp(Args ...args)
: UnaryNodeOp(args...) { }
- void forward() {
- Element(_1 = -_2, val_, a_->val());
+ NodeOps forwardOps() {
+ return {
+ NodeOp(Element(_1 = -_2,
+ val_,
+ children_[0]->val()))
+ };
}
- void backward() {
- if(a_->trainable())
- Add(-_1, a_->grad(), adj_);
+ NodeOps backwardOps() {
+ return {
+ NodeOp(Add(-_1,
+ children_[0]->grad(),
+ adj_))
+ };
}
- virtual std::string graphviz() {
- std::stringstream ss;
- ss << "\"" << this << "\" [shape=\"box\", label="
- << label("-") << ", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
- ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
- return ss.str();
+ const std::string type() {
+ return "-";
}
};
@@ -445,13 +413,18 @@ struct RowsNodeOp : public UnaryNodeOp {
thrust::copy(indeces.begin(), indeces.end(), indeces_.begin());
}
- void forward() {
- CopyRows(val_, a_->val(), indeces_);
+ NodeOps forwardOps() {
+ return {
+ NodeOp(CopyRows(val_, children_[0]->val(), indeces_))
+ };
}
- void backward() {
- if(a_->trainable())
- PasteRows(a_->grad(), adj_, indeces_);
+ NodeOps backwardOps() {
+ return {
+ NodeOp(PasteRows(children_[0]->grad(),
+ adj_,
+ indeces_))
+ };
}
template <class ...Args>
@@ -461,12 +434,12 @@ struct RowsNodeOp : public UnaryNodeOp {
return shape;
}
- virtual std::string graphviz() {
- std::stringstream ss;
- ss << "\"" << this << "\" [shape=\"box\", label="
- << label("rows") << ", style=\"filled\", fillcolor=\"orange\"]" << std::endl;
- ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
- return ss.str();
+ const std::string type() {
+ return "rows";
+ }
+
+ const std::string color() {
+ return "orange";
}
DeviceVector<size_t> indeces_;
@@ -477,13 +450,18 @@ struct TransposeNodeOp : public UnaryNodeOp {
TransposeNodeOp(Expr a, Args ...args)
: UnaryNodeOp(a, keywords::shape=newShape(a), args...) { }
- void forward() {
- Transpose(val_, a_->val());
+ NodeOps forwardOps() {
+ return {
+ NodeOp(Transpose(getCublasHandle(),
+ val_, children_[0]->val()))
+ };
}
- void backward() {
- if(a_->trainable())
- Transpose(a_->grad(), adj_);
+ NodeOps backwardOps() {
+ return {
+ NodeOp(Transpose(getCublasHandle(),
+ children_[0]->grad(), adj_))
+ };
}
template <class ...Args>
@@ -495,12 +473,12 @@ struct TransposeNodeOp : public UnaryNodeOp {
return shape;
}
- virtual std::string graphviz() {
- std::stringstream ss;
- ss << "\"" << this << "\" [shape=\"box\", label="
- << label("transpose") << ", style=\"filled\", fillcolor=\"orange\"]" << std::endl;
- ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
- return ss.str();
+ const std::string type() {
+ return "transpose";
+ }
+
+ const std::string color() {
+ return "orange";
}
};
@@ -516,33 +494,29 @@ struct ReshapeNodeOp : public UnaryNodeOp {
void backward() {}
void init_dependent() {
- a_->init_dependent();
+ children_[0]->init_dependent();
}
void set_zero_adjoint() {
- a_->set_zero_adjoint();
+ children_[0]->set_zero_adjoint();
}
Tensor& val() {
- val_.reset(new TensorGPU(a_->val()->data(), shape()));
+ val_.reset(new TensorGPU(children_[0]->val()->data(), shape()));
return val_;
};
Tensor& grad() {
- adj_.reset(new TensorGPU(a_->grad()->data(), shape()));
+ adj_.reset(new TensorGPU(children_[0]->grad()->data(), shape()));
return adj_;
};
- std::vector<Expr> children() {
- return a_->children();
+ const std::string type() {
+ return "reshape";
}
- virtual std::string graphviz() {
- std::stringstream ss;
- ss << "\"" << this << "\" [shape=\"box\", label="
- << label("reshape") << ", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
- ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
- return ss.str();
+ const std::string color() {
+ return "grey";
}
};
@@ -568,35 +542,31 @@ struct TimestepNodeOp : public UnaryNodeOp {
void backward() {}
void init_dependent() {
- a_->init_dependent();
+ children_[0]->init_dependent();
}
void set_zero_adjoint() {
- a_->set_zero_adjoint();
+ children_[0]->set_zero_adjoint();
}
Tensor& val() {
size_t offset = step_ * shape().elements();
- val_.reset(new TensorGPU(a_->val()->data() + offset, shape()));
+ val_.reset(new TensorGPU(children_[0]->val()->data() + offset, shape()));
return val_;
};
Tensor& grad() {
size_t offset = step_ * shape().elements();
- adj_.reset(new TensorGPU(a_->grad()->data() + offset, shape()));
+ adj_.reset(new TensorGPU(children_[0]->grad()->data() + offset, shape()));
return adj_;
};
- std::vector<Expr> children() {
- return a_->children();
+ const std::string type() {
+ return "step";
}
- virtual std::string graphviz() {
- std::stringstream ss;
- ss << "\"" << this << "\" [shape=\"box\", label="
- << label("step") << ", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
- ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
- return ss.str();
+ const std::string color() {
+ return "grey";
}
};