Welcome to mirror list, hosted at ThFree Co, Russian Federation.

implementations.inl « callbacks « intgemm - github.com/marian-nmt/intgemm.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: 47d2aa4503ac3e981ec30ae068b51b0210eeceef (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
/* This file is included multiple times, once per architecture. */
#if defined(CALLBACKS_THIS_IS_SSE2)
  #define CPU_NAME SSE2
  #define CPU_ATTR INTGEMM_SSE2
#elif defined(CALLBACKS_THIS_IS_AVX2)
  #define CPU_NAME AVX2
  #define CPU_ATTR INTGEMM_AVX2
#elif defined(CALLBACKS_THIS_IS_AVX512BW)
  #define CPU_NAME AVX512BW
  #define CPU_ATTR INTGEMM_AVX512BW
#else
  #error "Only SSE2, AVX2 and AVX512BW are supported"
#endif

#if defined(CALLBACKS_THIS_IS_SSE2)
  #define vi vector_t<CPUType::SSE2, int>
  #define vf vector_t<CPUType::SSE2, float>
  #define vd vector_t<CPUType::SSE2, double>
#else
  #define vi vector_t<CPUType::AVX2, int>
  #define vf vector_t<CPUType::AVX2, float>
  #define vd vector_t<CPUType::AVX2, double>
#endif

namespace intgemm {
namespace callbacks {

template <CPUType CpuType, typename CallbackConfig>
class CallbackImpl;

}}

/*
 * Callbacks implementations....
 */
namespace intgemm {
namespace callbacks {

/*
 * Sequence
 */
template <typename... Configs>
class CallbackImpl<CPUType::CPU_NAME, std::tuple<Configs...>> {
public:
  CPU_ATTR CallbackImpl(const std::tuple<Configs...>& configs) : callbacks(init_callbacks(configs, make_sequence<sizeof...(Configs)>())) {}

  CPU_ATTR void operator()(vi input, const OutputBufferInfo& info) {
    run_callbacks(input, info, callbacks, make_sequence<sizeof...(Configs)>());
  }

private:
  using CallbacksTupleType = std::tuple<CallbackImpl<CPUType::CPU_NAME, Configs>...>;

  CallbacksTupleType callbacks;

  template <unsigned... Indices>
  CallbacksTupleType init_callbacks(const std::tuple<Configs...>& configs, sequence<Indices...>) {
    return std::make_tuple(CallbackImpl<CPUType::CPU_NAME, typename std::tuple_element<Indices, std::tuple<Configs...>>::type>(std::get<Indices>(configs))...);
  }

#define RUN_CALLBACKS_PIPELINE_IMPL(vtype) \
  template <unsigned FirstIndex> \
  CPU_ATTR static inline void run_callbacks(vtype input, const OutputBufferInfo& info, CallbacksTupleType& tuple, sequence<FirstIndex>) { \
    std::get<FirstIndex>(tuple)(input, info); \
  } \
  template <unsigned FirstIndex, unsigned SecondIndex, unsigned... RestIndices> \
  CPU_ATTR static inline void run_callbacks(vtype input, const OutputBufferInfo& info, CallbacksTupleType& tuple, sequence<FirstIndex, SecondIndex, RestIndices...>) { \
    auto output = std::get<FirstIndex>(tuple)(input, info); \
    run_callbacks(output, info, tuple, sequence<SecondIndex, RestIndices...>()); \
  }

  RUN_CALLBACKS_PIPELINE_IMPL(vi)
  RUN_CALLBACKS_PIPELINE_IMPL(vf)
  RUN_CALLBACKS_PIPELINE_IMPL(vd)

#undef RUN_CALLBACKS_PIPELINE_IMPL
};

/*
 * Dummy
 */
template <> class CallbackImpl<CPUType::CPU_NAME, Dummy> {
public:
  CPU_ATTR CallbackImpl(const Dummy&) {}
  CPU_ATTR void operator()(vi, const OutputBufferInfo&) {}
};

/*
 * Write
 */
template <typename Type>
class CallbackImpl<CPUType::CPU_NAME, Write<Type>> {
public:
  CPU_ATTR CallbackImpl(const Write<Type>& config) : config(config) {}

  CPU_ATTR void operator()(vector_t<CPUType::CPU_NAME, Type> input, const OutputBufferInfo& info) {
    kernels::write(input, config.output_addr, info.row_idx * info.cols + info.col_idx);
  }

private:
  Write<Type> config;
};

/*
 * Unquantize
 */
template <> class CallbackImpl<CPUType::CPU_NAME, Unquantize> {
public:
  CPU_ATTR CallbackImpl(const Unquantize& config) : config(config) {
    unquant_mult = set1_ps<vf>(config.unquant_mult);
  }

  CPU_ATTR vf operator()(vi input, const OutputBufferInfo&) {
    return kernels::unquantize(input, unquant_mult);
  }

private:
  vf unquant_mult;
  Unquantize config;
};

/*
 * UnquantizeAndWrite
 */
template <> class CallbackImpl<CPUType::CPU_NAME, UnquantizeAndWrite> {
public:
  CPU_ATTR CallbackImpl(const UnquantizeAndWrite& config) : config(config) {
    unquant_mult = set1_ps<vf>(config.unquant_mult);
  }

  CPU_ATTR void operator()(vi input, const OutputBufferInfo& info) {
    // Workaround gcc 5 internal compiler error that can't read register members in debug.
    vf mult_reg;
#if !defined(__OPTIMIZE__) && (__GNUC__ == 5) && !defined(__clang__) && !defined(__INTEL_COMPILER)
    asm ("vmovdqa %1, %0" : "=x" (mult_reg) : "m" (unquant_mult));
#else
    mult_reg = unquant_mult;
#endif
    auto result = kernels::unquantize(input, mult_reg);
    kernels::write(result, config.output_addr, info.row_idx * info.cols + info.col_idx);
  }

private:
  vf unquant_mult;
  UnquantizeAndWrite config;
};

/*
 * AddBiasAndWrite
 */
template <> class CallbackImpl<CPUType::CPU_NAME, AddBiasAndWrite> {
public:
  CPU_ATTR CallbackImpl(const AddBiasAndWrite& config) : config(config) {}

  CPU_ATTR void operator()(vi input, const OutputBufferInfo& info) {
    auto result = kernels::add_bias(input, config.bias_addr, info.col_idx);
    kernels::write(result, config.output_addr, info.row_idx * info.cols + info.col_idx);
  }

private:
  AddBiasAndWrite config;
};

/*
 * UnquantizeAndAddBiasAndWrite
 */
template <> class CallbackImpl<CPUType::CPU_NAME, UnquantizeAndAddBiasAndWrite> {
public:
  CPU_ATTR CallbackImpl(const UnquantizeAndAddBiasAndWrite& config) : config(config) {
    unquant_mult = set1_ps<vf>(config.unquant_mult);
  }

  CPU_ATTR void operator()(vi input, const OutputBufferInfo& info) {
    // Workaround gcc 5 internal compiler error that can't read register members in debug.
    vf mult_reg;
#if !defined(__OPTIMIZE__) && (__GNUC__ == 5) && !defined(__clang__) && !defined(__INTEL_COMPILER)
    asm ("vmovdqa %1, %0" : "=x" (mult_reg) : "m" (unquant_mult));
#else
    mult_reg = unquant_mult;
#endif
    auto result = kernels::unquantize(input, mult_reg);
    result = kernels::add_bias(result, config.bias_addr, info.col_idx);
    kernels::write(result, config.output_addr, info.row_idx * info.cols + info.col_idx);
  }
private:
  vf unquant_mult;
  UnquantizeAndAddBiasAndWrite config;
};

}
}

#undef CPU_NAME
#undef CPU_ATTR
#undef vi
#undef vf
#undef vd