Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRoman Grundkiewicz <romang@amu.edu.pl>2017-06-05 14:05:57 +0300
committerRoman Grundkiewicz <romang@amu.edu.pl>2017-06-05 14:05:57 +0300
commitd2c9e9fc740af88f81c3d226610c9cfec35a5a34 (patch)
tree0918fa1f98d7376877d265c18f4a46e0be6dfd33 /src/graph/node_operators_unary.h
parent8c9743121b5b397a02eef0218af1d5a3fd691788 (diff)
Autoformat .h and .cpp files
Diffstat (limited to 'src/graph/node_operators_unary.h')
-rw-r--r--src/graph/node_operators_unary.h666
1 files changed, 217 insertions, 449 deletions
diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h
index fe3c4056..81c58b7e 100644
--- a/src/graph/node_operators_unary.h
+++ b/src/graph/node_operators_unary.h
@@ -1,113 +1,78 @@
#pragma once
+#include "graph/backend_gpu.h"
#include "graph/node.h"
-#include "tensors/tensor.h"
+#include "kernels/sparse.h"
#include "kernels/tensor_operators.h"
#include "kernels/thrust_functions.h"
-#include "kernels/sparse.h"
-#include "graph/backend_gpu.h"
+#include "tensors/tensor.h"
namespace marian {
struct UnaryNodeOp : public NaryNodeOp {
- template <typename ...Args>
- UnaryNodeOp(Expr a, Args ...args)
- : NaryNodeOp({a},
- keywords::shape=a->shape(),
- args...) {}
-
- const std::string color() {
- return "yellow";
- }
+ template <typename... Args>
+ UnaryNodeOp(Expr a, Args... args)
+ : NaryNodeOp({a}, keywords::shape = a->shape(), args...) {}
+
+ const std::string color() { return "yellow"; }
};
struct ScalarAddNodeOp : public UnaryNodeOp {
- private:
- float scalar_{0};
-
- public:
- template <typename ...Args>
- ScalarAddNodeOp(Expr a, float scalar, Args ...args)
- : UnaryNodeOp(a, args...), scalar_{scalar} { }
-
- NodeOps forwardOps() {
- return {
- NodeOp(Element(_1 = _2 + scalar_,
- val_,
- child(0)->val()))
- };
- }
-
- NodeOps backwardOps() {
- return {
- NodeOp(Add(_1, child(0)->grad(), adj_))
- };
- }
-
- const std::string type() {
- return "scalar_add";
- }
+private:
+ float scalar_{0};
+
+public:
+ template <typename... Args>
+ ScalarAddNodeOp(Expr a, float scalar, Args... args)
+ : UnaryNodeOp(a, args...), scalar_{scalar} {}
+
+ NodeOps forwardOps() {
+ return {NodeOp(Element(_1 = _2 + scalar_, val_, child(0)->val()))};
+ }
+
+ NodeOps backwardOps() { return {NodeOp(Add(_1, child(0)->grad(), adj_))}; }
+
+ const std::string type() { return "scalar_add"; }
};
struct ScalarMultNodeOp : public UnaryNodeOp {
- private:
- float scalar_{0};
-
- public:
-
- template <typename ...Args>
- ScalarMultNodeOp(Expr a, float scalar, Args ...args)
- : UnaryNodeOp(a, args...), scalar_{scalar} { }
-
- NodeOps forwardOps() {
- return {
- NodeOp(Element(_1 = scalar_ * _2,
- val_,
- child(0)->val()))
- };
- }
-
- NodeOps backwardOps() {
- return {
- NodeOp(Add(scalar_ * _1, child(0)->grad(), adj_))
- };
- }
-
- const std::string type() {
- return "scalar_add";
- }
-};
+private:
+ float scalar_{0};
+
+public:
+ template <typename... Args>
+ ScalarMultNodeOp(Expr a, float scalar, Args... args)
+ : UnaryNodeOp(a, args...), scalar_{scalar} {}
+
+ NodeOps forwardOps() {
+ return {NodeOp(Element(_1 = scalar_ * _2, val_, child(0)->val()))};
+ }
+
+ NodeOps backwardOps() {
+ return {NodeOp(Add(scalar_ * _1, child(0)->grad(), adj_))};
+ }
+ const std::string type() { return "scalar_add"; }
+};
struct LogitNodeOp : public UnaryNodeOp {
- template <typename ...Args>
- LogitNodeOp(Args ...args)
- : UnaryNodeOp(args...) { }
+ template <typename... Args>
+ LogitNodeOp(Args... args) : UnaryNodeOp(args...) {}
NodeOps forwardOps() {
- return {
- NodeOp(Element(_1 = Sigma(_2),
- val_,
- child(0)->val()))
- };
+ return {NodeOp(Element(_1 = Sigma(_2), val_, child(0)->val()))};
}
NodeOps backwardOps() {
- return {
- NodeOp(Add(_1 * _2 * (1.0f - _2),
- child(0)->grad(),
- adj_, val_))
- };
+ return {NodeOp(Add(_1 * _2 * (1.0f - _2), child(0)->grad(), adj_, val_))};
}
- const std::string type() {
- return "logit";
- }
+ const std::string type() { return "logit"; }
};
struct TanhNodeOp : public NaryNodeOp {
TanhNodeOp(const std::vector<Expr>& nodes)
- : NaryNodeOp(nodes, keywords::shape=newShape(nodes)) { }
+ : NaryNodeOp(nodes, keywords::shape = newShape(nodes)) {}
Shape newShape(const std::vector<Expr>& nodes) {
Shape shape = nodes[0]->shape();
@@ -124,34 +89,27 @@ struct TanhNodeOp : public NaryNodeOp {
}
NodeOps forwardOps() {
- switch (children_.size()) {
- case 1:
- return { NodeOp(Element(_1 = Tanh(_2),
- val_,
- child(0)->val())) };
+ switch(children_.size()) {
+ case 1: return {NodeOp(Element(_1 = Tanh(_2), val_, child(0)->val()))};
case 2:
- return { NodeOp(Element(_1 = Tanh(_2 + _3),
- val_,
- child(0)->val(),
- child(1)->val())) };
+ return {NodeOp(Element(
+ _1 = Tanh(_2 + _3), val_, child(0)->val(), child(1)->val()))};
case 3:
- return { NodeOp(Element(_1 = Tanh(_2 + _3 + _4),
- val_,
- child(0)->val(),
- child(1)->val(),
- child(2)->val())) };
+ return {NodeOp(Element(_1 = Tanh(_2 + _3 + _4),
+ val_,
+ child(0)->val(),
+ child(1)->val(),
+ child(2)->val()))};
default:
return {
- NodeOp(
- Element(_1 = _2 + _3 + _4,
- val_,
- child(0)->val(),
- child(1)->val(),
- child(2)->val());
- for(int i = 3; i < children_.size(); ++i)
- Element(_1 += _2, val_, child(i)->val());
- Element(_1 = Tanh(_1), val_);
- )
+ NodeOp(Element(_1 = _2 + _3 + _4,
+ val_,
+ child(0)->val(),
+ child(1)->val(),
+ child(2)->val());
+ for(int i = 3; i < children_.size(); ++i)
+ Element(_1 += _2, val_, child(i)->val());
+ Element(_1 = Tanh(_1), val_);)
};
}
}
@@ -160,27 +118,24 @@ struct TanhNodeOp : public NaryNodeOp {
NodeOps ops;
for(int i = 0; i < children_.size(); i++) {
ops.push_back(
- NodeOp(Add(_1 * (1.0f - (_2 * _2)),
- child(i)->grad(), adj_, val_))
- );
+ NodeOp(Add(_1 * (1.0f - (_2 * _2)), child(i)->grad(), adj_, val_)));
}
return ops;
}
- const std::string color() {
- return "yellow";
- }
+ const std::string color() { return "yellow"; }
- const std::string type() {
- return "tanh";
- }
+ const std::string type() { return "tanh"; }
};
/**
- * Represents a <a href="https://en.wikipedia.org/wiki/Rectifier_(neural_networks)">rectified linear</a> node
+ * Represents a <a
+href="https://en.wikipedia.org/wiki/Rectifier_(neural_networks)">rectified
+linear</a> node
* in an expression graph.
*
- * This node implements the <a href="https://en.wikipedia.org/wiki/Activation_function">activation function</a>
+ * This node implements the <a
+href="https://en.wikipedia.org/wiki/Activation_function">activation function</a>
* \f$f(x) = \max(0, x)\f$ and its derivative:
*
\f[
@@ -192,50 +147,35 @@ struct TanhNodeOp : public NaryNodeOp {
\f]
*/
struct ReLUNodeOp : public UnaryNodeOp {
- template <typename ...Args>
- ReLUNodeOp(Args ...args)
- : UnaryNodeOp(args...) { }
+ template <typename... Args>
+ ReLUNodeOp(Args... args) : UnaryNodeOp(args...) {}
NodeOps forwardOps() {
- return {
- NodeOp(Element(_1 = ReLU(_2),
- val_,
- child(0)->val()))
- };
+ return {NodeOp(Element(_1 = ReLU(_2), val_, child(0)->val()))};
}
NodeOps backwardOps() {
- return {
- NodeOp(Add(_1 * ReLUback(_2),
- child(0)->grad(),
- adj_, child(0)->val()))
- };
+ return {NodeOp(
+ Add(_1 * ReLUback(_2), child(0)->grad(), adj_, child(0)->val()))};
}
- const std::string type() {
- return "ReLU";
- }
+ const std::string type() { return "ReLU"; }
};
struct SoftmaxNodeOp : public NaryNodeOp {
- template <typename ...Args>
- SoftmaxNodeOp(Expr a, Args ...args)
- : NaryNodeOp(a, args...), mask_(nullptr) {
- }
+ template <typename... Args>
+ SoftmaxNodeOp(Expr a, Args... args)
+ : NaryNodeOp(a, args...), mask_(nullptr) {}
- template <typename ...Args>
- SoftmaxNodeOp(Expr a, Expr mask, Args ...args)
- : NaryNodeOp({a, mask}, args...), mask_(mask) {
- }
+ template <typename... Args>
+ SoftmaxNodeOp(Expr a, Expr mask, Args... args)
+ : NaryNodeOp({a, mask}, args...), mask_(mask) {}
Expr mask_;
NodeOps forwardOps() {
return {
- NodeOp(Softmax(val_,
- child(0)->val(),
- mask_ ? mask_->val() : nullptr))
- };
+ NodeOp(Softmax(val_, child(0)->val(), mask_ ? mask_->val() : nullptr))};
}
virtual size_t hash() {
@@ -247,7 +187,6 @@ struct SoftmaxNodeOp : public NaryNodeOp {
return hash_;
}
-
NodeOps backwardOps() {
// For each row, the Jacobian times vector is given by:
// J * dy = p .* (dy - avg*1)
@@ -261,65 +200,47 @@ struct SoftmaxNodeOp : public NaryNodeOp {
// val_ is already masked if there is a mask, so no need to apply here.
- return {
- NodeOp(SoftmaxGrad(child(0)->grad(), adj_, val_))
- };
+ return {NodeOp(SoftmaxGrad(child(0)->grad(), adj_, val_))};
}
- const std::string type() {
- return "softmax";
- }
+ const std::string type() { return "softmax"; }
};
struct LogSoftmaxNodeOp : public UnaryNodeOp {
- template <typename ...Args>
- LogSoftmaxNodeOp(Args ...args)
- : UnaryNodeOp(args...) { }
+ template <typename... Args>
+ LogSoftmaxNodeOp(Args... args) : UnaryNodeOp(args...) {}
- NodeOps forwardOps() {
- return {
- NodeOp(LogSoftmax(val_, child(0)->val()))
- };
- }
+ NodeOps forwardOps() { return {NodeOp(LogSoftmax(val_, child(0)->val()))}; }
NodeOps backwardOps() {
// Based on the description for softmax, we have logsoftmax:
// J * dy = dy - avg*1
// where avg = exp(p)'*dy and p is the softmax output (probabilities).
- return {
- NodeOp(LogSoftmaxGrad(child(0)->grad(), adj_, val_))
- };
+ return {NodeOp(LogSoftmaxGrad(child(0)->grad(), adj_, val_))};
}
- const std::string type() {
- return "logsoftmax";
- }
+ const std::string type() { return "logsoftmax"; }
};
struct SumNodeOp : public UnaryNodeOp {
int ax_;
- template <typename ...Args>
- SumNodeOp(Expr a, Args ...args)
- : UnaryNodeOp(a, keywords::shape=newShape(a, args...), args...),
- ax_(keywords::Get(keywords::axis, -1, args...)) { }
+ template <typename... Args>
+ SumNodeOp(Expr a, Args... args)
+ : UnaryNodeOp(a, keywords::shape = newShape(a, args...), args...),
+ ax_(keywords::Get(keywords::axis, -1, args...)) {}
- NodeOps forwardOps() {
- return { NodeOp(Reduce(_1, val_, child(0)->val())) };
- }
+ NodeOps forwardOps() { return {NodeOp(Reduce(_1, val_, child(0)->val()))}; }
- NodeOps backwardOps() {
- return { NodeOp(Add(_1, child(0)->grad(), adj_)) };
- }
+ NodeOps backwardOps() { return {NodeOp(Add(_1, child(0)->grad(), adj_))}; }
- template <class ...Args>
- Shape newShape(Expr a, Args ...args) {
+ template <class... Args>
+ Shape newShape(Expr a, Args... args) {
int ax = keywords::Get(keywords::axis, -1, args...);
Shape shape = a->shape();
if(ax != -1) {
shape.set(ax, 1);
- }
- else {
+ } else {
shape.set(0, 1);
shape.set(1, 1);
shape.set(2, 1);
@@ -328,13 +249,9 @@ struct SumNodeOp : public UnaryNodeOp {
return shape;
}
- const std::string type() {
- return "sum";
- }
+ const std::string type() { return "sum"; }
- const std::string color() {
- return "orange";
- }
+ const std::string color() { return "orange"; }
virtual size_t hash() {
if(!hash_) {
@@ -343,44 +260,37 @@ struct SumNodeOp : public UnaryNodeOp {
}
return hash_;
}
-
-
};
struct MeanNodeOp : public UnaryNodeOp {
int ax_;
- template <typename ...Args>
- MeanNodeOp(Expr a, Args ...args)
- : UnaryNodeOp(a, keywords::shape=newShape(a, args...), args...),
- ax_(keywords::Get(keywords::axis, -1, args...)) { }
+ template <typename... Args>
+ MeanNodeOp(Expr a, Args... args)
+ : UnaryNodeOp(a, keywords::shape = newShape(a, args...), args...),
+ ax_(keywords::Get(keywords::axis, -1, args...)) {}
NodeOps forwardOps() {
int left = child(0)->shape().elements() / val_->shape().elements();
float scale = 1.f / left;
- return {
- NodeOp(Reduce(_1, val_, child(0)->val(), scale))
- };
+ return {NodeOp(Reduce(_1, val_, child(0)->val(), scale))};
}
NodeOps backwardOps() {
int left = child(0)->shape().elements() / val_->shape().elements();
float scale = 1.f / left;
- return {
- NodeOp(Add(_1, child(0)->grad(), adj_, scale))
- };
+ return {NodeOp(Add(_1, child(0)->grad(), adj_, scale))};
}
- template <class ...Args>
- Shape newShape(Expr a, Args ...args) {
+ template <class... Args>
+ Shape newShape(Expr a, Args... args) {
int ax = keywords::Get(keywords::axis, -1, args...);
Shape shape = a->shape();
if(ax != -1) {
shape.set(ax, 1);
- }
- else {
+ } else {
shape.set(0, 1);
shape.set(1, 1);
shape.set(2, 1);
@@ -389,13 +299,9 @@ struct MeanNodeOp : public UnaryNodeOp {
return shape;
}
- const std::string type() {
- return "mean";
- }
+ const std::string type() { return "mean"; }
- const std::string color() {
- return "orange";
- }
+ const std::string color() { return "orange"; }
virtual size_t hash() {
if(!hash_) {
@@ -404,93 +310,55 @@ struct MeanNodeOp : public UnaryNodeOp {
}
return hash_;
}
-
};
-
struct LogNodeOp : public UnaryNodeOp {
- template <typename ...Args>
- LogNodeOp(Args ...args)
- : UnaryNodeOp(args...) {}
+ template <typename... Args>
+ LogNodeOp(Args... args) : UnaryNodeOp(args...) {}
NodeOps forwardOps() {
- return {
- NodeOp(Element(_1 = Log(_2),
- val_,
- child(0)->val()))
- };
+ return {NodeOp(Element(_1 = Log(_2), val_, child(0)->val()))};
}
NodeOps backwardOps() {
return {
- NodeOp(Add(_1 * (1.f / _2),
- child(0)->grad(),
- adj_,
- child(0)->val()))
- };
+ NodeOp(Add(_1 * (1.f / _2), child(0)->grad(), adj_, child(0)->val()))};
}
- const std::string type() {
- return "log";
- }
+ const std::string type() { return "log"; }
};
struct ExpNodeOp : public UnaryNodeOp {
- template <typename ...Args>
- ExpNodeOp(Args ...args)
- : UnaryNodeOp(args...) { }
+ template <typename... Args>
+ ExpNodeOp(Args... args) : UnaryNodeOp(args...) {}
NodeOps forwardOps() {
- return {
- NodeOp(Element(_1 = Exp(_2),
- val_,
- child(0)->val()))
- };
+ return {NodeOp(Element(_1 = Exp(_2), val_, child(0)->val()))};
}
NodeOps backwardOps() {
- return {
- NodeOp(Add(_1 * Exp(_2),
- child(0)->grad(),
- adj_,
- child(0)->val()))
- };
- }
-
- const std::string type() {
- return "exp";
+ return {NodeOp(Add(_1 * Exp(_2), child(0)->grad(), adj_, child(0)->val()))};
}
+ const std::string type() { return "exp"; }
};
struct SqrtNodeOp : public UnaryNodeOp {
float epsilon_;
- template <typename ...Args>
- SqrtNodeOp(Expr a, float epsilon, Args ...args)
- : UnaryNodeOp(a, args...),
- epsilon_(epsilon) { }
+ template <typename... Args>
+ SqrtNodeOp(Expr a, float epsilon, Args... args)
+ : UnaryNodeOp(a, args...), epsilon_(epsilon) {}
NodeOps forwardOps() {
- return {
- NodeOp(Element(_1 = Sqrt(_2 + epsilon_),
- val_,
- child(0)->val()))
- };
+ return {NodeOp(Element(_1 = Sqrt(_2 + epsilon_), val_, child(0)->val()))};
}
NodeOps backwardOps() {
- return {
- NodeOp(Add(0.5f * (1.f / _1) * _2,
- child(0)->grad(),
- val_,
- adj_))
- };
+ return {NodeOp(Add(0.5f * (1.f / _1) * _2, child(0)->grad(), val_, adj_))};
}
- const std::string type() {
- return "sqrt";
- }
+ const std::string type() { return "sqrt"; }
virtual size_t hash() {
if(!hash_) {
@@ -500,106 +368,65 @@ struct SqrtNodeOp : public UnaryNodeOp {
}
return hash_;
}
-
-
};
struct SquareNodeOp : public UnaryNodeOp {
float epsilon_;
- template <typename ...Args>
- SquareNodeOp(Args ...args)
- : UnaryNodeOp(args...) { }
+ template <typename... Args>
+ SquareNodeOp(Args... args) : UnaryNodeOp(args...) {}
NodeOps forwardOps() {
- return {
- NodeOp(Element(_1 = _2 * _2,
- val_,
- child(0)->val()))
- };
+ return {NodeOp(Element(_1 = _2 * _2, val_, child(0)->val()))};
}
NodeOps backwardOps() {
return {
- NodeOp(Add(2.f * _1 * _2,
- child(0)->grad(),
- child(0)->val(),
- adj_))
- };
- }
-
- const std::string type() {
- return "square";
+ NodeOp(Add(2.f * _1 * _2, child(0)->grad(), child(0)->val(), adj_))};
}
+ const std::string type() { return "square"; }
};
-
struct NegNodeOp : public UnaryNodeOp {
- template <typename ...Args>
- NegNodeOp(Args ...args)
- : UnaryNodeOp(args...) { }
+ template <typename... Args>
+ NegNodeOp(Args... args) : UnaryNodeOp(args...) {}
NodeOps forwardOps() {
- return {
- NodeOp(Element(_1 = -_2,
- val_,
- child(0)->val()))
- };
+ return {NodeOp(Element(_1 = -_2, val_, child(0)->val()))};
}
- NodeOps backwardOps() {
- return {
- NodeOp(Add(-_1,
- child(0)->grad(),
- adj_))
- };
- }
+ NodeOps backwardOps() { return {NodeOp(Add(-_1, child(0)->grad(), adj_))}; }
- const std::string type() {
- return "-";
- }
+ const std::string type() { return "-"; }
};
struct RowsNodeOp : public UnaryNodeOp {
- template <typename ...Args>
- RowsNodeOp(Expr a, const std::vector<size_t>& indeces, Args ...args)
- : UnaryNodeOp(a, keywords::shape=newShape(a, indeces), args...),
- indeces_(indeces) {
- }
+ template <typename... Args>
+ RowsNodeOp(Expr a, const std::vector<size_t>& indeces, Args... args)
+ : UnaryNodeOp(a, keywords::shape = newShape(a, indeces), args...),
+ indeces_(indeces) {}
NodeOps forwardOps() {
// @TODO: solve this with a tensor!
- return {
- NodeOp(CopyRows(val_,
- child(0)->val(),
- indeces_))
- };
+ return {NodeOp(CopyRows(val_, child(0)->val(), indeces_))};
}
NodeOps backwardOps() {
- return {
- NodeOp(PasteRows(child(0)->grad(),
- adj_,
- indeces_))
- };
+ return {NodeOp(PasteRows(child(0)->grad(), adj_, indeces_))};
}
- template <class ...Args>
+ template <class... Args>
Shape newShape(Expr a, const std::vector<size_t>& indeces) {
Shape shape = a->shape();
shape.set(0, indeces.size());
return shape;
}
- const std::string type() {
- return "rows";
- }
+ const std::string type() { return "rows"; }
- const std::string color() {
- return "orange";
- }
+ const std::string color() { return "orange"; }
virtual size_t hash() {
if(!hash_) {
@@ -611,49 +438,35 @@ struct RowsNodeOp : public UnaryNodeOp {
return hash_;
}
-
std::vector<size_t> indeces_;
};
struct ColsNodeOp : public UnaryNodeOp {
- template <typename ...Args>
- ColsNodeOp(Expr a, const std::vector<size_t>& indeces, Args ...args)
- : UnaryNodeOp(a, keywords::shape=newShape(a, indeces), args...),
- indeces_(indeces) {
- }
+ template <typename... Args>
+ ColsNodeOp(Expr a, const std::vector<size_t>& indeces, Args... args)
+ : UnaryNodeOp(a, keywords::shape = newShape(a, indeces), args...),
+ indeces_(indeces) {}
NodeOps forwardOps() {
// @TODO: solve this with a tensor!
- return {
- NodeOp(CopyCols(val_,
- child(0)->val(),
- indeces_))
- };
+ return {NodeOp(CopyCols(val_, child(0)->val(), indeces_))};
}
NodeOps backwardOps() {
- return {
- NodeOp(PasteCols(child(0)->grad(),
- adj_,
- indeces_))
- };
+ return {NodeOp(PasteCols(child(0)->grad(), adj_, indeces_))};
}
- template <class ...Args>
+ template <class... Args>
Shape newShape(Expr a, const std::vector<size_t>& indeces) {
Shape shape = a->shape();
shape.set(1, indeces.size());
return shape;
}
- const std::string type() {
- return "cols";
- }
+ const std::string type() { return "cols"; }
- const std::string color() {
- return "orange";
- }
+ const std::string color() { return "orange"; }
virtual size_t hash() {
if(!hash_) {
@@ -665,31 +478,29 @@ struct ColsNodeOp : public UnaryNodeOp {
return hash_;
}
-
std::vector<size_t> indeces_;
};
-
struct TransposeNodeOp : public UnaryNodeOp {
- template <typename ...Args>
- TransposeNodeOp(Expr a, Args ...args)
- : UnaryNodeOp(a, keywords::shape=newShape(a), args...) { }
+ template <typename... Args>
+ TransposeNodeOp(Expr a, Args... args)
+ : UnaryNodeOp(a, keywords::shape = newShape(a), args...) {}
NodeOps forwardOps() {
- return {
- NodeOp(Transpose(std::static_pointer_cast<BackendGPU>(getBackend())->getCublasHandle(),
- val_, child(0)->val()))
- };
+ return {NodeOp(Transpose(
+ std::static_pointer_cast<BackendGPU>(getBackend())->getCublasHandle(),
+ val_,
+ child(0)->val()))};
}
NodeOps backwardOps() {
- return {
- NodeOp(Transpose(std::static_pointer_cast<BackendGPU>(getBackend())->getCublasHandle(),
- child(0)->grad(), adj_))
- };
+ return {NodeOp(Transpose(
+ std::static_pointer_cast<BackendGPU>(getBackend())->getCublasHandle(),
+ child(0)->grad(),
+ adj_))};
}
- template <class ...Args>
+ template <class... Args>
Shape newShape(Expr a) {
Shape shape = a->shape();
int temp = shape[0];
@@ -698,13 +509,9 @@ struct TransposeNodeOp : public UnaryNodeOp {
return shape;
}
- const std::string type() {
- return "transpose";
- }
+ const std::string type() { return "transpose"; }
- const std::string color() {
- return "orange";
- }
+ const std::string color() { return "orange"; }
};
class ReshapeNodeOp : public UnaryNodeOp {
@@ -712,46 +519,37 @@ private:
Expr reshapee_;
public:
- template <typename ...Args>
- ReshapeNodeOp(Expr a, Shape shape, Args ...args)
- : UnaryNodeOp(a, keywords::shape=shape, args...),
- reshapee_(a) { }
+ template <typename... Args>
+ ReshapeNodeOp(Expr a, Shape shape, Args... args)
+ : UnaryNodeOp(a, keywords::shape = shape, args...), reshapee_(a) {}
-
-
size_t allocate() { return 0; }
void free() {}
void forward() {}
void backward() {}
- void init_dependent() {
- reshapee_->init_dependent();
- }
+ void init_dependent() { reshapee_->init_dependent(); }
- void set_zero_adjoint() {
- reshapee_->set_zero_adjoint();
- }
+ void set_zero_adjoint() { reshapee_->set_zero_adjoint(); }
- Tensor& val() {
+ Tensor& val() {
auto childVal = reshapee_->val();
- val_.reset(new TensorBase(childVal->data(), shape(), childVal->getDevice()));
+ val_.reset(
+ new TensorBase(childVal->data(), shape(), childVal->getDevice()));
return val_;
};
Tensor& grad() {
auto childGrad = reshapee_->grad();
- adj_.reset(new TensorBase(childGrad->data(), shape(), childGrad->getDevice()));
+ adj_.reset(
+ new TensorBase(childGrad->data(), shape(), childGrad->getDevice()));
return adj_;
};
- const std::string type() {
- return "reshape";
- }
+ const std::string type() { return "reshape"; }
- const std::string color() {
- return "grey";
- }
+ const std::string color() { return "grey"; }
virtual size_t hash() {
if(!hash_) {
@@ -762,7 +560,6 @@ public:
}
return hash_;
}
-
};
class TimestepNodeOp : public UnaryNodeOp {
@@ -772,9 +569,9 @@ private:
public:
TimestepNodeOp(Expr a, size_t step)
- : UnaryNodeOp(a, keywords::shape=newShape(a)),
- stepNode_(a), step_(step)
- { }
+ : UnaryNodeOp(a, keywords::shape = newShape(a)),
+ stepNode_(a),
+ step_(step) {}
Shape newShape(Expr a) {
Shape outShape = a->shape();
@@ -789,73 +586,53 @@ public:
void forward() {}
void backward() {}
- void init_dependent() {
- stepNode_->init_dependent();
- }
+ void init_dependent() { stepNode_->init_dependent(); }
- void set_zero_adjoint() {
- stepNode_->set_zero_adjoint();
- }
+ void set_zero_adjoint() { stepNode_->set_zero_adjoint(); }
- Tensor& val() {
+ Tensor& val() {
auto childVal = stepNode_->val();
size_t offset = step_ * shape().elements();
- val_.reset(new TensorBase(childVal->data() + offset, shape(), childVal->getDevice()));
+ val_.reset(new TensorBase(
+ childVal->data() + offset, shape(), childVal->getDevice()));
return val_;
};
Tensor& grad() {
auto childGrad = stepNode_->grad();
size_t offset = step_ * shape().elements();
- adj_.reset(new TensorBase(childGrad->data() + offset, shape(), childGrad->getDevice()));
+ adj_.reset(new TensorBase(
+ childGrad->data() + offset, shape(), childGrad->getDevice()));
return adj_;
};
- const std::string type() {
- return "step";
- }
+ const std::string type() { return "step"; }
- const std::string color() {
- return "grey";
- }
+ const std::string color() { return "grey"; }
virtual size_t hash() {
if(!hash_) {
- hash_ = NaryNodeOp::hash();
- boost::hash_combine(hash_, step_);
+ hash_ = NaryNodeOp::hash();
+ boost::hash_combine(hash_, step_);
}
return hash_;
}
-
};
struct ShiftNodeOp : public UnaryNodeOp {
- template <typename ...Args>
- ShiftNodeOp(Expr a, Shape shift, Args ...args)
- : UnaryNodeOp(a, keywords::shape=a->shape(), args...),
- shift_(shift) {
- }
+ template <typename... Args>
+ ShiftNodeOp(Expr a, Shape shift, Args... args)
+ : UnaryNodeOp(a, keywords::shape = a->shape(), args...), shift_(shift) {}
NodeOps forwardOps() {
- return {
- NodeOp(Shift(val_,
- child(0)->val(),
- shift_))
- };
+ return {NodeOp(Shift(val_, child(0)->val(), shift_))};
}
NodeOps backwardOps() {
- return {
- NodeOp(Shift(child(0)->grad(),
- adj_,
- shift_,
- true))
- };
+ return {NodeOp(Shift(child(0)->grad(), adj_, shift_, true))};
}
- const std::string type() {
- return "shift";
- }
+ const std::string type() { return "shift"; }
virtual size_t hash() {
if(!hash_) {
@@ -867,39 +644,31 @@ struct ShiftNodeOp : public UnaryNodeOp {
return hash_;
}
-
Shape shift_;
};
struct LexicalProbNodeOp : public NaryNodeOp {
- template <typename ...Args>
- LexicalProbNodeOp(Expr logits, Expr att, float eps, Ptr<sparse::CSR> lf, Args ...args)
- : NaryNodeOp({logits, att}, keywords::shape=logits->shape(), args...),
- eps_(eps),
- lf_(lf) {
- }
+ template <typename... Args>
+ LexicalProbNodeOp(
+ Expr logits, Expr att, float eps, Ptr<sparse::CSR> lf, Args... args)
+ : NaryNodeOp({logits, att}, keywords::shape = logits->shape(), args...),
+ eps_(eps),
+ lf_(lf) {}
void forward() {
- sparse::LfaForward(val_,
- child(0)->val(),
- child(1)->val(),
- lf_);
+ sparse::LfaForward(val_, child(0)->val(), child(1)->val(), lf_);
// val = x + ln(p + eps)
- Element(_1 = (Log(_1 + eps_) + _2),
- val_, child(0)->val());
+ Element(_1 = (Log(_1 + eps_) + _2), val_, child(0)->val());
}
-
+
void backward() {
- Add(_1, child(0)->grad(), adj_);
- // adj' = adj / (p + eps) = adj / exp(val - x)
- Element(_1 = _1 / Exp(_2 - _3),
- adj_, val_, child(0)->val());
- sparse::LfaBackward(child(1)->grad(), adj_, lf_);
+ Add(_1, child(0)->grad(), adj_);
+ // adj' = adj / (p + eps) = adj / exp(val - x)
+ Element(_1 = _1 / Exp(_2 - _3), adj_, val_, child(0)->val());
+ sparse::LfaBackward(child(1)->grad(), adj_, lf_);
}
- const std::string type() {
- return "lexical_prob";
- }
+ const std::string type() { return "lexical_prob"; }
virtual size_t hash() {
if(!hash_) {
@@ -913,5 +682,4 @@ struct LexicalProbNodeOp : public NaryNodeOp {
float eps_;
Ptr<sparse::CSR> lf_;
};
-
}