diff options
author | Marcin Junczys-Dowmunt <junczys@amu.edu.pl> | 2018-02-16 22:59:12 +0300 |
---|---|---|
committer | Marcin Junczys-Dowmunt <junczys@amu.edu.pl> | 2018-02-16 22:59:12 +0300 |
commit | 327cfc1cc3fbe3ab92927c800aa30aef1c41d517 (patch) | |
tree | 9e67c0035ebb7d2c022982443efc1ad338ebc615 /src/graph | |
parent | dd296e77f76143033fc1589c7dce6d12196bbfdd (diff) |
pass through backend
Diffstat (limited to 'src/graph')
-rw-r--r-- | src/graph/expression_graph.cpp | 6 | ||||
-rw-r--r-- | src/graph/node_operators_unary.h | 24 | ||||
-rw-r--r-- | src/graph/parameters.h | 6 |
3 files changed, 11 insertions, 25 deletions
diff --git a/src/graph/expression_graph.cpp b/src/graph/expression_graph.cpp index 183b5787..934e2b73 100644 --- a/src/graph/expression_graph.cpp +++ b/src/graph/expression_graph.cpp @@ -12,15 +12,15 @@ void ExpressionGraph::setDevice(DeviceId deviceId) { if(!backend_) { backend_ = BackendByDevice(deviceId, Config::seed); params_ = New<Parameters>(); - params_->init(backend_->getDevice()); - tensors_ = New<TensorAllocator>(backend_->getDevice()); + params_->init(backend_); + tensors_ = New<TensorAllocator>(backend_); } } Expr ExpressionGraph::dropout(float prob, Shape shape) { return Expression<ConstantNode>(shared_from_this(), keywords::init = [prob, this](Tensor t) { - Dropout(backend_, t, prob); + Dropout(t, prob); }, keywords::shape = shape); } diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h index a3f27fd2..0170fc73 100644 --- a/src/graph/node_operators_unary.h +++ b/src/graph/node_operators_unary.h @@ -211,21 +211,7 @@ struct TanhNodeOp : public NaryNodeOp { const std::string type() { return "tanh"; } }; -/** - * Represents a <a - * href="https://en.wikipedia.org/wiki/Rectifier_(neural_networks)">rectified - * linear</a> node in an expression graph. - * - * This node implements the activation function \f$ f(x) = \max(0, x) \f$ and - * its derivative: - * \f[ - * f^\prime(x) = - * \begin{cases} - * 0 & \text{if } x \leq 0 \\ - * 1 & \text{if } x > 0 - * \end{cases} - * \f] - */ + struct ReLUNodeOp : public UnaryNodeOp { template <typename... Args> ReLUNodeOp(Args... args) : UnaryNodeOp(args...) {} @@ -877,14 +863,14 @@ public: Tensor& val() { auto childVal = reshapee_->val(); val_.reset( - new TensorBase(childVal->memory(), shape(), childVal->getDevice())); + new TensorBase(childVal->memory(), shape(), childVal->getBackend())); return val_; }; Tensor& grad() { auto childGrad = reshapee_->grad(); adj_.reset( - new TensorBase(childGrad->memory(), shape(), childGrad->getDevice())); + new TensorBase(childGrad->memory(), shape(), childGrad->getBackend())); return adj_; }; @@ -953,7 +939,7 @@ public: size_t offset = step_ * shape().elements() * sizeof(float); auto mem = New<MemoryPiece>(childVal->memory()->data() + offset, childVal->memory()->size()); - val_.reset(new TensorBase(mem, shape(), childVal->getDevice())); + val_.reset(new TensorBase(mem, shape(), childVal->getBackend())); return val_; }; @@ -962,7 +948,7 @@ public: size_t offset = step_ * shape().elements() * sizeof(float); auto mem = New<MemoryPiece>(childGrad->memory()->data() + offset, childGrad->memory()->size()); - adj_.reset(new TensorBase(mem, shape(), childGrad->getDevice())); + adj_.reset(new TensorBase(mem, shape(), childGrad->getBackend())); return adj_; }; diff --git a/src/graph/parameters.h b/src/graph/parameters.h index ed8b7690..3f282e4a 100644 --- a/src/graph/parameters.h +++ b/src/graph/parameters.h @@ -20,9 +20,9 @@ private: Ptr<TensorAllocator> grads_; public: - void init(DeviceId deviceId) { - vals_ = New<TensorAllocator>(deviceId); - grads_ = New<TensorAllocator>(deviceId); + void init(Ptr<Backend> backend) { + vals_ = New<TensorAllocator>(backend); + grads_ = New<TensorAllocator>(backend); } auto begin() -> decltype(params_.begin()) { return params_.begin(); } |