Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/intgemm/intgemm.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMateusz Chudyk <mateuszchudyk@gmail.com>2020-03-16 17:47:06 +0300
committerMateusz Chudyk <mateuszchudyk@gmail.com>2020-03-16 19:40:21 +0300
commitc351bd5793ccc36738ecfe921479edd588f723cf (patch)
tree8e4940e766c9b20be888e9d9dbf8b5d16acfad20
parentfbe8a3c382171f6cb1cf0c5113c56cbcb29a44f1 (diff)
Fix AVX512 16bit
-rw-r--r--avx2_gemm.h4
-rw-r--r--avx512_gemm.h2
-rw-r--r--multiply.h4
-rw-r--r--sse2_gemm.h2
-rw-r--r--ssse3_gemm.h2
5 files changed, 7 insertions, 7 deletions
diff --git a/avx2_gemm.h b/avx2_gemm.h
index 786f135..5243b36 100644
--- a/avx2_gemm.h
+++ b/avx2_gemm.h
@@ -89,7 +89,7 @@ struct AVX2_16bit {
avx2::SelectColumnsOfB((const __m256i*)input, (__m256i*)output, rows * 2, cols_begin, cols_end);
}
- INTGEMM_MULTIPLY(INTGEMM_AVX2, CPUType::AVX2, int16_t)
+ INTGEMM_MULTIPLY(INTGEMM_AVX2, __m256i, CPUType::AVX2, int16_t)
constexpr static const char *const kName = "16-bit AVX2";
@@ -240,7 +240,7 @@ struct AVX2_8bit {
avx2::SelectColumnsOfB((const __m256i*)input, (__m256i*)output, rows, cols_begin, cols_end);
}
- INTGEMM_MULTIPLY(INTGEMM_AVX2, CPUType::AVX2, int8_t)
+ INTGEMM_MULTIPLY(INTGEMM_AVX2, __m256i, CPUType::AVX2, int8_t)
INTGEMM_MULTIPLY8SHIFT(__m256i, INTGEMM_AVX2, CPUType::AVX2)
diff --git a/avx512_gemm.h b/avx512_gemm.h
index 51b846b..f031942 100644
--- a/avx512_gemm.h
+++ b/avx512_gemm.h
@@ -208,7 +208,7 @@ struct AVX512_16bit {
}
/* Only INTGEMM_AVX512F is necessary but due to GCC 5.4 bug we have to set INTGEMM_AVX512BW */
- INTGEMM_MULTIPLY(INTGEMM_AVX512BW, CPUType::AVX2, int16_t)
+ INTGEMM_MULTIPLY(INTGEMM_AVX512BW, __m512i, CPUType::AVX2, int16_t)
constexpr static const char *const kName = "16-bit AVX512";
diff --git a/multiply.h b/multiply.h
index 239254d..e3b70b1 100644
--- a/multiply.h
+++ b/multiply.h
@@ -263,10 +263,10 @@ struct Multiply_MakeFinalOutputAndRunCallback<int16_t> {
INTGEMM_MULTIPLY16_MAKE_FINAL_OUTPUT_AND_RUN_CALLBACK_IMPL(INTGEMM_AVX512BW, __m512i)
};
-#define INTGEMM_MULTIPLY(target, cpu_type, integer) \
+#define INTGEMM_MULTIPLY(target, regsiter, cpu_type, integer) \
template <Index TileRows, Index TileColumnsMultiplier, typename Callback> \
target static void Multiply(const integer *A, const integer *B, Index A_rows, Index width, Index B_cols, Callback callback) { \
- using Register = vector_t<cpu_type, integer>; \
+ using Register = regsiter; \
static constexpr Index TileColumns = 8 * TileColumnsMultiplier; \
assert(A_rows % TileRows == 0); \
assert(width % (sizeof(Register) / sizeof(integer)) == 0); \
diff --git a/sse2_gemm.h b/sse2_gemm.h
index 62df0a1..c782fc2 100644
--- a/sse2_gemm.h
+++ b/sse2_gemm.h
@@ -86,7 +86,7 @@ struct SSE2_16bit {
//TODO #DEFINE
sse2::SelectColumnsOfB((const __m128i*)input, (__m128i*)output, rows * 2, cols_begin, cols_end);
}
- INTGEMM_MULTIPLY(INTGEMM_SSE2, CPUType::SSE2, int16_t)
+ INTGEMM_MULTIPLY(INTGEMM_SSE2, __m128i, CPUType::SSE2, int16_t)
constexpr static const char *const kName = "16-bit SSE2";
diff --git a/ssse3_gemm.h b/ssse3_gemm.h
index 9756df5..b7427e3 100644
--- a/ssse3_gemm.h
+++ b/ssse3_gemm.h
@@ -156,7 +156,7 @@ struct SSSE3_8bit {
ssse3::SelectColumnsOfB((const __m128i*)input, (__m128i*)output, rows, cols_begin, cols_end);
}
- INTGEMM_MULTIPLY(INTGEMM_SSSE3, CPUType::SSE2, int8_t)
+ INTGEMM_MULTIPLY(INTGEMM_SSSE3, __m128i, CPUType::SSE2, int8_t)
INTGEMM_MULTIPLY8SHIFT(__m128i, INTGEMM_SSSE3, CPUType::SSE2)