Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/intgemm/intgemm.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKenneth Heafield <kpu@users.noreply.github.com>2020-03-03 13:33:53 +0300
committerGitHub <noreply@github.com>2020-03-03 13:33:53 +0300
commitbd52e619f1585fe670ae614ceea2c473d929fd4d (patch)
tree558f1b67c9119fdff228c7a4a9e284c7c6624684
parent1db4a86d5736d09d5b2e7f1965a99057b03ba7af (diff)
parentc7dc0b2670b7788b7832ef51fb3c0004f074b5d0 (diff)
Merge pull request #69 from kpuatamazon/master
Quantizer: arbitrary length and OpenMP support
-rw-r--r--CMakeLists.txt30
-rw-r--r--avx2_gemm.h14
-rw-r--r--avx512_gemm.h19
-rw-r--r--benchmarks/benchmark_quantizer.cc42
-rw-r--r--multiply.h66
-rw-r--r--ssse3_gemm.h15
-rw-r--r--test/multiply_test.cc24
-rw-r--r--test/quantize_test.cc36
8 files changed, 159 insertions, 87 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 5b24410..022fa7f 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -18,7 +18,7 @@ try_compile(INTGEMM_COMPILER_SUPPORTS_AVX512
COMPILE_DEFINITIONS -mavx512f -mavx512bw -mavx512dq)
if(NOT INTGEMM_COMPILER_SUPPORTS_AVX512)
- message("${Orange}Not building AVX512-based multiplication because your compiler is too old.\nFor details rerun cmake with --debug-trycompile then try to build in compile_tests/CMakeFiles/CMakeTmp.${ColourReset}")
+ message(WARNING "${Orange}Not building AVX512-based multiplication because your compiler is too old.\nFor details rerun cmake with --debug-trycompile then try to build in compile_tests/CMakeFiles/CMakeTmp.${ColourReset}")
endif()
try_compile(INTGEMM_COMPILER_SUPPORTS_AVX512VNNI
@@ -26,13 +26,13 @@ try_compile(INTGEMM_COMPILER_SUPPORTS_AVX512VNNI
${CMAKE_CURRENT_SOURCE_DIR}/compile_test_avx512vnni.cc)
#No compiler flags for this test; that's part of the test!
if(NOT INTGEMM_COMPILER_SUPPORTS_AVX512VNNI)
- message("${Orange}Not building AVX512VNNI-based multiplication because your compiler is too old.\nFor details rerun cmake with --debug-trycompile then try to build in compile_tests/CMakeFiles/CMakeTmp.${ColourReset}")
+ message(WARNING "${Orange}Not building AVX512VNNI-based multiplication because your compiler is too old.\nFor details rerun cmake with --debug-trycompile then try to build in compile_tests/CMakeFiles/CMakeTmp.${ColourReset}")
endif()
# Working around https://bugs.llvm.org/show_bug.cgi?id=41482
# Anything compiled with clang might not work properly in SSE2/SSSE3 world
if("${CMAKE_CXX_COMPILER_ID}" MATCHES "Clang")
- message("${Orange}Compiling with Clang and using -mavx due to https://bugs.llvm.org/show_bug.cgi?id=41482. Support for SSE2/SSSE3 hardware is likely broken at this point.${ColourReset}")
+ message(WARNING "${Orange}Compiling with Clang and using -mavx due to https://bugs.llvm.org/show_bug.cgi?id=41482. Support for SSE2/SSSE3 hardware is likely broken at this point.${ColourReset}")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mavx")
endif()
@@ -41,15 +41,29 @@ configure_file(${CMAKE_CURRENT_SOURCE_DIR}/intgemm_config.h.in ${CMAKE_CURRENT_B
include_directories(${CMAKE_CURRENT_BINARY_DIR})
+add_library(intgemm STATIC intgemm.cc)
+
+if (OPENMP)
+ message(STATUS "Compiling with OpenMP")
+ find_package(OpenMP)
+ if (NOT ${OpenMP_CXX_FOUND})
+ message(SEND_ERROR "OpenMP requested but C++ support not found")
+ endif()
+ add_compile_options(${OpenMP_CXX_FLAGS})
+ target_link_libraries(intgemm PUBLIC OpenMP::OpenMP_CXX)
+endif()
+
if(INTGEMM_DONT_BUILD_TESTS)
return()
endif()
-foreach(exe benchmark biasmultiply)
- add_executable(${exe} benchmarks/${exe}.cc intgemm.cc)
+foreach(exe benchmark biasmultiply benchmark_quantizer)
+ add_executable(${exe} benchmarks/${exe}.cc)
+ target_link_libraries(${exe} intgemm)
endforeach()
-add_executable(example example.cc intgemm.cc)
+add_executable(example example.cc)
+target_link_libraries(example intgemm)
add_executable(tests
test/test.cc
@@ -78,10 +92,8 @@ add_executable(tests
test/kernels/unquantize_test.cc
test/kernels/upcast_test.cc
test/kernels/write_test.cc
-
- # Definitions
- intgemm.cc
)
+target_link_libraries(tests intgemm)
#CTest integration with Catch2
include(${CMAKE_CURRENT_SOURCE_DIR}/CMake/Catch.cmake)
diff --git a/avx2_gemm.h b/avx2_gemm.h
index 1addd1e..54f3c6c 100644
--- a/avx2_gemm.h
+++ b/avx2_gemm.h
@@ -7,6 +7,7 @@
#include <cstdint>
#include <stdint.h>
+#include <cstring>
namespace intgemm {
@@ -135,7 +136,6 @@ class QuantizeTile8 {
return Tile(input, input + 2 * cols, input + 16 * cols, input + 18 * cols);
}
- private:
INTGEMM_AVX2 inline __m256i Tile(const float *input0, const float *input1, const float *input2, const float *input3) const {
// Looking at the assembly, gcc has pulled this outside the loops calling this.
const __m256i neg127 = _mm256_set1_epi8(-127);
@@ -159,6 +159,7 @@ class QuantizeTile8 {
return _mm256_permutevar8x32_epi32(packed, shuffle_param);
}
+ private:
//A version that produces uint8_ts
INTGEMM_AVX2 inline __m256i TileU(const float *input0, const float *input1, const float *input2, const float *input3) const {
// Looking at the assembly, gcc has pulled this outside the loops calling this.
@@ -201,16 +202,7 @@ struct AVX2_8bit {
Quantize(input, output, quant_mult, rows * cols);
}
- // Just quantize everything in order.
- INTGEMM_AVX2 static void Quantize(const float *input, int8_t *output, float quant_mult, Index size) {
- assert(size % 32 == 0);
- assert(reinterpret_cast<uintptr_t>(input) % 32 == 0);
- avx2::QuantizeTile8 q(quant_mult);
- const float *end = input + size;
- for (; input != end; input += 32, output += 32) {
- *reinterpret_cast<__m256i*>(output) = q.Consecutive(input);
- }
- }
+ INTGEMM_QUANTIZE(INTGEMM_AVX2, __m256i, avx2)
// Currently A is prepared by quantization but this could theoretically change.
INTGEMM_AVX2 static inline void PrepareA(const float *input, uint8_t *output, float quant_mult, Index rows, Index cols) {
diff --git a/avx512_gemm.h b/avx512_gemm.h
index 2a5fff1..eba0322 100644
--- a/avx512_gemm.h
+++ b/avx512_gemm.h
@@ -229,17 +229,24 @@ struct AVX512_8bit {
// Convert to 8-bit signed integers.
/* Only INTGEMM_AVX512F is necessary but due to GCC 5.4 bug we have to set INTGEMM_AVX512BW */
INTGEMM_AVX512BW static void Quantize(const float *input, int8_t *output, float quant_mult, Index size) {
- assert(size % 16 == 0);
- assert(reinterpret_cast<uintptr_t>(input) % 64 == 0);
+ assert(reinterpret_cast<uintptr_t>(input) % sizeof(__m512i) == 0);
const __m512i neg127 = _mm512_set1_epi32(-127);
const __m512 quant_mult_reg = _mm512_set1_ps(quant_mult);
- const float *end = input + size;
- for (; input < end; input += 16, output += 16) {
- __m512i asint = avx512f::QuantizerGrab(input, quant_mult_reg);
+ const std::size_t kBatch = sizeof(__m512i) / sizeof(float);
+ const float *fast_input_end = input + (size & ~(kBatch - 1));
+ int8_t *fast_output_end = output + (size & ~(kBatch - 1));
+#pragma omp parallel for
+ for (const float *input_it = input; input_it < fast_input_end; input_it += kBatch) {
+ __m512i asint = avx512f::QuantizerGrab(input_it, quant_mult_reg);
asint = _mm512_max_epi32(asint, neg127);
// There doesn't seem to be an unmasked version.
- _mm512_mask_cvtsepi32_storeu_epi8(output, 0xffff, asint);
+ _mm512_mask_cvtsepi32_storeu_epi8(output + (input_it - input), 0xffff, asint);
}
+ std::size_t overhang = size & (kBatch - 1);
+ if (!overhang) return; // We needed a branch anyway for the empty case.
+ __m512i asint = avx512f::QuantizerGrab(fast_input_end, quant_mult_reg);
+ asint = _mm512_max_epi32(asint, neg127);
+ _mm512_mask_cvtsepi32_storeu_epi8(fast_output_end, (1 << overhang) - 1, asint);
}
// Preparing A for the signed/unsigned multiplication. Using add 127
diff --git a/benchmarks/benchmark_quantizer.cc b/benchmarks/benchmark_quantizer.cc
new file mode 100644
index 0000000..0a6e6d8
--- /dev/null
+++ b/benchmarks/benchmark_quantizer.cc
@@ -0,0 +1,42 @@
+#include "../intgemm.h"
+#include "../aligned.h"
+#include "../stop_watch.h"
+#include "../ssse3_gemm.h"
+#include "../avx2_gemm.h"
+#include "../avx512_gemm.h"
+
+#include <iomanip>
+#include <random>
+#include <vector>
+
+namespace {
+template <class Backend> void QuantizerBench(const float *in, int8_t *out, std::size_t count) {
+ if (intgemm::kCPU < Backend::kUses) return;
+ Backend::Quantize(in, out, 1.0, count);
+ const std::size_t kTries = 60;
+ auto start = std::chrono::system_clock::now();
+ for (std::size_t t = 0; t < kTries; ++t) {
+ Backend::Quantize(in, out, 1.0, count);
+ }
+ auto end = std::chrono::system_clock::now();
+ double took = std::chrono::duration<double>(end - start).count() / kTries;
+ std::cout << std::setw(9) << count << ' ' << std::fixed << std::setw(9) << std::setprecision(7) << took << ' ' << Backend::kName << std::endl;
+}
+} // namespace
+
+int main() {
+ for (std::size_t count = 1; count < (1ULL<<30); count *= 2) {
+ intgemm::AlignedVector<float> in(count);
+ intgemm::AlignedVector<int8_t> out(count);
+ std::mt19937 gen;
+ std::uniform_real_distribution<float> dist(-129.0, 129.0);
+ for (float &element : in) {
+ element = dist(gen);
+ }
+ QuantizerBench<intgemm::SSSE3_8bit>(in.begin(), out.begin(), count);
+ QuantizerBench<intgemm::AVX2_8bit>(in.begin(), out.begin(), count);
+#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512
+ QuantizerBench<intgemm::AVX512_8bit>(in.begin(), out.begin(), count);
+#endif
+ }
+}
diff --git a/multiply.h b/multiply.h
index 0aa86aa..e2aa0fb 100644
--- a/multiply.h
+++ b/multiply.h
@@ -59,20 +59,38 @@ static inline INTGEMM_AVX512F float MaxFloat32(__m512 a) {
#endif
+// Quantize function used for SSSE3 and AVX2.
+#define INTGEMM_QUANTIZE(target, Register, name) \
+target static void Quantize(const float *const input, int8_t *const output, float quant_mult, Index size) { \
+ assert(reinterpret_cast<uintptr_t>(input) % sizeof(Register) == 0); \
+ assert(reinterpret_cast<uintptr_t>(output) % sizeof(Register) == 0); \
+ name::QuantizeTile8 q(quant_mult); \
+ const std::size_t kBatch = sizeof(Register); \
+ const std::size_t fast_end = size & ~(kBatch - 1); \
+ _Pragma("omp parallel for") \
+ for (std::size_t i = 0; i < fast_end; i += kBatch) { \
+ *reinterpret_cast<Register*>(output + i) = q.Consecutive(input + i); \
+ } \
+ std::size_t overhang = size & (kBatch - 1); \
+ if (!overhang) return; \
+ /* Each does size(Register) / 32 == kBatch / 4 floats at a time.
+ * If we're allowed to read one of them, then we can read the whole register. */ \
+ const float *inputs[4]; \
+ std::size_t i; \
+ for (i = 0; i < (overhang + (kBatch / 4) - 1) / (kBatch / 4); ++i) { \
+ inputs[i] = &input[fast_end + i * (kBatch / 4)]; \
+ } \
+ /* These will be clipped off. */ \
+ for (; i < 4; ++i) { \
+ inputs[i] = &input[fast_end]; \
+ } \
+ Register result = q.Tile(inputs[0], inputs[1], inputs[2], inputs[3]); \
+ std::memcpy(output, &result, overhang); \
+}
+
/* Take 4 registers with 32-bit values to be horizontally added. Reduce them
* to one register with 32-bit values in the pattern 1 2 3 4 1 2 3 4, leaving
* the final addition (which crosses 128-bit lanes) to the caller.
-template <class Register> inline Register Pack0123(Register sum0, Register sum1, Register sum2, Register sum3) {
- // 1 2 1 2 1 2 1 2
- Interleave32(sum0, sum1);
- Register pack01 = add_epi32(sum0, sum1);
- // 3 4 3 4 3 4 3 4
- Interleave32(sum2, sum3);
- Register pack23 = add_epi32(sum2, sum3);
- Interleave64(pack01, pack23);
- // 1 2 3 4 1 2 3 4
- return add_epi32(pack01, pack23);
-}
*/
#define INTGEMM_PACK0123(target, Register) \
target inline Register Pack0123(Register sum0, Register sum1, Register sum2, Register sum3) { \
@@ -562,20 +580,28 @@ INTGEMM_SSSE3 inline static void InnerINTGEMM_SSSE3(
} \
#define INTGEMM_MAXABSOLUTE(Register, target) \
-target static float MaxAbsolute(const float *begin_float, const float *end_float) { \
+target static inline float MaxAbsolute(const float *begin_float, const float *end_float) { \
assert(end_float > begin_float); \
- assert((end_float - begin_float) % (sizeof(Register) / sizeof(float)) == 0); \
+ assert(reinterpret_cast<uintptr_t>(begin_float) % sizeof(Register) == 0); \
const Register *begin = reinterpret_cast<const Register*>(begin_float); \
- const Register *end = reinterpret_cast<const Register*>(end_float); \
- union {float f; int32_t i;} float_convert; \
- float_convert.i = 0x7fffffff; \
- Register and_me = set1_ps<Register>(float_convert.f); \
- Register highest = and_ps(and_me, *begin); \
- for (++begin; begin != end; ++begin) { \
+ const float *end_reg = end_float - (reinterpret_cast<uintptr_t>(end_float) % sizeof(Register)) / sizeof(float); \
+ const Register *end = reinterpret_cast<const Register*>(end_reg); \
+ union {float f; int32_t i;} and_convert, float_convert; \
+ and_convert.i = 0x7fffffff; \
+ Register and_me = set1_ps<Register>(and_convert.f); \
+ Register highest = setzero_ps<Register>(); \
+ for (; begin < end; ++begin) { \
Register reg = and_ps(and_me, *begin); \
highest = max_ps(highest, reg); \
} \
- return MaxFloat32(highest); \
+ float ret = MaxFloat32(highest); \
+ /* Overhang: this would be more efficient if done in a single SIMD operation with some zeroing */ \
+ for (const float *i = end_reg; i < end_float; ++i) { \
+ float_convert.f = *i; \
+ float_convert.i &= and_convert.i; \
+ ret = std::max(ret, float_convert.f); \
+ } \
+ return ret; \
} \
} // namespace intgemm
diff --git a/ssse3_gemm.h b/ssse3_gemm.h
index 44e2a4d..dc1ae83 100644
--- a/ssse3_gemm.h
+++ b/ssse3_gemm.h
@@ -7,6 +7,7 @@
#include <cstdint>
#include <stdint.h>
+#include <cstring>
// 16-bit is in sse2_gemm.h
@@ -53,7 +54,6 @@ class QuantizeTile8 {
return Tile(inputs[0], inputs[1], inputs[2], inputs[3]);
}
- private:
// Quantize 16xfloat into 16xint8_t
INTGEMM_SSSE3 inline __m128i Tile(const float *input0, const float *input1, const float *input2, const float *input3) const {
const __m128i neg128 = _mm_set1_epi8(-128);
@@ -77,6 +77,7 @@ class QuantizeTile8 {
// No permute needed. packs is in order for SSE.
}
+ private:
INTGEMM_SSSE3 inline __m128i TileU(const float *input0, const float *input1, const float *input2, const float *input3) const {
const __m128i neg128 = _mm_set1_epi8(-128);
const __m128i pos127 = _mm_set1_epi8(127);
@@ -106,7 +107,6 @@ class QuantizeTile8 {
} // namespace
-
// pmaddubsw (the 8-bit multiply) is INTGEMM_SSSE3, so pedantically that's the version we need.
struct SSSE3_8bit {
typedef int8_t Integer;
@@ -116,16 +116,7 @@ struct SSSE3_8bit {
Quantize(input, output, quant_mult, rows * cols);
}
- INTGEMM_SSSE3 static void Quantize(const float *input, int8_t *output, float quant_mult, Index size) {
- assert(size % 16 == 0);
- assert(reinterpret_cast<uintptr_t>(input) % 16 == 0);
- assert(reinterpret_cast<uintptr_t>(output) % 16 == 0);
- ssse3::QuantizeTile8 q(quant_mult);
- const float *end = input + size;
- for (; input != end; input += 16, output += 16) {
- *reinterpret_cast<__m128i*>(output) = q.Consecutive(input);
- }
- }
+ INTGEMM_QUANTIZE(INTGEMM_SSSE3, __m128i, ssse3)
// Version with unsigned int + 127
// Currently A is prepared by quantization but this could theoretically change.
diff --git a/test/multiply_test.cc b/test/multiply_test.cc
index 59c62a9..97f68a3 100644
--- a/test/multiply_test.cc
+++ b/test/multiply_test.cc
@@ -194,18 +194,20 @@ void CompareMaxAbs(const float *begin, const float *end, float test) {
template <float (*Backend) (const float *, const float *)> void TestMaxAbsolute() {
std::mt19937 gen;
std::uniform_real_distribution<float> dist(-8.0, 8.0);
- AlignedVector<float> test(64);
- // 64 tries.
- for (int t = 0; t < 64; ++t) {
- // Fill with [-8, 8).
- for (auto& it : test) {
- it = dist(gen);
+ const std::size_t kLengthMax = 65;
+ AlignedVector<float> test(kLengthMax);
+ for (std::size_t len = 1; len < kLengthMax; ++len) {
+ for (int t = 0; t < len; ++t) {
+ // Fill with [-8, 8).
+ for (auto& it : test) {
+ it = dist(gen);
+ }
+ CompareMaxAbs(test.begin(), test.begin() + len, Backend(test.begin(), test.begin() + len));
+ test[t] = -32.0;
+ CompareMaxAbs(test.begin(), test.begin() + len, Backend(test.begin(), test.begin() + len));
+ test[t] = 32.0;
+ CompareMaxAbs(test.begin(), test.begin() + len, Backend(test.begin(), test.begin() + len));
}
- CompareMaxAbs(test.begin(), test.end(), Backend(test.begin(), test.end()));
- test[t] = -32.0;
- CompareMaxAbs(test.begin(), test.end(), Backend(test.begin(), test.end()));
- test[t] = 32.0;
- CompareMaxAbs(test.begin(), test.end(), Backend(test.begin(), test.end()));
}
}
diff --git a/test/quantize_test.cc b/test/quantize_test.cc
index 83b1d20..3263812 100644
--- a/test/quantize_test.cc
+++ b/test/quantize_test.cc
@@ -52,45 +52,45 @@ template <class Backend> bool Test(const float *input_unaligned, float quant_mul
Backend::Quantize(input.begin(), test.begin(), quant_mult, size);
for (std::size_t i = 0; i < size; ++i) {
if (IsOff(input[i] * quant_mult, ref[i], test[i])) {
- UNSCOPED_INFO("Error at " << i << " from " << input[i] << '*' << quant_mult << '=' << (input[i]*quant_mult) << " ref = " << ref[i] << " test = " << test[i]);
+ UNSCOPED_INFO("Error at " << i << " from " << input[i] << '*' << quant_mult << '=' << (input[i]*quant_mult) << " ref = " << static_cast<int>(ref[i]) << " test = " << static_cast<int>(test[i]));
success = false;
}
}
return success;
}
-template <class Backend> bool TestMany() {
- bool success = true;
- float input[32] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31};
- success &= Test<Backend>(input, 1.0, 32);
- success &= Test<Backend>(input, 32.0, 32);
- float corners[32] = {-32769, -32768, -32767, -129, -128, -127, -1, 0, 1, 126, 127, 128, 129, 32766, 32768, 32769, -1.9, -1.5, -1.1, -1, -0.9, -0.5, -0.1, 0.0, 0.1, 0.5, 0.9, 1.0, 1.1, 1.5, 1.9, 16056.8};
- success &= Test<Backend>(corners, 1.0, sizeof(corners) / sizeof(float));
- success &= Test<Backend>(corners, -1.0, sizeof(corners) / sizeof(float));
- success &= Test<Backend>(corners, -0.49, sizeof(corners) / sizeof(float));
- return success;
+template <class Backend> void TestMany(std::size_t grow) {
+ float input[33] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32};
+ float corners[33] = {-32769, -32768, -32767, -129, -128, -127, -1, 0, 1, 126, 127, 128, 129, 32766, 32768, 32769, -1.9, -1.5, -1.1, -1, -0.9, -0.5, -0.1, 0.0, 0.1, 0.5, 0.9, 1.0, 1.1, 1.5, 1.9, 16056.8, 2.5};
+ for (std::size_t len = 0; len <= 33; len += grow) {
+ CHECK(Test<Backend>(input, 1.0, len));
+ CHECK(Test<Backend>(input, 32.0, len));
+ CHECK(Test<Backend>(corners, 1.0, len));
+ CHECK(Test<Backend>(corners, -1.0, len));
+ CHECK(Test<Backend>(corners, -0.49, len));
+ }
}
TEST_CASE ("Quantize SSE2", "[quantize]") {
if (kCPU < CPUType::SSE2) return;
- CHECK(TestMany<SSE2_16bit>());
+ TestMany<SSE2_16bit>(8);
}
-TEST_CASE ("Quantize SSE3", "[quantize]") {
+TEST_CASE ("Quantize SSSE3", "[quantize]") {
if (kCPU < CPUType::SSSE3) return;
- CHECK(TestMany<SSSE3_8bit>());
+ TestMany<SSSE3_8bit>(1);
}
TEST_CASE ("Quantize AVX2", "[quantize]") {
if (kCPU < CPUType::AVX2) return;
- CHECK(TestMany<AVX2_8bit>());
- CHECK(TestMany<AVX2_16bit>());
+ TestMany<AVX2_8bit>(1);
+ TestMany<AVX2_16bit>(16);
}
#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512
TEST_CASE ("Quantize AVX512", "[quantize]") {
if (kCPU < CPUType::AVX512BW) return;
- CHECK(TestMany<AVX512_8bit>());
- CHECK(TestMany<AVX512_16bit>());
+ TestMany<AVX512_8bit>(1);
+ TestMany<AVX512_16bit>(16);
}
#endif