diff options
author | Marcin Junczys-Dowmunt <junczys@amu.edu.pl> | 2017-10-29 17:42:01 +0300 |
---|---|---|
committer | Marcin Junczys-Dowmunt <junczys@amu.edu.pl> | 2017-10-29 17:42:01 +0300 |
commit | bc95140cfb7afc51fa104f4ecab3e5453bd706ef (patch) | |
tree | 8d57999c8c556cae36d7b9f8c992826cd0804445 /src/graph | |
parent | 2e16934080d4bf41d0ab7557836732bedb635efd (diff) | |
parent | 46433253735e79e03613fcbd28e64ff393f72451 (diff) |
merge rnn_test.cpp
Diffstat (limited to 'src/graph')
-rw-r--r-- | src/graph/expression_graph.cu | 2 | ||||
-rw-r--r-- | src/graph/expression_graph.h | 21 | ||||
-rw-r--r-- | src/graph/expression_operators.cu | 8 | ||||
-rw-r--r-- | src/graph/expression_operators.h | 2 | ||||
-rw-r--r-- | src/graph/node_operators.h | 6 | ||||
-rw-r--r-- | src/graph/node_operators_binary.h | 36 | ||||
-rw-r--r-- | src/graph/node_operators_unary.h | 11 | ||||
-rw-r--r-- | src/graph/parameters.h | 3 |
8 files changed, 49 insertions, 40 deletions
diff --git a/src/graph/expression_graph.cu b/src/graph/expression_graph.cu index fb720ecc..829af085 100644 --- a/src/graph/expression_graph.cu +++ b/src/graph/expression_graph.cu @@ -49,6 +49,6 @@ Expr ExpressionGraph::gaussian(float mean, float stddev, Shape shape) { } void ExpressionGraph::checkNan(Tensor t) { - UTIL_THROW_IF2(throwNaN_ && IsNan(t), "Tensor has NaN"); + ABORT_IF(throwNaN_ && IsNan(t), "Tensor has NaN"); } } diff --git a/src/graph/expression_graph.h b/src/graph/expression_graph.h index 1ca0204a..67a5f5ef 100644 --- a/src/graph/expression_graph.h +++ b/src/graph/expression_graph.h @@ -178,8 +178,8 @@ public: * and that all backward pass computations have been performed. */ void backward() { - UTIL_THROW_IF2(topNodes_.size() > 1, - "There are more than one top most node for backward step"); + ABORT_IF(topNodes_.size() > 1, + "There are more than one top most node for backward step"); params_->allocateBackward(); params_->set_zero_adjoint(); @@ -269,23 +269,22 @@ public: if(p) { // if yes add to tape and return - UTIL_THROW_IF2(shape != p->shape(), - "Requested shape for existing parameter " - << name - << " does not match original shape"); + ABORT_IF(shape != p->shape(), + "Requested shape for existing parameter '{}' does not match " + "original shape", + name); add(p); return p; } // if graph was reloaded do not allow creation of new parameters - UTIL_THROW_IF2(reloaded_, - "Graph was reloaded and parameter " << name - << " is newly created"); + ABORT_IF(reloaded_, + "Graph was reloaded and parameter '{}' is newly created", + name); // if not check if name is not taken by other node - UTIL_THROW_IF2(get(name), - "Non-parameter with name " << name << "already exists"); + ABORT_IF(get(name), "Non-parameter with name '{}' already exists", name); // create parameter node (adds to tape) p = Expression<ParamNode>( diff --git a/src/graph/expression_operators.cu b/src/graph/expression_operators.cu index af592625..10861c14 100644 --- a/src/graph/expression_operators.cu +++ b/src/graph/expression_operators.cu @@ -195,11 +195,11 @@ Expr affine(Expr a, Expr b, Expr c) { } Expr plus(const std::vector<Expr>&) { - UTIL_THROW2("Not implemented"); + ABORT("Not implemented"); } Expr swish(const std::vector<Expr>&) { - UTIL_THROW2("Not implemented"); + ABORT("Not implemented"); } Expr tanh(const std::vector<Expr>& nodes) { @@ -207,11 +207,11 @@ Expr tanh(const std::vector<Expr>& nodes) { } Expr logit(const std::vector<Expr>&) { - UTIL_THROW2("Not implemented"); + ABORT("Not implemented"); } Expr relu(const std::vector<Expr>&) { - UTIL_THROW2("Not implemented"); + ABORT("Not implemented"); } Expr sqrt(Expr a, float eps) { diff --git a/src/graph/expression_operators.h b/src/graph/expression_operators.h index 8824bb40..c99af41d 100644 --- a/src/graph/expression_operators.h +++ b/src/graph/expression_operators.h @@ -109,7 +109,7 @@ Expr dropout(Expr x, Args... args) { auto mask = Get(keywords::mask, nullptr, args...); float dropout_prob = Get(keywords::dropout_prob, 0.0f, args...); - UTIL_THROW_IF2(!mask && !dropout_prob, "Neither mask nor dropout prob given"); + ABORT_IF(!mask && !dropout_prob, "Neither mask nor dropout prob given"); if(!mask) { auto graph = x->graph(); mask = graph->dropout(dropout_prob, x->shape()); diff --git a/src/graph/node_operators.h b/src/graph/node_operators.h index cec27bfe..1cb112dd 100644 --- a/src/graph/node_operators.h +++ b/src/graph/node_operators.h @@ -11,8 +11,7 @@ struct ConstantNode : public Node { : Node(args...), init_(Get(keywords::init, [](Tensor) {})), initialized_(false) { - UTIL_THROW_IF2(!Has(keywords::shape), - "Constant items require shape information"); + ABORT_IF(!Has(keywords::shape), "Constant items require shape information"); setTrainable(false); } @@ -47,8 +46,7 @@ struct ParamNode : public Node { : Node(args...), init_(Get(keywords::init, [](Tensor) {})), initialized_(false) { - UTIL_THROW_IF2(!Has(keywords::shape), - "Param items require shape information"); + ABORT_IF(!Has(keywords::shape), "Param items require shape information"); setTrainable(!Get(keywords::fixed, false)); } diff --git a/src/graph/node_operators_binary.h b/src/graph/node_operators_binary.h index 74da4b14..4ec5c092 100644 --- a/src/graph/node_operators_binary.h +++ b/src/graph/node_operators_binary.h @@ -41,10 +41,15 @@ private: public: template <typename... Args> - DotNodeOp( - Expr a, Expr b, bool transA, bool transB, float scalar, Args... args) - : NaryNodeOp( - {a, b}, keywords::shape = newShape(a, b, transA, transB), args...), + DotNodeOp(Expr a, + Expr b, + bool transA, + bool transB, + float scalar, + Args... args) + : NaryNodeOp({a, b}, + keywords::shape = newShape(a, b, transA, transB), + args...), transA_(transA), transB_(transB), scalar_(scalar) {} @@ -240,10 +245,15 @@ private: public: template <typename... Args> - DotBatchedNodeOp( - Expr a, Expr b, bool transA, bool transB, float scalar, Args... args) - : NaryNodeOp( - {a, b}, keywords::shape = newShape(a, b, transA, transB), args...), + DotBatchedNodeOp(Expr a, + Expr b, + bool transA, + bool transB, + float scalar, + Args... args) + : NaryNodeOp({a, b}, + keywords::shape = newShape(a, b, transA, transB), + args...), transA_(transA), transB_(transB), scalar_(scalar) {} @@ -263,8 +273,8 @@ public: Shape outShape = shapeA; outShape.set(1, shapeB[1]); - UTIL_THROW_IF2(shapeA[1] != shapeB[0], - "matrix product requires dimensions to match"); + ABORT_IF(shapeA[1] != shapeB[0], + "matrix product requires dimensions to match"); return outShape; } @@ -425,8 +435,8 @@ struct ElementBinaryNodeOp : public NaryNodeOp { Shape shape1 = a->shape(); Shape shape2 = b->shape(); for(int i = 0; i < shape1.size(); ++i) { - UTIL_THROW_IF2(shape1[i] != shape2[i] && shape1[i] != 1 && shape2[i] != 1, - "Shapes cannot be broadcasted"); + ABORT_IF(shape1[i] != shape2[i] && shape1[i] != 1 && shape2[i] != 1, + "Shapes cannot be broadcasted"); shape1.set(i, std::max(shape1[i], shape2[i])); } return shape1; @@ -625,7 +635,7 @@ struct TanhPlus3NodeOp : public NaryNodeOp { for(int n = 1; n < nodes.size(); ++n) { Shape shapen = nodes[n]->shape(); for(int i = 0; i < shapen.size(); ++i) { - UTIL_THROW_IF2(shape[i] != shapen[i] && shape[i] != 1 && shapen[i] != 1, + ABORT_IF(shape[i] != shapen[i] && shape[i] != 1 && shapen[i] != 1, "Shapes cannot be broadcasted"); shape.set(i, std::max(shape[i], shapen[i])); } diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h index 9881357c..a3f60366 100644 --- a/src/graph/node_operators_unary.h +++ b/src/graph/node_operators_unary.h @@ -146,8 +146,8 @@ struct TanhNodeOp : public NaryNodeOp { for(int n = 1; n < nodes.size(); ++n) { Shape shapen = nodes[n]->shape(); for(int i = 0; i < shapen.size(); ++i) { - UTIL_THROW_IF2(shape[i] != shapen[i] && shape[i] != 1 && shapen[i] != 1, - "Shapes cannot be broadcasted"); + ABORT_IF(shape[i] != shapen[i] && shape[i] != 1 && shapen[i] != 1, + "Shapes cannot be broadcasted"); shape.set(i, std::max(shape[i], shapen[i])); } } @@ -237,8 +237,11 @@ struct SwishNodeOp : public UnaryNodeOp { } NodeOps backwardOps() { - return {NodeOp( - Add(_1 * (_3 + Sigma(_2) * (1.f - _3)), child(0)->grad(), adj_, child(0)->val(), val_))}; + return {NodeOp(Add(_1 * (_3 + Sigma(_2) * (1.f - _3)), + child(0)->grad(), + adj_, + child(0)->val(), + val_))}; } const std::string type() { return "swish"; } diff --git a/src/graph/parameters.h b/src/graph/parameters.h index bbe02f36..df73dbd5 100644 --- a/src/graph/parameters.h +++ b/src/graph/parameters.h @@ -51,8 +51,7 @@ public: void add(Expr p, const std::string& name) { params_.push_back(p); - UTIL_THROW_IF2(named_.count(name), - "Parameter " << name << "already exists"); + ABORT_IF(named_.count(name), "Parameter '{}' already exists", name); named_[name] = p; } |