Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/intgemm.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMateusz Chudyk <mateuszchudyk@gmail.com>2019-05-24 15:45:11 +0300
committerMateusz Chudyk <mateuszchudyk@gmail.com>2019-06-18 16:38:53 +0300
commitf0785bea3b42a8e5ab7e322b5ad0dc1e9018d65f (patch)
treeacee1462bbf5fd31ec58a1988c2d4f03009eefcf
parent5239e8820c3afa68abd679bfe7cad7e7b9b9893e (diff)
Add support for postprocess pipeline
-rw-r--r--CMakeLists.txt1
-rw-r--r--README.md2
-rw-r--r--avx2_gemm.h15
-rw-r--r--avx512_gemm.h11
-rw-r--r--benchmark.cc4
-rw-r--r--cops.h212
-rw-r--r--example.cc4
-rw-r--r--intgemm.h38
-rw-r--r--multiply.h43
-rw-r--r--postprocess.h221
-rw-r--r--postprocess_pipeline.h140
-rw-r--r--sse2_gemm.h2
-rw-r--r--ssse3_gemm.h13
-rw-r--r--test/multiply_test.cc7
-rw-r--r--test/pipeline_test.cc32
-rw-r--r--test/relu_test.cc133
-rw-r--r--vec_utils.h6
17 files changed, 532 insertions, 352 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt
index dc89b83..0d42c23 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -33,6 +33,7 @@ endforeach()
include_directories(.)
add_executable(tests
test/multiply_test.cc
+ test/pipeline_test.cc
test/quantize_test.cc
test/relu_test.cc
intgemm.cc
diff --git a/README.md b/README.md
index 83d25be..7742cbb 100644
--- a/README.md
+++ b/README.md
@@ -32,7 +32,7 @@ intgemm::Int16::PrepareA(A, A_prepared, quant_mult, A_rows, width);
/* Prepare B for multiplication. This is typically done offline. */
intgemm::Int16::PrepareB(B, B_prepared, quant_mult, width, B_cols);
/* Multiply and produce results in C */
-intgemm::Int16::Multiply<intgemm::JustUnquantizeC>(A_prepared.begin(), B_prepared.begin(), intgemm::JustUnquantizeC(C.begin(), 1.0 / (quant_mult * quant_mult)), A_rows, width, B_cols);
+intgemm::Int16::Multiply(A_prepared.begin(), B_prepared.begin(), C.begin(), intgemm::CreatePostprocessPipeline(intgemm::Unquantize(1.0 / (quant_mult * quant_mult))), A_rows, width, B_cols);
```
For 8-bit, use `Int8` instead of `Int16`.
diff --git a/avx2_gemm.h b/avx2_gemm.h
index c5ca0bc..a03ff09 100644
--- a/avx2_gemm.h
+++ b/avx2_gemm.h
@@ -80,7 +80,7 @@ struct AVX2_16bit {
avx2::SelectColumnsOfB((const __m256i*)input, (__m256i*)output, rows * 2, cols_begin, cols_end);
}
- INTGEMM_MULTIPLY16(__m256i, INTGEMM_AVX2, OnAVX2)
+ INTGEMM_MULTIPLY16(__m256i, INTGEMM_AVX2, CPUType::CPU_AVX2)
constexpr static const char *const kName = "16-bit INTGEMM_AVX2";
@@ -163,22 +163,13 @@ struct AVX2_8bit {
static const Index kBTileRow = 32;
static const Index kBTileCol = 8;
-/*
- INTGEMM_AVX2 static void PrepareB(const float *input, int8_t *output, float quant_mult, Index rows, Index cols) {
- PrepareBFor8(input, output, avx2::QuantizeTile8(quant_mult), rows, cols);
- }*/
-
INTGEMM_PREPARE_B_8(INTGEMM_AVX2, avx2::QuantizeTile8)
INTGEMM_AVX2 static void SelectColumnsB(const int8_t *input, int8_t *output, Index rows, const Index *cols_begin, const Index *cols_end) {
avx2::SelectColumnsOfB((const __m256i*)input, (__m256i*)output, rows, cols_begin, cols_end);
}
-/*
- INTGEMM_AVX2 static void Multiply(const int8_t *A, const int8_t *B, float *C, float unquant_mult, Index A_rows, Index width, Index B_cols) {
- //Multiply8_SSE2OrAVX2<Multiply8_AVXAVX2, __m256i, __m256>(A, B, C, unquant_mult, A_rows, width, B_cols);
- Multiply8_SSE2OrAVX2__m256i<JustUnquantizeC>(A, B, JustUnquantizeC(C, unquant_mult), A_rows, width, B_cols);
- }*/
- INTGEMM_MULTIPLY8(__m256i, INTGEMM_AVX2, OnAVX2)
+
+ INTGEMM_MULTIPLY8(__m256i, INTGEMM_AVX2, CPUType::CPU_AVX2)
constexpr static const char *const kName = "8-bit INTGEMM_AVX2";
diff --git a/avx512_gemm.h b/avx512_gemm.h
index c9233a6..28b94bf 100644
--- a/avx512_gemm.h
+++ b/avx512_gemm.h
@@ -166,7 +166,7 @@ struct AVX512_16bit {
}
/* Only INTGEMM_AVX512F is necessary but due to GCC 5.4 bug we have to set INTGEMM_AVX512BW */
- INTGEMM_MULTIPLY16(__m512i, INTGEMM_AVX512BW, OnAVX2)
+ INTGEMM_MULTIPLY16(__m512i, INTGEMM_AVX512BW, CPUType::CPU_AVX2)
constexpr static const char *const kName = "16-bit AVX512";
@@ -217,8 +217,8 @@ struct AVX512_8bit {
// Special AVX512 implementation due to having 32 registers (so I don't have to
// allocate registers manually) and no sign instruction.
- template <class WriteC>
- INTGEMM_AVX512BW static void Multiply(const int8_t *A, const int8_t *B, WriteC C, Index A_rows, Index width, Index B_cols) {
+ template <typename PostprocessPipeline>
+ INTGEMM_AVX512BW static void Multiply(const int8_t *A, const int8_t *B, float *C, PostprocessPipeline pipeline, Index A_rows, Index width, Index B_cols) {
typedef __m512i Integer;
//typedef __m256 Float; // For quantization we only do 8 at a time.
// This is copy-paste from Multiply8_SSE2OrAVX2.
@@ -227,7 +227,7 @@ struct AVX512_8bit {
assert(reinterpret_cast<uintptr_t>(A) % sizeof(Integer) == 0);
assert(reinterpret_cast<uintptr_t>(B) % sizeof(Integer) == 0);
// There's 8 results for INTGEMM_AVX2 to handle.
- typename WriteC::OnAVX2 write_C(C);
+ auto inited_pipeline = InitPostprocessPipeline<CPUType::CPU_AVX2>(pipeline);
const int simd_width = width / sizeof(Integer);
const Integer *B0_col = reinterpret_cast<const Integer*>(B);
// Added for AVX512.
@@ -324,7 +324,8 @@ struct AVX512_8bit {
Integer pack4567 = Pack0123(sum4, sum5, sum6, sum7);
auto total = PermuteSummer(pack0123, pack4567);
- write_C(A_rowidx, B_cols, B0_colidx, total);
+ auto result = inited_pipeline.run(total);
+ writer(C, A_rowidx, B_cols, B0_colidx, result);
}
}
}
diff --git a/benchmark.cc b/benchmark.cc
index 5e3d9d8..6477792 100644
--- a/benchmark.cc
+++ b/benchmark.cc
@@ -78,10 +78,10 @@ template <class Backend> void Run(const RandomMatrices &m, std::vector<uint64_t>
Backend::PrepareB(m.B.begin(), B_prepared.begin(), quant_mult, m.width, m.B_cols);
AlignedVector<float> output(m.A_rows * m.B_cols);
// Burn in
- Backend::Multiply(A_prepared.begin(), B_prepared.begin(), JustUnquantizeC(output.begin(), unquant_mult), m.A_rows, m.width, m.B_cols);
+ Backend::Multiply(A_prepared.begin(), B_prepared.begin(), output.begin(), CreatePostprocessPipeline(Unquantize(unquant_mult)), m.A_rows, m.width, m.B_cols);
{
StopWatch w(stats);
- Backend::Multiply(A_prepared.begin(), B_prepared.begin(), JustUnquantizeC(output.begin(), unquant_mult), m.A_rows, m.width, m.B_cols);
+ Backend::Multiply(A_prepared.begin(), B_prepared.begin(), output.begin(), CreatePostprocessPipeline(Unquantize(unquant_mult)), m.A_rows, m.width, m.B_cols);
}
}
diff --git a/cops.h b/cops.h
deleted file mode 100644
index 6d8c771..0000000
--- a/cops.h
+++ /dev/null
@@ -1,212 +0,0 @@
-#pragma once
-
-#include "intrinsics.h"
-#include "vec_utils.h"
-
-#include <cassert>
-#include <exception>
-
-namespace intgemm {
-
-class JustUnquantizeC {
- public:
- JustUnquantizeC(float *C, float unquant_mult) : C_(C), unquant_mult_(unquant_mult) {}
-
- class OnSSE2 {
- public:
- INTGEMM_SSE2 explicit OnSSE2(const JustUnquantizeC &from)
- : C_(from.C_), unquant_mult_(set1_ps<__m128>(from.unquant_mult_)) {
- assert(reinterpret_cast<uintptr_t>(C_) % sizeof(__m128i) == 0);
- }
-
- INTGEMM_SSE2 inline void operator()(Index rowIDX, Index cols, Index colIDX, MultiplyResult128 result) {
- storeu_ps(C_ + rowIDX*cols + colIDX , unquantize(result.pack0123, unquant_mult_));
- storeu_ps(C_ + rowIDX*cols + colIDX + 4, unquantize(result.pack4567, unquant_mult_));
- }
- private:
- float *C_;
- __m128 unquant_mult_;
- };
-
- class OnAVX2 {
- public:
- INTGEMM_AVX2 explicit OnAVX2(const JustUnquantizeC &from)
- : C_(from.C_), unquant_mult_(set1_ps<__m256>(from.unquant_mult_)) {
- assert(reinterpret_cast<uintptr_t>(C_) % sizeof(__m256i) == 0);
- }
-
- INTGEMM_AVX2 inline void operator()(Index rowIDX, Index cols, Index colIDX, __m256i result) {
- storeu_ps(C_ + rowIDX*cols + colIDX, unquantize(result, unquant_mult_));
- }
-
- private:
- float *C_;
- __m256 unquant_mult_;
- };
-
- private:
- float *C_;
- float unquant_mult_;
-};
-
-class BiasAddUnquantizeC {
- public:
- BiasAddUnquantizeC(float *C, const float *bias, float unquant_mult) : C_(C), bias_(bias), unquant_mult_(unquant_mult) {}
-
- class OnSSE2 {
- public:
- INTGEMM_SSE2 explicit OnSSE2(const BiasAddUnquantizeC &from)
- : C_(from.C_), bias_(from.bias_), unquant_mult_(_mm_set1_ps(from.unquant_mult_)) {
- assert(reinterpret_cast<uintptr_t>(C_) % sizeof(__m128i) == 0);
- }
-
- INTGEMM_SSE2 inline void operator()(Index rowIDX, Index cols, Index colIDX, MultiplyResult128 result) {
- auto biasSection0123 = loadu_ps<__m128>(bias_ + colIDX);
- auto biasSection4567 = loadu_ps<__m128>(bias_ + colIDX + 4);
- storeu_ps(C_ + rowIDX*cols + colIDX, add_ps(unquantize(result.pack0123, unquant_mult_), biasSection0123));
- storeu_ps(C_ + rowIDX*cols + colIDX + 4, add_ps(unquantize(result.pack4567, unquant_mult_), biasSection4567));
- }
- private:
- float *C_;
- const float *bias_;
- __m128 unquant_mult_;
- };
-
- class OnAVX2 {
- public:
- INTGEMM_AVX2 explicit OnAVX2(const BiasAddUnquantizeC &from)
- : C_(from.C_), bias_(from.bias_), unquant_mult_(_mm256_set1_ps(from.unquant_mult_)) {
- assert(reinterpret_cast<uintptr_t>(C_) % sizeof(__m256i) == 0);
- }
-
- INTGEMM_AVX2 inline void operator()(Index rowIDX, Index cols, Index colIDX, __m256i result) {
- auto biasSection = loadu_ps<__m256>(bias_ + colIDX);
- storeu_ps(C_ + rowIDX*cols + colIDX, add_ps(unquantize(result, unquant_mult_), biasSection));
- }
-
- private:
- float *C_;
- const float *bias_;
- __m256 unquant_mult_;
- };
-
- private:
- float *C_;
- const float *bias_;
- float unquant_mult_;
-};
-
-class Identity {
- public:
- explicit Identity(int32_t *C) : C_(C) {}
-
- class OnSSE2 {
- public:
- INTGEMM_SSE2 explicit OnSSE2(const Identity &from)
- : C_(from.C_) {
- assert(reinterpret_cast<uintptr_t>(C_) % sizeof(__m128i) == 0);
- }
-
- INTGEMM_SSE2 inline void operator()(Index rowIDX, Index cols, Index colIDX, MultiplyResult128 result) {
- _mm_storeu_si128(reinterpret_cast<__m128i*>(C_ + rowIDX*cols + colIDX), result.pack0123);
- _mm_storeu_si128(reinterpret_cast<__m128i*>(C_ + rowIDX*cols + colIDX + 4), result.pack4567);
- }
- private:
- int32_t *C_;
- };
-
- class OnAVX2 {
- public:
- INTGEMM_AVX2 explicit OnAVX2(const Identity &from)
- : C_(from.C_) {
- assert(reinterpret_cast<uintptr_t>(C_) % sizeof(__m256i) == 0);
- }
-
- INTGEMM_AVX2 inline void operator()(Index rowIDX, Index cols, Index colIDX, __m256i result) {
- _mm256_storeu_si256(reinterpret_cast<__m256i*>(C_ + rowIDX*cols + colIDX), result);
- }
-
- private:
- int32_t *C_;
- };
-
- private:
- int32_t *C_;
-};
-
-class ReLU {
- public:
- explicit ReLU(float *C, float unquant_mult) : C_(C), unquant_mult_(unquant_mult) {}
-
- class OnSSE2 {
- public:
- INTGEMM_SSE2 explicit OnSSE2(const ReLU& from)
- : C_(from.C_), unquant_mult_(set1_ps<__m128>(from.unquant_mult_)) {
- assert(reinterpret_cast<uintptr_t>(C_) % sizeof(__m128i) == 0);
- }
-
- INTGEMM_SSE2 inline void operator()(Index rowIDX, Index cols, Index colIDX, MultiplyResult128 result) {
- static const auto zeros_ = setzero_ps<__m128>();
-
- auto unquantized0123 = unquantize(result.pack0123, unquant_mult_);
- auto nonnegative0123 = max_ps(zeros_, unquantized0123);
- storeu_ps(C_ + rowIDX*cols + colIDX, nonnegative0123);
-
- auto unquantized4567 = unquantize(result.pack4567, unquant_mult_);
- auto nonnegative4567 = max_ps(zeros_, unquantized4567);
- storeu_ps(C_ + rowIDX*cols + colIDX + 4, nonnegative4567);
- }
-
- private:
- float* C_;
- __m128 unquant_mult_;
- };
-
- using OnSSSE3 = OnSSE2;
-
- class OnAVX2 {
- public:
- INTGEMM_AVX2 explicit OnAVX2(const ReLU& from)
- : C_(from.C_), unquant_mult_(set1_ps<__m256>(from.unquant_mult_)) {
- assert(reinterpret_cast<uintptr_t>(C_) % sizeof(__m256i) == 0);
- }
-
- INTGEMM_AVX2 inline void operator()(Index rowIDX, Index cols, Index colIDX, __m256i result) {
- static const auto zeros_ = setzero_ps<__m256>();
-
- auto nonnegative = max_ps(zeros_, unquantize(result, unquant_mult_));
- storeu_ps(C_ + rowIDX*cols + colIDX, nonnegative);
- }
-
- private:
- float* C_;
- __m256 unquant_mult_;
- };
-
-#ifndef INTGEMM_NO_AVX512
- class OnAVX512 {
- public:
- INTGEMM_AVX512BW explicit OnAVX512(const ReLU& from)
- : C_(from.C_), unquant_mult_(set1_ps<__m512>(from.unquant_mult_)) {
- assert(reinterpret_cast<uintptr_t>(C_) % sizeof(__m512i) == 0);
- }
-
- INTGEMM_AVX512BW inline void operator()(Index rowIDX, Index cols, Index colIDX, __m512i result) {
- static const auto zeros_ = setzero_ps<__m512>();
-
- auto nonnegative = max_ps(zeros_, unquantize(result, unquant_mult_));
- storeu_ps(C_ + rowIDX*cols + colIDX, nonnegative);
- }
-
- private:
- float* C_;
- __m512 unquant_mult_;
- };
-#endif
-
- private:
- float* C_;
- float unquant_mult_;
-};
-
-}
diff --git a/example.cc b/example.cc
index 06c7a59..d01c820 100644
--- a/example.cc
+++ b/example.cc
@@ -51,7 +51,7 @@ int main() {
AlignedVector<float> C(A_rows * B_cols);
// Do the actual multiply.
- intgemm::Int16::Multiply(A_prepared.begin(), B_prepared.begin(), intgemm::JustUnquantizeC(C.begin(), 1.0 / (quant_mult * quant_mult)), A_rows, width, B_cols);
+ intgemm::Int16::Multiply(A_prepared.begin(), B_prepared.begin(), C.begin(), intgemm::CreatePostprocessPipeline(intgemm::Unquantize(1.0 / (quant_mult * quant_mult))), A_rows, width, B_cols);
// Sanity check. C will be row major.
assert(fabs(C[0] - top_left_reference) < 0.05);
}
@@ -70,7 +70,7 @@ int main() {
AlignedVector<float> C(A_rows * B_cols);
// Do the actual multiply.
- intgemm::Int8::Multiply(A_prepared.begin(), B_prepared.begin(), intgemm::JustUnquantizeC(C.begin(), 1.0 / (quant_mult * quant_mult)), A_rows, width, B_cols);
+ intgemm::Int8::Multiply(A_prepared.begin(), B_prepared.begin(), C.begin(), intgemm::CreatePostprocessPipeline(intgemm::Unquantize(1.0 / (quant_mult * quant_mult))), A_rows, width, B_cols);
// Sanity check. C will be row major.
assert(fabs(C[0] - top_left_reference) < 0.05);
}
diff --git a/intgemm.h b/intgemm.h
index 37f1652..e380758 100644
--- a/intgemm.h
+++ b/intgemm.h
@@ -48,7 +48,7 @@
#include "sse2_gemm.h"
#include "ssse3_gemm.h"
#include "avx2_gemm.h"
-#include "cops.h"
+#include "postprocess.h"
#ifndef INTGEMM_NO_AVX512
#include "avx512_gemm.h"
#endif
@@ -67,8 +67,8 @@ struct Unsupported_16bit {
static void SelectColumnsB(const int16_t *, int16_t *, Index, const Index *, const Index *) {
throw UnsupportedCPU();
}
- template<class WriteC>
- static void Multiply(const int16_t *, const int16_t *, WriteC, Index, Index, Index) {
+ template <typename PostprocessPipeline>
+ static void Multiply(const int16_t *, const int16_t *, float *, PostprocessPipeline, Index, Index, Index) {
throw UnsupportedCPU();
}
constexpr static const char *const kName = "16-bit Unsupported";
@@ -84,8 +84,8 @@ struct Unsupported_8bit {
static void SelectColumnsB(const int8_t *, int8_t *, Index, const Index *, const Index *) {
throw UnsupportedCPU();
}
- template<class WriteC>
- static void Multiply(const int8_t *, const int8_t *, WriteC, Index, Index, Index) {
+ template <typename PostprocessPipeline>
+ static void Multiply(const int8_t *, const int8_t *, float *, PostprocessPipeline, Index, Index, Index) {
throw UnsupportedCPU();
}
constexpr static const char *const kName = "8-bit Unsupported";
@@ -133,15 +133,15 @@ template <class T> T ChooseCPU(T avx512, T avx2, T ssse3, T sse2, T unsupported)
}
/* 16-bit matrix multiplication. */
-template<class WriteC>
+template <typename PostprocessPipeline>
class Int16Mult {
public:
// Multiply C = A * B, presuming A and B have been prepared.
- static void (*Multiply)(const int16_t *A, const int16_t *B, WriteC functor, Index A_rows, Index width, Index B_cols);
+ static void (*Multiply)(const int16_t *A, const int16_t *B, float *C, PostprocessPipeline pipeline, Index A_rows, Index width, Index B_cols);
};
-template <class WriteC>
-void (*Int16Mult<WriteC>::Multiply)(const int16_t *A, const int16_t *B, WriteC functor, Index A_rows, Index width, Index B_cols) = ChooseCPU(AVX512_16bit::Multiply<WriteC>, AVX2_16bit::Multiply<WriteC>, SSE2_16bit::Multiply<WriteC>, SSE2_16bit::Multiply<WriteC>, Unsupported_16bit::Multiply);
+template <typename PostprocessPipeline>
+void (*Int16Mult<PostprocessPipeline>::Multiply)(const int16_t *A, const int16_t *B, float *C, PostprocessPipeline pipeline, Index A_rows, Index width, Index B_cols) = ChooseCPU(AVX512_16bit::Multiply<PostprocessPipeline>, AVX2_16bit::Multiply<PostprocessPipeline>, SSE2_16bit::Multiply<PostprocessPipeline>, SSE2_16bit::Multiply<PostprocessPipeline>, Unsupported_16bit::Multiply);
struct Int16 {
typedef int16_t Integer;
@@ -172,24 +172,24 @@ struct Int16 {
static void (*SelectColumnsB)(const int16_t *input, int16_t *output, Index rows, const Index *cols_begin, const Index *cols_end);
// Multiply C = A * B, presuming A and B have been prepared.
- template<class WriteC>
- static void Multiply(const int16_t *A, const int16_t *B, WriteC functor, Index A_rows, Index width, Index B_cols) {
- Int16Mult<WriteC>::Multiply(A, B, functor, A_rows, width, B_cols);
+ template <typename PostprocessPipeline>
+ static void Multiply(const int16_t *A, const int16_t *B, float *C, PostprocessPipeline pipeline, Index A_rows, Index width, Index B_cols) {
+ Int16Mult<PostprocessPipeline>::Multiply(A, B, C, pipeline, A_rows, width, B_cols);
}
static const char *const kName;
};
/* 8-bit matrix multiplication */
-template<class WriteC>
+template <typename PostprocessPipeline>
class Int8Mult {
public:
// Multiply C = A * B, presuming A and B have been prepared.
- static void (*Multiply)(const int8_t *A, const int8_t *B, WriteC functor, Index A_rows, Index width, Index B_cols);
+ static void (*Multiply)(const int8_t *A, const int8_t *B, float *C, PostprocessPipeline pipeline, Index A_rows, Index width, Index B_cols);
};
-template <class WriteC>
-void (*Int8Mult<WriteC>::Multiply)(const int8_t *A, const int8_t *B, WriteC functor, Index A_rows, Index width, Index B_cols) = ChooseCPU(AVX512_8bit::Multiply<WriteC>, AVX2_8bit::Multiply<WriteC>, SSSE3_8bit::Multiply<WriteC>, SSSE3_8bit::Multiply<WriteC>, Unsupported_8bit::Multiply);
+template <typename PostprocessPipeline>
+void (*Int8Mult<PostprocessPipeline>::Multiply)(const int8_t *A, const int8_t *B, float *C, PostprocessPipeline pipeline, Index A_rows, Index width, Index B_cols) = ChooseCPU(AVX512_8bit::Multiply<PostprocessPipeline>, AVX2_8bit::Multiply<PostprocessPipeline>, SSSE3_8bit::Multiply<PostprocessPipeline>, SSSE3_8bit::Multiply<PostprocessPipeline>, Unsupported_8bit::Multiply);
struct Int8 {
typedef int8_t Integer;
@@ -219,9 +219,9 @@ struct Int8 {
static void (*SelectColumnsB)(const int8_t *input, int8_t *output, Index rows, const Index *cols_begin, const Index *cols_end);
// Multiply C = A * B, presuming A and B have been prepared.
- template<class WriteC>
- static void Multiply(const int8_t *A, const int8_t *B, WriteC functor, Index A_rows, Index width, Index B_cols) {
- Int8Mult<WriteC>::Multiply(A, B, functor, A_rows, width, B_cols);
+ template <typename PostprocessPipeline>
+ static void Multiply(const int8_t *A, const int8_t *B, float* C, PostprocessPipeline pipeline, Index A_rows, Index width, Index B_cols) {
+ Int8Mult<PostprocessPipeline>::Multiply(A, B, C, pipeline, A_rows, width, B_cols);
}
static const char *const kName;
diff --git a/multiply.h b/multiply.h
index 4642616..3b05252 100644
--- a/multiply.h
+++ b/multiply.h
@@ -2,10 +2,30 @@
#include "interleave.h"
#include "intrinsics.h"
+#include "postprocess_pipeline.h"
#include "vec_utils.h"
namespace intgemm {
+INTGEMM_SSE2 static inline void writer(float* C, Index rowIDX, Index cols, Index colIDX, __m128 result) {
+ *reinterpret_cast<__m128*>(C + rowIDX*cols + colIDX) = result;
+}
+
+INTGEMM_SSE2 static inline void writer(float* C, Index rowIDX, Index cols, Index colIDX, RegisterPair128 result) {
+ *reinterpret_cast<__m128*>(C + rowIDX*cols + colIDX) = result.pack0123;
+ *reinterpret_cast<__m128*>(C + rowIDX*cols + colIDX + 4) = result.pack4567;
+}
+
+INTGEMM_AVX2 static inline void writer(float* C, Index rowIDX, Index cols, Index colIDX, __m256 result) {
+ *reinterpret_cast<__m256*>(C + rowIDX*cols + colIDX) = result;
+}
+
+#ifndef INTGEMM_NO_AVX512
+INTGEMM_AVX512BW static inline void writer(float* C, Index rowIDX, Index cols, Index colIDX, __m512 result) {
+ *reinterpret_cast<__m512*>(C + rowIDX*cols + colIDX) = result;
+}
+#endif
+
INTGEMM_SSE2 static inline float MaxFloat32(__m128 a) {
// Fold to just using the first 64 bits.
__m128 second_half = _mm_shuffle_ps(a, a, 3 * 4 + 2);
@@ -17,9 +37,9 @@ INTGEMM_SSE2 static inline float MaxFloat32(__m128 a) {
return *reinterpret_cast<float*>(&a);
}
-INTGEMM_SSE2 static inline MultiplyResult128 PermuteSummer(__m128i pack0123, __m128i pack4567) {
+INTGEMM_SSE2 static inline RegisterPair128i PermuteSummer(__m128i pack0123, __m128i pack4567) {
// No op for 128 bits: already reduced fully.
- MultiplyResult128 ret;
+ RegisterPair128i ret;
ret.pack0123 = pack0123;
ret.pack4567 = pack4567;
return ret;
@@ -126,14 +146,14 @@ INTGEMM_PACK0123(INTGEMM_AVX512BW, __m512i)
// width must be a multiple of the register size.
// B_cols must be a multiple of 8.
// Multiply16
-#define INTGEMM_MULTIPLY16(Integer, target, WriteCSubType) \
- template <class WriteC> target static void Multiply(const int16_t *A, const int16_t *B, WriteC C, Index A_rows, Index width, Index B_cols) { \
+#define INTGEMM_MULTIPLY16(Integer, target, cpu_type) \
+template <typename PostprocessPipeline> target static void Multiply(const int16_t *A, const int16_t *B, float* C, PostprocessPipeline pipeline, Index A_rows, Index width, Index B_cols) { \
assert(width % (sizeof(Integer) / sizeof(int16_t)) == 0); \
assert(B_cols % 8 == 0); \
assert(reinterpret_cast<uintptr_t>(A) % sizeof(Integer) == 0); \
assert(reinterpret_cast<uintptr_t>(B) % sizeof(Integer) == 0); \
const int simd_width = width / (sizeof(Integer) / sizeof(int16_t)); \
- typename WriteC::WriteCSubType write_C(C); \
+ auto inited_pipeline = InitPostprocessPipeline<cpu_type>(pipeline); \
const Integer *B0_col = reinterpret_cast<const Integer *>(B); \
for (Index B0_colidx = 0; B0_colidx < B_cols; B0_col += 8 * simd_width, B0_colidx += 8) { \
/* Process one row of A at a time. Doesn't seem to be faster to do multiple rows of A at once.*/ \
@@ -177,7 +197,8 @@ INTGEMM_PACK0123(INTGEMM_AVX512BW, __m512i)
Integer pack4567 = Pack0123(sum4, sum5, sum6, sum7); \
/*The specific implementation may need to reduce further.*/ \
auto total = PermuteSummer(pack0123, pack4567); \
- write_C(A_rowidx, B_cols, B0_colidx, total); \
+ auto result = inited_pipeline.run(total); \
+ writer(C, A_rowidx, B_cols, B0_colidx, result); \
} \
} \
} \
@@ -330,15 +351,15 @@ INTGEMM_SSSE3 inline static void InnerINTGEMM_SSSE3(
sum7 = adds_epi16(sum7, maddubs_epi16(a_positive, sign_epi8(b[7], a)));
}
//INTGEMM_AVX2 or INTGEMM_SSSE3 multiply
-#define INTGEMM_MULTIPLY8(Integer, target, WriteCSubType) \
-template <class WriteC> target static void Multiply(const int8_t *A, const int8_t *B, WriteC C, Index A_rows, Index width, Index B_cols) { \
+#define INTGEMM_MULTIPLY8(Integer, target, cpu_type) \
+ template <typename PostprocessPipeline> target static void Multiply(const int8_t *A, const int8_t *B, float* C, PostprocessPipeline pipeline, Index A_rows, Index width, Index B_cols) { \
assert(width % sizeof(Integer) == 0); \
assert(B_cols % 8 == 0); \
assert(reinterpret_cast<uintptr_t>(A) % sizeof(Integer) == 0); \
assert(reinterpret_cast<uintptr_t>(B) % sizeof(Integer) == 0); \
const int simd_width = width / sizeof(Integer); \
const Integer *B0_col = reinterpret_cast<const Integer*>(B); \
- typename WriteC::WriteCSubType c_writer(C); \
+ auto inited_pipeline = InitPostprocessPipeline<cpu_type>(pipeline); \
/*Go over 8 columns of B at a time.*/ \
for (Index B0_colidx = 0; B0_colidx != B_cols; B0_col += 8 * simd_width, B0_colidx += 8) { \
/*Process one row of A at a time. Doesn't seem to be faster to do multiple rows of A at once.*/ \
@@ -393,8 +414,8 @@ template <class WriteC> target static void Multiply(const int8_t *A, const int8_
Integer pack0123 = Pack0123(sum0, sum1, sum2, sum3); \
Integer pack4567 = Pack0123(sum4, sum5, sum6, sum7); \
auto total = PermuteSummer(pack0123, pack4567); \
- /*WriteC(C + A_rowidx * B_cols + B0_colidx, total, unquant_reg);*/ \
- c_writer(A_rowidx, B_cols, B0_colidx, total); \
+ auto result = inited_pipeline.run(total); \
+ writer(C, A_rowidx, B_cols, B0_colidx, result); \
} \
} \
} \
diff --git a/postprocess.h b/postprocess.h
new file mode 100644
index 0000000..7336819
--- /dev/null
+++ b/postprocess.h
@@ -0,0 +1,221 @@
+#pragma once
+
+#include "intrinsics.h"
+#include "postprocess_pipeline.h"
+#include "vec_utils.h"
+
+namespace intgemm {
+
+/*
+ * Unquantize
+ */
+class Unquantize {
+public:
+ float unquantize_multiplier;
+
+ Unquantize(float unquantize_multiplier) : unquantize_multiplier(unquantize_multiplier) {}
+};
+
+template <>
+class PostprocessImpl<Unquantize, CPUType::CPU_SSE2> {
+public:
+ using InputRegister = RegisterPair128i;
+ using OutputRegister = RegisterPair128;
+
+ INTGEMM_SSE2 PostprocessImpl(const Unquantize& config) {
+ unquantize_multiplier = set1_ps<__m128>(config.unquantize_multiplier);
+ }
+
+ INTGEMM_SSE2 inline OutputRegister run(InputRegister input) {
+ return {
+ mul_ps(cvtepi32_ps(input.pack0123), unquantize_multiplier),
+ mul_ps(cvtepi32_ps(input.pack4567), unquantize_multiplier),
+ };
+ }
+
+private:
+ __m128 unquantize_multiplier;
+};
+
+template <>
+class PostprocessImpl<Unquantize, CPUType::CPU_AVX2> {
+public:
+ using InputRegister = __m256i;
+ using OutputRegister = __m256;
+
+ INTGEMM_AVX2 PostprocessImpl(const Unquantize& config) {
+ unquantize_multiplier = set1_ps<__m256>(config.unquantize_multiplier);
+ }
+
+ INTGEMM_AVX2 inline OutputRegister run(InputRegister input) {
+ return mul_ps(cvtepi32_ps(input), unquantize_multiplier);
+ }
+
+private:
+ __m256 unquantize_multiplier;
+};
+
+template <>
+class PostprocessImpl<Unquantize, CPUType::CPU_AVX512BW> {
+public:
+ using InputRegister = __m512i;
+ using OutputRegister = __m512;
+
+ INTGEMM_AVX512BW PostprocessImpl(const Unquantize& config) {
+ unquantize_multiplier = set1_ps<__m512>(config.unquantize_multiplier);
+ }
+
+ INTGEMM_AVX512BW inline OutputRegister run(InputRegister input) {
+ return mul_ps(cvtepi32_ps(input), unquantize_multiplier);
+ }
+
+private:
+ __m512 unquantize_multiplier;
+};
+
+/*
+ * Identity
+ */
+class Identity {};
+
+template <>
+class PostprocessImpl<Identity, CPUType::CPU_SSE2> {
+public:
+ using InputRegister = RegisterPair128i;
+ using OutputRegister = RegisterPair128i;
+
+ PostprocessImpl(const Identity& config) {}
+
+ INTGEMM_SSE2 inline OutputRegister run(InputRegister input) {
+ return input;
+ }
+};
+
+template <>
+class PostprocessImpl<Identity, CPUType::CPU_AVX2> {
+public:
+ using InputRegister = __m256i;
+ using OutputRegister = __m256i;
+
+ PostprocessImpl(const Identity& config) {}
+
+ INTGEMM_AVX2 inline OutputRegister run(InputRegister input) {
+ return input;
+ }
+};
+
+template <>
+class PostprocessImpl<Identity, CPUType::CPU_AVX512BW> {
+public:
+ using InputRegister = __m512i;
+ using OutputRegister = __m512i;
+
+ PostprocessImpl(const Identity& config) {}
+
+ INTGEMM_AVX512BW inline OutputRegister run(InputRegister input) {
+ return input;
+ }
+};
+
+/*
+ * Add a bias term
+ */
+class AddBias {
+public:
+ const float* bias;
+
+ AddBias(const float* bias) : bias(bias) {}
+};
+
+template <>
+class PostprocessImpl<AddBias, CPUType::CPU_SSE2> {
+public:
+ using InputRegister = RegisterPair128;
+ using OutputRegister = RegisterPair128;
+
+ PostprocessImpl(const AddBias& config) : config(config) {}
+
+ INTGEMM_SSE2 inline OutputRegister run(InputRegister input) {
+ auto bias_term0123 = *reinterpret_cast<const __m128*>(config.bias);
+ auto bias_term4567 = *reinterpret_cast<const __m128*>(config.bias);
+ return {
+ add_ps(input.pack0123, bias_term0123),
+ add_ps(input.pack4567, bias_term4567),
+ };
+ }
+
+private:
+ const AddBias config;
+};
+
+template <>
+class PostprocessImpl<AddBias, CPUType::CPU_AVX2> {
+public:
+ using InputRegister = __m256;
+ using OutputRegister = __m256;
+
+ PostprocessImpl(const AddBias& config) : config(config) {}
+
+ INTGEMM_AVX2 inline OutputRegister run(InputRegister input) {
+ auto bias_term = *reinterpret_cast<const __m256*>(config.bias);
+ return add_ps(input, bias_term);
+ }
+
+private:
+ const AddBias config;
+};
+
+/*
+ * ReLU
+ */
+class ReLU {};
+
+template <>
+class PostprocessImpl<ReLU, CPUType::CPU_SSE2> {
+public:
+ using InputRegister = RegisterPair128;
+ using OutputRegister = RegisterPair128;
+
+ PostprocessImpl(const ReLU& config) {}
+
+ INTGEMM_SSE2 inline OutputRegister run(InputRegister input) {
+ static const auto const_zero = set1_ps<__m128>(0.f);
+ return {
+ max_ps(const_zero, input.pack0123),
+ max_ps(const_zero, input.pack4567),
+ };
+ }
+};
+
+template <>
+class PostprocessImpl<ReLU, CPUType::CPU_SSSE3> : public PostprocessImpl<ReLU, CPUType::CPU_SSE2> {};
+
+template <>
+class PostprocessImpl<ReLU, CPUType::CPU_AVX2> {
+public:
+ using InputRegister = __m256;
+ using OutputRegister = __m256;
+
+ PostprocessImpl(const ReLU& config) {}
+
+ INTGEMM_AVX2 inline OutputRegister run(InputRegister input) {
+ static const auto const_zero = set1_ps<__m256>(0.f);
+ return max_ps(const_zero, input);
+ }
+};
+
+template <>
+class PostprocessImpl<ReLU, CPUType::CPU_AVX512BW> {
+public:
+ using InputRegister = __m512;
+ using OutputRegister = __m512;
+
+ PostprocessImpl(const ReLU& config) {}
+
+ INTGEMM_AVX512BW inline OutputRegister run(InputRegister input) {
+ static const auto const_zero = set1_ps<__m512>(0.f);
+ return max_ps(const_zero, input);
+ }
+};
+
+}
diff --git a/postprocess_pipeline.h b/postprocess_pipeline.h
new file mode 100644
index 0000000..0783ff0
--- /dev/null
+++ b/postprocess_pipeline.h
@@ -0,0 +1,140 @@
+#pragma once
+
+#include "intrinsics.h"
+#include "types.h"
+
+#include <tuple>
+
+namespace intgemm {
+
+template <typename... Stages>
+using PostprocessPipeline = std::tuple<Stages...>;
+
+template <typename... Stages>
+constexpr std::tuple<Stages...> CreatePostprocessPipeline(const Stages&... stages) {
+ return std::tuple<Stages...>(stages...);
+}
+
+template <typename Postprocess, CPUType CpuType>
+class PostprocessImpl;
+
+namespace { // anonymous namespace
+
+template <std::size_t... I>
+struct integer_seq {};
+
+template <std::size_t N, std::size_t... I>
+struct integer_seq_from_one_s : integer_seq_from_one_s<N - 1, N - 1, I...> {};
+
+template <std::size_t... I>
+struct integer_seq_from_one_s<1, I...> : integer_seq<I...> {};
+
+template <typename... Types>
+using integer_seq_from_one = integer_seq_from_one_s<sizeof...(Types) + 1>;
+
+template <typename Stage>
+struct remove_first_stage_type_s { using type = std::tuple<>;};
+
+template <typename FirstStage, typename... RestStages>
+struct remove_first_stage_type_s<std::tuple<FirstStage, RestStages...>> { using type = std::tuple<RestStages...>; };
+
+template <typename... Stages>
+using remove_first_stage_type = typename remove_first_stage_type_s<Stages...>::type;
+
+template <typename FirstStage, typename... RestStages>
+struct get_first_stage_type_s { using type = FirstStage; };
+
+template <typename Stage>
+struct get_first_stage_type_s<Stage> { using type = Stage; };
+
+template <typename... Stages>
+using get_first_stage_type = typename get_first_stage_type_s<Stages...>::type;
+
+template <typename FirstStage, typename... RestStages>
+struct get_last_stage_type_s { using type = typename get_last_stage_type_s<RestStages...>::type; };
+
+template <typename Stage>
+struct get_last_stage_type_s<Stage> { using type = Stage; };
+
+template <typename... Stages>
+using get_last_stage_type = typename get_last_stage_type_s<Stages...>::type;
+
+template <typename Tuple, typename std::size_t...I>
+constexpr remove_first_stage_type<Tuple> ShiftPostprocessPipelineImpl(const Tuple& pipeline, integer_seq<I...>) {
+ return CreatePostprocessPipeline(std::get<I>(pipeline)...);
+}
+
+template <typename FirstStage, typename... RestStages>
+constexpr std::tuple<RestStages...> ShiftPostprocessPipeline(const std::tuple<FirstStage, RestStages...>& pipeline) {
+ return ShiftPostprocessPipelineImpl(pipeline, integer_seq_from_one<std::tuple<FirstStage, RestStages...>>());
+}
+
+template <CPUType CpuType, typename Stage>
+constexpr std::tuple<PostprocessImpl<Stage, CpuType>> InitPostprocessPipelineImpl(std::tuple<Stage> pipeline) {
+ return std::tuple<PostprocessImpl<Stage, CpuType>>(PostprocessImpl<Stage, CpuType>(std::get<0>(pipeline)));
+}
+
+template <CPUType CpuType, typename FirstStage, typename... RestStages>
+constexpr std::tuple<PostprocessImpl<FirstStage, CpuType>, PostprocessImpl<RestStages, CpuType>...> InitPostprocessPipelineImpl(std::tuple<FirstStage, RestStages...> pipeline) {
+ return std::tuple_cat(
+ std::tuple<PostprocessImpl<FirstStage, CpuType>>(PostprocessImpl<FirstStage, CpuType>(std::get<0>(pipeline))),
+ InitPostprocessPipelineImpl<CpuType, RestStages...>(ShiftPostprocessPipeline(pipeline))
+ );
+}
+
+template <CPUType CpuType>
+struct RunPostprocessPipelineImpl;
+
+#define RUN_POSTPROCESS_PIPELINE_IMPL_INSERT_IMPL(attribute, cpu_type) \
+ template <> \
+ struct RunPostprocessPipelineImpl<cpu_type> { \
+ template <typename Stage> \
+ attribute static constexpr typename Stage::OutputRegister \
+ run(std::tuple<Stage> pipeline, typename Stage::InputRegister input) { \
+ return std::get<0>(pipeline).run(input); \
+ } \
+ template <typename... Stages> \
+ attribute static constexpr typename get_last_stage_type<Stages...>::OutputRegister \
+ run(std::tuple<Stages...> pipeline, typename get_first_stage_type<Stages...>::InputRegister input) { \
+ return run( \
+ ShiftPostprocessPipeline(pipeline), \
+ std::get<0>(pipeline).run(input)); \
+ } \
+ };
+
+RUN_POSTPROCESS_PIPELINE_IMPL_INSERT_IMPL(INTGEMM_SSE2, CPUType::CPU_SSE2)
+RUN_POSTPROCESS_PIPELINE_IMPL_INSERT_IMPL(INTGEMM_SSSE3, CPUType::CPU_SSSE3)
+RUN_POSTPROCESS_PIPELINE_IMPL_INSERT_IMPL(INTGEMM_AVX2, CPUType::CPU_AVX2)
+RUN_POSTPROCESS_PIPELINE_IMPL_INSERT_IMPL(INTGEMM_AVX512BW, CPUType::CPU_AVX512BW)
+
+} // anonymous namespace
+
+template <CPUType CpuType, typename... Stages>
+class InitedPostprocessPipeline {};
+
+template <CPUType CpuType, typename... Stages>
+constexpr InitedPostprocessPipeline<CpuType, Stages...> InitPostprocessPipeline(std::tuple<Stages...> pipeline) {
+ return InitedPostprocessPipeline<CpuType, Stages...>(pipeline);
+}
+
+#define INITED_POSTPROCESS_PIPELINE_INSERT_IMPL(attribute, cpu_type) \
+ template <typename... Stages> \
+ class InitedPostprocessPipeline<cpu_type, Stages...> { \
+ public: \
+ using InputRegister = typename get_first_stage_type<PostprocessImpl<Stages, cpu_type>...>::InputRegister; \
+ using OutputRegister = typename get_last_stage_type<PostprocessImpl<Stages, cpu_type>...>::OutputRegister; \
+ InitedPostprocessPipeline(std::tuple<Stages...> pipeline) \
+ : inited_pipeline(InitPostprocessPipelineImpl<cpu_type, Stages...>(pipeline)) {} \
+ attribute inline OutputRegister run(InputRegister input) { \
+ return RunPostprocessPipelineImpl<cpu_type>::run(inited_pipeline, input); \
+ } \
+ private: \
+ const std::tuple<PostprocessImpl<Stages, cpu_type>...> inited_pipeline; \
+ };
+
+INITED_POSTPROCESS_PIPELINE_INSERT_IMPL(INTGEMM_SSE2, CPUType::CPU_SSE2)
+INITED_POSTPROCESS_PIPELINE_INSERT_IMPL(INTGEMM_SSSE3, CPUType::CPU_SSSE3)
+INITED_POSTPROCESS_PIPELINE_INSERT_IMPL(INTGEMM_AVX2, CPUType::CPU_AVX2)
+INITED_POSTPROCESS_PIPELINE_INSERT_IMPL(INTGEMM_AVX512BW, CPUType::CPU_AVX512BW)
+
+}
diff --git a/sse2_gemm.h b/sse2_gemm.h
index 3ca263f..dfccc5c 100644
--- a/sse2_gemm.h
+++ b/sse2_gemm.h
@@ -72,7 +72,7 @@ struct SSE2_16bit {
//TODO #DEFINE
sse2::SelectColumnsOfB((const __m128i*)input, (__m128i*)output, rows * 2, cols_begin, cols_end);
}
- INTGEMM_MULTIPLY16(__m128i, INTGEMM_SSE2, OnSSE2)
+ INTGEMM_MULTIPLY16(__m128i, INTGEMM_SSE2, CPUType::CPU_SSE2)
constexpr static const char *const kName = "16-bit INTGEMM_SSE2";
diff --git a/ssse3_gemm.h b/ssse3_gemm.h
index 9c21467..4e12b90 100644
--- a/ssse3_gemm.h
+++ b/ssse3_gemm.h
@@ -88,21 +88,14 @@ struct SSSE3_8bit {
// Tile size for B; B must be a multiple of this block size.
static const Index kBTileRow = 16;
static const Index kBTileCol = 8;
-/*
- INTGEMM_SSSE3 static void PrepareB(const float *input, int8_t *output, float quant_mult, Index rows, Index cols) {
- PrepareBFor8(input, output, ssse3::QuantizeTile8(quant_mult), rows, cols);
- }*/
+
INTGEMM_PREPARE_B_8(INTGEMM_SSSE3, ssse3::QuantizeTile8)
INTGEMM_SSSE3 static void SelectColumnsB(const int8_t *input, int8_t *output, Index rows, const Index *cols_begin, const Index *cols_end) {
ssse3::SelectColumnsOfB((const __m128i*)input, (__m128i*)output, rows, cols_begin, cols_end);
}
-/*
- INTGEMM_SSSE3 static void Multiply(const int8_t *A, const int8_t *B, float *C, float unquant_mult, Index A_rows, Index width, Index B_cols) {
- //Multiply8_SSE2OrAVX2<Multiply8_C, __m128i, __m128>(A, B, C, unquant_mult, A_rows, width, B_cols);
- Multiply8_SSE2OrAVX2__m128i<JustUnquantizeC>(A, B, JustUnquantizeC(C, unquant_mult), A_rows, width, B_cols);
- }*/
- INTGEMM_MULTIPLY8(__m128i, INTGEMM_SSSE3, OnSSE2)
+
+ INTGEMM_MULTIPLY8(__m128i, INTGEMM_SSSE3, CPUType::CPU_SSE2)
constexpr static const char *const kName = "8-bit INTGEMM_SSSE3";
diff --git a/test/multiply_test.cc b/test/multiply_test.cc
index 0cfec02..0c0becf 100644
--- a/test/multiply_test.cc
+++ b/test/multiply_test.cc
@@ -1,7 +1,8 @@
-#include "intgemm.h"
#include "aligned.h"
#include "interleave.h"
+#include "intgemm.h"
#include "multiply.h"
+#include "postprocess.h"
#define CATCH_CONFIG_RUNNER
#include "3rd_party/catch.hpp"
@@ -366,7 +367,7 @@ template <class Routine> void TestMultiply(Index A_rows, Index width, Index B_co
Routine::PrepareB(B.begin(), B_prep.begin(), quant_mult, width, B_cols);
AlignedVector<float> test_C(A_rows * B_cols);
- Routine::Multiply(A_prep.begin(), B_prep.begin(), JustUnquantizeC(test_C.begin(), unquant_mult), A_rows, width, B_cols);
+ Routine::Multiply(A_prep.begin(), B_prep.begin(), test_C.begin(), CreatePostprocessPipeline(Unquantize(unquant_mult)), A_rows, width, B_cols);
AlignedVector<Integer> B_quant(B.size());
Routine::Quantize(B.begin(), B_quant.begin(), quant_mult, B.size());
@@ -415,7 +416,7 @@ template <class Routine> void TestMultiplyBias(Index A_rows, Index width, Index
AlignedVector<float> test_C(A_rows * B_cols);
- Routine::Multiply(A_prep.begin(), B_prep.begin(), BiasAddUnquantizeC(test_C.begin(), bias.begin(), unquant_mult), A_rows, width, B_cols);
+ Routine::Multiply(A_prep.begin(), B_prep.begin(), test_C.begin(), CreatePostprocessPipeline(Unquantize(unquant_mult), AddBias(bias.begin())), A_rows, width, B_cols);
AlignedVector<Integer> B_quant(B.size());
Routine::Quantize(B.begin(), B_quant.begin(), quant_mult, B.size());
diff --git a/test/pipeline_test.cc b/test/pipeline_test.cc
new file mode 100644
index 0000000..dc3d71e
--- /dev/null
+++ b/test/pipeline_test.cc
@@ -0,0 +1,32 @@
+#include "3rd_party/catch.hpp"
+#include "postprocess.h"
+
+#include <numeric>
+
+namespace intgemm {
+
+INTGEMM_AVX2 TEST_CASE("PostprocessPipeline AVX2", "Unquantize-ReLU") {
+ if (kCPU < CPU_AVX2)
+ return;
+
+ int raw_input[8];
+ std::iota(raw_input, raw_input + 8, -2);
+
+ auto input = *reinterpret_cast<__m256i*>(raw_input);
+ auto pipeline = CreatePostprocessPipeline(Unquantize(0.5f), ReLU());
+ auto inited_pipeline = InitPostprocessPipeline<CPU_AVX2>(pipeline);
+ auto output = inited_pipeline.run(input);
+
+ float* raw_output = reinterpret_cast<float*>(&output);
+
+ CHECK(raw_output[0] == 0.0f); // input = -2
+ CHECK(raw_output[1] == 0.0f); // input = -1
+ CHECK(raw_output[2] == 0.0f); // input = 0
+ CHECK(raw_output[3] == 0.5f); // input = 1
+ CHECK(raw_output[4] == 1.0f); // input = 2
+ CHECK(raw_output[5] == 1.5f); // input = 3
+ CHECK(raw_output[6] == 2.0f); // input = 4
+ CHECK(raw_output[7] == 2.5f); // input = 5
+}
+
+}
diff --git a/test/relu_test.cc b/test/relu_test.cc
index 264cd5a..0c677c9 100644
--- a/test/relu_test.cc
+++ b/test/relu_test.cc
@@ -1,5 +1,5 @@
#include "3rd_party/catch.hpp"
-#include "cops.h"
+#include "postprocess.h"
#include <numeric>
@@ -8,56 +8,48 @@ namespace intgemm {
INTGEMM_SSE2 TEST_CASE("ReLU SSE2",) {
if (kCPU < CPU_SSE2)
return;
- const unsigned N = 4;
- int32_t raw_input[2 * N];
- std::iota(raw_input, raw_input + 2 * N, -2);
-
- MultiplyResult128 input;
- input.pack0123 = *reinterpret_cast<__m128i*>(raw_input);
- input.pack4567 = *reinterpret_cast<__m128i*>(raw_input + N);
-
- float output[2 * N];
- std::fill(output, output + 2 * N, 42);
-
- auto postproc = ReLU::OnSSE2(ReLU(output, 1.f));
- postproc(0, 1, 0, input);
-
- CHECK(output[0] == 0.f); // input = -2
- CHECK(output[1] == 0.f); // input = -1
- CHECK(output[2] == 0.f); // input = 0
- CHECK(output[3] == 1.f); // input = 1
- CHECK(output[4] == 2.f); // input = 2
- CHECK(output[5] == 3.f); // input = 3
- CHECK(output[6] == 4.f); // input = 4
- CHECK(output[7] == 5.f); // input = 5
+ float raw_input[8];
+ std::iota(raw_input, raw_input + 8, -2);
+
+ RegisterPair128 input;
+ input.pack0123 = *reinterpret_cast<__m128*>(raw_input);
+ input.pack4567 = *reinterpret_cast<__m128*>(raw_input + 4);
+
+ auto postproc = PostprocessImpl<ReLU, CPUType::CPU_SSE2>(ReLU());
+ auto output = postproc.run(input);
+ auto raw_output = reinterpret_cast<float*>(&output);
+
+ CHECK(raw_output[0] == 0.f); // input = -2
+ CHECK(raw_output[1] == 0.f); // input = -1
+ CHECK(raw_output[2] == 0.f); // input = 0
+ CHECK(raw_output[3] == 1.f); // input = 1
+ CHECK(raw_output[4] == 2.f); // input = 2
+ CHECK(raw_output[5] == 3.f); // input = 3
+ CHECK(raw_output[6] == 4.f); // input = 4
+ CHECK(raw_output[7] == 5.f); // input = 5
}
INTGEMM_AVX2 TEST_CASE("ReLU AVX2",) {
if (kCPU < CPU_AVX2)
return;
- const unsigned N = 8;
-
- int32_t raw_input[N];
- std::iota(raw_input, raw_input + N, -4);
-
- auto input = *reinterpret_cast<__m256i*>(raw_input);
-
- float output[N];
- std::fill(output, output + N, 42);
-
- auto postproc = ReLU::OnAVX2(ReLU(output, 1.f));
- postproc(0, 1, 0, input);
-
- CHECK(output[0] == 0.f); // input = -4
- CHECK(output[1] == 0.f); // input = -3
- CHECK(output[2] == 0.f); // input = -2
- CHECK(output[3] == 0.f); // input = -1
- CHECK(output[4] == 0.f); // input = 0
- CHECK(output[5] == 1.f); // input = 1
- CHECK(output[6] == 2.f); // input = 2
- CHECK(output[7] == 3.f); // input = 3
+ float raw_input[8];
+ std::iota(raw_input, raw_input + 8, -4);
+
+ auto input = *reinterpret_cast<__m256*>(raw_input);
+ auto postproc = PostprocessImpl<ReLU, CPUType::CPU_AVX2>(ReLU());
+ auto output = postproc.run(input);
+ auto raw_output = reinterpret_cast<float*>(&output);
+
+ CHECK(raw_output[0] == 0.f); // input = -4
+ CHECK(raw_output[1] == 0.f); // input = -3
+ CHECK(raw_output[2] == 0.f); // input = -2
+ CHECK(raw_output[3] == 0.f); // input = -1
+ CHECK(raw_output[4] == 0.f); // input = 0
+ CHECK(raw_output[5] == 1.f); // input = 1
+ CHECK(raw_output[6] == 2.f); // input = 2
+ CHECK(raw_output[7] == 3.f); // input = 3
}
#ifndef INTGEMM_NO_AVX512
@@ -66,35 +58,30 @@ INTGEMM_AVX512BW TEST_CASE("ReLU AVX512",) {
if (kCPU < CPU_AVX512BW)
return;
- const unsigned N = 16;
-
- int32_t raw_input[N];
- std::iota(raw_input, raw_input + N, -8);
-
- auto input = *reinterpret_cast<__m512i*>(raw_input);
-
- float output[N];
- std::fill(output, output + N, 42);
-
- auto postproc = ReLU::OnAVX512(ReLU(output, 1.f));
- postproc(0, 1, 0, input);
-
- CHECK(output[0] == 0.f); // input = -8
- CHECK(output[1] == 0.f); // input = -7
- CHECK(output[2] == 0.f); // input = -6
- CHECK(output[3] == 0.f); // input = -5
- CHECK(output[4] == 0.f); // input = -4
- CHECK(output[5] == 0.f); // input = -3
- CHECK(output[6] == 0.f); // input = -2
- CHECK(output[7] == 0.f); // input = -1
- CHECK(output[8] == 0.f); // input = 0
- CHECK(output[9] == 1.f); // input = 1
- CHECK(output[10] == 2.f); // input = 2
- CHECK(output[11] == 3.f); // input = 3
- CHECK(output[12] == 4.f); // input = 4
- CHECK(output[13] == 5.f); // input = 5
- CHECK(output[14] == 6.f); // input = 6
- CHECK(output[15] == 7.f); // input = 7
+ float raw_input[16];
+ std::iota(raw_input, raw_input + 16, -8);
+
+ auto input = *reinterpret_cast<__m512*>(raw_input);
+ auto postproc = PostprocessImpl<ReLU, CPUType::CPU_AVX512BW>(ReLU());
+ auto output = postproc.run(input);
+ auto raw_output = reinterpret_cast<float*>(&output);
+
+ CHECK(raw_output[0] == 0.f); // input = -8
+ CHECK(raw_output[1] == 0.f); // input = -7
+ CHECK(raw_output[2] == 0.f); // input = -6
+ CHECK(raw_output[3] == 0.f); // input = -5
+ CHECK(raw_output[4] == 0.f); // input = -4
+ CHECK(raw_output[5] == 0.f); // input = -3
+ CHECK(raw_output[6] == 0.f); // input = -2
+ CHECK(raw_output[7] == 0.f); // input = -1
+ CHECK(raw_output[8] == 0.f); // input = 0
+ CHECK(raw_output[9] == 1.f); // input = 1
+ CHECK(raw_output[10] == 2.f); // input = 2
+ CHECK(raw_output[11] == 3.f); // input = 3
+ CHECK(raw_output[12] == 4.f); // input = 4
+ CHECK(raw_output[13] == 5.f); // input = 5
+ CHECK(raw_output[14] == 6.f); // input = 6
+ CHECK(raw_output[15] == 7.f); // input = 7
}
#endif
diff --git a/vec_utils.h b/vec_utils.h
index 3574ba0..fb6aea4 100644
--- a/vec_utils.h
+++ b/vec_utils.h
@@ -4,10 +4,14 @@
namespace intgemm {
-struct MultiplyResult128 {
+struct RegisterPair128i {
__m128i pack0123, pack4567;
};
+struct RegisterPair128 {
+ __m128 pack0123, pack4567;
+};
+
/*
*
* Quantize