diff options
Diffstat (limited to 'multiply.h')
-rw-r--r-- | multiply.h | 43 |
1 files changed, 32 insertions, 11 deletions
@@ -2,10 +2,30 @@ #include "interleave.h" #include "intrinsics.h" +#include "postprocess_pipeline.h" #include "vec_utils.h" namespace intgemm { +INTGEMM_SSE2 static inline void writer(float* C, Index rowIDX, Index cols, Index colIDX, __m128 result) { + *reinterpret_cast<__m128*>(C + rowIDX*cols + colIDX) = result; +} + +INTGEMM_SSE2 static inline void writer(float* C, Index rowIDX, Index cols, Index colIDX, RegisterPair128 result) { + *reinterpret_cast<__m128*>(C + rowIDX*cols + colIDX) = result.pack0123; + *reinterpret_cast<__m128*>(C + rowIDX*cols + colIDX + 4) = result.pack4567; +} + +INTGEMM_AVX2 static inline void writer(float* C, Index rowIDX, Index cols, Index colIDX, __m256 result) { + *reinterpret_cast<__m256*>(C + rowIDX*cols + colIDX) = result; +} + +#ifndef INTGEMM_NO_AVX512 +INTGEMM_AVX512BW static inline void writer(float* C, Index rowIDX, Index cols, Index colIDX, __m512 result) { + *reinterpret_cast<__m512*>(C + rowIDX*cols + colIDX) = result; +} +#endif + INTGEMM_SSE2 static inline float MaxFloat32(__m128 a) { // Fold to just using the first 64 bits. __m128 second_half = _mm_shuffle_ps(a, a, 3 * 4 + 2); @@ -17,9 +37,9 @@ INTGEMM_SSE2 static inline float MaxFloat32(__m128 a) { return *reinterpret_cast<float*>(&a); } -INTGEMM_SSE2 static inline MultiplyResult128 PermuteSummer(__m128i pack0123, __m128i pack4567) { +INTGEMM_SSE2 static inline RegisterPair128i PermuteSummer(__m128i pack0123, __m128i pack4567) { // No op for 128 bits: already reduced fully. - MultiplyResult128 ret; + RegisterPair128i ret; ret.pack0123 = pack0123; ret.pack4567 = pack4567; return ret; @@ -126,14 +146,14 @@ INTGEMM_PACK0123(INTGEMM_AVX512BW, __m512i) // width must be a multiple of the register size. // B_cols must be a multiple of 8. // Multiply16 -#define INTGEMM_MULTIPLY16(Integer, target, WriteCSubType) \ - template <class WriteC> target static void Multiply(const int16_t *A, const int16_t *B, WriteC C, Index A_rows, Index width, Index B_cols) { \ +#define INTGEMM_MULTIPLY16(Integer, target, cpu_type) \ +template <typename PostprocessPipeline> target static void Multiply(const int16_t *A, const int16_t *B, float* C, PostprocessPipeline pipeline, Index A_rows, Index width, Index B_cols) { \ assert(width % (sizeof(Integer) / sizeof(int16_t)) == 0); \ assert(B_cols % 8 == 0); \ assert(reinterpret_cast<uintptr_t>(A) % sizeof(Integer) == 0); \ assert(reinterpret_cast<uintptr_t>(B) % sizeof(Integer) == 0); \ const int simd_width = width / (sizeof(Integer) / sizeof(int16_t)); \ - typename WriteC::WriteCSubType write_C(C); \ + auto inited_pipeline = InitPostprocessPipeline<cpu_type>(pipeline); \ const Integer *B0_col = reinterpret_cast<const Integer *>(B); \ for (Index B0_colidx = 0; B0_colidx < B_cols; B0_col += 8 * simd_width, B0_colidx += 8) { \ /* Process one row of A at a time. Doesn't seem to be faster to do multiple rows of A at once.*/ \ @@ -177,7 +197,8 @@ INTGEMM_PACK0123(INTGEMM_AVX512BW, __m512i) Integer pack4567 = Pack0123(sum4, sum5, sum6, sum7); \ /*The specific implementation may need to reduce further.*/ \ auto total = PermuteSummer(pack0123, pack4567); \ - write_C(A_rowidx, B_cols, B0_colidx, total); \ + auto result = inited_pipeline.run(total); \ + writer(C, A_rowidx, B_cols, B0_colidx, result); \ } \ } \ } \ @@ -330,15 +351,15 @@ INTGEMM_SSSE3 inline static void InnerINTGEMM_SSSE3( sum7 = adds_epi16(sum7, maddubs_epi16(a_positive, sign_epi8(b[7], a))); } //INTGEMM_AVX2 or INTGEMM_SSSE3 multiply -#define INTGEMM_MULTIPLY8(Integer, target, WriteCSubType) \ -template <class WriteC> target static void Multiply(const int8_t *A, const int8_t *B, WriteC C, Index A_rows, Index width, Index B_cols) { \ +#define INTGEMM_MULTIPLY8(Integer, target, cpu_type) \ + template <typename PostprocessPipeline> target static void Multiply(const int8_t *A, const int8_t *B, float* C, PostprocessPipeline pipeline, Index A_rows, Index width, Index B_cols) { \ assert(width % sizeof(Integer) == 0); \ assert(B_cols % 8 == 0); \ assert(reinterpret_cast<uintptr_t>(A) % sizeof(Integer) == 0); \ assert(reinterpret_cast<uintptr_t>(B) % sizeof(Integer) == 0); \ const int simd_width = width / sizeof(Integer); \ const Integer *B0_col = reinterpret_cast<const Integer*>(B); \ - typename WriteC::WriteCSubType c_writer(C); \ + auto inited_pipeline = InitPostprocessPipeline<cpu_type>(pipeline); \ /*Go over 8 columns of B at a time.*/ \ for (Index B0_colidx = 0; B0_colidx != B_cols; B0_col += 8 * simd_width, B0_colidx += 8) { \ /*Process one row of A at a time. Doesn't seem to be faster to do multiple rows of A at once.*/ \ @@ -393,8 +414,8 @@ template <class WriteC> target static void Multiply(const int8_t *A, const int8_ Integer pack0123 = Pack0123(sum0, sum1, sum2, sum3); \ Integer pack4567 = Pack0123(sum4, sum5, sum6, sum7); \ auto total = PermuteSummer(pack0123, pack4567); \ - /*WriteC(C + A_rowidx * B_cols + B0_colidx, total, unquant_reg);*/ \ - c_writer(A_rowidx, B_cols, B0_colidx, total); \ + auto result = inited_pipeline.run(total); \ + writer(C, A_rowidx, B_cols, B0_colidx, result); \ } \ } \ } \ |