diff options
author | Kenneth Heafield <github@kheafield.com> | 2019-04-18 16:38:59 +0300 |
---|---|---|
committer | Kenneth Heafield <github@kheafield.com> | 2019-04-18 16:38:59 +0300 |
commit | fba6fb3b78611b262cac3022fc585efc020991ff (patch) | |
tree | 25b9cd6679d3854651630777a2c814493d8416ef | |
parent | e75044e331d905f0b01adfe219852582579d2872 (diff) |
Look ma, no gcc compiler errors
-rw-r--r-- | avx2_gemm.h | 10 | ||||
-rw-r--r-- | avx512_gemm.h | 17 | ||||
-rw-r--r-- | cops.h | 78 | ||||
-rw-r--r-- | multiply.h | 16 | ||||
-rw-r--r-- | sse2_gemm.h | 12 | ||||
-rw-r--r-- | ssse3_gemm.h | 2 |
6 files changed, 57 insertions, 78 deletions
diff --git a/avx2_gemm.h b/avx2_gemm.h index c326976..393eebb 100644 --- a/avx2_gemm.h +++ b/avx2_gemm.h @@ -81,12 +81,8 @@ struct AVX2_16bit { AVX2 static void SelectColumnsB(const int16_t *input, int16_t *output, Index rows, const Index *cols_begin, const Index *cols_end) { avx2::SelectColumnsOfB((const __m256i*)input, (__m256i*)output, rows * 2, cols_begin, cols_end); } -/* - AVX2 static void Multiply(const int16_t *A, const int16_t *B, float *C, float unquant_mult, Index A_rows, Index width, Index B_cols) { - Multiply16__m256i<JustUnquantizeC> (A, B, JustUnquantizeC(C, unquant_mult), A_rows, width, B_cols); - } - */ - MULTIPLY16_define(__m256i, AVX2) + + MULTIPLY16_define(__m256i, AVX2, OnAVX2) constexpr static const char *const kName = "16-bit AVX2"; @@ -184,7 +180,7 @@ struct AVX2_8bit { //Multiply8_SSE2OrAVX2<Multiply8_AVXAVX2, __m256i, __m256>(A, B, C, unquant_mult, A_rows, width, B_cols); Multiply8_SSE2OrAVX2__m256i<JustUnquantizeC>(A, B, JustUnquantizeC(C, unquant_mult), A_rows, width, B_cols); }*/ - MULTIPLY8_define(__m256i, AVX2) + MULTIPLY8_define(__m256i, AVX2, OnAVX2) constexpr static const char *const kName = "8-bit AVX2"; diff --git a/avx512_gemm.h b/avx512_gemm.h index e48ed9a..3806929 100644 --- a/avx512_gemm.h +++ b/avx512_gemm.h @@ -156,12 +156,8 @@ struct AVX512_16bit { AVX512F static void SelectColumnsB(const int16_t *input, int16_t *output, Index rows, const Index *cols_begin, const Index *cols_end) { avx512f::SelectColumnsOfB((const __m512i*)input, (__m512i*)output, rows * 2, cols_begin, cols_end); } -/* - AVX512F static void Multiply(const int16_t *A, const int16_t *B, float *C, float unquant_mult, Index A_rows, Index width, Index B_cols) { - // The unquantization is only 256-bit wide because there are 8 results. - Multiply16__m512i<JustUnquantizeC> (A, B, JustUnquantizeC(C, unquant_mult), A_rows, width, B_cols); - }*/ - MULTIPLY16_define(__m512i, AVX512F) + + MULTIPLY16_define(__m512i, AVX512F, OnAVX2) constexpr static const char *const kName = "16-bit AVX512"; @@ -209,7 +205,7 @@ struct AVX512_8bit { // Special AVX512 implementation due to having 32 registers (so I don't have to // allocate registers manually) and no sign instruction. template <class WriteC> - AVX512BW static void Multiply(const int8_t *A, const int8_t *B, WriteC functor, Index A_rows, Index width, Index B_cols) { + AVX512BW static void Multiply(const int8_t *A, const int8_t *B, WriteC C, Index A_rows, Index width, Index B_cols) { typedef __m512i Integer; //typedef __m256 Float; // For quantization we only do 8 at a time. // This is copy-paste from Multiply8_SSE2OrAVX2. @@ -217,8 +213,8 @@ struct AVX512_8bit { assert(B_cols % 8 == 0); assert(reinterpret_cast<uintptr_t>(A) % sizeof(Integer) == 0); assert(reinterpret_cast<uintptr_t>(B) % sizeof(Integer) == 0); - //assert(reinterpret_cast<uintptr_t>(C) % sizeof(Integer) == 0); - //Float unquant_reg = set1_ps<Float>(unquant_mult); + // There's 8 results for AVX2 to handle. + typename WriteC::OnAVX2 write_C(C); const int simd_width = width / sizeof(Integer); const Integer *B0_col = reinterpret_cast<const Integer*>(B); // Added for AVX512. @@ -315,8 +311,7 @@ struct AVX512_8bit { Integer pack4567 = Pack0123(sum4, sum5, sum6, sum7); auto total = PermuteSummer(pack0123, pack4567); - //WriteC(C + A_rowidx * B_cols + B0_colidx, total, unquant_reg); - functor(A_rowidx, B_cols, B0_colidx, total); + write_C(A_rowidx, B_cols, B0_colidx, total); } } } @@ -6,46 +6,44 @@ namespace intgemm { class JustUnquantizeC { -public: - JustUnquantizeC(float *C, float unquant_mult); - - SSE2 inline void operator()(Index rowIDX, Index cols, Index colIDX, MultiplyResult128 result); - AVX2 inline void operator()(Index rowIDX, Index cols, Index colIDX, __m256i result); - -private: - SSE2 void InitRegisterSSE(float unquant_mult); - AVX2 void InitRegisterAVX2(float unquant_mult); - - float *C_; - __m128 unquant_mult_128_; // Registers - __m256 unquant_mult_256_; + public: + JustUnquantizeC(float *C, float unquant_mult) : C_(C), unquant_mult_(unquant_mult) {} + + class OnSSE2 { + public: + SSE2 explicit OnSSE2(const JustUnquantizeC &from) + : C_(from.C_), unquant_mult_(_mm_set1_ps(from.unquant_mult_)) { + assert(reinterpret_cast<uintptr_t>(C_) % sizeof(__m128i) == 0); + } + + SSE2 inline void operator()(Index rowIDX, Index cols, Index colIDX, MultiplyResult128 result) { + *reinterpret_cast<__m128*>(C_ + rowIDX*cols + colIDX) = mul_ps(cvtepi32_ps(result.pack0123), unquant_mult_); + *reinterpret_cast<__m128*>(C_ + rowIDX*cols + colIDX + 4) = mul_ps(cvtepi32_ps(result.pack4567), unquant_mult_); + } + private: + float *C_; + __m128 unquant_mult_; + }; + + class OnAVX2 { + public: + AVX2 explicit OnAVX2(const JustUnquantizeC &from) + : C_(from.C_), unquant_mult_(_mm256_set1_ps(from.unquant_mult_)) { + assert(reinterpret_cast<uintptr_t>(C_) % sizeof(__m256i) == 0); + } + + AVX2 inline void operator()(Index rowIDX, Index cols, Index colIDX, __m256i result) { + *reinterpret_cast<__m256*>(C_ + rowIDX*cols + colIDX) = mul_ps(cvtepi32_ps(result), unquant_mult_); + } + + private: + float *C_; + __m256 unquant_mult_; + }; + + private: + float *C_; + float unquant_mult_; }; -SSE2 void JustUnquantizeC::InitRegisterSSE(float unquant_mult) { - assert(reinterpret_cast<uintptr_t>(C_) % sizeof(__m128) == 0); - unquant_mult_128_ = _mm_set1_ps(unquant_mult); -} - -AVX2 void JustUnquantizeC::InitRegisterAVX2(float unquant_mult) { - assert(reinterpret_cast<uintptr_t>(C_) % sizeof(__m256) == 0); - unquant_mult_256_ = _mm256_set1_ps(unquant_mult); -} - -JustUnquantizeC::JustUnquantizeC(float *C, float unquant_mult) : C_(C) { - //We need both to make sure our tests pass - //Some of the assertions might give false positives on SSE2/3 - InitRegisterSSE(unquant_mult); - if (__builtin_cpu_supports("avx2")) { - InitRegisterAVX2(unquant_mult); - } -} - - -SSE2 inline void JustUnquantizeC::operator()(Index rowIDX, Index cols, Index colIDX, MultiplyResult128 result){ - *reinterpret_cast<__m128*>(C_ + rowIDX*cols + colIDX) = mul_ps(cvtepi32_ps(result.pack0123), unquant_mult_128_); - *reinterpret_cast<__m128*>(C_ + rowIDX*cols + colIDX + 4) = mul_ps(cvtepi32_ps(result.pack4567), unquant_mult_128_); -} -AVX2 inline void JustUnquantizeC::operator()(Index rowIDX, Index cols, Index colIDX, __m256i result) { - *reinterpret_cast<__m256*>(C_ + rowIDX*cols + colIDX) = mul_ps(cvtepi32_ps(result), unquant_mult_256_); -} } //Namespace @@ -135,15 +135,14 @@ PACK_DEFINE(AVX512F, __m512i) // width must be a multiple of the register size. // B_cols must be a multiple of 8. // Multiply16 -#define MULTIPLY16_define(Integer, target) \ - template <class WriteC> target static void Multiply(const int16_t *A, const int16_t *B, WriteC functor, Index A_rows, Index width, Index B_cols) { \ +#define MULTIPLY16_define(Integer, target, WriteCSubType) \ + template <class WriteC> target static void Multiply(const int16_t *A, const int16_t *B, WriteC C, Index A_rows, Index width, Index B_cols) { \ assert(width % (sizeof(Integer) / sizeof(int16_t)) == 0); \ assert(B_cols % 8 == 0); \ assert(reinterpret_cast<uintptr_t>(A) % sizeof(Integer) == 0); \ assert(reinterpret_cast<uintptr_t>(B) % sizeof(Integer) == 0); \ - /*assert(reinterpret_cast<uintptr_t>(C) % sizeof(Integer) == 0); Moved to WriteC*/ \ const int simd_width = width / (sizeof(Integer) / sizeof(int16_t)); \ - /*const Float unquant_reg = set1_ps<Float>(unquant_mult); moved to WriteC*/ \ + typename WriteC::WriteCSubType write_C(C); \ const Integer *B0_col = reinterpret_cast<const Integer *>(B); \ for (int B0_colidx = 0; B0_colidx < B_cols; B0_col += 8 * simd_width, B0_colidx += 8) { \ /* Process one row of A at a time. Doesn't seem to be faster to do multiple rows of A at once.*/ \ @@ -187,7 +186,7 @@ PACK_DEFINE(AVX512F, __m512i) Integer pack4567 = Pack0123(sum4, sum5, sum6, sum7); \ /*The specific implementation may need to reduce further.*/ \ auto total = PermuteSummer(pack0123, pack4567); \ - functor(A_rowidx, B_cols, B0_colidx, total); \ + write_C(A_rowidx, B_cols, B0_colidx, total); \ } \ } \ } \ @@ -340,14 +339,15 @@ SSSE3 inline static void InnerSSSE3( sum7 = adds_epi16(sum7, maddubs_epi16(a_positive, sign_epi8(b[7], a))); } //AVX2 or SSSE3 multiply -#define MULTIPLY8_define(Integer, target) \ -template <class WriteC> target static void Multiply(const int8_t *A, const int8_t *B, WriteC functor, Index A_rows, Index width, Index B_cols) { \ +#define MULTIPLY8_define(Integer, target, WriteCSubType) \ +template <class WriteC> target static void Multiply(const int8_t *A, const int8_t *B, WriteC C, Index A_rows, Index width, Index B_cols) { \ assert(width % sizeof(Integer) == 0); \ assert(B_cols % 8 == 0); \ assert(reinterpret_cast<uintptr_t>(A) % sizeof(Integer) == 0); \ assert(reinterpret_cast<uintptr_t>(B) % sizeof(Integer) == 0); \ const int simd_width = width / sizeof(Integer); \ const Integer *B0_col = reinterpret_cast<const Integer*>(B); \ + typename WriteC::WriteCSubType c_writer(C); \ /*Go over 8 columns of B at a time.*/ \ for (int B0_colidx = 0; B0_colidx != B_cols; B0_col += 8 * simd_width, B0_colidx += 8) { \ /*Process one row of A at a time. Doesn't seem to be faster to do multiple rows of A at once.*/ \ @@ -403,7 +403,7 @@ template <class WriteC> target static void Multiply(const int8_t *A, const int8_ Integer pack4567 = Pack0123(sum4, sum5, sum6, sum7); \ auto total = PermuteSummer(pack0123, pack4567); \ /*WriteC(C + A_rowidx * B_cols + B0_colidx, total, unquant_reg);*/ \ - functor(A_rowidx, B_cols, B0_colidx, total); \ + c_writer(A_rowidx, B_cols, B0_colidx, total); \ } \ } \ } \ diff --git a/sse2_gemm.h b/sse2_gemm.h index a3e8233..8717150 100644 --- a/sse2_gemm.h +++ b/sse2_gemm.h @@ -65,11 +65,6 @@ struct SSE2_16bit { // Tile size for B; B must be a multiple of this block size. static const Index kBTileRow = 8; static const Index kBTileCol = 8; -/* - SSE2 static void PrepareB(const float *input, int16_t *output, float quant_mult, Index rows, Index cols) { - //TODO #DEFINE - PrepareBFor16(input, output, sse2::QuantizeTile16(quant_mult), rows, cols); - }*/ PREPARE_B_16_DEF(SSE2, sse2::QuantizeTile16) @@ -77,12 +72,7 @@ struct SSE2_16bit { //TODO #DEFINE sse2::SelectColumnsOfB((const __m128i*)input, (__m128i*)output, rows * 2, cols_begin, cols_end); } -/* - SSE2 static void Multiply(const int16_t *A, const int16_t *B, float *C, float unquant_mult, Index A_rows, Index width, Index B_cols) { - //TODO #DEFINE - Multiply16__m128i<JustUnquantizeC> (A, B, JustUnquantizeC(C, unquant_mult), A_rows, width, B_cols); - }*/ - MULTIPLY16_define(__m128i, SSE2) + MULTIPLY16_define(__m128i, SSE2, OnSSE2) constexpr static const char *const kName = "16-bit SSE2"; diff --git a/ssse3_gemm.h b/ssse3_gemm.h index d384500..b19e552 100644 --- a/ssse3_gemm.h +++ b/ssse3_gemm.h @@ -102,7 +102,7 @@ struct SSSE3_8bit { //Multiply8_SSE2OrAVX2<Multiply8_C, __m128i, __m128>(A, B, C, unquant_mult, A_rows, width, B_cols); Multiply8_SSE2OrAVX2__m128i<JustUnquantizeC>(A, B, JustUnquantizeC(C, unquant_mult), A_rows, width, B_cols); }*/ - MULTIPLY8_define(__m128i, SSSE3) + MULTIPLY8_define(__m128i, SSSE3, OnSSE2) constexpr static const char *const kName = "8-bit SSSE3"; |