Welcome to mirror list, hosted at ThFree Co, Russian Federation.

stats.inl « intgemm - github.com/marian-nmt/intgemm/intgemm.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: 1f8c5572c786fda0a0e42e840213a061d82a1a5e (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
/* This file is included multiple times, once per architecture. */
#if defined(INTGEMM_THIS_IS_AVX512DQ)
#define INTGEMM_ARCH AVX512BW
#define INTGEMM_TARGET INTGEMM_AVX512DQ
#elif defined(INTGEMM_THIS_IS_AVX2)
#define INTGEMM_ARCH AVX2
#define INTGEMM_TARGET INTGEMM_AVX2
#elif defined(INTGEMM_THIS_IS_SSE2)
#define INTGEMM_ARCH SSE2
#define INTGEMM_TARGET INTGEMM_SSE2
#else
#error Included with unexpected architecture
#endif
#include <iostream>

namespace intgemm {
namespace INTGEMM_ARCH {

/* Compute the maximum absolute value over floats aligned to register size.
 * Do not call this function directly; it's a subroutine of MaxAbsolute.
 */
INTGEMM_TARGET static inline float MaxAbsoluteThread(const FRegister *begin, const FRegister *end) {
  FRegister highest = setzero_ps<FRegister>();
  const FRegister abs_mask = cast_ps(set1_epi32<Register>(kFloatAbsoluteMask));
#pragma omp for
  for (const FRegister *i = begin; i < end; ++i) {
    FRegister reg = and_ps(abs_mask, *i);
    highest = max_ps(highest, reg);
  }
  return MaxFloat32(highest);
}

/* Compute the maximum absolute value of an array of floats.
 * begin_float must be aligned to a multiple of the register size.
*/
INTGEMM_TARGET static inline float MaxAbsolute(const float *begin_float, const float *end_float) {
  assert(reinterpret_cast<uintptr_t>(begin_float) % sizeof(FRegister) == 0);
  const float *end_reg = end_float - (reinterpret_cast<uintptr_t>(end_float) % sizeof(FRegister)) / sizeof(float);
  float ret = 0.0;
#pragma omp parallel reduction(max:ret) num_threads(std::max<int>(1, std::min<int>(omp_get_max_threads(), (end_float - begin_float) / 16384)))
  {
    float shard_max = MaxAbsoluteThread(
        reinterpret_cast<const FRegister*>(begin_float),
        reinterpret_cast<const FRegister*>(end_reg));
    ret = std::max(ret, shard_max);
  }
  /* Overhang. The beginning was aligned so if there's any overhang we're
   * allowed to read the next full register.  Then mask that to 0. */
#if defined(INTGEMM_THIS_IS_AVX512DQ)
  if (end_float != end_reg) {
    const FRegister abs_mask = cast_ps(set1_epi32<Register>(kFloatAbsoluteMask));
    __mmask16 mask = (1 << (end_float - end_reg)) - 1;
    FRegister masked = _mm512_maskz_and_ps(mask, abs_mask, *reinterpret_cast<const FRegister*>(end_reg));
    ret = std::max(ret, MaxFloat32(masked));
  }
#else
  for (const float *i = end_reg; i < end_float; ++i) {
    ret = std::max(ret, std::fabs(*i));
  }
#endif
  return ret;
}

/* Computes the euclidean norm and returns the mean and the standard deviation. Optionally it can be the mean and standard deviation in absolute terms. */
INTGEMM_TARGET static inline MeanStd VectorMeanStd(const float *begin_float, const float *end_float, bool absolute) {
  assert(end_float > begin_float);
  // Make sure we deal with any number of elements
  long num_items = end_float - begin_float;
  const long constexpr width = sizeof(FRegister) / sizeof(float);
  std::ldiv_t result = std::ldiv(num_items, width);

  const FRegister *begin = reinterpret_cast<const FRegister*>(begin_float);
  const FRegister *end = reinterpret_cast<const FRegister*>(begin_float + result.quot*width);
  FRegister squares = set1_ps<FRegister>(0);
  FRegister sums = set1_ps<FRegister>(0);
  float squares_sum = 0;
  float normal_sums = 0;
  if (absolute) {
    const FRegister abs_mask = cast_ps(set1_epi32<Register>(kFloatAbsoluteMask));
    for (; begin != end; begin++) {
      FRegister vec = and_ps(abs_mask, *begin);
      squares = add_ps(squares, mul_ps(vec, vec));
      sums = add_ps(sums, vec);
    }
    for (long i = 0; i < result.rem; i++) {
      size_t index = result.quot*width + i;
      squares_sum += begin_float[index]*begin_float[index];
      normal_sums += std::fabs(begin_float[index]);
    }
  } else {
    for (; begin != end; begin++) {
      FRegister vec = *begin;
      squares = add_ps(squares, mul_ps(vec, vec));
      sums = add_ps(sums, vec);
    }
    for (long i = 0; i < result.rem; i++) {
      size_t index = result.quot*width + i;
      squares_sum += begin_float[index]*begin_float[index];
      normal_sums += begin_float[index];
    }
  }
  squares_sum += AddFloat32(squares);
  normal_sums += AddFloat32(sums);
  MeanStd ret;
  ret.mean = normal_sums/num_items;
  ret.stddev = std::sqrt((squares_sum/num_items) - (ret.mean*ret.mean));
  return ret;
}

} // namespace INTGEMM_ARCH
} // namespace intgemm

#undef INTGEMM_ARCH
#undef INTGEMM_TARGET