Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/intgemm/intgemm.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'avx512_gemm.cc')
-rw-r--r--avx512_gemm.cc24
1 files changed, 13 insertions, 11 deletions
diff --git a/avx512_gemm.cc b/avx512_gemm.cc
index e7df5ad..904b779 100644
--- a/avx512_gemm.cc
+++ b/avx512_gemm.cc
@@ -35,7 +35,7 @@ inline __m512i QuantizerGrab(const float *input, const __m512 quant_mult_reg) {
// rearranging B.
//
// Convert to 16-bit signed integers.
-void AVX512_16bit::Quantize(const float *input, int16_t *output, float quant_mult, int size) {
+void AVX512_16bit::Quantize(const float *input, int16_t *output, float quant_mult, Index size) {
assert(size % 16 == 0);
assert(reinterpret_cast<uintptr_t>(input) % 64 == 0);
// Fill with the quantization multiplier.
@@ -48,7 +48,7 @@ void AVX512_16bit::Quantize(const float *input, int16_t *output, float quant_mul
}
// Convert to 8-bit signed integers.
-void AVX512_8bit::Quantize(const float *input, int8_t *output, float quant_mult, int size) {
+void AVX512_8bit::Quantize(const float *input, int8_t *output, float quant_mult, Index size) {
assert(size % 16 == 0);
assert(reinterpret_cast<uintptr_t>(input) % 64 == 0);
const __m512i neg127 = _mm512_set1_epi32(-127);
@@ -91,7 +91,7 @@ class QuantizeTile16 {
explicit QuantizeTile16(float mult) : mult_reg_(_mm512_set1_ps(mult)) {}
- inline __m512i ForReshape(const float *input, int cols) {
+ inline __m512i ForReshape(const float *input, Index cols) {
__m512i g0 = QuantizerGrabHalves(input, input + 16 * cols, mult_reg_);
__m512i g1 = QuantizerGrabHalves(input + 8 * cols, input + 24 * cols, mult_reg_);
__m512i packed = _mm512_packs_epi32(g0, g1);
@@ -109,7 +109,7 @@ class QuantizeTile8 {
explicit QuantizeTile8(float mult) : mult_reg_(_mm512_set1_ps(mult)) {}
- inline __m512i ForReshape(const float *input, int cols) {
+ inline __m512i ForReshape(const float *input, Index cols) {
// TODO: try alternative: _mm512_cvtsepi32_epi8 ?
const __m512i neg127 = _mm512_set1_epi8(-127);
// In reverse order: grabbing the first 32-bit values from each 128-bit register, then the second 32-bit values, etc.
@@ -137,30 +137,30 @@ class QuantizeTile8 {
} // namespace
-void AVX512_16bit::PrepareB(const float *input, int16_t *output, float quant_mult, int rows, int cols) {
+void AVX512_16bit::PrepareB(const float *input, int16_t *output, float quant_mult, Index rows, Index cols) {
PrepareBFor16(input, output, QuantizeTile16(quant_mult), rows, cols);
}
-void AVX512_16bit::SelectColumnsB(const int16_t *input, int16_t *output, int rows, const std::size_t *cols_begin, const std::size_t *cols_end) {
+void AVX512_16bit::SelectColumnsB(const int16_t *input, int16_t *output, Index rows, const Index *cols_begin, const Index *cols_end) {
SelectColumnsOfB((const __m512i*)input, (__m512i*)output, rows * 2, cols_begin, cols_end);
}
-void AVX512_8bit::PrepareB(const float *input, int8_t *output, float quant_mult, int rows, int cols) {
+void AVX512_8bit::PrepareB(const float *input, int8_t *output, float quant_mult, Index rows, Index cols) {
PrepareBFor8(input, output, QuantizeTile8(quant_mult), rows, cols);
}
-void AVX512_8bit::SelectColumnsB(const int8_t *input, int8_t *output, int rows, const std::size_t *cols_begin, const std::size_t *cols_end) {
+void AVX512_8bit::SelectColumnsB(const int8_t *input, int8_t *output, Index rows, const Index *cols_begin, const Index *cols_end) {
SelectColumnsOfB((const __m512i*)input, (__m512i*)output, rows, cols_begin, cols_end);
}
-void AVX512_16bit::Multiply(const int16_t *A, const int16_t *B, float *C, float unquant_mult, int A_rows, int width, int B_cols) {
+void AVX512_16bit::Multiply(const int16_t *A, const int16_t *B, float *C, float unquant_mult, Index A_rows, Index width, Index B_cols) {
// The unquantization is only 256-bit wide because there are 8 results.
Multiply16<__m512i, __m256> (A, B, C, unquant_mult, A_rows, width, B_cols);
}
// Special AVX512 implementation due to having 32 registers (so I don't have to
// allocate registers manually) and no sign instruction.
-void AVX512_8bit::Multiply(const int8_t *A, const int8_t *B, float *C, float unquant_mult, int A_rows, int width, int B_cols) {
+void AVX512_8bit::Multiply(const int8_t *A, const int8_t *B, float *C, float unquant_mult, Index A_rows, Index width, Index B_cols) {
typedef __m512i Integer;
typedef __m256 Float; // For quantization we only do 8 at a time.
// This is copy-paste from Multiply8_SSE2OrAVX2.
@@ -264,7 +264,9 @@ void AVX512_8bit::Multiply(const int8_t *A, const int8_t *B, float *C, float unq
sum7 = madd_epi16(sum7, ones);
Integer pack0123 = Pack0123(sum0, sum1, sum2, sum3);
Integer pack4567 = Pack0123(sum4, sum5, sum6, sum7);
- WriteC(C + A_rowidx * B_cols + B0_colidx, pack0123, pack4567, unquant_reg);
+
+ auto total = PermuteSummer(pack0123, pack4567);
+ WriteC(C + A_rowidx * B_cols + B0_colidx, total, unquant_reg);
}
}
}