diff options
author | Dayeong Lee <dayeongl@google.com> | 2021-12-09 04:05:44 +0300 |
---|---|---|
committer | Copybara-Service <copybara-worker@google.com> | 2021-12-09 04:06:10 +0300 |
commit | abaaa6a6a5515cfe958cf7b32ae1f2e5ca1b962f (patch) | |
tree | 3a74448eee2922024d62fbe2e6b8d92ed1b99f0f | |
parent | 6c292a6e91cd3dab6059334d60c09fb5c7d1a94e (diff) |
Ruy:Fix 16bit-packing msan error.
PiperOrigin-RevId: 415133840
-rw-r--r-- | ruy/pack_avx512.cc | 13 |
1 files changed, 7 insertions, 6 deletions
diff --git a/ruy/pack_avx512.cc b/ruy/pack_avx512.cc index 29d6f9f..29a1850 100644 --- a/ruy/pack_avx512.cc +++ b/ruy/pack_avx512.cc @@ -70,7 +70,7 @@ namespace { template <typename PackImplAvx512, typename Scalar> inline void ZeroHalfAvx512(int src_rows, Scalar packed_zero_point, - Scalar* packed_ptr) { + Scalar* packed_ptr, int chunked_row_mask) { using Layout = typename PackImplAvx512::Layout; static constexpr int kHalfLayoutCols = PackImplAvx512::kHalfLayoutCols; // Half the number of cols in a @@ -79,7 +79,7 @@ inline void ZeroHalfAvx512(int src_rows, Scalar packed_zero_point, RUY_DCHECK_EQ(Layout::kCols, 16); RUY_DCHECK_EQ(Layout::kRows, 4); - const int non_trailing_blocks = (src_rows & ~31) >> 2; + const int non_trailing_blocks = (src_rows & ~chunked_row_mask) >> 2; // This routine fills half blocks, and typically fills the second halves. // Thus packed_ptr is already offset by 8 * 4. for (int k = 0; k < non_trailing_blocks; ++k) { @@ -865,6 +865,7 @@ void Pack8bitColMajorForAvx512(const std::int8_t* src_ptr, kNumRowChunks * Layout::kCols * Layout::kRows; std::int8_t trailing_buf[kTrailingBufSize]; memset(trailing_buf, 0, kTrailingBufSize * sizeof(std::int8_t)); + constexpr int kChunkedRowMask = kNumRowChunks * Layout::kRows - 1; std::int32_t* second_sums_ptr = sums_ptr ? sums_ptr + kHalfLayoutCols : nullptr; @@ -882,7 +883,8 @@ void Pack8bitColMajorForAvx512(const std::int8_t* src_ptr, remaining_src_cols, src_rows, packed_ptr, sums_ptr, trailing_buf); ZeroHalfAvx512<PackImpl8bitAvx512, std::int8_t>( - src_rows, zerobuf[0] ^ input_xor, packed_ptr + kHalfBlockOffset); + src_rows, zerobuf[0] ^ input_xor, packed_ptr + kHalfBlockOffset, + kChunkedRowMask); // The kernel may not need the second half-blocks sums to be set. if (second_sums_ptr) { for (int i = 0; i < kHalfLayoutCols; ++i) { @@ -890,7 +892,6 @@ void Pack8bitColMajorForAvx512(const std::int8_t* src_ptr, } } } - constexpr int kChunkedRowMask = kNumRowChunks * Layout::kRows - 1; const bool trailing_data = (src_rows & kChunkedRowMask) > 0; // If the number of source rows is not a multiple of kChunkedRowMask, there // will be data in the trailing buffer, @@ -930,6 +931,7 @@ void Pack16bitColMajorForAvx512(const std::int16_t* src_ptr, constexpr int kTrailingBufSize = kNumRowChunks * Layout::kCols * Layout::kRows; std::int16_t trailing_buf[kTrailingBufSize] = {0}; + constexpr int kChunkedRowMask = kNumRowChunks * Layout::kRows - 1; std::int32_t* second_sums_ptr = sums_ptr ? sums_ptr + kHalfLayoutCols : nullptr; @@ -944,7 +946,7 @@ void Pack16bitColMajorForAvx512(const std::int16_t* src_ptr, HalfPack16bitAvx512(src_ptr, zerobuf, src_stride, remaining_src_cols, src_rows, packed_ptr, sums_ptr, trailing_buf); ZeroHalfAvx512<PackImpl16bitAvx512, std::int16_t>( - src_rows, zerobuf[0], packed_ptr + kHalfBlockOffset); + src_rows, zerobuf[0], packed_ptr + kHalfBlockOffset, kChunkedRowMask); // The kernel may not need the second half-blocks sums to be set. if (second_sums_ptr) { for (int i = 0; i < kHalfLayoutCols; ++i) { @@ -952,7 +954,6 @@ void Pack16bitColMajorForAvx512(const std::int16_t* src_ptr, } } } - constexpr int kChunkedRowMask = kNumRowChunks * Layout::kRows - 1; const bool trailing_data = (src_rows & kChunkedRowMask) > 0; // If the number of source rows is not a multiple of kChunkedRowMask, there // will be data in the trailing buffer, |