diff options
author | Chao Mei <chaomei@google.com> | 2021-04-06 05:07:24 +0300 |
---|---|---|
committer | Copybara-Service <copybara-worker@google.com> | 2021-04-06 05:07:45 +0300 |
commit | 38a9266b832767a3f535a74a9e0cf39f7892e594 (patch) | |
tree | 1cb63d0a0092cc974d3ff87246037f3e2a45bf57 | |
parent | 939449243eb36e5b668cc00a1c936f2b1ad4dc27 (diff) |
Create a utility library to suppress floating-point denormals, and apply it to every task execution of every thread.
PiperOrigin-RevId: 366919663
-rw-r--r-- | ruy/BUILD | 10 | ||||
-rw-r--r-- | ruy/CMakeLists.txt | 16 | ||||
-rw-r--r-- | ruy/denormal.cc | 121 | ||||
-rw-r--r-- | ruy/denormal.h | 53 | ||||
-rw-r--r-- | ruy/thread_pool.cc | 4 | ||||
-rw-r--r-- | ruy/trmul.cc | 7 |
6 files changed, 211 insertions, 0 deletions
@@ -357,6 +357,7 @@ cc_library( deps = [ ":blocking_counter", ":check_macros", + ":denormal", ":time", ":trace", ":wait", @@ -420,6 +421,14 @@ cc_library( ) cc_library( + name = "denormal", + srcs = ["denormal.cc"], + hdrs = ["denormal.h"], + copts = ruy_copts(), + visibility = ["//visibility:public"], +) + +cc_library( name = "performance_advisory", hdrs = ["performance_advisory.h"], copts = ruy_copts(), @@ -956,6 +965,7 @@ cc_library( ":cpu_cache_params", ":cpuinfo", ":ctx", + ":denormal", ":mat", ":matrix", ":mul_params", diff --git a/ruy/CMakeLists.txt b/ruy/CMakeLists.txt index 4c3e394..b83bc8c 100644 --- a/ruy/CMakeLists.txt +++ b/ruy/CMakeLists.txt @@ -376,6 +376,7 @@ ruy_cc_library( DEPS ruy_blocking_counter ruy_check_macros + ruy_denormal ruy_time ruy_trace ruy_wait @@ -455,6 +456,20 @@ ruy_cc_library( ruy_cc_library( NAME + ruy_denormal + SRCS + denormal.cc + HDRS + denormal.h + COPTS + ${ruy_0_Wall_Wcxx14_compat_Wextra_Wundef} + ${ruy_1_mfpu_neon} + ${ruy_2_O3} + PUBLIC +) + +ruy_cc_library( + NAME ruy_performance_advisory HDRS performance_advisory.h @@ -1102,6 +1117,7 @@ ruy_cc_library( ruy_cpu_cache_params ruy_cpuinfo ruy_ctx + ruy_denormal ruy_mat ruy_matrix ruy_mul_params diff --git a/ruy/denormal.cc b/ruy/denormal.cc new file mode 100644 index 0000000..b3c0850 --- /dev/null +++ b/ruy/denormal.cc @@ -0,0 +1,121 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "ruy/denormal.h" + +// NOTE: this is simply a copy of pthreadpool/src/threadpool-utils.h that's not +// exposed by the pthreadpool library +// (https://github.com/Maratyszcza/pthreadpool), but with an additional C++ +// helper class to suppress floating-point denormal values. + +/* SSE-specific headers */ +#if defined(__SSE__) || defined(__x86_64__) || defined(_M_X64) || \ + (defined(_M_IX86_FP) && _M_IX86_FP >= 1) +#include <xmmintrin.h> +#endif + +/* MSVC-specific headers */ +#if defined(_MSC_VER) +#include <intrin.h> +#endif + +namespace ruy { +namespace { +inline struct fpu_state get_fpu_state() { + struct fpu_state state = {}; +#if defined(__SSE__) || defined(__x86_64__) || defined(_M_X64) || \ + (defined(_M_IX86_FP) && _M_IX86_FP >= 1) + state.mxcsr = static_cast<std::uint32_t>(_mm_getcsr()); +#elif defined(_MSC_VER) && defined(_M_ARM) + state.fpscr = + static_cast<std::uint32_t>(_MoveFromCoprocessor(10, 7, 1, 0, 0)); +#elif defined(_MSC_VER) && defined(_M_ARM64) + state.fpcr = static_cast<std::uint64_t>(_ReadStatusReg(0x5A20)); +#elif defined(__GNUC__) && defined(__arm__) && defined(__ARM_FP) && \ + (__ARM_FP != 0) + __asm__ __volatile__("VMRS %[fpscr], fpscr" : [fpscr] "=r"(state.fpscr)); +#elif defined(__GNUC__) && defined(__aarch64__) + __asm__ __volatile__("MRS %[fpcr], fpcr" : [fpcr] "=r"(state.fpcr)); +#endif + return state; +} + +inline void set_fpu_state(const struct fpu_state state) { +#if defined(__SSE__) || defined(__x86_64__) || defined(_M_X64) || \ + (defined(_M_IX86_FP) && _M_IX86_FP >= 1) + _mm_setcsr(static_cast<unsigned int>(state.mxcsr)); +#elif defined(_MSC_VER) && defined(_M_ARM) + _MoveToCoprocessor(static_cast<int>(state.fpscr, 10, 7, 1, 0, 0)); +#elif defined(_MSC_VER) && defined(_M_ARM64) + _WriteStatusReg(0x5A20, static_cast<__int64>(state.fpcr)); +#elif defined(__GNUC__) && defined(__arm__) && defined(__ARM_FP) && \ + (__ARM_FP != 0) + __asm__ __volatile__("VMSR fpscr, %[fpscr]" : : [fpscr] "r"(state.fpscr)); +#elif defined(__GNUC__) && defined(__aarch64__) + __asm__ __volatile__("MSR fpcr, %[fpcr]" : : [fpcr] "r"(state.fpcr)); +#else + (void)state; +#endif +} + +inline void disable_fpu_denormals() { +#if defined(__SSE__) || defined(__x86_64__) || defined(_M_X64) || \ + (defined(_M_IX86_FP) && _M_IX86_FP >= 1) + _mm_setcsr(_mm_getcsr() | 0x8040); +#elif defined(_MSC_VER) && defined(_M_ARM) + int fpscr = _MoveFromCoprocessor(10, 7, 1, 0, 0); + fpscr |= 0x1000000; + _MoveToCoprocessor(fpscr, 10, 7, 1, 0, 0); +#elif defined(_MSC_VER) && defined(_M_ARM64) + __int64 fpcr = _ReadStatusReg(0x5A20); + fpcr |= 0x1080000; + _WriteStatusReg(0x5A20, fpcr); +#elif defined(__GNUC__) && defined(__arm__) && defined(__ARM_FP) && \ + (__ARM_FP != 0) + std::uint32_t fpscr; +#if defined(__thumb__) && !defined(__thumb2__) + __asm__ __volatile__( + "VMRS %[fpscr], fpscr\n" + "ORRS %[fpscr], %[bitmask]\n" + "VMSR fpscr, %[fpscr]\n" + : [fpscr] "=l"(fpscr) + : [bitmask] "l"(0x1000000) + : "cc"); +#else + __asm__ __volatile__( + "VMRS %[fpscr], fpscr\n" + "ORR %[fpscr], #0x1000000\n" + "VMSR fpscr, %[fpscr]\n" + : [fpscr] "=r"(fpscr)); +#endif +#elif defined(__GNUC__) && defined(__aarch64__) + std::uint64_t fpcr; + __asm__ __volatile__( + "MRS %[fpcr], fpcr\n" + "ORR %w[fpcr], %w[fpcr], 0x1000000\n" + "ORR %w[fpcr], %w[fpcr], 0x80000\n" + "MSR fpcr, %[fpcr]\n" + : [fpcr] "=r"(fpcr)); +#endif +} +} // namespace + +ScopedSuppressDenormals::ScopedSuppressDenormals() { + restore_ = get_fpu_state(); + disable_fpu_denormals(); +} + +ScopedSuppressDenormals::~ScopedSuppressDenormals() { set_fpu_state(restore_); } +} // namespace ruy diff --git a/ruy/denormal.h b/ruy/denormal.h new file mode 100644 index 0000000..e5b836c --- /dev/null +++ b/ruy/denormal.h @@ -0,0 +1,53 @@ +/* Copyright 2021 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef RUY_RUY_DENORMAL_H_ +#define RUY_RUY_DENORMAL_H_ + +#include <cstdint> + +namespace ruy { +// NOTE: the following 'fpu_state' struct is copied from +// pthreadpool/src/threadpool-utils.h that's not exposed by the pthreadpool +// library (https://github.com/Maratyszcza/pthreadpool). +struct fpu_state { +#if defined(__SSE__) || defined(__x86_64__) || defined(_M_X64) || \ + (defined(_M_IX86_FP) && _M_IX86_FP >= 1) + std::uint32_t mxcsr; +#elif defined(__GNUC__) && defined(__arm__) && defined(__ARM_FP) && \ + (__ARM_FP != 0) || \ + defined(_MSC_VER) && defined(_M_ARM) + std::uint32_t fpscr; +#elif defined(__GNUC__) && defined(__aarch64__) || \ + defined(_MSC_VER) && defined(_M_ARM64) + std::uint64_t fpcr; +#endif +}; + +// While this class is active, denormal floating point numbers are suppressed. +// The destructor restores the original flags. +class ScopedSuppressDenormals { + public: + ScopedSuppressDenormals(); + ~ScopedSuppressDenormals(); + + private: + fpu_state restore_; + + ScopedSuppressDenormals(const ScopedSuppressDenormals&) = delete; + void operator=(const ScopedSuppressDenormals&) = delete; +}; +} // namespace ruy + +#endif // RUY_RUY_DENORMAL_H_ diff --git a/ruy/thread_pool.cc b/ruy/thread_pool.cc index 100cfe3..5f22a13 100644 --- a/ruy/thread_pool.cc +++ b/ruy/thread_pool.cc @@ -25,6 +25,7 @@ limitations under the License. #include <thread> // NOLINT(build/c++11) #include "ruy/check_macros.h" +#include "ruy/denormal.h" #include "ruy/trace.h" #include "ruy/wait.h" @@ -113,6 +114,9 @@ class Thread { RUY_TRACE_SCOPE_NAME("Ruy worker thread function"); ChangeState(State::Ready); + // Suppress denormals to avoid computation inefficiency. + ScopedSuppressDenormals suppress_denormals; + // Thread main loop while (true) { RUY_TRACE_SCOPE_NAME("Ruy worker thread loop iteration"); diff --git a/ruy/trmul.cc b/ruy/trmul.cc index 9345f0c..602660b 100644 --- a/ruy/trmul.cc +++ b/ruy/trmul.cc @@ -30,6 +30,7 @@ limitations under the License. #include "ruy/cpu_cache_params.h" #include "ruy/cpuinfo.h" #include "ruy/ctx.h" +#include "ruy/denormal.h" #include "ruy/mat.h" #include "ruy/matrix.h" #include "ruy/mul_params.h" @@ -307,6 +308,12 @@ void TrMul(Ctx* ctx, TrMulParams* params) { GetTentativeThreadCount(ctx, rows, cols, depth); const auto& cpu_cache_params = ctx->mutable_cpuinfo()->CacheParams(); + // Suppress denormals to avoid computation inefficiency. + // Note this only handles the denormal suppression on the main thread. As for + // worker threads, the suppression is handled in each thread's main loop. See + // the corresponding code in thread_pool.cc for details. + ScopedSuppressDenormals suppress_denormals; + // Case of running this TrMul as a simple loop. // This is a good place to start reading this function: all the rest // of this function is just an optimized, but functionally equivalent, |