Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/intgemm/intgemm.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKenneth Heafield <github@kheafield.com>2019-07-05 14:25:47 +0300
committerKenneth Heafield <github@kheafield.com>2019-07-05 14:25:47 +0300
commitb7a7fe77a4e659f785588585526e06721ffcdd08 (patch)
tree0a975c5c7fb48349a0c348bbbd5c8527caef17ff
parent2807988a70c59169c1ea223bd734562351508f47 (diff)
parentce292be1138ecce0ec127ed59fe79d0091be7d11 (diff)
Merge branch 'master' into 4bit4bit
-rw-r--r--CMakeLists.txt10
-rw-r--r--aligned.h3
-rw-r--r--interleave.h1
-rw-r--r--intrinsics.h77
-rw-r--r--postprocess.h252
-rw-r--r--postprocess_pipeline.h4
-rw-r--r--test/multiply_test.cc31
-rw-r--r--test/pipeline_test.cc70
-rw-r--r--test/postprocess/add_bias_test.cc95
-rw-r--r--test/postprocess/pipeline_test.cc63
-rw-r--r--test/postprocess/relu_test.cc213
-rw-r--r--test/postprocess/sigmoid_test.cc33
-rw-r--r--test/postprocess/tanh_test.cc33
-rw-r--r--test/postprocess/unquantize_test.cc88
-rw-r--r--test/quantize_test.cc12
-rw-r--r--test/relu_test.cc89
-rw-r--r--test/test.cc6
-rw-r--r--test/test.h12
-rw-r--r--test/utils_test.cc38
-rw-r--r--utils.h20
-rw-r--r--vec_utils.h80
21 files changed, 984 insertions, 246 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt
index d39e09c..c6fc8d0 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -33,9 +33,15 @@ endforeach()
include_directories(.)
add_executable(tests
test/multiply_test.cc
- test/pipeline_test.cc
+ test/postprocess/add_bias_test.cc
+ test/postprocess/pipeline_test.cc
+ test/postprocess/relu_test.cc
+ test/postprocess/sigmoid_test.cc
+ test/postprocess/tanh_test.cc
+ test/postprocess/unquantize_test.cc
test/quantize_test.cc
- test/relu_test.cc
+ test/test.cc
+ test/utils_test.cc
test/log4_test.cc
intgemm.cc
)
diff --git a/aligned.h b/aligned.h
index 7514000..6795788 100644
--- a/aligned.h
+++ b/aligned.h
@@ -22,6 +22,9 @@ template <class T> class AlignedVector {
T *end() { return mem_ + size_; }
const T *end() const { return mem_ + size_; }
+ template <typename ReturnType>
+ ReturnType *as() { return reinterpret_cast<ReturnType*>(mem_); }
+
private:
T *mem_;
std::size_t size_;
diff --git a/interleave.h b/interleave.h
index 4c4e956..d9ade05 100644
--- a/interleave.h
+++ b/interleave.h
@@ -3,6 +3,7 @@
#include "intrinsics.h"
#include "types.h"
+#include <algorithm>
#include <cassert>
#include <stdint.h>
diff --git a/intrinsics.h b/intrinsics.h
index 7c36d6b..293efc3 100644
--- a/intrinsics.h
+++ b/intrinsics.h
@@ -36,6 +36,9 @@ INTGEMM_SSE2 static inline __m128i add_epi32(__m128i first, __m128i second) {
INTGEMM_SSE2 static inline __m128i adds_epi16(__m128i first, __m128i second) {
return _mm_adds_epi16(first, second);
}
+INTGEMM_SSE2 static inline __m128 add_ps(__m128 a, __m128 b) {
+ return _mm_add_ps(a, b);
+}
INTGEMM_SSE2 static inline __m128 and_ps(__m128 first, __m128 second) {
return _mm_and_ps(first, second);
}
@@ -45,6 +48,15 @@ INTGEMM_SSE2 static inline __m128 cvtepi32_ps(__m128i arg) {
INTGEMM_SSE2 static inline __m128i cvtps_epi32(__m128 arg) {
return _mm_cvtps_epi32(arg);
}
+INTGEMM_SSE2 static inline __m128i cvttps_epi32(__m128 a) {
+ return _mm_cvttps_epi32(a);
+}
+INTGEMM_SSE2 static inline __m128 div_ps(__m128 a, __m128 b) {
+ return _mm_div_ps(a, b);
+}
+/*
+ * Missing i32gather_ps for SSE2
+ */
template <> INTGEMM_SSE2 inline __m128 loadu_ps(const float* mem_addr) {
return _mm_loadu_ps(mem_addr);
}
@@ -54,9 +66,18 @@ INTGEMM_SSE2 static inline __m128i madd_epi16(__m128i first, __m128i second) {
INTGEMM_SSSE3 static inline __m128i maddubs_epi16(__m128i first, __m128i second) {
return _mm_maddubs_epi16(first, second);
}
+/*
+ * Missing max_epi8 for SSE2
+ */
+INTGEMM_SSE2 static inline __m128i max_epi16(__m128i first, __m128i second) {
+ return _mm_max_epi16(first, second);
+}
INTGEMM_SSE2 static inline __m128 max_ps(__m128 first, __m128 second) {
return _mm_max_ps(first, second);
}
+INTGEMM_SSE2 static inline __m128 min_ps(__m128 a, __m128 b) {
+ return _mm_min_ps(a, b);
+}
INTGEMM_SSE2 static inline __m128 mul_ps(__m128 a, __m128 b) {
return _mm_mul_ps(a, b);
}
@@ -81,8 +102,8 @@ INTGEMM_SSSE3 static inline __m128i sign_epi8(__m128i first, __m128i second) {
INTGEMM_SSE2 static inline void storeu_ps(float* mem_addr, __m128 a) {
_mm_storeu_ps(mem_addr, a);
}
-INTGEMM_SSE2 static inline __m128 add_ps (__m128 a, __m128 b) {
- return _mm_add_ps(a, b);
+INTGEMM_SSE2 static inline __m128 sub_ps(__m128 a, __m128 b) {
+ return _mm_sub_ps(a, b);
}
/*
@@ -99,6 +120,9 @@ INTGEMM_AVX2 static inline __m256i add_epi32(__m256i first, __m256i second) {
INTGEMM_AVX2 static inline __m256i adds_epi16(__m256i first, __m256i second) {
return _mm256_adds_epi16(first, second);
}
+INTGEMM_AVX2 static inline __m256 add_ps(__m256 a, __m256 b) {
+ return _mm256_add_ps(a, b);
+}
INTGEMM_AVX2 static inline __m256 and_ps(__m256 first, __m256 second) {
return _mm256_and_ps(first, second);
}
@@ -108,6 +132,16 @@ INTGEMM_AVX2 static inline __m256 cvtepi32_ps(__m256i arg) {
INTGEMM_AVX2 static inline __m256i cvtps_epi32(__m256 arg) {
return _mm256_cvtps_epi32(arg);
}
+INTGEMM_AVX2 static inline __m256i cvttps_epi32(__m256 a) {
+ return _mm256_cvttps_epi32(a);
+}
+INTGEMM_AVX2 static inline __m256 div_ps(__m256 a, __m256 b) {
+ return _mm256_div_ps(a, b);
+}
+template <unsigned Scale>
+INTGEMM_AVX2 static inline __m256 i32gather_ps(float const *base_addr, __m256i vindex) {
+ return _mm256_i32gather_ps(base_addr, vindex, Scale);
+}
template <> INTGEMM_AVX2 inline __m256 loadu_ps(const float* mem_addr) {
return _mm256_loadu_ps(mem_addr);
}
@@ -117,9 +151,18 @@ INTGEMM_AVX2 static inline __m256i madd_epi16(__m256i first, __m256i second) {
INTGEMM_AVX2 static inline __m256i maddubs_epi16(__m256i first, __m256i second) {
return _mm256_maddubs_epi16(first, second);
}
+INTGEMM_AVX2 static inline __m256i max_epi8(__m256i first, __m256i second) {
+ return _mm256_max_epi8(first, second);
+}
+INTGEMM_AVX2 static inline __m256i max_epi16(__m256i first, __m256i second) {
+ return _mm256_max_epi16(first, second);
+}
INTGEMM_AVX2 static inline __m256 max_ps(__m256 first, __m256 second) {
return _mm256_max_ps(first, second);
}
+INTGEMM_AVX2 static inline __m256 min_ps(__m256 a, __m256 b) {
+ return _mm256_min_ps(a, b);
+}
INTGEMM_AVX2 static inline __m256 mul_ps(__m256 a, __m256 b) {
return _mm256_mul_ps(a, b);
}
@@ -144,8 +187,8 @@ INTGEMM_AVX2 static inline __m256i sign_epi8(__m256i first, __m256i second) {
INTGEMM_AVX2 static inline void storeu_ps(float* mem_addr, __m256 a) {
_mm256_storeu_ps(mem_addr, a);
}
-INTGEMM_AVX2 static inline __m256 add_ps (__m256 a, __m256 b) {
- return _mm256_add_ps(a, b);
+INTGEMM_AVX2 static inline __m256 sub_ps(__m256 a, __m256 b) {
+ return _mm256_sub_ps(a, b);
}
/*
@@ -164,6 +207,9 @@ INTGEMM_AVX512BW static inline __m512i add_epi32(__m512i first, __m512i second)
INTGEMM_AVX512BW static inline __m512i adds_epi16(__m512i first, __m512i second) {
return _mm512_adds_epi16(first, second);
}
+INTGEMM_AVX512BW static inline __m512 add_ps(__m512 a, __m512 b) {
+ return _mm512_add_ps(a, b);
+}
INTGEMM_AVX512DQ static inline __m512 and_ps(__m512 first, __m512 second) {
return _mm512_and_ps(first, second);
}
@@ -173,6 +219,16 @@ INTGEMM_AVX512BW static inline __m512 cvtepi32_ps(__m512i arg) {
INTGEMM_AVX512BW static inline __m512i cvtps_epi32(__m512 arg) {
return _mm512_cvtps_epi32(arg);
}
+INTGEMM_AVX512BW static inline __m512i cvttps_epi32(__m512 a) {
+ return _mm512_cvttps_epi32(a);
+}
+INTGEMM_AVX512BW static inline __m512 div_ps(__m512 a, __m512 b) {
+ return _mm512_div_ps(a, b);
+}
+template <unsigned Scale>
+INTGEMM_AVX512BW static inline __m512 i32gather_ps(float const *base_addr, __m512i vindex) {
+ return _mm512_i32gather_ps(vindex, base_addr, Scale);
+}
template <> INTGEMM_AVX512BW inline __m512 loadu_ps(const float* mem_addr) {
return _mm512_loadu_ps(mem_addr);
}
@@ -182,11 +238,17 @@ INTGEMM_AVX512BW static inline __m512i madd_epi16(__m512i first, __m512i second)
INTGEMM_AVX512BW static inline __m512i maddubs_epi16(__m512i first, __m512i second) {
return _mm512_maddubs_epi16(first, second);
}
+INTGEMM_AVX512BW static inline __m512i max_epi8(__m512i first, __m512i second) {
+ return _mm512_max_epi8(first, second);
+}
+INTGEMM_AVX512BW static inline __m512i max_epi16(__m512i first, __m512i second) {
+ return _mm512_max_epi16(first, second);
+}
INTGEMM_AVX512BW static inline __m512 max_ps(__m512 first, __m512 second) {
return _mm512_max_ps(first, second);
}
-INTGEMM_AVX512BW static inline __m512 add_ps(__m512 first, __m512 second) {
- return _mm512_add_ps(first, second);
+INTGEMM_AVX512BW static inline __m512 min_ps(__m512 a, __m512 b) {
+ return _mm512_min_ps(a, b);
}
INTGEMM_AVX512BW static inline __m512 mul_ps(__m512 a, __m512 b) {
return _mm512_mul_ps(a, b);
@@ -212,6 +274,9 @@ template <> INTGEMM_AVX512BW inline __m512i setzero_si<__m512i>() {
INTGEMM_AVX512BW static inline void storeu_ps(float* mem_addr, __m512 a) {
_mm512_storeu_ps(mem_addr, a);
}
+INTGEMM_AVX512BW static inline __m512 sub_ps(__m512 a, __m512 b) {
+ return _mm512_sub_ps(a, b);
+}
#endif
diff --git a/postprocess.h b/postprocess.h
index 0855548..7835b2b 100644
--- a/postprocess.h
+++ b/postprocess.h
@@ -5,6 +5,10 @@
#include "types.h"
#include "vec_utils.h"
+// TODO: We support some postprocess in few variations e.g. we support ReLU for
+// float -> float, int8 -> int8, int16 -> int16. Maybe it would be a good idea
+// to pass input type and output type as a template parameter of postprocess?
+
namespace intgemm {
/*
@@ -56,6 +60,8 @@ private:
__m256 unquantize_multiplier;
};
+#ifndef INTGEMM_NO_AVX512
+
template <>
class PostprocessImpl<Unquantize, CPUType::AVX512BW> {
public:
@@ -74,49 +80,7 @@ private:
__m512 unquantize_multiplier;
};
-/*
- * Identity
- */
-class Identity {};
-
-template <>
-class PostprocessImpl<Identity, CPUType::SSE2> {
-public:
- using InputRegister = RegisterPair128i;
- using OutputRegister = RegisterPair128i;
-
- PostprocessImpl(const Identity& config) {}
-
- INTGEMM_SSE2 inline OutputRegister run(InputRegister input, Index offset) {
- return input;
- }
-};
-
-template <>
-class PostprocessImpl<Identity, CPUType::AVX2> {
-public:
- using InputRegister = __m256i;
- using OutputRegister = __m256i;
-
- PostprocessImpl(const Identity& config) {}
-
- INTGEMM_AVX2 inline OutputRegister run(InputRegister input, Index offset) {
- return input;
- }
-};
-
-template <>
-class PostprocessImpl<Identity, CPUType::AVX512BW> {
-public:
- using InputRegister = __m512i;
- using OutputRegister = __m512i;
-
- PostprocessImpl(const Identity& config) {}
-
- INTGEMM_AVX512BW inline OutputRegister run(InputRegister input, Index offset) {
- return input;
- }
-};
+#endif
/*
* Add a bias term
@@ -167,6 +131,27 @@ private:
const AddBias config;
};
+#ifndef INTGEMM_NO_AVX512
+
+template <>
+class PostprocessImpl<AddBias, CPUType::AVX512BW> {
+public:
+ using InputRegister = __m512;
+ using OutputRegister = __m512;
+
+ PostprocessImpl(const AddBias& config) : config(config) {}
+
+ INTGEMM_AVX512BW inline OutputRegister run(InputRegister input, Index offset) {
+ auto bias_term = *reinterpret_cast<const __m512*>(config.bias + (offset % config.length));
+ return add_ps(input, bias_term);
+ }
+
+private:
+ const AddBias config;
+};
+
+#endif
+
/*
* ReLU
*/
@@ -206,6 +191,8 @@ public:
}
};
+#ifndef INTGEMM_NO_AVX512
+
template <>
class PostprocessImpl<ReLU, CPUType::AVX512BW> {
public:
@@ -220,4 +207,183 @@ public:
}
};
+#endif
+
+/*
+ * ReLU_int8
+ */
+class ReLU_int8 {};
+
+template <>
+class PostprocessImpl<ReLU_int8, CPUType::SSE2> {
+public:
+ using InputRegister = __m128i;
+ using OutputRegister = __m128i;
+
+ PostprocessImpl(const ReLU_int8& config) {}
+
+ INTGEMM_SSE2 inline OutputRegister run(InputRegister input, Index offset) {
+ static const auto const_zero = setzero_si<__m128i>();
+ return _mm_and_si128(_mm_cmplt_epi8(const_zero, input), input);
+ }
+};
+
+template <>
+class PostprocessImpl<ReLU_int8, CPUType::AVX2> {
+public:
+ using InputRegister = __m256i;
+ using OutputRegister = __m256i;
+
+ PostprocessImpl(const ReLU_int8& config) {}
+
+ INTGEMM_AVX2 inline OutputRegister run(InputRegister input, Index offset) {
+ static const auto const_zero = setzero_si<__m256i>();
+ return max_epi8(const_zero, input);
+ }
+};
+
+#ifndef INTGEMM_NO_AVX512
+
+template <>
+class PostprocessImpl<ReLU_int8, CPUType::AVX512BW> {
+public:
+ using InputRegister = __m512i;
+ using OutputRegister = __m512i;
+
+ PostprocessImpl(const ReLU_int8& config) {}
+
+ INTGEMM_AVX512BW inline OutputRegister run(InputRegister input, Index offset) {
+ static const auto const_zero = setzero_si<__m512i>();
+ return max_epi8(const_zero, input);
+ }
+};
+
+#endif
+
+/*
+ * ReLU_int16
+ */
+class ReLU_int16 {};
+
+template <>
+class PostprocessImpl<ReLU_int16, CPUType::SSE2> {
+public:
+ using InputRegister = __m128i;
+ using OutputRegister = __m128i;
+
+ PostprocessImpl(const ReLU_int16& config) {}
+
+ INTGEMM_SSE2 inline OutputRegister run(InputRegister input, Index offset) {
+ static const auto const_zero = setzero_si<__m128i>();
+ return max_epi16(const_zero, input);
+ }
+};
+
+template <>
+class PostprocessImpl<ReLU_int16, CPUType::AVX2> {
+public:
+ using InputRegister = __m256i;
+ using OutputRegister = __m256i;
+
+ PostprocessImpl(const ReLU_int16& config) {}
+
+ INTGEMM_AVX2 inline OutputRegister run(InputRegister input, Index offset) {
+ static const auto const_zero = setzero_si<__m256i>();
+ return max_epi16(const_zero, input);
+ }
+};
+
+#ifndef INTGEMM_NO_AVX512
+
+template <>
+class PostprocessImpl<ReLU_int16, CPUType::AVX512BW> {
+public:
+ using InputRegister = __m512i;
+ using OutputRegister = __m512i;
+
+ PostprocessImpl(const ReLU_int16& config) {}
+
+ INTGEMM_AVX512BW inline OutputRegister run(InputRegister input, Index offset) {
+ static const auto const_zero = setzero_si<__m512i>();
+ return max_epi16(const_zero, input);
+ }
+};
+
+#endif
+
+/*
+ * Sigmoid (uses Taylor series approximation of e^x)
+ */
+class Sigmoid {};
+
+template <>
+class PostprocessImpl<Sigmoid, CPUType::AVX2> {
+public:
+ using InputRegister = __m256;
+ using OutputRegister = __m256;
+
+ PostprocessImpl(const Sigmoid& config) {}
+
+ INTGEMM_AVX2 inline OutputRegister run(InputRegister input, Index offset) {
+ static const auto const_zero = set1_ps<__m256>(0.f);
+ static const auto const_one = set1_ps<__m256>(1.f);
+
+ auto x = input;
+ auto minus_x = sub_ps(const_zero, x);
+ auto e_x = exp_approx_taylor(x);
+ auto e_minus_x = exp_approx_taylor(minus_x);
+
+ auto sigmoid_case1 = _mm256_rcp_ps(add_ps(const_one, e_minus_x));
+ auto sigmoid_case2 = mul_ps(e_x, _mm256_rcp_ps(add_ps(const_one, e_x)));
+
+ auto nonnegative_x_mask = _mm256_cmp_ps(const_zero, x, _CMP_LT_OS);
+ return _mm256_blendv_ps(sigmoid_case1, sigmoid_case2, nonnegative_x_mask);
+ }
+};
+
+/*
+ * Tanh (uses Taylor series approximation of e^x)
+ */
+class Tanh {};
+
+template <>
+class PostprocessImpl<Tanh, CPUType::AVX2> {
+public:
+ using InputRegister = __m256;
+ using OutputRegister = __m256;
+
+ PostprocessImpl(const Tanh& config) {}
+
+ INTGEMM_AVX2 inline OutputRegister run(InputRegister input, Index offset) {
+ const static auto const_zero = setzero_ps<__m256>();
+
+ auto e_x = exp_approx_taylor(input);
+ auto e_minus_x = exp_approx_taylor(sub_ps(const_zero, input));
+
+ return div_ps(sub_ps(e_x, e_minus_x), add_ps(e_x, e_minus_x));
+ }
+};
+
+#ifndef INTGEMM_NO_AVX512
+
+template <>
+class PostprocessImpl<Tanh, CPUType::AVX512BW> {
+public:
+ using InputRegister = __m512;
+ using OutputRegister = __m512;
+
+ PostprocessImpl(const Tanh& config) {}
+
+ INTGEMM_AVX512BW inline OutputRegister run(InputRegister input, Index offset) {
+ const static auto const_zero = setzero_ps<__m512>();
+
+ auto e_x = exp_approx_taylor(input);
+ auto e_minus_x = exp_approx_taylor(sub_ps(const_zero, input));
+
+ return div_ps(sub_ps(e_x, e_minus_x), add_ps(e_x, e_minus_x));
+ }
+};
+
+#endif
+
}
diff --git a/postprocess_pipeline.h b/postprocess_pipeline.h
index ad26ac5..361ff2b 100644
--- a/postprocess_pipeline.h
+++ b/postprocess_pipeline.h
@@ -12,8 +12,8 @@ template <typename... Stages>
using PostprocessPipeline = std::tuple<Stages...>;
template <typename... Stages>
-constexpr std::tuple<Stages...> CreatePostprocessPipeline(const Stages&... stages) {
- return std::make_tuple(stages...);
+constexpr std::tuple<Stages...> CreatePostprocessPipeline(Stages&&... stages) {
+ return std::make_tuple(std::forward<Stages>(stages)...);
}
template <typename Postprocess, CPUType CpuType>
diff --git a/test/multiply_test.cc b/test/multiply_test.cc
index 82062fe..93d7127 100644
--- a/test/multiply_test.cc
+++ b/test/multiply_test.cc
@@ -1,22 +1,16 @@
+#include "test/test.h"
#include "aligned.h"
#include "interleave.h"
#include "intgemm.h"
#include "multiply.h"
#include "postprocess.h"
-#define CATCH_CONFIG_RUNNER
-#include "3rd_party/catch.hpp"
-#define CHECK_MESSAGE(cond, msg) do { INFO(msg); CHECK(cond); } while((void)0, 0)
-#define CHECK_FALSE_MESSAGE(cond, msg) do { INFO(msg); CHECK_FALSE(cond); } while((void)0, 0)
-#define REQUIRE_MESSAGE(cond, msg) do { INFO(msg); REQUIRE(cond); } while((void)0, 0)
-#define REQUIRE_FALSE_MESSAGE(cond, msg) do { INFO(msg); REQUIRE_FALSE(cond); } while((void)0, 0)
-
#include <algorithm>
#include <cassert>
#include <cmath>
-#include <cstring>
#include <cstdio>
#include <cstdlib>
+#include <cstring>
#include <iomanip>
#include <iostream>
#include <memory>
@@ -61,7 +55,7 @@ INTGEMM_SSE2 TEST_CASE("Transpose 16", "[transpose]") {
SlowTranspose(input.begin(), ref.begin(), N, N);
// Overwrite input.
- __m128i *t = reinterpret_cast<__m128i*>(input.begin());
+ __m128i *t = input.as<__m128i>();
Transpose16InLane(t[0], t[1], t[2], t[3], t[4], t[5], t[6], t[7]);
for (int16_t i = 0; i < input.size(); ++i) {
@@ -79,7 +73,7 @@ INTGEMM_SSSE3 TEST_CASE("Transpose 8", "[transpose]") {
SlowTranspose(input.begin(), ref.begin(), N, N);
// Overwrite input.
- __m128i *t = reinterpret_cast<__m128i*>(input.begin());
+ __m128i *t = input.as<__m128i>();
Transpose8InLane(t[0], t[1], t[2], t[3], t[4], t[5], t[6], t[7], t[8], t[9], t[10], t[11], t[12], t[13], t[14], t[15]);
for (int i = 0; i < input.size(); ++i) {
@@ -554,20 +548,3 @@ TEST_CASE ("Multiply AVX2 16bit with bias", "[biased_multiply]") {
#endif
} // namespace intgemm
-
-int main(int argc, char ** argv) {
- return Catch::Session().run(argc, argv);
-}
-
-/*
- // Top matrix sizes from Marian
- TestBoth(8, 256, 256);
- TestBoth(8, 2048, 256);
- TestBoth(8, 2048, 256);
- TestBoth(320, 256, 256);
- TestBoth(472, 256, 256);
- TestBoth(248, 256, 256);
- TestBoth(200, 256, 256);
- return 0;
-}
-*/
diff --git a/test/pipeline_test.cc b/test/pipeline_test.cc
deleted file mode 100644
index 1b8c21d..0000000
--- a/test/pipeline_test.cc
+++ /dev/null
@@ -1,70 +0,0 @@
-#include "3rd_party/catch.hpp"
-#include "postprocess.h"
-
-#include <numeric>
-
-namespace intgemm {
-
-INTGEMM_AVX2 TEST_CASE("PostprocessPipeline AVX2", "Unquantize-ReLU") {
- if (kCPU < CPUType::AVX2)
- return;
-
- __m256i input;
- __m256 output;
-
- auto raw_input = reinterpret_cast<int*>(&input);
- std::iota(raw_input, raw_input + 8, -2);
-
- auto raw_output = reinterpret_cast<float*>(&output);
- std::fill(raw_output, raw_output + 8, 42);
-
- auto pipeline = CreatePostprocessPipeline(Unquantize(0.5f), ReLU());
- auto inited_pipeline = InitPostprocessPipeline<CPUType::AVX2>(pipeline);
- output = inited_pipeline.run(input, 0);
-
- CHECK(raw_output[0] == 0.0f); // input = -2
- CHECK(raw_output[1] == 0.0f); // input = -1
- CHECK(raw_output[2] == 0.0f); // input = 0
- CHECK(raw_output[3] == 0.5f); // input = 1
- CHECK(raw_output[4] == 1.0f); // input = 2
- CHECK(raw_output[5] == 1.5f); // input = 3
- CHECK(raw_output[6] == 2.0f); // input = 4
- CHECK(raw_output[7] == 2.5f); // input = 5
-}
-
-INTGEMM_AVX2 TEST_CASE("PostprocessPipeline AVX2 on whole buffer", "Unquantize-ReLU") {
- if (kCPU < CPUType::AVX2)
- return;
-
- __m256i input[2];
- __m256 output[2];
-
- auto raw_input = reinterpret_cast<int*>(input);
- std::iota(raw_input, raw_input + 16, -8);
-
- auto raw_output = reinterpret_cast<float*>(output);
- std::fill(raw_output, raw_output + 16, 42);
-
- auto pipeline = CreatePostprocessPipeline(Unquantize(0.5f), ReLU());
- auto inited_pipeline = InitPostprocessPipeline<CPUType::AVX2>(pipeline);
- inited_pipeline.run(input, 2, output);
-
- CHECK(raw_output[0] == 0.f); // input = -8
- CHECK(raw_output[1] == 0.f); // input = -7
- CHECK(raw_output[2] == 0.f); // input = -6
- CHECK(raw_output[3] == 0.f); // input = -5
- CHECK(raw_output[4] == 0.f); // input = -4
- CHECK(raw_output[5] == 0.f); // input = -3
- CHECK(raw_output[6] == 0.f); // input = -2
- CHECK(raw_output[7] == 0.f); // input = -1
- CHECK(raw_output[8] == 0.0f); // input = 0
- CHECK(raw_output[9] == 0.5f); // input = 1
- CHECK(raw_output[10] == 1.0f); // input = 2
- CHECK(raw_output[11] == 1.5f); // input = 3
- CHECK(raw_output[12] == 2.0f); // input = 4
- CHECK(raw_output[13] == 2.5f); // input = 5
- CHECK(raw_output[14] == 3.0f); // input = 6
- CHECK(raw_output[15] == 3.5f); // input = 7
-}
-
-}
diff --git a/test/postprocess/add_bias_test.cc b/test/postprocess/add_bias_test.cc
new file mode 100644
index 0000000..5e893ea
--- /dev/null
+++ b/test/postprocess/add_bias_test.cc
@@ -0,0 +1,95 @@
+#include "test/test.h"
+#include "aligned.h"
+#include "postprocess.h"
+
+#include <numeric>
+
+namespace intgemm {
+
+INTGEMM_SSE2 TEST_CASE("AddBias SSE2",) {
+ if (kCPU < CPUType::SSE2)
+ return;
+
+ AlignedVector<float> input(8);
+ AlignedVector<float> bias(8);
+ AlignedVector<float> output(8);
+
+ std::iota(input.begin(), input.end(), -2);
+ std::iota(bias.begin(), bias.end(), 0);
+
+ auto postproc = PostprocessImpl<AddBias, CPUType::SSE2>(AddBias(bias.begin(), bias.size()));
+ auto output_tmp = postproc.run({input.as<__m128>()[0], input.as<__m128>()[1]}, 0);
+ output.as<__m128>()[0] = output_tmp.pack0123;
+ output.as<__m128>()[1] = output_tmp.pack4567;
+
+ CHECK(output[0] == -2.f); // input = -2, bias = 0
+ CHECK(output[1] == 0.f); // input = -1, bias = 1
+ CHECK(output[2] == 2.f); // input = 0, bias = 2
+ CHECK(output[3] == 4.f); // input = 1, bias = 3
+ CHECK(output[4] == 6.f); // input = 2, bias = 4
+ CHECK(output[5] == 8.f); // input = 3, bias = 5
+ CHECK(output[6] == 10.f); // input = 4, bias = 6
+ CHECK(output[7] == 12.f); // input = 5, bias = 7
+}
+
+INTGEMM_AVX2 TEST_CASE("AddBias AVX2",) {
+ if (kCPU < CPUType::AVX2)
+ return;
+
+ AlignedVector<float> input(8);
+ AlignedVector<float> bias(8);
+ AlignedVector<float> output(8);
+
+ std::iota(input.begin(), input.end(), -4);
+ std::iota(bias.begin(), bias.end(), 0);
+
+ auto postproc = PostprocessImpl<AddBias, CPUType::AVX2>(AddBias(bias.begin(), bias.size()));
+ *output.as<__m256>() = postproc.run(*input.as<__m256>(), 0);
+
+ CHECK(output[0] == -4.f); // input = -4, bias = 0
+ CHECK(output[1] == -2.f); // input = -3, bias = 1
+ CHECK(output[2] == 0.f); // input = -2, bias = 2
+ CHECK(output[3] == 2.f); // input = -1, bias = 3
+ CHECK(output[4] == 4.f); // input = 0, bias = 4
+ CHECK(output[5] == 6.f); // input = 1, bias = 5
+ CHECK(output[6] == 8.f); // input = 2, bias = 6
+ CHECK(output[7] == 10.f); // input = 3, bias = 7
+}
+
+#ifndef INTGEMM_NO_AVX512
+
+INTGEMM_AVX512BW TEST_CASE("AddBias AVX512",) {
+ if (kCPU < CPUType::AVX512BW)
+ return;
+
+ AlignedVector<float> input(16);
+ AlignedVector<float> bias(16);
+ AlignedVector<float> output(16);
+
+ std::iota(input.begin(), input.end(), -8);
+ std::iota(bias.begin(), bias.end(), 0);
+
+ auto postproc = PostprocessImpl<AddBias, CPUType::AVX512BW>(AddBias(bias.begin(), bias.size()));
+ *output.as<__m512>() = postproc.run(*input.as<__m512>(), 0);
+
+ CHECK(output[0] == -8.f); // input = -8, bias = 0
+ CHECK(output[1] == -6.f); // input = -7, bias = 1
+ CHECK(output[2] == -4.f); // input = -6, bias = 2
+ CHECK(output[3] == -2.f); // input = -5, bias = 3
+ CHECK(output[4] == 0.f); // input = -4, bias = 4
+ CHECK(output[5] == 2.f); // input = -3, bias = 5
+ CHECK(output[6] == 4.f); // input = -2, bias = 6
+ CHECK(output[7] == 6.f); // input = -1, bias = 7
+ CHECK(output[8] == 8.f); // input = 0, bias = 8
+ CHECK(output[9] == 10.f); // input = 1, bias = 9
+ CHECK(output[10] == 12.f); // input = 2, bias = 10
+ CHECK(output[11] == 14.f); // input = 3, bias = 11
+ CHECK(output[12] == 16.f); // input = 4, bias = 12
+ CHECK(output[13] == 18.f); // input = 5, bias = 13
+ CHECK(output[14] == 20.f); // input = 6, bias = 14
+ CHECK(output[15] == 22.f); // input = 7, bias = 15
+}
+
+#endif
+
+}
diff --git a/test/postprocess/pipeline_test.cc b/test/postprocess/pipeline_test.cc
new file mode 100644
index 0000000..144ee48
--- /dev/null
+++ b/test/postprocess/pipeline_test.cc
@@ -0,0 +1,63 @@
+#include "test/test.h"
+#include "aligned.h"
+#include "postprocess.h"
+
+#include <numeric>
+
+namespace intgemm {
+
+INTGEMM_AVX2 TEST_CASE("PostprocessPipeline AVX2", "Unquantize-ReLU") {
+ if (kCPU < CPUType::AVX2)
+ return;
+
+ AlignedVector<int32_t> input(8);
+ AlignedVector<float> output(8);
+
+ std::iota(input.begin(), input.end(), -2);
+
+ auto pipeline = CreatePostprocessPipeline(Unquantize(0.5f), ReLU());
+ auto inited_pipeline = InitPostprocessPipeline<CPUType::AVX2>(pipeline);
+ *output.as<__m256>() = inited_pipeline.run(*input.as<__m256i>(), 0);
+
+ CHECK(output[0] == 0.0f); // input = -2
+ CHECK(output[1] == 0.0f); // input = -1
+ CHECK(output[2] == 0.0f); // input = 0
+ CHECK(output[3] == 0.5f); // input = 1
+ CHECK(output[4] == 1.0f); // input = 2
+ CHECK(output[5] == 1.5f); // input = 3
+ CHECK(output[6] == 2.0f); // input = 4
+ CHECK(output[7] == 2.5f); // input = 5
+}
+
+INTGEMM_AVX2 TEST_CASE("PostprocessPipeline AVX2 on whole buffer", "Unquantize-ReLU") {
+ if (kCPU < CPUType::AVX2)
+ return;
+
+ AlignedVector<int32_t> input(16);
+ AlignedVector<float> output(16);
+
+ std::iota(input.begin(), input.end(), -8);
+
+ auto pipeline = CreatePostprocessPipeline(Unquantize(0.5f), ReLU());
+ auto inited_pipeline = InitPostprocessPipeline<CPUType::AVX2>(pipeline);
+ inited_pipeline.run(input.as<__m256i>(), 2, output.as<__m256>());
+
+ CHECK(output[0] == 0.f); // input = -8
+ CHECK(output[1] == 0.f); // input = -7
+ CHECK(output[2] == 0.f); // input = -6
+ CHECK(output[3] == 0.f); // input = -5
+ CHECK(output[4] == 0.f); // input = -4
+ CHECK(output[5] == 0.f); // input = -3
+ CHECK(output[6] == 0.f); // input = -2
+ CHECK(output[7] == 0.f); // input = -1
+ CHECK(output[8] == 0.0f); // input = 0
+ CHECK(output[9] == 0.5f); // input = 1
+ CHECK(output[10] == 1.0f); // input = 2
+ CHECK(output[11] == 1.5f); // input = 3
+ CHECK(output[12] == 2.0f); // input = 4
+ CHECK(output[13] == 2.5f); // input = 5
+ CHECK(output[14] == 3.0f); // input = 6
+ CHECK(output[15] == 3.5f); // input = 7
+}
+
+}
diff --git a/test/postprocess/relu_test.cc b/test/postprocess/relu_test.cc
new file mode 100644
index 0000000..af6677e
--- /dev/null
+++ b/test/postprocess/relu_test.cc
@@ -0,0 +1,213 @@
+#include "test/test.h"
+#include "aligned.h"
+#include "postprocess.h"
+
+#include <numeric>
+
+namespace intgemm {
+
+/*
+ * ReLU: float -> float
+ */
+INTGEMM_SSE2 TEST_CASE("ReLU SSE2",) {
+ if (kCPU < CPUType::SSE2)
+ return;
+
+ AlignedVector<float> input(8);
+ AlignedVector<float> output(8);
+ std::iota(input.begin(), input.end(), -2);
+
+ auto postproc = PostprocessImpl<ReLU, CPUType::SSE2>(ReLU());
+ auto output_tmp = postproc.run({input.as<__m128>()[0], input.as<__m128>()[1]}, 0);
+ output.as<__m128>()[0] = output_tmp.pack0123;
+ output.as<__m128>()[1] = output_tmp.pack4567;
+
+ CHECK(output[0] == 0.f); // input = -2
+ CHECK(output[1] == 0.f); // input = -1
+ CHECK(output[2] == 0.f); // input = 0
+ CHECK(output[3] == 1.f); // input = 1
+ CHECK(output[4] == 2.f); // input = 2
+ CHECK(output[5] == 3.f); // input = 3
+ CHECK(output[6] == 4.f); // input = 4
+ CHECK(output[7] == 5.f); // input = 5
+}
+
+INTGEMM_AVX2 TEST_CASE("ReLU AVX2",) {
+ if (kCPU < CPUType::AVX2)
+ return;
+
+ AlignedVector<float> input(8);
+ AlignedVector<float> output(8);
+
+ std::iota(input.begin(), input.end(), -4);
+
+ auto postproc = PostprocessImpl<ReLU, CPUType::AVX2>(ReLU());
+ *output.as<__m256>() = postproc.run(*input.as<__m256>(), 0);
+
+ CHECK(output[0] == 0.f); // input = -4
+ CHECK(output[1] == 0.f); // input = -3
+ CHECK(output[2] == 0.f); // input = -2
+ CHECK(output[3] == 0.f); // input = -1
+ CHECK(output[4] == 0.f); // input = 0
+ CHECK(output[5] == 1.f); // input = 1
+ CHECK(output[6] == 2.f); // input = 2
+ CHECK(output[7] == 3.f); // input = 3
+}
+
+#ifndef INTGEMM_NO_AVX512
+
+INTGEMM_AVX512BW TEST_CASE("ReLU AVX512",) {
+ if (kCPU < CPUType::AVX512BW)
+ return;
+
+ AlignedVector<float> input(16);
+ AlignedVector<float> output(16);
+
+ std::iota(input.begin(), input.end(), -8);
+
+ auto postproc = PostprocessImpl<ReLU, CPUType::AVX512BW>(ReLU());
+ *output.as<__m512>() = postproc.run(*input.as<__m512>(), 0);
+
+ CHECK(output[0] == 0.f); // input = -8
+ CHECK(output[1] == 0.f); // input = -7
+ CHECK(output[2] == 0.f); // input = -6
+ CHECK(output[3] == 0.f); // input = -5
+ CHECK(output[4] == 0.f); // input = -4
+ CHECK(output[5] == 0.f); // input = -3
+ CHECK(output[6] == 0.f); // input = -2
+ CHECK(output[7] == 0.f); // input = -1
+ CHECK(output[8] == 0.f); // input = 0
+ CHECK(output[9] == 1.f); // input = 1
+ CHECK(output[10] == 2.f); // input = 2
+ CHECK(output[11] == 3.f); // input = 3
+ CHECK(output[12] == 4.f); // input = 4
+ CHECK(output[13] == 5.f); // input = 5
+ CHECK(output[14] == 6.f); // input = 6
+ CHECK(output[15] == 7.f); // input = 7
+}
+
+#endif
+
+/*
+ * ReLU: int8 -> int8
+ */
+INTGEMM_SSE2 TEST_CASE("ReLU_int8 SSE2",) {
+ if (kCPU < CPUType::SSE2)
+ return;
+
+ const unsigned NEGATIVE_NUMBERS = 10;
+
+ AlignedVector<int8_t> input(16);
+ AlignedVector<int8_t> output(16);
+
+ std::iota(input.begin(), input.end(), -NEGATIVE_NUMBERS);
+
+ auto postproc = PostprocessImpl<ReLU_int8, CPUType::SSE2>(ReLU_int8());
+ *output.as<__m128i>() = postproc.run(*input.as<__m128i>(), 0);
+
+ for (auto i = 0; i < output.size(); ++i)
+ CHECK(output[i] == (i <= NEGATIVE_NUMBERS ? 0 : i - NEGATIVE_NUMBERS));
+}
+
+INTGEMM_AVX2 TEST_CASE("ReLU_int8 AVX2",) {
+ if (kCPU < CPUType::AVX2)
+ return;
+
+ const unsigned NEGATIVE_NUMBERS = 10;
+
+ AlignedVector<int8_t> input(32);
+ AlignedVector<int8_t> output(32);
+
+ std::iota(input.begin(), input.end(), -NEGATIVE_NUMBERS);
+
+ auto postproc = PostprocessImpl<ReLU_int8, CPUType::AVX2>(ReLU_int8());
+ *output.as<__m256i>() = postproc.run(*input.as<__m256i>(), 0);
+
+ for (auto i = 0; i < output.size(); ++i)
+ CHECK(output[i] == (i <= NEGATIVE_NUMBERS ? 0 : i - NEGATIVE_NUMBERS));
+}
+
+#ifndef INTGEMM_NO_AVX512
+
+INTGEMM_AVX512BW TEST_CASE("ReLU_int8 AVX512",) {
+ if (kCPU < CPUType::AVX512BW)
+ return;
+
+ const unsigned NEGATIVE_NUMBERS = 30;
+
+ AlignedVector<int8_t> input(64);
+ AlignedVector<int8_t> output(64);
+
+ std::iota(input.begin(), input.end(), -NEGATIVE_NUMBERS);
+
+ auto postproc = PostprocessImpl<ReLU_int8, CPUType::AVX512BW>(ReLU_int8());
+ *output.as<__m512i>() = postproc.run(*input.as<__m512i>(), 0);
+
+ for (auto i = 0; i < output.size(); ++i)
+ CHECK(output[i] == (i <= NEGATIVE_NUMBERS ? 0 : i - NEGATIVE_NUMBERS));
+}
+
+#endif
+
+/*
+ * ReLU: int16 -> int16
+ */
+INTGEMM_SSE2 TEST_CASE("ReLU_int16 SSE2",) {
+ if (kCPU < CPUType::SSE2)
+ return;
+
+ const unsigned NEGATIVE_NUMBERS = 5;
+
+ AlignedVector<int16_t> input(8);
+ AlignedVector<int16_t> output(8);
+
+ std::iota(input.begin(), input.end(), -NEGATIVE_NUMBERS);
+
+ auto postproc = PostprocessImpl<ReLU_int16, CPUType::SSE2>(ReLU_int16());
+ *output.as<__m128i>() = postproc.run(*input.as<__m128i>(), 0);
+
+ for (auto i = 0; i < output.size(); ++i)
+ CHECK(output[i] == (i <= NEGATIVE_NUMBERS ? 0 : i - NEGATIVE_NUMBERS));
+}
+
+INTGEMM_AVX2 TEST_CASE("ReLU_int16 AVX2",) {
+ if (kCPU < CPUType::AVX2)
+ return;
+
+ const unsigned NEGATIVE_NUMBERS = 10;
+
+ AlignedVector<int16_t> input(16);
+ AlignedVector<int16_t> output(16);
+
+ std::iota(input.begin(), input.end(), -NEGATIVE_NUMBERS);
+
+ auto postproc = PostprocessImpl<ReLU_int16, CPUType::AVX2>(ReLU_int16());
+ *output.as<__m256i>() = postproc.run(*input.as<__m256i>(), 0);
+
+ for (auto i = 0; i < output.size(); ++i)
+ CHECK(output[i] == (i <= NEGATIVE_NUMBERS ? 0 : i - NEGATIVE_NUMBERS));
+}
+
+#ifndef INTGEMM_NO_AVX512
+
+INTGEMM_AVX512BW TEST_CASE("ReLU_int16 AVX512",) {
+ if (kCPU < CPUType::AVX512BW)
+ return;
+
+ const unsigned NEGATIVE_NUMBERS = 15;
+
+ AlignedVector<int16_t> input(32);
+ AlignedVector<int16_t> output(32);
+
+ std::iota(input.begin(), input.end(), -NEGATIVE_NUMBERS);
+
+ auto postproc = PostprocessImpl<ReLU_int16, CPUType::AVX512BW>(ReLU_int16());
+ *output.as<__m512i>() = postproc.run(*input.as<__m512i>(), 0);
+
+ for (auto i = 0; i < output.size(); ++i)
+ CHECK(output[i] == (i <= NEGATIVE_NUMBERS ? 0 : i - NEGATIVE_NUMBERS));
+}
+
+#endif
+
+}
diff --git a/test/postprocess/sigmoid_test.cc b/test/postprocess/sigmoid_test.cc
new file mode 100644
index 0000000..43c713c
--- /dev/null
+++ b/test/postprocess/sigmoid_test.cc
@@ -0,0 +1,33 @@
+#include "test/test.h"
+#include "aligned.h"
+#include "postprocess.h"
+
+#include <numeric>
+
+namespace intgemm {
+
+INTGEMM_AVX2 TEST_CASE("Sigmoid AVX2",) {
+ if (kCPU < CPUType::AVX2)
+ return;
+
+ const float error_tolerance = 0.001f;
+
+ AlignedVector<float> input(8);
+ AlignedVector<float> output(8);
+
+ std::iota(input.begin(), input.end(), -4);
+
+ auto postproc = PostprocessImpl<Sigmoid, CPUType::AVX2>(Sigmoid());
+ *output.as<__m256>() = postproc.run(*input.as<__m256>(), 0);
+
+ CHECK_EPS(output[0], 0.0179862f, error_tolerance); // input = -4
+ CHECK_EPS(output[1], 0.0474259f, error_tolerance); // input = -3
+ CHECK_EPS(output[2], 0.1192029f, error_tolerance); // input = -2
+ CHECK_EPS(output[3], 0.2689414f, error_tolerance); // input = -1
+ CHECK_EPS(output[4], 0.5f , error_tolerance); // input = 0
+ CHECK_EPS(output[5], 0.7310586f, error_tolerance); // input = 1
+ CHECK_EPS(output[6], 0.8807970f, error_tolerance); // input = 2
+ CHECK_EPS(output[7], 0.9525740f, error_tolerance); // input = 3
+}
+
+}
diff --git a/test/postprocess/tanh_test.cc b/test/postprocess/tanh_test.cc
new file mode 100644
index 0000000..f0e4dc2
--- /dev/null
+++ b/test/postprocess/tanh_test.cc
@@ -0,0 +1,33 @@
+#include "test/test.h"
+#include "aligned.h"
+#include "postprocess.h"
+
+#include <numeric>
+
+namespace intgemm {
+
+INTGEMM_AVX2 TEST_CASE("Tanh AVX2",) {
+ if (kCPU < CPUType::AVX2)
+ return;
+
+ const float error_tolerance = 0.001f;
+
+ AlignedVector<float> input(8);
+ AlignedVector<float> output(8);
+
+ std::generate(input.begin(), input.end(), [] () { static int n = -4; return n++ / 4.f; });
+
+ auto postproc = PostprocessImpl<Tanh, CPUType::AVX2>(Tanh());
+ *output.as<__m256>() = postproc.run(*input.as<__m256>(), 0);
+
+ CHECK_EPS(output[0], -0.7615942f, error_tolerance); // input = -1
+ CHECK_EPS(output[1], -0.6351490f, error_tolerance); // input = -0.75
+ CHECK_EPS(output[2], -0.4621172f, error_tolerance); // input = -0.5
+ CHECK_EPS(output[3], -0.2449187f, error_tolerance); // input = -0.25
+ CHECK_EPS(output[4], 0.0f , error_tolerance); // input = 0
+ CHECK_EPS(output[5], 0.2449187f, error_tolerance); // input = 0.25
+ CHECK_EPS(output[6], 0.4621172f, error_tolerance); // input = 0.5
+ CHECK_EPS(output[7], 0.6351490f, error_tolerance); // input = 0.75
+}
+
+}
diff --git a/test/postprocess/unquantize_test.cc b/test/postprocess/unquantize_test.cc
new file mode 100644
index 0000000..c33b909
--- /dev/null
+++ b/test/postprocess/unquantize_test.cc
@@ -0,0 +1,88 @@
+#include "test/test.h"
+#include "aligned.h"
+#include "postprocess.h"
+
+#include <numeric>
+
+namespace intgemm {
+
+INTGEMM_SSE2 TEST_CASE("Unquantize SSE2",) {
+ if (kCPU < CPUType::SSE2)
+ return;
+
+ AlignedVector<int32_t> input(8);
+ AlignedVector<float> output(8);
+ std::iota(input.begin(), input.end(), -2);
+
+ auto postproc = PostprocessImpl<Unquantize, CPUType::SSE2>(Unquantize(0.5f));
+ auto output_tmp = postproc.run({input.as<__m128i>()[0], input.as<__m128i>()[1]}, 0);
+ output.as<__m128>()[0] = output_tmp.pack0123;
+ output.as<__m128>()[1] = output_tmp.pack4567;
+
+ CHECK(output[0] == -1.0f); // input = -2
+ CHECK(output[1] == -0.5f); // input = -1
+ CHECK(output[2] == 0.0f); // input = 0
+ CHECK(output[3] == 0.5f); // input = 1
+ CHECK(output[4] == 1.0f); // input = 2
+ CHECK(output[5] == 1.5f); // input = 3
+ CHECK(output[6] == 2.0f); // input = 4
+ CHECK(output[7] == 2.5f); // input = 5
+}
+
+INTGEMM_AVX2 TEST_CASE("Unquantize AVX2",) {
+ if (kCPU < CPUType::AVX2)
+ return;
+
+ AlignedVector<int32_t> input(8);
+ AlignedVector<float> output(8);
+
+ std::iota(input.begin(), input.end(), -4);
+
+ auto postproc = PostprocessImpl<Unquantize, CPUType::AVX2>(Unquantize(0.5f));
+ *output.as<__m256>() = postproc.run(*input.as<__m256i>(), 0);
+
+ CHECK(output[0] == -2.0f); // input = -4
+ CHECK(output[1] == -1.5f); // input = -3
+ CHECK(output[2] == -1.0f); // input = -2
+ CHECK(output[3] == -0.5f); // input = -1
+ CHECK(output[4] == 0.0f); // input = 0
+ CHECK(output[5] == 0.5f); // input = 1
+ CHECK(output[6] == 1.0f); // input = 2
+ CHECK(output[7] == 1.5f); // input = 3
+}
+
+#ifndef INTGEMM_NO_AVX512
+
+INTGEMM_AVX512BW TEST_CASE("Unquantize AVX512",) {
+ if (kCPU < CPUType::AVX512BW)
+ return;
+
+ AlignedVector<int32_t> input(16);
+ AlignedVector<float> output(16);
+
+ std::iota(input.begin(), input.end(), -8);
+
+ auto postproc = PostprocessImpl<Unquantize, CPUType::AVX512BW>(Unquantize(0.5f));
+ *output.as<__m512>() = postproc.run(*input.as<__m512i>(), 0);
+
+ CHECK(output[0] == -4.0f); // input = -8
+ CHECK(output[1] == -3.5f); // input = -7
+ CHECK(output[2] == -3.0f); // input = -6
+ CHECK(output[3] == -2.5f); // input = -5
+ CHECK(output[4] == -2.0f); // input = -4
+ CHECK(output[5] == -1.5f); // input = -3
+ CHECK(output[6] == -1.0f); // input = -2
+ CHECK(output[7] == -0.5f); // input = -1
+ CHECK(output[8] == 0.0f); // input = 0
+ CHECK(output[9] == 0.5f); // input = 1
+ CHECK(output[10] == 1.0f); // input = 2
+ CHECK(output[11] == 1.5f); // input = 3
+ CHECK(output[12] == 2.0f); // input = 4
+ CHECK(output[13] == 2.5f); // input = 5
+ CHECK(output[14] == 3.0f); // input = 6
+ CHECK(output[15] == 3.5f); // input = 7
+}
+
+#endif
+
+}
diff --git a/test/quantize_test.cc b/test/quantize_test.cc
index fb866f1..fd7f0a4 100644
--- a/test/quantize_test.cc
+++ b/test/quantize_test.cc
@@ -1,15 +1,13 @@
-#include "avx512_gemm.h"
+#include "test/test.h"
+#include "aligned.h"
#include "avx2_gemm.h"
-#include "ssse3_gemm.h"
+#include "avx512_gemm.h"
#include "sse2_gemm.h"
-#include "aligned.h"
-
-#include "3rd_party/catch.hpp"
+#include "ssse3_gemm.h"
#include <cstring>
-#include <math.h>
-
#include <iostream>
+#include <math.h>
namespace intgemm {
namespace {
diff --git a/test/relu_test.cc b/test/relu_test.cc
deleted file mode 100644
index 183f415..0000000
--- a/test/relu_test.cc
+++ /dev/null
@@ -1,89 +0,0 @@
-#include "3rd_party/catch.hpp"
-#include "postprocess.h"
-
-#include <numeric>
-
-namespace intgemm {
-
-INTGEMM_SSE2 TEST_CASE("ReLU SSE2",) {
- if (kCPU < CPUType::SSE2)
- return;
-
- float raw_input[8];
- std::iota(raw_input, raw_input + 8, -2);
-
- RegisterPair128 input;
- input.pack0123 = *reinterpret_cast<__m128*>(raw_input);
- input.pack4567 = *reinterpret_cast<__m128*>(raw_input + 4);
-
- auto postproc = PostprocessImpl<ReLU, CPUType::SSE2>(ReLU());
- auto output = postproc.run(input, 0);
- auto raw_output = reinterpret_cast<float*>(&output);
-
- CHECK(raw_output[0] == 0.f); // input = -2
- CHECK(raw_output[1] == 0.f); // input = -1
- CHECK(raw_output[2] == 0.f); // input = 0
- CHECK(raw_output[3] == 1.f); // input = 1
- CHECK(raw_output[4] == 2.f); // input = 2
- CHECK(raw_output[5] == 3.f); // input = 3
- CHECK(raw_output[6] == 4.f); // input = 4
- CHECK(raw_output[7] == 5.f); // input = 5
-}
-
-INTGEMM_AVX2 TEST_CASE("ReLU AVX2",) {
- if (kCPU < CPUType::AVX2)
- return;
-
- float raw_input[8];
- std::iota(raw_input, raw_input + 8, -4);
-
- auto input = *reinterpret_cast<__m256*>(raw_input);
- auto postproc = PostprocessImpl<ReLU, CPUType::AVX2>(ReLU());
- auto output = postproc.run(input, 0);
- auto raw_output = reinterpret_cast<float*>(&output);
-
- CHECK(raw_output[0] == 0.f); // input = -4
- CHECK(raw_output[1] == 0.f); // input = -3
- CHECK(raw_output[2] == 0.f); // input = -2
- CHECK(raw_output[3] == 0.f); // input = -1
- CHECK(raw_output[4] == 0.f); // input = 0
- CHECK(raw_output[5] == 1.f); // input = 1
- CHECK(raw_output[6] == 2.f); // input = 2
- CHECK(raw_output[7] == 3.f); // input = 3
-}
-
-#ifndef INTGEMM_NO_AVX512
-
-INTGEMM_AVX512BW TEST_CASE("ReLU AVX512",) {
- if (kCPU < CPUType::AVX512BW)
- return;
-
- float raw_input[16];
- std::iota(raw_input, raw_input + 16, -8);
-
- auto input = *reinterpret_cast<__m512*>(raw_input);
- auto postproc = PostprocessImpl<ReLU, CPUType::AVX512BW>(ReLU());
- auto output = postproc.run(input, 0);
- auto raw_output = reinterpret_cast<float*>(&output);
-
- CHECK(raw_output[0] == 0.f); // input = -8
- CHECK(raw_output[1] == 0.f); // input = -7
- CHECK(raw_output[2] == 0.f); // input = -6
- CHECK(raw_output[3] == 0.f); // input = -5
- CHECK(raw_output[4] == 0.f); // input = -4
- CHECK(raw_output[5] == 0.f); // input = -3
- CHECK(raw_output[6] == 0.f); // input = -2
- CHECK(raw_output[7] == 0.f); // input = -1
- CHECK(raw_output[8] == 0.f); // input = 0
- CHECK(raw_output[9] == 1.f); // input = 1
- CHECK(raw_output[10] == 2.f); // input = 2
- CHECK(raw_output[11] == 3.f); // input = 3
- CHECK(raw_output[12] == 4.f); // input = 4
- CHECK(raw_output[13] == 5.f); // input = 5
- CHECK(raw_output[14] == 6.f); // input = 6
- CHECK(raw_output[15] == 7.f); // input = 7
-}
-
-#endif
-
-}
diff --git a/test/test.cc b/test/test.cc
new file mode 100644
index 0000000..58c62f8
--- /dev/null
+++ b/test/test.cc
@@ -0,0 +1,6 @@
+#define CATCH_CONFIG_RUNNER
+#include "test/test.h"
+
+int main(int argc, char ** argv) {
+ return Catch::Session().run(argc, argv);
+}
diff --git a/test/test.h b/test/test.h
new file mode 100644
index 0000000..572a529
--- /dev/null
+++ b/test/test.h
@@ -0,0 +1,12 @@
+#include "3rd_party/catch.hpp"
+
+#define CHECK_MESSAGE(cond, msg) do { INFO(msg); CHECK(cond); } while(0)
+#define CHECK_FALSE_MESSAGE(cond, msg) do { INFO(msg); CHECK_FALSE(cond); } while(0)
+#define REQUIRE_MESSAGE(cond, msg) do { INFO(msg); REQUIRE(cond); } while(0)
+#define REQUIRE_FALSE_MESSAGE(cond, msg) do { INFO(msg); REQUIRE_FALSE(cond); } while(0)
+
+#define CHECK_EPS(actual, expected, epsilon) \
+ do { \
+ if (fabs((actual) - (expected)) < epsilon) { SUCCEED(); } \
+ else { CHECK((actual) == (expected)); } \
+ } while(0)
diff --git a/test/utils_test.cc b/test/utils_test.cc
new file mode 100644
index 0000000..580a872
--- /dev/null
+++ b/test/utils_test.cc
@@ -0,0 +1,38 @@
+#include "test/test.h"
+#include "utils.h"
+
+namespace intgemm {
+namespace {
+
+TEST_CASE("Factorial",) {
+ CHECK(factorial(0) == 1);
+ CHECK(factorial(1) == 1);
+ CHECK(factorial(2) == 2);
+ CHECK(factorial(3) == 6);
+ CHECK(factorial(4) == 24);
+
+ // Maximum result that fits in unsinged long long
+ CHECK(factorial(20) == 2432902008176640000);
+}
+
+TEST_CASE("Expi (negative)",) {
+ const double eps = 0.0000001;
+ CHECK_EPS(expi(-1), 0.3678794411714423, eps);
+ CHECK_EPS(expi(-2), 0.1353352832366127, eps);
+ CHECK_EPS(expi(-10), 0.0000453999297625, eps);
+}
+
+TEST_CASE("Expi (zero)",) {
+ const double eps = 0.0000001;
+ CHECK_EPS(expi(0), 1.0, eps);
+}
+
+TEST_CASE("Expi (positive)",) {
+ const double eps = 0.0000001;
+ CHECK_EPS(expi(1), 2.7182818284590452, eps);
+ CHECK_EPS(expi(2), 7.3890560989306502, eps);
+ CHECK_EPS(expi(10), 22026.4657948067165170, eps);
+}
+
+}
+}
diff --git a/utils.h b/utils.h
index a403995..2927693 100644
--- a/utils.h
+++ b/utils.h
@@ -49,4 +49,24 @@ constexpr subtuple_t<Tuple, Indices...> make_subtuple(const Tuple& tuple, sequen
return std::make_tuple(std::get<Indices>(tuple)...);
}
+/*
+ * Factorial
+ */
+constexpr unsigned long long factorial(unsigned n) {
+ return n <= 1 ? 1 : n * factorial(n - 1);
+}
+
+/*
+ * e^n, where n is integer
+ */
+namespace { // anonymous namespace
+constexpr double expi_nonnegative(unsigned n) {
+ return n == 0 ? 1.0 : (n == 1 ? 2.718281828459045 : expi_nonnegative(n / 2) * expi_nonnegative((n + 1) / 2));
+}
+} // anonymous namespace
+
+constexpr double expi(int n) {
+ return (n >= 0 ? expi_nonnegative(n) : 1.0 / expi_nonnegative(-n));
+}
+
}
diff --git a/vec_utils.h b/vec_utils.h
index fb6aea4..acb7d6e 100644
--- a/vec_utils.h
+++ b/vec_utils.h
@@ -46,4 +46,84 @@ INTGEMM_AVX512BW static inline __m512 unquantize(__m512i input, __m512 unquantiz
}
#endif
+/*
+ *
+ * Calculate floor: float -> float
+ *
+ */
+INTGEMM_SSE2 static inline __m128 floor_ff(__m128 a) {
+ return cvtepi32_ps(_mm_cvttps_epi32(a));
+}
+INTGEMM_AVX2 static inline __m256 floor_ff(__m256 a) {
+ return _mm256_floor_ps(a);
+}
+#ifndef INTGEMM_NO_AVX512
+INTGEMM_AVX512BW static inline __m512 floor_ff(__m512 a) {
+ return cvtepi32_ps(cvttps_epi32(a)); // TODO: Is there any better way to do that?
+}
+#endif
+
+/*
+ *
+ * Calculate approximation of e^x using Taylor series and lookup table
+ *
+ */
+
+template <typename Register>
+Register exp_approx_taylor(Register x) {
+ static constexpr int EXP_MIN = -20;
+ static constexpr int EXP_MAX = 20;
+ static constexpr float EXP_LOOKUP[EXP_MAX - EXP_MIN + 1] = {
+ expi(-20), expi(-19), expi(-18), expi(-17), expi(-16), expi(-15),
+ expi(-14), expi(-13), expi(-12), expi(-11), expi(-10), expi(-9),
+ expi(-8), expi(-7), expi(-6), expi(-5), expi(-4), expi(-3), expi(-2),
+ expi(-1), expi(0), expi(1), expi(2), expi(3), expi(4), expi(5),
+ expi(6), expi(7), expi(8), expi(9), expi(10), expi(11), expi(12),
+ expi(13), expi(14), expi(15), expi(16), expi(17), expi(18), expi(19),
+ expi(20),
+ };
+
+ static const Register dividers[] = {
+ set1_ps<Register>(1.f / factorial(7)),
+ set1_ps<Register>(1.f / factorial(6)),
+ set1_ps<Register>(1.f / factorial(5)),
+ set1_ps<Register>(1.f / factorial(4)),
+ set1_ps<Register>(1.f / factorial(3)),
+ set1_ps<Register>(1.f / factorial(2)),
+ set1_ps<Register>(1.f / factorial(1)),
+ };
+ static const auto const_one = set1_ps<Register>(1.f);
+ static const auto const_min_x = set1_ps<Register>(EXP_MIN);
+ static const auto const_max_x = set1_ps<Register>(EXP_MAX);
+
+ x = max_ps(x, const_min_x);
+ x = min_ps(x, const_max_x);
+
+ auto a = floor_ff(x);
+ auto xa = sub_ps(x, a);
+
+ auto result = mul_ps(dividers[0], xa);
+
+ result = add_ps(result, dividers[1]);
+ result = mul_ps(result, xa);
+ result = add_ps(result, dividers[2]);
+ result = mul_ps(result, xa);
+ result = add_ps(result, dividers[3]);
+ result = mul_ps(result, xa);
+ result = add_ps(result, dividers[4]);
+ result = mul_ps(result, xa);
+ result = add_ps(result, dividers[5]);
+ result = mul_ps(result, xa);
+ result = add_ps(result, dividers[6]);
+ result = mul_ps(result, xa);
+
+ result = add_ps(result, const_one);
+
+ auto ea = i32gather_ps<4>(EXP_LOOKUP + EXP_MAX, cvtps_epi32(a));
+ return mul_ps(ea, result);
+}
+
+template INTGEMM_AVX2 static __m256 exp_approx_taylor(__m256 x);
+template INTGEMM_AVX512BW static __m512 exp_approx_taylor(__m512 x);
+
}