diff options
author | Nikolay Bogoychev <nheart@gmail.com> | 2020-04-20 16:19:47 +0300 |
---|---|---|
committer | Nikolay Bogoychev <nheart@gmail.com> | 2020-04-20 16:19:47 +0300 |
commit | e2404d9a05148379fbc3e7cc9904b3157e7d73f9 (patch) | |
tree | f11ec9d4237d6c8b4b403f90b6aa1e54b87bc94b | |
parent | 1b262a42d7ff978310c335842203af8c8b47cb2a (diff) | |
parent | ec396d1b8d6f29e3a70924df4225cfd4050a1c2b (diff) |
Merge branch 'master' into absolute_std
-rw-r--r-- | aligned.h | 9 | ||||
-rw-r--r-- | avx512_gemm.h | 5 | ||||
-rw-r--r-- | avx512vnni_gemm.h | 72 | ||||
-rw-r--r-- | intgemm.h | 11 | ||||
-rw-r--r-- | multiply.h | 41 | ||||
-rw-r--r-- | test/add127_test.cc | 6 | ||||
-rw-r--r-- | test/multiply_test.cc | 6 | ||||
-rw-r--r-- | test/test.h | 38 | ||||
-rw-r--r-- | test/utils_test.cc | 13 | ||||
-rw-r--r-- | utils.h | 24 |
10 files changed, 139 insertions, 86 deletions
@@ -1,5 +1,6 @@ #pragma once #include <cstdlib> +#include <new> #include <stdlib.h> // 64-byte aligned simple vector. @@ -10,11 +11,9 @@ template <class T> class AlignedVector { public: explicit AlignedVector(std::size_t size) : size_(size) { - #ifdef __APPLE__ - posix_memalign(reinterpret_cast<void **>(&mem_), 64, size * sizeof(T)); - #else - mem_ = reinterpret_cast<T*>(aligned_alloc(64, (size * sizeof(T) + 63) & ~63)); // pedantic requirements for memory size on aligned_alloc in case it's not just a call to posix_memalign - #endif + if (posix_memalign(reinterpret_cast<void **>(&mem_), 64, size * sizeof(T))) { + throw std::bad_alloc(); + } } AlignedVector(const AlignedVector&) = delete; diff --git a/avx512_gemm.h b/avx512_gemm.h index 623e21a..c6a473e 100644 --- a/avx512_gemm.h +++ b/avx512_gemm.h @@ -331,11 +331,12 @@ struct AVX512_8bit { // There's 8 results for INTGEMM_AVX2 to handle. auto callback_impl = callbacks::CallbackImpl<CPUType::AVX2, Callback>(callback); const int simd_width = width / sizeof(Register); - const Register *B0_col = reinterpret_cast<const Register*>(B); // Added for AVX512. Register zeros = setzero_si<Register>(); // Go over 8 columns of B at a time. - for (Index B0_colidx = 0; B0_colidx != B_cols; B0_col += 8 * simd_width, B0_colidx += 8) { +#pragma omp for + for (Index B0_colidx = 0; B0_colidx < B_cols; B0_colidx += 8) { + const Register *B0_col = reinterpret_cast<const Register*>(B) + B0_colidx * simd_width; // Process one row of A at a time. Doesn't seem to be faster to do multiple rows of A at once. for (Index A_rowidx = 0; A_rowidx < A_rows; ++A_rowidx) { // Iterate over shared (inner) dimension. diff --git a/avx512vnni_gemm.h b/avx512vnni_gemm.h index 59f6405..22c5c4e 100644 --- a/avx512vnni_gemm.h +++ b/avx512vnni_gemm.h @@ -8,6 +8,15 @@ namespace intgemm { +// Workaround extra vmovdqa64 https://gcc.gnu.org/bugzilla/show_bug.cgi?id=94663 +INTGEMM_AVX512VNNI static inline void VNNI8(__m512i &c, __m512i a, __m512i b) { +#if defined(__GNUC__) && !defined(__clang__) && !defined(__INTEL_COMPILER) + asm ("vpdpbusds %2, %1, %0" : "+x"(c) : "x"(a), "mx"(b)); +#else + c = _mm512_dpbusds_epi32(c, a, b); +#endif +} + struct AVX512VNNI_8bit : public AVX512_8bit { template <typename Callback> INTGEMM_AVX512VNNI static void Multiply(const int8_t *A, const int8_t *B, Index A_rows, Index width, Index B_cols, Callback callback) { @@ -18,10 +27,11 @@ struct AVX512VNNI_8bit : public AVX512_8bit { assert(reinterpret_cast<uintptr_t>(B) % sizeof(Register) == 0); auto callback_impl = callbacks::CallbackImpl<CPUType::AVX2, Callback>(callback); const int simd_width = width / sizeof(Register); - const Register *B0_col = reinterpret_cast<const Register*>(B); Register zeros = setzero_si<Register>(); // Go over 8 columns of B at a time. - for (Index B0_colidx = 0; B0_colidx != B_cols; B0_col += 8 * simd_width, B0_colidx += 8) { +#pragma omp for + for (Index B0_colidx = 0; B0_colidx < B_cols; B0_colidx += 8) { + const Register *B0_col = reinterpret_cast<const Register*>(B) + B0_colidx * simd_width; // Process one row of A at a time. Doesn't seem to be faster to do multiple rows of A at once. for (Index A_rowidx = 0; A_rowidx < A_rows; ++A_rowidx) { // Iterate over shared (inner) dimension. @@ -53,14 +63,14 @@ struct AVX512VNNI_8bit : public AVX512_8bit { b5 = _mm512_mask_sub_epi8(b5, neg_mask, zeros, b5); b6 = _mm512_mask_sub_epi8(b6, neg_mask, zeros, b6); b7 = _mm512_mask_sub_epi8(b7, neg_mask, zeros, b7); - sum0 = _mm512_dpbusds_epi32(sum0, a_positive, b0); - sum1 = _mm512_dpbusds_epi32(sum1, a_positive, b1); - sum2 = _mm512_dpbusds_epi32(sum2, a_positive, b2); - sum3 = _mm512_dpbusds_epi32(sum3, a_positive, b3); - sum4 = _mm512_dpbusds_epi32(sum4, a_positive, b4); - sum5 = _mm512_dpbusds_epi32(sum5, a_positive, b5); - sum6 = _mm512_dpbusds_epi32(sum6, a_positive, b6); - sum7 = _mm512_dpbusds_epi32(sum7, a_positive, b7); + VNNI8(sum0, a_positive, b0); + VNNI8(sum1, a_positive, b1); + VNNI8(sum2, a_positive, b2); + VNNI8(sum3, a_positive, b3); + VNNI8(sum4, a_positive, b4); + VNNI8(sum5, a_positive, b5); + VNNI8(sum6, a_positive, b6); + VNNI8(sum7, a_positive, b7); } Register pack0123 = Pack0123(sum0, sum1, sum2, sum3); Register pack4567 = Pack0123(sum4, sum5, sum6, sum7); @@ -79,10 +89,11 @@ struct AVX512VNNI_8bit : public AVX512_8bit { assert(reinterpret_cast<uintptr_t>(B) % sizeof(Register) == 0); auto callback_impl = callbacks::CallbackImpl<CPUType::AVX2, Callback>(callback); const int simd_width = width / sizeof(Register); - const Register *B0_col = reinterpret_cast<const Register*>(B); Register zeros = setzero_si<Register>(); // Go over 8 columns of B at a time. - for (Index B0_colidx = 0; B0_colidx != B_cols; B0_col += 8 * simd_width, B0_colidx += 8) { +#pragma omp for + for (Index B0_colidx = 0; B0_colidx < B_cols; B0_colidx += 8) { + const Register *B0_col = reinterpret_cast<const Register*>(B) + B0_colidx * simd_width; // Process one row of A at a time. Doesn't seem to be faster to do multiple rows of A at once. for (Index A_rowidx = 0; A_rowidx < A_rows; ++A_rowidx) { // Iterate over shared (inner) dimension. @@ -94,14 +105,14 @@ struct AVX512VNNI_8bit : public AVX512_8bit { for (; A_live != A_end; ++A_live, B_live += 8) { Register a = *A_live; //MultiplyAdd - sum0 = _mm512_dpbusds_epi32(sum0, a, *B_live); - sum1 = _mm512_dpbusds_epi32(sum1, a, *(B_live + 1)); - sum2 = _mm512_dpbusds_epi32(sum2, a, *(B_live + 2)); - sum3 = _mm512_dpbusds_epi32(sum3, a, *(B_live + 3)); - sum4 = _mm512_dpbusds_epi32(sum4, a, *(B_live + 4)); - sum5 = _mm512_dpbusds_epi32(sum5, a, *(B_live + 5)); - sum6 = _mm512_dpbusds_epi32(sum6, a, *(B_live + 6)); - sum7 = _mm512_dpbusds_epi32(sum7, a, *(B_live + 7)); + VNNI8(sum0, a, *B_live); + VNNI8(sum1, a, *(B_live + 1)); + VNNI8(sum2, a, *(B_live + 2)); + VNNI8(sum3, a, *(B_live + 3)); + VNNI8(sum4, a, *(B_live + 4)); + VNNI8(sum5, a, *(B_live + 5)); + VNNI8(sum6, a, *(B_live + 6)); + VNNI8(sum7, a, *(B_live + 7)); } Register pack0123 = Pack0123(sum0, sum1, sum2, sum3); Register pack4567 = Pack0123(sum4, sum5, sum6, sum7); @@ -119,11 +130,12 @@ struct AVX512VNNI_8bit : public AVX512_8bit { assert(reinterpret_cast<uintptr_t>(B) % sizeof(Register) == 0); auto callback_impl = callbacks::CallbackImpl<CPUType::AVX2, Callback>(callback); const int simd_width = width / sizeof(Register); - const Register *B0_col = reinterpret_cast<const Register*>(B); Register zeros = setzero_si<Register>(); const Register a = set1_epi8<Register>(1); // Go over 8 columns of B at a time. - for (Index B0_colidx = 0; B0_colidx != B_cols; B0_col += 8 * simd_width, B0_colidx += 8) { +#pragma omp for + for (Index B0_colidx = 0; B0_colidx < B_cols; B0_colidx += 8) { + const Register *B0_col = reinterpret_cast<const Register*>(B) + B0_colidx * simd_width; const Register *B_live = B0_col; //In order to make the code look as much as possible as the above function const Register *B_end = B_live + simd_width*8; @@ -131,14 +143,14 @@ struct AVX512VNNI_8bit : public AVX512_8bit { Register sum0 = zeros, sum1 = zeros, sum2 = zeros, sum3 = zeros, sum4 = zeros, sum5 = zeros, sum6 = zeros, sum7 = zeros; for (; B_live != B_end; B_live += 8) { // Retrieve the conveniently consecutive values of B. - sum0 = _mm512_dpbusds_epi32(sum0, a, *B_live); - sum1 = _mm512_dpbusds_epi32(sum1, a, *(B_live + 1)); - sum2 = _mm512_dpbusds_epi32(sum2, a, *(B_live + 2)); - sum3 = _mm512_dpbusds_epi32(sum3, a, *(B_live + 3)); - sum4 = _mm512_dpbusds_epi32(sum4, a, *(B_live + 4)); - sum5 = _mm512_dpbusds_epi32(sum5, a, *(B_live + 5)); - sum6 = _mm512_dpbusds_epi32(sum6, a, *(B_live + 6)); - sum7 = _mm512_dpbusds_epi32(sum7, a, *(B_live + 7)); + VNNI8(sum0, a, *B_live); + VNNI8(sum1, a, *(B_live + 1)); + VNNI8(sum2, a, *(B_live + 2)); + VNNI8(sum3, a, *(B_live + 3)); + VNNI8(sum4, a, *(B_live + 4)); + VNNI8(sum5, a, *(B_live + 5)); + VNNI8(sum6, a, *(B_live + 6)); + VNNI8(sum7, a, *(B_live + 7)); } Register pack0123 = Pack0123(sum0, sum1, sum2, sum3); Register pack4567 = Pack0123(sum4, sum5, sum6, sum7); @@ -285,7 +285,7 @@ private: }; template <typename Callback> -void (*Int8::MultiplyImpl<Callback>::run)(const int8_t *A, const int8_t *B, Index A_rows, Index width, Index B_cols, Callback callback) = ChooseCPU(AVX512VNNI_8bit::Multiply<Callback>, AVX512_8bit::Multiply<Callback>, AVX2_8bit::Multiply<Callback>, SSSE3_8bit::Multiply<Callback>, SSSE3_8bit::Multiply<Callback>, Unsupported_8bit::Multiply); +void (*Int8::MultiplyImpl<Callback>::run)(const int8_t *A, const int8_t *B, Index A_rows, Index width, Index B_cols, Callback callback) = ChooseCPU(OMPParallelWrap<Callback, AVX512VNNI_8bit>, OMPParallelWrap<Callback, AVX512_8bit>, OMPParallelWrap<Callback, AVX2_8bit>, OMPParallelWrap<Callback, SSSE3_8bit>, Unsupported_8bit::Multiply<Callback>, Unsupported_8bit::Multiply<Callback>); /* * 8-bit matrix multiplication with shifting A by 127 @@ -348,7 +348,12 @@ private: }; template <class Callback> -void (*Int8Shift::MultiplyImpl<Callback>::run)(const uint8_t *A, const int8_t *B, Index A_rows, Index width, Index B_cols, Callback callback) = ChooseCPU(AVX512VNNI_8bit::Multiply8Shift<Callback>, AVX512_8bit::Multiply8Shift<Callback>, AVX2_8bit::Multiply8Shift<Callback>, SSSE3_8bit::Multiply8Shift<Callback>, SSSE3_8bit::Multiply8Shift<Callback>, Unsupported_8bit::Multiply8Shift); +void (*Int8Shift::MultiplyImpl<Callback>::run)(const uint8_t *A, const int8_t *B, Index A_rows, Index width, Index B_cols, Callback callback) = ChooseCPU( + OMPParallelWrap8Shift<Callback, AVX512VNNI_8bit>, + OMPParallelWrap8Shift<Callback, AVX512_8bit>, + OMPParallelWrap8Shift<Callback, AVX2_8bit>, + OMPParallelWrap8Shift<Callback, SSSE3_8bit>, + Unsupported_8bit::Multiply8Shift<Callback>, Unsupported_8bit::Multiply8Shift<Callback>); template <class Callback> void (*Int8Shift::PrepareBiasImpl<Callback>::run)(const int8_t *B, Index width, Index B_cols, Callback callback) = ChooseCPU(AVX512VNNI_8bit::PrepareBias<Callback>, AVX512_8bit::PrepareBias<Callback>, AVX2_8bit::PrepareBias<Callback>, SSSE3_8bit::PrepareBias<Callback>, SSSE3_8bit::PrepareBias<Callback>, Unsupported_8bit::PrepareBias); @@ -407,7 +412,7 @@ private: }; template <typename Callback> -void (*Int16::MultiplyImpl<Callback>::run)(const int16_t *A, const int16_t *B, Index A_rows, Index width, Index B_cols, Callback callback) = ChooseCPU(AVX512_16bit::Multiply<Callback> /*TODO VNNI 16-bit. */, AVX512_16bit::Multiply<Callback>, AVX2_16bit::Multiply<Callback>, SSE2_16bit::Multiply<Callback>, SSE2_16bit::Multiply<Callback>, Unsupported_16bit::Multiply); +void (*Int16::MultiplyImpl<Callback>::run)(const int16_t *A, const int16_t *B, Index A_rows, Index width, Index B_cols, Callback callback) = ChooseCPU(OMPParallelWrap<Callback, AVX512_16bit> /*TODO VNNI 16-bit. */, OMPParallelWrap<Callback, AVX512_16bit>, OMPParallelWrap<Callback, AVX2_16bit>, OMPParallelWrap<Callback, SSE2_16bit>, OMPParallelWrap<Callback, SSE2_16bit>, Unsupported_16bit::Multiply<Callback>); extern const CPUType kCPU; @@ -205,8 +205,9 @@ template <typename Callback> target static void Multiply(const int16_t *A, const assert(reinterpret_cast<uintptr_t>(B) % sizeof(Register) == 0); \ const int simd_width = width / (sizeof(Register) / sizeof(int16_t)); \ auto callback_impl = callbacks::CallbackImpl<cpu_type, Callback>(callback); \ - const Register *B0_col = reinterpret_cast<const Register *>(B); \ - for (Index B0_colidx = 0; B0_colidx < B_cols; B0_col += 8 * simd_width, B0_colidx += 8) { \ + _Pragma("omp for") \ + for (Index B0_colidx = 0; B0_colidx < B_cols; B0_colidx += 8) { \ + const Register *B0_col = reinterpret_cast<const Register *>(B) + simd_width * B0_colidx; \ /* Process one row of A at a time. Doesn't seem to be faster to do multiple rows of A at once.*/ \ for (Index A_rowidx = 0; A_rowidx < A_rows; ++A_rowidx) { \ const Register *A_row = reinterpret_cast<const Register*>(A + A_rowidx * width); \ @@ -261,9 +262,10 @@ template <typename Callback> target static void Multiply(const int16_t *A, const assert(reinterpret_cast<uintptr_t>(B) % sizeof(Register) == 0); \ const int simd_width = width / (sizeof(Register) / sizeof(int8_t)); \ auto callback_impl = callbacks::CallbackImpl<cpu_type, Callback>(callback); \ - const Register *B0_col = reinterpret_cast<const Register *>(B); \ const Register a = set1_epi8<Register>(1); \ - for (Index B0_colidx = 0; B0_colidx < B_cols; B0_col += 8 * simd_width, B0_colidx += 8) { \ + _Pragma("omp for") \ + for (Index B0_colidx = 0; B0_colidx < B_cols; B0_colidx += 8) { \ + const Register *B0_col = reinterpret_cast<const Register *>(B) + simd_width * B0_colidx; \ /*const Register *A_row = reinterpret_cast<const Register*>(A + A_rowidx * width);*/ \ /* These will be packed 16-bit integers containing sums for each row of B multiplied by the row of A. \ Iterate over shared (inner) dimension.*/ \ @@ -335,8 +337,9 @@ template <typename Callback> target static void Multiply(const int16_t *A, const assert(reinterpret_cast<uintptr_t>(B) % sizeof(Register) == 0); \ const int simd_width = width / (sizeof(Register) / sizeof(int8_t)); \ auto callback_impl = callbacks::CallbackImpl<cpu_type, Callback>(callback); \ - const Register *B0_col = reinterpret_cast<const Register *>(B); \ - for (Index B0_colidx = 0; B0_colidx < B_cols; B0_col += 8 * simd_width, B0_colidx += 8) { \ + _Pragma("omp for") \ + for (Index B0_colidx = 0; B0_colidx < B_cols; B0_colidx += 8) { \ + const Register *B0_col = reinterpret_cast<const Register *>(B) + simd_width * B0_colidx; \ /* Process one row of A at a time. Doesn't seem to be faster to do multiple rows of A at once.*/ \ for (Index A_rowidx = 0; A_rowidx < A_rows; ++A_rowidx) { \ const Register *A_row = reinterpret_cast<const Register*>(A + A_rowidx * width); \ @@ -559,9 +562,9 @@ INTGEMM_SSSE3 inline static void InnerINTGEMM_SSSE3( assert(reinterpret_cast<uintptr_t>(B) % sizeof(Register) == 0); \ const int simd_width = width / sizeof(Register); \ auto callback_impl = callbacks::CallbackImpl<cpu_type, Callback>(callback); \ - const Register *B0_col = reinterpret_cast<const Register*>(B); \ - /*Go over 8 columns of B at a time.*/ \ - for (Index B0_colidx = 0; B0_colidx != B_cols; B0_col += 8 * simd_width, B0_colidx += 8) { \ + _Pragma("omp for") \ + for (Index B0_colidx = 0; B0_colidx < B_cols; B0_colidx += 8) { \ + const Register *B0_col = reinterpret_cast<const Register *>(B) + simd_width * B0_colidx; \ /*Process one row of A at a time. Doesn't seem to be faster to do multiple rows of A at once.*/ \ for (Index A_rowidx = 0; A_rowidx < A_rows; ++A_rowidx) { \ /*Iterate over shared (inner) dimension.*/ \ @@ -617,7 +620,25 @@ INTGEMM_SSSE3 inline static void InnerINTGEMM_SSSE3( RunCallback(callback_impl, total, A_rowidx, B0_colidx, A_rows, B_cols); \ } \ } \ -} \ +} + +/* Wrap a multiply call in OMP parallelism. Here it launches threads then + * inside the implementation there is a pragma omp for. In gcc >= 8 these + * could have been the same but older compilers don't imbue target attributes + * on the hidden function created by pragma omp parallel. + * + * Also, gcc 7 is unable to deduce the function pointer type (for ChooseCPU) if + * I use typename Backend::Integer directly in the arguments. As a workaround, + * have a default template argument Integer then use that so it's resolved. + */ +template <class Callback, class Backend, class Integer = typename Backend::Integer> static inline void OMPParallelWrap(const Integer *A, const Integer *B, Index A_rows, Index width, Index B_cols, Callback callback) { +#pragma omp parallel + Backend::template Multiply<Callback>(A, B, A_rows, width, B_cols, callback); +} +template <class Callback, class Backend> static inline void OMPParallelWrap8Shift(const uint8_t *A, const int8_t *B, Index A_rows, Index width, Index B_cols, Callback callback) { +#pragma omp parallel + Backend::template Multiply8Shift<Callback>(A, B, A_rows, width, B_cols, callback); +} #define INTGEMM_MAXABSOLUTE(Register, target) \ target static inline float MaxAbsolute(const float *begin_float, const float *end_float) { \ diff --git a/test/add127_test.cc b/test/add127_test.cc index d959b14..cec20c2 100644 --- a/test/add127_test.cc +++ b/test/add127_test.cc @@ -127,7 +127,7 @@ template <class Routine> void TestMultiplyBiasNew(Index A_rows, Index width, Ind }); AlignedVector<float> float_C(test_C.size()); - references::MultiplyFF(A.begin(), B.begin(), float_C.begin(), A_rows, width, B_cols, [&](float sum, const callbacks::OutputBufferInfo& info) { + references::Multiply(A.begin(), B.begin(), float_C.begin(), A_rows, width, B_cols, [&](float sum, const callbacks::OutputBufferInfo& info) { return sum + bias[info.col_idx]; }); @@ -184,7 +184,7 @@ template <class Routine> void TestMultiplyShiftNonShift(Index A_rows, Index widt Routine::Multiply(A_prep_old.begin(), B_prep.begin(), A_rows, width, B_cols, callbacks::UnquantizeAndAddBiasAndWrite(unquant_mult, bias.begin(), slowint_C.begin())); AlignedVector<float> float_C(test_C.size()); - references::MultiplyFF(A.begin(), B.begin(), float_C.begin(), A_rows, width, B_cols, [&](float sum, const callbacks::OutputBufferInfo& info) { + references::Multiply(A.begin(), B.begin(), float_C.begin(), A_rows, width, B_cols, [&](float sum, const callbacks::OutputBufferInfo& info) { return sum + bias[info.col_idx]; }); @@ -245,7 +245,7 @@ template <class Routine> void TestMultiplyShiftInt(Index A_rows, Index width, In // }); AlignedVector<float> float_C(test_C.size()); - references::MultiplyFF(A.begin(), B.begin(), float_C.begin(), A_rows, width, B_cols, [&](float sum, const callbacks::OutputBufferInfo& info) { + references::Multiply(A.begin(), B.begin(), float_C.begin(), A_rows, width, B_cols, [&](float sum, const callbacks::OutputBufferInfo& info) { return sum + bias[info.col_idx]; }); /* diff --git a/test/multiply_test.cc b/test/multiply_test.cc index 260dd76..a054753 100644 --- a/test/multiply_test.cc +++ b/test/multiply_test.cc @@ -278,7 +278,7 @@ template <class Routine> void TestMultiply(Index A_rows, Index width, Index B_co Routine::PrepareB(B.begin(), B_prep.begin(), quant_mult, width, B_cols); AlignedVector<float> test_C(A_rows * B_cols); - Routine::Multiply(A_prep.begin(), B_prep.begin(), A_rows, width, B_cols, callbacks::UnquantizeAndWrite(unquant_mult, test_C.begin())); + OMPParallelWrap<callbacks::UnquantizeAndWrite, Routine>(A_prep.begin(), B_prep.begin(), A_rows, width, B_cols, callbacks::UnquantizeAndWrite(unquant_mult, test_C.begin())); // Routine::Multiply(A_prep.begin(), B_prep.begin(), A_rows, width, B_cols, callbacks::Sequence( // callbacks::Unquantize(unquant_mult), // callbacks::Write<float>(test_C.begin()) @@ -293,7 +293,7 @@ template <class Routine> void TestMultiply(Index A_rows, Index width, Index B_co }); AlignedVector<float> float_C(test_C.size()); - references::MultiplyFF(A.begin(), B.begin(), float_C.begin(), A_rows, width, B_cols, [&](float sum, const callbacks::OutputBufferInfo&) { + references::Multiply(A.begin(), B.begin(), float_C.begin(), A_rows, width, B_cols, [&](float sum, const callbacks::OutputBufferInfo&) { return sum; }); @@ -346,7 +346,7 @@ template <class Routine> void TestMultiplyBias(Index A_rows, Index width, Index }); AlignedVector<float> float_C(test_C.size()); - references::MultiplyFF(A.begin(), B.begin(), float_C.begin(), A_rows, width, B_cols, [&](float sum, const callbacks::OutputBufferInfo& info) { + references::Multiply(A.begin(), B.begin(), float_C.begin(), A_rows, width, B_cols, [&](float sum, const callbacks::OutputBufferInfo& info) { return sum + bias[info.col_idx]; }); diff --git a/test/test.h b/test/test.h index 7de38e9..f145681 100644 --- a/test/test.h +++ b/test/test.h @@ -76,30 +76,28 @@ void Quantize(const float* input, Type* output, float quant_mult, Index size) { } } -// Multiply A(float) x B(float) -template <typename LambdaCallback> -void MultiplyFF(const float* A, const float* B, float* C, Index A_rows, Index width, Index B_cols, LambdaCallback callback) { - for (Index r = 0; r < A_rows; ++r) { - for (Index c = 0; c < B_cols; ++c) { - float sum = 0.0f; - for (Index k = 0; k < width; ++k) { - sum += A[r * width + k] * B[k * B_cols + c]; - } - C[r * B_cols + c] = callback(sum, {r, c, A_rows, B_cols}); - } - } -} +/* + * Multiply C = A x B + * + * Notes: A and B has to be both integers or both floating points. + * + * Callback takes two arguments: + * - Intermediate value of multiplication 1 row times 1 column - it's int32_t or double based on types A and B. + * - Object containing information about position in output matrix - callbacks::OutputBufferInfo. + */ +template <typename TypeA, typename TypeB, typename TypeC, typename LambdaCallback, + typename std::enable_if< + (std::is_integral<TypeA>::value && std::is_integral<TypeB>::value) || + (std::is_floating_point<TypeA>::value && std::is_floating_point<TypeB>::value) + >::type* = nullptr> +void Multiply(const TypeA* A, const TypeB* B, TypeC* C, Index A_rows, Index width, Index B_cols, LambdaCallback callback) { + using IntermediateType = typename std::conditional<std::is_integral<TypeA>::value, int32_t, double>::type; -// Multiply A(int) x B(int) -template <typename TypeA, typename TypeB, typename LambdaCallback, - typename std::enable_if<std::is_integral<TypeA>::value>::type* = nullptr, - typename std::enable_if<std::is_integral<TypeB>::value>::type* = nullptr> -void Multiply(const TypeA* A, const TypeB* B, float* C, Index A_rows, Index width, Index B_cols, LambdaCallback callback) { for (Index r = 0; r < A_rows; ++r) { for (Index c = 0; c < B_cols; ++c) { - int32_t sum = 0; + IntermediateType sum = 0; for (Index k = 0; k < width; ++k) { - sum += int32_t(A[r * width + k]) * int32_t(B[k * B_cols + c]); + sum += IntermediateType(A[r * width + k]) * IntermediateType(B[k * B_cols + c]); } C[r * B_cols + c] = callback(sum, {r, c, A_rows, B_cols}); } diff --git a/test/utils_test.cc b/test/utils_test.cc index 8596104..0281802 100644 --- a/test/utils_test.cc +++ b/test/utils_test.cc @@ -41,6 +41,12 @@ struct StaticLoopTest { } }; +TEST_CASE("Static loop (N = 0)",) { + Index result = 128; + StaticLoop<StaticLoopTest, MakeStaticLoopIterator<0>>(result); + CHECK(result == 128); +} + TEST_CASE("Static loop (N = 1)",) { Index result = 128; StaticLoop<StaticLoopTest, MakeStaticLoopIterator<1>>(result); @@ -78,5 +84,12 @@ TEST_CASE("Static loop with mult-dim iterator (Iterator<5, 2>)",) { CHECK(result == 11223344); } +TEST_CASE("Round up",) { + CHECK(round_up(0, 5) == 0); + CHECK(round_up(1, 5) == 5); + CHECK(round_up(4, 5) == 5); + CHECK(round_up(6, 5) == 10); +} + } } @@ -52,20 +52,18 @@ constexpr subtuple_t<Tuple, Indices...> make_subtuple(const Tuple& tuple, sequen /* * Factorial */ -constexpr unsigned long long factorial(unsigned n) { +static constexpr unsigned long long factorial(unsigned n) { return n <= 1 ? 1 : n * factorial(n - 1); } /* * e^n, where n is integer */ -namespace { // anonymous namespace -constexpr double expi_nonnegative(unsigned n) { +static constexpr double expi_nonnegative(unsigned n) { return n == 0 ? 1.0 : (n == 1 ? 2.718281828459045 : expi_nonnegative(n / 2) * expi_nonnegative((n + 1) / 2)); } -} // anonymous namespace -constexpr double expi(int n) { +static constexpr double expi(int n) { return (n >= 0 ? expi_nonnegative(n) : 1.0 / expi_nonnegative(-n)); } @@ -143,7 +141,7 @@ public: /* * Last iterator */ - using last = StaticLoopIterator<total_iterations - 1, Ns...>; + using end = StaticLoopIterator<total_iterations, Ns...>; }; /* @@ -190,15 +188,21 @@ using MakeStaticLoopIterator = StaticLoopIterator<0, Ns...>; * [4, 1] Test 1 * */ -template <typename Body, typename StaticLoopIterator, typename std::enable_if<std::is_same<StaticLoopIterator, typename StaticLoopIterator::last>::value>::type* = nullptr, typename... Args> -__attribute__((always_inline)) static inline void StaticLoop(Args&&... args) { - Body::template body<StaticLoopIterator>(std::forward<Args>(args)...); +template <typename Body, typename StaticLoopIterator, typename std::enable_if<std::is_same<StaticLoopIterator, typename StaticLoopIterator::end>::value>::type* = nullptr, typename... Args> +__attribute__((always_inline)) static inline void StaticLoop(Args&&...) { } -template <typename Body, typename StaticLoopIterator, typename std::enable_if<!std::is_same<StaticLoopIterator, typename StaticLoopIterator::last>::value>::type* = nullptr, typename... Args> +template <typename Body, typename StaticLoopIterator, typename std::enable_if<!std::is_same<StaticLoopIterator, typename StaticLoopIterator::end>::value>::type* = nullptr, typename... Args> __attribute__((always_inline)) static inline void StaticLoop(Args&&... args) { Body::template body<StaticLoopIterator>(std::forward<Args>(args)...); StaticLoop<Body, typename StaticLoopIterator::next>(std::forward<Args>(args)...); } +/* + * Round up + */ +static constexpr Index round_up(Index value, Index factor) { + return (value + factor - 1) / factor * factor; +} + } |