From 2985958d1d9554789c5cc3004c162d43ad80e361 Mon Sep 17 00:00:00 2001 From: Kenneth Heafield Date: Mon, 24 Feb 2020 17:41:29 +0000 Subject: MaxAbsolute with arbitrary many arguments --- multiply.h | 26 +++++++++++++++++--------- test/multiply_test.cc | 24 +++++++++++++----------- 2 files changed, 30 insertions(+), 20 deletions(-) diff --git a/multiply.h b/multiply.h index 0aa86aa..9a15e0e 100644 --- a/multiply.h +++ b/multiply.h @@ -562,20 +562,28 @@ INTGEMM_SSSE3 inline static void InnerINTGEMM_SSSE3( } \ #define INTGEMM_MAXABSOLUTE(Register, target) \ -target static float MaxAbsolute(const float *begin_float, const float *end_float) { \ +target static inline float MaxAbsolute(const float *begin_float, const float *end_float) { \ assert(end_float > begin_float); \ - assert((end_float - begin_float) % (sizeof(Register) / sizeof(float)) == 0); \ + assert(reinterpret_cast(begin_float) % sizeof(Register) == 0); \ const Register *begin = reinterpret_cast(begin_float); \ - const Register *end = reinterpret_cast(end_float); \ - union {float f; int32_t i;} float_convert; \ - float_convert.i = 0x7fffffff; \ - Register and_me = set1_ps(float_convert.f); \ - Register highest = and_ps(and_me, *begin); \ - for (++begin; begin != end; ++begin) { \ + const float *end_reg = end_float - (reinterpret_cast(end_float) % sizeof(Register)) / sizeof(float); \ + const Register *end = reinterpret_cast(end_reg); \ + union {float f; int32_t i;} and_convert, float_convert; \ + and_convert.i = 0x7fffffff; \ + Register and_me = set1_ps(and_convert.f); \ + Register highest = setzero_ps(); \ + for (; begin < end; ++begin) { \ Register reg = and_ps(and_me, *begin); \ highest = max_ps(highest, reg); \ } \ - return MaxFloat32(highest); \ + float ret = MaxFloat32(highest); \ + /* Overhang: this would be more efficient if done in a single SIMD operation with some zeroing */ \ + for (const float *i = end_reg; i < end_float; ++i) { \ + float_convert.f = *i; \ + float_convert.i &= and_convert.i; \ + ret = std::max(ret, float_convert.f); \ + } \ + return ret; \ } \ } // namespace intgemm diff --git a/test/multiply_test.cc b/test/multiply_test.cc index 59c62a9..97f68a3 100644 --- a/test/multiply_test.cc +++ b/test/multiply_test.cc @@ -194,18 +194,20 @@ void CompareMaxAbs(const float *begin, const float *end, float test) { template void TestMaxAbsolute() { std::mt19937 gen; std::uniform_real_distribution dist(-8.0, 8.0); - AlignedVector test(64); - // 64 tries. - for (int t = 0; t < 64; ++t) { - // Fill with [-8, 8). - for (auto& it : test) { - it = dist(gen); + const std::size_t kLengthMax = 65; + AlignedVector test(kLengthMax); + for (std::size_t len = 1; len < kLengthMax; ++len) { + for (int t = 0; t < len; ++t) { + // Fill with [-8, 8). + for (auto& it : test) { + it = dist(gen); + } + CompareMaxAbs(test.begin(), test.begin() + len, Backend(test.begin(), test.begin() + len)); + test[t] = -32.0; + CompareMaxAbs(test.begin(), test.begin() + len, Backend(test.begin(), test.begin() + len)); + test[t] = 32.0; + CompareMaxAbs(test.begin(), test.begin() + len, Backend(test.begin(), test.begin() + len)); } - CompareMaxAbs(test.begin(), test.end(), Backend(test.begin(), test.end())); - test[t] = -32.0; - CompareMaxAbs(test.begin(), test.end(), Backend(test.begin(), test.end())); - test[t] = 32.0; - CompareMaxAbs(test.begin(), test.end(), Backend(test.begin(), test.end())); } } -- cgit v1.2.3