From a5b1e8631bdeda4d984659bda2f020cbc273ce14 Mon Sep 17 00:00:00 2001 From: Mateusz Chudyk Date: Tue, 18 Feb 2020 16:58:17 +0000 Subject: Solve #67: Disambiguate name Integer for dependent types --- avx2_gemm.h | 18 ++--- avx512_gemm.h | 56 +++++++-------- avx512vnni_gemm.h | 98 ++++++++++++------------- interleave.h | 12 ++-- multiply.h | 208 +++++++++++++++++++++++++++--------------------------- sse2_gemm.h | 4 +- ssse3_gemm.h | 10 +-- 7 files changed, 203 insertions(+), 203 deletions(-) diff --git a/avx2_gemm.h b/avx2_gemm.h index 93709e4..1addd1e 100644 --- a/avx2_gemm.h +++ b/avx2_gemm.h @@ -20,21 +20,21 @@ INTGEMM_SELECT_COL_B(INTGEMM_AVX2, __m256i) class QuantizeTile16 { public: - typedef __m256i Integer; + typedef __m256i Register; INTGEMM_AVX2 explicit QuantizeTile16(float mult) : mult_(_mm256_set1_ps(mult)) {} - INTGEMM_AVX2 Integer Consecutive(const float *input) const { + INTGEMM_AVX2 Register Consecutive(const float *input) const { return Tile(input, input + 8); } - INTGEMM_AVX2 Integer ConsecutiveWithWrapping(const float *input, Index cols_left, Index cols, Index row_step) const { + INTGEMM_AVX2 Register ConsecutiveWithWrapping(const float *input, Index cols_left, Index cols, Index row_step) const { return Tile( input, input + 8 + (cols_left <= 8 ? cols * (row_step - 1) : 0)); } - INTGEMM_AVX2 Integer ForReshape(const float *input, Index cols) const { + INTGEMM_AVX2 Register ForReshape(const float *input, Index cols) const { // 8 rows in the first 128-bit register, 8 in the second register. return Tile(input, input + 8 * cols); } @@ -103,7 +103,7 @@ namespace avx2 { */ class QuantizeTile8 { public: - typedef __m256i Integer; + typedef __m256i Register; INTGEMM_AVX2 explicit QuantizeTile8(float quant_mult) : mult_(_mm256_set1_ps(quant_mult)) {} @@ -115,16 +115,16 @@ class QuantizeTile8 { return TileU(input, input + 8, input + 16, input + 24); } - INTGEMM_AVX2 Integer ConsecutiveWithWrapping(const float *input, Index cols_left, Index cols, Index row_step) const { + INTGEMM_AVX2 Register ConsecutiveWithWrapping(const float *input, Index cols_left, Index cols, Index row_step) const { const float* inputs[4]; for (int i = 0; i < sizeof(inputs) / sizeof(inputs[0]); ++i) { - while (cols_left < sizeof(Integer) / sizeof(float)) { + while (cols_left < sizeof(Register) / sizeof(float)) { input += cols * (row_step - 1); cols_left += cols; } inputs[i] = input; - input += sizeof(Integer) / sizeof(float); - cols_left -= sizeof(Integer) / sizeof(float); + input += sizeof(Register) / sizeof(float); + cols_left -= sizeof(Register) / sizeof(float); } return Tile(inputs[0], inputs[1], inputs[2], inputs[3]); } diff --git a/avx512_gemm.h b/avx512_gemm.h index 91fdd8a..2a5fff1 100644 --- a/avx512_gemm.h +++ b/avx512_gemm.h @@ -70,12 +70,12 @@ INTGEMM_AVX512BW inline __m512i QuantizerGrabHalves(const float *input0, const f // being used for the quantizer. class QuantizeTile16 { public: - typedef __m512i Integer; + typedef __m512i Register; /* Only INTGEMM_AVX512F is necessary but due to GCC 5.4 bug we have to set INTGEMM_AVX512BW */ INTGEMM_AVX512BW explicit QuantizeTile16(float mult) : mult_reg_(_mm512_set1_ps(mult)) {} - INTGEMM_AVX512BW Integer ConsecutiveWithWrapping(const float *input, Index cols_left, Index cols, Index row_step) const { + INTGEMM_AVX512BW Register ConsecutiveWithWrapping(const float *input, Index cols_left, Index cols, Index row_step) const { auto input0 = input; auto input1 = input + 16 + (cols_left <= 16 ? cols * (row_step - 1) : 0); auto g0 = QuantizerGrabHalves(input0, input1, mult_reg_); @@ -98,24 +98,24 @@ class QuantizeTile16 { class QuantizeTile8 { public: - typedef __m512i Integer; + typedef __m512i Register; /* Only INTGEMM_AVX512F is necessary but due to GCC 5.4 bug we have to set INTGEMM_AVX512BW */ INTGEMM_AVX512BW explicit QuantizeTile8(float mult) : mult_reg_(_mm512_set1_ps(mult)) {} - INTGEMM_AVX512BW Integer ConsecutiveWithWrapping(const float *input, Index cols_left, Index cols, Index row_step) const { + INTGEMM_AVX512BW Register ConsecutiveWithWrapping(const float *input, Index cols_left, Index cols, Index row_step) const { static const __m512i neg127 = _mm512_set1_epi8(-127); static const __m512i shuffle_param = _mm512_set_epi32(15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0); const float* inputs[4]; for (int i = 0; i < sizeof(inputs) / sizeof(inputs[0]); ++i) { - while (cols_left < sizeof(Integer) / sizeof(float)) { + while (cols_left < sizeof(Register) / sizeof(float)) { input += cols * (row_step - 1); cols_left += cols; } inputs[i] = input; - input += sizeof(Integer) / sizeof(float); - cols_left -= sizeof(Integer) / sizeof(float); + input += sizeof(Register) / sizeof(float); + cols_left -= sizeof(Register) / sizeof(float); } auto g0 = QuantizerGrab(inputs[0], mult_reg_); @@ -292,41 +292,41 @@ struct AVX512_8bit { // allocate registers manually) and no sign instruction. template INTGEMM_AVX512BW static void Multiply(const int8_t *A, const int8_t *B, Index A_rows, Index width, Index B_cols, Callback callback) { - typedef __m512i Integer; + typedef __m512i Register; //typedef __m256 Float; // For quantization we only do 8 at a time. // This is copy-paste from Multiply8_SSE2OrAVX2. - assert(width % sizeof(Integer) == 0); + assert(width % sizeof(Register) == 0); assert(B_cols % 8 == 0); - assert(reinterpret_cast(A) % sizeof(Integer) == 0); - assert(reinterpret_cast(B) % sizeof(Integer) == 0); + assert(reinterpret_cast(A) % sizeof(Register) == 0); + assert(reinterpret_cast(B) % sizeof(Register) == 0); // There's 8 results for INTGEMM_AVX2 to handle. auto callback_impl = callbacks::CallbackImpl(callback); - const int simd_width = width / sizeof(Integer); - const Integer *B0_col = reinterpret_cast(B); + const int simd_width = width / sizeof(Register); + const Register *B0_col = reinterpret_cast(B); // Added for AVX512. - Integer zeros = setzero_si(); + Register zeros = setzero_si(); // Go over 8 columns of B at a time. for (Index B0_colidx = 0; B0_colidx != B_cols; B0_col += 8 * simd_width, B0_colidx += 8) { // Process one row of A at a time. Doesn't seem to be faster to do multiple rows of A at once. for (Index A_rowidx = 0; A_rowidx < A_rows; ++A_rowidx) { // Iterate over shared (inner) dimension. - const Integer *A_live = reinterpret_cast(A + A_rowidx * width); - const Integer *A_end = A_live + simd_width; - const Integer *B_live = B0_col; + const Register *A_live = reinterpret_cast(A + A_rowidx * width); + const Register *A_end = A_live + simd_width; + const Register *B_live = B0_col; // Do the first iteration to initialize the sums. __m512i a = *A_live; __mmask64 neg_mask = _mm512_test_epi8_mask(a, _mm512_set1_epi8(-128)); __m512i a_positive = _mm512_abs_epi8(a); // These will be packed 16-bit integers containing sums for each column of B multiplied by the row of A. - Integer sum0 = maddubs_epi16(a_positive, _mm512_mask_sub_epi8(B_live[0], neg_mask, zeros, B_live[0])); - Integer sum1 = maddubs_epi16(a_positive, _mm512_mask_sub_epi8(B_live[1], neg_mask, zeros, B_live[1])); - Integer sum2 = maddubs_epi16(a_positive, _mm512_mask_sub_epi8(B_live[2], neg_mask, zeros, B_live[2])); - Integer sum3 = maddubs_epi16(a_positive, _mm512_mask_sub_epi8(B_live[3], neg_mask, zeros, B_live[3])); - Integer sum4 = maddubs_epi16(a_positive, _mm512_mask_sub_epi8(B_live[4], neg_mask, zeros, B_live[4])); - Integer sum5 = maddubs_epi16(a_positive, _mm512_mask_sub_epi8(B_live[5], neg_mask, zeros, B_live[5])); - Integer sum6 = maddubs_epi16(a_positive, _mm512_mask_sub_epi8(B_live[6], neg_mask, zeros, B_live[6])); - Integer sum7 = maddubs_epi16(a_positive, _mm512_mask_sub_epi8(B_live[7], neg_mask, zeros, B_live[7])); + Register sum0 = maddubs_epi16(a_positive, _mm512_mask_sub_epi8(B_live[0], neg_mask, zeros, B_live[0])); + Register sum1 = maddubs_epi16(a_positive, _mm512_mask_sub_epi8(B_live[1], neg_mask, zeros, B_live[1])); + Register sum2 = maddubs_epi16(a_positive, _mm512_mask_sub_epi8(B_live[2], neg_mask, zeros, B_live[2])); + Register sum3 = maddubs_epi16(a_positive, _mm512_mask_sub_epi8(B_live[3], neg_mask, zeros, B_live[3])); + Register sum4 = maddubs_epi16(a_positive, _mm512_mask_sub_epi8(B_live[4], neg_mask, zeros, B_live[4])); + Register sum5 = maddubs_epi16(a_positive, _mm512_mask_sub_epi8(B_live[5], neg_mask, zeros, B_live[5])); + Register sum6 = maddubs_epi16(a_positive, _mm512_mask_sub_epi8(B_live[6], neg_mask, zeros, B_live[6])); + Register sum7 = maddubs_epi16(a_positive, _mm512_mask_sub_epi8(B_live[7], neg_mask, zeros, B_live[7])); ++A_live; B_live += 8; @@ -384,7 +384,7 @@ struct AVX512_8bit { // Unique code ends: can we do an inline function? } // Upcast to 32-bit and horizontally add. - Integer ones = set1_epi16(1); + Register ones = set1_epi16(1); sum0 = madd_epi16(sum0, ones); sum1 = madd_epi16(sum1, ones); sum2 = madd_epi16(sum2, ones); @@ -393,8 +393,8 @@ struct AVX512_8bit { sum5 = madd_epi16(sum5, ones); sum6 = madd_epi16(sum6, ones); sum7 = madd_epi16(sum7, ones); - Integer pack0123 = Pack0123(sum0, sum1, sum2, sum3); - Integer pack4567 = Pack0123(sum4, sum5, sum6, sum7); + Register pack0123 = Pack0123(sum0, sum1, sum2, sum3); + Register pack4567 = Pack0123(sum4, sum5, sum6, sum7); auto total = PermuteSummer(pack0123, pack4567); callback_impl(total, callbacks::OutputBufferInfo(A_rowidx, B0_colidx, A_rows, B_cols)); diff --git a/avx512vnni_gemm.h b/avx512vnni_gemm.h index 3f616a6..59f6405 100644 --- a/avx512vnni_gemm.h +++ b/avx512vnni_gemm.h @@ -11,39 +11,39 @@ namespace intgemm { struct AVX512VNNI_8bit : public AVX512_8bit { template INTGEMM_AVX512VNNI static void Multiply(const int8_t *A, const int8_t *B, Index A_rows, Index width, Index B_cols, Callback callback) { - typedef __m512i Integer; - assert(width % sizeof(Integer) == 0); + typedef __m512i Register; + assert(width % sizeof(Register) == 0); assert(B_cols % 8 == 0); - assert(reinterpret_cast(A) % sizeof(Integer) == 0); - assert(reinterpret_cast(B) % sizeof(Integer) == 0); + assert(reinterpret_cast(A) % sizeof(Register) == 0); + assert(reinterpret_cast(B) % sizeof(Register) == 0); auto callback_impl = callbacks::CallbackImpl(callback); - const int simd_width = width / sizeof(Integer); - const Integer *B0_col = reinterpret_cast(B); - Integer zeros = setzero_si(); + const int simd_width = width / sizeof(Register); + const Register *B0_col = reinterpret_cast(B); + Register zeros = setzero_si(); // Go over 8 columns of B at a time. for (Index B0_colidx = 0; B0_colidx != B_cols; B0_col += 8 * simd_width, B0_colidx += 8) { // Process one row of A at a time. Doesn't seem to be faster to do multiple rows of A at once. for (Index A_rowidx = 0; A_rowidx < A_rows; ++A_rowidx) { // Iterate over shared (inner) dimension. - const Integer *A_live = reinterpret_cast(A + A_rowidx * width); - const Integer *A_end = A_live + simd_width; - const Integer *B_live = B0_col; + const Register *A_live = reinterpret_cast(A + A_rowidx * width); + const Register *A_end = A_live + simd_width; + const Register *B_live = B0_col; // TODO: separate first step. - Integer sum0 = zeros, sum1 = zeros, sum2 = zeros, sum3 = zeros, sum4 = zeros, sum5 = zeros, sum6 = zeros, sum7 = zeros; + Register sum0 = zeros, sum1 = zeros, sum2 = zeros, sum3 = zeros, sum4 = zeros, sum5 = zeros, sum6 = zeros, sum7 = zeros; for (; A_live != A_end; ++A_live, B_live += 8) { - Integer a = *A_live; + Register a = *A_live; // Retrieve the conveniently consecutive values of B. - Integer b0 = *B_live; - Integer b1 = *(B_live + 1); - Integer b2 = *(B_live + 2); - Integer b3 = *(B_live + 3); - Integer b4 = *(B_live + 4); - Integer b5 = *(B_live + 5); - Integer b6 = *(B_live + 6); - Integer b7 = *(B_live + 7); + Register b0 = *B_live; + Register b1 = *(B_live + 1); + Register b2 = *(B_live + 2); + Register b3 = *(B_live + 3); + Register b4 = *(B_live + 4); + Register b5 = *(B_live + 5); + Register b6 = *(B_live + 6); + Register b7 = *(B_live + 7); // Get a mask where a is negative. __mmask64 neg_mask = _mm512_test_epi8_mask(a, _mm512_set1_epi8(-128)); - Integer a_positive = _mm512_abs_epi8(a); + Register a_positive = _mm512_abs_epi8(a); // Negate by subtracting from zero with a mask. b0 = _mm512_mask_sub_epi8(b0, neg_mask, zeros, b0); b1 = _mm512_mask_sub_epi8(b1, neg_mask, zeros, b1); @@ -62,8 +62,8 @@ struct AVX512VNNI_8bit : public AVX512_8bit { sum6 = _mm512_dpbusds_epi32(sum6, a_positive, b6); sum7 = _mm512_dpbusds_epi32(sum7, a_positive, b7); } - Integer pack0123 = Pack0123(sum0, sum1, sum2, sum3); - Integer pack4567 = Pack0123(sum4, sum5, sum6, sum7); + Register pack0123 = Pack0123(sum0, sum1, sum2, sum3); + Register pack4567 = Pack0123(sum4, sum5, sum6, sum7); auto total = PermuteSummer(pack0123, pack4567); callback_impl(total, callbacks::OutputBufferInfo(A_rowidx, B0_colidx, A_rows, B_cols)); } @@ -72,27 +72,27 @@ struct AVX512VNNI_8bit : public AVX512_8bit { template INTGEMM_AVX512VNNI static void Multiply8Shift(const uint8_t *A, const int8_t *B, Index A_rows, Index width, Index B_cols, Callback callback) { - typedef __m512i Integer; - assert(width % sizeof(Integer) == 0); + typedef __m512i Register; + assert(width % sizeof(Register) == 0); assert(B_cols % 8 == 0); - assert(reinterpret_cast(A) % sizeof(Integer) == 0); - assert(reinterpret_cast(B) % sizeof(Integer) == 0); + assert(reinterpret_cast(A) % sizeof(Register) == 0); + assert(reinterpret_cast(B) % sizeof(Register) == 0); auto callback_impl = callbacks::CallbackImpl(callback); - const int simd_width = width / sizeof(Integer); - const Integer *B0_col = reinterpret_cast(B); - Integer zeros = setzero_si(); + const int simd_width = width / sizeof(Register); + const Register *B0_col = reinterpret_cast(B); + Register zeros = setzero_si(); // Go over 8 columns of B at a time. for (Index B0_colidx = 0; B0_colidx != B_cols; B0_col += 8 * simd_width, B0_colidx += 8) { // Process one row of A at a time. Doesn't seem to be faster to do multiple rows of A at once. for (Index A_rowidx = 0; A_rowidx < A_rows; ++A_rowidx) { // Iterate over shared (inner) dimension. - const Integer *A_live = reinterpret_cast(A + A_rowidx * width); - const Integer *A_end = A_live + simd_width; - const Integer *B_live = B0_col; + const Register *A_live = reinterpret_cast(A + A_rowidx * width); + const Register *A_end = A_live + simd_width; + const Register *B_live = B0_col; // TODO: separate first step. - Integer sum0 = zeros, sum1 = zeros, sum2 = zeros, sum3 = zeros, sum4 = zeros, sum5 = zeros, sum6 = zeros, sum7 = zeros; + Register sum0 = zeros, sum1 = zeros, sum2 = zeros, sum3 = zeros, sum4 = zeros, sum5 = zeros, sum6 = zeros, sum7 = zeros; for (; A_live != A_end; ++A_live, B_live += 8) { - Integer a = *A_live; + Register a = *A_live; //MultiplyAdd sum0 = _mm512_dpbusds_epi32(sum0, a, *B_live); sum1 = _mm512_dpbusds_epi32(sum1, a, *(B_live + 1)); @@ -103,8 +103,8 @@ struct AVX512VNNI_8bit : public AVX512_8bit { sum6 = _mm512_dpbusds_epi32(sum6, a, *(B_live + 6)); sum7 = _mm512_dpbusds_epi32(sum7, a, *(B_live + 7)); } - Integer pack0123 = Pack0123(sum0, sum1, sum2, sum3); - Integer pack4567 = Pack0123(sum4, sum5, sum6, sum7); + Register pack0123 = Pack0123(sum0, sum1, sum2, sum3); + Register pack4567 = Pack0123(sum4, sum5, sum6, sum7); auto total = PermuteSummer(pack0123, pack4567); callback_impl(total, callbacks::OutputBufferInfo(A_rowidx, B0_colidx, A_rows, B_cols)); } @@ -113,22 +113,22 @@ struct AVX512VNNI_8bit : public AVX512_8bit { template INTGEMM_AVX512VNNI static void PrepareBias(const int8_t *B, Index width, Index B_cols, Callback callback) { - typedef __m512i Integer; - assert(width % sizeof(Integer) == 0); + typedef __m512i Register; + assert(width % sizeof(Register) == 0); assert(B_cols % 8 == 0); - assert(reinterpret_cast(B) % sizeof(Integer) == 0); + assert(reinterpret_cast(B) % sizeof(Register) == 0); auto callback_impl = callbacks::CallbackImpl(callback); - const int simd_width = width / sizeof(Integer); - const Integer *B0_col = reinterpret_cast(B); - Integer zeros = setzero_si(); - const Integer a = set1_epi8(1); + const int simd_width = width / sizeof(Register); + const Register *B0_col = reinterpret_cast(B); + Register zeros = setzero_si(); + const Register a = set1_epi8(1); // Go over 8 columns of B at a time. for (Index B0_colidx = 0; B0_colidx != B_cols; B0_col += 8 * simd_width, B0_colidx += 8) { - const Integer *B_live = B0_col; //In order to make the code look as much as possible as the above function - const Integer *B_end = B_live + simd_width*8; + const Register *B_live = B0_col; //In order to make the code look as much as possible as the above function + const Register *B_end = B_live + simd_width*8; // TODO: separate first step. - Integer sum0 = zeros, sum1 = zeros, sum2 = zeros, sum3 = zeros, sum4 = zeros, sum5 = zeros, sum6 = zeros, sum7 = zeros; + Register sum0 = zeros, sum1 = zeros, sum2 = zeros, sum3 = zeros, sum4 = zeros, sum5 = zeros, sum6 = zeros, sum7 = zeros; for (; B_live != B_end; B_live += 8) { // Retrieve the conveniently consecutive values of B. sum0 = _mm512_dpbusds_epi32(sum0, a, *B_live); @@ -140,8 +140,8 @@ struct AVX512VNNI_8bit : public AVX512_8bit { sum6 = _mm512_dpbusds_epi32(sum6, a, *(B_live + 6)); sum7 = _mm512_dpbusds_epi32(sum7, a, *(B_live + 7)); } - Integer pack0123 = Pack0123(sum0, sum1, sum2, sum3); - Integer pack4567 = Pack0123(sum4, sum5, sum6, sum7); + Register pack0123 = Pack0123(sum0, sum1, sum2, sum3); + Register pack4567 = Pack0123(sum4, sum5, sum6, sum7); auto total = PermuteSummer(pack0123, pack4567); callback_impl(total, callbacks::OutputBufferInfo(0, B0_colidx, 1, B_cols)); } diff --git a/interleave.h b/interleave.h index 41ac8b7..79f3163 100644 --- a/interleave.h +++ b/interleave.h @@ -181,7 +181,7 @@ template static inline void Transpose8InLane( #define INTGEMM_PREPARE_B_8(target, QuantClass) \ target static inline void PrepareB(const float *input, int8_t *output_shadow, float quant_mult, Index rows, Index cols) { \ typedef typename QuantClass Quantizer; \ - typedef typename Quantizer::Integer Register; \ + typedef typename Quantizer::Register Register; \ Quantizer q = Quantizer(quant_mult); \ /* Currently all multipliers have a stride of 8 columns.*/ \ const int kColStride = 8; \ @@ -216,7 +216,7 @@ target static inline void PrepareB(const float *input, int8_t *output_shadow, fl #define INTGEMM_PREPARE_B_16(target, QuantClass) \ target static inline void PrepareB(const float *input, int16_t *output_shadow, float quant_mult, Index rows, Index cols) { \ typedef typename QuantClass Quantizer; \ - typedef typename Quantizer::Integer Register; \ + typedef typename Quantizer::Register Register; \ Quantizer q = Quantizer(quant_mult); \ assert(cols % 8 == 0); \ assert(rows % (sizeof(Register) / sizeof(int16_t)) == 0); \ @@ -266,10 +266,10 @@ target static inline void PrepareBQuantizedTransposed(const Integer* input, Inte * * cols and rows describe size of transposed B. */ -#define INTGEMM_PREPARE_B_TRANSPOSED(target, Quantizer, integer) \ -target static inline void PrepareBTransposed(const float* input, integer* output, float quant_mult, Index cols, Index rows) { \ - using Register = typename Quantizer::Integer; \ - const Index RegisterElemsInt = sizeof(Register) / sizeof(integer); \ +#define INTGEMM_PREPARE_B_TRANSPOSED(target, Quantizer, Integer) \ +target static inline void PrepareBTransposed(const float* input, Integer* output, float quant_mult, Index cols, Index rows) { \ + using Register = typename Quantizer::Register; \ + const Index RegisterElemsInt = sizeof(Register) / sizeof(Integer); \ const Index RegisterElemsFloat = sizeof(Register) / sizeof(float); \ const Index kColStride = 8; \ \ diff --git a/multiply.h b/multiply.h index abc224c..0aa86aa 100644 --- a/multiply.h +++ b/multiply.h @@ -139,42 +139,42 @@ INTGEMM_AVX2 static inline void RunCallback(Callback& callback_impl, vector_t target static void Multiply(const int16_t *A, const int16_t *B, Index A_rows, Index width, Index B_cols, Callback callback) { \ - assert(width % (sizeof(Integer) / sizeof(int16_t)) == 0); \ + assert(width % (sizeof(Register) / sizeof(int16_t)) == 0); \ assert(B_cols % 8 == 0); \ - assert(reinterpret_cast(A) % sizeof(Integer) == 0); \ - assert(reinterpret_cast(B) % sizeof(Integer) == 0); \ - const int simd_width = width / (sizeof(Integer) / sizeof(int16_t)); \ + assert(reinterpret_cast(A) % sizeof(Register) == 0); \ + assert(reinterpret_cast(B) % sizeof(Register) == 0); \ + const int simd_width = width / (sizeof(Register) / sizeof(int16_t)); \ auto callback_impl = callbacks::CallbackImpl(callback); \ - const Integer *B0_col = reinterpret_cast(B); \ + const Register *B0_col = reinterpret_cast(B); \ for (Index B0_colidx = 0; B0_colidx < B_cols; B0_col += 8 * simd_width, B0_colidx += 8) { \ /* Process one row of A at a time. Doesn't seem to be faster to do multiple rows of A at once.*/ \ for (Index A_rowidx = 0; A_rowidx < A_rows; ++A_rowidx) { \ - const Integer *A_row = reinterpret_cast(A + A_rowidx * width); \ + const Register *A_row = reinterpret_cast(A + A_rowidx * width); \ /* These will be packed 32-bit integers containing sums for each row of B multiplied by the row of A. \ Iterate over shared (inner) dimension.*/ \ int k = 0; \ - Integer a = *(A_row + k); \ - Integer sum0 = madd_epi16(a, *(B0_col + k * 8)); \ - Integer sum1 = madd_epi16(a, *(B0_col + k * 8 + 1)); \ - Integer sum2 = madd_epi16(a, *(B0_col + k * 8 + 2)); \ - Integer sum3 = madd_epi16(a, *(B0_col + k * 8 + 3)); \ - Integer sum4 = madd_epi16(a, *(B0_col + k * 8 + 4)); \ - Integer sum5 = madd_epi16(a, *(B0_col + k * 8 + 5)); \ - Integer sum6 = madd_epi16(a, *(B0_col + k * 8 + 6)); \ - Integer sum7 = madd_epi16(a, *(B0_col + k * 8 + 7)); \ + Register a = *(A_row + k); \ + Register sum0 = madd_epi16(a, *(B0_col + k * 8)); \ + Register sum1 = madd_epi16(a, *(B0_col + k * 8 + 1)); \ + Register sum2 = madd_epi16(a, *(B0_col + k * 8 + 2)); \ + Register sum3 = madd_epi16(a, *(B0_col + k * 8 + 3)); \ + Register sum4 = madd_epi16(a, *(B0_col + k * 8 + 4)); \ + Register sum5 = madd_epi16(a, *(B0_col + k * 8 + 5)); \ + Register sum6 = madd_epi16(a, *(B0_col + k * 8 + 6)); \ + Register sum7 = madd_epi16(a, *(B0_col + k * 8 + 7)); \ for (int k = 1; k < simd_width; ++k) { \ - Integer a = *(A_row + k); \ + Register a = *(A_row + k); \ /* Multiply 16-bit, horizontally add to packed 32-bit integers.*/ \ - Integer mult0 = madd_epi16(a, *(B0_col + k * 8)); \ - Integer mult1 = madd_epi16(a, *(B0_col + k * 8 + 1)); \ - Integer mult2 = madd_epi16(a, *(B0_col + k * 8 + 2)); \ - Integer mult3 = madd_epi16(a, *(B0_col + k * 8 + 3)); \ - Integer mult4 = madd_epi16(a, *(B0_col + k * 8 + 4)); \ - Integer mult5 = madd_epi16(a, *(B0_col + k * 8 + 5)); \ - Integer mult6 = madd_epi16(a, *(B0_col + k * 8 + 6)); \ - Integer mult7 = madd_epi16(a, *(B0_col + k * 8 + 7)); \ + Register mult0 = madd_epi16(a, *(B0_col + k * 8)); \ + Register mult1 = madd_epi16(a, *(B0_col + k * 8 + 1)); \ + Register mult2 = madd_epi16(a, *(B0_col + k * 8 + 2)); \ + Register mult3 = madd_epi16(a, *(B0_col + k * 8 + 3)); \ + Register mult4 = madd_epi16(a, *(B0_col + k * 8 + 4)); \ + Register mult5 = madd_epi16(a, *(B0_col + k * 8 + 5)); \ + Register mult6 = madd_epi16(a, *(B0_col + k * 8 + 6)); \ + Register mult7 = madd_epi16(a, *(B0_col + k * 8 + 7)); \ /* Sum packed 32-bit integers with danger of overflow. TODO: accumulate in 64-bit every so often.*/ \ sum0 = add_epi32(sum0, mult0); \ sum1 = add_epi32(sum1, mult1); \ @@ -186,8 +186,8 @@ template target static void Multiply(const int16_t *A, const sum7 = add_epi32(sum7, mult7); \ } \ /* Reduce sums within 128-bit lanes.*/ \ - Integer pack0123 = Pack0123(sum0, sum1, sum2, sum3); \ - Integer pack4567 = Pack0123(sum4, sum5, sum6, sum7); \ + Register pack0123 = Pack0123(sum0, sum1, sum2, sum3); \ + Register pack4567 = Pack0123(sum4, sum5, sum6, sum7); \ /*The specific implementation may need to reduce further.*/ \ auto total = PermuteSummer(pack0123, pack4567); \ RunCallback(callback_impl, total, A_rowidx, B0_colidx, A_rows, B_cols); \ @@ -196,30 +196,30 @@ template target static void Multiply(const int16_t *A, const } \ //An int8_prepbias version of the above code, using the add 127 technique -#define INTGEMM_PREPAREBIASFOR8(Integer, target, cpu_type) \ +#define INTGEMM_PREPAREBIASFOR8(Register, target, cpu_type) \ template target static void PrepareBias(const int8_t *B, Index width, Index B_cols, Callback callback) { \ - assert(width % (sizeof(Integer) / sizeof(int8_t)) == 0); \ + assert(width % (sizeof(Register) / sizeof(int8_t)) == 0); \ assert(B_cols % 8 == 0); \ - assert(reinterpret_cast(B) % sizeof(Integer) == 0); \ - const int simd_width = width / (sizeof(Integer) / sizeof(int8_t)); \ + assert(reinterpret_cast(B) % sizeof(Register) == 0); \ + const int simd_width = width / (sizeof(Register) / sizeof(int8_t)); \ auto callback_impl = callbacks::CallbackImpl(callback); \ - const Integer *B0_col = reinterpret_cast(B); \ - const Integer a = set1_epi8(1); \ + const Register *B0_col = reinterpret_cast(B); \ + const Register a = set1_epi8(1); \ for (Index B0_colidx = 0; B0_colidx < B_cols; B0_col += 8 * simd_width, B0_colidx += 8) { \ - /*const Integer *A_row = reinterpret_cast(A + A_rowidx * width);*/ \ + /*const Register *A_row = reinterpret_cast(A + A_rowidx * width);*/ \ /* These will be packed 16-bit integers containing sums for each row of B multiplied by the row of A. \ Iterate over shared (inner) dimension.*/ \ int k = 0; \ - Integer sum0 = maddubs_epi16(a, *(B0_col + k * 8)); \ - Integer sum1 = maddubs_epi16(a, *(B0_col + k * 8 + 1)); \ - Integer sum2 = maddubs_epi16(a, *(B0_col + k * 8 + 2)); \ - Integer sum3 = maddubs_epi16(a, *(B0_col + k * 8 + 3)); \ - Integer sum4 = maddubs_epi16(a, *(B0_col + k * 8 + 4)); \ - Integer sum5 = maddubs_epi16(a, *(B0_col + k * 8 + 5)); \ - Integer sum6 = maddubs_epi16(a, *(B0_col + k * 8 + 6)); \ - Integer sum7 = maddubs_epi16(a, *(B0_col + k * 8 + 7)); \ + Register sum0 = maddubs_epi16(a, *(B0_col + k * 8)); \ + Register sum1 = maddubs_epi16(a, *(B0_col + k * 8 + 1)); \ + Register sum2 = maddubs_epi16(a, *(B0_col + k * 8 + 2)); \ + Register sum3 = maddubs_epi16(a, *(B0_col + k * 8 + 3)); \ + Register sum4 = maddubs_epi16(a, *(B0_col + k * 8 + 4)); \ + Register sum5 = maddubs_epi16(a, *(B0_col + k * 8 + 5)); \ + Register sum6 = maddubs_epi16(a, *(B0_col + k * 8 + 6)); \ + Register sum7 = maddubs_epi16(a, *(B0_col + k * 8 + 7)); \ /* Upcast to 32-bit and horizontally add. Seems a bit faster if this is declared here.*/ \ - Integer ones = set1_epi16(1); \ + Register ones = set1_epi16(1); \ sum0 = madd_epi16(sum0, ones); \ sum1 = madd_epi16(sum1, ones); \ sum2 = madd_epi16(sum2, ones); \ @@ -229,16 +229,16 @@ template target static void Multiply(const int16_t *A, const sum6 = madd_epi16(sum6, ones); \ sum7 = madd_epi16(sum7, ones); \ for (int k = 1; k < simd_width; ++k) { \ - /*Integer a = *(A_row + k);*/ \ + /*Register a = *(A_row + k);*/ \ /* Multiply 8-bit, horizontally add to packed 16-bit integers.*/ \ - Integer mult0 = maddubs_epi16(a, *(B0_col + k * 8)); \ - Integer mult1 = maddubs_epi16(a, *(B0_col + k * 8 + 1)); \ - Integer mult2 = maddubs_epi16(a, *(B0_col + k * 8 + 2)); \ - Integer mult3 = maddubs_epi16(a, *(B0_col + k * 8 + 3)); \ - Integer mult4 = maddubs_epi16(a, *(B0_col + k * 8 + 4)); \ - Integer mult5 = maddubs_epi16(a, *(B0_col + k * 8 + 5)); \ - Integer mult6 = maddubs_epi16(a, *(B0_col + k * 8 + 6)); \ - Integer mult7 = maddubs_epi16(a, *(B0_col + k * 8 + 7)); \ + Register mult0 = maddubs_epi16(a, *(B0_col + k * 8)); \ + Register mult1 = maddubs_epi16(a, *(B0_col + k * 8 + 1)); \ + Register mult2 = maddubs_epi16(a, *(B0_col + k * 8 + 2)); \ + Register mult3 = maddubs_epi16(a, *(B0_col + k * 8 + 3)); \ + Register mult4 = maddubs_epi16(a, *(B0_col + k * 8 + 4)); \ + Register mult5 = maddubs_epi16(a, *(B0_col + k * 8 + 5)); \ + Register mult6 = maddubs_epi16(a, *(B0_col + k * 8 + 6)); \ + Register mult7 = maddubs_epi16(a, *(B0_col + k * 8 + 7)); \ /* Upcast to 32-bit and horizontally add.*/ \ mult0 = madd_epi16(mult0, ones); \ mult1 = madd_epi16(mult1, ones); \ @@ -260,8 +260,8 @@ template target static void Multiply(const int16_t *A, const \ } \ /* Reduce sums within 128-bit lanes.*/ \ - Integer pack0123 = Pack0123(sum0, sum1, sum2, sum3); \ - Integer pack4567 = Pack0123(sum4, sum5, sum6, sum7); \ + Register pack0123 = Pack0123(sum0, sum1, sum2, sum3); \ + Register pack4567 = Pack0123(sum4, sum5, sum6, sum7); \ /*The specific implementation may need to reduce further.*/ \ auto total = PermuteSummer(pack0123, pack4567); \ RunCallback(callback_impl, total, 0, B0_colidx, 1, B_cols); \ @@ -269,33 +269,33 @@ template target static void Multiply(const int16_t *A, const } \ //An int8 version of the above code, using the add 127 technique -#define INTGEMM_MULTIPLY8SHIFT(Integer, target, cpu_type) \ +#define INTGEMM_MULTIPLY8SHIFT(Register, target, cpu_type) \ template target static void Multiply8Shift(const uint8_t *A, const int8_t *B, Index A_rows, Index width, Index B_cols, Callback callback) { \ - assert(width % (sizeof(Integer) / sizeof(int8_t)) == 0); \ + assert(width % (sizeof(Register) / sizeof(int8_t)) == 0); \ assert(B_cols % 8 == 0); \ - assert(reinterpret_cast(A) % sizeof(Integer) == 0); \ - assert(reinterpret_cast(B) % sizeof(Integer) == 0); \ - const int simd_width = width / (sizeof(Integer) / sizeof(int8_t)); \ + assert(reinterpret_cast(A) % sizeof(Register) == 0); \ + assert(reinterpret_cast(B) % sizeof(Register) == 0); \ + const int simd_width = width / (sizeof(Register) / sizeof(int8_t)); \ auto callback_impl = callbacks::CallbackImpl(callback); \ - const Integer *B0_col = reinterpret_cast(B); \ + const Register *B0_col = reinterpret_cast(B); \ for (Index B0_colidx = 0; B0_colidx < B_cols; B0_col += 8 * simd_width, B0_colidx += 8) { \ /* Process one row of A at a time. Doesn't seem to be faster to do multiple rows of A at once.*/ \ for (Index A_rowidx = 0; A_rowidx < A_rows; ++A_rowidx) { \ - const Integer *A_row = reinterpret_cast(A + A_rowidx * width); \ + const Register *A_row = reinterpret_cast(A + A_rowidx * width); \ /* These will be packed 16-bit integers containing sums for each row of B multiplied by the row of A. \ Iterate over shared (inner) dimension.*/ \ int k = 0; \ - Integer a = *(A_row + k); \ - Integer sum0 = maddubs_epi16(a, *(B0_col + k * 8)); \ - Integer sum1 = maddubs_epi16(a, *(B0_col + k * 8 + 1)); \ - Integer sum2 = maddubs_epi16(a, *(B0_col + k * 8 + 2)); \ - Integer sum3 = maddubs_epi16(a, *(B0_col + k * 8 + 3)); \ - Integer sum4 = maddubs_epi16(a, *(B0_col + k * 8 + 4)); \ - Integer sum5 = maddubs_epi16(a, *(B0_col + k * 8 + 5)); \ - Integer sum6 = maddubs_epi16(a, *(B0_col + k * 8 + 6)); \ - Integer sum7 = maddubs_epi16(a, *(B0_col + k * 8 + 7)); \ + Register a = *(A_row + k); \ + Register sum0 = maddubs_epi16(a, *(B0_col + k * 8)); \ + Register sum1 = maddubs_epi16(a, *(B0_col + k * 8 + 1)); \ + Register sum2 = maddubs_epi16(a, *(B0_col + k * 8 + 2)); \ + Register sum3 = maddubs_epi16(a, *(B0_col + k * 8 + 3)); \ + Register sum4 = maddubs_epi16(a, *(B0_col + k * 8 + 4)); \ + Register sum5 = maddubs_epi16(a, *(B0_col + k * 8 + 5)); \ + Register sum6 = maddubs_epi16(a, *(B0_col + k * 8 + 6)); \ + Register sum7 = maddubs_epi16(a, *(B0_col + k * 8 + 7)); \ /* Upcast to 32-bit and horizontally add. Seems a bit faster if this is declared here.*/ \ - Integer ones = set1_epi16(1); \ + Register ones = set1_epi16(1); \ sum0 = madd_epi16(sum0, ones); \ sum1 = madd_epi16(sum1, ones); \ sum2 = madd_epi16(sum2, ones); \ @@ -305,16 +305,16 @@ template target static void Multiply(const int16_t *A, const sum6 = madd_epi16(sum6, ones); \ sum7 = madd_epi16(sum7, ones); \ for (int k = 1; k < simd_width; ++k) { \ - Integer a = *(A_row + k); \ + Register a = *(A_row + k); \ /* Multiply 8-bit, horizontally add to packed 16-bit integers.*/ \ - Integer mult0 = maddubs_epi16(a, *(B0_col + k * 8)); \ - Integer mult1 = maddubs_epi16(a, *(B0_col + k * 8 + 1)); \ - Integer mult2 = maddubs_epi16(a, *(B0_col + k * 8 + 2)); \ - Integer mult3 = maddubs_epi16(a, *(B0_col + k * 8 + 3)); \ - Integer mult4 = maddubs_epi16(a, *(B0_col + k * 8 + 4)); \ - Integer mult5 = maddubs_epi16(a, *(B0_col + k * 8 + 5)); \ - Integer mult6 = maddubs_epi16(a, *(B0_col + k * 8 + 6)); \ - Integer mult7 = maddubs_epi16(a, *(B0_col + k * 8 + 7)); \ + Register mult0 = maddubs_epi16(a, *(B0_col + k * 8)); \ + Register mult1 = maddubs_epi16(a, *(B0_col + k * 8 + 1)); \ + Register mult2 = maddubs_epi16(a, *(B0_col + k * 8 + 2)); \ + Register mult3 = maddubs_epi16(a, *(B0_col + k * 8 + 3)); \ + Register mult4 = maddubs_epi16(a, *(B0_col + k * 8 + 4)); \ + Register mult5 = maddubs_epi16(a, *(B0_col + k * 8 + 5)); \ + Register mult6 = maddubs_epi16(a, *(B0_col + k * 8 + 6)); \ + Register mult7 = maddubs_epi16(a, *(B0_col + k * 8 + 7)); \ /* Upcast to 32-bit and horizontally add.*/ \ mult0 = madd_epi16(mult0, ones); \ mult1 = madd_epi16(mult1, ones); \ @@ -336,8 +336,8 @@ template target static void Multiply(const int16_t *A, const \ } \ /* Reduce sums within 128-bit lanes.*/ \ - Integer pack0123 = Pack0123(sum0, sum1, sum2, sum3); \ - Integer pack4567 = Pack0123(sum4, sum5, sum6, sum7); \ + Register pack0123 = Pack0123(sum0, sum1, sum2, sum3); \ + Register pack4567 = Pack0123(sum4, sum5, sum6, sum7); \ /*The specific implementation may need to reduce further.*/ \ auto total = PermuteSummer(pack0123, pack4567); \ RunCallback(callback_impl, total, A_rowidx, B0_colidx, A_rows, B_cols); \ @@ -493,35 +493,35 @@ INTGEMM_SSSE3 inline static void InnerINTGEMM_SSSE3( sum7 = adds_epi16(sum7, maddubs_epi16(a_positive, sign_epi8(b[7], a))); } //INTGEMM_AVX2 or INTGEMM_SSSE3 multiply -#define INTGEMM_MULTIPLY8(Integer, target, cpu_type) \ +#define INTGEMM_MULTIPLY8(Register, target, cpu_type) \ template target static void Multiply(const int8_t *A, const int8_t *B, Index A_rows, Index width, Index B_cols, Callback callback) { \ - assert(width % sizeof(Integer) == 0); \ + assert(width % sizeof(Register) == 0); \ assert(B_cols % 8 == 0); \ - assert(reinterpret_cast(A) % sizeof(Integer) == 0); \ - assert(reinterpret_cast(B) % sizeof(Integer) == 0); \ - const int simd_width = width / sizeof(Integer); \ + assert(reinterpret_cast(A) % sizeof(Register) == 0); \ + assert(reinterpret_cast(B) % sizeof(Register) == 0); \ + const int simd_width = width / sizeof(Register); \ auto callback_impl = callbacks::CallbackImpl(callback); \ - const Integer *B0_col = reinterpret_cast(B); \ + const Register *B0_col = reinterpret_cast(B); \ /*Go over 8 columns of B at a time.*/ \ for (Index B0_colidx = 0; B0_colidx != B_cols; B0_col += 8 * simd_width, B0_colidx += 8) { \ /*Process one row of A at a time. Doesn't seem to be faster to do multiple rows of A at once.*/ \ for (Index A_rowidx = 0; A_rowidx < A_rows; ++A_rowidx) { \ /*Iterate over shared (inner) dimension.*/ \ - const Integer *A_live = reinterpret_cast(A + A_rowidx * width); \ - const Integer *A_end = A_live + simd_width; \ - const Integer *B_live = B0_col; \ + const Register *A_live = reinterpret_cast(A + A_rowidx * width); \ + const Register *A_end = A_live + simd_width; \ + const Register *B_live = B0_col; \ /* Rather than initializing as zeros and adding, just initialize the first.*/ \ - Integer a = *(A_live++); \ - Integer a_positive = abs_epi8(a); \ + Register a = *(A_live++); \ + Register a_positive = abs_epi8(a); \ /* These will be packed 16-bit integers containing sums for each column of B multiplied by the row of A.*/ \ - Integer sum0 = maddubs_epi16(a_positive, sign_epi8(B_live[0], a)); \ - Integer sum1 = maddubs_epi16(a_positive, sign_epi8(B_live[1], a)); \ - Integer sum2 = maddubs_epi16(a_positive, sign_epi8(B_live[2], a)); \ - Integer sum3 = maddubs_epi16(a_positive, sign_epi8(B_live[3], a)); \ - Integer sum4 = maddubs_epi16(a_positive, sign_epi8(B_live[4], a)); \ - Integer sum5 = maddubs_epi16(a_positive, sign_epi8(B_live[5], a)); \ - Integer sum6 = maddubs_epi16(a_positive, sign_epi8(B_live[6], a)); \ - Integer sum7 = maddubs_epi16(a_positive, sign_epi8(B_live[7], a)); \ + Register sum0 = maddubs_epi16(a_positive, sign_epi8(B_live[0], a)); \ + Register sum1 = maddubs_epi16(a_positive, sign_epi8(B_live[1], a)); \ + Register sum2 = maddubs_epi16(a_positive, sign_epi8(B_live[2], a)); \ + Register sum3 = maddubs_epi16(a_positive, sign_epi8(B_live[3], a)); \ + Register sum4 = maddubs_epi16(a_positive, sign_epi8(B_live[4], a)); \ + Register sum5 = maddubs_epi16(a_positive, sign_epi8(B_live[5], a)); \ + Register sum6 = maddubs_epi16(a_positive, sign_epi8(B_live[6], a)); \ + Register sum7 = maddubs_epi16(a_positive, sign_epi8(B_live[7], a)); \ B_live += 8; \ /* Use A as the loop variable so the add can be done where gcc likes it for branch prediction.*/ \ for (; A_live != A_end; ++A_live, B_live += 8) { \ @@ -544,7 +544,7 @@ INTGEMM_SSSE3 inline static void InnerINTGEMM_SSSE3( * _mm512_srai_epi32(_mm512_slli_epi32(sum, 16), 16), * _mm512_srai_epi32(sum, 16)); */ \ - Integer ones = set1_epi16(1); \ + Register ones = set1_epi16(1); \ sum0 = madd_epi16(sum0, ones); \ sum1 = madd_epi16(sum1, ones); \ sum2 = madd_epi16(sum2, ones); \ @@ -553,8 +553,8 @@ INTGEMM_SSSE3 inline static void InnerINTGEMM_SSSE3( sum5 = madd_epi16(sum5, ones); \ sum6 = madd_epi16(sum6, ones); \ sum7 = madd_epi16(sum7, ones); \ - Integer pack0123 = Pack0123(sum0, sum1, sum2, sum3); \ - Integer pack4567 = Pack0123(sum4, sum5, sum6, sum7); \ + Register pack0123 = Pack0123(sum0, sum1, sum2, sum3); \ + Register pack4567 = Pack0123(sum4, sum5, sum6, sum7); \ auto total = PermuteSummer(pack0123, pack4567); \ RunCallback(callback_impl, total, A_rowidx, B0_colidx, A_rows, B_cols); \ } \ diff --git a/sse2_gemm.h b/sse2_gemm.h index a2bae35..34de052 100644 --- a/sse2_gemm.h +++ b/sse2_gemm.h @@ -21,7 +21,7 @@ INTGEMM_SELECT_COL_B(INTGEMM_SSE2, __m128i) class QuantizeTile16 { public: - typedef __m128i Integer; + typedef __m128i Register; INTGEMM_SSE2 explicit QuantizeTile16(float mult) : mult_reg_(_mm_set1_ps(mult)) {} @@ -29,7 +29,7 @@ class QuantizeTile16 { return Tile(input, input + 4); } - INTGEMM_SSE2 Integer ConsecutiveWithWrapping(const float *input, Index cols_left, Index cols, Index row_step) const { + INTGEMM_SSE2 Register ConsecutiveWithWrapping(const float *input, Index cols_left, Index cols, Index row_step) const { return Tile( input, input + 4 + (cols_left <= 4 ? cols * (row_step - 1) : 0)); diff --git a/ssse3_gemm.h b/ssse3_gemm.h index 40a26f4..44e2a4d 100644 --- a/ssse3_gemm.h +++ b/ssse3_gemm.h @@ -22,7 +22,7 @@ INTGEMM_SELECT_COL_B(INTGEMM_SSSE3, __m128i) class QuantizeTile8 { public: - typedef __m128i Integer; + typedef __m128i Register; INTGEMM_SSSE3 explicit QuantizeTile8(float mult) : mult_reg_(_mm_set1_ps(mult)) {} @@ -39,16 +39,16 @@ class QuantizeTile8 { return TileU(input, input + 4, input + 8, input + 12); } - INTGEMM_SSSE3 Integer ConsecutiveWithWrapping(const float *input, Index cols_left, Index cols, Index row_step) const { + INTGEMM_SSSE3 Register ConsecutiveWithWrapping(const float *input, Index cols_left, Index cols, Index row_step) const { const float* inputs[4]; for (int i = 0; i < sizeof(inputs) / sizeof(inputs[0]); ++i) { - while (cols_left < sizeof(Integer) / sizeof(float)) { + while (cols_left < sizeof(Register) / sizeof(float)) { input += cols * (row_step - 1); cols_left += cols; } inputs[i] = input; - input += sizeof(Integer) / sizeof(float); - cols_left -= sizeof(Integer) / sizeof(float); + input += sizeof(Register) / sizeof(float); + cols_left -= sizeof(Register) / sizeof(float); } return Tile(inputs[0], inputs[1], inputs[2], inputs[3]); } -- cgit v1.2.3