Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/google/ruy.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDayeong Lee <dayeongl@google.com>2021-12-09 04:05:44 +0300
committerCopybara-Service <copybara-worker@google.com>2021-12-09 04:06:10 +0300
commitabaaa6a6a5515cfe958cf7b32ae1f2e5ca1b962f (patch)
tree3a74448eee2922024d62fbe2e6b8d92ed1b99f0f
parent6c292a6e91cd3dab6059334d60c09fb5c7d1a94e (diff)
Ruy:Fix 16bit-packing msan error.
PiperOrigin-RevId: 415133840
-rw-r--r--ruy/pack_avx512.cc13
1 files changed, 7 insertions, 6 deletions
diff --git a/ruy/pack_avx512.cc b/ruy/pack_avx512.cc
index 29d6f9f..29a1850 100644
--- a/ruy/pack_avx512.cc
+++ b/ruy/pack_avx512.cc
@@ -70,7 +70,7 @@ namespace {
template <typename PackImplAvx512, typename Scalar>
inline void ZeroHalfAvx512(int src_rows, Scalar packed_zero_point,
- Scalar* packed_ptr) {
+ Scalar* packed_ptr, int chunked_row_mask) {
using Layout = typename PackImplAvx512::Layout;
static constexpr int kHalfLayoutCols =
PackImplAvx512::kHalfLayoutCols; // Half the number of cols in a
@@ -79,7 +79,7 @@ inline void ZeroHalfAvx512(int src_rows, Scalar packed_zero_point,
RUY_DCHECK_EQ(Layout::kCols, 16);
RUY_DCHECK_EQ(Layout::kRows, 4);
- const int non_trailing_blocks = (src_rows & ~31) >> 2;
+ const int non_trailing_blocks = (src_rows & ~chunked_row_mask) >> 2;
// This routine fills half blocks, and typically fills the second halves.
// Thus packed_ptr is already offset by 8 * 4.
for (int k = 0; k < non_trailing_blocks; ++k) {
@@ -865,6 +865,7 @@ void Pack8bitColMajorForAvx512(const std::int8_t* src_ptr,
kNumRowChunks * Layout::kCols * Layout::kRows;
std::int8_t trailing_buf[kTrailingBufSize];
memset(trailing_buf, 0, kTrailingBufSize * sizeof(std::int8_t));
+ constexpr int kChunkedRowMask = kNumRowChunks * Layout::kRows - 1;
std::int32_t* second_sums_ptr =
sums_ptr ? sums_ptr + kHalfLayoutCols : nullptr;
@@ -882,7 +883,8 @@ void Pack8bitColMajorForAvx512(const std::int8_t* src_ptr,
remaining_src_cols, src_rows, packed_ptr, sums_ptr,
trailing_buf);
ZeroHalfAvx512<PackImpl8bitAvx512, std::int8_t>(
- src_rows, zerobuf[0] ^ input_xor, packed_ptr + kHalfBlockOffset);
+ src_rows, zerobuf[0] ^ input_xor, packed_ptr + kHalfBlockOffset,
+ kChunkedRowMask);
// The kernel may not need the second half-blocks sums to be set.
if (second_sums_ptr) {
for (int i = 0; i < kHalfLayoutCols; ++i) {
@@ -890,7 +892,6 @@ void Pack8bitColMajorForAvx512(const std::int8_t* src_ptr,
}
}
}
- constexpr int kChunkedRowMask = kNumRowChunks * Layout::kRows - 1;
const bool trailing_data = (src_rows & kChunkedRowMask) > 0;
// If the number of source rows is not a multiple of kChunkedRowMask, there
// will be data in the trailing buffer,
@@ -930,6 +931,7 @@ void Pack16bitColMajorForAvx512(const std::int16_t* src_ptr,
constexpr int kTrailingBufSize =
kNumRowChunks * Layout::kCols * Layout::kRows;
std::int16_t trailing_buf[kTrailingBufSize] = {0};
+ constexpr int kChunkedRowMask = kNumRowChunks * Layout::kRows - 1;
std::int32_t* second_sums_ptr =
sums_ptr ? sums_ptr + kHalfLayoutCols : nullptr;
@@ -944,7 +946,7 @@ void Pack16bitColMajorForAvx512(const std::int16_t* src_ptr,
HalfPack16bitAvx512(src_ptr, zerobuf, src_stride, remaining_src_cols,
src_rows, packed_ptr, sums_ptr, trailing_buf);
ZeroHalfAvx512<PackImpl16bitAvx512, std::int16_t>(
- src_rows, zerobuf[0], packed_ptr + kHalfBlockOffset);
+ src_rows, zerobuf[0], packed_ptr + kHalfBlockOffset, kChunkedRowMask);
// The kernel may not need the second half-blocks sums to be set.
if (second_sums_ptr) {
for (int i = 0; i < kHalfLayoutCols; ++i) {
@@ -952,7 +954,6 @@ void Pack16bitColMajorForAvx512(const std::int16_t* src_ptr,
}
}
}
- constexpr int kChunkedRowMask = kNumRowChunks * Layout::kRows - 1;
const bool trailing_data = (src_rows & kChunkedRowMask) > 0;
// If the number of source rows is not a multiple of kChunkedRowMask, there
// will be data in the trailing buffer,