Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/intgemm.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--avx2_gemm.cc8
-rw-r--r--avx2_gemm.h4
-rw-r--r--avx512_gemm.cc8
-rw-r--r--avx512_gemm.h4
-rw-r--r--interleave.h31
-rw-r--r--intgemm.cc8
-rw-r--r--intgemm.h6
-rw-r--r--multiply.h10
-rw-r--r--sse2_gemm.cc4
-rw-r--r--sse2_gemm.h2
-rw-r--r--ssse3_gemm.cc4
-rw-r--r--ssse3_gemm.h2
12 files changed, 81 insertions, 10 deletions
diff --git a/avx2_gemm.cc b/avx2_gemm.cc
index 59a4c48..6ac500a 100644
--- a/avx2_gemm.cc
+++ b/avx2_gemm.cc
@@ -125,10 +125,18 @@ void AVX2_16bit::PrepareB(const float *input, int16_t *output, float quant_mult,
PrepareBFor16(input, output, QuantizeTile16(quant_mult), rows, cols);
}
+void AVX2_16bit::SelectColumnsB(const int16_t *input, int16_t *output, int rows, const int *cols_begin, const int *cols_end) {
+ SelectColumnsOfB((const __m256i*)input, (__m256i*)output, rows * 2, cols_begin, cols_end);
+}
+
void AVX2_8bit::PrepareB(const float *input, int8_t *output, float quant_mult, int rows, int cols) {
PrepareBFor8(input, output, QuantizeTile8(quant_mult), rows, cols);
}
+void AVX2_8bit::SelectColumnsB(const int8_t *input, int8_t *output, int rows, const int *cols_begin, const int *cols_end) {
+ SelectColumnsOfB((const __m256i*)input, (__m256i*)output, rows, cols_begin, cols_end);
+}
+
void AVX2_16bit::Multiply(const int16_t *A, const int16_t *B, float *C, float unquant_mult, int A_rows, int width, int B_cols) {
Multiply16<__m256i, __m256>(A, B, C, unquant_mult, A_rows, width, B_cols);
}
diff --git a/avx2_gemm.h b/avx2_gemm.h
index 23faa2a..31db7f6 100644
--- a/avx2_gemm.h
+++ b/avx2_gemm.h
@@ -20,6 +20,8 @@ struct AVX2_16bit {
static void PrepareB(const float *input, int16_t *output, float quant_mult, int rows, int cols);
+ static void SelectColumnsB(const int16_t *input, int16_t *output, int rows, const int *cols_begin, const int *cols_end);
+
static void Multiply(const int16_t *A, const int16_t *B, float *C, float unquant_mult, int A_rows, int width, int B_cols);
static const char *const kName;
@@ -43,6 +45,8 @@ struct AVX2_8bit {
static void PrepareB(const float *input, int8_t *output, float quant_mult, int rows, int cols);
+ static void SelectColumnsB(const int8_t *input, int8_t *output, int rows, const int *cols_begin, const int *cols_end);
+
static void Multiply(const int8_t *A, const int8_t *B, float *C, float unquant_mult, int A_rows, int width, int B_cols);
static const char *const kName;
diff --git a/avx512_gemm.cc b/avx512_gemm.cc
index 70a785b..efa43d8 100644
--- a/avx512_gemm.cc
+++ b/avx512_gemm.cc
@@ -142,10 +142,18 @@ void AVX512_16bit::PrepareB(const float *input, int16_t *output, float quant_mul
PrepareBFor16(input, output, QuantizeTile16(quant_mult), rows, cols);
}
+void AVX512_16bit::SelectColumnsB(const int16_t *input, int16_t *output, int rows, const int *cols_begin, const int *cols_end) {
+ SelectColumnsOfB((const __m512i*)input, (__m512i*)output, rows * 2, cols_begin, cols_end);
+}
+
void AVX512_8bit::PrepareB(const float *input, int8_t *output, float quant_mult, int rows, int cols) {
PrepareBFor8(input, output, QuantizeTile8(quant_mult), rows, cols);
}
+void AVX512_8bit::SelectColumnsB(const int8_t *input, int8_t *output, int rows, const int *cols_begin, const int *cols_end) {
+ SelectColumnsOfB((const __m512i*)input, (__m512i*)output, rows, cols_begin, cols_end);
+}
+
void AVX512_16bit::Multiply(const int16_t *A, const int16_t *B, float *C, float unquant_mult, int A_rows, int width, int B_cols) {
// The unquantization is only 256-bit wide because there are 8 results.
Multiply16<__m512i, __m256> (A, B, C, unquant_mult, A_rows, width, B_cols);
diff --git a/avx512_gemm.h b/avx512_gemm.h
index 2bc0358..21b49cf 100644
--- a/avx512_gemm.h
+++ b/avx512_gemm.h
@@ -34,6 +34,8 @@ struct AVX512_16bit {
static void PrepareB(const float *input, int16_t *output, float quant_mult, int rows, int cols);
+ static void SelectColumnsB(const int16_t *input, int16_t *output, int rows, const int *cols_begin, const int *cols_end);
+
static void Multiply(const int16_t *A, const int16_t *B, float *C, float unquant_mult, int A_rows, int width, int B_cols);
static const char *const kName;
@@ -59,6 +61,8 @@ struct AVX512_8bit {
static void PrepareB(const float *input, int8_t *output, float quant_mult, int rows, int cols);
+ static void SelectColumnsB(const int8_t *input, int8_t *output, int rows, const int *cols_begin, const int *cols_end);
+
static void Multiply(const int8_t *A, const int8_t *B, float *C, float unquant_mult, int A_rows, int width, int B_cols);
static const char *const kName;
diff --git a/interleave.h b/interleave.h
index 5e26f57..1c449ca 100644
--- a/interleave.h
+++ b/interleave.h
@@ -44,14 +44,25 @@ inline void Interleave64(type &first, type &second) { \
first = temp; \
}
+
+template <class Register> inline Register setzero_si();
#ifdef __SSE2__
INTGEMM_INTERLEAVE(__m128i, )
+template <> inline __m128i setzero_si<__m128i>() {
+ return _mm_setzero_si128();
+}
#endif
#ifdef __AVX2__
INTGEMM_INTERLEAVE(__m256i, 256)
+template <> inline __m256i setzero_si<__m256i>() {
+ return _mm256_setzero_si256();
+}
#endif
#ifdef __AVX512F__
INTGEMM_INTERLEAVE(__m512i, 512)
+template <> inline __m512i setzero_si<__m512i>() {
+ return _mm512_setzero_si512();
+}
#endif
template <class Register> inline void Swap(Register &a, Register &b) {
@@ -237,4 +248,24 @@ template <class Quantizer> inline void PrepareBFor16(const float *input, int16_t
}
}
+/* Select columns of B from PrepareB format to PrepareB format.
+ */
+template <class Register> inline void SelectColumnsOfB(const Register *input, Register *output, int rows_bytes /* number of bytes in a row */, const int *cols_begin, const int *cols_end) {
+ // Do columns for multiples of 8.
+ int register_rows = rows_bytes / sizeof(Register);
+ const int *cols_end8 = cols_begin + ((cols_end - cols_begin) & ~7);
+ const Register *starts[8];
+ for (; cols_begin != cols_end8; cols_begin += 8) {
+ for (int k = 0; k < 8; ++k) {
+ starts[k] = input + (cols_begin[k] & 7) + (cols_begin[k] & ~7) * register_rows;
+ }
+ for (int r = 0; r < register_rows; ++r) {
+ for (int k = 0; k < 8; ++k) {
+ *(output++) = *starts[k];
+ starts[k] += 8;
+ }
+ }
+ }
+}
+
} // namespace intgemm
diff --git a/intgemm.cc b/intgemm.cc
index 6c5eeab..88f12e5 100644
--- a/intgemm.cc
+++ b/intgemm.cc
@@ -25,6 +25,9 @@ struct Unsupported_16bit {
static void PrepareB(const float *, int16_t *, float, int, int) {
throw UnsupportedCPU();
}
+ static void SelectColumnsB(const int16_t *input, int16_t *output, int rows, const int *cols_begin, const int *cols_end) {
+ throw UnsupportedCPU();
+ }
static void Multiply(const int16_t *, const int16_t *, float *C, float, int, int, int) {
throw UnsupportedCPU();
}
@@ -39,6 +42,9 @@ struct Unsupported_8bit {
static void PrepareB(const float *, int8_t *, float, int, int) {
throw UnsupportedCPU();
}
+ static void SelectColumnsB(const int8_t *input, int8_t *output, int rows, const int *cols_begin, const int *cols_end) {
+ throw UnsupportedCPU();
+ }
static void Multiply(const int8_t *, const int8_t *, float *C, float, int, int, int) {
throw UnsupportedCPU();
}
@@ -77,11 +83,13 @@ template <class T> T ChooseCPU(T avx512, T avx2, T ssse3, T sse2, T unsupported)
void (*Int16::Quantize)(const float *input, int16_t *output, float quant_mult, int size) = ChooseCPU(AVX512_16bit::Quantize, AVX2_16bit::Quantize, SSE2_16bit::Quantize, SSE2_16bit::Quantize, Unsupported_16bit::Quantize);
void (*Int16::PrepareB)(const float *input, int16_t *output, float quant_mult, int rows, int cols) = ChooseCPU(AVX512_16bit::PrepareB, AVX2_16bit::PrepareB, SSE2_16bit::PrepareB, SSE2_16bit::PrepareB, Unsupported_16bit::PrepareB);
+void (*Int16::SelectColumnsB)(const int16_t *input, int16_t *output, int rows, const int *cols_begin, const int *cols_end) = ChooseCPU(AVX512_16bit::SelectColumnsB, AVX2_16bit::SelectColumnsB, SSE2_16bit::SelectColumnsB, SSE2_16bit::SelectColumnsB, Unsupported_16bit::SelectColumnsB);
void (*Int16::Multiply)(const int16_t *A, const int16_t *B, float *C, float unquant_mult, int A_rows, int width, int B_cols) = ChooseCPU(AVX512_16bit::Multiply, AVX2_16bit::Multiply, SSE2_16bit::Multiply, SSE2_16bit::Multiply, Unsupported_16bit::Multiply);
const char *const Int16::kName = ChooseCPU(AVX512_16bit::kName, AVX2_16bit::kName, SSE2_16bit::kName, SSE2_16bit::kName, Unsupported_16bit::kName);
void (*Int8::Quantize)(const float *input, int8_t *output, float quant_mult, int size) = ChooseCPU(AVX512_8bit::Quantize, AVX2_8bit::Quantize, SSSE3_8bit::Quantize, Unsupported_8bit::Quantize, Unsupported_8bit::Quantize);
void (*Int8::PrepareB)(const float *input, int8_t *output, float quant_mult, int rows, int cols) = ChooseCPU(AVX512_8bit::PrepareB, AVX2_8bit::PrepareB, SSSE3_8bit::PrepareB, Unsupported_8bit::PrepareB, Unsupported_8bit::PrepareB);
+void (*Int8::SelectColumnsB)(const int8_t *input, int8_t *output, int rows, const int *cols_begin, const int *cols_end) = ChooseCPU(AVX512_8bit::SelectColumnsB, AVX2_8bit::SelectColumnsB, SSSE3_8bit::SelectColumnsB, Unsupported_8bit::SelectColumnsB, Unsupported_8bit::SelectColumnsB);
void (*Int8::Multiply)(const int8_t *A, const int8_t *B, float *C, float unquant_mult, int A_rows, int width, int B_cols) = ChooseCPU(AVX512_8bit::Multiply, AVX2_8bit::Multiply, SSSE3_8bit::Multiply, Unsupported_8bit::Multiply, Unsupported_8bit::Multiply);
const char *const Int8::kName = ChooseCPU(AVX512_8bit::kName, AVX2_8bit::kName, SSSE3_8bit::kName, Unsupported_8bit::kName, Unsupported_8bit::kName);
diff --git a/intgemm.h b/intgemm.h
index 1d198a8..e7ba370 100644
--- a/intgemm.h
+++ b/intgemm.h
@@ -83,6 +83,9 @@ struct Int16 {
// It will match the Multiply function on the same CPU though.
static void (*PrepareB)(const float *input, int16_t *output, float quant_mult, int rows, int cols);
+ // Select columns from a prepared B matrix. The number of selected columns must be a multiple of 8.
+ static void (*SelectColumnsB)(const int16_t *input, int16_t *output, int rows, const int *cols_begin, const int *cols_end);
+
// Multiply C = A * B, presuming A and B have been prepared.
static void (*Multiply)(const int16_t *A, const int16_t *B, float *C, float unquant_mult, int A_rows, int width, int B_cols);
@@ -114,6 +117,9 @@ struct Int8 {
// It will match the Multiply function on the same CPU though.
static void (*PrepareB)(const float *input, int8_t *output, float quant_mult, int rows, int cols);
+ // Select columns from a prepared B matrix. The number of selected columns must be a multiple of 8.
+ static void (*SelectColumnsB)(const int8_t *input, int8_t *output, int rows, const int *cols_begin, const int *cols_end);
+
// Multiply C = A * B, presuming A and B have been prepared.
static void (*Multiply)(const int8_t *A, const int8_t *B, float *C, float unquant_mult, int A_rows, int width, int B_cols);
diff --git a/multiply.h b/multiply.h
index 81a90c2..5c64adf 100644
--- a/multiply.h
+++ b/multiply.h
@@ -14,7 +14,6 @@ namespace intgemm {
*/
template <class Register> inline Register set1_epi16(int16_t to);
template <class Register> inline Register set1_ps(float to);
-template <class Register> inline Register setzero_si();
#ifdef __SSE2__
inline __m128i add_epi32(__m128i first, __m128i second) {
return _mm_add_epi32(first, second);
@@ -28,9 +27,6 @@ template <> inline __m128i set1_epi16<__m128i>(int16_t to) {
template <> inline __m128 set1_ps<__m128>(float to) {
return _mm_set1_ps(to);
}
-template <> inline __m128i setzero_si<__m128i>() {
- return _mm_setzero_si128();
-}
inline __m128i madd_epi16(__m128i first, __m128i second) {
return _mm_madd_epi16(first, second);
}
@@ -64,9 +60,6 @@ template <> inline __m256i set1_epi16<__m256i>(int16_t to) {
template <> inline __m256 set1_ps<__m256>(float to) {
return _mm256_set1_ps(to);
}
-template <> inline __m256i setzero_si<__m256i>() {
- return _mm256_setzero_si256();
-}
inline __m256i madd_epi16(__m256i first, __m256i second) {
return _mm256_madd_epi16(first, second);
}
@@ -100,9 +93,6 @@ template <> inline __m512i set1_epi16<__m512i>(int16_t to) {
template <> inline __m512 set1_ps<__m512>(float to) {
return _mm512_set1_ps(to);
}
-template <> inline __m512i setzero_si<__m512i>() {
- return _mm512_setzero_si512();
-}
inline __m512i madd_epi16(__m512i first, __m512i second) {
return _mm512_madd_epi16(first, second);
}
diff --git a/sse2_gemm.cc b/sse2_gemm.cc
index 05a4080..3e0a0d8 100644
--- a/sse2_gemm.cc
+++ b/sse2_gemm.cc
@@ -61,6 +61,10 @@ void SSE2_16bit::PrepareB(const float *input, int16_t *output, float quant_mult,
PrepareBFor16(input, output, QuantizeTile16(quant_mult), rows, cols);
}
+void SSE2_16bit::SelectColumnsB(const int16_t *input, int16_t *output, int rows, const int *cols_begin, const int *cols_end) {
+ SelectColumnsOfB((const __m128i*)input, (__m128i*)output, rows * 2, cols_begin, cols_end);
+}
+
void SSE2_16bit::Multiply(const int16_t *A, const int16_t *B, float *C, float unquant_mult, int A_rows, int width, int B_cols) {
Multiply16<__m128i, __m128>(A, B, C, unquant_mult, A_rows, width, B_cols);
}
diff --git a/sse2_gemm.h b/sse2_gemm.h
index 23997ed..84823fa 100644
--- a/sse2_gemm.h
+++ b/sse2_gemm.h
@@ -22,6 +22,8 @@ struct SSE2_16bit {
static void PrepareB(const float *input, int16_t *output, float quant_mult, int rows, int cols);
+ static void SelectColumnsB(const int16_t *input, int16_t *output, int rows, const int *cols_begin, const int *cols_end);
+
static void Multiply(const int16_t *A, const int16_t *B, float *C, float unquant_mult, int A_rows, int width, int B_cols);
static const char *const kName;
diff --git a/ssse3_gemm.cc b/ssse3_gemm.cc
index 4ccecec..c044910 100644
--- a/ssse3_gemm.cc
+++ b/ssse3_gemm.cc
@@ -78,6 +78,10 @@ void SSSE3_8bit::PrepareB(const float *input, int8_t *output, float quant_mult,
PrepareBFor8(input, output, QuantizeTile8(quant_mult), rows, cols);
}
+void SSSE3_8bit::SelectColumnsB(const int8_t *input, int8_t *output, int rows, const int *cols_begin, const int *cols_end) {
+ SelectColumnsOfB((const __m128i*)input, (__m128i*)output, rows, cols_begin, cols_end);
+}
+
void SSSE3_8bit::Multiply(const int8_t *A, const int8_t *B, float *C, float unquant_mult, int A_rows, int width, int B_cols) {
Multiply8_SSE2OrAVX2<Multiply8_C, __m128i, __m128>(A, B, C, unquant_mult, A_rows, width, B_cols);
}
diff --git a/ssse3_gemm.h b/ssse3_gemm.h
index 9897d94..d904a3e 100644
--- a/ssse3_gemm.h
+++ b/ssse3_gemm.h
@@ -23,6 +23,8 @@ struct SSSE3_8bit {
static void PrepareB(const float *input, int8_t *output, float quant_mult, int rows, int cols);
+ static void SelectColumnsB(const int8_t *input, int8_t *output, int rows, const int *cols_begin, const int *cols_end);
+
static void Multiply(const int8_t *A, const int8_t *B, float *C, float unquant_mult, int A_rows, int width, int B_cols);
static const char *const kName;