diff options
author | T.J. Alumbaugh <talumbau@google.com> | 2020-07-30 01:04:36 +0300 |
---|---|---|
committer | Copybara-Service <copybara-worker@google.com> | 2020-07-30 01:05:07 +0300 |
commit | d4822f45c6830b16ea05e7fa40c5680fb8010eb7 (patch) | |
tree | 3e6961d14cc210ace7419b9732919de8cf65d2f9 | |
parent | 18e34facf87619fdfd6571b3050796e8f2f2d15b (diff) |
Adds AVX path and AVX float kernel.
PiperOrigin-RevId: 323876243
-rw-r--r-- | ruy/BUILD | 41 | ||||
-rw-r--r-- | ruy/build_defs.bzl | 6 | ||||
-rw-r--r-- | ruy/cpuinfo.cc | 3 | ||||
-rw-r--r-- | ruy/cpuinfo.h | 1 | ||||
-rw-r--r-- | ruy/ctx.cc | 2 | ||||
-rw-r--r-- | ruy/have_built_path_for.h | 1 | ||||
-rw-r--r-- | ruy/have_built_path_for_avx.cc | 35 | ||||
-rw-r--r-- | ruy/kernel_avx.cc | 75 | ||||
-rw-r--r-- | ruy/kernel_avx2_fma.cc | 301 | ||||
-rw-r--r-- | ruy/kernel_x86.h | 349 | ||||
-rw-r--r-- | ruy/path.h | 9 | ||||
-rw-r--r-- | ruy/platform.h | 6 | ||||
-rw-r--r-- | ruy/test.h | 16 |
13 files changed, 296 insertions, 549 deletions
@@ -1,7 +1,7 @@ # Ruy is not BLAS load("@bazel_skylib//lib:selects.bzl", "selects") -load(":build_defs.bzl", "ruy_copts", "ruy_copts_avx", "ruy_copts_avx2_fma", "ruy_copts_avx512") +load(":build_defs.bzl", "ruy_copts", "ruy_copts_avx2_fma", "ruy_copts_avx512") load(":build_defs.oss.bzl", "ruy_linkopts_thread_standard_library") load(":ruy_test_ext.oss.bzl", "ruy_test_ext_defines", "ruy_test_ext_deps") load(":ruy_test.bzl", "ruy_benchmark", "ruy_test") @@ -682,43 +682,6 @@ cc_library( ) cc_library( - name = "kernel_avx", - srcs = [ - "kernel_avx.cc", - ], - hdrs = [ - "kernel_x86.h", - ], - copts = ruy_copts() + ruy_copts_avx(), - deps = [ - ":check_macros", - ":kernel_common", - ":mat", - ":mul_params", - ":opt_set", - ":path", - ":platform", - ":tune", - "//ruy/profiler:instrumentation", - ], -) - -cc_library( - name = "have_built_path_for_avx", - srcs = [ - "have_built_path_for_avx.cc", - ], - hdrs = [ - "have_built_path_for.h", - ], - copts = ruy_copts() + ruy_copts_avx2_fma(), - deps = [ - ":opt_set", - ":platform", - ], -) - -cc_library( name = "kernel", hdrs = [ "kernel.h", @@ -728,7 +691,6 @@ cc_library( ":apply_multiplier", ":check_macros", ":kernel_arm", # fixdeps: keep - ":kernel_avx", ":kernel_avx2_fma", # fixdeps: keep ":kernel_avx512", # fixdeps: keep ":kernel_common", @@ -773,7 +735,6 @@ cc_library( "have_built_path_for.h", ], deps = [ - ":have_built_path_for_avx", ":have_built_path_for_avx2_fma", ":have_built_path_for_avx512", ":platform", diff --git a/ruy/build_defs.bzl b/ruy/build_defs.bzl index e2fc325..a36942b 100644 --- a/ruy/build_defs.bzl +++ b/ruy/build_defs.bzl @@ -63,12 +63,6 @@ def ruy_copts_avx512(): "//conditions:default": [], }) -def ruy_copts_avx(): - return select({ - "//ruy:x86_64": ["-mavx"], - "//conditions:default": [], - }) - def ruy_copts_avx2_fma(): return select({ "//ruy:x86_64": ["-mavx2", "-mfma"], diff --git a/ruy/cpuinfo.cc b/ruy/cpuinfo.cc index 3203bb2..335158d 100644 --- a/ruy/cpuinfo.cc +++ b/ruy/cpuinfo.cc @@ -106,8 +106,6 @@ bool CpuInfo::Avx2Fma() { cpuinfo_has_x86_fma3(); } -bool CpuInfo::Avx() { return EnsureInitialized() && cpuinfo_has_x86_avx(); } - bool CpuInfo::Avx512() { return EnsureInitialized() && cpuinfo_has_x86_avx512f() && cpuinfo_has_x86_avx512dq() && cpuinfo_has_x86_avx512cd() && @@ -131,7 +129,6 @@ bool CpuInfo::EnsureInitialized() { } bool CpuInfo::NeonDotprod() { return false; } bool CpuInfo::Sse42() { return false; } -bool CpuInfo::Avx() { return false; } bool CpuInfo::Avx2Fma() { return false; } bool CpuInfo::Avx512() { return false; } bool CpuInfo::AvxVnni() { return false; } diff --git a/ruy/cpuinfo.h b/ruy/cpuinfo.h index 5d816f9..17d061e 100644 --- a/ruy/cpuinfo.h +++ b/ruy/cpuinfo.h @@ -31,7 +31,6 @@ class CpuInfo final { // X86 features bool Sse42(); - bool Avx(); bool Avx2Fma(); bool Avx512(); bool AvxVnni(); @@ -111,8 +111,6 @@ Path DetectRuntimeSupportedPaths(Path paths_to_detect, CpuInfo* cpuinfo) { #elif RUY_PLATFORM_X86 // x86 SIMD paths currently require both runtime detection, and detection of // whether we're building the path at all. - maybe_add(Path::kAvx, - [=]() { return HaveBuiltPathForAvx() && cpuinfo->Avx(); }); maybe_add(Path::kAvx2Fma, [=]() { return HaveBuiltPathForAvx2Fma() && cpuinfo->Avx2Fma(); }); maybe_add(Path::kAvx512, diff --git a/ruy/have_built_path_for.h b/ruy/have_built_path_for.h index 23cb028..60e98e1 100644 --- a/ruy/have_built_path_for.h +++ b/ruy/have_built_path_for.h @@ -21,7 +21,6 @@ limitations under the License. namespace ruy { #if RUY_PLATFORM_X86 -bool HaveBuiltPathForAvx(); bool HaveBuiltPathForAvx2Fma(); bool HaveBuiltPathForAvx512(); #endif // RUY_PLATFORM_X86 diff --git a/ruy/have_built_path_for_avx.cc b/ruy/have_built_path_for_avx.cc deleted file mode 100644 index 948c7a5..0000000 --- a/ruy/have_built_path_for_avx.cc +++ /dev/null @@ -1,35 +0,0 @@ -/* Copyright 2020 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "ruy/have_built_path_for.h" -#include "ruy/opt_set.h" - -namespace ruy { - -#if RUY_PLATFORM_X86 -// IMPORTANT: -// These patterns must match those in the pack and kernel cc files. -#if !(RUY_PLATFORM_AVX && RUY_OPT(ASM)) - -bool HaveBuiltPathForAvx() { return false; } - -#else // RUY_PLATFORM_AVX && RUY_OPT(ASM) - -bool HaveBuiltPathForAvx() { return true; } - -#endif // RUY_PLATFORM_AVX && RUY_OPT(ASM) -#endif // RUY_PLATFORM_X86 - -} // namespace ruy diff --git a/ruy/kernel_avx.cc b/ruy/kernel_avx.cc deleted file mode 100644 index c3bc473..0000000 --- a/ruy/kernel_avx.cc +++ /dev/null @@ -1,75 +0,0 @@ -/* Copyright 2020 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include <algorithm> -#include <cstdint> -#include <cstring> - -#include "ruy/check_macros.h" -#include "ruy/kernel_common.h" -#include "ruy/kernel_x86.h" -#include "ruy/opt_set.h" -#include "ruy/platform.h" -#include "ruy/profiler/instrumentation.h" - -#if RUY_PLATFORM_AVX && RUY_OPT(ASM) -#include <immintrin.h> // IWYU pragma: keep -#endif - -namespace ruy { - -#if !(RUY_PLATFORM_AVX && RUY_OPT(ASM)) - -void KernelFloatAvx(const KernelParamsFloat<8, 8>&) { - // CPU-ID-based checks should disable the path that would reach this point. - RUY_DCHECK(false); -} - -void KernelFloatAvxSingleCol(const KernelParamsFloat<8, 8>&) { - // CPU-ID-based checks should disable the path that would reach this point. - RUY_DCHECK(false); -} - -#else // RUY_PLATFORM_AVX && RUY_OPT(ASM) - -namespace { -namespace intrin_utils { - - -// AVX doesn't have fused multiply-add so we define an inline function to be -// used in the common code following. -template<> -inline __m256 MulAdd<Path::kAvx>(const __m256& a, const __m256& b, const __m256& c) { - const __m256 prod = _mm256_mul_ps(a, b); - return _mm256_add_ps(prod, c); -} - -} // namespace intrin_utils -} // namespace - - -void KernelFloatAvx(const KernelParamsFloat<8, 8>& params) { - profiler::ScopeLabel label("Kernel kAvx float"); - KernelFloatAvxCommon<Path::kAvx>(params); -} - -void KernelFloatAvxSingleCol(const KernelParamsFloat<8, 8>& params) { - profiler::ScopeLabel label("Kernel kAvx float GEMV"); - KernelFloatAvxCommonSingleCol<Path::kAvx>(params); -} - -#endif // RUY_PLATFORM_AVX && RUY_OPT(ASM) - -} // namespace ruy diff --git a/ruy/kernel_avx2_fma.cc b/ruy/kernel_avx2_fma.cc index 0b3535f..4ba73f5 100644 --- a/ruy/kernel_avx2_fma.cc +++ b/ruy/kernel_avx2_fma.cc @@ -282,6 +282,48 @@ inline void mm256_storeu_epi32(std::int32_t* dst, const __m256i v) { _mm256_storeu_si256(reinterpret_cast<__m256i*>(dst), v); } +inline float mm256_get1_ps(const __m256 a, int i) { + __m256i ai = _mm256_castps_si256(a); + int float_val_as_int; + switch (i) { + case 0: + float_val_as_int = _mm256_extract_epi32(ai, 0); + break; + case 1: + float_val_as_int = _mm256_extract_epi32(ai, 1); + break; + case 2: + float_val_as_int = _mm256_extract_epi32(ai, 2); + break; + case 3: + float_val_as_int = _mm256_extract_epi32(ai, 3); + break; + case 4: + float_val_as_int = _mm256_extract_epi32(ai, 4); + break; + case 5: + float_val_as_int = _mm256_extract_epi32(ai, 5); + break; + case 6: + float_val_as_int = _mm256_extract_epi32(ai, 6); + break; + case 7: + float_val_as_int = _mm256_extract_epi32(ai, 7); + break; + default: + RUY_DCHECK_LT(i, 8); + return .0f; + } + float float_val; + std::memcpy(&float_val, &float_val_as_int, sizeof(float_val)); + return float_val; +} + +inline void mm256_n_storeu_ps(float* dst, int residual_rows, const __m256 v) { + for (int i = 0; i < residual_rows; ++i) { + dst[i] = intrin_utils::mm256_get1_ps(v, i); + } +} // Transpose a 8x8 matrix of floats. void mm256_transpose8x8_ps(__m256* v0, __m256* v1, __m256* v2, __m256* v3, @@ -321,14 +363,6 @@ void mm256_transpose8x8_epi32(__m256i* v0, __m256i* v1, __m256i* v2, reinterpret_cast<__m256*>(v4), reinterpret_cast<__m256*>(v5), reinterpret_cast<__m256*>(v6), reinterpret_cast<__m256*>(v7)); } - -// Make an inline function for FMA so we can share the float kernels -// with non-FMA code. -template<> -inline __m256 MulAdd<Path::kAvx2Fma>(const __m256& a, const __m256& b, const __m256& c) { - return _mm256_fmadd_ps(a, b, c); -} - } // namespace intrin_utils } // namespace @@ -1174,15 +1208,260 @@ void Kernel8bitAvx2SingleCol(const KernelParams8bit<8, 8>& params) { rhs_col_ptr += kAvx8bitBlockSize * params.rhs_stride; } // NOLINT(readability/fn_size) - void KernelFloatAvx2(const KernelParamsFloat<8, 8>& params) { profiler::ScopeLabel label("Kernel kAvx2Fma float"); - KernelFloatAvxCommon<Path::kAvx2Fma>(params); + + // As parameters are defined, we need to scale by sizeof(float). + const std::int64_t lhs_stride = params.lhs_stride >> 2; + const std::int64_t dst_stride = params.dst_stride >> 2; + const std::int64_t rhs_stride = params.rhs_stride >> 2; + // + int bias_ptr_block_increment = params.flags & RUY_ASM_FLAG_HAS_BIAS ? 1 : 0; + // AVX2 float block size = 8. + const int end_row = std::min(params.dst_rows, params.last_row + 8); + const int end_col = std::min(params.dst_cols, params.last_col + 8); + // + const float* adj_rhs_col_ptr = + params.rhs_base_ptr - params.start_col * rhs_stride; + float* adj_dst_col_ptr = + params.dst_base_ptr - params.start_col * dst_stride - params.start_row; + const float* adj_lhs_col_ptr = + params.lhs_base_ptr - params.start_row * lhs_stride; + const float* bias_ptr = params.bias; + + const __m256 clamp_max_v = _mm256_set1_ps(params.clamp_max); + const __m256 clamp_min_v = _mm256_set1_ps(params.clamp_min); + const bool channel_dimension_is_col = + params.flags & RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL; + + int col = params.start_col; + // Loop through cols by float block size, leaving incomplete remainder + for (; col <= end_col - 8; col += 8) { + __m256 accum_data_v[8]; + + const float* rhs_col_ptr = adj_rhs_col_ptr + col * rhs_stride; + float* dst_col_ptr = adj_dst_col_ptr + col * dst_stride; + + for (int row = params.start_row; row < end_row; row += 8) { + const int residual_rows = std::min(end_row - row, 8); + + const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride; + float* dst_ptr = dst_col_ptr + row; + + // Initialize with bias. + if (channel_dimension_is_col) { + const float* bias_elem_ptr = bias_ptr + col * bias_ptr_block_increment; + for (int j = 0; j < 8; ++j) { + accum_data_v[j] = _mm256_broadcast_ss(bias_elem_ptr + j); + } + } else { + const float* bias_elem_ptr = bias_ptr + row * bias_ptr_block_increment; + const __m256 initial_accum_data = _mm256_loadu_ps(bias_elem_ptr); + + for (int j = 0; j < 8; ++j) { + accum_data_v[j] = initial_accum_data; + } + } + + const float* lhs_ptr = lhs_col_ptr; + const float* rhs_ptr = rhs_col_ptr; + for (int d = 0; d < params.depth; ++d) { + const __m256 lhs_data = _mm256_loadu_ps(lhs_ptr); + // In this version RHS values are loaded individually rather than first + // loading together and then extract with broadcasting. This is because + // AVX flavours and instrinsics and compilers in combination do not + // handle this pattern of extraction very well. + const float* rhs_data = rhs_ptr; + + for (int j = 0; j < 8; ++j) { + const __m256 dup_rhs_element_j = _mm256_set1_ps(rhs_data[j]); + accum_data_v[j] = + _mm256_fmadd_ps(lhs_data, dup_rhs_element_j, accum_data_v[j]); + } + lhs_ptr += 8; + rhs_ptr += 8; + } + + if (residual_rows == 8) { + for (int j = 0; j < 8; ++j) { + float* block_ptr = dst_ptr + j * dst_stride; + accum_data_v[j] = _mm256_min_ps(accum_data_v[j], clamp_max_v); + accum_data_v[j] = _mm256_max_ps(accum_data_v[j], clamp_min_v); + _mm256_storeu_ps(block_ptr, accum_data_v[j]); + } + } else { + for (int j = 0; j < 8; ++j) { + float* block_ptr = dst_ptr + j * dst_stride; + accum_data_v[j] = _mm256_min_ps(accum_data_v[j], clamp_max_v); + accum_data_v[j] = _mm256_max_ps(accum_data_v[j], clamp_min_v); + intrin_utils::mm256_n_storeu_ps(block_ptr, residual_rows, + accum_data_v[j]); + } + } + } // End row-block loop. + } // End col-block loop. + + if (col < end_col) { + // Remaining cols in [0, float block size). + RUY_DCHECK_GE(end_col - col, 0); + RUY_DCHECK_LT(end_col - col, 8); + + __m256 accum_data_v[8]; + + const float* rhs_col_ptr = adj_rhs_col_ptr + col * rhs_stride; + float* dst_col_ptr = adj_dst_col_ptr + col * dst_stride; + const int residual_cols = std::min(end_col - col, 8); + + for (int row = params.start_row; row < end_row; row += 8) { + const int residual_rows = std::min(end_row - row, 8); + + const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride; + float* dst_ptr = dst_col_ptr + row; + + // Initialize with bias. + if (channel_dimension_is_col) { + const float* bias_elem_ptr = bias_ptr + col * bias_ptr_block_increment; + for (int j = 0; j < 8; ++j) { + accum_data_v[j] = _mm256_broadcast_ss(bias_elem_ptr + j); + } + } else { + const float* bias_elem_ptr = bias_ptr + row * bias_ptr_block_increment; + const __m256 initial_accum_data = _mm256_loadu_ps(bias_elem_ptr); + + for (int j = 0; j < 8; ++j) { + accum_data_v[j] = initial_accum_data; + } + } + + const float* lhs_ptr = lhs_col_ptr; + const float* rhs_ptr = rhs_col_ptr; + for (int d = 0; d < params.depth; ++d) { + const __m256 lhs_data = _mm256_loadu_ps(lhs_ptr); + const float* rhs_data = rhs_ptr; + + for (int j = 0; j < 8; ++j) { + const __m256 dup_rhs_element_j = _mm256_set1_ps(rhs_data[j]); + accum_data_v[j] = + _mm256_fmadd_ps(lhs_data, dup_rhs_element_j, accum_data_v[j]); + } + lhs_ptr += 8; + rhs_ptr += 8; + } + + for (int j = 0; j < residual_cols; ++j) { + float* block_ptr = dst_ptr + j * dst_stride; + accum_data_v[j] = _mm256_min_ps(accum_data_v[j], clamp_max_v); + accum_data_v[j] = _mm256_max_ps(accum_data_v[j], clamp_min_v); + intrin_utils::mm256_n_storeu_ps(block_ptr, residual_rows, + accum_data_v[j]); + } + } // End row-block loop. + } // End col-block terminal conditional. } void KernelFloatAvx2SingleCol(const KernelParamsFloat<8, 8>& params) { profiler::ScopeLabel label("Kernel kAvx2Fma float GEMV"); - KernelFloatAvxCommonSingleCol<Path::kAvx2Fma>(params); + + RUY_DCHECK_EQ(params.dst_cols, 1); + RUY_DCHECK_EQ(params.last_col, 0); + RUY_DCHECK_EQ(params.start_col, 0); + + // As parameters are defined, we need to scale by sizeof(float). + const std::int64_t lhs_stride = params.lhs_stride >> 2; + // + int bias_ptr_block_increment = params.flags & RUY_ASM_FLAG_HAS_BIAS ? 1 : 0; + // AVX2 float block size = 8. + const int end_row = std::min(params.dst_rows, params.last_row + 8); + + float* adj_dst_col_ptr = params.dst_base_ptr - params.start_row; + const float* adj_lhs_col_ptr = + params.lhs_base_ptr - params.start_row * lhs_stride; + const float* bias_col_ptr = params.bias; + + const __m256 clamp_max_v = _mm256_set1_ps(params.clamp_max); + const __m256 clamp_min_v = _mm256_set1_ps(params.clamp_min); + + __m256 accum_data_v; + + const float* rhs_col_ptr = params.rhs_base_ptr; + float* dst_col_ptr = adj_dst_col_ptr; + + int row = params.start_row; + for (; row <= end_row - 8; row += 8) { + const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride; + float* dst_ptr = dst_col_ptr + row; + const float* bias_ptr = bias_col_ptr + row * bias_ptr_block_increment; + + // Initialize with bias. + accum_data_v = _mm256_loadu_ps(bias_ptr); + + const float* lhs_ptr = lhs_col_ptr; + const float* rhs_ptr = rhs_col_ptr; + int d = 0; + for (; d <= params.depth - 4; d += 4) { + const __m256 lhs_data_0 = _mm256_loadu_ps(lhs_ptr); + const __m256 dup_rhs_element_0 = _mm256_set1_ps(rhs_ptr[0]); + accum_data_v = + _mm256_fmadd_ps(lhs_data_0, dup_rhs_element_0, accum_data_v); + const __m256 dup_rhs_element_1 = _mm256_set1_ps(rhs_ptr[8]); + const __m256 lhs_data_1 = _mm256_loadu_ps(lhs_ptr + 8); + accum_data_v = + _mm256_fmadd_ps(lhs_data_1, dup_rhs_element_1, accum_data_v); + + const __m256 lhs_data_2 = _mm256_loadu_ps(lhs_ptr + 16); + const __m256 dup_rhs_element_2 = _mm256_set1_ps(rhs_ptr[16]); + accum_data_v = + _mm256_fmadd_ps(lhs_data_2, dup_rhs_element_2, accum_data_v); + const __m256 dup_rhs_element_3 = _mm256_set1_ps(rhs_ptr[24]); + const __m256 lhs_data_3 = _mm256_loadu_ps(lhs_ptr + 24); + accum_data_v = + _mm256_fmadd_ps(lhs_data_3, dup_rhs_element_3, accum_data_v); + lhs_ptr += 32; // Loaded 8 * 4 floats. + rhs_ptr += 32; + } + for (; d < params.depth; ++d) { + const __m256 lhs_data = _mm256_loadu_ps(lhs_ptr); + const float* rhs_data = rhs_ptr; + + const __m256 dup_rhs_element_j = _mm256_set1_ps(rhs_data[0]); + accum_data_v = _mm256_fmadd_ps(lhs_data, dup_rhs_element_j, accum_data_v); + lhs_ptr += 8; + rhs_ptr += 8; + } + + accum_data_v = _mm256_min_ps(accum_data_v, clamp_max_v); + accum_data_v = _mm256_max_ps(accum_data_v, clamp_min_v); + _mm256_storeu_ps(dst_ptr, accum_data_v); + } // End row-block loop. + + if (row < end_row) { + const int residual_rows = end_row - row; + RUY_CHECK_GE(residual_rows, 1); + RUY_CHECK_LT(residual_rows, 8); + + const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride; + float* dst_ptr = dst_col_ptr + row; + const float* bias_ptr = bias_col_ptr + row * bias_ptr_block_increment; + + // Initialize with bias. + accum_data_v = _mm256_loadu_ps(bias_ptr); + + const float* lhs_ptr = lhs_col_ptr; + const float* rhs_ptr = rhs_col_ptr; + for (int d = 0; d < params.depth; ++d) { + const __m256 lhs_data = _mm256_loadu_ps(lhs_ptr); + const float* rhs_data = rhs_ptr; + + const __m256 dup_rhs_element_j = _mm256_set1_ps(rhs_data[0]); + accum_data_v = _mm256_fmadd_ps(lhs_data, dup_rhs_element_j, accum_data_v); + lhs_ptr += 8; + rhs_ptr += 8; + } + + accum_data_v = _mm256_min_ps(accum_data_v, clamp_max_v); + accum_data_v = _mm256_max_ps(accum_data_v, clamp_min_v); + intrin_utils::mm256_n_storeu_ps(dst_ptr, residual_rows, accum_data_v); + } // End handling of residual rows. } #endif // RUY_PLATFORM_AVX2_FMA && RUY_OPT(ASM) diff --git a/ruy/kernel_x86.h b/ruy/kernel_x86.h index b330482..641fbb6 100644 --- a/ruy/kernel_x86.h +++ b/ruy/kernel_x86.h @@ -31,7 +31,6 @@ namespace ruy { #if RUY_PLATFORM_X86 RUY_INHERIT_KERNEL(Path::kStandardCpp, Path::kAvx2Fma) -RUY_INHERIT_KERNEL(Path::kStandardCpp, Path::kAvx) RUY_INHERIT_KERNEL(Path::kAvx2Fma, Path::kAvx512) void Kernel8bitAvx512(const KernelParams8bit<16, 16>& params); @@ -135,356 +134,8 @@ struct Kernel<Path::kAvx2Fma, float, float, float, float> { } }; -void KernelFloatAvx(const KernelParamsFloat<8, 8>& params); -void KernelFloatAvxSingleCol(const KernelParamsFloat<8, 8>& params); - -template <> -struct Kernel<Path::kAvx, float, float, float, float> { - static constexpr Path kPath = Path::kAvx; - Tuning tuning = Tuning::kAuto; - using LhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 8>; - using RhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 8>; - explicit Kernel(Tuning tuning_) : tuning(tuning_) {} - void Run(const PMat<float>& lhs, const PMat<float>& rhs, - const MulParams<float, float>& mul_params, int start_row, - int start_col, int end_row, int end_col, Mat<float>* dst) const { - KernelParamsFloat<LhsLayout::kCols, RhsLayout::kCols> params; - MakeKernelParamsFloat(lhs, rhs, mul_params, start_row, start_col, end_row, - end_col, dst, ¶ms); - if (dst->layout.cols == 1 && - mul_params.channel_dimension() == ChannelDimension::kRow) { - KernelFloatAvxSingleCol(params); - } else { - KernelFloatAvx(params); - } - } -}; #endif // RUY_PLATFORM_X86 -#if ((RUY_PLATFORM_AVX || RUY_PLATFORM_AVX2_FMA) && RUY_OPT(ASM)) - -#include <immintrin.h> // IWYU pragma: keep - -namespace { -namespace intrin_utils { - -// Defined as a template so clang won't detect it as an uneeded -// definition. -template<Path path> -inline float mm256_get1_ps(const __m256 a, int i) { - __m256i ai = _mm256_castps_si256(a); - int float_val_as_int; - switch (i) { - case 0: - float_val_as_int = _mm256_extract_epi32(ai, 0); - break; - case 1: - float_val_as_int = _mm256_extract_epi32(ai, 1); - break; - case 2: - float_val_as_int = _mm256_extract_epi32(ai, 2); - break; - case 3: - float_val_as_int = _mm256_extract_epi32(ai, 3); - break; - case 4: - float_val_as_int = _mm256_extract_epi32(ai, 4); - break; - case 5: - float_val_as_int = _mm256_extract_epi32(ai, 5); - break; - case 6: - float_val_as_int = _mm256_extract_epi32(ai, 6); - break; - case 7: - float_val_as_int = _mm256_extract_epi32(ai, 7); - break; - default: - RUY_DCHECK_LT(i, 8); - return .0f; - } - float float_val; - std::memcpy(&float_val, &float_val_as_int, sizeof(float_val)); - return float_val; -} - -// Defined as a template so clang won't detect it as an uneeded -// definition. -template <Path path> -inline void mm256_n_storeu_ps(float* dst, int residual_rows, const __m256 v) { - for (int i = 0; i < residual_rows; ++i) { - dst[i] = intrin_utils::mm256_get1_ps<path>(v, i); - } -} - - -template<Path path> -inline __m256 MulAdd(const __m256&, const __m256&, const __m256&) { - // Specializations added for AVX and AVX2FMA paths in their respective kernel - // files. - RUY_DCHECK(false); -} -} // namespace intrin_utils -} // namespace - -template<Path path> -inline void KernelFloatAvxCommon(const KernelParamsFloat<8, 8>& params) { - - // As parameters are defined, we need to scale by sizeof(float). - const std::int64_t lhs_stride = params.lhs_stride >> 2; - const std::int64_t dst_stride = params.dst_stride >> 2; - const std::int64_t rhs_stride = params.rhs_stride >> 2; - // - int bias_ptr_block_increment = params.flags & RUY_ASM_FLAG_HAS_BIAS ? 1 : 0; - // AVX2 float block size = 8. - const int end_row = std::min(params.dst_rows, params.last_row + 8); - const int end_col = std::min(params.dst_cols, params.last_col + 8); - // - const float* adj_rhs_col_ptr = - params.rhs_base_ptr - params.start_col * rhs_stride; - float* adj_dst_col_ptr = - params.dst_base_ptr - params.start_col * dst_stride - params.start_row; - const float* adj_lhs_col_ptr = - params.lhs_base_ptr - params.start_row * lhs_stride; - const float* bias_ptr = params.bias; - - const __m256 clamp_max_v = _mm256_set1_ps(params.clamp_max); - const __m256 clamp_min_v = _mm256_set1_ps(params.clamp_min); - const bool channel_dimension_is_col = - params.flags & RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL; - - int col = params.start_col; - // Loop through cols by float block size, leaving incomplete remainder - for (; col <= end_col - 8; col += 8) { - __m256 accum_data_v[8]; - - const float* rhs_col_ptr = adj_rhs_col_ptr + col * rhs_stride; - float* dst_col_ptr = adj_dst_col_ptr + col * dst_stride; - - for (int row = params.start_row; row < end_row; row += 8) { - const int residual_rows = std::min(end_row - row, 8); - - const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride; - float* dst_ptr = dst_col_ptr + row; - - // Initialize with bias. - if (channel_dimension_is_col) { - const float* bias_elem_ptr = bias_ptr + col * bias_ptr_block_increment; - for (int j = 0; j < 8; ++j) { - accum_data_v[j] = _mm256_broadcast_ss(bias_elem_ptr + j); - } - } else { - const float* bias_elem_ptr = bias_ptr + row * bias_ptr_block_increment; - const __m256 initial_accum_data = _mm256_loadu_ps(bias_elem_ptr); - - for (int j = 0; j < 8; ++j) { - accum_data_v[j] = initial_accum_data; - } - } - - const float* lhs_ptr = lhs_col_ptr; - const float* rhs_ptr = rhs_col_ptr; - for (int d = 0; d < params.depth; ++d) { - const __m256 lhs_data = _mm256_loadu_ps(lhs_ptr); - // In this version RHS values are loaded individually rather than first - // loading together and then extract with broadcasting. This is because - // AVX flavours and instrinsics and compilers in combination do not - // handle this pattern of extraction very well. - const float* rhs_data = rhs_ptr; - - for (int j = 0; j < 8; ++j) { - const __m256 dup_rhs_element_j = _mm256_set1_ps(rhs_data[j]); - accum_data_v[j] = - intrin_utils::MulAdd<path>(lhs_data, dup_rhs_element_j, accum_data_v[j]); - } - lhs_ptr += 8; - rhs_ptr += 8; - } - - if (residual_rows == 8) { - for (int j = 0; j < 8; ++j) { - float* block_ptr = dst_ptr + j * dst_stride; - accum_data_v[j] = _mm256_min_ps(accum_data_v[j], clamp_max_v); - accum_data_v[j] = _mm256_max_ps(accum_data_v[j], clamp_min_v); - _mm256_storeu_ps(block_ptr, accum_data_v[j]); - } - } else { - for (int j = 0; j < 8; ++j) { - float* block_ptr = dst_ptr + j * dst_stride; - accum_data_v[j] = _mm256_min_ps(accum_data_v[j], clamp_max_v); - accum_data_v[j] = _mm256_max_ps(accum_data_v[j], clamp_min_v); - intrin_utils::mm256_n_storeu_ps<path>(block_ptr, residual_rows, - accum_data_v[j]); - } - } - } // End row-block loop. - } // End col-block loop. - - if (col < end_col) { - // Remaining cols in [0, float block size). - RUY_DCHECK_GE(end_col - col, 0); - RUY_DCHECK_LT(end_col - col, 8); - - __m256 accum_data_v[8]; - - const float* rhs_col_ptr = adj_rhs_col_ptr + col * rhs_stride; - float* dst_col_ptr = adj_dst_col_ptr + col * dst_stride; - const int residual_cols = std::min(end_col - col, 8); - - for (int row = params.start_row; row < end_row; row += 8) { - const int residual_rows = std::min(end_row - row, 8); - - const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride; - float* dst_ptr = dst_col_ptr + row; - - // Initialize with bias. - if (channel_dimension_is_col) { - const float* bias_elem_ptr = bias_ptr + col * bias_ptr_block_increment; - for (int j = 0; j < 8; ++j) { - accum_data_v[j] = _mm256_broadcast_ss(bias_elem_ptr + j); - } - } else { - const float* bias_elem_ptr = bias_ptr + row * bias_ptr_block_increment; - const __m256 initial_accum_data = _mm256_loadu_ps(bias_elem_ptr); - - for (int j = 0; j < 8; ++j) { - accum_data_v[j] = initial_accum_data; - } - } - - const float* lhs_ptr = lhs_col_ptr; - const float* rhs_ptr = rhs_col_ptr; - for (int d = 0; d < params.depth; ++d) { - const __m256 lhs_data = _mm256_loadu_ps(lhs_ptr); - const float* rhs_data = rhs_ptr; - - for (int j = 0; j < 8; ++j) { - const __m256 dup_rhs_element_j = _mm256_set1_ps(rhs_data[j]); - accum_data_v[j] = - intrin_utils::MulAdd<path>(lhs_data, dup_rhs_element_j, accum_data_v[j]); - } - lhs_ptr += 8; - rhs_ptr += 8; - } - - for (int j = 0; j < residual_cols; ++j) { - float* block_ptr = dst_ptr + j * dst_stride; - accum_data_v[j] = _mm256_min_ps(accum_data_v[j], clamp_max_v); - accum_data_v[j] = _mm256_max_ps(accum_data_v[j], clamp_min_v); - intrin_utils::mm256_n_storeu_ps<path>(block_ptr, residual_rows, - accum_data_v[j]); - } - } // End row-block loop. - } // End col-block terminal conditional. -} - -template<Path path> -inline void KernelFloatAvxCommonSingleCol(const KernelParamsFloat<8, 8>& params) { - - RUY_DCHECK_EQ(params.dst_cols, 1); - RUY_DCHECK_EQ(params.last_col, 0); - RUY_DCHECK_EQ(params.start_col, 0); - - // As parameters are defined, we need to scale by sizeof(float). - const std::int64_t lhs_stride = params.lhs_stride >> 2; - // - int bias_ptr_block_increment = params.flags & RUY_ASM_FLAG_HAS_BIAS ? 1 : 0; - // AVX2 float block size = 8. - const int end_row = std::min(params.dst_rows, params.last_row + 8); - - float* adj_dst_col_ptr = params.dst_base_ptr - params.start_row; - const float* adj_lhs_col_ptr = - params.lhs_base_ptr - params.start_row * lhs_stride; - const float* bias_col_ptr = params.bias; - - const __m256 clamp_max_v = _mm256_set1_ps(params.clamp_max); - const __m256 clamp_min_v = _mm256_set1_ps(params.clamp_min); - - __m256 accum_data_v; - - const float* rhs_col_ptr = params.rhs_base_ptr; - float* dst_col_ptr = adj_dst_col_ptr; - - int row = params.start_row; - for (; row <= end_row - 8; row += 8) { - const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride; - float* dst_ptr = dst_col_ptr + row; - const float* bias_ptr = bias_col_ptr + row * bias_ptr_block_increment; - - // Initialize with bias. - accum_data_v = _mm256_loadu_ps(bias_ptr); - - const float* lhs_ptr = lhs_col_ptr; - const float* rhs_ptr = rhs_col_ptr; - int d = 0; - for (; d <= params.depth - 4; d += 4) { - const __m256 lhs_data_0 = _mm256_loadu_ps(lhs_ptr); - const __m256 dup_rhs_element_0 = _mm256_set1_ps(rhs_ptr[0]); - accum_data_v = - intrin_utils::MulAdd<path>(lhs_data_0, dup_rhs_element_0, accum_data_v); - const __m256 dup_rhs_element_1 = _mm256_set1_ps(rhs_ptr[8]); - const __m256 lhs_data_1 = _mm256_loadu_ps(lhs_ptr + 8); - accum_data_v = - intrin_utils::MulAdd<path>(lhs_data_1, dup_rhs_element_1, accum_data_v); - - const __m256 lhs_data_2 = _mm256_loadu_ps(lhs_ptr + 16); - const __m256 dup_rhs_element_2 = _mm256_set1_ps(rhs_ptr[16]); - accum_data_v = - intrin_utils::MulAdd<path>(lhs_data_2, dup_rhs_element_2, accum_data_v); - const __m256 dup_rhs_element_3 = _mm256_set1_ps(rhs_ptr[24]); - const __m256 lhs_data_3 = _mm256_loadu_ps(lhs_ptr + 24); - accum_data_v = - intrin_utils::MulAdd<path>(lhs_data_3, dup_rhs_element_3, accum_data_v); - lhs_ptr += 32; // Loaded 8 * 4 floats. - rhs_ptr += 32; - } - for (; d < params.depth; ++d) { - const __m256 lhs_data = _mm256_loadu_ps(lhs_ptr); - const float* rhs_data = rhs_ptr; - - const __m256 dup_rhs_element_j = _mm256_set1_ps(rhs_data[0]); - accum_data_v = intrin_utils::MulAdd<path>(lhs_data, dup_rhs_element_j, accum_data_v); - lhs_ptr += 8; - rhs_ptr += 8; - } - - accum_data_v = _mm256_min_ps(accum_data_v, clamp_max_v); - accum_data_v = _mm256_max_ps(accum_data_v, clamp_min_v); - _mm256_storeu_ps(dst_ptr, accum_data_v); - } // End row-block loop. - - if (row < end_row) { - const int residual_rows = end_row - row; - RUY_CHECK_GE(residual_rows, 1); - RUY_CHECK_LT(residual_rows, 8); - - const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride; - float* dst_ptr = dst_col_ptr + row; - const float* bias_ptr = bias_col_ptr + row * bias_ptr_block_increment; - - // Initialize with bias. - accum_data_v = _mm256_loadu_ps(bias_ptr); - - const float* lhs_ptr = lhs_col_ptr; - const float* rhs_ptr = rhs_col_ptr; - for (int d = 0; d < params.depth; ++d) { - const __m256 lhs_data = _mm256_loadu_ps(lhs_ptr); - const float* rhs_data = rhs_ptr; - - const __m256 dup_rhs_element_j = _mm256_set1_ps(rhs_data[0]); - accum_data_v = intrin_utils::MulAdd<path>(lhs_data, dup_rhs_element_j, accum_data_v); - lhs_ptr += 8; - rhs_ptr += 8; - } - - accum_data_v = _mm256_min_ps(accum_data_v, clamp_max_v); - accum_data_v = _mm256_max_ps(accum_data_v, clamp_min_v); - intrin_utils::mm256_n_storeu_ps<path>(dst_ptr, residual_rows, accum_data_v); - } // End handling of residual rows. -} -#endif // (RUY_PLATFORM_AVX || RUY_PLATFORM_AVX2_FMA) && RUY_OPT(ASM) - - } // namespace ruy #endif // RUY_RUY_KERNEL_X86_H_ @@ -77,15 +77,12 @@ enum class Path : std::uint8_t { #endif // RUY_PLATFORM_ARM #if RUY_PLATFORM_X86 - // Optimized for AVX - // Compiled with -mavx - kAvx = 0x10, // Optimized for AVX2+FMA. // Compiled with -mavx2 -mfma. - kAvx2Fma = 0x20, + kAvx2Fma = 0x10, // Optimized for AVX-512. // Compiled with -mavx512f -mavx512vl -mavx512cd -mavx512bw -mavx512dq. - kAvx512 = 0x40, + kAvx512 = 0x20, #endif // RUY_PLATFORM_X86 }; @@ -148,7 +145,7 @@ constexpr Path kExtraArchPaths = Path::kNone; constexpr Path kDefaultArchPaths = Path::kNeon; constexpr Path kExtraArchPaths = Path::kNone; #elif RUY_PLATFORM_X86 -constexpr Path kDefaultArchPaths = Path::kAvx | Path::kAvx2Fma | Path::kAvx512; +constexpr Path kDefaultArchPaths = Path::kAvx2Fma | Path::kAvx512; constexpr Path kExtraArchPaths = Path::kNone; #else constexpr Path kDefaultArchPaths = Path::kNone; diff --git a/ruy/platform.h b/ruy/platform.h index c8bed3a..7421613 100644 --- a/ruy/platform.h +++ b/ruy/platform.h @@ -139,12 +139,6 @@ limitations under the License. #define RUY_PLATFORM_AVX2_FMA 0 #endif -#if RUY_PLATFORM_X86_ENHANCEMENTS && RUY_PLATFORM_X86 && defined(__AVX__) -#define RUY_PLATFORM_AVX 1 -#else -#define RUY_PLATFORM_AVX 0 -#endif - // Detect Emscripten, typically Wasm. #ifdef __EMSCRIPTEN__ #define RUY_PLATFORM_EMSCRIPTEN 1 @@ -107,7 +107,6 @@ inline const char* PathName(Path path) { #elif RUY_PLATFORM_X86 RUY_PATHNAME_CASE(kAvx2Fma) RUY_PATHNAME_CASE(kAvx512) - RUY_PATHNAME_CASE(kAvx) #endif default: RUY_CHECK(false); @@ -1355,8 +1354,9 @@ bool Agree(ExternalPath external_path1, const Matrix<Scalar>& matrix1, const int size = matrix1.layout().rows() * matrix1.layout().cols(); double tolerated_max_diff = 0; double tolerated_mean_diff = 0; - const float kSmallestAllowedDifference = 4. * std::numeric_limits<Scalar>::epsilon(); if (std::is_floating_point<Scalar>::value) { + // TODO: replace hardcoded 100 by something more sensible, probably + // roughly sqrt(depth) based on central limit theorem. double max_abs_val = 0; for (int row = 0; row < matrix1.layout().rows(); row++) { for (int col = 0; col < matrix1.layout().cols(); col++) { @@ -1370,18 +1370,6 @@ bool Agree(ExternalPath external_path1, const Matrix<Scalar>& matrix1, } tolerated_max_diff = max_abs_val * std::numeric_limits<Scalar>::epsilon() * 64 * std::sqrt(static_cast<float>(depth)); - if (tolerated_max_diff < kSmallestAllowedDifference) { - // Clamp the tolerated max diff to be a bit above machine epsilon if the - // calculated value is too small. - tolerated_max_diff = kSmallestAllowedDifference; - if (external_path1 == ExternalPath::kEigen || - external_path2 == ExternalPath::kEigen || - external_path1 == ExternalPath::kEigenTensor || - external_path2 == ExternalPath::kEigenTensor) { - // Make additional allowance for Eigen differences. - tolerated_max_diff *= 3.0f; - } - } tolerated_mean_diff = tolerated_max_diff / std::sqrt(size); } else if (std::is_same<Scalar, std::int32_t>::value) { // raw integer case, no rounding, so we can require exactness |