Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/src/graph
diff options
context:
space:
mode:
authorRoman Grundkiewicz <rgrundki@exseed.ed.ac.uk>2018-07-02 19:17:14 +0300
committerRoman Grundkiewicz <rgrundki@exseed.ed.ac.uk>2018-07-02 19:17:14 +0300
commit9bcb4ee5a30de27f9ac8976f06f1c5ef952f7930 (patch)
treeab7c6ce05d4c8758566bd6dd41cb2119e0c69715 /src/graph
parentf85e63d77972b4214550509558a2306f3ddbc989 (diff)
Fix #262 PoolingWrapper dependency
Diffstat (limited to 'src/graph')
-rw-r--r--src/graph/expression_operators.cpp2
-rw-r--r--src/graph/node_operators_binary.h7
-rw-r--r--src/graph/node_operators_unary.h6
3 files changed, 13 insertions, 2 deletions
diff --git a/src/graph/expression_operators.cpp b/src/graph/expression_operators.cpp
index aaf8cb7e..d7e25b5a 100644
--- a/src/graph/expression_operators.cpp
+++ b/src/graph/expression_operators.cpp
@@ -441,6 +441,7 @@ Expr shift(Expr a, Shape shift) {
//}
#ifdef CUDA_FOUND
+#ifdef CUDNN
Expr avg_pooling(Expr x,
int height,
@@ -505,4 +506,5 @@ Expr pooling_with_masking(Expr x, Expr mask, int width, bool isEven) {
}
#endif
+#endif
}
diff --git a/src/graph/node_operators_binary.h b/src/graph/node_operators_binary.h
index dfc75542..b5476e68 100644
--- a/src/graph/node_operators_binary.h
+++ b/src/graph/node_operators_binary.h
@@ -4,9 +4,12 @@
#include "functional/functional.h"
#include "graph/node.h"
-#include "tensors/gpu/cudnn_wrappers.h"
#include "tensors/tensor_operators.h"
+#ifdef CUDNN
+#include "tensors/gpu/cudnn_wrappers.h"
+#endif
+
namespace marian {
class DotNodeOp : public NaryNodeOp {
@@ -737,6 +740,7 @@ struct HighwayNodeOp : public NaryNodeOp {
const std::string type() { return "highway"; }
};
+#ifdef CUDNN
class ConvolutionOp : public NaryNodeOp {
public:
ConvolutionOp(const std::vector<Expr>& nodes,
@@ -773,4 +777,5 @@ public:
protected:
ConvolutionWrapper conv_;
};
+#endif
}
diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h
index 1e0c71e6..acc84c9c 100644
--- a/src/graph/node_operators_unary.h
+++ b/src/graph/node_operators_unary.h
@@ -7,7 +7,9 @@
#include "graph/node.h"
#include "tensors/tensor_operators.h"
-//#include "tensors/gpu/cudnn_wrappers.h"
+#ifdef CUDNN
+#include "tensors/gpu/cudnn_wrappers.h"
+#endif
namespace marian {
@@ -1068,6 +1070,7 @@ struct ShiftNodeOp : public UnaryNodeOp {
// Ptr<sparse::CSR> lf_;
//};
+#ifdef CUDNN
class PoolingOp : public UnaryNodeOp {
public:
PoolingOp(Expr x,
@@ -1101,6 +1104,7 @@ public:
protected:
PoolingWrapper pooling_;
};
+#endif
class PoolingWithMaskingOp : public UnaryNodeOp {
public: