Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/intgemm/intgemm.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'multiply.h')
-rw-r--r--multiply.h43
1 files changed, 32 insertions, 11 deletions
diff --git a/multiply.h b/multiply.h
index 4642616..3b05252 100644
--- a/multiply.h
+++ b/multiply.h
@@ -2,10 +2,30 @@
#include "interleave.h"
#include "intrinsics.h"
+#include "postprocess_pipeline.h"
#include "vec_utils.h"
namespace intgemm {
+INTGEMM_SSE2 static inline void writer(float* C, Index rowIDX, Index cols, Index colIDX, __m128 result) {
+ *reinterpret_cast<__m128*>(C + rowIDX*cols + colIDX) = result;
+}
+
+INTGEMM_SSE2 static inline void writer(float* C, Index rowIDX, Index cols, Index colIDX, RegisterPair128 result) {
+ *reinterpret_cast<__m128*>(C + rowIDX*cols + colIDX) = result.pack0123;
+ *reinterpret_cast<__m128*>(C + rowIDX*cols + colIDX + 4) = result.pack4567;
+}
+
+INTGEMM_AVX2 static inline void writer(float* C, Index rowIDX, Index cols, Index colIDX, __m256 result) {
+ *reinterpret_cast<__m256*>(C + rowIDX*cols + colIDX) = result;
+}
+
+#ifndef INTGEMM_NO_AVX512
+INTGEMM_AVX512BW static inline void writer(float* C, Index rowIDX, Index cols, Index colIDX, __m512 result) {
+ *reinterpret_cast<__m512*>(C + rowIDX*cols + colIDX) = result;
+}
+#endif
+
INTGEMM_SSE2 static inline float MaxFloat32(__m128 a) {
// Fold to just using the first 64 bits.
__m128 second_half = _mm_shuffle_ps(a, a, 3 * 4 + 2);
@@ -17,9 +37,9 @@ INTGEMM_SSE2 static inline float MaxFloat32(__m128 a) {
return *reinterpret_cast<float*>(&a);
}
-INTGEMM_SSE2 static inline MultiplyResult128 PermuteSummer(__m128i pack0123, __m128i pack4567) {
+INTGEMM_SSE2 static inline RegisterPair128i PermuteSummer(__m128i pack0123, __m128i pack4567) {
// No op for 128 bits: already reduced fully.
- MultiplyResult128 ret;
+ RegisterPair128i ret;
ret.pack0123 = pack0123;
ret.pack4567 = pack4567;
return ret;
@@ -126,14 +146,14 @@ INTGEMM_PACK0123(INTGEMM_AVX512BW, __m512i)
// width must be a multiple of the register size.
// B_cols must be a multiple of 8.
// Multiply16
-#define INTGEMM_MULTIPLY16(Integer, target, WriteCSubType) \
- template <class WriteC> target static void Multiply(const int16_t *A, const int16_t *B, WriteC C, Index A_rows, Index width, Index B_cols) { \
+#define INTGEMM_MULTIPLY16(Integer, target, cpu_type) \
+template <typename PostprocessPipeline> target static void Multiply(const int16_t *A, const int16_t *B, float* C, PostprocessPipeline pipeline, Index A_rows, Index width, Index B_cols) { \
assert(width % (sizeof(Integer) / sizeof(int16_t)) == 0); \
assert(B_cols % 8 == 0); \
assert(reinterpret_cast<uintptr_t>(A) % sizeof(Integer) == 0); \
assert(reinterpret_cast<uintptr_t>(B) % sizeof(Integer) == 0); \
const int simd_width = width / (sizeof(Integer) / sizeof(int16_t)); \
- typename WriteC::WriteCSubType write_C(C); \
+ auto inited_pipeline = InitPostprocessPipeline<cpu_type>(pipeline); \
const Integer *B0_col = reinterpret_cast<const Integer *>(B); \
for (Index B0_colidx = 0; B0_colidx < B_cols; B0_col += 8 * simd_width, B0_colidx += 8) { \
/* Process one row of A at a time. Doesn't seem to be faster to do multiple rows of A at once.*/ \
@@ -177,7 +197,8 @@ INTGEMM_PACK0123(INTGEMM_AVX512BW, __m512i)
Integer pack4567 = Pack0123(sum4, sum5, sum6, sum7); \
/*The specific implementation may need to reduce further.*/ \
auto total = PermuteSummer(pack0123, pack4567); \
- write_C(A_rowidx, B_cols, B0_colidx, total); \
+ auto result = inited_pipeline.run(total); \
+ writer(C, A_rowidx, B_cols, B0_colidx, result); \
} \
} \
} \
@@ -330,15 +351,15 @@ INTGEMM_SSSE3 inline static void InnerINTGEMM_SSSE3(
sum7 = adds_epi16(sum7, maddubs_epi16(a_positive, sign_epi8(b[7], a)));
}
//INTGEMM_AVX2 or INTGEMM_SSSE3 multiply
-#define INTGEMM_MULTIPLY8(Integer, target, WriteCSubType) \
-template <class WriteC> target static void Multiply(const int8_t *A, const int8_t *B, WriteC C, Index A_rows, Index width, Index B_cols) { \
+#define INTGEMM_MULTIPLY8(Integer, target, cpu_type) \
+ template <typename PostprocessPipeline> target static void Multiply(const int8_t *A, const int8_t *B, float* C, PostprocessPipeline pipeline, Index A_rows, Index width, Index B_cols) { \
assert(width % sizeof(Integer) == 0); \
assert(B_cols % 8 == 0); \
assert(reinterpret_cast<uintptr_t>(A) % sizeof(Integer) == 0); \
assert(reinterpret_cast<uintptr_t>(B) % sizeof(Integer) == 0); \
const int simd_width = width / sizeof(Integer); \
const Integer *B0_col = reinterpret_cast<const Integer*>(B); \
- typename WriteC::WriteCSubType c_writer(C); \
+ auto inited_pipeline = InitPostprocessPipeline<cpu_type>(pipeline); \
/*Go over 8 columns of B at a time.*/ \
for (Index B0_colidx = 0; B0_colidx != B_cols; B0_col += 8 * simd_width, B0_colidx += 8) { \
/*Process one row of A at a time. Doesn't seem to be faster to do multiple rows of A at once.*/ \
@@ -393,8 +414,8 @@ template <class WriteC> target static void Multiply(const int8_t *A, const int8_
Integer pack0123 = Pack0123(sum0, sum1, sum2, sum3); \
Integer pack4567 = Pack0123(sum4, sum5, sum6, sum7); \
auto total = PermuteSummer(pack0123, pack4567); \
- /*WriteC(C + A_rowidx * B_cols + B0_colidx, total, unquant_reg);*/ \
- c_writer(A_rowidx, B_cols, B0_colidx, total); \
+ auto result = inited_pipeline.run(total); \
+ writer(C, A_rowidx, B_cols, B0_colidx, result); \
} \
} \
} \