Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/intgemm.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMateusz Chudyk <mateuszchudyk@gmail.com>2019-07-08 17:57:41 +0300
committerMateusz Chudyk <mateuszchudyk@gmail.com>2019-07-09 22:39:50 +0300
commit03ec24b72137c785cd9c931c34bc50c9cbd3cae3 (patch)
tree32f37bf856726337456dffca2b2f1f43bc53e229
parent7e514a4d1178ddeae0cf38fa29f5ca758abf8a9a (diff)
Add code infrastructure for support callbacks
-rw-r--r--CMakeLists.txt6
-rw-r--r--avx512_gemm.h10
-rw-r--r--benchmark.cc5
-rw-r--r--callbacks.h6
-rw-r--r--callbacks/avx2.h13
-rw-r--r--callbacks/avx512.h17
-rw-r--r--callbacks/configs.h10
-rw-r--r--callbacks/implementations.inl67
-rw-r--r--callbacks/sse2.h13
-rw-r--r--example.cc5
-rw-r--r--intgemm.h37
-rw-r--r--multiply.h18
-rw-r--r--postprocess.h390
-rw-r--r--postprocess_pipeline.h113
-rw-r--r--test/multiply_test.cc6
-rw-r--r--test/postprocess/add_bias_test.cc95
-rw-r--r--test/postprocess/pipeline_test.cc63
-rw-r--r--test/postprocess/relu_test.cc213
-rw-r--r--test/postprocess/sigmoid_test.cc33
-rw-r--r--test/postprocess/tanh_test.cc33
-rw-r--r--test/postprocess/unquantize_test.cc88
-rw-r--r--vec_utils.h1
22 files changed, 165 insertions, 1077 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 31b7b53..12300e0 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -33,12 +33,6 @@ endforeach()
include_directories(.)
add_executable(tests
test/multiply_test.cc
- test/postprocess/add_bias_test.cc
- test/postprocess/pipeline_test.cc
- test/postprocess/relu_test.cc
- test/postprocess/sigmoid_test.cc
- test/postprocess/tanh_test.cc
- test/postprocess/unquantize_test.cc
test/quantize_test.cc
test/test.cc
test/utils_test.cc
diff --git a/avx512_gemm.h b/avx512_gemm.h
index 8326a82..f474b35 100644
--- a/avx512_gemm.h
+++ b/avx512_gemm.h
@@ -217,8 +217,8 @@ struct AVX512_8bit {
// Special AVX512 implementation due to having 32 registers (so I don't have to
// allocate registers manually) and no sign instruction.
- template <typename PostprocessPipeline>
- INTGEMM_AVX512BW static void Multiply(const int8_t *A, const int8_t *B, float *C, PostprocessPipeline pipeline, Index A_rows, Index width, Index B_cols) {
+ template <typename Callback>
+ INTGEMM_AVX512BW static void Multiply(const int8_t *A, const int8_t *B, Index A_rows, Index width, Index B_cols, Callback callback) {
typedef __m512i Integer;
//typedef __m256 Float; // For quantization we only do 8 at a time.
// This is copy-paste from Multiply8_SSE2OrAVX2.
@@ -227,7 +227,7 @@ struct AVX512_8bit {
assert(reinterpret_cast<uintptr_t>(A) % sizeof(Integer) == 0);
assert(reinterpret_cast<uintptr_t>(B) % sizeof(Integer) == 0);
// There's 8 results for INTGEMM_AVX2 to handle.
- auto inited_pipeline = InitPostprocessPipeline<CPUType::AVX2>(pipeline);
+ auto callback_impl = callbacks::CallbackImpl<Callback, CPUType::AVX2>(callback);
const int simd_width = width / sizeof(Integer);
const Integer *B0_col = reinterpret_cast<const Integer*>(B);
// Added for AVX512.
@@ -324,9 +324,7 @@ struct AVX512_8bit {
Integer pack4567 = Pack0123(sum4, sum5, sum6, sum7);
auto total = PermuteSummer(pack0123, pack4567);
- auto offset = A_rowidx * B_cols + B0_colidx;
- auto result = inited_pipeline.run(total, offset);
- writer(C, offset, result);
+ callback_impl(total, A_rowidx, B0_colidx, A_rows, width, B_cols);
}
}
}
diff --git a/benchmark.cc b/benchmark.cc
index 6477792..669ea77 100644
--- a/benchmark.cc
+++ b/benchmark.cc
@@ -5,6 +5,7 @@
#include "sse2_gemm.h"
#include "intgemm.h"
#include "stop_watch.h"
+#include "callbacks.h"
#include <algorithm>
#include <cassert>
@@ -78,10 +79,10 @@ template <class Backend> void Run(const RandomMatrices &m, std::vector<uint64_t>
Backend::PrepareB(m.B.begin(), B_prepared.begin(), quant_mult, m.width, m.B_cols);
AlignedVector<float> output(m.A_rows * m.B_cols);
// Burn in
- Backend::Multiply(A_prepared.begin(), B_prepared.begin(), output.begin(), CreatePostprocessPipeline(Unquantize(unquant_mult)), m.A_rows, m.width, m.B_cols);
+ Backend::Multiply(A_prepared.begin(), B_prepared.begin(), m.A_rows, m.width, m.B_cols, callbacks::Dummy());
{
StopWatch w(stats);
- Backend::Multiply(A_prepared.begin(), B_prepared.begin(), output.begin(), CreatePostprocessPipeline(Unquantize(unquant_mult)), m.A_rows, m.width, m.B_cols);
+ Backend::Multiply(A_prepared.begin(), B_prepared.begin(), m.A_rows, m.width, m.B_cols, callbacks::Dummy());
}
}
diff --git a/callbacks.h b/callbacks.h
new file mode 100644
index 0000000..079ca0f
--- /dev/null
+++ b/callbacks.h
@@ -0,0 +1,6 @@
+#pragma once
+
+#include "callbacks/configs.h"
+#include "callbacks/sse2.h"
+#include "callbacks/avx2.h"
+#include "callbacks/avx512.h"
diff --git a/callbacks/avx2.h b/callbacks/avx2.h
new file mode 100644
index 0000000..c69875d
--- /dev/null
+++ b/callbacks/avx2.h
@@ -0,0 +1,13 @@
+#pragma once
+
+#define THIS_IS_AVX2
+#include "callbacks/implementations.inl"
+#undef THIS_IS_AVX2
+
+namespace intgemm {
+namespace callbacks {
+
+// Put here callbacks supported only by AVX2...
+
+}
+}
diff --git a/callbacks/avx512.h b/callbacks/avx512.h
new file mode 100644
index 0000000..a224a84
--- /dev/null
+++ b/callbacks/avx512.h
@@ -0,0 +1,17 @@
+#pragma once
+
+#ifndef INTGEMM_NO_AVX512
+
+#define THIS_IS_AVX512BW
+#include "callbacks/implementations.inl"
+#undef THIS_IS_AVX512BW
+
+namespace intgemm {
+namespace callbacks {
+
+// Put here callbacks supported only by AVX512BW...
+
+}
+}
+
+#endif
diff --git a/callbacks/configs.h b/callbacks/configs.h
new file mode 100644
index 0000000..fae60bd
--- /dev/null
+++ b/callbacks/configs.h
@@ -0,0 +1,10 @@
+#pragma once
+
+namespace intgemm {
+namespace callbacks {
+
+struct Dummy {
+};
+
+}
+}
diff --git a/callbacks/implementations.inl b/callbacks/implementations.inl
new file mode 100644
index 0000000..01195f1
--- /dev/null
+++ b/callbacks/implementations.inl
@@ -0,0 +1,67 @@
+#include "callbacks/configs.h"
+
+#include "intrinsics.h"
+#include "types.h"
+#include "vec_traits.h"
+
+#if defined(THIS_IS_SSE2)
+ #define CPU_NAME SSE2
+ #define CPU_ATTR INTGEMM_SSE2
+#elif defined(THIS_IS_AVX2)
+ #define CPU_NAME AVX2
+ #define CPU_ATTR INTGEMM_AVX2
+#elif defined(THIS_IS_AVX512BW)
+ #define CPU_NAME AVX512BW
+ #define CPU_ATTR INTGEMM_AVX512BW
+#else
+ #error "Only SSE2, AVX2 and AVX512BW are supported"
+#endif
+
+#define vi vector_t<CPUType::CPU_NAME, int>
+#define vf vector_t<CPUType::CPU_NAME, float>
+#define vd vector_t<CPUType::CPU_NAME, double>
+#define dvi dvector_t<CPUType::CPU_NAME, int>
+#define dvf dvector_t<CPUType::CPU_NAME, float>
+#define dvd dvector_t<CPUType::CPU_NAME, double>
+
+#if defined(THIS_IS_SSE2)
+#define vinput dvector_t<CPUType::SSE2, int>
+#else
+#define vinput vector_t<CPUType::AVX2, int>
+#endif
+
+namespace intgemm {
+namespace callbacks {
+
+template <typename CallbackConfig, CPUType CpuType>
+class CallbackImpl;
+
+}}
+
+/*
+ * Callbacks implementations....
+ */
+namespace intgemm {
+namespace callbacks {
+
+/*
+ * Dummy
+ */
+template <> class CallbackImpl<Dummy, CPUType::CPU_NAME> {
+public:
+ CPU_ATTR CallbackImpl(const Dummy&) {}
+ CPU_ATTR void operator()(vinput, Index, Index, Index, Index, Index) {}
+};
+
+}
+}
+
+#undef CPU_NAME
+#undef CPU_ATTR
+#undef vi
+#undef vf
+#undef vd
+#undef dvi
+#undef dvf
+#undef dvd
+#undef vinput
diff --git a/callbacks/sse2.h b/callbacks/sse2.h
new file mode 100644
index 0000000..8758a20
--- /dev/null
+++ b/callbacks/sse2.h
@@ -0,0 +1,13 @@
+#pragma once
+
+#define THIS_IS_SSE2
+#include "callbacks/implementations.inl"
+#undef THIS_IS_SSE2
+
+namespace intgemm {
+namespace callbacks {
+
+// Put here callbacks supported only by SSE2...
+
+}
+}
diff --git a/example.cc b/example.cc
index d01c820..3e5cce7 100644
--- a/example.cc
+++ b/example.cc
@@ -2,6 +2,7 @@
// This is just for AlignedVector, which helps managed 64-byte aligned memory.
// Feel free to manage memory yourself.
#include "aligned.h"
+#include "callbacks.h"
#include <cassert>
#include <math.h>
@@ -51,7 +52,7 @@ int main() {
AlignedVector<float> C(A_rows * B_cols);
// Do the actual multiply.
- intgemm::Int16::Multiply(A_prepared.begin(), B_prepared.begin(), C.begin(), intgemm::CreatePostprocessPipeline(intgemm::Unquantize(1.0 / (quant_mult * quant_mult))), A_rows, width, B_cols);
+ intgemm::Int16::Multiply(A_prepared.begin(), B_prepared.begin(), A_rows, width, B_cols, intgemm::callbacks::Dummy());
// Sanity check. C will be row major.
assert(fabs(C[0] - top_left_reference) < 0.05);
}
@@ -70,7 +71,7 @@ int main() {
AlignedVector<float> C(A_rows * B_cols);
// Do the actual multiply.
- intgemm::Int8::Multiply(A_prepared.begin(), B_prepared.begin(), C.begin(), intgemm::CreatePostprocessPipeline(intgemm::Unquantize(1.0 / (quant_mult * quant_mult))), A_rows, width, B_cols);
+ intgemm::Int8::Multiply(A_prepared.begin(), B_prepared.begin(), A_rows, width, B_cols, intgemm::callbacks::Dummy());
// Sanity check. C will be row major.
assert(fabs(C[0] - top_left_reference) < 0.05);
}
diff --git a/intgemm.h b/intgemm.h
index e380758..74d5e3d 100644
--- a/intgemm.h
+++ b/intgemm.h
@@ -48,7 +48,6 @@
#include "sse2_gemm.h"
#include "ssse3_gemm.h"
#include "avx2_gemm.h"
-#include "postprocess.h"
#ifndef INTGEMM_NO_AVX512
#include "avx512_gemm.h"
#endif
@@ -67,8 +66,8 @@ struct Unsupported_16bit {
static void SelectColumnsB(const int16_t *, int16_t *, Index, const Index *, const Index *) {
throw UnsupportedCPU();
}
- template <typename PostprocessPipeline>
- static void Multiply(const int16_t *, const int16_t *, float *, PostprocessPipeline, Index, Index, Index) {
+ template <typename Callback>
+ static void Multiply(const int16_t *, const int16_t *, Index, Index, Index, Callback) {
throw UnsupportedCPU();
}
constexpr static const char *const kName = "16-bit Unsupported";
@@ -84,8 +83,8 @@ struct Unsupported_8bit {
static void SelectColumnsB(const int8_t *, int8_t *, Index, const Index *, const Index *) {
throw UnsupportedCPU();
}
- template <typename PostprocessPipeline>
- static void Multiply(const int8_t *, const int8_t *, float *, PostprocessPipeline, Index, Index, Index) {
+ template <typename Callback>
+ static void Multiply(const int8_t *, const int8_t *, Index, Index, Index, Callback) {
throw UnsupportedCPU();
}
constexpr static const char *const kName = "8-bit Unsupported";
@@ -133,15 +132,15 @@ template <class T> T ChooseCPU(T avx512, T avx2, T ssse3, T sse2, T unsupported)
}
/* 16-bit matrix multiplication. */
-template <typename PostprocessPipeline>
+template <typename Callback>
class Int16Mult {
public:
// Multiply C = A * B, presuming A and B have been prepared.
- static void (*Multiply)(const int16_t *A, const int16_t *B, float *C, PostprocessPipeline pipeline, Index A_rows, Index width, Index B_cols);
+ static void (*Multiply)(const int16_t *A, const int16_t *B, Index A_rows, Index width, Index B_cols, Callback callback);
};
-template <typename PostprocessPipeline>
-void (*Int16Mult<PostprocessPipeline>::Multiply)(const int16_t *A, const int16_t *B, float *C, PostprocessPipeline pipeline, Index A_rows, Index width, Index B_cols) = ChooseCPU(AVX512_16bit::Multiply<PostprocessPipeline>, AVX2_16bit::Multiply<PostprocessPipeline>, SSE2_16bit::Multiply<PostprocessPipeline>, SSE2_16bit::Multiply<PostprocessPipeline>, Unsupported_16bit::Multiply);
+template <typename Callback>
+void (*Int16Mult<Callback>::Multiply)(const int16_t *A, const int16_t *B, Index A_rows, Index width, Index B_cols, Callback callback) = ChooseCPU(AVX512_16bit::Multiply<Callback>, AVX2_16bit::Multiply<Callback>, SSE2_16bit::Multiply<Callback>, SSE2_16bit::Multiply<Callback>, Unsupported_16bit::Multiply);
struct Int16 {
typedef int16_t Integer;
@@ -172,24 +171,24 @@ struct Int16 {
static void (*SelectColumnsB)(const int16_t *input, int16_t *output, Index rows, const Index *cols_begin, const Index *cols_end);
// Multiply C = A * B, presuming A and B have been prepared.
- template <typename PostprocessPipeline>
- static void Multiply(const int16_t *A, const int16_t *B, float *C, PostprocessPipeline pipeline, Index A_rows, Index width, Index B_cols) {
- Int16Mult<PostprocessPipeline>::Multiply(A, B, C, pipeline, A_rows, width, B_cols);
+ template <typename Callback>
+ static void Multiply(const int16_t *A, const int16_t *B, Index A_rows, Index width, Index B_cols, Callback callback) {
+ Int16Mult<Callback>::Multiply(A, B, A_rows, width, B_cols, callback);
}
static const char *const kName;
};
/* 8-bit matrix multiplication */
-template <typename PostprocessPipeline>
+template <typename Callback>
class Int8Mult {
public:
// Multiply C = A * B, presuming A and B have been prepared.
- static void (*Multiply)(const int8_t *A, const int8_t *B, float *C, PostprocessPipeline pipeline, Index A_rows, Index width, Index B_cols);
+ static void (*Multiply)(const int8_t *A, const int8_t *B, Index A_rows, Index width, Index B_cols, Callback callback);
};
-template <typename PostprocessPipeline>
-void (*Int8Mult<PostprocessPipeline>::Multiply)(const int8_t *A, const int8_t *B, float *C, PostprocessPipeline pipeline, Index A_rows, Index width, Index B_cols) = ChooseCPU(AVX512_8bit::Multiply<PostprocessPipeline>, AVX2_8bit::Multiply<PostprocessPipeline>, SSSE3_8bit::Multiply<PostprocessPipeline>, SSSE3_8bit::Multiply<PostprocessPipeline>, Unsupported_8bit::Multiply);
+template <typename Callback>
+void (*Int8Mult<Callback>::Multiply)(const int8_t *A, const int8_t *B, Index A_rows, Index width, Index B_cols, Callback callback) = ChooseCPU(AVX512_8bit::Multiply<Callback>, AVX2_8bit::Multiply<Callback>, SSSE3_8bit::Multiply<Callback>, SSSE3_8bit::Multiply<Callback>, Unsupported_8bit::Multiply);
struct Int8 {
typedef int8_t Integer;
@@ -219,9 +218,9 @@ struct Int8 {
static void (*SelectColumnsB)(const int8_t *input, int8_t *output, Index rows, const Index *cols_begin, const Index *cols_end);
// Multiply C = A * B, presuming A and B have been prepared.
- template <typename PostprocessPipeline>
- static void Multiply(const int8_t *A, const int8_t *B, float* C, PostprocessPipeline pipeline, Index A_rows, Index width, Index B_cols) {
- Int8Mult<PostprocessPipeline>::Multiply(A, B, C, pipeline, A_rows, width, B_cols);
+ template <typename Callback>
+ static void Multiply(const int8_t *A, const int8_t *B, Index A_rows, Index width, Index B_cols, Callback callback) {
+ Int8Mult<Callback>::Multiply(A, B, A_rows, width, B_cols, callback);
}
static const char *const kName;
diff --git a/multiply.h b/multiply.h
index dff06f3..5a7c94a 100644
--- a/multiply.h
+++ b/multiply.h
@@ -2,9 +2,9 @@
#include "interleave.h"
#include "intrinsics.h"
-#include "postprocess_pipeline.h"
#include "vec_utils.h"
#include "vec_traits.h"
+#include "callbacks.h"
namespace intgemm {
@@ -144,13 +144,13 @@ INTGEMM_PACK0123(INTGEMM_AVX512BW, __m512i)
// B_cols must be a multiple of 8.
// Multiply16
#define INTGEMM_MULTIPLY16(Integer, target, cpu_type) \
-template <typename PostprocessPipeline> target static void Multiply(const int16_t *A, const int16_t *B, float* C, PostprocessPipeline pipeline, Index A_rows, Index width, Index B_cols) { \
+template <typename Callback> target static void Multiply(const int16_t *A, const int16_t *B, Index A_rows, Index width, Index B_cols, Callback callback) { \
assert(width % (sizeof(Integer) / sizeof(int16_t)) == 0); \
assert(B_cols % 8 == 0); \
assert(reinterpret_cast<uintptr_t>(A) % sizeof(Integer) == 0); \
assert(reinterpret_cast<uintptr_t>(B) % sizeof(Integer) == 0); \
const int simd_width = width / (sizeof(Integer) / sizeof(int16_t)); \
- auto inited_pipeline = InitPostprocessPipeline<cpu_type>(pipeline); \
+ auto callback_impl = callbacks::CallbackImpl<Callback, cpu_type>(callback); \
const Integer *B0_col = reinterpret_cast<const Integer *>(B); \
for (Index B0_colidx = 0; B0_colidx < B_cols; B0_col += 8 * simd_width, B0_colidx += 8) { \
/* Process one row of A at a time. Doesn't seem to be faster to do multiple rows of A at once.*/ \
@@ -194,9 +194,7 @@ template <typename PostprocessPipeline> target static void Multiply(const int16_
Integer pack4567 = Pack0123(sum4, sum5, sum6, sum7); \
/*The specific implementation may need to reduce further.*/ \
auto total = PermuteSummer(pack0123, pack4567); \
- auto offset = A_rowidx * B_cols + B0_colidx; \
- auto result = inited_pipeline.run(total, offset); \
- writer(C, offset, result); \
+ callback_impl(total, A_rowidx, B0_colidx, A_rows, width, B_cols); \
} \
} \
} \
@@ -350,14 +348,14 @@ INTGEMM_SSSE3 inline static void InnerINTGEMM_SSSE3(
}
//INTGEMM_AVX2 or INTGEMM_SSSE3 multiply
#define INTGEMM_MULTIPLY8(Integer, target, cpu_type) \
- template <typename PostprocessPipeline> target static void Multiply(const int8_t *A, const int8_t *B, float* C, PostprocessPipeline pipeline, Index A_rows, Index width, Index B_cols) { \
+ template <typename Callback> target static void Multiply(const int8_t *A, const int8_t *B, Index A_rows, Index width, Index B_cols, Callback callback) { \
assert(width % sizeof(Integer) == 0); \
assert(B_cols % 8 == 0); \
assert(reinterpret_cast<uintptr_t>(A) % sizeof(Integer) == 0); \
assert(reinterpret_cast<uintptr_t>(B) % sizeof(Integer) == 0); \
const int simd_width = width / sizeof(Integer); \
+ auto callback_impl = callbacks::CallbackImpl<Callback, cpu_type>(callback); \
const Integer *B0_col = reinterpret_cast<const Integer*>(B); \
- auto inited_pipeline = InitPostprocessPipeline<cpu_type>(pipeline); \
/*Go over 8 columns of B at a time.*/ \
for (Index B0_colidx = 0; B0_colidx != B_cols; B0_col += 8 * simd_width, B0_colidx += 8) { \
/*Process one row of A at a time. Doesn't seem to be faster to do multiple rows of A at once.*/ \
@@ -412,9 +410,7 @@ INTGEMM_SSSE3 inline static void InnerINTGEMM_SSSE3(
Integer pack0123 = Pack0123(sum0, sum1, sum2, sum3); \
Integer pack4567 = Pack0123(sum4, sum5, sum6, sum7); \
auto total = PermuteSummer(pack0123, pack4567); \
- auto offset = A_rowidx * B_cols + B0_colidx; \
- auto result = inited_pipeline.run(total, offset); \
- writer(C, offset, result); \
+ callback_impl(total, A_rowidx, B0_colidx, A_rows, width, B_cols); \
} \
} \
} \
diff --git a/postprocess.h b/postprocess.h
deleted file mode 100644
index 53c5a3e..0000000
--- a/postprocess.h
+++ /dev/null
@@ -1,390 +0,0 @@
-#pragma once
-
-#include "intrinsics.h"
-#include "postprocess_pipeline.h"
-#include "types.h"
-#include "vec_utils.h"
-#include "vec_traits.h"
-
-// TODO: We support some postprocess in few variations e.g. we support ReLU for
-// float -> float, int8 -> int8, int16 -> int16. Maybe it would be a good idea
-// to pass input type and output type as a template parameter of postprocess?
-
-namespace intgemm {
-
-/*
- * Unquantize
- */
-class Unquantize {
-public:
- float unquantize_multiplier;
-
- Unquantize(float unquantize_multiplier) : unquantize_multiplier(unquantize_multiplier) {}
-};
-
-template <>
-class PostprocessImpl<Unquantize, CPUType::SSE2> {
-public:
- using InputRegister = dvector_t<CPUType::SSE2, int>;
- using OutputRegister = dvector_t<CPUType::SSE2, float>;
-
- INTGEMM_SSE2 PostprocessImpl(const Unquantize& config) {
- unquantize_multiplier = set1_ps<__m128>(config.unquantize_multiplier);
- }
-
- INTGEMM_SSE2 inline OutputRegister run(InputRegister input, Index offset) {
- return {
- mul_ps(cvtepi32_ps(input.first), unquantize_multiplier),
- mul_ps(cvtepi32_ps(input.second), unquantize_multiplier),
- };
- }
-
-private:
- __m128 unquantize_multiplier;
-};
-
-template <>
-class PostprocessImpl<Unquantize, CPUType::AVX2> {
-public:
- using InputRegister = __m256i;
- using OutputRegister = __m256;
-
- INTGEMM_AVX2 PostprocessImpl(const Unquantize& config) {
- unquantize_multiplier = set1_ps<__m256>(config.unquantize_multiplier);
- }
-
- INTGEMM_AVX2 inline OutputRegister run(InputRegister input, Index offset) {
- return mul_ps(cvtepi32_ps(input), unquantize_multiplier);
- }
-
-private:
- __m256 unquantize_multiplier;
-};
-
-#ifndef INTGEMM_NO_AVX512
-
-template <>
-class PostprocessImpl<Unquantize, CPUType::AVX512BW> {
-public:
- using InputRegister = __m512i;
- using OutputRegister = __m512;
-
- INTGEMM_AVX512BW PostprocessImpl(const Unquantize& config) {
- unquantize_multiplier = set1_ps<__m512>(config.unquantize_multiplier);
- }
-
- INTGEMM_AVX512BW inline OutputRegister run(InputRegister input, Index offset) {
- return mul_ps(cvtepi32_ps(input), unquantize_multiplier);
- }
-
-private:
- __m512 unquantize_multiplier;
-};
-
-#endif
-
-/*
- * Add a bias term
- */
-class AddBias {
-public:
- const float* bias;
- const Index length;
-
- AddBias(const float* bias, Index length) : bias(bias), length(length) {}
-};
-
-template <>
-class PostprocessImpl<AddBias, CPUType::SSE2> {
-public:
- using InputRegister = dvector_t<CPUType::SSE2, float>;
- using OutputRegister = dvector_t<CPUType::SSE2, float>;
-
- PostprocessImpl(const AddBias& config) : config(config) {}
-
- INTGEMM_SSE2 inline OutputRegister run(InputRegister input, Index offset) {
- auto bias_term0123 = *reinterpret_cast<const __m128*>(config.bias + (offset % config.length));
- auto bias_term4567 = *reinterpret_cast<const __m128*>(config.bias + (offset % config.length) + 4);
- return {
- add_ps(input.first, bias_term0123),
- add_ps(input.second, bias_term4567),
- };
- }
-
-private:
- const AddBias config;
-};
-
-template <>
-class PostprocessImpl<AddBias, CPUType::AVX2> {
-public:
- using InputRegister = __m256;
- using OutputRegister = __m256;
-
- PostprocessImpl(const AddBias& config) : config(config) {}
-
- INTGEMM_AVX2 inline OutputRegister run(InputRegister input, Index offset) {
- auto bias_term = *reinterpret_cast<const __m256*>(config.bias + (offset % config.length));
- return add_ps(input, bias_term);
- }
-
-private:
- const AddBias config;
-};
-
-#ifndef INTGEMM_NO_AVX512
-
-template <>
-class PostprocessImpl<AddBias, CPUType::AVX512BW> {
-public:
- using InputRegister = __m512;
- using OutputRegister = __m512;
-
- PostprocessImpl(const AddBias& config) : config(config) {}
-
- INTGEMM_AVX512BW inline OutputRegister run(InputRegister input, Index offset) {
- auto bias_term = *reinterpret_cast<const __m512*>(config.bias + (offset % config.length));
- return add_ps(input, bias_term);
- }
-
-private:
- const AddBias config;
-};
-
-#endif
-
-/*
- * ReLU
- */
-class ReLU {};
-
-template <>
-class PostprocessImpl<ReLU, CPUType::SSE2> {
-public:
- using InputRegister = dvector_t<CPUType::SSE2, float>;
- using OutputRegister = dvector_t<CPUType::SSE2, float>;
-
- PostprocessImpl(const ReLU& config) {}
-
- INTGEMM_SSE2 inline OutputRegister run(InputRegister input, Index offset) {
- static const auto const_zero = set1_ps<__m128>(0.f);
- return {
- max_ps(const_zero, input.first),
- max_ps(const_zero, input.second),
- };
- }
-};
-
-template <>
-class PostprocessImpl<ReLU, CPUType::SSSE3> : public PostprocessImpl<ReLU, CPUType::SSE2> {};
-
-template <>
-class PostprocessImpl<ReLU, CPUType::AVX2> {
-public:
- using InputRegister = __m256;
- using OutputRegister = __m256;
-
- PostprocessImpl(const ReLU& config) {}
-
- INTGEMM_AVX2 inline OutputRegister run(InputRegister input, Index offset) {
- static const auto const_zero = set1_ps<__m256>(0.f);
- return max_ps(const_zero, input);
- }
-};
-
-#ifndef INTGEMM_NO_AVX512
-
-template <>
-class PostprocessImpl<ReLU, CPUType::AVX512BW> {
-public:
- using InputRegister = __m512;
- using OutputRegister = __m512;
-
- PostprocessImpl(const ReLU& config) {}
-
- INTGEMM_AVX512BW inline OutputRegister run(InputRegister input, Index offset) {
- static const auto const_zero = set1_ps<__m512>(0.f);
- return max_ps(const_zero, input);
- }
-};
-
-#endif
-
-/*
- * ReLU_int8
- */
-class ReLU_int8 {};
-
-template <>
-class PostprocessImpl<ReLU_int8, CPUType::SSE2> {
-public:
- using InputRegister = __m128i;
- using OutputRegister = __m128i;
-
- PostprocessImpl(const ReLU_int8& config) {}
-
- INTGEMM_SSE2 inline OutputRegister run(InputRegister input, Index offset) {
- static const auto const_zero = setzero_si<__m128i>();
- return _mm_and_si128(_mm_cmplt_epi8(const_zero, input), input);
- }
-};
-
-template <>
-class PostprocessImpl<ReLU_int8, CPUType::AVX2> {
-public:
- using InputRegister = __m256i;
- using OutputRegister = __m256i;
-
- PostprocessImpl(const ReLU_int8& config) {}
-
- INTGEMM_AVX2 inline OutputRegister run(InputRegister input, Index offset) {
- static const auto const_zero = setzero_si<__m256i>();
- return max_epi8(const_zero, input);
- }
-};
-
-#ifndef INTGEMM_NO_AVX512
-
-template <>
-class PostprocessImpl<ReLU_int8, CPUType::AVX512BW> {
-public:
- using InputRegister = __m512i;
- using OutputRegister = __m512i;
-
- PostprocessImpl(const ReLU_int8& config) {}
-
- INTGEMM_AVX512BW inline OutputRegister run(InputRegister input, Index offset) {
- static const auto const_zero = setzero_si<__m512i>();
- return max_epi8(const_zero, input);
- }
-};
-
-#endif
-
-/*
- * ReLU_int16
- */
-class ReLU_int16 {};
-
-template <>
-class PostprocessImpl<ReLU_int16, CPUType::SSE2> {
-public:
- using InputRegister = __m128i;
- using OutputRegister = __m128i;
-
- PostprocessImpl(const ReLU_int16& config) {}
-
- INTGEMM_SSE2 inline OutputRegister run(InputRegister input, Index offset) {
- static const auto const_zero = setzero_si<__m128i>();
- return max_epi16(const_zero, input);
- }
-};
-
-template <>
-class PostprocessImpl<ReLU_int16, CPUType::AVX2> {
-public:
- using InputRegister = __m256i;
- using OutputRegister = __m256i;
-
- PostprocessImpl(const ReLU_int16& config) {}
-
- INTGEMM_AVX2 inline OutputRegister run(InputRegister input, Index offset) {
- static const auto const_zero = setzero_si<__m256i>();
- return max_epi16(const_zero, input);
- }
-};
-
-#ifndef INTGEMM_NO_AVX512
-
-template <>
-class PostprocessImpl<ReLU_int16, CPUType::AVX512BW> {
-public:
- using InputRegister = __m512i;
- using OutputRegister = __m512i;
-
- PostprocessImpl(const ReLU_int16& config) {}
-
- INTGEMM_AVX512BW inline OutputRegister run(InputRegister input, Index offset) {
- static const auto const_zero = setzero_si<__m512i>();
- return max_epi16(const_zero, input);
- }
-};
-
-#endif
-
-/*
- * Sigmoid (uses Taylor series approximation of e^x)
- */
-class Sigmoid {};
-
-template <>
-class PostprocessImpl<Sigmoid, CPUType::AVX2> {
-public:
- using InputRegister = __m256;
- using OutputRegister = __m256;
-
- PostprocessImpl(const Sigmoid& config) {}
-
- INTGEMM_AVX2 inline OutputRegister run(InputRegister input, Index offset) {
- static const auto const_zero = set1_ps<__m256>(0.f);
- static const auto const_one = set1_ps<__m256>(1.f);
-
- auto x = input;
- auto minus_x = sub_ps(const_zero, x);
- auto e_x = exp_approx_taylor(x);
- auto e_minus_x = exp_approx_taylor(minus_x);
-
- auto sigmoid_case1 = _mm256_rcp_ps(add_ps(const_one, e_minus_x));
- auto sigmoid_case2 = mul_ps(e_x, _mm256_rcp_ps(add_ps(const_one, e_x)));
-
- auto nonnegative_x_mask = _mm256_cmp_ps(const_zero, x, _CMP_LT_OS);
- return _mm256_blendv_ps(sigmoid_case1, sigmoid_case2, nonnegative_x_mask);
- }
-};
-
-/*
- * Tanh (uses Taylor series approximation of e^x)
- */
-class Tanh {};
-
-template <>
-class PostprocessImpl<Tanh, CPUType::AVX2> {
-public:
- using InputRegister = __m256;
- using OutputRegister = __m256;
-
- PostprocessImpl(const Tanh& config) {}
-
- INTGEMM_AVX2 inline OutputRegister run(InputRegister input, Index offset) {
- const static auto const_zero = setzero_ps<__m256>();
-
- auto e_x = exp_approx_taylor(input);
- auto e_minus_x = exp_approx_taylor(sub_ps(const_zero, input));
-
- return div_ps(sub_ps(e_x, e_minus_x), add_ps(e_x, e_minus_x));
- }
-};
-
-#ifndef INTGEMM_NO_AVX512
-
-template <>
-class PostprocessImpl<Tanh, CPUType::AVX512BW> {
-public:
- using InputRegister = __m512;
- using OutputRegister = __m512;
-
- PostprocessImpl(const Tanh& config) {}
-
- INTGEMM_AVX512BW inline OutputRegister run(InputRegister input, Index offset) {
- const static auto const_zero = setzero_ps<__m512>();
-
- auto e_x = exp_approx_taylor(input);
- auto e_minus_x = exp_approx_taylor(sub_ps(const_zero, input));
-
- return div_ps(sub_ps(e_x, e_minus_x), add_ps(e_x, e_minus_x));
- }
-};
-
-#endif
-
-}
diff --git a/postprocess_pipeline.h b/postprocess_pipeline.h
deleted file mode 100644
index 361ff2b..0000000
--- a/postprocess_pipeline.h
+++ /dev/null
@@ -1,113 +0,0 @@
-#pragma once
-
-#include "intrinsics.h"
-#include "types.h"
-#include "utils.h"
-
-#include <tuple>
-
-namespace intgemm {
-
-template <typename... Stages>
-using PostprocessPipeline = std::tuple<Stages...>;
-
-template <typename... Stages>
-constexpr std::tuple<Stages...> CreatePostprocessPipeline(Stages&&... stages) {
- return std::make_tuple(std::forward<Stages>(stages)...);
-}
-
-template <typename Postprocess, CPUType CpuType>
-class PostprocessImpl;
-
-namespace { // anonymous namespace
-
-template <typename... Stages>
-using input_register_type = typename std::tuple_element<
- 0,
- std::tuple<Stages...>
- >::type::InputRegister;
-
-template <typename... Stages>
-using output_register_type = typename std::tuple_element<
- std::tuple_size<std::tuple<Stages...>>::value - 1,
- std::tuple<Stages...>
- >::type::OutputRegister;
-
-template <typename FirstStage, typename... RestStages>
-constexpr std::tuple<RestStages...> DropFirstStage(const std::tuple<FirstStage, RestStages...>& pipeline) {
- return make_subtuple(pipeline, sequence_popfront<make_sequence<sizeof...(RestStages) + 1>>());
-}
-
-template <CPUType CpuType>
-constexpr std::tuple<> InitPostprocessPipelineImpl(std::tuple<> pipeline) {
- return std::tuple<>();
-}
-
-template <CPUType CpuType, typename FirstStage, typename... RestStages>
-constexpr std::tuple<PostprocessImpl<FirstStage, CpuType>, PostprocessImpl<RestStages, CpuType>...> InitPostprocessPipelineImpl(std::tuple<FirstStage, RestStages...> pipeline) {
- return std::tuple_cat(
- std::tuple<PostprocessImpl<FirstStage, CpuType>>(PostprocessImpl<FirstStage, CpuType>(std::get<0>(pipeline))),
- InitPostprocessPipelineImpl<CpuType, RestStages...>(DropFirstStage(pipeline))
- );
-}
-
-template <CPUType CpuType>
-struct RunPostprocessPipelineImpl;
-
-#define RUN_POSTPROCESS_PIPELINE_IMPL_INSERT_IMPL(attribute, cpu_type) \
- template <> \
- struct RunPostprocessPipelineImpl<cpu_type> { \
- template <typename Stage> \
- attribute static constexpr output_register_type<Stage> \
- run(std::tuple<Stage> pipeline, input_register_type<Stage> input, Index offset) { \
- return std::get<0>(pipeline).run(input, offset); \
- } \
- template <typename... Stages> \
- attribute static constexpr output_register_type<Stages...> \
- run(std::tuple<Stages...> pipeline, input_register_type<Stages...> input, Index offset) { \
- return run( \
- DropFirstStage(pipeline), \
- std::get<0>(pipeline).run(input, offset), offset); \
- } \
- };
-
-RUN_POSTPROCESS_PIPELINE_IMPL_INSERT_IMPL(INTGEMM_SSE2, CPUType::SSE2)
-RUN_POSTPROCESS_PIPELINE_IMPL_INSERT_IMPL(INTGEMM_SSSE3, CPUType::SSSE3)
-RUN_POSTPROCESS_PIPELINE_IMPL_INSERT_IMPL(INTGEMM_AVX2, CPUType::AVX2)
-RUN_POSTPROCESS_PIPELINE_IMPL_INSERT_IMPL(INTGEMM_AVX512BW, CPUType::AVX512BW)
-
-} // anonymous namespace
-
-template <CPUType CpuType, typename... Stages>
-class InitedPostprocessPipeline {};
-
-template <CPUType CpuType, typename... Stages>
-constexpr InitedPostprocessPipeline<CpuType, Stages...> InitPostprocessPipeline(std::tuple<Stages...> pipeline) {
- return InitedPostprocessPipeline<CpuType, Stages...>(pipeline);
-}
-
-#define INITED_POSTPROCESS_PIPELINE_INSERT_IMPL(attribute, cpu_type) \
- template <typename... Stages> \
- class InitedPostprocessPipeline<cpu_type, Stages...> { \
- public: \
- using InputRegister = input_register_type<PostprocessImpl<Stages, cpu_type>...>; \
- using OutputRegister = output_register_type<PostprocessImpl<Stages, cpu_type>...>; \
- InitedPostprocessPipeline(std::tuple<Stages...> pipeline) \
- : inited_pipeline(InitPostprocessPipelineImpl<cpu_type, Stages...>(pipeline)) {} \
- attribute inline OutputRegister run(InputRegister input, Index offset) { \
- return RunPostprocessPipelineImpl<cpu_type>::run(inited_pipeline, input, offset); \
- } \
- attribute inline void run(const InputRegister* input, unsigned length, OutputRegister* output) { \
- for (unsigned i = 0; i < length; ++i) \
- output[i] = RunPostprocessPipelineImpl<cpu_type>::run(inited_pipeline, input[i], i * sizeof(InputRegister)); \
- } \
- private: \
- const std::tuple<PostprocessImpl<Stages, cpu_type>...> inited_pipeline; \
- };
-
-INITED_POSTPROCESS_PIPELINE_INSERT_IMPL(INTGEMM_SSE2, CPUType::SSE2)
-INITED_POSTPROCESS_PIPELINE_INSERT_IMPL(INTGEMM_SSSE3, CPUType::SSSE3)
-INITED_POSTPROCESS_PIPELINE_INSERT_IMPL(INTGEMM_AVX2, CPUType::AVX2)
-INITED_POSTPROCESS_PIPELINE_INSERT_IMPL(INTGEMM_AVX512BW, CPUType::AVX512BW)
-
-}
diff --git a/test/multiply_test.cc b/test/multiply_test.cc
index 93d7127..e7fcf77 100644
--- a/test/multiply_test.cc
+++ b/test/multiply_test.cc
@@ -3,7 +3,7 @@
#include "interleave.h"
#include "intgemm.h"
#include "multiply.h"
-#include "postprocess.h"
+#include "callbacks.h"
#include <algorithm>
#include <cassert>
@@ -361,7 +361,7 @@ template <class Routine> void TestMultiply(Index A_rows, Index width, Index B_co
Routine::PrepareB(B.begin(), B_prep.begin(), quant_mult, width, B_cols);
AlignedVector<float> test_C(A_rows * B_cols);
- Routine::Multiply(A_prep.begin(), B_prep.begin(), test_C.begin(), CreatePostprocessPipeline(Unquantize(unquant_mult)), A_rows, width, B_cols);
+ Routine::Multiply(A_prep.begin(), B_prep.begin(), A_rows, width, B_cols, callbacks::Dummy());
AlignedVector<Integer> B_quant(B.size());
Routine::Quantize(B.begin(), B_quant.begin(), quant_mult, B.size());
@@ -410,7 +410,7 @@ template <class Routine> void TestMultiplyBias(Index A_rows, Index width, Index
AlignedVector<float> test_C(A_rows * B_cols);
- Routine::Multiply(A_prep.begin(), B_prep.begin(), test_C.begin(), CreatePostprocessPipeline(Unquantize(unquant_mult), AddBias(bias.begin(), B_cols)), A_rows, width, B_cols);
+ Routine::Multiply(A_prep.begin(), B_prep.begin(), A_rows, width, B_cols, callbacks::Dummy());
AlignedVector<Integer> B_quant(B.size());
Routine::Quantize(B.begin(), B_quant.begin(), quant_mult, B.size());
diff --git a/test/postprocess/add_bias_test.cc b/test/postprocess/add_bias_test.cc
deleted file mode 100644
index 3bc7f74..0000000
--- a/test/postprocess/add_bias_test.cc
+++ /dev/null
@@ -1,95 +0,0 @@
-#include "test/test.h"
-#include "aligned.h"
-#include "postprocess.h"
-
-#include <numeric>
-
-namespace intgemm {
-
-INTGEMM_SSE2 TEST_CASE("AddBias SSE2",) {
- if (kCPU < CPUType::SSE2)
- return;
-
- AlignedVector<float> input(8);
- AlignedVector<float> bias(8);
- AlignedVector<float> output(8);
-
- std::iota(input.begin(), input.end(), -2);
- std::iota(bias.begin(), bias.end(), 0);
-
- auto postproc = PostprocessImpl<AddBias, CPUType::SSE2>(AddBias(bias.begin(), bias.size()));
- auto output_tmp = postproc.run({input.as<__m128>()[0], input.as<__m128>()[1]}, 0);
- output.as<__m128>()[0] = output_tmp.first;
- output.as<__m128>()[1] = output_tmp.second;
-
- CHECK(output[0] == -2.f); // input = -2, bias = 0
- CHECK(output[1] == 0.f); // input = -1, bias = 1
- CHECK(output[2] == 2.f); // input = 0, bias = 2
- CHECK(output[3] == 4.f); // input = 1, bias = 3
- CHECK(output[4] == 6.f); // input = 2, bias = 4
- CHECK(output[5] == 8.f); // input = 3, bias = 5
- CHECK(output[6] == 10.f); // input = 4, bias = 6
- CHECK(output[7] == 12.f); // input = 5, bias = 7
-}
-
-INTGEMM_AVX2 TEST_CASE("AddBias AVX2",) {
- if (kCPU < CPUType::AVX2)
- return;
-
- AlignedVector<float> input(8);
- AlignedVector<float> bias(8);
- AlignedVector<float> output(8);
-
- std::iota(input.begin(), input.end(), -4);
- std::iota(bias.begin(), bias.end(), 0);
-
- auto postproc = PostprocessImpl<AddBias, CPUType::AVX2>(AddBias(bias.begin(), bias.size()));
- *output.as<__m256>() = postproc.run(*input.as<__m256>(), 0);
-
- CHECK(output[0] == -4.f); // input = -4, bias = 0
- CHECK(output[1] == -2.f); // input = -3, bias = 1
- CHECK(output[2] == 0.f); // input = -2, bias = 2
- CHECK(output[3] == 2.f); // input = -1, bias = 3
- CHECK(output[4] == 4.f); // input = 0, bias = 4
- CHECK(output[5] == 6.f); // input = 1, bias = 5
- CHECK(output[6] == 8.f); // input = 2, bias = 6
- CHECK(output[7] == 10.f); // input = 3, bias = 7
-}
-
-#ifndef INTGEMM_NO_AVX512
-
-INTGEMM_AVX512BW TEST_CASE("AddBias AVX512",) {
- if (kCPU < CPUType::AVX512BW)
- return;
-
- AlignedVector<float> input(16);
- AlignedVector<float> bias(16);
- AlignedVector<float> output(16);
-
- std::iota(input.begin(), input.end(), -8);
- std::iota(bias.begin(), bias.end(), 0);
-
- auto postproc = PostprocessImpl<AddBias, CPUType::AVX512BW>(AddBias(bias.begin(), bias.size()));
- *output.as<__m512>() = postproc.run(*input.as<__m512>(), 0);
-
- CHECK(output[0] == -8.f); // input = -8, bias = 0
- CHECK(output[1] == -6.f); // input = -7, bias = 1
- CHECK(output[2] == -4.f); // input = -6, bias = 2
- CHECK(output[3] == -2.f); // input = -5, bias = 3
- CHECK(output[4] == 0.f); // input = -4, bias = 4
- CHECK(output[5] == 2.f); // input = -3, bias = 5
- CHECK(output[6] == 4.f); // input = -2, bias = 6
- CHECK(output[7] == 6.f); // input = -1, bias = 7
- CHECK(output[8] == 8.f); // input = 0, bias = 8
- CHECK(output[9] == 10.f); // input = 1, bias = 9
- CHECK(output[10] == 12.f); // input = 2, bias = 10
- CHECK(output[11] == 14.f); // input = 3, bias = 11
- CHECK(output[12] == 16.f); // input = 4, bias = 12
- CHECK(output[13] == 18.f); // input = 5, bias = 13
- CHECK(output[14] == 20.f); // input = 6, bias = 14
- CHECK(output[15] == 22.f); // input = 7, bias = 15
-}
-
-#endif
-
-}
diff --git a/test/postprocess/pipeline_test.cc b/test/postprocess/pipeline_test.cc
deleted file mode 100644
index 144ee48..0000000
--- a/test/postprocess/pipeline_test.cc
+++ /dev/null
@@ -1,63 +0,0 @@
-#include "test/test.h"
-#include "aligned.h"
-#include "postprocess.h"
-
-#include <numeric>
-
-namespace intgemm {
-
-INTGEMM_AVX2 TEST_CASE("PostprocessPipeline AVX2", "Unquantize-ReLU") {
- if (kCPU < CPUType::AVX2)
- return;
-
- AlignedVector<int32_t> input(8);
- AlignedVector<float> output(8);
-
- std::iota(input.begin(), input.end(), -2);
-
- auto pipeline = CreatePostprocessPipeline(Unquantize(0.5f), ReLU());
- auto inited_pipeline = InitPostprocessPipeline<CPUType::AVX2>(pipeline);
- *output.as<__m256>() = inited_pipeline.run(*input.as<__m256i>(), 0);
-
- CHECK(output[0] == 0.0f); // input = -2
- CHECK(output[1] == 0.0f); // input = -1
- CHECK(output[2] == 0.0f); // input = 0
- CHECK(output[3] == 0.5f); // input = 1
- CHECK(output[4] == 1.0f); // input = 2
- CHECK(output[5] == 1.5f); // input = 3
- CHECK(output[6] == 2.0f); // input = 4
- CHECK(output[7] == 2.5f); // input = 5
-}
-
-INTGEMM_AVX2 TEST_CASE("PostprocessPipeline AVX2 on whole buffer", "Unquantize-ReLU") {
- if (kCPU < CPUType::AVX2)
- return;
-
- AlignedVector<int32_t> input(16);
- AlignedVector<float> output(16);
-
- std::iota(input.begin(), input.end(), -8);
-
- auto pipeline = CreatePostprocessPipeline(Unquantize(0.5f), ReLU());
- auto inited_pipeline = InitPostprocessPipeline<CPUType::AVX2>(pipeline);
- inited_pipeline.run(input.as<__m256i>(), 2, output.as<__m256>());
-
- CHECK(output[0] == 0.f); // input = -8
- CHECK(output[1] == 0.f); // input = -7
- CHECK(output[2] == 0.f); // input = -6
- CHECK(output[3] == 0.f); // input = -5
- CHECK(output[4] == 0.f); // input = -4
- CHECK(output[5] == 0.f); // input = -3
- CHECK(output[6] == 0.f); // input = -2
- CHECK(output[7] == 0.f); // input = -1
- CHECK(output[8] == 0.0f); // input = 0
- CHECK(output[9] == 0.5f); // input = 1
- CHECK(output[10] == 1.0f); // input = 2
- CHECK(output[11] == 1.5f); // input = 3
- CHECK(output[12] == 2.0f); // input = 4
- CHECK(output[13] == 2.5f); // input = 5
- CHECK(output[14] == 3.0f); // input = 6
- CHECK(output[15] == 3.5f); // input = 7
-}
-
-}
diff --git a/test/postprocess/relu_test.cc b/test/postprocess/relu_test.cc
deleted file mode 100644
index a560790..0000000
--- a/test/postprocess/relu_test.cc
+++ /dev/null
@@ -1,213 +0,0 @@
-#include "test/test.h"
-#include "aligned.h"
-#include "postprocess.h"
-
-#include <numeric>
-
-namespace intgemm {
-
-/*
- * ReLU: float -> float
- */
-INTGEMM_SSE2 TEST_CASE("ReLU SSE2",) {
- if (kCPU < CPUType::SSE2)
- return;
-
- AlignedVector<float> input(8);
- AlignedVector<float> output(8);
- std::iota(input.begin(), input.end(), -2);
-
- auto postproc = PostprocessImpl<ReLU, CPUType::SSE2>(ReLU());
- auto output_tmp = postproc.run({input.as<__m128>()[0], input.as<__m128>()[1]}, 0);
- output.as<__m128>()[0] = output_tmp.first;
- output.as<__m128>()[1] = output_tmp.second;
-
- CHECK(output[0] == 0.f); // input = -2
- CHECK(output[1] == 0.f); // input = -1
- CHECK(output[2] == 0.f); // input = 0
- CHECK(output[3] == 1.f); // input = 1
- CHECK(output[4] == 2.f); // input = 2
- CHECK(output[5] == 3.f); // input = 3
- CHECK(output[6] == 4.f); // input = 4
- CHECK(output[7] == 5.f); // input = 5
-}
-
-INTGEMM_AVX2 TEST_CASE("ReLU AVX2",) {
- if (kCPU < CPUType::AVX2)
- return;
-
- AlignedVector<float> input(8);
- AlignedVector<float> output(8);
-
- std::iota(input.begin(), input.end(), -4);
-
- auto postproc = PostprocessImpl<ReLU, CPUType::AVX2>(ReLU());
- *output.as<__m256>() = postproc.run(*input.as<__m256>(), 0);
-
- CHECK(output[0] == 0.f); // input = -4
- CHECK(output[1] == 0.f); // input = -3
- CHECK(output[2] == 0.f); // input = -2
- CHECK(output[3] == 0.f); // input = -1
- CHECK(output[4] == 0.f); // input = 0
- CHECK(output[5] == 1.f); // input = 1
- CHECK(output[6] == 2.f); // input = 2
- CHECK(output[7] == 3.f); // input = 3
-}
-
-#ifndef INTGEMM_NO_AVX512
-
-INTGEMM_AVX512BW TEST_CASE("ReLU AVX512",) {
- if (kCPU < CPUType::AVX512BW)
- return;
-
- AlignedVector<float> input(16);
- AlignedVector<float> output(16);
-
- std::iota(input.begin(), input.end(), -8);
-
- auto postproc = PostprocessImpl<ReLU, CPUType::AVX512BW>(ReLU());
- *output.as<__m512>() = postproc.run(*input.as<__m512>(), 0);
-
- CHECK(output[0] == 0.f); // input = -8
- CHECK(output[1] == 0.f); // input = -7
- CHECK(output[2] == 0.f); // input = -6
- CHECK(output[3] == 0.f); // input = -5
- CHECK(output[4] == 0.f); // input = -4
- CHECK(output[5] == 0.f); // input = -3
- CHECK(output[6] == 0.f); // input = -2
- CHECK(output[7] == 0.f); // input = -1
- CHECK(output[8] == 0.f); // input = 0
- CHECK(output[9] == 1.f); // input = 1
- CHECK(output[10] == 2.f); // input = 2
- CHECK(output[11] == 3.f); // input = 3
- CHECK(output[12] == 4.f); // input = 4
- CHECK(output[13] == 5.f); // input = 5
- CHECK(output[14] == 6.f); // input = 6
- CHECK(output[15] == 7.f); // input = 7
-}
-
-#endif
-
-/*
- * ReLU: int8 -> int8
- */
-INTGEMM_SSE2 TEST_CASE("ReLU_int8 SSE2",) {
- if (kCPU < CPUType::SSE2)
- return;
-
- const unsigned NEGATIVE_NUMBERS = 10;
-
- AlignedVector<int8_t> input(16);
- AlignedVector<int8_t> output(16);
-
- std::iota(input.begin(), input.end(), -NEGATIVE_NUMBERS);
-
- auto postproc = PostprocessImpl<ReLU_int8, CPUType::SSE2>(ReLU_int8());
- *output.as<__m128i>() = postproc.run(*input.as<__m128i>(), 0);
-
- for (auto i = 0; i < output.size(); ++i)
- CHECK(output[i] == (i <= NEGATIVE_NUMBERS ? 0 : i - NEGATIVE_NUMBERS));
-}
-
-INTGEMM_AVX2 TEST_CASE("ReLU_int8 AVX2",) {
- if (kCPU < CPUType::AVX2)
- return;
-
- const unsigned NEGATIVE_NUMBERS = 10;
-
- AlignedVector<int8_t> input(32);
- AlignedVector<int8_t> output(32);
-
- std::iota(input.begin(), input.end(), -NEGATIVE_NUMBERS);
-
- auto postproc = PostprocessImpl<ReLU_int8, CPUType::AVX2>(ReLU_int8());
- *output.as<__m256i>() = postproc.run(*input.as<__m256i>(), 0);
-
- for (auto i = 0; i < output.size(); ++i)
- CHECK(output[i] == (i <= NEGATIVE_NUMBERS ? 0 : i - NEGATIVE_NUMBERS));
-}
-
-#ifndef INTGEMM_NO_AVX512
-
-INTGEMM_AVX512BW TEST_CASE("ReLU_int8 AVX512",) {
- if (kCPU < CPUType::AVX512BW)
- return;
-
- const unsigned NEGATIVE_NUMBERS = 30;
-
- AlignedVector<int8_t> input(64);
- AlignedVector<int8_t> output(64);
-
- std::iota(input.begin(), input.end(), -NEGATIVE_NUMBERS);
-
- auto postproc = PostprocessImpl<ReLU_int8, CPUType::AVX512BW>(ReLU_int8());
- *output.as<__m512i>() = postproc.run(*input.as<__m512i>(), 0);
-
- for (auto i = 0; i < output.size(); ++i)
- CHECK(output[i] == (i <= NEGATIVE_NUMBERS ? 0 : i - NEGATIVE_NUMBERS));
-}
-
-#endif
-
-/*
- * ReLU: int16 -> int16
- */
-INTGEMM_SSE2 TEST_CASE("ReLU_int16 SSE2",) {
- if (kCPU < CPUType::SSE2)
- return;
-
- const unsigned NEGATIVE_NUMBERS = 5;
-
- AlignedVector<int16_t> input(8);
- AlignedVector<int16_t> output(8);
-
- std::iota(input.begin(), input.end(), -NEGATIVE_NUMBERS);
-
- auto postproc = PostprocessImpl<ReLU_int16, CPUType::SSE2>(ReLU_int16());
- *output.as<__m128i>() = postproc.run(*input.as<__m128i>(), 0);
-
- for (auto i = 0; i < output.size(); ++i)
- CHECK(output[i] == (i <= NEGATIVE_NUMBERS ? 0 : i - NEGATIVE_NUMBERS));
-}
-
-INTGEMM_AVX2 TEST_CASE("ReLU_int16 AVX2",) {
- if (kCPU < CPUType::AVX2)
- return;
-
- const unsigned NEGATIVE_NUMBERS = 10;
-
- AlignedVector<int16_t> input(16);
- AlignedVector<int16_t> output(16);
-
- std::iota(input.begin(), input.end(), -NEGATIVE_NUMBERS);
-
- auto postproc = PostprocessImpl<ReLU_int16, CPUType::AVX2>(ReLU_int16());
- *output.as<__m256i>() = postproc.run(*input.as<__m256i>(), 0);
-
- for (auto i = 0; i < output.size(); ++i)
- CHECK(output[i] == (i <= NEGATIVE_NUMBERS ? 0 : i - NEGATIVE_NUMBERS));
-}
-
-#ifndef INTGEMM_NO_AVX512
-
-INTGEMM_AVX512BW TEST_CASE("ReLU_int16 AVX512",) {
- if (kCPU < CPUType::AVX512BW)
- return;
-
- const unsigned NEGATIVE_NUMBERS = 15;
-
- AlignedVector<int16_t> input(32);
- AlignedVector<int16_t> output(32);
-
- std::iota(input.begin(), input.end(), -NEGATIVE_NUMBERS);
-
- auto postproc = PostprocessImpl<ReLU_int16, CPUType::AVX512BW>(ReLU_int16());
- *output.as<__m512i>() = postproc.run(*input.as<__m512i>(), 0);
-
- for (auto i = 0; i < output.size(); ++i)
- CHECK(output[i] == (i <= NEGATIVE_NUMBERS ? 0 : i - NEGATIVE_NUMBERS));
-}
-
-#endif
-
-}
diff --git a/test/postprocess/sigmoid_test.cc b/test/postprocess/sigmoid_test.cc
deleted file mode 100644
index 43c713c..0000000
--- a/test/postprocess/sigmoid_test.cc
+++ /dev/null
@@ -1,33 +0,0 @@
-#include "test/test.h"
-#include "aligned.h"
-#include "postprocess.h"
-
-#include <numeric>
-
-namespace intgemm {
-
-INTGEMM_AVX2 TEST_CASE("Sigmoid AVX2",) {
- if (kCPU < CPUType::AVX2)
- return;
-
- const float error_tolerance = 0.001f;
-
- AlignedVector<float> input(8);
- AlignedVector<float> output(8);
-
- std::iota(input.begin(), input.end(), -4);
-
- auto postproc = PostprocessImpl<Sigmoid, CPUType::AVX2>(Sigmoid());
- *output.as<__m256>() = postproc.run(*input.as<__m256>(), 0);
-
- CHECK_EPS(output[0], 0.0179862f, error_tolerance); // input = -4
- CHECK_EPS(output[1], 0.0474259f, error_tolerance); // input = -3
- CHECK_EPS(output[2], 0.1192029f, error_tolerance); // input = -2
- CHECK_EPS(output[3], 0.2689414f, error_tolerance); // input = -1
- CHECK_EPS(output[4], 0.5f , error_tolerance); // input = 0
- CHECK_EPS(output[5], 0.7310586f, error_tolerance); // input = 1
- CHECK_EPS(output[6], 0.8807970f, error_tolerance); // input = 2
- CHECK_EPS(output[7], 0.9525740f, error_tolerance); // input = 3
-}
-
-}
diff --git a/test/postprocess/tanh_test.cc b/test/postprocess/tanh_test.cc
deleted file mode 100644
index f0e4dc2..0000000
--- a/test/postprocess/tanh_test.cc
+++ /dev/null
@@ -1,33 +0,0 @@
-#include "test/test.h"
-#include "aligned.h"
-#include "postprocess.h"
-
-#include <numeric>
-
-namespace intgemm {
-
-INTGEMM_AVX2 TEST_CASE("Tanh AVX2",) {
- if (kCPU < CPUType::AVX2)
- return;
-
- const float error_tolerance = 0.001f;
-
- AlignedVector<float> input(8);
- AlignedVector<float> output(8);
-
- std::generate(input.begin(), input.end(), [] () { static int n = -4; return n++ / 4.f; });
-
- auto postproc = PostprocessImpl<Tanh, CPUType::AVX2>(Tanh());
- *output.as<__m256>() = postproc.run(*input.as<__m256>(), 0);
-
- CHECK_EPS(output[0], -0.7615942f, error_tolerance); // input = -1
- CHECK_EPS(output[1], -0.6351490f, error_tolerance); // input = -0.75
- CHECK_EPS(output[2], -0.4621172f, error_tolerance); // input = -0.5
- CHECK_EPS(output[3], -0.2449187f, error_tolerance); // input = -0.25
- CHECK_EPS(output[4], 0.0f , error_tolerance); // input = 0
- CHECK_EPS(output[5], 0.2449187f, error_tolerance); // input = 0.25
- CHECK_EPS(output[6], 0.4621172f, error_tolerance); // input = 0.5
- CHECK_EPS(output[7], 0.6351490f, error_tolerance); // input = 0.75
-}
-
-}
diff --git a/test/postprocess/unquantize_test.cc b/test/postprocess/unquantize_test.cc
deleted file mode 100644
index 45e6bc4..0000000
--- a/test/postprocess/unquantize_test.cc
+++ /dev/null
@@ -1,88 +0,0 @@
-#include "test/test.h"
-#include "aligned.h"
-#include "postprocess.h"
-
-#include <numeric>
-
-namespace intgemm {
-
-INTGEMM_SSE2 TEST_CASE("Unquantize SSE2",) {
- if (kCPU < CPUType::SSE2)
- return;
-
- AlignedVector<int32_t> input(8);
- AlignedVector<float> output(8);
- std::iota(input.begin(), input.end(), -2);
-
- auto postproc = PostprocessImpl<Unquantize, CPUType::SSE2>(Unquantize(0.5f));
- auto output_tmp = postproc.run({input.as<__m128i>()[0], input.as<__m128i>()[1]}, 0);
- output.as<__m128>()[0] = output_tmp.first;
- output.as<__m128>()[1] = output_tmp.second;
-
- CHECK(output[0] == -1.0f); // input = -2
- CHECK(output[1] == -0.5f); // input = -1
- CHECK(output[2] == 0.0f); // input = 0
- CHECK(output[3] == 0.5f); // input = 1
- CHECK(output[4] == 1.0f); // input = 2
- CHECK(output[5] == 1.5f); // input = 3
- CHECK(output[6] == 2.0f); // input = 4
- CHECK(output[7] == 2.5f); // input = 5
-}
-
-INTGEMM_AVX2 TEST_CASE("Unquantize AVX2",) {
- if (kCPU < CPUType::AVX2)
- return;
-
- AlignedVector<int32_t> input(8);
- AlignedVector<float> output(8);
-
- std::iota(input.begin(), input.end(), -4);
-
- auto postproc = PostprocessImpl<Unquantize, CPUType::AVX2>(Unquantize(0.5f));
- *output.as<__m256>() = postproc.run(*input.as<__m256i>(), 0);
-
- CHECK(output[0] == -2.0f); // input = -4
- CHECK(output[1] == -1.5f); // input = -3
- CHECK(output[2] == -1.0f); // input = -2
- CHECK(output[3] == -0.5f); // input = -1
- CHECK(output[4] == 0.0f); // input = 0
- CHECK(output[5] == 0.5f); // input = 1
- CHECK(output[6] == 1.0f); // input = 2
- CHECK(output[7] == 1.5f); // input = 3
-}
-
-#ifndef INTGEMM_NO_AVX512
-
-INTGEMM_AVX512BW TEST_CASE("Unquantize AVX512",) {
- if (kCPU < CPUType::AVX512BW)
- return;
-
- AlignedVector<int32_t> input(16);
- AlignedVector<float> output(16);
-
- std::iota(input.begin(), input.end(), -8);
-
- auto postproc = PostprocessImpl<Unquantize, CPUType::AVX512BW>(Unquantize(0.5f));
- *output.as<__m512>() = postproc.run(*input.as<__m512i>(), 0);
-
- CHECK(output[0] == -4.0f); // input = -8
- CHECK(output[1] == -3.5f); // input = -7
- CHECK(output[2] == -3.0f); // input = -6
- CHECK(output[3] == -2.5f); // input = -5
- CHECK(output[4] == -2.0f); // input = -4
- CHECK(output[5] == -1.5f); // input = -3
- CHECK(output[6] == -1.0f); // input = -2
- CHECK(output[7] == -0.5f); // input = -1
- CHECK(output[8] == 0.0f); // input = 0
- CHECK(output[9] == 0.5f); // input = 1
- CHECK(output[10] == 1.0f); // input = 2
- CHECK(output[11] == 1.5f); // input = 3
- CHECK(output[12] == 2.0f); // input = 4
- CHECK(output[13] == 2.5f); // input = 5
- CHECK(output[14] == 3.0f); // input = 6
- CHECK(output[15] == 3.5f); // input = 7
-}
-
-#endif
-
-}
diff --git a/vec_utils.h b/vec_utils.h
index a5f1469..7254063 100644
--- a/vec_utils.h
+++ b/vec_utils.h
@@ -1,6 +1,7 @@
#pragma once
#include "intrinsics.h"
+#include "utils.h"
namespace intgemm {