diff options
author | Dayeong Lee <dayeongl@google.com> | 2021-12-07 04:01:48 +0300 |
---|---|---|
committer | Copybara-Service <copybara-worker@google.com> | 2021-12-07 04:02:15 +0300 |
commit | 6c292a6e91cd3dab6059334d60c09fb5c7d1a94e (patch) | |
tree | c8977f70fedab2f25e39b0d24db1bea197147780 | |
parent | 8c3fd3f266b4a22d542d4aa41329b5018d6b87e1 (diff) |
Ruy:Add new packing for 16bit ColMajor for Avx512.
PiperOrigin-RevId: 414576763
-rw-r--r-- | ruy/pack_avx512.cc | 287 | ||||
-rw-r--r-- | ruy/pack_x86.h | 47 |
2 files changed, 325 insertions, 9 deletions
diff --git a/ruy/pack_avx512.cc b/ruy/pack_avx512.cc index 5281fa8..29d6f9f 100644 --- a/ruy/pack_avx512.cc +++ b/ruy/pack_avx512.cc @@ -38,6 +38,12 @@ void Pack8bitColMajorForAvx512(const std::int8_t*, std::int8_t, RUY_DCHECK(false); } +void Pack16bitColMajorForAvx512(const std::int16_t*, const std::int16_t*, int, + int, int, std::int16_t*, std::int32_t*) { + // CPU-ID-based checks should disable the path that would reach this point. + RUY_DCHECK(false); +} + void PackFloatColMajorForAvx512(const float*, const float*, int, int, int, float*) { // CPU-ID-based checks should disable the path that would reach this point. @@ -56,15 +62,19 @@ void Pack8bitRowMajorForAvx512(const std::uint8_t*, int, int, std::int8_t*, int, using PackImpl8bitAvx512 = PackImpl<Path::kAvx512, FixedKernelLayout<Order::kColMajor, 4, 16>, std::int8_t, std::int8_t, std::int32_t, Order::kColMajor>; +using PackImpl16bitAvx512 = + PackImpl<Path::kAvx512, FixedKernelLayout<Order::kColMajor, 4, 16>, + std::int16_t, std::int16_t, std::int32_t, Order::kColMajor>; namespace { -inline void ZeroHalf8bitAvx512(int src_rows, std::int8_t packed_zero_point, - std::int8_t* packed_ptr) { - using Layout = PackImpl8bitAvx512::Layout; +template <typename PackImplAvx512, typename Scalar> +inline void ZeroHalfAvx512(int src_rows, Scalar packed_zero_point, + Scalar* packed_ptr) { + using Layout = typename PackImplAvx512::Layout; static constexpr int kHalfLayoutCols = - PackImpl8bitAvx512::kHalfLayoutCols; // Half the number of cols in a - // block. + PackImplAvx512::kHalfLayoutCols; // Half the number of cols in a + // block. RUY_DCHECK_EQ(kHalfLayoutCols, 8); RUY_DCHECK_EQ(Layout::kCols, 16); RUY_DCHECK_EQ(Layout::kRows, 4); @@ -79,8 +89,8 @@ inline void ZeroHalf8bitAvx512(int src_rows, std::int8_t packed_zero_point, } } -inline __m512i LoaduTwo(const std::int8_t* addr_lo, - const std::int8_t* addr_hi) { +template <typename Scalar> +inline __m512i LoaduTwo(const Scalar* addr_lo, const Scalar* addr_hi) { __m512i lower_filled = _mm512_castsi256_si512( _mm256_loadu_si256(reinterpret_cast<const __m256i*>(addr_lo))); return _mm512_inserti32x8( @@ -98,6 +108,16 @@ inline __m512i MaskLoaduTwo(__mmask32 row_mask, const __m256i default_value_v, 1); } +inline __m512i MaskLoaduTwo(__mmask32 row_mask, const __m256i default_value_v, + const std::int16_t* addr_lo, + const std::int16_t* addr_hi) { + const __m512i lower_filled = _mm512_castsi256_si512( + _mm256_mask_loadu_epi16(default_value_v, row_mask, addr_lo)); + return _mm512_inserti32x8( + lower_filled, _mm256_mask_loadu_epi16(default_value_v, row_mask, addr_hi), + 1); +} + inline void HalfPack8bitAvx512(const std::int8_t* src_ptr, std::int8_t input_xor, const std::int8_t* zerobuf, int src_stride, @@ -454,6 +474,193 @@ inline void HalfPack8bitAvx512(const std::int8_t* src_ptr, } } +inline void HalfPack16bitAvx512(const std::int16_t* src_ptr, + const std::int16_t* zerobuf, int src_stride, + int remaining_src_cols, int src_rows, + std::int16_t* packed_ptr, + std::int32_t* sums_ptr, + std::int16_t* trailing_buf) { + using Layout = PackImpl16bitAvx512::Layout; + RUY_DCHECK_EQ(Layout::kCols, 16); + RUY_DCHECK_EQ(Layout::kRows, 4); + // Each Layout::Rows is 4 contiguous input, contiguous packed elements. + // We process 4 of these chunks at a time, padding std::int16_t input chunks. + constexpr int kNumRowChunks = 4; + constexpr int kNumChunkedSrcRows = kNumRowChunks * Layout::kRows; + + const std::int16_t* src_ptr0 = src_ptr; + const std::int16_t* src_ptr1 = src_ptr0 + src_stride; + const std::int16_t* src_ptr2 = src_ptr1 + src_stride; + const std::int16_t* src_ptr3 = src_ptr2 + src_stride; + const std::int16_t* src_ptr4 = src_ptr3 + src_stride; + const std::int16_t* src_ptr5 = src_ptr4 + src_stride; + const std::int16_t* src_ptr6 = src_ptr5 + src_stride; + const std::int16_t* src_ptr7 = src_ptr6 + src_stride; + std::int64_t src_inc0 = kNumChunkedSrcRows; + std::int64_t src_inc1 = kNumChunkedSrcRows; + std::int64_t src_inc2 = kNumChunkedSrcRows; + std::int64_t src_inc3 = kNumChunkedSrcRows; + std::int64_t src_inc4 = kNumChunkedSrcRows; + std::int64_t src_inc5 = kNumChunkedSrcRows; + std::int64_t src_inc6 = kNumChunkedSrcRows; + std::int64_t src_inc7 = kNumChunkedSrcRows; + // Handle cases where source does not have kHalfLayoutCols (8) columns. + if (remaining_src_cols < 8) { + if (remaining_src_cols <= 0) { + src_ptr0 = zerobuf; + src_inc0 = 0; + } + if (remaining_src_cols <= 1) { + src_ptr1 = zerobuf; + src_inc1 = 0; + } + if (remaining_src_cols <= 2) { + src_ptr2 = zerobuf; + src_inc2 = 0; + } + if (remaining_src_cols <= 3) { + src_ptr3 = zerobuf; + src_inc3 = 0; + } + if (remaining_src_cols <= 4) { + src_ptr4 = zerobuf; + src_inc4 = 0; + } + if (remaining_src_cols <= 5) { + src_ptr5 = zerobuf; + src_inc5 = 0; + } + if (remaining_src_cols <= 6) { + src_ptr6 = zerobuf; + src_inc6 = 0; + } + src_ptr7 = zerobuf; + src_inc7 = 0; + } + + const std::int16_t zero_point = zerobuf[0]; + + if (sums_ptr) { + // i: kHalfLayoutCols. + for (int i = 0; i < 8; ++i) { + sums_ptr[i] = 0; + } + } + std::int32_t sums_adjustment = 0; + const __m512i ones_16bit = _mm512_set1_epi16(1); + __m512i sums_8x2_32bit = _mm512_set1_epi32(0); + + // The overall packing effectively pads the source rows to + // (src_rows + 31) & ~31. The iteration over k may skip when m=1, and then we + // only pack for (src_rows + 15) & ~15. When there is an incomplete + // destination block, this is stored into trailing_buf instead of packed_ptr. + for (int k = 0; k < src_rows; k += 2 * kNumChunkedSrcRows) { + // m: {0, 1} for 2 chunks of rows. + for (int m = 0; m < 2; ++m) { + const int available_src_rows = src_rows - k - m * kNumChunkedSrcRows; + + // Available source rows. + // If this is less than 0 (for m=1), we skip, having filled trailing + // buffer for m=0. Also, if source rows is zero on m=1, then we filled + // exactly to the end of the column in the packed buffer. + if (available_src_rows > 0) { + __m512i t0, t1, t2, t3; + __m512i r0, r1, r2, r3; + std::int16_t* dst_ptr = packed_ptr; + + if (available_src_rows >= kNumChunkedSrcRows) { + t0 = LoaduTwo(src_ptr0, src_ptr4); + t1 = LoaduTwo(src_ptr1, src_ptr5); + t2 = LoaduTwo(src_ptr2, src_ptr6); + t3 = LoaduTwo(src_ptr3, src_ptr7); + } else { + RUY_DCHECK_LT(available_src_rows >> 2, kNumChunkedSrcRows); + // We do not care what goes into the trailing buffer, but we want + // in_data[...] == zero_point for irrelevant values in the summation. + // + // We compensate for padding-with-zero_point by initializing the + // summations with the compensating offset. + sums_adjustment += + -(zero_point)*4 * (4 - ((available_src_rows + 3) >> 2)); + + const __m256i zero_point_v = _mm256_set1_epi16(zero_point); + const __mmask32 row_mask = + (static_cast<std::uint64_t>(1) << available_src_rows) - 1; + + t0 = MaskLoaduTwo(row_mask, zero_point_v, src_ptr0, src_ptr4); + t1 = MaskLoaduTwo(row_mask, zero_point_v, src_ptr1, src_ptr5); + t2 = MaskLoaduTwo(row_mask, zero_point_v, src_ptr2, src_ptr6); + t3 = MaskLoaduTwo(row_mask, zero_point_v, src_ptr3, src_ptr7); + dst_ptr = trailing_buf; + } + + r0 = _mm512_unpacklo_epi64(t0, t1); + r2 = _mm512_unpackhi_epi64(t0, t1); + r1 = _mm512_unpacklo_epi64(t2, t3); + r3 = _mm512_unpackhi_epi64(t2, t3); + + r1 = _mm512_permutex_epi64(r1, 0x4e); + r3 = _mm512_permutex_epi64(r3, 0x4e); + + t0 = _mm512_mask_blend_epi64(0xcc, r0, r1); + t1 = _mm512_mask_blend_epi64(0x33, r0, r1); + t2 = _mm512_mask_blend_epi64(0xcc, r2, r3); + t3 = _mm512_mask_blend_epi64(0x33, r2, r3); + + t1 = _mm512_permutex_epi64(t1, 0x4e); + t3 = _mm512_permutex_epi64(t3, 0x4e); + + _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst_ptr + 0 * 16 * 4), + t0); + _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst_ptr + 2 * 16 * 4), + t1); + _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst_ptr + 1 * 16 * 4), + t2); + _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst_ptr + 3 * 16 * 4), + t3); + + if (sums_ptr) { + sums_8x2_32bit = _mm512_add_epi32(sums_8x2_32bit, + _mm512_madd_epi16(t0, ones_16bit)); + sums_8x2_32bit = _mm512_add_epi32(sums_8x2_32bit, + _mm512_madd_epi16(t1, ones_16bit)); + sums_8x2_32bit = _mm512_add_epi32(sums_8x2_32bit, + _mm512_madd_epi16(t2, ones_16bit)); + sums_8x2_32bit = _mm512_add_epi32(sums_8x2_32bit, + _mm512_madd_epi16(t3, ones_16bit)); + } + } + + packed_ptr += 16 * kNumChunkedSrcRows; + src_ptr0 += src_inc0; + src_ptr1 += src_inc1; + src_ptr2 += src_inc2; + src_ptr3 += src_inc3; + src_ptr4 += src_inc4; + src_ptr5 += src_inc5; + src_ptr6 += src_inc6; + src_ptr7 += src_inc7; + } + } + + if (sums_ptr) { + const __m256i sums_adjustment_v = _mm256_set1_epi32(sums_adjustment); + + __m256i sums = + _mm256_loadu_si256(reinterpret_cast<const __m256i*>(sums_ptr)); + const __m512i idx = + _mm512_set_epi32(15, 13, 11, 9, 7, 5, 3, 1, 14, 12, 10, 8, 6, 4, 2, 0); + + const __m512i sums_2x8_32bit = + _mm512_permutexvar_epi32(idx, sums_8x2_32bit); + sums = _mm256_add_epi32(sums, sums_adjustment_v); + sums = _mm256_add_epi32(sums, _mm512_castsi512_si256(sums_2x8_32bit)); + sums = _mm256_add_epi32(sums, _mm512_extracti32x8_epi32(sums_2x8_32bit, 1)); + + _mm256_storeu_si256(reinterpret_cast<__m256i*>(sums_ptr), sums); + } +} + inline __m512 LoaduTwo(const float* addr_lo, const float* addr_hi) { const __m512 lower_filled = _mm512_castps256_ps512(_mm256_loadu_ps(addr_lo)); return _mm512_insertf32x8(lower_filled, _mm256_loadu_ps(addr_hi), 1); @@ -674,8 +881,8 @@ void Pack8bitColMajorForAvx512(const std::int8_t* src_ptr, HalfPack8bitAvx512(src_ptr, input_xor, zerobuf, src_stride, remaining_src_cols, src_rows, packed_ptr, sums_ptr, trailing_buf); - ZeroHalf8bitAvx512(src_rows, zerobuf[0] ^ input_xor, - packed_ptr + kHalfBlockOffset); + ZeroHalfAvx512<PackImpl8bitAvx512, std::int8_t>( + src_rows, zerobuf[0] ^ input_xor, packed_ptr + kHalfBlockOffset); // The kernel may not need the second half-blocks sums to be set. if (second_sums_ptr) { for (int i = 0; i < kHalfLayoutCols; ++i) { @@ -697,6 +904,68 @@ void Pack8bitColMajorForAvx512(const std::int8_t* src_ptr, } } +void Pack16bitColMajorForAvx512(const std::int16_t* src_ptr, + const std::int16_t* zerobuf, int src_stride, + int remaining_src_cols, int src_rows, + std::int16_t* packed_ptr, + std::int32_t* sums_ptr) { + profiler::ScopeLabel label("Pack kAvx512 16bit"); + + using Layout = PackImpl16bitAvx512::Layout; + constexpr int kHalfBlockOffset = 32; + RUY_DCHECK_EQ(kHalfBlockOffset * 2, Layout::kRows * Layout::kCols); + static constexpr int kHalfLayoutCols = + PackImpl16bitAvx512::kHalfLayoutCols; // Half the number of cols in a + // block. + RUY_DCHECK_EQ(kHalfLayoutCols, 8); + RUY_DCHECK_EQ(Layout::kCols, 16); + RUY_DCHECK_EQ(Layout::kRows, 4); + + // Each Layout::Rows is 4 contiguous input, contiguous packed elements. + // We process 8 of these chunks at a time, padding short input chunks. + constexpr int kNumRowChunks = 4; + + // Each packed block is 4*16, and there are normally 8. The trailing block is + // only slightly shorter. + constexpr int kTrailingBufSize = + kNumRowChunks * Layout::kCols * Layout::kRows; + std::int16_t trailing_buf[kTrailingBufSize] = {0}; + + std::int32_t* second_sums_ptr = + sums_ptr ? sums_ptr + kHalfLayoutCols : nullptr; + if (remaining_src_cols > kHalfLayoutCols) { + HalfPack16bitAvx512(src_ptr, zerobuf, src_stride, remaining_src_cols, + src_rows, packed_ptr, sums_ptr, trailing_buf); + HalfPack16bitAvx512(src_ptr + src_stride * kHalfLayoutCols, zerobuf, + src_stride, remaining_src_cols - kHalfLayoutCols, + src_rows, packed_ptr + kHalfBlockOffset, + second_sums_ptr, trailing_buf + kHalfBlockOffset); + } else { + HalfPack16bitAvx512(src_ptr, zerobuf, src_stride, remaining_src_cols, + src_rows, packed_ptr, sums_ptr, trailing_buf); + ZeroHalfAvx512<PackImpl16bitAvx512, std::int16_t>( + src_rows, zerobuf[0], packed_ptr + kHalfBlockOffset); + // The kernel may not need the second half-blocks sums to be set. + if (second_sums_ptr) { + for (int i = 0; i < kHalfLayoutCols; ++i) { + second_sums_ptr[i] = (zerobuf[0]) * ((src_rows + 3) & ~3); + } + } + } + constexpr int kChunkedRowMask = kNumRowChunks * Layout::kRows - 1; + const bool trailing_data = (src_rows & kChunkedRowMask) > 0; + // If the number of source rows is not a multiple of kChunkedRowMask, there + // will be data in the trailing buffer, + if (trailing_data) { + const int non_trailing_rows = src_rows & ~kChunkedRowMask; + // Destination "rows" are padded to next highest multiple of Layout::kRows. + const int dst_rows = (src_rows + 3) & ~3; + const int trailing_rows = dst_rows - non_trailing_rows; + memcpy(packed_ptr + Layout::kCols * non_trailing_rows, trailing_buf, + Layout::kCols * trailing_rows * sizeof(std::int16_t)); + } +} + void PackFloatColMajorForAvx512(const float* src_ptr, const float* zerobuf, int src_stride, int remaining_src_cols, int src_rows, float* packed_ptr) { diff --git a/ruy/pack_x86.h b/ruy/pack_x86.h index f3ea54e..a28bbc9 100644 --- a/ruy/pack_x86.h +++ b/ruy/pack_x86.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef RUY_RUY_PACK_X86_H_ #define RUY_RUY_PACK_X86_H_ +#include <algorithm> #include <cstdint> #include <cstring> #include <type_traits> @@ -271,6 +272,52 @@ struct PackImpl<Path::kAvx512, FixedKernelLayout<Order::kColMajor, 4, 16>, } }; +void Pack16bitColMajorForAvx512(const std::int16_t* src_ptr, + const std::int16_t* zerobuf, int src_stride, + int remaining_src_cols, int src_rows, + std::int16_t* packed_ptr, + std::int32_t* sums_ptr); + +template <> +struct PackImpl<Path::kAvx512, FixedKernelLayout<Order::kColMajor, 4, 16>, + std::int16_t, std::int16_t, std::int32_t, Order::kColMajor> { + using Layout = FixedKernelLayout<Order::kColMajor, 4, 16>; + static constexpr int kHalfLayoutCols = + 8; // Half the number of cols in a block. + + static void Run(Tuning, const Mat<std::int16_t>& src_matrix, + PMat<std::int16_t>* packed_matrix, int start_col, + int end_col) { + profiler::ScopeLabel label("Pack (AVX-512 16-bit)"); + + RUY_DCHECK(IsColMajor(src_matrix.layout)); + RUY_DCHECK(IsColMajor(packed_matrix->layout)); + RUY_DCHECK_EQ((end_col - start_col) % Layout::kCols, 0); + RUY_DCHECK_EQ(start_col % Layout::kCols, 0); + RUY_DCHECK_EQ(kHalfLayoutCols * 2, Layout::kCols); + std::int32_t* sums = packed_matrix->sums; + std::int16_t zerobuf[kHalfLayoutCols * Layout::kRows]; + std::fill(zerobuf, zerobuf + kHalfLayoutCols * Layout::kRows, + static_cast<int16_t>(packed_matrix->zero_point)); + for (int block_col = start_col; block_col < end_col; + block_col += Layout::kCols) { + std::int32_t* sums_ptr = sums ? sums + block_col : nullptr; + int src_stride = src_matrix.layout.stride; + const std::int16_t* src_ptr = + src_matrix.data.get() + src_stride * block_col; + int remaining_src_cols = src_matrix.layout.cols - block_col; + + static constexpr int block_col_mask = ~(Layout::kCols - 1); + std::int16_t* packed_ptr = + packed_matrix->data + + packed_matrix->layout.stride * (block_col & block_col_mask); + Pack16bitColMajorForAvx512(src_ptr, zerobuf, src_stride, + remaining_src_cols, src_matrix.layout.rows, + packed_ptr, sums_ptr); + } + } +}; + void PackFloatColMajorForAvx512(const float* src_ptr, const float* zerobuf, int src_stride, int remaining_src_cols, int src_rows, float* packed_ptr); |