Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/intgemm/intgemm.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKenneth Heafield <github@kheafield.com>2019-04-18 16:38:59 +0300
committerKenneth Heafield <github@kheafield.com>2019-04-18 16:38:59 +0300
commitfba6fb3b78611b262cac3022fc585efc020991ff (patch)
tree25b9cd6679d3854651630777a2c814493d8416ef
parente75044e331d905f0b01adfe219852582579d2872 (diff)
Look ma, no gcc compiler errors
-rw-r--r--avx2_gemm.h10
-rw-r--r--avx512_gemm.h17
-rw-r--r--cops.h78
-rw-r--r--multiply.h16
-rw-r--r--sse2_gemm.h12
-rw-r--r--ssse3_gemm.h2
6 files changed, 57 insertions, 78 deletions
diff --git a/avx2_gemm.h b/avx2_gemm.h
index c326976..393eebb 100644
--- a/avx2_gemm.h
+++ b/avx2_gemm.h
@@ -81,12 +81,8 @@ struct AVX2_16bit {
AVX2 static void SelectColumnsB(const int16_t *input, int16_t *output, Index rows, const Index *cols_begin, const Index *cols_end) {
avx2::SelectColumnsOfB((const __m256i*)input, (__m256i*)output, rows * 2, cols_begin, cols_end);
}
-/*
- AVX2 static void Multiply(const int16_t *A, const int16_t *B, float *C, float unquant_mult, Index A_rows, Index width, Index B_cols) {
- Multiply16__m256i<JustUnquantizeC> (A, B, JustUnquantizeC(C, unquant_mult), A_rows, width, B_cols);
- }
- */
- MULTIPLY16_define(__m256i, AVX2)
+
+ MULTIPLY16_define(__m256i, AVX2, OnAVX2)
constexpr static const char *const kName = "16-bit AVX2";
@@ -184,7 +180,7 @@ struct AVX2_8bit {
//Multiply8_SSE2OrAVX2<Multiply8_AVXAVX2, __m256i, __m256>(A, B, C, unquant_mult, A_rows, width, B_cols);
Multiply8_SSE2OrAVX2__m256i<JustUnquantizeC>(A, B, JustUnquantizeC(C, unquant_mult), A_rows, width, B_cols);
}*/
- MULTIPLY8_define(__m256i, AVX2)
+ MULTIPLY8_define(__m256i, AVX2, OnAVX2)
constexpr static const char *const kName = "8-bit AVX2";
diff --git a/avx512_gemm.h b/avx512_gemm.h
index e48ed9a..3806929 100644
--- a/avx512_gemm.h
+++ b/avx512_gemm.h
@@ -156,12 +156,8 @@ struct AVX512_16bit {
AVX512F static void SelectColumnsB(const int16_t *input, int16_t *output, Index rows, const Index *cols_begin, const Index *cols_end) {
avx512f::SelectColumnsOfB((const __m512i*)input, (__m512i*)output, rows * 2, cols_begin, cols_end);
}
-/*
- AVX512F static void Multiply(const int16_t *A, const int16_t *B, float *C, float unquant_mult, Index A_rows, Index width, Index B_cols) {
- // The unquantization is only 256-bit wide because there are 8 results.
- Multiply16__m512i<JustUnquantizeC> (A, B, JustUnquantizeC(C, unquant_mult), A_rows, width, B_cols);
- }*/
- MULTIPLY16_define(__m512i, AVX512F)
+
+ MULTIPLY16_define(__m512i, AVX512F, OnAVX2)
constexpr static const char *const kName = "16-bit AVX512";
@@ -209,7 +205,7 @@ struct AVX512_8bit {
// Special AVX512 implementation due to having 32 registers (so I don't have to
// allocate registers manually) and no sign instruction.
template <class WriteC>
- AVX512BW static void Multiply(const int8_t *A, const int8_t *B, WriteC functor, Index A_rows, Index width, Index B_cols) {
+ AVX512BW static void Multiply(const int8_t *A, const int8_t *B, WriteC C, Index A_rows, Index width, Index B_cols) {
typedef __m512i Integer;
//typedef __m256 Float; // For quantization we only do 8 at a time.
// This is copy-paste from Multiply8_SSE2OrAVX2.
@@ -217,8 +213,8 @@ struct AVX512_8bit {
assert(B_cols % 8 == 0);
assert(reinterpret_cast<uintptr_t>(A) % sizeof(Integer) == 0);
assert(reinterpret_cast<uintptr_t>(B) % sizeof(Integer) == 0);
- //assert(reinterpret_cast<uintptr_t>(C) % sizeof(Integer) == 0);
- //Float unquant_reg = set1_ps<Float>(unquant_mult);
+ // There's 8 results for AVX2 to handle.
+ typename WriteC::OnAVX2 write_C(C);
const int simd_width = width / sizeof(Integer);
const Integer *B0_col = reinterpret_cast<const Integer*>(B);
// Added for AVX512.
@@ -315,8 +311,7 @@ struct AVX512_8bit {
Integer pack4567 = Pack0123(sum4, sum5, sum6, sum7);
auto total = PermuteSummer(pack0123, pack4567);
- //WriteC(C + A_rowidx * B_cols + B0_colidx, total, unquant_reg);
- functor(A_rowidx, B_cols, B0_colidx, total);
+ write_C(A_rowidx, B_cols, B0_colidx, total);
}
}
}
diff --git a/cops.h b/cops.h
index 196f705..aa9f0ff 100644
--- a/cops.h
+++ b/cops.h
@@ -6,46 +6,44 @@
namespace intgemm {
class JustUnquantizeC {
-public:
- JustUnquantizeC(float *C, float unquant_mult);
-
- SSE2 inline void operator()(Index rowIDX, Index cols, Index colIDX, MultiplyResult128 result);
- AVX2 inline void operator()(Index rowIDX, Index cols, Index colIDX, __m256i result);
-
-private:
- SSE2 void InitRegisterSSE(float unquant_mult);
- AVX2 void InitRegisterAVX2(float unquant_mult);
-
- float *C_;
- __m128 unquant_mult_128_; // Registers
- __m256 unquant_mult_256_;
+ public:
+ JustUnquantizeC(float *C, float unquant_mult) : C_(C), unquant_mult_(unquant_mult) {}
+
+ class OnSSE2 {
+ public:
+ SSE2 explicit OnSSE2(const JustUnquantizeC &from)
+ : C_(from.C_), unquant_mult_(_mm_set1_ps(from.unquant_mult_)) {
+ assert(reinterpret_cast<uintptr_t>(C_) % sizeof(__m128i) == 0);
+ }
+
+ SSE2 inline void operator()(Index rowIDX, Index cols, Index colIDX, MultiplyResult128 result) {
+ *reinterpret_cast<__m128*>(C_ + rowIDX*cols + colIDX) = mul_ps(cvtepi32_ps(result.pack0123), unquant_mult_);
+ *reinterpret_cast<__m128*>(C_ + rowIDX*cols + colIDX + 4) = mul_ps(cvtepi32_ps(result.pack4567), unquant_mult_);
+ }
+ private:
+ float *C_;
+ __m128 unquant_mult_;
+ };
+
+ class OnAVX2 {
+ public:
+ AVX2 explicit OnAVX2(const JustUnquantizeC &from)
+ : C_(from.C_), unquant_mult_(_mm256_set1_ps(from.unquant_mult_)) {
+ assert(reinterpret_cast<uintptr_t>(C_) % sizeof(__m256i) == 0);
+ }
+
+ AVX2 inline void operator()(Index rowIDX, Index cols, Index colIDX, __m256i result) {
+ *reinterpret_cast<__m256*>(C_ + rowIDX*cols + colIDX) = mul_ps(cvtepi32_ps(result), unquant_mult_);
+ }
+
+ private:
+ float *C_;
+ __m256 unquant_mult_;
+ };
+
+ private:
+ float *C_;
+ float unquant_mult_;
};
-SSE2 void JustUnquantizeC::InitRegisterSSE(float unquant_mult) {
- assert(reinterpret_cast<uintptr_t>(C_) % sizeof(__m128) == 0);
- unquant_mult_128_ = _mm_set1_ps(unquant_mult);
-}
-
-AVX2 void JustUnquantizeC::InitRegisterAVX2(float unquant_mult) {
- assert(reinterpret_cast<uintptr_t>(C_) % sizeof(__m256) == 0);
- unquant_mult_256_ = _mm256_set1_ps(unquant_mult);
-}
-
-JustUnquantizeC::JustUnquantizeC(float *C, float unquant_mult) : C_(C) {
- //We need both to make sure our tests pass
- //Some of the assertions might give false positives on SSE2/3
- InitRegisterSSE(unquant_mult);
- if (__builtin_cpu_supports("avx2")) {
- InitRegisterAVX2(unquant_mult);
- }
-}
-
-
-SSE2 inline void JustUnquantizeC::operator()(Index rowIDX, Index cols, Index colIDX, MultiplyResult128 result){
- *reinterpret_cast<__m128*>(C_ + rowIDX*cols + colIDX) = mul_ps(cvtepi32_ps(result.pack0123), unquant_mult_128_);
- *reinterpret_cast<__m128*>(C_ + rowIDX*cols + colIDX + 4) = mul_ps(cvtepi32_ps(result.pack4567), unquant_mult_128_);
-}
-AVX2 inline void JustUnquantizeC::operator()(Index rowIDX, Index cols, Index colIDX, __m256i result) {
- *reinterpret_cast<__m256*>(C_ + rowIDX*cols + colIDX) = mul_ps(cvtepi32_ps(result), unquant_mult_256_);
-}
} //Namespace
diff --git a/multiply.h b/multiply.h
index e27f31f..406358e 100644
--- a/multiply.h
+++ b/multiply.h
@@ -135,15 +135,14 @@ PACK_DEFINE(AVX512F, __m512i)
// width must be a multiple of the register size.
// B_cols must be a multiple of 8.
// Multiply16
-#define MULTIPLY16_define(Integer, target) \
- template <class WriteC> target static void Multiply(const int16_t *A, const int16_t *B, WriteC functor, Index A_rows, Index width, Index B_cols) { \
+#define MULTIPLY16_define(Integer, target, WriteCSubType) \
+ template <class WriteC> target static void Multiply(const int16_t *A, const int16_t *B, WriteC C, Index A_rows, Index width, Index B_cols) { \
assert(width % (sizeof(Integer) / sizeof(int16_t)) == 0); \
assert(B_cols % 8 == 0); \
assert(reinterpret_cast<uintptr_t>(A) % sizeof(Integer) == 0); \
assert(reinterpret_cast<uintptr_t>(B) % sizeof(Integer) == 0); \
- /*assert(reinterpret_cast<uintptr_t>(C) % sizeof(Integer) == 0); Moved to WriteC*/ \
const int simd_width = width / (sizeof(Integer) / sizeof(int16_t)); \
- /*const Float unquant_reg = set1_ps<Float>(unquant_mult); moved to WriteC*/ \
+ typename WriteC::WriteCSubType write_C(C); \
const Integer *B0_col = reinterpret_cast<const Integer *>(B); \
for (int B0_colidx = 0; B0_colidx < B_cols; B0_col += 8 * simd_width, B0_colidx += 8) { \
/* Process one row of A at a time. Doesn't seem to be faster to do multiple rows of A at once.*/ \
@@ -187,7 +186,7 @@ PACK_DEFINE(AVX512F, __m512i)
Integer pack4567 = Pack0123(sum4, sum5, sum6, sum7); \
/*The specific implementation may need to reduce further.*/ \
auto total = PermuteSummer(pack0123, pack4567); \
- functor(A_rowidx, B_cols, B0_colidx, total); \
+ write_C(A_rowidx, B_cols, B0_colidx, total); \
} \
} \
} \
@@ -340,14 +339,15 @@ SSSE3 inline static void InnerSSSE3(
sum7 = adds_epi16(sum7, maddubs_epi16(a_positive, sign_epi8(b[7], a)));
}
//AVX2 or SSSE3 multiply
-#define MULTIPLY8_define(Integer, target) \
-template <class WriteC> target static void Multiply(const int8_t *A, const int8_t *B, WriteC functor, Index A_rows, Index width, Index B_cols) { \
+#define MULTIPLY8_define(Integer, target, WriteCSubType) \
+template <class WriteC> target static void Multiply(const int8_t *A, const int8_t *B, WriteC C, Index A_rows, Index width, Index B_cols) { \
assert(width % sizeof(Integer) == 0); \
assert(B_cols % 8 == 0); \
assert(reinterpret_cast<uintptr_t>(A) % sizeof(Integer) == 0); \
assert(reinterpret_cast<uintptr_t>(B) % sizeof(Integer) == 0); \
const int simd_width = width / sizeof(Integer); \
const Integer *B0_col = reinterpret_cast<const Integer*>(B); \
+ typename WriteC::WriteCSubType c_writer(C); \
/*Go over 8 columns of B at a time.*/ \
for (int B0_colidx = 0; B0_colidx != B_cols; B0_col += 8 * simd_width, B0_colidx += 8) { \
/*Process one row of A at a time. Doesn't seem to be faster to do multiple rows of A at once.*/ \
@@ -403,7 +403,7 @@ template <class WriteC> target static void Multiply(const int8_t *A, const int8_
Integer pack4567 = Pack0123(sum4, sum5, sum6, sum7); \
auto total = PermuteSummer(pack0123, pack4567); \
/*WriteC(C + A_rowidx * B_cols + B0_colidx, total, unquant_reg);*/ \
- functor(A_rowidx, B_cols, B0_colidx, total); \
+ c_writer(A_rowidx, B_cols, B0_colidx, total); \
} \
} \
} \
diff --git a/sse2_gemm.h b/sse2_gemm.h
index a3e8233..8717150 100644
--- a/sse2_gemm.h
+++ b/sse2_gemm.h
@@ -65,11 +65,6 @@ struct SSE2_16bit {
// Tile size for B; B must be a multiple of this block size.
static const Index kBTileRow = 8;
static const Index kBTileCol = 8;
-/*
- SSE2 static void PrepareB(const float *input, int16_t *output, float quant_mult, Index rows, Index cols) {
- //TODO #DEFINE
- PrepareBFor16(input, output, sse2::QuantizeTile16(quant_mult), rows, cols);
- }*/
PREPARE_B_16_DEF(SSE2, sse2::QuantizeTile16)
@@ -77,12 +72,7 @@ struct SSE2_16bit {
//TODO #DEFINE
sse2::SelectColumnsOfB((const __m128i*)input, (__m128i*)output, rows * 2, cols_begin, cols_end);
}
-/*
- SSE2 static void Multiply(const int16_t *A, const int16_t *B, float *C, float unquant_mult, Index A_rows, Index width, Index B_cols) {
- //TODO #DEFINE
- Multiply16__m128i<JustUnquantizeC> (A, B, JustUnquantizeC(C, unquant_mult), A_rows, width, B_cols);
- }*/
- MULTIPLY16_define(__m128i, SSE2)
+ MULTIPLY16_define(__m128i, SSE2, OnSSE2)
constexpr static const char *const kName = "16-bit SSE2";
diff --git a/ssse3_gemm.h b/ssse3_gemm.h
index d384500..b19e552 100644
--- a/ssse3_gemm.h
+++ b/ssse3_gemm.h
@@ -102,7 +102,7 @@ struct SSSE3_8bit {
//Multiply8_SSE2OrAVX2<Multiply8_C, __m128i, __m128>(A, B, C, unquant_mult, A_rows, width, B_cols);
Multiply8_SSE2OrAVX2__m128i<JustUnquantizeC>(A, B, JustUnquantizeC(C, unquant_mult), A_rows, width, B_cols);
}*/
- MULTIPLY8_define(__m128i, SSSE3)
+ MULTIPLY8_define(__m128i, SSSE3, OnSSE2)
constexpr static const char *const kName = "8-bit SSSE3";