Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMarcin Junczys-Dowmunt <junczys@amu.edu.pl>2017-02-13 20:35:36 +0300
committerMarcin Junczys-Dowmunt <junczys@amu.edu.pl>2017-02-13 20:35:36 +0300
commita562a5dcc7f8c8ba72f2848f364e9965b61188a3 (patch)
treeda39f56d37dbb8420dabb8636220ae0b6738ff2d /src/graph/node_operators_unary.h
parent7a095eadcd68d2a83095203ce4af709a14a8f9e4 (diff)
bn
Diffstat (limited to 'src/graph/node_operators_unary.h')
-rw-r--r--src/graph/node_operators_unary.h42
1 files changed, 35 insertions, 7 deletions
diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h
index 5a6da1b4..80ee0801 100644
--- a/src/graph/node_operators_unary.h
+++ b/src/graph/node_operators_unary.h
@@ -423,19 +423,17 @@ struct ExpNodeOp : public UnaryNodeOp {
};
-struct PowNodeOp : public UnaryNodeOp {
- float exponent_;
+struct SqrtNodeOp : public UnaryNodeOp {
float epsilon_;
template <typename ...Args>
- PowNodeOp(Expr a, float exponent, float epsilon, Args ...args)
+ SqrtNodeOp(Expr a, float epsilon, Args ...args)
: UnaryNodeOp(a, args...),
- exponent_(exponent),
epsilon_(epsilon) { }
NodeOps forwardOps() {
return {
- NodeOp(Element(_1 = Pow(epsilon_ + _2, exponent_),
+ NodeOp(Element(_1 = Sqrt(_2 + epsilon_),
val_,
children_[0]->val()))
};
@@ -443,7 +441,37 @@ struct PowNodeOp : public UnaryNodeOp {
NodeOps backwardOps() {
return {
- NodeOp(Add(exponent_ * Pow(epsilon_ + _1, exponent_ - 1.f) * _2,
+ NodeOp(Add(0.5f * (1.f / _1) * _2,
+ children_[0]->grad(),
+ val_,
+ adj_))
+ };
+ }
+
+ const std::string type() {
+ return "sqrt";
+ }
+
+};
+
+struct SquareNodeOp : public UnaryNodeOp {
+ float epsilon_;
+
+ template <typename ...Args>
+ SquareNodeOp(Args ...args)
+ : UnaryNodeOp(args...) { }
+
+ NodeOps forwardOps() {
+ return {
+ NodeOp(Element(_1 = _2 * _2,
+ val_,
+ children_[0]->val()))
+ };
+ }
+
+ NodeOps backwardOps() {
+ return {
+ NodeOp(Add(2.f * _1 * _2,
children_[0]->grad(),
children_[0]->val(),
adj_))
@@ -451,7 +479,7 @@ struct PowNodeOp : public UnaryNodeOp {
}
const std::string type() {
- return "pow";
+ return "square";
}
};