From 327cfc1cc3fbe3ab92927c800aa30aef1c41d517 Mon Sep 17 00:00:00 2001 From: Marcin Junczys-Dowmunt Date: Fri, 16 Feb 2018 11:59:12 -0800 Subject: pass through backend --- src/graph/node_operators_unary.h | 24 +++++------------------- 1 file changed, 5 insertions(+), 19 deletions(-) (limited to 'src/graph/node_operators_unary.h') diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h index a3f27fd2..0170fc73 100644 --- a/src/graph/node_operators_unary.h +++ b/src/graph/node_operators_unary.h @@ -211,21 +211,7 @@ struct TanhNodeOp : public NaryNodeOp { const std::string type() { return "tanh"; } }; -/** - * Represents a rectified - * linear node in an expression graph. - * - * This node implements the activation function \f$ f(x) = \max(0, x) \f$ and - * its derivative: - * \f[ - * f^\prime(x) = - * \begin{cases} - * 0 & \text{if } x \leq 0 \\ - * 1 & \text{if } x > 0 - * \end{cases} - * \f] - */ + struct ReLUNodeOp : public UnaryNodeOp { template ReLUNodeOp(Args... args) : UnaryNodeOp(args...) {} @@ -877,14 +863,14 @@ public: Tensor& val() { auto childVal = reshapee_->val(); val_.reset( - new TensorBase(childVal->memory(), shape(), childVal->getDevice())); + new TensorBase(childVal->memory(), shape(), childVal->getBackend())); return val_; }; Tensor& grad() { auto childGrad = reshapee_->grad(); adj_.reset( - new TensorBase(childGrad->memory(), shape(), childGrad->getDevice())); + new TensorBase(childGrad->memory(), shape(), childGrad->getBackend())); return adj_; }; @@ -953,7 +939,7 @@ public: size_t offset = step_ * shape().elements() * sizeof(float); auto mem = New(childVal->memory()->data() + offset, childVal->memory()->size()); - val_.reset(new TensorBase(mem, shape(), childVal->getDevice())); + val_.reset(new TensorBase(mem, shape(), childVal->getBackend())); return val_; }; @@ -962,7 +948,7 @@ public: size_t offset = step_ * shape().elements() * sizeof(float); auto mem = New(childGrad->memory()->data() + offset, childGrad->memory()->size()); - adj_.reset(new TensorBase(mem, shape(), childGrad->getDevice())); + adj_.reset(new TensorBase(mem, shape(), childGrad->getBackend())); return adj_; }; -- cgit v1.2.3