diff options
author | Roman Grundkiewicz <rgrundki@exseed.ed.ac.uk> | 2018-07-02 19:17:14 +0300 |
---|---|---|
committer | Roman Grundkiewicz <rgrundki@exseed.ed.ac.uk> | 2018-07-02 19:17:14 +0300 |
commit | 9bcb4ee5a30de27f9ac8976f06f1c5ef952f7930 (patch) | |
tree | ab7c6ce05d4c8758566bd6dd41cb2119e0c69715 /src/graph | |
parent | f85e63d77972b4214550509558a2306f3ddbc989 (diff) |
Fix #262 PoolingWrapper dependency
Diffstat (limited to 'src/graph')
-rw-r--r-- | src/graph/expression_operators.cpp | 2 | ||||
-rw-r--r-- | src/graph/node_operators_binary.h | 7 | ||||
-rw-r--r-- | src/graph/node_operators_unary.h | 6 |
3 files changed, 13 insertions, 2 deletions
diff --git a/src/graph/expression_operators.cpp b/src/graph/expression_operators.cpp index aaf8cb7e..d7e25b5a 100644 --- a/src/graph/expression_operators.cpp +++ b/src/graph/expression_operators.cpp @@ -441,6 +441,7 @@ Expr shift(Expr a, Shape shift) { //} #ifdef CUDA_FOUND +#ifdef CUDNN Expr avg_pooling(Expr x, int height, @@ -505,4 +506,5 @@ Expr pooling_with_masking(Expr x, Expr mask, int width, bool isEven) { } #endif +#endif } diff --git a/src/graph/node_operators_binary.h b/src/graph/node_operators_binary.h index dfc75542..b5476e68 100644 --- a/src/graph/node_operators_binary.h +++ b/src/graph/node_operators_binary.h @@ -4,9 +4,12 @@ #include "functional/functional.h" #include "graph/node.h" -#include "tensors/gpu/cudnn_wrappers.h" #include "tensors/tensor_operators.h" +#ifdef CUDNN +#include "tensors/gpu/cudnn_wrappers.h" +#endif + namespace marian { class DotNodeOp : public NaryNodeOp { @@ -737,6 +740,7 @@ struct HighwayNodeOp : public NaryNodeOp { const std::string type() { return "highway"; } }; +#ifdef CUDNN class ConvolutionOp : public NaryNodeOp { public: ConvolutionOp(const std::vector<Expr>& nodes, @@ -773,4 +777,5 @@ public: protected: ConvolutionWrapper conv_; }; +#endif } diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h index 1e0c71e6..acc84c9c 100644 --- a/src/graph/node_operators_unary.h +++ b/src/graph/node_operators_unary.h @@ -7,7 +7,9 @@ #include "graph/node.h" #include "tensors/tensor_operators.h" -//#include "tensors/gpu/cudnn_wrappers.h" +#ifdef CUDNN +#include "tensors/gpu/cudnn_wrappers.h" +#endif namespace marian { @@ -1068,6 +1070,7 @@ struct ShiftNodeOp : public UnaryNodeOp { // Ptr<sparse::CSR> lf_; //}; +#ifdef CUDNN class PoolingOp : public UnaryNodeOp { public: PoolingOp(Expr x, @@ -1101,6 +1104,7 @@ public: protected: PoolingWrapper pooling_; }; +#endif class PoolingWithMaskingOp : public UnaryNodeOp { public: |