Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/src/graph
diff options
context:
space:
mode:
authorMarcin Junczys-Dowmunt <junczys@amu.edu.pl>2017-11-02 22:58:56 +0300
committerMarcin Junczys-Dowmunt <junczys@amu.edu.pl>2017-11-02 22:58:56 +0300
commit9d06d786ba568e74f54be98f5a23d9b53d87ed7b (patch)
tree8d82c2e2c33d9fa113d5bd9fd52bb8cd099db3bd /src/graph
parentca64c429e4aa4dcd49b68bab6a7d744fe06b44c2 (diff)
farewell thrust
Diffstat (limited to 'src/graph')
-rw-r--r--src/graph/node_operators_binary.h24
-rw-r--r--src/graph/node_operators_unary.h14
2 files changed, 35 insertions, 3 deletions
diff --git a/src/graph/node_operators_binary.h b/src/graph/node_operators_binary.h
index f5fe6cf5..86fafdfb 100644
--- a/src/graph/node_operators_binary.h
+++ b/src/graph/node_operators_binary.h
@@ -193,6 +193,8 @@ struct AffineNodeOp : public NaryNodeOp {
}
NodeOps forwardOps() {
+ using namespace functional;
+
return {
NodeOp(Prod(std::static_pointer_cast<BackendGPU>(getBackend())
->getCublasHandle(),
@@ -207,6 +209,8 @@ struct AffineNodeOp : public NaryNodeOp {
}
NodeOps backwardOps() {
+ using namespace functional;
+
// D is the adjoint, the matrix of derivatives
// df/dA += D*B.T
// df/dB += A.T*D
@@ -405,10 +409,14 @@ struct ScalarProductNodeOp : public NaryNodeOp {
}
NodeOps forwardOps() {
+ using namespace functional;
+
return {NodeOp(Reduce(_1 * _2, val_, child(0)->val(), child(1)->val()))};
}
NodeOps backwardOps() {
+ using namespace functional;
+
return {NodeOp(Add(_1 * _2, child(0)->grad(), child(1)->val(), adj_)),
NodeOp(Add(_1 * _2, child(1)->grad(), child(0)->val(), adj_))};
}
@@ -435,11 +443,15 @@ struct PlusNodeOp : public ElementBinaryNodeOp {
PlusNodeOp(Args... args) : ElementBinaryNodeOp(args...) {}
NodeOps forwardOps() {
+ using namespace functional;
+
return {
NodeOp(Element(_1 = _2 + _3, val_, child(0)->val(), child(1)->val()))};
}
NodeOps backwardOps() {
+ using namespace functional;
+
return {NodeOp(Add(_1, child(0)->grad(), adj_)),
NodeOp(Add(_1, child(1)->grad(), adj_))};
}
@@ -452,11 +464,15 @@ struct MinusNodeOp : public ElementBinaryNodeOp {
MinusNodeOp(Args... args) : ElementBinaryNodeOp(args...) {}
NodeOps forwardOps() {
+ using namespace functional;
+
return {
NodeOp(Element(_1 = _2 - _3, val_, child(0)->val(), child(1)->val()))};
}
NodeOps backwardOps() {
+ using namespace functional;
+
return {NodeOp(Add(_1, child(0)->grad(), adj_)),
NodeOp(Add(-_1, child(1)->grad(), adj_))};
}
@@ -469,11 +485,15 @@ struct MultNodeOp : public ElementBinaryNodeOp {
MultNodeOp(Args... args) : ElementBinaryNodeOp(args...) {}
NodeOps forwardOps() {
+ using namespace functional;
+
return {
NodeOp(Element(_1 = _2 * _3, val_, child(0)->val(), child(1)->val()))};
}
NodeOps backwardOps() {
+ using namespace functional;
+
return {NodeOp(Add(_1 * _2, child(0)->grad(), adj_, child(1)->val())),
NodeOp(Add(_1 * _2, child(1)->grad(), adj_, child(0)->val()))};
}
@@ -486,11 +506,15 @@ struct DivNodeOp : public ElementBinaryNodeOp {
DivNodeOp(Args... args) : ElementBinaryNodeOp(args...) {}
NodeOps forwardOps() {
+ using namespace functional;
+
return {
NodeOp(Element(_1 = _2 / _3, val_, child(0)->val(), child(1)->val()))};
}
NodeOps backwardOps() {
+ using namespace functional;
+
return {
NodeOp(Add(_1 * 1.0f / _2, child(0)->grad(), adj_, child(1)->val())),
NodeOp(Add(-_1 * _2 / (_3 * _3),
diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h
index 05294bee..9d5b8287 100644
--- a/src/graph/node_operators_unary.h
+++ b/src/graph/node_operators_unary.h
@@ -55,7 +55,10 @@ public:
return {NodeOp(Element(_1 = _2 + scalar_, val_, child(0)->val()))};
}
- NodeOps backwardOps() { return {NodeOp(Add(_1, child(0)->grad(), adj_))}; }
+ NodeOps backwardOps() {
+ using namespace functional;
+ return {NodeOp(Add(_1, child(0)->grad(), adj_))};
+ }
const std::string type() { return "scalar_add"; }
};
@@ -392,9 +395,14 @@ struct SumNodeOp : public UnaryNodeOp {
SumNodeOp(Expr a, Args... args)
: UnaryNodeOp(a, keywords::shape = newShape(a, args...), args...) {}
- NodeOps forwardOps() { return {NodeOp(Reduce(_1, val_, child(0)->val()))}; }
+ NodeOps forwardOps() {
+ using namespace functional;
+
+ return {NodeOp(Reduce(_1, val_, child(0)->val()))}; }
- NodeOps backwardOps() { return {NodeOp(Add(_1, child(0)->grad(), adj_))}; }
+ NodeOps backwardOps() {
+ using namespace functional;
+ return {NodeOp(Add(_1, child(0)->grad(), adj_))}; }
template <class... Args>
Shape newShape(Expr a, Args... args) {