diff options
author | Young Jin Kim <youki@microsoft.com> | 2019-12-03 22:53:14 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2019-12-03 22:53:14 +0300 |
commit | 84e66a976046180187724aff60a236c5378fde7c (patch) | |
tree | f2c4e39fe4d46df1b7a23602d244d21c9f9ee35b | |
parent | f0b354327aaf2330c65340725b1981040c8bec9e (diff) | |
parent | e6e9b167426c12cd048c3d7d76651492f818daec (diff) |
Merge pull request #1 from marian-nmt/youki/win-jit-debug-int8
Youki/win jit debug int8
66 files changed, 8610 insertions, 4995 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt index e6c7419..c06b60b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -37,8 +37,10 @@ set(FBGEMM_GENERIC_SRCS src/ExecuteKernel.cc src/FbgemmI8Spmdm.cc src/GenerateKernelU8S8S32ACC16.cc src/GenerateKernelU8S8S32ACC16Avx512.cc + src/GenerateKernelU8S8S32ACC16Avx512VNNI.cc src/GenerateKernelU8S8S32ACC32.cc src/GenerateKernelU8S8S32ACC32Avx512.cc + src/GenerateKernelU8S8S32ACC32Avx512VNNI.cc src/GroupwiseConvAcc32Avx2.cc src/PackAMatrix.cc src/PackAWithIm2Col.cc @@ -87,8 +89,10 @@ endif() #All the source files that either use avx2 instructions statically set(FBGEMM_AVX2_SRCS src/FbgemmFP16UKernelsAvx2.cc + src/FbgemmI8Depthwise3DAvx2.cc src/FbgemmI8DepthwiseAvx2.cc src/OptimizedKernelsAvx2.cc + src/PackDepthwiseConvMatrixAvx2.cc src/QuantUtilsAvx2.cc src/UtilsAvx2.cc) diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md index 0f7ad8b..d1abc70 100644 --- a/CODE_OF_CONDUCT.md +++ b/CODE_OF_CONDUCT.md @@ -1,5 +1,77 @@ # Code of Conduct -Facebook has adopted a Code of Conduct that we expect project participants to adhere to. -Please read the [full text](https://code.fb.com/codeofconduct/) -so that you can understand what actions will and will not be tolerated. +## Our Pledge + +In the interest of fostering an open and welcoming environment, we as +contributors and maintainers pledge to make participation in our project and +our community a harassment-free experience for everyone, regardless of age, body +size, disability, ethnicity, sex characteristics, gender identity and expression, +level of experience, education, socio-economic status, nationality, personal +appearance, race, religion, or sexual identity and orientation. + +## Our Standards + +Examples of behavior that contributes to creating a positive environment +include: + +* Using welcoming and inclusive language +* Being respectful of differing viewpoints and experiences +* Gracefully accepting constructive criticism +* Focusing on what is best for the community +* Showing empathy towards other community members + +Examples of unacceptable behavior by participants include: + +* The use of sexualized language or imagery and unwelcome sexual attention or + advances +* Trolling, insulting/derogatory comments, and personal or political attacks +* Public or private harassment +* Publishing others' private information, such as a physical or electronic + address, without explicit permission +* Other conduct which could reasonably be considered inappropriate in a + professional setting + +## Our Responsibilities + +Project maintainers are responsible for clarifying the standards of acceptable +behavior and are expected to take appropriate and fair corrective action in +response to any instances of unacceptable behavior. + +Project maintainers have the right and responsibility to remove, edit, or +reject comments, commits, code, wiki edits, issues, and other contributions +that are not aligned to this Code of Conduct, or to ban temporarily or +permanently any contributor for other behaviors that they deem inappropriate, +threatening, offensive, or harmful. + +## Scope + +This Code of Conduct applies within all project spaces, and it also applies when +an individual is representing the project or its community in public spaces. +Examples of representing a project or community include using an official +project e-mail address, posting via an official social media account, or acting +as an appointed representative at an online or offline event. Representation of +a project may be further defined and clarified by project maintainers. + +## Enforcement + +Instances of abusive, harassing, or otherwise unacceptable behavior may be +reported by contacting the project team at <opensource-conduct@fb.com>. All +complaints will be reviewed and investigated and will result in a response that +is deemed necessary and appropriate to the circumstances. The project team is +obligated to maintain confidentiality with regard to the reporter of an incident. +Further details of specific enforcement policies may be posted separately. + +Project maintainers who do not follow or enforce the Code of Conduct in good +faith may face temporary or permanent repercussions as determined by other +members of the project's leadership. + +## Attribution + +This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, +available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html + +[homepage]: https://www.contributor-covenant.org + +For answers to common questions about this code of conduct, see +https://www.contributor-covenant.org/faq + @@ -12,9 +12,9 @@ row-wise quantization and outlier-aware quantization. FBGEMM also exploits fusion opportunities in order to overcome the unique challenges of matrix multiplication at lower precision with bandwidth-bound operations. -FBGEMM is used as a backend of Caffe2 quantized operators for x86 machines -(https://github.com/pytorch/pytorch/tree/master/caffe2/quantization/server). -We also plan to integrate FBGEMM into PyTorch. +FBGEMM is used as a backend of Caffe2 and PyTorch quantized operators for x86 machines: +* Caffe2: https://github.com/pytorch/pytorch/tree/master/caffe2/quantization/server +* PyTorch: https://github.com/pytorch/pytorch/tree/master/aten/src/ATen/native/quantized/cpu ## Examples @@ -64,6 +64,9 @@ General build instructions are as follows: ``` git clone --recursive https://github.com/pytorch/FBGEMM.git cd FBGEMM +# if you are updating an existing checkout +git submodule sync +git submodule update --init --recursive mkdir build && cd build cmake .. make diff --git a/bench/ConvUnifiedBenchmark.cc b/bench/ConvUnifiedBenchmark.cc index 6bc2cf4..35a3b2a 100644 --- a/bench/ConvUnifiedBenchmark.cc +++ b/bench/ConvUnifiedBenchmark.cc @@ -24,6 +24,7 @@ using namespace std; using namespace fbgemm; +// clang-format off // 2D conv shapes vector<conv_param_t<2>> shapes_2d = { // MB, IC, OC, IH, IW, G, KH, KW, stride_h, stride_w, @@ -31,22 +32,31 @@ vector<conv_param_t<2>> shapes_2d = { // 2D convolutions // regular conv_param_t<>(1, 128, 128, {56, 56}, 1, {3, 3}, {1, 1}, {1, 1, 1, 1}), + // regular with dilation + conv_param_t<>(1, 128, 128, {56, 56}, 1, {3, 3}, {1, 1}, {1, 1, 1, 1}, {2, 2}), // groupwise conv_param_t<>(1, 128, 128, {56, 56}, 32, {3, 3}, {1, 1}, {1, 1, 1, 1}), - // DW conv_param_t<>(1, 272, 272, {47, 125}, 272, {3, 3}, {1, 1}, {1, 1, 1, 1}), + // Pointwise + conv_param_t<>(1, 128, 128, {56, 56}, 1, {1, 1}, {1, 1}, {0, 0, 0, 0}) + }; // 3D conv shapes vector<conv_param_t<3>> shapes_3d = { - // MB, IC, OC, {IT, IH, IW}, G, {KT, KH, KW}, {stride_t, stride_h, stride_w}, - // {pad_prev, pad_h_top, pad_w_left, pad_next, pad_h_bottom, pad_w_right} - // Regular - conv_param_t<3>(1, 64, 64, {32, 56, 56}, 1, {3, 3, 3}, {1, 1, 1}, {1, 1, 1, 1, 1, 1}), - // Depthwise - conv_param_t<3>(1, 64, 64, {32, 56, 56}, 64, {3, 3, 3}, {1, 1, 1}, {1, 1, 1, 1, 1, 1}) -}; + // MB, IC, OC, {IT, IH, IW}, G, {KT, KH, KW}, {stride_t, stride_h, + // stride_w}, + // {pad_prev, pad_h_top, pad_w_left, pad_next, pad_h_bottom, pad_w_right} + // Regular + conv_param_t<3>(1, 64, 64, {8, 14, 14}, 1, {3, 3, 3}, {1, 1, 1}, {1, 1, 1, 1, 1, 1}), + //With dilations + conv_param_t<3>(1, 64, 64, {8, 14, 14}, 1, {3, 3, 3}, {1, 1, 1}, {1, 1, 1, 1, 1, 1}, {2, 2, 2}), + // Depthwise + conv_param_t<3>(1, 64, 64, {8, 14, 14}, 64, {3, 3, 3}, {1, 1, 1}, {1, 1, 1, 1, 1, 1}), + // Pointwise + conv_param_t<3>(1, 128, 128, {8, 14, 14}, 1, {1, 1, 1}, {1, 1, 1}, {0, 0, 0, 0})}; +// clang-format on template <int SPATIAL_DIM, typename Acc_t> void performance_test(const vector<conv_param_t<SPATIAL_DIM>>& shapes) { @@ -77,6 +87,10 @@ void performance_test(const vector<conv_param_t<SPATIAL_DIM>>& shapes) { header += "pad_t, "; } header += "pad_h, pad_w, "; + if (SPATIAL_DIM == 3) { + header += "dilation_t, "; + } + header += "dilation_h, dilation_w, "; header += "Type, M, N, K, "; @@ -110,6 +124,9 @@ void performance_test(const vector<conv_param_t<SPATIAL_DIM>>& shapes) { aligned_vector<int8_t> Bint8( kernel_dim * conv_p.IC * (conv_p.OC / conv_p.G)); + aligned_vector<int8_t> Bint8_tr( + kernel_dim * conv_p.IC * (conv_p.OC / conv_p.G)); + int im_out_dim = accumulate( conv_p.OUT_DIM.begin(), conv_p.OUT_DIM.end(), 1, multiplies<int>()); aligned_vector<int32_t> Cint32_ref(conv_p.MB * im_out_dim * conv_p.OC); @@ -132,14 +149,14 @@ void performance_test(const vector<conv_param_t<SPATIAL_DIM>>& shapes) { randFill(C_multiplier, 0.1234f / 2, 0.1234f * 3 / 2); int32_t C_zero_point = 5; - aligned_vector<float> Bfp32(Bint8.begin(), Bint8.end()); - // reference implementation + // conv_ref expects weights to be in G (R S C/G) K/G + transposeConvWeights<SPATIAL_DIM>(conv_p, Bint8.data(), Bint8_tr.data()); conv_ref( conv_p, Aint8.data(), Aint8_zero_point, - Bint8.data(), + Bint8_tr.data(), Cint32_ref.data()); // matrix dimensions after im2col @@ -162,7 +179,7 @@ void performance_test(const vector<conv_param_t<SPATIAL_DIM>>& shapes) { KDimPerGroup, OC_per_G, OC_per_G, - Bint8.data() + g * KDimPerGroup * OC_per_G, + Bint8_tr.data() + g * KDimPerGroup * OC_per_G, Bint8_zero_point.data(), col_offsets.data() + g * OC_per_G, conv_p.OC); @@ -271,7 +288,9 @@ void performance_test(const vector<conv_param_t<SPATIAL_DIM>>& shapes) { for (int i = 0; i < SPATIAL_DIM; ++i) { cout << conv_p.pad[i] << ", "; } - + for (int i = 0; i < SPATIAL_DIM; ++i) { + cout << conv_p.dilation[i] << ", "; + } cout << setw(13) << runType << ", " << setw(5) << fixed << setw(5) << setw(6) << MDim << ", " << setw(6) << NDim << ", " << setw(6) << KDim << ", "; diff --git a/bench/Depthwise3DBenchmark.cc b/bench/Depthwise3DBenchmark.cc index 0efdcac..ff2be6f 100644 --- a/bench/Depthwise3DBenchmark.cc +++ b/bench/Depthwise3DBenchmark.cc @@ -4,7 +4,6 @@ * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. */ -#include "test/I8DepthwiseTest.h" #include <algorithm> #include <chrono> @@ -19,8 +18,8 @@ #include "AlignedVec.h" #include "BenchUtils.h" -#include "fbgemm/Utils.h" #include "fbgemm/FbgemmI8DepthwiseAvx2.h" +#include "fbgemm/Utils.h" #include "src/RefImplementations.h" using namespace std; @@ -35,6 +34,34 @@ int main() { } #endif + // From ResNeXt-3D-101 + // clang-format off + vector<vector<int>> shapes_3d = { + // NOTE: clang-format wants to use a different formatting but the current + // formatting should be easier to read. + // N, K, T_in, H_in, W_in, stride + { 1, 64, 32, 56, 56, 1, }, + { 1, 128, 16, 28, 28, 1, }, + { 1, 256, 8, 14, 14, 1, }, + { 1, 512, 4, 7, 7, 1, }, + + { 1, 128, 32, 56, 56, 2, }, + { 1, 256, 16, 28, 28, 2, }, + { 1, 512, 8, 14, 14, 2, }, + + { 5, 64, 32, 56, 56, 1, }, + { 5, 128, 16, 28, 28, 1, }, + { 5, 256, 8, 14, 14, 1, }, + { 5, 512, 4, 7, 7, 1, }, + + { 5, 128, 32, 56, 56, 2, }, + { 5, 256, 16, 28, 28, 2, }, + { 5, 512, 8, 14, 14, 2, }, + + { 1, 8, 4, 4, 4, 1, }, + }; + // clang-format on + // Depthwise is memory BW bound so we want to flush LLC. bool flush = true; std::vector<char> llc; @@ -61,14 +88,28 @@ int main() { constexpr int K_T = 3, K_H = 3, K_W = 3; constexpr int PAD_P = 1, PAD_N = 1, PAD_T = 1, PAD_B = 1, PAD_L = 1, PAD_R = 1; - int T_OUT = (T + PAD_P + PAD_N - K_T) / stride_t + 1; - int H_OUT = (H + PAD_T + PAD_B - K_H) / stride_h + 1; - int W_OUT = (W + PAD_L + PAD_R - K_W) / stride_w + 1; + + conv_param_t<3> conv_p( + N, + K, + K, + {T, H, W}, + K, + {K_T, K_H, K_W}, + {stride_t, stride_h, stride_w}, + {PAD_P, PAD_T, PAD_L, PAD_N, PAD_B, PAD_R}); + int T_OUT = conv_p.OUT_DIM[0]; + int H_OUT = conv_p.OUT_DIM[1]; + int W_OUT = conv_p.OUT_DIM[2]; + + int MDim = N * T_OUT * H_OUT * W_OUT; + int KDim = K_T * K_H * K_W * K; + int KDimPerGroup = KDim / conv_p.G; aligned_vector<uint8_t> A(N * T * H * W * K); - aligned_vector<int8_t> B(K * K_T * K_H * K_W); - aligned_vector<int32_t> C_ref(N * T_OUT * H_OUT * W_OUT * K), - C(C_ref.size()); + aligned_vector<int8_t> B(KDim); + aligned_vector<int32_t> C_ref(MDim * K), C(C_ref.size()); + aligned_vector<uint8_t> C_uint8_ref(C_ref.size()), C_uint8(C_ref.size()); randFill<uint8_t>(A, 0, 86); int32_t A_zero_point = 43; @@ -76,52 +117,49 @@ int main() { randFill<int8_t>(B, -16, 16); int32_t B_zero_point = 5; - depthwise_3x3x3_pad_1_ref( - N, - T, - H, - W, - K, - stride_t, - stride_h, - stride_w, - A_zero_point, - A.data(), - B.data(), - C_ref.data()); - - int32_t minimum = *min_element(C_ref.begin(), C_ref.end()); - int32_t maximum = *max_element(C_ref.begin(), C_ref.end()); + aligned_vector<float> C_multiplier(1); + randFill(C_multiplier, 0.001234f / 2, 0.001234f * 3 / 2); + int32_t C_zero_point = 5; - float C_multiplier = 255. / (maximum - minimum); + vector<int32_t> row_offsets(MDim); + // im2col to compute row offset later + vector<uint8_t> A_im2col(MDim * KDim); + im2col_ref(conv_p, A.data(), A_zero_point, A_im2col.data()); aligned_vector<int32_t> col_offsets(K); aligned_vector<int32_t> bias(K); randFill(col_offsets, -100, 100); randFill(bias, -40, 40); - int32_t C_zero_point = 5; - aligned_vector<uint8_t> C_uint8_ref(C_ref.size()), C_uint8(C_ref.size()); - depthwise_3x3x3_pad_1_ref( - N, - T, - H, - W, - K, - stride_t, - stride_h, - stride_w, - A_zero_point, - A.data(), - B_zero_point, - B.data(), - C_multiplier, - C_zero_point, - C_uint8_ref.data(), - col_offsets.data(), - bias.data()); - - Packed3x3x3ConvMatrix Bp(K, B.data()); + conv_ref(conv_p, A.data(), A_zero_point, B.data(), C_ref.data()); + + for (int g = 0; g < conv_p.G; ++g) { + // Compute row offset + row_offsets_u8acc32_ref( + MDim, + KDimPerGroup, + KDim, + A_im2col.data() + g * KDimPerGroup, + row_offsets.data()); + + // Requantization + requantize_u8acc32_ref( + MDim, + 1, + conv_p.G, + C_ref.data() + g, + C_uint8_ref.data() + g, + C_multiplier.data(), + C_zero_point, + A_zero_point, + &B_zero_point, + row_offsets.data(), + col_offsets.data() + g, + bias.data() + g, + K); + } + + PackedDepthWiseConvMatrix Bp(K, 3 * 3 * 3, B.data()); double ttot = 0; double bytes = double(NITER) * @@ -153,7 +191,7 @@ int main() { A.data(), B_zero_point, Bp, - C_multiplier, + C_multiplier[0], C_zero_point, C_uint8.data(), col_offsets.data(), diff --git a/bench/DepthwiseBenchmark.cc b/bench/DepthwiseBenchmark.cc index 96921a1..6c2ee17 100644 --- a/bench/DepthwiseBenchmark.cc +++ b/bench/DepthwiseBenchmark.cc @@ -17,8 +17,8 @@ #include "AlignedVec.h" #include "BenchUtils.h" -#include "fbgemm/Utils.h" #include "fbgemm/FbgemmI8DepthwiseAvx2.h" +#include "fbgemm/Utils.h" #include "src/RefImplementations.h" using namespace std; @@ -34,10 +34,11 @@ int main() { #endif // From Xray OCR + // clang-format off vector<vector<int>> shapes = { // NOTE: clang-format wants to use a different formatting but the current // formatting should be easier to read. - // N, G, H_in, W_in, stride + // N, K, H_in, W_in, stride { 1, 272, 47, 125, 1, }, { 1, 272, 64, 125, 1, }, { 1, 272, 66, 125, 1, }, @@ -138,6 +139,7 @@ int main() { { 96, 544, 14, 14, 2, }, { 100, 544, 14, 14, 2, }, }; + // clang-format on // Depthwise is memory BW bound so we want to flush LLC. bool flush = true; @@ -155,19 +157,35 @@ int main() { for (auto shape : shapes) { int N = shape[0]; - int G = shape[1]; + int K = shape[1]; int H = shape[2]; int W = shape[3]; int stride_h = shape[4]; int stride_w = stride_h; constexpr int R = 3, S = 3; - constexpr int PAD_T = 1, PAD_B = 1, PAD_L = 1, PAD_R = 1; - int H_OUT = (H + PAD_T + PAD_B - R) / stride_h + 1; - int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1; + int PAD_T = (R - 1) / 2, PAD_B = (R - 1) / 2, PAD_L = (S - 1) / 2, + PAD_R = (S - 1) / 2; + + conv_param_t<2> conv_p( + N, + K, + K, + {H, W}, + K, + {R, S}, + {stride_h, stride_w}, + {PAD_T, PAD_L, PAD_B, PAD_R}); + int H_OUT = conv_p.OUT_DIM[0]; + int W_OUT = conv_p.OUT_DIM[1]; + + int MDim = N * H_OUT * W_OUT; + int KDim = R * S * K; + int KDimPerGroup = KDim / conv_p.G; - aligned_vector<uint8_t> A(N * H * W * G); - aligned_vector<int8_t> B(G * R * S); - aligned_vector<int32_t> C_ref(N * H_OUT * W_OUT * G), C(C_ref.size()); + aligned_vector<uint8_t> A(N * H * W * K); + aligned_vector<int8_t> B(KDim); + aligned_vector<int32_t> C_ref(MDim * K), C(C_ref.size()); + aligned_vector<uint8_t> C_uint8_ref(C_ref.size()), C_uint8(C_ref.size()); randFill<uint8_t>(A, 0, 86); int32_t A_zero_point = 43; @@ -175,53 +193,54 @@ int main() { randFill<int8_t>(B, -16, 16); int32_t B_zero_point = 5; - depthwise_3x3_pad_1_ref( - N, - H, - W, - G, - stride_h, - stride_w, - A_zero_point, - A.data(), - B.data(), - C_ref.data()); - - int32_t minimum = *min_element(C_ref.begin(), C_ref.end()); - int32_t maximum = *max_element(C_ref.begin(), C_ref.end()); + aligned_vector<float> C_multiplier(1); + randFill(C_multiplier, 0.001234f / 2, 0.001234f * 3 / 2); + int32_t C_zero_point = 5; - float C_multiplier = 255. / (maximum - minimum); + vector<int32_t> row_offsets(MDim); + // im2col to compute row offset later + vector<uint8_t> A_im2col(MDim * KDim); + im2col_ref(conv_p, A.data(), A_zero_point, A_im2col.data()); - aligned_vector<int32_t> col_offsets(G); - aligned_vector<int32_t> bias(G); + aligned_vector<int32_t> col_offsets(K); + aligned_vector<int32_t> bias(K); randFill(col_offsets, -100, 100); randFill(bias, -40, 40); - int32_t C_zero_point = 5; - aligned_vector<uint8_t> C_uint8_ref(C_ref.size()), C_uint8(C_ref.size()); - depthwise_3x3_pad_1_ref( - N, - H, - W, - G, - stride_h, - stride_w, - A_zero_point, - A.data(), - B_zero_point, - B.data(), - C_multiplier, - C_zero_point, - C_uint8_ref.data(), - col_offsets.data(), - bias.data()); + conv_ref(conv_p, A.data(), A_zero_point, B.data(), C_ref.data()); + + for (int g = 0; g < conv_p.G; ++g) { + // Compute row offset + row_offsets_u8acc32_ref( + MDim, + KDimPerGroup, + KDim, + A_im2col.data() + g * KDimPerGroup, + row_offsets.data()); + + // Requantization + requantize_u8acc32_ref( + MDim, + 1, + conv_p.G, + C_ref.data() + g, + C_uint8_ref.data() + g, + C_multiplier.data(), + C_zero_point, + A_zero_point, + &B_zero_point, + row_offsets.data(), + col_offsets.data() + g, + bias.data() + g, + K); + } - Packed3x3ConvMatrix Bp(G, B.data()); + PackedDepthWiseConvMatrix Bp(K, 3 * 3, B.data()); double ttot = 0; double bytes = double(NITER) * - (G * (N * (2 * sizeof(int32_t) * H_OUT * W_OUT + H * W) + R * S)); - double ops = double(NITER) * N * H_OUT * W_OUT * G * R * S * 2; + (K * (N * (2 * sizeof(int32_t) * H_OUT * W_OUT + H * W) + R * S)); + double ops = double(NITER) * N * H_OUT * W_OUT * K * R * S * 2; chrono::time_point<chrono::system_clock> t_begin, t_end; for (int i = 0; i < NWARMUP + NITER; ++i) { llc_flush(); @@ -235,19 +254,20 @@ int main() { N, H, W, - G, + K, stride_h, stride_w, A_zero_point, A.data(), B_zero_point, Bp, - C_multiplier, + C_multiplier[0], C_zero_point, C_uint8.data(), col_offsets.data(), bias.data(), false, /* fuse_relu */ + 1.0f, /* act_scale * w_scale */ tid, num_threads); } @@ -262,10 +282,10 @@ int main() { for (int n = 0; n < N; ++n) { for (int h = 0; h < H_OUT; ++h) { for (int w = 0; w < W_OUT; ++w) { - for (int g = 0; g < G; ++g) { + for (int g = 0; g < K; ++g) { uint8_t expected = - C_uint8_ref[((n * H_OUT + h) * W_OUT + w) * G + g]; - uint8_t actual = C_uint8[((n * H_OUT + h) * W_OUT + w) * G + g]; + C_uint8_ref[((n * H_OUT + h) * W_OUT + w) * K + g]; + uint8_t actual = C_uint8[((n * H_OUT + h) * W_OUT + w) * K + g]; if (expected != actual) { cerr << "Depthwise 3x3 results differ at (" << n << ", " << h << ", " << w << ", " << g << "). expected " << (int)expected @@ -280,9 +300,9 @@ int main() { // Report performance printf( - "N = %d G = %d H = %d W = %d stride = %d with requantization fused\n", + "N = %d K = %d H = %d W = %d stride = %d with requantization fused\n", N, - G, + K, H, W, stride_h); diff --git a/bench/GEMMsBenchmark.cc b/bench/GEMMsBenchmark.cc index b404d8b..f493a96 100644 --- a/bench/GEMMsBenchmark.cc +++ b/bench/GEMMsBenchmark.cc @@ -28,6 +28,7 @@ using namespace std; using namespace fbgemm; void performance_test() { + // clang-format off static const vector<vector<int>> shapes = { // NOTE: clang-format wants to use a different formatting but the current // formatting should be easier to read. @@ -39,6 +40,7 @@ void performance_test() { {256, 512, 256}, {1024, 1024, 1024}, }; + // clang-format on bool flush = true; std::vector<char> llc; diff --git a/bench/GEMMsTunableBenchmark.cc b/bench/GEMMsTunableBenchmark.cc index a65b51f..2adc556 100644 --- a/bench/GEMMsTunableBenchmark.cc +++ b/bench/GEMMsTunableBenchmark.cc @@ -218,7 +218,8 @@ int main(int /* unused */, char** /* unused */) { } #endif - vector<vector<int>> shapes = { + // clang-format off + vector<vector<int>> shapes = { // NOTE: clang-format wants to use a different formatting but the current // formatting should be easier to read. // m, n, k @@ -266,7 +267,8 @@ int main(int /* unused */, char** /* unused */) { {128, 128, 128}, {256, 512, 256}, {1024, 1024, 1024}, -}; + }; + // clang-format on vector<int> MCBs; vector<int> NCBs; diff --git a/bench/PackedFloatInOutBenchmark.cc b/bench/PackedFloatInOutBenchmark.cc index 66ca67e..dcea65c 100644 --- a/bench/PackedFloatInOutBenchmark.cc +++ b/bench/PackedFloatInOutBenchmark.cc @@ -28,6 +28,7 @@ using namespace std; using namespace fbgemm; void performance_test() { + // clang-format off vector<vector<int>> shapes = { // NOTE: clang-format wants to use a different formatting but the current // formatting should be easier to read. @@ -66,6 +67,7 @@ void performance_test() { {1, 128, 2722}, {16, 256, 512}, }; + // clang-format on bool flush = true; std::vector<char> llc; diff --git a/bench/PackedRequantizeAcc16Benchmark.cc b/bench/PackedRequantizeAcc16Benchmark.cc index 40ff662..c6e2869 100644 --- a/bench/PackedRequantizeAcc16Benchmark.cc +++ b/bench/PackedRequantizeAcc16Benchmark.cc @@ -37,6 +37,7 @@ enum class BenchmarkType { }; void performance_test() { + // clang-format off vector<vector<int>> shapes = { // NOTE: clang-format wants to use a different formatting but the current // formatting should be easier to read. @@ -67,6 +68,7 @@ void performance_test() { {392, 2048, 512}, {392, 512, 2048}, }; + // clang-format on bool flush = true; std::vector<char> llc; diff --git a/bench/PackedRequantizeAcc32Benchmark.cc b/bench/PackedRequantizeAcc32Benchmark.cc index 2f04795..a61ef5a 100644 --- a/bench/PackedRequantizeAcc32Benchmark.cc +++ b/bench/PackedRequantizeAcc32Benchmark.cc @@ -28,6 +28,7 @@ using namespace std; using namespace fbgemm; void performance_test() { + // clang-format off vector<vector<int>> shapes = { // NOTE: clang-format wants to use a different formatting but the current // formatting should be easier to read. @@ -70,6 +71,7 @@ void performance_test() { {1, 128, 2722}, {16, 256, 512}, }; + // clang-format on bool flush = true; std::vector<char> llc; diff --git a/include/fbgemm/ConvUtils.h b/include/fbgemm/ConvUtils.h index 11f3dcc..5431958 100644 --- a/include/fbgemm/ConvUtils.h +++ b/include/fbgemm/ConvUtils.h @@ -8,9 +8,24 @@ #include <array> #include <string> +#include <type_traits> namespace fbgemm { +template <int N, int... Vals> +constexpr + typename std::enable_if<N == sizeof...(Vals), std::array<int, N>>::type + array_of_ones() { + return std::array<int, N>{{Vals...}}; +} + +template <int N, int... Vals> +constexpr + typename std::enable_if<N != sizeof...(Vals), std::array<int, N>>::type + array_of_ones() { + return array_of_ones<N, Vals..., 1>(); +} + /** * @brief A struct to conveniently store all convolution parameters. */ @@ -34,7 +49,6 @@ struct conv_param_t { /** * @brief Constructor for initializing the convolution parameters. - * TODO: Dilation is not handled correctly. */ conv_param_t( int mb, @@ -44,7 +58,8 @@ struct conv_param_t { int g, std::array<int, SPATIAL_DIM> k, std::array<int, SPATIAL_DIM> strd, - std::array<int, SPATIAL_DIM * 2> pd) + std::array<int, SPATIAL_DIM * 2> pd, + std::array<int, SPATIAL_DIM> dilations = array_of_ones<SPATIAL_DIM>()) : MB(mb), IC(ic), OC(oc), @@ -52,7 +67,8 @@ struct conv_param_t { G(g), K(k), stride(strd), - pad(pd) { + pad(pd), + dilation(dilations) { if (ic % g != 0) { throw std::runtime_error( "groups = " + std::to_string(g) + @@ -63,10 +79,10 @@ struct conv_param_t { "groups = " + std::to_string(g) + " does not divide number of output channels = " + std::to_string(oc)); } + for (int d = 0; d < SPATIAL_DIM; ++d) { - dilation[d] = 1; IN_DIMP[d] = IN_DIM[d] + pad[d] + pad[SPATIAL_DIM + d]; - OUT_DIM[d] = (IN_DIMP[d] - K[d]) / stride[d] + 1; + OUT_DIM[d] = (IN_DIMP[d] - dilation[d] * (K[d] - 1) - 1) / stride[d] + 1; } } @@ -102,8 +118,12 @@ struct conv_param_t { } for (int d = 0; d < SPATIAL_DIM * 2; ++d) { out += "pad_" + dim_string[3 - SPATIAL_DIM + (d % SPATIAL_DIM)] + ":" + - std::to_string(pad[d]); - if (d < SPATIAL_DIM * 2 - 1) { + std::to_string(pad[d]) + ", "; + } + for (int d = 0; d < SPATIAL_DIM; ++d) { + out += "dilation_" + dim_string[3 - SPATIAL_DIM + d] + ":" + + std::to_string(dilation[d]); + if (d < SPATIAL_DIM - 1) { out += ", "; } } @@ -121,6 +141,10 @@ struct conv_param_t { out += ", "; } } + for (int d = 0; d < SPATIAL_DIM; ++d) { + out += "dilation_" + std::to_string(d) + ":" + + std::to_string(dilation[d]) + ", "; + } } return out; } diff --git a/include/fbgemm/Fbgemm.h b/include/fbgemm/Fbgemm.h index 90d1ee9..668bd42 100644 --- a/include/fbgemm/Fbgemm.h +++ b/include/fbgemm/Fbgemm.h @@ -416,6 +416,19 @@ class FBGEMM_API PackBMatrix final const BlockingFactors* params = nullptr); /** + * This constructor accepts pre-packed matrix as an input. + * And, it skips the actual packing procedure. + */ + PackBMatrix( + matrix_op_t trans, + std::int32_t nRow, + std::int32_t nCol, + inpType* prepackedmat, + std::int32_t ld, + int groups = 1, + const BlockingFactors* params = nullptr); + + /** * Weight matrices are usually constant so worth pre-packing. */ bool isPrePacked() const { @@ -445,14 +458,17 @@ class FBGEMM_API PackBMatrix final std::int32_t addr(std::int32_t i, std::int32_t j) const; /** - * @brief Packs a block of source matrix into pmat buffer. + * @brief Packs a block of source matrix into pmat buffer. The blocking + * parameters are needed to compute the buffer size of each group. + * It will use default blocking parameters if params is not provided. */ - void pack(const block_type_t& block); + void pack(const block_type_t& block, const BlockingFactors* params = nullptr); /** * @brief Print the packed block. */ - void printPackedMatrix(std::string name); + void printPackedMatrix(std::string name, + const BlockingFactors* params = nullptr); /** * @return true if meta information like matrix shape is the same. @@ -467,7 +483,7 @@ class FBGEMM_API PackBMatrix final * @brief Unpack pmat buffer to the origin_buf (Used for the serialization to * recover weight matrix). */ - void unpack(T* origin_buf); + void unpack(T* origin_buf, const BlockingFactors* params = nullptr); ~PackBMatrix() {} @@ -476,6 +492,16 @@ class FBGEMM_API PackBMatrix final const T* smat_; std::int32_t ld_; std::int32_t row_interleave_; + + /** + * @brief Internal function performing both pack & unpack + */ + void pack_unpack_( + const block_type_t& block, + T* unpack_buf, + T* pack_buf, + bool ispack, + const BlockingFactors* params = nullptr); }; /** @@ -508,6 +534,11 @@ class FBGEMM_API PackWeightMatrixForGConv { void pack(); /** + * @brief Unpacks a pmat buffer into source matrix. + */ + void unpack(T* origin_buf); + + /** * @brief Return packed data */ inpType* getBuf() { @@ -530,6 +561,22 @@ class FBGEMM_API PackWeightMatrixForGConv { const T* sdata_; T* pdata_; bool bufAllocatedHere_; + + /** + * @brief Internal function performing both pack & unpack + */ + void pack_unpack_(const T* src, T* dst, bool ispack); + + /** + * @brief Get the index of the unpacked data + */ + int unpacked_index_(int r, int s, int k, int g, int c, bool tr); + + /** + * @brief Get the index of the packed data + */ + int packed_index_(int r, int s, int k, int g, int c); + }; /** @@ -562,12 +609,16 @@ class FBGEMM_API PackWeightsForConv { return W_im2col_packed_; } - std::shared_ptr<Packed3x3ConvMatrix> getPackedWFor2DDW() { - return W_dw_2D_packed_; + std::shared_ptr<PackedDepthWiseConvMatrix> getPackedWForDepthwise() { + return W_dw_packed_; } - std::shared_ptr<Packed3x3x3ConvMatrix> getPackedWFor3DDW() { - return W_dw_3D_packed_; + std::shared_ptr<PackedDepthWiseConvMatrix> getPackedWFor2DDW() { + return W_dw_packed_; + } + + std::shared_ptr<PackedDepthWiseConvMatrix> getPackedWFor3DDW() { + return W_dw_packed_; } std::shared_ptr<PackWeightMatrixForGConv<T, accT, SPATIAL_DIM>> @@ -575,17 +626,55 @@ class FBGEMM_API PackWeightsForConv { return W_gconv_packed_; } + std::shared_ptr<PackBMatrix<T, accT>> getPackedWForPointwise() { + return W_pointwise_packed_; + } + + int inputChannels() { + return conv_param_.IC; + } + + int outputChannels() { + return conv_param_.OC; + } + + std::array<int, SPATIAL_DIM> kernelDims() { + return conv_param_.K; + } + + int groups() { + return conv_param_.G; + } + + /** + * @brief Returns true if the packed weights would work for the given + * convolution parameters, and false otherwise + */ + bool isPackingCompliant(const conv_param_t<SPATIAL_DIM>& conv_p); + + /** + * @brief Returns a string of mismatching parameters + */ + std::string mismatchingParams(const conv_param_t<SPATIAL_DIM>& conv_p); + + /** + * @brief Unpack packed matric into origin_buf (Used for the serialization to + * recover weight matrix). + */ + void unpack(T* origin_buf); + private: + const conv_param_t<SPATIAL_DIM> conv_param_; // Packed weights if we use im2col based convolution implementation std::shared_ptr<PackBMatrix<T, accT>> W_im2col_packed_; - // Packed weights if we use 2D depthwise convolution implementation - std::shared_ptr<Packed3x3ConvMatrix> W_dw_2D_packed_; - // Packed weights if we use 3D depthwise convolution implementation - std::shared_ptr<Packed3x3x3ConvMatrix> W_dw_3D_packed_; + // Packed weights if we use depthwise convolution implementation + std::shared_ptr<PackedDepthWiseConvMatrix> W_dw_packed_; // Packed weights if we use groupwise (small channels per group) convolution // implementation std::shared_ptr<PackWeightMatrixForGConv<T, accT, SPATIAL_DIM>> W_gconv_packed_; + // Packed weights if we use direct gemm for pointwise convolution + std::shared_ptr<PackBMatrix<T, accT>> W_pointwise_packed_; }; /** @@ -661,7 +750,11 @@ class FBGEMM_API PackAWithIm2Col ~PackAWithIm2Col() { if (rowOffsetAllocatedHere) { +#ifdef _MSC_VER + _aligned_free(row_offset_); +#else free(row_offset_); +#endif } } @@ -752,7 +845,11 @@ class FBGEMM_API PackAWithRowOffset final ~PackAWithRowOffset() { if (rowOffsetAllocatedHere) { +#ifdef _MSC_VER + _aligned_free(row_offset_); +#else free(row_offset_); +#endif } } @@ -845,7 +942,11 @@ class FBGEMM_API PackAWithQuantRowOffset final ~PackAWithQuantRowOffset() { if (rowOffsetAllocatedHere) { +#ifdef _MSC_VER + _aligned_free(row_offset_); +#else free(row_offset_); +#endif } } @@ -1062,12 +1163,15 @@ class FBGEMM_API DoSConvOnInpBuffer { template < bool FUSE_RELU, QuantizationGranularity Q_GRAN = QuantizationGranularity::TENSOR, + typename BIAS_TYPE = std::int32_t, typename outT = std::uint8_t, typename inT = std::int32_t, typename nextOPType = DoNothing<outT, outT>> class FBGEMM_API ReQuantizeOutput { public: static constexpr int RELU_FUSED = FUSE_RELU; + static constexpr QuantizationGranularity QGRANType = Q_GRAN; + using BIAS_T = BIAS_TYPE; using outType = outT; using inpType = inT; /** @@ -1088,6 +1192,8 @@ class FBGEMM_API ReQuantizeOutput { * See PackedRequantizeTest.cc for an example. * TODO: if Aq_zero_point == 0, allow passing nullptr. * @params bias can be nullptr otherwise the length should be nCol + * @params act_times_w_scale activation_scale * weight_scale. This is only + * used if bias is unquantized (i.e., float). */ ReQuantizeOutput( nextOPType& nextop, @@ -1097,9 +1203,10 @@ class FBGEMM_API ReQuantizeOutput { const std::int32_t* Bq_zero_point, const std::int32_t* row_offsets, const std::int32_t* col_offsets, - const std::int32_t* bias, + const BIAS_T* bias, std::uint32_t nCol, - int groups = 1) + int groups = 1, + const float* act_times_w_scale = nullptr) : nextop_(nextop), C_multiplier_(C_multiplier), C_zero_point_(C_zero_point), @@ -1109,7 +1216,8 @@ class FBGEMM_API ReQuantizeOutput { q_col_offsets_(col_offsets), bias_(bias), ncols_(nCol), - groups_(groups) {} + groups_(groups), + act_times_w_scale_(act_times_w_scale) {} template <inst_set_t instSet> inline int f( @@ -1137,12 +1245,15 @@ class FBGEMM_API ReQuantizeOutput { const std::int32_t* getColOffsets() const { return q_col_offsets_; } - const std::int32_t* getBias() const { + const BIAS_T* getBias() const { return bias_; } std::uint32_t getNCols() const { return ncols_; } + const float* getActWScale() const { + return act_times_w_scale_; + } void setRowOffsets(const std::int32_t* row_offsets) { q_row_offsets_ = row_offsets; @@ -1156,9 +1267,10 @@ class FBGEMM_API ReQuantizeOutput { const std::int32_t* Bq_zero_point_; const std::int32_t* q_row_offsets_; const std::int32_t* q_col_offsets_; - const std::int32_t* bias_; + const BIAS_T* bias_; std::uint32_t ncols_; int groups_; + const float* act_times_w_scale_; }; /** @@ -1311,7 +1423,8 @@ template < typename outType, bool FUSE_RELU, QuantizationGranularity Q_GRAN, - int SPATIAL_DIM = 2> + int SPATIAL_DIM = 2, + typename BIAS_TYPE = std::int32_t> FBGEMM_API void fbgemmGroupwiseConv( const conv_param_t<SPATIAL_DIM>& conv_param, const std::uint8_t* activations, @@ -1320,7 +1433,7 @@ FBGEMM_API void fbgemmGroupwiseConv( packed_W& packed_weights, outType* out, std::int32_t* outBuffer, - const ReQuantizeOutput<FUSE_RELU, Q_GRAN>& outProcess, + const ReQuantizeOutput<FUSE_RELU, Q_GRAN, BIAS_TYPE>& outProcess, int thread_id, int num_threads); @@ -1361,6 +1474,13 @@ template <int SPATIAL_DIM> FBGEMM_API bool fbgemmOptimizedGConv(const conv_param_t<SPATIAL_DIM>& conv_p); /** + * @brief Is this convolution a direct matrix-matrix multiplication, i.e., 1x1 + * (aka pointwise) with right paddings etc.? + */ +template <int SPATIAL_DIM> +FBGEMM_API bool takePointWiseFastPath(const conv_param_t<SPATIAL_DIM>& conv_p); + +/** * @brief Allocate __size bytes of uninitialized storage whose alignment is * specified by __align. */ diff --git a/include/fbgemm/FbgemmFP16.h b/include/fbgemm/FbgemmFP16.h index 3d84977..8da0b56 100644 --- a/include/fbgemm/FbgemmFP16.h +++ b/include/fbgemm/FbgemmFP16.h @@ -104,6 +104,14 @@ class PackedGemmMatrixFP16 { } } + void setPacked(bool p) { + packed_ = p; + } + + bool packed() const { + return packed_; + } + void initializeMemory() { // allocate and initialize packed memory const int padding = 1024; // required by sw pipelined kernels @@ -128,6 +136,16 @@ class PackedGemmMatrixFP16 { #endif } + void unpackFromSrc(const matrix_op_t trans, float16* src_mat) { + bool tr = (trans == matrix_op_t::Transpose); + for (int i = 0; i < numRows(); i++) { + for (int j = 0; j < numCols(); j++) { + pmat_[tr ? i + numRows() * j : i * numCols() + j] = src_mat[addr(i, j)]; + } + } + packed_ = false; + } + // protected: // blocked row-major format address arithmetic uint64_t addr(const int r_, const int c_) const { @@ -163,6 +181,19 @@ class PackedGemmMatrixFP16 { pmat_[addr(i, j)]); } } + packed_ = true; + } + + // This function takes in an unpacked float16 matrix of the same size and + // packs it. There is no floating type conversion. + void packFromSrc(const matrix_op_t trans, const float16* smat) { + bool tr = (trans == matrix_op_t::Transpose); + for (int i = 0; i < numRows(); ++i) { + for (int j = 0; j < numCols(); ++j) { + pmat_[addr(i, j)] = smat[tr ? i + numRows() * j : i * numCols() + j]; + } + } + packed_ = true; } const float16& operator()(const int r, const int c) const { @@ -210,6 +241,7 @@ class PackedGemmMatrixFP16 { uint64_t size_; int kernel_ncol_blocks_; float16* pmat_; + bool packed_{false}; friend void cblas_gemm_compute( const matrix_op_t transa, diff --git a/include/fbgemm/FbgemmI8DepthwiseAvx2.h b/include/fbgemm/FbgemmI8DepthwiseAvx2.h index 069ff77..c454b16 100644 --- a/include/fbgemm/FbgemmI8DepthwiseAvx2.h +++ b/include/fbgemm/FbgemmI8DepthwiseAvx2.h @@ -11,31 +11,58 @@ namespace fbgemm { -// KERNEL_PROD is the product of all kernels. -// For example, KERNEL_PROD = 9 for 3x3, and 27 for 3x3x3. -template <int KERNEL_PROD> class FBGEMM_API PackedDepthWiseConvMatrix { public: - // smat in RSG layout - PackedDepthWiseConvMatrix(int K, const std::int8_t* smat); + /** + * @params K the number of channels (same as the number of groups because + * depth-wise convolution has one input/output channel per group) + * @params kernel_prod the product of all kernels. For example, kernel_prod = + * 9 for 3x3 conv, and 27 for 3x3x3 conv. + * @param smat the source unpacked weight in GRS layout + */ + PackedDepthWiseConvMatrix(int K, int kernel_prod, const std::int8_t* smat); virtual ~PackedDepthWiseConvMatrix(); const std::int8_t* PackedMat() const { return pmat_; } + int GetKernelProduct() const { + return kernel_prod_; + } + + /** + * @brief Unpacks pmat_ into unpack_data. + * Used for recovering the weight matrix into the original format + */ + void unpack(std::int8_t* unpacked_data); + + /** + * @brief returns the index into pmat_ given the row and column for smat + */ + int addr(int r, int c); + private: - int K_; - std::int8_t* pmat_; -}; // Packed3x3ConvMatrix + const int K_; /**< the number of channels */ + const int kernel_prod_; /** the product of all kernel dims */ + std::int8_t* pmat_; /** packed weight */ +}; // PackedDepthWiseConvMatrix -using Packed3x3ConvMatrix = PackedDepthWiseConvMatrix<3 * 3>; -using Packed3x3x3ConvMatrix = PackedDepthWiseConvMatrix<3 * 3 * 3>; +class FBGEMM_API Packed3x3ConvMatrix : public PackedDepthWiseConvMatrix { + public: + Packed3x3ConvMatrix(int K, const std::int8_t* smat) + : PackedDepthWiseConvMatrix(K, 3 * 3, smat) {} +}; -/** - * Depth-wise 3x3 convolution with pad=1 and stride=1 and K a multiple of 8 - * @params A The input image in NHWK layout - * @params Bp The pre-packed filter +class FBGEMM_API Packed3x3x3ConvMatrix : public PackedDepthWiseConvMatrix { + public: + Packed3x3x3ConvMatrix(int K, const std::int8_t* smat) + : PackedDepthWiseConvMatrix(K, 3 * 3 * 3, smat) {} +}; + +/** To be removed. Keeping it just to make sure we don't change C2 files and + * fbgemm files in a single diff + * */ FBGEMM_API void depthwise_3x3_pad_1( int N, @@ -46,8 +73,14 @@ FBGEMM_API void depthwise_3x3_pad_1( int stride_w, std::int32_t A_zero_point, const std::uint8_t* A, - const Packed3x3ConvMatrix& Bp, - std::int32_t* C, + std::int32_t B_zero_point, + const PackedDepthWiseConvMatrix& Bp, + float C_multiplier, + std::int32_t C_zero_point, + std::uint8_t* C, + const std::int32_t* col_offsets, + const std::int32_t* bias, + bool fuse_relu = false, int thread_id = 0, int num_threads = 1); @@ -56,7 +89,10 @@ FBGEMM_API void depthwise_3x3_pad_1( * This version is fused with requantization. * * @col_offsets nullptr if col_offsets are folded into bias + * @act_times_w_scale Only used if BIAS_TYPE is float, i.e., bias is + * unquantized. */ +template <typename BIAS_TYPE = std::int32_t> FBGEMM_API void depthwise_3x3_pad_1( int N, int H, @@ -67,22 +103,24 @@ FBGEMM_API void depthwise_3x3_pad_1( std::int32_t A_zero_point, const std::uint8_t* A, std::int32_t B_zero_point, - const Packed3x3ConvMatrix& Bp, + const PackedDepthWiseConvMatrix& Bp, float C_multiplier, std::int32_t C_zero_point, std::uint8_t* C, const std::int32_t* col_offsets, - const std::int32_t* bias, + const BIAS_TYPE* bias, bool fuse_relu = false, + float act_times_w_scale = 1.0f, int thread_id = 0, int num_threads = 1); /** - * Depth-wise 3x3 convolution with pad=1 and stride=1 and K a multiple of 8 - * This version is fused with requantization and uses per-channel quantization. + * Depth-wise 3x3 convolution with pad=1 and K a multiple of 8, fused with + * requantization, and using per-channel quantization. * * @col_offsets nullptr if col_offsets are folded into bias */ +template <typename BIAS_TYPE = std::int32_t> FBGEMM_API void depthwise_3x3_per_channel_quantization_pad_1( int N, int H, @@ -93,7 +131,31 @@ FBGEMM_API void depthwise_3x3_per_channel_quantization_pad_1( std::int32_t A_zero_point, const std::uint8_t* A, const std::int32_t* B_zero_point, - const Packed3x3ConvMatrix& Bp, + const PackedDepthWiseConvMatrix& Bp, + const float* C_multiplier, + std::int32_t C_zero_point, + std::uint8_t* C, + const std::int32_t* col_offsets, + const BIAS_TYPE* bias, + bool fuse_relu = false, + const float* act_times_w_scale = nullptr, + int thread_id = 0, + int num_threads = 1); + +/** To be removed. Keeping it just to make sure we don't change C2 files and + * fbgemm files in a single diff + */ +FBGEMM_API void depthwise_3x3_per_channel_quantization_pad_1( + int N, + int H, + int W, + int K, + int stride_h, + int stride_w, + std::int32_t A_zero_point, + const std::uint8_t* A, + const std::int32_t* B_zero_point, + const PackedDepthWiseConvMatrix& Bp, const float* C_multiplier, std::int32_t C_zero_point, std::uint8_t* C, @@ -103,6 +165,10 @@ FBGEMM_API void depthwise_3x3_per_channel_quantization_pad_1( int thread_id = 0, int num_threads = 1); +/** To be removed. Keeping it just to make sure we don't change C2 files and + * fbgemm files in a single diff + * + */ FBGEMM_API void depthwise_3x3x3_pad_1( int N, int T, @@ -114,14 +180,20 @@ FBGEMM_API void depthwise_3x3x3_pad_1( int stride_w, std::int32_t A_zero_point, const std::uint8_t* A, - const Packed3x3x3ConvMatrix& Bp, - std::int32_t* C, + std::int32_t B_zero_point, + const PackedDepthWiseConvMatrix& Bp, + float C_multiplier, + std::int32_t C_zero_point, + std::uint8_t* C, + const std::int32_t* col_offsets, + const std::int32_t* bias, + bool fuse_relu = false, int thread_id = 0, int num_threads = 1); - /** * @col_offsets nullptr if col_offsets are folded into bias */ +template <typename BIAS_TYPE = std::int32_t> FBGEMM_API void depthwise_3x3x3_pad_1( int N, int T, @@ -134,11 +206,38 @@ FBGEMM_API void depthwise_3x3x3_pad_1( std::int32_t A_zero_point, const std::uint8_t* A, std::int32_t B_zero_point, - const Packed3x3x3ConvMatrix& Bp, + const PackedDepthWiseConvMatrix& Bp, float C_multiplier, std::int32_t C_zero_point, std::uint8_t* C, const std::int32_t* col_offsets, + const BIAS_TYPE* bias, + bool fuse_relu = false, + float act_times_w_scale = 1.0f, + int thread_id = 0, + int num_threads = 1); + +/** To be removed. Keeping it just to make sure we don't change C2 files and + * fbgemm files in a single diff + * + */ +FBGEMM_API void depthwise_3x3x3_per_channel_quantization_pad_1( + int N, + int T, + int H, + int W, + int K, + int stride_t, + int stride_h, + int stride_w, + std::int32_t A_zero_point, + const std::uint8_t* A, + const std::int32_t* B_zero_point, + const PackedDepthWiseConvMatrix& Bp, + const float* C_multiplier, + std::int32_t C_zero_point, + std::uint8_t* C, + const std::int32_t* col_offsets, const std::int32_t* bias, bool fuse_relu = false, int thread_id = 0, @@ -147,6 +246,7 @@ FBGEMM_API void depthwise_3x3x3_pad_1( /** * @col_offsets nullptr if col_offsets are folded into bias */ +template <typename BIAS_TYPE = std::int32_t> FBGEMM_API void depthwise_3x3x3_per_channel_quantization_pad_1( int N, int T, @@ -159,13 +259,14 @@ FBGEMM_API void depthwise_3x3x3_per_channel_quantization_pad_1( std::int32_t A_zero_point, const std::uint8_t* A, const std::int32_t* B_zero_point, - const Packed3x3x3ConvMatrix& Bp, + const PackedDepthWiseConvMatrix& Bp, const float* C_multiplier, std::int32_t C_zero_point, std::uint8_t* C, const std::int32_t* col_offsets, - const std::int32_t* bias, + const BIAS_TYPE* bias, bool fuse_relu = false, + const float* act_times_w_scale = nullptr, int thread_id = 0, int num_threads = 1); diff --git a/include/fbgemm/OutputProcessing-inl.h b/include/fbgemm/OutputProcessing-inl.h index d984c60..04ae100 100644 --- a/include/fbgemm/OutputProcessing-inl.h +++ b/include/fbgemm/OutputProcessing-inl.h @@ -59,11 +59,13 @@ inline int DoSConvOnInpBuffer<outT, inT, nextOPType>::f( template < bool FUSE_RELU, QuantizationGranularity Q_GRAN, + typename BIAS_TYPE, typename outT, typename inT, typename nextOPType> template <inst_set_t instSet> -inline int ReQuantizeOutput<FUSE_RELU, Q_GRAN, outT, inT, nextOPType>::f( +inline int +ReQuantizeOutput<FUSE_RELU, Q_GRAN, BIAS_TYPE, outT, inT, nextOPType>::f( outT* out, const inT* inp, const block_type_t& block, @@ -98,11 +100,20 @@ inline int ReQuantizeOutput<FUSE_RELU, Q_GRAN, outT, inT, nextOPType>::f( raw -= q_row_offsets_[i - block.row_start] * Bq_zero_point_[Bq_zero_point_idx]; } + float raw_f; if (bias_) { - raw += bias_[j]; + if (std::is_same<BIAS_TYPE, float>::value) { + raw_f = raw; + raw_f += bias_[j] / act_times_w_scale_[Bq_zero_point_idx]; + } else { + raw += bias_[j]; + raw_f = raw; + } + } else { + raw_f = raw; } - float ab = raw * C_multiplier_[Bq_zero_point_idx]; + float ab = raw_f * C_multiplier_[Bq_zero_point_idx]; long rounded = std::lrintf(ab) + C_zero_point_; out[i * ld_out + j] = std::max( @@ -115,15 +126,16 @@ inline int ReQuantizeOutput<FUSE_RELU, Q_GRAN, outT, inT, nextOPType>::f( Bq_zero_point_[0] == 0) || q_row_offsets_ == nullptr; - requantizationParams_t r = {Aq_zero_point_, - Bq_zero_point_, - C_zero_point_, - C_multiplier_, - q_row_offsets_, - q_col_offsets_, - bias_, - ncols_, - groups_}; + requantizationParams_t<BIAS_TYPE> r = {Aq_zero_point_, + Bq_zero_point_, + C_zero_point_, + C_multiplier_, + q_row_offsets_, + q_col_offsets_, + bias_, + ncols_, + groups_, + act_times_w_scale_}; if (Aq_zero_point_ == 0) { if (b_symmetric) { diff --git a/include/fbgemm/PackingTraits-inl.h b/include/fbgemm/PackingTraits-inl.h index 76eb425..baccfad 100644 --- a/include/fbgemm/PackingTraits-inl.h +++ b/include/fbgemm/PackingTraits-inl.h @@ -222,3 +222,53 @@ struct PackingTraits< 128}; ///< Cache block for N dimension (multiple of NR). static constexpr int KCB{256}; ///< Cache block for K dimension. }; + +/** + * @brief Helper struct to type specialize for int16_t and int32_t together. + */ +template <typename T> +struct is_16or32bit { + static constexpr bool value = + std::is_same<T, int16_t>::value || std::is_same<T, int32_t>::value; +}; + +/** + * @brief Packing parameter specialization for accumulation into 32-bit/16-bit + * integers. + * + * Since there is no int16_t accumulation for AVX512 VNNI, we redirect int16_t + * to int32_t accumulation and use the same blocking parameters as int32_t. + * + * This is picked when T is of int8 type (signed or unsigned) and instruction + * set is avx512_vnni. + */ +template <typename T, typename accT> +struct PackingTraits< + T, + accT, + inst_set_t::avx512_vnni, + typename std::enable_if< + is_8bit<T>::value && is_16or32bit<accT>::value>::type> { + static constexpr int MR{8}; ///< Register block for M dimension. + static constexpr int NR_MIN{ + 16}; ///< Minimum register block for N dimension. + ///< 16 because 16*ROW_INTERLEAVE int8 elements + ///< completely fill a 512-bit wide vector. + static constexpr int NR{ + 32}; ///< Register block for N dimension. + ///< Must be a multiple of 16 because 16*ROW_INTERLEAVE int8 elements + ///< completely fill a 512-bit wide vector. Total registers used for + ///< N dimension: NR*ROW_INTERLEAVE*8/VLEN. We use MR x + ///< NR*ROW_INTERLEAVE*8/VLEN zmm registers + ///< for C accumulations. + + static constexpr int ROW_INTERLEAVE{ + 4}; ///< 4 rows are interleaved to use vpmaddubsw instruction for packing + ///< B matrix. + + static constexpr int MCB{ + 128}; ///< Cache block for M dimension (multiple of MR). + static constexpr int NCB{ + 32}; ///< Cache block for N dimension (multiple of NR). + static constexpr int KCB{256}; ///< Cache block for K dimension. +}; diff --git a/include/fbgemm/QuantUtils.h b/include/fbgemm/QuantUtils.h index 43855d8..508ce7d 100644 --- a/include/fbgemm/QuantUtils.h +++ b/include/fbgemm/QuantUtils.h @@ -7,6 +7,7 @@ #include <limits> #include "FbgemmBuild.h" #include "QuantUtilsAvx2.h" +#include "Utils.h" namespace fbgemm { @@ -78,6 +79,40 @@ FBGEMM_API void Quantize( int len, const TensorQuantizationParams& qparams); +/* + * @brief Quantize floating point data in src to type T + * + * @tparam T output quantized data type (int8_t, uint8_t and int32_t are + * supported) + * + * @tparam T LAYOUT layout of input tensor in src. (KCX and KXC are supported) + * KCX corresponds to KCRS or KCTRS (for weight tensors with + * time dimension) + * KXC corresponds to KRSC or KTRSC (for weight tensors with + * time dimension) + * + * @params K Output channels for weight tensors + * @params C Number of channels + * @params X R*S or T*R*S + * @params G Groups (if G == C the function performs channelwise quantization; + * if 1 < G < C the function performs groupwise quantization; + * if G == 1 the function performs per tensor quantization;) + * @params scales floating point scales. + * Size should be equal G + * @params zero_points zero points (should be reprsentable in type T). + * Size should be equal G + */ +template <typename T, layout_t LAYOUT = layout_t::KCX> +FBGEMM_API void QuantizeGroupwise( + const float* src, + int K, + int C, + int X, + int G, + const float* scales, + const std::int32_t* zero_points, + T* dst); + template <typename T> FBGEMM_API float Dequantize(T src, const TensorQuantizationParams& qparams) { return qparams.scale * (src - qparams.zero_point); diff --git a/include/fbgemm/QuantUtilsAvx2.h b/include/fbgemm/QuantUtilsAvx2.h index 47f33a8..c7f3f35 100644 --- a/include/fbgemm/QuantUtilsAvx2.h +++ b/include/fbgemm/QuantUtilsAvx2.h @@ -40,9 +40,10 @@ struct FBGEMM_API RequantizationParams { //////////////////////////////////////////////////////////////////////////////// // Utility functions +template <typename T=std::uint8_t> void QuantizeAvx2( const float* src, - std::uint8_t* dst, + T* dst, int len, const TensorQuantizationParams& qparams); @@ -71,14 +72,15 @@ template < bool B_SYMMETRIC, QuantizationGranularity Q_GRAN, bool HAS_BIAS, - bool FUSE_RELU> + bool FUSE_RELU, + typename BIAS_TYPE = std::int32_t> FBGEMM_API void requantizeOutputProcessingAvx2( std::uint8_t* out, const std::int32_t* inp, const block_type_t& block, int ld_out, int ld_in, - const requantizationParams_t& r); + const requantizationParams_t<BIAS_TYPE>& r); template < bool A_SYMMETRIC, @@ -86,14 +88,15 @@ template < QuantizationGranularity Q_GRAN, bool HAS_BIAS, bool FUSE_RELU, - int C_PER_G> + int C_PER_G, + typename BIAS_TYPE = std::int32_t> FBGEMM_API void requantizeOutputProcessingGConvAvx2( std::uint8_t* out, const std::int32_t* inp, const block_type_t& block, int ld_out, int ld_in, - const requantizationParams_t& r); + const requantizationParams_t<BIAS_TYPE>& r); template < bool A_SYMMETRIC, diff --git a/include/fbgemm/Utils.h b/include/fbgemm/Utils.h index 9f8e1ee..3976790 100644 --- a/include/fbgemm/Utils.h +++ b/include/fbgemm/Utils.h @@ -5,6 +5,7 @@ * LICENSE file in the root directory of this source tree. */ #pragma once +#include <array> #include <string> #include <type_traits> #include "FbgemmBuild.h" @@ -39,12 +40,12 @@ enum class matrix_op_t { NoTranspose, Transpose }; /** * @brief Typed enum for supported instruction sets. */ -enum class inst_set_t { anyarch, avx2, avx512 }; +enum class inst_set_t { anyarch, avx2, avx512, avx512_vnni }; /** * @brief Typed enum for optimized paths for convolutions */ -enum class optimized_conv_t { depthwise, groupwise, im2col }; +enum class optimized_conv_t { depthwise, groupwise, pointwise, im2col }; /** * @brief Typed enum for implementation type. @@ -54,6 +55,13 @@ enum class optimized_conv_t { depthwise, groupwise, im2col }; enum class impl_type_t { ref, opt }; /** + * @brief Typed enum to specify data layout. + * KCX can be KCRS format or KCTRS format (e.g., for 3-D convolutions) + * KXC can be KRSC format or KTRSC format (e.g., for 3-D convolutions) + */ +enum class layout_t { KCX, KXC }; + +/** * @brief A function to compare data in two buffers for closeness/equality. */ template <typename T> @@ -103,6 +111,11 @@ FBGEMM_API bool fbgemmHasAvx512Support(); FBGEMM_API bool fbgemmHasAvx2Support(); /** + * @brief Are we running on a AVX512_VNNI supported cpu? + */ +FBGEMM_API bool fbgemmHasAvx512VnniSupport(); + +/** * @brief Helper struct to enable autotuning of FBGEMM packing and kernels. * * This structure is optional. If not used, the default values for these @@ -119,6 +132,16 @@ struct FBGEMM_API BlockingFactors { int NCB; }; +template <int SIZE, typename T = std::int32_t> +FBGEMM_API std::string arrayToString(const std::array<T, SIZE>& inp) { + std::string out = "["; + for (int i = 0; i < SIZE; ++i) { + out += std::to_string(inp[i]); + out += (i != SIZE - 1) ? std::string(", ") : std::string("]"); + } + return out; +} + template <typename accT = std::int32_t> FBGEMM_API bool isValidBlockingFactor(BlockingFactors* param) { constexpr bool is_32bit = std::is_same<accT, int32_t>::value; @@ -129,10 +152,10 @@ FBGEMM_API bool isValidBlockingFactor(BlockingFactors* param) { return false; if (fbgemmHasAvx512Support()) { - if (param->NR != 16) + if (param->NR_MIN != 16 || param->NR % param->NR_MIN) return false; } else if (fbgemmHasAvx2Support()) { - if (param->NR != 8) + if (param->NR_MIN != 8 || param->NR % param->NR_MIN) return false; } } else if (is_16bit) { @@ -140,10 +163,10 @@ FBGEMM_API bool isValidBlockingFactor(BlockingFactors* param) { return false; if (fbgemmHasAvx512Support()) { - if (param->NR != 32) + if (param->NR_MIN != 32 || param->NR % param->NR_MIN) return false; } else if (fbgemmHasAvx2Support()) { - if (param->NR != 16) + if (param->NR_MIN != 16 || param->NR % param->NR_MIN) return false; } } @@ -153,10 +176,19 @@ FBGEMM_API bool isValidBlockingFactor(BlockingFactors* param) { if (param->NCB % param->NR) return false; if (fbgemmHasAvx512Support()) { - if (param->MR * (param->NCB / param->NR) > 24) - return false; + if (is_32bit) { + // Zmm register usage for C + if (param->MR * (param->NR / param->NR_MIN) > 28) + return false; + } else if (is_16bit) { + // Zmm register usage for C + one row for loading B + if ((param->MR * (param->NR / param->NR_MIN) + + (param->NR / param->NR_MIN)) > 28) + return false; + } + } else if (fbgemmHasAvx2Support()) { - if (param->MR * (param->NCB / param->NR) > 16) + if (param->MR * (param->NR / param->NR_MIN) > 12) return false; } return true; diff --git a/include/fbgemm/UtilsAvx2.h b/include/fbgemm/UtilsAvx2.h index 082edc1..3bac909 100644 --- a/include/fbgemm/UtilsAvx2.h +++ b/include/fbgemm/UtilsAvx2.h @@ -44,16 +44,19 @@ struct block_type_t { * QuantUtilsAvx2.h as it combines all the parameters needed for various * quantization granularities */ +template<typename BIAS_TYPE = std::int32_t> struct requantizationParams_t { + using BIAS_T = BIAS_TYPE; std::int32_t A_zero_point; const std::int32_t* B_zero_point; std::int32_t C_zero_point; const float* C_multiplier; const std::int32_t* row_offsets; const std::int32_t* col_offsets; - const std::int32_t* bias; + const BIAS_T* bias; std::uint32_t ncols; int groups; + const float* act_times_w_scale; }; /** diff --git a/src/CodeCache.h b/src/CodeCache.h new file mode 100644 index 0000000..08e9c9b --- /dev/null +++ b/src/CodeCache.h @@ -0,0 +1,59 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#pragma once +#include <condition_variable> +#include <future> +#include <map> +#include <mutex> + +namespace fbgemm { + +/** + * @brief Thread safe cache for microkernels, ensures single creation per key. + * @tparam Key Type of unique key (typically a tuple) + * @tparam Value Type of the microkernel function (Typically a function pointer) + */ +template <typename KEY, typename VALUE> class CodeCache { +private: + std::map<KEY, std::shared_future<VALUE>> values_; + std::mutex mutex_; + +public: + CodeCache(const CodeCache &) = delete; + CodeCache &operator=(const CodeCache &) = delete; + + CodeCache(){}; + + VALUE getOrCreate(const KEY &key, std::function<VALUE()> generatorFunction) { + std::shared_future<VALUE> returnFuture; + std::promise<VALUE> returnPromise; + bool needsToGenerate = false; + + // Check for existance of the key + { + std::unique_lock<std::mutex> lock(mutex_); + + auto it = values_.find(key); + if (it != values_.end()) { + returnFuture = it->second; + } else { + values_[key] = returnFuture = returnPromise.get_future().share(); + needsToGenerate = true; + } + } + + // The value (code) generation is not happening under a lock + if (needsToGenerate) { + returnPromise.set_value(generatorFunction()); + } + + // Wait for the future and return the value + return returnFuture.get(); + } +}; + +} // namespace fbgemm diff --git a/src/ExecuteKernelU8S8.cc b/src/ExecuteKernelU8S8.cc index f7292fd..4ae1b50 100644 --- a/src/ExecuteKernelU8S8.cc +++ b/src/ExecuteKernelU8S8.cc @@ -49,7 +49,8 @@ ExecuteKernel< throw std::runtime_error("Failed to initialize cpuinfo!"); } if (params) { - if (fbgemmHasAvx512Support() || fbgemmHasAvx2Support()) { + if (fbgemmHasAvx512VnniSupport() || fbgemmHasAvx512Support() || + fbgemmHasAvx2Support()) { mbSize_ = params->MCB; nbSize_ = params->NCB; nrMinSize_ = params->NR_MIN; @@ -59,7 +60,20 @@ ExecuteKernel< assert(0 && "unsupported architecure"); } } else { - if (fbgemmHasAvx512Support()) { + if (fbgemmHasAvx512VnniSupport()) { + mbSize_ = PackingTraits< + int8_t, + typename packingAMatrix::accType, + inst_set_t::avx512_vnni>::MCB; + nbSize_ = PackingTraits< + int8_t, + typename packingAMatrix::accType, + inst_set_t::avx512_vnni>::NCB; + nrMinSize_ = PackingTraits< + int8_t, + typename packingAMatrix::accType, + inst_set_t::avx512_vnni>::NR_MIN; + } else if (fbgemmHasAvx512Support()) { mbSize_ = PackingTraits< int8_t, typename packingAMatrix::accType, @@ -118,7 +132,25 @@ void ExecuteKernel< typename BaseType::jit_micro_kernel_fp fn; - if (fbgemmHasAvx512Support()) { + if (fbgemmHasAvx512VnniSupport()) { + if (std::is_same<typename packingAMatrix::accType, std::int16_t>::value) { + // For AVX512VNNI, we redirect int16_t to int32_t accumulation. + CodeGenBase<uint8_t, int8_t, int32_t, int32_t> codeObj; + fn = codeObj.getOrCreate<inst_set_t::avx512_vnni>( + accum, + packed_rows_A, + packedB_.blockColSize(), + packedA_.numPackedCols(), + nbSize_); + } else { + fn = BaseType::template getOrCreate<inst_set_t::avx512_vnni>( + accum, + packed_rows_A, + packedB_.blockColSize(), + packedA_.numPackedCols(), + nbSize_); + } + } else if (fbgemmHasAvx512Support()) { fn = BaseType::template getOrCreate<inst_set_t::avx512>( accum, packed_rows_A, @@ -148,7 +180,10 @@ void ExecuteKernel< if (jb == bColBlocks - 1) { int nc = ((packedB_.lastBcol() - 1) / nrMinSize_ + 1) * nrMinSize_; if (nc != nbSize_) { - if (fbgemmHasAvx512Support()) { + if (fbgemmHasAvx512VnniSupport()) { + fn = BaseType::template getOrCreate<inst_set_t::avx512_vnni>( + accum, packed_rows_A, nc, packedA_.numPackedCols(), nbSize_); + } else if (fbgemmHasAvx512Support()) { fn = BaseType::template getOrCreate<inst_set_t::avx512>( accum, packed_rows_A, nc, packedA_.numPackedCols(), nbSize_); } else if (fbgemmHasAvx2Support()) { @@ -213,7 +248,7 @@ void ExecuteKernel< int32_t nSize = C_buffer_start == C_tile_ ? jb * nbSize_ : packedB_.numCols(); if (nSize) { - if (fbgemmHasAvx512Support()) { + if (fbgemmHasAvx512VnniSupport() || fbgemmHasAvx512Support()) { // TODO: avx512 path // Currently use avx2 code outputProcess_.template f<inst_set_t::avx2>( @@ -238,7 +273,7 @@ void ExecuteKernel< if (C_buffer_start == C_tile_) { // When C_tile_ scratchpad was used to avoid accessing memory past // C_buffer_ . - if (fbgemmHasAvx512Support()) { + if (fbgemmHasAvx512VnniSupport() || fbgemmHasAvx512Support()) { // TODO: avx512 path // Currently use avx2 code outputProcess_.template f<inst_set_t::avx2>( @@ -280,19 +315,23 @@ void ExecuteKernel< //////////////////////////////////////////////////////////////////////////////// // ReQuantizeOutput -#define INSTANTIATE_REQUANT_BASE(PACK_A, ACC_T, RELU, Q_GRAN) \ - template class ExecuteKernel< \ - PACK_A<uint8_t, ACC_T>, \ - PackBMatrix<int8_t, ACC_T>, \ - uint8_t, \ - ReQuantizeOutput<RELU, Q_GRAN>>; +#define INSTANTIATE_REQUANT_BASE(PACK_A, ACC_T, RELU, Q_GRAN, BIAS_TYPE) \ + template class ExecuteKernel< \ + PACK_A<uint8_t, ACC_T>, \ + PackBMatrix<int8_t, ACC_T>, \ + uint8_t, \ + ReQuantizeOutput<RELU, Q_GRAN, BIAS_TYPE>>; + +#define INSTANTIATE_REQUANT_BIAS_T(PACK_A, ACC_T, RELU, Q_GRAN) \ + INSTANTIATE_REQUANT_BASE(PACK_A, ACC_T, RELU, Q_GRAN, float); \ + INSTANTIATE_REQUANT_BASE(PACK_A, ACC_T, RELU, Q_GRAN, int32_t); #define INSTANTIATE_REQUANT_Q_GRANS(PACK_A, ACC_T, RELU) \ - INSTANTIATE_REQUANT_BASE( \ + INSTANTIATE_REQUANT_BIAS_T( \ PACK_A, ACC_T, RELU, QuantizationGranularity::TENSOR); \ - INSTANTIATE_REQUANT_BASE( \ + INSTANTIATE_REQUANT_BIAS_T( \ PACK_A, ACC_T, RELU, QuantizationGranularity::GROUP); \ - INSTANTIATE_REQUANT_BASE( \ + INSTANTIATE_REQUANT_BIAS_T( \ PACK_A, ACC_T, RELU, QuantizationGranularity::OUT_CHANNEL); #define INSTANTIATE_REQUANT_RELU(PACK_A, ACC_T) \ @@ -309,21 +348,27 @@ INSTANTIATE_REQUANT_ACC_T(PackAWithRowOffset); #undef INSTANTIATE_REQUANT_ACC_T #undef INSTANTIATE_REQUANT_RELU #undef INSTANTIATE_REQUANT_Q_GRANS +#undef INSTANTIATE_REQUANT_BIAS_T #undef INSTANTIATE_REQUANT_BASE -#define INSTANTIATE_IM2COL_REQUANT_BASE(ACC_T, RELU, SPATIAL_DIM, Q_GRAN) \ - template class ExecuteKernel< \ - PackAWithIm2Col<uint8_t, ACC_T, SPATIAL_DIM>, \ - PackBMatrix<int8_t, ACC_T>, \ - uint8_t, \ - ReQuantizeOutput<RELU, Q_GRAN>>; +#define INSTANTIATE_IM2COL_REQUANT_BASE( \ + ACC_T, RELU, SPATIAL_DIM, Q_GRAN, BIAS_TYPE) \ + template class ExecuteKernel< \ + PackAWithIm2Col<uint8_t, ACC_T, SPATIAL_DIM>, \ + PackBMatrix<int8_t, ACC_T>, \ + uint8_t, \ + ReQuantizeOutput<RELU, Q_GRAN, BIAS_TYPE>>; + +#define INSTANTIATE_IM2COL_REQUANT_BIAS_T(ACC_T, RELU, SPATIAL_DIM, Q_GRAN) \ + INSTANTIATE_IM2COL_REQUANT_BASE(ACC_T, RELU, SPATIAL_DIM, Q_GRAN, float); \ + INSTANTIATE_IM2COL_REQUANT_BASE(ACC_T, RELU, SPATIAL_DIM, Q_GRAN, int32_t); #define INSTANTIATE_IM2COL_REQUANT_Q_GRANS(ACC_T, RELU, SPATIAL_DIM) \ - INSTANTIATE_IM2COL_REQUANT_BASE( \ + INSTANTIATE_IM2COL_REQUANT_BIAS_T( \ ACC_T, RELU, SPATIAL_DIM, QuantizationGranularity::TENSOR); \ - INSTANTIATE_IM2COL_REQUANT_BASE( \ + INSTANTIATE_IM2COL_REQUANT_BIAS_T( \ ACC_T, RELU, SPATIAL_DIM, QuantizationGranularity::GROUP); \ - INSTANTIATE_IM2COL_REQUANT_BASE( \ + INSTANTIATE_IM2COL_REQUANT_BIAS_T( \ ACC_T, RELU, SPATIAL_DIM, QuantizationGranularity::OUT_CHANNEL); #define INSTANTIATE_IM2COL_REQUANT_SPATIAL_DIM(ACC_T, RELU) \ @@ -340,6 +385,7 @@ INSTANTIATE_IM2COL_REQUANT_RELU(int16_t); #undef INSTANTIATE_IM2COL_REQUANT_RELU #undef INSTANTIATE_IM2COL_REQUANT_SPATIAL_DIM #undef INSTANTIATE_IM2COL_REQUANT_Q_GRANS +#undef INSTANTIATE_IM2COL_REQUANT_BIAS_T #undef INSTANTIATE_IM2COL_REQUANT_BASE //////////////////////////////////////////////////////////////////////////////// diff --git a/src/Fbgemm.cc b/src/Fbgemm.cc index 2f641ee..b691b88 100644 --- a/src/Fbgemm.cc +++ b/src/Fbgemm.cc @@ -48,7 +48,8 @@ void fbgemmPacked( if (!cpuinfo_initialize()) { throw std::runtime_error("Failed to initialize cpuinfo!"); } - if ((!fbgemmHasAvx512Support() && !fbgemmHasAvx2Support())) { + if ((!fbgemmHasAvx512VnniSupport() && !fbgemmHasAvx512Support() && + !fbgemmHasAvx2Support())) { assert(0 && "unknown architecure"); } @@ -62,7 +63,20 @@ void fbgemmPacked( MR = blocking_params->MR; } else { - if (fbgemmHasAvx512Support()) { + if (fbgemmHasAvx512VnniSupport()) { + MCB = PackingTraits< + typename packingAMatrix::inpType, + typename packingAMatrix::accType, + inst_set_t::avx512_vnni>::MCB; + KCB = PackingTraits< + typename packingAMatrix::inpType, + typename packingAMatrix::accType, + inst_set_t::avx512_vnni>::KCB; + MR = PackingTraits< + typename packingAMatrix::inpType, + typename packingAMatrix::accType, + inst_set_t::avx512_vnni>::MR; + } else if (fbgemmHasAvx512Support()) { MCB = PackingTraits< typename packingAMatrix::inpType, typename packingAMatrix::accType, @@ -223,22 +237,26 @@ bool fbgemmSupportedCPU() { //////////////////////////////////////////////////////////////////////////////// // ReQuantizeOutput -#define INSTANTIATE_BASE(PACK_A, ACC_T, RELU, Q_GRAN) \ +#define INSTANTIATE_BASE(PACK_A, ACC_T, RELU, Q_GRAN, BIAS_TYPE) \ template void fbgemmPacked( \ PackMatrix<PACK_A<uint8_t, ACC_T>, uint8_t, ACC_T>& packA, \ PackMatrix<PackBMatrix<int8_t, ACC_T>, int8_t, ACC_T>& packB, \ uint8_t* C, \ int32_t* C_buffer, \ uint32_t ldc, \ - const ReQuantizeOutput<RELU, Q_GRAN>& outProcess, \ + const ReQuantizeOutput<RELU, Q_GRAN, BIAS_TYPE>& outProcess, \ int thread_id, \ int num_threads, \ const BlockingFactors* blocking_params); -#define INSTANTIATE_Q_GRANS(PACK_A, ACC_T, RELU) \ - INSTANTIATE_BASE(PACK_A, ACC_T, RELU, QuantizationGranularity::TENSOR); \ - INSTANTIATE_BASE(PACK_A, ACC_T, RELU, QuantizationGranularity::GROUP); \ - INSTANTIATE_BASE(PACK_A, ACC_T, RELU, QuantizationGranularity::OUT_CHANNEL); +#define INSTANTIATE_BIAS_T(PACK_A, ACC_T, RELU, Q_GRAN) \ + INSTANTIATE_BASE(PACK_A, ACC_T, RELU, Q_GRAN, float); \ + INSTANTIATE_BASE(PACK_A, ACC_T, RELU, Q_GRAN, int32_t); + +#define INSTANTIATE_Q_GRANS(PACK_A, ACC_T, RELU) \ + INSTANTIATE_BIAS_T(PACK_A, ACC_T, RELU, QuantizationGranularity::TENSOR); \ + INSTANTIATE_BIAS_T(PACK_A, ACC_T, RELU, QuantizationGranularity::GROUP); \ + INSTANTIATE_BIAS_T(PACK_A, ACC_T, RELU, QuantizationGranularity::OUT_CHANNEL); #define INSTANTIATE_RELU(PACK_A, ACC_T) \ INSTANTIATE_Q_GRANS(PACK_A, ACC_T, false); \ @@ -254,27 +272,34 @@ INSTANTIATE_ACC_T(PackAWithRowOffset); #undef INSTANTIATE_ACC_T #undef INSTANTIATE_RELU #undef INSTANTIATE_Q_GRANS +#undef INSTANTIATE_BIAS_T #undef INSTANTIATE_BASE -#define INSTANTIATE_BASE(ACC_T, RELU, SPATIAL_DIM, Q_GRAN) \ - template void fbgemmPacked( \ - PackMatrix< \ - PackAWithIm2Col<uint8_t, ACC_T, SPATIAL_DIM>, \ - uint8_t, \ - ACC_T>& packA, \ - PackMatrix<PackBMatrix<int8_t, ACC_T>, int8_t, ACC_T>& packB, \ - uint8_t* C, \ - int32_t* C_buffer, \ - uint32_t ldc, \ - const ReQuantizeOutput<RELU, Q_GRAN>& outProcess, \ - int thread_id, \ - int num_threads, \ +#define INSTANTIATE_BASE(ACC_T, RELU, SPATIAL_DIM, Q_GRAN, BIAS_TYPE) \ + template void fbgemmPacked( \ + PackMatrix< \ + PackAWithIm2Col<uint8_t, ACC_T, SPATIAL_DIM>, \ + uint8_t, \ + ACC_T>& packA, \ + PackMatrix<PackBMatrix<int8_t, ACC_T>, int8_t, ACC_T>& packB, \ + uint8_t* C, \ + int32_t* C_buffer, \ + uint32_t ldc, \ + const ReQuantizeOutput<RELU, Q_GRAN, BIAS_TYPE>& outProcess, \ + int thread_id, \ + int num_threads, \ const BlockingFactors* blocking_params); -#define INSTANTIATE_Q_GRANS(ACC_T, RELU, SPATIAL_DIM) \ - INSTANTIATE_BASE(ACC_T, RELU, SPATIAL_DIM, QuantizationGranularity::TENSOR); \ - INSTANTIATE_BASE(ACC_T, RELU, SPATIAL_DIM, QuantizationGranularity::GROUP); \ - INSTANTIATE_BASE( \ +#define INSTANTIATE_BIAS_T(ACC_T, RELU, SPATIAL_DIM, Q_GRAN) \ + INSTANTIATE_BASE(ACC_T, RELU, SPATIAL_DIM, Q_GRAN, float); \ + INSTANTIATE_BASE(ACC_T, RELU, SPATIAL_DIM, Q_GRAN, int32_t); + +#define INSTANTIATE_Q_GRANS(ACC_T, RELU, SPATIAL_DIM) \ + INSTANTIATE_BIAS_T( \ + ACC_T, RELU, SPATIAL_DIM, QuantizationGranularity::TENSOR); \ + INSTANTIATE_BIAS_T( \ + ACC_T, RELU, SPATIAL_DIM, QuantizationGranularity::GROUP); \ + INSTANTIATE_BIAS_T( \ ACC_T, RELU, SPATIAL_DIM, QuantizationGranularity::OUT_CHANNEL); #define INSTANTIATE_SPATIAL_DIM(ACC_T, RELU) \ @@ -291,6 +316,7 @@ INSTANTIATE_RELU(int16_t); #undef INSTANTIATE_RELU #undef INSTANTIATE_SPATIAL_DIM #undef INSTANTIATE_Q_GRANS +#undef INSTANTIATE_BIAS_T #undef INSTANTIATE_BASE //////////////////////////////////////////////////////////////////////////////// diff --git a/src/FbgemmConv.cc b/src/FbgemmConv.cc index 5db63f6..de833d2 100644 --- a/src/FbgemmConv.cc +++ b/src/FbgemmConv.cc @@ -6,8 +6,9 @@ */ #include <algorithm> -#include <iostream> +#include <numeric> #include <vector> +#include <functional> #include "fbgemm/Fbgemm.h" namespace fbgemm { @@ -33,12 +34,24 @@ bool takeDepthWiseFastPath(const conv_param_t<SPATIAL_DIM>& conv_p) { }); } +template <int SPATIAL_DIM> +bool takePointWiseFastPath(const conv_param_t<SPATIAL_DIM>& conv_p) { + return std::accumulate(conv_p.K.begin(), conv_p.K.end(), 0) == SPATIAL_DIM && + std::accumulate(conv_p.stride.begin(), conv_p.stride.end(), 0) == + SPATIAL_DIM && + std::accumulate(conv_p.dilation.begin(), conv_p.dilation.end(), 0) == + SPATIAL_DIM && + std::accumulate(conv_p.pad.begin(), conv_p.pad.end(), 0) == 0; +} + template <int SPATIAL_DIM, typename ACC_T> optimized_conv_t ConvFastPath(const conv_param_t<SPATIAL_DIM>& conv_p) { if (takeDepthWiseFastPath<SPATIAL_DIM, ACC_T>(conv_p)) { return optimized_conv_t::depthwise; } else if (fbgemmOptimizedGConv<SPATIAL_DIM>(conv_p)) { return optimized_conv_t::groupwise; + } else if (takePointWiseFastPath<SPATIAL_DIM>(conv_p)) { + return optimized_conv_t::pointwise; } else { return optimized_conv_t::im2col; } @@ -58,58 +71,139 @@ int fbgemmConv( static_assert( SPATIAL_DIM == 2 || SPATIAL_DIM == 3, "Only 2D and 3D convolutions are supported"); + + if (!packed_weights.isPackingCompliant(conv_p)) { + std::string msg = + "[FBGEMM_CONV_ERROR] Convolution parameters " + "mismatch between pre-packed weights and conv invocation! "; + msg += packed_weights.mismatchingParams(conv_p); + msg += std::string( + " Please pack weights using the same parameters " + "with which convolution operation is invoked!"); + throw std::logic_error(msg); + } + switch (ConvFastPath<SPATIAL_DIM, ACC_T>(conv_p)) { case optimized_conv_t::depthwise: { // 2D and 3D depthwise fast path // std::cout << "Depthwise fast path" << std::endl; const std::int32_t* B_zero_point = outProcess.getBZeroPoint(); const float* C_multiplier = outProcess.getCMultiplier(); + const float* act_times_w_scale = outProcess.getActWScale(); if (SPATIAL_DIM == 3) { static_assert( std::is_same<typename processOutputType::outType, std::uint8_t>:: value, "For depthwise, only requantized output is supported"); - depthwise_3x3x3_pad_1( - conv_p.MB, // mini batch - conv_p.IN_DIM[0], // T - conv_p.IN_DIM[1], // H - conv_p.IN_DIM[2], // W - conv_p.OC, // output channels - conv_p.stride[0], // stride_t - conv_p.stride[1], // stride_h - conv_p.stride[2], // stride_w - outProcess.getAZeroPoint(), - activations, - B_zero_point[0], - *(packed_weights.getPackedWFor3DDW()), - C_multiplier[0], - outProcess.getCZeroPoint(), - out, - outProcess.getColOffsets(), - outProcess.getBias(), - outProcess.RELU_FUSED, // fuse_relu - thread_id, - num_threads); + + if (processOutputType::QGRANType == QuantizationGranularity::TENSOR) { + depthwise_3x3x3_pad_1( + conv_p.MB, // mini batch + conv_p.IN_DIM[0], // T + conv_p.IN_DIM[1], // H + conv_p.IN_DIM[2], // W + conv_p.OC, // output channels + conv_p.stride[0], // stride_t + conv_p.stride[1], // stride_h + conv_p.stride[2], // stride_w + outProcess.getAZeroPoint(), + activations, + B_zero_point[0], + *(packed_weights.getPackedWForDepthwise()), + C_multiplier[0], + outProcess.getCZeroPoint(), + out, + outProcess.getColOffsets(), + outProcess.getBias(), + outProcess.RELU_FUSED, // fuse_relu + act_times_w_scale ? act_times_w_scale[0] : 1.0f, + thread_id, + num_threads); + } else if ( + processOutputType::QGRANType == + QuantizationGranularity::OUT_CHANNEL || + processOutputType::QGRANType == QuantizationGranularity::GROUP) { + depthwise_3x3x3_per_channel_quantization_pad_1( + conv_p.MB, // mini batch + conv_p.IN_DIM[0], // T + conv_p.IN_DIM[1], // H + conv_p.IN_DIM[2], // W + conv_p.OC, // output channels + conv_p.stride[0], // stride_t + conv_p.stride[1], // stride_h + conv_p.stride[2], // stride_w + outProcess.getAZeroPoint(), + activations, + B_zero_point, + *(packed_weights.getPackedWForDepthwise()), + C_multiplier, + outProcess.getCZeroPoint(), + out, + outProcess.getColOffsets(), + outProcess.getBias(), + outProcess.RELU_FUSED, // fuse_relu + outProcess.getActWScale(), // act_scale * weight_scale + thread_id, + num_threads); + } else { + std::string msg = + "[FBGEMM_CONV_ERROR] This quantization granularity is " + "not supported"; + throw std::runtime_error(msg); + } } else { - depthwise_3x3_pad_1( - conv_p.MB, // mini batch - conv_p.IN_DIM[0], // H - conv_p.IN_DIM[1], // W - conv_p.OC, // output channels - conv_p.stride[0], // stride_h - conv_p.stride[1], // stride_w - outProcess.getAZeroPoint(), - activations, - B_zero_point[0], - *(packed_weights.getPackedWFor2DDW()), - C_multiplier[0], - outProcess.getCZeroPoint(), - out, - outProcess.getColOffsets(), - outProcess.getBias(), - outProcess.RELU_FUSED, // fuse_relu - thread_id, - num_threads); + if (processOutputType::QGRANType == QuantizationGranularity::TENSOR) { + depthwise_3x3_pad_1( + conv_p.MB, // mini batch + conv_p.IN_DIM[0], // H + conv_p.IN_DIM[1], // W + conv_p.OC, // output channels + conv_p.stride[0], // stride_h + conv_p.stride[1], // stride_w + outProcess.getAZeroPoint(), + activations, + B_zero_point[0], + *(packed_weights.getPackedWForDepthwise()), + C_multiplier[0], + outProcess.getCZeroPoint(), + out, + outProcess.getColOffsets(), + outProcess.getBias(), + outProcess.RELU_FUSED, // fuse_relu + act_times_w_scale ? act_times_w_scale[0] : 1.0f, + thread_id, + num_threads); + } else if ( + processOutputType::QGRANType == + QuantizationGranularity::OUT_CHANNEL || + processOutputType::QGRANType == QuantizationGranularity::GROUP) { + // The number of channels == groups for depthwise convolutions + depthwise_3x3_per_channel_quantization_pad_1( + conv_p.MB, // mini batch + conv_p.IN_DIM[0], // H + conv_p.IN_DIM[1], // W + conv_p.OC, // output channels + conv_p.stride[0], // stride_h + conv_p.stride[1], // stride_w + outProcess.getAZeroPoint(), + activations, + B_zero_point, + *(packed_weights.getPackedWForDepthwise()), + C_multiplier, + outProcess.getCZeroPoint(), + out, + outProcess.getColOffsets(), + outProcess.getBias(), + outProcess.RELU_FUSED, // fuse_relu + outProcess.getActWScale(), // act_scale * weight_scale + thread_id, + num_threads); + } else { + std::string msg = + "[FBGEMM_CONV_ERROR] This quantization granularity is " + "not supported"; + throw std::runtime_error(msg); + } } break; } @@ -134,14 +228,68 @@ int fbgemmConv( num_threads); break; } + case optimized_conv_t::pointwise: { + std::vector<int32_t> row_offset_buf( + PackAWithRowOffset<uint8_t>::rowOffsetBufferSize(blocking_params)); + int image_dim = std::accumulate( + conv_p.IN_DIM.begin(), + conv_p.IN_DIM.end(), + 1, + std::multiplies<int>()); + PackAWithRowOffset<uint8_t, ACC_T> packA( + matrix_op_t::NoTranspose, + conv_p.MB * image_dim, + conv_p.IC, + activations, + conv_p.IC, + nullptr, + conv_p.G, + row_offset_buf.data(), + blocking_params); + + outProcess.setRowOffsets(row_offset_buf.data()); + fbgemmPacked( + packA, + *(packed_weights.getPackedWForPointwise()), + out, + outBuffer, + conv_p.OC, + outProcess, + thread_id, + num_threads, + blocking_params); + break; + } case optimized_conv_t::im2col: { // All other convolutions go through im2col-based implementation // std::cout << "Im2col path" << std::endl; std::vector<int32_t> row_offset_buf( - PackAWithIm2Col<uint8_t, ACC_T, SPATIAL_DIM>::rowOffsetBufferSize()); + PackAWithIm2Col<uint8_t, ACC_T, SPATIAL_DIM>::rowOffsetBufferSize( + blocking_params)); const std::int32_t* b_zero_point = outProcess.getBZeroPoint(); - bool b_symmetric = b_zero_point[0] == 0; + bool b_symmetric = false; + if (processOutputType::QGRANType == QuantizationGranularity::TENSOR) { + b_symmetric = b_zero_point[0] == 0; + } else if ( + processOutputType::QGRANType == QuantizationGranularity::GROUP) { + b_symmetric = + std::all_of(b_zero_point, b_zero_point + conv_p.G, [](int i) { + return i == 0; + }); + } else if ( + processOutputType::QGRANType == + QuantizationGranularity::OUT_CHANNEL) { + b_symmetric = + std::all_of(b_zero_point, b_zero_point + conv_p.OC, [](int i) { + return i == 0; + }); + } else { + std::string msg = + "[FBGEMM_CONV_ERROR] This quantization granularity is " + "not supported"; + throw std::runtime_error(msg); + } PackAWithIm2Col<uint8_t, ACC_T, SPATIAL_DIM> packA( conv_p, activations, @@ -169,21 +317,25 @@ int fbgemmConv( return 0; } -#define INSTANTIATE_BASE(ACC_T, Q_GRAN, RELU, SPATIAL_DIM) \ +#define INSTANTIATE_BASE(ACC_T, Q_GRAN, RELU, SPATIAL_DIM, BIAS_TYPE) \ template int fbgemmConv( \ const conv_param_t<SPATIAL_DIM>& conv_p, \ const std::uint8_t* activations, \ PackWeightsForConv<SPATIAL_DIM, std::int8_t, ACC_T>& packed_weights, \ std::uint8_t* out, \ std::int32_t* outBuffer, \ - ReQuantizeOutput<RELU, Q_GRAN>& outProcess, \ + ReQuantizeOutput<RELU, Q_GRAN, BIAS_TYPE>& outProcess, \ int thread_id, \ int num_threads, \ const BlockingFactors* blocking_params); +#define INSTANTIATE_BIAS_T(ACC_T, Q_GRAN, RELU, SPATIAL_DIM) \ + INSTANTIATE_BASE(ACC_T, Q_GRAN, RELU, SPATIAL_DIM, float); \ + INSTANTIATE_BASE(ACC_T, Q_GRAN, RELU, SPATIAL_DIM, int32_t); + #define INSTANTIATE_SPATIAL_DIM(ACC_T, Q_GRAN, RELU) \ - INSTANTIATE_BASE(ACC_T, Q_GRAN, RELU, 2); \ - INSTANTIATE_BASE(ACC_T, Q_GRAN, RELU, 3); + INSTANTIATE_BIAS_T(ACC_T, Q_GRAN, RELU, 2); \ + INSTANTIATE_BIAS_T(ACC_T, Q_GRAN, RELU, 3); #define INSTANTIATE_RELU(ACC_T, Q_GRAN) \ INSTANTIATE_SPATIAL_DIM(ACC_T, Q_GRAN, true); \ @@ -199,6 +351,7 @@ INSTANTIATE_Q_GRANS(std::int32_t); #undef INSTANTIATE_Q_GRANS #undef INSTANTIATE_RELU #undef INSTANTIATE_SPATIAL_DIM +#undef INSTANTIATE_BIAS_T #undef INSTANTIATE_BASE template bool takeDepthWiseFastPath<2, std::int32_t>( diff --git a/src/FbgemmFP16.cc b/src/FbgemmFP16.cc index f357966..b034f2c 100644 --- a/src/FbgemmFP16.cc +++ b/src/FbgemmFP16.cc @@ -50,6 +50,7 @@ struct KernelInfo { // autotuned kernel splits for various cases m = 1:mb_max // may need re-autotuning for new uarch + // clang-format off static constexpr int partition[121][2][2] = { // NOTE: clang-format wants to use a different formatting but the current // formatting should be easier to read. @@ -175,6 +176,7 @@ struct KernelInfo { { { 6, 19 }, { 5, 1 } }, // 119 { { 6, 20 }, { 0, 0 } }, // 120 }; + // clang-format on }; constexpr KernelInfo::knl_ptr KernelInfo::kernel[7];; constexpr int KernelInfo::partition[121][2][2]; diff --git a/src/FbgemmI8Depthwise3DAvx2.cc b/src/FbgemmI8Depthwise3DAvx2.cc new file mode 100644 index 0000000..2114b20 --- /dev/null +++ b/src/FbgemmI8Depthwise3DAvx2.cc @@ -0,0 +1,1423 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include "fbgemm/FbgemmI8DepthwiseAvx2.h" + +#include <string> +#include <tuple> // for tie + +#include "FbgemmI8DepthwiseAvx2-inl.h" + +using namespace std; + +namespace fbgemm { + +template < + bool SUM_A, + bool REMAINDER = false, + bool PER_CHANNEL_QUANTIZATION = false> +static inline ALWAYS_INLINE void inner_prod_3x3x3_packed_( + int T, + int H, + int W, + int K, + int t_in, + int h_in, + int w_in, + const uint8_t* A, + int32_t A_zero_point, + const int8_t* Bp, + const int32_t* B_zero_point, + int32_t* C, + int remainder, + int32_t* row_offsets) { + __m256i A_zero_point_v = _mm256_set1_epi8(static_cast<uint8_t>(A_zero_point)); + __m256i mask_v = _mm256_setzero_si256(); + if (REMAINDER) { + mask_v = _mm256_loadu_si256( + reinterpret_cast<const __m256i*>(masks[remainder / 4])); + } + + // The code below can be written as a simple R*S loop but the compiler + // doesn't unroll so we're manually unrolling it. + // constexpr int R = 3, S = 3; + // array<__m256i, R * S> a_v; + // for (int r = 0; r < R; ++r) { + // for (int s = 0; s < S; ++s) { + // if (h_in + r >= 0 && h_in + r < H && w_in + s >= 0 && w_in + s < W) { + // if (REMAINDER) { + // a_v[r * S + s] = + // _mm256_maskload_epi32((const int *)(A + (r * W + s) * K), + // mask_v); + // } else { + // a_v[r * S + s] = + // _mm256_lddqu_si256((const __m256i *)(A + (r * W + s) * K)); + // } + // } else { + // a_v[r * S + s] = A_zero_point_v; + // } + // } + // } + __m256i a_v[8]; + a_v[0] = A_zero_point_v; + a_v[1] = A_zero_point_v; + a_v[2] = A_zero_point_v; + a_v[3] = A_zero_point_v; + a_v[4] = A_zero_point_v; + a_v[5] = A_zero_point_v; + a_v[6] = A_zero_point_v; + a_v[7] = A_zero_point_v; + + if (t_in >= 0 && t_in < T) { + if (h_in >= 0 && h_in < H) { + if (w_in >= 0 && w_in < W) { + a_v[0] = load_a<REMAINDER>(A + ((0 * H + 0) * W + 0) * K, mask_v); + } + if (w_in + 1 >= 0 && w_in + 1 < W) { + a_v[1] = load_a<REMAINDER>(A + ((0 * H + 0) * W + 1) * K, mask_v); + } + if (w_in + 2 >= 0 && w_in + 2 < W) { + a_v[2] = load_a<REMAINDER>(A + ((0 * H + 0) * W + 2) * K, mask_v); + } + } + + if (h_in + 1 >= 0 && h_in + 1 < H) { + if (w_in >= 0 && w_in < W) { + a_v[3] = load_a<REMAINDER>(A + ((0 * H + 1) * W + 0) * K, mask_v); + } + if (w_in + 1 >= 0 && w_in + 1 < W) { + a_v[4] = load_a<REMAINDER>(A + ((0 * H + 1) * W + 1) * K, mask_v); + } + if (w_in + 2 >= 0 && w_in + 2 < W) { + a_v[5] = load_a<REMAINDER>(A + ((0 * H + 1) * W + 2) * K, mask_v); + } + } + + if (h_in + 2 >= 0 && h_in + 2 < H) { + if (w_in >= 0 && w_in < W) { + a_v[6] = load_a<REMAINDER>(A + ((0 * H + 2) * W + 0) * K, mask_v); + } + if (w_in + 1 >= 0 && w_in + 1 < W) { + a_v[7] = load_a<REMAINDER>(A + ((0 * H + 2) * W + 1) * K, mask_v); + } + } + } + + __m256i a_sum[4]; + inner_prod_packed_<8, SUM_A, REMAINDER>( + a_v, reinterpret_cast<const __m256i*>(Bp), C, remainder, a_sum); + + a_v[0] = A_zero_point_v; + a_v[1] = A_zero_point_v; + a_v[2] = A_zero_point_v; + a_v[3] = A_zero_point_v; + a_v[4] = A_zero_point_v; + a_v[5] = A_zero_point_v; + a_v[6] = A_zero_point_v; + a_v[7] = A_zero_point_v; + + if (t_in >= 0 && t_in < T) { + if (h_in + 2 >= 0 && h_in + 2 < H) { + if (w_in + 2 >= 0 && w_in + 2 < W) { + a_v[0] = load_a<REMAINDER>(A + ((0 * H + 2) * W + 2) * K, mask_v); + } + } + } + + if (t_in + 1 >= 0 && t_in + 1 < T) { + if (h_in >= 0 && h_in < H) { + if (w_in >= 0 && w_in < W) { + a_v[1] = load_a<REMAINDER>(A + ((1 * H + 0) * W + 0) * K, mask_v); + } + if (w_in + 1 >= 0 && w_in + 1 < W) { + a_v[2] = load_a<REMAINDER>(A + ((1 * H + 0) * W + 1) * K, mask_v); + } + if (w_in + 2 >= 0 && w_in + 2 < W) { + a_v[3] = load_a<REMAINDER>(A + ((1 * H + 0) * W + 2) * K, mask_v); + } + } + + if (h_in + 1 >= 0 && h_in + 1 < H) { + if (w_in >= 0 && w_in < W) { + a_v[4] = load_a<REMAINDER>(A + ((1 * H + 1) * W + 0) * K, mask_v); + } + if (w_in + 1 >= 0 && w_in + 1 < W) { + a_v[5] = load_a<REMAINDER>(A + ((1 * H + 1) * W + 1) * K, mask_v); + } + if (w_in + 2 >= 0 && w_in + 2 < W) { + a_v[6] = load_a<REMAINDER>(A + ((1 * H + 1) * W + 2) * K, mask_v); + } + } + + if (h_in + 2 >= 0 && h_in + 2 < H) { + if (w_in >= 0 && w_in < W) { + a_v[7] = load_a<REMAINDER>(A + ((1 * H + 2) * W + 0) * K, mask_v); + } + } + } + + __m256i a_sum_temp[4]; + inner_prod_packed_<8, SUM_A, REMAINDER, true /* acc */>( + a_v, reinterpret_cast<const __m256i*>(Bp) + 8, C, remainder, a_sum_temp); + if (SUM_A) { + a_sum[0] = _mm256_add_epi32(a_sum[0], a_sum_temp[0]); + a_sum[1] = _mm256_add_epi32(a_sum[1], a_sum_temp[1]); + a_sum[2] = _mm256_add_epi32(a_sum[2], a_sum_temp[2]); + a_sum[3] = _mm256_add_epi32(a_sum[3], a_sum_temp[3]); + } + + a_v[0] = A_zero_point_v; + a_v[1] = A_zero_point_v; + a_v[2] = A_zero_point_v; + a_v[3] = A_zero_point_v; + a_v[4] = A_zero_point_v; + a_v[5] = A_zero_point_v; + a_v[6] = A_zero_point_v; + a_v[7] = A_zero_point_v; + + if (t_in + 1 >= 0 && t_in + 1 < T) { + if (h_in + 2 >= 0 && h_in + 2 < H) { + if (w_in + 1 >= 0 && w_in + 1 < W) { + a_v[0] = load_a<REMAINDER>(A + ((1 * H + 2) * W + 1) * K, mask_v); + } + if (w_in + 2 >= 0 && w_in + 2 < W) { + a_v[1] = load_a<REMAINDER>(A + ((1 * H + 2) * W + 2) * K, mask_v); + } + } + } + + if (t_in + 2 >= 0 && t_in + 2 < T) { + if (h_in >= 0 && h_in < H) { + if (w_in >= 0 && w_in < W) { + a_v[2] = load_a<REMAINDER>(A + ((2 * H + 0) * W + 0) * K, mask_v); + } + if (w_in + 1 >= 0 && w_in + 1 < W) { + a_v[3] = load_a<REMAINDER>(A + ((2 * H + 0) * W + 1) * K, mask_v); + } + if (w_in + 2 >= 0 && w_in + 2 < W) { + a_v[4] = load_a<REMAINDER>(A + ((2 * H + 0) * W + 2) * K, mask_v); + } + } + + if (h_in + 1 >= 0 && h_in + 1 < H) { + if (w_in >= 0 && w_in < W) { + a_v[5] = load_a<REMAINDER>(A + ((2 * H + 1) * W + 0) * K, mask_v); + } + if (w_in + 1 >= 0 && w_in + 1 < W) { + a_v[6] = load_a<REMAINDER>(A + ((2 * H + 1) * W + 1) * K, mask_v); + } + if (w_in + 2 >= 0 && w_in + 2 < W) { + a_v[7] = load_a<REMAINDER>(A + ((2 * H + 1) * W + 2) * K, mask_v); + } + } + } + + inner_prod_packed_<8, SUM_A, REMAINDER, true /* acc */>( + a_v, reinterpret_cast<const __m256i*>(Bp) + 16, C, remainder, a_sum_temp); + if (SUM_A) { + a_sum[0] = _mm256_add_epi32(a_sum[0], a_sum_temp[0]); + a_sum[1] = _mm256_add_epi32(a_sum[1], a_sum_temp[1]); + a_sum[2] = _mm256_add_epi32(a_sum[2], a_sum_temp[2]); + a_sum[3] = _mm256_add_epi32(a_sum[3], a_sum_temp[3]); + } + + a_v[0] = A_zero_point_v; + a_v[1] = A_zero_point_v; + a_v[2] = A_zero_point_v; + + if (t_in + 2 >= 0 && t_in + 2 < T) { + if (h_in + 2 >= 0 && h_in + 2 < H) { + if (w_in >= 0 && w_in < W) { + a_v[0] = load_a<REMAINDER>(A + ((2 * H + 2) * W + 0) * K, mask_v); + } + if (w_in + 1 >= 0 && w_in + 1 < W) { + a_v[1] = load_a<REMAINDER>(A + ((2 * H + 2) * W + 1) * K, mask_v); + } + if (w_in + 2 >= 0 && w_in + 2 < W) { + a_v[2] = load_a<REMAINDER>(A + ((2 * H + 2) * W + 2) * K, mask_v); + } + } + } + + inner_prod_packed_<3, SUM_A, REMAINDER, true /* acc */>( + a_v, reinterpret_cast<const __m256i*>(Bp) + 24, C, remainder, a_sum_temp); + + if (SUM_A) { + a_sum[0] = _mm256_add_epi32(a_sum[0], a_sum_temp[0]); + a_sum[1] = _mm256_add_epi32(a_sum[1], a_sum_temp[1]); + a_sum[2] = _mm256_add_epi32(a_sum[2], a_sum_temp[2]); + a_sum[3] = _mm256_add_epi32(a_sum[3], a_sum_temp[3]); + + __m256i B_zero_point_v; + for (int i = 0; i < (REMAINDER ? (remainder / 8) : 4); ++i) { + if (PER_CHANNEL_QUANTIZATION) { + B_zero_point_v = _mm256_loadu_si256( + reinterpret_cast<const __m256i*>(B_zero_point + i * 8)); + } else { + B_zero_point_v = _mm256_set1_epi32(B_zero_point[0]); + } + _mm256_store_si256( + reinterpret_cast<__m256i*>(&row_offsets[i * 8]), + _mm256_mullo_epi32(a_sum[i], B_zero_point_v)); + } + } +} + +template < + bool FUSE_RELU, + bool HAS_BIAS, + bool A_SYMMETRIC, + bool B_SYMMETRIC, + typename BIAS_TYPE> +static inline ALWAYS_INLINE void depthwise_3x3x3_kernel_( + int T, + int H, + int W, + int K, + int t, + int h, + int w, + int stride_t, + int stride_h, + int stride_w, + int32_t A_zero_point, + const uint8_t* A, + int32_t B_zero_point, + const int8_t* Bp, + float C_multiplier, + int32_t C_zero_point, + int32_t* C_int32, + uint8_t* C_uint8, + int32_t* row_offsets, + const int32_t* col_offsets, + const BIAS_TYPE* bias, + float act_times_w_scale) { + constexpr int R = 3, S = 3; + constexpr int PAD_P = 1, PAD_T = 1, PAD_B = 1, PAD_L = 1, PAD_R = 1; + int H_OUT = (H + PAD_T + PAD_B - R) / stride_h + 1; + int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1; + int t_in = -PAD_P + t * stride_t; + int h_in = -PAD_T + h * stride_h; + int w_in = -PAD_L + w * stride_w; + + int k; + for (k = 0; k < K / 32 * 32; k += 32) { + inner_prod_3x3x3_packed_<!B_SYMMETRIC /*SUM_A*/>( + T, + H, + W, + K, + t_in, + h_in, + w_in, + A + ((t_in * H + h_in) * W + w_in) * K + k, + A_zero_point, + Bp + k * 28, + &B_zero_point, + C_int32 + k, + 0, + B_SYMMETRIC ? nullptr : &row_offsets[k]); + } + int remainder = K - k; + if (remainder) { + inner_prod_3x3x3_packed_<!B_SYMMETRIC /*SUM_A*/, true>( + T, + H, + W, + K, + t_in, + h_in, + w_in, + A + ((t_in * H + h_in) * W + w_in) * K + k, + A_zero_point, + Bp + k * 28, + &B_zero_point, + C_int32 + k, + remainder, + B_SYMMETRIC ? nullptr : &row_offsets[k]); + } + + requantize_< + FUSE_RELU, + HAS_BIAS, + false, /*PER_CHAN_QUANT*/ + A_SYMMETRIC, + B_SYMMETRIC>( + A_zero_point, + &C_multiplier, + C_zero_point, + C_int32, + C_uint8 + ((t * H_OUT + h) * W_OUT + w) * K, + K, + row_offsets, + col_offsets, + bias, + &act_times_w_scale); +} + +template <bool FUSE_RELU, bool HAS_BIAS, bool A_SYMMETRIC, typename BIAS_TYPE> +static inline ALWAYS_INLINE void +depthwise_3x3x3_per_channel_quantization_kernel_( + int T, + int H, + int W, + int K, + int t, + int h, + int w, + int stride_t, + int stride_h, + int stride_w, + int32_t A_zero_point, + const uint8_t* A, + const int32_t* B_zero_point, + const int8_t* Bp, + const float* C_multiplier, + int32_t C_zero_point, + int32_t* C_int32, + uint8_t* C_uint8, + int32_t* row_offsets, + const int32_t* col_offsets, + const BIAS_TYPE* bias, + const float* act_times_w_scale) { + constexpr int R = 3, S = 3; + constexpr int PAD_P = 1, PAD_T = 1, PAD_B = 1, PAD_L = 1, PAD_R = 1; + int H_OUT = (H + PAD_T + PAD_B - R) / stride_h + 1; + int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1; + int t_in = -PAD_P + t * stride_t; + int h_in = -PAD_T + h * stride_h; + int w_in = -PAD_L + w * stride_w; + + int k; + for (k = 0; k < K / 32 * 32; k += 32) { + inner_prod_3x3x3_packed_< + true, /*SUM_A*/ + false, /*remainder*/ + true /*per-channel*/>( + T, + H, + W, + K, + t_in, + h_in, + w_in, + A + ((t_in * H + h_in) * W + w_in) * K + k, + A_zero_point, + Bp + k * 28, + B_zero_point + k, + C_int32 + k, + 0, + &row_offsets[k]); + } + int remainder = K - k; + if (remainder) { + inner_prod_3x3x3_packed_< + true, /*SUM_A*/ + true, /*remainder*/ + true /*per-channel*/>( + T, + H, + W, + K, + t_in, + h_in, + w_in, + A + ((t_in * H + h_in) * W + w_in) * K + k, + A_zero_point, + Bp + k * 28, + B_zero_point + k, + C_int32 + k, + remainder, + &row_offsets[k]); + } + requantize_< + FUSE_RELU, + HAS_BIAS, + true, /*PER_CHAN_QUANT*/ + A_SYMMETRIC, + false /*B_SYMM*/>( + A_zero_point, + C_multiplier, + C_zero_point, + C_int32, + C_uint8 + ((t * H_OUT + h) * W_OUT + w) * K, + K, + row_offsets, + col_offsets, + bias, + act_times_w_scale); +} + +template < + bool FUSE_RELU, + bool HAS_BIAS, + bool A_SYMMETRIC, + bool B_SYMMETRIC, + typename BIAS_TYPE> +static inline ALWAYS_INLINE void depthwise_3x3x3_pad_1_( + int N, + int T, + int H, + int W, + int K, + int stride_t, + int stride_h, + int stride_w, + int32_t A_zero_point, + const uint8_t* A, + int32_t B_zero_point, + const PackedDepthWiseConvMatrix& B, + float C_multiplier, + int32_t C_zero_point, + int32_t* C_int32, + uint8_t* C_uint8, + const int32_t* col_offsets, + const BIAS_TYPE* bias, + float act_times_w_scale, + int thread_id, + int num_threads) { + assert(K % 8 == 0); + constexpr int K_T = 3, K_H = 3, K_W = 3; + constexpr int PAD_P = 1, PAD_N = 1, PAD_T = 1, PAD_B = 1, PAD_L = 1, + PAD_R = 1; + int T_OUT = (T + PAD_P + PAD_N - K_T) / stride_t + 1; + int H_OUT = (H + PAD_T + PAD_B - K_H) / stride_h + 1; + int W_OUT = (W + PAD_L + PAD_R - K_W) / stride_w + 1; + const int8_t* Bp = B.PackedMat(); + + //int32_t row_offsets[(K + 31) / 32 * 32] __attribute__((aligned(64))); + int32_t* row_offsets + = static_cast<int32_t*>(ALIGNED_MALLOC((K + 31) / 32 * 32 * sizeof(int32_t), 64)); + + int n_begin, n_end; + int t_begin, t_end, h_begin, h_end; + if (N >= num_threads) { + int n_per_thread = (N + num_threads - 1) / num_threads; + n_begin = std::min(thread_id * n_per_thread, N); + n_end = std::min(n_begin + n_per_thread, N); + t_begin = 0; + t_end = T_OUT; + h_begin = 0; + h_end = H_OUT; + } else { + int nthreads_per_n = num_threads / N; + n_begin = std::min(thread_id / nthreads_per_n, N); + n_end = std::min(n_begin + 1, N); + + int tid_of_n_begin = std::min(n_begin * nthreads_per_n, num_threads); + int tid_of_n_end = std::min(tid_of_n_begin + nthreads_per_n, num_threads); + int nthreads_of_n = tid_of_n_end - tid_of_n_begin; + int tid_within_n = thread_id - tid_of_n_begin; + assert(tid_within_n >= 0); + assert(tid_within_n < nthreads_of_n); + + // n is processed by num_threads_t * num_threads_h 2D grid of threads + int num_threads_t, num_threads_h; + // num_threads_w <= num_threads_h + tie(num_threads_t, num_threads_h) = closest_factors_(nthreads_of_n); + int tid_t = tid_within_n / num_threads_h; + int tid_h = tid_within_n % num_threads_h; + + int t_per_thread = (T_OUT + num_threads_t - 1) / num_threads_t; + t_begin = std::min(tid_t * t_per_thread, T_OUT); + t_end = std::min(t_begin + t_per_thread, T_OUT); + + int h_per_thread = (H_OUT + num_threads_h - 1) / num_threads_h; + h_begin = std::min(tid_h * h_per_thread, H_OUT); + h_end = std::min(h_begin + h_per_thread, H_OUT); + } + + for (int n = n_begin; n < n_end; ++n) { + const uint8_t* A_base = A + n * T * H * W * K; + uint8_t* C_uint8_base = C_uint8 + n * T_OUT * H_OUT * W_OUT * K; + + for (int t = t_begin; t < t_end; ++t) { + for (int h = h_begin; h < h_end; ++h) { + for (int w = 0; w < W_OUT; ++w) { + depthwise_3x3x3_kernel_< + FUSE_RELU, + HAS_BIAS, + A_SYMMETRIC, + B_SYMMETRIC>( + T, + H, + W, + K, + t, + h, + w, + stride_t, + stride_h, + stride_w, + A_zero_point, + A_base, + B_zero_point, + Bp, + C_multiplier, + C_zero_point, + C_int32, + C_uint8_base, + row_offsets, + col_offsets, + bias, + act_times_w_scale); + } // w + } // h + } // t + } // for each n + FREE(row_offsets); +}; + +template <bool FUSE_RELU, bool HAS_BIAS, bool A_SYMMETRIC, typename BIAS_TYPE> +static inline ALWAYS_INLINE void +depthwise_3x3x3_per_channel_quantization_pad_1_( + int N, + int T, + int H, + int W, + int K, + int stride_t, + int stride_h, + int stride_w, + int32_t A_zero_point, + const uint8_t* A, + const int32_t* B_zero_point, + const PackedDepthWiseConvMatrix& B, + const float* C_multiplier, + int32_t C_zero_point, + int32_t* C_int32, + uint8_t* C_uint8, + const int32_t* col_offsets, + const BIAS_TYPE* bias, + const float* act_times_w_scale, + int thread_id, + int num_threads) { + assert(K % 8 == 0); + constexpr int K_T = 3, K_H = 3, K_W = 3; + constexpr int PAD_P = 1, PAD_N = 1, PAD_T = 1, PAD_B = 1, PAD_L = 1, + PAD_R = 1; + int T_OUT = (T + PAD_P + PAD_N - K_T) / stride_t + 1; + int H_OUT = (H + PAD_T + PAD_B - K_H) / stride_h + 1; + int W_OUT = (W + PAD_L + PAD_R - K_W) / stride_w + 1; + const int8_t* Bp = B.PackedMat(); + + //int32_t row_offsets[(K + 31) / 32 * 32] __attribute__((aligned(64))); + int32_t* row_offsets + = static_cast<int32_t*>(ALIGNED_MALLOC((K + 31) / 32 * 32 * sizeof(int32_t), 64)); + + int n_begin, n_end; + int t_begin, t_end, h_begin, h_end; + if (N >= num_threads) { + int n_per_thread = (N + num_threads - 1) / num_threads; + n_begin = std::min(thread_id * n_per_thread, N); + n_end = std::min(n_begin + n_per_thread, N); + t_begin = 0; + t_end = T_OUT; + h_begin = 0; + h_end = H_OUT; + } else { + int nthreads_per_n = num_threads / N; + n_begin = std::min(thread_id / nthreads_per_n, N); + n_end = std::min(n_begin + 1, N); + + int tid_of_n_begin = std::min(n_begin * nthreads_per_n, num_threads); + int tid_of_n_end = std::min(tid_of_n_begin + nthreads_per_n, num_threads); + int nthreads_of_n = tid_of_n_end - tid_of_n_begin; + int tid_within_n = thread_id - tid_of_n_begin; + assert(tid_within_n >= 0); + assert(tid_within_n < nthreads_of_n); + + // n is processed by num_threads_t * num_threads_h 2D grid of threads + int num_threads_t, num_threads_h; + // num_threads_w <= num_threads_h + tie(num_threads_t, num_threads_h) = closest_factors_(nthreads_of_n); + int tid_t = tid_within_n / num_threads_h; + int tid_h = tid_within_n % num_threads_h; + + int t_per_thread = (T_OUT + num_threads_t - 1) / num_threads_t; + t_begin = std::min(tid_t * t_per_thread, T_OUT); + t_end = std::min(t_begin + t_per_thread, T_OUT); + + int h_per_thread = (H_OUT + num_threads_h - 1) / num_threads_h; + h_begin = std::min(tid_h * h_per_thread, H_OUT); + h_end = std::min(h_begin + h_per_thread, H_OUT); + } + + for (int n = n_begin; n < n_end; ++n) { + const uint8_t* A_base = A + n * T * H * W * K; + uint8_t* C_uint8_base = C_uint8 + n * T_OUT * H_OUT * W_OUT * K; + + for (int t = t_begin; t < t_end; ++t) { + for (int h = h_begin; h < h_end; ++h) { + for (int w = 0; w < W_OUT; ++w) { + depthwise_3x3x3_per_channel_quantization_kernel_< + FUSE_RELU, + HAS_BIAS, + A_SYMMETRIC, + BIAS_TYPE>( + T, + H, + W, + K, + t, + h, + w, + stride_t, + stride_h, + stride_w, + A_zero_point, + A_base, + B_zero_point, + Bp, + C_multiplier, + C_zero_point, + C_int32, + C_uint8_base, + row_offsets, + col_offsets, + bias, + act_times_w_scale); + } // w + } // h + } // t + } // for each n + FREE(row_offsets); +}; + +// Dispatch A_SYMMETRIC and B_SYMMETRIC +template <bool FUSE_RELU, bool HAS_BIAS, typename BIAS_TYPE> +static void depthwise_3x3x3_pad_1_( + int N, + int T, + int H, + int W, + int K, + int stride_t, + int stride_h, + int stride_w, + int32_t A_zero_point, + const uint8_t* A, + int32_t B_zero_point, + const PackedDepthWiseConvMatrix& B, + float C_multiplier, + int32_t C_zero_point, + uint8_t* C, + const int32_t* col_offsets, + const BIAS_TYPE* bias, + float act_times_w_scale, + int thread_id, + int num_threads) { + int32_t* C_int32_temp = new int32_t[(K + 31) / 32 * 32]; + if (A_zero_point == 0 || col_offsets == nullptr) { + if (B_zero_point == 0) { + depthwise_3x3x3_pad_1_< + FUSE_RELU, + HAS_BIAS, + true /*A_symmetric*/, + true /*B_symmetric*/, + BIAS_TYPE>( + N, + T, + H, + W, + K, + stride_t, + stride_h, + stride_w, + A_zero_point, + A, + B_zero_point, + B, + C_multiplier, + C_zero_point, + C_int32_temp, + C, + col_offsets, + bias, + act_times_w_scale, + thread_id, + num_threads); + } else { + depthwise_3x3x3_pad_1_< + FUSE_RELU, + HAS_BIAS, + true /*A_symmetric*/, + false /*B_symmetric*/, + BIAS_TYPE>( + N, + T, + H, + W, + K, + stride_t, + stride_h, + stride_w, + A_zero_point, + A, + B_zero_point, + B, + C_multiplier, + C_zero_point, + C_int32_temp, + C, + col_offsets, + bias, + act_times_w_scale, + thread_id, + num_threads); + } + } else { + if (B_zero_point == 0) { + depthwise_3x3x3_pad_1_< + FUSE_RELU, + HAS_BIAS, + false /*A_symmetric*/, + true /*B_symmetric*/, + BIAS_TYPE>( + N, + T, + H, + W, + K, + stride_t, + stride_h, + stride_w, + A_zero_point, + A, + B_zero_point, + B, + C_multiplier, + C_zero_point, + C_int32_temp, + C, + col_offsets, + bias, + act_times_w_scale, + thread_id, + num_threads); + } else { + depthwise_3x3x3_pad_1_< + FUSE_RELU, + HAS_BIAS, + false /*A_symmetric*/, + false /*B_symmetric*/, + BIAS_TYPE>( + N, + T, + H, + W, + K, + stride_t, + stride_h, + stride_w, + A_zero_point, + A, + B_zero_point, + B, + C_multiplier, + C_zero_point, + C_int32_temp, + C, + col_offsets, + bias, + act_times_w_scale, + thread_id, + num_threads); + } + } + delete[] C_int32_temp; +} + +// Dispatch HAS_BIAS +template <bool FUSE_RELU, typename BIAS_TYPE> +static void depthwise_3x3x3_pad_1_( + int N, + int T, + int H, + int W, + int K, + int stride_t, + int stride_h, + int stride_w, + int32_t A_zero_point, + const uint8_t* A, + int32_t B_zero_point, + const PackedDepthWiseConvMatrix& B, + float C_multiplier, + int32_t C_zero_point, + uint8_t* C, + const int32_t* col_offsets, + const BIAS_TYPE* bias, + float act_times_w_scale, + int thread_id, + int num_threads) { + if (bias) { + depthwise_3x3x3_pad_1_<FUSE_RELU, true /*HAS_BIAS*/, BIAS_TYPE>( + N, + T, + H, + W, + K, + stride_t, + stride_h, + stride_w, + A_zero_point, + A, + B_zero_point, + B, + C_multiplier, + C_zero_point, + C, + col_offsets, + bias, + act_times_w_scale, + thread_id, + num_threads); + } else { + depthwise_3x3x3_pad_1_<FUSE_RELU, false /*HAS_BIAS*/, BIAS_TYPE>( + N, + T, + H, + W, + K, + stride_t, + stride_h, + stride_w, + A_zero_point, + A, + B_zero_point, + B, + C_multiplier, + C_zero_point, + C, + col_offsets, + bias, + act_times_w_scale, + thread_id, + num_threads); + } +} + +// Dispatch FUSE_RELU +template <typename BIAS_TYPE> +void depthwise_3x3x3_pad_1( + int N, + int T, + int H, + int W, + int K, + int stride_t, + int stride_h, + int stride_w, + int32_t A_zero_point, + const uint8_t* A, + int32_t B_zero_point, + const PackedDepthWiseConvMatrix& B, + float C_multiplier, + int32_t C_zero_point, + uint8_t* C, + const int32_t* col_offsets, + const BIAS_TYPE* bias, + bool fuse_relu, + float act_times_w_scale, + int thread_id, + int num_threads) { + if (B.GetKernelProduct() != 3 * 3 * 3) { + string msg = + "[FBGEMM_CONV_ERROR] Packed weight is expected to have kernel_prod " + + to_string(3 * 3 * 3) + " but has " + to_string(B.GetKernelProduct()); + throw logic_error(msg); + } + if (stride_t == 0 || stride_h == 0 || stride_w == 0 || num_threads == 0) { + assert( + 0 && + "stride_t == 0 || stride_h == 0 || stride_w == 0 || num_threads == 0"); + return; + } + if (N == 0) { + // In C2, batch size 0 is allowed, so we should just early return. + return; + } + if (fuse_relu) { + depthwise_3x3x3_pad_1_<true /*FUSE_RELU*/, BIAS_TYPE>( + N, + T, + H, + W, + K, + stride_t, + stride_h, + stride_w, + A_zero_point, + A, + B_zero_point, + B, + C_multiplier, + C_zero_point, + C, + col_offsets, + bias, + act_times_w_scale, + thread_id, + num_threads); + } else { + depthwise_3x3x3_pad_1_<false /*FUSE_RELU*/, BIAS_TYPE>( + N, + T, + H, + W, + K, + stride_t, + stride_h, + stride_w, + A_zero_point, + A, + B_zero_point, + B, + C_multiplier, + C_zero_point, + C, + col_offsets, + bias, + act_times_w_scale, + thread_id, + num_threads); + } +} + +// Dispatch A_SYMMETRIC +template <bool FUSE_RELU, bool HAS_BIAS, typename BIAS_TYPE> +static void depthwise_3x3x3_per_channel_quantization_pad_1_( + int N, + int T, + int H, + int W, + int K, + int stride_t, + int stride_h, + int stride_w, + int32_t A_zero_point, + const uint8_t* A, + const int32_t* B_zero_point, + const PackedDepthWiseConvMatrix& B, + const float* C_multiplier, + int32_t C_zero_point, + uint8_t* C, + const int32_t* col_offsets, + const BIAS_TYPE* bias, + const float* act_times_w_scale, + int thread_id, + int num_threads) { + int32_t* C_int32_temp = new int32_t[(K + 31) / 32 * 32]; + if (A_zero_point == 0 || col_offsets == nullptr) { + depthwise_3x3x3_per_channel_quantization_pad_1_< + FUSE_RELU, + HAS_BIAS, + true /*A_SYMM*/, + BIAS_TYPE>( + N, + T, + H, + W, + K, + stride_t, + stride_h, + stride_w, + A_zero_point, + A, + B_zero_point, + B, + C_multiplier, + C_zero_point, + C_int32_temp, + C, + col_offsets, + bias, + act_times_w_scale, + thread_id, + num_threads); + } else { + depthwise_3x3x3_per_channel_quantization_pad_1_< + FUSE_RELU, + HAS_BIAS, + false /*A_SYMM*/, + BIAS_TYPE>( + N, + T, + H, + W, + K, + stride_t, + stride_h, + stride_w, + A_zero_point, + A, + B_zero_point, + B, + C_multiplier, + C_zero_point, + C_int32_temp, + C, + col_offsets, + bias, + act_times_w_scale, + thread_id, + num_threads); + } + delete[] C_int32_temp; +} + +// Dispatch HAS_BIAS +template <bool FUSE_RELU, typename BIAS_TYPE> +static void depthwise_3x3x3_per_channel_quantization_pad_1_( + int N, + int T, + int H, + int W, + int K, + int stride_t, + int stride_h, + int stride_w, + int32_t A_zero_point, + const uint8_t* A, + const int32_t* B_zero_point, + const PackedDepthWiseConvMatrix& B, + const float* C_multiplier, + int32_t C_zero_point, + uint8_t* C, + const int32_t* col_offsets, + const BIAS_TYPE* bias, + const float* act_times_w_scale, + int thread_id, + int num_threads) { + if (bias) { + depthwise_3x3x3_per_channel_quantization_pad_1_< + FUSE_RELU, + true /* HAS_BIAS */, + BIAS_TYPE>( + N, + T, + H, + W, + K, + stride_t, + stride_h, + stride_w, + A_zero_point, + A, + B_zero_point, + B, + C_multiplier, + C_zero_point, + C, + col_offsets, + bias, + act_times_w_scale, + thread_id, + num_threads); + } else { + depthwise_3x3x3_per_channel_quantization_pad_1_< + FUSE_RELU, + false /* HAS_BIAS */, + BIAS_TYPE>( + N, + T, + H, + W, + K, + stride_t, + stride_h, + stride_w, + A_zero_point, + A, + B_zero_point, + B, + C_multiplier, + C_zero_point, + C, + col_offsets, + bias, + act_times_w_scale, + thread_id, + num_threads); + } +} + +// Dispatch FUSE_RELU +template <typename BIAS_TYPE> +void depthwise_3x3x3_per_channel_quantization_pad_1( + int N, + int T, + int H, + int W, + int K, + int stride_t, + int stride_h, + int stride_w, + int32_t A_zero_point, + const uint8_t* A, + const int32_t* B_zero_point, + const PackedDepthWiseConvMatrix& B, + const float* C_multiplier, + int32_t C_zero_point, + uint8_t* C, + const int32_t* col_offsets, + const BIAS_TYPE* bias, + bool fuse_relu, + const float* act_times_w_scale, + int thread_id, + int num_threads) { + if (B.GetKernelProduct() != 3 * 3 * 3) { + string msg = + "[FBGEMM_CONV_ERROR] Packed weight is expected to have kernel_prod " + + to_string(3 * 3 * 3) + " but has " + to_string(B.GetKernelProduct()); + throw logic_error(msg); + } + if (stride_t == 0 || stride_h == 0 || stride_w == 0 || num_threads == 0) { + assert( + 0 && + "stride_t == 0 || stride_h == 0 || stride_w == 0 || num_threads == 0"); + return; + } + if (N == 0) { + // In C2, batch size 0 is allowed, so we should just early return. + return; + } + if (fuse_relu) { + depthwise_3x3x3_per_channel_quantization_pad_1_< + true /* FUSE_RELU */, + BIAS_TYPE>( + N, + T, + H, + W, + K, + stride_t, + stride_h, + stride_w, + A_zero_point, + A, + B_zero_point, + B, + C_multiplier, + C_zero_point, + C, + col_offsets, + bias, + act_times_w_scale, + thread_id, + num_threads); + } else { + depthwise_3x3x3_per_channel_quantization_pad_1_< + false /* FUSE_RELU */, + BIAS_TYPE>( + N, + T, + H, + W, + K, + stride_t, + stride_h, + stride_w, + A_zero_point, + A, + B_zero_point, + B, + C_multiplier, + C_zero_point, + C, + col_offsets, + bias, + act_times_w_scale, + thread_id, + num_threads); + } +} + +// To be removed +void depthwise_3x3x3_pad_1( + int N, + int T, + int H, + int W, + int K, + int stride_t, + int stride_h, + int stride_w, + int32_t A_zero_point, + const uint8_t* A, + int32_t B_zero_point, + const PackedDepthWiseConvMatrix& B, + float C_multiplier, + int32_t C_zero_point, + uint8_t* C, + const int32_t* col_offsets, + const int32_t* bias, + bool fuse_relu, + int thread_id, + int num_threads) { + depthwise_3x3x3_pad_1<int32_t>( + N, + T, + H, + W, + K, + stride_t, + stride_h, + stride_w, + A_zero_point, + A, + B_zero_point, + B, + C_multiplier, + C_zero_point, + C, + col_offsets, + bias, + fuse_relu, + 1.0f, // act_scale * weight_scale + thread_id, + num_threads); +} + +void depthwise_3x3x3_per_channel_quantization_pad_1( + int N, + int T, + int H, + int W, + int K, + int stride_t, + int stride_h, + int stride_w, + int32_t A_zero_point, + const uint8_t* A, + const int32_t* B_zero_point, + const PackedDepthWiseConvMatrix& B, + const float* C_multiplier, + int32_t C_zero_point, + uint8_t* C, + const int32_t* col_offsets, + const int32_t* bias, + bool fuse_relu, + int thread_id, + int num_threads) { + depthwise_3x3x3_per_channel_quantization_pad_1( + N, + T, + H, + W, + K, + stride_t, + stride_h, + stride_w, + A_zero_point, + A, + B_zero_point, + B, + C_multiplier, + C_zero_point, + C, + col_offsets, + bias, + fuse_relu, + nullptr, // act_scale * weight_scale + thread_id, + num_threads); +} + +template void depthwise_3x3x3_pad_1( + int N, + int T, + int H, + int W, + int K, + int stride_t, + int stride_h, + int stride_w, + int32_t A_zero_point, + const uint8_t* A, + int32_t B_zero_point, + const PackedDepthWiseConvMatrix& B, + float C_multiplier, + int32_t C_zero_point, + uint8_t* C, + const int32_t* col_offsets, + const int32_t* bias, + bool fuse_relu, + float act_times_w_scale, + int thread_id, + int num_threads); + +template void depthwise_3x3x3_pad_1( + int N, + int T, + int H, + int W, + int K, + int stride_t, + int stride_h, + int stride_w, + int32_t A_zero_point, + const uint8_t* A, + int32_t B_zero_point, + const PackedDepthWiseConvMatrix& B, + float C_multiplier, + int32_t C_zero_point, + uint8_t* C, + const int32_t* col_offsets, + const float* bias, + bool fuse_relu, + float act_times_w_scale, + int thread_id, + int num_threads); + +template void depthwise_3x3x3_per_channel_quantization_pad_1( + int N, + int T, + int H, + int W, + int K, + int stride_t, + int stride_h, + int stride_w, + int32_t A_zero_point, + const uint8_t* A, + const int32_t* B_zero_point, + const PackedDepthWiseConvMatrix& B, + const float* C_multiplier, + int32_t C_zero_point, + uint8_t* C, + const int32_t* col_offsets, + const int32_t* bias, + bool fuse_relu, + const float* act_times_w_scale, + int thread_id, + int num_threads); + +template void depthwise_3x3x3_per_channel_quantization_pad_1( + int N, + int T, + int H, + int W, + int K, + int stride_t, + int stride_h, + int stride_w, + int32_t A_zero_point, + const uint8_t* A, + const int32_t* B_zero_point, + const PackedDepthWiseConvMatrix& B, + const float* C_multiplier, + int32_t C_zero_point, + uint8_t* C, + const int32_t* col_offsets, + const float* bias, + bool fuse_relu, + const float* act_times_w_scale, + int thread_id, + int num_threads); + +} // namespace fbgemm diff --git a/src/FbgemmI8DepthwiseAvx2-inl.h b/src/FbgemmI8DepthwiseAvx2-inl.h new file mode 100644 index 0000000..aee9ab3 --- /dev/null +++ b/src/FbgemmI8DepthwiseAvx2-inl.h @@ -0,0 +1,710 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#pragma once + +#include <algorithm> // for min and max +#include <cassert> +#include <cmath> // for lrintf and sqrt +#include <cstdint> +#include <type_traits> // for is_same + +#include <immintrin.h> +#include "fbgemm/Utils.h" + +namespace fbgemm { + +// clang-format off +static int masks[8][8] = { + // NOTE: clang-format wants to use a different formatting but the current + // formatting should be easier to read. + { 0, 0, 0, 0, 0, 0, 0, 0, }, + { -1, 0, 0, 0, 0, 0, 0, 0, }, + { -1, -1, 0, 0, 0, 0, 0, 0, }, + { -1, -1, -1, 0, 0, 0, 0, 0, }, + { -1, -1, -1, -1, 0, 0, 0, 0, }, + { -1, -1, -1, -1, -1, 0, 0, 0, }, + { -1, -1, -1, -1, -1, -1, 0, 0, }, + { -1, -1, -1, -1, -1, -1, -1, 0, }, +}; +// clang-format on + +// c = a0 * b0 + a1 * b1 + a2 * b2 + a3 * b3 +// A is in uint8_t +// B is in int8_t and pre-interleaved +// C is in int32_t and 4 registers have results in the following layout: +// c0_v: c[0:4], c[16:20] +// c1_v: c[4:8], c[20:24] +// c2_v: c[8:12], c[24:28] +// c3_v: c[12:16], c[28:32] +template <bool SUM_A = false> +static inline ALWAYS_INLINE void madd_epi16x4_packed( + __m256i a0_v, + __m256i a1_v, + __m256i a2_v, + __m256i a3_v, + const __m256i* b, + __m256i* c0_v, + __m256i* c1_v, + __m256i* c2_v, + __m256i* c3_v, + __m256i* a_sum = nullptr) { + __m256i a01_lo_v = _mm256_unpacklo_epi8(a0_v, a1_v); + __m256i a01_hi_v = _mm256_unpackhi_epi8(a0_v, a1_v); + __m256i a23_lo_v = _mm256_unpacklo_epi8(a2_v, a3_v); + __m256i a23_hi_v = _mm256_unpackhi_epi8(a2_v, a3_v); + + if (SUM_A) { + __m256i one_epi8_v = _mm256_set1_epi8(1); + a_sum[0] = + _mm256_adds_epi16(_mm256_maddubs_epi16(a01_lo_v, one_epi8_v), a_sum[0]); + a_sum[1] = + _mm256_adds_epi16(_mm256_maddubs_epi16(a01_hi_v, one_epi8_v), a_sum[1]); + a_sum[0] = + _mm256_adds_epi16(_mm256_maddubs_epi16(a23_lo_v, one_epi8_v), a_sum[0]); + a_sum[1] = + _mm256_adds_epi16(_mm256_maddubs_epi16(a23_hi_v, one_epi8_v), a_sum[1]); + } + + __m256i a0_interleaved_v = _mm256_unpacklo_epi16(a01_lo_v, a23_lo_v); + __m256i a1_interleaved_v = _mm256_unpackhi_epi16(a01_lo_v, a23_lo_v); + __m256i a2_interleaved_v = _mm256_unpacklo_epi16(a01_hi_v, a23_hi_v); + __m256i a3_interleaved_v = _mm256_unpackhi_epi16(a01_hi_v, a23_hi_v); + + __m256i b0_v = _mm256_load_si256(b + 0); + __m256i b1_v = _mm256_load_si256(b + 1); + __m256i b2_v = _mm256_load_si256(b + 2); + __m256i b3_v = _mm256_load_si256(b + 3); + + __m256i ab0 = _mm256_maddubs_epi16(a0_interleaved_v, b0_v); + __m256i ab1 = _mm256_maddubs_epi16(a1_interleaved_v, b1_v); + __m256i ab2 = _mm256_maddubs_epi16(a2_interleaved_v, b2_v); + __m256i ab3 = _mm256_maddubs_epi16(a3_interleaved_v, b3_v); + + __m256i one_v = _mm256_set1_epi16(1); + *c0_v = _mm256_madd_epi16(ab0, one_v); + *c1_v = _mm256_madd_epi16(ab1, one_v); + *c2_v = _mm256_madd_epi16(ab2, one_v); + *c3_v = _mm256_madd_epi16(ab3, one_v); +} + +// c = a0 * b0 + a1 * b1 + a2 * b2 +// A is in uint8_t +// B is in int8_t and pre-interleaved +// C is in int32_t and 4 registers have results in the following layout: +// c0_v: c[0:4], c[16:20] +// c1_v: c[4:8], c[20:24] +// c2_v: c[8:12], c[24:28] +// c3_v: c[12:16], c[28:32] +template <bool SUM_A = false> +static inline ALWAYS_INLINE void madd_epi16x3_packed( + __m256i a0_v, + __m256i a1_v, + __m256i a2_v, + const __m256i* b, + __m256i* c0_v, + __m256i* c1_v, + __m256i* c2_v, + __m256i* c3_v, + __m256i* a_sum = nullptr) { + __m256i zero_v = _mm256_setzero_si256(); + + __m256i a01_lo_v = _mm256_unpacklo_epi8(a0_v, a1_v); + __m256i a01_hi_v = _mm256_unpackhi_epi8(a0_v, a1_v); + __m256i a23_lo_v = _mm256_unpacklo_epi8(a2_v, zero_v); + __m256i a23_hi_v = _mm256_unpackhi_epi8(a2_v, zero_v); + + if (SUM_A) { + __m256i one_epi8_v = _mm256_set1_epi8(1); + a_sum[0] = + _mm256_adds_epi16(_mm256_maddubs_epi16(a01_lo_v, one_epi8_v), a_sum[0]); + a_sum[1] = + _mm256_adds_epi16(_mm256_maddubs_epi16(a01_hi_v, one_epi8_v), a_sum[1]); + a_sum[0] = + _mm256_adds_epi16(_mm256_maddubs_epi16(a23_lo_v, one_epi8_v), a_sum[0]); + a_sum[1] = + _mm256_adds_epi16(_mm256_maddubs_epi16(a23_hi_v, one_epi8_v), a_sum[1]); + } + + __m256i a0_interleaved_v = _mm256_unpacklo_epi16(a01_lo_v, a23_lo_v); + __m256i a1_interleaved_v = _mm256_unpackhi_epi16(a01_lo_v, a23_lo_v); + __m256i a2_interleaved_v = _mm256_unpacklo_epi16(a01_hi_v, a23_hi_v); + __m256i a3_interleaved_v = _mm256_unpackhi_epi16(a01_hi_v, a23_hi_v); + + __m256i b0_v = _mm256_load_si256(b + 0); + __m256i b1_v = _mm256_load_si256(b + 1); + __m256i b2_v = _mm256_load_si256(b + 2); + __m256i b3_v = _mm256_load_si256(b + 3); + + __m256i ab0 = _mm256_maddubs_epi16(a0_interleaved_v, b0_v); + __m256i ab1 = _mm256_maddubs_epi16(a1_interleaved_v, b1_v); + __m256i ab2 = _mm256_maddubs_epi16(a2_interleaved_v, b2_v); + __m256i ab3 = _mm256_maddubs_epi16(a3_interleaved_v, b3_v); + + __m256i one_v = _mm256_set1_epi16(1); + *c0_v = _mm256_madd_epi16(ab0, one_v); + *c1_v = _mm256_madd_epi16(ab1, one_v); + *c2_v = _mm256_madd_epi16(ab2, one_v); + *c3_v = _mm256_madd_epi16(ab3, one_v); +} + +// c = a0 * b0 + a1 * b1 +// A is in uint8_t +// B is in int8_t and pre-interleaved +// C is in int32_t and 4 registers have results in the following layout: +// c0_v: c[0:4], c[4:8] +// c1_v: c[8:12], c[12:16] +// c2_v: c[16:20], c[20:24] +// c3_v: c[24:28], c[28:32] +template <bool SUM_A = false> +static inline ALWAYS_INLINE void madd_epi16x2_packed( + __m256i a0_v, + __m256i a1_v, + const __m256i* b, + __m256i* c0_v, + __m256i* c1_v, + __m256i* c2_v, + __m256i* c3_v, + __m256i* a_sum = nullptr) { + __m256i a_lo_v = _mm256_unpacklo_epi8(a0_v, a1_v); + __m256i a_hi_v = _mm256_unpackhi_epi8(a0_v, a1_v); + + if (SUM_A) { + __m256i one_epi8_v = _mm256_set1_epi8(1); + a_sum[0] = + _mm256_adds_epi16(_mm256_maddubs_epi16(a_lo_v, one_epi8_v), a_sum[0]); + a_sum[1] = + _mm256_adds_epi16(_mm256_maddubs_epi16(a_hi_v, one_epi8_v), a_sum[1]); + } + + __m256i b0_v = _mm256_load_si256(b + 0); + __m256i b1_v = _mm256_load_si256(b + 1); + + __m256i ab_lo_v = _mm256_maddubs_epi16(a_lo_v, b0_v); + __m256i ab_hi_v = _mm256_maddubs_epi16(a_hi_v, b1_v); + + *c0_v = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(ab_lo_v)); + *c1_v = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(ab_hi_v)); + *c2_v = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(ab_lo_v, 1)); + *c3_v = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(ab_hi_v, 1)); +} + +// c = a0 * b0 +// A is in uint8_t +// B is in int8_t and pre-interleaved +// C is in int32_t and 4 registers have results in the following layout: +// c0_v: c[0:4], c[4:8] +// c1_v: c[8:12], c[12:16] +// c2_v: c[16:20], c[20:24] +// c3_v: c[24:28], c[28:32] +template <bool SUM_A = false> +static inline ALWAYS_INLINE void madd_epi16_packed( + __m256i a_v, + const __m256i* b, + __m256i* c0_v, + __m256i* c1_v, + __m256i* c2_v, + __m256i* c3_v, + __m256i* a_sum = nullptr) { + __m256i zero_v = _mm256_setzero_si256(); + + __m256i a_lo_v = _mm256_unpacklo_epi8(a_v, zero_v); + __m256i a_hi_v = _mm256_unpackhi_epi8(a_v, zero_v); + + if (SUM_A) { + __m256i one_epi8_v = _mm256_set1_epi8(1); + a_sum[0] = + _mm256_adds_epi16(_mm256_maddubs_epi16(a_lo_v, one_epi8_v), a_sum[0]); + a_sum[1] = + _mm256_adds_epi16(_mm256_maddubs_epi16(a_hi_v, one_epi8_v), a_sum[1]); + } + + __m256i b0_v = _mm256_load_si256(b + 0); + __m256i b1_v = _mm256_load_si256(b + 1); + + __m256i ab_lo_v = _mm256_maddubs_epi16(a_lo_v, b0_v); + __m256i ab_hi_v = _mm256_maddubs_epi16(a_hi_v, b1_v); + + *c0_v = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(ab_lo_v)); + *c1_v = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(ab_hi_v)); + *c2_v = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(ab_lo_v, 1)); + *c3_v = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(ab_hi_v, 1)); +} + +// K is the number of accumulations we're doing +template <int K, bool SUM_A = false, bool REMAINDER = false, bool ACC = false> +static inline ALWAYS_INLINE void inner_prod_packed_( + const __m256i* a_v, + const __m256i* Bp, + std::int32_t* C, + int remainder, + __m256i* a_sum = nullptr) { + __m256i c[4], c_temp[4]; + __m256i a_sum_temp[2] = {0, 0}; + + int k = 0; + if (K >= 4) { + madd_epi16x4_packed<SUM_A>( + a_v[0], + a_v[1], + a_v[2], + a_v[3], + Bp, + &c[0], + &c[1], + &c[2], + &c[3], + a_sum_temp); + + for (k = 4; k < K / 4 * 4; k += 4) { + madd_epi16x4_packed<SUM_A>( + a_v[k + 0], + a_v[k + 1], + a_v[k + 2], + a_v[k + 3], + Bp + k, + &c_temp[0], + &c_temp[1], + &c_temp[2], + &c_temp[3], + a_sum_temp); + + c[0] = _mm256_add_epi32(c[0], c_temp[0]); + c[1] = _mm256_add_epi32(c[1], c_temp[1]); + c[2] = _mm256_add_epi32(c[2], c_temp[2]); + c[3] = _mm256_add_epi32(c[3], c_temp[3]); + } + } else { + c[0] = _mm256_setzero_si256(); + c[1] = _mm256_setzero_si256(); + c[2] = _mm256_setzero_si256(); + c[3] = _mm256_setzero_si256(); + } + + if (K - k == 3) { + madd_epi16x3_packed<SUM_A>( + a_v[k], + a_v[k + 1], + a_v[k + 2], + Bp + k, + &c_temp[0], + &c_temp[1], + &c_temp[2], + &c_temp[3], + a_sum_temp); + + c[0] = _mm256_add_epi32(c[0], c_temp[0]); + c[1] = _mm256_add_epi32(c[1], c_temp[1]); + c[2] = _mm256_add_epi32(c[2], c_temp[2]); + c[3] = _mm256_add_epi32(c[3], c_temp[3]); + } + + c_temp[0] = _mm256_permute2f128_si256(c[0], c[1], 0x20); + c_temp[1] = _mm256_permute2f128_si256(c[2], c[3], 0x20); + c_temp[2] = _mm256_permute2f128_si256(c[0], c[1], 0x31); + c_temp[3] = _mm256_permute2f128_si256(c[2], c[3], 0x31); + + if (K - k == 0 || K - k == 3) { + c[0] = c_temp[0]; + c[1] = c_temp[1]; + c[2] = c_temp[2]; + c[3] = c_temp[3]; + } else { + if (K - k == 1) { + madd_epi16_packed<SUM_A>( + a_v[k], Bp + k, &c[0], &c[1], &c[2], &c[3], a_sum_temp); + } else if (K - k == 2) { + madd_epi16x2_packed<SUM_A>( + a_v[k], a_v[k + 1], Bp + k, &c[0], &c[1], &c[2], &c[3], a_sum_temp); + } + + c[0] = _mm256_add_epi32(c[0], c_temp[0]); + c[1] = _mm256_add_epi32(c[1], c_temp[1]); + c[2] = _mm256_add_epi32(c[2], c_temp[2]); + c[3] = _mm256_add_epi32(c[3], c_temp[3]); + } + + if (REMAINDER) { + for (int r = 0; r < remainder / 8; ++r) { + if (ACC) { + _mm256_storeu_si256( + reinterpret_cast<__m256i*>(C + r * 8), + _mm256_add_epi32( + _mm256_loadu_si256(reinterpret_cast<__m256i*>(C + r * 8)), + c[r])); + } else { + _mm256_storeu_si256(reinterpret_cast<__m256i*>(C + r * 8), c[r]); + } + } + } else { + if (ACC) { + _mm256_storeu_si256( + reinterpret_cast<__m256i*>(C), + _mm256_add_epi32( + _mm256_loadu_si256(reinterpret_cast<__m256i*>(C)), c[0])); + _mm256_storeu_si256( + reinterpret_cast<__m256i*>(C + 8), + _mm256_add_epi32( + _mm256_loadu_si256(reinterpret_cast<__m256i*>(C + 8)), c[1])); + _mm256_storeu_si256( + reinterpret_cast<__m256i*>(C + 16), + _mm256_add_epi32( + _mm256_loadu_si256(reinterpret_cast<__m256i*>(C + 16)), c[2])); + _mm256_storeu_si256( + reinterpret_cast<__m256i*>(C + 24), + _mm256_add_epi32( + _mm256_loadu_si256(reinterpret_cast<__m256i*>(C + 24)), c[3])); + } else { + _mm256_storeu_si256(reinterpret_cast<__m256i*>(C), c[0]); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(C + 8), c[1]); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(C + 16), c[2]); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(C + 24), c[3]); + } + } + + if (SUM_A) { + a_sum[0] = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(a_sum_temp[0])); + a_sum[1] = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(a_sum_temp[1])); + a_sum[2] = + _mm256_cvtepi16_epi32(_mm256_extracti128_si256(a_sum_temp[0], 1)); + a_sum[3] = + _mm256_cvtepi16_epi32(_mm256_extracti128_si256(a_sum_temp[1], 1)); + } +} + +// Almost same as ReQuantizeOutput in OutputProcessing-inh.h but different +// row_offsets for each row because of depth-wise convolution +template < + bool FUSE_RELU, + bool HAS_BIAS, + bool PER_CHANNEL_QUANTIZATION, + bool A_SYMMETRIC, + bool B_SYMMETRIC, + typename BIAS_TYPE> +static inline ALWAYS_INLINE void requantize_( + std::int32_t A_zero_point, + const float* C_multiplier, + std::int32_t C_zero_point, + const std::int32_t* C_int32, + std::uint8_t* C_uint8, + int n, + const std::int32_t* row_offsets, + const std::int32_t* col_offsets, + const BIAS_TYPE* bias, + const float* act_times_w_scale = nullptr) { + __m256 multiplier_v = _mm256_setzero_ps(); + // Broadcasted reciprocal of act_times_w_scale + __m256 act_times_w_rcp_v = _mm256_setzero_ps(); + if (!PER_CHANNEL_QUANTIZATION) { + multiplier_v = _mm256_set1_ps(*C_multiplier); + if (std::is_same<BIAS_TYPE, float>::value) { + act_times_w_rcp_v = _mm256_set1_ps(1.0f / (*act_times_w_scale)); + } + } + + __m256i min_v = _mm256_set1_epi8(static_cast<std::uint8_t>(0)); + __m256i max_v = _mm256_set1_epi8(static_cast<std::uint8_t>(255)); + + if (A_SYMMETRIC) { + assert(A_zero_point == 0 || col_offsets == nullptr); + } + __m256i A_zero_point_v = _mm256_set1_epi32(A_zero_point); + __m256i C_zero_point_epi16_v = _mm256_set1_epi16(C_zero_point); + __m256i C_zero_point_epi8_v = _mm256_set1_epi8(C_zero_point); + + __m256i permute_mask_v = + _mm256_set_epi32(0x07, 0x03, 0x06, 0x02, 0x05, 0x01, 0x04, 0x00); + + constexpr int VLEN = 8; + int j = 0; + for (; j < n / (VLEN * 4) * (VLEN * 4); j += (VLEN * 4)) { + __m256i x_v = + _mm256_loadu_si256(reinterpret_cast<const __m256i*>(C_int32 + j)); + __m256i y_v = _mm256_loadu_si256( + reinterpret_cast<const __m256i*>(C_int32 + j + VLEN)); + __m256i z_v = _mm256_loadu_si256( + reinterpret_cast<const __m256i*>(C_int32 + j + 2 * VLEN)); + __m256i w_v = _mm256_loadu_si256( + reinterpret_cast<const __m256i*>(C_int32 + j + 3 * VLEN)); + + __m256i row_offset_v; + if (!B_SYMMETRIC) { + row_offset_v = + _mm256_loadu_si256(reinterpret_cast<const __m256i*>(row_offsets + j)); + x_v = _mm256_sub_epi32(x_v, row_offset_v); + } + __m256i col_off_v; + if (!A_SYMMETRIC) { + col_off_v = _mm256_mullo_epi32( + A_zero_point_v, + _mm256_loadu_si256( + reinterpret_cast<const __m256i*>(col_offsets + j))); + x_v = _mm256_sub_epi32(x_v, col_off_v); + } + + if (!B_SYMMETRIC) { + row_offset_v = _mm256_loadu_si256( + reinterpret_cast<const __m256i*>(row_offsets + j + VLEN)); + y_v = _mm256_sub_epi32(y_v, row_offset_v); + } + if (!A_SYMMETRIC) { + col_off_v = _mm256_mullo_epi32( + A_zero_point_v, + _mm256_loadu_si256( + reinterpret_cast<const __m256i*>(col_offsets + j + VLEN))); + y_v = _mm256_sub_epi32(y_v, col_off_v); + } + + if (!B_SYMMETRIC) { + row_offset_v = _mm256_loadu_si256( + reinterpret_cast<const __m256i*>(row_offsets + j + 2 * VLEN)); + z_v = _mm256_sub_epi32(z_v, row_offset_v); + } + if (!A_SYMMETRIC) { + col_off_v = _mm256_mullo_epi32( + A_zero_point_v, + _mm256_loadu_si256( + reinterpret_cast<const __m256i*>(col_offsets + j + 2 * VLEN))); + z_v = _mm256_sub_epi32(z_v, col_off_v); + } + + if (!B_SYMMETRIC) { + row_offset_v = _mm256_loadu_si256( + reinterpret_cast<const __m256i*>(row_offsets + j + 3 * VLEN)); + w_v = _mm256_sub_epi32(w_v, row_offset_v); + } + if (!A_SYMMETRIC) { + col_off_v = _mm256_mullo_epi32( + A_zero_point_v, + _mm256_loadu_si256( + reinterpret_cast<const __m256i*>(col_offsets + j + 3 * VLEN))); + w_v = _mm256_sub_epi32(w_v, col_off_v); + } + + // convert to float + __m256 xf_v, yf_v, zf_v, wf_v; + if (HAS_BIAS) { // static if + if (std::is_same<BIAS_TYPE, float>::value) { + __m256 x_bias_v, y_bias_v, z_bias_v, w_bias_v; + if (PER_CHANNEL_QUANTIZATION) { + x_bias_v = _mm256_div_ps( + _mm256_loadu_ps( + reinterpret_cast<const float*>(bias + j + 0 * VLEN)), + _mm256_loadu_ps(act_times_w_scale + j + 0 * VLEN)); + y_bias_v = _mm256_div_ps( + _mm256_loadu_ps( + reinterpret_cast<const float*>(bias + j + 1 * VLEN)), + _mm256_loadu_ps(act_times_w_scale + j + 1 * VLEN)); + z_bias_v = _mm256_div_ps( + _mm256_loadu_ps( + reinterpret_cast<const float*>(bias + j + 2 * VLEN)), + _mm256_loadu_ps(act_times_w_scale + j + 2 * VLEN)); + w_bias_v = _mm256_div_ps( + _mm256_loadu_ps( + reinterpret_cast<const float*>(bias + j + 3 * VLEN)), + _mm256_loadu_ps(act_times_w_scale + j + 3 * VLEN)); + } else { + x_bias_v = _mm256_mul_ps( + _mm256_loadu_ps( + reinterpret_cast<const float*>(bias + j + 0 * VLEN)), + act_times_w_rcp_v); + y_bias_v = _mm256_mul_ps( + _mm256_loadu_ps( + reinterpret_cast<const float*>(bias + j + 1 * VLEN)), + act_times_w_rcp_v); + z_bias_v = _mm256_mul_ps( + _mm256_loadu_ps( + reinterpret_cast<const float*>(bias + j + 2 * VLEN)), + act_times_w_rcp_v); + w_bias_v = _mm256_mul_ps( + _mm256_loadu_ps( + reinterpret_cast<const float*>(bias + j + 3 * VLEN)), + act_times_w_rcp_v); + } + xf_v = _mm256_add_ps(_mm256_cvtepi32_ps(x_v), x_bias_v); + yf_v = _mm256_add_ps(_mm256_cvtepi32_ps(y_v), y_bias_v); + zf_v = _mm256_add_ps(_mm256_cvtepi32_ps(z_v), z_bias_v); + wf_v = _mm256_add_ps(_mm256_cvtepi32_ps(w_v), w_bias_v); + } else { + x_v = _mm256_add_epi32( + x_v, + _mm256_loadu_si256( + reinterpret_cast<const __m256i*>(bias + j + 0 * VLEN))); + y_v = _mm256_add_epi32( + y_v, + _mm256_loadu_si256( + reinterpret_cast<const __m256i*>(bias + j + 1 * VLEN))); + z_v = _mm256_add_epi32( + z_v, + _mm256_loadu_si256( + reinterpret_cast<const __m256i*>(bias + j + 2 * VLEN))); + w_v = _mm256_add_epi32( + w_v, + _mm256_loadu_si256( + reinterpret_cast<const __m256i*>(bias + j + 3 * VLEN))); + xf_v = _mm256_cvtepi32_ps(x_v); + yf_v = _mm256_cvtepi32_ps(y_v); + zf_v = _mm256_cvtepi32_ps(z_v); + wf_v = _mm256_cvtepi32_ps(w_v); + } + } else { + xf_v = _mm256_cvtepi32_ps(x_v); + yf_v = _mm256_cvtepi32_ps(y_v); + zf_v = _mm256_cvtepi32_ps(z_v); + wf_v = _mm256_cvtepi32_ps(w_v); + } + + if (PER_CHANNEL_QUANTIZATION) { + multiplier_v = _mm256_loadu_ps(C_multiplier + j + 0 * VLEN); + } + __m256 x_scaled_v = _mm256_mul_ps(xf_v, multiplier_v); + if (PER_CHANNEL_QUANTIZATION) { + multiplier_v = _mm256_loadu_ps(C_multiplier + j + 1 * VLEN); + } + __m256 y_scaled_v = _mm256_mul_ps(yf_v, multiplier_v); + if (PER_CHANNEL_QUANTIZATION) { + multiplier_v = _mm256_loadu_ps(C_multiplier + j + 2 * VLEN); + } + __m256 z_scaled_v = _mm256_mul_ps(zf_v, multiplier_v); + if (PER_CHANNEL_QUANTIZATION) { + multiplier_v = _mm256_loadu_ps(C_multiplier + j + 3 * VLEN); + } + __m256 w_scaled_v = _mm256_mul_ps(wf_v, multiplier_v); + + __m256i x_rounded_v = _mm256_cvtps_epi32(x_scaled_v); + __m256i y_rounded_v = _mm256_cvtps_epi32(y_scaled_v); + __m256i z_rounded_v = _mm256_cvtps_epi32(z_scaled_v); + __m256i w_rounded_v = _mm256_cvtps_epi32(w_scaled_v); + + __m256i xy_packed_v = _mm256_adds_epi16( + _mm256_packs_epi32(x_rounded_v, y_rounded_v), C_zero_point_epi16_v); + __m256i zw_packed_v = _mm256_adds_epi16( + _mm256_packs_epi32(z_rounded_v, w_rounded_v), C_zero_point_epi16_v); + __m256i xyzw_packed_v = _mm256_packus_epi16(xy_packed_v, zw_packed_v); + __m256i xyzw_clamped_v = _mm256_max_epu8( + FUSE_RELU ? C_zero_point_epi8_v : min_v, + _mm256_min_epu8(xyzw_packed_v, max_v)); + + xyzw_clamped_v = + _mm256_permutevar8x32_epi32(xyzw_clamped_v, permute_mask_v); + + _mm256_storeu_si256( + reinterpret_cast<__m256i*>(C_uint8 + j), xyzw_clamped_v); + } // j loop vectorized and unrolled 4x + + for (; j < n / VLEN * VLEN; j += VLEN) { + __m256i x_v = + _mm256_loadu_si256(reinterpret_cast<const __m256i*>(C_int32 + j)); + + if (!B_SYMMETRIC) { + __m256i row_offset_v = + _mm256_loadu_si256(reinterpret_cast<const __m256i*>(row_offsets + j)); + x_v = _mm256_sub_epi32(x_v, row_offset_v); + } + if (!A_SYMMETRIC) { + __m256i col_off_v = _mm256_mullo_epi32( + A_zero_point_v, + _mm256_loadu_si256( + reinterpret_cast<const __m256i*>(col_offsets + j))); + x_v = _mm256_sub_epi32(x_v, col_off_v); + } + + // Convert to float + __m256 xf_v; + if (HAS_BIAS) { // static if + if (std::is_same<BIAS_TYPE, float>::value) { + __m256 x_bias_v; + if (PER_CHANNEL_QUANTIZATION) { + x_bias_v = _mm256_div_ps( + _mm256_loadu_ps(reinterpret_cast<const float*>(bias + j)), + _mm256_loadu_ps(act_times_w_scale + j)); + } else { + x_bias_v = _mm256_mul_ps( + _mm256_loadu_ps(reinterpret_cast<const float*>(bias + j)), + act_times_w_rcp_v); + } + xf_v = _mm256_add_ps(_mm256_cvtepi32_ps(x_v), x_bias_v); + } else { + x_v = _mm256_add_epi32( + x_v, + _mm256_loadu_si256(reinterpret_cast<const __m256i*>(bias + j))); + xf_v = _mm256_cvtepi32_ps(x_v); + } + } else { + xf_v = _mm256_cvtepi32_ps(x_v); + } + + if (PER_CHANNEL_QUANTIZATION) { + multiplier_v = _mm256_loadu_ps(C_multiplier + j); + } + __m256 x_scaled_v = _mm256_mul_ps(xf_v, multiplier_v); + __m256i x_rounded_v = _mm256_cvtps_epi32(x_scaled_v); + + __m256i x_packed_v = _mm256_adds_epi16( + _mm256_packs_epi32(x_rounded_v, _mm256_setzero_si256()), + C_zero_point_epi16_v); + x_packed_v = _mm256_packus_epi16(x_packed_v, _mm256_setzero_si256()); + __m256i x_clamped_v = _mm256_max_epu8( + FUSE_RELU ? C_zero_point_epi8_v : min_v, + _mm256_min_epu8(x_packed_v, max_v)); + + x_clamped_v = _mm256_permutevar8x32_epi32(x_clamped_v, permute_mask_v); + + _mm_storel_epi64( + reinterpret_cast<__m128i*>(C_uint8 + j), + _mm256_castsi256_si128(x_clamped_v)); + } // j loop vectorized + + for (; j < n; ++j) { + std::int32_t raw = C_int32[j]; + if (!B_SYMMETRIC) { + raw -= row_offsets[j]; + } + if (!A_SYMMETRIC) { + raw -= A_zero_point * col_offsets[j]; + } + float raw_f; + if (HAS_BIAS) { // static if + if (std::is_same<BIAS_TYPE, float>::value) { + raw_f = raw; + raw_f += bias[j] / act_times_w_scale[PER_CHANNEL_QUANTIZATION ? j : 0]; + } else { + raw += bias[j]; + raw_f = raw; + } + } else { + raw_f = raw; + } + + float ab = raw_f * C_multiplier[PER_CHANNEL_QUANTIZATION ? j : 0]; + long rounded = lrintf(ab) + C_zero_point; + + C_uint8[j] = std::max( + FUSE_RELU ? static_cast<long>(C_zero_point) : 0l, + std::min(255l, rounded)); + } +} + +template <bool REMAINDER> +static inline ALWAYS_INLINE __m256i load_a( + const std::uint8_t* A, + __m256i mask_v) { + if (REMAINDER) { + return _mm256_maskload_epi32(reinterpret_cast<const int*>(A), mask_v); + } else { + return _mm256_lddqu_si256(reinterpret_cast<const __m256i*>(A)); + } +} + +static inline std::pair<int, int> closest_factors_(int n) { + int a = static_cast<int>(std::sqrt(n)); + while (n % a != 0) { + a--; + } + return {a, n / a}; // a <= n / a +} + +} // namespace fbgemm diff --git a/src/FbgemmI8DepthwiseAvx2.cc b/src/FbgemmI8DepthwiseAvx2.cc index f96d1d2..994f206 100644 --- a/src/FbgemmI8DepthwiseAvx2.cc +++ b/src/FbgemmI8DepthwiseAvx2.cc @@ -7,523 +7,15 @@ #include "fbgemm/FbgemmI8DepthwiseAvx2.h" #include "fbgemm/Utils.h" -#include <algorithm> // for min and max -#include <cassert> -#include <cmath> // for lrintf and sqrt +#include <string> #include <tuple> // for tie -#include <immintrin.h> +#include "FbgemmI8DepthwiseAvx2-inl.h" using namespace std; namespace fbgemm { -static int masks[8][8] = { - // NOTE: clang-format wants to use a different formatting but the current - // formatting should be easier to read. - { 0, 0, 0, 0, 0, 0, 0, 0, }, - { -1, 0, 0, 0, 0, 0, 0, 0, }, - { -1, -1, 0, 0, 0, 0, 0, 0, }, - { -1, -1, -1, 0, 0, 0, 0, 0, }, - { -1, -1, -1, -1, 0, 0, 0, 0, }, - { -1, -1, -1, -1, -1, 0, 0, 0, }, - { -1, -1, -1, -1, -1, -1, 0, 0, }, - { -1, -1, -1, -1, -1, -1, -1, 0, }, -}; - -template <int KERNEL_PROD> -PackedDepthWiseConvMatrix<KERNEL_PROD>::PackedDepthWiseConvMatrix( - int K, - const int8_t* smat) - : K_(K) { - // Transpose the input matrix to make packing faster. - int8_t* smat_transposed = static_cast<int8_t *>(ALIGNED_MALLOC( - K * KERNEL_PROD * sizeof(int8_t), 64)); - for (int i = 0; i < KERNEL_PROD; ++i) { - for (int j = 0; j < K; ++j) { - smat_transposed[i * K + j] = smat[i + j * KERNEL_PROD]; - } - } - - // Allocate packed arrays - constexpr int KERNEL_PROD_ALIGNED = (KERNEL_PROD + 1) / 2 * 2; -#ifdef _MSC_VER - pmat_ = static_cast<int8_t *>(_aligned_malloc( - ((K + 31) / 32) * KERNEL_PROD_ALIGNED * 32 * sizeof(int8_t), 64)); -#else - posix_memalign( - (void**)&pmat_, - 64, - ((K + 31) / 32) * KERNEL_PROD_ALIGNED * 32 * sizeof(int8_t)); -#endif - - // Pack input matrix - // The layout is optimized to use vpmaddubsw efficiently (see - // madd_epi16x4_packed function). - // For a group of 32 channels, we have 10 32B SIMD registers. - // Denote ith channel jth filter as (i, j) - // 0th SIMD register: - // (0, 0), (0, 1), (0, 2), (0, 3), ..., (3, 0), (3, 1), (3, 2), (3, 3) - // (16, 0), (16, 1), (16, 2), (16, 3), ..., (19, 0), (19, 1), (19, 2), (19, 3) - // 1st SIMD register: - // (4, 0), (4, 1), (4, 2), (4, 3), ..., (7, 0), (7, 1), (7, 2), (7, 3) - // (20, 0), (20, 1), (20, 2), (20, 3), ..., (23, 0), (23, 1), (23, 2), (23, 3) - // 2nd SIMD register: - // (8, 0), (8, 1), (8, 2), (8, 3), ..., (11, 0), (11, 1), (11, 2), (11, 3) - // (24, 0), (24, 1), (24, 2), (24, 3), ..., (27, 0), (27, 1), (27, 2), (27, 3) - // 3rd SIMD register: - // (12, 0), (12, 1), (12, 2), (12, 3), ..., (15, 0), (15, 1), (15, 2), (15, 3) - // (28, 0), (28, 1), (28, 2), (28, 3), ..., (31, 0), (31, 1), (31, 2), (31, 3) - // 4-7th SIMD register: same as the previous 4 registers but for 4-7th filter - // coefficients - // ... - // - // REMAINDER - // If KERNEL_PROD % 4 == 1 for example when KERNEL_PROD == 9 - // 8th SIMD register: - // (0, 8), zero, ..., (7, 8), zero - // (16, 8), zero, ..., (23, 8), zero - // 9th SIMD register: - // (8, 8), zero, ..., (15, 8), zero - // (24, 8), zero, ..., (31, 8), zero - // We use madd_epi16_packed for this case - // - // If KERNEL_PROD % 4 == 2 for example when KERNEL_PROD == 10 - // 8th SIMD register: - // (0, 8), (0, 9), ..., (7, 8), (7, 9) - // (16, 8), (16, 9), ..., (23, 8), (23, 9) - // 9th SIMD register: - // (8, 8), (8, 9), ..., (15, 8), (15, 9) - // (24, 8), (24, 9), ..., (31, 8), (31, 9) - // - // If KERNEL_PROD % 4 == 3 for example when KERNEL_PROD == 11 - // 8th SIMD register: - // (0, 8), (0, 9), (0, 10), zero, ..., (3, 8), (3, 9), (3, 10), zero - // (16, 8), (16, 9), (16, 10), zero, ..., (19, 8), (19, 9), (19, 10), zero - // 9th SIMD register: - // (4, 8), (4, 9), (4, 10), zero, ..., (7, 8), (7, 9), (7, 10), zero - // (20, 8), (20, 9), (20, 10), zero, ..., (23, 8), (23, 9), (23, 10), zero - // 10th SIMD register: - // (8, 8), (8, 9), (8, 10), zero, ..., (11, 8), (11, 9), (11, 10), zero - // (24, 8), (24, 9), (24, 10), zero, ..., (27, 8), (27, 9), (27, 10), zero - // 11th SIMD register: - // (12, 8), (12, 9), (12, 10), zero, ..., (15, 8), (15, 9), (15, 10), zero - // (28, 8), (28, 9), (28, 10), zero, ..., (31, 8), (31, 9), (31, 10), zero - for (int k1 = 0; k1 < K; k1 += 32) { - __m256i b_v[KERNEL_PROD]; - int remainder = K - k1; - if (remainder < 32) { - __m256i mask_v = _mm256_loadu_si256( - reinterpret_cast<const __m256i*>(masks[remainder / 4])); - for (int i = 0; i < KERNEL_PROD; ++i) { - b_v[i] = _mm256_maskload_epi32( - reinterpret_cast<const int*>(smat_transposed + i * K + k1), mask_v); - } - } else { - for (int i = 0; i < KERNEL_PROD; ++i) { - b_v[i] = _mm256_lddqu_si256( - reinterpret_cast<const __m256i*>(smat_transposed + i * K + k1)); - } - } - - // Interleave 2 SIMD registers - __m256i b_interleaved_epi16[KERNEL_PROD_ALIGNED]; - __m256i zero_v = _mm256_setzero_si256(); - for (int i = 0; i < KERNEL_PROD_ALIGNED / 2; ++i) { - if (2 * i + 1 >= KERNEL_PROD) { - b_interleaved_epi16[2 * i] = _mm256_unpacklo_epi8(b_v[2 * i], zero_v); - b_interleaved_epi16[2 * i + 1] = - _mm256_unpackhi_epi8(b_v[2 * i], zero_v); - } else { - b_interleaved_epi16[2 * i] = - _mm256_unpacklo_epi8(b_v[2 * i], b_v[2 * i + 1]); - b_interleaved_epi16[2 * i + 1] = - _mm256_unpackhi_epi8(b_v[2 * i], b_v[2 * i + 1]); - } - } - - // Interleave 4 SIMD registers - __m256i b_interleaved_epi32[KERNEL_PROD_ALIGNED]; - for (int i = 0; i < KERNEL_PROD_ALIGNED / 4; ++i) { - b_interleaved_epi32[4 * i] = _mm256_unpacklo_epi16( - b_interleaved_epi16[4 * i], b_interleaved_epi16[4 * i + 2]); - b_interleaved_epi32[4 * i + 1] = _mm256_unpackhi_epi16( - b_interleaved_epi16[4 * i], b_interleaved_epi16[4 * i + 2]); - b_interleaved_epi32[4 * i + 2] = _mm256_unpacklo_epi16( - b_interleaved_epi16[4 * i + 1], b_interleaved_epi16[4 * i + 3]); - b_interleaved_epi32[4 * i + 3] = _mm256_unpackhi_epi16( - b_interleaved_epi16[4 * i + 1], b_interleaved_epi16[4 * i + 3]); - } - for (int i = KERNEL_PROD_ALIGNED / 4 * 4; i < KERNEL_PROD_ALIGNED; ++i) { - b_interleaved_epi32[i] = b_interleaved_epi16[i]; - } - - for (int i = 0; i < KERNEL_PROD_ALIGNED; ++i) { - _mm256_storeu_si256( - reinterpret_cast<__m256i*>( - &pmat_[((k1 / 32) * KERNEL_PROD_ALIGNED + i) * 32]), - b_interleaved_epi32[i]); - } - } - - FREE(smat_transposed); -} - -template <int KERNEL_PROD> -PackedDepthWiseConvMatrix<KERNEL_PROD>::~PackedDepthWiseConvMatrix() { -#ifdef _MSC_VER - _aligned_free(pmat_); -#else - free(pmat_); -#endif -} - -template class PackedDepthWiseConvMatrix<3 * 3>; -template class PackedDepthWiseConvMatrix<3 * 3 * 3>; - -// c = a0 * b0 + a1 * b1 + a2 * b2 + a3 * b3 -// A is in uint8_t -// B is in int8_t and pre-interleaved -// C is in int32_t and 4 registers have results in the following layout: -// c0_v: c[0:4], c[16:20] -// c1_v: c[4:8], c[20:24] -// c2_v: c[8:12], c[24:28] -// c3_v: c[12:16], c[28:32] -template <bool SUM_A = false> -static inline ALWAYS_INLINE void madd_epi16x4_packed( - __m256i a0_v, - __m256i a1_v, - __m256i a2_v, - __m256i a3_v, - const __m256i* b, - __m256i* c0_v, - __m256i* c1_v, - __m256i* c2_v, - __m256i* c3_v, - __m256i* a_sum = nullptr) { - __m256i a01_lo_v = _mm256_unpacklo_epi8(a0_v, a1_v); - __m256i a01_hi_v = _mm256_unpackhi_epi8(a0_v, a1_v); - __m256i a23_lo_v = _mm256_unpacklo_epi8(a2_v, a3_v); - __m256i a23_hi_v = _mm256_unpackhi_epi8(a2_v, a3_v); - - if (SUM_A) { - __m256i one_epi8_v = _mm256_set1_epi8(1); - a_sum[0] = - _mm256_adds_epi16(_mm256_maddubs_epi16(a01_lo_v, one_epi8_v), a_sum[0]); - a_sum[1] = - _mm256_adds_epi16(_mm256_maddubs_epi16(a01_hi_v, one_epi8_v), a_sum[1]); - a_sum[0] = - _mm256_adds_epi16(_mm256_maddubs_epi16(a23_lo_v, one_epi8_v), a_sum[0]); - a_sum[1] = - _mm256_adds_epi16(_mm256_maddubs_epi16(a23_hi_v, one_epi8_v), a_sum[1]); - } - - __m256i a0_interleaved_v = _mm256_unpacklo_epi16(a01_lo_v, a23_lo_v); - __m256i a1_interleaved_v = _mm256_unpackhi_epi16(a01_lo_v, a23_lo_v); - __m256i a2_interleaved_v = _mm256_unpacklo_epi16(a01_hi_v, a23_hi_v); - __m256i a3_interleaved_v = _mm256_unpackhi_epi16(a01_hi_v, a23_hi_v); - - __m256i b0_v = _mm256_load_si256(b + 0); - __m256i b1_v = _mm256_load_si256(b + 1); - __m256i b2_v = _mm256_load_si256(b + 2); - __m256i b3_v = _mm256_load_si256(b + 3); - - __m256i ab0 = _mm256_maddubs_epi16(a0_interleaved_v, b0_v); - __m256i ab1 = _mm256_maddubs_epi16(a1_interleaved_v, b1_v); - __m256i ab2 = _mm256_maddubs_epi16(a2_interleaved_v, b2_v); - __m256i ab3 = _mm256_maddubs_epi16(a3_interleaved_v, b3_v); - - __m256i one_v = _mm256_set1_epi16(1); - *c0_v = _mm256_madd_epi16(ab0, one_v); - *c1_v = _mm256_madd_epi16(ab1, one_v); - *c2_v = _mm256_madd_epi16(ab2, one_v); - *c3_v = _mm256_madd_epi16(ab3, one_v); -} - -// c = a0 * b0 + a1 * b1 + a2 * b2 -// A is in uint8_t -// B is in int8_t and pre-interleaved -// C is in int32_t and 4 registers have results in the following layout: -// c0_v: c[0:4], c[16:20] -// c1_v: c[4:8], c[20:24] -// c2_v: c[8:12], c[24:28] -// c3_v: c[12:16], c[28:32] -template <bool SUM_A = false> -static inline ALWAYS_INLINE void madd_epi16x3_packed( - __m256i a0_v, - __m256i a1_v, - __m256i a2_v, - const __m256i* b, - __m256i* c0_v, - __m256i* c1_v, - __m256i* c2_v, - __m256i* c3_v, - __m256i* a_sum = nullptr) { - __m256i zero_v = _mm256_setzero_si256(); - - __m256i a01_lo_v = _mm256_unpacklo_epi8(a0_v, a1_v); - __m256i a01_hi_v = _mm256_unpackhi_epi8(a0_v, a1_v); - __m256i a23_lo_v = _mm256_unpacklo_epi8(a2_v, zero_v); - __m256i a23_hi_v = _mm256_unpackhi_epi8(a2_v, zero_v); - - if (SUM_A) { - __m256i one_epi8_v = _mm256_set1_epi8(1); - a_sum[0] = - _mm256_adds_epi16(_mm256_maddubs_epi16(a01_lo_v, one_epi8_v), a_sum[0]); - a_sum[1] = - _mm256_adds_epi16(_mm256_maddubs_epi16(a01_hi_v, one_epi8_v), a_sum[1]); - a_sum[0] = - _mm256_adds_epi16(_mm256_maddubs_epi16(a23_lo_v, one_epi8_v), a_sum[0]); - a_sum[1] = - _mm256_adds_epi16(_mm256_maddubs_epi16(a23_hi_v, one_epi8_v), a_sum[1]); - } - - __m256i a0_interleaved_v = _mm256_unpacklo_epi16(a01_lo_v, a23_lo_v); - __m256i a1_interleaved_v = _mm256_unpackhi_epi16(a01_lo_v, a23_lo_v); - __m256i a2_interleaved_v = _mm256_unpacklo_epi16(a01_hi_v, a23_hi_v); - __m256i a3_interleaved_v = _mm256_unpackhi_epi16(a01_hi_v, a23_hi_v); - - __m256i b0_v = _mm256_load_si256(b + 0); - __m256i b1_v = _mm256_load_si256(b + 1); - __m256i b2_v = _mm256_load_si256(b + 2); - __m256i b3_v = _mm256_load_si256(b + 3); - - __m256i ab0 = _mm256_maddubs_epi16(a0_interleaved_v, b0_v); - __m256i ab1 = _mm256_maddubs_epi16(a1_interleaved_v, b1_v); - __m256i ab2 = _mm256_maddubs_epi16(a2_interleaved_v, b2_v); - __m256i ab3 = _mm256_maddubs_epi16(a3_interleaved_v, b3_v); - - __m256i one_v = _mm256_set1_epi16(1); - *c0_v = _mm256_madd_epi16(ab0, one_v); - *c1_v = _mm256_madd_epi16(ab1, one_v); - *c2_v = _mm256_madd_epi16(ab2, one_v); - *c3_v = _mm256_madd_epi16(ab3, one_v); -} - -// c = a0 * b0 + a1 * b1 -// A is in uint8_t -// B is in int8_t and pre-interleaved -// C is in int32_t and 4 registers have results in the following layout: -// c0_v: c[0:4], c[4:8] -// c1_v: c[8:12], c[12:16] -// c2_v: c[16:20], c[20:24] -// c3_v: c[24:28], c[28:32] -template <bool SUM_A = false> -static inline ALWAYS_INLINE void madd_epi16x2_packed( - __m256i a0_v, - __m256i a1_v, - const __m256i* b, - __m256i* c0_v, - __m256i* c1_v, - __m256i* c2_v, - __m256i* c3_v, - __m256i* a_sum = nullptr) { - __m256i a_lo_v = _mm256_unpacklo_epi8(a0_v, a1_v); - __m256i a_hi_v = _mm256_unpackhi_epi8(a0_v, a1_v); - - if (SUM_A) { - __m256i one_epi8_v = _mm256_set1_epi8(1); - a_sum[0] = - _mm256_adds_epi16(_mm256_maddubs_epi16(a_lo_v, one_epi8_v), a_sum[0]); - a_sum[1] = - _mm256_adds_epi16(_mm256_maddubs_epi16(a_hi_v, one_epi8_v), a_sum[1]); - } - - __m256i b0_v = _mm256_load_si256(b + 0); - __m256i b1_v = _mm256_load_si256(b + 1); - - __m256i ab_lo_v = _mm256_maddubs_epi16(a_lo_v, b0_v); - __m256i ab_hi_v = _mm256_maddubs_epi16(a_hi_v, b1_v); - - *c0_v = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(ab_lo_v)); - *c1_v = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(ab_hi_v)); - *c2_v = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(ab_lo_v, 1)); - *c3_v = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(ab_hi_v, 1)); -} - -// c = a0 * b0 -// A is in uint8_t -// B is in int8_t and pre-interleaved -// C is in int32_t and 4 registers have results in the following layout: -// c0_v: c[0:4], c[4:8] -// c1_v: c[8:12], c[12:16] -// c2_v: c[16:20], c[20:24] -// c3_v: c[24:28], c[28:32] -template <bool SUM_A = false> -static inline ALWAYS_INLINE void madd_epi16_packed( - __m256i a_v, - const __m256i* b, - __m256i* c0_v, - __m256i* c1_v, - __m256i* c2_v, - __m256i* c3_v, - __m256i* a_sum = nullptr) { - __m256i zero_v = _mm256_setzero_si256(); - - __m256i a_lo_v = _mm256_unpacklo_epi8(a_v, zero_v); - __m256i a_hi_v = _mm256_unpackhi_epi8(a_v, zero_v); - - if (SUM_A) { - __m256i one_epi8_v = _mm256_set1_epi8(1); - a_sum[0] = - _mm256_adds_epi16(_mm256_maddubs_epi16(a_lo_v, one_epi8_v), a_sum[0]); - a_sum[1] = - _mm256_adds_epi16(_mm256_maddubs_epi16(a_hi_v, one_epi8_v), a_sum[1]); - } - - __m256i b0_v = _mm256_load_si256(b + 0); - __m256i b1_v = _mm256_load_si256(b + 1); - - __m256i ab_lo_v = _mm256_maddubs_epi16(a_lo_v, b0_v); - __m256i ab_hi_v = _mm256_maddubs_epi16(a_hi_v, b1_v); - - *c0_v = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(ab_lo_v)); - *c1_v = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(ab_hi_v)); - *c2_v = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(ab_lo_v, 1)); - *c3_v = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(ab_hi_v, 1)); -} - -// K is the number of accumulations we're doing -template <int K, bool SUM_A = false, bool REMAINDER = false, bool ACC = false> -static inline ALWAYS_INLINE void inner_prod_packed_( - const __m256i* a_v, - const __m256i* Bp, - int32_t* C, - int remainder, - __m256i* a_sum = nullptr) { - __m256i c[4], c_temp[4]; - __m256i a_sum_temp[2] = {0, 0}; - - int k = 0; - if (K >= 4) { - madd_epi16x4_packed<SUM_A>( - a_v[0], - a_v[1], - a_v[2], - a_v[3], - Bp, - &c[0], - &c[1], - &c[2], - &c[3], - a_sum_temp); - - for (k = 4; k < K / 4 * 4; k += 4) { - madd_epi16x4_packed<SUM_A>( - a_v[k + 0], - a_v[k + 1], - a_v[k + 2], - a_v[k + 3], - Bp + k, - &c_temp[0], - &c_temp[1], - &c_temp[2], - &c_temp[3], - a_sum_temp); - - c[0] = _mm256_add_epi32(c[0], c_temp[0]); - c[1] = _mm256_add_epi32(c[1], c_temp[1]); - c[2] = _mm256_add_epi32(c[2], c_temp[2]); - c[3] = _mm256_add_epi32(c[3], c_temp[3]); - } - } else { - c[0] = _mm256_setzero_si256(); - c[1] = _mm256_setzero_si256(); - c[2] = _mm256_setzero_si256(); - c[3] = _mm256_setzero_si256(); - } - - if (K - k == 3) { - madd_epi16x3_packed<SUM_A>( - a_v[k], - a_v[k + 1], - a_v[k + 2], - Bp + k, - &c_temp[0], - &c_temp[1], - &c_temp[2], - &c_temp[3], - a_sum_temp); - - c[0] = _mm256_add_epi32(c[0], c_temp[0]); - c[1] = _mm256_add_epi32(c[1], c_temp[1]); - c[2] = _mm256_add_epi32(c[2], c_temp[2]); - c[3] = _mm256_add_epi32(c[3], c_temp[3]); - } - - c_temp[0] = _mm256_permute2f128_si256(c[0], c[1], 0x20); - c_temp[1] = _mm256_permute2f128_si256(c[2], c[3], 0x20); - c_temp[2] = _mm256_permute2f128_si256(c[0], c[1], 0x31); - c_temp[3] = _mm256_permute2f128_si256(c[2], c[3], 0x31); - - if (K - k == 0 || K - k == 3) { - c[0] = c_temp[0]; - c[1] = c_temp[1]; - c[2] = c_temp[2]; - c[3] = c_temp[3]; - } else { - if (K - k == 1) { - madd_epi16_packed<SUM_A>( - a_v[k], Bp + k, &c[0], &c[1], &c[2], &c[3], a_sum_temp); - } else if (K - k == 2) { - madd_epi16x2_packed<SUM_A>( - a_v[k], a_v[k + 1], Bp + k, &c[0], &c[1], &c[2], &c[3], a_sum_temp); - } - - c[0] = _mm256_add_epi32(c[0], c_temp[0]); - c[1] = _mm256_add_epi32(c[1], c_temp[1]); - c[2] = _mm256_add_epi32(c[2], c_temp[2]); - c[3] = _mm256_add_epi32(c[3], c_temp[3]); - } - - if (REMAINDER) { - for (int r = 0; r < remainder / 8; ++r) { - if (ACC) { - _mm256_storeu_si256( - reinterpret_cast<__m256i*>(C + r * 8), - _mm256_add_epi32( - _mm256_loadu_si256(reinterpret_cast<__m256i*>(C + r * 8)), - c[r])); - } else { - _mm256_storeu_si256(reinterpret_cast<__m256i*>(C + r * 8), c[r]); - } - } - } else { - if (ACC) { - _mm256_storeu_si256( - reinterpret_cast<__m256i*>(C), - _mm256_add_epi32( - _mm256_loadu_si256(reinterpret_cast<__m256i*>(C)), c[0])); - _mm256_storeu_si256( - reinterpret_cast<__m256i*>(C + 8), - _mm256_add_epi32( - _mm256_loadu_si256(reinterpret_cast<__m256i*>(C + 8)), c[1])); - _mm256_storeu_si256( - reinterpret_cast<__m256i*>(C + 16), - _mm256_add_epi32( - _mm256_loadu_si256(reinterpret_cast<__m256i*>(C + 16)), c[2])); - _mm256_storeu_si256( - reinterpret_cast<__m256i*>(C + 24), - _mm256_add_epi32( - _mm256_loadu_si256(reinterpret_cast<__m256i*>(C + 24)), c[3])); - } else { - _mm256_storeu_si256(reinterpret_cast<__m256i*>(C), c[0]); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(C + 8), c[1]); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(C + 16), c[2]); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(C + 24), c[3]); - } - } - - if (SUM_A) { - a_sum[0] = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(a_sum_temp[0])); - a_sum[1] = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(a_sum_temp[1])); - a_sum[2] = - _mm256_cvtepi16_epi32(_mm256_extracti128_si256(a_sum_temp[0], 1)); - a_sum[3] = - _mm256_cvtepi16_epi32(_mm256_extracti128_si256(a_sum_temp[1], 1)); - } -} - template <bool SUM_A = false, bool REMAINDER = false> static inline ALWAYS_INLINE void inner_prod_3x3_packed_( const __m256i* a_v, @@ -534,238 +26,6 @@ static inline ALWAYS_INLINE void inner_prod_3x3_packed_( return inner_prod_packed_<9, SUM_A, REMAINDER>(a_v, Bp, C, remainder, a_sum); } -// Almost same as ReQuantizeOutput in OutputProcessing-inh.h but different -// row_offsets for each row because of depth-wise convolution -template < - bool FUSE_RELU, - bool HAS_BIAS, - bool PER_CHANNEL_QUANTIZATION, - bool A_SYMMETRIC, - bool B_SYMMETRIC> -static inline ALWAYS_INLINE void requantize_( - int32_t A_zero_point, - const float* C_multiplier, - int32_t C_zero_point, - const int32_t* C_int32, - uint8_t* C_uint8, - int n, - const int32_t* row_offsets, - const int32_t* col_offsets, - const int32_t* bias) { - __m256 multiplier_v = _mm256_setzero_ps(); - if (!PER_CHANNEL_QUANTIZATION) { - multiplier_v = _mm256_set1_ps(*C_multiplier); - } - - __m256i min_v = _mm256_set1_epi8(static_cast<uint8_t>(0)); - __m256i max_v = _mm256_set1_epi8(static_cast<uint8_t>(255)); - - if (A_SYMMETRIC) { - assert(A_zero_point == 0 || col_offsets == nullptr); - } - __m256i A_zero_point_v = _mm256_set1_epi32(A_zero_point); - __m256i C_zero_point_epi16_v = _mm256_set1_epi16(C_zero_point); - __m256i C_zero_point_epi8_v = _mm256_set1_epi8(C_zero_point); - - __m256i permute_mask_v = - _mm256_set_epi32(0x07, 0x03, 0x06, 0x02, 0x05, 0x01, 0x04, 0x00); - - constexpr int VLEN = 8; - int j = 0; - for (; j < n / (VLEN * 4) * (VLEN * 4); j += (VLEN * 4)) { - __m256i x_v = - _mm256_loadu_si256(reinterpret_cast<const __m256i*>(C_int32 + j)); - __m256i y_v = _mm256_loadu_si256( - reinterpret_cast<const __m256i*>(C_int32 + j + VLEN)); - __m256i z_v = _mm256_loadu_si256( - reinterpret_cast<const __m256i*>(C_int32 + j + 2 * VLEN)); - __m256i w_v = _mm256_loadu_si256( - reinterpret_cast<const __m256i*>(C_int32 + j + 3 * VLEN)); - - __m256i row_offset_v; - if (!B_SYMMETRIC) { - row_offset_v = - _mm256_loadu_si256(reinterpret_cast<const __m256i*>(row_offsets + j)); - x_v = _mm256_sub_epi32(x_v, row_offset_v); - } - __m256i col_off_v; - if (!A_SYMMETRIC) { - col_off_v = _mm256_mullo_epi32( - A_zero_point_v, - _mm256_loadu_si256( - reinterpret_cast<const __m256i*>(col_offsets + j))); - x_v = _mm256_sub_epi32(x_v, col_off_v); - } - - if (!B_SYMMETRIC) { - row_offset_v = _mm256_loadu_si256( - reinterpret_cast<const __m256i*>(row_offsets + j + VLEN)); - y_v = _mm256_sub_epi32(y_v, row_offset_v); - } - if (!A_SYMMETRIC) { - col_off_v = _mm256_mullo_epi32( - A_zero_point_v, - _mm256_loadu_si256( - reinterpret_cast<const __m256i*>(col_offsets + j + VLEN))); - y_v = _mm256_sub_epi32(y_v, col_off_v); - } - - if (!B_SYMMETRIC) { - row_offset_v = _mm256_loadu_si256( - reinterpret_cast<const __m256i*>(row_offsets + j + 2 * VLEN)); - z_v = _mm256_sub_epi32(z_v, row_offset_v); - } - if (!A_SYMMETRIC) { - col_off_v = _mm256_mullo_epi32( - A_zero_point_v, - _mm256_loadu_si256( - reinterpret_cast<const __m256i*>(col_offsets + j + 2 * VLEN))); - z_v = _mm256_sub_epi32(z_v, col_off_v); - } - - if (!B_SYMMETRIC) { - row_offset_v = _mm256_loadu_si256( - reinterpret_cast<const __m256i*>(row_offsets + j + 3 * VLEN)); - w_v = _mm256_sub_epi32(w_v, row_offset_v); - } - if (!A_SYMMETRIC) { - col_off_v = _mm256_mullo_epi32( - A_zero_point_v, - _mm256_loadu_si256( - reinterpret_cast<const __m256i*>(col_offsets + j + 3 * VLEN))); - w_v = _mm256_sub_epi32(w_v, col_off_v); - } - - if (HAS_BIAS) { // static if - x_v = _mm256_add_epi32( - x_v, _mm256_loadu_si256(reinterpret_cast<const __m256i*>(bias + j))); - y_v = _mm256_add_epi32( - y_v, - _mm256_loadu_si256( - reinterpret_cast<const __m256i*>(bias + j + VLEN))); - z_v = _mm256_add_epi32( - z_v, - _mm256_loadu_si256( - reinterpret_cast<const __m256i*>(bias + j + 2 * VLEN))); - w_v = _mm256_add_epi32( - w_v, - _mm256_loadu_si256( - reinterpret_cast<const __m256i*>(bias + j + 3 * VLEN))); - } - - if (PER_CHANNEL_QUANTIZATION) { - multiplier_v = _mm256_loadu_ps(C_multiplier + j); - } - __m256 x_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(x_v), multiplier_v); - if (PER_CHANNEL_QUANTIZATION) { - multiplier_v = _mm256_loadu_ps(C_multiplier + j + VLEN); - } - __m256 y_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(y_v), multiplier_v); - if (PER_CHANNEL_QUANTIZATION) { - multiplier_v = _mm256_loadu_ps(C_multiplier + j + 2 * VLEN); - } - __m256 z_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(z_v), multiplier_v); - if (PER_CHANNEL_QUANTIZATION) { - multiplier_v = _mm256_loadu_ps(C_multiplier + j + 3 * VLEN); - } - __m256 w_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(w_v), multiplier_v); - - __m256i x_rounded_v = _mm256_cvtps_epi32(x_scaled_v); - __m256i y_rounded_v = _mm256_cvtps_epi32(y_scaled_v); - __m256i z_rounded_v = _mm256_cvtps_epi32(z_scaled_v); - __m256i w_rounded_v = _mm256_cvtps_epi32(w_scaled_v); - - __m256i xy_packed_v = _mm256_adds_epi16( - _mm256_packs_epi32(x_rounded_v, y_rounded_v), C_zero_point_epi16_v); - __m256i zw_packed_v = _mm256_adds_epi16( - _mm256_packs_epi32(z_rounded_v, w_rounded_v), C_zero_point_epi16_v); - __m256i xyzw_packed_v = _mm256_packus_epi16(xy_packed_v, zw_packed_v); - __m256i xyzw_clamped_v = _mm256_max_epu8( - FUSE_RELU ? C_zero_point_epi8_v : min_v, - _mm256_min_epu8(xyzw_packed_v, max_v)); - - xyzw_clamped_v = - _mm256_permutevar8x32_epi32(xyzw_clamped_v, permute_mask_v); - - _mm256_storeu_si256( - reinterpret_cast<__m256i*>(C_uint8 + j), xyzw_clamped_v); - } // j loop vectorized and unrolled 4x - - for (; j < n / VLEN * VLEN; j += VLEN) { - __m256i x_v = - _mm256_loadu_si256(reinterpret_cast<const __m256i*>(C_int32 + j)); - - if (!B_SYMMETRIC) { - __m256i row_offset_v = - _mm256_loadu_si256(reinterpret_cast<const __m256i*>(row_offsets + j)); - x_v = _mm256_sub_epi32(x_v, row_offset_v); - } - if (!A_SYMMETRIC) { - __m256i col_off_v = _mm256_mullo_epi32( - A_zero_point_v, - _mm256_loadu_si256( - reinterpret_cast<const __m256i*>(col_offsets + j))); - x_v = _mm256_sub_epi32(x_v, col_off_v); - } - - if (HAS_BIAS) { // static if - x_v = _mm256_add_epi32( - x_v, _mm256_loadu_si256(reinterpret_cast<const __m256i*>(bias + j))); - } - - if (PER_CHANNEL_QUANTIZATION) { - multiplier_v = _mm256_loadu_ps(C_multiplier + j); - } - __m256 x_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(x_v), multiplier_v); - __m256i x_rounded_v = _mm256_cvtps_epi32(x_scaled_v); - - __m256i x_packed_v = _mm256_adds_epi16( - _mm256_packs_epi32(x_rounded_v, _mm256_setzero_si256()), - C_zero_point_epi16_v); - x_packed_v = _mm256_packus_epi16(x_packed_v, _mm256_setzero_si256()); - __m256i x_clamped_v = _mm256_max_epu8( - FUSE_RELU ? C_zero_point_epi8_v : min_v, - _mm256_min_epu8(x_packed_v, max_v)); - - x_clamped_v = _mm256_permutevar8x32_epi32(x_clamped_v, permute_mask_v); - - _mm_storel_epi64( - reinterpret_cast<__m128i*>(C_uint8 + j), - _mm256_castsi256_si128(x_clamped_v)); - } // j loop vectorized - - for (; j < n; ++j) { - int32_t raw = C_int32[j]; - if (!B_SYMMETRIC) { - raw -= row_offsets[j]; - } - if (!A_SYMMETRIC) { - raw -= A_zero_point * col_offsets[j]; - } - if (HAS_BIAS) { // static if - raw += bias[j]; - } - - float ab = raw * C_multiplier[PER_CHANNEL_QUANTIZATION ? j : 0]; - long rounded = lrintf(ab) + C_zero_point; - - C_uint8[j] = std::max( - FUSE_RELU ? static_cast<long>(C_zero_point) : 0l, - std::min(255l, rounded)); - } -} - -template <bool REMAINDER> -static inline ALWAYS_INLINE __m256i load_a( - const uint8_t* A, - __m256i mask_v) { - if (REMAINDER) { - return _mm256_maskload_epi32(reinterpret_cast<const int*>(A), mask_v); - } else { - return _mm256_lddqu_si256(reinterpret_cast<const __m256i*>(A)); - } -} - template < bool SUM_A, bool REMAINDER = false, @@ -878,257 +138,11 @@ static inline ALWAYS_INLINE void inner_prod_3x3_packed_( } template < - bool SUM_A, - bool REMAINDER = false, - bool PER_CHANNEL_QUANTIZATION = false> -static inline ALWAYS_INLINE void inner_prod_3x3x3_packed_( - int T, - int H, - int W, - int K, - int t_in, - int h_in, - int w_in, - const uint8_t* A, - int32_t A_zero_point, - const int8_t* Bp, - const int32_t* B_zero_point, - int32_t* C, - int remainder, - int32_t* row_offsets) { - __m256i A_zero_point_v = _mm256_set1_epi8(static_cast<uint8_t>(A_zero_point)); - __m256i mask_v = _mm256_setzero_si256(); - if (REMAINDER) { - mask_v = _mm256_loadu_si256( - reinterpret_cast<const __m256i*>(masks[remainder / 4])); - } - - // The code below can be written as a simple R*S loop but the compiler - // doesn't unroll so we're manually unrolling it. - // constexpr int R = 3, S = 3; - // array<__m256i, R * S> a_v; - // for (int r = 0; r < R; ++r) { - // for (int s = 0; s < S; ++s) { - // if (h_in + r >= 0 && h_in + r < H && w_in + s >= 0 && w_in + s < W) { - // if (REMAINDER) { - // a_v[r * S + s] = - // _mm256_maskload_epi32((const int *)(A + (r * W + s) * K), - // mask_v); - // } else { - // a_v[r * S + s] = - // _mm256_lddqu_si256((const __m256i *)(A + (r * W + s) * K)); - // } - // } else { - // a_v[r * S + s] = A_zero_point_v; - // } - // } - // } - __m256i a_v[8]; - a_v[0] = A_zero_point_v; - a_v[1] = A_zero_point_v; - a_v[2] = A_zero_point_v; - a_v[3] = A_zero_point_v; - a_v[4] = A_zero_point_v; - a_v[5] = A_zero_point_v; - a_v[6] = A_zero_point_v; - a_v[7] = A_zero_point_v; - - if (t_in >= 0 && t_in < T) { - if (h_in >= 0 && h_in < H) { - if (w_in >= 0 && w_in < W) { - a_v[0] = load_a<REMAINDER>(A + ((0 * H + 0) * W + 0) * K, mask_v); - } - if (w_in + 1 >= 0 && w_in + 1 < W) { - a_v[1] = load_a<REMAINDER>(A + ((0 * H + 0) * W + 1) * K, mask_v); - } - if (w_in + 2 >= 0 && w_in + 2 < W) { - a_v[2] = load_a<REMAINDER>(A + ((0 * H + 0) * W + 2) * K, mask_v); - } - } - - if (h_in + 1 >= 0 && h_in + 1 < H) { - if (w_in >= 0 && w_in < W) { - a_v[3] = load_a<REMAINDER>(A + ((0 * H + 1) * W + 0) * K, mask_v); - } - if (w_in + 1 >= 0 && w_in + 1 < W) { - a_v[4] = load_a<REMAINDER>(A + ((0 * H + 1) * W + 1) * K, mask_v); - } - if (w_in + 2 >= 0 && w_in + 2 < W) { - a_v[5] = load_a<REMAINDER>(A + ((0 * H + 1) * W + 2) * K, mask_v); - } - } - - if (h_in + 2 >= 0 && h_in + 2 < H) { - if (w_in >= 0 && w_in < W) { - a_v[6] = load_a<REMAINDER>(A + ((0 * H + 2) * W + 0) * K, mask_v); - } - if (w_in + 1 >= 0 && w_in + 1 < W) { - a_v[7] = load_a<REMAINDER>(A + ((0 * H + 2) * W + 1) * K, mask_v); - } - } - } - - __m256i a_sum[4]; - inner_prod_packed_<8, SUM_A, REMAINDER>( - a_v, reinterpret_cast<const __m256i*>(Bp), C, remainder, a_sum); - - a_v[0] = A_zero_point_v; - a_v[1] = A_zero_point_v; - a_v[2] = A_zero_point_v; - a_v[3] = A_zero_point_v; - a_v[4] = A_zero_point_v; - a_v[5] = A_zero_point_v; - a_v[6] = A_zero_point_v; - a_v[7] = A_zero_point_v; - - if (t_in >= 0 && t_in < T) { - if (h_in + 2 >= 0 && h_in + 2 < H) { - if (w_in + 2 >= 0 && w_in + 2 < W) { - a_v[0] = load_a<REMAINDER>(A + ((0 * H + 2) * W + 2) * K, mask_v); - } - } - } - - if (t_in + 1 >= 0 && t_in + 1 < T) { - if (h_in >= 0 && h_in < H) { - if (w_in >= 0 && w_in < W) { - a_v[1] = load_a<REMAINDER>(A + ((1 * H + 0) * W + 0) * K, mask_v); - } - if (w_in + 1 >= 0 && w_in + 1 < W) { - a_v[2] = load_a<REMAINDER>(A + ((1 * H + 0) * W + 1) * K, mask_v); - } - if (w_in + 2 >= 0 && w_in + 2 < W) { - a_v[3] = load_a<REMAINDER>(A + ((1 * H + 0) * W + 2) * K, mask_v); - } - } - - if (h_in + 1 >= 0 && h_in + 1 < H) { - if (w_in >= 0 && w_in < W) { - a_v[4] = load_a<REMAINDER>(A + ((1 * H + 1) * W + 0) * K, mask_v); - } - if (w_in + 1 >= 0 && w_in + 1 < W) { - a_v[5] = load_a<REMAINDER>(A + ((1 * H + 1) * W + 1) * K, mask_v); - } - if (w_in + 2 >= 0 && w_in + 2 < W) { - a_v[6] = load_a<REMAINDER>(A + ((1 * H + 1) * W + 2) * K, mask_v); - } - } - - if (h_in + 2 >= 0 && h_in + 2 < H) { - if (w_in >= 0 && w_in < W) { - a_v[7] = load_a<REMAINDER>(A + ((1 * H + 2) * W + 0) * K, mask_v); - } - } - } - - __m256i a_sum_temp[4]; - inner_prod_packed_<8, SUM_A, REMAINDER, true /* acc */>( - a_v, reinterpret_cast<const __m256i*>(Bp) + 8, C, remainder, a_sum_temp); - if (SUM_A) { - a_sum[0] = _mm256_add_epi32(a_sum[0], a_sum_temp[0]); - a_sum[1] = _mm256_add_epi32(a_sum[1], a_sum_temp[1]); - a_sum[2] = _mm256_add_epi32(a_sum[2], a_sum_temp[2]); - a_sum[3] = _mm256_add_epi32(a_sum[3], a_sum_temp[3]); - } - - a_v[0] = A_zero_point_v; - a_v[1] = A_zero_point_v; - a_v[2] = A_zero_point_v; - a_v[3] = A_zero_point_v; - a_v[4] = A_zero_point_v; - a_v[5] = A_zero_point_v; - a_v[6] = A_zero_point_v; - a_v[7] = A_zero_point_v; - - if (t_in + 1 >= 0 && t_in + 1 < T) { - if (h_in + 2 >= 0 && h_in + 2 < H) { - if (w_in + 1 >= 0 && w_in + 1 < W) { - a_v[0] = load_a<REMAINDER>(A + ((1 * H + 2) * W + 1) * K, mask_v); - } - if (w_in + 2 >= 0 && w_in + 2 < W) { - a_v[1] = load_a<REMAINDER>(A + ((1 * H + 2) * W + 2) * K, mask_v); - } - } - } - - if (t_in + 2 >= 0 && t_in + 2 < T) { - if (h_in >= 0 && h_in < H) { - if (w_in >= 0 && w_in < W) { - a_v[2] = load_a<REMAINDER>(A + ((2 * H + 0) * W + 0) * K, mask_v); - } - if (w_in + 1 >= 0 && w_in + 1 < W) { - a_v[3] = load_a<REMAINDER>(A + ((2 * H + 0) * W + 1) * K, mask_v); - } - if (w_in + 2 >= 0 && w_in + 2 < W) { - a_v[4] = load_a<REMAINDER>(A + ((2 * H + 0) * W + 2) * K, mask_v); - } - } - - if (h_in + 1 >= 0 && h_in + 1 < H) { - if (w_in >= 0 && w_in < W) { - a_v[5] = load_a<REMAINDER>(A + ((2 * H + 1) * W + 0) * K, mask_v); - } - if (w_in + 1 >= 0 && w_in + 1 < W) { - a_v[6] = load_a<REMAINDER>(A + ((2 * H + 1) * W + 1) * K, mask_v); - } - if (w_in + 2 >= 0 && w_in + 2 < W) { - a_v[7] = load_a<REMAINDER>(A + ((2 * H + 1) * W + 2) * K, mask_v); - } - } - } - - inner_prod_packed_<8, SUM_A, REMAINDER, true /* acc */>( - a_v, reinterpret_cast<const __m256i*>(Bp) + 16, C, remainder, a_sum_temp); - if (SUM_A) { - a_sum[0] = _mm256_add_epi32(a_sum[0], a_sum_temp[0]); - a_sum[1] = _mm256_add_epi32(a_sum[1], a_sum_temp[1]); - a_sum[2] = _mm256_add_epi32(a_sum[2], a_sum_temp[2]); - a_sum[3] = _mm256_add_epi32(a_sum[3], a_sum_temp[3]); - } - - a_v[0] = A_zero_point_v; - a_v[1] = A_zero_point_v; - a_v[2] = A_zero_point_v; - - if (t_in + 2 >= 0 && t_in + 2 < T) { - if (h_in + 2 >= 0 && h_in + 2 < H) { - if (w_in >= 0 && w_in < W) { - a_v[0] = load_a<REMAINDER>(A + ((2 * H + 2) * W + 0) * K, mask_v); - } - if (w_in + 1 >= 0 && w_in + 1 < W) { - a_v[1] = load_a<REMAINDER>(A + ((2 * H + 2) * W + 1) * K, mask_v); - } - if (w_in + 2 >= 0 && w_in + 2 < W) { - a_v[2] = load_a<REMAINDER>(A + ((2 * H + 2) * W + 2) * K, mask_v); - } - } - } - - inner_prod_packed_<3, SUM_A, REMAINDER, true /* acc */>( - a_v, reinterpret_cast<const __m256i*>(Bp) + 24, C, remainder, a_sum_temp); - - if (SUM_A) { - a_sum[0] = _mm256_add_epi32(a_sum[0], a_sum_temp[0]); - a_sum[1] = _mm256_add_epi32(a_sum[1], a_sum_temp[1]); - a_sum[2] = _mm256_add_epi32(a_sum[2], a_sum_temp[2]); - a_sum[3] = _mm256_add_epi32(a_sum[3], a_sum_temp[3]); - - __m256i B_zero_point_v; - for (int i = 0; i < (REMAINDER ? (remainder / 8) : 4); ++i) { - if (PER_CHANNEL_QUANTIZATION) { - B_zero_point_v = _mm256_loadu_si256( - reinterpret_cast<const __m256i*>(B_zero_point + i * 8)); - } else { - B_zero_point_v = _mm256_set1_epi32(B_zero_point[0]); - } - _mm256_store_si256( - reinterpret_cast<__m256i*>(&row_offsets[i * 8]), - _mm256_mullo_epi32(a_sum[i], B_zero_point_v)); - } - } -} - -template <bool FUSE_RELU, bool HAS_BIAS, bool A_SYMMETRIC, bool B_SYMMETRIC> + bool FUSE_RELU, + bool HAS_BIAS, + bool A_SYMMETRIC, + bool B_SYMMETRIC, + typename BIAS_TYPE> static inline ALWAYS_INLINE void depthwise_3x3_kernel_( int H, int W, @@ -1147,7 +161,8 @@ static inline ALWAYS_INLINE void depthwise_3x3_kernel_( uint8_t* C_uint8, int32_t* row_offsets, const int32_t* col_offsets, - const int32_t* bias) { + const BIAS_TYPE* bias, + float act_times_w_scale) { constexpr int S = 3; constexpr int PAD_T = 1, PAD_L = 1, PAD_R = 1; int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1; @@ -1192,7 +207,8 @@ static inline ALWAYS_INLINE void depthwise_3x3_kernel_( HAS_BIAS, false, /*PER_CHAN_QUANT*/ A_SYMMETRIC, - B_SYMMETRIC>( + B_SYMMETRIC, + BIAS_TYPE>( A_zero_point, &C_multiplier, C_zero_point, @@ -1201,95 +217,11 @@ static inline ALWAYS_INLINE void depthwise_3x3_kernel_( K, row_offsets, col_offsets, - bias); -} - -template <bool FUSE_RELU, bool HAS_BIAS, bool A_SYMMETRIC, bool B_SYMMETRIC> -static inline ALWAYS_INLINE void depthwise_3x3x3_kernel_( - int T, - int H, - int W, - int K, - int t, - int h, - int w, - int stride_t, - int stride_h, - int stride_w, - int32_t A_zero_point, - const uint8_t* A, - int32_t B_zero_point, - const int8_t* Bp, - float C_multiplier, - int32_t C_zero_point, - int32_t* C_int32, - uint8_t* C_uint8, - int32_t* row_offsets, - const int32_t* col_offsets, - const int32_t* bias) { - constexpr int R = 3, S = 3; - constexpr int PAD_P = 1, PAD_T = 1, PAD_B = 1, PAD_L = 1, PAD_R = 1; - int H_OUT = (H + PAD_T + PAD_B - R) / stride_h + 1; - int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1; - int t_in = -PAD_P + t * stride_t; - int h_in = -PAD_T + h * stride_h; - int w_in = -PAD_L + w * stride_w; - - int k; - for (k = 0; k < K / 32 * 32; k += 32) { - inner_prod_3x3x3_packed_<!B_SYMMETRIC /*SUM_A*/>( - T, - H, - W, - K, - t_in, - h_in, - w_in, - A + ((t_in * H + h_in) * W + w_in) * K + k, - A_zero_point, - Bp + k * 28, - &B_zero_point, - C_int32 + k, - 0, - B_SYMMETRIC ? nullptr : &row_offsets[k]); - } - int remainder = K - k; - if (remainder) { - inner_prod_3x3x3_packed_<!B_SYMMETRIC /*SUM_A*/, true>( - T, - H, - W, - K, - t_in, - h_in, - w_in, - A + ((t_in * H + h_in) * W + w_in) * K + k, - A_zero_point, - Bp + k * 28, - &B_zero_point, - C_int32 + k, - remainder, - B_SYMMETRIC ? nullptr : &row_offsets[k]); - } - - requantize_< - FUSE_RELU, - HAS_BIAS, - false, /*PER_CHAN_QUANT*/ - A_SYMMETRIC, - B_SYMMETRIC>( - A_zero_point, - &C_multiplier, - C_zero_point, - C_int32, - C_uint8 + ((t * H_OUT + h) * W_OUT + w) * K, - K, - row_offsets, - col_offsets, - bias); + bias, + &act_times_w_scale); } -template <bool FUSE_RELU, bool HAS_BIAS, bool A_SYMMETRIC> +template <bool FUSE_RELU, bool HAS_BIAS, bool A_SYMMETRIC, typename BIAS_TYPE> static inline ALWAYS_INLINE void depthwise_3x3_per_channel_quantization_kernel_( int H, @@ -1309,7 +241,8 @@ depthwise_3x3_per_channel_quantization_kernel_( uint8_t* C_uint8, int32_t* row_offsets, const int32_t* col_offsets, - const int32_t* bias) { + const BIAS_TYPE* bias, + const float* act_times_w_scale) { constexpr int S = 3; constexpr int PAD_T = 1, PAD_L = 1, PAD_R = 1; int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1; @@ -1360,7 +293,8 @@ depthwise_3x3_per_channel_quantization_kernel_( HAS_BIAS, true, /*PER_CHAN_QUANT*/ A_SYMMETRIC, - false /*B_SYMM*/>( + false, /*B_SYMM*/ + BIAS_TYPE>( A_zero_point, C_multiplier, C_zero_point, @@ -1369,113 +303,20 @@ depthwise_3x3_per_channel_quantization_kernel_( K, row_offsets, col_offsets, - bias); -} - -template <bool FUSE_RELU, bool HAS_BIAS, bool A_SYMMETRIC> -static inline ALWAYS_INLINE void -depthwise_3x3x3_per_channel_quantization_kernel_( - int T, - int H, - int W, - int K, - int t, - int h, - int w, - int stride_t, - int stride_h, - int stride_w, - int32_t A_zero_point, - const uint8_t* A, - const int32_t* B_zero_point, - const int8_t* Bp, - const float* C_multiplier, - int32_t C_zero_point, - int32_t* C_int32, - uint8_t* C_uint8, - int32_t* row_offsets, - const int32_t* col_offsets, - const int32_t* bias) { - constexpr int R = 3, S = 3; - constexpr int PAD_P = 1, PAD_T = 1, PAD_B = 1, PAD_L = 1, PAD_R = 1; - int H_OUT = (H + PAD_T + PAD_B - R) / stride_h + 1; - int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1; - int t_in = -PAD_P + t * stride_t; - int h_in = -PAD_T + h * stride_h; - int w_in = -PAD_L + w * stride_w; - - int k; - for (k = 0; k < K / 32 * 32; k += 32) { - inner_prod_3x3x3_packed_< - true, /*SUM_A*/ - false, /*remainder*/ - true /*per-channel*/>( - T, - H, - W, - K, - t_in, - h_in, - w_in, - A + ((t_in * H + h_in) * W + w_in) * K + k, - A_zero_point, - Bp + k * 28, - B_zero_point + k, - C_int32 + k, - 0, - &row_offsets[k]); - } - int remainder = K - k; - if (remainder) { - inner_prod_3x3x3_packed_< - true, /*SUM_A*/ - true, /*remainder*/ - true /*per-channel*/>( - T, - H, - W, - K, - t_in, - h_in, - w_in, - A + ((t_in * H + h_in) * W + w_in) * K + k, - A_zero_point, - Bp + k * 28, - B_zero_point + k, - C_int32 + k, - remainder, - &row_offsets[k]); - } - requantize_< - FUSE_RELU, - HAS_BIAS, - true, /*PER_CHAN_QUANT*/ - A_SYMMETRIC, - false /*B_SYMM*/>( - A_zero_point, - C_multiplier, - C_zero_point, - C_int32, - C_uint8 + ((t * H_OUT + h) * W_OUT + w) * K, - K, - row_offsets, - col_offsets, - bias); -} - -static pair<int, int> closest_factors_(int n) { - int a = (int)std::sqrt(n); - while (n % a != 0) { - a--; - } - return {a, n / a}; // a <= n / a + bias, + act_times_w_scale); } // TODO: short-circuit when B_zero_point is 0 or A_zero_point is 0 // This implemntation should be general enough to handle not just 3x3 but other // filter shapes by parameterizing with R and S but restricting it to just 3x3 // for now. -template <bool FUSE_RELU, bool HAS_BIAS, bool A_SYMMETRIC, bool B_SYMMETRIC> +template < + bool FUSE_RELU, + bool HAS_BIAS, + bool A_SYMMETRIC, + bool B_SYMMETRIC, + typename BIAS_TYPE> static inline ALWAYS_INLINE void depthwise_3x3_pad_1_( int N, int H, @@ -1486,13 +327,14 @@ static inline ALWAYS_INLINE void depthwise_3x3_pad_1_( int32_t A_zero_point, const uint8_t* A, int32_t B_zero_point, - const Packed3x3ConvMatrix& B, + const PackedDepthWiseConvMatrix& B, float C_multiplier, int32_t C_zero_point, int32_t* C_int32, uint8_t* C_uint8, const int32_t* col_offsets, - const int32_t* bias, + const BIAS_TYPE* bias, + float act_times_w_scale, int thread_id, int num_threads) { assert(K % 8 == 0); @@ -1551,7 +393,12 @@ static inline ALWAYS_INLINE void depthwise_3x3_pad_1_( if (h_begin == 0) { if (w_begin == 0) { - depthwise_3x3_kernel_<FUSE_RELU, HAS_BIAS, A_SYMMETRIC, B_SYMMETRIC>( + depthwise_3x3_kernel_< + FUSE_RELU, + HAS_BIAS, + A_SYMMETRIC, + B_SYMMETRIC, + BIAS_TYPE>( H, W, K, @@ -1569,11 +416,17 @@ static inline ALWAYS_INLINE void depthwise_3x3_pad_1_( C_uint8_base, row_offsets, col_offsets, - bias); + bias, + act_times_w_scale); } for (w = std::max(1, w_begin); w < std::min(W_OUT - 1, w_end); ++w) { - depthwise_3x3_kernel_<FUSE_RELU, HAS_BIAS, A_SYMMETRIC, B_SYMMETRIC>( + depthwise_3x3_kernel_< + FUSE_RELU, + HAS_BIAS, + A_SYMMETRIC, + B_SYMMETRIC, + BIAS_TYPE>( H, W, K, @@ -1591,12 +444,18 @@ static inline ALWAYS_INLINE void depthwise_3x3_pad_1_( C_uint8_base, row_offsets, col_offsets, - bias); + bias, + act_times_w_scale); } if (w_end == W_OUT) { w = W_OUT - 1; - depthwise_3x3_kernel_<FUSE_RELU, HAS_BIAS, A_SYMMETRIC, B_SYMMETRIC>( + depthwise_3x3_kernel_< + FUSE_RELU, + HAS_BIAS, + A_SYMMETRIC, + B_SYMMETRIC, + BIAS_TYPE>( H, W, K, @@ -1614,14 +473,20 @@ static inline ALWAYS_INLINE void depthwise_3x3_pad_1_( C_uint8_base, row_offsets, col_offsets, - bias); + bias, + act_times_w_scale); } } for (h = std::max(1, h_begin); h < std::min(H - 1, h_end); ++h) { if (w_begin == 0) { w = 0; - depthwise_3x3_kernel_<FUSE_RELU, HAS_BIAS, A_SYMMETRIC, B_SYMMETRIC>( + depthwise_3x3_kernel_< + FUSE_RELU, + HAS_BIAS, + A_SYMMETRIC, + B_SYMMETRIC, + BIAS_TYPE>( H, W, K, @@ -1639,11 +504,17 @@ static inline ALWAYS_INLINE void depthwise_3x3_pad_1_( C_uint8_base, row_offsets, col_offsets, - bias); + bias, + act_times_w_scale); } for (w = std::max(1, w_begin); w < std::min(W_OUT - 1, w_end); ++w) { - depthwise_3x3_kernel_<FUSE_RELU, HAS_BIAS, A_SYMMETRIC, B_SYMMETRIC>( + depthwise_3x3_kernel_< + FUSE_RELU, + HAS_BIAS, + A_SYMMETRIC, + B_SYMMETRIC, + BIAS_TYPE>( H, W, K, @@ -1661,12 +532,18 @@ static inline ALWAYS_INLINE void depthwise_3x3_pad_1_( C_uint8_base, row_offsets, col_offsets, - bias); + bias, + act_times_w_scale); } if (w_end == W_OUT) { w = W_OUT - 1; - depthwise_3x3_kernel_<FUSE_RELU, HAS_BIAS, A_SYMMETRIC, B_SYMMETRIC>( + depthwise_3x3_kernel_< + FUSE_RELU, + HAS_BIAS, + A_SYMMETRIC, + B_SYMMETRIC, + BIAS_TYPE>( H, W, K, @@ -1684,7 +561,8 @@ static inline ALWAYS_INLINE void depthwise_3x3_pad_1_( C_uint8_base, row_offsets, col_offsets, - bias); + bias, + act_times_w_scale); } } @@ -1692,7 +570,12 @@ static inline ALWAYS_INLINE void depthwise_3x3_pad_1_( h = H_OUT - 1; w = 0; if (w_begin == 0) { - depthwise_3x3_kernel_<FUSE_RELU, HAS_BIAS, A_SYMMETRIC, B_SYMMETRIC>( + depthwise_3x3_kernel_< + FUSE_RELU, + HAS_BIAS, + A_SYMMETRIC, + B_SYMMETRIC, + BIAS_TYPE>( H, W, K, @@ -1710,11 +593,17 @@ static inline ALWAYS_INLINE void depthwise_3x3_pad_1_( C_uint8_base, row_offsets, col_offsets, - bias); + bias, + act_times_w_scale); } for (w = std::max(1, w_begin); w < std::min(W_OUT - 1, w_end); ++w) { - depthwise_3x3_kernel_<FUSE_RELU, HAS_BIAS, A_SYMMETRIC, B_SYMMETRIC>( + depthwise_3x3_kernel_< + FUSE_RELU, + HAS_BIAS, + A_SYMMETRIC, + B_SYMMETRIC, + BIAS_TYPE>( H, W, K, @@ -1732,12 +621,18 @@ static inline ALWAYS_INLINE void depthwise_3x3_pad_1_( C_uint8_base, row_offsets, col_offsets, - bias); + bias, + act_times_w_scale); } if (w_end == W_OUT) { w = W_OUT - 1; - depthwise_3x3_kernel_<FUSE_RELU, HAS_BIAS, A_SYMMETRIC, B_SYMMETRIC>( + depthwise_3x3_kernel_< + FUSE_RELU, + HAS_BIAS, + A_SYMMETRIC, + B_SYMMETRIC, + BIAS_TYPE>( H, W, K, @@ -1755,126 +650,15 @@ static inline ALWAYS_INLINE void depthwise_3x3_pad_1_( C_uint8_base, row_offsets, col_offsets, - bias); + bias, + act_times_w_scale); } } } // for each n FREE(row_offsets); }; -template <bool FUSE_RELU, bool HAS_BIAS, bool A_SYMMETRIC, bool B_SYMMETRIC> -static inline ALWAYS_INLINE void depthwise_3x3x3_pad_1_( - int N, - int T, - int H, - int W, - int K, - int stride_t, - int stride_h, - int stride_w, - int32_t A_zero_point, - const uint8_t* A, - int32_t B_zero_point, - const Packed3x3x3ConvMatrix& B, - float C_multiplier, - int32_t C_zero_point, - int32_t* C_int32, - uint8_t* C_uint8, - const int32_t* col_offsets, - const int32_t* bias, - int thread_id, - int num_threads) { - assert(K % 8 == 0); - constexpr int K_T = 3, K_H = 3, K_W = 3; - constexpr int PAD_P = 1, PAD_N = 1, PAD_T = 1, PAD_B = 1, PAD_L = 1, - PAD_R = 1; - int T_OUT = (T + PAD_P + PAD_N - K_T) / stride_t + 1; - int H_OUT = (H + PAD_T + PAD_B - K_H) / stride_h + 1; - int W_OUT = (W + PAD_L + PAD_R - K_W) / stride_w + 1; - const int8_t* Bp = B.PackedMat(); - - int32_t* row_offsets = static_cast<int32_t*>(ALIGNED_MALLOC(((K + 31) / 32 * 32)*sizeof(int32_t), 64)); // __attribute__((aligned(64))); - - int n_begin, n_end; - int t_begin, t_end, h_begin, h_end; - if (N >= num_threads) { - int n_per_thread = (N + num_threads - 1) / num_threads; - n_begin = std::min(thread_id * n_per_thread, N); - n_end = std::min(n_begin + n_per_thread, N); - t_begin = 0; - t_end = T_OUT; - h_begin = 0; - h_end = H_OUT; - } else { - int nthreads_per_n = num_threads / N; - n_begin = std::min(thread_id / nthreads_per_n, N); - n_end = std::min(n_begin + 1, N); - - int tid_of_n_begin = std::min(n_begin * nthreads_per_n, num_threads); - int tid_of_n_end = std::min(tid_of_n_begin + nthreads_per_n, num_threads); - int nthreads_of_n = tid_of_n_end - tid_of_n_begin; - int tid_within_n = thread_id - tid_of_n_begin; - assert(tid_within_n >= 0); - assert(tid_within_n < nthreads_of_n); - - // n is processed by num_threads_t * num_threads_h 2D grid of threads - int num_threads_t, num_threads_h; - // num_threads_w <= num_threads_h - tie(num_threads_t, num_threads_h) = closest_factors_(nthreads_of_n); - int tid_t = tid_within_n / num_threads_h; - int tid_h = tid_within_n % num_threads_h; - - int t_per_thread = (T_OUT + num_threads_t - 1) / num_threads_t; - t_begin = std::min(tid_t * t_per_thread, T_OUT); - t_end = std::min(t_begin + t_per_thread, T_OUT); - - int h_per_thread = (H_OUT + num_threads_h - 1) / num_threads_h; - h_begin = std::min(tid_h * h_per_thread, H_OUT); - h_end = std::min(h_begin + h_per_thread, H_OUT); - } - - for (int n = n_begin; n < n_end; ++n) { - const uint8_t* A_base = A + n * T * H * W * K; - uint8_t* C_uint8_base = C_uint8 + n * T_OUT * H_OUT * W_OUT * K; - - for (int t = t_begin; t < t_end; ++t) { - for (int h = h_begin; h < h_end; ++h) { - for (int w = 0; w < W_OUT; ++w) { - depthwise_3x3x3_kernel_< - FUSE_RELU, - HAS_BIAS, - A_SYMMETRIC, - B_SYMMETRIC>( - T, - H, - W, - K, - t, - h, - w, - stride_t, - stride_h, - stride_w, - A_zero_point, - A_base, - B_zero_point, - Bp, - C_multiplier, - C_zero_point, - C_int32, - C_uint8_base, - row_offsets, - col_offsets, - bias); - } // w - } // h - } // t - } // for each n - - FREE(row_offsets); -}; - -template <bool FUSE_RELU, bool HAS_BIAS, bool A_SYMMETRIC> +template <bool FUSE_RELU, bool HAS_BIAS, bool A_SYMMETRIC, typename BIAS_TYPE> static inline ALWAYS_INLINE void depthwise_3x3_per_channel_quantization_pad_1_( int N, @@ -1886,13 +670,14 @@ depthwise_3x3_per_channel_quantization_pad_1_( int32_t A_zero_point, const uint8_t* A, const int32_t* B_zero_point, - const Packed3x3ConvMatrix& B, + const PackedDepthWiseConvMatrix& B, const float* C_multiplier, int32_t C_zero_point, int32_t* C_int32, uint8_t* C_uint8, const int32_t* col_offsets, - const int32_t* bias, + const BIAS_TYPE* bias, + const float* act_times_w_scale, int thread_id, int num_threads) { assert(K % 8 == 0); @@ -1954,7 +739,8 @@ depthwise_3x3_per_channel_quantization_pad_1_( depthwise_3x3_per_channel_quantization_kernel_< FUSE_RELU, HAS_BIAS, - A_SYMMETRIC>( + A_SYMMETRIC, + BIAS_TYPE>( H, W, K, @@ -1972,14 +758,16 @@ depthwise_3x3_per_channel_quantization_pad_1_( C_uint8_base, row_offsets, col_offsets, - bias); + bias, + act_times_w_scale); } for (w = std::max(1, w_begin); w < std::min(W_OUT - 1, w_end); ++w) { depthwise_3x3_per_channel_quantization_kernel_< FUSE_RELU, HAS_BIAS, - A_SYMMETRIC>( + A_SYMMETRIC, + BIAS_TYPE>( H, W, K, @@ -1997,7 +785,8 @@ depthwise_3x3_per_channel_quantization_pad_1_( C_uint8_base, row_offsets, col_offsets, - bias); + bias, + act_times_w_scale); } if (w_end == W_OUT) { @@ -2005,7 +794,8 @@ depthwise_3x3_per_channel_quantization_pad_1_( depthwise_3x3_per_channel_quantization_kernel_< FUSE_RELU, HAS_BIAS, - A_SYMMETRIC>( + A_SYMMETRIC, + BIAS_TYPE>( H, W, K, @@ -2023,7 +813,8 @@ depthwise_3x3_per_channel_quantization_pad_1_( C_uint8_base, row_offsets, col_offsets, - bias); + bias, + act_times_w_scale); } } @@ -2033,7 +824,8 @@ depthwise_3x3_per_channel_quantization_pad_1_( depthwise_3x3_per_channel_quantization_kernel_< FUSE_RELU, HAS_BIAS, - A_SYMMETRIC>( + A_SYMMETRIC, + BIAS_TYPE>( H, W, K, @@ -2051,14 +843,16 @@ depthwise_3x3_per_channel_quantization_pad_1_( C_uint8_base, row_offsets, col_offsets, - bias); + bias, + act_times_w_scale); } for (w = std::max(1, w_begin); w < std::min(W_OUT - 1, w_end); ++w) { depthwise_3x3_per_channel_quantization_kernel_< FUSE_RELU, HAS_BIAS, - A_SYMMETRIC>( + A_SYMMETRIC, + BIAS_TYPE>( H, W, K, @@ -2076,7 +870,8 @@ depthwise_3x3_per_channel_quantization_pad_1_( C_uint8_base, row_offsets, col_offsets, - bias); + bias, + act_times_w_scale); } if (w_end == W_OUT) { @@ -2084,7 +879,8 @@ depthwise_3x3_per_channel_quantization_pad_1_( depthwise_3x3_per_channel_quantization_kernel_< FUSE_RELU, HAS_BIAS, - A_SYMMETRIC>( + A_SYMMETRIC, + BIAS_TYPE>( H, W, K, @@ -2102,7 +898,8 @@ depthwise_3x3_per_channel_quantization_pad_1_( C_uint8_base, row_offsets, col_offsets, - bias); + bias, + act_times_w_scale); } } @@ -2113,7 +910,8 @@ depthwise_3x3_per_channel_quantization_pad_1_( depthwise_3x3_per_channel_quantization_kernel_< FUSE_RELU, HAS_BIAS, - A_SYMMETRIC>( + A_SYMMETRIC, + BIAS_TYPE>( H, W, K, @@ -2131,14 +929,16 @@ depthwise_3x3_per_channel_quantization_pad_1_( C_uint8_base, row_offsets, col_offsets, - bias); + bias, + act_times_w_scale); } for (w = std::max(1, w_begin); w < std::min(W_OUT - 1, w_end); ++w) { depthwise_3x3_per_channel_quantization_kernel_< FUSE_RELU, HAS_BIAS, - A_SYMMETRIC>( + A_SYMMETRIC, + BIAS_TYPE>( H, W, K, @@ -2156,7 +956,8 @@ depthwise_3x3_per_channel_quantization_pad_1_( C_uint8_base, row_offsets, col_offsets, - bias); + bias, + act_times_w_scale); } if (w_end == W_OUT) { @@ -2164,7 +965,8 @@ depthwise_3x3_per_channel_quantization_pad_1_( depthwise_3x3_per_channel_quantization_kernel_< FUSE_RELU, HAS_BIAS, - A_SYMMETRIC>( + A_SYMMETRIC, + BIAS_TYPE>( H, W, K, @@ -2182,128 +984,15 @@ depthwise_3x3_per_channel_quantization_pad_1_( C_uint8_base, row_offsets, col_offsets, - bias); + bias, + act_times_w_scale); } } } // for each n - - FREE(row_offsets); -}; - -template <bool FUSE_RELU, bool HAS_BIAS, bool A_SYMMETRIC> -static inline ALWAYS_INLINE void -depthwise_3x3x3_per_channel_quantization_pad_1_( - int N, - int T, - int H, - int W, - int K, - int stride_t, - int stride_h, - int stride_w, - int32_t A_zero_point, - const uint8_t* A, - const int32_t* B_zero_point, - const Packed3x3x3ConvMatrix& B, - const float* C_multiplier, - int32_t C_zero_point, - int32_t* C_int32, - uint8_t* C_uint8, - const int32_t* col_offsets, - const int32_t* bias, - int thread_id, - int num_threads) { - assert(K % 8 == 0); - constexpr int K_T = 3, K_H = 3, K_W = 3; - constexpr int PAD_P = 1, PAD_N = 1, PAD_T = 1, PAD_B = 1, PAD_L = 1, - PAD_R = 1; - int T_OUT = (T + PAD_P + PAD_N - K_T) / stride_t + 1; - int H_OUT = (H + PAD_T + PAD_B - K_H) / stride_h + 1; - int W_OUT = (W + PAD_L + PAD_R - K_W) / stride_w + 1; - const int8_t* Bp = B.PackedMat(); - - int32_t* row_offsets = static_cast<int32_t*>(ALIGNED_MALLOC(((K + 31) / 32 * 32)*sizeof(int32_t), 64)); // __attribute__((aligned(64))); - - int n_begin, n_end; - int t_begin, t_end, h_begin, h_end; - if (N >= num_threads) { - int n_per_thread = (N + num_threads - 1) / num_threads; - n_begin = std::min(thread_id * n_per_thread, N); - n_end = std::min(n_begin + n_per_thread, N); - t_begin = 0; - t_end = T_OUT; - h_begin = 0; - h_end = H_OUT; - } else { - int nthreads_per_n = num_threads / N; - n_begin = std::min(thread_id / nthreads_per_n, N); - n_end = std::min(n_begin + 1, N); - - int tid_of_n_begin = std::min(n_begin * nthreads_per_n, num_threads); - int tid_of_n_end = std::min(tid_of_n_begin + nthreads_per_n, num_threads); - int nthreads_of_n = tid_of_n_end - tid_of_n_begin; - int tid_within_n = thread_id - tid_of_n_begin; - assert(tid_within_n >= 0); - assert(tid_within_n < nthreads_of_n); - - // n is processed by num_threads_t * num_threads_h 2D grid of threads - int num_threads_t, num_threads_h; - // num_threads_w <= num_threads_h - tie(num_threads_t, num_threads_h) = closest_factors_(nthreads_of_n); - int tid_t = tid_within_n / num_threads_h; - int tid_h = tid_within_n % num_threads_h; - - int t_per_thread = (T_OUT + num_threads_t - 1) / num_threads_t; - t_begin = std::min(tid_t * t_per_thread, T_OUT); - t_end = std::min(t_begin + t_per_thread, T_OUT); - - int h_per_thread = (H_OUT + num_threads_h - 1) / num_threads_h; - h_begin = std::min(tid_h * h_per_thread, H_OUT); - h_end = std::min(h_begin + h_per_thread, H_OUT); - } - - for (int n = n_begin; n < n_end; ++n) { - const uint8_t* A_base = A + n * T * H * W * K; - uint8_t* C_uint8_base = C_uint8 + n * T_OUT * H_OUT * W_OUT * K; - - for (int t = t_begin; t < t_end; ++t) { - for (int h = h_begin; h < h_end; ++h) { - for (int w = 0; w < W_OUT; ++w) { - depthwise_3x3x3_per_channel_quantization_kernel_< - FUSE_RELU, - HAS_BIAS, - A_SYMMETRIC>( - T, - H, - W, - K, - t, - h, - w, - stride_t, - stride_h, - stride_w, - A_zero_point, - A_base, - B_zero_point, - Bp, - C_multiplier, - C_zero_point, - C_int32, - C_uint8_base, - row_offsets, - col_offsets, - bias); - } // w - } // h - } // t - } // for each n - - FREE(row_offsets); }; // Dispatch A_SYMMETRIC and B_SYMMETRIC -template <bool FUSE_RELU, bool HAS_BIAS> +template <bool FUSE_RELU, bool HAS_BIAS, typename BIAS_TYPE> static void depthwise_3x3_pad_1_( int N, int H, @@ -2314,12 +1003,13 @@ static void depthwise_3x3_pad_1_( int32_t A_zero_point, const uint8_t* A, int32_t B_zero_point, - const Packed3x3ConvMatrix& B, + const PackedDepthWiseConvMatrix& B, float C_multiplier, int32_t C_zero_point, uint8_t* C, const int32_t* col_offsets, - const int32_t* bias, + const BIAS_TYPE* bias, + float act_times_w_scale, int thread_id, int num_threads) { int32_t* C_int32_temp = new int32_t[(K + 31) / 32 * 32]; @@ -2329,7 +1019,8 @@ static void depthwise_3x3_pad_1_( FUSE_RELU, HAS_BIAS, true /*A_symmetric*/, - true /*B_symmetric*/>( + true /*B_symmetric*/, + BIAS_TYPE>( N, H, W, @@ -2346,6 +1037,7 @@ static void depthwise_3x3_pad_1_( C, col_offsets, bias, + act_times_w_scale, thread_id, num_threads); } else { @@ -2353,7 +1045,8 @@ static void depthwise_3x3_pad_1_( FUSE_RELU, HAS_BIAS, true /*A_symmetric*/, - false /*B_symmetric*/>( + false /*B_symmetric*/, + BIAS_TYPE>( N, H, W, @@ -2370,6 +1063,7 @@ static void depthwise_3x3_pad_1_( C, col_offsets, bias, + act_times_w_scale, thread_id, num_threads); } @@ -2379,7 +1073,8 @@ static void depthwise_3x3_pad_1_( FUSE_RELU, HAS_BIAS, false /*A_symmetric*/, - true /*B_symmetric*/>( + true /*B_symmetric*/, + BIAS_TYPE>( N, H, W, @@ -2396,6 +1091,7 @@ static void depthwise_3x3_pad_1_( C, col_offsets, bias, + act_times_w_scale, thread_id, num_threads); } else { @@ -2403,7 +1099,8 @@ static void depthwise_3x3_pad_1_( FUSE_RELU, HAS_BIAS, false /*A_symmetric*/, - false /*B_symmetric*/>( + false /*B_symmetric*/, + BIAS_TYPE>( N, H, W, @@ -2420,6 +1117,7 @@ static void depthwise_3x3_pad_1_( C, col_offsets, bias, + act_times_w_scale, thread_id, num_threads); } @@ -2428,7 +1126,7 @@ static void depthwise_3x3_pad_1_( } // Dispatch HAS_BIAS -template <bool FUSE_RELU> +template <bool FUSE_RELU, typename BIAS_TYPE> static void depthwise_3x3_pad_1_( int N, int H, @@ -2439,16 +1137,17 @@ static void depthwise_3x3_pad_1_( int32_t A_zero_point, const uint8_t* A, int32_t B_zero_point, - const Packed3x3ConvMatrix& B, + const PackedDepthWiseConvMatrix& B, float C_multiplier, int32_t C_zero_point, uint8_t* C, const int32_t* col_offsets, - const int32_t* bias, + const BIAS_TYPE* bias, + float act_times_w_scale, int thread_id, int num_threads) { if (bias) { - depthwise_3x3_pad_1_<FUSE_RELU, true /*HAS_BIAS*/>( + depthwise_3x3_pad_1_<FUSE_RELU, true /*HAS_BIAS*/, BIAS_TYPE>( N, H, W, @@ -2464,10 +1163,11 @@ static void depthwise_3x3_pad_1_( C, col_offsets, bias, + act_times_w_scale, thread_id, num_threads); } else { - depthwise_3x3_pad_1_<FUSE_RELU, false /*HAS_BIAS*/>( + depthwise_3x3_pad_1_<FUSE_RELU, false /*HAS_BIAS*/, BIAS_TYPE>( N, H, W, @@ -2483,6 +1183,7 @@ static void depthwise_3x3_pad_1_( C, col_offsets, bias, + act_times_w_scale, thread_id, num_threads); } @@ -2490,6 +1191,7 @@ static void depthwise_3x3_pad_1_( // Dispatch input shape and FUSE_RELU // assumption: W > 3 and H > 3 +template <typename BIAS_TYPE> void depthwise_3x3_pad_1( int N, int H, @@ -2500,18 +1202,33 @@ void depthwise_3x3_pad_1( int32_t A_zero_point, const uint8_t* A, int32_t B_zero_point, - const Packed3x3ConvMatrix& B, + const PackedDepthWiseConvMatrix& B, float C_multiplier, int32_t C_zero_point, uint8_t* C, const int32_t* col_offsets, - const int32_t* bias, + const BIAS_TYPE* bias, bool fuse_relu, + float act_times_w_scale, int thread_id, int num_threads) { + if (B.GetKernelProduct() != 3 * 3) { + string msg = + "[FBGEMM_CONV_ERROR] Packed weight is expected to have kernel_prod " + + to_string(3 * 3) + " but has " + to_string(B.GetKernelProduct()); + throw logic_error(msg); + } + if (stride_h == 0 || stride_w == 0 || num_threads == 0) { + assert(0 && "stride_h == 0 || stride_w == 0 || num_threads == 0"); + return; + } + if (N == 0) { + // In C2, batch size 0 is allowed, so we should just early return. + return; + } if (fuse_relu) { if (7 == H && 7 == W && 1 == stride_h && 1 == stride_w) { - depthwise_3x3_pad_1_<true /* FUSE_RELU */>( + depthwise_3x3_pad_1_<true /* FUSE_RELU */, BIAS_TYPE>( N, H, W, @@ -2527,10 +1244,11 @@ void depthwise_3x3_pad_1( C, col_offsets, bias, + act_times_w_scale, thread_id, num_threads); } else if (14 == H && 14 == W && 2 == stride_h && 2 == stride_w) { - depthwise_3x3_pad_1_<true /* FUSE_RELU */>( + depthwise_3x3_pad_1_<true /* FUSE_RELU */, BIAS_TYPE>( N, H, W, @@ -2546,10 +1264,11 @@ void depthwise_3x3_pad_1( C, col_offsets, bias, + act_times_w_scale, thread_id, num_threads); } else if (1 == stride_h && 1 == stride_w) { - depthwise_3x3_pad_1_<true /* FUSE_RELU */>( + depthwise_3x3_pad_1_<true /* FUSE_RELU */, BIAS_TYPE>( N, H, W, @@ -2565,10 +1284,11 @@ void depthwise_3x3_pad_1( C, col_offsets, bias, + act_times_w_scale, thread_id, num_threads); } else if (2 == stride_h && 2 == stride_w) { - depthwise_3x3_pad_1_<true /* FUSE_RELU */>( + depthwise_3x3_pad_1_<true /* FUSE_RELU */, BIAS_TYPE>( N, H, W, @@ -2584,10 +1304,11 @@ void depthwise_3x3_pad_1( C, col_offsets, bias, + act_times_w_scale, thread_id, num_threads); } else { - depthwise_3x3_pad_1_<true /* FUSE_RELU */>( + depthwise_3x3_pad_1_<true /* FUSE_RELU */, BIAS_TYPE>( N, H, W, @@ -2603,12 +1324,13 @@ void depthwise_3x3_pad_1( C, col_offsets, bias, + act_times_w_scale, thread_id, num_threads); } } else { if (7 == H && 7 == W && 1 == stride_h && 1 == stride_w) { - depthwise_3x3_pad_1_<false /* FUSE_RELU */>( + depthwise_3x3_pad_1_<false /* FUSE_RELU */, BIAS_TYPE>( N, H, W, @@ -2624,10 +1346,11 @@ void depthwise_3x3_pad_1( C, col_offsets, bias, + act_times_w_scale, thread_id, num_threads); } else if (14 == H && 14 == W && 2 == stride_h && 2 == stride_w) { - depthwise_3x3_pad_1_<false /* FUSE_RELU */>( + depthwise_3x3_pad_1_<false /* FUSE_RELU */, BIAS_TYPE>( N, H, W, @@ -2643,10 +1366,11 @@ void depthwise_3x3_pad_1( C, col_offsets, bias, + act_times_w_scale, thread_id, num_threads); } else if (1 == stride_h && 1 == stride_w) { - depthwise_3x3_pad_1_<false /* FUSE_RELU */>( + depthwise_3x3_pad_1_<false /* FUSE_RELU */, BIAS_TYPE>( N, H, W, @@ -2662,10 +1386,11 @@ void depthwise_3x3_pad_1( C, col_offsets, bias, + act_times_w_scale, thread_id, num_threads); } else if (2 == stride_h && 2 == stride_w) { - depthwise_3x3_pad_1_<false /* FUSE_RELU */>( + depthwise_3x3_pad_1_<false /* FUSE_RELU */, BIAS_TYPE>( N, H, W, @@ -2681,10 +1406,11 @@ void depthwise_3x3_pad_1( C, col_offsets, bias, + act_times_w_scale, thread_id, num_threads); } else { - depthwise_3x3_pad_1_<false /* FUSE_RELU */>( + depthwise_3x3_pad_1_<false /* FUSE_RELU */, BIAS_TYPE>( N, H, W, @@ -2700,283 +1426,15 @@ void depthwise_3x3_pad_1( C, col_offsets, bias, + act_times_w_scale, thread_id, num_threads); } } } -// Dispatch A_SYMMETRIC and B_SYMMETRIC -template <bool FUSE_RELU, bool HAS_BIAS> -static void depthwise_3x3x3_pad_1_( - int N, - int T, - int H, - int W, - int K, - int stride_t, - int stride_h, - int stride_w, - int32_t A_zero_point, - const uint8_t* A, - int32_t B_zero_point, - const Packed3x3x3ConvMatrix& B, - float C_multiplier, - int32_t C_zero_point, - uint8_t* C, - const int32_t* col_offsets, - const int32_t* bias, - int thread_id, - int num_threads) { - int32_t* C_int32_temp = new int32_t[(K + 31) / 32 * 32]; - if (A_zero_point == 0 || col_offsets == nullptr) { - if (B_zero_point == 0) { - depthwise_3x3x3_pad_1_< - FUSE_RELU, - HAS_BIAS, - true /*A_symmetric*/, - true /*B_symmetric*/>( - N, - T, - H, - W, - K, - stride_t, - stride_h, - stride_w, - A_zero_point, - A, - B_zero_point, - B, - C_multiplier, - C_zero_point, - C_int32_temp, - C, - col_offsets, - bias, - thread_id, - num_threads); - } else { - depthwise_3x3x3_pad_1_< - FUSE_RELU, - HAS_BIAS, - true /*A_symmetric*/, - false /*B_symmetric*/>( - N, - T, - H, - W, - K, - stride_t, - stride_h, - stride_w, - A_zero_point, - A, - B_zero_point, - B, - C_multiplier, - C_zero_point, - C_int32_temp, - C, - col_offsets, - bias, - thread_id, - num_threads); - } - } else { - if (B_zero_point == 0) { - depthwise_3x3x3_pad_1_< - FUSE_RELU, - HAS_BIAS, - false /*A_symmetric*/, - true /*B_symmetric*/>( - N, - T, - H, - W, - K, - stride_t, - stride_h, - stride_w, - A_zero_point, - A, - B_zero_point, - B, - C_multiplier, - C_zero_point, - C_int32_temp, - C, - col_offsets, - bias, - thread_id, - num_threads); - } else { - depthwise_3x3x3_pad_1_< - FUSE_RELU, - HAS_BIAS, - false /*A_symmetric*/, - false /*B_symmetric*/>( - N, - T, - H, - W, - K, - stride_t, - stride_h, - stride_w, - A_zero_point, - A, - B_zero_point, - B, - C_multiplier, - C_zero_point, - C_int32_temp, - C, - col_offsets, - bias, - thread_id, - num_threads); - } - } - delete[] C_int32_temp; -} - -// Dispatch HAS_BIAS -template <bool FUSE_RELU> -static void depthwise_3x3x3_pad_1_( - int N, - int T, - int H, - int W, - int K, - int stride_t, - int stride_h, - int stride_w, - int32_t A_zero_point, - const uint8_t* A, - int32_t B_zero_point, - const Packed3x3x3ConvMatrix& B, - float C_multiplier, - int32_t C_zero_point, - uint8_t* C, - const int32_t* col_offsets, - const int32_t* bias, - int thread_id, - int num_threads) { - if (bias) { - depthwise_3x3x3_pad_1_<FUSE_RELU, true /*HAS_BIAS*/>( - N, - T, - H, - W, - K, - stride_t, - stride_h, - stride_w, - A_zero_point, - A, - B_zero_point, - B, - C_multiplier, - C_zero_point, - C, - col_offsets, - bias, - thread_id, - num_threads); - } else { - depthwise_3x3x3_pad_1_<FUSE_RELU, false /*HAS_BIAS*/>( - N, - T, - H, - W, - K, - stride_t, - stride_h, - stride_w, - A_zero_point, - A, - B_zero_point, - B, - C_multiplier, - C_zero_point, - C, - col_offsets, - bias, - thread_id, - num_threads); - } -} - -// Dispatch FUSE_RELU -void depthwise_3x3x3_pad_1( - int N, - int T, - int H, - int W, - int K, - int stride_t, - int stride_h, - int stride_w, - int32_t A_zero_point, - const uint8_t* A, - int32_t B_zero_point, - const Packed3x3x3ConvMatrix& B, - float C_multiplier, - int32_t C_zero_point, - uint8_t* C, - const int32_t* col_offsets, - const int32_t* bias, - bool fuse_relu, - int thread_id, - int num_threads) { - if (fuse_relu) { - depthwise_3x3x3_pad_1_<true /*FUSE_RELU*/>( - N, - T, - H, - W, - K, - stride_t, - stride_h, - stride_w, - A_zero_point, - A, - B_zero_point, - B, - C_multiplier, - C_zero_point, - C, - col_offsets, - bias, - thread_id, - num_threads); - } else { - depthwise_3x3x3_pad_1_<false /*FUSE_RELU*/>( - N, - T, - H, - W, - K, - stride_t, - stride_h, - stride_w, - A_zero_point, - A, - B_zero_point, - B, - C_multiplier, - C_zero_point, - C, - col_offsets, - bias, - thread_id, - num_threads); - } -} - // Dispatch A_SYMMETRIC -template <bool FUSE_RELU, bool HAS_BIAS> +template <bool FUSE_RELU, bool HAS_BIAS, typename BIAS_TYPE> static void depthwise_3x3_per_channel_quantization_pad_1_( int N, int H, @@ -2987,12 +1445,13 @@ static void depthwise_3x3_per_channel_quantization_pad_1_( int32_t A_zero_point, const uint8_t* A, const int32_t* B_zero_point, - const Packed3x3ConvMatrix& Bp, + const PackedDepthWiseConvMatrix& Bp, const float* C_multiplier, int32_t C_zero_point, uint8_t* C, const int32_t* col_offsets, - const int32_t* bias, + const BIAS_TYPE* bias, + const float* act_times_w_scale, int thread_id, int num_threads) { int32_t* C_int32_temp = new int32_t[(K + 31) / 32 * 32]; @@ -3000,7 +1459,8 @@ static void depthwise_3x3_per_channel_quantization_pad_1_( depthwise_3x3_per_channel_quantization_pad_1_< FUSE_RELU, HAS_BIAS, - true /*A_SYMM*/>( + true /*A_SYMM*/, + BIAS_TYPE>( N, H, W, @@ -3017,13 +1477,15 @@ static void depthwise_3x3_per_channel_quantization_pad_1_( C, col_offsets, bias, + act_times_w_scale, thread_id, num_threads); } else { depthwise_3x3_per_channel_quantization_pad_1_< FUSE_RELU, HAS_BIAS, - false /*A_SYMM*/>( + false /*A_SYMM*/, + BIAS_TYPE>( N, H, W, @@ -3040,6 +1502,7 @@ static void depthwise_3x3_per_channel_quantization_pad_1_( C, col_offsets, bias, + act_times_w_scale, thread_id, num_threads); } @@ -3047,7 +1510,7 @@ static void depthwise_3x3_per_channel_quantization_pad_1_( } // Dispatch HAS_BIAS -template <bool FUSE_RELU> +template <bool FUSE_RELU, typename BIAS_TYPE> static void depthwise_3x3_per_channel_quantization_pad_1_( int N, int H, @@ -3058,18 +1521,20 @@ static void depthwise_3x3_per_channel_quantization_pad_1_( int32_t A_zero_point, const uint8_t* A, const int32_t* B_zero_point, - const Packed3x3ConvMatrix& Bp, + const PackedDepthWiseConvMatrix& Bp, const float* C_multiplier, int32_t C_zero_point, uint8_t* C, const int32_t* col_offsets, - const int32_t* bias, + const BIAS_TYPE* bias, + const float* act_times_w_scale, int thread_id, int num_threads) { if (bias) { depthwise_3x3_per_channel_quantization_pad_1_< FUSE_RELU, - true /* HAS_BIAS */>( + true /* HAS_BIAS */, + BIAS_TYPE>( N, H, W, @@ -3085,12 +1550,14 @@ static void depthwise_3x3_per_channel_quantization_pad_1_( C, col_offsets, bias, + act_times_w_scale, thread_id, num_threads); } else { depthwise_3x3_per_channel_quantization_pad_1_< FUSE_RELU, - false /* HAS_BIAS */>( + false /* HAS_BIAS */, + BIAS_TYPE>( N, H, W, @@ -3106,12 +1573,14 @@ static void depthwise_3x3_per_channel_quantization_pad_1_( C, col_offsets, bias, + act_times_w_scale, thread_id, num_threads); } } // Dispatch input shape and FUSE_RELU +template <typename BIAS_TYPE> void depthwise_3x3_per_channel_quantization_pad_1( int N, int H, @@ -3122,18 +1591,35 @@ void depthwise_3x3_per_channel_quantization_pad_1( int32_t A_zero_point, const uint8_t* A, const int32_t* B_zero_point, - const Packed3x3ConvMatrix& Bp, + const PackedDepthWiseConvMatrix& Bp, const float* C_multiplier, int32_t C_zero_point, uint8_t* C, const int32_t* col_offsets, - const int32_t* bias, + const BIAS_TYPE* bias, bool fuse_relu, + const float* act_times_w_scale, int thread_id, int num_threads) { + if (Bp.GetKernelProduct() != 3 * 3) { + string msg = + "[FBGEMM_CONV_ERROR] Packed weight is expected to have kernel_prod " + + to_string(3 * 3) + " but has " + to_string(Bp.GetKernelProduct()); + throw logic_error(msg); + } + if (stride_h == 0 || stride_w == 0 || num_threads == 0) { + assert(0 && "stride_h == 0 || stride_w == 0 || num_threads == 0"); + return; + } + if (N == 0) { + // In C2, batch size 0 is allowed, so we should just early return. + return; + } if (fuse_relu) { if (7 == H && 7 == W && 1 == stride_h && 1 == stride_w) { - depthwise_3x3_per_channel_quantization_pad_1_<true /* FUSE_RELU */>( + depthwise_3x3_per_channel_quantization_pad_1_< + true /* FUSE_RELU */, + BIAS_TYPE>( N, H, W, @@ -3149,10 +1635,13 @@ void depthwise_3x3_per_channel_quantization_pad_1( C, col_offsets, bias, + act_times_w_scale, thread_id, num_threads); } else if (14 == H && 14 == W && 2 == stride_h && 2 == stride_w) { - depthwise_3x3_per_channel_quantization_pad_1_<true /* FUSE_RELU */>( + depthwise_3x3_per_channel_quantization_pad_1_< + true /* FUSE_RELU */, + BIAS_TYPE>( N, H, W, @@ -3168,10 +1657,13 @@ void depthwise_3x3_per_channel_quantization_pad_1( C, col_offsets, bias, + act_times_w_scale, thread_id, num_threads); } else if (1 == stride_h && 1 == stride_w) { - depthwise_3x3_per_channel_quantization_pad_1_<true /* FUSE_RELU */>( + depthwise_3x3_per_channel_quantization_pad_1_< + true /* FUSE_RELU */, + BIAS_TYPE>( N, H, W, @@ -3187,10 +1679,13 @@ void depthwise_3x3_per_channel_quantization_pad_1( C, col_offsets, bias, + act_times_w_scale, thread_id, num_threads); } else if (2 == stride_h && 2 == stride_w) { - depthwise_3x3_per_channel_quantization_pad_1_<true /* FUSE_RELU */>( + depthwise_3x3_per_channel_quantization_pad_1_< + true /* FUSE_RELU */, + BIAS_TYPE>( N, H, W, @@ -3206,10 +1701,13 @@ void depthwise_3x3_per_channel_quantization_pad_1( C, col_offsets, bias, + act_times_w_scale, thread_id, num_threads); } else { - depthwise_3x3_per_channel_quantization_pad_1_<true /* FUSE_RELU */>( + depthwise_3x3_per_channel_quantization_pad_1_< + true /* FUSE_RELU */, + BIAS_TYPE>( N, H, W, @@ -3225,12 +1723,15 @@ void depthwise_3x3_per_channel_quantization_pad_1( C, col_offsets, bias, + act_times_w_scale, thread_id, num_threads); } } else { if (7 == H && 7 == W && 1 == stride_h && 1 == stride_w) { - depthwise_3x3_per_channel_quantization_pad_1_<false /* FUSE_RELU */>( + depthwise_3x3_per_channel_quantization_pad_1_< + false /* FUSE_RELU */, + BIAS_TYPE>( N, H, W, @@ -3246,10 +1747,13 @@ void depthwise_3x3_per_channel_quantization_pad_1( C, col_offsets, bias, + act_times_w_scale, thread_id, num_threads); } else if (14 == H && 14 == W && 2 == stride_h && 2 == stride_w) { - depthwise_3x3_per_channel_quantization_pad_1_<false /* FUSE_RELU */>( + depthwise_3x3_per_channel_quantization_pad_1_< + false /* FUSE_RELU */, + BIAS_TYPE>( N, H, W, @@ -3265,10 +1769,13 @@ void depthwise_3x3_per_channel_quantization_pad_1( C, col_offsets, bias, + act_times_w_scale, thread_id, num_threads); } else if (1 == stride_h && 1 == stride_w) { - depthwise_3x3_per_channel_quantization_pad_1_<false /* FUSE_RELU */>( + depthwise_3x3_per_channel_quantization_pad_1_< + false /* FUSE_RELU */, + BIAS_TYPE>( N, H, W, @@ -3284,10 +1791,13 @@ void depthwise_3x3_per_channel_quantization_pad_1( C, col_offsets, bias, + act_times_w_scale, thread_id, num_threads); } else if (2 == stride_h && 2 == stride_w) { - depthwise_3x3_per_channel_quantization_pad_1_<false /* FUSE_RELU */>( + depthwise_3x3_per_channel_quantization_pad_1_< + false /* FUSE_RELU */, + BIAS_TYPE>( N, H, W, @@ -3303,10 +1813,13 @@ void depthwise_3x3_per_channel_quantization_pad_1( C, col_offsets, bias, + act_times_w_scale, thread_id, num_threads); } else { - depthwise_3x3_per_channel_quantization_pad_1_<false /* FUSE_RELU */>( + depthwise_3x3_per_channel_quantization_pad_1_< + false /* FUSE_RELU */, + BIAS_TYPE>( N, H, W, @@ -3322,225 +1835,179 @@ void depthwise_3x3_per_channel_quantization_pad_1( C, col_offsets, bias, + act_times_w_scale, thread_id, num_threads); } } } -// Dispatch A_SYMMETRIC -template <bool FUSE_RELU, bool HAS_BIAS> -static void depthwise_3x3x3_per_channel_quantization_pad_1_( +// To be removed +void depthwise_3x3_pad_1( int N, - int T, int H, int W, int K, - int stride_t, int stride_h, int stride_w, int32_t A_zero_point, const uint8_t* A, - const int32_t* B_zero_point, - const Packed3x3x3ConvMatrix& B, - const float* C_multiplier, + int32_t B_zero_point, + const PackedDepthWiseConvMatrix& B, + float C_multiplier, int32_t C_zero_point, uint8_t* C, const int32_t* col_offsets, const int32_t* bias, + bool fuse_relu, int thread_id, int num_threads) { - int32_t* C_int32_temp = new int32_t[(K + 31) / 32 * 32]; - if (A_zero_point == 0 || col_offsets == nullptr) { - depthwise_3x3x3_per_channel_quantization_pad_1_< - FUSE_RELU, - HAS_BIAS, - true /*A_SYMM*/>( - N, - T, - H, - W, - K, - stride_t, - stride_h, - stride_w, - A_zero_point, - A, - B_zero_point, - B, - C_multiplier, - C_zero_point, - C_int32_temp, - C, - col_offsets, - bias, - thread_id, - num_threads); - } else { - depthwise_3x3x3_per_channel_quantization_pad_1_< - FUSE_RELU, - HAS_BIAS, - false /*A_SYMM*/>( - N, - T, - H, - W, - K, - stride_t, - stride_h, - stride_w, - A_zero_point, - A, - B_zero_point, - B, - C_multiplier, - C_zero_point, - C_int32_temp, - C, - col_offsets, - bias, - thread_id, - num_threads); - } - delete[] C_int32_temp; + depthwise_3x3_pad_1<std::int32_t>( + N, + H, + W, + K, + stride_h, + stride_w, + A_zero_point, + A, + B_zero_point, + B, + C_multiplier, + C_zero_point, + C, + col_offsets, + bias, + fuse_relu, + 1.0f, + thread_id, + num_threads); } -// Dispatch HAS_BIAS -template <bool FUSE_RELU> -static void depthwise_3x3x3_per_channel_quantization_pad_1_( +// To be removed +void depthwise_3x3_per_channel_quantization_pad_1( int N, - int T, int H, int W, int K, - int stride_t, int stride_h, int stride_w, int32_t A_zero_point, const uint8_t* A, const int32_t* B_zero_point, - const Packed3x3x3ConvMatrix& B, + const PackedDepthWiseConvMatrix& Bp, const float* C_multiplier, int32_t C_zero_point, uint8_t* C, const int32_t* col_offsets, const int32_t* bias, + bool fuse_relu, int thread_id, int num_threads) { - if (bias) { - depthwise_3x3x3_per_channel_quantization_pad_1_< - FUSE_RELU, - true /* HAS_BIAS */>( - N, - T, - H, - W, - K, - stride_t, - stride_h, - stride_w, - A_zero_point, - A, - B_zero_point, - B, - C_multiplier, - C_zero_point, - C, - col_offsets, - bias, - thread_id, - num_threads); - } else { - depthwise_3x3x3_per_channel_quantization_pad_1_< - FUSE_RELU, - false /* HAS_BIAS */>( - N, - T, - H, - W, - K, - stride_t, - stride_h, - stride_w, - A_zero_point, - A, - B_zero_point, - B, - C_multiplier, - C_zero_point, - C, - col_offsets, - bias, - thread_id, - num_threads); - } + depthwise_3x3_per_channel_quantization_pad_1<std::int32_t>( + N, + H, + W, + K, + stride_h, + stride_w, + A_zero_point, + A, + B_zero_point, + Bp, + C_multiplier, + C_zero_point, + C, + col_offsets, + bias, + fuse_relu, + nullptr, + thread_id, + num_threads); } -// Dispatch FUSE_RELU -void depthwise_3x3x3_per_channel_quantization_pad_1( +template void depthwise_3x3_pad_1( + int N, + int H, + int W, + int K, + int stride_h, + int stride_w, + int32_t A_zero_point, + const uint8_t* A, + int32_t B_zero_point, + const PackedDepthWiseConvMatrix& B, + float C_multiplier, + int32_t C_zero_point, + uint8_t* C, + const int32_t* col_offsets, + const int32_t* bias, + bool fuse_relu, + float act_times_w_scale, + int thread_id, + int num_threads); + +template void depthwise_3x3_pad_1( + int N, + int H, + int W, + int K, + int stride_h, + int stride_w, + int32_t A_zero_point, + const uint8_t* A, + int32_t B_zero_point, + const PackedDepthWiseConvMatrix& B, + float C_multiplier, + int32_t C_zero_point, + uint8_t* C, + const int32_t* col_offsets, + const float* bias, + bool fuse_relu, + float act_times_w_scale, + int thread_id, + int num_threads); + +template void depthwise_3x3_per_channel_quantization_pad_1( int N, - int T, int H, int W, int K, - int stride_t, int stride_h, int stride_w, int32_t A_zero_point, const uint8_t* A, const int32_t* B_zero_point, - const Packed3x3x3ConvMatrix& B, + const PackedDepthWiseConvMatrix& Bp, const float* C_multiplier, int32_t C_zero_point, uint8_t* C, const int32_t* col_offsets, const int32_t* bias, bool fuse_relu, + const float* act_times_w_scale, int thread_id, - int num_threads) { - if (fuse_relu) { - depthwise_3x3x3_per_channel_quantization_pad_1_<true /* FUSE_RELU */>( - N, - T, - H, - W, - K, - stride_t, - stride_h, - stride_w, - A_zero_point, - A, - B_zero_point, - B, - C_multiplier, - C_zero_point, - C, - col_offsets, - bias, - thread_id, - num_threads); - } else { - depthwise_3x3x3_per_channel_quantization_pad_1_<false /* FUSE_RELU */>( - N, - T, - H, - W, - K, - stride_t, - stride_h, - stride_w, - A_zero_point, - A, - B_zero_point, - B, - C_multiplier, - C_zero_point, - C, - col_offsets, - bias, - thread_id, - num_threads); - } -} + int num_threads); + +template void depthwise_3x3_per_channel_quantization_pad_1( + int N, + int H, + int W, + int K, + int stride_h, + int stride_w, + int32_t A_zero_point, + const uint8_t* A, + const int32_t* B_zero_point, + const PackedDepthWiseConvMatrix& Bp, + const float* C_multiplier, + int32_t C_zero_point, + uint8_t* C, + const int32_t* col_offsets, + const float* bias, + bool fuse_relu, + const float* act_times_w_scale, + int thread_id, + int num_threads); } // namespace fbgemm diff --git a/src/GenerateKernel.h b/src/GenerateKernel.h index dccdfc5..c0fece4 100644 --- a/src/GenerateKernel.h +++ b/src/GenerateKernel.h @@ -8,8 +8,11 @@ #include <asmjit/asmjit.h> #include <cpuinfo.h> #include <map> +#include <mutex> +#include <sstream> #include <string> #include <tuple> +#include "CodeCache.h" #include "fbgemm/Fbgemm.h" /*#define FBGEMM_LOG_CODE 1*/ @@ -18,7 +21,7 @@ namespace fbgemm { namespace x86 = asmjit::x86; /** - * @brief AVX2/AVX512 JIT assembly code generator. + * @brief AVX2/AVX512/AVX512VNNI JIT assembly code generator. * @tparam TA Type of matrix A. * @tparam TB Type of matrix B. * @tparam TC Type of matrix C. @@ -40,35 +43,7 @@ class CodeGenBase { * @brief Constructor for initializing AVX2/AVX512 registers. */ CodeGenBase(const BlockingFactors* params = nullptr) - : blocking_params(params), - CRegs_avx2_{x86::ymm0, - x86::ymm1, - x86::ymm2, - x86::ymm3, - x86::ymm4, - x86::ymm5, - x86::ymm6, - x86::ymm7, - x86::ymm8, - x86::ymm9, - x86::ymm10, - x86::ymm11}, - CRegs_avx512_{ - x86::zmm0, x86::zmm1, x86::zmm2, x86::zmm3, x86::zmm4, - x86::zmm5, x86::zmm6, x86::zmm7, x86::zmm8, x86::zmm9, - x86::zmm10, x86::zmm11, x86::zmm12, x86::zmm13, x86::zmm14, - x86::zmm15, x86::zmm16, x86::zmm17, x86::zmm18, x86::zmm19, - x86::zmm20, x86::zmm21, x86::zmm22, x86::zmm23, x86::zmm24, - x86::zmm25, x86::zmm26, x86::zmm27, - }, - AllRegs_avx512_{x86::zmm0, x86::zmm1, x86::zmm2, x86::zmm3, - x86::zmm4, x86::zmm5, x86::zmm6, x86::zmm7, - x86::zmm8, x86::zmm9, x86::zmm10, x86::zmm11, - x86::zmm12, x86::zmm13, x86::zmm14, x86::zmm15, - x86::zmm16, x86::zmm17, x86::zmm18, x86::zmm19, - x86::zmm20, x86::zmm21, x86::zmm22, x86::zmm23, - x86::zmm24, x86::zmm25, x86::zmm26, x86::zmm27, - x86::zmm28, x86::zmm29, x86::zmm30, x86::zmm31} { + : blocking_params(params) { // vector width in bits if (cpuinfo_initialize()) { if (fbgemmHasAvx512Support()) { @@ -104,7 +79,7 @@ class CodeGenBase { */ template <inst_set_t instSet> void initCRegs( - asmjit::X86Emitter* a, + x86::Emitter* a, int rowRegs, int colRegs, int leadingDimCRegAssign = 4); @@ -114,10 +89,10 @@ class CodeGenBase { */ template <inst_set_t instSet> void genComputeBlock( - asmjit::X86Emitter* a, - asmjit::X86Gp buffer_A, - asmjit::X86Gp buffer_B, - asmjit::X86Gp B_pf, + x86::Emitter* a, + x86::Gp buffer_A, + x86::Gp buffer_B, + x86::Gp B_pf, int rowRegs, int colRegs, int lda, @@ -129,11 +104,11 @@ class CodeGenBase { */ template <inst_set_t instSet> void storeCRegs( - asmjit::X86Emitter* a, + x86::Emitter* a, int rowRegs, int colRegs, - asmjit::X86Gp C_Offset, - asmjit::X86Gp ldcReg, + x86::Gp C_Offset, + x86::Gp ldcReg, bool accum, int leadingDimCRegAssign = 4); @@ -143,7 +118,7 @@ class CodeGenBase { * (debug-only) */ template <inst_set_t instSet> - std::string getCodeLoggingFile( + static std::string getCodeLoggingFile( bool accum, int mc, int nc, @@ -152,48 +127,60 @@ class CodeGenBase { int MR, int NR, int NR_MIN) { - std::string fileName = "gemm_"; + std::ostringstream oss; + oss << "gemm_"; if (std::is_same<accT, std::int16_t>::value) { - fileName += "acc16_"; + oss << "acc16_"; } else if (std::is_same<accT, std::int32_t>::value) { - fileName += "acc32_"; + oss << "acc32_"; } else { - fileName += "unknown_"; + oss << "unknown_"; } - fileName += "accum-" + std::to_string(accum); - fileName += "_MC-" + std::to_string(mc); - fileName += "_NC-" + std::to_string(nc); - fileName += "_NCB-" + std::to_string(NCB); - fileName += "_NCB-" + std::to_string(KCB); - fileName += "_MR-" + std::to_string(MR); - fileName += "_NR-" + std::to_string(NR); - fileName += "_NR_MIN-" + std::to_string(NR_MIN); - if (instSet == inst_set_t::avx512) { - fileName += "_avx512"; + oss << "accum-" + std::to_string(accum) + << "_MC-" + std::to_string(mc) + << "_NC-" + std::to_string(nc) + << "_NCB-" + std::to_string(NCB) + << "_NCB-" + std::to_string(KCB) + << "_MR-" + std::to_string(MR) + << "_NR-" + std::to_string(NR) + << "_NR_MIN-" + std::to_string(NR_MIN); + if (instSet == inst_set_t::avx512_vnni) { + oss << "_avx512vnni"; + } else if (instSet == inst_set_t::avx512) { + oss << "_avx512"; } else if (instSet == inst_set_t::avx2) { - fileName += "_avx2"; + oss << "_avx2"; } - fileName += ".txt"; - return fileName; + oss << ".txt"; + return oss.str(); } private: - asmjit::X86Ymm - CRegs_avx2_[12]; ///< AVX2 ymm registers for C in the micro-kernel. - asmjit::X86Zmm - CRegs_avx512_[28]; ///< AVX512 zmm registers for C in the micro-kernel. - asmjit::X86Zmm - AllRegs_avx512_[32]; ///< all AVX512 zmm registers. - int vectorWidth_; ///< Vector width in bits. int VLEN_; ///< Vector width in elements. - static thread_local asmjit::JitRuntime rt_; ///< JIT Runtime for asmjit. - static thread_local asmjit::CodeHolder code_; ///< JIT Code Holder for asmjit. + + static asmjit::JitRuntime &runtime() { + static asmjit::JitRuntime rt; //< JIT Runtime for asmjit, + // depents on other static + // variables. Required to prevent + // initialization order fiasco + return rt; + } + + static std::mutex rtMutex_; ///< Controll access to runtime; + // The hash depends on accumulate, mc, nc, ncb, kcb, nr, mr, nr_min - static thread_local std::map< - std::tuple<bool, int, int, int, int, int, int, int>, - jit_micro_kernel_fp> + static CodeCache<std::tuple<bool, int, int, int, int, int, int, int>, + jit_micro_kernel_fp> codeCache_; ///< JIT Code Cache for reuse. }; +template <typename TA, typename TB, typename TC, typename accT> +std::mutex CodeGenBase<TA, TB, TC, accT>::rtMutex_; + +template <typename TA, typename TB, typename TC, typename accT> +CodeCache<std::tuple<bool, int, int, int, int, int, int, int>, + typename CodeGenBase<TA, TB, TC, accT>::jit_micro_kernel_fp> + CodeGenBase<TA, TB, TC, accT>::codeCache_; + } // namespace fbgemm diff --git a/src/GenerateKernelU8S8S32ACC16.cc b/src/GenerateKernelU8S8S32ACC16.cc index 082518c..205af14 100644 --- a/src/GenerateKernelU8S8S32ACC16.cc +++ b/src/GenerateKernelU8S8S32ACC16.cc @@ -9,18 +9,6 @@ namespace fbgemm { -template <typename TA, typename TB, typename TC, typename accT> -thread_local asmjit::JitRuntime CodeGenBase<TA, TB, TC, accT>::rt_; - -template <typename TA, typename TB, typename TC, typename accT> -thread_local asmjit::CodeHolder CodeGenBase<TA, TB, TC, accT>::code_; - -template <typename TA, typename TB, typename TC, typename accT> -thread_local std::map< - std::tuple<bool, int, int, int, int, int, int, int>, - typename CodeGenBase<TA, TB, TC, accT>::jit_micro_kernel_fp> - CodeGenBase<TA, TB, TC, accT>::codeCache_; - namespace x86 = asmjit::x86; /** @@ -31,16 +19,17 @@ template <> template <> void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::initCRegs< inst_set_t::avx2>( - asmjit::X86Emitter* a, + x86::Emitter* a, int rowRegs, int colRegs, - int leadingDimCRegAssign) { + int leadingDimCReg) { + using CRegs = x86::Ymm; for (int i = 0; i < rowRegs; ++i) { for (int j = 0; j < colRegs; ++j) { a->vxorps( - CRegs_avx2_[i * leadingDimCRegAssign + j], - CRegs_avx2_[i * leadingDimCRegAssign + j], - CRegs_avx2_[i * leadingDimCRegAssign + j]); + CRegs(i * leadingDimCReg + j), + CRegs(i * leadingDimCReg + j), + CRegs(i * leadingDimCReg + j)); } } } @@ -53,18 +42,20 @@ template <> template <> void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::genComputeBlock< inst_set_t::avx2>( - asmjit::X86Emitter* a, - asmjit::X86Gp buffer_A, - asmjit::X86Gp buffer_B, - asmjit::X86Gp /* unused (reserved for prefetching)*/, + x86::Emitter* a, + x86::Gp buffer_A, + x86::Gp buffer_B, + x86::Gp /* unused (reserved for prefetching)*/, int rowRegs, int colRegs, int lda, - int leadingDimCRegAssign) { + int leadingDimCReg) { // used for matrix A - asmjit::X86Ymm AReg = x86::ymm12; + x86::Ymm AReg = x86::ymm12; - asmjit::X86Ymm tmpReg = x86::ymm14; + x86::Ymm tmpReg = x86::ymm14; + + using CRegs = x86::Ymm; for (int i = 0; i < rowRegs; ++i) { // broadcast A @@ -74,9 +65,9 @@ void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::genComputeBlock< a->vpmaddubsw( tmpReg, AReg, x86::dword_ptr(buffer_B, j * VLEN_ * sizeof(int8_t))); a->vpaddsw( - CRegs_avx2_[i * leadingDimCRegAssign + j], + CRegs(i * leadingDimCReg + j), tmpReg, - CRegs_avx2_[i * leadingDimCRegAssign + j]); + CRegs(i * leadingDimCReg + j)); // Prefetching is hurting performance in some cases // because prefetch instructions itself consumes a slot // in pipeline issue thus slowing down the kernel. @@ -95,25 +86,30 @@ template <> template <> void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::storeCRegs< inst_set_t::avx2>( - asmjit::X86Emitter* a, + x86::Emitter* a, int rowRegs, int colRegs, - asmjit::X86Gp C_Offset, - asmjit::X86Gp ldcReg, + x86::Gp C_Offset, + x86::Gp ldcReg, bool accum, - int leadingDimCRegAssign) { - asmjit::X86Xmm extractDest128 = x86::xmm15; - asmjit::X86Ymm extractDest256 = x86::ymm15; + int leadingDimCReg) { + x86::Xmm extractDest128 = x86::xmm15; + x86::Ymm extractDest256 = x86::ymm15; + using CRegs = x86::Ymm; for (int i = 0; i < rowRegs; ++i) { a->imul(C_Offset, ldcReg, static_cast<asmjit::Imm>(i * sizeof(int32_t))); for (int j = 0; j < colRegs; ++j) { for (int idx = 0; idx < 2; ++idx) { a->vextracti128( - extractDest128, CRegs_avx2_[i * leadingDimCRegAssign + j], idx); + extractDest128, CRegs(i * leadingDimCReg + j), idx); a->vpmovsxwd(extractDest256, extractDest128); - asmjit::X86Mem destAddr = x86::dword_ptr( + x86::Mem destAddr = x86::dword_ptr( +#ifdef _MSC_VER + a->gpz(9), C_Offset, 0, (j * 2 + idx) * 8 * sizeof(int32_t)); +#else a->zcx(), C_Offset, 0, (j * 2 + idx) * 8 * sizeof(int32_t)); +#endif if (accum) { a->vpaddd(extractDest256, extractDest256, destAddr); } @@ -172,192 +168,195 @@ CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<inst_set_t::avx2>( nRegBlockSize, nRegBlockSizeMin); - if (codeCache_.find(kernelSig) != codeCache_.end()) { - return codeCache_[kernelSig]; - } - code_.reset(false); - code_.init(rt_.getCodeInfo()); - asmjit::X86Assembler assembler(&code_); - asmjit::X86Emitter* a = assembler.asEmitter(); + return codeCache_.getOrCreate(kernelSig, [&]() -> jit_micro_kernel_fp { + asmjit::CodeHolder code; + code.init(runtime().codeInfo()); + x86::Assembler assembler(&code); + x86::Emitter *a = assembler.as<x86::Emitter>(); #if defined(FBGEMM_LOG_CODE) - // generated code logging - FILE* codeLogfile = fopen( - getCodeLoggingFile<inst_set_t::avx2>( - accum, - mc, - nc, - nBlock, - kBlock, - mRegBlockSize, - nRegBlockSize, - nRegBlockSizeMin) - .c_str(), - "w"); - asmjit::FileLogger* codeLogger = new asmjit::FileLogger(codeLogfile); - if (codeLogger) { - code_.setLogger(codeLogger); - } + // generated code logging + FILE *codeLogfile = fopen(getCodeLoggingFile<inst_set_t::avx2>( + accum, mc, nc, nBlock, kBlock, mRegBlockSize, + nRegBlockSize, nRegBlockSizeMin) + .c_str(), + "w"); + asmjit::FileLogger *codeLogger = new asmjit::FileLogger(codeLogfile); + if (codeLogger) { + code.setLogger(codeLogger); + } #endif - int mRegBlocks = mc / mRegBlockSize; - int mRegBlocksRem = mc % mRegBlockSize; - assert(kc % row_interleave == 0 && "kc must be a multiple of row_interleave"); - // assert((nc == nRegBlockSize) && - //"nc must be equal to the number of register blocks"); - - // arguments to the function created - asmjit::X86Gp buffer_A = a->zdi(); - asmjit::X86Gp buffer_B = a->zsi(); - asmjit::X86Gp B_pf = a->zdx(); - asmjit::X86Gp CBase = a->zcx(); - asmjit::X86Gp kSize = a->gpzRef(8); - asmjit::X86Gp ldcReg = a->gpzRef(9); - - asmjit::FuncDetail func; - func.init( - asmjit:: - FuncSignature6<void, uint8_t*, int8_t*, int8_t*, int32_t*, int, int>( - asmjit::CallConv::kIdHost)); - - asmjit::FuncFrameInfo ffi; - ffi.setDirtyRegs( - asmjit::X86Reg::kKindVec, - asmjit::Utils::mask(0, 1, 2, 3, 4, 5, 6, 7) | - asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14, 15)); - ffi.setDirtyRegs( - asmjit::X86Reg::kKindGp, asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14)); - - asmjit::FuncArgsMapper args(&func); - args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg); - - args.updateFrameInfo(ffi); - - asmjit::FuncFrameLayout layout; - layout.init(func, ffi); - - asmjit::FuncUtils::emitProlog(a, layout); - asmjit::FuncUtils::allocArgs(a, layout, args); - - asmjit::Label Loopk = a->newLabel(); - asmjit::Label LoopMBlocks = a->newLabel(); - - asmjit::X86Gp buffer_B_saved = a->gpzRef(10); - asmjit::X86Gp C_Offset = a->gpzRef(11); - // asmjit::X86Gp B_pf_saved = a->gpzRef(12); - asmjit::X86Gp iIdx = a->gpzRef(13); - asmjit::X86Gp kIdx = a->gpzRef(14); - - int colRegs = nc * row_interleave * sizeof(int8_t) / VLEN_; - if (mRegBlocks > 0) { - // move 0 to iteration variables - a->mov(iIdx, 0); - - // save B_buffer address - a->mov(buffer_B_saved, buffer_B); - // a->mov(B_pf_saved, B_pf); - - a->bind(LoopMBlocks); - a->inc(iIdx); - - int rowRegs = mRegBlockSize; - - // init C registers - initCRegs<inst_set_t::avx2>(a, rowRegs, colRegs); - - // init k loop index - a->mov(kIdx, 0); - a->bind(Loopk); - // k is incremented by row_interleave - a->add(kIdx, static_cast<asmjit::Imm>(row_interleave)); - - genComputeBlock<inst_set_t::avx2>( - a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock); - - // update buffer_A address for next k iteration - a->add( - buffer_A, static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t))); - - // update buffer_B address for next k iteration - a->add( - buffer_B, - static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t))); - // a->add(B_pf, static_cast<asmjit::Imm>(nBlock * row_interleave * - // sizeof(int8_t))); - - a->cmp(kIdx, kSize); - a->jl(Loopk); - - // store C matrix - storeCRegs<inst_set_t::avx2>(a, rowRegs, colRegs, C_Offset, ldcReg, accum); - - // increment A for next block - a->sub(buffer_A, kSize); - a->add( - buffer_A, static_cast<asmjit::Imm>((rowRegs)*kBlock * sizeof(uint8_t))); - // increment C for next block - a->imul( - C_Offset, ldcReg, static_cast<asmjit::Imm>(rowRegs * sizeof(int32_t))); - a->add(CBase, C_Offset); - // reset B - a->mov(buffer_B, buffer_B_saved); - // a->mov(B_pf, B_pf_saved); - - a->cmp(iIdx, mRegBlocks); - a->jl(LoopMBlocks); - } - // generate code for remainder - if (mRegBlocksRem > 0) { - asmjit::Label LoopkRem = a->newLabel(); - int rowRegs = mRegBlocksRem; + int mRegBlocks = mc / mRegBlockSize; + int mRegBlocksRem = mc % mRegBlockSize; + assert(kc % row_interleave == 0 && + "kc must be a multiple of row_interleave"); + // assert((nc == nRegBlockSize) && + //"nc must be equal to the number of register blocks"); + + // arguments to the function created +#ifdef _MSC_VER + x86::Gp buffer_A = a->zcx(); + x86::Gp buffer_B = a->zdx(); + x86::Gp B_pf = a->gpz(8); + x86::Gp CBase = a->gpz(9); + x86::Gp kSize = a->zdi(); + x86::Gp ldcReg = a->zsi(); +#else + x86::Gp buffer_A = a->zdi(); + x86::Gp buffer_B = a->zsi(); + x86::Gp B_pf = a->zdx(); + x86::Gp CBase = a->zcx(); + x86::Gp kSize = a->gpz(8); + x86::Gp ldcReg = a->gpz(9); +#endif + + asmjit::FuncDetail func; + func.init( + asmjit::FuncSignatureT<void, uint8_t *, int8_t *, int8_t *, int32_t *, + int, int>(asmjit::CallConv::kIdHost)); + + asmjit::FuncFrame frame; + frame.init(func); + frame.setDirtyRegs( + x86::Reg::kGroupVec, + asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) | + asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15)); + frame.setDirtyRegs(x86::Reg::kGroupGp, + asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14)); + + asmjit::FuncArgsAssignment args(&func); + args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg); + + args.updateFuncFrame(frame); + frame.finalize(); + + a->emitProlog(frame); + a->emitArgsAssignment(frame, args); + + asmjit::Label Loopk = a->newLabel(); + asmjit::Label LoopMBlocks = a->newLabel(); + + x86::Gp buffer_B_saved = a->gpz(10); + x86::Gp C_Offset = a->gpz(11); + // x86::Gp B_pf_saved = a->gpz(12); + x86::Gp iIdx = a->gpz(13); + x86::Gp kIdx = a->gpz(14); + + int colRegs = nc * row_interleave * sizeof(int8_t) / VLEN_; + if (mRegBlocks > 0) { + // move 0 to iteration variables + a->mov(iIdx, 0); + + // save B_buffer address + a->mov(buffer_B_saved, buffer_B); + // a->mov(B_pf_saved, B_pf); + + a->bind(LoopMBlocks); + a->inc(iIdx); + + int rowRegs = mRegBlockSize; + + // init C registers + initCRegs<inst_set_t::avx2>(a, rowRegs, colRegs); + + // init k loop index + a->mov(kIdx, 0); + a->bind(Loopk); + // k is incremented by row_interleave + a->add(kIdx, static_cast<asmjit::Imm>(row_interleave)); + + genComputeBlock<inst_set_t::avx2>(a, buffer_A, buffer_B, B_pf, rowRegs, + colRegs, kBlock); + + // update buffer_A address for next k iteration + a->add(buffer_A, + static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t))); + + // update buffer_B address for next k iteration + a->add(buffer_B, static_cast<asmjit::Imm>(nBlock * row_interleave * + sizeof(int8_t))); + // a->add(B_pf, static_cast<asmjit::Imm>(nBlock * row_interleave * + // sizeof(int8_t))); + + a->cmp(kIdx, kSize); + a->jl(Loopk); + + // store C matrix + storeCRegs<inst_set_t::avx2>(a, rowRegs, colRegs, C_Offset, ldcReg, + accum); + + // increment A for next block + a->sub(buffer_A, kSize); + a->add(buffer_A, + static_cast<asmjit::Imm>((rowRegs)*kBlock * sizeof(uint8_t))); + // increment C for next block + a->imul(C_Offset, ldcReg, + static_cast<asmjit::Imm>(rowRegs * sizeof(int32_t))); + a->add(CBase, C_Offset); + // reset B + a->mov(buffer_B, buffer_B_saved); + // a->mov(B_pf, B_pf_saved); + + a->cmp(iIdx, mRegBlocks); + a->jl(LoopMBlocks); + } + // generate code for remainder + if (mRegBlocksRem > 0) { + asmjit::Label LoopkRem = a->newLabel(); + int rowRegs = mRegBlocksRem; - // init C registers - initCRegs<inst_set_t::avx2>(a, rowRegs, colRegs); + // init C registers + initCRegs<inst_set_t::avx2>(a, rowRegs, colRegs); - // init k loop index - a->mov(kIdx, 0); - a->bind(LoopkRem); + // init k loop index + a->mov(kIdx, 0); + a->bind(LoopkRem); - // k is incremented by row_interleave - a->add(kIdx, static_cast<asmjit::Imm>(row_interleave)); + // k is incremented by row_interleave + a->add(kIdx, static_cast<asmjit::Imm>(row_interleave)); - genComputeBlock<inst_set_t::avx2>( - a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock); + genComputeBlock<inst_set_t::avx2>(a, buffer_A, buffer_B, B_pf, rowRegs, + colRegs, kBlock); - // update buffer_A address for next k iteration - a->add( - buffer_A, static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t))); + // update buffer_A address for next k iteration + a->add(buffer_A, + static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t))); - // update buffer_B address for next k iteration - a->add( - buffer_B, - static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t))); - // a->add(B_pf, static_cast<asmjit::Imm>(nBlock * row_interleave * - // sizeof(int8_t))); + // update buffer_B address for next k iteration + a->add(buffer_B, static_cast<asmjit::Imm>(nBlock * row_interleave * + sizeof(int8_t))); + // a->add(B_pf, static_cast<asmjit::Imm>(nBlock * row_interleave * + // sizeof(int8_t))); - a->cmp(kIdx, kSize); - a->jl(LoopkRem); + a->cmp(kIdx, kSize); + a->jl(LoopkRem); - // store C matrix - storeCRegs<inst_set_t::avx2>(a, rowRegs, colRegs, C_Offset, ldcReg, accum); - } + // store C matrix + storeCRegs<inst_set_t::avx2>(a, rowRegs, colRegs, C_Offset, ldcReg, + accum); + } - asmjit::FuncUtils::emitEpilog(a, layout); + a->emitEpilog(frame); - jit_micro_kernel_fp fn; - asmjit::Error err = rt_.add(&fn, &code_); - if (err) { - std::cout << "Error: in fn add" << std::endl; - return nullptr; - } - codeCache_[kernelSig] = fn; + jit_micro_kernel_fp fn; + asmjit::Error err; + { + std::unique_lock<std::mutex> lock(rtMutex_); + err = runtime().add(&fn, &code); + } + if (err) { + std::cout << "Error: in fn add" << std::endl; + return nullptr; + } #if defined(FBGEMM_LOG_CODE) - fclose(codeLogfile); - delete codeLogger; + fclose(codeLogfile); + delete codeLogger; #endif - return fn; + return fn; + }); } } // namespace fbgemm diff --git a/src/GenerateKernelU8S8S32ACC16Avx512.cc b/src/GenerateKernelU8S8S32ACC16Avx512.cc index 505fec1..819f33b 100644 --- a/src/GenerateKernelU8S8S32ACC16Avx512.cc +++ b/src/GenerateKernelU8S8S32ACC16Avx512.cc @@ -19,16 +19,17 @@ template <> template <> void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::initCRegs< inst_set_t::avx512>( - asmjit::X86Emitter* a, + x86::Emitter* a, int rowRegs, int colRegs, - int leadingDimCRegAssign) { + int leadingDimCReg) { + using CRegs = x86::Zmm; for (int i = 0; i < rowRegs; ++i) { for (int j = 0; j < colRegs; ++j) { a->vxorps( - CRegs_avx512_[i * leadingDimCRegAssign + j], - CRegs_avx512_[i * leadingDimCRegAssign + j], - CRegs_avx512_[i * leadingDimCRegAssign + j]); + CRegs(i * leadingDimCReg + j), + CRegs(i * leadingDimCReg + j), + CRegs(i * leadingDimCReg + j)); } } } @@ -41,37 +42,38 @@ template <> template <> void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::genComputeBlock< inst_set_t::avx512>( - asmjit::X86Emitter* a, - asmjit::X86Gp buffer_A, - asmjit::X86Gp buffer_B, - asmjit::X86Gp /* unused (reserved for prefetching)*/, + x86::Emitter* a, + x86::Gp buffer_A, + x86::Gp buffer_B, + x86::Gp /* unused (reserved for prefetching)*/, int rowRegs, int colRegs, int lda, - int leadingDimCRegAssign) { + int leadingDimCReg) { // used for matrix A - asmjit::X86Zmm AReg = x86::zmm29; + x86::Zmm AReg = x86::zmm29; - asmjit::X86Zmm tmpReg = x86::zmm30; + x86::Zmm tmpReg = x86::zmm30; // We start allocating BRegs from zmm27 and then allocate zmm26 and so on. for (int j = 0; j < colRegs; ++j) { a->vmovups( - AllRegs_avx512_[27 - j], + x86::Zmm(27 - j), x86::dword_ptr(buffer_B, j * VLEN_ * sizeof(int8_t))); } + using CRegs = x86::Zmm; + for (int i = 0; i < rowRegs; ++i) { // broadcast A a->vpbroadcastw( AReg, x86::dword_ptr(buffer_A, (i * lda) * sizeof(uint8_t))); for (int j = 0; j < colRegs; ++j) { - a->vpmaddubsw( - tmpReg, AReg, AllRegs_avx512_[27-j]); + a->vpmaddubsw(tmpReg, AReg, x86::Zmm(27 - j)); a->vpaddsw( - CRegs_avx512_[i * leadingDimCRegAssign + j], + CRegs(i * leadingDimCReg + j), tmpReg, - CRegs_avx512_[i * leadingDimCRegAssign + j]); + CRegs(i * leadingDimCReg + j)); // Prefetching is hurting performance in some cases // because prefetch instructions itself consumes a slot // in pipeline issue thus slowing down the kernel. @@ -90,25 +92,31 @@ template <> template <> void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::storeCRegs< inst_set_t::avx512>( - asmjit::X86Emitter* a, + x86::Emitter* a, int rowRegs, int colRegs, - asmjit::X86Gp C_Offset, - asmjit::X86Gp ldcReg, + x86::Gp C_Offset, + x86::Gp ldcReg, + bool accum, - int leadingDimCRegAssign) { - asmjit::X86Ymm extractDest256 = x86::ymm31; - asmjit::X86Zmm extractDest512 = x86::zmm31; + int leadingDimCReg) { + x86::Ymm extractDest256 = x86::ymm31; + x86::Zmm extractDest512 = x86::zmm31; + using CRegs = x86::Zmm; for (int i = 0; i < rowRegs; ++i) { a->imul(C_Offset, ldcReg, static_cast<asmjit::Imm>(i * sizeof(int32_t))); for (int j = 0; j < colRegs; ++j) { for (int idx = 0; idx < 2; ++idx) { a->vextracti32x8( - extractDest256, CRegs_avx512_[i * leadingDimCRegAssign + j], idx); + extractDest256, CRegs(i * leadingDimCReg + j), idx); a->vpmovsxwd(extractDest512, extractDest256); - asmjit::X86Mem destAddr = x86::dword_ptr( + x86::Mem destAddr = x86::dword_ptr( +#ifdef _MSC_VER + a->gpz(9), C_Offset, 0, (j * 2 + idx) * 16 * sizeof(int32_t)); +#else a->zcx(), C_Offset, 0, (j * 2 + idx) * 16 * sizeof(int32_t)); +#endif if (accum) { a->vpaddd(extractDest512, extractDest512, destAddr); } @@ -167,261 +175,256 @@ CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<inst_set_t::avx512>( nRegBlockSize, nRegBlockSizeMin); - if (codeCache_.find(kernelSig) != codeCache_.end()) { - return codeCache_[kernelSig]; - } - - code_.reset(false); - code_.init(rt_.getCodeInfo()); - asmjit::X86Assembler assembler(&code_); - asmjit::X86Emitter* a = assembler.asEmitter(); + return codeCache_.getOrCreate(kernelSig, [&]() -> jit_micro_kernel_fp { + asmjit::CodeHolder code; + code.init(runtime().codeInfo()); + x86::Assembler assembler(&code); + x86::Emitter *a = assembler.as<x86::Emitter>(); #if defined(FBGEMM_LOG_CODE) - // generated code logging - FILE* codeLogfile = fopen( - getCodeLoggingFile<inst_set_t::avx512>( - accum, - mc, - nc, - nBlock, - kBlock, - mRegBlockSize, - nRegBlockSize, - nRegBlockSizeMin) - .c_str(), - "w"); - asmjit::FileLogger* codeLogger = new asmjit::FileLogger(codeLogfile); - if (codeLogger) { - code_.setLogger(codeLogger); - } + // generated code logging + FILE *codeLogfile = fopen(getCodeLoggingFile<inst_set_t::avx512>( + accum, mc, nc, nBlock, kBlock, mRegBlockSize, + nRegBlockSize, nRegBlockSizeMin) + .c_str(), + "w"); + asmjit::FileLogger *codeLogger = new asmjit::FileLogger(codeLogfile); + if (codeLogger) { + code.setLogger(codeLogger); + } #endif - assert(kc % row_interleave == 0 && "kc must be a multiple of row_interleave"); - assert(nc % nRegBlockSizeMin == 0 && "nc must be a multiple of NR_MIN"); - int maxMRegs = mRegBlockSize; - int maxNRegs = nRegBlockSize * row_interleave / VLEN_; - assert( - maxMRegs * maxNRegs <= 24 && - "MR*(NR*ROW_INTERLEAVE*8/512) \ - must be <= 24(available registers constraint)"); - int mRegBlocks = mc / mRegBlockSize; - int mRegBlocksRem = mc % mRegBlockSize; - - // arguments to the function created - asmjit::X86Gp buffer_A = a->zdi(); - asmjit::X86Gp buffer_B = a->zsi(); - asmjit::X86Gp B_pf = a->zdx(); - asmjit::X86Gp CBase = a->zcx(); - asmjit::X86Gp kSize = a->gpzRef(8); - asmjit::X86Gp ldcReg = a->gpzRef(9); - - asmjit::FuncDetail func; - func.init( - asmjit:: - FuncSignature6<void, uint8_t*, int8_t*, int8_t*, int32_t*, int, int>( - asmjit::CallConv::kIdHost)); - - asmjit::FuncFrameInfo ffi; - ffi.setDirtyRegs( - asmjit::X86Reg::kKindVec, - asmjit::Utils::mask(0, 1, 2, 3, 4, 5, 6, 7) | - asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14, 15)); - ffi.setDirtyRegs( - asmjit::X86Reg::kKindGp, - asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14, 15)); - - asmjit::FuncArgsMapper args(&func); - args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg); - - args.updateFrameInfo(ffi); - - asmjit::FuncFrameLayout layout; - layout.init(func, ffi); - - asmjit::FuncUtils::emitProlog(a, layout); - asmjit::FuncUtils::allocArgs(a, layout, args); - - asmjit::Label LoopMBlocks = a->newLabel(); - asmjit::Label LoopNBlocks = a->newLabel(); - asmjit::Label Loopk = a->newLabel(); - - asmjit::X86Gp buffer_B_saved = a->gpzRef(10); - asmjit::X86Gp C_Offset = a->gpzRef(11); - // asmjit::X86Gp B_pf_saved = a->gpzRef(12); - asmjit::X86Gp iIdx = a->gpzRef(13); - asmjit::X86Gp jIdx = a->gpzRef(14); - asmjit::X86Gp kIdx = a->gpzRef(15); - - // save B_buffer address - a->mov(buffer_B_saved, buffer_B); - // a->mov(B_pf_saved, B_pf); - - int currColRegs = nc * row_interleave * sizeof(int8_t) / VLEN_; - int colRegs = std::min(currColRegs, maxNRegs); - if (mRegBlocks > 0) { - // move 0 to iteration variables - a->mov(iIdx, 0); - - a->bind(LoopMBlocks); - a->inc(iIdx); - a->mov(jIdx, 0); - - a->bind(LoopNBlocks); - a->inc(jIdx); - - int rowRegs = mRegBlockSize; - - // init C registers - initCRegs<inst_set_t::avx512>(a, rowRegs, colRegs, colRegs); - - // init k loop index - a->mov(kIdx, 0); - a->bind(Loopk); - // k is incremented by row_interleave - a->add(kIdx, static_cast<asmjit::Imm>(row_interleave)); - - genComputeBlock<inst_set_t::avx512>( - a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock, colRegs); - - // update buffer_A address for next k iteration - a->add( - buffer_A, static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t))); - - // update buffer_B address for next k iteration - a->add( - buffer_B, - static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t))); - // a->add(B_pf, static_cast<asmjit::Imm>(nBlock * row_interleave * - // sizeof(int8_t))); - - a->cmp(kIdx, kSize); - a->jl(Loopk); - - // store C matrix - storeCRegs<inst_set_t::avx512>( - a, rowRegs, colRegs, C_Offset, ldcReg, accum, colRegs); - - // reset A - a->sub(buffer_A, kSize); - - // B for next block - a->mov(buffer_B, buffer_B_saved); - // using C_Offset as temp reg - a->imul( - C_Offset, - jIdx, - static_cast<asmjit::Imm>( - nRegBlockSize * row_interleave * sizeof(int8_t))); - a->add(buffer_B, C_Offset); - - // increment C for next block - a->add(CBase, static_cast<asmjit::Imm>(nRegBlockSize * sizeof(int32_t))); - - int jLoopTrips = currColRegs / maxNRegs; - // jLoopTrips should be at least 1 - jLoopTrips = jLoopTrips ? jLoopTrips : 1; - a->cmp(jIdx, jLoopTrips); - a->jl(LoopNBlocks); - - // increment A for next block - a->add( - buffer_A, static_cast<asmjit::Imm>((rowRegs)*kBlock * sizeof(uint8_t))); - - // increment C for next A block - a->sub( - CBase, - static_cast<asmjit::Imm>(jLoopTrips * nRegBlockSize * sizeof(int32_t))); - a->imul( - C_Offset, ldcReg, static_cast<asmjit::Imm>(rowRegs * sizeof(int32_t))); - a->add(CBase, C_Offset); - - // reset B - a->mov(buffer_B, buffer_B_saved); - // a->mov(B_pf, B_pf_saved); - - a->cmp(iIdx, mRegBlocks); - a->jl(LoopMBlocks); - } - // generate code for remainder - if (mRegBlocksRem > 0) { - asmjit::Label LoopNRem = a->newLabel(); - asmjit::Label LoopkRem = a->newLabel(); - int rowRegs = mRegBlocksRem; - - a->mov(jIdx, 0); - a->bind(LoopNRem); - a->inc(jIdx); - - // init C registers - initCRegs<inst_set_t::avx512>(a, rowRegs, colRegs, colRegs); - - // init k loop index - a->mov(kIdx, 0); - a->bind(LoopkRem); - - // k is incremented by row_interleave - a->add(kIdx, static_cast<asmjit::Imm>(row_interleave)); - - genComputeBlock<inst_set_t::avx512>( - a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock, colRegs); - - // update buffer_A address for next k iteration - a->add( - buffer_A, static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t))); - - // update buffer_B address for next k iteration - a->add( - buffer_B, - static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t))); - // a->add(B_pf, static_cast<asmjit::Imm>(nBlock * row_interleave * - // sizeof(int8_t))); - - a->cmp(kIdx, kSize); - a->jl(LoopkRem); - - // reset A - a->sub(buffer_A, kSize); - - // B for next block - a->mov(buffer_B, buffer_B_saved); - // using C_Offset as temp reg - a->imul( - C_Offset, - jIdx, - static_cast<asmjit::Imm>( - nRegBlockSize * row_interleave * sizeof(int8_t))); - a->add(buffer_B, C_Offset); - - // store C matrix - storeCRegs<inst_set_t::avx512>( - a, rowRegs, colRegs, C_Offset, ldcReg, accum, colRegs); - - // increment C for next block - a->add(CBase, static_cast<asmjit::Imm>(nRegBlockSize * sizeof(int32_t))); - - int jLoopTrips = currColRegs / maxNRegs; - // jLoopTrips should be at least 1 - jLoopTrips = jLoopTrips ? jLoopTrips : 1; - a->cmp(jIdx, jLoopTrips); - a->jl(LoopNRem); - } + assert(kc % row_interleave == 0 && + "kc must be a multiple of row_interleave"); + assert(nc % nRegBlockSizeMin == 0 && "nc must be a multiple of NR_MIN"); + int maxMRegs = mRegBlockSize; + int maxNRegs = nRegBlockSize * row_interleave / VLEN_; + assert((maxMRegs + 1) * maxNRegs <= 28 && + "number of zmm registers for C + one row for loading B: \ + MR*(NR*ROW_INTERLEAVE*8/512) + (NR*ROW_INTERLEAVE*8/512) \ + must be <= 28(available registers constraint)"); + int mRegBlocks = mc / mRegBlockSize; + int mRegBlocksRem = mc % mRegBlockSize; + + // arguments to the function created +#ifdef _MSC_VER + x86::Gp buffer_A = a->zcx(); + x86::Gp buffer_B = a->zdx(); + x86::Gp B_pf = a->gpz(8); + x86::Gp CBase = a->gpz(9); + x86::Gp kSize = a->zdi(); + x86::Gp ldcReg = a->zsi(); +#else + x86::Gp buffer_A = a->zdi(); + x86::Gp buffer_B = a->zsi(); + x86::Gp B_pf = a->zdx(); + x86::Gp CBase = a->zcx(); + x86::Gp kSize = a->gpz(8); + x86::Gp ldcReg = a->gpz(9); +#endif + + asmjit::FuncDetail func; + func.init( + asmjit::FuncSignatureT<void, uint8_t *, int8_t *, int8_t *, int32_t *, + int, int>(asmjit::CallConv::kIdHost)); + + asmjit::FuncFrame frame; + frame.init(func); + + frame.setDirtyRegs( + x86::Reg::kGroupVec, + asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) | + asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15)); + frame.setDirtyRegs(x86::Reg::kGroupGp, + asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15)); + + asmjit::FuncArgsAssignment args(&func); + args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg); + + args.updateFuncFrame(frame); + frame.finalize(); + + a->emitProlog(frame); + a->emitArgsAssignment(frame, args); + + asmjit::Label LoopMBlocks = a->newLabel(); + asmjit::Label LoopNBlocks = a->newLabel(); + asmjit::Label Loopk = a->newLabel(); + + x86::Gp buffer_B_saved = a->gpz(10); + x86::Gp C_Offset = a->gpz(11); + // x86::Gp B_pf_saved = a->gpz(12); + x86::Gp iIdx = a->gpz(13); + x86::Gp jIdx = a->gpz(14); + x86::Gp kIdx = a->gpz(15); + + // save B_buffer address + a->mov(buffer_B_saved, buffer_B); + // a->mov(B_pf_saved, B_pf); + + int currColRegs = nc * row_interleave * sizeof(int8_t) / VLEN_; + int colRegs = std::min(currColRegs, maxNRegs); + if (mRegBlocks > 0) { + // move 0 to iteration variables + a->mov(iIdx, 0); + + a->bind(LoopMBlocks); + a->inc(iIdx); + a->mov(jIdx, 0); + + a->bind(LoopNBlocks); + a->inc(jIdx); + + int rowRegs = mRegBlockSize; + + // init C registers + initCRegs<inst_set_t::avx512>(a, rowRegs, colRegs, colRegs); + + // init k loop index + a->mov(kIdx, 0); + a->bind(Loopk); + // k is incremented by row_interleave + a->add(kIdx, static_cast<asmjit::Imm>(row_interleave)); + + genComputeBlock<inst_set_t::avx512>(a, buffer_A, buffer_B, B_pf, rowRegs, + colRegs, kBlock, colRegs); + + // update buffer_A address for next k iteration + a->add(buffer_A, + static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t))); + + // update buffer_B address for next k iteration + a->add(buffer_B, static_cast<asmjit::Imm>(nBlock * row_interleave * + sizeof(int8_t))); + // a->add(B_pf, static_cast<asmjit::Imm>(nBlock * row_interleave * + // sizeof(int8_t))); + + a->cmp(kIdx, kSize); + a->jl(Loopk); + + // store C matrix + storeCRegs<inst_set_t::avx512>(a, rowRegs, colRegs, C_Offset, ldcReg, + accum, colRegs); + + // reset A + a->sub(buffer_A, kSize); + + // B for next block + a->mov(buffer_B, buffer_B_saved); + // using C_Offset as temp reg + a->imul(C_Offset, jIdx, + static_cast<asmjit::Imm>(nRegBlockSize * row_interleave * + sizeof(int8_t))); + a->add(buffer_B, C_Offset); + + // increment C for next block + a->add(CBase, static_cast<asmjit::Imm>(nRegBlockSize * sizeof(int32_t))); + + int jLoopTrips = currColRegs / maxNRegs; + // jLoopTrips should be at least 1 + jLoopTrips = jLoopTrips ? jLoopTrips : 1; + a->cmp(jIdx, jLoopTrips); + a->jl(LoopNBlocks); + + // increment A for next block + a->add(buffer_A, + static_cast<asmjit::Imm>((rowRegs)*kBlock * sizeof(uint8_t))); + + // increment C for next A block + a->sub(CBase, static_cast<asmjit::Imm>(jLoopTrips * nRegBlockSize * + sizeof(int32_t))); + a->imul(C_Offset, ldcReg, + static_cast<asmjit::Imm>(rowRegs * sizeof(int32_t))); + a->add(CBase, C_Offset); + + // reset B + a->mov(buffer_B, buffer_B_saved); + // a->mov(B_pf, B_pf_saved); + + a->cmp(iIdx, mRegBlocks); + a->jl(LoopMBlocks); + } + // generate code for remainder + if (mRegBlocksRem > 0) { + asmjit::Label LoopNRem = a->newLabel(); + asmjit::Label LoopkRem = a->newLabel(); + int rowRegs = mRegBlocksRem; + + a->mov(jIdx, 0); + a->bind(LoopNRem); + a->inc(jIdx); + + // init C registers + initCRegs<inst_set_t::avx512>(a, rowRegs, colRegs, colRegs); + + // init k loop index + a->mov(kIdx, 0); + a->bind(LoopkRem); + + // k is incremented by row_interleave + a->add(kIdx, static_cast<asmjit::Imm>(row_interleave)); + + genComputeBlock<inst_set_t::avx512>(a, buffer_A, buffer_B, B_pf, rowRegs, + colRegs, kBlock, colRegs); + + // update buffer_A address for next k iteration + a->add(buffer_A, + static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t))); + + // update buffer_B address for next k iteration + a->add(buffer_B, static_cast<asmjit::Imm>(nBlock * row_interleave * + sizeof(int8_t))); + // a->add(B_pf, static_cast<asmjit::Imm>(nBlock * row_interleave * + // sizeof(int8_t))); + + a->cmp(kIdx, kSize); + a->jl(LoopkRem); + + // reset A + a->sub(buffer_A, kSize); + + // B for next block + a->mov(buffer_B, buffer_B_saved); + // using C_Offset as temp reg + a->imul(C_Offset, jIdx, + static_cast<asmjit::Imm>(nRegBlockSize * row_interleave * + sizeof(int8_t))); + a->add(buffer_B, C_Offset); + + // store C matrix + storeCRegs<inst_set_t::avx512>(a, rowRegs, colRegs, C_Offset, ldcReg, + accum, colRegs); + + // increment C for next block + a->add(CBase, static_cast<asmjit::Imm>(nRegBlockSize * sizeof(int32_t))); + + int jLoopTrips = currColRegs / maxNRegs; + // jLoopTrips should be at least 1 + jLoopTrips = jLoopTrips ? jLoopTrips : 1; + a->cmp(jIdx, jLoopTrips); + a->jl(LoopNRem); + } - asmjit::FuncUtils::emitEpilog(a, layout); + a->emitEpilog(frame); - jit_micro_kernel_fp fn; - asmjit::Error err = rt_.add(&fn, &code_); - if (err) { - std::cout << "Error: in fn add" << std::endl; - return nullptr; - } - codeCache_[kernelSig] = fn; + jit_micro_kernel_fp fn; + asmjit::Error err; + { + std::unique_lock<std::mutex> lock(rtMutex_); + err = runtime().add(&fn, &code); + } + if (err) { + std::cout << "Error: in fn add" << std::endl; + return nullptr; + } #if defined(FBGEMM_LOG_CODE) - fclose(codeLogfile); - delete codeLogger; + fclose(codeLogfile); + delete codeLogger; #endif - return fn; + return fn; + }); } } // namespace fbgemm diff --git a/src/GenerateKernelU8S8S32ACC16Avx512VNNI.cc b/src/GenerateKernelU8S8S32ACC16Avx512VNNI.cc new file mode 100644 index 0000000..f559aba --- /dev/null +++ b/src/GenerateKernelU8S8S32ACC16Avx512VNNI.cc @@ -0,0 +1,102 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include <iostream> +#include "GenerateKernel.h" + +namespace fbgemm { + +namespace x86 = asmjit::x86; + +/** + * Generate AVX512 instructions for initializing the C registers to 0 in 16-bit + * Accumulation kernel. + */ +template <> +template <> +void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::initCRegs< + inst_set_t::avx512_vnni>( + x86::Emitter* a, + int rowRegs, + int colRegs, + int leadingDimCReg) { + assert(0 && "Accumulation to int16_t is not available for VNNI!"); + + // For AVX512VNNI, redirect to int32_t accumulation. + CodeGenBase<uint8_t, int8_t, int32_t, int32_t> codeObj; + codeObj.initCRegs<inst_set_t::avx512_vnni>( + a, rowRegs, colRegs, leadingDimCReg); +} + +/** + * Generate AVX512 instructions for computing block in the rank-k update of + * 16-bit Accmulation kernel. + */ +template <> +template <> +void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::genComputeBlock< + inst_set_t::avx512_vnni>( + x86::Emitter* a, + x86::Gp buffer_A, + x86::Gp buffer_B, + x86::Gp /* unused (reserved for prefetching)*/, + int rowRegs, + int colRegs, + int lda, + int leadingDimCReg) { + assert(0 && "Accumulation to int16_t is not available for VNNI!"); + + // For AVX512VNNI, redirect to int32_t accumulation. + CodeGenBase<uint8_t, int8_t, int32_t, int32_t> codeObj; + codeObj.genComputeBlock<inst_set_t::avx512_vnni>( + a, buffer_A, buffer_B, buffer_B, rowRegs, colRegs, lda, leadingDimCReg); +} + +/** + * Generate AVX512 instructions for storing the C registers back to the memory + * in 16-bit Accumulation kernel. + */ +template <> +template <> +void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::storeCRegs< + inst_set_t::avx512_vnni>( + x86::Emitter* a, + int rowRegs, + int colRegs, + x86::Gp C_Offset, + x86::Gp ldcReg, + bool accum, + int leadingDimCReg) { + assert(0 && "Accumulation to int16_t is not available for VNNI!"); + + // For AVX512VNNI, redirect to int32_t accumulation. + CodeGenBase<uint8_t, int8_t, int32_t, int32_t> codeObj; + codeObj.storeCRegs<inst_set_t::avx512_vnni>( + a, rowRegs, colRegs, C_Offset, ldcReg, accum, leadingDimCReg); +} + +/** + * Get or Create the AVX512 instructions for 16-bit Accumulation macro-kernel. + * + */ +template <> +template <> +CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::jit_micro_kernel_fp +CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate< + inst_set_t::avx512_vnni>( + bool accum, + int32_t mc, + int32_t nc, + int32_t kc, + int32_t /* unused */) { + assert(0 && "Accumulation to int16_t is not available for VNNI!"); + + // For AVX512VNNI, redirect to int32_t accumulation. + CodeGenBase<uint8_t, int8_t, int32_t, int32_t> codeObj; + return codeObj.getOrCreate<inst_set_t::avx512_vnni>(accum, mc, nc, kc, kc); +} + +} // namespace fbgemm diff --git a/src/GenerateKernelU8S8S32ACC32.cc b/src/GenerateKernelU8S8S32ACC32.cc index d044530..dc9c534 100644 --- a/src/GenerateKernelU8S8S32ACC32.cc +++ b/src/GenerateKernelU8S8S32ACC32.cc @@ -9,18 +9,6 @@ namespace fbgemm { -template <typename TA, typename TB, typename TC, typename accT> -thread_local asmjit::JitRuntime CodeGenBase<TA, TB, TC, accT>::rt_; - -template <typename TA, typename TB, typename TC, typename accT> -thread_local asmjit::CodeHolder CodeGenBase<TA, TB, TC, accT>::code_; - -template <typename TA, typename TB, typename TC, typename accT> -thread_local std::map< - std::tuple<bool, int, int, int, int, int, int, int>, - typename CodeGenBase<TA, TB, TC, accT>::jit_micro_kernel_fp> - CodeGenBase<TA, TB, TC, accT>::codeCache_; - namespace x86 = asmjit::x86; /** @@ -31,16 +19,17 @@ template <> template <> void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::initCRegs< inst_set_t::avx2>( - asmjit::X86Emitter* a, + x86::Emitter* a, int rowRegs, int colRegs, int leadingDimCReg) { + using CRegs = x86::Ymm; for (int i = 0; i < rowRegs; ++i) { for (int j = 0; j < colRegs; ++j) { a->vxorps( - CRegs_avx2_[i * leadingDimCReg + j], - CRegs_avx2_[i * leadingDimCReg + j], - CRegs_avx2_[i * leadingDimCReg + j]); + CRegs(i * leadingDimCReg + j), + CRegs(i * leadingDimCReg + j), + CRegs(i * leadingDimCReg + j)); } } } @@ -53,25 +42,27 @@ template <> template <> void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::genComputeBlock< inst_set_t::avx2>( - asmjit::X86Emitter* a, - asmjit::X86Gp buffer_A, - asmjit::X86Gp buffer_B, - asmjit::X86Gp B_pf, + x86::Emitter* a, + x86::Gp buffer_A, + x86::Gp buffer_B, + x86::Gp B_pf, int rowRegs, int colRegs, int lda, - int leadingDimCRegAssign) { + int leadingDimCReg) { // used for matrix A - asmjit::X86Ymm AReg = x86::ymm12; + x86::Ymm AReg = x86::ymm12; // used for matrix B - asmjit::X86Ymm BReg = x86::ymm13; + x86::Ymm BReg = x86::ymm13; // Contains 16-bit 1s - asmjit::X86Ymm oneReg = x86::ymm15; + x86::Ymm oneReg = x86::ymm15; // temporary register - asmjit::X86Ymm res1 = x86::ymm14; + x86::Ymm res1 = x86::ymm14; + + using CRegs = x86::Ymm; for (int j = 0; j < colRegs; ++j) { // load B @@ -83,9 +74,9 @@ void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::genComputeBlock< a->vpmaddubsw(res1, AReg, BReg); a->vpmaddwd(res1, oneReg, res1); a->vpaddd( - CRegs_avx2_[i * leadingDimCRegAssign + j], + CRegs(i * leadingDimCReg + j), res1, - CRegs_avx2_[i * leadingDimCRegAssign + j]); + CRegs(i * leadingDimCReg + j)); } a->prefetcht0(x86::dword_ptr(B_pf, j * VLEN_ * sizeof(int8_t))); } @@ -99,16 +90,14 @@ template <> template <> void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::storeCRegs< inst_set_t::avx2>( - asmjit::X86Emitter* a, + x86::Emitter* a, int rowRegs, int colRegs, - asmjit::X86Gp C_Offset, - asmjit::X86Gp ldcReg, + x86::Gp C_Offset, + x86::Gp ldcReg, bool accum, - int leadingDimCRegAssign) { - // temp register - asmjit::X86Ymm tmpReg = x86::ymm14; - + int leadingDimCReg) { + using CRegs = x86::Ymm; for (int i = 0; i < rowRegs; ++i) { if (i != 0) { a->add(C_Offset, ldcReg); @@ -116,13 +105,21 @@ void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::storeCRegs< for (int j = 0; j < colRegs; ++j) { if (accum) { a->vpaddd( - CRegs_avx2_[i * leadingDimCRegAssign + j], - CRegs_avx2_[i * leadingDimCRegAssign + j], + CRegs(i * leadingDimCReg + j), + CRegs(i * leadingDimCReg + j), +#ifdef _MSC_VER + x86::dword_ptr(a->gpz(9), C_Offset, 0, j * 8 * sizeof(int32_t))); +#else x86::dword_ptr(a->zcx(), C_Offset, 0, j * 8 * sizeof(int32_t))); +#endif } a->vmovups( +#ifdef _MSC_VER + x86::dword_ptr(a->gpz(9), C_Offset, 0, j * 8 * sizeof(int32_t)), +#else x86::dword_ptr(a->zcx(), C_Offset, 0, j * 8 * sizeof(int32_t)), - CRegs_avx2_[i * leadingDimCRegAssign + j]); +#endif + CRegs(i * leadingDimCReg + j)); } } } @@ -176,207 +173,178 @@ CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreate<inst_set_t::avx2>( nRegBlockSize, nRegBlockSizeMin); - if (codeCache_.find(kernelSig) != codeCache_.end()) { - return codeCache_[kernelSig]; - } - code_.reset(false); - code_.init(rt_.getCodeInfo()); - asmjit::X86Assembler assembler(&code_); - asmjit::X86Emitter* a = assembler.asEmitter(); + return codeCache_.getOrCreate(kernelSig, [&]() -> jit_micro_kernel_fp { + asmjit::CodeHolder code; + code.init(runtime().codeInfo()); + x86::Assembler assembler(&code); + x86::Emitter *a = assembler.as<x86::Emitter>(); #if defined(FBGEMM_LOG_CODE) - // generated code logging - FILE* codeLogfile = fopen( - getCodeLoggingFile<inst_set_t::avx2>( - accum, - mc, - nc, - nBlock, - kBlock, - mRegBlockSize, - nRegBlockSize, - nRegBlockSizeMin) - .c_str(), - "w"); - asmjit::FileLogger* codeLogger = new asmjit::FileLogger(codeLogfile); - if (codeLogger) { - code_.setLogger(codeLogger); - } + // generated code logging + FILE *codeLogfile = fopen(getCodeLoggingFile<inst_set_t::avx2>( + accum, mc, nc, nBlock, kBlock, mRegBlockSize, + nRegBlockSize, nRegBlockSizeMin) + .c_str(), + "w"); + asmjit::FileLogger *codeLogger = new asmjit::FileLogger(codeLogfile); + if (codeLogger) { + code.setLogger(codeLogger); + } #endif - // assert(mc <= 12 && "mc must be <= 12 (available registers constraint)"); - int mRegBlocks = mc / mRegBlockSize; - int mRegBlocksRem = mc % mRegBlockSize; - - // arguments to the function created - asmjit::X86Gp buffer_A = a->zdi(); - asmjit::X86Gp buffer_B = a->zsi(); - asmjit::X86Gp B_pf = a->zdx(); - asmjit::X86Gp CBase = a->zcx(); - asmjit::X86Gp kSize = a->gpzRef(8); - asmjit::X86Gp ldcReg = a->gpzRef(9); - - asmjit::FuncDetail func; - func.init( - asmjit:: - FuncSignature6<void, uint8_t*, int8_t*, int8_t*, int32_t*, int, int>( - asmjit::CallConv::kIdHost)); - - asmjit::FuncFrameInfo ffi; - ffi.setDirtyRegs( - asmjit::X86Reg::kKindVec, - asmjit::Utils::mask(0, 1, 2, 3, 4, 5, 6, 7) | - asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14, 15)); - ffi.setDirtyRegs( - asmjit::X86Reg::kKindGp, asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14)); - - asmjit::FuncArgsMapper args(&func); - args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg); - - args.updateFrameInfo(ffi); - - asmjit::FuncFrameLayout layout; - layout.init(func, ffi); - - asmjit::FuncUtils::emitProlog(a, layout); - asmjit::FuncUtils::allocArgs(a, layout, args); - - asmjit::Label Loopk = a->newLabel(); - asmjit::Label LoopMBlocks = a->newLabel(); - - asmjit::X86Gp buffer_B_saved = a->gpzRef(10); - asmjit::X86Gp C_Offset = a->gpzRef(11); - asmjit::X86Gp B_pf_saved = a->gpzRef(12); - asmjit::X86Gp iIdx = a->gpzRef(13); - asmjit::X86Gp kIdx = a->gpzRef(14); - // asmjit::X86Gp B_pf = a->gpzRef(8); - - asmjit::X86Ymm oneReg = x86::ymm15; - // create 16-bit 1s - // i.e., oneReg[0:15] contains 0x0001, oneReg[16:31] contains 0x0001 - // and so on - a->vpcmpeqw(oneReg, oneReg, oneReg); - a->vpsrlw(oneReg, oneReg, 15); - a->imul(ldcReg, ldcReg, static_cast<asmjit::Imm>(sizeof(int32_t))); - a->mov(C_Offset, 0); - - int colRegs = nc * row_interleave * sizeof(int8_t) / VLEN_; - if (mRegBlocks > 0) { - // move 0 to iteration variables - a->mov(iIdx, 0); - - // save B_buffer address - a->mov(buffer_B_saved, buffer_B); - a->mov(B_pf_saved, B_pf); - - a->bind(LoopMBlocks); - a->inc(iIdx); - - int rowRegs = mRegBlockSize; - - // init C registers - initCRegs<inst_set_t::avx2>(a, rowRegs, colRegs, colRegs); - - // init k loop index - a->mov(kIdx, 0); - a->bind(Loopk); - - // k is incremented by row_interleave - a->add(kIdx, static_cast<asmjit::Imm>(row_interleave)); - - genComputeBlock<inst_set_t::avx2>( - a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock, colRegs); - - // update buffer_A address for next k iteration - a->add( - buffer_A, static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t))); - - // update buffer_B address for next k iteration - a->add( - buffer_B, - static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t))); - a->add( - B_pf, - static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t))); - - // a->add(B_pf, 32*sizeof(float)); - - a->cmp(kIdx, kSize); - a->jl(Loopk); - - // store C matrix - storeCRegs<inst_set_t::avx2>( - a, rowRegs, colRegs, C_Offset, ldcReg, accum, colRegs); - - // increment A for next block - a->sub(buffer_A, kSize); - a->add( - buffer_A, static_cast<asmjit::Imm>((rowRegs)*kBlock * sizeof(uint8_t))); - - // increment C for next block - a->imul(C_Offset, ldcReg, static_cast<asmjit::Imm>(rowRegs)); - a->add(CBase, C_Offset); + // assert(mc <= 12 && "mc must be <= 12 (available registers constraint)"); + int mRegBlocks = mc / mRegBlockSize; + int mRegBlocksRem = mc % mRegBlockSize; + + // arguments to the function created +#ifdef _MSC_VER + x86::Gp buffer_A = a->zcx(); + x86::Gp buffer_B = a->zdx(); + x86::Gp B_pf = a->gpz(8); + x86::Gp CBase = a->gpz(9); + x86::Gp kSize = a->zdi(); + x86::Gp ldcReg = a->zsi(); +#else + x86::Gp buffer_A = a->zdi(); + x86::Gp buffer_B = a->zsi(); + x86::Gp B_pf = a->zdx(); + x86::Gp CBase = a->zcx(); + x86::Gp kSize = a->gpz(8); + x86::Gp ldcReg = a->gpz(9); +#endif + + asmjit::FuncDetail func; + func.init( + asmjit::FuncSignatureT<void, uint8_t *, int8_t *, int8_t *, int32_t *, + int, int>(asmjit::CallConv::kIdHost)); + + asmjit::FuncFrame frame; + frame.init(func); + frame.setDirtyRegs( + x86::Reg::kGroupVec, + asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) | + asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15)); + frame.setDirtyRegs(x86::Reg::kGroupGp, + asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14)); + + asmjit::FuncArgsAssignment args(&func); + args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg); + + args.updateFuncFrame(frame); + frame.finalize(); + + a->emitProlog(frame); + a->emitArgsAssignment(frame, args); + + asmjit::Label LoopMBlocks = a->newLabel(); + + x86::Gp buffer_B_saved = a->gpz(10); + x86::Gp C_Offset = a->gpz(11); + x86::Gp B_pf_saved = a->gpz(12); + x86::Gp iIdx = a->gpz(13); + x86::Gp kIdx = a->gpz(14); + // x86::Gp B_pf = a->gpz(8); + + x86::Ymm oneReg = x86::ymm15; + // create 16-bit 1s + // i.e., oneReg[0:15] contains 0x0001, oneReg[16:31] contains 0x0001 + // and so on + a->vpcmpeqw(oneReg, oneReg, oneReg); + a->vpsrlw(oneReg, oneReg, 15); + a->imul(ldcReg, ldcReg, static_cast<asmjit::Imm>(sizeof(int32_t))); a->mov(C_Offset, 0); - // reset B - a->mov(buffer_B, buffer_B_saved); - a->mov(B_pf, B_pf_saved); - a->cmp(iIdx, mRegBlocks); - a->jl(LoopMBlocks); - } - // generate code for remainder - if (mRegBlocksRem > 0) { - asmjit::Label LoopkRem = a->newLabel(); - int rowRegs = mRegBlocksRem; - - // init C registers - initCRegs<inst_set_t::avx2>(a, rowRegs, colRegs, colRegs); - - // init k loop index - a->mov(kIdx, 0); - a->bind(LoopkRem); - - // k is incremented by row_interleave - a->add(kIdx, static_cast<asmjit::Imm>(row_interleave)); - - genComputeBlock<inst_set_t::avx2>( - a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock, colRegs); - - // update buffer_A address for next k iteration - a->add( - buffer_A, static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t))); - - // update buffer_B address for next k iteration - a->add( - buffer_B, - static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t))); - a->add( - B_pf, - static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t))); - - a->cmp(kIdx, kSize); - a->jl(LoopkRem); - - // store C matrix - storeCRegs<inst_set_t::avx2>( - a, rowRegs, colRegs, C_Offset, ldcReg, accum, colRegs); - } + int colRegs = nc * row_interleave * sizeof(int8_t) / VLEN_; - asmjit::FuncUtils::emitEpilog(a, layout); + auto issueLoopOverK = [&](int rowRegs) { + asmjit::Label LoopKLabel = a->newLabel(); - jit_micro_kernel_fp fn; - asmjit::Error err = rt_.add(&fn, &code_); - if (err) { - std::cout << "Error: in fn add" << std::endl; - return nullptr; - } - codeCache_[kernelSig] = fn; + // Init C (result) vector registers + initCRegs<inst_set_t::avx2>(a, rowRegs, colRegs, colRegs); + + // Loops over K + a->mov(kIdx, 0); + a->bind(LoopKLabel); + + // k is incremented by row_interleave + a->add(kIdx, static_cast<asmjit::Imm>(row_interleave)); + + genComputeBlock<inst_set_t::avx2>(a, buffer_A, buffer_B, B_pf, rowRegs, + colRegs, kBlock, colRegs); + + // update buffer_A address for next k iteration + a->add(buffer_A, + static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t))); + + // update buffer_B address for next k iteration + a->add(buffer_B, static_cast<asmjit::Imm>(nBlock * row_interleave * + sizeof(int8_t))); + a->add(B_pf, static_cast<asmjit::Imm>(nBlock * row_interleave * + sizeof(int8_t))); + + a->cmp(kIdx, kSize); + a->jl(LoopKLabel); + + // store C matrix + storeCRegs<inst_set_t::avx2>(a, rowRegs, colRegs, C_Offset, ldcReg, accum, + colRegs); + }; + + if (mRegBlocks > 0) { + // move 0 to iteration variables + a->mov(iIdx, 0); + + // save B_buffer address + a->mov(buffer_B_saved, buffer_B); + a->mov(B_pf_saved, B_pf); + + a->bind(LoopMBlocks); + a->inc(iIdx); + + issueLoopOverK(mRegBlockSize); + + int rowRegs = mRegBlockSize; + // increment A for next block + a->sub(buffer_A, kSize); + a->add(buffer_A, + static_cast<asmjit::Imm>((rowRegs)*kBlock * sizeof(uint8_t))); + + // increment C for next block + a->imul(C_Offset, ldcReg, static_cast<asmjit::Imm>(rowRegs)); + a->add(CBase, C_Offset); + a->mov(C_Offset, 0); + + // reset B + a->mov(buffer_B, buffer_B_saved); + a->mov(B_pf, B_pf_saved); + + a->cmp(iIdx, mRegBlocks); + a->jl(LoopMBlocks); + } + // generate code for remainder + if (mRegBlocksRem > 0) { + issueLoopOverK(mRegBlocksRem); + } + + a->emitEpilog(frame); + + jit_micro_kernel_fp fn; + asmjit::Error err; + { + std::unique_lock<std::mutex> lock(rtMutex_); + err = runtime().add(&fn, &code); + } + if (err) { + std::cout << "Error: in fn add" << std::endl; + return nullptr; + } #if defined(FBGEMM_LOG_CODE) - fclose(codeLogfile); - delete codeLogger; + fclose(codeLogfile); + delete codeLogger; #endif - return fn; + return fn; + }); } } // namespace fbgemm diff --git a/src/GenerateKernelU8S8S32ACC32Avx512.cc b/src/GenerateKernelU8S8S32ACC32Avx512.cc index d1729e4..5037292 100644 --- a/src/GenerateKernelU8S8S32ACC32Avx512.cc +++ b/src/GenerateKernelU8S8S32ACC32Avx512.cc @@ -19,16 +19,17 @@ template <> template <> void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::initCRegs< inst_set_t::avx512>( - asmjit::X86Emitter* a, + x86::Emitter* a, int rowRegs, int colRegs, int leadingDimCReg) { + using CRegs = x86::Zmm; for (int i = 0; i < rowRegs; ++i) { for (int j = 0; j < colRegs; ++j) { a->vxorps( - CRegs_avx512_[i * leadingDimCReg + j], - CRegs_avx512_[i * leadingDimCReg + j], - CRegs_avx512_[i * leadingDimCReg + j]); + CRegs(i * leadingDimCReg + j), + CRegs(i * leadingDimCReg + j), + CRegs(i * leadingDimCReg + j)); } } } @@ -41,26 +42,27 @@ template <> template <> void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::genComputeBlock< inst_set_t::avx512>( - asmjit::X86Emitter* a, - asmjit::X86Gp buffer_A, - asmjit::X86Gp buffer_B, - asmjit::X86Gp B_pf, + x86::Emitter* a, + x86::Gp buffer_A, + x86::Gp buffer_B, + x86::Gp B_pf, int rowRegs, int colRegs, int lda, - int leadingDimCRegAssign) { + int leadingDimCReg) { // used for matrix A - asmjit::X86Zmm AReg = x86::zmm31; + x86::Zmm AReg = x86::zmm31; // used for matrix B - asmjit::X86Zmm BReg = x86::zmm30; + x86::Zmm BReg = x86::zmm30; // Contains 16-bit 1s - asmjit::X86Zmm oneReg = x86::zmm29; + x86::Zmm oneReg = x86::zmm29; // temporary register - asmjit::X86Zmm res1 = x86::zmm28; + x86::Zmm res1 = x86::zmm28; + using CRegs = x86::Zmm; for (int j = 0; j < colRegs; ++j) { // load B a->vmovaps(BReg, x86::dword_ptr(buffer_B, j * VLEN_ * sizeof(int8_t))); @@ -71,9 +73,9 @@ void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::genComputeBlock< a->vpmaddubsw(res1, AReg, BReg); a->vpmaddwd(res1, oneReg, res1); a->vpaddd( - CRegs_avx512_[i * leadingDimCRegAssign + j], + CRegs(i * leadingDimCReg + j), res1, - CRegs_avx512_[i * leadingDimCRegAssign + j]); + CRegs(i * leadingDimCReg + j)); } a->prefetcht0(x86::dword_ptr(B_pf, j * VLEN_ * sizeof(int8_t))); } @@ -87,33 +89,38 @@ template <> template <> void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::storeCRegs< inst_set_t::avx512>( - asmjit::X86Emitter* a, + x86::Emitter* a, int rowRegs, int colRegs, - asmjit::X86Gp C_Offset, - asmjit::X86Gp ldcReg, + x86::Gp C_Offset, + x86::Gp ldcReg, bool accum, - int leadingDimCRegAssign) { - // temp register - asmjit::X86Zmm tmpReg = x86::zmm28; - + int leadingDimCReg) { + using CRegs = x86::Zmm; for (int i = 0; i < rowRegs; ++i) { if (i != 0) { a->add(C_Offset, ldcReg); - } - else { + } else { a->mov(C_Offset, static_cast<asmjit::Imm>(0)); } for (int j = 0; j < colRegs; ++j) { if (accum) { a->vpaddd( - CRegs_avx512_[i * leadingDimCRegAssign + j], - CRegs_avx512_[i * leadingDimCRegAssign + j], + CRegs(i * leadingDimCReg + j), + CRegs(i * leadingDimCReg + j), +#ifdef _MSC_VER + x86::dword_ptr(a->gpz(9), C_Offset, 0, j * 16 * sizeof(int32_t))); +#else x86::dword_ptr(a->zcx(), C_Offset, 0, j * 16 * sizeof(int32_t))); +#endif } a->vmovups( +#ifdef _MSC_VER + x86::dword_ptr(a->gpz(9), C_Offset, 0, j * 16 * sizeof(int32_t)), +#else x86::dword_ptr(a->zcx(), C_Offset, 0, j * 16 * sizeof(int32_t)), - CRegs_avx512_[i * leadingDimCRegAssign + j]); +#endif + CRegs(i * leadingDimCReg + j)); } } } @@ -167,278 +174,269 @@ CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreate<inst_set_t::avx512>( nRegBlockSize, nRegBlockSizeMin); - if (codeCache_.find(kernelSig) != codeCache_.end()) { - return codeCache_[kernelSig]; - } - code_.reset(false); - code_.init(rt_.getCodeInfo()); - asmjit::X86Assembler assembler(&code_); - asmjit::X86Emitter* a = assembler.asEmitter(); - + return codeCache_.getOrCreate(kernelSig, [&]() -> jit_micro_kernel_fp { + asmjit::CodeHolder code; + code.init(runtime().codeInfo()); + x86::Assembler assembler(&code); + x86::Emitter *a = assembler.as<x86::Emitter>(); #if defined(FBGEMM_LOG_CODE) - // generated code logging - FILE* codeLogfile = fopen( - getCodeLoggingFile<inst_set_t::avx512>( - accum, - mc, - nc, - nBlock, - kBlock, - mRegBlockSize, - nRegBlockSize, - nRegBlockSizeMin) - .c_str(), - "w"); - asmjit::FileLogger* codeLogger = new asmjit::FileLogger(codeLogfile); - if (codeLogger) { - code_.setLogger(codeLogger); - } + // generated code logging + FILE *codeLogfile = fopen(getCodeLoggingFile<inst_set_t::avx512>( + accum, mc, nc, nBlock, kBlock, mRegBlockSize, + nRegBlockSize, nRegBlockSizeMin) + .c_str(), + "w"); + asmjit::FileLogger *codeLogger = new asmjit::FileLogger(codeLogfile); + if (codeLogger) { + code.setLogger(codeLogger); + } #endif - assert(kc % row_interleave == 0 && "kc must be a multiple of row_interleave"); - assert(nc % nRegBlockSizeMin == 0 && "nc must be a multiple of NR_MIN"); - int maxMRegs = mRegBlockSize; - int maxNRegs = nRegBlockSize * row_interleave / VLEN_; - assert( - maxMRegs * maxNRegs <= 28 && - "MR*(NR*ROW_INTERLEAVE*8/512) \ + assert(kc % row_interleave == 0 && + "kc must be a multiple of row_interleave"); + assert(nc % nRegBlockSizeMin == 0 && "nc must be a multiple of NR_MIN"); + int maxMRegs = mRegBlockSize; + int maxNRegs = nRegBlockSize * row_interleave / VLEN_; + assert(maxMRegs * maxNRegs <= 28 && "MR*(NR*ROW_INTERLEAVE*8/512) \ must be <= 28(available registers constraint)"); - int mRegBlocks = mc / mRegBlockSize; - int mRegBlocksRem = mc % mRegBlockSize; - - // arguments to the function created - asmjit::X86Gp buffer_A = a->zdi(); - asmjit::X86Gp buffer_B = a->zsi(); - asmjit::X86Gp B_pf = a->zdx(); - asmjit::X86Gp CBase = a->zcx(); - asmjit::X86Gp kSize = a->gpzRef(8); - asmjit::X86Gp ldcReg = a->gpzRef(9); - - asmjit::FuncDetail func; - func.init( - asmjit:: - FuncSignature6<void, uint8_t*, int8_t*, int8_t*, int32_t*, int, int>( - asmjit::CallConv::kIdHost)); - - asmjit::FuncFrameInfo ffi; - ffi.setDirtyRegs( - asmjit::X86Reg::kKindVec, - asmjit::Utils::mask(0, 1, 2, 3, 4, 5, 6, 7) | - asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14, 15)); - ffi.setDirtyRegs( - asmjit::X86Reg::kKindGp, - asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14, 15)); - - asmjit::FuncArgsMapper args(&func); - args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg); - - args.updateFrameInfo(ffi); - - asmjit::FuncFrameLayout layout; - layout.init(func, ffi); - - asmjit::FuncUtils::emitProlog(a, layout); - asmjit::FuncUtils::allocArgs(a, layout, args); - - asmjit::Label LoopMBlocks = a->newLabel(); - asmjit::Label LoopNBlocks = a->newLabel(); - asmjit::Label Loopk = a->newLabel(); - - asmjit::X86Gp buffer_B_saved = a->gpzRef(10); - asmjit::X86Gp C_Offset = a->gpzRef(11); - asmjit::X86Gp B_pf_saved = a->gpzRef(12); - asmjit::X86Gp iIdx = a->gpzRef(13); - asmjit::X86Gp jIdx = a->gpzRef(14); - asmjit::X86Gp kIdx = a->gpzRef(15); - // asmjit::X86Gp B_pf = a->gpzRef(8); - - asmjit::X86Zmm oneReg = x86::zmm29; - // create 16-bit 1s - // i.e., oneReg[0:15] contains 0x0001, oneReg[16:31] contains 0x0001 - // and so on - // a->vpcmpeqw(oneReg, oneReg, oneReg); - a->vpternlogd(oneReg, oneReg, oneReg, 0xff); - a->vpsrlw(oneReg, oneReg, 15); - a->imul(ldcReg, ldcReg, static_cast<asmjit::Imm>(sizeof(int32_t))); - - // save B_buffer address - a->mov(buffer_B_saved, buffer_B); - a->mov(B_pf_saved, B_pf); - - int currColRegs = nc * row_interleave * sizeof(int8_t) / VLEN_; - int colRegs = std::min(currColRegs, maxNRegs); - if (mRegBlocks > 0) { - // move 0 to iteration variables - a->mov(iIdx, 0); - - a->bind(LoopMBlocks); - a->inc(iIdx); - a->mov(jIdx, 0); - - a->bind(LoopNBlocks); - a->inc(jIdx); - - int rowRegs = mRegBlockSize; - - // init C registers - initCRegs<inst_set_t::avx512>(a, rowRegs, colRegs, colRegs); - - // init k loop index - a->mov(kIdx, 0); - a->bind(Loopk); - - // k is incremented by row_interleave - a->add(kIdx, static_cast<asmjit::Imm>(row_interleave)); - - genComputeBlock<inst_set_t::avx512>( - a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock, colRegs); - - // update buffer_A address for next k iteration - a->add( - buffer_A, static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t))); - - // update buffer_B address for next k iteration - a->add( - buffer_B, - static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t))); - a->add( - B_pf, - static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t))); - - // a->add(B_pf, static_cast<asmjit::Imm>(32*sizeof(float))); - - a->cmp(kIdx, kSize); - a->jl(Loopk); - - // store C matrix - storeCRegs<inst_set_t::avx512>( - a, rowRegs, colRegs, C_Offset, ldcReg, accum, colRegs); - - // reset A - a->sub(buffer_A, kSize); - - // B for next block - a->mov(buffer_B, buffer_B_saved); - // using C_Offset as temp reg - a->imul( - C_Offset, - jIdx, - static_cast<asmjit::Imm>( - nRegBlockSize * row_interleave * sizeof(int8_t))); - a->add(buffer_B, C_Offset); - a->mov(B_pf, B_pf_saved); - a->add(B_pf, C_Offset); - - // increment C for next B block - a->add(CBase, static_cast<asmjit::Imm>(nRegBlockSize * sizeof(int32_t))); - - int jLoopTrips = currColRegs / maxNRegs; - // jLoopTrips should be at least 1 - jLoopTrips = jLoopTrips ? jLoopTrips : 1; - a->cmp(jIdx, jLoopTrips); - a->jl(LoopNBlocks); - - // increment A for next block - a->add( - buffer_A, static_cast<asmjit::Imm>((rowRegs)*kBlock * sizeof(uint8_t))); - - // increment C for next A block - a->sub( - CBase, - static_cast<asmjit::Imm>(jLoopTrips * nRegBlockSize * sizeof(int32_t))); - a->imul(C_Offset, ldcReg, static_cast<asmjit::Imm>(rowRegs)); - a->add(CBase, C_Offset); - - // reset B - a->mov(buffer_B, buffer_B_saved); - a->mov(B_pf, B_pf_saved); - a->cmp(iIdx, mRegBlocks); - a->jl(LoopMBlocks); - } - // generate code for remainder - if (mRegBlocksRem > 0) { - asmjit::Label LoopNRem = a->newLabel(); - asmjit::Label LoopkRem = a->newLabel(); - int rowRegs = mRegBlocksRem; - - a->mov(jIdx, 0); - a->bind(LoopNRem); - a->inc(jIdx); - - // init C registers - initCRegs<inst_set_t::avx512>(a, rowRegs, colRegs, colRegs); - - // init k loop index - a->mov(kIdx, 0); - a->bind(LoopkRem); - - // k is incremented by row_interleave - a->add(kIdx, static_cast<asmjit::Imm>(row_interleave)); - - genComputeBlock<inst_set_t::avx512>( - a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock, colRegs); - - // update buffer_A address for next k iteration - a->add( - buffer_A, static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t))); - - // update buffer_B address for next k iteration - a->add( - buffer_B, - static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t))); - a->add( - B_pf, - static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t))); - - a->cmp(kIdx, kSize); - a->jl(LoopkRem); - - // reset A - a->sub(buffer_A, kSize); - - // B for next block - // using C_Offset as temp reg - a->imul( - C_Offset, - jIdx, - static_cast<asmjit::Imm>( - nRegBlockSize * row_interleave * sizeof(int8_t))); - a->mov(buffer_B, buffer_B_saved); - a->add(buffer_B, C_Offset); - a->mov(B_pf, B_pf_saved); - a->add(B_pf, C_Offset); - - // store C matrix - storeCRegs<inst_set_t::avx512>( - a, rowRegs, colRegs, C_Offset, ldcReg, accum, colRegs); - - // increment C for next B block - a->add(CBase, static_cast<asmjit::Imm>(nRegBlockSize * sizeof(int32_t))); - - int jLoopTrips = currColRegs / maxNRegs; - // jLoopTrips should be at least 1 - jLoopTrips = jLoopTrips ? jLoopTrips : 1; - a->cmp(jIdx, jLoopTrips); - a->jl(LoopNRem); - } + int mRegBlocks = mc / mRegBlockSize; + int mRegBlocksRem = mc % mRegBlockSize; + + // arguments to the function created +#ifdef _MSC_VER + x86::Gp buffer_A = a->zcx(); + x86::Gp buffer_B = a->zdx(); + x86::Gp B_pf = a->gpz(8); + x86::Gp CBase = a->gpz(9); + x86::Gp kSize = a->zdi(); + x86::Gp ldcReg = a->zsi(); +#else + x86::Gp buffer_A = a->zdi(); + x86::Gp buffer_B = a->zsi(); + x86::Gp B_pf = a->zdx(); + x86::Gp CBase = a->zcx(); + x86::Gp kSize = a->gpz(8); + x86::Gp ldcReg = a->gpz(9); +#endif - asmjit::FuncUtils::emitEpilog(a, layout); + asmjit::FuncDetail func; + func.init( + asmjit::FuncSignatureT<void, uint8_t *, int8_t *, int8_t *, int32_t *, + int, int>(asmjit::CallConv::kIdHost)); + + asmjit::FuncFrame frame; + frame.init(func); + + frame.setDirtyRegs( + x86::Reg::kGroupVec, + asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) | + asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15)); + frame.setDirtyRegs(x86::Reg::kGroupGp, + asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15)); + + asmjit::FuncArgsAssignment args(&func); + args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg); + + args.updateFuncFrame(frame); + frame.finalize(); + + a->emitProlog(frame); + a->emitArgsAssignment(frame, args); + + asmjit::Label LoopMBlocks = a->newLabel(); + asmjit::Label LoopNBlocks = a->newLabel(); + asmjit::Label Loopk = a->newLabel(); + + x86::Gp buffer_B_saved = a->gpz(10); + x86::Gp C_Offset = a->gpz(11); + x86::Gp B_pf_saved = a->gpz(12); + x86::Gp iIdx = a->gpz(13); + x86::Gp jIdx = a->gpz(14); + x86::Gp kIdx = a->gpz(15); + // x86::Gp B_pf = a->gpz(8); + + x86::Zmm oneReg = x86::zmm29; + // create 16-bit 1s + // i.e., oneReg[0:15] contains 0x0001, oneReg[16:31] contains 0x0001 + // and so on + // a->vpcmpeqw(oneReg, oneReg, oneReg); + a->vpternlogd(oneReg, oneReg, oneReg, 0xff); + a->vpsrlw(oneReg, oneReg, 15); + a->imul(ldcReg, ldcReg, static_cast<asmjit::Imm>(sizeof(int32_t))); + + // save B_buffer address + a->mov(buffer_B_saved, buffer_B); + a->mov(B_pf_saved, B_pf); + + int currColRegs = nc * row_interleave * sizeof(int8_t) / VLEN_; + int colRegs = std::min(currColRegs, maxNRegs); + if (mRegBlocks > 0) { + // move 0 to iteration variables + a->mov(iIdx, 0); + + a->bind(LoopMBlocks); + a->inc(iIdx); + a->mov(jIdx, 0); + + a->bind(LoopNBlocks); + a->inc(jIdx); + + int rowRegs = mRegBlockSize; + + // init C registers + initCRegs<inst_set_t::avx512>(a, rowRegs, colRegs, colRegs); + + // init k loop index + a->mov(kIdx, 0); + a->bind(Loopk); + + // k is incremented by row_interleave + a->add(kIdx, static_cast<asmjit::Imm>(row_interleave)); + + genComputeBlock<inst_set_t::avx512>(a, buffer_A, buffer_B, B_pf, rowRegs, + colRegs, kBlock, colRegs); + + // update buffer_A address for next k iteration + a->add(buffer_A, + static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t))); + + // update buffer_B address for next k iteration + a->add(buffer_B, static_cast<asmjit::Imm>(nBlock * row_interleave * + sizeof(int8_t))); + a->add(B_pf, static_cast<asmjit::Imm>(nBlock * row_interleave * + sizeof(int8_t))); + + // a->add(B_pf, static_cast<asmjit::Imm>(32*sizeof(float))); + + a->cmp(kIdx, kSize); + a->jl(Loopk); + + // store C matrix + storeCRegs<inst_set_t::avx512>(a, rowRegs, colRegs, C_Offset, ldcReg, + accum, colRegs); + + // reset A + a->sub(buffer_A, kSize); + + // B for next block + a->mov(buffer_B, buffer_B_saved); + // using C_Offset as temp reg + a->imul(C_Offset, jIdx, + static_cast<asmjit::Imm>(nRegBlockSize * row_interleave * + sizeof(int8_t))); + a->add(buffer_B, C_Offset); + a->mov(B_pf, B_pf_saved); + a->add(B_pf, C_Offset); + + // increment C for next B block + a->add(CBase, static_cast<asmjit::Imm>(nRegBlockSize * sizeof(int32_t))); + + int jLoopTrips = currColRegs / maxNRegs; + // jLoopTrips should be at least 1 + jLoopTrips = jLoopTrips ? jLoopTrips : 1; + a->cmp(jIdx, jLoopTrips); + a->jl(LoopNBlocks); + + // increment A for next block + a->add(buffer_A, + static_cast<asmjit::Imm>((rowRegs)*kBlock * sizeof(uint8_t))); + + // increment C for next A block + a->sub(CBase, static_cast<asmjit::Imm>(jLoopTrips * nRegBlockSize * + sizeof(int32_t))); + a->imul(C_Offset, ldcReg, static_cast<asmjit::Imm>(rowRegs)); + a->add(CBase, C_Offset); + + // reset B + a->mov(buffer_B, buffer_B_saved); + a->mov(B_pf, B_pf_saved); + a->cmp(iIdx, mRegBlocks); + a->jl(LoopMBlocks); + } + // generate code for remainder + if (mRegBlocksRem > 0) { + asmjit::Label LoopNRem = a->newLabel(); + asmjit::Label LoopkRem = a->newLabel(); + int rowRegs = mRegBlocksRem; + + a->mov(jIdx, 0); + a->bind(LoopNRem); + a->inc(jIdx); + + // init C registers + initCRegs<inst_set_t::avx512>(a, rowRegs, colRegs, colRegs); + + // init k loop index + a->mov(kIdx, 0); + a->bind(LoopkRem); + + // k is incremented by row_interleave + a->add(kIdx, static_cast<asmjit::Imm>(row_interleave)); + + genComputeBlock<inst_set_t::avx512>(a, buffer_A, buffer_B, B_pf, rowRegs, + colRegs, kBlock, colRegs); + + // update buffer_A address for next k iteration + a->add(buffer_A, + static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t))); + + // update buffer_B address for next k iteration + a->add(buffer_B, static_cast<asmjit::Imm>(nBlock * row_interleave * + sizeof(int8_t))); + a->add(B_pf, static_cast<asmjit::Imm>(nBlock * row_interleave * + sizeof(int8_t))); + + a->cmp(kIdx, kSize); + a->jl(LoopkRem); + + // reset A + a->sub(buffer_A, kSize); + + // B for next block + // using C_Offset as temp reg + a->imul(C_Offset, jIdx, + static_cast<asmjit::Imm>(nRegBlockSize * row_interleave * + sizeof(int8_t))); + a->mov(buffer_B, buffer_B_saved); + a->add(buffer_B, C_Offset); + a->mov(B_pf, B_pf_saved); + a->add(B_pf, C_Offset); + + // store C matrix + storeCRegs<inst_set_t::avx512>(a, rowRegs, colRegs, C_Offset, ldcReg, + accum, colRegs); + + // increment C for next B block + a->add(CBase, static_cast<asmjit::Imm>(nRegBlockSize * sizeof(int32_t))); + + int jLoopTrips = currColRegs / maxNRegs; + // jLoopTrips should be at least 1 + jLoopTrips = jLoopTrips ? jLoopTrips : 1; + a->cmp(jIdx, jLoopTrips); + a->jl(LoopNRem); + } - jit_micro_kernel_fp fn; - asmjit::Error err = rt_.add(&fn, &code_); - if (err) { - std::cout << "Error: in fn add" << std::endl; - return nullptr; - } - codeCache_[kernelSig] = fn; + a->emitEpilog(frame); + + jit_micro_kernel_fp fn; + asmjit::Error err; + { + std::unique_lock<std::mutex> lock(rtMutex_); + err = runtime().add(&fn, &code); + } + if (err) { + std::cout << "Error: in fn add" << std::endl; + return nullptr; + } #if defined(FBGEMM_LOG_CODE) - fclose(codeLogfile); - delete codeLogger; + fclose(codeLogfile); + delete codeLogger; #endif - return fn; + return fn; + }); } } // namespace fbgemm diff --git a/src/GenerateKernelU8S8S32ACC32Avx512VNNI.cc b/src/GenerateKernelU8S8S32ACC32Avx512VNNI.cc new file mode 100644 index 0000000..bd8be1f --- /dev/null +++ b/src/GenerateKernelU8S8S32ACC32Avx512VNNI.cc @@ -0,0 +1,435 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include <iostream> +#include "GenerateKernel.h" + +namespace fbgemm { + +namespace x86 = asmjit::x86; + +/** + * Generate AVX512 instructions for initializing the C registers to 0 in 32-bit + * Accumulation kernel. + */ +template <> +template <> +void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::initCRegs< + inst_set_t::avx512_vnni>( + x86::Emitter* a, + int rowRegs, + int colRegs, + int leadingDimCReg) { + using CRegs = x86::Zmm; + for (int i = 0; i < rowRegs; ++i) { + for (int j = 0; j < colRegs; ++j) { + a->vxorps( + CRegs(i * leadingDimCReg + j), + CRegs(i * leadingDimCReg + j), + CRegs(i * leadingDimCReg + j)); + } + } +} + +/** + * Generate AVX512 instructions for computing block in the rank-k update of + * 32-bit Accmulation kernel. + */ +template <> +template <> +void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::genComputeBlock< + inst_set_t::avx512_vnni>( + x86::Emitter* a, + x86::Gp buffer_A, + x86::Gp buffer_B, + x86::Gp B_pf, + int rowRegs, + int colRegs, + int lda, + int leadingDimCReg) { + // used for matrix A + x86::Zmm AReg = x86::zmm31; + + // used for matrix B + x86::Zmm BReg = x86::zmm30; + + using CRegs = x86::Zmm; + + for (int j = 0; j < colRegs; ++j) { + // load B + a->vmovaps(BReg, x86::dword_ptr(buffer_B, j * VLEN_ * sizeof(int8_t))); + // load A, broadcast and fmas + for (int i = 0; i < rowRegs; ++i) { + a->vpbroadcastd( + AReg, x86::dword_ptr(buffer_A, (i * lda) * sizeof(uint8_t))); + a->vpdpbusd(CRegs(i * leadingDimCReg + j), AReg, BReg); + } + a->prefetcht0(x86::dword_ptr(B_pf, j * VLEN_ * sizeof(int8_t))); + } +} + +/** + * Generate AVX512 instructions for storing the C registers back to the memory + * in 32-bit Accumulation kernel. + */ +template <> +template <> +void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::storeCRegs< + inst_set_t::avx512_vnni>( + x86::Emitter* a, + int rowRegs, + int colRegs, + x86::Gp C_Offset, + x86::Gp ldcReg, + bool accum, + int leadingDimCReg) { + using CRegs = x86::Zmm; + for (int i = 0; i < rowRegs; ++i) { + if (i != 0) { + a->add(C_Offset, ldcReg); + } else { + a->mov(C_Offset, static_cast<asmjit::Imm>(0)); + } + for (int j = 0; j < colRegs; ++j) { + if (accum) { + a->vpaddd( + CRegs(i * leadingDimCReg + j), + CRegs(i * leadingDimCReg + j), +#ifdef _MSC_VER + x86::dword_ptr(a->gpz(9), C_Offset, 0, j * 16 * sizeof(int32_t))); +#else + x86::dword_ptr(a->zcx(), C_Offset, 0, j * 16 * sizeof(int32_t))); +#endif + } + a->vmovups( +#ifdef _MSC_VER + x86::dword_ptr(a->gpz(9), C_Offset, 0, j * 16 * sizeof(int32_t)), +#else + x86::dword_ptr(a->zcx(), C_Offset, 0, j * 16 * sizeof(int32_t)), +#endif + CRegs(i * leadingDimCReg + j)); + } + } +} + +/** + * Get or Create the AVX512 instructions for 32-bit Accumulation macro-kernel. + * + */ +template <> +template <> +CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::jit_micro_kernel_fp +CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreate< + inst_set_t::avx512_vnni>( + bool accum, + int32_t mc, + int32_t nc, + int32_t kc, + int32_t /* unused */) { + std::tuple<bool, int, int, int, int, int, int, int> kernelSig; + int kBlock; + int nBlock; + int mRegBlockSize; + int nRegBlockSize; + int nRegBlockSizeMin; + int row_interleave; + + if (blocking_params) { + kBlock = blocking_params->KCB; + nBlock = blocking_params->NCB; + mRegBlockSize = blocking_params->MR; + nRegBlockSize = blocking_params->NR; + nRegBlockSizeMin = blocking_params->NR_MIN; + row_interleave = blocking_params->ROW_INTERLEAVE; + } else { + kBlock = PackingTraits<uint8_t, int32_t, inst_set_t::avx512_vnni>::KCB; + nBlock = PackingTraits<uint8_t, int32_t, inst_set_t::avx512_vnni>::NCB; + mRegBlockSize = + PackingTraits<uint8_t, int32_t, inst_set_t::avx512_vnni>::MR; + nRegBlockSize = + PackingTraits<uint8_t, int32_t, inst_set_t::avx512_vnni>::NR; + nRegBlockSizeMin = + PackingTraits<uint8_t, int32_t, inst_set_t::avx512_vnni>::NR_MIN; + row_interleave = PackingTraits<uint8_t, int32_t, inst_set_t::avx512_vnni>:: + ROW_INTERLEAVE; + } + + kernelSig = std::make_tuple( + accum, + mc, + nc, + nBlock, + kBlock, + mRegBlockSize, + nRegBlockSize, + nRegBlockSizeMin); + + return codeCache_.getOrCreate(kernelSig, [&]() -> jit_micro_kernel_fp { + asmjit::CodeHolder code; + code.init(runtime().codeInfo()); + x86::Assembler assembler(&code); + x86::Emitter *a = assembler.as<x86::Emitter>(); + +#if defined(FBGEMM_LOG_CODE) + // generated code logging + FILE *codeLogfile = fopen(getCodeLoggingFile<inst_set_t::avx512_vnni>( + accum, mc, nc, nBlock, kBlock, mRegBlockSize, + nRegBlockSize, nRegBlockSizeMin) + .c_str(), + "w"); + asmjit::FileLogger *codeLogger = new asmjit::FileLogger(codeLogfile); + if (codeLogger) { + code.setLogger(codeLogger); + } +#endif + + assert(kc % row_interleave == 0 && + "kc must be a multiple of row_interleave"); + assert(nc % nRegBlockSizeMin == 0 && "nc must be a multiple of NR_MIN"); + int maxMRegs = mRegBlockSize; + int maxNRegs = nRegBlockSize * row_interleave / VLEN_; + assert(maxMRegs * maxNRegs <= 28 && "MR*(NR*ROW_INTERLEAVE*8/512) \ + must be <= 28(available registers constraint)"); + + int mRegBlocks = mc / mRegBlockSize; + int mRegBlocksRem = mc % mRegBlockSize; + + // arguments to the function created +#ifdef _MSC_VER + x86::Gp buffer_A = a->zcx(); + x86::Gp buffer_B = a->zdx(); + x86::Gp B_pf = a->gpz(8); + x86::Gp CBase = a->gpz(9); + x86::Gp kSize = a->zdi(); + x86::Gp ldcReg = a->zsi(); +#else + x86::Gp buffer_A = a->zdi(); + x86::Gp buffer_B = a->zsi(); + x86::Gp B_pf = a->zdx(); + x86::Gp CBase = a->zcx(); + x86::Gp kSize = a->gpz(8); + x86::Gp ldcReg = a->gpz(9); +#endif + + asmjit::FuncDetail func; + func.init( + asmjit::FuncSignatureT<void, uint8_t *, int8_t *, int8_t *, int32_t *, + int, int>(asmjit::CallConv::kIdHost)); + + asmjit::FuncFrame frame; + frame.init(func); + + frame.setDirtyRegs( + x86::Reg::kGroupVec, + asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) | + asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15)); + frame.setDirtyRegs(x86::Reg::kGroupGp, + asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15)); + + asmjit::FuncArgsAssignment args(&func); + args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg); + + args.updateFuncFrame(frame); + frame.finalize(); + + a->emitProlog(frame); + a->emitArgsAssignment(frame, args); + + asmjit::Label LoopMBlocks = a->newLabel(); + asmjit::Label LoopNBlocks = a->newLabel(); + asmjit::Label Loopk = a->newLabel(); + + x86::Gp buffer_B_saved = a->gpz(10); + x86::Gp C_Offset = a->gpz(11); + x86::Gp B_pf_saved = a->gpz(12); + x86::Gp iIdx = a->gpz(13); + x86::Gp jIdx = a->gpz(14); + x86::Gp kIdx = a->gpz(15); + // x86::Gp B_pf = a->gpz(8); + + x86::Zmm oneReg = x86::zmm29; + // create 16-bit 1s + // i.e., oneReg[0:15] contains 0x0001, oneReg[16:31] contains 0x0001 + // and so on + // a->vpcmpeqw(oneReg, oneReg, oneReg); + a->vpternlogd(oneReg, oneReg, oneReg, 0xff); + a->vpsrlw(oneReg, oneReg, 15); + a->imul(ldcReg, ldcReg, static_cast<asmjit::Imm>(sizeof(int32_t))); + + // save B_buffer address + a->mov(buffer_B_saved, buffer_B); + a->mov(B_pf_saved, B_pf); + + int currColRegs = nc * row_interleave * sizeof(int8_t) / VLEN_; + int colRegs = std::min(currColRegs, maxNRegs); + if (mRegBlocks > 0) { + // move 0 to iteration variables + a->mov(iIdx, 0); + + a->bind(LoopMBlocks); + a->inc(iIdx); + a->mov(jIdx, 0); + + a->bind(LoopNBlocks); + a->inc(jIdx); + + int rowRegs = mRegBlockSize; + + // init C registers + initCRegs<inst_set_t::avx512_vnni>(a, rowRegs, colRegs, colRegs); + + // init k loop index + a->mov(kIdx, 0); + a->bind(Loopk); + + // k is incremented by row_interleave + a->add(kIdx, static_cast<asmjit::Imm>(row_interleave)); + + genComputeBlock<inst_set_t::avx512_vnni>( + a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock, colRegs); + + // update buffer_A address for next k iteration + a->add(buffer_A, + static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t))); + + // update buffer_B address for next k iteration + a->add(buffer_B, static_cast<asmjit::Imm>(nBlock * row_interleave * + sizeof(int8_t))); + a->add(B_pf, static_cast<asmjit::Imm>(nBlock * row_interleave * + sizeof(int8_t))); + + // a->add(B_pf, static_cast<asmjit::Imm>(32*sizeof(float))); + + a->cmp(kIdx, kSize); + a->jl(Loopk); + + // store C matrix + storeCRegs<inst_set_t::avx512_vnni>(a, rowRegs, colRegs, C_Offset, ldcReg, + accum, colRegs); + + // reset A + a->sub(buffer_A, kSize); + + // B for next block + a->mov(buffer_B, buffer_B_saved); + // using C_Offset as temp reg + a->imul(C_Offset, jIdx, + static_cast<asmjit::Imm>(nRegBlockSize * row_interleave * + sizeof(int8_t))); + a->add(buffer_B, C_Offset); + a->mov(B_pf, B_pf_saved); + a->add(B_pf, C_Offset); + + // increment C for next B block + a->add(CBase, static_cast<asmjit::Imm>(nRegBlockSize * sizeof(int32_t))); + + int jLoopTrips = currColRegs / maxNRegs; + // jLoopTrips should be at least 1 + jLoopTrips = jLoopTrips ? jLoopTrips : 1; + a->cmp(jIdx, jLoopTrips); + a->jl(LoopNBlocks); + + // increment A for next block + a->add(buffer_A, + static_cast<asmjit::Imm>((rowRegs)*kBlock * sizeof(uint8_t))); + + // increment C for next A block + a->sub(CBase, static_cast<asmjit::Imm>(jLoopTrips * nRegBlockSize * + sizeof(int32_t))); + a->imul(C_Offset, ldcReg, static_cast<asmjit::Imm>(rowRegs)); + a->add(CBase, C_Offset); + + // reset B + a->mov(buffer_B, buffer_B_saved); + a->mov(B_pf, B_pf_saved); + a->cmp(iIdx, mRegBlocks); + a->jl(LoopMBlocks); + } + // generate code for remainder + if (mRegBlocksRem > 0) { + asmjit::Label LoopNRem = a->newLabel(); + asmjit::Label LoopkRem = a->newLabel(); + int rowRegs = mRegBlocksRem; + + a->mov(jIdx, 0); + a->bind(LoopNRem); + a->inc(jIdx); + + // init C registers + initCRegs<inst_set_t::avx512_vnni>(a, rowRegs, colRegs, colRegs); + + // init k loop index + a->mov(kIdx, 0); + a->bind(LoopkRem); + + // k is incremented by row_interleave + a->add(kIdx, static_cast<asmjit::Imm>(row_interleave)); + + genComputeBlock<inst_set_t::avx512_vnni>( + a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock, colRegs); + + // update buffer_A address for next k iteration + a->add(buffer_A, + static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t))); + + // update buffer_B address for next k iteration + a->add(buffer_B, static_cast<asmjit::Imm>(nBlock * row_interleave * + sizeof(int8_t))); + a->add(B_pf, static_cast<asmjit::Imm>(nBlock * row_interleave * + sizeof(int8_t))); + + a->cmp(kIdx, kSize); + a->jl(LoopkRem); + + // reset A + a->sub(buffer_A, kSize); + // B for next block + // using C_Offset as temp reg + a->imul(C_Offset, jIdx, + static_cast<asmjit::Imm>(nRegBlockSize * row_interleave * + sizeof(int8_t))); + a->mov(buffer_B, buffer_B_saved); + a->add(buffer_B, C_Offset); + a->mov(B_pf, B_pf_saved); + a->add(B_pf, C_Offset); + + // store C matrix + storeCRegs<inst_set_t::avx512_vnni>(a, rowRegs, colRegs, C_Offset, ldcReg, + accum, colRegs); + + // increment C for next B block + a->add(CBase, static_cast<asmjit::Imm>(nRegBlockSize * sizeof(int32_t))); + + int jLoopTrips = currColRegs / maxNRegs; + // jLoopTrips should be at least 1 + jLoopTrips = jLoopTrips ? jLoopTrips : 1; + a->cmp(jIdx, jLoopTrips); + a->jl(LoopNRem); + } + + a->emitEpilog(frame); + + jit_micro_kernel_fp fn; + asmjit::Error err; + { + std::unique_lock<std::mutex> lock(rtMutex_); + err = runtime().add(&fn, &code); + } + if (err) { + std::cout << "Error: in fn add" << std::endl; + return nullptr; + } + +#if defined(FBGEMM_LOG_CODE) + fclose(codeLogfile); + delete codeLogger; +#endif + + return fn; + }); +} + +} // namespace fbgemm diff --git a/src/GroupwiseConv.h b/src/GroupwiseConv.h index 1e6324e..58ee24d 100644 --- a/src/GroupwiseConv.h +++ b/src/GroupwiseConv.h @@ -10,8 +10,10 @@ #include <cassert> #include <cstdint> #include <map> +#include <mutex> #include <string> #include <tuple> +#include "CodeCache.h" #include "fbgemm/ConvUtils.h" #include "fbgemm/Fbgemm.h" #include "fbgemm/Utils.h" @@ -128,60 +130,58 @@ class GenConvKernel { const conv_param_t<SPATIAL_DIM>& conv_param); template <inst_set_t instSet> - void createVector16BitOne(asmjit::X86Emitter* a); + void createVector16BitOne(x86::Emitter* a); template <inst_set_t instSet> - void createVector8BitOne(asmjit::X86Emitter* a); + void createVector8BitOne(x86::Emitter* a); template <inst_set_t instSet> - void setToZeroPt(asmjit::X86Emitter* a, asmjit::X86Ymm destReg); + void setToZeroPt(x86::Emitter* a, x86::Ymm destReg); template <inst_set_t instSet> - void - gen8bitFMA(asmjit::X86Emitter* a, asmjit::X86Ymm aReg, asmjit::X86Ymm wReg); + void gen8bitFMA(x86::Emitter* a, x86::Ymm aReg, x86::Ymm wReg); template <inst_set_t instSet> - void genForLoadingWeights(asmjit::X86Emitter* a, int c_offset); + void genForLoadingWeights(x86::Emitter* a, int c_offset); template <inst_set_t instSet> - void genConstForPermutations(asmjit::X86Emitter* a); + void genConstForPermutations(x86::Emitter* a); template <inst_set_t instSet> - void genForTopEdge(asmjit::X86Emitter* a, int c_offset); + void genForTopEdge(x86::Emitter* a, int c_offset); template <inst_set_t instSet> - void genForLeftEdge(asmjit::X86Emitter* a, int c_offset); + void genForLeftEdge(x86::Emitter* a, int c_offset); template <inst_set_t instSet> - void genForRightEdge(asmjit::X86Emitter* a, int c_offset); + void genForRightEdge(x86::Emitter* a, int c_offset); template <inst_set_t instSet> - void genForBottomEdge(asmjit::X86Emitter* a, int c_offset); + void genForBottomEdge(x86::Emitter* a, int c_offset); template <inst_set_t instSet> - void genCoreInsts(asmjit::X86Emitter* a, int c_offset); + void genCoreInsts(x86::Emitter* a, int c_offset); template <inst_set_t instSet> - void storeResult(asmjit::X86Emitter* a); + void storeResult(x86::Emitter* a); // for Rowoffset kernel // Add 4 consecutive numbers of 32 uint8 and emit 8 32-bit template <inst_set_t instSet> - void gen8BitSumX4(asmjit::X86Emitter* a, asmjit::X86Ymm aReg); + void gen8BitSumX4(x86::Emitter* a, x86::Ymm aReg); // Add 8 consecutive numbers of 64 uint8 and emit 8 32-bit template <inst_set_t instSet> - void - gen8BitSumX8(asmjit::X86Emitter* a, asmjit::X86Ymm aReg, asmjit::X86Ymm bReg); + void gen8BitSumX8(x86::Emitter* a, x86::Ymm aReg, x86::Ymm bReg); // Add 16 consecutive numbers of 128 uint8 and emit 8 32-bit template <inst_set_t instSet> void gen8BitSumX16( - asmjit::X86Emitter* a, - asmjit::X86Ymm aReg, - asmjit::X86Ymm bReg, - asmjit::X86Ymm cReg, - asmjit::X86Ymm dReg); + x86::Emitter* a, + x86::Ymm aReg, + x86::Ymm bReg, + x86::Ymm cReg, + x86::Ymm dReg); // Generate instruction sequence that loads 8-bit values and sum them up. // Depending on C_per_G_, this function dispatches to gen8BitSumX4/8/16 @@ -191,73 +191,78 @@ class GenConvKernel { // Internally, actRegAvx2_, stPermRegAvx2_, WRegs_avx2_[0, 1], tmpReg1Avx2_, // and resultRegAvx2_ are used. template <inst_set_t instSet> - void gen8BitSum( - asmjit::X86Emitter* a, - int act_offset, - bool use_scratch_reg1 = true); + void + gen8BitSum(x86::Emitter* a, int act_offset, bool use_scratch_reg1 = true); // Use scratchReg1_ and tmpReg1Avx2_ internally template <inst_set_t instSet> - void genZeroPtSum(asmjit::X86Emitter* a, int multiplier); + void genZeroPtSum(x86::Emitter* a, int multiplier); template <inst_set_t instSet> - void genForTopEdgeRowoffset(asmjit::X86Emitter* a); + void genForTopEdgeRowoffset(x86::Emitter* a); template <inst_set_t instSet> - void genForLeftEdgeRowoffset(asmjit::X86Emitter* a); + void genForLeftEdgeRowoffset(x86::Emitter* a); template <inst_set_t instSet> - void genForRightEdgeRowoffset(asmjit::X86Emitter* a); + void genForRightEdgeRowoffset(x86::Emitter* a); template <inst_set_t instSet> - void genForBottomEdgeRowoffset(asmjit::X86Emitter* a); + void genForBottomEdgeRowoffset(x86::Emitter* a); template <inst_set_t instSet> - void genRowoffsetCorners(asmjit::X86Emitter* a); + void genRowoffsetCorners(x86::Emitter* a); template <inst_set_t instSet> - void genRowoffsetCore(asmjit::X86Emitter* a); + void genRowoffsetCore(x86::Emitter* a); template <inst_set_t instSet> - void storeResultRowoffset(asmjit::X86Emitter* a, int offset = 0); - - static thread_local asmjit::JitRuntime rt_; ///< JIT Runtime for asmjit. - static thread_local asmjit::CodeHolder code_; ///< JIT Code Holder for asmjit. - static thread_local std:: - map<std::tuple<bool, int, int, int>, jit_conv_kernel_fp> - codeCache_; ///< JIT Code Cache for reuse. - static thread_local std:: - map<std::tuple<bool, int, int, int>, jit_rowoffset_kernel_fp> - codeCacheRowOffset_; ///< JIT Code Cache for row offset kernel. - - private: + void storeResultRowoffset(x86::Emitter* a, int offset = 0); + + + static asmjit::JitRuntime &runtime() { + static asmjit::JitRuntime rt; //< JIT Runtime for asmjit, + // depents on other static + // variables. Required to prevent + // initialization order fiasco + return rt; + } + + static std::mutex rtMutex_; ///< Controll access to runtime; + + static CodeCache<std::tuple<bool, int, int, int>, jit_conv_kernel_fp> + codeCache_; ///< JIT Code Cache for reuse. + static CodeCache<std::tuple<bool, int, int, int>, jit_rowoffset_kernel_fp> + codeCacheRowOffset_; ///< JIT Code Cache for row offset kernel. + +private: int vectorWidth_; ///< Vector width in bits. int VLEN_; ///< Vector width in elements. // avx2 specific - asmjit::X86Ymm + x86::Ymm WRegs_avx2_[9]; ///< AVX2 ymm registers for weights in the micro-kernel. - asmjit::X86Ymm zeroPTRegAvx2_; - asmjit::X86Ymm tmpReg1Avx2_; - asmjit::X86Ymm stPermRegAvx2_; - asmjit::X86Ymm actRegAvx2_; - asmjit::X86Ymm resultRegAvx2_; - asmjit::X86Ymm oneReg8BitAvx2_; - asmjit::X86Ymm oneReg16BitAvx2_; + x86::Ymm zeroPTRegAvx2_; + x86::Ymm tmpReg1Avx2_; + x86::Ymm stPermRegAvx2_; + x86::Ymm actRegAvx2_; + x86::Ymm resultRegAvx2_; + x86::Ymm oneReg8BitAvx2_; + x86::Ymm oneReg16BitAvx2_; // arguments to the function created - asmjit::X86Gp in_acts_R_; - asmjit::X86Gp wghts_R_; - asmjit::X86Gp out_acts_R_; - asmjit::X86Gp a_zero_pt_R_; - asmjit::X86Gp H_R_; - asmjit::X86Gp W_R_; - asmjit::X86Gp row_offset_R_; + x86::Gp in_acts_R_; + x86::Gp wghts_R_; + x86::Gp out_acts_R_; + x86::Gp a_zero_pt_R_; + x86::Gp H_R_; + x86::Gp W_R_; + x86::Gp row_offset_R_; // Used registers - asmjit::X86Gp loopR1_; - asmjit::X86Gp loopR2_; - asmjit::X86Gp scratchReg1_; - asmjit::X86Gp scratchReg2_; + x86::Gp loopR1_; + x86::Gp loopR2_; + x86::Gp scratchReg1_; + x86::Gp scratchReg2_; // Other parameters bool isAZeroPointZero_; @@ -276,4 +281,15 @@ class GenConvKernel { int W_PAD_; ///< Padding for width (left and right) }; +template <int SPATIAL_DIM, typename accT> +std::mutex GenConvKernel<SPATIAL_DIM, accT>::rtMutex_; + +template <int SPATIAL_DIM, typename accT> +CodeCache<std::tuple<bool, int, int, int>, jit_conv_kernel_fp> + GenConvKernel<SPATIAL_DIM, accT>::codeCache_; + +template <int SPATIAL_DIM, typename accT> +CodeCache<std::tuple<bool, int, int, int>, jit_rowoffset_kernel_fp> + GenConvKernel<SPATIAL_DIM, accT>::codeCacheRowOffset_; + } // namespace fbgemm diff --git a/src/GroupwiseConvAcc32Avx2.cc b/src/GroupwiseConvAcc32Avx2.cc index e789695..396e792 100644 --- a/src/GroupwiseConvAcc32Avx2.cc +++ b/src/GroupwiseConvAcc32Avx2.cc @@ -21,20 +21,6 @@ namespace fbgemm { using namespace std; -template <int SPATIAL_DIM, typename accT> -thread_local asmjit::JitRuntime GenConvKernel<SPATIAL_DIM, accT>::rt_; - -template <int SPATIAL_DIM, typename accT> -thread_local asmjit::CodeHolder GenConvKernel<SPATIAL_DIM, accT>::code_; - -template <int SPATIAL_DIM, typename accT> -thread_local std::map<std::tuple<bool, int, int, int>, jit_conv_kernel_fp> - GenConvKernel<SPATIAL_DIM, accT>::codeCache_; - -template <int SPATIAL_DIM, typename accT> -thread_local std::map<std::tuple<bool, int, int, int>, jit_rowoffset_kernel_fp> - GenConvKernel<SPATIAL_DIM, accT>::codeCacheRowOffset_; - namespace x86 = asmjit::x86; template <int SPATIAL_DIM> @@ -91,20 +77,19 @@ jit_conv_kernel_fp getOrCreateConvKernel( // Note: Wrong code is generated if it's not one of the supported convolution assert(fbgemmOptimizedGConv<SPATIAL_DIM>(conv_param)); auto kernelSig = getKernelSig(conv_param, a_zero_point == 0); - if (GenConvKernel<SPATIAL_DIM, accT>::codeCache_.find(kernelSig) != - GenConvKernel<SPATIAL_DIM, accT>::codeCache_.end()) { - return GenConvKernel<SPATIAL_DIM, accT>::codeCache_[kernelSig]; - } else { - auto genObj = GenConvKernel<SPATIAL_DIM, accT>(conv_param, a_zero_point); - // TODO: Instruction set based dispatch - return genObj.template getOrCreate<inst_set_t::avx2>(conv_param); - } + return GenConvKernel<SPATIAL_DIM, accT>::codeCache_.getOrCreate( + kernelSig, [&]() { + auto genObj = + GenConvKernel<SPATIAL_DIM, accT>(conv_param, a_zero_point); + // TODO: Instruction set based dispatch + return genObj.template getOrCreate<inst_set_t::avx2>(conv_param); + }); } template <> template <> void GenConvKernel<2, int32_t>::createVector8BitOne<inst_set_t::avx2>( - asmjit::X86Emitter* a) { + x86::Emitter* a) { // create 8-bit 1s // i.e., oneReg16BitAvx2_[0:7] contains 0x01, oneReg8BitAvx2_[8:15] contains // 0x01 and so on @@ -115,7 +100,7 @@ void GenConvKernel<2, int32_t>::createVector8BitOne<inst_set_t::avx2>( template <> template <> void GenConvKernel<2, int32_t>::createVector16BitOne<inst_set_t::avx2>( - asmjit::X86Emitter* a) { + x86::Emitter* a) { // create 16-bit 1s // i.e., oneReg16BitAvx2_[0:15] contains 0x0001, oneReg16BitAvx2_[16:31] // contains 0x0001 and so on @@ -125,11 +110,11 @@ void GenConvKernel<2, int32_t>::createVector16BitOne<inst_set_t::avx2>( template <> template <> void GenConvKernel<2, int32_t>::setToZeroPt<inst_set_t::avx2>( - asmjit::X86Emitter* a, - asmjit::X86Ymm destReg) { + x86::Emitter* a, + x86::Ymm destReg) { // make destReg all zeros a->vxorps(destReg, destReg, destReg); - asmjit::X86Xmm const_reg_xmm = x86::xmm10; + x86::Xmm const_reg_xmm = x86::xmm10; // move zero point to xmm10 a->movq(const_reg_xmm, a_zero_pt_R_); // make copies of zero point @@ -143,9 +128,9 @@ void GenConvKernel<2, int32_t>::setToZeroPt<inst_set_t::avx2>( template <> template <> void GenConvKernel<2, int32_t>::genConstForPermutations<inst_set_t::avx2>( - asmjit::X86Emitter* a) { - asmjit::X86Gp permute_const_reg = a->gpzRef(12); - asmjit::X86Xmm const_reg_xmm = x86::xmm10; + x86::Emitter* a) { + x86::Gp permute_const_reg = a->gpz(12); + x86::Xmm const_reg_xmm = x86::xmm10; // We have 1st group in even lanes and 2nd group in odd lanes. // Permute to put 1st group to lower 128-bit and 2nd group in upper // 128-bit. @@ -159,8 +144,7 @@ void GenConvKernel<2, int32_t>::genConstForPermutations<inst_set_t::avx2>( template <> template <> -void GenConvKernel<2, int32_t>::storeResult<inst_set_t::avx2>( - asmjit::X86Emitter* a) { +void GenConvKernel<2, int32_t>::storeResult<inst_set_t::avx2>(x86::Emitter* a) { if (C_per_G_ == 4) { // store with permutation a->vpermd(resultRegAvx2_, stPermRegAvx2_, resultRegAvx2_); @@ -171,7 +155,7 @@ void GenConvKernel<2, int32_t>::storeResult<inst_set_t::avx2>( template <> template <> void GenConvKernel<2, int32_t>::storeResultRowoffset<inst_set_t::avx2>( - asmjit::X86Emitter* a, + x86::Emitter* a, int offset) { // store if (C_per_G_ == 4) { @@ -198,7 +182,7 @@ void GenConvKernel<2, int32_t>::storeResultRowoffset<inst_set_t::avx2>( template <> template <> void GenConvKernel<2, int32_t>::genForLoadingWeights<inst_set_t::avx2>( - asmjit::X86Emitter* a, + x86::Emitter* a, int c_offset) { // load weights for (int r = 0; r < R_; ++r) { @@ -225,9 +209,9 @@ void GenConvKernel<2, int32_t>::genForLoadingWeights<inst_set_t::avx2>( template <> template <> void GenConvKernel<2, int32_t>::gen8bitFMA<inst_set_t::avx2>( - asmjit::X86Emitter* a, - asmjit::X86Ymm aReg, - asmjit::X86Ymm wReg) { + x86::Emitter* a, + x86::Ymm aReg, + x86::Ymm wReg) { a->vpmaddubsw(tmpReg1Avx2_, aReg, wReg); a->vpmaddwd(tmpReg1Avx2_, oneReg16BitAvx2_, tmpReg1Avx2_); a->vpaddd(resultRegAvx2_, tmpReg1Avx2_, resultRegAvx2_); @@ -236,8 +220,8 @@ void GenConvKernel<2, int32_t>::gen8bitFMA<inst_set_t::avx2>( template <> template <> void GenConvKernel<2, int32_t>::gen8BitSumX4<inst_set_t::avx2>( - asmjit::X86Emitter* a, - asmjit::X86Ymm aReg) { + x86::Emitter* a, + x86::Ymm aReg) { a->vpmaddubsw(tmpReg1Avx2_, aReg, oneReg8BitAvx2_); a->vpmaddwd(tmpReg1Avx2_, tmpReg1Avx2_, oneReg16BitAvx2_); a->vpaddd(resultRegAvx2_, tmpReg1Avx2_, resultRegAvx2_); @@ -246,9 +230,9 @@ void GenConvKernel<2, int32_t>::gen8BitSumX4<inst_set_t::avx2>( template <> template <> void GenConvKernel<2, int32_t>::gen8BitSumX8<inst_set_t::avx2>( - asmjit::X86Emitter* a, - asmjit::X86Ymm aReg, - asmjit::X86Ymm bReg) { + x86::Emitter* a, + x86::Ymm aReg, + x86::Ymm bReg) { a->vxorps(tmpReg1Avx2_, tmpReg1Avx2_, tmpReg1Avx2_); // Let a[0] denote 0th (LSB) 8-bit of aReg // After vpsadbw, a[0:2] = a[0] + ... + a[7] @@ -267,11 +251,11 @@ void GenConvKernel<2, int32_t>::gen8BitSumX8<inst_set_t::avx2>( template <> template <> void GenConvKernel<2, int32_t>::gen8BitSumX16<inst_set_t::avx2>( - asmjit::X86Emitter* a, - asmjit::X86Ymm aReg, - asmjit::X86Ymm bReg, - asmjit::X86Ymm cReg, - asmjit::X86Ymm dReg) { + x86::Emitter* a, + x86::Ymm aReg, + x86::Ymm bReg, + x86::Ymm cReg, + x86::Ymm dReg) { a->vxorps(tmpReg1Avx2_, tmpReg1Avx2_, tmpReg1Avx2_); // After vpsadbw, a[0:2] = a[0] + ... + a[7] // a[8:10] = a[8] + ... + a[15] @@ -319,7 +303,7 @@ void GenConvKernel<2, int32_t>::gen8BitSumX16<inst_set_t::avx2>( template <> template <> void GenConvKernel<2, int32_t>::gen8BitSum<inst_set_t::avx2>( - asmjit::X86Emitter* a, + x86::Emitter* a, int act_offset, bool use_scratch_reg1 /*=true*/) { if (use_scratch_reg1) { @@ -385,11 +369,11 @@ void GenConvKernel<2, int32_t>::gen8BitSum<inst_set_t::avx2>( template <> template <> void GenConvKernel<2, int32_t>::genZeroPtSum<inst_set_t::avx2>( - asmjit::X86Emitter* a, + x86::Emitter* a, int multiplier) { a->mov(scratchReg1_, static_cast<asmjit::Imm>(multiplier)); // tmpReg1Avx2_ also uses xmm11 - asmjit::X86Xmm const_reg_xmm = x86::xmm11; + x86::Xmm const_reg_xmm = x86::xmm11; a->movq(const_reg_xmm, scratchReg1_); a->vpbroadcastd(tmpReg1Avx2_, const_reg_xmm); a->vpmulld(tmpReg1Avx2_, zeroPTRegAvx2_, tmpReg1Avx2_); @@ -399,7 +383,7 @@ void GenConvKernel<2, int32_t>::genZeroPtSum<inst_set_t::avx2>( template <> template <> void GenConvKernel<2, int32_t>::genForTopEdge<inst_set_t::avx2>( - asmjit::X86Emitter* a, + x86::Emitter* a, int c_offset) { // top-left corner code if (c_offset == 0) { @@ -559,7 +543,7 @@ void GenConvKernel<2, int32_t>::genForTopEdge<inst_set_t::avx2>( template <> template <> void GenConvKernel<2, int32_t>::genForLeftEdge<inst_set_t::avx2>( - asmjit::X86Emitter* a, + x86::Emitter* a, int c_offset) { // left edge excluding corners asmjit::Label LoopLeftEdge = a->newLabel(); @@ -626,7 +610,7 @@ void GenConvKernel<2, int32_t>::genForLeftEdge<inst_set_t::avx2>( template <> template <> void GenConvKernel<2, int32_t>::genForRightEdge<inst_set_t::avx2>( - asmjit::X86Emitter* a, + x86::Emitter* a, int c_offset) { // right edge excluding corners asmjit::Label LoopRightEdge = a->newLabel(); @@ -714,7 +698,7 @@ void GenConvKernel<2, int32_t>::genForRightEdge<inst_set_t::avx2>( template <> template <> void GenConvKernel<2, int32_t>::genForBottomEdge<inst_set_t::avx2>( - asmjit::X86Emitter* a, + x86::Emitter* a, int c_offset) { // bottom-left corner // we updating the last row @@ -906,7 +890,7 @@ void GenConvKernel<2, int32_t>::genForBottomEdge<inst_set_t::avx2>( template <> template <> void GenConvKernel<2, int32_t>::genCoreInsts<inst_set_t::avx2>( - asmjit::X86Emitter* a, + x86::Emitter* a, int c_offset) { // main compute asmjit::Label LoopH = a->newLabel(); @@ -1010,10 +994,10 @@ template <> template <> jit_conv_kernel_fp GenConvKernel<2, int32_t>::getOrCreate<inst_set_t::avx2>( const conv_param_t<2>& conv_param) { - code_.reset(false); - code_.init(rt_.getCodeInfo()); - asmjit::X86Assembler assembler(&code_); - asmjit::X86Emitter* a = assembler.asEmitter(); + asmjit::CodeHolder code; + code.init(runtime().codeInfo()); + x86::Assembler assembler(&code); + x86::Emitter* a = assembler.as<x86::Emitter>(); #if defined(FBGEMM_LOG_CODE) // log code to a file @@ -1021,25 +1005,34 @@ jit_conv_kernel_fp GenConvKernel<2, int32_t>::getOrCreate<inst_set_t::avx2>( fopen(getCodeLoggingFile<inst_set_t::avx2>(false).c_str(), "w"); asmjit::FileLogger* codeLogger = new asmjit::FileLogger(codeLogfile); if (codeLogger) { - code_.setLogger(codeLogger); + code.setLogger(codeLogger); } #endif // arguments to the function created +#ifdef _MSC_VER + in_acts_R_ = a->zcx(); + wghts_R_ = a->zdx(); + out_acts_R_ = a->gpz(8); + a_zero_pt_R_ = a->gpz(9); + H_R_ = a->zdi(); + W_R_ = a->zsi(); +#else in_acts_R_ = a->zdi(); wghts_R_ = a->zsi(); out_acts_R_ = a->zdx(); a_zero_pt_R_ = a->zcx(); - H_R_ = a->gpzRef(8); - W_R_ = a->gpzRef(9); - row_offset_R_ = a->gpzRef(10); + H_R_ = a->gpz(8); + W_R_ = a->gpz(9); +#endif + row_offset_R_ = a->gpz(10); // register for temporary use - scratchReg1_ = a->gpzRef(12); - scratchReg2_ = a->gpzRef(13); + scratchReg1_ = a->gpz(12); + scratchReg2_ = a->gpz(13); asmjit::FuncDetail func; - func.init(asmjit::FuncSignature6< + func.init(asmjit::FuncSignatureT< void, uint8_t*, int8_t*, @@ -1048,29 +1041,29 @@ jit_conv_kernel_fp GenConvKernel<2, int32_t>::getOrCreate<inst_set_t::avx2>( int32_t, int32_t>(asmjit::CallConv::kIdHost)); - asmjit::FuncFrameInfo ffi; - ffi.setDirtyRegs( - asmjit::X86Reg::kKindVec, - asmjit::Utils::mask(0, 1, 2, 3, 4, 5, 6, 7) | - asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14, 15)); - ffi.setDirtyRegs( - asmjit::X86Reg::kKindGp, asmjit::Utils::mask(10, 11, 12, 13, 14, 15)); + asmjit::FuncFrame frame; + frame.init(func); - asmjit::FuncArgsMapper args(&func); - args.assignAll(in_acts_R_, wghts_R_, out_acts_R_, a_zero_pt_R_, H_R_, W_R_); + frame.setDirtyRegs( + x86::Reg::kGroupVec, + asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) | + asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15)); + frame.setDirtyRegs( + x86::Reg::kGroupGp, asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15)); - args.updateFrameInfo(ffi); + asmjit::FuncArgsAssignment args(&func); + args.assignAll(in_acts_R_, wghts_R_, out_acts_R_, a_zero_pt_R_, H_R_, W_R_); - asmjit::FuncFrameLayout layout; - layout.init(func, ffi); + args.updateFuncFrame(frame); + frame.finalize(); - asmjit::FuncUtils::emitProlog(a, layout); - asmjit::FuncUtils::allocArgs(a, layout, args); + a->emitProlog(frame); + a->emitArgsAssignment(frame, args); createVector16BitOne<inst_set_t::avx2>(a); - loopR1_ = a->gpzRef(14); - loopR2_ = a->gpzRef(15); + loopR1_ = a->gpz(14); + loopR2_ = a->gpz(15); if (!isAZeroPointZero_) { setToZeroPt<inst_set_t::avx2>(a, zeroPTRegAvx2_); @@ -1095,16 +1088,18 @@ jit_conv_kernel_fp GenConvKernel<2, int32_t>::getOrCreate<inst_set_t::avx2>( genCoreInsts<inst_set_t::avx2>(a, c); } - asmjit::FuncUtils::emitEpilog(a, layout); + a->emitEpilog(frame); jit_conv_kernel_fp fn; - asmjit::Error err = rt_.add(&fn, &code_); + asmjit::Error err; + { + std::unique_lock<std::mutex> lock(rtMutex_); + err = runtime().add(&fn, &code); + } if (err) { std::cout << "Error: in fn add" << std::endl; return nullptr; } - auto kernelSig = getKernelSig(conv_param, isAZeroPointZero_); - codeCache_[kernelSig] = fn; #if defined(FBGEMM_LOG_CODE) fclose(codeLogfile); @@ -1117,7 +1112,7 @@ jit_conv_kernel_fp GenConvKernel<2, int32_t>::getOrCreate<inst_set_t::avx2>( template <> template <> void GenConvKernel<2, int32_t>::genForTopEdgeRowoffset<inst_set_t::avx2>( - asmjit::X86Emitter* a) { + x86::Emitter* a) { // top-left corner code // zero out the results register a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_); @@ -1213,7 +1208,7 @@ void GenConvKernel<2, int32_t>::genForTopEdgeRowoffset<inst_set_t::avx2>( template <> template <> void GenConvKernel<2, int32_t>::genForLeftEdgeRowoffset<inst_set_t::avx2>( - asmjit::X86Emitter* a) { + x86::Emitter* a) { // left edge excluding corners asmjit::Label LoopLeftEdge = a->newLabel(); a->mov(loopR1_, static_cast<asmjit::Imm>(H_PAD_)); @@ -1256,7 +1251,7 @@ void GenConvKernel<2, int32_t>::genForLeftEdgeRowoffset<inst_set_t::avx2>( template <> template <> void GenConvKernel<2, int32_t>::genForRightEdgeRowoffset<inst_set_t::avx2>( - asmjit::X86Emitter* a) { + x86::Emitter* a) { // right edge excluding corners asmjit::Label LoopRightEdge = a->newLabel(); @@ -1326,7 +1321,7 @@ void GenConvKernel<2, int32_t>::genForRightEdgeRowoffset<inst_set_t::avx2>( template <> template <> void GenConvKernel<2, int32_t>::genForBottomEdgeRowoffset<inst_set_t::avx2>( - asmjit::X86Emitter* a) { + x86::Emitter* a) { // bottom-left corner // zero out a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_); @@ -1429,7 +1424,7 @@ void GenConvKernel<2, int32_t>::genForBottomEdgeRowoffset<inst_set_t::avx2>( template <> template <> void GenConvKernel<2, int32_t>::genRowoffsetCore<inst_set_t::avx2>( - asmjit::X86Emitter* a) { + x86::Emitter* a) { // number of uint8 elements in input channels should be a multiple of 32 assert(C_ % 32 == 0); @@ -1490,10 +1485,10 @@ template <> jit_rowoffset_kernel_fp GenConvKernel<2, int32_t>::getOrCreateRowOffset<inst_set_t::avx2>( const conv_param_t<2>& conv_param) { - code_.reset(false); - code_.init(rt_.getCodeInfo()); - asmjit::X86Assembler assembler(&code_); - asmjit::X86Emitter* a = assembler.asEmitter(); + asmjit::CodeHolder code; + code.init(runtime().codeInfo()); + x86::Assembler assembler(&code); + x86::Emitter* a = assembler.as<x86::Emitter>(); #if defined(FBGEMM_LOG_CODE) // log code to a file @@ -1501,54 +1496,62 @@ GenConvKernel<2, int32_t>::getOrCreateRowOffset<inst_set_t::avx2>( fopen(getCodeLoggingFile<inst_set_t::avx2>(true).c_str(), "w"); asmjit::FileLogger* codeLogger = new asmjit::FileLogger(codeLogfile); if (codeLogger) { - code_.setLogger(codeLogger); + code.setLogger(codeLogger); } #endif // arguments to the function created +#ifdef _MSC_VER + in_acts_R_ = a->zcx(); + a_zero_pt_R_ = a->zdx(); + H_R_ = a->gpz(8); + W_R_ = a->gpz(9); + row_offset_R_ = a->zdi(); +#else in_acts_R_ = a->zdi(); a_zero_pt_R_ = a->zsi(); H_R_ = a->zdx(); W_R_ = a->zcx(); - row_offset_R_ = a->gpzRef(8); + row_offset_R_ = a->gpz(8); +#endif // register for temporary use - scratchReg1_ = a->gpzRef(12); - scratchReg2_ = a->gpzRef(13); + scratchReg1_ = a->gpz(12); + scratchReg2_ = a->gpz(13); - loopR1_ = a->gpzRef(14); - loopR2_ = a->gpzRef(15); + loopR1_ = a->gpz(14); + loopR2_ = a->gpz(15); asmjit::FuncDetail func; func.init( asmjit:: - FuncSignature5<void, uint8_t*, int32_t, int32_t, int32_t, int32_t*>( + FuncSignatureT<void, uint8_t*, int32_t, int32_t, int32_t, int32_t*>( asmjit::CallConv::kIdHost)); - asmjit::FuncFrameInfo ffi; - ffi.setDirtyRegs( - asmjit::X86Reg::kKindVec, - asmjit::Utils::mask(0, 1, 2, 3, 4, 5, 6, 7) | - asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14, 15)); - ffi.setDirtyRegs( - asmjit::X86Reg::kKindGp, asmjit::Utils::mask(10, 11, 12, 13, 14, 15)); + asmjit::FuncFrame frame; + frame.init(func); - asmjit::FuncArgsMapper args(&func); - args.assignAll(in_acts_R_, a_zero_pt_R_, H_R_, W_R_, row_offset_R_); + frame.setDirtyRegs( + x86::Reg::kGroupVec, + asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) | + asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15)); + frame.setDirtyRegs( + x86::Reg::kGroupGp, asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15)); - args.updateFrameInfo(ffi); + asmjit::FuncArgsAssignment args(&func); + args.assignAll(in_acts_R_, a_zero_pt_R_, H_R_, W_R_, row_offset_R_); - asmjit::FuncFrameLayout layout; - layout.init(func, ffi); + args.updateFuncFrame(frame); + frame.finalize(); - asmjit::FuncUtils::emitProlog(a, layout); - asmjit::FuncUtils::allocArgs(a, layout, args); + a->emitProlog(frame); + a->emitArgsAssignment(frame, args); // This uses xmm10 register temporarily. Should come before // createVector8BitOne if (!isAZeroPointZero_) { // we can use xmm11 because ymm11 is used by tmpReg1Avx2_ - asmjit::X86Xmm const_reg_xmm = x86::xmm11; + x86::Xmm const_reg_xmm = x86::xmm11; a->movq(const_reg_xmm, a_zero_pt_R_); a->vpbroadcastd(zeroPTRegAvx2_, const_reg_xmm); @@ -1569,16 +1572,18 @@ GenConvKernel<2, int32_t>::getOrCreateRowOffset<inst_set_t::avx2>( genRowoffsetCore<inst_set_t::avx2>(a); - asmjit::FuncUtils::emitEpilog(a, layout); + a->emitEpilog(frame); + asmjit::Error err; jit_rowoffset_kernel_fp fn; - asmjit::Error err = rt_.add(&fn, &code_); + { + std::unique_lock<std::mutex> lock(rtMutex_); + err = runtime().add(&fn, &code); + } if (err) { std::cout << "Error: in fn add" << std::endl; return nullptr; } - auto kernelSig = getKernelSig(conv_param, isAZeroPointZero_); - codeCacheRowOffset_[kernelSig] = fn; #if defined(FBGEMM_LOG_CODE) delete codeLogger; @@ -1781,7 +1786,8 @@ template < typename outType, bool FUSE_RELU, QuantizationGranularity Q_GRAN, - int SPATIAL_DIM> + int SPATIAL_DIM, + typename BIAS_TYPE> void fbgemmGroupwiseConv( const conv_param_t<SPATIAL_DIM>& conv_param, const std::uint8_t* activations, @@ -1790,10 +1796,10 @@ void fbgemmGroupwiseConv( packed_W& packed_weights, outType* out, int32_t* outBuffer, - const ReQuantizeOutput<FUSE_RELU, Q_GRAN>& outProcess, + const ReQuantizeOutput<FUSE_RELU, Q_GRAN, BIAS_TYPE>& outProcess, int thread_id, int num_threads) { - typedef ReQuantizeOutput<FUSE_RELU, Q_GRAN> processOutputType; + typedef ReQuantizeOutput<FUSE_RELU, Q_GRAN, BIAS_TYPE> processOutputType; if (!cpuinfo_initialize()) { throw std::runtime_error("Failed to initialize cpuinfo!"); @@ -1884,15 +1890,17 @@ void fbgemmGroupwiseConv( outProcess.getBZeroPoint()[0] == 0) || rowOffsetBuf == nullptr; - requantizationParams_t r = {a_zero_point, - outProcess.getBZeroPoint(), - outProcess.getCZeroPoint(), - outProcess.getCMultiplier(), - rowOffsetBuf, - outProcess.getColOffsets(), - outProcess.getBias(), - outProcess.getNCols(), - G}; + requantizationParams_t<typename processOutputType::BIAS_T> r = { + a_zero_point, + outProcess.getBZeroPoint(), + outProcess.getCZeroPoint(), + outProcess.getCMultiplier(), + rowOffsetBuf, + outProcess.getColOffsets(), + outProcess.getBias(), + outProcess.getNCols(), + G, + outProcess.getActWScale()}; const std::int32_t* inp = outBuffer; block_type_t block{i * oh_ow, oh_ow, gOuter * K_per_G, 8 * K_per_G}; @@ -2163,15 +2171,14 @@ jit_rowoffset_kernel_fp getOrCreateRowOffsetKernel( // Note: Wrong code is generated if it's not one of the supported convolution assert(fbgemmOptimizedGConv<SPATIAL_DIM>(conv_param)); auto kernelSig = getKernelSig(conv_param, a_zero_point == 0); - if (GenConvKernel<SPATIAL_DIM, int32_t>::codeCacheRowOffset_.find( - kernelSig) != - GenConvKernel<SPATIAL_DIM, int32_t>::codeCacheRowOffset_.end()) { - return GenConvKernel<SPATIAL_DIM, int32_t>::codeCacheRowOffset_[kernelSig]; - } else { - auto genObj = GenConvKernel<SPATIAL_DIM, int32_t>(conv_param, a_zero_point); - // TODO: Instruction set based dispatch - return genObj.template getOrCreateRowOffset<inst_set_t::avx2>(conv_param); - } + return GenConvKernel<SPATIAL_DIM, int32_t>::codeCacheRowOffset_.getOrCreate( + kernelSig, [&]() { + auto genObj = + GenConvKernel<SPATIAL_DIM, int32_t>(conv_param, a_zero_point); + // TODO: Instruction set based dispatch + return genObj.template getOrCreateRowOffset<inst_set_t::avx2>( + conv_param); + }); } template <int SPATIAL_DIM> @@ -2215,7 +2222,7 @@ int rowOffsetBufferSizeGConv(const conv_param_t<SPATIAL_DIM>& conv_param) { template int rowOffsetBufferSizeGConv<2>(const conv_param_t<2>& conv_param); template int rowOffsetBufferSizeGConv<3>(const conv_param_t<3>& conv_param); -#define INSTANTIATE_BASE(RELU, Q_GRAN, SPATIAL_DIM) \ +#define INSTANTIATE_BASE(RELU, Q_GRAN, SPATIAL_DIM, BIAS_TYPE) \ template void fbgemmGroupwiseConv( \ const conv_param_t<SPATIAL_DIM>& conv_param, \ const uint8_t* activations, \ @@ -2224,13 +2231,17 @@ template int rowOffsetBufferSizeGConv<3>(const conv_param_t<3>& conv_param); PackWeightMatrixForGConv<int8_t, int32_t, SPATIAL_DIM>& packed_weights, \ uint8_t* out, \ int32_t* outBuffer, \ - const ReQuantizeOutput<RELU, Q_GRAN>& outProcess, \ + const ReQuantizeOutput<RELU, Q_GRAN, BIAS_TYPE>& outProcess, \ int thread_id, \ int num_threads); +#define INSTANTIATE_BIAS_T(RELU, Q_GRAN, SPATIAL_DIM) \ + INSTANTIATE_BASE(RELU, Q_GRAN, SPATIAL_DIM, float); \ + INSTANTIATE_BASE(RELU, Q_GRAN, SPATIAL_DIM, int32_t); + #define INSTANTIATE_SPATIAL_DIM(RELU, Q_GRAN) \ - INSTANTIATE_BASE(RELU, Q_GRAN, 2); \ - INSTANTIATE_BASE(RELU, Q_GRAN, 3); + INSTANTIATE_BIAS_T(RELU, Q_GRAN, 2); \ + INSTANTIATE_BIAS_T(RELU, Q_GRAN, 3); #define INSTANTIATE_Q_GRANS(RELU) \ INSTANTIATE_SPATIAL_DIM(RELU, QuantizationGranularity::TENSOR); \ @@ -2242,6 +2253,7 @@ INSTANTIATE_Q_GRANS(true); #undef INSTANTIATE_Q_GRANS #undef INSTANTIATE_SPATIAL_DIM +#undef INSTANTIATE_BIAS_T #undef INSTANTIATE_BASE template void fbgemmGroupwiseConv( diff --git a/src/OptimizedKernelsAvx2.cc b/src/OptimizedKernelsAvx2.cc index e8c65c3..326bd72 100644 --- a/src/OptimizedKernelsAvx2.cc +++ b/src/OptimizedKernelsAvx2.cc @@ -7,6 +7,7 @@ #include "OptimizedKernelsAvx2.h" #include <immintrin.h> +#include "fbgemm/Utils.h" using namespace std; @@ -14,37 +15,37 @@ namespace fbgemm { int32_t reduceAvx2(const uint8_t* A, int len) { int32_t row_sum = 0; -#if defined(__AVX2__) - __m256i sum_v = _mm256_setzero_si256(); - __m256i one_epi16_v = _mm256_set1_epi16(1); - __m256i one_epi8_v = _mm256_set1_epi8(1); + if (fbgemm::fbgemmHasAvx2Support()) { + __m256i sum_v = _mm256_setzero_si256(); + __m256i one_epi16_v = _mm256_set1_epi16(1); + __m256i one_epi8_v = _mm256_set1_epi8(1); - int i; - // vectorized - for (i = 0; i < len / 32 * 32; i += 32) { - __m256i src_v = _mm256_loadu_si256(reinterpret_cast<__m256i const*>(A + i)); - sum_v = _mm256_add_epi32( + int i; + // vectorized + for (i = 0; i < len / 32 * 32; i += 32) { + __m256i src_v = _mm256_loadu_si256(reinterpret_cast<__m256i const*>(A + i)); + sum_v = _mm256_add_epi32( sum_v, _mm256_madd_epi16( - _mm256_maddubs_epi16(src_v, one_epi8_v), one_epi16_v)); - } - - alignas(64) int32_t temp[8]; - _mm256_store_si256(reinterpret_cast<__m256i*>(temp), sum_v); - for (int k = 0; k < 8; ++k) { - row_sum += temp[k]; - } + _mm256_maddubs_epi16(src_v, one_epi8_v), one_epi16_v)); + } - // scalar - for (; i < len; ++i) { - row_sum += A[i]; - } + alignas(64) int32_t temp[8]; + _mm256_store_si256(reinterpret_cast<__m256i*>(temp), sum_v); + for (int k = 0; k < 8; ++k) { + row_sum += temp[k]; + } -#else - for (int i = 0; i < len; ++i) { - row_sum += A[i]; + // scalar + for (; i < len; ++i) { + row_sum += A[i]; + } + } else { + for (int i = 0; i < len; ++i) { + row_sum += A[i]; + } } -#endif + return row_sum; } diff --git a/src/PackAMatrix.cc b/src/PackAMatrix.cc index 87adaba..5fabf97 100644 --- a/src/PackAMatrix.cc +++ b/src/PackAMatrix.cc @@ -34,31 +34,35 @@ PackAMatrix<T, accT>::PackAMatrix( if (!cpuinfo_initialize()) { throw std::runtime_error("Failed to initialize cpuinfo!"); } + if ((!fbgemmHasAvx512VnniSupport() && !fbgemmHasAvx512Support() && + !fbgemmHasAvx2Support())) { + assert(0 && "unknown architecure"); + } + if (params) { - if (fbgemmHasAvx512Support() || fbgemmHasAvx2Support()) { - BaseType::brow_ = params->MCB; - BaseType::bcol_ = params->KCB; - row_interleave_B_ = params->ROW_INTERLEAVE; - } else { - // TODO: Have default slower path - assert(0 && "unsupported architecure"); - } + BaseType::brow_ = params->MCB; + BaseType::bcol_ = params->KCB; + row_interleave_B_ = params->ROW_INTERLEAVE; } else { - if (fbgemmHasAvx512Support()) { + if (fbgemmHasAvx512VnniSupport()) { + BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx512_vnni>::MCB; + BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx512_vnni>::KCB; + row_interleave_B_ = + PackingTraits<T, accT, inst_set_t::avx512_vnni>::ROW_INTERLEAVE; + } else if (fbgemmHasAvx512Support()) { BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx512>::MCB; BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx512>::KCB; row_interleave_B_ = PackingTraits<T, accT, inst_set_t::avx512>::ROW_INTERLEAVE; - } else if (fbgemmHasAvx2Support()) { + } else { + // AVX2 BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx2>::MCB; BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx2>::KCB; row_interleave_B_ = PackingTraits<T, accT, inst_set_t::avx2>::ROW_INTERLEAVE; - } else { - // TODO: Have default slower path - assert(0 && "unsupported architecure"); } } + if (BaseType::numCols() % groups != 0) { throw std::runtime_error( "groups = " + std::to_string(groups) + diff --git a/src/PackAWithIm2Col.cc b/src/PackAWithIm2Col.cc index e55dd4e..6101fef 100644 --- a/src/PackAWithIm2Col.cc +++ b/src/PackAWithIm2Col.cc @@ -49,32 +49,35 @@ PackAWithIm2Col<T, accT, SPATIAL_DIM>::PackAWithIm2Col( if (!cpuinfo_initialize()) { throw std::runtime_error("Failed to initialize cpuinfo!"); } + if ((!fbgemmHasAvx512VnniSupport() && !fbgemmHasAvx512Support() && + !fbgemmHasAvx2Support())) { + assert(0 && "unknown architecure"); + } if (params) { - if (fbgemmHasAvx512Support() || fbgemmHasAvx2Support()) { - BaseType::brow_ = params->MCB; - BaseType::bcol_ = params->KCB; - row_interleave_B_ = params->ROW_INTERLEAVE; - } else { - // TODO: Have default slower path - assert(0 && "unsupported architecure"); - } + BaseType::brow_ = params->MCB; + BaseType::bcol_ = params->KCB; + row_interleave_B_ = params->ROW_INTERLEAVE; } else { - if (fbgemmHasAvx512Support()) { + if (fbgemmHasAvx512VnniSupport()) { + BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx512_vnni>::MCB; + BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx512_vnni>::KCB; + row_interleave_B_ = + PackingTraits<T, accT, inst_set_t::avx512_vnni>::ROW_INTERLEAVE; + } else if (fbgemmHasAvx512Support()) { BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx512>::MCB; BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx512>::KCB; row_interleave_B_ = PackingTraits<T, accT, inst_set_t::avx512>::ROW_INTERLEAVE; - } else if (fbgemmHasAvx2Support()) { + } else { + // AVX2 BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx2>::MCB; BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx2>::KCB; row_interleave_B_ = PackingTraits<T, accT, inst_set_t::avx2>::ROW_INTERLEAVE; - } else { - // TODO: Have default slower path - assert(0 && "unsupported architecure"); } } + if (BaseType::numCols() % conv_p.G != 0) { throw std::runtime_error( "groups = " + std::to_string(conv_p.G) + @@ -272,6 +275,7 @@ void PackAWithIm2Col<T, accT, SPATIAL_DIM>::pack(const block_type_t& block) { conv_p_.K[1] == 7 && conv_p_.stride[0] == 2 && conv_p_.stride[1] == 2 && conv_p_.pad[0] == 3 && conv_p_.pad[1] == 3 && block.col_size == 147 && block_p.col_size == 148 && block.col_start == 0 && + conv_p_.dilation[0] == 1 && conv_p_.dilation[1] == 1 && std::is_same<T, uint8_t>::value) { if (BaseType::blockColSize() == 256) { pack_a_with_im2col_opt< @@ -347,8 +351,10 @@ void PackAWithIm2Col<T, accT, SPATIAL_DIM>::pack(const block_type_t& block) { int r = grs / conv_p_.K[1] % conv_p_.K[0]; int g = grs / conv_p_.K[1] / conv_p_.K[0]; - int h_in = -conv_p_.pad[0] + h * conv_p_.stride[0] + r; - int w_in = -conv_p_.pad[1] + w * conv_p_.stride[1] + s; + int h_in = + -conv_p_.pad[0] + h * conv_p_.stride[0] + r * conv_p_.dilation[0]; + int w_in = + -conv_p_.pad[1] + w * conv_p_.stride[1] + s * conv_p_.dilation[1]; if (h_in < 0 || h_in >= conv_p_.IN_DIM[0] || w_in < 0 || w_in >= conv_p_.IN_DIM[1]) { @@ -396,9 +402,12 @@ void PackAWithIm2Col<T, accT, SPATIAL_DIM>::pack(const block_type_t& block) { int q = gqrs / conv_p_.K[2] / conv_p_.K[1] % conv_p_.K[0]; int g = gqrs / conv_p_.K[2] / conv_p_.K[1] / conv_p_.K[0]; - int t_in = -conv_p_.pad[0] + t * conv_p_.stride[0] + q; - int h_in = -conv_p_.pad[1] + h * conv_p_.stride[1] + r; - int w_in = -conv_p_.pad[2] + w * conv_p_.stride[2] + s; + int t_in = + -conv_p_.pad[0] + t * conv_p_.stride[0] + q * conv_p_.dilation[0]; + int h_in = + -conv_p_.pad[1] + h * conv_p_.stride[1] + r * conv_p_.dilation[1]; + int w_in = + -conv_p_.pad[2] + w * conv_p_.stride[2] + s * conv_p_.dilation[2]; if (t_in < 0 || t_in >= conv_p_.IN_DIM[0] || h_in < 0 || h_in >= conv_p_.IN_DIM[1] || w_in < 0 || @@ -481,7 +490,9 @@ int PackAWithIm2Col<T, accT, SPATIAL_DIM>::rowOffsetBufferSize( if (params) { return params->MCB; } else { - if (fbgemmHasAvx512Support()) { + if (fbgemmHasAvx512VnniSupport()) { + return PackingTraits<T, accT, inst_set_t::avx512_vnni>::MCB; + } else if (fbgemmHasAvx512Support()) { return PackingTraits<T, accT, inst_set_t::avx512>::MCB; } else if (fbgemmHasAvx2Support()) { return PackingTraits<T, accT, inst_set_t::avx2>::MCB; diff --git a/src/PackAWithQuantRowOffset.cc b/src/PackAWithQuantRowOffset.cc index 305a298..13a8fad 100644 --- a/src/PackAWithQuantRowOffset.cc +++ b/src/PackAWithQuantRowOffset.cc @@ -45,32 +45,37 @@ PackAWithQuantRowOffset<T, accT>::PackAWithQuantRowOffset( if (!cpuinfo_initialize()) { throw std::runtime_error("Failed to initialize cpuinfo!"); } - rowOffsetAllocatedHere = false; + if ((!fbgemmHasAvx512VnniSupport() && !fbgemmHasAvx512Support() && + !fbgemmHasAvx2Support())) { + assert(0 && "unknown architecure"); + } + if (params) { - if (fbgemmHasAvx512Support() || fbgemmHasAvx2Support()) { - BaseType::brow_ = params->MCB; - BaseType::bcol_ = params->KCB; - row_interleave_B_ = params->ROW_INTERLEAVE; - } else { - // TODO: Have default slower path - assert(0 && "unsupported architecure"); - } + BaseType::brow_ = params->MCB; + BaseType::bcol_ = params->KCB; + row_interleave_B_ = params->ROW_INTERLEAVE; } else { - if (fbgemmHasAvx512Support()) { + if (fbgemmHasAvx512VnniSupport()) { + BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx512_vnni>::MCB; + BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx512_vnni>::KCB; + row_interleave_B_ = + PackingTraits<T, accT, inst_set_t::avx512_vnni>::ROW_INTERLEAVE; + } else if (fbgemmHasAvx512Support()) { BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx512>::MCB; BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx512>::KCB; row_interleave_B_ = PackingTraits<T, accT, inst_set_t::avx512>::ROW_INTERLEAVE; - } else if (fbgemmHasAvx2Support()) { + } else { + // AVX2 BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx2>::MCB; BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx2>::KCB; row_interleave_B_ = PackingTraits<T, accT, inst_set_t::avx2>::ROW_INTERLEAVE; - } else { - // TODO: Have default slower path - assert(0 && "unknown architecure"); } } + + rowOffsetAllocatedHere = false; + if (BaseType::numCols() % groups != 0) { throw std::runtime_error( "groups = " + std::to_string(groups) + @@ -202,7 +207,9 @@ int PackAWithQuantRowOffset<T, accT>::rowOffsetBufferSize( if (params) { return params->MCB; } else { - if (fbgemmHasAvx512Support()) { + if (fbgemmHasAvx512VnniSupport()) { + return PackingTraits<T, accT, inst_set_t::avx512_vnni>::MCB; + } else if (fbgemmHasAvx512Support()) { return PackingTraits<T, accT, inst_set_t::avx512>::MCB; } else if (fbgemmHasAvx2Support()) { return PackingTraits<T, accT, inst_set_t::avx2>::MCB; diff --git a/src/PackAWithRowOffset.cc b/src/PackAWithRowOffset.cc index b791817..e84c67b 100644 --- a/src/PackAWithRowOffset.cc +++ b/src/PackAWithRowOffset.cc @@ -39,32 +39,37 @@ PackAWithRowOffset<T, accT>::PackAWithRowOffset( if (!cpuinfo_initialize()) { throw std::runtime_error("Failed to initialize cpuinfo!"); } - rowOffsetAllocatedHere = false; + if ((!fbgemmHasAvx512VnniSupport() && !fbgemmHasAvx512Support() && + !fbgemmHasAvx2Support())) { + assert(0 && "unknown architecure"); + } + if (params) { - if (fbgemmHasAvx512Support() || fbgemmHasAvx2Support()) { - BaseType::brow_ = params->MCB; - BaseType::bcol_ = params->KCB; - row_interleave_B_ = params->ROW_INTERLEAVE; - } else { - // TODO: Have default slower path - assert(0 && "unsupported architecure"); - } + BaseType::brow_ = params->MCB; + BaseType::bcol_ = params->KCB; + row_interleave_B_ = params->ROW_INTERLEAVE; } else { - if (fbgemmHasAvx512Support()) { + if (fbgemmHasAvx512VnniSupport()) { + BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx512_vnni>::MCB; + BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx512_vnni>::KCB; + row_interleave_B_ = + PackingTraits<T, accT, inst_set_t::avx512_vnni>::ROW_INTERLEAVE; + } else if (fbgemmHasAvx512Support()) { BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx512>::MCB; BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx512>::KCB; row_interleave_B_ = PackingTraits<T, accT, inst_set_t::avx512>::ROW_INTERLEAVE; - } else if (fbgemmHasAvx2Support()) { + } else { + // AVX2 BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx2>::MCB; BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx2>::KCB; row_interleave_B_ = PackingTraits<T, accT, inst_set_t::avx2>::ROW_INTERLEAVE; - } else { - // TODO: Have default slower path - assert(0 && "unknown architecure"); } } + + rowOffsetAllocatedHere = false; + if (BaseType::numCols() % groups != 0) { throw std::runtime_error( "groups = " + std::to_string(groups) + @@ -190,7 +195,9 @@ int PackAWithRowOffset<T, accT>::rowOffsetBufferSize( if (params) { return params->MCB; } else { - if (fbgemmHasAvx512Support()) { + if (fbgemmHasAvx512VnniSupport()) { + return PackingTraits<T, accT, inst_set_t::avx512_vnni>::MCB; + } else if (fbgemmHasAvx512Support()) { return PackingTraits<T, accT, inst_set_t::avx512>::MCB; } else if (fbgemmHasAvx2Support()) { return PackingTraits<T, accT, inst_set_t::avx2>::MCB; diff --git a/src/PackBMatrix.cc b/src/PackBMatrix.cc index b6d06ca..c237ac4 100644 --- a/src/PackBMatrix.cc +++ b/src/PackBMatrix.cc @@ -188,6 +188,76 @@ PackBMatrix<T, accT>::PackBMatrix( if (!cpuinfo_initialize()) { throw std::runtime_error("Failed to initialize cpuinfo!"); } + if ((!fbgemmHasAvx512VnniSupport() && !fbgemmHasAvx512Support() && + !fbgemmHasAvx2Support())) { + assert(0 && "unknown architecure"); + } + + if (params) { + BaseType::brow_ = params->KCB; + BaseType::bcol_ = params->NCB; + row_interleave_ = params->ROW_INTERLEAVE; + } else { + if (fbgemmHasAvx512VnniSupport()) { + BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx512_vnni>::KCB; + BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx512_vnni>::NCB; + row_interleave_ = + PackingTraits<T, accT, inst_set_t::avx512_vnni>::ROW_INTERLEAVE; + } else if (fbgemmHasAvx512Support()) { + BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx512>::KCB; + BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx512>::NCB; + row_interleave_ = + PackingTraits<T, accT, inst_set_t::avx512>::ROW_INTERLEAVE; + } else { + // AVX2 + BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx2>::KCB; + BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx2>::NCB; + row_interleave_ = + PackingTraits<T, accT, inst_set_t::avx2>::ROW_INTERLEAVE; + } + } + + if (BaseType::numRows() % groups != 0) { + throw std::runtime_error( + "groups = " + std::to_string(groups) + + " does not divide numRows = " + std::to_string(BaseType::numRows())); + } + + // blocking for one group + block_type_t block{ + 0, BaseType::numRows() / BaseType::numGroups(), 0, BaseType::numCols()}; + BaseType::packedBlock(block); + if (!pmat) { + BaseType::bufAllocatedHere_ = true; + BaseType::buf_ = (T*)fbgemmAlignedAlloc( + 64, + BaseType::numGroups() * BaseType::blockRows() * BaseType::brow_ * + BaseType::blockCols() * BaseType::bcol_ * sizeof(T)); + } + pack(block, params); +} + +template <typename T, typename accT> +PackBMatrix<T, accT>::PackBMatrix( + matrix_op_t trans, + int32_t nRow, + int32_t nCol, + inpType* prepackedmat, + int32_t ld, + int groups, + const BlockingFactors* params) + : PackMatrix<PackBMatrix<T, accT>, T, accT>( + nRow, + nCol, + prepackedmat, + groups, + params), + trans_(trans), + smat_(nullptr), + ld_(ld) { + if (!cpuinfo_initialize()) { + throw std::runtime_error("Failed to initialize cpuinfo!"); + } if (params) { if (fbgemmHasAvx512Support() || fbgemmHasAvx2Support()) { BaseType::brow_ = params->KCB; @@ -221,20 +291,17 @@ PackBMatrix<T, accT>::PackBMatrix( // blocking for one group block_type_t block{ - 0, BaseType::numRows() / BaseType::numGroups(), 0, BaseType::numCols()}; + 0, BaseType::numRows() / BaseType::numGroups(), 0, BaseType::numCols() }; BaseType::packedBlock(block); - if (!pmat) { - BaseType::bufAllocatedHere_ = true; - BaseType::buf_ = (T*)fbgemmAlignedAlloc( - 64, - BaseType::numGroups() * BaseType::blockRows() * BaseType::brow_ * - BaseType::blockCols() * BaseType::bcol_ * sizeof(T)); - } - pack(block); } template <typename T, typename accT> -void PackBMatrix<T, accT>::pack(const block_type_t& block) { +void PackBMatrix<T, accT>::pack_unpack_( + const block_type_t& block, + T* unpack_buf, + T* pack_buf, + bool ispack, + const BlockingFactors* params) { assert((BaseType::blockRowSize() % row_interleave_) == 0); assert((block.row_start % BaseType::blockRowSize()) == 0); assert((block.col_start % BaseType::blockColSize()) == 0); @@ -242,8 +309,8 @@ void PackBMatrix<T, accT>::pack(const block_type_t& block) { BaseType::packedBlock(block); bool tr = (trans_ == matrix_op_t::Transpose); for (int g = 0; g < BaseType::numGroups(); ++g) { - T* out = BaseType::getBuf() + - g * BaseType::packedBufferSize(block.row_size, block.col_size); + T* pack_buf_cur = pack_buf + + g * BaseType::packedBufferSize(block.row_size, block.col_size, params); for (int i = block.row_start; i < block.row_start + block.row_size; ++i) { int r_offset = ((i / BaseType::blockRowSize()) * BaseType::blockCols()) * (BaseType::blockRowSize() * BaseType::blockColSize()) + @@ -268,10 +335,16 @@ void PackBMatrix<T, accT>::pack(const block_type_t& block) { c_blk_offset * BaseType::blockRowSize() * BaseType::blockColSize() + c_idx_offset * row_interleave_; - int out_idx = r_offset + c_offset; - T val = tr ? smat_[i + (g * block.col_size + j) * ld_] - : smat_[(g * block.row_size + i) * ld_ + j]; - out[out_idx] = val; + if (ispack) { + pack_buf_cur[r_offset + c_offset] = tr + ? unpack_buf[i + (g * block.col_size + j) * ld_] + : unpack_buf[(g * block.row_size + i) * ld_ + j]; + } else { + T* unpack_buf_cur = tr + ? &(unpack_buf[i + (g * block.col_size + j) * ld_]) + : &(unpack_buf[(g * block.row_size + i) * ld_ + j]); + *unpack_buf_cur = pack_buf_cur[r_offset + c_offset]; + } c_idx_offset++; if (c_idx_offset == BaseType::blockColSize()) { @@ -280,78 +353,49 @@ void PackBMatrix<T, accT>::pack(const block_type_t& block) { } } } - // fill the remaining with zero. - // Please see the comment in PackAMatrix.cc on zero vs zero_pt fill. - for (int i = block.row_start + block.row_size; - i < (block.row_start + block.row_size + row_interleave_ - 1) / - row_interleave_ * row_interleave_; - ++i) { - int r_offset = ((i / BaseType::blockRowSize()) * BaseType::blockCols()) * - (BaseType::blockRowSize() * BaseType::blockColSize()) + - (i % BaseType::blockRowSize() / row_interleave_) * - BaseType::blockColSize() * row_interleave_ + - i % row_interleave_; - for (int j = block.col_start; j < block.col_start + block.col_size; j++) { - int c_offset = (j / BaseType::blockColSize()) * - BaseType::blockRowSize() * BaseType::blockColSize() + - (j % BaseType::blockColSize()) * row_interleave_; + if (ispack) { + // fill the remaining with zero. + // Please see the comment in PackAMatrix.cc on zero vs zero_pt fill. + for (int i = block.row_start + block.row_size; + i < (block.row_start + block.row_size + row_interleave_ - 1) / + row_interleave_ * row_interleave_; + ++i) { + int r_offset = + ((i / BaseType::blockRowSize()) * BaseType::blockCols()) * + (BaseType::blockRowSize() * BaseType::blockColSize()) + + (i % BaseType::blockRowSize() / row_interleave_) * + BaseType::blockColSize() * row_interleave_ + + i % row_interleave_; + for (int j = block.col_start; j < block.col_start + block.col_size; + j++) { + int c_offset = (j / BaseType::blockColSize()) * + BaseType::blockRowSize() * BaseType::blockColSize() + + (j % BaseType::blockColSize()) * row_interleave_; - int out_idx = r_offset + c_offset; - out[out_idx] = 0; + int out_idx = r_offset + c_offset; + pack_buf_cur[out_idx] = 0; + } } } } // for each group } template <typename T, typename accT> -void PackBMatrix<T, accT>::unpack(T* origin_buf) { - bool tr = (trans_ == matrix_op_t::Transpose); - for (int g = 0; g < this->numGroups(); ++g) { - T* out = BaseType::getBuf() + - g * - BaseType::packedBufferSize( - BaseType::numPackedRows(), BaseType::numPackedCols()); - for (int i = BaseType::packedRowStart(); - i < BaseType::packedRowStart() + BaseType::numPackedRows(); - ++i) { - int r_offset = ((i / BaseType::blockRowSize()) * BaseType::blockCols()) * - (BaseType::blockRowSize() * BaseType::blockColSize()) + - (i % BaseType::blockRowSize() / row_interleave_) * - BaseType::blockColSize() * row_interleave_ + - i % row_interleave_; - - int c_start_offset = - (BaseType::packedColStart() / BaseType::blockColSize()) * - BaseType::blockRowSize() * BaseType::blockColSize() + - (BaseType::packedColStart() % BaseType::blockColSize()) * - row_interleave_; - - int c_idx_offset = 0; - int c_blk_offset = 0; - for (int j = BaseType::packedColStart(); - j < BaseType::packedColStart() + BaseType::numPackedCols(); - ++j) { - int c_offset = c_start_offset + - c_blk_offset * BaseType::blockRowSize() * BaseType::blockColSize() + - c_idx_offset * row_interleave_; - - int out_idx = r_offset + c_offset; - - T val = out[out_idx]; - if (tr) { - origin_buf[i + (g * BaseType::numPackedCols() + j) * ld_] = val; - } else { - origin_buf[(g * BaseType::numPackedRows() + i) * ld_ + j] = val; - } +void PackBMatrix<T, accT>::pack( + const block_type_t& block, + const BlockingFactors* params) { + pack_unpack_(block, const_cast<T*>(smat_), BaseType::getBuf(), true, params); +} - c_idx_offset++; - if (c_idx_offset == BaseType::blockColSize()) { - c_idx_offset = 0; - c_blk_offset++; - } - } - } - } // for each group +template <typename T, typename accT> +void PackBMatrix<T, accT>::unpack( + T* origin_buf, + const BlockingFactors* params) { + block_type_t blockB{BaseType::packedRowStart(), + BaseType::numPackedRows(), + BaseType::packedColStart(), + BaseType::numPackedCols()}; + pack_unpack_(blockB, origin_buf, BaseType::getBuf(), false, params); } template <typename T, typename accT> @@ -374,7 +418,9 @@ int32_t PackBMatrix<T, accT>::addr(int32_t r, int32_t c) const { } template <typename T, typename accT> -void PackBMatrix<T, accT>::printPackedMatrix(std::string name) { +void PackBMatrix<T, accT>::printPackedMatrix( + std::string name, + const BlockingFactors* params) { std::cout << name << ":" << "[" << BaseType::numPackedRows() << ", " << BaseType::numPackedCols() << "]" << std::endl; @@ -382,33 +428,39 @@ void PackBMatrix<T, accT>::printPackedMatrix(std::string name) { << "[" << BaseType::blockRowSize() << ", " << BaseType::blockColSize() << "]" << std::endl; - T* out = BaseType::getBuf(); - - for (auto nr = 0; nr < BaseType::blockRows(); ++nr) { - auto rows = (nr == BaseType::blockRows() - 1) ? BaseType::lastBrow() - : BaseType::blockRowSize(); - for (auto nc = 0; nc < BaseType::blockCols(); ++nc) { - std::cout << "block:" << nr << ", " << nc << std::endl; - auto cols = (nc == BaseType::blockCols() - 1) ? BaseType::lastBcol() - : BaseType::blockColSize(); - for (auto r = 0; r < (rows + row_interleave_ - 1) / row_interleave_; - ++r) { - for (auto c = 0; c < cols * row_interleave_; ++c) { - T val = - out[nr * BaseType::blockCols() * BaseType::blockRowSize() * - BaseType::blockColSize() + - nc * BaseType::blockRowSize() * BaseType::blockColSize() + - r * BaseType::blockColSize() * row_interleave_ + c]; - if (std::is_integral<T>::value) { - // cast to int64 because cout doesn't print int8_t type directly - std::cout << std::setw(5) << static_cast<int64_t>(val) << " "; - } else { - std::cout << std::setw(5) << val << " "; + for (int g = 0; g < BaseType::numGroups(); ++g) { + T* out = BaseType::getBuf() + + g * + BaseType::packedBufferSize( + BaseType::numPackedRows(), BaseType::numPackedCols(), params); + std::cout << "group: " << g << std::endl; + for (auto nr = 0; nr < BaseType::blockRows(); ++nr) { + auto rows = (nr == BaseType::blockRows() - 1) ? BaseType::lastBrow() + : BaseType::blockRowSize(); + for (auto nc = 0; nc < BaseType::blockCols(); ++nc) { + std::cout << "block:" << nr << ", " << nc << std::endl; + auto cols = (nc == BaseType::blockCols() - 1) + ? BaseType::lastBcol() + : BaseType::blockColSize(); + for (auto r = 0; r < (rows + row_interleave_ - 1) / row_interleave_; + ++r) { + for (auto c = 0; c < cols * row_interleave_; ++c) { + T val = + out[nr * BaseType::blockCols() * BaseType::blockRowSize() * + BaseType::blockColSize() + + nc * BaseType::blockRowSize() * BaseType::blockColSize() + + r * BaseType::blockColSize() * row_interleave_ + c]; + if (std::is_integral<T>::value) { + // cast to int64 because cout doesn't print int8_t type directly + std::cout << std::setw(5) << static_cast<int64_t>(val) << " "; + } else { + std::cout << std::setw(5) << val << " "; + } } + std::cout << std::endl; } std::cout << std::endl; } - std::cout << std::endl; } } } diff --git a/src/PackDepthwiseConvMatrixAvx2.cc b/src/PackDepthwiseConvMatrixAvx2.cc new file mode 100644 index 0000000..a84c469 --- /dev/null +++ b/src/PackDepthwiseConvMatrixAvx2.cc @@ -0,0 +1,211 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include "fbgemm/FbgemmI8DepthwiseAvx2.h" +#include "fbgemm/Utils.h" +#include "fbgemm/Fbgemm.h" + +#include <immintrin.h> + +using namespace std; + +namespace fbgemm { + +// clang-format off +static int masks[8][8] = { + // NOTE: clang-format wants to use a different formatting but the current + // formatting should be easier to read. + { 0, 0, 0, 0, 0, 0, 0, 0, }, + { -1, 0, 0, 0, 0, 0, 0, 0, }, + { -1, -1, 0, 0, 0, 0, 0, 0, }, + { -1, -1, -1, 0, 0, 0, 0, 0, }, + { -1, -1, -1, -1, 0, 0, 0, 0, }, + { -1, -1, -1, -1, -1, 0, 0, 0, }, + { -1, -1, -1, -1, -1, -1, 0, 0, }, + { -1, -1, -1, -1, -1, -1, -1, 0, }, +}; +// clang-format on + +PackedDepthWiseConvMatrix::PackedDepthWiseConvMatrix( + int K, + int kernel_prod, + const int8_t* smat) + : K_(K), kernel_prod_(kernel_prod) { + // Transpose the input matrix to make packing faster. + int8_t* smat_transposed + = static_cast<int8_t*>(ALIGNED_MALLOC(K * kernel_prod * sizeof(int8_t), 64)); + + for (int i = 0; i < kernel_prod; ++i) { + for (int j = 0; j < K; ++j) { + smat_transposed[i * K + j] = smat[i + j * kernel_prod]; + } + } + + // Allocate packed arrays + int kernel_prod_aligned = (kernel_prod + 1) / 2 * 2; + pmat_ = static_cast<int8_t *>(fbgemmAlignedAlloc(64, ((K + 31) / 32) * kernel_prod_aligned * 32 * sizeof(int8_t))); + + // Pack input matrix + // The layout is optimized to use vpmaddubsw efficiently (see + // madd_epi16x4_packed function). + // For a group of 32 channels, we have 10 32B SIMD registers. + // Denote ith channel jth filter as (i, j) + // 0th SIMD register: + // (0, 0), (0, 1), (0, 2), (0, 3), ..., (3, 0), (3, 1), (3, 2), (3, 3) + // (16, 0), (16, 1), (16, 2), (16, 3), ..., (19, 0), (19, 1), (19, 2), (19, 3) + // 1st SIMD register: + // (4, 0), (4, 1), (4, 2), (4, 3), ..., (7, 0), (7, 1), (7, 2), (7, 3) + // (20, 0), (20, 1), (20, 2), (20, 3), ..., (23, 0), (23, 1), (23, 2), (23, 3) + // 2nd SIMD register: + // (8, 0), (8, 1), (8, 2), (8, 3), ..., (11, 0), (11, 1), (11, 2), (11, 3) + // (24, 0), (24, 1), (24, 2), (24, 3), ..., (27, 0), (27, 1), (27, 2), (27, 3) + // 3rd SIMD register: + // (12, 0), (12, 1), (12, 2), (12, 3), ..., (15, 0), (15, 1), (15, 2), (15, 3) + // (28, 0), (28, 1), (28, 2), (28, 3), ..., (31, 0), (31, 1), (31, 2), (31, 3) + // 4-7th SIMD register: same as the previous 4 registers but for 4-7th filter + // coefficients + // ... + // + // REMAINDER + // If kernel_prod % 4 == 1 for example when kernel_prod == 9 + // 8th SIMD register: + // (0, 8), zero, ..., (7, 8), zero + // (16, 8), zero, ..., (23, 8), zero + // 9th SIMD register: + // (8, 8), zero, ..., (15, 8), zero + // (24, 8), zero, ..., (31, 8), zero + // We use madd_epi16_packed for this case + // + // If kernel_prod % 4 == 2 for example when kernel_prod == 10 + // 8th SIMD register: + // (0, 8), (0, 9), ..., (7, 8), (7, 9) + // (16, 8), (16, 9), ..., (23, 8), (23, 9) + // 9th SIMD register: + // (8, 8), (8, 9), ..., (15, 8), (15, 9) + // (24, 8), (24, 9), ..., (31, 8), (31, 9) + // + // If kernel_prod % 4 == 3 for example when kernel_prod == 11 + // 8th SIMD register: + // (0, 8), (0, 9), (0, 10), zero, ..., (3, 8), (3, 9), (3, 10), zero + // (16, 8), (16, 9), (16, 10), zero, ..., (19, 8), (19, 9), (19, 10), zero + // 9th SIMD register: + // (4, 8), (4, 9), (4, 10), zero, ..., (7, 8), (7, 9), (7, 10), zero + // (20, 8), (20, 9), (20, 10), zero, ..., (23, 8), (23, 9), (23, 10), zero + // 10th SIMD register: + // (8, 8), (8, 9), (8, 10), zero, ..., (11, 8), (11, 9), (11, 10), zero + // (24, 8), (24, 9), (24, 10), zero, ..., (27, 8), (27, 9), (27, 10), zero + // 11th SIMD register: + // (12, 8), (12, 9), (12, 10), zero, ..., (15, 8), (15, 9), (15, 10), zero + // (28, 8), (28, 9), (28, 10), zero, ..., (31, 8), (31, 9), (31, 10), zero + for (int k1 = 0; k1 < K; k1 += 32) { + __m256i* b_v = static_cast<__m256i*>(ALIGNED_MALLOC(kernel_prod * sizeof(__m256i), 64)); + int remainder = K - k1; + if (remainder < 32) { + __m256i mask_v = _mm256_loadu_si256( + reinterpret_cast<const __m256i*>(masks[remainder / 4])); + for (int i = 0; i < kernel_prod; ++i) { + b_v[i] = _mm256_maskload_epi32( + reinterpret_cast<const int*>(smat_transposed + i * K + k1), mask_v); + } + } else { + for (int i = 0; i < kernel_prod; ++i) { + b_v[i] = _mm256_lddqu_si256( + reinterpret_cast<const __m256i*>(smat_transposed + i * K + k1)); + } + } + + // Interleave 2 SIMD registers + __m256i* b_interleaved_epi16 = static_cast<__m256i*>(ALIGNED_MALLOC(kernel_prod_aligned * sizeof(__m256i), 64)); + __m256i zero_v = _mm256_setzero_si256(); + for (int i = 0; i < kernel_prod_aligned / 2; ++i) { + if (2 * i + 1 >= kernel_prod) { + b_interleaved_epi16[2 * i] = _mm256_unpacklo_epi8(b_v[2 * i], zero_v); + b_interleaved_epi16[2 * i + 1] = + _mm256_unpackhi_epi8(b_v[2 * i], zero_v); + } else { + b_interleaved_epi16[2 * i] = + _mm256_unpacklo_epi8(b_v[2 * i], b_v[2 * i + 1]); + b_interleaved_epi16[2 * i + 1] = + _mm256_unpackhi_epi8(b_v[2 * i], b_v[2 * i + 1]); + } + } + + // Interleave 4 SIMD registers + __m256i* b_interleaved_epi32 = static_cast<__m256i*>(ALIGNED_MALLOC(kernel_prod_aligned * sizeof(__m256i), 64)); + for (int i = 0; i < kernel_prod_aligned / 4; ++i) { + b_interleaved_epi32[4 * i] = _mm256_unpacklo_epi16( + b_interleaved_epi16[4 * i], b_interleaved_epi16[4 * i + 2]); + b_interleaved_epi32[4 * i + 1] = _mm256_unpackhi_epi16( + b_interleaved_epi16[4 * i], b_interleaved_epi16[4 * i + 2]); + b_interleaved_epi32[4 * i + 2] = _mm256_unpacklo_epi16( + b_interleaved_epi16[4 * i + 1], b_interleaved_epi16[4 * i + 3]); + b_interleaved_epi32[4 * i + 3] = _mm256_unpackhi_epi16( + b_interleaved_epi16[4 * i + 1], b_interleaved_epi16[4 * i + 3]); + } + for (int i = kernel_prod_aligned / 4 * 4; i < kernel_prod_aligned; ++i) { + b_interleaved_epi32[i] = b_interleaved_epi16[i]; + } + + for (int i = 0; i < kernel_prod_aligned; ++i) { + _mm256_storeu_si256( + reinterpret_cast<__m256i*>( + &pmat_[((k1 / 32) * kernel_prod_aligned + i) * 32]), + b_interleaved_epi32[i]); + } + + FREE(b_v); + FREE(b_interleaved_epi16); + FREE(b_interleaved_epi32); + } + FREE(smat_transposed); +} + +int PackedDepthWiseConvMatrix::addr(int r, int c) { + int kernel_prod_aligned = (kernel_prod_ + 1) / 2 * 2; + if (c >= kernel_prod_ / 4 * 4 && + (kernel_prod_ % 4 == 1 || kernel_prod_ % 4 == 2)) { + int kBlock = r / 32; + int reg_idx = (r % 16) / 8 + c / 4 * 4; + + int blk_idx = kBlock * kernel_prod_aligned + reg_idx; + + int r_ = r % 8; + int c_ = c % 4; + + int in_blk_idx = (r % 32) / 16 * 16 + 2 * r_ + c_; + return blk_idx * 32 + in_blk_idx; + + } else { + int kBlock = r / 32; + int reg_idx = (r % 16) / 4 + c / 4 * 4; + + int blk_idx = kBlock * kernel_prod_aligned + reg_idx; + + int r_ = r % 4; + int c_ = c % 4; + + int in_blk_idx = (r % 32) / 16 * 16 + 4 * r_ + c_; + return blk_idx * 32 + in_blk_idx; + } +} + +void PackedDepthWiseConvMatrix::unpack(int8_t* unpacked_data) { + for (int r = 0; r < K_; ++r) { + for (int c = 0; c < kernel_prod_; ++c) { + unpacked_data[r * kernel_prod_ + c] = pmat_[addr(r, c)]; + } + } +} + +PackedDepthWiseConvMatrix::~PackedDepthWiseConvMatrix() { +#ifdef _MSC_VER + _aligned_free(pmat_); +#else + free(pmat_); +#endif +} + +} // namespace fbgemm diff --git a/src/PackMatrix.cc b/src/PackMatrix.cc index c9a68a6..ff7b842 100644 --- a/src/PackMatrix.cc +++ b/src/PackMatrix.cc @@ -36,54 +36,42 @@ int PackMatrix<PT, inpType, accType>::packedBufferSize( if (!cpuinfo_initialize()) { throw std::runtime_error("Failed to initialize cpuinfo!"); } + if ((!fbgemmHasAvx512VnniSupport() && !fbgemmHasAvx512Support() && + !fbgemmHasAvx2Support())) { + assert(0 && "unknown architecure"); + } + int MCB, KCB, NCB; if (params) { - if (fbgemmHasAvx512Support() || fbgemmHasAvx2Support()) { - MCB = params->MCB; - NCB = params->NCB; - KCB = params->KCB; - } else { - // TODO: Have default slower path - assert(0 && "unsupported architecure"); - } + MCB = params->MCB; + NCB = params->NCB; + KCB = params->KCB; } else { - if (fbgemmHasAvx512Support()) { + if (fbgemmHasAvx512VnniSupport()) { + MCB = PackingTraits<inpType, accType, inst_set_t::avx512_vnni>::MCB; + NCB = PackingTraits<inpType, accType, inst_set_t::avx512_vnni>::NCB; + KCB = PackingTraits<inpType, accType, inst_set_t::avx512_vnni>::KCB; + } else if (fbgemmHasAvx512Support()) { MCB = PackingTraits<inpType, accType, inst_set_t::avx512>::MCB; NCB = PackingTraits<inpType, accType, inst_set_t::avx512>::NCB; KCB = PackingTraits<inpType, accType, inst_set_t::avx512>::KCB; - } else if (fbgemmHasAvx2Support()) { + } else { + // AVX2 MCB = PackingTraits<inpType, accType, inst_set_t::avx2>::MCB; NCB = PackingTraits<inpType, accType, inst_set_t::avx2>::NCB; KCB = PackingTraits<inpType, accType, inst_set_t::avx2>::KCB; - } else { - // TODO: Have default slower path - assert(0 && "unsupported architecure"); - return -1; } } - if (fbgemmHasAvx512Support()) { - if (isA()) { - return MCB * KCB; - } else { - int rowBlock = KCB; - int colBlock = NCB; - return (((rows + rowBlock - 1) / rowBlock) * rowBlock) * - (((cols + colBlock - 1) / colBlock) * colBlock); - } - } else if (fbgemmHasAvx2Support()) { - if (isA()) { - return MCB * KCB; - } else { - int rowBlock = KCB; - int colBlock = NCB; - return (((rows + rowBlock - 1) / rowBlock) * rowBlock) * - (((cols + colBlock - 1) / colBlock) * colBlock); - } + if (isA()) { + return MCB * KCB; } else { - // TODO: Have default slower path - assert(0 && "unsupported architecure"); + int rowBlock = KCB; + int colBlock = NCB; + return (((rows + rowBlock - 1) / rowBlock) * rowBlock) * + (((cols + colBlock - 1) / colBlock) * colBlock); } + return -1; } diff --git a/src/PackWeightMatrixForGConv.cc b/src/PackWeightMatrixForGConv.cc index 0fb0e2c..f6ad59e 100644 --- a/src/PackWeightMatrixForGConv.cc +++ b/src/PackWeightMatrixForGConv.cc @@ -36,8 +36,61 @@ PackWeightMatrixForGConv<T, accT, SPATIAL_DIM>::PackWeightMatrixForGConv( } /** - * @brief Pack weight tensor in a suitable format required for the optimized - * kernel. + * @brief Get the index of the unpacked data for a given <r, s, k, g, c, tr> + * + * Non-transposed: G (R S C/G) K/G + * Transposed: G K/G (R S C/G) + * Using inline as this will be called frequently + */ +template <typename T, typename accT, int SPATIAL_DIM> +inline int PackWeightMatrixForGConv<T, accT, SPATIAL_DIM>::unpacked_index_( + int r, int s, int k, int g, int c, bool tr) { + // Get the full dimensions + int R = conv_param_.K[0]; + int S = conv_param_.K[1]; + int G = conv_param_.G; + int IC_per_G = conv_param_.IC / G; + int OC_per_G = conv_param_.OC / G; + + int idx; + if (tr) { + idx = (((g * OC_per_G + k) * R + r) * S + s) * IC_per_G + c; + } else { + idx = (((g * R + r) * S + s) * IC_per_G + c) * OC_per_G + k; + } + return idx; +} + +/** + * @brief Get the index of the packed data for a given <r, s, k, g, c> + * + * The index may differ depending on IC_per_G. + * Using inline as this will be called frequently + */ +template <typename T, typename accT, int SPATIAL_DIM> +inline int PackWeightMatrixForGConv<T, accT, SPATIAL_DIM>::packed_index_( + int r, int s, int k, int g, int c) { + // Get the full dimensions + int R = conv_param_.K[0]; + int S = conv_param_.K[1]; + int G = conv_param_.G; + int IC_per_G = conv_param_.IC / G; + int OC_per_G = conv_param_.OC / G; + + int idx; + // For IC_per_G == 4, we need to work on 2 groups at a time + if (IC_per_G == 4) { + idx = (((((g / 2) * R + r) * S + s) * OC_per_G + k) * 2 + (g % 2)) + * IC_per_G + c; + } else { + idx = ((((g * (IC_per_G / 4) + (c / 4)) * R + r) * S + s) * OC_per_G + k) + * 4 + (c % 4); + } + return idx; +} + +/** + * @ brief Pack or unpack matrix * * Let IC_per_G be number of input channels per group and OC_per_G be number of * output channels per group. @@ -54,14 +107,16 @@ PackWeightMatrixForGConv<T, accT, SPATIAL_DIM>::PackWeightMatrixForGConv( * while working on 1 group at a time. * In this case, the layout is G (C/4) R S K 4 */ + template <typename T, typename accT, int SPATIAL_DIM> -void PackWeightMatrixForGConv<T, accT, SPATIAL_DIM>::pack() { +void PackWeightMatrixForGConv<T, accT, SPATIAL_DIM>::pack_unpack_( + const T* src, T* dst, bool ispack) { // filters are assumed to be in G RS C/G K/G format int R = conv_param_.K[0]; int S = conv_param_.K[1]; int G = conv_param_.G; - int IC_per_G = conv_param_.IC / conv_param_.G; - int OC_per_G = conv_param_.OC / conv_param_.G; + int IC_per_G = conv_param_.IC / G; + int OC_per_G = conv_param_.OC / G; // If transpose option is set, the weight matrix is in layout G K/G (R S C/G) // instead of G (R S C/G) K/G @@ -73,25 +128,13 @@ void PackWeightMatrixForGConv<T, accT, SPATIAL_DIM>::pack() { for (int k = 0; k < OC_per_G; ++k) { for (int g = 0; g < G; ++g) { for (int c = 0; c < IC_per_G; ++c) { - inpType b = tr - ? sdata_ - [(((g * OC_per_G + k) * R + r) * S + s) * IC_per_G + c] - : sdata_ - [(((g * R + r) * S + s) * IC_per_G + c) * OC_per_G + k]; - if (IC_per_G == 4) { - // For IC_per_G == 4, we need to work on 2 groups at a time - pdata_ - [(((((g / 2) * R + r) * S + s) * OC_per_G + k) * 2 + - (g % 2)) * - IC_per_G + - c] = b; + int p_idx = packed_index_(r, s, k, g, c); + int up_idx = unpacked_index_(r, s, k, g, c, tr); + // Pack: src (unpacked) -> dst (packed) + if (ispack) { + dst[p_idx] = src[up_idx]; } else { - pdata_ - [((((g * (IC_per_G / 4) + (c / 4)) * R + r) * S + s) * - OC_per_G + - k) * - 4 + - (c % 4)] = b; + dst[up_idx] = src[p_idx]; } } } @@ -99,14 +142,54 @@ void PackWeightMatrixForGConv<T, accT, SPATIAL_DIM>::pack() { } } } else { + // For pack & transposed, call transposeConvWeights() + // G K/G (R S C/G) => G (R S C/G) K/G if (tr) { - // conv_ref expects weights to be in G (R S C/G) K/G format - transposeConvWeights(conv_param_, sdata_, pdata_); + if (ispack) { + transposeConvWeights(conv_param_, src, dst); + } else { + // TODO: Wrap this as a inverseTransposeConvWeights()? + // For unpack & transposed, call transposeConvWeights() + // G (R S C/G) K/G => G K/G (R S C/G) + for (int r = 0; r < R; ++r) { + for (int s = 0; s < S; ++s) { + for (int k = 0; k < OC_per_G; ++k) { + for (int g = 0; g < G; ++g) { + for (int c = 0; c < IC_per_G; ++c) { + dst[(((g * OC_per_G + k) * R + r) * S + s) + * IC_per_G + c] = + src[(((g * R + r) * S + s) * IC_per_G + c) + * OC_per_G + k]; + } + } + } + } + } + } // end if(ispack) } else { // just copy the data for not supported cases - memcpy(pdata_, sdata_, G * R * S * OC_per_G * IC_per_G * sizeof(inpType)); - } - } + memcpy(dst, src, + G * R * S * OC_per_G * IC_per_G * sizeof(inpType)); + } //end if(tr) + } // end if(fbgemmOptimizedGConv(conv_param_) +} + +/** + * @brief Pack weight tensor in a suitable format required for the optimized + * kernel. + */ +template <typename T, typename accT, int SPATIAL_DIM> +void PackWeightMatrixForGConv<T, accT, SPATIAL_DIM>::pack() { + pack_unpack_(sdata_, pdata_, true); +} + +/** + * @brief Unpack the packed weight tensor (for the optimized kernel) + * to the original form. + */ +template <typename T, typename accT, int SPATIAL_DIM> +void PackWeightMatrixForGConv<T, accT, SPATIAL_DIM>::unpack(T* origin_buf) { + pack_unpack_(const_cast<const T*>(pdata_), origin_buf, false); } template class PackWeightMatrixForGConv<int8_t, int32_t, 2>; diff --git a/src/PackWeightsForConv.cc b/src/PackWeightsForConv.cc index c811144..192fb00 100644 --- a/src/PackWeightsForConv.cc +++ b/src/PackWeightsForConv.cc @@ -4,6 +4,7 @@ * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. */ +#include <algorithm> #include <memory> #include "fbgemm/Fbgemm.h" @@ -13,7 +14,8 @@ template <int SPATIAL_DIM, typename T, typename accT> PackWeightsForConv<SPATIAL_DIM, T, accT>::PackWeightsForConv( const conv_param_t<SPATIAL_DIM>& conv_p, const T* sdata, - const BlockingFactors* blocking_params) { + const BlockingFactors* blocking_params) + : conv_param_(conv_p) { static_assert( SPATIAL_DIM == 2 || SPATIAL_DIM == 3, "Only 2D and 3D convolutions are supported"); @@ -21,50 +23,153 @@ PackWeightsForConv<SPATIAL_DIM, T, accT>::PackWeightsForConv( // FbgemmConv.cc switch (ConvFastPath<SPATIAL_DIM, accT>(conv_p)) { case optimized_conv_t::depthwise: { - if (SPATIAL_DIM == 3) { - W_im2col_packed_ = nullptr; - W_dw_2D_packed_ = nullptr; - W_dw_3D_packed_ = - std::make_shared<Packed3x3x3ConvMatrix>(conv_p.G, sdata); - W_gconv_packed_ = nullptr; - } else { - W_im2col_packed_ = nullptr; - W_dw_2D_packed_ = - std::make_shared<Packed3x3ConvMatrix>(conv_p.G, sdata); - W_dw_3D_packed_ = nullptr; - W_gconv_packed_ = nullptr; - } + W_dw_packed_ = std::make_shared<PackedDepthWiseConvMatrix>( + conv_p.G, SPATIAL_DIM == 3 ? 3 * 3 * 3 : 3 * 3, sdata); break; } case optimized_conv_t::groupwise: { - W_im2col_packed_ = nullptr; - W_dw_2D_packed_ = nullptr; - W_dw_3D_packed_ = nullptr; W_gconv_packed_ = std::make_shared<PackWeightMatrixForGConv<T, accT, SPATIAL_DIM>>( - matrix_op_t::NoTranspose, conv_p, sdata, nullptr); + matrix_op_t::Transpose, conv_p, sdata, nullptr); + break; + } + case optimized_conv_t::pointwise: { + int NDim = conv_p.OC / conv_p.G; + int KDim = conv_p.K[0] * conv_p.K[1] * conv_p.IC; + W_pointwise_packed_ = std::make_shared<PackBMatrix<T, accT>>( + matrix_op_t::Transpose, + KDim, + NDim, + sdata, + KDim / conv_p.G, + nullptr, + conv_p.G, + blocking_params); break; } case optimized_conv_t::im2col: { int NDim = conv_p.OC / conv_p.G; int KDim = conv_p.K[0] * conv_p.K[1] * conv_p.IC; W_im2col_packed_ = std::make_shared<PackBMatrix<T, accT>>( - matrix_op_t::NoTranspose, + matrix_op_t::Transpose, KDim, NDim, sdata, - NDim, + KDim / conv_p.G, nullptr, conv_p.G, blocking_params); - W_dw_2D_packed_ = nullptr; - W_dw_3D_packed_ = nullptr; - W_gconv_packed_ = nullptr; break; } } // switch } +template <int SPATIAL_DIM, typename T, typename accT> +void PackWeightsForConv<SPATIAL_DIM, T, accT>::unpack(T* origin_buf) { + if (W_dw_packed_) { + W_dw_packed_->unpack(origin_buf); + } else if (W_gconv_packed_) { + W_gconv_packed_->unpack(origin_buf); + } else if (W_im2col_packed_) { + W_im2col_packed_->unpack(origin_buf); + } else if (W_pointwise_packed_) { + W_pointwise_packed_->unpack(origin_buf); + } else { + assert(false && "At least one packed weights object should exist"); + } +} + +template <int SPATIAL_DIM, typename T, typename accT> +bool PackWeightsForConv<SPATIAL_DIM, T, accT>::isPackingCompliant( + const conv_param_t<SPATIAL_DIM>& test_conv_p) { + return conv_param_.IC == test_conv_p.IC && conv_param_.OC == test_conv_p.OC && + conv_param_.G == test_conv_p.G && + std::equal( + conv_param_.K.begin(), + conv_param_.K.end(), + test_conv_p.K.begin()) && + std::equal( + conv_param_.stride.begin(), + conv_param_.stride.end(), + test_conv_p.stride.begin()) && + std::equal( + conv_param_.pad.begin(), + conv_param_.pad.end(), + test_conv_p.pad.begin()) && + std::equal( + conv_param_.dilation.begin(), + conv_param_.dilation.end(), + test_conv_p.dilation.begin()); +} + +template <int SPATIAL_DIM, typename T, typename accT> +std::string PackWeightsForConv<SPATIAL_DIM, T, accT>::mismatchingParams( + const conv_param_t<SPATIAL_DIM>& test_conv_p) { + std::string msg = ""; + + auto combineStr = [](std::string id, std::string str1, std::string str2) { + std::string out = id + std::string(" "); + out += str1; + out += std::string(" vs ") + str2; + out += std::string(";"); + return out; + }; + + auto combineInt = [&combineStr](std::string id, int int1, int int2) { + return combineStr(id, std::to_string(int1), std::to_string(int2)); + }; + + if (conv_param_.IC != test_conv_p.IC) { + msg += combineInt("input_channels", conv_param_.IC, test_conv_p.IC); + } + if (conv_param_.OC != test_conv_p.OC) { + msg += combineInt("output_channels", conv_param_.IC, test_conv_p.IC); + } + if (conv_param_.G != test_conv_p.G) { + msg += combineInt("groups", conv_param_.G, test_conv_p.G); + } + + if (!std::equal( + conv_param_.K.begin(), conv_param_.K.end(), test_conv_p.K.begin())) { + msg += combineStr( + "kernel", + arrayToString<SPATIAL_DIM>(conv_param_.K), + arrayToString<SPATIAL_DIM>(test_conv_p.K)); + } + + if (!std::equal( + conv_param_.stride.begin(), + conv_param_.stride.end(), + test_conv_p.stride.begin())) { + msg += combineStr( + "stride", + arrayToString<SPATIAL_DIM>(conv_param_.stride), + arrayToString<SPATIAL_DIM>(test_conv_p.stride)); + } + + if (!std::equal( + conv_param_.pad.begin(), + conv_param_.pad.end(), + test_conv_p.pad.begin())) { + msg += combineStr( + "pad", + arrayToString<2 * SPATIAL_DIM>(conv_param_.pad), + arrayToString<2 * SPATIAL_DIM>(test_conv_p.pad)); + } + + if (!std::equal( + conv_param_.dilation.begin(), + conv_param_.dilation.end(), + test_conv_p.dilation.begin())) { + msg += combineStr( + "dilation", + arrayToString<SPATIAL_DIM>(conv_param_.dilation), + arrayToString<SPATIAL_DIM>(test_conv_p.dilation)); + } + + return msg; +} + template class PackWeightsForConv<2, int8_t, int32_t>; template class PackWeightsForConv<3, int8_t, int32_t>; diff --git a/src/QuantUtils.cc b/src/QuantUtils.cc index 1ab00d1..a209efc 100644 --- a/src/QuantUtils.cc +++ b/src/QuantUtils.cc @@ -164,30 +164,143 @@ void ChooseRequantizationMultiplier( dst[i] = Quantize<T>(src[i], qparams); \ } \ } -FBGEMM_SPECIALIZED_QUANTIZE(int8_t) FBGEMM_SPECIALIZED_QUANTIZE(uint16_t) FBGEMM_SPECIALIZED_QUANTIZE(int16_t) FBGEMM_SPECIALIZED_QUANTIZE(int32_t) #undef FBGEMM_SPECIALIZED_QUANTIZE +#define FBGEMM_SPECIALIZED_QUANTIZE_AVX2(T) \ +template <> \ +void Quantize<T>( \ + const float* src, \ + T* dst, \ + int len, \ + const TensorQuantizationParams& qparams) { \ + bool avx2_support = cpuinfo_initialize() && fbgemmHasAvx2Support(); \ + bool fma_support = cpuinfo_has_x86_fma3(); \ + if (avx2_support && fma_support && qparams.precision == 8) { \ + /* fast path */ \ + QuantizeAvx2<T>(src, dst, len, qparams); \ + } else { \ + for (std::size_t i = 0; i < len; ++i) { \ + dst[i] = Quantize<T>(src[i], qparams); \ + } \ + } \ +} + +FBGEMM_SPECIALIZED_QUANTIZE_AVX2(int8_t) +FBGEMM_SPECIALIZED_QUANTIZE_AVX2(uint8_t) +#undef FBGEMM_SPECIALIZED_QUANTIZE_AVX2 + +#define FBGEMM_SPECIALIZED_QUANTIZEGROUPWISEKCX(T) \ + template <> \ + void QuantizeGroupwise<T, layout_t::KCX>( \ + const float* src, \ + int N, \ + int C, \ + int X, \ + int G, \ + const float* scales, \ + const std::int32_t* zero_points, \ + T* dst) { \ + assert(C % G == 0); \ + int C_per_G = C / G; \ + for (int i = 0; i < N; ++i) { \ + for (int g = 0; g < G; ++g) { \ + float scale = scales[g]; \ + int32_t zero_point = zero_points[g]; \ + for (int c = 0; c < C / G; ++c) { \ + for (int x = 0; x < X; ++x) { \ + dst[(i * C + g * C_per_G + c) * X + x] = Quantize<T>( \ + src[(i * C + g * C_per_G + c) * X + x], \ + zero_point, \ + scale, \ + 8 * sizeof(T)); \ + } \ + } \ + } \ + } \ + } +FBGEMM_SPECIALIZED_QUANTIZEGROUPWISEKCX(int8_t) +FBGEMM_SPECIALIZED_QUANTIZEGROUPWISEKCX(int32_t) +#undef FBGEMM_SPECIALIZED_QUANTIZEGROUPWISEKCX + template <> -void Quantize<uint8_t>( +void QuantizeGroupwise<uint8_t, layout_t::KCX>( const float* src, - uint8_t* dst, - int len, - const TensorQuantizationParams& qparams) { - bool avx2_support = cpuinfo_initialize() && fbgemmHasAvx2Support(); - bool fma_support = cpuinfo_has_x86_fma3(); - if (avx2_support && fma_support && qparams.precision == 8) { - // fast path - QuantizeAvx2(src, dst, len, qparams); - } else { - for (std::size_t i = 0; i < len; ++i) { - dst[i] = Quantize<uint8_t>(src[i], qparams); + int K, + int C, + int X, + int G, + const float* scales, + const std::int32_t* zero_points, + uint8_t* dst) { + assert(C % G == 0); + int C_per_G = C / G; + fbgemm::TensorQuantizationParams qparams; + qparams.precision = 8 * sizeof(uint8_t); + bool takeFastPath = + cpuinfo_initialize() && fbgemmHasAvx2Support() && cpuinfo_has_x86_fma3(); + + for (int i = 0; i < K; ++i) { + for (int g = 0; g < G; ++g) { + qparams.scale = scales[g]; + qparams.zero_point = zero_points[g]; + if (takeFastPath) { + QuantizeAvx2( + src + (i * C + g * C_per_G) * X, + dst + (i * C + g * C_per_G) * X, + C_per_G * X, + qparams); + } else { + for (int c = 0; c < C / G; ++c) { + for (int x = 0; x < X; ++x) { + dst[(i * C + g * C_per_G + c) * X + x] = Quantize<uint8_t>( + src[(i * C + g * C_per_G + c) * X + x], + qparams.zero_point, + qparams.scale, + qparams.precision); + } + } + } } } } +#define FBGEMM_SPECIALIZED_QUANTIZEGROUPWISEKXC(T) \ + template <> \ + void QuantizeGroupwise<T, layout_t::KXC>( \ + const float* src, \ + int K, \ + int C, \ + int X, \ + int G, \ + const float* scales, \ + const std::int32_t* zero_points, \ + T* dst) { \ + assert(C % G == 0); \ + int C_per_G = C / G; \ + for (int i = 0; i < K; ++i) { \ + for (int x = 0; x < X; ++x) { \ + for (int g = 0; g < G; ++g) { \ + float scale = scales[g]; \ + int32_t zero_point = zero_points[g]; \ + for (int c = 0; c < C / G; ++c) { \ + dst[(i * X + x) * C + g * C_per_G + c] = Quantize<T>( \ + src[(i * X + x) * C + g * C_per_G + c], \ + zero_point, \ + scale, \ + 8 * sizeof(T)); \ + } \ + } \ + } \ + } \ + } +FBGEMM_SPECIALIZED_QUANTIZEGROUPWISEKXC(int8_t) +FBGEMM_SPECIALIZED_QUANTIZEGROUPWISEKXC(uint8_t) +FBGEMM_SPECIALIZED_QUANTIZEGROUPWISEKXC(int32_t) +#undef FBGEMM_SPECIALIZED_QUANTIZEGROUPWISEKXC + //////////////////////////////////////////////////////////////////////////////// // Requantization (pure fixed-point) diff --git a/src/QuantUtilsAvx2.cc b/src/QuantUtilsAvx2.cc index 821999e..66828ae 100755..100644 --- a/src/QuantUtilsAvx2.cc +++ b/src/QuantUtilsAvx2.cc @@ -18,16 +18,20 @@ using namespace std; //////////////////////////////////////////////////////////////////////////////// // Utility functions +template <typename T> void QuantizeAvx2( const float* src, - uint8_t* dst, + T* dst, int len, const TensorQuantizationParams& qparams) { -#if defined(__AVX2__) && defined(__FMA__) - constexpr int VLEN = 8; - std::size_t i = 0; - __m256 inverse_scale_v = _mm256_set1_ps(1.f / qparams.scale); - __m256i shuffle_mask_v = _mm256_set_epi8( + // original compile condition - #if defined(__AVX2__) && (defined(__FMA__) || defined(_MSC_VER)) + if (fbgemm::fbgemmHasAvx2Support()) { + constexpr int VLEN = 8; + constexpr float min_val = std::numeric_limits<T>::min(); + constexpr float max_val = std::numeric_limits<T>::max(); + std::size_t i = 0; + __m256 inverse_scale_v = _mm256_set1_ps(1.f / qparams.scale); + __m256i shuffle_mask_v = _mm256_set_epi8( 0xff, 0xff, 0xff, @@ -60,41 +64,56 @@ void QuantizeAvx2( 0x08, 0x04, 0x00); - __m256i permute_mask_v = + __m256i permute_mask_v = _mm256_set_epi32(0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x00); - for (; i < len / VLEN * VLEN; i += VLEN) { - __m256 src_v = _mm256_loadu_ps(src + i); - __m256 transformed_v = _mm256_fmadd_ps( + for (; i < len / VLEN * VLEN; i += VLEN) { + __m256 src_v = _mm256_loadu_ps(src + i); + __m256 transformed_v = _mm256_fmadd_ps( src_v, inverse_scale_v, _mm256_set1_ps(qparams.zero_point)); - __m256 clipped_v = _mm256_min_ps( - _mm256_max_ps(transformed_v, _mm256_set1_ps(0.f)), - _mm256_set1_ps(255.f)); - __m256i rounded_v = _mm256_cvtps_epi32(clipped_v); - - // An instruction sequence to save 8 32-bit integers as 8 8-bit integers - rounded_v = _mm256_shuffle_epi8(rounded_v, shuffle_mask_v); - rounded_v = _mm256_permutevar8x32_epi32(rounded_v, permute_mask_v); - _mm_storel_epi64( + __m256 clipped_v = _mm256_min_ps( + _mm256_max_ps(transformed_v, _mm256_set1_ps(min_val)), + _mm256_set1_ps(max_val)); + __m256i rounded_v = _mm256_cvtps_epi32(clipped_v); + + // An instruction sequence to save 8 32-bit integers as 8 8-bit integers + rounded_v = _mm256_shuffle_epi8(rounded_v, shuffle_mask_v); + rounded_v = _mm256_permutevar8x32_epi32(rounded_v, permute_mask_v); + _mm_storel_epi64( reinterpret_cast<__m128i*>(dst + i), _mm256_castsi256_si128(rounded_v)); - } + } - for (; i < len; ++i) { - float transformed = qparams.zero_point + src[i] / qparams.scale; - float clipped = std::min(std::max(transformed, 0.f), 255.f); - // Not exactly the same behavior as the vectorized code. - // The vectorized code above always rounds to even in halfway cases - // (https://software.intel.com/en-us/node/523819), but std::nearbyint - // does the same only when the current rounding mode is FE_TONEAREST. - // However, in practice, this should not be a problem because most cases - // use the default rounding mode FE_TONEAREST. - // Note that we cannot implement the same behavior as the vectorized code - // using std::round because it does rounding away from zero in halfway - // cases. - dst[i] = nearbyint(clipped); + for (; i < len; ++i) { + float transformed = qparams.zero_point + src[i] / qparams.scale; + float clipped = std::min(std::max(transformed, min_val), max_val); + // Not exactly the same behavior as the vectorized code. + // The vectorized code above always rounds to even in halfway cases + // (https://software.intel.com/en-us/node/523819), but std::nearbyint + // does the same only when the current rounding mode is FE_TONEAREST. + // However, in practice, this should not be a problem because most cases + // use the default rounding mode FE_TONEAREST. + // Note that we cannot implement the same behavior as the vectorized code + // using std::round because it does rounding away from zero in halfway + // cases. + dst[i] = nearbyint(clipped); + } } -#endif } +// Instantiate QuantizeAvx2 for known datatypes +template +void QuantizeAvx2<uint8_t>( + const float* src, + uint8_t* dst, + int len, + const TensorQuantizationParams& qparams); +template +void QuantizeAvx2<int8_t>( + const float* src, + int8_t* dst, + int len, + const TensorQuantizationParams& qparams); + + void FindMinMax(const float* a, float* min, float* max, int len) { if (len <= 0) { *min = 0.0f; @@ -105,24 +124,24 @@ void FindMinMax(const float* a, float* min, float* max, int len) { float temp_min = *a, temp_max = *a; int i = 0; -#ifdef __AVX__ - __m256 min_v = _mm256_set1_ps(*a), max_v = _mm256_set1_ps(*a); - constexpr int VLEN = 8; - if (len >= VLEN) { - for (; i < len / VLEN * VLEN; i += VLEN) { - min_v = _mm256_min_ps(min_v, _mm256_loadu_ps(a + i)); - max_v = _mm256_max_ps(max_v, _mm256_loadu_ps(a + i)); - } + if (fbgemm::fbgemmHasAvx2Support()) { + __m256 min_v = _mm256_set1_ps(*a), max_v = _mm256_set1_ps(*a); + constexpr int VLEN = 8; + if (len >= VLEN) { + for (; i < len / VLEN * VLEN; i += VLEN) { + min_v = _mm256_min_ps(min_v, _mm256_loadu_ps(a + i)); + max_v = _mm256_max_ps(max_v, _mm256_loadu_ps(a + i)); + } - float min_buf[VLEN], max_buf[VLEN]; - _mm256_storeu_ps(min_buf, min_v); - _mm256_storeu_ps(max_buf, max_v); - for (int j = 0; j < VLEN; ++j) { - temp_min = std::min(temp_min, min_buf[j]); - temp_max = std::max(temp_max, max_buf[j]); + float min_buf[VLEN], max_buf[VLEN]; + _mm256_storeu_ps(min_buf, min_v); + _mm256_storeu_ps(max_buf, max_v); + for (int j = 0; j < VLEN; ++j) { + temp_min = std::min(temp_min, min_buf[j]); + temp_max = std::max(temp_max, max_buf[j]); + } } } -#endif for (; i < len; i++) { temp_min = std::min(temp_min, a[i]); @@ -135,15 +154,15 @@ void FindMinMax(const float* a, float* min, float* max, int len) { //////////////////////////////////////////////////////////////////////////////// // Requantization (with floats) -#ifdef __AVX2__ void RequantizeAvx2( const int32_t* src, uint8_t* dst, int len, const RequantizationParams& params) { - DoNothing<> doNothingObj{}; - int32_t Bq_zero_point[] = { 0 }; - ReQuantizeOutput<false /* FUSE_RELU */> requantizeObj( + if (fbgemm::fbgemmHasAvx2Support()) { + DoNothing<> doNothingObj{}; + int32_t Bq_zero_point[] = { 0 }; + ReQuantizeOutput<false /* FUSE_RELU */> requantizeObj( doNothingObj, ¶ms.real_multiplier, params.target_qparams.zero_point, @@ -153,7 +172,8 @@ void RequantizeAvx2( nullptr, // col_offsets nullptr, // bias len); // ncol - requantizeObj.f<inst_set_t::avx2>(dst, src, {0, 1, 0, len}, 0, 0); + requantizeObj.f<inst_set_t::avx2>(dst, src, { 0, 1, 0, len }, 0, 0); + } } void RequantizeFixedPointAvx2( @@ -161,24 +181,26 @@ void RequantizeFixedPointAvx2( uint8_t* dst, int len, const RequantizationParams& params) { - constexpr int VLEN = 8; + if (fbgemm::fbgemmHasAvx2Support()) + { + constexpr int VLEN = 8; - __m256i b = _mm256_set1_epi32(params.multiplier); + __m256i b = _mm256_set1_epi32(params.multiplier); - // AVX2 doesn't support arithmetic right shift. - // As a work around, we convert 64-bit multiplied results to uint64_t by - // adding 0x8000000000000000ULL, logical right shift, and subtract by - // (0x8000000000000000ULL >> right_shift). - __m256i pre_shift_nudge = _mm256_set1_epi64x( + // AVX2 doesn't support arithmetic right shift. + // As a work around, we convert 64-bit multiplied results to uint64_t by + // adding 0x8000000000000000ULL, logical right shift, and subtract by + // (0x8000000000000000ULL >> right_shift). + __m256i pre_shift_nudge = _mm256_set1_epi64x( (1ll << (params.right_shift - 1)) + 0x8000000000000000ULL); - __m256i post_shift_nudge = _mm256_set1_epi64x( + __m256i post_shift_nudge = _mm256_set1_epi64x( params.target_qparams.zero_point - (0x8000000000000000ULL >> params.right_shift)); - __m256i min_v = _mm256_set1_epi32(numeric_limits<uint8_t>::min()); - __m256i max_v = _mm256_set1_epi32(numeric_limits<uint8_t>::max()); + __m256i min_v = _mm256_set1_epi32(numeric_limits<uint8_t>::min()); + __m256i max_v = _mm256_set1_epi32(numeric_limits<uint8_t>::max()); - __m256i shuffle_mask_v = _mm256_set_epi8( + __m256i shuffle_mask_v = _mm256_set_epi8( 0xff, 0xff, 0xff, @@ -211,75 +233,68 @@ void RequantizeFixedPointAvx2( 0x08, 0x04, 0x00); - __m256i permute_mask_v = + __m256i permute_mask_v = _mm256_set_epi32(0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x00); - int i = 0; - for (; i < len / VLEN * VLEN; i += VLEN) { - __m256i a_v = _mm256_loadu_si256((const __m256i*)(src + i)); + int i = 0; + for (; i < len / VLEN * VLEN; i += VLEN) { + __m256i a_v = _mm256_loadu_si256((const __m256i*)(src + i)); - // a = a0 | a1 | a2 | a3 | a4 | a5 | a6 | a7 - // b = b0 | b1 | b3 | b3 | b4 | b5 | b6 | b7 - __m256i a_even_v = a_v; - __m256i a_odd_v = _mm256_srli_si256(a_v, 4); + // a = a0 | a1 | a2 | a3 | a4 | a5 | a6 | a7 + // b = b0 | b1 | b3 | b3 | b4 | b5 | b6 | b7 + __m256i a_even_v = a_v; + __m256i a_odd_v = _mm256_srli_si256(a_v, 4); - __m256i ab_even_v = _mm256_mul_epi32(a_even_v, b); - __m256i ab_odd_v = _mm256_mul_epi32(a_odd_v, b); + __m256i ab_even_v = _mm256_mul_epi32(a_even_v, b); + __m256i ab_odd_v = _mm256_mul_epi32(a_odd_v, b); - __m256i even_rounded_v = _mm256_add_epi64(ab_even_v, pre_shift_nudge); - __m256i odd_rounded_v = _mm256_add_epi64(ab_odd_v, pre_shift_nudge); + __m256i even_rounded_v = _mm256_add_epi64(ab_even_v, pre_shift_nudge); + __m256i odd_rounded_v = _mm256_add_epi64(ab_odd_v, pre_shift_nudge); - __m256i even_result_v = _mm256_add_epi64( + __m256i even_result_v = _mm256_add_epi64( _mm256_srli_epi64(even_rounded_v, params.right_shift), post_shift_nudge); - __m256i odd_result_v = _mm256_add_epi64( + __m256i odd_result_v = _mm256_add_epi64( _mm256_srli_epi64(odd_rounded_v, params.right_shift), post_shift_nudge); - odd_result_v = _mm256_slli_si256(odd_result_v, 4); + odd_result_v = _mm256_slli_si256(odd_result_v, 4); - // even_result_v has numbers we want in its even 32-bit SIMD lanes, and - // odd_result_v has numbers we want in its odd 32-bit SIMD lanes. - // Use blend to combine them. - __m256i result_v = _mm256_blend_epi32(even_result_v, odd_result_v, 0xaa); - __m256i clipped_v = + // even_result_v has numbers we want in its even 32-bit SIMD lanes, and + // odd_result_v has numbers we want in its odd 32-bit SIMD lanes. + // Use blend to combine them. + __m256i result_v = _mm256_blend_epi32(even_result_v, odd_result_v, 0xaa); + __m256i clipped_v = _mm256_max_epi32(min_v, _mm256_min_epi32(max_v, result_v)); - clipped_v = _mm256_shuffle_epi8(clipped_v, shuffle_mask_v); - clipped_v = _mm256_permutevar8x32_epi32(clipped_v, permute_mask_v); - *(int64_t*)(dst + i) = _mm256_extract_epi64(clipped_v, 0); - } + clipped_v = _mm256_shuffle_epi8(clipped_v, shuffle_mask_v); + clipped_v = _mm256_permutevar8x32_epi32(clipped_v, permute_mask_v); + *(int64_t*)(dst + i) = _mm256_extract_epi64(clipped_v, 0); + } - for (; i < len; ++i) { - int64_t ab_64 = + for (; i < len; ++i) { + int64_t ab_64 = static_cast<int64_t>(src[i]) * static_cast<int64_t>(params.multiplier); - int64_t nudge = 1ll << std::max(0, params.right_shift - 1); - int64_t quantized_down = params.target_qparams.zero_point + + int64_t nudge = 1ll << std::max(0, params.right_shift - 1); + int64_t quantized_down = params.target_qparams.zero_point + ((ab_64 + nudge) >> params.right_shift); - dst[i] = std::min<int64_t>(std::max<int64_t>(quantized_down, 0l), 255l); + dst[i] = std::min<int64_t>(std::max<int64_t>(quantized_down, 0l), 255l); + } } } -#else -// dummy implementations to avoid link errors -void RequantizeAvx2(const int32_t* src, uint8_t* dst, int len, const RequantizationParams& params) { - assert(false && "RequantizeAvx2() was called unexpectedly in non-AVX2 build"); -} -void RequantizeFixedPointAvx2(const int32_t* src, uint8_t* dst, int len, const RequantizationParams& params) { - assert(false && "RequantizeFixedPointAvx2() was called unexpectedly in non-AVX2 build"); -} -#endif template < bool A_SYMMETRIC, bool B_SYMMETRIC, QuantizationGranularity Q_GRAN, bool HAS_BIAS, - bool FUSE_RELU> + bool FUSE_RELU, + typename BIAS_TYPE> void requantizeOutputProcessingAvx2( uint8_t* out, const int32_t* inp, const block_type_t& block, int ld_out, int ld_in, - const requantizationParams_t& r) { + const requantizationParams_t<BIAS_TYPE>& r) { // Adoption of implementation at QNNPACK/src/requantization/fp32-sse2.c // using AVX2 instructions int quant_param_idx = 0; @@ -290,6 +305,15 @@ void requantizeOutputProcessingAvx2( } __m256 multiplier_v = _mm256_set1_ps(r.C_multiplier[quant_param_idx]); + // Broadcasted reciprocal of act_times_w_scale + __m256 act_times_w_rcp_v; + if (!(Q_GRAN == QuantizationGranularity::OUT_CHANNEL)) { + if (is_same<BIAS_TYPE, float>::value) { + act_times_w_rcp_v = + _mm256_set1_ps(1.0 / r.act_times_w_scale[quant_param_idx]); + } + } + __m256i min_v = _mm256_set1_epi8(static_cast<uint8_t>(0)); __m256i max_v = _mm256_set1_epi8(static_cast<uint8_t>(255)); @@ -399,22 +423,76 @@ void requantizeOutputProcessingAvx2( } w_v = _mm256_sub_epi32(w_v, row_offset_v); } + __m256 xf_v, yf_v, zf_v, wf_v; if (HAS_BIAS) { - x_v = _mm256_add_epi32( - x_v, - _mm256_loadu_si256(reinterpret_cast<const __m256i*>(r.bias + j))); - y_v = _mm256_add_epi32( - y_v, - _mm256_loadu_si256( - reinterpret_cast<const __m256i*>(r.bias + j + VLEN))); - z_v = _mm256_add_epi32( - z_v, - _mm256_loadu_si256( - reinterpret_cast<const __m256i*>(r.bias + j + 2 * VLEN))); - w_v = _mm256_add_epi32( - w_v, - _mm256_loadu_si256( - reinterpret_cast<const __m256i*>(r.bias + j + 3 * VLEN))); + if (is_same<BIAS_TYPE, float>::value) { + __m256 x_bias_v, y_bias_v, z_bias_v, w_bias_v; + if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) { + x_bias_v = _mm256_div_ps( + _mm256_loadu_ps( + reinterpret_cast<const float*>(r.bias + j + 0 * VLEN)), + _mm256_loadu_ps(r.act_times_w_scale + j + 0 * VLEN)); + y_bias_v = _mm256_div_ps( + _mm256_loadu_ps( + reinterpret_cast<const float*>(r.bias + j + 1 * VLEN)), + _mm256_loadu_ps(r.act_times_w_scale + j + 1 * VLEN)); + z_bias_v = _mm256_div_ps( + _mm256_loadu_ps( + reinterpret_cast<const float*>(r.bias + j + 2 * VLEN)), + _mm256_loadu_ps(r.act_times_w_scale + j + 2 * VLEN)); + w_bias_v = _mm256_div_ps( + _mm256_loadu_ps( + reinterpret_cast<const float*>(r.bias + j + 3 * VLEN)), + _mm256_loadu_ps(r.act_times_w_scale + j + 3 * VLEN)); + } else { + x_bias_v = _mm256_mul_ps( + _mm256_loadu_ps( + reinterpret_cast<const float*>(r.bias + j + 0 * VLEN)), + act_times_w_rcp_v); + y_bias_v = _mm256_mul_ps( + _mm256_loadu_ps( + reinterpret_cast<const float*>(r.bias + j + 1 * VLEN)), + act_times_w_rcp_v); + z_bias_v = _mm256_mul_ps( + _mm256_loadu_ps( + reinterpret_cast<const float*>(r.bias + j + 2 * VLEN)), + act_times_w_rcp_v); + w_bias_v = _mm256_mul_ps( + _mm256_loadu_ps( + reinterpret_cast<const float*>(r.bias + j + 3 * VLEN)), + act_times_w_rcp_v); + } + xf_v = _mm256_add_ps(_mm256_cvtepi32_ps(x_v), x_bias_v); + yf_v = _mm256_add_ps(_mm256_cvtepi32_ps(y_v), y_bias_v); + zf_v = _mm256_add_ps(_mm256_cvtepi32_ps(z_v), z_bias_v); + wf_v = _mm256_add_ps(_mm256_cvtepi32_ps(w_v), w_bias_v); + } else { + x_v = _mm256_add_epi32( + x_v, + _mm256_loadu_si256( + reinterpret_cast<const __m256i*>(r.bias + j + 0 * VLEN))); + y_v = _mm256_add_epi32( + y_v, + _mm256_loadu_si256( + reinterpret_cast<const __m256i*>(r.bias + j + 1 * VLEN))); + z_v = _mm256_add_epi32( + z_v, + _mm256_loadu_si256( + reinterpret_cast<const __m256i*>(r.bias + j + 2 * VLEN))); + w_v = _mm256_add_epi32( + w_v, + _mm256_loadu_si256( + reinterpret_cast<const __m256i*>(r.bias + j + 3 * VLEN))); + xf_v = _mm256_cvtepi32_ps(x_v); + yf_v = _mm256_cvtepi32_ps(y_v); + zf_v = _mm256_cvtepi32_ps(z_v); + wf_v = _mm256_cvtepi32_ps(w_v); + } + } else { + xf_v = _mm256_cvtepi32_ps(x_v); + yf_v = _mm256_cvtepi32_ps(y_v); + zf_v = _mm256_cvtepi32_ps(z_v); + wf_v = _mm256_cvtepi32_ps(w_v); } /* @@ -431,22 +509,19 @@ void requantizeOutputProcessingAvx2( */ __m256 x_scaled_v, y_scaled_v, z_scaled_v, w_scaled_v; if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) { - x_scaled_v = _mm256_mul_ps( - _mm256_cvtepi32_ps(x_v), _mm256_loadu_ps(r.C_multiplier + j)); - y_scaled_v = _mm256_mul_ps( - _mm256_cvtepi32_ps(y_v), - _mm256_loadu_ps(r.C_multiplier + j + VLEN)); - z_scaled_v = _mm256_mul_ps( - _mm256_cvtepi32_ps(z_v), - _mm256_loadu_ps(r.C_multiplier + j + 2 * VLEN)); - w_scaled_v = _mm256_mul_ps( - _mm256_cvtepi32_ps(w_v), - _mm256_loadu_ps(r.C_multiplier + j + 3 * VLEN)); + x_scaled_v = + _mm256_mul_ps(xf_v, _mm256_loadu_ps(r.C_multiplier + j + 0 * VLEN)); + y_scaled_v = + _mm256_mul_ps(yf_v, _mm256_loadu_ps(r.C_multiplier + j + 1 * VLEN)); + z_scaled_v = + _mm256_mul_ps(zf_v, _mm256_loadu_ps(r.C_multiplier + j + 2 * VLEN)); + w_scaled_v = + _mm256_mul_ps(wf_v, _mm256_loadu_ps(r.C_multiplier + j + 3 * VLEN)); } else { - x_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(x_v), multiplier_v); - y_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(y_v), multiplier_v); - z_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(z_v), multiplier_v); - w_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(w_v), multiplier_v); + x_scaled_v = _mm256_mul_ps(xf_v, multiplier_v); + y_scaled_v = _mm256_mul_ps(yf_v, multiplier_v); + z_scaled_v = _mm256_mul_ps(zf_v, multiplier_v); + w_scaled_v = _mm256_mul_ps(wf_v, multiplier_v); } /* @@ -533,18 +608,35 @@ void requantizeOutputProcessingAvx2( } x_v = _mm256_sub_epi32(x_v, row_offset_v); } + __m256 xf_v; if (HAS_BIAS) { - x_v = _mm256_add_epi32( - x_v, - _mm256_loadu_si256(reinterpret_cast<const __m256i*>(r.bias + j))); + if (is_same<BIAS_TYPE, float>::value) { + __m256 x_bias_v; + if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) { + x_bias_v = _mm256_div_ps( + _mm256_loadu_ps(reinterpret_cast<const float*>(r.bias + j)), + _mm256_loadu_ps(r.act_times_w_scale + j)); + } else { + x_bias_v = _mm256_mul_ps( + _mm256_loadu_ps(reinterpret_cast<const float*>(r.bias + j)), + act_times_w_rcp_v); + } + xf_v = _mm256_add_ps(_mm256_cvtepi32_ps(x_v), x_bias_v); + } else { + x_v = _mm256_add_epi32( + x_v, + _mm256_loadu_si256(reinterpret_cast<const __m256i*>(r.bias + j))); + xf_v = _mm256_cvtepi32_ps(x_v); + } + } else { + xf_v = _mm256_cvtepi32_ps(x_v); } __m256 x_scaled_v; if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) { - x_scaled_v = _mm256_mul_ps( - _mm256_cvtepi32_ps(x_v), _mm256_loadu_ps(r.C_multiplier + j)); + x_scaled_v = _mm256_mul_ps(xf_v, _mm256_loadu_ps(r.C_multiplier + j)); } else { - x_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(x_v), multiplier_v); + x_scaled_v = _mm256_mul_ps(xf_v, multiplier_v); } __m256i x_rounded_v = _mm256_cvtps_epi32(x_scaled_v); @@ -582,6 +674,7 @@ void requantizeOutputProcessingAvx2( int remainder = block.col_start + block.col_size - j; if (remainder > 0) { + // clang-format off alignas(64) const int masks[8][8] = { // NOTE: clang-format wants to use a different formatting but the // current formatting should be easier to read. @@ -594,6 +687,7 @@ void requantizeOutputProcessingAvx2( { -1, -1, -1, -1, -1, -1, 0, 0, }, { -1, -1, -1, -1, -1, -1, -1, 0, }, }; + // clang-format on __m256i mask_v = _mm256_load_si256( reinterpret_cast<const __m256i*>(masks[remainder])); @@ -615,17 +709,40 @@ void requantizeOutputProcessingAvx2( } x_v = _mm256_sub_epi32(x_v, row_offset_v); } + + __m256 xf_v; if (HAS_BIAS) { - x_v = _mm256_add_epi32(x_v, _mm256_maskload_epi32(r.bias + j, mask_v)); + if (is_same<BIAS_TYPE, float>::value) { + __m256 x_bias_v; + if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) { + x_bias_v = _mm256_div_ps( + _mm256_maskload_ps( + reinterpret_cast<const float*>(r.bias + j), mask_v), + _mm256_maskload_ps(r.act_times_w_scale + j, mask_v)); + } else { + x_bias_v = _mm256_mul_ps( + _mm256_maskload_ps( + reinterpret_cast<const float*>(r.bias + j), mask_v), + act_times_w_rcp_v); + } + xf_v = _mm256_add_ps(_mm256_cvtepi32_ps(x_v), x_bias_v); + } else { + x_v = _mm256_add_epi32( + x_v, + _mm256_maskload_epi32( + reinterpret_cast<const int*>(r.bias + j), mask_v)); + xf_v = _mm256_cvtepi32_ps(x_v); + } + } else { + xf_v = _mm256_cvtepi32_ps(x_v); } __m256 x_scaled_v; if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) { - x_scaled_v = _mm256_mul_ps( - _mm256_cvtepi32_ps(x_v), - _mm256_maskload_ps(r.C_multiplier + j, mask_v)); + x_scaled_v = + _mm256_mul_ps(xf_v, _mm256_maskload_ps(r.C_multiplier + j, mask_v)); } else { - x_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(x_v), multiplier_v); + x_scaled_v = _mm256_mul_ps(xf_v, multiplier_v); } __m256i x_rounded_v = _mm256_cvtps_epi32(x_scaled_v); @@ -767,6 +884,7 @@ void requantizeForFloatAvx2( int remainder = block.col_start + block.col_size - j; if (remainder > 0) { + // clang-format off alignas(64) const int masks[8][8] = { // NOTE: clang-format wants to use a different formatting but the // current formatting should be easier to read. @@ -779,6 +897,7 @@ void requantizeForFloatAvx2( { -1, -1, -1, -1, -1, -1, 0, 0, }, { -1, -1, -1, -1, -1, -1, -1, 0, }, }; + // clang-format on __m256i mask_v = _mm256_load_si256( reinterpret_cast<const __m256i*>(masks[remainder])); @@ -831,14 +950,15 @@ template < QuantizationGranularity Q_GRAN, bool HAS_BIAS, bool FUSE_RELU, - int C_PER_G> + int C_PER_G, + typename BIAS_TYPE> void requantizeOutputProcessingGConvAvx2( uint8_t* out, const int32_t* inp, const block_type_t& block, int ld_out, int ld_in, - const requantizationParams_t& r) { + const requantizationParams_t<BIAS_TYPE>& r) { // Adoption of implementation at QNNPACK/src/requantization/fp32-sse2.c // using AVX2 instructions int quant_param_idx = 0; @@ -849,6 +969,14 @@ void requantizeOutputProcessingGConvAvx2( } __m256 multiplier_v = _mm256_set1_ps(r.C_multiplier[quant_param_idx]); + // Broadcasted reciprocal of act_times_w_scale + __m256 act_times_w_rcp_v; + if (!(Q_GRAN == QuantizationGranularity::OUT_CHANNEL)) { + if (is_same<BIAS_TYPE, float>::value) { + act_times_w_rcp_v = + _mm256_set1_ps(1.0 / r.act_times_w_scale[quant_param_idx]); + } + } __m256i min_v = _mm256_set1_epi8(static_cast<uint8_t>(0)); __m256i max_v = _mm256_set1_epi8(static_cast<uint8_t>(255)); @@ -1095,22 +1223,135 @@ void requantizeOutputProcessingGConvAvx2( } w_v = _mm256_sub_epi32(w_v, row_offset_v); } + __m256 xf_v, yf_v, zf_v, wf_v; if (HAS_BIAS) { - x_v = _mm256_add_epi32( - x_v, - _mm256_loadu_si256(reinterpret_cast<const __m256i*>(r.bias + j))); - y_v = _mm256_add_epi32( - y_v, - _mm256_loadu_si256( - reinterpret_cast<const __m256i*>(r.bias + j + VLEN))); - z_v = _mm256_add_epi32( - z_v, - _mm256_loadu_si256( - reinterpret_cast<const __m256i*>(r.bias + j + 2 * VLEN))); - w_v = _mm256_add_epi32( - w_v, - _mm256_loadu_si256( - reinterpret_cast<const __m256i*>(r.bias + j + 3 * VLEN))); + if (is_same<BIAS_TYPE, float>::value) { + __m256 x_bias_v = _mm256_loadu_ps( + reinterpret_cast<const float*>(r.bias + j + 0 * VLEN)); + __m256 y_bias_v = _mm256_loadu_ps( + reinterpret_cast<const float*>(r.bias + j + 1 * VLEN)); + __m256 z_bias_v = _mm256_loadu_ps( + reinterpret_cast<const float*>(r.bias + j + 2 * VLEN)); + __m256 w_bias_v = _mm256_loadu_ps( + reinterpret_cast<const float*>(r.bias + j + 3 * VLEN)); + if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) { + x_bias_v = _mm256_div_ps( + x_bias_v, _mm256_loadu_ps(r.act_times_w_scale + j + 0 * VLEN)); + y_bias_v = _mm256_div_ps( + y_bias_v, _mm256_loadu_ps(r.act_times_w_scale + j + 1 * VLEN)); + z_bias_v = _mm256_div_ps( + z_bias_v, _mm256_loadu_ps(r.act_times_w_scale + j + 2 * VLEN)); + w_bias_v = _mm256_div_ps( + w_bias_v, _mm256_loadu_ps(r.act_times_w_scale + j + 3 * VLEN)); + } else if (Q_GRAN == QuantizationGranularity::GROUP) { + __m256 diviser_v; + if (C_PER_G == 4) { + diviser_v = _mm256_insertf128_ps( + _mm256_castps128_ps256( + _mm_set1_ps(r.act_times_w_scale[quant_param_idx + 0])), + _mm_set1_ps(r.act_times_w_scale[quant_param_idx + 1]), + 1); + x_bias_v = _mm256_div_ps(x_bias_v, diviser_v); + + diviser_v = _mm256_insertf128_ps( + _mm256_castps128_ps256( + _mm_set1_ps(r.act_times_w_scale[quant_param_idx + 2])), + _mm_set1_ps(r.act_times_w_scale[quant_param_idx + 3]), + 1); + y_bias_v = _mm256_div_ps(y_bias_v, diviser_v); + + diviser_v = _mm256_insertf128_ps( + _mm256_castps128_ps256( + _mm_set1_ps(r.act_times_w_scale[quant_param_idx + 4])), + _mm_set1_ps(r.act_times_w_scale[quant_param_idx + 5]), + 1); + z_bias_v = _mm256_div_ps(z_bias_v, diviser_v); + + diviser_v = _mm256_insertf128_ps( + _mm256_castps128_ps256( + _mm_set1_ps(r.act_times_w_scale[quant_param_idx + 6])), + _mm_set1_ps(r.act_times_w_scale[quant_param_idx + 7]), + 1); + w_bias_v = _mm256_div_ps(w_bias_v, diviser_v); + + } else if (C_PER_G == 8) { + diviser_v = _mm256_set1_ps( + r.act_times_w_scale + [quant_param_idx + + (j - block.col_start) / (VLEN * 4) * 4 + 0]); + x_bias_v = _mm256_div_ps(x_bias_v, diviser_v); + + diviser_v = _mm256_set1_ps( + r.act_times_w_scale + [quant_param_idx + + (j - block.col_start) / (VLEN * 4) * 4 + 1]); + y_bias_v = _mm256_div_ps(y_bias_v, diviser_v); + + diviser_v = _mm256_set1_ps( + r.act_times_w_scale + [quant_param_idx + + (j - block.col_start) / (VLEN * 4) * 4 + 2]); + z_bias_v = _mm256_div_ps(z_bias_v, diviser_v); + + diviser_v = _mm256_set1_ps( + r.act_times_w_scale + [quant_param_idx + + (j - block.col_start) / (VLEN * 4) * 4 + 3]); + w_bias_v = _mm256_div_ps(w_bias_v, diviser_v); + + } else { + assert(C_PER_G == 16); + diviser_v = _mm256_set1_ps( + r.act_times_w_scale + [quant_param_idx + + (j - block.col_start) / (VLEN * 4) * 2 + 0]); + x_bias_v = _mm256_div_ps(x_bias_v, diviser_v); + y_bias_v = _mm256_div_ps(y_bias_v, diviser_v); + + diviser_v = _mm256_set1_ps( + r.act_times_w_scale + [quant_param_idx + + (j - block.col_start) / (VLEN * 4) * 2 + 1]); + z_bias_v = _mm256_div_ps(z_bias_v, diviser_v); + w_bias_v = _mm256_div_ps(w_bias_v, diviser_v); + } + } else { + x_bias_v = _mm256_mul_ps(x_bias_v, act_times_w_rcp_v); + y_bias_v = _mm256_mul_ps(y_bias_v, act_times_w_rcp_v); + z_bias_v = _mm256_mul_ps(z_bias_v, act_times_w_rcp_v); + w_bias_v = _mm256_mul_ps(w_bias_v, act_times_w_rcp_v); + } + xf_v = _mm256_add_ps(_mm256_cvtepi32_ps(x_v), x_bias_v); + yf_v = _mm256_add_ps(_mm256_cvtepi32_ps(y_v), y_bias_v); + zf_v = _mm256_add_ps(_mm256_cvtepi32_ps(z_v), z_bias_v); + wf_v = _mm256_add_ps(_mm256_cvtepi32_ps(w_v), w_bias_v); + } else { + x_v = _mm256_add_epi32( + x_v, + _mm256_loadu_si256( + reinterpret_cast<const __m256i*>(r.bias + j + 0 * VLEN))); + y_v = _mm256_add_epi32( + y_v, + _mm256_loadu_si256( + reinterpret_cast<const __m256i*>(r.bias + j + 1 * VLEN))); + z_v = _mm256_add_epi32( + z_v, + _mm256_loadu_si256( + reinterpret_cast<const __m256i*>(r.bias + j + 2 * VLEN))); + w_v = _mm256_add_epi32( + w_v, + _mm256_loadu_si256( + reinterpret_cast<const __m256i*>(r.bias + j + 3 * VLEN))); + xf_v = _mm256_cvtepi32_ps(x_v); + yf_v = _mm256_cvtepi32_ps(y_v); + zf_v = _mm256_cvtepi32_ps(z_v); + wf_v = _mm256_cvtepi32_ps(w_v); + } + } else { + xf_v = _mm256_cvtepi32_ps(x_v); + yf_v = _mm256_cvtepi32_ps(y_v); + zf_v = _mm256_cvtepi32_ps(z_v); + wf_v = _mm256_cvtepi32_ps(w_v); } /* @@ -1127,17 +1368,13 @@ void requantizeOutputProcessingGConvAvx2( */ __m256 x_scaled_v, y_scaled_v, z_scaled_v, w_scaled_v; if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) { - x_scaled_v = _mm256_mul_ps( - _mm256_cvtepi32_ps(x_v), _mm256_loadu_ps(r.C_multiplier + j)); - y_scaled_v = _mm256_mul_ps( - _mm256_cvtepi32_ps(y_v), - _mm256_loadu_ps(r.C_multiplier + j + VLEN)); - z_scaled_v = _mm256_mul_ps( - _mm256_cvtepi32_ps(z_v), - _mm256_loadu_ps(r.C_multiplier + j + 2 * VLEN)); - w_scaled_v = _mm256_mul_ps( - _mm256_cvtepi32_ps(w_v), - _mm256_loadu_ps(r.C_multiplier + j + 3 * VLEN)); + x_scaled_v = _mm256_mul_ps(xf_v, _mm256_loadu_ps(r.C_multiplier + j)); + y_scaled_v = + _mm256_mul_ps(yf_v, _mm256_loadu_ps(r.C_multiplier + j + VLEN)); + z_scaled_v = + _mm256_mul_ps(zf_v, _mm256_loadu_ps(r.C_multiplier + j + 2 * VLEN)); + w_scaled_v = + _mm256_mul_ps(wf_v, _mm256_loadu_ps(r.C_multiplier + j + 3 * VLEN)); } else if (Q_GRAN == QuantizationGranularity::GROUP) { if (C_PER_G == 4) { multiplier_v = _mm256_insertf128_ps( @@ -1145,70 +1382,70 @@ void requantizeOutputProcessingGConvAvx2( _mm_set1_ps(r.C_multiplier[quant_param_idx])), _mm_set1_ps(r.C_multiplier[quant_param_idx + 1]), 1); - x_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(x_v), multiplier_v); + x_scaled_v = _mm256_mul_ps(xf_v, multiplier_v); multiplier_v = _mm256_insertf128_ps( _mm256_castps128_ps256( _mm_set1_ps(r.C_multiplier[quant_param_idx + 2])), _mm_set1_ps(r.C_multiplier[quant_param_idx + 3]), 1); - y_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(y_v), multiplier_v); + y_scaled_v = _mm256_mul_ps(yf_v, multiplier_v); multiplier_v = _mm256_insertf128_ps( _mm256_castps128_ps256( _mm_set1_ps(r.C_multiplier[quant_param_idx + 4])), _mm_set1_ps(r.C_multiplier[quant_param_idx + 5]), 1); - z_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(z_v), multiplier_v); + z_scaled_v = _mm256_mul_ps(zf_v, multiplier_v); multiplier_v = _mm256_insertf128_ps( _mm256_castps128_ps256( _mm_set1_ps(r.C_multiplier[quant_param_idx + 6])), _mm_set1_ps(r.C_multiplier[quant_param_idx + 7]), 1); - w_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(w_v), multiplier_v); + w_scaled_v = _mm256_mul_ps(wf_v, multiplier_v); } else if (C_PER_G == 8) { multiplier_v = _mm256_set1_ps( r.C_multiplier [quant_param_idx + (j - block.col_start) / (VLEN * 4) * 4]); - x_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(x_v), multiplier_v); - - multiplier_v = _mm256_set1_ps( - r.C_multiplier - [quant_param_idx + (j - block.col_start) / (VLEN * 4) * 4 + - 1]); - y_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(y_v), multiplier_v); - - multiplier_v = _mm256_set1_ps( - r.C_multiplier - [quant_param_idx + (j - block.col_start) / (VLEN * 4) * 4 + - 2]); - z_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(z_v), multiplier_v); - - multiplier_v = _mm256_set1_ps( - r.C_multiplier - [quant_param_idx + (j - block.col_start) / (VLEN * 4) * 4 + - 3]); - w_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(w_v), multiplier_v); + x_scaled_v = _mm256_mul_ps(xf_v, multiplier_v); + + multiplier_v = + _mm256_set1_ps(r.C_multiplier + [quant_param_idx + + (j - block.col_start) / (VLEN * 4) * 4 + 1]); + y_scaled_v = _mm256_mul_ps(yf_v, multiplier_v); + + multiplier_v = + _mm256_set1_ps(r.C_multiplier + [quant_param_idx + + (j - block.col_start) / (VLEN * 4) * 4 + 2]); + z_scaled_v = _mm256_mul_ps(zf_v, multiplier_v); + + multiplier_v = + _mm256_set1_ps(r.C_multiplier + [quant_param_idx + + (j - block.col_start) / (VLEN * 4) * 4 + 3]); + w_scaled_v = _mm256_mul_ps(wf_v, multiplier_v); } else { multiplier_v = _mm256_set1_ps( r.C_multiplier [quant_param_idx + (j - block.col_start) / (VLEN * 4) * 2]); - x_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(x_v), multiplier_v); - y_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(y_v), multiplier_v); - - multiplier_v = _mm256_set1_ps( - r.C_multiplier - [quant_param_idx + (j - block.col_start) / (VLEN * 4) * 2 + - 1]); - z_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(z_v), multiplier_v); - w_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(w_v), multiplier_v); + x_scaled_v = _mm256_mul_ps(xf_v, multiplier_v); + y_scaled_v = _mm256_mul_ps(yf_v, multiplier_v); + + multiplier_v = + _mm256_set1_ps(r.C_multiplier + [quant_param_idx + + (j - block.col_start) / (VLEN * 4) * 2 + 1]); + z_scaled_v = _mm256_mul_ps(zf_v, multiplier_v); + w_scaled_v = _mm256_mul_ps(wf_v, multiplier_v); } } else { - x_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(x_v), multiplier_v); - y_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(y_v), multiplier_v); - z_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(z_v), multiplier_v); - w_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(w_v), multiplier_v); + x_scaled_v = _mm256_mul_ps(xf_v, multiplier_v); + y_scaled_v = _mm256_mul_ps(yf_v, multiplier_v); + z_scaled_v = _mm256_mul_ps(zf_v, multiplier_v); + w_scaled_v = _mm256_mul_ps(wf_v, multiplier_v); } /* @@ -1279,46 +1516,69 @@ void requantizeOutputProcessingGConvAvx2( } // i loop } -#define INSTANTIATE_REQUANTIZE(A_SYM, B_SYM, Q_GRAN, BIAS, RELU) \ - template void \ - requantizeOutputProcessingAvx2<A_SYM, B_SYM, Q_GRAN, BIAS, RELU>( \ - uint8_t * out, \ - const int32_t* inp, \ - const block_type_t& block, \ - int ld_out, \ - int ld_in, \ - const requantizationParams_t& r); \ - template void requantizeForFloatAvx2<A_SYM, B_SYM, Q_GRAN, BIAS, RELU>( \ - float* out, \ - const int32_t* inp, \ - const block_type_t& block, \ - int ld_out, \ - int ld_in, \ - const requantizationForFloatParams_t& r); \ - template void \ - requantizeOutputProcessingGConvAvx2<A_SYM, B_SYM, Q_GRAN, BIAS, RELU, 4>( \ - uint8_t * out, \ - const int32_t* inp, \ - const block_type_t& block, \ - int ld_out, \ - int ld_in, \ - const requantizationParams_t& r); \ - template void \ - requantizeOutputProcessingGConvAvx2<A_SYM, B_SYM, Q_GRAN, BIAS, RELU, 8>( \ - uint8_t * out, \ - const int32_t* inp, \ - const block_type_t& block, \ - int ld_out, \ - int ld_in, \ - const requantizationParams_t& r); \ - template void \ - requantizeOutputProcessingGConvAvx2<A_SYM, B_SYM, Q_GRAN, BIAS, RELU, 16>( \ - uint8_t * out, \ - const int32_t* inp, \ - const block_type_t& block, \ - int ld_out, \ - int ld_in, \ - const requantizationParams_t& r); +#define INSTANTIATE_REQUANTIZE_BIAS_TYPE( \ + A_SYM, B_SYM, Q_GRAN, BIAS, RELU, BIAS_TYPE) \ + template void \ + requantizeOutputProcessingAvx2<A_SYM, B_SYM, Q_GRAN, BIAS, RELU, BIAS_TYPE>( \ + uint8_t * out, \ + const int32_t* inp, \ + const block_type_t& block, \ + int ld_out, \ + int ld_in, \ + const requantizationParams_t<BIAS_TYPE>& r); \ + template void requantizeOutputProcessingGConvAvx2< \ + A_SYM, \ + B_SYM, \ + Q_GRAN, \ + BIAS, \ + RELU, \ + 4, \ + BIAS_TYPE>( \ + uint8_t * out, \ + const int32_t* inp, \ + const block_type_t& block, \ + int ld_out, \ + int ld_in, \ + const requantizationParams_t<BIAS_TYPE>& r); \ + template void requantizeOutputProcessingGConvAvx2< \ + A_SYM, \ + B_SYM, \ + Q_GRAN, \ + BIAS, \ + RELU, \ + 8, \ + BIAS_TYPE>( \ + uint8_t * out, \ + const int32_t* inp, \ + const block_type_t& block, \ + int ld_out, \ + int ld_in, \ + const requantizationParams_t<BIAS_TYPE>& r); \ + template void requantizeOutputProcessingGConvAvx2< \ + A_SYM, \ + B_SYM, \ + Q_GRAN, \ + BIAS, \ + RELU, \ + 16, \ + BIAS_TYPE>( \ + uint8_t * out, \ + const int32_t* inp, \ + const block_type_t& block, \ + int ld_out, \ + int ld_in, \ + const requantizationParams_t<BIAS_TYPE>& r); + +#define INSTANTIATE_REQUANTIZE(A_SYM, B_SYM, Q_GRAN, BIAS, RELU) \ + INSTANTIATE_REQUANTIZE_BIAS_TYPE(A_SYM, B_SYM, Q_GRAN, BIAS, RELU, float) \ + INSTANTIATE_REQUANTIZE_BIAS_TYPE(A_SYM, B_SYM, Q_GRAN, BIAS, RELU, int32_t) \ + template void requantizeForFloatAvx2<A_SYM, B_SYM, Q_GRAN, BIAS, RELU>( \ + float* out, \ + const int32_t* inp, \ + const block_type_t& block, \ + int ld_out, \ + int ld_in, \ + const requantizationForFloatParams_t& r); #define INSTANTIATE_A_SYM(B_SYM, Q_GRAN, BIAS, RELU) \ INSTANTIATE_REQUANTIZE(true, B_SYM, Q_GRAN, BIAS, RELU) \ diff --git a/src/RefImplementations.cc b/src/RefImplementations.cc index b4b0c2b..dc40d44 100644 --- a/src/RefImplementations.cc +++ b/src/RefImplementations.cc @@ -181,8 +181,7 @@ void cblas_sgemm_ref( int ldb, float beta, float* Cfp32, - int ldc - ) { + int ldc) { for (int i = 0; i < m; ++i) { for (int j = 0; j < n; ++j) { float sum = 0; @@ -204,7 +203,6 @@ void cblas_sgemm_ref( } } - void row_offsets_u8acc32_ref( int M, int K, @@ -302,9 +300,11 @@ void im2col_ref( for (int h = 0; h < OUT_DIM[0]; ++h) { for (int w = 0; w < OUT_DIM[1]; ++w) { for (int r = 0; r < K[0]; ++r) { - int h_in = -conv_p.pad[0] + h * conv_p.stride[0] + r; + int h_in = + -conv_p.pad[0] + h * conv_p.stride[0] + r * conv_p.dilation[0]; for (int s = 0; s < K[1]; ++s) { - int w_in = -conv_p.pad[1] + w * conv_p.stride[1] + s; + int w_in = + -conv_p.pad[1] + w * conv_p.stride[1] + s * conv_p.dilation[1]; if (h_in < 0 || h_in >= IN_DIM[0] || w_in < 0 || w_in >= IN_DIM[1]) { for (int g = 0; g < G; ++g) { @@ -365,11 +365,14 @@ void im2col_ref( for (int h = 0; h < OUT_DIM[1]; ++h) { for (int w = 0; w < OUT_DIM[2]; ++w) { for (int q = 0; q < K[0]; ++q) { - int t_in = -conv_p.pad[0] + t * conv_p.stride[0] + q; + int t_in = + -conv_p.pad[0] + t * conv_p.stride[0] + q * conv_p.dilation[0]; for (int r = 0; r < K[1]; ++r) { - int h_in = -conv_p.pad[1] + h * conv_p.stride[1] + r; + int h_in = -conv_p.pad[1] + h * conv_p.stride[1] + + r * conv_p.dilation[1]; for (int s = 0; s < K[2]; ++s) { - int w_in = -conv_p.pad[2] + w * conv_p.stride[2] + s; + int w_in = -conv_p.pad[2] + w * conv_p.stride[2] + + s * conv_p.dilation[2]; if (t_in < 0 || t_in >= IN_DIM[0] || h_in < 0 || h_in >= IN_DIM[1] || w_in < 0 || w_in >= IN_DIM[2]) { for (int g = 0; g < G; ++g) { @@ -449,9 +452,11 @@ void conv_ref( for (int m = 0; m < OC / G; ++m) { int sum = 0; for (int r = 0; r < K[0]; ++r) { - int h_in = -conv_p.pad[0] + h * conv_p.stride[0] + r; + int h_in = -conv_p.pad[0] + h * conv_p.stride[0] + + r * conv_p.dilation[0]; for (int s = 0; s < K[1]; ++s) { - int w_in = -conv_p.pad[1] + w * conv_p.stride[1] + s; + int w_in = -conv_p.pad[1] + w * conv_p.stride[1] + + s * conv_p.dilation[1]; for (int c = 0; c < IC / G; ++c) { int a = h_in < 0 || h_in >= IN_DIM[0] || w_in < 0 || w_in >= IN_DIM[1] @@ -501,11 +506,14 @@ void conv_ref( for (int m = 0; m < OC / G; ++m) { int sum = 0; for (int q = 0; q < K[0]; ++q) { - int t_in = -conv_p.pad[0] + t * conv_p.stride[0] + q; + int t_in = -conv_p.pad[0] + t * conv_p.stride[0] + + q * conv_p.dilation[0]; for (int r = 0; r < K[1]; ++r) { - int h_in = -conv_p.pad[1] + h * conv_p.stride[1] + r; + int h_in = -conv_p.pad[1] + h * conv_p.stride[1] + + r * conv_p.dilation[1]; for (int s = 0; s < K[2]; ++s) { - int w_in = -conv_p.pad[2] + w * conv_p.stride[2] + s; + int w_in = -conv_p.pad[2] + w * conv_p.stride[2] + + s * conv_p.dilation[2]; for (int c = 0; c < IC / G; ++c) { int a = t_in < 0 || t_in >= IN_DIM[0] || h_in < 0 || h_in >= IN_DIM[1] || w_in < 0 || w_in >= IN_DIM[2] @@ -542,427 +550,55 @@ void transposeConvWeights( const conv_param_t<SPATIAL_DIM>& conv_p, const std::int8_t* src, std::int8_t* dest) { - assert(SPATIAL_DIM == 2 && "Only 2D supported currently"); - int R = conv_p.K[0]; - int S = conv_p.K[1]; int G = conv_p.G; int IC_per_G = conv_p.IC / conv_p.G; int OC_per_G = conv_p.OC / conv_p.G; - // Transforms weights from G K/G (R S C/G) to G (R S C/G) K/G format. - for (int r = 0; r < R; ++r) { - for (int s = 0; s < S; ++s) { - for (int k = 0; k < OC_per_G; ++k) { - for (int g = 0; g < G; ++g) { - for (int c = 0; c < IC_per_G; ++c) { - dest[(((g * R + r) * S + s) * IC_per_G + c) * OC_per_G + k] = - src[(((g * OC_per_G + k) * R + r) * S + s) * IC_per_G + c]; - } - } - } - } - } -} - -void depthwise_3x3_pad_1_ref( - int N, - int H, - int W, - int K, - int stride_h, - int stride_w, - int32_t A_zero_point, - const uint8_t* A, - const int8_t* B, - int32_t* C) { - constexpr int R = 3, S = 3; - constexpr int PAD_T = 1, PAD_B = 1, PAD_L = 1, PAD_R = 1; - int H_OUT = (H + PAD_T + PAD_B - R) / stride_h + 1; - int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1; - - for (int n = 0; n < N; ++n) { - for (int h = 0; h < H_OUT; ++h) { - for (int w = 0; w < W_OUT; ++w) { - for (int k = 0; k < K; ++k) { - int sum = 0; - for (int r = 0; r < R; ++r) { - int h_in = -PAD_T + h * stride_h + r; - for (int s = 0; s < S; ++s) { - int w_in = -PAD_L + w * stride_w + s; - int a = h_in < 0 || h_in >= H || w_in < 0 || w_in >= W - ? A_zero_point - : A[((n * H + h_in) * W + w_in) * K + k]; - int b = B[(k * R + r) * S + s]; - sum += a * b; - } - } - C[((n * H_OUT + h) * W_OUT + w) * K + k] = sum; - } - } - } - } // for each n -}; - -void depthwise_3x3_pad_1_ref( - int N, - int H, - int W, - int K, - int stride_h, - int stride_w, - int32_t A_zero_point, - const uint8_t* A, - int32_t B_zero_point, - const int8_t* B, - float C_multiplier, - int32_t C_zero_point, - uint8_t* C, - const int32_t* col_offsets, - const int32_t* bias) { - constexpr int R = 3, S = 3; - constexpr int PAD_T = 1, PAD_B = 1, PAD_L = 1, PAD_R = 1; - int H_OUT = (H + PAD_T + PAD_B - R) / stride_h + 1; - int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1; - - vector<int32_t> C_int32(N * H_OUT * W_OUT * K); - depthwise_3x3_pad_1_ref( - N, H, W, K, stride_h, stride_w, A_zero_point, A, B, C_int32.data()); - - vector<int32_t> row_offsets(N * H_OUT * W_OUT * K); - for (int n = 0; n < N; ++n) { - for (int h = 0; h < H_OUT; ++h) { - for (int w = 0; w < W_OUT; ++w) { - for (int k = 0; k < K; ++k) { - int sum = 0; - for (int r = 0; r < R; ++r) { - int h_in = -PAD_T + h * stride_h + r; - for (int s = 0; s < S; ++s) { - int w_in = -PAD_L + w * stride_w + s; - int a = h_in < 0 || h_in >= H || w_in < 0 || w_in >= W - ? A_zero_point - : A[((n * H + h_in) * W + w_in) * K + k]; - sum += a; + assert( + (SPATIAL_DIM == 3 || SPATIAL_DIM == 2) && + "Only 2D and 3D convolutions are supported"); + if (SPATIAL_DIM == 2) { + int R = conv_p.K[0]; + int S = conv_p.K[1]; + // Transforms weights from G K/G (R S C/G) to G (R S C/G) K/G format. + for (int r = 0; r < R; ++r) { + for (int s = 0; s < S; ++s) { + for (int k = 0; k < OC_per_G; ++k) { + for (int g = 0; g < G; ++g) { + for (int c = 0; c < IC_per_G; ++c) { + dest[(((g * R + r) * S + s) * IC_per_G + c) * OC_per_G + k] = + src[(((g * OC_per_G + k) * R + r) * S + s) * IC_per_G + c]; } } - row_offsets[((n * H_OUT + h) * W_OUT + w) * K + k] = sum; } } } - } // for each n - - for (int i = 0; i < N * H_OUT * W_OUT; ++i) { - for (int k = 0; k < K; ++k) { - requantize_u8acc32_ref( - 1, - 1, - 1, - C_int32.data() + i * K + k, - C + i * K + k, - &C_multiplier, - C_zero_point, - A_zero_point, - &B_zero_point, - &row_offsets[i * K + k], - col_offsets + k, - bias ? bias + k : nullptr, - 1); - } - } -}; - -void depthwise_3x3_per_channel_quantization_pad_1_ref( - int N, - int H, - int W, - int K, - int stride_h, - int stride_w, - int32_t A_zero_point, - const uint8_t* A, - const int32_t* B_zero_point, - const int8_t* B, - const float* C_multiplier, - int32_t C_zero_point, - uint8_t* C, - const int32_t* col_offsets, - const int32_t* bias) { - constexpr int R = 3, S = 3; - constexpr int PAD_T = 1, PAD_B = 1, PAD_L = 1, PAD_R = 1; - int H_OUT = (H + PAD_T + PAD_B - R) / stride_h + 1; - int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1; - - vector<int32_t> C_int32(N * H_OUT * W_OUT * K); - depthwise_3x3_pad_1_ref( - N, H, W, K, stride_h, stride_w, A_zero_point, A, B, C_int32.data()); - - vector<int32_t> row_offsets(N * H_OUT * W_OUT * K); - for (int n = 0; n < N; ++n) { - for (int h = 0; h < H_OUT; ++h) { - for (int w = 0; w < W_OUT; ++w) { - for (int k = 0; k < K; ++k) { - int sum = 0; - for (int r = 0; r < R; ++r) { - int h_in = -PAD_T + h * stride_h + r; - for (int s = 0; s < S; ++s) { - int w_in = -PAD_L + w * stride_w + s; - int a = h_in < 0 || h_in >= H || w_in < 0 || w_in >= W - ? A_zero_point - : A[((n * H + h_in) * W + w_in) * K + k]; - sum += a; + } else { + // Transforms weights from G K/G (T R S C/G) to G (T R S C/G) K/G format. + int T = conv_p.K[0]; + int R = conv_p.K[1]; + int S = conv_p.K[2]; + for (int t = 0; t < T; ++t) { + for (int r = 0; r < R; ++r) { + for (int s = 0; s < S; ++s) { + for (int k = 0; k < OC_per_G; ++k) { + for (int g = 0; g < G; ++g) { + for (int c = 0; c < IC_per_G; ++c) { + dest + [((((g * T + t) * R + r) * S + s) * IC_per_G + c) * + OC_per_G + + k] = + src[((((g * OC_per_G + k) * T + t) * R + r) * S + s) * + IC_per_G + + c]; + } } } - row_offsets[((n * H_OUT + h) * W_OUT + w) * K + k] = sum; } } } - } // for each n - - for (int i = 0; i < N * H_OUT * W_OUT; ++i) { - for (int k = 0; k < K; ++k) { - requantize_u8acc32_ref( - 1, - 1, - 1, - C_int32.data() + i * K + k, - C + i * K + k, - &C_multiplier[k], - C_zero_point, - A_zero_point, - &B_zero_point[k], - &row_offsets[i * K + k], - col_offsets + k, - bias ? bias + k : nullptr, - 1); - } } -}; - -void depthwise_3x3x3_pad_1_ref( - int N, - int T, - int H, - int W, - int K, - int stride_t, - int stride_h, - int stride_w, - int32_t A_zero_point, - const uint8_t* A, - const int8_t* B, - int32_t* C) { - constexpr int K_T = 3, K_H = 3, K_W = 3; - constexpr int PAD_P = 1, PAD_N = 1, PAD_T = 1, PAD_B = 1, PAD_L = 1, - PAD_R = 1; - int T_OUT = (T + PAD_P + PAD_N - K_T) / stride_t + 1; - int H_OUT = (H + PAD_T + PAD_B - K_H) / stride_h + 1; - int W_OUT = (W + PAD_L + PAD_R - K_W) / stride_w + 1; - - for (int n = 0; n < N; ++n) { - for (int t = 0; t < T_OUT; ++t) { - for (int h = 0; h < H_OUT; ++h) { - for (int w = 0; w < W_OUT; ++w) { - for (int k = 0; k < K; ++k) { - int sum = 0; - for (int k_t = 0; k_t < K_T; ++k_t) { - int t_in = -PAD_P + t * stride_t + k_t; - for (int k_h = 0; k_h < K_H; ++k_h) { - int h_in = -PAD_T + h * stride_h + k_h; - for (int k_w = 0; k_w < K_W; ++k_w) { - int w_in = -PAD_L + w * stride_w + k_w; - int a = t_in < 0 || t_in >= T || h_in < 0 || h_in >= H || - w_in < 0 || w_in >= W - ? A_zero_point - : A[(((n * T + t_in) * H + h_in) * W + w_in) * K + k]; - int b = B[((k * K_T + k_t) * K_H + k_h) * K_W + k_w]; - sum += a * b; - } - } - } - C[(((n * T_OUT + t) * H_OUT + h) * W_OUT + w) * K + k] = sum; - } - } // w - } // h - } // t - } // for each n -}; - -void depthwise_3x3x3_pad_1_ref( - int N, - int T, - int H, - int W, - int K, - int stride_t, - int stride_h, - int stride_w, - int32_t A_zero_point, - const uint8_t* A, - int32_t B_zero_point, - const int8_t* B, - float C_multiplier, - int32_t C_zero_point, - uint8_t* C, - const int32_t* col_offsets, - const int32_t* bias) { - constexpr int K_T = 3, K_H = 3, K_W = 3; - constexpr int PAD_P = 1, PAD_N = 1, PAD_T = 1, PAD_B = 1, PAD_L = 1, - PAD_R = 1; - int T_OUT = (T + PAD_P + PAD_N - K_T) / stride_t + 1; - int H_OUT = (H + PAD_T + PAD_B - K_H) / stride_h + 1; - int W_OUT = (W + PAD_L + PAD_R - K_W) / stride_w + 1; - - vector<int32_t> C_int32(N * T_OUT * H_OUT * W_OUT * K); - depthwise_3x3x3_pad_1_ref( - N, - T, - H, - W, - K, - stride_t, - stride_h, - stride_w, - A_zero_point, - A, - B, - C_int32.data()); - - vector<int32_t> row_offsets(N * T_OUT * H_OUT * W_OUT * K); - for (int n = 0; n < N; ++n) { - for (int t = 0; t < T_OUT; ++t) { - for (int h = 0; h < H_OUT; ++h) { - for (int w = 0; w < W_OUT; ++w) { - for (int k = 0; k < K; ++k) { - int sum = 0; - for (int k_t = 0; k_t < K_T; ++k_t) { - int t_in = -PAD_P + t * stride_t + k_t; - for (int k_h = 0; k_h < K_H; ++k_h) { - int h_in = -PAD_T + h * stride_h + k_h; - for (int k_w = 0; k_w < K_W; ++k_w) { - int w_in = -PAD_L + w * stride_w + k_w; - int a = t_in < 0 || t_in >= T || h_in < 0 || h_in >= H || - w_in < 0 || w_in >= W - ? A_zero_point - : A[(((n * T + t_in) * H + h_in) * W + w_in) * K + k]; - sum += a; - } - } - } - row_offsets[(((n * T_OUT + t) * H_OUT + h) * W_OUT + w) * K + k] = - sum; - } - } // w - } // h - } // t - } // for each n - - for (int i = 0; i < N * T_OUT * H_OUT * W_OUT; ++i) { - for (int k = 0; k < K; ++k) { - requantize_u8acc32_ref( - 1, - 1, - 1, - C_int32.data() + i * K + k, - C + i * K + k, - &C_multiplier, - C_zero_point, - A_zero_point, - &B_zero_point, - &row_offsets[i * K + k], - col_offsets + k, - bias ? bias + k : nullptr, - 1); - } - } -}; - -void depthwise_3x3x3_per_channel_quantization_pad_1_ref( - int N, - int T, - int H, - int W, - int K, - int stride_t, - int stride_h, - int stride_w, - int32_t A_zero_point, - const uint8_t* A, - const int32_t* B_zero_point, - const int8_t* B, - const float* C_multiplier, - int32_t C_zero_point, - uint8_t* C, - const int32_t* col_offsets, - const int32_t* bias) { - constexpr int K_T = 3, K_H = 3, K_W = 3; - constexpr int PAD_P = 1, PAD_N = 1, PAD_T = 1, PAD_B = 1, PAD_L = 1, - PAD_R = 1; - int T_OUT = (T + PAD_P + PAD_N - K_T) / stride_t + 1; - int H_OUT = (H + PAD_T + PAD_B - K_H) / stride_h + 1; - int W_OUT = (W + PAD_L + PAD_R - K_W) / stride_w + 1; - - vector<int32_t> C_int32(N * T_OUT * H_OUT * W_OUT * K); - depthwise_3x3x3_pad_1_ref( - N, - T, - H, - W, - K, - stride_t, - stride_h, - stride_w, - A_zero_point, - A, - B, - C_int32.data()); - - vector<int32_t> row_offsets(N * T_OUT * H_OUT * W_OUT * K); - for (int n = 0; n < N; ++n) { - for (int t = 0; t < T_OUT; ++t) { - for (int h = 0; h < H_OUT; ++h) { - for (int w = 0; w < W_OUT; ++w) { - for (int k = 0; k < K; ++k) { - int sum = 0; - for (int k_t = 0; k_t < K_T; ++k_t) { - int t_in = -PAD_P + t * stride_t + k_t; - for (int k_h = 0; k_h < K_H; ++k_h) { - int h_in = -PAD_T + h * stride_h + k_h; - for (int k_w = 0; k_w < K_W; ++k_w) { - int w_in = -PAD_L + w * stride_w + k_w; - int a = t_in < 0 || t_in >= T || h_in < 0 || h_in >= H || - w_in < 0 || w_in >= W - ? A_zero_point - : A[(((n * T + t_in) * H + h_in) * W + w_in) * K + k]; - sum += a; - } - } - } - row_offsets[(((n * T_OUT + t) * H_OUT + h) * W_OUT + w) * K + k] = - sum; - } - } // w - } // h - } // t - } // for each n - - for (int i = 0; i < N * T_OUT * H_OUT * W_OUT; ++i) { - for (int k = 0; k < K; ++k) { - requantize_u8acc32_ref( - 1, - 1, - 1, - C_int32.data() + i * K + k, - C + i * K + k, - &C_multiplier[k], - C_zero_point, - A_zero_point, - &B_zero_point[k], - &row_offsets[i * K + k], - col_offsets + k, - bias ? bias + k : nullptr, - 1); - } - } -}; +} template void transposeConvWeights( const conv_param_t<2>& conv_p, diff --git a/src/RefImplementations.h b/src/RefImplementations.h index 082bdf1..a20e348 100644 --- a/src/RefImplementations.h +++ b/src/RefImplementations.h @@ -215,124 +215,4 @@ FBGEMM_API void im2col_ref( std::int32_t A_zero_point, std::uint8_t* Ao); -/* - * @brief Reference implementation of depthwise convolution with a 3x3 filter - * and padding size 1. - */ -FBGEMM_API void depthwise_3x3_pad_1_ref( - int N, - int H, - int W, - int K, - int stride_h, - int stride_w, - std::int32_t A_zero_point, - const std::uint8_t* A, - const std::int8_t* B, - std::int32_t* C); - -/* - * @brief Reference implementation of depthwise convolution with a 3x3 filter - * and padding size 1, followed by requantization. (the same scaling factors and - * zero points for each channel). - */ -FBGEMM_API void depthwise_3x3_pad_1_ref( - int N, - int H, - int W, - int K, - int stride_h, - int stride_w, - std::int32_t A_zero_point, - const std::uint8_t* A, - std::int32_t B_zero_point, - const std::int8_t* B, - float C_multiplier, - std::int32_t C_zero_point, - std::uint8_t* C, - const std::int32_t* col_offsets, - const std::int32_t* bias); - -/* - * @brief Reference implementation of depthwise convolution with a 3x3 filter - * and padding size 1, followed by requantization. (different scaling factors - * and zero points for each channel). - */ -FBGEMM_API void depthwise_3x3_per_channel_quantization_pad_1_ref( - int N, - int H, - int W, - int K, - int stride_h, - int stride_w, - std::int32_t A_zero_point, - const std::uint8_t* A, - const std::int32_t* B_zero_point, - const std::int8_t* B, - const float* C_multiplier, - std::int32_t C_zero_point, - std::uint8_t* C, - const std::int32_t* col_offsets, - const std::int32_t* bias); - -/* - * @brief Reference implementation of 3D depthwise convolution with a 3x3x3 - * filter and padding size 1. - */ -FBGEMM_API void depthwise_3x3x3_pad_1_ref( - int N, - int T, - int H, - int W, - int K, - int stride_t, - int stride_h, - int stride_w, - std::int32_t A_zero_point, - const std::uint8_t* A, - const std::int8_t* B, - std::int32_t* C); - -/* - * @brief Reference implementation of 3D depthwise convolution with a 3x3x3 - * filter and padding size 1, followed by requantization. - */ -FBGEMM_API void depthwise_3x3x3_pad_1_ref( - int N, - int T, - int H, - int W, - int K, - int stride_t, - int stride_h, - int stride_w, - std::int32_t A_zero_point, - const std::uint8_t* A, - std::int32_t B_zero_point, - const std::int8_t* B, - float C_multiplier, - std::int32_t C_zero_point, - std::uint8_t* C, - const std::int32_t* col_offsets, - const std::int32_t* bias); - -FBGEMM_API void depthwise_3x3x3_per_channel_quantization_pad_1_ref( - int N, - int T, - int H, - int W, - int K, - int stride_t, - int stride_h, - int stride_w, - std::int32_t A_zero_point, - const std::uint8_t* A, - const std::int32_t* B_zero_point, - const std::int8_t* B, - const float* C_multiplier, - std::int32_t C_zero_point, - std::uint8_t* C, - const std::int32_t* col_offsets, - const std::int32_t* bias); - } // namespace fbgemm diff --git a/src/Utils.cc b/src/Utils.cc index 355a5cb..2e88561 100755 --- a/src/Utils.cc +++ b/src/Utils.cc @@ -180,11 +180,7 @@ void transpose_simd( // Run time CPU detection if (cpuinfo_initialize()) { if (fbgemmHasAvx512Support()) { -#ifdef _MSC_VER - internal::transpose_8x8(M, N, src, ld_src, dst, ld_dst); -#else internal::transpose_16x16(M, N, src, ld_src, dst, ld_dst); -#endif } else if (fbgemmHasAvx2Support()) { internal::transpose_8x8(M, N, src, ld_src, dst, ld_dst); } else { @@ -206,4 +202,7 @@ bool fbgemmHasAvx2Support() { return (cpuinfo_initialize() && cpuinfo_has_x86_avx2()); } +bool fbgemmHasAvx512VnniSupport() { + return (cpuinfo_has_x86_avx512vnni()); +} } // namespace fbgemm diff --git a/test/FP16Test.cc b/test/FP16Test.cc index eb49086..3267655 100644 --- a/test/FP16Test.cc +++ b/test/FP16Test.cc @@ -27,7 +27,26 @@ using namespace fbgemm; namespace { // The template parameter is transpose of A and B class FBGemmFP16Test - : public testing::TestWithParam<pair<matrix_op_t, matrix_op_t>> {}; + : public testing::TestWithParam<pair<matrix_op_t, matrix_op_t>> { + protected: + vector<vector<int>> GenShapes() const { + vector<vector<int>> shapes; + random_device r; + default_random_engine generator(r()); + uniform_int_distribution<int> dm(1, 256); + uniform_int_distribution<int> dnk(1, 1024); + for (int i = 0; i < 10; i++) { + int m = dm(generator); + int n = dnk(generator); + int k = dnk(generator); + shapes.push_back({m, n, k}); + if (m > 10) { + shapes.push_back({(m / 10) * 10, n, k}); + } + } + return shapes; + } +}; }; // namespace INSTANTIATE_TEST_CASE_P( @@ -44,21 +63,75 @@ INSTANTIATE_TEST_CASE_P( matrix_op_t::Transpose, matrix_op_t::Transpose)*/)); TEST_P(FBGemmFP16Test, Test) { - vector<vector<int>> shapes; - random_device r; - default_random_engine generator(r()); - uniform_int_distribution<int> dm(1, 256); - uniform_int_distribution<int> dnk(1, 1024); - for (int i = 0; i < 10; i++) { - int m = dm(generator); - int n = dnk(generator); - int k = dnk(generator); - shapes.push_back({m, n, k}); - if (m > 10) { - shapes.push_back({(m / 10) * 10, n, k}); + auto shapes = GenShapes(); + float alpha = 1.f, beta = 0.f; + matrix_op_t atrans, btrans; + tie(atrans, btrans) = GetParam(); + + for (auto s : shapes) { + int m = s[0]; + int n = s[1]; + int k = s[2]; + + cerr << "m = " << m << " n = " << n << " k = " << k; + if (atrans == matrix_op_t::Transpose) { + cerr << " A_transposed"; + } + if (btrans == matrix_op_t::Transpose) { + cerr << " B_transposed"; + } + cerr << endl; + + // initialize with small numbers + aligned_vector<int> Aint(m * k); + aligned_vector<int> Bint(k * n); + randFill(Aint, 0, 4); + randFill(Bint, 0, 4); + aligned_vector<float> A(Aint.begin(), Aint.end()); + aligned_vector<float> B(Bint.begin(), Bint.end()); + + aligned_vector<float> C(m * n, NAN); + + aligned_vector<float> A_ref(A), B_ref(B), C_ref(C); + + if (atrans == matrix_op_t::Transpose) { + transpose_matrix(A_ref.data(), k, m); + } + if (btrans == matrix_op_t::Transpose) { + transpose_matrix(B_ref.data(), n, k); + } + + // Gold via reference sgemm + matmul_fp_ref(m, n, k, k, n, n, A_ref.data(), B_ref.data(), C_ref.data()); + + // fbgemm fp16 + PackedGemmMatrixFP16 Bp(btrans, k, n, alpha, B.data()); +#ifdef _OPENMP +#pragma omp parallel +#endif + { + int num_threads = fbgemm_get_num_threads(); + int tid = fbgemm_get_thread_num(); + + cblas_gemm_compute( + atrans, m, A.data(), Bp, beta, C.data(), tid, num_threads); + } + + // correctness check + for (int i = 0; i < m; ++i) { + for (int j = 0; j < n; ++j) { + float expected = C_ref[i * n + j]; + float actual = C[i * n + j]; + EXPECT_EQ(expected, actual) + << "GEMM results differ at (" << i << ", " << j << "). ref " + << expected << " FBGemm " << actual; + } } } +} +TEST_P(FBGemmFP16Test, Unpack) { + auto shapes = GenShapes(); float alpha = 1.f, beta = 0.f; matrix_op_t atrans, btrans; tie(atrans, btrans) = GetParam(); @@ -101,6 +174,23 @@ TEST_P(FBGemmFP16Test, Test) { // fbgemm fp16 PackedGemmMatrixFP16 Bp(btrans, k, n, alpha, B.data()); + EXPECT_TRUE(Bp.packed()); + + // Test unpack + aligned_vector<float16> tmp(Bp.matSize()); + memcpy(tmp.data(), Bp.pmat(), Bp.matSize() * sizeof(float16)); + Bp.unpackFromSrc(btrans, tmp.data()); + EXPECT_FALSE(Bp.packed()); + memcpy(tmp.data(), Bp.pmat(), Bp.matSize() * sizeof(float16)); + for (int i = 0; i < k; ++i) { + for (int j = 0; j < n; ++j) { + EXPECT_EQ(B[i * n + j], cpu_half2float(tmp[i * n + j])); + } + } + + // Pack it back + Bp.packFromSrc(btrans, tmp.data()); + EXPECT_TRUE(Bp.packed()); #ifdef _OPENMP #pragma omp parallel diff --git a/test/GConvTest.cc b/test/GConvTest.cc index 84f0d52..982208b 100644 --- a/test/GConvTest.cc +++ b/test/GConvTest.cc @@ -25,14 +25,6 @@ using namespace std; using namespace fbgemm; -vector<matrix_op_t> transposeVals{matrix_op_t::NoTranspose, - matrix_op_t::Transpose}; - -vector<QuantizationGranularity> qGranularityVals{ - QuantizationGranularity::TENSOR, - QuantizationGranularity::GROUP, - QuantizationGranularity::OUT_CHANNEL}; - namespace { class fbgemmGConvAcc32Test : public testing::TestWithParam<tuple<matrix_op_t, matrix_op_t>> {}; @@ -43,6 +35,8 @@ class fbgemmGConvAcc32WithQuantGranularityTest QuantizationGranularity, bool, bool>> {}; +class fbgemmGConvPackTest + : public testing::TestWithParam<tuple<matrix_op_t, matrix_op_t>> {}; }; // namespace INSTANTIATE_TEST_CASE_P( @@ -61,6 +55,13 @@ INSTANTIATE_TEST_CASE_P( ::testing::ValuesIn(qGranularityVals), ::testing::Bool(), // A symmetric ::testing::Bool())); // B symmetric + +INSTANTIATE_TEST_CASE_P( + InstantiationName, + fbgemmGConvPackTest, + ::testing::Combine( + ::testing::Values(matrix_op_t::NoTranspose), + ::testing::ValuesIn(transposeVals))); /** * @brief Shapes for unit test. */ @@ -413,3 +414,51 @@ TEST_P(fbgemmGConvAcc32Test, NoRequantizeTest) { static_cast<int32_t>(0)); } // for each shape } + +/** + * @brief Unit test for packing and unpacking the weight tensor + */ +TEST_P(fbgemmGConvPackTest, PackUnpackTest) { + vector<conv_param_t<>> shapes(GetShapes_()); + matrix_op_t atrans, btrans; + tie(atrans, btrans) = GetParam(); + + for (auto conv_p : shapes) { + int R = conv_p.K[0]; + int S = conv_p.K[1]; + int IC_per_G = conv_p.IC / conv_p.G; + int OC_per_G = conv_p.OC / conv_p.G; + + // Weights -- test the packing/unpacking of only the weights + // when btrans == Transpose, the weight matrix is in layout G K/G (R S C/G) + // instead of G (R S C/G) K/G + int weight_len = R * S * conv_p.G * IC_per_G * OC_per_G; + aligned_vector<int8_t> Bint8(weight_len, 0); + + // Random fill the weights + randFill<int8_t>(Bint8, -4, 4); + + // Instantiate the object + PackWeightMatrixForGConv<int8_t> packedWeights( + btrans, conv_p, Bint8.data(), nullptr); + + // Setup a buffer to get pack -> unpacked results + aligned_vector<int8_t> unpack_buf(weight_len, 0); + + // START Actual pack-unpack operations + // Perform packing first. This should populate pdata_ of packedWeights + packedWeights.pack(); + + // Next perform unpacking + packedWeights.unpack(unpack_buf.data()); + // END actual pack-unpack operations + + // Sanity check + for (int i = 0; i < weight_len; ++i) { + EXPECT_EQ(Bint8.data()[i], unpack_buf.data()[i]) + << "Pack/Unpack results differ at index " << i + << ", Reference: " << static_cast<int>(Bint8.data()[i]) + << ", Pack-Unpacked: " << static_cast<int>(unpack_buf.data()[i]); + } + } // for each shape +} diff --git a/test/I8DepthwiseTest.cc b/test/I8DepthwiseTest.cc index 11bd625..9de6943 100644 --- a/test/I8DepthwiseTest.cc +++ b/test/I8DepthwiseTest.cc @@ -4,7 +4,6 @@ * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. */ -#include "I8DepthwiseTest.h" #include <cmath> #include <cstdio> @@ -22,6 +21,7 @@ using namespace std; namespace fbgemm { // From Xray OCR +// clang-format off static vector<vector<int>> shapes = { // NOTE: clang-format wants to use a different formatting but the current // formatting should be easier to read. @@ -68,9 +68,28 @@ static vector<vector<int>> shapes = { { 1, 8, 4, 4, 1, }, }; +static vector<vector<int>> shapes_3d = { + // NOTE: clang-format wants to use a different formatting but the current + // formatting should be easier to read. + // N, K, T_in, H_in, W_in, stride + { 1, 32, 16, 28, 28, 1, }, + { 1, 128, 8, 14, 14, 2, }, + { 5, 16, 32, 56, 56, 1, }, + { 1, 8, 4, 4, 4, 1, }, +}; +// clang-format on + namespace { -class FBGemmDepthWiseTest - : public testing::TestWithParam<tuple<bool, bool>> {}; + +class FBGemmDepthWiseTest : public testing::TestWithParam<tuple<bool, bool>> {}; + +// Two parameters are K (or Groups) and kernel_prod, i.e., +// (output_channels)(kernel_prod) +// output_channels == Groups. +// For example, kernel_prod for 3x3 convolution is 9 +class FBGemmDepthWisePackUnpackTest + : public testing::TestWithParam<tuple<int, int>> {}; + } // namespace INSTANTIATE_TEST_CASE_P( @@ -78,6 +97,13 @@ INSTANTIATE_TEST_CASE_P( FBGemmDepthWiseTest, ::testing::Combine(::testing::Bool(), ::testing::Bool())); +INSTANTIATE_TEST_CASE_P( + InstantiationName, + FBGemmDepthWisePackUnpackTest, + ::testing::Combine( + ::testing::ValuesIn({8, 16, 24, 32, 40, 64, 72}), + ::testing::ValuesIn({1, 2, 3, 4, 5, 9, 10, 11, 27}))); + TEST_P(FBGemmDepthWiseTest, Test3x3) { bool a_symmetric, b_symmetric; tie(a_symmetric, b_symmetric) = GetParam(); @@ -90,13 +116,29 @@ TEST_P(FBGemmDepthWiseTest, Test3x3) { int stride_h = shape[4]; int stride_w = stride_h; constexpr int R = 3, S = 3; - constexpr int PAD_T = 1, PAD_B = 1, PAD_L = 1, PAD_R = 1; - int H_OUT = (H + PAD_T + PAD_B - R) / stride_h + 1; - int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1; + int PAD_T = (R - 1) / 2, PAD_B = (R - 1) / 2, PAD_L = (S - 1) / 2, + PAD_R = (S - 1) / 2; + + conv_param_t<2> conv_p( + N, + K, + K, + {H, W}, + K, + {R, S}, + {stride_h, stride_w}, + {PAD_T, PAD_L, PAD_B, PAD_R}); + int H_OUT = conv_p.OUT_DIM[0]; + int W_OUT = conv_p.OUT_DIM[1]; + + int MDim = N * H_OUT * W_OUT; + int KDim = R * S * K; + int KDimPerGroup = KDim / conv_p.G; aligned_vector<uint8_t> A(N * H * W * K); - aligned_vector<int8_t> B(K * R * S); - aligned_vector<int32_t> C_ref(N * H_OUT * W_OUT * K), C(C_ref.size()); + aligned_vector<int8_t> B(KDim); + aligned_vector<int32_t> C_ref(MDim * K), C(C_ref.size()); + aligned_vector<uint8_t> C_uint8_ref(C_ref.size()), C_uint8(C_ref.size()); randFill<uint8_t>(A, 0, 86); int32_t A_zero_point = a_symmetric ? 0 : 43; @@ -104,48 +146,54 @@ TEST_P(FBGemmDepthWiseTest, Test3x3) { randFill<int8_t>(B, -16, 16); int32_t B_zero_point = b_symmetric ? 0 : 5; - depthwise_3x3_pad_1_ref( - N, - H, - W, - K, - stride_h, - stride_w, - A_zero_point, - A.data(), - B.data(), - C_ref.data()); - - int32_t minimum = *min_element(C_ref.begin(), C_ref.end()); - int32_t maximum = *max_element(C_ref.begin(), C_ref.end()); - - float C_multiplier = 255. / (maximum - minimum); + aligned_vector<float> C_multiplier(1); + randFill(C_multiplier, 0.001234f / 2, 0.001234f * 3 / 2); + int32_t C_zero_point = 5; aligned_vector<int32_t> col_offsets(K); aligned_vector<int32_t> bias(K); randFill(col_offsets, -100, 100); randFill(bias, -40, 40); - int32_t C_zero_point = 5; - aligned_vector<uint8_t> C_uint8_ref(C_ref.size()), C_uint8(C_ref.size()); - depthwise_3x3_pad_1_ref( - N, - H, - W, - K, - stride_h, - stride_w, - A_zero_point, - A.data(), - B_zero_point, - B.data(), - C_multiplier, - C_zero_point, - C_uint8_ref.data(), - col_offsets.data(), - bias.data()); + vector<int32_t> row_offsets(MDim); + // im2col to compute row offset later + vector<uint8_t> A_im2col; + if (!b_symmetric) { + A_im2col.resize(MDim * KDim); + im2col_ref(conv_p, A.data(), A_zero_point, A_im2col.data()); + } - Packed3x3ConvMatrix Bp(K, B.data()); + conv_ref(conv_p, A.data(), A_zero_point, B.data(), C_ref.data()); + + for (int g = 0; g < conv_p.G; ++g) { + // Compute row offset + if (!b_symmetric) { + row_offsets_u8acc32_ref( + MDim, + KDimPerGroup, + KDim, + A_im2col.data() + g * KDimPerGroup, + row_offsets.data()); + } + + // Requantization + requantize_u8acc32_ref( + MDim, + 1, + conv_p.G, + C_ref.data() + g, + C_uint8_ref.data() + g, + C_multiplier.data(), + C_zero_point, + A_zero_point, + &B_zero_point, + row_offsets.data(), + col_offsets.data() + g, + bias.data() + g, + K); + } + + PackedDepthWiseConvMatrix Bp(K, 3 * 3, B.data()); depthwise_3x3_pad_1( N, @@ -158,12 +206,13 @@ TEST_P(FBGemmDepthWiseTest, Test3x3) { A.data(), B_zero_point, Bp, - C_multiplier, + C_multiplier[0], C_zero_point, C_uint8.data(), a_symmetric ? nullptr : col_offsets.data(), bias.data(), false, /* fuse_relu */ + 1.0f, /* act_scale * w_scale */ 0, 1); @@ -205,67 +254,83 @@ TEST_P(FBGemmDepthWiseTest, Test3x3x3) { constexpr int K_T = 3, K_H = 3, K_W = 3; constexpr int PAD_P = 1, PAD_N = 1, PAD_T = 1, PAD_B = 1, PAD_L = 1, PAD_R = 1; - int T_OUT = (T + PAD_P + PAD_N - K_T) / stride_t + 1; - int H_OUT = (H + PAD_T + PAD_B - K_H) / stride_h + 1; - int W_OUT = (W + PAD_L + PAD_R - K_W) / stride_w + 1; + + conv_param_t<3> conv_p( + N, + K, + K, + {T, H, W}, + K, + {K_T, K_H, K_W}, + {stride_t, stride_h, stride_w}, + {PAD_P, PAD_T, PAD_L, PAD_N, PAD_B, PAD_R}); + int T_OUT = conv_p.OUT_DIM[0]; + int H_OUT = conv_p.OUT_DIM[1]; + int W_OUT = conv_p.OUT_DIM[2]; + + int MDim = N * T_OUT * H_OUT * W_OUT; + int KDim = K_T * K_H * K_W * K; + int KDimPerGroup = KDim / conv_p.G; aligned_vector<uint8_t> A(N * T * H * W * K); - aligned_vector<int8_t> B(K * K_T * K_H * K_W); - aligned_vector<int32_t> C_ref(N * T_OUT * H_OUT * W_OUT * K), - C(C_ref.size()); + aligned_vector<int8_t> B(KDim); + aligned_vector<int32_t> C_ref(MDim * K), C(C_ref.size()); + aligned_vector<uint8_t> C_uint8_ref(C_ref.size()), C_uint8(C_ref.size()); randFill<uint8_t>(A, 0, 86); int32_t A_zero_point = a_symmetric ? 0 : 43; randFill<int8_t>(B, -16, 16); - int32_t B_zero_point = 5; - - depthwise_3x3x3_pad_1_ref( - N, - T, - H, - W, - K, - stride_t, - stride_h, - stride_w, - A_zero_point, - A.data(), - B.data(), - C_ref.data()); - - int32_t minimum = *min_element(C_ref.begin(), C_ref.end()); - int32_t maximum = *max_element(C_ref.begin(), C_ref.end()); + int32_t B_zero_point = b_symmetric ? 0 : 5; - float C_multiplier = 255. / (maximum - minimum); + aligned_vector<float> C_multiplier(1); + randFill(C_multiplier, 0.001234f / 2, 0.001234f * 3 / 2); + int32_t C_zero_point = 5; aligned_vector<int32_t> col_offsets(K); aligned_vector<int32_t> bias(K); randFill(col_offsets, -100, 100); randFill(bias, -40, 40); - int32_t C_zero_point = 5; - aligned_vector<uint8_t> C_uint8_ref(C_ref.size()), C_uint8(C_ref.size()); - depthwise_3x3x3_pad_1_ref( - N, - T, - H, - W, - K, - stride_t, - stride_h, - stride_w, - A_zero_point, - A.data(), - B_zero_point, - B.data(), - C_multiplier, - C_zero_point, - C_uint8_ref.data(), - col_offsets.data(), - bias.data()); + vector<int32_t> row_offsets(MDim); + // im2col to compute row offset later + vector<uint8_t> A_im2col; + if (!b_symmetric) { + A_im2col.resize(MDim * KDim); + im2col_ref(conv_p, A.data(), A_zero_point, A_im2col.data()); + } - Packed3x3x3ConvMatrix Bp(K, B.data()); + conv_ref(conv_p, A.data(), A_zero_point, B.data(), C_ref.data()); + + for (int g = 0; g < conv_p.G; ++g) { + // Compute row offset + if (!b_symmetric) { + row_offsets_u8acc32_ref( + MDim, + KDimPerGroup, + KDim, + A_im2col.data() + g * KDimPerGroup, + row_offsets.data()); + } + + // Requantization + requantize_u8acc32_ref( + MDim, + 1, + conv_p.G, + C_ref.data() + g, + C_uint8_ref.data() + g, + C_multiplier.data(), + C_zero_point, + A_zero_point, + &B_zero_point, + row_offsets.data(), + col_offsets.data() + g, + bias.data() + g, + K); + } + + PackedDepthWiseConvMatrix Bp(K, 3 * 3 * 3, B.data()); depthwise_3x3x3_pad_1( N, @@ -280,10 +345,10 @@ TEST_P(FBGemmDepthWiseTest, Test3x3x3) { A.data(), B_zero_point, Bp, - C_multiplier, + C_multiplier[0], C_zero_point, C_uint8.data(), - col_offsets.data(), + a_symmetric ? nullptr : col_offsets.data(), bias.data(), false, /* fuse_relu */ 0, @@ -297,8 +362,8 @@ TEST_P(FBGemmDepthWiseTest, Test3x3x3) { for (int k = 0; k < K; ++k) { int32_t expected = C_uint8_ref [(((n * T_OUT + t) * H_OUT + h) * W_OUT + w) * K + k]; - int32_t actual = C_uint8 - [(((n * T_OUT + t) * H_OUT + h) * W_OUT + w) * K + k]; + int32_t actual = + C_uint8[(((n * T_OUT + t) * H_OUT + h) * W_OUT + w) * K + k]; EXPECT_EQ(expected, actual) << "Depthwise 3x3 results differ at (" << n << ", " << t << ", " << h << ", " << w << ", " << k << ")."; @@ -319,14 +384,29 @@ TEST(FBGemmDepthWiseTest, Test3x3PerChannelQuantization) { int stride_h = shape[4]; int stride_w = stride_h; constexpr int R = 3, S = 3; - constexpr int PAD_T = 1, PAD_B = 1, PAD_L = 1, PAD_R = 1; - int H_OUT = (H + PAD_T + PAD_B - R) / stride_h + 1; - int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1; + int PAD_T = (R - 1) / 2, PAD_B = (R - 1) / 2, PAD_L = (S - 1) / 2, + PAD_R = (S - 1) / 2; + + conv_param_t<2> conv_p( + N, + K, + K, + {H, W}, + K, + {R, S}, + {stride_h, stride_w}, + {PAD_T, PAD_L, PAD_B, PAD_R}); + int H_OUT = conv_p.OUT_DIM[0]; + int W_OUT = conv_p.OUT_DIM[1]; + + int MDim = N * H_OUT * W_OUT; + int KDim = R * S * K; + int KDimPerGroup = KDim / conv_p.G; aligned_vector<uint8_t> A(N * H * W * K); - aligned_vector<int8_t> B(K * R * S); - int32_t C_num_rows = N * H_OUT * W_OUT; - aligned_vector<int32_t> C_ref(C_num_rows * K), C(C_ref.size()); + aligned_vector<int8_t> B(KDim); + aligned_vector<int32_t> C_ref(MDim * K), C(C_ref.size()); + aligned_vector<uint8_t> C_uint8_ref(C_ref.size()), C_uint8(C_ref.size()); randFill<uint8_t>(A, 0, 86); int32_t A_zero_point = 43; @@ -342,28 +422,8 @@ TEST(FBGemmDepthWiseTest, Test3x3PerChannelQuantization) { B_zero_point[k] = 5 + k; } - depthwise_3x3_pad_1_ref( - N, - H, - W, - K, - stride_h, - stride_w, - A_zero_point, - A.data(), - B.data(), - C_ref.data()); - - aligned_vector<int32_t> C_ref_transpose(C_ref); - transpose_matrix(C_ref.data(), C_num_rows, K); - vector<float> C_multiplier(K); - for (auto k = 0; k < K; ++k) { - auto C_ref_k_begin = C_ref_transpose.begin() + k * C_num_rows; - auto C_ref_k_end = C_ref_k_begin + C_num_rows; - int32_t minimum = *min_element(C_ref_k_begin, C_ref_k_end); - int32_t maximum = *max_element(C_ref_k_begin, C_ref_k_end); - C_multiplier[k] = 255. / (maximum - minimum); - } + aligned_vector<float> C_multiplier(K); + randFill(C_multiplier, 0.001234f / 2, 0.001234f * 3 / 2); int32_t C_zero_point = 5; aligned_vector<int32_t> col_offsets(K); @@ -371,25 +431,40 @@ TEST(FBGemmDepthWiseTest, Test3x3PerChannelQuantization) { randFill(col_offsets, -100, 100); randFill(bias, -40, 40); - aligned_vector<uint8_t> C_uint8_ref(C_ref.size()), C_uint8(C_ref.size()); - depthwise_3x3_per_channel_quantization_pad_1_ref( - N, - H, - W, - K, - stride_h, - stride_w, - A_zero_point, - A.data(), - B_zero_point.data(), - B.data(), - C_multiplier.data(), - C_zero_point, - C_uint8_ref.data(), - col_offsets.data(), - bias.data()); + // im2col to compute row offset later + vector<int32_t> row_offsets(MDim); + vector<uint8_t> A_im2col(MDim * KDim); + im2col_ref(conv_p, A.data(), A_zero_point, A_im2col.data()); + + conv_ref(conv_p, A.data(), A_zero_point, B.data(), C_ref.data()); + + for (int g = 0; g < conv_p.G; ++g) { + // Compute row offset + row_offsets_u8acc32_ref( + MDim, + KDimPerGroup, + KDim, + A_im2col.data() + g * KDimPerGroup, + row_offsets.data()); + + // Requantization + requantize_u8acc32_ref( + MDim, + 1, + conv_p.G, + C_ref.data() + g, + C_uint8_ref.data() + g, + C_multiplier.data() + g, + C_zero_point, + A_zero_point, + B_zero_point.data() + g, + row_offsets.data(), + col_offsets.data() + g, + bias.data() + g, + K); + } - Packed3x3ConvMatrix Bp(K, B.data()); + PackedDepthWiseConvMatrix Bp(K, 3 * 3, B.data()); depthwise_3x3_per_channel_quantization_pad_1( N, @@ -442,14 +517,28 @@ TEST(FBGemmDepthWiseTest, Test3x3x3PerChannelQuantization) { constexpr int K_T = 3, K_H = 3, K_W = 3; constexpr int PAD_P = 1, PAD_N = 1, PAD_T = 1, PAD_B = 1, PAD_L = 1, PAD_R = 1; - int T_OUT = (T + PAD_P + PAD_N - K_T) / stride_t + 1; - int H_OUT = (H + PAD_T + PAD_B - K_H) / stride_h + 1; - int W_OUT = (W + PAD_L + PAD_R - K_W) / stride_w + 1; + + conv_param_t<3> conv_p( + N, + K, + K, + {T, H, W}, + K, + {K_T, K_H, K_W}, + {stride_t, stride_h, stride_w}, + {PAD_P, PAD_T, PAD_L, PAD_N, PAD_B, PAD_R}); + int T_OUT = conv_p.OUT_DIM[0]; + int H_OUT = conv_p.OUT_DIM[1]; + int W_OUT = conv_p.OUT_DIM[2]; + + int MDim = N * T_OUT * H_OUT * W_OUT; + int KDim = K_T * K_H * K_W * K; + int KDimPerGroup = KDim / conv_p.G; aligned_vector<uint8_t> A(N * T * H * W * K); - aligned_vector<int8_t> B(K * K_T * K_H * K_W); - int32_t C_num_rows = N * T_OUT * H_OUT * W_OUT; - aligned_vector<int32_t> C_ref(C_num_rows * K), C(C_ref.size()); + aligned_vector<int8_t> B(KDim); + aligned_vector<int32_t> C_ref(MDim * K), C(C_ref.size()); + aligned_vector<uint8_t> C_uint8_ref(C_ref.size()), C_uint8(C_ref.size()); randFill<uint8_t>(A, 0, 86); int32_t A_zero_point = 43; @@ -465,30 +554,8 @@ TEST(FBGemmDepthWiseTest, Test3x3x3PerChannelQuantization) { B_zero_point[k] = 5 + k; } - depthwise_3x3x3_pad_1_ref( - N, - T, - H, - W, - K, - stride_t, - stride_h, - stride_w, - A_zero_point, - A.data(), - B.data(), - C_ref.data()); - - aligned_vector<int32_t> C_ref_transpose(C_ref); - transpose_matrix(C_ref.data(), C_num_rows, K); - vector<float> C_multiplier(K); - for (auto k = 0; k < K; ++k) { - auto C_ref_k_begin = C_ref_transpose.begin() + k * C_num_rows; - auto C_ref_k_end = C_ref_k_begin + C_num_rows; - int32_t minimum = *min_element(C_ref_k_begin, C_ref_k_end); - int32_t maximum = *max_element(C_ref_k_begin, C_ref_k_end); - C_multiplier[k] = 255. / (maximum - minimum); - } + aligned_vector<float> C_multiplier(K); + randFill(C_multiplier, 0.001234f / 2, 0.001234f * 3 / 2); int32_t C_zero_point = 5; aligned_vector<int32_t> col_offsets(K); @@ -496,27 +563,40 @@ TEST(FBGemmDepthWiseTest, Test3x3x3PerChannelQuantization) { randFill(col_offsets, -100, 100); randFill(bias, -40, 40); - aligned_vector<uint8_t> C_uint8_ref(C_ref.size()), C_uint8(C_ref.size()); - depthwise_3x3x3_per_channel_quantization_pad_1_ref( - N, - T, - H, - W, - K, - stride_t, - stride_h, - stride_w, - A_zero_point, - A.data(), - B_zero_point.data(), - B.data(), - C_multiplier.data(), - C_zero_point, - C_uint8_ref.data(), - col_offsets.data(), - bias.data()); + vector<int32_t> row_offsets(MDim); + // im2col to compute row offset later + vector<uint8_t> A_im2col(MDim * KDim); + im2col_ref(conv_p, A.data(), A_zero_point, A_im2col.data()); + + conv_ref(conv_p, A.data(), A_zero_point, B.data(), C_ref.data()); + + for (int g = 0; g < conv_p.G; ++g) { + // Compute row offset + row_offsets_u8acc32_ref( + MDim, + KDimPerGroup, + KDim, + A_im2col.data() + g * KDimPerGroup, + row_offsets.data()); + + // Requantization + requantize_u8acc32_ref( + MDim, + 1, + conv_p.G, + C_ref.data() + g, + C_uint8_ref.data() + g, + C_multiplier.data() + g, + C_zero_point, + A_zero_point, + B_zero_point.data() + g, + row_offsets.data(), + col_offsets.data() + g, + bias.data() + g, + K); + } - Packed3x3x3ConvMatrix Bp(K, B.data()); + PackedDepthWiseConvMatrix Bp(K, 3 * 3 * 3, B.data()); depthwise_3x3x3_per_channel_quantization_pad_1( N, @@ -561,4 +641,22 @@ TEST(FBGemmDepthWiseTest, Test3x3x3PerChannelQuantization) { } // for each shape } // Test3x3PerChannelQuantization +TEST_P(FBGemmDepthWisePackUnpackTest, TestPackUnpack) { + int K, kernel_prod; + tie(K, kernel_prod) = GetParam(); + + ASSERT_EQ(K % 8, 0) + << "output channels (== groups) should be a multiple of 8"; + aligned_vector<int8_t> B(K * kernel_prod); + randFill<int8_t>(B, -16, 16); + + aligned_vector<int8_t> BUnpacked(K * kernel_prod); + + PackedDepthWiseConvMatrix BPacked(K, kernel_prod, B.data()); + BPacked.unpack(BUnpacked.data()); + + ASSERT_EQ(B, BUnpacked) + << "Original and unpacked data elements are not the same"; +} // TestPackUnpack + } // namespace fbgemm diff --git a/test/I8DepthwiseTest.h b/test/I8DepthwiseTest.h deleted file mode 100644 index d65362a..0000000 --- a/test/I8DepthwiseTest.h +++ /dev/null @@ -1,39 +0,0 @@ -/* - * Copyright (c) Facebook, Inc. and its affiliates. - * All rights reserved. - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ -#pragma once - -#include <vector> - -namespace fbgemm { - -// From ResNeXt-3D-101 -static std::vector<std::vector<int>> shapes_3d = { - // NOTE: clang-format wants to use a different formatting but the current - // formatting should be easier to read. - // N, K, T_in, H_in, W_in, stride - { 1, 64, 32, 56, 56, 1, }, - { 1, 128, 16, 28, 28, 1, }, - { 1, 256, 8, 14, 14, 1, }, - { 1, 512, 4, 7, 7, 1, }, - - { 1, 128, 32, 56, 56, 2, }, - { 1, 256, 16, 28, 28, 2, }, - { 1, 512, 8, 14, 14, 2, }, - - { 5, 64, 32, 56, 56, 1, }, - { 5, 128, 16, 28, 28, 1, }, - { 5, 256, 8, 14, 14, 1, }, - { 5, 512, 4, 7, 7, 1, }, - - { 5, 128, 32, 56, 56, 2, }, - { 5, 256, 16, 28, 28, 2, }, - { 5, 512, 8, 14, 14, 2, }, - - { 1, 8, 4, 4, 4, 1, }, -}; - -} // namespace fbgemm diff --git a/test/Im2ColFusedRequantizeTest.cc b/test/Im2ColFusedRequantizeTest.cc index b14303f..56df3c8 100644 --- a/test/Im2ColFusedRequantizeTest.cc +++ b/test/Im2ColFusedRequantizeTest.cc @@ -24,11 +24,6 @@ using namespace std; using namespace fbgemm; -vector<QuantizationGranularity> qGranularityVals{ - QuantizationGranularity::TENSOR, - QuantizationGranularity::GROUP, - QuantizationGranularity::OUT_CHANNEL}; - namespace { class fbgemmIm2colTest : public testing::TestWithParam<tuple<QuantizationGranularity, bool>> {}; diff --git a/test/PackedRequantizeAcc16Test.cc b/test/PackedRequantizeAcc16Test.cc index 20f860e..8978150 100644 --- a/test/PackedRequantizeAcc16Test.cc +++ b/test/PackedRequantizeAcc16Test.cc @@ -26,20 +26,14 @@ using namespace std; using namespace fbgemm; -vector<matrix_op_t> transposeVals{matrix_op_t::NoTranspose, - matrix_op_t::Transpose}; - -vector<QuantizationGranularity> qGranularityVals{ - QuantizationGranularity::TENSOR, - QuantizationGranularity::GROUP, - QuantizationGranularity::OUT_CHANNEL}; - namespace { class fbgemmu8s8acc16WithQuantGranularityTest : public testing::TestWithParam< tuple<matrix_op_t, matrix_op_t, bool, QuantizationGranularity>> {}; class fbgemmu8s8acc16Test : public testing::TestWithParam<tuple<matrix_op_t, matrix_op_t, bool>> {}; +class fbgemmPackUnpackAcc16Test + : public testing::TestWithParam<tuple<matrix_op_t, bool>> {}; }; // namespace INSTANTIATE_TEST_CASE_P( @@ -59,6 +53,11 @@ INSTANTIATE_TEST_CASE_P( ::testing::ValuesIn(transposeVals), ::testing::Bool())); +INSTANTIATE_TEST_CASE_P( + InstantiationName, + fbgemmPackUnpackAcc16Test, + ::testing::Combine(::testing::ValuesIn(transposeVals), ::testing::Bool())); + /** * @brief Shapes for unit test. */ @@ -87,6 +86,8 @@ static vector<vector<int>> GetShapes_() { {102, 512, 258}, {1024, 512, 258}, + + {120, 4, 288}, }; return shapes; } @@ -810,3 +811,79 @@ TEST_P(fbgemmu8s8acc16Test, NoRequantizeTest) { } // for each groups } // for each shape } + +/** + * @brief Unit test for packing and unpacking the weight tensor. + */ +TEST_P(fbgemmPackUnpackAcc16Test, TestPackUnpack) { + vector<vector<int>> shapes(GetShapes_()); + matrix_op_t btrans; + bool test_ld; + tie(btrans, test_ld) = GetParam(); + + BlockingFactors params; + params.MCB = 48; + params.NCB = 16; + params.KCB = 256; + params.MR = 1; + params.NR = 16; + params.ROW_INTERLEAVE = 4; + params.NR_MIN = 16; + vector<BlockingFactors*> vec_params_ptr = {¶ms, nullptr}; + + for (auto shape : shapes) { + for (int groups : {1, 3, 4}) { + for (auto params_ptr : vec_params_ptr) { + int n = shape[1]; + int k = shape[2]; + + if (k % groups != 0) { + continue; + } + int k_per_group = k / groups; + + // kxn matrix + aligned_vector<int8_t> Bint8(k * n); + randFill<int8_t>(Bint8, -128, 127); + + // To test lda != k , we just reduce k by half and use the original k + // as lda. + int n_adjusted = n; + if (test_ld) { + if (btrans == matrix_op_t::NoTranspose) { + n_adjusted = std::max(n / 2, 1); + } + } + + // Note that packing for weight is performed during the constructor + // stage. + PackBMatrix<int8_t, int16_t> packedWeights( + btrans, + k, + n_adjusted, + Bint8.data(), + (btrans == matrix_op_t::Transpose) ? k_per_group : n, + nullptr, + groups, + params_ptr); + + // Setup a buffer to get pack -> unpacked results + aligned_vector<int8_t> unpack_buf(k * n, 0); + + // Perform unpacking + packedWeights.unpack(unpack_buf.data(), params_ptr); + + // Sanity check + for (int i = 0; i < k; i++) { + for (int j = 0; j < n_adjusted; j++) { + EXPECT_EQ(Bint8.data()[i * n + j], unpack_buf.data()[i * n + j]) + << "Pack/Unpack results differ at index (" << i << ", " << j + << ", Reference: " << static_cast<int>(Bint8.data()[i * n + j]) + << ", Pack-Unpacked: " + << static_cast<int>(unpack_buf.data()[i * n + j]); + } + } + } + } + } +} diff --git a/test/PackedRequantizeTest.cc b/test/PackedRequantizeTest.cc index fd827b0..15e7d55 100644 --- a/test/PackedRequantizeTest.cc +++ b/test/PackedRequantizeTest.cc @@ -25,20 +25,14 @@ using namespace std; using namespace fbgemm; -vector<matrix_op_t> transposeVals{matrix_op_t::NoTranspose, - matrix_op_t::Transpose}; - -vector<QuantizationGranularity> qGranularityVals{ - QuantizationGranularity::TENSOR, - QuantizationGranularity::GROUP, - QuantizationGranularity::OUT_CHANNEL}; - namespace { class fbgemmu8s8acc32WithQuantGranularityTest : public testing::TestWithParam< tuple<matrix_op_t, matrix_op_t, bool, QuantizationGranularity>> {}; class fbgemmu8s8acc32Test : public testing::TestWithParam<tuple<matrix_op_t, matrix_op_t, bool>> {}; +class fbgemmPackUnpackAcc32Test + : public testing::TestWithParam<tuple<matrix_op_t, bool>> {}; }; // namespace INSTANTIATE_TEST_CASE_P( @@ -58,6 +52,11 @@ INSTANTIATE_TEST_CASE_P( ::testing::ValuesIn(transposeVals), ::testing::Bool())); +INSTANTIATE_TEST_CASE_P( + InstantiationName, + fbgemmPackUnpackAcc32Test, + ::testing::Combine(::testing::ValuesIn(transposeVals), ::testing::Bool())); + /** * @brief Shapes for unit test. */ @@ -86,6 +85,8 @@ static vector<vector<int>> GetShapes_() { {102, 512, 258}, {1024, 512, 258}, + + {120, 4, 288}, }; return shapes; } @@ -749,3 +750,79 @@ TEST_P(fbgemmu8s8acc32Test, TestSymmetricQuantizedInputOutput) { } // for each groups } // for each shape } + +/** + * @brief Unit test for packing and unpacking the weight tensor. + */ +TEST_P(fbgemmPackUnpackAcc32Test, TestPackUnpack) { + vector<vector<int>> shapes(GetShapes_()); + matrix_op_t btrans; + bool test_ld; + tie(btrans, test_ld) = GetParam(); + + BlockingFactors params; + params.MCB = 48; + params.NCB = 16; + params.KCB = 256; + params.MR = 1; + params.NR = 16; + params.ROW_INTERLEAVE = 4; + params.NR_MIN = 16; + vector<BlockingFactors*> vec_params_ptr = {¶ms, nullptr}; + + for (auto shape : shapes) { + for (int groups : {1, 3, 4}) { + for (auto params_ptr : vec_params_ptr) { + int n = shape[1]; + int k = shape[2]; + + if (k % groups != 0) { + continue; + } + int k_per_group = k / groups; + + // kxn matrix + aligned_vector<int8_t> Bint8(k * n); + randFill<int8_t>(Bint8, -128, 127); + + // To test lda != k , we just reduce k by half and use the original k + // as lda. + int n_adjusted = n; + if (test_ld) { + if (btrans == matrix_op_t::NoTranspose) { + n_adjusted = std::max(n / 2, 1); + } + } + + // Note that packing for weight is performed during the constructor + // stage. + PackBMatrix<int8_t> packedWeights( + btrans, + k, + n_adjusted, + Bint8.data(), + (btrans == matrix_op_t::Transpose) ? k_per_group : n, + nullptr, + groups, + params_ptr); + + // Setup a buffer to get pack -> unpacked results + aligned_vector<int8_t> unpack_buf(k * n, 0); + + // Perform unpacking + packedWeights.unpack(unpack_buf.data(), params_ptr); + + // Sanity check + for (int i = 0; i < k; i++) { + for (int j = 0; j < n_adjusted; j++) { + EXPECT_EQ(Bint8.data()[i * n + j], unpack_buf.data()[i * n + j]) + << "Pack/Unpack results differ at index (" << i << ", " << j + << ", Reference: " << static_cast<int>(Bint8.data()[i * n + j]) + << ", Pack-Unpacked: " + << static_cast<int>(unpack_buf.data()[i * n + j]); + } + } + } + } + } +} diff --git a/test/QuantUtilsTest.cc b/test/QuantUtilsTest.cc new file mode 100644 index 0000000..ddb1f91 --- /dev/null +++ b/test/QuantUtilsTest.cc @@ -0,0 +1,183 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include <algorithm> +#include <limits> +#include <random> + +#include <gtest/gtest.h> + +#include "fbgemm/QuantUtils.h" +#include "fbgemm/Utils.h" + +using namespace std; +using namespace fbgemm; + +// tuple represents K, C, X, G, layout_t +// layout_t can be KCX or KXC +class QuantizeGroupwiseTest + : public testing::TestWithParam<tuple<int, int, int, int, layout_t>> {}; + +INSTANTIATE_TEST_CASE_P( + InstantiationName, + QuantizeGroupwiseTest, + ::testing::Combine( + ::testing::ValuesIn({4, 12, 64}), // K + ::testing::ValuesIn({12, 16, 32}), // C + ::testing::ValuesIn({1, 10, 15, 30}), // X + ::testing::ValuesIn({1, 4}), // G + ::testing::ValuesIn({layout_t::KCX, layout_t::KXC}))); + +template <typename T, layout_t LT> +void ref_impl( + const vector<float>& src, + int K, + int C, + int X, + int G, + const vector<float>& scales, + const vector<int>& zero_points, + vector<T>& dst) { + int C_per_G = C / G; + for (int i = 0; i < K; ++i) { + for (int g = 0; g < G; ++g) { + for (int c = 0; c < C / G; ++c) { + for (int x = 0; x < X; ++x) { + float num; + if (LT == layout_t::KCX) { + num = src[(i * C + g * C_per_G + c) * X + x]; + } else { + num = src[(i * X + x) * C + g * C_per_G + c]; + } + int res = nearbyint(zero_points[g] + num / scales[g]); + T final_res = min<T>( + max<T>(res, numeric_limits<T>::min()), numeric_limits<T>::max()); + if (LT == layout_t::KCX) { + dst[(i * C + g * C_per_G + c) * X + x] = final_res; + } else { + dst[(i * X + x) * C + g * C_per_G + c] = final_res; + } + } + } + } + } +} + +template <typename T, layout_t LT> +void runTests( + const vector<float>& src, + int K, + int C, + int X, + int G, + const vector<float>& scales, + const vector<int>& zero_points, + vector<T>& dst, + vector<T>& dst_ref) { + QuantizeGroupwise<T, LT>( + src.data(), K, C, X, G, scales.data(), zero_points.data(), dst.data()); + + ref_impl<T, LT>(src, K, C, X, G, scales, zero_points, dst_ref); +} + +/** + * There can be off-by-one error in quantized values due to how the mid-point + * cases are rounded-off in vectorized vs scalar codes and due to adding of + * zero_point before rounding vs after rounding. We ignore such differences + * while comparing results. + */ +template <typename T> +::testing::AssertionResult isNear( + const vector<T>& res, + const vector<T>& res_ref) { + bool match = true; + if (res.size() == res_ref.size()) { + for (int i = 0; i < res.size(); ++i) { + if (!(res[i] == res_ref[i] || res[i] == res_ref[i] + 1 || + res[i] == res_ref[i] - 1)) { + match = false; + break; + } + } + } + if (match) + return ::testing::AssertionSuccess(); + else + return ::testing::AssertionFailure() << " Quantized results do not match"; +} + +/** + * Test for QuantizeGroupwise + */ +TEST_P(QuantizeGroupwiseTest, quantizeTest) { + int K, C, X, G; + layout_t layout; + tie(K, C, X, G, layout) = GetParam(); + + random_device rd; + mt19937 gen(rd()); + + uniform_real_distribution<float> disFP(0.1, 1.1); + + vector<float> inp(K * C * X); + generate(inp.begin(), inp.end(), [&, disFP]() mutable { return disFP(gen); }); + + vector<float> scales(G); + generate(scales.begin(), scales.end(), [&, disFP]() mutable { + return disFP(gen); + }); + + uniform_int_distribution<> disUInt8(0, 8); + vector<int> zero_points_uint8(G); + generate( + zero_points_uint8.begin(), + zero_points_uint8.end(), + [&, disUInt8]() mutable { return disUInt8(gen); }); + + uniform_int_distribution<> disInt8(-64, 63); + vector<int> zero_points_int8(G); + generate( + zero_points_int8.begin(), zero_points_int8.end(), [&, disInt8]() mutable { + return disInt8(gen); + }); + + uniform_int_distribution<> disInt32(-512, 512); + vector<int> zero_points_int32(G); + generate( + zero_points_int32.begin(), + zero_points_int32.end(), + [&, disInt32]() mutable { return disInt32(gen); }); + + vector<uint8_t> dstuint8(K * C * X); + vector<uint8_t> dstuint8_ref(K * C * X); + + vector<int8_t> dstint8(K * C * X); + vector<int8_t> dstint8_ref(K * C * X); + + vector<int32_t> dstint32(K * C * X); + vector<int32_t> dstint32_ref(K * C * X); + + if (layout == layout_t::KCX) { + runTests<uint8_t, layout_t::KCX>( + inp, K, C, X, G, scales, zero_points_uint8, dstuint8, dstuint8_ref); + runTests<int8_t, layout_t::KCX>( + inp, K, C, X, G, scales, zero_points_int8, dstint8, dstint8_ref); + runTests<int32_t, layout_t::KCX>( + inp, K, C, X, G, scales, zero_points_int32, dstint32, dstint32_ref); + } else { + runTests<uint8_t, layout_t::KXC>( + inp, K, C, X, G, scales, zero_points_uint8, dstuint8, dstuint8_ref); + runTests<int8_t, layout_t::KXC>( + inp, K, C, X, G, scales, zero_points_int8, dstint8, dstint8_ref); + runTests<int32_t, layout_t::KXC>( + inp, K, C, X, G, scales, zero_points_int32, dstint32, dstint32_ref); + } + + EXPECT_TRUE(isNear(dstuint8, dstuint8_ref)); + EXPECT_TRUE(isNear(dstint8, dstint8_ref)); + EXPECT_TRUE(isNear(dstint32, dstint32_ref)); +} diff --git a/test/RequantizeOnlyTest.cc b/test/RequantizeOnlyTest.cc new file mode 100644 index 0000000..94e8e7d --- /dev/null +++ b/test/RequantizeOnlyTest.cc @@ -0,0 +1,169 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include <algorithm> +#include <functional> +#include <iostream> +#include <random> +#include <stdexcept> +#include <string> + +#include <gtest/gtest.h> + +#include "TestUtils.h" +#include "bench/BenchUtils.h" +#include "fbgemm/Fbgemm.h" + +using namespace std; +using namespace fbgemm; + +vector<QuantizationGranularity> qGranularityValsLocal{ + QuantizationGranularity::TENSOR, + QuantizationGranularity::OUT_CHANNEL}; + +namespace { + +// tuple represents #rows, #cols, fuse_relu, quantization_granularity, bias_type +class FloatRequantizeTest + : public testing::TestWithParam< + tuple<int, int, bool, QuantizationGranularity>> {}; + +}; // namespace + +INSTANTIATE_TEST_CASE_P( + InstantiationName, + FloatRequantizeTest, + ::testing::Combine( + ::testing::ValuesIn({1, 2, 3, 4}), // number of rows + ::testing::ValuesIn( + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 16, 20, 32}), // number of + // cols + ::testing::Bool(), // fuse relu + ::testing::ValuesIn(qGranularityValsLocal))); // requantization granularity + +/** + * Test for float bias + */ +TEST_P(FloatRequantizeTest, floatBiasTest) { + int rows, cols; + bool fuse_relu; + QuantizationGranularity q_gran; + tie(rows, cols, fuse_relu, q_gran) = GetParam(); + + int numElements = rows * cols; + + aligned_vector<float> act_times_w_scale(cols); + randFill<float>(act_times_w_scale, -8, 8); + + float out_scale = 2.0f; + + aligned_vector<float> C_multiplier(cols); + transform( + act_times_w_scale.begin(), + act_times_w_scale.end(), + C_multiplier.begin(), + [&out_scale](float i) { return i / out_scale; }); + + aligned_vector<int32_t> Bint8_zero_point(cols); + randFill<int32_t>(Bint8_zero_point, -8, 8); + + aligned_vector<int32_t> row_offset_buf(rows); + randFill<int32_t>(row_offset_buf, -8, 8); + + aligned_vector<int32_t> col_offsets(cols); + randFill<int32_t>(col_offsets, -8, 8); + + // quantized bias + aligned_vector<int32_t> bias_q(cols); + randFill<int32_t>(bias_q, -8, 8); + + // floating point bias + aligned_vector<float> bias_f(cols); + if (q_gran == QuantizationGranularity::TENSOR) { + transform( + bias_q.begin(), + bias_q.end(), + bias_f.begin(), + [&act_times_w_scale](float i) { return i * act_times_w_scale[0]; }); + } else if (q_gran == QuantizationGranularity::OUT_CHANNEL) { + transform( + act_times_w_scale.begin(), + act_times_w_scale.end(), + bias_q.begin(), + bias_f.begin(), + multiplies<float>()); + + } else { + FAIL(); + } + + aligned_vector<int32_t> input(numElements); + randFill<int32_t>(input, -8, 8); + + aligned_vector<uint8_t> output_q_bias(numElements); + aligned_vector<uint8_t> output_f_bias(numElements); + + int32_t C_zero_point = 3; + int32_t Aint8_zero_point = 3; + + block_type_t block{0, rows, 0, cols}; + + DoNothing<> doNothingObj{}; + +#define TESTCODE(FUSE_RELU, Q_GRAN) \ + ReQuantizeOutput<FUSE_RELU, Q_GRAN> reqObj_q( \ + doNothingObj, \ + C_multiplier.data(), \ + C_zero_point, \ + Aint8_zero_point, \ + Bint8_zero_point.data(), \ + row_offset_buf.data(), \ + col_offsets.data(), \ + bias_q.data(), \ + cols); \ + ReQuantizeOutput<FUSE_RELU, Q_GRAN, float> reqObj_f( \ + doNothingObj, \ + C_multiplier.data(), \ + C_zero_point, \ + Aint8_zero_point, \ + Bint8_zero_point.data(), \ + row_offset_buf.data(), \ + col_offsets.data(), \ + bias_f.data(), \ + cols, \ + 1, \ + act_times_w_scale.data()); \ + reqObj_q.f<inst_set_t::avx2>( \ + output_q_bias.data(), input.data(), block, cols, cols); \ + reqObj_f.f<inst_set_t::avx2>( \ + output_f_bias.data(), input.data(), block, cols, cols); + + if (fuse_relu) { + if (q_gran == QuantizationGranularity::TENSOR) { + TESTCODE(true, QuantizationGranularity::TENSOR) + + } else if (q_gran == QuantizationGranularity::OUT_CHANNEL) { + TESTCODE(true, QuantizationGranularity::OUT_CHANNEL) + + } else { + FAIL(); + } + + } else { + if (q_gran == QuantizationGranularity::TENSOR) { + TESTCODE(false, QuantizationGranularity::TENSOR) + + } else if (q_gran == QuantizationGranularity::OUT_CHANNEL) { + TESTCODE(false, QuantizationGranularity::OUT_CHANNEL) + + } else { + FAIL(); + } + } +#undef TESTCODE + ASSERT_EQ(output_q_bias, output_f_bias) + << "Requantization with quantized bias and float bias differs"; +} diff --git a/test/TestUtils.h b/test/TestUtils.h index 2cb7b88..d320ae2 100644 --- a/test/TestUtils.h +++ b/test/TestUtils.h @@ -7,9 +7,18 @@ #pragma once #include <cmath> #include <vector> +#include "fbgemm/Fbgemm.h" namespace fbgemm { +static std::vector<matrix_op_t> transposeVals = { matrix_op_t::NoTranspose, + matrix_op_t::Transpose }; + +static std::vector<QuantizationGranularity> qGranularityVals = { + QuantizationGranularity::TENSOR, + QuantizationGranularity::GROUP, + QuantizationGranularity::OUT_CHANNEL }; + /* * @brief Check and validate the buffers for reference and FBGEMM result. */ diff --git a/test/UniConvPackingTest.cc b/test/UniConvPackingTest.cc deleted file mode 100644 index 77552af..0000000 --- a/test/UniConvPackingTest.cc +++ /dev/null @@ -1,148 +0,0 @@ -/* - * Copyright (c) Facebook, Inc. and its affiliates. - * All rights reserved. - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ -#include <algorithm> -#include <random> -#include <iostream> - - -#include <gtest/gtest.h> - -#include "QuantizationHelpers.h" -#include "TestUtils.h" -#include "bench/BenchUtils.h" -#include "fbgemm/Fbgemm.h" -#include "src/RefImplementations.h" - -using namespace std; -using namespace fbgemm; - -namespace { - -// tuple represents MB, IC, OC, IT, IH, IW, KH/KW, stride, pad -class convPackingTest - : public testing::TestWithParam< - tuple<int, int, int, int, int, int, int, int, int, int>> {}; - -}; // namespace - -INSTANTIATE_TEST_CASE_P( - InstantiationName, - convPackingTest, - ::testing::Combine( - ::testing::ValuesIn({1, 2}), // MB - ::testing::ValuesIn({16, 32}), // IC - ::testing::ValuesIn({16, 32}), // OC - ::testing::ValuesIn({17}), // IT - ::testing::ValuesIn({10, 30, 55}), // IH - ::testing::ValuesIn({10, 30, 55}), // IW - ::testing::ValuesIn({1, 4, 16}), // G - ::testing::ValuesIn({3, 7}), // kernel - ::testing::ValuesIn({1, 2}), // stride - ::testing::ValuesIn({1, 2}))); // pad - -/** - * Test for conv packing - */ -TEST_P(convPackingTest, packingTest) { - int MB, IC, OC, IT, IH, IW, G, kernel, stride, pad; - tie(MB, IC, OC, IT, IH, IW, G, kernel, stride, pad) = GetParam(); - - conv_param_t<2> conv_p_2d( - MB, - IC, - OC, - {IH, IW}, - G, - {kernel, kernel}, - {stride, stride}, - {pad, pad, pad, pad}); - - int kernel_dim_2d = kernel * kernel; - aligned_vector<int8_t> Bint8_2d( - kernel_dim_2d * conv_p_2d.IC * (conv_p_2d.OC / conv_p_2d.G)); - PackWeightsForConv<2> packedB_2D(conv_p_2d, Bint8_2d.data()); - - switch (ConvFastPath<2, int32_t>(conv_p_2d)) { - case optimized_conv_t::depthwise: { - ASSERT_NE(packedB_2D.getPackedWFor2DDW(), nullptr) - << "2D depthwise packed matrix is null"; - ASSERT_EQ(packedB_2D.getPackedWForIm2col(), nullptr) - << "im2col packed matrix should be null"; - ASSERT_EQ(packedB_2D.getPackedWFor3DDW(), nullptr) - << "3D depthwise packed matrix should be null"; - ASSERT_EQ(packedB_2D.getPackedWForGroupwise(), nullptr) - << "groupwise packed matrix should be null"; - break; - } - case optimized_conv_t::groupwise: { - ASSERT_EQ(packedB_2D.getPackedWForIm2col(), nullptr) - << "im2col packed matrix should be null"; - ASSERT_EQ(packedB_2D.getPackedWFor2DDW(), nullptr) - << "2D depthwise packed matrix is null"; - ASSERT_EQ(packedB_2D.getPackedWFor3DDW(), nullptr) - << "3D depthwise packed matrix should be null"; - ASSERT_NE(packedB_2D.getPackedWForGroupwise(), nullptr) - << "Groupwise packed matrix is null"; - break; - } - case optimized_conv_t::im2col: { - ASSERT_EQ(packedB_2D.getPackedWFor2DDW(), nullptr) - << "2D depthwise packed matrix is null"; - ASSERT_EQ(packedB_2D.getPackedWFor3DDW(), nullptr) - << "3D depthwise packed matrix should be null"; - ASSERT_EQ(packedB_2D.getPackedWForGroupwise(), nullptr) - << "groupwise packed matrix should be null"; - ASSERT_NE(packedB_2D.getPackedWForIm2col(), nullptr) - << "im2col packed matrix is null"; - break; - } - } - - conv_param_t<3> conv_p_3d( - MB, - IC, - OC, - {IT, IH, IW}, - G, - {kernel, kernel, kernel}, - {stride, stride, stride}, - {pad, pad, pad, pad, pad, pad}); - - int kernel_dim_3d = kernel * kernel * kernel; - aligned_vector<int8_t> Bint8_3d( - kernel_dim_3d * conv_p_3d.IC * (conv_p_3d.OC / conv_p_3d.G)); - PackWeightsForConv<3> packedB_3D(conv_p_3d, Bint8_3d.data()); - - switch (ConvFastPath<3, int32_t>(conv_p_3d)) { - case optimized_conv_t::depthwise: { - ASSERT_EQ(packedB_3D.getPackedWFor2DDW(), nullptr) - << "2D depthwise packed matrix is null"; - ASSERT_EQ(packedB_3D.getPackedWForIm2col(), nullptr) - << "im2col packed matrix should be null"; - ASSERT_NE(packedB_3D.getPackedWFor3DDW(), nullptr) - << "3D depthwise packed matrix should be null"; - ASSERT_EQ(packedB_3D.getPackedWForGroupwise(), nullptr) - << "groupwise packed matrix should be null"; - break; - } - case optimized_conv_t::groupwise: { - ASSERT_TRUE(false) << "groupwise are not supported for 3D"; - break; - } - case optimized_conv_t::im2col: { - ASSERT_EQ(packedB_3D.getPackedWFor2DDW(), nullptr) - << "2D depthwise packed matrix is null"; - ASSERT_EQ(packedB_3D.getPackedWFor3DDW(), nullptr) - << "3D depthwise packed matrix should be null"; - ASSERT_EQ(packedB_3D.getPackedWForGroupwise(), nullptr) - << "groupwise packed matrix should be null"; - ASSERT_NE(packedB_3D.getPackedWForIm2col(), nullptr) - << "im2col packed matrix is null"; - break; - } - } -} diff --git a/test/UniConvTest.cc b/test/UniConvTest.cc new file mode 100644 index 0000000..e9c7ba5 --- /dev/null +++ b/test/UniConvTest.cc @@ -0,0 +1,714 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include <algorithm> +#include <iostream> +#include <random> +#include <stdexcept> + +#include <gtest/gtest.h> + +#include "QuantizationHelpers.h" +#include "TestUtils.h" +#include "bench/BenchUtils.h" +#include "fbgemm/Fbgemm.h" +#include "src/RefImplementations.h" + +using namespace std; +using namespace fbgemm; + +// clang-format off +static vector<conv_param_t<>> GetShapes_() { + vector<conv_param_t<>> shapes = { + // MB, IC, OC, {IH, IW}, G, {KH, KW}, {stride_h, stride_w}, {pad_t, pad_l, + // pad_b, pad_r} + // Regular + conv_param_t<>(1, 16, 16, {10, 30}, 1, {3, 3}, {1, 1}, {1, 1, 1, 1}), + conv_param_t<>(1, 32, 32, {10, 30}, 1, {3, 3}, {1, 1}, {1, 1, 1, 1}), + conv_param_t<>(1, 16, 32, {30, 10}, 1, {3, 3}, {1, 1}, {1, 1, 1, 1}), + conv_param_t<>(1, 32, 16, {10, 30}, 1, {3, 3}, {1, 1}, {1, 1, 1, 1}), + conv_param_t<>(1, 32, 16, {10, 30}, 1, {3, 3}, {1, 1}, {0, 0, 0, 0}), + conv_param_t<>(1, 32, 16, {10, 30}, 1, {3, 3}, {1, 1}, {0, 0, 0, 0}, {2, 2}), + conv_param_t<>(1, 32, 16, {10, 30}, 1, {3, 3}, {1, 1}, {1, 1, 1, 1}, {2, 2}), + conv_param_t<>(1, 32, 16, {10, 30}, 1, {3, 3}, {1, 1}, {1, 1, 1, 1}, {3, 3}), + conv_param_t<>(1, 32, 16, {10, 30}, 1, {3, 3}, {2, 2}, {1, 1, 1, 1}), + conv_param_t<>(1, 32, 16, {10, 30}, 1, {3, 3}, {2, 2}, {1, 1, 1, 1}, {2, 2}), + conv_param_t<>(1, 32, 16, {10, 30}, 1, {3, 3}, {2, 2}, {1, 1, 1, 1}, {2, 1}), + conv_param_t<>(1, 32, 16, {10, 30}, 1, {3, 3}, {2, 2}, {1, 1, 1, 1}, {1, 2}), + conv_param_t<>(1, 32, 16, {10, 30}, 1, {3, 3}, {1, 1}, {2, 1, 2, 1}), + conv_param_t<>(1, 32, 16, {10, 30}, 1, {3, 3}, {1, 1}, {1, 2, 1, 2}), + conv_param_t<>(1, 32, 16, {10, 30}, 1, {3, 3}, {1, 2}, {1, 1, 1, 1}), + conv_param_t<>(1, 32, 16, {10, 30}, 1, {3, 3}, {2, 1}, {1, 1, 1, 1}), + conv_param_t<>(1, 16, 16, {10, 30}, 1, {3, 5}, {1, 1}, {1, 1, 1, 1}), + conv_param_t<>(1, 16, 16, {10, 30}, 1, {5, 3}, {1, 1}, {1, 1, 1, 1}), + conv_param_t<>(1, 16, 16, {10, 30}, 1, {5, 3}, {1, 1}, {1, 1, 1, 1}, {2, 2}), + // groupwise + conv_param_t<>(1, 32, 32, {10, 30}, 8, {3, 3}, {1, 1}, {1, 1, 1, 1}), + conv_param_t<>(1, 32, 16, {10, 30}, 8, {3, 3}, {1, 1}, {1, 1, 1, 1}), + conv_param_t<>(1, 16, 32, {10, 30}, 8, {3, 3}, {2, 2}, {1, 1, 1, 1}), + conv_param_t<>(1, 32, 32, {10, 30}, 8, {3, 3}, {2, 2}, {2, 1, 2, 1}), + conv_param_t<>(1, 32, 32, {10, 30}, 8, {3, 3}, {1, 2}, {2, 1, 2, 1}), + conv_param_t<>(1, 32, 32, {10, 30}, 8, {3, 3}, {2, 1}, {2, 1, 2, 1}), + conv_param_t<>(1, 32, 32, {10, 30}, 8, {3, 5}, {1, 1}, {1, 1, 1, 1}), + conv_param_t<>(1, 32, 32, {10, 30}, 8, {5, 3}, {1, 1}, {1, 1, 1, 1}), + conv_param_t<>(1, 32, 32, {10, 30}, 8, {5, 3}, {1, 1}, {1, 1, 1, 1}, {2, 2}), + // DW + conv_param_t<>(1, 32, 32, {10, 30}, 32, {3, 3}, {1, 1}, {1, 1, 1, 1}), + conv_param_t<>(1, 32, 32, {10, 30}, 32, {3, 3}, {2, 2}, {1, 1, 1, 1}), + conv_param_t<>(1, 32, 32, {10, 30}, 32, {3, 3}, {1, 1}, {1, 2, 1, 2}), + conv_param_t<>(1, 32, 32, {10, 30}, 32, {3, 3}, {2, 1}, {1, 1, 1, 1}), + conv_param_t<>(1, 32, 32, {10, 30}, 32, {3, 3}, {1, 2}, {1, 1, 1, 1}), + conv_param_t<>(1, 32, 32, {10, 30}, 32, {3, 5}, {1, 1}, {1, 1, 1, 1}), + conv_param_t<>(1, 32, 32, {10, 30}, 32, {5, 3}, {1, 1}, {1, 1, 1, 1}), + conv_param_t<>(1, 32, 32, {10, 30}, 32, {5, 3}, {1, 1}, {1, 1, 1, 1}, {2, 2}), + // Pointwise + conv_param_t<>(1, 32, 32, {10, 30}, 1, {1, 1}, {1, 1}, {0, 0, 0, 0}), + conv_param_t<>(1, 16, 32, {10, 30}, 1, {1, 1}, {1, 1}, {0, 0, 0, 0}), + conv_param_t<>(1, 32, 16, {10, 30}, 1, {1, 1}, {1, 1}, {0, 0, 0, 0}), + conv_param_t<>(1, 32, 16, {10, 30}, 1, {1, 1}, {2, 2}, {0, 0, 0, 0}), + conv_param_t<>(1, 32, 16, {10, 30}, 1, {1, 1}, {1, 2}, {0, 0, 0, 0}), + conv_param_t<>(1, 32, 16, {10, 30}, 1, {1, 1}, {2, 1}, {0, 0, 0, 0}), + }; + return shapes; +} +// clang-format on + +namespace { + +// tuple represents MB, IC, OC, IT, IH, IW, KH/KW, stride, pad +class uniConvTest + : public testing::TestWithParam< + tuple<int, int, int, int, int, int, int, int, int, int>> {}; + +// tuple represents QuantizationGranularity, A symmetric, B symmetric, +// test_bias, test_float_bias +class UniConvQGranTest + : public testing::TestWithParam< + tuple<QuantizationGranularity, bool, bool, bool, bool>> {}; + +}; // namespace + +// Combine only allows at most 10 generators. +INSTANTIATE_TEST_CASE_P( + InstantiationName, + uniConvTest, + ::testing::Combine( + ::testing::ValuesIn({1, 2}), // MB + ::testing::ValuesIn({16, 32}), // IC + ::testing::ValuesIn({16, 32}), // OC + ::testing::ValuesIn({17}), // IT + ::testing::ValuesIn({10, 30, 55}), // IH + ::testing::ValuesIn({10, 30, 55}), // IW + ::testing::ValuesIn({1, 4, 16}), // G + ::testing::ValuesIn({1, 3, 7}), // kernel + ::testing::ValuesIn({1, 2}), // stride + ::testing::ValuesIn({0, 1, 2}))); // pad + +INSTANTIATE_TEST_CASE_P( + InstantiationName, + UniConvQGranTest, + ::testing::Combine( + ::testing::ValuesIn(qGranularityVals), + ::testing::Bool(), // A symmetric + ::testing::Bool(), // B symmetric + ::testing::Bool(), // test_bias + ::testing::Bool())); // test_float_bias +/** + * Test for conv packing + */ +TEST_P(uniConvTest, packingTest) { + int MB, IC, OC, IT, IH, IW, G, kernel, stride, pad; + tie(MB, IC, OC, IT, IH, IW, G, kernel, stride, pad) = GetParam(); + + conv_param_t<2> conv_p_2d( + MB, + IC, + OC, + {IH, IW}, + G, + {kernel, kernel}, + {stride, stride}, + {pad, pad, pad, pad}); + + int kernel_dim_2d = kernel * kernel; + aligned_vector<int8_t> Bint8_2d( + kernel_dim_2d * conv_p_2d.IC * (conv_p_2d.OC / conv_p_2d.G)); + PackWeightsForConv<2> packedB_2D(conv_p_2d, Bint8_2d.data()); + + switch (ConvFastPath<2, int32_t>(conv_p_2d)) { + case optimized_conv_t::depthwise: { + ASSERT_EQ(packedB_2D.getPackedWForIm2col(), nullptr) + << "im2col packed matrix should be null"; + ASSERT_EQ(packedB_2D.getPackedWForGroupwise(), nullptr) + << "groupwise packed matrix should be null"; + ASSERT_EQ(packedB_2D.getPackedWForPointwise(), nullptr) + << "pointwise packed matrix should be null"; + ASSERT_NE(packedB_2D.getPackedWForDepthwise(), nullptr) + << "depthwise packed matrix is null"; + break; + } + case optimized_conv_t::groupwise: { + ASSERT_EQ(packedB_2D.getPackedWForIm2col(), nullptr) + << "im2col packed matrix should be null"; + ASSERT_EQ(packedB_2D.getPackedWForDepthwise(), nullptr) + << "depthwise packed matrix should be null"; + ASSERT_EQ(packedB_2D.getPackedWForPointwise(), nullptr) + << "pointwise packed matrix should be null"; + ASSERT_NE(packedB_2D.getPackedWForGroupwise(), nullptr) + << "Groupwise packed matrix is null"; + break; + } + case optimized_conv_t::pointwise: { + ASSERT_EQ(packedB_2D.getPackedWForIm2col(), nullptr) + << "im2col packed matrix should be null"; + ASSERT_EQ(packedB_2D.getPackedWForDepthwise(), nullptr) + << "depthwise packed matrix should null"; + ASSERT_EQ(packedB_2D.getPackedWForGroupwise(), nullptr) + << "Groupwise packed matrix should be null"; + ASSERT_NE(packedB_2D.getPackedWForPointwise(), nullptr) + << "pointwise packed matrix is null"; + break; + } + case optimized_conv_t::im2col: { + ASSERT_EQ(packedB_2D.getPackedWForDepthwise(), nullptr) + << "depthwise packed matrix should be null"; + ASSERT_EQ(packedB_2D.getPackedWForGroupwise(), nullptr) + << "groupwise packed matrix should be null"; + ASSERT_EQ(packedB_2D.getPackedWForPointwise(), nullptr) + << "pointwise packed matrix should be null"; + ASSERT_NE(packedB_2D.getPackedWForIm2col(), nullptr) + << "im2col packed matrix is null"; + break; + } + } + + conv_param_t<3> conv_p_3d( + MB, + IC, + OC, + {IT, IH, IW}, + G, + {kernel, kernel, kernel}, + {stride, stride, stride}, + {pad, pad, pad, pad, pad, pad}); + + int kernel_dim_3d = kernel * kernel * kernel; + aligned_vector<int8_t> Bint8_3d( + kernel_dim_3d * conv_p_3d.IC * (conv_p_3d.OC / conv_p_3d.G)); + PackWeightsForConv<3> packedB_3D(conv_p_3d, Bint8_3d.data()); + + switch (ConvFastPath<3, int32_t>(conv_p_3d)) { + case optimized_conv_t::depthwise: { + ASSERT_EQ(packedB_3D.getPackedWForIm2col(), nullptr) + << "im2col packed matrix should be null"; + ASSERT_EQ(packedB_3D.getPackedWForGroupwise(), nullptr) + << "groupwise packed matrix should be null"; + ASSERT_EQ(packedB_3D.getPackedWForPointwise(), nullptr) + << "pointwise packed matrix should be null"; + ASSERT_NE(packedB_3D.getPackedWForDepthwise(), nullptr) + << "depthwise packed matrix is null"; + break; + } + case optimized_conv_t::groupwise: { + ASSERT_TRUE(false) << "groupwise are not supported for 3D"; + break; + } + case optimized_conv_t::pointwise: { + ASSERT_EQ(packedB_3D.getPackedWForDepthwise(), nullptr) + << "depthwise packed matrix should be null"; + ASSERT_EQ(packedB_3D.getPackedWForGroupwise(), nullptr) + << "groupwise packed matrix should be null"; + ASSERT_EQ(packedB_3D.getPackedWForIm2col(), nullptr) + << "im2col packed matrix should be null"; + ASSERT_NE(packedB_3D.getPackedWForPointwise(), nullptr) + << "pointwise packed matrix is null"; + break; + } + case optimized_conv_t::im2col: { + ASSERT_EQ(packedB_3D.getPackedWForDepthwise(), nullptr) + << "depthwise packed matrix should be null"; + ASSERT_EQ(packedB_3D.getPackedWForGroupwise(), nullptr) + << "groupwise packed matrix should be null"; + ASSERT_EQ(packedB_3D.getPackedWForPointwise(), nullptr) + << "pointwise packed matrix should be null"; + ASSERT_NE(packedB_3D.getPackedWForIm2col(), nullptr) + << "im2col packed matrix is null"; + break; + } + } +} + +/** + * Test for packing/unpacking + */ +TEST_P(uniConvTest, packUnpackTest) { + int MB, IC, OC, IT, IH, IW, G, kernel, stride, pad; + tie(MB, IC, OC, IT, IH, IW, G, kernel, stride, pad) = GetParam(); + + conv_param_t<2> conv_p_2d( + MB, + IC, + OC, + {IH, IW}, + G, + {kernel, kernel}, + {stride, stride}, + {pad, pad, pad, pad}); + + int kernel_dim_2d = kernel * kernel; + + aligned_vector<int8_t> Bint8_2d( + kernel_dim_2d * conv_p_2d.IC * (conv_p_2d.OC / conv_p_2d.G)); + aligned_vector<int8_t> Bint8_2d_unpacked( + kernel_dim_2d * conv_p_2d.IC * (conv_p_2d.OC / conv_p_2d.G)); + + PackWeightsForConv<2> packedB_2D(conv_p_2d, Bint8_2d.data()); + + packedB_2D.unpack(Bint8_2d_unpacked.data()); + + ASSERT_EQ(Bint8_2d, Bint8_2d_unpacked) + << "Original and unpacked data elements are not the same [2D]"; + + conv_param_t<3> conv_p_3d( + MB, + IC, + OC, + {IT, IH, IW}, + G, + {kernel, kernel, kernel}, + {stride, stride, stride}, + {pad, pad, pad, pad, pad, pad}); + + int kernel_dim_3d = kernel * kernel * kernel; + + aligned_vector<int8_t> Bint8_3d( + kernel_dim_3d * conv_p_3d.IC * (conv_p_3d.OC / conv_p_3d.G)); + + aligned_vector<int8_t> Bint8_3d_unpacked( + kernel_dim_3d * conv_p_3d.IC * (conv_p_3d.OC / conv_p_3d.G)); + + PackWeightsForConv<3> packedB_3D(conv_p_3d, Bint8_3d.data()); + + packedB_3D.unpack(Bint8_3d_unpacked.data()); + + ASSERT_EQ(Bint8_3d, Bint8_3d_unpacked) + << "Original and unpacked data elements are not the same [3D]"; +} + +TEST(uniConvTest, cornerCases) { + int stride = 1; + conv_param_t<2> conv_p_2d( + 1, // mini-batch + 16, // input channels + 32, // output channels + {28, 28}, // input height/width + 4, // groups + {3, 3}, // kernel height/width + {stride, stride}, // strides + {1, 1, 1, 1}); // padding + + int kernel_dim_2d = conv_p_2d.K[0] * conv_p_2d.K[1]; + + aligned_vector<uint8_t> Aint8( + conv_p_2d.MB * conv_p_2d.IN_DIM[0] * conv_p_2d.IN_DIM[1] * conv_p_2d.IC); + aligned_vector<int8_t> Bint8_2d( + kernel_dim_2d * conv_p_2d.IC * (conv_p_2d.OC / conv_p_2d.G)); + aligned_vector<int32_t> Cint32_fb( + conv_p_2d.MB * conv_p_2d.OUT_DIM[0] * conv_p_2d.OUT_DIM[1] * + conv_p_2d.OC); + aligned_vector<uint8_t> Cint8_fb(Cint32_fb.size(), 0); + + // A matrix (input activations) + randFill<uint8_t>(Aint8, 0, 5); + int32_t Aint8_zero_point = 4; + + // B matrix (weights) + randFill<int8_t>(Bint8_2d, -4, 4); + aligned_vector<int32_t> Bint8_zero_point(1); + randFill(Bint8_zero_point, -3, -1); + + aligned_vector<float> C_multiplier(Bint8_zero_point.size()); + randFill(C_multiplier, 0.1234f / 2, 0.1234f * 3 / 2); + int32_t C_zero_point = 5; + + PackWeightsForConv<2> packedB_2D(conv_p_2d, Bint8_2d.data()); + + vector<int32_t> col_offsets(conv_p_2d.OC); + + DoNothing<> doNothingObj{}; + ReQuantizeOutput<false, QuantizationGranularity::TENSOR> outputProcObj( + doNothingObj, + C_multiplier.data(), + C_zero_point, + Aint8_zero_point, + Bint8_zero_point.data(), + nullptr, // row offsets + col_offsets.data(), + nullptr, // bias + conv_p_2d.OC, + conv_p_2d.G); + + try { + conv_p_2d.stride[0] = 2; + fbgemmConv( + conv_p_2d, + Aint8.data(), + packedB_2D, + Cint8_fb.data(), + Cint32_fb.data(), + outputProcObj, + 0, + 1); + } catch (std::logic_error const& err) { + std::string s(err.what()); + EXPECT_TRUE(s.rfind("[FBGEMM_CONV_ERROR]", 0) == 0); + } + + // reset + conv_p_2d.stride[0] = stride; + // this should run fine + fbgemmConv( + conv_p_2d, + Aint8.data(), + packedB_2D, + Cint8_fb.data(), + Cint32_fb.data(), + outputProcObj, + 0, + 1); +} + +/** + * @brief Unit test for uint8 activations, int8 weights, and 32-bit + * accumulation. Output processing: requantization -> nothing + */ +TEST_P(UniConvQGranTest, requantizeTest) { + vector<conv_param_t<>> shapes(GetShapes_()); + QuantizationGranularity q_granularity; + bool a_symmetric, b_symmetric; + bool test_bias, test_float_bias; + tie(q_granularity, a_symmetric, b_symmetric, test_bias, test_float_bias) = + GetParam(); + + for (auto conv_p : shapes) { + int R = conv_p.K[0]; + int S = conv_p.K[1]; + int G = conv_p.G; + int OC = conv_p.OC; + int OH = conv_p.OUT_DIM[0]; + int OW = conv_p.OUT_DIM[1]; + int IC_per_G = conv_p.IC / conv_p.G; + int OC_per_G = conv_p.OC / conv_p.G; + + // activations + aligned_vector<uint8_t> Aint8( + conv_p.MB * conv_p.IN_DIM[0] * conv_p.IN_DIM[1] * conv_p.IC, 0); + + // weights + // The weight matrix is in layout G K/G (R S C/G) + aligned_vector<int8_t> Bint8(R * S * conv_p.G * IC_per_G * OC_per_G, 0); + aligned_vector<int8_t> Bint8_tr(R * S * G * IC_per_G * OC_per_G, 0); + + aligned_vector<int32_t> Cint32_ref(conv_p.MB * OH * OW * OC, 0); + aligned_vector<int32_t> Cint32_fb(Cint32_ref.size(), 0); + aligned_vector<uint8_t> Cint8_ref(Cint32_ref.size(), 0); + aligned_vector<uint8_t> Cint8_fb(Cint32_ref.size(), 0); + + randFill<uint8_t>(Aint8, 0, 5); + int32_t Aint8_zero_point = a_symmetric ? 0 : 4; + + randFill<int8_t>(Bint8, -4, 4); + + // computing column offset + vector<int32_t> col_offsets(G * OC_per_G); + + int ncols_per_quant_group = G * OC_per_G; + if (q_granularity == QuantizationGranularity::GROUP) { + ncols_per_quant_group = OC_per_G; + } else if (q_granularity == QuantizationGranularity::OUT_CHANNEL) { + ncols_per_quant_group = 1; + } + + aligned_vector<int32_t> Bint8_zero_point( + G * OC_per_G / ncols_per_quant_group); + if (b_symmetric) { + randFill(Bint8_zero_point, -3, 3); + } else { + randFill(Bint8_zero_point, 0, 0); + } + + // matrix dimensions after im2col for each GEMM. + // For each group, there is one GEMM of the following dimensions + int MDim = conv_p.MB * OH * OW; + int NDim = OC_per_G; + int KDim = R * S * IC_per_G; + + vector<uint8_t> Aint8_im2col(MDim * KDim * G); + im2col_ref(conv_p, Aint8.data(), Aint8_zero_point, Aint8_im2col.data()); + + vector<int32_t> row_offsets(MDim); + + // activation_scale * weight_scale + aligned_vector<float> act_times_w_scale(Bint8_zero_point.size()); + randFill(act_times_w_scale, 0.1234f / 2, 0.1234f * 3 / 2); + + float out_scale = 2.0f; + aligned_vector<float> C_multiplier(Bint8_zero_point.size()); + transform( + act_times_w_scale.begin(), + act_times_w_scale.end(), + C_multiplier.begin(), + [&out_scale](float i) { return i / out_scale; }); + + int32_t C_zero_pt = 5; + + // initialize bias + aligned_vector<int32_t> bias_int32(OC); + aligned_vector<float> bias_fp32(OC); + if (test_bias) { + randFill(bias_int32, -8, 8); + } + + // floating point bias + if (test_float_bias) { + if (q_granularity == QuantizationGranularity::TENSOR) { + transform( + bias_int32.begin(), + bias_int32.end(), + bias_fp32.begin(), + [&act_times_w_scale](float i) { return i * act_times_w_scale[0]; }); + } else if (q_granularity == QuantizationGranularity::GROUP) { + for (int g = 0; g < G; ++g) { + for (int c = 0; c < OC_per_G; ++c) { + bias_fp32[g * OC_per_G + c] = act_times_w_scale[g] * + static_cast<float>(bias_int32[g * OC_per_G + c]); + } + } + } else { // OUT_CHANNEL + transform( + act_times_w_scale.begin(), + act_times_w_scale.end(), + bias_int32.begin(), + bias_fp32.begin(), + multiplies<float>()); + } + } + // reference implementation + // conv_ref expects weights to be in G (R S C/G) K/G + int8_t* rightBData = Bint8.data(); + transposeConvWeights(conv_p, Bint8.data(), Bint8_tr.data()); + rightBData = Bint8_tr.data(); + for (int g = 0; g < G; ++g) { + col_offsets_with_zero_pt_s8acc32_ref( + R * S * IC_per_G, + OC_per_G, + OC_per_G, + rightBData + g * R * S * IC_per_G * OC_per_G, + Bint8_zero_point.data() + g * OC_per_G / ncols_per_quant_group, + col_offsets.data() + g * OC_per_G, + ncols_per_quant_group); + } + conv_ref( + conv_p, Aint8.data(), Aint8_zero_point, rightBData, Cint32_ref.data()); + + for (int g = 0; g < G; ++g) { + row_offsets_u8acc32_ref( + MDim, + KDim, + KDim * G, + Aint8_im2col.data() + g * KDim, + row_offsets.data()); + + requantize_u8acc32_ref( + MDim, + NDim, + G * NDim, + Cint32_ref.data() + g * NDim, + Cint8_ref.data() + g * NDim, + C_multiplier.data() + g * NDim / ncols_per_quant_group, + C_zero_pt, + Aint8_zero_point, + Bint8_zero_point.data() + g * NDim / ncols_per_quant_group, + row_offsets.data(), + col_offsets.data() + g * NDim, + test_bias ? bias_int32.data() + g * NDim : nullptr, + ncols_per_quant_group); + } + + PackWeightsForConv<2> packedWeights(conv_p, Bint8.data()); + + // TODO: Uncomment once we support multiple threads in fbgemmGroupwiseConv + // #ifdef _OPENMP + // #pragma omp parallel + // #endif + { + vector<int32_t> row_offset_buf(rowOffsetBufferSizeGConv(conv_p)); + + DoNothing<> doNothingObj{}; + + int num_threads = fbgemm_get_num_threads(); + int tid = fbgemm_get_thread_num(); + + if (q_granularity == QuantizationGranularity::TENSOR) { + if (test_float_bias) { + ReQuantizeOutput<false, QuantizationGranularity::TENSOR, float> + reqObj( + doNothingObj, + C_multiplier.data(), + C_zero_pt, + Aint8_zero_point, + Bint8_zero_point.data(), + nullptr, /* row offset buffer */ + col_offsets.data(), + test_bias ? bias_fp32.data() : nullptr, + G * NDim, + G, + act_times_w_scale.data()); + + fbgemmConv( + conv_p, + Aint8.data(), + packedWeights, + Cint8_fb.data(), + Cint32_fb.data(), + reqObj, + tid, + num_threads); + + } else { + ReQuantizeOutput<false, QuantizationGranularity::TENSOR> reqObj( + doNothingObj, + C_multiplier.data(), + C_zero_pt, + Aint8_zero_point, + Bint8_zero_point.data(), + nullptr, /* row offset buffer */ + col_offsets.data(), + test_bias ? bias_int32.data() : nullptr, + G * NDim, + G); + + fbgemmConv( + conv_p, + Aint8.data(), + packedWeights, + Cint8_fb.data(), + Cint32_fb.data(), + reqObj, + tid, + num_threads); + } + + } else if (q_granularity == QuantizationGranularity::GROUP) { + if (test_float_bias) { + ReQuantizeOutput<false, QuantizationGranularity::GROUP, float> reqObj( + doNothingObj, + C_multiplier.data(), + C_zero_pt, + Aint8_zero_point, + Bint8_zero_point.data(), + nullptr, /* row offset buffer */ + col_offsets.data(), + test_bias ? bias_fp32.data() : nullptr, + G * NDim, + G, + act_times_w_scale.data()); + + fbgemmConv( + conv_p, + Aint8.data(), + packedWeights, + Cint8_fb.data(), + Cint32_fb.data(), + reqObj, + tid, + num_threads); + + } else { + ReQuantizeOutput<false, QuantizationGranularity::GROUP> reqObj( + doNothingObj, + C_multiplier.data(), + C_zero_pt, + Aint8_zero_point, + Bint8_zero_point.data(), + nullptr, /* row offset buffer */ + col_offsets.data(), + test_bias ? bias_int32.data() : nullptr, + G * NDim, + G); + + fbgemmConv( + conv_p, + Aint8.data(), + packedWeights, + Cint8_fb.data(), + Cint32_fb.data(), + reqObj, + tid, + num_threads); + } + + } else { + if (test_float_bias) { + ReQuantizeOutput<false, QuantizationGranularity::OUT_CHANNEL, float> + reqObj( + doNothingObj, + C_multiplier.data(), + C_zero_pt, + Aint8_zero_point, + Bint8_zero_point.data(), + nullptr, /* row offset buffer */ + col_offsets.data(), + test_bias ? bias_fp32.data() : nullptr, + G * NDim, + G, + act_times_w_scale.data()); + + fbgemmConv( + conv_p, + Aint8.data(), + packedWeights, + Cint8_fb.data(), + Cint32_fb.data(), + reqObj, + tid, + num_threads); + + } else { + ReQuantizeOutput<false, QuantizationGranularity::OUT_CHANNEL> reqObj( + doNothingObj, + C_multiplier.data(), + C_zero_pt, + Aint8_zero_point, + Bint8_zero_point.data(), + nullptr, /* row offset buffer */ + col_offsets.data(), + test_bias ? bias_int32.data() : nullptr, + G * NDim, + G); + + fbgemmConv( + conv_p, + Aint8.data(), + packedWeights, + Cint8_fb.data(), + Cint32_fb.data(), + reqObj, + tid, + num_threads); + } + } + } // omp parallel + + compare_validate_buffers( + Cint8_ref.data(), + Cint8_fb.data(), + MDim, + NDim * G, + NDim * G, + static_cast<uint8_t>(0)); + } // for each shape +} diff --git a/third_party/asmjit b/third_party/asmjit -Subproject 673dcefaa048c5f5a2bf8b85daf8f7b9978d018 +Subproject 4da474ac9aa2689e88d5e40a2f37628f302d7e3 |