diff options
author | Mateusz Chudyk <mateuszchudyk@gmail.com> | 2019-08-23 13:52:44 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2019-08-23 13:52:44 +0300 |
commit | 2d9f646d5fe397c1a8660b91f2a5021c215b35ae (patch) | |
tree | cdfa69bc76ccc44415011ff5eb17452d1b0e0f3c | |
parent | 66c40eed8b649abe2f903ceca2279abe78d5f385 (diff) | |
parent | 773cb5271efca4c7d9efe1143f22cf81a472f774 (diff) |
Merge pull request #30 from kpu/add127_fullupcast
Add127 fullupcast
-rw-r--r-- | CMakeLists.txt | 10 | ||||
-rw-r--r-- | avx2_gemm.h | 50 | ||||
-rw-r--r-- | avx512_gemm.h | 35 | ||||
-rw-r--r-- | benchmarks/benchmark.cc (renamed from benchmark.cc) | 0 | ||||
-rw-r--r-- | benchmarks/biasmultiply.cc | 239 | ||||
-rw-r--r-- | interleave.h | 17 | ||||
-rw-r--r-- | intgemm.cc | 4 | ||||
-rw-r--r-- | intgemm.h | 40 | ||||
-rw-r--r-- | intrinsics.h | 12 | ||||
-rw-r--r-- | multiply.h | 155 | ||||
-rw-r--r-- | ssse3_gemm.h | 53 | ||||
-rw-r--r-- | test/add127_test.cc | 258 | ||||
-rw-r--r-- | test/multiply_test.cc | 48 | ||||
-rw-r--r-- | test/test.cc | 56 | ||||
-rw-r--r-- | test/test.h | 16 | ||||
-rw-r--r-- | test_mull.cpp | 328 |
16 files changed, 1267 insertions, 54 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt index 3c6f3ff..322fc88 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -34,16 +34,20 @@ configure_file(${CMAKE_CURRENT_SOURCE_DIR}/config.h.in ${CMAKE_CURRENT_BINARY_DI include_directories(${CMAKE_CURRENT_SOURCE_DIR}) include_directories(${CMAKE_CURRENT_BINARY_DIR}) -foreach(exe example benchmark) - add_executable(${exe} ${exe}.cc intgemm.cc) +foreach(exe benchmark biasmultiply) + add_executable(${exe} benchmarks/${exe}.cc intgemm.cc) endforeach() +add_executable(example example.cc intgemm.cc) + add_executable(tests + test/test.cc + # General tests test/multiply_test.cc test/quantize_test.cc - test/test.cc test/utils_test.cc + test/add127_test.cc # Kernels tests test/kernels/add_bias_test.cc diff --git a/avx2_gemm.h b/avx2_gemm.h index ed3f895..58221d9 100644 --- a/avx2_gemm.h +++ b/avx2_gemm.h @@ -103,6 +103,10 @@ class QuantizeTile8 { return Tile(input, input + 8, input + 16, input + 24); } + INTGEMM_AVX2 inline __m256i ConsecutiveU(const float *input) { + return TileU(input, input + 8, input + 16, input + 24); + } + INTGEMM_AVX2 inline __m256i ForReshape(const float *input, Index cols) { // Put higher rows in the second half of the register. These will jumble // around in the same way then conveniently land in the right place. @@ -132,6 +136,32 @@ class QuantizeTile8 { // and the values are only used for GEMM. return _mm256_permutevar8x32_epi32(packed, shuffle_param); } + + //A version that produces uint8_ts + INTGEMM_AVX2 inline __m256i TileU(const float *input0, const float *input1, const float *input2, const float *input3) { + // Looking at the assembly, gcc has pulled this outside the loops calling this. + const __m256i neg127 = _mm256_set1_epi8(-127); + const __m256i pos127 = _mm256_set1_epi8(127); + const __m256i shuffle_param = _mm256_set_epi32(7, 3, 6, 2, 5, 1, 4, 0); + // Grab 4 registers at a time in 32-bit format. + __m256i g0 = avx2::QuantizerGrab(input0, mult_); + __m256i g1 = avx2::QuantizerGrab(input1, mult_); + __m256i g2 = avx2::QuantizerGrab(input2, mult_); + __m256i g3 = avx2::QuantizerGrab(input3, mult_); + // Pack 32-bit to 16-bit. + __m256i packed0 = _mm256_packs_epi32(g0, g1); + __m256i packed1 = _mm256_packs_epi32(g2, g3); + // Pack 16-bit to 8-bit. + __m256i packed = _mm256_packs_epi16(packed0, packed1); + // Ban -128. + packed = _mm256_max_epi8(packed, neg127); //Could be removed if we use +128 + packed = _mm256_add_epi8(packed, pos127); + // Currently in 0 1 2 3 8 9 10 11 16 17 18 19 24 25 26 27 4 5 6 7 12 13 14 15 20 21 22 23 28 29 30 31 + // Or as 32-bit integers 0 2 4 6 1 3 5 7 + // Technically this could be removed so long as the rows are bigger than 16 + // and the values are only used for GEMM. + return _mm256_permutevar8x32_epi32(packed, shuffle_param); + } const __m256 mult_; }; @@ -160,6 +190,22 @@ struct AVX2_8bit { } } + // Currently A is prepared by quantization but this could theoretically change. + INTGEMM_AVX2 static inline void PrepareA(const float *input, uint8_t *output, float quant_mult, Index rows, Index cols) { + QuantizeU(input, output, quant_mult, rows * cols); + } + + // Just quantize everything in order. + INTGEMM_AVX2 static void QuantizeU(const float *input, uint8_t *output, float quant_mult, Index size) { + assert(size % 32 == 0); + assert(reinterpret_cast<uintptr_t>(input) % 32 == 0); + avx2::QuantizeTile8 q(quant_mult); + const float *end = input + size; + for (; input != end; input += 32, output += 32) { + *reinterpret_cast<__m256i*>(output) = q.ConsecutiveU(input); + } + } + // Tile size for B; B must be a multiple of this block size. static const Index kBTileRow = 32; static const Index kBTileCol = 8; @@ -171,6 +217,10 @@ struct AVX2_8bit { } INTGEMM_MULTIPLY8(__m256i, INTGEMM_AVX2, CPUType::AVX2) + + INTGEMM_MULTIPLY8NEW(__m256i, INTGEMM_AVX2, CPUType::AVX2) + + INTGEMM_PREPAREBIASFOR8(__m256i, INTGEMM_AVX2, CPUType::AVX2) constexpr static const char *const kName = "8-bit INTGEMM_AVX2"; diff --git a/avx512_gemm.h b/avx512_gemm.h index e7d675c..d72859b 100644 --- a/avx512_gemm.h +++ b/avx512_gemm.h @@ -204,6 +204,35 @@ struct AVX512_8bit { } } + // Preparing A for the signed/unsigned multiplication. Using add 127 + /* Only INTGEMM_AVX512F is necessary but due to GCC 5.4 bug we have to set INTGEMM_AVX512BW */ + INTGEMM_AVX512BW static inline void PrepareA(const float *input, uint8_t *output, float quant_mult, Index rows, Index cols) { + QuantizeU(input, output, quant_mult, rows * cols); + } + + // Technically output can be unaligned in Quantize. + // But then it will need to be aligned for Multiply. + // Convert to 8-bit signed integers. + /* Only INTGEMM_AVX512F is necessary but due to GCC 5.4 bug we have to set INTGEMM_AVX512BW */ + + INTGEMM_AVX512BW static void QuantizeU(const float *input, uint8_t *output, float quant_mult, Index size) { + assert(size % 16 == 0); + assert(reinterpret_cast<uintptr_t>(input) % 64 == 0); + const __m512i neg127 = _mm512_set1_epi32(-127); + const __m128i pos127 = _mm_set1_epi8(127); + const __m512 quant_mult_reg = _mm512_set1_ps(quant_mult); + const float *end = input + size; + for (; input < end; input += 16, output += 16) { + __m512i asint = avx512f::QuantizerGrab(input, quant_mult_reg); + asint = _mm512_max_epi32(asint, neg127); + + //First convert to 8 bit then add and finally store, + //because _mm512_mask_cvtsepi32_storeu_epi8 saturates to signed + __m128i as8bit = _mm512_cvtsepi32_epi8(asint); + *reinterpret_cast<__m128i*>(output) = _mm_add_epi8(as8bit, pos127); + } + } + // Tile size for B; B must be a multiple of this block size. static const Index kBTileRow = 64; static const Index kBTileCol = 8; @@ -333,6 +362,12 @@ struct AVX512_8bit { } } + //INTGEMM_PREPARE_BIAS_FOR_8(INTGEMM_AVX2, __m256) + + INTGEMM_MULTIPLY8NEW(__m512i, INTGEMM_AVX512BW, CPUType::AVX2) + + INTGEMM_PREPAREBIASFOR8(__m512i, INTGEMM_AVX512BW, CPUType::AVX2) + constexpr static const char *const kName = "8-bit AVX512"; static const CPUType kUses = CPUType::AVX512BW; diff --git a/benchmark.cc b/benchmarks/benchmark.cc index 6b01304..6b01304 100644 --- a/benchmark.cc +++ b/benchmarks/benchmark.cc diff --git a/benchmarks/biasmultiply.cc b/benchmarks/biasmultiply.cc new file mode 100644 index 0000000..69ee776 --- /dev/null +++ b/benchmarks/biasmultiply.cc @@ -0,0 +1,239 @@ +#include "intgemm.h" +#include "aligned.h" +#include <chrono> +#include <random> +#include <iostream> + +using namespace intgemm; + +template <class Routine> +void testOld(Index rows, Index cols) { + +} + +template <class Routine> +std::chrono::duration<double> testNew(Index A_rows, Index width, Index B_cols) { + AlignedVector<float> A(A_rows * width); + AlignedVector<float> B(width * B_cols); + AlignedVector<float> bias(B_cols); + std::mt19937 gen; + std::uniform_real_distribution<float> dist(-1.0f, 1.0f); + for (auto& it : A) { + it = dist(gen); + } + for (auto& it : B) { + it = dist(gen); + } + for (auto& it : bias) { + it = dist(gen); + } + + float alpha = 2.0f; + float quant_mult = 127/alpha; + float unquant_mult = 1.0/(quant_mult*quant_mult); + + AlignedVector<uint8_t> A_prep(A.size()); + AlignedVector<int8_t> B_prep(B.size()); + Routine::PrepareA(A.begin(), A_prep.begin(), quant_mult, A_rows, width); + Routine::PrepareB(B.begin(), B_prep.begin(), quant_mult, width, B_cols); + + AlignedVector<float> test_C(A_rows * B_cols); + + float unquant_mult_forprep = (-1)*(alpha)*(alpha)/(127.0f); //Minus one to invert add_ps later on + Routine::PrepareBiasFor8(1, B_prep.begin(), 1, width, B_cols, callbacks::UnquantizeAndAddBiasAndWrite(unquant_mult_forprep, bias.begin(), bias.begin())); + auto start = std::chrono::system_clock::now(); + Routine::Multiply8new(A_prep.begin(), B_prep.begin(), A_rows, width, B_cols, callbacks::UnquantizeAndAddBiasAndWrite(unquant_mult, bias.begin(), test_C.begin())); + auto end = std::chrono::system_clock::now(); + + std::chrono::duration<double> elapsed_seconds = end-start; + return elapsed_seconds; + +} + +template <class Routine> +std::chrono::duration<double> testOld(Index A_rows, Index width, Index B_cols) { + AlignedVector<float> A(A_rows * width); + AlignedVector<float> B(width * B_cols); + AlignedVector<float> bias(B_cols); + std::mt19937 gen; + std::uniform_real_distribution<float> dist(-1.0f, 1.0f); + for (auto& it : A) { + it = dist(gen); + } + for (auto& it : B) { + it = dist(gen); + } + for (auto& it : bias) { + it = dist(gen); + } + + float alpha = 2.0f; + float quant_mult = 127/alpha; + float unquant_mult = 1.0/(quant_mult*quant_mult); + + AlignedVector<int8_t> A_prep(A.size()); + AlignedVector<int8_t> B_prep(B.size()); + Routine::PrepareA(A.begin(), A_prep.begin(), quant_mult, A_rows, width); + Routine::PrepareB(B.begin(), B_prep.begin(), quant_mult, width, B_cols); + + AlignedVector<float> test_C(A_rows * B_cols); + + auto start = std::chrono::system_clock::now(); + Routine::Multiply(A_prep.begin(), B_prep.begin(), A_rows, width, B_cols, callbacks::UnquantizeAndAddBiasAndWrite(unquant_mult, bias.begin(), test_C.begin())); + auto end = std::chrono::system_clock::now(); + + std::chrono::duration<double> elapsed_seconds = end-start; + return elapsed_seconds; + +} + +template <class Routine> +std::chrono::duration<double> testOld_nobias(Index A_rows, Index width, Index B_cols) { + AlignedVector<float> A(A_rows * width); + AlignedVector<float> B(width * B_cols); + std::mt19937 gen; + std::uniform_real_distribution<float> dist(-1.0f, 1.0f); + for (auto& it : A) { + it = dist(gen); + } + for (auto& it : B) { + it = dist(gen); + } + + float alpha = 2.0f; + float quant_mult = 127/alpha; + float unquant_mult = 1.0/(quant_mult*quant_mult); + + AlignedVector<int8_t> A_prep(A.size()); + AlignedVector<int8_t> B_prep(B.size()); + Routine::PrepareA(A.begin(), A_prep.begin(), quant_mult, A_rows, width); + Routine::PrepareB(B.begin(), B_prep.begin(), quant_mult, width, B_cols); + + AlignedVector<float> test_C(A_rows * B_cols); + + auto start = std::chrono::system_clock::now(); + Routine::Multiply(A_prep.begin(), B_prep.begin(), A_rows, width, B_cols, callbacks::UnquantizeAndWrite(unquant_mult, test_C.begin())); + auto end = std::chrono::system_clock::now(); + + std::chrono::duration<double> elapsed_seconds = end-start; + return elapsed_seconds; + +} + +int main(int argc, char ** argv) { + int repeat = 1000; + if (argc > 1) { + repeat = atoi(argv[1]); + } + + std::chrono::duration<double> oldSSSE3_nobias = testOld_nobias<SSSE3_8bit>(1, 64, 8); + for (int i = 0; i<repeat; i++) { + oldSSSE3_nobias += testOld_nobias<SSSE3_8bit>(8, 256, 256); + oldSSSE3_nobias += testOld_nobias<SSSE3_8bit>(8, 2048, 256); + oldSSSE3_nobias += testOld_nobias<SSSE3_8bit>(320, 256, 256); + oldSSSE3_nobias += testOld_nobias<SSSE3_8bit>(472, 256, 256); + oldSSSE3_nobias += testOld_nobias<SSSE3_8bit>(248, 256, 256); + oldSSSE3_nobias += testOld_nobias<SSSE3_8bit>(200, 256, 256); + } + + std::cout << repeat << " iterations of Old SSSE3 without bias took: " << oldSSSE3_nobias.count() << " seconds." << std::endl; + + std::chrono::duration<double> oldSSSE3 = testOld<SSSE3_8bit>(1, 64, 8); + for (int i = 0; i<repeat; i++) { + oldSSSE3 += testOld<SSSE3_8bit>(8, 256, 256); + oldSSSE3 += testOld<SSSE3_8bit>(8, 2048, 256); + oldSSSE3 += testOld<SSSE3_8bit>(320, 256, 256); + oldSSSE3 += testOld<SSSE3_8bit>(472, 256, 256); + oldSSSE3 += testOld<SSSE3_8bit>(248, 256, 256); + oldSSSE3 += testOld<SSSE3_8bit>(200, 256, 256); + } + + std::cout << repeat << " iterations of Old SSSE3 took: " << oldSSSE3.count() << " seconds." << std::endl; + + std::chrono::duration<double> newTimeSSSE3 = testOld<SSSE3_8bit>(1, 64, 8); + for (int i = 0; i<repeat; i++) { + newTimeSSSE3 += testNew<SSSE3_8bit>(8, 256, 256); + newTimeSSSE3 += testNew<SSSE3_8bit>(8, 2048, 256); + newTimeSSSE3 += testNew<SSSE3_8bit>(320, 256, 256); + newTimeSSSE3 += testNew<SSSE3_8bit>(472, 256, 256); + newTimeSSSE3 += testNew<SSSE3_8bit>(248, 256, 256); + newTimeSSSE3 += testNew<SSSE3_8bit>(200, 256, 256); + } + + std::cout << repeat << " iterations of New SSSE3 took: " << newTimeSSSE3.count() << " seconds." << std::endl; + + std::chrono::duration<double> oldAVX2_nobias = testOld_nobias<AVX2_8bit>(1, 64, 8); + for (int i = 0; i<repeat; i++) { + oldAVX2_nobias += testOld_nobias<AVX2_8bit>(8, 256, 256); + oldAVX2_nobias += testOld_nobias<AVX2_8bit>(8, 2048, 256); + oldAVX2_nobias += testOld_nobias<AVX2_8bit>(320, 256, 256); + oldAVX2_nobias += testOld_nobias<AVX2_8bit>(472, 256, 256); + oldAVX2_nobias += testOld_nobias<AVX2_8bit>(248, 256, 256); + oldAVX2_nobias += testOld_nobias<AVX2_8bit>(200, 256, 256); + } + + std::cout << repeat << " iterations of Old AVX2 without bias took: " << oldAVX2_nobias.count() << " seconds." << std::endl; + + std::chrono::duration<double> oldAVX2 = testOld<AVX2_8bit>(1, 64, 8); + for (int i = 0; i<repeat; i++) { + oldAVX2 += testOld<AVX2_8bit>(8, 256, 256); + oldAVX2 += testOld<AVX2_8bit>(8, 2048, 256); + oldAVX2 += testOld<AVX2_8bit>(320, 256, 256); + oldAVX2 += testOld<AVX2_8bit>(472, 256, 256); + oldAVX2 += testOld<AVX2_8bit>(248, 256, 256); + oldAVX2 += testOld<AVX2_8bit>(200, 256, 256); + } + + std::cout << repeat << " iterations of Old AVX2 took: " << oldAVX2.count() << " seconds." << std::endl; + + std::chrono::duration<double> newTimeAVX2 = testOld<AVX2_8bit>(1, 64, 8); + for (int i = 0; i<repeat; i++) { + newTimeAVX2 += testNew<AVX2_8bit>(8, 256, 256); + newTimeAVX2 += testNew<AVX2_8bit>(8, 2048, 256); + newTimeAVX2 += testNew<AVX2_8bit>(320, 256, 256); + newTimeAVX2 += testNew<AVX2_8bit>(472, 256, 256); + newTimeAVX2 += testNew<AVX2_8bit>(248, 256, 256); + newTimeAVX2 += testNew<AVX2_8bit>(200, 256, 256); + } + + std::cout << repeat << " iterations of New AVX2 took: " << newTimeAVX2.count() << " seconds." << std::endl; + + if (kCPU < CPUType::AVX512BW) return 0; + std::chrono::duration<double> oldAVX512_nobias = testOld_nobias<AVX512_8bit>(1, 64, 8); + for (int i = 0; i<repeat; i++) { + oldAVX512_nobias += testOld_nobias<AVX512_8bit>(8, 256, 256); + oldAVX512_nobias += testOld_nobias<AVX512_8bit>(8, 2048, 256); + oldAVX512_nobias += testOld_nobias<AVX512_8bit>(320, 256, 256); + oldAVX512_nobias += testOld_nobias<AVX512_8bit>(472, 256, 256); + oldAVX512_nobias += testOld_nobias<AVX512_8bit>(248, 256, 256); + oldAVX512_nobias += testOld_nobias<AVX512_8bit>(200, 256, 256); + } + + std::cout << repeat << " iterations of Old AVX512 without bias took: " << oldAVX512_nobias.count() << " seconds." << std::endl; + + std::chrono::duration<double> oldAVX512 = testOld<AVX512_8bit>(1, 64, 8); + for (int i = 0; i<repeat; i++) { + oldAVX512 += testOld<AVX512_8bit>(8, 256, 256); + oldAVX512 += testOld<AVX512_8bit>(8, 2048, 256); + oldAVX512 += testOld<AVX512_8bit>(320, 256, 256); + oldAVX512 += testOld<AVX512_8bit>(472, 256, 256); + oldAVX512 += testOld<AVX512_8bit>(248, 256, 256); + oldAVX512 += testOld<AVX512_8bit>(200, 256, 256); + } + + std::cout << repeat << " iterations of Old AVX512 took: " << oldAVX512.count() << " seconds." << std::endl; + + std::chrono::duration<double> newTimeAVX512 = testOld<AVX512_8bit>(1, 64, 8); + for (int i = 0; i<repeat; i++) { + newTimeAVX512 += testNew<AVX512_8bit>(8, 256, 256); + newTimeAVX512 += testNew<AVX512_8bit>(8, 2048, 256); + newTimeAVX512 += testNew<AVX512_8bit>(320, 256, 256); + newTimeAVX512 += testNew<AVX512_8bit>(472, 256, 256); + newTimeAVX512 += testNew<AVX512_8bit>(248, 256, 256); + newTimeAVX512 += testNew<AVX512_8bit>(200, 256, 256); + } + + std::cout << repeat << " iterations of New AVX512 took: " << newTimeAVX512.count() << " seconds." << std::endl; + + +} diff --git a/interleave.h b/interleave.h index 202062f..cc06378 100644 --- a/interleave.h +++ b/interleave.h @@ -270,4 +270,21 @@ target static inline void SelectColumnsOfB(const Register *input, Register *outp } \ } \ +#define INTGEMM_PREPARE_BIAS_FOR_8(target, Register) \ +target static inline void PrepareBiasFor8(const float *input, float *bias, float alpha, Index rows, Index cols) { \ + assert(cols*sizeof(float) % sizeof(Register) == 0); \ + constexpr int stride = sizeof(Register) / sizeof(float); \ + Register alpha_reg = set1_ps<Register>(alpha); \ + for (Index c = 0; c<cols; c+=stride) { \ + Register vectorsum = set1_ps<Register>(0.0f); \ + for (Index r = 0; r < rows; r++) { \ + Register column_stride = load_ps<Register>(input + r*cols + c); \ + vectorsum = add_ps(vectorsum, column_stride); \ + } \ + Register *towrite = reinterpret_cast<Register *>(bias + c); \ + vectorsum = mul_ps(vectorsum, alpha_reg); \ + *towrite = sub_ps(*towrite, vectorsum); \ + } \ +} \ + } // namespace intgemm @@ -16,10 +16,14 @@ const char *const Int16::kName = ChooseCPU(AVX512_16bit::kName, AVX2_16bit::kNam void (*Int8::Quantize)(const float *input, int8_t *output, float quant_mult, Index size) = ChooseCPU(AVX512_8bit::Quantize, AVX2_8bit::Quantize, SSSE3_8bit::Quantize, Unsupported_8bit::Quantize, Unsupported_8bit::Quantize); +void (*Int8::QuantizeU)(const float *input, uint8_t *output, float quant_mult, Index size) = ChooseCPU(AVX512_8bit::QuantizeU, AVX2_8bit::QuantizeU, SSSE3_8bit::QuantizeU, Unsupported_8bit::QuantizeU, Unsupported_8bit::QuantizeU); + void (*Int8::PrepareB)(const float *input, int8_t *output, float quant_mult, Index rows, Index cols) = ChooseCPU(AVX512_8bit::PrepareB, AVX2_8bit::PrepareB, SSSE3_8bit::PrepareB, Unsupported_8bit::PrepareB, Unsupported_8bit::PrepareB); void (*Int8::SelectColumnsB)(const int8_t *input, int8_t *output, Index rows, const Index *cols_begin, const Index *cols_end) = ChooseCPU(AVX512_8bit::SelectColumnsB, AVX2_8bit::SelectColumnsB, SSSE3_8bit::SelectColumnsB, Unsupported_8bit::SelectColumnsB, Unsupported_8bit::SelectColumnsB); +//void (*Int8::PrepareBiasFor8)(const float *input, float *bias, float alpha, Index rows, Index cols) = ChooseCPU(AVX512_8bit::PrepareBiasFor8, AVX2_8bit::PrepareBiasFor8, SSSE3_8bit::PrepareBiasFor8, Unsupported_8bit::PrepareBiasFor8, Unsupported_8bit::PrepareBiasFor8); + const char *const Int8::kName = ChooseCPU(AVX512_8bit::kName, AVX2_8bit::kName, SSSE3_8bit::kName, Unsupported_8bit::kName, Unsupported_8bit::kName); const CPUType kCPU = ChooseCPU(CPUType::AVX512BW, CPUType::AVX2, CPUType::SSSE3, CPUType::SSE2, CPUType::UNSUPPORTED); @@ -76,9 +76,16 @@ struct Unsupported_8bit { static void Quantize(const float *, int8_t *, float, Index) { throw UnsupportedCPU(); } + static void QuantizeU(const float *, uint8_t *, float, Index) { + throw UnsupportedCPU(); + } static void PrepareB(const float *, int8_t *, float, Index, Index) { throw UnsupportedCPU(); } + template<class Callback> + static void PrepareBiasFor8(const int8_t, const int8_t *, Index, Index, Index, Callback) { + throw UnsupportedCPU(); + } static void SelectColumnsB(const int8_t *, int8_t *, Index, const Index *, const Index *) { throw UnsupportedCPU(); } @@ -86,6 +93,10 @@ struct Unsupported_8bit { static void Multiply(const int8_t *, const int8_t *, Index, Index, Index, Callback) { throw UnsupportedCPU(); } + template<class Callback> + static void Multiply8new(const uint8_t *, const int8_t *, Index, Index, Index, Callback) { + throw UnsupportedCPU(); + } constexpr static const char *const kName = "8-bit Unsupported"; }; @@ -184,11 +195,20 @@ class Int8Mult { public: // Multiply C = A * B, presuming A and B have been prepared. static void (*Multiply)(const int8_t *A, const int8_t *B, Index A_rows, Index width, Index B_cols, Callback callback); + static void (*Multiply8new)(const uint8_t *A, const int8_t *B, Index A_rows, Index width, Index B_cols, Callback callback); + static void (*PrepareBiasFor8)(const int8_t A, const int8_t *B, Index A_rows, Index width, Index B_cols, Callback callback); }; template <typename Callback> void (*Int8Mult<Callback>::Multiply)(const int8_t *A, const int8_t *B, Index A_rows, Index width, Index B_cols, Callback callback) = ChooseCPU(AVX512_8bit::Multiply<Callback>, AVX2_8bit::Multiply<Callback>, SSSE3_8bit::Multiply<Callback>, SSSE3_8bit::Multiply<Callback>, Unsupported_8bit::Multiply); +template <class Callback> +void (*Int8Mult<Callback>::Multiply8new)(const uint8_t *A, const int8_t *B, Index A_rows, Index width, Index B_cols, Callback callback) = ChooseCPU(AVX512_8bit::Multiply8new<Callback>, AVX2_8bit::Multiply8new<Callback>, SSSE3_8bit::Multiply8new<Callback>, SSSE3_8bit::Multiply8new<Callback>, Unsupported_8bit::Multiply8new); + +template <class Callback> +void (*Int8Mult<Callback>::PrepareBiasFor8)(const int8_t A, const int8_t *B, Index A_rows, Index width, Index B_cols, Callback callback) = ChooseCPU(AVX512_8bit::PrepareBiasFor8<Callback>, AVX2_8bit::PrepareBiasFor8<Callback>, SSSE3_8bit::PrepareBiasFor8<Callback>, SSSE3_8bit::PrepareBiasFor8<Callback>, Unsupported_8bit::PrepareBiasFor8); + + struct Int8 { typedef int8_t Integer; @@ -206,8 +226,18 @@ struct Int8 { Quantize(input, output, quant_mult, rows * cols); } + static inline void PrepareANew(const float *input, int8_t *output, float quant_mult, Index rows, Index cols) { + QuantizeU(input, reinterpret_cast<uint8_t *>(output), quant_mult, rows * cols); + } + // Multiply floats by quant_mult then convert to 8-bit integers with saturation. static void (*Quantize)(const float *input, int8_t *output, float quant_mult, Index size); + + // Multiply floats by quant_mult then convert to 8-bit integers with saturation. + static void (*QuantizeU)(const float *input, uint8_t *output, float quant_mult, Index size); + + // PrepareB + //static void (*PrepareBiasFor8)(const float *input, float *bias, float alpha, Index rows, Index cols); // Warning: the output of PrepareB depends on the CPU. // It will match the Multiply function on the same CPU though. @@ -221,6 +251,16 @@ struct Int8 { static void Multiply(const int8_t *A, const int8_t *B, Index A_rows, Index width, Index B_cols, Callback callback) { Int8Mult<Callback>::Multiply(A, B, A_rows, width, B_cols, callback); } + + template<class Callback> + static void Multiply8new(const int8_t *A, const int8_t *B, Index A_rows, Index width, Index B_cols, Callback callback) { + Int8Mult<Callback>::Multiply8new((const uint8_t *)A, B, A_rows, width, B_cols, callback); + } + + template<class Callback> + static void PrepareBiasFor8(const int8_t A, const int8_t *B, Index A_rows, Index width, Index B_cols, Callback callback) { + Int8Mult<Callback>::PrepareBiasFor8(A, B, A_rows, width, B_cols, callback); + } static const char *const kName; }; diff --git a/intrinsics.h b/intrinsics.h index e19efd7..9be5296 100644 --- a/intrinsics.h +++ b/intrinsics.h @@ -16,10 +16,11 @@ namespace intgemm { * Define a bunch of intrinstics as overloaded functions so they work with * templates. */ +template <class Register> static inline Register load_ps(float const* from); template <class Register> static inline Register loadu_ps(const float* mem_addr); -template <class Register> static inline Register set1_epi8(int8_t to); template <class Register> static inline Register set1_epi16(int16_t to); template <class Register> static inline Register set1_epi32(int32_t to); +template <class Register> static inline Register set1_epi8(int8_t to); template <class Register> static inline Register set1_pd(double to); template <class Register> static inline Register set1_ps(float to); template <class Register> static inline Register setzero_pd(); @@ -73,6 +74,9 @@ INTGEMM_SSE2 static inline __m128 div_ps(__m128 a, __m128 b) { /* * Missing i32gather_ps for SSE2 */ +template <> INTGEMM_SSE2 inline __m128 load_ps<__m128>(const float* from) { + return _mm_load_ps(from); +} template <> INTGEMM_SSE2 inline __m128 loadu_ps(const float* mem_addr) { return _mm_loadu_ps(mem_addr); } @@ -242,6 +246,9 @@ INTGEMM_AVX2 static inline __m256 i32gather_ps(float const *base_addr, __m256i v template <> INTGEMM_AVX2 inline __m256 loadu_ps(const float* mem_addr) { return _mm256_loadu_ps(mem_addr); } +template <> INTGEMM_AVX2 inline __m256 load_ps<__m256>(const float* from) { + return _mm256_load_ps(from); +} INTGEMM_AVX2 static inline __m256i madd_epi16(__m256i first, __m256i second) { return _mm256_madd_epi16(first, second); } @@ -479,6 +486,9 @@ template <> INTGEMM_AVX512BW inline __m512 setzero_ps<__m512>() { template <> INTGEMM_AVX512BW inline __m512i setzero_si<__m512i>() { return _mm512_setzero_si512(); } +template <> INTGEMM_AVX512BW inline __m512 load_ps<__m512>(const float* from) { + return _mm512_load_ps(from); +} /* * Missing sign_epi8 */ @@ -192,7 +192,160 @@ template <typename Callback> target static void Multiply(const int16_t *A, const } \ } \ -/* 8-bit matrix multiply used by AVX and INTGEMM_AVX2. +//An int8_prepbias version of the above code, using the add 127 technique +#define INTGEMM_PREPAREBIASFOR8(Integer, target, cpu_type) \ + template <class Callback> target static void PrepareBiasFor8(const int8_t A, const int8_t *B, Index A_rows, Index width, Index B_cols, Callback callback) { \ + assert(width % (sizeof(Integer) / sizeof(int8_t)) == 0); \ + assert(B_cols % 8 == 0); \ + assert(reinterpret_cast<uintptr_t>(B) % sizeof(Integer) == 0); \ + const int simd_width = width / (sizeof(Integer) / sizeof(int8_t)); \ + auto callback_impl = callbacks::CallbackImpl<cpu_type, Callback>(callback); \ + const Integer *B0_col = reinterpret_cast<const Integer *>(B); \ + const Integer a = set1_epi8<Integer>(A); \ + for (Index B0_colidx = 0; B0_colidx < B_cols; B0_col += 8 * simd_width, B0_colidx += 8) { \ + /* Process one row of A at a time. Doesn't seem to be faster to do multiple rows of A at once.*/ \ + for (Index A_rowidx = 0; A_rowidx < A_rows; ++A_rowidx) { \ + /*const Integer *A_row = reinterpret_cast<const Integer*>(A + A_rowidx * width);*/ \ + /* These will be packed 16-bit integers containing sums for each row of B multiplied by the row of A. \ + Iterate over shared (inner) dimension.*/ \ + int k = 0; \ + Integer sum0 = maddubs_epi16(a, *(B0_col + k * 8)); \ + Integer sum1 = maddubs_epi16(a, *(B0_col + k * 8 + 1)); \ + Integer sum2 = maddubs_epi16(a, *(B0_col + k * 8 + 2)); \ + Integer sum3 = maddubs_epi16(a, *(B0_col + k * 8 + 3)); \ + Integer sum4 = maddubs_epi16(a, *(B0_col + k * 8 + 4)); \ + Integer sum5 = maddubs_epi16(a, *(B0_col + k * 8 + 5)); \ + Integer sum6 = maddubs_epi16(a, *(B0_col + k * 8 + 6)); \ + Integer sum7 = maddubs_epi16(a, *(B0_col + k * 8 + 7)); \ + /* Upcast to 32-bit and horizontally add. Seems a bit faster if this is declared here.*/ \ + Integer ones = set1_epi16<Integer>(1); \ + sum0 = madd_epi16(sum0, ones); \ + sum1 = madd_epi16(sum1, ones); \ + sum2 = madd_epi16(sum2, ones); \ + sum3 = madd_epi16(sum3, ones); \ + sum4 = madd_epi16(sum4, ones); \ + sum5 = madd_epi16(sum5, ones); \ + sum6 = madd_epi16(sum6, ones); \ + sum7 = madd_epi16(sum7, ones); \ + for (int k = 1; k < simd_width; ++k) { \ + /*Integer a = *(A_row + k);*/ \ + /* Multiply 8-bit, horizontally add to packed 16-bit integers.*/ \ + Integer mult0 = maddubs_epi16(a, *(B0_col + k * 8)); \ + Integer mult1 = maddubs_epi16(a, *(B0_col + k * 8 + 1)); \ + Integer mult2 = maddubs_epi16(a, *(B0_col + k * 8 + 2)); \ + Integer mult3 = maddubs_epi16(a, *(B0_col + k * 8 + 3)); \ + Integer mult4 = maddubs_epi16(a, *(B0_col + k * 8 + 4)); \ + Integer mult5 = maddubs_epi16(a, *(B0_col + k * 8 + 5)); \ + Integer mult6 = maddubs_epi16(a, *(B0_col + k * 8 + 6)); \ + Integer mult7 = maddubs_epi16(a, *(B0_col + k * 8 + 7)); \ + /* Upcast to 32-bit and horizontally add.*/ \ + mult0 = madd_epi16(mult0, ones); \ + mult1 = madd_epi16(mult1, ones); \ + mult2 = madd_epi16(mult2, ones); \ + mult3 = madd_epi16(mult3, ones); \ + mult4 = madd_epi16(mult4, ones); \ + mult5 = madd_epi16(mult5, ones); \ + mult6 = madd_epi16(mult6, ones); \ + mult7 = madd_epi16(mult7, ones); \ + /*Add in 32bit*/ \ + sum0 = add_epi32(sum0, mult0); \ + sum1 = add_epi32(sum1, mult1); \ + sum2 = add_epi32(sum2, mult2); \ + sum3 = add_epi32(sum3, mult3); \ + sum4 = add_epi32(sum4, mult4); \ + sum5 = add_epi32(sum5, mult5); \ + sum6 = add_epi32(sum6, mult6); \ + sum7 = add_epi32(sum7, mult7); \ + \ + } \ + /* Reduce sums within 128-bit lanes.*/ \ + Integer pack0123 = Pack0123(sum0, sum1, sum2, sum3); \ + Integer pack4567 = Pack0123(sum4, sum5, sum6, sum7); \ + /*The specific implementation may need to reduce further.*/ \ + auto total = PermuteSummer(pack0123, pack4567); \ + RunCallback(callback_impl, total, A_rowidx, B0_colidx, A_rows, B_cols); \ + } \ + } \ +} \ + +//An int8 version of the above code, using the add 127 technique +#define INTGEMM_MULTIPLY8NEW(Integer, target, cpu_type) \ + template <class Callback> target static void Multiply8new(const uint8_t *A, const int8_t *B, Index A_rows, Index width, Index B_cols, Callback callback) { \ + assert(width % (sizeof(Integer) / sizeof(int8_t)) == 0); \ + assert(B_cols % 8 == 0); \ + assert(reinterpret_cast<uintptr_t>(A) % sizeof(Integer) == 0); \ + assert(reinterpret_cast<uintptr_t>(B) % sizeof(Integer) == 0); \ + const int simd_width = width / (sizeof(Integer) / sizeof(int8_t)); \ + auto callback_impl = callbacks::CallbackImpl<cpu_type, Callback>(callback); \ + const Integer *B0_col = reinterpret_cast<const Integer *>(B); \ + for (Index B0_colidx = 0; B0_colidx < B_cols; B0_col += 8 * simd_width, B0_colidx += 8) { \ + /* Process one row of A at a time. Doesn't seem to be faster to do multiple rows of A at once.*/ \ + for (Index A_rowidx = 0; A_rowidx < A_rows; ++A_rowidx) { \ + const Integer *A_row = reinterpret_cast<const Integer*>(A + A_rowidx * width); \ + /* These will be packed 16-bit integers containing sums for each row of B multiplied by the row of A. \ + Iterate over shared (inner) dimension.*/ \ + int k = 0; \ + Integer a = *(A_row + k); \ + Integer sum0 = maddubs_epi16(a, *(B0_col + k * 8)); \ + Integer sum1 = maddubs_epi16(a, *(B0_col + k * 8 + 1)); \ + Integer sum2 = maddubs_epi16(a, *(B0_col + k * 8 + 2)); \ + Integer sum3 = maddubs_epi16(a, *(B0_col + k * 8 + 3)); \ + Integer sum4 = maddubs_epi16(a, *(B0_col + k * 8 + 4)); \ + Integer sum5 = maddubs_epi16(a, *(B0_col + k * 8 + 5)); \ + Integer sum6 = maddubs_epi16(a, *(B0_col + k * 8 + 6)); \ + Integer sum7 = maddubs_epi16(a, *(B0_col + k * 8 + 7)); \ + /* Upcast to 32-bit and horizontally add. Seems a bit faster if this is declared here.*/ \ + Integer ones = set1_epi16<Integer>(1); \ + sum0 = madd_epi16(sum0, ones); \ + sum1 = madd_epi16(sum1, ones); \ + sum2 = madd_epi16(sum2, ones); \ + sum3 = madd_epi16(sum3, ones); \ + sum4 = madd_epi16(sum4, ones); \ + sum5 = madd_epi16(sum5, ones); \ + sum6 = madd_epi16(sum6, ones); \ + sum7 = madd_epi16(sum7, ones); \ + for (int k = 1; k < simd_width; ++k) { \ + Integer a = *(A_row + k); \ + /* Multiply 8-bit, horizontally add to packed 16-bit integers.*/ \ + Integer mult0 = maddubs_epi16(a, *(B0_col + k * 8)); \ + Integer mult1 = maddubs_epi16(a, *(B0_col + k * 8 + 1)); \ + Integer mult2 = maddubs_epi16(a, *(B0_col + k * 8 + 2)); \ + Integer mult3 = maddubs_epi16(a, *(B0_col + k * 8 + 3)); \ + Integer mult4 = maddubs_epi16(a, *(B0_col + k * 8 + 4)); \ + Integer mult5 = maddubs_epi16(a, *(B0_col + k * 8 + 5)); \ + Integer mult6 = maddubs_epi16(a, *(B0_col + k * 8 + 6)); \ + Integer mult7 = maddubs_epi16(a, *(B0_col + k * 8 + 7)); \ + /* Upcast to 32-bit and horizontally add.*/ \ + mult0 = madd_epi16(mult0, ones); \ + mult1 = madd_epi16(mult1, ones); \ + mult2 = madd_epi16(mult2, ones); \ + mult3 = madd_epi16(mult3, ones); \ + mult4 = madd_epi16(mult4, ones); \ + mult5 = madd_epi16(mult5, ones); \ + mult6 = madd_epi16(mult6, ones); \ + mult7 = madd_epi16(mult7, ones); \ + /*Add in 32bit*/ \ + sum0 = add_epi32(sum0, mult0); \ + sum1 = add_epi32(sum1, mult1); \ + sum2 = add_epi32(sum2, mult2); \ + sum3 = add_epi32(sum3, mult3); \ + sum4 = add_epi32(sum4, mult4); \ + sum5 = add_epi32(sum5, mult5); \ + sum6 = add_epi32(sum6, mult6); \ + sum7 = add_epi32(sum7, mult7); \ + \ + } \ + /* Reduce sums within 128-bit lanes.*/ \ + Integer pack0123 = Pack0123(sum0, sum1, sum2, sum3); \ + Integer pack4567 = Pack0123(sum4, sum5, sum6, sum7); \ + /*The specific implementation may need to reduce further.*/ \ + auto total = PermuteSummer(pack0123, pack4567); \ + RunCallback(callback_impl, total, A_rowidx, B0_colidx, A_rows, B_cols); \ + } \ + } \ +} \ + +/* 8-bit matrix multiply used by AVX and AVX2. * These have two peculiar properties: * 1. The sign instructions don't exist in AVX512. * 2. 16 registers means gcc's register allocation failed so I wrote it in my diff --git a/ssse3_gemm.h b/ssse3_gemm.h index 24cc179..fcd0de8 100644 --- a/ssse3_gemm.h +++ b/ssse3_gemm.h @@ -35,6 +35,11 @@ class QuantizeTile8 { return Tile(input, input + 8); } + INTGEMM_SSSE3 inline __m128i ConsecutiveU(const float *input) { + return TileU(input, input + 8); + } + + private: // Quantize 16xfloat into 16xint8_t INTGEMM_SSSE3 inline __m128i Tile(const float *input0, const float *input1) { @@ -48,7 +53,7 @@ class QuantizeTile8 { __m128i packed = _mm_packs_epi16(packed0, packed1); /* Ban -128. * Don't use the SSE4.1 instruction _mm_max_epi8(packed, neg127). Instead, - * use INTGEMM_SSE2 instructions _mm_cmpeq_epi8 and _mm_sub_epi8. + * use SSE2 instructions _mm_cmpeq_epi8 and _mm_sub_epi8. * The first generates 0xff for fields -128. * The second subtracts 0xff from -128 which has the effect of converting * to -127. @@ -59,6 +64,29 @@ class QuantizeTile8 { // No permute needed. packs is in order for SSE. } + INTGEMM_SSSE3 inline __m128i TileU(const float *input0, const float *input1) { + const __m128i neg128 = _mm_set1_epi8(-128); + const __m128i pos127 = _mm_set1_epi8(127); + __m128i g0 = QuantizerGrab(input0, mult_reg_); + __m128i g1 = QuantizerGrab(input0 + 4, mult_reg_); + __m128i g2 = QuantizerGrab(input1, mult_reg_); + __m128i g3 = QuantizerGrab(input1 + 4, mult_reg_); + __m128i packed0 = _mm_packs_epi32(g0, g1); + __m128i packed1 = _mm_packs_epi32(g2, g3); + __m128i packed = _mm_packs_epi16(packed0, packed1); + /* Ban -128. + * Don't use the SSE4.1 instruction _mm_max_epi8(packed, neg127). Instead, + * use SSE2 instructions _mm_cmpeq_epi8 and _mm_sub_epi8. + * The first generates 0xff for fields -128. + * The second subtracts 0xff from -128 which has the effect of converting + * to -127. + */ + // packed = _mm_max_epi8(packed, neg127); + __m128i evils = _mm_cmpeq_epi8(packed, neg128); + return _mm_add_epi8(_mm_sub_epi8(packed, evils), pos127); + // No permute needed. packs is in order for SSE. + } + private: const __m128 mult_reg_; }; @@ -86,6 +114,23 @@ struct SSSE3_8bit { } } + // Version with unsigned int + 127 + // Currently A is prepared by quantization but this could theoretically change. + INTGEMM_SSSE3 static inline void PrepareA(const float *input, uint8_t *output, float quant_mult, Index rows, Index cols) { + QuantizeU(input, output, quant_mult, rows * cols); + } + + INTGEMM_SSSE3 static void QuantizeU(const float *input, uint8_t *output, float quant_mult, Index size) { + assert(size % 16 == 0); + assert(reinterpret_cast<uintptr_t>(input) % 16 == 0); + assert(reinterpret_cast<uintptr_t>(output) % 16 == 0); + ssse3::QuantizeTile8 q(quant_mult); + const float *end = input + size; + for (; input != end; input += 16, output += 16) { + *reinterpret_cast<__m128i*>(output) = q.ConsecutiveU(input); + } + } + // Tile size for B; B must be a multiple of this block size. static const Index kBTileRow = 16; static const Index kBTileCol = 8; @@ -97,7 +142,13 @@ struct SSSE3_8bit { } INTGEMM_MULTIPLY8(__m128i, INTGEMM_SSSE3, CPUType::SSE2) + + INTGEMM_MULTIPLY8NEW(__m128i, INTGEMM_SSSE3, CPUType::SSE2) + //INTGEMM_PREPARE_BIAS_FOR_8(INTGEMM_SSSE3, __m128) + + INTGEMM_PREPAREBIASFOR8(__m128i, INTGEMM_SSSE3, CPUType::SSE2) + constexpr static const char *const kName = "8-bit INTGEMM_SSSE3"; static const CPUType kUses = CPUType::SSSE3; diff --git a/test/add127_test.cc b/test/add127_test.cc new file mode 100644 index 0000000..e466c9d --- /dev/null +++ b/test/add127_test.cc @@ -0,0 +1,258 @@ +#include "test/test.h" + +namespace intgemm { + +void newBias(const float * input, float * bias, float* output, float alpha, Index rows, Index width, Index cols) { + AlignedVector<float> intermediate(rows*cols); + AlignedVector<float> ones(rows*width); + for (auto&& it : ones) { + it = 1; + } + SlowRefFloat(ones.begin(), input, intermediate.begin(), rows, width, cols); + for (auto&& it : intermediate) { + it = it*alpha; + } + + + for (Index c = 0; c<cols; c++) { + output[c] = bias[c] - intermediate.begin()[rows*c]; + } + +} + +void SlowSumB(const int8_t * input, float * bias, float* output, float alpha, Index rows, Index cols) { + for (Index r = 0; r<rows; r++) { + for (Index c = 0; c<cols; c++) { + output[c] += input[r * cols + c]; + } + } + + for (Index c = 0; c<cols; c++) { + output[c] = bias[c] + output[c]*alpha; + } +} + +void CompareAs(int8_t * output_old, uint8_t * output_new, Index rows, Index cols) { + for (Index r = 0; r<rows; r++) { + for (Index c = 0; c<cols; c++) { + int a = int(output_old[rows*c + r]); + int b = int(output_new[rows*c + r]); + INFO("Inaccurate at row: " << r << " column " << c << ' ' + << a << ' ' << b); + CHECK(a+127 == b); + } + } +} + +void CompareBiases(const float *bias_ref, const float *bias, Index cols) { + for (std::size_t i = 0; i < cols; ++i) { + INFO("Inaccurate at " << i << ' ' << bias_ref[i] << ' ' << bias[i]); + CHECK(fabs(bias_ref[i] - bias[i]) < 0.0001); + } +} + +template <class Routine> void TestPrepareA(Index rows, Index cols) { + std::mt19937 gen; + // Go somewhat out of range too. + std::uniform_real_distribution<float> dist(-2, 2); + // Create array. + AlignedVector<float> inputA(rows * cols); + for (auto& it : inputA) { + it = dist(gen); + } + AlignedVector<int8_t> oldA(rows * cols); + AlignedVector<uint8_t> newA(rows * cols); + float quant_mult = 64; //From example + Routine::PrepareA(inputA.begin(), oldA.begin(), quant_mult, rows, cols); + Routine::PrepareA(inputA.begin(), newA.begin(), quant_mult, rows, cols); + CompareAs(oldA.begin(), newA.begin(), rows, cols); +} + +template <class Routine> void TestPrepareBias(Index rows, Index cols) { + std::mt19937 gen; + // Go somewhat out of range too. + std::uniform_real_distribution<float> dist(-30.0, 30.0); + // Create array. + AlignedVector<float> inputB(rows * cols); + for (auto& it : inputB) { + it = dist(gen); + } + + float alpha = 25; + float quant_mult = 127/alpha; + + AlignedVector<int8_t> B_prep(inputB.size()); + AlignedVector<int8_t> B_quant(inputB.size()); + Routine::PrepareB(inputB.begin(), B_prep.begin(), quant_mult, rows, cols); + Routine::Quantize(inputB.begin(), B_quant.begin(), quant_mult, inputB.size()); + + + AlignedVector<float> inputBias(cols); + AlignedVector<float> goldBias(cols); + + for (auto& it : goldBias) { + it = 0; + } + for (auto& it : inputBias) { + it = dist(gen); + } + + float unquant_mult_forprep = (-1)*(alpha)*(alpha)/(127.0f); + + SlowSumB(B_quant.begin(), inputBias.begin(), goldBias.begin(), alpha, rows, cols); + + Routine::PrepareBiasFor8(1, B_prep.begin(), 1, rows, cols, callbacks::UnquantizeAndAddBiasAndWrite(unquant_mult_forprep, inputBias.begin(), inputBias.begin())); + + CompareBiases(goldBias.begin(), inputBias.begin(), cols); +} + +template <class Routine> void TestMultiplyBiasNew(Index A_rows, Index width, Index B_cols, + float int_tolerance=.1, float float_tolerance=1, float MSE_float_tolerance=0, float MSE_int_tolerance=0) { + std::ostringstream info; + info << Routine::kName << "\t" << A_rows << '\t' << width << '\t' << B_cols << '\n'; + + // Initialize A and B. + AlignedVector<float> A(A_rows * width); + AlignedVector<float> B(width * B_cols); + AlignedVector<float> bias(B_cols); + std::mt19937 gen; + std::uniform_real_distribution<float> dist(-1.0f, 1.0f); + for (auto& it : A) { + it = dist(gen); + } + for (auto& it : B) { + it = dist(gen); + } + for (auto& it : bias) { + it = dist(gen); + } + + float alpha = 2.0f; + float quant_mult = 127/alpha; + float unquant_mult = 1.0/(quant_mult*quant_mult); + + AlignedVector<uint8_t> A_prep(A.size()); + AlignedVector<int8_t> B_prep(B.size()); + Routine::PrepareA(A.begin(), A_prep.begin(), quant_mult, A_rows, width); + Routine::PrepareB(B.begin(), B_prep.begin(), quant_mult, width, B_cols); + + AlignedVector<float> test_C(A_rows * B_cols); + + /*REFERENCE MULTIPLICATION + * + * + */ + AlignedVector<int8_t> B_quant(B.size()); + Routine::Quantize(B.begin(), B_quant.begin(), quant_mult, B.size()); + AlignedVector<float> slowint_C(test_C.size()); + // Taking the original A_preparation which means A would be int8_t + AlignedVector<int8_t> A_prep2(A.size()); + Routine::PrepareA(A.begin(), A_prep2.begin(), quant_mult, A_rows, width); + SlowRefInt(A_prep2.begin(), B_quant.begin(), slowint_C.begin(), unquant_mult, A_rows, width, B_cols, bias.begin()); + + AlignedVector<float> float_C(test_C.size()); + SlowRefFloat(A.begin(), B.begin(), float_C.begin(), A_rows, width, B_cols, bias.begin()); + + /*ACTUAL MULTIPLICATION + * + */ + float unquant_mult_forprep = (-1)*(alpha)*(alpha)/(127.0f); //Minus one to invert add_ps later on + Routine::PrepareBiasFor8(1, B_prep.begin(), 1, width, B_cols, callbacks::UnquantizeAndAddBiasAndWrite(unquant_mult_forprep, bias.begin(), bias.begin())); + //Routine::PrepareBiasFor8(B.begin(), bias.begin(), alpha, width, B_cols); + Routine::Multiply8new(A_prep.begin(), B_prep.begin(), A_rows, width, B_cols, callbacks::UnquantizeAndAddBiasAndWrite(unquant_mult, bias.begin(), test_C.begin())); + + Compare(float_C.begin(), slowint_C.begin(), test_C.begin(), test_C.size(), info.str(), + int_tolerance, float_tolerance, MSE_float_tolerance, MSE_int_tolerance); +} + +/* +// Bias +TEST_CASE("PrepareBias SSSE3", "[Add127]") { + if (kCPU < CPUType::SSSE3) return; + TestPrepareBias<SSSE3_8bit>(8,8); + TestPrepareBias<SSSE3_8bit>(256,256); + TestPrepareBias<SSSE3_8bit>(2048,256); + TestPrepareBias<SSSE3_8bit>(512,512); +} + +TEST_CASE("PrepareBias AVX2", "[Add127]") { + if (kCPU < CPUType::AVX2) return; + TestPrepareBias<AVX2_8bit>(8,8); + TestPrepareBias<AVX2_8bit>(256,256); + TestPrepareBias<AVX2_8bit>(2048,256); + TestPrepareBias<AVX2_8bit>(512,512); +} + +TEST_CASE("PrepareBias AVX512F", "[Add127]") { + if (kCPU < CPUType::AVX512BW) return; + #ifndef INTGEMM_NO_AVX512 + TestPrepareBias<AVX512_8bit>(8,8); + TestPrepareBias<AVX512_8bit>(256,256); + TestPrepareBias<AVX512_8bit>(2048,256); + TestPrepareBias<AVX512_8bit>(512,512); + #endif +}*/ + +//A +TEST_CASE("PrepareA SSSE3", "[Add127]") { + if (kCPU < CPUType::SSSE3) return; + TestPrepareA<SSSE3_8bit>(64,64); + TestPrepareA<SSSE3_8bit>(256,256); + TestPrepareA<SSSE3_8bit>(512,512); + TestPrepareA<SSSE3_8bit>(2048,256); +} + +TEST_CASE("PrepareA AVX2", "[Add127]") { + if (kCPU < CPUType::AVX2) return; + TestPrepareA<AVX2_8bit>(64,64); + TestPrepareA<AVX2_8bit>(256,256); + TestPrepareA<AVX2_8bit>(512,512); + TestPrepareA<AVX2_8bit>(2048,256); +} + +TEST_CASE("PrepareA AVX512F", "[Add127]") { + if (kCPU < CPUType::AVX512BW) return; + #ifndef INTGEMM_NO_AVX512 + TestPrepareA<AVX512_8bit>(64,64); + TestPrepareA<AVX512_8bit>(256,256); + TestPrepareA<AVX512_8bit>(512,512); + TestPrepareA<AVX512_8bit>(2048,256); + #endif +} + +// Multiply + +TEST_CASE ("Multiply SSSE3 8bit with new bias", "[Add127]") { + if (kCPU < CPUType::SSSE3) return; + TestMultiplyBiasNew<SSSE3_8bit>(1, 64, 8, 0.11, 0.1, 0.06, 0.05); + TestMultiplyBiasNew<SSSE3_8bit>(8, 256, 256, 0.45, 0.54, 0.17, 0.16); // 0.064, 0.026); + TestMultiplyBiasNew<SSSE3_8bit>(8, 2048, 256, 1.7, 1.7, 0.46, 0.43); // 4.4, 4.4); + TestMultiplyBiasNew<SSSE3_8bit>(320, 256, 256, 0.56, 0.64, 0.16, 0.15); // 0.1, 0.01); + TestMultiplyBiasNew<SSSE3_8bit>(472, 256, 256, 0.46, 0.62, 0.17, 0.16); // 0.1, 0.011); + TestMultiplyBiasNew<SSSE3_8bit>(248, 256, 256, 0.48, 0.64, 0.16, 0.15); // 0.1, 0.012); + TestMultiplyBiasNew<SSSE3_8bit>(200, 256, 256, 0.55, 0.74, 0.17, 0.16); // 0.1, 0.011); +} + +TEST_CASE ("Multiply AVX2 8bit with new bias", "[Add127]") { + if (kCPU < CPUType::AVX2) return; + TestMultiplyBiasNew<AVX2_8bit>(1, 64, 8, 0.11, 0.11, 0.06, 0.05); + TestMultiplyBiasNew<AVX2_8bit>(8, 256, 256, 0.49, 0.54, 0.17, 0.16); //0.1, 0); + TestMultiplyBiasNew<AVX2_8bit>(8, 2048, 256, 1.57, 1.66, 0.46, 0.46); //1.8, 1.8); + TestMultiplyBiasNew<AVX2_8bit>(320, 256, 256, 0.49, 0.64, 0.16, 0.15); //0.1, 0); + TestMultiplyBiasNew<AVX2_8bit>(472, 256, 256, 0.46, 0.62, 0.17, 0.16); //0.1, 0); + TestMultiplyBiasNew<AVX2_8bit>(248, 256, 256, 0.48, 0.64, 0.16, 0.15); //0.1, 0); + TestMultiplyBiasNew<AVX2_8bit>(200, 256, 256, 0.55, 0.74, 0.17, 0.16); //0.1, 0); +} + +TEST_CASE ("Multiply AVX512F 8bit with new bias", "[Add127]") { + if (kCPU < CPUType::AVX512BW) return; + TestMultiplyBiasNew<AVX512_8bit>(1, 64, 8, 0.11, 0.11, 0.06, 0.05); + TestMultiplyBiasNew<AVX512_8bit>(8, 256, 256, 0.48, 0.54, 0.17, 0.16); //, 1.6, 1.6); //0.1, 0); + TestMultiplyBiasNew<AVX512_8bit>(8, 2048, 256, 1.57, 1.66, 0.46, 0.43); //1.8, 1.8); + TestMultiplyBiasNew<AVX512_8bit>(320, 256, 256, 0.48, 0.64, 0.16, 0.15); //0.1, 0); + TestMultiplyBiasNew<AVX512_8bit>(472, 256, 256, 0.46, 0.62, 0.17, 0.16); //0.1, 0); + TestMultiplyBiasNew<AVX512_8bit>(248, 256, 256, 0.48, 0.64, 0.16, 0.15); //0.1, 0); + TestMultiplyBiasNew<AVX512_8bit>(200, 256, 256, 0.57, 0.74, 0.17, 0.16); //0.1, 0); +} + +} //namespace intgemm diff --git a/test/multiply_test.cc b/test/multiply_test.cc index 5dac484..82a11ef 100644 --- a/test/multiply_test.cc +++ b/test/multiply_test.cc @@ -15,7 +15,6 @@ #include <iostream> #include <memory> #include <random> -#include <sstream> namespace intgemm { @@ -286,53 +285,6 @@ TEST_CASE("MaxAbsolute AVX512F", "[max]") { // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE // SOFTWARE. // Compute A*B slowly in floats. -void SlowRefFloat(const float *A, const float *B, float *C, Index A_rows, Index width, Index B_cols, const float *bias=nullptr) { - for (Index r = 0; r < A_rows; ++r) { - for (Index c = 0; c < B_cols; ++c) { - float sum = 0.0f; - for (Index k = 0; k < width; ++k) { - sum += A[r * width + k] * B[k * B_cols + c]; - } - if (bias) { - C[r * B_cols + c] = sum + bias[c]; - } else { - C[r * B_cols + c] = sum; - } - } - } -} - -// Compute A*B slowly from integers. -template <class Integer> void SlowRefInt(const Integer *A, const Integer *B, float *C, float unquant_mult, Index A_rows, Index width, Index B_cols, const float *bias=nullptr) { - for (Index r = 0; r < A_rows; ++r) { - for (Index c = 0; c < B_cols; ++c) { - int32_t sum = 0; - for (Index k = 0; k < width; ++k) { - sum += static_cast<int16_t>(A[r * width + k]) * static_cast<int16_t>(B[k * B_cols + c]); - } - if (bias) { - C[r * B_cols + c] = sum * unquant_mult + bias[c]; - } else { - C[r * B_cols + c] = sum * unquant_mult; - } - } - } -} - -void Compare(const float *float_ref, const float *int_ref, const float *int_test, std::size_t size, std::string test_info, - float int_tolerance, float float_tolerance, float MSE_float_tolerance, float MSE_int_tolerance) { - float int_sum = 0.0, float_sum = 0.0; - for (std::size_t i = 0; i < size; ++i) { - float int_diff = int_ref[i] - int_test[i]; - float float_diff = float_ref[i] - int_test[i]; - CHECK_MESSAGE(fabs(int_diff) <= int_tolerance, test_info << "Inaccurate compared to int reference at " << i << ' ' << int_ref[i] << ' ' << int_test[i]); - CHECK_MESSAGE(fabs(float_diff) <= float_tolerance, test_info << "Inaccurate compared to float reference at " << i << ' ' << float_ref[i] << ' ' << int_test[i]); - int_sum += int_diff * int_diff; - float_sum += float_diff * float_diff; - } - CHECK_MESSAGE(fabs(sqrt(float_sum / size)) <= MSE_float_tolerance, test_info << "Float MSE = " << sqrt(float_sum / size)); - CHECK_MESSAGE(fabs(sqrt(int_sum / size)) <= MSE_int_tolerance, test_info << "Int MSE = " << sqrt(int_sum / size)); -} template <class Routine> void TestMultiply(Index A_rows, Index width, Index B_cols, float int_tolerance=.1, float float_tolerance=1, float MSE_float_tolerance=0, float MSE_int_tolerance=0) { diff --git a/test/test.cc b/test/test.cc index 58c62f8..cb45b73 100644 --- a/test/test.cc +++ b/test/test.cc @@ -4,3 +4,59 @@ int main(int argc, char ** argv) { return Catch::Session().run(argc, argv); } + +namespace intgemm { + +void SlowRefFloat(const float *A, const float *B, float *C, Index A_rows, Index width, Index B_cols, const float *bias) { + for (Index r = 0; r < A_rows; ++r) { + for (Index c = 0; c < B_cols; ++c) { + float sum = 0.0f; + for (Index k = 0; k < width; ++k) { + sum += A[r * width + k] * B[k * B_cols + c]; + } + if (bias) { + C[r * B_cols + c] = sum + bias[c]; + } else { + C[r * B_cols + c] = sum; + } + } + } +} + +// Compute A*B slowly from integers. +template <class Integer> void SlowRefInt(const Integer *A, const Integer *B, float *C, float unquant_mult, Index A_rows, Index width, Index B_cols, const float *bias) { + for (Index r = 0; r < A_rows; ++r) { + for (Index c = 0; c < B_cols; ++c) { + int32_t sum = 0; + for (Index k = 0; k < width; ++k) { + sum += static_cast<int16_t>(A[r * width + k]) * static_cast<int16_t>(B[k * B_cols + c]); + } + if (bias) { + C[r * B_cols + c] = sum * unquant_mult + bias[c]; + } else { + C[r * B_cols + c] = sum * unquant_mult; + } + } + } +} + +template void SlowRefInt<int8_t>(const int8_t *A, const int8_t *B, float *C, float unquant_mult, Index A_rows, Index width, Index B_cols, const float *bias); +template void SlowRefInt<int16_t>(const int16_t *A, const int16_t *B, float *C, float unquant_mult, Index A_rows, Index width, Index B_cols, const float *bias); +template void SlowRefInt<int32_t>(const int32_t *A, const int32_t *B, float *C, float unquant_mult, Index A_rows, Index width, Index B_cols, const float *bias); + +void Compare(const float *float_ref, const float *int_ref, const float *int_test, std::size_t size, std::string test_info, + float int_tolerance, float float_tolerance, float MSE_float_tolerance, float MSE_int_tolerance) { + float int_sum = 0.0, float_sum = 0.0; + for (std::size_t i = 0; i < size; ++i) { + float int_diff = int_ref[i] - int_test[i]; + float float_diff = float_ref[i] - int_test[i]; + CHECK_MESSAGE(fabs(int_diff) <= int_tolerance, test_info << "Inaccurate compared to int reference at " << i << ' ' << int_ref[i] << ' ' << int_test[i]); + CHECK_MESSAGE(fabs(float_diff) <= float_tolerance, test_info << "Inaccurate compared to float reference at " << i << ' ' << float_ref[i] << ' ' << int_test[i]); + int_sum += int_diff * int_diff; + float_sum += float_diff * float_diff; + } + CHECK_MESSAGE(fabs(sqrt(float_sum / size)) <= MSE_float_tolerance, test_info << "Float MSE = " << sqrt(float_sum / size)); + CHECK_MESSAGE(fabs(sqrt(int_sum / size)) <= MSE_int_tolerance, test_info << "Int MSE = " << sqrt(int_sum / size)); +} + +} //namespace intgemm diff --git a/test/test.h b/test/test.h index f2f745d..1034337 100644 --- a/test/test.h +++ b/test/test.h @@ -1,4 +1,9 @@ +#pragma once + #include "3rd_party/catch.hpp" +#include <sstream> +#include "intgemm.h" +#include "aligned.h" #include "config.h" @@ -14,3 +19,14 @@ } while(0) #define KERNEL_TEST_CASE(name) TEST_CASE("Kernel: " name, "[kernel_test]") + +namespace intgemm { +void SlowRefFloat(const float *A, const float *B, float *C, Index A_rows, Index width, Index B_cols, const float *bias=nullptr); + +// Compute A*B slowly from integers. +template <class Integer> void SlowRefInt(const Integer *A, const Integer *B, float *C, float unquant_mult, Index A_rows, Index width, Index B_cols, const float *bias=nullptr); + +void Compare(const float *float_ref, const float *int_ref, const float *int_test, std::size_t size, std::string test_info, + float int_tolerance, float float_tolerance, float MSE_float_tolerance, float MSE_int_tolerance); + +} //namespace intgemm diff --git a/test_mull.cpp b/test_mull.cpp new file mode 100644 index 0000000..e83f1c9 --- /dev/null +++ b/test_mull.cpp @@ -0,0 +1,328 @@ +#include "intgemm.cc" +#include "aligned.h" +#include <iostream> +#include <random> +#include <string> +#include <algorithm> +#include <fstream> +#include <sstream> + + +/*Adapted from https://www.bfilipek.com/2018/07/string-view-perf-followup.html . We should probably go string_view way +inline void tokenizeLine(std::string& str, std::vector<std::string>& output, + std::string delimeter = " ") { + auto first = std::begin(str); + + while (first != str.end()) { + const auto second = std::find_first_of(first, std::end(str), std::begin(delimeter), std::end(delimeter)); + + if (first != second) { + output.emplace_back(str.substr(std::distance(std::begin(str), first), std::distance(first, second))); + } + + if (second == str.end()) + break; + + first = std::next(second); + } +} + +//This is a different parsing method, without stringStream +template<class StringType> +void ReadInFile(StringType infile) { + std::ifstream in(infile); + std::string line; + + //First line, Info about the matrix + std::getline(in, line); + std::istringstream iss(line); + std::string temp1, temp2, temp3, temp4; + int RowsA, ColsA, RowsB, ColsB; + if (!(iss >> temp1 >> RowsA >> temp2 >> ColsA >> temp3 >> RowsB >> temp4 >> ColsB)) { + std::cerr << "Error parsing line 1 " << std::endl; + exit(1); + } + + //Second line, get QuantMult + std::getline(in, line); + std::istringstream iss2(line); + float quantMultA, quantMultB; + if (!(iss2 >> temp1 >> quantMultA >> temp2 >> quantMultA)) { + std::cerr << "Error parsing line 2 " << std::endl; + exit(1); + } + std::getline(in, line); //Just some text + //Fourth line, AQuant + std::vector<int> AQuant; + std::getline(in, line); + std::vector<std::string> tmp_container; + tokenizeLine(line, tmp_container); + if (tmp_container.size() != RowsA*ColsA) { + std::cerr << "Error parsing matrix A. Size mismatch. Expected " << RowsA*ColsA << " got " << tmp_container.size() << std::endl; + } + for (auto&& num : tmp_container) { + AQuant.push_back(std::stoi(num)); + } + tmp_container.resize(0); + + std::getline(in, line); //Just some text + //Sixth line, B_raw + std::vector<float> B_raw; + std::getline(in, line); + tokenizeLine(line, tmp_container); + if (tmp_container.size() != RowsB*ColsB) { + std::cerr << "Error parsing matrix B. Size mismatch. Expected " << RowsB*ColsB << " got " << tmp_container.size() << std::endl; + } + for (auto&& num : tmp_container) { + B_raw.push_back(std::stof(num)); + } + tmp_container.resize(0); + + std::getline(in, line); //Just some text + //Eight line, Bias + std::vector<float> Bias; + std::getline(in, line); + tokenizeLine(line, tmp_container); + if (tmp_container.size() != ColsB) { + std::cerr << "Error parsing bias. Size mismatch. Expected " << ColsB << " got " << tmp_container.size() << std::endl; + } + for (auto&& num : tmp_container) { + Bias.push_back(std::stof(num)); + } + tmp_container.resize(0); + +} + +*/ +template<class StringType> +void ReadInFile(StringType infile) { + std::ifstream in(infile); + std::string line; + + //First line, Info about the matrix + std::getline(in, line); + std::istringstream iss(line); + std::string temp1, temp2, temp3, temp4; + int RowsA, ColsA, RowsB, ColsB; + if (!(iss >> temp1 >> RowsA >> temp2 >> ColsA >> temp3 >> RowsB >> temp4 >> ColsB)) { + std::cerr << "Error parsing line 1 " << std::endl; + exit(1); + } + + //Second line, get QuantMult + std::getline(in, line); + std::istringstream iss2(line); + float quantMultA, quantMultB; + if (!(iss2 >> temp1 >> quantMultA >> temp2 >> quantMultA)) { + std::cerr << "Error parsing line 2 " << std::endl; + exit(1); + } + std::getline(in, line); //Just some text for human readability + + //4th line, AQuant + std::vector<int> AQuant; + std::getline(in, line); + std::istringstream iss3(line); + for (int i = 0; i < RowsA*ColsA; i++) { + int num; + if (!(iss3 >> num)) { + std::cerr << "Error parsing matrix A at " << i << std::endl;; + } + AQuant.push_back(num); + } + + std::getline(in, line); //Just some text for human readability + //6th line, B_raw + std::vector<float> B_raw; + std::getline(in, line); + std::istringstream iss4(line); + for (int i = 0; i < RowsB*ColsB; i++) { + float num; + if (!(iss4 >> num)) { + std::cerr << "Error parsing matrix B " << std::endl; + } + B_raw.push_back(num); + } + + std::getline(in, line); //Just some text for human readability + //8th line, Bias + std::vector<float> Bias; + std::getline(in, line); + std::istringstream iss5(line); + for (int i = 0; i < ColsB; i++) { + float num; + if (!(iss5 >> num)) { + std::cerr << "Error parsing matrix bias " << std::endl; + } + Bias.push_back(num); + } +} + +using namespace intgemm; +template<class T> +void printMatrix(T* data, Index rows, Index cols) { + std::cout << "["; + for (int i = 0; i<rows; i++) { + std::cout << "["; + for (int j =0; j<cols; j++) { + std::cout << (float)data[i*cols + j]; + if (j != cols - 1) { + std::cout << ", "; + } + } + std::cout << "]"; + if (i != rows -1) { + std::cout << ',' << std::endl; + } + } + std::cout << "]" << std::endl; +} + +void SlowRefFloat(const float *A, const float *B, float *C, Index A_rows, Index width, Index B_cols, const float *bias) { + for (Index r = 0; r < A_rows; ++r) { + for (Index c = 0; c < B_cols; ++c) { + float sum = 0.0f; + for (Index k = 0; k < width; ++k) { + sum += A[r * width + k] * B[k * B_cols + c]; + } + if (bias) { + C[r * B_cols + c] = sum + bias[c]; + } else { + C[r * B_cols + c] = sum; + } + } + } +} + +// Compute A*B slowly from integers. +template <class Integer> +void SlowRefInt(const Integer *A, const Integer *B, float *C, float unquant_mult, Index A_rows, Index width, Index B_cols, const float *bias) { + for (Index r = 0; r < A_rows; ++r) { + for (Index c = 0; c < B_cols; ++c) { + int32_t sum = 0; + for (Index k = 0; k < width; ++k) { + sum += static_cast<int16_t>(A[r * width + k]) * static_cast<int16_t>(B[k * B_cols + c]); + } + if (bias) { + C[r * B_cols + c] = sum * unquant_mult + bias[c]; + } else { + C[r * B_cols + c] = sum * unquant_mult; + } + } + } +} + +int main() { + + const Index A_rows = 1; + const Index width = 2048; + const Index B_cols = 8; + + AlignedVector<float> A(A_rows * width); + AlignedVector<float> B(width * B_cols); + AlignedVector<float> bias(B_cols); + + float alpha = 2.0f; + float quant_mult = 127/alpha; + float unquant_mult = 1.0 / (quant_mult * quant_mult); + + std::mt19937 gen; + std::uniform_real_distribution<float> dist(-2.0f, 2.0f); + + for (auto& it : A) { + it = dist(gen); + } + for (auto& it : B) { + it = dist(gen); + } + for (auto& it : bias) { + it = dist(gen); + } + + AlignedVector<float> bias_orig(B_cols); + for (int i = 0; i < bias.size(); i++) { + bias_orig[i] = bias[i]; + } + + AlignedVector<int8_t> A_prep(A.size()); + AlignedVector<int8_t> B_prep(B.size()); + + AVX2_8bit::PrepareA(A.begin(), A_prep.begin(), quant_mult, A_rows, width); + AVX2_8bit::PrepareB(B.begin(), B_prep.begin(), quant_mult, width, B_cols); +/* + std::cout << "A:" << std::endl; + printMatrix(A.begin(), A_rows, width); + std::cout << "B:" << std::endl; + printMatrix(B.begin(), width, B_cols); + std::cout << "bias:" << std::endl; + printMatrix(bias.begin(), 1, B_cols);*/ + + + AlignedVector<float> test_C(A_rows * B_cols); + + AVX2_8bit::Multiply(A_prep.begin(), B_prep.begin(), A_rows, width, B_cols, UnquantizeAndAddBiasAndWrite(unquant_mult, bias.begin(), test_C.begin())); + //AVX2_8bit::Multiply(A_prep.begin(), B_prep.begin(), JustUnquantizeC(test_C.begin(), unquant_mult), A_rows, width, B_cols); + std::cout << "Old multiply:" << std::endl; + printMatrix(test_C.begin(), A_rows, B_cols); + + //NEEEXT + AlignedVector<uint8_t> A_prep2(A.size()); + AVX2_8bit::PrepareA(A.begin(), A_prep2.begin(), quant_mult, A_rows, width); + + AVX2_8bit::PrepareBiasFor8(B.begin(), bias.begin(), alpha, width, B_cols); + + //printMatrix(bias.begin(), 1, B_cols); //Print bias + + AVX2_8bit::Multiply8new(reinterpret_cast<uint8_t*>(A_prep2.begin()), B_prep.begin(), A_rows, width, B_cols, UnquantizeAndAddBiasAndWrite(unquant_mult, bias.begin(), test_C.begin())); + //AVX2_8bit::Multiply8new(reinterpret_cast<uint8_t*>(A_prep.begin()), B_prep.begin(), JustUnquantizeC(test_C.begin(), unquant_mult), A_rows, width, B_cols); + + AlignedVector<int16_t> A_prep3(A.size()); + AlignedVector<int16_t> B_prep3(B.size()); + std::cout << "New multiply:" << std::endl; + printMatrix(test_C.begin(), A_rows, B_cols); + for (int i = 0; i < A_prep2.size(); i++) { + A_prep3[i] = A_prep2[i]; + } + AVX2_16bit::PrepareB(B.begin(), B_prep3.begin(), quant_mult, width, B_cols); + AVX2_16bit::Multiply(A_prep3.begin(), B_prep3.begin(), A_rows, width, B_cols, UnquantizeAndAddBiasAndWrite(unquant_mult, bias.begin(), test_C.begin())); + + std::cout << "New multiply, 16 bit:" << std::endl; + printMatrix(test_C.begin(), A_rows, B_cols); + + //FULL INTS + AlignedVector<float> C_slowint(A_rows * B_cols); + AlignedVector<int8_t> B_quant(width * B_cols); + AVX2_8bit::Quantize(B.begin(), B_quant.begin(), quant_mult, B.size()); + + SlowRefInt(A_prep.begin(), B_quant.begin(), C_slowint.begin(), + unquant_mult, A_rows, width, B_cols, bias_orig.begin()); + + + std::cout << "Reference int8:" << std::endl; + printMatrix(C_slowint.begin(), A_rows, B_cols); + + //FULL INT16 + AlignedVector<int16_t> A_prep4(A.size()); + for (int i = 0; i < A_prep2.size(); i++) { + A_prep4[i] = A_prep[i]; + } + + AlignedVector<float> C_slowint2(A_rows * B_cols); + AlignedVector<int16_t> B_quant2(width * B_cols); + AVX2_16bit::Quantize(B.begin(), B_quant2.begin(), quant_mult, B.size()); + + SlowRefInt(A_prep4.begin(), B_quant2.begin(), C_slowint2.begin(), + unquant_mult, A_rows, width, B_cols, bias_orig.begin()); + + + std::cout << "Reference int16:" << std::endl; + printMatrix(C_slowint2.begin(), A_rows, B_cols); + + //FLOATS + AlignedVector<float> C(A_rows * B_cols); + + SlowRefFloat(A.begin(), B.begin(), C.begin(), A_rows, width, B_cols, bias_orig.begin()); + std::cout << "Reference float:" << std::endl; + printMatrix(C.begin(), A_rows, B_cols); + +} |