diff options
Diffstat (limited to 'avx2_gemm.h')
-rw-r--r-- | avx2_gemm.h | 18 |
1 files changed, 9 insertions, 9 deletions
diff --git a/avx2_gemm.h b/avx2_gemm.h index 93709e4..1addd1e 100644 --- a/avx2_gemm.h +++ b/avx2_gemm.h @@ -20,21 +20,21 @@ INTGEMM_SELECT_COL_B(INTGEMM_AVX2, __m256i) class QuantizeTile16 { public: - typedef __m256i Integer; + typedef __m256i Register; INTGEMM_AVX2 explicit QuantizeTile16(float mult) : mult_(_mm256_set1_ps(mult)) {} - INTGEMM_AVX2 Integer Consecutive(const float *input) const { + INTGEMM_AVX2 Register Consecutive(const float *input) const { return Tile(input, input + 8); } - INTGEMM_AVX2 Integer ConsecutiveWithWrapping(const float *input, Index cols_left, Index cols, Index row_step) const { + INTGEMM_AVX2 Register ConsecutiveWithWrapping(const float *input, Index cols_left, Index cols, Index row_step) const { return Tile( input, input + 8 + (cols_left <= 8 ? cols * (row_step - 1) : 0)); } - INTGEMM_AVX2 Integer ForReshape(const float *input, Index cols) const { + INTGEMM_AVX2 Register ForReshape(const float *input, Index cols) const { // 8 rows in the first 128-bit register, 8 in the second register. return Tile(input, input + 8 * cols); } @@ -103,7 +103,7 @@ namespace avx2 { */ class QuantizeTile8 { public: - typedef __m256i Integer; + typedef __m256i Register; INTGEMM_AVX2 explicit QuantizeTile8(float quant_mult) : mult_(_mm256_set1_ps(quant_mult)) {} @@ -115,16 +115,16 @@ class QuantizeTile8 { return TileU(input, input + 8, input + 16, input + 24); } - INTGEMM_AVX2 Integer ConsecutiveWithWrapping(const float *input, Index cols_left, Index cols, Index row_step) const { + INTGEMM_AVX2 Register ConsecutiveWithWrapping(const float *input, Index cols_left, Index cols, Index row_step) const { const float* inputs[4]; for (int i = 0; i < sizeof(inputs) / sizeof(inputs[0]); ++i) { - while (cols_left < sizeof(Integer) / sizeof(float)) { + while (cols_left < sizeof(Register) / sizeof(float)) { input += cols * (row_step - 1); cols_left += cols; } inputs[i] = input; - input += sizeof(Integer) / sizeof(float); - cols_left -= sizeof(Integer) / sizeof(float); + input += sizeof(Register) / sizeof(float); + cols_left -= sizeof(Register) / sizeof(float); } return Tile(inputs[0], inputs[1], inputs[2], inputs[3]); } |