Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTomasz Dwojak <t.dwojak@amu.edu.pl>2017-11-16 02:05:55 +0300
committerTomasz Dwojak <t.dwojak@amu.edu.pl>2017-11-16 02:30:41 +0300
commit0effc8d28d25fbf80e3cea34d6f5b4e42490a7f2 (patch)
tree0581ddd34c53d722bf9184fccc643cceca121c03 /src/graph/node_operators_unary.h
parentce44cd9e287804860538dc80af0f87ae3ff8cfec (diff)
Refactor convolution and and poolingWithMasking
Diffstat (limited to 'src/graph/node_operators_unary.h')
-rw-r--r--src/graph/node_operators_unary.h41
1 files changed, 41 insertions, 0 deletions
diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h
index 90995ddf..faf21dee 100644
--- a/src/graph/node_operators_unary.h
+++ b/src/graph/node_operators_unary.h
@@ -1051,4 +1051,45 @@ protected:
PoolingWrapper pooling_;
};
+class PoolingWithMaskingOp : public UnaryNodeOp {
+ public:
+ PoolingWithMaskingOp( Expr x, Expr mask, int width, bool isEven=false)
+ : UnaryNodeOp(x),
+ mask_(mask),
+ width_(width),
+ isEven_(isEven)
+ {
+ auto xShape = x->shape();
+ int dimBatch = xShape[0];
+ int dimWord = xShape[1];
+ int cols = (isEven_) ? xShape[2] - 1 : xShape[2];
+ int dimSentence = (cols / width_) + (cols % width_ != 0);
+ shape_ = {dimBatch, dimWord, dimSentence};
+ }
+
+ NodeOps forwardOps() {
+ return {NodeOp(PoolingWithMaskingForward(val_,
+ child(0)->val(),
+ mask_->val(),
+ width_,
+ isEven_))};
+ }
+
+ NodeOps backwardOps() {
+ return {NodeOp(PoolingWithMaskingBackward(adj_,
+ child(0)->grad(),
+ child(0)->val(),
+ mask_->val(),
+ width_,
+ isEven_))};
+ }
+
+ const std::string type() {return "layer_pooling";}
+
+ protected:
+ Expr mask_;
+ int width_;
+ bool isEven_;
+};
+
}