diff options
Diffstat (limited to 'src/graph/node_operators_unary.h')
-rw-r--r-- | src/graph/node_operators_unary.h | 41 |
1 files changed, 41 insertions, 0 deletions
diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h index 90995ddf..faf21dee 100644 --- a/src/graph/node_operators_unary.h +++ b/src/graph/node_operators_unary.h @@ -1051,4 +1051,45 @@ protected: PoolingWrapper pooling_; }; +class PoolingWithMaskingOp : public UnaryNodeOp { + public: + PoolingWithMaskingOp( Expr x, Expr mask, int width, bool isEven=false) + : UnaryNodeOp(x), + mask_(mask), + width_(width), + isEven_(isEven) + { + auto xShape = x->shape(); + int dimBatch = xShape[0]; + int dimWord = xShape[1]; + int cols = (isEven_) ? xShape[2] - 1 : xShape[2]; + int dimSentence = (cols / width_) + (cols % width_ != 0); + shape_ = {dimBatch, dimWord, dimSentence}; + } + + NodeOps forwardOps() { + return {NodeOp(PoolingWithMaskingForward(val_, + child(0)->val(), + mask_->val(), + width_, + isEven_))}; + } + + NodeOps backwardOps() { + return {NodeOp(PoolingWithMaskingBackward(adj_, + child(0)->grad(), + child(0)->val(), + mask_->val(), + width_, + isEven_))}; + } + + const std::string type() {return "layer_pooling";} + + protected: + Expr mask_; + int width_; + bool isEven_; +}; + } |