diff options
author | Kenneth Heafield <github@kheafield.com> | 2019-07-05 14:25:47 +0300 |
---|---|---|
committer | Kenneth Heafield <github@kheafield.com> | 2019-07-05 14:25:47 +0300 |
commit | b7a7fe77a4e659f785588585526e06721ffcdd08 (patch) | |
tree | 0a975c5c7fb48349a0c348bbbd5c8527caef17ff | |
parent | 2807988a70c59169c1ea223bd734562351508f47 (diff) | |
parent | ce292be1138ecce0ec127ed59fe79d0091be7d11 (diff) |
Merge branch 'master' into 4bit4bit
-rw-r--r-- | CMakeLists.txt | 10 | ||||
-rw-r--r-- | aligned.h | 3 | ||||
-rw-r--r-- | interleave.h | 1 | ||||
-rw-r--r-- | intrinsics.h | 77 | ||||
-rw-r--r-- | postprocess.h | 252 | ||||
-rw-r--r-- | postprocess_pipeline.h | 4 | ||||
-rw-r--r-- | test/multiply_test.cc | 31 | ||||
-rw-r--r-- | test/pipeline_test.cc | 70 | ||||
-rw-r--r-- | test/postprocess/add_bias_test.cc | 95 | ||||
-rw-r--r-- | test/postprocess/pipeline_test.cc | 63 | ||||
-rw-r--r-- | test/postprocess/relu_test.cc | 213 | ||||
-rw-r--r-- | test/postprocess/sigmoid_test.cc | 33 | ||||
-rw-r--r-- | test/postprocess/tanh_test.cc | 33 | ||||
-rw-r--r-- | test/postprocess/unquantize_test.cc | 88 | ||||
-rw-r--r-- | test/quantize_test.cc | 12 | ||||
-rw-r--r-- | test/relu_test.cc | 89 | ||||
-rw-r--r-- | test/test.cc | 6 | ||||
-rw-r--r-- | test/test.h | 12 | ||||
-rw-r--r-- | test/utils_test.cc | 38 | ||||
-rw-r--r-- | utils.h | 20 | ||||
-rw-r--r-- | vec_utils.h | 80 |
21 files changed, 984 insertions, 246 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt index d39e09c..c6fc8d0 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -33,9 +33,15 @@ endforeach() include_directories(.) add_executable(tests test/multiply_test.cc - test/pipeline_test.cc + test/postprocess/add_bias_test.cc + test/postprocess/pipeline_test.cc + test/postprocess/relu_test.cc + test/postprocess/sigmoid_test.cc + test/postprocess/tanh_test.cc + test/postprocess/unquantize_test.cc test/quantize_test.cc - test/relu_test.cc + test/test.cc + test/utils_test.cc test/log4_test.cc intgemm.cc ) @@ -22,6 +22,9 @@ template <class T> class AlignedVector { T *end() { return mem_ + size_; } const T *end() const { return mem_ + size_; } + template <typename ReturnType> + ReturnType *as() { return reinterpret_cast<ReturnType*>(mem_); } + private: T *mem_; std::size_t size_; diff --git a/interleave.h b/interleave.h index 4c4e956..d9ade05 100644 --- a/interleave.h +++ b/interleave.h @@ -3,6 +3,7 @@ #include "intrinsics.h" #include "types.h" +#include <algorithm> #include <cassert> #include <stdint.h> diff --git a/intrinsics.h b/intrinsics.h index 7c36d6b..293efc3 100644 --- a/intrinsics.h +++ b/intrinsics.h @@ -36,6 +36,9 @@ INTGEMM_SSE2 static inline __m128i add_epi32(__m128i first, __m128i second) { INTGEMM_SSE2 static inline __m128i adds_epi16(__m128i first, __m128i second) { return _mm_adds_epi16(first, second); } +INTGEMM_SSE2 static inline __m128 add_ps(__m128 a, __m128 b) { + return _mm_add_ps(a, b); +} INTGEMM_SSE2 static inline __m128 and_ps(__m128 first, __m128 second) { return _mm_and_ps(first, second); } @@ -45,6 +48,15 @@ INTGEMM_SSE2 static inline __m128 cvtepi32_ps(__m128i arg) { INTGEMM_SSE2 static inline __m128i cvtps_epi32(__m128 arg) { return _mm_cvtps_epi32(arg); } +INTGEMM_SSE2 static inline __m128i cvttps_epi32(__m128 a) { + return _mm_cvttps_epi32(a); +} +INTGEMM_SSE2 static inline __m128 div_ps(__m128 a, __m128 b) { + return _mm_div_ps(a, b); +} +/* + * Missing i32gather_ps for SSE2 + */ template <> INTGEMM_SSE2 inline __m128 loadu_ps(const float* mem_addr) { return _mm_loadu_ps(mem_addr); } @@ -54,9 +66,18 @@ INTGEMM_SSE2 static inline __m128i madd_epi16(__m128i first, __m128i second) { INTGEMM_SSSE3 static inline __m128i maddubs_epi16(__m128i first, __m128i second) { return _mm_maddubs_epi16(first, second); } +/* + * Missing max_epi8 for SSE2 + */ +INTGEMM_SSE2 static inline __m128i max_epi16(__m128i first, __m128i second) { + return _mm_max_epi16(first, second); +} INTGEMM_SSE2 static inline __m128 max_ps(__m128 first, __m128 second) { return _mm_max_ps(first, second); } +INTGEMM_SSE2 static inline __m128 min_ps(__m128 a, __m128 b) { + return _mm_min_ps(a, b); +} INTGEMM_SSE2 static inline __m128 mul_ps(__m128 a, __m128 b) { return _mm_mul_ps(a, b); } @@ -81,8 +102,8 @@ INTGEMM_SSSE3 static inline __m128i sign_epi8(__m128i first, __m128i second) { INTGEMM_SSE2 static inline void storeu_ps(float* mem_addr, __m128 a) { _mm_storeu_ps(mem_addr, a); } -INTGEMM_SSE2 static inline __m128 add_ps (__m128 a, __m128 b) { - return _mm_add_ps(a, b); +INTGEMM_SSE2 static inline __m128 sub_ps(__m128 a, __m128 b) { + return _mm_sub_ps(a, b); } /* @@ -99,6 +120,9 @@ INTGEMM_AVX2 static inline __m256i add_epi32(__m256i first, __m256i second) { INTGEMM_AVX2 static inline __m256i adds_epi16(__m256i first, __m256i second) { return _mm256_adds_epi16(first, second); } +INTGEMM_AVX2 static inline __m256 add_ps(__m256 a, __m256 b) { + return _mm256_add_ps(a, b); +} INTGEMM_AVX2 static inline __m256 and_ps(__m256 first, __m256 second) { return _mm256_and_ps(first, second); } @@ -108,6 +132,16 @@ INTGEMM_AVX2 static inline __m256 cvtepi32_ps(__m256i arg) { INTGEMM_AVX2 static inline __m256i cvtps_epi32(__m256 arg) { return _mm256_cvtps_epi32(arg); } +INTGEMM_AVX2 static inline __m256i cvttps_epi32(__m256 a) { + return _mm256_cvttps_epi32(a); +} +INTGEMM_AVX2 static inline __m256 div_ps(__m256 a, __m256 b) { + return _mm256_div_ps(a, b); +} +template <unsigned Scale> +INTGEMM_AVX2 static inline __m256 i32gather_ps(float const *base_addr, __m256i vindex) { + return _mm256_i32gather_ps(base_addr, vindex, Scale); +} template <> INTGEMM_AVX2 inline __m256 loadu_ps(const float* mem_addr) { return _mm256_loadu_ps(mem_addr); } @@ -117,9 +151,18 @@ INTGEMM_AVX2 static inline __m256i madd_epi16(__m256i first, __m256i second) { INTGEMM_AVX2 static inline __m256i maddubs_epi16(__m256i first, __m256i second) { return _mm256_maddubs_epi16(first, second); } +INTGEMM_AVX2 static inline __m256i max_epi8(__m256i first, __m256i second) { + return _mm256_max_epi8(first, second); +} +INTGEMM_AVX2 static inline __m256i max_epi16(__m256i first, __m256i second) { + return _mm256_max_epi16(first, second); +} INTGEMM_AVX2 static inline __m256 max_ps(__m256 first, __m256 second) { return _mm256_max_ps(first, second); } +INTGEMM_AVX2 static inline __m256 min_ps(__m256 a, __m256 b) { + return _mm256_min_ps(a, b); +} INTGEMM_AVX2 static inline __m256 mul_ps(__m256 a, __m256 b) { return _mm256_mul_ps(a, b); } @@ -144,8 +187,8 @@ INTGEMM_AVX2 static inline __m256i sign_epi8(__m256i first, __m256i second) { INTGEMM_AVX2 static inline void storeu_ps(float* mem_addr, __m256 a) { _mm256_storeu_ps(mem_addr, a); } -INTGEMM_AVX2 static inline __m256 add_ps (__m256 a, __m256 b) { - return _mm256_add_ps(a, b); +INTGEMM_AVX2 static inline __m256 sub_ps(__m256 a, __m256 b) { + return _mm256_sub_ps(a, b); } /* @@ -164,6 +207,9 @@ INTGEMM_AVX512BW static inline __m512i add_epi32(__m512i first, __m512i second) INTGEMM_AVX512BW static inline __m512i adds_epi16(__m512i first, __m512i second) { return _mm512_adds_epi16(first, second); } +INTGEMM_AVX512BW static inline __m512 add_ps(__m512 a, __m512 b) { + return _mm512_add_ps(a, b); +} INTGEMM_AVX512DQ static inline __m512 and_ps(__m512 first, __m512 second) { return _mm512_and_ps(first, second); } @@ -173,6 +219,16 @@ INTGEMM_AVX512BW static inline __m512 cvtepi32_ps(__m512i arg) { INTGEMM_AVX512BW static inline __m512i cvtps_epi32(__m512 arg) { return _mm512_cvtps_epi32(arg); } +INTGEMM_AVX512BW static inline __m512i cvttps_epi32(__m512 a) { + return _mm512_cvttps_epi32(a); +} +INTGEMM_AVX512BW static inline __m512 div_ps(__m512 a, __m512 b) { + return _mm512_div_ps(a, b); +} +template <unsigned Scale> +INTGEMM_AVX512BW static inline __m512 i32gather_ps(float const *base_addr, __m512i vindex) { + return _mm512_i32gather_ps(vindex, base_addr, Scale); +} template <> INTGEMM_AVX512BW inline __m512 loadu_ps(const float* mem_addr) { return _mm512_loadu_ps(mem_addr); } @@ -182,11 +238,17 @@ INTGEMM_AVX512BW static inline __m512i madd_epi16(__m512i first, __m512i second) INTGEMM_AVX512BW static inline __m512i maddubs_epi16(__m512i first, __m512i second) { return _mm512_maddubs_epi16(first, second); } +INTGEMM_AVX512BW static inline __m512i max_epi8(__m512i first, __m512i second) { + return _mm512_max_epi8(first, second); +} +INTGEMM_AVX512BW static inline __m512i max_epi16(__m512i first, __m512i second) { + return _mm512_max_epi16(first, second); +} INTGEMM_AVX512BW static inline __m512 max_ps(__m512 first, __m512 second) { return _mm512_max_ps(first, second); } -INTGEMM_AVX512BW static inline __m512 add_ps(__m512 first, __m512 second) { - return _mm512_add_ps(first, second); +INTGEMM_AVX512BW static inline __m512 min_ps(__m512 a, __m512 b) { + return _mm512_min_ps(a, b); } INTGEMM_AVX512BW static inline __m512 mul_ps(__m512 a, __m512 b) { return _mm512_mul_ps(a, b); @@ -212,6 +274,9 @@ template <> INTGEMM_AVX512BW inline __m512i setzero_si<__m512i>() { INTGEMM_AVX512BW static inline void storeu_ps(float* mem_addr, __m512 a) { _mm512_storeu_ps(mem_addr, a); } +INTGEMM_AVX512BW static inline __m512 sub_ps(__m512 a, __m512 b) { + return _mm512_sub_ps(a, b); +} #endif diff --git a/postprocess.h b/postprocess.h index 0855548..7835b2b 100644 --- a/postprocess.h +++ b/postprocess.h @@ -5,6 +5,10 @@ #include "types.h" #include "vec_utils.h" +// TODO: We support some postprocess in few variations e.g. we support ReLU for +// float -> float, int8 -> int8, int16 -> int16. Maybe it would be a good idea +// to pass input type and output type as a template parameter of postprocess? + namespace intgemm { /* @@ -56,6 +60,8 @@ private: __m256 unquantize_multiplier; }; +#ifndef INTGEMM_NO_AVX512 + template <> class PostprocessImpl<Unquantize, CPUType::AVX512BW> { public: @@ -74,49 +80,7 @@ private: __m512 unquantize_multiplier; }; -/* - * Identity - */ -class Identity {}; - -template <> -class PostprocessImpl<Identity, CPUType::SSE2> { -public: - using InputRegister = RegisterPair128i; - using OutputRegister = RegisterPair128i; - - PostprocessImpl(const Identity& config) {} - - INTGEMM_SSE2 inline OutputRegister run(InputRegister input, Index offset) { - return input; - } -}; - -template <> -class PostprocessImpl<Identity, CPUType::AVX2> { -public: - using InputRegister = __m256i; - using OutputRegister = __m256i; - - PostprocessImpl(const Identity& config) {} - - INTGEMM_AVX2 inline OutputRegister run(InputRegister input, Index offset) { - return input; - } -}; - -template <> -class PostprocessImpl<Identity, CPUType::AVX512BW> { -public: - using InputRegister = __m512i; - using OutputRegister = __m512i; - - PostprocessImpl(const Identity& config) {} - - INTGEMM_AVX512BW inline OutputRegister run(InputRegister input, Index offset) { - return input; - } -}; +#endif /* * Add a bias term @@ -167,6 +131,27 @@ private: const AddBias config; }; +#ifndef INTGEMM_NO_AVX512 + +template <> +class PostprocessImpl<AddBias, CPUType::AVX512BW> { +public: + using InputRegister = __m512; + using OutputRegister = __m512; + + PostprocessImpl(const AddBias& config) : config(config) {} + + INTGEMM_AVX512BW inline OutputRegister run(InputRegister input, Index offset) { + auto bias_term = *reinterpret_cast<const __m512*>(config.bias + (offset % config.length)); + return add_ps(input, bias_term); + } + +private: + const AddBias config; +}; + +#endif + /* * ReLU */ @@ -206,6 +191,8 @@ public: } }; +#ifndef INTGEMM_NO_AVX512 + template <> class PostprocessImpl<ReLU, CPUType::AVX512BW> { public: @@ -220,4 +207,183 @@ public: } }; +#endif + +/* + * ReLU_int8 + */ +class ReLU_int8 {}; + +template <> +class PostprocessImpl<ReLU_int8, CPUType::SSE2> { +public: + using InputRegister = __m128i; + using OutputRegister = __m128i; + + PostprocessImpl(const ReLU_int8& config) {} + + INTGEMM_SSE2 inline OutputRegister run(InputRegister input, Index offset) { + static const auto const_zero = setzero_si<__m128i>(); + return _mm_and_si128(_mm_cmplt_epi8(const_zero, input), input); + } +}; + +template <> +class PostprocessImpl<ReLU_int8, CPUType::AVX2> { +public: + using InputRegister = __m256i; + using OutputRegister = __m256i; + + PostprocessImpl(const ReLU_int8& config) {} + + INTGEMM_AVX2 inline OutputRegister run(InputRegister input, Index offset) { + static const auto const_zero = setzero_si<__m256i>(); + return max_epi8(const_zero, input); + } +}; + +#ifndef INTGEMM_NO_AVX512 + +template <> +class PostprocessImpl<ReLU_int8, CPUType::AVX512BW> { +public: + using InputRegister = __m512i; + using OutputRegister = __m512i; + + PostprocessImpl(const ReLU_int8& config) {} + + INTGEMM_AVX512BW inline OutputRegister run(InputRegister input, Index offset) { + static const auto const_zero = setzero_si<__m512i>(); + return max_epi8(const_zero, input); + } +}; + +#endif + +/* + * ReLU_int16 + */ +class ReLU_int16 {}; + +template <> +class PostprocessImpl<ReLU_int16, CPUType::SSE2> { +public: + using InputRegister = __m128i; + using OutputRegister = __m128i; + + PostprocessImpl(const ReLU_int16& config) {} + + INTGEMM_SSE2 inline OutputRegister run(InputRegister input, Index offset) { + static const auto const_zero = setzero_si<__m128i>(); + return max_epi16(const_zero, input); + } +}; + +template <> +class PostprocessImpl<ReLU_int16, CPUType::AVX2> { +public: + using InputRegister = __m256i; + using OutputRegister = __m256i; + + PostprocessImpl(const ReLU_int16& config) {} + + INTGEMM_AVX2 inline OutputRegister run(InputRegister input, Index offset) { + static const auto const_zero = setzero_si<__m256i>(); + return max_epi16(const_zero, input); + } +}; + +#ifndef INTGEMM_NO_AVX512 + +template <> +class PostprocessImpl<ReLU_int16, CPUType::AVX512BW> { +public: + using InputRegister = __m512i; + using OutputRegister = __m512i; + + PostprocessImpl(const ReLU_int16& config) {} + + INTGEMM_AVX512BW inline OutputRegister run(InputRegister input, Index offset) { + static const auto const_zero = setzero_si<__m512i>(); + return max_epi16(const_zero, input); + } +}; + +#endif + +/* + * Sigmoid (uses Taylor series approximation of e^x) + */ +class Sigmoid {}; + +template <> +class PostprocessImpl<Sigmoid, CPUType::AVX2> { +public: + using InputRegister = __m256; + using OutputRegister = __m256; + + PostprocessImpl(const Sigmoid& config) {} + + INTGEMM_AVX2 inline OutputRegister run(InputRegister input, Index offset) { + static const auto const_zero = set1_ps<__m256>(0.f); + static const auto const_one = set1_ps<__m256>(1.f); + + auto x = input; + auto minus_x = sub_ps(const_zero, x); + auto e_x = exp_approx_taylor(x); + auto e_minus_x = exp_approx_taylor(minus_x); + + auto sigmoid_case1 = _mm256_rcp_ps(add_ps(const_one, e_minus_x)); + auto sigmoid_case2 = mul_ps(e_x, _mm256_rcp_ps(add_ps(const_one, e_x))); + + auto nonnegative_x_mask = _mm256_cmp_ps(const_zero, x, _CMP_LT_OS); + return _mm256_blendv_ps(sigmoid_case1, sigmoid_case2, nonnegative_x_mask); + } +}; + +/* + * Tanh (uses Taylor series approximation of e^x) + */ +class Tanh {}; + +template <> +class PostprocessImpl<Tanh, CPUType::AVX2> { +public: + using InputRegister = __m256; + using OutputRegister = __m256; + + PostprocessImpl(const Tanh& config) {} + + INTGEMM_AVX2 inline OutputRegister run(InputRegister input, Index offset) { + const static auto const_zero = setzero_ps<__m256>(); + + auto e_x = exp_approx_taylor(input); + auto e_minus_x = exp_approx_taylor(sub_ps(const_zero, input)); + + return div_ps(sub_ps(e_x, e_minus_x), add_ps(e_x, e_minus_x)); + } +}; + +#ifndef INTGEMM_NO_AVX512 + +template <> +class PostprocessImpl<Tanh, CPUType::AVX512BW> { +public: + using InputRegister = __m512; + using OutputRegister = __m512; + + PostprocessImpl(const Tanh& config) {} + + INTGEMM_AVX512BW inline OutputRegister run(InputRegister input, Index offset) { + const static auto const_zero = setzero_ps<__m512>(); + + auto e_x = exp_approx_taylor(input); + auto e_minus_x = exp_approx_taylor(sub_ps(const_zero, input)); + + return div_ps(sub_ps(e_x, e_minus_x), add_ps(e_x, e_minus_x)); + } +}; + +#endif + } diff --git a/postprocess_pipeline.h b/postprocess_pipeline.h index ad26ac5..361ff2b 100644 --- a/postprocess_pipeline.h +++ b/postprocess_pipeline.h @@ -12,8 +12,8 @@ template <typename... Stages> using PostprocessPipeline = std::tuple<Stages...>; template <typename... Stages> -constexpr std::tuple<Stages...> CreatePostprocessPipeline(const Stages&... stages) { - return std::make_tuple(stages...); +constexpr std::tuple<Stages...> CreatePostprocessPipeline(Stages&&... stages) { + return std::make_tuple(std::forward<Stages>(stages)...); } template <typename Postprocess, CPUType CpuType> diff --git a/test/multiply_test.cc b/test/multiply_test.cc index 82062fe..93d7127 100644 --- a/test/multiply_test.cc +++ b/test/multiply_test.cc @@ -1,22 +1,16 @@ +#include "test/test.h" #include "aligned.h" #include "interleave.h" #include "intgemm.h" #include "multiply.h" #include "postprocess.h" -#define CATCH_CONFIG_RUNNER -#include "3rd_party/catch.hpp" -#define CHECK_MESSAGE(cond, msg) do { INFO(msg); CHECK(cond); } while((void)0, 0) -#define CHECK_FALSE_MESSAGE(cond, msg) do { INFO(msg); CHECK_FALSE(cond); } while((void)0, 0) -#define REQUIRE_MESSAGE(cond, msg) do { INFO(msg); REQUIRE(cond); } while((void)0, 0) -#define REQUIRE_FALSE_MESSAGE(cond, msg) do { INFO(msg); REQUIRE_FALSE(cond); } while((void)0, 0) - #include <algorithm> #include <cassert> #include <cmath> -#include <cstring> #include <cstdio> #include <cstdlib> +#include <cstring> #include <iomanip> #include <iostream> #include <memory> @@ -61,7 +55,7 @@ INTGEMM_SSE2 TEST_CASE("Transpose 16", "[transpose]") { SlowTranspose(input.begin(), ref.begin(), N, N); // Overwrite input. - __m128i *t = reinterpret_cast<__m128i*>(input.begin()); + __m128i *t = input.as<__m128i>(); Transpose16InLane(t[0], t[1], t[2], t[3], t[4], t[5], t[6], t[7]); for (int16_t i = 0; i < input.size(); ++i) { @@ -79,7 +73,7 @@ INTGEMM_SSSE3 TEST_CASE("Transpose 8", "[transpose]") { SlowTranspose(input.begin(), ref.begin(), N, N); // Overwrite input. - __m128i *t = reinterpret_cast<__m128i*>(input.begin()); + __m128i *t = input.as<__m128i>(); Transpose8InLane(t[0], t[1], t[2], t[3], t[4], t[5], t[6], t[7], t[8], t[9], t[10], t[11], t[12], t[13], t[14], t[15]); for (int i = 0; i < input.size(); ++i) { @@ -554,20 +548,3 @@ TEST_CASE ("Multiply AVX2 16bit with bias", "[biased_multiply]") { #endif } // namespace intgemm - -int main(int argc, char ** argv) { - return Catch::Session().run(argc, argv); -} - -/* - // Top matrix sizes from Marian - TestBoth(8, 256, 256); - TestBoth(8, 2048, 256); - TestBoth(8, 2048, 256); - TestBoth(320, 256, 256); - TestBoth(472, 256, 256); - TestBoth(248, 256, 256); - TestBoth(200, 256, 256); - return 0; -} -*/ diff --git a/test/pipeline_test.cc b/test/pipeline_test.cc deleted file mode 100644 index 1b8c21d..0000000 --- a/test/pipeline_test.cc +++ /dev/null @@ -1,70 +0,0 @@ -#include "3rd_party/catch.hpp" -#include "postprocess.h" - -#include <numeric> - -namespace intgemm { - -INTGEMM_AVX2 TEST_CASE("PostprocessPipeline AVX2", "Unquantize-ReLU") { - if (kCPU < CPUType::AVX2) - return; - - __m256i input; - __m256 output; - - auto raw_input = reinterpret_cast<int*>(&input); - std::iota(raw_input, raw_input + 8, -2); - - auto raw_output = reinterpret_cast<float*>(&output); - std::fill(raw_output, raw_output + 8, 42); - - auto pipeline = CreatePostprocessPipeline(Unquantize(0.5f), ReLU()); - auto inited_pipeline = InitPostprocessPipeline<CPUType::AVX2>(pipeline); - output = inited_pipeline.run(input, 0); - - CHECK(raw_output[0] == 0.0f); // input = -2 - CHECK(raw_output[1] == 0.0f); // input = -1 - CHECK(raw_output[2] == 0.0f); // input = 0 - CHECK(raw_output[3] == 0.5f); // input = 1 - CHECK(raw_output[4] == 1.0f); // input = 2 - CHECK(raw_output[5] == 1.5f); // input = 3 - CHECK(raw_output[6] == 2.0f); // input = 4 - CHECK(raw_output[7] == 2.5f); // input = 5 -} - -INTGEMM_AVX2 TEST_CASE("PostprocessPipeline AVX2 on whole buffer", "Unquantize-ReLU") { - if (kCPU < CPUType::AVX2) - return; - - __m256i input[2]; - __m256 output[2]; - - auto raw_input = reinterpret_cast<int*>(input); - std::iota(raw_input, raw_input + 16, -8); - - auto raw_output = reinterpret_cast<float*>(output); - std::fill(raw_output, raw_output + 16, 42); - - auto pipeline = CreatePostprocessPipeline(Unquantize(0.5f), ReLU()); - auto inited_pipeline = InitPostprocessPipeline<CPUType::AVX2>(pipeline); - inited_pipeline.run(input, 2, output); - - CHECK(raw_output[0] == 0.f); // input = -8 - CHECK(raw_output[1] == 0.f); // input = -7 - CHECK(raw_output[2] == 0.f); // input = -6 - CHECK(raw_output[3] == 0.f); // input = -5 - CHECK(raw_output[4] == 0.f); // input = -4 - CHECK(raw_output[5] == 0.f); // input = -3 - CHECK(raw_output[6] == 0.f); // input = -2 - CHECK(raw_output[7] == 0.f); // input = -1 - CHECK(raw_output[8] == 0.0f); // input = 0 - CHECK(raw_output[9] == 0.5f); // input = 1 - CHECK(raw_output[10] == 1.0f); // input = 2 - CHECK(raw_output[11] == 1.5f); // input = 3 - CHECK(raw_output[12] == 2.0f); // input = 4 - CHECK(raw_output[13] == 2.5f); // input = 5 - CHECK(raw_output[14] == 3.0f); // input = 6 - CHECK(raw_output[15] == 3.5f); // input = 7 -} - -} diff --git a/test/postprocess/add_bias_test.cc b/test/postprocess/add_bias_test.cc new file mode 100644 index 0000000..5e893ea --- /dev/null +++ b/test/postprocess/add_bias_test.cc @@ -0,0 +1,95 @@ +#include "test/test.h" +#include "aligned.h" +#include "postprocess.h" + +#include <numeric> + +namespace intgemm { + +INTGEMM_SSE2 TEST_CASE("AddBias SSE2",) { + if (kCPU < CPUType::SSE2) + return; + + AlignedVector<float> input(8); + AlignedVector<float> bias(8); + AlignedVector<float> output(8); + + std::iota(input.begin(), input.end(), -2); + std::iota(bias.begin(), bias.end(), 0); + + auto postproc = PostprocessImpl<AddBias, CPUType::SSE2>(AddBias(bias.begin(), bias.size())); + auto output_tmp = postproc.run({input.as<__m128>()[0], input.as<__m128>()[1]}, 0); + output.as<__m128>()[0] = output_tmp.pack0123; + output.as<__m128>()[1] = output_tmp.pack4567; + + CHECK(output[0] == -2.f); // input = -2, bias = 0 + CHECK(output[1] == 0.f); // input = -1, bias = 1 + CHECK(output[2] == 2.f); // input = 0, bias = 2 + CHECK(output[3] == 4.f); // input = 1, bias = 3 + CHECK(output[4] == 6.f); // input = 2, bias = 4 + CHECK(output[5] == 8.f); // input = 3, bias = 5 + CHECK(output[6] == 10.f); // input = 4, bias = 6 + CHECK(output[7] == 12.f); // input = 5, bias = 7 +} + +INTGEMM_AVX2 TEST_CASE("AddBias AVX2",) { + if (kCPU < CPUType::AVX2) + return; + + AlignedVector<float> input(8); + AlignedVector<float> bias(8); + AlignedVector<float> output(8); + + std::iota(input.begin(), input.end(), -4); + std::iota(bias.begin(), bias.end(), 0); + + auto postproc = PostprocessImpl<AddBias, CPUType::AVX2>(AddBias(bias.begin(), bias.size())); + *output.as<__m256>() = postproc.run(*input.as<__m256>(), 0); + + CHECK(output[0] == -4.f); // input = -4, bias = 0 + CHECK(output[1] == -2.f); // input = -3, bias = 1 + CHECK(output[2] == 0.f); // input = -2, bias = 2 + CHECK(output[3] == 2.f); // input = -1, bias = 3 + CHECK(output[4] == 4.f); // input = 0, bias = 4 + CHECK(output[5] == 6.f); // input = 1, bias = 5 + CHECK(output[6] == 8.f); // input = 2, bias = 6 + CHECK(output[7] == 10.f); // input = 3, bias = 7 +} + +#ifndef INTGEMM_NO_AVX512 + +INTGEMM_AVX512BW TEST_CASE("AddBias AVX512",) { + if (kCPU < CPUType::AVX512BW) + return; + + AlignedVector<float> input(16); + AlignedVector<float> bias(16); + AlignedVector<float> output(16); + + std::iota(input.begin(), input.end(), -8); + std::iota(bias.begin(), bias.end(), 0); + + auto postproc = PostprocessImpl<AddBias, CPUType::AVX512BW>(AddBias(bias.begin(), bias.size())); + *output.as<__m512>() = postproc.run(*input.as<__m512>(), 0); + + CHECK(output[0] == -8.f); // input = -8, bias = 0 + CHECK(output[1] == -6.f); // input = -7, bias = 1 + CHECK(output[2] == -4.f); // input = -6, bias = 2 + CHECK(output[3] == -2.f); // input = -5, bias = 3 + CHECK(output[4] == 0.f); // input = -4, bias = 4 + CHECK(output[5] == 2.f); // input = -3, bias = 5 + CHECK(output[6] == 4.f); // input = -2, bias = 6 + CHECK(output[7] == 6.f); // input = -1, bias = 7 + CHECK(output[8] == 8.f); // input = 0, bias = 8 + CHECK(output[9] == 10.f); // input = 1, bias = 9 + CHECK(output[10] == 12.f); // input = 2, bias = 10 + CHECK(output[11] == 14.f); // input = 3, bias = 11 + CHECK(output[12] == 16.f); // input = 4, bias = 12 + CHECK(output[13] == 18.f); // input = 5, bias = 13 + CHECK(output[14] == 20.f); // input = 6, bias = 14 + CHECK(output[15] == 22.f); // input = 7, bias = 15 +} + +#endif + +} diff --git a/test/postprocess/pipeline_test.cc b/test/postprocess/pipeline_test.cc new file mode 100644 index 0000000..144ee48 --- /dev/null +++ b/test/postprocess/pipeline_test.cc @@ -0,0 +1,63 @@ +#include "test/test.h" +#include "aligned.h" +#include "postprocess.h" + +#include <numeric> + +namespace intgemm { + +INTGEMM_AVX2 TEST_CASE("PostprocessPipeline AVX2", "Unquantize-ReLU") { + if (kCPU < CPUType::AVX2) + return; + + AlignedVector<int32_t> input(8); + AlignedVector<float> output(8); + + std::iota(input.begin(), input.end(), -2); + + auto pipeline = CreatePostprocessPipeline(Unquantize(0.5f), ReLU()); + auto inited_pipeline = InitPostprocessPipeline<CPUType::AVX2>(pipeline); + *output.as<__m256>() = inited_pipeline.run(*input.as<__m256i>(), 0); + + CHECK(output[0] == 0.0f); // input = -2 + CHECK(output[1] == 0.0f); // input = -1 + CHECK(output[2] == 0.0f); // input = 0 + CHECK(output[3] == 0.5f); // input = 1 + CHECK(output[4] == 1.0f); // input = 2 + CHECK(output[5] == 1.5f); // input = 3 + CHECK(output[6] == 2.0f); // input = 4 + CHECK(output[7] == 2.5f); // input = 5 +} + +INTGEMM_AVX2 TEST_CASE("PostprocessPipeline AVX2 on whole buffer", "Unquantize-ReLU") { + if (kCPU < CPUType::AVX2) + return; + + AlignedVector<int32_t> input(16); + AlignedVector<float> output(16); + + std::iota(input.begin(), input.end(), -8); + + auto pipeline = CreatePostprocessPipeline(Unquantize(0.5f), ReLU()); + auto inited_pipeline = InitPostprocessPipeline<CPUType::AVX2>(pipeline); + inited_pipeline.run(input.as<__m256i>(), 2, output.as<__m256>()); + + CHECK(output[0] == 0.f); // input = -8 + CHECK(output[1] == 0.f); // input = -7 + CHECK(output[2] == 0.f); // input = -6 + CHECK(output[3] == 0.f); // input = -5 + CHECK(output[4] == 0.f); // input = -4 + CHECK(output[5] == 0.f); // input = -3 + CHECK(output[6] == 0.f); // input = -2 + CHECK(output[7] == 0.f); // input = -1 + CHECK(output[8] == 0.0f); // input = 0 + CHECK(output[9] == 0.5f); // input = 1 + CHECK(output[10] == 1.0f); // input = 2 + CHECK(output[11] == 1.5f); // input = 3 + CHECK(output[12] == 2.0f); // input = 4 + CHECK(output[13] == 2.5f); // input = 5 + CHECK(output[14] == 3.0f); // input = 6 + CHECK(output[15] == 3.5f); // input = 7 +} + +} diff --git a/test/postprocess/relu_test.cc b/test/postprocess/relu_test.cc new file mode 100644 index 0000000..af6677e --- /dev/null +++ b/test/postprocess/relu_test.cc @@ -0,0 +1,213 @@ +#include "test/test.h" +#include "aligned.h" +#include "postprocess.h" + +#include <numeric> + +namespace intgemm { + +/* + * ReLU: float -> float + */ +INTGEMM_SSE2 TEST_CASE("ReLU SSE2",) { + if (kCPU < CPUType::SSE2) + return; + + AlignedVector<float> input(8); + AlignedVector<float> output(8); + std::iota(input.begin(), input.end(), -2); + + auto postproc = PostprocessImpl<ReLU, CPUType::SSE2>(ReLU()); + auto output_tmp = postproc.run({input.as<__m128>()[0], input.as<__m128>()[1]}, 0); + output.as<__m128>()[0] = output_tmp.pack0123; + output.as<__m128>()[1] = output_tmp.pack4567; + + CHECK(output[0] == 0.f); // input = -2 + CHECK(output[1] == 0.f); // input = -1 + CHECK(output[2] == 0.f); // input = 0 + CHECK(output[3] == 1.f); // input = 1 + CHECK(output[4] == 2.f); // input = 2 + CHECK(output[5] == 3.f); // input = 3 + CHECK(output[6] == 4.f); // input = 4 + CHECK(output[7] == 5.f); // input = 5 +} + +INTGEMM_AVX2 TEST_CASE("ReLU AVX2",) { + if (kCPU < CPUType::AVX2) + return; + + AlignedVector<float> input(8); + AlignedVector<float> output(8); + + std::iota(input.begin(), input.end(), -4); + + auto postproc = PostprocessImpl<ReLU, CPUType::AVX2>(ReLU()); + *output.as<__m256>() = postproc.run(*input.as<__m256>(), 0); + + CHECK(output[0] == 0.f); // input = -4 + CHECK(output[1] == 0.f); // input = -3 + CHECK(output[2] == 0.f); // input = -2 + CHECK(output[3] == 0.f); // input = -1 + CHECK(output[4] == 0.f); // input = 0 + CHECK(output[5] == 1.f); // input = 1 + CHECK(output[6] == 2.f); // input = 2 + CHECK(output[7] == 3.f); // input = 3 +} + +#ifndef INTGEMM_NO_AVX512 + +INTGEMM_AVX512BW TEST_CASE("ReLU AVX512",) { + if (kCPU < CPUType::AVX512BW) + return; + + AlignedVector<float> input(16); + AlignedVector<float> output(16); + + std::iota(input.begin(), input.end(), -8); + + auto postproc = PostprocessImpl<ReLU, CPUType::AVX512BW>(ReLU()); + *output.as<__m512>() = postproc.run(*input.as<__m512>(), 0); + + CHECK(output[0] == 0.f); // input = -8 + CHECK(output[1] == 0.f); // input = -7 + CHECK(output[2] == 0.f); // input = -6 + CHECK(output[3] == 0.f); // input = -5 + CHECK(output[4] == 0.f); // input = -4 + CHECK(output[5] == 0.f); // input = -3 + CHECK(output[6] == 0.f); // input = -2 + CHECK(output[7] == 0.f); // input = -1 + CHECK(output[8] == 0.f); // input = 0 + CHECK(output[9] == 1.f); // input = 1 + CHECK(output[10] == 2.f); // input = 2 + CHECK(output[11] == 3.f); // input = 3 + CHECK(output[12] == 4.f); // input = 4 + CHECK(output[13] == 5.f); // input = 5 + CHECK(output[14] == 6.f); // input = 6 + CHECK(output[15] == 7.f); // input = 7 +} + +#endif + +/* + * ReLU: int8 -> int8 + */ +INTGEMM_SSE2 TEST_CASE("ReLU_int8 SSE2",) { + if (kCPU < CPUType::SSE2) + return; + + const unsigned NEGATIVE_NUMBERS = 10; + + AlignedVector<int8_t> input(16); + AlignedVector<int8_t> output(16); + + std::iota(input.begin(), input.end(), -NEGATIVE_NUMBERS); + + auto postproc = PostprocessImpl<ReLU_int8, CPUType::SSE2>(ReLU_int8()); + *output.as<__m128i>() = postproc.run(*input.as<__m128i>(), 0); + + for (auto i = 0; i < output.size(); ++i) + CHECK(output[i] == (i <= NEGATIVE_NUMBERS ? 0 : i - NEGATIVE_NUMBERS)); +} + +INTGEMM_AVX2 TEST_CASE("ReLU_int8 AVX2",) { + if (kCPU < CPUType::AVX2) + return; + + const unsigned NEGATIVE_NUMBERS = 10; + + AlignedVector<int8_t> input(32); + AlignedVector<int8_t> output(32); + + std::iota(input.begin(), input.end(), -NEGATIVE_NUMBERS); + + auto postproc = PostprocessImpl<ReLU_int8, CPUType::AVX2>(ReLU_int8()); + *output.as<__m256i>() = postproc.run(*input.as<__m256i>(), 0); + + for (auto i = 0; i < output.size(); ++i) + CHECK(output[i] == (i <= NEGATIVE_NUMBERS ? 0 : i - NEGATIVE_NUMBERS)); +} + +#ifndef INTGEMM_NO_AVX512 + +INTGEMM_AVX512BW TEST_CASE("ReLU_int8 AVX512",) { + if (kCPU < CPUType::AVX512BW) + return; + + const unsigned NEGATIVE_NUMBERS = 30; + + AlignedVector<int8_t> input(64); + AlignedVector<int8_t> output(64); + + std::iota(input.begin(), input.end(), -NEGATIVE_NUMBERS); + + auto postproc = PostprocessImpl<ReLU_int8, CPUType::AVX512BW>(ReLU_int8()); + *output.as<__m512i>() = postproc.run(*input.as<__m512i>(), 0); + + for (auto i = 0; i < output.size(); ++i) + CHECK(output[i] == (i <= NEGATIVE_NUMBERS ? 0 : i - NEGATIVE_NUMBERS)); +} + +#endif + +/* + * ReLU: int16 -> int16 + */ +INTGEMM_SSE2 TEST_CASE("ReLU_int16 SSE2",) { + if (kCPU < CPUType::SSE2) + return; + + const unsigned NEGATIVE_NUMBERS = 5; + + AlignedVector<int16_t> input(8); + AlignedVector<int16_t> output(8); + + std::iota(input.begin(), input.end(), -NEGATIVE_NUMBERS); + + auto postproc = PostprocessImpl<ReLU_int16, CPUType::SSE2>(ReLU_int16()); + *output.as<__m128i>() = postproc.run(*input.as<__m128i>(), 0); + + for (auto i = 0; i < output.size(); ++i) + CHECK(output[i] == (i <= NEGATIVE_NUMBERS ? 0 : i - NEGATIVE_NUMBERS)); +} + +INTGEMM_AVX2 TEST_CASE("ReLU_int16 AVX2",) { + if (kCPU < CPUType::AVX2) + return; + + const unsigned NEGATIVE_NUMBERS = 10; + + AlignedVector<int16_t> input(16); + AlignedVector<int16_t> output(16); + + std::iota(input.begin(), input.end(), -NEGATIVE_NUMBERS); + + auto postproc = PostprocessImpl<ReLU_int16, CPUType::AVX2>(ReLU_int16()); + *output.as<__m256i>() = postproc.run(*input.as<__m256i>(), 0); + + for (auto i = 0; i < output.size(); ++i) + CHECK(output[i] == (i <= NEGATIVE_NUMBERS ? 0 : i - NEGATIVE_NUMBERS)); +} + +#ifndef INTGEMM_NO_AVX512 + +INTGEMM_AVX512BW TEST_CASE("ReLU_int16 AVX512",) { + if (kCPU < CPUType::AVX512BW) + return; + + const unsigned NEGATIVE_NUMBERS = 15; + + AlignedVector<int16_t> input(32); + AlignedVector<int16_t> output(32); + + std::iota(input.begin(), input.end(), -NEGATIVE_NUMBERS); + + auto postproc = PostprocessImpl<ReLU_int16, CPUType::AVX512BW>(ReLU_int16()); + *output.as<__m512i>() = postproc.run(*input.as<__m512i>(), 0); + + for (auto i = 0; i < output.size(); ++i) + CHECK(output[i] == (i <= NEGATIVE_NUMBERS ? 0 : i - NEGATIVE_NUMBERS)); +} + +#endif + +} diff --git a/test/postprocess/sigmoid_test.cc b/test/postprocess/sigmoid_test.cc new file mode 100644 index 0000000..43c713c --- /dev/null +++ b/test/postprocess/sigmoid_test.cc @@ -0,0 +1,33 @@ +#include "test/test.h" +#include "aligned.h" +#include "postprocess.h" + +#include <numeric> + +namespace intgemm { + +INTGEMM_AVX2 TEST_CASE("Sigmoid AVX2",) { + if (kCPU < CPUType::AVX2) + return; + + const float error_tolerance = 0.001f; + + AlignedVector<float> input(8); + AlignedVector<float> output(8); + + std::iota(input.begin(), input.end(), -4); + + auto postproc = PostprocessImpl<Sigmoid, CPUType::AVX2>(Sigmoid()); + *output.as<__m256>() = postproc.run(*input.as<__m256>(), 0); + + CHECK_EPS(output[0], 0.0179862f, error_tolerance); // input = -4 + CHECK_EPS(output[1], 0.0474259f, error_tolerance); // input = -3 + CHECK_EPS(output[2], 0.1192029f, error_tolerance); // input = -2 + CHECK_EPS(output[3], 0.2689414f, error_tolerance); // input = -1 + CHECK_EPS(output[4], 0.5f , error_tolerance); // input = 0 + CHECK_EPS(output[5], 0.7310586f, error_tolerance); // input = 1 + CHECK_EPS(output[6], 0.8807970f, error_tolerance); // input = 2 + CHECK_EPS(output[7], 0.9525740f, error_tolerance); // input = 3 +} + +} diff --git a/test/postprocess/tanh_test.cc b/test/postprocess/tanh_test.cc new file mode 100644 index 0000000..f0e4dc2 --- /dev/null +++ b/test/postprocess/tanh_test.cc @@ -0,0 +1,33 @@ +#include "test/test.h" +#include "aligned.h" +#include "postprocess.h" + +#include <numeric> + +namespace intgemm { + +INTGEMM_AVX2 TEST_CASE("Tanh AVX2",) { + if (kCPU < CPUType::AVX2) + return; + + const float error_tolerance = 0.001f; + + AlignedVector<float> input(8); + AlignedVector<float> output(8); + + std::generate(input.begin(), input.end(), [] () { static int n = -4; return n++ / 4.f; }); + + auto postproc = PostprocessImpl<Tanh, CPUType::AVX2>(Tanh()); + *output.as<__m256>() = postproc.run(*input.as<__m256>(), 0); + + CHECK_EPS(output[0], -0.7615942f, error_tolerance); // input = -1 + CHECK_EPS(output[1], -0.6351490f, error_tolerance); // input = -0.75 + CHECK_EPS(output[2], -0.4621172f, error_tolerance); // input = -0.5 + CHECK_EPS(output[3], -0.2449187f, error_tolerance); // input = -0.25 + CHECK_EPS(output[4], 0.0f , error_tolerance); // input = 0 + CHECK_EPS(output[5], 0.2449187f, error_tolerance); // input = 0.25 + CHECK_EPS(output[6], 0.4621172f, error_tolerance); // input = 0.5 + CHECK_EPS(output[7], 0.6351490f, error_tolerance); // input = 0.75 +} + +} diff --git a/test/postprocess/unquantize_test.cc b/test/postprocess/unquantize_test.cc new file mode 100644 index 0000000..c33b909 --- /dev/null +++ b/test/postprocess/unquantize_test.cc @@ -0,0 +1,88 @@ +#include "test/test.h" +#include "aligned.h" +#include "postprocess.h" + +#include <numeric> + +namespace intgemm { + +INTGEMM_SSE2 TEST_CASE("Unquantize SSE2",) { + if (kCPU < CPUType::SSE2) + return; + + AlignedVector<int32_t> input(8); + AlignedVector<float> output(8); + std::iota(input.begin(), input.end(), -2); + + auto postproc = PostprocessImpl<Unquantize, CPUType::SSE2>(Unquantize(0.5f)); + auto output_tmp = postproc.run({input.as<__m128i>()[0], input.as<__m128i>()[1]}, 0); + output.as<__m128>()[0] = output_tmp.pack0123; + output.as<__m128>()[1] = output_tmp.pack4567; + + CHECK(output[0] == -1.0f); // input = -2 + CHECK(output[1] == -0.5f); // input = -1 + CHECK(output[2] == 0.0f); // input = 0 + CHECK(output[3] == 0.5f); // input = 1 + CHECK(output[4] == 1.0f); // input = 2 + CHECK(output[5] == 1.5f); // input = 3 + CHECK(output[6] == 2.0f); // input = 4 + CHECK(output[7] == 2.5f); // input = 5 +} + +INTGEMM_AVX2 TEST_CASE("Unquantize AVX2",) { + if (kCPU < CPUType::AVX2) + return; + + AlignedVector<int32_t> input(8); + AlignedVector<float> output(8); + + std::iota(input.begin(), input.end(), -4); + + auto postproc = PostprocessImpl<Unquantize, CPUType::AVX2>(Unquantize(0.5f)); + *output.as<__m256>() = postproc.run(*input.as<__m256i>(), 0); + + CHECK(output[0] == -2.0f); // input = -4 + CHECK(output[1] == -1.5f); // input = -3 + CHECK(output[2] == -1.0f); // input = -2 + CHECK(output[3] == -0.5f); // input = -1 + CHECK(output[4] == 0.0f); // input = 0 + CHECK(output[5] == 0.5f); // input = 1 + CHECK(output[6] == 1.0f); // input = 2 + CHECK(output[7] == 1.5f); // input = 3 +} + +#ifndef INTGEMM_NO_AVX512 + +INTGEMM_AVX512BW TEST_CASE("Unquantize AVX512",) { + if (kCPU < CPUType::AVX512BW) + return; + + AlignedVector<int32_t> input(16); + AlignedVector<float> output(16); + + std::iota(input.begin(), input.end(), -8); + + auto postproc = PostprocessImpl<Unquantize, CPUType::AVX512BW>(Unquantize(0.5f)); + *output.as<__m512>() = postproc.run(*input.as<__m512i>(), 0); + + CHECK(output[0] == -4.0f); // input = -8 + CHECK(output[1] == -3.5f); // input = -7 + CHECK(output[2] == -3.0f); // input = -6 + CHECK(output[3] == -2.5f); // input = -5 + CHECK(output[4] == -2.0f); // input = -4 + CHECK(output[5] == -1.5f); // input = -3 + CHECK(output[6] == -1.0f); // input = -2 + CHECK(output[7] == -0.5f); // input = -1 + CHECK(output[8] == 0.0f); // input = 0 + CHECK(output[9] == 0.5f); // input = 1 + CHECK(output[10] == 1.0f); // input = 2 + CHECK(output[11] == 1.5f); // input = 3 + CHECK(output[12] == 2.0f); // input = 4 + CHECK(output[13] == 2.5f); // input = 5 + CHECK(output[14] == 3.0f); // input = 6 + CHECK(output[15] == 3.5f); // input = 7 +} + +#endif + +} diff --git a/test/quantize_test.cc b/test/quantize_test.cc index fb866f1..fd7f0a4 100644 --- a/test/quantize_test.cc +++ b/test/quantize_test.cc @@ -1,15 +1,13 @@ -#include "avx512_gemm.h" +#include "test/test.h" +#include "aligned.h" #include "avx2_gemm.h" -#include "ssse3_gemm.h" +#include "avx512_gemm.h" #include "sse2_gemm.h" -#include "aligned.h" - -#include "3rd_party/catch.hpp" +#include "ssse3_gemm.h" #include <cstring> -#include <math.h> - #include <iostream> +#include <math.h> namespace intgemm { namespace { diff --git a/test/relu_test.cc b/test/relu_test.cc deleted file mode 100644 index 183f415..0000000 --- a/test/relu_test.cc +++ /dev/null @@ -1,89 +0,0 @@ -#include "3rd_party/catch.hpp" -#include "postprocess.h" - -#include <numeric> - -namespace intgemm { - -INTGEMM_SSE2 TEST_CASE("ReLU SSE2",) { - if (kCPU < CPUType::SSE2) - return; - - float raw_input[8]; - std::iota(raw_input, raw_input + 8, -2); - - RegisterPair128 input; - input.pack0123 = *reinterpret_cast<__m128*>(raw_input); - input.pack4567 = *reinterpret_cast<__m128*>(raw_input + 4); - - auto postproc = PostprocessImpl<ReLU, CPUType::SSE2>(ReLU()); - auto output = postproc.run(input, 0); - auto raw_output = reinterpret_cast<float*>(&output); - - CHECK(raw_output[0] == 0.f); // input = -2 - CHECK(raw_output[1] == 0.f); // input = -1 - CHECK(raw_output[2] == 0.f); // input = 0 - CHECK(raw_output[3] == 1.f); // input = 1 - CHECK(raw_output[4] == 2.f); // input = 2 - CHECK(raw_output[5] == 3.f); // input = 3 - CHECK(raw_output[6] == 4.f); // input = 4 - CHECK(raw_output[7] == 5.f); // input = 5 -} - -INTGEMM_AVX2 TEST_CASE("ReLU AVX2",) { - if (kCPU < CPUType::AVX2) - return; - - float raw_input[8]; - std::iota(raw_input, raw_input + 8, -4); - - auto input = *reinterpret_cast<__m256*>(raw_input); - auto postproc = PostprocessImpl<ReLU, CPUType::AVX2>(ReLU()); - auto output = postproc.run(input, 0); - auto raw_output = reinterpret_cast<float*>(&output); - - CHECK(raw_output[0] == 0.f); // input = -4 - CHECK(raw_output[1] == 0.f); // input = -3 - CHECK(raw_output[2] == 0.f); // input = -2 - CHECK(raw_output[3] == 0.f); // input = -1 - CHECK(raw_output[4] == 0.f); // input = 0 - CHECK(raw_output[5] == 1.f); // input = 1 - CHECK(raw_output[6] == 2.f); // input = 2 - CHECK(raw_output[7] == 3.f); // input = 3 -} - -#ifndef INTGEMM_NO_AVX512 - -INTGEMM_AVX512BW TEST_CASE("ReLU AVX512",) { - if (kCPU < CPUType::AVX512BW) - return; - - float raw_input[16]; - std::iota(raw_input, raw_input + 16, -8); - - auto input = *reinterpret_cast<__m512*>(raw_input); - auto postproc = PostprocessImpl<ReLU, CPUType::AVX512BW>(ReLU()); - auto output = postproc.run(input, 0); - auto raw_output = reinterpret_cast<float*>(&output); - - CHECK(raw_output[0] == 0.f); // input = -8 - CHECK(raw_output[1] == 0.f); // input = -7 - CHECK(raw_output[2] == 0.f); // input = -6 - CHECK(raw_output[3] == 0.f); // input = -5 - CHECK(raw_output[4] == 0.f); // input = -4 - CHECK(raw_output[5] == 0.f); // input = -3 - CHECK(raw_output[6] == 0.f); // input = -2 - CHECK(raw_output[7] == 0.f); // input = -1 - CHECK(raw_output[8] == 0.f); // input = 0 - CHECK(raw_output[9] == 1.f); // input = 1 - CHECK(raw_output[10] == 2.f); // input = 2 - CHECK(raw_output[11] == 3.f); // input = 3 - CHECK(raw_output[12] == 4.f); // input = 4 - CHECK(raw_output[13] == 5.f); // input = 5 - CHECK(raw_output[14] == 6.f); // input = 6 - CHECK(raw_output[15] == 7.f); // input = 7 -} - -#endif - -} diff --git a/test/test.cc b/test/test.cc new file mode 100644 index 0000000..58c62f8 --- /dev/null +++ b/test/test.cc @@ -0,0 +1,6 @@ +#define CATCH_CONFIG_RUNNER +#include "test/test.h" + +int main(int argc, char ** argv) { + return Catch::Session().run(argc, argv); +} diff --git a/test/test.h b/test/test.h new file mode 100644 index 0000000..572a529 --- /dev/null +++ b/test/test.h @@ -0,0 +1,12 @@ +#include "3rd_party/catch.hpp" + +#define CHECK_MESSAGE(cond, msg) do { INFO(msg); CHECK(cond); } while(0) +#define CHECK_FALSE_MESSAGE(cond, msg) do { INFO(msg); CHECK_FALSE(cond); } while(0) +#define REQUIRE_MESSAGE(cond, msg) do { INFO(msg); REQUIRE(cond); } while(0) +#define REQUIRE_FALSE_MESSAGE(cond, msg) do { INFO(msg); REQUIRE_FALSE(cond); } while(0) + +#define CHECK_EPS(actual, expected, epsilon) \ + do { \ + if (fabs((actual) - (expected)) < epsilon) { SUCCEED(); } \ + else { CHECK((actual) == (expected)); } \ + } while(0) diff --git a/test/utils_test.cc b/test/utils_test.cc new file mode 100644 index 0000000..580a872 --- /dev/null +++ b/test/utils_test.cc @@ -0,0 +1,38 @@ +#include "test/test.h" +#include "utils.h" + +namespace intgemm { +namespace { + +TEST_CASE("Factorial",) { + CHECK(factorial(0) == 1); + CHECK(factorial(1) == 1); + CHECK(factorial(2) == 2); + CHECK(factorial(3) == 6); + CHECK(factorial(4) == 24); + + // Maximum result that fits in unsinged long long + CHECK(factorial(20) == 2432902008176640000); +} + +TEST_CASE("Expi (negative)",) { + const double eps = 0.0000001; + CHECK_EPS(expi(-1), 0.3678794411714423, eps); + CHECK_EPS(expi(-2), 0.1353352832366127, eps); + CHECK_EPS(expi(-10), 0.0000453999297625, eps); +} + +TEST_CASE("Expi (zero)",) { + const double eps = 0.0000001; + CHECK_EPS(expi(0), 1.0, eps); +} + +TEST_CASE("Expi (positive)",) { + const double eps = 0.0000001; + CHECK_EPS(expi(1), 2.7182818284590452, eps); + CHECK_EPS(expi(2), 7.3890560989306502, eps); + CHECK_EPS(expi(10), 22026.4657948067165170, eps); +} + +} +} @@ -49,4 +49,24 @@ constexpr subtuple_t<Tuple, Indices...> make_subtuple(const Tuple& tuple, sequen return std::make_tuple(std::get<Indices>(tuple)...); } +/* + * Factorial + */ +constexpr unsigned long long factorial(unsigned n) { + return n <= 1 ? 1 : n * factorial(n - 1); +} + +/* + * e^n, where n is integer + */ +namespace { // anonymous namespace +constexpr double expi_nonnegative(unsigned n) { + return n == 0 ? 1.0 : (n == 1 ? 2.718281828459045 : expi_nonnegative(n / 2) * expi_nonnegative((n + 1) / 2)); +} +} // anonymous namespace + +constexpr double expi(int n) { + return (n >= 0 ? expi_nonnegative(n) : 1.0 / expi_nonnegative(-n)); +} + } diff --git a/vec_utils.h b/vec_utils.h index fb6aea4..acb7d6e 100644 --- a/vec_utils.h +++ b/vec_utils.h @@ -46,4 +46,84 @@ INTGEMM_AVX512BW static inline __m512 unquantize(__m512i input, __m512 unquantiz } #endif +/* + * + * Calculate floor: float -> float + * + */ +INTGEMM_SSE2 static inline __m128 floor_ff(__m128 a) { + return cvtepi32_ps(_mm_cvttps_epi32(a)); +} +INTGEMM_AVX2 static inline __m256 floor_ff(__m256 a) { + return _mm256_floor_ps(a); +} +#ifndef INTGEMM_NO_AVX512 +INTGEMM_AVX512BW static inline __m512 floor_ff(__m512 a) { + return cvtepi32_ps(cvttps_epi32(a)); // TODO: Is there any better way to do that? +} +#endif + +/* + * + * Calculate approximation of e^x using Taylor series and lookup table + * + */ + +template <typename Register> +Register exp_approx_taylor(Register x) { + static constexpr int EXP_MIN = -20; + static constexpr int EXP_MAX = 20; + static constexpr float EXP_LOOKUP[EXP_MAX - EXP_MIN + 1] = { + expi(-20), expi(-19), expi(-18), expi(-17), expi(-16), expi(-15), + expi(-14), expi(-13), expi(-12), expi(-11), expi(-10), expi(-9), + expi(-8), expi(-7), expi(-6), expi(-5), expi(-4), expi(-3), expi(-2), + expi(-1), expi(0), expi(1), expi(2), expi(3), expi(4), expi(5), + expi(6), expi(7), expi(8), expi(9), expi(10), expi(11), expi(12), + expi(13), expi(14), expi(15), expi(16), expi(17), expi(18), expi(19), + expi(20), + }; + + static const Register dividers[] = { + set1_ps<Register>(1.f / factorial(7)), + set1_ps<Register>(1.f / factorial(6)), + set1_ps<Register>(1.f / factorial(5)), + set1_ps<Register>(1.f / factorial(4)), + set1_ps<Register>(1.f / factorial(3)), + set1_ps<Register>(1.f / factorial(2)), + set1_ps<Register>(1.f / factorial(1)), + }; + static const auto const_one = set1_ps<Register>(1.f); + static const auto const_min_x = set1_ps<Register>(EXP_MIN); + static const auto const_max_x = set1_ps<Register>(EXP_MAX); + + x = max_ps(x, const_min_x); + x = min_ps(x, const_max_x); + + auto a = floor_ff(x); + auto xa = sub_ps(x, a); + + auto result = mul_ps(dividers[0], xa); + + result = add_ps(result, dividers[1]); + result = mul_ps(result, xa); + result = add_ps(result, dividers[2]); + result = mul_ps(result, xa); + result = add_ps(result, dividers[3]); + result = mul_ps(result, xa); + result = add_ps(result, dividers[4]); + result = mul_ps(result, xa); + result = add_ps(result, dividers[5]); + result = mul_ps(result, xa); + result = add_ps(result, dividers[6]); + result = mul_ps(result, xa); + + result = add_ps(result, const_one); + + auto ea = i32gather_ps<4>(EXP_LOOKUP + EXP_MAX, cvtps_epi32(a)); + return mul_ps(ea, result); +} + +template INTGEMM_AVX2 static __m256 exp_approx_taylor(__m256 x); +template INTGEMM_AVX512BW static __m512 exp_approx_taylor(__m512 x); + } |