diff options
author | Tomasz Dwojak <t.dwojak@amu.edu.pl> | 2017-11-16 02:05:55 +0300 |
---|---|---|
committer | Tomasz Dwojak <t.dwojak@amu.edu.pl> | 2017-11-16 02:30:41 +0300 |
commit | 0effc8d28d25fbf80e3cea34d6f5b4e42490a7f2 (patch) | |
tree | 0581ddd34c53d722bf9184fccc643cceca121c03 /src/graph/node_operators_unary.h | |
parent | ce44cd9e287804860538dc80af0f87ae3ff8cfec (diff) |
Refactor convolution and and poolingWithMasking
Diffstat (limited to 'src/graph/node_operators_unary.h')
-rw-r--r-- | src/graph/node_operators_unary.h | 41 |
1 files changed, 41 insertions, 0 deletions
diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h index 90995ddf..faf21dee 100644 --- a/src/graph/node_operators_unary.h +++ b/src/graph/node_operators_unary.h @@ -1051,4 +1051,45 @@ protected: PoolingWrapper pooling_; }; +class PoolingWithMaskingOp : public UnaryNodeOp { + public: + PoolingWithMaskingOp( Expr x, Expr mask, int width, bool isEven=false) + : UnaryNodeOp(x), + mask_(mask), + width_(width), + isEven_(isEven) + { + auto xShape = x->shape(); + int dimBatch = xShape[0]; + int dimWord = xShape[1]; + int cols = (isEven_) ? xShape[2] - 1 : xShape[2]; + int dimSentence = (cols / width_) + (cols % width_ != 0); + shape_ = {dimBatch, dimWord, dimSentence}; + } + + NodeOps forwardOps() { + return {NodeOp(PoolingWithMaskingForward(val_, + child(0)->val(), + mask_->val(), + width_, + isEven_))}; + } + + NodeOps backwardOps() { + return {NodeOp(PoolingWithMaskingBackward(adj_, + child(0)->grad(), + child(0)->val(), + mask_->val(), + width_, + isEven_))}; + } + + const std::string type() {return "layer_pooling";} + + protected: + Expr mask_; + int width_; + bool isEven_; +}; + } |