Welcome to mirror list, hosted at ThFree Co, Russian Federation.

conv1d_gpu.cu « ops « src - github.com/OpenNMT/CTranslate2.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: bc3ba7598e326be018b4a162cd8e3c63b55aa0c4 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
#include "ctranslate2/ops/conv1d.h"

#include "cuda/utils.h"

namespace ctranslate2 {
  namespace ops {

    template <Device D, typename T>
    void Conv1D::compute(const StorageView& input,
                         const StorageView& weight,
                         const StorageView* bias,
                         StorageView& output) const {
#ifndef CT2_WITH_CUDNN
      (void)input;
      (void)weight;
      (void)bias;
      (void)output;
      throw std::runtime_error("Conv1D on GPU currently requires the cuDNN library "
                               "which is not integrated in this build");

#else
      const int batch_size = input.dim(0);
      const int in_channels = input.dim(1);
      const int input_length = input.dim(2);
      const int output_length = output.dim(2);
      const int out_channels = weight.dim(0);
      const int kernel_size = weight.dim(2);

      cudnnDataType_t data_type = cuda::get_cudnn_data_type(input.dtype());

      cudnnTensorDescriptor_t input_desc;
      CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc));
      CUDNN_CHECK(cudnnSetTensor4dDescriptor(input_desc, CUDNN_TENSOR_NCHW, data_type,
                                             batch_size, in_channels, 1, input_length));

      cudnnTensorDescriptor_t output_desc;
      CUDNN_CHECK(cudnnCreateTensorDescriptor(&output_desc));
      CUDNN_CHECK(cudnnSetTensor4dDescriptor(output_desc, CUDNN_TENSOR_NCHW, data_type,
                                             batch_size, out_channels, 1, output_length));

      cudnnFilterDescriptor_t weight_desc;
      CUDNN_CHECK(cudnnCreateFilterDescriptor(&weight_desc));
      CUDNN_CHECK(cudnnSetFilter4dDescriptor(weight_desc, data_type, CUDNN_TENSOR_NCHW,
                                             out_channels, in_channels, 1, kernel_size));

      cudnnConvolutionDescriptor_t conv_desc;
      CUDNN_CHECK(cudnnCreateConvolutionDescriptor(&conv_desc));
      CUDNN_CHECK(cudnnSetConvolution2dDescriptor(conv_desc,
                                                  /*pad_h=*/0, /*pad_w=*/_padding,
                                                  /*stride_h=*/1, /*stride_w=*/_stride,
                                                  /*dilation_h=*/1, /*dilation_w=*/_dilation,
                                                  CUDNN_CROSS_CORRELATION,
                                                  data_type));

      cudnnHandle_t handle = cuda::get_cudnn_handle();

      cudnnConvolutionFwdAlgo_t algo = (bias
                                        ? CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM
                                        : CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM);

      size_t workspace_size = 0;
      void* workspace = nullptr;
      CUDNN_CHECK(cudnnGetConvolutionForwardWorkspaceSize(handle,
                                                          input_desc,
                                                          weight_desc,
                                                          conv_desc,
                                                          output_desc,
                                                          algo,
                                                          &workspace_size));

      if (workspace_size > 0)
        workspace = get_allocator<Device::CUDA>().allocate(workspace_size);

      float alpha = 1;
      float beta = 0;

      if (bias) {
        cudnnTensorDescriptor_t bias_desc;
        CUDNN_CHECK(cudnnCreateTensorDescriptor(&bias_desc));
        CUDNN_CHECK(cudnnSetTensor4dDescriptor(bias_desc, CUDNN_TENSOR_NCHW, data_type,
                                               1, out_channels, 1, 1));

        cudnnActivationDescriptor_t activation_desc;
        CUDNN_CHECK(cudnnCreateActivationDescriptor(&activation_desc));
        CUDNN_CHECK(cudnnSetActivationDescriptor(activation_desc,
                                                 CUDNN_ACTIVATION_IDENTITY,
                                                 CUDNN_NOT_PROPAGATE_NAN,
                                                 /*coef=*/0));

        CUDNN_CHECK(cudnnConvolutionBiasActivationForward(handle,
                                                          &alpha,
                                                          input_desc,
                                                          input.buffer(),
                                                          weight_desc,
                                                          weight.buffer(),
                                                          conv_desc,
                                                          algo,
                                                          workspace,
                                                          workspace_size,
                                                          &beta,
                                                          output_desc,
                                                          output.buffer(),
                                                          bias_desc,
                                                          bias->buffer(),
                                                          activation_desc,
                                                          output_desc,
                                                          output.buffer()));

        CUDNN_CHECK(cudnnDestroyActivationDescriptor(activation_desc));
        CUDNN_CHECK(cudnnDestroyTensorDescriptor(bias_desc));

      } else {
        CUDNN_CHECK(cudnnConvolutionForward(handle,
                                            &alpha,
                                            input_desc,
                                            input.buffer(),
                                            weight_desc,
                                            weight.buffer(),
                                            conv_desc,
                                            algo,
                                            workspace,
                                            workspace_size,
                                            &beta,
                                            output_desc,
                                            output.buffer()));
      }

      if (workspace)
        get_allocator<Device::CUDA>().free(workspace);

      CUDNN_CHECK(cudnnDestroyConvolutionDescriptor(conv_desc));
      CUDNN_CHECK(cudnnDestroyFilterDescriptor(weight_desc));
      CUDNN_CHECK(cudnnDestroyTensorDescriptor(input_desc));
      CUDNN_CHECK(cudnnDestroyTensorDescriptor(output_desc));
#endif
    }

#define DECLARE_IMPL(T)                                                 \
    template void                                                       \
    Conv1D::compute<Device::CUDA, T>(const StorageView& input,          \
                                     const StorageView& weight,         \
                                     const StorageView* bias,           \
                                     StorageView& output) const;

    DECLARE_IMPL(float)
    DECLARE_IMPL(float16_t)

  }
}