Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/intgemm.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMateusz Chudyk <mateuszchudyk@gmail.com>2020-03-27 17:02:12 +0300
committerMateusz Chudyk <mateuszchudyk@gmail.com>2020-03-27 17:02:17 +0300
commit4b65293dafce60cd958149be9abb400cc9c7057f (patch)
treeaaeb91a47a5f34b750071737c3f2a00fe205b568
parentc351bd5793ccc36738ecfe921479edd588f723cf (diff)
-rw-r--r--CMakeLists.txt1
-rw-r--r--benchmarks/benchmark.cc2
-rw-r--r--benchmarks/biasmultiply.cc6
-rw-r--r--example.cc4
-rw-r--r--interleave.h82
-rw-r--r--test/add127_test.cc8
-rw-r--r--test/multiply_test.cc10
-rw-r--r--test/multiply_tiling_test.cc4
-rw-r--r--test/prepare_b.cc167
9 files changed, 227 insertions, 57 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt
index ec1579a..57b0c53 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -58,6 +58,7 @@ add_executable(tests
test/add127_test.cc
test/multiply_test.cc
test/multiply_tiling_test
+ test/prepare_b.cc
test/prepare_b_quantized_transposed.cc
test/prepare_b_transposed.cc
test/quantize_test.cc
diff --git a/benchmarks/benchmark.cc b/benchmarks/benchmark.cc
index e4b3815..4f04765 100644
--- a/benchmarks/benchmark.cc
+++ b/benchmarks/benchmark.cc
@@ -77,7 +77,7 @@ template <class Backend> void Run(const RandomMatrices &m, std::vector<uint64_t>
AlignedVector<Integer> A_prepared(m.A_rows * m.width);
Backend::PrepareA(m.A.begin(), A_prepared.begin(), quant_mult, m.A_rows, m.width);
AlignedVector<Integer> B_prepared(m.width * m.B_cols);
- Backend::template PrepareB<1>(m.B.begin(), B_prepared.begin(), quant_mult, m.width, m.B_cols);
+ Backend::template PrepareB<8>(m.B.begin(), B_prepared.begin(), quant_mult, m.width, m.B_cols);
AlignedVector<float> output(m.A_rows * m.B_cols);
// Burn in
Backend::template Multiply<1, 1>(A_prepared.begin(), B_prepared.begin(), m.A_rows, m.width, m.B_cols, callbacks::UnquantizeAndWrite(unquant_mult, output.begin()));
diff --git a/benchmarks/biasmultiply.cc b/benchmarks/biasmultiply.cc
index a95af19..e031609 100644
--- a/benchmarks/biasmultiply.cc
+++ b/benchmarks/biasmultiply.cc
@@ -35,7 +35,7 @@ std::chrono::duration<double> testNew(Index A_rows, Index width, Index B_cols) {
AlignedVector<uint8_t> A_prep(A.size());
AlignedVector<int8_t> B_prep(B.size());
Routine::PrepareA(A.begin(), A_prep.begin(), quant_mult, A_rows, width);
- Routine::template PrepareB<1>(B.begin(), B_prep.begin(), quant_mult, width, B_cols);
+ Routine::template PrepareB<8>(B.begin(), B_prep.begin(), quant_mult, width, B_cols);
AlignedVector<float> test_C(A_rows * B_cols);
@@ -74,7 +74,7 @@ std::chrono::duration<double> testOld(Index A_rows, Index width, Index B_cols) {
AlignedVector<int8_t> A_prep(A.size());
AlignedVector<int8_t> B_prep(B.size());
Routine::PrepareA(A.begin(), A_prep.begin(), quant_mult, A_rows, width);
- Routine::template PrepareB<1>(B.begin(), B_prep.begin(), quant_mult, width, B_cols);
+ Routine::template PrepareB<8>(B.begin(), B_prep.begin(), quant_mult, width, B_cols);
AlignedVector<float> test_C(A_rows * B_cols);
@@ -107,7 +107,7 @@ std::chrono::duration<double> testOld_nobias(Index A_rows, Index width, Index B_
AlignedVector<int8_t> A_prep(A.size());
AlignedVector<int8_t> B_prep(B.size());
Routine::PrepareA(A.begin(), A_prep.begin(), quant_mult, A_rows, width);
- Routine::template PrepareB<1>(B.begin(), B_prep.begin(), quant_mult, width, B_cols);
+ Routine::template PrepareB<8>(B.begin(), B_prep.begin(), quant_mult, width, B_cols);
AlignedVector<float> test_C(A_rows * B_cols);
diff --git a/example.cc b/example.cc
index 1bf0bf3..2a94364 100644
--- a/example.cc
+++ b/example.cc
@@ -48,7 +48,7 @@ int main() {
intgemm::Int16::PrepareA(A.begin(), A_prepared.begin(), quant_mult, A_rows, width);
// Quantize and reshape B.
// Typically you will do this once when parameters are loaded, not every time.
- intgemm::Int16::template PrepareB<1>(B.begin(), B_prepared.begin(), quant_mult, width, B_cols);
+ intgemm::Int16::template PrepareB<8>(B.begin(), B_prepared.begin(), quant_mult, width, B_cols);
AlignedVector<float> C(A_rows * B_cols);
// Do the actual multiply.
@@ -67,7 +67,7 @@ int main() {
intgemm::Int8::PrepareA(A.begin(), A_prepared.begin(), quant_mult, A_rows, width);
// Quantize and reshape B.
// Typically you will do this once when parameters are loaded, not every time.
- intgemm::Int8::template PrepareB<1>(B.begin(), B_prepared.begin(), quant_mult, width, B_cols);
+ intgemm::Int8::template PrepareB<8>(B.begin(), B_prepared.begin(), quant_mult, width, B_cols);
AlignedVector<float> C(A_rows * B_cols);
// Do the actual multiply.
diff --git a/interleave.h b/interleave.h
index 8680ce8..8bf20cd 100644
--- a/interleave.h
+++ b/interleave.h
@@ -8,6 +8,7 @@
#include <algorithm>
#include <cassert>
#include <stdint.h>
+#include <iostream>
namespace intgemm {
@@ -187,31 +188,32 @@ template <class Register> static inline void Transpose8InLane(
// 257 273
// ... ...
-template <typename Type>
+template <typename Type, Index TileColumns>
struct PrepareB_InnerLoop;
#define INTGEMM_PREPARE_B_8_INNER_LOOP(target, Register) \
template <typename Iterator, typename Quantizer> \
- target static void body(Register* output, const Quantizer &quantizer, const float* input, Index cols, Index row, Index col) { \
+ target static void body(Register* output, const Quantizer &quantizer, const float* input, Index rows, Index cols, Index row, Index col) { \
static constexpr Index I = Iterator::template I<0>(); \
- output[8 * I + 0] = quantizer.ForReshape(input + cols * (row + 0) + 8 * I + col, cols); \
- output[8 * I + 1] = quantizer.ForReshape(input + cols * (row + 1) + 8 * I + col, cols); \
- output[8 * I + 2] = quantizer.ForReshape(input + cols * (row + 4) + 8 * I + col, cols); \
- output[8 * I + 3] = quantizer.ForReshape(input + cols * (row + 5) + 8 * I + col, cols); \
- output[8 * I + 4] = quantizer.ForReshape(input + cols * (row + 8) + 8 * I + col, cols); \
- output[8 * I + 5] = quantizer.ForReshape(input + cols * (row + 9) + 8 * I + col, cols); \
- output[8 * I + 6] = quantizer.ForReshape(input + cols * (row + 12) + 8 * I + col, cols); \
- output[8 * I + 7] = quantizer.ForReshape(input + cols * (row + 13) + 8 * I + col, cols); \
- Interleave8(output[8 * I + 0], output[8 * I + 1]); \
- Interleave8(output[8 * I + 2], output[8 * I + 3]); \
- Interleave8(output[8 * I + 4], output[8 * I + 5]); \
- Interleave8(output[8 * I + 6], output[8 * I + 7]); \
- Transpose16InLane(output[8 * I + 0], output[8 * I + 1], output[8 * I + 2], output[8 * I + 3], \
- output[8 * I + 4], output[8 * I + 5], output[8 * I + 6], output[8 * I + 7]); \
+ const Index offset = rows / 8 * TileColumns; \
+ output[0 / TileColumns * offset + 8 * I + 0] = quantizer.ForReshape(input + cols * (row + 0) + 8 * I + col, cols); \
+ output[1 / TileColumns * offset + 8 * I + 1] = quantizer.ForReshape(input + cols * (row + 1) + 8 * I + col, cols); \
+ output[2 / TileColumns * offset + 8 * I + 2] = quantizer.ForReshape(input + cols * (row + 4) + 8 * I + col, cols); \
+ output[3 / TileColumns * offset + 8 * I + 3] = quantizer.ForReshape(input + cols * (row + 5) + 8 * I + col, cols); \
+ output[4 / TileColumns * offset + 8 * I + 4] = quantizer.ForReshape(input + cols * (row + 8) + 8 * I + col, cols); \
+ output[5 / TileColumns * offset + 8 * I + 5] = quantizer.ForReshape(input + cols * (row + 9) + 8 * I + col, cols); \
+ output[6 / TileColumns * offset + 8 * I + 6] = quantizer.ForReshape(input + cols * (row + 12) + 8 * I + col, cols); \
+ output[7 / TileColumns * offset + 8 * I + 7] = quantizer.ForReshape(input + cols * (row + 13) + 8 * I + col, cols); \
+ Interleave8(output[0 / TileColumns * offset + 8 * I + 0], output[1 / TileColumns * offset + 8 * I + 1]); \
+ Interleave8(output[2 / TileColumns * offset + 8 * I + 2], output[3 / TileColumns * offset + 8 * I + 3]); \
+ Interleave8(output[4 / TileColumns * offset + 8 * I + 4], output[5 / TileColumns * offset + 8 * I + 5]); \
+ Interleave8(output[6 / TileColumns * offset + 8 * I + 6], output[7 / TileColumns * offset + 8 * I + 7]); \
+ Transpose16InLane(output[0 / TileColumns * offset + 8 * I + 0], output[1 / TileColumns * offset + 8 * I + 1], output[2 / TileColumns * offset + 8 * I + 2], output[3 / TileColumns * offset + 8 * I + 3], \
+ output[4 / TileColumns * offset + 8 * I + 4], output[5 / TileColumns * offset + 8 * I + 5], output[6 / TileColumns * offset + 8 * I + 6], output[7 / TileColumns * offset + 8 * I + 7]); \
}
-template <>
-struct PrepareB_InnerLoop<int8_t> {
+template <Index TileColumns>
+struct PrepareB_InnerLoop<int8_t, TileColumns> {
INTGEMM_PREPARE_B_8_INNER_LOOP(INTGEMM_SSSE3, __m128i)
INTGEMM_PREPARE_B_8_INNER_LOOP(INTGEMM_AVX2, __m256i)
INTGEMM_PREPARE_B_8_INNER_LOOP(INTGEMM_AVX512BW, __m512i)
@@ -219,49 +221,51 @@ struct PrepareB_InnerLoop<int8_t> {
#define INTGEMM_PREPARE_B_16_INNER_LOOP(target, Register) \
template <typename Iterator, typename Quantizer> \
- target static void body(Register* output, const Quantizer &quantizer, const float* input, Index cols, Index row, Index col) { \
+ target static void body(Register* output, const Quantizer &quantizer, const float* input, Index rows, Index cols, Index row, Index col) { \
static constexpr Index I = Iterator::template I<0>(); \
- output[8 * I + 0] = quantizer.ForReshape(input + cols * (row + 0) + 8 * I + col, cols); \
- output[8 * I + 1] = quantizer.ForReshape(input + cols * (row + 1) + 8 * I + col, cols); \
- output[8 * I + 2] = quantizer.ForReshape(input + cols * (row + 2) + 8 * I + col, cols); \
- output[8 * I + 3] = quantizer.ForReshape(input + cols * (row + 3) + 8 * I + col, cols); \
- output[8 * I + 4] = quantizer.ForReshape(input + cols * (row + 4) + 8 * I + col, cols); \
- output[8 * I + 5] = quantizer.ForReshape(input + cols * (row + 5) + 8 * I + col, cols); \
- output[8 * I + 6] = quantizer.ForReshape(input + cols * (row + 6) + 8 * I + col, cols); \
- output[8 * I + 7] = quantizer.ForReshape(input + cols * (row + 7) + 8 * I + col, cols); \
- Transpose16InLane(output[8 * I + 0], output[8 * I + 1], output[8 * I + 2], output[8 * I + 3], \
- output[8 * I + 4], output[8 * I + 5], output[8 * I + 6], output[8 * I + 7]); \
+ const Index offset = rows / 8 * TileColumns; \
+ output[0 / TileColumns * 8 * I + 0] = quantizer.ForReshape(input + cols * (row + 0) + 8 * I + col, cols); \
+ output[1 / TileColumns * 8 * I + 1] = quantizer.ForReshape(input + cols * (row + 1) + 8 * I + col, cols); \
+ output[2 / TileColumns * 8 * I + 2] = quantizer.ForReshape(input + cols * (row + 2) + 8 * I + col, cols); \
+ output[3 / TileColumns * 8 * I + 3] = quantizer.ForReshape(input + cols * (row + 3) + 8 * I + col, cols); \
+ output[4 / TileColumns * 8 * I + 4] = quantizer.ForReshape(input + cols * (row + 4) + 8 * I + col, cols); \
+ output[5 / TileColumns * 8 * I + 5] = quantizer.ForReshape(input + cols * (row + 5) + 8 * I + col, cols); \
+ output[6 / TileColumns * 8 * I + 6] = quantizer.ForReshape(input + cols * (row + 6) + 8 * I + col, cols); \
+ output[7 / TileColumns * 8 * I + 7] = quantizer.ForReshape(input + cols * (row + 7) + 8 * I + col, cols); \
+ Transpose16InLane(output[0 / TileColumns * offset + 8 * I + 0], output[1 / TileColumns * offset + 8 * I + 1], output[2 / TileColumns * offset + 8 * I + 2], output[3 / TileColumns * offset + 8 * I + 3], \
+ output[4 / TileColumns * offset + 8 * I + 4], output[5 / TileColumns * offset + 8 * I + 5], output[6 / TileColumns * offset + 8 * I + 6], output[7 / TileColumns * offset + 8 * I + 7]); \
}
-template <>
-struct PrepareB_InnerLoop<int16_t> {
+template <Index TileColumns>
+struct PrepareB_InnerLoop<int16_t, TileColumns> {
INTGEMM_PREPARE_B_16_INNER_LOOP(INTGEMM_SSSE3, __m128i)
INTGEMM_PREPARE_B_16_INNER_LOOP(INTGEMM_AVX2, __m256i)
INTGEMM_PREPARE_B_16_INNER_LOOP(INTGEMM_AVX512BW, __m512i)
};
#define INTGEMM_PREPARE_B(target, Quantizer, Integer) \
-template <Index TileColumnsMultiplier> \
+template <Index TileColumns> \
target static inline void PrepareB(const float *input, Integer *output, float quant_mult, Index rows, Index cols) { \
- static constexpr Index Columns = 8 * TileColumnsMultiplier; \
using Register = Quantizer::Register; \
- const Index RegisterElems = sizeof(Register) / sizeof(Integer); \
+ static constexpr Index TileColumnsRoundUp = (TileColumns + 7) / 8 * 8; \
+ constexpr Index RegisterElems = sizeof(Register) / sizeof(Integer); \
\
Quantizer quantizer = Quantizer(quant_mult); \
Register *output_it = reinterpret_cast<Register*>(output); \
\
- assert(cols % Columns == 0); \
- assert(rows % (RegisterElems * TileColumnsMultiplier) == 0); \
+ assert(cols % TileColumnsRoundUp == 0); \
+ assert(rows % RegisterElems == 0); \
assert(reinterpret_cast<uintptr_t>(input) % sizeof(Register) == 0); \
assert(reinterpret_cast<uintptr_t>(output_it) % sizeof(Register) == 0); \
\
- for (Index c = 0; c < cols; c += Columns) { \
- for (Index r = 0; r < rows; r += RegisterElems, output_it += Columns) { \
+ for (Index c = 0; c < cols - TileColumns; c += TileColumns) { \
+ for (Index r = 0; r < rows; r += RegisterElems, output_it += TileColumns) { \
+ std::cout << TileColumns << ", r: " << r << ", c: " << c << ", output: " << output_it << "/" << output + rows * cols * sizeof(Integer) << std::endl; \
/* Quantize and perform a transpose with height sizeof(Register) and width Columns. \
This isn't quite Transpose8InLane because it's half the number of columns, \
so each register starts with two rows instead of being one row. \
The quantizers know to skip a row.*/ \
- StaticLoop<PrepareB_InnerLoop<Integer>, MakeStaticLoopIterator<TileColumnsMultiplier>>(output_it, quantizer, input, cols, r, c); \
+ StaticLoop<PrepareB_InnerLoop<Integer, TileColumns>, MakeStaticLoopIterator<TileColumnsRoundUp / 8>>(output_it, quantizer, input, rows, cols, r, c); \
} \
} \
}
diff --git a/test/add127_test.cc b/test/add127_test.cc
index 7ebb93e..b192291 100644
--- a/test/add127_test.cc
+++ b/test/add127_test.cc
@@ -46,7 +46,7 @@ template <class Routine> void TestPrepareBias(Index rows, Index cols) {
AlignedVector<int8_t> B_prep(inputB.size());
AlignedVector<int8_t> B_quant(inputB.size());
- Routine::template PrepareB<1>(inputB.begin(), B_prep.begin(), quant_mult, rows, cols);
+ Routine::template PrepareB<8>(inputB.begin(), B_prep.begin(), quant_mult, rows, cols);
Routine::Quantize(inputB.begin(), B_quant.begin(), quant_mult, inputB.size());
AlignedVector<float> inputBias(cols);
@@ -107,7 +107,7 @@ template <class Routine> void TestMultiplyBiasNew(Index A_rows, Index width, Ind
AlignedVector<uint8_t> A_prep(A.size());
AlignedVector<int8_t> B_prep(B.size());
Routine::PrepareA(A.begin(), A_prep.begin(), quant_mult, A_rows, width);
- Routine::template PrepareB<1>(B.begin(), B_prep.begin(), quant_mult, width, B_cols);
+ Routine::template PrepareB<8>(B.begin(), B_prep.begin(), quant_mult, width, B_cols);
AlignedVector<float> test_C(A_rows * B_cols);
@@ -172,7 +172,7 @@ template <class Routine> void TestMultiplyShiftNonShift(Index A_rows, Index widt
AlignedVector<int8_t> B_prep(B.size());
Routine::PrepareA(A.begin(), A_prep.begin(), quant_mult, A_rows, width);
Routine::PrepareA(A.begin(), A_prep_old.begin(), quant_mult, A_rows, width); //Non shited version
- Routine::template PrepareB<1>(B.begin(), B_prep.begin(), quant_mult, width, B_cols);
+ Routine::template PrepareB<8>(B.begin(), B_prep.begin(), quant_mult, width, B_cols);
AlignedVector<float> test_C(A_rows * B_cols);
@@ -228,7 +228,7 @@ template <class Routine> void TestMultiplyShiftInt(Index A_rows, Index width, In
AlignedVector<int8_t> B_prep(B.size());
Routine::PrepareA(A.begin(), A_prep.begin(), quant_mult, A_rows, width);
Routine::PrepareA(A.begin(), A_prep_old.begin(), quant_mult, A_rows, width); //Non shited version
- Routine::template PrepareB<1>(B.begin(), B_prep.begin(), quant_mult, width, B_cols);
+ Routine::template PrepareB<8>(B.begin(), B_prep.begin(), quant_mult, width, B_cols);
AlignedVector<float> test_C(A_rows * B_cols);
diff --git a/test/multiply_test.cc b/test/multiply_test.cc
index 9deff69..6956c56 100644
--- a/test/multiply_test.cc
+++ b/test/multiply_test.cc
@@ -67,7 +67,7 @@ template <class Routine> void TestPrepare(Index rows = 32, Index cols = 16) {
typedef typename Routine::Integer Integer;
// Call Prepare
AlignedVector<Integer> test(input.size());
- Routine::template PrepareB<1>(input.begin(), test.begin(), 1, rows, cols);
+ Routine::template PrepareB<8>(input.begin(), test.begin(), 1, rows, cols);
// Compute reference output.
AlignedVector<Integer> quantized(input.size());
@@ -119,7 +119,7 @@ template <class Routine> void TestSelectColumnsB(Index rows = 64, Index cols = 1
}
typedef typename Routine::Integer Integer;
AlignedVector<Integer> prepared(input.size());
- Routine::template PrepareB<1>(input.begin(), prepared.begin(), 1, rows, cols);
+ Routine::template PrepareB<8>(input.begin(), prepared.begin(), 1, rows, cols);
int kSelectCols = 24;
Index select_cols[kSelectCols];
@@ -140,7 +140,7 @@ template <class Routine> void TestSelectColumnsB(Index rows = 64, Index cols = 1
}
}
AlignedVector<Integer> ref(rows * kSelectCols);
- Routine::template PrepareB<1>(selected.begin(), ref.begin(), 1, rows, kSelectCols);
+ Routine::template PrepareB<8>(selected.begin(), ref.begin(), 1, rows, kSelectCols);
CHECK_MESSAGE(memcmp(ref.begin(), test.begin(), sizeof(Integer) * rows * kSelectCols) == 0, "Reference:\n" <<
PrintMatrix(ref.begin(), rows, kSelectCols) << PrintMatrix(test.begin(), rows, kSelectCols));
}
@@ -273,7 +273,7 @@ template <class Routine> void TestMultiply(Index A_rows, Index width, Index B_co
AlignedVector<Integer> A_prep(A.size());
AlignedVector<Integer> B_prep(B.size());
Routine::PrepareA(A.begin(), A_prep.begin(), quant_mult, A_rows, width);
- Routine::template PrepareB<1>(B.begin(), B_prep.begin(), quant_mult, width, B_cols);
+ Routine::template PrepareB<8>(B.begin(), B_prep.begin(), quant_mult, width, B_cols);
AlignedVector<float> test_C(A_rows * B_cols);
Routine::template Multiply<1, 1>(A_prep.begin(), B_prep.begin(), A_rows, width, B_cols, callbacks::UnquantizeAndWrite(unquant_mult, test_C.begin()));
@@ -329,7 +329,7 @@ template <class Routine> void TestMultiplyBias(Index A_rows, Index width, Index
AlignedVector<Integer> A_prep(A.size());
AlignedVector<Integer> B_prep(B.size());
Routine::PrepareA(A.begin(), A_prep.begin(), quant_mult, A_rows, width);
- Routine::template PrepareB<1>(B.begin(), B_prep.begin(), quant_mult, width, B_cols);
+ Routine::template PrepareB<8>(B.begin(), B_prep.begin(), quant_mult, width, B_cols);
AlignedVector<float> test_C(A_rows * B_cols);
diff --git a/test/multiply_tiling_test.cc b/test/multiply_tiling_test.cc
index ba50756..73f0c33 100644
--- a/test/multiply_tiling_test.cc
+++ b/test/multiply_tiling_test.cc
@@ -11,7 +11,6 @@
#include <random>
namespace intgemm {
-namespace {
template <typename Backend, Index TileRows, Index TileColumnsMultiplier>
bool Test(const AlignedVector<float>& A, const AlignedVector<float>& B, Index A_rows, Index width, Index B_cols, float quant_mult) {
@@ -23,7 +22,7 @@ bool Test(const AlignedVector<float>& A, const AlignedVector<float>& B, Index A_
AlignedVector<int32_t> reference(output.size());
Backend::PrepareA(A.begin(), A_quantized.begin(), quant_mult, A_rows, width);
- Backend::template PrepareB<TileColumnsMultiplier>(B.begin(), B_prepared.begin(), quant_mult, width, B_cols);
+ Backend::template PrepareB<TileColumnsMultiplier * 8>(B.begin(), B_prepared.begin(), quant_mult, width, B_cols);
Backend::template Multiply<TileRows, TileColumnsMultiplier>(A_quantized.begin(), B_prepared.begin(), A_rows, width, B_cols, callbacks::Write<int32_t>(output.begin()));
references::Quantize(B.begin(), B_quantized.begin(), quant_mult, B_quantized.size());
@@ -159,4 +158,3 @@ TEST_CASE("Multiply AVX2 16bit - custom tiling", "") {
#endif
}
-}
diff --git a/test/prepare_b.cc b/test/prepare_b.cc
new file mode 100644
index 0000000..563accd
--- /dev/null
+++ b/test/prepare_b.cc
@@ -0,0 +1,167 @@
+#include "test.h"
+#include "../aligned.h"
+#include "../avx2_gemm.h"
+#include "../avx512_gemm.h"
+#include "../sse2_gemm.h"
+#include "../ssse3_gemm.h"
+
+#include <cstring>
+#include <iostream>
+#include <math.h>
+
+namespace intgemm {
+namespace {
+
+template <typename Backend, Index TileColumns>
+void PrepareBRef(const float* input, typename Backend::Integer* output, float quant_mult, Index B_rows, Index B_cols) {
+ using vec_t = intgemm::vector_t<Backend::kUses, typename Backend::Integer>;
+ constexpr Index vec_len = sizeof(vec_t) / sizeof(typename Backend::Integer);
+
+ for (Index c = 0; c < B_cols; c += TileColumns)
+ for (Index r = 0; r < B_rows; r += vec_len)
+ for (Index ci = 0; ci < TileColumns; ++ci)
+ for (Index ri = 0; ri < vec_len; ++ri) {
+ *output++ = input[(r + ri) * B_cols + (c + ci)] * quant_mult;
+ }
+}
+
+template <typename Backend, Index TileColumns>
+bool TestInner(const AlignedVector<float>& input, Index B_rows, Index B_cols, float quant_mult) {
+ bool success = true;
+
+ AlignedVector<typename Backend::Integer> output(input.size());
+ Backend::template PrepareB<TileColumns>(input.begin(), output.begin(), quant_mult, B_rows, B_cols);
+
+ AlignedVector<typename Backend::Integer> reference(input.size());
+ PrepareBRef<Backend, TileColumns>(input.begin(), reference.begin(), quant_mult, B_rows, B_cols);
+
+ if (TileColumns != 7 && TileColumns != 8) {
+ std::cout << "Input:" << std::endl;
+ for (int i = 0; i < B_rows; ++i) {
+ for (int j = 0; j < B_cols; ++j)
+ std::cout << input[i * B_cols + j] << ", ";
+ std::cout << std::endl;
+ }
+
+ std::cout << "Output:" << std::endl;
+ for (int i = 0; i < B_rows; ++i) {
+ for (int j = 0; j < B_cols; ++j)
+ std::cout << output[i * B_cols + j] << ", ";
+ std::cout << std::endl;
+ }
+
+ std::cout << "Reference:" << std::endl;
+ for (int i = 0; i < B_rows; ++i) {
+ for (int j = 0; j < B_cols; ++j)
+ std::cout << reference[i * B_cols + j] << ", ";
+ std::cout << std::endl;
+ }
+ }
+
+ for (std::size_t i = 0; i < output.size(); ++i) {
+ if (output[i] != reference[i]) {
+ UNSCOPED_INFO("Error at " << i << ", output = " << int(output[i]) << ", reference = " << int(reference[i]));
+ success = false;
+ break;
+ }
+ }
+ return success;
+}
+
+template <typename Backend, Index TileColumns>
+bool Test(Index B_rows, Index B_cols, float quant_mult) {
+ AlignedVector<float> input(B_rows * B_cols);
+
+ std::generate(input.begin(), input.end(), []() {
+ static constexpr int divider = sizeof(intgemm::vector_t<Backend::kUses, typename Backend::Integer>) / sizeof(typename Backend::Integer);
+ static int value = 0;
+ return (value++) % divider;
+ });
+
+ return TestInner<Backend, TileColumns>(input, B_rows, B_cols, quant_mult);
+}
+
+TEST_CASE("PrepareB SSE2", "") {
+ if (kCPU < CPUType::SSE2)
+ return;
+
+ // CHECK(Test<SSE2_16bit, 1>(32, 32, 2.0f));
+ // CHECK(Test<SSE2_16bit, 2>(32, 2*16, 2.0f));
+ // CHECK(Test<SSE2_16bit, 3>(32, 3*16, 2.0f));
+ // CHECK(Test<SSE2_16bit, 4>(32, 4*8, 2.0f));
+ // CHECK(Test<SSE2_16bit, 5>(32, 5*8, 2.0f));
+ // CHECK(Test<SSE2_16bit, 6>(32, 6*8, 2.0f));
+ CHECK(Test<SSE2_16bit, 7>(32, 7*8, 2.0f));
+ CHECK(Test<SSE2_16bit, 8>(32, 32, 2.0f));
+ // CHECK(Test<SSE2_16bit, 9>(32, 9*16, 2.0f));
+}
+
+TEST_CASE("PrepareB SSSE3", "") {
+ if (kCPU < CPUType::SSSE3)
+ return;
+
+ // CHECK(Test<SSSE3_8bit, 1>(32, 1*128, 2.0f));
+ // CHECK(Test<SSSE3_8bit, 2>(32, 2*128, 2.0f));
+ // CHECK(Test<SSSE3_8bit, 3>(32, 3*128, 2.0f));
+ // CHECK(Test<SSSE3_8bit, 4>(32, 4*128, 2.0f));
+ // CHECK(Test<SSSE3_8bit, 5>(32, 5*128, 2.0f));
+ // CHECK(Test<SSSE3_8bit, 6>(32, 6*128, 2.0f));
+ // CHECK(Test<SSSE3_8bit, 7>(32, 7*128, 2.0f));
+ // CHECK(Test<SSSE3_8bit, 8>(32, 8*128, 2.0f));
+ // CHECK(Test<SSSE3_8bit, 9>(32, 9*128, 2.0f));
+}
+
+TEST_CASE("PrepareB AVX2", "") {
+ if (kCPU < CPUType::AVX2)
+ return;
+
+ // CHECK(Test<AVX2_8bit, 1>(32, 1*128, 2.0f));
+ // CHECK(Test<AVX2_8bit, 2>(32, 2*128, 2.0f));
+ // CHECK(Test<AVX2_8bit, 3>(32, 3*128, 2.0f));
+ // CHECK(Test<AVX2_8bit, 4>(32, 4*128, 2.0f));
+ // CHECK(Test<AVX2_8bit, 5>(32, 5*128, 2.0f));
+ // CHECK(Test<AVX2_8bit, 6>(32, 6*128, 2.0f));
+ // CHECK(Test<AVX2_8bit, 7>(32, 7*128, 2.0f));
+ // CHECK(Test<AVX2_8bit, 8>(32, 8*128, 2.0f));
+ // CHECK(Test<AVX2_8bit, 9>(32, 9*128, 2.0f));
+
+ // CHECK(Test<AVX2_16bit, 1>(32, 1*128, 2.0f));
+ // CHECK(Test<AVX2_16bit, 2>(32, 2*128, 2.0f));
+ // CHECK(Test<AVX2_16bit, 3>(32, 3*128, 2.0f));
+ // CHECK(Test<AVX2_16bit, 4>(32, 4*128, 2.0f));
+ // CHECK(Test<AVX2_16bit, 5>(32, 5*128, 2.0f));
+ // CHECK(Test<AVX2_16bit, 6>(32, 6*128, 2.0f));
+ // CHECK(Test<AVX2_16bit, 7>(32, 7*128, 2.0f));
+ // CHECK(Test<AVX2_16bit, 8>(32, 8*128, 2.0f));
+ // CHECK(Test<AVX2_16bit, 9>(32, 9*128, 2.0f));
+}
+
+#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512
+ TEST_CASE("PrepareB AVX512", "") {
+ if (kCPU < CPUType::AVX512BW)
+ return;
+
+ // CHECK(Test<AVX512_8bit, 1>(32, 1*128, 2.0f));
+ // CHECK(Test<AVX512_8bit, 2>(32, 2*128, 2.0f));
+ // CHECK(Test<AVX512_8bit, 3>(32, 3*128, 2.0f));
+ // CHECK(Test<AVX512_8bit, 4>(32, 4*128, 2.0f));
+ // CHECK(Test<AVX512_8bit, 5>(32, 5*128, 2.0f));
+ // CHECK(Test<AVX512_8bit, 6>(32, 6*128, 2.0f));
+ // CHECK(Test<AVX512_8bit, 7>(32, 7*128, 2.0f));
+ // CHECK(Test<AVX512_8bit, 8>(32, 8*128, 2.0f));
+ // CHECK(Test<AVX512_8bit, 9>(32, 9*128, 2.0f));
+
+ // CHECK(Test<AVX512_16bit, 1>(32, 1*128, 2.0f));
+ // CHECK(Test<AVX512_16bit, 2>(32, 2*128, 2.0f));
+ // CHECK(Test<AVX512_16bit, 3>(32, 3*128, 2.0f));
+ // CHECK(Test<AVX512_16bit, 4>(32, 4*128, 2.0f));
+ // CHECK(Test<AVX512_16bit, 5>(32, 5*128, 2.0f));
+ // CHECK(Test<AVX512_16bit, 6>(32, 6*128, 2.0f));
+ // CHECK(Test<AVX512_16bit, 7>(32, 7*128, 2.0f));
+ // CHECK(Test<AVX512_16bit, 8>(32, 8*128, 2.0f));
+ // CHECK(Test<AVX512_16bit, 9>(32, 9*128, 2.0f));
+ }
+#endif
+
+}
+}