Welcome to mirror list, hosted at ThFree Co, Russian Federation.

implementations.inl « callbacks « intgemm - github.com/marian-nmt/intgemm/intgemm.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: 9a8f9e1220b19e159a43bb6a409c5393184e2f01 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
/* This file is included multiple times, once per architecture. */
#if defined(CALLBACKS_THIS_IS_SSE2)
  #define CPU_NAME SSE2
  #define INTGEMM_TARGET INTGEMM_SSE2
#elif defined(CALLBACKS_THIS_IS_AVX2)
  #define CPU_NAME AVX2
  #define INTGEMM_TARGET INTGEMM_AVX2
#elif defined(CALLBACKS_THIS_IS_AVX512BW)
  #define CPU_NAME AVX512BW
  #define INTGEMM_TARGET INTGEMM_AVX512BW
#else
  #error "Only SSE2, AVX2 and AVX512BW are supported"
#endif

#if defined(CALLBACKS_THIS_IS_SSE2)
  #define vi vector_t<CPUType::SSE2, int>
  #define vf vector_t<CPUType::SSE2, float>
  #define vd vector_t<CPUType::SSE2, double>
#else
  #define vi vector_t<CPUType::AVX2, int>
  #define vf vector_t<CPUType::AVX2, float>
  #define vd vector_t<CPUType::AVX2, double>
#endif

/* Intel compiler 19.1.0.166 20191121 fails to link constructors with target attributes */
#ifdef __INTEL_COMPILER
#define INTGEMM_TARGET_CONSTRUCTOR
#else
#define INTGEMM_TARGET_CONSTRUCTOR INTGEMM_TARGET
#endif

namespace intgemm {
namespace callbacks {

template <CPUType CpuType, typename CallbackConfig>
class CallbackImpl;

}}

/*
 * Callbacks implementations....
 */
namespace intgemm {
namespace callbacks {

/*
 * Sequence
 */
template <typename... Configs>
class CallbackImpl<CPUType::CPU_NAME, std::tuple<Configs...>> {
public:
  explicit CallbackImpl(const std::tuple<Configs...>& configs) : callbacks(init_callbacks(configs, make_sequence<sizeof...(Configs)>())) {}

  INTGEMM_TARGET void Run(vi input, const OutputBufferInfo& info) {
    run_callbacks(input, info, callbacks, make_sequence<sizeof...(Configs)>());
  }

private:
  using CallbacksTupleType = std::tuple<CallbackImpl<CPUType::CPU_NAME, Configs>...>;

  CallbacksTupleType callbacks;

  template <unsigned... Indices>
  CallbacksTupleType init_callbacks(const std::tuple<Configs...>& configs, sequence<Indices...>) {
    return std::make_tuple(CallbackImpl<CPUType::CPU_NAME, typename std::tuple_element<Indices, std::tuple<Configs...>>::type>(std::get<Indices>(configs))...);
  }

#define RUN_CALLBACKS_PIPELINE_IMPL(vtype) \
  template <unsigned FirstIndex> \
  INTGEMM_TARGET static inline void run_callbacks(vtype input, const OutputBufferInfo& info, CallbacksTupleType& tuple, sequence<FirstIndex>) { \
    std::get<FirstIndex>(tuple)(input, info); \
  } \
  template <unsigned FirstIndex, unsigned SecondIndex, unsigned... RestIndices> \
  INTGEMM_TARGET static inline void run_callbacks(vtype input, const OutputBufferInfo& info, CallbacksTupleType& tuple, sequence<FirstIndex, SecondIndex, RestIndices...>) { \
    auto output = std::get<FirstIndex>(tuple)(input, info); \
    run_callbacks(output, info, tuple, sequence<SecondIndex, RestIndices...>()); \
  }

  RUN_CALLBACKS_PIPELINE_IMPL(vi)
  RUN_CALLBACKS_PIPELINE_IMPL(vf)
  RUN_CALLBACKS_PIPELINE_IMPL(vd)

#undef RUN_CALLBACKS_PIPELINE_IMPL
};

/*
 * Dummy
 */
template <> class CallbackImpl<CPUType::CPU_NAME, Dummy> {
public:
  explicit INTGEMM_TARGET_CONSTRUCTOR CallbackImpl(const Dummy&) {}
  INTGEMM_TARGET void Run(vi, const OutputBufferInfo&) {}
};

/*
 * Write
 */
template <typename Type>
class CallbackImpl<CPUType::CPU_NAME, Write<Type>> {
public:
  explicit INTGEMM_TARGET_CONSTRUCTOR CallbackImpl(const Write<Type>& config) : config(config) {}

  INTGEMM_TARGET void Run(vector_t<CPUType::CPU_NAME, Type> input, const OutputBufferInfo& info) {
    kernels::write(input, config.output_addr, info.row_idx * info.cols + info.col_idx);
  }

private:
  Write<Type> config;
};

/*
 * Unquantize
 */
template <> class CallbackImpl<CPUType::CPU_NAME, Unquantize> {
public:
  explicit INTGEMM_TARGET_CONSTRUCTOR CallbackImpl(const Unquantize& config) : config(config) {
    unquant_mult = set1_ps<vf>(config.unquant_mult);
  }

  INTGEMM_TARGET vf Run(vi input, const OutputBufferInfo&) {
    return kernels::unquantize(input, unquant_mult);
  }

private:
  vf unquant_mult;
  Unquantize config;
};

/*
 * UnquantizeAndWrite
 */
template <> class CallbackImpl<CPUType::CPU_NAME, UnquantizeAndWrite> {
public:
  explicit INTGEMM_TARGET_CONSTRUCTOR CallbackImpl(const UnquantizeAndWrite& config) : config(config) {
    unquant_mult = set1_ps<vf>(config.unquant_mult);
  }

  INTGEMM_TARGET void Run(vi input, const OutputBufferInfo& info) {
    // Workaround gcc 5 internal compiler error that can't read register members in debug.
    vf mult_reg;
#if !defined(__OPTIMIZE__) && (__GNUC__ == 5) && !defined(__clang__) && !defined(__INTEL_COMPILER)
    asm ("vmovdqa %1, %0" : "=x" (mult_reg) : "m" (unquant_mult));
#else
    mult_reg = unquant_mult;
#endif
    auto result = kernels::unquantize(input, mult_reg);
    kernels::write(result, config.output_addr, info.row_idx * info.cols + info.col_idx);
  }

private:
  vf unquant_mult;
  UnquantizeAndWrite config;
};

/*
 * AddBiasAndWrite
 */
template <> class CallbackImpl<CPUType::CPU_NAME, AddBiasAndWrite> {
public:
  explicit INTGEMM_TARGET_CONSTRUCTOR CallbackImpl(const AddBiasAndWrite& config) : config(config) {}

  INTGEMM_TARGET void Run(vi input, const OutputBufferInfo& info) {
    auto result = kernels::add_bias(input, config.bias_addr, info.col_idx);
    kernels::write(result, config.output_addr, info.row_idx * info.cols + info.col_idx);
  }

private:
  AddBiasAndWrite config;
};

/*
 * UnquantizeAndAddBiasAndWrite
 */
template <> class CallbackImpl<CPUType::CPU_NAME, UnquantizeAndAddBiasAndWrite> {
public:
  explicit INTGEMM_TARGET_CONSTRUCTOR CallbackImpl(const UnquantizeAndAddBiasAndWrite& config) : config(config) {
    unquant_mult = set1_ps<vf>(config.unquant_mult);
  }

  INTGEMM_TARGET void Run(vi input, const OutputBufferInfo& info) {
    // Workaround gcc 5 internal compiler error that can't read register members in debug.
    vf mult_reg;
#if !defined(__OPTIMIZE__) && (__GNUC__ == 5) && !defined(__clang__) && !defined(__INTEL_COMPILER)
    asm ("vmovdqa %1, %0" : "=x" (mult_reg) : "m" (unquant_mult));
#else
    mult_reg = unquant_mult;
#endif
    auto result = kernels::unquantize(input, mult_reg);
    result = kernels::add_bias(result, config.bias_addr, info.col_idx);
    kernels::write(result, config.output_addr, info.row_idx * info.cols + info.col_idx);
  }
private:
  vf unquant_mult;
  UnquantizeAndAddBiasAndWrite config;
};

}
}

#undef CPU_NAME
#undef INTGEMM_TARGET
#undef vi
#undef vf
#undef vd