diff options
author | Marcin Junczys-Dowmunt <junczys@amu.edu.pl> | 2017-07-02 06:51:17 +0300 |
---|---|---|
committer | Marcin Junczys-Dowmunt <junczys@amu.edu.pl> | 2017-07-02 06:51:17 +0300 |
commit | 15cd7fdfcd7c2c7279de1b16ed850f25957252a8 (patch) | |
tree | 2cd009820c31f8244fe9d6400d1367e189efca02 /src/graph/node_operators_unary.h | |
parent | ac7e41f3df2504d01eb4e2f8bd08e9e1cb782298 (diff) | |
parent | 27ef488363fb82300f1d2f68a36fbac96264b172 (diff) |
merge with master
Diffstat (limited to 'src/graph/node_operators_unary.h')
-rw-r--r-- | src/graph/node_operators_unary.h | 143 |
1 files changed, 143 insertions, 0 deletions
diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h index 95227269..ee749e9c 100644 --- a/src/graph/node_operators_unary.h +++ b/src/graph/node_operators_unary.h @@ -7,6 +7,20 @@ #include "kernels/thrust_functions.h" #include "tensors/tensor.h" +#ifdef CUDNN + +#include <cudnn.h> + +#define CUDA_CALL(x) do { if((x) != cudaSuccess) { \ + printf("Error at %s:%d\n",__FILE__,__LINE__); \ + return EXIT_FAILURE;}} while(0) + +#define CUDNN_CALL(x) do { if((x) != CUDNN_STATUS_SUCCESS) { \ + printf("Error (%s) at %s:%d\n",cudnnGetErrorString(x),__FILE__,__LINE__); \ + }} while(0) + +#endif + namespace marian { struct UnaryNodeOp : public NaryNodeOp { @@ -684,4 +698,133 @@ struct LexicalProbNodeOp : public NaryNodeOp { float eps_; Ptr<sparse::CSR> lf_; }; + +#ifdef CUDNN + +class PoolingOp : public UnaryNodeOp { + public: + enum class Mode {MAX_POOLING, AVERAGE_POOLING}; + + PoolingOp( + Expr x, + int height, int width, + int padHeight, int padWidth, + int strideHeight, int strideWidth, + Mode mode = Mode::AVERAGE_POOLING) + : UnaryNodeOp(x) + { + CUDNN_CALL( cudnnCreate(&cudnnHandle_) ); + + + CUDNN_CALL( cudnnCreateTensorDescriptor(&xDesc_) ); + CUDNN_CALL( cudnnSetTensor4dDescriptor(xDesc_, + CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, + x->shape()[0], x->shape()[1], + x->shape()[2], x->shape()[3] + )); + + + cudnnPoolingMode_t cudnnPoolingMode; + switch (mode) { + case Mode::MAX_POOLING: + cudnnPoolingMode = CUDNN_POOLING_MAX; + break; + case Mode::AVERAGE_POOLING: + cudnnPoolingMode = CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING; + break; + default: + break; + }; + + height = std::min(height, x->shape()[2]); + strideHeight = std::min(strideHeight, x->shape()[2]); + + CUDNN_CALL( cudnnCreatePoolingDescriptor(&poolingDesc_) ); + CUDNN_CALL( cudnnSetPooling2dDescriptor(poolingDesc_, + cudnnPoolingMode, + CUDNN_NOT_PROPAGATE_NAN, + height, width, + padHeight, padWidth, + strideHeight, strideWidth + )); + + CUDNN_CALL(cudnnGetPooling2dForwardOutputDim( + poolingDesc_, + xDesc_, + shape_.begin(), shape_.begin() + 1, shape_.begin() + 2, shape_.begin() + 3 + )); + + CUDNN_CALL( cudnnCreateTensorDescriptor(&yDesc_) ); + CUDNN_CALL( cudnnSetTensor4dDescriptor(yDesc_, + CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, + shape_[0], shape_[1], + shape_[2], shape_[3]) + ); + CUDNN_CALL( cudnnCreateTensorDescriptor(&adjDesc_) ); + CUDNN_CALL( cudnnSetTensor4dDescriptor(adjDesc_, + CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, + shape_[0], shape_[1], + shape_[2], shape_[3]) + ); + } + + + NodeOps forwardOps() { + const float alpha = 1.0f; + const float beta = 0.0f; + + cudaSetDevice(val_->getDevice()); + + return { + NodeOp( + CUDNN_CALL( cudnnPoolingForward(cudnnHandle_, + poolingDesc_, + &alpha, + xDesc_, children_[0]->val()->data(), + &beta, + yDesc_, val_->data())) + ) + }; + } + + NodeOps backwardOps() { + cudaSetDevice(adj_->getDevice()); + const float alpha = 1.0f; + const float beta = 1.0f; + return { + NodeOp( + CUDNN_CALL( cudnnPoolingBackward(cudnnHandle_, + poolingDesc_, + &alpha, + yDesc_, val_->data(), + adjDesc_, adj_->data(), + xDesc_, children_[0]->val()->data(), + &beta, + xDesc_, children_[0]->grad()->data() + ))) + }; + } + + const std::string type() { + return "layer_max_pooling"; + } + + virtual ~PoolingOp() { + CUDNN_CALL( cudnnDestroy(cudnnHandle_) ); + CUDNN_CALL( cudnnDestroyPoolingDescriptor(poolingDesc_) ); + CUDNN_CALL( cudnnDestroyTensorDescriptor(xDesc_) ); + CUDNN_CALL( cudnnDestroyTensorDescriptor(yDesc_) ); + CUDNN_CALL( cudnnDestroyTensorDescriptor(adjDesc_) ); + } + + protected: + cudnnHandle_t cudnnHandle_; + cudnnPoolingDescriptor_t poolingDesc_; + cudnnTensorDescriptor_t xDesc_; + cudnnTensorDescriptor_t yDesc_; + cudnnTensorDescriptor_t adjDesc_; + +}; + +#endif } |