Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/src/graph
diff options
context:
space:
mode:
authorTomasz Dwojak <t.dwojak@amu.edu.pl>2017-11-10 16:11:07 +0300
committerTomasz Dwojak <t.dwojak@amu.edu.pl>2017-11-13 11:40:38 +0300
commit9b24f9c6ac23d50ca698daa7ca2f13ecc1e96ead (patch)
tree884348ae51a39fef93fc5b6eca1d5ddb0d89f1e5 /src/graph
parent72761ed08425cc1d20a9f522c619052bef6f5dd2 (diff)
Move pooling code to cudnn files
Diffstat (limited to 'src/graph')
-rw-r--r--src/graph/expression_operators.cu42
-rw-r--r--src/graph/node_operators_binary.h11
-rw-r--r--src/graph/node_operators_unary.h151
3 files changed, 54 insertions, 150 deletions
diff --git a/src/graph/expression_operators.cu b/src/graph/expression_operators.cu
index 98570862..b271d971 100644
--- a/src/graph/expression_operators.cu
+++ b/src/graph/expression_operators.cu
@@ -303,33 +303,43 @@ Expr convolution(Expr x, Expr filters, Expr bias) {
return Expression<ConvolutionOp>(nodes);
}
-// clang-format off
Expr avg_pooling(
Expr x,
- int height, int width,
- int padHeight, int padWidth,
- int strideHeight, int strideWidth)
+ int height,
+ int width,
+ int padHeight,
+ int padWidth,
+ int strideHeight,
+ int strideWidth)
{
return Expression<PoolingOp>(x,
- height, width,
- padHeight, padWidth,
- strideHeight, strideWidth,
- PoolingOp::Mode::AVERAGE_POOLING);
+ height,
+ width,
+ padHeight,
+ padWidth,
+ strideHeight,
+ strideWidth,
+ "avg");
}
Expr max_pooling(
Expr x,
- int height, int width,
- int padHeight, int padWidth,
- int strideHeight, int strideWidth)
+ int height,
+ int width,
+ int padHeight,
+ int padWidth,
+ int strideHeight,
+ int strideWidth)
{
return Expression<PoolingOp>(x,
- height, width,
- padHeight, padWidth,
- strideHeight, strideWidth,
- PoolingOp::Mode::MAX_POOLING);
+ height,
+ width,
+ padHeight,
+ padWidth,
+ strideHeight,
+ strideWidth,
+ "max");
}
-// clang-format on
#endif
}
diff --git a/src/graph/node_operators_binary.h b/src/graph/node_operators_binary.h
index 4b8dd70e..a3d24179 100644
--- a/src/graph/node_operators_binary.h
+++ b/src/graph/node_operators_binary.h
@@ -712,18 +712,23 @@ class ConvolutionOp : public NaryNodeOp {
public:
ConvolutionOp(const std::vector<Expr>& nodes)
: NaryNodeOp(nodes),
- conv_(nodes[1]->val(), nodes[2]->val()) {
- conv_.getOutputShape(nodes[0]->val(), shape_);
+ conv_(nodes[1]->shape(), nodes[2]->shape()) {
+ conv_.getOutputShape(nodes[0]->shape(), shape_);
}
NodeOps forwardOps() {
- return {NodeOp(conv_.forward(child(0)->val(), val_))};
+ return {NodeOp(conv_.forward(
+ child(0)->val(),
+ child(1)->val(),
+ child(2)->val(),
+ val_))};
}
NodeOps backwardOps() {
return {NodeOp(conv_.backward(
child(0)->val(),
child(0)->grad(),
+ child(1)->val(),
child(1)->grad(),
child(2)->grad(),
adj_))};
diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h
index 6554c4f0..90995ddf 100644
--- a/src/graph/node_operators_unary.h
+++ b/src/graph/node_operators_unary.h
@@ -6,30 +6,8 @@
#include "kernels/tensor_operators.h"
#include "tensors/tensor.h"
#include "functional/functional.h"
+#include "kernels/cudnn_wrappers.h"
-#ifdef CUDNN
-
-#include <cudnn.h>
-
-#define CUDA_CALL(x) \
- do { \
- if((x) != cudaSuccess) { \
- printf("Error at %s:%d\n", __FILE__, __LINE__); \
- return EXIT_FAILURE; \
- } \
- } while(0)
-
-#define CUDNN_CALL(x) \
- do { \
- if((x) != CUDNN_STATUS_SUCCESS) { \
- printf("Error (%s) at %s:%d\n", \
- cudnnGetErrorString(x), \
- __FILE__, \
- __LINE__); \
- } \
- } while(0)
-
-#endif
namespace marian {
@@ -1034,12 +1012,8 @@ struct ShiftNodeOp : public UnaryNodeOp {
// Ptr<sparse::CSR> lf_;
//};
-#ifdef CUDNN
-
class PoolingOp : public UnaryNodeOp {
public:
- enum class Mode { MAX_POOLING, AVERAGE_POOLING };
-
PoolingOp(Expr x,
int height,
int width,
@@ -1047,119 +1021,34 @@ public:
int padWidth,
int strideHeight,
int strideWidth,
- Mode mode = Mode::AVERAGE_POOLING)
- : UnaryNodeOp(x) {
- CUDNN_CALL(cudnnCreate(&cudnnHandle_));
-
- CUDNN_CALL(cudnnCreateTensorDescriptor(&xDesc_));
- CUDNN_CALL(cudnnSetTensor4dDescriptor(xDesc_,
- CUDNN_TENSOR_NCHW,
- CUDNN_DATA_FLOAT,
- x->shape()[0],
- x->shape()[1],
- x->shape()[2],
- x->shape()[3]));
-
- cudnnPoolingMode_t cudnnPoolingMode;
- switch(mode) {
- case Mode::MAX_POOLING: cudnnPoolingMode = CUDNN_POOLING_MAX; break;
- case Mode::AVERAGE_POOLING:
- cudnnPoolingMode = CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING;
- break;
- default: break;
- };
-
- height = std::min(height, x->shape()[2]);
- strideHeight = std::min(strideHeight, x->shape()[2]);
-
- CUDNN_CALL(cudnnCreatePoolingDescriptor(&poolingDesc_));
- CUDNN_CALL(cudnnSetPooling2dDescriptor(poolingDesc_,
- cudnnPoolingMode,
- CUDNN_NOT_PROPAGATE_NAN,
- height,
- width,
- padHeight,
- padWidth,
- strideHeight,
- strideWidth));
- /* @TODO: does not compile
- CUDNN_CALL(cudnnGetPooling2dForwardOutputDim(poolingDesc_,
- xDesc_,
- shape_.begin(),
- shape_.begin() + 1,
- shape_.begin() + 2,
- shape_.begin() + 3));
-*/
- CUDNN_CALL(cudnnCreateTensorDescriptor(&yDesc_));
- CUDNN_CALL(cudnnSetTensor4dDescriptor(yDesc_,
- CUDNN_TENSOR_NCHW,
- CUDNN_DATA_FLOAT,
- shape_[0],
- shape_[1],
- shape_[2],
- shape_[3]));
- CUDNN_CALL(cudnnCreateTensorDescriptor(&adjDesc_));
- CUDNN_CALL(cudnnSetTensor4dDescriptor(adjDesc_,
- CUDNN_TENSOR_NCHW,
- CUDNN_DATA_FLOAT,
- shape_[0],
- shape_[1],
- shape_[2],
- shape_[3]));
+ std::string mode)
+ : UnaryNodeOp(x),
+ pooling_(height,
+ width,
+ padHeight,
+ padWidth,
+ strideHeight,
+ strideWidth,
+ mode) {
}
NodeOps forwardOps() {
- const float alpha = 1.0f;
- const float beta = 0.0f;
-
- cudaSetDevice(val_->getDevice());
-
- return {NodeOp(CUDNN_CALL(cudnnPoolingForward(cudnnHandle_,
- poolingDesc_,
- &alpha,
- xDesc_,
- children_[0]->val()->data(),
- &beta,
- yDesc_,
- val_->data())))};
+ return {NodeOp(pooling_.forward(child(0)->val(), val_))};
}
NodeOps backwardOps() {
- cudaSetDevice(adj_->getDevice());
- const float alpha = 1.0f;
- const float beta = 1.0f;
- return {
- NodeOp(CUDNN_CALL(cudnnPoolingBackward(cudnnHandle_,
- poolingDesc_,
- &alpha,
- yDesc_,
- val_->data(),
- adjDesc_,
- adj_->data(),
- xDesc_,
- children_[0]->val()->data(),
- &beta,
- xDesc_,
- children_[0]->grad()->data())))};
- }
-
- const std::string type() { return "layer_max_pooling"; }
-
- virtual ~PoolingOp() {
- CUDNN_CALL(cudnnDestroy(cudnnHandle_));
- CUDNN_CALL(cudnnDestroyPoolingDescriptor(poolingDesc_));
- CUDNN_CALL(cudnnDestroyTensorDescriptor(xDesc_));
- CUDNN_CALL(cudnnDestroyTensorDescriptor(yDesc_));
- CUDNN_CALL(cudnnDestroyTensorDescriptor(adjDesc_));
+ return {NodeOp(pooling_.backward(
+ child(0)->val(),
+ child(0)->grad(),
+ val_,
+ adj_))};
}
+ const std::string type() { return "layer_pooling"; }
+
+
protected:
- cudnnHandle_t cudnnHandle_;
- cudnnPoolingDescriptor_t poolingDesc_;
- cudnnTensorDescriptor_t xDesc_;
- cudnnTensorDescriptor_t yDesc_;
- cudnnTensorDescriptor_t adjDesc_;
+ PoolingWrapper pooling_;
};
-#endif
}