1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
|
#include "callbacks/configs.h"
#include "callbacks/output_buffer_info.h"
#include "intrinsics.h"
#include "kernels.h"
#include "types.h"
#include "vec_traits.h"
#if defined(THIS_IS_SSE2)
#define CPU_NAME SSE2
#define CPU_ATTR INTGEMM_SSE2
#elif defined(THIS_IS_AVX2)
#define CPU_NAME AVX2
#define CPU_ATTR INTGEMM_AVX2
#elif defined(THIS_IS_AVX512BW)
#define CPU_NAME AVX512BW
#define CPU_ATTR INTGEMM_AVX512BW
#else
#error "Only SSE2, AVX2 and AVX512BW are supported"
#endif
#if defined(THIS_IS_SSE2)
#define vi vector_t<CPUType::SSE2, int>
#define vf vector_t<CPUType::SSE2, float>
#define vd vector_t<CPUType::SSE2, double>
#else
#define vi vector_t<CPUType::AVX2, int>
#define vf vector_t<CPUType::AVX2, float>
#define vd vector_t<CPUType::AVX2, double>
#endif
namespace intgemm {
namespace callbacks {
template <CPUType CpuType, typename CallbackConfig>
class CallbackImpl;
}}
/*
* Callbacks implementations....
*/
namespace intgemm {
namespace callbacks {
/*
* Dummy
*/
template <> class CallbackImpl<CPUType::CPU_NAME, Dummy> {
public:
CPU_ATTR CallbackImpl(const Dummy&) {}
CPU_ATTR void operator()(vi, const OutputBufferInfo&) {}
};
/*
* UnquantizeAndWrite
*/
template <> class CallbackImpl<CPUType::CPU_NAME, UnquantizeAndWrite> {
public:
CPU_ATTR CallbackImpl(const UnquantizeAndWrite& config) : config(config) {
unquant_mult = set1_ps<vf>(config.unquant_mult);
}
CPU_ATTR void operator()(vi input, const OutputBufferInfo& info) {
auto result = kernels::unquantize(input, unquant_mult);
kernels::write(result, config.addr, info.row_idx * info.cols + info.col_idx);
}
private:
UnquantizeAndWrite config;
vf unquant_mult;
};
/*
* UnquantizeAndAddBiasAndWrite
*/
template <> class CallbackImpl<CPUType::CPU_NAME, UnquantizeAndAddBiasAndWrite> {
public:
CPU_ATTR CallbackImpl(const UnquantizeAndAddBiasAndWrite& config) : config(config) {
unquant_mult = set1_ps<vf>(config.unquant_mult);
}
CPU_ATTR void operator()(vi input, const OutputBufferInfo& info) {
auto result = kernels::unquantize(input, unquant_mult);
result = kernels::add_bias(result, config.bias_addr, info.col_idx);
kernels::write(result, config.output_addr, info.row_idx * info.cols + info.col_idx);
}
private:
UnquantizeAndAddBiasAndWrite config;
vf unquant_mult;
};
}
}
#undef CPU_NAME
#undef CPU_ATTR
#undef vi
#undef vf
#undef vd
|