From dbba0f220dc16d6c6104f67010e9ce3b9f2a204b Mon Sep 17 00:00:00 2001 From: Marcin Junczys-Dowmunt Date: Sat, 24 Feb 2018 20:11:02 -0800 Subject: add cudnn back --- src/graph/node_operators_unary.h | 158 +++++++++++++++++++-------------------- 1 file changed, 79 insertions(+), 79 deletions(-) (limited to 'src/graph/node_operators_unary.h') diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h index 07c06fda..0a76471b 100644 --- a/src/graph/node_operators_unary.h +++ b/src/graph/node_operators_unary.h @@ -1055,84 +1055,84 @@ struct ShiftNodeOp : public UnaryNodeOp { // Ptr lf_; //}; -//class PoolingOp : public UnaryNodeOp { -//public: -// PoolingOp(Expr x, -// int height, -// int width, -// int padHeight, -// int padWidth, -// int strideHeight, -// int strideWidth, -// std::string mode) -// : UnaryNodeOp(x), -// pooling_(height, -// width, -// padHeight, -// padWidth, -// strideHeight, -// strideWidth, -// mode) { -// } -// -// NodeOps forwardOps() { -// return {NodeOp(pooling_.forward(child(0)->val(), val_))}; -// } -// -// NodeOps backwardOps() { -// return {NodeOp(pooling_.backward( -// child(0)->val(), -// child(0)->grad(), -// val_, -// adj_))}; -// } -// -// const std::string type() { return "layer_pooling"; } -// -// -//protected: -// PoolingWrapper pooling_; -//}; -// -//class PoolingWithMaskingOp : public UnaryNodeOp { -// public: -// PoolingWithMaskingOp( Expr x, Expr mask, int width, bool isEven=false) -// : UnaryNodeOp(x), -// mask_(mask), -// width_(width), -// isEven_(isEven) -// { -// auto xShape = x->shape(); -// int dimBatch = xShape[0]; -// int dimWord = xShape[1]; -// int cols = (isEven_) ? xShape[2] - 1 : xShape[2]; -// int dimSentence = (cols / width_) + (cols % width_ != 0); -// shape_ = {dimBatch, dimWord, dimSentence}; -// } -// -// NodeOps forwardOps() { -// return {NodeOp(PoolingWithMaskingForward(val_, -// child(0)->val(), -// mask_->val(), -// width_, -// isEven_))}; -// } -// -// NodeOps backwardOps() { -// return {NodeOp(PoolingWithMaskingBackward(adj_, -// child(0)->grad(), -// child(0)->val(), -// mask_->val(), -// width_, -// isEven_))}; -// } -// -// const std::string type() {return "layer_pooling";} -// -// protected: -// Expr mask_; -// int width_; -// bool isEven_; -//}; +class PoolingOp : public UnaryNodeOp { +public: + PoolingOp(Expr x, + int height, + int width, + int padHeight, + int padWidth, + int strideHeight, + int strideWidth, + std::string mode) + : UnaryNodeOp(x), + pooling_(height, + width, + padHeight, + padWidth, + strideHeight, + strideWidth, + mode) { + } + + NodeOps forwardOps() { + return {NodeOp(pooling_.forward(child(0)->val(), val_))}; + } + + NodeOps backwardOps() { + return {NodeOp(pooling_.backward( + child(0)->val(), + child(0)->grad(), + val_, + adj_))}; + } + + const std::string type() { return "layer_pooling"; } + + +protected: + PoolingWrapper pooling_; +}; + +class PoolingWithMaskingOp : public UnaryNodeOp { + public: + PoolingWithMaskingOp( Expr x, Expr mask, int width, bool isEven=false) + : UnaryNodeOp(x), + mask_(mask), + width_(width), + isEven_(isEven) + { + auto xShape = x->shape(); + int dimBatch = xShape[0]; + int dimWord = xShape[1]; + int cols = (isEven_) ? xShape[2] - 1 : xShape[2]; + int dimSentence = (cols / width_) + (cols % width_ != 0); + shape_ = {dimBatch, dimWord, dimSentence}; + } + + NodeOps forwardOps() { + return {NodeOp(PoolingWithMaskingForward(val_, + child(0)->val(), + mask_->val(), + width_, + isEven_))}; + } + + NodeOps backwardOps() { + return {NodeOp(PoolingWithMaskingBackward(adj_, + child(0)->grad(), + child(0)->val(), + mask_->val(), + width_, + isEven_))}; + } + + const std::string type() {return "layer_pooling";} + + protected: + Expr mask_; + int width_; + bool isEven_; +}; } -- cgit v1.2.3