diff options
Diffstat (limited to 'intgemm/avx2_gemm.h')
-rw-r--r-- | intgemm/avx2_gemm.h | 74 |
1 files changed, 33 insertions, 41 deletions
diff --git a/intgemm/avx2_gemm.h b/intgemm/avx2_gemm.h index d111b32..5e81475 100644 --- a/intgemm/avx2_gemm.h +++ b/intgemm/avx2_gemm.h @@ -19,34 +19,30 @@ INTGEMM_SELECT_COL_B(INTGEMM_AVX2, __m256i) class QuantizeTile16 { public: - INTGEMM_AVX2 explicit QuantizeTile16(float mult) : mult_(_mm256_set1_ps(mult)) {} - - INTGEMM_AVX2 Register Consecutive(const float *input) const { - return Tile(input, input + 8); + INTGEMM_AVX2 static inline Register Consecutive(FRegister mult_reg, const float *input) { + return Tile(mult_reg, input, input + 8); } - INTGEMM_AVX2 Register ConsecutiveWithWrapping(const float *input, Index cols_left, Index cols, Index row_step) const { - return Tile( + INTGEMM_AVX2 static inline Register ConsecutiveWithWrapping(FRegister mult_reg, const float *input, Index cols_left, Index cols, Index row_step) { + return Tile(mult_reg, input, input + 8 + (cols_left <= 8 ? cols * (row_step - 1) : 0)); } - INTGEMM_AVX2 Register ForReshape(const float *input, Index cols) const { + INTGEMM_AVX2 static inline Register ForReshape(FRegister mult_reg, const float *input, Index cols) { // 8 rows in the first 128-bit register, 8 in the second register. - return Tile(input, input + 8 * cols); + return Tile(mult_reg, input, input + 8 * cols); } private: - INTGEMM_AVX2 __m256i Tile(const float *input0, const float *input1) const { - Register g0 = QuantizerGrab(input0, mult_); - Register g1 = QuantizerGrab(input1, mult_); + INTGEMM_AVX2 static inline Register Tile(FRegister mult_reg, const float *input0, const float *input1) { + Register g0 = QuantizerGrab(input0, mult_reg); + Register g1 = QuantizerGrab(input1, mult_reg); Register packed = _mm256_packs_epi32(g0, g1); // Reorder the packed values because Intel does 0 1 2 3 8 9 10 11 4 5 6 7 12 13 14 15. // Technically this could be removed if the PrepareB did the same reordering internally. return _mm256_permute4x64_epi64(packed, 0xd8 /* 0, 2, 1, 3 */); } - - const FRegister mult_; }; struct Kernels16 { @@ -61,10 +57,10 @@ struct Kernels16 { INTGEMM_AVX2 static void Quantize(const float *input, int16_t *output, float quant_mult, Index size) { assert(size % 16 == 0); assert(reinterpret_cast<uintptr_t>(input) % 32 == 0); - avx2::QuantizeTile16 q(quant_mult); + FRegister q = set1_ps<FRegister>(quant_mult); const float *end = input + size; for (; input != end; input += 16, output += 16) { - *reinterpret_cast<__m256i*>(output) = q.Consecutive(input); + *reinterpret_cast<__m256i*>(output) = QuantizeTile16::Consecutive(q, input); } } @@ -96,17 +92,15 @@ struct Kernels16 { */ class QuantizeTile8 { public: - INTGEMM_AVX2 explicit QuantizeTile8(float quant_mult) : mult_(_mm256_set1_ps(quant_mult)) {} - - INTGEMM_AVX2 inline __m256i Consecutive(const float *input) const { - return Tile(input, input + 8, input + 16, input + 24); + INTGEMM_AVX2 static inline Register Consecutive(FRegister quant_mult, const float *input) { + return Tile(quant_mult, input, input + 8, input + 16, input + 24); } - INTGEMM_AVX2 inline __m256i ConsecutiveU(const float *input) const { - return TileU(input, input + 8, input + 16, input + 24); + INTGEMM_AVX2 static inline Register ConsecutiveU(FRegister quant_mult, const float *input) { + return TileU(quant_mult, input, input + 8, input + 16, input + 24); } - INTGEMM_AVX2 Register ConsecutiveWithWrapping(const float *input, Index cols_left, Index cols, Index row_step) const { + INTGEMM_AVX2 static inline Register ConsecutiveWithWrapping(FRegister quant_mult, const float *input, Index cols_left, Index cols, Index row_step) { const float* inputs[4]; for (Index i = 0; i < sizeof(inputs) / sizeof(inputs[0]); ++i) { while (cols_left < sizeof(Register) / sizeof(float)) { @@ -117,24 +111,24 @@ class QuantizeTile8 { input += sizeof(Register) / sizeof(float); cols_left -= sizeof(Register) / sizeof(float); } - return Tile(inputs[0], inputs[1], inputs[2], inputs[3]); + return Tile(quant_mult, inputs[0], inputs[1], inputs[2], inputs[3]); } - INTGEMM_AVX2 inline __m256i ForReshape(const float *input, Index cols) const { + INTGEMM_AVX2 static inline Register ForReshape(FRegister quant_mult, const float *input, Index cols) { // Put higher rows in the second half of the register. These will jumble // around in the same way then conveniently land in the right place. - return Tile(input, input + 2 * cols, input + 16 * cols, input + 18 * cols); + return Tile(quant_mult, input, input + 2 * cols, input + 16 * cols, input + 18 * cols); } - INTGEMM_AVX2 inline __m256i Tile(const float *input0, const float *input1, const float *input2, const float *input3) const { + INTGEMM_AVX2 static inline __m256i Tile(FRegister quant_mult, const float *input0, const float *input1, const float *input2, const float *input3) { // Looking at the assembly, gcc has pulled this outside the loops calling this. const __m256i neg127 = _mm256_set1_epi8(-127); const __m256i shuffle_param = _mm256_set_epi32(7, 3, 6, 2, 5, 1, 4, 0); // Grab 4 registers at a time in 32-bit format. - __m256i g0 = avx2::QuantizerGrab(input0, mult_); - __m256i g1 = avx2::QuantizerGrab(input1, mult_); - __m256i g2 = avx2::QuantizerGrab(input2, mult_); - __m256i g3 = avx2::QuantizerGrab(input3, mult_); + __m256i g0 = avx2::QuantizerGrab(input0, quant_mult); + __m256i g1 = avx2::QuantizerGrab(input1, quant_mult); + __m256i g2 = avx2::QuantizerGrab(input2, quant_mult); + __m256i g3 = avx2::QuantizerGrab(input3, quant_mult); // Pack 32-bit to 16-bit. __m256i packed0 = _mm256_packs_epi32(g0, g1); __m256i packed1 = _mm256_packs_epi32(g2, g3); @@ -151,16 +145,16 @@ class QuantizeTile8 { private: //A version that produces uint8_ts - INTGEMM_AVX2 inline __m256i TileU(const float *input0, const float *input1, const float *input2, const float *input3) const { + INTGEMM_AVX2 static inline Register TileU(FRegister quant_mult, const float *input0, const float *input1, const float *input2, const float *input3) { // Looking at the assembly, gcc has pulled this outside the loops calling this. const __m256i neg127 = _mm256_set1_epi8(-127); const __m256i pos127 = _mm256_set1_epi8(127); const __m256i shuffle_param = _mm256_set_epi32(7, 3, 6, 2, 5, 1, 4, 0); // Grab 4 registers at a time in 32-bit format. - __m256i g0 = avx2::QuantizerGrab(input0, mult_); - __m256i g1 = avx2::QuantizerGrab(input1, mult_); - __m256i g2 = avx2::QuantizerGrab(input2, mult_); - __m256i g3 = avx2::QuantizerGrab(input3, mult_); + __m256i g0 = avx2::QuantizerGrab(input0, quant_mult); + __m256i g1 = avx2::QuantizerGrab(input1, quant_mult); + __m256i g2 = avx2::QuantizerGrab(input2, quant_mult); + __m256i g3 = avx2::QuantizerGrab(input3, quant_mult); // Pack 32-bit to 16-bit. __m256i packed0 = _mm256_packs_epi32(g0, g1); __m256i packed1 = _mm256_packs_epi32(g2, g3); @@ -175,8 +169,6 @@ class QuantizeTile8 { // and the values are only used for GEMM. return _mm256_permutevar8x32_epi32(packed, shuffle_param); } - - const __m256 mult_; }; struct Kernels8 { @@ -187,9 +179,9 @@ struct Kernels8 { Quantize(input, output, quant_mult, rows * cols); } private: - INTGEMM_QUANTIZE_THREAD(INTGEMM_AVX2, __m256i, avx2) + INTGEMM_QUANTIZE_THREAD(INTGEMM_AVX2) public: - INTGEMM_QUANTIZE(INTGEMM_AVX2, __m256i, avx2) + INTGEMM_QUANTIZE(INTGEMM_AVX2) // Currently A is prepared by quantization but this could theoretically change. INTGEMM_AVX2 static inline void PrepareA(const float *input, uint8_t *output, float quant_mult, Index rows, Index cols) { @@ -200,10 +192,10 @@ struct Kernels8 { INTGEMM_AVX2 static void QuantizeU(const float *input, uint8_t *output, float quant_mult, Index size) { assert(size % 32 == 0); assert(reinterpret_cast<uintptr_t>(input) % 32 == 0); - avx2::QuantizeTile8 q(quant_mult); + FRegister q = set1_ps<FRegister>(quant_mult); const float *end = input + size; for (; input != end; input += 32, output += 32) { - *reinterpret_cast<__m256i*>(output) = q.ConsecutiveU(input); + *reinterpret_cast<__m256i*>(output) = QuantizeTile8::ConsecutiveU(q, input); } } |