Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/intgemm.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--avx2_gemm.h9
-rw-r--r--avx512_gemm.h17
-rw-r--r--benchmark.cc22
-rw-r--r--example.cc4
-rw-r--r--intgemm.h60
-rw-r--r--multiply.h19
-rw-r--r--sse2_gemm.h5
-rw-r--r--ssse3_gemm.h5
-rw-r--r--test/multiply_test.cc76
9 files changed, 109 insertions, 108 deletions
diff --git a/avx2_gemm.h b/avx2_gemm.h
index c560b49..d278e3e 100644
--- a/avx2_gemm.h
+++ b/avx2_gemm.h
@@ -82,10 +82,12 @@ struct AVX2_16bit {
AVX2 static void SelectColumnsB(const int16_t *input, int16_t *output, Index rows, const Index *cols_begin, const Index *cols_end) {
avx2::SelectColumnsOfB((const __m256i*)input, (__m256i*)output, rows * 2, cols_begin, cols_end);
}
-
+/*
AVX2 static void Multiply(const int16_t *A, const int16_t *B, float *C, float unquant_mult, Index A_rows, Index width, Index B_cols) {
Multiply16__m256i<JustUnquantizeC> (A, B, JustUnquantizeC(C, unquant_mult), A_rows, width, B_cols);
}
+ */
+ MULTIPLY16_define(__m256i, AVX2)
constexpr static const char *const kName = "16-bit AVX2";
@@ -174,11 +176,12 @@ struct AVX2_8bit {
AVX2 static void SelectColumnsB(const int8_t *input, int8_t *output, Index rows, const Index *cols_begin, const Index *cols_end) {
avx2::SelectColumnsOfB((const __m256i*)input, (__m256i*)output, rows, cols_begin, cols_end);
}
-
+/*
AVX2 static void Multiply(const int8_t *A, const int8_t *B, float *C, float unquant_mult, Index A_rows, Index width, Index B_cols) {
//Multiply8_SSE2OrAVX2<Multiply8_AVXAVX2, __m256i, __m256>(A, B, C, unquant_mult, A_rows, width, B_cols);
Multiply8_SSE2OrAVX2__m256i<JustUnquantizeC>(A, B, JustUnquantizeC(C, unquant_mult), A_rows, width, B_cols);
- }
+ }*/
+ MULTIPLY8_define(__m256i, AVX2)
constexpr static const char *const kName = "8-bit AVX2";
diff --git a/avx512_gemm.h b/avx512_gemm.h
index 93a5997..97d2e73 100644
--- a/avx512_gemm.h
+++ b/avx512_gemm.h
@@ -155,11 +155,12 @@ struct AVX512_16bit {
AVX512F static void SelectColumnsB(const int16_t *input, int16_t *output, Index rows, const Index *cols_begin, const Index *cols_end) {
avx512f::SelectColumnsOfB((const __m512i*)input, (__m512i*)output, rows * 2, cols_begin, cols_end);
}
-
+/*
AVX512F static void Multiply(const int16_t *A, const int16_t *B, float *C, float unquant_mult, Index A_rows, Index width, Index B_cols) {
// The unquantization is only 256-bit wide because there are 8 results.
Multiply16__m512i<JustUnquantizeC> (A, B, JustUnquantizeC(C, unquant_mult), A_rows, width, B_cols);
- }
+ }*/
+ MULTIPLY16_define(__m512i, AVX512F)
constexpr static const char *const kName = "16-bit AVX512";
@@ -206,16 +207,17 @@ struct AVX512_8bit {
// Special AVX512 implementation due to having 32 registers (so I don't have to
// allocate registers manually) and no sign instruction.
- AVX512BW static void Multiply(const int8_t *A, const int8_t *B, float *C, float unquant_mult, Index A_rows, Index width, Index B_cols) {
+ template <class WriteC>
+ AVX512BW static void Multiply(const int8_t *A, const int8_t *B, WriteC functor, Index A_rows, Index width, Index B_cols) {
typedef __m512i Integer;
- typedef __m256 Float; // For quantization we only do 8 at a time.
+ //typedef __m256 Float; // For quantization we only do 8 at a time.
// This is copy-paste from Multiply8_SSE2OrAVX2.
assert(width % sizeof(Integer) == 0);
assert(B_cols % 8 == 0);
assert(reinterpret_cast<uintptr_t>(A) % sizeof(Integer) == 0);
assert(reinterpret_cast<uintptr_t>(B) % sizeof(Integer) == 0);
- assert(reinterpret_cast<uintptr_t>(C) % sizeof(Integer) == 0);
- Float unquant_reg = set1_ps<Float>(unquant_mult);
+ //assert(reinterpret_cast<uintptr_t>(C) % sizeof(Integer) == 0);
+ //Float unquant_reg = set1_ps<Float>(unquant_mult);
const int simd_width = width / sizeof(Integer);
const Integer *B0_col = reinterpret_cast<const Integer*>(B);
// Added for AVX512.
@@ -312,7 +314,8 @@ struct AVX512_8bit {
Integer pack4567 = Pack0123(sum4, sum5, sum6, sum7);
auto total = PermuteSummer(pack0123, pack4567);
- WriteC(C + A_rowidx * B_cols + B0_colidx, total, unquant_reg);
+ //WriteC(C + A_rowidx * B_cols + B0_colidx, total, unquant_reg);
+ functor(A_rowidx, B_cols, B0_colidx, total);
}
}
}
diff --git a/benchmark.cc b/benchmark.cc
index a629049..965fd68 100644
--- a/benchmark.cc
+++ b/benchmark.cc
@@ -62,7 +62,7 @@ struct RandomMatrices {
AlignedVector<float> A, B;
};
-template <class Backend> void Run(const RandomMatrices &m, std::vector<uint64_t> &stats) {
+template <class Backend, class WriteC> void Run(const RandomMatrices &m, std::vector<uint64_t> &stats) {
typedef typename Backend::Integer Integer;
float quant_mult = 127.0 / 2;
float unquant_mult = 1.0 / (quant_mult * quant_mult);
@@ -72,20 +72,20 @@ template <class Backend> void Run(const RandomMatrices &m, std::vector<uint64_t>
Backend::PrepareB(m.B.get(), B_prepared.get(), quant_mult, m.width, m.B_cols);
AlignedVector<float> output(m.A_rows * m.B_cols);
// Burn in
- Backend::Multiply(A_prepared.get(), B_prepared.get(), output.get(), unquant_mult, m.A_rows, m.width, m.B_cols);
+ Backend::template Multiply<WriteC>(A_prepared.get(), B_prepared.get(), JustUnquantizeC(output.get(), unquant_mult), m.A_rows, m.width, m.B_cols);
{
StopWatch w(stats);
- Backend::Multiply(A_prepared.get(), B_prepared.get(), output.get(), unquant_mult, m.A_rows, m.width, m.B_cols);
+ Backend::template Multiply<WriteC>(A_prepared.get(), B_prepared.get(), JustUnquantizeC(output.get(), unquant_mult), m.A_rows, m.width, m.B_cols);
}
}
-template <class Backend> void RunAll(RandomMatrices *matrices, RandomMatrices *matrices_end, std::vector<std::vector<uint64_t> > &stats) {
+template <class Backend, class WriteC> void RunAll(RandomMatrices *matrices, RandomMatrices *matrices_end, std::vector<std::vector<uint64_t> > &stats) {
if (Backend::kUses > kCPU) return;
std::size_t size = matrices_end - matrices;
if (stats.size() < size)
stats.resize(size);
for (std::size_t i = 0; i < size; ++i) {
- Run<Backend>(matrices[i], stats[i]);
+ Run<Backend, WriteC>(matrices[i], stats[i]);
}
}
@@ -169,13 +169,13 @@ int main(int argc, char ** argv) {
for (int samples = 0; samples < kSamples; ++samples) {
std::cerr << "Sample " << samples << " / " << kSamples << std::endl;
RandomMatrices *end = (samples < 4) ? matrices_end : full_sample;
- RunAll<SSSE3_8bit>(matrices, end, stats.ssse3_8bit);
- RunAll<SSE2_16bit>(matrices, end, stats.sse2_16bit);
- RunAll<AVX2_8bit>(matrices, end, stats.avx2_8bit);
- RunAll<AVX2_16bit>(matrices, end, stats.avx2_16bit);
+ RunAll<SSSE3_8bit, JustUnquantizeC>(matrices, end, stats.ssse3_8bit);
+ RunAll<SSE2_16bit, JustUnquantizeC>(matrices, end, stats.sse2_16bit);
+ RunAll<AVX2_8bit, JustUnquantizeC>(matrices, end, stats.avx2_8bit);
+ RunAll<AVX2_16bit, JustUnquantizeC>(matrices, end, stats.avx2_16bit);
#ifndef INTGEMM_NO_AVX512
- RunAll<AVX512_8bit>(matrices, end, stats.avx512_8bit);
- RunAll<AVX512_16bit>(matrices, end, stats.avx512_16bit);
+ RunAll<AVX512_8bit, JustUnquantizeC>(matrices, end, stats.avx512_8bit);
+ RunAll<AVX512_16bit, JustUnquantizeC>(matrices, end, stats.avx512_16bit);
#endif
}
diff --git a/example.cc b/example.cc
index 9ec9a7c..f68dd03 100644
--- a/example.cc
+++ b/example.cc
@@ -49,7 +49,7 @@ int main() {
AlignedVector<float> C(A_rows * B_cols);
// Do the actual multiply.
- intgemm::Int16<intgemm::JustUnquantizeC>::Multiply(A_prepared.get(), B_prepared.get(), C.get(), 1.0 / (quant_mult * quant_mult), A_rows, width, B_cols);
+ intgemm::Int16<intgemm::JustUnquantizeC>::Multiply(A_prepared.get(), B_prepared.get(), intgemm::JustUnquantizeC(C.get(), 1.0 / (quant_mult * quant_mult)), A_rows, width, B_cols);
// Sanity check. C will be row major.
assert(fabs(C[0] - top_left_reference) < 0.05);
}
@@ -68,7 +68,7 @@ int main() {
AlignedVector<float> C(A_rows * B_cols);
// Do the actual multiply.
- intgemm::Int8<intgemm::JustUnquantizeC>::Multiply(A_prepared.get(), B_prepared.get(), C.get(), 1.0 / (quant_mult * quant_mult), A_rows, width, B_cols);
+ intgemm::Int8<intgemm::JustUnquantizeC>::Multiply(A_prepared.get(), B_prepared.get(), intgemm::JustUnquantizeC(C.get(), 1.0 / (quant_mult * quant_mult)), A_rows, width, B_cols);
// Sanity check. C will be row major.
assert(fabs(C[0] - top_left_reference) < 0.05);
}
diff --git a/intgemm.h b/intgemm.h
index e0d3658..76414a1 100644
--- a/intgemm.h
+++ b/intgemm.h
@@ -56,7 +56,9 @@
namespace intgemm {
-struct Unsupported_16bit {
+template<class WriteC>
+class Unsupported_16bit {
+public:
static void Quantize(const float *, int16_t *, float, Index) {
throw UnsupportedCPU();
}
@@ -66,13 +68,15 @@ struct Unsupported_16bit {
static void SelectColumnsB(const int16_t *, int16_t *, Index, const Index *, const Index *) {
throw UnsupportedCPU();
}
- static void Multiply(const int16_t *, const int16_t *, float *, float, Index, Index, Index) {
+ static void Multiply(const int16_t *, const int16_t *, WriteC, Index, Index, Index) {
throw UnsupportedCPU();
}
constexpr static const char *const kName = "16-bit Unsupported";
};
-struct Unsupported_8bit {
+template<class WriteC>
+class Unsupported_8bit {
+public:
static void Quantize(const float *, int8_t *, float, Index) {
throw UnsupportedCPU();
}
@@ -82,7 +86,7 @@ struct Unsupported_8bit {
static void SelectColumnsB(const int8_t *, int8_t *, Index, const Index *, const Index *) {
throw UnsupportedCPU();
}
- static void Multiply(const int8_t *, const int8_t *, float *, float, Index, Index, Index) {
+ static void Multiply(const int8_t *, const int8_t *, WriteC, Index, Index, Index) {
throw UnsupportedCPU();
}
constexpr static const char *const kName = "8-bit Unsupported";
@@ -132,7 +136,7 @@ template <class T> T ChooseCPU(T avx512, T avx2, T ssse3, T sse2, T unsupported)
}
/* 16-bit matrix multiplication. */
-template<class cOperator>
+template<class WriteC>
class Int16 {
public:
typedef int16_t Integer;
@@ -163,28 +167,28 @@ public:
static void (*SelectColumnsB)(const int16_t *input, int16_t *output, Index rows, const Index *cols_begin, const Index *cols_end);
// Multiply C = A * B, presuming A and B have been prepared.
- static void (*Multiply)(const int16_t *A, const int16_t *B, float *C, float unquant_mult, Index A_rows, Index width, Index B_cols);
+ static void (*Multiply)(const int16_t *A, const int16_t *B, WriteC functor, Index A_rows, Index width, Index B_cols);
static const char *const kName;
};
-template <class cOperator>
-void (*Int16<cOperator>::Quantize)(const float *input, int16_t *output, float quant_mult, Index size) = ChooseCPU(AVX512_16bit::Quantize, AVX2_16bit::Quantize, SSE2_16bit::Quantize, SSE2_16bit::Quantize, Unsupported_16bit::Quantize);
+template <class WriteC>
+void (*Int16<WriteC>::Quantize)(const float *input, int16_t *output, float quant_mult, Index size) = ChooseCPU(AVX512_16bit::Quantize, AVX2_16bit::Quantize, SSE2_16bit::Quantize, SSE2_16bit::Quantize, Unsupported_16bit<WriteC>::Quantize);
-template <class cOperator>
-void (*Int16<cOperator>::PrepareB)(const float *input, int16_t *output, float quant_mult, Index rows, Index cols) = ChooseCPU(AVX512_16bit::PrepareB, AVX2_16bit::PrepareB, SSE2_16bit::PrepareB, SSE2_16bit::PrepareB, Unsupported_16bit::PrepareB);
+template <class WriteC>
+void (*Int16<WriteC>::PrepareB)(const float *input, int16_t *output, float quant_mult, Index rows, Index cols) = ChooseCPU(AVX512_16bit::PrepareB, AVX2_16bit::PrepareB, SSE2_16bit::PrepareB, SSE2_16bit::PrepareB, Unsupported_16bit<WriteC>::PrepareB);
-template <class cOperator>
-void (*Int16<cOperator>::SelectColumnsB)(const int16_t *input, int16_t *output, Index rows, const Index *cols_begin, const Index *cols_end) = ChooseCPU(AVX512_16bit::SelectColumnsB, AVX2_16bit::SelectColumnsB, SSE2_16bit::SelectColumnsB, SSE2_16bit::SelectColumnsB, Unsupported_16bit::SelectColumnsB);
+template <class WriteC>
+void (*Int16<WriteC>::SelectColumnsB)(const int16_t *input, int16_t *output, Index rows, const Index *cols_begin, const Index *cols_end) = ChooseCPU(AVX512_16bit::SelectColumnsB, AVX2_16bit::SelectColumnsB, SSE2_16bit::SelectColumnsB, SSE2_16bit::SelectColumnsB, Unsupported_16bit<WriteC>::SelectColumnsB);
-template <class cOperator>
-void (*Int16<cOperator>::Multiply)(const int16_t *A, const int16_t *B, float *C, float unquant_mult, Index A_rows, Index width, Index B_cols) = ChooseCPU(AVX512_16bit::Multiply, AVX2_16bit::Multiply, SSE2_16bit::Multiply, SSE2_16bit::Multiply, Unsupported_16bit::Multiply);
+template <class WriteC>
+void (*Int16<WriteC>::Multiply)(const int16_t *A, const int16_t *B, WriteC functor, Index A_rows, Index width, Index B_cols) = ChooseCPU(AVX512_16bit::Multiply<WriteC>, AVX2_16bit::Multiply<WriteC>, SSE2_16bit::Multiply<WriteC>, SSE2_16bit::Multiply<WriteC>, Unsupported_16bit<WriteC>::Multiply);
-template <class cOperator>
-const char *const Int16<cOperator>::kName = ChooseCPU(AVX512_16bit::kName, AVX2_16bit::kName, SSE2_16bit::kName, SSE2_16bit::kName, Unsupported_16bit::kName);
+template <class WriteC>
+const char *const Int16<WriteC>::kName = ChooseCPU(AVX512_16bit::kName, AVX2_16bit::kName, SSE2_16bit::kName, SSE2_16bit::kName, Unsupported_16bit<WriteC>::kName);
/* 8-bit matrix multiplication */
-template<class cOperator>
+template<class WriteC>
class Int8 {
public:
typedef int8_t Integer;
@@ -214,25 +218,25 @@ public:
static void (*SelectColumnsB)(const int8_t *input, int8_t *output, Index rows, const Index *cols_begin, const Index *cols_end);
// Multiply C = A * B, presuming A and B have been prepared.
- static void (*Multiply)(const int8_t *A, const int8_t *B, float *C, float unquant_mult, Index A_rows, Index width, Index B_cols);
+ static void (*Multiply)(const int8_t *A, const int8_t *B, WriteC functor, Index A_rows, Index width, Index B_cols);
static const char *const kName;
};
-template<class cOperator>
-void (*Int8<cOperator>::Quantize)(const float *input, int8_t *output, float quant_mult, Index size) = ChooseCPU(AVX512_8bit::Quantize, AVX2_8bit::Quantize, SSSE3_8bit::Quantize, Unsupported_8bit::Quantize, Unsupported_8bit::Quantize);
+template<class WriteC>
+void (*Int8<WriteC>::Quantize)(const float *input, int8_t *output, float quant_mult, Index size) = ChooseCPU(AVX512_8bit::Quantize, AVX2_8bit::Quantize, SSSE3_8bit::Quantize, Unsupported_8bit<WriteC>::Quantize, Unsupported_8bit<WriteC>::Quantize);
-template<class cOperator>
-void (*Int8<cOperator>::PrepareB)(const float *input, int8_t *output, float quant_mult, Index rows, Index cols) = ChooseCPU(AVX512_8bit::PrepareB, AVX2_8bit::PrepareB, SSSE3_8bit::PrepareB, Unsupported_8bit::PrepareB, Unsupported_8bit::PrepareB);
+template<class WriteC>
+void (*Int8<WriteC>::PrepareB)(const float *input, int8_t *output, float quant_mult, Index rows, Index cols) = ChooseCPU(AVX512_8bit::PrepareB, AVX2_8bit::PrepareB, SSSE3_8bit::PrepareB, Unsupported_8bit<WriteC>::PrepareB, Unsupported_8bit<WriteC>::PrepareB);
-template<class cOperator>
-void (*Int8<cOperator>::SelectColumnsB)(const int8_t *input, int8_t *output, Index rows, const Index *cols_begin, const Index *cols_end) = ChooseCPU(AVX512_8bit::SelectColumnsB, AVX2_8bit::SelectColumnsB, SSSE3_8bit::SelectColumnsB, Unsupported_8bit::SelectColumnsB, Unsupported_8bit::SelectColumnsB);
+template<class WriteC>
+void (*Int8<WriteC>::SelectColumnsB)(const int8_t *input, int8_t *output, Index rows, const Index *cols_begin, const Index *cols_end) = ChooseCPU(AVX512_8bit::SelectColumnsB, AVX2_8bit::SelectColumnsB, SSSE3_8bit::SelectColumnsB, Unsupported_8bit<WriteC>::SelectColumnsB, Unsupported_8bit<WriteC>::SelectColumnsB);
-template<class cOperator>
-void (*Int8<cOperator>::Multiply)(const int8_t *A, const int8_t *B, float *C, float unquant_mult, Index A_rows, Index width, Index B_cols) = ChooseCPU(AVX512_8bit::Multiply, AVX2_8bit::Multiply, SSSE3_8bit::Multiply, Unsupported_8bit::Multiply, Unsupported_8bit::Multiply);
+template<class WriteC>
+void (*Int8<WriteC>::Multiply)(const int8_t *A, const int8_t *B, WriteC functor, Index A_rows, Index width, Index B_cols) = ChooseCPU(AVX512_8bit::Multiply<WriteC>, AVX2_8bit::Multiply<WriteC>, SSSE3_8bit::Multiply<WriteC>, Unsupported_8bit<WriteC>::Multiply, Unsupported_8bit<WriteC>::Multiply);
-template<class cOperator>
-const char *const Int8<cOperator>::kName = ChooseCPU(AVX512_8bit::kName, AVX2_8bit::kName, SSSE3_8bit::kName, Unsupported_8bit::kName, Unsupported_8bit::kName);
+template<class WriteC>
+const char *const Int8<WriteC>::kName = ChooseCPU(AVX512_8bit::kName, AVX2_8bit::kName, SSSE3_8bit::kName, Unsupported_8bit<WriteC>::kName, Unsupported_8bit<WriteC>::kName);
const CPUType kCPU = ChooseCPU(CPU_AVX512BW, CPU_AVX2, CPU_SSSE3, CPU_SSE2, CPU_UNSUPPORTED);
diff --git a/multiply.h b/multiply.h
index 5ff6bb0..b761625 100644
--- a/multiply.h
+++ b/multiply.h
@@ -134,9 +134,9 @@ PACK_DEFINE(AVX512F, __m512i)
// A_rows can be anything non-negative.
// width must be a multiple of the register size.
// B_cols must be a multiple of 8.
-//template <class Integer, class WriteC> inline void Multiply16(const int16_t *A, const int16_t *B, WriteC functor, Index A_rows, Index width, Index B_cols) {
+// Multiply16
#define MULTIPLY16_define(Integer, target) \
- template <class WriteC> target inline void Multiply16##Integer(const int16_t *A, const int16_t *B, WriteC functor, Index A_rows, Index width, Index B_cols) { \
+ template <class WriteC> target static void Multiply(const int16_t *A, const int16_t *B, WriteC functor, Index A_rows, Index width, Index B_cols) { \
assert(width % (sizeof(Integer) / sizeof(int16_t)) == 0); \
assert(B_cols % 8 == 0); \
assert(reinterpret_cast<uintptr_t>(A) % sizeof(Integer) == 0); \
@@ -192,13 +192,6 @@ PACK_DEFINE(AVX512F, __m512i)
} \
} \
-MULTIPLY16_define(__m128i, SSE2)
-
-MULTIPLY16_define(__m256i, AVX2)
-#ifndef INTGEMM_NO_AVX512
-MULTIPLY16_define(__m512i, AVX512F)
-#endif
-//MULTIPLY16_define(__m256i, AVX2)
/* 8-bit matrix multiply used by AVX and AVX2.
* These have two peculiar properties:
* 1. The sign instructions don't exist in AVX512.
@@ -346,9 +339,9 @@ SSSE3 inline static void InnerSSSE3(
sum6 = adds_epi16(sum6, maddubs_epi16(a_positive, sign_epi8(b[6], a)));
sum7 = adds_epi16(sum7, maddubs_epi16(a_positive, sign_epi8(b[7], a)));
}
-
+//AVX2 or SSSE3 multiply
#define MULTIPLY8_define(Integer, target) \
-template <class WriteC> target inline void Multiply8_SSE2OrAVX2##Integer(const int8_t *A, const int8_t *B, WriteC functor, Index A_rows, Index width, Index B_cols) { \
+template <class WriteC> target static void Multiply(const int8_t *A, const int8_t *B, WriteC functor, Index A_rows, Index width, Index B_cols) { \
assert(width % sizeof(Integer) == 0); \
assert(B_cols % 8 == 0); \
assert(reinterpret_cast<uintptr_t>(A) % sizeof(Integer) == 0); \
@@ -415,10 +408,6 @@ template <class WriteC> target inline void Multiply8_SSE2OrAVX2##Integer(const i
} \
} \
-MULTIPLY8_define(__m128i, SSSE3)
-
-MULTIPLY8_define(__m256i, AVX2)
-
// Find the maximum absolute value of packed float32s.
/*
diff --git a/sse2_gemm.h b/sse2_gemm.h
index a7d60ae..3ede6d8 100644
--- a/sse2_gemm.h
+++ b/sse2_gemm.h
@@ -73,11 +73,12 @@ struct SSE2_16bit {
//TODO #DEFINE
sse2::SelectColumnsOfB((const __m128i*)input, (__m128i*)output, rows * 2, cols_begin, cols_end);
}
-
+/*
SSE2 static void Multiply(const int16_t *A, const int16_t *B, float *C, float unquant_mult, Index A_rows, Index width, Index B_cols) {
//TODO #DEFINE
Multiply16__m128i<JustUnquantizeC> (A, B, JustUnquantizeC(C, unquant_mult), A_rows, width, B_cols);
- }
+ }*/
+ MULTIPLY16_define(__m128i, SSE2)
constexpr static const char *const kName = "16-bit SSE2";
diff --git a/ssse3_gemm.h b/ssse3_gemm.h
index bb2df94..d384500 100644
--- a/ssse3_gemm.h
+++ b/ssse3_gemm.h
@@ -97,11 +97,12 @@ struct SSSE3_8bit {
SSSE3 static void SelectColumnsB(const int8_t *input, int8_t *output, Index rows, const Index *cols_begin, const Index *cols_end) {
ssse3::SelectColumnsOfB((const __m128i*)input, (__m128i*)output, rows, cols_begin, cols_end);
}
-
+/*
SSSE3 static void Multiply(const int8_t *A, const int8_t *B, float *C, float unquant_mult, Index A_rows, Index width, Index B_cols) {
//Multiply8_SSE2OrAVX2<Multiply8_C, __m128i, __m128>(A, B, C, unquant_mult, A_rows, width, B_cols);
Multiply8_SSE2OrAVX2__m128i<JustUnquantizeC>(A, B, JustUnquantizeC(C, unquant_mult), A_rows, width, B_cols);
- }
+ }*/
+ MULTIPLY8_define(__m128i, SSSE3)
constexpr static const char *const kName = "8-bit SSSE3";
diff --git a/test/multiply_test.cc b/test/multiply_test.cc
index 26ca9be..ee5e2c6 100644
--- a/test/multiply_test.cc
+++ b/test/multiply_test.cc
@@ -331,7 +331,7 @@ void Compare(const float *float_ref, const float *int_ref, const float *int_test
CHECK_MESSAGE(fabs(sqrt(int_sum / size)) <= MSE_int_tolerance, test_info << "Int MSE = " << sqrt(int_sum / size));
}
-template <class Routine> void TestMultiply(Index A_rows, Index width, Index B_cols,
+template <class Routine, class WriteC> void TestMultiply(Index A_rows, Index width, Index B_cols,
float int_tolerance=.1, float float_tolerance=1, float MSE_float_tolerance=0, float MSE_int_tolerance=0) {
typedef typename Routine::Integer Integer;
std::ostringstream info;
@@ -357,7 +357,7 @@ template <class Routine> void TestMultiply(Index A_rows, Index width, Index B_co
Routine::PrepareB(B.get(), B_prep.get(), quant_mult, width, B_cols);
AlignedVector<float> test_C(A_rows * B_cols);
- Routine::Multiply(A_prep.get(), B_prep.get(), test_C.get(), unquant_mult, A_rows, width, B_cols);
+ Routine::template Multiply<WriteC>(A_prep.get(), B_prep.get(), WriteC(test_C.get(), unquant_mult), A_rows, width, B_cols);
AlignedVector<Integer> B_quant(width * B_cols);
Routine::Quantize(B.get(), B_quant.get(), quant_mult, width * B_cols);
@@ -374,63 +374,63 @@ template <class Routine> void TestMultiply(Index A_rows, Index width, Index B_co
TEST_CASE ("Multiply SSE2 16bit", "[multiply]") {
if (kCPU < CPU_SSE2) return;
- TestMultiply<SSE2_16bit>(8, 256, 256, .1, 1, 0.01);
- TestMultiply<SSE2_16bit>(8, 2048, 256, .1, 1, 0.02);
- TestMultiply<SSE2_16bit>(320, 256, 256, .1, 1, 0.01);
- TestMultiply<SSE2_16bit>(472, 256, 256, .1, 1, 0.01);
- TestMultiply<SSE2_16bit>(248, 256, 256, .1, 1, 0.01);
- TestMultiply<SSE2_16bit>(200, 256, 256, .1, 1, 0.01);
+ TestMultiply<SSE2_16bit, JustUnquantizeC>(8, 256, 256, .1, 1, 0.01);
+ TestMultiply<SSE2_16bit, JustUnquantizeC>(8, 2048, 256, .1, 1, 0.02);
+ TestMultiply<SSE2_16bit, JustUnquantizeC>(320, 256, 256, .1, 1, 0.01);
+ TestMultiply<SSE2_16bit, JustUnquantizeC>(472, 256, 256, .1, 1, 0.01);
+ TestMultiply<SSE2_16bit, JustUnquantizeC>(248, 256, 256, .1, 1, 0.01);
+ TestMultiply<SSE2_16bit, JustUnquantizeC>(200, 256, 256, .1, 1, 0.01);
}
TEST_CASE ("Multiply SSSE3 8bit", "[multiply]") {
if (kCPU < CPU_SSSE3) return;
- TestMultiply<SSSE3_8bit>(8, 256, 256, 1.2, 1.2, 0.064, 0.026);
- TestMultiply<SSSE3_8bit>(8, 2048, 256, 33, 33, 4.4, 4.4);
- TestMultiply<SSSE3_8bit>(320, 256, 256, 1.9, 1.9, 0.1, 0.01);
- TestMultiply<SSSE3_8bit>(472, 256, 256, 2.1, 2.1, 0.1, 0.011);
- TestMultiply<SSSE3_8bit>(248, 256, 256, 1.7, 1.7, 0.1, 0.012);
- TestMultiply<SSSE3_8bit>(200, 256, 256, 1.8, 1.9, 0.1, 0.011);
+ TestMultiply<SSSE3_8bit, JustUnquantizeC>(8, 256, 256, 1.2, 1.2, 0.064, 0.026);
+ TestMultiply<SSSE3_8bit, JustUnquantizeC>(8, 2048, 256, 33, 33, 4.4, 4.4);
+ TestMultiply<SSSE3_8bit, JustUnquantizeC>(320, 256, 256, 1.9, 1.9, 0.1, 0.01);
+ TestMultiply<SSSE3_8bit, JustUnquantizeC>(472, 256, 256, 2.1, 2.1, 0.1, 0.011);
+ TestMultiply<SSSE3_8bit, JustUnquantizeC>(248, 256, 256, 1.7, 1.7, 0.1, 0.012);
+ TestMultiply<SSSE3_8bit, JustUnquantizeC>(200, 256, 256, 1.8, 1.9, 0.1, 0.011);
}
TEST_CASE ("Multiply AVX2 8bit", "[multiply]") {
if (kCPU < CPU_AVX2) return;
- TestMultiply<AVX2_8bit>(8, 256, 256, .1, 1, 0.1);
- TestMultiply<AVX2_8bit>(8, 2048, 256, 19, 19, 1.8, 1.8);
- TestMultiply<AVX2_8bit>(320, 256, 256, .1, 1, 0.1);
- TestMultiply<AVX2_8bit>(472, 256, 256, .1, 1, 0.1);
- TestMultiply<AVX2_8bit>(248, 256, 256, .1, 1, 0.1);
- TestMultiply<AVX2_8bit>(200, 256, 256, .1, 1, 0.1);
+ TestMultiply<AVX2_8bit, JustUnquantizeC>(8, 256, 256, .1, 1, 0.1);
+ TestMultiply<AVX2_8bit, JustUnquantizeC>(8, 2048, 256, 19, 19, 1.8, 1.8);
+ TestMultiply<AVX2_8bit, JustUnquantizeC>(320, 256, 256, .1, 1, 0.1);
+ TestMultiply<AVX2_8bit, JustUnquantizeC>(472, 256, 256, .1, 1, 0.1);
+ TestMultiply<AVX2_8bit, JustUnquantizeC>(248, 256, 256, .1, 1, 0.1);
+ TestMultiply<AVX2_8bit, JustUnquantizeC>(200, 256, 256, .1, 1, 0.1);
}
TEST_CASE ("Multiply AVX2 16bit", "[multiply]") {
if (kCPU < CPU_AVX2) return;
- TestMultiply<AVX2_16bit>(8, 256, 256, .1, 1, 0.01);
- TestMultiply<AVX2_16bit>(8, 2048, 256, .1, 1, 0.02);
- TestMultiply<AVX2_16bit>(320, 256, 256, .1, 1, 0.01);
- TestMultiply<AVX2_16bit>(472, 256, 256, .1, 1, 0.01);
- TestMultiply<AVX2_16bit>(248, 256, 256, .1, 1, 0.01);
- TestMultiply<AVX2_16bit>(200, 256, 256, .1, 1, 0.01);
+ TestMultiply<AVX2_16bit, JustUnquantizeC>(8, 256, 256, .1, 1, 0.01);
+ TestMultiply<AVX2_16bit, JustUnquantizeC>(8, 2048, 256, .1, 1, 0.02);
+ TestMultiply<AVX2_16bit, JustUnquantizeC>(320, 256, 256, .1, 1, 0.01);
+ TestMultiply<AVX2_16bit, JustUnquantizeC>(472, 256, 256, .1, 1, 0.01);
+ TestMultiply<AVX2_16bit, JustUnquantizeC>(248, 256, 256, .1, 1, 0.01);
+ TestMultiply<AVX2_16bit, JustUnquantizeC>(200, 256, 256, .1, 1, 0.01);
}
#ifndef INTGEMM_NO_AVX512
TEST_CASE ("Multiply AVX512 8bit", "[multiply]") {
if (kCPU < CPU_AVX512BW) return;
- TestMultiply<AVX512_8bit>(8, 256, 256, .1, 1, 0.062);
- TestMultiply<AVX512_8bit>(8, 2048, 256, 4.2, 4, 0.41, 0.37);
- TestMultiply<AVX512_8bit>(320, 256, 256, .1, 1, 0.06);
- TestMultiply<AVX512_8bit>(472, 256, 256, .1, 1, 0.06);
- TestMultiply<AVX512_8bit>(248, 256, 256, .1, 1, 0.06);
- TestMultiply<AVX512_8bit>(200, 256, 256, .1, 1, 0.06);
+ TestMultiply<AVX512_8bit, JustUnquantizeC>(8, 256, 256, .1, 1, 0.062);
+ TestMultiply<AVX512_8bit, JustUnquantizeC>(8, 2048, 256, 4.2, 4, 0.41, 0.37);
+ TestMultiply<AVX512_8bit, JustUnquantizeC>(320, 256, 256, .1, 1, 0.06);
+ TestMultiply<AVX512_8bit, JustUnquantizeC>(472, 256, 256, .1, 1, 0.06);
+ TestMultiply<AVX512_8bit, JustUnquantizeC>(248, 256, 256, .1, 1, 0.06);
+ TestMultiply<AVX512_8bit, JustUnquantizeC>(200, 256, 256, .1, 1, 0.06);
}
TEST_CASE ("Multiply AVX512 16bit", "[multiply]") {
if (kCPU < CPU_AVX512BW) return;
- TestMultiply<AVX512_16bit>(8, 256, 256, .1, 1, 0.01);
- TestMultiply<AVX512_16bit>(8, 2048, 256, .1, 1, 0.011);
- TestMultiply<AVX512_16bit>(320, 256, 256, .1, 1, 0.01);
- TestMultiply<AVX512_16bit>(472, 256, 256, .1, 1, 0.01);
- TestMultiply<AVX512_16bit>(248, 256, 256, .1, 1, 0.01);
- TestMultiply<AVX512_16bit>(200, 256, 256, .1, 1, 0.01);
+ TestMultiply<AVX512_16bit, JustUnquantizeC>(8, 256, 256, .1, 1, 0.01);
+ TestMultiply<AVX512_16bit, JustUnquantizeC>(8, 2048, 256, .1, 1, 0.011);
+ TestMultiply<AVX512_16bit, JustUnquantizeC>(320, 256, 256, .1, 1, 0.01);
+ TestMultiply<AVX512_16bit, JustUnquantizeC>(472, 256, 256, .1, 1, 0.01);
+ TestMultiply<AVX512_16bit, JustUnquantizeC>(248, 256, 256, .1, 1, 0.01);
+ TestMultiply<AVX512_16bit, JustUnquantizeC>(200, 256, 256, .1, 1, 0.01);
}
#endif