diff options
Diffstat (limited to 'avx512_gemm.cc')
-rw-r--r-- | avx512_gemm.cc | 24 |
1 files changed, 13 insertions, 11 deletions
diff --git a/avx512_gemm.cc b/avx512_gemm.cc index e7df5ad..904b779 100644 --- a/avx512_gemm.cc +++ b/avx512_gemm.cc @@ -35,7 +35,7 @@ inline __m512i QuantizerGrab(const float *input, const __m512 quant_mult_reg) { // rearranging B. // // Convert to 16-bit signed integers. -void AVX512_16bit::Quantize(const float *input, int16_t *output, float quant_mult, int size) { +void AVX512_16bit::Quantize(const float *input, int16_t *output, float quant_mult, Index size) { assert(size % 16 == 0); assert(reinterpret_cast<uintptr_t>(input) % 64 == 0); // Fill with the quantization multiplier. @@ -48,7 +48,7 @@ void AVX512_16bit::Quantize(const float *input, int16_t *output, float quant_mul } // Convert to 8-bit signed integers. -void AVX512_8bit::Quantize(const float *input, int8_t *output, float quant_mult, int size) { +void AVX512_8bit::Quantize(const float *input, int8_t *output, float quant_mult, Index size) { assert(size % 16 == 0); assert(reinterpret_cast<uintptr_t>(input) % 64 == 0); const __m512i neg127 = _mm512_set1_epi32(-127); @@ -91,7 +91,7 @@ class QuantizeTile16 { explicit QuantizeTile16(float mult) : mult_reg_(_mm512_set1_ps(mult)) {} - inline __m512i ForReshape(const float *input, int cols) { + inline __m512i ForReshape(const float *input, Index cols) { __m512i g0 = QuantizerGrabHalves(input, input + 16 * cols, mult_reg_); __m512i g1 = QuantizerGrabHalves(input + 8 * cols, input + 24 * cols, mult_reg_); __m512i packed = _mm512_packs_epi32(g0, g1); @@ -109,7 +109,7 @@ class QuantizeTile8 { explicit QuantizeTile8(float mult) : mult_reg_(_mm512_set1_ps(mult)) {} - inline __m512i ForReshape(const float *input, int cols) { + inline __m512i ForReshape(const float *input, Index cols) { // TODO: try alternative: _mm512_cvtsepi32_epi8 ? const __m512i neg127 = _mm512_set1_epi8(-127); // In reverse order: grabbing the first 32-bit values from each 128-bit register, then the second 32-bit values, etc. @@ -137,30 +137,30 @@ class QuantizeTile8 { } // namespace -void AVX512_16bit::PrepareB(const float *input, int16_t *output, float quant_mult, int rows, int cols) { +void AVX512_16bit::PrepareB(const float *input, int16_t *output, float quant_mult, Index rows, Index cols) { PrepareBFor16(input, output, QuantizeTile16(quant_mult), rows, cols); } -void AVX512_16bit::SelectColumnsB(const int16_t *input, int16_t *output, int rows, const std::size_t *cols_begin, const std::size_t *cols_end) { +void AVX512_16bit::SelectColumnsB(const int16_t *input, int16_t *output, Index rows, const Index *cols_begin, const Index *cols_end) { SelectColumnsOfB((const __m512i*)input, (__m512i*)output, rows * 2, cols_begin, cols_end); } -void AVX512_8bit::PrepareB(const float *input, int8_t *output, float quant_mult, int rows, int cols) { +void AVX512_8bit::PrepareB(const float *input, int8_t *output, float quant_mult, Index rows, Index cols) { PrepareBFor8(input, output, QuantizeTile8(quant_mult), rows, cols); } -void AVX512_8bit::SelectColumnsB(const int8_t *input, int8_t *output, int rows, const std::size_t *cols_begin, const std::size_t *cols_end) { +void AVX512_8bit::SelectColumnsB(const int8_t *input, int8_t *output, Index rows, const Index *cols_begin, const Index *cols_end) { SelectColumnsOfB((const __m512i*)input, (__m512i*)output, rows, cols_begin, cols_end); } -void AVX512_16bit::Multiply(const int16_t *A, const int16_t *B, float *C, float unquant_mult, int A_rows, int width, int B_cols) { +void AVX512_16bit::Multiply(const int16_t *A, const int16_t *B, float *C, float unquant_mult, Index A_rows, Index width, Index B_cols) { // The unquantization is only 256-bit wide because there are 8 results. Multiply16<__m512i, __m256> (A, B, C, unquant_mult, A_rows, width, B_cols); } // Special AVX512 implementation due to having 32 registers (so I don't have to // allocate registers manually) and no sign instruction. -void AVX512_8bit::Multiply(const int8_t *A, const int8_t *B, float *C, float unquant_mult, int A_rows, int width, int B_cols) { +void AVX512_8bit::Multiply(const int8_t *A, const int8_t *B, float *C, float unquant_mult, Index A_rows, Index width, Index B_cols) { typedef __m512i Integer; typedef __m256 Float; // For quantization we only do 8 at a time. // This is copy-paste from Multiply8_SSE2OrAVX2. @@ -264,7 +264,9 @@ void AVX512_8bit::Multiply(const int8_t *A, const int8_t *B, float *C, float unq sum7 = madd_epi16(sum7, ones); Integer pack0123 = Pack0123(sum0, sum1, sum2, sum3); Integer pack4567 = Pack0123(sum4, sum5, sum6, sum7); - WriteC(C + A_rowidx * B_cols + B0_colidx, pack0123, pack4567, unquant_reg); + + auto total = PermuteSummer(pack0123, pack4567); + WriteC(C + A_rowidx * B_cols + B0_colidx, total, unquant_reg); } } } |