Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/intgemm.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--avx2_gemm.h2
-rw-r--r--avx512_gemm.h4
-rw-r--r--cops.h12
-rw-r--r--intrinsics.h10
-rw-r--r--sse2_gemm.h2
-rw-r--r--ssse3_gemm.h2
6 files changed, 21 insertions, 11 deletions
diff --git a/avx2_gemm.h b/avx2_gemm.h
index 9040551..c5ca0bc 100644
--- a/avx2_gemm.h
+++ b/avx2_gemm.h
@@ -12,7 +12,7 @@ namespace intgemm {
namespace avx2 {
INTGEMM_AVX2 inline __m256i QuantizerGrab(const float *input, const __m256 quant_mult_reg) {
- return quantize(*reinterpret_cast<const __m256*>(input), quant_mult_reg);
+ return quantize(loadu_ps<__m256>(input), quant_mult_reg);
}
INTGEMM_SELECT_COL_B(INTGEMM_AVX2, __m256i)
diff --git a/avx512_gemm.h b/avx512_gemm.h
index 4719fae..c9233a6 100644
--- a/avx512_gemm.h
+++ b/avx512_gemm.h
@@ -36,7 +36,7 @@ namespace avx512f {
// Load from memory, multiply, and convert to int32_t.
/* Only INTGEMM_AVX512F is necessary but due to GCC 5.4 bug we have to set INTGEMM_AVX512BW */
INTGEMM_AVX512BW inline __m512i QuantizerGrab(const float *input, const __m512 quant_mult_reg) {
- return quantize(*reinterpret_cast<const __m512*>(input), quant_mult_reg);
+ return quantize(loadu_ps<__m512>(input), quant_mult_reg);
}
/* Only INTGEMM_AVX512F is necessary but due to GCC 5.4 bug we have to set INTGEMM_AVX512BW */
@@ -56,7 +56,7 @@ INTGEMM_AVX512DQ inline __m512 Concat(const __m256 first, const __m256 second) {
// Like QuantizerGrab, but allows 32-byte halves (i.e. 8 columns) to be controlled independently.
/* Only INTGEMM_AVX512F is necessary but due to GCC 5.4 bug we have to set INTGEMM_AVX512BW */
INTGEMM_AVX512BW inline __m512i QuantizerGrabHalves(const float *input0, const float *input1, const __m512 quant_mult_reg) {
- __m512 appended = avx512f::Concat(*reinterpret_cast<const __m256*>(input0), *reinterpret_cast<const __m256*>(input1));
+ __m512 appended = avx512f::Concat(loadu_ps<__m256>(input0), loadu_ps<__m256>(input1));
appended = _mm512_mul_ps(appended, quant_mult_reg);
return _mm512_cvtps_epi32(appended);
}
diff --git a/cops.h b/cops.h
index 67796f5..6d8c771 100644
--- a/cops.h
+++ b/cops.h
@@ -61,10 +61,10 @@ class BiasAddUnquantizeC {
}
INTGEMM_SSE2 inline void operator()(Index rowIDX, Index cols, Index colIDX, MultiplyResult128 result) {
- const __m128* biasSection = reinterpret_cast<const __m128*>(bias_ + colIDX);
- const __m128* biasSection2 = reinterpret_cast<const __m128*>(bias_ + colIDX + 4);
- storeu_ps(C_ + rowIDX*cols + colIDX, add_ps(unquantize(result.pack0123, unquant_mult_), *biasSection));
- storeu_ps(C_ + rowIDX*cols + colIDX + 4, add_ps(unquantize(result.pack4567, unquant_mult_), *biasSection2));
+ auto biasSection0123 = loadu_ps<__m128>(bias_ + colIDX);
+ auto biasSection4567 = loadu_ps<__m128>(bias_ + colIDX + 4);
+ storeu_ps(C_ + rowIDX*cols + colIDX, add_ps(unquantize(result.pack0123, unquant_mult_), biasSection0123));
+ storeu_ps(C_ + rowIDX*cols + colIDX + 4, add_ps(unquantize(result.pack4567, unquant_mult_), biasSection4567));
}
private:
float *C_;
@@ -80,8 +80,8 @@ class BiasAddUnquantizeC {
}
INTGEMM_AVX2 inline void operator()(Index rowIDX, Index cols, Index colIDX, __m256i result) {
- const __m256* biasSection = reinterpret_cast<const __m256*>(bias_ + colIDX);
- storeu_ps(C_ + rowIDX*cols + colIDX, add_ps(unquantize(result, unquant_mult_), *biasSection));
+ auto biasSection = loadu_ps<__m256>(bias_ + colIDX);
+ storeu_ps(C_ + rowIDX*cols + colIDX, add_ps(unquantize(result, unquant_mult_), biasSection));
}
private:
diff --git a/intrinsics.h b/intrinsics.h
index 5cdaa8d..7c36d6b 100644
--- a/intrinsics.h
+++ b/intrinsics.h
@@ -15,6 +15,7 @@ namespace intgemm {
* Define a bunch of intrinstics as overloaded functions so they work with
* templates.
*/
+template <class Register> static inline Register loadu_ps(const float* mem_addr);
template <class Register> static inline Register set1_epi16(int16_t to);
template <class Register> static inline Register set1_epi32(int32_t to);
template <class Register> static inline Register set1_ps(float to);
@@ -44,6 +45,9 @@ INTGEMM_SSE2 static inline __m128 cvtepi32_ps(__m128i arg) {
INTGEMM_SSE2 static inline __m128i cvtps_epi32(__m128 arg) {
return _mm_cvtps_epi32(arg);
}
+template <> INTGEMM_SSE2 inline __m128 loadu_ps(const float* mem_addr) {
+ return _mm_loadu_ps(mem_addr);
+}
INTGEMM_SSE2 static inline __m128i madd_epi16(__m128i first, __m128i second) {
return _mm_madd_epi16(first, second);
}
@@ -104,6 +108,9 @@ INTGEMM_AVX2 static inline __m256 cvtepi32_ps(__m256i arg) {
INTGEMM_AVX2 static inline __m256i cvtps_epi32(__m256 arg) {
return _mm256_cvtps_epi32(arg);
}
+template <> INTGEMM_AVX2 inline __m256 loadu_ps(const float* mem_addr) {
+ return _mm256_loadu_ps(mem_addr);
+}
INTGEMM_AVX2 static inline __m256i madd_epi16(__m256i first, __m256i second) {
return _mm256_madd_epi16(first, second);
}
@@ -166,6 +173,9 @@ INTGEMM_AVX512BW static inline __m512 cvtepi32_ps(__m512i arg) {
INTGEMM_AVX512BW static inline __m512i cvtps_epi32(__m512 arg) {
return _mm512_cvtps_epi32(arg);
}
+template <> INTGEMM_AVX512BW inline __m512 loadu_ps(const float* mem_addr) {
+ return _mm512_loadu_ps(mem_addr);
+}
INTGEMM_AVX512BW static inline __m512i madd_epi16(__m512i first, __m512i second) {
return _mm512_madd_epi16(first, second);
}
diff --git a/sse2_gemm.h b/sse2_gemm.h
index 32d5e85..3ca263f 100644
--- a/sse2_gemm.h
+++ b/sse2_gemm.h
@@ -11,7 +11,7 @@ namespace intgemm {
namespace sse2 {
INTGEMM_SSE2 inline __m128i QuantizerGrab(const float *input, const __m128 quant_mult_reg) {
- return quantize(*reinterpret_cast<const __m128*>(input), quant_mult_reg);
+ return quantize(loadu_ps<__m128>(input), quant_mult_reg);
}
INTGEMM_SELECT_COL_B(INTGEMM_SSE2, __m128i)
diff --git a/ssse3_gemm.h b/ssse3_gemm.h
index f7abbda..9c21467 100644
--- a/ssse3_gemm.h
+++ b/ssse3_gemm.h
@@ -14,7 +14,7 @@ namespace intgemm {
namespace ssse3 {
INTGEMM_SSSE3 inline __m128i QuantizerGrab(const float *input, const __m128 quant_mult_reg) {
- return quantize(*reinterpret_cast<const __m128*>(input), quant_mult_reg);
+ return quantize(loadu_ps<__m128>(input), quant_mult_reg);
}
INTGEMM_SELECT_COL_B(INTGEMM_SSSE3, __m128i)