diff options
-rw-r--r-- | avx2_gemm.cc | 8 | ||||
-rw-r--r-- | avx2_gemm.h | 4 | ||||
-rw-r--r-- | avx512_gemm.cc | 8 | ||||
-rw-r--r-- | avx512_gemm.h | 4 | ||||
-rw-r--r-- | interleave.h | 31 | ||||
-rw-r--r-- | intgemm.cc | 8 | ||||
-rw-r--r-- | intgemm.h | 6 | ||||
-rw-r--r-- | multiply.h | 10 | ||||
-rw-r--r-- | sse2_gemm.cc | 4 | ||||
-rw-r--r-- | sse2_gemm.h | 2 | ||||
-rw-r--r-- | ssse3_gemm.cc | 4 | ||||
-rw-r--r-- | ssse3_gemm.h | 2 |
12 files changed, 81 insertions, 10 deletions
diff --git a/avx2_gemm.cc b/avx2_gemm.cc index 59a4c48..6ac500a 100644 --- a/avx2_gemm.cc +++ b/avx2_gemm.cc @@ -125,10 +125,18 @@ void AVX2_16bit::PrepareB(const float *input, int16_t *output, float quant_mult, PrepareBFor16(input, output, QuantizeTile16(quant_mult), rows, cols); } +void AVX2_16bit::SelectColumnsB(const int16_t *input, int16_t *output, int rows, const int *cols_begin, const int *cols_end) { + SelectColumnsOfB((const __m256i*)input, (__m256i*)output, rows * 2, cols_begin, cols_end); +} + void AVX2_8bit::PrepareB(const float *input, int8_t *output, float quant_mult, int rows, int cols) { PrepareBFor8(input, output, QuantizeTile8(quant_mult), rows, cols); } +void AVX2_8bit::SelectColumnsB(const int8_t *input, int8_t *output, int rows, const int *cols_begin, const int *cols_end) { + SelectColumnsOfB((const __m256i*)input, (__m256i*)output, rows, cols_begin, cols_end); +} + void AVX2_16bit::Multiply(const int16_t *A, const int16_t *B, float *C, float unquant_mult, int A_rows, int width, int B_cols) { Multiply16<__m256i, __m256>(A, B, C, unquant_mult, A_rows, width, B_cols); } diff --git a/avx2_gemm.h b/avx2_gemm.h index 23faa2a..31db7f6 100644 --- a/avx2_gemm.h +++ b/avx2_gemm.h @@ -20,6 +20,8 @@ struct AVX2_16bit { static void PrepareB(const float *input, int16_t *output, float quant_mult, int rows, int cols); + static void SelectColumnsB(const int16_t *input, int16_t *output, int rows, const int *cols_begin, const int *cols_end); + static void Multiply(const int16_t *A, const int16_t *B, float *C, float unquant_mult, int A_rows, int width, int B_cols); static const char *const kName; @@ -43,6 +45,8 @@ struct AVX2_8bit { static void PrepareB(const float *input, int8_t *output, float quant_mult, int rows, int cols); + static void SelectColumnsB(const int8_t *input, int8_t *output, int rows, const int *cols_begin, const int *cols_end); + static void Multiply(const int8_t *A, const int8_t *B, float *C, float unquant_mult, int A_rows, int width, int B_cols); static const char *const kName; diff --git a/avx512_gemm.cc b/avx512_gemm.cc index 70a785b..efa43d8 100644 --- a/avx512_gemm.cc +++ b/avx512_gemm.cc @@ -142,10 +142,18 @@ void AVX512_16bit::PrepareB(const float *input, int16_t *output, float quant_mul PrepareBFor16(input, output, QuantizeTile16(quant_mult), rows, cols); } +void AVX512_16bit::SelectColumnsB(const int16_t *input, int16_t *output, int rows, const int *cols_begin, const int *cols_end) { + SelectColumnsOfB((const __m512i*)input, (__m512i*)output, rows * 2, cols_begin, cols_end); +} + void AVX512_8bit::PrepareB(const float *input, int8_t *output, float quant_mult, int rows, int cols) { PrepareBFor8(input, output, QuantizeTile8(quant_mult), rows, cols); } +void AVX512_8bit::SelectColumnsB(const int8_t *input, int8_t *output, int rows, const int *cols_begin, const int *cols_end) { + SelectColumnsOfB((const __m512i*)input, (__m512i*)output, rows, cols_begin, cols_end); +} + void AVX512_16bit::Multiply(const int16_t *A, const int16_t *B, float *C, float unquant_mult, int A_rows, int width, int B_cols) { // The unquantization is only 256-bit wide because there are 8 results. Multiply16<__m512i, __m256> (A, B, C, unquant_mult, A_rows, width, B_cols); diff --git a/avx512_gemm.h b/avx512_gemm.h index 2bc0358..21b49cf 100644 --- a/avx512_gemm.h +++ b/avx512_gemm.h @@ -34,6 +34,8 @@ struct AVX512_16bit { static void PrepareB(const float *input, int16_t *output, float quant_mult, int rows, int cols); + static void SelectColumnsB(const int16_t *input, int16_t *output, int rows, const int *cols_begin, const int *cols_end); + static void Multiply(const int16_t *A, const int16_t *B, float *C, float unquant_mult, int A_rows, int width, int B_cols); static const char *const kName; @@ -59,6 +61,8 @@ struct AVX512_8bit { static void PrepareB(const float *input, int8_t *output, float quant_mult, int rows, int cols); + static void SelectColumnsB(const int8_t *input, int8_t *output, int rows, const int *cols_begin, const int *cols_end); + static void Multiply(const int8_t *A, const int8_t *B, float *C, float unquant_mult, int A_rows, int width, int B_cols); static const char *const kName; diff --git a/interleave.h b/interleave.h index 5e26f57..1c449ca 100644 --- a/interleave.h +++ b/interleave.h @@ -44,14 +44,25 @@ inline void Interleave64(type &first, type &second) { \ first = temp; \ } + +template <class Register> inline Register setzero_si(); #ifdef __SSE2__ INTGEMM_INTERLEAVE(__m128i, ) +template <> inline __m128i setzero_si<__m128i>() { + return _mm_setzero_si128(); +} #endif #ifdef __AVX2__ INTGEMM_INTERLEAVE(__m256i, 256) +template <> inline __m256i setzero_si<__m256i>() { + return _mm256_setzero_si256(); +} #endif #ifdef __AVX512F__ INTGEMM_INTERLEAVE(__m512i, 512) +template <> inline __m512i setzero_si<__m512i>() { + return _mm512_setzero_si512(); +} #endif template <class Register> inline void Swap(Register &a, Register &b) { @@ -237,4 +248,24 @@ template <class Quantizer> inline void PrepareBFor16(const float *input, int16_t } } +/* Select columns of B from PrepareB format to PrepareB format. + */ +template <class Register> inline void SelectColumnsOfB(const Register *input, Register *output, int rows_bytes /* number of bytes in a row */, const int *cols_begin, const int *cols_end) { + // Do columns for multiples of 8. + int register_rows = rows_bytes / sizeof(Register); + const int *cols_end8 = cols_begin + ((cols_end - cols_begin) & ~7); + const Register *starts[8]; + for (; cols_begin != cols_end8; cols_begin += 8) { + for (int k = 0; k < 8; ++k) { + starts[k] = input + (cols_begin[k] & 7) + (cols_begin[k] & ~7) * register_rows; + } + for (int r = 0; r < register_rows; ++r) { + for (int k = 0; k < 8; ++k) { + *(output++) = *starts[k]; + starts[k] += 8; + } + } + } +} + } // namespace intgemm @@ -25,6 +25,9 @@ struct Unsupported_16bit { static void PrepareB(const float *, int16_t *, float, int, int) { throw UnsupportedCPU(); } + static void SelectColumnsB(const int16_t *input, int16_t *output, int rows, const int *cols_begin, const int *cols_end) { + throw UnsupportedCPU(); + } static void Multiply(const int16_t *, const int16_t *, float *C, float, int, int, int) { throw UnsupportedCPU(); } @@ -39,6 +42,9 @@ struct Unsupported_8bit { static void PrepareB(const float *, int8_t *, float, int, int) { throw UnsupportedCPU(); } + static void SelectColumnsB(const int8_t *input, int8_t *output, int rows, const int *cols_begin, const int *cols_end) { + throw UnsupportedCPU(); + } static void Multiply(const int8_t *, const int8_t *, float *C, float, int, int, int) { throw UnsupportedCPU(); } @@ -77,11 +83,13 @@ template <class T> T ChooseCPU(T avx512, T avx2, T ssse3, T sse2, T unsupported) void (*Int16::Quantize)(const float *input, int16_t *output, float quant_mult, int size) = ChooseCPU(AVX512_16bit::Quantize, AVX2_16bit::Quantize, SSE2_16bit::Quantize, SSE2_16bit::Quantize, Unsupported_16bit::Quantize); void (*Int16::PrepareB)(const float *input, int16_t *output, float quant_mult, int rows, int cols) = ChooseCPU(AVX512_16bit::PrepareB, AVX2_16bit::PrepareB, SSE2_16bit::PrepareB, SSE2_16bit::PrepareB, Unsupported_16bit::PrepareB); +void (*Int16::SelectColumnsB)(const int16_t *input, int16_t *output, int rows, const int *cols_begin, const int *cols_end) = ChooseCPU(AVX512_16bit::SelectColumnsB, AVX2_16bit::SelectColumnsB, SSE2_16bit::SelectColumnsB, SSE2_16bit::SelectColumnsB, Unsupported_16bit::SelectColumnsB); void (*Int16::Multiply)(const int16_t *A, const int16_t *B, float *C, float unquant_mult, int A_rows, int width, int B_cols) = ChooseCPU(AVX512_16bit::Multiply, AVX2_16bit::Multiply, SSE2_16bit::Multiply, SSE2_16bit::Multiply, Unsupported_16bit::Multiply); const char *const Int16::kName = ChooseCPU(AVX512_16bit::kName, AVX2_16bit::kName, SSE2_16bit::kName, SSE2_16bit::kName, Unsupported_16bit::kName); void (*Int8::Quantize)(const float *input, int8_t *output, float quant_mult, int size) = ChooseCPU(AVX512_8bit::Quantize, AVX2_8bit::Quantize, SSSE3_8bit::Quantize, Unsupported_8bit::Quantize, Unsupported_8bit::Quantize); void (*Int8::PrepareB)(const float *input, int8_t *output, float quant_mult, int rows, int cols) = ChooseCPU(AVX512_8bit::PrepareB, AVX2_8bit::PrepareB, SSSE3_8bit::PrepareB, Unsupported_8bit::PrepareB, Unsupported_8bit::PrepareB); +void (*Int8::SelectColumnsB)(const int8_t *input, int8_t *output, int rows, const int *cols_begin, const int *cols_end) = ChooseCPU(AVX512_8bit::SelectColumnsB, AVX2_8bit::SelectColumnsB, SSSE3_8bit::SelectColumnsB, Unsupported_8bit::SelectColumnsB, Unsupported_8bit::SelectColumnsB); void (*Int8::Multiply)(const int8_t *A, const int8_t *B, float *C, float unquant_mult, int A_rows, int width, int B_cols) = ChooseCPU(AVX512_8bit::Multiply, AVX2_8bit::Multiply, SSSE3_8bit::Multiply, Unsupported_8bit::Multiply, Unsupported_8bit::Multiply); const char *const Int8::kName = ChooseCPU(AVX512_8bit::kName, AVX2_8bit::kName, SSSE3_8bit::kName, Unsupported_8bit::kName, Unsupported_8bit::kName); @@ -83,6 +83,9 @@ struct Int16 { // It will match the Multiply function on the same CPU though. static void (*PrepareB)(const float *input, int16_t *output, float quant_mult, int rows, int cols); + // Select columns from a prepared B matrix. The number of selected columns must be a multiple of 8. + static void (*SelectColumnsB)(const int16_t *input, int16_t *output, int rows, const int *cols_begin, const int *cols_end); + // Multiply C = A * B, presuming A and B have been prepared. static void (*Multiply)(const int16_t *A, const int16_t *B, float *C, float unquant_mult, int A_rows, int width, int B_cols); @@ -114,6 +117,9 @@ struct Int8 { // It will match the Multiply function on the same CPU though. static void (*PrepareB)(const float *input, int8_t *output, float quant_mult, int rows, int cols); + // Select columns from a prepared B matrix. The number of selected columns must be a multiple of 8. + static void (*SelectColumnsB)(const int8_t *input, int8_t *output, int rows, const int *cols_begin, const int *cols_end); + // Multiply C = A * B, presuming A and B have been prepared. static void (*Multiply)(const int8_t *A, const int8_t *B, float *C, float unquant_mult, int A_rows, int width, int B_cols); @@ -14,7 +14,6 @@ namespace intgemm { */ template <class Register> inline Register set1_epi16(int16_t to); template <class Register> inline Register set1_ps(float to); -template <class Register> inline Register setzero_si(); #ifdef __SSE2__ inline __m128i add_epi32(__m128i first, __m128i second) { return _mm_add_epi32(first, second); @@ -28,9 +27,6 @@ template <> inline __m128i set1_epi16<__m128i>(int16_t to) { template <> inline __m128 set1_ps<__m128>(float to) { return _mm_set1_ps(to); } -template <> inline __m128i setzero_si<__m128i>() { - return _mm_setzero_si128(); -} inline __m128i madd_epi16(__m128i first, __m128i second) { return _mm_madd_epi16(first, second); } @@ -64,9 +60,6 @@ template <> inline __m256i set1_epi16<__m256i>(int16_t to) { template <> inline __m256 set1_ps<__m256>(float to) { return _mm256_set1_ps(to); } -template <> inline __m256i setzero_si<__m256i>() { - return _mm256_setzero_si256(); -} inline __m256i madd_epi16(__m256i first, __m256i second) { return _mm256_madd_epi16(first, second); } @@ -100,9 +93,6 @@ template <> inline __m512i set1_epi16<__m512i>(int16_t to) { template <> inline __m512 set1_ps<__m512>(float to) { return _mm512_set1_ps(to); } -template <> inline __m512i setzero_si<__m512i>() { - return _mm512_setzero_si512(); -} inline __m512i madd_epi16(__m512i first, __m512i second) { return _mm512_madd_epi16(first, second); } diff --git a/sse2_gemm.cc b/sse2_gemm.cc index 05a4080..3e0a0d8 100644 --- a/sse2_gemm.cc +++ b/sse2_gemm.cc @@ -61,6 +61,10 @@ void SSE2_16bit::PrepareB(const float *input, int16_t *output, float quant_mult, PrepareBFor16(input, output, QuantizeTile16(quant_mult), rows, cols); } +void SSE2_16bit::SelectColumnsB(const int16_t *input, int16_t *output, int rows, const int *cols_begin, const int *cols_end) { + SelectColumnsOfB((const __m128i*)input, (__m128i*)output, rows * 2, cols_begin, cols_end); +} + void SSE2_16bit::Multiply(const int16_t *A, const int16_t *B, float *C, float unquant_mult, int A_rows, int width, int B_cols) { Multiply16<__m128i, __m128>(A, B, C, unquant_mult, A_rows, width, B_cols); } diff --git a/sse2_gemm.h b/sse2_gemm.h index 23997ed..84823fa 100644 --- a/sse2_gemm.h +++ b/sse2_gemm.h @@ -22,6 +22,8 @@ struct SSE2_16bit { static void PrepareB(const float *input, int16_t *output, float quant_mult, int rows, int cols); + static void SelectColumnsB(const int16_t *input, int16_t *output, int rows, const int *cols_begin, const int *cols_end); + static void Multiply(const int16_t *A, const int16_t *B, float *C, float unquant_mult, int A_rows, int width, int B_cols); static const char *const kName; diff --git a/ssse3_gemm.cc b/ssse3_gemm.cc index 4ccecec..c044910 100644 --- a/ssse3_gemm.cc +++ b/ssse3_gemm.cc @@ -78,6 +78,10 @@ void SSSE3_8bit::PrepareB(const float *input, int8_t *output, float quant_mult, PrepareBFor8(input, output, QuantizeTile8(quant_mult), rows, cols); } +void SSSE3_8bit::SelectColumnsB(const int8_t *input, int8_t *output, int rows, const int *cols_begin, const int *cols_end) { + SelectColumnsOfB((const __m128i*)input, (__m128i*)output, rows, cols_begin, cols_end); +} + void SSSE3_8bit::Multiply(const int8_t *A, const int8_t *B, float *C, float unquant_mult, int A_rows, int width, int B_cols) { Multiply8_SSE2OrAVX2<Multiply8_C, __m128i, __m128>(A, B, C, unquant_mult, A_rows, width, B_cols); } diff --git a/ssse3_gemm.h b/ssse3_gemm.h index 9897d94..d904a3e 100644 --- a/ssse3_gemm.h +++ b/ssse3_gemm.h @@ -23,6 +23,8 @@ struct SSSE3_8bit { static void PrepareB(const float *input, int8_t *output, float quant_mult, int rows, int cols); + static void SelectColumnsB(const int8_t *input, int8_t *output, int rows, const int *cols_begin, const int *cols_end); + static void Multiply(const int8_t *A, const int8_t *B, float *C, float unquant_mult, int A_rows, int width, int B_cols); static const char *const kName; |