Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/intgemm/intgemm.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'intgemm/avx2_gemm.h')
-rw-r--r--intgemm/avx2_gemm.h74
1 files changed, 33 insertions, 41 deletions
diff --git a/intgemm/avx2_gemm.h b/intgemm/avx2_gemm.h
index d111b32..5e81475 100644
--- a/intgemm/avx2_gemm.h
+++ b/intgemm/avx2_gemm.h
@@ -19,34 +19,30 @@ INTGEMM_SELECT_COL_B(INTGEMM_AVX2, __m256i)
class QuantizeTile16 {
public:
- INTGEMM_AVX2 explicit QuantizeTile16(float mult) : mult_(_mm256_set1_ps(mult)) {}
-
- INTGEMM_AVX2 Register Consecutive(const float *input) const {
- return Tile(input, input + 8);
+ INTGEMM_AVX2 static inline Register Consecutive(FRegister mult_reg, const float *input) {
+ return Tile(mult_reg, input, input + 8);
}
- INTGEMM_AVX2 Register ConsecutiveWithWrapping(const float *input, Index cols_left, Index cols, Index row_step) const {
- return Tile(
+ INTGEMM_AVX2 static inline Register ConsecutiveWithWrapping(FRegister mult_reg, const float *input, Index cols_left, Index cols, Index row_step) {
+ return Tile(mult_reg,
input,
input + 8 + (cols_left <= 8 ? cols * (row_step - 1) : 0));
}
- INTGEMM_AVX2 Register ForReshape(const float *input, Index cols) const {
+ INTGEMM_AVX2 static inline Register ForReshape(FRegister mult_reg, const float *input, Index cols) {
// 8 rows in the first 128-bit register, 8 in the second register.
- return Tile(input, input + 8 * cols);
+ return Tile(mult_reg, input, input + 8 * cols);
}
private:
- INTGEMM_AVX2 __m256i Tile(const float *input0, const float *input1) const {
- Register g0 = QuantizerGrab(input0, mult_);
- Register g1 = QuantizerGrab(input1, mult_);
+ INTGEMM_AVX2 static inline Register Tile(FRegister mult_reg, const float *input0, const float *input1) {
+ Register g0 = QuantizerGrab(input0, mult_reg);
+ Register g1 = QuantizerGrab(input1, mult_reg);
Register packed = _mm256_packs_epi32(g0, g1);
// Reorder the packed values because Intel does 0 1 2 3 8 9 10 11 4 5 6 7 12 13 14 15.
// Technically this could be removed if the PrepareB did the same reordering internally.
return _mm256_permute4x64_epi64(packed, 0xd8 /* 0, 2, 1, 3 */);
}
-
- const FRegister mult_;
};
struct Kernels16 {
@@ -61,10 +57,10 @@ struct Kernels16 {
INTGEMM_AVX2 static void Quantize(const float *input, int16_t *output, float quant_mult, Index size) {
assert(size % 16 == 0);
assert(reinterpret_cast<uintptr_t>(input) % 32 == 0);
- avx2::QuantizeTile16 q(quant_mult);
+ FRegister q = set1_ps<FRegister>(quant_mult);
const float *end = input + size;
for (; input != end; input += 16, output += 16) {
- *reinterpret_cast<__m256i*>(output) = q.Consecutive(input);
+ *reinterpret_cast<__m256i*>(output) = QuantizeTile16::Consecutive(q, input);
}
}
@@ -96,17 +92,15 @@ struct Kernels16 {
*/
class QuantizeTile8 {
public:
- INTGEMM_AVX2 explicit QuantizeTile8(float quant_mult) : mult_(_mm256_set1_ps(quant_mult)) {}
-
- INTGEMM_AVX2 inline __m256i Consecutive(const float *input) const {
- return Tile(input, input + 8, input + 16, input + 24);
+ INTGEMM_AVX2 static inline Register Consecutive(FRegister quant_mult, const float *input) {
+ return Tile(quant_mult, input, input + 8, input + 16, input + 24);
}
- INTGEMM_AVX2 inline __m256i ConsecutiveU(const float *input) const {
- return TileU(input, input + 8, input + 16, input + 24);
+ INTGEMM_AVX2 static inline Register ConsecutiveU(FRegister quant_mult, const float *input) {
+ return TileU(quant_mult, input, input + 8, input + 16, input + 24);
}
- INTGEMM_AVX2 Register ConsecutiveWithWrapping(const float *input, Index cols_left, Index cols, Index row_step) const {
+ INTGEMM_AVX2 static inline Register ConsecutiveWithWrapping(FRegister quant_mult, const float *input, Index cols_left, Index cols, Index row_step) {
const float* inputs[4];
for (Index i = 0; i < sizeof(inputs) / sizeof(inputs[0]); ++i) {
while (cols_left < sizeof(Register) / sizeof(float)) {
@@ -117,24 +111,24 @@ class QuantizeTile8 {
input += sizeof(Register) / sizeof(float);
cols_left -= sizeof(Register) / sizeof(float);
}
- return Tile(inputs[0], inputs[1], inputs[2], inputs[3]);
+ return Tile(quant_mult, inputs[0], inputs[1], inputs[2], inputs[3]);
}
- INTGEMM_AVX2 inline __m256i ForReshape(const float *input, Index cols) const {
+ INTGEMM_AVX2 static inline Register ForReshape(FRegister quant_mult, const float *input, Index cols) {
// Put higher rows in the second half of the register. These will jumble
// around in the same way then conveniently land in the right place.
- return Tile(input, input + 2 * cols, input + 16 * cols, input + 18 * cols);
+ return Tile(quant_mult, input, input + 2 * cols, input + 16 * cols, input + 18 * cols);
}
- INTGEMM_AVX2 inline __m256i Tile(const float *input0, const float *input1, const float *input2, const float *input3) const {
+ INTGEMM_AVX2 static inline __m256i Tile(FRegister quant_mult, const float *input0, const float *input1, const float *input2, const float *input3) {
// Looking at the assembly, gcc has pulled this outside the loops calling this.
const __m256i neg127 = _mm256_set1_epi8(-127);
const __m256i shuffle_param = _mm256_set_epi32(7, 3, 6, 2, 5, 1, 4, 0);
// Grab 4 registers at a time in 32-bit format.
- __m256i g0 = avx2::QuantizerGrab(input0, mult_);
- __m256i g1 = avx2::QuantizerGrab(input1, mult_);
- __m256i g2 = avx2::QuantizerGrab(input2, mult_);
- __m256i g3 = avx2::QuantizerGrab(input3, mult_);
+ __m256i g0 = avx2::QuantizerGrab(input0, quant_mult);
+ __m256i g1 = avx2::QuantizerGrab(input1, quant_mult);
+ __m256i g2 = avx2::QuantizerGrab(input2, quant_mult);
+ __m256i g3 = avx2::QuantizerGrab(input3, quant_mult);
// Pack 32-bit to 16-bit.
__m256i packed0 = _mm256_packs_epi32(g0, g1);
__m256i packed1 = _mm256_packs_epi32(g2, g3);
@@ -151,16 +145,16 @@ class QuantizeTile8 {
private:
//A version that produces uint8_ts
- INTGEMM_AVX2 inline __m256i TileU(const float *input0, const float *input1, const float *input2, const float *input3) const {
+ INTGEMM_AVX2 static inline Register TileU(FRegister quant_mult, const float *input0, const float *input1, const float *input2, const float *input3) {
// Looking at the assembly, gcc has pulled this outside the loops calling this.
const __m256i neg127 = _mm256_set1_epi8(-127);
const __m256i pos127 = _mm256_set1_epi8(127);
const __m256i shuffle_param = _mm256_set_epi32(7, 3, 6, 2, 5, 1, 4, 0);
// Grab 4 registers at a time in 32-bit format.
- __m256i g0 = avx2::QuantizerGrab(input0, mult_);
- __m256i g1 = avx2::QuantizerGrab(input1, mult_);
- __m256i g2 = avx2::QuantizerGrab(input2, mult_);
- __m256i g3 = avx2::QuantizerGrab(input3, mult_);
+ __m256i g0 = avx2::QuantizerGrab(input0, quant_mult);
+ __m256i g1 = avx2::QuantizerGrab(input1, quant_mult);
+ __m256i g2 = avx2::QuantizerGrab(input2, quant_mult);
+ __m256i g3 = avx2::QuantizerGrab(input3, quant_mult);
// Pack 32-bit to 16-bit.
__m256i packed0 = _mm256_packs_epi32(g0, g1);
__m256i packed1 = _mm256_packs_epi32(g2, g3);
@@ -175,8 +169,6 @@ class QuantizeTile8 {
// and the values are only used for GEMM.
return _mm256_permutevar8x32_epi32(packed, shuffle_param);
}
-
- const __m256 mult_;
};
struct Kernels8 {
@@ -187,9 +179,9 @@ struct Kernels8 {
Quantize(input, output, quant_mult, rows * cols);
}
private:
- INTGEMM_QUANTIZE_THREAD(INTGEMM_AVX2, __m256i, avx2)
+ INTGEMM_QUANTIZE_THREAD(INTGEMM_AVX2)
public:
- INTGEMM_QUANTIZE(INTGEMM_AVX2, __m256i, avx2)
+ INTGEMM_QUANTIZE(INTGEMM_AVX2)
// Currently A is prepared by quantization but this could theoretically change.
INTGEMM_AVX2 static inline void PrepareA(const float *input, uint8_t *output, float quant_mult, Index rows, Index cols) {
@@ -200,10 +192,10 @@ struct Kernels8 {
INTGEMM_AVX2 static void QuantizeU(const float *input, uint8_t *output, float quant_mult, Index size) {
assert(size % 32 == 0);
assert(reinterpret_cast<uintptr_t>(input) % 32 == 0);
- avx2::QuantizeTile8 q(quant_mult);
+ FRegister q = set1_ps<FRegister>(quant_mult);
const float *end = input + size;
for (; input != end; input += 32, output += 32) {
- *reinterpret_cast<__m256i*>(output) = q.ConsecutiveU(input);
+ *reinterpret_cast<__m256i*>(output) = QuantizeTile8::ConsecutiveU(q, input);
}
}