diff options
Diffstat (limited to 'test/multiply_test.cc')
-rw-r--r-- | test/multiply_test.cc | 245 |
1 files changed, 244 insertions, 1 deletions
diff --git a/test/multiply_test.cc b/test/multiply_test.cc index 186b0f9..f72758f 100644 --- a/test/multiply_test.cc +++ b/test/multiply_test.cc @@ -315,6 +315,57 @@ template <class Routine> void TestMultiply(Index A_rows, Index width, Index B_co int_tolerance, float_tolerance, MSE_float_tolerance, MSE_int_tolerance); } +template <class Routine> void TestMultiplyRelu(Index A_rows, Index width, Index B_cols, + float int_tolerance=.1, float float_tolerance=1, float MSE_float_tolerance=0, float MSE_int_tolerance=0) { + using Integer = typename Routine::Integer; + std::ostringstream info; + info << Routine::kName << "\t" << A_rows << '\t' << width << '\t' << B_cols << '\n'; + + // Initialize A and B. + AlignedVector<float> A(A_rows * width); + AlignedVector<float> B(width * B_cols); + std::mt19937 gen; + std::uniform_real_distribution<float> dist(-1.0f, 1.0f); + for (auto& it : A) { + it = dist(gen); + } + for (auto& it : B) { + it = dist(gen); + } + + float quant_mult = (sizeof(Integer) == 2) ? 1024 : 64; + float unquant_mult = 1.0f / (quant_mult*quant_mult); + + AlignedVector<Integer> A_prep(A.size()); + AlignedVector<Integer> B_prep(B.size()); + Routine::PrepareA(A.begin(), A_prep.begin(), quant_mult, A_rows, width); + Routine::PrepareB(B.begin(), B_prep.begin(), quant_mult, width, B_cols); + + AlignedVector<float> test_C(A_rows * B_cols); + OMPParallelWrap<callbacks::UnquantizeAndWriteRelu, Routine>(A_prep.begin(), B_prep.begin(), A_rows, width, B_cols, callbacks::UnquantizeAndWriteRelu(unquant_mult, test_C.begin())); + // Routine::Multiply(A_prep.begin(), B_prep.begin(), A_rows, width, B_cols, callbacks::Sequence( + // callbacks::Unquantize(unquant_mult), + // callbacks::Write<float>(test_C.begin()) + // )); + + AlignedVector<Integer> B_quant(B.size()); + Routine::Quantize(B.begin(), B_quant.begin(), quant_mult, static_cast<Index>(B.size())); + AlignedVector<float> slowint_C(test_C.size()); + // Assuming A is just quantization here. + references::Multiply(A_prep.begin(), B_quant.begin(), slowint_C.begin(), A_rows, width, B_cols, [&](int32_t sum, const callbacks::OutputBufferInfo&) { + float ret = std::max(0.0f, sum * unquant_mult); + return ret; + }); + + AlignedVector<float> float_C(test_C.size()); + references::Multiply(A.begin(), B.begin(), float_C.begin(), A_rows, width, B_cols, [&](double sum, const callbacks::OutputBufferInfo&) { + return static_cast<float>(std::max(0.0,sum)); + }); + + CompareMSE(float_C.begin(), slowint_C.begin(), test_C.begin(), test_C.size(), info.str(), + int_tolerance, float_tolerance, MSE_float_tolerance, MSE_int_tolerance); +} + //Code duplication may be avoided through some use of variadic templates, as the different WriteC symbols //Require different number of arguments. I don't think the refactoring is worth it. template <class Routine> void TestMultiplyBias(Index A_rows, Index width, Index B_cols, @@ -338,7 +389,7 @@ template <class Routine> void TestMultiplyBias(Index A_rows, Index width, Index for (auto& it : bias) { it = dist(gen); } - + float quant_mult = (sizeof(Integer) == 2) ? 1024 : 64; float unquant_mult = 1.0f / (quant_mult*quant_mult); @@ -368,6 +419,57 @@ template <class Routine> void TestMultiplyBias(Index A_rows, Index width, Index int_tolerance, float_tolerance, MSE_float_tolerance, MSE_int_tolerance); } +template <class Routine> void TestMultiplyBiasRelu(Index A_rows, Index width, Index B_cols, + float int_tolerance = 0.1f, float float_tolerance = 1.0f, float MSE_float_tolerance = 0.0f, float MSE_int_tolerance = 0.0f) { + using Integer = typename Routine::Integer; + std::ostringstream info; + info << Routine::kName << "\t" << A_rows << '\t' << width << '\t' << B_cols << '\n'; + + // Initialize A and B. + AlignedVector<float> A(A_rows * width); + AlignedVector<float> B(width * B_cols); + AlignedVector<float> bias(B_cols); + std::mt19937 gen; + std::uniform_real_distribution<float> dist(-1.0f, 1.0f); + for (auto& it : A) { + it = dist(gen); + } + for (auto& it : B) { + it = dist(gen); + } + for (auto& it : bias) { + it = dist(gen); + } + + float quant_mult = (sizeof(Integer) == 2) ? 1024 : 64; + float unquant_mult = 1.0f / (quant_mult*quant_mult); + + AlignedVector<Integer> A_prep(A.size()); + AlignedVector<Integer> B_prep(B.size()); + Routine::PrepareA(A.begin(), A_prep.begin(), quant_mult, A_rows, width); + Routine::PrepareB(B.begin(), B_prep.begin(), quant_mult, width, B_cols); + + AlignedVector<float> test_C(A_rows * B_cols); + + Routine::Multiply(A_prep.begin(), B_prep.begin(), A_rows, width, B_cols, callbacks::UnquantizeAndAddBiasAndWriteRelu(unquant_mult, bias.begin(), test_C.begin())); + + AlignedVector<Integer> B_quant(B.size()); + Routine::Quantize(B.begin(), B_quant.begin(), quant_mult, static_cast<Index>(B.size())); + AlignedVector<float> slowint_C(test_C.size()); + // Assuming A is just quantization here. + references::Multiply(A_prep.begin(), B_quant.begin(), slowint_C.begin(), A_rows, width, B_cols, [&](int32_t sum, const callbacks::OutputBufferInfo& info) { + return std::max(0.0f, sum * unquant_mult + bias[info.col_idx]); + }); + + AlignedVector<float> float_C(test_C.size()); + references::Multiply(A.begin(), B.begin(), float_C.begin(), A_rows, width, B_cols, [&](double sum, const callbacks::OutputBufferInfo& info) { + return std::max(0.0f, static_cast<float>(sum) + bias[info.col_idx]); + }); + + CompareMSE(float_C.begin(), slowint_C.begin(), test_C.begin(), test_C.size(), info.str(), + int_tolerance, float_tolerance, MSE_float_tolerance, MSE_int_tolerance); +} + TEST_CASE ("Multiply SSE2 16bit", "[multiply]") { if (kCPU < CPUType::SSE2) return; TestMultiply<SSE2::Kernels16>(8, 256, 256, .1f, 1, 0.01f); @@ -378,6 +480,16 @@ TEST_CASE ("Multiply SSE2 16bit", "[multiply]") { TestMultiply<SSE2::Kernels16>(200, 256, 256, .1f, 1, 0.01f); } +TEST_CASE ("Multiply SSE2 16bit with relu", "[multiply_relu]") { + if (kCPU < CPUType::SSE2) return; + TestMultiplyRelu<SSE2::Kernels16>(8, 256, 256, .1f, 1, 0.01f); + TestMultiplyRelu<SSE2::Kernels16>(8, 2048, 256, .1f, 1, 0.02f); + TestMultiplyRelu<SSE2::Kernels16>(320, 256, 256, .1f, 1, 0.01f); + TestMultiplyRelu<SSE2::Kernels16>(472, 256, 256, .1f, 1, 0.01f); + TestMultiplyRelu<SSE2::Kernels16>(248, 256, 256, .1f, 1, 0.01f); + TestMultiplyRelu<SSE2::Kernels16>(200, 256, 256, .1f, 1, 0.01f); +} + TEST_CASE ("Multiply SSE2 16bit with bias", "[biased_multiply]") { if (kCPU < CPUType::SSE2) return; TestMultiplyBias<SSE2::Kernels16>(8, 256, 256, .1f, 1, 0.01f); @@ -388,6 +500,16 @@ TEST_CASE ("Multiply SSE2 16bit with bias", "[biased_multiply]") { TestMultiplyBias<SSE2::Kernels16>(200, 256, 256, .1f, 1, 0.01f); } +TEST_CASE ("Multiply SSE2 16bit with bias and relu", "[biased_multiply_relu]") { + if (kCPU < CPUType::SSE2) return; + TestMultiplyBiasRelu<SSE2::Kernels16>(8, 256, 256, .1f, 1, 0.01f); + TestMultiplyBiasRelu<SSE2::Kernels16>(8, 2048, 256, .1f, 1, 0.02f); + TestMultiplyBiasRelu<SSE2::Kernels16>(320, 256, 256, .1f, 1, 0.01f); + TestMultiplyBiasRelu<SSE2::Kernels16>(472, 256, 256, .1f, 1, 0.01f); + TestMultiplyBiasRelu<SSE2::Kernels16>(248, 256, 256, .1f, 1, 0.01f); + TestMultiplyBiasRelu<SSE2::Kernels16>(200, 256, 256, .1f, 1, 0.01f); +} + TEST_CASE ("Multiply SSSE3 8bit", "[multiply]") { if (kCPU < CPUType::SSSE3) return; TestMultiply<SSSE3::Kernels8>(8, 256, 256, 1.2f, 1.2f, 0.064f, 0.026f); @@ -398,6 +520,16 @@ TEST_CASE ("Multiply SSSE3 8bit", "[multiply]") { TestMultiply<SSSE3::Kernels8>(200, 256, 256, 1.8f, 1.9f, 0.1f, 0.011f); } +TEST_CASE ("Multiply SSSE3 8bit with relu", "[multiply_relu]") { + if (kCPU < CPUType::SSSE3) return; + TestMultiplyRelu<SSSE3::Kernels8>(8, 256, 256, 1.2f, 1.2f, 0.064f, 0.026f); + TestMultiplyRelu<SSSE3::Kernels8>(8, 2048, 256, 33, 33, 4.4f, 4.4f); + TestMultiplyRelu<SSSE3::Kernels8>(320, 256, 256, 1.9f, 1.9f, 0.1f, 0.01f); + TestMultiplyRelu<SSSE3::Kernels8>(472, 256, 256, 2.1f, 2.1f, 0.1f, 0.011f); + TestMultiplyRelu<SSSE3::Kernels8>(248, 256, 256, 1.7f, 1.7f, 0.1f, 0.012f); + TestMultiplyRelu<SSSE3::Kernels8>(200, 256, 256, 1.8f, 1.9f, 0.1f, 0.011f); +} + TEST_CASE ("Multiply SSSE3 8bit with bias", "[biased_multiply]") { if (kCPU < CPUType::SSSE3) return; TestMultiplyBias<SSSE3::Kernels8>(8, 256, 256, 1.2f, 1.2f, 0.064f, 0.026f); @@ -408,6 +540,16 @@ TEST_CASE ("Multiply SSSE3 8bit with bias", "[biased_multiply]") { TestMultiplyBias<SSSE3::Kernels8>(200, 256, 256, 1.8f, 1.9f, 0.1f, 0.011f); } +TEST_CASE ("Multiply SSSE3 8bit with bias and relu", "[biased_multiply_relu]") { + if (kCPU < CPUType::SSSE3) return; + TestMultiplyBiasRelu<SSSE3::Kernels8>(8, 256, 256, 1.2f, 1.2f, 0.064f, 0.026f); + TestMultiplyBiasRelu<SSSE3::Kernels8>(8, 2048, 256, 33, 33, 4.4f, 4.4f); + TestMultiplyBiasRelu<SSSE3::Kernels8>(320, 256, 256, 1.9f, 1.9f, 0.1f, 0.01f); + TestMultiplyBiasRelu<SSSE3::Kernels8>(472, 256, 256, 2.1f, 2.1f, 0.1f, 0.011f); + TestMultiplyBiasRelu<SSSE3::Kernels8>(248, 256, 256, 1.7f, 1.7f, 0.1f, 0.012f); + TestMultiplyBiasRelu<SSSE3::Kernels8>(200, 256, 256, 1.8f, 1.9f, 0.1f, 0.011f); +} + #ifdef INTGEMM_COMPILER_SUPPORTS_AVX2 TEST_CASE ("Multiply AVX2 8bit", "[multiply]") { @@ -420,6 +562,16 @@ TEST_CASE ("Multiply AVX2 8bit", "[multiply]") { TestMultiply<AVX2::Kernels8>(200, 256, 256, .1f, 1, 0.1f); } +TEST_CASE ("Multiply AVX2 8bit with relu", "[multiply_relu]") { + if (kCPU < CPUType::AVX2) return; + TestMultiplyRelu<AVX2::Kernels8>(8, 256, 256, .1f, 1, 0.1f); + TestMultiplyRelu<AVX2::Kernels8>(8, 2048, 256, 19, 19, 1.8f, 1.8f); + TestMultiplyRelu<AVX2::Kernels8>(320, 256, 256, .1f, 1, 0.1f); + TestMultiplyRelu<AVX2::Kernels8>(472, 256, 256, .1f, 1, 0.1f); + TestMultiplyRelu<AVX2::Kernels8>(248, 256, 256, .1f, 1, 0.1f); + TestMultiplyRelu<AVX2::Kernels8>(200, 256, 256, .1f, 1, 0.1f); +} + TEST_CASE ("Multiply AVX2 8bit with bias", "[biased_multiply]") { if (kCPU < CPUType::AVX2) return; TestMultiplyBias<AVX2::Kernels8>(8, 256, 256, .1f, 1, 0.1f); @@ -430,6 +582,16 @@ TEST_CASE ("Multiply AVX2 8bit with bias", "[biased_multiply]") { TestMultiplyBias<AVX2::Kernels8>(200, 256, 256, .1f, 1, 0.1f); } +TEST_CASE ("Multiply AVX2 8bit with bias and relu", "[biased_multiply_relu]") { + if (kCPU < CPUType::AVX2) return; + TestMultiplyBiasRelu<AVX2::Kernels8>(8, 256, 256, .1f, 1, 0.1f); + TestMultiplyBiasRelu<AVX2::Kernels8>(8, 2048, 256, 19, 19, 1.8f, 1.8f); + TestMultiplyBiasRelu<AVX2::Kernels8>(320, 256, 256, .1f, 1, 0.1f); + TestMultiplyBiasRelu<AVX2::Kernels8>(472, 256, 256, .1f, 1, 0.1f); + TestMultiplyBiasRelu<AVX2::Kernels8>(248, 256, 256, .1f, 1, 0.1f); + TestMultiplyBiasRelu<AVX2::Kernels8>(200, 256, 256, .1f, 1, 0.1f); +} + TEST_CASE ("Multiply AVX2 16bit", "[multiply]") { if (kCPU < CPUType::AVX2) return; TestMultiply<AVX2::Kernels16>(8, 256, 256, .1f, 1, 0.01f); @@ -440,6 +602,16 @@ TEST_CASE ("Multiply AVX2 16bit", "[multiply]") { TestMultiply<AVX2::Kernels16>(200, 256, 256, .1f, 1, 0.01f); } +TEST_CASE ("Multiply AVX2 16bit with relu", "[multiply_relu]") { + if (kCPU < CPUType::AVX2) return; + TestMultiplyRelu<AVX2::Kernels16>(8, 256, 256, .1f, 1, 0.01f); + TestMultiplyRelu<AVX2::Kernels16>(8, 2048, 256, .1f, 1, 0.02f); + TestMultiplyRelu<AVX2::Kernels16>(320, 256, 256, .1f, 1, 0.01f); + TestMultiplyRelu<AVX2::Kernels16>(472, 256, 256, .1f, 1, 0.01f); + TestMultiplyRelu<AVX2::Kernels16>(248, 256, 256, .1f, 1, 0.01f); + TestMultiplyRelu<AVX2::Kernels16>(200, 256, 256, .1f, 1, 0.01f); +} + TEST_CASE ("Multiply AVX2 16bit with bias", "[biased_multiply]") { if (kCPU < CPUType::AVX2) return; TestMultiplyBias<AVX2::Kernels16>(8, 256, 256, .1f, 1, 0.01f); @@ -449,6 +621,16 @@ TEST_CASE ("Multiply AVX2 16bit with bias", "[biased_multiply]") { TestMultiplyBias<AVX2::Kernels16>(248, 256, 256, .1f, 1, 0.01f); TestMultiplyBias<AVX2::Kernels16>(200, 256, 256, .1f, 1, 0.01f); } + +TEST_CASE ("Multiply AVX2 16bit with bias and relu", "[biased_multiply_relu]") { + if (kCPU < CPUType::AVX2) return; + TestMultiplyBiasRelu<AVX2::Kernels16>(8, 256, 256, .1f, 1, 0.01f); + TestMultiplyBiasRelu<AVX2::Kernels16>(8, 2048, 256, .1f, 1, 0.02f); + TestMultiplyBiasRelu<AVX2::Kernels16>(320, 256, 256, .1f, 1, 0.01f); + TestMultiplyBiasRelu<AVX2::Kernels16>(472, 256, 256, .1f, 1, 0.01f); + TestMultiplyBiasRelu<AVX2::Kernels16>(248, 256, 256, .1f, 1, 0.01f); + TestMultiplyBiasRelu<AVX2::Kernels16>(200, 256, 256, .1f, 1, 0.01f); +} #endif #ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW @@ -462,6 +644,16 @@ TEST_CASE ("Multiply AVX2 16bit with bias", "[biased_multiply]") { TestMultiply<AVX512BW::Kernels8>(200, 256, 256, 0, 0.28f, 0.06f); } + TEST_CASE ("Multiply AVX512 8bit with relu", "[multiply_relu]") { + if (kCPU < CPUType::AVX512BW) return; + TestMultiplyRelu<AVX512BW::Kernels8>(8, 256, 256, 0, 0.25f, 0.062f); + TestMultiplyRelu<AVX512BW::Kernels8>(8, 2048, 256, 3.7f, 4, 0.37f, 0.33f); + TestMultiplyRelu<AVX512BW::Kernels8>(320, 256, 256, 0, 0.26f, 0.059f); + TestMultiplyRelu<AVX512BW::Kernels8>(472, 256, 256, 0, 0.29f, 0.059f); + TestMultiplyRelu<AVX512BW::Kernels8>(248, 256, 256, 0, 0.29f, 0.059f); + TestMultiplyRelu<AVX512BW::Kernels8>(200, 256, 256, 0, 0.28f, 0.06f); + } + TEST_CASE ("Multiply AVX512 8bit with bias", "[biased_multiply]") { if (kCPU < CPUType::AVX512BW) return; TestMultiplyBias<AVX512BW::Kernels8>(8, 256, 256, 0, 0.25f, 0.062f); @@ -472,6 +664,16 @@ TEST_CASE ("Multiply AVX2 16bit with bias", "[biased_multiply]") { TestMultiplyBias<AVX512BW::Kernels8>(200, 256, 256, 0, 0.28f, 0.06f); } + TEST_CASE ("Multiply AVX512 8bit with bias and relu", "[biased_multiply_relu]") { + if (kCPU < CPUType::AVX512BW) return; + TestMultiplyBiasRelu<AVX512BW::Kernels8>(8, 256, 256, 0, 0.25f, 0.062f); + TestMultiplyBiasRelu<AVX512BW::Kernels8>(8, 2048, 256, 3.7f, 4, 0.37f, 0.33f); + TestMultiplyBiasRelu<AVX512BW::Kernels8>(320, 256, 256, 0, 0.26f, 0.059f); + TestMultiplyBiasRelu<AVX512BW::Kernels8>(472, 256, 256, 0, 0.29f, 0.059f); + TestMultiplyBiasRelu<AVX512BW::Kernels8>(248, 256, 256, 0, 0.29f, 0.059f); + TestMultiplyBiasRelu<AVX512BW::Kernels8>(200, 256, 256, 0, 0.28f, 0.06f); + } + #ifdef INTGEMM_COMPILER_SUPPORTS_AVX512VNNI TEST_CASE ("Multiply AVX512VNNI 8bit", "[multiply]") { if (kCPU < CPUType::AVX512VNNI) return; @@ -483,6 +685,16 @@ TEST_CASE ("Multiply AVX2 16bit with bias", "[biased_multiply]") { TestMultiply<AVX512VNNI::Kernels8>(200, 256, 256, 0, 0.28f, 0.06f); } + TEST_CASE ("Multiply AVX512VNNI 8bit with relu", "[multiply_relu]") { + if (kCPU < CPUType::AVX512VNNI) return; + TestMultiplyRelu<AVX512VNNI::Kernels8>(8, 256, 256, 0, 0.25f, 0.062f); + TestMultiplyRelu<AVX512VNNI::Kernels8>(8, 2048, 256, 0, 0.55f, 0.25f); + TestMultiplyRelu<AVX512VNNI::Kernels8>(320, 256, 256, 0, 0.26f, 0.059f); + TestMultiplyRelu<AVX512VNNI::Kernels8>(472, 256, 256, 0, 0.29f, 0.059f); + TestMultiplyRelu<AVX512VNNI::Kernels8>(248, 256, 256, 0, 0.29f, 0.059f); + TestMultiplyRelu<AVX512VNNI::Kernels8>(200, 256, 256, 0, 0.28f, 0.06f); + } + TEST_CASE ("Multiply AVX512VNNI 8bit with bias", "[biased_multiply]") { if (kCPU < CPUType::AVX512VNNI) return; TestMultiplyBias<AVX512VNNI::Kernels8>(8, 256, 256, 0, 0.25f, 0.062f); @@ -492,6 +704,16 @@ TEST_CASE ("Multiply AVX2 16bit with bias", "[biased_multiply]") { TestMultiplyBias<AVX512VNNI::Kernels8>(248, 256, 256, 0, 0.29f, 0.059f); TestMultiplyBias<AVX512VNNI::Kernels8>(200, 256, 256, 0, 0.28f, 0.06f); } + + TEST_CASE ("Multiply AVX512VNNI 8bit with bias and relu", "[biased_multiply_relu]") { + if (kCPU < CPUType::AVX512VNNI) return; + TestMultiplyBiasRelu<AVX512VNNI::Kernels8>(8, 256, 256, 0, 0.25f, 0.062f); + TestMultiplyBiasRelu<AVX512VNNI::Kernels8>(8, 2048, 256, 0, 0.55f, 0.25f); + TestMultiplyBiasRelu<AVX512VNNI::Kernels8>(320, 256, 256, 0, 0.26f, 0.059f); + TestMultiplyBiasRelu<AVX512VNNI::Kernels8>(472, 256, 256, 0, 0.29f, 0.059f); + TestMultiplyBiasRelu<AVX512VNNI::Kernels8>(248, 256, 256, 0, 0.29f, 0.059f); + TestMultiplyBiasRelu<AVX512VNNI::Kernels8>(200, 256, 256, 0, 0.28f, 0.06f); + } #endif TEST_CASE ("Multiply AVX512 16bit", "[multiply]") { @@ -504,6 +726,17 @@ TEST_CASE ("Multiply AVX2 16bit with bias", "[biased_multiply]") { TestMultiply<AVX512BW::Kernels16>(200, 256, 256, .1f, 1, 0.01f); } + TEST_CASE ("Multiply AVX512 16bit with relu", "[multiply_relu]") { + if (kCPU < CPUType::AVX512BW) return; + TestMultiplyRelu<AVX512BW::Kernels16>(8, 256, 256, .1f, 1, 0.01f); + TestMultiplyRelu<AVX512BW::Kernels16>(8, 2048, 256, .1f, 1, 0.011f); + TestMultiplyRelu<AVX512BW::Kernels16>(320, 256, 256, .1f, 1, 0.01f); + TestMultiplyRelu<AVX512BW::Kernels16>(472, 256, 256, .1f, 1, 0.01f); + TestMultiplyRelu<AVX512BW::Kernels16>(248, 256, 256, .1f, 1, 0.01f); + TestMultiplyRelu<AVX512BW::Kernels16>(200, 256, 256, .1f, 1, 0.01f); + } + + TEST_CASE ("Multiply AVX512 16bit with bias", "[biased_multiply]") { if (kCPU < CPUType::AVX512BW) return; TestMultiplyBias<AVX512BW::Kernels16>(8, 256, 256, .1f, 1, 0.01f); @@ -513,6 +746,16 @@ TEST_CASE ("Multiply AVX2 16bit with bias", "[biased_multiply]") { TestMultiplyBias<AVX512BW::Kernels16>(248, 256, 256, .1f, 1, 0.01f); TestMultiplyBias<AVX512BW::Kernels16>(200, 256, 256, .1f, 1, 0.01f); } + + TEST_CASE ("Multiply AVX512 16bit with bias and relu", "[biased_multiply_relu]") { + if (kCPU < CPUType::AVX512BW) return; + TestMultiplyBiasRelu<AVX512BW::Kernels16>(8, 256, 256, .1f, 1, 0.01f); + TestMultiplyBiasRelu<AVX512BW::Kernels16>(8, 2048, 256, .1f, 1, 0.011f); + TestMultiplyBiasRelu<AVX512BW::Kernels16>(320, 256, 256, .1f, 1, 0.01f); + TestMultiplyBiasRelu<AVX512BW::Kernels16>(472, 256, 256, .1f, 1, 0.01f); + TestMultiplyBiasRelu<AVX512BW::Kernels16>(248, 256, 256, .1f, 1, 0.01f); + TestMultiplyBiasRelu<AVX512BW::Kernels16>(200, 256, 256, .1f, 1, 0.01f); + } #endif } // namespace intgemm |