Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/intgemm.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--avx2_gemm.h23
-rw-r--r--avx512_gemm.h15
-rw-r--r--interleave.h212
-rw-r--r--multiply.h270
-rw-r--r--sse2_gemm.h14
-rw-r--r--ssse3_gemm.h8
6 files changed, 302 insertions, 240 deletions
diff --git a/avx2_gemm.h b/avx2_gemm.h
index 32b1a5e..c931cd8 100644
--- a/avx2_gemm.h
+++ b/avx2_gemm.h
@@ -22,7 +22,7 @@ class QuantizeTile16 {
public:
typedef __m256i Integer;
- explicit QuantizeTile16(float mult) : mult_(_mm256_set1_ps(mult)) {}
+ AVX2 explicit QuantizeTile16(float mult) : mult_(_mm256_set1_ps(mult)) {}
AVX2 Integer Consecutive(const float *input) {
return Tile(input, input + 8);
@@ -71,17 +71,18 @@ struct AVX2_16bit {
// Tile size for B; B must be a multiple of this block size.
static const Index kBTileRow = 16;
static const Index kBTileCol = 8;
-
+/*
AVX2 static void PrepareB(const float *input, int16_t *output, float quant_mult, Index rows, Index cols) {
PrepareBFor16(input, output, avx2::QuantizeTile16(quant_mult), rows, cols);
- }
+ }*/
+ PREPARE_B_16_DEF(AVX2, avx2::QuantizeTile16)
AVX2 static void SelectColumnsB(const int16_t *input, int16_t *output, Index rows, const Index *cols_begin, const Index *cols_end) {
SelectColumnsOfB((const __m256i*)input, (__m256i*)output, rows * 2, cols_begin, cols_end);
}
AVX2 static void Multiply(const int16_t *A, const int16_t *B, float *C, float unquant_mult, Index A_rows, Index width, Index B_cols) {
- Multiply16<__m256i, JustUnquantizeC> (A, B, JustUnquantizeC(C, unquant_mult), A_rows, width, B_cols);
+ Multiply16__m256i<JustUnquantizeC> (A, B, JustUnquantizeC(C, unquant_mult), A_rows, width, B_cols);
}
constexpr static const char *const kName = "16-bit AVX2";
@@ -98,7 +99,7 @@ class QuantizeTile8 {
public:
typedef __m256i Integer;
- explicit QuantizeTile8(float quant_mult) : mult_(_mm256_set1_ps(quant_mult)) {}
+ AVX2 explicit QuantizeTile8(float quant_mult) : mult_(_mm256_set1_ps(quant_mult)) {}
AVX2 inline __m256i Consecutive(const float *input) {
return Tile(input, input + 8, input + 16, input + 24);
@@ -161,16 +162,20 @@ struct AVX2_8bit {
static const Index kBTileRow = 32;
static const Index kBTileCol = 8;
+/*
AVX2 static void PrepareB(const float *input, int8_t *output, float quant_mult, Index rows, Index cols) {
PrepareBFor8(input, output, avx2::QuantizeTile8(quant_mult), rows, cols);
- }
+ }*/
+
+ PREPARE_B_8_DEF(AVX2, avx2::QuantizeTile8)
AVX2 static void SelectColumnsB(const int8_t *input, int8_t *output, Index rows, const Index *cols_begin, const Index *cols_end) {
SelectColumnsOfB((const __m256i*)input, (__m256i*)output, rows, cols_begin, cols_end);
}
AVX2 static void Multiply(const int8_t *A, const int8_t *B, float *C, float unquant_mult, Index A_rows, Index width, Index B_cols) {
- Multiply8_SSE2OrAVX2<Multiply8_AVXAVX2, __m256i, __m256>(A, B, C, unquant_mult, A_rows, width, B_cols);
+ //Multiply8_SSE2OrAVX2<Multiply8_AVXAVX2, __m256i, __m256>(A, B, C, unquant_mult, A_rows, width, B_cols);
+ Multiply8_SSE2OrAVX2__m256i<Multiply8_AVXAVX2>(A, B, C, unquant_mult, A_rows, width, B_cols);
}
constexpr static const char *const kName = "8-bit AVX2";
@@ -179,8 +184,8 @@ struct AVX2_8bit {
};
// Technically only requires AVX
-AVX2 float AVX2_MaxAbsolute(const float *begin, const float *end) {
- return MaxAbsoluteBackend<__m256>(begin, end);
+AVX2 float AVX2_MaxAbsolute(const float *begin_float, const float *end_float) {
+ MAXABS_DEFINE(__m256)
}
} // namespace intgemm
diff --git a/avx512_gemm.h b/avx512_gemm.h
index e226686..755bddf 100644
--- a/avx512_gemm.h
+++ b/avx512_gemm.h
@@ -143,10 +143,12 @@ struct AVX512_16bit {
// Tile size for B; B must be a multiple of this block size.
static const Index kBTileRow = 32;
static const Index kBTileCol = 8;
-
+/*
AVX512F static void PrepareB(const float *input, int16_t *output, float quant_mult, Index rows, Index cols) {
PrepareBFor16(input, output, avx512f::QuantizeTile16(quant_mult), rows, cols);
}
+*/
+ PREPARE_B_16_DEF(AVX512F, avx512f::QuantizeTile16)
AVX512F static void SelectColumnsB(const int16_t *input, int16_t *output, Index rows, const Index *cols_begin, const Index *cols_end) {
SelectColumnsOfB((const __m512i*)input, (__m512i*)output, rows * 2, cols_begin, cols_end);
@@ -154,7 +156,7 @@ struct AVX512_16bit {
AVX512F static void Multiply(const int16_t *A, const int16_t *B, float *C, float unquant_mult, Index A_rows, Index width, Index B_cols) {
// The unquantization is only 256-bit wide because there are 8 results.
- Multiply16<__m512i, JustUnquantizeC> (A, B, JustUnquantizeC(C, unquant_mult), A_rows, width, B_cols);
+ Multiply16__m512i<JustUnquantizeC> (A, B, JustUnquantizeC(C, unquant_mult), A_rows, width, B_cols);
}
constexpr static const char *const kName = "16-bit AVX512";
@@ -190,10 +192,11 @@ struct AVX512_8bit {
// Tile size for B; B must be a multiple of this block size.
static const Index kBTileRow = 64;
static const Index kBTileCol = 8;
-
+/*
AVX512F static void PrepareB(const float *input, int8_t *output, float quant_mult, Index rows, Index cols) {
PrepareBFor8(input, output, avx512f::QuantizeTile8(quant_mult), rows, cols);
- }
+ }*/
+ PREPARE_B_8_DEF(AVX512F, avx512f::QuantizeTile8)
AVX512F static void SelectColumnsB(const int8_t *input, int8_t *output, Index rows, const Index *cols_begin, const Index *cols_end) {
SelectColumnsOfB((const __m512i*)input, (__m512i*)output, rows, cols_begin, cols_end);
@@ -317,8 +320,8 @@ struct AVX512_8bit {
static const CPUType kUses = CPU_AVX512BW;
};
-AVX512F float AVX512_MaxAbsolute(const float *begin, const float *end) {
- return MaxAbsoluteBackend<__m512>(begin, end);
+AVX512F float AVX512_MaxAbsolute(const float *begin_float, const float *end_float) {
+ MAXABS_DEFINE(__m512)
}
} // namespace intgemm
diff --git a/interleave.h b/interleave.h
index a2a8fb9..81d6957 100644
--- a/interleave.h
+++ b/interleave.h
@@ -65,60 +65,70 @@ template <> AVX512F inline __m512i setzero_si<__m512i>() {
}
#endif
-template <class Register> static inline void Swap(Register &a, Register &b) {
- Register tmp = a;
- a = b;
- b = tmp;
-}
+#define SWAP_DEFINE(target, Register) \
+target static inline void Swap(Register &a, Register &b) { \
+ Register tmp = a; \
+ a = b; \
+ b = tmp; \
+} \
+
+SWAP_DEFINE(SSE2, __m128i)
+SWAP_DEFINE(AVX2, __m256i)
+#ifndef INTGEMM_NO_AVX512
+SWAP_DEFINE(AVX512F, __m512i)
+#endif
/* Transpose registers containing 8 packed 16-bit integers.
* Each 128-bit lane is handled independently.
*/
-template <class Register> static inline void Transpose16InLane(Register &r0, Register &r1, Register &r2, Register &r3, Register &r4, Register &r5, Register &r6, Register &r7) {
- // r0: columns 0 1 2 3 4 5 6 7 from row 0
- // r1: columns 0 1 2 3 4 5 6 7 from row 1
-
- Interleave16(r0, r1);
- Interleave16(r2, r3);
- Interleave16(r4, r5);
- Interleave16(r6, r7);
- // r0: columns 0 0 1 1 2 2 3 3 from rows 0 and 1
- // r1: columns 4 4 5 5 6 6 7 7 from rows 0 and 1
- // r2: columns 0 0 1 1 2 2 3 3 from rows 2 and 3
- // r3: columns 4 4 5 5 6 6 7 7 from rows 2 and 3
- // r4: columns 0 0 1 1 2 2 3 3 from rows 4 and 5
- // r5: columns 4 4 5 5 6 6 7 7 from rows 4 and 5
- // r6: columns 0 0 1 1 2 2 3 3 from rows 6 and 7
- // r7: columns 4 4 5 5 6 6 7 7 from rows 6 and 7
-
- Interleave32(r0, r2);
- Interleave32(r1, r3);
- Interleave32(r4, r6);
- Interleave32(r5, r7);
- // r0: columns 0 0 0 0 1 1 1 1 from rows 0, 1, 2, and 3
- // r1: columns 4 4 4 4 5 5 5 5 from rows 0, 1, 2, and 3
- // r2: columns 2 2 2 2 3 3 3 3 from rows 0, 1, 2, and 3
- // r3: columns 6 6 6 6 7 7 7 7 from rows 0, 1, 2, and 3
- // r4: columns 0 0 0 0 1 1 1 1 from rows 4, 5, 6, and 7
- // r5: columns 4 4 4 4 5 5 5 5 from rows 4, 5, 6, and 7
- // r6: columns 2 2 2 2 3 3 3 3 from rows 4, 5, 6, and 7
- // r7: columns 6 6 6 6 7 7 7 7 from rows 4, 5, 6, and 7
+#define TRANSPOSE16_DEFINE(target, Register) \
+target static inline void Transpose16InLane(Register &r0, Register &r1, Register &r2, Register &r3, Register &r4, Register &r5, Register &r6, Register &r7) { \
+ /* r0: columns 0 1 2 3 4 5 6 7 from row 0
+ r1: columns 0 1 2 3 4 5 6 7 from row 1*/ \
+ Interleave16(r0, r1); \
+ Interleave16(r2, r3); \
+ Interleave16(r4, r5); \
+ Interleave16(r6, r7); \
+ /* r0: columns 0 0 1 1 2 2 3 3 from rows 0 and 1
+ r1: columns 4 4 5 5 6 6 7 7 from rows 0 and 1
+ r2: columns 0 0 1 1 2 2 3 3 from rows 2 and 3
+ r3: columns 4 4 5 5 6 6 7 7 from rows 2 and 3
+ r4: columns 0 0 1 1 2 2 3 3 from rows 4 and 5
+ r5: columns 4 4 5 5 6 6 7 7 from rows 4 and 5
+ r6: columns 0 0 1 1 2 2 3 3 from rows 6 and 7
+ r7: columns 4 4 5 5 6 6 7 7 from rows 6 and 7*/ \
+ Interleave32(r0, r2); \
+ Interleave32(r1, r3); \
+ Interleave32(r4, r6); \
+ Interleave32(r5, r7); \
+ /* r0: columns 0 0 0 0 1 1 1 1 from rows 0, 1, 2, and 3
+ r1: columns 4 4 4 4 5 5 5 5 from rows 0, 1, 2, and 3
+ r2: columns 2 2 2 2 3 3 3 3 from rows 0, 1, 2, and 3
+ r3: columns 6 6 6 6 7 7 7 7 from rows 0, 1, 2, and 3
+ r4: columns 0 0 0 0 1 1 1 1 from rows 4, 5, 6, and 7
+ r5: columns 4 4 4 4 5 5 5 5 from rows 4, 5, 6, and 7
+ r6: columns 2 2 2 2 3 3 3 3 from rows 4, 5, 6, and 7
+ r7: columns 6 6 6 6 7 7 7 7 from rows 4, 5, 6, and 7*/ \
+ Interleave64(r0, r4); \
+ Interleave64(r1, r5); \
+ Interleave64(r2, r6); \
+ Interleave64(r3, r7); \
+ /* r0: columns 0 0 0 0 0 0 0 0 from rows 0 through 7
+ r1: columns 4 4 4 4 4 4 4 4 from rows 0 through 7
+ r2: columns 2 2 2 2 2 2 2 2 from rows 0 through 7
+ r3: columns 6 6 6 6 6 6 6 6 from rows 0 through 7
+ r4: columns 1 1 1 1 1 1 1 1 from rows 0 through 7
+ r5: columns 5 5 5 5 5 5 5 5 from rows 0 through 7*/ \
+ /* Empirically gcc is able to remove these movs and just rename the outputs of Interleave64. */ \
+ Swap(r1, r4); \
+ Swap(r3, r6); \
+} \
- Interleave64(r0, r4);
- Interleave64(r1, r5);
- Interleave64(r2, r6);
- Interleave64(r3, r7);
- // r0: columns 0 0 0 0 0 0 0 0 from rows 0 through 7
- // r1: columns 4 4 4 4 4 4 4 4 from rows 0 through 7
- // r2: columns 2 2 2 2 2 2 2 2 from rows 0 through 7
- // r3: columns 6 6 6 6 6 6 6 6 from rows 0 through 7
- // r4: columns 1 1 1 1 1 1 1 1 from rows 0 through 7
- // r5: columns 5 5 5 5 5 5 5 5 from rows 0 through 7
-
- // Empirically gcc is able to remove these movs and just rename the outputs of Interleave64.
- Swap(r1, r4);
- Swap(r3, r6);
-}
+TRANSPOSE16_DEFINE(SSE2, __m128i)
+TRANSPOSE16_DEFINE(AVX2, __m256i)
+#ifndef INTGEMM_NO_AVX512
+TRANSPOSE16_DEFINE(AVX512F, __m512i)
+#endif
/* Tranpose registers containing 16 packed 8-bit integers.
* Each 128-bit lane is handled independently.
@@ -196,57 +206,61 @@ template <class Register> static inline void Transpose8InLane(
// 256 272
// 257 273
// ... ...
-template <class Quantizer> static inline void PrepareBFor8(const float *input, int8_t *output_shadow, Quantizer q, Index rows, Index cols) {
- typedef typename Quantizer::Integer Register;
- // Currently all multipliers have a stride of 8 columns.
- const int kColStride = 8;
- assert(cols % kColStride == 0);
- assert(rows % sizeof(Register) == 0);
- assert(reinterpret_cast<uintptr_t>(input) % sizeof(Register) == 0);
- Register *output = reinterpret_cast<Register*>(output_shadow);
- assert(reinterpret_cast<uintptr_t>(output) % sizeof(Register) == 0);
-
- for (int c = 0; c < cols; c += kColStride) {
- for (int r = 0; r < rows; r += sizeof(Register), output += 8) {
- // Quantize and perform a transpose with height sizeof(Register) and width 8.
- // This isn't quite Transpose8InLane because it's half the number of columns,
- // so each register starts with two rows instead of being one row.
- // The quantizers know to skip a row.
- output[0] = q.ForReshape(input + cols * (r ) + c, cols);
- output[1] = q.ForReshape(input + cols * (r + 1) + c, cols);
- output[2] = q.ForReshape(input + cols * (r + 4) + c, cols);
- output[3] = q.ForReshape(input + cols * (r + 5) + c, cols);
- output[4] = q.ForReshape(input + cols * (r + 8) + c, cols);
- output[5] = q.ForReshape(input + cols * (r + 9) + c, cols);
- output[6] = q.ForReshape(input + cols * (r + 12) + c, cols);
- output[7] = q.ForReshape(input + cols * (r + 13) + c, cols);
- Interleave8(output[0], output[1]);
- Interleave8(output[2], output[3]);
- Interleave8(output[4], output[5]);
- Interleave8(output[6], output[7]);
- Transpose16InLane(output[0], output[1], output[2], output[3], output[4], output[5], output[6], output[7]);
- }
- }
-}
-
-template <class Quantizer> static inline void PrepareBFor16(const float *input, int16_t *output_shadow, Quantizer q, Index rows, Index cols) {
- typedef typename Quantizer::Integer Register;
- assert(cols % 8 == 0);
- assert(rows % (sizeof(Register) / sizeof(int16_t)) == 0);
- assert(reinterpret_cast<uintptr_t>(input) % sizeof(Register) == 0);
- Register *output = reinterpret_cast<Register*>(output_shadow);
- assert(reinterpret_cast<uintptr_t>(output) % sizeof(Register) == 0);
+#define PREPARE_B_8_DEF(target, QuantClass) \
+target static inline void PrepareB(const float *input, int8_t *output_shadow, float quant_mult, Index rows, Index cols) { \
+ typedef typename QuantClass Quantizer; \
+ typedef typename Quantizer::Integer Register; \
+ Quantizer q = Quantizer(quant_mult); \
+ /* Currently all multipliers have a stride of 8 columns.*/ \
+ const int kColStride = 8; \
+ assert(cols % kColStride == 0); \
+ assert(rows % sizeof(Register) == 0); \
+ assert(reinterpret_cast<uintptr_t>(input) % sizeof(Register) == 0); \
+ Register *output = reinterpret_cast<Register*>(output_shadow); \
+ assert(reinterpret_cast<uintptr_t>(output) % sizeof(Register) == 0); \
+ for (int c = 0; c < cols; c += kColStride) { \
+ for (int r = 0; r < rows; r += sizeof(Register), output += 8) { \
+ /* Quantize and perform a transpose with height sizeof(Register) and width 8. \
+ This isn't quite Transpose8InLane because it's half the number of columns, \
+ so each register starts with two rows instead of being one row. \
+ The quantizers know to skip a row.*/ \
+ output[0] = q.ForReshape(input + cols * (r ) + c, cols); \
+ output[1] = q.ForReshape(input + cols * (r + 1) + c, cols); \
+ output[2] = q.ForReshape(input + cols * (r + 4) + c, cols); \
+ output[3] = q.ForReshape(input + cols * (r + 5) + c, cols); \
+ output[4] = q.ForReshape(input + cols * (r + 8) + c, cols); \
+ output[5] = q.ForReshape(input + cols * (r + 9) + c, cols); \
+ output[6] = q.ForReshape(input + cols * (r + 12) + c, cols); \
+ output[7] = q.ForReshape(input + cols * (r + 13) + c, cols); \
+ Interleave8(output[0], output[1]); \
+ Interleave8(output[2], output[3]); \
+ Interleave8(output[4], output[5]); \
+ Interleave8(output[6], output[7]); \
+ Transpose16InLane(output[0], output[1], output[2], output[3], output[4], output[5], output[6], output[7]); \
+ } \
+ } \
+} \
- for (int c = 0; c < cols; c += 8) {
- for (int r = 0; r < rows; r += (sizeof(Register) / sizeof(int16_t)), output += 8) {
- // gcc unrolls this loop and uses registers for output[k]
- for (int k = 0; k < 8; ++k) {
- output[k] = q.ForReshape(input + cols * (r + k) + c, cols);
- }
- Transpose16InLane(output[0], output[1], output[2], output[3], output[4], output[5], output[6], output[7]);
- }
- }
-}
+#define PREPARE_B_16_DEF(target, QuantClass) \
+target static inline void PrepareB(const float *input, int16_t *output_shadow, float quant_mult, Index rows, Index cols) { \
+ typedef typename QuantClass Quantizer; \
+ typedef typename Quantizer::Integer Register; \
+ Quantizer q = Quantizer(quant_mult); \
+ assert(cols % 8 == 0); \
+ assert(rows % (sizeof(Register) / sizeof(int16_t)) == 0); \
+ assert(reinterpret_cast<uintptr_t>(input) % sizeof(Register) == 0); \
+ Register *output = reinterpret_cast<Register*>(output_shadow); \
+ assert(reinterpret_cast<uintptr_t>(output) % sizeof(Register) == 0); \
+ for (int c = 0; c < cols; c += 8) { \
+ for (int r = 0; r < rows; r += (sizeof(Register) / sizeof(int16_t)), output += 8) { \
+ /* gcc unrolls this loop and uses registers for output[k]*/ \
+ for (int k = 0; k < 8; ++k) { \
+ output[k] = q.ForReshape(input + cols * (r + k) + c, cols); \
+ } \
+ Transpose16InLane(output[0], output[1], output[2], output[3], output[4], output[5], output[6], output[7]); \
+ } \
+ } \
+} \
/* Select columns of B from PrepareB format to PrepareB format.
*/
diff --git a/multiply.h b/multiply.h
index eed4e14..0d3854e 100644
--- a/multiply.h
+++ b/multiply.h
@@ -69,7 +69,7 @@ static inline float MaxFloat32(__m512 a) {
/* Take 4 registers with 32-bit values to be horizontally added. Reduce them
* to one register with 32-bit values in the pattern 1 2 3 4 1 2 3 4, leaving
- * the final addition (which crosses 128-bit lanes) to the caller. */
+ * the final addition (which crosses 128-bit lanes) to the caller.
template <class Register> inline Register Pack0123(Register sum0, Register sum1, Register sum2, Register sum3) {
// 1 2 1 2 1 2 1 2
Interleave32(sum0, sum1);
@@ -81,6 +81,22 @@ template <class Register> inline Register Pack0123(Register sum0, Register sum1,
// 1 2 3 4 1 2 3 4
return add_epi32(pack01, pack23);
}
+ */
+#define PACK_DEFINE(target, Register) \
+target inline Register Pack0123(Register sum0, Register sum1, Register sum2, Register sum3) { \
+ Interleave32(sum0, sum1); \
+ Register pack01 = add_epi32(sum0, sum1); \
+ Interleave32(sum2, sum3); \
+ Register pack23 = add_epi32(sum2, sum3); \
+ Interleave64(pack01, pack23); \
+ return add_epi32(pack01, pack23); \
+} \
+
+PACK_DEFINE(SSE2, __m128i)
+PACK_DEFINE(AVX2, __m256i)
+#ifndef INTGEMM_NO_AVX512
+PACK_DEFINE(AVX512F, __m512i)
+#endif
// 16-bit multiplier for SSE2, AVX2, and AVX512.
// C = A * B * unquant_mult
@@ -118,67 +134,71 @@ template <class Register> inline Register Pack0123(Register sum0, Register sum1,
// A_rows can be anything non-negative.
// width must be a multiple of the register size.
// B_cols must be a multiple of 8.
-//#define Multiply16(Integer, Annotate) \ //fd
-// template <class WriteC> Annotate inline void Multiply16(const int16_t *A, const int16_t *B, WriteC functor, Index A_rows, Index width, Index B_cols) {
-//
-template <class Integer, class WriteC> inline void Multiply16(const int16_t *A, const int16_t *B, WriteC functor, Index A_rows, Index width, Index B_cols) {
- assert(width % (sizeof(Integer) / sizeof(int16_t)) == 0);
- assert(B_cols % 8 == 0);
- assert(reinterpret_cast<uintptr_t>(A) % sizeof(Integer) == 0);
- assert(reinterpret_cast<uintptr_t>(B) % sizeof(Integer) == 0);
- //assert(reinterpret_cast<uintptr_t>(C) % sizeof(Integer) == 0);
- const int simd_width = width / (sizeof(Integer) / sizeof(int16_t));
- //const Float unquant_reg = set1_ps<Float>(unquant_mult);
- const Integer *B0_col = reinterpret_cast<const Integer *>(B);
- for (int B0_colidx = 0; B0_colidx < B_cols; B0_col += 8 * simd_width, B0_colidx += 8) {
- // Process one row of A at a time. Doesn't seem to be faster to do multiple rows of A at once.
- for (int A_rowidx = 0; A_rowidx < A_rows; ++A_rowidx) {
- const Integer *A_row = reinterpret_cast<const Integer*>(A + A_rowidx * width);
- // These will be packed 32-bit integers containing sums for each row of B multiplied by the row of A.
- // Iterate over shared (inner) dimension.
- int k = 0;
- Integer a = *(A_row + k);
- Integer sum0 = madd_epi16(a, *(B0_col + k * 8));
- Integer sum1 = madd_epi16(a, *(B0_col + k * 8 + 1));
- Integer sum2 = madd_epi16(a, *(B0_col + k * 8 + 2));
- Integer sum3 = madd_epi16(a, *(B0_col + k * 8 + 3));
- Integer sum4 = madd_epi16(a, *(B0_col + k * 8 + 4));
- Integer sum5 = madd_epi16(a, *(B0_col + k * 8 + 5));
- Integer sum6 = madd_epi16(a, *(B0_col + k * 8 + 6));
- Integer sum7 = madd_epi16(a, *(B0_col + k * 8 + 7));
- for (int k = 1; k < simd_width; ++k) {
- Integer a = *(A_row + k);
- // Multiply 16-bit, horizontally add to packed 32-bit integers.
- Integer mult0 = madd_epi16(a, *(B0_col + k * 8));
- Integer mult1 = madd_epi16(a, *(B0_col + k * 8 + 1));
- Integer mult2 = madd_epi16(a, *(B0_col + k * 8 + 2));
- Integer mult3 = madd_epi16(a, *(B0_col + k * 8 + 3));
- Integer mult4 = madd_epi16(a, *(B0_col + k * 8 + 4));
- Integer mult5 = madd_epi16(a, *(B0_col + k * 8 + 5));
- Integer mult6 = madd_epi16(a, *(B0_col + k * 8 + 6));
- Integer mult7 = madd_epi16(a, *(B0_col + k * 8 + 7));
- // Sum packed 32-bit integers with danger of overflow. TODO: accumulate in 64-bit every so often.
- sum0 = add_epi32(sum0, mult0);
- sum1 = add_epi32(sum1, mult1);
- sum2 = add_epi32(sum2, mult2);
- sum3 = add_epi32(sum3, mult3);
- sum4 = add_epi32(sum4, mult4);
- sum5 = add_epi32(sum5, mult5);
- sum6 = add_epi32(sum6, mult6);
- sum7 = add_epi32(sum7, mult7);
- }
- // Reduce sums within 128-bit lanes.
- Integer pack0123 = Pack0123(sum0, sum1, sum2, sum3);
- Integer pack4567 = Pack0123(sum4, sum5, sum6, sum7);
- // The specific implementation may need to reduce further.
+//template <class Integer, class WriteC> inline void Multiply16(const int16_t *A, const int16_t *B, WriteC functor, Index A_rows, Index width, Index B_cols) {
+#define MULTIPLY16_define(Integer, target) \
+ template <class WriteC> target inline void Multiply16##Integer(const int16_t *A, const int16_t *B, WriteC functor, Index A_rows, Index width, Index B_cols) { \
+ assert(width % (sizeof(Integer) / sizeof(int16_t)) == 0); \
+ assert(B_cols % 8 == 0); \
+ assert(reinterpret_cast<uintptr_t>(A) % sizeof(Integer) == 0); \
+ assert(reinterpret_cast<uintptr_t>(B) % sizeof(Integer) == 0); \
+ /*assert(reinterpret_cast<uintptr_t>(C) % sizeof(Integer) == 0); Moved to WriteC*/ \
+ const int simd_width = width / (sizeof(Integer) / sizeof(int16_t)); \
+ /*const Float unquant_reg = set1_ps<Float>(unquant_mult); moved to WriteC*/ \
+ const Integer *B0_col = reinterpret_cast<const Integer *>(B); \
+ for (int B0_colidx = 0; B0_colidx < B_cols; B0_col += 8 * simd_width, B0_colidx += 8) { \
+ /* Process one row of A at a time. Doesn't seem to be faster to do multiple rows of A at once.*/ \
+ for (int A_rowidx = 0; A_rowidx < A_rows; ++A_rowidx) { \
+ const Integer *A_row = reinterpret_cast<const Integer*>(A + A_rowidx * width); \
+ /* These will be packed 32-bit integers containing sums for each row of B multiplied by the row of A. \
+ Iterate over shared (inner) dimension.*/ \
+ int k = 0; \
+ Integer a = *(A_row + k); \
+ Integer sum0 = madd_epi16(a, *(B0_col + k * 8)); \
+ Integer sum1 = madd_epi16(a, *(B0_col + k * 8 + 1)); \
+ Integer sum2 = madd_epi16(a, *(B0_col + k * 8 + 2)); \
+ Integer sum3 = madd_epi16(a, *(B0_col + k * 8 + 3)); \
+ Integer sum4 = madd_epi16(a, *(B0_col + k * 8 + 4)); \
+ Integer sum5 = madd_epi16(a, *(B0_col + k * 8 + 5)); \
+ Integer sum6 = madd_epi16(a, *(B0_col + k * 8 + 6)); \
+ Integer sum7 = madd_epi16(a, *(B0_col + k * 8 + 7)); \
+ for (int k = 1; k < simd_width; ++k) { \
+ Integer a = *(A_row + k); \
+ /* Multiply 16-bit, horizontally add to packed 32-bit integers.*/ \
+ Integer mult0 = madd_epi16(a, *(B0_col + k * 8)); \
+ Integer mult1 = madd_epi16(a, *(B0_col + k * 8 + 1)); \
+ Integer mult2 = madd_epi16(a, *(B0_col + k * 8 + 2)); \
+ Integer mult3 = madd_epi16(a, *(B0_col + k * 8 + 3)); \
+ Integer mult4 = madd_epi16(a, *(B0_col + k * 8 + 4)); \
+ Integer mult5 = madd_epi16(a, *(B0_col + k * 8 + 5)); \
+ Integer mult6 = madd_epi16(a, *(B0_col + k * 8 + 6)); \
+ Integer mult7 = madd_epi16(a, *(B0_col + k * 8 + 7)); \
+ /* Sum packed 32-bit integers with danger of overflow. TODO: accumulate in 64-bit every so often.*/ \
+ sum0 = add_epi32(sum0, mult0); \
+ sum1 = add_epi32(sum1, mult1); \
+ sum2 = add_epi32(sum2, mult2); \
+ sum3 = add_epi32(sum3, mult3); \
+ sum4 = add_epi32(sum4, mult4); \
+ sum5 = add_epi32(sum5, mult5); \
+ sum6 = add_epi32(sum6, mult6); \
+ sum7 = add_epi32(sum7, mult7); \
+ } \
+ /* Reduce sums within 128-bit lanes.*/ \
+ Integer pack0123 = Pack0123(sum0, sum1, sum2, sum3); \
+ Integer pack4567 = Pack0123(sum4, sum5, sum6, sum7); \
+ /*The specific implementation may need to reduce further.*/ \
+ auto total = PermuteSummer(pack0123, pack4567); \
+ functor(A_rowidx, B_cols, B0_colidx, total); \
+ } \
+ } \
+} \
- auto total = PermuteSummer(pack0123, pack4567);
- functor(A_rowidx, B_cols, B0_colidx, total);
- //WriteC(C + A_rowidx * B_cols + B0_colidx, total, unquant_reg);
- }
- }
-}
+MULTIPLY16_define(__m128i, SSE2)
+MULTIPLY16_define(__m256i, AVX2)
+#ifndef INTGEMM_NO_AVX512
+MULTIPLY16_define(__m512i, AVX512F)
+#endif
+//MULTIPLY16_define(__m256i, AVX2)
/* 8-bit matrix multiply used by AVX and AVX2.
* These have two peculiar properties:
* 1. The sign instructions don't exist in AVX512.
@@ -192,7 +212,7 @@ template <class Integer, class WriteC> inline void Multiply16(const int16_t *A,
* vpmaddubsw. That's why this code is generic over 128-bit or 256-bit.
*/
struct Multiply8_AVXAVX2 {
- template <class Integer> inline static void Inner(
+ template <class Integer> AVX2 inline static void Inner(
Integer a, const Integer *b,
Integer &sum0, Integer &sum1, Integer &sum2, Integer &sum3,
Integer &sum4, Integer &sum5, Integer &sum6, Integer &sum7) {
@@ -328,44 +348,41 @@ struct Multiply8_C {
sum7 = adds_epi16(sum7, maddubs_epi16(a_positive, sign_epi8(b[7], a)));
}
};
-
-template <class Algo, class Integer, class Float> inline void Multiply8_SSE2OrAVX2(const int8_t *A, const int8_t *B, float *C, float unquant_mult, Index A_rows, Index width, Index B_cols) {
- assert(width % sizeof(Integer) == 0);
- assert(B_cols % 8 == 0);
- assert(reinterpret_cast<uintptr_t>(A) % sizeof(Integer) == 0);
- assert(reinterpret_cast<uintptr_t>(B) % sizeof(Integer) == 0);
- assert(reinterpret_cast<uintptr_t>(C) % sizeof(Integer) == 0);
- Float unquant_reg = set1_ps<Float>(unquant_mult);
- const int simd_width = width / sizeof(Integer);
- const Integer *B0_col = reinterpret_cast<const Integer*>(B);
- // Go over 8 columns of B at a time.
- for (int B0_colidx = 0; B0_colidx != B_cols; B0_col += 8 * simd_width, B0_colidx += 8) {
- // Process one row of A at a time. Doesn't seem to be faster to do multiple rows of A at once.
- for (int A_rowidx = 0; A_rowidx < A_rows; ++A_rowidx) {
- // Iterate over shared (inner) dimension.
- const Integer *A_live = reinterpret_cast<const Integer *>(A + A_rowidx * width);
- const Integer *A_end = A_live + simd_width;
- const Integer *B_live = B0_col;
-
- // Rather than initializing as zeros and adding, just initialize the first.
- Integer a = *(A_live++);
- Integer a_positive = abs_epi8(a);
- // These will be packed 16-bit integers containing sums for each column of B multiplied by the row of A.
- Integer sum0 = maddubs_epi16(a_positive, sign_epi8(B_live[0], a));
- Integer sum1 = maddubs_epi16(a_positive, sign_epi8(B_live[1], a));
- Integer sum2 = maddubs_epi16(a_positive, sign_epi8(B_live[2], a));
- Integer sum3 = maddubs_epi16(a_positive, sign_epi8(B_live[3], a));
- Integer sum4 = maddubs_epi16(a_positive, sign_epi8(B_live[4], a));
- Integer sum5 = maddubs_epi16(a_positive, sign_epi8(B_live[5], a));
- Integer sum6 = maddubs_epi16(a_positive, sign_epi8(B_live[6], a));
- Integer sum7 = maddubs_epi16(a_positive, sign_epi8(B_live[7], a));
- B_live += 8;
-
- // Use A as the loop variable so the add can be done where gcc likes it
- // for branch prediction.
- for (; A_live != A_end; ++A_live, B_live += 8) {
- Algo::Inner(*A_live, B_live, sum0, sum1, sum2, sum3, sum4, sum5, sum6, sum7);
- }
+#define MULTIPLY8_define(Integer, Float, target) \
+template <class Algo> target inline void Multiply8_SSE2OrAVX2##Integer(const int8_t *A, const int8_t *B, float *C, float unquant_mult, Index A_rows, Index width, Index B_cols) { \
+ assert(width % sizeof(Integer) == 0); \
+ assert(B_cols % 8 == 0); \
+ assert(reinterpret_cast<uintptr_t>(A) % sizeof(Integer) == 0); \
+ assert(reinterpret_cast<uintptr_t>(B) % sizeof(Integer) == 0); \
+ assert(reinterpret_cast<uintptr_t>(C) % sizeof(Integer) == 0); \
+ Float unquant_reg = set1_ps<Float>(unquant_mult); \
+ const int simd_width = width / sizeof(Integer); \
+ const Integer *B0_col = reinterpret_cast<const Integer*>(B); \
+ /*Go over 8 columns of B at a time.*/ \
+ for (int B0_colidx = 0; B0_colidx != B_cols; B0_col += 8 * simd_width, B0_colidx += 8) { \
+ /*Process one row of A at a time. Doesn't seem to be faster to do multiple rows of A at once.*/ \
+ for (int A_rowidx = 0; A_rowidx < A_rows; ++A_rowidx) { \
+ /*Iterate over shared (inner) dimension.*/ \
+ const Integer *A_live = reinterpret_cast<const Integer *>(A + A_rowidx * width); \
+ const Integer *A_end = A_live + simd_width; \
+ const Integer *B_live = B0_col; \
+ /* Rather than initializing as zeros and adding, just initialize the first.*/ \
+ Integer a = *(A_live++); \
+ Integer a_positive = abs_epi8(a); \
+ /* These will be packed 16-bit integers containing sums for each column of B multiplied by the row of A.*/ \
+ Integer sum0 = maddubs_epi16(a_positive, sign_epi8(B_live[0], a)); \
+ Integer sum1 = maddubs_epi16(a_positive, sign_epi8(B_live[1], a)); \
+ Integer sum2 = maddubs_epi16(a_positive, sign_epi8(B_live[2], a)); \
+ Integer sum3 = maddubs_epi16(a_positive, sign_epi8(B_live[3], a)); \
+ Integer sum4 = maddubs_epi16(a_positive, sign_epi8(B_live[4], a)); \
+ Integer sum5 = maddubs_epi16(a_positive, sign_epi8(B_live[5], a)); \
+ Integer sum6 = maddubs_epi16(a_positive, sign_epi8(B_live[6], a)); \
+ Integer sum7 = maddubs_epi16(a_positive, sign_epi8(B_live[7], a)); \
+ B_live += 8; \
+ /* Use A as the loop variable so the add can be done where gcc likes it for branch prediction.*/ \
+ for (; A_live != A_end; ++A_live, B_live += 8) { \
+ Algo::Inner(*A_live, B_live, sum0, sum1, sum2, sum3, sum4, sum5, sum6, sum7); \
+ } \
/* Convert 16-bit to 32-bit and add, not caring what parts are added.
* Implementations:
* 1. https://github.com/tesseract-ocr/tesseract/blob/master/src/arch/intsimdmatrixavx2.cpp#L67 under Apache license:
@@ -382,26 +399,31 @@ template <class Algo, class Integer, class Float> inline void Multiply8_SSE2OrAV
* sum = _mm512_add_epi32(
* _mm512_srai_epi32(_mm512_slli_epi32(sum, 16), 16),
* _mm512_srai_epi32(sum, 16));
- */
- Integer ones = set1_epi16<Integer>(1);
- sum0 = madd_epi16(sum0, ones);
- sum1 = madd_epi16(sum1, ones);
- sum2 = madd_epi16(sum2, ones);
- sum3 = madd_epi16(sum3, ones);
- sum4 = madd_epi16(sum4, ones);
- sum5 = madd_epi16(sum5, ones);
- sum6 = madd_epi16(sum6, ones);
- sum7 = madd_epi16(sum7, ones);
- Integer pack0123 = Pack0123(sum0, sum1, sum2, sum3);
- Integer pack4567 = Pack0123(sum4, sum5, sum6, sum7);
+ */ \
+ Integer ones = set1_epi16<Integer>(1); \
+ sum0 = madd_epi16(sum0, ones); \
+ sum1 = madd_epi16(sum1, ones); \
+ sum2 = madd_epi16(sum2, ones); \
+ sum3 = madd_epi16(sum3, ones); \
+ sum4 = madd_epi16(sum4, ones); \
+ sum5 = madd_epi16(sum5, ones); \
+ sum6 = madd_epi16(sum6, ones); \
+ sum7 = madd_epi16(sum7, ones); \
+ Integer pack0123 = Pack0123(sum0, sum1, sum2, sum3); \
+ Integer pack4567 = Pack0123(sum4, sum5, sum6, sum7); \
+ auto total = PermuteSummer(pack0123, pack4567); \
+ WriteC(C + A_rowidx * B_cols + B0_colidx, total, unquant_reg); \
+ } \
+ } \
+} \
+
+MULTIPLY8_define(__m128i, __m128, SSSE3)
+
+MULTIPLY8_define(__m256i, __m256, AVX2)
- auto total = PermuteSummer(pack0123, pack4567);
- WriteC(C + A_rowidx * B_cols + B0_colidx, total, unquant_reg);
- }
- }
-}
// Find the maximum absolute value of packed float32s.
+/*
template <class Register> inline static float MaxAbsoluteBackend(const float *begin_float, const float *end_float) {
assert(end_float > begin_float);
assert((end_float - begin_float) % (sizeof(Register) / sizeof(float)) == 0);
@@ -418,6 +440,20 @@ template <class Register> inline static float MaxAbsoluteBackend(const float *be
}
return MaxFloat32(highest);
-}
+}*/
+#define MAXABS_DEFINE(Register) \
+ assert(end_float > begin_float); \
+ assert((end_float - begin_float) % (sizeof(Register) / sizeof(float)) == 0); \
+ const Register *begin = reinterpret_cast<const Register*>(begin_float); \
+ const Register *end = reinterpret_cast<const Register*>(end_float); \
+ union {float f; int32_t i;} float_convert; \
+ float_convert.i = 0x7fffffff; \
+ Register and_me = set1_ps<Register>(float_convert.f); \
+ Register highest = and_ps(and_me, *begin); \
+ for (++begin; begin != end; ++begin) { \
+ Register reg = and_ps(and_me, *begin); \
+ highest = max_ps(highest, reg); \
+ } \
+ return MaxFloat32(highest); \
} // namespace intgemm
diff --git a/sse2_gemm.h b/sse2_gemm.h
index 7a5e5ce..53b0033 100644
--- a/sse2_gemm.h
+++ b/sse2_gemm.h
@@ -19,7 +19,7 @@ class QuantizeTile16 {
public:
typedef __m128i Integer;
- explicit QuantizeTile16(float mult) : mult_reg_(_mm_set1_ps(mult)) {}
+ SSE2 explicit QuantizeTile16(float mult) : mult_reg_(_mm_set1_ps(mult)) {}
// Quantize 8xfloat into 8xint16_t
SSE2 inline __m128i Consecutive(const float *input) {
@@ -59,11 +59,13 @@ struct SSE2_16bit {
// Tile size for B; B must be a multiple of this block size.
static const Index kBTileRow = 8;
static const Index kBTileCol = 8;
-
+/*
SSE2 static void PrepareB(const float *input, int16_t *output, float quant_mult, Index rows, Index cols) {
//TODO #DEFINE
PrepareBFor16(input, output, sse2::QuantizeTile16(quant_mult), rows, cols);
- }
+ }*/
+
+ PREPARE_B_16_DEF(SSE2, sse2::QuantizeTile16)
SSE2 static void SelectColumnsB(const int16_t *input, int16_t *output, Index rows, const Index *cols_begin, const Index *cols_end) {
//TODO #DEFINE
@@ -72,7 +74,7 @@ struct SSE2_16bit {
SSE2 static void Multiply(const int16_t *A, const int16_t *B, float *C, float unquant_mult, Index A_rows, Index width, Index B_cols) {
//TODO #DEFINE
- Multiply16<__m128i, JustUnquantizeC> (A, B, JustUnquantizeC(C, unquant_mult), A_rows, width, B_cols);
+ Multiply16__m128i<JustUnquantizeC> (A, B, JustUnquantizeC(C, unquant_mult), A_rows, width, B_cols);
}
constexpr static const char *const kName = "16-bit SSE2";
@@ -81,8 +83,8 @@ struct SSE2_16bit {
};
// Technically only requires SSE
-SSE2 float SSE2_MaxAbsolute(const float *begin, const float *end) {
- return MaxAbsoluteBackend<__m128>(begin, end);
+SSE2 float SSE2_MaxAbsolute(const float *begin_float, const float *end_float) {
+ MAXABS_DEFINE(__m128)
}
} // namespace intgemm
diff --git a/ssse3_gemm.h b/ssse3_gemm.h
index 69ac298..2b830a9 100644
--- a/ssse3_gemm.h
+++ b/ssse3_gemm.h
@@ -86,17 +86,19 @@ struct SSSE3_8bit {
// Tile size for B; B must be a multiple of this block size.
static const Index kBTileRow = 16;
static const Index kBTileCol = 8;
-
+/*
SSSE3 static void PrepareB(const float *input, int8_t *output, float quant_mult, Index rows, Index cols) {
PrepareBFor8(input, output, ssse3::QuantizeTile8(quant_mult), rows, cols);
- }
+ }*/
+ PREPARE_B_8_DEF(SSSE3, ssse3::QuantizeTile8)
SSSE3 static void SelectColumnsB(const int8_t *input, int8_t *output, Index rows, const Index *cols_begin, const Index *cols_end) {
SelectColumnsOfB((const __m128i*)input, (__m128i*)output, rows, cols_begin, cols_end);
}
SSSE3 static void Multiply(const int8_t *A, const int8_t *B, float *C, float unquant_mult, Index A_rows, Index width, Index B_cols) {
- Multiply8_SSE2OrAVX2<Multiply8_C, __m128i, __m128>(A, B, C, unquant_mult, A_rows, width, B_cols);
+ //Multiply8_SSE2OrAVX2<Multiply8_C, __m128i, __m128>(A, B, C, unquant_mult, A_rows, width, B_cols);
+ Multiply8_SSE2OrAVX2__m128i<Multiply8_C>(A, B, C, unquant_mult, A_rows, width, B_cols);
}
constexpr static const char *const kName = "8-bit SSSE3";