Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMarcin Junczys-Dowmunt <junczys@amu.edu.pl>2017-07-02 06:51:17 +0300
committerMarcin Junczys-Dowmunt <junczys@amu.edu.pl>2017-07-02 06:51:17 +0300
commit15cd7fdfcd7c2c7279de1b16ed850f25957252a8 (patch)
tree2cd009820c31f8244fe9d6400d1367e189efca02 /src/graph/node_operators_unary.h
parentac7e41f3df2504d01eb4e2f8bd08e9e1cb782298 (diff)
parent27ef488363fb82300f1d2f68a36fbac96264b172 (diff)
merge with master
Diffstat (limited to 'src/graph/node_operators_unary.h')
-rw-r--r--src/graph/node_operators_unary.h143
1 files changed, 143 insertions, 0 deletions
diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h
index 95227269..ee749e9c 100644
--- a/src/graph/node_operators_unary.h
+++ b/src/graph/node_operators_unary.h
@@ -7,6 +7,20 @@
#include "kernels/thrust_functions.h"
#include "tensors/tensor.h"
+#ifdef CUDNN
+
+#include <cudnn.h>
+
+#define CUDA_CALL(x) do { if((x) != cudaSuccess) { \
+ printf("Error at %s:%d\n",__FILE__,__LINE__); \
+ return EXIT_FAILURE;}} while(0)
+
+#define CUDNN_CALL(x) do { if((x) != CUDNN_STATUS_SUCCESS) { \
+ printf("Error (%s) at %s:%d\n",cudnnGetErrorString(x),__FILE__,__LINE__); \
+ }} while(0)
+
+#endif
+
namespace marian {
struct UnaryNodeOp : public NaryNodeOp {
@@ -684,4 +698,133 @@ struct LexicalProbNodeOp : public NaryNodeOp {
float eps_;
Ptr<sparse::CSR> lf_;
};
+
+#ifdef CUDNN
+
+class PoolingOp : public UnaryNodeOp {
+ public:
+ enum class Mode {MAX_POOLING, AVERAGE_POOLING};
+
+ PoolingOp(
+ Expr x,
+ int height, int width,
+ int padHeight, int padWidth,
+ int strideHeight, int strideWidth,
+ Mode mode = Mode::AVERAGE_POOLING)
+ : UnaryNodeOp(x)
+ {
+ CUDNN_CALL( cudnnCreate(&cudnnHandle_) );
+
+
+ CUDNN_CALL( cudnnCreateTensorDescriptor(&xDesc_) );
+ CUDNN_CALL( cudnnSetTensor4dDescriptor(xDesc_,
+ CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT,
+ x->shape()[0], x->shape()[1],
+ x->shape()[2], x->shape()[3]
+ ));
+
+
+ cudnnPoolingMode_t cudnnPoolingMode;
+ switch (mode) {
+ case Mode::MAX_POOLING:
+ cudnnPoolingMode = CUDNN_POOLING_MAX;
+ break;
+ case Mode::AVERAGE_POOLING:
+ cudnnPoolingMode = CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING;
+ break;
+ default:
+ break;
+ };
+
+ height = std::min(height, x->shape()[2]);
+ strideHeight = std::min(strideHeight, x->shape()[2]);
+
+ CUDNN_CALL( cudnnCreatePoolingDescriptor(&poolingDesc_) );
+ CUDNN_CALL( cudnnSetPooling2dDescriptor(poolingDesc_,
+ cudnnPoolingMode,
+ CUDNN_NOT_PROPAGATE_NAN,
+ height, width,
+ padHeight, padWidth,
+ strideHeight, strideWidth
+ ));
+
+ CUDNN_CALL(cudnnGetPooling2dForwardOutputDim(
+ poolingDesc_,
+ xDesc_,
+ shape_.begin(), shape_.begin() + 1, shape_.begin() + 2, shape_.begin() + 3
+ ));
+
+ CUDNN_CALL( cudnnCreateTensorDescriptor(&yDesc_) );
+ CUDNN_CALL( cudnnSetTensor4dDescriptor(yDesc_,
+ CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT,
+ shape_[0], shape_[1],
+ shape_[2], shape_[3])
+ );
+ CUDNN_CALL( cudnnCreateTensorDescriptor(&adjDesc_) );
+ CUDNN_CALL( cudnnSetTensor4dDescriptor(adjDesc_,
+ CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT,
+ shape_[0], shape_[1],
+ shape_[2], shape_[3])
+ );
+ }
+
+
+ NodeOps forwardOps() {
+ const float alpha = 1.0f;
+ const float beta = 0.0f;
+
+ cudaSetDevice(val_->getDevice());
+
+ return {
+ NodeOp(
+ CUDNN_CALL( cudnnPoolingForward(cudnnHandle_,
+ poolingDesc_,
+ &alpha,
+ xDesc_, children_[0]->val()->data(),
+ &beta,
+ yDesc_, val_->data()))
+ )
+ };
+ }
+
+ NodeOps backwardOps() {
+ cudaSetDevice(adj_->getDevice());
+ const float alpha = 1.0f;
+ const float beta = 1.0f;
+ return {
+ NodeOp(
+ CUDNN_CALL( cudnnPoolingBackward(cudnnHandle_,
+ poolingDesc_,
+ &alpha,
+ yDesc_, val_->data(),
+ adjDesc_, adj_->data(),
+ xDesc_, children_[0]->val()->data(),
+ &beta,
+ xDesc_, children_[0]->grad()->data()
+ )))
+ };
+ }
+
+ const std::string type() {
+ return "layer_max_pooling";
+ }
+
+ virtual ~PoolingOp() {
+ CUDNN_CALL( cudnnDestroy(cudnnHandle_) );
+ CUDNN_CALL( cudnnDestroyPoolingDescriptor(poolingDesc_) );
+ CUDNN_CALL( cudnnDestroyTensorDescriptor(xDesc_) );
+ CUDNN_CALL( cudnnDestroyTensorDescriptor(yDesc_) );
+ CUDNN_CALL( cudnnDestroyTensorDescriptor(adjDesc_) );
+ }
+
+ protected:
+ cudnnHandle_t cudnnHandle_;
+ cudnnPoolingDescriptor_t poolingDesc_;
+ cudnnTensorDescriptor_t xDesc_;
+ cudnnTensorDescriptor_t yDesc_;
+ cudnnTensorDescriptor_t adjDesc_;
+
+};
+
+#endif
}