diff options
author | T.J. Alumbaugh <talumbau@google.com> | 2020-08-31 23:19:34 +0300 |
---|---|---|
committer | Copybara-Service <copybara-worker@google.com> | 2020-08-31 23:19:54 +0300 |
commit | 9e637492489df547a9a6db555fca6756df02ace6 (patch) | |
tree | ca80f30967cc2a05327ba8175b1fbb94773434d8 | |
parent | 29a155b0b0cff2c3da2b54201f039e9c07a4a695 (diff) |
AVX 8bit kernel. Forked from AVX2+FMA version
PiperOrigin-RevId: 329365097
-rw-r--r-- | ruy/kernel_avx.cc | 1323 | ||||
-rw-r--r-- | ruy/kernel_avx2_fma.cc | 455 | ||||
-rw-r--r-- | ruy/kernel_x86.h | 317 |
3 files changed, 1754 insertions, 341 deletions
diff --git a/ruy/kernel_avx.cc b/ruy/kernel_avx.cc index 159e754..a005f8c 100644 --- a/ruy/kernel_avx.cc +++ b/ruy/kernel_avx.cc @@ -32,6 +32,16 @@ namespace ruy { #if !(RUY_PLATFORM_AVX && RUY_OPT(ASM)) +void Kernel8bitAvx(const KernelParams8bit<8, 8>&) { + // CPU-ID-based checks should disable the path that would reach this point. + RUY_DCHECK(false); +} + +void Kernel8bitAvxSingleCol(const KernelParams8bit<8, 8>&) { + // CPU-ID-based checks should disable the path that would reach this point. + RUY_DCHECK(false); +} + void KernelFloatAvx(const KernelParamsFloat<8, 8>&) { // CPU-ID-based checks should disable the path that would reach this point. RUY_DCHECK(false); @@ -44,9 +54,373 @@ void KernelFloatAvxSingleCol(const KernelParamsFloat<8, 8>&) { #else // RUY_PLATFORM_AVX && RUY_OPT(ASM) +static constexpr int kAvx8bitBlockSize = 8; +static constexpr int kAvx8bitInnerSize = 4; + namespace { namespace intrin_utils { +template <> +inline __m256i mm256_shuffle_epi8<Path::kAvx>(const __m256i& a, + const __m256i& b) { + __m128i a_lo = _mm256_extractf128_si256(a, 0); + __m128i a_hi = _mm256_extractf128_si256(a, 1); + __m128i b_lo = _mm256_extractf128_si256(b, 0); + __m128i b_hi = _mm256_extractf128_si256(b, 1); + __m128i dst_lo = _mm_shuffle_epi8(a_lo, b_lo); + __m128i dst_hi = _mm_shuffle_epi8(a_hi, b_hi); + return _mm256_set_m128i(dst_hi, dst_lo); +} + +template <> +inline __m128i mm256_extracti128_si256<Path::kAvx>(const __m256i& a, + const int imm) { + switch (imm) { + case 0: + return _mm256_extractf128_si256(a, 0); + case 1: + return _mm256_extractf128_si256(a, 1); + default: + RUY_DCHECK_LT(imm, 2); + return _mm_setzero_si128(); + } +} + +template <Path path> +inline __m256i mm256_cvtepi8_epi16(const __m128i& a) { + // Take the upper 64 bits of a and put in the first 64 bits of 'hi' + __m128i hi = _mm_unpackhi_epi64(a, _mm_setzero_si128()); + return _mm256_set_m128i(_mm_cvtepi8_epi16(hi), _mm_cvtepi8_epi16(a)); +} + +template <Path path> +inline __m256i mm256_cvtepi32_epi64(const __m128i& a) { + // sign extend the 32-bit values in the lower 64 bits of a. + __m128i lo = _mm_cvtepi32_epi64(a); + __m128i hi = _mm_cvtepi32_epi64(_mm_unpackhi_epi64(a, _mm_setzero_si128())); + return _mm256_set_m128i(hi, lo); +} + +inline __m128i mm_permute_helper(const __m256i& a, const __m256i& b, + const int imm) { + __m128i tmp = _mm_setzero_si128(); + if (!(imm & 8)) { + switch (imm & 3) { + case 0: + return _mm256_extractf128_si256(a, 0); + case 1: + return _mm256_extractf128_si256(a, 1); + case 2: + return _mm256_extractf128_si256(b, 0); + case 3: + return _mm256_extractf128_si256(b, 1); + } + } + return tmp; +} + +template <Path path> +inline __m256i mm256_permute2x128_si256(const __m256i& a, const __m256i& b, + const int imm) { + const int lo_imm = imm & 15; + __m128i lo = mm_permute_helper(a, b, lo_imm); + const int hi_imm = (imm >> 4) & 15; + __m128i hi = mm_permute_helper(a, b, hi_imm); + return _mm256_set_m128i(hi, lo); +} + +template <Path path> +inline __m256i mm256_max_epi32(const __m256i& a, const __m256i& b) { + __m128i a_lo = _mm256_extractf128_si256(a, 0); + __m128i a_hi = _mm256_extractf128_si256(a, 1); + __m128i b_lo = _mm256_extractf128_si256(b, 0); + __m128i b_hi = _mm256_extractf128_si256(b, 1); + __m128i lo = _mm_max_epi32(a_lo, b_lo); + __m128i hi = _mm_max_epi32(a_hi, b_hi); + return _mm256_set_m128i(hi, lo); +} + +template <Path path> +inline __m256i mm256_min_epi32(const __m256i& a, const __m256i& b) { + __m128i a_lo = _mm256_extractf128_si256(a, 0); + __m128i a_hi = _mm256_extractf128_si256(a, 1); + __m128i b_lo = _mm256_extractf128_si256(b, 0); + __m128i b_hi = _mm256_extractf128_si256(b, 1); + __m128i lo = _mm_min_epi32(a_lo, b_lo); + __m128i hi = _mm_min_epi32(a_hi, b_hi); + return _mm256_set_m128i(hi, lo); +} + +template <Path path> +inline __m256i mm256_add_epi32(const __m256i& a, const __m256i& b) { + __m128i a_lo = _mm256_extractf128_si256(a, 0); + __m128i a_hi = _mm256_extractf128_si256(a, 1); + __m128i b_lo = _mm256_extractf128_si256(b, 0); + __m128i b_hi = _mm256_extractf128_si256(b, 1); + __m128i lo = _mm_add_epi32(a_lo, b_lo); + __m128i hi = _mm_add_epi32(a_hi, b_hi); + return _mm256_set_m128i(hi, lo); +} + +template <Path path> +inline __m256i mm256_add_epi64(const __m256i& a, const __m256i& b) { + __m128i a_lo = _mm256_extractf128_si256(a, 0); + __m128i a_hi = _mm256_extractf128_si256(a, 1); + __m128i b_lo = _mm256_extractf128_si256(b, 0); + __m128i b_hi = _mm256_extractf128_si256(b, 1); + __m128i lo = _mm_add_epi64(a_lo, b_lo); + __m128i hi = _mm_add_epi64(a_hi, b_hi); + return _mm256_set_m128i(hi, lo); +} + +template <Path path> +inline __m256i mm256_slli_epi64(const __m256i& a, int imm) { + __m128i a_lo = _mm256_extractf128_si256(a, 0); + __m128i a_hi = _mm256_extractf128_si256(a, 1); + __m128i lo = _mm_slli_epi64(a_lo, imm); + __m128i hi = _mm_slli_epi64(a_hi, imm); + return _mm256_set_m128i(hi, lo); +} + +template <Path path> +inline __m256i mm256_mullo_epi32(const __m256i& a, const __m256i& b) { + __m128i a_lo = _mm256_extractf128_si256(a, 0); + __m128i a_hi = _mm256_extractf128_si256(a, 1); + __m128i b_lo = _mm256_extractf128_si256(b, 0); + __m128i b_hi = _mm256_extractf128_si256(b, 1); + __m128i lo = _mm_mullo_epi32(a_lo, b_lo); + __m128i hi = _mm_mullo_epi32(a_hi, b_hi); + return _mm256_set_m128i(hi, lo); +} + +// Defined as a macro since `imm` must be an immediate. +#define BlendM128_epi32(a, b, imm) \ + _mm_castps_si128(_mm_blend_ps(_mm_castsi128_ps(a), _mm_castsi128_ps(b), imm)) + +// Defined as a macro since `imm` must be an immediate. +#define mm256_blend_epi32(ans, a, b, imm) \ + __m128i a_lo = _mm256_extractf128_si256(a, 0); \ + __m128i a_hi = _mm256_extractf128_si256(a, 1); \ + __m128i b_lo = _mm256_extractf128_si256(b, 0); \ + __m128i b_hi = _mm256_extractf128_si256(b, 1); \ + __m128i lo = BlendM128_epi32(a_lo, b_lo, imm & 0xe); \ + __m128i hi = BlendM128_epi32(a_hi, b_hi, imm >> 4); \ + ans = _mm256_set_m128i(hi, lo); + +template <Path path> +inline __m256i mm256_madd_epi16(const __m256i& a, const __m256i& b) { + __m128i a_lo = _mm256_extractf128_si256(a, 0); + __m128i a_hi = _mm256_extractf128_si256(a, 1); + __m128i b_lo = _mm256_extractf128_si256(b, 0); + __m128i b_hi = _mm256_extractf128_si256(b, 1); + __m128i lo = _mm_madd_epi16(a_lo, b_lo); + __m128i hi = _mm_madd_epi16(a_hi, b_hi); + return _mm256_set_m128i(hi, lo); +} + +inline __m128i mm_srlv_epi64(const __m128i& a, const __m128i& b) { + // shift both elements of a by lower 64bits of b. + __m128i res_lo = _mm_srl_epi64(a, b); + // shift both elements of a by upper 64bits of b. + __m128i hi_count = _mm_unpackhi_epi64(b, _mm_setzero_si128()); + __m128i res_hi = _mm_srl_epi64(a, hi_count); + // Take the lower 64 bits of res_lo and upper 64 bits of res hi + // 1. Swap the upper and lower 64 bits of res_hi + __m128i tmp_hi = + _mm_castpd_si128(_mm_permute_pd(_mm_castsi128_pd(res_hi), 1)); + // The lower 64 bits of res_lo and the lower 64 bits of tmp_hi. + return _mm_unpacklo_epi64(res_lo, tmp_hi); +} + +template <Path path> +inline __m256i mm256_srlv_epi64(const __m256i& a, const __m256i& b) { + __m128i a_lo = _mm256_extractf128_si256(a, 0); + __m128i a_hi = _mm256_extractf128_si256(a, 1); + __m128i b_lo = _mm256_extractf128_si256(b, 0); + __m128i b_hi = _mm256_extractf128_si256(b, 1); + __m128i lo = mm_srlv_epi64(a_lo, b_lo); + __m128i hi = mm_srlv_epi64(a_hi, b_hi); + return _mm256_set_m128i(hi, lo); +} + +template <Path path> +inline __m128i mm_sllv_epi64(const __m128i& a, const __m128i& b) { + // shift both elements of a by lower 64bits of b. + __m128i res_lo = _mm_sll_epi64(a, b); + // shift both elements of a by upper 64bits of b. + __m128i hi_count = _mm_unpackhi_epi64(b, _mm_setzero_si128()); + __m128i res_hi = _mm_sll_epi64(a, hi_count); + // Take the lower 64 bits of res_lo and upper 64 bits of res hi + // 1. Swap the upper and lower 64 bits of res_hi + __m128i tmp_hi = + _mm_castpd_si128(_mm_permute_pd(_mm_castsi128_pd(res_hi), 1)); + // The lower 64 bits of res_lo and the lower 64 bits of tmp_hi. + return _mm_unpacklo_epi64(res_lo, tmp_hi); +} + +template <Path path> +inline __m256i mm256_sllv_epi64(const __m256i& a, const __m256i& b) { + __m128i a_lo = _mm256_extractf128_si256(a, 0); + __m128i a_hi = _mm256_extractf128_si256(a, 1); + __m128i b_lo = _mm256_extractf128_si256(b, 0); + __m128i b_hi = _mm256_extractf128_si256(b, 1); + __m128i lo = mm_sllv_epi64<path>(a_lo, b_lo); + __m128i hi = mm_sllv_epi64<path>(a_hi, b_hi); + return _mm256_set_m128i(hi, lo); +} + +#define PermuteM128_epi32(a, imm) \ + _mm_castps_si128(_mm_permute_ps(_mm_castsi128_ps(a), imm)); + +inline __m128i mm_srlv_epi32(const __m128i& a, const __m128i& b) { + // shift all elements of a by first 32bits of b. + + __m128i res0 = _mm_srl_epi32(a, BlendM128_epi32(_mm_setzero_si128(), b, 1)); + // put bits 32-63 of b in the first slot. + __m128i tmp1 = PermuteM128_epi32(b, 1); + // put bits 32-63 of a in the first slot. + __m128i a1 = PermuteM128_epi32(a, 1); + // shift all elements of a by second 32bits of b. + __m128i res1 = + _mm_srl_epi32(a1, BlendM128_epi32(_mm_setzero_si128(), tmp1, 1)); + // put bits 64-95 of b in the first slot. + __m128i tmp2 = PermuteM128_epi32(b, 2); + // shift all elements of a by third 32bits of b. + __m128i res2 = + _mm_srl_epi32(a, BlendM128_epi32(_mm_setzero_si128(), tmp2, 1)); + // put bits 96-127 of b in the first slot. + __m128i tmp3 = PermuteM128_epi32(b, 3); + // put bits 96-127 of a in the third slot. + __m128i a3 = PermuteM128_epi32(a, 48); + // shift all elements of a3 by fourth 32bits of b. + __m128i res3 = + _mm_srl_epi32(a3, BlendM128_epi32(_mm_setzero_si128(), tmp3, 1)); + // Take bits 0-31 of res0, bits 0-31 of res1, + // bits 64-95 of res2, and bits 64-95 of res3. + // res0 _ _ _ 0 + // res1 _ _ _ 1 + // res2 _ 2 _ _ + // res3 _ 3 _ _ + // f_01 _ _ 1 0 + // f_23 _ _ 3 2 + + __m128i f_01 = _mm_unpacklo_epi32(res0, res1); + __m128i f_23 = _mm_unpacklo_epi32(res2, res3); + // The lower 64 bits of res_lo and the lower 64 bits of tmp_hi. + return _mm_unpacklo_epi64(f_01, f_23); +} + +template <Path path> +inline __m256i mm256_srlv_epi32(const __m256i& a, const __m256i& b) { + __m128i a_lo = _mm256_extractf128_si256(a, 0); + __m128i a_hi = _mm256_extractf128_si256(a, 1); + __m128i b_lo = _mm256_extractf128_si256(b, 0); + __m128i b_hi = _mm256_extractf128_si256(b, 1); + __m128i lo = mm_srlv_epi32(a_lo, b_lo); + __m128i hi = mm_srlv_epi32(a_hi, b_hi); + __m256i ans = _mm256_set_m128i(hi, lo); + return ans; +} + +inline __m128i mm_sllv_epi32(const __m128i& a, const __m128i& b) { + // shift all elements of a by first 32bits of b. + __m128i res0 = _mm_sll_epi32(a, BlendM128_epi32(_mm_setzero_si128(), b, 1)); + + // put bits 32-63 of b in the first slot. + __m128i tmp1 = PermuteM128_epi32(b, 1); + // put bits 32-63 of a in the first slot. + __m128i a1 = PermuteM128_epi32(a, 1); + // shift all elements of a by second 32bits of b. + __m128i res1 = + _mm_sll_epi32(a1, BlendM128_epi32(_mm_setzero_si128(), tmp1, 1)); + + // put bits 64-95 of b in the first slot. + __m128i tmp2 = PermuteM128_epi32(b, 2); + // shift all elements of a by third 32bits of b. + __m128i res2 = + _mm_sll_epi32(a, BlendM128_epi32(_mm_setzero_si128(), tmp2, 1)); + + // put bits 96-127 of b in the first slot. + __m128i tmp3 = PermuteM128_epi32(b, 3); + // put bits 96-127 of a in the third slot. + __m128i a3 = PermuteM128_epi32(a, 48); + // shift all elements of a3 by fourth 32bits of b. + __m128i res3 = + _mm_sll_epi32(a3, BlendM128_epi32(_mm_setzero_si128(), tmp3, 1)); + + // Take bits 0-31 of res0, bits 0-31 of res1, + // bits 64-95 of res2, and bits 64-95 of res3. + // res0 _ _ _ 0 + // res1 _ _ _ 1 + // res2 _ 2 _ _ + // res3 _ 3 _ _ + // f_01 _ _ 1 0 + // f_23 _ _ 3 2 + __m128i f_01 = _mm_unpacklo_epi32(res0, res1); + __m128i f_23 = _mm_unpackhi_epi32(res2, res3); + // The lower 64 bits of res_lo and the lower 64 bits of tmp_hi. + return _mm_unpacklo_epi64(f_01, f_23); +} + +template <Path path> +inline __m256i mm256_sllv_epi32(const __m256i& a, const __m256i& b) { + __m128i a_lo = _mm256_extractf128_si256(a, 0); + __m128i a_hi = _mm256_extractf128_si256(a, 1); + __m128i b_lo = _mm256_extractf128_si256(b, 0); + __m128i b_hi = _mm256_extractf128_si256(b, 1); + __m128i lo = mm_sllv_epi32(a_lo, b_lo); + __m128i hi = mm_sllv_epi32(a_hi, b_hi); + return _mm256_set_m128i(hi, lo); +} + +template <Path path> +inline __m256i mm256_sub_epi32(const __m256i& a, const __m256i& b) { + __m128i a_lo = _mm256_extractf128_si256(a, 0); + __m128i a_hi = _mm256_extractf128_si256(a, 1); + __m128i b_lo = _mm256_extractf128_si256(b, 0); + __m128i b_hi = _mm256_extractf128_si256(b, 1); + __m128i lo = _mm_sub_epi32(a_lo, b_lo); + __m128i hi = _mm_sub_epi32(a_hi, b_hi); + return _mm256_set_m128i(hi, lo); +} + +template <Path path> +inline __m256i mm256_mul_epi32(const __m256i& a, const __m256i& b) { + __m128i a_lo = _mm256_extractf128_si256(a, 0); + __m128i a_hi = _mm256_extractf128_si256(a, 1); + __m128i b_lo = _mm256_extractf128_si256(b, 0); + __m128i b_hi = _mm256_extractf128_si256(b, 1); + __m128i lo = _mm_mul_epi32(a_lo, b_lo); + __m128i hi = _mm_mul_epi32(a_hi, b_hi); + return _mm256_set_m128i(hi, lo); +} + +// Perform the equivalent of mm256_permutevar8x32 with +// a second argument of {0, 2, 4, 6, 1, 3, 5, 7} for the +// 7 6 5 4 3 2 1 0 +template <Path path> +inline __m256i PermuteEpi32EvenOdds(const __m256i& a) { + // a_lo = 3 2 1 0 + __m128i a_lo = _mm256_extractf128_si256(a, 0); + // a_hi = 7 6 5 4 + __m128i a_hi = _mm256_extractf128_si256(a, 1); + // shuffle a_lo to get 3 1 2 0 + __m128i tmp_lo = _mm_shuffle_epi32(a_lo, 0xd8); + // shuffle a_hi to get 7 5 6 4 + __m128i tmp_hi = _mm_shuffle_epi32(a_hi, 0xd8); + // unpack lo 64 of res_lo and res hi to get 6 4 2 0 + __m128i res_lo = _mm_unpacklo_epi64(tmp_lo, tmp_hi); + // unpack hi 64 of res_lo and res hi to get 7 5 1 3 + __m128i res_hi = _mm_unpackhi_epi64(tmp_lo, tmp_hi); + return _mm256_set_m128i(res_hi, res_lo); +} + +template <Path path> +inline __m256i AddBiasEpi32(const __m256i& a, const int32_t* bias, int offset) { + const __m256i bias0 = _mm256_set1_epi32(*(bias + offset)); + return mm256_add_epi32<path>(a, bias0); +} + // AVX doesn't have fused multiply-add so we define an inline function to be // used in the common code following. template <> @@ -59,6 +433,955 @@ inline __m256 MulAdd<Path::kAvx>(const __m256& a, const __m256& b, } // namespace intrin_utils } // namespace +template <Path path> +void Kernel8bitAvxImpl(const KernelParams8bit<8, 8>& params) { + profiler::ScopeLabel label("Kernel kAvx 8-bit"); + const std::int8_t splitter_idx_data[32] = { + 0, 1, 4, 5, 8, 9, 12, 13, // + 2, 3, 6, 7, 10, 11, 14, 15, // + 0, 1, 4, 5, 8, 9, 12, 13, // + 2, 3, 6, 7, 10, 11, 14, 15 // + }; + + std::int32_t dst_stride = 0; + if ((params.dst_type_id == DstTypeId<std::int8_t>::kValue) || + (params.dst_type_id == DstTypeId<std::uint8_t>::kValue)) { + dst_stride = params.dst_stride; + } else if (params.dst_type_id == DstTypeId<std::int16_t>::kValue) { + dst_stride = params.dst_stride / sizeof(std::int16_t); + } else if (params.dst_type_id == DstTypeId<std::int32_t>::kValue) { + dst_stride = params.dst_stride / sizeof(std::int32_t); + } else { + RUY_DCHECK(false); + } + + const std::int8_t* rhs_col_ptr = params.rhs_base_ptr; + void* dst_col_ptr = params.dst_base_ptr; + + for (int col = params.start_col; col <= params.last_col; + col += kAvx8bitBlockSize) { + const std::int8_t* lhs_col_ptr = params.lhs_base_ptr; + void* dst_ptr = dst_col_ptr; + + const std::int32_t lhs_zero_point = params.lhs_zero_point; + const bool has_rhs_sums_offsets = + (params.flags & RUY_ASM_FLAG_HAS_RHS_SUMS) && lhs_zero_point; + std::int32_t rhs_sums_offsets[8]; + if (has_rhs_sums_offsets) { + const __m256i rhs_sums_offset_v = intrin_utils::mm256_mullo_epi32<path>( + _mm256_set1_epi32(lhs_zero_point), + _mm256_loadu_si256( + reinterpret_cast<__m256i const*>(¶ms.rhs_sums[col]))); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(rhs_sums_offsets), + rhs_sums_offset_v); + } + + for (int row = params.start_row; row <= params.last_row; + row += kAvx8bitBlockSize) { + int channel = + (params.flags & RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) ? col : row; + int multiplier_channel = + (params.flags & RUY_ASM_FLAG_HAS_PERCHANNEL) ? channel : 0; + const int residual_rows = + std::min(params.dst_rows - row, kAvx8bitBlockSize); + const int residual_cols = + std::min(params.dst_cols - col, kAvx8bitBlockSize); + + const __m256i splitter_idx = _mm256_loadu_si256( + reinterpret_cast<__m256i const*>(splitter_idx_data)); + + __m256i accum_data_v0; + __m256i accum_data_v1; + __m256i accum_data_v2; + __m256i accum_data_v3; + __m256i accum_data_v4; + __m256i accum_data_v5; + __m256i accum_data_v6; + __m256i accum_data_v7; + + // initial_accum_data will be the initialize of each of the + // accum_data_* accumulator registers. We compute into it terms that are + // identical across columns. + __m128i initial_accum_data_lo = _mm_set1_epi32(params.prod_zp_depth); + __m128i initial_accum_data_hi = _mm_set1_epi32(params.prod_zp_depth); + + // In the channels-are-rows case, we can load bias here. + if ((params.flags & RUY_ASM_FLAG_HAS_BIAS) && + !(params.flags & RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL)) { + initial_accum_data_lo = _mm_add_epi32( + initial_accum_data_lo, + _mm_loadu_si128( + reinterpret_cast<const __m128i*>(params.bias + row))); + initial_accum_data_hi = _mm_add_epi32( + initial_accum_data_hi, + _mm_loadu_si128( + reinterpret_cast<const __m128i*>(params.bias + row + 4))); + } + + // Adjustments common across columns. + const std::int32_t rhs_zero_point = params.rhs_zero_point; + if ((params.flags & RUY_ASM_FLAG_HAS_LHS_SUMS) && rhs_zero_point) { + const __m128i rhs_zp = _mm_set1_epi32(rhs_zero_point); + const __m128i lhs_sums_offset_lo = _mm_mullo_epi32( + rhs_zp, _mm_loadu_si128(reinterpret_cast<__m128i const*>( + ¶ms.lhs_sums[row]))); + const __m128i lhs_sums_offset_hi = _mm_mullo_epi32( + rhs_zp, _mm_loadu_si128(reinterpret_cast<__m128i const*>( + ¶ms.lhs_sums[row + 4]))); + + initial_accum_data_lo = + _mm_sub_epi32(initial_accum_data_lo, lhs_sums_offset_lo); + initial_accum_data_hi = + _mm_sub_epi32(initial_accum_data_hi, lhs_sums_offset_hi); + } + + // Adjustments differing across columns. + if (has_rhs_sums_offsets) { + __m256i initial_accum_data = + _mm256_set_m128i(initial_accum_data_hi, initial_accum_data_lo); + + accum_data_v0 = intrin_utils::mm256_sub_epi32<path>( + initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[0])); + accum_data_v1 = intrin_utils::mm256_sub_epi32<path>( + initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[1])); + accum_data_v2 = intrin_utils::mm256_sub_epi32<path>( + initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[2])); + accum_data_v3 = intrin_utils::mm256_sub_epi32<path>( + initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[3])); + accum_data_v4 = intrin_utils::mm256_sub_epi32<path>( + initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[4])); + accum_data_v5 = intrin_utils::mm256_sub_epi32<path>( + initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[5])); + accum_data_v6 = intrin_utils::mm256_sub_epi32<path>( + initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[6])); + accum_data_v7 = intrin_utils::mm256_sub_epi32<path>( + initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[7])); + } else { + __m256i initial_accum_data = + _mm256_set_m128i(initial_accum_data_hi, initial_accum_data_lo); + accum_data_v0 = initial_accum_data; + accum_data_v1 = initial_accum_data; + accum_data_v2 = initial_accum_data; + accum_data_v3 = initial_accum_data; + accum_data_v4 = initial_accum_data; + accum_data_v5 = initial_accum_data; + accum_data_v6 = initial_accum_data; + accum_data_v7 = initial_accum_data; + } + + // Finally, in the channels-are-columns case, load bias data here. + if ((params.flags & RUY_ASM_FLAG_HAS_BIAS) && + (params.flags & RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL)) { + accum_data_v0 = intrin_utils::AddBiasEpi32<path>(accum_data_v0, + params.bias + col, 0); + accum_data_v1 = intrin_utils::AddBiasEpi32<path>(accum_data_v1, + params.bias + col, 1); + accum_data_v2 = intrin_utils::AddBiasEpi32<path>(accum_data_v2, + params.bias + col, 2); + accum_data_v3 = intrin_utils::AddBiasEpi32<path>(accum_data_v3, + params.bias + col, 3); + accum_data_v4 = intrin_utils::AddBiasEpi32<path>(accum_data_v4, + params.bias + col, 4); + accum_data_v5 = intrin_utils::AddBiasEpi32<path>(accum_data_v5, + params.bias + col, 5); + accum_data_v6 = intrin_utils::AddBiasEpi32<path>(accum_data_v6, + params.bias + col, 6); + accum_data_v7 = intrin_utils::AddBiasEpi32<path>(accum_data_v7, + params.bias + col, 7); + } + + const std::int8_t* lhs_ptr = lhs_col_ptr; + const std::int8_t* rhs_ptr = rhs_col_ptr; + for (int d = 0; d < params.depth; d += kAvx8bitInnerSize) { + const __m256i lhs_data = + _mm256_load_si256(reinterpret_cast<const __m256i*>(lhs_ptr)); + const __m256i rhs_data_8bit = + _mm256_load_si256(reinterpret_cast<const __m256i*>(rhs_ptr)); + + // Each "int32" is two 16-bit RHS values, sign extended from 8-bit. + std::int32_t rhs_data[16]; + const __m128i rhs_data_bottom_lane = + _mm256_castsi256_si128(rhs_data_8bit); + const __m128i rhs_data_top_lane = + _mm256_extractf128_si256(rhs_data_8bit, 1); + const __m256i rhs_16_bit_dup_low = + intrin_utils::mm256_cvtepi8_epi16<path>(rhs_data_bottom_lane); + const __m256i rhs_16_bit_dup_high = + intrin_utils::mm256_cvtepi8_epi16<path>(rhs_data_top_lane); + // Now that we have cast the RHS data, we store it so that each value + // can be separately loaded in the accumulation loop. + _mm256_storeu_si256(reinterpret_cast<__m256i*>(rhs_data), + rhs_16_bit_dup_low); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(rhs_data + 8), + rhs_16_bit_dup_high); + + // NOTE: There may be opportunities for permuting the data in the + // packing code instead of here. + const __m256i lhs_data_split = + intrin_utils::mm256_shuffle_epi8<path>(lhs_data, splitter_idx); + const __m256i lhs_data_split_expand_bottom = + intrin_utils::mm256_cvtepi8_epi16<path>( + _mm256_extractf128_si256(lhs_data_split, 0)); + const __m256i lhs_data_split_expand_top = + intrin_utils::mm256_cvtepi8_epi16<path>( + _mm256_extractf128_si256(lhs_data_split, 1)); + + // Take bytes 0, 1, 4, 5, 8, 9, ... expanded to 16-bit. + const __m256i lhs_16_bit_low = + intrin_utils::mm256_permute2x128_si256<path>( + lhs_data_split_expand_bottom, lhs_data_split_expand_top, 0x20); + // Take bytes 2, 3, 6, 7, 10, 11, ... expanded to 16-bit. + const __m256i lhs_16_bit_high = + intrin_utils::mm256_permute2x128_si256<path>( + lhs_data_split_expand_bottom, lhs_data_split_expand_top, 0x31); + auto process_column = [=](int col, __m256i& accum) { + const std::int32_t low_rhs_value = rhs_data[col * 2]; + const std::int32_t high_rhs_value = rhs_data[col * 2 + 1]; + + const __m256i rhs_16_bit_dup_low = _mm256_set1_epi32(low_rhs_value); + const __m256i rhs_16_bit_dup_high = _mm256_set1_epi32(high_rhs_value); + + accum = intrin_utils::mm256_add_epi32<path>( + accum, intrin_utils::mm256_madd_epi16<path>(lhs_16_bit_low, + rhs_16_bit_dup_low)); + accum = intrin_utils::mm256_add_epi32<path>( + accum, intrin_utils::mm256_madd_epi16<path>(lhs_16_bit_high, + rhs_16_bit_dup_high)); + }; + process_column(0, accum_data_v0); + process_column(1, accum_data_v1); + process_column(2, accum_data_v2); + process_column(3, accum_data_v3); + process_column(4, accum_data_v4); + process_column(5, accum_data_v5); + process_column(6, accum_data_v6); + process_column(7, accum_data_v7); + + lhs_ptr += kAvx8bitBlockSize * kAvx8bitInnerSize; + rhs_ptr += kAvx8bitBlockSize * kAvx8bitInnerSize; + } + + if (params.dst_type_id != DstTypeId<std::int32_t>::kValue) { + __m256i m_vector; + __m256i e_vector; + // Does not make use of RUY_ASM_FLAG_NEEDS_LEFT_SHIFT. + m_vector = _mm256_loadu_si256(reinterpret_cast<const __m256i*>( + params.multiplier_fixedpoint + multiplier_channel)); + e_vector = _mm256_loadu_si256(reinterpret_cast<const __m256i*>( + params.multiplier_exponent + multiplier_channel)); + + const __m256i m_64bit_low = intrin_utils::mm256_cvtepi32_epi64<path>( + _mm256_extractf128_si256(m_vector, 0)); + const __m256i m_64bit_high = intrin_utils::mm256_cvtepi32_epi64<path>( + _mm256_extractf128_si256(m_vector, 1)); + + const __m256i zero_vector = _mm256_setzero_si256(); + const __m256i left_shift = + intrin_utils::mm256_max_epi32<path>(e_vector, zero_vector); + const __m256i neg_e_vector = + intrin_utils::mm256_sub_epi32<path>(zero_vector, e_vector); + const __m256i right_shift = + intrin_utils::mm256_max_epi32<path>(neg_e_vector, zero_vector); + const __m256i final_right_shift = intrin_utils::mm256_add_epi32<path>( + right_shift, _mm256_set1_epi32(31)); + const __m256i final_right_shift_low = + intrin_utils::mm256_cvtepi32_epi64<path>( + _mm256_extractf128_si256(final_right_shift, 0)); + const __m256i final_right_shift_high = + intrin_utils::mm256_cvtepi32_epi64<path>( + _mm256_extractf128_si256(final_right_shift, 1)); + // Really we want 0x100000000, but use half to avoid overflowing. + + const __m256i convert_to_signed_halved = + intrin_utils::mm256_srlv_epi32<path>(_mm256_set1_epi32(0x80000000), + right_shift); + const __m256i convert_to_unsigned_64 = + _mm256_set1_epi64x(0x8000000000000000); + + __m256i post_scaling_offset = intrin_utils::mm256_add_epi32<path>( + convert_to_signed_halved, convert_to_signed_halved); + + const __m256i offset_vector = + intrin_utils::mm256_slli_epi64<path>(_mm256_set1_epi64x(1), 30); + // Really these should be shifted by neg_e_vector, but tests pass when + // using right_shift. + const __m256i offset_vector_low = intrin_utils::mm256_add_epi64<path>( + intrin_utils::mm256_sllv_epi64<path>( + offset_vector, intrin_utils::mm256_cvtepi32_epi64<path>( + _mm256_extractf128_si256(right_shift, 0))), + convert_to_unsigned_64); + const __m256i offset_vector_high = intrin_utils::mm256_add_epi64<path>( + intrin_utils::mm256_sllv_epi64<path>( + offset_vector, intrin_utils::mm256_cvtepi32_epi64<path>( + _mm256_extractf128_si256(right_shift, 1))), + convert_to_unsigned_64); + + if (params.dst_zero_point) { + const __m256i dst_zero_point = + _mm256_set1_epi32(params.dst_zero_point); + // The post-scaling offset is subtracted later, so this has the effect + // of adding the zero point. + post_scaling_offset = intrin_utils::mm256_sub_epi32<path>( + post_scaling_offset, dst_zero_point); + } + + // We cannot do + // + // scaled_v_low = + // _mm256_srav_epi64(scaled_v_low, final_right_shift_low); + // scaled_v_high = + // _mm256_srav_epi64(scaled_v_high, final_right_shift_high); + // + // since this instruction is not in AVX2. Instead we use + // _mm256_srlv_epi64, but this is an unsigned shift, so we applied + // offsets before (convert_to_unsigned_64) and after + // (convert_to_signed_halved). + // + // The overall process is, for 64-bit scaled accumulator: + // unsigned_accum = signed_accum + 1 << 63; + // unsigned_accum = (unsigned_accum >> right_shift) >> 31; + // signed_accum = unsigned_accum - ((1 << 32) >> right_shift) / 2 * 2; + + // There are various ways to repack the results, in the absence of + // _mm256_cvtepi64_epi32() or anything like it. + // A. + // accum_data_v[j] = + // _mm256_set_epi32(_mm256_extract_epi32(scaled_v_high, 6), + // _mm256_extract_epi32(scaled_v_high, 4), + // _mm256_extract_epi32(scaled_v_high, 2), + // _mm256_extract_epi32(scaled_v_high, 0), + // _mm256_extract_epi32(scaled_v_low, 6), + // _mm256_extract_epi32(scaled_v_low, 4), + // _mm256_extract_epi32(scaled_v_low, 2), + // _mm256_extract_epi32(scaled_v_low, 0)); + // B. + // scaled_v_low = _mm256_shuffle_epi32(scaled_v_low, 0xd8); + // scaled_v_high = _mm256_shuffle_epi32(scaled_v_high, 0xd8); + // accum_data_v[j] = + // _mm256_set_epi64x(_mm256_extract_epi64(scaled_v_high, 2), + // _mm256_extract_epi64(scaled_v_high, 0), + // _mm256_extract_epi64(scaled_v_low, 2), + // _mm256_extract_epi64(scaled_v_low, 0)); + // C. + // scaled_v_low = + // _mm256_permutevar8x32_epi32(scaled_v_low, repack_perm); + // scaled_v_high = + // _mm256_permutevar8x32_epi32(scaled_v_high, repack_perm); + // accum_data_v[j] = + // _mm256_permute2x128_si256(scaled_v_low, scaled_v_high, 0x20); + // + // However, we choose the following because it uses two lighter + // instructions. The permutation does have a longer latency, but this + // loop can be unrolled. + // D. + // scaled_v_high = _mm256_slli_epi64(scaled_v_high, 32); + // __m256i results = + // _mm256_blend_epi32(scaled_v_low, scaled_v_high, 0xaa); + // results = _mm256_permutevar8x32_epi32(results, repack_perm); + // accum_data_v[j] = intrin_utils::mm256_sub_epi32<path>(results, + // post_scaling_offset); + + // This multiplier code is complex and expensive enough on x86, that + // we prefer to implement the channels-are-columns case by transposing + // around it, rather than duplicate it (which would also require + // duplicating the above code computing the multiplier constants). + // This is one instance where channels-are-columns has lower performance + // than channels-are-rows. + const bool transpose_around_multiplier = + (params.flags & RUY_ASM_FLAG_HAS_PERCHANNEL) && + (params.flags & RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL); + if (transpose_around_multiplier) { + // Transpose the 8x8 accumulators block. Will be un-transposed below + // after the multplier implementation. + intrin_utils::mm256_transpose8x8_epi32<path>( + &accum_data_v0, &accum_data_v1, &accum_data_v2, &accum_data_v3, + &accum_data_v4, &accum_data_v5, &accum_data_v6, &accum_data_v7); + } + auto apply_multiplier = [=](__m256i& accum) { + __m256i shifted_accum = + intrin_utils::mm256_sllv_epi32<path>(accum, left_shift); + // Apply the fixed-point part of the multiplier. + __m256i scaled_v_low = intrin_utils::mm256_mul_epi32<path>( + intrin_utils::mm256_cvtepi32_epi64<path>( + _mm256_extractf128_si256(shifted_accum, 0)), + m_64bit_low); + __m256i scaled_v_high = intrin_utils::mm256_mul_epi32<path>( + intrin_utils::mm256_cvtepi32_epi64<path>( + _mm256_extractf128_si256(shifted_accum, 1)), + m_64bit_high); + scaled_v_low = intrin_utils::mm256_add_epi64<path>(scaled_v_low, + offset_vector_low); + scaled_v_high = intrin_utils::mm256_add_epi64<path>( + scaled_v_high, offset_vector_high); + + scaled_v_low = intrin_utils::mm256_srlv_epi64<path>( + scaled_v_low, final_right_shift_low); + scaled_v_high = intrin_utils::mm256_srlv_epi64<path>( + scaled_v_high, final_right_shift_high); + scaled_v_high = + intrin_utils::mm256_slli_epi64<path>(scaled_v_high, 32); + __m256i results; + mm256_blend_epi32(results, scaled_v_low, scaled_v_high, 0xaa); + // Permute results to this ordering of int32 elements + // lo->hi (0, 2, 4, 6, 1, 3, 5, 7); + results = intrin_utils::PermuteEpi32EvenOdds<path>(results); + + accum = + intrin_utils::mm256_sub_epi32<path>(results, post_scaling_offset); + }; + apply_multiplier(accum_data_v0); + apply_multiplier(accum_data_v1); + apply_multiplier(accum_data_v2); + apply_multiplier(accum_data_v3); + apply_multiplier(accum_data_v4); + apply_multiplier(accum_data_v5); + apply_multiplier(accum_data_v6); + apply_multiplier(accum_data_v7); + // See above comment: here we transpose again to undo the transposition + // of the 8x8 block of accumulators used to implement the + // channels-are-columns case. + if (transpose_around_multiplier) { + intrin_utils::mm256_transpose8x8_epi32<path>( + &accum_data_v0, &accum_data_v1, &accum_data_v2, &accum_data_v3, + &accum_data_v4, &accum_data_v5, &accum_data_v6, &accum_data_v7); + } + } + const __m256i clamp_max_v = _mm256_set1_epi32(params.clamp_max); + const __m256i clamp_min_v = _mm256_set1_epi32(params.clamp_min); + const bool store_full_block = (residual_rows == kAvx8bitBlockSize) && + (residual_cols == kAvx8bitBlockSize); + + __m256i accum_data_v[kAvx8bitBlockSize]; + if (!store_full_block) { + accum_data_v[0] = accum_data_v0; + accum_data_v[1] = accum_data_v1; + accum_data_v[2] = accum_data_v2; + accum_data_v[3] = accum_data_v3; + accum_data_v[4] = accum_data_v4; + accum_data_v[5] = accum_data_v5; + accum_data_v[6] = accum_data_v6; + accum_data_v[7] = accum_data_v7; + } + + if (params.dst_type_id == DstTypeId<std::int8_t>::kValue) { + std::int8_t* tmp_ptr = static_cast<std::int8_t*>(dst_ptr); + if (store_full_block) { + accum_data_v0 = + intrin_utils::mm256_min_epi32<path>(accum_data_v0, clamp_max_v); + accum_data_v0 = + intrin_utils::mm256_max_epi32<path>(accum_data_v0, clamp_min_v); + accum_data_v1 = + intrin_utils::mm256_min_epi32<path>(accum_data_v1, clamp_max_v); + accum_data_v1 = + intrin_utils::mm256_max_epi32<path>(accum_data_v1, clamp_min_v); + accum_data_v2 = + intrin_utils::mm256_min_epi32<path>(accum_data_v2, clamp_max_v); + accum_data_v2 = + intrin_utils::mm256_max_epi32<path>(accum_data_v2, clamp_min_v); + accum_data_v3 = + intrin_utils::mm256_min_epi32<path>(accum_data_v3, clamp_max_v); + accum_data_v3 = + intrin_utils::mm256_max_epi32<path>(accum_data_v3, clamp_min_v); + accum_data_v4 = + intrin_utils::mm256_min_epi32<path>(accum_data_v4, clamp_max_v); + accum_data_v4 = + intrin_utils::mm256_max_epi32<path>(accum_data_v4, clamp_min_v); + accum_data_v5 = + intrin_utils::mm256_min_epi32<path>(accum_data_v5, clamp_max_v); + accum_data_v5 = + intrin_utils::mm256_max_epi32<path>(accum_data_v5, clamp_min_v); + accum_data_v6 = + intrin_utils::mm256_min_epi32<path>(accum_data_v6, clamp_max_v); + accum_data_v6 = + intrin_utils::mm256_max_epi32<path>(accum_data_v6, clamp_min_v); + accum_data_v7 = + intrin_utils::mm256_min_epi32<path>(accum_data_v7, clamp_max_v); + accum_data_v7 = + intrin_utils::mm256_max_epi32<path>(accum_data_v7, clamp_min_v); + intrin_utils::mm256_storeu_cvtepi32_epi8<path>( + &tmp_ptr[0 * dst_stride], accum_data_v0); + intrin_utils::mm256_storeu_cvtepi32_epi8<path>( + &tmp_ptr[1 * dst_stride], accum_data_v1); + intrin_utils::mm256_storeu_cvtepi32_epi8<path>( + &tmp_ptr[2 * dst_stride], accum_data_v2); + intrin_utils::mm256_storeu_cvtepi32_epi8<path>( + &tmp_ptr[3 * dst_stride], accum_data_v3); + intrin_utils::mm256_storeu_cvtepi32_epi8<path>( + &tmp_ptr[4 * dst_stride], accum_data_v4); + intrin_utils::mm256_storeu_cvtepi32_epi8<path>( + &tmp_ptr[5 * dst_stride], accum_data_v5); + intrin_utils::mm256_storeu_cvtepi32_epi8<path>( + &tmp_ptr[6 * dst_stride], accum_data_v6); + intrin_utils::mm256_storeu_cvtepi32_epi8<path>( + &tmp_ptr[7 * dst_stride], accum_data_v7); + } else { + for (int j = 0; j < residual_cols; ++j) { + __m256i result = accum_data_v[j]; + result = intrin_utils::mm256_min_epi32<path>(result, clamp_max_v); + result = intrin_utils::mm256_max_epi32<path>(result, clamp_min_v); + intrin_utils::mm256_n_storeu_cvtepi32_epi8<path>( + tmp_ptr, residual_rows, result); + tmp_ptr += dst_stride; + } + } + dst_ptr = static_cast<void*>(static_cast<std::int8_t*>(dst_ptr) + + kAvx8bitBlockSize); + } else if (params.dst_type_id == DstTypeId<std::uint8_t>::kValue) { + std::uint8_t* tmp_ptr = static_cast<std::uint8_t*>(dst_ptr); + if (store_full_block) { + accum_data_v0 = + intrin_utils::mm256_min_epi32<path>(accum_data_v0, clamp_max_v); + accum_data_v0 = + intrin_utils::mm256_max_epi32<path>(accum_data_v0, clamp_min_v); + accum_data_v1 = + intrin_utils::mm256_min_epi32<path>(accum_data_v1, clamp_max_v); + accum_data_v1 = + intrin_utils::mm256_max_epi32<path>(accum_data_v1, clamp_min_v); + accum_data_v2 = + intrin_utils::mm256_min_epi32<path>(accum_data_v2, clamp_max_v); + accum_data_v2 = + intrin_utils::mm256_max_epi32<path>(accum_data_v2, clamp_min_v); + accum_data_v3 = + intrin_utils::mm256_min_epi32<path>(accum_data_v3, clamp_max_v); + accum_data_v3 = + intrin_utils::mm256_max_epi32<path>(accum_data_v3, clamp_min_v); + accum_data_v4 = + intrin_utils::mm256_min_epi32<path>(accum_data_v4, clamp_max_v); + accum_data_v4 = + intrin_utils::mm256_max_epi32<path>(accum_data_v4, clamp_min_v); + accum_data_v5 = + intrin_utils::mm256_min_epi32<path>(accum_data_v5, clamp_max_v); + accum_data_v5 = + intrin_utils::mm256_max_epi32<path>(accum_data_v5, clamp_min_v); + accum_data_v6 = + intrin_utils::mm256_min_epi32<path>(accum_data_v6, clamp_max_v); + accum_data_v6 = + intrin_utils::mm256_max_epi32<path>(accum_data_v6, clamp_min_v); + accum_data_v7 = + intrin_utils::mm256_min_epi32<path>(accum_data_v7, clamp_max_v); + accum_data_v7 = + intrin_utils::mm256_max_epi32<path>(accum_data_v7, clamp_min_v); + intrin_utils::mm256_storeu_cvtepi32_epi8<path>(&tmp_ptr[0], + accum_data_v0); + intrin_utils::mm256_storeu_cvtepi32_epi8<path>(&tmp_ptr[dst_stride], + accum_data_v1); + intrin_utils::mm256_storeu_cvtepi32_epi8<path>( + &tmp_ptr[2 * dst_stride], accum_data_v2); + intrin_utils::mm256_storeu_cvtepi32_epi8<path>( + &tmp_ptr[3 * dst_stride], accum_data_v3); + intrin_utils::mm256_storeu_cvtepi32_epi8<path>( + &tmp_ptr[4 * dst_stride], accum_data_v4); + intrin_utils::mm256_storeu_cvtepi32_epi8<path>( + &tmp_ptr[5 * dst_stride], accum_data_v5); + intrin_utils::mm256_storeu_cvtepi32_epi8<path>( + &tmp_ptr[6 * dst_stride], accum_data_v6); + intrin_utils::mm256_storeu_cvtepi32_epi8<path>( + &tmp_ptr[7 * dst_stride], accum_data_v7); + } else { + for (int j = 0; j < residual_cols; ++j) { + __m256i result = accum_data_v[j]; + result = intrin_utils::mm256_min_epi32<path>(result, clamp_max_v); + result = intrin_utils::mm256_max_epi32<path>(result, clamp_min_v); + intrin_utils::mm256_n_storeu_cvtepi32_epi8<path>( + tmp_ptr, residual_rows, result); + tmp_ptr += dst_stride; + } + } + dst_ptr = static_cast<void*>(static_cast<std::uint8_t*>(dst_ptr) + + kAvx8bitBlockSize); + } else if (params.dst_type_id == DstTypeId<std::int16_t>::kValue) { + std::int16_t* tmp_ptr = static_cast<std::int16_t*>(dst_ptr); + if (store_full_block) { + accum_data_v0 = + intrin_utils::mm256_min_epi32<path>(accum_data_v0, clamp_max_v); + accum_data_v0 = + intrin_utils::mm256_max_epi32<path>(accum_data_v0, clamp_min_v); + accum_data_v1 = + intrin_utils::mm256_min_epi32<path>(accum_data_v1, clamp_max_v); + accum_data_v1 = + intrin_utils::mm256_max_epi32<path>(accum_data_v1, clamp_min_v); + accum_data_v2 = + intrin_utils::mm256_min_epi32<path>(accum_data_v2, clamp_max_v); + accum_data_v2 = + intrin_utils::mm256_max_epi32<path>(accum_data_v2, clamp_min_v); + accum_data_v3 = + intrin_utils::mm256_min_epi32<path>(accum_data_v3, clamp_max_v); + accum_data_v3 = + intrin_utils::mm256_max_epi32<path>(accum_data_v3, clamp_min_v); + accum_data_v4 = + intrin_utils::mm256_min_epi32<path>(accum_data_v4, clamp_max_v); + accum_data_v4 = + intrin_utils::mm256_max_epi32<path>(accum_data_v4, clamp_min_v); + accum_data_v5 = + intrin_utils::mm256_min_epi32<path>(accum_data_v5, clamp_max_v); + accum_data_v5 = + intrin_utils::mm256_max_epi32<path>(accum_data_v5, clamp_min_v); + accum_data_v6 = + intrin_utils::mm256_min_epi32<path>(accum_data_v6, clamp_max_v); + accum_data_v6 = + intrin_utils::mm256_max_epi32<path>(accum_data_v6, clamp_min_v); + accum_data_v7 = + intrin_utils::mm256_min_epi32<path>(accum_data_v7, clamp_max_v); + accum_data_v7 = + intrin_utils::mm256_max_epi32<path>(accum_data_v7, clamp_min_v); + intrin_utils::mm256_storeu_cvtepi32_epi16<path>(&tmp_ptr[0], + accum_data_v0); + intrin_utils::mm256_storeu_cvtepi32_epi16<path>(&tmp_ptr[dst_stride], + accum_data_v1); + intrin_utils::mm256_storeu_cvtepi32_epi16<path>( + &tmp_ptr[2 * dst_stride], accum_data_v2); + intrin_utils::mm256_storeu_cvtepi32_epi16<path>( + &tmp_ptr[3 * dst_stride], accum_data_v3); + intrin_utils::mm256_storeu_cvtepi32_epi16<path>( + &tmp_ptr[4 * dst_stride], accum_data_v4); + intrin_utils::mm256_storeu_cvtepi32_epi16<path>( + &tmp_ptr[5 * dst_stride], accum_data_v5); + intrin_utils::mm256_storeu_cvtepi32_epi16<path>( + &tmp_ptr[6 * dst_stride], accum_data_v6); + intrin_utils::mm256_storeu_cvtepi32_epi16<path>( + &tmp_ptr[7 * dst_stride], accum_data_v7); + } else { + for (int j = 0; j < residual_cols; ++j) { + __m256i result = accum_data_v[j]; + result = intrin_utils::mm256_min_epi32<path>(result, clamp_max_v); + result = intrin_utils::mm256_max_epi32<path>(result, clamp_min_v); + intrin_utils::mm256_n_storeu_cvtepi32_epi16<path>( + tmp_ptr, residual_rows, result); + tmp_ptr += dst_stride; + } + } + dst_ptr = static_cast<void*>(static_cast<std::int16_t*>(dst_ptr) + + kAvx8bitBlockSize); + } else if (params.dst_type_id == DstTypeId<std::int32_t>::kValue) { + if (store_full_block) { + std::int32_t* tmp_ptr = static_cast<std::int32_t*>(dst_ptr); + intrin_utils::mm256_storeu_epi32<path>(&tmp_ptr[0], accum_data_v0); + intrin_utils::mm256_storeu_epi32<path>(&tmp_ptr[dst_stride], + accum_data_v1); + intrin_utils::mm256_storeu_epi32<path>(&tmp_ptr[2 * dst_stride], + accum_data_v2); + intrin_utils::mm256_storeu_epi32<path>(&tmp_ptr[3 * dst_stride], + accum_data_v3); + intrin_utils::mm256_storeu_epi32<path>(&tmp_ptr[4 * dst_stride], + accum_data_v4); + intrin_utils::mm256_storeu_epi32<path>(&tmp_ptr[5 * dst_stride], + accum_data_v5); + intrin_utils::mm256_storeu_epi32<path>(&tmp_ptr[6 * dst_stride], + accum_data_v6); + intrin_utils::mm256_storeu_epi32<path>(&tmp_ptr[7 * dst_stride], + accum_data_v7); + } else { + std::int32_t* dst_block_ptr = static_cast<std::int32_t*>(dst_ptr); + for (int j = 0; j < residual_cols; ++j) { + intrin_utils::mm256_n_storeu_epi32<path>( + dst_block_ptr, residual_rows, accum_data_v[j]); + dst_block_ptr += dst_stride; + } + } + dst_ptr = static_cast<void*>(static_cast<std::int32_t*>(dst_ptr) + + kAvx8bitBlockSize); + } else { + RUY_DCHECK(false); + } + + lhs_col_ptr += kAvx8bitBlockSize * params.lhs_stride; + } // End row-block loop. + + dst_col_ptr = static_cast<void*>(static_cast<char*>(dst_col_ptr) + + kAvx8bitBlockSize * params.dst_stride); + rhs_col_ptr += kAvx8bitBlockSize * params.rhs_stride; + } // End col-block loop. +} // NOLINT(readability/fn_size) + +void Kernel8bitAvx(const KernelParams8bit<8, 8>& params) { + Kernel8bitAvxImpl<Path::kAvx>(params); +} + +template <Path path> +void Kernel8bitAvxSingleColImpl(const KernelParams8bit<8, 8>& params) { + profiler::ScopeLabel label("Kernel kAvx2 8-bit GEMV"); + + RUY_DCHECK_EQ(params.dst_cols, 1); + RUY_DCHECK_EQ(params.last_col, 0); + RUY_DCHECK_EQ(params.start_col, 0); + + const std::int8_t splitter_idx_data[32] = { + 0, 1, 4, 5, 8, 9, 12, 13, // + 2, 3, 6, 7, 10, 11, 14, 15, // + 0, 1, 4, 5, 8, 9, 12, 13, // + 2, 3, 6, 7, 10, 11, 14, 15 // + }; + + int bias_ptr_block_increment = + params.flags & RUY_ASM_FLAG_HAS_BIAS ? kAvx8bitBlockSize : 0; + + const std::int8_t* rhs_col_ptr = params.rhs_base_ptr; + void* dst_col_ptr = params.dst_base_ptr; + const std::int32_t* bias_col_ptr = params.bias; + if (params.flags & RUY_ASM_FLAG_HAS_BIAS) { + bias_col_ptr += params.start_row; + } + + const std::int8_t* lhs_col_ptr = params.lhs_base_ptr; + void* dst_ptr = dst_col_ptr; + const std::int32_t* bias_ptr = bias_col_ptr; + + const std::int32_t lhs_zero_point = params.lhs_zero_point; + const bool has_rhs_sums_offsets = + (params.flags & RUY_ASM_FLAG_HAS_RHS_SUMS) && lhs_zero_point; + std::int32_t rhs_sums_offsets[8]; + if (has_rhs_sums_offsets) { + const __m256i rhs_sums_offset_v = intrin_utils::mm256_mullo_epi32<path>( + _mm256_set1_epi32(lhs_zero_point), + _mm256_loadu_si256( + reinterpret_cast<__m256i const*>(¶ms.rhs_sums[0]))); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(rhs_sums_offsets), + rhs_sums_offset_v); + } + + for (int row = params.start_row; row <= params.last_row; + row += kAvx8bitBlockSize) { + const int residual_rows = + std::min(params.dst_rows - row, kAvx8bitBlockSize); + + const __m256i splitter_idx = + _mm256_loadu_si256(reinterpret_cast<__m256i const*>(splitter_idx_data)); + + __m256i accum_data_v0; + + // Initialize with bias. + __m256i initial_accum_data = + _mm256_loadu_si256(reinterpret_cast<const __m256i*>(bias_ptr)); + bias_ptr += bias_ptr_block_increment; + + // Adjustments common across columns. + const std::int32_t rhs_zero_point = params.rhs_zero_point; + if ((params.flags & RUY_ASM_FLAG_HAS_LHS_SUMS) && rhs_zero_point) { + const __m256i lhs_sums_offset = intrin_utils::mm256_mullo_epi32<path>( + _mm256_set1_epi32(rhs_zero_point), + _mm256_loadu_si256( + reinterpret_cast<__m256i const*>(¶ms.lhs_sums[row]))); + initial_accum_data = intrin_utils::mm256_sub_epi32<path>( + initial_accum_data, lhs_sums_offset); + } + const std::int32_t prod_zp_depth = params.prod_zp_depth; + if (prod_zp_depth) { + initial_accum_data = intrin_utils::mm256_add_epi32<path>( + initial_accum_data, _mm256_set1_epi32(prod_zp_depth)); + } + + // Adjustments differing across columns. + if (has_rhs_sums_offsets) { + accum_data_v0 = intrin_utils::mm256_sub_epi32<path>( + initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[0])); + } else { + accum_data_v0 = initial_accum_data; + } + + const std::int8_t* lhs_ptr = lhs_col_ptr; + const std::int8_t* rhs_ptr = rhs_col_ptr; + for (int d = 0; d < params.depth; d += kAvx8bitInnerSize) { + const __m256i lhs_data = + _mm256_load_si256(reinterpret_cast<const __m256i*>(lhs_ptr)); + const __m128i rhs_data_8bit = intrin_utils::mm_loadu_si32<path>(rhs_ptr); + + // Each "int32" is two 16-bit RHS values, sign extended from 8-bit. + // For simplicity we load 4x the data that we need and process twice the + // data that we need and store only the data we need. + std::int32_t rhs_data[2]; + const __m128i rhs_16_bit_dup = _mm_cvtepi8_epi16(rhs_data_8bit); + // Now that we have cast the RHS data, we store it so that each value + // can be separately loaded in the accumulation loop. + _mm_storeu_si64(reinterpret_cast<__m128i*>(rhs_data), rhs_16_bit_dup); + + // NOTE: There may be opportunities for permuting the data in the packing + // code instead of here. + const __m256i lhs_data_split = + intrin_utils::mm256_shuffle_epi8<path>(lhs_data, splitter_idx); + const __m256i lhs_data_split_expand_bottom = + intrin_utils::mm256_cvtepi8_epi16<path>( + _mm256_extractf128_si256(lhs_data_split, 0)); + const __m256i lhs_data_split_expand_top = + intrin_utils::mm256_cvtepi8_epi16<path>( + _mm256_extractf128_si256(lhs_data_split, 1)); + + // Take bytes 0, 1, 4, 5, 8, 9, ... expanded to 16-bit. + const __m256i lhs_16_bit_low = + intrin_utils::mm256_permute2x128_si256<path>( + lhs_data_split_expand_bottom, lhs_data_split_expand_top, 0x20); + // Take bytes 2, 3, 6, 7, 10, 11, ... expanded to 16-bit. + const __m256i lhs_16_bit_high = + intrin_utils::mm256_permute2x128_si256<path>( + lhs_data_split_expand_bottom, lhs_data_split_expand_top, 0x31); + // Accumulate for column 0. + const std::int32_t low_rhs_value = rhs_data[0]; + const std::int32_t high_rhs_value = rhs_data[1]; + + const __m256i rhs_16_bit_dup_low = _mm256_set1_epi32(low_rhs_value); + const __m256i rhs_16_bit_dup_high = _mm256_set1_epi32(high_rhs_value); + + accum_data_v0 = intrin_utils::mm256_add_epi32<path>( + accum_data_v0, intrin_utils::mm256_madd_epi16<path>( + lhs_16_bit_low, rhs_16_bit_dup_low)); + accum_data_v0 = intrin_utils::mm256_add_epi32<path>( + accum_data_v0, intrin_utils::mm256_madd_epi16<path>( + lhs_16_bit_high, rhs_16_bit_dup_high)); + + lhs_ptr += kAvx8bitBlockSize * kAvx8bitInnerSize; + rhs_ptr += kAvx8bitBlockSize * kAvx8bitInnerSize; + } + + if (params.dst_type_id != DstTypeId<std::int32_t>::kValue) { + __m256i m_vector; + __m256i e_vector; + // Does not make use of RUY_ASM_FLAG_NEEDS_LEFT_SHIFT. + int channel = (params.flags & RUY_ASM_FLAG_HAS_PERCHANNEL) ? row : 0; + m_vector = _mm256_loadu_si256(reinterpret_cast<const __m256i*>( + params.multiplier_fixedpoint + channel)); + e_vector = _mm256_loadu_si256(reinterpret_cast<const __m256i*>( + params.multiplier_exponent + channel)); + + const __m256i m_64bit_low = intrin_utils::mm256_cvtepi32_epi64<path>( + _mm256_extractf128_si256(m_vector, 0)); + const __m256i m_64bit_high = intrin_utils::mm256_cvtepi32_epi64<path>( + _mm256_extractf128_si256(m_vector, 1)); + + const __m256i zero_vector = _mm256_setzero_si256(); + const __m256i left_shift = + intrin_utils::mm256_max_epi32<path>(e_vector, zero_vector); + const __m256i neg_e_vector = + intrin_utils::mm256_sub_epi32<path>(zero_vector, e_vector); + const __m256i right_shift = + intrin_utils::mm256_max_epi32<path>(neg_e_vector, zero_vector); + const __m256i final_right_shift = intrin_utils::mm256_add_epi32<path>( + right_shift, _mm256_set1_epi32(31)); + const __m256i final_right_shift_low = + intrin_utils::mm256_cvtepi32_epi64<path>( + _mm256_extractf128_si256(final_right_shift, 0)); + const __m256i final_right_shift_high = + intrin_utils::mm256_cvtepi32_epi64<path>( + _mm256_extractf128_si256(final_right_shift, 1)); + // Really we want 0x100000000, but use half to avoid overflowing. + const __m256i convert_to_signed_halved = + intrin_utils::mm256_srlv_epi32<path>(_mm256_set1_epi32(0x80000000), + right_shift); + const __m256i convert_to_unsigned_64 = + _mm256_set1_epi64x(0x8000000000000000); + + __m256i post_scaling_offset = intrin_utils::mm256_add_epi32<path>( + convert_to_signed_halved, convert_to_signed_halved); + + const __m256i offset_vector = + intrin_utils::mm256_slli_epi64<path>(_mm256_set1_epi64x(1), 30); + // Really these should be shifted by neg_e_vector, but tests pass when + // using right_shift. + const __m256i offset_vector_low = intrin_utils::mm256_add_epi64<path>( + intrin_utils::mm256_sllv_epi64<path>( + offset_vector, intrin_utils::mm256_cvtepi32_epi64<path>( + _mm256_extractf128_si256(right_shift, 0))), + convert_to_unsigned_64); + const __m256i offset_vector_high = intrin_utils::mm256_add_epi64<path>( + intrin_utils::mm256_sllv_epi64<path>( + offset_vector, intrin_utils::mm256_cvtepi32_epi64<path>( + _mm256_extractf128_si256(right_shift, 1))), + convert_to_unsigned_64); + + if (params.dst_zero_point) { + const __m256i dst_zero_point = _mm256_set1_epi32(params.dst_zero_point); + // The post-scaling offset is subtracted later, so this has the effect + // of adding the zero point. + post_scaling_offset = intrin_utils::mm256_sub_epi32<path>( + post_scaling_offset, dst_zero_point); + } + + // See GEMM version for details of this process. + { + __m256i shifted_accum = + intrin_utils::mm256_sllv_epi32<path>(accum_data_v0, left_shift); + // Apply the fixed-point part of the multiplier. + __m256i scaled_v_low = intrin_utils::mm256_mul_epi32<path>( + intrin_utils::mm256_cvtepi32_epi64<path>( + _mm256_extractf128_si256(shifted_accum, 0)), + m_64bit_low); + __m256i scaled_v_high = intrin_utils::mm256_mul_epi32<path>( + intrin_utils::mm256_cvtepi32_epi64<path>( + _mm256_extractf128_si256(shifted_accum, 1)), + m_64bit_high); + + scaled_v_low = intrin_utils::mm256_add_epi64<path>(scaled_v_low, + offset_vector_low); + scaled_v_high = intrin_utils::mm256_add_epi64<path>(scaled_v_high, + offset_vector_high); + + scaled_v_low = intrin_utils::mm256_srlv_epi64<path>( + scaled_v_low, final_right_shift_low); + scaled_v_high = intrin_utils::mm256_srlv_epi64<path>( + scaled_v_high, final_right_shift_high); + + scaled_v_high = intrin_utils::mm256_slli_epi64<path>(scaled_v_high, 32); + __m256i results; + mm256_blend_epi32(results, scaled_v_low, scaled_v_high, 0xaa); + // Permute results to this ordering of int32 elements + // lo->hi (0, 2, 4, 6, 1, 3, 5, 7); + results = intrin_utils::PermuteEpi32EvenOdds<path>(results); + accum_data_v0 = + intrin_utils::mm256_sub_epi32<path>(results, post_scaling_offset); + } + } + const __m256i clamp_max_v = _mm256_set1_epi32(params.clamp_max); + const __m256i clamp_min_v = _mm256_set1_epi32(params.clamp_min); + + if (params.dst_type_id == DstTypeId<std::int8_t>::kValue) { + std::int8_t* tmp_ptr = static_cast<std::int8_t*>(dst_ptr); + __m256i result = accum_data_v0; + int32_t res = _mm256_extract_epi32(result, 0); + result = intrin_utils::mm256_min_epi32<path>(result, clamp_max_v); + result = intrin_utils::mm256_max_epi32<path>(result, clamp_min_v); + res = _mm256_extract_epi32(result, 0); + intrin_utils::mm256_n_storeu_cvtepi32_epi8<path>(tmp_ptr, residual_rows, + result); + dst_ptr = static_cast<void*>(static_cast<std::int8_t*>(dst_ptr) + + kAvx8bitBlockSize); + } else if (params.dst_type_id == DstTypeId<std::uint8_t>::kValue) { + std::uint8_t* tmp_ptr = static_cast<std::uint8_t*>(dst_ptr); + __m256i result = accum_data_v0; + result = intrin_utils::mm256_min_epi32<path>(result, clamp_max_v); + result = intrin_utils::mm256_max_epi32<path>(result, clamp_min_v); + intrin_utils::mm256_n_storeu_cvtepi32_epi8<path>(tmp_ptr, residual_rows, + result); + dst_ptr = static_cast<void*>(static_cast<std::uint8_t*>(dst_ptr) + + kAvx8bitBlockSize); + } else if (params.dst_type_id == DstTypeId<std::int16_t>::kValue) { + std::int16_t* tmp_ptr = static_cast<std::int16_t*>(dst_ptr); + __m256i result = accum_data_v0; + result = intrin_utils::mm256_min_epi32<path>(result, clamp_max_v); + result = intrin_utils::mm256_max_epi32<path>(result, clamp_min_v); + intrin_utils::mm256_n_storeu_cvtepi32_epi16<path>(tmp_ptr, residual_rows, + result); + dst_ptr = static_cast<void*>(static_cast<std::int16_t*>(dst_ptr) + + kAvx8bitBlockSize); + } else if (params.dst_type_id == DstTypeId<std::int32_t>::kValue) { + std::int32_t* dst_block_ptr = static_cast<std::int32_t*>(dst_ptr); + intrin_utils::mm256_n_storeu_epi32<path>(dst_block_ptr, residual_rows, + accum_data_v0); + dst_ptr = static_cast<void*>(static_cast<std::int32_t*>(dst_ptr) + + kAvx8bitBlockSize); + } else { + RUY_DCHECK(false); + } + + lhs_col_ptr += kAvx8bitBlockSize * params.lhs_stride; + } // End row-block loop. + + dst_col_ptr = static_cast<void*>(static_cast<char*>(dst_col_ptr) + + kAvx8bitBlockSize * params.dst_stride); + rhs_col_ptr += kAvx8bitBlockSize * params.rhs_stride; +} // NOLINT(readability/fn_size) + +void Kernel8bitAvxSingleCol(const KernelParams8bit<8, 8>& params) { + Kernel8bitAvxSingleColImpl<Path::kAvx>(params); +} + void KernelFloatAvx(const KernelParamsFloat<8, 8>& params) { profiler::ScopeLabel label("Kernel kAvx float"); KernelFloatAvxCommon<Path::kAvx>(params); diff --git a/ruy/kernel_avx2_fma.cc b/ruy/kernel_avx2_fma.cc index 66da091..92cc2f0 100644 --- a/ruy/kernel_avx2_fma.cc +++ b/ruy/kernel_avx2_fma.cc @@ -60,266 +60,10 @@ static constexpr int kAvx8bitInnerSize = 4; namespace { namespace intrin_utils { -// Polyfill for _mm_storeu_si16(dst, v). -inline void mm_storeu_si16(void* dst, __m128i v) { -#if defined __clang__ - _mm_storeu_si16(dst, v); -#else - // GCC 9 lacks support for __mm_storeu_si16. - *static_cast<std::int16_t*>(dst) = _mm_extract_epi16(v, 0); -#endif -} - -// Polyfill for _mm_storeu_si32(dst, v). -inline void mm_storeu_si32(void* dst, __m128i v) { -#if defined __clang__ - _mm_storeu_si32(dst, v); -#else - // GCC 9 lacks support for __mm_storeu_si32. - *static_cast<std::int32_t*>(dst) = _mm_extract_epi32(v, 0); -#endif -} - -// Polyfill for _mm_loadu_si32(src). -inline __m128i mm_loadu_si32(const void* src) { -#if defined __clang__ - return _mm_loadu_si32(src); -#else - // GCC 9 lacks support for _mm_loadu_si32. - __m128i res; - asm("movss %[src], %[res]" - : [res] "=x"(res) - : [src] "m"(*static_cast<const int*>(src))); - return res; -#endif -} - -inline void mm256_n_storeu_cvtepi32_epi8(std::uint8_t* dst, int residual_rows, - const __m256i v) { - // Select bytes 0, 4, 8, 12 within each lane, effectively truncating. - const __m256i repack_perm = _mm256_set1_epi32(0x0c080400); - __m256i shuffled_v; - if (residual_rows > 1) { - // This selects 0, 4, 8, 12, 0, 4, 8, 12, ..., but we only use the first 4 - // in each 128-bit lane. - shuffled_v = _mm256_shuffle_epi8(v, repack_perm); - } - switch (residual_rows) { - case 0: - break; - case 1: - dst[0] = _mm256_extract_epi8(v, 0); - break; - case 2: - mm_storeu_si16(dst, _mm256_extracti128_si256(shuffled_v, 0)); - break; - case 3: { - __m128i trailing_packed = _mm256_extracti128_si256(shuffled_v, 0); - mm_storeu_si16(dst, trailing_packed); - dst[2] = _mm_extract_epi8(trailing_packed, 2); - break; - } - case 4: - mm_storeu_si32(dst, _mm256_extracti128_si256(shuffled_v, 0)); - break; - case 5: - mm_storeu_si32(dst, _mm256_extracti128_si256(shuffled_v, 0)); - dst[4] = _mm256_extract_epi8(shuffled_v, 16); - break; - case 6: - mm_storeu_si32(dst, _mm256_extracti128_si256(shuffled_v, 0)); - mm_storeu_si16(dst + 4, _mm256_extracti128_si256(shuffled_v, 1)); - break; - case 7: { - mm_storeu_si32(dst, _mm256_extracti128_si256(shuffled_v, 0)); - __m128i trailing_packed = _mm256_extracti128_si256(shuffled_v, 1); - mm_storeu_si16(dst + 4, trailing_packed); - dst[6] = _mm_extract_epi8(trailing_packed, 2); - break; - } - case 8: - mm_storeu_si32(dst, _mm256_extracti128_si256(shuffled_v, 0)); - mm_storeu_si32(dst + 4, _mm256_extracti128_si256(shuffled_v, 1)); - break; - default: - RUY_DCHECK_LE(residual_rows, 8); - break; - } -} - -inline void mm256_storeu_cvtepi32_epi8(std::uint8_t* dst, const __m256i v) { - // Select bytes 0, 4, 8, 12 within each lane, effectively truncating. - const __m256i repack_perm = _mm256_set1_epi32(0x0c080400); - const __m256i shuffled_v = _mm256_shuffle_epi8(v, repack_perm); - mm_storeu_si32(dst, _mm256_extracti128_si256(shuffled_v, 0)); - mm_storeu_si32(dst + 4, _mm256_extracti128_si256(shuffled_v, 1)); -} - -inline void mm256_n_storeu_cvtepi32_epi8(std::int8_t* dst, int residual_rows, - const __m256i v) { - intrin_utils::mm256_n_storeu_cvtepi32_epi8( - reinterpret_cast<std::uint8_t*>(dst), residual_rows, v); -} - -inline void mm256_storeu_cvtepi32_epi8(std::int8_t* dst, const __m256i v) { - // Select bytes 0, 4, 8, 12 within each lane, effectively truncating. - const __m256i repack_perm = _mm256_set1_epi32(0x0c080400); - const __m256i shuffled_v = _mm256_shuffle_epi8(v, repack_perm); - mm_storeu_si32(dst, _mm256_extracti128_si256(shuffled_v, 0)); - mm_storeu_si32(dst + 4, _mm256_extracti128_si256(shuffled_v, 1)); -} - -inline void mm256_n_storeu_cvtepi32_epi16(std::int16_t* dst, int residual_rows, - const __m256i v) { - // Select bytes 0, 1, 4, 5, 8, 9, 12, 13 within each lane, effectively - // truncating each 16-bit integer. - const __m256i repack_perm = _mm256_set1_epi64x(0x0d0c090805040100); - __m256i shuffled_v; - __m128i shuffled_v_low; - if (residual_rows > 1) { - shuffled_v = _mm256_shuffle_epi8(v, repack_perm); - shuffled_v_low = _mm256_extracti128_si256(shuffled_v, 0); - } else { - shuffled_v_low = _mm256_extracti128_si256(v, 0); - } - switch (residual_rows) { - case 0: - break; - case 1: - mm_storeu_si16(dst, shuffled_v_low); - break; - case 2: - mm_storeu_si32(dst, shuffled_v_low); - break; - case 3: { - mm_storeu_si32(dst, shuffled_v_low); - dst[2] = _mm_extract_epi16(shuffled_v_low, 2); - break; - } - case 4: - _mm_storeu_si64(dst, shuffled_v_low); - break; - case 5: - _mm_storeu_si64(dst, shuffled_v_low); - dst[4] = _mm256_extract_epi16(shuffled_v, 8); - break; - case 6: - _mm_storeu_si64(dst, shuffled_v_low); - mm_storeu_si32(dst + 4, _mm256_extracti128_si256(shuffled_v, 1)); - break; - case 7: { - _mm_storeu_si64(dst, shuffled_v_low); - __m128i trailing_packed = _mm256_extracti128_si256(shuffled_v, 1); - mm_storeu_si32(dst + 4, trailing_packed); - dst[6] = _mm_extract_epi16(trailing_packed, 2); - break; - } - case 8: - _mm_storeu_si64(dst, _mm256_extracti128_si256(shuffled_v, 0)); - _mm_storeu_si64(dst + 4, _mm256_extracti128_si256(shuffled_v, 1)); - break; - default: - RUY_DCHECK_LE(residual_rows, 8); - break; - } -} - -inline void mm256_storeu_cvtepi32_epi16(std::int16_t* dst, const __m256i v) { - // Select bytes 0, 1, 4, 5, 8, 9, 12, 13 within each lane, effectively - // truncating each 16-bit integer. - const __m256i repack_perm = _mm256_set1_epi64x(0x0d0c090805040100); - const __m256i shuffled_v = _mm256_shuffle_epi8(v, repack_perm); - _mm_storeu_si64(dst, _mm256_extracti128_si256(shuffled_v, 0)); - _mm_storeu_si64(dst + 4, _mm256_extracti128_si256(shuffled_v, 1)); -} - -inline void mm256_n_storeu_epi32(std::int32_t* dst, int residual_rows, - const __m256i v) { - const __m128i v_low = _mm256_extracti128_si256(v, 0); - switch (residual_rows) { - case 0: - break; - case 1: - mm_storeu_si32(dst, v_low); - break; - case 2: - _mm_storeu_si64(dst, v_low); - break; - case 3: { - __m128i trailing_packed = v_low; - _mm_storeu_si64(dst, trailing_packed); - dst[2] = _mm_extract_epi32(trailing_packed, 2); - break; - } - case 4: - _mm_storeu_si128(reinterpret_cast<__m128i*>(dst), v_low); - break; - case 5: - _mm_storeu_si128(reinterpret_cast<__m128i*>(dst), v_low); - dst[4] = _mm256_extract_epi32(v, 4); - break; - case 6: - _mm_storeu_si128(reinterpret_cast<__m128i*>(dst), v_low); - _mm_storeu_si64(dst + 4, _mm256_extracti128_si256(v, 1)); - break; - case 7: { - _mm_storeu_si128(reinterpret_cast<__m128i*>(dst), v_low); - __m128i trailing_packed = _mm256_extracti128_si256(v, 1); - _mm_storeu_si64(dst + 4, trailing_packed); - dst[6] = _mm_extract_epi32(trailing_packed, 2); - break; - } - case 8: - _mm256_storeu_si256(reinterpret_cast<__m256i*>(dst), v); - break; - default: - RUY_DCHECK_LE(residual_rows, 8); - break; - } -} - -inline void mm256_storeu_epi32(std::int32_t* dst, const __m256i v) { - _mm256_storeu_si256(reinterpret_cast<__m256i*>(dst), v); -} - - -// Transpose a 8x8 matrix of floats. -void mm256_transpose8x8_ps(__m256* v0, __m256* v1, __m256* v2, __m256* v3, - __m256* v4, __m256* v5, __m256* v6, __m256* v7) { - __m256 t2x2_0 = _mm256_unpacklo_ps(*v0, *v1); - __m256 t2x2_1 = _mm256_unpackhi_ps(*v0, *v1); - __m256 t2x2_2 = _mm256_unpacklo_ps(*v2, *v3); - __m256 t2x2_3 = _mm256_unpackhi_ps(*v2, *v3); - __m256 t2x2_4 = _mm256_unpacklo_ps(*v4, *v5); - __m256 t2x2_5 = _mm256_unpackhi_ps(*v4, *v5); - __m256 t2x2_6 = _mm256_unpacklo_ps(*v6, *v7); - __m256 t2x2_7 = _mm256_unpackhi_ps(*v6, *v7); - __m256 t4x4_0 = _mm256_shuffle_ps(t2x2_0, t2x2_2, _MM_SHUFFLE(1, 0, 1, 0)); - __m256 t4x4_1 = _mm256_shuffle_ps(t2x2_0, t2x2_2, _MM_SHUFFLE(3, 2, 3, 2)); - __m256 t4x4_2 = _mm256_shuffle_ps(t2x2_1, t2x2_3, _MM_SHUFFLE(1, 0, 1, 0)); - __m256 t4x4_3 = _mm256_shuffle_ps(t2x2_1, t2x2_3, _MM_SHUFFLE(3, 2, 3, 2)); - __m256 t4x4_4 = _mm256_shuffle_ps(t2x2_4, t2x2_6, _MM_SHUFFLE(1, 0, 1, 0)); - __m256 t4x4_5 = _mm256_shuffle_ps(t2x2_4, t2x2_6, _MM_SHUFFLE(3, 2, 3, 2)); - __m256 t4x4_6 = _mm256_shuffle_ps(t2x2_5, t2x2_7, _MM_SHUFFLE(1, 0, 1, 0)); - __m256 t4x4_7 = _mm256_shuffle_ps(t2x2_5, t2x2_7, _MM_SHUFFLE(3, 2, 3, 2)); - *v0 = _mm256_permute2f128_ps(t4x4_0, t4x4_4, 0x20); - *v1 = _mm256_permute2f128_ps(t4x4_1, t4x4_5, 0x20); - *v2 = _mm256_permute2f128_ps(t4x4_2, t4x4_6, 0x20); - *v3 = _mm256_permute2f128_ps(t4x4_3, t4x4_7, 0x20); - *v4 = _mm256_permute2f128_ps(t4x4_0, t4x4_4, 0x31); - *v5 = _mm256_permute2f128_ps(t4x4_1, t4x4_5, 0x31); - *v6 = _mm256_permute2f128_ps(t4x4_2, t4x4_6, 0x31); - *v7 = _mm256_permute2f128_ps(t4x4_3, t4x4_7, 0x31); -} -// Transpose a 8x8 matrix of int32's. -void mm256_transpose8x8_epi32(__m256i* v0, __m256i* v1, __m256i* v2, - __m256i* v3, __m256i* v4, __m256i* v5, - __m256i* v6, __m256i* v7) { - mm256_transpose8x8_ps( - reinterpret_cast<__m256*>(v0), reinterpret_cast<__m256*>(v1), - reinterpret_cast<__m256*>(v2), reinterpret_cast<__m256*>(v3), - reinterpret_cast<__m256*>(v4), reinterpret_cast<__m256*>(v5), - reinterpret_cast<__m256*>(v6), reinterpret_cast<__m256*>(v7)); +template <> +inline __m256i mm256_shuffle_epi8<Path::kAvx2Fma>(const __m256i& a, + const __m256i& b) { + return _mm256_shuffle_epi8(a, b); } // Make an inline function for FMA so we can share the float kernels @@ -330,10 +74,25 @@ inline __m256 MulAdd<Path::kAvx2Fma>(const __m256& a, const __m256& b, return _mm256_fmadd_ps(a, b, c); } +template <> +inline __m128i mm256_extracti128_si256<Path::kAvx2Fma>(const __m256i& a, + const int imm) { + switch (imm) { + case 0: + return _mm256_extracti128_si256(a, 0); + case 1: + return _mm256_extracti128_si256(a, 1); + default: + RUY_DCHECK_LT(imm, 2); + return _mm_setzero_si128(); + } +} + } // namespace intrin_utils } // namespace -void Kernel8bitAvx2(const KernelParams8bit<8, 8>& params) { +template <Path path> +void Kernel8bitAvx2Impl(const KernelParams8bit<8, 8>& params) { profiler::ScopeLabel label("Kernel kAvx2Fma 8-bit"); const std::int8_t splitter_idx_data[32] = { 0, 1, 4, 5, 8, 9, 12, 13, // @@ -674,7 +433,7 @@ void Kernel8bitAvx2(const KernelParams8bit<8, 8>& params) { if (transpose_around_multiplier) { // Transpose the 8x8 accumulators block. Will be un-transposed below // after the multplier implementation. - intrin_utils::mm256_transpose8x8_epi32( + intrin_utils::mm256_transpose8x8_epi32<path>( &accum_data_v0, &accum_data_v1, &accum_data_v2, &accum_data_v3, &accum_data_v4, &accum_data_v5, &accum_data_v6, &accum_data_v7); } @@ -714,7 +473,7 @@ void Kernel8bitAvx2(const KernelParams8bit<8, 8>& params) { // of the 8x8 block of accumulators used to implement the // channels-are-columns case. if (transpose_around_multiplier) { - intrin_utils::mm256_transpose8x8_epi32( + intrin_utils::mm256_transpose8x8_epi32<path>( &accum_data_v0, &accum_data_v1, &accum_data_v2, &accum_data_v3, &accum_data_v4, &accum_data_v5, &accum_data_v6, &accum_data_v7); } @@ -755,29 +514,29 @@ void Kernel8bitAvx2(const KernelParams8bit<8, 8>& params) { accum_data_v6 = _mm256_max_epi32(accum_data_v6, clamp_min_v); accum_data_v7 = _mm256_min_epi32(accum_data_v7, clamp_max_v); accum_data_v7 = _mm256_max_epi32(accum_data_v7, clamp_min_v); - intrin_utils::mm256_storeu_cvtepi32_epi8(&tmp_ptr[0 * dst_stride], - accum_data_v0); - intrin_utils::mm256_storeu_cvtepi32_epi8(&tmp_ptr[1 * dst_stride], - accum_data_v1); - intrin_utils::mm256_storeu_cvtepi32_epi8(&tmp_ptr[2 * dst_stride], - accum_data_v2); - intrin_utils::mm256_storeu_cvtepi32_epi8(&tmp_ptr[3 * dst_stride], - accum_data_v3); - intrin_utils::mm256_storeu_cvtepi32_epi8(&tmp_ptr[4 * dst_stride], - accum_data_v4); - intrin_utils::mm256_storeu_cvtepi32_epi8(&tmp_ptr[5 * dst_stride], - accum_data_v5); - intrin_utils::mm256_storeu_cvtepi32_epi8(&tmp_ptr[6 * dst_stride], - accum_data_v6); - intrin_utils::mm256_storeu_cvtepi32_epi8(&tmp_ptr[7 * dst_stride], - accum_data_v7); + intrin_utils::mm256_storeu_cvtepi32_epi8<path>( + &tmp_ptr[0 * dst_stride], accum_data_v0); + intrin_utils::mm256_storeu_cvtepi32_epi8<path>( + &tmp_ptr[1 * dst_stride], accum_data_v1); + intrin_utils::mm256_storeu_cvtepi32_epi8<path>( + &tmp_ptr[2 * dst_stride], accum_data_v2); + intrin_utils::mm256_storeu_cvtepi32_epi8<path>( + &tmp_ptr[3 * dst_stride], accum_data_v3); + intrin_utils::mm256_storeu_cvtepi32_epi8<path>( + &tmp_ptr[4 * dst_stride], accum_data_v4); + intrin_utils::mm256_storeu_cvtepi32_epi8<path>( + &tmp_ptr[5 * dst_stride], accum_data_v5); + intrin_utils::mm256_storeu_cvtepi32_epi8<path>( + &tmp_ptr[6 * dst_stride], accum_data_v6); + intrin_utils::mm256_storeu_cvtepi32_epi8<path>( + &tmp_ptr[7 * dst_stride], accum_data_v7); } else { for (int j = 0; j < residual_cols; ++j) { __m256i result = accum_data_v[j]; result = _mm256_min_epi32(result, clamp_max_v); result = _mm256_max_epi32(result, clamp_min_v); - intrin_utils::mm256_n_storeu_cvtepi32_epi8(tmp_ptr, residual_rows, - result); + intrin_utils::mm256_n_storeu_cvtepi32_epi8<path>( + tmp_ptr, residual_rows, result); tmp_ptr += dst_stride; } } @@ -802,28 +561,29 @@ void Kernel8bitAvx2(const KernelParams8bit<8, 8>& params) { accum_data_v6 = _mm256_max_epi32(accum_data_v6, clamp_min_v); accum_data_v7 = _mm256_min_epi32(accum_data_v7, clamp_max_v); accum_data_v7 = _mm256_max_epi32(accum_data_v7, clamp_min_v); - intrin_utils::mm256_storeu_cvtepi32_epi8(&tmp_ptr[0], accum_data_v0); - intrin_utils::mm256_storeu_cvtepi32_epi8(&tmp_ptr[dst_stride], - accum_data_v1); - intrin_utils::mm256_storeu_cvtepi32_epi8(&tmp_ptr[2 * dst_stride], - accum_data_v2); - intrin_utils::mm256_storeu_cvtepi32_epi8(&tmp_ptr[3 * dst_stride], - accum_data_v3); - intrin_utils::mm256_storeu_cvtepi32_epi8(&tmp_ptr[4 * dst_stride], - accum_data_v4); - intrin_utils::mm256_storeu_cvtepi32_epi8(&tmp_ptr[5 * dst_stride], - accum_data_v5); - intrin_utils::mm256_storeu_cvtepi32_epi8(&tmp_ptr[6 * dst_stride], - accum_data_v6); - intrin_utils::mm256_storeu_cvtepi32_epi8(&tmp_ptr[7 * dst_stride], - accum_data_v7); + intrin_utils::mm256_storeu_cvtepi32_epi8<path>(&tmp_ptr[0], + accum_data_v0); + intrin_utils::mm256_storeu_cvtepi32_epi8<path>(&tmp_ptr[dst_stride], + accum_data_v1); + intrin_utils::mm256_storeu_cvtepi32_epi8<path>( + &tmp_ptr[2 * dst_stride], accum_data_v2); + intrin_utils::mm256_storeu_cvtepi32_epi8<path>( + &tmp_ptr[3 * dst_stride], accum_data_v3); + intrin_utils::mm256_storeu_cvtepi32_epi8<path>( + &tmp_ptr[4 * dst_stride], accum_data_v4); + intrin_utils::mm256_storeu_cvtepi32_epi8<path>( + &tmp_ptr[5 * dst_stride], accum_data_v5); + intrin_utils::mm256_storeu_cvtepi32_epi8<path>( + &tmp_ptr[6 * dst_stride], accum_data_v6); + intrin_utils::mm256_storeu_cvtepi32_epi8<path>( + &tmp_ptr[7 * dst_stride], accum_data_v7); } else { for (int j = 0; j < residual_cols; ++j) { __m256i result = accum_data_v[j]; result = _mm256_min_epi32(result, clamp_max_v); result = _mm256_max_epi32(result, clamp_min_v); - intrin_utils::mm256_n_storeu_cvtepi32_epi8(tmp_ptr, residual_rows, - result); + intrin_utils::mm256_n_storeu_cvtepi32_epi8<path>( + tmp_ptr, residual_rows, result); tmp_ptr += dst_stride; } } @@ -848,28 +608,29 @@ void Kernel8bitAvx2(const KernelParams8bit<8, 8>& params) { accum_data_v6 = _mm256_max_epi32(accum_data_v6, clamp_min_v); accum_data_v7 = _mm256_min_epi32(accum_data_v7, clamp_max_v); accum_data_v7 = _mm256_max_epi32(accum_data_v7, clamp_min_v); - intrin_utils::mm256_storeu_cvtepi32_epi16(&tmp_ptr[0], accum_data_v0); - intrin_utils::mm256_storeu_cvtepi32_epi16(&tmp_ptr[dst_stride], - accum_data_v1); - intrin_utils::mm256_storeu_cvtepi32_epi16(&tmp_ptr[2 * dst_stride], - accum_data_v2); - intrin_utils::mm256_storeu_cvtepi32_epi16(&tmp_ptr[3 * dst_stride], - accum_data_v3); - intrin_utils::mm256_storeu_cvtepi32_epi16(&tmp_ptr[4 * dst_stride], - accum_data_v4); - intrin_utils::mm256_storeu_cvtepi32_epi16(&tmp_ptr[5 * dst_stride], - accum_data_v5); - intrin_utils::mm256_storeu_cvtepi32_epi16(&tmp_ptr[6 * dst_stride], - accum_data_v6); - intrin_utils::mm256_storeu_cvtepi32_epi16(&tmp_ptr[7 * dst_stride], - accum_data_v7); + intrin_utils::mm256_storeu_cvtepi32_epi16<path>(&tmp_ptr[0], + accum_data_v0); + intrin_utils::mm256_storeu_cvtepi32_epi16<path>(&tmp_ptr[dst_stride], + accum_data_v1); + intrin_utils::mm256_storeu_cvtepi32_epi16<path>( + &tmp_ptr[2 * dst_stride], accum_data_v2); + intrin_utils::mm256_storeu_cvtepi32_epi16<path>( + &tmp_ptr[3 * dst_stride], accum_data_v3); + intrin_utils::mm256_storeu_cvtepi32_epi16<path>( + &tmp_ptr[4 * dst_stride], accum_data_v4); + intrin_utils::mm256_storeu_cvtepi32_epi16<path>( + &tmp_ptr[5 * dst_stride], accum_data_v5); + intrin_utils::mm256_storeu_cvtepi32_epi16<path>( + &tmp_ptr[6 * dst_stride], accum_data_v6); + intrin_utils::mm256_storeu_cvtepi32_epi16<path>( + &tmp_ptr[7 * dst_stride], accum_data_v7); } else { for (int j = 0; j < residual_cols; ++j) { __m256i result = accum_data_v[j]; result = _mm256_min_epi32(result, clamp_max_v); result = _mm256_max_epi32(result, clamp_min_v); - intrin_utils::mm256_n_storeu_cvtepi32_epi16(tmp_ptr, residual_rows, - result); + intrin_utils::mm256_n_storeu_cvtepi32_epi16<path>( + tmp_ptr, residual_rows, result); tmp_ptr += dst_stride; } } @@ -878,25 +639,26 @@ void Kernel8bitAvx2(const KernelParams8bit<8, 8>& params) { } else if (params.dst_type_id == DstTypeId<std::int32_t>::kValue) { if (store_full_block) { std::int32_t* tmp_ptr = static_cast<std::int32_t*>(dst_ptr); - intrin_utils::mm256_storeu_epi32(&tmp_ptr[0], accum_data_v0); - intrin_utils::mm256_storeu_epi32(&tmp_ptr[dst_stride], accum_data_v1); - intrin_utils::mm256_storeu_epi32(&tmp_ptr[2 * dst_stride], - accum_data_v2); - intrin_utils::mm256_storeu_epi32(&tmp_ptr[3 * dst_stride], - accum_data_v3); - intrin_utils::mm256_storeu_epi32(&tmp_ptr[4 * dst_stride], - accum_data_v4); - intrin_utils::mm256_storeu_epi32(&tmp_ptr[5 * dst_stride], - accum_data_v5); - intrin_utils::mm256_storeu_epi32(&tmp_ptr[6 * dst_stride], - accum_data_v6); - intrin_utils::mm256_storeu_epi32(&tmp_ptr[7 * dst_stride], - accum_data_v7); + intrin_utils::mm256_storeu_epi32<path>(&tmp_ptr[0], accum_data_v0); + intrin_utils::mm256_storeu_epi32<path>(&tmp_ptr[dst_stride], + accum_data_v1); + intrin_utils::mm256_storeu_epi32<path>(&tmp_ptr[2 * dst_stride], + accum_data_v2); + intrin_utils::mm256_storeu_epi32<path>(&tmp_ptr[3 * dst_stride], + accum_data_v3); + intrin_utils::mm256_storeu_epi32<path>(&tmp_ptr[4 * dst_stride], + accum_data_v4); + intrin_utils::mm256_storeu_epi32<path>(&tmp_ptr[5 * dst_stride], + accum_data_v5); + intrin_utils::mm256_storeu_epi32<path>(&tmp_ptr[6 * dst_stride], + accum_data_v6); + intrin_utils::mm256_storeu_epi32<path>(&tmp_ptr[7 * dst_stride], + accum_data_v7); } else { std::int32_t* dst_block_ptr = static_cast<std::int32_t*>(dst_ptr); for (int j = 0; j < residual_cols; ++j) { - intrin_utils::mm256_n_storeu_epi32(dst_block_ptr, residual_rows, - accum_data_v[j]); + intrin_utils::mm256_n_storeu_epi32<path>( + dst_block_ptr, residual_rows, accum_data_v[j]); dst_block_ptr += dst_stride; } } @@ -915,7 +677,12 @@ void Kernel8bitAvx2(const KernelParams8bit<8, 8>& params) { } // End col-block loop. } // NOLINT(readability/fn_size) -void Kernel8bitAvx2SingleCol(const KernelParams8bit<8, 8>& params) { +void Kernel8bitAvx2(const KernelParams8bit<8, 8>& params) { + Kernel8bitAvx2Impl<Path::kAvx2Fma>(params); +} + +template <Path path> +void Kernel8bitAvx2SingleColImpl(const KernelParams8bit<8, 8>& params) { profiler::ScopeLabel label("Kernel kAvx2Fma 8-bit GEMV"); RUY_DCHECK_EQ(params.dst_cols, 1); @@ -1000,7 +767,7 @@ void Kernel8bitAvx2SingleCol(const KernelParams8bit<8, 8>& params) { for (int d = 0; d < params.depth; d += kAvx8bitInnerSize) { const __m256i lhs_data = _mm256_load_si256(reinterpret_cast<const __m256i*>(lhs_ptr)); - const __m128i rhs_data_8bit = intrin_utils::mm_loadu_si32(rhs_ptr); + const __m128i rhs_data_8bit = intrin_utils::mm_loadu_si32<path>(rhs_ptr); // Each "int32" is two 16-bit RHS values, sign extended from 8-bit. // For simplicity we load 4x the data that we need and process twice the @@ -1133,10 +900,12 @@ void Kernel8bitAvx2SingleCol(const KernelParams8bit<8, 8>& params) { if (params.dst_type_id == DstTypeId<std::int8_t>::kValue) { std::int8_t* tmp_ptr = static_cast<std::int8_t*>(dst_ptr); __m256i result = accum_data_v0; + int32_t res = _mm256_extract_epi32(result, 0); result = _mm256_min_epi32(result, clamp_max_v); result = _mm256_max_epi32(result, clamp_min_v); - intrin_utils::mm256_n_storeu_cvtepi32_epi8(tmp_ptr, residual_rows, - result); + res = _mm256_extract_epi32(result, 0); + intrin_utils::mm256_n_storeu_cvtepi32_epi8<path>(tmp_ptr, residual_rows, + result); dst_ptr = static_cast<void*>(static_cast<std::int8_t*>(dst_ptr) + kAvx8bitBlockSize); } else if (params.dst_type_id == DstTypeId<std::uint8_t>::kValue) { @@ -1144,8 +913,8 @@ void Kernel8bitAvx2SingleCol(const KernelParams8bit<8, 8>& params) { __m256i result = accum_data_v0; result = _mm256_min_epi32(result, clamp_max_v); result = _mm256_max_epi32(result, clamp_min_v); - intrin_utils::mm256_n_storeu_cvtepi32_epi8(tmp_ptr, residual_rows, - result); + intrin_utils::mm256_n_storeu_cvtepi32_epi8<path>(tmp_ptr, residual_rows, + result); dst_ptr = static_cast<void*>(static_cast<std::uint8_t*>(dst_ptr) + kAvx8bitBlockSize); } else if (params.dst_type_id == DstTypeId<std::int16_t>::kValue) { @@ -1153,14 +922,14 @@ void Kernel8bitAvx2SingleCol(const KernelParams8bit<8, 8>& params) { __m256i result = accum_data_v0; result = _mm256_min_epi32(result, clamp_max_v); result = _mm256_max_epi32(result, clamp_min_v); - intrin_utils::mm256_n_storeu_cvtepi32_epi16(tmp_ptr, residual_rows, - result); + intrin_utils::mm256_n_storeu_cvtepi32_epi16<path>(tmp_ptr, residual_rows, + result); dst_ptr = static_cast<void*>(static_cast<std::int16_t*>(dst_ptr) + kAvx8bitBlockSize); } else if (params.dst_type_id == DstTypeId<std::int32_t>::kValue) { std::int32_t* dst_block_ptr = static_cast<std::int32_t*>(dst_ptr); - intrin_utils::mm256_n_storeu_epi32(dst_block_ptr, residual_rows, - accum_data_v0); + intrin_utils::mm256_n_storeu_epi32<path>(dst_block_ptr, residual_rows, + accum_data_v0); dst_ptr = static_cast<void*>(static_cast<std::int32_t*>(dst_ptr) + kAvx8bitBlockSize); } else { @@ -1175,6 +944,10 @@ void Kernel8bitAvx2SingleCol(const KernelParams8bit<8, 8>& params) { rhs_col_ptr += kAvx8bitBlockSize * params.rhs_stride; } // NOLINT(readability/fn_size) +void Kernel8bitAvx2SingleCol(const KernelParams8bit<8, 8>& params) { + Kernel8bitAvx2SingleColImpl<Path::kAvx2Fma>(params); +} + void KernelFloatAvx2(const KernelParamsFloat<8, 8>& params) { profiler::ScopeLabel label("Kernel kAvx2Fma float"); KernelFloatAvxCommon<Path::kAvx2Fma>(params); diff --git a/ruy/kernel_x86.h b/ruy/kernel_x86.h index bad86cd..c530a1f 100644 --- a/ruy/kernel_x86.h +++ b/ruy/kernel_x86.h @@ -159,6 +159,32 @@ struct Kernel<Path::kAvx, float, float, float, float> { } } }; + +void Kernel8bitAvx(const KernelParams8bit<8, 8>& params); +void Kernel8bitAvxSingleCol(const KernelParams8bit<8, 8>& params); + +template <typename DstScalar> +struct Kernel<Path::kAvx, std::int8_t, std::int8_t, std::int32_t, DstScalar> { + static constexpr Path kPath = Path::kAvx2Fma; + Tuning tuning = Tuning::kAuto; + using LhsLayout = FixedKernelLayout<Order::kColMajor, 4, 8>; + using RhsLayout = FixedKernelLayout<Order::kColMajor, 4, 8>; + explicit Kernel(Tuning tuning_) : tuning(tuning_) {} + void Run(const PMat<std::int8_t>& lhs, const PMat<std::int8_t>& rhs, + const MulParams<std::int32_t, DstScalar>& mul_params, int start_row, + int start_col, int end_row, int end_col, Mat<DstScalar>* dst) const { + KernelParams8bit<LhsLayout::kCols, RhsLayout::kCols> params; + MakeKernelParams8bit(lhs, rhs, mul_params, start_row, start_col, end_row, + end_col, dst, ¶ms); + if (dst->layout.cols == 1 && + mul_params.channel_dimension() == ChannelDimension::kRow) { + Kernel8bitAvxSingleCol(params); + } else { + Kernel8bitAvx(params); + } + } +}; + #endif // RUY_PLATFORM_X86 } // namespace ruy @@ -225,6 +251,297 @@ inline __m256 MulAdd(const __m256&, const __m256&, const __m256&) { // files. RUY_DCHECK(false); } + +template <Path path> +inline __m256i mm256_shuffle_epi8(const __m256i&, const __m256i&) { + // Specializations added for AVX and AVX2FMA paths in their respective kernel + // files. + RUY_DCHECK(false); +} + +// Polyfill for _mm_storeu_si16(dst, v). +template <Path path> +inline void mm_storeu_si16(void* dst, __m128i v) { +#if defined __clang__ + _mm_storeu_si16(dst, v); +#else + // GCC 9 lacks support for __mm_storeu_si16. + *static_cast<std::int16_t*>(dst) = _mm_extract_epi16(v, 0); +#endif +} + +// Polyfill for _mm_storeu_si32(dst, v). +template <Path path> +inline void mm_storeu_si32(void* dst, __m128i v) { +#if defined __clang__ + _mm_storeu_si32(dst, v); +#else + // GCC 9 lacks support for __mm_storeu_si32. + *static_cast<std::int32_t*>(dst) = _mm_extract_epi32(v, 0); +#endif +} + +// Polyfill for _mm_loadu_si32(src). +template <Path path> +inline __m128i mm_loadu_si32(const void* src) { +#if defined __clang__ + return _mm_loadu_si32(src); +#else + // GCC 9 lacks support for _mm_loadu_si32. + __m128i res; + asm("movss %[src], %[res]" + : [res] "=x"(res) + : [src] "m"(*static_cast<const int*>(src))); + return res; +#endif +} + +template <Path path> +inline __m128i mm256_extracti128_si256(const __m256i&, const int) { + RUY_DCHECK(false); +} + +template <Path path> +inline void mm256_n_storeu_cvtepi32_epi8(std::uint8_t* dst, int residual_rows, + const __m256i v) { + // Select bytes 0, 4, 8, 12 within each lane, effectively truncating. + const __m256i repack_perm = _mm256_set1_epi32(0x0c080400); + __m256i shuffled_v; + if (residual_rows > 1) { + // This selects 0, 4, 8, 12, 0, 4, 8, 12, ..., but we only use the first 4 + // in each 128-bit lane. + shuffled_v = intrin_utils::mm256_shuffle_epi8<path>(v, repack_perm); + } + switch (residual_rows) { + case 0: + break; + case 1: + dst[0] = _mm256_extract_epi8(v, 0); + break; + case 2: + mm_storeu_si16<path>(dst, mm256_extracti128_si256<path>(shuffled_v, 0)); + break; + case 3: { + __m128i trailing_packed = mm256_extracti128_si256<path>(shuffled_v, 0); + mm_storeu_si16<path>(dst, trailing_packed); + dst[2] = _mm_extract_epi8(trailing_packed, 2); + break; + } + case 4: + mm_storeu_si32<path>(dst, mm256_extracti128_si256<path>(shuffled_v, 0)); + break; + case 5: + mm_storeu_si32<path>(dst, mm256_extracti128_si256<path>(shuffled_v, 0)); + dst[4] = _mm256_extract_epi8(shuffled_v, 16); + break; + case 6: + mm_storeu_si32<path>(dst, mm256_extracti128_si256<path>(shuffled_v, 0)); + mm_storeu_si16<path>(dst + 4, + mm256_extracti128_si256<path>(shuffled_v, 1)); + break; + case 7: { + mm_storeu_si32<path>(dst, mm256_extracti128_si256<path>(shuffled_v, 0)); + __m128i trailing_packed = mm256_extracti128_si256<path>(shuffled_v, 1); + mm_storeu_si16<path>(dst + 4, trailing_packed); + dst[6] = _mm_extract_epi8(trailing_packed, 2); + break; + } + case 8: + mm_storeu_si32<path>(dst, mm256_extracti128_si256<path>(shuffled_v, 0)); + mm_storeu_si32<path>(dst + 4, + mm256_extracti128_si256<path>(shuffled_v, 1)); + break; + default: + RUY_DCHECK_LE(residual_rows, 8); + break; + } +} + +template <Path path> +inline void mm256_storeu_cvtepi32_epi8(std::uint8_t* dst, const __m256i v) { + // Select bytes 0, 4, 8, 12 within each lane, effectively truncating. + const __m256i repack_perm = _mm256_set1_epi32(0x0c080400); + const __m256i shuffled_v = mm256_shuffle_epi8<path>(v, repack_perm); + mm_storeu_si32<path>(dst, mm256_extracti128_si256<path>(shuffled_v, 0)); + mm_storeu_si32<path>(dst + 4, mm256_extracti128_si256<path>(shuffled_v, 1)); +} + +template <Path path> +inline void mm256_n_storeu_cvtepi32_epi8(std::int8_t* dst, int residual_rows, + const __m256i v) { + intrin_utils::mm256_n_storeu_cvtepi32_epi8<path>( + reinterpret_cast<std::uint8_t*>(dst), residual_rows, v); +} + +template <Path path> +inline void mm256_storeu_cvtepi32_epi8(std::int8_t* dst, const __m256i v) { + // Select bytes 0, 4, 8, 12 within each lane, effectively truncating. + const __m256i repack_perm = _mm256_set1_epi32(0x0c080400); + const __m256i shuffled_v = mm256_shuffle_epi8<path>(v, repack_perm); + mm_storeu_si32<path>(dst, mm256_extracti128_si256<path>(shuffled_v, 0)); + mm_storeu_si32<path>(dst + 4, mm256_extracti128_si256<path>(shuffled_v, 1)); +} + +template <Path path> +inline void mm256_n_storeu_cvtepi32_epi16(std::int16_t* dst, int residual_rows, + const __m256i v) { + // Select bytes 0, 1, 4, 5, 8, 9, 12, 13 within each lane, effectively + // truncating each 16-bit integer. + const __m256i repack_perm = _mm256_set1_epi64x(0x0d0c090805040100); + __m256i shuffled_v; + __m128i shuffled_v_low; + if (residual_rows > 1) { + shuffled_v = mm256_shuffle_epi8<path>(v, repack_perm); + shuffled_v_low = mm256_extracti128_si256<path>(shuffled_v, 0); + } else { + shuffled_v_low = mm256_extracti128_si256<path>(v, 0); + } + switch (residual_rows) { + case 0: + break; + case 1: + mm_storeu_si16<path>(dst, shuffled_v_low); + break; + case 2: + mm_storeu_si32<path>(dst, shuffled_v_low); + break; + case 3: { + mm_storeu_si32<path>(dst, shuffled_v_low); + dst[2] = _mm_extract_epi16(shuffled_v_low, 2); + break; + } + case 4: + _mm_storeu_si64(dst, shuffled_v_low); + break; + case 5: + _mm_storeu_si64(dst, shuffled_v_low); + dst[4] = _mm256_extract_epi16(shuffled_v, 8); + break; + case 6: + _mm_storeu_si64(dst, shuffled_v_low); + mm_storeu_si32<path>(dst + 4, + mm256_extracti128_si256<path>(shuffled_v, 1)); + break; + case 7: { + _mm_storeu_si64(dst, shuffled_v_low); + __m128i trailing_packed = mm256_extracti128_si256<path>(shuffled_v, 1); + mm_storeu_si32<path>(dst + 4, trailing_packed); + dst[6] = _mm_extract_epi16(trailing_packed, 2); + break; + } + case 8: + _mm_storeu_si64(dst, mm256_extracti128_si256<path>(shuffled_v, 0)); + _mm_storeu_si64(dst + 4, mm256_extracti128_si256<path>(shuffled_v, 1)); + break; + default: + RUY_DCHECK_LE(residual_rows, 8); + break; + } +} + +template <Path path> +inline void mm256_storeu_cvtepi32_epi16(std::int16_t* dst, const __m256i v) { + // Select bytes 0, 1, 4, 5, 8, 9, 12, 13 within each lane, effectively + // truncating each 16-bit integer. + const __m256i repack_perm = _mm256_set1_epi64x(0x0d0c090805040100); + const __m256i shuffled_v = mm256_shuffle_epi8<path>(v, repack_perm); + _mm_storeu_si64(dst, mm256_extracti128_si256<path>(shuffled_v, 0)); + _mm_storeu_si64(dst + 4, mm256_extracti128_si256<path>(shuffled_v, 1)); +} + +template <Path path> +inline void mm256_n_storeu_epi32(std::int32_t* dst, int residual_rows, + const __m256i v) { + const __m128i v_low = mm256_extracti128_si256<path>(v, 0); + switch (residual_rows) { + case 0: + break; + case 1: + mm_storeu_si32<path>(dst, v_low); + break; + case 2: + _mm_storeu_si64(dst, v_low); + break; + case 3: { + __m128i trailing_packed = v_low; + _mm_storeu_si64(dst, trailing_packed); + dst[2] = _mm_extract_epi32(trailing_packed, 2); + break; + } + case 4: + _mm_storeu_si128(reinterpret_cast<__m128i*>(dst), v_low); + break; + case 5: + _mm_storeu_si128(reinterpret_cast<__m128i*>(dst), v_low); + dst[4] = _mm256_extract_epi32(v, 4); + break; + case 6: + _mm_storeu_si128(reinterpret_cast<__m128i*>(dst), v_low); + _mm_storeu_si64(dst + 4, mm256_extracti128_si256<path>(v, 1)); + break; + case 7: { + _mm_storeu_si128(reinterpret_cast<__m128i*>(dst), v_low); + __m128i trailing_packed = mm256_extracti128_si256<path>(v, 1); + _mm_storeu_si64(dst + 4, trailing_packed); + dst[6] = _mm_extract_epi32(trailing_packed, 2); + break; + } + case 8: + _mm256_storeu_si256(reinterpret_cast<__m256i*>(dst), v); + break; + default: + RUY_DCHECK_LE(residual_rows, 8); + break; + } +} + +template <Path path> +inline void mm256_storeu_epi32(std::int32_t* dst, const __m256i v) { + _mm256_storeu_si256(reinterpret_cast<__m256i*>(dst), v); +} + +// Transpose a 8x8 matrix of floats. +template <Path path> +void mm256_transpose8x8_ps(__m256* v0, __m256* v1, __m256* v2, __m256* v3, + __m256* v4, __m256* v5, __m256* v6, __m256* v7) { + __m256 t2x2_0 = _mm256_unpacklo_ps(*v0, *v1); + __m256 t2x2_1 = _mm256_unpackhi_ps(*v0, *v1); + __m256 t2x2_2 = _mm256_unpacklo_ps(*v2, *v3); + __m256 t2x2_3 = _mm256_unpackhi_ps(*v2, *v3); + __m256 t2x2_4 = _mm256_unpacklo_ps(*v4, *v5); + __m256 t2x2_5 = _mm256_unpackhi_ps(*v4, *v5); + __m256 t2x2_6 = _mm256_unpacklo_ps(*v6, *v7); + __m256 t2x2_7 = _mm256_unpackhi_ps(*v6, *v7); + __m256 t4x4_0 = _mm256_shuffle_ps(t2x2_0, t2x2_2, _MM_SHUFFLE(1, 0, 1, 0)); + __m256 t4x4_1 = _mm256_shuffle_ps(t2x2_0, t2x2_2, _MM_SHUFFLE(3, 2, 3, 2)); + __m256 t4x4_2 = _mm256_shuffle_ps(t2x2_1, t2x2_3, _MM_SHUFFLE(1, 0, 1, 0)); + __m256 t4x4_3 = _mm256_shuffle_ps(t2x2_1, t2x2_3, _MM_SHUFFLE(3, 2, 3, 2)); + __m256 t4x4_4 = _mm256_shuffle_ps(t2x2_4, t2x2_6, _MM_SHUFFLE(1, 0, 1, 0)); + __m256 t4x4_5 = _mm256_shuffle_ps(t2x2_4, t2x2_6, _MM_SHUFFLE(3, 2, 3, 2)); + __m256 t4x4_6 = _mm256_shuffle_ps(t2x2_5, t2x2_7, _MM_SHUFFLE(1, 0, 1, 0)); + __m256 t4x4_7 = _mm256_shuffle_ps(t2x2_5, t2x2_7, _MM_SHUFFLE(3, 2, 3, 2)); + *v0 = _mm256_permute2f128_ps(t4x4_0, t4x4_4, 0x20); + *v1 = _mm256_permute2f128_ps(t4x4_1, t4x4_5, 0x20); + *v2 = _mm256_permute2f128_ps(t4x4_2, t4x4_6, 0x20); + *v3 = _mm256_permute2f128_ps(t4x4_3, t4x4_7, 0x20); + *v4 = _mm256_permute2f128_ps(t4x4_0, t4x4_4, 0x31); + *v5 = _mm256_permute2f128_ps(t4x4_1, t4x4_5, 0x31); + *v6 = _mm256_permute2f128_ps(t4x4_2, t4x4_6, 0x31); + *v7 = _mm256_permute2f128_ps(t4x4_3, t4x4_7, 0x31); +} + +// Transpose a 8x8 matrix of int32's. +template <Path path> +void mm256_transpose8x8_epi32(__m256i* v0, __m256i* v1, __m256i* v2, + __m256i* v3, __m256i* v4, __m256i* v5, + __m256i* v6, __m256i* v7) { + mm256_transpose8x8_ps<path>( + reinterpret_cast<__m256*>(v0), reinterpret_cast<__m256*>(v1), + reinterpret_cast<__m256*>(v2), reinterpret_cast<__m256*>(v3), + reinterpret_cast<__m256*>(v4), reinterpret_cast<__m256*>(v5), + reinterpret_cast<__m256*>(v6), reinterpret_cast<__m256*>(v7)); +} + } // namespace intrin_utils } // namespace |