Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/OpenNMT/CTranslate2.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGuillaume Klein <guillaumekln@users.noreply.github.com>2022-10-26 17:32:45 +0300
committerGitHub <noreply@github.com>2022-10-26 17:32:45 +0300
commit1157df7875462b1c46a2403e9335cc864b3c90dd (patch)
tree0b4511f81444678617a023fd6da5539b4bcb6f28
parent0d0b580036c433b65e5d7ba925fd07e69bb700c1 (diff)
Add Conv1D operator and layer (#941)
* Add Conv1D operator and layer * Fix typo in CMakeLists.txt * Skip the tests when the backend is not available * Add a separate CMake option for cuDNN * Minor code cleanup
-rw-r--r--CMakeLists.txt37
-rw-r--r--docs/installation.md4
-rw-r--r--include/ctranslate2/layers/common.h16
-rw-r--r--include/ctranslate2/ops/conv1d.h40
-rw-r--r--include/ctranslate2/ops/ops.h1
-rw-r--r--python/ctranslate2/specs/common_spec.py6
-rw-r--r--python/ctranslate2/specs/model_spec.py2
-rw-r--r--src/cuda/utils.cc41
-rw-r--r--src/cuda/utils.h16
-rw-r--r--src/layers/common.cc26
-rw-r--r--src/ops/conv1d.cc60
-rw-r--r--src/ops/conv1d_cpu.cc123
-rw-r--r--src/ops/conv1d_gpu.cu149
-rw-r--r--tests/benchmark_ops.cc11
-rw-r--r--tests/ops_test.cc132
15 files changed, 662 insertions, 2 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 7de875fd..81ab9c60 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -13,6 +13,7 @@ option(WITH_ACCELERATE "Compile with Accelerate backend" OFF)
option(WITH_OPENBLAS "Compile with OpenBLAS backend" OFF)
option(WITH_RUY "Compile with Ruy backend" OFF)
option(WITH_CUDA "Compile with CUDA backend" OFF)
+option(WITH_CUDNN "Compile with cuDNN backend" OFF)
option(CUDA_DYNAMIC_LOADING "Dynamically load CUDA libraries at runtime" OFF)
option(ENABLE_CPU_DISPATCH "Compile CPU kernels for multiple ISA and dispatch at runtime" ON)
option(ENABLE_PROFILING "Compile with profiling support" OFF)
@@ -103,6 +104,8 @@ set(SOURCES
src/ops/bias_add_cpu.cc
src/ops/concat.cc
src/ops/concat_split_cpu.cc
+ src/ops/conv1d.cc
+ src/ops/conv1d_cpu.cc
src/ops/dequantize.cc
src/ops/dequantize_cpu.cc
src/ops/gather.cc
@@ -332,6 +335,8 @@ if(WITH_DNNL)
add_definitions(-DCT2_WITH_DNNL)
list(APPEND PRIVATE_INCLUDE_DIRECTORIES ${DNNL_INCLUDE_DIR})
list(APPEND LIBRARIES ${DNNL_LIBRARY})
+else()
+ message(WARNING "DNNL library is not enabled: convolution layers will not be supported on CPU")
endif()
if (WITH_ACCELERATE)
@@ -416,6 +421,35 @@ if (WITH_CUDA)
cuda_include_directories(${THRUST_INCLUDE_DIRS})
list(APPEND PRIVATE_INCLUDE_DIRECTORIES ${THRUST_INCLUDE_DIRS})
+ if(WITH_CUDNN)
+ # Find cuDNN includes.
+ find_path(CUDNN_INCLUDE_DIR NAMES cudnn.h HINTS ${CUDA_TOOLKIT_ROOT_DIR}/include)
+ if(CUDNN_INCLUDE_DIR)
+ message(STATUS "Found cuDNN include directory: ${CUDNN_INCLUDE_DIR}")
+ else()
+ message(FATAL_ERROR "cuDNN include directory not found")
+ endif()
+
+ # Find cuDNN libraries.
+ find_library(CUDNN_LIBRARIES
+ NAMES cudnn
+ HINTS ${CUDA_TOOLKIT_ROOT_DIR}/lib ${CUDA_TOOLKIT_ROOT_DIR}/lib64
+ )
+ if(CUDNN_LIBRARIES)
+ message(STATUS "Found cuDNN libraries: ${CUDNN_LIBRARIES}")
+ else()
+ message(FATAL_ERROR "cuDNN libraries not found")
+ endif()
+
+ # libcudnn.so is a shim layer that dynamically loads the correct library at runtime,
+ # so we explictly link against it even with CUDA_DYNAMIC_LOADING.
+ list(APPEND PRIVATE_INCLUDE_DIRECTORIES ${CUDNN_INCLUDE_DIR})
+ list(APPEND LIBRARIES ${CUDNN_LIBRARIES})
+ add_definitions(-DCT2_WITH_CUDNN)
+ else()
+ message(WARNING "cuDNN library is not enabled: convolution layers will not be supported on GPU")
+ endif()
+
if(CUDA_DYNAMIC_LOADING)
if(CUDA_VERSION_MAJOR LESS 11)
message(FATAL_ERROR "Dynamic loading of CUDA libraries requires CUDA 11 or above")
@@ -433,6 +467,7 @@ if (WITH_CUDA)
src/cuda/utils.cc
src/ops/bias_add_gpu.cu
src/ops/concat_split_gpu.cu
+ src/ops/conv1d_gpu.cu
src/ops/dequantize_gpu.cu
src/ops/gather_gpu.cu
src/ops/gumbel_max_gpu.cu
@@ -444,6 +479,8 @@ if (WITH_CUDA)
src/ops/topk_gpu.cu
src/ops/quantize_gpu.cu
)
+elseif(WITH_CUDNN)
+ message(FATAL_ERROR "WITH_CUDNN=ON requires WITH_CUDA=ON")
else()
add_library(${PROJECT_NAME} ${SOURCES})
endif()
diff --git a/docs/installation.md b/docs/installation.md
index 5026590b..1f53d98d 100644
--- a/docs/installation.md
+++ b/docs/installation.md
@@ -98,6 +98,7 @@ The following options can be set with `-DOPTION=VALUE` during the CMake configur
| ENABLE_PROFILING | **OFF**, ON | Enables the integrated profiler (usually disabled in production builds) |
| OPENMP_RUNTIME | **INTEL**, COMP, NONE | Selects or disables the OpenMP runtime:<ul><li>INTEL: Intel OpenMP</li><li>COMP: OpenMP runtime provided by the compiler</li><li>NONE: no OpenMP runtime</li></ul> |
| WITH_CUDA | **OFF**, ON | Compiles with the CUDA backend |
+| WITH_CUDNN | **OFF**, ON | Compiles with the cuDNN backend |
| WITH_DNNL | **OFF**, ON | Compiles with the oneDNN backend (a.k.a. DNNL) |
| WITH_MKL | OFF, **ON** | Compiles with the Intel MKL backend |
| WITH_ACCELERATE | **OFF**, ON | Compiles with the Apple Accelerate backend |
@@ -107,8 +108,9 @@ The following options can be set with `-DOPTION=VALUE` during the CMake configur
Some build options require additional dependencies. See their respective documentation for installation instructions.
* `-DWITH_CUDA=ON` requires [CUDA](https://developer.nvidia.com/cuda-toolkit) >= 10.0
+* `-DWITH_CUDNN=ON` requires [cuDNN](https://developer.nvidia.com/cudnn) >= 8
* `-DWITH_MKL=ON` requires [Intel MKL](https://www.intel.com/content/www/us/en/developer/tools/oneapi/onemkl.html) >= 2019.5
-* `-DWITH_DNNL=ON` requires [oneDNN](https://github.com/oneapi-src/oneDNN) >= 1.5
+* `-DWITH_DNNL=ON` requires [oneDNN](https://github.com/oneapi-src/oneDNN) >= 2.0
* `-DWITH_ACCELERATE=ON` requires [Accelerate](https://developer.apple.com/documentation/accelerate)
* `-DWITH_OPENBLAS=ON` requires [OpenBLAS](https://github.com/xianyi/OpenBLAS)
diff --git a/include/ctranslate2/layers/common.h b/include/ctranslate2/layers/common.h
index 69dd2f10..5fa3f26b 100644
--- a/include/ctranslate2/layers/common.h
+++ b/include/ctranslate2/layers/common.h
@@ -161,5 +161,21 @@ namespace ctranslate2 {
const StorageView& _gamma;
};
+ class Conv1D : public Layer {
+ public:
+ Conv1D(const models::Model& model,
+ const std::string& scope,
+ dim_t stride = 1,
+ dim_t padding = 0,
+ dim_t dilation = 1);
+ DataType output_type() const override;
+ dim_t output_size() const override;
+ void operator()(const StorageView& input, StorageView& output) const;
+ private:
+ const ops::Conv1D _conv_op;
+ const StorageView& _weight;
+ const StorageView* _bias;
+ };
+
}
}
diff --git a/include/ctranslate2/ops/conv1d.h b/include/ctranslate2/ops/conv1d.h
new file mode 100644
index 00000000..b7533dcd
--- /dev/null
+++ b/include/ctranslate2/ops/conv1d.h
@@ -0,0 +1,40 @@
+#pragma once
+
+#include "activation.h"
+#include "op.h"
+
+namespace ctranslate2 {
+ namespace ops {
+
+ class Conv1D : public Op {
+ public:
+ Conv1D(dim_t stride = 1, dim_t padding = 0, dim_t dilation = 1);
+
+ void operator()(const StorageView& input,
+ const StorageView& weight,
+ const StorageView& bias,
+ StorageView& output) const;
+
+ void operator()(const StorageView& input,
+ const StorageView& weight,
+ StorageView& output) const;
+
+ private:
+ dim_t _stride;
+ dim_t _padding;
+ dim_t _dilation;
+
+ void operator()(const StorageView& input,
+ const StorageView& weight,
+ const StorageView* bias,
+ StorageView& output) const;
+
+ template <Device D, typename T>
+ void compute(const StorageView& input,
+ const StorageView& weight,
+ const StorageView* bias,
+ StorageView& output) const;
+ };
+
+ }
+}
diff --git a/include/ctranslate2/ops/ops.h b/include/ctranslate2/ops/ops.h
index 215952bd..9262118b 100644
--- a/include/ctranslate2/ops/ops.h
+++ b/include/ctranslate2/ops/ops.h
@@ -5,6 +5,7 @@
#include "add.h"
#include "bias_add.h"
#include "concat.h"
+#include "conv1d.h"
#include "cos.h"
#include "gather.h"
#include "gelu.h"
diff --git a/python/ctranslate2/specs/common_spec.py b/python/ctranslate2/specs/common_spec.py
index f657dcc7..d16a787f 100644
--- a/python/ctranslate2/specs/common_spec.py
+++ b/python/ctranslate2/specs/common_spec.py
@@ -38,6 +38,12 @@ class LinearSpec(model_spec.LayerSpec):
return isinstance(self.bias, np.ndarray)
+class Conv1DSpec(model_spec.LayerSpec):
+ def __init__(self):
+ self.weight = None
+ self.bias = model_spec.OPTIONAL
+
+
class EmbeddingsSpec(model_spec.LayerSpec):
def __init__(self):
self.weight = None
diff --git a/python/ctranslate2/specs/model_spec.py b/python/ctranslate2/specs/model_spec.py
index 10b0428b..6fd221d3 100644
--- a/python/ctranslate2/specs/model_spec.py
+++ b/python/ctranslate2/specs/model_spec.py
@@ -159,7 +159,7 @@ class LayerSpec(metaclass=FrozenMeta):
return
scale = None
- is_quantizable = "weight" in name
+ is_quantizable = hasattr(spec, "%s_scale" % name)
if is_quantizable:
if quantization == "int16":
diff --git a/src/cuda/utils.cc b/src/cuda/utils.cc
index 47fdb5ca..621a9088 100644
--- a/src/cuda/utils.cc
+++ b/src/cuda/utils.cc
@@ -102,6 +102,47 @@ namespace ctranslate2 {
return cublas_handle.get();
}
+#ifdef CT2_WITH_CUDNN
+ class CudnnHandle {
+ public:
+ CudnnHandle() {
+ CUDA_CHECK(cudaGetDevice(&_device));
+ CUDNN_CHECK(cudnnCreate(&_handle));
+ CUDNN_CHECK(cudnnSetStream(_handle, get_cuda_stream()));
+ }
+ ~CudnnHandle() {
+ ScopedDeviceSetter scoped_device_setter(Device::CUDA, _device);
+ cudnnDestroy(_handle);
+ }
+ cudnnHandle_t get() const {
+ return _handle;
+ }
+ private:
+ int _device;
+ cudnnHandle_t _handle;
+ };
+
+ cudnnHandle_t get_cudnn_handle() {
+ static thread_local CudnnHandle cudnn_handle;
+ return cudnn_handle.get();
+ }
+
+ cudnnDataType_t get_cudnn_data_type(DataType dtype) {
+ switch (dtype) {
+ case DataType::FLOAT:
+ return CUDNN_DATA_FLOAT;
+ case DataType::FLOAT16:
+ return CUDNN_DATA_HALF;
+ case DataType::INT32:
+ return CUDNN_DATA_INT32;
+ case DataType::INT8:
+ return CUDNN_DATA_INT8;
+ default:
+ throw std::invalid_argument("No cuDNN data type for type " + dtype_name(dtype));
+ }
+ }
+#endif
+
int get_gpu_count() {
int gpu_count = 0;
cudaError_t status = cudaGetDeviceCount(&gpu_count);
diff --git a/src/cuda/utils.h b/src/cuda/utils.h
index 91150872..d9de4d79 100644
--- a/src/cuda/utils.h
+++ b/src/cuda/utils.h
@@ -6,6 +6,10 @@
#include <cublas_v2.h>
#include <thrust/execution_policy.h>
+#ifdef CT2_WITH_CUDNN
+# include <cudnn.h>
+#endif
+
#include "ctranslate2/types.h"
#include "ctranslate2/utils.h"
@@ -28,10 +32,22 @@ namespace ctranslate2 {
+ std::string(ctranslate2::cuda::cublasGetStatusName(status))); \
}
+#define CUDNN_CHECK(ans) \
+ { \
+ cudnnStatus_t status = (ans); \
+ if (status != CUDNN_STATUS_SUCCESS) \
+ THROW_RUNTIME_ERROR("cuDNN failed with status " \
+ + std::string(cudnnGetErrorString(status))); \
+ }
+
const char* cublasGetStatusName(cublasStatus_t status);
cudaStream_t get_cuda_stream();
cublasHandle_t get_cublas_handle();
+#ifdef CT2_WITH_CUDNN
+ cudnnHandle_t get_cudnn_handle();
+ cudnnDataType_t get_cudnn_data_type(DataType dtype);
+#endif
int get_gpu_count();
bool has_gpu();
diff --git a/src/layers/common.cc b/src/layers/common.cc
index e3d7f22d..75ba0c12 100644
--- a/src/layers/common.cc
+++ b/src/layers/common.cc
@@ -365,5 +365,31 @@ namespace ctranslate2 {
_norm_op(_beta, _gamma, input, output);
}
+
+ Conv1D::Conv1D(const models::Model& model,
+ const std::string& scope,
+ dim_t stride,
+ dim_t padding,
+ dim_t dilation)
+ : _conv_op(stride, padding, dilation)
+ , _weight(model.get_variable(scope + "/weight"))
+ , _bias(model.get_variable_if_exists(scope + "/bias")) {
+ }
+
+ DataType Conv1D::output_type() const {
+ return _weight.dtype();
+ }
+
+ dim_t Conv1D::output_size() const {
+ return _weight.dim(0);
+ }
+
+ void Conv1D::operator()(const StorageView& input, StorageView& output) const {
+ if (_bias)
+ _conv_op(input, _weight, *_bias, output);
+ else
+ _conv_op(input, _weight, output);
+ }
+
}
}
diff --git a/src/ops/conv1d.cc b/src/ops/conv1d.cc
new file mode 100644
index 00000000..b2721770
--- /dev/null
+++ b/src/ops/conv1d.cc
@@ -0,0 +1,60 @@
+#include "ctranslate2/ops/conv1d.h"
+
+#include "dispatch.h"
+
+namespace ctranslate2 {
+ namespace ops {
+
+ Conv1D::Conv1D(dim_t stride, dim_t padding, dim_t dilation)
+ : _stride(stride)
+ , _padding(padding)
+ , _dilation(dilation)
+ {
+ }
+
+ void Conv1D::operator()(const StorageView& input,
+ const StorageView& weight,
+ const StorageView& bias,
+ StorageView& output) const {
+ operator()(input, weight, &bias, output);
+ }
+
+ void Conv1D::operator()(const StorageView& input,
+ const StorageView& weight,
+ StorageView& output) const {
+ operator()(input, weight, nullptr, output);
+ }
+
+ void Conv1D::operator()(const StorageView& input,
+ const StorageView& weight,
+ const StorageView* bias,
+ StorageView& output) const {
+ const dim_t batch_size = input.dim(0);
+ const dim_t input_length = input.dim(2);
+ const dim_t out_channels = weight.dim(0);
+ const dim_t kernel_size = weight.dim(2);
+ const dim_t output_length = (
+ input_length + (2 * _padding) - (_dilation * (kernel_size - 1) + 1)) / _stride + 1;
+
+ output.resize({batch_size, out_channels, output_length});
+
+ switch (input.dtype()) {
+ case DataType::FLOAT: {
+ DEVICE_DISPATCH(input.device(), (compute<D, float>(input, weight, bias, output)));
+ break;
+ }
+#ifdef CT2_WITH_CUDA
+ case DataType::FLOAT16: {
+ if (input.device() != Device::CUDA)
+ throw std::invalid_argument("FP16 Conv1D is only supported on GPU");
+ compute<Device::CUDA, float16_t>(input, weight, bias, output);
+ break;
+ }
+#endif
+ default:
+ throw std::invalid_argument("Conv1D only supports float (or float16 on GPU)");
+ }
+ }
+
+ }
+}
diff --git a/src/ops/conv1d_cpu.cc b/src/ops/conv1d_cpu.cc
new file mode 100644
index 00000000..b2eb05be
--- /dev/null
+++ b/src/ops/conv1d_cpu.cc
@@ -0,0 +1,123 @@
+#include "ctranslate2/ops/conv1d.h"
+
+#ifdef CT2_WITH_DNNL
+# include <dnnl.hpp>
+#endif
+
+namespace ctranslate2 {
+ namespace ops {
+
+ template<>
+ void Conv1D::compute<Device::CPU, float>(const StorageView& input,
+ const StorageView& weight,
+ const StorageView* bias,
+ StorageView& output) const {
+#ifndef CT2_WITH_DNNL
+ (void)input;
+ (void)weight;
+ (void)bias;
+ (void)output;
+ throw std::runtime_error("Conv1D on CPU currently requires the oneDNN library (a.k.a. DNNL) "
+ "which is not integrated in this build");
+
+#else
+ dnnl::engine engine(dnnl::engine::kind::cpu, 0);
+ dnnl::stream engine_stream(engine);
+
+ dnnl::memory::dims input_dims(input.shape().begin(), input.shape().end());
+ dnnl::memory::dims output_dims(output.shape().begin(), output.shape().end());
+ dnnl::memory::dims weight_dims(weight.shape().begin(), weight.shape().end());
+
+ using tag = dnnl::memory::format_tag;
+ using dt = dnnl::memory::data_type;
+
+ dnnl::memory::desc input_md(input_dims, dt::f32, tag::any);
+ dnnl::memory::desc output_md(output_dims, dt::f32, tag::any);
+ dnnl::memory::desc weight_md(weight_dims, dt::f32, tag::any);
+
+ dnnl::memory input_mem({input_dims, dt::f32, tag::ncw}, engine,
+ const_cast<void*>(input.buffer()));
+ dnnl::memory output_mem({output_dims, dt::f32, tag::ncw}, engine,
+ output.buffer());
+ dnnl::memory weight_mem({weight_dims, dt::f32, tag::oiw}, engine,
+ const_cast<void*>(weight.buffer()));
+
+ dnnl::memory::dims stride{_stride};
+ dnnl::memory::dims dilation{_dilation > 1 ? _dilation : 0};
+ dnnl::memory::dims padding{_padding};
+
+ std::unique_ptr<dnnl::convolution_forward::desc> conv_desc;
+ std::unordered_map<int, dnnl::memory> args;
+ args.reserve(4);
+
+ if (bias) {
+ dnnl::memory::dims bias_dims(bias->shape().begin(), bias->shape().end());
+ dnnl::memory::desc bias_md(bias_dims, dt::f32, tag::a);
+ dnnl::memory bias_mem(bias_md, engine, const_cast<void*>(bias->buffer()));
+ args.emplace(DNNL_ARG_BIAS, bias_mem);
+
+ conv_desc = std::make_unique<dnnl::convolution_forward::desc>(
+ dnnl::prop_kind::forward_inference,
+ dnnl::algorithm::convolution_direct,
+ input_md,
+ weight_md,
+ bias_md,
+ output_md,
+ stride,
+ dilation,
+ padding,
+ padding);
+
+ } else {
+ conv_desc = std::make_unique<dnnl::convolution_forward::desc>(
+ dnnl::prop_kind::forward_inference,
+ dnnl::algorithm::convolution_direct,
+ input_md,
+ weight_md,
+ output_md,
+ stride,
+ dilation,
+ padding,
+ padding);
+ }
+
+ dnnl::convolution_forward::primitive_desc conv_pd(*conv_desc, engine);
+
+ dnnl::memory conv_input_mem = input_mem;
+ dnnl::memory conv_weight_mem = weight_mem;
+ dnnl::memory conv_output_mem = output_mem;
+
+ if (conv_pd.src_desc() != input_mem.get_desc()) {
+ conv_input_mem = dnnl::memory(conv_pd.src_desc(), engine);
+ dnnl::reorder(input_mem, conv_input_mem)
+ .execute(engine_stream, input_mem, conv_input_mem);
+ }
+
+ if (conv_pd.weights_desc() != weight_mem.get_desc()) {
+ conv_weight_mem = dnnl::memory(conv_pd.weights_desc(), engine);
+ dnnl::reorder(weight_mem, conv_weight_mem)
+ .execute(engine_stream, weight_mem, conv_weight_mem);
+ }
+
+ if (conv_pd.dst_desc() != output_mem.get_desc()) {
+ conv_output_mem = dnnl::memory(conv_pd.dst_desc(), engine);
+ }
+
+ args.emplace(DNNL_ARG_SRC, conv_input_mem);
+ args.emplace(DNNL_ARG_WEIGHTS, conv_weight_mem);
+ args.emplace(DNNL_ARG_DST, conv_output_mem);
+
+ dnnl::convolution_forward conv(conv_pd);
+ conv.execute(engine_stream, args);
+
+ if (conv_pd.dst_desc() != output_mem.get_desc()) {
+ dnnl::reorder(conv_output_mem, output_mem)
+ .execute(engine_stream, conv_output_mem, output_mem);
+ }
+
+ engine_stream.wait();
+#endif
+ }
+
+ }
+}
diff --git a/src/ops/conv1d_gpu.cu b/src/ops/conv1d_gpu.cu
new file mode 100644
index 00000000..bc3ba759
--- /dev/null
+++ b/src/ops/conv1d_gpu.cu
@@ -0,0 +1,149 @@
+#include "ctranslate2/ops/conv1d.h"
+
+#include "cuda/utils.h"
+
+namespace ctranslate2 {
+ namespace ops {
+
+ template <Device D, typename T>
+ void Conv1D::compute(const StorageView& input,
+ const StorageView& weight,
+ const StorageView* bias,
+ StorageView& output) const {
+#ifndef CT2_WITH_CUDNN
+ (void)input;
+ (void)weight;
+ (void)bias;
+ (void)output;
+ throw std::runtime_error("Conv1D on GPU currently requires the cuDNN library "
+ "which is not integrated in this build");
+
+#else
+ const int batch_size = input.dim(0);
+ const int in_channels = input.dim(1);
+ const int input_length = input.dim(2);
+ const int output_length = output.dim(2);
+ const int out_channels = weight.dim(0);
+ const int kernel_size = weight.dim(2);
+
+ cudnnDataType_t data_type = cuda::get_cudnn_data_type(input.dtype());
+
+ cudnnTensorDescriptor_t input_desc;
+ CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc));
+ CUDNN_CHECK(cudnnSetTensor4dDescriptor(input_desc, CUDNN_TENSOR_NCHW, data_type,
+ batch_size, in_channels, 1, input_length));
+
+ cudnnTensorDescriptor_t output_desc;
+ CUDNN_CHECK(cudnnCreateTensorDescriptor(&output_desc));
+ CUDNN_CHECK(cudnnSetTensor4dDescriptor(output_desc, CUDNN_TENSOR_NCHW, data_type,
+ batch_size, out_channels, 1, output_length));
+
+ cudnnFilterDescriptor_t weight_desc;
+ CUDNN_CHECK(cudnnCreateFilterDescriptor(&weight_desc));
+ CUDNN_CHECK(cudnnSetFilter4dDescriptor(weight_desc, data_type, CUDNN_TENSOR_NCHW,
+ out_channels, in_channels, 1, kernel_size));
+
+ cudnnConvolutionDescriptor_t conv_desc;
+ CUDNN_CHECK(cudnnCreateConvolutionDescriptor(&conv_desc));
+ CUDNN_CHECK(cudnnSetConvolution2dDescriptor(conv_desc,
+ /*pad_h=*/0, /*pad_w=*/_padding,
+ /*stride_h=*/1, /*stride_w=*/_stride,
+ /*dilation_h=*/1, /*dilation_w=*/_dilation,
+ CUDNN_CROSS_CORRELATION,
+ data_type));
+
+ cudnnHandle_t handle = cuda::get_cudnn_handle();
+
+ cudnnConvolutionFwdAlgo_t algo = (bias
+ ? CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM
+ : CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM);
+
+ size_t workspace_size = 0;
+ void* workspace = nullptr;
+ CUDNN_CHECK(cudnnGetConvolutionForwardWorkspaceSize(handle,
+ input_desc,
+ weight_desc,
+ conv_desc,
+ output_desc,
+ algo,
+ &workspace_size));
+
+ if (workspace_size > 0)
+ workspace = get_allocator<Device::CUDA>().allocate(workspace_size);
+
+ float alpha = 1;
+ float beta = 0;
+
+ if (bias) {
+ cudnnTensorDescriptor_t bias_desc;
+ CUDNN_CHECK(cudnnCreateTensorDescriptor(&bias_desc));
+ CUDNN_CHECK(cudnnSetTensor4dDescriptor(bias_desc, CUDNN_TENSOR_NCHW, data_type,
+ 1, out_channels, 1, 1));
+
+ cudnnActivationDescriptor_t activation_desc;
+ CUDNN_CHECK(cudnnCreateActivationDescriptor(&activation_desc));
+ CUDNN_CHECK(cudnnSetActivationDescriptor(activation_desc,
+ CUDNN_ACTIVATION_IDENTITY,
+ CUDNN_NOT_PROPAGATE_NAN,
+ /*coef=*/0));
+
+ CUDNN_CHECK(cudnnConvolutionBiasActivationForward(handle,
+ &alpha,
+ input_desc,
+ input.buffer(),
+ weight_desc,
+ weight.buffer(),
+ conv_desc,
+ algo,
+ workspace,
+ workspace_size,
+ &beta,
+ output_desc,
+ output.buffer(),
+ bias_desc,
+ bias->buffer(),
+ activation_desc,
+ output_desc,
+ output.buffer()));
+
+ CUDNN_CHECK(cudnnDestroyActivationDescriptor(activation_desc));
+ CUDNN_CHECK(cudnnDestroyTensorDescriptor(bias_desc));
+
+ } else {
+ CUDNN_CHECK(cudnnConvolutionForward(handle,
+ &alpha,
+ input_desc,
+ input.buffer(),
+ weight_desc,
+ weight.buffer(),
+ conv_desc,
+ algo,
+ workspace,
+ workspace_size,
+ &beta,
+ output_desc,
+ output.buffer()));
+ }
+
+ if (workspace)
+ get_allocator<Device::CUDA>().free(workspace);
+
+ CUDNN_CHECK(cudnnDestroyConvolutionDescriptor(conv_desc));
+ CUDNN_CHECK(cudnnDestroyFilterDescriptor(weight_desc));
+ CUDNN_CHECK(cudnnDestroyTensorDescriptor(input_desc));
+ CUDNN_CHECK(cudnnDestroyTensorDescriptor(output_desc));
+#endif
+ }
+
+#define DECLARE_IMPL(T) \
+ template void \
+ Conv1D::compute<Device::CUDA, T>(const StorageView& input, \
+ const StorageView& weight, \
+ const StorageView* bias, \
+ StorageView& output) const;
+
+ DECLARE_IMPL(float)
+ DECLARE_IMPL(float16_t)
+
+ }
+}
diff --git a/tests/benchmark_ops.cc b/tests/benchmark_ops.cc
index 4d32739c..8ba0fd13 100644
--- a/tests/benchmark_ops.cc
+++ b/tests/benchmark_ops.cc
@@ -105,6 +105,15 @@ void benchmark_dequantize(Device device) {
BENCHMARK(dequantize_op(x, input_scale, weight_scale, false, true, y), 100000);
}
+void benchmark_conv1d(Device device) {
+ StorageView x({1, 768, 3000}, DataType::FLOAT, device);
+ StorageView weight({768, 768, 3}, DataType::FLOAT, device);
+ StorageView bias({768}, DataType::FLOAT, device);
+ StorageView y(device);
+ const ops::Conv1D conv_op{2, 1};
+ BENCHMARK(conv_op(x, weight, bias, y), 100);
+}
+
int main(int argc, char* argv[]) {
if (argc < 3) {
std::cerr << "usage: " << argv[0] << " op device [dtype]" << std::endl;
@@ -140,6 +149,8 @@ int main(int argc, char* argv[]) {
benchmark_quantize(device, dtype);
else if (op == "dequantize")
benchmark_dequantize(device);
+ else if (op == "conv1d")
+ benchmark_conv1d(device);
return 0;
}
diff --git a/tests/ops_test.cc b/tests/ops_test.cc
index 9d56a971..a99aac67 100644
--- a/tests/ops_test.cc
+++ b/tests/ops_test.cc
@@ -828,6 +828,138 @@ TEST_P(OpDeviceTest, Max) {
});
}
+#ifndef CT2_WITH_DNNL
+# define GUARD_CONV1D_CPU_TEST GTEST_SKIP() << "Conv1D tests on CPU require oneDNN"
+#else
+# define GUARD_CONV1D_CPU_TEST do {} while (0)
+#endif
+
+#ifndef CT2_WITH_CUDNN
+# define GUARD_CONV1D_GPU_TEST GTEST_SKIP() << "Conv1D tests on GPU require cuDNN"
+#else
+# define GUARD_CONV1D_GPU_TEST do {} while (0)
+#endif
+
+static const StorageView conv_input({2, 2, 3}, std::vector<float>{
+ 0.5728129f, 0.8784890f, 0.2029965f, 0.3689166f, 0.6570600f, 0.9202735f,
+ 0.7081605f, 0.3570334f, 0.9339380f, 0.8162224f, 0.0597404f, 0.4628246f});
+
+static const StorageView conv_weight({4, 2, 2}, std::vector<float>{
+ 0.4969918f, 0.3711241f, 0.1489926f, -0.3010672f,
+ -0.2055028f, 0.2540314f, 0.3566069f, -0.1201057f,
+ -0.0737700f, -0.0630847f, -0.2370351f, -0.0451550f,
+ 0.0186623f, 0.3600836f, -0.2889268f, -0.4857445f});
+
+static const StorageView conv_bias({4}, std::vector<float>{
+ 0.4631361f, -0.1047785f, 0.1047658f, -0.3157263f});
+
+TEST_P(OpDeviceFPTest, Conv1D) {
+ const Device device = GetParam().first;
+ if (device == Device::CUDA)
+ GUARD_CONV1D_GPU_TEST;
+ else
+ GUARD_CONV1D_CPU_TEST;
+ const DataType dtype = GetParam().second;
+ const StorageView expected({2, 4, 2}, std::vector<float>{
+ 0.9309945f, 0.7959076f, 0.0533122f, -0.1099610f,
+ -0.1100256f, -0.1701476f, -0.4144599f, -0.8630960f,
+ 1.0512151f, 0.8567453f, 0.1242856f, 0.0248157f,
+ -0.1661695f, -0.0155492f, -0.4387956f, -0.2148425f});
+ StorageView output(dtype, device);
+ ops::Conv1D()(conv_input.to(device).to(dtype),
+ conv_weight.to(device).to(dtype),
+ conv_bias.to(device).to(dtype),
+ output);
+ EXPECT_EQ(output.dtype(), dtype);
+ expect_storage_eq(output.to_float(), expected, 1e-3);
+}
+
+TEST_P(OpDeviceFPTest, Conv1DNoBias) {
+ const Device device = GetParam().first;
+ if (device == Device::CUDA)
+ GUARD_CONV1D_GPU_TEST;
+ else
+ GUARD_CONV1D_CPU_TEST;
+ const DataType dtype = GetParam().second;
+ const StorageView expected({2, 4, 2}, std::vector<float>{
+ 0.4678584f, 0.3327716f, 0.1580907f, -0.005182412f,
+ -0.2147914f, -0.2749133f, -0.09873369f, -0.5473697f,
+ 0.5880789f, 0.3936091f, 0.2290641f, 0.1295942f,
+ -0.2709353f, -0.120315f, -0.1230693f, 0.1008837f});
+ StorageView output(dtype, device);
+ ops::Conv1D()(conv_input.to(device).to(dtype),
+ conv_weight.to(device).to(dtype),
+ output);
+ EXPECT_EQ(output.dtype(), dtype);
+ expect_storage_eq(output.to_float(), expected, 1e-3);
+}
+
+TEST_P(OpDeviceFPTest, Conv1DPadding) {
+ const Device device = GetParam().first;
+ if (device == Device::CUDA)
+ GUARD_CONV1D_GPU_TEST;
+ else
+ GUARD_CONV1D_CPU_TEST;
+ const DataType dtype = GetParam().second;
+ const StorageView expected({2, 4, 4}, std::vector<float>{
+ 0.5646521f, 0.9309945f, 0.7959076f, 0.7011377f,
+ -0.0035750f, 0.0533122f, -0.1099610f, 0.1816810f,
+ 0.0519716f, -0.1100256f, -0.1701476f, -0.1283464f,
+ -0.2886650f, -0.4144599f, -0.8630960f, -0.5778296f,
+ 0.4802138f, 1.0512151f, 0.8567453f, 0.9962531f,
+ -0.0229165f, 0.1242856f, 0.0248157f, -0.1316590f,
+ 0.0232352f, -0.1661695f, -0.0155492f, -0.0738365f,
+ -0.4572049f, -0.4387956f, -0.2148425f, -0.4320193f});
+ StorageView output(dtype, device);
+ ops::Conv1D(1, 1)(conv_input.to(device).to(dtype),
+ conv_weight.to(device).to(dtype),
+ conv_bias.to(device).to(dtype),
+ output);
+ EXPECT_EQ(output.dtype(), dtype);
+ expect_storage_eq(output.to_float(), expected, 1e-3);
+}
+
+TEST_P(OpDeviceFPTest, Conv1DStride) {
+ const Device device = GetParam().first;
+ if (device == Device::CUDA)
+ GUARD_CONV1D_GPU_TEST;
+ else
+ GUARD_CONV1D_CPU_TEST;
+ const DataType dtype = GetParam().second;
+ const StorageView expected({2, 4, 1}, std::vector<float>{
+ 0.9309945f, 0.0533122f, -0.1100256f, -0.4144599f,
+ 1.0512151f, 0.1242856f, -0.1661695f, -0.4387956f});
+ StorageView output(dtype, device);
+ ops::Conv1D(2)(conv_input.to(device).to(dtype),
+ conv_weight.to(device).to(dtype),
+ conv_bias.to(device).to(dtype),
+ output);
+ EXPECT_EQ(output.dtype(), dtype);
+ expect_storage_eq(output.to_float(), expected, 1e-3);
+}
+
+TEST_P(OpDeviceFPTest, Conv1DPaddingAndStride) {
+ const Device device = GetParam().first;
+ if (device == Device::CUDA)
+ GUARD_CONV1D_GPU_TEST;
+ else
+ GUARD_CONV1D_CPU_TEST;
+ const DataType dtype = GetParam().second;
+ const StorageView expected({2, 4, 2}, std::vector<float>{
+ 0.5646521f, 0.7959076f, -0.0035750f, -0.1099610f,
+ 0.0519716f, -0.1701476f, -0.2886650f, -0.8630960f,
+ 0.4802138f, 0.8567453f, -0.0229165f, 0.0248157f,
+ 0.0232352f, -0.0155492f, -0.4572049f, -0.2148425f});
+ StorageView output(dtype, device);
+ ops::Conv1D(2, 1)(conv_input.to(device).to(dtype),
+ conv_weight.to(device).to(dtype),
+ conv_bias.to(device).to(dtype),
+ output);
+ EXPECT_EQ(output.dtype(), dtype);
+ expect_storage_eq(output.to_float(), expected, 1e-3);
+}
+
+
static std::string fp_test_name(::testing::TestParamInfo<std::pair<Device, DataType>> param_info) {
return dtype_name(param_info.param.second);
}