Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/intgemm/intgemm.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNikolay Bogoychev <nheart@gmail.com>2021-06-23 14:48:30 +0300
committerGitHub <noreply@github.com>2021-06-23 14:48:30 +0300
commit6228d016ecc63470d2dbb76bd4ab7b0abe097993 (patch)
tree831a3e08ab24efa58ba608f62f0525e24ae2458b
parent18bcba45d08bcc0d5b64334b4b6ea2188a17b4f8 (diff)
Add relu callback (#89)
-rw-r--r--intgemm/callbacks/configs.h15
-rw-r--r--intgemm/callbacks/implementations.inl54
-rw-r--r--test/multiply_test.cc245
3 files changed, 313 insertions, 1 deletions
diff --git a/intgemm/callbacks/configs.h b/intgemm/callbacks/configs.h
index 1222448..d2fbe98 100644
--- a/intgemm/callbacks/configs.h
+++ b/intgemm/callbacks/configs.h
@@ -39,6 +39,13 @@ struct UnquantizeAndWrite {
UnquantizeAndWrite(float unquant_mult, float* output_addr) : unquant_mult(unquant_mult), output_addr(output_addr) {}
};
+struct UnquantizeAndWriteRelu {
+ float unquant_mult;
+ float* output_addr;
+
+ UnquantizeAndWriteRelu(float unquant_mult, float* output_addr) : unquant_mult(unquant_mult), output_addr(output_addr) {}
+};
+
struct AddBiasAndWrite {
const int* bias_addr;
int* output_addr;
@@ -54,5 +61,13 @@ struct UnquantizeAndAddBiasAndWrite {
UnquantizeAndAddBiasAndWrite(float unquant_mult, const float* bias_addr, float* output_addr) : unquant_mult(unquant_mult), bias_addr(bias_addr), output_addr(output_addr) {}
};
+struct UnquantizeAndAddBiasAndWriteRelu {
+ float unquant_mult;
+ const float* bias_addr;
+ float* output_addr;
+
+ UnquantizeAndAddBiasAndWriteRelu(float unquant_mult, const float* bias_addr, float* output_addr) : unquant_mult(unquant_mult), bias_addr(bias_addr), output_addr(output_addr) {}
+};
+
}
}
diff --git a/intgemm/callbacks/implementations.inl b/intgemm/callbacks/implementations.inl
index 9a8f9e1..126701d 100644
--- a/intgemm/callbacks/implementations.inl
+++ b/intgemm/callbacks/implementations.inl
@@ -153,6 +153,33 @@ private:
};
/*
+ * UnquantizeAndWriteRelu
+ */
+template <> class CallbackImpl<CPUType::CPU_NAME, UnquantizeAndWriteRelu> {
+public:
+ explicit INTGEMM_TARGET_CONSTRUCTOR CallbackImpl(const UnquantizeAndWriteRelu& config) : config(config) {
+ unquant_mult = set1_ps<vf>(config.unquant_mult);
+ }
+
+ INTGEMM_TARGET void Run(vi input, const OutputBufferInfo& info) {
+ // Workaround gcc 5 internal compiler error that can't read register members in debug.
+ vf mult_reg;
+#if !defined(__OPTIMIZE__) && (__GNUC__ == 5) && !defined(__clang__) && !defined(__INTEL_COMPILER)
+ asm ("vmovdqa %1, %0" : "=x" (mult_reg) : "m" (unquant_mult));
+#else
+ mult_reg = unquant_mult;
+#endif
+ auto result = kernels::relu<float>(kernels::unquantize(input, mult_reg));
+ kernels::write(result, config.output_addr, info.row_idx * info.cols + info.col_idx);
+ }
+
+private:
+ vf unquant_mult;
+ UnquantizeAndWriteRelu config;
+};
+
+
+/*
* AddBiasAndWrite
*/
template <> class CallbackImpl<CPUType::CPU_NAME, AddBiasAndWrite> {
@@ -194,6 +221,33 @@ private:
UnquantizeAndAddBiasAndWrite config;
};
+/*
+ * UnquantizeAndAddBiasAndWrite
+ */
+template <> class CallbackImpl<CPUType::CPU_NAME, UnquantizeAndAddBiasAndWriteRelu> {
+public:
+ explicit INTGEMM_TARGET_CONSTRUCTOR CallbackImpl(const UnquantizeAndAddBiasAndWriteRelu& config) : config(config) {
+ unquant_mult = set1_ps<vf>(config.unquant_mult);
+ }
+
+ INTGEMM_TARGET void Run(vi input, const OutputBufferInfo& info) {
+ // Workaround gcc 5 internal compiler error that can't read register members in debug.
+ vf mult_reg;
+#if !defined(__OPTIMIZE__) && (__GNUC__ == 5) && !defined(__clang__) && !defined(__INTEL_COMPILER)
+ asm ("vmovdqa %1, %0" : "=x" (mult_reg) : "m" (unquant_mult));
+#else
+ mult_reg = unquant_mult;
+#endif
+ auto result = kernels::unquantize(input, mult_reg);
+ result = kernels::add_bias(result, config.bias_addr, info.col_idx);
+ result = kernels::relu<float>(result);
+ kernels::write(result, config.output_addr, info.row_idx * info.cols + info.col_idx);
+ }
+private:
+ vf unquant_mult;
+ UnquantizeAndAddBiasAndWriteRelu config;
+};
+
}
}
diff --git a/test/multiply_test.cc b/test/multiply_test.cc
index 186b0f9..f72758f 100644
--- a/test/multiply_test.cc
+++ b/test/multiply_test.cc
@@ -315,6 +315,57 @@ template <class Routine> void TestMultiply(Index A_rows, Index width, Index B_co
int_tolerance, float_tolerance, MSE_float_tolerance, MSE_int_tolerance);
}
+template <class Routine> void TestMultiplyRelu(Index A_rows, Index width, Index B_cols,
+ float int_tolerance=.1, float float_tolerance=1, float MSE_float_tolerance=0, float MSE_int_tolerance=0) {
+ using Integer = typename Routine::Integer;
+ std::ostringstream info;
+ info << Routine::kName << "\t" << A_rows << '\t' << width << '\t' << B_cols << '\n';
+
+ // Initialize A and B.
+ AlignedVector<float> A(A_rows * width);
+ AlignedVector<float> B(width * B_cols);
+ std::mt19937 gen;
+ std::uniform_real_distribution<float> dist(-1.0f, 1.0f);
+ for (auto& it : A) {
+ it = dist(gen);
+ }
+ for (auto& it : B) {
+ it = dist(gen);
+ }
+
+ float quant_mult = (sizeof(Integer) == 2) ? 1024 : 64;
+ float unquant_mult = 1.0f / (quant_mult*quant_mult);
+
+ AlignedVector<Integer> A_prep(A.size());
+ AlignedVector<Integer> B_prep(B.size());
+ Routine::PrepareA(A.begin(), A_prep.begin(), quant_mult, A_rows, width);
+ Routine::PrepareB(B.begin(), B_prep.begin(), quant_mult, width, B_cols);
+
+ AlignedVector<float> test_C(A_rows * B_cols);
+ OMPParallelWrap<callbacks::UnquantizeAndWriteRelu, Routine>(A_prep.begin(), B_prep.begin(), A_rows, width, B_cols, callbacks::UnquantizeAndWriteRelu(unquant_mult, test_C.begin()));
+ // Routine::Multiply(A_prep.begin(), B_prep.begin(), A_rows, width, B_cols, callbacks::Sequence(
+ // callbacks::Unquantize(unquant_mult),
+ // callbacks::Write<float>(test_C.begin())
+ // ));
+
+ AlignedVector<Integer> B_quant(B.size());
+ Routine::Quantize(B.begin(), B_quant.begin(), quant_mult, static_cast<Index>(B.size()));
+ AlignedVector<float> slowint_C(test_C.size());
+ // Assuming A is just quantization here.
+ references::Multiply(A_prep.begin(), B_quant.begin(), slowint_C.begin(), A_rows, width, B_cols, [&](int32_t sum, const callbacks::OutputBufferInfo&) {
+ float ret = std::max(0.0f, sum * unquant_mult);
+ return ret;
+ });
+
+ AlignedVector<float> float_C(test_C.size());
+ references::Multiply(A.begin(), B.begin(), float_C.begin(), A_rows, width, B_cols, [&](double sum, const callbacks::OutputBufferInfo&) {
+ return static_cast<float>(std::max(0.0,sum));
+ });
+
+ CompareMSE(float_C.begin(), slowint_C.begin(), test_C.begin(), test_C.size(), info.str(),
+ int_tolerance, float_tolerance, MSE_float_tolerance, MSE_int_tolerance);
+}
+
//Code duplication may be avoided through some use of variadic templates, as the different WriteC symbols
//Require different number of arguments. I don't think the refactoring is worth it.
template <class Routine> void TestMultiplyBias(Index A_rows, Index width, Index B_cols,
@@ -338,7 +389,7 @@ template <class Routine> void TestMultiplyBias(Index A_rows, Index width, Index
for (auto& it : bias) {
it = dist(gen);
}
-
+
float quant_mult = (sizeof(Integer) == 2) ? 1024 : 64;
float unquant_mult = 1.0f / (quant_mult*quant_mult);
@@ -368,6 +419,57 @@ template <class Routine> void TestMultiplyBias(Index A_rows, Index width, Index
int_tolerance, float_tolerance, MSE_float_tolerance, MSE_int_tolerance);
}
+template <class Routine> void TestMultiplyBiasRelu(Index A_rows, Index width, Index B_cols,
+ float int_tolerance = 0.1f, float float_tolerance = 1.0f, float MSE_float_tolerance = 0.0f, float MSE_int_tolerance = 0.0f) {
+ using Integer = typename Routine::Integer;
+ std::ostringstream info;
+ info << Routine::kName << "\t" << A_rows << '\t' << width << '\t' << B_cols << '\n';
+
+ // Initialize A and B.
+ AlignedVector<float> A(A_rows * width);
+ AlignedVector<float> B(width * B_cols);
+ AlignedVector<float> bias(B_cols);
+ std::mt19937 gen;
+ std::uniform_real_distribution<float> dist(-1.0f, 1.0f);
+ for (auto& it : A) {
+ it = dist(gen);
+ }
+ for (auto& it : B) {
+ it = dist(gen);
+ }
+ for (auto& it : bias) {
+ it = dist(gen);
+ }
+
+ float quant_mult = (sizeof(Integer) == 2) ? 1024 : 64;
+ float unquant_mult = 1.0f / (quant_mult*quant_mult);
+
+ AlignedVector<Integer> A_prep(A.size());
+ AlignedVector<Integer> B_prep(B.size());
+ Routine::PrepareA(A.begin(), A_prep.begin(), quant_mult, A_rows, width);
+ Routine::PrepareB(B.begin(), B_prep.begin(), quant_mult, width, B_cols);
+
+ AlignedVector<float> test_C(A_rows * B_cols);
+
+ Routine::Multiply(A_prep.begin(), B_prep.begin(), A_rows, width, B_cols, callbacks::UnquantizeAndAddBiasAndWriteRelu(unquant_mult, bias.begin(), test_C.begin()));
+
+ AlignedVector<Integer> B_quant(B.size());
+ Routine::Quantize(B.begin(), B_quant.begin(), quant_mult, static_cast<Index>(B.size()));
+ AlignedVector<float> slowint_C(test_C.size());
+ // Assuming A is just quantization here.
+ references::Multiply(A_prep.begin(), B_quant.begin(), slowint_C.begin(), A_rows, width, B_cols, [&](int32_t sum, const callbacks::OutputBufferInfo& info) {
+ return std::max(0.0f, sum * unquant_mult + bias[info.col_idx]);
+ });
+
+ AlignedVector<float> float_C(test_C.size());
+ references::Multiply(A.begin(), B.begin(), float_C.begin(), A_rows, width, B_cols, [&](double sum, const callbacks::OutputBufferInfo& info) {
+ return std::max(0.0f, static_cast<float>(sum) + bias[info.col_idx]);
+ });
+
+ CompareMSE(float_C.begin(), slowint_C.begin(), test_C.begin(), test_C.size(), info.str(),
+ int_tolerance, float_tolerance, MSE_float_tolerance, MSE_int_tolerance);
+}
+
TEST_CASE ("Multiply SSE2 16bit", "[multiply]") {
if (kCPU < CPUType::SSE2) return;
TestMultiply<SSE2::Kernels16>(8, 256, 256, .1f, 1, 0.01f);
@@ -378,6 +480,16 @@ TEST_CASE ("Multiply SSE2 16bit", "[multiply]") {
TestMultiply<SSE2::Kernels16>(200, 256, 256, .1f, 1, 0.01f);
}
+TEST_CASE ("Multiply SSE2 16bit with relu", "[multiply_relu]") {
+ if (kCPU < CPUType::SSE2) return;
+ TestMultiplyRelu<SSE2::Kernels16>(8, 256, 256, .1f, 1, 0.01f);
+ TestMultiplyRelu<SSE2::Kernels16>(8, 2048, 256, .1f, 1, 0.02f);
+ TestMultiplyRelu<SSE2::Kernels16>(320, 256, 256, .1f, 1, 0.01f);
+ TestMultiplyRelu<SSE2::Kernels16>(472, 256, 256, .1f, 1, 0.01f);
+ TestMultiplyRelu<SSE2::Kernels16>(248, 256, 256, .1f, 1, 0.01f);
+ TestMultiplyRelu<SSE2::Kernels16>(200, 256, 256, .1f, 1, 0.01f);
+}
+
TEST_CASE ("Multiply SSE2 16bit with bias", "[biased_multiply]") {
if (kCPU < CPUType::SSE2) return;
TestMultiplyBias<SSE2::Kernels16>(8, 256, 256, .1f, 1, 0.01f);
@@ -388,6 +500,16 @@ TEST_CASE ("Multiply SSE2 16bit with bias", "[biased_multiply]") {
TestMultiplyBias<SSE2::Kernels16>(200, 256, 256, .1f, 1, 0.01f);
}
+TEST_CASE ("Multiply SSE2 16bit with bias and relu", "[biased_multiply_relu]") {
+ if (kCPU < CPUType::SSE2) return;
+ TestMultiplyBiasRelu<SSE2::Kernels16>(8, 256, 256, .1f, 1, 0.01f);
+ TestMultiplyBiasRelu<SSE2::Kernels16>(8, 2048, 256, .1f, 1, 0.02f);
+ TestMultiplyBiasRelu<SSE2::Kernels16>(320, 256, 256, .1f, 1, 0.01f);
+ TestMultiplyBiasRelu<SSE2::Kernels16>(472, 256, 256, .1f, 1, 0.01f);
+ TestMultiplyBiasRelu<SSE2::Kernels16>(248, 256, 256, .1f, 1, 0.01f);
+ TestMultiplyBiasRelu<SSE2::Kernels16>(200, 256, 256, .1f, 1, 0.01f);
+}
+
TEST_CASE ("Multiply SSSE3 8bit", "[multiply]") {
if (kCPU < CPUType::SSSE3) return;
TestMultiply<SSSE3::Kernels8>(8, 256, 256, 1.2f, 1.2f, 0.064f, 0.026f);
@@ -398,6 +520,16 @@ TEST_CASE ("Multiply SSSE3 8bit", "[multiply]") {
TestMultiply<SSSE3::Kernels8>(200, 256, 256, 1.8f, 1.9f, 0.1f, 0.011f);
}
+TEST_CASE ("Multiply SSSE3 8bit with relu", "[multiply_relu]") {
+ if (kCPU < CPUType::SSSE3) return;
+ TestMultiplyRelu<SSSE3::Kernels8>(8, 256, 256, 1.2f, 1.2f, 0.064f, 0.026f);
+ TestMultiplyRelu<SSSE3::Kernels8>(8, 2048, 256, 33, 33, 4.4f, 4.4f);
+ TestMultiplyRelu<SSSE3::Kernels8>(320, 256, 256, 1.9f, 1.9f, 0.1f, 0.01f);
+ TestMultiplyRelu<SSSE3::Kernels8>(472, 256, 256, 2.1f, 2.1f, 0.1f, 0.011f);
+ TestMultiplyRelu<SSSE3::Kernels8>(248, 256, 256, 1.7f, 1.7f, 0.1f, 0.012f);
+ TestMultiplyRelu<SSSE3::Kernels8>(200, 256, 256, 1.8f, 1.9f, 0.1f, 0.011f);
+}
+
TEST_CASE ("Multiply SSSE3 8bit with bias", "[biased_multiply]") {
if (kCPU < CPUType::SSSE3) return;
TestMultiplyBias<SSSE3::Kernels8>(8, 256, 256, 1.2f, 1.2f, 0.064f, 0.026f);
@@ -408,6 +540,16 @@ TEST_CASE ("Multiply SSSE3 8bit with bias", "[biased_multiply]") {
TestMultiplyBias<SSSE3::Kernels8>(200, 256, 256, 1.8f, 1.9f, 0.1f, 0.011f);
}
+TEST_CASE ("Multiply SSSE3 8bit with bias and relu", "[biased_multiply_relu]") {
+ if (kCPU < CPUType::SSSE3) return;
+ TestMultiplyBiasRelu<SSSE3::Kernels8>(8, 256, 256, 1.2f, 1.2f, 0.064f, 0.026f);
+ TestMultiplyBiasRelu<SSSE3::Kernels8>(8, 2048, 256, 33, 33, 4.4f, 4.4f);
+ TestMultiplyBiasRelu<SSSE3::Kernels8>(320, 256, 256, 1.9f, 1.9f, 0.1f, 0.01f);
+ TestMultiplyBiasRelu<SSSE3::Kernels8>(472, 256, 256, 2.1f, 2.1f, 0.1f, 0.011f);
+ TestMultiplyBiasRelu<SSSE3::Kernels8>(248, 256, 256, 1.7f, 1.7f, 0.1f, 0.012f);
+ TestMultiplyBiasRelu<SSSE3::Kernels8>(200, 256, 256, 1.8f, 1.9f, 0.1f, 0.011f);
+}
+
#ifdef INTGEMM_COMPILER_SUPPORTS_AVX2
TEST_CASE ("Multiply AVX2 8bit", "[multiply]") {
@@ -420,6 +562,16 @@ TEST_CASE ("Multiply AVX2 8bit", "[multiply]") {
TestMultiply<AVX2::Kernels8>(200, 256, 256, .1f, 1, 0.1f);
}
+TEST_CASE ("Multiply AVX2 8bit with relu", "[multiply_relu]") {
+ if (kCPU < CPUType::AVX2) return;
+ TestMultiplyRelu<AVX2::Kernels8>(8, 256, 256, .1f, 1, 0.1f);
+ TestMultiplyRelu<AVX2::Kernels8>(8, 2048, 256, 19, 19, 1.8f, 1.8f);
+ TestMultiplyRelu<AVX2::Kernels8>(320, 256, 256, .1f, 1, 0.1f);
+ TestMultiplyRelu<AVX2::Kernels8>(472, 256, 256, .1f, 1, 0.1f);
+ TestMultiplyRelu<AVX2::Kernels8>(248, 256, 256, .1f, 1, 0.1f);
+ TestMultiplyRelu<AVX2::Kernels8>(200, 256, 256, .1f, 1, 0.1f);
+}
+
TEST_CASE ("Multiply AVX2 8bit with bias", "[biased_multiply]") {
if (kCPU < CPUType::AVX2) return;
TestMultiplyBias<AVX2::Kernels8>(8, 256, 256, .1f, 1, 0.1f);
@@ -430,6 +582,16 @@ TEST_CASE ("Multiply AVX2 8bit with bias", "[biased_multiply]") {
TestMultiplyBias<AVX2::Kernels8>(200, 256, 256, .1f, 1, 0.1f);
}
+TEST_CASE ("Multiply AVX2 8bit with bias and relu", "[biased_multiply_relu]") {
+ if (kCPU < CPUType::AVX2) return;
+ TestMultiplyBiasRelu<AVX2::Kernels8>(8, 256, 256, .1f, 1, 0.1f);
+ TestMultiplyBiasRelu<AVX2::Kernels8>(8, 2048, 256, 19, 19, 1.8f, 1.8f);
+ TestMultiplyBiasRelu<AVX2::Kernels8>(320, 256, 256, .1f, 1, 0.1f);
+ TestMultiplyBiasRelu<AVX2::Kernels8>(472, 256, 256, .1f, 1, 0.1f);
+ TestMultiplyBiasRelu<AVX2::Kernels8>(248, 256, 256, .1f, 1, 0.1f);
+ TestMultiplyBiasRelu<AVX2::Kernels8>(200, 256, 256, .1f, 1, 0.1f);
+}
+
TEST_CASE ("Multiply AVX2 16bit", "[multiply]") {
if (kCPU < CPUType::AVX2) return;
TestMultiply<AVX2::Kernels16>(8, 256, 256, .1f, 1, 0.01f);
@@ -440,6 +602,16 @@ TEST_CASE ("Multiply AVX2 16bit", "[multiply]") {
TestMultiply<AVX2::Kernels16>(200, 256, 256, .1f, 1, 0.01f);
}
+TEST_CASE ("Multiply AVX2 16bit with relu", "[multiply_relu]") {
+ if (kCPU < CPUType::AVX2) return;
+ TestMultiplyRelu<AVX2::Kernels16>(8, 256, 256, .1f, 1, 0.01f);
+ TestMultiplyRelu<AVX2::Kernels16>(8, 2048, 256, .1f, 1, 0.02f);
+ TestMultiplyRelu<AVX2::Kernels16>(320, 256, 256, .1f, 1, 0.01f);
+ TestMultiplyRelu<AVX2::Kernels16>(472, 256, 256, .1f, 1, 0.01f);
+ TestMultiplyRelu<AVX2::Kernels16>(248, 256, 256, .1f, 1, 0.01f);
+ TestMultiplyRelu<AVX2::Kernels16>(200, 256, 256, .1f, 1, 0.01f);
+}
+
TEST_CASE ("Multiply AVX2 16bit with bias", "[biased_multiply]") {
if (kCPU < CPUType::AVX2) return;
TestMultiplyBias<AVX2::Kernels16>(8, 256, 256, .1f, 1, 0.01f);
@@ -449,6 +621,16 @@ TEST_CASE ("Multiply AVX2 16bit with bias", "[biased_multiply]") {
TestMultiplyBias<AVX2::Kernels16>(248, 256, 256, .1f, 1, 0.01f);
TestMultiplyBias<AVX2::Kernels16>(200, 256, 256, .1f, 1, 0.01f);
}
+
+TEST_CASE ("Multiply AVX2 16bit with bias and relu", "[biased_multiply_relu]") {
+ if (kCPU < CPUType::AVX2) return;
+ TestMultiplyBiasRelu<AVX2::Kernels16>(8, 256, 256, .1f, 1, 0.01f);
+ TestMultiplyBiasRelu<AVX2::Kernels16>(8, 2048, 256, .1f, 1, 0.02f);
+ TestMultiplyBiasRelu<AVX2::Kernels16>(320, 256, 256, .1f, 1, 0.01f);
+ TestMultiplyBiasRelu<AVX2::Kernels16>(472, 256, 256, .1f, 1, 0.01f);
+ TestMultiplyBiasRelu<AVX2::Kernels16>(248, 256, 256, .1f, 1, 0.01f);
+ TestMultiplyBiasRelu<AVX2::Kernels16>(200, 256, 256, .1f, 1, 0.01f);
+}
#endif
#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW
@@ -462,6 +644,16 @@ TEST_CASE ("Multiply AVX2 16bit with bias", "[biased_multiply]") {
TestMultiply<AVX512BW::Kernels8>(200, 256, 256, 0, 0.28f, 0.06f);
}
+ TEST_CASE ("Multiply AVX512 8bit with relu", "[multiply_relu]") {
+ if (kCPU < CPUType::AVX512BW) return;
+ TestMultiplyRelu<AVX512BW::Kernels8>(8, 256, 256, 0, 0.25f, 0.062f);
+ TestMultiplyRelu<AVX512BW::Kernels8>(8, 2048, 256, 3.7f, 4, 0.37f, 0.33f);
+ TestMultiplyRelu<AVX512BW::Kernels8>(320, 256, 256, 0, 0.26f, 0.059f);
+ TestMultiplyRelu<AVX512BW::Kernels8>(472, 256, 256, 0, 0.29f, 0.059f);
+ TestMultiplyRelu<AVX512BW::Kernels8>(248, 256, 256, 0, 0.29f, 0.059f);
+ TestMultiplyRelu<AVX512BW::Kernels8>(200, 256, 256, 0, 0.28f, 0.06f);
+ }
+
TEST_CASE ("Multiply AVX512 8bit with bias", "[biased_multiply]") {
if (kCPU < CPUType::AVX512BW) return;
TestMultiplyBias<AVX512BW::Kernels8>(8, 256, 256, 0, 0.25f, 0.062f);
@@ -472,6 +664,16 @@ TEST_CASE ("Multiply AVX2 16bit with bias", "[biased_multiply]") {
TestMultiplyBias<AVX512BW::Kernels8>(200, 256, 256, 0, 0.28f, 0.06f);
}
+ TEST_CASE ("Multiply AVX512 8bit with bias and relu", "[biased_multiply_relu]") {
+ if (kCPU < CPUType::AVX512BW) return;
+ TestMultiplyBiasRelu<AVX512BW::Kernels8>(8, 256, 256, 0, 0.25f, 0.062f);
+ TestMultiplyBiasRelu<AVX512BW::Kernels8>(8, 2048, 256, 3.7f, 4, 0.37f, 0.33f);
+ TestMultiplyBiasRelu<AVX512BW::Kernels8>(320, 256, 256, 0, 0.26f, 0.059f);
+ TestMultiplyBiasRelu<AVX512BW::Kernels8>(472, 256, 256, 0, 0.29f, 0.059f);
+ TestMultiplyBiasRelu<AVX512BW::Kernels8>(248, 256, 256, 0, 0.29f, 0.059f);
+ TestMultiplyBiasRelu<AVX512BW::Kernels8>(200, 256, 256, 0, 0.28f, 0.06f);
+ }
+
#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512VNNI
TEST_CASE ("Multiply AVX512VNNI 8bit", "[multiply]") {
if (kCPU < CPUType::AVX512VNNI) return;
@@ -483,6 +685,16 @@ TEST_CASE ("Multiply AVX2 16bit with bias", "[biased_multiply]") {
TestMultiply<AVX512VNNI::Kernels8>(200, 256, 256, 0, 0.28f, 0.06f);
}
+ TEST_CASE ("Multiply AVX512VNNI 8bit with relu", "[multiply_relu]") {
+ if (kCPU < CPUType::AVX512VNNI) return;
+ TestMultiplyRelu<AVX512VNNI::Kernels8>(8, 256, 256, 0, 0.25f, 0.062f);
+ TestMultiplyRelu<AVX512VNNI::Kernels8>(8, 2048, 256, 0, 0.55f, 0.25f);
+ TestMultiplyRelu<AVX512VNNI::Kernels8>(320, 256, 256, 0, 0.26f, 0.059f);
+ TestMultiplyRelu<AVX512VNNI::Kernels8>(472, 256, 256, 0, 0.29f, 0.059f);
+ TestMultiplyRelu<AVX512VNNI::Kernels8>(248, 256, 256, 0, 0.29f, 0.059f);
+ TestMultiplyRelu<AVX512VNNI::Kernels8>(200, 256, 256, 0, 0.28f, 0.06f);
+ }
+
TEST_CASE ("Multiply AVX512VNNI 8bit with bias", "[biased_multiply]") {
if (kCPU < CPUType::AVX512VNNI) return;
TestMultiplyBias<AVX512VNNI::Kernels8>(8, 256, 256, 0, 0.25f, 0.062f);
@@ -492,6 +704,16 @@ TEST_CASE ("Multiply AVX2 16bit with bias", "[biased_multiply]") {
TestMultiplyBias<AVX512VNNI::Kernels8>(248, 256, 256, 0, 0.29f, 0.059f);
TestMultiplyBias<AVX512VNNI::Kernels8>(200, 256, 256, 0, 0.28f, 0.06f);
}
+
+ TEST_CASE ("Multiply AVX512VNNI 8bit with bias and relu", "[biased_multiply_relu]") {
+ if (kCPU < CPUType::AVX512VNNI) return;
+ TestMultiplyBiasRelu<AVX512VNNI::Kernels8>(8, 256, 256, 0, 0.25f, 0.062f);
+ TestMultiplyBiasRelu<AVX512VNNI::Kernels8>(8, 2048, 256, 0, 0.55f, 0.25f);
+ TestMultiplyBiasRelu<AVX512VNNI::Kernels8>(320, 256, 256, 0, 0.26f, 0.059f);
+ TestMultiplyBiasRelu<AVX512VNNI::Kernels8>(472, 256, 256, 0, 0.29f, 0.059f);
+ TestMultiplyBiasRelu<AVX512VNNI::Kernels8>(248, 256, 256, 0, 0.29f, 0.059f);
+ TestMultiplyBiasRelu<AVX512VNNI::Kernels8>(200, 256, 256, 0, 0.28f, 0.06f);
+ }
#endif
TEST_CASE ("Multiply AVX512 16bit", "[multiply]") {
@@ -504,6 +726,17 @@ TEST_CASE ("Multiply AVX2 16bit with bias", "[biased_multiply]") {
TestMultiply<AVX512BW::Kernels16>(200, 256, 256, .1f, 1, 0.01f);
}
+ TEST_CASE ("Multiply AVX512 16bit with relu", "[multiply_relu]") {
+ if (kCPU < CPUType::AVX512BW) return;
+ TestMultiplyRelu<AVX512BW::Kernels16>(8, 256, 256, .1f, 1, 0.01f);
+ TestMultiplyRelu<AVX512BW::Kernels16>(8, 2048, 256, .1f, 1, 0.011f);
+ TestMultiplyRelu<AVX512BW::Kernels16>(320, 256, 256, .1f, 1, 0.01f);
+ TestMultiplyRelu<AVX512BW::Kernels16>(472, 256, 256, .1f, 1, 0.01f);
+ TestMultiplyRelu<AVX512BW::Kernels16>(248, 256, 256, .1f, 1, 0.01f);
+ TestMultiplyRelu<AVX512BW::Kernels16>(200, 256, 256, .1f, 1, 0.01f);
+ }
+
+
TEST_CASE ("Multiply AVX512 16bit with bias", "[biased_multiply]") {
if (kCPU < CPUType::AVX512BW) return;
TestMultiplyBias<AVX512BW::Kernels16>(8, 256, 256, .1f, 1, 0.01f);
@@ -513,6 +746,16 @@ TEST_CASE ("Multiply AVX2 16bit with bias", "[biased_multiply]") {
TestMultiplyBias<AVX512BW::Kernels16>(248, 256, 256, .1f, 1, 0.01f);
TestMultiplyBias<AVX512BW::Kernels16>(200, 256, 256, .1f, 1, 0.01f);
}
+
+ TEST_CASE ("Multiply AVX512 16bit with bias and relu", "[biased_multiply_relu]") {
+ if (kCPU < CPUType::AVX512BW) return;
+ TestMultiplyBiasRelu<AVX512BW::Kernels16>(8, 256, 256, .1f, 1, 0.01f);
+ TestMultiplyBiasRelu<AVX512BW::Kernels16>(8, 2048, 256, .1f, 1, 0.011f);
+ TestMultiplyBiasRelu<AVX512BW::Kernels16>(320, 256, 256, .1f, 1, 0.01f);
+ TestMultiplyBiasRelu<AVX512BW::Kernels16>(472, 256, 256, .1f, 1, 0.01f);
+ TestMultiplyBiasRelu<AVX512BW::Kernels16>(248, 256, 256, .1f, 1, 0.01f);
+ TestMultiplyBiasRelu<AVX512BW::Kernels16>(200, 256, 256, .1f, 1, 0.01f);
+ }
#endif
} // namespace intgemm