Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/intgemm.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNikolay Bogoychev <nheart@gmail.com>2020-04-20 16:19:47 +0300
committerNikolay Bogoychev <nheart@gmail.com>2020-04-20 16:19:47 +0300
commite2404d9a05148379fbc3e7cc9904b3157e7d73f9 (patch)
treef11ec9d4237d6c8b4b403f90b6aa1e54b87bc94b
parent1b262a42d7ff978310c335842203af8c8b47cb2a (diff)
parentec396d1b8d6f29e3a70924df4225cfd4050a1c2b (diff)
Merge branch 'master' into absolute_std
-rw-r--r--aligned.h9
-rw-r--r--avx512_gemm.h5
-rw-r--r--avx512vnni_gemm.h72
-rw-r--r--intgemm.h11
-rw-r--r--multiply.h41
-rw-r--r--test/add127_test.cc6
-rw-r--r--test/multiply_test.cc6
-rw-r--r--test/test.h38
-rw-r--r--test/utils_test.cc13
-rw-r--r--utils.h24
10 files changed, 139 insertions, 86 deletions
diff --git a/aligned.h b/aligned.h
index 6e72afb..6af3c31 100644
--- a/aligned.h
+++ b/aligned.h
@@ -1,5 +1,6 @@
#pragma once
#include <cstdlib>
+#include <new>
#include <stdlib.h>
// 64-byte aligned simple vector.
@@ -10,11 +11,9 @@ template <class T> class AlignedVector {
public:
explicit AlignedVector(std::size_t size)
: size_(size) {
- #ifdef __APPLE__
- posix_memalign(reinterpret_cast<void **>(&mem_), 64, size * sizeof(T));
- #else
- mem_ = reinterpret_cast<T*>(aligned_alloc(64, (size * sizeof(T) + 63) & ~63)); // pedantic requirements for memory size on aligned_alloc in case it's not just a call to posix_memalign
- #endif
+ if (posix_memalign(reinterpret_cast<void **>(&mem_), 64, size * sizeof(T))) {
+ throw std::bad_alloc();
+ }
}
AlignedVector(const AlignedVector&) = delete;
diff --git a/avx512_gemm.h b/avx512_gemm.h
index 623e21a..c6a473e 100644
--- a/avx512_gemm.h
+++ b/avx512_gemm.h
@@ -331,11 +331,12 @@ struct AVX512_8bit {
// There's 8 results for INTGEMM_AVX2 to handle.
auto callback_impl = callbacks::CallbackImpl<CPUType::AVX2, Callback>(callback);
const int simd_width = width / sizeof(Register);
- const Register *B0_col = reinterpret_cast<const Register*>(B);
// Added for AVX512.
Register zeros = setzero_si<Register>();
// Go over 8 columns of B at a time.
- for (Index B0_colidx = 0; B0_colidx != B_cols; B0_col += 8 * simd_width, B0_colidx += 8) {
+#pragma omp for
+ for (Index B0_colidx = 0; B0_colidx < B_cols; B0_colidx += 8) {
+ const Register *B0_col = reinterpret_cast<const Register*>(B) + B0_colidx * simd_width;
// Process one row of A at a time. Doesn't seem to be faster to do multiple rows of A at once.
for (Index A_rowidx = 0; A_rowidx < A_rows; ++A_rowidx) {
// Iterate over shared (inner) dimension.
diff --git a/avx512vnni_gemm.h b/avx512vnni_gemm.h
index 59f6405..22c5c4e 100644
--- a/avx512vnni_gemm.h
+++ b/avx512vnni_gemm.h
@@ -8,6 +8,15 @@
namespace intgemm {
+// Workaround extra vmovdqa64 https://gcc.gnu.org/bugzilla/show_bug.cgi?id=94663
+INTGEMM_AVX512VNNI static inline void VNNI8(__m512i &c, __m512i a, __m512i b) {
+#if defined(__GNUC__) && !defined(__clang__) && !defined(__INTEL_COMPILER)
+ asm ("vpdpbusds %2, %1, %0" : "+x"(c) : "x"(a), "mx"(b));
+#else
+ c = _mm512_dpbusds_epi32(c, a, b);
+#endif
+}
+
struct AVX512VNNI_8bit : public AVX512_8bit {
template <typename Callback>
INTGEMM_AVX512VNNI static void Multiply(const int8_t *A, const int8_t *B, Index A_rows, Index width, Index B_cols, Callback callback) {
@@ -18,10 +27,11 @@ struct AVX512VNNI_8bit : public AVX512_8bit {
assert(reinterpret_cast<uintptr_t>(B) % sizeof(Register) == 0);
auto callback_impl = callbacks::CallbackImpl<CPUType::AVX2, Callback>(callback);
const int simd_width = width / sizeof(Register);
- const Register *B0_col = reinterpret_cast<const Register*>(B);
Register zeros = setzero_si<Register>();
// Go over 8 columns of B at a time.
- for (Index B0_colidx = 0; B0_colidx != B_cols; B0_col += 8 * simd_width, B0_colidx += 8) {
+#pragma omp for
+ for (Index B0_colidx = 0; B0_colidx < B_cols; B0_colidx += 8) {
+ const Register *B0_col = reinterpret_cast<const Register*>(B) + B0_colidx * simd_width;
// Process one row of A at a time. Doesn't seem to be faster to do multiple rows of A at once.
for (Index A_rowidx = 0; A_rowidx < A_rows; ++A_rowidx) {
// Iterate over shared (inner) dimension.
@@ -53,14 +63,14 @@ struct AVX512VNNI_8bit : public AVX512_8bit {
b5 = _mm512_mask_sub_epi8(b5, neg_mask, zeros, b5);
b6 = _mm512_mask_sub_epi8(b6, neg_mask, zeros, b6);
b7 = _mm512_mask_sub_epi8(b7, neg_mask, zeros, b7);
- sum0 = _mm512_dpbusds_epi32(sum0, a_positive, b0);
- sum1 = _mm512_dpbusds_epi32(sum1, a_positive, b1);
- sum2 = _mm512_dpbusds_epi32(sum2, a_positive, b2);
- sum3 = _mm512_dpbusds_epi32(sum3, a_positive, b3);
- sum4 = _mm512_dpbusds_epi32(sum4, a_positive, b4);
- sum5 = _mm512_dpbusds_epi32(sum5, a_positive, b5);
- sum6 = _mm512_dpbusds_epi32(sum6, a_positive, b6);
- sum7 = _mm512_dpbusds_epi32(sum7, a_positive, b7);
+ VNNI8(sum0, a_positive, b0);
+ VNNI8(sum1, a_positive, b1);
+ VNNI8(sum2, a_positive, b2);
+ VNNI8(sum3, a_positive, b3);
+ VNNI8(sum4, a_positive, b4);
+ VNNI8(sum5, a_positive, b5);
+ VNNI8(sum6, a_positive, b6);
+ VNNI8(sum7, a_positive, b7);
}
Register pack0123 = Pack0123(sum0, sum1, sum2, sum3);
Register pack4567 = Pack0123(sum4, sum5, sum6, sum7);
@@ -79,10 +89,11 @@ struct AVX512VNNI_8bit : public AVX512_8bit {
assert(reinterpret_cast<uintptr_t>(B) % sizeof(Register) == 0);
auto callback_impl = callbacks::CallbackImpl<CPUType::AVX2, Callback>(callback);
const int simd_width = width / sizeof(Register);
- const Register *B0_col = reinterpret_cast<const Register*>(B);
Register zeros = setzero_si<Register>();
// Go over 8 columns of B at a time.
- for (Index B0_colidx = 0; B0_colidx != B_cols; B0_col += 8 * simd_width, B0_colidx += 8) {
+#pragma omp for
+ for (Index B0_colidx = 0; B0_colidx < B_cols; B0_colidx += 8) {
+ const Register *B0_col = reinterpret_cast<const Register*>(B) + B0_colidx * simd_width;
// Process one row of A at a time. Doesn't seem to be faster to do multiple rows of A at once.
for (Index A_rowidx = 0; A_rowidx < A_rows; ++A_rowidx) {
// Iterate over shared (inner) dimension.
@@ -94,14 +105,14 @@ struct AVX512VNNI_8bit : public AVX512_8bit {
for (; A_live != A_end; ++A_live, B_live += 8) {
Register a = *A_live;
//MultiplyAdd
- sum0 = _mm512_dpbusds_epi32(sum0, a, *B_live);
- sum1 = _mm512_dpbusds_epi32(sum1, a, *(B_live + 1));
- sum2 = _mm512_dpbusds_epi32(sum2, a, *(B_live + 2));
- sum3 = _mm512_dpbusds_epi32(sum3, a, *(B_live + 3));
- sum4 = _mm512_dpbusds_epi32(sum4, a, *(B_live + 4));
- sum5 = _mm512_dpbusds_epi32(sum5, a, *(B_live + 5));
- sum6 = _mm512_dpbusds_epi32(sum6, a, *(B_live + 6));
- sum7 = _mm512_dpbusds_epi32(sum7, a, *(B_live + 7));
+ VNNI8(sum0, a, *B_live);
+ VNNI8(sum1, a, *(B_live + 1));
+ VNNI8(sum2, a, *(B_live + 2));
+ VNNI8(sum3, a, *(B_live + 3));
+ VNNI8(sum4, a, *(B_live + 4));
+ VNNI8(sum5, a, *(B_live + 5));
+ VNNI8(sum6, a, *(B_live + 6));
+ VNNI8(sum7, a, *(B_live + 7));
}
Register pack0123 = Pack0123(sum0, sum1, sum2, sum3);
Register pack4567 = Pack0123(sum4, sum5, sum6, sum7);
@@ -119,11 +130,12 @@ struct AVX512VNNI_8bit : public AVX512_8bit {
assert(reinterpret_cast<uintptr_t>(B) % sizeof(Register) == 0);
auto callback_impl = callbacks::CallbackImpl<CPUType::AVX2, Callback>(callback);
const int simd_width = width / sizeof(Register);
- const Register *B0_col = reinterpret_cast<const Register*>(B);
Register zeros = setzero_si<Register>();
const Register a = set1_epi8<Register>(1);
// Go over 8 columns of B at a time.
- for (Index B0_colidx = 0; B0_colidx != B_cols; B0_col += 8 * simd_width, B0_colidx += 8) {
+#pragma omp for
+ for (Index B0_colidx = 0; B0_colidx < B_cols; B0_colidx += 8) {
+ const Register *B0_col = reinterpret_cast<const Register*>(B) + B0_colidx * simd_width;
const Register *B_live = B0_col; //In order to make the code look as much as possible as the above function
const Register *B_end = B_live + simd_width*8;
@@ -131,14 +143,14 @@ struct AVX512VNNI_8bit : public AVX512_8bit {
Register sum0 = zeros, sum1 = zeros, sum2 = zeros, sum3 = zeros, sum4 = zeros, sum5 = zeros, sum6 = zeros, sum7 = zeros;
for (; B_live != B_end; B_live += 8) {
// Retrieve the conveniently consecutive values of B.
- sum0 = _mm512_dpbusds_epi32(sum0, a, *B_live);
- sum1 = _mm512_dpbusds_epi32(sum1, a, *(B_live + 1));
- sum2 = _mm512_dpbusds_epi32(sum2, a, *(B_live + 2));
- sum3 = _mm512_dpbusds_epi32(sum3, a, *(B_live + 3));
- sum4 = _mm512_dpbusds_epi32(sum4, a, *(B_live + 4));
- sum5 = _mm512_dpbusds_epi32(sum5, a, *(B_live + 5));
- sum6 = _mm512_dpbusds_epi32(sum6, a, *(B_live + 6));
- sum7 = _mm512_dpbusds_epi32(sum7, a, *(B_live + 7));
+ VNNI8(sum0, a, *B_live);
+ VNNI8(sum1, a, *(B_live + 1));
+ VNNI8(sum2, a, *(B_live + 2));
+ VNNI8(sum3, a, *(B_live + 3));
+ VNNI8(sum4, a, *(B_live + 4));
+ VNNI8(sum5, a, *(B_live + 5));
+ VNNI8(sum6, a, *(B_live + 6));
+ VNNI8(sum7, a, *(B_live + 7));
}
Register pack0123 = Pack0123(sum0, sum1, sum2, sum3);
Register pack4567 = Pack0123(sum4, sum5, sum6, sum7);
diff --git a/intgemm.h b/intgemm.h
index 6f2bc1c..f4aa957 100644
--- a/intgemm.h
+++ b/intgemm.h
@@ -285,7 +285,7 @@ private:
};
template <typename Callback>
-void (*Int8::MultiplyImpl<Callback>::run)(const int8_t *A, const int8_t *B, Index A_rows, Index width, Index B_cols, Callback callback) = ChooseCPU(AVX512VNNI_8bit::Multiply<Callback>, AVX512_8bit::Multiply<Callback>, AVX2_8bit::Multiply<Callback>, SSSE3_8bit::Multiply<Callback>, SSSE3_8bit::Multiply<Callback>, Unsupported_8bit::Multiply);
+void (*Int8::MultiplyImpl<Callback>::run)(const int8_t *A, const int8_t *B, Index A_rows, Index width, Index B_cols, Callback callback) = ChooseCPU(OMPParallelWrap<Callback, AVX512VNNI_8bit>, OMPParallelWrap<Callback, AVX512_8bit>, OMPParallelWrap<Callback, AVX2_8bit>, OMPParallelWrap<Callback, SSSE3_8bit>, Unsupported_8bit::Multiply<Callback>, Unsupported_8bit::Multiply<Callback>);
/*
* 8-bit matrix multiplication with shifting A by 127
@@ -348,7 +348,12 @@ private:
};
template <class Callback>
-void (*Int8Shift::MultiplyImpl<Callback>::run)(const uint8_t *A, const int8_t *B, Index A_rows, Index width, Index B_cols, Callback callback) = ChooseCPU(AVX512VNNI_8bit::Multiply8Shift<Callback>, AVX512_8bit::Multiply8Shift<Callback>, AVX2_8bit::Multiply8Shift<Callback>, SSSE3_8bit::Multiply8Shift<Callback>, SSSE3_8bit::Multiply8Shift<Callback>, Unsupported_8bit::Multiply8Shift);
+void (*Int8Shift::MultiplyImpl<Callback>::run)(const uint8_t *A, const int8_t *B, Index A_rows, Index width, Index B_cols, Callback callback) = ChooseCPU(
+ OMPParallelWrap8Shift<Callback, AVX512VNNI_8bit>,
+ OMPParallelWrap8Shift<Callback, AVX512_8bit>,
+ OMPParallelWrap8Shift<Callback, AVX2_8bit>,
+ OMPParallelWrap8Shift<Callback, SSSE3_8bit>,
+ Unsupported_8bit::Multiply8Shift<Callback>, Unsupported_8bit::Multiply8Shift<Callback>);
template <class Callback>
void (*Int8Shift::PrepareBiasImpl<Callback>::run)(const int8_t *B, Index width, Index B_cols, Callback callback) = ChooseCPU(AVX512VNNI_8bit::PrepareBias<Callback>, AVX512_8bit::PrepareBias<Callback>, AVX2_8bit::PrepareBias<Callback>, SSSE3_8bit::PrepareBias<Callback>, SSSE3_8bit::PrepareBias<Callback>, Unsupported_8bit::PrepareBias);
@@ -407,7 +412,7 @@ private:
};
template <typename Callback>
-void (*Int16::MultiplyImpl<Callback>::run)(const int16_t *A, const int16_t *B, Index A_rows, Index width, Index B_cols, Callback callback) = ChooseCPU(AVX512_16bit::Multiply<Callback> /*TODO VNNI 16-bit. */, AVX512_16bit::Multiply<Callback>, AVX2_16bit::Multiply<Callback>, SSE2_16bit::Multiply<Callback>, SSE2_16bit::Multiply<Callback>, Unsupported_16bit::Multiply);
+void (*Int16::MultiplyImpl<Callback>::run)(const int16_t *A, const int16_t *B, Index A_rows, Index width, Index B_cols, Callback callback) = ChooseCPU(OMPParallelWrap<Callback, AVX512_16bit> /*TODO VNNI 16-bit. */, OMPParallelWrap<Callback, AVX512_16bit>, OMPParallelWrap<Callback, AVX2_16bit>, OMPParallelWrap<Callback, SSE2_16bit>, OMPParallelWrap<Callback, SSE2_16bit>, Unsupported_16bit::Multiply<Callback>);
extern const CPUType kCPU;
diff --git a/multiply.h b/multiply.h
index 0fcf7e1..b1158ab 100644
--- a/multiply.h
+++ b/multiply.h
@@ -205,8 +205,9 @@ template <typename Callback> target static void Multiply(const int16_t *A, const
assert(reinterpret_cast<uintptr_t>(B) % sizeof(Register) == 0); \
const int simd_width = width / (sizeof(Register) / sizeof(int16_t)); \
auto callback_impl = callbacks::CallbackImpl<cpu_type, Callback>(callback); \
- const Register *B0_col = reinterpret_cast<const Register *>(B); \
- for (Index B0_colidx = 0; B0_colidx < B_cols; B0_col += 8 * simd_width, B0_colidx += 8) { \
+ _Pragma("omp for") \
+ for (Index B0_colidx = 0; B0_colidx < B_cols; B0_colidx += 8) { \
+ const Register *B0_col = reinterpret_cast<const Register *>(B) + simd_width * B0_colidx; \
/* Process one row of A at a time. Doesn't seem to be faster to do multiple rows of A at once.*/ \
for (Index A_rowidx = 0; A_rowidx < A_rows; ++A_rowidx) { \
const Register *A_row = reinterpret_cast<const Register*>(A + A_rowidx * width); \
@@ -261,9 +262,10 @@ template <typename Callback> target static void Multiply(const int16_t *A, const
assert(reinterpret_cast<uintptr_t>(B) % sizeof(Register) == 0); \
const int simd_width = width / (sizeof(Register) / sizeof(int8_t)); \
auto callback_impl = callbacks::CallbackImpl<cpu_type, Callback>(callback); \
- const Register *B0_col = reinterpret_cast<const Register *>(B); \
const Register a = set1_epi8<Register>(1); \
- for (Index B0_colidx = 0; B0_colidx < B_cols; B0_col += 8 * simd_width, B0_colidx += 8) { \
+ _Pragma("omp for") \
+ for (Index B0_colidx = 0; B0_colidx < B_cols; B0_colidx += 8) { \
+ const Register *B0_col = reinterpret_cast<const Register *>(B) + simd_width * B0_colidx; \
/*const Register *A_row = reinterpret_cast<const Register*>(A + A_rowidx * width);*/ \
/* These will be packed 16-bit integers containing sums for each row of B multiplied by the row of A. \
Iterate over shared (inner) dimension.*/ \
@@ -335,8 +337,9 @@ template <typename Callback> target static void Multiply(const int16_t *A, const
assert(reinterpret_cast<uintptr_t>(B) % sizeof(Register) == 0); \
const int simd_width = width / (sizeof(Register) / sizeof(int8_t)); \
auto callback_impl = callbacks::CallbackImpl<cpu_type, Callback>(callback); \
- const Register *B0_col = reinterpret_cast<const Register *>(B); \
- for (Index B0_colidx = 0; B0_colidx < B_cols; B0_col += 8 * simd_width, B0_colidx += 8) { \
+ _Pragma("omp for") \
+ for (Index B0_colidx = 0; B0_colidx < B_cols; B0_colidx += 8) { \
+ const Register *B0_col = reinterpret_cast<const Register *>(B) + simd_width * B0_colidx; \
/* Process one row of A at a time. Doesn't seem to be faster to do multiple rows of A at once.*/ \
for (Index A_rowidx = 0; A_rowidx < A_rows; ++A_rowidx) { \
const Register *A_row = reinterpret_cast<const Register*>(A + A_rowidx * width); \
@@ -559,9 +562,9 @@ INTGEMM_SSSE3 inline static void InnerINTGEMM_SSSE3(
assert(reinterpret_cast<uintptr_t>(B) % sizeof(Register) == 0); \
const int simd_width = width / sizeof(Register); \
auto callback_impl = callbacks::CallbackImpl<cpu_type, Callback>(callback); \
- const Register *B0_col = reinterpret_cast<const Register*>(B); \
- /*Go over 8 columns of B at a time.*/ \
- for (Index B0_colidx = 0; B0_colidx != B_cols; B0_col += 8 * simd_width, B0_colidx += 8) { \
+ _Pragma("omp for") \
+ for (Index B0_colidx = 0; B0_colidx < B_cols; B0_colidx += 8) { \
+ const Register *B0_col = reinterpret_cast<const Register *>(B) + simd_width * B0_colidx; \
/*Process one row of A at a time. Doesn't seem to be faster to do multiple rows of A at once.*/ \
for (Index A_rowidx = 0; A_rowidx < A_rows; ++A_rowidx) { \
/*Iterate over shared (inner) dimension.*/ \
@@ -617,7 +620,25 @@ INTGEMM_SSSE3 inline static void InnerINTGEMM_SSSE3(
RunCallback(callback_impl, total, A_rowidx, B0_colidx, A_rows, B_cols); \
} \
} \
-} \
+}
+
+/* Wrap a multiply call in OMP parallelism. Here it launches threads then
+ * inside the implementation there is a pragma omp for. In gcc >= 8 these
+ * could have been the same but older compilers don't imbue target attributes
+ * on the hidden function created by pragma omp parallel.
+ *
+ * Also, gcc 7 is unable to deduce the function pointer type (for ChooseCPU) if
+ * I use typename Backend::Integer directly in the arguments. As a workaround,
+ * have a default template argument Integer then use that so it's resolved.
+ */
+template <class Callback, class Backend, class Integer = typename Backend::Integer> static inline void OMPParallelWrap(const Integer *A, const Integer *B, Index A_rows, Index width, Index B_cols, Callback callback) {
+#pragma omp parallel
+ Backend::template Multiply<Callback>(A, B, A_rows, width, B_cols, callback);
+}
+template <class Callback, class Backend> static inline void OMPParallelWrap8Shift(const uint8_t *A, const int8_t *B, Index A_rows, Index width, Index B_cols, Callback callback) {
+#pragma omp parallel
+ Backend::template Multiply8Shift<Callback>(A, B, A_rows, width, B_cols, callback);
+}
#define INTGEMM_MAXABSOLUTE(Register, target) \
target static inline float MaxAbsolute(const float *begin_float, const float *end_float) { \
diff --git a/test/add127_test.cc b/test/add127_test.cc
index d959b14..cec20c2 100644
--- a/test/add127_test.cc
+++ b/test/add127_test.cc
@@ -127,7 +127,7 @@ template <class Routine> void TestMultiplyBiasNew(Index A_rows, Index width, Ind
});
AlignedVector<float> float_C(test_C.size());
- references::MultiplyFF(A.begin(), B.begin(), float_C.begin(), A_rows, width, B_cols, [&](float sum, const callbacks::OutputBufferInfo& info) {
+ references::Multiply(A.begin(), B.begin(), float_C.begin(), A_rows, width, B_cols, [&](float sum, const callbacks::OutputBufferInfo& info) {
return sum + bias[info.col_idx];
});
@@ -184,7 +184,7 @@ template <class Routine> void TestMultiplyShiftNonShift(Index A_rows, Index widt
Routine::Multiply(A_prep_old.begin(), B_prep.begin(), A_rows, width, B_cols, callbacks::UnquantizeAndAddBiasAndWrite(unquant_mult, bias.begin(), slowint_C.begin()));
AlignedVector<float> float_C(test_C.size());
- references::MultiplyFF(A.begin(), B.begin(), float_C.begin(), A_rows, width, B_cols, [&](float sum, const callbacks::OutputBufferInfo& info) {
+ references::Multiply(A.begin(), B.begin(), float_C.begin(), A_rows, width, B_cols, [&](float sum, const callbacks::OutputBufferInfo& info) {
return sum + bias[info.col_idx];
});
@@ -245,7 +245,7 @@ template <class Routine> void TestMultiplyShiftInt(Index A_rows, Index width, In
// });
AlignedVector<float> float_C(test_C.size());
- references::MultiplyFF(A.begin(), B.begin(), float_C.begin(), A_rows, width, B_cols, [&](float sum, const callbacks::OutputBufferInfo& info) {
+ references::Multiply(A.begin(), B.begin(), float_C.begin(), A_rows, width, B_cols, [&](float sum, const callbacks::OutputBufferInfo& info) {
return sum + bias[info.col_idx];
});
/*
diff --git a/test/multiply_test.cc b/test/multiply_test.cc
index 260dd76..a054753 100644
--- a/test/multiply_test.cc
+++ b/test/multiply_test.cc
@@ -278,7 +278,7 @@ template <class Routine> void TestMultiply(Index A_rows, Index width, Index B_co
Routine::PrepareB(B.begin(), B_prep.begin(), quant_mult, width, B_cols);
AlignedVector<float> test_C(A_rows * B_cols);
- Routine::Multiply(A_prep.begin(), B_prep.begin(), A_rows, width, B_cols, callbacks::UnquantizeAndWrite(unquant_mult, test_C.begin()));
+ OMPParallelWrap<callbacks::UnquantizeAndWrite, Routine>(A_prep.begin(), B_prep.begin(), A_rows, width, B_cols, callbacks::UnquantizeAndWrite(unquant_mult, test_C.begin()));
// Routine::Multiply(A_prep.begin(), B_prep.begin(), A_rows, width, B_cols, callbacks::Sequence(
// callbacks::Unquantize(unquant_mult),
// callbacks::Write<float>(test_C.begin())
@@ -293,7 +293,7 @@ template <class Routine> void TestMultiply(Index A_rows, Index width, Index B_co
});
AlignedVector<float> float_C(test_C.size());
- references::MultiplyFF(A.begin(), B.begin(), float_C.begin(), A_rows, width, B_cols, [&](float sum, const callbacks::OutputBufferInfo&) {
+ references::Multiply(A.begin(), B.begin(), float_C.begin(), A_rows, width, B_cols, [&](float sum, const callbacks::OutputBufferInfo&) {
return sum;
});
@@ -346,7 +346,7 @@ template <class Routine> void TestMultiplyBias(Index A_rows, Index width, Index
});
AlignedVector<float> float_C(test_C.size());
- references::MultiplyFF(A.begin(), B.begin(), float_C.begin(), A_rows, width, B_cols, [&](float sum, const callbacks::OutputBufferInfo& info) {
+ references::Multiply(A.begin(), B.begin(), float_C.begin(), A_rows, width, B_cols, [&](float sum, const callbacks::OutputBufferInfo& info) {
return sum + bias[info.col_idx];
});
diff --git a/test/test.h b/test/test.h
index 7de38e9..f145681 100644
--- a/test/test.h
+++ b/test/test.h
@@ -76,30 +76,28 @@ void Quantize(const float* input, Type* output, float quant_mult, Index size) {
}
}
-// Multiply A(float) x B(float)
-template <typename LambdaCallback>
-void MultiplyFF(const float* A, const float* B, float* C, Index A_rows, Index width, Index B_cols, LambdaCallback callback) {
- for (Index r = 0; r < A_rows; ++r) {
- for (Index c = 0; c < B_cols; ++c) {
- float sum = 0.0f;
- for (Index k = 0; k < width; ++k) {
- sum += A[r * width + k] * B[k * B_cols + c];
- }
- C[r * B_cols + c] = callback(sum, {r, c, A_rows, B_cols});
- }
- }
-}
+/*
+ * Multiply C = A x B
+ *
+ * Notes: A and B has to be both integers or both floating points.
+ *
+ * Callback takes two arguments:
+ * - Intermediate value of multiplication 1 row times 1 column - it's int32_t or double based on types A and B.
+ * - Object containing information about position in output matrix - callbacks::OutputBufferInfo.
+ */
+template <typename TypeA, typename TypeB, typename TypeC, typename LambdaCallback,
+ typename std::enable_if<
+ (std::is_integral<TypeA>::value && std::is_integral<TypeB>::value) ||
+ (std::is_floating_point<TypeA>::value && std::is_floating_point<TypeB>::value)
+ >::type* = nullptr>
+void Multiply(const TypeA* A, const TypeB* B, TypeC* C, Index A_rows, Index width, Index B_cols, LambdaCallback callback) {
+ using IntermediateType = typename std::conditional<std::is_integral<TypeA>::value, int32_t, double>::type;
-// Multiply A(int) x B(int)
-template <typename TypeA, typename TypeB, typename LambdaCallback,
- typename std::enable_if<std::is_integral<TypeA>::value>::type* = nullptr,
- typename std::enable_if<std::is_integral<TypeB>::value>::type* = nullptr>
-void Multiply(const TypeA* A, const TypeB* B, float* C, Index A_rows, Index width, Index B_cols, LambdaCallback callback) {
for (Index r = 0; r < A_rows; ++r) {
for (Index c = 0; c < B_cols; ++c) {
- int32_t sum = 0;
+ IntermediateType sum = 0;
for (Index k = 0; k < width; ++k) {
- sum += int32_t(A[r * width + k]) * int32_t(B[k * B_cols + c]);
+ sum += IntermediateType(A[r * width + k]) * IntermediateType(B[k * B_cols + c]);
}
C[r * B_cols + c] = callback(sum, {r, c, A_rows, B_cols});
}
diff --git a/test/utils_test.cc b/test/utils_test.cc
index 8596104..0281802 100644
--- a/test/utils_test.cc
+++ b/test/utils_test.cc
@@ -41,6 +41,12 @@ struct StaticLoopTest {
}
};
+TEST_CASE("Static loop (N = 0)",) {
+ Index result = 128;
+ StaticLoop<StaticLoopTest, MakeStaticLoopIterator<0>>(result);
+ CHECK(result == 128);
+}
+
TEST_CASE("Static loop (N = 1)",) {
Index result = 128;
StaticLoop<StaticLoopTest, MakeStaticLoopIterator<1>>(result);
@@ -78,5 +84,12 @@ TEST_CASE("Static loop with mult-dim iterator (Iterator<5, 2>)",) {
CHECK(result == 11223344);
}
+TEST_CASE("Round up",) {
+ CHECK(round_up(0, 5) == 0);
+ CHECK(round_up(1, 5) == 5);
+ CHECK(round_up(4, 5) == 5);
+ CHECK(round_up(6, 5) == 10);
+}
+
}
}
diff --git a/utils.h b/utils.h
index 7fa2f6e..0cda979 100644
--- a/utils.h
+++ b/utils.h
@@ -52,20 +52,18 @@ constexpr subtuple_t<Tuple, Indices...> make_subtuple(const Tuple& tuple, sequen
/*
* Factorial
*/
-constexpr unsigned long long factorial(unsigned n) {
+static constexpr unsigned long long factorial(unsigned n) {
return n <= 1 ? 1 : n * factorial(n - 1);
}
/*
* e^n, where n is integer
*/
-namespace { // anonymous namespace
-constexpr double expi_nonnegative(unsigned n) {
+static constexpr double expi_nonnegative(unsigned n) {
return n == 0 ? 1.0 : (n == 1 ? 2.718281828459045 : expi_nonnegative(n / 2) * expi_nonnegative((n + 1) / 2));
}
-} // anonymous namespace
-constexpr double expi(int n) {
+static constexpr double expi(int n) {
return (n >= 0 ? expi_nonnegative(n) : 1.0 / expi_nonnegative(-n));
}
@@ -143,7 +141,7 @@ public:
/*
* Last iterator
*/
- using last = StaticLoopIterator<total_iterations - 1, Ns...>;
+ using end = StaticLoopIterator<total_iterations, Ns...>;
};
/*
@@ -190,15 +188,21 @@ using MakeStaticLoopIterator = StaticLoopIterator<0, Ns...>;
* [4, 1] Test 1
*
*/
-template <typename Body, typename StaticLoopIterator, typename std::enable_if<std::is_same<StaticLoopIterator, typename StaticLoopIterator::last>::value>::type* = nullptr, typename... Args>
-__attribute__((always_inline)) static inline void StaticLoop(Args&&... args) {
- Body::template body<StaticLoopIterator>(std::forward<Args>(args)...);
+template <typename Body, typename StaticLoopIterator, typename std::enable_if<std::is_same<StaticLoopIterator, typename StaticLoopIterator::end>::value>::type* = nullptr, typename... Args>
+__attribute__((always_inline)) static inline void StaticLoop(Args&&...) {
}
-template <typename Body, typename StaticLoopIterator, typename std::enable_if<!std::is_same<StaticLoopIterator, typename StaticLoopIterator::last>::value>::type* = nullptr, typename... Args>
+template <typename Body, typename StaticLoopIterator, typename std::enable_if<!std::is_same<StaticLoopIterator, typename StaticLoopIterator::end>::value>::type* = nullptr, typename... Args>
__attribute__((always_inline)) static inline void StaticLoop(Args&&... args) {
Body::template body<StaticLoopIterator>(std::forward<Args>(args)...);
StaticLoop<Body, typename StaticLoopIterator::next>(std::forward<Args>(args)...);
}
+/*
+ * Round up
+ */
+static constexpr Index round_up(Index value, Index factor) {
+ return (value + factor - 1) / factor * factor;
+}
+
}