diff options
author | Mateusz Chudyk <mateuszchudyk@gmail.com> | 2020-03-27 17:02:12 +0300 |
---|---|---|
committer | Mateusz Chudyk <mateuszchudyk@gmail.com> | 2020-03-27 17:02:17 +0300 |
commit | 4b65293dafce60cd958149be9abb400cc9c7057f (patch) | |
tree | aaeb91a47a5f34b750071737c3f2a00fe205b568 | |
parent | c351bd5793ccc36738ecfe921479edd588f723cf (diff) |
TEMPmultiply-tiling
-rw-r--r-- | CMakeLists.txt | 1 | ||||
-rw-r--r-- | benchmarks/benchmark.cc | 2 | ||||
-rw-r--r-- | benchmarks/biasmultiply.cc | 6 | ||||
-rw-r--r-- | example.cc | 4 | ||||
-rw-r--r-- | interleave.h | 82 | ||||
-rw-r--r-- | test/add127_test.cc | 8 | ||||
-rw-r--r-- | test/multiply_test.cc | 10 | ||||
-rw-r--r-- | test/multiply_tiling_test.cc | 4 | ||||
-rw-r--r-- | test/prepare_b.cc | 167 |
9 files changed, 227 insertions, 57 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt index ec1579a..57b0c53 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -58,6 +58,7 @@ add_executable(tests test/add127_test.cc test/multiply_test.cc test/multiply_tiling_test + test/prepare_b.cc test/prepare_b_quantized_transposed.cc test/prepare_b_transposed.cc test/quantize_test.cc diff --git a/benchmarks/benchmark.cc b/benchmarks/benchmark.cc index e4b3815..4f04765 100644 --- a/benchmarks/benchmark.cc +++ b/benchmarks/benchmark.cc @@ -77,7 +77,7 @@ template <class Backend> void Run(const RandomMatrices &m, std::vector<uint64_t> AlignedVector<Integer> A_prepared(m.A_rows * m.width); Backend::PrepareA(m.A.begin(), A_prepared.begin(), quant_mult, m.A_rows, m.width); AlignedVector<Integer> B_prepared(m.width * m.B_cols); - Backend::template PrepareB<1>(m.B.begin(), B_prepared.begin(), quant_mult, m.width, m.B_cols); + Backend::template PrepareB<8>(m.B.begin(), B_prepared.begin(), quant_mult, m.width, m.B_cols); AlignedVector<float> output(m.A_rows * m.B_cols); // Burn in Backend::template Multiply<1, 1>(A_prepared.begin(), B_prepared.begin(), m.A_rows, m.width, m.B_cols, callbacks::UnquantizeAndWrite(unquant_mult, output.begin())); diff --git a/benchmarks/biasmultiply.cc b/benchmarks/biasmultiply.cc index a95af19..e031609 100644 --- a/benchmarks/biasmultiply.cc +++ b/benchmarks/biasmultiply.cc @@ -35,7 +35,7 @@ std::chrono::duration<double> testNew(Index A_rows, Index width, Index B_cols) { AlignedVector<uint8_t> A_prep(A.size()); AlignedVector<int8_t> B_prep(B.size()); Routine::PrepareA(A.begin(), A_prep.begin(), quant_mult, A_rows, width); - Routine::template PrepareB<1>(B.begin(), B_prep.begin(), quant_mult, width, B_cols); + Routine::template PrepareB<8>(B.begin(), B_prep.begin(), quant_mult, width, B_cols); AlignedVector<float> test_C(A_rows * B_cols); @@ -74,7 +74,7 @@ std::chrono::duration<double> testOld(Index A_rows, Index width, Index B_cols) { AlignedVector<int8_t> A_prep(A.size()); AlignedVector<int8_t> B_prep(B.size()); Routine::PrepareA(A.begin(), A_prep.begin(), quant_mult, A_rows, width); - Routine::template PrepareB<1>(B.begin(), B_prep.begin(), quant_mult, width, B_cols); + Routine::template PrepareB<8>(B.begin(), B_prep.begin(), quant_mult, width, B_cols); AlignedVector<float> test_C(A_rows * B_cols); @@ -107,7 +107,7 @@ std::chrono::duration<double> testOld_nobias(Index A_rows, Index width, Index B_ AlignedVector<int8_t> A_prep(A.size()); AlignedVector<int8_t> B_prep(B.size()); Routine::PrepareA(A.begin(), A_prep.begin(), quant_mult, A_rows, width); - Routine::template PrepareB<1>(B.begin(), B_prep.begin(), quant_mult, width, B_cols); + Routine::template PrepareB<8>(B.begin(), B_prep.begin(), quant_mult, width, B_cols); AlignedVector<float> test_C(A_rows * B_cols); @@ -48,7 +48,7 @@ int main() { intgemm::Int16::PrepareA(A.begin(), A_prepared.begin(), quant_mult, A_rows, width); // Quantize and reshape B. // Typically you will do this once when parameters are loaded, not every time. - intgemm::Int16::template PrepareB<1>(B.begin(), B_prepared.begin(), quant_mult, width, B_cols); + intgemm::Int16::template PrepareB<8>(B.begin(), B_prepared.begin(), quant_mult, width, B_cols); AlignedVector<float> C(A_rows * B_cols); // Do the actual multiply. @@ -67,7 +67,7 @@ int main() { intgemm::Int8::PrepareA(A.begin(), A_prepared.begin(), quant_mult, A_rows, width); // Quantize and reshape B. // Typically you will do this once when parameters are loaded, not every time. - intgemm::Int8::template PrepareB<1>(B.begin(), B_prepared.begin(), quant_mult, width, B_cols); + intgemm::Int8::template PrepareB<8>(B.begin(), B_prepared.begin(), quant_mult, width, B_cols); AlignedVector<float> C(A_rows * B_cols); // Do the actual multiply. diff --git a/interleave.h b/interleave.h index 8680ce8..8bf20cd 100644 --- a/interleave.h +++ b/interleave.h @@ -8,6 +8,7 @@ #include <algorithm> #include <cassert> #include <stdint.h> +#include <iostream> namespace intgemm { @@ -187,31 +188,32 @@ template <class Register> static inline void Transpose8InLane( // 257 273 // ... ... -template <typename Type> +template <typename Type, Index TileColumns> struct PrepareB_InnerLoop; #define INTGEMM_PREPARE_B_8_INNER_LOOP(target, Register) \ template <typename Iterator, typename Quantizer> \ - target static void body(Register* output, const Quantizer &quantizer, const float* input, Index cols, Index row, Index col) { \ + target static void body(Register* output, const Quantizer &quantizer, const float* input, Index rows, Index cols, Index row, Index col) { \ static constexpr Index I = Iterator::template I<0>(); \ - output[8 * I + 0] = quantizer.ForReshape(input + cols * (row + 0) + 8 * I + col, cols); \ - output[8 * I + 1] = quantizer.ForReshape(input + cols * (row + 1) + 8 * I + col, cols); \ - output[8 * I + 2] = quantizer.ForReshape(input + cols * (row + 4) + 8 * I + col, cols); \ - output[8 * I + 3] = quantizer.ForReshape(input + cols * (row + 5) + 8 * I + col, cols); \ - output[8 * I + 4] = quantizer.ForReshape(input + cols * (row + 8) + 8 * I + col, cols); \ - output[8 * I + 5] = quantizer.ForReshape(input + cols * (row + 9) + 8 * I + col, cols); \ - output[8 * I + 6] = quantizer.ForReshape(input + cols * (row + 12) + 8 * I + col, cols); \ - output[8 * I + 7] = quantizer.ForReshape(input + cols * (row + 13) + 8 * I + col, cols); \ - Interleave8(output[8 * I + 0], output[8 * I + 1]); \ - Interleave8(output[8 * I + 2], output[8 * I + 3]); \ - Interleave8(output[8 * I + 4], output[8 * I + 5]); \ - Interleave8(output[8 * I + 6], output[8 * I + 7]); \ - Transpose16InLane(output[8 * I + 0], output[8 * I + 1], output[8 * I + 2], output[8 * I + 3], \ - output[8 * I + 4], output[8 * I + 5], output[8 * I + 6], output[8 * I + 7]); \ + const Index offset = rows / 8 * TileColumns; \ + output[0 / TileColumns * offset + 8 * I + 0] = quantizer.ForReshape(input + cols * (row + 0) + 8 * I + col, cols); \ + output[1 / TileColumns * offset + 8 * I + 1] = quantizer.ForReshape(input + cols * (row + 1) + 8 * I + col, cols); \ + output[2 / TileColumns * offset + 8 * I + 2] = quantizer.ForReshape(input + cols * (row + 4) + 8 * I + col, cols); \ + output[3 / TileColumns * offset + 8 * I + 3] = quantizer.ForReshape(input + cols * (row + 5) + 8 * I + col, cols); \ + output[4 / TileColumns * offset + 8 * I + 4] = quantizer.ForReshape(input + cols * (row + 8) + 8 * I + col, cols); \ + output[5 / TileColumns * offset + 8 * I + 5] = quantizer.ForReshape(input + cols * (row + 9) + 8 * I + col, cols); \ + output[6 / TileColumns * offset + 8 * I + 6] = quantizer.ForReshape(input + cols * (row + 12) + 8 * I + col, cols); \ + output[7 / TileColumns * offset + 8 * I + 7] = quantizer.ForReshape(input + cols * (row + 13) + 8 * I + col, cols); \ + Interleave8(output[0 / TileColumns * offset + 8 * I + 0], output[1 / TileColumns * offset + 8 * I + 1]); \ + Interleave8(output[2 / TileColumns * offset + 8 * I + 2], output[3 / TileColumns * offset + 8 * I + 3]); \ + Interleave8(output[4 / TileColumns * offset + 8 * I + 4], output[5 / TileColumns * offset + 8 * I + 5]); \ + Interleave8(output[6 / TileColumns * offset + 8 * I + 6], output[7 / TileColumns * offset + 8 * I + 7]); \ + Transpose16InLane(output[0 / TileColumns * offset + 8 * I + 0], output[1 / TileColumns * offset + 8 * I + 1], output[2 / TileColumns * offset + 8 * I + 2], output[3 / TileColumns * offset + 8 * I + 3], \ + output[4 / TileColumns * offset + 8 * I + 4], output[5 / TileColumns * offset + 8 * I + 5], output[6 / TileColumns * offset + 8 * I + 6], output[7 / TileColumns * offset + 8 * I + 7]); \ } -template <> -struct PrepareB_InnerLoop<int8_t> { +template <Index TileColumns> +struct PrepareB_InnerLoop<int8_t, TileColumns> { INTGEMM_PREPARE_B_8_INNER_LOOP(INTGEMM_SSSE3, __m128i) INTGEMM_PREPARE_B_8_INNER_LOOP(INTGEMM_AVX2, __m256i) INTGEMM_PREPARE_B_8_INNER_LOOP(INTGEMM_AVX512BW, __m512i) @@ -219,49 +221,51 @@ struct PrepareB_InnerLoop<int8_t> { #define INTGEMM_PREPARE_B_16_INNER_LOOP(target, Register) \ template <typename Iterator, typename Quantizer> \ - target static void body(Register* output, const Quantizer &quantizer, const float* input, Index cols, Index row, Index col) { \ + target static void body(Register* output, const Quantizer &quantizer, const float* input, Index rows, Index cols, Index row, Index col) { \ static constexpr Index I = Iterator::template I<0>(); \ - output[8 * I + 0] = quantizer.ForReshape(input + cols * (row + 0) + 8 * I + col, cols); \ - output[8 * I + 1] = quantizer.ForReshape(input + cols * (row + 1) + 8 * I + col, cols); \ - output[8 * I + 2] = quantizer.ForReshape(input + cols * (row + 2) + 8 * I + col, cols); \ - output[8 * I + 3] = quantizer.ForReshape(input + cols * (row + 3) + 8 * I + col, cols); \ - output[8 * I + 4] = quantizer.ForReshape(input + cols * (row + 4) + 8 * I + col, cols); \ - output[8 * I + 5] = quantizer.ForReshape(input + cols * (row + 5) + 8 * I + col, cols); \ - output[8 * I + 6] = quantizer.ForReshape(input + cols * (row + 6) + 8 * I + col, cols); \ - output[8 * I + 7] = quantizer.ForReshape(input + cols * (row + 7) + 8 * I + col, cols); \ - Transpose16InLane(output[8 * I + 0], output[8 * I + 1], output[8 * I + 2], output[8 * I + 3], \ - output[8 * I + 4], output[8 * I + 5], output[8 * I + 6], output[8 * I + 7]); \ + const Index offset = rows / 8 * TileColumns; \ + output[0 / TileColumns * 8 * I + 0] = quantizer.ForReshape(input + cols * (row + 0) + 8 * I + col, cols); \ + output[1 / TileColumns * 8 * I + 1] = quantizer.ForReshape(input + cols * (row + 1) + 8 * I + col, cols); \ + output[2 / TileColumns * 8 * I + 2] = quantizer.ForReshape(input + cols * (row + 2) + 8 * I + col, cols); \ + output[3 / TileColumns * 8 * I + 3] = quantizer.ForReshape(input + cols * (row + 3) + 8 * I + col, cols); \ + output[4 / TileColumns * 8 * I + 4] = quantizer.ForReshape(input + cols * (row + 4) + 8 * I + col, cols); \ + output[5 / TileColumns * 8 * I + 5] = quantizer.ForReshape(input + cols * (row + 5) + 8 * I + col, cols); \ + output[6 / TileColumns * 8 * I + 6] = quantizer.ForReshape(input + cols * (row + 6) + 8 * I + col, cols); \ + output[7 / TileColumns * 8 * I + 7] = quantizer.ForReshape(input + cols * (row + 7) + 8 * I + col, cols); \ + Transpose16InLane(output[0 / TileColumns * offset + 8 * I + 0], output[1 / TileColumns * offset + 8 * I + 1], output[2 / TileColumns * offset + 8 * I + 2], output[3 / TileColumns * offset + 8 * I + 3], \ + output[4 / TileColumns * offset + 8 * I + 4], output[5 / TileColumns * offset + 8 * I + 5], output[6 / TileColumns * offset + 8 * I + 6], output[7 / TileColumns * offset + 8 * I + 7]); \ } -template <> -struct PrepareB_InnerLoop<int16_t> { +template <Index TileColumns> +struct PrepareB_InnerLoop<int16_t, TileColumns> { INTGEMM_PREPARE_B_16_INNER_LOOP(INTGEMM_SSSE3, __m128i) INTGEMM_PREPARE_B_16_INNER_LOOP(INTGEMM_AVX2, __m256i) INTGEMM_PREPARE_B_16_INNER_LOOP(INTGEMM_AVX512BW, __m512i) }; #define INTGEMM_PREPARE_B(target, Quantizer, Integer) \ -template <Index TileColumnsMultiplier> \ +template <Index TileColumns> \ target static inline void PrepareB(const float *input, Integer *output, float quant_mult, Index rows, Index cols) { \ - static constexpr Index Columns = 8 * TileColumnsMultiplier; \ using Register = Quantizer::Register; \ - const Index RegisterElems = sizeof(Register) / sizeof(Integer); \ + static constexpr Index TileColumnsRoundUp = (TileColumns + 7) / 8 * 8; \ + constexpr Index RegisterElems = sizeof(Register) / sizeof(Integer); \ \ Quantizer quantizer = Quantizer(quant_mult); \ Register *output_it = reinterpret_cast<Register*>(output); \ \ - assert(cols % Columns == 0); \ - assert(rows % (RegisterElems * TileColumnsMultiplier) == 0); \ + assert(cols % TileColumnsRoundUp == 0); \ + assert(rows % RegisterElems == 0); \ assert(reinterpret_cast<uintptr_t>(input) % sizeof(Register) == 0); \ assert(reinterpret_cast<uintptr_t>(output_it) % sizeof(Register) == 0); \ \ - for (Index c = 0; c < cols; c += Columns) { \ - for (Index r = 0; r < rows; r += RegisterElems, output_it += Columns) { \ + for (Index c = 0; c < cols - TileColumns; c += TileColumns) { \ + for (Index r = 0; r < rows; r += RegisterElems, output_it += TileColumns) { \ + std::cout << TileColumns << ", r: " << r << ", c: " << c << ", output: " << output_it << "/" << output + rows * cols * sizeof(Integer) << std::endl; \ /* Quantize and perform a transpose with height sizeof(Register) and width Columns. \ This isn't quite Transpose8InLane because it's half the number of columns, \ so each register starts with two rows instead of being one row. \ The quantizers know to skip a row.*/ \ - StaticLoop<PrepareB_InnerLoop<Integer>, MakeStaticLoopIterator<TileColumnsMultiplier>>(output_it, quantizer, input, cols, r, c); \ + StaticLoop<PrepareB_InnerLoop<Integer, TileColumns>, MakeStaticLoopIterator<TileColumnsRoundUp / 8>>(output_it, quantizer, input, rows, cols, r, c); \ } \ } \ } diff --git a/test/add127_test.cc b/test/add127_test.cc index 7ebb93e..b192291 100644 --- a/test/add127_test.cc +++ b/test/add127_test.cc @@ -46,7 +46,7 @@ template <class Routine> void TestPrepareBias(Index rows, Index cols) { AlignedVector<int8_t> B_prep(inputB.size()); AlignedVector<int8_t> B_quant(inputB.size()); - Routine::template PrepareB<1>(inputB.begin(), B_prep.begin(), quant_mult, rows, cols); + Routine::template PrepareB<8>(inputB.begin(), B_prep.begin(), quant_mult, rows, cols); Routine::Quantize(inputB.begin(), B_quant.begin(), quant_mult, inputB.size()); AlignedVector<float> inputBias(cols); @@ -107,7 +107,7 @@ template <class Routine> void TestMultiplyBiasNew(Index A_rows, Index width, Ind AlignedVector<uint8_t> A_prep(A.size()); AlignedVector<int8_t> B_prep(B.size()); Routine::PrepareA(A.begin(), A_prep.begin(), quant_mult, A_rows, width); - Routine::template PrepareB<1>(B.begin(), B_prep.begin(), quant_mult, width, B_cols); + Routine::template PrepareB<8>(B.begin(), B_prep.begin(), quant_mult, width, B_cols); AlignedVector<float> test_C(A_rows * B_cols); @@ -172,7 +172,7 @@ template <class Routine> void TestMultiplyShiftNonShift(Index A_rows, Index widt AlignedVector<int8_t> B_prep(B.size()); Routine::PrepareA(A.begin(), A_prep.begin(), quant_mult, A_rows, width); Routine::PrepareA(A.begin(), A_prep_old.begin(), quant_mult, A_rows, width); //Non shited version - Routine::template PrepareB<1>(B.begin(), B_prep.begin(), quant_mult, width, B_cols); + Routine::template PrepareB<8>(B.begin(), B_prep.begin(), quant_mult, width, B_cols); AlignedVector<float> test_C(A_rows * B_cols); @@ -228,7 +228,7 @@ template <class Routine> void TestMultiplyShiftInt(Index A_rows, Index width, In AlignedVector<int8_t> B_prep(B.size()); Routine::PrepareA(A.begin(), A_prep.begin(), quant_mult, A_rows, width); Routine::PrepareA(A.begin(), A_prep_old.begin(), quant_mult, A_rows, width); //Non shited version - Routine::template PrepareB<1>(B.begin(), B_prep.begin(), quant_mult, width, B_cols); + Routine::template PrepareB<8>(B.begin(), B_prep.begin(), quant_mult, width, B_cols); AlignedVector<float> test_C(A_rows * B_cols); diff --git a/test/multiply_test.cc b/test/multiply_test.cc index 9deff69..6956c56 100644 --- a/test/multiply_test.cc +++ b/test/multiply_test.cc @@ -67,7 +67,7 @@ template <class Routine> void TestPrepare(Index rows = 32, Index cols = 16) { typedef typename Routine::Integer Integer; // Call Prepare AlignedVector<Integer> test(input.size()); - Routine::template PrepareB<1>(input.begin(), test.begin(), 1, rows, cols); + Routine::template PrepareB<8>(input.begin(), test.begin(), 1, rows, cols); // Compute reference output. AlignedVector<Integer> quantized(input.size()); @@ -119,7 +119,7 @@ template <class Routine> void TestSelectColumnsB(Index rows = 64, Index cols = 1 } typedef typename Routine::Integer Integer; AlignedVector<Integer> prepared(input.size()); - Routine::template PrepareB<1>(input.begin(), prepared.begin(), 1, rows, cols); + Routine::template PrepareB<8>(input.begin(), prepared.begin(), 1, rows, cols); int kSelectCols = 24; Index select_cols[kSelectCols]; @@ -140,7 +140,7 @@ template <class Routine> void TestSelectColumnsB(Index rows = 64, Index cols = 1 } } AlignedVector<Integer> ref(rows * kSelectCols); - Routine::template PrepareB<1>(selected.begin(), ref.begin(), 1, rows, kSelectCols); + Routine::template PrepareB<8>(selected.begin(), ref.begin(), 1, rows, kSelectCols); CHECK_MESSAGE(memcmp(ref.begin(), test.begin(), sizeof(Integer) * rows * kSelectCols) == 0, "Reference:\n" << PrintMatrix(ref.begin(), rows, kSelectCols) << PrintMatrix(test.begin(), rows, kSelectCols)); } @@ -273,7 +273,7 @@ template <class Routine> void TestMultiply(Index A_rows, Index width, Index B_co AlignedVector<Integer> A_prep(A.size()); AlignedVector<Integer> B_prep(B.size()); Routine::PrepareA(A.begin(), A_prep.begin(), quant_mult, A_rows, width); - Routine::template PrepareB<1>(B.begin(), B_prep.begin(), quant_mult, width, B_cols); + Routine::template PrepareB<8>(B.begin(), B_prep.begin(), quant_mult, width, B_cols); AlignedVector<float> test_C(A_rows * B_cols); Routine::template Multiply<1, 1>(A_prep.begin(), B_prep.begin(), A_rows, width, B_cols, callbacks::UnquantizeAndWrite(unquant_mult, test_C.begin())); @@ -329,7 +329,7 @@ template <class Routine> void TestMultiplyBias(Index A_rows, Index width, Index AlignedVector<Integer> A_prep(A.size()); AlignedVector<Integer> B_prep(B.size()); Routine::PrepareA(A.begin(), A_prep.begin(), quant_mult, A_rows, width); - Routine::template PrepareB<1>(B.begin(), B_prep.begin(), quant_mult, width, B_cols); + Routine::template PrepareB<8>(B.begin(), B_prep.begin(), quant_mult, width, B_cols); AlignedVector<float> test_C(A_rows * B_cols); diff --git a/test/multiply_tiling_test.cc b/test/multiply_tiling_test.cc index ba50756..73f0c33 100644 --- a/test/multiply_tiling_test.cc +++ b/test/multiply_tiling_test.cc @@ -11,7 +11,6 @@ #include <random> namespace intgemm { -namespace { template <typename Backend, Index TileRows, Index TileColumnsMultiplier> bool Test(const AlignedVector<float>& A, const AlignedVector<float>& B, Index A_rows, Index width, Index B_cols, float quant_mult) { @@ -23,7 +22,7 @@ bool Test(const AlignedVector<float>& A, const AlignedVector<float>& B, Index A_ AlignedVector<int32_t> reference(output.size()); Backend::PrepareA(A.begin(), A_quantized.begin(), quant_mult, A_rows, width); - Backend::template PrepareB<TileColumnsMultiplier>(B.begin(), B_prepared.begin(), quant_mult, width, B_cols); + Backend::template PrepareB<TileColumnsMultiplier * 8>(B.begin(), B_prepared.begin(), quant_mult, width, B_cols); Backend::template Multiply<TileRows, TileColumnsMultiplier>(A_quantized.begin(), B_prepared.begin(), A_rows, width, B_cols, callbacks::Write<int32_t>(output.begin())); references::Quantize(B.begin(), B_quantized.begin(), quant_mult, B_quantized.size()); @@ -159,4 +158,3 @@ TEST_CASE("Multiply AVX2 16bit - custom tiling", "") { #endif } -} diff --git a/test/prepare_b.cc b/test/prepare_b.cc new file mode 100644 index 0000000..563accd --- /dev/null +++ b/test/prepare_b.cc @@ -0,0 +1,167 @@ +#include "test.h" +#include "../aligned.h" +#include "../avx2_gemm.h" +#include "../avx512_gemm.h" +#include "../sse2_gemm.h" +#include "../ssse3_gemm.h" + +#include <cstring> +#include <iostream> +#include <math.h> + +namespace intgemm { +namespace { + +template <typename Backend, Index TileColumns> +void PrepareBRef(const float* input, typename Backend::Integer* output, float quant_mult, Index B_rows, Index B_cols) { + using vec_t = intgemm::vector_t<Backend::kUses, typename Backend::Integer>; + constexpr Index vec_len = sizeof(vec_t) / sizeof(typename Backend::Integer); + + for (Index c = 0; c < B_cols; c += TileColumns) + for (Index r = 0; r < B_rows; r += vec_len) + for (Index ci = 0; ci < TileColumns; ++ci) + for (Index ri = 0; ri < vec_len; ++ri) { + *output++ = input[(r + ri) * B_cols + (c + ci)] * quant_mult; + } +} + +template <typename Backend, Index TileColumns> +bool TestInner(const AlignedVector<float>& input, Index B_rows, Index B_cols, float quant_mult) { + bool success = true; + + AlignedVector<typename Backend::Integer> output(input.size()); + Backend::template PrepareB<TileColumns>(input.begin(), output.begin(), quant_mult, B_rows, B_cols); + + AlignedVector<typename Backend::Integer> reference(input.size()); + PrepareBRef<Backend, TileColumns>(input.begin(), reference.begin(), quant_mult, B_rows, B_cols); + + if (TileColumns != 7 && TileColumns != 8) { + std::cout << "Input:" << std::endl; + for (int i = 0; i < B_rows; ++i) { + for (int j = 0; j < B_cols; ++j) + std::cout << input[i * B_cols + j] << ", "; + std::cout << std::endl; + } + + std::cout << "Output:" << std::endl; + for (int i = 0; i < B_rows; ++i) { + for (int j = 0; j < B_cols; ++j) + std::cout << output[i * B_cols + j] << ", "; + std::cout << std::endl; + } + + std::cout << "Reference:" << std::endl; + for (int i = 0; i < B_rows; ++i) { + for (int j = 0; j < B_cols; ++j) + std::cout << reference[i * B_cols + j] << ", "; + std::cout << std::endl; + } + } + + for (std::size_t i = 0; i < output.size(); ++i) { + if (output[i] != reference[i]) { + UNSCOPED_INFO("Error at " << i << ", output = " << int(output[i]) << ", reference = " << int(reference[i])); + success = false; + break; + } + } + return success; +} + +template <typename Backend, Index TileColumns> +bool Test(Index B_rows, Index B_cols, float quant_mult) { + AlignedVector<float> input(B_rows * B_cols); + + std::generate(input.begin(), input.end(), []() { + static constexpr int divider = sizeof(intgemm::vector_t<Backend::kUses, typename Backend::Integer>) / sizeof(typename Backend::Integer); + static int value = 0; + return (value++) % divider; + }); + + return TestInner<Backend, TileColumns>(input, B_rows, B_cols, quant_mult); +} + +TEST_CASE("PrepareB SSE2", "") { + if (kCPU < CPUType::SSE2) + return; + + // CHECK(Test<SSE2_16bit, 1>(32, 32, 2.0f)); + // CHECK(Test<SSE2_16bit, 2>(32, 2*16, 2.0f)); + // CHECK(Test<SSE2_16bit, 3>(32, 3*16, 2.0f)); + // CHECK(Test<SSE2_16bit, 4>(32, 4*8, 2.0f)); + // CHECK(Test<SSE2_16bit, 5>(32, 5*8, 2.0f)); + // CHECK(Test<SSE2_16bit, 6>(32, 6*8, 2.0f)); + CHECK(Test<SSE2_16bit, 7>(32, 7*8, 2.0f)); + CHECK(Test<SSE2_16bit, 8>(32, 32, 2.0f)); + // CHECK(Test<SSE2_16bit, 9>(32, 9*16, 2.0f)); +} + +TEST_CASE("PrepareB SSSE3", "") { + if (kCPU < CPUType::SSSE3) + return; + + // CHECK(Test<SSSE3_8bit, 1>(32, 1*128, 2.0f)); + // CHECK(Test<SSSE3_8bit, 2>(32, 2*128, 2.0f)); + // CHECK(Test<SSSE3_8bit, 3>(32, 3*128, 2.0f)); + // CHECK(Test<SSSE3_8bit, 4>(32, 4*128, 2.0f)); + // CHECK(Test<SSSE3_8bit, 5>(32, 5*128, 2.0f)); + // CHECK(Test<SSSE3_8bit, 6>(32, 6*128, 2.0f)); + // CHECK(Test<SSSE3_8bit, 7>(32, 7*128, 2.0f)); + // CHECK(Test<SSSE3_8bit, 8>(32, 8*128, 2.0f)); + // CHECK(Test<SSSE3_8bit, 9>(32, 9*128, 2.0f)); +} + +TEST_CASE("PrepareB AVX2", "") { + if (kCPU < CPUType::AVX2) + return; + + // CHECK(Test<AVX2_8bit, 1>(32, 1*128, 2.0f)); + // CHECK(Test<AVX2_8bit, 2>(32, 2*128, 2.0f)); + // CHECK(Test<AVX2_8bit, 3>(32, 3*128, 2.0f)); + // CHECK(Test<AVX2_8bit, 4>(32, 4*128, 2.0f)); + // CHECK(Test<AVX2_8bit, 5>(32, 5*128, 2.0f)); + // CHECK(Test<AVX2_8bit, 6>(32, 6*128, 2.0f)); + // CHECK(Test<AVX2_8bit, 7>(32, 7*128, 2.0f)); + // CHECK(Test<AVX2_8bit, 8>(32, 8*128, 2.0f)); + // CHECK(Test<AVX2_8bit, 9>(32, 9*128, 2.0f)); + + // CHECK(Test<AVX2_16bit, 1>(32, 1*128, 2.0f)); + // CHECK(Test<AVX2_16bit, 2>(32, 2*128, 2.0f)); + // CHECK(Test<AVX2_16bit, 3>(32, 3*128, 2.0f)); + // CHECK(Test<AVX2_16bit, 4>(32, 4*128, 2.0f)); + // CHECK(Test<AVX2_16bit, 5>(32, 5*128, 2.0f)); + // CHECK(Test<AVX2_16bit, 6>(32, 6*128, 2.0f)); + // CHECK(Test<AVX2_16bit, 7>(32, 7*128, 2.0f)); + // CHECK(Test<AVX2_16bit, 8>(32, 8*128, 2.0f)); + // CHECK(Test<AVX2_16bit, 9>(32, 9*128, 2.0f)); +} + +#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512 + TEST_CASE("PrepareB AVX512", "") { + if (kCPU < CPUType::AVX512BW) + return; + + // CHECK(Test<AVX512_8bit, 1>(32, 1*128, 2.0f)); + // CHECK(Test<AVX512_8bit, 2>(32, 2*128, 2.0f)); + // CHECK(Test<AVX512_8bit, 3>(32, 3*128, 2.0f)); + // CHECK(Test<AVX512_8bit, 4>(32, 4*128, 2.0f)); + // CHECK(Test<AVX512_8bit, 5>(32, 5*128, 2.0f)); + // CHECK(Test<AVX512_8bit, 6>(32, 6*128, 2.0f)); + // CHECK(Test<AVX512_8bit, 7>(32, 7*128, 2.0f)); + // CHECK(Test<AVX512_8bit, 8>(32, 8*128, 2.0f)); + // CHECK(Test<AVX512_8bit, 9>(32, 9*128, 2.0f)); + + // CHECK(Test<AVX512_16bit, 1>(32, 1*128, 2.0f)); + // CHECK(Test<AVX512_16bit, 2>(32, 2*128, 2.0f)); + // CHECK(Test<AVX512_16bit, 3>(32, 3*128, 2.0f)); + // CHECK(Test<AVX512_16bit, 4>(32, 4*128, 2.0f)); + // CHECK(Test<AVX512_16bit, 5>(32, 5*128, 2.0f)); + // CHECK(Test<AVX512_16bit, 6>(32, 6*128, 2.0f)); + // CHECK(Test<AVX512_16bit, 7>(32, 7*128, 2.0f)); + // CHECK(Test<AVX512_16bit, 8>(32, 8*128, 2.0f)); + // CHECK(Test<AVX512_16bit, 9>(32, 9*128, 2.0f)); + } +#endif + +} +} |