Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/intgemm.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMateusz Chudyk <mateuszchudyk@gmail.com>2019-08-23 13:52:44 +0300
committerGitHub <noreply@github.com>2019-08-23 13:52:44 +0300
commit2d9f646d5fe397c1a8660b91f2a5021c215b35ae (patch)
treecdfa69bc76ccc44415011ff5eb17452d1b0e0f3c
parent66c40eed8b649abe2f903ceca2279abe78d5f385 (diff)
parent773cb5271efca4c7d9efe1143f22cf81a472f774 (diff)
Merge pull request #30 from kpu/add127_fullupcast
Add127 fullupcast
-rw-r--r--CMakeLists.txt10
-rw-r--r--avx2_gemm.h50
-rw-r--r--avx512_gemm.h35
-rw-r--r--benchmarks/benchmark.cc (renamed from benchmark.cc)0
-rw-r--r--benchmarks/biasmultiply.cc239
-rw-r--r--interleave.h17
-rw-r--r--intgemm.cc4
-rw-r--r--intgemm.h40
-rw-r--r--intrinsics.h12
-rw-r--r--multiply.h155
-rw-r--r--ssse3_gemm.h53
-rw-r--r--test/add127_test.cc258
-rw-r--r--test/multiply_test.cc48
-rw-r--r--test/test.cc56
-rw-r--r--test/test.h16
-rw-r--r--test_mull.cpp328
16 files changed, 1267 insertions, 54 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 3c6f3ff..322fc88 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -34,16 +34,20 @@ configure_file(${CMAKE_CURRENT_SOURCE_DIR}/config.h.in ${CMAKE_CURRENT_BINARY_DI
include_directories(${CMAKE_CURRENT_SOURCE_DIR})
include_directories(${CMAKE_CURRENT_BINARY_DIR})
-foreach(exe example benchmark)
- add_executable(${exe} ${exe}.cc intgemm.cc)
+foreach(exe benchmark biasmultiply)
+ add_executable(${exe} benchmarks/${exe}.cc intgemm.cc)
endforeach()
+add_executable(example example.cc intgemm.cc)
+
add_executable(tests
+ test/test.cc
+
# General tests
test/multiply_test.cc
test/quantize_test.cc
- test/test.cc
test/utils_test.cc
+ test/add127_test.cc
# Kernels tests
test/kernels/add_bias_test.cc
diff --git a/avx2_gemm.h b/avx2_gemm.h
index ed3f895..58221d9 100644
--- a/avx2_gemm.h
+++ b/avx2_gemm.h
@@ -103,6 +103,10 @@ class QuantizeTile8 {
return Tile(input, input + 8, input + 16, input + 24);
}
+ INTGEMM_AVX2 inline __m256i ConsecutiveU(const float *input) {
+ return TileU(input, input + 8, input + 16, input + 24);
+ }
+
INTGEMM_AVX2 inline __m256i ForReshape(const float *input, Index cols) {
// Put higher rows in the second half of the register. These will jumble
// around in the same way then conveniently land in the right place.
@@ -132,6 +136,32 @@ class QuantizeTile8 {
// and the values are only used for GEMM.
return _mm256_permutevar8x32_epi32(packed, shuffle_param);
}
+
+ //A version that produces uint8_ts
+ INTGEMM_AVX2 inline __m256i TileU(const float *input0, const float *input1, const float *input2, const float *input3) {
+ // Looking at the assembly, gcc has pulled this outside the loops calling this.
+ const __m256i neg127 = _mm256_set1_epi8(-127);
+ const __m256i pos127 = _mm256_set1_epi8(127);
+ const __m256i shuffle_param = _mm256_set_epi32(7, 3, 6, 2, 5, 1, 4, 0);
+ // Grab 4 registers at a time in 32-bit format.
+ __m256i g0 = avx2::QuantizerGrab(input0, mult_);
+ __m256i g1 = avx2::QuantizerGrab(input1, mult_);
+ __m256i g2 = avx2::QuantizerGrab(input2, mult_);
+ __m256i g3 = avx2::QuantizerGrab(input3, mult_);
+ // Pack 32-bit to 16-bit.
+ __m256i packed0 = _mm256_packs_epi32(g0, g1);
+ __m256i packed1 = _mm256_packs_epi32(g2, g3);
+ // Pack 16-bit to 8-bit.
+ __m256i packed = _mm256_packs_epi16(packed0, packed1);
+ // Ban -128.
+ packed = _mm256_max_epi8(packed, neg127); //Could be removed if we use +128
+ packed = _mm256_add_epi8(packed, pos127);
+ // Currently in 0 1 2 3 8 9 10 11 16 17 18 19 24 25 26 27 4 5 6 7 12 13 14 15 20 21 22 23 28 29 30 31
+ // Or as 32-bit integers 0 2 4 6 1 3 5 7
+ // Technically this could be removed so long as the rows are bigger than 16
+ // and the values are only used for GEMM.
+ return _mm256_permutevar8x32_epi32(packed, shuffle_param);
+ }
const __m256 mult_;
};
@@ -160,6 +190,22 @@ struct AVX2_8bit {
}
}
+ // Currently A is prepared by quantization but this could theoretically change.
+ INTGEMM_AVX2 static inline void PrepareA(const float *input, uint8_t *output, float quant_mult, Index rows, Index cols) {
+ QuantizeU(input, output, quant_mult, rows * cols);
+ }
+
+ // Just quantize everything in order.
+ INTGEMM_AVX2 static void QuantizeU(const float *input, uint8_t *output, float quant_mult, Index size) {
+ assert(size % 32 == 0);
+ assert(reinterpret_cast<uintptr_t>(input) % 32 == 0);
+ avx2::QuantizeTile8 q(quant_mult);
+ const float *end = input + size;
+ for (; input != end; input += 32, output += 32) {
+ *reinterpret_cast<__m256i*>(output) = q.ConsecutiveU(input);
+ }
+ }
+
// Tile size for B; B must be a multiple of this block size.
static const Index kBTileRow = 32;
static const Index kBTileCol = 8;
@@ -171,6 +217,10 @@ struct AVX2_8bit {
}
INTGEMM_MULTIPLY8(__m256i, INTGEMM_AVX2, CPUType::AVX2)
+
+ INTGEMM_MULTIPLY8NEW(__m256i, INTGEMM_AVX2, CPUType::AVX2)
+
+ INTGEMM_PREPAREBIASFOR8(__m256i, INTGEMM_AVX2, CPUType::AVX2)
constexpr static const char *const kName = "8-bit INTGEMM_AVX2";
diff --git a/avx512_gemm.h b/avx512_gemm.h
index e7d675c..d72859b 100644
--- a/avx512_gemm.h
+++ b/avx512_gemm.h
@@ -204,6 +204,35 @@ struct AVX512_8bit {
}
}
+ // Preparing A for the signed/unsigned multiplication. Using add 127
+ /* Only INTGEMM_AVX512F is necessary but due to GCC 5.4 bug we have to set INTGEMM_AVX512BW */
+ INTGEMM_AVX512BW static inline void PrepareA(const float *input, uint8_t *output, float quant_mult, Index rows, Index cols) {
+ QuantizeU(input, output, quant_mult, rows * cols);
+ }
+
+ // Technically output can be unaligned in Quantize.
+ // But then it will need to be aligned for Multiply.
+ // Convert to 8-bit signed integers.
+ /* Only INTGEMM_AVX512F is necessary but due to GCC 5.4 bug we have to set INTGEMM_AVX512BW */
+
+ INTGEMM_AVX512BW static void QuantizeU(const float *input, uint8_t *output, float quant_mult, Index size) {
+ assert(size % 16 == 0);
+ assert(reinterpret_cast<uintptr_t>(input) % 64 == 0);
+ const __m512i neg127 = _mm512_set1_epi32(-127);
+ const __m128i pos127 = _mm_set1_epi8(127);
+ const __m512 quant_mult_reg = _mm512_set1_ps(quant_mult);
+ const float *end = input + size;
+ for (; input < end; input += 16, output += 16) {
+ __m512i asint = avx512f::QuantizerGrab(input, quant_mult_reg);
+ asint = _mm512_max_epi32(asint, neg127);
+
+ //First convert to 8 bit then add and finally store,
+ //because _mm512_mask_cvtsepi32_storeu_epi8 saturates to signed
+ __m128i as8bit = _mm512_cvtsepi32_epi8(asint);
+ *reinterpret_cast<__m128i*>(output) = _mm_add_epi8(as8bit, pos127);
+ }
+ }
+
// Tile size for B; B must be a multiple of this block size.
static const Index kBTileRow = 64;
static const Index kBTileCol = 8;
@@ -333,6 +362,12 @@ struct AVX512_8bit {
}
}
+ //INTGEMM_PREPARE_BIAS_FOR_8(INTGEMM_AVX2, __m256)
+
+ INTGEMM_MULTIPLY8NEW(__m512i, INTGEMM_AVX512BW, CPUType::AVX2)
+
+ INTGEMM_PREPAREBIASFOR8(__m512i, INTGEMM_AVX512BW, CPUType::AVX2)
+
constexpr static const char *const kName = "8-bit AVX512";
static const CPUType kUses = CPUType::AVX512BW;
diff --git a/benchmark.cc b/benchmarks/benchmark.cc
index 6b01304..6b01304 100644
--- a/benchmark.cc
+++ b/benchmarks/benchmark.cc
diff --git a/benchmarks/biasmultiply.cc b/benchmarks/biasmultiply.cc
new file mode 100644
index 0000000..69ee776
--- /dev/null
+++ b/benchmarks/biasmultiply.cc
@@ -0,0 +1,239 @@
+#include "intgemm.h"
+#include "aligned.h"
+#include <chrono>
+#include <random>
+#include <iostream>
+
+using namespace intgemm;
+
+template <class Routine>
+void testOld(Index rows, Index cols) {
+
+}
+
+template <class Routine>
+std::chrono::duration<double> testNew(Index A_rows, Index width, Index B_cols) {
+ AlignedVector<float> A(A_rows * width);
+ AlignedVector<float> B(width * B_cols);
+ AlignedVector<float> bias(B_cols);
+ std::mt19937 gen;
+ std::uniform_real_distribution<float> dist(-1.0f, 1.0f);
+ for (auto& it : A) {
+ it = dist(gen);
+ }
+ for (auto& it : B) {
+ it = dist(gen);
+ }
+ for (auto& it : bias) {
+ it = dist(gen);
+ }
+
+ float alpha = 2.0f;
+ float quant_mult = 127/alpha;
+ float unquant_mult = 1.0/(quant_mult*quant_mult);
+
+ AlignedVector<uint8_t> A_prep(A.size());
+ AlignedVector<int8_t> B_prep(B.size());
+ Routine::PrepareA(A.begin(), A_prep.begin(), quant_mult, A_rows, width);
+ Routine::PrepareB(B.begin(), B_prep.begin(), quant_mult, width, B_cols);
+
+ AlignedVector<float> test_C(A_rows * B_cols);
+
+ float unquant_mult_forprep = (-1)*(alpha)*(alpha)/(127.0f); //Minus one to invert add_ps later on
+ Routine::PrepareBiasFor8(1, B_prep.begin(), 1, width, B_cols, callbacks::UnquantizeAndAddBiasAndWrite(unquant_mult_forprep, bias.begin(), bias.begin()));
+ auto start = std::chrono::system_clock::now();
+ Routine::Multiply8new(A_prep.begin(), B_prep.begin(), A_rows, width, B_cols, callbacks::UnquantizeAndAddBiasAndWrite(unquant_mult, bias.begin(), test_C.begin()));
+ auto end = std::chrono::system_clock::now();
+
+ std::chrono::duration<double> elapsed_seconds = end-start;
+ return elapsed_seconds;
+
+}
+
+template <class Routine>
+std::chrono::duration<double> testOld(Index A_rows, Index width, Index B_cols) {
+ AlignedVector<float> A(A_rows * width);
+ AlignedVector<float> B(width * B_cols);
+ AlignedVector<float> bias(B_cols);
+ std::mt19937 gen;
+ std::uniform_real_distribution<float> dist(-1.0f, 1.0f);
+ for (auto& it : A) {
+ it = dist(gen);
+ }
+ for (auto& it : B) {
+ it = dist(gen);
+ }
+ for (auto& it : bias) {
+ it = dist(gen);
+ }
+
+ float alpha = 2.0f;
+ float quant_mult = 127/alpha;
+ float unquant_mult = 1.0/(quant_mult*quant_mult);
+
+ AlignedVector<int8_t> A_prep(A.size());
+ AlignedVector<int8_t> B_prep(B.size());
+ Routine::PrepareA(A.begin(), A_prep.begin(), quant_mult, A_rows, width);
+ Routine::PrepareB(B.begin(), B_prep.begin(), quant_mult, width, B_cols);
+
+ AlignedVector<float> test_C(A_rows * B_cols);
+
+ auto start = std::chrono::system_clock::now();
+ Routine::Multiply(A_prep.begin(), B_prep.begin(), A_rows, width, B_cols, callbacks::UnquantizeAndAddBiasAndWrite(unquant_mult, bias.begin(), test_C.begin()));
+ auto end = std::chrono::system_clock::now();
+
+ std::chrono::duration<double> elapsed_seconds = end-start;
+ return elapsed_seconds;
+
+}
+
+template <class Routine>
+std::chrono::duration<double> testOld_nobias(Index A_rows, Index width, Index B_cols) {
+ AlignedVector<float> A(A_rows * width);
+ AlignedVector<float> B(width * B_cols);
+ std::mt19937 gen;
+ std::uniform_real_distribution<float> dist(-1.0f, 1.0f);
+ for (auto& it : A) {
+ it = dist(gen);
+ }
+ for (auto& it : B) {
+ it = dist(gen);
+ }
+
+ float alpha = 2.0f;
+ float quant_mult = 127/alpha;
+ float unquant_mult = 1.0/(quant_mult*quant_mult);
+
+ AlignedVector<int8_t> A_prep(A.size());
+ AlignedVector<int8_t> B_prep(B.size());
+ Routine::PrepareA(A.begin(), A_prep.begin(), quant_mult, A_rows, width);
+ Routine::PrepareB(B.begin(), B_prep.begin(), quant_mult, width, B_cols);
+
+ AlignedVector<float> test_C(A_rows * B_cols);
+
+ auto start = std::chrono::system_clock::now();
+ Routine::Multiply(A_prep.begin(), B_prep.begin(), A_rows, width, B_cols, callbacks::UnquantizeAndWrite(unquant_mult, test_C.begin()));
+ auto end = std::chrono::system_clock::now();
+
+ std::chrono::duration<double> elapsed_seconds = end-start;
+ return elapsed_seconds;
+
+}
+
+int main(int argc, char ** argv) {
+ int repeat = 1000;
+ if (argc > 1) {
+ repeat = atoi(argv[1]);
+ }
+
+ std::chrono::duration<double> oldSSSE3_nobias = testOld_nobias<SSSE3_8bit>(1, 64, 8);
+ for (int i = 0; i<repeat; i++) {
+ oldSSSE3_nobias += testOld_nobias<SSSE3_8bit>(8, 256, 256);
+ oldSSSE3_nobias += testOld_nobias<SSSE3_8bit>(8, 2048, 256);
+ oldSSSE3_nobias += testOld_nobias<SSSE3_8bit>(320, 256, 256);
+ oldSSSE3_nobias += testOld_nobias<SSSE3_8bit>(472, 256, 256);
+ oldSSSE3_nobias += testOld_nobias<SSSE3_8bit>(248, 256, 256);
+ oldSSSE3_nobias += testOld_nobias<SSSE3_8bit>(200, 256, 256);
+ }
+
+ std::cout << repeat << " iterations of Old SSSE3 without bias took: " << oldSSSE3_nobias.count() << " seconds." << std::endl;
+
+ std::chrono::duration<double> oldSSSE3 = testOld<SSSE3_8bit>(1, 64, 8);
+ for (int i = 0; i<repeat; i++) {
+ oldSSSE3 += testOld<SSSE3_8bit>(8, 256, 256);
+ oldSSSE3 += testOld<SSSE3_8bit>(8, 2048, 256);
+ oldSSSE3 += testOld<SSSE3_8bit>(320, 256, 256);
+ oldSSSE3 += testOld<SSSE3_8bit>(472, 256, 256);
+ oldSSSE3 += testOld<SSSE3_8bit>(248, 256, 256);
+ oldSSSE3 += testOld<SSSE3_8bit>(200, 256, 256);
+ }
+
+ std::cout << repeat << " iterations of Old SSSE3 took: " << oldSSSE3.count() << " seconds." << std::endl;
+
+ std::chrono::duration<double> newTimeSSSE3 = testOld<SSSE3_8bit>(1, 64, 8);
+ for (int i = 0; i<repeat; i++) {
+ newTimeSSSE3 += testNew<SSSE3_8bit>(8, 256, 256);
+ newTimeSSSE3 += testNew<SSSE3_8bit>(8, 2048, 256);
+ newTimeSSSE3 += testNew<SSSE3_8bit>(320, 256, 256);
+ newTimeSSSE3 += testNew<SSSE3_8bit>(472, 256, 256);
+ newTimeSSSE3 += testNew<SSSE3_8bit>(248, 256, 256);
+ newTimeSSSE3 += testNew<SSSE3_8bit>(200, 256, 256);
+ }
+
+ std::cout << repeat << " iterations of New SSSE3 took: " << newTimeSSSE3.count() << " seconds." << std::endl;
+
+ std::chrono::duration<double> oldAVX2_nobias = testOld_nobias<AVX2_8bit>(1, 64, 8);
+ for (int i = 0; i<repeat; i++) {
+ oldAVX2_nobias += testOld_nobias<AVX2_8bit>(8, 256, 256);
+ oldAVX2_nobias += testOld_nobias<AVX2_8bit>(8, 2048, 256);
+ oldAVX2_nobias += testOld_nobias<AVX2_8bit>(320, 256, 256);
+ oldAVX2_nobias += testOld_nobias<AVX2_8bit>(472, 256, 256);
+ oldAVX2_nobias += testOld_nobias<AVX2_8bit>(248, 256, 256);
+ oldAVX2_nobias += testOld_nobias<AVX2_8bit>(200, 256, 256);
+ }
+
+ std::cout << repeat << " iterations of Old AVX2 without bias took: " << oldAVX2_nobias.count() << " seconds." << std::endl;
+
+ std::chrono::duration<double> oldAVX2 = testOld<AVX2_8bit>(1, 64, 8);
+ for (int i = 0; i<repeat; i++) {
+ oldAVX2 += testOld<AVX2_8bit>(8, 256, 256);
+ oldAVX2 += testOld<AVX2_8bit>(8, 2048, 256);
+ oldAVX2 += testOld<AVX2_8bit>(320, 256, 256);
+ oldAVX2 += testOld<AVX2_8bit>(472, 256, 256);
+ oldAVX2 += testOld<AVX2_8bit>(248, 256, 256);
+ oldAVX2 += testOld<AVX2_8bit>(200, 256, 256);
+ }
+
+ std::cout << repeat << " iterations of Old AVX2 took: " << oldAVX2.count() << " seconds." << std::endl;
+
+ std::chrono::duration<double> newTimeAVX2 = testOld<AVX2_8bit>(1, 64, 8);
+ for (int i = 0; i<repeat; i++) {
+ newTimeAVX2 += testNew<AVX2_8bit>(8, 256, 256);
+ newTimeAVX2 += testNew<AVX2_8bit>(8, 2048, 256);
+ newTimeAVX2 += testNew<AVX2_8bit>(320, 256, 256);
+ newTimeAVX2 += testNew<AVX2_8bit>(472, 256, 256);
+ newTimeAVX2 += testNew<AVX2_8bit>(248, 256, 256);
+ newTimeAVX2 += testNew<AVX2_8bit>(200, 256, 256);
+ }
+
+ std::cout << repeat << " iterations of New AVX2 took: " << newTimeAVX2.count() << " seconds." << std::endl;
+
+ if (kCPU < CPUType::AVX512BW) return 0;
+ std::chrono::duration<double> oldAVX512_nobias = testOld_nobias<AVX512_8bit>(1, 64, 8);
+ for (int i = 0; i<repeat; i++) {
+ oldAVX512_nobias += testOld_nobias<AVX512_8bit>(8, 256, 256);
+ oldAVX512_nobias += testOld_nobias<AVX512_8bit>(8, 2048, 256);
+ oldAVX512_nobias += testOld_nobias<AVX512_8bit>(320, 256, 256);
+ oldAVX512_nobias += testOld_nobias<AVX512_8bit>(472, 256, 256);
+ oldAVX512_nobias += testOld_nobias<AVX512_8bit>(248, 256, 256);
+ oldAVX512_nobias += testOld_nobias<AVX512_8bit>(200, 256, 256);
+ }
+
+ std::cout << repeat << " iterations of Old AVX512 without bias took: " << oldAVX512_nobias.count() << " seconds." << std::endl;
+
+ std::chrono::duration<double> oldAVX512 = testOld<AVX512_8bit>(1, 64, 8);
+ for (int i = 0; i<repeat; i++) {
+ oldAVX512 += testOld<AVX512_8bit>(8, 256, 256);
+ oldAVX512 += testOld<AVX512_8bit>(8, 2048, 256);
+ oldAVX512 += testOld<AVX512_8bit>(320, 256, 256);
+ oldAVX512 += testOld<AVX512_8bit>(472, 256, 256);
+ oldAVX512 += testOld<AVX512_8bit>(248, 256, 256);
+ oldAVX512 += testOld<AVX512_8bit>(200, 256, 256);
+ }
+
+ std::cout << repeat << " iterations of Old AVX512 took: " << oldAVX512.count() << " seconds." << std::endl;
+
+ std::chrono::duration<double> newTimeAVX512 = testOld<AVX512_8bit>(1, 64, 8);
+ for (int i = 0; i<repeat; i++) {
+ newTimeAVX512 += testNew<AVX512_8bit>(8, 256, 256);
+ newTimeAVX512 += testNew<AVX512_8bit>(8, 2048, 256);
+ newTimeAVX512 += testNew<AVX512_8bit>(320, 256, 256);
+ newTimeAVX512 += testNew<AVX512_8bit>(472, 256, 256);
+ newTimeAVX512 += testNew<AVX512_8bit>(248, 256, 256);
+ newTimeAVX512 += testNew<AVX512_8bit>(200, 256, 256);
+ }
+
+ std::cout << repeat << " iterations of New AVX512 took: " << newTimeAVX512.count() << " seconds." << std::endl;
+
+
+}
diff --git a/interleave.h b/interleave.h
index 202062f..cc06378 100644
--- a/interleave.h
+++ b/interleave.h
@@ -270,4 +270,21 @@ target static inline void SelectColumnsOfB(const Register *input, Register *outp
} \
} \
+#define INTGEMM_PREPARE_BIAS_FOR_8(target, Register) \
+target static inline void PrepareBiasFor8(const float *input, float *bias, float alpha, Index rows, Index cols) { \
+ assert(cols*sizeof(float) % sizeof(Register) == 0); \
+ constexpr int stride = sizeof(Register) / sizeof(float); \
+ Register alpha_reg = set1_ps<Register>(alpha); \
+ for (Index c = 0; c<cols; c+=stride) { \
+ Register vectorsum = set1_ps<Register>(0.0f); \
+ for (Index r = 0; r < rows; r++) { \
+ Register column_stride = load_ps<Register>(input + r*cols + c); \
+ vectorsum = add_ps(vectorsum, column_stride); \
+ } \
+ Register *towrite = reinterpret_cast<Register *>(bias + c); \
+ vectorsum = mul_ps(vectorsum, alpha_reg); \
+ *towrite = sub_ps(*towrite, vectorsum); \
+ } \
+} \
+
} // namespace intgemm
diff --git a/intgemm.cc b/intgemm.cc
index 6928f0c..8e86f3c 100644
--- a/intgemm.cc
+++ b/intgemm.cc
@@ -16,10 +16,14 @@ const char *const Int16::kName = ChooseCPU(AVX512_16bit::kName, AVX2_16bit::kNam
void (*Int8::Quantize)(const float *input, int8_t *output, float quant_mult, Index size) = ChooseCPU(AVX512_8bit::Quantize, AVX2_8bit::Quantize, SSSE3_8bit::Quantize, Unsupported_8bit::Quantize, Unsupported_8bit::Quantize);
+void (*Int8::QuantizeU)(const float *input, uint8_t *output, float quant_mult, Index size) = ChooseCPU(AVX512_8bit::QuantizeU, AVX2_8bit::QuantizeU, SSSE3_8bit::QuantizeU, Unsupported_8bit::QuantizeU, Unsupported_8bit::QuantizeU);
+
void (*Int8::PrepareB)(const float *input, int8_t *output, float quant_mult, Index rows, Index cols) = ChooseCPU(AVX512_8bit::PrepareB, AVX2_8bit::PrepareB, SSSE3_8bit::PrepareB, Unsupported_8bit::PrepareB, Unsupported_8bit::PrepareB);
void (*Int8::SelectColumnsB)(const int8_t *input, int8_t *output, Index rows, const Index *cols_begin, const Index *cols_end) = ChooseCPU(AVX512_8bit::SelectColumnsB, AVX2_8bit::SelectColumnsB, SSSE3_8bit::SelectColumnsB, Unsupported_8bit::SelectColumnsB, Unsupported_8bit::SelectColumnsB);
+//void (*Int8::PrepareBiasFor8)(const float *input, float *bias, float alpha, Index rows, Index cols) = ChooseCPU(AVX512_8bit::PrepareBiasFor8, AVX2_8bit::PrepareBiasFor8, SSSE3_8bit::PrepareBiasFor8, Unsupported_8bit::PrepareBiasFor8, Unsupported_8bit::PrepareBiasFor8);
+
const char *const Int8::kName = ChooseCPU(AVX512_8bit::kName, AVX2_8bit::kName, SSSE3_8bit::kName, Unsupported_8bit::kName, Unsupported_8bit::kName);
const CPUType kCPU = ChooseCPU(CPUType::AVX512BW, CPUType::AVX2, CPUType::SSSE3, CPUType::SSE2, CPUType::UNSUPPORTED);
diff --git a/intgemm.h b/intgemm.h
index c096aae..d4155b5 100644
--- a/intgemm.h
+++ b/intgemm.h
@@ -76,9 +76,16 @@ struct Unsupported_8bit {
static void Quantize(const float *, int8_t *, float, Index) {
throw UnsupportedCPU();
}
+ static void QuantizeU(const float *, uint8_t *, float, Index) {
+ throw UnsupportedCPU();
+ }
static void PrepareB(const float *, int8_t *, float, Index, Index) {
throw UnsupportedCPU();
}
+ template<class Callback>
+ static void PrepareBiasFor8(const int8_t, const int8_t *, Index, Index, Index, Callback) {
+ throw UnsupportedCPU();
+ }
static void SelectColumnsB(const int8_t *, int8_t *, Index, const Index *, const Index *) {
throw UnsupportedCPU();
}
@@ -86,6 +93,10 @@ struct Unsupported_8bit {
static void Multiply(const int8_t *, const int8_t *, Index, Index, Index, Callback) {
throw UnsupportedCPU();
}
+ template<class Callback>
+ static void Multiply8new(const uint8_t *, const int8_t *, Index, Index, Index, Callback) {
+ throw UnsupportedCPU();
+ }
constexpr static const char *const kName = "8-bit Unsupported";
};
@@ -184,11 +195,20 @@ class Int8Mult {
public:
// Multiply C = A * B, presuming A and B have been prepared.
static void (*Multiply)(const int8_t *A, const int8_t *B, Index A_rows, Index width, Index B_cols, Callback callback);
+ static void (*Multiply8new)(const uint8_t *A, const int8_t *B, Index A_rows, Index width, Index B_cols, Callback callback);
+ static void (*PrepareBiasFor8)(const int8_t A, const int8_t *B, Index A_rows, Index width, Index B_cols, Callback callback);
};
template <typename Callback>
void (*Int8Mult<Callback>::Multiply)(const int8_t *A, const int8_t *B, Index A_rows, Index width, Index B_cols, Callback callback) = ChooseCPU(AVX512_8bit::Multiply<Callback>, AVX2_8bit::Multiply<Callback>, SSSE3_8bit::Multiply<Callback>, SSSE3_8bit::Multiply<Callback>, Unsupported_8bit::Multiply);
+template <class Callback>
+void (*Int8Mult<Callback>::Multiply8new)(const uint8_t *A, const int8_t *B, Index A_rows, Index width, Index B_cols, Callback callback) = ChooseCPU(AVX512_8bit::Multiply8new<Callback>, AVX2_8bit::Multiply8new<Callback>, SSSE3_8bit::Multiply8new<Callback>, SSSE3_8bit::Multiply8new<Callback>, Unsupported_8bit::Multiply8new);
+
+template <class Callback>
+void (*Int8Mult<Callback>::PrepareBiasFor8)(const int8_t A, const int8_t *B, Index A_rows, Index width, Index B_cols, Callback callback) = ChooseCPU(AVX512_8bit::PrepareBiasFor8<Callback>, AVX2_8bit::PrepareBiasFor8<Callback>, SSSE3_8bit::PrepareBiasFor8<Callback>, SSSE3_8bit::PrepareBiasFor8<Callback>, Unsupported_8bit::PrepareBiasFor8);
+
+
struct Int8 {
typedef int8_t Integer;
@@ -206,8 +226,18 @@ struct Int8 {
Quantize(input, output, quant_mult, rows * cols);
}
+ static inline void PrepareANew(const float *input, int8_t *output, float quant_mult, Index rows, Index cols) {
+ QuantizeU(input, reinterpret_cast<uint8_t *>(output), quant_mult, rows * cols);
+ }
+
// Multiply floats by quant_mult then convert to 8-bit integers with saturation.
static void (*Quantize)(const float *input, int8_t *output, float quant_mult, Index size);
+
+ // Multiply floats by quant_mult then convert to 8-bit integers with saturation.
+ static void (*QuantizeU)(const float *input, uint8_t *output, float quant_mult, Index size);
+
+ // PrepareB
+ //static void (*PrepareBiasFor8)(const float *input, float *bias, float alpha, Index rows, Index cols);
// Warning: the output of PrepareB depends on the CPU.
// It will match the Multiply function on the same CPU though.
@@ -221,6 +251,16 @@ struct Int8 {
static void Multiply(const int8_t *A, const int8_t *B, Index A_rows, Index width, Index B_cols, Callback callback) {
Int8Mult<Callback>::Multiply(A, B, A_rows, width, B_cols, callback);
}
+
+ template<class Callback>
+ static void Multiply8new(const int8_t *A, const int8_t *B, Index A_rows, Index width, Index B_cols, Callback callback) {
+ Int8Mult<Callback>::Multiply8new((const uint8_t *)A, B, A_rows, width, B_cols, callback);
+ }
+
+ template<class Callback>
+ static void PrepareBiasFor8(const int8_t A, const int8_t *B, Index A_rows, Index width, Index B_cols, Callback callback) {
+ Int8Mult<Callback>::PrepareBiasFor8(A, B, A_rows, width, B_cols, callback);
+ }
static const char *const kName;
};
diff --git a/intrinsics.h b/intrinsics.h
index e19efd7..9be5296 100644
--- a/intrinsics.h
+++ b/intrinsics.h
@@ -16,10 +16,11 @@ namespace intgemm {
* Define a bunch of intrinstics as overloaded functions so they work with
* templates.
*/
+template <class Register> static inline Register load_ps(float const* from);
template <class Register> static inline Register loadu_ps(const float* mem_addr);
-template <class Register> static inline Register set1_epi8(int8_t to);
template <class Register> static inline Register set1_epi16(int16_t to);
template <class Register> static inline Register set1_epi32(int32_t to);
+template <class Register> static inline Register set1_epi8(int8_t to);
template <class Register> static inline Register set1_pd(double to);
template <class Register> static inline Register set1_ps(float to);
template <class Register> static inline Register setzero_pd();
@@ -73,6 +74,9 @@ INTGEMM_SSE2 static inline __m128 div_ps(__m128 a, __m128 b) {
/*
* Missing i32gather_ps for SSE2
*/
+template <> INTGEMM_SSE2 inline __m128 load_ps<__m128>(const float* from) {
+ return _mm_load_ps(from);
+}
template <> INTGEMM_SSE2 inline __m128 loadu_ps(const float* mem_addr) {
return _mm_loadu_ps(mem_addr);
}
@@ -242,6 +246,9 @@ INTGEMM_AVX2 static inline __m256 i32gather_ps(float const *base_addr, __m256i v
template <> INTGEMM_AVX2 inline __m256 loadu_ps(const float* mem_addr) {
return _mm256_loadu_ps(mem_addr);
}
+template <> INTGEMM_AVX2 inline __m256 load_ps<__m256>(const float* from) {
+ return _mm256_load_ps(from);
+}
INTGEMM_AVX2 static inline __m256i madd_epi16(__m256i first, __m256i second) {
return _mm256_madd_epi16(first, second);
}
@@ -479,6 +486,9 @@ template <> INTGEMM_AVX512BW inline __m512 setzero_ps<__m512>() {
template <> INTGEMM_AVX512BW inline __m512i setzero_si<__m512i>() {
return _mm512_setzero_si512();
}
+template <> INTGEMM_AVX512BW inline __m512 load_ps<__m512>(const float* from) {
+ return _mm512_load_ps(from);
+}
/*
* Missing sign_epi8
*/
diff --git a/multiply.h b/multiply.h
index aef4aab..c8f636b 100644
--- a/multiply.h
+++ b/multiply.h
@@ -192,7 +192,160 @@ template <typename Callback> target static void Multiply(const int16_t *A, const
} \
} \
-/* 8-bit matrix multiply used by AVX and INTGEMM_AVX2.
+//An int8_prepbias version of the above code, using the add 127 technique
+#define INTGEMM_PREPAREBIASFOR8(Integer, target, cpu_type) \
+ template <class Callback> target static void PrepareBiasFor8(const int8_t A, const int8_t *B, Index A_rows, Index width, Index B_cols, Callback callback) { \
+ assert(width % (sizeof(Integer) / sizeof(int8_t)) == 0); \
+ assert(B_cols % 8 == 0); \
+ assert(reinterpret_cast<uintptr_t>(B) % sizeof(Integer) == 0); \
+ const int simd_width = width / (sizeof(Integer) / sizeof(int8_t)); \
+ auto callback_impl = callbacks::CallbackImpl<cpu_type, Callback>(callback); \
+ const Integer *B0_col = reinterpret_cast<const Integer *>(B); \
+ const Integer a = set1_epi8<Integer>(A); \
+ for (Index B0_colidx = 0; B0_colidx < B_cols; B0_col += 8 * simd_width, B0_colidx += 8) { \
+ /* Process one row of A at a time. Doesn't seem to be faster to do multiple rows of A at once.*/ \
+ for (Index A_rowidx = 0; A_rowidx < A_rows; ++A_rowidx) { \
+ /*const Integer *A_row = reinterpret_cast<const Integer*>(A + A_rowidx * width);*/ \
+ /* These will be packed 16-bit integers containing sums for each row of B multiplied by the row of A. \
+ Iterate over shared (inner) dimension.*/ \
+ int k = 0; \
+ Integer sum0 = maddubs_epi16(a, *(B0_col + k * 8)); \
+ Integer sum1 = maddubs_epi16(a, *(B0_col + k * 8 + 1)); \
+ Integer sum2 = maddubs_epi16(a, *(B0_col + k * 8 + 2)); \
+ Integer sum3 = maddubs_epi16(a, *(B0_col + k * 8 + 3)); \
+ Integer sum4 = maddubs_epi16(a, *(B0_col + k * 8 + 4)); \
+ Integer sum5 = maddubs_epi16(a, *(B0_col + k * 8 + 5)); \
+ Integer sum6 = maddubs_epi16(a, *(B0_col + k * 8 + 6)); \
+ Integer sum7 = maddubs_epi16(a, *(B0_col + k * 8 + 7)); \
+ /* Upcast to 32-bit and horizontally add. Seems a bit faster if this is declared here.*/ \
+ Integer ones = set1_epi16<Integer>(1); \
+ sum0 = madd_epi16(sum0, ones); \
+ sum1 = madd_epi16(sum1, ones); \
+ sum2 = madd_epi16(sum2, ones); \
+ sum3 = madd_epi16(sum3, ones); \
+ sum4 = madd_epi16(sum4, ones); \
+ sum5 = madd_epi16(sum5, ones); \
+ sum6 = madd_epi16(sum6, ones); \
+ sum7 = madd_epi16(sum7, ones); \
+ for (int k = 1; k < simd_width; ++k) { \
+ /*Integer a = *(A_row + k);*/ \
+ /* Multiply 8-bit, horizontally add to packed 16-bit integers.*/ \
+ Integer mult0 = maddubs_epi16(a, *(B0_col + k * 8)); \
+ Integer mult1 = maddubs_epi16(a, *(B0_col + k * 8 + 1)); \
+ Integer mult2 = maddubs_epi16(a, *(B0_col + k * 8 + 2)); \
+ Integer mult3 = maddubs_epi16(a, *(B0_col + k * 8 + 3)); \
+ Integer mult4 = maddubs_epi16(a, *(B0_col + k * 8 + 4)); \
+ Integer mult5 = maddubs_epi16(a, *(B0_col + k * 8 + 5)); \
+ Integer mult6 = maddubs_epi16(a, *(B0_col + k * 8 + 6)); \
+ Integer mult7 = maddubs_epi16(a, *(B0_col + k * 8 + 7)); \
+ /* Upcast to 32-bit and horizontally add.*/ \
+ mult0 = madd_epi16(mult0, ones); \
+ mult1 = madd_epi16(mult1, ones); \
+ mult2 = madd_epi16(mult2, ones); \
+ mult3 = madd_epi16(mult3, ones); \
+ mult4 = madd_epi16(mult4, ones); \
+ mult5 = madd_epi16(mult5, ones); \
+ mult6 = madd_epi16(mult6, ones); \
+ mult7 = madd_epi16(mult7, ones); \
+ /*Add in 32bit*/ \
+ sum0 = add_epi32(sum0, mult0); \
+ sum1 = add_epi32(sum1, mult1); \
+ sum2 = add_epi32(sum2, mult2); \
+ sum3 = add_epi32(sum3, mult3); \
+ sum4 = add_epi32(sum4, mult4); \
+ sum5 = add_epi32(sum5, mult5); \
+ sum6 = add_epi32(sum6, mult6); \
+ sum7 = add_epi32(sum7, mult7); \
+ \
+ } \
+ /* Reduce sums within 128-bit lanes.*/ \
+ Integer pack0123 = Pack0123(sum0, sum1, sum2, sum3); \
+ Integer pack4567 = Pack0123(sum4, sum5, sum6, sum7); \
+ /*The specific implementation may need to reduce further.*/ \
+ auto total = PermuteSummer(pack0123, pack4567); \
+ RunCallback(callback_impl, total, A_rowidx, B0_colidx, A_rows, B_cols); \
+ } \
+ } \
+} \
+
+//An int8 version of the above code, using the add 127 technique
+#define INTGEMM_MULTIPLY8NEW(Integer, target, cpu_type) \
+ template <class Callback> target static void Multiply8new(const uint8_t *A, const int8_t *B, Index A_rows, Index width, Index B_cols, Callback callback) { \
+ assert(width % (sizeof(Integer) / sizeof(int8_t)) == 0); \
+ assert(B_cols % 8 == 0); \
+ assert(reinterpret_cast<uintptr_t>(A) % sizeof(Integer) == 0); \
+ assert(reinterpret_cast<uintptr_t>(B) % sizeof(Integer) == 0); \
+ const int simd_width = width / (sizeof(Integer) / sizeof(int8_t)); \
+ auto callback_impl = callbacks::CallbackImpl<cpu_type, Callback>(callback); \
+ const Integer *B0_col = reinterpret_cast<const Integer *>(B); \
+ for (Index B0_colidx = 0; B0_colidx < B_cols; B0_col += 8 * simd_width, B0_colidx += 8) { \
+ /* Process one row of A at a time. Doesn't seem to be faster to do multiple rows of A at once.*/ \
+ for (Index A_rowidx = 0; A_rowidx < A_rows; ++A_rowidx) { \
+ const Integer *A_row = reinterpret_cast<const Integer*>(A + A_rowidx * width); \
+ /* These will be packed 16-bit integers containing sums for each row of B multiplied by the row of A. \
+ Iterate over shared (inner) dimension.*/ \
+ int k = 0; \
+ Integer a = *(A_row + k); \
+ Integer sum0 = maddubs_epi16(a, *(B0_col + k * 8)); \
+ Integer sum1 = maddubs_epi16(a, *(B0_col + k * 8 + 1)); \
+ Integer sum2 = maddubs_epi16(a, *(B0_col + k * 8 + 2)); \
+ Integer sum3 = maddubs_epi16(a, *(B0_col + k * 8 + 3)); \
+ Integer sum4 = maddubs_epi16(a, *(B0_col + k * 8 + 4)); \
+ Integer sum5 = maddubs_epi16(a, *(B0_col + k * 8 + 5)); \
+ Integer sum6 = maddubs_epi16(a, *(B0_col + k * 8 + 6)); \
+ Integer sum7 = maddubs_epi16(a, *(B0_col + k * 8 + 7)); \
+ /* Upcast to 32-bit and horizontally add. Seems a bit faster if this is declared here.*/ \
+ Integer ones = set1_epi16<Integer>(1); \
+ sum0 = madd_epi16(sum0, ones); \
+ sum1 = madd_epi16(sum1, ones); \
+ sum2 = madd_epi16(sum2, ones); \
+ sum3 = madd_epi16(sum3, ones); \
+ sum4 = madd_epi16(sum4, ones); \
+ sum5 = madd_epi16(sum5, ones); \
+ sum6 = madd_epi16(sum6, ones); \
+ sum7 = madd_epi16(sum7, ones); \
+ for (int k = 1; k < simd_width; ++k) { \
+ Integer a = *(A_row + k); \
+ /* Multiply 8-bit, horizontally add to packed 16-bit integers.*/ \
+ Integer mult0 = maddubs_epi16(a, *(B0_col + k * 8)); \
+ Integer mult1 = maddubs_epi16(a, *(B0_col + k * 8 + 1)); \
+ Integer mult2 = maddubs_epi16(a, *(B0_col + k * 8 + 2)); \
+ Integer mult3 = maddubs_epi16(a, *(B0_col + k * 8 + 3)); \
+ Integer mult4 = maddubs_epi16(a, *(B0_col + k * 8 + 4)); \
+ Integer mult5 = maddubs_epi16(a, *(B0_col + k * 8 + 5)); \
+ Integer mult6 = maddubs_epi16(a, *(B0_col + k * 8 + 6)); \
+ Integer mult7 = maddubs_epi16(a, *(B0_col + k * 8 + 7)); \
+ /* Upcast to 32-bit and horizontally add.*/ \
+ mult0 = madd_epi16(mult0, ones); \
+ mult1 = madd_epi16(mult1, ones); \
+ mult2 = madd_epi16(mult2, ones); \
+ mult3 = madd_epi16(mult3, ones); \
+ mult4 = madd_epi16(mult4, ones); \
+ mult5 = madd_epi16(mult5, ones); \
+ mult6 = madd_epi16(mult6, ones); \
+ mult7 = madd_epi16(mult7, ones); \
+ /*Add in 32bit*/ \
+ sum0 = add_epi32(sum0, mult0); \
+ sum1 = add_epi32(sum1, mult1); \
+ sum2 = add_epi32(sum2, mult2); \
+ sum3 = add_epi32(sum3, mult3); \
+ sum4 = add_epi32(sum4, mult4); \
+ sum5 = add_epi32(sum5, mult5); \
+ sum6 = add_epi32(sum6, mult6); \
+ sum7 = add_epi32(sum7, mult7); \
+ \
+ } \
+ /* Reduce sums within 128-bit lanes.*/ \
+ Integer pack0123 = Pack0123(sum0, sum1, sum2, sum3); \
+ Integer pack4567 = Pack0123(sum4, sum5, sum6, sum7); \
+ /*The specific implementation may need to reduce further.*/ \
+ auto total = PermuteSummer(pack0123, pack4567); \
+ RunCallback(callback_impl, total, A_rowidx, B0_colidx, A_rows, B_cols); \
+ } \
+ } \
+} \
+
+/* 8-bit matrix multiply used by AVX and AVX2.
* These have two peculiar properties:
* 1. The sign instructions don't exist in AVX512.
* 2. 16 registers means gcc's register allocation failed so I wrote it in my
diff --git a/ssse3_gemm.h b/ssse3_gemm.h
index 24cc179..fcd0de8 100644
--- a/ssse3_gemm.h
+++ b/ssse3_gemm.h
@@ -35,6 +35,11 @@ class QuantizeTile8 {
return Tile(input, input + 8);
}
+ INTGEMM_SSSE3 inline __m128i ConsecutiveU(const float *input) {
+ return TileU(input, input + 8);
+ }
+
+
private:
// Quantize 16xfloat into 16xint8_t
INTGEMM_SSSE3 inline __m128i Tile(const float *input0, const float *input1) {
@@ -48,7 +53,7 @@ class QuantizeTile8 {
__m128i packed = _mm_packs_epi16(packed0, packed1);
/* Ban -128.
* Don't use the SSE4.1 instruction _mm_max_epi8(packed, neg127). Instead,
- * use INTGEMM_SSE2 instructions _mm_cmpeq_epi8 and _mm_sub_epi8.
+ * use SSE2 instructions _mm_cmpeq_epi8 and _mm_sub_epi8.
* The first generates 0xff for fields -128.
* The second subtracts 0xff from -128 which has the effect of converting
* to -127.
@@ -59,6 +64,29 @@ class QuantizeTile8 {
// No permute needed. packs is in order for SSE.
}
+ INTGEMM_SSSE3 inline __m128i TileU(const float *input0, const float *input1) {
+ const __m128i neg128 = _mm_set1_epi8(-128);
+ const __m128i pos127 = _mm_set1_epi8(127);
+ __m128i g0 = QuantizerGrab(input0, mult_reg_);
+ __m128i g1 = QuantizerGrab(input0 + 4, mult_reg_);
+ __m128i g2 = QuantizerGrab(input1, mult_reg_);
+ __m128i g3 = QuantizerGrab(input1 + 4, mult_reg_);
+ __m128i packed0 = _mm_packs_epi32(g0, g1);
+ __m128i packed1 = _mm_packs_epi32(g2, g3);
+ __m128i packed = _mm_packs_epi16(packed0, packed1);
+ /* Ban -128.
+ * Don't use the SSE4.1 instruction _mm_max_epi8(packed, neg127). Instead,
+ * use SSE2 instructions _mm_cmpeq_epi8 and _mm_sub_epi8.
+ * The first generates 0xff for fields -128.
+ * The second subtracts 0xff from -128 which has the effect of converting
+ * to -127.
+ */
+ // packed = _mm_max_epi8(packed, neg127);
+ __m128i evils = _mm_cmpeq_epi8(packed, neg128);
+ return _mm_add_epi8(_mm_sub_epi8(packed, evils), pos127);
+ // No permute needed. packs is in order for SSE.
+ }
+
private:
const __m128 mult_reg_;
};
@@ -86,6 +114,23 @@ struct SSSE3_8bit {
}
}
+ // Version with unsigned int + 127
+ // Currently A is prepared by quantization but this could theoretically change.
+ INTGEMM_SSSE3 static inline void PrepareA(const float *input, uint8_t *output, float quant_mult, Index rows, Index cols) {
+ QuantizeU(input, output, quant_mult, rows * cols);
+ }
+
+ INTGEMM_SSSE3 static void QuantizeU(const float *input, uint8_t *output, float quant_mult, Index size) {
+ assert(size % 16 == 0);
+ assert(reinterpret_cast<uintptr_t>(input) % 16 == 0);
+ assert(reinterpret_cast<uintptr_t>(output) % 16 == 0);
+ ssse3::QuantizeTile8 q(quant_mult);
+ const float *end = input + size;
+ for (; input != end; input += 16, output += 16) {
+ *reinterpret_cast<__m128i*>(output) = q.ConsecutiveU(input);
+ }
+ }
+
// Tile size for B; B must be a multiple of this block size.
static const Index kBTileRow = 16;
static const Index kBTileCol = 8;
@@ -97,7 +142,13 @@ struct SSSE3_8bit {
}
INTGEMM_MULTIPLY8(__m128i, INTGEMM_SSSE3, CPUType::SSE2)
+
+ INTGEMM_MULTIPLY8NEW(__m128i, INTGEMM_SSSE3, CPUType::SSE2)
+ //INTGEMM_PREPARE_BIAS_FOR_8(INTGEMM_SSSE3, __m128)
+
+ INTGEMM_PREPAREBIASFOR8(__m128i, INTGEMM_SSSE3, CPUType::SSE2)
+
constexpr static const char *const kName = "8-bit INTGEMM_SSSE3";
static const CPUType kUses = CPUType::SSSE3;
diff --git a/test/add127_test.cc b/test/add127_test.cc
new file mode 100644
index 0000000..e466c9d
--- /dev/null
+++ b/test/add127_test.cc
@@ -0,0 +1,258 @@
+#include "test/test.h"
+
+namespace intgemm {
+
+void newBias(const float * input, float * bias, float* output, float alpha, Index rows, Index width, Index cols) {
+ AlignedVector<float> intermediate(rows*cols);
+ AlignedVector<float> ones(rows*width);
+ for (auto&& it : ones) {
+ it = 1;
+ }
+ SlowRefFloat(ones.begin(), input, intermediate.begin(), rows, width, cols);
+ for (auto&& it : intermediate) {
+ it = it*alpha;
+ }
+
+
+ for (Index c = 0; c<cols; c++) {
+ output[c] = bias[c] - intermediate.begin()[rows*c];
+ }
+
+}
+
+void SlowSumB(const int8_t * input, float * bias, float* output, float alpha, Index rows, Index cols) {
+ for (Index r = 0; r<rows; r++) {
+ for (Index c = 0; c<cols; c++) {
+ output[c] += input[r * cols + c];
+ }
+ }
+
+ for (Index c = 0; c<cols; c++) {
+ output[c] = bias[c] + output[c]*alpha;
+ }
+}
+
+void CompareAs(int8_t * output_old, uint8_t * output_new, Index rows, Index cols) {
+ for (Index r = 0; r<rows; r++) {
+ for (Index c = 0; c<cols; c++) {
+ int a = int(output_old[rows*c + r]);
+ int b = int(output_new[rows*c + r]);
+ INFO("Inaccurate at row: " << r << " column " << c << ' '
+ << a << ' ' << b);
+ CHECK(a+127 == b);
+ }
+ }
+}
+
+void CompareBiases(const float *bias_ref, const float *bias, Index cols) {
+ for (std::size_t i = 0; i < cols; ++i) {
+ INFO("Inaccurate at " << i << ' ' << bias_ref[i] << ' ' << bias[i]);
+ CHECK(fabs(bias_ref[i] - bias[i]) < 0.0001);
+ }
+}
+
+template <class Routine> void TestPrepareA(Index rows, Index cols) {
+ std::mt19937 gen;
+ // Go somewhat out of range too.
+ std::uniform_real_distribution<float> dist(-2, 2);
+ // Create array.
+ AlignedVector<float> inputA(rows * cols);
+ for (auto& it : inputA) {
+ it = dist(gen);
+ }
+ AlignedVector<int8_t> oldA(rows * cols);
+ AlignedVector<uint8_t> newA(rows * cols);
+ float quant_mult = 64; //From example
+ Routine::PrepareA(inputA.begin(), oldA.begin(), quant_mult, rows, cols);
+ Routine::PrepareA(inputA.begin(), newA.begin(), quant_mult, rows, cols);
+ CompareAs(oldA.begin(), newA.begin(), rows, cols);
+}
+
+template <class Routine> void TestPrepareBias(Index rows, Index cols) {
+ std::mt19937 gen;
+ // Go somewhat out of range too.
+ std::uniform_real_distribution<float> dist(-30.0, 30.0);
+ // Create array.
+ AlignedVector<float> inputB(rows * cols);
+ for (auto& it : inputB) {
+ it = dist(gen);
+ }
+
+ float alpha = 25;
+ float quant_mult = 127/alpha;
+
+ AlignedVector<int8_t> B_prep(inputB.size());
+ AlignedVector<int8_t> B_quant(inputB.size());
+ Routine::PrepareB(inputB.begin(), B_prep.begin(), quant_mult, rows, cols);
+ Routine::Quantize(inputB.begin(), B_quant.begin(), quant_mult, inputB.size());
+
+
+ AlignedVector<float> inputBias(cols);
+ AlignedVector<float> goldBias(cols);
+
+ for (auto& it : goldBias) {
+ it = 0;
+ }
+ for (auto& it : inputBias) {
+ it = dist(gen);
+ }
+
+ float unquant_mult_forprep = (-1)*(alpha)*(alpha)/(127.0f);
+
+ SlowSumB(B_quant.begin(), inputBias.begin(), goldBias.begin(), alpha, rows, cols);
+
+ Routine::PrepareBiasFor8(1, B_prep.begin(), 1, rows, cols, callbacks::UnquantizeAndAddBiasAndWrite(unquant_mult_forprep, inputBias.begin(), inputBias.begin()));
+
+ CompareBiases(goldBias.begin(), inputBias.begin(), cols);
+}
+
+template <class Routine> void TestMultiplyBiasNew(Index A_rows, Index width, Index B_cols,
+ float int_tolerance=.1, float float_tolerance=1, float MSE_float_tolerance=0, float MSE_int_tolerance=0) {
+ std::ostringstream info;
+ info << Routine::kName << "\t" << A_rows << '\t' << width << '\t' << B_cols << '\n';
+
+ // Initialize A and B.
+ AlignedVector<float> A(A_rows * width);
+ AlignedVector<float> B(width * B_cols);
+ AlignedVector<float> bias(B_cols);
+ std::mt19937 gen;
+ std::uniform_real_distribution<float> dist(-1.0f, 1.0f);
+ for (auto& it : A) {
+ it = dist(gen);
+ }
+ for (auto& it : B) {
+ it = dist(gen);
+ }
+ for (auto& it : bias) {
+ it = dist(gen);
+ }
+
+ float alpha = 2.0f;
+ float quant_mult = 127/alpha;
+ float unquant_mult = 1.0/(quant_mult*quant_mult);
+
+ AlignedVector<uint8_t> A_prep(A.size());
+ AlignedVector<int8_t> B_prep(B.size());
+ Routine::PrepareA(A.begin(), A_prep.begin(), quant_mult, A_rows, width);
+ Routine::PrepareB(B.begin(), B_prep.begin(), quant_mult, width, B_cols);
+
+ AlignedVector<float> test_C(A_rows * B_cols);
+
+ /*REFERENCE MULTIPLICATION
+ *
+ *
+ */
+ AlignedVector<int8_t> B_quant(B.size());
+ Routine::Quantize(B.begin(), B_quant.begin(), quant_mult, B.size());
+ AlignedVector<float> slowint_C(test_C.size());
+ // Taking the original A_preparation which means A would be int8_t
+ AlignedVector<int8_t> A_prep2(A.size());
+ Routine::PrepareA(A.begin(), A_prep2.begin(), quant_mult, A_rows, width);
+ SlowRefInt(A_prep2.begin(), B_quant.begin(), slowint_C.begin(), unquant_mult, A_rows, width, B_cols, bias.begin());
+
+ AlignedVector<float> float_C(test_C.size());
+ SlowRefFloat(A.begin(), B.begin(), float_C.begin(), A_rows, width, B_cols, bias.begin());
+
+ /*ACTUAL MULTIPLICATION
+ *
+ */
+ float unquant_mult_forprep = (-1)*(alpha)*(alpha)/(127.0f); //Minus one to invert add_ps later on
+ Routine::PrepareBiasFor8(1, B_prep.begin(), 1, width, B_cols, callbacks::UnquantizeAndAddBiasAndWrite(unquant_mult_forprep, bias.begin(), bias.begin()));
+ //Routine::PrepareBiasFor8(B.begin(), bias.begin(), alpha, width, B_cols);
+ Routine::Multiply8new(A_prep.begin(), B_prep.begin(), A_rows, width, B_cols, callbacks::UnquantizeAndAddBiasAndWrite(unquant_mult, bias.begin(), test_C.begin()));
+
+ Compare(float_C.begin(), slowint_C.begin(), test_C.begin(), test_C.size(), info.str(),
+ int_tolerance, float_tolerance, MSE_float_tolerance, MSE_int_tolerance);
+}
+
+/*
+// Bias
+TEST_CASE("PrepareBias SSSE3", "[Add127]") {
+ if (kCPU < CPUType::SSSE3) return;
+ TestPrepareBias<SSSE3_8bit>(8,8);
+ TestPrepareBias<SSSE3_8bit>(256,256);
+ TestPrepareBias<SSSE3_8bit>(2048,256);
+ TestPrepareBias<SSSE3_8bit>(512,512);
+}
+
+TEST_CASE("PrepareBias AVX2", "[Add127]") {
+ if (kCPU < CPUType::AVX2) return;
+ TestPrepareBias<AVX2_8bit>(8,8);
+ TestPrepareBias<AVX2_8bit>(256,256);
+ TestPrepareBias<AVX2_8bit>(2048,256);
+ TestPrepareBias<AVX2_8bit>(512,512);
+}
+
+TEST_CASE("PrepareBias AVX512F", "[Add127]") {
+ if (kCPU < CPUType::AVX512BW) return;
+ #ifndef INTGEMM_NO_AVX512
+ TestPrepareBias<AVX512_8bit>(8,8);
+ TestPrepareBias<AVX512_8bit>(256,256);
+ TestPrepareBias<AVX512_8bit>(2048,256);
+ TestPrepareBias<AVX512_8bit>(512,512);
+ #endif
+}*/
+
+//A
+TEST_CASE("PrepareA SSSE3", "[Add127]") {
+ if (kCPU < CPUType::SSSE3) return;
+ TestPrepareA<SSSE3_8bit>(64,64);
+ TestPrepareA<SSSE3_8bit>(256,256);
+ TestPrepareA<SSSE3_8bit>(512,512);
+ TestPrepareA<SSSE3_8bit>(2048,256);
+}
+
+TEST_CASE("PrepareA AVX2", "[Add127]") {
+ if (kCPU < CPUType::AVX2) return;
+ TestPrepareA<AVX2_8bit>(64,64);
+ TestPrepareA<AVX2_8bit>(256,256);
+ TestPrepareA<AVX2_8bit>(512,512);
+ TestPrepareA<AVX2_8bit>(2048,256);
+}
+
+TEST_CASE("PrepareA AVX512F", "[Add127]") {
+ if (kCPU < CPUType::AVX512BW) return;
+ #ifndef INTGEMM_NO_AVX512
+ TestPrepareA<AVX512_8bit>(64,64);
+ TestPrepareA<AVX512_8bit>(256,256);
+ TestPrepareA<AVX512_8bit>(512,512);
+ TestPrepareA<AVX512_8bit>(2048,256);
+ #endif
+}
+
+// Multiply
+
+TEST_CASE ("Multiply SSSE3 8bit with new bias", "[Add127]") {
+ if (kCPU < CPUType::SSSE3) return;
+ TestMultiplyBiasNew<SSSE3_8bit>(1, 64, 8, 0.11, 0.1, 0.06, 0.05);
+ TestMultiplyBiasNew<SSSE3_8bit>(8, 256, 256, 0.45, 0.54, 0.17, 0.16); // 0.064, 0.026);
+ TestMultiplyBiasNew<SSSE3_8bit>(8, 2048, 256, 1.7, 1.7, 0.46, 0.43); // 4.4, 4.4);
+ TestMultiplyBiasNew<SSSE3_8bit>(320, 256, 256, 0.56, 0.64, 0.16, 0.15); // 0.1, 0.01);
+ TestMultiplyBiasNew<SSSE3_8bit>(472, 256, 256, 0.46, 0.62, 0.17, 0.16); // 0.1, 0.011);
+ TestMultiplyBiasNew<SSSE3_8bit>(248, 256, 256, 0.48, 0.64, 0.16, 0.15); // 0.1, 0.012);
+ TestMultiplyBiasNew<SSSE3_8bit>(200, 256, 256, 0.55, 0.74, 0.17, 0.16); // 0.1, 0.011);
+}
+
+TEST_CASE ("Multiply AVX2 8bit with new bias", "[Add127]") {
+ if (kCPU < CPUType::AVX2) return;
+ TestMultiplyBiasNew<AVX2_8bit>(1, 64, 8, 0.11, 0.11, 0.06, 0.05);
+ TestMultiplyBiasNew<AVX2_8bit>(8, 256, 256, 0.49, 0.54, 0.17, 0.16); //0.1, 0);
+ TestMultiplyBiasNew<AVX2_8bit>(8, 2048, 256, 1.57, 1.66, 0.46, 0.46); //1.8, 1.8);
+ TestMultiplyBiasNew<AVX2_8bit>(320, 256, 256, 0.49, 0.64, 0.16, 0.15); //0.1, 0);
+ TestMultiplyBiasNew<AVX2_8bit>(472, 256, 256, 0.46, 0.62, 0.17, 0.16); //0.1, 0);
+ TestMultiplyBiasNew<AVX2_8bit>(248, 256, 256, 0.48, 0.64, 0.16, 0.15); //0.1, 0);
+ TestMultiplyBiasNew<AVX2_8bit>(200, 256, 256, 0.55, 0.74, 0.17, 0.16); //0.1, 0);
+}
+
+TEST_CASE ("Multiply AVX512F 8bit with new bias", "[Add127]") {
+ if (kCPU < CPUType::AVX512BW) return;
+ TestMultiplyBiasNew<AVX512_8bit>(1, 64, 8, 0.11, 0.11, 0.06, 0.05);
+ TestMultiplyBiasNew<AVX512_8bit>(8, 256, 256, 0.48, 0.54, 0.17, 0.16); //, 1.6, 1.6); //0.1, 0);
+ TestMultiplyBiasNew<AVX512_8bit>(8, 2048, 256, 1.57, 1.66, 0.46, 0.43); //1.8, 1.8);
+ TestMultiplyBiasNew<AVX512_8bit>(320, 256, 256, 0.48, 0.64, 0.16, 0.15); //0.1, 0);
+ TestMultiplyBiasNew<AVX512_8bit>(472, 256, 256, 0.46, 0.62, 0.17, 0.16); //0.1, 0);
+ TestMultiplyBiasNew<AVX512_8bit>(248, 256, 256, 0.48, 0.64, 0.16, 0.15); //0.1, 0);
+ TestMultiplyBiasNew<AVX512_8bit>(200, 256, 256, 0.57, 0.74, 0.17, 0.16); //0.1, 0);
+}
+
+} //namespace intgemm
diff --git a/test/multiply_test.cc b/test/multiply_test.cc
index 5dac484..82a11ef 100644
--- a/test/multiply_test.cc
+++ b/test/multiply_test.cc
@@ -15,7 +15,6 @@
#include <iostream>
#include <memory>
#include <random>
-#include <sstream>
namespace intgemm {
@@ -286,53 +285,6 @@ TEST_CASE("MaxAbsolute AVX512F", "[max]") {
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
// SOFTWARE.
// Compute A*B slowly in floats.
-void SlowRefFloat(const float *A, const float *B, float *C, Index A_rows, Index width, Index B_cols, const float *bias=nullptr) {
- for (Index r = 0; r < A_rows; ++r) {
- for (Index c = 0; c < B_cols; ++c) {
- float sum = 0.0f;
- for (Index k = 0; k < width; ++k) {
- sum += A[r * width + k] * B[k * B_cols + c];
- }
- if (bias) {
- C[r * B_cols + c] = sum + bias[c];
- } else {
- C[r * B_cols + c] = sum;
- }
- }
- }
-}
-
-// Compute A*B slowly from integers.
-template <class Integer> void SlowRefInt(const Integer *A, const Integer *B, float *C, float unquant_mult, Index A_rows, Index width, Index B_cols, const float *bias=nullptr) {
- for (Index r = 0; r < A_rows; ++r) {
- for (Index c = 0; c < B_cols; ++c) {
- int32_t sum = 0;
- for (Index k = 0; k < width; ++k) {
- sum += static_cast<int16_t>(A[r * width + k]) * static_cast<int16_t>(B[k * B_cols + c]);
- }
- if (bias) {
- C[r * B_cols + c] = sum * unquant_mult + bias[c];
- } else {
- C[r * B_cols + c] = sum * unquant_mult;
- }
- }
- }
-}
-
-void Compare(const float *float_ref, const float *int_ref, const float *int_test, std::size_t size, std::string test_info,
- float int_tolerance, float float_tolerance, float MSE_float_tolerance, float MSE_int_tolerance) {
- float int_sum = 0.0, float_sum = 0.0;
- for (std::size_t i = 0; i < size; ++i) {
- float int_diff = int_ref[i] - int_test[i];
- float float_diff = float_ref[i] - int_test[i];
- CHECK_MESSAGE(fabs(int_diff) <= int_tolerance, test_info << "Inaccurate compared to int reference at " << i << ' ' << int_ref[i] << ' ' << int_test[i]);
- CHECK_MESSAGE(fabs(float_diff) <= float_tolerance, test_info << "Inaccurate compared to float reference at " << i << ' ' << float_ref[i] << ' ' << int_test[i]);
- int_sum += int_diff * int_diff;
- float_sum += float_diff * float_diff;
- }
- CHECK_MESSAGE(fabs(sqrt(float_sum / size)) <= MSE_float_tolerance, test_info << "Float MSE = " << sqrt(float_sum / size));
- CHECK_MESSAGE(fabs(sqrt(int_sum / size)) <= MSE_int_tolerance, test_info << "Int MSE = " << sqrt(int_sum / size));
-}
template <class Routine> void TestMultiply(Index A_rows, Index width, Index B_cols,
float int_tolerance=.1, float float_tolerance=1, float MSE_float_tolerance=0, float MSE_int_tolerance=0) {
diff --git a/test/test.cc b/test/test.cc
index 58c62f8..cb45b73 100644
--- a/test/test.cc
+++ b/test/test.cc
@@ -4,3 +4,59 @@
int main(int argc, char ** argv) {
return Catch::Session().run(argc, argv);
}
+
+namespace intgemm {
+
+void SlowRefFloat(const float *A, const float *B, float *C, Index A_rows, Index width, Index B_cols, const float *bias) {
+ for (Index r = 0; r < A_rows; ++r) {
+ for (Index c = 0; c < B_cols; ++c) {
+ float sum = 0.0f;
+ for (Index k = 0; k < width; ++k) {
+ sum += A[r * width + k] * B[k * B_cols + c];
+ }
+ if (bias) {
+ C[r * B_cols + c] = sum + bias[c];
+ } else {
+ C[r * B_cols + c] = sum;
+ }
+ }
+ }
+}
+
+// Compute A*B slowly from integers.
+template <class Integer> void SlowRefInt(const Integer *A, const Integer *B, float *C, float unquant_mult, Index A_rows, Index width, Index B_cols, const float *bias) {
+ for (Index r = 0; r < A_rows; ++r) {
+ for (Index c = 0; c < B_cols; ++c) {
+ int32_t sum = 0;
+ for (Index k = 0; k < width; ++k) {
+ sum += static_cast<int16_t>(A[r * width + k]) * static_cast<int16_t>(B[k * B_cols + c]);
+ }
+ if (bias) {
+ C[r * B_cols + c] = sum * unquant_mult + bias[c];
+ } else {
+ C[r * B_cols + c] = sum * unquant_mult;
+ }
+ }
+ }
+}
+
+template void SlowRefInt<int8_t>(const int8_t *A, const int8_t *B, float *C, float unquant_mult, Index A_rows, Index width, Index B_cols, const float *bias);
+template void SlowRefInt<int16_t>(const int16_t *A, const int16_t *B, float *C, float unquant_mult, Index A_rows, Index width, Index B_cols, const float *bias);
+template void SlowRefInt<int32_t>(const int32_t *A, const int32_t *B, float *C, float unquant_mult, Index A_rows, Index width, Index B_cols, const float *bias);
+
+void Compare(const float *float_ref, const float *int_ref, const float *int_test, std::size_t size, std::string test_info,
+ float int_tolerance, float float_tolerance, float MSE_float_tolerance, float MSE_int_tolerance) {
+ float int_sum = 0.0, float_sum = 0.0;
+ for (std::size_t i = 0; i < size; ++i) {
+ float int_diff = int_ref[i] - int_test[i];
+ float float_diff = float_ref[i] - int_test[i];
+ CHECK_MESSAGE(fabs(int_diff) <= int_tolerance, test_info << "Inaccurate compared to int reference at " << i << ' ' << int_ref[i] << ' ' << int_test[i]);
+ CHECK_MESSAGE(fabs(float_diff) <= float_tolerance, test_info << "Inaccurate compared to float reference at " << i << ' ' << float_ref[i] << ' ' << int_test[i]);
+ int_sum += int_diff * int_diff;
+ float_sum += float_diff * float_diff;
+ }
+ CHECK_MESSAGE(fabs(sqrt(float_sum / size)) <= MSE_float_tolerance, test_info << "Float MSE = " << sqrt(float_sum / size));
+ CHECK_MESSAGE(fabs(sqrt(int_sum / size)) <= MSE_int_tolerance, test_info << "Int MSE = " << sqrt(int_sum / size));
+}
+
+} //namespace intgemm
diff --git a/test/test.h b/test/test.h
index f2f745d..1034337 100644
--- a/test/test.h
+++ b/test/test.h
@@ -1,4 +1,9 @@
+#pragma once
+
#include "3rd_party/catch.hpp"
+#include <sstream>
+#include "intgemm.h"
+#include "aligned.h"
#include "config.h"
@@ -14,3 +19,14 @@
} while(0)
#define KERNEL_TEST_CASE(name) TEST_CASE("Kernel: " name, "[kernel_test]")
+
+namespace intgemm {
+void SlowRefFloat(const float *A, const float *B, float *C, Index A_rows, Index width, Index B_cols, const float *bias=nullptr);
+
+// Compute A*B slowly from integers.
+template <class Integer> void SlowRefInt(const Integer *A, const Integer *B, float *C, float unquant_mult, Index A_rows, Index width, Index B_cols, const float *bias=nullptr);
+
+void Compare(const float *float_ref, const float *int_ref, const float *int_test, std::size_t size, std::string test_info,
+ float int_tolerance, float float_tolerance, float MSE_float_tolerance, float MSE_int_tolerance);
+
+} //namespace intgemm
diff --git a/test_mull.cpp b/test_mull.cpp
new file mode 100644
index 0000000..e83f1c9
--- /dev/null
+++ b/test_mull.cpp
@@ -0,0 +1,328 @@
+#include "intgemm.cc"
+#include "aligned.h"
+#include <iostream>
+#include <random>
+#include <string>
+#include <algorithm>
+#include <fstream>
+#include <sstream>
+
+
+/*Adapted from https://www.bfilipek.com/2018/07/string-view-perf-followup.html . We should probably go string_view way
+inline void tokenizeLine(std::string& str, std::vector<std::string>& output,
+ std::string delimeter = " ") {
+ auto first = std::begin(str);
+
+ while (first != str.end()) {
+ const auto second = std::find_first_of(first, std::end(str), std::begin(delimeter), std::end(delimeter));
+
+ if (first != second) {
+ output.emplace_back(str.substr(std::distance(std::begin(str), first), std::distance(first, second)));
+ }
+
+ if (second == str.end())
+ break;
+
+ first = std::next(second);
+ }
+}
+
+//This is a different parsing method, without stringStream
+template<class StringType>
+void ReadInFile(StringType infile) {
+ std::ifstream in(infile);
+ std::string line;
+
+ //First line, Info about the matrix
+ std::getline(in, line);
+ std::istringstream iss(line);
+ std::string temp1, temp2, temp3, temp4;
+ int RowsA, ColsA, RowsB, ColsB;
+ if (!(iss >> temp1 >> RowsA >> temp2 >> ColsA >> temp3 >> RowsB >> temp4 >> ColsB)) {
+ std::cerr << "Error parsing line 1 " << std::endl;
+ exit(1);
+ }
+
+ //Second line, get QuantMult
+ std::getline(in, line);
+ std::istringstream iss2(line);
+ float quantMultA, quantMultB;
+ if (!(iss2 >> temp1 >> quantMultA >> temp2 >> quantMultA)) {
+ std::cerr << "Error parsing line 2 " << std::endl;
+ exit(1);
+ }
+ std::getline(in, line); //Just some text
+ //Fourth line, AQuant
+ std::vector<int> AQuant;
+ std::getline(in, line);
+ std::vector<std::string> tmp_container;
+ tokenizeLine(line, tmp_container);
+ if (tmp_container.size() != RowsA*ColsA) {
+ std::cerr << "Error parsing matrix A. Size mismatch. Expected " << RowsA*ColsA << " got " << tmp_container.size() << std::endl;
+ }
+ for (auto&& num : tmp_container) {
+ AQuant.push_back(std::stoi(num));
+ }
+ tmp_container.resize(0);
+
+ std::getline(in, line); //Just some text
+ //Sixth line, B_raw
+ std::vector<float> B_raw;
+ std::getline(in, line);
+ tokenizeLine(line, tmp_container);
+ if (tmp_container.size() != RowsB*ColsB) {
+ std::cerr << "Error parsing matrix B. Size mismatch. Expected " << RowsB*ColsB << " got " << tmp_container.size() << std::endl;
+ }
+ for (auto&& num : tmp_container) {
+ B_raw.push_back(std::stof(num));
+ }
+ tmp_container.resize(0);
+
+ std::getline(in, line); //Just some text
+ //Eight line, Bias
+ std::vector<float> Bias;
+ std::getline(in, line);
+ tokenizeLine(line, tmp_container);
+ if (tmp_container.size() != ColsB) {
+ std::cerr << "Error parsing bias. Size mismatch. Expected " << ColsB << " got " << tmp_container.size() << std::endl;
+ }
+ for (auto&& num : tmp_container) {
+ Bias.push_back(std::stof(num));
+ }
+ tmp_container.resize(0);
+
+}
+
+*/
+template<class StringType>
+void ReadInFile(StringType infile) {
+ std::ifstream in(infile);
+ std::string line;
+
+ //First line, Info about the matrix
+ std::getline(in, line);
+ std::istringstream iss(line);
+ std::string temp1, temp2, temp3, temp4;
+ int RowsA, ColsA, RowsB, ColsB;
+ if (!(iss >> temp1 >> RowsA >> temp2 >> ColsA >> temp3 >> RowsB >> temp4 >> ColsB)) {
+ std::cerr << "Error parsing line 1 " << std::endl;
+ exit(1);
+ }
+
+ //Second line, get QuantMult
+ std::getline(in, line);
+ std::istringstream iss2(line);
+ float quantMultA, quantMultB;
+ if (!(iss2 >> temp1 >> quantMultA >> temp2 >> quantMultA)) {
+ std::cerr << "Error parsing line 2 " << std::endl;
+ exit(1);
+ }
+ std::getline(in, line); //Just some text for human readability
+
+ //4th line, AQuant
+ std::vector<int> AQuant;
+ std::getline(in, line);
+ std::istringstream iss3(line);
+ for (int i = 0; i < RowsA*ColsA; i++) {
+ int num;
+ if (!(iss3 >> num)) {
+ std::cerr << "Error parsing matrix A at " << i << std::endl;;
+ }
+ AQuant.push_back(num);
+ }
+
+ std::getline(in, line); //Just some text for human readability
+ //6th line, B_raw
+ std::vector<float> B_raw;
+ std::getline(in, line);
+ std::istringstream iss4(line);
+ for (int i = 0; i < RowsB*ColsB; i++) {
+ float num;
+ if (!(iss4 >> num)) {
+ std::cerr << "Error parsing matrix B " << std::endl;
+ }
+ B_raw.push_back(num);
+ }
+
+ std::getline(in, line); //Just some text for human readability
+ //8th line, Bias
+ std::vector<float> Bias;
+ std::getline(in, line);
+ std::istringstream iss5(line);
+ for (int i = 0; i < ColsB; i++) {
+ float num;
+ if (!(iss5 >> num)) {
+ std::cerr << "Error parsing matrix bias " << std::endl;
+ }
+ Bias.push_back(num);
+ }
+}
+
+using namespace intgemm;
+template<class T>
+void printMatrix(T* data, Index rows, Index cols) {
+ std::cout << "[";
+ for (int i = 0; i<rows; i++) {
+ std::cout << "[";
+ for (int j =0; j<cols; j++) {
+ std::cout << (float)data[i*cols + j];
+ if (j != cols - 1) {
+ std::cout << ", ";
+ }
+ }
+ std::cout << "]";
+ if (i != rows -1) {
+ std::cout << ',' << std::endl;
+ }
+ }
+ std::cout << "]" << std::endl;
+}
+
+void SlowRefFloat(const float *A, const float *B, float *C, Index A_rows, Index width, Index B_cols, const float *bias) {
+ for (Index r = 0; r < A_rows; ++r) {
+ for (Index c = 0; c < B_cols; ++c) {
+ float sum = 0.0f;
+ for (Index k = 0; k < width; ++k) {
+ sum += A[r * width + k] * B[k * B_cols + c];
+ }
+ if (bias) {
+ C[r * B_cols + c] = sum + bias[c];
+ } else {
+ C[r * B_cols + c] = sum;
+ }
+ }
+ }
+}
+
+// Compute A*B slowly from integers.
+template <class Integer>
+void SlowRefInt(const Integer *A, const Integer *B, float *C, float unquant_mult, Index A_rows, Index width, Index B_cols, const float *bias) {
+ for (Index r = 0; r < A_rows; ++r) {
+ for (Index c = 0; c < B_cols; ++c) {
+ int32_t sum = 0;
+ for (Index k = 0; k < width; ++k) {
+ sum += static_cast<int16_t>(A[r * width + k]) * static_cast<int16_t>(B[k * B_cols + c]);
+ }
+ if (bias) {
+ C[r * B_cols + c] = sum * unquant_mult + bias[c];
+ } else {
+ C[r * B_cols + c] = sum * unquant_mult;
+ }
+ }
+ }
+}
+
+int main() {
+
+ const Index A_rows = 1;
+ const Index width = 2048;
+ const Index B_cols = 8;
+
+ AlignedVector<float> A(A_rows * width);
+ AlignedVector<float> B(width * B_cols);
+ AlignedVector<float> bias(B_cols);
+
+ float alpha = 2.0f;
+ float quant_mult = 127/alpha;
+ float unquant_mult = 1.0 / (quant_mult * quant_mult);
+
+ std::mt19937 gen;
+ std::uniform_real_distribution<float> dist(-2.0f, 2.0f);
+
+ for (auto& it : A) {
+ it = dist(gen);
+ }
+ for (auto& it : B) {
+ it = dist(gen);
+ }
+ for (auto& it : bias) {
+ it = dist(gen);
+ }
+
+ AlignedVector<float> bias_orig(B_cols);
+ for (int i = 0; i < bias.size(); i++) {
+ bias_orig[i] = bias[i];
+ }
+
+ AlignedVector<int8_t> A_prep(A.size());
+ AlignedVector<int8_t> B_prep(B.size());
+
+ AVX2_8bit::PrepareA(A.begin(), A_prep.begin(), quant_mult, A_rows, width);
+ AVX2_8bit::PrepareB(B.begin(), B_prep.begin(), quant_mult, width, B_cols);
+/*
+ std::cout << "A:" << std::endl;
+ printMatrix(A.begin(), A_rows, width);
+ std::cout << "B:" << std::endl;
+ printMatrix(B.begin(), width, B_cols);
+ std::cout << "bias:" << std::endl;
+ printMatrix(bias.begin(), 1, B_cols);*/
+
+
+ AlignedVector<float> test_C(A_rows * B_cols);
+
+ AVX2_8bit::Multiply(A_prep.begin(), B_prep.begin(), A_rows, width, B_cols, UnquantizeAndAddBiasAndWrite(unquant_mult, bias.begin(), test_C.begin()));
+ //AVX2_8bit::Multiply(A_prep.begin(), B_prep.begin(), JustUnquantizeC(test_C.begin(), unquant_mult), A_rows, width, B_cols);
+ std::cout << "Old multiply:" << std::endl;
+ printMatrix(test_C.begin(), A_rows, B_cols);
+
+ //NEEEXT
+ AlignedVector<uint8_t> A_prep2(A.size());
+ AVX2_8bit::PrepareA(A.begin(), A_prep2.begin(), quant_mult, A_rows, width);
+
+ AVX2_8bit::PrepareBiasFor8(B.begin(), bias.begin(), alpha, width, B_cols);
+
+ //printMatrix(bias.begin(), 1, B_cols); //Print bias
+
+ AVX2_8bit::Multiply8new(reinterpret_cast<uint8_t*>(A_prep2.begin()), B_prep.begin(), A_rows, width, B_cols, UnquantizeAndAddBiasAndWrite(unquant_mult, bias.begin(), test_C.begin()));
+ //AVX2_8bit::Multiply8new(reinterpret_cast<uint8_t*>(A_prep.begin()), B_prep.begin(), JustUnquantizeC(test_C.begin(), unquant_mult), A_rows, width, B_cols);
+
+ AlignedVector<int16_t> A_prep3(A.size());
+ AlignedVector<int16_t> B_prep3(B.size());
+ std::cout << "New multiply:" << std::endl;
+ printMatrix(test_C.begin(), A_rows, B_cols);
+ for (int i = 0; i < A_prep2.size(); i++) {
+ A_prep3[i] = A_prep2[i];
+ }
+ AVX2_16bit::PrepareB(B.begin(), B_prep3.begin(), quant_mult, width, B_cols);
+ AVX2_16bit::Multiply(A_prep3.begin(), B_prep3.begin(), A_rows, width, B_cols, UnquantizeAndAddBiasAndWrite(unquant_mult, bias.begin(), test_C.begin()));
+
+ std::cout << "New multiply, 16 bit:" << std::endl;
+ printMatrix(test_C.begin(), A_rows, B_cols);
+
+ //FULL INTS
+ AlignedVector<float> C_slowint(A_rows * B_cols);
+ AlignedVector<int8_t> B_quant(width * B_cols);
+ AVX2_8bit::Quantize(B.begin(), B_quant.begin(), quant_mult, B.size());
+
+ SlowRefInt(A_prep.begin(), B_quant.begin(), C_slowint.begin(),
+ unquant_mult, A_rows, width, B_cols, bias_orig.begin());
+
+
+ std::cout << "Reference int8:" << std::endl;
+ printMatrix(C_slowint.begin(), A_rows, B_cols);
+
+ //FULL INT16
+ AlignedVector<int16_t> A_prep4(A.size());
+ for (int i = 0; i < A_prep2.size(); i++) {
+ A_prep4[i] = A_prep[i];
+ }
+
+ AlignedVector<float> C_slowint2(A_rows * B_cols);
+ AlignedVector<int16_t> B_quant2(width * B_cols);
+ AVX2_16bit::Quantize(B.begin(), B_quant2.begin(), quant_mult, B.size());
+
+ SlowRefInt(A_prep4.begin(), B_quant2.begin(), C_slowint2.begin(),
+ unquant_mult, A_rows, width, B_cols, bias_orig.begin());
+
+
+ std::cout << "Reference int16:" << std::endl;
+ printMatrix(C_slowint2.begin(), A_rows, B_cols);
+
+ //FLOATS
+ AlignedVector<float> C(A_rows * B_cols);
+
+ SlowRefFloat(A.begin(), B.begin(), C.begin(), A_rows, width, B_cols, bias_orig.begin());
+ std::cout << "Reference float:" << std::endl;
+ printMatrix(C.begin(), A_rows, B_cols);
+
+}