Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMarcin Junczys-Dowmunt <junczys@amu.edu.pl>2018-02-25 07:11:02 +0300
committerMarcin Junczys-Dowmunt <junczys@amu.edu.pl>2018-02-25 07:11:02 +0300
commitdbba0f220dc16d6c6104f67010e9ce3b9f2a204b (patch)
tree0bfd3c3edf988aa3cb4360d2f2052d121900ee23 /src/graph/node_operators_unary.h
parent845063b3429f9304b7d09a7c43037308cc4d06a4 (diff)
add cudnn back
Diffstat (limited to 'src/graph/node_operators_unary.h')
-rw-r--r--src/graph/node_operators_unary.h158
1 files changed, 79 insertions, 79 deletions
diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h
index 07c06fda..0a76471b 100644
--- a/src/graph/node_operators_unary.h
+++ b/src/graph/node_operators_unary.h
@@ -1055,84 +1055,84 @@ struct ShiftNodeOp : public UnaryNodeOp {
// Ptr<sparse::CSR> lf_;
//};
-//class PoolingOp : public UnaryNodeOp {
-//public:
-// PoolingOp(Expr x,
-// int height,
-// int width,
-// int padHeight,
-// int padWidth,
-// int strideHeight,
-// int strideWidth,
-// std::string mode)
-// : UnaryNodeOp(x),
-// pooling_(height,
-// width,
-// padHeight,
-// padWidth,
-// strideHeight,
-// strideWidth,
-// mode) {
-// }
-//
-// NodeOps forwardOps() {
-// return {NodeOp(pooling_.forward(child(0)->val(), val_))};
-// }
-//
-// NodeOps backwardOps() {
-// return {NodeOp(pooling_.backward(
-// child(0)->val(),
-// child(0)->grad(),
-// val_,
-// adj_))};
-// }
-//
-// const std::string type() { return "layer_pooling"; }
-//
-//
-//protected:
-// PoolingWrapper pooling_;
-//};
-//
-//class PoolingWithMaskingOp : public UnaryNodeOp {
-// public:
-// PoolingWithMaskingOp( Expr x, Expr mask, int width, bool isEven=false)
-// : UnaryNodeOp(x),
-// mask_(mask),
-// width_(width),
-// isEven_(isEven)
-// {
-// auto xShape = x->shape();
-// int dimBatch = xShape[0];
-// int dimWord = xShape[1];
-// int cols = (isEven_) ? xShape[2] - 1 : xShape[2];
-// int dimSentence = (cols / width_) + (cols % width_ != 0);
-// shape_ = {dimBatch, dimWord, dimSentence};
-// }
-//
-// NodeOps forwardOps() {
-// return {NodeOp(PoolingWithMaskingForward(val_,
-// child(0)->val(),
-// mask_->val(),
-// width_,
-// isEven_))};
-// }
-//
-// NodeOps backwardOps() {
-// return {NodeOp(PoolingWithMaskingBackward(adj_,
-// child(0)->grad(),
-// child(0)->val(),
-// mask_->val(),
-// width_,
-// isEven_))};
-// }
-//
-// const std::string type() {return "layer_pooling";}
-//
-// protected:
-// Expr mask_;
-// int width_;
-// bool isEven_;
-//};
+class PoolingOp : public UnaryNodeOp {
+public:
+ PoolingOp(Expr x,
+ int height,
+ int width,
+ int padHeight,
+ int padWidth,
+ int strideHeight,
+ int strideWidth,
+ std::string mode)
+ : UnaryNodeOp(x),
+ pooling_(height,
+ width,
+ padHeight,
+ padWidth,
+ strideHeight,
+ strideWidth,
+ mode) {
+ }
+
+ NodeOps forwardOps() {
+ return {NodeOp(pooling_.forward(child(0)->val(), val_))};
+ }
+
+ NodeOps backwardOps() {
+ return {NodeOp(pooling_.backward(
+ child(0)->val(),
+ child(0)->grad(),
+ val_,
+ adj_))};
+ }
+
+ const std::string type() { return "layer_pooling"; }
+
+
+protected:
+ PoolingWrapper pooling_;
+};
+
+class PoolingWithMaskingOp : public UnaryNodeOp {
+ public:
+ PoolingWithMaskingOp( Expr x, Expr mask, int width, bool isEven=false)
+ : UnaryNodeOp(x),
+ mask_(mask),
+ width_(width),
+ isEven_(isEven)
+ {
+ auto xShape = x->shape();
+ int dimBatch = xShape[0];
+ int dimWord = xShape[1];
+ int cols = (isEven_) ? xShape[2] - 1 : xShape[2];
+ int dimSentence = (cols / width_) + (cols % width_ != 0);
+ shape_ = {dimBatch, dimWord, dimSentence};
+ }
+
+ NodeOps forwardOps() {
+ return {NodeOp(PoolingWithMaskingForward(val_,
+ child(0)->val(),
+ mask_->val(),
+ width_,
+ isEven_))};
+ }
+
+ NodeOps backwardOps() {
+ return {NodeOp(PoolingWithMaskingBackward(adj_,
+ child(0)->grad(),
+ child(0)->val(),
+ mask_->val(),
+ width_,
+ isEven_))};
+ }
+
+ const std::string type() {return "layer_pooling";}
+
+ protected:
+ Expr mask_;
+ int width_;
+ bool isEven_;
+};
}