Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/src/graph
diff options
context:
space:
mode:
authorMarcin Junczys-Dowmunt <junczys@amu.edu.pl>2017-10-29 17:42:01 +0300
committerMarcin Junczys-Dowmunt <junczys@amu.edu.pl>2017-10-29 17:42:01 +0300
commitbc95140cfb7afc51fa104f4ecab3e5453bd706ef (patch)
tree8d57999c8c556cae36d7b9f8c992826cd0804445 /src/graph
parent2e16934080d4bf41d0ab7557836732bedb635efd (diff)
parent46433253735e79e03613fcbd28e64ff393f72451 (diff)
merge rnn_test.cpp
Diffstat (limited to 'src/graph')
-rw-r--r--src/graph/expression_graph.cu2
-rw-r--r--src/graph/expression_graph.h21
-rw-r--r--src/graph/expression_operators.cu8
-rw-r--r--src/graph/expression_operators.h2
-rw-r--r--src/graph/node_operators.h6
-rw-r--r--src/graph/node_operators_binary.h36
-rw-r--r--src/graph/node_operators_unary.h11
-rw-r--r--src/graph/parameters.h3
8 files changed, 49 insertions, 40 deletions
diff --git a/src/graph/expression_graph.cu b/src/graph/expression_graph.cu
index fb720ecc..829af085 100644
--- a/src/graph/expression_graph.cu
+++ b/src/graph/expression_graph.cu
@@ -49,6 +49,6 @@ Expr ExpressionGraph::gaussian(float mean, float stddev, Shape shape) {
}
void ExpressionGraph::checkNan(Tensor t) {
- UTIL_THROW_IF2(throwNaN_ && IsNan(t), "Tensor has NaN");
+ ABORT_IF(throwNaN_ && IsNan(t), "Tensor has NaN");
}
}
diff --git a/src/graph/expression_graph.h b/src/graph/expression_graph.h
index 1ca0204a..67a5f5ef 100644
--- a/src/graph/expression_graph.h
+++ b/src/graph/expression_graph.h
@@ -178,8 +178,8 @@ public:
* and that all backward pass computations have been performed.
*/
void backward() {
- UTIL_THROW_IF2(topNodes_.size() > 1,
- "There are more than one top most node for backward step");
+ ABORT_IF(topNodes_.size() > 1,
+ "There are more than one top most node for backward step");
params_->allocateBackward();
params_->set_zero_adjoint();
@@ -269,23 +269,22 @@ public:
if(p) {
// if yes add to tape and return
- UTIL_THROW_IF2(shape != p->shape(),
- "Requested shape for existing parameter "
- << name
- << " does not match original shape");
+ ABORT_IF(shape != p->shape(),
+ "Requested shape for existing parameter '{}' does not match "
+ "original shape",
+ name);
add(p);
return p;
}
// if graph was reloaded do not allow creation of new parameters
- UTIL_THROW_IF2(reloaded_,
- "Graph was reloaded and parameter " << name
- << " is newly created");
+ ABORT_IF(reloaded_,
+ "Graph was reloaded and parameter '{}' is newly created",
+ name);
// if not check if name is not taken by other node
- UTIL_THROW_IF2(get(name),
- "Non-parameter with name " << name << "already exists");
+ ABORT_IF(get(name), "Non-parameter with name '{}' already exists", name);
// create parameter node (adds to tape)
p = Expression<ParamNode>(
diff --git a/src/graph/expression_operators.cu b/src/graph/expression_operators.cu
index af592625..10861c14 100644
--- a/src/graph/expression_operators.cu
+++ b/src/graph/expression_operators.cu
@@ -195,11 +195,11 @@ Expr affine(Expr a, Expr b, Expr c) {
}
Expr plus(const std::vector<Expr>&) {
- UTIL_THROW2("Not implemented");
+ ABORT("Not implemented");
}
Expr swish(const std::vector<Expr>&) {
- UTIL_THROW2("Not implemented");
+ ABORT("Not implemented");
}
Expr tanh(const std::vector<Expr>& nodes) {
@@ -207,11 +207,11 @@ Expr tanh(const std::vector<Expr>& nodes) {
}
Expr logit(const std::vector<Expr>&) {
- UTIL_THROW2("Not implemented");
+ ABORT("Not implemented");
}
Expr relu(const std::vector<Expr>&) {
- UTIL_THROW2("Not implemented");
+ ABORT("Not implemented");
}
Expr sqrt(Expr a, float eps) {
diff --git a/src/graph/expression_operators.h b/src/graph/expression_operators.h
index 8824bb40..c99af41d 100644
--- a/src/graph/expression_operators.h
+++ b/src/graph/expression_operators.h
@@ -109,7 +109,7 @@ Expr dropout(Expr x, Args... args) {
auto mask = Get(keywords::mask, nullptr, args...);
float dropout_prob = Get(keywords::dropout_prob, 0.0f, args...);
- UTIL_THROW_IF2(!mask && !dropout_prob, "Neither mask nor dropout prob given");
+ ABORT_IF(!mask && !dropout_prob, "Neither mask nor dropout prob given");
if(!mask) {
auto graph = x->graph();
mask = graph->dropout(dropout_prob, x->shape());
diff --git a/src/graph/node_operators.h b/src/graph/node_operators.h
index cec27bfe..1cb112dd 100644
--- a/src/graph/node_operators.h
+++ b/src/graph/node_operators.h
@@ -11,8 +11,7 @@ struct ConstantNode : public Node {
: Node(args...),
init_(Get(keywords::init, [](Tensor) {})),
initialized_(false) {
- UTIL_THROW_IF2(!Has(keywords::shape),
- "Constant items require shape information");
+ ABORT_IF(!Has(keywords::shape), "Constant items require shape information");
setTrainable(false);
}
@@ -47,8 +46,7 @@ struct ParamNode : public Node {
: Node(args...),
init_(Get(keywords::init, [](Tensor) {})),
initialized_(false) {
- UTIL_THROW_IF2(!Has(keywords::shape),
- "Param items require shape information");
+ ABORT_IF(!Has(keywords::shape), "Param items require shape information");
setTrainable(!Get(keywords::fixed, false));
}
diff --git a/src/graph/node_operators_binary.h b/src/graph/node_operators_binary.h
index 74da4b14..4ec5c092 100644
--- a/src/graph/node_operators_binary.h
+++ b/src/graph/node_operators_binary.h
@@ -41,10 +41,15 @@ private:
public:
template <typename... Args>
- DotNodeOp(
- Expr a, Expr b, bool transA, bool transB, float scalar, Args... args)
- : NaryNodeOp(
- {a, b}, keywords::shape = newShape(a, b, transA, transB), args...),
+ DotNodeOp(Expr a,
+ Expr b,
+ bool transA,
+ bool transB,
+ float scalar,
+ Args... args)
+ : NaryNodeOp({a, b},
+ keywords::shape = newShape(a, b, transA, transB),
+ args...),
transA_(transA),
transB_(transB),
scalar_(scalar) {}
@@ -240,10 +245,15 @@ private:
public:
template <typename... Args>
- DotBatchedNodeOp(
- Expr a, Expr b, bool transA, bool transB, float scalar, Args... args)
- : NaryNodeOp(
- {a, b}, keywords::shape = newShape(a, b, transA, transB), args...),
+ DotBatchedNodeOp(Expr a,
+ Expr b,
+ bool transA,
+ bool transB,
+ float scalar,
+ Args... args)
+ : NaryNodeOp({a, b},
+ keywords::shape = newShape(a, b, transA, transB),
+ args...),
transA_(transA),
transB_(transB),
scalar_(scalar) {}
@@ -263,8 +273,8 @@ public:
Shape outShape = shapeA;
outShape.set(1, shapeB[1]);
- UTIL_THROW_IF2(shapeA[1] != shapeB[0],
- "matrix product requires dimensions to match");
+ ABORT_IF(shapeA[1] != shapeB[0],
+ "matrix product requires dimensions to match");
return outShape;
}
@@ -425,8 +435,8 @@ struct ElementBinaryNodeOp : public NaryNodeOp {
Shape shape1 = a->shape();
Shape shape2 = b->shape();
for(int i = 0; i < shape1.size(); ++i) {
- UTIL_THROW_IF2(shape1[i] != shape2[i] && shape1[i] != 1 && shape2[i] != 1,
- "Shapes cannot be broadcasted");
+ ABORT_IF(shape1[i] != shape2[i] && shape1[i] != 1 && shape2[i] != 1,
+ "Shapes cannot be broadcasted");
shape1.set(i, std::max(shape1[i], shape2[i]));
}
return shape1;
@@ -625,7 +635,7 @@ struct TanhPlus3NodeOp : public NaryNodeOp {
for(int n = 1; n < nodes.size(); ++n) {
Shape shapen = nodes[n]->shape();
for(int i = 0; i < shapen.size(); ++i) {
- UTIL_THROW_IF2(shape[i] != shapen[i] && shape[i] != 1 && shapen[i] != 1,
+ ABORT_IF(shape[i] != shapen[i] && shape[i] != 1 && shapen[i] != 1,
"Shapes cannot be broadcasted");
shape.set(i, std::max(shape[i], shapen[i]));
}
diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h
index 9881357c..a3f60366 100644
--- a/src/graph/node_operators_unary.h
+++ b/src/graph/node_operators_unary.h
@@ -146,8 +146,8 @@ struct TanhNodeOp : public NaryNodeOp {
for(int n = 1; n < nodes.size(); ++n) {
Shape shapen = nodes[n]->shape();
for(int i = 0; i < shapen.size(); ++i) {
- UTIL_THROW_IF2(shape[i] != shapen[i] && shape[i] != 1 && shapen[i] != 1,
- "Shapes cannot be broadcasted");
+ ABORT_IF(shape[i] != shapen[i] && shape[i] != 1 && shapen[i] != 1,
+ "Shapes cannot be broadcasted");
shape.set(i, std::max(shape[i], shapen[i]));
}
}
@@ -237,8 +237,11 @@ struct SwishNodeOp : public UnaryNodeOp {
}
NodeOps backwardOps() {
- return {NodeOp(
- Add(_1 * (_3 + Sigma(_2) * (1.f - _3)), child(0)->grad(), adj_, child(0)->val(), val_))};
+ return {NodeOp(Add(_1 * (_3 + Sigma(_2) * (1.f - _3)),
+ child(0)->grad(),
+ adj_,
+ child(0)->val(),
+ val_))};
}
const std::string type() { return "swish"; }
diff --git a/src/graph/parameters.h b/src/graph/parameters.h
index bbe02f36..df73dbd5 100644
--- a/src/graph/parameters.h
+++ b/src/graph/parameters.h
@@ -51,8 +51,7 @@ public:
void add(Expr p, const std::string& name) {
params_.push_back(p);
- UTIL_THROW_IF2(named_.count(name),
- "Parameter " << name << "already exists");
+ ABORT_IF(named_.count(name), "Parameter '{}' already exists", name);
named_[name] = p;
}