diff options
author | Marcin Junczys-Dowmunt <junczys@amu.edu.pl> | 2018-02-16 22:59:12 +0300 |
---|---|---|
committer | Marcin Junczys-Dowmunt <junczys@amu.edu.pl> | 2018-02-16 22:59:12 +0300 |
commit | 327cfc1cc3fbe3ab92927c800aa30aef1c41d517 (patch) | |
tree | 9e67c0035ebb7d2c022982443efc1ad338ebc615 /src/graph/node_operators_unary.h | |
parent | dd296e77f76143033fc1589c7dce6d12196bbfdd (diff) |
pass through backend
Diffstat (limited to 'src/graph/node_operators_unary.h')
-rw-r--r-- | src/graph/node_operators_unary.h | 24 |
1 files changed, 5 insertions, 19 deletions
diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h index a3f27fd2..0170fc73 100644 --- a/src/graph/node_operators_unary.h +++ b/src/graph/node_operators_unary.h @@ -211,21 +211,7 @@ struct TanhNodeOp : public NaryNodeOp { const std::string type() { return "tanh"; } }; -/** - * Represents a <a - * href="https://en.wikipedia.org/wiki/Rectifier_(neural_networks)">rectified - * linear</a> node in an expression graph. - * - * This node implements the activation function \f$ f(x) = \max(0, x) \f$ and - * its derivative: - * \f[ - * f^\prime(x) = - * \begin{cases} - * 0 & \text{if } x \leq 0 \\ - * 1 & \text{if } x > 0 - * \end{cases} - * \f] - */ + struct ReLUNodeOp : public UnaryNodeOp { template <typename... Args> ReLUNodeOp(Args... args) : UnaryNodeOp(args...) {} @@ -877,14 +863,14 @@ public: Tensor& val() { auto childVal = reshapee_->val(); val_.reset( - new TensorBase(childVal->memory(), shape(), childVal->getDevice())); + new TensorBase(childVal->memory(), shape(), childVal->getBackend())); return val_; }; Tensor& grad() { auto childGrad = reshapee_->grad(); adj_.reset( - new TensorBase(childGrad->memory(), shape(), childGrad->getDevice())); + new TensorBase(childGrad->memory(), shape(), childGrad->getBackend())); return adj_; }; @@ -953,7 +939,7 @@ public: size_t offset = step_ * shape().elements() * sizeof(float); auto mem = New<MemoryPiece>(childVal->memory()->data() + offset, childVal->memory()->size()); - val_.reset(new TensorBase(mem, shape(), childVal->getDevice())); + val_.reset(new TensorBase(mem, shape(), childVal->getBackend())); return val_; }; @@ -962,7 +948,7 @@ public: size_t offset = step_ * shape().elements() * sizeof(float); auto mem = New<MemoryPiece>(childGrad->memory()->data() + offset, childGrad->memory()->size()); - adj_.reset(new TensorBase(mem, shape(), childGrad->getDevice())); + adj_.reset(new TensorBase(mem, shape(), childGrad->getBackend())); return adj_; }; |