diff options
author | Mateusz Chudyk <mateuszchudyk@gmail.com> | 2019-07-08 17:57:41 +0300 |
---|---|---|
committer | Mateusz Chudyk <mateuszchudyk@gmail.com> | 2019-07-09 22:39:50 +0300 |
commit | 03ec24b72137c785cd9c931c34bc50c9cbd3cae3 (patch) | |
tree | 32f37bf856726337456dffca2b2f1f43bc53e229 | |
parent | 7e514a4d1178ddeae0cf38fa29f5ca758abf8a9a (diff) |
Add code infrastructure for support callbacks
-rw-r--r-- | CMakeLists.txt | 6 | ||||
-rw-r--r-- | avx512_gemm.h | 10 | ||||
-rw-r--r-- | benchmark.cc | 5 | ||||
-rw-r--r-- | callbacks.h | 6 | ||||
-rw-r--r-- | callbacks/avx2.h | 13 | ||||
-rw-r--r-- | callbacks/avx512.h | 17 | ||||
-rw-r--r-- | callbacks/configs.h | 10 | ||||
-rw-r--r-- | callbacks/implementations.inl | 67 | ||||
-rw-r--r-- | callbacks/sse2.h | 13 | ||||
-rw-r--r-- | example.cc | 5 | ||||
-rw-r--r-- | intgemm.h | 37 | ||||
-rw-r--r-- | multiply.h | 18 | ||||
-rw-r--r-- | postprocess.h | 390 | ||||
-rw-r--r-- | postprocess_pipeline.h | 113 | ||||
-rw-r--r-- | test/multiply_test.cc | 6 | ||||
-rw-r--r-- | test/postprocess/add_bias_test.cc | 95 | ||||
-rw-r--r-- | test/postprocess/pipeline_test.cc | 63 | ||||
-rw-r--r-- | test/postprocess/relu_test.cc | 213 | ||||
-rw-r--r-- | test/postprocess/sigmoid_test.cc | 33 | ||||
-rw-r--r-- | test/postprocess/tanh_test.cc | 33 | ||||
-rw-r--r-- | test/postprocess/unquantize_test.cc | 88 | ||||
-rw-r--r-- | vec_utils.h | 1 |
22 files changed, 165 insertions, 1077 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt index 31b7b53..12300e0 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -33,12 +33,6 @@ endforeach() include_directories(.) add_executable(tests test/multiply_test.cc - test/postprocess/add_bias_test.cc - test/postprocess/pipeline_test.cc - test/postprocess/relu_test.cc - test/postprocess/sigmoid_test.cc - test/postprocess/tanh_test.cc - test/postprocess/unquantize_test.cc test/quantize_test.cc test/test.cc test/utils_test.cc diff --git a/avx512_gemm.h b/avx512_gemm.h index 8326a82..f474b35 100644 --- a/avx512_gemm.h +++ b/avx512_gemm.h @@ -217,8 +217,8 @@ struct AVX512_8bit { // Special AVX512 implementation due to having 32 registers (so I don't have to // allocate registers manually) and no sign instruction. - template <typename PostprocessPipeline> - INTGEMM_AVX512BW static void Multiply(const int8_t *A, const int8_t *B, float *C, PostprocessPipeline pipeline, Index A_rows, Index width, Index B_cols) { + template <typename Callback> + INTGEMM_AVX512BW static void Multiply(const int8_t *A, const int8_t *B, Index A_rows, Index width, Index B_cols, Callback callback) { typedef __m512i Integer; //typedef __m256 Float; // For quantization we only do 8 at a time. // This is copy-paste from Multiply8_SSE2OrAVX2. @@ -227,7 +227,7 @@ struct AVX512_8bit { assert(reinterpret_cast<uintptr_t>(A) % sizeof(Integer) == 0); assert(reinterpret_cast<uintptr_t>(B) % sizeof(Integer) == 0); // There's 8 results for INTGEMM_AVX2 to handle. - auto inited_pipeline = InitPostprocessPipeline<CPUType::AVX2>(pipeline); + auto callback_impl = callbacks::CallbackImpl<Callback, CPUType::AVX2>(callback); const int simd_width = width / sizeof(Integer); const Integer *B0_col = reinterpret_cast<const Integer*>(B); // Added for AVX512. @@ -324,9 +324,7 @@ struct AVX512_8bit { Integer pack4567 = Pack0123(sum4, sum5, sum6, sum7); auto total = PermuteSummer(pack0123, pack4567); - auto offset = A_rowidx * B_cols + B0_colidx; - auto result = inited_pipeline.run(total, offset); - writer(C, offset, result); + callback_impl(total, A_rowidx, B0_colidx, A_rows, width, B_cols); } } } diff --git a/benchmark.cc b/benchmark.cc index 6477792..669ea77 100644 --- a/benchmark.cc +++ b/benchmark.cc @@ -5,6 +5,7 @@ #include "sse2_gemm.h" #include "intgemm.h" #include "stop_watch.h" +#include "callbacks.h" #include <algorithm> #include <cassert> @@ -78,10 +79,10 @@ template <class Backend> void Run(const RandomMatrices &m, std::vector<uint64_t> Backend::PrepareB(m.B.begin(), B_prepared.begin(), quant_mult, m.width, m.B_cols); AlignedVector<float> output(m.A_rows * m.B_cols); // Burn in - Backend::Multiply(A_prepared.begin(), B_prepared.begin(), output.begin(), CreatePostprocessPipeline(Unquantize(unquant_mult)), m.A_rows, m.width, m.B_cols); + Backend::Multiply(A_prepared.begin(), B_prepared.begin(), m.A_rows, m.width, m.B_cols, callbacks::Dummy()); { StopWatch w(stats); - Backend::Multiply(A_prepared.begin(), B_prepared.begin(), output.begin(), CreatePostprocessPipeline(Unquantize(unquant_mult)), m.A_rows, m.width, m.B_cols); + Backend::Multiply(A_prepared.begin(), B_prepared.begin(), m.A_rows, m.width, m.B_cols, callbacks::Dummy()); } } diff --git a/callbacks.h b/callbacks.h new file mode 100644 index 0000000..079ca0f --- /dev/null +++ b/callbacks.h @@ -0,0 +1,6 @@ +#pragma once + +#include "callbacks/configs.h" +#include "callbacks/sse2.h" +#include "callbacks/avx2.h" +#include "callbacks/avx512.h" diff --git a/callbacks/avx2.h b/callbacks/avx2.h new file mode 100644 index 0000000..c69875d --- /dev/null +++ b/callbacks/avx2.h @@ -0,0 +1,13 @@ +#pragma once + +#define THIS_IS_AVX2 +#include "callbacks/implementations.inl" +#undef THIS_IS_AVX2 + +namespace intgemm { +namespace callbacks { + +// Put here callbacks supported only by AVX2... + +} +} diff --git a/callbacks/avx512.h b/callbacks/avx512.h new file mode 100644 index 0000000..a224a84 --- /dev/null +++ b/callbacks/avx512.h @@ -0,0 +1,17 @@ +#pragma once + +#ifndef INTGEMM_NO_AVX512 + +#define THIS_IS_AVX512BW +#include "callbacks/implementations.inl" +#undef THIS_IS_AVX512BW + +namespace intgemm { +namespace callbacks { + +// Put here callbacks supported only by AVX512BW... + +} +} + +#endif diff --git a/callbacks/configs.h b/callbacks/configs.h new file mode 100644 index 0000000..fae60bd --- /dev/null +++ b/callbacks/configs.h @@ -0,0 +1,10 @@ +#pragma once + +namespace intgemm { +namespace callbacks { + +struct Dummy { +}; + +} +} diff --git a/callbacks/implementations.inl b/callbacks/implementations.inl new file mode 100644 index 0000000..01195f1 --- /dev/null +++ b/callbacks/implementations.inl @@ -0,0 +1,67 @@ +#include "callbacks/configs.h" + +#include "intrinsics.h" +#include "types.h" +#include "vec_traits.h" + +#if defined(THIS_IS_SSE2) + #define CPU_NAME SSE2 + #define CPU_ATTR INTGEMM_SSE2 +#elif defined(THIS_IS_AVX2) + #define CPU_NAME AVX2 + #define CPU_ATTR INTGEMM_AVX2 +#elif defined(THIS_IS_AVX512BW) + #define CPU_NAME AVX512BW + #define CPU_ATTR INTGEMM_AVX512BW +#else + #error "Only SSE2, AVX2 and AVX512BW are supported" +#endif + +#define vi vector_t<CPUType::CPU_NAME, int> +#define vf vector_t<CPUType::CPU_NAME, float> +#define vd vector_t<CPUType::CPU_NAME, double> +#define dvi dvector_t<CPUType::CPU_NAME, int> +#define dvf dvector_t<CPUType::CPU_NAME, float> +#define dvd dvector_t<CPUType::CPU_NAME, double> + +#if defined(THIS_IS_SSE2) +#define vinput dvector_t<CPUType::SSE2, int> +#else +#define vinput vector_t<CPUType::AVX2, int> +#endif + +namespace intgemm { +namespace callbacks { + +template <typename CallbackConfig, CPUType CpuType> +class CallbackImpl; + +}} + +/* + * Callbacks implementations.... + */ +namespace intgemm { +namespace callbacks { + +/* + * Dummy + */ +template <> class CallbackImpl<Dummy, CPUType::CPU_NAME> { +public: + CPU_ATTR CallbackImpl(const Dummy&) {} + CPU_ATTR void operator()(vinput, Index, Index, Index, Index, Index) {} +}; + +} +} + +#undef CPU_NAME +#undef CPU_ATTR +#undef vi +#undef vf +#undef vd +#undef dvi +#undef dvf +#undef dvd +#undef vinput diff --git a/callbacks/sse2.h b/callbacks/sse2.h new file mode 100644 index 0000000..8758a20 --- /dev/null +++ b/callbacks/sse2.h @@ -0,0 +1,13 @@ +#pragma once + +#define THIS_IS_SSE2 +#include "callbacks/implementations.inl" +#undef THIS_IS_SSE2 + +namespace intgemm { +namespace callbacks { + +// Put here callbacks supported only by SSE2... + +} +} @@ -2,6 +2,7 @@ // This is just for AlignedVector, which helps managed 64-byte aligned memory. // Feel free to manage memory yourself. #include "aligned.h" +#include "callbacks.h" #include <cassert> #include <math.h> @@ -51,7 +52,7 @@ int main() { AlignedVector<float> C(A_rows * B_cols); // Do the actual multiply. - intgemm::Int16::Multiply(A_prepared.begin(), B_prepared.begin(), C.begin(), intgemm::CreatePostprocessPipeline(intgemm::Unquantize(1.0 / (quant_mult * quant_mult))), A_rows, width, B_cols); + intgemm::Int16::Multiply(A_prepared.begin(), B_prepared.begin(), A_rows, width, B_cols, intgemm::callbacks::Dummy()); // Sanity check. C will be row major. assert(fabs(C[0] - top_left_reference) < 0.05); } @@ -70,7 +71,7 @@ int main() { AlignedVector<float> C(A_rows * B_cols); // Do the actual multiply. - intgemm::Int8::Multiply(A_prepared.begin(), B_prepared.begin(), C.begin(), intgemm::CreatePostprocessPipeline(intgemm::Unquantize(1.0 / (quant_mult * quant_mult))), A_rows, width, B_cols); + intgemm::Int8::Multiply(A_prepared.begin(), B_prepared.begin(), A_rows, width, B_cols, intgemm::callbacks::Dummy()); // Sanity check. C will be row major. assert(fabs(C[0] - top_left_reference) < 0.05); } @@ -48,7 +48,6 @@ #include "sse2_gemm.h" #include "ssse3_gemm.h" #include "avx2_gemm.h" -#include "postprocess.h" #ifndef INTGEMM_NO_AVX512 #include "avx512_gemm.h" #endif @@ -67,8 +66,8 @@ struct Unsupported_16bit { static void SelectColumnsB(const int16_t *, int16_t *, Index, const Index *, const Index *) { throw UnsupportedCPU(); } - template <typename PostprocessPipeline> - static void Multiply(const int16_t *, const int16_t *, float *, PostprocessPipeline, Index, Index, Index) { + template <typename Callback> + static void Multiply(const int16_t *, const int16_t *, Index, Index, Index, Callback) { throw UnsupportedCPU(); } constexpr static const char *const kName = "16-bit Unsupported"; @@ -84,8 +83,8 @@ struct Unsupported_8bit { static void SelectColumnsB(const int8_t *, int8_t *, Index, const Index *, const Index *) { throw UnsupportedCPU(); } - template <typename PostprocessPipeline> - static void Multiply(const int8_t *, const int8_t *, float *, PostprocessPipeline, Index, Index, Index) { + template <typename Callback> + static void Multiply(const int8_t *, const int8_t *, Index, Index, Index, Callback) { throw UnsupportedCPU(); } constexpr static const char *const kName = "8-bit Unsupported"; @@ -133,15 +132,15 @@ template <class T> T ChooseCPU(T avx512, T avx2, T ssse3, T sse2, T unsupported) } /* 16-bit matrix multiplication. */ -template <typename PostprocessPipeline> +template <typename Callback> class Int16Mult { public: // Multiply C = A * B, presuming A and B have been prepared. - static void (*Multiply)(const int16_t *A, const int16_t *B, float *C, PostprocessPipeline pipeline, Index A_rows, Index width, Index B_cols); + static void (*Multiply)(const int16_t *A, const int16_t *B, Index A_rows, Index width, Index B_cols, Callback callback); }; -template <typename PostprocessPipeline> -void (*Int16Mult<PostprocessPipeline>::Multiply)(const int16_t *A, const int16_t *B, float *C, PostprocessPipeline pipeline, Index A_rows, Index width, Index B_cols) = ChooseCPU(AVX512_16bit::Multiply<PostprocessPipeline>, AVX2_16bit::Multiply<PostprocessPipeline>, SSE2_16bit::Multiply<PostprocessPipeline>, SSE2_16bit::Multiply<PostprocessPipeline>, Unsupported_16bit::Multiply); +template <typename Callback> +void (*Int16Mult<Callback>::Multiply)(const int16_t *A, const int16_t *B, Index A_rows, Index width, Index B_cols, Callback callback) = ChooseCPU(AVX512_16bit::Multiply<Callback>, AVX2_16bit::Multiply<Callback>, SSE2_16bit::Multiply<Callback>, SSE2_16bit::Multiply<Callback>, Unsupported_16bit::Multiply); struct Int16 { typedef int16_t Integer; @@ -172,24 +171,24 @@ struct Int16 { static void (*SelectColumnsB)(const int16_t *input, int16_t *output, Index rows, const Index *cols_begin, const Index *cols_end); // Multiply C = A * B, presuming A and B have been prepared. - template <typename PostprocessPipeline> - static void Multiply(const int16_t *A, const int16_t *B, float *C, PostprocessPipeline pipeline, Index A_rows, Index width, Index B_cols) { - Int16Mult<PostprocessPipeline>::Multiply(A, B, C, pipeline, A_rows, width, B_cols); + template <typename Callback> + static void Multiply(const int16_t *A, const int16_t *B, Index A_rows, Index width, Index B_cols, Callback callback) { + Int16Mult<Callback>::Multiply(A, B, A_rows, width, B_cols, callback); } static const char *const kName; }; /* 8-bit matrix multiplication */ -template <typename PostprocessPipeline> +template <typename Callback> class Int8Mult { public: // Multiply C = A * B, presuming A and B have been prepared. - static void (*Multiply)(const int8_t *A, const int8_t *B, float *C, PostprocessPipeline pipeline, Index A_rows, Index width, Index B_cols); + static void (*Multiply)(const int8_t *A, const int8_t *B, Index A_rows, Index width, Index B_cols, Callback callback); }; -template <typename PostprocessPipeline> -void (*Int8Mult<PostprocessPipeline>::Multiply)(const int8_t *A, const int8_t *B, float *C, PostprocessPipeline pipeline, Index A_rows, Index width, Index B_cols) = ChooseCPU(AVX512_8bit::Multiply<PostprocessPipeline>, AVX2_8bit::Multiply<PostprocessPipeline>, SSSE3_8bit::Multiply<PostprocessPipeline>, SSSE3_8bit::Multiply<PostprocessPipeline>, Unsupported_8bit::Multiply); +template <typename Callback> +void (*Int8Mult<Callback>::Multiply)(const int8_t *A, const int8_t *B, Index A_rows, Index width, Index B_cols, Callback callback) = ChooseCPU(AVX512_8bit::Multiply<Callback>, AVX2_8bit::Multiply<Callback>, SSSE3_8bit::Multiply<Callback>, SSSE3_8bit::Multiply<Callback>, Unsupported_8bit::Multiply); struct Int8 { typedef int8_t Integer; @@ -219,9 +218,9 @@ struct Int8 { static void (*SelectColumnsB)(const int8_t *input, int8_t *output, Index rows, const Index *cols_begin, const Index *cols_end); // Multiply C = A * B, presuming A and B have been prepared. - template <typename PostprocessPipeline> - static void Multiply(const int8_t *A, const int8_t *B, float* C, PostprocessPipeline pipeline, Index A_rows, Index width, Index B_cols) { - Int8Mult<PostprocessPipeline>::Multiply(A, B, C, pipeline, A_rows, width, B_cols); + template <typename Callback> + static void Multiply(const int8_t *A, const int8_t *B, Index A_rows, Index width, Index B_cols, Callback callback) { + Int8Mult<Callback>::Multiply(A, B, A_rows, width, B_cols, callback); } static const char *const kName; @@ -2,9 +2,9 @@ #include "interleave.h" #include "intrinsics.h" -#include "postprocess_pipeline.h" #include "vec_utils.h" #include "vec_traits.h" +#include "callbacks.h" namespace intgemm { @@ -144,13 +144,13 @@ INTGEMM_PACK0123(INTGEMM_AVX512BW, __m512i) // B_cols must be a multiple of 8. // Multiply16 #define INTGEMM_MULTIPLY16(Integer, target, cpu_type) \ -template <typename PostprocessPipeline> target static void Multiply(const int16_t *A, const int16_t *B, float* C, PostprocessPipeline pipeline, Index A_rows, Index width, Index B_cols) { \ +template <typename Callback> target static void Multiply(const int16_t *A, const int16_t *B, Index A_rows, Index width, Index B_cols, Callback callback) { \ assert(width % (sizeof(Integer) / sizeof(int16_t)) == 0); \ assert(B_cols % 8 == 0); \ assert(reinterpret_cast<uintptr_t>(A) % sizeof(Integer) == 0); \ assert(reinterpret_cast<uintptr_t>(B) % sizeof(Integer) == 0); \ const int simd_width = width / (sizeof(Integer) / sizeof(int16_t)); \ - auto inited_pipeline = InitPostprocessPipeline<cpu_type>(pipeline); \ + auto callback_impl = callbacks::CallbackImpl<Callback, cpu_type>(callback); \ const Integer *B0_col = reinterpret_cast<const Integer *>(B); \ for (Index B0_colidx = 0; B0_colidx < B_cols; B0_col += 8 * simd_width, B0_colidx += 8) { \ /* Process one row of A at a time. Doesn't seem to be faster to do multiple rows of A at once.*/ \ @@ -194,9 +194,7 @@ template <typename PostprocessPipeline> target static void Multiply(const int16_ Integer pack4567 = Pack0123(sum4, sum5, sum6, sum7); \ /*The specific implementation may need to reduce further.*/ \ auto total = PermuteSummer(pack0123, pack4567); \ - auto offset = A_rowidx * B_cols + B0_colidx; \ - auto result = inited_pipeline.run(total, offset); \ - writer(C, offset, result); \ + callback_impl(total, A_rowidx, B0_colidx, A_rows, width, B_cols); \ } \ } \ } \ @@ -350,14 +348,14 @@ INTGEMM_SSSE3 inline static void InnerINTGEMM_SSSE3( } //INTGEMM_AVX2 or INTGEMM_SSSE3 multiply #define INTGEMM_MULTIPLY8(Integer, target, cpu_type) \ - template <typename PostprocessPipeline> target static void Multiply(const int8_t *A, const int8_t *B, float* C, PostprocessPipeline pipeline, Index A_rows, Index width, Index B_cols) { \ + template <typename Callback> target static void Multiply(const int8_t *A, const int8_t *B, Index A_rows, Index width, Index B_cols, Callback callback) { \ assert(width % sizeof(Integer) == 0); \ assert(B_cols % 8 == 0); \ assert(reinterpret_cast<uintptr_t>(A) % sizeof(Integer) == 0); \ assert(reinterpret_cast<uintptr_t>(B) % sizeof(Integer) == 0); \ const int simd_width = width / sizeof(Integer); \ + auto callback_impl = callbacks::CallbackImpl<Callback, cpu_type>(callback); \ const Integer *B0_col = reinterpret_cast<const Integer*>(B); \ - auto inited_pipeline = InitPostprocessPipeline<cpu_type>(pipeline); \ /*Go over 8 columns of B at a time.*/ \ for (Index B0_colidx = 0; B0_colidx != B_cols; B0_col += 8 * simd_width, B0_colidx += 8) { \ /*Process one row of A at a time. Doesn't seem to be faster to do multiple rows of A at once.*/ \ @@ -412,9 +410,7 @@ INTGEMM_SSSE3 inline static void InnerINTGEMM_SSSE3( Integer pack0123 = Pack0123(sum0, sum1, sum2, sum3); \ Integer pack4567 = Pack0123(sum4, sum5, sum6, sum7); \ auto total = PermuteSummer(pack0123, pack4567); \ - auto offset = A_rowidx * B_cols + B0_colidx; \ - auto result = inited_pipeline.run(total, offset); \ - writer(C, offset, result); \ + callback_impl(total, A_rowidx, B0_colidx, A_rows, width, B_cols); \ } \ } \ } \ diff --git a/postprocess.h b/postprocess.h deleted file mode 100644 index 53c5a3e..0000000 --- a/postprocess.h +++ /dev/null @@ -1,390 +0,0 @@ -#pragma once - -#include "intrinsics.h" -#include "postprocess_pipeline.h" -#include "types.h" -#include "vec_utils.h" -#include "vec_traits.h" - -// TODO: We support some postprocess in few variations e.g. we support ReLU for -// float -> float, int8 -> int8, int16 -> int16. Maybe it would be a good idea -// to pass input type and output type as a template parameter of postprocess? - -namespace intgemm { - -/* - * Unquantize - */ -class Unquantize { -public: - float unquantize_multiplier; - - Unquantize(float unquantize_multiplier) : unquantize_multiplier(unquantize_multiplier) {} -}; - -template <> -class PostprocessImpl<Unquantize, CPUType::SSE2> { -public: - using InputRegister = dvector_t<CPUType::SSE2, int>; - using OutputRegister = dvector_t<CPUType::SSE2, float>; - - INTGEMM_SSE2 PostprocessImpl(const Unquantize& config) { - unquantize_multiplier = set1_ps<__m128>(config.unquantize_multiplier); - } - - INTGEMM_SSE2 inline OutputRegister run(InputRegister input, Index offset) { - return { - mul_ps(cvtepi32_ps(input.first), unquantize_multiplier), - mul_ps(cvtepi32_ps(input.second), unquantize_multiplier), - }; - } - -private: - __m128 unquantize_multiplier; -}; - -template <> -class PostprocessImpl<Unquantize, CPUType::AVX2> { -public: - using InputRegister = __m256i; - using OutputRegister = __m256; - - INTGEMM_AVX2 PostprocessImpl(const Unquantize& config) { - unquantize_multiplier = set1_ps<__m256>(config.unquantize_multiplier); - } - - INTGEMM_AVX2 inline OutputRegister run(InputRegister input, Index offset) { - return mul_ps(cvtepi32_ps(input), unquantize_multiplier); - } - -private: - __m256 unquantize_multiplier; -}; - -#ifndef INTGEMM_NO_AVX512 - -template <> -class PostprocessImpl<Unquantize, CPUType::AVX512BW> { -public: - using InputRegister = __m512i; - using OutputRegister = __m512; - - INTGEMM_AVX512BW PostprocessImpl(const Unquantize& config) { - unquantize_multiplier = set1_ps<__m512>(config.unquantize_multiplier); - } - - INTGEMM_AVX512BW inline OutputRegister run(InputRegister input, Index offset) { - return mul_ps(cvtepi32_ps(input), unquantize_multiplier); - } - -private: - __m512 unquantize_multiplier; -}; - -#endif - -/* - * Add a bias term - */ -class AddBias { -public: - const float* bias; - const Index length; - - AddBias(const float* bias, Index length) : bias(bias), length(length) {} -}; - -template <> -class PostprocessImpl<AddBias, CPUType::SSE2> { -public: - using InputRegister = dvector_t<CPUType::SSE2, float>; - using OutputRegister = dvector_t<CPUType::SSE2, float>; - - PostprocessImpl(const AddBias& config) : config(config) {} - - INTGEMM_SSE2 inline OutputRegister run(InputRegister input, Index offset) { - auto bias_term0123 = *reinterpret_cast<const __m128*>(config.bias + (offset % config.length)); - auto bias_term4567 = *reinterpret_cast<const __m128*>(config.bias + (offset % config.length) + 4); - return { - add_ps(input.first, bias_term0123), - add_ps(input.second, bias_term4567), - }; - } - -private: - const AddBias config; -}; - -template <> -class PostprocessImpl<AddBias, CPUType::AVX2> { -public: - using InputRegister = __m256; - using OutputRegister = __m256; - - PostprocessImpl(const AddBias& config) : config(config) {} - - INTGEMM_AVX2 inline OutputRegister run(InputRegister input, Index offset) { - auto bias_term = *reinterpret_cast<const __m256*>(config.bias + (offset % config.length)); - return add_ps(input, bias_term); - } - -private: - const AddBias config; -}; - -#ifndef INTGEMM_NO_AVX512 - -template <> -class PostprocessImpl<AddBias, CPUType::AVX512BW> { -public: - using InputRegister = __m512; - using OutputRegister = __m512; - - PostprocessImpl(const AddBias& config) : config(config) {} - - INTGEMM_AVX512BW inline OutputRegister run(InputRegister input, Index offset) { - auto bias_term = *reinterpret_cast<const __m512*>(config.bias + (offset % config.length)); - return add_ps(input, bias_term); - } - -private: - const AddBias config; -}; - -#endif - -/* - * ReLU - */ -class ReLU {}; - -template <> -class PostprocessImpl<ReLU, CPUType::SSE2> { -public: - using InputRegister = dvector_t<CPUType::SSE2, float>; - using OutputRegister = dvector_t<CPUType::SSE2, float>; - - PostprocessImpl(const ReLU& config) {} - - INTGEMM_SSE2 inline OutputRegister run(InputRegister input, Index offset) { - static const auto const_zero = set1_ps<__m128>(0.f); - return { - max_ps(const_zero, input.first), - max_ps(const_zero, input.second), - }; - } -}; - -template <> -class PostprocessImpl<ReLU, CPUType::SSSE3> : public PostprocessImpl<ReLU, CPUType::SSE2> {}; - -template <> -class PostprocessImpl<ReLU, CPUType::AVX2> { -public: - using InputRegister = __m256; - using OutputRegister = __m256; - - PostprocessImpl(const ReLU& config) {} - - INTGEMM_AVX2 inline OutputRegister run(InputRegister input, Index offset) { - static const auto const_zero = set1_ps<__m256>(0.f); - return max_ps(const_zero, input); - } -}; - -#ifndef INTGEMM_NO_AVX512 - -template <> -class PostprocessImpl<ReLU, CPUType::AVX512BW> { -public: - using InputRegister = __m512; - using OutputRegister = __m512; - - PostprocessImpl(const ReLU& config) {} - - INTGEMM_AVX512BW inline OutputRegister run(InputRegister input, Index offset) { - static const auto const_zero = set1_ps<__m512>(0.f); - return max_ps(const_zero, input); - } -}; - -#endif - -/* - * ReLU_int8 - */ -class ReLU_int8 {}; - -template <> -class PostprocessImpl<ReLU_int8, CPUType::SSE2> { -public: - using InputRegister = __m128i; - using OutputRegister = __m128i; - - PostprocessImpl(const ReLU_int8& config) {} - - INTGEMM_SSE2 inline OutputRegister run(InputRegister input, Index offset) { - static const auto const_zero = setzero_si<__m128i>(); - return _mm_and_si128(_mm_cmplt_epi8(const_zero, input), input); - } -}; - -template <> -class PostprocessImpl<ReLU_int8, CPUType::AVX2> { -public: - using InputRegister = __m256i; - using OutputRegister = __m256i; - - PostprocessImpl(const ReLU_int8& config) {} - - INTGEMM_AVX2 inline OutputRegister run(InputRegister input, Index offset) { - static const auto const_zero = setzero_si<__m256i>(); - return max_epi8(const_zero, input); - } -}; - -#ifndef INTGEMM_NO_AVX512 - -template <> -class PostprocessImpl<ReLU_int8, CPUType::AVX512BW> { -public: - using InputRegister = __m512i; - using OutputRegister = __m512i; - - PostprocessImpl(const ReLU_int8& config) {} - - INTGEMM_AVX512BW inline OutputRegister run(InputRegister input, Index offset) { - static const auto const_zero = setzero_si<__m512i>(); - return max_epi8(const_zero, input); - } -}; - -#endif - -/* - * ReLU_int16 - */ -class ReLU_int16 {}; - -template <> -class PostprocessImpl<ReLU_int16, CPUType::SSE2> { -public: - using InputRegister = __m128i; - using OutputRegister = __m128i; - - PostprocessImpl(const ReLU_int16& config) {} - - INTGEMM_SSE2 inline OutputRegister run(InputRegister input, Index offset) { - static const auto const_zero = setzero_si<__m128i>(); - return max_epi16(const_zero, input); - } -}; - -template <> -class PostprocessImpl<ReLU_int16, CPUType::AVX2> { -public: - using InputRegister = __m256i; - using OutputRegister = __m256i; - - PostprocessImpl(const ReLU_int16& config) {} - - INTGEMM_AVX2 inline OutputRegister run(InputRegister input, Index offset) { - static const auto const_zero = setzero_si<__m256i>(); - return max_epi16(const_zero, input); - } -}; - -#ifndef INTGEMM_NO_AVX512 - -template <> -class PostprocessImpl<ReLU_int16, CPUType::AVX512BW> { -public: - using InputRegister = __m512i; - using OutputRegister = __m512i; - - PostprocessImpl(const ReLU_int16& config) {} - - INTGEMM_AVX512BW inline OutputRegister run(InputRegister input, Index offset) { - static const auto const_zero = setzero_si<__m512i>(); - return max_epi16(const_zero, input); - } -}; - -#endif - -/* - * Sigmoid (uses Taylor series approximation of e^x) - */ -class Sigmoid {}; - -template <> -class PostprocessImpl<Sigmoid, CPUType::AVX2> { -public: - using InputRegister = __m256; - using OutputRegister = __m256; - - PostprocessImpl(const Sigmoid& config) {} - - INTGEMM_AVX2 inline OutputRegister run(InputRegister input, Index offset) { - static const auto const_zero = set1_ps<__m256>(0.f); - static const auto const_one = set1_ps<__m256>(1.f); - - auto x = input; - auto minus_x = sub_ps(const_zero, x); - auto e_x = exp_approx_taylor(x); - auto e_minus_x = exp_approx_taylor(minus_x); - - auto sigmoid_case1 = _mm256_rcp_ps(add_ps(const_one, e_minus_x)); - auto sigmoid_case2 = mul_ps(e_x, _mm256_rcp_ps(add_ps(const_one, e_x))); - - auto nonnegative_x_mask = _mm256_cmp_ps(const_zero, x, _CMP_LT_OS); - return _mm256_blendv_ps(sigmoid_case1, sigmoid_case2, nonnegative_x_mask); - } -}; - -/* - * Tanh (uses Taylor series approximation of e^x) - */ -class Tanh {}; - -template <> -class PostprocessImpl<Tanh, CPUType::AVX2> { -public: - using InputRegister = __m256; - using OutputRegister = __m256; - - PostprocessImpl(const Tanh& config) {} - - INTGEMM_AVX2 inline OutputRegister run(InputRegister input, Index offset) { - const static auto const_zero = setzero_ps<__m256>(); - - auto e_x = exp_approx_taylor(input); - auto e_minus_x = exp_approx_taylor(sub_ps(const_zero, input)); - - return div_ps(sub_ps(e_x, e_minus_x), add_ps(e_x, e_minus_x)); - } -}; - -#ifndef INTGEMM_NO_AVX512 - -template <> -class PostprocessImpl<Tanh, CPUType::AVX512BW> { -public: - using InputRegister = __m512; - using OutputRegister = __m512; - - PostprocessImpl(const Tanh& config) {} - - INTGEMM_AVX512BW inline OutputRegister run(InputRegister input, Index offset) { - const static auto const_zero = setzero_ps<__m512>(); - - auto e_x = exp_approx_taylor(input); - auto e_minus_x = exp_approx_taylor(sub_ps(const_zero, input)); - - return div_ps(sub_ps(e_x, e_minus_x), add_ps(e_x, e_minus_x)); - } -}; - -#endif - -} diff --git a/postprocess_pipeline.h b/postprocess_pipeline.h deleted file mode 100644 index 361ff2b..0000000 --- a/postprocess_pipeline.h +++ /dev/null @@ -1,113 +0,0 @@ -#pragma once - -#include "intrinsics.h" -#include "types.h" -#include "utils.h" - -#include <tuple> - -namespace intgemm { - -template <typename... Stages> -using PostprocessPipeline = std::tuple<Stages...>; - -template <typename... Stages> -constexpr std::tuple<Stages...> CreatePostprocessPipeline(Stages&&... stages) { - return std::make_tuple(std::forward<Stages>(stages)...); -} - -template <typename Postprocess, CPUType CpuType> -class PostprocessImpl; - -namespace { // anonymous namespace - -template <typename... Stages> -using input_register_type = typename std::tuple_element< - 0, - std::tuple<Stages...> - >::type::InputRegister; - -template <typename... Stages> -using output_register_type = typename std::tuple_element< - std::tuple_size<std::tuple<Stages...>>::value - 1, - std::tuple<Stages...> - >::type::OutputRegister; - -template <typename FirstStage, typename... RestStages> -constexpr std::tuple<RestStages...> DropFirstStage(const std::tuple<FirstStage, RestStages...>& pipeline) { - return make_subtuple(pipeline, sequence_popfront<make_sequence<sizeof...(RestStages) + 1>>()); -} - -template <CPUType CpuType> -constexpr std::tuple<> InitPostprocessPipelineImpl(std::tuple<> pipeline) { - return std::tuple<>(); -} - -template <CPUType CpuType, typename FirstStage, typename... RestStages> -constexpr std::tuple<PostprocessImpl<FirstStage, CpuType>, PostprocessImpl<RestStages, CpuType>...> InitPostprocessPipelineImpl(std::tuple<FirstStage, RestStages...> pipeline) { - return std::tuple_cat( - std::tuple<PostprocessImpl<FirstStage, CpuType>>(PostprocessImpl<FirstStage, CpuType>(std::get<0>(pipeline))), - InitPostprocessPipelineImpl<CpuType, RestStages...>(DropFirstStage(pipeline)) - ); -} - -template <CPUType CpuType> -struct RunPostprocessPipelineImpl; - -#define RUN_POSTPROCESS_PIPELINE_IMPL_INSERT_IMPL(attribute, cpu_type) \ - template <> \ - struct RunPostprocessPipelineImpl<cpu_type> { \ - template <typename Stage> \ - attribute static constexpr output_register_type<Stage> \ - run(std::tuple<Stage> pipeline, input_register_type<Stage> input, Index offset) { \ - return std::get<0>(pipeline).run(input, offset); \ - } \ - template <typename... Stages> \ - attribute static constexpr output_register_type<Stages...> \ - run(std::tuple<Stages...> pipeline, input_register_type<Stages...> input, Index offset) { \ - return run( \ - DropFirstStage(pipeline), \ - std::get<0>(pipeline).run(input, offset), offset); \ - } \ - }; - -RUN_POSTPROCESS_PIPELINE_IMPL_INSERT_IMPL(INTGEMM_SSE2, CPUType::SSE2) -RUN_POSTPROCESS_PIPELINE_IMPL_INSERT_IMPL(INTGEMM_SSSE3, CPUType::SSSE3) -RUN_POSTPROCESS_PIPELINE_IMPL_INSERT_IMPL(INTGEMM_AVX2, CPUType::AVX2) -RUN_POSTPROCESS_PIPELINE_IMPL_INSERT_IMPL(INTGEMM_AVX512BW, CPUType::AVX512BW) - -} // anonymous namespace - -template <CPUType CpuType, typename... Stages> -class InitedPostprocessPipeline {}; - -template <CPUType CpuType, typename... Stages> -constexpr InitedPostprocessPipeline<CpuType, Stages...> InitPostprocessPipeline(std::tuple<Stages...> pipeline) { - return InitedPostprocessPipeline<CpuType, Stages...>(pipeline); -} - -#define INITED_POSTPROCESS_PIPELINE_INSERT_IMPL(attribute, cpu_type) \ - template <typename... Stages> \ - class InitedPostprocessPipeline<cpu_type, Stages...> { \ - public: \ - using InputRegister = input_register_type<PostprocessImpl<Stages, cpu_type>...>; \ - using OutputRegister = output_register_type<PostprocessImpl<Stages, cpu_type>...>; \ - InitedPostprocessPipeline(std::tuple<Stages...> pipeline) \ - : inited_pipeline(InitPostprocessPipelineImpl<cpu_type, Stages...>(pipeline)) {} \ - attribute inline OutputRegister run(InputRegister input, Index offset) { \ - return RunPostprocessPipelineImpl<cpu_type>::run(inited_pipeline, input, offset); \ - } \ - attribute inline void run(const InputRegister* input, unsigned length, OutputRegister* output) { \ - for (unsigned i = 0; i < length; ++i) \ - output[i] = RunPostprocessPipelineImpl<cpu_type>::run(inited_pipeline, input[i], i * sizeof(InputRegister)); \ - } \ - private: \ - const std::tuple<PostprocessImpl<Stages, cpu_type>...> inited_pipeline; \ - }; - -INITED_POSTPROCESS_PIPELINE_INSERT_IMPL(INTGEMM_SSE2, CPUType::SSE2) -INITED_POSTPROCESS_PIPELINE_INSERT_IMPL(INTGEMM_SSSE3, CPUType::SSSE3) -INITED_POSTPROCESS_PIPELINE_INSERT_IMPL(INTGEMM_AVX2, CPUType::AVX2) -INITED_POSTPROCESS_PIPELINE_INSERT_IMPL(INTGEMM_AVX512BW, CPUType::AVX512BW) - -} diff --git a/test/multiply_test.cc b/test/multiply_test.cc index 93d7127..e7fcf77 100644 --- a/test/multiply_test.cc +++ b/test/multiply_test.cc @@ -3,7 +3,7 @@ #include "interleave.h" #include "intgemm.h" #include "multiply.h" -#include "postprocess.h" +#include "callbacks.h" #include <algorithm> #include <cassert> @@ -361,7 +361,7 @@ template <class Routine> void TestMultiply(Index A_rows, Index width, Index B_co Routine::PrepareB(B.begin(), B_prep.begin(), quant_mult, width, B_cols); AlignedVector<float> test_C(A_rows * B_cols); - Routine::Multiply(A_prep.begin(), B_prep.begin(), test_C.begin(), CreatePostprocessPipeline(Unquantize(unquant_mult)), A_rows, width, B_cols); + Routine::Multiply(A_prep.begin(), B_prep.begin(), A_rows, width, B_cols, callbacks::Dummy()); AlignedVector<Integer> B_quant(B.size()); Routine::Quantize(B.begin(), B_quant.begin(), quant_mult, B.size()); @@ -410,7 +410,7 @@ template <class Routine> void TestMultiplyBias(Index A_rows, Index width, Index AlignedVector<float> test_C(A_rows * B_cols); - Routine::Multiply(A_prep.begin(), B_prep.begin(), test_C.begin(), CreatePostprocessPipeline(Unquantize(unquant_mult), AddBias(bias.begin(), B_cols)), A_rows, width, B_cols); + Routine::Multiply(A_prep.begin(), B_prep.begin(), A_rows, width, B_cols, callbacks::Dummy()); AlignedVector<Integer> B_quant(B.size()); Routine::Quantize(B.begin(), B_quant.begin(), quant_mult, B.size()); diff --git a/test/postprocess/add_bias_test.cc b/test/postprocess/add_bias_test.cc deleted file mode 100644 index 3bc7f74..0000000 --- a/test/postprocess/add_bias_test.cc +++ /dev/null @@ -1,95 +0,0 @@ -#include "test/test.h" -#include "aligned.h" -#include "postprocess.h" - -#include <numeric> - -namespace intgemm { - -INTGEMM_SSE2 TEST_CASE("AddBias SSE2",) { - if (kCPU < CPUType::SSE2) - return; - - AlignedVector<float> input(8); - AlignedVector<float> bias(8); - AlignedVector<float> output(8); - - std::iota(input.begin(), input.end(), -2); - std::iota(bias.begin(), bias.end(), 0); - - auto postproc = PostprocessImpl<AddBias, CPUType::SSE2>(AddBias(bias.begin(), bias.size())); - auto output_tmp = postproc.run({input.as<__m128>()[0], input.as<__m128>()[1]}, 0); - output.as<__m128>()[0] = output_tmp.first; - output.as<__m128>()[1] = output_tmp.second; - - CHECK(output[0] == -2.f); // input = -2, bias = 0 - CHECK(output[1] == 0.f); // input = -1, bias = 1 - CHECK(output[2] == 2.f); // input = 0, bias = 2 - CHECK(output[3] == 4.f); // input = 1, bias = 3 - CHECK(output[4] == 6.f); // input = 2, bias = 4 - CHECK(output[5] == 8.f); // input = 3, bias = 5 - CHECK(output[6] == 10.f); // input = 4, bias = 6 - CHECK(output[7] == 12.f); // input = 5, bias = 7 -} - -INTGEMM_AVX2 TEST_CASE("AddBias AVX2",) { - if (kCPU < CPUType::AVX2) - return; - - AlignedVector<float> input(8); - AlignedVector<float> bias(8); - AlignedVector<float> output(8); - - std::iota(input.begin(), input.end(), -4); - std::iota(bias.begin(), bias.end(), 0); - - auto postproc = PostprocessImpl<AddBias, CPUType::AVX2>(AddBias(bias.begin(), bias.size())); - *output.as<__m256>() = postproc.run(*input.as<__m256>(), 0); - - CHECK(output[0] == -4.f); // input = -4, bias = 0 - CHECK(output[1] == -2.f); // input = -3, bias = 1 - CHECK(output[2] == 0.f); // input = -2, bias = 2 - CHECK(output[3] == 2.f); // input = -1, bias = 3 - CHECK(output[4] == 4.f); // input = 0, bias = 4 - CHECK(output[5] == 6.f); // input = 1, bias = 5 - CHECK(output[6] == 8.f); // input = 2, bias = 6 - CHECK(output[7] == 10.f); // input = 3, bias = 7 -} - -#ifndef INTGEMM_NO_AVX512 - -INTGEMM_AVX512BW TEST_CASE("AddBias AVX512",) { - if (kCPU < CPUType::AVX512BW) - return; - - AlignedVector<float> input(16); - AlignedVector<float> bias(16); - AlignedVector<float> output(16); - - std::iota(input.begin(), input.end(), -8); - std::iota(bias.begin(), bias.end(), 0); - - auto postproc = PostprocessImpl<AddBias, CPUType::AVX512BW>(AddBias(bias.begin(), bias.size())); - *output.as<__m512>() = postproc.run(*input.as<__m512>(), 0); - - CHECK(output[0] == -8.f); // input = -8, bias = 0 - CHECK(output[1] == -6.f); // input = -7, bias = 1 - CHECK(output[2] == -4.f); // input = -6, bias = 2 - CHECK(output[3] == -2.f); // input = -5, bias = 3 - CHECK(output[4] == 0.f); // input = -4, bias = 4 - CHECK(output[5] == 2.f); // input = -3, bias = 5 - CHECK(output[6] == 4.f); // input = -2, bias = 6 - CHECK(output[7] == 6.f); // input = -1, bias = 7 - CHECK(output[8] == 8.f); // input = 0, bias = 8 - CHECK(output[9] == 10.f); // input = 1, bias = 9 - CHECK(output[10] == 12.f); // input = 2, bias = 10 - CHECK(output[11] == 14.f); // input = 3, bias = 11 - CHECK(output[12] == 16.f); // input = 4, bias = 12 - CHECK(output[13] == 18.f); // input = 5, bias = 13 - CHECK(output[14] == 20.f); // input = 6, bias = 14 - CHECK(output[15] == 22.f); // input = 7, bias = 15 -} - -#endif - -} diff --git a/test/postprocess/pipeline_test.cc b/test/postprocess/pipeline_test.cc deleted file mode 100644 index 144ee48..0000000 --- a/test/postprocess/pipeline_test.cc +++ /dev/null @@ -1,63 +0,0 @@ -#include "test/test.h" -#include "aligned.h" -#include "postprocess.h" - -#include <numeric> - -namespace intgemm { - -INTGEMM_AVX2 TEST_CASE("PostprocessPipeline AVX2", "Unquantize-ReLU") { - if (kCPU < CPUType::AVX2) - return; - - AlignedVector<int32_t> input(8); - AlignedVector<float> output(8); - - std::iota(input.begin(), input.end(), -2); - - auto pipeline = CreatePostprocessPipeline(Unquantize(0.5f), ReLU()); - auto inited_pipeline = InitPostprocessPipeline<CPUType::AVX2>(pipeline); - *output.as<__m256>() = inited_pipeline.run(*input.as<__m256i>(), 0); - - CHECK(output[0] == 0.0f); // input = -2 - CHECK(output[1] == 0.0f); // input = -1 - CHECK(output[2] == 0.0f); // input = 0 - CHECK(output[3] == 0.5f); // input = 1 - CHECK(output[4] == 1.0f); // input = 2 - CHECK(output[5] == 1.5f); // input = 3 - CHECK(output[6] == 2.0f); // input = 4 - CHECK(output[7] == 2.5f); // input = 5 -} - -INTGEMM_AVX2 TEST_CASE("PostprocessPipeline AVX2 on whole buffer", "Unquantize-ReLU") { - if (kCPU < CPUType::AVX2) - return; - - AlignedVector<int32_t> input(16); - AlignedVector<float> output(16); - - std::iota(input.begin(), input.end(), -8); - - auto pipeline = CreatePostprocessPipeline(Unquantize(0.5f), ReLU()); - auto inited_pipeline = InitPostprocessPipeline<CPUType::AVX2>(pipeline); - inited_pipeline.run(input.as<__m256i>(), 2, output.as<__m256>()); - - CHECK(output[0] == 0.f); // input = -8 - CHECK(output[1] == 0.f); // input = -7 - CHECK(output[2] == 0.f); // input = -6 - CHECK(output[3] == 0.f); // input = -5 - CHECK(output[4] == 0.f); // input = -4 - CHECK(output[5] == 0.f); // input = -3 - CHECK(output[6] == 0.f); // input = -2 - CHECK(output[7] == 0.f); // input = -1 - CHECK(output[8] == 0.0f); // input = 0 - CHECK(output[9] == 0.5f); // input = 1 - CHECK(output[10] == 1.0f); // input = 2 - CHECK(output[11] == 1.5f); // input = 3 - CHECK(output[12] == 2.0f); // input = 4 - CHECK(output[13] == 2.5f); // input = 5 - CHECK(output[14] == 3.0f); // input = 6 - CHECK(output[15] == 3.5f); // input = 7 -} - -} diff --git a/test/postprocess/relu_test.cc b/test/postprocess/relu_test.cc deleted file mode 100644 index a560790..0000000 --- a/test/postprocess/relu_test.cc +++ /dev/null @@ -1,213 +0,0 @@ -#include "test/test.h" -#include "aligned.h" -#include "postprocess.h" - -#include <numeric> - -namespace intgemm { - -/* - * ReLU: float -> float - */ -INTGEMM_SSE2 TEST_CASE("ReLU SSE2",) { - if (kCPU < CPUType::SSE2) - return; - - AlignedVector<float> input(8); - AlignedVector<float> output(8); - std::iota(input.begin(), input.end(), -2); - - auto postproc = PostprocessImpl<ReLU, CPUType::SSE2>(ReLU()); - auto output_tmp = postproc.run({input.as<__m128>()[0], input.as<__m128>()[1]}, 0); - output.as<__m128>()[0] = output_tmp.first; - output.as<__m128>()[1] = output_tmp.second; - - CHECK(output[0] == 0.f); // input = -2 - CHECK(output[1] == 0.f); // input = -1 - CHECK(output[2] == 0.f); // input = 0 - CHECK(output[3] == 1.f); // input = 1 - CHECK(output[4] == 2.f); // input = 2 - CHECK(output[5] == 3.f); // input = 3 - CHECK(output[6] == 4.f); // input = 4 - CHECK(output[7] == 5.f); // input = 5 -} - -INTGEMM_AVX2 TEST_CASE("ReLU AVX2",) { - if (kCPU < CPUType::AVX2) - return; - - AlignedVector<float> input(8); - AlignedVector<float> output(8); - - std::iota(input.begin(), input.end(), -4); - - auto postproc = PostprocessImpl<ReLU, CPUType::AVX2>(ReLU()); - *output.as<__m256>() = postproc.run(*input.as<__m256>(), 0); - - CHECK(output[0] == 0.f); // input = -4 - CHECK(output[1] == 0.f); // input = -3 - CHECK(output[2] == 0.f); // input = -2 - CHECK(output[3] == 0.f); // input = -1 - CHECK(output[4] == 0.f); // input = 0 - CHECK(output[5] == 1.f); // input = 1 - CHECK(output[6] == 2.f); // input = 2 - CHECK(output[7] == 3.f); // input = 3 -} - -#ifndef INTGEMM_NO_AVX512 - -INTGEMM_AVX512BW TEST_CASE("ReLU AVX512",) { - if (kCPU < CPUType::AVX512BW) - return; - - AlignedVector<float> input(16); - AlignedVector<float> output(16); - - std::iota(input.begin(), input.end(), -8); - - auto postproc = PostprocessImpl<ReLU, CPUType::AVX512BW>(ReLU()); - *output.as<__m512>() = postproc.run(*input.as<__m512>(), 0); - - CHECK(output[0] == 0.f); // input = -8 - CHECK(output[1] == 0.f); // input = -7 - CHECK(output[2] == 0.f); // input = -6 - CHECK(output[3] == 0.f); // input = -5 - CHECK(output[4] == 0.f); // input = -4 - CHECK(output[5] == 0.f); // input = -3 - CHECK(output[6] == 0.f); // input = -2 - CHECK(output[7] == 0.f); // input = -1 - CHECK(output[8] == 0.f); // input = 0 - CHECK(output[9] == 1.f); // input = 1 - CHECK(output[10] == 2.f); // input = 2 - CHECK(output[11] == 3.f); // input = 3 - CHECK(output[12] == 4.f); // input = 4 - CHECK(output[13] == 5.f); // input = 5 - CHECK(output[14] == 6.f); // input = 6 - CHECK(output[15] == 7.f); // input = 7 -} - -#endif - -/* - * ReLU: int8 -> int8 - */ -INTGEMM_SSE2 TEST_CASE("ReLU_int8 SSE2",) { - if (kCPU < CPUType::SSE2) - return; - - const unsigned NEGATIVE_NUMBERS = 10; - - AlignedVector<int8_t> input(16); - AlignedVector<int8_t> output(16); - - std::iota(input.begin(), input.end(), -NEGATIVE_NUMBERS); - - auto postproc = PostprocessImpl<ReLU_int8, CPUType::SSE2>(ReLU_int8()); - *output.as<__m128i>() = postproc.run(*input.as<__m128i>(), 0); - - for (auto i = 0; i < output.size(); ++i) - CHECK(output[i] == (i <= NEGATIVE_NUMBERS ? 0 : i - NEGATIVE_NUMBERS)); -} - -INTGEMM_AVX2 TEST_CASE("ReLU_int8 AVX2",) { - if (kCPU < CPUType::AVX2) - return; - - const unsigned NEGATIVE_NUMBERS = 10; - - AlignedVector<int8_t> input(32); - AlignedVector<int8_t> output(32); - - std::iota(input.begin(), input.end(), -NEGATIVE_NUMBERS); - - auto postproc = PostprocessImpl<ReLU_int8, CPUType::AVX2>(ReLU_int8()); - *output.as<__m256i>() = postproc.run(*input.as<__m256i>(), 0); - - for (auto i = 0; i < output.size(); ++i) - CHECK(output[i] == (i <= NEGATIVE_NUMBERS ? 0 : i - NEGATIVE_NUMBERS)); -} - -#ifndef INTGEMM_NO_AVX512 - -INTGEMM_AVX512BW TEST_CASE("ReLU_int8 AVX512",) { - if (kCPU < CPUType::AVX512BW) - return; - - const unsigned NEGATIVE_NUMBERS = 30; - - AlignedVector<int8_t> input(64); - AlignedVector<int8_t> output(64); - - std::iota(input.begin(), input.end(), -NEGATIVE_NUMBERS); - - auto postproc = PostprocessImpl<ReLU_int8, CPUType::AVX512BW>(ReLU_int8()); - *output.as<__m512i>() = postproc.run(*input.as<__m512i>(), 0); - - for (auto i = 0; i < output.size(); ++i) - CHECK(output[i] == (i <= NEGATIVE_NUMBERS ? 0 : i - NEGATIVE_NUMBERS)); -} - -#endif - -/* - * ReLU: int16 -> int16 - */ -INTGEMM_SSE2 TEST_CASE("ReLU_int16 SSE2",) { - if (kCPU < CPUType::SSE2) - return; - - const unsigned NEGATIVE_NUMBERS = 5; - - AlignedVector<int16_t> input(8); - AlignedVector<int16_t> output(8); - - std::iota(input.begin(), input.end(), -NEGATIVE_NUMBERS); - - auto postproc = PostprocessImpl<ReLU_int16, CPUType::SSE2>(ReLU_int16()); - *output.as<__m128i>() = postproc.run(*input.as<__m128i>(), 0); - - for (auto i = 0; i < output.size(); ++i) - CHECK(output[i] == (i <= NEGATIVE_NUMBERS ? 0 : i - NEGATIVE_NUMBERS)); -} - -INTGEMM_AVX2 TEST_CASE("ReLU_int16 AVX2",) { - if (kCPU < CPUType::AVX2) - return; - - const unsigned NEGATIVE_NUMBERS = 10; - - AlignedVector<int16_t> input(16); - AlignedVector<int16_t> output(16); - - std::iota(input.begin(), input.end(), -NEGATIVE_NUMBERS); - - auto postproc = PostprocessImpl<ReLU_int16, CPUType::AVX2>(ReLU_int16()); - *output.as<__m256i>() = postproc.run(*input.as<__m256i>(), 0); - - for (auto i = 0; i < output.size(); ++i) - CHECK(output[i] == (i <= NEGATIVE_NUMBERS ? 0 : i - NEGATIVE_NUMBERS)); -} - -#ifndef INTGEMM_NO_AVX512 - -INTGEMM_AVX512BW TEST_CASE("ReLU_int16 AVX512",) { - if (kCPU < CPUType::AVX512BW) - return; - - const unsigned NEGATIVE_NUMBERS = 15; - - AlignedVector<int16_t> input(32); - AlignedVector<int16_t> output(32); - - std::iota(input.begin(), input.end(), -NEGATIVE_NUMBERS); - - auto postproc = PostprocessImpl<ReLU_int16, CPUType::AVX512BW>(ReLU_int16()); - *output.as<__m512i>() = postproc.run(*input.as<__m512i>(), 0); - - for (auto i = 0; i < output.size(); ++i) - CHECK(output[i] == (i <= NEGATIVE_NUMBERS ? 0 : i - NEGATIVE_NUMBERS)); -} - -#endif - -} diff --git a/test/postprocess/sigmoid_test.cc b/test/postprocess/sigmoid_test.cc deleted file mode 100644 index 43c713c..0000000 --- a/test/postprocess/sigmoid_test.cc +++ /dev/null @@ -1,33 +0,0 @@ -#include "test/test.h" -#include "aligned.h" -#include "postprocess.h" - -#include <numeric> - -namespace intgemm { - -INTGEMM_AVX2 TEST_CASE("Sigmoid AVX2",) { - if (kCPU < CPUType::AVX2) - return; - - const float error_tolerance = 0.001f; - - AlignedVector<float> input(8); - AlignedVector<float> output(8); - - std::iota(input.begin(), input.end(), -4); - - auto postproc = PostprocessImpl<Sigmoid, CPUType::AVX2>(Sigmoid()); - *output.as<__m256>() = postproc.run(*input.as<__m256>(), 0); - - CHECK_EPS(output[0], 0.0179862f, error_tolerance); // input = -4 - CHECK_EPS(output[1], 0.0474259f, error_tolerance); // input = -3 - CHECK_EPS(output[2], 0.1192029f, error_tolerance); // input = -2 - CHECK_EPS(output[3], 0.2689414f, error_tolerance); // input = -1 - CHECK_EPS(output[4], 0.5f , error_tolerance); // input = 0 - CHECK_EPS(output[5], 0.7310586f, error_tolerance); // input = 1 - CHECK_EPS(output[6], 0.8807970f, error_tolerance); // input = 2 - CHECK_EPS(output[7], 0.9525740f, error_tolerance); // input = 3 -} - -} diff --git a/test/postprocess/tanh_test.cc b/test/postprocess/tanh_test.cc deleted file mode 100644 index f0e4dc2..0000000 --- a/test/postprocess/tanh_test.cc +++ /dev/null @@ -1,33 +0,0 @@ -#include "test/test.h" -#include "aligned.h" -#include "postprocess.h" - -#include <numeric> - -namespace intgemm { - -INTGEMM_AVX2 TEST_CASE("Tanh AVX2",) { - if (kCPU < CPUType::AVX2) - return; - - const float error_tolerance = 0.001f; - - AlignedVector<float> input(8); - AlignedVector<float> output(8); - - std::generate(input.begin(), input.end(), [] () { static int n = -4; return n++ / 4.f; }); - - auto postproc = PostprocessImpl<Tanh, CPUType::AVX2>(Tanh()); - *output.as<__m256>() = postproc.run(*input.as<__m256>(), 0); - - CHECK_EPS(output[0], -0.7615942f, error_tolerance); // input = -1 - CHECK_EPS(output[1], -0.6351490f, error_tolerance); // input = -0.75 - CHECK_EPS(output[2], -0.4621172f, error_tolerance); // input = -0.5 - CHECK_EPS(output[3], -0.2449187f, error_tolerance); // input = -0.25 - CHECK_EPS(output[4], 0.0f , error_tolerance); // input = 0 - CHECK_EPS(output[5], 0.2449187f, error_tolerance); // input = 0.25 - CHECK_EPS(output[6], 0.4621172f, error_tolerance); // input = 0.5 - CHECK_EPS(output[7], 0.6351490f, error_tolerance); // input = 0.75 -} - -} diff --git a/test/postprocess/unquantize_test.cc b/test/postprocess/unquantize_test.cc deleted file mode 100644 index 45e6bc4..0000000 --- a/test/postprocess/unquantize_test.cc +++ /dev/null @@ -1,88 +0,0 @@ -#include "test/test.h" -#include "aligned.h" -#include "postprocess.h" - -#include <numeric> - -namespace intgemm { - -INTGEMM_SSE2 TEST_CASE("Unquantize SSE2",) { - if (kCPU < CPUType::SSE2) - return; - - AlignedVector<int32_t> input(8); - AlignedVector<float> output(8); - std::iota(input.begin(), input.end(), -2); - - auto postproc = PostprocessImpl<Unquantize, CPUType::SSE2>(Unquantize(0.5f)); - auto output_tmp = postproc.run({input.as<__m128i>()[0], input.as<__m128i>()[1]}, 0); - output.as<__m128>()[0] = output_tmp.first; - output.as<__m128>()[1] = output_tmp.second; - - CHECK(output[0] == -1.0f); // input = -2 - CHECK(output[1] == -0.5f); // input = -1 - CHECK(output[2] == 0.0f); // input = 0 - CHECK(output[3] == 0.5f); // input = 1 - CHECK(output[4] == 1.0f); // input = 2 - CHECK(output[5] == 1.5f); // input = 3 - CHECK(output[6] == 2.0f); // input = 4 - CHECK(output[7] == 2.5f); // input = 5 -} - -INTGEMM_AVX2 TEST_CASE("Unquantize AVX2",) { - if (kCPU < CPUType::AVX2) - return; - - AlignedVector<int32_t> input(8); - AlignedVector<float> output(8); - - std::iota(input.begin(), input.end(), -4); - - auto postproc = PostprocessImpl<Unquantize, CPUType::AVX2>(Unquantize(0.5f)); - *output.as<__m256>() = postproc.run(*input.as<__m256i>(), 0); - - CHECK(output[0] == -2.0f); // input = -4 - CHECK(output[1] == -1.5f); // input = -3 - CHECK(output[2] == -1.0f); // input = -2 - CHECK(output[3] == -0.5f); // input = -1 - CHECK(output[4] == 0.0f); // input = 0 - CHECK(output[5] == 0.5f); // input = 1 - CHECK(output[6] == 1.0f); // input = 2 - CHECK(output[7] == 1.5f); // input = 3 -} - -#ifndef INTGEMM_NO_AVX512 - -INTGEMM_AVX512BW TEST_CASE("Unquantize AVX512",) { - if (kCPU < CPUType::AVX512BW) - return; - - AlignedVector<int32_t> input(16); - AlignedVector<float> output(16); - - std::iota(input.begin(), input.end(), -8); - - auto postproc = PostprocessImpl<Unquantize, CPUType::AVX512BW>(Unquantize(0.5f)); - *output.as<__m512>() = postproc.run(*input.as<__m512i>(), 0); - - CHECK(output[0] == -4.0f); // input = -8 - CHECK(output[1] == -3.5f); // input = -7 - CHECK(output[2] == -3.0f); // input = -6 - CHECK(output[3] == -2.5f); // input = -5 - CHECK(output[4] == -2.0f); // input = -4 - CHECK(output[5] == -1.5f); // input = -3 - CHECK(output[6] == -1.0f); // input = -2 - CHECK(output[7] == -0.5f); // input = -1 - CHECK(output[8] == 0.0f); // input = 0 - CHECK(output[9] == 0.5f); // input = 1 - CHECK(output[10] == 1.0f); // input = 2 - CHECK(output[11] == 1.5f); // input = 3 - CHECK(output[12] == 2.0f); // input = 4 - CHECK(output[13] == 2.5f); // input = 5 - CHECK(output[14] == 3.0f); // input = 6 - CHECK(output[15] == 3.5f); // input = 7 -} - -#endif - -} diff --git a/vec_utils.h b/vec_utils.h index a5f1469..7254063 100644 --- a/vec_utils.h +++ b/vec_utils.h @@ -1,6 +1,7 @@ #pragma once #include "intrinsics.h" +#include "utils.h" namespace intgemm { |