Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/intgemm.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMateusz Chudyk <mateuszchudyk@gmail.com>2019-06-18 16:34:01 +0300
committerMateusz Chudyk <mateuszchudyk@gmail.com>2019-06-18 16:41:12 +0300
commitee2742f9135500a957393a6bdf3ee2ac1e7f9862 (patch)
tree60b3f74362c574d5b3fcb76522679c1e40f358b2
parentf5c08aea11e572e57216f812c0123f0627ff2853 (diff)
Make CPUType scoped enum
-rw-r--r--avx2_gemm.h8
-rw-r--r--avx512_gemm.h8
-rw-r--r--intgemm.cc2
-rw-r--r--postprocess.h24
-rw-r--r--postprocess_pipeline.h16
-rw-r--r--sse2_gemm.h4
-rw-r--r--ssse3_gemm.h4
-rw-r--r--test/multiply_test.cc50
-rw-r--r--test/pipeline_test.cc8
-rw-r--r--test/quantize_test.cc8
-rw-r--r--test/relu_test.cc12
-rw-r--r--types.h8
12 files changed, 79 insertions, 73 deletions
diff --git a/avx2_gemm.h b/avx2_gemm.h
index a03ff09..1482090 100644
--- a/avx2_gemm.h
+++ b/avx2_gemm.h
@@ -80,11 +80,11 @@ struct AVX2_16bit {
avx2::SelectColumnsOfB((const __m256i*)input, (__m256i*)output, rows * 2, cols_begin, cols_end);
}
- INTGEMM_MULTIPLY16(__m256i, INTGEMM_AVX2, CPUType::CPU_AVX2)
+ INTGEMM_MULTIPLY16(__m256i, INTGEMM_AVX2, CPUType::AVX2)
constexpr static const char *const kName = "16-bit INTGEMM_AVX2";
- static const CPUType kUses = CPU_AVX2;
+ static const CPUType kUses = CPUType::AVX2;
};
namespace avx2 {
@@ -169,11 +169,11 @@ struct AVX2_8bit {
avx2::SelectColumnsOfB((const __m256i*)input, (__m256i*)output, rows, cols_begin, cols_end);
}
- INTGEMM_MULTIPLY8(__m256i, INTGEMM_AVX2, CPUType::CPU_AVX2)
+ INTGEMM_MULTIPLY8(__m256i, INTGEMM_AVX2, CPUType::AVX2)
constexpr static const char *const kName = "8-bit INTGEMM_AVX2";
- static const CPUType kUses = CPU_AVX2;
+ static const CPUType kUses = CPUType::AVX2;
};
} // namespace intgemm
diff --git a/avx512_gemm.h b/avx512_gemm.h
index 4d6f0db..043dfae 100644
--- a/avx512_gemm.h
+++ b/avx512_gemm.h
@@ -166,11 +166,11 @@ struct AVX512_16bit {
}
/* Only INTGEMM_AVX512F is necessary but due to GCC 5.4 bug we have to set INTGEMM_AVX512BW */
- INTGEMM_MULTIPLY16(__m512i, INTGEMM_AVX512BW, CPUType::CPU_AVX2)
+ INTGEMM_MULTIPLY16(__m512i, INTGEMM_AVX512BW, CPUType::AVX2)
constexpr static const char *const kName = "16-bit AVX512";
- static const CPUType kUses = CPU_AVX512BW;
+ static const CPUType kUses = CPUType::AVX512BW;
};
struct AVX512_8bit {
@@ -227,7 +227,7 @@ struct AVX512_8bit {
assert(reinterpret_cast<uintptr_t>(A) % sizeof(Integer) == 0);
assert(reinterpret_cast<uintptr_t>(B) % sizeof(Integer) == 0);
// There's 8 results for INTGEMM_AVX2 to handle.
- auto inited_pipeline = InitPostprocessPipeline<CPUType::CPU_AVX2>(pipeline);
+ auto inited_pipeline = InitPostprocessPipeline<CPUType::AVX2>(pipeline);
const int simd_width = width / sizeof(Integer);
const Integer *B0_col = reinterpret_cast<const Integer*>(B);
// Added for AVX512.
@@ -333,7 +333,7 @@ struct AVX512_8bit {
constexpr static const char *const kName = "8-bit AVX512";
- static const CPUType kUses = CPU_AVX512BW;
+ static const CPUType kUses = CPUType::AVX512BW;
};
} // namespace intgemm
diff --git a/intgemm.cc b/intgemm.cc
index 511eafb..39ba227 100644
--- a/intgemm.cc
+++ b/intgemm.cc
@@ -22,7 +22,7 @@ void (*Int8::SelectColumnsB)(const int8_t *input, int8_t *output, Index rows, co
const char *const Int8::kName = ChooseCPU(AVX512_8bit::kName, AVX2_8bit::kName, SSSE3_8bit::kName, Unsupported_8bit::kName, Unsupported_8bit::kName);
-const CPUType kCPU = ChooseCPU(CPU_AVX512BW, CPU_AVX2, CPU_SSSE3, CPU_SSE2, CPU_UNSUPPORTED);
+const CPUType kCPU = ChooseCPU(CPUType::AVX512BW, CPUType::AVX2, CPUType::SSSE3, CPUType::SSE2, CPUType::UNSUPPORTED);
float (*MaxAbsolute)(const float *begin, const float *end) = ChooseCPU(avx512f::MaxAbsolute, avx2::MaxAbsolute, sse2::MaxAbsolute, sse2::MaxAbsolute, Unsupported_MaxAbsolute);
diff --git a/postprocess.h b/postprocess.h
index 76755b3..0855548 100644
--- a/postprocess.h
+++ b/postprocess.h
@@ -18,7 +18,7 @@ public:
};
template <>
-class PostprocessImpl<Unquantize, CPUType::CPU_SSE2> {
+class PostprocessImpl<Unquantize, CPUType::SSE2> {
public:
using InputRegister = RegisterPair128i;
using OutputRegister = RegisterPair128;
@@ -39,7 +39,7 @@ private:
};
template <>
-class PostprocessImpl<Unquantize, CPUType::CPU_AVX2> {
+class PostprocessImpl<Unquantize, CPUType::AVX2> {
public:
using InputRegister = __m256i;
using OutputRegister = __m256;
@@ -57,7 +57,7 @@ private:
};
template <>
-class PostprocessImpl<Unquantize, CPUType::CPU_AVX512BW> {
+class PostprocessImpl<Unquantize, CPUType::AVX512BW> {
public:
using InputRegister = __m512i;
using OutputRegister = __m512;
@@ -80,7 +80,7 @@ private:
class Identity {};
template <>
-class PostprocessImpl<Identity, CPUType::CPU_SSE2> {
+class PostprocessImpl<Identity, CPUType::SSE2> {
public:
using InputRegister = RegisterPair128i;
using OutputRegister = RegisterPair128i;
@@ -93,7 +93,7 @@ public:
};
template <>
-class PostprocessImpl<Identity, CPUType::CPU_AVX2> {
+class PostprocessImpl<Identity, CPUType::AVX2> {
public:
using InputRegister = __m256i;
using OutputRegister = __m256i;
@@ -106,7 +106,7 @@ public:
};
template <>
-class PostprocessImpl<Identity, CPUType::CPU_AVX512BW> {
+class PostprocessImpl<Identity, CPUType::AVX512BW> {
public:
using InputRegister = __m512i;
using OutputRegister = __m512i;
@@ -130,7 +130,7 @@ public:
};
template <>
-class PostprocessImpl<AddBias, CPUType::CPU_SSE2> {
+class PostprocessImpl<AddBias, CPUType::SSE2> {
public:
using InputRegister = RegisterPair128;
using OutputRegister = RegisterPair128;
@@ -151,7 +151,7 @@ private:
};
template <>
-class PostprocessImpl<AddBias, CPUType::CPU_AVX2> {
+class PostprocessImpl<AddBias, CPUType::AVX2> {
public:
using InputRegister = __m256;
using OutputRegister = __m256;
@@ -173,7 +173,7 @@ private:
class ReLU {};
template <>
-class PostprocessImpl<ReLU, CPUType::CPU_SSE2> {
+class PostprocessImpl<ReLU, CPUType::SSE2> {
public:
using InputRegister = RegisterPair128;
using OutputRegister = RegisterPair128;
@@ -190,10 +190,10 @@ public:
};
template <>
-class PostprocessImpl<ReLU, CPUType::CPU_SSSE3> : public PostprocessImpl<ReLU, CPUType::CPU_SSE2> {};
+class PostprocessImpl<ReLU, CPUType::SSSE3> : public PostprocessImpl<ReLU, CPUType::SSE2> {};
template <>
-class PostprocessImpl<ReLU, CPUType::CPU_AVX2> {
+class PostprocessImpl<ReLU, CPUType::AVX2> {
public:
using InputRegister = __m256;
using OutputRegister = __m256;
@@ -207,7 +207,7 @@ public:
};
template <>
-class PostprocessImpl<ReLU, CPUType::CPU_AVX512BW> {
+class PostprocessImpl<ReLU, CPUType::AVX512BW> {
public:
using InputRegister = __m512;
using OutputRegister = __m512;
diff --git a/postprocess_pipeline.h b/postprocess_pipeline.h
index 01f3f11..ad26ac5 100644
--- a/postprocess_pipeline.h
+++ b/postprocess_pipeline.h
@@ -71,10 +71,10 @@ struct RunPostprocessPipelineImpl;
} \
};
-RUN_POSTPROCESS_PIPELINE_IMPL_INSERT_IMPL(INTGEMM_SSE2, CPUType::CPU_SSE2)
-RUN_POSTPROCESS_PIPELINE_IMPL_INSERT_IMPL(INTGEMM_SSSE3, CPUType::CPU_SSSE3)
-RUN_POSTPROCESS_PIPELINE_IMPL_INSERT_IMPL(INTGEMM_AVX2, CPUType::CPU_AVX2)
-RUN_POSTPROCESS_PIPELINE_IMPL_INSERT_IMPL(INTGEMM_AVX512BW, CPUType::CPU_AVX512BW)
+RUN_POSTPROCESS_PIPELINE_IMPL_INSERT_IMPL(INTGEMM_SSE2, CPUType::SSE2)
+RUN_POSTPROCESS_PIPELINE_IMPL_INSERT_IMPL(INTGEMM_SSSE3, CPUType::SSSE3)
+RUN_POSTPROCESS_PIPELINE_IMPL_INSERT_IMPL(INTGEMM_AVX2, CPUType::AVX2)
+RUN_POSTPROCESS_PIPELINE_IMPL_INSERT_IMPL(INTGEMM_AVX512BW, CPUType::AVX512BW)
} // anonymous namespace
@@ -105,9 +105,9 @@ constexpr InitedPostprocessPipeline<CpuType, Stages...> InitPostprocessPipeline(
const std::tuple<PostprocessImpl<Stages, cpu_type>...> inited_pipeline; \
};
-INITED_POSTPROCESS_PIPELINE_INSERT_IMPL(INTGEMM_SSE2, CPUType::CPU_SSE2)
-INITED_POSTPROCESS_PIPELINE_INSERT_IMPL(INTGEMM_SSSE3, CPUType::CPU_SSSE3)
-INITED_POSTPROCESS_PIPELINE_INSERT_IMPL(INTGEMM_AVX2, CPUType::CPU_AVX2)
-INITED_POSTPROCESS_PIPELINE_INSERT_IMPL(INTGEMM_AVX512BW, CPUType::CPU_AVX512BW)
+INITED_POSTPROCESS_PIPELINE_INSERT_IMPL(INTGEMM_SSE2, CPUType::SSE2)
+INITED_POSTPROCESS_PIPELINE_INSERT_IMPL(INTGEMM_SSSE3, CPUType::SSSE3)
+INITED_POSTPROCESS_PIPELINE_INSERT_IMPL(INTGEMM_AVX2, CPUType::AVX2)
+INITED_POSTPROCESS_PIPELINE_INSERT_IMPL(INTGEMM_AVX512BW, CPUType::AVX512BW)
}
diff --git a/sse2_gemm.h b/sse2_gemm.h
index dfccc5c..6b7e698 100644
--- a/sse2_gemm.h
+++ b/sse2_gemm.h
@@ -72,11 +72,11 @@ struct SSE2_16bit {
//TODO #DEFINE
sse2::SelectColumnsOfB((const __m128i*)input, (__m128i*)output, rows * 2, cols_begin, cols_end);
}
- INTGEMM_MULTIPLY16(__m128i, INTGEMM_SSE2, CPUType::CPU_SSE2)
+ INTGEMM_MULTIPLY16(__m128i, INTGEMM_SSE2, CPUType::SSE2)
constexpr static const char *const kName = "16-bit INTGEMM_SSE2";
- static const CPUType kUses = CPU_SSE2;
+ static const CPUType kUses = CPUType::SSE2;
};
} // namespace intgemm
diff --git a/ssse3_gemm.h b/ssse3_gemm.h
index 4e12b90..6038889 100644
--- a/ssse3_gemm.h
+++ b/ssse3_gemm.h
@@ -95,11 +95,11 @@ struct SSSE3_8bit {
ssse3::SelectColumnsOfB((const __m128i*)input, (__m128i*)output, rows, cols_begin, cols_end);
}
- INTGEMM_MULTIPLY8(__m128i, INTGEMM_SSSE3, CPUType::CPU_SSE2)
+ INTGEMM_MULTIPLY8(__m128i, INTGEMM_SSSE3, CPUType::SSE2)
constexpr static const char *const kName = "8-bit INTGEMM_SSSE3";
- static const CPUType kUses = CPU_SSSE3;
+ static const CPUType kUses = CPUType::SSSE3;
};
} // namespace intgemm
diff --git a/test/multiply_test.cc b/test/multiply_test.cc
index 00256a8..82062fe 100644
--- a/test/multiply_test.cc
+++ b/test/multiply_test.cc
@@ -52,7 +52,7 @@ template <class V> void SlowTranspose(const V *from, V *to, Index rows, Index co
}
INTGEMM_SSE2 TEST_CASE("Transpose 16", "[transpose]") {
- if (kCPU < CPU_SSE2) return;
+ if (kCPU < CPUType::SSE2) return;
const unsigned N = 8;
AlignedVector<int16_t> input(N * N);
std::iota(input.begin(), input.end(), 0);
@@ -70,7 +70,7 @@ INTGEMM_SSE2 TEST_CASE("Transpose 16", "[transpose]") {
}
INTGEMM_SSSE3 TEST_CASE("Transpose 8", "[transpose]") {
- if (kCPU < CPU_SSSE3) return;
+ if (kCPU < CPUType::SSSE3) return;
const unsigned N = 16;
AlignedVector<int8_t> input(N * N);
std::iota(input.begin(), input.end(), 0);
@@ -125,7 +125,7 @@ template <class Routine> void TestPrepare(Index rows = 32, Index cols = 16) {
}
TEST_CASE("Prepare AVX512", "[prepare]") {
- if (kCPU < CPU_AVX512BW) return;
+ if (kCPU < CPUType::AVX512BW) return;
#ifndef INTGEMM_NO_AVX512
TestPrepare<AVX512_8bit>(64, 8);
TestPrepare<AVX512_8bit>(256, 32);
@@ -135,20 +135,20 @@ TEST_CASE("Prepare AVX512", "[prepare]") {
}
TEST_CASE("Prepare AVX2", "[prepare]") {
- if (kCPU < CPU_AVX2) return;
+ if (kCPU < CPUType::AVX2) return;
TestPrepare<AVX2_8bit>(64, 32);
TestPrepare<AVX2_16bit>(64, 32);
}
TEST_CASE("Prepare SSSE3", "[prepare]") {
- if (kCPU < CPU_SSSE3) return;
+ if (kCPU < CPUType::SSSE3) return;
TestPrepare<SSSE3_8bit>(16, 8);
TestPrepare<SSSE3_8bit>(32, 16);
TestPrepare<SSSE3_8bit>(32, 32);
}
TEST_CASE("Prepare SSE2", "[prepare]") {
- if (kCPU < CPU_SSE2) return;
+ if (kCPU < CPUType::SSE2) return;
TestPrepare<SSE2_16bit>(8, 8);
TestPrepare<SSE2_16bit>(32, 32);
}
@@ -190,7 +190,7 @@ template <class Routine> void TestSelectColumnsB(Index rows = 64, Index cols = 1
}
TEST_CASE("SelectColumnsB AVX512", "[select]") {
- if (kCPU < CPU_AVX512BW) return;
+ if (kCPU < CPUType::AVX512BW) return;
#ifndef INTGEMM_NO_AVX512
TestSelectColumnsB<AVX512_8bit>();
TestSelectColumnsB<AVX512_16bit>(256, 256);
@@ -198,19 +198,19 @@ TEST_CASE("SelectColumnsB AVX512", "[select]") {
}
TEST_CASE("SelectColumnsB AVX2", "[select]") {
- if (kCPU < CPU_AVX2) return;
+ if (kCPU < CPUType::AVX2) return;
TestSelectColumnsB<AVX2_8bit>(256, 256);
TestSelectColumnsB<AVX2_16bit>(256, 256);
}
TEST_CASE("SelectColumnsB SSSE3", "[select]") {
- if (kCPU < CPU_SSSE3) return;
+ if (kCPU < CPUType::SSSE3) return;
TestSelectColumnsB<SSSE3_8bit>();
TestSelectColumnsB<SSSE3_8bit>(256, 256);
}
TEST_CASE("SelectColumnsB SSE2", "[select]") {
- if (kCPU < CPU_SSE2) return;
+ if (kCPU < CPUType::SSE2) return;
TestSelectColumnsB<SSE2_16bit>();
TestSelectColumnsB<SSE2_16bit>(256, 256);
}
@@ -254,17 +254,17 @@ template <float (*Backend) (const float *, const float *)> void TestMaxAbsolute(
}
TEST_CASE("MaxAbsolute SSE2", "[max]") {
- if (kCPU < CPU_SSE2) return;
+ if (kCPU < CPUType::SSE2) return;
TestMaxAbsolute<sse2::MaxAbsolute>();
}
TEST_CASE("MaxAbsolute AVX2", "[max]") {
- if (kCPU < CPU_AVX2) return;
+ if (kCPU < CPUType::AVX2) return;
TestMaxAbsolute<avx2::MaxAbsolute>();
}
TEST_CASE("MaxAbsolute AVX512F", "[max]") {
- if (kCPU < CPU_AVX512BW) return;
+ if (kCPU < CPUType::AVX512BW) return;
#ifndef INTGEMM_NO_AVX512
TestMaxAbsolute<avx512f::MaxAbsolute>();
#endif
@@ -432,7 +432,7 @@ template <class Routine> void TestMultiplyBias(Index A_rows, Index width, Index
}
TEST_CASE ("Multiply SSE2 16bit", "[multiply]") {
- if (kCPU < CPU_SSE2) return;
+ if (kCPU < CPUType::SSE2) return;
TestMultiply<SSE2_16bit>(8, 256, 256, .1, 1, 0.01);
TestMultiply<SSE2_16bit>(8, 2048, 256, .1, 1, 0.02);
TestMultiply<SSE2_16bit>(320, 256, 256, .1, 1, 0.01);
@@ -442,7 +442,7 @@ TEST_CASE ("Multiply SSE2 16bit", "[multiply]") {
}
TEST_CASE ("Multiply SSE2 16bit with bias", "[biased_multiply]") {
- if (kCPU < CPU_SSE2) return;
+ if (kCPU < CPUType::SSE2) return;
TestMultiplyBias<SSE2_16bit>(8, 256, 256, .1, 1, 0.01);
TestMultiplyBias<SSE2_16bit>(8, 2048, 256, .1, 1, 0.02);
TestMultiplyBias<SSE2_16bit>(320, 256, 256, .1, 1, 0.01);
@@ -452,7 +452,7 @@ TEST_CASE ("Multiply SSE2 16bit with bias", "[biased_multiply]") {
}
TEST_CASE ("Multiply SSSE3 8bit", "[multiply]") {
- if (kCPU < CPU_SSSE3) return;
+ if (kCPU < CPUType::SSSE3) return;
TestMultiply<SSSE3_8bit>(8, 256, 256, 1.2, 1.2, 0.064, 0.026);
TestMultiply<SSSE3_8bit>(8, 2048, 256, 33, 33, 4.4, 4.4);
TestMultiply<SSSE3_8bit>(320, 256, 256, 1.9, 1.9, 0.1, 0.01);
@@ -462,7 +462,7 @@ TEST_CASE ("Multiply SSSE3 8bit", "[multiply]") {
}
TEST_CASE ("Multiply SSSE3 8bit with bias", "[biased_multiply]") {
- if (kCPU < CPU_SSSE3) return;
+ if (kCPU < CPUType::SSSE3) return;
TestMultiplyBias<SSSE3_8bit>(8, 256, 256, 1.2, 1.2, 0.064, 0.026);
TestMultiplyBias<SSSE3_8bit>(8, 2048, 256, 33, 33, 4.4, 4.4);
TestMultiplyBias<SSSE3_8bit>(320, 256, 256, 1.9, 1.9, 0.1, 0.01);
@@ -472,7 +472,7 @@ TEST_CASE ("Multiply SSSE3 8bit with bias", "[biased_multiply]") {
}
TEST_CASE ("Multiply AVX2 8bit", "[multiply]") {
- if (kCPU < CPU_AVX2) return;
+ if (kCPU < CPUType::AVX2) return;
TestMultiply<AVX2_8bit>(8, 256, 256, .1, 1, 0.1);
TestMultiply<AVX2_8bit>(8, 2048, 256, 19, 19, 1.8, 1.8);
TestMultiply<AVX2_8bit>(320, 256, 256, .1, 1, 0.1);
@@ -482,7 +482,7 @@ TEST_CASE ("Multiply AVX2 8bit", "[multiply]") {
}
TEST_CASE ("Multiply AVX2 8bit with bias", "[biased_multiply]") {
- if (kCPU < CPU_AVX2) return;
+ if (kCPU < CPUType::AVX2) return;
TestMultiplyBias<AVX2_8bit>(8, 256, 256, .1, 1, 0.1);
TestMultiplyBias<AVX2_8bit>(8, 2048, 256, 19, 19, 1.8, 1.8);
TestMultiplyBias<AVX2_8bit>(320, 256, 256, .1, 1, 0.1);
@@ -492,7 +492,7 @@ TEST_CASE ("Multiply AVX2 8bit with bias", "[biased_multiply]") {
}
TEST_CASE ("Multiply AVX2 16bit", "[multiply]") {
- if (kCPU < CPU_AVX2) return;
+ if (kCPU < CPUType::AVX2) return;
TestMultiply<AVX2_16bit>(8, 256, 256, .1, 1, 0.01);
TestMultiply<AVX2_16bit>(8, 2048, 256, .1, 1, 0.02);
TestMultiply<AVX2_16bit>(320, 256, 256, .1, 1, 0.01);
@@ -502,7 +502,7 @@ TEST_CASE ("Multiply AVX2 16bit", "[multiply]") {
}
TEST_CASE ("Multiply AVX2 16bit with bias", "[biased_multiply]") {
- if (kCPU < CPU_AVX2) return;
+ if (kCPU < CPUType::AVX2) return;
TestMultiplyBias<AVX2_16bit>(8, 256, 256, .1, 1, 0.01);
TestMultiplyBias<AVX2_16bit>(8, 2048, 256, .1, 1, 0.02);
TestMultiplyBias<AVX2_16bit>(320, 256, 256, .1, 1, 0.01);
@@ -513,7 +513,7 @@ TEST_CASE ("Multiply AVX2 16bit with bias", "[biased_multiply]") {
#ifndef INTGEMM_NO_AVX512
TEST_CASE ("Multiply AVX512 8bit", "[multiply]") {
- if (kCPU < CPU_AVX512BW) return;
+ if (kCPU < CPUType::AVX512BW) return;
TestMultiply<AVX512_8bit>(8, 256, 256, .1, 1, 0.062);
TestMultiply<AVX512_8bit>(8, 2048, 256, 4.2, 4, 0.41, 0.37);
TestMultiply<AVX512_8bit>(320, 256, 256, .1, 1, 0.06);
@@ -523,7 +523,7 @@ TEST_CASE ("Multiply AVX2 16bit with bias", "[biased_multiply]") {
}
TEST_CASE ("Multiply AVX512 8bit with bias", "[biased_multiply]") {
- if (kCPU < CPU_AVX512BW) return;
+ if (kCPU < CPUType::AVX512BW) return;
TestMultiplyBias<AVX512_8bit>(8, 256, 256, .1, 1, 0.062);
TestMultiplyBias<AVX512_8bit>(8, 2048, 256, 4.2, 4, 0.41, 0.37);
TestMultiplyBias<AVX512_8bit>(320, 256, 256, .1, 1, 0.06);
@@ -533,7 +533,7 @@ TEST_CASE ("Multiply AVX2 16bit with bias", "[biased_multiply]") {
}
TEST_CASE ("Multiply AVX512 16bit", "[multiply]") {
- if (kCPU < CPU_AVX512BW) return;
+ if (kCPU < CPUType::AVX512BW) return;
TestMultiply<AVX512_16bit>(8, 256, 256, .1, 1, 0.01);
TestMultiply<AVX512_16bit>(8, 2048, 256, .1, 1, 0.011);
TestMultiply<AVX512_16bit>(320, 256, 256, .1, 1, 0.01);
@@ -543,7 +543,7 @@ TEST_CASE ("Multiply AVX2 16bit with bias", "[biased_multiply]") {
}
TEST_CASE ("Multiply AVX512 16bit with bias", "[biased_multiply]") {
- if (kCPU < CPU_AVX512BW) return;
+ if (kCPU < CPUType::AVX512BW) return;
TestMultiplyBias<AVX512_16bit>(8, 256, 256, .1, 1, 0.01);
TestMultiplyBias<AVX512_16bit>(8, 2048, 256, .1, 1, 0.011);
TestMultiplyBias<AVX512_16bit>(320, 256, 256, .1, 1, 0.01);
diff --git a/test/pipeline_test.cc b/test/pipeline_test.cc
index 389b9d7..1b8c21d 100644
--- a/test/pipeline_test.cc
+++ b/test/pipeline_test.cc
@@ -6,7 +6,7 @@
namespace intgemm {
INTGEMM_AVX2 TEST_CASE("PostprocessPipeline AVX2", "Unquantize-ReLU") {
- if (kCPU < CPU_AVX2)
+ if (kCPU < CPUType::AVX2)
return;
__m256i input;
@@ -19,7 +19,7 @@ INTGEMM_AVX2 TEST_CASE("PostprocessPipeline AVX2", "Unquantize-ReLU") {
std::fill(raw_output, raw_output + 8, 42);
auto pipeline = CreatePostprocessPipeline(Unquantize(0.5f), ReLU());
- auto inited_pipeline = InitPostprocessPipeline<CPU_AVX2>(pipeline);
+ auto inited_pipeline = InitPostprocessPipeline<CPUType::AVX2>(pipeline);
output = inited_pipeline.run(input, 0);
CHECK(raw_output[0] == 0.0f); // input = -2
@@ -33,7 +33,7 @@ INTGEMM_AVX2 TEST_CASE("PostprocessPipeline AVX2", "Unquantize-ReLU") {
}
INTGEMM_AVX2 TEST_CASE("PostprocessPipeline AVX2 on whole buffer", "Unquantize-ReLU") {
- if (kCPU < CPU_AVX2)
+ if (kCPU < CPUType::AVX2)
return;
__m256i input[2];
@@ -46,7 +46,7 @@ INTGEMM_AVX2 TEST_CASE("PostprocessPipeline AVX2 on whole buffer", "Unquantize-R
std::fill(raw_output, raw_output + 16, 42);
auto pipeline = CreatePostprocessPipeline(Unquantize(0.5f), ReLU());
- auto inited_pipeline = InitPostprocessPipeline<CPU_AVX2>(pipeline);
+ auto inited_pipeline = InitPostprocessPipeline<CPUType::AVX2>(pipeline);
inited_pipeline.run(input, 2, output);
CHECK(raw_output[0] == 0.f); // input = -8
diff --git a/test/quantize_test.cc b/test/quantize_test.cc
index e68fc34..fb866f1 100644
--- a/test/quantize_test.cc
+++ b/test/quantize_test.cc
@@ -74,23 +74,23 @@ template <class Backend> bool TestMany() {
}
TEST_CASE ("Quantize SSE2", "[quantize]") {
- if (kCPU < CPU_SSE2) return;
+ if (kCPU < CPUType::SSE2) return;
CHECK(TestMany<SSE2_16bit>());
}
TEST_CASE ("Quantize SSE3", "[quantize]") {
- if (kCPU < CPU_SSSE3) return;
+ if (kCPU < CPUType::SSSE3) return;
CHECK(TestMany<SSSE3_8bit>());
}
TEST_CASE ("Quantize AVX2", "[quantize]") {
- if (kCPU < CPU_AVX2) return;
+ if (kCPU < CPUType::AVX2) return;
CHECK(TestMany<AVX2_8bit>());
CHECK(TestMany<AVX2_16bit>());
}
#ifndef INTGEMM_NO_AVX512
TEST_CASE ("Quantize AVX512", "[quantize]") {
- if (kCPU < CPU_AVX512BW) return;
+ if (kCPU < CPUType::AVX512BW) return;
CHECK(TestMany<AVX512_8bit>());
CHECK(TestMany<AVX512_16bit>());
}
diff --git a/test/relu_test.cc b/test/relu_test.cc
index 0a72a29..183f415 100644
--- a/test/relu_test.cc
+++ b/test/relu_test.cc
@@ -6,7 +6,7 @@
namespace intgemm {
INTGEMM_SSE2 TEST_CASE("ReLU SSE2",) {
- if (kCPU < CPU_SSE2)
+ if (kCPU < CPUType::SSE2)
return;
float raw_input[8];
@@ -16,7 +16,7 @@ INTGEMM_SSE2 TEST_CASE("ReLU SSE2",) {
input.pack0123 = *reinterpret_cast<__m128*>(raw_input);
input.pack4567 = *reinterpret_cast<__m128*>(raw_input + 4);
- auto postproc = PostprocessImpl<ReLU, CPUType::CPU_SSE2>(ReLU());
+ auto postproc = PostprocessImpl<ReLU, CPUType::SSE2>(ReLU());
auto output = postproc.run(input, 0);
auto raw_output = reinterpret_cast<float*>(&output);
@@ -31,14 +31,14 @@ INTGEMM_SSE2 TEST_CASE("ReLU SSE2",) {
}
INTGEMM_AVX2 TEST_CASE("ReLU AVX2",) {
- if (kCPU < CPU_AVX2)
+ if (kCPU < CPUType::AVX2)
return;
float raw_input[8];
std::iota(raw_input, raw_input + 8, -4);
auto input = *reinterpret_cast<__m256*>(raw_input);
- auto postproc = PostprocessImpl<ReLU, CPUType::CPU_AVX2>(ReLU());
+ auto postproc = PostprocessImpl<ReLU, CPUType::AVX2>(ReLU());
auto output = postproc.run(input, 0);
auto raw_output = reinterpret_cast<float*>(&output);
@@ -55,14 +55,14 @@ INTGEMM_AVX2 TEST_CASE("ReLU AVX2",) {
#ifndef INTGEMM_NO_AVX512
INTGEMM_AVX512BW TEST_CASE("ReLU AVX512",) {
- if (kCPU < CPU_AVX512BW)
+ if (kCPU < CPUType::AVX512BW)
return;
float raw_input[16];
std::iota(raw_input, raw_input + 16, -8);
auto input = *reinterpret_cast<__m512*>(raw_input);
- auto postproc = PostprocessImpl<ReLU, CPUType::CPU_AVX512BW>(ReLU());
+ auto postproc = PostprocessImpl<ReLU, CPUType::AVX512BW>(ReLU());
auto output = postproc.run(input, 0);
auto raw_output = reinterpret_cast<float*>(&output);
diff --git a/types.h b/types.h
index 2f15b73..2331181 100644
--- a/types.h
+++ b/types.h
@@ -33,7 +33,13 @@ class UnsupportedCPU : public std::exception {
typedef unsigned int Index;
// If you want to detect the CPU and dispatch yourself, here's what to use:
-typedef enum {CPU_AVX512BW = 4, CPU_AVX2 = 3, CPU_SSSE3 = 2, CPU_SSE2 = 1, CPU_UNSUPPORTED = 0} CPUType;
+enum class CPUType {
+ UNSUPPORTED = 0,
+ SSE2,
+ SSSE3,
+ AVX2,
+ AVX512BW,
+};
// Running CPU type. This is defined in intgemm.cc (as the dispatcher).
extern const CPUType kCPU;