Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/google/ruy.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDayeong Lee <dayeongl@google.com>2021-12-07 04:01:48 +0300
committerCopybara-Service <copybara-worker@google.com>2021-12-07 04:02:15 +0300
commit6c292a6e91cd3dab6059334d60c09fb5c7d1a94e (patch)
treec8977f70fedab2f25e39b0d24db1bea197147780
parent8c3fd3f266b4a22d542d4aa41329b5018d6b87e1 (diff)
Ruy:Add new packing for 16bit ColMajor for Avx512.
PiperOrigin-RevId: 414576763
-rw-r--r--ruy/pack_avx512.cc287
-rw-r--r--ruy/pack_x86.h47
2 files changed, 325 insertions, 9 deletions
diff --git a/ruy/pack_avx512.cc b/ruy/pack_avx512.cc
index 5281fa8..29d6f9f 100644
--- a/ruy/pack_avx512.cc
+++ b/ruy/pack_avx512.cc
@@ -38,6 +38,12 @@ void Pack8bitColMajorForAvx512(const std::int8_t*, std::int8_t,
RUY_DCHECK(false);
}
+void Pack16bitColMajorForAvx512(const std::int16_t*, const std::int16_t*, int,
+ int, int, std::int16_t*, std::int32_t*) {
+ // CPU-ID-based checks should disable the path that would reach this point.
+ RUY_DCHECK(false);
+}
+
void PackFloatColMajorForAvx512(const float*, const float*, int, int, int,
float*) {
// CPU-ID-based checks should disable the path that would reach this point.
@@ -56,15 +62,19 @@ void Pack8bitRowMajorForAvx512(const std::uint8_t*, int, int, std::int8_t*, int,
using PackImpl8bitAvx512 =
PackImpl<Path::kAvx512, FixedKernelLayout<Order::kColMajor, 4, 16>,
std::int8_t, std::int8_t, std::int32_t, Order::kColMajor>;
+using PackImpl16bitAvx512 =
+ PackImpl<Path::kAvx512, FixedKernelLayout<Order::kColMajor, 4, 16>,
+ std::int16_t, std::int16_t, std::int32_t, Order::kColMajor>;
namespace {
-inline void ZeroHalf8bitAvx512(int src_rows, std::int8_t packed_zero_point,
- std::int8_t* packed_ptr) {
- using Layout = PackImpl8bitAvx512::Layout;
+template <typename PackImplAvx512, typename Scalar>
+inline void ZeroHalfAvx512(int src_rows, Scalar packed_zero_point,
+ Scalar* packed_ptr) {
+ using Layout = typename PackImplAvx512::Layout;
static constexpr int kHalfLayoutCols =
- PackImpl8bitAvx512::kHalfLayoutCols; // Half the number of cols in a
- // block.
+ PackImplAvx512::kHalfLayoutCols; // Half the number of cols in a
+ // block.
RUY_DCHECK_EQ(kHalfLayoutCols, 8);
RUY_DCHECK_EQ(Layout::kCols, 16);
RUY_DCHECK_EQ(Layout::kRows, 4);
@@ -79,8 +89,8 @@ inline void ZeroHalf8bitAvx512(int src_rows, std::int8_t packed_zero_point,
}
}
-inline __m512i LoaduTwo(const std::int8_t* addr_lo,
- const std::int8_t* addr_hi) {
+template <typename Scalar>
+inline __m512i LoaduTwo(const Scalar* addr_lo, const Scalar* addr_hi) {
__m512i lower_filled = _mm512_castsi256_si512(
_mm256_loadu_si256(reinterpret_cast<const __m256i*>(addr_lo)));
return _mm512_inserti32x8(
@@ -98,6 +108,16 @@ inline __m512i MaskLoaduTwo(__mmask32 row_mask, const __m256i default_value_v,
1);
}
+inline __m512i MaskLoaduTwo(__mmask32 row_mask, const __m256i default_value_v,
+ const std::int16_t* addr_lo,
+ const std::int16_t* addr_hi) {
+ const __m512i lower_filled = _mm512_castsi256_si512(
+ _mm256_mask_loadu_epi16(default_value_v, row_mask, addr_lo));
+ return _mm512_inserti32x8(
+ lower_filled, _mm256_mask_loadu_epi16(default_value_v, row_mask, addr_hi),
+ 1);
+}
+
inline void HalfPack8bitAvx512(const std::int8_t* src_ptr,
std::int8_t input_xor,
const std::int8_t* zerobuf, int src_stride,
@@ -454,6 +474,193 @@ inline void HalfPack8bitAvx512(const std::int8_t* src_ptr,
}
}
+inline void HalfPack16bitAvx512(const std::int16_t* src_ptr,
+ const std::int16_t* zerobuf, int src_stride,
+ int remaining_src_cols, int src_rows,
+ std::int16_t* packed_ptr,
+ std::int32_t* sums_ptr,
+ std::int16_t* trailing_buf) {
+ using Layout = PackImpl16bitAvx512::Layout;
+ RUY_DCHECK_EQ(Layout::kCols, 16);
+ RUY_DCHECK_EQ(Layout::kRows, 4);
+ // Each Layout::Rows is 4 contiguous input, contiguous packed elements.
+ // We process 4 of these chunks at a time, padding std::int16_t input chunks.
+ constexpr int kNumRowChunks = 4;
+ constexpr int kNumChunkedSrcRows = kNumRowChunks * Layout::kRows;
+
+ const std::int16_t* src_ptr0 = src_ptr;
+ const std::int16_t* src_ptr1 = src_ptr0 + src_stride;
+ const std::int16_t* src_ptr2 = src_ptr1 + src_stride;
+ const std::int16_t* src_ptr3 = src_ptr2 + src_stride;
+ const std::int16_t* src_ptr4 = src_ptr3 + src_stride;
+ const std::int16_t* src_ptr5 = src_ptr4 + src_stride;
+ const std::int16_t* src_ptr6 = src_ptr5 + src_stride;
+ const std::int16_t* src_ptr7 = src_ptr6 + src_stride;
+ std::int64_t src_inc0 = kNumChunkedSrcRows;
+ std::int64_t src_inc1 = kNumChunkedSrcRows;
+ std::int64_t src_inc2 = kNumChunkedSrcRows;
+ std::int64_t src_inc3 = kNumChunkedSrcRows;
+ std::int64_t src_inc4 = kNumChunkedSrcRows;
+ std::int64_t src_inc5 = kNumChunkedSrcRows;
+ std::int64_t src_inc6 = kNumChunkedSrcRows;
+ std::int64_t src_inc7 = kNumChunkedSrcRows;
+ // Handle cases where source does not have kHalfLayoutCols (8) columns.
+ if (remaining_src_cols < 8) {
+ if (remaining_src_cols <= 0) {
+ src_ptr0 = zerobuf;
+ src_inc0 = 0;
+ }
+ if (remaining_src_cols <= 1) {
+ src_ptr1 = zerobuf;
+ src_inc1 = 0;
+ }
+ if (remaining_src_cols <= 2) {
+ src_ptr2 = zerobuf;
+ src_inc2 = 0;
+ }
+ if (remaining_src_cols <= 3) {
+ src_ptr3 = zerobuf;
+ src_inc3 = 0;
+ }
+ if (remaining_src_cols <= 4) {
+ src_ptr4 = zerobuf;
+ src_inc4 = 0;
+ }
+ if (remaining_src_cols <= 5) {
+ src_ptr5 = zerobuf;
+ src_inc5 = 0;
+ }
+ if (remaining_src_cols <= 6) {
+ src_ptr6 = zerobuf;
+ src_inc6 = 0;
+ }
+ src_ptr7 = zerobuf;
+ src_inc7 = 0;
+ }
+
+ const std::int16_t zero_point = zerobuf[0];
+
+ if (sums_ptr) {
+ // i: kHalfLayoutCols.
+ for (int i = 0; i < 8; ++i) {
+ sums_ptr[i] = 0;
+ }
+ }
+ std::int32_t sums_adjustment = 0;
+ const __m512i ones_16bit = _mm512_set1_epi16(1);
+ __m512i sums_8x2_32bit = _mm512_set1_epi32(0);
+
+ // The overall packing effectively pads the source rows to
+ // (src_rows + 31) & ~31. The iteration over k may skip when m=1, and then we
+ // only pack for (src_rows + 15) & ~15. When there is an incomplete
+ // destination block, this is stored into trailing_buf instead of packed_ptr.
+ for (int k = 0; k < src_rows; k += 2 * kNumChunkedSrcRows) {
+ // m: {0, 1} for 2 chunks of rows.
+ for (int m = 0; m < 2; ++m) {
+ const int available_src_rows = src_rows - k - m * kNumChunkedSrcRows;
+
+ // Available source rows.
+ // If this is less than 0 (for m=1), we skip, having filled trailing
+ // buffer for m=0. Also, if source rows is zero on m=1, then we filled
+ // exactly to the end of the column in the packed buffer.
+ if (available_src_rows > 0) {
+ __m512i t0, t1, t2, t3;
+ __m512i r0, r1, r2, r3;
+ std::int16_t* dst_ptr = packed_ptr;
+
+ if (available_src_rows >= kNumChunkedSrcRows) {
+ t0 = LoaduTwo(src_ptr0, src_ptr4);
+ t1 = LoaduTwo(src_ptr1, src_ptr5);
+ t2 = LoaduTwo(src_ptr2, src_ptr6);
+ t3 = LoaduTwo(src_ptr3, src_ptr7);
+ } else {
+ RUY_DCHECK_LT(available_src_rows >> 2, kNumChunkedSrcRows);
+ // We do not care what goes into the trailing buffer, but we want
+ // in_data[...] == zero_point for irrelevant values in the summation.
+ //
+ // We compensate for padding-with-zero_point by initializing the
+ // summations with the compensating offset.
+ sums_adjustment +=
+ -(zero_point)*4 * (4 - ((available_src_rows + 3) >> 2));
+
+ const __m256i zero_point_v = _mm256_set1_epi16(zero_point);
+ const __mmask32 row_mask =
+ (static_cast<std::uint64_t>(1) << available_src_rows) - 1;
+
+ t0 = MaskLoaduTwo(row_mask, zero_point_v, src_ptr0, src_ptr4);
+ t1 = MaskLoaduTwo(row_mask, zero_point_v, src_ptr1, src_ptr5);
+ t2 = MaskLoaduTwo(row_mask, zero_point_v, src_ptr2, src_ptr6);
+ t3 = MaskLoaduTwo(row_mask, zero_point_v, src_ptr3, src_ptr7);
+ dst_ptr = trailing_buf;
+ }
+
+ r0 = _mm512_unpacklo_epi64(t0, t1);
+ r2 = _mm512_unpackhi_epi64(t0, t1);
+ r1 = _mm512_unpacklo_epi64(t2, t3);
+ r3 = _mm512_unpackhi_epi64(t2, t3);
+
+ r1 = _mm512_permutex_epi64(r1, 0x4e);
+ r3 = _mm512_permutex_epi64(r3, 0x4e);
+
+ t0 = _mm512_mask_blend_epi64(0xcc, r0, r1);
+ t1 = _mm512_mask_blend_epi64(0x33, r0, r1);
+ t2 = _mm512_mask_blend_epi64(0xcc, r2, r3);
+ t3 = _mm512_mask_blend_epi64(0x33, r2, r3);
+
+ t1 = _mm512_permutex_epi64(t1, 0x4e);
+ t3 = _mm512_permutex_epi64(t3, 0x4e);
+
+ _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst_ptr + 0 * 16 * 4),
+ t0);
+ _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst_ptr + 2 * 16 * 4),
+ t1);
+ _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst_ptr + 1 * 16 * 4),
+ t2);
+ _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst_ptr + 3 * 16 * 4),
+ t3);
+
+ if (sums_ptr) {
+ sums_8x2_32bit = _mm512_add_epi32(sums_8x2_32bit,
+ _mm512_madd_epi16(t0, ones_16bit));
+ sums_8x2_32bit = _mm512_add_epi32(sums_8x2_32bit,
+ _mm512_madd_epi16(t1, ones_16bit));
+ sums_8x2_32bit = _mm512_add_epi32(sums_8x2_32bit,
+ _mm512_madd_epi16(t2, ones_16bit));
+ sums_8x2_32bit = _mm512_add_epi32(sums_8x2_32bit,
+ _mm512_madd_epi16(t3, ones_16bit));
+ }
+ }
+
+ packed_ptr += 16 * kNumChunkedSrcRows;
+ src_ptr0 += src_inc0;
+ src_ptr1 += src_inc1;
+ src_ptr2 += src_inc2;
+ src_ptr3 += src_inc3;
+ src_ptr4 += src_inc4;
+ src_ptr5 += src_inc5;
+ src_ptr6 += src_inc6;
+ src_ptr7 += src_inc7;
+ }
+ }
+
+ if (sums_ptr) {
+ const __m256i sums_adjustment_v = _mm256_set1_epi32(sums_adjustment);
+
+ __m256i sums =
+ _mm256_loadu_si256(reinterpret_cast<const __m256i*>(sums_ptr));
+ const __m512i idx =
+ _mm512_set_epi32(15, 13, 11, 9, 7, 5, 3, 1, 14, 12, 10, 8, 6, 4, 2, 0);
+
+ const __m512i sums_2x8_32bit =
+ _mm512_permutexvar_epi32(idx, sums_8x2_32bit);
+ sums = _mm256_add_epi32(sums, sums_adjustment_v);
+ sums = _mm256_add_epi32(sums, _mm512_castsi512_si256(sums_2x8_32bit));
+ sums = _mm256_add_epi32(sums, _mm512_extracti32x8_epi32(sums_2x8_32bit, 1));
+
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(sums_ptr), sums);
+ }
+}
+
inline __m512 LoaduTwo(const float* addr_lo, const float* addr_hi) {
const __m512 lower_filled = _mm512_castps256_ps512(_mm256_loadu_ps(addr_lo));
return _mm512_insertf32x8(lower_filled, _mm256_loadu_ps(addr_hi), 1);
@@ -674,8 +881,8 @@ void Pack8bitColMajorForAvx512(const std::int8_t* src_ptr,
HalfPack8bitAvx512(src_ptr, input_xor, zerobuf, src_stride,
remaining_src_cols, src_rows, packed_ptr, sums_ptr,
trailing_buf);
- ZeroHalf8bitAvx512(src_rows, zerobuf[0] ^ input_xor,
- packed_ptr + kHalfBlockOffset);
+ ZeroHalfAvx512<PackImpl8bitAvx512, std::int8_t>(
+ src_rows, zerobuf[0] ^ input_xor, packed_ptr + kHalfBlockOffset);
// The kernel may not need the second half-blocks sums to be set.
if (second_sums_ptr) {
for (int i = 0; i < kHalfLayoutCols; ++i) {
@@ -697,6 +904,68 @@ void Pack8bitColMajorForAvx512(const std::int8_t* src_ptr,
}
}
+void Pack16bitColMajorForAvx512(const std::int16_t* src_ptr,
+ const std::int16_t* zerobuf, int src_stride,
+ int remaining_src_cols, int src_rows,
+ std::int16_t* packed_ptr,
+ std::int32_t* sums_ptr) {
+ profiler::ScopeLabel label("Pack kAvx512 16bit");
+
+ using Layout = PackImpl16bitAvx512::Layout;
+ constexpr int kHalfBlockOffset = 32;
+ RUY_DCHECK_EQ(kHalfBlockOffset * 2, Layout::kRows * Layout::kCols);
+ static constexpr int kHalfLayoutCols =
+ PackImpl16bitAvx512::kHalfLayoutCols; // Half the number of cols in a
+ // block.
+ RUY_DCHECK_EQ(kHalfLayoutCols, 8);
+ RUY_DCHECK_EQ(Layout::kCols, 16);
+ RUY_DCHECK_EQ(Layout::kRows, 4);
+
+ // Each Layout::Rows is 4 contiguous input, contiguous packed elements.
+ // We process 8 of these chunks at a time, padding short input chunks.
+ constexpr int kNumRowChunks = 4;
+
+ // Each packed block is 4*16, and there are normally 8. The trailing block is
+ // only slightly shorter.
+ constexpr int kTrailingBufSize =
+ kNumRowChunks * Layout::kCols * Layout::kRows;
+ std::int16_t trailing_buf[kTrailingBufSize] = {0};
+
+ std::int32_t* second_sums_ptr =
+ sums_ptr ? sums_ptr + kHalfLayoutCols : nullptr;
+ if (remaining_src_cols > kHalfLayoutCols) {
+ HalfPack16bitAvx512(src_ptr, zerobuf, src_stride, remaining_src_cols,
+ src_rows, packed_ptr, sums_ptr, trailing_buf);
+ HalfPack16bitAvx512(src_ptr + src_stride * kHalfLayoutCols, zerobuf,
+ src_stride, remaining_src_cols - kHalfLayoutCols,
+ src_rows, packed_ptr + kHalfBlockOffset,
+ second_sums_ptr, trailing_buf + kHalfBlockOffset);
+ } else {
+ HalfPack16bitAvx512(src_ptr, zerobuf, src_stride, remaining_src_cols,
+ src_rows, packed_ptr, sums_ptr, trailing_buf);
+ ZeroHalfAvx512<PackImpl16bitAvx512, std::int16_t>(
+ src_rows, zerobuf[0], packed_ptr + kHalfBlockOffset);
+ // The kernel may not need the second half-blocks sums to be set.
+ if (second_sums_ptr) {
+ for (int i = 0; i < kHalfLayoutCols; ++i) {
+ second_sums_ptr[i] = (zerobuf[0]) * ((src_rows + 3) & ~3);
+ }
+ }
+ }
+ constexpr int kChunkedRowMask = kNumRowChunks * Layout::kRows - 1;
+ const bool trailing_data = (src_rows & kChunkedRowMask) > 0;
+ // If the number of source rows is not a multiple of kChunkedRowMask, there
+ // will be data in the trailing buffer,
+ if (trailing_data) {
+ const int non_trailing_rows = src_rows & ~kChunkedRowMask;
+ // Destination "rows" are padded to next highest multiple of Layout::kRows.
+ const int dst_rows = (src_rows + 3) & ~3;
+ const int trailing_rows = dst_rows - non_trailing_rows;
+ memcpy(packed_ptr + Layout::kCols * non_trailing_rows, trailing_buf,
+ Layout::kCols * trailing_rows * sizeof(std::int16_t));
+ }
+}
+
void PackFloatColMajorForAvx512(const float* src_ptr, const float* zerobuf,
int src_stride, int remaining_src_cols,
int src_rows, float* packed_ptr) {
diff --git a/ruy/pack_x86.h b/ruy/pack_x86.h
index f3ea54e..a28bbc9 100644
--- a/ruy/pack_x86.h
+++ b/ruy/pack_x86.h
@@ -16,6 +16,7 @@ limitations under the License.
#ifndef RUY_RUY_PACK_X86_H_
#define RUY_RUY_PACK_X86_H_
+#include <algorithm>
#include <cstdint>
#include <cstring>
#include <type_traits>
@@ -271,6 +272,52 @@ struct PackImpl<Path::kAvx512, FixedKernelLayout<Order::kColMajor, 4, 16>,
}
};
+void Pack16bitColMajorForAvx512(const std::int16_t* src_ptr,
+ const std::int16_t* zerobuf, int src_stride,
+ int remaining_src_cols, int src_rows,
+ std::int16_t* packed_ptr,
+ std::int32_t* sums_ptr);
+
+template <>
+struct PackImpl<Path::kAvx512, FixedKernelLayout<Order::kColMajor, 4, 16>,
+ std::int16_t, std::int16_t, std::int32_t, Order::kColMajor> {
+ using Layout = FixedKernelLayout<Order::kColMajor, 4, 16>;
+ static constexpr int kHalfLayoutCols =
+ 8; // Half the number of cols in a block.
+
+ static void Run(Tuning, const Mat<std::int16_t>& src_matrix,
+ PMat<std::int16_t>* packed_matrix, int start_col,
+ int end_col) {
+ profiler::ScopeLabel label("Pack (AVX-512 16-bit)");
+
+ RUY_DCHECK(IsColMajor(src_matrix.layout));
+ RUY_DCHECK(IsColMajor(packed_matrix->layout));
+ RUY_DCHECK_EQ((end_col - start_col) % Layout::kCols, 0);
+ RUY_DCHECK_EQ(start_col % Layout::kCols, 0);
+ RUY_DCHECK_EQ(kHalfLayoutCols * 2, Layout::kCols);
+ std::int32_t* sums = packed_matrix->sums;
+ std::int16_t zerobuf[kHalfLayoutCols * Layout::kRows];
+ std::fill(zerobuf, zerobuf + kHalfLayoutCols * Layout::kRows,
+ static_cast<int16_t>(packed_matrix->zero_point));
+ for (int block_col = start_col; block_col < end_col;
+ block_col += Layout::kCols) {
+ std::int32_t* sums_ptr = sums ? sums + block_col : nullptr;
+ int src_stride = src_matrix.layout.stride;
+ const std::int16_t* src_ptr =
+ src_matrix.data.get() + src_stride * block_col;
+ int remaining_src_cols = src_matrix.layout.cols - block_col;
+
+ static constexpr int block_col_mask = ~(Layout::kCols - 1);
+ std::int16_t* packed_ptr =
+ packed_matrix->data +
+ packed_matrix->layout.stride * (block_col & block_col_mask);
+ Pack16bitColMajorForAvx512(src_ptr, zerobuf, src_stride,
+ remaining_src_cols, src_matrix.layout.rows,
+ packed_ptr, sums_ptr);
+ }
+ }
+};
+
void PackFloatColMajorForAvx512(const float* src_ptr, const float* zerobuf,
int src_stride, int remaining_src_cols,
int src_rows, float* packed_ptr);