1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
|
/* This file is included multiple times, once per architecture. */
#if defined(INTGEMM_THIS_IS_AVX512DQ)
#define INTGEMM_ARCH AVX512BW
#define INTGEMM_TARGET INTGEMM_AVX512DQ
#elif defined(INTGEMM_THIS_IS_AVX2)
#define INTGEMM_ARCH AVX2
#define INTGEMM_TARGET INTGEMM_AVX2
#elif defined(INTGEMM_THIS_IS_SSE2)
#define INTGEMM_ARCH SSE2
#define INTGEMM_TARGET INTGEMM_SSE2
#else
#error Included with unexpected architecture
#endif
#include <iostream>
namespace intgemm {
namespace INTGEMM_ARCH {
/* Compute the maximum absolute value over floats aligned to register size.
* Do not call this function directly; it's a subroutine of MaxAbsolute.
*/
INTGEMM_TARGET static inline float MaxAbsoluteThread(const FRegister *begin, const FRegister *end) {
FRegister highest = setzero_ps<FRegister>();
const FRegister abs_mask = cast_ps(set1_epi32<Register>(kFloatAbsoluteMask));
#pragma omp for
for (const FRegister *i = begin; i < end; ++i) {
FRegister reg = and_ps(abs_mask, *i);
highest = max_ps(highest, reg);
}
return MaxFloat32(highest);
}
/* Compute the maximum absolute value of an array of floats.
* begin_float must be aligned to a multiple of the register size.
*/
INTGEMM_TARGET static inline float MaxAbsolute(const float *begin_float, const float *end_float) {
assert(reinterpret_cast<uintptr_t>(begin_float) % sizeof(FRegister) == 0);
const float *end_reg = end_float - (reinterpret_cast<uintptr_t>(end_float) % sizeof(FRegister)) / sizeof(float);
float ret = 0.0;
#pragma omp parallel reduction(max:ret) num_threads(std::max<int>(1, std::min<int>(omp_get_max_threads(), (end_float - begin_float) / 16384)))
{
float shard_max = MaxAbsoluteThread(
reinterpret_cast<const FRegister*>(begin_float),
reinterpret_cast<const FRegister*>(end_reg));
ret = std::max(ret, shard_max);
}
/* Overhang. The beginning was aligned so if there's any overhang we're
* allowed to read the next full register. Then mask that to 0. */
#if defined(INTGEMM_THIS_IS_AVX512DQ)
if (end_float != end_reg) {
const FRegister abs_mask = cast_ps(set1_epi32<Register>(kFloatAbsoluteMask));
__mmask16 mask = (1 << (end_float - end_reg)) - 1;
FRegister masked = _mm512_maskz_and_ps(mask, abs_mask, *reinterpret_cast<const FRegister*>(end_reg));
ret = std::max(ret, MaxFloat32(masked));
}
#else
for (const float *i = end_reg; i < end_float; ++i) {
ret = std::max(ret, std::fabs(*i));
}
#endif
return ret;
}
/* Computes the euclidean norm and returns the mean and the standard deviation. Optionally it can be the mean and standard deviation in absolute terms. */
INTGEMM_TARGET static inline MeanStd VectorMeanStd(const float *begin_float, const float *end_float, bool absolute) {
assert(end_float > begin_float);
// Make sure we deal with any number of elements
long num_items = end_float - begin_float;
const long constexpr width = sizeof(FRegister) / sizeof(float);
std::ldiv_t result = std::ldiv(num_items, width);
const FRegister *begin = reinterpret_cast<const FRegister*>(begin_float);
const FRegister *end = reinterpret_cast<const FRegister*>(begin_float + result.quot*width);
FRegister squares = set1_ps<FRegister>(0);
FRegister sums = set1_ps<FRegister>(0);
float squares_sum = 0;
float normal_sums = 0;
if (absolute) {
const FRegister abs_mask = cast_ps(set1_epi32<Register>(kFloatAbsoluteMask));
for (; begin != end; begin++) {
FRegister vec = and_ps(abs_mask, *begin);
squares = add_ps(squares, mul_ps(vec, vec));
sums = add_ps(sums, vec);
}
for (long i = 0; i < result.rem; i++) {
size_t index = result.quot*width + i;
squares_sum += begin_float[index]*begin_float[index];
normal_sums += std::fabs(begin_float[index]);
}
} else {
for (; begin != end; begin++) {
FRegister vec = *begin;
squares = add_ps(squares, mul_ps(vec, vec));
sums = add_ps(sums, vec);
}
for (long i = 0; i < result.rem; i++) {
size_t index = result.quot*width + i;
squares_sum += begin_float[index]*begin_float[index];
normal_sums += begin_float[index];
}
}
squares_sum += AddFloat32(squares);
normal_sums += AddFloat32(sums);
MeanStd ret;
ret.mean = normal_sums/num_items;
ret.stddev = std::sqrt((squares_sum/num_items) - (ret.mean*ret.mean));
return ret;
}
} // namespace INTGEMM_ARCH
} // namespace intgemm
#undef INTGEMM_ARCH
#undef INTGEMM_TARGET
|