From 6228d016ecc63470d2dbb76bd4ab7b0abe097993 Mon Sep 17 00:00:00 2001 From: Nikolay Bogoychev Date: Wed, 23 Jun 2021 12:48:30 +0100 Subject: Add relu callback (#89) --- intgemm/callbacks/configs.h | 15 +++ intgemm/callbacks/implementations.inl | 54 ++++++++ test/multiply_test.cc | 245 +++++++++++++++++++++++++++++++++- 3 files changed, 313 insertions(+), 1 deletion(-) diff --git a/intgemm/callbacks/configs.h b/intgemm/callbacks/configs.h index 1222448..d2fbe98 100644 --- a/intgemm/callbacks/configs.h +++ b/intgemm/callbacks/configs.h @@ -39,6 +39,13 @@ struct UnquantizeAndWrite { UnquantizeAndWrite(float unquant_mult, float* output_addr) : unquant_mult(unquant_mult), output_addr(output_addr) {} }; +struct UnquantizeAndWriteRelu { + float unquant_mult; + float* output_addr; + + UnquantizeAndWriteRelu(float unquant_mult, float* output_addr) : unquant_mult(unquant_mult), output_addr(output_addr) {} +}; + struct AddBiasAndWrite { const int* bias_addr; int* output_addr; @@ -54,5 +61,13 @@ struct UnquantizeAndAddBiasAndWrite { UnquantizeAndAddBiasAndWrite(float unquant_mult, const float* bias_addr, float* output_addr) : unquant_mult(unquant_mult), bias_addr(bias_addr), output_addr(output_addr) {} }; +struct UnquantizeAndAddBiasAndWriteRelu { + float unquant_mult; + const float* bias_addr; + float* output_addr; + + UnquantizeAndAddBiasAndWriteRelu(float unquant_mult, const float* bias_addr, float* output_addr) : unquant_mult(unquant_mult), bias_addr(bias_addr), output_addr(output_addr) {} +}; + } } diff --git a/intgemm/callbacks/implementations.inl b/intgemm/callbacks/implementations.inl index 9a8f9e1..126701d 100644 --- a/intgemm/callbacks/implementations.inl +++ b/intgemm/callbacks/implementations.inl @@ -152,6 +152,33 @@ private: UnquantizeAndWrite config; }; +/* + * UnquantizeAndWriteRelu + */ +template <> class CallbackImpl { +public: + explicit INTGEMM_TARGET_CONSTRUCTOR CallbackImpl(const UnquantizeAndWriteRelu& config) : config(config) { + unquant_mult = set1_ps(config.unquant_mult); + } + + INTGEMM_TARGET void Run(vi input, const OutputBufferInfo& info) { + // Workaround gcc 5 internal compiler error that can't read register members in debug. + vf mult_reg; +#if !defined(__OPTIMIZE__) && (__GNUC__ == 5) && !defined(__clang__) && !defined(__INTEL_COMPILER) + asm ("vmovdqa %1, %0" : "=x" (mult_reg) : "m" (unquant_mult)); +#else + mult_reg = unquant_mult; +#endif + auto result = kernels::relu(kernels::unquantize(input, mult_reg)); + kernels::write(result, config.output_addr, info.row_idx * info.cols + info.col_idx); + } + +private: + vf unquant_mult; + UnquantizeAndWriteRelu config; +}; + + /* * AddBiasAndWrite */ @@ -194,6 +221,33 @@ private: UnquantizeAndAddBiasAndWrite config; }; +/* + * UnquantizeAndAddBiasAndWrite + */ +template <> class CallbackImpl { +public: + explicit INTGEMM_TARGET_CONSTRUCTOR CallbackImpl(const UnquantizeAndAddBiasAndWriteRelu& config) : config(config) { + unquant_mult = set1_ps(config.unquant_mult); + } + + INTGEMM_TARGET void Run(vi input, const OutputBufferInfo& info) { + // Workaround gcc 5 internal compiler error that can't read register members in debug. + vf mult_reg; +#if !defined(__OPTIMIZE__) && (__GNUC__ == 5) && !defined(__clang__) && !defined(__INTEL_COMPILER) + asm ("vmovdqa %1, %0" : "=x" (mult_reg) : "m" (unquant_mult)); +#else + mult_reg = unquant_mult; +#endif + auto result = kernels::unquantize(input, mult_reg); + result = kernels::add_bias(result, config.bias_addr, info.col_idx); + result = kernels::relu(result); + kernels::write(result, config.output_addr, info.row_idx * info.cols + info.col_idx); + } +private: + vf unquant_mult; + UnquantizeAndAddBiasAndWriteRelu config; +}; + } } diff --git a/test/multiply_test.cc b/test/multiply_test.cc index 186b0f9..f72758f 100644 --- a/test/multiply_test.cc +++ b/test/multiply_test.cc @@ -315,6 +315,57 @@ template void TestMultiply(Index A_rows, Index width, Index B_co int_tolerance, float_tolerance, MSE_float_tolerance, MSE_int_tolerance); } +template void TestMultiplyRelu(Index A_rows, Index width, Index B_cols, + float int_tolerance=.1, float float_tolerance=1, float MSE_float_tolerance=0, float MSE_int_tolerance=0) { + using Integer = typename Routine::Integer; + std::ostringstream info; + info << Routine::kName << "\t" << A_rows << '\t' << width << '\t' << B_cols << '\n'; + + // Initialize A and B. + AlignedVector A(A_rows * width); + AlignedVector B(width * B_cols); + std::mt19937 gen; + std::uniform_real_distribution dist(-1.0f, 1.0f); + for (auto& it : A) { + it = dist(gen); + } + for (auto& it : B) { + it = dist(gen); + } + + float quant_mult = (sizeof(Integer) == 2) ? 1024 : 64; + float unquant_mult = 1.0f / (quant_mult*quant_mult); + + AlignedVector A_prep(A.size()); + AlignedVector B_prep(B.size()); + Routine::PrepareA(A.begin(), A_prep.begin(), quant_mult, A_rows, width); + Routine::PrepareB(B.begin(), B_prep.begin(), quant_mult, width, B_cols); + + AlignedVector test_C(A_rows * B_cols); + OMPParallelWrap(A_prep.begin(), B_prep.begin(), A_rows, width, B_cols, callbacks::UnquantizeAndWriteRelu(unquant_mult, test_C.begin())); + // Routine::Multiply(A_prep.begin(), B_prep.begin(), A_rows, width, B_cols, callbacks::Sequence( + // callbacks::Unquantize(unquant_mult), + // callbacks::Write(test_C.begin()) + // )); + + AlignedVector B_quant(B.size()); + Routine::Quantize(B.begin(), B_quant.begin(), quant_mult, static_cast(B.size())); + AlignedVector slowint_C(test_C.size()); + // Assuming A is just quantization here. + references::Multiply(A_prep.begin(), B_quant.begin(), slowint_C.begin(), A_rows, width, B_cols, [&](int32_t sum, const callbacks::OutputBufferInfo&) { + float ret = std::max(0.0f, sum * unquant_mult); + return ret; + }); + + AlignedVector float_C(test_C.size()); + references::Multiply(A.begin(), B.begin(), float_C.begin(), A_rows, width, B_cols, [&](double sum, const callbacks::OutputBufferInfo&) { + return static_cast(std::max(0.0,sum)); + }); + + CompareMSE(float_C.begin(), slowint_C.begin(), test_C.begin(), test_C.size(), info.str(), + int_tolerance, float_tolerance, MSE_float_tolerance, MSE_int_tolerance); +} + //Code duplication may be avoided through some use of variadic templates, as the different WriteC symbols //Require different number of arguments. I don't think the refactoring is worth it. template void TestMultiplyBias(Index A_rows, Index width, Index B_cols, @@ -338,7 +389,7 @@ template void TestMultiplyBias(Index A_rows, Index width, Index for (auto& it : bias) { it = dist(gen); } - + float quant_mult = (sizeof(Integer) == 2) ? 1024 : 64; float unquant_mult = 1.0f / (quant_mult*quant_mult); @@ -368,6 +419,57 @@ template void TestMultiplyBias(Index A_rows, Index width, Index int_tolerance, float_tolerance, MSE_float_tolerance, MSE_int_tolerance); } +template void TestMultiplyBiasRelu(Index A_rows, Index width, Index B_cols, + float int_tolerance = 0.1f, float float_tolerance = 1.0f, float MSE_float_tolerance = 0.0f, float MSE_int_tolerance = 0.0f) { + using Integer = typename Routine::Integer; + std::ostringstream info; + info << Routine::kName << "\t" << A_rows << '\t' << width << '\t' << B_cols << '\n'; + + // Initialize A and B. + AlignedVector A(A_rows * width); + AlignedVector B(width * B_cols); + AlignedVector bias(B_cols); + std::mt19937 gen; + std::uniform_real_distribution dist(-1.0f, 1.0f); + for (auto& it : A) { + it = dist(gen); + } + for (auto& it : B) { + it = dist(gen); + } + for (auto& it : bias) { + it = dist(gen); + } + + float quant_mult = (sizeof(Integer) == 2) ? 1024 : 64; + float unquant_mult = 1.0f / (quant_mult*quant_mult); + + AlignedVector A_prep(A.size()); + AlignedVector B_prep(B.size()); + Routine::PrepareA(A.begin(), A_prep.begin(), quant_mult, A_rows, width); + Routine::PrepareB(B.begin(), B_prep.begin(), quant_mult, width, B_cols); + + AlignedVector test_C(A_rows * B_cols); + + Routine::Multiply(A_prep.begin(), B_prep.begin(), A_rows, width, B_cols, callbacks::UnquantizeAndAddBiasAndWriteRelu(unquant_mult, bias.begin(), test_C.begin())); + + AlignedVector B_quant(B.size()); + Routine::Quantize(B.begin(), B_quant.begin(), quant_mult, static_cast(B.size())); + AlignedVector slowint_C(test_C.size()); + // Assuming A is just quantization here. + references::Multiply(A_prep.begin(), B_quant.begin(), slowint_C.begin(), A_rows, width, B_cols, [&](int32_t sum, const callbacks::OutputBufferInfo& info) { + return std::max(0.0f, sum * unquant_mult + bias[info.col_idx]); + }); + + AlignedVector float_C(test_C.size()); + references::Multiply(A.begin(), B.begin(), float_C.begin(), A_rows, width, B_cols, [&](double sum, const callbacks::OutputBufferInfo& info) { + return std::max(0.0f, static_cast(sum) + bias[info.col_idx]); + }); + + CompareMSE(float_C.begin(), slowint_C.begin(), test_C.begin(), test_C.size(), info.str(), + int_tolerance, float_tolerance, MSE_float_tolerance, MSE_int_tolerance); +} + TEST_CASE ("Multiply SSE2 16bit", "[multiply]") { if (kCPU < CPUType::SSE2) return; TestMultiply(8, 256, 256, .1f, 1, 0.01f); @@ -378,6 +480,16 @@ TEST_CASE ("Multiply SSE2 16bit", "[multiply]") { TestMultiply(200, 256, 256, .1f, 1, 0.01f); } +TEST_CASE ("Multiply SSE2 16bit with relu", "[multiply_relu]") { + if (kCPU < CPUType::SSE2) return; + TestMultiplyRelu(8, 256, 256, .1f, 1, 0.01f); + TestMultiplyRelu(8, 2048, 256, .1f, 1, 0.02f); + TestMultiplyRelu(320, 256, 256, .1f, 1, 0.01f); + TestMultiplyRelu(472, 256, 256, .1f, 1, 0.01f); + TestMultiplyRelu(248, 256, 256, .1f, 1, 0.01f); + TestMultiplyRelu(200, 256, 256, .1f, 1, 0.01f); +} + TEST_CASE ("Multiply SSE2 16bit with bias", "[biased_multiply]") { if (kCPU < CPUType::SSE2) return; TestMultiplyBias(8, 256, 256, .1f, 1, 0.01f); @@ -388,6 +500,16 @@ TEST_CASE ("Multiply SSE2 16bit with bias", "[biased_multiply]") { TestMultiplyBias(200, 256, 256, .1f, 1, 0.01f); } +TEST_CASE ("Multiply SSE2 16bit with bias and relu", "[biased_multiply_relu]") { + if (kCPU < CPUType::SSE2) return; + TestMultiplyBiasRelu(8, 256, 256, .1f, 1, 0.01f); + TestMultiplyBiasRelu(8, 2048, 256, .1f, 1, 0.02f); + TestMultiplyBiasRelu(320, 256, 256, .1f, 1, 0.01f); + TestMultiplyBiasRelu(472, 256, 256, .1f, 1, 0.01f); + TestMultiplyBiasRelu(248, 256, 256, .1f, 1, 0.01f); + TestMultiplyBiasRelu(200, 256, 256, .1f, 1, 0.01f); +} + TEST_CASE ("Multiply SSSE3 8bit", "[multiply]") { if (kCPU < CPUType::SSSE3) return; TestMultiply(8, 256, 256, 1.2f, 1.2f, 0.064f, 0.026f); @@ -398,6 +520,16 @@ TEST_CASE ("Multiply SSSE3 8bit", "[multiply]") { TestMultiply(200, 256, 256, 1.8f, 1.9f, 0.1f, 0.011f); } +TEST_CASE ("Multiply SSSE3 8bit with relu", "[multiply_relu]") { + if (kCPU < CPUType::SSSE3) return; + TestMultiplyRelu(8, 256, 256, 1.2f, 1.2f, 0.064f, 0.026f); + TestMultiplyRelu(8, 2048, 256, 33, 33, 4.4f, 4.4f); + TestMultiplyRelu(320, 256, 256, 1.9f, 1.9f, 0.1f, 0.01f); + TestMultiplyRelu(472, 256, 256, 2.1f, 2.1f, 0.1f, 0.011f); + TestMultiplyRelu(248, 256, 256, 1.7f, 1.7f, 0.1f, 0.012f); + TestMultiplyRelu(200, 256, 256, 1.8f, 1.9f, 0.1f, 0.011f); +} + TEST_CASE ("Multiply SSSE3 8bit with bias", "[biased_multiply]") { if (kCPU < CPUType::SSSE3) return; TestMultiplyBias(8, 256, 256, 1.2f, 1.2f, 0.064f, 0.026f); @@ -408,6 +540,16 @@ TEST_CASE ("Multiply SSSE3 8bit with bias", "[biased_multiply]") { TestMultiplyBias(200, 256, 256, 1.8f, 1.9f, 0.1f, 0.011f); } +TEST_CASE ("Multiply SSSE3 8bit with bias and relu", "[biased_multiply_relu]") { + if (kCPU < CPUType::SSSE3) return; + TestMultiplyBiasRelu(8, 256, 256, 1.2f, 1.2f, 0.064f, 0.026f); + TestMultiplyBiasRelu(8, 2048, 256, 33, 33, 4.4f, 4.4f); + TestMultiplyBiasRelu(320, 256, 256, 1.9f, 1.9f, 0.1f, 0.01f); + TestMultiplyBiasRelu(472, 256, 256, 2.1f, 2.1f, 0.1f, 0.011f); + TestMultiplyBiasRelu(248, 256, 256, 1.7f, 1.7f, 0.1f, 0.012f); + TestMultiplyBiasRelu(200, 256, 256, 1.8f, 1.9f, 0.1f, 0.011f); +} + #ifdef INTGEMM_COMPILER_SUPPORTS_AVX2 TEST_CASE ("Multiply AVX2 8bit", "[multiply]") { @@ -420,6 +562,16 @@ TEST_CASE ("Multiply AVX2 8bit", "[multiply]") { TestMultiply(200, 256, 256, .1f, 1, 0.1f); } +TEST_CASE ("Multiply AVX2 8bit with relu", "[multiply_relu]") { + if (kCPU < CPUType::AVX2) return; + TestMultiplyRelu(8, 256, 256, .1f, 1, 0.1f); + TestMultiplyRelu(8, 2048, 256, 19, 19, 1.8f, 1.8f); + TestMultiplyRelu(320, 256, 256, .1f, 1, 0.1f); + TestMultiplyRelu(472, 256, 256, .1f, 1, 0.1f); + TestMultiplyRelu(248, 256, 256, .1f, 1, 0.1f); + TestMultiplyRelu(200, 256, 256, .1f, 1, 0.1f); +} + TEST_CASE ("Multiply AVX2 8bit with bias", "[biased_multiply]") { if (kCPU < CPUType::AVX2) return; TestMultiplyBias(8, 256, 256, .1f, 1, 0.1f); @@ -430,6 +582,16 @@ TEST_CASE ("Multiply AVX2 8bit with bias", "[biased_multiply]") { TestMultiplyBias(200, 256, 256, .1f, 1, 0.1f); } +TEST_CASE ("Multiply AVX2 8bit with bias and relu", "[biased_multiply_relu]") { + if (kCPU < CPUType::AVX2) return; + TestMultiplyBiasRelu(8, 256, 256, .1f, 1, 0.1f); + TestMultiplyBiasRelu(8, 2048, 256, 19, 19, 1.8f, 1.8f); + TestMultiplyBiasRelu(320, 256, 256, .1f, 1, 0.1f); + TestMultiplyBiasRelu(472, 256, 256, .1f, 1, 0.1f); + TestMultiplyBiasRelu(248, 256, 256, .1f, 1, 0.1f); + TestMultiplyBiasRelu(200, 256, 256, .1f, 1, 0.1f); +} + TEST_CASE ("Multiply AVX2 16bit", "[multiply]") { if (kCPU < CPUType::AVX2) return; TestMultiply(8, 256, 256, .1f, 1, 0.01f); @@ -440,6 +602,16 @@ TEST_CASE ("Multiply AVX2 16bit", "[multiply]") { TestMultiply(200, 256, 256, .1f, 1, 0.01f); } +TEST_CASE ("Multiply AVX2 16bit with relu", "[multiply_relu]") { + if (kCPU < CPUType::AVX2) return; + TestMultiplyRelu(8, 256, 256, .1f, 1, 0.01f); + TestMultiplyRelu(8, 2048, 256, .1f, 1, 0.02f); + TestMultiplyRelu(320, 256, 256, .1f, 1, 0.01f); + TestMultiplyRelu(472, 256, 256, .1f, 1, 0.01f); + TestMultiplyRelu(248, 256, 256, .1f, 1, 0.01f); + TestMultiplyRelu(200, 256, 256, .1f, 1, 0.01f); +} + TEST_CASE ("Multiply AVX2 16bit with bias", "[biased_multiply]") { if (kCPU < CPUType::AVX2) return; TestMultiplyBias(8, 256, 256, .1f, 1, 0.01f); @@ -449,6 +621,16 @@ TEST_CASE ("Multiply AVX2 16bit with bias", "[biased_multiply]") { TestMultiplyBias(248, 256, 256, .1f, 1, 0.01f); TestMultiplyBias(200, 256, 256, .1f, 1, 0.01f); } + +TEST_CASE ("Multiply AVX2 16bit with bias and relu", "[biased_multiply_relu]") { + if (kCPU < CPUType::AVX2) return; + TestMultiplyBiasRelu(8, 256, 256, .1f, 1, 0.01f); + TestMultiplyBiasRelu(8, 2048, 256, .1f, 1, 0.02f); + TestMultiplyBiasRelu(320, 256, 256, .1f, 1, 0.01f); + TestMultiplyBiasRelu(472, 256, 256, .1f, 1, 0.01f); + TestMultiplyBiasRelu(248, 256, 256, .1f, 1, 0.01f); + TestMultiplyBiasRelu(200, 256, 256, .1f, 1, 0.01f); +} #endif #ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW @@ -462,6 +644,16 @@ TEST_CASE ("Multiply AVX2 16bit with bias", "[biased_multiply]") { TestMultiply(200, 256, 256, 0, 0.28f, 0.06f); } + TEST_CASE ("Multiply AVX512 8bit with relu", "[multiply_relu]") { + if (kCPU < CPUType::AVX512BW) return; + TestMultiplyRelu(8, 256, 256, 0, 0.25f, 0.062f); + TestMultiplyRelu(8, 2048, 256, 3.7f, 4, 0.37f, 0.33f); + TestMultiplyRelu(320, 256, 256, 0, 0.26f, 0.059f); + TestMultiplyRelu(472, 256, 256, 0, 0.29f, 0.059f); + TestMultiplyRelu(248, 256, 256, 0, 0.29f, 0.059f); + TestMultiplyRelu(200, 256, 256, 0, 0.28f, 0.06f); + } + TEST_CASE ("Multiply AVX512 8bit with bias", "[biased_multiply]") { if (kCPU < CPUType::AVX512BW) return; TestMultiplyBias(8, 256, 256, 0, 0.25f, 0.062f); @@ -472,6 +664,16 @@ TEST_CASE ("Multiply AVX2 16bit with bias", "[biased_multiply]") { TestMultiplyBias(200, 256, 256, 0, 0.28f, 0.06f); } + TEST_CASE ("Multiply AVX512 8bit with bias and relu", "[biased_multiply_relu]") { + if (kCPU < CPUType::AVX512BW) return; + TestMultiplyBiasRelu(8, 256, 256, 0, 0.25f, 0.062f); + TestMultiplyBiasRelu(8, 2048, 256, 3.7f, 4, 0.37f, 0.33f); + TestMultiplyBiasRelu(320, 256, 256, 0, 0.26f, 0.059f); + TestMultiplyBiasRelu(472, 256, 256, 0, 0.29f, 0.059f); + TestMultiplyBiasRelu(248, 256, 256, 0, 0.29f, 0.059f); + TestMultiplyBiasRelu(200, 256, 256, 0, 0.28f, 0.06f); + } + #ifdef INTGEMM_COMPILER_SUPPORTS_AVX512VNNI TEST_CASE ("Multiply AVX512VNNI 8bit", "[multiply]") { if (kCPU < CPUType::AVX512VNNI) return; @@ -483,6 +685,16 @@ TEST_CASE ("Multiply AVX2 16bit with bias", "[biased_multiply]") { TestMultiply(200, 256, 256, 0, 0.28f, 0.06f); } + TEST_CASE ("Multiply AVX512VNNI 8bit with relu", "[multiply_relu]") { + if (kCPU < CPUType::AVX512VNNI) return; + TestMultiplyRelu(8, 256, 256, 0, 0.25f, 0.062f); + TestMultiplyRelu(8, 2048, 256, 0, 0.55f, 0.25f); + TestMultiplyRelu(320, 256, 256, 0, 0.26f, 0.059f); + TestMultiplyRelu(472, 256, 256, 0, 0.29f, 0.059f); + TestMultiplyRelu(248, 256, 256, 0, 0.29f, 0.059f); + TestMultiplyRelu(200, 256, 256, 0, 0.28f, 0.06f); + } + TEST_CASE ("Multiply AVX512VNNI 8bit with bias", "[biased_multiply]") { if (kCPU < CPUType::AVX512VNNI) return; TestMultiplyBias(8, 256, 256, 0, 0.25f, 0.062f); @@ -492,6 +704,16 @@ TEST_CASE ("Multiply AVX2 16bit with bias", "[biased_multiply]") { TestMultiplyBias(248, 256, 256, 0, 0.29f, 0.059f); TestMultiplyBias(200, 256, 256, 0, 0.28f, 0.06f); } + + TEST_CASE ("Multiply AVX512VNNI 8bit with bias and relu", "[biased_multiply_relu]") { + if (kCPU < CPUType::AVX512VNNI) return; + TestMultiplyBiasRelu(8, 256, 256, 0, 0.25f, 0.062f); + TestMultiplyBiasRelu(8, 2048, 256, 0, 0.55f, 0.25f); + TestMultiplyBiasRelu(320, 256, 256, 0, 0.26f, 0.059f); + TestMultiplyBiasRelu(472, 256, 256, 0, 0.29f, 0.059f); + TestMultiplyBiasRelu(248, 256, 256, 0, 0.29f, 0.059f); + TestMultiplyBiasRelu(200, 256, 256, 0, 0.28f, 0.06f); + } #endif TEST_CASE ("Multiply AVX512 16bit", "[multiply]") { @@ -504,6 +726,17 @@ TEST_CASE ("Multiply AVX2 16bit with bias", "[biased_multiply]") { TestMultiply(200, 256, 256, .1f, 1, 0.01f); } + TEST_CASE ("Multiply AVX512 16bit with relu", "[multiply_relu]") { + if (kCPU < CPUType::AVX512BW) return; + TestMultiplyRelu(8, 256, 256, .1f, 1, 0.01f); + TestMultiplyRelu(8, 2048, 256, .1f, 1, 0.011f); + TestMultiplyRelu(320, 256, 256, .1f, 1, 0.01f); + TestMultiplyRelu(472, 256, 256, .1f, 1, 0.01f); + TestMultiplyRelu(248, 256, 256, .1f, 1, 0.01f); + TestMultiplyRelu(200, 256, 256, .1f, 1, 0.01f); + } + + TEST_CASE ("Multiply AVX512 16bit with bias", "[biased_multiply]") { if (kCPU < CPUType::AVX512BW) return; TestMultiplyBias(8, 256, 256, .1f, 1, 0.01f); @@ -513,6 +746,16 @@ TEST_CASE ("Multiply AVX2 16bit with bias", "[biased_multiply]") { TestMultiplyBias(248, 256, 256, .1f, 1, 0.01f); TestMultiplyBias(200, 256, 256, .1f, 1, 0.01f); } + + TEST_CASE ("Multiply AVX512 16bit with bias and relu", "[biased_multiply_relu]") { + if (kCPU < CPUType::AVX512BW) return; + TestMultiplyBiasRelu(8, 256, 256, .1f, 1, 0.01f); + TestMultiplyBiasRelu(8, 2048, 256, .1f, 1, 0.011f); + TestMultiplyBiasRelu(320, 256, 256, .1f, 1, 0.01f); + TestMultiplyBiasRelu(472, 256, 256, .1f, 1, 0.01f); + TestMultiplyBiasRelu(248, 256, 256, .1f, 1, 0.01f); + TestMultiplyBiasRelu(200, 256, 256, .1f, 1, 0.01f); + } #endif } // namespace intgemm -- cgit v1.2.3