Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/intgemm/intgemm.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'test/multiply_test.cc')
-rw-r--r--test/multiply_test.cc245
1 files changed, 244 insertions, 1 deletions
diff --git a/test/multiply_test.cc b/test/multiply_test.cc
index 186b0f9..f72758f 100644
--- a/test/multiply_test.cc
+++ b/test/multiply_test.cc
@@ -315,6 +315,57 @@ template <class Routine> void TestMultiply(Index A_rows, Index width, Index B_co
int_tolerance, float_tolerance, MSE_float_tolerance, MSE_int_tolerance);
}
+template <class Routine> void TestMultiplyRelu(Index A_rows, Index width, Index B_cols,
+ float int_tolerance=.1, float float_tolerance=1, float MSE_float_tolerance=0, float MSE_int_tolerance=0) {
+ using Integer = typename Routine::Integer;
+ std::ostringstream info;
+ info << Routine::kName << "\t" << A_rows << '\t' << width << '\t' << B_cols << '\n';
+
+ // Initialize A and B.
+ AlignedVector<float> A(A_rows * width);
+ AlignedVector<float> B(width * B_cols);
+ std::mt19937 gen;
+ std::uniform_real_distribution<float> dist(-1.0f, 1.0f);
+ for (auto& it : A) {
+ it = dist(gen);
+ }
+ for (auto& it : B) {
+ it = dist(gen);
+ }
+
+ float quant_mult = (sizeof(Integer) == 2) ? 1024 : 64;
+ float unquant_mult = 1.0f / (quant_mult*quant_mult);
+
+ AlignedVector<Integer> A_prep(A.size());
+ AlignedVector<Integer> B_prep(B.size());
+ Routine::PrepareA(A.begin(), A_prep.begin(), quant_mult, A_rows, width);
+ Routine::PrepareB(B.begin(), B_prep.begin(), quant_mult, width, B_cols);
+
+ AlignedVector<float> test_C(A_rows * B_cols);
+ OMPParallelWrap<callbacks::UnquantizeAndWriteRelu, Routine>(A_prep.begin(), B_prep.begin(), A_rows, width, B_cols, callbacks::UnquantizeAndWriteRelu(unquant_mult, test_C.begin()));
+ // Routine::Multiply(A_prep.begin(), B_prep.begin(), A_rows, width, B_cols, callbacks::Sequence(
+ // callbacks::Unquantize(unquant_mult),
+ // callbacks::Write<float>(test_C.begin())
+ // ));
+
+ AlignedVector<Integer> B_quant(B.size());
+ Routine::Quantize(B.begin(), B_quant.begin(), quant_mult, static_cast<Index>(B.size()));
+ AlignedVector<float> slowint_C(test_C.size());
+ // Assuming A is just quantization here.
+ references::Multiply(A_prep.begin(), B_quant.begin(), slowint_C.begin(), A_rows, width, B_cols, [&](int32_t sum, const callbacks::OutputBufferInfo&) {
+ float ret = std::max(0.0f, sum * unquant_mult);
+ return ret;
+ });
+
+ AlignedVector<float> float_C(test_C.size());
+ references::Multiply(A.begin(), B.begin(), float_C.begin(), A_rows, width, B_cols, [&](double sum, const callbacks::OutputBufferInfo&) {
+ return static_cast<float>(std::max(0.0,sum));
+ });
+
+ CompareMSE(float_C.begin(), slowint_C.begin(), test_C.begin(), test_C.size(), info.str(),
+ int_tolerance, float_tolerance, MSE_float_tolerance, MSE_int_tolerance);
+}
+
//Code duplication may be avoided through some use of variadic templates, as the different WriteC symbols
//Require different number of arguments. I don't think the refactoring is worth it.
template <class Routine> void TestMultiplyBias(Index A_rows, Index width, Index B_cols,
@@ -338,7 +389,7 @@ template <class Routine> void TestMultiplyBias(Index A_rows, Index width, Index
for (auto& it : bias) {
it = dist(gen);
}
-
+
float quant_mult = (sizeof(Integer) == 2) ? 1024 : 64;
float unquant_mult = 1.0f / (quant_mult*quant_mult);
@@ -368,6 +419,57 @@ template <class Routine> void TestMultiplyBias(Index A_rows, Index width, Index
int_tolerance, float_tolerance, MSE_float_tolerance, MSE_int_tolerance);
}
+template <class Routine> void TestMultiplyBiasRelu(Index A_rows, Index width, Index B_cols,
+ float int_tolerance = 0.1f, float float_tolerance = 1.0f, float MSE_float_tolerance = 0.0f, float MSE_int_tolerance = 0.0f) {
+ using Integer = typename Routine::Integer;
+ std::ostringstream info;
+ info << Routine::kName << "\t" << A_rows << '\t' << width << '\t' << B_cols << '\n';
+
+ // Initialize A and B.
+ AlignedVector<float> A(A_rows * width);
+ AlignedVector<float> B(width * B_cols);
+ AlignedVector<float> bias(B_cols);
+ std::mt19937 gen;
+ std::uniform_real_distribution<float> dist(-1.0f, 1.0f);
+ for (auto& it : A) {
+ it = dist(gen);
+ }
+ for (auto& it : B) {
+ it = dist(gen);
+ }
+ for (auto& it : bias) {
+ it = dist(gen);
+ }
+
+ float quant_mult = (sizeof(Integer) == 2) ? 1024 : 64;
+ float unquant_mult = 1.0f / (quant_mult*quant_mult);
+
+ AlignedVector<Integer> A_prep(A.size());
+ AlignedVector<Integer> B_prep(B.size());
+ Routine::PrepareA(A.begin(), A_prep.begin(), quant_mult, A_rows, width);
+ Routine::PrepareB(B.begin(), B_prep.begin(), quant_mult, width, B_cols);
+
+ AlignedVector<float> test_C(A_rows * B_cols);
+
+ Routine::Multiply(A_prep.begin(), B_prep.begin(), A_rows, width, B_cols, callbacks::UnquantizeAndAddBiasAndWriteRelu(unquant_mult, bias.begin(), test_C.begin()));
+
+ AlignedVector<Integer> B_quant(B.size());
+ Routine::Quantize(B.begin(), B_quant.begin(), quant_mult, static_cast<Index>(B.size()));
+ AlignedVector<float> slowint_C(test_C.size());
+ // Assuming A is just quantization here.
+ references::Multiply(A_prep.begin(), B_quant.begin(), slowint_C.begin(), A_rows, width, B_cols, [&](int32_t sum, const callbacks::OutputBufferInfo& info) {
+ return std::max(0.0f, sum * unquant_mult + bias[info.col_idx]);
+ });
+
+ AlignedVector<float> float_C(test_C.size());
+ references::Multiply(A.begin(), B.begin(), float_C.begin(), A_rows, width, B_cols, [&](double sum, const callbacks::OutputBufferInfo& info) {
+ return std::max(0.0f, static_cast<float>(sum) + bias[info.col_idx]);
+ });
+
+ CompareMSE(float_C.begin(), slowint_C.begin(), test_C.begin(), test_C.size(), info.str(),
+ int_tolerance, float_tolerance, MSE_float_tolerance, MSE_int_tolerance);
+}
+
TEST_CASE ("Multiply SSE2 16bit", "[multiply]") {
if (kCPU < CPUType::SSE2) return;
TestMultiply<SSE2::Kernels16>(8, 256, 256, .1f, 1, 0.01f);
@@ -378,6 +480,16 @@ TEST_CASE ("Multiply SSE2 16bit", "[multiply]") {
TestMultiply<SSE2::Kernels16>(200, 256, 256, .1f, 1, 0.01f);
}
+TEST_CASE ("Multiply SSE2 16bit with relu", "[multiply_relu]") {
+ if (kCPU < CPUType::SSE2) return;
+ TestMultiplyRelu<SSE2::Kernels16>(8, 256, 256, .1f, 1, 0.01f);
+ TestMultiplyRelu<SSE2::Kernels16>(8, 2048, 256, .1f, 1, 0.02f);
+ TestMultiplyRelu<SSE2::Kernels16>(320, 256, 256, .1f, 1, 0.01f);
+ TestMultiplyRelu<SSE2::Kernels16>(472, 256, 256, .1f, 1, 0.01f);
+ TestMultiplyRelu<SSE2::Kernels16>(248, 256, 256, .1f, 1, 0.01f);
+ TestMultiplyRelu<SSE2::Kernels16>(200, 256, 256, .1f, 1, 0.01f);
+}
+
TEST_CASE ("Multiply SSE2 16bit with bias", "[biased_multiply]") {
if (kCPU < CPUType::SSE2) return;
TestMultiplyBias<SSE2::Kernels16>(8, 256, 256, .1f, 1, 0.01f);
@@ -388,6 +500,16 @@ TEST_CASE ("Multiply SSE2 16bit with bias", "[biased_multiply]") {
TestMultiplyBias<SSE2::Kernels16>(200, 256, 256, .1f, 1, 0.01f);
}
+TEST_CASE ("Multiply SSE2 16bit with bias and relu", "[biased_multiply_relu]") {
+ if (kCPU < CPUType::SSE2) return;
+ TestMultiplyBiasRelu<SSE2::Kernels16>(8, 256, 256, .1f, 1, 0.01f);
+ TestMultiplyBiasRelu<SSE2::Kernels16>(8, 2048, 256, .1f, 1, 0.02f);
+ TestMultiplyBiasRelu<SSE2::Kernels16>(320, 256, 256, .1f, 1, 0.01f);
+ TestMultiplyBiasRelu<SSE2::Kernels16>(472, 256, 256, .1f, 1, 0.01f);
+ TestMultiplyBiasRelu<SSE2::Kernels16>(248, 256, 256, .1f, 1, 0.01f);
+ TestMultiplyBiasRelu<SSE2::Kernels16>(200, 256, 256, .1f, 1, 0.01f);
+}
+
TEST_CASE ("Multiply SSSE3 8bit", "[multiply]") {
if (kCPU < CPUType::SSSE3) return;
TestMultiply<SSSE3::Kernels8>(8, 256, 256, 1.2f, 1.2f, 0.064f, 0.026f);
@@ -398,6 +520,16 @@ TEST_CASE ("Multiply SSSE3 8bit", "[multiply]") {
TestMultiply<SSSE3::Kernels8>(200, 256, 256, 1.8f, 1.9f, 0.1f, 0.011f);
}
+TEST_CASE ("Multiply SSSE3 8bit with relu", "[multiply_relu]") {
+ if (kCPU < CPUType::SSSE3) return;
+ TestMultiplyRelu<SSSE3::Kernels8>(8, 256, 256, 1.2f, 1.2f, 0.064f, 0.026f);
+ TestMultiplyRelu<SSSE3::Kernels8>(8, 2048, 256, 33, 33, 4.4f, 4.4f);
+ TestMultiplyRelu<SSSE3::Kernels8>(320, 256, 256, 1.9f, 1.9f, 0.1f, 0.01f);
+ TestMultiplyRelu<SSSE3::Kernels8>(472, 256, 256, 2.1f, 2.1f, 0.1f, 0.011f);
+ TestMultiplyRelu<SSSE3::Kernels8>(248, 256, 256, 1.7f, 1.7f, 0.1f, 0.012f);
+ TestMultiplyRelu<SSSE3::Kernels8>(200, 256, 256, 1.8f, 1.9f, 0.1f, 0.011f);
+}
+
TEST_CASE ("Multiply SSSE3 8bit with bias", "[biased_multiply]") {
if (kCPU < CPUType::SSSE3) return;
TestMultiplyBias<SSSE3::Kernels8>(8, 256, 256, 1.2f, 1.2f, 0.064f, 0.026f);
@@ -408,6 +540,16 @@ TEST_CASE ("Multiply SSSE3 8bit with bias", "[biased_multiply]") {
TestMultiplyBias<SSSE3::Kernels8>(200, 256, 256, 1.8f, 1.9f, 0.1f, 0.011f);
}
+TEST_CASE ("Multiply SSSE3 8bit with bias and relu", "[biased_multiply_relu]") {
+ if (kCPU < CPUType::SSSE3) return;
+ TestMultiplyBiasRelu<SSSE3::Kernels8>(8, 256, 256, 1.2f, 1.2f, 0.064f, 0.026f);
+ TestMultiplyBiasRelu<SSSE3::Kernels8>(8, 2048, 256, 33, 33, 4.4f, 4.4f);
+ TestMultiplyBiasRelu<SSSE3::Kernels8>(320, 256, 256, 1.9f, 1.9f, 0.1f, 0.01f);
+ TestMultiplyBiasRelu<SSSE3::Kernels8>(472, 256, 256, 2.1f, 2.1f, 0.1f, 0.011f);
+ TestMultiplyBiasRelu<SSSE3::Kernels8>(248, 256, 256, 1.7f, 1.7f, 0.1f, 0.012f);
+ TestMultiplyBiasRelu<SSSE3::Kernels8>(200, 256, 256, 1.8f, 1.9f, 0.1f, 0.011f);
+}
+
#ifdef INTGEMM_COMPILER_SUPPORTS_AVX2
TEST_CASE ("Multiply AVX2 8bit", "[multiply]") {
@@ -420,6 +562,16 @@ TEST_CASE ("Multiply AVX2 8bit", "[multiply]") {
TestMultiply<AVX2::Kernels8>(200, 256, 256, .1f, 1, 0.1f);
}
+TEST_CASE ("Multiply AVX2 8bit with relu", "[multiply_relu]") {
+ if (kCPU < CPUType::AVX2) return;
+ TestMultiplyRelu<AVX2::Kernels8>(8, 256, 256, .1f, 1, 0.1f);
+ TestMultiplyRelu<AVX2::Kernels8>(8, 2048, 256, 19, 19, 1.8f, 1.8f);
+ TestMultiplyRelu<AVX2::Kernels8>(320, 256, 256, .1f, 1, 0.1f);
+ TestMultiplyRelu<AVX2::Kernels8>(472, 256, 256, .1f, 1, 0.1f);
+ TestMultiplyRelu<AVX2::Kernels8>(248, 256, 256, .1f, 1, 0.1f);
+ TestMultiplyRelu<AVX2::Kernels8>(200, 256, 256, .1f, 1, 0.1f);
+}
+
TEST_CASE ("Multiply AVX2 8bit with bias", "[biased_multiply]") {
if (kCPU < CPUType::AVX2) return;
TestMultiplyBias<AVX2::Kernels8>(8, 256, 256, .1f, 1, 0.1f);
@@ -430,6 +582,16 @@ TEST_CASE ("Multiply AVX2 8bit with bias", "[biased_multiply]") {
TestMultiplyBias<AVX2::Kernels8>(200, 256, 256, .1f, 1, 0.1f);
}
+TEST_CASE ("Multiply AVX2 8bit with bias and relu", "[biased_multiply_relu]") {
+ if (kCPU < CPUType::AVX2) return;
+ TestMultiplyBiasRelu<AVX2::Kernels8>(8, 256, 256, .1f, 1, 0.1f);
+ TestMultiplyBiasRelu<AVX2::Kernels8>(8, 2048, 256, 19, 19, 1.8f, 1.8f);
+ TestMultiplyBiasRelu<AVX2::Kernels8>(320, 256, 256, .1f, 1, 0.1f);
+ TestMultiplyBiasRelu<AVX2::Kernels8>(472, 256, 256, .1f, 1, 0.1f);
+ TestMultiplyBiasRelu<AVX2::Kernels8>(248, 256, 256, .1f, 1, 0.1f);
+ TestMultiplyBiasRelu<AVX2::Kernels8>(200, 256, 256, .1f, 1, 0.1f);
+}
+
TEST_CASE ("Multiply AVX2 16bit", "[multiply]") {
if (kCPU < CPUType::AVX2) return;
TestMultiply<AVX2::Kernels16>(8, 256, 256, .1f, 1, 0.01f);
@@ -440,6 +602,16 @@ TEST_CASE ("Multiply AVX2 16bit", "[multiply]") {
TestMultiply<AVX2::Kernels16>(200, 256, 256, .1f, 1, 0.01f);
}
+TEST_CASE ("Multiply AVX2 16bit with relu", "[multiply_relu]") {
+ if (kCPU < CPUType::AVX2) return;
+ TestMultiplyRelu<AVX2::Kernels16>(8, 256, 256, .1f, 1, 0.01f);
+ TestMultiplyRelu<AVX2::Kernels16>(8, 2048, 256, .1f, 1, 0.02f);
+ TestMultiplyRelu<AVX2::Kernels16>(320, 256, 256, .1f, 1, 0.01f);
+ TestMultiplyRelu<AVX2::Kernels16>(472, 256, 256, .1f, 1, 0.01f);
+ TestMultiplyRelu<AVX2::Kernels16>(248, 256, 256, .1f, 1, 0.01f);
+ TestMultiplyRelu<AVX2::Kernels16>(200, 256, 256, .1f, 1, 0.01f);
+}
+
TEST_CASE ("Multiply AVX2 16bit with bias", "[biased_multiply]") {
if (kCPU < CPUType::AVX2) return;
TestMultiplyBias<AVX2::Kernels16>(8, 256, 256, .1f, 1, 0.01f);
@@ -449,6 +621,16 @@ TEST_CASE ("Multiply AVX2 16bit with bias", "[biased_multiply]") {
TestMultiplyBias<AVX2::Kernels16>(248, 256, 256, .1f, 1, 0.01f);
TestMultiplyBias<AVX2::Kernels16>(200, 256, 256, .1f, 1, 0.01f);
}
+
+TEST_CASE ("Multiply AVX2 16bit with bias and relu", "[biased_multiply_relu]") {
+ if (kCPU < CPUType::AVX2) return;
+ TestMultiplyBiasRelu<AVX2::Kernels16>(8, 256, 256, .1f, 1, 0.01f);
+ TestMultiplyBiasRelu<AVX2::Kernels16>(8, 2048, 256, .1f, 1, 0.02f);
+ TestMultiplyBiasRelu<AVX2::Kernels16>(320, 256, 256, .1f, 1, 0.01f);
+ TestMultiplyBiasRelu<AVX2::Kernels16>(472, 256, 256, .1f, 1, 0.01f);
+ TestMultiplyBiasRelu<AVX2::Kernels16>(248, 256, 256, .1f, 1, 0.01f);
+ TestMultiplyBiasRelu<AVX2::Kernels16>(200, 256, 256, .1f, 1, 0.01f);
+}
#endif
#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW
@@ -462,6 +644,16 @@ TEST_CASE ("Multiply AVX2 16bit with bias", "[biased_multiply]") {
TestMultiply<AVX512BW::Kernels8>(200, 256, 256, 0, 0.28f, 0.06f);
}
+ TEST_CASE ("Multiply AVX512 8bit with relu", "[multiply_relu]") {
+ if (kCPU < CPUType::AVX512BW) return;
+ TestMultiplyRelu<AVX512BW::Kernels8>(8, 256, 256, 0, 0.25f, 0.062f);
+ TestMultiplyRelu<AVX512BW::Kernels8>(8, 2048, 256, 3.7f, 4, 0.37f, 0.33f);
+ TestMultiplyRelu<AVX512BW::Kernels8>(320, 256, 256, 0, 0.26f, 0.059f);
+ TestMultiplyRelu<AVX512BW::Kernels8>(472, 256, 256, 0, 0.29f, 0.059f);
+ TestMultiplyRelu<AVX512BW::Kernels8>(248, 256, 256, 0, 0.29f, 0.059f);
+ TestMultiplyRelu<AVX512BW::Kernels8>(200, 256, 256, 0, 0.28f, 0.06f);
+ }
+
TEST_CASE ("Multiply AVX512 8bit with bias", "[biased_multiply]") {
if (kCPU < CPUType::AVX512BW) return;
TestMultiplyBias<AVX512BW::Kernels8>(8, 256, 256, 0, 0.25f, 0.062f);
@@ -472,6 +664,16 @@ TEST_CASE ("Multiply AVX2 16bit with bias", "[biased_multiply]") {
TestMultiplyBias<AVX512BW::Kernels8>(200, 256, 256, 0, 0.28f, 0.06f);
}
+ TEST_CASE ("Multiply AVX512 8bit with bias and relu", "[biased_multiply_relu]") {
+ if (kCPU < CPUType::AVX512BW) return;
+ TestMultiplyBiasRelu<AVX512BW::Kernels8>(8, 256, 256, 0, 0.25f, 0.062f);
+ TestMultiplyBiasRelu<AVX512BW::Kernels8>(8, 2048, 256, 3.7f, 4, 0.37f, 0.33f);
+ TestMultiplyBiasRelu<AVX512BW::Kernels8>(320, 256, 256, 0, 0.26f, 0.059f);
+ TestMultiplyBiasRelu<AVX512BW::Kernels8>(472, 256, 256, 0, 0.29f, 0.059f);
+ TestMultiplyBiasRelu<AVX512BW::Kernels8>(248, 256, 256, 0, 0.29f, 0.059f);
+ TestMultiplyBiasRelu<AVX512BW::Kernels8>(200, 256, 256, 0, 0.28f, 0.06f);
+ }
+
#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512VNNI
TEST_CASE ("Multiply AVX512VNNI 8bit", "[multiply]") {
if (kCPU < CPUType::AVX512VNNI) return;
@@ -483,6 +685,16 @@ TEST_CASE ("Multiply AVX2 16bit with bias", "[biased_multiply]") {
TestMultiply<AVX512VNNI::Kernels8>(200, 256, 256, 0, 0.28f, 0.06f);
}
+ TEST_CASE ("Multiply AVX512VNNI 8bit with relu", "[multiply_relu]") {
+ if (kCPU < CPUType::AVX512VNNI) return;
+ TestMultiplyRelu<AVX512VNNI::Kernels8>(8, 256, 256, 0, 0.25f, 0.062f);
+ TestMultiplyRelu<AVX512VNNI::Kernels8>(8, 2048, 256, 0, 0.55f, 0.25f);
+ TestMultiplyRelu<AVX512VNNI::Kernels8>(320, 256, 256, 0, 0.26f, 0.059f);
+ TestMultiplyRelu<AVX512VNNI::Kernels8>(472, 256, 256, 0, 0.29f, 0.059f);
+ TestMultiplyRelu<AVX512VNNI::Kernels8>(248, 256, 256, 0, 0.29f, 0.059f);
+ TestMultiplyRelu<AVX512VNNI::Kernels8>(200, 256, 256, 0, 0.28f, 0.06f);
+ }
+
TEST_CASE ("Multiply AVX512VNNI 8bit with bias", "[biased_multiply]") {
if (kCPU < CPUType::AVX512VNNI) return;
TestMultiplyBias<AVX512VNNI::Kernels8>(8, 256, 256, 0, 0.25f, 0.062f);
@@ -492,6 +704,16 @@ TEST_CASE ("Multiply AVX2 16bit with bias", "[biased_multiply]") {
TestMultiplyBias<AVX512VNNI::Kernels8>(248, 256, 256, 0, 0.29f, 0.059f);
TestMultiplyBias<AVX512VNNI::Kernels8>(200, 256, 256, 0, 0.28f, 0.06f);
}
+
+ TEST_CASE ("Multiply AVX512VNNI 8bit with bias and relu", "[biased_multiply_relu]") {
+ if (kCPU < CPUType::AVX512VNNI) return;
+ TestMultiplyBiasRelu<AVX512VNNI::Kernels8>(8, 256, 256, 0, 0.25f, 0.062f);
+ TestMultiplyBiasRelu<AVX512VNNI::Kernels8>(8, 2048, 256, 0, 0.55f, 0.25f);
+ TestMultiplyBiasRelu<AVX512VNNI::Kernels8>(320, 256, 256, 0, 0.26f, 0.059f);
+ TestMultiplyBiasRelu<AVX512VNNI::Kernels8>(472, 256, 256, 0, 0.29f, 0.059f);
+ TestMultiplyBiasRelu<AVX512VNNI::Kernels8>(248, 256, 256, 0, 0.29f, 0.059f);
+ TestMultiplyBiasRelu<AVX512VNNI::Kernels8>(200, 256, 256, 0, 0.28f, 0.06f);
+ }
#endif
TEST_CASE ("Multiply AVX512 16bit", "[multiply]") {
@@ -504,6 +726,17 @@ TEST_CASE ("Multiply AVX2 16bit with bias", "[biased_multiply]") {
TestMultiply<AVX512BW::Kernels16>(200, 256, 256, .1f, 1, 0.01f);
}
+ TEST_CASE ("Multiply AVX512 16bit with relu", "[multiply_relu]") {
+ if (kCPU < CPUType::AVX512BW) return;
+ TestMultiplyRelu<AVX512BW::Kernels16>(8, 256, 256, .1f, 1, 0.01f);
+ TestMultiplyRelu<AVX512BW::Kernels16>(8, 2048, 256, .1f, 1, 0.011f);
+ TestMultiplyRelu<AVX512BW::Kernels16>(320, 256, 256, .1f, 1, 0.01f);
+ TestMultiplyRelu<AVX512BW::Kernels16>(472, 256, 256, .1f, 1, 0.01f);
+ TestMultiplyRelu<AVX512BW::Kernels16>(248, 256, 256, .1f, 1, 0.01f);
+ TestMultiplyRelu<AVX512BW::Kernels16>(200, 256, 256, .1f, 1, 0.01f);
+ }
+
+
TEST_CASE ("Multiply AVX512 16bit with bias", "[biased_multiply]") {
if (kCPU < CPUType::AVX512BW) return;
TestMultiplyBias<AVX512BW::Kernels16>(8, 256, 256, .1f, 1, 0.01f);
@@ -513,6 +746,16 @@ TEST_CASE ("Multiply AVX2 16bit with bias", "[biased_multiply]") {
TestMultiplyBias<AVX512BW::Kernels16>(248, 256, 256, .1f, 1, 0.01f);
TestMultiplyBias<AVX512BW::Kernels16>(200, 256, 256, .1f, 1, 0.01f);
}
+
+ TEST_CASE ("Multiply AVX512 16bit with bias and relu", "[biased_multiply_relu]") {
+ if (kCPU < CPUType::AVX512BW) return;
+ TestMultiplyBiasRelu<AVX512BW::Kernels16>(8, 256, 256, .1f, 1, 0.01f);
+ TestMultiplyBiasRelu<AVX512BW::Kernels16>(8, 2048, 256, .1f, 1, 0.011f);
+ TestMultiplyBiasRelu<AVX512BW::Kernels16>(320, 256, 256, .1f, 1, 0.01f);
+ TestMultiplyBiasRelu<AVX512BW::Kernels16>(472, 256, 256, .1f, 1, 0.01f);
+ TestMultiplyBiasRelu<AVX512BW::Kernels16>(248, 256, 256, .1f, 1, 0.01f);
+ TestMultiplyBiasRelu<AVX512BW::Kernels16>(200, 256, 256, .1f, 1, 0.01f);
+ }
#endif
} // namespace intgemm