diff options
Diffstat (limited to 'src/graph/node_operators_unary.h')
-rw-r--r-- | src/graph/node_operators_unary.h | 96 |
1 files changed, 42 insertions, 54 deletions
diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h index 2da3c463..0aa7f0a6 100644 --- a/src/graph/node_operators_unary.h +++ b/src/graph/node_operators_unary.h @@ -157,60 +157,6 @@ struct ReLUNodeOp : public UnaryNodeOp { } }; -/** - * @brief Represents a <a href="https://en.wikipedia.org/wiki/Dropout_(neural_networks)">dropout</a> node - * in an expression graph. - * - * @see \cite dropout - * @see \cite cudnn - */ -//struct DropoutNodeOp : public UnaryNodeOp { -// template <typename ...Args> -// DropoutNodeOp(Args ...args) -// : UnaryNodeOp(args...), -// allocated_(false), p_(Get(keywords::value, 0.5)) {} -// -// ~DropoutNodeOp() { -// if(allocated_) -// CudnnDropoutDestroy(dropDesc_, space_, states_); -// } -// -// void inference() { -// Element(_1 = _2, val_, children_[0]->val()); -// } -// -// void forward() { -// if(!allocated_) { -// CudnnDropoutPrepare(children_[0]->val(), p_, -// &dropDesc_, -// &space_, &spaceSize_, -// &states_, (size_t)this); // seeding with pointer address -// allocated_ = true; -// } -// -// CudnnDropoutForward(dropDesc_, space_, spaceSize_, -// val_, children_[0]->val()); -// } -// -// void backward() { -// if(children_[0]->trainable()) -// CudnnDropoutBackward(dropDesc_, space_, spaceSize_, -// children_[0]->grad(), adj_); -// } -// -// const std::string type() { -// return "dropout"; -// } -// -// private: -// bool allocated_; -// float p_; -// void* states_; -// void* space_; -// size_t spaceSize_; -// cudnnDropoutDescriptor_t dropDesc_; -//}; - struct SoftmaxNodeOp : public NaryNodeOp { template <typename ...Args> SoftmaxNodeOp(Expr a, Args ...args) @@ -758,4 +704,46 @@ struct TimestepNodeOp : public UnaryNodeOp { }; +struct ShiftNodeOp : public UnaryNodeOp { + template <typename ...Args> + ShiftNodeOp(Expr a, Shape shift, Args ...args) + : UnaryNodeOp(a, keywords::shape=a->shape(), args...), + shift_(shift) { + } + + NodeOps forwardOps() { + return { + NodeOp(Shift(val_, + children_[0]->val(), + shift_)) + }; + } + + NodeOps backwardOps() { + return { + NodeOp(Shift(children_[0]->grad(), + adj_, + shift_, + true)) + }; + } + + const std::string type() { + return "shift"; + } + + virtual size_t hash() { + if(!hash_) { + size_t seed = NaryNodeOp::hash(); + for(auto i : shape_) + boost::hash_combine(seed, i); + hash_ = seed; + } + return hash_; + } + + + Shape shift_; +}; + } |