Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/google/ruy.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorT.J. Alumbaugh <talumbau@google.com>2020-07-30 01:04:36 +0300
committerCopybara-Service <copybara-worker@google.com>2020-07-30 01:05:07 +0300
commitd4822f45c6830b16ea05e7fa40c5680fb8010eb7 (patch)
tree3e6961d14cc210ace7419b9732919de8cf65d2f9
parent18e34facf87619fdfd6571b3050796e8f2f2d15b (diff)
Adds AVX path and AVX float kernel.
PiperOrigin-RevId: 323876243
-rw-r--r--ruy/BUILD41
-rw-r--r--ruy/build_defs.bzl6
-rw-r--r--ruy/cpuinfo.cc3
-rw-r--r--ruy/cpuinfo.h1
-rw-r--r--ruy/ctx.cc2
-rw-r--r--ruy/have_built_path_for.h1
-rw-r--r--ruy/have_built_path_for_avx.cc35
-rw-r--r--ruy/kernel_avx.cc75
-rw-r--r--ruy/kernel_avx2_fma.cc301
-rw-r--r--ruy/kernel_x86.h349
-rw-r--r--ruy/path.h9
-rw-r--r--ruy/platform.h6
-rw-r--r--ruy/test.h16
13 files changed, 296 insertions, 549 deletions
diff --git a/ruy/BUILD b/ruy/BUILD
index 392203f..9d79984 100644
--- a/ruy/BUILD
+++ b/ruy/BUILD
@@ -1,7 +1,7 @@
# Ruy is not BLAS
load("@bazel_skylib//lib:selects.bzl", "selects")
-load(":build_defs.bzl", "ruy_copts", "ruy_copts_avx", "ruy_copts_avx2_fma", "ruy_copts_avx512")
+load(":build_defs.bzl", "ruy_copts", "ruy_copts_avx2_fma", "ruy_copts_avx512")
load(":build_defs.oss.bzl", "ruy_linkopts_thread_standard_library")
load(":ruy_test_ext.oss.bzl", "ruy_test_ext_defines", "ruy_test_ext_deps")
load(":ruy_test.bzl", "ruy_benchmark", "ruy_test")
@@ -682,43 +682,6 @@ cc_library(
)
cc_library(
- name = "kernel_avx",
- srcs = [
- "kernel_avx.cc",
- ],
- hdrs = [
- "kernel_x86.h",
- ],
- copts = ruy_copts() + ruy_copts_avx(),
- deps = [
- ":check_macros",
- ":kernel_common",
- ":mat",
- ":mul_params",
- ":opt_set",
- ":path",
- ":platform",
- ":tune",
- "//ruy/profiler:instrumentation",
- ],
-)
-
-cc_library(
- name = "have_built_path_for_avx",
- srcs = [
- "have_built_path_for_avx.cc",
- ],
- hdrs = [
- "have_built_path_for.h",
- ],
- copts = ruy_copts() + ruy_copts_avx2_fma(),
- deps = [
- ":opt_set",
- ":platform",
- ],
-)
-
-cc_library(
name = "kernel",
hdrs = [
"kernel.h",
@@ -728,7 +691,6 @@ cc_library(
":apply_multiplier",
":check_macros",
":kernel_arm", # fixdeps: keep
- ":kernel_avx",
":kernel_avx2_fma", # fixdeps: keep
":kernel_avx512", # fixdeps: keep
":kernel_common",
@@ -773,7 +735,6 @@ cc_library(
"have_built_path_for.h",
],
deps = [
- ":have_built_path_for_avx",
":have_built_path_for_avx2_fma",
":have_built_path_for_avx512",
":platform",
diff --git a/ruy/build_defs.bzl b/ruy/build_defs.bzl
index e2fc325..a36942b 100644
--- a/ruy/build_defs.bzl
+++ b/ruy/build_defs.bzl
@@ -63,12 +63,6 @@ def ruy_copts_avx512():
"//conditions:default": [],
})
-def ruy_copts_avx():
- return select({
- "//ruy:x86_64": ["-mavx"],
- "//conditions:default": [],
- })
-
def ruy_copts_avx2_fma():
return select({
"//ruy:x86_64": ["-mavx2", "-mfma"],
diff --git a/ruy/cpuinfo.cc b/ruy/cpuinfo.cc
index 3203bb2..335158d 100644
--- a/ruy/cpuinfo.cc
+++ b/ruy/cpuinfo.cc
@@ -106,8 +106,6 @@ bool CpuInfo::Avx2Fma() {
cpuinfo_has_x86_fma3();
}
-bool CpuInfo::Avx() { return EnsureInitialized() && cpuinfo_has_x86_avx(); }
-
bool CpuInfo::Avx512() {
return EnsureInitialized() && cpuinfo_has_x86_avx512f() &&
cpuinfo_has_x86_avx512dq() && cpuinfo_has_x86_avx512cd() &&
@@ -131,7 +129,6 @@ bool CpuInfo::EnsureInitialized() {
}
bool CpuInfo::NeonDotprod() { return false; }
bool CpuInfo::Sse42() { return false; }
-bool CpuInfo::Avx() { return false; }
bool CpuInfo::Avx2Fma() { return false; }
bool CpuInfo::Avx512() { return false; }
bool CpuInfo::AvxVnni() { return false; }
diff --git a/ruy/cpuinfo.h b/ruy/cpuinfo.h
index 5d816f9..17d061e 100644
--- a/ruy/cpuinfo.h
+++ b/ruy/cpuinfo.h
@@ -31,7 +31,6 @@ class CpuInfo final {
// X86 features
bool Sse42();
- bool Avx();
bool Avx2Fma();
bool Avx512();
bool AvxVnni();
diff --git a/ruy/ctx.cc b/ruy/ctx.cc
index bbf5cad..c4d5b71 100644
--- a/ruy/ctx.cc
+++ b/ruy/ctx.cc
@@ -111,8 +111,6 @@ Path DetectRuntimeSupportedPaths(Path paths_to_detect, CpuInfo* cpuinfo) {
#elif RUY_PLATFORM_X86
// x86 SIMD paths currently require both runtime detection, and detection of
// whether we're building the path at all.
- maybe_add(Path::kAvx,
- [=]() { return HaveBuiltPathForAvx() && cpuinfo->Avx(); });
maybe_add(Path::kAvx2Fma,
[=]() { return HaveBuiltPathForAvx2Fma() && cpuinfo->Avx2Fma(); });
maybe_add(Path::kAvx512,
diff --git a/ruy/have_built_path_for.h b/ruy/have_built_path_for.h
index 23cb028..60e98e1 100644
--- a/ruy/have_built_path_for.h
+++ b/ruy/have_built_path_for.h
@@ -21,7 +21,6 @@ limitations under the License.
namespace ruy {
#if RUY_PLATFORM_X86
-bool HaveBuiltPathForAvx();
bool HaveBuiltPathForAvx2Fma();
bool HaveBuiltPathForAvx512();
#endif // RUY_PLATFORM_X86
diff --git a/ruy/have_built_path_for_avx.cc b/ruy/have_built_path_for_avx.cc
deleted file mode 100644
index 948c7a5..0000000
--- a/ruy/have_built_path_for_avx.cc
+++ /dev/null
@@ -1,35 +0,0 @@
-/* Copyright 2020 Google LLC. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "ruy/have_built_path_for.h"
-#include "ruy/opt_set.h"
-
-namespace ruy {
-
-#if RUY_PLATFORM_X86
-// IMPORTANT:
-// These patterns must match those in the pack and kernel cc files.
-#if !(RUY_PLATFORM_AVX && RUY_OPT(ASM))
-
-bool HaveBuiltPathForAvx() { return false; }
-
-#else // RUY_PLATFORM_AVX && RUY_OPT(ASM)
-
-bool HaveBuiltPathForAvx() { return true; }
-
-#endif // RUY_PLATFORM_AVX && RUY_OPT(ASM)
-#endif // RUY_PLATFORM_X86
-
-} // namespace ruy
diff --git a/ruy/kernel_avx.cc b/ruy/kernel_avx.cc
deleted file mode 100644
index c3bc473..0000000
--- a/ruy/kernel_avx.cc
+++ /dev/null
@@ -1,75 +0,0 @@
-/* Copyright 2020 Google LLC. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include <algorithm>
-#include <cstdint>
-#include <cstring>
-
-#include "ruy/check_macros.h"
-#include "ruy/kernel_common.h"
-#include "ruy/kernel_x86.h"
-#include "ruy/opt_set.h"
-#include "ruy/platform.h"
-#include "ruy/profiler/instrumentation.h"
-
-#if RUY_PLATFORM_AVX && RUY_OPT(ASM)
-#include <immintrin.h> // IWYU pragma: keep
-#endif
-
-namespace ruy {
-
-#if !(RUY_PLATFORM_AVX && RUY_OPT(ASM))
-
-void KernelFloatAvx(const KernelParamsFloat<8, 8>&) {
- // CPU-ID-based checks should disable the path that would reach this point.
- RUY_DCHECK(false);
-}
-
-void KernelFloatAvxSingleCol(const KernelParamsFloat<8, 8>&) {
- // CPU-ID-based checks should disable the path that would reach this point.
- RUY_DCHECK(false);
-}
-
-#else // RUY_PLATFORM_AVX && RUY_OPT(ASM)
-
-namespace {
-namespace intrin_utils {
-
-
-// AVX doesn't have fused multiply-add so we define an inline function to be
-// used in the common code following.
-template<>
-inline __m256 MulAdd<Path::kAvx>(const __m256& a, const __m256& b, const __m256& c) {
- const __m256 prod = _mm256_mul_ps(a, b);
- return _mm256_add_ps(prod, c);
-}
-
-} // namespace intrin_utils
-} // namespace
-
-
-void KernelFloatAvx(const KernelParamsFloat<8, 8>& params) {
- profiler::ScopeLabel label("Kernel kAvx float");
- KernelFloatAvxCommon<Path::kAvx>(params);
-}
-
-void KernelFloatAvxSingleCol(const KernelParamsFloat<8, 8>& params) {
- profiler::ScopeLabel label("Kernel kAvx float GEMV");
- KernelFloatAvxCommonSingleCol<Path::kAvx>(params);
-}
-
-#endif // RUY_PLATFORM_AVX && RUY_OPT(ASM)
-
-} // namespace ruy
diff --git a/ruy/kernel_avx2_fma.cc b/ruy/kernel_avx2_fma.cc
index 0b3535f..4ba73f5 100644
--- a/ruy/kernel_avx2_fma.cc
+++ b/ruy/kernel_avx2_fma.cc
@@ -282,6 +282,48 @@ inline void mm256_storeu_epi32(std::int32_t* dst, const __m256i v) {
_mm256_storeu_si256(reinterpret_cast<__m256i*>(dst), v);
}
+inline float mm256_get1_ps(const __m256 a, int i) {
+ __m256i ai = _mm256_castps_si256(a);
+ int float_val_as_int;
+ switch (i) {
+ case 0:
+ float_val_as_int = _mm256_extract_epi32(ai, 0);
+ break;
+ case 1:
+ float_val_as_int = _mm256_extract_epi32(ai, 1);
+ break;
+ case 2:
+ float_val_as_int = _mm256_extract_epi32(ai, 2);
+ break;
+ case 3:
+ float_val_as_int = _mm256_extract_epi32(ai, 3);
+ break;
+ case 4:
+ float_val_as_int = _mm256_extract_epi32(ai, 4);
+ break;
+ case 5:
+ float_val_as_int = _mm256_extract_epi32(ai, 5);
+ break;
+ case 6:
+ float_val_as_int = _mm256_extract_epi32(ai, 6);
+ break;
+ case 7:
+ float_val_as_int = _mm256_extract_epi32(ai, 7);
+ break;
+ default:
+ RUY_DCHECK_LT(i, 8);
+ return .0f;
+ }
+ float float_val;
+ std::memcpy(&float_val, &float_val_as_int, sizeof(float_val));
+ return float_val;
+}
+
+inline void mm256_n_storeu_ps(float* dst, int residual_rows, const __m256 v) {
+ for (int i = 0; i < residual_rows; ++i) {
+ dst[i] = intrin_utils::mm256_get1_ps(v, i);
+ }
+}
// Transpose a 8x8 matrix of floats.
void mm256_transpose8x8_ps(__m256* v0, __m256* v1, __m256* v2, __m256* v3,
@@ -321,14 +363,6 @@ void mm256_transpose8x8_epi32(__m256i* v0, __m256i* v1, __m256i* v2,
reinterpret_cast<__m256*>(v4), reinterpret_cast<__m256*>(v5),
reinterpret_cast<__m256*>(v6), reinterpret_cast<__m256*>(v7));
}
-
-// Make an inline function for FMA so we can share the float kernels
-// with non-FMA code.
-template<>
-inline __m256 MulAdd<Path::kAvx2Fma>(const __m256& a, const __m256& b, const __m256& c) {
- return _mm256_fmadd_ps(a, b, c);
-}
-
} // namespace intrin_utils
} // namespace
@@ -1174,15 +1208,260 @@ void Kernel8bitAvx2SingleCol(const KernelParams8bit<8, 8>& params) {
rhs_col_ptr += kAvx8bitBlockSize * params.rhs_stride;
} // NOLINT(readability/fn_size)
-
void KernelFloatAvx2(const KernelParamsFloat<8, 8>& params) {
profiler::ScopeLabel label("Kernel kAvx2Fma float");
- KernelFloatAvxCommon<Path::kAvx2Fma>(params);
+
+ // As parameters are defined, we need to scale by sizeof(float).
+ const std::int64_t lhs_stride = params.lhs_stride >> 2;
+ const std::int64_t dst_stride = params.dst_stride >> 2;
+ const std::int64_t rhs_stride = params.rhs_stride >> 2;
+ //
+ int bias_ptr_block_increment = params.flags & RUY_ASM_FLAG_HAS_BIAS ? 1 : 0;
+ // AVX2 float block size = 8.
+ const int end_row = std::min(params.dst_rows, params.last_row + 8);
+ const int end_col = std::min(params.dst_cols, params.last_col + 8);
+ //
+ const float* adj_rhs_col_ptr =
+ params.rhs_base_ptr - params.start_col * rhs_stride;
+ float* adj_dst_col_ptr =
+ params.dst_base_ptr - params.start_col * dst_stride - params.start_row;
+ const float* adj_lhs_col_ptr =
+ params.lhs_base_ptr - params.start_row * lhs_stride;
+ const float* bias_ptr = params.bias;
+
+ const __m256 clamp_max_v = _mm256_set1_ps(params.clamp_max);
+ const __m256 clamp_min_v = _mm256_set1_ps(params.clamp_min);
+ const bool channel_dimension_is_col =
+ params.flags & RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL;
+
+ int col = params.start_col;
+ // Loop through cols by float block size, leaving incomplete remainder
+ for (; col <= end_col - 8; col += 8) {
+ __m256 accum_data_v[8];
+
+ const float* rhs_col_ptr = adj_rhs_col_ptr + col * rhs_stride;
+ float* dst_col_ptr = adj_dst_col_ptr + col * dst_stride;
+
+ for (int row = params.start_row; row < end_row; row += 8) {
+ const int residual_rows = std::min(end_row - row, 8);
+
+ const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride;
+ float* dst_ptr = dst_col_ptr + row;
+
+ // Initialize with bias.
+ if (channel_dimension_is_col) {
+ const float* bias_elem_ptr = bias_ptr + col * bias_ptr_block_increment;
+ for (int j = 0; j < 8; ++j) {
+ accum_data_v[j] = _mm256_broadcast_ss(bias_elem_ptr + j);
+ }
+ } else {
+ const float* bias_elem_ptr = bias_ptr + row * bias_ptr_block_increment;
+ const __m256 initial_accum_data = _mm256_loadu_ps(bias_elem_ptr);
+
+ for (int j = 0; j < 8; ++j) {
+ accum_data_v[j] = initial_accum_data;
+ }
+ }
+
+ const float* lhs_ptr = lhs_col_ptr;
+ const float* rhs_ptr = rhs_col_ptr;
+ for (int d = 0; d < params.depth; ++d) {
+ const __m256 lhs_data = _mm256_loadu_ps(lhs_ptr);
+ // In this version RHS values are loaded individually rather than first
+ // loading together and then extract with broadcasting. This is because
+ // AVX flavours and instrinsics and compilers in combination do not
+ // handle this pattern of extraction very well.
+ const float* rhs_data = rhs_ptr;
+
+ for (int j = 0; j < 8; ++j) {
+ const __m256 dup_rhs_element_j = _mm256_set1_ps(rhs_data[j]);
+ accum_data_v[j] =
+ _mm256_fmadd_ps(lhs_data, dup_rhs_element_j, accum_data_v[j]);
+ }
+ lhs_ptr += 8;
+ rhs_ptr += 8;
+ }
+
+ if (residual_rows == 8) {
+ for (int j = 0; j < 8; ++j) {
+ float* block_ptr = dst_ptr + j * dst_stride;
+ accum_data_v[j] = _mm256_min_ps(accum_data_v[j], clamp_max_v);
+ accum_data_v[j] = _mm256_max_ps(accum_data_v[j], clamp_min_v);
+ _mm256_storeu_ps(block_ptr, accum_data_v[j]);
+ }
+ } else {
+ for (int j = 0; j < 8; ++j) {
+ float* block_ptr = dst_ptr + j * dst_stride;
+ accum_data_v[j] = _mm256_min_ps(accum_data_v[j], clamp_max_v);
+ accum_data_v[j] = _mm256_max_ps(accum_data_v[j], clamp_min_v);
+ intrin_utils::mm256_n_storeu_ps(block_ptr, residual_rows,
+ accum_data_v[j]);
+ }
+ }
+ } // End row-block loop.
+ } // End col-block loop.
+
+ if (col < end_col) {
+ // Remaining cols in [0, float block size).
+ RUY_DCHECK_GE(end_col - col, 0);
+ RUY_DCHECK_LT(end_col - col, 8);
+
+ __m256 accum_data_v[8];
+
+ const float* rhs_col_ptr = adj_rhs_col_ptr + col * rhs_stride;
+ float* dst_col_ptr = adj_dst_col_ptr + col * dst_stride;
+ const int residual_cols = std::min(end_col - col, 8);
+
+ for (int row = params.start_row; row < end_row; row += 8) {
+ const int residual_rows = std::min(end_row - row, 8);
+
+ const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride;
+ float* dst_ptr = dst_col_ptr + row;
+
+ // Initialize with bias.
+ if (channel_dimension_is_col) {
+ const float* bias_elem_ptr = bias_ptr + col * bias_ptr_block_increment;
+ for (int j = 0; j < 8; ++j) {
+ accum_data_v[j] = _mm256_broadcast_ss(bias_elem_ptr + j);
+ }
+ } else {
+ const float* bias_elem_ptr = bias_ptr + row * bias_ptr_block_increment;
+ const __m256 initial_accum_data = _mm256_loadu_ps(bias_elem_ptr);
+
+ for (int j = 0; j < 8; ++j) {
+ accum_data_v[j] = initial_accum_data;
+ }
+ }
+
+ const float* lhs_ptr = lhs_col_ptr;
+ const float* rhs_ptr = rhs_col_ptr;
+ for (int d = 0; d < params.depth; ++d) {
+ const __m256 lhs_data = _mm256_loadu_ps(lhs_ptr);
+ const float* rhs_data = rhs_ptr;
+
+ for (int j = 0; j < 8; ++j) {
+ const __m256 dup_rhs_element_j = _mm256_set1_ps(rhs_data[j]);
+ accum_data_v[j] =
+ _mm256_fmadd_ps(lhs_data, dup_rhs_element_j, accum_data_v[j]);
+ }
+ lhs_ptr += 8;
+ rhs_ptr += 8;
+ }
+
+ for (int j = 0; j < residual_cols; ++j) {
+ float* block_ptr = dst_ptr + j * dst_stride;
+ accum_data_v[j] = _mm256_min_ps(accum_data_v[j], clamp_max_v);
+ accum_data_v[j] = _mm256_max_ps(accum_data_v[j], clamp_min_v);
+ intrin_utils::mm256_n_storeu_ps(block_ptr, residual_rows,
+ accum_data_v[j]);
+ }
+ } // End row-block loop.
+ } // End col-block terminal conditional.
}
void KernelFloatAvx2SingleCol(const KernelParamsFloat<8, 8>& params) {
profiler::ScopeLabel label("Kernel kAvx2Fma float GEMV");
- KernelFloatAvxCommonSingleCol<Path::kAvx2Fma>(params);
+
+ RUY_DCHECK_EQ(params.dst_cols, 1);
+ RUY_DCHECK_EQ(params.last_col, 0);
+ RUY_DCHECK_EQ(params.start_col, 0);
+
+ // As parameters are defined, we need to scale by sizeof(float).
+ const std::int64_t lhs_stride = params.lhs_stride >> 2;
+ //
+ int bias_ptr_block_increment = params.flags & RUY_ASM_FLAG_HAS_BIAS ? 1 : 0;
+ // AVX2 float block size = 8.
+ const int end_row = std::min(params.dst_rows, params.last_row + 8);
+
+ float* adj_dst_col_ptr = params.dst_base_ptr - params.start_row;
+ const float* adj_lhs_col_ptr =
+ params.lhs_base_ptr - params.start_row * lhs_stride;
+ const float* bias_col_ptr = params.bias;
+
+ const __m256 clamp_max_v = _mm256_set1_ps(params.clamp_max);
+ const __m256 clamp_min_v = _mm256_set1_ps(params.clamp_min);
+
+ __m256 accum_data_v;
+
+ const float* rhs_col_ptr = params.rhs_base_ptr;
+ float* dst_col_ptr = adj_dst_col_ptr;
+
+ int row = params.start_row;
+ for (; row <= end_row - 8; row += 8) {
+ const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride;
+ float* dst_ptr = dst_col_ptr + row;
+ const float* bias_ptr = bias_col_ptr + row * bias_ptr_block_increment;
+
+ // Initialize with bias.
+ accum_data_v = _mm256_loadu_ps(bias_ptr);
+
+ const float* lhs_ptr = lhs_col_ptr;
+ const float* rhs_ptr = rhs_col_ptr;
+ int d = 0;
+ for (; d <= params.depth - 4; d += 4) {
+ const __m256 lhs_data_0 = _mm256_loadu_ps(lhs_ptr);
+ const __m256 dup_rhs_element_0 = _mm256_set1_ps(rhs_ptr[0]);
+ accum_data_v =
+ _mm256_fmadd_ps(lhs_data_0, dup_rhs_element_0, accum_data_v);
+ const __m256 dup_rhs_element_1 = _mm256_set1_ps(rhs_ptr[8]);
+ const __m256 lhs_data_1 = _mm256_loadu_ps(lhs_ptr + 8);
+ accum_data_v =
+ _mm256_fmadd_ps(lhs_data_1, dup_rhs_element_1, accum_data_v);
+
+ const __m256 lhs_data_2 = _mm256_loadu_ps(lhs_ptr + 16);
+ const __m256 dup_rhs_element_2 = _mm256_set1_ps(rhs_ptr[16]);
+ accum_data_v =
+ _mm256_fmadd_ps(lhs_data_2, dup_rhs_element_2, accum_data_v);
+ const __m256 dup_rhs_element_3 = _mm256_set1_ps(rhs_ptr[24]);
+ const __m256 lhs_data_3 = _mm256_loadu_ps(lhs_ptr + 24);
+ accum_data_v =
+ _mm256_fmadd_ps(lhs_data_3, dup_rhs_element_3, accum_data_v);
+ lhs_ptr += 32; // Loaded 8 * 4 floats.
+ rhs_ptr += 32;
+ }
+ for (; d < params.depth; ++d) {
+ const __m256 lhs_data = _mm256_loadu_ps(lhs_ptr);
+ const float* rhs_data = rhs_ptr;
+
+ const __m256 dup_rhs_element_j = _mm256_set1_ps(rhs_data[0]);
+ accum_data_v = _mm256_fmadd_ps(lhs_data, dup_rhs_element_j, accum_data_v);
+ lhs_ptr += 8;
+ rhs_ptr += 8;
+ }
+
+ accum_data_v = _mm256_min_ps(accum_data_v, clamp_max_v);
+ accum_data_v = _mm256_max_ps(accum_data_v, clamp_min_v);
+ _mm256_storeu_ps(dst_ptr, accum_data_v);
+ } // End row-block loop.
+
+ if (row < end_row) {
+ const int residual_rows = end_row - row;
+ RUY_CHECK_GE(residual_rows, 1);
+ RUY_CHECK_LT(residual_rows, 8);
+
+ const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride;
+ float* dst_ptr = dst_col_ptr + row;
+ const float* bias_ptr = bias_col_ptr + row * bias_ptr_block_increment;
+
+ // Initialize with bias.
+ accum_data_v = _mm256_loadu_ps(bias_ptr);
+
+ const float* lhs_ptr = lhs_col_ptr;
+ const float* rhs_ptr = rhs_col_ptr;
+ for (int d = 0; d < params.depth; ++d) {
+ const __m256 lhs_data = _mm256_loadu_ps(lhs_ptr);
+ const float* rhs_data = rhs_ptr;
+
+ const __m256 dup_rhs_element_j = _mm256_set1_ps(rhs_data[0]);
+ accum_data_v = _mm256_fmadd_ps(lhs_data, dup_rhs_element_j, accum_data_v);
+ lhs_ptr += 8;
+ rhs_ptr += 8;
+ }
+
+ accum_data_v = _mm256_min_ps(accum_data_v, clamp_max_v);
+ accum_data_v = _mm256_max_ps(accum_data_v, clamp_min_v);
+ intrin_utils::mm256_n_storeu_ps(dst_ptr, residual_rows, accum_data_v);
+ } // End handling of residual rows.
}
#endif // RUY_PLATFORM_AVX2_FMA && RUY_OPT(ASM)
diff --git a/ruy/kernel_x86.h b/ruy/kernel_x86.h
index b330482..641fbb6 100644
--- a/ruy/kernel_x86.h
+++ b/ruy/kernel_x86.h
@@ -31,7 +31,6 @@ namespace ruy {
#if RUY_PLATFORM_X86
RUY_INHERIT_KERNEL(Path::kStandardCpp, Path::kAvx2Fma)
-RUY_INHERIT_KERNEL(Path::kStandardCpp, Path::kAvx)
RUY_INHERIT_KERNEL(Path::kAvx2Fma, Path::kAvx512)
void Kernel8bitAvx512(const KernelParams8bit<16, 16>& params);
@@ -135,356 +134,8 @@ struct Kernel<Path::kAvx2Fma, float, float, float, float> {
}
};
-void KernelFloatAvx(const KernelParamsFloat<8, 8>& params);
-void KernelFloatAvxSingleCol(const KernelParamsFloat<8, 8>& params);
-
-template <>
-struct Kernel<Path::kAvx, float, float, float, float> {
- static constexpr Path kPath = Path::kAvx;
- Tuning tuning = Tuning::kAuto;
- using LhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 8>;
- using RhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 8>;
- explicit Kernel(Tuning tuning_) : tuning(tuning_) {}
- void Run(const PMat<float>& lhs, const PMat<float>& rhs,
- const MulParams<float, float>& mul_params, int start_row,
- int start_col, int end_row, int end_col, Mat<float>* dst) const {
- KernelParamsFloat<LhsLayout::kCols, RhsLayout::kCols> params;
- MakeKernelParamsFloat(lhs, rhs, mul_params, start_row, start_col, end_row,
- end_col, dst, &params);
- if (dst->layout.cols == 1 &&
- mul_params.channel_dimension() == ChannelDimension::kRow) {
- KernelFloatAvxSingleCol(params);
- } else {
- KernelFloatAvx(params);
- }
- }
-};
#endif // RUY_PLATFORM_X86
-#if ((RUY_PLATFORM_AVX || RUY_PLATFORM_AVX2_FMA) && RUY_OPT(ASM))
-
-#include <immintrin.h> // IWYU pragma: keep
-
-namespace {
-namespace intrin_utils {
-
-// Defined as a template so clang won't detect it as an uneeded
-// definition.
-template<Path path>
-inline float mm256_get1_ps(const __m256 a, int i) {
- __m256i ai = _mm256_castps_si256(a);
- int float_val_as_int;
- switch (i) {
- case 0:
- float_val_as_int = _mm256_extract_epi32(ai, 0);
- break;
- case 1:
- float_val_as_int = _mm256_extract_epi32(ai, 1);
- break;
- case 2:
- float_val_as_int = _mm256_extract_epi32(ai, 2);
- break;
- case 3:
- float_val_as_int = _mm256_extract_epi32(ai, 3);
- break;
- case 4:
- float_val_as_int = _mm256_extract_epi32(ai, 4);
- break;
- case 5:
- float_val_as_int = _mm256_extract_epi32(ai, 5);
- break;
- case 6:
- float_val_as_int = _mm256_extract_epi32(ai, 6);
- break;
- case 7:
- float_val_as_int = _mm256_extract_epi32(ai, 7);
- break;
- default:
- RUY_DCHECK_LT(i, 8);
- return .0f;
- }
- float float_val;
- std::memcpy(&float_val, &float_val_as_int, sizeof(float_val));
- return float_val;
-}
-
-// Defined as a template so clang won't detect it as an uneeded
-// definition.
-template <Path path>
-inline void mm256_n_storeu_ps(float* dst, int residual_rows, const __m256 v) {
- for (int i = 0; i < residual_rows; ++i) {
- dst[i] = intrin_utils::mm256_get1_ps<path>(v, i);
- }
-}
-
-
-template<Path path>
-inline __m256 MulAdd(const __m256&, const __m256&, const __m256&) {
- // Specializations added for AVX and AVX2FMA paths in their respective kernel
- // files.
- RUY_DCHECK(false);
-}
-} // namespace intrin_utils
-} // namespace
-
-template<Path path>
-inline void KernelFloatAvxCommon(const KernelParamsFloat<8, 8>& params) {
-
- // As parameters are defined, we need to scale by sizeof(float).
- const std::int64_t lhs_stride = params.lhs_stride >> 2;
- const std::int64_t dst_stride = params.dst_stride >> 2;
- const std::int64_t rhs_stride = params.rhs_stride >> 2;
- //
- int bias_ptr_block_increment = params.flags & RUY_ASM_FLAG_HAS_BIAS ? 1 : 0;
- // AVX2 float block size = 8.
- const int end_row = std::min(params.dst_rows, params.last_row + 8);
- const int end_col = std::min(params.dst_cols, params.last_col + 8);
- //
- const float* adj_rhs_col_ptr =
- params.rhs_base_ptr - params.start_col * rhs_stride;
- float* adj_dst_col_ptr =
- params.dst_base_ptr - params.start_col * dst_stride - params.start_row;
- const float* adj_lhs_col_ptr =
- params.lhs_base_ptr - params.start_row * lhs_stride;
- const float* bias_ptr = params.bias;
-
- const __m256 clamp_max_v = _mm256_set1_ps(params.clamp_max);
- const __m256 clamp_min_v = _mm256_set1_ps(params.clamp_min);
- const bool channel_dimension_is_col =
- params.flags & RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL;
-
- int col = params.start_col;
- // Loop through cols by float block size, leaving incomplete remainder
- for (; col <= end_col - 8; col += 8) {
- __m256 accum_data_v[8];
-
- const float* rhs_col_ptr = adj_rhs_col_ptr + col * rhs_stride;
- float* dst_col_ptr = adj_dst_col_ptr + col * dst_stride;
-
- for (int row = params.start_row; row < end_row; row += 8) {
- const int residual_rows = std::min(end_row - row, 8);
-
- const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride;
- float* dst_ptr = dst_col_ptr + row;
-
- // Initialize with bias.
- if (channel_dimension_is_col) {
- const float* bias_elem_ptr = bias_ptr + col * bias_ptr_block_increment;
- for (int j = 0; j < 8; ++j) {
- accum_data_v[j] = _mm256_broadcast_ss(bias_elem_ptr + j);
- }
- } else {
- const float* bias_elem_ptr = bias_ptr + row * bias_ptr_block_increment;
- const __m256 initial_accum_data = _mm256_loadu_ps(bias_elem_ptr);
-
- for (int j = 0; j < 8; ++j) {
- accum_data_v[j] = initial_accum_data;
- }
- }
-
- const float* lhs_ptr = lhs_col_ptr;
- const float* rhs_ptr = rhs_col_ptr;
- for (int d = 0; d < params.depth; ++d) {
- const __m256 lhs_data = _mm256_loadu_ps(lhs_ptr);
- // In this version RHS values are loaded individually rather than first
- // loading together and then extract with broadcasting. This is because
- // AVX flavours and instrinsics and compilers in combination do not
- // handle this pattern of extraction very well.
- const float* rhs_data = rhs_ptr;
-
- for (int j = 0; j < 8; ++j) {
- const __m256 dup_rhs_element_j = _mm256_set1_ps(rhs_data[j]);
- accum_data_v[j] =
- intrin_utils::MulAdd<path>(lhs_data, dup_rhs_element_j, accum_data_v[j]);
- }
- lhs_ptr += 8;
- rhs_ptr += 8;
- }
-
- if (residual_rows == 8) {
- for (int j = 0; j < 8; ++j) {
- float* block_ptr = dst_ptr + j * dst_stride;
- accum_data_v[j] = _mm256_min_ps(accum_data_v[j], clamp_max_v);
- accum_data_v[j] = _mm256_max_ps(accum_data_v[j], clamp_min_v);
- _mm256_storeu_ps(block_ptr, accum_data_v[j]);
- }
- } else {
- for (int j = 0; j < 8; ++j) {
- float* block_ptr = dst_ptr + j * dst_stride;
- accum_data_v[j] = _mm256_min_ps(accum_data_v[j], clamp_max_v);
- accum_data_v[j] = _mm256_max_ps(accum_data_v[j], clamp_min_v);
- intrin_utils::mm256_n_storeu_ps<path>(block_ptr, residual_rows,
- accum_data_v[j]);
- }
- }
- } // End row-block loop.
- } // End col-block loop.
-
- if (col < end_col) {
- // Remaining cols in [0, float block size).
- RUY_DCHECK_GE(end_col - col, 0);
- RUY_DCHECK_LT(end_col - col, 8);
-
- __m256 accum_data_v[8];
-
- const float* rhs_col_ptr = adj_rhs_col_ptr + col * rhs_stride;
- float* dst_col_ptr = adj_dst_col_ptr + col * dst_stride;
- const int residual_cols = std::min(end_col - col, 8);
-
- for (int row = params.start_row; row < end_row; row += 8) {
- const int residual_rows = std::min(end_row - row, 8);
-
- const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride;
- float* dst_ptr = dst_col_ptr + row;
-
- // Initialize with bias.
- if (channel_dimension_is_col) {
- const float* bias_elem_ptr = bias_ptr + col * bias_ptr_block_increment;
- for (int j = 0; j < 8; ++j) {
- accum_data_v[j] = _mm256_broadcast_ss(bias_elem_ptr + j);
- }
- } else {
- const float* bias_elem_ptr = bias_ptr + row * bias_ptr_block_increment;
- const __m256 initial_accum_data = _mm256_loadu_ps(bias_elem_ptr);
-
- for (int j = 0; j < 8; ++j) {
- accum_data_v[j] = initial_accum_data;
- }
- }
-
- const float* lhs_ptr = lhs_col_ptr;
- const float* rhs_ptr = rhs_col_ptr;
- for (int d = 0; d < params.depth; ++d) {
- const __m256 lhs_data = _mm256_loadu_ps(lhs_ptr);
- const float* rhs_data = rhs_ptr;
-
- for (int j = 0; j < 8; ++j) {
- const __m256 dup_rhs_element_j = _mm256_set1_ps(rhs_data[j]);
- accum_data_v[j] =
- intrin_utils::MulAdd<path>(lhs_data, dup_rhs_element_j, accum_data_v[j]);
- }
- lhs_ptr += 8;
- rhs_ptr += 8;
- }
-
- for (int j = 0; j < residual_cols; ++j) {
- float* block_ptr = dst_ptr + j * dst_stride;
- accum_data_v[j] = _mm256_min_ps(accum_data_v[j], clamp_max_v);
- accum_data_v[j] = _mm256_max_ps(accum_data_v[j], clamp_min_v);
- intrin_utils::mm256_n_storeu_ps<path>(block_ptr, residual_rows,
- accum_data_v[j]);
- }
- } // End row-block loop.
- } // End col-block terminal conditional.
-}
-
-template<Path path>
-inline void KernelFloatAvxCommonSingleCol(const KernelParamsFloat<8, 8>& params) {
-
- RUY_DCHECK_EQ(params.dst_cols, 1);
- RUY_DCHECK_EQ(params.last_col, 0);
- RUY_DCHECK_EQ(params.start_col, 0);
-
- // As parameters are defined, we need to scale by sizeof(float).
- const std::int64_t lhs_stride = params.lhs_stride >> 2;
- //
- int bias_ptr_block_increment = params.flags & RUY_ASM_FLAG_HAS_BIAS ? 1 : 0;
- // AVX2 float block size = 8.
- const int end_row = std::min(params.dst_rows, params.last_row + 8);
-
- float* adj_dst_col_ptr = params.dst_base_ptr - params.start_row;
- const float* adj_lhs_col_ptr =
- params.lhs_base_ptr - params.start_row * lhs_stride;
- const float* bias_col_ptr = params.bias;
-
- const __m256 clamp_max_v = _mm256_set1_ps(params.clamp_max);
- const __m256 clamp_min_v = _mm256_set1_ps(params.clamp_min);
-
- __m256 accum_data_v;
-
- const float* rhs_col_ptr = params.rhs_base_ptr;
- float* dst_col_ptr = adj_dst_col_ptr;
-
- int row = params.start_row;
- for (; row <= end_row - 8; row += 8) {
- const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride;
- float* dst_ptr = dst_col_ptr + row;
- const float* bias_ptr = bias_col_ptr + row * bias_ptr_block_increment;
-
- // Initialize with bias.
- accum_data_v = _mm256_loadu_ps(bias_ptr);
-
- const float* lhs_ptr = lhs_col_ptr;
- const float* rhs_ptr = rhs_col_ptr;
- int d = 0;
- for (; d <= params.depth - 4; d += 4) {
- const __m256 lhs_data_0 = _mm256_loadu_ps(lhs_ptr);
- const __m256 dup_rhs_element_0 = _mm256_set1_ps(rhs_ptr[0]);
- accum_data_v =
- intrin_utils::MulAdd<path>(lhs_data_0, dup_rhs_element_0, accum_data_v);
- const __m256 dup_rhs_element_1 = _mm256_set1_ps(rhs_ptr[8]);
- const __m256 lhs_data_1 = _mm256_loadu_ps(lhs_ptr + 8);
- accum_data_v =
- intrin_utils::MulAdd<path>(lhs_data_1, dup_rhs_element_1, accum_data_v);
-
- const __m256 lhs_data_2 = _mm256_loadu_ps(lhs_ptr + 16);
- const __m256 dup_rhs_element_2 = _mm256_set1_ps(rhs_ptr[16]);
- accum_data_v =
- intrin_utils::MulAdd<path>(lhs_data_2, dup_rhs_element_2, accum_data_v);
- const __m256 dup_rhs_element_3 = _mm256_set1_ps(rhs_ptr[24]);
- const __m256 lhs_data_3 = _mm256_loadu_ps(lhs_ptr + 24);
- accum_data_v =
- intrin_utils::MulAdd<path>(lhs_data_3, dup_rhs_element_3, accum_data_v);
- lhs_ptr += 32; // Loaded 8 * 4 floats.
- rhs_ptr += 32;
- }
- for (; d < params.depth; ++d) {
- const __m256 lhs_data = _mm256_loadu_ps(lhs_ptr);
- const float* rhs_data = rhs_ptr;
-
- const __m256 dup_rhs_element_j = _mm256_set1_ps(rhs_data[0]);
- accum_data_v = intrin_utils::MulAdd<path>(lhs_data, dup_rhs_element_j, accum_data_v);
- lhs_ptr += 8;
- rhs_ptr += 8;
- }
-
- accum_data_v = _mm256_min_ps(accum_data_v, clamp_max_v);
- accum_data_v = _mm256_max_ps(accum_data_v, clamp_min_v);
- _mm256_storeu_ps(dst_ptr, accum_data_v);
- } // End row-block loop.
-
- if (row < end_row) {
- const int residual_rows = end_row - row;
- RUY_CHECK_GE(residual_rows, 1);
- RUY_CHECK_LT(residual_rows, 8);
-
- const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride;
- float* dst_ptr = dst_col_ptr + row;
- const float* bias_ptr = bias_col_ptr + row * bias_ptr_block_increment;
-
- // Initialize with bias.
- accum_data_v = _mm256_loadu_ps(bias_ptr);
-
- const float* lhs_ptr = lhs_col_ptr;
- const float* rhs_ptr = rhs_col_ptr;
- for (int d = 0; d < params.depth; ++d) {
- const __m256 lhs_data = _mm256_loadu_ps(lhs_ptr);
- const float* rhs_data = rhs_ptr;
-
- const __m256 dup_rhs_element_j = _mm256_set1_ps(rhs_data[0]);
- accum_data_v = intrin_utils::MulAdd<path>(lhs_data, dup_rhs_element_j, accum_data_v);
- lhs_ptr += 8;
- rhs_ptr += 8;
- }
-
- accum_data_v = _mm256_min_ps(accum_data_v, clamp_max_v);
- accum_data_v = _mm256_max_ps(accum_data_v, clamp_min_v);
- intrin_utils::mm256_n_storeu_ps<path>(dst_ptr, residual_rows, accum_data_v);
- } // End handling of residual rows.
-}
-#endif // (RUY_PLATFORM_AVX || RUY_PLATFORM_AVX2_FMA) && RUY_OPT(ASM)
-
-
} // namespace ruy
#endif // RUY_RUY_KERNEL_X86_H_
diff --git a/ruy/path.h b/ruy/path.h
index d3c5b06..a3cd939 100644
--- a/ruy/path.h
+++ b/ruy/path.h
@@ -77,15 +77,12 @@ enum class Path : std::uint8_t {
#endif // RUY_PLATFORM_ARM
#if RUY_PLATFORM_X86
- // Optimized for AVX
- // Compiled with -mavx
- kAvx = 0x10,
// Optimized for AVX2+FMA.
// Compiled with -mavx2 -mfma.
- kAvx2Fma = 0x20,
+ kAvx2Fma = 0x10,
// Optimized for AVX-512.
// Compiled with -mavx512f -mavx512vl -mavx512cd -mavx512bw -mavx512dq.
- kAvx512 = 0x40,
+ kAvx512 = 0x20,
#endif // RUY_PLATFORM_X86
};
@@ -148,7 +145,7 @@ constexpr Path kExtraArchPaths = Path::kNone;
constexpr Path kDefaultArchPaths = Path::kNeon;
constexpr Path kExtraArchPaths = Path::kNone;
#elif RUY_PLATFORM_X86
-constexpr Path kDefaultArchPaths = Path::kAvx | Path::kAvx2Fma | Path::kAvx512;
+constexpr Path kDefaultArchPaths = Path::kAvx2Fma | Path::kAvx512;
constexpr Path kExtraArchPaths = Path::kNone;
#else
constexpr Path kDefaultArchPaths = Path::kNone;
diff --git a/ruy/platform.h b/ruy/platform.h
index c8bed3a..7421613 100644
--- a/ruy/platform.h
+++ b/ruy/platform.h
@@ -139,12 +139,6 @@ limitations under the License.
#define RUY_PLATFORM_AVX2_FMA 0
#endif
-#if RUY_PLATFORM_X86_ENHANCEMENTS && RUY_PLATFORM_X86 && defined(__AVX__)
-#define RUY_PLATFORM_AVX 1
-#else
-#define RUY_PLATFORM_AVX 0
-#endif
-
// Detect Emscripten, typically Wasm.
#ifdef __EMSCRIPTEN__
#define RUY_PLATFORM_EMSCRIPTEN 1
diff --git a/ruy/test.h b/ruy/test.h
index 3de8d94..30d48bf 100644
--- a/ruy/test.h
+++ b/ruy/test.h
@@ -107,7 +107,6 @@ inline const char* PathName(Path path) {
#elif RUY_PLATFORM_X86
RUY_PATHNAME_CASE(kAvx2Fma)
RUY_PATHNAME_CASE(kAvx512)
- RUY_PATHNAME_CASE(kAvx)
#endif
default:
RUY_CHECK(false);
@@ -1355,8 +1354,9 @@ bool Agree(ExternalPath external_path1, const Matrix<Scalar>& matrix1,
const int size = matrix1.layout().rows() * matrix1.layout().cols();
double tolerated_max_diff = 0;
double tolerated_mean_diff = 0;
- const float kSmallestAllowedDifference = 4. * std::numeric_limits<Scalar>::epsilon();
if (std::is_floating_point<Scalar>::value) {
+ // TODO: replace hardcoded 100 by something more sensible, probably
+ // roughly sqrt(depth) based on central limit theorem.
double max_abs_val = 0;
for (int row = 0; row < matrix1.layout().rows(); row++) {
for (int col = 0; col < matrix1.layout().cols(); col++) {
@@ -1370,18 +1370,6 @@ bool Agree(ExternalPath external_path1, const Matrix<Scalar>& matrix1,
}
tolerated_max_diff = max_abs_val * std::numeric_limits<Scalar>::epsilon() *
64 * std::sqrt(static_cast<float>(depth));
- if (tolerated_max_diff < kSmallestAllowedDifference) {
- // Clamp the tolerated max diff to be a bit above machine epsilon if the
- // calculated value is too small.
- tolerated_max_diff = kSmallestAllowedDifference;
- if (external_path1 == ExternalPath::kEigen ||
- external_path2 == ExternalPath::kEigen ||
- external_path1 == ExternalPath::kEigenTensor ||
- external_path2 == ExternalPath::kEigenTensor) {
- // Make additional allowance for Eigen differences.
- tolerated_max_diff *= 3.0f;
- }
- }
tolerated_mean_diff = tolerated_max_diff / std::sqrt(size);
} else if (std::is_same<Scalar, std::int32_t>::value) {
// raw integer case, no rounding, so we can require exactness