Welcome to mirror list, hosted at ThFree Co, Russian Federation.

conv1d.cc « ops « src - github.com/OpenNMT/CTranslate2.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: b2721770b16aa7c70ada079e44660d3086a6a649 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
#include "ctranslate2/ops/conv1d.h"

#include "dispatch.h"

namespace ctranslate2 {
  namespace ops {

    Conv1D::Conv1D(dim_t stride, dim_t padding, dim_t dilation)
      : _stride(stride)
      , _padding(padding)
      , _dilation(dilation)
    {
    }

    void Conv1D::operator()(const StorageView& input,
                            const StorageView& weight,
                            const StorageView& bias,
                            StorageView& output) const {
      operator()(input, weight, &bias, output);
    }

    void Conv1D::operator()(const StorageView& input,
                            const StorageView& weight,
                            StorageView& output) const {
      operator()(input, weight, nullptr, output);
    }

    void Conv1D::operator()(const StorageView& input,
                            const StorageView& weight,
                            const StorageView* bias,
                            StorageView& output) const {
      const dim_t batch_size = input.dim(0);
      const dim_t input_length = input.dim(2);
      const dim_t out_channels = weight.dim(0);
      const dim_t kernel_size = weight.dim(2);
      const dim_t output_length = (
        input_length + (2 * _padding) - (_dilation * (kernel_size - 1) + 1)) / _stride + 1;

      output.resize({batch_size, out_channels, output_length});

      switch (input.dtype()) {
      case DataType::FLOAT: {
        DEVICE_DISPATCH(input.device(), (compute<D, float>(input, weight, bias, output)));
        break;
      }
#ifdef CT2_WITH_CUDA
      case DataType::FLOAT16: {
        if (input.device() != Device::CUDA)
          throw std::invalid_argument("FP16 Conv1D is only supported on GPU");
        compute<Device::CUDA, float16_t>(input, weight, bias, output);
        break;
      }
#endif
      default:
        throw std::invalid_argument("Conv1D only supports float (or float16 on GPU)");
      }
    }

  }
}