Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/src/graph
diff options
context:
space:
mode:
authorRoman Grundkiewicz <rgrundki@exseed.ed.ac.uk>2018-03-12 23:34:10 +0300
committerRoman Grundkiewicz <rgrundki@exseed.ed.ac.uk>2018-03-12 23:34:10 +0300
commit6d0c75cf48bab913e2c9c52f1c4c6cd0d656005d (patch)
tree717342edade369af33a771f00a7dd05354ea8afb /src/graph
parent5f2eedc6e505eecf5bdef474be3e4f7066702fa7 (diff)
Autoformat files
Diffstat (limited to 'src/graph')
-rw-r--r--src/graph/expression_graph.cpp11
-rw-r--r--src/graph/expression_graph.h10
-rw-r--r--src/graph/expression_operators.cpp87
-rw-r--r--src/graph/expression_operators.h4
-rw-r--r--src/graph/node.cpp1
-rw-r--r--src/graph/node.h5
-rw-r--r--src/graph/node_initializers.cpp11
-rw-r--r--src/graph/node_initializers.h6
-rw-r--r--src/graph/node_operators.h11
-rw-r--r--src/graph/node_operators_binary.h248
-rw-r--r--src/graph/node_operators_unary.h126
11 files changed, 220 insertions, 300 deletions
diff --git a/src/graph/expression_graph.cpp b/src/graph/expression_graph.cpp
index f0ae1ffa..4a0edb34 100644
--- a/src/graph/expression_graph.cpp
+++ b/src/graph/expression_graph.cpp
@@ -1,5 +1,5 @@
-#include <sstream>
#include "graph/expression_graph.h"
+#include <sstream>
#include "tensors/tensor_operators.h"
@@ -18,15 +18,12 @@ void ExpressionGraph::setDevice(DeviceId deviceId) {
}
Expr ExpressionGraph::dropout(float prob, const Shape& shape) {
- return Expression<ConstantNode>(shared_from_this(),
- shape,
- [prob, this](Tensor t) {
- Dropout(t, prob);
- });
+ return Expression<ConstantNode>(
+ shared_from_this(), shape, [prob, this](Tensor t) { Dropout(t, prob); });
}
void ExpressionGraph::checkNan(Tensor t) {
ABORT_IF(throwNaN_, "Not implemented");
- //ABORT_IF(throwNaN_ && IsNan(t), "Tensor has NaN");
+ // ABORT_IF(throwNaN_ && IsNan(t), "Tensor has NaN");
}
}
diff --git a/src/graph/expression_graph.h b/src/graph/expression_graph.h
index ea1645ec..c6cd4558 100644
--- a/src/graph/expression_graph.h
+++ b/src/graph/expression_graph.h
@@ -215,7 +215,9 @@ public:
ABORT_IF(shape != p->shape(),
"Requested shape {} for existing parameter '{}' does not match "
"original shape {}",
- shape, name, p->shape());
+ shape,
+ name,
+ p->shape());
p->setTrainable(!fixed);
add(p);
@@ -239,10 +241,8 @@ public:
return p;
}
- Expr constant(const Shape& shape,
- const NodeInitializer& init) {
- return Expression<ConstantNode>(
- shared_from_this(), shape, init);
+ Expr constant(const Shape& shape, const NodeInitializer& init) {
+ return Expression<ConstantNode>(shared_from_this(), shape, init);
}
Expr ones(const Shape& shape) {
diff --git a/src/graph/expression_operators.cpp b/src/graph/expression_operators.cpp
index a1c9faa4..a4a8b079 100644
--- a/src/graph/expression_operators.cpp
+++ b/src/graph/expression_operators.cpp
@@ -126,7 +126,6 @@ Expr repeat(Expr a, size_t repeats, keywords::axis_k ax) {
return concatenate(std::vector<Expr>(repeats, a), ax);
}
-
Expr reshape(Expr a, Shape shape) {
return Expression<ReshapeNodeOp>(a, shape);
}
@@ -165,10 +164,7 @@ Expr flatten(Expr a) {
}
Expr flatten_2d(Expr a) {
- Shape shape = {
- a->shape().elements() / a->shape()[-1],
- a->shape()[-1]
- };
+ Shape shape = {a->shape().elements() / a->shape()[-1], a->shape()[-1]};
return Expression<ReshapeNodeOp>(a, shape);
}
@@ -232,17 +228,16 @@ Expr step(Expr a, int step, int axis) {
}
Expr cross_entropy(Expr a, Expr b) {
- //auto sOrig = a->shape();
- //auto sOut = a->shape();
- //Shape sTemp({sOrig[0] * sOrig[2] * sOrig[3], sOrig[1], 1, 1});
- //sOut.set(1, 1);
- //return reshape(Expression<CrossEntropyNodeOp>(reshape(a, sTemp), b), sOut);
+ // auto sOrig = a->shape();
+ // auto sOut = a->shape();
+ // Shape sTemp({sOrig[0] * sOrig[2] * sOrig[3], sOrig[1], 1, 1});
+ // sOut.set(1, 1);
+ // return reshape(Expression<CrossEntropyNodeOp>(reshape(a, sTemp), b), sOut);
return Expression<CrossEntropyNodeOp>(a, b);
}
-Expr affine(Expr a, Expr b, Expr c,
- bool transA, bool transB, float scalar) {
+Expr affine(Expr a, Expr b, Expr c, bool transA, bool transB, float scalar) {
std::vector<Expr> nodes = {a, b, c};
return Expression<AffineNodeOp>(nodes, transA, transB, scalar);
}
@@ -299,6 +294,7 @@ Expr highway(Expr y, Expr x, Expr t) {
}
Expr highway(const std::string prefix, Expr x) {
+ // clang-format off
size_t outDim = x->shape()[-1];
auto g = mlp::dense(x->graph())
("prefix", prefix + "_highway_d1")
@@ -311,6 +307,7 @@ Expr highway(const std::string prefix, Expr x) {
("activation", mlp::act::ReLU)
.construct()->apply(x);
return (g * relued) + ((1 - g) * x);
+ // clang-format on
}
// Expr batch_norm(Expr x, Expr gamma, Expr beta) {
@@ -334,41 +331,26 @@ Expr shift(Expr a, Shape shift) {
#ifdef CUDA_FOUND
-Expr avg_pooling(
- Expr x,
- int height,
- int width,
- int padHeight,
- int padWidth,
- int strideHeight,
- int strideWidth) {
- return Expression<PoolingOp>(x,
- height,
- width,
- padHeight,
- padWidth,
- strideHeight,
- strideWidth,
- "avg");
-}
-
-Expr max_pooling(
- Expr x,
- int height,
- int width,
- int padHeight,
- int padWidth,
- int strideHeight,
- int strideWidth)
-{
- return Expression<PoolingOp>(x,
- height,
- width,
- padHeight,
- padWidth,
- strideHeight,
- strideWidth,
- "max");
+Expr avg_pooling(Expr x,
+ int height,
+ int width,
+ int padHeight,
+ int padWidth,
+ int strideHeight,
+ int strideWidth) {
+ return Expression<PoolingOp>(
+ x, height, width, padHeight, padWidth, strideHeight, strideWidth, "avg");
+}
+
+Expr max_pooling(Expr x,
+ int height,
+ int width,
+ int padHeight,
+ int padWidth,
+ int strideHeight,
+ int strideWidth) {
+ return Expression<PoolingOp>(
+ x, height, width, padHeight, padWidth, strideHeight, strideWidth, "max");
}
Expr convert2cudnnFormat(Expr x) {
@@ -377,13 +359,13 @@ Expr convert2cudnnFormat(Expr x) {
int embSize = x->shape()[2];
std::vector<size_t> newIndeces;
- for (int b = 0; b < numExamples; ++b) {
- for (int t = 0; t < numWords; ++t) {
+ for(int b = 0; b < numExamples; ++b) {
+ for(int t = 0; t < numWords; ++t) {
newIndeces.push_back((t * numExamples) + b);
}
}
- auto xRows = reshape(x, {x->shape()[0] * x ->shape()[1], x->shape()[2]});
+ auto xRows = reshape(x, {x->shape()[0] * x->shape()[1], x->shape()[2]});
Shape outShape({numExamples, 1, numWords, embSize});
return reshape(rows(xRows, newIndeces), outShape);
@@ -397,8 +379,8 @@ Expr convertFromcudnnFormat(Expr x) {
auto reshapedX = reshape(x, {batchDim * sentenceDim, embSize});
std::vector<size_t> newIndeces;
- for (int t = 0; t < sentenceDim; ++t) {
- for (int b = 0; b < batchDim; ++b) {
+ for(int t = 0; t < sentenceDim; ++t) {
+ for(int b = 0; b < batchDim; ++b) {
newIndeces.push_back(b * sentenceDim + t);
}
}
@@ -412,5 +394,4 @@ Expr pooling_with_masking(Expr x, Expr mask, int width, bool isEven) {
}
#endif
-
}
diff --git a/src/graph/expression_operators.h b/src/graph/expression_operators.h
index 1145be3c..c637105f 100644
--- a/src/graph/expression_operators.h
+++ b/src/graph/expression_operators.h
@@ -110,7 +110,6 @@ Expr mean(Expr a, keywords::axis_k ax = 0);
Expr cross_entropy(Expr a, Expr b);
-
Expr scalar_product(Expr a, Expr b, keywords::axis_k ax = 0);
Expr weighted_average(Expr in, Expr weights, keywords::axis_k ax = 0);
@@ -161,6 +160,5 @@ Expr max_pooling(Expr x,
int strideHeight = 1,
int strideWidth = 1);
-Expr pooling_with_masking(Expr x, Expr mask, int width, bool isEven=false);
-
+Expr pooling_with_masking(Expr x, Expr mask, int width, bool isEven = false);
}
diff --git a/src/graph/node.cpp b/src/graph/node.cpp
index 1c93683c..721cb30f 100644
--- a/src/graph/node.cpp
+++ b/src/graph/node.cpp
@@ -1,4 +1,5 @@
#include "tensors/backend.h"
+
#include "graph/expression_graph.h"
#include "graph/node.h"
diff --git a/src/graph/node.h b/src/graph/node.h
index 74af5771..15f223aa 100644
--- a/src/graph/node.h
+++ b/src/graph/node.h
@@ -33,8 +33,7 @@ protected:
public:
Node(Ptr<ExpressionGraph> graph, Shape shape)
- : graph_(graph),
- shape_(shape) {}
+ : graph_(graph), shape_(shape) {}
virtual ~Node() {
if(destroy_) {
@@ -152,7 +151,7 @@ struct NaryNodeOp : public Node {
}
NaryNodeOp(const std::vector<Expr>& nodes)
- : NaryNodeOp(nodes, nodes[0]->shape()) {}
+ : NaryNodeOp(nodes, nodes[0]->shape()) {}
virtual ~NaryNodeOp() {}
diff --git a/src/graph/node_initializers.cpp b/src/graph/node_initializers.cpp
index 0d131c61..6650ede6 100644
--- a/src/graph/node_initializers.cpp
+++ b/src/graph/node_initializers.cpp
@@ -109,9 +109,8 @@ void ortho(Tensor t) {
NodeInitializer from_vector(const std::vector<float>& v) {
auto vPtr = New<std::vector<float>>(v.begin(), v.end());
- return [vPtr](Tensor t) {
- t->set(vPtr->data(), vPtr->data() + vPtr->size());
- };
+ return
+ [vPtr](Tensor t) { t->set(vPtr->data(), vPtr->data() + vPtr->size()); };
}
NodeInitializer from_vector(const std::vector<size_t>& v) {
@@ -138,9 +137,9 @@ NodeInitializer from_numpy(const cnpy::NpyArrayPtr& np) {
// move this somewhere else
NodeInitializer from_word2vec(const std::string& file,
- int dimVoc,
- int dimEmb,
- bool normalize /*= false*/) {
+ int dimVoc,
+ int dimEmb,
+ bool normalize /*= false*/) {
return [file, dimVoc, dimEmb, normalize](Tensor t) {
auto embs = Word2VecReader().read(file, dimVoc, dimEmb);
diff --git a/src/graph/node_initializers.h b/src/graph/node_initializers.h
index 5b069657..bd74c6c4 100644
--- a/src/graph/node_initializers.h
+++ b/src/graph/node_initializers.h
@@ -70,9 +70,9 @@ NodeInitializer from_sparse_vector(
NodeInitializer from_numpy(const cnpy::NpyArrayPtr& np);
NodeInitializer from_word2vec(const std::string& file,
- int dimVoc,
- int dimEmb,
- bool normalize = false);
+ int dimVoc,
+ int dimEmb,
+ bool normalize = false);
}
} // namespace marian
diff --git a/src/graph/node_operators.h b/src/graph/node_operators.h
index 8720d0bb..4e97fff3 100644
--- a/src/graph/node_operators.h
+++ b/src/graph/node_operators.h
@@ -7,11 +7,12 @@
namespace marian {
struct ConstantNode : public Node {
- ConstantNode(Ptr<ExpressionGraph> graph, const Shape& shape, const NodeInitializer& init)
+ ConstantNode(Ptr<ExpressionGraph> graph,
+ const Shape& shape,
+ const NodeInitializer& init)
: Node(graph, shape),
init_(new NodeInitializer(init)),
initialized_(false) {
-
setTrainable(false);
}
@@ -41,11 +42,13 @@ private:
};
struct ParamNode : public Node {
- ParamNode(Ptr<ExpressionGraph> graph, const Shape& shape, const NodeInitializer& init, bool fixed = false)
+ ParamNode(Ptr<ExpressionGraph> graph,
+ const Shape& shape,
+ const NodeInitializer& init,
+ bool fixed = false)
: Node(graph, shape),
init_(new NodeInitializer(init)),
initialized_(false) {
-
setTrainable(!fixed);
}
diff --git a/src/graph/node_operators_binary.h b/src/graph/node_operators_binary.h
index c9e67cd7..6fc08690 100644
--- a/src/graph/node_operators_binary.h
+++ b/src/graph/node_operators_binary.h
@@ -16,13 +16,8 @@ private:
float scalar_;
public:
- DotNodeOp(Expr a,
- Expr b,
- bool transA,
- bool transB,
- float scalar)
- : NaryNodeOp({a, b},
- newShape(a, b, transA, transB)),
+ DotNodeOp(Expr a, Expr b, bool transA, bool transB, float scalar)
+ : NaryNodeOp({a, b}, newShape(a, b, transA, transB)),
transA_(transA),
transB_(transB),
scalar_(scalar) {}
@@ -49,14 +44,13 @@ public:
NodeOps forwardOps() {
// C = alpha * dot(op(A), op(B))
- return {NodeOp(Prod(
- val_,
- child(0)->val(),
- child(1)->val(),
- transA_,
- transB_,
- 0.f,
- scalar_))};
+ return {NodeOp(Prod(val_,
+ child(0)->val(),
+ child(1)->val(),
+ transA_,
+ transB_,
+ 0.f,
+ scalar_))};
}
NodeOps backwardOps() {
@@ -149,7 +143,7 @@ public:
: NaryNodeOp(nodes, newShape(nodes[0], nodes[1], transA, transB)),
transA_(transA),
transB_(transB),
- scalar_(scalar){}
+ scalar_(scalar) {}
Shape newShape(Expr a, Expr b, bool transA, bool transB) {
auto shapeA = a->shape();
@@ -171,19 +165,17 @@ public:
return outShape;
}
-
NodeOps forwardOps() {
using namespace functional;
return {
- NodeOp(Prod(
- val_,
- child(0)->val(),
- child(1)->val(),
- transA_,
- transB_,
- 0.f,
- scalar_);
- Add(_1, val_, child(2)->val()))
+ NodeOp(Prod(val_,
+ child(0)->val(),
+ child(1)->val(),
+ transA_,
+ transB_,
+ 0.f,
+ scalar_);
+ Add(_1, val_, child(2)->val()))
};
}
@@ -266,7 +258,6 @@ public:
const std::string type() { return "affine"; }
};
-
class DotBatchedNodeOp : public NaryNodeOp {
private:
bool transA_;
@@ -274,13 +265,8 @@ private:
float scalar_;
public:
- DotBatchedNodeOp(Expr a,
- Expr b,
- bool transA,
- bool transB,
- float scalar)
- : NaryNodeOp({a, b},
- newShape(a, b, transA, transB)),
+ DotBatchedNodeOp(Expr a, Expr b, bool transA, bool transB, float scalar)
+ : NaryNodeOp({a, b}, newShape(a, b, transA, transB)),
transA_(transA),
transB_(transB),
scalar_(scalar) {}
@@ -307,14 +293,13 @@ public:
NodeOps forwardOps() {
// C = alpha * dot(op(A), op(B))
- return {NodeOp(ProdBatched(
- val_,
- child(0)->val(),
- child(1)->val(),
- transA_,
- transB_,
- 0.f,
- scalar_))};
+ return {NodeOp(ProdBatched(val_,
+ child(0)->val(),
+ child(1)->val(),
+ transA_,
+ transB_,
+ 0.f,
+ scalar_))};
}
NodeOps backwardOps() {
@@ -325,71 +310,67 @@ public:
// to sum gradients from different graph parts
if(!transA_ && transB_)
- return {
- NodeOp(ProdBatched(child(0)->grad(),
- adj_,
- child(1)->val(),
- false,
- false,
- 1.0,
- scalar_)),
- NodeOp(ProdBatched(child(1)->grad(),
- adj_,
- child(0)->val(),
- true,
- false,
- 1.0,
- scalar_))};
+ return {NodeOp(ProdBatched(child(0)->grad(),
+ adj_,
+ child(1)->val(),
+ false,
+ false,
+ 1.0,
+ scalar_)),
+ NodeOp(ProdBatched(child(1)->grad(),
+ adj_,
+ child(0)->val(),
+ true,
+ false,
+ 1.0,
+ scalar_))};
if(transA_ && !transB_)
- return {
- NodeOp(ProdBatched(child(0)->grad(),
- child(1)->val(),
- adj_,
- false,
- true,
- 1.0,
- scalar_)),
- NodeOp(ProdBatched(child(1)->grad(),
- child(0)->val(),
- adj_,
- false,
- false,
- 1.0,
- scalar_))};
+ return {NodeOp(ProdBatched(child(0)->grad(),
+ child(1)->val(),
+ adj_,
+ false,
+ true,
+ 1.0,
+ scalar_)),
+ NodeOp(ProdBatched(child(1)->grad(),
+ child(0)->val(),
+ adj_,
+ false,
+ false,
+ 1.0,
+ scalar_))};
if(transA_ && transB_)
- return {
- NodeOp(ProdBatched(child(0)->grad(),
- child(1)->val(),
- adj_,
- true,
- true,
- 1.0,
- scalar_)),
- NodeOp(ProdBatched(child(1)->grad(),
- adj_,
- child(0)->val(),
- true,
- true,
- 1.0,
- scalar_))};
-
- return {
- NodeOp(ProdBatched(child(0)->grad(),
- adj_,
- child(1)->val(),
- false,
- true,
- 1.0,
- scalar_)),
- NodeOp(ProdBatched(child(1)->grad(),
- child(0)->val(),
- adj_,
- true,
- false,
- 1.0,
- scalar_))};
+ return {NodeOp(ProdBatched(child(0)->grad(),
+ child(1)->val(),
+ adj_,
+ true,
+ true,
+ 1.0,
+ scalar_)),
+ NodeOp(ProdBatched(child(1)->grad(),
+ adj_,
+ child(0)->val(),
+ true,
+ true,
+ 1.0,
+ scalar_))};
+
+ return {NodeOp(ProdBatched(child(0)->grad(),
+ adj_,
+ child(1)->val(),
+ false,
+ true,
+ 1.0,
+ scalar_)),
+ NodeOp(ProdBatched(child(1)->grad(),
+ child(0)->val(),
+ adj_,
+ true,
+ false,
+ 1.0,
+ scalar_))};
}
const std::string type() { return "•"; }
@@ -400,8 +381,7 @@ public:
struct ScalarProductNodeOp : public NaryNodeOp {
template <typename... Args>
ScalarProductNodeOp(Expr a, Expr b, Args... args)
- : NaryNodeOp({a, b}, newShape(a, b, args...)) {
- }
+ : NaryNodeOp({a, b}, newShape(a, b, args...)) {}
template <typename... Args>
Shape newShape(Expr a, Expr b, Args... args) {
@@ -433,12 +413,9 @@ struct ScalarProductNodeOp : public NaryNodeOp {
};
struct ElementBinaryNodeOp : public NaryNodeOp {
- ElementBinaryNodeOp(Expr a, Expr b)
- : NaryNodeOp({a, b}, newShape(a, b)) {}
+ ElementBinaryNodeOp(Expr a, Expr b) : NaryNodeOp({a, b}, newShape(a, b)) {}
- Shape newShape(Expr a, Expr b) {
- return Shape::broadcast({a, b});
- }
+ Shape newShape(Expr a, Expr b) { return Shape::broadcast({a, b}); }
const std::string color() { return "yellow"; }
};
@@ -553,8 +530,7 @@ struct DivNodeOp : public ElementBinaryNodeOp {
// Cross-entropy node. It computes -b*log(softmax(a)), summing rowwise.
struct CrossEntropyNodeOp : public NaryNodeOp {
- CrossEntropyNodeOp(Expr a, Expr b)
- : NaryNodeOp({a, b}, newShape(a)) {}
+ CrossEntropyNodeOp(Expr a, Expr b) : NaryNodeOp({a, b}, newShape(a)) {}
Shape newShape(Expr a) {
Shape shape1 = a->shape();
@@ -578,7 +554,9 @@ struct CrossEntropyNodeOp : public NaryNodeOp {
struct ConcatenateNodeOp : public NaryNodeOp {
template <typename... Args>
ConcatenateNodeOp(const std::vector<Expr>& nodes, Args... args)
- : NaryNodeOp(nodes, newShape(nodes, keywords::Get(keywords::axis, 0, args...))) {}
+ : NaryNodeOp(nodes,
+ newShape(nodes, keywords::Get(keywords::axis, 0, args...))) {
+ }
Shape newShape(const std::vector<Expr>& nodes, int ax) {
Shape shape = nodes.back()->shape();
@@ -730,38 +708,33 @@ struct HighwayNodeOp : public NaryNodeOp {
class ConvolutionOp : public NaryNodeOp {
public:
- ConvolutionOp(
- const std::vector<Expr>& nodes,
- int hPad = 0,
- int wPad = 0,
- int hStride = 1,
- int wStride = 1)
- : NaryNodeOp(nodes),
- conv_(nodes[1]->shape(),
- nodes[2]->shape(),
- hPad,
- wPad,
- hStride,
- wStride) {
+ ConvolutionOp(const std::vector<Expr>& nodes,
+ int hPad = 0,
+ int wPad = 0,
+ int hStride = 1,
+ int wStride = 1)
+ : NaryNodeOp(nodes),
+ conv_(nodes[1]->shape(),
+ nodes[2]->shape(),
+ hPad,
+ wPad,
+ hStride,
+ wStride) {
conv_.getOutputShape(nodes[0]->shape(), shape_);
}
NodeOps forwardOps() {
return {NodeOp(conv_.forward(
- child(0)->val(),
- child(1)->val(),
- child(2)->val(),
- val_))};
+ child(0)->val(), child(1)->val(), child(2)->val(), val_))};
}
NodeOps backwardOps() {
- return {NodeOp(conv_.backward(
- child(0)->val(),
- child(0)->grad(),
- child(1)->val(),
- child(1)->grad(),
- child(2)->grad(),
- adj_))};
+ return {NodeOp(conv_.backward(child(0)->val(),
+ child(0)->grad(),
+ child(1)->val(),
+ child(1)->grad(),
+ child(2)->grad(),
+ adj_))};
}
const std::string type() { return "layer_convolution"; }
@@ -769,5 +742,4 @@ public:
protected:
ConvolutionWrapper conv_;
};
-
}
diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h
index 0ca2c2a2..8d81a63a 100644
--- a/src/graph/node_operators_unary.h
+++ b/src/graph/node_operators_unary.h
@@ -12,11 +12,9 @@
namespace marian {
struct UnaryNodeOp : public NaryNodeOp {
- UnaryNodeOp(Expr a, Shape shape)
- : NaryNodeOp({a}, shape) {}
+ UnaryNodeOp(Expr a, Shape shape) : NaryNodeOp({a}, shape) {}
- UnaryNodeOp(Expr a)
- : NaryNodeOp({a}, a->shape()) {}
+ UnaryNodeOp(Expr a) : NaryNodeOp({a}, a->shape()) {}
const std::string color() { return "yellow"; }
};
@@ -26,9 +24,7 @@ private:
float scalar_{0};
public:
- ScalarAddNodeOp(Expr a, float scalar)
- : UnaryNodeOp(a),
- scalar_{scalar} {}
+ ScalarAddNodeOp(Expr a, float scalar) : UnaryNodeOp(a), scalar_{scalar} {}
NodeOps forwardOps() {
using namespace functional;
@@ -67,8 +63,7 @@ private:
float scalar_{0};
public:
- ScalarMultNodeOp(Expr a, float scalar)
- : UnaryNodeOp(a), scalar_{scalar} {}
+ ScalarMultNodeOp(Expr a, float scalar) : UnaryNodeOp(a), scalar_{scalar} {}
NodeOps forwardOps() {
using namespace functional;
@@ -210,7 +205,6 @@ struct TanhNodeOp : public NaryNodeOp {
const std::string type() { return "tanh"; }
};
-
struct ReLUNodeOp : public UnaryNodeOp {
ReLUNodeOp(Expr a) : UnaryNodeOp(a) {}
@@ -262,8 +256,7 @@ struct ReLUNodeOp : public UnaryNodeOp {
* \f]
*/
struct PReLUNodeOp : public UnaryNodeOp {
- PReLUNodeOp(float alpha, Expr a)
- : UnaryNodeOp(a), alpha_(alpha) {}
+ PReLUNodeOp(float alpha, Expr a) : UnaryNodeOp(a), alpha_(alpha) {}
NodeOps forwardOps() {
using namespace functional;
@@ -334,11 +327,9 @@ struct SwishNodeOp : public UnaryNodeOp {
};
struct SoftmaxNodeOp : public UnaryNodeOp {
- SoftmaxNodeOp(Expr a)
- : UnaryNodeOp(a), mask_(nullptr) {}
+ SoftmaxNodeOp(Expr a) : UnaryNodeOp(a), mask_(nullptr) {}
- SoftmaxNodeOp(Expr a, Expr mask)
- : UnaryNodeOp(a), mask_(mask) {}
+ SoftmaxNodeOp(Expr a, Expr mask) : UnaryNodeOp(a), mask_(mask) {}
Expr mask_;
@@ -407,17 +398,18 @@ struct SumNodeOp : public UnaryNodeOp {
int ax_;
template <typename... Args>
- SumNodeOp(Expr a, Args... args)
- : UnaryNodeOp(a, newShape(a, args...)) {}
+ SumNodeOp(Expr a, Args... args) : UnaryNodeOp(a, newShape(a, args...)) {}
NodeOps forwardOps() {
using namespace functional;
- return {NodeOp(Reduce(_1, val_, child(0)->val()))}; }
+ return {NodeOp(Reduce(_1, val_, child(0)->val()))};
+ }
NodeOps backwardOps() {
using namespace functional;
- return {NodeOp(Add(_1, child(0)->grad(), adj_))}; }
+ return {NodeOp(Add(_1, child(0)->grad(), adj_))};
+ }
template <class... Args>
Shape newShape(Expr a, Args... args) {
@@ -456,8 +448,7 @@ struct MeanNodeOp : public UnaryNodeOp {
int ax_;
template <typename... Args>
- MeanNodeOp(Expr a, Args... args)
- : UnaryNodeOp(a, newShape(a, args...)) {}
+ MeanNodeOp(Expr a, Args... args) : UnaryNodeOp(a, newShape(a, args...)) {}
NodeOps forwardOps() {
using namespace functional;
@@ -543,8 +534,7 @@ struct ExpNodeOp : public UnaryNodeOp {
struct SqrtNodeOp : public UnaryNodeOp {
float epsilon_;
- SqrtNodeOp(Expr a, float epsilon)
- : UnaryNodeOp(a), epsilon_(epsilon) {}
+ SqrtNodeOp(Expr a, float epsilon) : UnaryNodeOp(a), epsilon_(epsilon) {}
NodeOps forwardOps() {
using namespace functional;
@@ -614,8 +604,7 @@ struct NegNodeOp : public UnaryNodeOp {
struct RowsNodeOp : public UnaryNodeOp {
RowsNodeOp(Expr a, const std::vector<size_t>& indeces)
- : UnaryNodeOp(a, newShape(a, indeces)),
- indices_(indeces) {}
+ : UnaryNodeOp(a, newShape(a, indeces)), indices_(indeces) {}
NodeOps forwardOps() {
// @TODO: solve this with a tensor!
@@ -666,8 +655,7 @@ struct RowsNodeOp : public UnaryNodeOp {
struct ColsNodeOp : public UnaryNodeOp {
ColsNodeOp(Expr a, const std::vector<size_t>& indeces)
- : UnaryNodeOp(a, newShape(a, indeces)),
- indices_(indeces) {}
+ : UnaryNodeOp(a, newShape(a, indeces)), indices_(indeces) {}
NodeOps forwardOps() {
// @TODO: solve this with a tensor!
@@ -716,8 +704,7 @@ struct ColsNodeOp : public UnaryNodeOp {
struct SelectNodeOp : public UnaryNodeOp {
SelectNodeOp(Expr a, int axis, const std::vector<size_t>& indeces)
- : UnaryNodeOp(a, newShape(a, axis, indeces)),
- indices_(indeces) {}
+ : UnaryNodeOp(a, newShape(a, axis, indeces)), indices_(indeces) {}
NodeOps forwardOps() {
return {NodeOp(
@@ -772,8 +759,7 @@ struct TransposeNodeOp : public UnaryNodeOp {
std::vector<int> axes_;
TransposeNodeOp(Expr a, const std::vector<int>& axes)
- : UnaryNodeOp(a, newShape(a, axes)),
- axes_{axes} {}
+ : UnaryNodeOp(a, newShape(a, axes)), axes_{axes} {}
NodeOps forwardOps() {
return {NodeOp(TransposeND(val_, child(0)->val(), axes_))};
@@ -788,7 +774,7 @@ struct TransposeNodeOp : public UnaryNodeOp {
Shape shape = a->shape();
ABORT_IF(shape.size() != axes.size(),
- "Shape and transpose axes have different number of dimensions");
+ "Shape and transpose axes have different number of dimensions");
for(int i = 0; i < shape.size(); ++i)
shape.set(i, a->shape()[axes[i]]);
@@ -829,8 +815,7 @@ private:
public:
template <typename... Args>
- ReshapeNodeOp(Expr a, Shape shape)
- : UnaryNodeOp(a, shape), reshapee_(a) {
+ ReshapeNodeOp(Expr a, Shape shape) : UnaryNodeOp(a, shape), reshapee_(a) {
Node::destroy_ = false;
}
@@ -894,9 +879,7 @@ private:
public:
StepNodeOp(Expr a, int step, int axis)
- : UnaryNodeOp(a, newShape(a, axis)),
- stepNode_(a),
- step_(step) {
+ : UnaryNodeOp(a, newShape(a, axis)), stepNode_(a), step_(step) {
Node::destroy_ = false;
}
@@ -1056,67 +1039,54 @@ public:
padWidth,
strideHeight,
strideWidth,
- mode) {
- }
+ mode) {}
NodeOps forwardOps() {
return {NodeOp(pooling_.forward(child(0)->val(), val_))};
}
NodeOps backwardOps() {
- return {NodeOp(pooling_.backward(
- child(0)->val(),
- child(0)->grad(),
- val_,
- adj_))};
+ return {NodeOp(
+ pooling_.backward(child(0)->val(), child(0)->grad(), val_, adj_))};
}
const std::string type() { return "layer_pooling"; }
-
protected:
PoolingWrapper pooling_;
};
class PoolingWithMaskingOp : public UnaryNodeOp {
- public:
- PoolingWithMaskingOp( Expr x, Expr mask, int width, bool isEven=false)
- : UnaryNodeOp(x),
- mask_(mask),
- width_(width),
- isEven_(isEven)
- {
- auto xShape = x->shape();
- int dimBatch = xShape[0];
- int dimWord = xShape[1];
- int cols = (isEven_) ? xShape[2] - 1 : xShape[2];
- int dimSentence = (cols / width_) + (cols % width_ != 0);
- shape_ = {dimBatch, dimWord, dimSentence};
- }
+public:
+ PoolingWithMaskingOp(Expr x, Expr mask, int width, bool isEven = false)
+ : UnaryNodeOp(x), mask_(mask), width_(width), isEven_(isEven) {
+ auto xShape = x->shape();
+ int dimBatch = xShape[0];
+ int dimWord = xShape[1];
+ int cols = (isEven_) ? xShape[2] - 1 : xShape[2];
+ int dimSentence = (cols / width_) + (cols % width_ != 0);
+ shape_ = {dimBatch, dimWord, dimSentence};
+ }
- NodeOps forwardOps() {
- return {NodeOp(PoolingWithMaskingForward(val_,
+ NodeOps forwardOps() {
+ return {NodeOp(PoolingWithMaskingForward(
+ val_, child(0)->val(), mask_->val(), width_, isEven_))};
+ }
+
+ NodeOps backwardOps() {
+ return {NodeOp(PoolingWithMaskingBackward(adj_,
+ child(0)->grad(),
child(0)->val(),
mask_->val(),
width_,
isEven_))};
- }
-
- NodeOps backwardOps() {
- return {NodeOp(PoolingWithMaskingBackward(adj_,
- child(0)->grad(),
- child(0)->val(),
- mask_->val(),
- width_,
- isEven_))};
- }
+ }
- const std::string type() {return "layer_pooling";}
+ const std::string type() { return "layer_pooling"; }
- protected:
- Expr mask_;
- int width_;
- bool isEven_;
+protected:
+ Expr mask_;
+ int width_;
+ bool isEven_;
};
-
}