Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/intgemm/intgemm.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'avx512vnni_gemm.h')
-rw-r--r--avx512vnni_gemm.h98
1 files changed, 49 insertions, 49 deletions
diff --git a/avx512vnni_gemm.h b/avx512vnni_gemm.h
index 3f616a6..59f6405 100644
--- a/avx512vnni_gemm.h
+++ b/avx512vnni_gemm.h
@@ -11,39 +11,39 @@ namespace intgemm {
struct AVX512VNNI_8bit : public AVX512_8bit {
template <typename Callback>
INTGEMM_AVX512VNNI static void Multiply(const int8_t *A, const int8_t *B, Index A_rows, Index width, Index B_cols, Callback callback) {
- typedef __m512i Integer;
- assert(width % sizeof(Integer) == 0);
+ typedef __m512i Register;
+ assert(width % sizeof(Register) == 0);
assert(B_cols % 8 == 0);
- assert(reinterpret_cast<uintptr_t>(A) % sizeof(Integer) == 0);
- assert(reinterpret_cast<uintptr_t>(B) % sizeof(Integer) == 0);
+ assert(reinterpret_cast<uintptr_t>(A) % sizeof(Register) == 0);
+ assert(reinterpret_cast<uintptr_t>(B) % sizeof(Register) == 0);
auto callback_impl = callbacks::CallbackImpl<CPUType::AVX2, Callback>(callback);
- const int simd_width = width / sizeof(Integer);
- const Integer *B0_col = reinterpret_cast<const Integer*>(B);
- Integer zeros = setzero_si<Integer>();
+ const int simd_width = width / sizeof(Register);
+ const Register *B0_col = reinterpret_cast<const Register*>(B);
+ Register zeros = setzero_si<Register>();
// Go over 8 columns of B at a time.
for (Index B0_colidx = 0; B0_colidx != B_cols; B0_col += 8 * simd_width, B0_colidx += 8) {
// Process one row of A at a time. Doesn't seem to be faster to do multiple rows of A at once.
for (Index A_rowidx = 0; A_rowidx < A_rows; ++A_rowidx) {
// Iterate over shared (inner) dimension.
- const Integer *A_live = reinterpret_cast<const Integer *>(A + A_rowidx * width);
- const Integer *A_end = A_live + simd_width;
- const Integer *B_live = B0_col;
+ const Register *A_live = reinterpret_cast<const Register *>(A + A_rowidx * width);
+ const Register *A_end = A_live + simd_width;
+ const Register *B_live = B0_col;
// TODO: separate first step.
- Integer sum0 = zeros, sum1 = zeros, sum2 = zeros, sum3 = zeros, sum4 = zeros, sum5 = zeros, sum6 = zeros, sum7 = zeros;
+ Register sum0 = zeros, sum1 = zeros, sum2 = zeros, sum3 = zeros, sum4 = zeros, sum5 = zeros, sum6 = zeros, sum7 = zeros;
for (; A_live != A_end; ++A_live, B_live += 8) {
- Integer a = *A_live;
+ Register a = *A_live;
// Retrieve the conveniently consecutive values of B.
- Integer b0 = *B_live;
- Integer b1 = *(B_live + 1);
- Integer b2 = *(B_live + 2);
- Integer b3 = *(B_live + 3);
- Integer b4 = *(B_live + 4);
- Integer b5 = *(B_live + 5);
- Integer b6 = *(B_live + 6);
- Integer b7 = *(B_live + 7);
+ Register b0 = *B_live;
+ Register b1 = *(B_live + 1);
+ Register b2 = *(B_live + 2);
+ Register b3 = *(B_live + 3);
+ Register b4 = *(B_live + 4);
+ Register b5 = *(B_live + 5);
+ Register b6 = *(B_live + 6);
+ Register b7 = *(B_live + 7);
// Get a mask where a is negative.
__mmask64 neg_mask = _mm512_test_epi8_mask(a, _mm512_set1_epi8(-128));
- Integer a_positive = _mm512_abs_epi8(a);
+ Register a_positive = _mm512_abs_epi8(a);
// Negate by subtracting from zero with a mask.
b0 = _mm512_mask_sub_epi8(b0, neg_mask, zeros, b0);
b1 = _mm512_mask_sub_epi8(b1, neg_mask, zeros, b1);
@@ -62,8 +62,8 @@ struct AVX512VNNI_8bit : public AVX512_8bit {
sum6 = _mm512_dpbusds_epi32(sum6, a_positive, b6);
sum7 = _mm512_dpbusds_epi32(sum7, a_positive, b7);
}
- Integer pack0123 = Pack0123(sum0, sum1, sum2, sum3);
- Integer pack4567 = Pack0123(sum4, sum5, sum6, sum7);
+ Register pack0123 = Pack0123(sum0, sum1, sum2, sum3);
+ Register pack4567 = Pack0123(sum4, sum5, sum6, sum7);
auto total = PermuteSummer(pack0123, pack4567);
callback_impl(total, callbacks::OutputBufferInfo(A_rowidx, B0_colidx, A_rows, B_cols));
}
@@ -72,27 +72,27 @@ struct AVX512VNNI_8bit : public AVX512_8bit {
template <typename Callback>
INTGEMM_AVX512VNNI static void Multiply8Shift(const uint8_t *A, const int8_t *B, Index A_rows, Index width, Index B_cols, Callback callback) {
- typedef __m512i Integer;
- assert(width % sizeof(Integer) == 0);
+ typedef __m512i Register;
+ assert(width % sizeof(Register) == 0);
assert(B_cols % 8 == 0);
- assert(reinterpret_cast<uintptr_t>(A) % sizeof(Integer) == 0);
- assert(reinterpret_cast<uintptr_t>(B) % sizeof(Integer) == 0);
+ assert(reinterpret_cast<uintptr_t>(A) % sizeof(Register) == 0);
+ assert(reinterpret_cast<uintptr_t>(B) % sizeof(Register) == 0);
auto callback_impl = callbacks::CallbackImpl<CPUType::AVX2, Callback>(callback);
- const int simd_width = width / sizeof(Integer);
- const Integer *B0_col = reinterpret_cast<const Integer*>(B);
- Integer zeros = setzero_si<Integer>();
+ const int simd_width = width / sizeof(Register);
+ const Register *B0_col = reinterpret_cast<const Register*>(B);
+ Register zeros = setzero_si<Register>();
// Go over 8 columns of B at a time.
for (Index B0_colidx = 0; B0_colidx != B_cols; B0_col += 8 * simd_width, B0_colidx += 8) {
// Process one row of A at a time. Doesn't seem to be faster to do multiple rows of A at once.
for (Index A_rowidx = 0; A_rowidx < A_rows; ++A_rowidx) {
// Iterate over shared (inner) dimension.
- const Integer *A_live = reinterpret_cast<const Integer *>(A + A_rowidx * width);
- const Integer *A_end = A_live + simd_width;
- const Integer *B_live = B0_col;
+ const Register *A_live = reinterpret_cast<const Register *>(A + A_rowidx * width);
+ const Register *A_end = A_live + simd_width;
+ const Register *B_live = B0_col;
// TODO: separate first step.
- Integer sum0 = zeros, sum1 = zeros, sum2 = zeros, sum3 = zeros, sum4 = zeros, sum5 = zeros, sum6 = zeros, sum7 = zeros;
+ Register sum0 = zeros, sum1 = zeros, sum2 = zeros, sum3 = zeros, sum4 = zeros, sum5 = zeros, sum6 = zeros, sum7 = zeros;
for (; A_live != A_end; ++A_live, B_live += 8) {
- Integer a = *A_live;
+ Register a = *A_live;
//MultiplyAdd
sum0 = _mm512_dpbusds_epi32(sum0, a, *B_live);
sum1 = _mm512_dpbusds_epi32(sum1, a, *(B_live + 1));
@@ -103,8 +103,8 @@ struct AVX512VNNI_8bit : public AVX512_8bit {
sum6 = _mm512_dpbusds_epi32(sum6, a, *(B_live + 6));
sum7 = _mm512_dpbusds_epi32(sum7, a, *(B_live + 7));
}
- Integer pack0123 = Pack0123(sum0, sum1, sum2, sum3);
- Integer pack4567 = Pack0123(sum4, sum5, sum6, sum7);
+ Register pack0123 = Pack0123(sum0, sum1, sum2, sum3);
+ Register pack4567 = Pack0123(sum4, sum5, sum6, sum7);
auto total = PermuteSummer(pack0123, pack4567);
callback_impl(total, callbacks::OutputBufferInfo(A_rowidx, B0_colidx, A_rows, B_cols));
}
@@ -113,22 +113,22 @@ struct AVX512VNNI_8bit : public AVX512_8bit {
template <typename Callback>
INTGEMM_AVX512VNNI static void PrepareBias(const int8_t *B, Index width, Index B_cols, Callback callback) {
- typedef __m512i Integer;
- assert(width % sizeof(Integer) == 0);
+ typedef __m512i Register;
+ assert(width % sizeof(Register) == 0);
assert(B_cols % 8 == 0);
- assert(reinterpret_cast<uintptr_t>(B) % sizeof(Integer) == 0);
+ assert(reinterpret_cast<uintptr_t>(B) % sizeof(Register) == 0);
auto callback_impl = callbacks::CallbackImpl<CPUType::AVX2, Callback>(callback);
- const int simd_width = width / sizeof(Integer);
- const Integer *B0_col = reinterpret_cast<const Integer*>(B);
- Integer zeros = setzero_si<Integer>();
- const Integer a = set1_epi8<Integer>(1);
+ const int simd_width = width / sizeof(Register);
+ const Register *B0_col = reinterpret_cast<const Register*>(B);
+ Register zeros = setzero_si<Register>();
+ const Register a = set1_epi8<Register>(1);
// Go over 8 columns of B at a time.
for (Index B0_colidx = 0; B0_colidx != B_cols; B0_col += 8 * simd_width, B0_colidx += 8) {
- const Integer *B_live = B0_col; //In order to make the code look as much as possible as the above function
- const Integer *B_end = B_live + simd_width*8;
+ const Register *B_live = B0_col; //In order to make the code look as much as possible as the above function
+ const Register *B_end = B_live + simd_width*8;
// TODO: separate first step.
- Integer sum0 = zeros, sum1 = zeros, sum2 = zeros, sum3 = zeros, sum4 = zeros, sum5 = zeros, sum6 = zeros, sum7 = zeros;
+ Register sum0 = zeros, sum1 = zeros, sum2 = zeros, sum3 = zeros, sum4 = zeros, sum5 = zeros, sum6 = zeros, sum7 = zeros;
for (; B_live != B_end; B_live += 8) {
// Retrieve the conveniently consecutive values of B.
sum0 = _mm512_dpbusds_epi32(sum0, a, *B_live);
@@ -140,8 +140,8 @@ struct AVX512VNNI_8bit : public AVX512_8bit {
sum6 = _mm512_dpbusds_epi32(sum6, a, *(B_live + 6));
sum7 = _mm512_dpbusds_epi32(sum7, a, *(B_live + 7));
}
- Integer pack0123 = Pack0123(sum0, sum1, sum2, sum3);
- Integer pack4567 = Pack0123(sum4, sum5, sum6, sum7);
+ Register pack0123 = Pack0123(sum0, sum1, sum2, sum3);
+ Register pack4567 = Pack0123(sum4, sum5, sum6, sum7);
auto total = PermuteSummer(pack0123, pack4567);
callback_impl(total, callbacks::OutputBufferInfo(0, B0_colidx, 1, B_cols));
}