diff options
author | Young Jin Kim <youki@microsoft.com> | 2019-09-25 19:40:48 +0300 |
---|---|---|
committer | Young Jin Kim <youki@microsoft.com> | 2019-09-25 19:40:48 +0300 |
commit | 7bd598c9e97871e42c19449fddf7bd317898eb58 (patch) | |
tree | a3b1fea18477b63add473037d0644a96b115e0da | |
parent | 08763b198ef743741560ae42a9c10a3017c7c9ce (diff) | |
parent | 518d8a1832cf1eb1dda2feace1a278e9e4f302ba (diff) |
Merge remote-tracking branch 'upstream/master' into youki/win-jit-debug-int8
Fix for windows build errors
43 files changed, 6032 insertions, 4613 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt index 0460799..c06b60b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -89,8 +89,10 @@ endif() #All the source files that either use avx2 instructions statically set(FBGEMM_AVX2_SRCS src/FbgemmFP16UKernelsAvx2.cc + src/FbgemmI8Depthwise3DAvx2.cc src/FbgemmI8DepthwiseAvx2.cc src/OptimizedKernelsAvx2.cc + src/PackDepthwiseConvMatrixAvx2.cc src/QuantUtilsAvx2.cc src/UtilsAvx2.cc) diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md index 0f7ad8b..d1abc70 100644 --- a/CODE_OF_CONDUCT.md +++ b/CODE_OF_CONDUCT.md @@ -1,5 +1,77 @@ # Code of Conduct -Facebook has adopted a Code of Conduct that we expect project participants to adhere to. -Please read the [full text](https://code.fb.com/codeofconduct/) -so that you can understand what actions will and will not be tolerated. +## Our Pledge + +In the interest of fostering an open and welcoming environment, we as +contributors and maintainers pledge to make participation in our project and +our community a harassment-free experience for everyone, regardless of age, body +size, disability, ethnicity, sex characteristics, gender identity and expression, +level of experience, education, socio-economic status, nationality, personal +appearance, race, religion, or sexual identity and orientation. + +## Our Standards + +Examples of behavior that contributes to creating a positive environment +include: + +* Using welcoming and inclusive language +* Being respectful of differing viewpoints and experiences +* Gracefully accepting constructive criticism +* Focusing on what is best for the community +* Showing empathy towards other community members + +Examples of unacceptable behavior by participants include: + +* The use of sexualized language or imagery and unwelcome sexual attention or + advances +* Trolling, insulting/derogatory comments, and personal or political attacks +* Public or private harassment +* Publishing others' private information, such as a physical or electronic + address, without explicit permission +* Other conduct which could reasonably be considered inappropriate in a + professional setting + +## Our Responsibilities + +Project maintainers are responsible for clarifying the standards of acceptable +behavior and are expected to take appropriate and fair corrective action in +response to any instances of unacceptable behavior. + +Project maintainers have the right and responsibility to remove, edit, or +reject comments, commits, code, wiki edits, issues, and other contributions +that are not aligned to this Code of Conduct, or to ban temporarily or +permanently any contributor for other behaviors that they deem inappropriate, +threatening, offensive, or harmful. + +## Scope + +This Code of Conduct applies within all project spaces, and it also applies when +an individual is representing the project or its community in public spaces. +Examples of representing a project or community include using an official +project e-mail address, posting via an official social media account, or acting +as an appointed representative at an online or offline event. Representation of +a project may be further defined and clarified by project maintainers. + +## Enforcement + +Instances of abusive, harassing, or otherwise unacceptable behavior may be +reported by contacting the project team at <opensource-conduct@fb.com>. All +complaints will be reviewed and investigated and will result in a response that +is deemed necessary and appropriate to the circumstances. The project team is +obligated to maintain confidentiality with regard to the reporter of an incident. +Further details of specific enforcement policies may be posted separately. + +Project maintainers who do not follow or enforce the Code of Conduct in good +faith may face temporary or permanent repercussions as determined by other +members of the project's leadership. + +## Attribution + +This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, +available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html + +[homepage]: https://www.contributor-covenant.org + +For answers to common questions about this code of conduct, see +https://www.contributor-covenant.org/faq + diff --git a/bench/ConvUnifiedBenchmark.cc b/bench/ConvUnifiedBenchmark.cc index b450beb..35a3b2a 100644 --- a/bench/ConvUnifiedBenchmark.cc +++ b/bench/ConvUnifiedBenchmark.cc @@ -24,33 +24,39 @@ using namespace std; using namespace fbgemm; +// clang-format off // 2D conv shapes vector<conv_param_t<2>> shapes_2d = { - // MB, IC, OC, IH, IW, G, KH, KW, stride_h, stride_w, - // pad_h_top, pad_w_left, pad_h_bottom, pad_w_right - // 2D convolutions - // regular - conv_param_t<>(1, 128, 128, {56, 56}, 1, {3, 3}, {1, 1}, {1, 1, 1, 1}), - // groupwise - conv_param_t<>(1, 128, 128, {56, 56}, 32, {3, 3}, {1, 1}, {1, 1, 1, 1}), - // DW - conv_param_t<>(1, 272, 272, {47, 125}, 272, {3, 3}, {1, 1}, {1, 1, 1, 1}), - // Pointwise - conv_param_t<>(1, 128, 128, {56, 56}, 1, {1, 1}, {1, 1}, {0, 0, 0, 0}) + // MB, IC, OC, IH, IW, G, KH, KW, stride_h, stride_w, + // pad_h_top, pad_w_left, pad_h_bottom, pad_w_right + // 2D convolutions + // regular + conv_param_t<>(1, 128, 128, {56, 56}, 1, {3, 3}, {1, 1}, {1, 1, 1, 1}), + // regular with dilation + conv_param_t<>(1, 128, 128, {56, 56}, 1, {3, 3}, {1, 1}, {1, 1, 1, 1}, {2, 2}), + // groupwise + conv_param_t<>(1, 128, 128, {56, 56}, 32, {3, 3}, {1, 1}, {1, 1, 1, 1}), + // DW + conv_param_t<>(1, 272, 272, {47, 125}, 272, {3, 3}, {1, 1}, {1, 1, 1, 1}), + // Pointwise + conv_param_t<>(1, 128, 128, {56, 56}, 1, {1, 1}, {1, 1}, {0, 0, 0, 0}) }; // 3D conv shapes vector<conv_param_t<3>> shapes_3d = { - // MB, IC, OC, {IT, IH, IW}, G, {KT, KH, KW}, {stride_t, stride_h, stride_w}, - // {pad_prev, pad_h_top, pad_w_left, pad_next, pad_h_bottom, pad_w_right} - // Regular - conv_param_t<3>(1, 64, 64, {8, 14, 14}, 1, {3, 3, 3}, {1, 1, 1}, {1, 1, 1, 1, 1, 1}), - // Depthwise - conv_param_t<3>(1, 64, 64, {8, 14, 14}, 64, {3, 3, 3}, {1, 1, 1}, {1, 1, 1, 1, 1, 1}), - // Pointwise - conv_param_t<3>(1, 128, 128, {8, 14, 14}, 1, {1, 1, 1}, {1, 1, 1}, {0, 0, 0, 0}) -}; + // MB, IC, OC, {IT, IH, IW}, G, {KT, KH, KW}, {stride_t, stride_h, + // stride_w}, + // {pad_prev, pad_h_top, pad_w_left, pad_next, pad_h_bottom, pad_w_right} + // Regular + conv_param_t<3>(1, 64, 64, {8, 14, 14}, 1, {3, 3, 3}, {1, 1, 1}, {1, 1, 1, 1, 1, 1}), + //With dilations + conv_param_t<3>(1, 64, 64, {8, 14, 14}, 1, {3, 3, 3}, {1, 1, 1}, {1, 1, 1, 1, 1, 1}, {2, 2, 2}), + // Depthwise + conv_param_t<3>(1, 64, 64, {8, 14, 14}, 64, {3, 3, 3}, {1, 1, 1}, {1, 1, 1, 1, 1, 1}), + // Pointwise + conv_param_t<3>(1, 128, 128, {8, 14, 14}, 1, {1, 1, 1}, {1, 1, 1}, {0, 0, 0, 0})}; +// clang-format on template <int SPATIAL_DIM, typename Acc_t> void performance_test(const vector<conv_param_t<SPATIAL_DIM>>& shapes) { @@ -81,6 +87,10 @@ void performance_test(const vector<conv_param_t<SPATIAL_DIM>>& shapes) { header += "pad_t, "; } header += "pad_h, pad_w, "; + if (SPATIAL_DIM == 3) { + header += "dilation_t, "; + } + header += "dilation_h, dilation_w, "; header += "Type, M, N, K, "; @@ -278,7 +288,9 @@ void performance_test(const vector<conv_param_t<SPATIAL_DIM>>& shapes) { for (int i = 0; i < SPATIAL_DIM; ++i) { cout << conv_p.pad[i] << ", "; } - + for (int i = 0; i < SPATIAL_DIM; ++i) { + cout << conv_p.dilation[i] << ", "; + } cout << setw(13) << runType << ", " << setw(5) << fixed << setw(5) << setw(6) << MDim << ", " << setw(6) << NDim << ", " << setw(6) << KDim << ", "; diff --git a/bench/Depthwise3DBenchmark.cc b/bench/Depthwise3DBenchmark.cc index 0efdcac..ff2be6f 100644 --- a/bench/Depthwise3DBenchmark.cc +++ b/bench/Depthwise3DBenchmark.cc @@ -4,7 +4,6 @@ * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. */ -#include "test/I8DepthwiseTest.h" #include <algorithm> #include <chrono> @@ -19,8 +18,8 @@ #include "AlignedVec.h" #include "BenchUtils.h" -#include "fbgemm/Utils.h" #include "fbgemm/FbgemmI8DepthwiseAvx2.h" +#include "fbgemm/Utils.h" #include "src/RefImplementations.h" using namespace std; @@ -35,6 +34,34 @@ int main() { } #endif + // From ResNeXt-3D-101 + // clang-format off + vector<vector<int>> shapes_3d = { + // NOTE: clang-format wants to use a different formatting but the current + // formatting should be easier to read. + // N, K, T_in, H_in, W_in, stride + { 1, 64, 32, 56, 56, 1, }, + { 1, 128, 16, 28, 28, 1, }, + { 1, 256, 8, 14, 14, 1, }, + { 1, 512, 4, 7, 7, 1, }, + + { 1, 128, 32, 56, 56, 2, }, + { 1, 256, 16, 28, 28, 2, }, + { 1, 512, 8, 14, 14, 2, }, + + { 5, 64, 32, 56, 56, 1, }, + { 5, 128, 16, 28, 28, 1, }, + { 5, 256, 8, 14, 14, 1, }, + { 5, 512, 4, 7, 7, 1, }, + + { 5, 128, 32, 56, 56, 2, }, + { 5, 256, 16, 28, 28, 2, }, + { 5, 512, 8, 14, 14, 2, }, + + { 1, 8, 4, 4, 4, 1, }, + }; + // clang-format on + // Depthwise is memory BW bound so we want to flush LLC. bool flush = true; std::vector<char> llc; @@ -61,14 +88,28 @@ int main() { constexpr int K_T = 3, K_H = 3, K_W = 3; constexpr int PAD_P = 1, PAD_N = 1, PAD_T = 1, PAD_B = 1, PAD_L = 1, PAD_R = 1; - int T_OUT = (T + PAD_P + PAD_N - K_T) / stride_t + 1; - int H_OUT = (H + PAD_T + PAD_B - K_H) / stride_h + 1; - int W_OUT = (W + PAD_L + PAD_R - K_W) / stride_w + 1; + + conv_param_t<3> conv_p( + N, + K, + K, + {T, H, W}, + K, + {K_T, K_H, K_W}, + {stride_t, stride_h, stride_w}, + {PAD_P, PAD_T, PAD_L, PAD_N, PAD_B, PAD_R}); + int T_OUT = conv_p.OUT_DIM[0]; + int H_OUT = conv_p.OUT_DIM[1]; + int W_OUT = conv_p.OUT_DIM[2]; + + int MDim = N * T_OUT * H_OUT * W_OUT; + int KDim = K_T * K_H * K_W * K; + int KDimPerGroup = KDim / conv_p.G; aligned_vector<uint8_t> A(N * T * H * W * K); - aligned_vector<int8_t> B(K * K_T * K_H * K_W); - aligned_vector<int32_t> C_ref(N * T_OUT * H_OUT * W_OUT * K), - C(C_ref.size()); + aligned_vector<int8_t> B(KDim); + aligned_vector<int32_t> C_ref(MDim * K), C(C_ref.size()); + aligned_vector<uint8_t> C_uint8_ref(C_ref.size()), C_uint8(C_ref.size()); randFill<uint8_t>(A, 0, 86); int32_t A_zero_point = 43; @@ -76,52 +117,49 @@ int main() { randFill<int8_t>(B, -16, 16); int32_t B_zero_point = 5; - depthwise_3x3x3_pad_1_ref( - N, - T, - H, - W, - K, - stride_t, - stride_h, - stride_w, - A_zero_point, - A.data(), - B.data(), - C_ref.data()); - - int32_t minimum = *min_element(C_ref.begin(), C_ref.end()); - int32_t maximum = *max_element(C_ref.begin(), C_ref.end()); + aligned_vector<float> C_multiplier(1); + randFill(C_multiplier, 0.001234f / 2, 0.001234f * 3 / 2); + int32_t C_zero_point = 5; - float C_multiplier = 255. / (maximum - minimum); + vector<int32_t> row_offsets(MDim); + // im2col to compute row offset later + vector<uint8_t> A_im2col(MDim * KDim); + im2col_ref(conv_p, A.data(), A_zero_point, A_im2col.data()); aligned_vector<int32_t> col_offsets(K); aligned_vector<int32_t> bias(K); randFill(col_offsets, -100, 100); randFill(bias, -40, 40); - int32_t C_zero_point = 5; - aligned_vector<uint8_t> C_uint8_ref(C_ref.size()), C_uint8(C_ref.size()); - depthwise_3x3x3_pad_1_ref( - N, - T, - H, - W, - K, - stride_t, - stride_h, - stride_w, - A_zero_point, - A.data(), - B_zero_point, - B.data(), - C_multiplier, - C_zero_point, - C_uint8_ref.data(), - col_offsets.data(), - bias.data()); - - Packed3x3x3ConvMatrix Bp(K, B.data()); + conv_ref(conv_p, A.data(), A_zero_point, B.data(), C_ref.data()); + + for (int g = 0; g < conv_p.G; ++g) { + // Compute row offset + row_offsets_u8acc32_ref( + MDim, + KDimPerGroup, + KDim, + A_im2col.data() + g * KDimPerGroup, + row_offsets.data()); + + // Requantization + requantize_u8acc32_ref( + MDim, + 1, + conv_p.G, + C_ref.data() + g, + C_uint8_ref.data() + g, + C_multiplier.data(), + C_zero_point, + A_zero_point, + &B_zero_point, + row_offsets.data(), + col_offsets.data() + g, + bias.data() + g, + K); + } + + PackedDepthWiseConvMatrix Bp(K, 3 * 3 * 3, B.data()); double ttot = 0; double bytes = double(NITER) * @@ -153,7 +191,7 @@ int main() { A.data(), B_zero_point, Bp, - C_multiplier, + C_multiplier[0], C_zero_point, C_uint8.data(), col_offsets.data(), diff --git a/bench/DepthwiseBenchmark.cc b/bench/DepthwiseBenchmark.cc index 96921a1..6c2ee17 100644 --- a/bench/DepthwiseBenchmark.cc +++ b/bench/DepthwiseBenchmark.cc @@ -17,8 +17,8 @@ #include "AlignedVec.h" #include "BenchUtils.h" -#include "fbgemm/Utils.h" #include "fbgemm/FbgemmI8DepthwiseAvx2.h" +#include "fbgemm/Utils.h" #include "src/RefImplementations.h" using namespace std; @@ -34,10 +34,11 @@ int main() { #endif // From Xray OCR + // clang-format off vector<vector<int>> shapes = { // NOTE: clang-format wants to use a different formatting but the current // formatting should be easier to read. - // N, G, H_in, W_in, stride + // N, K, H_in, W_in, stride { 1, 272, 47, 125, 1, }, { 1, 272, 64, 125, 1, }, { 1, 272, 66, 125, 1, }, @@ -138,6 +139,7 @@ int main() { { 96, 544, 14, 14, 2, }, { 100, 544, 14, 14, 2, }, }; + // clang-format on // Depthwise is memory BW bound so we want to flush LLC. bool flush = true; @@ -155,19 +157,35 @@ int main() { for (auto shape : shapes) { int N = shape[0]; - int G = shape[1]; + int K = shape[1]; int H = shape[2]; int W = shape[3]; int stride_h = shape[4]; int stride_w = stride_h; constexpr int R = 3, S = 3; - constexpr int PAD_T = 1, PAD_B = 1, PAD_L = 1, PAD_R = 1; - int H_OUT = (H + PAD_T + PAD_B - R) / stride_h + 1; - int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1; + int PAD_T = (R - 1) / 2, PAD_B = (R - 1) / 2, PAD_L = (S - 1) / 2, + PAD_R = (S - 1) / 2; + + conv_param_t<2> conv_p( + N, + K, + K, + {H, W}, + K, + {R, S}, + {stride_h, stride_w}, + {PAD_T, PAD_L, PAD_B, PAD_R}); + int H_OUT = conv_p.OUT_DIM[0]; + int W_OUT = conv_p.OUT_DIM[1]; + + int MDim = N * H_OUT * W_OUT; + int KDim = R * S * K; + int KDimPerGroup = KDim / conv_p.G; - aligned_vector<uint8_t> A(N * H * W * G); - aligned_vector<int8_t> B(G * R * S); - aligned_vector<int32_t> C_ref(N * H_OUT * W_OUT * G), C(C_ref.size()); + aligned_vector<uint8_t> A(N * H * W * K); + aligned_vector<int8_t> B(KDim); + aligned_vector<int32_t> C_ref(MDim * K), C(C_ref.size()); + aligned_vector<uint8_t> C_uint8_ref(C_ref.size()), C_uint8(C_ref.size()); randFill<uint8_t>(A, 0, 86); int32_t A_zero_point = 43; @@ -175,53 +193,54 @@ int main() { randFill<int8_t>(B, -16, 16); int32_t B_zero_point = 5; - depthwise_3x3_pad_1_ref( - N, - H, - W, - G, - stride_h, - stride_w, - A_zero_point, - A.data(), - B.data(), - C_ref.data()); - - int32_t minimum = *min_element(C_ref.begin(), C_ref.end()); - int32_t maximum = *max_element(C_ref.begin(), C_ref.end()); + aligned_vector<float> C_multiplier(1); + randFill(C_multiplier, 0.001234f / 2, 0.001234f * 3 / 2); + int32_t C_zero_point = 5; - float C_multiplier = 255. / (maximum - minimum); + vector<int32_t> row_offsets(MDim); + // im2col to compute row offset later + vector<uint8_t> A_im2col(MDim * KDim); + im2col_ref(conv_p, A.data(), A_zero_point, A_im2col.data()); - aligned_vector<int32_t> col_offsets(G); - aligned_vector<int32_t> bias(G); + aligned_vector<int32_t> col_offsets(K); + aligned_vector<int32_t> bias(K); randFill(col_offsets, -100, 100); randFill(bias, -40, 40); - int32_t C_zero_point = 5; - aligned_vector<uint8_t> C_uint8_ref(C_ref.size()), C_uint8(C_ref.size()); - depthwise_3x3_pad_1_ref( - N, - H, - W, - G, - stride_h, - stride_w, - A_zero_point, - A.data(), - B_zero_point, - B.data(), - C_multiplier, - C_zero_point, - C_uint8_ref.data(), - col_offsets.data(), - bias.data()); + conv_ref(conv_p, A.data(), A_zero_point, B.data(), C_ref.data()); + + for (int g = 0; g < conv_p.G; ++g) { + // Compute row offset + row_offsets_u8acc32_ref( + MDim, + KDimPerGroup, + KDim, + A_im2col.data() + g * KDimPerGroup, + row_offsets.data()); + + // Requantization + requantize_u8acc32_ref( + MDim, + 1, + conv_p.G, + C_ref.data() + g, + C_uint8_ref.data() + g, + C_multiplier.data(), + C_zero_point, + A_zero_point, + &B_zero_point, + row_offsets.data(), + col_offsets.data() + g, + bias.data() + g, + K); + } - Packed3x3ConvMatrix Bp(G, B.data()); + PackedDepthWiseConvMatrix Bp(K, 3 * 3, B.data()); double ttot = 0; double bytes = double(NITER) * - (G * (N * (2 * sizeof(int32_t) * H_OUT * W_OUT + H * W) + R * S)); - double ops = double(NITER) * N * H_OUT * W_OUT * G * R * S * 2; + (K * (N * (2 * sizeof(int32_t) * H_OUT * W_OUT + H * W) + R * S)); + double ops = double(NITER) * N * H_OUT * W_OUT * K * R * S * 2; chrono::time_point<chrono::system_clock> t_begin, t_end; for (int i = 0; i < NWARMUP + NITER; ++i) { llc_flush(); @@ -235,19 +254,20 @@ int main() { N, H, W, - G, + K, stride_h, stride_w, A_zero_point, A.data(), B_zero_point, Bp, - C_multiplier, + C_multiplier[0], C_zero_point, C_uint8.data(), col_offsets.data(), bias.data(), false, /* fuse_relu */ + 1.0f, /* act_scale * w_scale */ tid, num_threads); } @@ -262,10 +282,10 @@ int main() { for (int n = 0; n < N; ++n) { for (int h = 0; h < H_OUT; ++h) { for (int w = 0; w < W_OUT; ++w) { - for (int g = 0; g < G; ++g) { + for (int g = 0; g < K; ++g) { uint8_t expected = - C_uint8_ref[((n * H_OUT + h) * W_OUT + w) * G + g]; - uint8_t actual = C_uint8[((n * H_OUT + h) * W_OUT + w) * G + g]; + C_uint8_ref[((n * H_OUT + h) * W_OUT + w) * K + g]; + uint8_t actual = C_uint8[((n * H_OUT + h) * W_OUT + w) * K + g]; if (expected != actual) { cerr << "Depthwise 3x3 results differ at (" << n << ", " << h << ", " << w << ", " << g << "). expected " << (int)expected @@ -280,9 +300,9 @@ int main() { // Report performance printf( - "N = %d G = %d H = %d W = %d stride = %d with requantization fused\n", + "N = %d K = %d H = %d W = %d stride = %d with requantization fused\n", N, - G, + K, H, W, stride_h); diff --git a/bench/GEMMsBenchmark.cc b/bench/GEMMsBenchmark.cc index b404d8b..f493a96 100644 --- a/bench/GEMMsBenchmark.cc +++ b/bench/GEMMsBenchmark.cc @@ -28,6 +28,7 @@ using namespace std; using namespace fbgemm; void performance_test() { + // clang-format off static const vector<vector<int>> shapes = { // NOTE: clang-format wants to use a different formatting but the current // formatting should be easier to read. @@ -39,6 +40,7 @@ void performance_test() { {256, 512, 256}, {1024, 1024, 1024}, }; + // clang-format on bool flush = true; std::vector<char> llc; diff --git a/bench/GEMMsTunableBenchmark.cc b/bench/GEMMsTunableBenchmark.cc index a65b51f..2adc556 100644 --- a/bench/GEMMsTunableBenchmark.cc +++ b/bench/GEMMsTunableBenchmark.cc @@ -218,7 +218,8 @@ int main(int /* unused */, char** /* unused */) { } #endif - vector<vector<int>> shapes = { + // clang-format off + vector<vector<int>> shapes = { // NOTE: clang-format wants to use a different formatting but the current // formatting should be easier to read. // m, n, k @@ -266,7 +267,8 @@ int main(int /* unused */, char** /* unused */) { {128, 128, 128}, {256, 512, 256}, {1024, 1024, 1024}, -}; + }; + // clang-format on vector<int> MCBs; vector<int> NCBs; diff --git a/bench/PackedFloatInOutBenchmark.cc b/bench/PackedFloatInOutBenchmark.cc index 66ca67e..dcea65c 100644 --- a/bench/PackedFloatInOutBenchmark.cc +++ b/bench/PackedFloatInOutBenchmark.cc @@ -28,6 +28,7 @@ using namespace std; using namespace fbgemm; void performance_test() { + // clang-format off vector<vector<int>> shapes = { // NOTE: clang-format wants to use a different formatting but the current // formatting should be easier to read. @@ -66,6 +67,7 @@ void performance_test() { {1, 128, 2722}, {16, 256, 512}, }; + // clang-format on bool flush = true; std::vector<char> llc; diff --git a/bench/PackedRequantizeAcc16Benchmark.cc b/bench/PackedRequantizeAcc16Benchmark.cc index 40ff662..c6e2869 100644 --- a/bench/PackedRequantizeAcc16Benchmark.cc +++ b/bench/PackedRequantizeAcc16Benchmark.cc @@ -37,6 +37,7 @@ enum class BenchmarkType { }; void performance_test() { + // clang-format off vector<vector<int>> shapes = { // NOTE: clang-format wants to use a different formatting but the current // formatting should be easier to read. @@ -67,6 +68,7 @@ void performance_test() { {392, 2048, 512}, {392, 512, 2048}, }; + // clang-format on bool flush = true; std::vector<char> llc; diff --git a/bench/PackedRequantizeAcc32Benchmark.cc b/bench/PackedRequantizeAcc32Benchmark.cc index 2f04795..a61ef5a 100644 --- a/bench/PackedRequantizeAcc32Benchmark.cc +++ b/bench/PackedRequantizeAcc32Benchmark.cc @@ -28,6 +28,7 @@ using namespace std; using namespace fbgemm; void performance_test() { + // clang-format off vector<vector<int>> shapes = { // NOTE: clang-format wants to use a different formatting but the current // formatting should be easier to read. @@ -70,6 +71,7 @@ void performance_test() { {1, 128, 2722}, {16, 256, 512}, }; + // clang-format on bool flush = true; std::vector<char> llc; diff --git a/include/fbgemm/ConvUtils.h b/include/fbgemm/ConvUtils.h index 11f3dcc..5431958 100644 --- a/include/fbgemm/ConvUtils.h +++ b/include/fbgemm/ConvUtils.h @@ -8,9 +8,24 @@ #include <array> #include <string> +#include <type_traits> namespace fbgemm { +template <int N, int... Vals> +constexpr + typename std::enable_if<N == sizeof...(Vals), std::array<int, N>>::type + array_of_ones() { + return std::array<int, N>{{Vals...}}; +} + +template <int N, int... Vals> +constexpr + typename std::enable_if<N != sizeof...(Vals), std::array<int, N>>::type + array_of_ones() { + return array_of_ones<N, Vals..., 1>(); +} + /** * @brief A struct to conveniently store all convolution parameters. */ @@ -34,7 +49,6 @@ struct conv_param_t { /** * @brief Constructor for initializing the convolution parameters. - * TODO: Dilation is not handled correctly. */ conv_param_t( int mb, @@ -44,7 +58,8 @@ struct conv_param_t { int g, std::array<int, SPATIAL_DIM> k, std::array<int, SPATIAL_DIM> strd, - std::array<int, SPATIAL_DIM * 2> pd) + std::array<int, SPATIAL_DIM * 2> pd, + std::array<int, SPATIAL_DIM> dilations = array_of_ones<SPATIAL_DIM>()) : MB(mb), IC(ic), OC(oc), @@ -52,7 +67,8 @@ struct conv_param_t { G(g), K(k), stride(strd), - pad(pd) { + pad(pd), + dilation(dilations) { if (ic % g != 0) { throw std::runtime_error( "groups = " + std::to_string(g) + @@ -63,10 +79,10 @@ struct conv_param_t { "groups = " + std::to_string(g) + " does not divide number of output channels = " + std::to_string(oc)); } + for (int d = 0; d < SPATIAL_DIM; ++d) { - dilation[d] = 1; IN_DIMP[d] = IN_DIM[d] + pad[d] + pad[SPATIAL_DIM + d]; - OUT_DIM[d] = (IN_DIMP[d] - K[d]) / stride[d] + 1; + OUT_DIM[d] = (IN_DIMP[d] - dilation[d] * (K[d] - 1) - 1) / stride[d] + 1; } } @@ -102,8 +118,12 @@ struct conv_param_t { } for (int d = 0; d < SPATIAL_DIM * 2; ++d) { out += "pad_" + dim_string[3 - SPATIAL_DIM + (d % SPATIAL_DIM)] + ":" + - std::to_string(pad[d]); - if (d < SPATIAL_DIM * 2 - 1) { + std::to_string(pad[d]) + ", "; + } + for (int d = 0; d < SPATIAL_DIM; ++d) { + out += "dilation_" + dim_string[3 - SPATIAL_DIM + d] + ":" + + std::to_string(dilation[d]); + if (d < SPATIAL_DIM - 1) { out += ", "; } } @@ -121,6 +141,10 @@ struct conv_param_t { out += ", "; } } + for (int d = 0; d < SPATIAL_DIM; ++d) { + out += "dilation_" + std::to_string(d) + ":" + + std::to_string(dilation[d]) + ", "; + } } return out; } diff --git a/include/fbgemm/Fbgemm.h b/include/fbgemm/Fbgemm.h index 09ddbe1..668bd42 100644 --- a/include/fbgemm/Fbgemm.h +++ b/include/fbgemm/Fbgemm.h @@ -609,12 +609,16 @@ class FBGEMM_API PackWeightsForConv { return W_im2col_packed_; } - std::shared_ptr<Packed3x3ConvMatrix> getPackedWFor2DDW() { - return W_dw_2D_packed_; + std::shared_ptr<PackedDepthWiseConvMatrix> getPackedWForDepthwise() { + return W_dw_packed_; } - std::shared_ptr<Packed3x3x3ConvMatrix> getPackedWFor3DDW() { - return W_dw_3D_packed_; + std::shared_ptr<PackedDepthWiseConvMatrix> getPackedWFor2DDW() { + return W_dw_packed_; + } + + std::shared_ptr<PackedDepthWiseConvMatrix> getPackedWFor3DDW() { + return W_dw_packed_; } std::shared_ptr<PackWeightMatrixForGConv<T, accT, SPATIAL_DIM>> @@ -663,10 +667,8 @@ class FBGEMM_API PackWeightsForConv { const conv_param_t<SPATIAL_DIM> conv_param_; // Packed weights if we use im2col based convolution implementation std::shared_ptr<PackBMatrix<T, accT>> W_im2col_packed_; - // Packed weights if we use 2D depthwise convolution implementation - std::shared_ptr<Packed3x3ConvMatrix> W_dw_2D_packed_; - // Packed weights if we use 3D depthwise convolution implementation - std::shared_ptr<Packed3x3x3ConvMatrix> W_dw_3D_packed_; + // Packed weights if we use depthwise convolution implementation + std::shared_ptr<PackedDepthWiseConvMatrix> W_dw_packed_; // Packed weights if we use groupwise (small channels per group) convolution // implementation std::shared_ptr<PackWeightMatrixForGConv<T, accT, SPATIAL_DIM>> @@ -1161,12 +1163,15 @@ class FBGEMM_API DoSConvOnInpBuffer { template < bool FUSE_RELU, QuantizationGranularity Q_GRAN = QuantizationGranularity::TENSOR, + typename BIAS_TYPE = std::int32_t, typename outT = std::uint8_t, typename inT = std::int32_t, typename nextOPType = DoNothing<outT, outT>> class FBGEMM_API ReQuantizeOutput { public: static constexpr int RELU_FUSED = FUSE_RELU; + static constexpr QuantizationGranularity QGRANType = Q_GRAN; + using BIAS_T = BIAS_TYPE; using outType = outT; using inpType = inT; /** @@ -1187,6 +1192,8 @@ class FBGEMM_API ReQuantizeOutput { * See PackedRequantizeTest.cc for an example. * TODO: if Aq_zero_point == 0, allow passing nullptr. * @params bias can be nullptr otherwise the length should be nCol + * @params act_times_w_scale activation_scale * weight_scale. This is only + * used if bias is unquantized (i.e., float). */ ReQuantizeOutput( nextOPType& nextop, @@ -1196,9 +1203,10 @@ class FBGEMM_API ReQuantizeOutput { const std::int32_t* Bq_zero_point, const std::int32_t* row_offsets, const std::int32_t* col_offsets, - const std::int32_t* bias, + const BIAS_T* bias, std::uint32_t nCol, - int groups = 1) + int groups = 1, + const float* act_times_w_scale = nullptr) : nextop_(nextop), C_multiplier_(C_multiplier), C_zero_point_(C_zero_point), @@ -1208,7 +1216,8 @@ class FBGEMM_API ReQuantizeOutput { q_col_offsets_(col_offsets), bias_(bias), ncols_(nCol), - groups_(groups) {} + groups_(groups), + act_times_w_scale_(act_times_w_scale) {} template <inst_set_t instSet> inline int f( @@ -1236,12 +1245,15 @@ class FBGEMM_API ReQuantizeOutput { const std::int32_t* getColOffsets() const { return q_col_offsets_; } - const std::int32_t* getBias() const { + const BIAS_T* getBias() const { return bias_; } std::uint32_t getNCols() const { return ncols_; } + const float* getActWScale() const { + return act_times_w_scale_; + } void setRowOffsets(const std::int32_t* row_offsets) { q_row_offsets_ = row_offsets; @@ -1255,9 +1267,10 @@ class FBGEMM_API ReQuantizeOutput { const std::int32_t* Bq_zero_point_; const std::int32_t* q_row_offsets_; const std::int32_t* q_col_offsets_; - const std::int32_t* bias_; + const BIAS_T* bias_; std::uint32_t ncols_; int groups_; + const float* act_times_w_scale_; }; /** @@ -1410,7 +1423,8 @@ template < typename outType, bool FUSE_RELU, QuantizationGranularity Q_GRAN, - int SPATIAL_DIM = 2> + int SPATIAL_DIM = 2, + typename BIAS_TYPE = std::int32_t> FBGEMM_API void fbgemmGroupwiseConv( const conv_param_t<SPATIAL_DIM>& conv_param, const std::uint8_t* activations, @@ -1419,7 +1433,7 @@ FBGEMM_API void fbgemmGroupwiseConv( packed_W& packed_weights, outType* out, std::int32_t* outBuffer, - const ReQuantizeOutput<FUSE_RELU, Q_GRAN>& outProcess, + const ReQuantizeOutput<FUSE_RELU, Q_GRAN, BIAS_TYPE>& outProcess, int thread_id, int num_threads); diff --git a/include/fbgemm/FbgemmI8DepthwiseAvx2.h b/include/fbgemm/FbgemmI8DepthwiseAvx2.h index e7b0ec4..c454b16 100644 --- a/include/fbgemm/FbgemmI8DepthwiseAvx2.h +++ b/include/fbgemm/FbgemmI8DepthwiseAvx2.h @@ -11,19 +11,26 @@ namespace fbgemm { -// KERNEL_PROD is the product of all kernels. -// For example, KERNEL_PROD = 9 for 3x3, and 27 for 3x3x3. -template <int KERNEL_PROD> class FBGEMM_API PackedDepthWiseConvMatrix { public: - // smat in GRS layout - PackedDepthWiseConvMatrix(int K, const std::int8_t* smat); + /** + * @params K the number of channels (same as the number of groups because + * depth-wise convolution has one input/output channel per group) + * @params kernel_prod the product of all kernels. For example, kernel_prod = + * 9 for 3x3 conv, and 27 for 3x3x3 conv. + * @param smat the source unpacked weight in GRS layout + */ + PackedDepthWiseConvMatrix(int K, int kernel_prod, const std::int8_t* smat); virtual ~PackedDepthWiseConvMatrix(); const std::int8_t* PackedMat() const { return pmat_; } + int GetKernelProduct() const { + return kernel_prod_; + } + /** * @brief Unpacks pmat_ into unpack_data. * Used for recovering the weight matrix into the original format @@ -36,24 +43,26 @@ class FBGEMM_API PackedDepthWiseConvMatrix { int addr(int r, int c); private: - int K_; - std::int8_t* pmat_; -}; // Packed3x3ConvMatrix - -using Packed3x3ConvMatrix = PackedDepthWiseConvMatrix<3 * 3>; -using Packed3x3x3ConvMatrix = PackedDepthWiseConvMatrix<3 * 3 * 3>; -using Packed1ConvMatrix = PackedDepthWiseConvMatrix<1>; -using Packed2ConvMatrix = PackedDepthWiseConvMatrix<2>; -using Packed3ConvMatrix = PackedDepthWiseConvMatrix<3>; -using Packed4ConvMatrix = PackedDepthWiseConvMatrix<4>; -using Packed5ConvMatrix = PackedDepthWiseConvMatrix<5>; -using Packed10ConvMatrix = PackedDepthWiseConvMatrix<10>; -using Packed11ConvMatrix = PackedDepthWiseConvMatrix<11>; + const int K_; /**< the number of channels */ + const int kernel_prod_; /** the product of all kernel dims */ + std::int8_t* pmat_; /** packed weight */ +}; // PackedDepthWiseConvMatrix -/** - * Depth-wise 3x3 convolution with pad=1 and stride=1 and K a multiple of 8 - * @params A The input image in NHWK layout - * @params Bp The pre-packed filter +class FBGEMM_API Packed3x3ConvMatrix : public PackedDepthWiseConvMatrix { + public: + Packed3x3ConvMatrix(int K, const std::int8_t* smat) + : PackedDepthWiseConvMatrix(K, 3 * 3, smat) {} +}; + +class FBGEMM_API Packed3x3x3ConvMatrix : public PackedDepthWiseConvMatrix { + public: + Packed3x3x3ConvMatrix(int K, const std::int8_t* smat) + : PackedDepthWiseConvMatrix(K, 3 * 3 * 3, smat) {} +}; + +/** To be removed. Keeping it just to make sure we don't change C2 files and + * fbgemm files in a single diff + * */ FBGEMM_API void depthwise_3x3_pad_1( int N, @@ -64,8 +73,14 @@ FBGEMM_API void depthwise_3x3_pad_1( int stride_w, std::int32_t A_zero_point, const std::uint8_t* A, - const Packed3x3ConvMatrix& Bp, - std::int32_t* C, + std::int32_t B_zero_point, + const PackedDepthWiseConvMatrix& Bp, + float C_multiplier, + std::int32_t C_zero_point, + std::uint8_t* C, + const std::int32_t* col_offsets, + const std::int32_t* bias, + bool fuse_relu = false, int thread_id = 0, int num_threads = 1); @@ -74,7 +89,10 @@ FBGEMM_API void depthwise_3x3_pad_1( * This version is fused with requantization. * * @col_offsets nullptr if col_offsets are folded into bias + * @act_times_w_scale Only used if BIAS_TYPE is float, i.e., bias is + * unquantized. */ +template <typename BIAS_TYPE = std::int32_t> FBGEMM_API void depthwise_3x3_pad_1( int N, int H, @@ -85,22 +103,48 @@ FBGEMM_API void depthwise_3x3_pad_1( std::int32_t A_zero_point, const std::uint8_t* A, std::int32_t B_zero_point, - const Packed3x3ConvMatrix& Bp, + const PackedDepthWiseConvMatrix& Bp, float C_multiplier, std::int32_t C_zero_point, std::uint8_t* C, const std::int32_t* col_offsets, - const std::int32_t* bias, + const BIAS_TYPE* bias, bool fuse_relu = false, + float act_times_w_scale = 1.0f, int thread_id = 0, int num_threads = 1); /** - * Depth-wise 3x3 convolution with pad=1 and stride=1 and K a multiple of 8 - * This version is fused with requantization and uses per-channel quantization. + * Depth-wise 3x3 convolution with pad=1 and K a multiple of 8, fused with + * requantization, and using per-channel quantization. * * @col_offsets nullptr if col_offsets are folded into bias */ +template <typename BIAS_TYPE = std::int32_t> +FBGEMM_API void depthwise_3x3_per_channel_quantization_pad_1( + int N, + int H, + int W, + int K, + int stride_h, + int stride_w, + std::int32_t A_zero_point, + const std::uint8_t* A, + const std::int32_t* B_zero_point, + const PackedDepthWiseConvMatrix& Bp, + const float* C_multiplier, + std::int32_t C_zero_point, + std::uint8_t* C, + const std::int32_t* col_offsets, + const BIAS_TYPE* bias, + bool fuse_relu = false, + const float* act_times_w_scale = nullptr, + int thread_id = 0, + int num_threads = 1); + +/** To be removed. Keeping it just to make sure we don't change C2 files and + * fbgemm files in a single diff + */ FBGEMM_API void depthwise_3x3_per_channel_quantization_pad_1( int N, int H, @@ -111,7 +155,7 @@ FBGEMM_API void depthwise_3x3_per_channel_quantization_pad_1( std::int32_t A_zero_point, const std::uint8_t* A, const std::int32_t* B_zero_point, - const Packed3x3ConvMatrix& Bp, + const PackedDepthWiseConvMatrix& Bp, const float* C_multiplier, std::int32_t C_zero_point, std::uint8_t* C, @@ -121,6 +165,10 @@ FBGEMM_API void depthwise_3x3_per_channel_quantization_pad_1( int thread_id = 0, int num_threads = 1); +/** To be removed. Keeping it just to make sure we don't change C2 files and + * fbgemm files in a single diff + * + */ FBGEMM_API void depthwise_3x3x3_pad_1( int N, int T, @@ -132,14 +180,20 @@ FBGEMM_API void depthwise_3x3x3_pad_1( int stride_w, std::int32_t A_zero_point, const std::uint8_t* A, - const Packed3x3x3ConvMatrix& Bp, - std::int32_t* C, + std::int32_t B_zero_point, + const PackedDepthWiseConvMatrix& Bp, + float C_multiplier, + std::int32_t C_zero_point, + std::uint8_t* C, + const std::int32_t* col_offsets, + const std::int32_t* bias, + bool fuse_relu = false, int thread_id = 0, int num_threads = 1); - /** * @col_offsets nullptr if col_offsets are folded into bias */ +template <typename BIAS_TYPE = std::int32_t> FBGEMM_API void depthwise_3x3x3_pad_1( int N, int T, @@ -152,11 +206,38 @@ FBGEMM_API void depthwise_3x3x3_pad_1( std::int32_t A_zero_point, const std::uint8_t* A, std::int32_t B_zero_point, - const Packed3x3x3ConvMatrix& Bp, + const PackedDepthWiseConvMatrix& Bp, float C_multiplier, std::int32_t C_zero_point, std::uint8_t* C, const std::int32_t* col_offsets, + const BIAS_TYPE* bias, + bool fuse_relu = false, + float act_times_w_scale = 1.0f, + int thread_id = 0, + int num_threads = 1); + +/** To be removed. Keeping it just to make sure we don't change C2 files and + * fbgemm files in a single diff + * + */ +FBGEMM_API void depthwise_3x3x3_per_channel_quantization_pad_1( + int N, + int T, + int H, + int W, + int K, + int stride_t, + int stride_h, + int stride_w, + std::int32_t A_zero_point, + const std::uint8_t* A, + const std::int32_t* B_zero_point, + const PackedDepthWiseConvMatrix& Bp, + const float* C_multiplier, + std::int32_t C_zero_point, + std::uint8_t* C, + const std::int32_t* col_offsets, const std::int32_t* bias, bool fuse_relu = false, int thread_id = 0, @@ -165,6 +246,7 @@ FBGEMM_API void depthwise_3x3x3_pad_1( /** * @col_offsets nullptr if col_offsets are folded into bias */ +template <typename BIAS_TYPE = std::int32_t> FBGEMM_API void depthwise_3x3x3_per_channel_quantization_pad_1( int N, int T, @@ -177,13 +259,14 @@ FBGEMM_API void depthwise_3x3x3_per_channel_quantization_pad_1( std::int32_t A_zero_point, const std::uint8_t* A, const std::int32_t* B_zero_point, - const Packed3x3x3ConvMatrix& Bp, + const PackedDepthWiseConvMatrix& Bp, const float* C_multiplier, std::int32_t C_zero_point, std::uint8_t* C, const std::int32_t* col_offsets, - const std::int32_t* bias, + const BIAS_TYPE* bias, bool fuse_relu = false, + const float* act_times_w_scale = nullptr, int thread_id = 0, int num_threads = 1); diff --git a/include/fbgemm/OutputProcessing-inl.h b/include/fbgemm/OutputProcessing-inl.h index d984c60..04ae100 100644 --- a/include/fbgemm/OutputProcessing-inl.h +++ b/include/fbgemm/OutputProcessing-inl.h @@ -59,11 +59,13 @@ inline int DoSConvOnInpBuffer<outT, inT, nextOPType>::f( template < bool FUSE_RELU, QuantizationGranularity Q_GRAN, + typename BIAS_TYPE, typename outT, typename inT, typename nextOPType> template <inst_set_t instSet> -inline int ReQuantizeOutput<FUSE_RELU, Q_GRAN, outT, inT, nextOPType>::f( +inline int +ReQuantizeOutput<FUSE_RELU, Q_GRAN, BIAS_TYPE, outT, inT, nextOPType>::f( outT* out, const inT* inp, const block_type_t& block, @@ -98,11 +100,20 @@ inline int ReQuantizeOutput<FUSE_RELU, Q_GRAN, outT, inT, nextOPType>::f( raw -= q_row_offsets_[i - block.row_start] * Bq_zero_point_[Bq_zero_point_idx]; } + float raw_f; if (bias_) { - raw += bias_[j]; + if (std::is_same<BIAS_TYPE, float>::value) { + raw_f = raw; + raw_f += bias_[j] / act_times_w_scale_[Bq_zero_point_idx]; + } else { + raw += bias_[j]; + raw_f = raw; + } + } else { + raw_f = raw; } - float ab = raw * C_multiplier_[Bq_zero_point_idx]; + float ab = raw_f * C_multiplier_[Bq_zero_point_idx]; long rounded = std::lrintf(ab) + C_zero_point_; out[i * ld_out + j] = std::max( @@ -115,15 +126,16 @@ inline int ReQuantizeOutput<FUSE_RELU, Q_GRAN, outT, inT, nextOPType>::f( Bq_zero_point_[0] == 0) || q_row_offsets_ == nullptr; - requantizationParams_t r = {Aq_zero_point_, - Bq_zero_point_, - C_zero_point_, - C_multiplier_, - q_row_offsets_, - q_col_offsets_, - bias_, - ncols_, - groups_}; + requantizationParams_t<BIAS_TYPE> r = {Aq_zero_point_, + Bq_zero_point_, + C_zero_point_, + C_multiplier_, + q_row_offsets_, + q_col_offsets_, + bias_, + ncols_, + groups_, + act_times_w_scale_}; if (Aq_zero_point_ == 0) { if (b_symmetric) { diff --git a/include/fbgemm/QuantUtilsAvx2.h b/include/fbgemm/QuantUtilsAvx2.h index 47f33a8..c7f3f35 100644 --- a/include/fbgemm/QuantUtilsAvx2.h +++ b/include/fbgemm/QuantUtilsAvx2.h @@ -40,9 +40,10 @@ struct FBGEMM_API RequantizationParams { //////////////////////////////////////////////////////////////////////////////// // Utility functions +template <typename T=std::uint8_t> void QuantizeAvx2( const float* src, - std::uint8_t* dst, + T* dst, int len, const TensorQuantizationParams& qparams); @@ -71,14 +72,15 @@ template < bool B_SYMMETRIC, QuantizationGranularity Q_GRAN, bool HAS_BIAS, - bool FUSE_RELU> + bool FUSE_RELU, + typename BIAS_TYPE = std::int32_t> FBGEMM_API void requantizeOutputProcessingAvx2( std::uint8_t* out, const std::int32_t* inp, const block_type_t& block, int ld_out, int ld_in, - const requantizationParams_t& r); + const requantizationParams_t<BIAS_TYPE>& r); template < bool A_SYMMETRIC, @@ -86,14 +88,15 @@ template < QuantizationGranularity Q_GRAN, bool HAS_BIAS, bool FUSE_RELU, - int C_PER_G> + int C_PER_G, + typename BIAS_TYPE = std::int32_t> FBGEMM_API void requantizeOutputProcessingGConvAvx2( std::uint8_t* out, const std::int32_t* inp, const block_type_t& block, int ld_out, int ld_in, - const requantizationParams_t& r); + const requantizationParams_t<BIAS_TYPE>& r); template < bool A_SYMMETRIC, diff --git a/include/fbgemm/UtilsAvx2.h b/include/fbgemm/UtilsAvx2.h index 082edc1..3bac909 100644 --- a/include/fbgemm/UtilsAvx2.h +++ b/include/fbgemm/UtilsAvx2.h @@ -44,16 +44,19 @@ struct block_type_t { * QuantUtilsAvx2.h as it combines all the parameters needed for various * quantization granularities */ +template<typename BIAS_TYPE = std::int32_t> struct requantizationParams_t { + using BIAS_T = BIAS_TYPE; std::int32_t A_zero_point; const std::int32_t* B_zero_point; std::int32_t C_zero_point; const float* C_multiplier; const std::int32_t* row_offsets; const std::int32_t* col_offsets; - const std::int32_t* bias; + const BIAS_T* bias; std::uint32_t ncols; int groups; + const float* act_times_w_scale; }; /** diff --git a/src/CodeCache.h b/src/CodeCache.h new file mode 100644 index 0000000..08e9c9b --- /dev/null +++ b/src/CodeCache.h @@ -0,0 +1,59 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#pragma once +#include <condition_variable> +#include <future> +#include <map> +#include <mutex> + +namespace fbgemm { + +/** + * @brief Thread safe cache for microkernels, ensures single creation per key. + * @tparam Key Type of unique key (typically a tuple) + * @tparam Value Type of the microkernel function (Typically a function pointer) + */ +template <typename KEY, typename VALUE> class CodeCache { +private: + std::map<KEY, std::shared_future<VALUE>> values_; + std::mutex mutex_; + +public: + CodeCache(const CodeCache &) = delete; + CodeCache &operator=(const CodeCache &) = delete; + + CodeCache(){}; + + VALUE getOrCreate(const KEY &key, std::function<VALUE()> generatorFunction) { + std::shared_future<VALUE> returnFuture; + std::promise<VALUE> returnPromise; + bool needsToGenerate = false; + + // Check for existance of the key + { + std::unique_lock<std::mutex> lock(mutex_); + + auto it = values_.find(key); + if (it != values_.end()) { + returnFuture = it->second; + } else { + values_[key] = returnFuture = returnPromise.get_future().share(); + needsToGenerate = true; + } + } + + // The value (code) generation is not happening under a lock + if (needsToGenerate) { + returnPromise.set_value(generatorFunction()); + } + + // Wait for the future and return the value + return returnFuture.get(); + } +}; + +} // namespace fbgemm diff --git a/src/ExecuteKernelU8S8.cc b/src/ExecuteKernelU8S8.cc index 0a4ff55..4ae1b50 100644 --- a/src/ExecuteKernelU8S8.cc +++ b/src/ExecuteKernelU8S8.cc @@ -315,19 +315,23 @@ void ExecuteKernel< //////////////////////////////////////////////////////////////////////////////// // ReQuantizeOutput -#define INSTANTIATE_REQUANT_BASE(PACK_A, ACC_T, RELU, Q_GRAN) \ - template class ExecuteKernel< \ - PACK_A<uint8_t, ACC_T>, \ - PackBMatrix<int8_t, ACC_T>, \ - uint8_t, \ - ReQuantizeOutput<RELU, Q_GRAN>>; +#define INSTANTIATE_REQUANT_BASE(PACK_A, ACC_T, RELU, Q_GRAN, BIAS_TYPE) \ + template class ExecuteKernel< \ + PACK_A<uint8_t, ACC_T>, \ + PackBMatrix<int8_t, ACC_T>, \ + uint8_t, \ + ReQuantizeOutput<RELU, Q_GRAN, BIAS_TYPE>>; + +#define INSTANTIATE_REQUANT_BIAS_T(PACK_A, ACC_T, RELU, Q_GRAN) \ + INSTANTIATE_REQUANT_BASE(PACK_A, ACC_T, RELU, Q_GRAN, float); \ + INSTANTIATE_REQUANT_BASE(PACK_A, ACC_T, RELU, Q_GRAN, int32_t); #define INSTANTIATE_REQUANT_Q_GRANS(PACK_A, ACC_T, RELU) \ - INSTANTIATE_REQUANT_BASE( \ + INSTANTIATE_REQUANT_BIAS_T( \ PACK_A, ACC_T, RELU, QuantizationGranularity::TENSOR); \ - INSTANTIATE_REQUANT_BASE( \ + INSTANTIATE_REQUANT_BIAS_T( \ PACK_A, ACC_T, RELU, QuantizationGranularity::GROUP); \ - INSTANTIATE_REQUANT_BASE( \ + INSTANTIATE_REQUANT_BIAS_T( \ PACK_A, ACC_T, RELU, QuantizationGranularity::OUT_CHANNEL); #define INSTANTIATE_REQUANT_RELU(PACK_A, ACC_T) \ @@ -344,21 +348,27 @@ INSTANTIATE_REQUANT_ACC_T(PackAWithRowOffset); #undef INSTANTIATE_REQUANT_ACC_T #undef INSTANTIATE_REQUANT_RELU #undef INSTANTIATE_REQUANT_Q_GRANS +#undef INSTANTIATE_REQUANT_BIAS_T #undef INSTANTIATE_REQUANT_BASE -#define INSTANTIATE_IM2COL_REQUANT_BASE(ACC_T, RELU, SPATIAL_DIM, Q_GRAN) \ - template class ExecuteKernel< \ - PackAWithIm2Col<uint8_t, ACC_T, SPATIAL_DIM>, \ - PackBMatrix<int8_t, ACC_T>, \ - uint8_t, \ - ReQuantizeOutput<RELU, Q_GRAN>>; +#define INSTANTIATE_IM2COL_REQUANT_BASE( \ + ACC_T, RELU, SPATIAL_DIM, Q_GRAN, BIAS_TYPE) \ + template class ExecuteKernel< \ + PackAWithIm2Col<uint8_t, ACC_T, SPATIAL_DIM>, \ + PackBMatrix<int8_t, ACC_T>, \ + uint8_t, \ + ReQuantizeOutput<RELU, Q_GRAN, BIAS_TYPE>>; + +#define INSTANTIATE_IM2COL_REQUANT_BIAS_T(ACC_T, RELU, SPATIAL_DIM, Q_GRAN) \ + INSTANTIATE_IM2COL_REQUANT_BASE(ACC_T, RELU, SPATIAL_DIM, Q_GRAN, float); \ + INSTANTIATE_IM2COL_REQUANT_BASE(ACC_T, RELU, SPATIAL_DIM, Q_GRAN, int32_t); #define INSTANTIATE_IM2COL_REQUANT_Q_GRANS(ACC_T, RELU, SPATIAL_DIM) \ - INSTANTIATE_IM2COL_REQUANT_BASE( \ + INSTANTIATE_IM2COL_REQUANT_BIAS_T( \ ACC_T, RELU, SPATIAL_DIM, QuantizationGranularity::TENSOR); \ - INSTANTIATE_IM2COL_REQUANT_BASE( \ + INSTANTIATE_IM2COL_REQUANT_BIAS_T( \ ACC_T, RELU, SPATIAL_DIM, QuantizationGranularity::GROUP); \ - INSTANTIATE_IM2COL_REQUANT_BASE( \ + INSTANTIATE_IM2COL_REQUANT_BIAS_T( \ ACC_T, RELU, SPATIAL_DIM, QuantizationGranularity::OUT_CHANNEL); #define INSTANTIATE_IM2COL_REQUANT_SPATIAL_DIM(ACC_T, RELU) \ @@ -375,6 +385,7 @@ INSTANTIATE_IM2COL_REQUANT_RELU(int16_t); #undef INSTANTIATE_IM2COL_REQUANT_RELU #undef INSTANTIATE_IM2COL_REQUANT_SPATIAL_DIM #undef INSTANTIATE_IM2COL_REQUANT_Q_GRANS +#undef INSTANTIATE_IM2COL_REQUANT_BIAS_T #undef INSTANTIATE_IM2COL_REQUANT_BASE //////////////////////////////////////////////////////////////////////////////// diff --git a/src/Fbgemm.cc b/src/Fbgemm.cc index 1052044..b691b88 100644 --- a/src/Fbgemm.cc +++ b/src/Fbgemm.cc @@ -237,22 +237,26 @@ bool fbgemmSupportedCPU() { //////////////////////////////////////////////////////////////////////////////// // ReQuantizeOutput -#define INSTANTIATE_BASE(PACK_A, ACC_T, RELU, Q_GRAN) \ +#define INSTANTIATE_BASE(PACK_A, ACC_T, RELU, Q_GRAN, BIAS_TYPE) \ template void fbgemmPacked( \ PackMatrix<PACK_A<uint8_t, ACC_T>, uint8_t, ACC_T>& packA, \ PackMatrix<PackBMatrix<int8_t, ACC_T>, int8_t, ACC_T>& packB, \ uint8_t* C, \ int32_t* C_buffer, \ uint32_t ldc, \ - const ReQuantizeOutput<RELU, Q_GRAN>& outProcess, \ + const ReQuantizeOutput<RELU, Q_GRAN, BIAS_TYPE>& outProcess, \ int thread_id, \ int num_threads, \ const BlockingFactors* blocking_params); -#define INSTANTIATE_Q_GRANS(PACK_A, ACC_T, RELU) \ - INSTANTIATE_BASE(PACK_A, ACC_T, RELU, QuantizationGranularity::TENSOR); \ - INSTANTIATE_BASE(PACK_A, ACC_T, RELU, QuantizationGranularity::GROUP); \ - INSTANTIATE_BASE(PACK_A, ACC_T, RELU, QuantizationGranularity::OUT_CHANNEL); +#define INSTANTIATE_BIAS_T(PACK_A, ACC_T, RELU, Q_GRAN) \ + INSTANTIATE_BASE(PACK_A, ACC_T, RELU, Q_GRAN, float); \ + INSTANTIATE_BASE(PACK_A, ACC_T, RELU, Q_GRAN, int32_t); + +#define INSTANTIATE_Q_GRANS(PACK_A, ACC_T, RELU) \ + INSTANTIATE_BIAS_T(PACK_A, ACC_T, RELU, QuantizationGranularity::TENSOR); \ + INSTANTIATE_BIAS_T(PACK_A, ACC_T, RELU, QuantizationGranularity::GROUP); \ + INSTANTIATE_BIAS_T(PACK_A, ACC_T, RELU, QuantizationGranularity::OUT_CHANNEL); #define INSTANTIATE_RELU(PACK_A, ACC_T) \ INSTANTIATE_Q_GRANS(PACK_A, ACC_T, false); \ @@ -268,27 +272,34 @@ INSTANTIATE_ACC_T(PackAWithRowOffset); #undef INSTANTIATE_ACC_T #undef INSTANTIATE_RELU #undef INSTANTIATE_Q_GRANS +#undef INSTANTIATE_BIAS_T #undef INSTANTIATE_BASE -#define INSTANTIATE_BASE(ACC_T, RELU, SPATIAL_DIM, Q_GRAN) \ - template void fbgemmPacked( \ - PackMatrix< \ - PackAWithIm2Col<uint8_t, ACC_T, SPATIAL_DIM>, \ - uint8_t, \ - ACC_T>& packA, \ - PackMatrix<PackBMatrix<int8_t, ACC_T>, int8_t, ACC_T>& packB, \ - uint8_t* C, \ - int32_t* C_buffer, \ - uint32_t ldc, \ - const ReQuantizeOutput<RELU, Q_GRAN>& outProcess, \ - int thread_id, \ - int num_threads, \ +#define INSTANTIATE_BASE(ACC_T, RELU, SPATIAL_DIM, Q_GRAN, BIAS_TYPE) \ + template void fbgemmPacked( \ + PackMatrix< \ + PackAWithIm2Col<uint8_t, ACC_T, SPATIAL_DIM>, \ + uint8_t, \ + ACC_T>& packA, \ + PackMatrix<PackBMatrix<int8_t, ACC_T>, int8_t, ACC_T>& packB, \ + uint8_t* C, \ + int32_t* C_buffer, \ + uint32_t ldc, \ + const ReQuantizeOutput<RELU, Q_GRAN, BIAS_TYPE>& outProcess, \ + int thread_id, \ + int num_threads, \ const BlockingFactors* blocking_params); -#define INSTANTIATE_Q_GRANS(ACC_T, RELU, SPATIAL_DIM) \ - INSTANTIATE_BASE(ACC_T, RELU, SPATIAL_DIM, QuantizationGranularity::TENSOR); \ - INSTANTIATE_BASE(ACC_T, RELU, SPATIAL_DIM, QuantizationGranularity::GROUP); \ - INSTANTIATE_BASE( \ +#define INSTANTIATE_BIAS_T(ACC_T, RELU, SPATIAL_DIM, Q_GRAN) \ + INSTANTIATE_BASE(ACC_T, RELU, SPATIAL_DIM, Q_GRAN, float); \ + INSTANTIATE_BASE(ACC_T, RELU, SPATIAL_DIM, Q_GRAN, int32_t); + +#define INSTANTIATE_Q_GRANS(ACC_T, RELU, SPATIAL_DIM) \ + INSTANTIATE_BIAS_T( \ + ACC_T, RELU, SPATIAL_DIM, QuantizationGranularity::TENSOR); \ + INSTANTIATE_BIAS_T( \ + ACC_T, RELU, SPATIAL_DIM, QuantizationGranularity::GROUP); \ + INSTANTIATE_BIAS_T( \ ACC_T, RELU, SPATIAL_DIM, QuantizationGranularity::OUT_CHANNEL); #define INSTANTIATE_SPATIAL_DIM(ACC_T, RELU) \ @@ -305,6 +316,7 @@ INSTANTIATE_RELU(int16_t); #undef INSTANTIATE_RELU #undef INSTANTIATE_SPATIAL_DIM #undef INSTANTIATE_Q_GRANS +#undef INSTANTIATE_BIAS_T #undef INSTANTIATE_BASE //////////////////////////////////////////////////////////////////////////////// diff --git a/src/FbgemmConv.cc b/src/FbgemmConv.cc index 33d1535..de833d2 100644 --- a/src/FbgemmConv.cc +++ b/src/FbgemmConv.cc @@ -89,52 +89,121 @@ int fbgemmConv( // std::cout << "Depthwise fast path" << std::endl; const std::int32_t* B_zero_point = outProcess.getBZeroPoint(); const float* C_multiplier = outProcess.getCMultiplier(); + const float* act_times_w_scale = outProcess.getActWScale(); if (SPATIAL_DIM == 3) { static_assert( std::is_same<typename processOutputType::outType, std::uint8_t>:: value, "For depthwise, only requantized output is supported"); - depthwise_3x3x3_pad_1( - conv_p.MB, // mini batch - conv_p.IN_DIM[0], // T - conv_p.IN_DIM[1], // H - conv_p.IN_DIM[2], // W - conv_p.OC, // output channels - conv_p.stride[0], // stride_t - conv_p.stride[1], // stride_h - conv_p.stride[2], // stride_w - outProcess.getAZeroPoint(), - activations, - B_zero_point[0], - *(packed_weights.getPackedWFor3DDW()), - C_multiplier[0], - outProcess.getCZeroPoint(), - out, - outProcess.getColOffsets(), - outProcess.getBias(), - outProcess.RELU_FUSED, // fuse_relu - thread_id, - num_threads); + + if (processOutputType::QGRANType == QuantizationGranularity::TENSOR) { + depthwise_3x3x3_pad_1( + conv_p.MB, // mini batch + conv_p.IN_DIM[0], // T + conv_p.IN_DIM[1], // H + conv_p.IN_DIM[2], // W + conv_p.OC, // output channels + conv_p.stride[0], // stride_t + conv_p.stride[1], // stride_h + conv_p.stride[2], // stride_w + outProcess.getAZeroPoint(), + activations, + B_zero_point[0], + *(packed_weights.getPackedWForDepthwise()), + C_multiplier[0], + outProcess.getCZeroPoint(), + out, + outProcess.getColOffsets(), + outProcess.getBias(), + outProcess.RELU_FUSED, // fuse_relu + act_times_w_scale ? act_times_w_scale[0] : 1.0f, + thread_id, + num_threads); + } else if ( + processOutputType::QGRANType == + QuantizationGranularity::OUT_CHANNEL || + processOutputType::QGRANType == QuantizationGranularity::GROUP) { + depthwise_3x3x3_per_channel_quantization_pad_1( + conv_p.MB, // mini batch + conv_p.IN_DIM[0], // T + conv_p.IN_DIM[1], // H + conv_p.IN_DIM[2], // W + conv_p.OC, // output channels + conv_p.stride[0], // stride_t + conv_p.stride[1], // stride_h + conv_p.stride[2], // stride_w + outProcess.getAZeroPoint(), + activations, + B_zero_point, + *(packed_weights.getPackedWForDepthwise()), + C_multiplier, + outProcess.getCZeroPoint(), + out, + outProcess.getColOffsets(), + outProcess.getBias(), + outProcess.RELU_FUSED, // fuse_relu + outProcess.getActWScale(), // act_scale * weight_scale + thread_id, + num_threads); + } else { + std::string msg = + "[FBGEMM_CONV_ERROR] This quantization granularity is " + "not supported"; + throw std::runtime_error(msg); + } } else { - depthwise_3x3_pad_1( - conv_p.MB, // mini batch - conv_p.IN_DIM[0], // H - conv_p.IN_DIM[1], // W - conv_p.OC, // output channels - conv_p.stride[0], // stride_h - conv_p.stride[1], // stride_w - outProcess.getAZeroPoint(), - activations, - B_zero_point[0], - *(packed_weights.getPackedWFor2DDW()), - C_multiplier[0], - outProcess.getCZeroPoint(), - out, - outProcess.getColOffsets(), - outProcess.getBias(), - outProcess.RELU_FUSED, // fuse_relu - thread_id, - num_threads); + if (processOutputType::QGRANType == QuantizationGranularity::TENSOR) { + depthwise_3x3_pad_1( + conv_p.MB, // mini batch + conv_p.IN_DIM[0], // H + conv_p.IN_DIM[1], // W + conv_p.OC, // output channels + conv_p.stride[0], // stride_h + conv_p.stride[1], // stride_w + outProcess.getAZeroPoint(), + activations, + B_zero_point[0], + *(packed_weights.getPackedWForDepthwise()), + C_multiplier[0], + outProcess.getCZeroPoint(), + out, + outProcess.getColOffsets(), + outProcess.getBias(), + outProcess.RELU_FUSED, // fuse_relu + act_times_w_scale ? act_times_w_scale[0] : 1.0f, + thread_id, + num_threads); + } else if ( + processOutputType::QGRANType == + QuantizationGranularity::OUT_CHANNEL || + processOutputType::QGRANType == QuantizationGranularity::GROUP) { + // The number of channels == groups for depthwise convolutions + depthwise_3x3_per_channel_quantization_pad_1( + conv_p.MB, // mini batch + conv_p.IN_DIM[0], // H + conv_p.IN_DIM[1], // W + conv_p.OC, // output channels + conv_p.stride[0], // stride_h + conv_p.stride[1], // stride_w + outProcess.getAZeroPoint(), + activations, + B_zero_point, + *(packed_weights.getPackedWForDepthwise()), + C_multiplier, + outProcess.getCZeroPoint(), + out, + outProcess.getColOffsets(), + outProcess.getBias(), + outProcess.RELU_FUSED, // fuse_relu + outProcess.getActWScale(), // act_scale * weight_scale + thread_id, + num_threads); + } else { + std::string msg = + "[FBGEMM_CONV_ERROR] This quantization granularity is " + "not supported"; + throw std::runtime_error(msg); + } } break; } @@ -195,11 +264,32 @@ int fbgemmConv( // All other convolutions go through im2col-based implementation // std::cout << "Im2col path" << std::endl; std::vector<int32_t> row_offset_buf( - PackAWithIm2Col<uint8_t, ACC_T, SPATIAL_DIM> - ::rowOffsetBufferSize(blocking_params)); + PackAWithIm2Col<uint8_t, ACC_T, SPATIAL_DIM>::rowOffsetBufferSize( + blocking_params)); const std::int32_t* b_zero_point = outProcess.getBZeroPoint(); - bool b_symmetric = b_zero_point[0] == 0; + bool b_symmetric = false; + if (processOutputType::QGRANType == QuantizationGranularity::TENSOR) { + b_symmetric = b_zero_point[0] == 0; + } else if ( + processOutputType::QGRANType == QuantizationGranularity::GROUP) { + b_symmetric = + std::all_of(b_zero_point, b_zero_point + conv_p.G, [](int i) { + return i == 0; + }); + } else if ( + processOutputType::QGRANType == + QuantizationGranularity::OUT_CHANNEL) { + b_symmetric = + std::all_of(b_zero_point, b_zero_point + conv_p.OC, [](int i) { + return i == 0; + }); + } else { + std::string msg = + "[FBGEMM_CONV_ERROR] This quantization granularity is " + "not supported"; + throw std::runtime_error(msg); + } PackAWithIm2Col<uint8_t, ACC_T, SPATIAL_DIM> packA( conv_p, activations, @@ -227,21 +317,25 @@ int fbgemmConv( return 0; } -#define INSTANTIATE_BASE(ACC_T, Q_GRAN, RELU, SPATIAL_DIM) \ +#define INSTANTIATE_BASE(ACC_T, Q_GRAN, RELU, SPATIAL_DIM, BIAS_TYPE) \ template int fbgemmConv( \ const conv_param_t<SPATIAL_DIM>& conv_p, \ const std::uint8_t* activations, \ PackWeightsForConv<SPATIAL_DIM, std::int8_t, ACC_T>& packed_weights, \ std::uint8_t* out, \ std::int32_t* outBuffer, \ - ReQuantizeOutput<RELU, Q_GRAN>& outProcess, \ + ReQuantizeOutput<RELU, Q_GRAN, BIAS_TYPE>& outProcess, \ int thread_id, \ int num_threads, \ const BlockingFactors* blocking_params); +#define INSTANTIATE_BIAS_T(ACC_T, Q_GRAN, RELU, SPATIAL_DIM) \ + INSTANTIATE_BASE(ACC_T, Q_GRAN, RELU, SPATIAL_DIM, float); \ + INSTANTIATE_BASE(ACC_T, Q_GRAN, RELU, SPATIAL_DIM, int32_t); + #define INSTANTIATE_SPATIAL_DIM(ACC_T, Q_GRAN, RELU) \ - INSTANTIATE_BASE(ACC_T, Q_GRAN, RELU, 2); \ - INSTANTIATE_BASE(ACC_T, Q_GRAN, RELU, 3); + INSTANTIATE_BIAS_T(ACC_T, Q_GRAN, RELU, 2); \ + INSTANTIATE_BIAS_T(ACC_T, Q_GRAN, RELU, 3); #define INSTANTIATE_RELU(ACC_T, Q_GRAN) \ INSTANTIATE_SPATIAL_DIM(ACC_T, Q_GRAN, true); \ @@ -257,6 +351,7 @@ INSTANTIATE_Q_GRANS(std::int32_t); #undef INSTANTIATE_Q_GRANS #undef INSTANTIATE_RELU #undef INSTANTIATE_SPATIAL_DIM +#undef INSTANTIATE_BIAS_T #undef INSTANTIATE_BASE template bool takeDepthWiseFastPath<2, std::int32_t>( diff --git a/src/FbgemmFP16.cc b/src/FbgemmFP16.cc index f357966..b034f2c 100644 --- a/src/FbgemmFP16.cc +++ b/src/FbgemmFP16.cc @@ -50,6 +50,7 @@ struct KernelInfo { // autotuned kernel splits for various cases m = 1:mb_max // may need re-autotuning for new uarch + // clang-format off static constexpr int partition[121][2][2] = { // NOTE: clang-format wants to use a different formatting but the current // formatting should be easier to read. @@ -175,6 +176,7 @@ struct KernelInfo { { { 6, 19 }, { 5, 1 } }, // 119 { { 6, 20 }, { 0, 0 } }, // 120 }; + // clang-format on }; constexpr KernelInfo::knl_ptr KernelInfo::kernel[7];; constexpr int KernelInfo::partition[121][2][2]; diff --git a/src/FbgemmI8Depthwise3DAvx2.cc b/src/FbgemmI8Depthwise3DAvx2.cc new file mode 100644 index 0000000..925d265 --- /dev/null +++ b/src/FbgemmI8Depthwise3DAvx2.cc @@ -0,0 +1,1415 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include "fbgemm/FbgemmI8DepthwiseAvx2.h" + +#include <string> +#include <tuple> // for tie + +#include "FbgemmI8DepthwiseAvx2-inl.h" + +using namespace std; + +namespace fbgemm { + +template < + bool SUM_A, + bool REMAINDER = false, + bool PER_CHANNEL_QUANTIZATION = false> +static inline __attribute__((always_inline)) void inner_prod_3x3x3_packed_( + int T, + int H, + int W, + int K, + int t_in, + int h_in, + int w_in, + const uint8_t* A, + int32_t A_zero_point, + const int8_t* Bp, + const int32_t* B_zero_point, + int32_t* C, + int remainder, + int32_t* row_offsets) { + __m256i A_zero_point_v = _mm256_set1_epi8(static_cast<uint8_t>(A_zero_point)); + __m256i mask_v = _mm256_setzero_si256(); + if (REMAINDER) { + mask_v = _mm256_loadu_si256( + reinterpret_cast<const __m256i*>(masks[remainder / 4])); + } + + // The code below can be written as a simple R*S loop but the compiler + // doesn't unroll so we're manually unrolling it. + // constexpr int R = 3, S = 3; + // array<__m256i, R * S> a_v; + // for (int r = 0; r < R; ++r) { + // for (int s = 0; s < S; ++s) { + // if (h_in + r >= 0 && h_in + r < H && w_in + s >= 0 && w_in + s < W) { + // if (REMAINDER) { + // a_v[r * S + s] = + // _mm256_maskload_epi32((const int *)(A + (r * W + s) * K), + // mask_v); + // } else { + // a_v[r * S + s] = + // _mm256_lddqu_si256((const __m256i *)(A + (r * W + s) * K)); + // } + // } else { + // a_v[r * S + s] = A_zero_point_v; + // } + // } + // } + __m256i a_v[8]; + a_v[0] = A_zero_point_v; + a_v[1] = A_zero_point_v; + a_v[2] = A_zero_point_v; + a_v[3] = A_zero_point_v; + a_v[4] = A_zero_point_v; + a_v[5] = A_zero_point_v; + a_v[6] = A_zero_point_v; + a_v[7] = A_zero_point_v; + + if (t_in >= 0 && t_in < T) { + if (h_in >= 0 && h_in < H) { + if (w_in >= 0 && w_in < W) { + a_v[0] = load_a<REMAINDER>(A + ((0 * H + 0) * W + 0) * K, mask_v); + } + if (w_in + 1 >= 0 && w_in + 1 < W) { + a_v[1] = load_a<REMAINDER>(A + ((0 * H + 0) * W + 1) * K, mask_v); + } + if (w_in + 2 >= 0 && w_in + 2 < W) { + a_v[2] = load_a<REMAINDER>(A + ((0 * H + 0) * W + 2) * K, mask_v); + } + } + + if (h_in + 1 >= 0 && h_in + 1 < H) { + if (w_in >= 0 && w_in < W) { + a_v[3] = load_a<REMAINDER>(A + ((0 * H + 1) * W + 0) * K, mask_v); + } + if (w_in + 1 >= 0 && w_in + 1 < W) { + a_v[4] = load_a<REMAINDER>(A + ((0 * H + 1) * W + 1) * K, mask_v); + } + if (w_in + 2 >= 0 && w_in + 2 < W) { + a_v[5] = load_a<REMAINDER>(A + ((0 * H + 1) * W + 2) * K, mask_v); + } + } + + if (h_in + 2 >= 0 && h_in + 2 < H) { + if (w_in >= 0 && w_in < W) { + a_v[6] = load_a<REMAINDER>(A + ((0 * H + 2) * W + 0) * K, mask_v); + } + if (w_in + 1 >= 0 && w_in + 1 < W) { + a_v[7] = load_a<REMAINDER>(A + ((0 * H + 2) * W + 1) * K, mask_v); + } + } + } + + __m256i a_sum[4]; + inner_prod_packed_<8, SUM_A, REMAINDER>( + a_v, reinterpret_cast<const __m256i*>(Bp), C, remainder, a_sum); + + a_v[0] = A_zero_point_v; + a_v[1] = A_zero_point_v; + a_v[2] = A_zero_point_v; + a_v[3] = A_zero_point_v; + a_v[4] = A_zero_point_v; + a_v[5] = A_zero_point_v; + a_v[6] = A_zero_point_v; + a_v[7] = A_zero_point_v; + + if (t_in >= 0 && t_in < T) { + if (h_in + 2 >= 0 && h_in + 2 < H) { + if (w_in + 2 >= 0 && w_in + 2 < W) { + a_v[0] = load_a<REMAINDER>(A + ((0 * H + 2) * W + 2) * K, mask_v); + } + } + } + + if (t_in + 1 >= 0 && t_in + 1 < T) { + if (h_in >= 0 && h_in < H) { + if (w_in >= 0 && w_in < W) { + a_v[1] = load_a<REMAINDER>(A + ((1 * H + 0) * W + 0) * K, mask_v); + } + if (w_in + 1 >= 0 && w_in + 1 < W) { + a_v[2] = load_a<REMAINDER>(A + ((1 * H + 0) * W + 1) * K, mask_v); + } + if (w_in + 2 >= 0 && w_in + 2 < W) { + a_v[3] = load_a<REMAINDER>(A + ((1 * H + 0) * W + 2) * K, mask_v); + } + } + + if (h_in + 1 >= 0 && h_in + 1 < H) { + if (w_in >= 0 && w_in < W) { + a_v[4] = load_a<REMAINDER>(A + ((1 * H + 1) * W + 0) * K, mask_v); + } + if (w_in + 1 >= 0 && w_in + 1 < W) { + a_v[5] = load_a<REMAINDER>(A + ((1 * H + 1) * W + 1) * K, mask_v); + } + if (w_in + 2 >= 0 && w_in + 2 < W) { + a_v[6] = load_a<REMAINDER>(A + ((1 * H + 1) * W + 2) * K, mask_v); + } + } + + if (h_in + 2 >= 0 && h_in + 2 < H) { + if (w_in >= 0 && w_in < W) { + a_v[7] = load_a<REMAINDER>(A + ((1 * H + 2) * W + 0) * K, mask_v); + } + } + } + + __m256i a_sum_temp[4]; + inner_prod_packed_<8, SUM_A, REMAINDER, true /* acc */>( + a_v, reinterpret_cast<const __m256i*>(Bp) + 8, C, remainder, a_sum_temp); + if (SUM_A) { + a_sum[0] = _mm256_add_epi32(a_sum[0], a_sum_temp[0]); + a_sum[1] = _mm256_add_epi32(a_sum[1], a_sum_temp[1]); + a_sum[2] = _mm256_add_epi32(a_sum[2], a_sum_temp[2]); + a_sum[3] = _mm256_add_epi32(a_sum[3], a_sum_temp[3]); + } + + a_v[0] = A_zero_point_v; + a_v[1] = A_zero_point_v; + a_v[2] = A_zero_point_v; + a_v[3] = A_zero_point_v; + a_v[4] = A_zero_point_v; + a_v[5] = A_zero_point_v; + a_v[6] = A_zero_point_v; + a_v[7] = A_zero_point_v; + + if (t_in + 1 >= 0 && t_in + 1 < T) { + if (h_in + 2 >= 0 && h_in + 2 < H) { + if (w_in + 1 >= 0 && w_in + 1 < W) { + a_v[0] = load_a<REMAINDER>(A + ((1 * H + 2) * W + 1) * K, mask_v); + } + if (w_in + 2 >= 0 && w_in + 2 < W) { + a_v[1] = load_a<REMAINDER>(A + ((1 * H + 2) * W + 2) * K, mask_v); + } + } + } + + if (t_in + 2 >= 0 && t_in + 2 < T) { + if (h_in >= 0 && h_in < H) { + if (w_in >= 0 && w_in < W) { + a_v[2] = load_a<REMAINDER>(A + ((2 * H + 0) * W + 0) * K, mask_v); + } + if (w_in + 1 >= 0 && w_in + 1 < W) { + a_v[3] = load_a<REMAINDER>(A + ((2 * H + 0) * W + 1) * K, mask_v); + } + if (w_in + 2 >= 0 && w_in + 2 < W) { + a_v[4] = load_a<REMAINDER>(A + ((2 * H + 0) * W + 2) * K, mask_v); + } + } + + if (h_in + 1 >= 0 && h_in + 1 < H) { + if (w_in >= 0 && w_in < W) { + a_v[5] = load_a<REMAINDER>(A + ((2 * H + 1) * W + 0) * K, mask_v); + } + if (w_in + 1 >= 0 && w_in + 1 < W) { + a_v[6] = load_a<REMAINDER>(A + ((2 * H + 1) * W + 1) * K, mask_v); + } + if (w_in + 2 >= 0 && w_in + 2 < W) { + a_v[7] = load_a<REMAINDER>(A + ((2 * H + 1) * W + 2) * K, mask_v); + } + } + } + + inner_prod_packed_<8, SUM_A, REMAINDER, true /* acc */>( + a_v, reinterpret_cast<const __m256i*>(Bp) + 16, C, remainder, a_sum_temp); + if (SUM_A) { + a_sum[0] = _mm256_add_epi32(a_sum[0], a_sum_temp[0]); + a_sum[1] = _mm256_add_epi32(a_sum[1], a_sum_temp[1]); + a_sum[2] = _mm256_add_epi32(a_sum[2], a_sum_temp[2]); + a_sum[3] = _mm256_add_epi32(a_sum[3], a_sum_temp[3]); + } + + a_v[0] = A_zero_point_v; + a_v[1] = A_zero_point_v; + a_v[2] = A_zero_point_v; + + if (t_in + 2 >= 0 && t_in + 2 < T) { + if (h_in + 2 >= 0 && h_in + 2 < H) { + if (w_in >= 0 && w_in < W) { + a_v[0] = load_a<REMAINDER>(A + ((2 * H + 2) * W + 0) * K, mask_v); + } + if (w_in + 1 >= 0 && w_in + 1 < W) { + a_v[1] = load_a<REMAINDER>(A + ((2 * H + 2) * W + 1) * K, mask_v); + } + if (w_in + 2 >= 0 && w_in + 2 < W) { + a_v[2] = load_a<REMAINDER>(A + ((2 * H + 2) * W + 2) * K, mask_v); + } + } + } + + inner_prod_packed_<3, SUM_A, REMAINDER, true /* acc */>( + a_v, reinterpret_cast<const __m256i*>(Bp) + 24, C, remainder, a_sum_temp); + + if (SUM_A) { + a_sum[0] = _mm256_add_epi32(a_sum[0], a_sum_temp[0]); + a_sum[1] = _mm256_add_epi32(a_sum[1], a_sum_temp[1]); + a_sum[2] = _mm256_add_epi32(a_sum[2], a_sum_temp[2]); + a_sum[3] = _mm256_add_epi32(a_sum[3], a_sum_temp[3]); + + __m256i B_zero_point_v; + for (int i = 0; i < (REMAINDER ? (remainder / 8) : 4); ++i) { + if (PER_CHANNEL_QUANTIZATION) { + B_zero_point_v = _mm256_loadu_si256( + reinterpret_cast<const __m256i*>(B_zero_point + i * 8)); + } else { + B_zero_point_v = _mm256_set1_epi32(B_zero_point[0]); + } + _mm256_store_si256( + reinterpret_cast<__m256i*>(&row_offsets[i * 8]), + _mm256_mullo_epi32(a_sum[i], B_zero_point_v)); + } + } +} + +template < + bool FUSE_RELU, + bool HAS_BIAS, + bool A_SYMMETRIC, + bool B_SYMMETRIC, + typename BIAS_TYPE> +static inline __attribute__((always_inline)) void depthwise_3x3x3_kernel_( + int T, + int H, + int W, + int K, + int t, + int h, + int w, + int stride_t, + int stride_h, + int stride_w, + int32_t A_zero_point, + const uint8_t* A, + int32_t B_zero_point, + const int8_t* Bp, + float C_multiplier, + int32_t C_zero_point, + int32_t* C_int32, + uint8_t* C_uint8, + int32_t* row_offsets, + const int32_t* col_offsets, + const BIAS_TYPE* bias, + float act_times_w_scale) { + constexpr int R = 3, S = 3; + constexpr int PAD_P = 1, PAD_T = 1, PAD_B = 1, PAD_L = 1, PAD_R = 1; + int H_OUT = (H + PAD_T + PAD_B - R) / stride_h + 1; + int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1; + int t_in = -PAD_P + t * stride_t; + int h_in = -PAD_T + h * stride_h; + int w_in = -PAD_L + w * stride_w; + + int k; + for (k = 0; k < K / 32 * 32; k += 32) { + inner_prod_3x3x3_packed_<!B_SYMMETRIC /*SUM_A*/>( + T, + H, + W, + K, + t_in, + h_in, + w_in, + A + ((t_in * H + h_in) * W + w_in) * K + k, + A_zero_point, + Bp + k * 28, + &B_zero_point, + C_int32 + k, + 0, + B_SYMMETRIC ? nullptr : &row_offsets[k]); + } + int remainder = K - k; + if (remainder) { + inner_prod_3x3x3_packed_<!B_SYMMETRIC /*SUM_A*/, true>( + T, + H, + W, + K, + t_in, + h_in, + w_in, + A + ((t_in * H + h_in) * W + w_in) * K + k, + A_zero_point, + Bp + k * 28, + &B_zero_point, + C_int32 + k, + remainder, + B_SYMMETRIC ? nullptr : &row_offsets[k]); + } + + requantize_< + FUSE_RELU, + HAS_BIAS, + false, /*PER_CHAN_QUANT*/ + A_SYMMETRIC, + B_SYMMETRIC>( + A_zero_point, + &C_multiplier, + C_zero_point, + C_int32, + C_uint8 + ((t * H_OUT + h) * W_OUT + w) * K, + K, + row_offsets, + col_offsets, + bias, + &act_times_w_scale); +} + +template <bool FUSE_RELU, bool HAS_BIAS, bool A_SYMMETRIC, typename BIAS_TYPE> +static inline __attribute__((always_inline)) void +depthwise_3x3x3_per_channel_quantization_kernel_( + int T, + int H, + int W, + int K, + int t, + int h, + int w, + int stride_t, + int stride_h, + int stride_w, + int32_t A_zero_point, + const uint8_t* A, + const int32_t* B_zero_point, + const int8_t* Bp, + const float* C_multiplier, + int32_t C_zero_point, + int32_t* C_int32, + uint8_t* C_uint8, + int32_t* row_offsets, + const int32_t* col_offsets, + const BIAS_TYPE* bias, + const float* act_times_w_scale) { + constexpr int R = 3, S = 3; + constexpr int PAD_P = 1, PAD_T = 1, PAD_B = 1, PAD_L = 1, PAD_R = 1; + int H_OUT = (H + PAD_T + PAD_B - R) / stride_h + 1; + int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1; + int t_in = -PAD_P + t * stride_t; + int h_in = -PAD_T + h * stride_h; + int w_in = -PAD_L + w * stride_w; + + int k; + for (k = 0; k < K / 32 * 32; k += 32) { + inner_prod_3x3x3_packed_< + true, /*SUM_A*/ + false, /*remainder*/ + true /*per-channel*/>( + T, + H, + W, + K, + t_in, + h_in, + w_in, + A + ((t_in * H + h_in) * W + w_in) * K + k, + A_zero_point, + Bp + k * 28, + B_zero_point + k, + C_int32 + k, + 0, + &row_offsets[k]); + } + int remainder = K - k; + if (remainder) { + inner_prod_3x3x3_packed_< + true, /*SUM_A*/ + true, /*remainder*/ + true /*per-channel*/>( + T, + H, + W, + K, + t_in, + h_in, + w_in, + A + ((t_in * H + h_in) * W + w_in) * K + k, + A_zero_point, + Bp + k * 28, + B_zero_point + k, + C_int32 + k, + remainder, + &row_offsets[k]); + } + requantize_< + FUSE_RELU, + HAS_BIAS, + true, /*PER_CHAN_QUANT*/ + A_SYMMETRIC, + false /*B_SYMM*/>( + A_zero_point, + C_multiplier, + C_zero_point, + C_int32, + C_uint8 + ((t * H_OUT + h) * W_OUT + w) * K, + K, + row_offsets, + col_offsets, + bias, + act_times_w_scale); +} + +template < + bool FUSE_RELU, + bool HAS_BIAS, + bool A_SYMMETRIC, + bool B_SYMMETRIC, + typename BIAS_TYPE> +static inline __attribute__((always_inline)) void depthwise_3x3x3_pad_1_( + int N, + int T, + int H, + int W, + int K, + int stride_t, + int stride_h, + int stride_w, + int32_t A_zero_point, + const uint8_t* A, + int32_t B_zero_point, + const PackedDepthWiseConvMatrix& B, + float C_multiplier, + int32_t C_zero_point, + int32_t* C_int32, + uint8_t* C_uint8, + const int32_t* col_offsets, + const BIAS_TYPE* bias, + float act_times_w_scale, + int thread_id, + int num_threads) { + assert(K % 8 == 0); + constexpr int K_T = 3, K_H = 3, K_W = 3; + constexpr int PAD_P = 1, PAD_N = 1, PAD_T = 1, PAD_B = 1, PAD_L = 1, + PAD_R = 1; + int T_OUT = (T + PAD_P + PAD_N - K_T) / stride_t + 1; + int H_OUT = (H + PAD_T + PAD_B - K_H) / stride_h + 1; + int W_OUT = (W + PAD_L + PAD_R - K_W) / stride_w + 1; + const int8_t* Bp = B.PackedMat(); + + int32_t row_offsets[(K + 31) / 32 * 32] __attribute__((aligned(64))); + + int n_begin, n_end; + int t_begin, t_end, h_begin, h_end; + if (N >= num_threads) { + int n_per_thread = (N + num_threads - 1) / num_threads; + n_begin = std::min(thread_id * n_per_thread, N); + n_end = std::min(n_begin + n_per_thread, N); + t_begin = 0; + t_end = T_OUT; + h_begin = 0; + h_end = H_OUT; + } else { + int nthreads_per_n = num_threads / N; + n_begin = std::min(thread_id / nthreads_per_n, N); + n_end = std::min(n_begin + 1, N); + + int tid_of_n_begin = std::min(n_begin * nthreads_per_n, num_threads); + int tid_of_n_end = std::min(tid_of_n_begin + nthreads_per_n, num_threads); + int nthreads_of_n = tid_of_n_end - tid_of_n_begin; + int tid_within_n = thread_id - tid_of_n_begin; + assert(tid_within_n >= 0); + assert(tid_within_n < nthreads_of_n); + + // n is processed by num_threads_t * num_threads_h 2D grid of threads + int num_threads_t, num_threads_h; + // num_threads_w <= num_threads_h + tie(num_threads_t, num_threads_h) = closest_factors_(nthreads_of_n); + int tid_t = tid_within_n / num_threads_h; + int tid_h = tid_within_n % num_threads_h; + + int t_per_thread = (T_OUT + num_threads_t - 1) / num_threads_t; + t_begin = std::min(tid_t * t_per_thread, T_OUT); + t_end = std::min(t_begin + t_per_thread, T_OUT); + + int h_per_thread = (H_OUT + num_threads_h - 1) / num_threads_h; + h_begin = std::min(tid_h * h_per_thread, H_OUT); + h_end = std::min(h_begin + h_per_thread, H_OUT); + } + + for (int n = n_begin; n < n_end; ++n) { + const uint8_t* A_base = A + n * T * H * W * K; + uint8_t* C_uint8_base = C_uint8 + n * T_OUT * H_OUT * W_OUT * K; + + for (int t = t_begin; t < t_end; ++t) { + for (int h = h_begin; h < h_end; ++h) { + for (int w = 0; w < W_OUT; ++w) { + depthwise_3x3x3_kernel_< + FUSE_RELU, + HAS_BIAS, + A_SYMMETRIC, + B_SYMMETRIC>( + T, + H, + W, + K, + t, + h, + w, + stride_t, + stride_h, + stride_w, + A_zero_point, + A_base, + B_zero_point, + Bp, + C_multiplier, + C_zero_point, + C_int32, + C_uint8_base, + row_offsets, + col_offsets, + bias, + act_times_w_scale); + } // w + } // h + } // t + } // for each n +}; + +template <bool FUSE_RELU, bool HAS_BIAS, bool A_SYMMETRIC, typename BIAS_TYPE> +static inline __attribute__((always_inline)) void +depthwise_3x3x3_per_channel_quantization_pad_1_( + int N, + int T, + int H, + int W, + int K, + int stride_t, + int stride_h, + int stride_w, + int32_t A_zero_point, + const uint8_t* A, + const int32_t* B_zero_point, + const PackedDepthWiseConvMatrix& B, + const float* C_multiplier, + int32_t C_zero_point, + int32_t* C_int32, + uint8_t* C_uint8, + const int32_t* col_offsets, + const BIAS_TYPE* bias, + const float* act_times_w_scale, + int thread_id, + int num_threads) { + assert(K % 8 == 0); + constexpr int K_T = 3, K_H = 3, K_W = 3; + constexpr int PAD_P = 1, PAD_N = 1, PAD_T = 1, PAD_B = 1, PAD_L = 1, + PAD_R = 1; + int T_OUT = (T + PAD_P + PAD_N - K_T) / stride_t + 1; + int H_OUT = (H + PAD_T + PAD_B - K_H) / stride_h + 1; + int W_OUT = (W + PAD_L + PAD_R - K_W) / stride_w + 1; + const int8_t* Bp = B.PackedMat(); + + int32_t row_offsets[(K + 31) / 32 * 32] __attribute__((aligned(64))); + + int n_begin, n_end; + int t_begin, t_end, h_begin, h_end; + if (N >= num_threads) { + int n_per_thread = (N + num_threads - 1) / num_threads; + n_begin = std::min(thread_id * n_per_thread, N); + n_end = std::min(n_begin + n_per_thread, N); + t_begin = 0; + t_end = T_OUT; + h_begin = 0; + h_end = H_OUT; + } else { + int nthreads_per_n = num_threads / N; + n_begin = std::min(thread_id / nthreads_per_n, N); + n_end = std::min(n_begin + 1, N); + + int tid_of_n_begin = std::min(n_begin * nthreads_per_n, num_threads); + int tid_of_n_end = std::min(tid_of_n_begin + nthreads_per_n, num_threads); + int nthreads_of_n = tid_of_n_end - tid_of_n_begin; + int tid_within_n = thread_id - tid_of_n_begin; + assert(tid_within_n >= 0); + assert(tid_within_n < nthreads_of_n); + + // n is processed by num_threads_t * num_threads_h 2D grid of threads + int num_threads_t, num_threads_h; + // num_threads_w <= num_threads_h + tie(num_threads_t, num_threads_h) = closest_factors_(nthreads_of_n); + int tid_t = tid_within_n / num_threads_h; + int tid_h = tid_within_n % num_threads_h; + + int t_per_thread = (T_OUT + num_threads_t - 1) / num_threads_t; + t_begin = std::min(tid_t * t_per_thread, T_OUT); + t_end = std::min(t_begin + t_per_thread, T_OUT); + + int h_per_thread = (H_OUT + num_threads_h - 1) / num_threads_h; + h_begin = std::min(tid_h * h_per_thread, H_OUT); + h_end = std::min(h_begin + h_per_thread, H_OUT); + } + + for (int n = n_begin; n < n_end; ++n) { + const uint8_t* A_base = A + n * T * H * W * K; + uint8_t* C_uint8_base = C_uint8 + n * T_OUT * H_OUT * W_OUT * K; + + for (int t = t_begin; t < t_end; ++t) { + for (int h = h_begin; h < h_end; ++h) { + for (int w = 0; w < W_OUT; ++w) { + depthwise_3x3x3_per_channel_quantization_kernel_< + FUSE_RELU, + HAS_BIAS, + A_SYMMETRIC, + BIAS_TYPE>( + T, + H, + W, + K, + t, + h, + w, + stride_t, + stride_h, + stride_w, + A_zero_point, + A_base, + B_zero_point, + Bp, + C_multiplier, + C_zero_point, + C_int32, + C_uint8_base, + row_offsets, + col_offsets, + bias, + act_times_w_scale); + } // w + } // h + } // t + } // for each n +}; + +// Dispatch A_SYMMETRIC and B_SYMMETRIC +template <bool FUSE_RELU, bool HAS_BIAS, typename BIAS_TYPE> +static void depthwise_3x3x3_pad_1_( + int N, + int T, + int H, + int W, + int K, + int stride_t, + int stride_h, + int stride_w, + int32_t A_zero_point, + const uint8_t* A, + int32_t B_zero_point, + const PackedDepthWiseConvMatrix& B, + float C_multiplier, + int32_t C_zero_point, + uint8_t* C, + const int32_t* col_offsets, + const BIAS_TYPE* bias, + float act_times_w_scale, + int thread_id, + int num_threads) { + int32_t C_int32_temp[(K + 31) / 32 * 32]; + if (A_zero_point == 0 || col_offsets == nullptr) { + if (B_zero_point == 0) { + depthwise_3x3x3_pad_1_< + FUSE_RELU, + HAS_BIAS, + true /*A_symmetric*/, + true /*B_symmetric*/, + BIAS_TYPE>( + N, + T, + H, + W, + K, + stride_t, + stride_h, + stride_w, + A_zero_point, + A, + B_zero_point, + B, + C_multiplier, + C_zero_point, + C_int32_temp, + C, + col_offsets, + bias, + act_times_w_scale, + thread_id, + num_threads); + } else { + depthwise_3x3x3_pad_1_< + FUSE_RELU, + HAS_BIAS, + true /*A_symmetric*/, + false /*B_symmetric*/, + BIAS_TYPE>( + N, + T, + H, + W, + K, + stride_t, + stride_h, + stride_w, + A_zero_point, + A, + B_zero_point, + B, + C_multiplier, + C_zero_point, + C_int32_temp, + C, + col_offsets, + bias, + act_times_w_scale, + thread_id, + num_threads); + } + } else { + if (B_zero_point == 0) { + depthwise_3x3x3_pad_1_< + FUSE_RELU, + HAS_BIAS, + false /*A_symmetric*/, + true /*B_symmetric*/, + BIAS_TYPE>( + N, + T, + H, + W, + K, + stride_t, + stride_h, + stride_w, + A_zero_point, + A, + B_zero_point, + B, + C_multiplier, + C_zero_point, + C_int32_temp, + C, + col_offsets, + bias, + act_times_w_scale, + thread_id, + num_threads); + } else { + depthwise_3x3x3_pad_1_< + FUSE_RELU, + HAS_BIAS, + false /*A_symmetric*/, + false /*B_symmetric*/, + BIAS_TYPE>( + N, + T, + H, + W, + K, + stride_t, + stride_h, + stride_w, + A_zero_point, + A, + B_zero_point, + B, + C_multiplier, + C_zero_point, + C_int32_temp, + C, + col_offsets, + bias, + act_times_w_scale, + thread_id, + num_threads); + } + } +} + +// Dispatch HAS_BIAS +template <bool FUSE_RELU, typename BIAS_TYPE> +static void depthwise_3x3x3_pad_1_( + int N, + int T, + int H, + int W, + int K, + int stride_t, + int stride_h, + int stride_w, + int32_t A_zero_point, + const uint8_t* A, + int32_t B_zero_point, + const PackedDepthWiseConvMatrix& B, + float C_multiplier, + int32_t C_zero_point, + uint8_t* C, + const int32_t* col_offsets, + const BIAS_TYPE* bias, + float act_times_w_scale, + int thread_id, + int num_threads) { + if (bias) { + depthwise_3x3x3_pad_1_<FUSE_RELU, true /*HAS_BIAS*/, BIAS_TYPE>( + N, + T, + H, + W, + K, + stride_t, + stride_h, + stride_w, + A_zero_point, + A, + B_zero_point, + B, + C_multiplier, + C_zero_point, + C, + col_offsets, + bias, + act_times_w_scale, + thread_id, + num_threads); + } else { + depthwise_3x3x3_pad_1_<FUSE_RELU, false /*HAS_BIAS*/, BIAS_TYPE>( + N, + T, + H, + W, + K, + stride_t, + stride_h, + stride_w, + A_zero_point, + A, + B_zero_point, + B, + C_multiplier, + C_zero_point, + C, + col_offsets, + bias, + act_times_w_scale, + thread_id, + num_threads); + } +} + +// Dispatch FUSE_RELU +template <typename BIAS_TYPE> +void depthwise_3x3x3_pad_1( + int N, + int T, + int H, + int W, + int K, + int stride_t, + int stride_h, + int stride_w, + int32_t A_zero_point, + const uint8_t* A, + int32_t B_zero_point, + const PackedDepthWiseConvMatrix& B, + float C_multiplier, + int32_t C_zero_point, + uint8_t* C, + const int32_t* col_offsets, + const BIAS_TYPE* bias, + bool fuse_relu, + float act_times_w_scale, + int thread_id, + int num_threads) { + if (B.GetKernelProduct() != 3 * 3 * 3) { + string msg = + "[FBGEMM_CONV_ERROR] Packed weight is expected to have kernel_prod " + + to_string(3 * 3 * 3) + " but has " + to_string(B.GetKernelProduct()); + throw logic_error(msg); + } + if (stride_t == 0 || stride_h == 0 || stride_w == 0 || num_threads == 0) { + assert( + 0 && + "stride_t == 0 || stride_h == 0 || stride_w == 0 || num_threads == 0"); + return; + } + if (N == 0) { + // In C2, batch size 0 is allowed, so we should just early return. + return; + } + if (fuse_relu) { + depthwise_3x3x3_pad_1_<true /*FUSE_RELU*/, BIAS_TYPE>( + N, + T, + H, + W, + K, + stride_t, + stride_h, + stride_w, + A_zero_point, + A, + B_zero_point, + B, + C_multiplier, + C_zero_point, + C, + col_offsets, + bias, + act_times_w_scale, + thread_id, + num_threads); + } else { + depthwise_3x3x3_pad_1_<false /*FUSE_RELU*/, BIAS_TYPE>( + N, + T, + H, + W, + K, + stride_t, + stride_h, + stride_w, + A_zero_point, + A, + B_zero_point, + B, + C_multiplier, + C_zero_point, + C, + col_offsets, + bias, + act_times_w_scale, + thread_id, + num_threads); + } +} + +// Dispatch A_SYMMETRIC +template <bool FUSE_RELU, bool HAS_BIAS, typename BIAS_TYPE> +static void depthwise_3x3x3_per_channel_quantization_pad_1_( + int N, + int T, + int H, + int W, + int K, + int stride_t, + int stride_h, + int stride_w, + int32_t A_zero_point, + const uint8_t* A, + const int32_t* B_zero_point, + const PackedDepthWiseConvMatrix& B, + const float* C_multiplier, + int32_t C_zero_point, + uint8_t* C, + const int32_t* col_offsets, + const BIAS_TYPE* bias, + const float* act_times_w_scale, + int thread_id, + int num_threads) { + int32_t C_int32_temp[(K + 31) / 32 * 32]; + if (A_zero_point == 0 || col_offsets == nullptr) { + depthwise_3x3x3_per_channel_quantization_pad_1_< + FUSE_RELU, + HAS_BIAS, + true /*A_SYMM*/, + BIAS_TYPE>( + N, + T, + H, + W, + K, + stride_t, + stride_h, + stride_w, + A_zero_point, + A, + B_zero_point, + B, + C_multiplier, + C_zero_point, + C_int32_temp, + C, + col_offsets, + bias, + act_times_w_scale, + thread_id, + num_threads); + } else { + depthwise_3x3x3_per_channel_quantization_pad_1_< + FUSE_RELU, + HAS_BIAS, + false /*A_SYMM*/, + BIAS_TYPE>( + N, + T, + H, + W, + K, + stride_t, + stride_h, + stride_w, + A_zero_point, + A, + B_zero_point, + B, + C_multiplier, + C_zero_point, + C_int32_temp, + C, + col_offsets, + bias, + act_times_w_scale, + thread_id, + num_threads); + } +} + +// Dispatch HAS_BIAS +template <bool FUSE_RELU, typename BIAS_TYPE> +static void depthwise_3x3x3_per_channel_quantization_pad_1_( + int N, + int T, + int H, + int W, + int K, + int stride_t, + int stride_h, + int stride_w, + int32_t A_zero_point, + const uint8_t* A, + const int32_t* B_zero_point, + const PackedDepthWiseConvMatrix& B, + const float* C_multiplier, + int32_t C_zero_point, + uint8_t* C, + const int32_t* col_offsets, + const BIAS_TYPE* bias, + const float* act_times_w_scale, + int thread_id, + int num_threads) { + if (bias) { + depthwise_3x3x3_per_channel_quantization_pad_1_< + FUSE_RELU, + true /* HAS_BIAS */, + BIAS_TYPE>( + N, + T, + H, + W, + K, + stride_t, + stride_h, + stride_w, + A_zero_point, + A, + B_zero_point, + B, + C_multiplier, + C_zero_point, + C, + col_offsets, + bias, + act_times_w_scale, + thread_id, + num_threads); + } else { + depthwise_3x3x3_per_channel_quantization_pad_1_< + FUSE_RELU, + false /* HAS_BIAS */, + BIAS_TYPE>( + N, + T, + H, + W, + K, + stride_t, + stride_h, + stride_w, + A_zero_point, + A, + B_zero_point, + B, + C_multiplier, + C_zero_point, + C, + col_offsets, + bias, + act_times_w_scale, + thread_id, + num_threads); + } +} + +// Dispatch FUSE_RELU +template <typename BIAS_TYPE> +void depthwise_3x3x3_per_channel_quantization_pad_1( + int N, + int T, + int H, + int W, + int K, + int stride_t, + int stride_h, + int stride_w, + int32_t A_zero_point, + const uint8_t* A, + const int32_t* B_zero_point, + const PackedDepthWiseConvMatrix& B, + const float* C_multiplier, + int32_t C_zero_point, + uint8_t* C, + const int32_t* col_offsets, + const BIAS_TYPE* bias, + bool fuse_relu, + const float* act_times_w_scale, + int thread_id, + int num_threads) { + if (B.GetKernelProduct() != 3 * 3 * 3) { + string msg = + "[FBGEMM_CONV_ERROR] Packed weight is expected to have kernel_prod " + + to_string(3 * 3 * 3) + " but has " + to_string(B.GetKernelProduct()); + throw logic_error(msg); + } + if (stride_t == 0 || stride_h == 0 || stride_w == 0 || num_threads == 0) { + assert( + 0 && + "stride_t == 0 || stride_h == 0 || stride_w == 0 || num_threads == 0"); + return; + } + if (N == 0) { + // In C2, batch size 0 is allowed, so we should just early return. + return; + } + if (fuse_relu) { + depthwise_3x3x3_per_channel_quantization_pad_1_< + true /* FUSE_RELU */, + BIAS_TYPE>( + N, + T, + H, + W, + K, + stride_t, + stride_h, + stride_w, + A_zero_point, + A, + B_zero_point, + B, + C_multiplier, + C_zero_point, + C, + col_offsets, + bias, + act_times_w_scale, + thread_id, + num_threads); + } else { + depthwise_3x3x3_per_channel_quantization_pad_1_< + false /* FUSE_RELU */, + BIAS_TYPE>( + N, + T, + H, + W, + K, + stride_t, + stride_h, + stride_w, + A_zero_point, + A, + B_zero_point, + B, + C_multiplier, + C_zero_point, + C, + col_offsets, + bias, + act_times_w_scale, + thread_id, + num_threads); + } +} + +// To be removed +void depthwise_3x3x3_pad_1( + int N, + int T, + int H, + int W, + int K, + int stride_t, + int stride_h, + int stride_w, + int32_t A_zero_point, + const uint8_t* A, + int32_t B_zero_point, + const PackedDepthWiseConvMatrix& B, + float C_multiplier, + int32_t C_zero_point, + uint8_t* C, + const int32_t* col_offsets, + const int32_t* bias, + bool fuse_relu, + int thread_id, + int num_threads) { + depthwise_3x3x3_pad_1<int32_t>( + N, + T, + H, + W, + K, + stride_t, + stride_h, + stride_w, + A_zero_point, + A, + B_zero_point, + B, + C_multiplier, + C_zero_point, + C, + col_offsets, + bias, + fuse_relu, + 1.0f, // act_scale * weight_scale + thread_id, + num_threads); +} + +void depthwise_3x3x3_per_channel_quantization_pad_1( + int N, + int T, + int H, + int W, + int K, + int stride_t, + int stride_h, + int stride_w, + int32_t A_zero_point, + const uint8_t* A, + const int32_t* B_zero_point, + const PackedDepthWiseConvMatrix& B, + const float* C_multiplier, + int32_t C_zero_point, + uint8_t* C, + const int32_t* col_offsets, + const int32_t* bias, + bool fuse_relu, + int thread_id, + int num_threads) { + depthwise_3x3x3_per_channel_quantization_pad_1( + N, + T, + H, + W, + K, + stride_t, + stride_h, + stride_w, + A_zero_point, + A, + B_zero_point, + B, + C_multiplier, + C_zero_point, + C, + col_offsets, + bias, + fuse_relu, + nullptr, // act_scale * weight_scale + thread_id, + num_threads); +} + +template void depthwise_3x3x3_pad_1( + int N, + int T, + int H, + int W, + int K, + int stride_t, + int stride_h, + int stride_w, + int32_t A_zero_point, + const uint8_t* A, + int32_t B_zero_point, + const PackedDepthWiseConvMatrix& B, + float C_multiplier, + int32_t C_zero_point, + uint8_t* C, + const int32_t* col_offsets, + const int32_t* bias, + bool fuse_relu, + float act_times_w_scale, + int thread_id, + int num_threads); + +template void depthwise_3x3x3_pad_1( + int N, + int T, + int H, + int W, + int K, + int stride_t, + int stride_h, + int stride_w, + int32_t A_zero_point, + const uint8_t* A, + int32_t B_zero_point, + const PackedDepthWiseConvMatrix& B, + float C_multiplier, + int32_t C_zero_point, + uint8_t* C, + const int32_t* col_offsets, + const float* bias, + bool fuse_relu, + float act_times_w_scale, + int thread_id, + int num_threads); + +template void depthwise_3x3x3_per_channel_quantization_pad_1( + int N, + int T, + int H, + int W, + int K, + int stride_t, + int stride_h, + int stride_w, + int32_t A_zero_point, + const uint8_t* A, + const int32_t* B_zero_point, + const PackedDepthWiseConvMatrix& B, + const float* C_multiplier, + int32_t C_zero_point, + uint8_t* C, + const int32_t* col_offsets, + const int32_t* bias, + bool fuse_relu, + const float* act_times_w_scale, + int thread_id, + int num_threads); + +template void depthwise_3x3x3_per_channel_quantization_pad_1( + int N, + int T, + int H, + int W, + int K, + int stride_t, + int stride_h, + int stride_w, + int32_t A_zero_point, + const uint8_t* A, + const int32_t* B_zero_point, + const PackedDepthWiseConvMatrix& B, + const float* C_multiplier, + int32_t C_zero_point, + uint8_t* C, + const int32_t* col_offsets, + const float* bias, + bool fuse_relu, + const float* act_times_w_scale, + int thread_id, + int num_threads); + +} // namespace fbgemm diff --git a/src/FbgemmI8DepthwiseAvx2-inl.h b/src/FbgemmI8DepthwiseAvx2-inl.h new file mode 100644 index 0000000..7ad39fc --- /dev/null +++ b/src/FbgemmI8DepthwiseAvx2-inl.h @@ -0,0 +1,709 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#pragma once + +#include <algorithm> // for min and max +#include <cassert> +#include <cmath> // for lrintf and sqrt +#include <cstdint> +#include <type_traits> // for is_same + +#include <immintrin.h> + +namespace fbgemm { + +// clang-format off +static int masks[8][8] = { + // NOTE: clang-format wants to use a different formatting but the current + // formatting should be easier to read. + { 0, 0, 0, 0, 0, 0, 0, 0, }, + { -1, 0, 0, 0, 0, 0, 0, 0, }, + { -1, -1, 0, 0, 0, 0, 0, 0, }, + { -1, -1, -1, 0, 0, 0, 0, 0, }, + { -1, -1, -1, -1, 0, 0, 0, 0, }, + { -1, -1, -1, -1, -1, 0, 0, 0, }, + { -1, -1, -1, -1, -1, -1, 0, 0, }, + { -1, -1, -1, -1, -1, -1, -1, 0, }, +}; +// clang-format on + +// c = a0 * b0 + a1 * b1 + a2 * b2 + a3 * b3 +// A is in uint8_t +// B is in int8_t and pre-interleaved +// C is in int32_t and 4 registers have results in the following layout: +// c0_v: c[0:4], c[16:20] +// c1_v: c[4:8], c[20:24] +// c2_v: c[8:12], c[24:28] +// c3_v: c[12:16], c[28:32] +template <bool SUM_A = false> +static inline __attribute__((always_inline)) void madd_epi16x4_packed( + __m256i a0_v, + __m256i a1_v, + __m256i a2_v, + __m256i a3_v, + const __m256i* b, + __m256i* c0_v, + __m256i* c1_v, + __m256i* c2_v, + __m256i* c3_v, + __m256i* a_sum = nullptr) { + __m256i a01_lo_v = _mm256_unpacklo_epi8(a0_v, a1_v); + __m256i a01_hi_v = _mm256_unpackhi_epi8(a0_v, a1_v); + __m256i a23_lo_v = _mm256_unpacklo_epi8(a2_v, a3_v); + __m256i a23_hi_v = _mm256_unpackhi_epi8(a2_v, a3_v); + + if (SUM_A) { + __m256i one_epi8_v = _mm256_set1_epi8(1); + a_sum[0] = + _mm256_adds_epi16(_mm256_maddubs_epi16(a01_lo_v, one_epi8_v), a_sum[0]); + a_sum[1] = + _mm256_adds_epi16(_mm256_maddubs_epi16(a01_hi_v, one_epi8_v), a_sum[1]); + a_sum[0] = + _mm256_adds_epi16(_mm256_maddubs_epi16(a23_lo_v, one_epi8_v), a_sum[0]); + a_sum[1] = + _mm256_adds_epi16(_mm256_maddubs_epi16(a23_hi_v, one_epi8_v), a_sum[1]); + } + + __m256i a0_interleaved_v = _mm256_unpacklo_epi16(a01_lo_v, a23_lo_v); + __m256i a1_interleaved_v = _mm256_unpackhi_epi16(a01_lo_v, a23_lo_v); + __m256i a2_interleaved_v = _mm256_unpacklo_epi16(a01_hi_v, a23_hi_v); + __m256i a3_interleaved_v = _mm256_unpackhi_epi16(a01_hi_v, a23_hi_v); + + __m256i b0_v = _mm256_load_si256(b + 0); + __m256i b1_v = _mm256_load_si256(b + 1); + __m256i b2_v = _mm256_load_si256(b + 2); + __m256i b3_v = _mm256_load_si256(b + 3); + + __m256i ab0 = _mm256_maddubs_epi16(a0_interleaved_v, b0_v); + __m256i ab1 = _mm256_maddubs_epi16(a1_interleaved_v, b1_v); + __m256i ab2 = _mm256_maddubs_epi16(a2_interleaved_v, b2_v); + __m256i ab3 = _mm256_maddubs_epi16(a3_interleaved_v, b3_v); + + __m256i one_v = _mm256_set1_epi16(1); + *c0_v = _mm256_madd_epi16(ab0, one_v); + *c1_v = _mm256_madd_epi16(ab1, one_v); + *c2_v = _mm256_madd_epi16(ab2, one_v); + *c3_v = _mm256_madd_epi16(ab3, one_v); +} + +// c = a0 * b0 + a1 * b1 + a2 * b2 +// A is in uint8_t +// B is in int8_t and pre-interleaved +// C is in int32_t and 4 registers have results in the following layout: +// c0_v: c[0:4], c[16:20] +// c1_v: c[4:8], c[20:24] +// c2_v: c[8:12], c[24:28] +// c3_v: c[12:16], c[28:32] +template <bool SUM_A = false> +static inline __attribute__((always_inline)) void madd_epi16x3_packed( + __m256i a0_v, + __m256i a1_v, + __m256i a2_v, + const __m256i* b, + __m256i* c0_v, + __m256i* c1_v, + __m256i* c2_v, + __m256i* c3_v, + __m256i* a_sum = nullptr) { + __m256i zero_v = _mm256_setzero_si256(); + + __m256i a01_lo_v = _mm256_unpacklo_epi8(a0_v, a1_v); + __m256i a01_hi_v = _mm256_unpackhi_epi8(a0_v, a1_v); + __m256i a23_lo_v = _mm256_unpacklo_epi8(a2_v, zero_v); + __m256i a23_hi_v = _mm256_unpackhi_epi8(a2_v, zero_v); + + if (SUM_A) { + __m256i one_epi8_v = _mm256_set1_epi8(1); + a_sum[0] = + _mm256_adds_epi16(_mm256_maddubs_epi16(a01_lo_v, one_epi8_v), a_sum[0]); + a_sum[1] = + _mm256_adds_epi16(_mm256_maddubs_epi16(a01_hi_v, one_epi8_v), a_sum[1]); + a_sum[0] = + _mm256_adds_epi16(_mm256_maddubs_epi16(a23_lo_v, one_epi8_v), a_sum[0]); + a_sum[1] = + _mm256_adds_epi16(_mm256_maddubs_epi16(a23_hi_v, one_epi8_v), a_sum[1]); + } + + __m256i a0_interleaved_v = _mm256_unpacklo_epi16(a01_lo_v, a23_lo_v); + __m256i a1_interleaved_v = _mm256_unpackhi_epi16(a01_lo_v, a23_lo_v); + __m256i a2_interleaved_v = _mm256_unpacklo_epi16(a01_hi_v, a23_hi_v); + __m256i a3_interleaved_v = _mm256_unpackhi_epi16(a01_hi_v, a23_hi_v); + + __m256i b0_v = _mm256_load_si256(b + 0); + __m256i b1_v = _mm256_load_si256(b + 1); + __m256i b2_v = _mm256_load_si256(b + 2); + __m256i b3_v = _mm256_load_si256(b + 3); + + __m256i ab0 = _mm256_maddubs_epi16(a0_interleaved_v, b0_v); + __m256i ab1 = _mm256_maddubs_epi16(a1_interleaved_v, b1_v); + __m256i ab2 = _mm256_maddubs_epi16(a2_interleaved_v, b2_v); + __m256i ab3 = _mm256_maddubs_epi16(a3_interleaved_v, b3_v); + + __m256i one_v = _mm256_set1_epi16(1); + *c0_v = _mm256_madd_epi16(ab0, one_v); + *c1_v = _mm256_madd_epi16(ab1, one_v); + *c2_v = _mm256_madd_epi16(ab2, one_v); + *c3_v = _mm256_madd_epi16(ab3, one_v); +} + +// c = a0 * b0 + a1 * b1 +// A is in uint8_t +// B is in int8_t and pre-interleaved +// C is in int32_t and 4 registers have results in the following layout: +// c0_v: c[0:4], c[4:8] +// c1_v: c[8:12], c[12:16] +// c2_v: c[16:20], c[20:24] +// c3_v: c[24:28], c[28:32] +template <bool SUM_A = false> +static inline __attribute__((always_inline)) void madd_epi16x2_packed( + __m256i a0_v, + __m256i a1_v, + const __m256i* b, + __m256i* c0_v, + __m256i* c1_v, + __m256i* c2_v, + __m256i* c3_v, + __m256i* a_sum = nullptr) { + __m256i a_lo_v = _mm256_unpacklo_epi8(a0_v, a1_v); + __m256i a_hi_v = _mm256_unpackhi_epi8(a0_v, a1_v); + + if (SUM_A) { + __m256i one_epi8_v = _mm256_set1_epi8(1); + a_sum[0] = + _mm256_adds_epi16(_mm256_maddubs_epi16(a_lo_v, one_epi8_v), a_sum[0]); + a_sum[1] = + _mm256_adds_epi16(_mm256_maddubs_epi16(a_hi_v, one_epi8_v), a_sum[1]); + } + + __m256i b0_v = _mm256_load_si256(b + 0); + __m256i b1_v = _mm256_load_si256(b + 1); + + __m256i ab_lo_v = _mm256_maddubs_epi16(a_lo_v, b0_v); + __m256i ab_hi_v = _mm256_maddubs_epi16(a_hi_v, b1_v); + + *c0_v = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(ab_lo_v)); + *c1_v = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(ab_hi_v)); + *c2_v = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(ab_lo_v, 1)); + *c3_v = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(ab_hi_v, 1)); +} + +// c = a0 * b0 +// A is in uint8_t +// B is in int8_t and pre-interleaved +// C is in int32_t and 4 registers have results in the following layout: +// c0_v: c[0:4], c[4:8] +// c1_v: c[8:12], c[12:16] +// c2_v: c[16:20], c[20:24] +// c3_v: c[24:28], c[28:32] +template <bool SUM_A = false> +static inline __attribute__((always_inline)) void madd_epi16_packed( + __m256i a_v, + const __m256i* b, + __m256i* c0_v, + __m256i* c1_v, + __m256i* c2_v, + __m256i* c3_v, + __m256i* a_sum = nullptr) { + __m256i zero_v = _mm256_setzero_si256(); + + __m256i a_lo_v = _mm256_unpacklo_epi8(a_v, zero_v); + __m256i a_hi_v = _mm256_unpackhi_epi8(a_v, zero_v); + + if (SUM_A) { + __m256i one_epi8_v = _mm256_set1_epi8(1); + a_sum[0] = + _mm256_adds_epi16(_mm256_maddubs_epi16(a_lo_v, one_epi8_v), a_sum[0]); + a_sum[1] = + _mm256_adds_epi16(_mm256_maddubs_epi16(a_hi_v, one_epi8_v), a_sum[1]); + } + + __m256i b0_v = _mm256_load_si256(b + 0); + __m256i b1_v = _mm256_load_si256(b + 1); + + __m256i ab_lo_v = _mm256_maddubs_epi16(a_lo_v, b0_v); + __m256i ab_hi_v = _mm256_maddubs_epi16(a_hi_v, b1_v); + + *c0_v = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(ab_lo_v)); + *c1_v = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(ab_hi_v)); + *c2_v = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(ab_lo_v, 1)); + *c3_v = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(ab_hi_v, 1)); +} + +// K is the number of accumulations we're doing +template <int K, bool SUM_A = false, bool REMAINDER = false, bool ACC = false> +static inline __attribute__((always_inline)) void inner_prod_packed_( + const __m256i* a_v, + const __m256i* Bp, + std::int32_t* C, + int remainder, + __m256i* a_sum = nullptr) { + __m256i c[4], c_temp[4]; + __m256i a_sum_temp[2] = {0, 0}; + + int k = 0; + if (K >= 4) { + madd_epi16x4_packed<SUM_A>( + a_v[0], + a_v[1], + a_v[2], + a_v[3], + Bp, + &c[0], + &c[1], + &c[2], + &c[3], + a_sum_temp); + + for (k = 4; k < K / 4 * 4; k += 4) { + madd_epi16x4_packed<SUM_A>( + a_v[k + 0], + a_v[k + 1], + a_v[k + 2], + a_v[k + 3], + Bp + k, + &c_temp[0], + &c_temp[1], + &c_temp[2], + &c_temp[3], + a_sum_temp); + + c[0] = _mm256_add_epi32(c[0], c_temp[0]); + c[1] = _mm256_add_epi32(c[1], c_temp[1]); + c[2] = _mm256_add_epi32(c[2], c_temp[2]); + c[3] = _mm256_add_epi32(c[3], c_temp[3]); + } + } else { + c[0] = _mm256_setzero_si256(); + c[1] = _mm256_setzero_si256(); + c[2] = _mm256_setzero_si256(); + c[3] = _mm256_setzero_si256(); + } + + if (K - k == 3) { + madd_epi16x3_packed<SUM_A>( + a_v[k], + a_v[k + 1], + a_v[k + 2], + Bp + k, + &c_temp[0], + &c_temp[1], + &c_temp[2], + &c_temp[3], + a_sum_temp); + + c[0] = _mm256_add_epi32(c[0], c_temp[0]); + c[1] = _mm256_add_epi32(c[1], c_temp[1]); + c[2] = _mm256_add_epi32(c[2], c_temp[2]); + c[3] = _mm256_add_epi32(c[3], c_temp[3]); + } + + c_temp[0] = _mm256_permute2f128_si256(c[0], c[1], 0x20); + c_temp[1] = _mm256_permute2f128_si256(c[2], c[3], 0x20); + c_temp[2] = _mm256_permute2f128_si256(c[0], c[1], 0x31); + c_temp[3] = _mm256_permute2f128_si256(c[2], c[3], 0x31); + + if (K - k == 0 || K - k == 3) { + c[0] = c_temp[0]; + c[1] = c_temp[1]; + c[2] = c_temp[2]; + c[3] = c_temp[3]; + } else { + if (K - k == 1) { + madd_epi16_packed<SUM_A>( + a_v[k], Bp + k, &c[0], &c[1], &c[2], &c[3], a_sum_temp); + } else if (K - k == 2) { + madd_epi16x2_packed<SUM_A>( + a_v[k], a_v[k + 1], Bp + k, &c[0], &c[1], &c[2], &c[3], a_sum_temp); + } + + c[0] = _mm256_add_epi32(c[0], c_temp[0]); + c[1] = _mm256_add_epi32(c[1], c_temp[1]); + c[2] = _mm256_add_epi32(c[2], c_temp[2]); + c[3] = _mm256_add_epi32(c[3], c_temp[3]); + } + + if (REMAINDER) { + for (int r = 0; r < remainder / 8; ++r) { + if (ACC) { + _mm256_storeu_si256( + reinterpret_cast<__m256i*>(C + r * 8), + _mm256_add_epi32( + _mm256_loadu_si256(reinterpret_cast<__m256i*>(C + r * 8)), + c[r])); + } else { + _mm256_storeu_si256(reinterpret_cast<__m256i*>(C + r * 8), c[r]); + } + } + } else { + if (ACC) { + _mm256_storeu_si256( + reinterpret_cast<__m256i*>(C), + _mm256_add_epi32( + _mm256_loadu_si256(reinterpret_cast<__m256i*>(C)), c[0])); + _mm256_storeu_si256( + reinterpret_cast<__m256i*>(C + 8), + _mm256_add_epi32( + _mm256_loadu_si256(reinterpret_cast<__m256i*>(C + 8)), c[1])); + _mm256_storeu_si256( + reinterpret_cast<__m256i*>(C + 16), + _mm256_add_epi32( + _mm256_loadu_si256(reinterpret_cast<__m256i*>(C + 16)), c[2])); + _mm256_storeu_si256( + reinterpret_cast<__m256i*>(C + 24), + _mm256_add_epi32( + _mm256_loadu_si256(reinterpret_cast<__m256i*>(C + 24)), c[3])); + } else { + _mm256_storeu_si256(reinterpret_cast<__m256i*>(C), c[0]); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(C + 8), c[1]); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(C + 16), c[2]); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(C + 24), c[3]); + } + } + + if (SUM_A) { + a_sum[0] = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(a_sum_temp[0])); + a_sum[1] = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(a_sum_temp[1])); + a_sum[2] = + _mm256_cvtepi16_epi32(_mm256_extracti128_si256(a_sum_temp[0], 1)); + a_sum[3] = + _mm256_cvtepi16_epi32(_mm256_extracti128_si256(a_sum_temp[1], 1)); + } +} + +// Almost same as ReQuantizeOutput in OutputProcessing-inh.h but different +// row_offsets for each row because of depth-wise convolution +template < + bool FUSE_RELU, + bool HAS_BIAS, + bool PER_CHANNEL_QUANTIZATION, + bool A_SYMMETRIC, + bool B_SYMMETRIC, + typename BIAS_TYPE> +static inline __attribute__((always_inline)) void requantize_( + std::int32_t A_zero_point, + const float* C_multiplier, + std::int32_t C_zero_point, + const std::int32_t* C_int32, + std::uint8_t* C_uint8, + int n, + const std::int32_t* row_offsets, + const std::int32_t* col_offsets, + const BIAS_TYPE* bias, + const float* act_times_w_scale = nullptr) { + __m256 multiplier_v = _mm256_setzero_ps(); + // Broadcasted reciprocal of act_times_w_scale + __m256 act_times_w_rcp_v = _mm256_setzero_ps(); + if (!PER_CHANNEL_QUANTIZATION) { + multiplier_v = _mm256_set1_ps(*C_multiplier); + if (std::is_same<BIAS_TYPE, float>::value) { + act_times_w_rcp_v = _mm256_set1_ps(1.0f / (*act_times_w_scale)); + } + } + + __m256i min_v = _mm256_set1_epi8(static_cast<std::uint8_t>(0)); + __m256i max_v = _mm256_set1_epi8(static_cast<std::uint8_t>(255)); + + if (A_SYMMETRIC) { + assert(A_zero_point == 0 || col_offsets == nullptr); + } + __m256i A_zero_point_v = _mm256_set1_epi32(A_zero_point); + __m256i C_zero_point_epi16_v = _mm256_set1_epi16(C_zero_point); + __m256i C_zero_point_epi8_v = _mm256_set1_epi8(C_zero_point); + + __m256i permute_mask_v = + _mm256_set_epi32(0x07, 0x03, 0x06, 0x02, 0x05, 0x01, 0x04, 0x00); + + constexpr int VLEN = 8; + int j = 0; + for (; j < n / (VLEN * 4) * (VLEN * 4); j += (VLEN * 4)) { + __m256i x_v = + _mm256_loadu_si256(reinterpret_cast<const __m256i*>(C_int32 + j)); + __m256i y_v = _mm256_loadu_si256( + reinterpret_cast<const __m256i*>(C_int32 + j + VLEN)); + __m256i z_v = _mm256_loadu_si256( + reinterpret_cast<const __m256i*>(C_int32 + j + 2 * VLEN)); + __m256i w_v = _mm256_loadu_si256( + reinterpret_cast<const __m256i*>(C_int32 + j + 3 * VLEN)); + + __m256i row_offset_v; + if (!B_SYMMETRIC) { + row_offset_v = + _mm256_loadu_si256(reinterpret_cast<const __m256i*>(row_offsets + j)); + x_v = _mm256_sub_epi32(x_v, row_offset_v); + } + __m256i col_off_v; + if (!A_SYMMETRIC) { + col_off_v = _mm256_mullo_epi32( + A_zero_point_v, + _mm256_loadu_si256( + reinterpret_cast<const __m256i*>(col_offsets + j))); + x_v = _mm256_sub_epi32(x_v, col_off_v); + } + + if (!B_SYMMETRIC) { + row_offset_v = _mm256_loadu_si256( + reinterpret_cast<const __m256i*>(row_offsets + j + VLEN)); + y_v = _mm256_sub_epi32(y_v, row_offset_v); + } + if (!A_SYMMETRIC) { + col_off_v = _mm256_mullo_epi32( + A_zero_point_v, + _mm256_loadu_si256( + reinterpret_cast<const __m256i*>(col_offsets + j + VLEN))); + y_v = _mm256_sub_epi32(y_v, col_off_v); + } + + if (!B_SYMMETRIC) { + row_offset_v = _mm256_loadu_si256( + reinterpret_cast<const __m256i*>(row_offsets + j + 2 * VLEN)); + z_v = _mm256_sub_epi32(z_v, row_offset_v); + } + if (!A_SYMMETRIC) { + col_off_v = _mm256_mullo_epi32( + A_zero_point_v, + _mm256_loadu_si256( + reinterpret_cast<const __m256i*>(col_offsets + j + 2 * VLEN))); + z_v = _mm256_sub_epi32(z_v, col_off_v); + } + + if (!B_SYMMETRIC) { + row_offset_v = _mm256_loadu_si256( + reinterpret_cast<const __m256i*>(row_offsets + j + 3 * VLEN)); + w_v = _mm256_sub_epi32(w_v, row_offset_v); + } + if (!A_SYMMETRIC) { + col_off_v = _mm256_mullo_epi32( + A_zero_point_v, + _mm256_loadu_si256( + reinterpret_cast<const __m256i*>(col_offsets + j + 3 * VLEN))); + w_v = _mm256_sub_epi32(w_v, col_off_v); + } + + // convert to float + __m256 xf_v, yf_v, zf_v, wf_v; + if (HAS_BIAS) { // static if + if (std::is_same<BIAS_TYPE, float>::value) { + __m256 x_bias_v, y_bias_v, z_bias_v, w_bias_v; + if (PER_CHANNEL_QUANTIZATION) { + x_bias_v = _mm256_div_ps( + _mm256_loadu_ps( + reinterpret_cast<const float*>(bias + j + 0 * VLEN)), + _mm256_loadu_ps(act_times_w_scale + j + 0 * VLEN)); + y_bias_v = _mm256_div_ps( + _mm256_loadu_ps( + reinterpret_cast<const float*>(bias + j + 1 * VLEN)), + _mm256_loadu_ps(act_times_w_scale + j + 1 * VLEN)); + z_bias_v = _mm256_div_ps( + _mm256_loadu_ps( + reinterpret_cast<const float*>(bias + j + 2 * VLEN)), + _mm256_loadu_ps(act_times_w_scale + j + 2 * VLEN)); + w_bias_v = _mm256_div_ps( + _mm256_loadu_ps( + reinterpret_cast<const float*>(bias + j + 3 * VLEN)), + _mm256_loadu_ps(act_times_w_scale + j + 3 * VLEN)); + } else { + x_bias_v = _mm256_mul_ps( + _mm256_loadu_ps( + reinterpret_cast<const float*>(bias + j + 0 * VLEN)), + act_times_w_rcp_v); + y_bias_v = _mm256_mul_ps( + _mm256_loadu_ps( + reinterpret_cast<const float*>(bias + j + 1 * VLEN)), + act_times_w_rcp_v); + z_bias_v = _mm256_mul_ps( + _mm256_loadu_ps( + reinterpret_cast<const float*>(bias + j + 2 * VLEN)), + act_times_w_rcp_v); + w_bias_v = _mm256_mul_ps( + _mm256_loadu_ps( + reinterpret_cast<const float*>(bias + j + 3 * VLEN)), + act_times_w_rcp_v); + } + xf_v = _mm256_add_ps(_mm256_cvtepi32_ps(x_v), x_bias_v); + yf_v = _mm256_add_ps(_mm256_cvtepi32_ps(y_v), y_bias_v); + zf_v = _mm256_add_ps(_mm256_cvtepi32_ps(z_v), z_bias_v); + wf_v = _mm256_add_ps(_mm256_cvtepi32_ps(w_v), w_bias_v); + } else { + x_v = _mm256_add_epi32( + x_v, + _mm256_loadu_si256( + reinterpret_cast<const __m256i*>(bias + j + 0 * VLEN))); + y_v = _mm256_add_epi32( + y_v, + _mm256_loadu_si256( + reinterpret_cast<const __m256i*>(bias + j + 1 * VLEN))); + z_v = _mm256_add_epi32( + z_v, + _mm256_loadu_si256( + reinterpret_cast<const __m256i*>(bias + j + 2 * VLEN))); + w_v = _mm256_add_epi32( + w_v, + _mm256_loadu_si256( + reinterpret_cast<const __m256i*>(bias + j + 3 * VLEN))); + xf_v = _mm256_cvtepi32_ps(x_v); + yf_v = _mm256_cvtepi32_ps(y_v); + zf_v = _mm256_cvtepi32_ps(z_v); + wf_v = _mm256_cvtepi32_ps(w_v); + } + } else { + xf_v = _mm256_cvtepi32_ps(x_v); + yf_v = _mm256_cvtepi32_ps(y_v); + zf_v = _mm256_cvtepi32_ps(z_v); + wf_v = _mm256_cvtepi32_ps(w_v); + } + + if (PER_CHANNEL_QUANTIZATION) { + multiplier_v = _mm256_loadu_ps(C_multiplier + j + 0 * VLEN); + } + __m256 x_scaled_v = _mm256_mul_ps(xf_v, multiplier_v); + if (PER_CHANNEL_QUANTIZATION) { + multiplier_v = _mm256_loadu_ps(C_multiplier + j + 1 * VLEN); + } + __m256 y_scaled_v = _mm256_mul_ps(yf_v, multiplier_v); + if (PER_CHANNEL_QUANTIZATION) { + multiplier_v = _mm256_loadu_ps(C_multiplier + j + 2 * VLEN); + } + __m256 z_scaled_v = _mm256_mul_ps(zf_v, multiplier_v); + if (PER_CHANNEL_QUANTIZATION) { + multiplier_v = _mm256_loadu_ps(C_multiplier + j + 3 * VLEN); + } + __m256 w_scaled_v = _mm256_mul_ps(wf_v, multiplier_v); + + __m256i x_rounded_v = _mm256_cvtps_epi32(x_scaled_v); + __m256i y_rounded_v = _mm256_cvtps_epi32(y_scaled_v); + __m256i z_rounded_v = _mm256_cvtps_epi32(z_scaled_v); + __m256i w_rounded_v = _mm256_cvtps_epi32(w_scaled_v); + + __m256i xy_packed_v = _mm256_adds_epi16( + _mm256_packs_epi32(x_rounded_v, y_rounded_v), C_zero_point_epi16_v); + __m256i zw_packed_v = _mm256_adds_epi16( + _mm256_packs_epi32(z_rounded_v, w_rounded_v), C_zero_point_epi16_v); + __m256i xyzw_packed_v = _mm256_packus_epi16(xy_packed_v, zw_packed_v); + __m256i xyzw_clamped_v = _mm256_max_epu8( + FUSE_RELU ? C_zero_point_epi8_v : min_v, + _mm256_min_epu8(xyzw_packed_v, max_v)); + + xyzw_clamped_v = + _mm256_permutevar8x32_epi32(xyzw_clamped_v, permute_mask_v); + + _mm256_storeu_si256( + reinterpret_cast<__m256i*>(C_uint8 + j), xyzw_clamped_v); + } // j loop vectorized and unrolled 4x + + for (; j < n / VLEN * VLEN; j += VLEN) { + __m256i x_v = + _mm256_loadu_si256(reinterpret_cast<const __m256i*>(C_int32 + j)); + + if (!B_SYMMETRIC) { + __m256i row_offset_v = + _mm256_loadu_si256(reinterpret_cast<const __m256i*>(row_offsets + j)); + x_v = _mm256_sub_epi32(x_v, row_offset_v); + } + if (!A_SYMMETRIC) { + __m256i col_off_v = _mm256_mullo_epi32( + A_zero_point_v, + _mm256_loadu_si256( + reinterpret_cast<const __m256i*>(col_offsets + j))); + x_v = _mm256_sub_epi32(x_v, col_off_v); + } + + // Convert to float + __m256 xf_v; + if (HAS_BIAS) { // static if + if (std::is_same<BIAS_TYPE, float>::value) { + __m256 x_bias_v; + if (PER_CHANNEL_QUANTIZATION) { + x_bias_v = _mm256_div_ps( + _mm256_loadu_ps(reinterpret_cast<const float*>(bias + j)), + _mm256_loadu_ps(act_times_w_scale + j)); + } else { + x_bias_v = _mm256_mul_ps( + _mm256_loadu_ps(reinterpret_cast<const float*>(bias + j)), + act_times_w_rcp_v); + } + xf_v = _mm256_add_ps(_mm256_cvtepi32_ps(x_v), x_bias_v); + } else { + x_v = _mm256_add_epi32( + x_v, + _mm256_loadu_si256(reinterpret_cast<const __m256i*>(bias + j))); + xf_v = _mm256_cvtepi32_ps(x_v); + } + } else { + xf_v = _mm256_cvtepi32_ps(x_v); + } + + if (PER_CHANNEL_QUANTIZATION) { + multiplier_v = _mm256_loadu_ps(C_multiplier + j); + } + __m256 x_scaled_v = _mm256_mul_ps(xf_v, multiplier_v); + __m256i x_rounded_v = _mm256_cvtps_epi32(x_scaled_v); + + __m256i x_packed_v = _mm256_adds_epi16( + _mm256_packs_epi32(x_rounded_v, _mm256_setzero_si256()), + C_zero_point_epi16_v); + x_packed_v = _mm256_packus_epi16(x_packed_v, _mm256_setzero_si256()); + __m256i x_clamped_v = _mm256_max_epu8( + FUSE_RELU ? C_zero_point_epi8_v : min_v, + _mm256_min_epu8(x_packed_v, max_v)); + + x_clamped_v = _mm256_permutevar8x32_epi32(x_clamped_v, permute_mask_v); + + _mm_storel_epi64( + reinterpret_cast<__m128i*>(C_uint8 + j), + _mm256_castsi256_si128(x_clamped_v)); + } // j loop vectorized + + for (; j < n; ++j) { + std::int32_t raw = C_int32[j]; + if (!B_SYMMETRIC) { + raw -= row_offsets[j]; + } + if (!A_SYMMETRIC) { + raw -= A_zero_point * col_offsets[j]; + } + float raw_f; + if (HAS_BIAS) { // static if + if (std::is_same<BIAS_TYPE, float>::value) { + raw_f = raw; + raw_f += bias[j] / act_times_w_scale[PER_CHANNEL_QUANTIZATION ? j : 0]; + } else { + raw += bias[j]; + raw_f = raw; + } + } else { + raw_f = raw; + } + + float ab = raw_f * C_multiplier[PER_CHANNEL_QUANTIZATION ? j : 0]; + long rounded = lrintf(ab) + C_zero_point; + + C_uint8[j] = std::max( + FUSE_RELU ? static_cast<long>(C_zero_point) : 0l, + std::min(255l, rounded)); + } +} + +template <bool REMAINDER> +static inline __attribute__((always_inline)) __m256i load_a( + const std::uint8_t* A, + __m256i mask_v) { + if (REMAINDER) { + return _mm256_maskload_epi32(reinterpret_cast<const int*>(A), mask_v); + } else { + return _mm256_lddqu_si256(reinterpret_cast<const __m256i*>(A)); + } +} + +static inline std::pair<int, int> closest_factors_(int n) { + int a = static_cast<int>(std::sqrt(n)); + while (n % a != 0) { + a--; + } + return {a, n / a}; // a <= n / a +} + +} // namespace fbgemm diff --git a/src/FbgemmI8DepthwiseAvx2.cc b/src/FbgemmI8DepthwiseAvx2.cc index 183a8a9..994f206 100644 --- a/src/FbgemmI8DepthwiseAvx2.cc +++ b/src/FbgemmI8DepthwiseAvx2.cc @@ -7,569 +7,15 @@ #include "fbgemm/FbgemmI8DepthwiseAvx2.h" #include "fbgemm/Utils.h" -#include <algorithm> // for min and max -#include <cassert> -#include <cmath> // for lrintf and sqrt +#include <string> #include <tuple> // for tie -#include <immintrin.h> +#include "FbgemmI8DepthwiseAvx2-inl.h" using namespace std; namespace fbgemm { -static int masks[8][8] = { - // NOTE: clang-format wants to use a different formatting but the current - // formatting should be easier to read. - { 0, 0, 0, 0, 0, 0, 0, 0, }, - { -1, 0, 0, 0, 0, 0, 0, 0, }, - { -1, -1, 0, 0, 0, 0, 0, 0, }, - { -1, -1, -1, 0, 0, 0, 0, 0, }, - { -1, -1, -1, -1, 0, 0, 0, 0, }, - { -1, -1, -1, -1, -1, 0, 0, 0, }, - { -1, -1, -1, -1, -1, -1, 0, 0, }, - { -1, -1, -1, -1, -1, -1, -1, 0, }, -}; - -template <int KERNEL_PROD> -PackedDepthWiseConvMatrix<KERNEL_PROD>::PackedDepthWiseConvMatrix( - int K, - const int8_t* smat) - : K_(K) { - // Transpose the input matrix to make packing faster. - int8_t* smat_transposed = static_cast<int8_t *>(ALIGNED_MALLOC( - K * KERNEL_PROD * sizeof(int8_t), 64)); - for (int i = 0; i < KERNEL_PROD; ++i) { - for (int j = 0; j < K; ++j) { - smat_transposed[i * K + j] = smat[i + j * KERNEL_PROD]; - } - } - - // Allocate packed arrays - constexpr int KERNEL_PROD_ALIGNED = (KERNEL_PROD + 1) / 2 * 2; -#ifdef _MSC_VER - pmat_ = static_cast<int8_t *>(_aligned_malloc( - ((K + 31) / 32) * KERNEL_PROD_ALIGNED * 32 * sizeof(int8_t), 64)); -#else - posix_memalign( - (void**)&pmat_, - 64, - ((K + 31) / 32) * KERNEL_PROD_ALIGNED * 32 * sizeof(int8_t)); -#endif - - // Pack input matrix - // The layout is optimized to use vpmaddubsw efficiently (see - // madd_epi16x4_packed function). - // For a group of 32 channels, we have 10 32B SIMD registers. - // Denote ith channel jth filter as (i, j) - // 0th SIMD register: - // (0, 0), (0, 1), (0, 2), (0, 3), ..., (3, 0), (3, 1), (3, 2), (3, 3) - // (16, 0), (16, 1), (16, 2), (16, 3), ..., (19, 0), (19, 1), (19, 2), (19, 3) - // 1st SIMD register: - // (4, 0), (4, 1), (4, 2), (4, 3), ..., (7, 0), (7, 1), (7, 2), (7, 3) - // (20, 0), (20, 1), (20, 2), (20, 3), ..., (23, 0), (23, 1), (23, 2), (23, 3) - // 2nd SIMD register: - // (8, 0), (8, 1), (8, 2), (8, 3), ..., (11, 0), (11, 1), (11, 2), (11, 3) - // (24, 0), (24, 1), (24, 2), (24, 3), ..., (27, 0), (27, 1), (27, 2), (27, 3) - // 3rd SIMD register: - // (12, 0), (12, 1), (12, 2), (12, 3), ..., (15, 0), (15, 1), (15, 2), (15, 3) - // (28, 0), (28, 1), (28, 2), (28, 3), ..., (31, 0), (31, 1), (31, 2), (31, 3) - // 4-7th SIMD register: same as the previous 4 registers but for 4-7th filter - // coefficients - // ... - // - // REMAINDER - // If KERNEL_PROD % 4 == 1 for example when KERNEL_PROD == 9 - // 8th SIMD register: - // (0, 8), zero, ..., (7, 8), zero - // (16, 8), zero, ..., (23, 8), zero - // 9th SIMD register: - // (8, 8), zero, ..., (15, 8), zero - // (24, 8), zero, ..., (31, 8), zero - // We use madd_epi16_packed for this case - // - // If KERNEL_PROD % 4 == 2 for example when KERNEL_PROD == 10 - // 8th SIMD register: - // (0, 8), (0, 9), ..., (7, 8), (7, 9) - // (16, 8), (16, 9), ..., (23, 8), (23, 9) - // 9th SIMD register: - // (8, 8), (8, 9), ..., (15, 8), (15, 9) - // (24, 8), (24, 9), ..., (31, 8), (31, 9) - // - // If KERNEL_PROD % 4 == 3 for example when KERNEL_PROD == 11 - // 8th SIMD register: - // (0, 8), (0, 9), (0, 10), zero, ..., (3, 8), (3, 9), (3, 10), zero - // (16, 8), (16, 9), (16, 10), zero, ..., (19, 8), (19, 9), (19, 10), zero - // 9th SIMD register: - // (4, 8), (4, 9), (4, 10), zero, ..., (7, 8), (7, 9), (7, 10), zero - // (20, 8), (20, 9), (20, 10), zero, ..., (23, 8), (23, 9), (23, 10), zero - // 10th SIMD register: - // (8, 8), (8, 9), (8, 10), zero, ..., (11, 8), (11, 9), (11, 10), zero - // (24, 8), (24, 9), (24, 10), zero, ..., (27, 8), (27, 9), (27, 10), zero - // 11th SIMD register: - // (12, 8), (12, 9), (12, 10), zero, ..., (15, 8), (15, 9), (15, 10), zero - // (28, 8), (28, 9), (28, 10), zero, ..., (31, 8), (31, 9), (31, 10), zero - for (int k1 = 0; k1 < K; k1 += 32) { - __m256i b_v[KERNEL_PROD]; - int remainder = K - k1; - if (remainder < 32) { - __m256i mask_v = _mm256_loadu_si256( - reinterpret_cast<const __m256i*>(masks[remainder / 4])); - for (int i = 0; i < KERNEL_PROD; ++i) { - b_v[i] = _mm256_maskload_epi32( - reinterpret_cast<const int*>(smat_transposed + i * K + k1), mask_v); - } - } else { - for (int i = 0; i < KERNEL_PROD; ++i) { - b_v[i] = _mm256_lddqu_si256( - reinterpret_cast<const __m256i*>(smat_transposed + i * K + k1)); - } - } - - // Interleave 2 SIMD registers - __m256i b_interleaved_epi16[KERNEL_PROD_ALIGNED]; - __m256i zero_v = _mm256_setzero_si256(); - for (int i = 0; i < KERNEL_PROD_ALIGNED / 2; ++i) { - if (2 * i + 1 >= KERNEL_PROD) { - b_interleaved_epi16[2 * i] = _mm256_unpacklo_epi8(b_v[2 * i], zero_v); - b_interleaved_epi16[2 * i + 1] = - _mm256_unpackhi_epi8(b_v[2 * i], zero_v); - } else { - b_interleaved_epi16[2 * i] = - _mm256_unpacklo_epi8(b_v[2 * i], b_v[2 * i + 1]); - b_interleaved_epi16[2 * i + 1] = - _mm256_unpackhi_epi8(b_v[2 * i], b_v[2 * i + 1]); - } - } - - // Interleave 4 SIMD registers - __m256i b_interleaved_epi32[KERNEL_PROD_ALIGNED]; - for (int i = 0; i < KERNEL_PROD_ALIGNED / 4; ++i) { - b_interleaved_epi32[4 * i] = _mm256_unpacklo_epi16( - b_interleaved_epi16[4 * i], b_interleaved_epi16[4 * i + 2]); - b_interleaved_epi32[4 * i + 1] = _mm256_unpackhi_epi16( - b_interleaved_epi16[4 * i], b_interleaved_epi16[4 * i + 2]); - b_interleaved_epi32[4 * i + 2] = _mm256_unpacklo_epi16( - b_interleaved_epi16[4 * i + 1], b_interleaved_epi16[4 * i + 3]); - b_interleaved_epi32[4 * i + 3] = _mm256_unpackhi_epi16( - b_interleaved_epi16[4 * i + 1], b_interleaved_epi16[4 * i + 3]); - } - for (int i = KERNEL_PROD_ALIGNED / 4 * 4; i < KERNEL_PROD_ALIGNED; ++i) { - b_interleaved_epi32[i] = b_interleaved_epi16[i]; - } - - for (int i = 0; i < KERNEL_PROD_ALIGNED; ++i) { - _mm256_storeu_si256( - reinterpret_cast<__m256i*>( - &pmat_[((k1 / 32) * KERNEL_PROD_ALIGNED + i) * 32]), - b_interleaved_epi32[i]); - } - } - - FREE(smat_transposed); -} - -template <int KERNEL_PROD> -int PackedDepthWiseConvMatrix<KERNEL_PROD>::addr(int r, int c) { - constexpr int KERNEL_PROD_ALIGNED = (KERNEL_PROD + 1) / 2 * 2; - if (c >= KERNEL_PROD / 4 * 4 && - (KERNEL_PROD % 4 == 1 || KERNEL_PROD % 4 == 2)) { - int kBlock = r / 32; - int reg_idx = (r % 16) / 8 + c / 4 * 4; - - int blk_idx = kBlock * KERNEL_PROD_ALIGNED + reg_idx; - - int r_ = r % 8; - int c_ = c % 4; - - int in_blk_idx = (r % 32) / 16 * 16 + 2 * r_ + c_; - return blk_idx * 32 + in_blk_idx; - - } else { - int kBlock = r / 32; - int reg_idx = (r % 16) / 4 + c / 4 * 4; - - int blk_idx = kBlock * KERNEL_PROD_ALIGNED + reg_idx; - - int r_ = r % 4; - int c_ = c % 4; - - int in_blk_idx = (r % 32) / 16 * 16 + 4 * r_ + c_; - return blk_idx * 32 + in_blk_idx; - } -} - -template <int KERNEL_PROD> -void PackedDepthWiseConvMatrix<KERNEL_PROD>::unpack(int8_t* unpacked_data) { - for (int r = 0; r < K_; ++r) { - for (int c = 0; c < KERNEL_PROD; ++c) { - unpacked_data[r * KERNEL_PROD + c] = pmat_[addr(r, c)]; - } - } -} - -template <int KERNEL_PROD> -PackedDepthWiseConvMatrix<KERNEL_PROD>::~PackedDepthWiseConvMatrix() { -#ifdef _MSC_VER - _aligned_free(pmat_); -#else - free(pmat_); -#endif -} - -template class PackedDepthWiseConvMatrix<3 * 3>; -template class PackedDepthWiseConvMatrix<3 * 3 * 3>; -template class PackedDepthWiseConvMatrix<1>; -template class PackedDepthWiseConvMatrix<2>; -template class PackedDepthWiseConvMatrix<3>; -template class PackedDepthWiseConvMatrix<4>; -template class PackedDepthWiseConvMatrix<5>; -template class PackedDepthWiseConvMatrix<5 * 2>; -template class PackedDepthWiseConvMatrix<11 * 1>; - -// c = a0 * b0 + a1 * b1 + a2 * b2 + a3 * b3 -// A is in uint8_t -// B is in int8_t and pre-interleaved -// C is in int32_t and 4 registers have results in the following layout: -// c0_v: c[0:4], c[16:20] -// c1_v: c[4:8], c[20:24] -// c2_v: c[8:12], c[24:28] -// c3_v: c[12:16], c[28:32] -template <bool SUM_A = false> -static inline ALWAYS_INLINE void madd_epi16x4_packed( - __m256i a0_v, - __m256i a1_v, - __m256i a2_v, - __m256i a3_v, - const __m256i* b, - __m256i* c0_v, - __m256i* c1_v, - __m256i* c2_v, - __m256i* c3_v, - __m256i* a_sum = nullptr) { - __m256i a01_lo_v = _mm256_unpacklo_epi8(a0_v, a1_v); - __m256i a01_hi_v = _mm256_unpackhi_epi8(a0_v, a1_v); - __m256i a23_lo_v = _mm256_unpacklo_epi8(a2_v, a3_v); - __m256i a23_hi_v = _mm256_unpackhi_epi8(a2_v, a3_v); - - if (SUM_A) { - __m256i one_epi8_v = _mm256_set1_epi8(1); - a_sum[0] = - _mm256_adds_epi16(_mm256_maddubs_epi16(a01_lo_v, one_epi8_v), a_sum[0]); - a_sum[1] = - _mm256_adds_epi16(_mm256_maddubs_epi16(a01_hi_v, one_epi8_v), a_sum[1]); - a_sum[0] = - _mm256_adds_epi16(_mm256_maddubs_epi16(a23_lo_v, one_epi8_v), a_sum[0]); - a_sum[1] = - _mm256_adds_epi16(_mm256_maddubs_epi16(a23_hi_v, one_epi8_v), a_sum[1]); - } - - __m256i a0_interleaved_v = _mm256_unpacklo_epi16(a01_lo_v, a23_lo_v); - __m256i a1_interleaved_v = _mm256_unpackhi_epi16(a01_lo_v, a23_lo_v); - __m256i a2_interleaved_v = _mm256_unpacklo_epi16(a01_hi_v, a23_hi_v); - __m256i a3_interleaved_v = _mm256_unpackhi_epi16(a01_hi_v, a23_hi_v); - - __m256i b0_v = _mm256_load_si256(b + 0); - __m256i b1_v = _mm256_load_si256(b + 1); - __m256i b2_v = _mm256_load_si256(b + 2); - __m256i b3_v = _mm256_load_si256(b + 3); - - __m256i ab0 = _mm256_maddubs_epi16(a0_interleaved_v, b0_v); - __m256i ab1 = _mm256_maddubs_epi16(a1_interleaved_v, b1_v); - __m256i ab2 = _mm256_maddubs_epi16(a2_interleaved_v, b2_v); - __m256i ab3 = _mm256_maddubs_epi16(a3_interleaved_v, b3_v); - - __m256i one_v = _mm256_set1_epi16(1); - *c0_v = _mm256_madd_epi16(ab0, one_v); - *c1_v = _mm256_madd_epi16(ab1, one_v); - *c2_v = _mm256_madd_epi16(ab2, one_v); - *c3_v = _mm256_madd_epi16(ab3, one_v); -} - -// c = a0 * b0 + a1 * b1 + a2 * b2 -// A is in uint8_t -// B is in int8_t and pre-interleaved -// C is in int32_t and 4 registers have results in the following layout: -// c0_v: c[0:4], c[16:20] -// c1_v: c[4:8], c[20:24] -// c2_v: c[8:12], c[24:28] -// c3_v: c[12:16], c[28:32] -template <bool SUM_A = false> -static inline ALWAYS_INLINE void madd_epi16x3_packed( - __m256i a0_v, - __m256i a1_v, - __m256i a2_v, - const __m256i* b, - __m256i* c0_v, - __m256i* c1_v, - __m256i* c2_v, - __m256i* c3_v, - __m256i* a_sum = nullptr) { - __m256i zero_v = _mm256_setzero_si256(); - - __m256i a01_lo_v = _mm256_unpacklo_epi8(a0_v, a1_v); - __m256i a01_hi_v = _mm256_unpackhi_epi8(a0_v, a1_v); - __m256i a23_lo_v = _mm256_unpacklo_epi8(a2_v, zero_v); - __m256i a23_hi_v = _mm256_unpackhi_epi8(a2_v, zero_v); - - if (SUM_A) { - __m256i one_epi8_v = _mm256_set1_epi8(1); - a_sum[0] = - _mm256_adds_epi16(_mm256_maddubs_epi16(a01_lo_v, one_epi8_v), a_sum[0]); - a_sum[1] = - _mm256_adds_epi16(_mm256_maddubs_epi16(a01_hi_v, one_epi8_v), a_sum[1]); - a_sum[0] = - _mm256_adds_epi16(_mm256_maddubs_epi16(a23_lo_v, one_epi8_v), a_sum[0]); - a_sum[1] = - _mm256_adds_epi16(_mm256_maddubs_epi16(a23_hi_v, one_epi8_v), a_sum[1]); - } - - __m256i a0_interleaved_v = _mm256_unpacklo_epi16(a01_lo_v, a23_lo_v); - __m256i a1_interleaved_v = _mm256_unpackhi_epi16(a01_lo_v, a23_lo_v); - __m256i a2_interleaved_v = _mm256_unpacklo_epi16(a01_hi_v, a23_hi_v); - __m256i a3_interleaved_v = _mm256_unpackhi_epi16(a01_hi_v, a23_hi_v); - - __m256i b0_v = _mm256_load_si256(b + 0); - __m256i b1_v = _mm256_load_si256(b + 1); - __m256i b2_v = _mm256_load_si256(b + 2); - __m256i b3_v = _mm256_load_si256(b + 3); - - __m256i ab0 = _mm256_maddubs_epi16(a0_interleaved_v, b0_v); - __m256i ab1 = _mm256_maddubs_epi16(a1_interleaved_v, b1_v); - __m256i ab2 = _mm256_maddubs_epi16(a2_interleaved_v, b2_v); - __m256i ab3 = _mm256_maddubs_epi16(a3_interleaved_v, b3_v); - - __m256i one_v = _mm256_set1_epi16(1); - *c0_v = _mm256_madd_epi16(ab0, one_v); - *c1_v = _mm256_madd_epi16(ab1, one_v); - *c2_v = _mm256_madd_epi16(ab2, one_v); - *c3_v = _mm256_madd_epi16(ab3, one_v); -} - -// c = a0 * b0 + a1 * b1 -// A is in uint8_t -// B is in int8_t and pre-interleaved -// C is in int32_t and 4 registers have results in the following layout: -// c0_v: c[0:4], c[4:8] -// c1_v: c[8:12], c[12:16] -// c2_v: c[16:20], c[20:24] -// c3_v: c[24:28], c[28:32] -template <bool SUM_A = false> -static inline ALWAYS_INLINE void madd_epi16x2_packed( - __m256i a0_v, - __m256i a1_v, - const __m256i* b, - __m256i* c0_v, - __m256i* c1_v, - __m256i* c2_v, - __m256i* c3_v, - __m256i* a_sum = nullptr) { - __m256i a_lo_v = _mm256_unpacklo_epi8(a0_v, a1_v); - __m256i a_hi_v = _mm256_unpackhi_epi8(a0_v, a1_v); - - if (SUM_A) { - __m256i one_epi8_v = _mm256_set1_epi8(1); - a_sum[0] = - _mm256_adds_epi16(_mm256_maddubs_epi16(a_lo_v, one_epi8_v), a_sum[0]); - a_sum[1] = - _mm256_adds_epi16(_mm256_maddubs_epi16(a_hi_v, one_epi8_v), a_sum[1]); - } - - __m256i b0_v = _mm256_load_si256(b + 0); - __m256i b1_v = _mm256_load_si256(b + 1); - - __m256i ab_lo_v = _mm256_maddubs_epi16(a_lo_v, b0_v); - __m256i ab_hi_v = _mm256_maddubs_epi16(a_hi_v, b1_v); - - *c0_v = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(ab_lo_v)); - *c1_v = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(ab_hi_v)); - *c2_v = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(ab_lo_v, 1)); - *c3_v = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(ab_hi_v, 1)); -} - -// c = a0 * b0 -// A is in uint8_t -// B is in int8_t and pre-interleaved -// C is in int32_t and 4 registers have results in the following layout: -// c0_v: c[0:4], c[4:8] -// c1_v: c[8:12], c[12:16] -// c2_v: c[16:20], c[20:24] -// c3_v: c[24:28], c[28:32] -template <bool SUM_A = false> -static inline ALWAYS_INLINE void madd_epi16_packed( - __m256i a_v, - const __m256i* b, - __m256i* c0_v, - __m256i* c1_v, - __m256i* c2_v, - __m256i* c3_v, - __m256i* a_sum = nullptr) { - __m256i zero_v = _mm256_setzero_si256(); - - __m256i a_lo_v = _mm256_unpacklo_epi8(a_v, zero_v); - __m256i a_hi_v = _mm256_unpackhi_epi8(a_v, zero_v); - - if (SUM_A) { - __m256i one_epi8_v = _mm256_set1_epi8(1); - a_sum[0] = - _mm256_adds_epi16(_mm256_maddubs_epi16(a_lo_v, one_epi8_v), a_sum[0]); - a_sum[1] = - _mm256_adds_epi16(_mm256_maddubs_epi16(a_hi_v, one_epi8_v), a_sum[1]); - } - - __m256i b0_v = _mm256_load_si256(b + 0); - __m256i b1_v = _mm256_load_si256(b + 1); - - __m256i ab_lo_v = _mm256_maddubs_epi16(a_lo_v, b0_v); - __m256i ab_hi_v = _mm256_maddubs_epi16(a_hi_v, b1_v); - - *c0_v = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(ab_lo_v)); - *c1_v = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(ab_hi_v)); - *c2_v = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(ab_lo_v, 1)); - *c3_v = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(ab_hi_v, 1)); -} - -// K is the number of accumulations we're doing -template <int K, bool SUM_A = false, bool REMAINDER = false, bool ACC = false> -static inline ALWAYS_INLINE void inner_prod_packed_( - const __m256i* a_v, - const __m256i* Bp, - int32_t* C, - int remainder, - __m256i* a_sum = nullptr) { - __m256i c[4], c_temp[4]; - __m256i a_sum_temp[2] = {0, 0}; - - int k = 0; - if (K >= 4) { - madd_epi16x4_packed<SUM_A>( - a_v[0], - a_v[1], - a_v[2], - a_v[3], - Bp, - &c[0], - &c[1], - &c[2], - &c[3], - a_sum_temp); - - for (k = 4; k < K / 4 * 4; k += 4) { - madd_epi16x4_packed<SUM_A>( - a_v[k + 0], - a_v[k + 1], - a_v[k + 2], - a_v[k + 3], - Bp + k, - &c_temp[0], - &c_temp[1], - &c_temp[2], - &c_temp[3], - a_sum_temp); - - c[0] = _mm256_add_epi32(c[0], c_temp[0]); - c[1] = _mm256_add_epi32(c[1], c_temp[1]); - c[2] = _mm256_add_epi32(c[2], c_temp[2]); - c[3] = _mm256_add_epi32(c[3], c_temp[3]); - } - } else { - c[0] = _mm256_setzero_si256(); - c[1] = _mm256_setzero_si256(); - c[2] = _mm256_setzero_si256(); - c[3] = _mm256_setzero_si256(); - } - - if (K - k == 3) { - madd_epi16x3_packed<SUM_A>( - a_v[k], - a_v[k + 1], - a_v[k + 2], - Bp + k, - &c_temp[0], - &c_temp[1], - &c_temp[2], - &c_temp[3], - a_sum_temp); - - c[0] = _mm256_add_epi32(c[0], c_temp[0]); - c[1] = _mm256_add_epi32(c[1], c_temp[1]); - c[2] = _mm256_add_epi32(c[2], c_temp[2]); - c[3] = _mm256_add_epi32(c[3], c_temp[3]); - } - - c_temp[0] = _mm256_permute2f128_si256(c[0], c[1], 0x20); - c_temp[1] = _mm256_permute2f128_si256(c[2], c[3], 0x20); - c_temp[2] = _mm256_permute2f128_si256(c[0], c[1], 0x31); - c_temp[3] = _mm256_permute2f128_si256(c[2], c[3], 0x31); - - if (K - k == 0 || K - k == 3) { - c[0] = c_temp[0]; - c[1] = c_temp[1]; - c[2] = c_temp[2]; - c[3] = c_temp[3]; - } else { - if (K - k == 1) { - madd_epi16_packed<SUM_A>( - a_v[k], Bp + k, &c[0], &c[1], &c[2], &c[3], a_sum_temp); - } else if (K - k == 2) { - madd_epi16x2_packed<SUM_A>( - a_v[k], a_v[k + 1], Bp + k, &c[0], &c[1], &c[2], &c[3], a_sum_temp); - } - - c[0] = _mm256_add_epi32(c[0], c_temp[0]); - c[1] = _mm256_add_epi32(c[1], c_temp[1]); - c[2] = _mm256_add_epi32(c[2], c_temp[2]); - c[3] = _mm256_add_epi32(c[3], c_temp[3]); - } - - if (REMAINDER) { - for (int r = 0; r < remainder / 8; ++r) { - if (ACC) { - _mm256_storeu_si256( - reinterpret_cast<__m256i*>(C + r * 8), - _mm256_add_epi32( - _mm256_loadu_si256(reinterpret_cast<__m256i*>(C + r * 8)), - c[r])); - } else { - _mm256_storeu_si256(reinterpret_cast<__m256i*>(C + r * 8), c[r]); - } - } - } else { - if (ACC) { - _mm256_storeu_si256( - reinterpret_cast<__m256i*>(C), - _mm256_add_epi32( - _mm256_loadu_si256(reinterpret_cast<__m256i*>(C)), c[0])); - _mm256_storeu_si256( - reinterpret_cast<__m256i*>(C + 8), - _mm256_add_epi32( - _mm256_loadu_si256(reinterpret_cast<__m256i*>(C + 8)), c[1])); - _mm256_storeu_si256( - reinterpret_cast<__m256i*>(C + 16), - _mm256_add_epi32( - _mm256_loadu_si256(reinterpret_cast<__m256i*>(C + 16)), c[2])); - _mm256_storeu_si256( - reinterpret_cast<__m256i*>(C + 24), - _mm256_add_epi32( - _mm256_loadu_si256(reinterpret_cast<__m256i*>(C + 24)), c[3])); - } else { - _mm256_storeu_si256(reinterpret_cast<__m256i*>(C), c[0]); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(C + 8), c[1]); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(C + 16), c[2]); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(C + 24), c[3]); - } - } - - if (SUM_A) { - a_sum[0] = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(a_sum_temp[0])); - a_sum[1] = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(a_sum_temp[1])); - a_sum[2] = - _mm256_cvtepi16_epi32(_mm256_extracti128_si256(a_sum_temp[0], 1)); - a_sum[3] = - _mm256_cvtepi16_epi32(_mm256_extracti128_si256(a_sum_temp[1], 1)); - } -} - template <bool SUM_A = false, bool REMAINDER = false> static inline ALWAYS_INLINE void inner_prod_3x3_packed_( const __m256i* a_v, @@ -580,238 +26,6 @@ static inline ALWAYS_INLINE void inner_prod_3x3_packed_( return inner_prod_packed_<9, SUM_A, REMAINDER>(a_v, Bp, C, remainder, a_sum); } -// Almost same as ReQuantizeOutput in OutputProcessing-inh.h but different -// row_offsets for each row because of depth-wise convolution -template < - bool FUSE_RELU, - bool HAS_BIAS, - bool PER_CHANNEL_QUANTIZATION, - bool A_SYMMETRIC, - bool B_SYMMETRIC> -static inline ALWAYS_INLINE void requantize_( - int32_t A_zero_point, - const float* C_multiplier, - int32_t C_zero_point, - const int32_t* C_int32, - uint8_t* C_uint8, - int n, - const int32_t* row_offsets, - const int32_t* col_offsets, - const int32_t* bias) { - __m256 multiplier_v = _mm256_setzero_ps(); - if (!PER_CHANNEL_QUANTIZATION) { - multiplier_v = _mm256_set1_ps(*C_multiplier); - } - - __m256i min_v = _mm256_set1_epi8(static_cast<uint8_t>(0)); - __m256i max_v = _mm256_set1_epi8(static_cast<uint8_t>(255)); - - if (A_SYMMETRIC) { - assert(A_zero_point == 0 || col_offsets == nullptr); - } - __m256i A_zero_point_v = _mm256_set1_epi32(A_zero_point); - __m256i C_zero_point_epi16_v = _mm256_set1_epi16(C_zero_point); - __m256i C_zero_point_epi8_v = _mm256_set1_epi8(C_zero_point); - - __m256i permute_mask_v = - _mm256_set_epi32(0x07, 0x03, 0x06, 0x02, 0x05, 0x01, 0x04, 0x00); - - constexpr int VLEN = 8; - int j = 0; - for (; j < n / (VLEN * 4) * (VLEN * 4); j += (VLEN * 4)) { - __m256i x_v = - _mm256_loadu_si256(reinterpret_cast<const __m256i*>(C_int32 + j)); - __m256i y_v = _mm256_loadu_si256( - reinterpret_cast<const __m256i*>(C_int32 + j + VLEN)); - __m256i z_v = _mm256_loadu_si256( - reinterpret_cast<const __m256i*>(C_int32 + j + 2 * VLEN)); - __m256i w_v = _mm256_loadu_si256( - reinterpret_cast<const __m256i*>(C_int32 + j + 3 * VLEN)); - - __m256i row_offset_v; - if (!B_SYMMETRIC) { - row_offset_v = - _mm256_loadu_si256(reinterpret_cast<const __m256i*>(row_offsets + j)); - x_v = _mm256_sub_epi32(x_v, row_offset_v); - } - __m256i col_off_v; - if (!A_SYMMETRIC) { - col_off_v = _mm256_mullo_epi32( - A_zero_point_v, - _mm256_loadu_si256( - reinterpret_cast<const __m256i*>(col_offsets + j))); - x_v = _mm256_sub_epi32(x_v, col_off_v); - } - - if (!B_SYMMETRIC) { - row_offset_v = _mm256_loadu_si256( - reinterpret_cast<const __m256i*>(row_offsets + j + VLEN)); - y_v = _mm256_sub_epi32(y_v, row_offset_v); - } - if (!A_SYMMETRIC) { - col_off_v = _mm256_mullo_epi32( - A_zero_point_v, - _mm256_loadu_si256( - reinterpret_cast<const __m256i*>(col_offsets + j + VLEN))); - y_v = _mm256_sub_epi32(y_v, col_off_v); - } - - if (!B_SYMMETRIC) { - row_offset_v = _mm256_loadu_si256( - reinterpret_cast<const __m256i*>(row_offsets + j + 2 * VLEN)); - z_v = _mm256_sub_epi32(z_v, row_offset_v); - } - if (!A_SYMMETRIC) { - col_off_v = _mm256_mullo_epi32( - A_zero_point_v, - _mm256_loadu_si256( - reinterpret_cast<const __m256i*>(col_offsets + j + 2 * VLEN))); - z_v = _mm256_sub_epi32(z_v, col_off_v); - } - - if (!B_SYMMETRIC) { - row_offset_v = _mm256_loadu_si256( - reinterpret_cast<const __m256i*>(row_offsets + j + 3 * VLEN)); - w_v = _mm256_sub_epi32(w_v, row_offset_v); - } - if (!A_SYMMETRIC) { - col_off_v = _mm256_mullo_epi32( - A_zero_point_v, - _mm256_loadu_si256( - reinterpret_cast<const __m256i*>(col_offsets + j + 3 * VLEN))); - w_v = _mm256_sub_epi32(w_v, col_off_v); - } - - if (HAS_BIAS) { // static if - x_v = _mm256_add_epi32( - x_v, _mm256_loadu_si256(reinterpret_cast<const __m256i*>(bias + j))); - y_v = _mm256_add_epi32( - y_v, - _mm256_loadu_si256( - reinterpret_cast<const __m256i*>(bias + j + VLEN))); - z_v = _mm256_add_epi32( - z_v, - _mm256_loadu_si256( - reinterpret_cast<const __m256i*>(bias + j + 2 * VLEN))); - w_v = _mm256_add_epi32( - w_v, - _mm256_loadu_si256( - reinterpret_cast<const __m256i*>(bias + j + 3 * VLEN))); - } - - if (PER_CHANNEL_QUANTIZATION) { - multiplier_v = _mm256_loadu_ps(C_multiplier + j); - } - __m256 x_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(x_v), multiplier_v); - if (PER_CHANNEL_QUANTIZATION) { - multiplier_v = _mm256_loadu_ps(C_multiplier + j + VLEN); - } - __m256 y_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(y_v), multiplier_v); - if (PER_CHANNEL_QUANTIZATION) { - multiplier_v = _mm256_loadu_ps(C_multiplier + j + 2 * VLEN); - } - __m256 z_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(z_v), multiplier_v); - if (PER_CHANNEL_QUANTIZATION) { - multiplier_v = _mm256_loadu_ps(C_multiplier + j + 3 * VLEN); - } - __m256 w_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(w_v), multiplier_v); - - __m256i x_rounded_v = _mm256_cvtps_epi32(x_scaled_v); - __m256i y_rounded_v = _mm256_cvtps_epi32(y_scaled_v); - __m256i z_rounded_v = _mm256_cvtps_epi32(z_scaled_v); - __m256i w_rounded_v = _mm256_cvtps_epi32(w_scaled_v); - - __m256i xy_packed_v = _mm256_adds_epi16( - _mm256_packs_epi32(x_rounded_v, y_rounded_v), C_zero_point_epi16_v); - __m256i zw_packed_v = _mm256_adds_epi16( - _mm256_packs_epi32(z_rounded_v, w_rounded_v), C_zero_point_epi16_v); - __m256i xyzw_packed_v = _mm256_packus_epi16(xy_packed_v, zw_packed_v); - __m256i xyzw_clamped_v = _mm256_max_epu8( - FUSE_RELU ? C_zero_point_epi8_v : min_v, - _mm256_min_epu8(xyzw_packed_v, max_v)); - - xyzw_clamped_v = - _mm256_permutevar8x32_epi32(xyzw_clamped_v, permute_mask_v); - - _mm256_storeu_si256( - reinterpret_cast<__m256i*>(C_uint8 + j), xyzw_clamped_v); - } // j loop vectorized and unrolled 4x - - for (; j < n / VLEN * VLEN; j += VLEN) { - __m256i x_v = - _mm256_loadu_si256(reinterpret_cast<const __m256i*>(C_int32 + j)); - - if (!B_SYMMETRIC) { - __m256i row_offset_v = - _mm256_loadu_si256(reinterpret_cast<const __m256i*>(row_offsets + j)); - x_v = _mm256_sub_epi32(x_v, row_offset_v); - } - if (!A_SYMMETRIC) { - __m256i col_off_v = _mm256_mullo_epi32( - A_zero_point_v, - _mm256_loadu_si256( - reinterpret_cast<const __m256i*>(col_offsets + j))); - x_v = _mm256_sub_epi32(x_v, col_off_v); - } - - if (HAS_BIAS) { // static if - x_v = _mm256_add_epi32( - x_v, _mm256_loadu_si256(reinterpret_cast<const __m256i*>(bias + j))); - } - - if (PER_CHANNEL_QUANTIZATION) { - multiplier_v = _mm256_loadu_ps(C_multiplier + j); - } - __m256 x_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(x_v), multiplier_v); - __m256i x_rounded_v = _mm256_cvtps_epi32(x_scaled_v); - - __m256i x_packed_v = _mm256_adds_epi16( - _mm256_packs_epi32(x_rounded_v, _mm256_setzero_si256()), - C_zero_point_epi16_v); - x_packed_v = _mm256_packus_epi16(x_packed_v, _mm256_setzero_si256()); - __m256i x_clamped_v = _mm256_max_epu8( - FUSE_RELU ? C_zero_point_epi8_v : min_v, - _mm256_min_epu8(x_packed_v, max_v)); - - x_clamped_v = _mm256_permutevar8x32_epi32(x_clamped_v, permute_mask_v); - - _mm_storel_epi64( - reinterpret_cast<__m128i*>(C_uint8 + j), - _mm256_castsi256_si128(x_clamped_v)); - } // j loop vectorized - - for (; j < n; ++j) { - int32_t raw = C_int32[j]; - if (!B_SYMMETRIC) { - raw -= row_offsets[j]; - } - if (!A_SYMMETRIC) { - raw -= A_zero_point * col_offsets[j]; - } - if (HAS_BIAS) { // static if - raw += bias[j]; - } - - float ab = raw * C_multiplier[PER_CHANNEL_QUANTIZATION ? j : 0]; - long rounded = lrintf(ab) + C_zero_point; - - C_uint8[j] = std::max( - FUSE_RELU ? static_cast<long>(C_zero_point) : 0l, - std::min(255l, rounded)); - } -} - -template <bool REMAINDER> -static inline ALWAYS_INLINE __m256i load_a( - const uint8_t* A, - __m256i mask_v) { - if (REMAINDER) { - return _mm256_maskload_epi32(reinterpret_cast<const int*>(A), mask_v); - } else { - return _mm256_lddqu_si256(reinterpret_cast<const __m256i*>(A)); - } -} - template < bool SUM_A, bool REMAINDER = false, @@ -924,257 +138,11 @@ static inline ALWAYS_INLINE void inner_prod_3x3_packed_( } template < - bool SUM_A, - bool REMAINDER = false, - bool PER_CHANNEL_QUANTIZATION = false> -static inline ALWAYS_INLINE void inner_prod_3x3x3_packed_( - int T, - int H, - int W, - int K, - int t_in, - int h_in, - int w_in, - const uint8_t* A, - int32_t A_zero_point, - const int8_t* Bp, - const int32_t* B_zero_point, - int32_t* C, - int remainder, - int32_t* row_offsets) { - __m256i A_zero_point_v = _mm256_set1_epi8(static_cast<uint8_t>(A_zero_point)); - __m256i mask_v = _mm256_setzero_si256(); - if (REMAINDER) { - mask_v = _mm256_loadu_si256( - reinterpret_cast<const __m256i*>(masks[remainder / 4])); - } - - // The code below can be written as a simple R*S loop but the compiler - // doesn't unroll so we're manually unrolling it. - // constexpr int R = 3, S = 3; - // array<__m256i, R * S> a_v; - // for (int r = 0; r < R; ++r) { - // for (int s = 0; s < S; ++s) { - // if (h_in + r >= 0 && h_in + r < H && w_in + s >= 0 && w_in + s < W) { - // if (REMAINDER) { - // a_v[r * S + s] = - // _mm256_maskload_epi32((const int *)(A + (r * W + s) * K), - // mask_v); - // } else { - // a_v[r * S + s] = - // _mm256_lddqu_si256((const __m256i *)(A + (r * W + s) * K)); - // } - // } else { - // a_v[r * S + s] = A_zero_point_v; - // } - // } - // } - __m256i a_v[8]; - a_v[0] = A_zero_point_v; - a_v[1] = A_zero_point_v; - a_v[2] = A_zero_point_v; - a_v[3] = A_zero_point_v; - a_v[4] = A_zero_point_v; - a_v[5] = A_zero_point_v; - a_v[6] = A_zero_point_v; - a_v[7] = A_zero_point_v; - - if (t_in >= 0 && t_in < T) { - if (h_in >= 0 && h_in < H) { - if (w_in >= 0 && w_in < W) { - a_v[0] = load_a<REMAINDER>(A + ((0 * H + 0) * W + 0) * K, mask_v); - } - if (w_in + 1 >= 0 && w_in + 1 < W) { - a_v[1] = load_a<REMAINDER>(A + ((0 * H + 0) * W + 1) * K, mask_v); - } - if (w_in + 2 >= 0 && w_in + 2 < W) { - a_v[2] = load_a<REMAINDER>(A + ((0 * H + 0) * W + 2) * K, mask_v); - } - } - - if (h_in + 1 >= 0 && h_in + 1 < H) { - if (w_in >= 0 && w_in < W) { - a_v[3] = load_a<REMAINDER>(A + ((0 * H + 1) * W + 0) * K, mask_v); - } - if (w_in + 1 >= 0 && w_in + 1 < W) { - a_v[4] = load_a<REMAINDER>(A + ((0 * H + 1) * W + 1) * K, mask_v); - } - if (w_in + 2 >= 0 && w_in + 2 < W) { - a_v[5] = load_a<REMAINDER>(A + ((0 * H + 1) * W + 2) * K, mask_v); - } - } - - if (h_in + 2 >= 0 && h_in + 2 < H) { - if (w_in >= 0 && w_in < W) { - a_v[6] = load_a<REMAINDER>(A + ((0 * H + 2) * W + 0) * K, mask_v); - } - if (w_in + 1 >= 0 && w_in + 1 < W) { - a_v[7] = load_a<REMAINDER>(A + ((0 * H + 2) * W + 1) * K, mask_v); - } - } - } - - __m256i a_sum[4]; - inner_prod_packed_<8, SUM_A, REMAINDER>( - a_v, reinterpret_cast<const __m256i*>(Bp), C, remainder, a_sum); - - a_v[0] = A_zero_point_v; - a_v[1] = A_zero_point_v; - a_v[2] = A_zero_point_v; - a_v[3] = A_zero_point_v; - a_v[4] = A_zero_point_v; - a_v[5] = A_zero_point_v; - a_v[6] = A_zero_point_v; - a_v[7] = A_zero_point_v; - - if (t_in >= 0 && t_in < T) { - if (h_in + 2 >= 0 && h_in + 2 < H) { - if (w_in + 2 >= 0 && w_in + 2 < W) { - a_v[0] = load_a<REMAINDER>(A + ((0 * H + 2) * W + 2) * K, mask_v); - } - } - } - - if (t_in + 1 >= 0 && t_in + 1 < T) { - if (h_in >= 0 && h_in < H) { - if (w_in >= 0 && w_in < W) { - a_v[1] = load_a<REMAINDER>(A + ((1 * H + 0) * W + 0) * K, mask_v); - } - if (w_in + 1 >= 0 && w_in + 1 < W) { - a_v[2] = load_a<REMAINDER>(A + ((1 * H + 0) * W + 1) * K, mask_v); - } - if (w_in + 2 >= 0 && w_in + 2 < W) { - a_v[3] = load_a<REMAINDER>(A + ((1 * H + 0) * W + 2) * K, mask_v); - } - } - - if (h_in + 1 >= 0 && h_in + 1 < H) { - if (w_in >= 0 && w_in < W) { - a_v[4] = load_a<REMAINDER>(A + ((1 * H + 1) * W + 0) * K, mask_v); - } - if (w_in + 1 >= 0 && w_in + 1 < W) { - a_v[5] = load_a<REMAINDER>(A + ((1 * H + 1) * W + 1) * K, mask_v); - } - if (w_in + 2 >= 0 && w_in + 2 < W) { - a_v[6] = load_a<REMAINDER>(A + ((1 * H + 1) * W + 2) * K, mask_v); - } - } - - if (h_in + 2 >= 0 && h_in + 2 < H) { - if (w_in >= 0 && w_in < W) { - a_v[7] = load_a<REMAINDER>(A + ((1 * H + 2) * W + 0) * K, mask_v); - } - } - } - - __m256i a_sum_temp[4]; - inner_prod_packed_<8, SUM_A, REMAINDER, true /* acc */>( - a_v, reinterpret_cast<const __m256i*>(Bp) + 8, C, remainder, a_sum_temp); - if (SUM_A) { - a_sum[0] = _mm256_add_epi32(a_sum[0], a_sum_temp[0]); - a_sum[1] = _mm256_add_epi32(a_sum[1], a_sum_temp[1]); - a_sum[2] = _mm256_add_epi32(a_sum[2], a_sum_temp[2]); - a_sum[3] = _mm256_add_epi32(a_sum[3], a_sum_temp[3]); - } - - a_v[0] = A_zero_point_v; - a_v[1] = A_zero_point_v; - a_v[2] = A_zero_point_v; - a_v[3] = A_zero_point_v; - a_v[4] = A_zero_point_v; - a_v[5] = A_zero_point_v; - a_v[6] = A_zero_point_v; - a_v[7] = A_zero_point_v; - - if (t_in + 1 >= 0 && t_in + 1 < T) { - if (h_in + 2 >= 0 && h_in + 2 < H) { - if (w_in + 1 >= 0 && w_in + 1 < W) { - a_v[0] = load_a<REMAINDER>(A + ((1 * H + 2) * W + 1) * K, mask_v); - } - if (w_in + 2 >= 0 && w_in + 2 < W) { - a_v[1] = load_a<REMAINDER>(A + ((1 * H + 2) * W + 2) * K, mask_v); - } - } - } - - if (t_in + 2 >= 0 && t_in + 2 < T) { - if (h_in >= 0 && h_in < H) { - if (w_in >= 0 && w_in < W) { - a_v[2] = load_a<REMAINDER>(A + ((2 * H + 0) * W + 0) * K, mask_v); - } - if (w_in + 1 >= 0 && w_in + 1 < W) { - a_v[3] = load_a<REMAINDER>(A + ((2 * H + 0) * W + 1) * K, mask_v); - } - if (w_in + 2 >= 0 && w_in + 2 < W) { - a_v[4] = load_a<REMAINDER>(A + ((2 * H + 0) * W + 2) * K, mask_v); - } - } - - if (h_in + 1 >= 0 && h_in + 1 < H) { - if (w_in >= 0 && w_in < W) { - a_v[5] = load_a<REMAINDER>(A + ((2 * H + 1) * W + 0) * K, mask_v); - } - if (w_in + 1 >= 0 && w_in + 1 < W) { - a_v[6] = load_a<REMAINDER>(A + ((2 * H + 1) * W + 1) * K, mask_v); - } - if (w_in + 2 >= 0 && w_in + 2 < W) { - a_v[7] = load_a<REMAINDER>(A + ((2 * H + 1) * W + 2) * K, mask_v); - } - } - } - - inner_prod_packed_<8, SUM_A, REMAINDER, true /* acc */>( - a_v, reinterpret_cast<const __m256i*>(Bp) + 16, C, remainder, a_sum_temp); - if (SUM_A) { - a_sum[0] = _mm256_add_epi32(a_sum[0], a_sum_temp[0]); - a_sum[1] = _mm256_add_epi32(a_sum[1], a_sum_temp[1]); - a_sum[2] = _mm256_add_epi32(a_sum[2], a_sum_temp[2]); - a_sum[3] = _mm256_add_epi32(a_sum[3], a_sum_temp[3]); - } - - a_v[0] = A_zero_point_v; - a_v[1] = A_zero_point_v; - a_v[2] = A_zero_point_v; - - if (t_in + 2 >= 0 && t_in + 2 < T) { - if (h_in + 2 >= 0 && h_in + 2 < H) { - if (w_in >= 0 && w_in < W) { - a_v[0] = load_a<REMAINDER>(A + ((2 * H + 2) * W + 0) * K, mask_v); - } - if (w_in + 1 >= 0 && w_in + 1 < W) { - a_v[1] = load_a<REMAINDER>(A + ((2 * H + 2) * W + 1) * K, mask_v); - } - if (w_in + 2 >= 0 && w_in + 2 < W) { - a_v[2] = load_a<REMAINDER>(A + ((2 * H + 2) * W + 2) * K, mask_v); - } - } - } - - inner_prod_packed_<3, SUM_A, REMAINDER, true /* acc */>( - a_v, reinterpret_cast<const __m256i*>(Bp) + 24, C, remainder, a_sum_temp); - - if (SUM_A) { - a_sum[0] = _mm256_add_epi32(a_sum[0], a_sum_temp[0]); - a_sum[1] = _mm256_add_epi32(a_sum[1], a_sum_temp[1]); - a_sum[2] = _mm256_add_epi32(a_sum[2], a_sum_temp[2]); - a_sum[3] = _mm256_add_epi32(a_sum[3], a_sum_temp[3]); - - __m256i B_zero_point_v; - for (int i = 0; i < (REMAINDER ? (remainder / 8) : 4); ++i) { - if (PER_CHANNEL_QUANTIZATION) { - B_zero_point_v = _mm256_loadu_si256( - reinterpret_cast<const __m256i*>(B_zero_point + i * 8)); - } else { - B_zero_point_v = _mm256_set1_epi32(B_zero_point[0]); - } - _mm256_store_si256( - reinterpret_cast<__m256i*>(&row_offsets[i * 8]), - _mm256_mullo_epi32(a_sum[i], B_zero_point_v)); - } - } -} - -template <bool FUSE_RELU, bool HAS_BIAS, bool A_SYMMETRIC, bool B_SYMMETRIC> + bool FUSE_RELU, + bool HAS_BIAS, + bool A_SYMMETRIC, + bool B_SYMMETRIC, + typename BIAS_TYPE> static inline ALWAYS_INLINE void depthwise_3x3_kernel_( int H, int W, @@ -1193,7 +161,8 @@ static inline ALWAYS_INLINE void depthwise_3x3_kernel_( uint8_t* C_uint8, int32_t* row_offsets, const int32_t* col_offsets, - const int32_t* bias) { + const BIAS_TYPE* bias, + float act_times_w_scale) { constexpr int S = 3; constexpr int PAD_T = 1, PAD_L = 1, PAD_R = 1; int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1; @@ -1238,7 +207,8 @@ static inline ALWAYS_INLINE void depthwise_3x3_kernel_( HAS_BIAS, false, /*PER_CHAN_QUANT*/ A_SYMMETRIC, - B_SYMMETRIC>( + B_SYMMETRIC, + BIAS_TYPE>( A_zero_point, &C_multiplier, C_zero_point, @@ -1247,95 +217,11 @@ static inline ALWAYS_INLINE void depthwise_3x3_kernel_( K, row_offsets, col_offsets, - bias); -} - -template <bool FUSE_RELU, bool HAS_BIAS, bool A_SYMMETRIC, bool B_SYMMETRIC> -static inline ALWAYS_INLINE void depthwise_3x3x3_kernel_( - int T, - int H, - int W, - int K, - int t, - int h, - int w, - int stride_t, - int stride_h, - int stride_w, - int32_t A_zero_point, - const uint8_t* A, - int32_t B_zero_point, - const int8_t* Bp, - float C_multiplier, - int32_t C_zero_point, - int32_t* C_int32, - uint8_t* C_uint8, - int32_t* row_offsets, - const int32_t* col_offsets, - const int32_t* bias) { - constexpr int R = 3, S = 3; - constexpr int PAD_P = 1, PAD_T = 1, PAD_B = 1, PAD_L = 1, PAD_R = 1; - int H_OUT = (H + PAD_T + PAD_B - R) / stride_h + 1; - int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1; - int t_in = -PAD_P + t * stride_t; - int h_in = -PAD_T + h * stride_h; - int w_in = -PAD_L + w * stride_w; - - int k; - for (k = 0; k < K / 32 * 32; k += 32) { - inner_prod_3x3x3_packed_<!B_SYMMETRIC /*SUM_A*/>( - T, - H, - W, - K, - t_in, - h_in, - w_in, - A + ((t_in * H + h_in) * W + w_in) * K + k, - A_zero_point, - Bp + k * 28, - &B_zero_point, - C_int32 + k, - 0, - B_SYMMETRIC ? nullptr : &row_offsets[k]); - } - int remainder = K - k; - if (remainder) { - inner_prod_3x3x3_packed_<!B_SYMMETRIC /*SUM_A*/, true>( - T, - H, - W, - K, - t_in, - h_in, - w_in, - A + ((t_in * H + h_in) * W + w_in) * K + k, - A_zero_point, - Bp + k * 28, - &B_zero_point, - C_int32 + k, - remainder, - B_SYMMETRIC ? nullptr : &row_offsets[k]); - } - - requantize_< - FUSE_RELU, - HAS_BIAS, - false, /*PER_CHAN_QUANT*/ - A_SYMMETRIC, - B_SYMMETRIC>( - A_zero_point, - &C_multiplier, - C_zero_point, - C_int32, - C_uint8 + ((t * H_OUT + h) * W_OUT + w) * K, - K, - row_offsets, - col_offsets, - bias); + bias, + &act_times_w_scale); } -template <bool FUSE_RELU, bool HAS_BIAS, bool A_SYMMETRIC> +template <bool FUSE_RELU, bool HAS_BIAS, bool A_SYMMETRIC, typename BIAS_TYPE> static inline ALWAYS_INLINE void depthwise_3x3_per_channel_quantization_kernel_( int H, @@ -1355,7 +241,8 @@ depthwise_3x3_per_channel_quantization_kernel_( uint8_t* C_uint8, int32_t* row_offsets, const int32_t* col_offsets, - const int32_t* bias) { + const BIAS_TYPE* bias, + const float* act_times_w_scale) { constexpr int S = 3; constexpr int PAD_T = 1, PAD_L = 1, PAD_R = 1; int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1; @@ -1406,7 +293,8 @@ depthwise_3x3_per_channel_quantization_kernel_( HAS_BIAS, true, /*PER_CHAN_QUANT*/ A_SYMMETRIC, - false /*B_SYMM*/>( + false, /*B_SYMM*/ + BIAS_TYPE>( A_zero_point, C_multiplier, C_zero_point, @@ -1415,113 +303,20 @@ depthwise_3x3_per_channel_quantization_kernel_( K, row_offsets, col_offsets, - bias); -} - -template <bool FUSE_RELU, bool HAS_BIAS, bool A_SYMMETRIC> -static inline ALWAYS_INLINE void -depthwise_3x3x3_per_channel_quantization_kernel_( - int T, - int H, - int W, - int K, - int t, - int h, - int w, - int stride_t, - int stride_h, - int stride_w, - int32_t A_zero_point, - const uint8_t* A, - const int32_t* B_zero_point, - const int8_t* Bp, - const float* C_multiplier, - int32_t C_zero_point, - int32_t* C_int32, - uint8_t* C_uint8, - int32_t* row_offsets, - const int32_t* col_offsets, - const int32_t* bias) { - constexpr int R = 3, S = 3; - constexpr int PAD_P = 1, PAD_T = 1, PAD_B = 1, PAD_L = 1, PAD_R = 1; - int H_OUT = (H + PAD_T + PAD_B - R) / stride_h + 1; - int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1; - int t_in = -PAD_P + t * stride_t; - int h_in = -PAD_T + h * stride_h; - int w_in = -PAD_L + w * stride_w; - - int k; - for (k = 0; k < K / 32 * 32; k += 32) { - inner_prod_3x3x3_packed_< - true, /*SUM_A*/ - false, /*remainder*/ - true /*per-channel*/>( - T, - H, - W, - K, - t_in, - h_in, - w_in, - A + ((t_in * H + h_in) * W + w_in) * K + k, - A_zero_point, - Bp + k * 28, - B_zero_point + k, - C_int32 + k, - 0, - &row_offsets[k]); - } - int remainder = K - k; - if (remainder) { - inner_prod_3x3x3_packed_< - true, /*SUM_A*/ - true, /*remainder*/ - true /*per-channel*/>( - T, - H, - W, - K, - t_in, - h_in, - w_in, - A + ((t_in * H + h_in) * W + w_in) * K + k, - A_zero_point, - Bp + k * 28, - B_zero_point + k, - C_int32 + k, - remainder, - &row_offsets[k]); - } - requantize_< - FUSE_RELU, - HAS_BIAS, - true, /*PER_CHAN_QUANT*/ - A_SYMMETRIC, - false /*B_SYMM*/>( - A_zero_point, - C_multiplier, - C_zero_point, - C_int32, - C_uint8 + ((t * H_OUT + h) * W_OUT + w) * K, - K, - row_offsets, - col_offsets, - bias); -} - -static pair<int, int> closest_factors_(int n) { - int a = (int)std::sqrt(n); - while (n % a != 0) { - a--; - } - return {a, n / a}; // a <= n / a + bias, + act_times_w_scale); } // TODO: short-circuit when B_zero_point is 0 or A_zero_point is 0 // This implemntation should be general enough to handle not just 3x3 but other // filter shapes by parameterizing with R and S but restricting it to just 3x3 // for now. -template <bool FUSE_RELU, bool HAS_BIAS, bool A_SYMMETRIC, bool B_SYMMETRIC> +template < + bool FUSE_RELU, + bool HAS_BIAS, + bool A_SYMMETRIC, + bool B_SYMMETRIC, + typename BIAS_TYPE> static inline ALWAYS_INLINE void depthwise_3x3_pad_1_( int N, int H, @@ -1532,13 +327,14 @@ static inline ALWAYS_INLINE void depthwise_3x3_pad_1_( int32_t A_zero_point, const uint8_t* A, int32_t B_zero_point, - const Packed3x3ConvMatrix& B, + const PackedDepthWiseConvMatrix& B, float C_multiplier, int32_t C_zero_point, int32_t* C_int32, uint8_t* C_uint8, const int32_t* col_offsets, - const int32_t* bias, + const BIAS_TYPE* bias, + float act_times_w_scale, int thread_id, int num_threads) { assert(K % 8 == 0); @@ -1597,7 +393,12 @@ static inline ALWAYS_INLINE void depthwise_3x3_pad_1_( if (h_begin == 0) { if (w_begin == 0) { - depthwise_3x3_kernel_<FUSE_RELU, HAS_BIAS, A_SYMMETRIC, B_SYMMETRIC>( + depthwise_3x3_kernel_< + FUSE_RELU, + HAS_BIAS, + A_SYMMETRIC, + B_SYMMETRIC, + BIAS_TYPE>( H, W, K, @@ -1615,11 +416,17 @@ static inline ALWAYS_INLINE void depthwise_3x3_pad_1_( C_uint8_base, row_offsets, col_offsets, - bias); + bias, + act_times_w_scale); } for (w = std::max(1, w_begin); w < std::min(W_OUT - 1, w_end); ++w) { - depthwise_3x3_kernel_<FUSE_RELU, HAS_BIAS, A_SYMMETRIC, B_SYMMETRIC>( + depthwise_3x3_kernel_< + FUSE_RELU, + HAS_BIAS, + A_SYMMETRIC, + B_SYMMETRIC, + BIAS_TYPE>( H, W, K, @@ -1637,12 +444,18 @@ static inline ALWAYS_INLINE void depthwise_3x3_pad_1_( C_uint8_base, row_offsets, col_offsets, - bias); + bias, + act_times_w_scale); } if (w_end == W_OUT) { w = W_OUT - 1; - depthwise_3x3_kernel_<FUSE_RELU, HAS_BIAS, A_SYMMETRIC, B_SYMMETRIC>( + depthwise_3x3_kernel_< + FUSE_RELU, + HAS_BIAS, + A_SYMMETRIC, + B_SYMMETRIC, + BIAS_TYPE>( H, W, K, @@ -1660,14 +473,20 @@ static inline ALWAYS_INLINE void depthwise_3x3_pad_1_( C_uint8_base, row_offsets, col_offsets, - bias); + bias, + act_times_w_scale); } } for (h = std::max(1, h_begin); h < std::min(H - 1, h_end); ++h) { if (w_begin == 0) { w = 0; - depthwise_3x3_kernel_<FUSE_RELU, HAS_BIAS, A_SYMMETRIC, B_SYMMETRIC>( + depthwise_3x3_kernel_< + FUSE_RELU, + HAS_BIAS, + A_SYMMETRIC, + B_SYMMETRIC, + BIAS_TYPE>( H, W, K, @@ -1685,11 +504,17 @@ static inline ALWAYS_INLINE void depthwise_3x3_pad_1_( C_uint8_base, row_offsets, col_offsets, - bias); + bias, + act_times_w_scale); } for (w = std::max(1, w_begin); w < std::min(W_OUT - 1, w_end); ++w) { - depthwise_3x3_kernel_<FUSE_RELU, HAS_BIAS, A_SYMMETRIC, B_SYMMETRIC>( + depthwise_3x3_kernel_< + FUSE_RELU, + HAS_BIAS, + A_SYMMETRIC, + B_SYMMETRIC, + BIAS_TYPE>( H, W, K, @@ -1707,12 +532,18 @@ static inline ALWAYS_INLINE void depthwise_3x3_pad_1_( C_uint8_base, row_offsets, col_offsets, - bias); + bias, + act_times_w_scale); } if (w_end == W_OUT) { w = W_OUT - 1; - depthwise_3x3_kernel_<FUSE_RELU, HAS_BIAS, A_SYMMETRIC, B_SYMMETRIC>( + depthwise_3x3_kernel_< + FUSE_RELU, + HAS_BIAS, + A_SYMMETRIC, + B_SYMMETRIC, + BIAS_TYPE>( H, W, K, @@ -1730,7 +561,8 @@ static inline ALWAYS_INLINE void depthwise_3x3_pad_1_( C_uint8_base, row_offsets, col_offsets, - bias); + bias, + act_times_w_scale); } } @@ -1738,7 +570,12 @@ static inline ALWAYS_INLINE void depthwise_3x3_pad_1_( h = H_OUT - 1; w = 0; if (w_begin == 0) { - depthwise_3x3_kernel_<FUSE_RELU, HAS_BIAS, A_SYMMETRIC, B_SYMMETRIC>( + depthwise_3x3_kernel_< + FUSE_RELU, + HAS_BIAS, + A_SYMMETRIC, + B_SYMMETRIC, + BIAS_TYPE>( H, W, K, @@ -1756,11 +593,17 @@ static inline ALWAYS_INLINE void depthwise_3x3_pad_1_( C_uint8_base, row_offsets, col_offsets, - bias); + bias, + act_times_w_scale); } for (w = std::max(1, w_begin); w < std::min(W_OUT - 1, w_end); ++w) { - depthwise_3x3_kernel_<FUSE_RELU, HAS_BIAS, A_SYMMETRIC, B_SYMMETRIC>( + depthwise_3x3_kernel_< + FUSE_RELU, + HAS_BIAS, + A_SYMMETRIC, + B_SYMMETRIC, + BIAS_TYPE>( H, W, K, @@ -1778,12 +621,18 @@ static inline ALWAYS_INLINE void depthwise_3x3_pad_1_( C_uint8_base, row_offsets, col_offsets, - bias); + bias, + act_times_w_scale); } if (w_end == W_OUT) { w = W_OUT - 1; - depthwise_3x3_kernel_<FUSE_RELU, HAS_BIAS, A_SYMMETRIC, B_SYMMETRIC>( + depthwise_3x3_kernel_< + FUSE_RELU, + HAS_BIAS, + A_SYMMETRIC, + B_SYMMETRIC, + BIAS_TYPE>( H, W, K, @@ -1801,126 +650,15 @@ static inline ALWAYS_INLINE void depthwise_3x3_pad_1_( C_uint8_base, row_offsets, col_offsets, - bias); + bias, + act_times_w_scale); } } } // for each n FREE(row_offsets); }; -template <bool FUSE_RELU, bool HAS_BIAS, bool A_SYMMETRIC, bool B_SYMMETRIC> -static inline ALWAYS_INLINE void depthwise_3x3x3_pad_1_( - int N, - int T, - int H, - int W, - int K, - int stride_t, - int stride_h, - int stride_w, - int32_t A_zero_point, - const uint8_t* A, - int32_t B_zero_point, - const Packed3x3x3ConvMatrix& B, - float C_multiplier, - int32_t C_zero_point, - int32_t* C_int32, - uint8_t* C_uint8, - const int32_t* col_offsets, - const int32_t* bias, - int thread_id, - int num_threads) { - assert(K % 8 == 0); - constexpr int K_T = 3, K_H = 3, K_W = 3; - constexpr int PAD_P = 1, PAD_N = 1, PAD_T = 1, PAD_B = 1, PAD_L = 1, - PAD_R = 1; - int T_OUT = (T + PAD_P + PAD_N - K_T) / stride_t + 1; - int H_OUT = (H + PAD_T + PAD_B - K_H) / stride_h + 1; - int W_OUT = (W + PAD_L + PAD_R - K_W) / stride_w + 1; - const int8_t* Bp = B.PackedMat(); - - int32_t* row_offsets = static_cast<int32_t*>(ALIGNED_MALLOC(((K + 31) / 32 * 32)*sizeof(int32_t), 64)); // __attribute__((aligned(64))); - - int n_begin, n_end; - int t_begin, t_end, h_begin, h_end; - if (N >= num_threads) { - int n_per_thread = (N + num_threads - 1) / num_threads; - n_begin = std::min(thread_id * n_per_thread, N); - n_end = std::min(n_begin + n_per_thread, N); - t_begin = 0; - t_end = T_OUT; - h_begin = 0; - h_end = H_OUT; - } else { - int nthreads_per_n = num_threads / N; - n_begin = std::min(thread_id / nthreads_per_n, N); - n_end = std::min(n_begin + 1, N); - - int tid_of_n_begin = std::min(n_begin * nthreads_per_n, num_threads); - int tid_of_n_end = std::min(tid_of_n_begin + nthreads_per_n, num_threads); - int nthreads_of_n = tid_of_n_end - tid_of_n_begin; - int tid_within_n = thread_id - tid_of_n_begin; - assert(tid_within_n >= 0); - assert(tid_within_n < nthreads_of_n); - - // n is processed by num_threads_t * num_threads_h 2D grid of threads - int num_threads_t, num_threads_h; - // num_threads_w <= num_threads_h - tie(num_threads_t, num_threads_h) = closest_factors_(nthreads_of_n); - int tid_t = tid_within_n / num_threads_h; - int tid_h = tid_within_n % num_threads_h; - - int t_per_thread = (T_OUT + num_threads_t - 1) / num_threads_t; - t_begin = std::min(tid_t * t_per_thread, T_OUT); - t_end = std::min(t_begin + t_per_thread, T_OUT); - - int h_per_thread = (H_OUT + num_threads_h - 1) / num_threads_h; - h_begin = std::min(tid_h * h_per_thread, H_OUT); - h_end = std::min(h_begin + h_per_thread, H_OUT); - } - - for (int n = n_begin; n < n_end; ++n) { - const uint8_t* A_base = A + n * T * H * W * K; - uint8_t* C_uint8_base = C_uint8 + n * T_OUT * H_OUT * W_OUT * K; - - for (int t = t_begin; t < t_end; ++t) { - for (int h = h_begin; h < h_end; ++h) { - for (int w = 0; w < W_OUT; ++w) { - depthwise_3x3x3_kernel_< - FUSE_RELU, - HAS_BIAS, - A_SYMMETRIC, - B_SYMMETRIC>( - T, - H, - W, - K, - t, - h, - w, - stride_t, - stride_h, - stride_w, - A_zero_point, - A_base, - B_zero_point, - Bp, - C_multiplier, - C_zero_point, - C_int32, - C_uint8_base, - row_offsets, - col_offsets, - bias); - } // w - } // h - } // t - } // for each n - - FREE(row_offsets); -}; - -template <bool FUSE_RELU, bool HAS_BIAS, bool A_SYMMETRIC> +template <bool FUSE_RELU, bool HAS_BIAS, bool A_SYMMETRIC, typename BIAS_TYPE> static inline ALWAYS_INLINE void depthwise_3x3_per_channel_quantization_pad_1_( int N, @@ -1932,13 +670,14 @@ depthwise_3x3_per_channel_quantization_pad_1_( int32_t A_zero_point, const uint8_t* A, const int32_t* B_zero_point, - const Packed3x3ConvMatrix& B, + const PackedDepthWiseConvMatrix& B, const float* C_multiplier, int32_t C_zero_point, int32_t* C_int32, uint8_t* C_uint8, const int32_t* col_offsets, - const int32_t* bias, + const BIAS_TYPE* bias, + const float* act_times_w_scale, int thread_id, int num_threads) { assert(K % 8 == 0); @@ -2000,7 +739,8 @@ depthwise_3x3_per_channel_quantization_pad_1_( depthwise_3x3_per_channel_quantization_kernel_< FUSE_RELU, HAS_BIAS, - A_SYMMETRIC>( + A_SYMMETRIC, + BIAS_TYPE>( H, W, K, @@ -2018,14 +758,16 @@ depthwise_3x3_per_channel_quantization_pad_1_( C_uint8_base, row_offsets, col_offsets, - bias); + bias, + act_times_w_scale); } for (w = std::max(1, w_begin); w < std::min(W_OUT - 1, w_end); ++w) { depthwise_3x3_per_channel_quantization_kernel_< FUSE_RELU, HAS_BIAS, - A_SYMMETRIC>( + A_SYMMETRIC, + BIAS_TYPE>( H, W, K, @@ -2043,7 +785,8 @@ depthwise_3x3_per_channel_quantization_pad_1_( C_uint8_base, row_offsets, col_offsets, - bias); + bias, + act_times_w_scale); } if (w_end == W_OUT) { @@ -2051,7 +794,8 @@ depthwise_3x3_per_channel_quantization_pad_1_( depthwise_3x3_per_channel_quantization_kernel_< FUSE_RELU, HAS_BIAS, - A_SYMMETRIC>( + A_SYMMETRIC, + BIAS_TYPE>( H, W, K, @@ -2069,7 +813,8 @@ depthwise_3x3_per_channel_quantization_pad_1_( C_uint8_base, row_offsets, col_offsets, - bias); + bias, + act_times_w_scale); } } @@ -2079,7 +824,8 @@ depthwise_3x3_per_channel_quantization_pad_1_( depthwise_3x3_per_channel_quantization_kernel_< FUSE_RELU, HAS_BIAS, - A_SYMMETRIC>( + A_SYMMETRIC, + BIAS_TYPE>( H, W, K, @@ -2097,14 +843,16 @@ depthwise_3x3_per_channel_quantization_pad_1_( C_uint8_base, row_offsets, col_offsets, - bias); + bias, + act_times_w_scale); } for (w = std::max(1, w_begin); w < std::min(W_OUT - 1, w_end); ++w) { depthwise_3x3_per_channel_quantization_kernel_< FUSE_RELU, HAS_BIAS, - A_SYMMETRIC>( + A_SYMMETRIC, + BIAS_TYPE>( H, W, K, @@ -2122,7 +870,8 @@ depthwise_3x3_per_channel_quantization_pad_1_( C_uint8_base, row_offsets, col_offsets, - bias); + bias, + act_times_w_scale); } if (w_end == W_OUT) { @@ -2130,7 +879,8 @@ depthwise_3x3_per_channel_quantization_pad_1_( depthwise_3x3_per_channel_quantization_kernel_< FUSE_RELU, HAS_BIAS, - A_SYMMETRIC>( + A_SYMMETRIC, + BIAS_TYPE>( H, W, K, @@ -2148,7 +898,8 @@ depthwise_3x3_per_channel_quantization_pad_1_( C_uint8_base, row_offsets, col_offsets, - bias); + bias, + act_times_w_scale); } } @@ -2159,7 +910,8 @@ depthwise_3x3_per_channel_quantization_pad_1_( depthwise_3x3_per_channel_quantization_kernel_< FUSE_RELU, HAS_BIAS, - A_SYMMETRIC>( + A_SYMMETRIC, + BIAS_TYPE>( H, W, K, @@ -2177,14 +929,16 @@ depthwise_3x3_per_channel_quantization_pad_1_( C_uint8_base, row_offsets, col_offsets, - bias); + bias, + act_times_w_scale); } for (w = std::max(1, w_begin); w < std::min(W_OUT - 1, w_end); ++w) { depthwise_3x3_per_channel_quantization_kernel_< FUSE_RELU, HAS_BIAS, - A_SYMMETRIC>( + A_SYMMETRIC, + BIAS_TYPE>( H, W, K, @@ -2202,7 +956,8 @@ depthwise_3x3_per_channel_quantization_pad_1_( C_uint8_base, row_offsets, col_offsets, - bias); + bias, + act_times_w_scale); } if (w_end == W_OUT) { @@ -2210,7 +965,8 @@ depthwise_3x3_per_channel_quantization_pad_1_( depthwise_3x3_per_channel_quantization_kernel_< FUSE_RELU, HAS_BIAS, - A_SYMMETRIC>( + A_SYMMETRIC, + BIAS_TYPE>( H, W, K, @@ -2228,128 +984,15 @@ depthwise_3x3_per_channel_quantization_pad_1_( C_uint8_base, row_offsets, col_offsets, - bias); + bias, + act_times_w_scale); } } } // for each n - - FREE(row_offsets); -}; - -template <bool FUSE_RELU, bool HAS_BIAS, bool A_SYMMETRIC> -static inline ALWAYS_INLINE void -depthwise_3x3x3_per_channel_quantization_pad_1_( - int N, - int T, - int H, - int W, - int K, - int stride_t, - int stride_h, - int stride_w, - int32_t A_zero_point, - const uint8_t* A, - const int32_t* B_zero_point, - const Packed3x3x3ConvMatrix& B, - const float* C_multiplier, - int32_t C_zero_point, - int32_t* C_int32, - uint8_t* C_uint8, - const int32_t* col_offsets, - const int32_t* bias, - int thread_id, - int num_threads) { - assert(K % 8 == 0); - constexpr int K_T = 3, K_H = 3, K_W = 3; - constexpr int PAD_P = 1, PAD_N = 1, PAD_T = 1, PAD_B = 1, PAD_L = 1, - PAD_R = 1; - int T_OUT = (T + PAD_P + PAD_N - K_T) / stride_t + 1; - int H_OUT = (H + PAD_T + PAD_B - K_H) / stride_h + 1; - int W_OUT = (W + PAD_L + PAD_R - K_W) / stride_w + 1; - const int8_t* Bp = B.PackedMat(); - - int32_t* row_offsets = static_cast<int32_t*>(ALIGNED_MALLOC(((K + 31) / 32 * 32)*sizeof(int32_t), 64)); // __attribute__((aligned(64))); - - int n_begin, n_end; - int t_begin, t_end, h_begin, h_end; - if (N >= num_threads) { - int n_per_thread = (N + num_threads - 1) / num_threads; - n_begin = std::min(thread_id * n_per_thread, N); - n_end = std::min(n_begin + n_per_thread, N); - t_begin = 0; - t_end = T_OUT; - h_begin = 0; - h_end = H_OUT; - } else { - int nthreads_per_n = num_threads / N; - n_begin = std::min(thread_id / nthreads_per_n, N); - n_end = std::min(n_begin + 1, N); - - int tid_of_n_begin = std::min(n_begin * nthreads_per_n, num_threads); - int tid_of_n_end = std::min(tid_of_n_begin + nthreads_per_n, num_threads); - int nthreads_of_n = tid_of_n_end - tid_of_n_begin; - int tid_within_n = thread_id - tid_of_n_begin; - assert(tid_within_n >= 0); - assert(tid_within_n < nthreads_of_n); - - // n is processed by num_threads_t * num_threads_h 2D grid of threads - int num_threads_t, num_threads_h; - // num_threads_w <= num_threads_h - tie(num_threads_t, num_threads_h) = closest_factors_(nthreads_of_n); - int tid_t = tid_within_n / num_threads_h; - int tid_h = tid_within_n % num_threads_h; - - int t_per_thread = (T_OUT + num_threads_t - 1) / num_threads_t; - t_begin = std::min(tid_t * t_per_thread, T_OUT); - t_end = std::min(t_begin + t_per_thread, T_OUT); - - int h_per_thread = (H_OUT + num_threads_h - 1) / num_threads_h; - h_begin = std::min(tid_h * h_per_thread, H_OUT); - h_end = std::min(h_begin + h_per_thread, H_OUT); - } - - for (int n = n_begin; n < n_end; ++n) { - const uint8_t* A_base = A + n * T * H * W * K; - uint8_t* C_uint8_base = C_uint8 + n * T_OUT * H_OUT * W_OUT * K; - - for (int t = t_begin; t < t_end; ++t) { - for (int h = h_begin; h < h_end; ++h) { - for (int w = 0; w < W_OUT; ++w) { - depthwise_3x3x3_per_channel_quantization_kernel_< - FUSE_RELU, - HAS_BIAS, - A_SYMMETRIC>( - T, - H, - W, - K, - t, - h, - w, - stride_t, - stride_h, - stride_w, - A_zero_point, - A_base, - B_zero_point, - Bp, - C_multiplier, - C_zero_point, - C_int32, - C_uint8_base, - row_offsets, - col_offsets, - bias); - } // w - } // h - } // t - } // for each n - - FREE(row_offsets); }; // Dispatch A_SYMMETRIC and B_SYMMETRIC -template <bool FUSE_RELU, bool HAS_BIAS> +template <bool FUSE_RELU, bool HAS_BIAS, typename BIAS_TYPE> static void depthwise_3x3_pad_1_( int N, int H, @@ -2360,12 +1003,13 @@ static void depthwise_3x3_pad_1_( int32_t A_zero_point, const uint8_t* A, int32_t B_zero_point, - const Packed3x3ConvMatrix& B, + const PackedDepthWiseConvMatrix& B, float C_multiplier, int32_t C_zero_point, uint8_t* C, const int32_t* col_offsets, - const int32_t* bias, + const BIAS_TYPE* bias, + float act_times_w_scale, int thread_id, int num_threads) { int32_t* C_int32_temp = new int32_t[(K + 31) / 32 * 32]; @@ -2375,7 +1019,8 @@ static void depthwise_3x3_pad_1_( FUSE_RELU, HAS_BIAS, true /*A_symmetric*/, - true /*B_symmetric*/>( + true /*B_symmetric*/, + BIAS_TYPE>( N, H, W, @@ -2392,6 +1037,7 @@ static void depthwise_3x3_pad_1_( C, col_offsets, bias, + act_times_w_scale, thread_id, num_threads); } else { @@ -2399,7 +1045,8 @@ static void depthwise_3x3_pad_1_( FUSE_RELU, HAS_BIAS, true /*A_symmetric*/, - false /*B_symmetric*/>( + false /*B_symmetric*/, + BIAS_TYPE>( N, H, W, @@ -2416,6 +1063,7 @@ static void depthwise_3x3_pad_1_( C, col_offsets, bias, + act_times_w_scale, thread_id, num_threads); } @@ -2425,7 +1073,8 @@ static void depthwise_3x3_pad_1_( FUSE_RELU, HAS_BIAS, false /*A_symmetric*/, - true /*B_symmetric*/>( + true /*B_symmetric*/, + BIAS_TYPE>( N, H, W, @@ -2442,6 +1091,7 @@ static void depthwise_3x3_pad_1_( C, col_offsets, bias, + act_times_w_scale, thread_id, num_threads); } else { @@ -2449,7 +1099,8 @@ static void depthwise_3x3_pad_1_( FUSE_RELU, HAS_BIAS, false /*A_symmetric*/, - false /*B_symmetric*/>( + false /*B_symmetric*/, + BIAS_TYPE>( N, H, W, @@ -2466,6 +1117,7 @@ static void depthwise_3x3_pad_1_( C, col_offsets, bias, + act_times_w_scale, thread_id, num_threads); } @@ -2474,7 +1126,7 @@ static void depthwise_3x3_pad_1_( } // Dispatch HAS_BIAS -template <bool FUSE_RELU> +template <bool FUSE_RELU, typename BIAS_TYPE> static void depthwise_3x3_pad_1_( int N, int H, @@ -2485,16 +1137,17 @@ static void depthwise_3x3_pad_1_( int32_t A_zero_point, const uint8_t* A, int32_t B_zero_point, - const Packed3x3ConvMatrix& B, + const PackedDepthWiseConvMatrix& B, float C_multiplier, int32_t C_zero_point, uint8_t* C, const int32_t* col_offsets, - const int32_t* bias, + const BIAS_TYPE* bias, + float act_times_w_scale, int thread_id, int num_threads) { if (bias) { - depthwise_3x3_pad_1_<FUSE_RELU, true /*HAS_BIAS*/>( + depthwise_3x3_pad_1_<FUSE_RELU, true /*HAS_BIAS*/, BIAS_TYPE>( N, H, W, @@ -2510,10 +1163,11 @@ static void depthwise_3x3_pad_1_( C, col_offsets, bias, + act_times_w_scale, thread_id, num_threads); } else { - depthwise_3x3_pad_1_<FUSE_RELU, false /*HAS_BIAS*/>( + depthwise_3x3_pad_1_<FUSE_RELU, false /*HAS_BIAS*/, BIAS_TYPE>( N, H, W, @@ -2529,6 +1183,7 @@ static void depthwise_3x3_pad_1_( C, col_offsets, bias, + act_times_w_scale, thread_id, num_threads); } @@ -2536,6 +1191,7 @@ static void depthwise_3x3_pad_1_( // Dispatch input shape and FUSE_RELU // assumption: W > 3 and H > 3 +template <typename BIAS_TYPE> void depthwise_3x3_pad_1( int N, int H, @@ -2546,18 +1202,33 @@ void depthwise_3x3_pad_1( int32_t A_zero_point, const uint8_t* A, int32_t B_zero_point, - const Packed3x3ConvMatrix& B, + const PackedDepthWiseConvMatrix& B, float C_multiplier, int32_t C_zero_point, uint8_t* C, const int32_t* col_offsets, - const int32_t* bias, + const BIAS_TYPE* bias, bool fuse_relu, + float act_times_w_scale, int thread_id, int num_threads) { + if (B.GetKernelProduct() != 3 * 3) { + string msg = + "[FBGEMM_CONV_ERROR] Packed weight is expected to have kernel_prod " + + to_string(3 * 3) + " but has " + to_string(B.GetKernelProduct()); + throw logic_error(msg); + } + if (stride_h == 0 || stride_w == 0 || num_threads == 0) { + assert(0 && "stride_h == 0 || stride_w == 0 || num_threads == 0"); + return; + } + if (N == 0) { + // In C2, batch size 0 is allowed, so we should just early return. + return; + } if (fuse_relu) { if (7 == H && 7 == W && 1 == stride_h && 1 == stride_w) { - depthwise_3x3_pad_1_<true /* FUSE_RELU */>( + depthwise_3x3_pad_1_<true /* FUSE_RELU */, BIAS_TYPE>( N, H, W, @@ -2573,10 +1244,11 @@ void depthwise_3x3_pad_1( C, col_offsets, bias, + act_times_w_scale, thread_id, num_threads); } else if (14 == H && 14 == W && 2 == stride_h && 2 == stride_w) { - depthwise_3x3_pad_1_<true /* FUSE_RELU */>( + depthwise_3x3_pad_1_<true /* FUSE_RELU */, BIAS_TYPE>( N, H, W, @@ -2592,10 +1264,11 @@ void depthwise_3x3_pad_1( C, col_offsets, bias, + act_times_w_scale, thread_id, num_threads); } else if (1 == stride_h && 1 == stride_w) { - depthwise_3x3_pad_1_<true /* FUSE_RELU */>( + depthwise_3x3_pad_1_<true /* FUSE_RELU */, BIAS_TYPE>( N, H, W, @@ -2611,10 +1284,11 @@ void depthwise_3x3_pad_1( C, col_offsets, bias, + act_times_w_scale, thread_id, num_threads); } else if (2 == stride_h && 2 == stride_w) { - depthwise_3x3_pad_1_<true /* FUSE_RELU */>( + depthwise_3x3_pad_1_<true /* FUSE_RELU */, BIAS_TYPE>( N, H, W, @@ -2630,10 +1304,11 @@ void depthwise_3x3_pad_1( C, col_offsets, bias, + act_times_w_scale, thread_id, num_threads); } else { - depthwise_3x3_pad_1_<true /* FUSE_RELU */>( + depthwise_3x3_pad_1_<true /* FUSE_RELU */, BIAS_TYPE>( N, H, W, @@ -2649,12 +1324,13 @@ void depthwise_3x3_pad_1( C, col_offsets, bias, + act_times_w_scale, thread_id, num_threads); } } else { if (7 == H && 7 == W && 1 == stride_h && 1 == stride_w) { - depthwise_3x3_pad_1_<false /* FUSE_RELU */>( + depthwise_3x3_pad_1_<false /* FUSE_RELU */, BIAS_TYPE>( N, H, W, @@ -2670,10 +1346,11 @@ void depthwise_3x3_pad_1( C, col_offsets, bias, + act_times_w_scale, thread_id, num_threads); } else if (14 == H && 14 == W && 2 == stride_h && 2 == stride_w) { - depthwise_3x3_pad_1_<false /* FUSE_RELU */>( + depthwise_3x3_pad_1_<false /* FUSE_RELU */, BIAS_TYPE>( N, H, W, @@ -2689,10 +1366,11 @@ void depthwise_3x3_pad_1( C, col_offsets, bias, + act_times_w_scale, thread_id, num_threads); } else if (1 == stride_h && 1 == stride_w) { - depthwise_3x3_pad_1_<false /* FUSE_RELU */>( + depthwise_3x3_pad_1_<false /* FUSE_RELU */, BIAS_TYPE>( N, H, W, @@ -2708,10 +1386,11 @@ void depthwise_3x3_pad_1( C, col_offsets, bias, + act_times_w_scale, thread_id, num_threads); } else if (2 == stride_h && 2 == stride_w) { - depthwise_3x3_pad_1_<false /* FUSE_RELU */>( + depthwise_3x3_pad_1_<false /* FUSE_RELU */, BIAS_TYPE>( N, H, W, @@ -2727,10 +1406,11 @@ void depthwise_3x3_pad_1( C, col_offsets, bias, + act_times_w_scale, thread_id, num_threads); } else { - depthwise_3x3_pad_1_<false /* FUSE_RELU */>( + depthwise_3x3_pad_1_<false /* FUSE_RELU */, BIAS_TYPE>( N, H, W, @@ -2746,283 +1426,15 @@ void depthwise_3x3_pad_1( C, col_offsets, bias, + act_times_w_scale, thread_id, num_threads); } } } -// Dispatch A_SYMMETRIC and B_SYMMETRIC -template <bool FUSE_RELU, bool HAS_BIAS> -static void depthwise_3x3x3_pad_1_( - int N, - int T, - int H, - int W, - int K, - int stride_t, - int stride_h, - int stride_w, - int32_t A_zero_point, - const uint8_t* A, - int32_t B_zero_point, - const Packed3x3x3ConvMatrix& B, - float C_multiplier, - int32_t C_zero_point, - uint8_t* C, - const int32_t* col_offsets, - const int32_t* bias, - int thread_id, - int num_threads) { - int32_t* C_int32_temp = new int32_t[(K + 31) / 32 * 32]; - if (A_zero_point == 0 || col_offsets == nullptr) { - if (B_zero_point == 0) { - depthwise_3x3x3_pad_1_< - FUSE_RELU, - HAS_BIAS, - true /*A_symmetric*/, - true /*B_symmetric*/>( - N, - T, - H, - W, - K, - stride_t, - stride_h, - stride_w, - A_zero_point, - A, - B_zero_point, - B, - C_multiplier, - C_zero_point, - C_int32_temp, - C, - col_offsets, - bias, - thread_id, - num_threads); - } else { - depthwise_3x3x3_pad_1_< - FUSE_RELU, - HAS_BIAS, - true /*A_symmetric*/, - false /*B_symmetric*/>( - N, - T, - H, - W, - K, - stride_t, - stride_h, - stride_w, - A_zero_point, - A, - B_zero_point, - B, - C_multiplier, - C_zero_point, - C_int32_temp, - C, - col_offsets, - bias, - thread_id, - num_threads); - } - } else { - if (B_zero_point == 0) { - depthwise_3x3x3_pad_1_< - FUSE_RELU, - HAS_BIAS, - false /*A_symmetric*/, - true /*B_symmetric*/>( - N, - T, - H, - W, - K, - stride_t, - stride_h, - stride_w, - A_zero_point, - A, - B_zero_point, - B, - C_multiplier, - C_zero_point, - C_int32_temp, - C, - col_offsets, - bias, - thread_id, - num_threads); - } else { - depthwise_3x3x3_pad_1_< - FUSE_RELU, - HAS_BIAS, - false /*A_symmetric*/, - false /*B_symmetric*/>( - N, - T, - H, - W, - K, - stride_t, - stride_h, - stride_w, - A_zero_point, - A, - B_zero_point, - B, - C_multiplier, - C_zero_point, - C_int32_temp, - C, - col_offsets, - bias, - thread_id, - num_threads); - } - } - delete[] C_int32_temp; -} - -// Dispatch HAS_BIAS -template <bool FUSE_RELU> -static void depthwise_3x3x3_pad_1_( - int N, - int T, - int H, - int W, - int K, - int stride_t, - int stride_h, - int stride_w, - int32_t A_zero_point, - const uint8_t* A, - int32_t B_zero_point, - const Packed3x3x3ConvMatrix& B, - float C_multiplier, - int32_t C_zero_point, - uint8_t* C, - const int32_t* col_offsets, - const int32_t* bias, - int thread_id, - int num_threads) { - if (bias) { - depthwise_3x3x3_pad_1_<FUSE_RELU, true /*HAS_BIAS*/>( - N, - T, - H, - W, - K, - stride_t, - stride_h, - stride_w, - A_zero_point, - A, - B_zero_point, - B, - C_multiplier, - C_zero_point, - C, - col_offsets, - bias, - thread_id, - num_threads); - } else { - depthwise_3x3x3_pad_1_<FUSE_RELU, false /*HAS_BIAS*/>( - N, - T, - H, - W, - K, - stride_t, - stride_h, - stride_w, - A_zero_point, - A, - B_zero_point, - B, - C_multiplier, - C_zero_point, - C, - col_offsets, - bias, - thread_id, - num_threads); - } -} - -// Dispatch FUSE_RELU -void depthwise_3x3x3_pad_1( - int N, - int T, - int H, - int W, - int K, - int stride_t, - int stride_h, - int stride_w, - int32_t A_zero_point, - const uint8_t* A, - int32_t B_zero_point, - const Packed3x3x3ConvMatrix& B, - float C_multiplier, - int32_t C_zero_point, - uint8_t* C, - const int32_t* col_offsets, - const int32_t* bias, - bool fuse_relu, - int thread_id, - int num_threads) { - if (fuse_relu) { - depthwise_3x3x3_pad_1_<true /*FUSE_RELU*/>( - N, - T, - H, - W, - K, - stride_t, - stride_h, - stride_w, - A_zero_point, - A, - B_zero_point, - B, - C_multiplier, - C_zero_point, - C, - col_offsets, - bias, - thread_id, - num_threads); - } else { - depthwise_3x3x3_pad_1_<false /*FUSE_RELU*/>( - N, - T, - H, - W, - K, - stride_t, - stride_h, - stride_w, - A_zero_point, - A, - B_zero_point, - B, - C_multiplier, - C_zero_point, - C, - col_offsets, - bias, - thread_id, - num_threads); - } -} - // Dispatch A_SYMMETRIC -template <bool FUSE_RELU, bool HAS_BIAS> +template <bool FUSE_RELU, bool HAS_BIAS, typename BIAS_TYPE> static void depthwise_3x3_per_channel_quantization_pad_1_( int N, int H, @@ -3033,12 +1445,13 @@ static void depthwise_3x3_per_channel_quantization_pad_1_( int32_t A_zero_point, const uint8_t* A, const int32_t* B_zero_point, - const Packed3x3ConvMatrix& Bp, + const PackedDepthWiseConvMatrix& Bp, const float* C_multiplier, int32_t C_zero_point, uint8_t* C, const int32_t* col_offsets, - const int32_t* bias, + const BIAS_TYPE* bias, + const float* act_times_w_scale, int thread_id, int num_threads) { int32_t* C_int32_temp = new int32_t[(K + 31) / 32 * 32]; @@ -3046,7 +1459,8 @@ static void depthwise_3x3_per_channel_quantization_pad_1_( depthwise_3x3_per_channel_quantization_pad_1_< FUSE_RELU, HAS_BIAS, - true /*A_SYMM*/>( + true /*A_SYMM*/, + BIAS_TYPE>( N, H, W, @@ -3063,13 +1477,15 @@ static void depthwise_3x3_per_channel_quantization_pad_1_( C, col_offsets, bias, + act_times_w_scale, thread_id, num_threads); } else { depthwise_3x3_per_channel_quantization_pad_1_< FUSE_RELU, HAS_BIAS, - false /*A_SYMM*/>( + false /*A_SYMM*/, + BIAS_TYPE>( N, H, W, @@ -3086,6 +1502,7 @@ static void depthwise_3x3_per_channel_quantization_pad_1_( C, col_offsets, bias, + act_times_w_scale, thread_id, num_threads); } @@ -3093,7 +1510,7 @@ static void depthwise_3x3_per_channel_quantization_pad_1_( } // Dispatch HAS_BIAS -template <bool FUSE_RELU> +template <bool FUSE_RELU, typename BIAS_TYPE> static void depthwise_3x3_per_channel_quantization_pad_1_( int N, int H, @@ -3104,18 +1521,20 @@ static void depthwise_3x3_per_channel_quantization_pad_1_( int32_t A_zero_point, const uint8_t* A, const int32_t* B_zero_point, - const Packed3x3ConvMatrix& Bp, + const PackedDepthWiseConvMatrix& Bp, const float* C_multiplier, int32_t C_zero_point, uint8_t* C, const int32_t* col_offsets, - const int32_t* bias, + const BIAS_TYPE* bias, + const float* act_times_w_scale, int thread_id, int num_threads) { if (bias) { depthwise_3x3_per_channel_quantization_pad_1_< FUSE_RELU, - true /* HAS_BIAS */>( + true /* HAS_BIAS */, + BIAS_TYPE>( N, H, W, @@ -3131,12 +1550,14 @@ static void depthwise_3x3_per_channel_quantization_pad_1_( C, col_offsets, bias, + act_times_w_scale, thread_id, num_threads); } else { depthwise_3x3_per_channel_quantization_pad_1_< FUSE_RELU, - false /* HAS_BIAS */>( + false /* HAS_BIAS */, + BIAS_TYPE>( N, H, W, @@ -3152,12 +1573,14 @@ static void depthwise_3x3_per_channel_quantization_pad_1_( C, col_offsets, bias, + act_times_w_scale, thread_id, num_threads); } } // Dispatch input shape and FUSE_RELU +template <typename BIAS_TYPE> void depthwise_3x3_per_channel_quantization_pad_1( int N, int H, @@ -3168,18 +1591,35 @@ void depthwise_3x3_per_channel_quantization_pad_1( int32_t A_zero_point, const uint8_t* A, const int32_t* B_zero_point, - const Packed3x3ConvMatrix& Bp, + const PackedDepthWiseConvMatrix& Bp, const float* C_multiplier, int32_t C_zero_point, uint8_t* C, const int32_t* col_offsets, - const int32_t* bias, + const BIAS_TYPE* bias, bool fuse_relu, + const float* act_times_w_scale, int thread_id, int num_threads) { + if (Bp.GetKernelProduct() != 3 * 3) { + string msg = + "[FBGEMM_CONV_ERROR] Packed weight is expected to have kernel_prod " + + to_string(3 * 3) + " but has " + to_string(Bp.GetKernelProduct()); + throw logic_error(msg); + } + if (stride_h == 0 || stride_w == 0 || num_threads == 0) { + assert(0 && "stride_h == 0 || stride_w == 0 || num_threads == 0"); + return; + } + if (N == 0) { + // In C2, batch size 0 is allowed, so we should just early return. + return; + } if (fuse_relu) { if (7 == H && 7 == W && 1 == stride_h && 1 == stride_w) { - depthwise_3x3_per_channel_quantization_pad_1_<true /* FUSE_RELU */>( + depthwise_3x3_per_channel_quantization_pad_1_< + true /* FUSE_RELU */, + BIAS_TYPE>( N, H, W, @@ -3195,10 +1635,13 @@ void depthwise_3x3_per_channel_quantization_pad_1( C, col_offsets, bias, + act_times_w_scale, thread_id, num_threads); } else if (14 == H && 14 == W && 2 == stride_h && 2 == stride_w) { - depthwise_3x3_per_channel_quantization_pad_1_<true /* FUSE_RELU */>( + depthwise_3x3_per_channel_quantization_pad_1_< + true /* FUSE_RELU */, + BIAS_TYPE>( N, H, W, @@ -3214,10 +1657,13 @@ void depthwise_3x3_per_channel_quantization_pad_1( C, col_offsets, bias, + act_times_w_scale, thread_id, num_threads); } else if (1 == stride_h && 1 == stride_w) { - depthwise_3x3_per_channel_quantization_pad_1_<true /* FUSE_RELU */>( + depthwise_3x3_per_channel_quantization_pad_1_< + true /* FUSE_RELU */, + BIAS_TYPE>( N, H, W, @@ -3233,10 +1679,13 @@ void depthwise_3x3_per_channel_quantization_pad_1( C, col_offsets, bias, + act_times_w_scale, thread_id, num_threads); } else if (2 == stride_h && 2 == stride_w) { - depthwise_3x3_per_channel_quantization_pad_1_<true /* FUSE_RELU */>( + depthwise_3x3_per_channel_quantization_pad_1_< + true /* FUSE_RELU */, + BIAS_TYPE>( N, H, W, @@ -3252,10 +1701,13 @@ void depthwise_3x3_per_channel_quantization_pad_1( C, col_offsets, bias, + act_times_w_scale, thread_id, num_threads); } else { - depthwise_3x3_per_channel_quantization_pad_1_<true /* FUSE_RELU */>( + depthwise_3x3_per_channel_quantization_pad_1_< + true /* FUSE_RELU */, + BIAS_TYPE>( N, H, W, @@ -3271,12 +1723,15 @@ void depthwise_3x3_per_channel_quantization_pad_1( C, col_offsets, bias, + act_times_w_scale, thread_id, num_threads); } } else { if (7 == H && 7 == W && 1 == stride_h && 1 == stride_w) { - depthwise_3x3_per_channel_quantization_pad_1_<false /* FUSE_RELU */>( + depthwise_3x3_per_channel_quantization_pad_1_< + false /* FUSE_RELU */, + BIAS_TYPE>( N, H, W, @@ -3292,10 +1747,13 @@ void depthwise_3x3_per_channel_quantization_pad_1( C, col_offsets, bias, + act_times_w_scale, thread_id, num_threads); } else if (14 == H && 14 == W && 2 == stride_h && 2 == stride_w) { - depthwise_3x3_per_channel_quantization_pad_1_<false /* FUSE_RELU */>( + depthwise_3x3_per_channel_quantization_pad_1_< + false /* FUSE_RELU */, + BIAS_TYPE>( N, H, W, @@ -3311,10 +1769,13 @@ void depthwise_3x3_per_channel_quantization_pad_1( C, col_offsets, bias, + act_times_w_scale, thread_id, num_threads); } else if (1 == stride_h && 1 == stride_w) { - depthwise_3x3_per_channel_quantization_pad_1_<false /* FUSE_RELU */>( + depthwise_3x3_per_channel_quantization_pad_1_< + false /* FUSE_RELU */, + BIAS_TYPE>( N, H, W, @@ -3330,10 +1791,13 @@ void depthwise_3x3_per_channel_quantization_pad_1( C, col_offsets, bias, + act_times_w_scale, thread_id, num_threads); } else if (2 == stride_h && 2 == stride_w) { - depthwise_3x3_per_channel_quantization_pad_1_<false /* FUSE_RELU */>( + depthwise_3x3_per_channel_quantization_pad_1_< + false /* FUSE_RELU */, + BIAS_TYPE>( N, H, W, @@ -3349,10 +1813,13 @@ void depthwise_3x3_per_channel_quantization_pad_1( C, col_offsets, bias, + act_times_w_scale, thread_id, num_threads); } else { - depthwise_3x3_per_channel_quantization_pad_1_<false /* FUSE_RELU */>( + depthwise_3x3_per_channel_quantization_pad_1_< + false /* FUSE_RELU */, + BIAS_TYPE>( N, H, W, @@ -3368,225 +1835,179 @@ void depthwise_3x3_per_channel_quantization_pad_1( C, col_offsets, bias, + act_times_w_scale, thread_id, num_threads); } } } -// Dispatch A_SYMMETRIC -template <bool FUSE_RELU, bool HAS_BIAS> -static void depthwise_3x3x3_per_channel_quantization_pad_1_( +// To be removed +void depthwise_3x3_pad_1( int N, - int T, int H, int W, int K, - int stride_t, int stride_h, int stride_w, int32_t A_zero_point, const uint8_t* A, - const int32_t* B_zero_point, - const Packed3x3x3ConvMatrix& B, - const float* C_multiplier, + int32_t B_zero_point, + const PackedDepthWiseConvMatrix& B, + float C_multiplier, int32_t C_zero_point, uint8_t* C, const int32_t* col_offsets, const int32_t* bias, + bool fuse_relu, int thread_id, int num_threads) { - int32_t* C_int32_temp = new int32_t[(K + 31) / 32 * 32]; - if (A_zero_point == 0 || col_offsets == nullptr) { - depthwise_3x3x3_per_channel_quantization_pad_1_< - FUSE_RELU, - HAS_BIAS, - true /*A_SYMM*/>( - N, - T, - H, - W, - K, - stride_t, - stride_h, - stride_w, - A_zero_point, - A, - B_zero_point, - B, - C_multiplier, - C_zero_point, - C_int32_temp, - C, - col_offsets, - bias, - thread_id, - num_threads); - } else { - depthwise_3x3x3_per_channel_quantization_pad_1_< - FUSE_RELU, - HAS_BIAS, - false /*A_SYMM*/>( - N, - T, - H, - W, - K, - stride_t, - stride_h, - stride_w, - A_zero_point, - A, - B_zero_point, - B, - C_multiplier, - C_zero_point, - C_int32_temp, - C, - col_offsets, - bias, - thread_id, - num_threads); - } - delete[] C_int32_temp; + depthwise_3x3_pad_1<std::int32_t>( + N, + H, + W, + K, + stride_h, + stride_w, + A_zero_point, + A, + B_zero_point, + B, + C_multiplier, + C_zero_point, + C, + col_offsets, + bias, + fuse_relu, + 1.0f, + thread_id, + num_threads); } -// Dispatch HAS_BIAS -template <bool FUSE_RELU> -static void depthwise_3x3x3_per_channel_quantization_pad_1_( +// To be removed +void depthwise_3x3_per_channel_quantization_pad_1( int N, - int T, int H, int W, int K, - int stride_t, int stride_h, int stride_w, int32_t A_zero_point, const uint8_t* A, const int32_t* B_zero_point, - const Packed3x3x3ConvMatrix& B, + const PackedDepthWiseConvMatrix& Bp, const float* C_multiplier, int32_t C_zero_point, uint8_t* C, const int32_t* col_offsets, const int32_t* bias, + bool fuse_relu, int thread_id, int num_threads) { - if (bias) { - depthwise_3x3x3_per_channel_quantization_pad_1_< - FUSE_RELU, - true /* HAS_BIAS */>( - N, - T, - H, - W, - K, - stride_t, - stride_h, - stride_w, - A_zero_point, - A, - B_zero_point, - B, - C_multiplier, - C_zero_point, - C, - col_offsets, - bias, - thread_id, - num_threads); - } else { - depthwise_3x3x3_per_channel_quantization_pad_1_< - FUSE_RELU, - false /* HAS_BIAS */>( - N, - T, - H, - W, - K, - stride_t, - stride_h, - stride_w, - A_zero_point, - A, - B_zero_point, - B, - C_multiplier, - C_zero_point, - C, - col_offsets, - bias, - thread_id, - num_threads); - } + depthwise_3x3_per_channel_quantization_pad_1<std::int32_t>( + N, + H, + W, + K, + stride_h, + stride_w, + A_zero_point, + A, + B_zero_point, + Bp, + C_multiplier, + C_zero_point, + C, + col_offsets, + bias, + fuse_relu, + nullptr, + thread_id, + num_threads); } -// Dispatch FUSE_RELU -void depthwise_3x3x3_per_channel_quantization_pad_1( +template void depthwise_3x3_pad_1( + int N, + int H, + int W, + int K, + int stride_h, + int stride_w, + int32_t A_zero_point, + const uint8_t* A, + int32_t B_zero_point, + const PackedDepthWiseConvMatrix& B, + float C_multiplier, + int32_t C_zero_point, + uint8_t* C, + const int32_t* col_offsets, + const int32_t* bias, + bool fuse_relu, + float act_times_w_scale, + int thread_id, + int num_threads); + +template void depthwise_3x3_pad_1( + int N, + int H, + int W, + int K, + int stride_h, + int stride_w, + int32_t A_zero_point, + const uint8_t* A, + int32_t B_zero_point, + const PackedDepthWiseConvMatrix& B, + float C_multiplier, + int32_t C_zero_point, + uint8_t* C, + const int32_t* col_offsets, + const float* bias, + bool fuse_relu, + float act_times_w_scale, + int thread_id, + int num_threads); + +template void depthwise_3x3_per_channel_quantization_pad_1( int N, - int T, int H, int W, int K, - int stride_t, int stride_h, int stride_w, int32_t A_zero_point, const uint8_t* A, const int32_t* B_zero_point, - const Packed3x3x3ConvMatrix& B, + const PackedDepthWiseConvMatrix& Bp, const float* C_multiplier, int32_t C_zero_point, uint8_t* C, const int32_t* col_offsets, const int32_t* bias, bool fuse_relu, + const float* act_times_w_scale, int thread_id, - int num_threads) { - if (fuse_relu) { - depthwise_3x3x3_per_channel_quantization_pad_1_<true /* FUSE_RELU */>( - N, - T, - H, - W, - K, - stride_t, - stride_h, - stride_w, - A_zero_point, - A, - B_zero_point, - B, - C_multiplier, - C_zero_point, - C, - col_offsets, - bias, - thread_id, - num_threads); - } else { - depthwise_3x3x3_per_channel_quantization_pad_1_<false /* FUSE_RELU */>( - N, - T, - H, - W, - K, - stride_t, - stride_h, - stride_w, - A_zero_point, - A, - B_zero_point, - B, - C_multiplier, - C_zero_point, - C, - col_offsets, - bias, - thread_id, - num_threads); - } -} + int num_threads); + +template void depthwise_3x3_per_channel_quantization_pad_1( + int N, + int H, + int W, + int K, + int stride_h, + int stride_w, + int32_t A_zero_point, + const uint8_t* A, + const int32_t* B_zero_point, + const PackedDepthWiseConvMatrix& Bp, + const float* C_multiplier, + int32_t C_zero_point, + uint8_t* C, + const int32_t* col_offsets, + const float* bias, + bool fuse_relu, + const float* act_times_w_scale, + int thread_id, + int num_threads); } // namespace fbgemm diff --git a/src/GenerateKernel.h b/src/GenerateKernel.h index e52097e..c0fece4 100644 --- a/src/GenerateKernel.h +++ b/src/GenerateKernel.h @@ -8,8 +8,11 @@ #include <asmjit/asmjit.h> #include <cpuinfo.h> #include <map> +#include <mutex> +#include <sstream> #include <string> #include <tuple> +#include "CodeCache.h" #include "fbgemm/Fbgemm.h" /*#define FBGEMM_LOG_CODE 1*/ @@ -40,35 +43,7 @@ class CodeGenBase { * @brief Constructor for initializing AVX2/AVX512 registers. */ CodeGenBase(const BlockingFactors* params = nullptr) - : blocking_params(params), - CRegs_avx2_{x86::ymm0, - x86::ymm1, - x86::ymm2, - x86::ymm3, - x86::ymm4, - x86::ymm5, - x86::ymm6, - x86::ymm7, - x86::ymm8, - x86::ymm9, - x86::ymm10, - x86::ymm11}, - CRegs_avx512_{ - x86::zmm0, x86::zmm1, x86::zmm2, x86::zmm3, x86::zmm4, - x86::zmm5, x86::zmm6, x86::zmm7, x86::zmm8, x86::zmm9, - x86::zmm10, x86::zmm11, x86::zmm12, x86::zmm13, x86::zmm14, - x86::zmm15, x86::zmm16, x86::zmm17, x86::zmm18, x86::zmm19, - x86::zmm20, x86::zmm21, x86::zmm22, x86::zmm23, x86::zmm24, - x86::zmm25, x86::zmm26, x86::zmm27, - }, - AllRegs_avx512_{x86::zmm0, x86::zmm1, x86::zmm2, x86::zmm3, - x86::zmm4, x86::zmm5, x86::zmm6, x86::zmm7, - x86::zmm8, x86::zmm9, x86::zmm10, x86::zmm11, - x86::zmm12, x86::zmm13, x86::zmm14, x86::zmm15, - x86::zmm16, x86::zmm17, x86::zmm18, x86::zmm19, - x86::zmm20, x86::zmm21, x86::zmm22, x86::zmm23, - x86::zmm24, x86::zmm25, x86::zmm26, x86::zmm27, - x86::zmm28, x86::zmm29, x86::zmm30, x86::zmm31} { + : blocking_params(params) { // vector width in bits if (cpuinfo_initialize()) { if (fbgemmHasAvx512Support()) { @@ -143,7 +118,7 @@ class CodeGenBase { * (debug-only) */ template <inst_set_t instSet> - std::string getCodeLoggingFile( + static std::string getCodeLoggingFile( bool accum, int mc, int nc, @@ -152,48 +127,60 @@ class CodeGenBase { int MR, int NR, int NR_MIN) { - std::string fileName = "gemm_"; + std::ostringstream oss; + oss << "gemm_"; if (std::is_same<accT, std::int16_t>::value) { - fileName += "acc16_"; + oss << "acc16_"; } else if (std::is_same<accT, std::int32_t>::value) { - fileName += "acc32_"; + oss << "acc32_"; } else { - fileName += "unknown_"; + oss << "unknown_"; } - fileName += "accum-" + std::to_string(accum); - fileName += "_MC-" + std::to_string(mc); - fileName += "_NC-" + std::to_string(nc); - fileName += "_NCB-" + std::to_string(NCB); - fileName += "_NCB-" + std::to_string(KCB); - fileName += "_MR-" + std::to_string(MR); - fileName += "_NR-" + std::to_string(NR); - fileName += "_NR_MIN-" + std::to_string(NR_MIN); + oss << "accum-" + std::to_string(accum) + << "_MC-" + std::to_string(mc) + << "_NC-" + std::to_string(nc) + << "_NCB-" + std::to_string(NCB) + << "_NCB-" + std::to_string(KCB) + << "_MR-" + std::to_string(MR) + << "_NR-" + std::to_string(NR) + << "_NR_MIN-" + std::to_string(NR_MIN); if (instSet == inst_set_t::avx512_vnni) { - fileName += "_avx512vnni"; + oss << "_avx512vnni"; } else if (instSet == inst_set_t::avx512) { - fileName += "_avx512"; + oss << "_avx512"; } else if (instSet == inst_set_t::avx2) { - fileName += "_avx2"; + oss << "_avx2"; } - fileName += ".txt"; - return fileName; + oss << ".txt"; + return oss.str(); } private: - x86::Ymm CRegs_avx2_[12]; ///< AVX2 ymm registers for C in the micro-kernel. - x86::Zmm - CRegs_avx512_[28]; ///< AVX512 zmm registers for C in the micro-kernel. - x86::Zmm AllRegs_avx512_[32]; ///< all AVX512 zmm registers. - int vectorWidth_; ///< Vector width in bits. int VLEN_; ///< Vector width in elements. - static thread_local asmjit::JitRuntime rt_; ///< JIT Runtime for asmjit. - static thread_local asmjit::CodeHolder code_; ///< JIT Code Holder for asmjit. + + static asmjit::JitRuntime &runtime() { + static asmjit::JitRuntime rt; //< JIT Runtime for asmjit, + // depents on other static + // variables. Required to prevent + // initialization order fiasco + return rt; + } + + static std::mutex rtMutex_; ///< Controll access to runtime; + // The hash depends on accumulate, mc, nc, ncb, kcb, nr, mr, nr_min - static thread_local std::map< - std::tuple<bool, int, int, int, int, int, int, int>, - jit_micro_kernel_fp> + static CodeCache<std::tuple<bool, int, int, int, int, int, int, int>, + jit_micro_kernel_fp> codeCache_; ///< JIT Code Cache for reuse. }; +template <typename TA, typename TB, typename TC, typename accT> +std::mutex CodeGenBase<TA, TB, TC, accT>::rtMutex_; + +template <typename TA, typename TB, typename TC, typename accT> +CodeCache<std::tuple<bool, int, int, int, int, int, int, int>, + typename CodeGenBase<TA, TB, TC, accT>::jit_micro_kernel_fp> + CodeGenBase<TA, TB, TC, accT>::codeCache_; + } // namespace fbgemm diff --git a/src/GenerateKernelU8S8S32ACC16.cc b/src/GenerateKernelU8S8S32ACC16.cc index 1e7e081..cbd5877 100644 --- a/src/GenerateKernelU8S8S32ACC16.cc +++ b/src/GenerateKernelU8S8S32ACC16.cc @@ -9,18 +9,6 @@ namespace fbgemm { -template <typename TA, typename TB, typename TC, typename accT> -thread_local asmjit::JitRuntime CodeGenBase<TA, TB, TC, accT>::rt_; - -template <typename TA, typename TB, typename TC, typename accT> -thread_local asmjit::CodeHolder CodeGenBase<TA, TB, TC, accT>::code_; - -template <typename TA, typename TB, typename TC, typename accT> -thread_local std::map< - std::tuple<bool, int, int, int, int, int, int, int>, - typename CodeGenBase<TA, TB, TC, accT>::jit_micro_kernel_fp> - CodeGenBase<TA, TB, TC, accT>::codeCache_; - namespace x86 = asmjit::x86; /** @@ -35,12 +23,13 @@ void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::initCRegs< int rowRegs, int colRegs, int leadingDimCReg) { + using CRegs = x86::Ymm; for (int i = 0; i < rowRegs; ++i) { for (int j = 0; j < colRegs; ++j) { a->vxorps( - CRegs_avx2_[i * leadingDimCReg + j], - CRegs_avx2_[i * leadingDimCReg + j], - CRegs_avx2_[i * leadingDimCReg + j]); + CRegs(i * leadingDimCReg + j), + CRegs(i * leadingDimCReg + j), + CRegs(i * leadingDimCReg + j)); } } } @@ -66,6 +55,8 @@ void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::genComputeBlock< x86::Ymm tmpReg = x86::ymm14; + using CRegs = x86::Ymm; + for (int i = 0; i < rowRegs; ++i) { // broadcast A a->vpbroadcastw( @@ -74,9 +65,9 @@ void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::genComputeBlock< a->vpmaddubsw( tmpReg, AReg, x86::dword_ptr(buffer_B, j * VLEN_ * sizeof(int8_t))); a->vpaddsw( - CRegs_avx2_[i * leadingDimCReg + j], + CRegs(i * leadingDimCReg + j), tmpReg, - CRegs_avx2_[i * leadingDimCReg + j]); + CRegs(i * leadingDimCReg + j)); // Prefetching is hurting performance in some cases // because prefetch instructions itself consumes a slot // in pipeline issue thus slowing down the kernel. @@ -105,12 +96,13 @@ void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::storeCRegs< x86::Xmm extractDest128 = x86::xmm15; x86::Ymm extractDest256 = x86::ymm15; + using CRegs = x86::Ymm; for (int i = 0; i < rowRegs; ++i) { a->imul(C_Offset, ldcReg, static_cast<asmjit::Imm>(i * sizeof(int32_t))); for (int j = 0; j < colRegs; ++j) { for (int idx = 0; idx < 2; ++idx) { a->vextracti128( - extractDest128, CRegs_avx2_[i * leadingDimCReg + j], idx); + extractDest128, CRegs(i * leadingDimCReg + j), idx); a->vpmovsxwd(extractDest256, extractDest128); x86::Mem destAddr = x86::dword_ptr( a->zcx(), C_Offset, 0, (j * 2 + idx) * 8 * sizeof(int32_t)); @@ -172,191 +164,186 @@ CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<inst_set_t::avx2>( nRegBlockSize, nRegBlockSizeMin); - if (codeCache_.find(kernelSig) != codeCache_.end()) { - return codeCache_[kernelSig]; - } - code_.reset(false); - code_.init(rt_.codeInfo()); - x86::Assembler assembler(&code_); - x86::Emitter* a = assembler.as<x86::Emitter>(); + return codeCache_.getOrCreate(kernelSig, [&]() -> jit_micro_kernel_fp { + asmjit::CodeHolder code; + code.init(runtime().codeInfo()); + x86::Assembler assembler(&code); + x86::Emitter *a = assembler.as<x86::Emitter>(); #if defined(FBGEMM_LOG_CODE) - // generated code logging - FILE* codeLogfile = fopen( - getCodeLoggingFile<inst_set_t::avx2>( - accum, - mc, - nc, - nBlock, - kBlock, - mRegBlockSize, - nRegBlockSize, - nRegBlockSizeMin) - .c_str(), - "w"); - asmjit::FileLogger* codeLogger = new asmjit::FileLogger(codeLogfile); - if (codeLogger) { - code_.setLogger(codeLogger); - } + // generated code logging + FILE *codeLogfile = fopen(getCodeLoggingFile<inst_set_t::avx2>( + accum, mc, nc, nBlock, kBlock, mRegBlockSize, + nRegBlockSize, nRegBlockSizeMin) + .c_str(), + "w"); + asmjit::FileLogger *codeLogger = new asmjit::FileLogger(codeLogfile); + if (codeLogger) { + code.setLogger(codeLogger); + } #endif - int mRegBlocks = mc / mRegBlockSize; - int mRegBlocksRem = mc % mRegBlockSize; - assert(kc % row_interleave == 0 && "kc must be a multiple of row_interleave"); - // assert((nc == nRegBlockSize) && - //"nc must be equal to the number of register blocks"); - - // arguments to the function created - x86::Gp buffer_A = a->zdi(); - x86::Gp buffer_B = a->zsi(); - x86::Gp B_pf = a->zdx(); - x86::Gp CBase = a->zcx(); - x86::Gp kSize = a->gpz(8); - x86::Gp ldcReg = a->gpz(9); - - asmjit::FuncDetail func; - func.init( - asmjit:: - FuncSignatureT<void, uint8_t*, int8_t*, int8_t*, int32_t*, int, int>( - asmjit::CallConv::kIdHost)); - - asmjit::FuncFrame frame; - frame.init(func); - frame.setDirtyRegs( - x86::Reg::kGroupVec, - asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) | - asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15)); - frame.setDirtyRegs( - x86::Reg::kGroupGp, asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14)); - - asmjit::FuncArgsAssignment args(&func); - args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg); - - args.updateFuncFrame(frame); - frame.finalize(); - - a->emitProlog(frame); - a->emitArgsAssignment(frame, args); - - asmjit::Label Loopk = a->newLabel(); - asmjit::Label LoopMBlocks = a->newLabel(); - - x86::Gp buffer_B_saved = a->gpz(10); - x86::Gp C_Offset = a->gpz(11); - // x86::Gp B_pf_saved = a->gpz(12); - x86::Gp iIdx = a->gpz(13); - x86::Gp kIdx = a->gpz(14); - - int colRegs = nc * row_interleave * sizeof(int8_t) / VLEN_; - if (mRegBlocks > 0) { - // move 0 to iteration variables - a->mov(iIdx, 0); - - // save B_buffer address - a->mov(buffer_B_saved, buffer_B); - // a->mov(B_pf_saved, B_pf); - - a->bind(LoopMBlocks); - a->inc(iIdx); - - int rowRegs = mRegBlockSize; - - // init C registers - initCRegs<inst_set_t::avx2>(a, rowRegs, colRegs); - - // init k loop index - a->mov(kIdx, 0); - a->bind(Loopk); - // k is incremented by row_interleave - a->add(kIdx, static_cast<asmjit::Imm>(row_interleave)); - - genComputeBlock<inst_set_t::avx2>( - a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock); - - // update buffer_A address for next k iteration - a->add( - buffer_A, static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t))); - - // update buffer_B address for next k iteration - a->add( - buffer_B, - static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t))); - // a->add(B_pf, static_cast<asmjit::Imm>(nBlock * row_interleave * - // sizeof(int8_t))); - - a->cmp(kIdx, kSize); - a->jl(Loopk); - - // store C matrix - storeCRegs<inst_set_t::avx2>(a, rowRegs, colRegs, C_Offset, ldcReg, accum); - - // increment A for next block - a->sub(buffer_A, kSize); - a->add( - buffer_A, static_cast<asmjit::Imm>((rowRegs)*kBlock * sizeof(uint8_t))); - // increment C for next block - a->imul( - C_Offset, ldcReg, static_cast<asmjit::Imm>(rowRegs * sizeof(int32_t))); - a->add(CBase, C_Offset); - // reset B - a->mov(buffer_B, buffer_B_saved); - // a->mov(B_pf, B_pf_saved); - - a->cmp(iIdx, mRegBlocks); - a->jl(LoopMBlocks); - } - // generate code for remainder - if (mRegBlocksRem > 0) { - asmjit::Label LoopkRem = a->newLabel(); - int rowRegs = mRegBlocksRem; + int mRegBlocks = mc / mRegBlockSize; + int mRegBlocksRem = mc % mRegBlockSize; + assert(kc % row_interleave == 0 && + "kc must be a multiple of row_interleave"); + // assert((nc == nRegBlockSize) && + //"nc must be equal to the number of register blocks"); + + // arguments to the function created + x86::Gp buffer_A = a->zdi(); + x86::Gp buffer_B = a->zsi(); + x86::Gp B_pf = a->zdx(); + x86::Gp CBase = a->zcx(); + x86::Gp kSize = a->gpz(8); + x86::Gp ldcReg = a->gpz(9); + + asmjit::FuncDetail func; + func.init( + asmjit::FuncSignatureT<void, uint8_t *, int8_t *, int8_t *, int32_t *, + int, int>(asmjit::CallConv::kIdHost)); + + asmjit::FuncFrame frame; + frame.init(func); + frame.setDirtyRegs( + x86::Reg::kGroupVec, + asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) | + asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15)); + frame.setDirtyRegs(x86::Reg::kGroupGp, + asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14)); + + asmjit::FuncArgsAssignment args(&func); + args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg); + + args.updateFuncFrame(frame); + frame.finalize(); + + a->emitProlog(frame); + a->emitArgsAssignment(frame, args); + + asmjit::Label Loopk = a->newLabel(); + asmjit::Label LoopMBlocks = a->newLabel(); + + x86::Gp buffer_B_saved = a->gpz(10); + x86::Gp C_Offset = a->gpz(11); + // x86::Gp B_pf_saved = a->gpz(12); + x86::Gp iIdx = a->gpz(13); + x86::Gp kIdx = a->gpz(14); + + int colRegs = nc * row_interleave * sizeof(int8_t) / VLEN_; + if (mRegBlocks > 0) { + // move 0 to iteration variables + a->mov(iIdx, 0); + + // save B_buffer address + a->mov(buffer_B_saved, buffer_B); + // a->mov(B_pf_saved, B_pf); + + a->bind(LoopMBlocks); + a->inc(iIdx); + + int rowRegs = mRegBlockSize; + + // init C registers + initCRegs<inst_set_t::avx2>(a, rowRegs, colRegs); + + // init k loop index + a->mov(kIdx, 0); + a->bind(Loopk); + // k is incremented by row_interleave + a->add(kIdx, static_cast<asmjit::Imm>(row_interleave)); + + genComputeBlock<inst_set_t::avx2>(a, buffer_A, buffer_B, B_pf, rowRegs, + colRegs, kBlock); + + // update buffer_A address for next k iteration + a->add(buffer_A, + static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t))); + + // update buffer_B address for next k iteration + a->add(buffer_B, static_cast<asmjit::Imm>(nBlock * row_interleave * + sizeof(int8_t))); + // a->add(B_pf, static_cast<asmjit::Imm>(nBlock * row_interleave * + // sizeof(int8_t))); + + a->cmp(kIdx, kSize); + a->jl(Loopk); + + // store C matrix + storeCRegs<inst_set_t::avx2>(a, rowRegs, colRegs, C_Offset, ldcReg, + accum); + + // increment A for next block + a->sub(buffer_A, kSize); + a->add(buffer_A, + static_cast<asmjit::Imm>((rowRegs)*kBlock * sizeof(uint8_t))); + // increment C for next block + a->imul(C_Offset, ldcReg, + static_cast<asmjit::Imm>(rowRegs * sizeof(int32_t))); + a->add(CBase, C_Offset); + // reset B + a->mov(buffer_B, buffer_B_saved); + // a->mov(B_pf, B_pf_saved); + + a->cmp(iIdx, mRegBlocks); + a->jl(LoopMBlocks); + } + // generate code for remainder + if (mRegBlocksRem > 0) { + asmjit::Label LoopkRem = a->newLabel(); + int rowRegs = mRegBlocksRem; - // init C registers - initCRegs<inst_set_t::avx2>(a, rowRegs, colRegs); + // init C registers + initCRegs<inst_set_t::avx2>(a, rowRegs, colRegs); - // init k loop index - a->mov(kIdx, 0); - a->bind(LoopkRem); + // init k loop index + a->mov(kIdx, 0); + a->bind(LoopkRem); - // k is incremented by row_interleave - a->add(kIdx, static_cast<asmjit::Imm>(row_interleave)); + // k is incremented by row_interleave + a->add(kIdx, static_cast<asmjit::Imm>(row_interleave)); - genComputeBlock<inst_set_t::avx2>( - a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock); + genComputeBlock<inst_set_t::avx2>(a, buffer_A, buffer_B, B_pf, rowRegs, + colRegs, kBlock); - // update buffer_A address for next k iteration - a->add( - buffer_A, static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t))); + // update buffer_A address for next k iteration + a->add(buffer_A, + static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t))); - // update buffer_B address for next k iteration - a->add( - buffer_B, - static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t))); - // a->add(B_pf, static_cast<asmjit::Imm>(nBlock * row_interleave * - // sizeof(int8_t))); + // update buffer_B address for next k iteration + a->add(buffer_B, static_cast<asmjit::Imm>(nBlock * row_interleave * + sizeof(int8_t))); + // a->add(B_pf, static_cast<asmjit::Imm>(nBlock * row_interleave * + // sizeof(int8_t))); - a->cmp(kIdx, kSize); - a->jl(LoopkRem); + a->cmp(kIdx, kSize); + a->jl(LoopkRem); - // store C matrix - storeCRegs<inst_set_t::avx2>(a, rowRegs, colRegs, C_Offset, ldcReg, accum); - } + // store C matrix + storeCRegs<inst_set_t::avx2>(a, rowRegs, colRegs, C_Offset, ldcReg, + accum); + } - a->emitEpilog(frame); + a->emitEpilog(frame); - jit_micro_kernel_fp fn; - asmjit::Error err = rt_.add(&fn, &code_); - if (err) { - std::cout << "Error: in fn add" << std::endl; - return nullptr; - } - codeCache_[kernelSig] = fn; + jit_micro_kernel_fp fn; + asmjit::Error err; + { + std::unique_lock<std::mutex> lock(rtMutex_); + err = runtime().add(&fn, &code); + } + if (err) { + std::cout << "Error: in fn add" << std::endl; + return nullptr; + } #if defined(FBGEMM_LOG_CODE) - fclose(codeLogfile); - delete codeLogger; + fclose(codeLogfile); + delete codeLogger; #endif - return fn; + return fn; + }); } } // namespace fbgemm diff --git a/src/GenerateKernelU8S8S32ACC16Avx512.cc b/src/GenerateKernelU8S8S32ACC16Avx512.cc index a49e440..512c8ba 100644 --- a/src/GenerateKernelU8S8S32ACC16Avx512.cc +++ b/src/GenerateKernelU8S8S32ACC16Avx512.cc @@ -23,12 +23,13 @@ void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::initCRegs< int rowRegs, int colRegs, int leadingDimCReg) { + using CRegs = x86::Zmm; for (int i = 0; i < rowRegs; ++i) { for (int j = 0; j < colRegs; ++j) { a->vxorps( - CRegs_avx512_[i * leadingDimCReg + j], - CRegs_avx512_[i * leadingDimCReg + j], - CRegs_avx512_[i * leadingDimCReg + j]); + CRegs(i * leadingDimCReg + j), + CRegs(i * leadingDimCReg + j), + CRegs(i * leadingDimCReg + j)); } } } @@ -57,20 +58,22 @@ void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::genComputeBlock< // We start allocating BRegs from zmm27 and then allocate zmm26 and so on. for (int j = 0; j < colRegs; ++j) { a->vmovups( - AllRegs_avx512_[27 - j], + x86::Zmm(27 - j), x86::dword_ptr(buffer_B, j * VLEN_ * sizeof(int8_t))); } + using CRegs = x86::Zmm; + for (int i = 0; i < rowRegs; ++i) { // broadcast A a->vpbroadcastw( AReg, x86::dword_ptr(buffer_A, (i * lda) * sizeof(uint8_t))); for (int j = 0; j < colRegs; ++j) { - a->vpmaddubsw(tmpReg, AReg, AllRegs_avx512_[27 - j]); + a->vpmaddubsw(tmpReg, AReg, x86::Zmm(27 - j)); a->vpaddsw( - CRegs_avx512_[i * leadingDimCReg + j], + CRegs(i * leadingDimCReg + j), tmpReg, - CRegs_avx512_[i * leadingDimCReg + j]); + CRegs(i * leadingDimCReg + j)); // Prefetching is hurting performance in some cases // because prefetch instructions itself consumes a slot // in pipeline issue thus slowing down the kernel. @@ -100,12 +103,13 @@ void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::storeCRegs< x86::Ymm extractDest256 = x86::ymm31; x86::Zmm extractDest512 = x86::zmm31; + using CRegs = x86::Zmm; for (int i = 0; i < rowRegs; ++i) { a->imul(C_Offset, ldcReg, static_cast<asmjit::Imm>(i * sizeof(int32_t))); for (int j = 0; j < colRegs; ++j) { for (int idx = 0; idx < 2; ++idx) { a->vextracti32x8( - extractDest256, CRegs_avx512_[i * leadingDimCReg + j], idx); + extractDest256, CRegs(i * leadingDimCReg + j), idx); a->vpmovsxwd(extractDest512, extractDest256); x86::Mem destAddr = x86::dword_ptr( a->zcx(), C_Offset, 0, (j * 2 + idx) * 16 * sizeof(int32_t)); @@ -167,262 +171,247 @@ CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<inst_set_t::avx512>( nRegBlockSize, nRegBlockSizeMin); - if (codeCache_.find(kernelSig) != codeCache_.end()) { - return codeCache_[kernelSig]; - } - - code_.reset(false); - code_.init(rt_.codeInfo()); - x86::Assembler assembler(&code_); - x86::Emitter* a = assembler.as<x86::Emitter>(); + return codeCache_.getOrCreate(kernelSig, [&]() -> jit_micro_kernel_fp { + asmjit::CodeHolder code; + code.init(runtime().codeInfo()); + x86::Assembler assembler(&code); + x86::Emitter *a = assembler.as<x86::Emitter>(); #if defined(FBGEMM_LOG_CODE) - // generated code logging - FILE* codeLogfile = fopen( - getCodeLoggingFile<inst_set_t::avx512>( - accum, - mc, - nc, - nBlock, - kBlock, - mRegBlockSize, - nRegBlockSize, - nRegBlockSizeMin) - .c_str(), - "w"); - asmjit::FileLogger* codeLogger = new asmjit::FileLogger(codeLogfile); - if (codeLogger) { - code_.setLogger(codeLogger); - } + // generated code logging + FILE *codeLogfile = fopen(getCodeLoggingFile<inst_set_t::avx512>( + accum, mc, nc, nBlock, kBlock, mRegBlockSize, + nRegBlockSize, nRegBlockSizeMin) + .c_str(), + "w"); + asmjit::FileLogger *codeLogger = new asmjit::FileLogger(codeLogfile); + if (codeLogger) { + code.setLogger(codeLogger); + } #endif - assert(kc % row_interleave == 0 && "kc must be a multiple of row_interleave"); - assert(nc % nRegBlockSizeMin == 0 && "nc must be a multiple of NR_MIN"); - int maxMRegs = mRegBlockSize; - int maxNRegs = nRegBlockSize * row_interleave / VLEN_; - assert( - (maxMRegs+1) * maxNRegs <= 28 && - "number of zmm registers for C + one row for loading B: \ + assert(kc % row_interleave == 0 && + "kc must be a multiple of row_interleave"); + assert(nc % nRegBlockSizeMin == 0 && "nc must be a multiple of NR_MIN"); + int maxMRegs = mRegBlockSize; + int maxNRegs = nRegBlockSize * row_interleave / VLEN_; + assert((maxMRegs + 1) * maxNRegs <= 28 && + "number of zmm registers for C + one row for loading B: \ MR*(NR*ROW_INTERLEAVE*8/512) + (NR*ROW_INTERLEAVE*8/512) \ must be <= 28(available registers constraint)"); - int mRegBlocks = mc / mRegBlockSize; - int mRegBlocksRem = mc % mRegBlockSize; - - // arguments to the function created - x86::Gp buffer_A = a->zdi(); - x86::Gp buffer_B = a->zsi(); - x86::Gp B_pf = a->zdx(); - x86::Gp CBase = a->zcx(); - x86::Gp kSize = a->gpz(8); - x86::Gp ldcReg = a->gpz(9); - - asmjit::FuncDetail func; - func.init( - asmjit:: - FuncSignatureT<void, uint8_t*, int8_t*, int8_t*, int32_t*, int, int>( - asmjit::CallConv::kIdHost)); - - asmjit::FuncFrame frame; - frame.init(func); - - frame.setDirtyRegs( - x86::Reg::kGroupVec, - asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) | - asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15)); - frame.setDirtyRegs( - x86::Reg::kGroupGp, - asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15)); - - asmjit::FuncArgsAssignment args(&func); - args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg); - - args.updateFuncFrame(frame); - frame.finalize(); - - a->emitProlog(frame); - a->emitArgsAssignment(frame, args); - - asmjit::Label LoopMBlocks = a->newLabel(); - asmjit::Label LoopNBlocks = a->newLabel(); - asmjit::Label Loopk = a->newLabel(); - - x86::Gp buffer_B_saved = a->gpz(10); - x86::Gp C_Offset = a->gpz(11); - // x86::Gp B_pf_saved = a->gpz(12); - x86::Gp iIdx = a->gpz(13); - x86::Gp jIdx = a->gpz(14); - x86::Gp kIdx = a->gpz(15); - - // save B_buffer address - a->mov(buffer_B_saved, buffer_B); - // a->mov(B_pf_saved, B_pf); - - int currColRegs = nc * row_interleave * sizeof(int8_t) / VLEN_; - int colRegs = std::min(currColRegs, maxNRegs); - if (mRegBlocks > 0) { - // move 0 to iteration variables - a->mov(iIdx, 0); - - a->bind(LoopMBlocks); - a->inc(iIdx); - a->mov(jIdx, 0); - - a->bind(LoopNBlocks); - a->inc(jIdx); - - int rowRegs = mRegBlockSize; - - // init C registers - initCRegs<inst_set_t::avx512>(a, rowRegs, colRegs, colRegs); - - // init k loop index - a->mov(kIdx, 0); - a->bind(Loopk); - // k is incremented by row_interleave - a->add(kIdx, static_cast<asmjit::Imm>(row_interleave)); - - genComputeBlock<inst_set_t::avx512>( - a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock, colRegs); - - // update buffer_A address for next k iteration - a->add( - buffer_A, static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t))); - - // update buffer_B address for next k iteration - a->add( - buffer_B, - static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t))); - // a->add(B_pf, static_cast<asmjit::Imm>(nBlock * row_interleave * - // sizeof(int8_t))); - - a->cmp(kIdx, kSize); - a->jl(Loopk); - - // store C matrix - storeCRegs<inst_set_t::avx512>( - a, rowRegs, colRegs, C_Offset, ldcReg, accum, colRegs); - - // reset A - a->sub(buffer_A, kSize); - - // B for next block - a->mov(buffer_B, buffer_B_saved); - // using C_Offset as temp reg - a->imul( - C_Offset, - jIdx, - static_cast<asmjit::Imm>( - nRegBlockSize * row_interleave * sizeof(int8_t))); - a->add(buffer_B, C_Offset); - - // increment C for next block - a->add(CBase, static_cast<asmjit::Imm>(nRegBlockSize * sizeof(int32_t))); - - int jLoopTrips = currColRegs / maxNRegs; - // jLoopTrips should be at least 1 - jLoopTrips = jLoopTrips ? jLoopTrips : 1; - a->cmp(jIdx, jLoopTrips); - a->jl(LoopNBlocks); - - // increment A for next block - a->add( - buffer_A, static_cast<asmjit::Imm>((rowRegs)*kBlock * sizeof(uint8_t))); - - // increment C for next A block - a->sub( - CBase, - static_cast<asmjit::Imm>(jLoopTrips * nRegBlockSize * sizeof(int32_t))); - a->imul( - C_Offset, ldcReg, static_cast<asmjit::Imm>(rowRegs * sizeof(int32_t))); - a->add(CBase, C_Offset); - - // reset B - a->mov(buffer_B, buffer_B_saved); - // a->mov(B_pf, B_pf_saved); - - a->cmp(iIdx, mRegBlocks); - a->jl(LoopMBlocks); - } - // generate code for remainder - if (mRegBlocksRem > 0) { - asmjit::Label LoopNRem = a->newLabel(); - asmjit::Label LoopkRem = a->newLabel(); - int rowRegs = mRegBlocksRem; - - a->mov(jIdx, 0); - a->bind(LoopNRem); - a->inc(jIdx); - - // init C registers - initCRegs<inst_set_t::avx512>(a, rowRegs, colRegs, colRegs); - - // init k loop index - a->mov(kIdx, 0); - a->bind(LoopkRem); - - // k is incremented by row_interleave - a->add(kIdx, static_cast<asmjit::Imm>(row_interleave)); - - genComputeBlock<inst_set_t::avx512>( - a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock, colRegs); - - // update buffer_A address for next k iteration - a->add( - buffer_A, static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t))); - - // update buffer_B address for next k iteration - a->add( - buffer_B, - static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t))); - // a->add(B_pf, static_cast<asmjit::Imm>(nBlock * row_interleave * - // sizeof(int8_t))); - - a->cmp(kIdx, kSize); - a->jl(LoopkRem); - - // reset A - a->sub(buffer_A, kSize); - - // B for next block - a->mov(buffer_B, buffer_B_saved); - // using C_Offset as temp reg - a->imul( - C_Offset, - jIdx, - static_cast<asmjit::Imm>( - nRegBlockSize * row_interleave * sizeof(int8_t))); - a->add(buffer_B, C_Offset); - - // store C matrix - storeCRegs<inst_set_t::avx512>( - a, rowRegs, colRegs, C_Offset, ldcReg, accum, colRegs); - - // increment C for next block - a->add(CBase, static_cast<asmjit::Imm>(nRegBlockSize * sizeof(int32_t))); - - int jLoopTrips = currColRegs / maxNRegs; - // jLoopTrips should be at least 1 - jLoopTrips = jLoopTrips ? jLoopTrips : 1; - a->cmp(jIdx, jLoopTrips); - a->jl(LoopNRem); - } + int mRegBlocks = mc / mRegBlockSize; + int mRegBlocksRem = mc % mRegBlockSize; + + // arguments to the function created + x86::Gp buffer_A = a->zdi(); + x86::Gp buffer_B = a->zsi(); + x86::Gp B_pf = a->zdx(); + x86::Gp CBase = a->zcx(); + x86::Gp kSize = a->gpz(8); + x86::Gp ldcReg = a->gpz(9); + + asmjit::FuncDetail func; + func.init( + asmjit::FuncSignatureT<void, uint8_t *, int8_t *, int8_t *, int32_t *, + int, int>(asmjit::CallConv::kIdHost)); + + asmjit::FuncFrame frame; + frame.init(func); + + frame.setDirtyRegs( + x86::Reg::kGroupVec, + asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) | + asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15)); + frame.setDirtyRegs(x86::Reg::kGroupGp, + asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15)); + + asmjit::FuncArgsAssignment args(&func); + args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg); + + args.updateFuncFrame(frame); + frame.finalize(); + + a->emitProlog(frame); + a->emitArgsAssignment(frame, args); + + asmjit::Label LoopMBlocks = a->newLabel(); + asmjit::Label LoopNBlocks = a->newLabel(); + asmjit::Label Loopk = a->newLabel(); + + x86::Gp buffer_B_saved = a->gpz(10); + x86::Gp C_Offset = a->gpz(11); + // x86::Gp B_pf_saved = a->gpz(12); + x86::Gp iIdx = a->gpz(13); + x86::Gp jIdx = a->gpz(14); + x86::Gp kIdx = a->gpz(15); + + // save B_buffer address + a->mov(buffer_B_saved, buffer_B); + // a->mov(B_pf_saved, B_pf); + + int currColRegs = nc * row_interleave * sizeof(int8_t) / VLEN_; + int colRegs = std::min(currColRegs, maxNRegs); + if (mRegBlocks > 0) { + // move 0 to iteration variables + a->mov(iIdx, 0); + + a->bind(LoopMBlocks); + a->inc(iIdx); + a->mov(jIdx, 0); + + a->bind(LoopNBlocks); + a->inc(jIdx); + + int rowRegs = mRegBlockSize; + + // init C registers + initCRegs<inst_set_t::avx512>(a, rowRegs, colRegs, colRegs); + + // init k loop index + a->mov(kIdx, 0); + a->bind(Loopk); + // k is incremented by row_interleave + a->add(kIdx, static_cast<asmjit::Imm>(row_interleave)); + + genComputeBlock<inst_set_t::avx512>(a, buffer_A, buffer_B, B_pf, rowRegs, + colRegs, kBlock, colRegs); + + // update buffer_A address for next k iteration + a->add(buffer_A, + static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t))); + + // update buffer_B address for next k iteration + a->add(buffer_B, static_cast<asmjit::Imm>(nBlock * row_interleave * + sizeof(int8_t))); + // a->add(B_pf, static_cast<asmjit::Imm>(nBlock * row_interleave * + // sizeof(int8_t))); + + a->cmp(kIdx, kSize); + a->jl(Loopk); + + // store C matrix + storeCRegs<inst_set_t::avx512>(a, rowRegs, colRegs, C_Offset, ldcReg, + accum, colRegs); + + // reset A + a->sub(buffer_A, kSize); + + // B for next block + a->mov(buffer_B, buffer_B_saved); + // using C_Offset as temp reg + a->imul(C_Offset, jIdx, + static_cast<asmjit::Imm>(nRegBlockSize * row_interleave * + sizeof(int8_t))); + a->add(buffer_B, C_Offset); + + // increment C for next block + a->add(CBase, static_cast<asmjit::Imm>(nRegBlockSize * sizeof(int32_t))); + + int jLoopTrips = currColRegs / maxNRegs; + // jLoopTrips should be at least 1 + jLoopTrips = jLoopTrips ? jLoopTrips : 1; + a->cmp(jIdx, jLoopTrips); + a->jl(LoopNBlocks); + + // increment A for next block + a->add(buffer_A, + static_cast<asmjit::Imm>((rowRegs)*kBlock * sizeof(uint8_t))); + + // increment C for next A block + a->sub(CBase, static_cast<asmjit::Imm>(jLoopTrips * nRegBlockSize * + sizeof(int32_t))); + a->imul(C_Offset, ldcReg, + static_cast<asmjit::Imm>(rowRegs * sizeof(int32_t))); + a->add(CBase, C_Offset); + + // reset B + a->mov(buffer_B, buffer_B_saved); + // a->mov(B_pf, B_pf_saved); + + a->cmp(iIdx, mRegBlocks); + a->jl(LoopMBlocks); + } + // generate code for remainder + if (mRegBlocksRem > 0) { + asmjit::Label LoopNRem = a->newLabel(); + asmjit::Label LoopkRem = a->newLabel(); + int rowRegs = mRegBlocksRem; + + a->mov(jIdx, 0); + a->bind(LoopNRem); + a->inc(jIdx); + + // init C registers + initCRegs<inst_set_t::avx512>(a, rowRegs, colRegs, colRegs); + + // init k loop index + a->mov(kIdx, 0); + a->bind(LoopkRem); + + // k is incremented by row_interleave + a->add(kIdx, static_cast<asmjit::Imm>(row_interleave)); + + genComputeBlock<inst_set_t::avx512>(a, buffer_A, buffer_B, B_pf, rowRegs, + colRegs, kBlock, colRegs); + + // update buffer_A address for next k iteration + a->add(buffer_A, + static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t))); + + // update buffer_B address for next k iteration + a->add(buffer_B, static_cast<asmjit::Imm>(nBlock * row_interleave * + sizeof(int8_t))); + // a->add(B_pf, static_cast<asmjit::Imm>(nBlock * row_interleave * + // sizeof(int8_t))); + + a->cmp(kIdx, kSize); + a->jl(LoopkRem); + + // reset A + a->sub(buffer_A, kSize); + + // B for next block + a->mov(buffer_B, buffer_B_saved); + // using C_Offset as temp reg + a->imul(C_Offset, jIdx, + static_cast<asmjit::Imm>(nRegBlockSize * row_interleave * + sizeof(int8_t))); + a->add(buffer_B, C_Offset); + + // store C matrix + storeCRegs<inst_set_t::avx512>(a, rowRegs, colRegs, C_Offset, ldcReg, + accum, colRegs); + + // increment C for next block + a->add(CBase, static_cast<asmjit::Imm>(nRegBlockSize * sizeof(int32_t))); + + int jLoopTrips = currColRegs / maxNRegs; + // jLoopTrips should be at least 1 + jLoopTrips = jLoopTrips ? jLoopTrips : 1; + a->cmp(jIdx, jLoopTrips); + a->jl(LoopNRem); + } - a->emitEpilog(frame); + a->emitEpilog(frame); - jit_micro_kernel_fp fn; - asmjit::Error err = rt_.add(&fn, &code_); - if (err) { - std::cout << "Error: in fn add" << std::endl; - return nullptr; - } - codeCache_[kernelSig] = fn; + jit_micro_kernel_fp fn; + asmjit::Error err; + { + std::unique_lock<std::mutex> lock(rtMutex_); + err = runtime().add(&fn, &code); + } + if (err) { + std::cout << "Error: in fn add" << std::endl; + return nullptr; + } #if defined(FBGEMM_LOG_CODE) - fclose(codeLogfile); - delete codeLogger; + fclose(codeLogfile); + delete codeLogger; #endif - return fn; + return fn; + }); } } // namespace fbgemm diff --git a/src/GenerateKernelU8S8S32ACC32.cc b/src/GenerateKernelU8S8S32ACC32.cc index 6b54743..226e974 100644 --- a/src/GenerateKernelU8S8S32ACC32.cc +++ b/src/GenerateKernelU8S8S32ACC32.cc @@ -9,18 +9,6 @@ namespace fbgemm { -template <typename TA, typename TB, typename TC, typename accT> -thread_local asmjit::JitRuntime CodeGenBase<TA, TB, TC, accT>::rt_; - -template <typename TA, typename TB, typename TC, typename accT> -thread_local asmjit::CodeHolder CodeGenBase<TA, TB, TC, accT>::code_; - -template <typename TA, typename TB, typename TC, typename accT> -thread_local std::map< - std::tuple<bool, int, int, int, int, int, int, int>, - typename CodeGenBase<TA, TB, TC, accT>::jit_micro_kernel_fp> - CodeGenBase<TA, TB, TC, accT>::codeCache_; - namespace x86 = asmjit::x86; /** @@ -35,12 +23,13 @@ void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::initCRegs< int rowRegs, int colRegs, int leadingDimCReg) { + using CRegs = x86::Ymm; for (int i = 0; i < rowRegs; ++i) { for (int j = 0; j < colRegs; ++j) { a->vxorps( - CRegs_avx2_[i * leadingDimCReg + j], - CRegs_avx2_[i * leadingDimCReg + j], - CRegs_avx2_[i * leadingDimCReg + j]); + CRegs(i * leadingDimCReg + j), + CRegs(i * leadingDimCReg + j), + CRegs(i * leadingDimCReg + j)); } } } @@ -73,6 +62,8 @@ void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::genComputeBlock< // temporary register x86::Ymm res1 = x86::ymm14; + using CRegs = x86::Ymm; + for (int j = 0; j < colRegs; ++j) { // load B a->vmovaps(BReg, x86::dword_ptr(buffer_B, j * VLEN_ * sizeof(int8_t))); @@ -83,9 +74,9 @@ void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::genComputeBlock< a->vpmaddubsw(res1, AReg, BReg); a->vpmaddwd(res1, oneReg, res1); a->vpaddd( - CRegs_avx2_[i * leadingDimCReg + j], + CRegs(i * leadingDimCReg + j), res1, - CRegs_avx2_[i * leadingDimCReg + j]); + CRegs(i * leadingDimCReg + j)); } a->prefetcht0(x86::dword_ptr(B_pf, j * VLEN_ * sizeof(int8_t))); } @@ -106,6 +97,7 @@ void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::storeCRegs< x86::Gp ldcReg, bool accum, int leadingDimCReg) { + using CRegs = x86::Ymm; for (int i = 0; i < rowRegs; ++i) { if (i != 0) { a->add(C_Offset, ldcReg); @@ -113,13 +105,13 @@ void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::storeCRegs< for (int j = 0; j < colRegs; ++j) { if (accum) { a->vpaddd( - CRegs_avx2_[i * leadingDimCReg + j], - CRegs_avx2_[i * leadingDimCReg + j], + CRegs(i * leadingDimCReg + j), + CRegs(i * leadingDimCReg + j), x86::dword_ptr(a->zcx(), C_Offset, 0, j * 8 * sizeof(int32_t))); } a->vmovups( x86::dword_ptr(a->zcx(), C_Offset, 0, j * 8 * sizeof(int32_t)), - CRegs_avx2_[i * leadingDimCReg + j]); + CRegs(i * leadingDimCReg + j)); } } } @@ -173,206 +165,169 @@ CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreate<inst_set_t::avx2>( nRegBlockSize, nRegBlockSizeMin); - if (codeCache_.find(kernelSig) != codeCache_.end()) { - return codeCache_[kernelSig]; - } - code_.reset(false); - code_.init(rt_.codeInfo()); - x86::Assembler assembler(&code_); - x86::Emitter* a = assembler.as<x86::Emitter>(); + return codeCache_.getOrCreate(kernelSig, [&]() -> jit_micro_kernel_fp { + asmjit::CodeHolder code; + code.init(runtime().codeInfo()); + x86::Assembler assembler(&code); + x86::Emitter *a = assembler.as<x86::Emitter>(); #if defined(FBGEMM_LOG_CODE) - // generated code logging - FILE* codeLogfile = fopen( - getCodeLoggingFile<inst_set_t::avx2>( - accum, - mc, - nc, - nBlock, - kBlock, - mRegBlockSize, - nRegBlockSize, - nRegBlockSizeMin) - .c_str(), - "w"); - asmjit::FileLogger* codeLogger = new asmjit::FileLogger(codeLogfile); - if (codeLogger) { - code_.setLogger(codeLogger); - } + // generated code logging + FILE *codeLogfile = fopen(getCodeLoggingFile<inst_set_t::avx2>( + accum, mc, nc, nBlock, kBlock, mRegBlockSize, + nRegBlockSize, nRegBlockSizeMin) + .c_str(), + "w"); + asmjit::FileLogger *codeLogger = new asmjit::FileLogger(codeLogfile); + if (codeLogger) { + code.setLogger(codeLogger); + } #endif - // assert(mc <= 12 && "mc must be <= 12 (available registers constraint)"); - int mRegBlocks = mc / mRegBlockSize; - int mRegBlocksRem = mc % mRegBlockSize; - - // arguments to the function created - x86::Gp buffer_A = a->zdi(); - x86::Gp buffer_B = a->zsi(); - x86::Gp B_pf = a->zdx(); - x86::Gp CBase = a->zcx(); - x86::Gp kSize = a->gpz(8); - x86::Gp ldcReg = a->gpz(9); - - asmjit::FuncDetail func; - func.init( - asmjit:: - FuncSignatureT<void, uint8_t*, int8_t*, int8_t*, int32_t*, int, int>( - asmjit::CallConv::kIdHost)); - - asmjit::FuncFrame frame; - frame.init(func); - frame.setDirtyRegs( - x86::Reg::kGroupVec, - asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) | - asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15)); - frame.setDirtyRegs( - x86::Reg::kGroupGp, asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14)); - - asmjit::FuncArgsAssignment args(&func); - args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg); - - args.updateFuncFrame(frame); - frame.finalize(); - - a->emitProlog(frame); - a->emitArgsAssignment(frame, args); - - asmjit::Label Loopk = a->newLabel(); - asmjit::Label LoopMBlocks = a->newLabel(); - - x86::Gp buffer_B_saved = a->gpz(10); - x86::Gp C_Offset = a->gpz(11); - x86::Gp B_pf_saved = a->gpz(12); - x86::Gp iIdx = a->gpz(13); - x86::Gp kIdx = a->gpz(14); - // x86::Gp B_pf = a->gpz(8); - - x86::Ymm oneReg = x86::ymm15; - // create 16-bit 1s - // i.e., oneReg[0:15] contains 0x0001, oneReg[16:31] contains 0x0001 - // and so on - a->vpcmpeqw(oneReg, oneReg, oneReg); - a->vpsrlw(oneReg, oneReg, 15); - a->imul(ldcReg, ldcReg, static_cast<asmjit::Imm>(sizeof(int32_t))); - a->mov(C_Offset, 0); - - int colRegs = nc * row_interleave * sizeof(int8_t) / VLEN_; - if (mRegBlocks > 0) { - // move 0 to iteration variables - a->mov(iIdx, 0); - - // save B_buffer address - a->mov(buffer_B_saved, buffer_B); - a->mov(B_pf_saved, B_pf); - - a->bind(LoopMBlocks); - a->inc(iIdx); - - int rowRegs = mRegBlockSize; - - // init C registers - initCRegs<inst_set_t::avx2>(a, rowRegs, colRegs, colRegs); - - // init k loop index - a->mov(kIdx, 0); - a->bind(Loopk); - - // k is incremented by row_interleave - a->add(kIdx, static_cast<asmjit::Imm>(row_interleave)); - - genComputeBlock<inst_set_t::avx2>( - a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock, colRegs); - - // update buffer_A address for next k iteration - a->add( - buffer_A, static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t))); - - // update buffer_B address for next k iteration - a->add( - buffer_B, - static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t))); - a->add( - B_pf, - static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t))); - - // a->add(B_pf, 32*sizeof(float)); - - a->cmp(kIdx, kSize); - a->jl(Loopk); - - // store C matrix - storeCRegs<inst_set_t::avx2>( - a, rowRegs, colRegs, C_Offset, ldcReg, accum, colRegs); - - // increment A for next block - a->sub(buffer_A, kSize); - a->add( - buffer_A, static_cast<asmjit::Imm>((rowRegs)*kBlock * sizeof(uint8_t))); - - // increment C for next block - a->imul(C_Offset, ldcReg, static_cast<asmjit::Imm>(rowRegs)); - a->add(CBase, C_Offset); + // assert(mc <= 12 && "mc must be <= 12 (available registers constraint)"); + int mRegBlocks = mc / mRegBlockSize; + int mRegBlocksRem = mc % mRegBlockSize; + + // arguments to the function created + x86::Gp buffer_A = a->zdi(); + x86::Gp buffer_B = a->zsi(); + x86::Gp B_pf = a->zdx(); + x86::Gp CBase = a->zcx(); + x86::Gp kSize = a->gpz(8); + x86::Gp ldcReg = a->gpz(9); + + asmjit::FuncDetail func; + func.init( + asmjit::FuncSignatureT<void, uint8_t *, int8_t *, int8_t *, int32_t *, + int, int>(asmjit::CallConv::kIdHost)); + + asmjit::FuncFrame frame; + frame.init(func); + frame.setDirtyRegs( + x86::Reg::kGroupVec, + asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) | + asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15)); + frame.setDirtyRegs(x86::Reg::kGroupGp, + asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14)); + + asmjit::FuncArgsAssignment args(&func); + args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg); + + args.updateFuncFrame(frame); + frame.finalize(); + + a->emitProlog(frame); + a->emitArgsAssignment(frame, args); + + asmjit::Label LoopMBlocks = a->newLabel(); + + x86::Gp buffer_B_saved = a->gpz(10); + x86::Gp C_Offset = a->gpz(11); + x86::Gp B_pf_saved = a->gpz(12); + x86::Gp iIdx = a->gpz(13); + x86::Gp kIdx = a->gpz(14); + // x86::Gp B_pf = a->gpz(8); + + x86::Ymm oneReg = x86::ymm15; + // create 16-bit 1s + // i.e., oneReg[0:15] contains 0x0001, oneReg[16:31] contains 0x0001 + // and so on + a->vpcmpeqw(oneReg, oneReg, oneReg); + a->vpsrlw(oneReg, oneReg, 15); + a->imul(ldcReg, ldcReg, static_cast<asmjit::Imm>(sizeof(int32_t))); a->mov(C_Offset, 0); - // reset B - a->mov(buffer_B, buffer_B_saved); - a->mov(B_pf, B_pf_saved); - a->cmp(iIdx, mRegBlocks); - a->jl(LoopMBlocks); - } - // generate code for remainder - if (mRegBlocksRem > 0) { - asmjit::Label LoopkRem = a->newLabel(); - int rowRegs = mRegBlocksRem; - - // init C registers - initCRegs<inst_set_t::avx2>(a, rowRegs, colRegs, colRegs); - - // init k loop index - a->mov(kIdx, 0); - a->bind(LoopkRem); - - // k is incremented by row_interleave - a->add(kIdx, static_cast<asmjit::Imm>(row_interleave)); - - genComputeBlock<inst_set_t::avx2>( - a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock, colRegs); - - // update buffer_A address for next k iteration - a->add( - buffer_A, static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t))); - - // update buffer_B address for next k iteration - a->add( - buffer_B, - static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t))); - a->add( - B_pf, - static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t))); - - a->cmp(kIdx, kSize); - a->jl(LoopkRem); - - // store C matrix - storeCRegs<inst_set_t::avx2>( - a, rowRegs, colRegs, C_Offset, ldcReg, accum, colRegs); - } + int colRegs = nc * row_interleave * sizeof(int8_t) / VLEN_; - a->emitEpilog(frame); + auto issueLoopOverK = [&](int rowRegs) { + asmjit::Label LoopKLabel = a->newLabel(); - jit_micro_kernel_fp fn; - asmjit::Error err = rt_.add(&fn, &code_); - if (err) { - std::cout << "Error: in fn add" << std::endl; - return nullptr; - } - codeCache_[kernelSig] = fn; + // Init C (result) vector registers + initCRegs<inst_set_t::avx2>(a, rowRegs, colRegs, colRegs); + + // Loops over K + a->mov(kIdx, 0); + a->bind(LoopKLabel); + + // k is incremented by row_interleave + a->add(kIdx, static_cast<asmjit::Imm>(row_interleave)); + + genComputeBlock<inst_set_t::avx2>(a, buffer_A, buffer_B, B_pf, rowRegs, + colRegs, kBlock, colRegs); + + // update buffer_A address for next k iteration + a->add(buffer_A, + static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t))); + + // update buffer_B address for next k iteration + a->add(buffer_B, static_cast<asmjit::Imm>(nBlock * row_interleave * + sizeof(int8_t))); + a->add(B_pf, static_cast<asmjit::Imm>(nBlock * row_interleave * + sizeof(int8_t))); + + a->cmp(kIdx, kSize); + a->jl(LoopKLabel); + + // store C matrix + storeCRegs<inst_set_t::avx2>(a, rowRegs, colRegs, C_Offset, ldcReg, accum, + colRegs); + }; + + if (mRegBlocks > 0) { + // move 0 to iteration variables + a->mov(iIdx, 0); + + // save B_buffer address + a->mov(buffer_B_saved, buffer_B); + a->mov(B_pf_saved, B_pf); + + a->bind(LoopMBlocks); + a->inc(iIdx); + + issueLoopOverK(mRegBlockSize); + + int rowRegs = mRegBlockSize; + // increment A for next block + a->sub(buffer_A, kSize); + a->add(buffer_A, + static_cast<asmjit::Imm>((rowRegs)*kBlock * sizeof(uint8_t))); + + // increment C for next block + a->imul(C_Offset, ldcReg, static_cast<asmjit::Imm>(rowRegs)); + a->add(CBase, C_Offset); + a->mov(C_Offset, 0); + + // reset B + a->mov(buffer_B, buffer_B_saved); + a->mov(B_pf, B_pf_saved); + + a->cmp(iIdx, mRegBlocks); + a->jl(LoopMBlocks); + } + // generate code for remainder + if (mRegBlocksRem > 0) { + issueLoopOverK(mRegBlocksRem); + } + + a->emitEpilog(frame); + + jit_micro_kernel_fp fn; + asmjit::Error err; + { + std::unique_lock<std::mutex> lock(rtMutex_); + err = runtime().add(&fn, &code); + } + if (err) { + std::cout << "Error: in fn add" << std::endl; + return nullptr; + } #if defined(FBGEMM_LOG_CODE) - fclose(codeLogfile); - delete codeLogger; + fclose(codeLogfile); + delete codeLogger; #endif - return fn; + return fn; + }); } } // namespace fbgemm diff --git a/src/GenerateKernelU8S8S32ACC32Avx512.cc b/src/GenerateKernelU8S8S32ACC32Avx512.cc index 5986e48..5037292 100644 --- a/src/GenerateKernelU8S8S32ACC32Avx512.cc +++ b/src/GenerateKernelU8S8S32ACC32Avx512.cc @@ -23,12 +23,13 @@ void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::initCRegs< int rowRegs, int colRegs, int leadingDimCReg) { + using CRegs = x86::Zmm; for (int i = 0; i < rowRegs; ++i) { for (int j = 0; j < colRegs; ++j) { a->vxorps( - CRegs_avx512_[i * leadingDimCReg + j], - CRegs_avx512_[i * leadingDimCReg + j], - CRegs_avx512_[i * leadingDimCReg + j]); + CRegs(i * leadingDimCReg + j), + CRegs(i * leadingDimCReg + j), + CRegs(i * leadingDimCReg + j)); } } } @@ -61,6 +62,7 @@ void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::genComputeBlock< // temporary register x86::Zmm res1 = x86::zmm28; + using CRegs = x86::Zmm; for (int j = 0; j < colRegs; ++j) { // load B a->vmovaps(BReg, x86::dword_ptr(buffer_B, j * VLEN_ * sizeof(int8_t))); @@ -71,9 +73,9 @@ void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::genComputeBlock< a->vpmaddubsw(res1, AReg, BReg); a->vpmaddwd(res1, oneReg, res1); a->vpaddd( - CRegs_avx512_[i * leadingDimCReg + j], + CRegs(i * leadingDimCReg + j), res1, - CRegs_avx512_[i * leadingDimCReg + j]); + CRegs(i * leadingDimCReg + j)); } a->prefetcht0(x86::dword_ptr(B_pf, j * VLEN_ * sizeof(int8_t))); } @@ -94,6 +96,7 @@ void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::storeCRegs< x86::Gp ldcReg, bool accum, int leadingDimCReg) { + using CRegs = x86::Zmm; for (int i = 0; i < rowRegs; ++i) { if (i != 0) { a->add(C_Offset, ldcReg); @@ -103,15 +106,21 @@ void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::storeCRegs< for (int j = 0; j < colRegs; ++j) { if (accum) { a->vpaddd( - CRegs_avx512_[i * leadingDimCReg + j], - CRegs_avx512_[i * leadingDimCReg + j], + CRegs(i * leadingDimCReg + j), + CRegs(i * leadingDimCReg + j), +#ifdef _MSC_VER x86::dword_ptr(a->gpz(9), C_Offset, 0, j * 16 * sizeof(int32_t))); - // x86::dword_ptr(a->zcx(), C_Offset, 0, j * 16 * sizeof(int32_t))); +#else + x86::dword_ptr(a->zcx(), C_Offset, 0, j * 16 * sizeof(int32_t))); +#endif } a->vmovups( +#ifdef _MSC_VER x86::dword_ptr(a->gpz(9), C_Offset, 0, j * 16 * sizeof(int32_t)), -// x86::dword_ptr(a->zcx(), C_Offset, 0, j * 16 * sizeof(int32_t)), - CRegs_avx512_[i * leadingDimCReg + j]); +#else + x86::dword_ptr(a->zcx(), C_Offset, 0, j * 16 * sizeof(int32_t)), +#endif + CRegs(i * leadingDimCReg + j)); } } } @@ -165,302 +174,269 @@ CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreate<inst_set_t::avx512>( nRegBlockSize, nRegBlockSizeMin); - if (codeCache_.find(kernelSig) != codeCache_.end()) { - return codeCache_[kernelSig]; - } - code_.reset(false); - code_.init(rt_.codeInfo()); - x86::Assembler assembler(&code_); - x86::Emitter* a = assembler.as<x86::Emitter>(); - + return codeCache_.getOrCreate(kernelSig, [&]() -> jit_micro_kernel_fp { + asmjit::CodeHolder code; + code.init(runtime().codeInfo()); + x86::Assembler assembler(&code); + x86::Emitter *a = assembler.as<x86::Emitter>(); #if defined(FBGEMM_LOG_CODE) - // generated code logging - FILE* codeLogfile = fopen( - getCodeLoggingFile<inst_set_t::avx512>( - accum, - mc, - nc, - nBlock, - kBlock, - mRegBlockSize, - nRegBlockSize, - nRegBlockSizeMin) - .c_str(), - "w"); - asmjit::FileLogger* codeLogger = new asmjit::FileLogger(codeLogfile); - if (codeLogger) { - code_.setLogger(codeLogger); - } + // generated code logging + FILE *codeLogfile = fopen(getCodeLoggingFile<inst_set_t::avx512>( + accum, mc, nc, nBlock, kBlock, mRegBlockSize, + nRegBlockSize, nRegBlockSizeMin) + .c_str(), + "w"); + asmjit::FileLogger *codeLogger = new asmjit::FileLogger(codeLogfile); + if (codeLogger) { + code.setLogger(codeLogger); + } #endif - assert(kc % row_interleave == 0 && "kc must be a multiple of row_interleave"); - assert(nc % nRegBlockSizeMin == 0 && "nc must be a multiple of NR_MIN"); - int maxMRegs = mRegBlockSize; - int maxNRegs = nRegBlockSize * row_interleave / VLEN_; - assert( - maxMRegs * maxNRegs <= 28 && - "MR*(NR*ROW_INTERLEAVE*8/512) \ + assert(kc % row_interleave == 0 && + "kc must be a multiple of row_interleave"); + assert(nc % nRegBlockSizeMin == 0 && "nc must be a multiple of NR_MIN"); + int maxMRegs = mRegBlockSize; + int maxNRegs = nRegBlockSize * row_interleave / VLEN_; + assert(maxMRegs * maxNRegs <= 28 && "MR*(NR*ROW_INTERLEAVE*8/512) \ must be <= 28(available registers constraint)"); - int mRegBlocks = mc / mRegBlockSize; - int mRegBlocksRem = mc % mRegBlockSize; + int mRegBlocks = mc / mRegBlockSize; + int mRegBlocksRem = mc % mRegBlockSize; - // arguments to the function created + // arguments to the function created #ifdef _MSC_VER - x86::Gp buffer_A = a->zcx(); - x86::Gp buffer_B = a->zdx(); - x86::Gp B_pf = a->gpz(8); - x86::Gp CBase = a->gpz(9); - x86::Gp kSize = a->zdi(); // a->zsi(); // x86::esi; // a->zsi(); - x86::Gp ldcReg = a->zsi(); // a->zdi(); // x86::edi; // a->zdi(); + x86::Gp buffer_A = a->zcx(); + x86::Gp buffer_B = a->zdx(); + x86::Gp B_pf = a->gpz(8); + x86::Gp CBase = a->gpz(9); + x86::Gp kSize = a->zdi(); + x86::Gp ldcReg = a->zsi(); #else - x86::Gp buffer_A = a->zdi(); - x86::Gp buffer_B = a->zsi(); - x86::Gp B_pf = a->zdx(); - x86::Gp CBase = a->zcx(); - x86::Gp kSize = a->gpz(8); - x86::Gp ldcReg = a->gpz(9); + x86::Gp buffer_A = a->zdi(); + x86::Gp buffer_B = a->zsi(); + x86::Gp B_pf = a->zdx(); + x86::Gp CBase = a->zcx(); + x86::Gp kSize = a->gpz(8); + x86::Gp ldcReg = a->gpz(9); #endif - asmjit::FuncDetail func; -#ifdef _MSC_VER - //func.init(asmjit::FuncSignature4<void, uint8_t*, int8_t*, int8_t*, int32_t*>( - // asmjit::CallConv::kIdHost)); - func.init(asmjit::FuncSignatureT<void, uint8_t*, int8_t*, int8_t*, int32_t*, int, int>( - asmjit::CallConv::kIdHost)); -#else - func.init( - asmjit:: - FuncSignatureT<void, uint8_t*, int8_t*, int8_t*, int32_t*, int, int>( - asmjit::CallConv::kIdHost)); -#endif - - asmjit::FuncFrame frame; - frame.init(func); - - frame.setDirtyRegs( - x86::Reg::kGroupVec, - asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) | - asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15)); - frame.setDirtyRegs( - x86::Reg::kGroupGp, - asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15)); - - asmjit::FuncArgsAssignment args(&func); - args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg); - - args.updateFuncFrame(frame); - frame.finalize(); - - a->emitProlog(frame); - a->emitArgsAssignment(frame, args); - -//#ifdef _MSC_VER -// // retrieve parameters from stack -// a->mov(kSize, asmjit::x86::dword_ptr(asmjit::x86::rsp, func.getArg(4).getStackOffset())); //0x20)); //func.getArg(4).getStackOffset())); -// std::cout << "func.getArg(4).getStackOffset(): " << func.getArg(4).getStackOffset() << std::endl; -// a->mov(ldcReg, asmjit::x86::dword_ptr(asmjit::x86::rsp, func.getArg(5).getStackOffset())); //;0x28)); //func.getArg(5).getStackOffset())); -// std::cout << "func.getArg(5).getStackOffset(): " << func.getArg(5).getStackOffset() << std::endl; -//#endif - - asmjit::Label LoopMBlocks = a->newLabel(); - asmjit::Label LoopNBlocks = a->newLabel(); - asmjit::Label Loopk = a->newLabel(); - - x86::Gp buffer_B_saved = a->gpz(10); - x86::Gp C_Offset = a->gpz(11); - x86::Gp B_pf_saved = a->gpz(12); - x86::Gp iIdx = a->gpz(13); - x86::Gp jIdx = a->gpz(14); - x86::Gp kIdx = a->gpz(15); - // x86::Gp B_pf = a->gpz(8); - - x86::Zmm oneReg = x86::zmm29; - // create 16-bit 1s - // i.e., oneReg[0:15] contains 0x0001, oneReg[16:31] contains 0x0001 - // and so on - // a->vpcmpeqw(oneReg, oneReg, oneReg); - a->vpternlogd(oneReg, oneReg, oneReg, 0xff); - a->vpsrlw(oneReg, oneReg, 15); - a->imul(ldcReg, ldcReg, static_cast<asmjit::Imm>(sizeof(int32_t))); - - // save B_buffer address - a->mov(buffer_B_saved, buffer_B); - a->mov(B_pf_saved, B_pf); - - int currColRegs = nc * row_interleave * sizeof(int8_t) / VLEN_; - int colRegs = std::min(currColRegs, maxNRegs); - if (mRegBlocks > 0) { - // move 0 to iteration variables - a->mov(iIdx, 0); - - a->bind(LoopMBlocks); - a->inc(iIdx); - a->mov(jIdx, 0); - - a->bind(LoopNBlocks); - a->inc(jIdx); - - int rowRegs = mRegBlockSize; - - // init C registers - initCRegs<inst_set_t::avx512>(a, rowRegs, colRegs, colRegs); - - // init k loop index - a->mov(kIdx, 0); - a->bind(Loopk); - - // k is incremented by row_interleave - a->add(kIdx, static_cast<asmjit::Imm>(row_interleave)); - - genComputeBlock<inst_set_t::avx512>( - a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock, colRegs); - - // update buffer_A address for next k iteration - a->add( - buffer_A, static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t))); - - // update buffer_B address for next k iteration - a->add( - buffer_B, - static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t))); - a->add( - B_pf, - static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t))); - - // a->add(B_pf, static_cast<asmjit::Imm>(32*sizeof(float))); - - a->cmp(kIdx, kSize); - a->jl(Loopk); - - // store C matrix - storeCRegs<inst_set_t::avx512>( - a, rowRegs, colRegs, C_Offset, ldcReg, accum, colRegs); - - // reset A - a->sub(buffer_A, kSize); - - // B for next block - a->mov(buffer_B, buffer_B_saved); - // using C_Offset as temp reg - a->imul( - C_Offset, - jIdx, - static_cast<asmjit::Imm>( - nRegBlockSize * row_interleave * sizeof(int8_t))); - a->add(buffer_B, C_Offset); - a->mov(B_pf, B_pf_saved); - a->add(B_pf, C_Offset); - - // increment C for next B block - a->add(CBase, static_cast<asmjit::Imm>(nRegBlockSize * sizeof(int32_t))); - - int jLoopTrips = currColRegs / maxNRegs; - // jLoopTrips should be at least 1 - jLoopTrips = jLoopTrips ? jLoopTrips : 1; - a->cmp(jIdx, jLoopTrips); - a->jl(LoopNBlocks); - - // increment A for next block - a->add( - buffer_A, static_cast<asmjit::Imm>((rowRegs)*kBlock * sizeof(uint8_t))); - - // increment C for next A block - a->sub( - CBase, - static_cast<asmjit::Imm>(jLoopTrips * nRegBlockSize * sizeof(int32_t))); - a->imul(C_Offset, ldcReg, static_cast<asmjit::Imm>(rowRegs)); - a->add(CBase, C_Offset); - - // reset B - a->mov(buffer_B, buffer_B_saved); - a->mov(B_pf, B_pf_saved); - a->cmp(iIdx, mRegBlocks); - a->jl(LoopMBlocks); - } - // generate code for remainder - if (mRegBlocksRem > 0) { - asmjit::Label LoopNRem = a->newLabel(); - asmjit::Label LoopkRem = a->newLabel(); - int rowRegs = mRegBlocksRem; - - a->mov(jIdx, 0); - a->bind(LoopNRem); - a->inc(jIdx); - - // init C registers - initCRegs<inst_set_t::avx512>(a, rowRegs, colRegs, colRegs); - - // init k loop index - a->mov(kIdx, 0); - a->bind(LoopkRem); - - // k is incremented by row_interleave - a->add(kIdx, static_cast<asmjit::Imm>(row_interleave)); - - genComputeBlock<inst_set_t::avx512>( - a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock, colRegs); - - // update buffer_A address for next k iteration - a->add( - buffer_A, static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t))); - - // update buffer_B address for next k iteration - a->add( - buffer_B, - static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t))); - a->add( - B_pf, - static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t))); - - a->cmp(kIdx, kSize); - a->jl(LoopkRem); - - // reset A - a->sub(buffer_A, kSize); - - // B for next block - // using C_Offset as temp reg - a->imul( - C_Offset, - jIdx, - static_cast<asmjit::Imm>( - nRegBlockSize * row_interleave * sizeof(int8_t))); - a->mov(buffer_B, buffer_B_saved); - a->add(buffer_B, C_Offset); - a->mov(B_pf, B_pf_saved); - a->add(B_pf, C_Offset); - - // store C matrix - storeCRegs<inst_set_t::avx512>( - a, rowRegs, colRegs, C_Offset, ldcReg, accum, colRegs); - - // increment C for next B block - a->add(CBase, static_cast<asmjit::Imm>(nRegBlockSize * sizeof(int32_t))); - - int jLoopTrips = currColRegs / maxNRegs; - // jLoopTrips should be at least 1 - jLoopTrips = jLoopTrips ? jLoopTrips : 1; - a->cmp(jIdx, jLoopTrips); - a->jl(LoopNRem); - } + asmjit::FuncDetail func; + func.init( + asmjit::FuncSignatureT<void, uint8_t *, int8_t *, int8_t *, int32_t *, + int, int>(asmjit::CallConv::kIdHost)); + + asmjit::FuncFrame frame; + frame.init(func); + + frame.setDirtyRegs( + x86::Reg::kGroupVec, + asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) | + asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15)); + frame.setDirtyRegs(x86::Reg::kGroupGp, + asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15)); + + asmjit::FuncArgsAssignment args(&func); + args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg); + + args.updateFuncFrame(frame); + frame.finalize(); + + a->emitProlog(frame); + a->emitArgsAssignment(frame, args); + + asmjit::Label LoopMBlocks = a->newLabel(); + asmjit::Label LoopNBlocks = a->newLabel(); + asmjit::Label Loopk = a->newLabel(); + + x86::Gp buffer_B_saved = a->gpz(10); + x86::Gp C_Offset = a->gpz(11); + x86::Gp B_pf_saved = a->gpz(12); + x86::Gp iIdx = a->gpz(13); + x86::Gp jIdx = a->gpz(14); + x86::Gp kIdx = a->gpz(15); + // x86::Gp B_pf = a->gpz(8); + + x86::Zmm oneReg = x86::zmm29; + // create 16-bit 1s + // i.e., oneReg[0:15] contains 0x0001, oneReg[16:31] contains 0x0001 + // and so on + // a->vpcmpeqw(oneReg, oneReg, oneReg); + a->vpternlogd(oneReg, oneReg, oneReg, 0xff); + a->vpsrlw(oneReg, oneReg, 15); + a->imul(ldcReg, ldcReg, static_cast<asmjit::Imm>(sizeof(int32_t))); + + // save B_buffer address + a->mov(buffer_B_saved, buffer_B); + a->mov(B_pf_saved, B_pf); + + int currColRegs = nc * row_interleave * sizeof(int8_t) / VLEN_; + int colRegs = std::min(currColRegs, maxNRegs); + if (mRegBlocks > 0) { + // move 0 to iteration variables + a->mov(iIdx, 0); + + a->bind(LoopMBlocks); + a->inc(iIdx); + a->mov(jIdx, 0); + + a->bind(LoopNBlocks); + a->inc(jIdx); + + int rowRegs = mRegBlockSize; + + // init C registers + initCRegs<inst_set_t::avx512>(a, rowRegs, colRegs, colRegs); + + // init k loop index + a->mov(kIdx, 0); + a->bind(Loopk); + + // k is incremented by row_interleave + a->add(kIdx, static_cast<asmjit::Imm>(row_interleave)); + + genComputeBlock<inst_set_t::avx512>(a, buffer_A, buffer_B, B_pf, rowRegs, + colRegs, kBlock, colRegs); + + // update buffer_A address for next k iteration + a->add(buffer_A, + static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t))); + + // update buffer_B address for next k iteration + a->add(buffer_B, static_cast<asmjit::Imm>(nBlock * row_interleave * + sizeof(int8_t))); + a->add(B_pf, static_cast<asmjit::Imm>(nBlock * row_interleave * + sizeof(int8_t))); + + // a->add(B_pf, static_cast<asmjit::Imm>(32*sizeof(float))); + + a->cmp(kIdx, kSize); + a->jl(Loopk); + + // store C matrix + storeCRegs<inst_set_t::avx512>(a, rowRegs, colRegs, C_Offset, ldcReg, + accum, colRegs); + + // reset A + a->sub(buffer_A, kSize); + + // B for next block + a->mov(buffer_B, buffer_B_saved); + // using C_Offset as temp reg + a->imul(C_Offset, jIdx, + static_cast<asmjit::Imm>(nRegBlockSize * row_interleave * + sizeof(int8_t))); + a->add(buffer_B, C_Offset); + a->mov(B_pf, B_pf_saved); + a->add(B_pf, C_Offset); + + // increment C for next B block + a->add(CBase, static_cast<asmjit::Imm>(nRegBlockSize * sizeof(int32_t))); + + int jLoopTrips = currColRegs / maxNRegs; + // jLoopTrips should be at least 1 + jLoopTrips = jLoopTrips ? jLoopTrips : 1; + a->cmp(jIdx, jLoopTrips); + a->jl(LoopNBlocks); + + // increment A for next block + a->add(buffer_A, + static_cast<asmjit::Imm>((rowRegs)*kBlock * sizeof(uint8_t))); + + // increment C for next A block + a->sub(CBase, static_cast<asmjit::Imm>(jLoopTrips * nRegBlockSize * + sizeof(int32_t))); + a->imul(C_Offset, ldcReg, static_cast<asmjit::Imm>(rowRegs)); + a->add(CBase, C_Offset); + + // reset B + a->mov(buffer_B, buffer_B_saved); + a->mov(B_pf, B_pf_saved); + a->cmp(iIdx, mRegBlocks); + a->jl(LoopMBlocks); + } + // generate code for remainder + if (mRegBlocksRem > 0) { + asmjit::Label LoopNRem = a->newLabel(); + asmjit::Label LoopkRem = a->newLabel(); + int rowRegs = mRegBlocksRem; + + a->mov(jIdx, 0); + a->bind(LoopNRem); + a->inc(jIdx); + + // init C registers + initCRegs<inst_set_t::avx512>(a, rowRegs, colRegs, colRegs); + + // init k loop index + a->mov(kIdx, 0); + a->bind(LoopkRem); + + // k is incremented by row_interleave + a->add(kIdx, static_cast<asmjit::Imm>(row_interleave)); + + genComputeBlock<inst_set_t::avx512>(a, buffer_A, buffer_B, B_pf, rowRegs, + colRegs, kBlock, colRegs); + + // update buffer_A address for next k iteration + a->add(buffer_A, + static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t))); + + // update buffer_B address for next k iteration + a->add(buffer_B, static_cast<asmjit::Imm>(nBlock * row_interleave * + sizeof(int8_t))); + a->add(B_pf, static_cast<asmjit::Imm>(nBlock * row_interleave * + sizeof(int8_t))); + + a->cmp(kIdx, kSize); + a->jl(LoopkRem); + + // reset A + a->sub(buffer_A, kSize); + + // B for next block + // using C_Offset as temp reg + a->imul(C_Offset, jIdx, + static_cast<asmjit::Imm>(nRegBlockSize * row_interleave * + sizeof(int8_t))); + a->mov(buffer_B, buffer_B_saved); + a->add(buffer_B, C_Offset); + a->mov(B_pf, B_pf_saved); + a->add(B_pf, C_Offset); + + // store C matrix + storeCRegs<inst_set_t::avx512>(a, rowRegs, colRegs, C_Offset, ldcReg, + accum, colRegs); + + // increment C for next B block + a->add(CBase, static_cast<asmjit::Imm>(nRegBlockSize * sizeof(int32_t))); + + int jLoopTrips = currColRegs / maxNRegs; + // jLoopTrips should be at least 1 + jLoopTrips = jLoopTrips ? jLoopTrips : 1; + a->cmp(jIdx, jLoopTrips); + a->jl(LoopNRem); + } - a->emitEpilog(frame); + a->emitEpilog(frame); - jit_micro_kernel_fp fn; - asmjit::Error err = rt_.add(&fn, &code_); - if (err) { - std::cout << "Error: in fn add" << std::endl; - return nullptr; - } - codeCache_[kernelSig] = fn; + jit_micro_kernel_fp fn; + asmjit::Error err; + { + std::unique_lock<std::mutex> lock(rtMutex_); + err = runtime().add(&fn, &code); + } + if (err) { + std::cout << "Error: in fn add" << std::endl; + return nullptr; + } #if defined(FBGEMM_LOG_CODE) - fclose(codeLogfile); - delete codeLogger; + fclose(codeLogfile); + delete codeLogger; #endif - return fn; + return fn; + }); } } // namespace fbgemm diff --git a/src/GenerateKernelU8S8S32ACC32Avx512VNNI.cc b/src/GenerateKernelU8S8S32ACC32Avx512VNNI.cc index 8ae0745..1d23e90 100644 --- a/src/GenerateKernelU8S8S32ACC32Avx512VNNI.cc +++ b/src/GenerateKernelU8S8S32ACC32Avx512VNNI.cc @@ -23,12 +23,13 @@ void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::initCRegs< int rowRegs, int colRegs, int leadingDimCReg) { + using CRegs = x86::Zmm; for (int i = 0; i < rowRegs; ++i) { for (int j = 0; j < colRegs; ++j) { a->vxorps( - CRegs_avx512_[i * leadingDimCReg + j], - CRegs_avx512_[i * leadingDimCReg + j], - CRegs_avx512_[i * leadingDimCReg + j]); + CRegs(i * leadingDimCReg + j), + CRegs(i * leadingDimCReg + j), + CRegs(i * leadingDimCReg + j)); } } } @@ -55,6 +56,8 @@ void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::genComputeBlock< // used for matrix B x86::Zmm BReg = x86::zmm30; + using CRegs = x86::Zmm; + for (int j = 0; j < colRegs; ++j) { // load B a->vmovaps(BReg, x86::dword_ptr(buffer_B, j * VLEN_ * sizeof(int8_t))); @@ -62,7 +65,7 @@ void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::genComputeBlock< for (int i = 0; i < rowRegs; ++i) { a->vpbroadcastd( AReg, x86::dword_ptr(buffer_A, (i * lda) * sizeof(uint8_t))); - a->vpdpbusd(CRegs_avx512_[i * leadingDimCReg + j], AReg, BReg); + a->vpdpbusd(CRegs(i * leadingDimCReg + j), AReg, BReg); } a->prefetcht0(x86::dword_ptr(B_pf, j * VLEN_ * sizeof(int8_t))); } @@ -83,6 +86,7 @@ void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::storeCRegs< x86::Gp ldcReg, bool accum, int leadingDimCReg) { + using CRegs = x86::Zmm; for (int i = 0; i < rowRegs; ++i) { if (i != 0) { a->add(C_Offset, ldcReg); @@ -92,13 +96,13 @@ void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::storeCRegs< for (int j = 0; j < colRegs; ++j) { if (accum) { a->vpaddd( - CRegs_avx512_[i * leadingDimCReg + j], - CRegs_avx512_[i * leadingDimCReg + j], + CRegs(i * leadingDimCReg + j), + CRegs(i * leadingDimCReg + j), x86::dword_ptr(a->zcx(), C_Offset, 0, j * 16 * sizeof(int32_t))); } a->vmovups( x86::dword_ptr(a->zcx(), C_Offset, 0, j * 16 * sizeof(int32_t)), - CRegs_avx512_[i * leadingDimCReg + j]); + CRegs(i * leadingDimCReg + j)); } } } @@ -155,277 +159,260 @@ CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreate< nRegBlockSize, nRegBlockSizeMin); - if (codeCache_.find(kernelSig) != codeCache_.end()) { - return codeCache_[kernelSig]; - } - code_.reset(false); - code_.init(rt_.codeInfo()); - x86::Assembler assembler(&code_); - x86::Emitter* a = assembler.as<x86::Emitter>(); + return codeCache_.getOrCreate(kernelSig, [&]() -> jit_micro_kernel_fp { + asmjit::CodeHolder code; + code.init(runtime().codeInfo()); + x86::Assembler assembler(&code); + x86::Emitter *a = assembler.as<x86::Emitter>(); #if defined(FBGEMM_LOG_CODE) - // generated code logging - FILE* codeLogfile = fopen( - getCodeLoggingFile<inst_set_t::avx512_vnni>( - accum, - mc, - nc, - nBlock, - kBlock, - mRegBlockSize, - nRegBlockSize, - nRegBlockSizeMin) - .c_str(), - "w"); - asmjit::FileLogger* codeLogger = new asmjit::FileLogger(codeLogfile); - if (codeLogger) { - code_.setLogger(codeLogger); - } + // generated code logging + FILE *codeLogfile = fopen(getCodeLoggingFile<inst_set_t::avx512_vnni>( + accum, mc, nc, nBlock, kBlock, mRegBlockSize, + nRegBlockSize, nRegBlockSizeMin) + .c_str(), + "w"); + asmjit::FileLogger *codeLogger = new asmjit::FileLogger(codeLogfile); + if (codeLogger) { + code.setLogger(codeLogger); + } #endif - assert(kc % row_interleave == 0 && "kc must be a multiple of row_interleave"); - assert(nc % nRegBlockSizeMin == 0 && "nc must be a multiple of NR_MIN"); - int maxMRegs = mRegBlockSize; - int maxNRegs = nRegBlockSize * row_interleave / VLEN_; - assert( - maxMRegs * maxNRegs <= 28 && - "MR*(NR*ROW_INTERLEAVE*8/512) \ + assert(kc % row_interleave == 0 && + "kc must be a multiple of row_interleave"); + assert(nc % nRegBlockSizeMin == 0 && "nc must be a multiple of NR_MIN"); + int maxMRegs = mRegBlockSize; + int maxNRegs = nRegBlockSize * row_interleave / VLEN_; + assert(maxMRegs * maxNRegs <= 28 && "MR*(NR*ROW_INTERLEAVE*8/512) \ must be <= 28(available registers constraint)"); - int mRegBlocks = mc / mRegBlockSize; - int mRegBlocksRem = mc % mRegBlockSize; - - // arguments to the function created - x86::Gp buffer_A = a->zdi(); - x86::Gp buffer_B = a->zsi(); - x86::Gp B_pf = a->zdx(); - x86::Gp CBase = a->zcx(); - x86::Gp kSize = a->gpz(8); - x86::Gp ldcReg = a->gpz(9); - - asmjit::FuncDetail func; - func.init( - asmjit:: - FuncSignatureT<void, uint8_t*, int8_t*, int8_t*, int32_t*, int, int>( - asmjit::CallConv::kIdHost)); - - asmjit::FuncFrame frame; - frame.init(func); - - frame.setDirtyRegs( - x86::Reg::kGroupVec, - asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) | - asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15)); - frame.setDirtyRegs( - x86::Reg::kGroupGp, - asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15)); - - asmjit::FuncArgsAssignment args(&func); - args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg); - - args.updateFuncFrame(frame); - frame.finalize(); - - a->emitProlog(frame); - a->emitArgsAssignment(frame, args); - - asmjit::Label LoopMBlocks = a->newLabel(); - asmjit::Label LoopNBlocks = a->newLabel(); - asmjit::Label Loopk = a->newLabel(); - - x86::Gp buffer_B_saved = a->gpz(10); - x86::Gp C_Offset = a->gpz(11); - x86::Gp B_pf_saved = a->gpz(12); - x86::Gp iIdx = a->gpz(13); - x86::Gp jIdx = a->gpz(14); - x86::Gp kIdx = a->gpz(15); - // x86::Gp B_pf = a->gpz(8); - - x86::Zmm oneReg = x86::zmm29; - // create 16-bit 1s - // i.e., oneReg[0:15] contains 0x0001, oneReg[16:31] contains 0x0001 - // and so on - // a->vpcmpeqw(oneReg, oneReg, oneReg); - a->vpternlogd(oneReg, oneReg, oneReg, 0xff); - a->vpsrlw(oneReg, oneReg, 15); - a->imul(ldcReg, ldcReg, static_cast<asmjit::Imm>(sizeof(int32_t))); - - // save B_buffer address - a->mov(buffer_B_saved, buffer_B); - a->mov(B_pf_saved, B_pf); - - int currColRegs = nc * row_interleave * sizeof(int8_t) / VLEN_; - int colRegs = std::min(currColRegs, maxNRegs); - if (mRegBlocks > 0) { - // move 0 to iteration variables - a->mov(iIdx, 0); - - a->bind(LoopMBlocks); - a->inc(iIdx); - a->mov(jIdx, 0); - - a->bind(LoopNBlocks); - a->inc(jIdx); - - int rowRegs = mRegBlockSize; - - // init C registers - initCRegs<inst_set_t::avx512_vnni>(a, rowRegs, colRegs, colRegs); - - // init k loop index - a->mov(kIdx, 0); - a->bind(Loopk); - - // k is incremented by row_interleave - a->add(kIdx, static_cast<asmjit::Imm>(row_interleave)); - - genComputeBlock<inst_set_t::avx512_vnni>( - a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock, colRegs); - - // update buffer_A address for next k iteration - a->add( - buffer_A, static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t))); - - // update buffer_B address for next k iteration - a->add( - buffer_B, - static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t))); - a->add( - B_pf, - static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t))); - - // a->add(B_pf, static_cast<asmjit::Imm>(32*sizeof(float))); - - a->cmp(kIdx, kSize); - a->jl(Loopk); - - // store C matrix - storeCRegs<inst_set_t::avx512_vnni>( - a, rowRegs, colRegs, C_Offset, ldcReg, accum, colRegs); - - // reset A - a->sub(buffer_A, kSize); - - // B for next block - a->mov(buffer_B, buffer_B_saved); - // using C_Offset as temp reg - a->imul( - C_Offset, - jIdx, - static_cast<asmjit::Imm>( - nRegBlockSize * row_interleave * sizeof(int8_t))); - a->add(buffer_B, C_Offset); - a->mov(B_pf, B_pf_saved); - a->add(B_pf, C_Offset); - - // increment C for next B block - a->add(CBase, static_cast<asmjit::Imm>(nRegBlockSize * sizeof(int32_t))); - - int jLoopTrips = currColRegs / maxNRegs; - // jLoopTrips should be at least 1 - jLoopTrips = jLoopTrips ? jLoopTrips : 1; - a->cmp(jIdx, jLoopTrips); - a->jl(LoopNBlocks); - - // increment A for next block - a->add( - buffer_A, static_cast<asmjit::Imm>((rowRegs)*kBlock * sizeof(uint8_t))); - - // increment C for next A block - a->sub( - CBase, - static_cast<asmjit::Imm>(jLoopTrips * nRegBlockSize * sizeof(int32_t))); - a->imul(C_Offset, ldcReg, static_cast<asmjit::Imm>(rowRegs)); - a->add(CBase, C_Offset); - - // reset B - a->mov(buffer_B, buffer_B_saved); - a->mov(B_pf, B_pf_saved); - a->cmp(iIdx, mRegBlocks); - a->jl(LoopMBlocks); - } - // generate code for remainder - if (mRegBlocksRem > 0) { - asmjit::Label LoopNRem = a->newLabel(); - asmjit::Label LoopkRem = a->newLabel(); - int rowRegs = mRegBlocksRem; - - a->mov(jIdx, 0); - a->bind(LoopNRem); - a->inc(jIdx); - - // init C registers - initCRegs<inst_set_t::avx512_vnni>(a, rowRegs, colRegs, colRegs); - - // init k loop index - a->mov(kIdx, 0); - a->bind(LoopkRem); - - // k is incremented by row_interleave - a->add(kIdx, static_cast<asmjit::Imm>(row_interleave)); - - genComputeBlock<inst_set_t::avx512_vnni>( - a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock, colRegs); - - // update buffer_A address for next k iteration - a->add( - buffer_A, static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t))); - - // update buffer_B address for next k iteration - a->add( - buffer_B, - static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t))); - a->add( - B_pf, - static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t))); - - a->cmp(kIdx, kSize); - a->jl(LoopkRem); - - // reset A - a->sub(buffer_A, kSize); - // B for next block - // using C_Offset as temp reg - a->imul( - C_Offset, - jIdx, - static_cast<asmjit::Imm>( - nRegBlockSize * row_interleave * sizeof(int8_t))); - a->mov(buffer_B, buffer_B_saved); - a->add(buffer_B, C_Offset); - a->mov(B_pf, B_pf_saved); - a->add(B_pf, C_Offset); - - // store C matrix - storeCRegs<inst_set_t::avx512_vnni>( - a, rowRegs, colRegs, C_Offset, ldcReg, accum, colRegs); - - // increment C for next B block - a->add(CBase, static_cast<asmjit::Imm>(nRegBlockSize * sizeof(int32_t))); - - int jLoopTrips = currColRegs / maxNRegs; - // jLoopTrips should be at least 1 - jLoopTrips = jLoopTrips ? jLoopTrips : 1; - a->cmp(jIdx, jLoopTrips); - a->jl(LoopNRem); - } + int mRegBlocks = mc / mRegBlockSize; + int mRegBlocksRem = mc % mRegBlockSize; + + // arguments to the function created + x86::Gp buffer_A = a->zdi(); + x86::Gp buffer_B = a->zsi(); + x86::Gp B_pf = a->zdx(); + x86::Gp CBase = a->zcx(); + x86::Gp kSize = a->gpz(8); + x86::Gp ldcReg = a->gpz(9); + + asmjit::FuncDetail func; + func.init( + asmjit::FuncSignatureT<void, uint8_t *, int8_t *, int8_t *, int32_t *, + int, int>(asmjit::CallConv::kIdHost)); + + asmjit::FuncFrame frame; + frame.init(func); + + frame.setDirtyRegs( + x86::Reg::kGroupVec, + asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) | + asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15)); + frame.setDirtyRegs(x86::Reg::kGroupGp, + asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15)); + + asmjit::FuncArgsAssignment args(&func); + args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg); + + args.updateFuncFrame(frame); + frame.finalize(); + + a->emitProlog(frame); + a->emitArgsAssignment(frame, args); + + asmjit::Label LoopMBlocks = a->newLabel(); + asmjit::Label LoopNBlocks = a->newLabel(); + asmjit::Label Loopk = a->newLabel(); + + x86::Gp buffer_B_saved = a->gpz(10); + x86::Gp C_Offset = a->gpz(11); + x86::Gp B_pf_saved = a->gpz(12); + x86::Gp iIdx = a->gpz(13); + x86::Gp jIdx = a->gpz(14); + x86::Gp kIdx = a->gpz(15); + // x86::Gp B_pf = a->gpz(8); + + x86::Zmm oneReg = x86::zmm29; + // create 16-bit 1s + // i.e., oneReg[0:15] contains 0x0001, oneReg[16:31] contains 0x0001 + // and so on + // a->vpcmpeqw(oneReg, oneReg, oneReg); + a->vpternlogd(oneReg, oneReg, oneReg, 0xff); + a->vpsrlw(oneReg, oneReg, 15); + a->imul(ldcReg, ldcReg, static_cast<asmjit::Imm>(sizeof(int32_t))); + + // save B_buffer address + a->mov(buffer_B_saved, buffer_B); + a->mov(B_pf_saved, B_pf); + + int currColRegs = nc * row_interleave * sizeof(int8_t) / VLEN_; + int colRegs = std::min(currColRegs, maxNRegs); + if (mRegBlocks > 0) { + // move 0 to iteration variables + a->mov(iIdx, 0); + + a->bind(LoopMBlocks); + a->inc(iIdx); + a->mov(jIdx, 0); + + a->bind(LoopNBlocks); + a->inc(jIdx); + + int rowRegs = mRegBlockSize; + + // init C registers + initCRegs<inst_set_t::avx512_vnni>(a, rowRegs, colRegs, colRegs); + + // init k loop index + a->mov(kIdx, 0); + a->bind(Loopk); + + // k is incremented by row_interleave + a->add(kIdx, static_cast<asmjit::Imm>(row_interleave)); + + genComputeBlock<inst_set_t::avx512_vnni>( + a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock, colRegs); + + // update buffer_A address for next k iteration + a->add(buffer_A, + static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t))); + + // update buffer_B address for next k iteration + a->add(buffer_B, static_cast<asmjit::Imm>(nBlock * row_interleave * + sizeof(int8_t))); + a->add(B_pf, static_cast<asmjit::Imm>(nBlock * row_interleave * + sizeof(int8_t))); + + // a->add(B_pf, static_cast<asmjit::Imm>(32*sizeof(float))); + + a->cmp(kIdx, kSize); + a->jl(Loopk); + + // store C matrix + storeCRegs<inst_set_t::avx512_vnni>(a, rowRegs, colRegs, C_Offset, ldcReg, + accum, colRegs); + + // reset A + a->sub(buffer_A, kSize); + + // B for next block + a->mov(buffer_B, buffer_B_saved); + // using C_Offset as temp reg + a->imul(C_Offset, jIdx, + static_cast<asmjit::Imm>(nRegBlockSize * row_interleave * + sizeof(int8_t))); + a->add(buffer_B, C_Offset); + a->mov(B_pf, B_pf_saved); + a->add(B_pf, C_Offset); + + // increment C for next B block + a->add(CBase, static_cast<asmjit::Imm>(nRegBlockSize * sizeof(int32_t))); + + int jLoopTrips = currColRegs / maxNRegs; + // jLoopTrips should be at least 1 + jLoopTrips = jLoopTrips ? jLoopTrips : 1; + a->cmp(jIdx, jLoopTrips); + a->jl(LoopNBlocks); + + // increment A for next block + a->add(buffer_A, + static_cast<asmjit::Imm>((rowRegs)*kBlock * sizeof(uint8_t))); + + // increment C for next A block + a->sub(CBase, static_cast<asmjit::Imm>(jLoopTrips * nRegBlockSize * + sizeof(int32_t))); + a->imul(C_Offset, ldcReg, static_cast<asmjit::Imm>(rowRegs)); + a->add(CBase, C_Offset); + + // reset B + a->mov(buffer_B, buffer_B_saved); + a->mov(B_pf, B_pf_saved); + a->cmp(iIdx, mRegBlocks); + a->jl(LoopMBlocks); + } + // generate code for remainder + if (mRegBlocksRem > 0) { + asmjit::Label LoopNRem = a->newLabel(); + asmjit::Label LoopkRem = a->newLabel(); + int rowRegs = mRegBlocksRem; + + a->mov(jIdx, 0); + a->bind(LoopNRem); + a->inc(jIdx); + + // init C registers + initCRegs<inst_set_t::avx512_vnni>(a, rowRegs, colRegs, colRegs); + + // init k loop index + a->mov(kIdx, 0); + a->bind(LoopkRem); + + // k is incremented by row_interleave + a->add(kIdx, static_cast<asmjit::Imm>(row_interleave)); + + genComputeBlock<inst_set_t::avx512_vnni>( + a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock, colRegs); + + // update buffer_A address for next k iteration + a->add(buffer_A, + static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t))); + + // update buffer_B address for next k iteration + a->add(buffer_B, static_cast<asmjit::Imm>(nBlock * row_interleave * + sizeof(int8_t))); + a->add(B_pf, static_cast<asmjit::Imm>(nBlock * row_interleave * + sizeof(int8_t))); + + a->cmp(kIdx, kSize); + a->jl(LoopkRem); + + // reset A + a->sub(buffer_A, kSize); + // B for next block + // using C_Offset as temp reg + a->imul(C_Offset, jIdx, + static_cast<asmjit::Imm>(nRegBlockSize * row_interleave * + sizeof(int8_t))); + a->mov(buffer_B, buffer_B_saved); + a->add(buffer_B, C_Offset); + a->mov(B_pf, B_pf_saved); + a->add(B_pf, C_Offset); + + // store C matrix + storeCRegs<inst_set_t::avx512_vnni>(a, rowRegs, colRegs, C_Offset, ldcReg, + accum, colRegs); + + // increment C for next B block + a->add(CBase, static_cast<asmjit::Imm>(nRegBlockSize * sizeof(int32_t))); + + int jLoopTrips = currColRegs / maxNRegs; + // jLoopTrips should be at least 1 + jLoopTrips = jLoopTrips ? jLoopTrips : 1; + a->cmp(jIdx, jLoopTrips); + a->jl(LoopNRem); + } - a->emitEpilog(frame); + a->emitEpilog(frame); - jit_micro_kernel_fp fn; - asmjit::Error err = rt_.add(&fn, &code_); - if (err) { - std::cout << "Error: in fn add" << std::endl; - return nullptr; - } - codeCache_[kernelSig] = fn; + jit_micro_kernel_fp fn; + asmjit::Error err; + { + std::unique_lock<std::mutex> lock(rtMutex_); + err = runtime().add(&fn, &code); + } + if (err) { + std::cout << "Error: in fn add" << std::endl; + return nullptr; + } #if defined(FBGEMM_LOG_CODE) - fclose(codeLogfile); - delete codeLogger; + fclose(codeLogfile); + delete codeLogger; #endif - return fn; + return fn; + }); } } // namespace fbgemm diff --git a/src/GroupwiseConv.h b/src/GroupwiseConv.h index 4c5eea5..58ee24d 100644 --- a/src/GroupwiseConv.h +++ b/src/GroupwiseConv.h @@ -10,8 +10,10 @@ #include <cassert> #include <cstdint> #include <map> +#include <mutex> #include <string> #include <tuple> +#include "CodeCache.h" #include "fbgemm/ConvUtils.h" #include "fbgemm/Fbgemm.h" #include "fbgemm/Utils.h" @@ -217,16 +219,23 @@ class GenConvKernel { template <inst_set_t instSet> void storeResultRowoffset(x86::Emitter* a, int offset = 0); - static thread_local asmjit::JitRuntime rt_; ///< JIT Runtime for asmjit. - static thread_local asmjit::CodeHolder code_; ///< JIT Code Holder for asmjit. - static thread_local std:: - map<std::tuple<bool, int, int, int>, jit_conv_kernel_fp> - codeCache_; ///< JIT Code Cache for reuse. - static thread_local std:: - map<std::tuple<bool, int, int, int>, jit_rowoffset_kernel_fp> - codeCacheRowOffset_; ///< JIT Code Cache for row offset kernel. - private: + static asmjit::JitRuntime &runtime() { + static asmjit::JitRuntime rt; //< JIT Runtime for asmjit, + // depents on other static + // variables. Required to prevent + // initialization order fiasco + return rt; + } + + static std::mutex rtMutex_; ///< Controll access to runtime; + + static CodeCache<std::tuple<bool, int, int, int>, jit_conv_kernel_fp> + codeCache_; ///< JIT Code Cache for reuse. + static CodeCache<std::tuple<bool, int, int, int>, jit_rowoffset_kernel_fp> + codeCacheRowOffset_; ///< JIT Code Cache for row offset kernel. + +private: int vectorWidth_; ///< Vector width in bits. int VLEN_; ///< Vector width in elements. // avx2 specific @@ -272,4 +281,15 @@ class GenConvKernel { int W_PAD_; ///< Padding for width (left and right) }; +template <int SPATIAL_DIM, typename accT> +std::mutex GenConvKernel<SPATIAL_DIM, accT>::rtMutex_; + +template <int SPATIAL_DIM, typename accT> +CodeCache<std::tuple<bool, int, int, int>, jit_conv_kernel_fp> + GenConvKernel<SPATIAL_DIM, accT>::codeCache_; + +template <int SPATIAL_DIM, typename accT> +CodeCache<std::tuple<bool, int, int, int>, jit_rowoffset_kernel_fp> + GenConvKernel<SPATIAL_DIM, accT>::codeCacheRowOffset_; + } // namespace fbgemm diff --git a/src/GroupwiseConvAcc32Avx2.cc b/src/GroupwiseConvAcc32Avx2.cc index b140c83..d1e0fdd 100644 --- a/src/GroupwiseConvAcc32Avx2.cc +++ b/src/GroupwiseConvAcc32Avx2.cc @@ -21,20 +21,6 @@ namespace fbgemm { using namespace std; -template <int SPATIAL_DIM, typename accT> -thread_local asmjit::JitRuntime GenConvKernel<SPATIAL_DIM, accT>::rt_; - -template <int SPATIAL_DIM, typename accT> -thread_local asmjit::CodeHolder GenConvKernel<SPATIAL_DIM, accT>::code_; - -template <int SPATIAL_DIM, typename accT> -thread_local std::map<std::tuple<bool, int, int, int>, jit_conv_kernel_fp> - GenConvKernel<SPATIAL_DIM, accT>::codeCache_; - -template <int SPATIAL_DIM, typename accT> -thread_local std::map<std::tuple<bool, int, int, int>, jit_rowoffset_kernel_fp> - GenConvKernel<SPATIAL_DIM, accT>::codeCacheRowOffset_; - namespace x86 = asmjit::x86; template <int SPATIAL_DIM> @@ -91,14 +77,13 @@ jit_conv_kernel_fp getOrCreateConvKernel( // Note: Wrong code is generated if it's not one of the supported convolution assert(fbgemmOptimizedGConv<SPATIAL_DIM>(conv_param)); auto kernelSig = getKernelSig(conv_param, a_zero_point == 0); - if (GenConvKernel<SPATIAL_DIM, accT>::codeCache_.find(kernelSig) != - GenConvKernel<SPATIAL_DIM, accT>::codeCache_.end()) { - return GenConvKernel<SPATIAL_DIM, accT>::codeCache_[kernelSig]; - } else { - auto genObj = GenConvKernel<SPATIAL_DIM, accT>(conv_param, a_zero_point); - // TODO: Instruction set based dispatch - return genObj.template getOrCreate<inst_set_t::avx2>(conv_param); - } + return GenConvKernel<SPATIAL_DIM, accT>::codeCache_.getOrCreate( + kernelSig, [&]() { + auto genObj = + GenConvKernel<SPATIAL_DIM, accT>(conv_param, a_zero_point); + // TODO: Instruction set based dispatch + return genObj.template getOrCreate<inst_set_t::avx2>(conv_param); + }); } template <> @@ -1009,9 +994,9 @@ template <> template <> jit_conv_kernel_fp GenConvKernel<2, int32_t>::getOrCreate<inst_set_t::avx2>( const conv_param_t<2>& conv_param) { - code_.reset(false); - code_.init(rt_.codeInfo()); - x86::Assembler assembler(&code_); + asmjit::CodeHolder code; + code.init(runtime().codeInfo()); + x86::Assembler assembler(&code); x86::Emitter* a = assembler.as<x86::Emitter>(); #if defined(FBGEMM_LOG_CODE) @@ -1020,7 +1005,7 @@ jit_conv_kernel_fp GenConvKernel<2, int32_t>::getOrCreate<inst_set_t::avx2>( fopen(getCodeLoggingFile<inst_set_t::avx2>(false).c_str(), "w"); asmjit::FileLogger* codeLogger = new asmjit::FileLogger(codeLogfile); if (codeLogger) { - code_.setLogger(codeLogger); + code.setLogger(codeLogger); } #endif @@ -1097,13 +1082,15 @@ jit_conv_kernel_fp GenConvKernel<2, int32_t>::getOrCreate<inst_set_t::avx2>( a->emitEpilog(frame); jit_conv_kernel_fp fn; - asmjit::Error err = rt_.add(&fn, &code_); + asmjit::Error err; + { + std::unique_lock<std::mutex> lock(rtMutex_); + err = runtime().add(&fn, &code); + } if (err) { std::cout << "Error: in fn add" << std::endl; return nullptr; } - auto kernelSig = getKernelSig(conv_param, isAZeroPointZero_); - codeCache_[kernelSig] = fn; #if defined(FBGEMM_LOG_CODE) fclose(codeLogfile); @@ -1489,9 +1476,9 @@ template <> jit_rowoffset_kernel_fp GenConvKernel<2, int32_t>::getOrCreateRowOffset<inst_set_t::avx2>( const conv_param_t<2>& conv_param) { - code_.reset(false); - code_.init(rt_.codeInfo()); - x86::Assembler assembler(&code_); + asmjit::CodeHolder code; + code.init(runtime().codeInfo()); + x86::Assembler assembler(&code); x86::Emitter* a = assembler.as<x86::Emitter>(); #if defined(FBGEMM_LOG_CODE) @@ -1500,7 +1487,7 @@ GenConvKernel<2, int32_t>::getOrCreateRowOffset<inst_set_t::avx2>( fopen(getCodeLoggingFile<inst_set_t::avx2>(true).c_str(), "w"); asmjit::FileLogger* codeLogger = new asmjit::FileLogger(codeLogfile); if (codeLogger) { - code_.setLogger(codeLogger); + code.setLogger(codeLogger); } #endif @@ -1570,14 +1557,16 @@ GenConvKernel<2, int32_t>::getOrCreateRowOffset<inst_set_t::avx2>( a->emitEpilog(frame); + asmjit::Error err; jit_rowoffset_kernel_fp fn; - asmjit::Error err = rt_.add(&fn, &code_); + { + std::unique_lock<std::mutex> lock(rtMutex_); + err = runtime().add(&fn, &code); + } if (err) { std::cout << "Error: in fn add" << std::endl; return nullptr; } - auto kernelSig = getKernelSig(conv_param, isAZeroPointZero_); - codeCacheRowOffset_[kernelSig] = fn; #if defined(FBGEMM_LOG_CODE) delete codeLogger; @@ -1780,7 +1769,8 @@ template < typename outType, bool FUSE_RELU, QuantizationGranularity Q_GRAN, - int SPATIAL_DIM> + int SPATIAL_DIM, + typename BIAS_TYPE> void fbgemmGroupwiseConv( const conv_param_t<SPATIAL_DIM>& conv_param, const std::uint8_t* activations, @@ -1789,10 +1779,10 @@ void fbgemmGroupwiseConv( packed_W& packed_weights, outType* out, int32_t* outBuffer, - const ReQuantizeOutput<FUSE_RELU, Q_GRAN>& outProcess, + const ReQuantizeOutput<FUSE_RELU, Q_GRAN, BIAS_TYPE>& outProcess, int thread_id, int num_threads) { - typedef ReQuantizeOutput<FUSE_RELU, Q_GRAN> processOutputType; + typedef ReQuantizeOutput<FUSE_RELU, Q_GRAN, BIAS_TYPE> processOutputType; if (!cpuinfo_initialize()) { throw std::runtime_error("Failed to initialize cpuinfo!"); @@ -1883,15 +1873,17 @@ void fbgemmGroupwiseConv( outProcess.getBZeroPoint()[0] == 0) || rowOffsetBuf == nullptr; - requantizationParams_t r = {a_zero_point, - outProcess.getBZeroPoint(), - outProcess.getCZeroPoint(), - outProcess.getCMultiplier(), - rowOffsetBuf, - outProcess.getColOffsets(), - outProcess.getBias(), - outProcess.getNCols(), - G}; + requantizationParams_t<typename processOutputType::BIAS_T> r = { + a_zero_point, + outProcess.getBZeroPoint(), + outProcess.getCZeroPoint(), + outProcess.getCMultiplier(), + rowOffsetBuf, + outProcess.getColOffsets(), + outProcess.getBias(), + outProcess.getNCols(), + G, + outProcess.getActWScale()}; const std::int32_t* inp = outBuffer; block_type_t block{i * oh_ow, oh_ow, gOuter * K_per_G, 8 * K_per_G}; @@ -2162,15 +2154,14 @@ jit_rowoffset_kernel_fp getOrCreateRowOffsetKernel( // Note: Wrong code is generated if it's not one of the supported convolution assert(fbgemmOptimizedGConv<SPATIAL_DIM>(conv_param)); auto kernelSig = getKernelSig(conv_param, a_zero_point == 0); - if (GenConvKernel<SPATIAL_DIM, int32_t>::codeCacheRowOffset_.find( - kernelSig) != - GenConvKernel<SPATIAL_DIM, int32_t>::codeCacheRowOffset_.end()) { - return GenConvKernel<SPATIAL_DIM, int32_t>::codeCacheRowOffset_[kernelSig]; - } else { - auto genObj = GenConvKernel<SPATIAL_DIM, int32_t>(conv_param, a_zero_point); - // TODO: Instruction set based dispatch - return genObj.template getOrCreateRowOffset<inst_set_t::avx2>(conv_param); - } + return GenConvKernel<SPATIAL_DIM, int32_t>::codeCacheRowOffset_.getOrCreate( + kernelSig, [&]() { + auto genObj = + GenConvKernel<SPATIAL_DIM, int32_t>(conv_param, a_zero_point); + // TODO: Instruction set based dispatch + return genObj.template getOrCreateRowOffset<inst_set_t::avx2>( + conv_param); + }); } template <int SPATIAL_DIM> @@ -2214,7 +2205,7 @@ int rowOffsetBufferSizeGConv(const conv_param_t<SPATIAL_DIM>& conv_param) { template int rowOffsetBufferSizeGConv<2>(const conv_param_t<2>& conv_param); template int rowOffsetBufferSizeGConv<3>(const conv_param_t<3>& conv_param); -#define INSTANTIATE_BASE(RELU, Q_GRAN, SPATIAL_DIM) \ +#define INSTANTIATE_BASE(RELU, Q_GRAN, SPATIAL_DIM, BIAS_TYPE) \ template void fbgemmGroupwiseConv( \ const conv_param_t<SPATIAL_DIM>& conv_param, \ const uint8_t* activations, \ @@ -2223,13 +2214,17 @@ template int rowOffsetBufferSizeGConv<3>(const conv_param_t<3>& conv_param); PackWeightMatrixForGConv<int8_t, int32_t, SPATIAL_DIM>& packed_weights, \ uint8_t* out, \ int32_t* outBuffer, \ - const ReQuantizeOutput<RELU, Q_GRAN>& outProcess, \ + const ReQuantizeOutput<RELU, Q_GRAN, BIAS_TYPE>& outProcess, \ int thread_id, \ int num_threads); +#define INSTANTIATE_BIAS_T(RELU, Q_GRAN, SPATIAL_DIM) \ + INSTANTIATE_BASE(RELU, Q_GRAN, SPATIAL_DIM, float); \ + INSTANTIATE_BASE(RELU, Q_GRAN, SPATIAL_DIM, int32_t); + #define INSTANTIATE_SPATIAL_DIM(RELU, Q_GRAN) \ - INSTANTIATE_BASE(RELU, Q_GRAN, 2); \ - INSTANTIATE_BASE(RELU, Q_GRAN, 3); + INSTANTIATE_BIAS_T(RELU, Q_GRAN, 2); \ + INSTANTIATE_BIAS_T(RELU, Q_GRAN, 3); #define INSTANTIATE_Q_GRANS(RELU) \ INSTANTIATE_SPATIAL_DIM(RELU, QuantizationGranularity::TENSOR); \ @@ -2241,6 +2236,7 @@ INSTANTIATE_Q_GRANS(true); #undef INSTANTIATE_Q_GRANS #undef INSTANTIATE_SPATIAL_DIM +#undef INSTANTIATE_BIAS_T #undef INSTANTIATE_BASE template void fbgemmGroupwiseConv( diff --git a/src/PackAWithIm2Col.cc b/src/PackAWithIm2Col.cc index 2aca27d..6101fef 100644 --- a/src/PackAWithIm2Col.cc +++ b/src/PackAWithIm2Col.cc @@ -275,6 +275,7 @@ void PackAWithIm2Col<T, accT, SPATIAL_DIM>::pack(const block_type_t& block) { conv_p_.K[1] == 7 && conv_p_.stride[0] == 2 && conv_p_.stride[1] == 2 && conv_p_.pad[0] == 3 && conv_p_.pad[1] == 3 && block.col_size == 147 && block_p.col_size == 148 && block.col_start == 0 && + conv_p_.dilation[0] == 1 && conv_p_.dilation[1] == 1 && std::is_same<T, uint8_t>::value) { if (BaseType::blockColSize() == 256) { pack_a_with_im2col_opt< @@ -350,8 +351,10 @@ void PackAWithIm2Col<T, accT, SPATIAL_DIM>::pack(const block_type_t& block) { int r = grs / conv_p_.K[1] % conv_p_.K[0]; int g = grs / conv_p_.K[1] / conv_p_.K[0]; - int h_in = -conv_p_.pad[0] + h * conv_p_.stride[0] + r; - int w_in = -conv_p_.pad[1] + w * conv_p_.stride[1] + s; + int h_in = + -conv_p_.pad[0] + h * conv_p_.stride[0] + r * conv_p_.dilation[0]; + int w_in = + -conv_p_.pad[1] + w * conv_p_.stride[1] + s * conv_p_.dilation[1]; if (h_in < 0 || h_in >= conv_p_.IN_DIM[0] || w_in < 0 || w_in >= conv_p_.IN_DIM[1]) { @@ -399,9 +402,12 @@ void PackAWithIm2Col<T, accT, SPATIAL_DIM>::pack(const block_type_t& block) { int q = gqrs / conv_p_.K[2] / conv_p_.K[1] % conv_p_.K[0]; int g = gqrs / conv_p_.K[2] / conv_p_.K[1] / conv_p_.K[0]; - int t_in = -conv_p_.pad[0] + t * conv_p_.stride[0] + q; - int h_in = -conv_p_.pad[1] + h * conv_p_.stride[1] + r; - int w_in = -conv_p_.pad[2] + w * conv_p_.stride[2] + s; + int t_in = + -conv_p_.pad[0] + t * conv_p_.stride[0] + q * conv_p_.dilation[0]; + int h_in = + -conv_p_.pad[1] + h * conv_p_.stride[1] + r * conv_p_.dilation[1]; + int w_in = + -conv_p_.pad[2] + w * conv_p_.stride[2] + s * conv_p_.dilation[2]; if (t_in < 0 || t_in >= conv_p_.IN_DIM[0] || h_in < 0 || h_in >= conv_p_.IN_DIM[1] || w_in < 0 || diff --git a/src/PackDepthwiseConvMatrixAvx2.cc b/src/PackDepthwiseConvMatrixAvx2.cc new file mode 100644 index 0000000..0e17bcd --- /dev/null +++ b/src/PackDepthwiseConvMatrixAvx2.cc @@ -0,0 +1,203 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include "fbgemm/FbgemmI8DepthwiseAvx2.h" + +#include <immintrin.h> + +using namespace std; + +namespace fbgemm { + +// clang-format off +static int masks[8][8] = { + // NOTE: clang-format wants to use a different formatting but the current + // formatting should be easier to read. + { 0, 0, 0, 0, 0, 0, 0, 0, }, + { -1, 0, 0, 0, 0, 0, 0, 0, }, + { -1, -1, 0, 0, 0, 0, 0, 0, }, + { -1, -1, -1, 0, 0, 0, 0, 0, }, + { -1, -1, -1, -1, 0, 0, 0, 0, }, + { -1, -1, -1, -1, -1, 0, 0, 0, }, + { -1, -1, -1, -1, -1, -1, 0, 0, }, + { -1, -1, -1, -1, -1, -1, -1, 0, }, +}; +// clang-format on + +PackedDepthWiseConvMatrix::PackedDepthWiseConvMatrix( + int K, + int kernel_prod, + const int8_t* smat) + : K_(K), kernel_prod_(kernel_prod) { + // Transpose the input matrix to make packing faster. + alignas(64) int8_t smat_transposed[K * kernel_prod]; + for (int i = 0; i < kernel_prod; ++i) { + for (int j = 0; j < K; ++j) { + smat_transposed[i * K + j] = smat[i + j * kernel_prod]; + } + } + + // Allocate packed arrays + int kernel_prod_aligned = (kernel_prod + 1) / 2 * 2; + // pmat_ = static_cast<int8_t *>(fbgemmAlignedAlloc( + // 64, ((K + 31) / 32) * KERNEL_PROD_ALIGNED * 32 * sizeof(int8_t))); + posix_memalign( + (void**)&pmat_, + 64, + ((K + 31) / 32) * kernel_prod_aligned * 32 * sizeof(int8_t)); + + // Pack input matrix + // The layout is optimized to use vpmaddubsw efficiently (see + // madd_epi16x4_packed function). + // For a group of 32 channels, we have 10 32B SIMD registers. + // Denote ith channel jth filter as (i, j) + // 0th SIMD register: + // (0, 0), (0, 1), (0, 2), (0, 3), ..., (3, 0), (3, 1), (3, 2), (3, 3) + // (16, 0), (16, 1), (16, 2), (16, 3), ..., (19, 0), (19, 1), (19, 2), (19, 3) + // 1st SIMD register: + // (4, 0), (4, 1), (4, 2), (4, 3), ..., (7, 0), (7, 1), (7, 2), (7, 3) + // (20, 0), (20, 1), (20, 2), (20, 3), ..., (23, 0), (23, 1), (23, 2), (23, 3) + // 2nd SIMD register: + // (8, 0), (8, 1), (8, 2), (8, 3), ..., (11, 0), (11, 1), (11, 2), (11, 3) + // (24, 0), (24, 1), (24, 2), (24, 3), ..., (27, 0), (27, 1), (27, 2), (27, 3) + // 3rd SIMD register: + // (12, 0), (12, 1), (12, 2), (12, 3), ..., (15, 0), (15, 1), (15, 2), (15, 3) + // (28, 0), (28, 1), (28, 2), (28, 3), ..., (31, 0), (31, 1), (31, 2), (31, 3) + // 4-7th SIMD register: same as the previous 4 registers but for 4-7th filter + // coefficients + // ... + // + // REMAINDER + // If kernel_prod % 4 == 1 for example when kernel_prod == 9 + // 8th SIMD register: + // (0, 8), zero, ..., (7, 8), zero + // (16, 8), zero, ..., (23, 8), zero + // 9th SIMD register: + // (8, 8), zero, ..., (15, 8), zero + // (24, 8), zero, ..., (31, 8), zero + // We use madd_epi16_packed for this case + // + // If kernel_prod % 4 == 2 for example when kernel_prod == 10 + // 8th SIMD register: + // (0, 8), (0, 9), ..., (7, 8), (7, 9) + // (16, 8), (16, 9), ..., (23, 8), (23, 9) + // 9th SIMD register: + // (8, 8), (8, 9), ..., (15, 8), (15, 9) + // (24, 8), (24, 9), ..., (31, 8), (31, 9) + // + // If kernel_prod % 4 == 3 for example when kernel_prod == 11 + // 8th SIMD register: + // (0, 8), (0, 9), (0, 10), zero, ..., (3, 8), (3, 9), (3, 10), zero + // (16, 8), (16, 9), (16, 10), zero, ..., (19, 8), (19, 9), (19, 10), zero + // 9th SIMD register: + // (4, 8), (4, 9), (4, 10), zero, ..., (7, 8), (7, 9), (7, 10), zero + // (20, 8), (20, 9), (20, 10), zero, ..., (23, 8), (23, 9), (23, 10), zero + // 10th SIMD register: + // (8, 8), (8, 9), (8, 10), zero, ..., (11, 8), (11, 9), (11, 10), zero + // (24, 8), (24, 9), (24, 10), zero, ..., (27, 8), (27, 9), (27, 10), zero + // 11th SIMD register: + // (12, 8), (12, 9), (12, 10), zero, ..., (15, 8), (15, 9), (15, 10), zero + // (28, 8), (28, 9), (28, 10), zero, ..., (31, 8), (31, 9), (31, 10), zero + for (int k1 = 0; k1 < K; k1 += 32) { + __m256i b_v[kernel_prod]; + int remainder = K - k1; + if (remainder < 32) { + __m256i mask_v = _mm256_loadu_si256( + reinterpret_cast<const __m256i*>(masks[remainder / 4])); + for (int i = 0; i < kernel_prod; ++i) { + b_v[i] = _mm256_maskload_epi32( + reinterpret_cast<const int*>(smat_transposed + i * K + k1), mask_v); + } + } else { + for (int i = 0; i < kernel_prod; ++i) { + b_v[i] = _mm256_lddqu_si256( + reinterpret_cast<const __m256i*>(smat_transposed + i * K + k1)); + } + } + + // Interleave 2 SIMD registers + __m256i b_interleaved_epi16[kernel_prod_aligned]; + __m256i zero_v = _mm256_setzero_si256(); + for (int i = 0; i < kernel_prod_aligned / 2; ++i) { + if (2 * i + 1 >= kernel_prod) { + b_interleaved_epi16[2 * i] = _mm256_unpacklo_epi8(b_v[2 * i], zero_v); + b_interleaved_epi16[2 * i + 1] = + _mm256_unpackhi_epi8(b_v[2 * i], zero_v); + } else { + b_interleaved_epi16[2 * i] = + _mm256_unpacklo_epi8(b_v[2 * i], b_v[2 * i + 1]); + b_interleaved_epi16[2 * i + 1] = + _mm256_unpackhi_epi8(b_v[2 * i], b_v[2 * i + 1]); + } + } + + // Interleave 4 SIMD registers + __m256i b_interleaved_epi32[kernel_prod_aligned]; + for (int i = 0; i < kernel_prod_aligned / 4; ++i) { + b_interleaved_epi32[4 * i] = _mm256_unpacklo_epi16( + b_interleaved_epi16[4 * i], b_interleaved_epi16[4 * i + 2]); + b_interleaved_epi32[4 * i + 1] = _mm256_unpackhi_epi16( + b_interleaved_epi16[4 * i], b_interleaved_epi16[4 * i + 2]); + b_interleaved_epi32[4 * i + 2] = _mm256_unpacklo_epi16( + b_interleaved_epi16[4 * i + 1], b_interleaved_epi16[4 * i + 3]); + b_interleaved_epi32[4 * i + 3] = _mm256_unpackhi_epi16( + b_interleaved_epi16[4 * i + 1], b_interleaved_epi16[4 * i + 3]); + } + for (int i = kernel_prod_aligned / 4 * 4; i < kernel_prod_aligned; ++i) { + b_interleaved_epi32[i] = b_interleaved_epi16[i]; + } + + for (int i = 0; i < kernel_prod_aligned; ++i) { + _mm256_storeu_si256( + reinterpret_cast<__m256i*>( + &pmat_[((k1 / 32) * kernel_prod_aligned + i) * 32]), + b_interleaved_epi32[i]); + } + } +} + +int PackedDepthWiseConvMatrix::addr(int r, int c) { + int kernel_prod_aligned = (kernel_prod_ + 1) / 2 * 2; + if (c >= kernel_prod_ / 4 * 4 && + (kernel_prod_ % 4 == 1 || kernel_prod_ % 4 == 2)) { + int kBlock = r / 32; + int reg_idx = (r % 16) / 8 + c / 4 * 4; + + int blk_idx = kBlock * kernel_prod_aligned + reg_idx; + + int r_ = r % 8; + int c_ = c % 4; + + int in_blk_idx = (r % 32) / 16 * 16 + 2 * r_ + c_; + return blk_idx * 32 + in_blk_idx; + + } else { + int kBlock = r / 32; + int reg_idx = (r % 16) / 4 + c / 4 * 4; + + int blk_idx = kBlock * kernel_prod_aligned + reg_idx; + + int r_ = r % 4; + int c_ = c % 4; + + int in_blk_idx = (r % 32) / 16 * 16 + 4 * r_ + c_; + return blk_idx * 32 + in_blk_idx; + } +} + +void PackedDepthWiseConvMatrix::unpack(int8_t* unpacked_data) { + for (int r = 0; r < K_; ++r) { + for (int c = 0; c < kernel_prod_; ++c) { + unpacked_data[r * kernel_prod_ + c] = pmat_[addr(r, c)]; + } + } +} + +PackedDepthWiseConvMatrix::~PackedDepthWiseConvMatrix() { + free(pmat_); +} + +} // namespace fbgemm diff --git a/src/PackWeightsForConv.cc b/src/PackWeightsForConv.cc index 44f210e..192fb00 100644 --- a/src/PackWeightsForConv.cc +++ b/src/PackWeightsForConv.cc @@ -23,35 +23,17 @@ PackWeightsForConv<SPATIAL_DIM, T, accT>::PackWeightsForConv( // FbgemmConv.cc switch (ConvFastPath<SPATIAL_DIM, accT>(conv_p)) { case optimized_conv_t::depthwise: { - if (SPATIAL_DIM == 3) { - W_im2col_packed_ = nullptr; - W_dw_2D_packed_ = nullptr; - W_dw_3D_packed_ = - std::make_shared<Packed3x3x3ConvMatrix>(conv_p.G, sdata); - W_gconv_packed_ = nullptr; - } else { - W_im2col_packed_ = nullptr; - W_dw_2D_packed_ = - std::make_shared<Packed3x3ConvMatrix>(conv_p.G, sdata); - W_dw_3D_packed_ = nullptr; - W_gconv_packed_ = nullptr; - } + W_dw_packed_ = std::make_shared<PackedDepthWiseConvMatrix>( + conv_p.G, SPATIAL_DIM == 3 ? 3 * 3 * 3 : 3 * 3, sdata); break; } case optimized_conv_t::groupwise: { - W_im2col_packed_ = nullptr; - W_dw_2D_packed_ = nullptr; - W_dw_3D_packed_ = nullptr; W_gconv_packed_ = std::make_shared<PackWeightMatrixForGConv<T, accT, SPATIAL_DIM>>( matrix_op_t::Transpose, conv_p, sdata, nullptr); break; } case optimized_conv_t::pointwise: { - W_im2col_packed_ = nullptr; - W_dw_2D_packed_ = nullptr; - W_dw_3D_packed_ = nullptr; - W_gconv_packed_ = nullptr; int NDim = conv_p.OC / conv_p.G; int KDim = conv_p.K[0] * conv_p.K[1] * conv_p.IC; W_pointwise_packed_ = std::make_shared<PackBMatrix<T, accT>>( @@ -77,9 +59,6 @@ PackWeightsForConv<SPATIAL_DIM, T, accT>::PackWeightsForConv( nullptr, conv_p.G, blocking_params); - W_dw_2D_packed_ = nullptr; - W_dw_3D_packed_ = nullptr; - W_gconv_packed_ = nullptr; break; } } // switch @@ -87,10 +66,8 @@ PackWeightsForConv<SPATIAL_DIM, T, accT>::PackWeightsForConv( template <int SPATIAL_DIM, typename T, typename accT> void PackWeightsForConv<SPATIAL_DIM, T, accT>::unpack(T* origin_buf) { - if (W_dw_2D_packed_) { - W_dw_2D_packed_->unpack(origin_buf); - } else if (W_dw_3D_packed_) { - W_dw_3D_packed_->unpack(origin_buf); + if (W_dw_packed_) { + W_dw_packed_->unpack(origin_buf); } else if (W_gconv_packed_) { W_gconv_packed_->unpack(origin_buf); } else if (W_im2col_packed_) { @@ -139,7 +116,7 @@ std::string PackWeightsForConv<SPATIAL_DIM, T, accT>::mismatchingParams( }; auto combineInt = [&combineStr](std::string id, int int1, int int2) { - return combineStr(id, std::to_string(int1), std::to_string(int2)); + return combineStr(id, std::to_string(int1), std::to_string(int2)); }; if (conv_param_.IC != test_conv_p.IC) { diff --git a/src/QuantUtils.cc b/src/QuantUtils.cc index 5dde90b..a209efc 100644 --- a/src/QuantUtils.cc +++ b/src/QuantUtils.cc @@ -164,30 +164,34 @@ void ChooseRequantizationMultiplier( dst[i] = Quantize<T>(src[i], qparams); \ } \ } -FBGEMM_SPECIALIZED_QUANTIZE(int8_t) FBGEMM_SPECIALIZED_QUANTIZE(uint16_t) FBGEMM_SPECIALIZED_QUANTIZE(int16_t) FBGEMM_SPECIALIZED_QUANTIZE(int32_t) #undef FBGEMM_SPECIALIZED_QUANTIZE -template <> -void Quantize<uint8_t>( - const float* src, - uint8_t* dst, - int len, - const TensorQuantizationParams& qparams) { - bool avx2_support = cpuinfo_initialize() && fbgemmHasAvx2Support(); - bool fma_support = cpuinfo_has_x86_fma3(); - if (avx2_support && fma_support && qparams.precision == 8) { - // fast path - QuantizeAvx2(src, dst, len, qparams); - } else { - for (std::size_t i = 0; i < len; ++i) { - dst[i] = Quantize<uint8_t>(src[i], qparams); - } - } +#define FBGEMM_SPECIALIZED_QUANTIZE_AVX2(T) \ +template <> \ +void Quantize<T>( \ + const float* src, \ + T* dst, \ + int len, \ + const TensorQuantizationParams& qparams) { \ + bool avx2_support = cpuinfo_initialize() && fbgemmHasAvx2Support(); \ + bool fma_support = cpuinfo_has_x86_fma3(); \ + if (avx2_support && fma_support && qparams.precision == 8) { \ + /* fast path */ \ + QuantizeAvx2<T>(src, dst, len, qparams); \ + } else { \ + for (std::size_t i = 0; i < len; ++i) { \ + dst[i] = Quantize<T>(src[i], qparams); \ + } \ + } \ } +FBGEMM_SPECIALIZED_QUANTIZE_AVX2(int8_t) +FBGEMM_SPECIALIZED_QUANTIZE_AVX2(uint8_t) +#undef FBGEMM_SPECIALIZED_QUANTIZE_AVX2 + #define FBGEMM_SPECIALIZED_QUANTIZEGROUPWISEKCX(T) \ template <> \ void QuantizeGroupwise<T, layout_t::KCX>( \ diff --git a/src/QuantUtilsAvx2.cc b/src/QuantUtilsAvx2.cc index 875f1dc..9381f0c 100644 --- a/src/QuantUtilsAvx2.cc +++ b/src/QuantUtilsAvx2.cc @@ -18,13 +18,16 @@ using namespace std; //////////////////////////////////////////////////////////////////////////////// // Utility functions +template <typename T> void QuantizeAvx2( const float* src, - uint8_t* dst, + T* dst, int len, const TensorQuantizationParams& qparams) { #if defined(__AVX2__) && (defined(__FMA__) || defined(_MSC_VER)) constexpr int VLEN = 8; + constexpr float min_val = std::numeric_limits<T>::min(); + constexpr float max_val = std::numeric_limits<T>::max(); std::size_t i = 0; __m256 inverse_scale_v = _mm256_set1_ps(1.f / qparams.scale); __m256i shuffle_mask_v = _mm256_set_epi8( @@ -67,8 +70,8 @@ void QuantizeAvx2( __m256 transformed_v = _mm256_fmadd_ps( src_v, inverse_scale_v, _mm256_set1_ps(qparams.zero_point)); __m256 clipped_v = _mm256_min_ps( - _mm256_max_ps(transformed_v, _mm256_set1_ps(0.f)), - _mm256_set1_ps(255.f)); + _mm256_max_ps(transformed_v, _mm256_set1_ps(min_val)), + _mm256_set1_ps(max_val)); __m256i rounded_v = _mm256_cvtps_epi32(clipped_v); // An instruction sequence to save 8 32-bit integers as 8 8-bit integers @@ -80,7 +83,7 @@ void QuantizeAvx2( for (; i < len; ++i) { float transformed = qparams.zero_point + src[i] / qparams.scale; - float clipped = std::min(std::max(transformed, 0.f), 255.f); + float clipped = std::min(std::max(transformed, min_val), max_val); // Not exactly the same behavior as the vectorized code. // The vectorized code above always rounds to even in halfway cases // (https://software.intel.com/en-us/node/523819), but std::nearbyint @@ -95,6 +98,21 @@ void QuantizeAvx2( #endif } +// Instantiate QuantizeAvx2 for known datatypes +template +void QuantizeAvx2<uint8_t>( + const float* src, + uint8_t* dst, + int len, + const TensorQuantizationParams& qparams); +template +void QuantizeAvx2<int8_t>( + const float* src, + int8_t* dst, + int len, + const TensorQuantizationParams& qparams); + + void FindMinMax(const float* a, float* min, float* max, int len) { if (len <= 0) { *min = 0.0f; @@ -264,14 +282,15 @@ template < bool B_SYMMETRIC, QuantizationGranularity Q_GRAN, bool HAS_BIAS, - bool FUSE_RELU> + bool FUSE_RELU, + typename BIAS_TYPE> void requantizeOutputProcessingAvx2( uint8_t* out, const int32_t* inp, const block_type_t& block, int ld_out, int ld_in, - const requantizationParams_t& r) { + const requantizationParams_t<BIAS_TYPE>& r) { // Adoption of implementation at QNNPACK/src/requantization/fp32-sse2.c // using AVX2 instructions int quant_param_idx = 0; @@ -282,6 +301,15 @@ void requantizeOutputProcessingAvx2( } __m256 multiplier_v = _mm256_set1_ps(r.C_multiplier[quant_param_idx]); + // Broadcasted reciprocal of act_times_w_scale + __m256 act_times_w_rcp_v; + if (!(Q_GRAN == QuantizationGranularity::OUT_CHANNEL)) { + if (is_same<BIAS_TYPE, float>::value) { + act_times_w_rcp_v = + _mm256_set1_ps(1.0 / r.act_times_w_scale[quant_param_idx]); + } + } + __m256i min_v = _mm256_set1_epi8(static_cast<uint8_t>(0)); __m256i max_v = _mm256_set1_epi8(static_cast<uint8_t>(255)); @@ -391,22 +419,76 @@ void requantizeOutputProcessingAvx2( } w_v = _mm256_sub_epi32(w_v, row_offset_v); } + __m256 xf_v, yf_v, zf_v, wf_v; if (HAS_BIAS) { - x_v = _mm256_add_epi32( - x_v, - _mm256_loadu_si256(reinterpret_cast<const __m256i*>(r.bias + j))); - y_v = _mm256_add_epi32( - y_v, - _mm256_loadu_si256( - reinterpret_cast<const __m256i*>(r.bias + j + VLEN))); - z_v = _mm256_add_epi32( - z_v, - _mm256_loadu_si256( - reinterpret_cast<const __m256i*>(r.bias + j + 2 * VLEN))); - w_v = _mm256_add_epi32( - w_v, - _mm256_loadu_si256( - reinterpret_cast<const __m256i*>(r.bias + j + 3 * VLEN))); + if (is_same<BIAS_TYPE, float>::value) { + __m256 x_bias_v, y_bias_v, z_bias_v, w_bias_v; + if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) { + x_bias_v = _mm256_div_ps( + _mm256_loadu_ps( + reinterpret_cast<const float*>(r.bias + j + 0 * VLEN)), + _mm256_loadu_ps(r.act_times_w_scale + j + 0 * VLEN)); + y_bias_v = _mm256_div_ps( + _mm256_loadu_ps( + reinterpret_cast<const float*>(r.bias + j + 1 * VLEN)), + _mm256_loadu_ps(r.act_times_w_scale + j + 1 * VLEN)); + z_bias_v = _mm256_div_ps( + _mm256_loadu_ps( + reinterpret_cast<const float*>(r.bias + j + 2 * VLEN)), + _mm256_loadu_ps(r.act_times_w_scale + j + 2 * VLEN)); + w_bias_v = _mm256_div_ps( + _mm256_loadu_ps( + reinterpret_cast<const float*>(r.bias + j + 3 * VLEN)), + _mm256_loadu_ps(r.act_times_w_scale + j + 3 * VLEN)); + } else { + x_bias_v = _mm256_mul_ps( + _mm256_loadu_ps( + reinterpret_cast<const float*>(r.bias + j + 0 * VLEN)), + act_times_w_rcp_v); + y_bias_v = _mm256_mul_ps( + _mm256_loadu_ps( + reinterpret_cast<const float*>(r.bias + j + 1 * VLEN)), + act_times_w_rcp_v); + z_bias_v = _mm256_mul_ps( + _mm256_loadu_ps( + reinterpret_cast<const float*>(r.bias + j + 2 * VLEN)), + act_times_w_rcp_v); + w_bias_v = _mm256_mul_ps( + _mm256_loadu_ps( + reinterpret_cast<const float*>(r.bias + j + 3 * VLEN)), + act_times_w_rcp_v); + } + xf_v = _mm256_add_ps(_mm256_cvtepi32_ps(x_v), x_bias_v); + yf_v = _mm256_add_ps(_mm256_cvtepi32_ps(y_v), y_bias_v); + zf_v = _mm256_add_ps(_mm256_cvtepi32_ps(z_v), z_bias_v); + wf_v = _mm256_add_ps(_mm256_cvtepi32_ps(w_v), w_bias_v); + } else { + x_v = _mm256_add_epi32( + x_v, + _mm256_loadu_si256( + reinterpret_cast<const __m256i*>(r.bias + j + 0 * VLEN))); + y_v = _mm256_add_epi32( + y_v, + _mm256_loadu_si256( + reinterpret_cast<const __m256i*>(r.bias + j + 1 * VLEN))); + z_v = _mm256_add_epi32( + z_v, + _mm256_loadu_si256( + reinterpret_cast<const __m256i*>(r.bias + j + 2 * VLEN))); + w_v = _mm256_add_epi32( + w_v, + _mm256_loadu_si256( + reinterpret_cast<const __m256i*>(r.bias + j + 3 * VLEN))); + xf_v = _mm256_cvtepi32_ps(x_v); + yf_v = _mm256_cvtepi32_ps(y_v); + zf_v = _mm256_cvtepi32_ps(z_v); + wf_v = _mm256_cvtepi32_ps(w_v); + } + } else { + xf_v = _mm256_cvtepi32_ps(x_v); + yf_v = _mm256_cvtepi32_ps(y_v); + zf_v = _mm256_cvtepi32_ps(z_v); + wf_v = _mm256_cvtepi32_ps(w_v); } /* @@ -423,22 +505,19 @@ void requantizeOutputProcessingAvx2( */ __m256 x_scaled_v, y_scaled_v, z_scaled_v, w_scaled_v; if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) { - x_scaled_v = _mm256_mul_ps( - _mm256_cvtepi32_ps(x_v), _mm256_loadu_ps(r.C_multiplier + j)); - y_scaled_v = _mm256_mul_ps( - _mm256_cvtepi32_ps(y_v), - _mm256_loadu_ps(r.C_multiplier + j + VLEN)); - z_scaled_v = _mm256_mul_ps( - _mm256_cvtepi32_ps(z_v), - _mm256_loadu_ps(r.C_multiplier + j + 2 * VLEN)); - w_scaled_v = _mm256_mul_ps( - _mm256_cvtepi32_ps(w_v), - _mm256_loadu_ps(r.C_multiplier + j + 3 * VLEN)); + x_scaled_v = + _mm256_mul_ps(xf_v, _mm256_loadu_ps(r.C_multiplier + j + 0 * VLEN)); + y_scaled_v = + _mm256_mul_ps(yf_v, _mm256_loadu_ps(r.C_multiplier + j + 1 * VLEN)); + z_scaled_v = + _mm256_mul_ps(zf_v, _mm256_loadu_ps(r.C_multiplier + j + 2 * VLEN)); + w_scaled_v = + _mm256_mul_ps(wf_v, _mm256_loadu_ps(r.C_multiplier + j + 3 * VLEN)); } else { - x_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(x_v), multiplier_v); - y_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(y_v), multiplier_v); - z_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(z_v), multiplier_v); - w_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(w_v), multiplier_v); + x_scaled_v = _mm256_mul_ps(xf_v, multiplier_v); + y_scaled_v = _mm256_mul_ps(yf_v, multiplier_v); + z_scaled_v = _mm256_mul_ps(zf_v, multiplier_v); + w_scaled_v = _mm256_mul_ps(wf_v, multiplier_v); } /* @@ -525,18 +604,35 @@ void requantizeOutputProcessingAvx2( } x_v = _mm256_sub_epi32(x_v, row_offset_v); } + __m256 xf_v; if (HAS_BIAS) { - x_v = _mm256_add_epi32( - x_v, - _mm256_loadu_si256(reinterpret_cast<const __m256i*>(r.bias + j))); + if (is_same<BIAS_TYPE, float>::value) { + __m256 x_bias_v; + if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) { + x_bias_v = _mm256_div_ps( + _mm256_loadu_ps(reinterpret_cast<const float*>(r.bias + j)), + _mm256_loadu_ps(r.act_times_w_scale + j)); + } else { + x_bias_v = _mm256_mul_ps( + _mm256_loadu_ps(reinterpret_cast<const float*>(r.bias + j)), + act_times_w_rcp_v); + } + xf_v = _mm256_add_ps(_mm256_cvtepi32_ps(x_v), x_bias_v); + } else { + x_v = _mm256_add_epi32( + x_v, + _mm256_loadu_si256(reinterpret_cast<const __m256i*>(r.bias + j))); + xf_v = _mm256_cvtepi32_ps(x_v); + } + } else { + xf_v = _mm256_cvtepi32_ps(x_v); } __m256 x_scaled_v; if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) { - x_scaled_v = _mm256_mul_ps( - _mm256_cvtepi32_ps(x_v), _mm256_loadu_ps(r.C_multiplier + j)); + x_scaled_v = _mm256_mul_ps(xf_v, _mm256_loadu_ps(r.C_multiplier + j)); } else { - x_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(x_v), multiplier_v); + x_scaled_v = _mm256_mul_ps(xf_v, multiplier_v); } __m256i x_rounded_v = _mm256_cvtps_epi32(x_scaled_v); @@ -574,6 +670,7 @@ void requantizeOutputProcessingAvx2( int remainder = block.col_start + block.col_size - j; if (remainder > 0) { + // clang-format off alignas(64) const int masks[8][8] = { // NOTE: clang-format wants to use a different formatting but the // current formatting should be easier to read. @@ -586,6 +683,7 @@ void requantizeOutputProcessingAvx2( { -1, -1, -1, -1, -1, -1, 0, 0, }, { -1, -1, -1, -1, -1, -1, -1, 0, }, }; + // clang-format on __m256i mask_v = _mm256_load_si256( reinterpret_cast<const __m256i*>(masks[remainder])); @@ -607,17 +705,40 @@ void requantizeOutputProcessingAvx2( } x_v = _mm256_sub_epi32(x_v, row_offset_v); } + + __m256 xf_v; if (HAS_BIAS) { - x_v = _mm256_add_epi32(x_v, _mm256_maskload_epi32(r.bias + j, mask_v)); + if (is_same<BIAS_TYPE, float>::value) { + __m256 x_bias_v; + if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) { + x_bias_v = _mm256_div_ps( + _mm256_maskload_ps( + reinterpret_cast<const float*>(r.bias + j), mask_v), + _mm256_maskload_ps(r.act_times_w_scale + j, mask_v)); + } else { + x_bias_v = _mm256_mul_ps( + _mm256_maskload_ps( + reinterpret_cast<const float*>(r.bias + j), mask_v), + act_times_w_rcp_v); + } + xf_v = _mm256_add_ps(_mm256_cvtepi32_ps(x_v), x_bias_v); + } else { + x_v = _mm256_add_epi32( + x_v, + _mm256_maskload_epi32( + reinterpret_cast<const int*>(r.bias + j), mask_v)); + xf_v = _mm256_cvtepi32_ps(x_v); + } + } else { + xf_v = _mm256_cvtepi32_ps(x_v); } __m256 x_scaled_v; if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) { - x_scaled_v = _mm256_mul_ps( - _mm256_cvtepi32_ps(x_v), - _mm256_maskload_ps(r.C_multiplier + j, mask_v)); + x_scaled_v = + _mm256_mul_ps(xf_v, _mm256_maskload_ps(r.C_multiplier + j, mask_v)); } else { - x_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(x_v), multiplier_v); + x_scaled_v = _mm256_mul_ps(xf_v, multiplier_v); } __m256i x_rounded_v = _mm256_cvtps_epi32(x_scaled_v); @@ -759,6 +880,7 @@ void requantizeForFloatAvx2( int remainder = block.col_start + block.col_size - j; if (remainder > 0) { + // clang-format off alignas(64) const int masks[8][8] = { // NOTE: clang-format wants to use a different formatting but the // current formatting should be easier to read. @@ -771,6 +893,7 @@ void requantizeForFloatAvx2( { -1, -1, -1, -1, -1, -1, 0, 0, }, { -1, -1, -1, -1, -1, -1, -1, 0, }, }; + // clang-format on __m256i mask_v = _mm256_load_si256( reinterpret_cast<const __m256i*>(masks[remainder])); @@ -823,14 +946,15 @@ template < QuantizationGranularity Q_GRAN, bool HAS_BIAS, bool FUSE_RELU, - int C_PER_G> + int C_PER_G, + typename BIAS_TYPE> void requantizeOutputProcessingGConvAvx2( uint8_t* out, const int32_t* inp, const block_type_t& block, int ld_out, int ld_in, - const requantizationParams_t& r) { + const requantizationParams_t<BIAS_TYPE>& r) { // Adoption of implementation at QNNPACK/src/requantization/fp32-sse2.c // using AVX2 instructions int quant_param_idx = 0; @@ -841,6 +965,14 @@ void requantizeOutputProcessingGConvAvx2( } __m256 multiplier_v = _mm256_set1_ps(r.C_multiplier[quant_param_idx]); + // Broadcasted reciprocal of act_times_w_scale + __m256 act_times_w_rcp_v; + if (!(Q_GRAN == QuantizationGranularity::OUT_CHANNEL)) { + if (is_same<BIAS_TYPE, float>::value) { + act_times_w_rcp_v = + _mm256_set1_ps(1.0 / r.act_times_w_scale[quant_param_idx]); + } + } __m256i min_v = _mm256_set1_epi8(static_cast<uint8_t>(0)); __m256i max_v = _mm256_set1_epi8(static_cast<uint8_t>(255)); @@ -1087,22 +1219,135 @@ void requantizeOutputProcessingGConvAvx2( } w_v = _mm256_sub_epi32(w_v, row_offset_v); } + __m256 xf_v, yf_v, zf_v, wf_v; if (HAS_BIAS) { - x_v = _mm256_add_epi32( - x_v, - _mm256_loadu_si256(reinterpret_cast<const __m256i*>(r.bias + j))); - y_v = _mm256_add_epi32( - y_v, - _mm256_loadu_si256( - reinterpret_cast<const __m256i*>(r.bias + j + VLEN))); - z_v = _mm256_add_epi32( - z_v, - _mm256_loadu_si256( - reinterpret_cast<const __m256i*>(r.bias + j + 2 * VLEN))); - w_v = _mm256_add_epi32( - w_v, - _mm256_loadu_si256( - reinterpret_cast<const __m256i*>(r.bias + j + 3 * VLEN))); + if (is_same<BIAS_TYPE, float>::value) { + __m256 x_bias_v = _mm256_loadu_ps( + reinterpret_cast<const float*>(r.bias + j + 0 * VLEN)); + __m256 y_bias_v = _mm256_loadu_ps( + reinterpret_cast<const float*>(r.bias + j + 1 * VLEN)); + __m256 z_bias_v = _mm256_loadu_ps( + reinterpret_cast<const float*>(r.bias + j + 2 * VLEN)); + __m256 w_bias_v = _mm256_loadu_ps( + reinterpret_cast<const float*>(r.bias + j + 3 * VLEN)); + if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) { + x_bias_v = _mm256_div_ps( + x_bias_v, _mm256_loadu_ps(r.act_times_w_scale + j + 0 * VLEN)); + y_bias_v = _mm256_div_ps( + y_bias_v, _mm256_loadu_ps(r.act_times_w_scale + j + 1 * VLEN)); + z_bias_v = _mm256_div_ps( + z_bias_v, _mm256_loadu_ps(r.act_times_w_scale + j + 2 * VLEN)); + w_bias_v = _mm256_div_ps( + w_bias_v, _mm256_loadu_ps(r.act_times_w_scale + j + 3 * VLEN)); + } else if (Q_GRAN == QuantizationGranularity::GROUP) { + __m256 diviser_v; + if (C_PER_G == 4) { + diviser_v = _mm256_insertf128_ps( + _mm256_castps128_ps256( + _mm_set1_ps(r.act_times_w_scale[quant_param_idx + 0])), + _mm_set1_ps(r.act_times_w_scale[quant_param_idx + 1]), + 1); + x_bias_v = _mm256_div_ps(x_bias_v, diviser_v); + + diviser_v = _mm256_insertf128_ps( + _mm256_castps128_ps256( + _mm_set1_ps(r.act_times_w_scale[quant_param_idx + 2])), + _mm_set1_ps(r.act_times_w_scale[quant_param_idx + 3]), + 1); + y_bias_v = _mm256_div_ps(y_bias_v, diviser_v); + + diviser_v = _mm256_insertf128_ps( + _mm256_castps128_ps256( + _mm_set1_ps(r.act_times_w_scale[quant_param_idx + 4])), + _mm_set1_ps(r.act_times_w_scale[quant_param_idx + 5]), + 1); + z_bias_v = _mm256_div_ps(z_bias_v, diviser_v); + + diviser_v = _mm256_insertf128_ps( + _mm256_castps128_ps256( + _mm_set1_ps(r.act_times_w_scale[quant_param_idx + 6])), + _mm_set1_ps(r.act_times_w_scale[quant_param_idx + 7]), + 1); + w_bias_v = _mm256_div_ps(w_bias_v, diviser_v); + + } else if (C_PER_G == 8) { + diviser_v = _mm256_set1_ps( + r.act_times_w_scale + [quant_param_idx + + (j - block.col_start) / (VLEN * 4) * 4 + 0]); + x_bias_v = _mm256_div_ps(x_bias_v, diviser_v); + + diviser_v = _mm256_set1_ps( + r.act_times_w_scale + [quant_param_idx + + (j - block.col_start) / (VLEN * 4) * 4 + 1]); + y_bias_v = _mm256_div_ps(y_bias_v, diviser_v); + + diviser_v = _mm256_set1_ps( + r.act_times_w_scale + [quant_param_idx + + (j - block.col_start) / (VLEN * 4) * 4 + 2]); + z_bias_v = _mm256_div_ps(z_bias_v, diviser_v); + + diviser_v = _mm256_set1_ps( + r.act_times_w_scale + [quant_param_idx + + (j - block.col_start) / (VLEN * 4) * 4 + 3]); + w_bias_v = _mm256_div_ps(w_bias_v, diviser_v); + + } else { + assert(C_PER_G == 16); + diviser_v = _mm256_set1_ps( + r.act_times_w_scale + [quant_param_idx + + (j - block.col_start) / (VLEN * 4) * 2 + 0]); + x_bias_v = _mm256_div_ps(x_bias_v, diviser_v); + y_bias_v = _mm256_div_ps(y_bias_v, diviser_v); + + diviser_v = _mm256_set1_ps( + r.act_times_w_scale + [quant_param_idx + + (j - block.col_start) / (VLEN * 4) * 2 + 1]); + z_bias_v = _mm256_div_ps(z_bias_v, diviser_v); + w_bias_v = _mm256_div_ps(w_bias_v, diviser_v); + } + } else { + x_bias_v = _mm256_mul_ps(x_bias_v, act_times_w_rcp_v); + y_bias_v = _mm256_mul_ps(y_bias_v, act_times_w_rcp_v); + z_bias_v = _mm256_mul_ps(z_bias_v, act_times_w_rcp_v); + w_bias_v = _mm256_mul_ps(w_bias_v, act_times_w_rcp_v); + } + xf_v = _mm256_add_ps(_mm256_cvtepi32_ps(x_v), x_bias_v); + yf_v = _mm256_add_ps(_mm256_cvtepi32_ps(y_v), y_bias_v); + zf_v = _mm256_add_ps(_mm256_cvtepi32_ps(z_v), z_bias_v); + wf_v = _mm256_add_ps(_mm256_cvtepi32_ps(w_v), w_bias_v); + } else { + x_v = _mm256_add_epi32( + x_v, + _mm256_loadu_si256( + reinterpret_cast<const __m256i*>(r.bias + j + 0 * VLEN))); + y_v = _mm256_add_epi32( + y_v, + _mm256_loadu_si256( + reinterpret_cast<const __m256i*>(r.bias + j + 1 * VLEN))); + z_v = _mm256_add_epi32( + z_v, + _mm256_loadu_si256( + reinterpret_cast<const __m256i*>(r.bias + j + 2 * VLEN))); + w_v = _mm256_add_epi32( + w_v, + _mm256_loadu_si256( + reinterpret_cast<const __m256i*>(r.bias + j + 3 * VLEN))); + xf_v = _mm256_cvtepi32_ps(x_v); + yf_v = _mm256_cvtepi32_ps(y_v); + zf_v = _mm256_cvtepi32_ps(z_v); + wf_v = _mm256_cvtepi32_ps(w_v); + } + } else { + xf_v = _mm256_cvtepi32_ps(x_v); + yf_v = _mm256_cvtepi32_ps(y_v); + zf_v = _mm256_cvtepi32_ps(z_v); + wf_v = _mm256_cvtepi32_ps(w_v); } /* @@ -1119,17 +1364,13 @@ void requantizeOutputProcessingGConvAvx2( */ __m256 x_scaled_v, y_scaled_v, z_scaled_v, w_scaled_v; if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) { - x_scaled_v = _mm256_mul_ps( - _mm256_cvtepi32_ps(x_v), _mm256_loadu_ps(r.C_multiplier + j)); - y_scaled_v = _mm256_mul_ps( - _mm256_cvtepi32_ps(y_v), - _mm256_loadu_ps(r.C_multiplier + j + VLEN)); - z_scaled_v = _mm256_mul_ps( - _mm256_cvtepi32_ps(z_v), - _mm256_loadu_ps(r.C_multiplier + j + 2 * VLEN)); - w_scaled_v = _mm256_mul_ps( - _mm256_cvtepi32_ps(w_v), - _mm256_loadu_ps(r.C_multiplier + j + 3 * VLEN)); + x_scaled_v = _mm256_mul_ps(xf_v, _mm256_loadu_ps(r.C_multiplier + j)); + y_scaled_v = + _mm256_mul_ps(yf_v, _mm256_loadu_ps(r.C_multiplier + j + VLEN)); + z_scaled_v = + _mm256_mul_ps(zf_v, _mm256_loadu_ps(r.C_multiplier + j + 2 * VLEN)); + w_scaled_v = + _mm256_mul_ps(wf_v, _mm256_loadu_ps(r.C_multiplier + j + 3 * VLEN)); } else if (Q_GRAN == QuantizationGranularity::GROUP) { if (C_PER_G == 4) { multiplier_v = _mm256_insertf128_ps( @@ -1137,70 +1378,70 @@ void requantizeOutputProcessingGConvAvx2( _mm_set1_ps(r.C_multiplier[quant_param_idx])), _mm_set1_ps(r.C_multiplier[quant_param_idx + 1]), 1); - x_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(x_v), multiplier_v); + x_scaled_v = _mm256_mul_ps(xf_v, multiplier_v); multiplier_v = _mm256_insertf128_ps( _mm256_castps128_ps256( _mm_set1_ps(r.C_multiplier[quant_param_idx + 2])), _mm_set1_ps(r.C_multiplier[quant_param_idx + 3]), 1); - y_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(y_v), multiplier_v); + y_scaled_v = _mm256_mul_ps(yf_v, multiplier_v); multiplier_v = _mm256_insertf128_ps( _mm256_castps128_ps256( _mm_set1_ps(r.C_multiplier[quant_param_idx + 4])), _mm_set1_ps(r.C_multiplier[quant_param_idx + 5]), 1); - z_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(z_v), multiplier_v); + z_scaled_v = _mm256_mul_ps(zf_v, multiplier_v); multiplier_v = _mm256_insertf128_ps( _mm256_castps128_ps256( _mm_set1_ps(r.C_multiplier[quant_param_idx + 6])), _mm_set1_ps(r.C_multiplier[quant_param_idx + 7]), 1); - w_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(w_v), multiplier_v); + w_scaled_v = _mm256_mul_ps(wf_v, multiplier_v); } else if (C_PER_G == 8) { multiplier_v = _mm256_set1_ps( r.C_multiplier [quant_param_idx + (j - block.col_start) / (VLEN * 4) * 4]); - x_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(x_v), multiplier_v); - - multiplier_v = _mm256_set1_ps( - r.C_multiplier - [quant_param_idx + (j - block.col_start) / (VLEN * 4) * 4 + - 1]); - y_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(y_v), multiplier_v); - - multiplier_v = _mm256_set1_ps( - r.C_multiplier - [quant_param_idx + (j - block.col_start) / (VLEN * 4) * 4 + - 2]); - z_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(z_v), multiplier_v); - - multiplier_v = _mm256_set1_ps( - r.C_multiplier - [quant_param_idx + (j - block.col_start) / (VLEN * 4) * 4 + - 3]); - w_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(w_v), multiplier_v); + x_scaled_v = _mm256_mul_ps(xf_v, multiplier_v); + + multiplier_v = + _mm256_set1_ps(r.C_multiplier + [quant_param_idx + + (j - block.col_start) / (VLEN * 4) * 4 + 1]); + y_scaled_v = _mm256_mul_ps(yf_v, multiplier_v); + + multiplier_v = + _mm256_set1_ps(r.C_multiplier + [quant_param_idx + + (j - block.col_start) / (VLEN * 4) * 4 + 2]); + z_scaled_v = _mm256_mul_ps(zf_v, multiplier_v); + + multiplier_v = + _mm256_set1_ps(r.C_multiplier + [quant_param_idx + + (j - block.col_start) / (VLEN * 4) * 4 + 3]); + w_scaled_v = _mm256_mul_ps(wf_v, multiplier_v); } else { multiplier_v = _mm256_set1_ps( r.C_multiplier [quant_param_idx + (j - block.col_start) / (VLEN * 4) * 2]); - x_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(x_v), multiplier_v); - y_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(y_v), multiplier_v); - - multiplier_v = _mm256_set1_ps( - r.C_multiplier - [quant_param_idx + (j - block.col_start) / (VLEN * 4) * 2 + - 1]); - z_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(z_v), multiplier_v); - w_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(w_v), multiplier_v); + x_scaled_v = _mm256_mul_ps(xf_v, multiplier_v); + y_scaled_v = _mm256_mul_ps(yf_v, multiplier_v); + + multiplier_v = + _mm256_set1_ps(r.C_multiplier + [quant_param_idx + + (j - block.col_start) / (VLEN * 4) * 2 + 1]); + z_scaled_v = _mm256_mul_ps(zf_v, multiplier_v); + w_scaled_v = _mm256_mul_ps(wf_v, multiplier_v); } } else { - x_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(x_v), multiplier_v); - y_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(y_v), multiplier_v); - z_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(z_v), multiplier_v); - w_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(w_v), multiplier_v); + x_scaled_v = _mm256_mul_ps(xf_v, multiplier_v); + y_scaled_v = _mm256_mul_ps(yf_v, multiplier_v); + z_scaled_v = _mm256_mul_ps(zf_v, multiplier_v); + w_scaled_v = _mm256_mul_ps(wf_v, multiplier_v); } /* @@ -1271,46 +1512,69 @@ void requantizeOutputProcessingGConvAvx2( } // i loop } -#define INSTANTIATE_REQUANTIZE(A_SYM, B_SYM, Q_GRAN, BIAS, RELU) \ - template void \ - requantizeOutputProcessingAvx2<A_SYM, B_SYM, Q_GRAN, BIAS, RELU>( \ - uint8_t * out, \ - const int32_t* inp, \ - const block_type_t& block, \ - int ld_out, \ - int ld_in, \ - const requantizationParams_t& r); \ - template void requantizeForFloatAvx2<A_SYM, B_SYM, Q_GRAN, BIAS, RELU>( \ - float* out, \ - const int32_t* inp, \ - const block_type_t& block, \ - int ld_out, \ - int ld_in, \ - const requantizationForFloatParams_t& r); \ - template void \ - requantizeOutputProcessingGConvAvx2<A_SYM, B_SYM, Q_GRAN, BIAS, RELU, 4>( \ - uint8_t * out, \ - const int32_t* inp, \ - const block_type_t& block, \ - int ld_out, \ - int ld_in, \ - const requantizationParams_t& r); \ - template void \ - requantizeOutputProcessingGConvAvx2<A_SYM, B_SYM, Q_GRAN, BIAS, RELU, 8>( \ - uint8_t * out, \ - const int32_t* inp, \ - const block_type_t& block, \ - int ld_out, \ - int ld_in, \ - const requantizationParams_t& r); \ - template void \ - requantizeOutputProcessingGConvAvx2<A_SYM, B_SYM, Q_GRAN, BIAS, RELU, 16>( \ - uint8_t * out, \ - const int32_t* inp, \ - const block_type_t& block, \ - int ld_out, \ - int ld_in, \ - const requantizationParams_t& r); +#define INSTANTIATE_REQUANTIZE_BIAS_TYPE( \ + A_SYM, B_SYM, Q_GRAN, BIAS, RELU, BIAS_TYPE) \ + template void \ + requantizeOutputProcessingAvx2<A_SYM, B_SYM, Q_GRAN, BIAS, RELU, BIAS_TYPE>( \ + uint8_t * out, \ + const int32_t* inp, \ + const block_type_t& block, \ + int ld_out, \ + int ld_in, \ + const requantizationParams_t<BIAS_TYPE>& r); \ + template void requantizeOutputProcessingGConvAvx2< \ + A_SYM, \ + B_SYM, \ + Q_GRAN, \ + BIAS, \ + RELU, \ + 4, \ + BIAS_TYPE>( \ + uint8_t * out, \ + const int32_t* inp, \ + const block_type_t& block, \ + int ld_out, \ + int ld_in, \ + const requantizationParams_t<BIAS_TYPE>& r); \ + template void requantizeOutputProcessingGConvAvx2< \ + A_SYM, \ + B_SYM, \ + Q_GRAN, \ + BIAS, \ + RELU, \ + 8, \ + BIAS_TYPE>( \ + uint8_t * out, \ + const int32_t* inp, \ + const block_type_t& block, \ + int ld_out, \ + int ld_in, \ + const requantizationParams_t<BIAS_TYPE>& r); \ + template void requantizeOutputProcessingGConvAvx2< \ + A_SYM, \ + B_SYM, \ + Q_GRAN, \ + BIAS, \ + RELU, \ + 16, \ + BIAS_TYPE>( \ + uint8_t * out, \ + const int32_t* inp, \ + const block_type_t& block, \ + int ld_out, \ + int ld_in, \ + const requantizationParams_t<BIAS_TYPE>& r); + +#define INSTANTIATE_REQUANTIZE(A_SYM, B_SYM, Q_GRAN, BIAS, RELU) \ + INSTANTIATE_REQUANTIZE_BIAS_TYPE(A_SYM, B_SYM, Q_GRAN, BIAS, RELU, float) \ + INSTANTIATE_REQUANTIZE_BIAS_TYPE(A_SYM, B_SYM, Q_GRAN, BIAS, RELU, int32_t) \ + template void requantizeForFloatAvx2<A_SYM, B_SYM, Q_GRAN, BIAS, RELU>( \ + float* out, \ + const int32_t* inp, \ + const block_type_t& block, \ + int ld_out, \ + int ld_in, \ + const requantizationForFloatParams_t& r); #define INSTANTIATE_A_SYM(B_SYM, Q_GRAN, BIAS, RELU) \ INSTANTIATE_REQUANTIZE(true, B_SYM, Q_GRAN, BIAS, RELU) \ diff --git a/src/RefImplementations.cc b/src/RefImplementations.cc index e3c0eac..dc40d44 100644 --- a/src/RefImplementations.cc +++ b/src/RefImplementations.cc @@ -300,9 +300,11 @@ void im2col_ref( for (int h = 0; h < OUT_DIM[0]; ++h) { for (int w = 0; w < OUT_DIM[1]; ++w) { for (int r = 0; r < K[0]; ++r) { - int h_in = -conv_p.pad[0] + h * conv_p.stride[0] + r; + int h_in = + -conv_p.pad[0] + h * conv_p.stride[0] + r * conv_p.dilation[0]; for (int s = 0; s < K[1]; ++s) { - int w_in = -conv_p.pad[1] + w * conv_p.stride[1] + s; + int w_in = + -conv_p.pad[1] + w * conv_p.stride[1] + s * conv_p.dilation[1]; if (h_in < 0 || h_in >= IN_DIM[0] || w_in < 0 || w_in >= IN_DIM[1]) { for (int g = 0; g < G; ++g) { @@ -363,11 +365,14 @@ void im2col_ref( for (int h = 0; h < OUT_DIM[1]; ++h) { for (int w = 0; w < OUT_DIM[2]; ++w) { for (int q = 0; q < K[0]; ++q) { - int t_in = -conv_p.pad[0] + t * conv_p.stride[0] + q; + int t_in = + -conv_p.pad[0] + t * conv_p.stride[0] + q * conv_p.dilation[0]; for (int r = 0; r < K[1]; ++r) { - int h_in = -conv_p.pad[1] + h * conv_p.stride[1] + r; + int h_in = -conv_p.pad[1] + h * conv_p.stride[1] + + r * conv_p.dilation[1]; for (int s = 0; s < K[2]; ++s) { - int w_in = -conv_p.pad[2] + w * conv_p.stride[2] + s; + int w_in = -conv_p.pad[2] + w * conv_p.stride[2] + + s * conv_p.dilation[2]; if (t_in < 0 || t_in >= IN_DIM[0] || h_in < 0 || h_in >= IN_DIM[1] || w_in < 0 || w_in >= IN_DIM[2]) { for (int g = 0; g < G; ++g) { @@ -447,9 +452,11 @@ void conv_ref( for (int m = 0; m < OC / G; ++m) { int sum = 0; for (int r = 0; r < K[0]; ++r) { - int h_in = -conv_p.pad[0] + h * conv_p.stride[0] + r; + int h_in = -conv_p.pad[0] + h * conv_p.stride[0] + + r * conv_p.dilation[0]; for (int s = 0; s < K[1]; ++s) { - int w_in = -conv_p.pad[1] + w * conv_p.stride[1] + s; + int w_in = -conv_p.pad[1] + w * conv_p.stride[1] + + s * conv_p.dilation[1]; for (int c = 0; c < IC / G; ++c) { int a = h_in < 0 || h_in >= IN_DIM[0] || w_in < 0 || w_in >= IN_DIM[1] @@ -499,11 +506,14 @@ void conv_ref( for (int m = 0; m < OC / G; ++m) { int sum = 0; for (int q = 0; q < K[0]; ++q) { - int t_in = -conv_p.pad[0] + t * conv_p.stride[0] + q; + int t_in = -conv_p.pad[0] + t * conv_p.stride[0] + + q * conv_p.dilation[0]; for (int r = 0; r < K[1]; ++r) { - int h_in = -conv_p.pad[1] + h * conv_p.stride[1] + r; + int h_in = -conv_p.pad[1] + h * conv_p.stride[1] + + r * conv_p.dilation[1]; for (int s = 0; s < K[2]; ++s) { - int w_in = -conv_p.pad[2] + w * conv_p.stride[2] + s; + int w_in = -conv_p.pad[2] + w * conv_p.stride[2] + + s * conv_p.dilation[2]; for (int c = 0; c < IC / G; ++c) { int a = t_in < 0 || t_in >= IN_DIM[0] || h_in < 0 || h_in >= IN_DIM[1] || w_in < 0 || w_in >= IN_DIM[2] @@ -590,406 +600,6 @@ void transposeConvWeights( } } -void depthwise_3x3_pad_1_ref( - int N, - int H, - int W, - int K, - int stride_h, - int stride_w, - int32_t A_zero_point, - const uint8_t* A, - const int8_t* B, - int32_t* C) { - constexpr int R = 3, S = 3; - constexpr int PAD_T = 1, PAD_B = 1, PAD_L = 1, PAD_R = 1; - int H_OUT = (H + PAD_T + PAD_B - R) / stride_h + 1; - int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1; - - for (int n = 0; n < N; ++n) { - for (int h = 0; h < H_OUT; ++h) { - for (int w = 0; w < W_OUT; ++w) { - for (int k = 0; k < K; ++k) { - int sum = 0; - for (int r = 0; r < R; ++r) { - int h_in = -PAD_T + h * stride_h + r; - for (int s = 0; s < S; ++s) { - int w_in = -PAD_L + w * stride_w + s; - int a = h_in < 0 || h_in >= H || w_in < 0 || w_in >= W - ? A_zero_point - : A[((n * H + h_in) * W + w_in) * K + k]; - int b = B[(k * R + r) * S + s]; - sum += a * b; - } - } - C[((n * H_OUT + h) * W_OUT + w) * K + k] = sum; - } - } - } - } // for each n -}; - -void depthwise_3x3_pad_1_ref( - int N, - int H, - int W, - int K, - int stride_h, - int stride_w, - int32_t A_zero_point, - const uint8_t* A, - int32_t B_zero_point, - const int8_t* B, - float C_multiplier, - int32_t C_zero_point, - uint8_t* C, - const int32_t* col_offsets, - const int32_t* bias) { - constexpr int R = 3, S = 3; - constexpr int PAD_T = 1, PAD_B = 1, PAD_L = 1, PAD_R = 1; - int H_OUT = (H + PAD_T + PAD_B - R) / stride_h + 1; - int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1; - - vector<int32_t> C_int32(N * H_OUT * W_OUT * K); - depthwise_3x3_pad_1_ref( - N, H, W, K, stride_h, stride_w, A_zero_point, A, B, C_int32.data()); - - vector<int32_t> row_offsets(N * H_OUT * W_OUT * K); - for (int n = 0; n < N; ++n) { - for (int h = 0; h < H_OUT; ++h) { - for (int w = 0; w < W_OUT; ++w) { - for (int k = 0; k < K; ++k) { - int sum = 0; - for (int r = 0; r < R; ++r) { - int h_in = -PAD_T + h * stride_h + r; - for (int s = 0; s < S; ++s) { - int w_in = -PAD_L + w * stride_w + s; - int a = h_in < 0 || h_in >= H || w_in < 0 || w_in >= W - ? A_zero_point - : A[((n * H + h_in) * W + w_in) * K + k]; - sum += a; - } - } - row_offsets[((n * H_OUT + h) * W_OUT + w) * K + k] = sum; - } - } - } - } // for each n - - for (int i = 0; i < N * H_OUT * W_OUT; ++i) { - for (int k = 0; k < K; ++k) { - requantize_u8acc32_ref( - 1, - 1, - 1, - C_int32.data() + i * K + k, - C + i * K + k, - &C_multiplier, - C_zero_point, - A_zero_point, - &B_zero_point, - &row_offsets[i * K + k], - col_offsets + k, - bias ? bias + k : nullptr, - 1); - } - } -}; - -void depthwise_3x3_per_channel_quantization_pad_1_ref( - int N, - int H, - int W, - int K, - int stride_h, - int stride_w, - int32_t A_zero_point, - const uint8_t* A, - const int32_t* B_zero_point, - const int8_t* B, - const float* C_multiplier, - int32_t C_zero_point, - uint8_t* C, - const int32_t* col_offsets, - const int32_t* bias) { - constexpr int R = 3, S = 3; - constexpr int PAD_T = 1, PAD_B = 1, PAD_L = 1, PAD_R = 1; - int H_OUT = (H + PAD_T + PAD_B - R) / stride_h + 1; - int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1; - - vector<int32_t> C_int32(N * H_OUT * W_OUT * K); - depthwise_3x3_pad_1_ref( - N, H, W, K, stride_h, stride_w, A_zero_point, A, B, C_int32.data()); - - vector<int32_t> row_offsets(N * H_OUT * W_OUT * K); - for (int n = 0; n < N; ++n) { - for (int h = 0; h < H_OUT; ++h) { - for (int w = 0; w < W_OUT; ++w) { - for (int k = 0; k < K; ++k) { - int sum = 0; - for (int r = 0; r < R; ++r) { - int h_in = -PAD_T + h * stride_h + r; - for (int s = 0; s < S; ++s) { - int w_in = -PAD_L + w * stride_w + s; - int a = h_in < 0 || h_in >= H || w_in < 0 || w_in >= W - ? A_zero_point - : A[((n * H + h_in) * W + w_in) * K + k]; - sum += a; - } - } - row_offsets[((n * H_OUT + h) * W_OUT + w) * K + k] = sum; - } - } - } - } // for each n - - for (int i = 0; i < N * H_OUT * W_OUT; ++i) { - for (int k = 0; k < K; ++k) { - requantize_u8acc32_ref( - 1, - 1, - 1, - C_int32.data() + i * K + k, - C + i * K + k, - &C_multiplier[k], - C_zero_point, - A_zero_point, - &B_zero_point[k], - &row_offsets[i * K + k], - col_offsets + k, - bias ? bias + k : nullptr, - 1); - } - } -}; - -void depthwise_3x3x3_pad_1_ref( - int N, - int T, - int H, - int W, - int K, - int stride_t, - int stride_h, - int stride_w, - int32_t A_zero_point, - const uint8_t* A, - const int8_t* B, - int32_t* C) { - constexpr int K_T = 3, K_H = 3, K_W = 3; - constexpr int PAD_P = 1, PAD_N = 1, PAD_T = 1, PAD_B = 1, PAD_L = 1, - PAD_R = 1; - int T_OUT = (T + PAD_P + PAD_N - K_T) / stride_t + 1; - int H_OUT = (H + PAD_T + PAD_B - K_H) / stride_h + 1; - int W_OUT = (W + PAD_L + PAD_R - K_W) / stride_w + 1; - - for (int n = 0; n < N; ++n) { - for (int t = 0; t < T_OUT; ++t) { - for (int h = 0; h < H_OUT; ++h) { - for (int w = 0; w < W_OUT; ++w) { - for (int k = 0; k < K; ++k) { - int sum = 0; - for (int k_t = 0; k_t < K_T; ++k_t) { - int t_in = -PAD_P + t * stride_t + k_t; - for (int k_h = 0; k_h < K_H; ++k_h) { - int h_in = -PAD_T + h * stride_h + k_h; - for (int k_w = 0; k_w < K_W; ++k_w) { - int w_in = -PAD_L + w * stride_w + k_w; - int a = t_in < 0 || t_in >= T || h_in < 0 || h_in >= H || - w_in < 0 || w_in >= W - ? A_zero_point - : A[(((n * T + t_in) * H + h_in) * W + w_in) * K + k]; - int b = B[((k * K_T + k_t) * K_H + k_h) * K_W + k_w]; - sum += a * b; - } - } - } - C[(((n * T_OUT + t) * H_OUT + h) * W_OUT + w) * K + k] = sum; - } - } // w - } // h - } // t - } // for each n -}; - -void depthwise_3x3x3_pad_1_ref( - int N, - int T, - int H, - int W, - int K, - int stride_t, - int stride_h, - int stride_w, - int32_t A_zero_point, - const uint8_t* A, - int32_t B_zero_point, - const int8_t* B, - float C_multiplier, - int32_t C_zero_point, - uint8_t* C, - const int32_t* col_offsets, - const int32_t* bias) { - constexpr int K_T = 3, K_H = 3, K_W = 3; - constexpr int PAD_P = 1, PAD_N = 1, PAD_T = 1, PAD_B = 1, PAD_L = 1, - PAD_R = 1; - int T_OUT = (T + PAD_P + PAD_N - K_T) / stride_t + 1; - int H_OUT = (H + PAD_T + PAD_B - K_H) / stride_h + 1; - int W_OUT = (W + PAD_L + PAD_R - K_W) / stride_w + 1; - - vector<int32_t> C_int32(N * T_OUT * H_OUT * W_OUT * K); - depthwise_3x3x3_pad_1_ref( - N, - T, - H, - W, - K, - stride_t, - stride_h, - stride_w, - A_zero_point, - A, - B, - C_int32.data()); - - vector<int32_t> row_offsets(N * T_OUT * H_OUT * W_OUT * K); - for (int n = 0; n < N; ++n) { - for (int t = 0; t < T_OUT; ++t) { - for (int h = 0; h < H_OUT; ++h) { - for (int w = 0; w < W_OUT; ++w) { - for (int k = 0; k < K; ++k) { - int sum = 0; - for (int k_t = 0; k_t < K_T; ++k_t) { - int t_in = -PAD_P + t * stride_t + k_t; - for (int k_h = 0; k_h < K_H; ++k_h) { - int h_in = -PAD_T + h * stride_h + k_h; - for (int k_w = 0; k_w < K_W; ++k_w) { - int w_in = -PAD_L + w * stride_w + k_w; - int a = t_in < 0 || t_in >= T || h_in < 0 || h_in >= H || - w_in < 0 || w_in >= W - ? A_zero_point - : A[(((n * T + t_in) * H + h_in) * W + w_in) * K + k]; - sum += a; - } - } - } - row_offsets[(((n * T_OUT + t) * H_OUT + h) * W_OUT + w) * K + k] = - sum; - } - } // w - } // h - } // t - } // for each n - - for (int i = 0; i < N * T_OUT * H_OUT * W_OUT; ++i) { - for (int k = 0; k < K; ++k) { - requantize_u8acc32_ref( - 1, - 1, - 1, - C_int32.data() + i * K + k, - C + i * K + k, - &C_multiplier, - C_zero_point, - A_zero_point, - &B_zero_point, - &row_offsets[i * K + k], - col_offsets + k, - bias ? bias + k : nullptr, - 1); - } - } -}; - -void depthwise_3x3x3_per_channel_quantization_pad_1_ref( - int N, - int T, - int H, - int W, - int K, - int stride_t, - int stride_h, - int stride_w, - int32_t A_zero_point, - const uint8_t* A, - const int32_t* B_zero_point, - const int8_t* B, - const float* C_multiplier, - int32_t C_zero_point, - uint8_t* C, - const int32_t* col_offsets, - const int32_t* bias) { - constexpr int K_T = 3, K_H = 3, K_W = 3; - constexpr int PAD_P = 1, PAD_N = 1, PAD_T = 1, PAD_B = 1, PAD_L = 1, - PAD_R = 1; - int T_OUT = (T + PAD_P + PAD_N - K_T) / stride_t + 1; - int H_OUT = (H + PAD_T + PAD_B - K_H) / stride_h + 1; - int W_OUT = (W + PAD_L + PAD_R - K_W) / stride_w + 1; - - vector<int32_t> C_int32(N * T_OUT * H_OUT * W_OUT * K); - depthwise_3x3x3_pad_1_ref( - N, - T, - H, - W, - K, - stride_t, - stride_h, - stride_w, - A_zero_point, - A, - B, - C_int32.data()); - - vector<int32_t> row_offsets(N * T_OUT * H_OUT * W_OUT * K); - for (int n = 0; n < N; ++n) { - for (int t = 0; t < T_OUT; ++t) { - for (int h = 0; h < H_OUT; ++h) { - for (int w = 0; w < W_OUT; ++w) { - for (int k = 0; k < K; ++k) { - int sum = 0; - for (int k_t = 0; k_t < K_T; ++k_t) { - int t_in = -PAD_P + t * stride_t + k_t; - for (int k_h = 0; k_h < K_H; ++k_h) { - int h_in = -PAD_T + h * stride_h + k_h; - for (int k_w = 0; k_w < K_W; ++k_w) { - int w_in = -PAD_L + w * stride_w + k_w; - int a = t_in < 0 || t_in >= T || h_in < 0 || h_in >= H || - w_in < 0 || w_in >= W - ? A_zero_point - : A[(((n * T + t_in) * H + h_in) * W + w_in) * K + k]; - sum += a; - } - } - } - row_offsets[(((n * T_OUT + t) * H_OUT + h) * W_OUT + w) * K + k] = - sum; - } - } // w - } // h - } // t - } // for each n - - for (int i = 0; i < N * T_OUT * H_OUT * W_OUT; ++i) { - for (int k = 0; k < K; ++k) { - requantize_u8acc32_ref( - 1, - 1, - 1, - C_int32.data() + i * K + k, - C + i * K + k, - &C_multiplier[k], - C_zero_point, - A_zero_point, - &B_zero_point[k], - &row_offsets[i * K + k], - col_offsets + k, - bias ? bias + k : nullptr, - 1); - } - } -}; - template void transposeConvWeights( const conv_param_t<2>& conv_p, const std::int8_t* src, diff --git a/src/RefImplementations.h b/src/RefImplementations.h index 082bdf1..a20e348 100644 --- a/src/RefImplementations.h +++ b/src/RefImplementations.h @@ -215,124 +215,4 @@ FBGEMM_API void im2col_ref( std::int32_t A_zero_point, std::uint8_t* Ao); -/* - * @brief Reference implementation of depthwise convolution with a 3x3 filter - * and padding size 1. - */ -FBGEMM_API void depthwise_3x3_pad_1_ref( - int N, - int H, - int W, - int K, - int stride_h, - int stride_w, - std::int32_t A_zero_point, - const std::uint8_t* A, - const std::int8_t* B, - std::int32_t* C); - -/* - * @brief Reference implementation of depthwise convolution with a 3x3 filter - * and padding size 1, followed by requantization. (the same scaling factors and - * zero points for each channel). - */ -FBGEMM_API void depthwise_3x3_pad_1_ref( - int N, - int H, - int W, - int K, - int stride_h, - int stride_w, - std::int32_t A_zero_point, - const std::uint8_t* A, - std::int32_t B_zero_point, - const std::int8_t* B, - float C_multiplier, - std::int32_t C_zero_point, - std::uint8_t* C, - const std::int32_t* col_offsets, - const std::int32_t* bias); - -/* - * @brief Reference implementation of depthwise convolution with a 3x3 filter - * and padding size 1, followed by requantization. (different scaling factors - * and zero points for each channel). - */ -FBGEMM_API void depthwise_3x3_per_channel_quantization_pad_1_ref( - int N, - int H, - int W, - int K, - int stride_h, - int stride_w, - std::int32_t A_zero_point, - const std::uint8_t* A, - const std::int32_t* B_zero_point, - const std::int8_t* B, - const float* C_multiplier, - std::int32_t C_zero_point, - std::uint8_t* C, - const std::int32_t* col_offsets, - const std::int32_t* bias); - -/* - * @brief Reference implementation of 3D depthwise convolution with a 3x3x3 - * filter and padding size 1. - */ -FBGEMM_API void depthwise_3x3x3_pad_1_ref( - int N, - int T, - int H, - int W, - int K, - int stride_t, - int stride_h, - int stride_w, - std::int32_t A_zero_point, - const std::uint8_t* A, - const std::int8_t* B, - std::int32_t* C); - -/* - * @brief Reference implementation of 3D depthwise convolution with a 3x3x3 - * filter and padding size 1, followed by requantization. - */ -FBGEMM_API void depthwise_3x3x3_pad_1_ref( - int N, - int T, - int H, - int W, - int K, - int stride_t, - int stride_h, - int stride_w, - std::int32_t A_zero_point, - const std::uint8_t* A, - std::int32_t B_zero_point, - const std::int8_t* B, - float C_multiplier, - std::int32_t C_zero_point, - std::uint8_t* C, - const std::int32_t* col_offsets, - const std::int32_t* bias); - -FBGEMM_API void depthwise_3x3x3_per_channel_quantization_pad_1_ref( - int N, - int T, - int H, - int W, - int K, - int stride_t, - int stride_h, - int stride_w, - std::int32_t A_zero_point, - const std::uint8_t* A, - const std::int32_t* B_zero_point, - const std::int8_t* B, - const float* C_multiplier, - std::int32_t C_zero_point, - std::uint8_t* C, - const std::int32_t* col_offsets, - const std::int32_t* bias); - } // namespace fbgemm diff --git a/test/I8DepthwiseTest.cc b/test/I8DepthwiseTest.cc index 0604879..9de6943 100644 --- a/test/I8DepthwiseTest.cc +++ b/test/I8DepthwiseTest.cc @@ -4,7 +4,6 @@ * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. */ -#include "I8DepthwiseTest.h" #include <cmath> #include <cstdio> @@ -22,6 +21,7 @@ using namespace std; namespace fbgemm { // From Xray OCR +// clang-format off static vector<vector<int>> shapes = { // NOTE: clang-format wants to use a different formatting but the current // formatting should be easier to read. @@ -68,6 +68,17 @@ static vector<vector<int>> shapes = { { 1, 8, 4, 4, 1, }, }; +static vector<vector<int>> shapes_3d = { + // NOTE: clang-format wants to use a different formatting but the current + // formatting should be easier to read. + // N, K, T_in, H_in, W_in, stride + { 1, 32, 16, 28, 28, 1, }, + { 1, 128, 8, 14, 14, 2, }, + { 5, 16, 32, 56, 56, 1, }, + { 1, 8, 4, 4, 4, 1, }, +}; +// clang-format on + namespace { class FBGemmDepthWiseTest : public testing::TestWithParam<tuple<bool, bool>> {}; @@ -105,13 +116,29 @@ TEST_P(FBGemmDepthWiseTest, Test3x3) { int stride_h = shape[4]; int stride_w = stride_h; constexpr int R = 3, S = 3; - constexpr int PAD_T = 1, PAD_B = 1, PAD_L = 1, PAD_R = 1; - int H_OUT = (H + PAD_T + PAD_B - R) / stride_h + 1; - int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1; + int PAD_T = (R - 1) / 2, PAD_B = (R - 1) / 2, PAD_L = (S - 1) / 2, + PAD_R = (S - 1) / 2; + + conv_param_t<2> conv_p( + N, + K, + K, + {H, W}, + K, + {R, S}, + {stride_h, stride_w}, + {PAD_T, PAD_L, PAD_B, PAD_R}); + int H_OUT = conv_p.OUT_DIM[0]; + int W_OUT = conv_p.OUT_DIM[1]; + + int MDim = N * H_OUT * W_OUT; + int KDim = R * S * K; + int KDimPerGroup = KDim / conv_p.G; aligned_vector<uint8_t> A(N * H * W * K); - aligned_vector<int8_t> B(K * R * S); - aligned_vector<int32_t> C_ref(N * H_OUT * W_OUT * K), C(C_ref.size()); + aligned_vector<int8_t> B(KDim); + aligned_vector<int32_t> C_ref(MDim * K), C(C_ref.size()); + aligned_vector<uint8_t> C_uint8_ref(C_ref.size()), C_uint8(C_ref.size()); randFill<uint8_t>(A, 0, 86); int32_t A_zero_point = a_symmetric ? 0 : 43; @@ -119,48 +146,54 @@ TEST_P(FBGemmDepthWiseTest, Test3x3) { randFill<int8_t>(B, -16, 16); int32_t B_zero_point = b_symmetric ? 0 : 5; - depthwise_3x3_pad_1_ref( - N, - H, - W, - K, - stride_h, - stride_w, - A_zero_point, - A.data(), - B.data(), - C_ref.data()); - - int32_t minimum = *min_element(C_ref.begin(), C_ref.end()); - int32_t maximum = *max_element(C_ref.begin(), C_ref.end()); - - float C_multiplier = 255. / (maximum - minimum); + aligned_vector<float> C_multiplier(1); + randFill(C_multiplier, 0.001234f / 2, 0.001234f * 3 / 2); + int32_t C_zero_point = 5; aligned_vector<int32_t> col_offsets(K); aligned_vector<int32_t> bias(K); randFill(col_offsets, -100, 100); randFill(bias, -40, 40); - int32_t C_zero_point = 5; - aligned_vector<uint8_t> C_uint8_ref(C_ref.size()), C_uint8(C_ref.size()); - depthwise_3x3_pad_1_ref( - N, - H, - W, - K, - stride_h, - stride_w, - A_zero_point, - A.data(), - B_zero_point, - B.data(), - C_multiplier, - C_zero_point, - C_uint8_ref.data(), - col_offsets.data(), - bias.data()); + vector<int32_t> row_offsets(MDim); + // im2col to compute row offset later + vector<uint8_t> A_im2col; + if (!b_symmetric) { + A_im2col.resize(MDim * KDim); + im2col_ref(conv_p, A.data(), A_zero_point, A_im2col.data()); + } + + conv_ref(conv_p, A.data(), A_zero_point, B.data(), C_ref.data()); + + for (int g = 0; g < conv_p.G; ++g) { + // Compute row offset + if (!b_symmetric) { + row_offsets_u8acc32_ref( + MDim, + KDimPerGroup, + KDim, + A_im2col.data() + g * KDimPerGroup, + row_offsets.data()); + } - Packed3x3ConvMatrix Bp(K, B.data()); + // Requantization + requantize_u8acc32_ref( + MDim, + 1, + conv_p.G, + C_ref.data() + g, + C_uint8_ref.data() + g, + C_multiplier.data(), + C_zero_point, + A_zero_point, + &B_zero_point, + row_offsets.data(), + col_offsets.data() + g, + bias.data() + g, + K); + } + + PackedDepthWiseConvMatrix Bp(K, 3 * 3, B.data()); depthwise_3x3_pad_1( N, @@ -173,12 +206,13 @@ TEST_P(FBGemmDepthWiseTest, Test3x3) { A.data(), B_zero_point, Bp, - C_multiplier, + C_multiplier[0], C_zero_point, C_uint8.data(), a_symmetric ? nullptr : col_offsets.data(), bias.data(), false, /* fuse_relu */ + 1.0f, /* act_scale * w_scale */ 0, 1); @@ -220,67 +254,83 @@ TEST_P(FBGemmDepthWiseTest, Test3x3x3) { constexpr int K_T = 3, K_H = 3, K_W = 3; constexpr int PAD_P = 1, PAD_N = 1, PAD_T = 1, PAD_B = 1, PAD_L = 1, PAD_R = 1; - int T_OUT = (T + PAD_P + PAD_N - K_T) / stride_t + 1; - int H_OUT = (H + PAD_T + PAD_B - K_H) / stride_h + 1; - int W_OUT = (W + PAD_L + PAD_R - K_W) / stride_w + 1; + + conv_param_t<3> conv_p( + N, + K, + K, + {T, H, W}, + K, + {K_T, K_H, K_W}, + {stride_t, stride_h, stride_w}, + {PAD_P, PAD_T, PAD_L, PAD_N, PAD_B, PAD_R}); + int T_OUT = conv_p.OUT_DIM[0]; + int H_OUT = conv_p.OUT_DIM[1]; + int W_OUT = conv_p.OUT_DIM[2]; + + int MDim = N * T_OUT * H_OUT * W_OUT; + int KDim = K_T * K_H * K_W * K; + int KDimPerGroup = KDim / conv_p.G; aligned_vector<uint8_t> A(N * T * H * W * K); - aligned_vector<int8_t> B(K * K_T * K_H * K_W); - aligned_vector<int32_t> C_ref(N * T_OUT * H_OUT * W_OUT * K), - C(C_ref.size()); + aligned_vector<int8_t> B(KDim); + aligned_vector<int32_t> C_ref(MDim * K), C(C_ref.size()); + aligned_vector<uint8_t> C_uint8_ref(C_ref.size()), C_uint8(C_ref.size()); randFill<uint8_t>(A, 0, 86); int32_t A_zero_point = a_symmetric ? 0 : 43; randFill<int8_t>(B, -16, 16); - int32_t B_zero_point = 5; - - depthwise_3x3x3_pad_1_ref( - N, - T, - H, - W, - K, - stride_t, - stride_h, - stride_w, - A_zero_point, - A.data(), - B.data(), - C_ref.data()); - - int32_t minimum = *min_element(C_ref.begin(), C_ref.end()); - int32_t maximum = *max_element(C_ref.begin(), C_ref.end()); + int32_t B_zero_point = b_symmetric ? 0 : 5; - float C_multiplier = 255. / (maximum - minimum); + aligned_vector<float> C_multiplier(1); + randFill(C_multiplier, 0.001234f / 2, 0.001234f * 3 / 2); + int32_t C_zero_point = 5; aligned_vector<int32_t> col_offsets(K); aligned_vector<int32_t> bias(K); randFill(col_offsets, -100, 100); randFill(bias, -40, 40); - int32_t C_zero_point = 5; - aligned_vector<uint8_t> C_uint8_ref(C_ref.size()), C_uint8(C_ref.size()); - depthwise_3x3x3_pad_1_ref( - N, - T, - H, - W, - K, - stride_t, - stride_h, - stride_w, - A_zero_point, - A.data(), - B_zero_point, - B.data(), - C_multiplier, - C_zero_point, - C_uint8_ref.data(), - col_offsets.data(), - bias.data()); + vector<int32_t> row_offsets(MDim); + // im2col to compute row offset later + vector<uint8_t> A_im2col; + if (!b_symmetric) { + A_im2col.resize(MDim * KDim); + im2col_ref(conv_p, A.data(), A_zero_point, A_im2col.data()); + } + + conv_ref(conv_p, A.data(), A_zero_point, B.data(), C_ref.data()); + + for (int g = 0; g < conv_p.G; ++g) { + // Compute row offset + if (!b_symmetric) { + row_offsets_u8acc32_ref( + MDim, + KDimPerGroup, + KDim, + A_im2col.data() + g * KDimPerGroup, + row_offsets.data()); + } + + // Requantization + requantize_u8acc32_ref( + MDim, + 1, + conv_p.G, + C_ref.data() + g, + C_uint8_ref.data() + g, + C_multiplier.data(), + C_zero_point, + A_zero_point, + &B_zero_point, + row_offsets.data(), + col_offsets.data() + g, + bias.data() + g, + K); + } - Packed3x3x3ConvMatrix Bp(K, B.data()); + PackedDepthWiseConvMatrix Bp(K, 3 * 3 * 3, B.data()); depthwise_3x3x3_pad_1( N, @@ -295,10 +345,10 @@ TEST_P(FBGemmDepthWiseTest, Test3x3x3) { A.data(), B_zero_point, Bp, - C_multiplier, + C_multiplier[0], C_zero_point, C_uint8.data(), - col_offsets.data(), + a_symmetric ? nullptr : col_offsets.data(), bias.data(), false, /* fuse_relu */ 0, @@ -334,14 +384,29 @@ TEST(FBGemmDepthWiseTest, Test3x3PerChannelQuantization) { int stride_h = shape[4]; int stride_w = stride_h; constexpr int R = 3, S = 3; - constexpr int PAD_T = 1, PAD_B = 1, PAD_L = 1, PAD_R = 1; - int H_OUT = (H + PAD_T + PAD_B - R) / stride_h + 1; - int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1; + int PAD_T = (R - 1) / 2, PAD_B = (R - 1) / 2, PAD_L = (S - 1) / 2, + PAD_R = (S - 1) / 2; + + conv_param_t<2> conv_p( + N, + K, + K, + {H, W}, + K, + {R, S}, + {stride_h, stride_w}, + {PAD_T, PAD_L, PAD_B, PAD_R}); + int H_OUT = conv_p.OUT_DIM[0]; + int W_OUT = conv_p.OUT_DIM[1]; + + int MDim = N * H_OUT * W_OUT; + int KDim = R * S * K; + int KDimPerGroup = KDim / conv_p.G; aligned_vector<uint8_t> A(N * H * W * K); - aligned_vector<int8_t> B(K * R * S); - int32_t C_num_rows = N * H_OUT * W_OUT; - aligned_vector<int32_t> C_ref(C_num_rows * K), C(C_ref.size()); + aligned_vector<int8_t> B(KDim); + aligned_vector<int32_t> C_ref(MDim * K), C(C_ref.size()); + aligned_vector<uint8_t> C_uint8_ref(C_ref.size()), C_uint8(C_ref.size()); randFill<uint8_t>(A, 0, 86); int32_t A_zero_point = 43; @@ -357,28 +422,8 @@ TEST(FBGemmDepthWiseTest, Test3x3PerChannelQuantization) { B_zero_point[k] = 5 + k; } - depthwise_3x3_pad_1_ref( - N, - H, - W, - K, - stride_h, - stride_w, - A_zero_point, - A.data(), - B.data(), - C_ref.data()); - - aligned_vector<int32_t> C_ref_transpose(C_ref); - transpose_matrix(C_ref.data(), C_num_rows, K); - vector<float> C_multiplier(K); - for (auto k = 0; k < K; ++k) { - auto C_ref_k_begin = C_ref_transpose.begin() + k * C_num_rows; - auto C_ref_k_end = C_ref_k_begin + C_num_rows; - int32_t minimum = *min_element(C_ref_k_begin, C_ref_k_end); - int32_t maximum = *max_element(C_ref_k_begin, C_ref_k_end); - C_multiplier[k] = 255. / (maximum - minimum); - } + aligned_vector<float> C_multiplier(K); + randFill(C_multiplier, 0.001234f / 2, 0.001234f * 3 / 2); int32_t C_zero_point = 5; aligned_vector<int32_t> col_offsets(K); @@ -386,25 +431,40 @@ TEST(FBGemmDepthWiseTest, Test3x3PerChannelQuantization) { randFill(col_offsets, -100, 100); randFill(bias, -40, 40); - aligned_vector<uint8_t> C_uint8_ref(C_ref.size()), C_uint8(C_ref.size()); - depthwise_3x3_per_channel_quantization_pad_1_ref( - N, - H, - W, - K, - stride_h, - stride_w, - A_zero_point, - A.data(), - B_zero_point.data(), - B.data(), - C_multiplier.data(), - C_zero_point, - C_uint8_ref.data(), - col_offsets.data(), - bias.data()); + // im2col to compute row offset later + vector<int32_t> row_offsets(MDim); + vector<uint8_t> A_im2col(MDim * KDim); + im2col_ref(conv_p, A.data(), A_zero_point, A_im2col.data()); + + conv_ref(conv_p, A.data(), A_zero_point, B.data(), C_ref.data()); + + for (int g = 0; g < conv_p.G; ++g) { + // Compute row offset + row_offsets_u8acc32_ref( + MDim, + KDimPerGroup, + KDim, + A_im2col.data() + g * KDimPerGroup, + row_offsets.data()); + + // Requantization + requantize_u8acc32_ref( + MDim, + 1, + conv_p.G, + C_ref.data() + g, + C_uint8_ref.data() + g, + C_multiplier.data() + g, + C_zero_point, + A_zero_point, + B_zero_point.data() + g, + row_offsets.data(), + col_offsets.data() + g, + bias.data() + g, + K); + } - Packed3x3ConvMatrix Bp(K, B.data()); + PackedDepthWiseConvMatrix Bp(K, 3 * 3, B.data()); depthwise_3x3_per_channel_quantization_pad_1( N, @@ -457,14 +517,28 @@ TEST(FBGemmDepthWiseTest, Test3x3x3PerChannelQuantization) { constexpr int K_T = 3, K_H = 3, K_W = 3; constexpr int PAD_P = 1, PAD_N = 1, PAD_T = 1, PAD_B = 1, PAD_L = 1, PAD_R = 1; - int T_OUT = (T + PAD_P + PAD_N - K_T) / stride_t + 1; - int H_OUT = (H + PAD_T + PAD_B - K_H) / stride_h + 1; - int W_OUT = (W + PAD_L + PAD_R - K_W) / stride_w + 1; + + conv_param_t<3> conv_p( + N, + K, + K, + {T, H, W}, + K, + {K_T, K_H, K_W}, + {stride_t, stride_h, stride_w}, + {PAD_P, PAD_T, PAD_L, PAD_N, PAD_B, PAD_R}); + int T_OUT = conv_p.OUT_DIM[0]; + int H_OUT = conv_p.OUT_DIM[1]; + int W_OUT = conv_p.OUT_DIM[2]; + + int MDim = N * T_OUT * H_OUT * W_OUT; + int KDim = K_T * K_H * K_W * K; + int KDimPerGroup = KDim / conv_p.G; aligned_vector<uint8_t> A(N * T * H * W * K); - aligned_vector<int8_t> B(K * K_T * K_H * K_W); - int32_t C_num_rows = N * T_OUT * H_OUT * W_OUT; - aligned_vector<int32_t> C_ref(C_num_rows * K), C(C_ref.size()); + aligned_vector<int8_t> B(KDim); + aligned_vector<int32_t> C_ref(MDim * K), C(C_ref.size()); + aligned_vector<uint8_t> C_uint8_ref(C_ref.size()), C_uint8(C_ref.size()); randFill<uint8_t>(A, 0, 86); int32_t A_zero_point = 43; @@ -480,30 +554,8 @@ TEST(FBGemmDepthWiseTest, Test3x3x3PerChannelQuantization) { B_zero_point[k] = 5 + k; } - depthwise_3x3x3_pad_1_ref( - N, - T, - H, - W, - K, - stride_t, - stride_h, - stride_w, - A_zero_point, - A.data(), - B.data(), - C_ref.data()); - - aligned_vector<int32_t> C_ref_transpose(C_ref); - transpose_matrix(C_ref.data(), C_num_rows, K); - vector<float> C_multiplier(K); - for (auto k = 0; k < K; ++k) { - auto C_ref_k_begin = C_ref_transpose.begin() + k * C_num_rows; - auto C_ref_k_end = C_ref_k_begin + C_num_rows; - int32_t minimum = *min_element(C_ref_k_begin, C_ref_k_end); - int32_t maximum = *max_element(C_ref_k_begin, C_ref_k_end); - C_multiplier[k] = 255. / (maximum - minimum); - } + aligned_vector<float> C_multiplier(K); + randFill(C_multiplier, 0.001234f / 2, 0.001234f * 3 / 2); int32_t C_zero_point = 5; aligned_vector<int32_t> col_offsets(K); @@ -511,27 +563,40 @@ TEST(FBGemmDepthWiseTest, Test3x3x3PerChannelQuantization) { randFill(col_offsets, -100, 100); randFill(bias, -40, 40); - aligned_vector<uint8_t> C_uint8_ref(C_ref.size()), C_uint8(C_ref.size()); - depthwise_3x3x3_per_channel_quantization_pad_1_ref( - N, - T, - H, - W, - K, - stride_t, - stride_h, - stride_w, - A_zero_point, - A.data(), - B_zero_point.data(), - B.data(), - C_multiplier.data(), - C_zero_point, - C_uint8_ref.data(), - col_offsets.data(), - bias.data()); + vector<int32_t> row_offsets(MDim); + // im2col to compute row offset later + vector<uint8_t> A_im2col(MDim * KDim); + im2col_ref(conv_p, A.data(), A_zero_point, A_im2col.data()); + + conv_ref(conv_p, A.data(), A_zero_point, B.data(), C_ref.data()); + + for (int g = 0; g < conv_p.G; ++g) { + // Compute row offset + row_offsets_u8acc32_ref( + MDim, + KDimPerGroup, + KDim, + A_im2col.data() + g * KDimPerGroup, + row_offsets.data()); + + // Requantization + requantize_u8acc32_ref( + MDim, + 1, + conv_p.G, + C_ref.data() + g, + C_uint8_ref.data() + g, + C_multiplier.data() + g, + C_zero_point, + A_zero_point, + B_zero_point.data() + g, + row_offsets.data(), + col_offsets.data() + g, + bias.data() + g, + K); + } - Packed3x3x3ConvMatrix Bp(K, B.data()); + PackedDepthWiseConvMatrix Bp(K, 3 * 3 * 3, B.data()); depthwise_3x3x3_per_channel_quantization_pad_1( N, @@ -587,36 +652,8 @@ TEST_P(FBGemmDepthWisePackUnpackTest, TestPackUnpack) { aligned_vector<int8_t> BUnpacked(K * kernel_prod); - if (kernel_prod == 1) { - Packed1ConvMatrix BPacked(K, B.data()); - BPacked.unpack(BUnpacked.data()); - } else if (kernel_prod == 2) { - Packed2ConvMatrix BPacked(K, B.data()); - BPacked.unpack(BUnpacked.data()); - } else if (kernel_prod == 3) { - Packed3ConvMatrix BPacked(K, B.data()); - BPacked.unpack(BUnpacked.data()); - } else if (kernel_prod == 4) { - Packed4ConvMatrix BPacked(K, B.data()); - BPacked.unpack(BUnpacked.data()); - } else if (kernel_prod == 5) { - Packed5ConvMatrix BPacked(K, B.data()); - BPacked.unpack(BUnpacked.data()); - } else if (kernel_prod == 9) { - Packed3x3ConvMatrix BPacked(K, B.data()); - BPacked.unpack(BUnpacked.data()); - } else if (kernel_prod == 10) { - Packed10ConvMatrix BPacked(K, B.data()); - BPacked.unpack(BUnpacked.data()); - } else if (kernel_prod == 11) { - Packed11ConvMatrix BPacked(K, B.data()); - BPacked.unpack(BUnpacked.data()); - } else if (kernel_prod == 27) { - Packed3x3x3ConvMatrix BPacked(K, B.data()); - BPacked.unpack(BUnpacked.data()); - } else { - ASSERT_TRUE(false); - } + PackedDepthWiseConvMatrix BPacked(K, kernel_prod, B.data()); + BPacked.unpack(BUnpacked.data()); ASSERT_EQ(B, BUnpacked) << "Original and unpacked data elements are not the same"; diff --git a/test/I8DepthwiseTest.h b/test/I8DepthwiseTest.h deleted file mode 100644 index d65362a..0000000 --- a/test/I8DepthwiseTest.h +++ /dev/null @@ -1,39 +0,0 @@ -/* - * Copyright (c) Facebook, Inc. and its affiliates. - * All rights reserved. - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ -#pragma once - -#include <vector> - -namespace fbgemm { - -// From ResNeXt-3D-101 -static std::vector<std::vector<int>> shapes_3d = { - // NOTE: clang-format wants to use a different formatting but the current - // formatting should be easier to read. - // N, K, T_in, H_in, W_in, stride - { 1, 64, 32, 56, 56, 1, }, - { 1, 128, 16, 28, 28, 1, }, - { 1, 256, 8, 14, 14, 1, }, - { 1, 512, 4, 7, 7, 1, }, - - { 1, 128, 32, 56, 56, 2, }, - { 1, 256, 16, 28, 28, 2, }, - { 1, 512, 8, 14, 14, 2, }, - - { 5, 64, 32, 56, 56, 1, }, - { 5, 128, 16, 28, 28, 1, }, - { 5, 256, 8, 14, 14, 1, }, - { 5, 512, 4, 7, 7, 1, }, - - { 5, 128, 32, 56, 56, 2, }, - { 5, 256, 16, 28, 28, 2, }, - { 5, 512, 8, 14, 14, 2, }, - - { 1, 8, 4, 4, 4, 1, }, -}; - -} // namespace fbgemm diff --git a/test/RequantizeOnlyTest.cc b/test/RequantizeOnlyTest.cc new file mode 100644 index 0000000..2f73d49 --- /dev/null +++ b/test/RequantizeOnlyTest.cc @@ -0,0 +1,169 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include <algorithm> +#include <functional> +#include <iostream> +#include <random> +#include <stdexcept> +#include <string> + +#include <gtest/gtest.h> + +#include "TestUtils.h" +#include "bench/BenchUtils.h" +#include "fbgemm/Fbgemm.h" + +using namespace std; +using namespace fbgemm; + +vector<QuantizationGranularity> qGranularityVals{ + QuantizationGranularity::TENSOR, + QuantizationGranularity::OUT_CHANNEL}; + +namespace { + +// tuple represents #rows, #cols, fuse_relu, quantization_granularity, bias_type +class FloatRequantizeTest + : public testing::TestWithParam< + tuple<int, int, bool, QuantizationGranularity>> {}; + +}; // namespace + +INSTANTIATE_TEST_CASE_P( + InstantiationName, + FloatRequantizeTest, + ::testing::Combine( + ::testing::ValuesIn({1, 2, 3, 4}), // number of rows + ::testing::ValuesIn( + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 16, 20, 32}), // number of + // cols + ::testing::Bool(), // fuse relu + ::testing::ValuesIn(qGranularityVals))); // requantization granularity + +/** + * Test for float bias + */ +TEST_P(FloatRequantizeTest, floatBiasTest) { + int rows, cols; + bool fuse_relu; + QuantizationGranularity q_gran; + tie(rows, cols, fuse_relu, q_gran) = GetParam(); + + int numElements = rows * cols; + + aligned_vector<float> act_times_w_scale(cols); + randFill<float>(act_times_w_scale, -8, 8); + + float out_scale = 2.0f; + + aligned_vector<float> C_multiplier(cols); + transform( + act_times_w_scale.begin(), + act_times_w_scale.end(), + C_multiplier.begin(), + [&out_scale](float i) { return i / out_scale; }); + + aligned_vector<int32_t> Bint8_zero_point(cols); + randFill<int32_t>(Bint8_zero_point, -8, 8); + + aligned_vector<int32_t> row_offset_buf(rows); + randFill<int32_t>(row_offset_buf, -8, 8); + + aligned_vector<int32_t> col_offsets(cols); + randFill<int32_t>(col_offsets, -8, 8); + + // quantized bias + aligned_vector<int32_t> bias_q(cols); + randFill<int32_t>(bias_q, -8, 8); + + // floating point bias + aligned_vector<float> bias_f(cols); + if (q_gran == QuantizationGranularity::TENSOR) { + transform( + bias_q.begin(), + bias_q.end(), + bias_f.begin(), + [&act_times_w_scale](float i) { return i * act_times_w_scale[0]; }); + } else if (q_gran == QuantizationGranularity::OUT_CHANNEL) { + transform( + act_times_w_scale.begin(), + act_times_w_scale.end(), + bias_q.begin(), + bias_f.begin(), + multiplies<float>()); + + } else { + FAIL(); + } + + aligned_vector<int32_t> input(numElements); + randFill<int32_t>(input, -8, 8); + + aligned_vector<uint8_t> output_q_bias(numElements); + aligned_vector<uint8_t> output_f_bias(numElements); + + int32_t C_zero_point = 3; + int32_t Aint8_zero_point = 3; + + block_type_t block{0, rows, 0, cols}; + + DoNothing<> doNothingObj{}; + +#define TESTCODE(FUSE_RELU, Q_GRAN) \ + ReQuantizeOutput<FUSE_RELU, Q_GRAN> reqObj_q( \ + doNothingObj, \ + C_multiplier.data(), \ + C_zero_point, \ + Aint8_zero_point, \ + Bint8_zero_point.data(), \ + row_offset_buf.data(), \ + col_offsets.data(), \ + bias_q.data(), \ + cols); \ + ReQuantizeOutput<FUSE_RELU, Q_GRAN, float> reqObj_f( \ + doNothingObj, \ + C_multiplier.data(), \ + C_zero_point, \ + Aint8_zero_point, \ + Bint8_zero_point.data(), \ + row_offset_buf.data(), \ + col_offsets.data(), \ + bias_f.data(), \ + cols, \ + 1, \ + act_times_w_scale.data()); \ + reqObj_q.f<inst_set_t::avx2>( \ + output_q_bias.data(), input.data(), block, cols, cols); \ + reqObj_f.f<inst_set_t::avx2>( \ + output_f_bias.data(), input.data(), block, cols, cols); + + if (fuse_relu) { + if (q_gran == QuantizationGranularity::TENSOR) { + TESTCODE(true, QuantizationGranularity::TENSOR) + + } else if (q_gran == QuantizationGranularity::OUT_CHANNEL) { + TESTCODE(true, QuantizationGranularity::OUT_CHANNEL) + + } else { + FAIL(); + } + + } else { + if (q_gran == QuantizationGranularity::TENSOR) { + TESTCODE(false, QuantizationGranularity::TENSOR) + + } else if (q_gran == QuantizationGranularity::OUT_CHANNEL) { + TESTCODE(false, QuantizationGranularity::OUT_CHANNEL) + + } else { + FAIL(); + } + } +#undef TESTCODE + ASSERT_EQ(output_q_bias, output_f_bias) + << "Requantization with quantized bias and float bias differs"; +} diff --git a/test/UniConvTest.cc b/test/UniConvTest.cc index 91bf578..cead3a6 100644 --- a/test/UniConvTest.cc +++ b/test/UniConvTest.cc @@ -5,11 +5,10 @@ * LICENSE file in the root directory of this source tree. */ #include <algorithm> -#include <random> #include <iostream> +#include <random> #include <stdexcept> - #include <gtest/gtest.h> #include "QuantizationHelpers.h" @@ -21,6 +20,67 @@ using namespace std; using namespace fbgemm; +vector<QuantizationGranularity> qGranularityVals{ + QuantizationGranularity::TENSOR, + QuantizationGranularity::GROUP, + QuantizationGranularity::OUT_CHANNEL}; + +// clang-format off +static vector<conv_param_t<>> GetShapes_() { + vector<conv_param_t<>> shapes = { + // MB, IC, OC, {IH, IW}, G, {KH, KW}, {stride_h, stride_w}, {pad_t, pad_l, + // pad_b, pad_r} + // Regular + conv_param_t<>(1, 16, 16, {10, 30}, 1, {3, 3}, {1, 1}, {1, 1, 1, 1}), + conv_param_t<>(1, 32, 32, {10, 30}, 1, {3, 3}, {1, 1}, {1, 1, 1, 1}), + conv_param_t<>(1, 16, 32, {30, 10}, 1, {3, 3}, {1, 1}, {1, 1, 1, 1}), + conv_param_t<>(1, 32, 16, {10, 30}, 1, {3, 3}, {1, 1}, {1, 1, 1, 1}), + conv_param_t<>(1, 32, 16, {10, 30}, 1, {3, 3}, {1, 1}, {0, 0, 0, 0}), + conv_param_t<>(1, 32, 16, {10, 30}, 1, {3, 3}, {1, 1}, {0, 0, 0, 0}, {2, 2}), + conv_param_t<>(1, 32, 16, {10, 30}, 1, {3, 3}, {1, 1}, {1, 1, 1, 1}, {2, 2}), + conv_param_t<>(1, 32, 16, {10, 30}, 1, {3, 3}, {1, 1}, {1, 1, 1, 1}, {3, 3}), + conv_param_t<>(1, 32, 16, {10, 30}, 1, {3, 3}, {2, 2}, {1, 1, 1, 1}), + conv_param_t<>(1, 32, 16, {10, 30}, 1, {3, 3}, {2, 2}, {1, 1, 1, 1}, {2, 2}), + conv_param_t<>(1, 32, 16, {10, 30}, 1, {3, 3}, {2, 2}, {1, 1, 1, 1}, {2, 1}), + conv_param_t<>(1, 32, 16, {10, 30}, 1, {3, 3}, {2, 2}, {1, 1, 1, 1}, {1, 2}), + conv_param_t<>(1, 32, 16, {10, 30}, 1, {3, 3}, {1, 1}, {2, 1, 2, 1}), + conv_param_t<>(1, 32, 16, {10, 30}, 1, {3, 3}, {1, 1}, {1, 2, 1, 2}), + conv_param_t<>(1, 32, 16, {10, 30}, 1, {3, 3}, {1, 2}, {1, 1, 1, 1}), + conv_param_t<>(1, 32, 16, {10, 30}, 1, {3, 3}, {2, 1}, {1, 1, 1, 1}), + conv_param_t<>(1, 16, 16, {10, 30}, 1, {3, 5}, {1, 1}, {1, 1, 1, 1}), + conv_param_t<>(1, 16, 16, {10, 30}, 1, {5, 3}, {1, 1}, {1, 1, 1, 1}), + conv_param_t<>(1, 16, 16, {10, 30}, 1, {5, 3}, {1, 1}, {1, 1, 1, 1}, {2, 2}), + // groupwise + conv_param_t<>(1, 32, 32, {10, 30}, 8, {3, 3}, {1, 1}, {1, 1, 1, 1}), + conv_param_t<>(1, 32, 16, {10, 30}, 8, {3, 3}, {1, 1}, {1, 1, 1, 1}), + conv_param_t<>(1, 16, 32, {10, 30}, 8, {3, 3}, {2, 2}, {1, 1, 1, 1}), + conv_param_t<>(1, 32, 32, {10, 30}, 8, {3, 3}, {2, 2}, {2, 1, 2, 1}), + conv_param_t<>(1, 32, 32, {10, 30}, 8, {3, 3}, {1, 2}, {2, 1, 2, 1}), + conv_param_t<>(1, 32, 32, {10, 30}, 8, {3, 3}, {2, 1}, {2, 1, 2, 1}), + conv_param_t<>(1, 32, 32, {10, 30}, 8, {3, 5}, {1, 1}, {1, 1, 1, 1}), + conv_param_t<>(1, 32, 32, {10, 30}, 8, {5, 3}, {1, 1}, {1, 1, 1, 1}), + conv_param_t<>(1, 32, 32, {10, 30}, 8, {5, 3}, {1, 1}, {1, 1, 1, 1}, {2, 2}), + // DW + conv_param_t<>(1, 32, 32, {10, 30}, 32, {3, 3}, {1, 1}, {1, 1, 1, 1}), + conv_param_t<>(1, 32, 32, {10, 30}, 32, {3, 3}, {2, 2}, {1, 1, 1, 1}), + conv_param_t<>(1, 32, 32, {10, 30}, 32, {3, 3}, {1, 1}, {1, 2, 1, 2}), + conv_param_t<>(1, 32, 32, {10, 30}, 32, {3, 3}, {2, 1}, {1, 1, 1, 1}), + conv_param_t<>(1, 32, 32, {10, 30}, 32, {3, 3}, {1, 2}, {1, 1, 1, 1}), + conv_param_t<>(1, 32, 32, {10, 30}, 32, {3, 5}, {1, 1}, {1, 1, 1, 1}), + conv_param_t<>(1, 32, 32, {10, 30}, 32, {5, 3}, {1, 1}, {1, 1, 1, 1}), + conv_param_t<>(1, 32, 32, {10, 30}, 32, {5, 3}, {1, 1}, {1, 1, 1, 1}, {2, 2}), + // Pointwise + conv_param_t<>(1, 32, 32, {10, 30}, 1, {1, 1}, {1, 1}, {0, 0, 0, 0}), + conv_param_t<>(1, 16, 32, {10, 30}, 1, {1, 1}, {1, 1}, {0, 0, 0, 0}), + conv_param_t<>(1, 32, 16, {10, 30}, 1, {1, 1}, {1, 1}, {0, 0, 0, 0}), + conv_param_t<>(1, 32, 16, {10, 30}, 1, {1, 1}, {2, 2}, {0, 0, 0, 0}), + conv_param_t<>(1, 32, 16, {10, 30}, 1, {1, 1}, {1, 2}, {0, 0, 0, 0}), + conv_param_t<>(1, 32, 16, {10, 30}, 1, {1, 1}, {2, 1}, {0, 0, 0, 0}), + }; + return shapes; +} +// clang-format on + namespace { // tuple represents MB, IC, OC, IT, IH, IW, KH/KW, stride, pad @@ -28,8 +88,15 @@ class uniConvTest : public testing::TestWithParam< tuple<int, int, int, int, int, int, int, int, int, int>> {}; +// tuple represents QuantizationGranularity, A symmetric, B symmetric, +// test_bias, test_float_bias +class UniConvQGranTest + : public testing::TestWithParam< + tuple<QuantizationGranularity, bool, bool, bool, bool>> {}; + }; // namespace +// Combine only allows at most 10 generators. INSTANTIATE_TEST_CASE_P( InstantiationName, uniConvTest, @@ -45,6 +112,15 @@ INSTANTIATE_TEST_CASE_P( ::testing::ValuesIn({1, 2}), // stride ::testing::ValuesIn({0, 1, 2}))); // pad +INSTANTIATE_TEST_CASE_P( + InstantiationName, + UniConvQGranTest, + ::testing::Combine( + ::testing::ValuesIn(qGranularityVals), + ::testing::Bool(), // A symmetric + ::testing::Bool(), // B symmetric + ::testing::Bool(), // test_bias + ::testing::Bool())); // test_float_bias /** * Test for conv packing */ @@ -71,23 +147,19 @@ TEST_P(uniConvTest, packingTest) { case optimized_conv_t::depthwise: { ASSERT_EQ(packedB_2D.getPackedWForIm2col(), nullptr) << "im2col packed matrix should be null"; - ASSERT_EQ(packedB_2D.getPackedWFor3DDW(), nullptr) - << "3D depthwise packed matrix should be null"; ASSERT_EQ(packedB_2D.getPackedWForGroupwise(), nullptr) << "groupwise packed matrix should be null"; ASSERT_EQ(packedB_2D.getPackedWForPointwise(), nullptr) << "pointwise packed matrix should be null"; - ASSERT_NE(packedB_2D.getPackedWFor2DDW(), nullptr) - << "2D depthwise packed matrix is null"; + ASSERT_NE(packedB_2D.getPackedWForDepthwise(), nullptr) + << "depthwise packed matrix is null"; break; } case optimized_conv_t::groupwise: { ASSERT_EQ(packedB_2D.getPackedWForIm2col(), nullptr) << "im2col packed matrix should be null"; - ASSERT_EQ(packedB_2D.getPackedWFor2DDW(), nullptr) - << "2D depthwise packed matrix is null"; - ASSERT_EQ(packedB_2D.getPackedWFor3DDW(), nullptr) - << "3D depthwise packed matrix should be null"; + ASSERT_EQ(packedB_2D.getPackedWForDepthwise(), nullptr) + << "depthwise packed matrix should be null"; ASSERT_EQ(packedB_2D.getPackedWForPointwise(), nullptr) << "pointwise packed matrix should be null"; ASSERT_NE(packedB_2D.getPackedWForGroupwise(), nullptr) @@ -97,10 +169,8 @@ TEST_P(uniConvTest, packingTest) { case optimized_conv_t::pointwise: { ASSERT_EQ(packedB_2D.getPackedWForIm2col(), nullptr) << "im2col packed matrix should be null"; - ASSERT_EQ(packedB_2D.getPackedWFor2DDW(), nullptr) - << "2D depthwise packed matrix is null"; - ASSERT_EQ(packedB_2D.getPackedWFor3DDW(), nullptr) - << "3D depthwise packed matrix should be null"; + ASSERT_EQ(packedB_2D.getPackedWForDepthwise(), nullptr) + << "depthwise packed matrix should null"; ASSERT_EQ(packedB_2D.getPackedWForGroupwise(), nullptr) << "Groupwise packed matrix should be null"; ASSERT_NE(packedB_2D.getPackedWForPointwise(), nullptr) @@ -108,10 +178,8 @@ TEST_P(uniConvTest, packingTest) { break; } case optimized_conv_t::im2col: { - ASSERT_EQ(packedB_2D.getPackedWFor2DDW(), nullptr) - << "2D depthwise packed matrix is null"; - ASSERT_EQ(packedB_2D.getPackedWFor3DDW(), nullptr) - << "3D depthwise packed matrix should be null"; + ASSERT_EQ(packedB_2D.getPackedWForDepthwise(), nullptr) + << "depthwise packed matrix should be null"; ASSERT_EQ(packedB_2D.getPackedWForGroupwise(), nullptr) << "groupwise packed matrix should be null"; ASSERT_EQ(packedB_2D.getPackedWForPointwise(), nullptr) @@ -139,16 +207,14 @@ TEST_P(uniConvTest, packingTest) { switch (ConvFastPath<3, int32_t>(conv_p_3d)) { case optimized_conv_t::depthwise: { - ASSERT_EQ(packedB_3D.getPackedWFor2DDW(), nullptr) - << "2D depthwise packed matrix is null"; ASSERT_EQ(packedB_3D.getPackedWForIm2col(), nullptr) << "im2col packed matrix should be null"; ASSERT_EQ(packedB_3D.getPackedWForGroupwise(), nullptr) << "groupwise packed matrix should be null"; ASSERT_EQ(packedB_3D.getPackedWForPointwise(), nullptr) << "pointwise packed matrix should be null"; - ASSERT_NE(packedB_3D.getPackedWFor3DDW(), nullptr) - << "3D depthwise packed matrix should be null"; + ASSERT_NE(packedB_3D.getPackedWForDepthwise(), nullptr) + << "depthwise packed matrix is null"; break; } case optimized_conv_t::groupwise: { @@ -156,10 +222,8 @@ TEST_P(uniConvTest, packingTest) { break; } case optimized_conv_t::pointwise: { - ASSERT_EQ(packedB_3D.getPackedWFor2DDW(), nullptr) - << "2D depthwise packed matrix is null"; - ASSERT_EQ(packedB_3D.getPackedWFor3DDW(), nullptr) - << "3D depthwise packed matrix should be null"; + ASSERT_EQ(packedB_3D.getPackedWForDepthwise(), nullptr) + << "depthwise packed matrix should be null"; ASSERT_EQ(packedB_3D.getPackedWForGroupwise(), nullptr) << "groupwise packed matrix should be null"; ASSERT_EQ(packedB_3D.getPackedWForIm2col(), nullptr) @@ -169,10 +233,8 @@ TEST_P(uniConvTest, packingTest) { break; } case optimized_conv_t::im2col: { - ASSERT_EQ(packedB_3D.getPackedWFor2DDW(), nullptr) - << "2D depthwise packed matrix is null"; - ASSERT_EQ(packedB_3D.getPackedWFor3DDW(), nullptr) - << "3D depthwise packed matrix should be null"; + ASSERT_EQ(packedB_3D.getPackedWForDepthwise(), nullptr) + << "depthwise packed matrix should be null"; ASSERT_EQ(packedB_3D.getPackedWForGroupwise(), nullptr) << "groupwise packed matrix should be null"; ASSERT_EQ(packedB_3D.getPackedWForPointwise(), nullptr) @@ -323,3 +385,335 @@ TEST(uniConvTest, cornerCases) { 0, 1); } + +/** + * @brief Unit test for uint8 activations, int8 weights, and 32-bit + * accumulation. Output processing: requantization -> nothing + */ +TEST_P(UniConvQGranTest, requantizeTest) { + vector<conv_param_t<>> shapes(GetShapes_()); + QuantizationGranularity q_granularity; + bool a_symmetric, b_symmetric; + bool test_bias, test_float_bias; + tie(q_granularity, a_symmetric, b_symmetric, test_bias, test_float_bias) = + GetParam(); + + for (auto conv_p : shapes) { + int R = conv_p.K[0]; + int S = conv_p.K[1]; + int G = conv_p.G; + int OC = conv_p.OC; + int OH = conv_p.OUT_DIM[0]; + int OW = conv_p.OUT_DIM[1]; + int IC_per_G = conv_p.IC / conv_p.G; + int OC_per_G = conv_p.OC / conv_p.G; + + // activations + aligned_vector<uint8_t> Aint8( + conv_p.MB * conv_p.IN_DIM[0] * conv_p.IN_DIM[1] * conv_p.IC, 0); + + // weights + // The weight matrix is in layout G K/G (R S C/G) + aligned_vector<int8_t> Bint8(R * S * conv_p.G * IC_per_G * OC_per_G, 0); + aligned_vector<int8_t> Bint8_tr(R * S * G * IC_per_G * OC_per_G, 0); + + aligned_vector<int32_t> Cint32_ref(conv_p.MB * OH * OW * OC, 0); + aligned_vector<int32_t> Cint32_fb(Cint32_ref.size(), 0); + aligned_vector<uint8_t> Cint8_ref(Cint32_ref.size(), 0); + aligned_vector<uint8_t> Cint8_fb(Cint32_ref.size(), 0); + + randFill<uint8_t>(Aint8, 0, 5); + int32_t Aint8_zero_point = a_symmetric ? 0 : 4; + + randFill<int8_t>(Bint8, -4, 4); + + // computing column offset + vector<int32_t> col_offsets(G * OC_per_G); + + int ncols_per_quant_group = G * OC_per_G; + if (q_granularity == QuantizationGranularity::GROUP) { + ncols_per_quant_group = OC_per_G; + } else if (q_granularity == QuantizationGranularity::OUT_CHANNEL) { + ncols_per_quant_group = 1; + } + + aligned_vector<int32_t> Bint8_zero_point( + G * OC_per_G / ncols_per_quant_group); + if (b_symmetric) { + randFill(Bint8_zero_point, -3, 3); + } else { + randFill(Bint8_zero_point, 0, 0); + } + + // matrix dimensions after im2col for each GEMM. + // For each group, there is one GEMM of the following dimensions + int MDim = conv_p.MB * OH * OW; + int NDim = OC_per_G; + int KDim = R * S * IC_per_G; + + vector<uint8_t> Aint8_im2col(MDim * KDim * G); + im2col_ref(conv_p, Aint8.data(), Aint8_zero_point, Aint8_im2col.data()); + + vector<int32_t> row_offsets(MDim); + + // activation_scale * weight_scale + aligned_vector<float> act_times_w_scale(Bint8_zero_point.size()); + randFill(act_times_w_scale, 0.1234f / 2, 0.1234f * 3 / 2); + + float out_scale = 2.0f; + aligned_vector<float> C_multiplier(Bint8_zero_point.size()); + transform( + act_times_w_scale.begin(), + act_times_w_scale.end(), + C_multiplier.begin(), + [&out_scale](float i) { return i / out_scale; }); + + int32_t C_zero_pt = 5; + + // initialize bias + aligned_vector<int32_t> bias_int32(OC); + aligned_vector<float> bias_fp32(OC); + if (test_bias) { + randFill(bias_int32, -8, 8); + } + + // floating point bias + if (test_float_bias) { + if (q_granularity == QuantizationGranularity::TENSOR) { + transform( + bias_int32.begin(), + bias_int32.end(), + bias_fp32.begin(), + [&act_times_w_scale](float i) { return i * act_times_w_scale[0]; }); + } else if (q_granularity == QuantizationGranularity::GROUP) { + for (int g = 0; g < G; ++g) { + for (int c = 0; c < OC_per_G; ++c) { + bias_fp32[g * OC_per_G + c] = act_times_w_scale[g] * + static_cast<float>(bias_int32[g * OC_per_G + c]); + } + } + } else { // OUT_CHANNEL + transform( + act_times_w_scale.begin(), + act_times_w_scale.end(), + bias_int32.begin(), + bias_fp32.begin(), + multiplies<float>()); + } + } + // reference implementation + // conv_ref expects weights to be in G (R S C/G) K/G + int8_t* rightBData = Bint8.data(); + transposeConvWeights(conv_p, Bint8.data(), Bint8_tr.data()); + rightBData = Bint8_tr.data(); + for (int g = 0; g < G; ++g) { + col_offsets_with_zero_pt_s8acc32_ref( + R * S * IC_per_G, + OC_per_G, + OC_per_G, + rightBData + g * R * S * IC_per_G * OC_per_G, + Bint8_zero_point.data() + g * OC_per_G / ncols_per_quant_group, + col_offsets.data() + g * OC_per_G, + ncols_per_quant_group); + } + conv_ref( + conv_p, Aint8.data(), Aint8_zero_point, rightBData, Cint32_ref.data()); + + for (int g = 0; g < G; ++g) { + row_offsets_u8acc32_ref( + MDim, + KDim, + KDim * G, + Aint8_im2col.data() + g * KDim, + row_offsets.data()); + + requantize_u8acc32_ref( + MDim, + NDim, + G * NDim, + Cint32_ref.data() + g * NDim, + Cint8_ref.data() + g * NDim, + C_multiplier.data() + g * NDim / ncols_per_quant_group, + C_zero_pt, + Aint8_zero_point, + Bint8_zero_point.data() + g * NDim / ncols_per_quant_group, + row_offsets.data(), + col_offsets.data() + g * NDim, + test_bias ? bias_int32.data() + g * NDim : nullptr, + ncols_per_quant_group); + } + + PackWeightsForConv<2> packedWeights(conv_p, Bint8.data()); + + // TODO: Uncomment once we support multiple threads in fbgemmGroupwiseConv + // #ifdef _OPENMP + // #pragma omp parallel + // #endif + { + vector<int32_t> row_offset_buf(rowOffsetBufferSizeGConv(conv_p)); + + DoNothing<> doNothingObj{}; + + int num_threads = fbgemm_get_num_threads(); + int tid = fbgemm_get_thread_num(); + + if (q_granularity == QuantizationGranularity::TENSOR) { + if (test_float_bias) { + ReQuantizeOutput<false, QuantizationGranularity::TENSOR, float> + reqObj( + doNothingObj, + C_multiplier.data(), + C_zero_pt, + Aint8_zero_point, + Bint8_zero_point.data(), + nullptr, /* row offset buffer */ + col_offsets.data(), + test_bias ? bias_fp32.data() : nullptr, + G * NDim, + G, + act_times_w_scale.data()); + + fbgemmConv( + conv_p, + Aint8.data(), + packedWeights, + Cint8_fb.data(), + Cint32_fb.data(), + reqObj, + tid, + num_threads); + + } else { + ReQuantizeOutput<false, QuantizationGranularity::TENSOR> reqObj( + doNothingObj, + C_multiplier.data(), + C_zero_pt, + Aint8_zero_point, + Bint8_zero_point.data(), + nullptr, /* row offset buffer */ + col_offsets.data(), + test_bias ? bias_int32.data() : nullptr, + G * NDim, + G); + + fbgemmConv( + conv_p, + Aint8.data(), + packedWeights, + Cint8_fb.data(), + Cint32_fb.data(), + reqObj, + tid, + num_threads); + } + + } else if (q_granularity == QuantizationGranularity::GROUP) { + if (test_float_bias) { + ReQuantizeOutput<false, QuantizationGranularity::GROUP, float> reqObj( + doNothingObj, + C_multiplier.data(), + C_zero_pt, + Aint8_zero_point, + Bint8_zero_point.data(), + nullptr, /* row offset buffer */ + col_offsets.data(), + test_bias ? bias_fp32.data() : nullptr, + G * NDim, + G, + act_times_w_scale.data()); + + fbgemmConv( + conv_p, + Aint8.data(), + packedWeights, + Cint8_fb.data(), + Cint32_fb.data(), + reqObj, + tid, + num_threads); + + } else { + ReQuantizeOutput<false, QuantizationGranularity::GROUP> reqObj( + doNothingObj, + C_multiplier.data(), + C_zero_pt, + Aint8_zero_point, + Bint8_zero_point.data(), + nullptr, /* row offset buffer */ + col_offsets.data(), + test_bias ? bias_int32.data() : nullptr, + G * NDim, + G); + + fbgemmConv( + conv_p, + Aint8.data(), + packedWeights, + Cint8_fb.data(), + Cint32_fb.data(), + reqObj, + tid, + num_threads); + } + + } else { + if (test_float_bias) { + ReQuantizeOutput<false, QuantizationGranularity::OUT_CHANNEL, float> + reqObj( + doNothingObj, + C_multiplier.data(), + C_zero_pt, + Aint8_zero_point, + Bint8_zero_point.data(), + nullptr, /* row offset buffer */ + col_offsets.data(), + test_bias ? bias_fp32.data() : nullptr, + G * NDim, + G, + act_times_w_scale.data()); + + fbgemmConv( + conv_p, + Aint8.data(), + packedWeights, + Cint8_fb.data(), + Cint32_fb.data(), + reqObj, + tid, + num_threads); + + } else { + ReQuantizeOutput<false, QuantizationGranularity::OUT_CHANNEL> reqObj( + doNothingObj, + C_multiplier.data(), + C_zero_pt, + Aint8_zero_point, + Bint8_zero_point.data(), + nullptr, /* row offset buffer */ + col_offsets.data(), + test_bias ? bias_int32.data() : nullptr, + G * NDim, + G); + + fbgemmConv( + conv_p, + Aint8.data(), + packedWeights, + Cint8_fb.data(), + Cint32_fb.data(), + reqObj, + tid, + num_threads); + } + } + } // omp parallel + + compare_validate_buffers( + Cint8_ref.data(), + Cint8_fb.data(), + MDim, + NDim * G, + NDim * G, + static_cast<uint8_t>(0)); + } // for each shape +} |