Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/intgemm/intgemm.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'avx512_gemm.h')
-rw-r--r--avx512_gemm.h56
1 files changed, 28 insertions, 28 deletions
diff --git a/avx512_gemm.h b/avx512_gemm.h
index 91fdd8a..2a5fff1 100644
--- a/avx512_gemm.h
+++ b/avx512_gemm.h
@@ -70,12 +70,12 @@ INTGEMM_AVX512BW inline __m512i QuantizerGrabHalves(const float *input0, const f
// being used for the quantizer.
class QuantizeTile16 {
public:
- typedef __m512i Integer;
+ typedef __m512i Register;
/* Only INTGEMM_AVX512F is necessary but due to GCC 5.4 bug we have to set INTGEMM_AVX512BW */
INTGEMM_AVX512BW explicit QuantizeTile16(float mult) : mult_reg_(_mm512_set1_ps(mult)) {}
- INTGEMM_AVX512BW Integer ConsecutiveWithWrapping(const float *input, Index cols_left, Index cols, Index row_step) const {
+ INTGEMM_AVX512BW Register ConsecutiveWithWrapping(const float *input, Index cols_left, Index cols, Index row_step) const {
auto input0 = input;
auto input1 = input + 16 + (cols_left <= 16 ? cols * (row_step - 1) : 0);
auto g0 = QuantizerGrabHalves(input0, input1, mult_reg_);
@@ -98,24 +98,24 @@ class QuantizeTile16 {
class QuantizeTile8 {
public:
- typedef __m512i Integer;
+ typedef __m512i Register;
/* Only INTGEMM_AVX512F is necessary but due to GCC 5.4 bug we have to set INTGEMM_AVX512BW */
INTGEMM_AVX512BW explicit QuantizeTile8(float mult) : mult_reg_(_mm512_set1_ps(mult)) {}
- INTGEMM_AVX512BW Integer ConsecutiveWithWrapping(const float *input, Index cols_left, Index cols, Index row_step) const {
+ INTGEMM_AVX512BW Register ConsecutiveWithWrapping(const float *input, Index cols_left, Index cols, Index row_step) const {
static const __m512i neg127 = _mm512_set1_epi8(-127);
static const __m512i shuffle_param = _mm512_set_epi32(15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0);
const float* inputs[4];
for (int i = 0; i < sizeof(inputs) / sizeof(inputs[0]); ++i) {
- while (cols_left < sizeof(Integer) / sizeof(float)) {
+ while (cols_left < sizeof(Register) / sizeof(float)) {
input += cols * (row_step - 1);
cols_left += cols;
}
inputs[i] = input;
- input += sizeof(Integer) / sizeof(float);
- cols_left -= sizeof(Integer) / sizeof(float);
+ input += sizeof(Register) / sizeof(float);
+ cols_left -= sizeof(Register) / sizeof(float);
}
auto g0 = QuantizerGrab(inputs[0], mult_reg_);
@@ -292,41 +292,41 @@ struct AVX512_8bit {
// allocate registers manually) and no sign instruction.
template <typename Callback>
INTGEMM_AVX512BW static void Multiply(const int8_t *A, const int8_t *B, Index A_rows, Index width, Index B_cols, Callback callback) {
- typedef __m512i Integer;
+ typedef __m512i Register;
//typedef __m256 Float; // For quantization we only do 8 at a time.
// This is copy-paste from Multiply8_SSE2OrAVX2.
- assert(width % sizeof(Integer) == 0);
+ assert(width % sizeof(Register) == 0);
assert(B_cols % 8 == 0);
- assert(reinterpret_cast<uintptr_t>(A) % sizeof(Integer) == 0);
- assert(reinterpret_cast<uintptr_t>(B) % sizeof(Integer) == 0);
+ assert(reinterpret_cast<uintptr_t>(A) % sizeof(Register) == 0);
+ assert(reinterpret_cast<uintptr_t>(B) % sizeof(Register) == 0);
// There's 8 results for INTGEMM_AVX2 to handle.
auto callback_impl = callbacks::CallbackImpl<CPUType::AVX2, Callback>(callback);
- const int simd_width = width / sizeof(Integer);
- const Integer *B0_col = reinterpret_cast<const Integer*>(B);
+ const int simd_width = width / sizeof(Register);
+ const Register *B0_col = reinterpret_cast<const Register*>(B);
// Added for AVX512.
- Integer zeros = setzero_si<Integer>();
+ Register zeros = setzero_si<Register>();
// Go over 8 columns of B at a time.
for (Index B0_colidx = 0; B0_colidx != B_cols; B0_col += 8 * simd_width, B0_colidx += 8) {
// Process one row of A at a time. Doesn't seem to be faster to do multiple rows of A at once.
for (Index A_rowidx = 0; A_rowidx < A_rows; ++A_rowidx) {
// Iterate over shared (inner) dimension.
- const Integer *A_live = reinterpret_cast<const Integer *>(A + A_rowidx * width);
- const Integer *A_end = A_live + simd_width;
- const Integer *B_live = B0_col;
+ const Register *A_live = reinterpret_cast<const Register *>(A + A_rowidx * width);
+ const Register *A_end = A_live + simd_width;
+ const Register *B_live = B0_col;
// Do the first iteration to initialize the sums.
__m512i a = *A_live;
__mmask64 neg_mask = _mm512_test_epi8_mask(a, _mm512_set1_epi8(-128));
__m512i a_positive = _mm512_abs_epi8(a);
// These will be packed 16-bit integers containing sums for each column of B multiplied by the row of A.
- Integer sum0 = maddubs_epi16(a_positive, _mm512_mask_sub_epi8(B_live[0], neg_mask, zeros, B_live[0]));
- Integer sum1 = maddubs_epi16(a_positive, _mm512_mask_sub_epi8(B_live[1], neg_mask, zeros, B_live[1]));
- Integer sum2 = maddubs_epi16(a_positive, _mm512_mask_sub_epi8(B_live[2], neg_mask, zeros, B_live[2]));
- Integer sum3 = maddubs_epi16(a_positive, _mm512_mask_sub_epi8(B_live[3], neg_mask, zeros, B_live[3]));
- Integer sum4 = maddubs_epi16(a_positive, _mm512_mask_sub_epi8(B_live[4], neg_mask, zeros, B_live[4]));
- Integer sum5 = maddubs_epi16(a_positive, _mm512_mask_sub_epi8(B_live[5], neg_mask, zeros, B_live[5]));
- Integer sum6 = maddubs_epi16(a_positive, _mm512_mask_sub_epi8(B_live[6], neg_mask, zeros, B_live[6]));
- Integer sum7 = maddubs_epi16(a_positive, _mm512_mask_sub_epi8(B_live[7], neg_mask, zeros, B_live[7]));
+ Register sum0 = maddubs_epi16(a_positive, _mm512_mask_sub_epi8(B_live[0], neg_mask, zeros, B_live[0]));
+ Register sum1 = maddubs_epi16(a_positive, _mm512_mask_sub_epi8(B_live[1], neg_mask, zeros, B_live[1]));
+ Register sum2 = maddubs_epi16(a_positive, _mm512_mask_sub_epi8(B_live[2], neg_mask, zeros, B_live[2]));
+ Register sum3 = maddubs_epi16(a_positive, _mm512_mask_sub_epi8(B_live[3], neg_mask, zeros, B_live[3]));
+ Register sum4 = maddubs_epi16(a_positive, _mm512_mask_sub_epi8(B_live[4], neg_mask, zeros, B_live[4]));
+ Register sum5 = maddubs_epi16(a_positive, _mm512_mask_sub_epi8(B_live[5], neg_mask, zeros, B_live[5]));
+ Register sum6 = maddubs_epi16(a_positive, _mm512_mask_sub_epi8(B_live[6], neg_mask, zeros, B_live[6]));
+ Register sum7 = maddubs_epi16(a_positive, _mm512_mask_sub_epi8(B_live[7], neg_mask, zeros, B_live[7]));
++A_live;
B_live += 8;
@@ -384,7 +384,7 @@ struct AVX512_8bit {
// Unique code ends: can we do an inline function?
}
// Upcast to 32-bit and horizontally add.
- Integer ones = set1_epi16<Integer>(1);
+ Register ones = set1_epi16<Register>(1);
sum0 = madd_epi16(sum0, ones);
sum1 = madd_epi16(sum1, ones);
sum2 = madd_epi16(sum2, ones);
@@ -393,8 +393,8 @@ struct AVX512_8bit {
sum5 = madd_epi16(sum5, ones);
sum6 = madd_epi16(sum6, ones);
sum7 = madd_epi16(sum7, ones);
- Integer pack0123 = Pack0123(sum0, sum1, sum2, sum3);
- Integer pack4567 = Pack0123(sum4, sum5, sum6, sum7);
+ Register pack0123 = Pack0123(sum0, sum1, sum2, sum3);
+ Register pack4567 = Pack0123(sum4, sum5, sum6, sum7);
auto total = PermuteSummer(pack0123, pack4567);
callback_impl(total, callbacks::OutputBufferInfo(A_rowidx, B0_colidx, A_rows, B_cols));