Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMarcin Junczys-Dowmunt <junczys@amu.edu.pl>2018-02-16 22:59:12 +0300
committerMarcin Junczys-Dowmunt <junczys@amu.edu.pl>2018-02-16 22:59:12 +0300
commit327cfc1cc3fbe3ab92927c800aa30aef1c41d517 (patch)
tree9e67c0035ebb7d2c022982443efc1ad338ebc615 /src/graph/node_operators_unary.h
parentdd296e77f76143033fc1589c7dce6d12196bbfdd (diff)
pass through backend
Diffstat (limited to 'src/graph/node_operators_unary.h')
-rw-r--r--src/graph/node_operators_unary.h24
1 files changed, 5 insertions, 19 deletions
diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h
index a3f27fd2..0170fc73 100644
--- a/src/graph/node_operators_unary.h
+++ b/src/graph/node_operators_unary.h
@@ -211,21 +211,7 @@ struct TanhNodeOp : public NaryNodeOp {
const std::string type() { return "tanh"; }
};
-/**
- * Represents a <a
- * href="https://en.wikipedia.org/wiki/Rectifier_(neural_networks)">rectified
- * linear</a> node in an expression graph.
- *
- * This node implements the activation function \f$ f(x) = \max(0, x) \f$ and
- * its derivative:
- * \f[
- * f^\prime(x) =
- * \begin{cases}
- * 0 & \text{if } x \leq 0 \\
- * 1 & \text{if } x > 0
- * \end{cases}
- * \f]
- */
+
struct ReLUNodeOp : public UnaryNodeOp {
template <typename... Args>
ReLUNodeOp(Args... args) : UnaryNodeOp(args...) {}
@@ -877,14 +863,14 @@ public:
Tensor& val() {
auto childVal = reshapee_->val();
val_.reset(
- new TensorBase(childVal->memory(), shape(), childVal->getDevice()));
+ new TensorBase(childVal->memory(), shape(), childVal->getBackend()));
return val_;
};
Tensor& grad() {
auto childGrad = reshapee_->grad();
adj_.reset(
- new TensorBase(childGrad->memory(), shape(), childGrad->getDevice()));
+ new TensorBase(childGrad->memory(), shape(), childGrad->getBackend()));
return adj_;
};
@@ -953,7 +939,7 @@ public:
size_t offset = step_ * shape().elements() * sizeof(float);
auto mem = New<MemoryPiece>(childVal->memory()->data() + offset,
childVal->memory()->size());
- val_.reset(new TensorBase(mem, shape(), childVal->getDevice()));
+ val_.reset(new TensorBase(mem, shape(), childVal->getBackend()));
return val_;
};
@@ -962,7 +948,7 @@ public:
size_t offset = step_ * shape().elements() * sizeof(float);
auto mem = New<MemoryPiece>(childGrad->memory()->data() + offset,
childGrad->memory()->size());
- adj_.reset(new TensorBase(mem, shape(), childGrad->getDevice()));
+ adj_.reset(new TensorBase(mem, shape(), childGrad->getBackend()));
return adj_;
};