Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/FBGEMM.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorYoung Jin Kim <youki@microsoft.com>2019-09-25 19:40:48 +0300
committerYoung Jin Kim <youki@microsoft.com>2019-09-25 19:40:48 +0300
commit7bd598c9e97871e42c19449fddf7bd317898eb58 (patch)
treea3b1fea18477b63add473037d0644a96b115e0da
parent08763b198ef743741560ae42a9c10a3017c7c9ce (diff)
parent518d8a1832cf1eb1dda2feace1a278e9e4f302ba (diff)
Merge remote-tracking branch 'upstream/master' into youki/win-jit-debug-int8
Fix for windows build errors
-rw-r--r--CMakeLists.txt2
-rw-r--r--CODE_OF_CONDUCT.md78
-rw-r--r--bench/ConvUnifiedBenchmark.cc54
-rw-r--r--bench/Depthwise3DBenchmark.cc134
-rw-r--r--bench/DepthwiseBenchmark.cc128
-rw-r--r--bench/GEMMsBenchmark.cc2
-rw-r--r--bench/GEMMsTunableBenchmark.cc6
-rw-r--r--bench/PackedFloatInOutBenchmark.cc2
-rw-r--r--bench/PackedRequantizeAcc16Benchmark.cc2
-rw-r--r--bench/PackedRequantizeAcc32Benchmark.cc2
-rw-r--r--include/fbgemm/ConvUtils.h38
-rw-r--r--include/fbgemm/Fbgemm.h44
-rw-r--r--include/fbgemm/FbgemmI8DepthwiseAvx2.h153
-rw-r--r--include/fbgemm/OutputProcessing-inl.h36
-rw-r--r--include/fbgemm/QuantUtilsAvx2.h13
-rw-r--r--include/fbgemm/UtilsAvx2.h5
-rw-r--r--src/CodeCache.h59
-rw-r--r--src/ExecuteKernelU8S8.cc47
-rw-r--r--src/Fbgemm.cc58
-rw-r--r--src/FbgemmConv.cc189
-rw-r--r--src/FbgemmFP16.cc2
-rw-r--r--src/FbgemmI8Depthwise3DAvx2.cc1415
-rw-r--r--src/FbgemmI8DepthwiseAvx2-inl.h709
-rw-r--r--src/FbgemmI8DepthwiseAvx2.cc2375
-rw-r--r--src/GenerateKernel.h101
-rw-r--r--src/GenerateKernelU8S8S32ACC16.cc357
-rw-r--r--src/GenerateKernelU8S8S32ACC16Avx512.cc493
-rw-r--r--src/GenerateKernelU8S8S32ACC32.cc373
-rw-r--r--src/GenerateKernelU8S8S32ACC32Avx512.cc556
-rw-r--r--src/GenerateKernelU8S8S32ACC32Avx512VNNI.cc519
-rw-r--r--src/GroupwiseConv.h38
-rw-r--r--src/GroupwiseConvAcc32Avx2.cc118
-rw-r--r--src/PackAWithIm2Col.cc16
-rw-r--r--src/PackDepthwiseConvMatrixAvx2.cc203
-rw-r--r--src/PackWeightsForConv.cc33
-rw-r--r--src/QuantUtils.cc38
-rw-r--r--src/QuantUtilsAvx2.cc566
-rw-r--r--src/RefImplementations.cc430
-rw-r--r--src/RefImplementations.h120
-rw-r--r--test/I8DepthwiseTest.cc469
-rw-r--r--test/I8DepthwiseTest.h39
-rw-r--r--test/RequantizeOnlyTest.cc169
-rw-r--r--test/UniConvTest.cc454
43 files changed, 6032 insertions, 4613 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 0460799..c06b60b 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -89,8 +89,10 @@ endif()
#All the source files that either use avx2 instructions statically
set(FBGEMM_AVX2_SRCS
src/FbgemmFP16UKernelsAvx2.cc
+ src/FbgemmI8Depthwise3DAvx2.cc
src/FbgemmI8DepthwiseAvx2.cc
src/OptimizedKernelsAvx2.cc
+ src/PackDepthwiseConvMatrixAvx2.cc
src/QuantUtilsAvx2.cc
src/UtilsAvx2.cc)
diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md
index 0f7ad8b..d1abc70 100644
--- a/CODE_OF_CONDUCT.md
+++ b/CODE_OF_CONDUCT.md
@@ -1,5 +1,77 @@
# Code of Conduct
-Facebook has adopted a Code of Conduct that we expect project participants to adhere to.
-Please read the [full text](https://code.fb.com/codeofconduct/)
-so that you can understand what actions will and will not be tolerated.
+## Our Pledge
+
+In the interest of fostering an open and welcoming environment, we as
+contributors and maintainers pledge to make participation in our project and
+our community a harassment-free experience for everyone, regardless of age, body
+size, disability, ethnicity, sex characteristics, gender identity and expression,
+level of experience, education, socio-economic status, nationality, personal
+appearance, race, religion, or sexual identity and orientation.
+
+## Our Standards
+
+Examples of behavior that contributes to creating a positive environment
+include:
+
+* Using welcoming and inclusive language
+* Being respectful of differing viewpoints and experiences
+* Gracefully accepting constructive criticism
+* Focusing on what is best for the community
+* Showing empathy towards other community members
+
+Examples of unacceptable behavior by participants include:
+
+* The use of sexualized language or imagery and unwelcome sexual attention or
+ advances
+* Trolling, insulting/derogatory comments, and personal or political attacks
+* Public or private harassment
+* Publishing others' private information, such as a physical or electronic
+ address, without explicit permission
+* Other conduct which could reasonably be considered inappropriate in a
+ professional setting
+
+## Our Responsibilities
+
+Project maintainers are responsible for clarifying the standards of acceptable
+behavior and are expected to take appropriate and fair corrective action in
+response to any instances of unacceptable behavior.
+
+Project maintainers have the right and responsibility to remove, edit, or
+reject comments, commits, code, wiki edits, issues, and other contributions
+that are not aligned to this Code of Conduct, or to ban temporarily or
+permanently any contributor for other behaviors that they deem inappropriate,
+threatening, offensive, or harmful.
+
+## Scope
+
+This Code of Conduct applies within all project spaces, and it also applies when
+an individual is representing the project or its community in public spaces.
+Examples of representing a project or community include using an official
+project e-mail address, posting via an official social media account, or acting
+as an appointed representative at an online or offline event. Representation of
+a project may be further defined and clarified by project maintainers.
+
+## Enforcement
+
+Instances of abusive, harassing, or otherwise unacceptable behavior may be
+reported by contacting the project team at <opensource-conduct@fb.com>. All
+complaints will be reviewed and investigated and will result in a response that
+is deemed necessary and appropriate to the circumstances. The project team is
+obligated to maintain confidentiality with regard to the reporter of an incident.
+Further details of specific enforcement policies may be posted separately.
+
+Project maintainers who do not follow or enforce the Code of Conduct in good
+faith may face temporary or permanent repercussions as determined by other
+members of the project's leadership.
+
+## Attribution
+
+This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
+available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
+
+[homepage]: https://www.contributor-covenant.org
+
+For answers to common questions about this code of conduct, see
+https://www.contributor-covenant.org/faq
+
diff --git a/bench/ConvUnifiedBenchmark.cc b/bench/ConvUnifiedBenchmark.cc
index b450beb..35a3b2a 100644
--- a/bench/ConvUnifiedBenchmark.cc
+++ b/bench/ConvUnifiedBenchmark.cc
@@ -24,33 +24,39 @@
using namespace std;
using namespace fbgemm;
+// clang-format off
// 2D conv shapes
vector<conv_param_t<2>> shapes_2d = {
- // MB, IC, OC, IH, IW, G, KH, KW, stride_h, stride_w,
- // pad_h_top, pad_w_left, pad_h_bottom, pad_w_right
- // 2D convolutions
- // regular
- conv_param_t<>(1, 128, 128, {56, 56}, 1, {3, 3}, {1, 1}, {1, 1, 1, 1}),
- // groupwise
- conv_param_t<>(1, 128, 128, {56, 56}, 32, {3, 3}, {1, 1}, {1, 1, 1, 1}),
- // DW
- conv_param_t<>(1, 272, 272, {47, 125}, 272, {3, 3}, {1, 1}, {1, 1, 1, 1}),
- // Pointwise
- conv_param_t<>(1, 128, 128, {56, 56}, 1, {1, 1}, {1, 1}, {0, 0, 0, 0})
+ // MB, IC, OC, IH, IW, G, KH, KW, stride_h, stride_w,
+ // pad_h_top, pad_w_left, pad_h_bottom, pad_w_right
+ // 2D convolutions
+ // regular
+ conv_param_t<>(1, 128, 128, {56, 56}, 1, {3, 3}, {1, 1}, {1, 1, 1, 1}),
+ // regular with dilation
+ conv_param_t<>(1, 128, 128, {56, 56}, 1, {3, 3}, {1, 1}, {1, 1, 1, 1}, {2, 2}),
+ // groupwise
+ conv_param_t<>(1, 128, 128, {56, 56}, 32, {3, 3}, {1, 1}, {1, 1, 1, 1}),
+ // DW
+ conv_param_t<>(1, 272, 272, {47, 125}, 272, {3, 3}, {1, 1}, {1, 1, 1, 1}),
+ // Pointwise
+ conv_param_t<>(1, 128, 128, {56, 56}, 1, {1, 1}, {1, 1}, {0, 0, 0, 0})
};
// 3D conv shapes
vector<conv_param_t<3>> shapes_3d = {
- // MB, IC, OC, {IT, IH, IW}, G, {KT, KH, KW}, {stride_t, stride_h, stride_w},
- // {pad_prev, pad_h_top, pad_w_left, pad_next, pad_h_bottom, pad_w_right}
- // Regular
- conv_param_t<3>(1, 64, 64, {8, 14, 14}, 1, {3, 3, 3}, {1, 1, 1}, {1, 1, 1, 1, 1, 1}),
- // Depthwise
- conv_param_t<3>(1, 64, 64, {8, 14, 14}, 64, {3, 3, 3}, {1, 1, 1}, {1, 1, 1, 1, 1, 1}),
- // Pointwise
- conv_param_t<3>(1, 128, 128, {8, 14, 14}, 1, {1, 1, 1}, {1, 1, 1}, {0, 0, 0, 0})
-};
+ // MB, IC, OC, {IT, IH, IW}, G, {KT, KH, KW}, {stride_t, stride_h,
+ // stride_w},
+ // {pad_prev, pad_h_top, pad_w_left, pad_next, pad_h_bottom, pad_w_right}
+ // Regular
+ conv_param_t<3>(1, 64, 64, {8, 14, 14}, 1, {3, 3, 3}, {1, 1, 1}, {1, 1, 1, 1, 1, 1}),
+ //With dilations
+ conv_param_t<3>(1, 64, 64, {8, 14, 14}, 1, {3, 3, 3}, {1, 1, 1}, {1, 1, 1, 1, 1, 1}, {2, 2, 2}),
+ // Depthwise
+ conv_param_t<3>(1, 64, 64, {8, 14, 14}, 64, {3, 3, 3}, {1, 1, 1}, {1, 1, 1, 1, 1, 1}),
+ // Pointwise
+ conv_param_t<3>(1, 128, 128, {8, 14, 14}, 1, {1, 1, 1}, {1, 1, 1}, {0, 0, 0, 0})};
+// clang-format on
template <int SPATIAL_DIM, typename Acc_t>
void performance_test(const vector<conv_param_t<SPATIAL_DIM>>& shapes) {
@@ -81,6 +87,10 @@ void performance_test(const vector<conv_param_t<SPATIAL_DIM>>& shapes) {
header += "pad_t, ";
}
header += "pad_h, pad_w, ";
+ if (SPATIAL_DIM == 3) {
+ header += "dilation_t, ";
+ }
+ header += "dilation_h, dilation_w, ";
header += "Type, M, N, K, ";
@@ -278,7 +288,9 @@ void performance_test(const vector<conv_param_t<SPATIAL_DIM>>& shapes) {
for (int i = 0; i < SPATIAL_DIM; ++i) {
cout << conv_p.pad[i] << ", ";
}
-
+ for (int i = 0; i < SPATIAL_DIM; ++i) {
+ cout << conv_p.dilation[i] << ", ";
+ }
cout << setw(13) << runType << ", " << setw(5) << fixed << setw(5)
<< setw(6) << MDim << ", " << setw(6) << NDim << ", " << setw(6)
<< KDim << ", ";
diff --git a/bench/Depthwise3DBenchmark.cc b/bench/Depthwise3DBenchmark.cc
index 0efdcac..ff2be6f 100644
--- a/bench/Depthwise3DBenchmark.cc
+++ b/bench/Depthwise3DBenchmark.cc
@@ -4,7 +4,6 @@
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
-#include "test/I8DepthwiseTest.h"
#include <algorithm>
#include <chrono>
@@ -19,8 +18,8 @@
#include "AlignedVec.h"
#include "BenchUtils.h"
-#include "fbgemm/Utils.h"
#include "fbgemm/FbgemmI8DepthwiseAvx2.h"
+#include "fbgemm/Utils.h"
#include "src/RefImplementations.h"
using namespace std;
@@ -35,6 +34,34 @@ int main() {
}
#endif
+ // From ResNeXt-3D-101
+ // clang-format off
+ vector<vector<int>> shapes_3d = {
+ // NOTE: clang-format wants to use a different formatting but the current
+ // formatting should be easier to read.
+ // N, K, T_in, H_in, W_in, stride
+ { 1, 64, 32, 56, 56, 1, },
+ { 1, 128, 16, 28, 28, 1, },
+ { 1, 256, 8, 14, 14, 1, },
+ { 1, 512, 4, 7, 7, 1, },
+
+ { 1, 128, 32, 56, 56, 2, },
+ { 1, 256, 16, 28, 28, 2, },
+ { 1, 512, 8, 14, 14, 2, },
+
+ { 5, 64, 32, 56, 56, 1, },
+ { 5, 128, 16, 28, 28, 1, },
+ { 5, 256, 8, 14, 14, 1, },
+ { 5, 512, 4, 7, 7, 1, },
+
+ { 5, 128, 32, 56, 56, 2, },
+ { 5, 256, 16, 28, 28, 2, },
+ { 5, 512, 8, 14, 14, 2, },
+
+ { 1, 8, 4, 4, 4, 1, },
+ };
+ // clang-format on
+
// Depthwise is memory BW bound so we want to flush LLC.
bool flush = true;
std::vector<char> llc;
@@ -61,14 +88,28 @@ int main() {
constexpr int K_T = 3, K_H = 3, K_W = 3;
constexpr int PAD_P = 1, PAD_N = 1, PAD_T = 1, PAD_B = 1, PAD_L = 1,
PAD_R = 1;
- int T_OUT = (T + PAD_P + PAD_N - K_T) / stride_t + 1;
- int H_OUT = (H + PAD_T + PAD_B - K_H) / stride_h + 1;
- int W_OUT = (W + PAD_L + PAD_R - K_W) / stride_w + 1;
+
+ conv_param_t<3> conv_p(
+ N,
+ K,
+ K,
+ {T, H, W},
+ K,
+ {K_T, K_H, K_W},
+ {stride_t, stride_h, stride_w},
+ {PAD_P, PAD_T, PAD_L, PAD_N, PAD_B, PAD_R});
+ int T_OUT = conv_p.OUT_DIM[0];
+ int H_OUT = conv_p.OUT_DIM[1];
+ int W_OUT = conv_p.OUT_DIM[2];
+
+ int MDim = N * T_OUT * H_OUT * W_OUT;
+ int KDim = K_T * K_H * K_W * K;
+ int KDimPerGroup = KDim / conv_p.G;
aligned_vector<uint8_t> A(N * T * H * W * K);
- aligned_vector<int8_t> B(K * K_T * K_H * K_W);
- aligned_vector<int32_t> C_ref(N * T_OUT * H_OUT * W_OUT * K),
- C(C_ref.size());
+ aligned_vector<int8_t> B(KDim);
+ aligned_vector<int32_t> C_ref(MDim * K), C(C_ref.size());
+ aligned_vector<uint8_t> C_uint8_ref(C_ref.size()), C_uint8(C_ref.size());
randFill<uint8_t>(A, 0, 86);
int32_t A_zero_point = 43;
@@ -76,52 +117,49 @@ int main() {
randFill<int8_t>(B, -16, 16);
int32_t B_zero_point = 5;
- depthwise_3x3x3_pad_1_ref(
- N,
- T,
- H,
- W,
- K,
- stride_t,
- stride_h,
- stride_w,
- A_zero_point,
- A.data(),
- B.data(),
- C_ref.data());
-
- int32_t minimum = *min_element(C_ref.begin(), C_ref.end());
- int32_t maximum = *max_element(C_ref.begin(), C_ref.end());
+ aligned_vector<float> C_multiplier(1);
+ randFill(C_multiplier, 0.001234f / 2, 0.001234f * 3 / 2);
+ int32_t C_zero_point = 5;
- float C_multiplier = 255. / (maximum - minimum);
+ vector<int32_t> row_offsets(MDim);
+ // im2col to compute row offset later
+ vector<uint8_t> A_im2col(MDim * KDim);
+ im2col_ref(conv_p, A.data(), A_zero_point, A_im2col.data());
aligned_vector<int32_t> col_offsets(K);
aligned_vector<int32_t> bias(K);
randFill(col_offsets, -100, 100);
randFill(bias, -40, 40);
- int32_t C_zero_point = 5;
- aligned_vector<uint8_t> C_uint8_ref(C_ref.size()), C_uint8(C_ref.size());
- depthwise_3x3x3_pad_1_ref(
- N,
- T,
- H,
- W,
- K,
- stride_t,
- stride_h,
- stride_w,
- A_zero_point,
- A.data(),
- B_zero_point,
- B.data(),
- C_multiplier,
- C_zero_point,
- C_uint8_ref.data(),
- col_offsets.data(),
- bias.data());
-
- Packed3x3x3ConvMatrix Bp(K, B.data());
+ conv_ref(conv_p, A.data(), A_zero_point, B.data(), C_ref.data());
+
+ for (int g = 0; g < conv_p.G; ++g) {
+ // Compute row offset
+ row_offsets_u8acc32_ref(
+ MDim,
+ KDimPerGroup,
+ KDim,
+ A_im2col.data() + g * KDimPerGroup,
+ row_offsets.data());
+
+ // Requantization
+ requantize_u8acc32_ref(
+ MDim,
+ 1,
+ conv_p.G,
+ C_ref.data() + g,
+ C_uint8_ref.data() + g,
+ C_multiplier.data(),
+ C_zero_point,
+ A_zero_point,
+ &B_zero_point,
+ row_offsets.data(),
+ col_offsets.data() + g,
+ bias.data() + g,
+ K);
+ }
+
+ PackedDepthWiseConvMatrix Bp(K, 3 * 3 * 3, B.data());
double ttot = 0;
double bytes = double(NITER) *
@@ -153,7 +191,7 @@ int main() {
A.data(),
B_zero_point,
Bp,
- C_multiplier,
+ C_multiplier[0],
C_zero_point,
C_uint8.data(),
col_offsets.data(),
diff --git a/bench/DepthwiseBenchmark.cc b/bench/DepthwiseBenchmark.cc
index 96921a1..6c2ee17 100644
--- a/bench/DepthwiseBenchmark.cc
+++ b/bench/DepthwiseBenchmark.cc
@@ -17,8 +17,8 @@
#include "AlignedVec.h"
#include "BenchUtils.h"
-#include "fbgemm/Utils.h"
#include "fbgemm/FbgemmI8DepthwiseAvx2.h"
+#include "fbgemm/Utils.h"
#include "src/RefImplementations.h"
using namespace std;
@@ -34,10 +34,11 @@ int main() {
#endif
// From Xray OCR
+ // clang-format off
vector<vector<int>> shapes = {
// NOTE: clang-format wants to use a different formatting but the current
// formatting should be easier to read.
- // N, G, H_in, W_in, stride
+ // N, K, H_in, W_in, stride
{ 1, 272, 47, 125, 1, },
{ 1, 272, 64, 125, 1, },
{ 1, 272, 66, 125, 1, },
@@ -138,6 +139,7 @@ int main() {
{ 96, 544, 14, 14, 2, },
{ 100, 544, 14, 14, 2, },
};
+ // clang-format on
// Depthwise is memory BW bound so we want to flush LLC.
bool flush = true;
@@ -155,19 +157,35 @@ int main() {
for (auto shape : shapes) {
int N = shape[0];
- int G = shape[1];
+ int K = shape[1];
int H = shape[2];
int W = shape[3];
int stride_h = shape[4];
int stride_w = stride_h;
constexpr int R = 3, S = 3;
- constexpr int PAD_T = 1, PAD_B = 1, PAD_L = 1, PAD_R = 1;
- int H_OUT = (H + PAD_T + PAD_B - R) / stride_h + 1;
- int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1;
+ int PAD_T = (R - 1) / 2, PAD_B = (R - 1) / 2, PAD_L = (S - 1) / 2,
+ PAD_R = (S - 1) / 2;
+
+ conv_param_t<2> conv_p(
+ N,
+ K,
+ K,
+ {H, W},
+ K,
+ {R, S},
+ {stride_h, stride_w},
+ {PAD_T, PAD_L, PAD_B, PAD_R});
+ int H_OUT = conv_p.OUT_DIM[0];
+ int W_OUT = conv_p.OUT_DIM[1];
+
+ int MDim = N * H_OUT * W_OUT;
+ int KDim = R * S * K;
+ int KDimPerGroup = KDim / conv_p.G;
- aligned_vector<uint8_t> A(N * H * W * G);
- aligned_vector<int8_t> B(G * R * S);
- aligned_vector<int32_t> C_ref(N * H_OUT * W_OUT * G), C(C_ref.size());
+ aligned_vector<uint8_t> A(N * H * W * K);
+ aligned_vector<int8_t> B(KDim);
+ aligned_vector<int32_t> C_ref(MDim * K), C(C_ref.size());
+ aligned_vector<uint8_t> C_uint8_ref(C_ref.size()), C_uint8(C_ref.size());
randFill<uint8_t>(A, 0, 86);
int32_t A_zero_point = 43;
@@ -175,53 +193,54 @@ int main() {
randFill<int8_t>(B, -16, 16);
int32_t B_zero_point = 5;
- depthwise_3x3_pad_1_ref(
- N,
- H,
- W,
- G,
- stride_h,
- stride_w,
- A_zero_point,
- A.data(),
- B.data(),
- C_ref.data());
-
- int32_t minimum = *min_element(C_ref.begin(), C_ref.end());
- int32_t maximum = *max_element(C_ref.begin(), C_ref.end());
+ aligned_vector<float> C_multiplier(1);
+ randFill(C_multiplier, 0.001234f / 2, 0.001234f * 3 / 2);
+ int32_t C_zero_point = 5;
- float C_multiplier = 255. / (maximum - minimum);
+ vector<int32_t> row_offsets(MDim);
+ // im2col to compute row offset later
+ vector<uint8_t> A_im2col(MDim * KDim);
+ im2col_ref(conv_p, A.data(), A_zero_point, A_im2col.data());
- aligned_vector<int32_t> col_offsets(G);
- aligned_vector<int32_t> bias(G);
+ aligned_vector<int32_t> col_offsets(K);
+ aligned_vector<int32_t> bias(K);
randFill(col_offsets, -100, 100);
randFill(bias, -40, 40);
- int32_t C_zero_point = 5;
- aligned_vector<uint8_t> C_uint8_ref(C_ref.size()), C_uint8(C_ref.size());
- depthwise_3x3_pad_1_ref(
- N,
- H,
- W,
- G,
- stride_h,
- stride_w,
- A_zero_point,
- A.data(),
- B_zero_point,
- B.data(),
- C_multiplier,
- C_zero_point,
- C_uint8_ref.data(),
- col_offsets.data(),
- bias.data());
+ conv_ref(conv_p, A.data(), A_zero_point, B.data(), C_ref.data());
+
+ for (int g = 0; g < conv_p.G; ++g) {
+ // Compute row offset
+ row_offsets_u8acc32_ref(
+ MDim,
+ KDimPerGroup,
+ KDim,
+ A_im2col.data() + g * KDimPerGroup,
+ row_offsets.data());
+
+ // Requantization
+ requantize_u8acc32_ref(
+ MDim,
+ 1,
+ conv_p.G,
+ C_ref.data() + g,
+ C_uint8_ref.data() + g,
+ C_multiplier.data(),
+ C_zero_point,
+ A_zero_point,
+ &B_zero_point,
+ row_offsets.data(),
+ col_offsets.data() + g,
+ bias.data() + g,
+ K);
+ }
- Packed3x3ConvMatrix Bp(G, B.data());
+ PackedDepthWiseConvMatrix Bp(K, 3 * 3, B.data());
double ttot = 0;
double bytes = double(NITER) *
- (G * (N * (2 * sizeof(int32_t) * H_OUT * W_OUT + H * W) + R * S));
- double ops = double(NITER) * N * H_OUT * W_OUT * G * R * S * 2;
+ (K * (N * (2 * sizeof(int32_t) * H_OUT * W_OUT + H * W) + R * S));
+ double ops = double(NITER) * N * H_OUT * W_OUT * K * R * S * 2;
chrono::time_point<chrono::system_clock> t_begin, t_end;
for (int i = 0; i < NWARMUP + NITER; ++i) {
llc_flush();
@@ -235,19 +254,20 @@ int main() {
N,
H,
W,
- G,
+ K,
stride_h,
stride_w,
A_zero_point,
A.data(),
B_zero_point,
Bp,
- C_multiplier,
+ C_multiplier[0],
C_zero_point,
C_uint8.data(),
col_offsets.data(),
bias.data(),
false, /* fuse_relu */
+ 1.0f, /* act_scale * w_scale */
tid,
num_threads);
}
@@ -262,10 +282,10 @@ int main() {
for (int n = 0; n < N; ++n) {
for (int h = 0; h < H_OUT; ++h) {
for (int w = 0; w < W_OUT; ++w) {
- for (int g = 0; g < G; ++g) {
+ for (int g = 0; g < K; ++g) {
uint8_t expected =
- C_uint8_ref[((n * H_OUT + h) * W_OUT + w) * G + g];
- uint8_t actual = C_uint8[((n * H_OUT + h) * W_OUT + w) * G + g];
+ C_uint8_ref[((n * H_OUT + h) * W_OUT + w) * K + g];
+ uint8_t actual = C_uint8[((n * H_OUT + h) * W_OUT + w) * K + g];
if (expected != actual) {
cerr << "Depthwise 3x3 results differ at (" << n << ", " << h
<< ", " << w << ", " << g << "). expected " << (int)expected
@@ -280,9 +300,9 @@ int main() {
// Report performance
printf(
- "N = %d G = %d H = %d W = %d stride = %d with requantization fused\n",
+ "N = %d K = %d H = %d W = %d stride = %d with requantization fused\n",
N,
- G,
+ K,
H,
W,
stride_h);
diff --git a/bench/GEMMsBenchmark.cc b/bench/GEMMsBenchmark.cc
index b404d8b..f493a96 100644
--- a/bench/GEMMsBenchmark.cc
+++ b/bench/GEMMsBenchmark.cc
@@ -28,6 +28,7 @@ using namespace std;
using namespace fbgemm;
void performance_test() {
+ // clang-format off
static const vector<vector<int>> shapes = {
// NOTE: clang-format wants to use a different formatting but the current
// formatting should be easier to read.
@@ -39,6 +40,7 @@ void performance_test() {
{256, 512, 256},
{1024, 1024, 1024},
};
+ // clang-format on
bool flush = true;
std::vector<char> llc;
diff --git a/bench/GEMMsTunableBenchmark.cc b/bench/GEMMsTunableBenchmark.cc
index a65b51f..2adc556 100644
--- a/bench/GEMMsTunableBenchmark.cc
+++ b/bench/GEMMsTunableBenchmark.cc
@@ -218,7 +218,8 @@ int main(int /* unused */, char** /* unused */) {
}
#endif
- vector<vector<int>> shapes = {
+ // clang-format off
+ vector<vector<int>> shapes = {
// NOTE: clang-format wants to use a different formatting but the current
// formatting should be easier to read.
// m, n, k
@@ -266,7 +267,8 @@ int main(int /* unused */, char** /* unused */) {
{128, 128, 128},
{256, 512, 256},
{1024, 1024, 1024},
-};
+ };
+ // clang-format on
vector<int> MCBs;
vector<int> NCBs;
diff --git a/bench/PackedFloatInOutBenchmark.cc b/bench/PackedFloatInOutBenchmark.cc
index 66ca67e..dcea65c 100644
--- a/bench/PackedFloatInOutBenchmark.cc
+++ b/bench/PackedFloatInOutBenchmark.cc
@@ -28,6 +28,7 @@ using namespace std;
using namespace fbgemm;
void performance_test() {
+ // clang-format off
vector<vector<int>> shapes = {
// NOTE: clang-format wants to use a different formatting but the current
// formatting should be easier to read.
@@ -66,6 +67,7 @@ void performance_test() {
{1, 128, 2722},
{16, 256, 512},
};
+ // clang-format on
bool flush = true;
std::vector<char> llc;
diff --git a/bench/PackedRequantizeAcc16Benchmark.cc b/bench/PackedRequantizeAcc16Benchmark.cc
index 40ff662..c6e2869 100644
--- a/bench/PackedRequantizeAcc16Benchmark.cc
+++ b/bench/PackedRequantizeAcc16Benchmark.cc
@@ -37,6 +37,7 @@ enum class BenchmarkType {
};
void performance_test() {
+ // clang-format off
vector<vector<int>> shapes = {
// NOTE: clang-format wants to use a different formatting but the current
// formatting should be easier to read.
@@ -67,6 +68,7 @@ void performance_test() {
{392, 2048, 512},
{392, 512, 2048},
};
+ // clang-format on
bool flush = true;
std::vector<char> llc;
diff --git a/bench/PackedRequantizeAcc32Benchmark.cc b/bench/PackedRequantizeAcc32Benchmark.cc
index 2f04795..a61ef5a 100644
--- a/bench/PackedRequantizeAcc32Benchmark.cc
+++ b/bench/PackedRequantizeAcc32Benchmark.cc
@@ -28,6 +28,7 @@ using namespace std;
using namespace fbgemm;
void performance_test() {
+ // clang-format off
vector<vector<int>> shapes = {
// NOTE: clang-format wants to use a different formatting but the current
// formatting should be easier to read.
@@ -70,6 +71,7 @@ void performance_test() {
{1, 128, 2722},
{16, 256, 512},
};
+ // clang-format on
bool flush = true;
std::vector<char> llc;
diff --git a/include/fbgemm/ConvUtils.h b/include/fbgemm/ConvUtils.h
index 11f3dcc..5431958 100644
--- a/include/fbgemm/ConvUtils.h
+++ b/include/fbgemm/ConvUtils.h
@@ -8,9 +8,24 @@
#include <array>
#include <string>
+#include <type_traits>
namespace fbgemm {
+template <int N, int... Vals>
+constexpr
+ typename std::enable_if<N == sizeof...(Vals), std::array<int, N>>::type
+ array_of_ones() {
+ return std::array<int, N>{{Vals...}};
+}
+
+template <int N, int... Vals>
+constexpr
+ typename std::enable_if<N != sizeof...(Vals), std::array<int, N>>::type
+ array_of_ones() {
+ return array_of_ones<N, Vals..., 1>();
+}
+
/**
* @brief A struct to conveniently store all convolution parameters.
*/
@@ -34,7 +49,6 @@ struct conv_param_t {
/**
* @brief Constructor for initializing the convolution parameters.
- * TODO: Dilation is not handled correctly.
*/
conv_param_t(
int mb,
@@ -44,7 +58,8 @@ struct conv_param_t {
int g,
std::array<int, SPATIAL_DIM> k,
std::array<int, SPATIAL_DIM> strd,
- std::array<int, SPATIAL_DIM * 2> pd)
+ std::array<int, SPATIAL_DIM * 2> pd,
+ std::array<int, SPATIAL_DIM> dilations = array_of_ones<SPATIAL_DIM>())
: MB(mb),
IC(ic),
OC(oc),
@@ -52,7 +67,8 @@ struct conv_param_t {
G(g),
K(k),
stride(strd),
- pad(pd) {
+ pad(pd),
+ dilation(dilations) {
if (ic % g != 0) {
throw std::runtime_error(
"groups = " + std::to_string(g) +
@@ -63,10 +79,10 @@ struct conv_param_t {
"groups = " + std::to_string(g) +
" does not divide number of output channels = " + std::to_string(oc));
}
+
for (int d = 0; d < SPATIAL_DIM; ++d) {
- dilation[d] = 1;
IN_DIMP[d] = IN_DIM[d] + pad[d] + pad[SPATIAL_DIM + d];
- OUT_DIM[d] = (IN_DIMP[d] - K[d]) / stride[d] + 1;
+ OUT_DIM[d] = (IN_DIMP[d] - dilation[d] * (K[d] - 1) - 1) / stride[d] + 1;
}
}
@@ -102,8 +118,12 @@ struct conv_param_t {
}
for (int d = 0; d < SPATIAL_DIM * 2; ++d) {
out += "pad_" + dim_string[3 - SPATIAL_DIM + (d % SPATIAL_DIM)] + ":" +
- std::to_string(pad[d]);
- if (d < SPATIAL_DIM * 2 - 1) {
+ std::to_string(pad[d]) + ", ";
+ }
+ for (int d = 0; d < SPATIAL_DIM; ++d) {
+ out += "dilation_" + dim_string[3 - SPATIAL_DIM + d] + ":" +
+ std::to_string(dilation[d]);
+ if (d < SPATIAL_DIM - 1) {
out += ", ";
}
}
@@ -121,6 +141,10 @@ struct conv_param_t {
out += ", ";
}
}
+ for (int d = 0; d < SPATIAL_DIM; ++d) {
+ out += "dilation_" + std::to_string(d) + ":" +
+ std::to_string(dilation[d]) + ", ";
+ }
}
return out;
}
diff --git a/include/fbgemm/Fbgemm.h b/include/fbgemm/Fbgemm.h
index 09ddbe1..668bd42 100644
--- a/include/fbgemm/Fbgemm.h
+++ b/include/fbgemm/Fbgemm.h
@@ -609,12 +609,16 @@ class FBGEMM_API PackWeightsForConv {
return W_im2col_packed_;
}
- std::shared_ptr<Packed3x3ConvMatrix> getPackedWFor2DDW() {
- return W_dw_2D_packed_;
+ std::shared_ptr<PackedDepthWiseConvMatrix> getPackedWForDepthwise() {
+ return W_dw_packed_;
}
- std::shared_ptr<Packed3x3x3ConvMatrix> getPackedWFor3DDW() {
- return W_dw_3D_packed_;
+ std::shared_ptr<PackedDepthWiseConvMatrix> getPackedWFor2DDW() {
+ return W_dw_packed_;
+ }
+
+ std::shared_ptr<PackedDepthWiseConvMatrix> getPackedWFor3DDW() {
+ return W_dw_packed_;
}
std::shared_ptr<PackWeightMatrixForGConv<T, accT, SPATIAL_DIM>>
@@ -663,10 +667,8 @@ class FBGEMM_API PackWeightsForConv {
const conv_param_t<SPATIAL_DIM> conv_param_;
// Packed weights if we use im2col based convolution implementation
std::shared_ptr<PackBMatrix<T, accT>> W_im2col_packed_;
- // Packed weights if we use 2D depthwise convolution implementation
- std::shared_ptr<Packed3x3ConvMatrix> W_dw_2D_packed_;
- // Packed weights if we use 3D depthwise convolution implementation
- std::shared_ptr<Packed3x3x3ConvMatrix> W_dw_3D_packed_;
+ // Packed weights if we use depthwise convolution implementation
+ std::shared_ptr<PackedDepthWiseConvMatrix> W_dw_packed_;
// Packed weights if we use groupwise (small channels per group) convolution
// implementation
std::shared_ptr<PackWeightMatrixForGConv<T, accT, SPATIAL_DIM>>
@@ -1161,12 +1163,15 @@ class FBGEMM_API DoSConvOnInpBuffer {
template <
bool FUSE_RELU,
QuantizationGranularity Q_GRAN = QuantizationGranularity::TENSOR,
+ typename BIAS_TYPE = std::int32_t,
typename outT = std::uint8_t,
typename inT = std::int32_t,
typename nextOPType = DoNothing<outT, outT>>
class FBGEMM_API ReQuantizeOutput {
public:
static constexpr int RELU_FUSED = FUSE_RELU;
+ static constexpr QuantizationGranularity QGRANType = Q_GRAN;
+ using BIAS_T = BIAS_TYPE;
using outType = outT;
using inpType = inT;
/**
@@ -1187,6 +1192,8 @@ class FBGEMM_API ReQuantizeOutput {
* See PackedRequantizeTest.cc for an example.
* TODO: if Aq_zero_point == 0, allow passing nullptr.
* @params bias can be nullptr otherwise the length should be nCol
+ * @params act_times_w_scale activation_scale * weight_scale. This is only
+ * used if bias is unquantized (i.e., float).
*/
ReQuantizeOutput(
nextOPType& nextop,
@@ -1196,9 +1203,10 @@ class FBGEMM_API ReQuantizeOutput {
const std::int32_t* Bq_zero_point,
const std::int32_t* row_offsets,
const std::int32_t* col_offsets,
- const std::int32_t* bias,
+ const BIAS_T* bias,
std::uint32_t nCol,
- int groups = 1)
+ int groups = 1,
+ const float* act_times_w_scale = nullptr)
: nextop_(nextop),
C_multiplier_(C_multiplier),
C_zero_point_(C_zero_point),
@@ -1208,7 +1216,8 @@ class FBGEMM_API ReQuantizeOutput {
q_col_offsets_(col_offsets),
bias_(bias),
ncols_(nCol),
- groups_(groups) {}
+ groups_(groups),
+ act_times_w_scale_(act_times_w_scale) {}
template <inst_set_t instSet>
inline int f(
@@ -1236,12 +1245,15 @@ class FBGEMM_API ReQuantizeOutput {
const std::int32_t* getColOffsets() const {
return q_col_offsets_;
}
- const std::int32_t* getBias() const {
+ const BIAS_T* getBias() const {
return bias_;
}
std::uint32_t getNCols() const {
return ncols_;
}
+ const float* getActWScale() const {
+ return act_times_w_scale_;
+ }
void setRowOffsets(const std::int32_t* row_offsets) {
q_row_offsets_ = row_offsets;
@@ -1255,9 +1267,10 @@ class FBGEMM_API ReQuantizeOutput {
const std::int32_t* Bq_zero_point_;
const std::int32_t* q_row_offsets_;
const std::int32_t* q_col_offsets_;
- const std::int32_t* bias_;
+ const BIAS_T* bias_;
std::uint32_t ncols_;
int groups_;
+ const float* act_times_w_scale_;
};
/**
@@ -1410,7 +1423,8 @@ template <
typename outType,
bool FUSE_RELU,
QuantizationGranularity Q_GRAN,
- int SPATIAL_DIM = 2>
+ int SPATIAL_DIM = 2,
+ typename BIAS_TYPE = std::int32_t>
FBGEMM_API void fbgemmGroupwiseConv(
const conv_param_t<SPATIAL_DIM>& conv_param,
const std::uint8_t* activations,
@@ -1419,7 +1433,7 @@ FBGEMM_API void fbgemmGroupwiseConv(
packed_W& packed_weights,
outType* out,
std::int32_t* outBuffer,
- const ReQuantizeOutput<FUSE_RELU, Q_GRAN>& outProcess,
+ const ReQuantizeOutput<FUSE_RELU, Q_GRAN, BIAS_TYPE>& outProcess,
int thread_id,
int num_threads);
diff --git a/include/fbgemm/FbgemmI8DepthwiseAvx2.h b/include/fbgemm/FbgemmI8DepthwiseAvx2.h
index e7b0ec4..c454b16 100644
--- a/include/fbgemm/FbgemmI8DepthwiseAvx2.h
+++ b/include/fbgemm/FbgemmI8DepthwiseAvx2.h
@@ -11,19 +11,26 @@
namespace fbgemm {
-// KERNEL_PROD is the product of all kernels.
-// For example, KERNEL_PROD = 9 for 3x3, and 27 for 3x3x3.
-template <int KERNEL_PROD>
class FBGEMM_API PackedDepthWiseConvMatrix {
public:
- // smat in GRS layout
- PackedDepthWiseConvMatrix(int K, const std::int8_t* smat);
+ /**
+ * @params K the number of channels (same as the number of groups because
+ * depth-wise convolution has one input/output channel per group)
+ * @params kernel_prod the product of all kernels. For example, kernel_prod =
+ * 9 for 3x3 conv, and 27 for 3x3x3 conv.
+ * @param smat the source unpacked weight in GRS layout
+ */
+ PackedDepthWiseConvMatrix(int K, int kernel_prod, const std::int8_t* smat);
virtual ~PackedDepthWiseConvMatrix();
const std::int8_t* PackedMat() const {
return pmat_;
}
+ int GetKernelProduct() const {
+ return kernel_prod_;
+ }
+
/**
* @brief Unpacks pmat_ into unpack_data.
* Used for recovering the weight matrix into the original format
@@ -36,24 +43,26 @@ class FBGEMM_API PackedDepthWiseConvMatrix {
int addr(int r, int c);
private:
- int K_;
- std::int8_t* pmat_;
-}; // Packed3x3ConvMatrix
-
-using Packed3x3ConvMatrix = PackedDepthWiseConvMatrix<3 * 3>;
-using Packed3x3x3ConvMatrix = PackedDepthWiseConvMatrix<3 * 3 * 3>;
-using Packed1ConvMatrix = PackedDepthWiseConvMatrix<1>;
-using Packed2ConvMatrix = PackedDepthWiseConvMatrix<2>;
-using Packed3ConvMatrix = PackedDepthWiseConvMatrix<3>;
-using Packed4ConvMatrix = PackedDepthWiseConvMatrix<4>;
-using Packed5ConvMatrix = PackedDepthWiseConvMatrix<5>;
-using Packed10ConvMatrix = PackedDepthWiseConvMatrix<10>;
-using Packed11ConvMatrix = PackedDepthWiseConvMatrix<11>;
+ const int K_; /**< the number of channels */
+ const int kernel_prod_; /** the product of all kernel dims */
+ std::int8_t* pmat_; /** packed weight */
+}; // PackedDepthWiseConvMatrix
-/**
- * Depth-wise 3x3 convolution with pad=1 and stride=1 and K a multiple of 8
- * @params A The input image in NHWK layout
- * @params Bp The pre-packed filter
+class FBGEMM_API Packed3x3ConvMatrix : public PackedDepthWiseConvMatrix {
+ public:
+ Packed3x3ConvMatrix(int K, const std::int8_t* smat)
+ : PackedDepthWiseConvMatrix(K, 3 * 3, smat) {}
+};
+
+class FBGEMM_API Packed3x3x3ConvMatrix : public PackedDepthWiseConvMatrix {
+ public:
+ Packed3x3x3ConvMatrix(int K, const std::int8_t* smat)
+ : PackedDepthWiseConvMatrix(K, 3 * 3 * 3, smat) {}
+};
+
+/** To be removed. Keeping it just to make sure we don't change C2 files and
+ * fbgemm files in a single diff
+ *
*/
FBGEMM_API void depthwise_3x3_pad_1(
int N,
@@ -64,8 +73,14 @@ FBGEMM_API void depthwise_3x3_pad_1(
int stride_w,
std::int32_t A_zero_point,
const std::uint8_t* A,
- const Packed3x3ConvMatrix& Bp,
- std::int32_t* C,
+ std::int32_t B_zero_point,
+ const PackedDepthWiseConvMatrix& Bp,
+ float C_multiplier,
+ std::int32_t C_zero_point,
+ std::uint8_t* C,
+ const std::int32_t* col_offsets,
+ const std::int32_t* bias,
+ bool fuse_relu = false,
int thread_id = 0,
int num_threads = 1);
@@ -74,7 +89,10 @@ FBGEMM_API void depthwise_3x3_pad_1(
* This version is fused with requantization.
*
* @col_offsets nullptr if col_offsets are folded into bias
+ * @act_times_w_scale Only used if BIAS_TYPE is float, i.e., bias is
+ * unquantized.
*/
+template <typename BIAS_TYPE = std::int32_t>
FBGEMM_API void depthwise_3x3_pad_1(
int N,
int H,
@@ -85,22 +103,48 @@ FBGEMM_API void depthwise_3x3_pad_1(
std::int32_t A_zero_point,
const std::uint8_t* A,
std::int32_t B_zero_point,
- const Packed3x3ConvMatrix& Bp,
+ const PackedDepthWiseConvMatrix& Bp,
float C_multiplier,
std::int32_t C_zero_point,
std::uint8_t* C,
const std::int32_t* col_offsets,
- const std::int32_t* bias,
+ const BIAS_TYPE* bias,
bool fuse_relu = false,
+ float act_times_w_scale = 1.0f,
int thread_id = 0,
int num_threads = 1);
/**
- * Depth-wise 3x3 convolution with pad=1 and stride=1 and K a multiple of 8
- * This version is fused with requantization and uses per-channel quantization.
+ * Depth-wise 3x3 convolution with pad=1 and K a multiple of 8, fused with
+ * requantization, and using per-channel quantization.
*
* @col_offsets nullptr if col_offsets are folded into bias
*/
+template <typename BIAS_TYPE = std::int32_t>
+FBGEMM_API void depthwise_3x3_per_channel_quantization_pad_1(
+ int N,
+ int H,
+ int W,
+ int K,
+ int stride_h,
+ int stride_w,
+ std::int32_t A_zero_point,
+ const std::uint8_t* A,
+ const std::int32_t* B_zero_point,
+ const PackedDepthWiseConvMatrix& Bp,
+ const float* C_multiplier,
+ std::int32_t C_zero_point,
+ std::uint8_t* C,
+ const std::int32_t* col_offsets,
+ const BIAS_TYPE* bias,
+ bool fuse_relu = false,
+ const float* act_times_w_scale = nullptr,
+ int thread_id = 0,
+ int num_threads = 1);
+
+/** To be removed. Keeping it just to make sure we don't change C2 files and
+ * fbgemm files in a single diff
+ */
FBGEMM_API void depthwise_3x3_per_channel_quantization_pad_1(
int N,
int H,
@@ -111,7 +155,7 @@ FBGEMM_API void depthwise_3x3_per_channel_quantization_pad_1(
std::int32_t A_zero_point,
const std::uint8_t* A,
const std::int32_t* B_zero_point,
- const Packed3x3ConvMatrix& Bp,
+ const PackedDepthWiseConvMatrix& Bp,
const float* C_multiplier,
std::int32_t C_zero_point,
std::uint8_t* C,
@@ -121,6 +165,10 @@ FBGEMM_API void depthwise_3x3_per_channel_quantization_pad_1(
int thread_id = 0,
int num_threads = 1);
+/** To be removed. Keeping it just to make sure we don't change C2 files and
+ * fbgemm files in a single diff
+ *
+ */
FBGEMM_API void depthwise_3x3x3_pad_1(
int N,
int T,
@@ -132,14 +180,20 @@ FBGEMM_API void depthwise_3x3x3_pad_1(
int stride_w,
std::int32_t A_zero_point,
const std::uint8_t* A,
- const Packed3x3x3ConvMatrix& Bp,
- std::int32_t* C,
+ std::int32_t B_zero_point,
+ const PackedDepthWiseConvMatrix& Bp,
+ float C_multiplier,
+ std::int32_t C_zero_point,
+ std::uint8_t* C,
+ const std::int32_t* col_offsets,
+ const std::int32_t* bias,
+ bool fuse_relu = false,
int thread_id = 0,
int num_threads = 1);
-
/**
* @col_offsets nullptr if col_offsets are folded into bias
*/
+template <typename BIAS_TYPE = std::int32_t>
FBGEMM_API void depthwise_3x3x3_pad_1(
int N,
int T,
@@ -152,11 +206,38 @@ FBGEMM_API void depthwise_3x3x3_pad_1(
std::int32_t A_zero_point,
const std::uint8_t* A,
std::int32_t B_zero_point,
- const Packed3x3x3ConvMatrix& Bp,
+ const PackedDepthWiseConvMatrix& Bp,
float C_multiplier,
std::int32_t C_zero_point,
std::uint8_t* C,
const std::int32_t* col_offsets,
+ const BIAS_TYPE* bias,
+ bool fuse_relu = false,
+ float act_times_w_scale = 1.0f,
+ int thread_id = 0,
+ int num_threads = 1);
+
+/** To be removed. Keeping it just to make sure we don't change C2 files and
+ * fbgemm files in a single diff
+ *
+ */
+FBGEMM_API void depthwise_3x3x3_per_channel_quantization_pad_1(
+ int N,
+ int T,
+ int H,
+ int W,
+ int K,
+ int stride_t,
+ int stride_h,
+ int stride_w,
+ std::int32_t A_zero_point,
+ const std::uint8_t* A,
+ const std::int32_t* B_zero_point,
+ const PackedDepthWiseConvMatrix& Bp,
+ const float* C_multiplier,
+ std::int32_t C_zero_point,
+ std::uint8_t* C,
+ const std::int32_t* col_offsets,
const std::int32_t* bias,
bool fuse_relu = false,
int thread_id = 0,
@@ -165,6 +246,7 @@ FBGEMM_API void depthwise_3x3x3_pad_1(
/**
* @col_offsets nullptr if col_offsets are folded into bias
*/
+template <typename BIAS_TYPE = std::int32_t>
FBGEMM_API void depthwise_3x3x3_per_channel_quantization_pad_1(
int N,
int T,
@@ -177,13 +259,14 @@ FBGEMM_API void depthwise_3x3x3_per_channel_quantization_pad_1(
std::int32_t A_zero_point,
const std::uint8_t* A,
const std::int32_t* B_zero_point,
- const Packed3x3x3ConvMatrix& Bp,
+ const PackedDepthWiseConvMatrix& Bp,
const float* C_multiplier,
std::int32_t C_zero_point,
std::uint8_t* C,
const std::int32_t* col_offsets,
- const std::int32_t* bias,
+ const BIAS_TYPE* bias,
bool fuse_relu = false,
+ const float* act_times_w_scale = nullptr,
int thread_id = 0,
int num_threads = 1);
diff --git a/include/fbgemm/OutputProcessing-inl.h b/include/fbgemm/OutputProcessing-inl.h
index d984c60..04ae100 100644
--- a/include/fbgemm/OutputProcessing-inl.h
+++ b/include/fbgemm/OutputProcessing-inl.h
@@ -59,11 +59,13 @@ inline int DoSConvOnInpBuffer<outT, inT, nextOPType>::f(
template <
bool FUSE_RELU,
QuantizationGranularity Q_GRAN,
+ typename BIAS_TYPE,
typename outT,
typename inT,
typename nextOPType>
template <inst_set_t instSet>
-inline int ReQuantizeOutput<FUSE_RELU, Q_GRAN, outT, inT, nextOPType>::f(
+inline int
+ReQuantizeOutput<FUSE_RELU, Q_GRAN, BIAS_TYPE, outT, inT, nextOPType>::f(
outT* out,
const inT* inp,
const block_type_t& block,
@@ -98,11 +100,20 @@ inline int ReQuantizeOutput<FUSE_RELU, Q_GRAN, outT, inT, nextOPType>::f(
raw -= q_row_offsets_[i - block.row_start] *
Bq_zero_point_[Bq_zero_point_idx];
}
+ float raw_f;
if (bias_) {
- raw += bias_[j];
+ if (std::is_same<BIAS_TYPE, float>::value) {
+ raw_f = raw;
+ raw_f += bias_[j] / act_times_w_scale_[Bq_zero_point_idx];
+ } else {
+ raw += bias_[j];
+ raw_f = raw;
+ }
+ } else {
+ raw_f = raw;
}
- float ab = raw * C_multiplier_[Bq_zero_point_idx];
+ float ab = raw_f * C_multiplier_[Bq_zero_point_idx];
long rounded = std::lrintf(ab) + C_zero_point_;
out[i * ld_out + j] = std::max(
@@ -115,15 +126,16 @@ inline int ReQuantizeOutput<FUSE_RELU, Q_GRAN, outT, inT, nextOPType>::f(
Bq_zero_point_[0] == 0) ||
q_row_offsets_ == nullptr;
- requantizationParams_t r = {Aq_zero_point_,
- Bq_zero_point_,
- C_zero_point_,
- C_multiplier_,
- q_row_offsets_,
- q_col_offsets_,
- bias_,
- ncols_,
- groups_};
+ requantizationParams_t<BIAS_TYPE> r = {Aq_zero_point_,
+ Bq_zero_point_,
+ C_zero_point_,
+ C_multiplier_,
+ q_row_offsets_,
+ q_col_offsets_,
+ bias_,
+ ncols_,
+ groups_,
+ act_times_w_scale_};
if (Aq_zero_point_ == 0) {
if (b_symmetric) {
diff --git a/include/fbgemm/QuantUtilsAvx2.h b/include/fbgemm/QuantUtilsAvx2.h
index 47f33a8..c7f3f35 100644
--- a/include/fbgemm/QuantUtilsAvx2.h
+++ b/include/fbgemm/QuantUtilsAvx2.h
@@ -40,9 +40,10 @@ struct FBGEMM_API RequantizationParams {
////////////////////////////////////////////////////////////////////////////////
// Utility functions
+template <typename T=std::uint8_t>
void QuantizeAvx2(
const float* src,
- std::uint8_t* dst,
+ T* dst,
int len,
const TensorQuantizationParams& qparams);
@@ -71,14 +72,15 @@ template <
bool B_SYMMETRIC,
QuantizationGranularity Q_GRAN,
bool HAS_BIAS,
- bool FUSE_RELU>
+ bool FUSE_RELU,
+ typename BIAS_TYPE = std::int32_t>
FBGEMM_API void requantizeOutputProcessingAvx2(
std::uint8_t* out,
const std::int32_t* inp,
const block_type_t& block,
int ld_out,
int ld_in,
- const requantizationParams_t& r);
+ const requantizationParams_t<BIAS_TYPE>& r);
template <
bool A_SYMMETRIC,
@@ -86,14 +88,15 @@ template <
QuantizationGranularity Q_GRAN,
bool HAS_BIAS,
bool FUSE_RELU,
- int C_PER_G>
+ int C_PER_G,
+ typename BIAS_TYPE = std::int32_t>
FBGEMM_API void requantizeOutputProcessingGConvAvx2(
std::uint8_t* out,
const std::int32_t* inp,
const block_type_t& block,
int ld_out,
int ld_in,
- const requantizationParams_t& r);
+ const requantizationParams_t<BIAS_TYPE>& r);
template <
bool A_SYMMETRIC,
diff --git a/include/fbgemm/UtilsAvx2.h b/include/fbgemm/UtilsAvx2.h
index 082edc1..3bac909 100644
--- a/include/fbgemm/UtilsAvx2.h
+++ b/include/fbgemm/UtilsAvx2.h
@@ -44,16 +44,19 @@ struct block_type_t {
* QuantUtilsAvx2.h as it combines all the parameters needed for various
* quantization granularities
*/
+template<typename BIAS_TYPE = std::int32_t>
struct requantizationParams_t {
+ using BIAS_T = BIAS_TYPE;
std::int32_t A_zero_point;
const std::int32_t* B_zero_point;
std::int32_t C_zero_point;
const float* C_multiplier;
const std::int32_t* row_offsets;
const std::int32_t* col_offsets;
- const std::int32_t* bias;
+ const BIAS_T* bias;
std::uint32_t ncols;
int groups;
+ const float* act_times_w_scale;
};
/**
diff --git a/src/CodeCache.h b/src/CodeCache.h
new file mode 100644
index 0000000..08e9c9b
--- /dev/null
+++ b/src/CodeCache.h
@@ -0,0 +1,59 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+#pragma once
+#include <condition_variable>
+#include <future>
+#include <map>
+#include <mutex>
+
+namespace fbgemm {
+
+/**
+ * @brief Thread safe cache for microkernels, ensures single creation per key.
+ * @tparam Key Type of unique key (typically a tuple)
+ * @tparam Value Type of the microkernel function (Typically a function pointer)
+ */
+template <typename KEY, typename VALUE> class CodeCache {
+private:
+ std::map<KEY, std::shared_future<VALUE>> values_;
+ std::mutex mutex_;
+
+public:
+ CodeCache(const CodeCache &) = delete;
+ CodeCache &operator=(const CodeCache &) = delete;
+
+ CodeCache(){};
+
+ VALUE getOrCreate(const KEY &key, std::function<VALUE()> generatorFunction) {
+ std::shared_future<VALUE> returnFuture;
+ std::promise<VALUE> returnPromise;
+ bool needsToGenerate = false;
+
+ // Check for existance of the key
+ {
+ std::unique_lock<std::mutex> lock(mutex_);
+
+ auto it = values_.find(key);
+ if (it != values_.end()) {
+ returnFuture = it->second;
+ } else {
+ values_[key] = returnFuture = returnPromise.get_future().share();
+ needsToGenerate = true;
+ }
+ }
+
+ // The value (code) generation is not happening under a lock
+ if (needsToGenerate) {
+ returnPromise.set_value(generatorFunction());
+ }
+
+ // Wait for the future and return the value
+ return returnFuture.get();
+ }
+};
+
+} // namespace fbgemm
diff --git a/src/ExecuteKernelU8S8.cc b/src/ExecuteKernelU8S8.cc
index 0a4ff55..4ae1b50 100644
--- a/src/ExecuteKernelU8S8.cc
+++ b/src/ExecuteKernelU8S8.cc
@@ -315,19 +315,23 @@ void ExecuteKernel<
////////////////////////////////////////////////////////////////////////////////
// ReQuantizeOutput
-#define INSTANTIATE_REQUANT_BASE(PACK_A, ACC_T, RELU, Q_GRAN) \
- template class ExecuteKernel< \
- PACK_A<uint8_t, ACC_T>, \
- PackBMatrix<int8_t, ACC_T>, \
- uint8_t, \
- ReQuantizeOutput<RELU, Q_GRAN>>;
+#define INSTANTIATE_REQUANT_BASE(PACK_A, ACC_T, RELU, Q_GRAN, BIAS_TYPE) \
+ template class ExecuteKernel< \
+ PACK_A<uint8_t, ACC_T>, \
+ PackBMatrix<int8_t, ACC_T>, \
+ uint8_t, \
+ ReQuantizeOutput<RELU, Q_GRAN, BIAS_TYPE>>;
+
+#define INSTANTIATE_REQUANT_BIAS_T(PACK_A, ACC_T, RELU, Q_GRAN) \
+ INSTANTIATE_REQUANT_BASE(PACK_A, ACC_T, RELU, Q_GRAN, float); \
+ INSTANTIATE_REQUANT_BASE(PACK_A, ACC_T, RELU, Q_GRAN, int32_t);
#define INSTANTIATE_REQUANT_Q_GRANS(PACK_A, ACC_T, RELU) \
- INSTANTIATE_REQUANT_BASE( \
+ INSTANTIATE_REQUANT_BIAS_T( \
PACK_A, ACC_T, RELU, QuantizationGranularity::TENSOR); \
- INSTANTIATE_REQUANT_BASE( \
+ INSTANTIATE_REQUANT_BIAS_T( \
PACK_A, ACC_T, RELU, QuantizationGranularity::GROUP); \
- INSTANTIATE_REQUANT_BASE( \
+ INSTANTIATE_REQUANT_BIAS_T( \
PACK_A, ACC_T, RELU, QuantizationGranularity::OUT_CHANNEL);
#define INSTANTIATE_REQUANT_RELU(PACK_A, ACC_T) \
@@ -344,21 +348,27 @@ INSTANTIATE_REQUANT_ACC_T(PackAWithRowOffset);
#undef INSTANTIATE_REQUANT_ACC_T
#undef INSTANTIATE_REQUANT_RELU
#undef INSTANTIATE_REQUANT_Q_GRANS
+#undef INSTANTIATE_REQUANT_BIAS_T
#undef INSTANTIATE_REQUANT_BASE
-#define INSTANTIATE_IM2COL_REQUANT_BASE(ACC_T, RELU, SPATIAL_DIM, Q_GRAN) \
- template class ExecuteKernel< \
- PackAWithIm2Col<uint8_t, ACC_T, SPATIAL_DIM>, \
- PackBMatrix<int8_t, ACC_T>, \
- uint8_t, \
- ReQuantizeOutput<RELU, Q_GRAN>>;
+#define INSTANTIATE_IM2COL_REQUANT_BASE( \
+ ACC_T, RELU, SPATIAL_DIM, Q_GRAN, BIAS_TYPE) \
+ template class ExecuteKernel< \
+ PackAWithIm2Col<uint8_t, ACC_T, SPATIAL_DIM>, \
+ PackBMatrix<int8_t, ACC_T>, \
+ uint8_t, \
+ ReQuantizeOutput<RELU, Q_GRAN, BIAS_TYPE>>;
+
+#define INSTANTIATE_IM2COL_REQUANT_BIAS_T(ACC_T, RELU, SPATIAL_DIM, Q_GRAN) \
+ INSTANTIATE_IM2COL_REQUANT_BASE(ACC_T, RELU, SPATIAL_DIM, Q_GRAN, float); \
+ INSTANTIATE_IM2COL_REQUANT_BASE(ACC_T, RELU, SPATIAL_DIM, Q_GRAN, int32_t);
#define INSTANTIATE_IM2COL_REQUANT_Q_GRANS(ACC_T, RELU, SPATIAL_DIM) \
- INSTANTIATE_IM2COL_REQUANT_BASE( \
+ INSTANTIATE_IM2COL_REQUANT_BIAS_T( \
ACC_T, RELU, SPATIAL_DIM, QuantizationGranularity::TENSOR); \
- INSTANTIATE_IM2COL_REQUANT_BASE( \
+ INSTANTIATE_IM2COL_REQUANT_BIAS_T( \
ACC_T, RELU, SPATIAL_DIM, QuantizationGranularity::GROUP); \
- INSTANTIATE_IM2COL_REQUANT_BASE( \
+ INSTANTIATE_IM2COL_REQUANT_BIAS_T( \
ACC_T, RELU, SPATIAL_DIM, QuantizationGranularity::OUT_CHANNEL);
#define INSTANTIATE_IM2COL_REQUANT_SPATIAL_DIM(ACC_T, RELU) \
@@ -375,6 +385,7 @@ INSTANTIATE_IM2COL_REQUANT_RELU(int16_t);
#undef INSTANTIATE_IM2COL_REQUANT_RELU
#undef INSTANTIATE_IM2COL_REQUANT_SPATIAL_DIM
#undef INSTANTIATE_IM2COL_REQUANT_Q_GRANS
+#undef INSTANTIATE_IM2COL_REQUANT_BIAS_T
#undef INSTANTIATE_IM2COL_REQUANT_BASE
////////////////////////////////////////////////////////////////////////////////
diff --git a/src/Fbgemm.cc b/src/Fbgemm.cc
index 1052044..b691b88 100644
--- a/src/Fbgemm.cc
+++ b/src/Fbgemm.cc
@@ -237,22 +237,26 @@ bool fbgemmSupportedCPU() {
////////////////////////////////////////////////////////////////////////////////
// ReQuantizeOutput
-#define INSTANTIATE_BASE(PACK_A, ACC_T, RELU, Q_GRAN) \
+#define INSTANTIATE_BASE(PACK_A, ACC_T, RELU, Q_GRAN, BIAS_TYPE) \
template void fbgemmPacked( \
PackMatrix<PACK_A<uint8_t, ACC_T>, uint8_t, ACC_T>& packA, \
PackMatrix<PackBMatrix<int8_t, ACC_T>, int8_t, ACC_T>& packB, \
uint8_t* C, \
int32_t* C_buffer, \
uint32_t ldc, \
- const ReQuantizeOutput<RELU, Q_GRAN>& outProcess, \
+ const ReQuantizeOutput<RELU, Q_GRAN, BIAS_TYPE>& outProcess, \
int thread_id, \
int num_threads, \
const BlockingFactors* blocking_params);
-#define INSTANTIATE_Q_GRANS(PACK_A, ACC_T, RELU) \
- INSTANTIATE_BASE(PACK_A, ACC_T, RELU, QuantizationGranularity::TENSOR); \
- INSTANTIATE_BASE(PACK_A, ACC_T, RELU, QuantizationGranularity::GROUP); \
- INSTANTIATE_BASE(PACK_A, ACC_T, RELU, QuantizationGranularity::OUT_CHANNEL);
+#define INSTANTIATE_BIAS_T(PACK_A, ACC_T, RELU, Q_GRAN) \
+ INSTANTIATE_BASE(PACK_A, ACC_T, RELU, Q_GRAN, float); \
+ INSTANTIATE_BASE(PACK_A, ACC_T, RELU, Q_GRAN, int32_t);
+
+#define INSTANTIATE_Q_GRANS(PACK_A, ACC_T, RELU) \
+ INSTANTIATE_BIAS_T(PACK_A, ACC_T, RELU, QuantizationGranularity::TENSOR); \
+ INSTANTIATE_BIAS_T(PACK_A, ACC_T, RELU, QuantizationGranularity::GROUP); \
+ INSTANTIATE_BIAS_T(PACK_A, ACC_T, RELU, QuantizationGranularity::OUT_CHANNEL);
#define INSTANTIATE_RELU(PACK_A, ACC_T) \
INSTANTIATE_Q_GRANS(PACK_A, ACC_T, false); \
@@ -268,27 +272,34 @@ INSTANTIATE_ACC_T(PackAWithRowOffset);
#undef INSTANTIATE_ACC_T
#undef INSTANTIATE_RELU
#undef INSTANTIATE_Q_GRANS
+#undef INSTANTIATE_BIAS_T
#undef INSTANTIATE_BASE
-#define INSTANTIATE_BASE(ACC_T, RELU, SPATIAL_DIM, Q_GRAN) \
- template void fbgemmPacked( \
- PackMatrix< \
- PackAWithIm2Col<uint8_t, ACC_T, SPATIAL_DIM>, \
- uint8_t, \
- ACC_T>& packA, \
- PackMatrix<PackBMatrix<int8_t, ACC_T>, int8_t, ACC_T>& packB, \
- uint8_t* C, \
- int32_t* C_buffer, \
- uint32_t ldc, \
- const ReQuantizeOutput<RELU, Q_GRAN>& outProcess, \
- int thread_id, \
- int num_threads, \
+#define INSTANTIATE_BASE(ACC_T, RELU, SPATIAL_DIM, Q_GRAN, BIAS_TYPE) \
+ template void fbgemmPacked( \
+ PackMatrix< \
+ PackAWithIm2Col<uint8_t, ACC_T, SPATIAL_DIM>, \
+ uint8_t, \
+ ACC_T>& packA, \
+ PackMatrix<PackBMatrix<int8_t, ACC_T>, int8_t, ACC_T>& packB, \
+ uint8_t* C, \
+ int32_t* C_buffer, \
+ uint32_t ldc, \
+ const ReQuantizeOutput<RELU, Q_GRAN, BIAS_TYPE>& outProcess, \
+ int thread_id, \
+ int num_threads, \
const BlockingFactors* blocking_params);
-#define INSTANTIATE_Q_GRANS(ACC_T, RELU, SPATIAL_DIM) \
- INSTANTIATE_BASE(ACC_T, RELU, SPATIAL_DIM, QuantizationGranularity::TENSOR); \
- INSTANTIATE_BASE(ACC_T, RELU, SPATIAL_DIM, QuantizationGranularity::GROUP); \
- INSTANTIATE_BASE( \
+#define INSTANTIATE_BIAS_T(ACC_T, RELU, SPATIAL_DIM, Q_GRAN) \
+ INSTANTIATE_BASE(ACC_T, RELU, SPATIAL_DIM, Q_GRAN, float); \
+ INSTANTIATE_BASE(ACC_T, RELU, SPATIAL_DIM, Q_GRAN, int32_t);
+
+#define INSTANTIATE_Q_GRANS(ACC_T, RELU, SPATIAL_DIM) \
+ INSTANTIATE_BIAS_T( \
+ ACC_T, RELU, SPATIAL_DIM, QuantizationGranularity::TENSOR); \
+ INSTANTIATE_BIAS_T( \
+ ACC_T, RELU, SPATIAL_DIM, QuantizationGranularity::GROUP); \
+ INSTANTIATE_BIAS_T( \
ACC_T, RELU, SPATIAL_DIM, QuantizationGranularity::OUT_CHANNEL);
#define INSTANTIATE_SPATIAL_DIM(ACC_T, RELU) \
@@ -305,6 +316,7 @@ INSTANTIATE_RELU(int16_t);
#undef INSTANTIATE_RELU
#undef INSTANTIATE_SPATIAL_DIM
#undef INSTANTIATE_Q_GRANS
+#undef INSTANTIATE_BIAS_T
#undef INSTANTIATE_BASE
////////////////////////////////////////////////////////////////////////////////
diff --git a/src/FbgemmConv.cc b/src/FbgemmConv.cc
index 33d1535..de833d2 100644
--- a/src/FbgemmConv.cc
+++ b/src/FbgemmConv.cc
@@ -89,52 +89,121 @@ int fbgemmConv(
// std::cout << "Depthwise fast path" << std::endl;
const std::int32_t* B_zero_point = outProcess.getBZeroPoint();
const float* C_multiplier = outProcess.getCMultiplier();
+ const float* act_times_w_scale = outProcess.getActWScale();
if (SPATIAL_DIM == 3) {
static_assert(
std::is_same<typename processOutputType::outType, std::uint8_t>::
value,
"For depthwise, only requantized output is supported");
- depthwise_3x3x3_pad_1(
- conv_p.MB, // mini batch
- conv_p.IN_DIM[0], // T
- conv_p.IN_DIM[1], // H
- conv_p.IN_DIM[2], // W
- conv_p.OC, // output channels
- conv_p.stride[0], // stride_t
- conv_p.stride[1], // stride_h
- conv_p.stride[2], // stride_w
- outProcess.getAZeroPoint(),
- activations,
- B_zero_point[0],
- *(packed_weights.getPackedWFor3DDW()),
- C_multiplier[0],
- outProcess.getCZeroPoint(),
- out,
- outProcess.getColOffsets(),
- outProcess.getBias(),
- outProcess.RELU_FUSED, // fuse_relu
- thread_id,
- num_threads);
+
+ if (processOutputType::QGRANType == QuantizationGranularity::TENSOR) {
+ depthwise_3x3x3_pad_1(
+ conv_p.MB, // mini batch
+ conv_p.IN_DIM[0], // T
+ conv_p.IN_DIM[1], // H
+ conv_p.IN_DIM[2], // W
+ conv_p.OC, // output channels
+ conv_p.stride[0], // stride_t
+ conv_p.stride[1], // stride_h
+ conv_p.stride[2], // stride_w
+ outProcess.getAZeroPoint(),
+ activations,
+ B_zero_point[0],
+ *(packed_weights.getPackedWForDepthwise()),
+ C_multiplier[0],
+ outProcess.getCZeroPoint(),
+ out,
+ outProcess.getColOffsets(),
+ outProcess.getBias(),
+ outProcess.RELU_FUSED, // fuse_relu
+ act_times_w_scale ? act_times_w_scale[0] : 1.0f,
+ thread_id,
+ num_threads);
+ } else if (
+ processOutputType::QGRANType ==
+ QuantizationGranularity::OUT_CHANNEL ||
+ processOutputType::QGRANType == QuantizationGranularity::GROUP) {
+ depthwise_3x3x3_per_channel_quantization_pad_1(
+ conv_p.MB, // mini batch
+ conv_p.IN_DIM[0], // T
+ conv_p.IN_DIM[1], // H
+ conv_p.IN_DIM[2], // W
+ conv_p.OC, // output channels
+ conv_p.stride[0], // stride_t
+ conv_p.stride[1], // stride_h
+ conv_p.stride[2], // stride_w
+ outProcess.getAZeroPoint(),
+ activations,
+ B_zero_point,
+ *(packed_weights.getPackedWForDepthwise()),
+ C_multiplier,
+ outProcess.getCZeroPoint(),
+ out,
+ outProcess.getColOffsets(),
+ outProcess.getBias(),
+ outProcess.RELU_FUSED, // fuse_relu
+ outProcess.getActWScale(), // act_scale * weight_scale
+ thread_id,
+ num_threads);
+ } else {
+ std::string msg =
+ "[FBGEMM_CONV_ERROR] This quantization granularity is "
+ "not supported";
+ throw std::runtime_error(msg);
+ }
} else {
- depthwise_3x3_pad_1(
- conv_p.MB, // mini batch
- conv_p.IN_DIM[0], // H
- conv_p.IN_DIM[1], // W
- conv_p.OC, // output channels
- conv_p.stride[0], // stride_h
- conv_p.stride[1], // stride_w
- outProcess.getAZeroPoint(),
- activations,
- B_zero_point[0],
- *(packed_weights.getPackedWFor2DDW()),
- C_multiplier[0],
- outProcess.getCZeroPoint(),
- out,
- outProcess.getColOffsets(),
- outProcess.getBias(),
- outProcess.RELU_FUSED, // fuse_relu
- thread_id,
- num_threads);
+ if (processOutputType::QGRANType == QuantizationGranularity::TENSOR) {
+ depthwise_3x3_pad_1(
+ conv_p.MB, // mini batch
+ conv_p.IN_DIM[0], // H
+ conv_p.IN_DIM[1], // W
+ conv_p.OC, // output channels
+ conv_p.stride[0], // stride_h
+ conv_p.stride[1], // stride_w
+ outProcess.getAZeroPoint(),
+ activations,
+ B_zero_point[0],
+ *(packed_weights.getPackedWForDepthwise()),
+ C_multiplier[0],
+ outProcess.getCZeroPoint(),
+ out,
+ outProcess.getColOffsets(),
+ outProcess.getBias(),
+ outProcess.RELU_FUSED, // fuse_relu
+ act_times_w_scale ? act_times_w_scale[0] : 1.0f,
+ thread_id,
+ num_threads);
+ } else if (
+ processOutputType::QGRANType ==
+ QuantizationGranularity::OUT_CHANNEL ||
+ processOutputType::QGRANType == QuantizationGranularity::GROUP) {
+ // The number of channels == groups for depthwise convolutions
+ depthwise_3x3_per_channel_quantization_pad_1(
+ conv_p.MB, // mini batch
+ conv_p.IN_DIM[0], // H
+ conv_p.IN_DIM[1], // W
+ conv_p.OC, // output channels
+ conv_p.stride[0], // stride_h
+ conv_p.stride[1], // stride_w
+ outProcess.getAZeroPoint(),
+ activations,
+ B_zero_point,
+ *(packed_weights.getPackedWForDepthwise()),
+ C_multiplier,
+ outProcess.getCZeroPoint(),
+ out,
+ outProcess.getColOffsets(),
+ outProcess.getBias(),
+ outProcess.RELU_FUSED, // fuse_relu
+ outProcess.getActWScale(), // act_scale * weight_scale
+ thread_id,
+ num_threads);
+ } else {
+ std::string msg =
+ "[FBGEMM_CONV_ERROR] This quantization granularity is "
+ "not supported";
+ throw std::runtime_error(msg);
+ }
}
break;
}
@@ -195,11 +264,32 @@ int fbgemmConv(
// All other convolutions go through im2col-based implementation
// std::cout << "Im2col path" << std::endl;
std::vector<int32_t> row_offset_buf(
- PackAWithIm2Col<uint8_t, ACC_T, SPATIAL_DIM>
- ::rowOffsetBufferSize(blocking_params));
+ PackAWithIm2Col<uint8_t, ACC_T, SPATIAL_DIM>::rowOffsetBufferSize(
+ blocking_params));
const std::int32_t* b_zero_point = outProcess.getBZeroPoint();
- bool b_symmetric = b_zero_point[0] == 0;
+ bool b_symmetric = false;
+ if (processOutputType::QGRANType == QuantizationGranularity::TENSOR) {
+ b_symmetric = b_zero_point[0] == 0;
+ } else if (
+ processOutputType::QGRANType == QuantizationGranularity::GROUP) {
+ b_symmetric =
+ std::all_of(b_zero_point, b_zero_point + conv_p.G, [](int i) {
+ return i == 0;
+ });
+ } else if (
+ processOutputType::QGRANType ==
+ QuantizationGranularity::OUT_CHANNEL) {
+ b_symmetric =
+ std::all_of(b_zero_point, b_zero_point + conv_p.OC, [](int i) {
+ return i == 0;
+ });
+ } else {
+ std::string msg =
+ "[FBGEMM_CONV_ERROR] This quantization granularity is "
+ "not supported";
+ throw std::runtime_error(msg);
+ }
PackAWithIm2Col<uint8_t, ACC_T, SPATIAL_DIM> packA(
conv_p,
activations,
@@ -227,21 +317,25 @@ int fbgemmConv(
return 0;
}
-#define INSTANTIATE_BASE(ACC_T, Q_GRAN, RELU, SPATIAL_DIM) \
+#define INSTANTIATE_BASE(ACC_T, Q_GRAN, RELU, SPATIAL_DIM, BIAS_TYPE) \
template int fbgemmConv( \
const conv_param_t<SPATIAL_DIM>& conv_p, \
const std::uint8_t* activations, \
PackWeightsForConv<SPATIAL_DIM, std::int8_t, ACC_T>& packed_weights, \
std::uint8_t* out, \
std::int32_t* outBuffer, \
- ReQuantizeOutput<RELU, Q_GRAN>& outProcess, \
+ ReQuantizeOutput<RELU, Q_GRAN, BIAS_TYPE>& outProcess, \
int thread_id, \
int num_threads, \
const BlockingFactors* blocking_params);
+#define INSTANTIATE_BIAS_T(ACC_T, Q_GRAN, RELU, SPATIAL_DIM) \
+ INSTANTIATE_BASE(ACC_T, Q_GRAN, RELU, SPATIAL_DIM, float); \
+ INSTANTIATE_BASE(ACC_T, Q_GRAN, RELU, SPATIAL_DIM, int32_t);
+
#define INSTANTIATE_SPATIAL_DIM(ACC_T, Q_GRAN, RELU) \
- INSTANTIATE_BASE(ACC_T, Q_GRAN, RELU, 2); \
- INSTANTIATE_BASE(ACC_T, Q_GRAN, RELU, 3);
+ INSTANTIATE_BIAS_T(ACC_T, Q_GRAN, RELU, 2); \
+ INSTANTIATE_BIAS_T(ACC_T, Q_GRAN, RELU, 3);
#define INSTANTIATE_RELU(ACC_T, Q_GRAN) \
INSTANTIATE_SPATIAL_DIM(ACC_T, Q_GRAN, true); \
@@ -257,6 +351,7 @@ INSTANTIATE_Q_GRANS(std::int32_t);
#undef INSTANTIATE_Q_GRANS
#undef INSTANTIATE_RELU
#undef INSTANTIATE_SPATIAL_DIM
+#undef INSTANTIATE_BIAS_T
#undef INSTANTIATE_BASE
template bool takeDepthWiseFastPath<2, std::int32_t>(
diff --git a/src/FbgemmFP16.cc b/src/FbgemmFP16.cc
index f357966..b034f2c 100644
--- a/src/FbgemmFP16.cc
+++ b/src/FbgemmFP16.cc
@@ -50,6 +50,7 @@ struct KernelInfo {
// autotuned kernel splits for various cases m = 1:mb_max
// may need re-autotuning for new uarch
+ // clang-format off
static constexpr int partition[121][2][2] = {
// NOTE: clang-format wants to use a different formatting but the current
// formatting should be easier to read.
@@ -175,6 +176,7 @@ struct KernelInfo {
{ { 6, 19 }, { 5, 1 } }, // 119
{ { 6, 20 }, { 0, 0 } }, // 120
};
+ // clang-format on
};
constexpr KernelInfo::knl_ptr KernelInfo::kernel[7];;
constexpr int KernelInfo::partition[121][2][2];
diff --git a/src/FbgemmI8Depthwise3DAvx2.cc b/src/FbgemmI8Depthwise3DAvx2.cc
new file mode 100644
index 0000000..925d265
--- /dev/null
+++ b/src/FbgemmI8Depthwise3DAvx2.cc
@@ -0,0 +1,1415 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+#include "fbgemm/FbgemmI8DepthwiseAvx2.h"
+
+#include <string>
+#include <tuple> // for tie
+
+#include "FbgemmI8DepthwiseAvx2-inl.h"
+
+using namespace std;
+
+namespace fbgemm {
+
+template <
+ bool SUM_A,
+ bool REMAINDER = false,
+ bool PER_CHANNEL_QUANTIZATION = false>
+static inline __attribute__((always_inline)) void inner_prod_3x3x3_packed_(
+ int T,
+ int H,
+ int W,
+ int K,
+ int t_in,
+ int h_in,
+ int w_in,
+ const uint8_t* A,
+ int32_t A_zero_point,
+ const int8_t* Bp,
+ const int32_t* B_zero_point,
+ int32_t* C,
+ int remainder,
+ int32_t* row_offsets) {
+ __m256i A_zero_point_v = _mm256_set1_epi8(static_cast<uint8_t>(A_zero_point));
+ __m256i mask_v = _mm256_setzero_si256();
+ if (REMAINDER) {
+ mask_v = _mm256_loadu_si256(
+ reinterpret_cast<const __m256i*>(masks[remainder / 4]));
+ }
+
+ // The code below can be written as a simple R*S loop but the compiler
+ // doesn't unroll so we're manually unrolling it.
+ // constexpr int R = 3, S = 3;
+ // array<__m256i, R * S> a_v;
+ // for (int r = 0; r < R; ++r) {
+ // for (int s = 0; s < S; ++s) {
+ // if (h_in + r >= 0 && h_in + r < H && w_in + s >= 0 && w_in + s < W) {
+ // if (REMAINDER) {
+ // a_v[r * S + s] =
+ // _mm256_maskload_epi32((const int *)(A + (r * W + s) * K),
+ // mask_v);
+ // } else {
+ // a_v[r * S + s] =
+ // _mm256_lddqu_si256((const __m256i *)(A + (r * W + s) * K));
+ // }
+ // } else {
+ // a_v[r * S + s] = A_zero_point_v;
+ // }
+ // }
+ // }
+ __m256i a_v[8];
+ a_v[0] = A_zero_point_v;
+ a_v[1] = A_zero_point_v;
+ a_v[2] = A_zero_point_v;
+ a_v[3] = A_zero_point_v;
+ a_v[4] = A_zero_point_v;
+ a_v[5] = A_zero_point_v;
+ a_v[6] = A_zero_point_v;
+ a_v[7] = A_zero_point_v;
+
+ if (t_in >= 0 && t_in < T) {
+ if (h_in >= 0 && h_in < H) {
+ if (w_in >= 0 && w_in < W) {
+ a_v[0] = load_a<REMAINDER>(A + ((0 * H + 0) * W + 0) * K, mask_v);
+ }
+ if (w_in + 1 >= 0 && w_in + 1 < W) {
+ a_v[1] = load_a<REMAINDER>(A + ((0 * H + 0) * W + 1) * K, mask_v);
+ }
+ if (w_in + 2 >= 0 && w_in + 2 < W) {
+ a_v[2] = load_a<REMAINDER>(A + ((0 * H + 0) * W + 2) * K, mask_v);
+ }
+ }
+
+ if (h_in + 1 >= 0 && h_in + 1 < H) {
+ if (w_in >= 0 && w_in < W) {
+ a_v[3] = load_a<REMAINDER>(A + ((0 * H + 1) * W + 0) * K, mask_v);
+ }
+ if (w_in + 1 >= 0 && w_in + 1 < W) {
+ a_v[4] = load_a<REMAINDER>(A + ((0 * H + 1) * W + 1) * K, mask_v);
+ }
+ if (w_in + 2 >= 0 && w_in + 2 < W) {
+ a_v[5] = load_a<REMAINDER>(A + ((0 * H + 1) * W + 2) * K, mask_v);
+ }
+ }
+
+ if (h_in + 2 >= 0 && h_in + 2 < H) {
+ if (w_in >= 0 && w_in < W) {
+ a_v[6] = load_a<REMAINDER>(A + ((0 * H + 2) * W + 0) * K, mask_v);
+ }
+ if (w_in + 1 >= 0 && w_in + 1 < W) {
+ a_v[7] = load_a<REMAINDER>(A + ((0 * H + 2) * W + 1) * K, mask_v);
+ }
+ }
+ }
+
+ __m256i a_sum[4];
+ inner_prod_packed_<8, SUM_A, REMAINDER>(
+ a_v, reinterpret_cast<const __m256i*>(Bp), C, remainder, a_sum);
+
+ a_v[0] = A_zero_point_v;
+ a_v[1] = A_zero_point_v;
+ a_v[2] = A_zero_point_v;
+ a_v[3] = A_zero_point_v;
+ a_v[4] = A_zero_point_v;
+ a_v[5] = A_zero_point_v;
+ a_v[6] = A_zero_point_v;
+ a_v[7] = A_zero_point_v;
+
+ if (t_in >= 0 && t_in < T) {
+ if (h_in + 2 >= 0 && h_in + 2 < H) {
+ if (w_in + 2 >= 0 && w_in + 2 < W) {
+ a_v[0] = load_a<REMAINDER>(A + ((0 * H + 2) * W + 2) * K, mask_v);
+ }
+ }
+ }
+
+ if (t_in + 1 >= 0 && t_in + 1 < T) {
+ if (h_in >= 0 && h_in < H) {
+ if (w_in >= 0 && w_in < W) {
+ a_v[1] = load_a<REMAINDER>(A + ((1 * H + 0) * W + 0) * K, mask_v);
+ }
+ if (w_in + 1 >= 0 && w_in + 1 < W) {
+ a_v[2] = load_a<REMAINDER>(A + ((1 * H + 0) * W + 1) * K, mask_v);
+ }
+ if (w_in + 2 >= 0 && w_in + 2 < W) {
+ a_v[3] = load_a<REMAINDER>(A + ((1 * H + 0) * W + 2) * K, mask_v);
+ }
+ }
+
+ if (h_in + 1 >= 0 && h_in + 1 < H) {
+ if (w_in >= 0 && w_in < W) {
+ a_v[4] = load_a<REMAINDER>(A + ((1 * H + 1) * W + 0) * K, mask_v);
+ }
+ if (w_in + 1 >= 0 && w_in + 1 < W) {
+ a_v[5] = load_a<REMAINDER>(A + ((1 * H + 1) * W + 1) * K, mask_v);
+ }
+ if (w_in + 2 >= 0 && w_in + 2 < W) {
+ a_v[6] = load_a<REMAINDER>(A + ((1 * H + 1) * W + 2) * K, mask_v);
+ }
+ }
+
+ if (h_in + 2 >= 0 && h_in + 2 < H) {
+ if (w_in >= 0 && w_in < W) {
+ a_v[7] = load_a<REMAINDER>(A + ((1 * H + 2) * W + 0) * K, mask_v);
+ }
+ }
+ }
+
+ __m256i a_sum_temp[4];
+ inner_prod_packed_<8, SUM_A, REMAINDER, true /* acc */>(
+ a_v, reinterpret_cast<const __m256i*>(Bp) + 8, C, remainder, a_sum_temp);
+ if (SUM_A) {
+ a_sum[0] = _mm256_add_epi32(a_sum[0], a_sum_temp[0]);
+ a_sum[1] = _mm256_add_epi32(a_sum[1], a_sum_temp[1]);
+ a_sum[2] = _mm256_add_epi32(a_sum[2], a_sum_temp[2]);
+ a_sum[3] = _mm256_add_epi32(a_sum[3], a_sum_temp[3]);
+ }
+
+ a_v[0] = A_zero_point_v;
+ a_v[1] = A_zero_point_v;
+ a_v[2] = A_zero_point_v;
+ a_v[3] = A_zero_point_v;
+ a_v[4] = A_zero_point_v;
+ a_v[5] = A_zero_point_v;
+ a_v[6] = A_zero_point_v;
+ a_v[7] = A_zero_point_v;
+
+ if (t_in + 1 >= 0 && t_in + 1 < T) {
+ if (h_in + 2 >= 0 && h_in + 2 < H) {
+ if (w_in + 1 >= 0 && w_in + 1 < W) {
+ a_v[0] = load_a<REMAINDER>(A + ((1 * H + 2) * W + 1) * K, mask_v);
+ }
+ if (w_in + 2 >= 0 && w_in + 2 < W) {
+ a_v[1] = load_a<REMAINDER>(A + ((1 * H + 2) * W + 2) * K, mask_v);
+ }
+ }
+ }
+
+ if (t_in + 2 >= 0 && t_in + 2 < T) {
+ if (h_in >= 0 && h_in < H) {
+ if (w_in >= 0 && w_in < W) {
+ a_v[2] = load_a<REMAINDER>(A + ((2 * H + 0) * W + 0) * K, mask_v);
+ }
+ if (w_in + 1 >= 0 && w_in + 1 < W) {
+ a_v[3] = load_a<REMAINDER>(A + ((2 * H + 0) * W + 1) * K, mask_v);
+ }
+ if (w_in + 2 >= 0 && w_in + 2 < W) {
+ a_v[4] = load_a<REMAINDER>(A + ((2 * H + 0) * W + 2) * K, mask_v);
+ }
+ }
+
+ if (h_in + 1 >= 0 && h_in + 1 < H) {
+ if (w_in >= 0 && w_in < W) {
+ a_v[5] = load_a<REMAINDER>(A + ((2 * H + 1) * W + 0) * K, mask_v);
+ }
+ if (w_in + 1 >= 0 && w_in + 1 < W) {
+ a_v[6] = load_a<REMAINDER>(A + ((2 * H + 1) * W + 1) * K, mask_v);
+ }
+ if (w_in + 2 >= 0 && w_in + 2 < W) {
+ a_v[7] = load_a<REMAINDER>(A + ((2 * H + 1) * W + 2) * K, mask_v);
+ }
+ }
+ }
+
+ inner_prod_packed_<8, SUM_A, REMAINDER, true /* acc */>(
+ a_v, reinterpret_cast<const __m256i*>(Bp) + 16, C, remainder, a_sum_temp);
+ if (SUM_A) {
+ a_sum[0] = _mm256_add_epi32(a_sum[0], a_sum_temp[0]);
+ a_sum[1] = _mm256_add_epi32(a_sum[1], a_sum_temp[1]);
+ a_sum[2] = _mm256_add_epi32(a_sum[2], a_sum_temp[2]);
+ a_sum[3] = _mm256_add_epi32(a_sum[3], a_sum_temp[3]);
+ }
+
+ a_v[0] = A_zero_point_v;
+ a_v[1] = A_zero_point_v;
+ a_v[2] = A_zero_point_v;
+
+ if (t_in + 2 >= 0 && t_in + 2 < T) {
+ if (h_in + 2 >= 0 && h_in + 2 < H) {
+ if (w_in >= 0 && w_in < W) {
+ a_v[0] = load_a<REMAINDER>(A + ((2 * H + 2) * W + 0) * K, mask_v);
+ }
+ if (w_in + 1 >= 0 && w_in + 1 < W) {
+ a_v[1] = load_a<REMAINDER>(A + ((2 * H + 2) * W + 1) * K, mask_v);
+ }
+ if (w_in + 2 >= 0 && w_in + 2 < W) {
+ a_v[2] = load_a<REMAINDER>(A + ((2 * H + 2) * W + 2) * K, mask_v);
+ }
+ }
+ }
+
+ inner_prod_packed_<3, SUM_A, REMAINDER, true /* acc */>(
+ a_v, reinterpret_cast<const __m256i*>(Bp) + 24, C, remainder, a_sum_temp);
+
+ if (SUM_A) {
+ a_sum[0] = _mm256_add_epi32(a_sum[0], a_sum_temp[0]);
+ a_sum[1] = _mm256_add_epi32(a_sum[1], a_sum_temp[1]);
+ a_sum[2] = _mm256_add_epi32(a_sum[2], a_sum_temp[2]);
+ a_sum[3] = _mm256_add_epi32(a_sum[3], a_sum_temp[3]);
+
+ __m256i B_zero_point_v;
+ for (int i = 0; i < (REMAINDER ? (remainder / 8) : 4); ++i) {
+ if (PER_CHANNEL_QUANTIZATION) {
+ B_zero_point_v = _mm256_loadu_si256(
+ reinterpret_cast<const __m256i*>(B_zero_point + i * 8));
+ } else {
+ B_zero_point_v = _mm256_set1_epi32(B_zero_point[0]);
+ }
+ _mm256_store_si256(
+ reinterpret_cast<__m256i*>(&row_offsets[i * 8]),
+ _mm256_mullo_epi32(a_sum[i], B_zero_point_v));
+ }
+ }
+}
+
+template <
+ bool FUSE_RELU,
+ bool HAS_BIAS,
+ bool A_SYMMETRIC,
+ bool B_SYMMETRIC,
+ typename BIAS_TYPE>
+static inline __attribute__((always_inline)) void depthwise_3x3x3_kernel_(
+ int T,
+ int H,
+ int W,
+ int K,
+ int t,
+ int h,
+ int w,
+ int stride_t,
+ int stride_h,
+ int stride_w,
+ int32_t A_zero_point,
+ const uint8_t* A,
+ int32_t B_zero_point,
+ const int8_t* Bp,
+ float C_multiplier,
+ int32_t C_zero_point,
+ int32_t* C_int32,
+ uint8_t* C_uint8,
+ int32_t* row_offsets,
+ const int32_t* col_offsets,
+ const BIAS_TYPE* bias,
+ float act_times_w_scale) {
+ constexpr int R = 3, S = 3;
+ constexpr int PAD_P = 1, PAD_T = 1, PAD_B = 1, PAD_L = 1, PAD_R = 1;
+ int H_OUT = (H + PAD_T + PAD_B - R) / stride_h + 1;
+ int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1;
+ int t_in = -PAD_P + t * stride_t;
+ int h_in = -PAD_T + h * stride_h;
+ int w_in = -PAD_L + w * stride_w;
+
+ int k;
+ for (k = 0; k < K / 32 * 32; k += 32) {
+ inner_prod_3x3x3_packed_<!B_SYMMETRIC /*SUM_A*/>(
+ T,
+ H,
+ W,
+ K,
+ t_in,
+ h_in,
+ w_in,
+ A + ((t_in * H + h_in) * W + w_in) * K + k,
+ A_zero_point,
+ Bp + k * 28,
+ &B_zero_point,
+ C_int32 + k,
+ 0,
+ B_SYMMETRIC ? nullptr : &row_offsets[k]);
+ }
+ int remainder = K - k;
+ if (remainder) {
+ inner_prod_3x3x3_packed_<!B_SYMMETRIC /*SUM_A*/, true>(
+ T,
+ H,
+ W,
+ K,
+ t_in,
+ h_in,
+ w_in,
+ A + ((t_in * H + h_in) * W + w_in) * K + k,
+ A_zero_point,
+ Bp + k * 28,
+ &B_zero_point,
+ C_int32 + k,
+ remainder,
+ B_SYMMETRIC ? nullptr : &row_offsets[k]);
+ }
+
+ requantize_<
+ FUSE_RELU,
+ HAS_BIAS,
+ false, /*PER_CHAN_QUANT*/
+ A_SYMMETRIC,
+ B_SYMMETRIC>(
+ A_zero_point,
+ &C_multiplier,
+ C_zero_point,
+ C_int32,
+ C_uint8 + ((t * H_OUT + h) * W_OUT + w) * K,
+ K,
+ row_offsets,
+ col_offsets,
+ bias,
+ &act_times_w_scale);
+}
+
+template <bool FUSE_RELU, bool HAS_BIAS, bool A_SYMMETRIC, typename BIAS_TYPE>
+static inline __attribute__((always_inline)) void
+depthwise_3x3x3_per_channel_quantization_kernel_(
+ int T,
+ int H,
+ int W,
+ int K,
+ int t,
+ int h,
+ int w,
+ int stride_t,
+ int stride_h,
+ int stride_w,
+ int32_t A_zero_point,
+ const uint8_t* A,
+ const int32_t* B_zero_point,
+ const int8_t* Bp,
+ const float* C_multiplier,
+ int32_t C_zero_point,
+ int32_t* C_int32,
+ uint8_t* C_uint8,
+ int32_t* row_offsets,
+ const int32_t* col_offsets,
+ const BIAS_TYPE* bias,
+ const float* act_times_w_scale) {
+ constexpr int R = 3, S = 3;
+ constexpr int PAD_P = 1, PAD_T = 1, PAD_B = 1, PAD_L = 1, PAD_R = 1;
+ int H_OUT = (H + PAD_T + PAD_B - R) / stride_h + 1;
+ int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1;
+ int t_in = -PAD_P + t * stride_t;
+ int h_in = -PAD_T + h * stride_h;
+ int w_in = -PAD_L + w * stride_w;
+
+ int k;
+ for (k = 0; k < K / 32 * 32; k += 32) {
+ inner_prod_3x3x3_packed_<
+ true, /*SUM_A*/
+ false, /*remainder*/
+ true /*per-channel*/>(
+ T,
+ H,
+ W,
+ K,
+ t_in,
+ h_in,
+ w_in,
+ A + ((t_in * H + h_in) * W + w_in) * K + k,
+ A_zero_point,
+ Bp + k * 28,
+ B_zero_point + k,
+ C_int32 + k,
+ 0,
+ &row_offsets[k]);
+ }
+ int remainder = K - k;
+ if (remainder) {
+ inner_prod_3x3x3_packed_<
+ true, /*SUM_A*/
+ true, /*remainder*/
+ true /*per-channel*/>(
+ T,
+ H,
+ W,
+ K,
+ t_in,
+ h_in,
+ w_in,
+ A + ((t_in * H + h_in) * W + w_in) * K + k,
+ A_zero_point,
+ Bp + k * 28,
+ B_zero_point + k,
+ C_int32 + k,
+ remainder,
+ &row_offsets[k]);
+ }
+ requantize_<
+ FUSE_RELU,
+ HAS_BIAS,
+ true, /*PER_CHAN_QUANT*/
+ A_SYMMETRIC,
+ false /*B_SYMM*/>(
+ A_zero_point,
+ C_multiplier,
+ C_zero_point,
+ C_int32,
+ C_uint8 + ((t * H_OUT + h) * W_OUT + w) * K,
+ K,
+ row_offsets,
+ col_offsets,
+ bias,
+ act_times_w_scale);
+}
+
+template <
+ bool FUSE_RELU,
+ bool HAS_BIAS,
+ bool A_SYMMETRIC,
+ bool B_SYMMETRIC,
+ typename BIAS_TYPE>
+static inline __attribute__((always_inline)) void depthwise_3x3x3_pad_1_(
+ int N,
+ int T,
+ int H,
+ int W,
+ int K,
+ int stride_t,
+ int stride_h,
+ int stride_w,
+ int32_t A_zero_point,
+ const uint8_t* A,
+ int32_t B_zero_point,
+ const PackedDepthWiseConvMatrix& B,
+ float C_multiplier,
+ int32_t C_zero_point,
+ int32_t* C_int32,
+ uint8_t* C_uint8,
+ const int32_t* col_offsets,
+ const BIAS_TYPE* bias,
+ float act_times_w_scale,
+ int thread_id,
+ int num_threads) {
+ assert(K % 8 == 0);
+ constexpr int K_T = 3, K_H = 3, K_W = 3;
+ constexpr int PAD_P = 1, PAD_N = 1, PAD_T = 1, PAD_B = 1, PAD_L = 1,
+ PAD_R = 1;
+ int T_OUT = (T + PAD_P + PAD_N - K_T) / stride_t + 1;
+ int H_OUT = (H + PAD_T + PAD_B - K_H) / stride_h + 1;
+ int W_OUT = (W + PAD_L + PAD_R - K_W) / stride_w + 1;
+ const int8_t* Bp = B.PackedMat();
+
+ int32_t row_offsets[(K + 31) / 32 * 32] __attribute__((aligned(64)));
+
+ int n_begin, n_end;
+ int t_begin, t_end, h_begin, h_end;
+ if (N >= num_threads) {
+ int n_per_thread = (N + num_threads - 1) / num_threads;
+ n_begin = std::min(thread_id * n_per_thread, N);
+ n_end = std::min(n_begin + n_per_thread, N);
+ t_begin = 0;
+ t_end = T_OUT;
+ h_begin = 0;
+ h_end = H_OUT;
+ } else {
+ int nthreads_per_n = num_threads / N;
+ n_begin = std::min(thread_id / nthreads_per_n, N);
+ n_end = std::min(n_begin + 1, N);
+
+ int tid_of_n_begin = std::min(n_begin * nthreads_per_n, num_threads);
+ int tid_of_n_end = std::min(tid_of_n_begin + nthreads_per_n, num_threads);
+ int nthreads_of_n = tid_of_n_end - tid_of_n_begin;
+ int tid_within_n = thread_id - tid_of_n_begin;
+ assert(tid_within_n >= 0);
+ assert(tid_within_n < nthreads_of_n);
+
+ // n is processed by num_threads_t * num_threads_h 2D grid of threads
+ int num_threads_t, num_threads_h;
+ // num_threads_w <= num_threads_h
+ tie(num_threads_t, num_threads_h) = closest_factors_(nthreads_of_n);
+ int tid_t = tid_within_n / num_threads_h;
+ int tid_h = tid_within_n % num_threads_h;
+
+ int t_per_thread = (T_OUT + num_threads_t - 1) / num_threads_t;
+ t_begin = std::min(tid_t * t_per_thread, T_OUT);
+ t_end = std::min(t_begin + t_per_thread, T_OUT);
+
+ int h_per_thread = (H_OUT + num_threads_h - 1) / num_threads_h;
+ h_begin = std::min(tid_h * h_per_thread, H_OUT);
+ h_end = std::min(h_begin + h_per_thread, H_OUT);
+ }
+
+ for (int n = n_begin; n < n_end; ++n) {
+ const uint8_t* A_base = A + n * T * H * W * K;
+ uint8_t* C_uint8_base = C_uint8 + n * T_OUT * H_OUT * W_OUT * K;
+
+ for (int t = t_begin; t < t_end; ++t) {
+ for (int h = h_begin; h < h_end; ++h) {
+ for (int w = 0; w < W_OUT; ++w) {
+ depthwise_3x3x3_kernel_<
+ FUSE_RELU,
+ HAS_BIAS,
+ A_SYMMETRIC,
+ B_SYMMETRIC>(
+ T,
+ H,
+ W,
+ K,
+ t,
+ h,
+ w,
+ stride_t,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A_base,
+ B_zero_point,
+ Bp,
+ C_multiplier,
+ C_zero_point,
+ C_int32,
+ C_uint8_base,
+ row_offsets,
+ col_offsets,
+ bias,
+ act_times_w_scale);
+ } // w
+ } // h
+ } // t
+ } // for each n
+};
+
+template <bool FUSE_RELU, bool HAS_BIAS, bool A_SYMMETRIC, typename BIAS_TYPE>
+static inline __attribute__((always_inline)) void
+depthwise_3x3x3_per_channel_quantization_pad_1_(
+ int N,
+ int T,
+ int H,
+ int W,
+ int K,
+ int stride_t,
+ int stride_h,
+ int stride_w,
+ int32_t A_zero_point,
+ const uint8_t* A,
+ const int32_t* B_zero_point,
+ const PackedDepthWiseConvMatrix& B,
+ const float* C_multiplier,
+ int32_t C_zero_point,
+ int32_t* C_int32,
+ uint8_t* C_uint8,
+ const int32_t* col_offsets,
+ const BIAS_TYPE* bias,
+ const float* act_times_w_scale,
+ int thread_id,
+ int num_threads) {
+ assert(K % 8 == 0);
+ constexpr int K_T = 3, K_H = 3, K_W = 3;
+ constexpr int PAD_P = 1, PAD_N = 1, PAD_T = 1, PAD_B = 1, PAD_L = 1,
+ PAD_R = 1;
+ int T_OUT = (T + PAD_P + PAD_N - K_T) / stride_t + 1;
+ int H_OUT = (H + PAD_T + PAD_B - K_H) / stride_h + 1;
+ int W_OUT = (W + PAD_L + PAD_R - K_W) / stride_w + 1;
+ const int8_t* Bp = B.PackedMat();
+
+ int32_t row_offsets[(K + 31) / 32 * 32] __attribute__((aligned(64)));
+
+ int n_begin, n_end;
+ int t_begin, t_end, h_begin, h_end;
+ if (N >= num_threads) {
+ int n_per_thread = (N + num_threads - 1) / num_threads;
+ n_begin = std::min(thread_id * n_per_thread, N);
+ n_end = std::min(n_begin + n_per_thread, N);
+ t_begin = 0;
+ t_end = T_OUT;
+ h_begin = 0;
+ h_end = H_OUT;
+ } else {
+ int nthreads_per_n = num_threads / N;
+ n_begin = std::min(thread_id / nthreads_per_n, N);
+ n_end = std::min(n_begin + 1, N);
+
+ int tid_of_n_begin = std::min(n_begin * nthreads_per_n, num_threads);
+ int tid_of_n_end = std::min(tid_of_n_begin + nthreads_per_n, num_threads);
+ int nthreads_of_n = tid_of_n_end - tid_of_n_begin;
+ int tid_within_n = thread_id - tid_of_n_begin;
+ assert(tid_within_n >= 0);
+ assert(tid_within_n < nthreads_of_n);
+
+ // n is processed by num_threads_t * num_threads_h 2D grid of threads
+ int num_threads_t, num_threads_h;
+ // num_threads_w <= num_threads_h
+ tie(num_threads_t, num_threads_h) = closest_factors_(nthreads_of_n);
+ int tid_t = tid_within_n / num_threads_h;
+ int tid_h = tid_within_n % num_threads_h;
+
+ int t_per_thread = (T_OUT + num_threads_t - 1) / num_threads_t;
+ t_begin = std::min(tid_t * t_per_thread, T_OUT);
+ t_end = std::min(t_begin + t_per_thread, T_OUT);
+
+ int h_per_thread = (H_OUT + num_threads_h - 1) / num_threads_h;
+ h_begin = std::min(tid_h * h_per_thread, H_OUT);
+ h_end = std::min(h_begin + h_per_thread, H_OUT);
+ }
+
+ for (int n = n_begin; n < n_end; ++n) {
+ const uint8_t* A_base = A + n * T * H * W * K;
+ uint8_t* C_uint8_base = C_uint8 + n * T_OUT * H_OUT * W_OUT * K;
+
+ for (int t = t_begin; t < t_end; ++t) {
+ for (int h = h_begin; h < h_end; ++h) {
+ for (int w = 0; w < W_OUT; ++w) {
+ depthwise_3x3x3_per_channel_quantization_kernel_<
+ FUSE_RELU,
+ HAS_BIAS,
+ A_SYMMETRIC,
+ BIAS_TYPE>(
+ T,
+ H,
+ W,
+ K,
+ t,
+ h,
+ w,
+ stride_t,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A_base,
+ B_zero_point,
+ Bp,
+ C_multiplier,
+ C_zero_point,
+ C_int32,
+ C_uint8_base,
+ row_offsets,
+ col_offsets,
+ bias,
+ act_times_w_scale);
+ } // w
+ } // h
+ } // t
+ } // for each n
+};
+
+// Dispatch A_SYMMETRIC and B_SYMMETRIC
+template <bool FUSE_RELU, bool HAS_BIAS, typename BIAS_TYPE>
+static void depthwise_3x3x3_pad_1_(
+ int N,
+ int T,
+ int H,
+ int W,
+ int K,
+ int stride_t,
+ int stride_h,
+ int stride_w,
+ int32_t A_zero_point,
+ const uint8_t* A,
+ int32_t B_zero_point,
+ const PackedDepthWiseConvMatrix& B,
+ float C_multiplier,
+ int32_t C_zero_point,
+ uint8_t* C,
+ const int32_t* col_offsets,
+ const BIAS_TYPE* bias,
+ float act_times_w_scale,
+ int thread_id,
+ int num_threads) {
+ int32_t C_int32_temp[(K + 31) / 32 * 32];
+ if (A_zero_point == 0 || col_offsets == nullptr) {
+ if (B_zero_point == 0) {
+ depthwise_3x3x3_pad_1_<
+ FUSE_RELU,
+ HAS_BIAS,
+ true /*A_symmetric*/,
+ true /*B_symmetric*/,
+ BIAS_TYPE>(
+ N,
+ T,
+ H,
+ W,
+ K,
+ stride_t,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B_zero_point,
+ B,
+ C_multiplier,
+ C_zero_point,
+ C_int32_temp,
+ C,
+ col_offsets,
+ bias,
+ act_times_w_scale,
+ thread_id,
+ num_threads);
+ } else {
+ depthwise_3x3x3_pad_1_<
+ FUSE_RELU,
+ HAS_BIAS,
+ true /*A_symmetric*/,
+ false /*B_symmetric*/,
+ BIAS_TYPE>(
+ N,
+ T,
+ H,
+ W,
+ K,
+ stride_t,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B_zero_point,
+ B,
+ C_multiplier,
+ C_zero_point,
+ C_int32_temp,
+ C,
+ col_offsets,
+ bias,
+ act_times_w_scale,
+ thread_id,
+ num_threads);
+ }
+ } else {
+ if (B_zero_point == 0) {
+ depthwise_3x3x3_pad_1_<
+ FUSE_RELU,
+ HAS_BIAS,
+ false /*A_symmetric*/,
+ true /*B_symmetric*/,
+ BIAS_TYPE>(
+ N,
+ T,
+ H,
+ W,
+ K,
+ stride_t,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B_zero_point,
+ B,
+ C_multiplier,
+ C_zero_point,
+ C_int32_temp,
+ C,
+ col_offsets,
+ bias,
+ act_times_w_scale,
+ thread_id,
+ num_threads);
+ } else {
+ depthwise_3x3x3_pad_1_<
+ FUSE_RELU,
+ HAS_BIAS,
+ false /*A_symmetric*/,
+ false /*B_symmetric*/,
+ BIAS_TYPE>(
+ N,
+ T,
+ H,
+ W,
+ K,
+ stride_t,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B_zero_point,
+ B,
+ C_multiplier,
+ C_zero_point,
+ C_int32_temp,
+ C,
+ col_offsets,
+ bias,
+ act_times_w_scale,
+ thread_id,
+ num_threads);
+ }
+ }
+}
+
+// Dispatch HAS_BIAS
+template <bool FUSE_RELU, typename BIAS_TYPE>
+static void depthwise_3x3x3_pad_1_(
+ int N,
+ int T,
+ int H,
+ int W,
+ int K,
+ int stride_t,
+ int stride_h,
+ int stride_w,
+ int32_t A_zero_point,
+ const uint8_t* A,
+ int32_t B_zero_point,
+ const PackedDepthWiseConvMatrix& B,
+ float C_multiplier,
+ int32_t C_zero_point,
+ uint8_t* C,
+ const int32_t* col_offsets,
+ const BIAS_TYPE* bias,
+ float act_times_w_scale,
+ int thread_id,
+ int num_threads) {
+ if (bias) {
+ depthwise_3x3x3_pad_1_<FUSE_RELU, true /*HAS_BIAS*/, BIAS_TYPE>(
+ N,
+ T,
+ H,
+ W,
+ K,
+ stride_t,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B_zero_point,
+ B,
+ C_multiplier,
+ C_zero_point,
+ C,
+ col_offsets,
+ bias,
+ act_times_w_scale,
+ thread_id,
+ num_threads);
+ } else {
+ depthwise_3x3x3_pad_1_<FUSE_RELU, false /*HAS_BIAS*/, BIAS_TYPE>(
+ N,
+ T,
+ H,
+ W,
+ K,
+ stride_t,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B_zero_point,
+ B,
+ C_multiplier,
+ C_zero_point,
+ C,
+ col_offsets,
+ bias,
+ act_times_w_scale,
+ thread_id,
+ num_threads);
+ }
+}
+
+// Dispatch FUSE_RELU
+template <typename BIAS_TYPE>
+void depthwise_3x3x3_pad_1(
+ int N,
+ int T,
+ int H,
+ int W,
+ int K,
+ int stride_t,
+ int stride_h,
+ int stride_w,
+ int32_t A_zero_point,
+ const uint8_t* A,
+ int32_t B_zero_point,
+ const PackedDepthWiseConvMatrix& B,
+ float C_multiplier,
+ int32_t C_zero_point,
+ uint8_t* C,
+ const int32_t* col_offsets,
+ const BIAS_TYPE* bias,
+ bool fuse_relu,
+ float act_times_w_scale,
+ int thread_id,
+ int num_threads) {
+ if (B.GetKernelProduct() != 3 * 3 * 3) {
+ string msg =
+ "[FBGEMM_CONV_ERROR] Packed weight is expected to have kernel_prod " +
+ to_string(3 * 3 * 3) + " but has " + to_string(B.GetKernelProduct());
+ throw logic_error(msg);
+ }
+ if (stride_t == 0 || stride_h == 0 || stride_w == 0 || num_threads == 0) {
+ assert(
+ 0 &&
+ "stride_t == 0 || stride_h == 0 || stride_w == 0 || num_threads == 0");
+ return;
+ }
+ if (N == 0) {
+ // In C2, batch size 0 is allowed, so we should just early return.
+ return;
+ }
+ if (fuse_relu) {
+ depthwise_3x3x3_pad_1_<true /*FUSE_RELU*/, BIAS_TYPE>(
+ N,
+ T,
+ H,
+ W,
+ K,
+ stride_t,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B_zero_point,
+ B,
+ C_multiplier,
+ C_zero_point,
+ C,
+ col_offsets,
+ bias,
+ act_times_w_scale,
+ thread_id,
+ num_threads);
+ } else {
+ depthwise_3x3x3_pad_1_<false /*FUSE_RELU*/, BIAS_TYPE>(
+ N,
+ T,
+ H,
+ W,
+ K,
+ stride_t,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B_zero_point,
+ B,
+ C_multiplier,
+ C_zero_point,
+ C,
+ col_offsets,
+ bias,
+ act_times_w_scale,
+ thread_id,
+ num_threads);
+ }
+}
+
+// Dispatch A_SYMMETRIC
+template <bool FUSE_RELU, bool HAS_BIAS, typename BIAS_TYPE>
+static void depthwise_3x3x3_per_channel_quantization_pad_1_(
+ int N,
+ int T,
+ int H,
+ int W,
+ int K,
+ int stride_t,
+ int stride_h,
+ int stride_w,
+ int32_t A_zero_point,
+ const uint8_t* A,
+ const int32_t* B_zero_point,
+ const PackedDepthWiseConvMatrix& B,
+ const float* C_multiplier,
+ int32_t C_zero_point,
+ uint8_t* C,
+ const int32_t* col_offsets,
+ const BIAS_TYPE* bias,
+ const float* act_times_w_scale,
+ int thread_id,
+ int num_threads) {
+ int32_t C_int32_temp[(K + 31) / 32 * 32];
+ if (A_zero_point == 0 || col_offsets == nullptr) {
+ depthwise_3x3x3_per_channel_quantization_pad_1_<
+ FUSE_RELU,
+ HAS_BIAS,
+ true /*A_SYMM*/,
+ BIAS_TYPE>(
+ N,
+ T,
+ H,
+ W,
+ K,
+ stride_t,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B_zero_point,
+ B,
+ C_multiplier,
+ C_zero_point,
+ C_int32_temp,
+ C,
+ col_offsets,
+ bias,
+ act_times_w_scale,
+ thread_id,
+ num_threads);
+ } else {
+ depthwise_3x3x3_per_channel_quantization_pad_1_<
+ FUSE_RELU,
+ HAS_BIAS,
+ false /*A_SYMM*/,
+ BIAS_TYPE>(
+ N,
+ T,
+ H,
+ W,
+ K,
+ stride_t,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B_zero_point,
+ B,
+ C_multiplier,
+ C_zero_point,
+ C_int32_temp,
+ C,
+ col_offsets,
+ bias,
+ act_times_w_scale,
+ thread_id,
+ num_threads);
+ }
+}
+
+// Dispatch HAS_BIAS
+template <bool FUSE_RELU, typename BIAS_TYPE>
+static void depthwise_3x3x3_per_channel_quantization_pad_1_(
+ int N,
+ int T,
+ int H,
+ int W,
+ int K,
+ int stride_t,
+ int stride_h,
+ int stride_w,
+ int32_t A_zero_point,
+ const uint8_t* A,
+ const int32_t* B_zero_point,
+ const PackedDepthWiseConvMatrix& B,
+ const float* C_multiplier,
+ int32_t C_zero_point,
+ uint8_t* C,
+ const int32_t* col_offsets,
+ const BIAS_TYPE* bias,
+ const float* act_times_w_scale,
+ int thread_id,
+ int num_threads) {
+ if (bias) {
+ depthwise_3x3x3_per_channel_quantization_pad_1_<
+ FUSE_RELU,
+ true /* HAS_BIAS */,
+ BIAS_TYPE>(
+ N,
+ T,
+ H,
+ W,
+ K,
+ stride_t,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B_zero_point,
+ B,
+ C_multiplier,
+ C_zero_point,
+ C,
+ col_offsets,
+ bias,
+ act_times_w_scale,
+ thread_id,
+ num_threads);
+ } else {
+ depthwise_3x3x3_per_channel_quantization_pad_1_<
+ FUSE_RELU,
+ false /* HAS_BIAS */,
+ BIAS_TYPE>(
+ N,
+ T,
+ H,
+ W,
+ K,
+ stride_t,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B_zero_point,
+ B,
+ C_multiplier,
+ C_zero_point,
+ C,
+ col_offsets,
+ bias,
+ act_times_w_scale,
+ thread_id,
+ num_threads);
+ }
+}
+
+// Dispatch FUSE_RELU
+template <typename BIAS_TYPE>
+void depthwise_3x3x3_per_channel_quantization_pad_1(
+ int N,
+ int T,
+ int H,
+ int W,
+ int K,
+ int stride_t,
+ int stride_h,
+ int stride_w,
+ int32_t A_zero_point,
+ const uint8_t* A,
+ const int32_t* B_zero_point,
+ const PackedDepthWiseConvMatrix& B,
+ const float* C_multiplier,
+ int32_t C_zero_point,
+ uint8_t* C,
+ const int32_t* col_offsets,
+ const BIAS_TYPE* bias,
+ bool fuse_relu,
+ const float* act_times_w_scale,
+ int thread_id,
+ int num_threads) {
+ if (B.GetKernelProduct() != 3 * 3 * 3) {
+ string msg =
+ "[FBGEMM_CONV_ERROR] Packed weight is expected to have kernel_prod " +
+ to_string(3 * 3 * 3) + " but has " + to_string(B.GetKernelProduct());
+ throw logic_error(msg);
+ }
+ if (stride_t == 0 || stride_h == 0 || stride_w == 0 || num_threads == 0) {
+ assert(
+ 0 &&
+ "stride_t == 0 || stride_h == 0 || stride_w == 0 || num_threads == 0");
+ return;
+ }
+ if (N == 0) {
+ // In C2, batch size 0 is allowed, so we should just early return.
+ return;
+ }
+ if (fuse_relu) {
+ depthwise_3x3x3_per_channel_quantization_pad_1_<
+ true /* FUSE_RELU */,
+ BIAS_TYPE>(
+ N,
+ T,
+ H,
+ W,
+ K,
+ stride_t,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B_zero_point,
+ B,
+ C_multiplier,
+ C_zero_point,
+ C,
+ col_offsets,
+ bias,
+ act_times_w_scale,
+ thread_id,
+ num_threads);
+ } else {
+ depthwise_3x3x3_per_channel_quantization_pad_1_<
+ false /* FUSE_RELU */,
+ BIAS_TYPE>(
+ N,
+ T,
+ H,
+ W,
+ K,
+ stride_t,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B_zero_point,
+ B,
+ C_multiplier,
+ C_zero_point,
+ C,
+ col_offsets,
+ bias,
+ act_times_w_scale,
+ thread_id,
+ num_threads);
+ }
+}
+
+// To be removed
+void depthwise_3x3x3_pad_1(
+ int N,
+ int T,
+ int H,
+ int W,
+ int K,
+ int stride_t,
+ int stride_h,
+ int stride_w,
+ int32_t A_zero_point,
+ const uint8_t* A,
+ int32_t B_zero_point,
+ const PackedDepthWiseConvMatrix& B,
+ float C_multiplier,
+ int32_t C_zero_point,
+ uint8_t* C,
+ const int32_t* col_offsets,
+ const int32_t* bias,
+ bool fuse_relu,
+ int thread_id,
+ int num_threads) {
+ depthwise_3x3x3_pad_1<int32_t>(
+ N,
+ T,
+ H,
+ W,
+ K,
+ stride_t,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B_zero_point,
+ B,
+ C_multiplier,
+ C_zero_point,
+ C,
+ col_offsets,
+ bias,
+ fuse_relu,
+ 1.0f, // act_scale * weight_scale
+ thread_id,
+ num_threads);
+}
+
+void depthwise_3x3x3_per_channel_quantization_pad_1(
+ int N,
+ int T,
+ int H,
+ int W,
+ int K,
+ int stride_t,
+ int stride_h,
+ int stride_w,
+ int32_t A_zero_point,
+ const uint8_t* A,
+ const int32_t* B_zero_point,
+ const PackedDepthWiseConvMatrix& B,
+ const float* C_multiplier,
+ int32_t C_zero_point,
+ uint8_t* C,
+ const int32_t* col_offsets,
+ const int32_t* bias,
+ bool fuse_relu,
+ int thread_id,
+ int num_threads) {
+ depthwise_3x3x3_per_channel_quantization_pad_1(
+ N,
+ T,
+ H,
+ W,
+ K,
+ stride_t,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B_zero_point,
+ B,
+ C_multiplier,
+ C_zero_point,
+ C,
+ col_offsets,
+ bias,
+ fuse_relu,
+ nullptr, // act_scale * weight_scale
+ thread_id,
+ num_threads);
+}
+
+template void depthwise_3x3x3_pad_1(
+ int N,
+ int T,
+ int H,
+ int W,
+ int K,
+ int stride_t,
+ int stride_h,
+ int stride_w,
+ int32_t A_zero_point,
+ const uint8_t* A,
+ int32_t B_zero_point,
+ const PackedDepthWiseConvMatrix& B,
+ float C_multiplier,
+ int32_t C_zero_point,
+ uint8_t* C,
+ const int32_t* col_offsets,
+ const int32_t* bias,
+ bool fuse_relu,
+ float act_times_w_scale,
+ int thread_id,
+ int num_threads);
+
+template void depthwise_3x3x3_pad_1(
+ int N,
+ int T,
+ int H,
+ int W,
+ int K,
+ int stride_t,
+ int stride_h,
+ int stride_w,
+ int32_t A_zero_point,
+ const uint8_t* A,
+ int32_t B_zero_point,
+ const PackedDepthWiseConvMatrix& B,
+ float C_multiplier,
+ int32_t C_zero_point,
+ uint8_t* C,
+ const int32_t* col_offsets,
+ const float* bias,
+ bool fuse_relu,
+ float act_times_w_scale,
+ int thread_id,
+ int num_threads);
+
+template void depthwise_3x3x3_per_channel_quantization_pad_1(
+ int N,
+ int T,
+ int H,
+ int W,
+ int K,
+ int stride_t,
+ int stride_h,
+ int stride_w,
+ int32_t A_zero_point,
+ const uint8_t* A,
+ const int32_t* B_zero_point,
+ const PackedDepthWiseConvMatrix& B,
+ const float* C_multiplier,
+ int32_t C_zero_point,
+ uint8_t* C,
+ const int32_t* col_offsets,
+ const int32_t* bias,
+ bool fuse_relu,
+ const float* act_times_w_scale,
+ int thread_id,
+ int num_threads);
+
+template void depthwise_3x3x3_per_channel_quantization_pad_1(
+ int N,
+ int T,
+ int H,
+ int W,
+ int K,
+ int stride_t,
+ int stride_h,
+ int stride_w,
+ int32_t A_zero_point,
+ const uint8_t* A,
+ const int32_t* B_zero_point,
+ const PackedDepthWiseConvMatrix& B,
+ const float* C_multiplier,
+ int32_t C_zero_point,
+ uint8_t* C,
+ const int32_t* col_offsets,
+ const float* bias,
+ bool fuse_relu,
+ const float* act_times_w_scale,
+ int thread_id,
+ int num_threads);
+
+} // namespace fbgemm
diff --git a/src/FbgemmI8DepthwiseAvx2-inl.h b/src/FbgemmI8DepthwiseAvx2-inl.h
new file mode 100644
index 0000000..7ad39fc
--- /dev/null
+++ b/src/FbgemmI8DepthwiseAvx2-inl.h
@@ -0,0 +1,709 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+#pragma once
+
+#include <algorithm> // for min and max
+#include <cassert>
+#include <cmath> // for lrintf and sqrt
+#include <cstdint>
+#include <type_traits> // for is_same
+
+#include <immintrin.h>
+
+namespace fbgemm {
+
+// clang-format off
+static int masks[8][8] = {
+ // NOTE: clang-format wants to use a different formatting but the current
+ // formatting should be easier to read.
+ { 0, 0, 0, 0, 0, 0, 0, 0, },
+ { -1, 0, 0, 0, 0, 0, 0, 0, },
+ { -1, -1, 0, 0, 0, 0, 0, 0, },
+ { -1, -1, -1, 0, 0, 0, 0, 0, },
+ { -1, -1, -1, -1, 0, 0, 0, 0, },
+ { -1, -1, -1, -1, -1, 0, 0, 0, },
+ { -1, -1, -1, -1, -1, -1, 0, 0, },
+ { -1, -1, -1, -1, -1, -1, -1, 0, },
+};
+// clang-format on
+
+// c = a0 * b0 + a1 * b1 + a2 * b2 + a3 * b3
+// A is in uint8_t
+// B is in int8_t and pre-interleaved
+// C is in int32_t and 4 registers have results in the following layout:
+// c0_v: c[0:4], c[16:20]
+// c1_v: c[4:8], c[20:24]
+// c2_v: c[8:12], c[24:28]
+// c3_v: c[12:16], c[28:32]
+template <bool SUM_A = false>
+static inline __attribute__((always_inline)) void madd_epi16x4_packed(
+ __m256i a0_v,
+ __m256i a1_v,
+ __m256i a2_v,
+ __m256i a3_v,
+ const __m256i* b,
+ __m256i* c0_v,
+ __m256i* c1_v,
+ __m256i* c2_v,
+ __m256i* c3_v,
+ __m256i* a_sum = nullptr) {
+ __m256i a01_lo_v = _mm256_unpacklo_epi8(a0_v, a1_v);
+ __m256i a01_hi_v = _mm256_unpackhi_epi8(a0_v, a1_v);
+ __m256i a23_lo_v = _mm256_unpacklo_epi8(a2_v, a3_v);
+ __m256i a23_hi_v = _mm256_unpackhi_epi8(a2_v, a3_v);
+
+ if (SUM_A) {
+ __m256i one_epi8_v = _mm256_set1_epi8(1);
+ a_sum[0] =
+ _mm256_adds_epi16(_mm256_maddubs_epi16(a01_lo_v, one_epi8_v), a_sum[0]);
+ a_sum[1] =
+ _mm256_adds_epi16(_mm256_maddubs_epi16(a01_hi_v, one_epi8_v), a_sum[1]);
+ a_sum[0] =
+ _mm256_adds_epi16(_mm256_maddubs_epi16(a23_lo_v, one_epi8_v), a_sum[0]);
+ a_sum[1] =
+ _mm256_adds_epi16(_mm256_maddubs_epi16(a23_hi_v, one_epi8_v), a_sum[1]);
+ }
+
+ __m256i a0_interleaved_v = _mm256_unpacklo_epi16(a01_lo_v, a23_lo_v);
+ __m256i a1_interleaved_v = _mm256_unpackhi_epi16(a01_lo_v, a23_lo_v);
+ __m256i a2_interleaved_v = _mm256_unpacklo_epi16(a01_hi_v, a23_hi_v);
+ __m256i a3_interleaved_v = _mm256_unpackhi_epi16(a01_hi_v, a23_hi_v);
+
+ __m256i b0_v = _mm256_load_si256(b + 0);
+ __m256i b1_v = _mm256_load_si256(b + 1);
+ __m256i b2_v = _mm256_load_si256(b + 2);
+ __m256i b3_v = _mm256_load_si256(b + 3);
+
+ __m256i ab0 = _mm256_maddubs_epi16(a0_interleaved_v, b0_v);
+ __m256i ab1 = _mm256_maddubs_epi16(a1_interleaved_v, b1_v);
+ __m256i ab2 = _mm256_maddubs_epi16(a2_interleaved_v, b2_v);
+ __m256i ab3 = _mm256_maddubs_epi16(a3_interleaved_v, b3_v);
+
+ __m256i one_v = _mm256_set1_epi16(1);
+ *c0_v = _mm256_madd_epi16(ab0, one_v);
+ *c1_v = _mm256_madd_epi16(ab1, one_v);
+ *c2_v = _mm256_madd_epi16(ab2, one_v);
+ *c3_v = _mm256_madd_epi16(ab3, one_v);
+}
+
+// c = a0 * b0 + a1 * b1 + a2 * b2
+// A is in uint8_t
+// B is in int8_t and pre-interleaved
+// C is in int32_t and 4 registers have results in the following layout:
+// c0_v: c[0:4], c[16:20]
+// c1_v: c[4:8], c[20:24]
+// c2_v: c[8:12], c[24:28]
+// c3_v: c[12:16], c[28:32]
+template <bool SUM_A = false>
+static inline __attribute__((always_inline)) void madd_epi16x3_packed(
+ __m256i a0_v,
+ __m256i a1_v,
+ __m256i a2_v,
+ const __m256i* b,
+ __m256i* c0_v,
+ __m256i* c1_v,
+ __m256i* c2_v,
+ __m256i* c3_v,
+ __m256i* a_sum = nullptr) {
+ __m256i zero_v = _mm256_setzero_si256();
+
+ __m256i a01_lo_v = _mm256_unpacklo_epi8(a0_v, a1_v);
+ __m256i a01_hi_v = _mm256_unpackhi_epi8(a0_v, a1_v);
+ __m256i a23_lo_v = _mm256_unpacklo_epi8(a2_v, zero_v);
+ __m256i a23_hi_v = _mm256_unpackhi_epi8(a2_v, zero_v);
+
+ if (SUM_A) {
+ __m256i one_epi8_v = _mm256_set1_epi8(1);
+ a_sum[0] =
+ _mm256_adds_epi16(_mm256_maddubs_epi16(a01_lo_v, one_epi8_v), a_sum[0]);
+ a_sum[1] =
+ _mm256_adds_epi16(_mm256_maddubs_epi16(a01_hi_v, one_epi8_v), a_sum[1]);
+ a_sum[0] =
+ _mm256_adds_epi16(_mm256_maddubs_epi16(a23_lo_v, one_epi8_v), a_sum[0]);
+ a_sum[1] =
+ _mm256_adds_epi16(_mm256_maddubs_epi16(a23_hi_v, one_epi8_v), a_sum[1]);
+ }
+
+ __m256i a0_interleaved_v = _mm256_unpacklo_epi16(a01_lo_v, a23_lo_v);
+ __m256i a1_interleaved_v = _mm256_unpackhi_epi16(a01_lo_v, a23_lo_v);
+ __m256i a2_interleaved_v = _mm256_unpacklo_epi16(a01_hi_v, a23_hi_v);
+ __m256i a3_interleaved_v = _mm256_unpackhi_epi16(a01_hi_v, a23_hi_v);
+
+ __m256i b0_v = _mm256_load_si256(b + 0);
+ __m256i b1_v = _mm256_load_si256(b + 1);
+ __m256i b2_v = _mm256_load_si256(b + 2);
+ __m256i b3_v = _mm256_load_si256(b + 3);
+
+ __m256i ab0 = _mm256_maddubs_epi16(a0_interleaved_v, b0_v);
+ __m256i ab1 = _mm256_maddubs_epi16(a1_interleaved_v, b1_v);
+ __m256i ab2 = _mm256_maddubs_epi16(a2_interleaved_v, b2_v);
+ __m256i ab3 = _mm256_maddubs_epi16(a3_interleaved_v, b3_v);
+
+ __m256i one_v = _mm256_set1_epi16(1);
+ *c0_v = _mm256_madd_epi16(ab0, one_v);
+ *c1_v = _mm256_madd_epi16(ab1, one_v);
+ *c2_v = _mm256_madd_epi16(ab2, one_v);
+ *c3_v = _mm256_madd_epi16(ab3, one_v);
+}
+
+// c = a0 * b0 + a1 * b1
+// A is in uint8_t
+// B is in int8_t and pre-interleaved
+// C is in int32_t and 4 registers have results in the following layout:
+// c0_v: c[0:4], c[4:8]
+// c1_v: c[8:12], c[12:16]
+// c2_v: c[16:20], c[20:24]
+// c3_v: c[24:28], c[28:32]
+template <bool SUM_A = false>
+static inline __attribute__((always_inline)) void madd_epi16x2_packed(
+ __m256i a0_v,
+ __m256i a1_v,
+ const __m256i* b,
+ __m256i* c0_v,
+ __m256i* c1_v,
+ __m256i* c2_v,
+ __m256i* c3_v,
+ __m256i* a_sum = nullptr) {
+ __m256i a_lo_v = _mm256_unpacklo_epi8(a0_v, a1_v);
+ __m256i a_hi_v = _mm256_unpackhi_epi8(a0_v, a1_v);
+
+ if (SUM_A) {
+ __m256i one_epi8_v = _mm256_set1_epi8(1);
+ a_sum[0] =
+ _mm256_adds_epi16(_mm256_maddubs_epi16(a_lo_v, one_epi8_v), a_sum[0]);
+ a_sum[1] =
+ _mm256_adds_epi16(_mm256_maddubs_epi16(a_hi_v, one_epi8_v), a_sum[1]);
+ }
+
+ __m256i b0_v = _mm256_load_si256(b + 0);
+ __m256i b1_v = _mm256_load_si256(b + 1);
+
+ __m256i ab_lo_v = _mm256_maddubs_epi16(a_lo_v, b0_v);
+ __m256i ab_hi_v = _mm256_maddubs_epi16(a_hi_v, b1_v);
+
+ *c0_v = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(ab_lo_v));
+ *c1_v = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(ab_hi_v));
+ *c2_v = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(ab_lo_v, 1));
+ *c3_v = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(ab_hi_v, 1));
+}
+
+// c = a0 * b0
+// A is in uint8_t
+// B is in int8_t and pre-interleaved
+// C is in int32_t and 4 registers have results in the following layout:
+// c0_v: c[0:4], c[4:8]
+// c1_v: c[8:12], c[12:16]
+// c2_v: c[16:20], c[20:24]
+// c3_v: c[24:28], c[28:32]
+template <bool SUM_A = false>
+static inline __attribute__((always_inline)) void madd_epi16_packed(
+ __m256i a_v,
+ const __m256i* b,
+ __m256i* c0_v,
+ __m256i* c1_v,
+ __m256i* c2_v,
+ __m256i* c3_v,
+ __m256i* a_sum = nullptr) {
+ __m256i zero_v = _mm256_setzero_si256();
+
+ __m256i a_lo_v = _mm256_unpacklo_epi8(a_v, zero_v);
+ __m256i a_hi_v = _mm256_unpackhi_epi8(a_v, zero_v);
+
+ if (SUM_A) {
+ __m256i one_epi8_v = _mm256_set1_epi8(1);
+ a_sum[0] =
+ _mm256_adds_epi16(_mm256_maddubs_epi16(a_lo_v, one_epi8_v), a_sum[0]);
+ a_sum[1] =
+ _mm256_adds_epi16(_mm256_maddubs_epi16(a_hi_v, one_epi8_v), a_sum[1]);
+ }
+
+ __m256i b0_v = _mm256_load_si256(b + 0);
+ __m256i b1_v = _mm256_load_si256(b + 1);
+
+ __m256i ab_lo_v = _mm256_maddubs_epi16(a_lo_v, b0_v);
+ __m256i ab_hi_v = _mm256_maddubs_epi16(a_hi_v, b1_v);
+
+ *c0_v = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(ab_lo_v));
+ *c1_v = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(ab_hi_v));
+ *c2_v = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(ab_lo_v, 1));
+ *c3_v = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(ab_hi_v, 1));
+}
+
+// K is the number of accumulations we're doing
+template <int K, bool SUM_A = false, bool REMAINDER = false, bool ACC = false>
+static inline __attribute__((always_inline)) void inner_prod_packed_(
+ const __m256i* a_v,
+ const __m256i* Bp,
+ std::int32_t* C,
+ int remainder,
+ __m256i* a_sum = nullptr) {
+ __m256i c[4], c_temp[4];
+ __m256i a_sum_temp[2] = {0, 0};
+
+ int k = 0;
+ if (K >= 4) {
+ madd_epi16x4_packed<SUM_A>(
+ a_v[0],
+ a_v[1],
+ a_v[2],
+ a_v[3],
+ Bp,
+ &c[0],
+ &c[1],
+ &c[2],
+ &c[3],
+ a_sum_temp);
+
+ for (k = 4; k < K / 4 * 4; k += 4) {
+ madd_epi16x4_packed<SUM_A>(
+ a_v[k + 0],
+ a_v[k + 1],
+ a_v[k + 2],
+ a_v[k + 3],
+ Bp + k,
+ &c_temp[0],
+ &c_temp[1],
+ &c_temp[2],
+ &c_temp[3],
+ a_sum_temp);
+
+ c[0] = _mm256_add_epi32(c[0], c_temp[0]);
+ c[1] = _mm256_add_epi32(c[1], c_temp[1]);
+ c[2] = _mm256_add_epi32(c[2], c_temp[2]);
+ c[3] = _mm256_add_epi32(c[3], c_temp[3]);
+ }
+ } else {
+ c[0] = _mm256_setzero_si256();
+ c[1] = _mm256_setzero_si256();
+ c[2] = _mm256_setzero_si256();
+ c[3] = _mm256_setzero_si256();
+ }
+
+ if (K - k == 3) {
+ madd_epi16x3_packed<SUM_A>(
+ a_v[k],
+ a_v[k + 1],
+ a_v[k + 2],
+ Bp + k,
+ &c_temp[0],
+ &c_temp[1],
+ &c_temp[2],
+ &c_temp[3],
+ a_sum_temp);
+
+ c[0] = _mm256_add_epi32(c[0], c_temp[0]);
+ c[1] = _mm256_add_epi32(c[1], c_temp[1]);
+ c[2] = _mm256_add_epi32(c[2], c_temp[2]);
+ c[3] = _mm256_add_epi32(c[3], c_temp[3]);
+ }
+
+ c_temp[0] = _mm256_permute2f128_si256(c[0], c[1], 0x20);
+ c_temp[1] = _mm256_permute2f128_si256(c[2], c[3], 0x20);
+ c_temp[2] = _mm256_permute2f128_si256(c[0], c[1], 0x31);
+ c_temp[3] = _mm256_permute2f128_si256(c[2], c[3], 0x31);
+
+ if (K - k == 0 || K - k == 3) {
+ c[0] = c_temp[0];
+ c[1] = c_temp[1];
+ c[2] = c_temp[2];
+ c[3] = c_temp[3];
+ } else {
+ if (K - k == 1) {
+ madd_epi16_packed<SUM_A>(
+ a_v[k], Bp + k, &c[0], &c[1], &c[2], &c[3], a_sum_temp);
+ } else if (K - k == 2) {
+ madd_epi16x2_packed<SUM_A>(
+ a_v[k], a_v[k + 1], Bp + k, &c[0], &c[1], &c[2], &c[3], a_sum_temp);
+ }
+
+ c[0] = _mm256_add_epi32(c[0], c_temp[0]);
+ c[1] = _mm256_add_epi32(c[1], c_temp[1]);
+ c[2] = _mm256_add_epi32(c[2], c_temp[2]);
+ c[3] = _mm256_add_epi32(c[3], c_temp[3]);
+ }
+
+ if (REMAINDER) {
+ for (int r = 0; r < remainder / 8; ++r) {
+ if (ACC) {
+ _mm256_storeu_si256(
+ reinterpret_cast<__m256i*>(C + r * 8),
+ _mm256_add_epi32(
+ _mm256_loadu_si256(reinterpret_cast<__m256i*>(C + r * 8)),
+ c[r]));
+ } else {
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(C + r * 8), c[r]);
+ }
+ }
+ } else {
+ if (ACC) {
+ _mm256_storeu_si256(
+ reinterpret_cast<__m256i*>(C),
+ _mm256_add_epi32(
+ _mm256_loadu_si256(reinterpret_cast<__m256i*>(C)), c[0]));
+ _mm256_storeu_si256(
+ reinterpret_cast<__m256i*>(C + 8),
+ _mm256_add_epi32(
+ _mm256_loadu_si256(reinterpret_cast<__m256i*>(C + 8)), c[1]));
+ _mm256_storeu_si256(
+ reinterpret_cast<__m256i*>(C + 16),
+ _mm256_add_epi32(
+ _mm256_loadu_si256(reinterpret_cast<__m256i*>(C + 16)), c[2]));
+ _mm256_storeu_si256(
+ reinterpret_cast<__m256i*>(C + 24),
+ _mm256_add_epi32(
+ _mm256_loadu_si256(reinterpret_cast<__m256i*>(C + 24)), c[3]));
+ } else {
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(C), c[0]);
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(C + 8), c[1]);
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(C + 16), c[2]);
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(C + 24), c[3]);
+ }
+ }
+
+ if (SUM_A) {
+ a_sum[0] = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(a_sum_temp[0]));
+ a_sum[1] = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(a_sum_temp[1]));
+ a_sum[2] =
+ _mm256_cvtepi16_epi32(_mm256_extracti128_si256(a_sum_temp[0], 1));
+ a_sum[3] =
+ _mm256_cvtepi16_epi32(_mm256_extracti128_si256(a_sum_temp[1], 1));
+ }
+}
+
+// Almost same as ReQuantizeOutput in OutputProcessing-inh.h but different
+// row_offsets for each row because of depth-wise convolution
+template <
+ bool FUSE_RELU,
+ bool HAS_BIAS,
+ bool PER_CHANNEL_QUANTIZATION,
+ bool A_SYMMETRIC,
+ bool B_SYMMETRIC,
+ typename BIAS_TYPE>
+static inline __attribute__((always_inline)) void requantize_(
+ std::int32_t A_zero_point,
+ const float* C_multiplier,
+ std::int32_t C_zero_point,
+ const std::int32_t* C_int32,
+ std::uint8_t* C_uint8,
+ int n,
+ const std::int32_t* row_offsets,
+ const std::int32_t* col_offsets,
+ const BIAS_TYPE* bias,
+ const float* act_times_w_scale = nullptr) {
+ __m256 multiplier_v = _mm256_setzero_ps();
+ // Broadcasted reciprocal of act_times_w_scale
+ __m256 act_times_w_rcp_v = _mm256_setzero_ps();
+ if (!PER_CHANNEL_QUANTIZATION) {
+ multiplier_v = _mm256_set1_ps(*C_multiplier);
+ if (std::is_same<BIAS_TYPE, float>::value) {
+ act_times_w_rcp_v = _mm256_set1_ps(1.0f / (*act_times_w_scale));
+ }
+ }
+
+ __m256i min_v = _mm256_set1_epi8(static_cast<std::uint8_t>(0));
+ __m256i max_v = _mm256_set1_epi8(static_cast<std::uint8_t>(255));
+
+ if (A_SYMMETRIC) {
+ assert(A_zero_point == 0 || col_offsets == nullptr);
+ }
+ __m256i A_zero_point_v = _mm256_set1_epi32(A_zero_point);
+ __m256i C_zero_point_epi16_v = _mm256_set1_epi16(C_zero_point);
+ __m256i C_zero_point_epi8_v = _mm256_set1_epi8(C_zero_point);
+
+ __m256i permute_mask_v =
+ _mm256_set_epi32(0x07, 0x03, 0x06, 0x02, 0x05, 0x01, 0x04, 0x00);
+
+ constexpr int VLEN = 8;
+ int j = 0;
+ for (; j < n / (VLEN * 4) * (VLEN * 4); j += (VLEN * 4)) {
+ __m256i x_v =
+ _mm256_loadu_si256(reinterpret_cast<const __m256i*>(C_int32 + j));
+ __m256i y_v = _mm256_loadu_si256(
+ reinterpret_cast<const __m256i*>(C_int32 + j + VLEN));
+ __m256i z_v = _mm256_loadu_si256(
+ reinterpret_cast<const __m256i*>(C_int32 + j + 2 * VLEN));
+ __m256i w_v = _mm256_loadu_si256(
+ reinterpret_cast<const __m256i*>(C_int32 + j + 3 * VLEN));
+
+ __m256i row_offset_v;
+ if (!B_SYMMETRIC) {
+ row_offset_v =
+ _mm256_loadu_si256(reinterpret_cast<const __m256i*>(row_offsets + j));
+ x_v = _mm256_sub_epi32(x_v, row_offset_v);
+ }
+ __m256i col_off_v;
+ if (!A_SYMMETRIC) {
+ col_off_v = _mm256_mullo_epi32(
+ A_zero_point_v,
+ _mm256_loadu_si256(
+ reinterpret_cast<const __m256i*>(col_offsets + j)));
+ x_v = _mm256_sub_epi32(x_v, col_off_v);
+ }
+
+ if (!B_SYMMETRIC) {
+ row_offset_v = _mm256_loadu_si256(
+ reinterpret_cast<const __m256i*>(row_offsets + j + VLEN));
+ y_v = _mm256_sub_epi32(y_v, row_offset_v);
+ }
+ if (!A_SYMMETRIC) {
+ col_off_v = _mm256_mullo_epi32(
+ A_zero_point_v,
+ _mm256_loadu_si256(
+ reinterpret_cast<const __m256i*>(col_offsets + j + VLEN)));
+ y_v = _mm256_sub_epi32(y_v, col_off_v);
+ }
+
+ if (!B_SYMMETRIC) {
+ row_offset_v = _mm256_loadu_si256(
+ reinterpret_cast<const __m256i*>(row_offsets + j + 2 * VLEN));
+ z_v = _mm256_sub_epi32(z_v, row_offset_v);
+ }
+ if (!A_SYMMETRIC) {
+ col_off_v = _mm256_mullo_epi32(
+ A_zero_point_v,
+ _mm256_loadu_si256(
+ reinterpret_cast<const __m256i*>(col_offsets + j + 2 * VLEN)));
+ z_v = _mm256_sub_epi32(z_v, col_off_v);
+ }
+
+ if (!B_SYMMETRIC) {
+ row_offset_v = _mm256_loadu_si256(
+ reinterpret_cast<const __m256i*>(row_offsets + j + 3 * VLEN));
+ w_v = _mm256_sub_epi32(w_v, row_offset_v);
+ }
+ if (!A_SYMMETRIC) {
+ col_off_v = _mm256_mullo_epi32(
+ A_zero_point_v,
+ _mm256_loadu_si256(
+ reinterpret_cast<const __m256i*>(col_offsets + j + 3 * VLEN)));
+ w_v = _mm256_sub_epi32(w_v, col_off_v);
+ }
+
+ // convert to float
+ __m256 xf_v, yf_v, zf_v, wf_v;
+ if (HAS_BIAS) { // static if
+ if (std::is_same<BIAS_TYPE, float>::value) {
+ __m256 x_bias_v, y_bias_v, z_bias_v, w_bias_v;
+ if (PER_CHANNEL_QUANTIZATION) {
+ x_bias_v = _mm256_div_ps(
+ _mm256_loadu_ps(
+ reinterpret_cast<const float*>(bias + j + 0 * VLEN)),
+ _mm256_loadu_ps(act_times_w_scale + j + 0 * VLEN));
+ y_bias_v = _mm256_div_ps(
+ _mm256_loadu_ps(
+ reinterpret_cast<const float*>(bias + j + 1 * VLEN)),
+ _mm256_loadu_ps(act_times_w_scale + j + 1 * VLEN));
+ z_bias_v = _mm256_div_ps(
+ _mm256_loadu_ps(
+ reinterpret_cast<const float*>(bias + j + 2 * VLEN)),
+ _mm256_loadu_ps(act_times_w_scale + j + 2 * VLEN));
+ w_bias_v = _mm256_div_ps(
+ _mm256_loadu_ps(
+ reinterpret_cast<const float*>(bias + j + 3 * VLEN)),
+ _mm256_loadu_ps(act_times_w_scale + j + 3 * VLEN));
+ } else {
+ x_bias_v = _mm256_mul_ps(
+ _mm256_loadu_ps(
+ reinterpret_cast<const float*>(bias + j + 0 * VLEN)),
+ act_times_w_rcp_v);
+ y_bias_v = _mm256_mul_ps(
+ _mm256_loadu_ps(
+ reinterpret_cast<const float*>(bias + j + 1 * VLEN)),
+ act_times_w_rcp_v);
+ z_bias_v = _mm256_mul_ps(
+ _mm256_loadu_ps(
+ reinterpret_cast<const float*>(bias + j + 2 * VLEN)),
+ act_times_w_rcp_v);
+ w_bias_v = _mm256_mul_ps(
+ _mm256_loadu_ps(
+ reinterpret_cast<const float*>(bias + j + 3 * VLEN)),
+ act_times_w_rcp_v);
+ }
+ xf_v = _mm256_add_ps(_mm256_cvtepi32_ps(x_v), x_bias_v);
+ yf_v = _mm256_add_ps(_mm256_cvtepi32_ps(y_v), y_bias_v);
+ zf_v = _mm256_add_ps(_mm256_cvtepi32_ps(z_v), z_bias_v);
+ wf_v = _mm256_add_ps(_mm256_cvtepi32_ps(w_v), w_bias_v);
+ } else {
+ x_v = _mm256_add_epi32(
+ x_v,
+ _mm256_loadu_si256(
+ reinterpret_cast<const __m256i*>(bias + j + 0 * VLEN)));
+ y_v = _mm256_add_epi32(
+ y_v,
+ _mm256_loadu_si256(
+ reinterpret_cast<const __m256i*>(bias + j + 1 * VLEN)));
+ z_v = _mm256_add_epi32(
+ z_v,
+ _mm256_loadu_si256(
+ reinterpret_cast<const __m256i*>(bias + j + 2 * VLEN)));
+ w_v = _mm256_add_epi32(
+ w_v,
+ _mm256_loadu_si256(
+ reinterpret_cast<const __m256i*>(bias + j + 3 * VLEN)));
+ xf_v = _mm256_cvtepi32_ps(x_v);
+ yf_v = _mm256_cvtepi32_ps(y_v);
+ zf_v = _mm256_cvtepi32_ps(z_v);
+ wf_v = _mm256_cvtepi32_ps(w_v);
+ }
+ } else {
+ xf_v = _mm256_cvtepi32_ps(x_v);
+ yf_v = _mm256_cvtepi32_ps(y_v);
+ zf_v = _mm256_cvtepi32_ps(z_v);
+ wf_v = _mm256_cvtepi32_ps(w_v);
+ }
+
+ if (PER_CHANNEL_QUANTIZATION) {
+ multiplier_v = _mm256_loadu_ps(C_multiplier + j + 0 * VLEN);
+ }
+ __m256 x_scaled_v = _mm256_mul_ps(xf_v, multiplier_v);
+ if (PER_CHANNEL_QUANTIZATION) {
+ multiplier_v = _mm256_loadu_ps(C_multiplier + j + 1 * VLEN);
+ }
+ __m256 y_scaled_v = _mm256_mul_ps(yf_v, multiplier_v);
+ if (PER_CHANNEL_QUANTIZATION) {
+ multiplier_v = _mm256_loadu_ps(C_multiplier + j + 2 * VLEN);
+ }
+ __m256 z_scaled_v = _mm256_mul_ps(zf_v, multiplier_v);
+ if (PER_CHANNEL_QUANTIZATION) {
+ multiplier_v = _mm256_loadu_ps(C_multiplier + j + 3 * VLEN);
+ }
+ __m256 w_scaled_v = _mm256_mul_ps(wf_v, multiplier_v);
+
+ __m256i x_rounded_v = _mm256_cvtps_epi32(x_scaled_v);
+ __m256i y_rounded_v = _mm256_cvtps_epi32(y_scaled_v);
+ __m256i z_rounded_v = _mm256_cvtps_epi32(z_scaled_v);
+ __m256i w_rounded_v = _mm256_cvtps_epi32(w_scaled_v);
+
+ __m256i xy_packed_v = _mm256_adds_epi16(
+ _mm256_packs_epi32(x_rounded_v, y_rounded_v), C_zero_point_epi16_v);
+ __m256i zw_packed_v = _mm256_adds_epi16(
+ _mm256_packs_epi32(z_rounded_v, w_rounded_v), C_zero_point_epi16_v);
+ __m256i xyzw_packed_v = _mm256_packus_epi16(xy_packed_v, zw_packed_v);
+ __m256i xyzw_clamped_v = _mm256_max_epu8(
+ FUSE_RELU ? C_zero_point_epi8_v : min_v,
+ _mm256_min_epu8(xyzw_packed_v, max_v));
+
+ xyzw_clamped_v =
+ _mm256_permutevar8x32_epi32(xyzw_clamped_v, permute_mask_v);
+
+ _mm256_storeu_si256(
+ reinterpret_cast<__m256i*>(C_uint8 + j), xyzw_clamped_v);
+ } // j loop vectorized and unrolled 4x
+
+ for (; j < n / VLEN * VLEN; j += VLEN) {
+ __m256i x_v =
+ _mm256_loadu_si256(reinterpret_cast<const __m256i*>(C_int32 + j));
+
+ if (!B_SYMMETRIC) {
+ __m256i row_offset_v =
+ _mm256_loadu_si256(reinterpret_cast<const __m256i*>(row_offsets + j));
+ x_v = _mm256_sub_epi32(x_v, row_offset_v);
+ }
+ if (!A_SYMMETRIC) {
+ __m256i col_off_v = _mm256_mullo_epi32(
+ A_zero_point_v,
+ _mm256_loadu_si256(
+ reinterpret_cast<const __m256i*>(col_offsets + j)));
+ x_v = _mm256_sub_epi32(x_v, col_off_v);
+ }
+
+ // Convert to float
+ __m256 xf_v;
+ if (HAS_BIAS) { // static if
+ if (std::is_same<BIAS_TYPE, float>::value) {
+ __m256 x_bias_v;
+ if (PER_CHANNEL_QUANTIZATION) {
+ x_bias_v = _mm256_div_ps(
+ _mm256_loadu_ps(reinterpret_cast<const float*>(bias + j)),
+ _mm256_loadu_ps(act_times_w_scale + j));
+ } else {
+ x_bias_v = _mm256_mul_ps(
+ _mm256_loadu_ps(reinterpret_cast<const float*>(bias + j)),
+ act_times_w_rcp_v);
+ }
+ xf_v = _mm256_add_ps(_mm256_cvtepi32_ps(x_v), x_bias_v);
+ } else {
+ x_v = _mm256_add_epi32(
+ x_v,
+ _mm256_loadu_si256(reinterpret_cast<const __m256i*>(bias + j)));
+ xf_v = _mm256_cvtepi32_ps(x_v);
+ }
+ } else {
+ xf_v = _mm256_cvtepi32_ps(x_v);
+ }
+
+ if (PER_CHANNEL_QUANTIZATION) {
+ multiplier_v = _mm256_loadu_ps(C_multiplier + j);
+ }
+ __m256 x_scaled_v = _mm256_mul_ps(xf_v, multiplier_v);
+ __m256i x_rounded_v = _mm256_cvtps_epi32(x_scaled_v);
+
+ __m256i x_packed_v = _mm256_adds_epi16(
+ _mm256_packs_epi32(x_rounded_v, _mm256_setzero_si256()),
+ C_zero_point_epi16_v);
+ x_packed_v = _mm256_packus_epi16(x_packed_v, _mm256_setzero_si256());
+ __m256i x_clamped_v = _mm256_max_epu8(
+ FUSE_RELU ? C_zero_point_epi8_v : min_v,
+ _mm256_min_epu8(x_packed_v, max_v));
+
+ x_clamped_v = _mm256_permutevar8x32_epi32(x_clamped_v, permute_mask_v);
+
+ _mm_storel_epi64(
+ reinterpret_cast<__m128i*>(C_uint8 + j),
+ _mm256_castsi256_si128(x_clamped_v));
+ } // j loop vectorized
+
+ for (; j < n; ++j) {
+ std::int32_t raw = C_int32[j];
+ if (!B_SYMMETRIC) {
+ raw -= row_offsets[j];
+ }
+ if (!A_SYMMETRIC) {
+ raw -= A_zero_point * col_offsets[j];
+ }
+ float raw_f;
+ if (HAS_BIAS) { // static if
+ if (std::is_same<BIAS_TYPE, float>::value) {
+ raw_f = raw;
+ raw_f += bias[j] / act_times_w_scale[PER_CHANNEL_QUANTIZATION ? j : 0];
+ } else {
+ raw += bias[j];
+ raw_f = raw;
+ }
+ } else {
+ raw_f = raw;
+ }
+
+ float ab = raw_f * C_multiplier[PER_CHANNEL_QUANTIZATION ? j : 0];
+ long rounded = lrintf(ab) + C_zero_point;
+
+ C_uint8[j] = std::max(
+ FUSE_RELU ? static_cast<long>(C_zero_point) : 0l,
+ std::min(255l, rounded));
+ }
+}
+
+template <bool REMAINDER>
+static inline __attribute__((always_inline)) __m256i load_a(
+ const std::uint8_t* A,
+ __m256i mask_v) {
+ if (REMAINDER) {
+ return _mm256_maskload_epi32(reinterpret_cast<const int*>(A), mask_v);
+ } else {
+ return _mm256_lddqu_si256(reinterpret_cast<const __m256i*>(A));
+ }
+}
+
+static inline std::pair<int, int> closest_factors_(int n) {
+ int a = static_cast<int>(std::sqrt(n));
+ while (n % a != 0) {
+ a--;
+ }
+ return {a, n / a}; // a <= n / a
+}
+
+} // namespace fbgemm
diff --git a/src/FbgemmI8DepthwiseAvx2.cc b/src/FbgemmI8DepthwiseAvx2.cc
index 183a8a9..994f206 100644
--- a/src/FbgemmI8DepthwiseAvx2.cc
+++ b/src/FbgemmI8DepthwiseAvx2.cc
@@ -7,569 +7,15 @@
#include "fbgemm/FbgemmI8DepthwiseAvx2.h"
#include "fbgemm/Utils.h"
-#include <algorithm> // for min and max
-#include <cassert>
-#include <cmath> // for lrintf and sqrt
+#include <string>
#include <tuple> // for tie
-#include <immintrin.h>
+#include "FbgemmI8DepthwiseAvx2-inl.h"
using namespace std;
namespace fbgemm {
-static int masks[8][8] = {
- // NOTE: clang-format wants to use a different formatting but the current
- // formatting should be easier to read.
- { 0, 0, 0, 0, 0, 0, 0, 0, },
- { -1, 0, 0, 0, 0, 0, 0, 0, },
- { -1, -1, 0, 0, 0, 0, 0, 0, },
- { -1, -1, -1, 0, 0, 0, 0, 0, },
- { -1, -1, -1, -1, 0, 0, 0, 0, },
- { -1, -1, -1, -1, -1, 0, 0, 0, },
- { -1, -1, -1, -1, -1, -1, 0, 0, },
- { -1, -1, -1, -1, -1, -1, -1, 0, },
-};
-
-template <int KERNEL_PROD>
-PackedDepthWiseConvMatrix<KERNEL_PROD>::PackedDepthWiseConvMatrix(
- int K,
- const int8_t* smat)
- : K_(K) {
- // Transpose the input matrix to make packing faster.
- int8_t* smat_transposed = static_cast<int8_t *>(ALIGNED_MALLOC(
- K * KERNEL_PROD * sizeof(int8_t), 64));
- for (int i = 0; i < KERNEL_PROD; ++i) {
- for (int j = 0; j < K; ++j) {
- smat_transposed[i * K + j] = smat[i + j * KERNEL_PROD];
- }
- }
-
- // Allocate packed arrays
- constexpr int KERNEL_PROD_ALIGNED = (KERNEL_PROD + 1) / 2 * 2;
-#ifdef _MSC_VER
- pmat_ = static_cast<int8_t *>(_aligned_malloc(
- ((K + 31) / 32) * KERNEL_PROD_ALIGNED * 32 * sizeof(int8_t), 64));
-#else
- posix_memalign(
- (void**)&pmat_,
- 64,
- ((K + 31) / 32) * KERNEL_PROD_ALIGNED * 32 * sizeof(int8_t));
-#endif
-
- // Pack input matrix
- // The layout is optimized to use vpmaddubsw efficiently (see
- // madd_epi16x4_packed function).
- // For a group of 32 channels, we have 10 32B SIMD registers.
- // Denote ith channel jth filter as (i, j)
- // 0th SIMD register:
- // (0, 0), (0, 1), (0, 2), (0, 3), ..., (3, 0), (3, 1), (3, 2), (3, 3)
- // (16, 0), (16, 1), (16, 2), (16, 3), ..., (19, 0), (19, 1), (19, 2), (19, 3)
- // 1st SIMD register:
- // (4, 0), (4, 1), (4, 2), (4, 3), ..., (7, 0), (7, 1), (7, 2), (7, 3)
- // (20, 0), (20, 1), (20, 2), (20, 3), ..., (23, 0), (23, 1), (23, 2), (23, 3)
- // 2nd SIMD register:
- // (8, 0), (8, 1), (8, 2), (8, 3), ..., (11, 0), (11, 1), (11, 2), (11, 3)
- // (24, 0), (24, 1), (24, 2), (24, 3), ..., (27, 0), (27, 1), (27, 2), (27, 3)
- // 3rd SIMD register:
- // (12, 0), (12, 1), (12, 2), (12, 3), ..., (15, 0), (15, 1), (15, 2), (15, 3)
- // (28, 0), (28, 1), (28, 2), (28, 3), ..., (31, 0), (31, 1), (31, 2), (31, 3)
- // 4-7th SIMD register: same as the previous 4 registers but for 4-7th filter
- // coefficients
- // ...
- //
- // REMAINDER
- // If KERNEL_PROD % 4 == 1 for example when KERNEL_PROD == 9
- // 8th SIMD register:
- // (0, 8), zero, ..., (7, 8), zero
- // (16, 8), zero, ..., (23, 8), zero
- // 9th SIMD register:
- // (8, 8), zero, ..., (15, 8), zero
- // (24, 8), zero, ..., (31, 8), zero
- // We use madd_epi16_packed for this case
- //
- // If KERNEL_PROD % 4 == 2 for example when KERNEL_PROD == 10
- // 8th SIMD register:
- // (0, 8), (0, 9), ..., (7, 8), (7, 9)
- // (16, 8), (16, 9), ..., (23, 8), (23, 9)
- // 9th SIMD register:
- // (8, 8), (8, 9), ..., (15, 8), (15, 9)
- // (24, 8), (24, 9), ..., (31, 8), (31, 9)
- //
- // If KERNEL_PROD % 4 == 3 for example when KERNEL_PROD == 11
- // 8th SIMD register:
- // (0, 8), (0, 9), (0, 10), zero, ..., (3, 8), (3, 9), (3, 10), zero
- // (16, 8), (16, 9), (16, 10), zero, ..., (19, 8), (19, 9), (19, 10), zero
- // 9th SIMD register:
- // (4, 8), (4, 9), (4, 10), zero, ..., (7, 8), (7, 9), (7, 10), zero
- // (20, 8), (20, 9), (20, 10), zero, ..., (23, 8), (23, 9), (23, 10), zero
- // 10th SIMD register:
- // (8, 8), (8, 9), (8, 10), zero, ..., (11, 8), (11, 9), (11, 10), zero
- // (24, 8), (24, 9), (24, 10), zero, ..., (27, 8), (27, 9), (27, 10), zero
- // 11th SIMD register:
- // (12, 8), (12, 9), (12, 10), zero, ..., (15, 8), (15, 9), (15, 10), zero
- // (28, 8), (28, 9), (28, 10), zero, ..., (31, 8), (31, 9), (31, 10), zero
- for (int k1 = 0; k1 < K; k1 += 32) {
- __m256i b_v[KERNEL_PROD];
- int remainder = K - k1;
- if (remainder < 32) {
- __m256i mask_v = _mm256_loadu_si256(
- reinterpret_cast<const __m256i*>(masks[remainder / 4]));
- for (int i = 0; i < KERNEL_PROD; ++i) {
- b_v[i] = _mm256_maskload_epi32(
- reinterpret_cast<const int*>(smat_transposed + i * K + k1), mask_v);
- }
- } else {
- for (int i = 0; i < KERNEL_PROD; ++i) {
- b_v[i] = _mm256_lddqu_si256(
- reinterpret_cast<const __m256i*>(smat_transposed + i * K + k1));
- }
- }
-
- // Interleave 2 SIMD registers
- __m256i b_interleaved_epi16[KERNEL_PROD_ALIGNED];
- __m256i zero_v = _mm256_setzero_si256();
- for (int i = 0; i < KERNEL_PROD_ALIGNED / 2; ++i) {
- if (2 * i + 1 >= KERNEL_PROD) {
- b_interleaved_epi16[2 * i] = _mm256_unpacklo_epi8(b_v[2 * i], zero_v);
- b_interleaved_epi16[2 * i + 1] =
- _mm256_unpackhi_epi8(b_v[2 * i], zero_v);
- } else {
- b_interleaved_epi16[2 * i] =
- _mm256_unpacklo_epi8(b_v[2 * i], b_v[2 * i + 1]);
- b_interleaved_epi16[2 * i + 1] =
- _mm256_unpackhi_epi8(b_v[2 * i], b_v[2 * i + 1]);
- }
- }
-
- // Interleave 4 SIMD registers
- __m256i b_interleaved_epi32[KERNEL_PROD_ALIGNED];
- for (int i = 0; i < KERNEL_PROD_ALIGNED / 4; ++i) {
- b_interleaved_epi32[4 * i] = _mm256_unpacklo_epi16(
- b_interleaved_epi16[4 * i], b_interleaved_epi16[4 * i + 2]);
- b_interleaved_epi32[4 * i + 1] = _mm256_unpackhi_epi16(
- b_interleaved_epi16[4 * i], b_interleaved_epi16[4 * i + 2]);
- b_interleaved_epi32[4 * i + 2] = _mm256_unpacklo_epi16(
- b_interleaved_epi16[4 * i + 1], b_interleaved_epi16[4 * i + 3]);
- b_interleaved_epi32[4 * i + 3] = _mm256_unpackhi_epi16(
- b_interleaved_epi16[4 * i + 1], b_interleaved_epi16[4 * i + 3]);
- }
- for (int i = KERNEL_PROD_ALIGNED / 4 * 4; i < KERNEL_PROD_ALIGNED; ++i) {
- b_interleaved_epi32[i] = b_interleaved_epi16[i];
- }
-
- for (int i = 0; i < KERNEL_PROD_ALIGNED; ++i) {
- _mm256_storeu_si256(
- reinterpret_cast<__m256i*>(
- &pmat_[((k1 / 32) * KERNEL_PROD_ALIGNED + i) * 32]),
- b_interleaved_epi32[i]);
- }
- }
-
- FREE(smat_transposed);
-}
-
-template <int KERNEL_PROD>
-int PackedDepthWiseConvMatrix<KERNEL_PROD>::addr(int r, int c) {
- constexpr int KERNEL_PROD_ALIGNED = (KERNEL_PROD + 1) / 2 * 2;
- if (c >= KERNEL_PROD / 4 * 4 &&
- (KERNEL_PROD % 4 == 1 || KERNEL_PROD % 4 == 2)) {
- int kBlock = r / 32;
- int reg_idx = (r % 16) / 8 + c / 4 * 4;
-
- int blk_idx = kBlock * KERNEL_PROD_ALIGNED + reg_idx;
-
- int r_ = r % 8;
- int c_ = c % 4;
-
- int in_blk_idx = (r % 32) / 16 * 16 + 2 * r_ + c_;
- return blk_idx * 32 + in_blk_idx;
-
- } else {
- int kBlock = r / 32;
- int reg_idx = (r % 16) / 4 + c / 4 * 4;
-
- int blk_idx = kBlock * KERNEL_PROD_ALIGNED + reg_idx;
-
- int r_ = r % 4;
- int c_ = c % 4;
-
- int in_blk_idx = (r % 32) / 16 * 16 + 4 * r_ + c_;
- return blk_idx * 32 + in_blk_idx;
- }
-}
-
-template <int KERNEL_PROD>
-void PackedDepthWiseConvMatrix<KERNEL_PROD>::unpack(int8_t* unpacked_data) {
- for (int r = 0; r < K_; ++r) {
- for (int c = 0; c < KERNEL_PROD; ++c) {
- unpacked_data[r * KERNEL_PROD + c] = pmat_[addr(r, c)];
- }
- }
-}
-
-template <int KERNEL_PROD>
-PackedDepthWiseConvMatrix<KERNEL_PROD>::~PackedDepthWiseConvMatrix() {
-#ifdef _MSC_VER
- _aligned_free(pmat_);
-#else
- free(pmat_);
-#endif
-}
-
-template class PackedDepthWiseConvMatrix<3 * 3>;
-template class PackedDepthWiseConvMatrix<3 * 3 * 3>;
-template class PackedDepthWiseConvMatrix<1>;
-template class PackedDepthWiseConvMatrix<2>;
-template class PackedDepthWiseConvMatrix<3>;
-template class PackedDepthWiseConvMatrix<4>;
-template class PackedDepthWiseConvMatrix<5>;
-template class PackedDepthWiseConvMatrix<5 * 2>;
-template class PackedDepthWiseConvMatrix<11 * 1>;
-
-// c = a0 * b0 + a1 * b1 + a2 * b2 + a3 * b3
-// A is in uint8_t
-// B is in int8_t and pre-interleaved
-// C is in int32_t and 4 registers have results in the following layout:
-// c0_v: c[0:4], c[16:20]
-// c1_v: c[4:8], c[20:24]
-// c2_v: c[8:12], c[24:28]
-// c3_v: c[12:16], c[28:32]
-template <bool SUM_A = false>
-static inline ALWAYS_INLINE void madd_epi16x4_packed(
- __m256i a0_v,
- __m256i a1_v,
- __m256i a2_v,
- __m256i a3_v,
- const __m256i* b,
- __m256i* c0_v,
- __m256i* c1_v,
- __m256i* c2_v,
- __m256i* c3_v,
- __m256i* a_sum = nullptr) {
- __m256i a01_lo_v = _mm256_unpacklo_epi8(a0_v, a1_v);
- __m256i a01_hi_v = _mm256_unpackhi_epi8(a0_v, a1_v);
- __m256i a23_lo_v = _mm256_unpacklo_epi8(a2_v, a3_v);
- __m256i a23_hi_v = _mm256_unpackhi_epi8(a2_v, a3_v);
-
- if (SUM_A) {
- __m256i one_epi8_v = _mm256_set1_epi8(1);
- a_sum[0] =
- _mm256_adds_epi16(_mm256_maddubs_epi16(a01_lo_v, one_epi8_v), a_sum[0]);
- a_sum[1] =
- _mm256_adds_epi16(_mm256_maddubs_epi16(a01_hi_v, one_epi8_v), a_sum[1]);
- a_sum[0] =
- _mm256_adds_epi16(_mm256_maddubs_epi16(a23_lo_v, one_epi8_v), a_sum[0]);
- a_sum[1] =
- _mm256_adds_epi16(_mm256_maddubs_epi16(a23_hi_v, one_epi8_v), a_sum[1]);
- }
-
- __m256i a0_interleaved_v = _mm256_unpacklo_epi16(a01_lo_v, a23_lo_v);
- __m256i a1_interleaved_v = _mm256_unpackhi_epi16(a01_lo_v, a23_lo_v);
- __m256i a2_interleaved_v = _mm256_unpacklo_epi16(a01_hi_v, a23_hi_v);
- __m256i a3_interleaved_v = _mm256_unpackhi_epi16(a01_hi_v, a23_hi_v);
-
- __m256i b0_v = _mm256_load_si256(b + 0);
- __m256i b1_v = _mm256_load_si256(b + 1);
- __m256i b2_v = _mm256_load_si256(b + 2);
- __m256i b3_v = _mm256_load_si256(b + 3);
-
- __m256i ab0 = _mm256_maddubs_epi16(a0_interleaved_v, b0_v);
- __m256i ab1 = _mm256_maddubs_epi16(a1_interleaved_v, b1_v);
- __m256i ab2 = _mm256_maddubs_epi16(a2_interleaved_v, b2_v);
- __m256i ab3 = _mm256_maddubs_epi16(a3_interleaved_v, b3_v);
-
- __m256i one_v = _mm256_set1_epi16(1);
- *c0_v = _mm256_madd_epi16(ab0, one_v);
- *c1_v = _mm256_madd_epi16(ab1, one_v);
- *c2_v = _mm256_madd_epi16(ab2, one_v);
- *c3_v = _mm256_madd_epi16(ab3, one_v);
-}
-
-// c = a0 * b0 + a1 * b1 + a2 * b2
-// A is in uint8_t
-// B is in int8_t and pre-interleaved
-// C is in int32_t and 4 registers have results in the following layout:
-// c0_v: c[0:4], c[16:20]
-// c1_v: c[4:8], c[20:24]
-// c2_v: c[8:12], c[24:28]
-// c3_v: c[12:16], c[28:32]
-template <bool SUM_A = false>
-static inline ALWAYS_INLINE void madd_epi16x3_packed(
- __m256i a0_v,
- __m256i a1_v,
- __m256i a2_v,
- const __m256i* b,
- __m256i* c0_v,
- __m256i* c1_v,
- __m256i* c2_v,
- __m256i* c3_v,
- __m256i* a_sum = nullptr) {
- __m256i zero_v = _mm256_setzero_si256();
-
- __m256i a01_lo_v = _mm256_unpacklo_epi8(a0_v, a1_v);
- __m256i a01_hi_v = _mm256_unpackhi_epi8(a0_v, a1_v);
- __m256i a23_lo_v = _mm256_unpacklo_epi8(a2_v, zero_v);
- __m256i a23_hi_v = _mm256_unpackhi_epi8(a2_v, zero_v);
-
- if (SUM_A) {
- __m256i one_epi8_v = _mm256_set1_epi8(1);
- a_sum[0] =
- _mm256_adds_epi16(_mm256_maddubs_epi16(a01_lo_v, one_epi8_v), a_sum[0]);
- a_sum[1] =
- _mm256_adds_epi16(_mm256_maddubs_epi16(a01_hi_v, one_epi8_v), a_sum[1]);
- a_sum[0] =
- _mm256_adds_epi16(_mm256_maddubs_epi16(a23_lo_v, one_epi8_v), a_sum[0]);
- a_sum[1] =
- _mm256_adds_epi16(_mm256_maddubs_epi16(a23_hi_v, one_epi8_v), a_sum[1]);
- }
-
- __m256i a0_interleaved_v = _mm256_unpacklo_epi16(a01_lo_v, a23_lo_v);
- __m256i a1_interleaved_v = _mm256_unpackhi_epi16(a01_lo_v, a23_lo_v);
- __m256i a2_interleaved_v = _mm256_unpacklo_epi16(a01_hi_v, a23_hi_v);
- __m256i a3_interleaved_v = _mm256_unpackhi_epi16(a01_hi_v, a23_hi_v);
-
- __m256i b0_v = _mm256_load_si256(b + 0);
- __m256i b1_v = _mm256_load_si256(b + 1);
- __m256i b2_v = _mm256_load_si256(b + 2);
- __m256i b3_v = _mm256_load_si256(b + 3);
-
- __m256i ab0 = _mm256_maddubs_epi16(a0_interleaved_v, b0_v);
- __m256i ab1 = _mm256_maddubs_epi16(a1_interleaved_v, b1_v);
- __m256i ab2 = _mm256_maddubs_epi16(a2_interleaved_v, b2_v);
- __m256i ab3 = _mm256_maddubs_epi16(a3_interleaved_v, b3_v);
-
- __m256i one_v = _mm256_set1_epi16(1);
- *c0_v = _mm256_madd_epi16(ab0, one_v);
- *c1_v = _mm256_madd_epi16(ab1, one_v);
- *c2_v = _mm256_madd_epi16(ab2, one_v);
- *c3_v = _mm256_madd_epi16(ab3, one_v);
-}
-
-// c = a0 * b0 + a1 * b1
-// A is in uint8_t
-// B is in int8_t and pre-interleaved
-// C is in int32_t and 4 registers have results in the following layout:
-// c0_v: c[0:4], c[4:8]
-// c1_v: c[8:12], c[12:16]
-// c2_v: c[16:20], c[20:24]
-// c3_v: c[24:28], c[28:32]
-template <bool SUM_A = false>
-static inline ALWAYS_INLINE void madd_epi16x2_packed(
- __m256i a0_v,
- __m256i a1_v,
- const __m256i* b,
- __m256i* c0_v,
- __m256i* c1_v,
- __m256i* c2_v,
- __m256i* c3_v,
- __m256i* a_sum = nullptr) {
- __m256i a_lo_v = _mm256_unpacklo_epi8(a0_v, a1_v);
- __m256i a_hi_v = _mm256_unpackhi_epi8(a0_v, a1_v);
-
- if (SUM_A) {
- __m256i one_epi8_v = _mm256_set1_epi8(1);
- a_sum[0] =
- _mm256_adds_epi16(_mm256_maddubs_epi16(a_lo_v, one_epi8_v), a_sum[0]);
- a_sum[1] =
- _mm256_adds_epi16(_mm256_maddubs_epi16(a_hi_v, one_epi8_v), a_sum[1]);
- }
-
- __m256i b0_v = _mm256_load_si256(b + 0);
- __m256i b1_v = _mm256_load_si256(b + 1);
-
- __m256i ab_lo_v = _mm256_maddubs_epi16(a_lo_v, b0_v);
- __m256i ab_hi_v = _mm256_maddubs_epi16(a_hi_v, b1_v);
-
- *c0_v = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(ab_lo_v));
- *c1_v = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(ab_hi_v));
- *c2_v = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(ab_lo_v, 1));
- *c3_v = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(ab_hi_v, 1));
-}
-
-// c = a0 * b0
-// A is in uint8_t
-// B is in int8_t and pre-interleaved
-// C is in int32_t and 4 registers have results in the following layout:
-// c0_v: c[0:4], c[4:8]
-// c1_v: c[8:12], c[12:16]
-// c2_v: c[16:20], c[20:24]
-// c3_v: c[24:28], c[28:32]
-template <bool SUM_A = false>
-static inline ALWAYS_INLINE void madd_epi16_packed(
- __m256i a_v,
- const __m256i* b,
- __m256i* c0_v,
- __m256i* c1_v,
- __m256i* c2_v,
- __m256i* c3_v,
- __m256i* a_sum = nullptr) {
- __m256i zero_v = _mm256_setzero_si256();
-
- __m256i a_lo_v = _mm256_unpacklo_epi8(a_v, zero_v);
- __m256i a_hi_v = _mm256_unpackhi_epi8(a_v, zero_v);
-
- if (SUM_A) {
- __m256i one_epi8_v = _mm256_set1_epi8(1);
- a_sum[0] =
- _mm256_adds_epi16(_mm256_maddubs_epi16(a_lo_v, one_epi8_v), a_sum[0]);
- a_sum[1] =
- _mm256_adds_epi16(_mm256_maddubs_epi16(a_hi_v, one_epi8_v), a_sum[1]);
- }
-
- __m256i b0_v = _mm256_load_si256(b + 0);
- __m256i b1_v = _mm256_load_si256(b + 1);
-
- __m256i ab_lo_v = _mm256_maddubs_epi16(a_lo_v, b0_v);
- __m256i ab_hi_v = _mm256_maddubs_epi16(a_hi_v, b1_v);
-
- *c0_v = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(ab_lo_v));
- *c1_v = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(ab_hi_v));
- *c2_v = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(ab_lo_v, 1));
- *c3_v = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(ab_hi_v, 1));
-}
-
-// K is the number of accumulations we're doing
-template <int K, bool SUM_A = false, bool REMAINDER = false, bool ACC = false>
-static inline ALWAYS_INLINE void inner_prod_packed_(
- const __m256i* a_v,
- const __m256i* Bp,
- int32_t* C,
- int remainder,
- __m256i* a_sum = nullptr) {
- __m256i c[4], c_temp[4];
- __m256i a_sum_temp[2] = {0, 0};
-
- int k = 0;
- if (K >= 4) {
- madd_epi16x4_packed<SUM_A>(
- a_v[0],
- a_v[1],
- a_v[2],
- a_v[3],
- Bp,
- &c[0],
- &c[1],
- &c[2],
- &c[3],
- a_sum_temp);
-
- for (k = 4; k < K / 4 * 4; k += 4) {
- madd_epi16x4_packed<SUM_A>(
- a_v[k + 0],
- a_v[k + 1],
- a_v[k + 2],
- a_v[k + 3],
- Bp + k,
- &c_temp[0],
- &c_temp[1],
- &c_temp[2],
- &c_temp[3],
- a_sum_temp);
-
- c[0] = _mm256_add_epi32(c[0], c_temp[0]);
- c[1] = _mm256_add_epi32(c[1], c_temp[1]);
- c[2] = _mm256_add_epi32(c[2], c_temp[2]);
- c[3] = _mm256_add_epi32(c[3], c_temp[3]);
- }
- } else {
- c[0] = _mm256_setzero_si256();
- c[1] = _mm256_setzero_si256();
- c[2] = _mm256_setzero_si256();
- c[3] = _mm256_setzero_si256();
- }
-
- if (K - k == 3) {
- madd_epi16x3_packed<SUM_A>(
- a_v[k],
- a_v[k + 1],
- a_v[k + 2],
- Bp + k,
- &c_temp[0],
- &c_temp[1],
- &c_temp[2],
- &c_temp[3],
- a_sum_temp);
-
- c[0] = _mm256_add_epi32(c[0], c_temp[0]);
- c[1] = _mm256_add_epi32(c[1], c_temp[1]);
- c[2] = _mm256_add_epi32(c[2], c_temp[2]);
- c[3] = _mm256_add_epi32(c[3], c_temp[3]);
- }
-
- c_temp[0] = _mm256_permute2f128_si256(c[0], c[1], 0x20);
- c_temp[1] = _mm256_permute2f128_si256(c[2], c[3], 0x20);
- c_temp[2] = _mm256_permute2f128_si256(c[0], c[1], 0x31);
- c_temp[3] = _mm256_permute2f128_si256(c[2], c[3], 0x31);
-
- if (K - k == 0 || K - k == 3) {
- c[0] = c_temp[0];
- c[1] = c_temp[1];
- c[2] = c_temp[2];
- c[3] = c_temp[3];
- } else {
- if (K - k == 1) {
- madd_epi16_packed<SUM_A>(
- a_v[k], Bp + k, &c[0], &c[1], &c[2], &c[3], a_sum_temp);
- } else if (K - k == 2) {
- madd_epi16x2_packed<SUM_A>(
- a_v[k], a_v[k + 1], Bp + k, &c[0], &c[1], &c[2], &c[3], a_sum_temp);
- }
-
- c[0] = _mm256_add_epi32(c[0], c_temp[0]);
- c[1] = _mm256_add_epi32(c[1], c_temp[1]);
- c[2] = _mm256_add_epi32(c[2], c_temp[2]);
- c[3] = _mm256_add_epi32(c[3], c_temp[3]);
- }
-
- if (REMAINDER) {
- for (int r = 0; r < remainder / 8; ++r) {
- if (ACC) {
- _mm256_storeu_si256(
- reinterpret_cast<__m256i*>(C + r * 8),
- _mm256_add_epi32(
- _mm256_loadu_si256(reinterpret_cast<__m256i*>(C + r * 8)),
- c[r]));
- } else {
- _mm256_storeu_si256(reinterpret_cast<__m256i*>(C + r * 8), c[r]);
- }
- }
- } else {
- if (ACC) {
- _mm256_storeu_si256(
- reinterpret_cast<__m256i*>(C),
- _mm256_add_epi32(
- _mm256_loadu_si256(reinterpret_cast<__m256i*>(C)), c[0]));
- _mm256_storeu_si256(
- reinterpret_cast<__m256i*>(C + 8),
- _mm256_add_epi32(
- _mm256_loadu_si256(reinterpret_cast<__m256i*>(C + 8)), c[1]));
- _mm256_storeu_si256(
- reinterpret_cast<__m256i*>(C + 16),
- _mm256_add_epi32(
- _mm256_loadu_si256(reinterpret_cast<__m256i*>(C + 16)), c[2]));
- _mm256_storeu_si256(
- reinterpret_cast<__m256i*>(C + 24),
- _mm256_add_epi32(
- _mm256_loadu_si256(reinterpret_cast<__m256i*>(C + 24)), c[3]));
- } else {
- _mm256_storeu_si256(reinterpret_cast<__m256i*>(C), c[0]);
- _mm256_storeu_si256(reinterpret_cast<__m256i*>(C + 8), c[1]);
- _mm256_storeu_si256(reinterpret_cast<__m256i*>(C + 16), c[2]);
- _mm256_storeu_si256(reinterpret_cast<__m256i*>(C + 24), c[3]);
- }
- }
-
- if (SUM_A) {
- a_sum[0] = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(a_sum_temp[0]));
- a_sum[1] = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(a_sum_temp[1]));
- a_sum[2] =
- _mm256_cvtepi16_epi32(_mm256_extracti128_si256(a_sum_temp[0], 1));
- a_sum[3] =
- _mm256_cvtepi16_epi32(_mm256_extracti128_si256(a_sum_temp[1], 1));
- }
-}
-
template <bool SUM_A = false, bool REMAINDER = false>
static inline ALWAYS_INLINE void inner_prod_3x3_packed_(
const __m256i* a_v,
@@ -580,238 +26,6 @@ static inline ALWAYS_INLINE void inner_prod_3x3_packed_(
return inner_prod_packed_<9, SUM_A, REMAINDER>(a_v, Bp, C, remainder, a_sum);
}
-// Almost same as ReQuantizeOutput in OutputProcessing-inh.h but different
-// row_offsets for each row because of depth-wise convolution
-template <
- bool FUSE_RELU,
- bool HAS_BIAS,
- bool PER_CHANNEL_QUANTIZATION,
- bool A_SYMMETRIC,
- bool B_SYMMETRIC>
-static inline ALWAYS_INLINE void requantize_(
- int32_t A_zero_point,
- const float* C_multiplier,
- int32_t C_zero_point,
- const int32_t* C_int32,
- uint8_t* C_uint8,
- int n,
- const int32_t* row_offsets,
- const int32_t* col_offsets,
- const int32_t* bias) {
- __m256 multiplier_v = _mm256_setzero_ps();
- if (!PER_CHANNEL_QUANTIZATION) {
- multiplier_v = _mm256_set1_ps(*C_multiplier);
- }
-
- __m256i min_v = _mm256_set1_epi8(static_cast<uint8_t>(0));
- __m256i max_v = _mm256_set1_epi8(static_cast<uint8_t>(255));
-
- if (A_SYMMETRIC) {
- assert(A_zero_point == 0 || col_offsets == nullptr);
- }
- __m256i A_zero_point_v = _mm256_set1_epi32(A_zero_point);
- __m256i C_zero_point_epi16_v = _mm256_set1_epi16(C_zero_point);
- __m256i C_zero_point_epi8_v = _mm256_set1_epi8(C_zero_point);
-
- __m256i permute_mask_v =
- _mm256_set_epi32(0x07, 0x03, 0x06, 0x02, 0x05, 0x01, 0x04, 0x00);
-
- constexpr int VLEN = 8;
- int j = 0;
- for (; j < n / (VLEN * 4) * (VLEN * 4); j += (VLEN * 4)) {
- __m256i x_v =
- _mm256_loadu_si256(reinterpret_cast<const __m256i*>(C_int32 + j));
- __m256i y_v = _mm256_loadu_si256(
- reinterpret_cast<const __m256i*>(C_int32 + j + VLEN));
- __m256i z_v = _mm256_loadu_si256(
- reinterpret_cast<const __m256i*>(C_int32 + j + 2 * VLEN));
- __m256i w_v = _mm256_loadu_si256(
- reinterpret_cast<const __m256i*>(C_int32 + j + 3 * VLEN));
-
- __m256i row_offset_v;
- if (!B_SYMMETRIC) {
- row_offset_v =
- _mm256_loadu_si256(reinterpret_cast<const __m256i*>(row_offsets + j));
- x_v = _mm256_sub_epi32(x_v, row_offset_v);
- }
- __m256i col_off_v;
- if (!A_SYMMETRIC) {
- col_off_v = _mm256_mullo_epi32(
- A_zero_point_v,
- _mm256_loadu_si256(
- reinterpret_cast<const __m256i*>(col_offsets + j)));
- x_v = _mm256_sub_epi32(x_v, col_off_v);
- }
-
- if (!B_SYMMETRIC) {
- row_offset_v = _mm256_loadu_si256(
- reinterpret_cast<const __m256i*>(row_offsets + j + VLEN));
- y_v = _mm256_sub_epi32(y_v, row_offset_v);
- }
- if (!A_SYMMETRIC) {
- col_off_v = _mm256_mullo_epi32(
- A_zero_point_v,
- _mm256_loadu_si256(
- reinterpret_cast<const __m256i*>(col_offsets + j + VLEN)));
- y_v = _mm256_sub_epi32(y_v, col_off_v);
- }
-
- if (!B_SYMMETRIC) {
- row_offset_v = _mm256_loadu_si256(
- reinterpret_cast<const __m256i*>(row_offsets + j + 2 * VLEN));
- z_v = _mm256_sub_epi32(z_v, row_offset_v);
- }
- if (!A_SYMMETRIC) {
- col_off_v = _mm256_mullo_epi32(
- A_zero_point_v,
- _mm256_loadu_si256(
- reinterpret_cast<const __m256i*>(col_offsets + j + 2 * VLEN)));
- z_v = _mm256_sub_epi32(z_v, col_off_v);
- }
-
- if (!B_SYMMETRIC) {
- row_offset_v = _mm256_loadu_si256(
- reinterpret_cast<const __m256i*>(row_offsets + j + 3 * VLEN));
- w_v = _mm256_sub_epi32(w_v, row_offset_v);
- }
- if (!A_SYMMETRIC) {
- col_off_v = _mm256_mullo_epi32(
- A_zero_point_v,
- _mm256_loadu_si256(
- reinterpret_cast<const __m256i*>(col_offsets + j + 3 * VLEN)));
- w_v = _mm256_sub_epi32(w_v, col_off_v);
- }
-
- if (HAS_BIAS) { // static if
- x_v = _mm256_add_epi32(
- x_v, _mm256_loadu_si256(reinterpret_cast<const __m256i*>(bias + j)));
- y_v = _mm256_add_epi32(
- y_v,
- _mm256_loadu_si256(
- reinterpret_cast<const __m256i*>(bias + j + VLEN)));
- z_v = _mm256_add_epi32(
- z_v,
- _mm256_loadu_si256(
- reinterpret_cast<const __m256i*>(bias + j + 2 * VLEN)));
- w_v = _mm256_add_epi32(
- w_v,
- _mm256_loadu_si256(
- reinterpret_cast<const __m256i*>(bias + j + 3 * VLEN)));
- }
-
- if (PER_CHANNEL_QUANTIZATION) {
- multiplier_v = _mm256_loadu_ps(C_multiplier + j);
- }
- __m256 x_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(x_v), multiplier_v);
- if (PER_CHANNEL_QUANTIZATION) {
- multiplier_v = _mm256_loadu_ps(C_multiplier + j + VLEN);
- }
- __m256 y_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(y_v), multiplier_v);
- if (PER_CHANNEL_QUANTIZATION) {
- multiplier_v = _mm256_loadu_ps(C_multiplier + j + 2 * VLEN);
- }
- __m256 z_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(z_v), multiplier_v);
- if (PER_CHANNEL_QUANTIZATION) {
- multiplier_v = _mm256_loadu_ps(C_multiplier + j + 3 * VLEN);
- }
- __m256 w_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(w_v), multiplier_v);
-
- __m256i x_rounded_v = _mm256_cvtps_epi32(x_scaled_v);
- __m256i y_rounded_v = _mm256_cvtps_epi32(y_scaled_v);
- __m256i z_rounded_v = _mm256_cvtps_epi32(z_scaled_v);
- __m256i w_rounded_v = _mm256_cvtps_epi32(w_scaled_v);
-
- __m256i xy_packed_v = _mm256_adds_epi16(
- _mm256_packs_epi32(x_rounded_v, y_rounded_v), C_zero_point_epi16_v);
- __m256i zw_packed_v = _mm256_adds_epi16(
- _mm256_packs_epi32(z_rounded_v, w_rounded_v), C_zero_point_epi16_v);
- __m256i xyzw_packed_v = _mm256_packus_epi16(xy_packed_v, zw_packed_v);
- __m256i xyzw_clamped_v = _mm256_max_epu8(
- FUSE_RELU ? C_zero_point_epi8_v : min_v,
- _mm256_min_epu8(xyzw_packed_v, max_v));
-
- xyzw_clamped_v =
- _mm256_permutevar8x32_epi32(xyzw_clamped_v, permute_mask_v);
-
- _mm256_storeu_si256(
- reinterpret_cast<__m256i*>(C_uint8 + j), xyzw_clamped_v);
- } // j loop vectorized and unrolled 4x
-
- for (; j < n / VLEN * VLEN; j += VLEN) {
- __m256i x_v =
- _mm256_loadu_si256(reinterpret_cast<const __m256i*>(C_int32 + j));
-
- if (!B_SYMMETRIC) {
- __m256i row_offset_v =
- _mm256_loadu_si256(reinterpret_cast<const __m256i*>(row_offsets + j));
- x_v = _mm256_sub_epi32(x_v, row_offset_v);
- }
- if (!A_SYMMETRIC) {
- __m256i col_off_v = _mm256_mullo_epi32(
- A_zero_point_v,
- _mm256_loadu_si256(
- reinterpret_cast<const __m256i*>(col_offsets + j)));
- x_v = _mm256_sub_epi32(x_v, col_off_v);
- }
-
- if (HAS_BIAS) { // static if
- x_v = _mm256_add_epi32(
- x_v, _mm256_loadu_si256(reinterpret_cast<const __m256i*>(bias + j)));
- }
-
- if (PER_CHANNEL_QUANTIZATION) {
- multiplier_v = _mm256_loadu_ps(C_multiplier + j);
- }
- __m256 x_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(x_v), multiplier_v);
- __m256i x_rounded_v = _mm256_cvtps_epi32(x_scaled_v);
-
- __m256i x_packed_v = _mm256_adds_epi16(
- _mm256_packs_epi32(x_rounded_v, _mm256_setzero_si256()),
- C_zero_point_epi16_v);
- x_packed_v = _mm256_packus_epi16(x_packed_v, _mm256_setzero_si256());
- __m256i x_clamped_v = _mm256_max_epu8(
- FUSE_RELU ? C_zero_point_epi8_v : min_v,
- _mm256_min_epu8(x_packed_v, max_v));
-
- x_clamped_v = _mm256_permutevar8x32_epi32(x_clamped_v, permute_mask_v);
-
- _mm_storel_epi64(
- reinterpret_cast<__m128i*>(C_uint8 + j),
- _mm256_castsi256_si128(x_clamped_v));
- } // j loop vectorized
-
- for (; j < n; ++j) {
- int32_t raw = C_int32[j];
- if (!B_SYMMETRIC) {
- raw -= row_offsets[j];
- }
- if (!A_SYMMETRIC) {
- raw -= A_zero_point * col_offsets[j];
- }
- if (HAS_BIAS) { // static if
- raw += bias[j];
- }
-
- float ab = raw * C_multiplier[PER_CHANNEL_QUANTIZATION ? j : 0];
- long rounded = lrintf(ab) + C_zero_point;
-
- C_uint8[j] = std::max(
- FUSE_RELU ? static_cast<long>(C_zero_point) : 0l,
- std::min(255l, rounded));
- }
-}
-
-template <bool REMAINDER>
-static inline ALWAYS_INLINE __m256i load_a(
- const uint8_t* A,
- __m256i mask_v) {
- if (REMAINDER) {
- return _mm256_maskload_epi32(reinterpret_cast<const int*>(A), mask_v);
- } else {
- return _mm256_lddqu_si256(reinterpret_cast<const __m256i*>(A));
- }
-}
-
template <
bool SUM_A,
bool REMAINDER = false,
@@ -924,257 +138,11 @@ static inline ALWAYS_INLINE void inner_prod_3x3_packed_(
}
template <
- bool SUM_A,
- bool REMAINDER = false,
- bool PER_CHANNEL_QUANTIZATION = false>
-static inline ALWAYS_INLINE void inner_prod_3x3x3_packed_(
- int T,
- int H,
- int W,
- int K,
- int t_in,
- int h_in,
- int w_in,
- const uint8_t* A,
- int32_t A_zero_point,
- const int8_t* Bp,
- const int32_t* B_zero_point,
- int32_t* C,
- int remainder,
- int32_t* row_offsets) {
- __m256i A_zero_point_v = _mm256_set1_epi8(static_cast<uint8_t>(A_zero_point));
- __m256i mask_v = _mm256_setzero_si256();
- if (REMAINDER) {
- mask_v = _mm256_loadu_si256(
- reinterpret_cast<const __m256i*>(masks[remainder / 4]));
- }
-
- // The code below can be written as a simple R*S loop but the compiler
- // doesn't unroll so we're manually unrolling it.
- // constexpr int R = 3, S = 3;
- // array<__m256i, R * S> a_v;
- // for (int r = 0; r < R; ++r) {
- // for (int s = 0; s < S; ++s) {
- // if (h_in + r >= 0 && h_in + r < H && w_in + s >= 0 && w_in + s < W) {
- // if (REMAINDER) {
- // a_v[r * S + s] =
- // _mm256_maskload_epi32((const int *)(A + (r * W + s) * K),
- // mask_v);
- // } else {
- // a_v[r * S + s] =
- // _mm256_lddqu_si256((const __m256i *)(A + (r * W + s) * K));
- // }
- // } else {
- // a_v[r * S + s] = A_zero_point_v;
- // }
- // }
- // }
- __m256i a_v[8];
- a_v[0] = A_zero_point_v;
- a_v[1] = A_zero_point_v;
- a_v[2] = A_zero_point_v;
- a_v[3] = A_zero_point_v;
- a_v[4] = A_zero_point_v;
- a_v[5] = A_zero_point_v;
- a_v[6] = A_zero_point_v;
- a_v[7] = A_zero_point_v;
-
- if (t_in >= 0 && t_in < T) {
- if (h_in >= 0 && h_in < H) {
- if (w_in >= 0 && w_in < W) {
- a_v[0] = load_a<REMAINDER>(A + ((0 * H + 0) * W + 0) * K, mask_v);
- }
- if (w_in + 1 >= 0 && w_in + 1 < W) {
- a_v[1] = load_a<REMAINDER>(A + ((0 * H + 0) * W + 1) * K, mask_v);
- }
- if (w_in + 2 >= 0 && w_in + 2 < W) {
- a_v[2] = load_a<REMAINDER>(A + ((0 * H + 0) * W + 2) * K, mask_v);
- }
- }
-
- if (h_in + 1 >= 0 && h_in + 1 < H) {
- if (w_in >= 0 && w_in < W) {
- a_v[3] = load_a<REMAINDER>(A + ((0 * H + 1) * W + 0) * K, mask_v);
- }
- if (w_in + 1 >= 0 && w_in + 1 < W) {
- a_v[4] = load_a<REMAINDER>(A + ((0 * H + 1) * W + 1) * K, mask_v);
- }
- if (w_in + 2 >= 0 && w_in + 2 < W) {
- a_v[5] = load_a<REMAINDER>(A + ((0 * H + 1) * W + 2) * K, mask_v);
- }
- }
-
- if (h_in + 2 >= 0 && h_in + 2 < H) {
- if (w_in >= 0 && w_in < W) {
- a_v[6] = load_a<REMAINDER>(A + ((0 * H + 2) * W + 0) * K, mask_v);
- }
- if (w_in + 1 >= 0 && w_in + 1 < W) {
- a_v[7] = load_a<REMAINDER>(A + ((0 * H + 2) * W + 1) * K, mask_v);
- }
- }
- }
-
- __m256i a_sum[4];
- inner_prod_packed_<8, SUM_A, REMAINDER>(
- a_v, reinterpret_cast<const __m256i*>(Bp), C, remainder, a_sum);
-
- a_v[0] = A_zero_point_v;
- a_v[1] = A_zero_point_v;
- a_v[2] = A_zero_point_v;
- a_v[3] = A_zero_point_v;
- a_v[4] = A_zero_point_v;
- a_v[5] = A_zero_point_v;
- a_v[6] = A_zero_point_v;
- a_v[7] = A_zero_point_v;
-
- if (t_in >= 0 && t_in < T) {
- if (h_in + 2 >= 0 && h_in + 2 < H) {
- if (w_in + 2 >= 0 && w_in + 2 < W) {
- a_v[0] = load_a<REMAINDER>(A + ((0 * H + 2) * W + 2) * K, mask_v);
- }
- }
- }
-
- if (t_in + 1 >= 0 && t_in + 1 < T) {
- if (h_in >= 0 && h_in < H) {
- if (w_in >= 0 && w_in < W) {
- a_v[1] = load_a<REMAINDER>(A + ((1 * H + 0) * W + 0) * K, mask_v);
- }
- if (w_in + 1 >= 0 && w_in + 1 < W) {
- a_v[2] = load_a<REMAINDER>(A + ((1 * H + 0) * W + 1) * K, mask_v);
- }
- if (w_in + 2 >= 0 && w_in + 2 < W) {
- a_v[3] = load_a<REMAINDER>(A + ((1 * H + 0) * W + 2) * K, mask_v);
- }
- }
-
- if (h_in + 1 >= 0 && h_in + 1 < H) {
- if (w_in >= 0 && w_in < W) {
- a_v[4] = load_a<REMAINDER>(A + ((1 * H + 1) * W + 0) * K, mask_v);
- }
- if (w_in + 1 >= 0 && w_in + 1 < W) {
- a_v[5] = load_a<REMAINDER>(A + ((1 * H + 1) * W + 1) * K, mask_v);
- }
- if (w_in + 2 >= 0 && w_in + 2 < W) {
- a_v[6] = load_a<REMAINDER>(A + ((1 * H + 1) * W + 2) * K, mask_v);
- }
- }
-
- if (h_in + 2 >= 0 && h_in + 2 < H) {
- if (w_in >= 0 && w_in < W) {
- a_v[7] = load_a<REMAINDER>(A + ((1 * H + 2) * W + 0) * K, mask_v);
- }
- }
- }
-
- __m256i a_sum_temp[4];
- inner_prod_packed_<8, SUM_A, REMAINDER, true /* acc */>(
- a_v, reinterpret_cast<const __m256i*>(Bp) + 8, C, remainder, a_sum_temp);
- if (SUM_A) {
- a_sum[0] = _mm256_add_epi32(a_sum[0], a_sum_temp[0]);
- a_sum[1] = _mm256_add_epi32(a_sum[1], a_sum_temp[1]);
- a_sum[2] = _mm256_add_epi32(a_sum[2], a_sum_temp[2]);
- a_sum[3] = _mm256_add_epi32(a_sum[3], a_sum_temp[3]);
- }
-
- a_v[0] = A_zero_point_v;
- a_v[1] = A_zero_point_v;
- a_v[2] = A_zero_point_v;
- a_v[3] = A_zero_point_v;
- a_v[4] = A_zero_point_v;
- a_v[5] = A_zero_point_v;
- a_v[6] = A_zero_point_v;
- a_v[7] = A_zero_point_v;
-
- if (t_in + 1 >= 0 && t_in + 1 < T) {
- if (h_in + 2 >= 0 && h_in + 2 < H) {
- if (w_in + 1 >= 0 && w_in + 1 < W) {
- a_v[0] = load_a<REMAINDER>(A + ((1 * H + 2) * W + 1) * K, mask_v);
- }
- if (w_in + 2 >= 0 && w_in + 2 < W) {
- a_v[1] = load_a<REMAINDER>(A + ((1 * H + 2) * W + 2) * K, mask_v);
- }
- }
- }
-
- if (t_in + 2 >= 0 && t_in + 2 < T) {
- if (h_in >= 0 && h_in < H) {
- if (w_in >= 0 && w_in < W) {
- a_v[2] = load_a<REMAINDER>(A + ((2 * H + 0) * W + 0) * K, mask_v);
- }
- if (w_in + 1 >= 0 && w_in + 1 < W) {
- a_v[3] = load_a<REMAINDER>(A + ((2 * H + 0) * W + 1) * K, mask_v);
- }
- if (w_in + 2 >= 0 && w_in + 2 < W) {
- a_v[4] = load_a<REMAINDER>(A + ((2 * H + 0) * W + 2) * K, mask_v);
- }
- }
-
- if (h_in + 1 >= 0 && h_in + 1 < H) {
- if (w_in >= 0 && w_in < W) {
- a_v[5] = load_a<REMAINDER>(A + ((2 * H + 1) * W + 0) * K, mask_v);
- }
- if (w_in + 1 >= 0 && w_in + 1 < W) {
- a_v[6] = load_a<REMAINDER>(A + ((2 * H + 1) * W + 1) * K, mask_v);
- }
- if (w_in + 2 >= 0 && w_in + 2 < W) {
- a_v[7] = load_a<REMAINDER>(A + ((2 * H + 1) * W + 2) * K, mask_v);
- }
- }
- }
-
- inner_prod_packed_<8, SUM_A, REMAINDER, true /* acc */>(
- a_v, reinterpret_cast<const __m256i*>(Bp) + 16, C, remainder, a_sum_temp);
- if (SUM_A) {
- a_sum[0] = _mm256_add_epi32(a_sum[0], a_sum_temp[0]);
- a_sum[1] = _mm256_add_epi32(a_sum[1], a_sum_temp[1]);
- a_sum[2] = _mm256_add_epi32(a_sum[2], a_sum_temp[2]);
- a_sum[3] = _mm256_add_epi32(a_sum[3], a_sum_temp[3]);
- }
-
- a_v[0] = A_zero_point_v;
- a_v[1] = A_zero_point_v;
- a_v[2] = A_zero_point_v;
-
- if (t_in + 2 >= 0 && t_in + 2 < T) {
- if (h_in + 2 >= 0 && h_in + 2 < H) {
- if (w_in >= 0 && w_in < W) {
- a_v[0] = load_a<REMAINDER>(A + ((2 * H + 2) * W + 0) * K, mask_v);
- }
- if (w_in + 1 >= 0 && w_in + 1 < W) {
- a_v[1] = load_a<REMAINDER>(A + ((2 * H + 2) * W + 1) * K, mask_v);
- }
- if (w_in + 2 >= 0 && w_in + 2 < W) {
- a_v[2] = load_a<REMAINDER>(A + ((2 * H + 2) * W + 2) * K, mask_v);
- }
- }
- }
-
- inner_prod_packed_<3, SUM_A, REMAINDER, true /* acc */>(
- a_v, reinterpret_cast<const __m256i*>(Bp) + 24, C, remainder, a_sum_temp);
-
- if (SUM_A) {
- a_sum[0] = _mm256_add_epi32(a_sum[0], a_sum_temp[0]);
- a_sum[1] = _mm256_add_epi32(a_sum[1], a_sum_temp[1]);
- a_sum[2] = _mm256_add_epi32(a_sum[2], a_sum_temp[2]);
- a_sum[3] = _mm256_add_epi32(a_sum[3], a_sum_temp[3]);
-
- __m256i B_zero_point_v;
- for (int i = 0; i < (REMAINDER ? (remainder / 8) : 4); ++i) {
- if (PER_CHANNEL_QUANTIZATION) {
- B_zero_point_v = _mm256_loadu_si256(
- reinterpret_cast<const __m256i*>(B_zero_point + i * 8));
- } else {
- B_zero_point_v = _mm256_set1_epi32(B_zero_point[0]);
- }
- _mm256_store_si256(
- reinterpret_cast<__m256i*>(&row_offsets[i * 8]),
- _mm256_mullo_epi32(a_sum[i], B_zero_point_v));
- }
- }
-}
-
-template <bool FUSE_RELU, bool HAS_BIAS, bool A_SYMMETRIC, bool B_SYMMETRIC>
+ bool FUSE_RELU,
+ bool HAS_BIAS,
+ bool A_SYMMETRIC,
+ bool B_SYMMETRIC,
+ typename BIAS_TYPE>
static inline ALWAYS_INLINE void depthwise_3x3_kernel_(
int H,
int W,
@@ -1193,7 +161,8 @@ static inline ALWAYS_INLINE void depthwise_3x3_kernel_(
uint8_t* C_uint8,
int32_t* row_offsets,
const int32_t* col_offsets,
- const int32_t* bias) {
+ const BIAS_TYPE* bias,
+ float act_times_w_scale) {
constexpr int S = 3;
constexpr int PAD_T = 1, PAD_L = 1, PAD_R = 1;
int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1;
@@ -1238,7 +207,8 @@ static inline ALWAYS_INLINE void depthwise_3x3_kernel_(
HAS_BIAS,
false, /*PER_CHAN_QUANT*/
A_SYMMETRIC,
- B_SYMMETRIC>(
+ B_SYMMETRIC,
+ BIAS_TYPE>(
A_zero_point,
&C_multiplier,
C_zero_point,
@@ -1247,95 +217,11 @@ static inline ALWAYS_INLINE void depthwise_3x3_kernel_(
K,
row_offsets,
col_offsets,
- bias);
-}
-
-template <bool FUSE_RELU, bool HAS_BIAS, bool A_SYMMETRIC, bool B_SYMMETRIC>
-static inline ALWAYS_INLINE void depthwise_3x3x3_kernel_(
- int T,
- int H,
- int W,
- int K,
- int t,
- int h,
- int w,
- int stride_t,
- int stride_h,
- int stride_w,
- int32_t A_zero_point,
- const uint8_t* A,
- int32_t B_zero_point,
- const int8_t* Bp,
- float C_multiplier,
- int32_t C_zero_point,
- int32_t* C_int32,
- uint8_t* C_uint8,
- int32_t* row_offsets,
- const int32_t* col_offsets,
- const int32_t* bias) {
- constexpr int R = 3, S = 3;
- constexpr int PAD_P = 1, PAD_T = 1, PAD_B = 1, PAD_L = 1, PAD_R = 1;
- int H_OUT = (H + PAD_T + PAD_B - R) / stride_h + 1;
- int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1;
- int t_in = -PAD_P + t * stride_t;
- int h_in = -PAD_T + h * stride_h;
- int w_in = -PAD_L + w * stride_w;
-
- int k;
- for (k = 0; k < K / 32 * 32; k += 32) {
- inner_prod_3x3x3_packed_<!B_SYMMETRIC /*SUM_A*/>(
- T,
- H,
- W,
- K,
- t_in,
- h_in,
- w_in,
- A + ((t_in * H + h_in) * W + w_in) * K + k,
- A_zero_point,
- Bp + k * 28,
- &B_zero_point,
- C_int32 + k,
- 0,
- B_SYMMETRIC ? nullptr : &row_offsets[k]);
- }
- int remainder = K - k;
- if (remainder) {
- inner_prod_3x3x3_packed_<!B_SYMMETRIC /*SUM_A*/, true>(
- T,
- H,
- W,
- K,
- t_in,
- h_in,
- w_in,
- A + ((t_in * H + h_in) * W + w_in) * K + k,
- A_zero_point,
- Bp + k * 28,
- &B_zero_point,
- C_int32 + k,
- remainder,
- B_SYMMETRIC ? nullptr : &row_offsets[k]);
- }
-
- requantize_<
- FUSE_RELU,
- HAS_BIAS,
- false, /*PER_CHAN_QUANT*/
- A_SYMMETRIC,
- B_SYMMETRIC>(
- A_zero_point,
- &C_multiplier,
- C_zero_point,
- C_int32,
- C_uint8 + ((t * H_OUT + h) * W_OUT + w) * K,
- K,
- row_offsets,
- col_offsets,
- bias);
+ bias,
+ &act_times_w_scale);
}
-template <bool FUSE_RELU, bool HAS_BIAS, bool A_SYMMETRIC>
+template <bool FUSE_RELU, bool HAS_BIAS, bool A_SYMMETRIC, typename BIAS_TYPE>
static inline ALWAYS_INLINE void
depthwise_3x3_per_channel_quantization_kernel_(
int H,
@@ -1355,7 +241,8 @@ depthwise_3x3_per_channel_quantization_kernel_(
uint8_t* C_uint8,
int32_t* row_offsets,
const int32_t* col_offsets,
- const int32_t* bias) {
+ const BIAS_TYPE* bias,
+ const float* act_times_w_scale) {
constexpr int S = 3;
constexpr int PAD_T = 1, PAD_L = 1, PAD_R = 1;
int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1;
@@ -1406,7 +293,8 @@ depthwise_3x3_per_channel_quantization_kernel_(
HAS_BIAS,
true, /*PER_CHAN_QUANT*/
A_SYMMETRIC,
- false /*B_SYMM*/>(
+ false, /*B_SYMM*/
+ BIAS_TYPE>(
A_zero_point,
C_multiplier,
C_zero_point,
@@ -1415,113 +303,20 @@ depthwise_3x3_per_channel_quantization_kernel_(
K,
row_offsets,
col_offsets,
- bias);
-}
-
-template <bool FUSE_RELU, bool HAS_BIAS, bool A_SYMMETRIC>
-static inline ALWAYS_INLINE void
-depthwise_3x3x3_per_channel_quantization_kernel_(
- int T,
- int H,
- int W,
- int K,
- int t,
- int h,
- int w,
- int stride_t,
- int stride_h,
- int stride_w,
- int32_t A_zero_point,
- const uint8_t* A,
- const int32_t* B_zero_point,
- const int8_t* Bp,
- const float* C_multiplier,
- int32_t C_zero_point,
- int32_t* C_int32,
- uint8_t* C_uint8,
- int32_t* row_offsets,
- const int32_t* col_offsets,
- const int32_t* bias) {
- constexpr int R = 3, S = 3;
- constexpr int PAD_P = 1, PAD_T = 1, PAD_B = 1, PAD_L = 1, PAD_R = 1;
- int H_OUT = (H + PAD_T + PAD_B - R) / stride_h + 1;
- int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1;
- int t_in = -PAD_P + t * stride_t;
- int h_in = -PAD_T + h * stride_h;
- int w_in = -PAD_L + w * stride_w;
-
- int k;
- for (k = 0; k < K / 32 * 32; k += 32) {
- inner_prod_3x3x3_packed_<
- true, /*SUM_A*/
- false, /*remainder*/
- true /*per-channel*/>(
- T,
- H,
- W,
- K,
- t_in,
- h_in,
- w_in,
- A + ((t_in * H + h_in) * W + w_in) * K + k,
- A_zero_point,
- Bp + k * 28,
- B_zero_point + k,
- C_int32 + k,
- 0,
- &row_offsets[k]);
- }
- int remainder = K - k;
- if (remainder) {
- inner_prod_3x3x3_packed_<
- true, /*SUM_A*/
- true, /*remainder*/
- true /*per-channel*/>(
- T,
- H,
- W,
- K,
- t_in,
- h_in,
- w_in,
- A + ((t_in * H + h_in) * W + w_in) * K + k,
- A_zero_point,
- Bp + k * 28,
- B_zero_point + k,
- C_int32 + k,
- remainder,
- &row_offsets[k]);
- }
- requantize_<
- FUSE_RELU,
- HAS_BIAS,
- true, /*PER_CHAN_QUANT*/
- A_SYMMETRIC,
- false /*B_SYMM*/>(
- A_zero_point,
- C_multiplier,
- C_zero_point,
- C_int32,
- C_uint8 + ((t * H_OUT + h) * W_OUT + w) * K,
- K,
- row_offsets,
- col_offsets,
- bias);
-}
-
-static pair<int, int> closest_factors_(int n) {
- int a = (int)std::sqrt(n);
- while (n % a != 0) {
- a--;
- }
- return {a, n / a}; // a <= n / a
+ bias,
+ act_times_w_scale);
}
// TODO: short-circuit when B_zero_point is 0 or A_zero_point is 0
// This implemntation should be general enough to handle not just 3x3 but other
// filter shapes by parameterizing with R and S but restricting it to just 3x3
// for now.
-template <bool FUSE_RELU, bool HAS_BIAS, bool A_SYMMETRIC, bool B_SYMMETRIC>
+template <
+ bool FUSE_RELU,
+ bool HAS_BIAS,
+ bool A_SYMMETRIC,
+ bool B_SYMMETRIC,
+ typename BIAS_TYPE>
static inline ALWAYS_INLINE void depthwise_3x3_pad_1_(
int N,
int H,
@@ -1532,13 +327,14 @@ static inline ALWAYS_INLINE void depthwise_3x3_pad_1_(
int32_t A_zero_point,
const uint8_t* A,
int32_t B_zero_point,
- const Packed3x3ConvMatrix& B,
+ const PackedDepthWiseConvMatrix& B,
float C_multiplier,
int32_t C_zero_point,
int32_t* C_int32,
uint8_t* C_uint8,
const int32_t* col_offsets,
- const int32_t* bias,
+ const BIAS_TYPE* bias,
+ float act_times_w_scale,
int thread_id,
int num_threads) {
assert(K % 8 == 0);
@@ -1597,7 +393,12 @@ static inline ALWAYS_INLINE void depthwise_3x3_pad_1_(
if (h_begin == 0) {
if (w_begin == 0) {
- depthwise_3x3_kernel_<FUSE_RELU, HAS_BIAS, A_SYMMETRIC, B_SYMMETRIC>(
+ depthwise_3x3_kernel_<
+ FUSE_RELU,
+ HAS_BIAS,
+ A_SYMMETRIC,
+ B_SYMMETRIC,
+ BIAS_TYPE>(
H,
W,
K,
@@ -1615,11 +416,17 @@ static inline ALWAYS_INLINE void depthwise_3x3_pad_1_(
C_uint8_base,
row_offsets,
col_offsets,
- bias);
+ bias,
+ act_times_w_scale);
}
for (w = std::max(1, w_begin); w < std::min(W_OUT - 1, w_end); ++w) {
- depthwise_3x3_kernel_<FUSE_RELU, HAS_BIAS, A_SYMMETRIC, B_SYMMETRIC>(
+ depthwise_3x3_kernel_<
+ FUSE_RELU,
+ HAS_BIAS,
+ A_SYMMETRIC,
+ B_SYMMETRIC,
+ BIAS_TYPE>(
H,
W,
K,
@@ -1637,12 +444,18 @@ static inline ALWAYS_INLINE void depthwise_3x3_pad_1_(
C_uint8_base,
row_offsets,
col_offsets,
- bias);
+ bias,
+ act_times_w_scale);
}
if (w_end == W_OUT) {
w = W_OUT - 1;
- depthwise_3x3_kernel_<FUSE_RELU, HAS_BIAS, A_SYMMETRIC, B_SYMMETRIC>(
+ depthwise_3x3_kernel_<
+ FUSE_RELU,
+ HAS_BIAS,
+ A_SYMMETRIC,
+ B_SYMMETRIC,
+ BIAS_TYPE>(
H,
W,
K,
@@ -1660,14 +473,20 @@ static inline ALWAYS_INLINE void depthwise_3x3_pad_1_(
C_uint8_base,
row_offsets,
col_offsets,
- bias);
+ bias,
+ act_times_w_scale);
}
}
for (h = std::max(1, h_begin); h < std::min(H - 1, h_end); ++h) {
if (w_begin == 0) {
w = 0;
- depthwise_3x3_kernel_<FUSE_RELU, HAS_BIAS, A_SYMMETRIC, B_SYMMETRIC>(
+ depthwise_3x3_kernel_<
+ FUSE_RELU,
+ HAS_BIAS,
+ A_SYMMETRIC,
+ B_SYMMETRIC,
+ BIAS_TYPE>(
H,
W,
K,
@@ -1685,11 +504,17 @@ static inline ALWAYS_INLINE void depthwise_3x3_pad_1_(
C_uint8_base,
row_offsets,
col_offsets,
- bias);
+ bias,
+ act_times_w_scale);
}
for (w = std::max(1, w_begin); w < std::min(W_OUT - 1, w_end); ++w) {
- depthwise_3x3_kernel_<FUSE_RELU, HAS_BIAS, A_SYMMETRIC, B_SYMMETRIC>(
+ depthwise_3x3_kernel_<
+ FUSE_RELU,
+ HAS_BIAS,
+ A_SYMMETRIC,
+ B_SYMMETRIC,
+ BIAS_TYPE>(
H,
W,
K,
@@ -1707,12 +532,18 @@ static inline ALWAYS_INLINE void depthwise_3x3_pad_1_(
C_uint8_base,
row_offsets,
col_offsets,
- bias);
+ bias,
+ act_times_w_scale);
}
if (w_end == W_OUT) {
w = W_OUT - 1;
- depthwise_3x3_kernel_<FUSE_RELU, HAS_BIAS, A_SYMMETRIC, B_SYMMETRIC>(
+ depthwise_3x3_kernel_<
+ FUSE_RELU,
+ HAS_BIAS,
+ A_SYMMETRIC,
+ B_SYMMETRIC,
+ BIAS_TYPE>(
H,
W,
K,
@@ -1730,7 +561,8 @@ static inline ALWAYS_INLINE void depthwise_3x3_pad_1_(
C_uint8_base,
row_offsets,
col_offsets,
- bias);
+ bias,
+ act_times_w_scale);
}
}
@@ -1738,7 +570,12 @@ static inline ALWAYS_INLINE void depthwise_3x3_pad_1_(
h = H_OUT - 1;
w = 0;
if (w_begin == 0) {
- depthwise_3x3_kernel_<FUSE_RELU, HAS_BIAS, A_SYMMETRIC, B_SYMMETRIC>(
+ depthwise_3x3_kernel_<
+ FUSE_RELU,
+ HAS_BIAS,
+ A_SYMMETRIC,
+ B_SYMMETRIC,
+ BIAS_TYPE>(
H,
W,
K,
@@ -1756,11 +593,17 @@ static inline ALWAYS_INLINE void depthwise_3x3_pad_1_(
C_uint8_base,
row_offsets,
col_offsets,
- bias);
+ bias,
+ act_times_w_scale);
}
for (w = std::max(1, w_begin); w < std::min(W_OUT - 1, w_end); ++w) {
- depthwise_3x3_kernel_<FUSE_RELU, HAS_BIAS, A_SYMMETRIC, B_SYMMETRIC>(
+ depthwise_3x3_kernel_<
+ FUSE_RELU,
+ HAS_BIAS,
+ A_SYMMETRIC,
+ B_SYMMETRIC,
+ BIAS_TYPE>(
H,
W,
K,
@@ -1778,12 +621,18 @@ static inline ALWAYS_INLINE void depthwise_3x3_pad_1_(
C_uint8_base,
row_offsets,
col_offsets,
- bias);
+ bias,
+ act_times_w_scale);
}
if (w_end == W_OUT) {
w = W_OUT - 1;
- depthwise_3x3_kernel_<FUSE_RELU, HAS_BIAS, A_SYMMETRIC, B_SYMMETRIC>(
+ depthwise_3x3_kernel_<
+ FUSE_RELU,
+ HAS_BIAS,
+ A_SYMMETRIC,
+ B_SYMMETRIC,
+ BIAS_TYPE>(
H,
W,
K,
@@ -1801,126 +650,15 @@ static inline ALWAYS_INLINE void depthwise_3x3_pad_1_(
C_uint8_base,
row_offsets,
col_offsets,
- bias);
+ bias,
+ act_times_w_scale);
}
}
} // for each n
FREE(row_offsets);
};
-template <bool FUSE_RELU, bool HAS_BIAS, bool A_SYMMETRIC, bool B_SYMMETRIC>
-static inline ALWAYS_INLINE void depthwise_3x3x3_pad_1_(
- int N,
- int T,
- int H,
- int W,
- int K,
- int stride_t,
- int stride_h,
- int stride_w,
- int32_t A_zero_point,
- const uint8_t* A,
- int32_t B_zero_point,
- const Packed3x3x3ConvMatrix& B,
- float C_multiplier,
- int32_t C_zero_point,
- int32_t* C_int32,
- uint8_t* C_uint8,
- const int32_t* col_offsets,
- const int32_t* bias,
- int thread_id,
- int num_threads) {
- assert(K % 8 == 0);
- constexpr int K_T = 3, K_H = 3, K_W = 3;
- constexpr int PAD_P = 1, PAD_N = 1, PAD_T = 1, PAD_B = 1, PAD_L = 1,
- PAD_R = 1;
- int T_OUT = (T + PAD_P + PAD_N - K_T) / stride_t + 1;
- int H_OUT = (H + PAD_T + PAD_B - K_H) / stride_h + 1;
- int W_OUT = (W + PAD_L + PAD_R - K_W) / stride_w + 1;
- const int8_t* Bp = B.PackedMat();
-
- int32_t* row_offsets = static_cast<int32_t*>(ALIGNED_MALLOC(((K + 31) / 32 * 32)*sizeof(int32_t), 64)); // __attribute__((aligned(64)));
-
- int n_begin, n_end;
- int t_begin, t_end, h_begin, h_end;
- if (N >= num_threads) {
- int n_per_thread = (N + num_threads - 1) / num_threads;
- n_begin = std::min(thread_id * n_per_thread, N);
- n_end = std::min(n_begin + n_per_thread, N);
- t_begin = 0;
- t_end = T_OUT;
- h_begin = 0;
- h_end = H_OUT;
- } else {
- int nthreads_per_n = num_threads / N;
- n_begin = std::min(thread_id / nthreads_per_n, N);
- n_end = std::min(n_begin + 1, N);
-
- int tid_of_n_begin = std::min(n_begin * nthreads_per_n, num_threads);
- int tid_of_n_end = std::min(tid_of_n_begin + nthreads_per_n, num_threads);
- int nthreads_of_n = tid_of_n_end - tid_of_n_begin;
- int tid_within_n = thread_id - tid_of_n_begin;
- assert(tid_within_n >= 0);
- assert(tid_within_n < nthreads_of_n);
-
- // n is processed by num_threads_t * num_threads_h 2D grid of threads
- int num_threads_t, num_threads_h;
- // num_threads_w <= num_threads_h
- tie(num_threads_t, num_threads_h) = closest_factors_(nthreads_of_n);
- int tid_t = tid_within_n / num_threads_h;
- int tid_h = tid_within_n % num_threads_h;
-
- int t_per_thread = (T_OUT + num_threads_t - 1) / num_threads_t;
- t_begin = std::min(tid_t * t_per_thread, T_OUT);
- t_end = std::min(t_begin + t_per_thread, T_OUT);
-
- int h_per_thread = (H_OUT + num_threads_h - 1) / num_threads_h;
- h_begin = std::min(tid_h * h_per_thread, H_OUT);
- h_end = std::min(h_begin + h_per_thread, H_OUT);
- }
-
- for (int n = n_begin; n < n_end; ++n) {
- const uint8_t* A_base = A + n * T * H * W * K;
- uint8_t* C_uint8_base = C_uint8 + n * T_OUT * H_OUT * W_OUT * K;
-
- for (int t = t_begin; t < t_end; ++t) {
- for (int h = h_begin; h < h_end; ++h) {
- for (int w = 0; w < W_OUT; ++w) {
- depthwise_3x3x3_kernel_<
- FUSE_RELU,
- HAS_BIAS,
- A_SYMMETRIC,
- B_SYMMETRIC>(
- T,
- H,
- W,
- K,
- t,
- h,
- w,
- stride_t,
- stride_h,
- stride_w,
- A_zero_point,
- A_base,
- B_zero_point,
- Bp,
- C_multiplier,
- C_zero_point,
- C_int32,
- C_uint8_base,
- row_offsets,
- col_offsets,
- bias);
- } // w
- } // h
- } // t
- } // for each n
-
- FREE(row_offsets);
-};
-
-template <bool FUSE_RELU, bool HAS_BIAS, bool A_SYMMETRIC>
+template <bool FUSE_RELU, bool HAS_BIAS, bool A_SYMMETRIC, typename BIAS_TYPE>
static inline ALWAYS_INLINE void
depthwise_3x3_per_channel_quantization_pad_1_(
int N,
@@ -1932,13 +670,14 @@ depthwise_3x3_per_channel_quantization_pad_1_(
int32_t A_zero_point,
const uint8_t* A,
const int32_t* B_zero_point,
- const Packed3x3ConvMatrix& B,
+ const PackedDepthWiseConvMatrix& B,
const float* C_multiplier,
int32_t C_zero_point,
int32_t* C_int32,
uint8_t* C_uint8,
const int32_t* col_offsets,
- const int32_t* bias,
+ const BIAS_TYPE* bias,
+ const float* act_times_w_scale,
int thread_id,
int num_threads) {
assert(K % 8 == 0);
@@ -2000,7 +739,8 @@ depthwise_3x3_per_channel_quantization_pad_1_(
depthwise_3x3_per_channel_quantization_kernel_<
FUSE_RELU,
HAS_BIAS,
- A_SYMMETRIC>(
+ A_SYMMETRIC,
+ BIAS_TYPE>(
H,
W,
K,
@@ -2018,14 +758,16 @@ depthwise_3x3_per_channel_quantization_pad_1_(
C_uint8_base,
row_offsets,
col_offsets,
- bias);
+ bias,
+ act_times_w_scale);
}
for (w = std::max(1, w_begin); w < std::min(W_OUT - 1, w_end); ++w) {
depthwise_3x3_per_channel_quantization_kernel_<
FUSE_RELU,
HAS_BIAS,
- A_SYMMETRIC>(
+ A_SYMMETRIC,
+ BIAS_TYPE>(
H,
W,
K,
@@ -2043,7 +785,8 @@ depthwise_3x3_per_channel_quantization_pad_1_(
C_uint8_base,
row_offsets,
col_offsets,
- bias);
+ bias,
+ act_times_w_scale);
}
if (w_end == W_OUT) {
@@ -2051,7 +794,8 @@ depthwise_3x3_per_channel_quantization_pad_1_(
depthwise_3x3_per_channel_quantization_kernel_<
FUSE_RELU,
HAS_BIAS,
- A_SYMMETRIC>(
+ A_SYMMETRIC,
+ BIAS_TYPE>(
H,
W,
K,
@@ -2069,7 +813,8 @@ depthwise_3x3_per_channel_quantization_pad_1_(
C_uint8_base,
row_offsets,
col_offsets,
- bias);
+ bias,
+ act_times_w_scale);
}
}
@@ -2079,7 +824,8 @@ depthwise_3x3_per_channel_quantization_pad_1_(
depthwise_3x3_per_channel_quantization_kernel_<
FUSE_RELU,
HAS_BIAS,
- A_SYMMETRIC>(
+ A_SYMMETRIC,
+ BIAS_TYPE>(
H,
W,
K,
@@ -2097,14 +843,16 @@ depthwise_3x3_per_channel_quantization_pad_1_(
C_uint8_base,
row_offsets,
col_offsets,
- bias);
+ bias,
+ act_times_w_scale);
}
for (w = std::max(1, w_begin); w < std::min(W_OUT - 1, w_end); ++w) {
depthwise_3x3_per_channel_quantization_kernel_<
FUSE_RELU,
HAS_BIAS,
- A_SYMMETRIC>(
+ A_SYMMETRIC,
+ BIAS_TYPE>(
H,
W,
K,
@@ -2122,7 +870,8 @@ depthwise_3x3_per_channel_quantization_pad_1_(
C_uint8_base,
row_offsets,
col_offsets,
- bias);
+ bias,
+ act_times_w_scale);
}
if (w_end == W_OUT) {
@@ -2130,7 +879,8 @@ depthwise_3x3_per_channel_quantization_pad_1_(
depthwise_3x3_per_channel_quantization_kernel_<
FUSE_RELU,
HAS_BIAS,
- A_SYMMETRIC>(
+ A_SYMMETRIC,
+ BIAS_TYPE>(
H,
W,
K,
@@ -2148,7 +898,8 @@ depthwise_3x3_per_channel_quantization_pad_1_(
C_uint8_base,
row_offsets,
col_offsets,
- bias);
+ bias,
+ act_times_w_scale);
}
}
@@ -2159,7 +910,8 @@ depthwise_3x3_per_channel_quantization_pad_1_(
depthwise_3x3_per_channel_quantization_kernel_<
FUSE_RELU,
HAS_BIAS,
- A_SYMMETRIC>(
+ A_SYMMETRIC,
+ BIAS_TYPE>(
H,
W,
K,
@@ -2177,14 +929,16 @@ depthwise_3x3_per_channel_quantization_pad_1_(
C_uint8_base,
row_offsets,
col_offsets,
- bias);
+ bias,
+ act_times_w_scale);
}
for (w = std::max(1, w_begin); w < std::min(W_OUT - 1, w_end); ++w) {
depthwise_3x3_per_channel_quantization_kernel_<
FUSE_RELU,
HAS_BIAS,
- A_SYMMETRIC>(
+ A_SYMMETRIC,
+ BIAS_TYPE>(
H,
W,
K,
@@ -2202,7 +956,8 @@ depthwise_3x3_per_channel_quantization_pad_1_(
C_uint8_base,
row_offsets,
col_offsets,
- bias);
+ bias,
+ act_times_w_scale);
}
if (w_end == W_OUT) {
@@ -2210,7 +965,8 @@ depthwise_3x3_per_channel_quantization_pad_1_(
depthwise_3x3_per_channel_quantization_kernel_<
FUSE_RELU,
HAS_BIAS,
- A_SYMMETRIC>(
+ A_SYMMETRIC,
+ BIAS_TYPE>(
H,
W,
K,
@@ -2228,128 +984,15 @@ depthwise_3x3_per_channel_quantization_pad_1_(
C_uint8_base,
row_offsets,
col_offsets,
- bias);
+ bias,
+ act_times_w_scale);
}
}
} // for each n
-
- FREE(row_offsets);
-};
-
-template <bool FUSE_RELU, bool HAS_BIAS, bool A_SYMMETRIC>
-static inline ALWAYS_INLINE void
-depthwise_3x3x3_per_channel_quantization_pad_1_(
- int N,
- int T,
- int H,
- int W,
- int K,
- int stride_t,
- int stride_h,
- int stride_w,
- int32_t A_zero_point,
- const uint8_t* A,
- const int32_t* B_zero_point,
- const Packed3x3x3ConvMatrix& B,
- const float* C_multiplier,
- int32_t C_zero_point,
- int32_t* C_int32,
- uint8_t* C_uint8,
- const int32_t* col_offsets,
- const int32_t* bias,
- int thread_id,
- int num_threads) {
- assert(K % 8 == 0);
- constexpr int K_T = 3, K_H = 3, K_W = 3;
- constexpr int PAD_P = 1, PAD_N = 1, PAD_T = 1, PAD_B = 1, PAD_L = 1,
- PAD_R = 1;
- int T_OUT = (T + PAD_P + PAD_N - K_T) / stride_t + 1;
- int H_OUT = (H + PAD_T + PAD_B - K_H) / stride_h + 1;
- int W_OUT = (W + PAD_L + PAD_R - K_W) / stride_w + 1;
- const int8_t* Bp = B.PackedMat();
-
- int32_t* row_offsets = static_cast<int32_t*>(ALIGNED_MALLOC(((K + 31) / 32 * 32)*sizeof(int32_t), 64)); // __attribute__((aligned(64)));
-
- int n_begin, n_end;
- int t_begin, t_end, h_begin, h_end;
- if (N >= num_threads) {
- int n_per_thread = (N + num_threads - 1) / num_threads;
- n_begin = std::min(thread_id * n_per_thread, N);
- n_end = std::min(n_begin + n_per_thread, N);
- t_begin = 0;
- t_end = T_OUT;
- h_begin = 0;
- h_end = H_OUT;
- } else {
- int nthreads_per_n = num_threads / N;
- n_begin = std::min(thread_id / nthreads_per_n, N);
- n_end = std::min(n_begin + 1, N);
-
- int tid_of_n_begin = std::min(n_begin * nthreads_per_n, num_threads);
- int tid_of_n_end = std::min(tid_of_n_begin + nthreads_per_n, num_threads);
- int nthreads_of_n = tid_of_n_end - tid_of_n_begin;
- int tid_within_n = thread_id - tid_of_n_begin;
- assert(tid_within_n >= 0);
- assert(tid_within_n < nthreads_of_n);
-
- // n is processed by num_threads_t * num_threads_h 2D grid of threads
- int num_threads_t, num_threads_h;
- // num_threads_w <= num_threads_h
- tie(num_threads_t, num_threads_h) = closest_factors_(nthreads_of_n);
- int tid_t = tid_within_n / num_threads_h;
- int tid_h = tid_within_n % num_threads_h;
-
- int t_per_thread = (T_OUT + num_threads_t - 1) / num_threads_t;
- t_begin = std::min(tid_t * t_per_thread, T_OUT);
- t_end = std::min(t_begin + t_per_thread, T_OUT);
-
- int h_per_thread = (H_OUT + num_threads_h - 1) / num_threads_h;
- h_begin = std::min(tid_h * h_per_thread, H_OUT);
- h_end = std::min(h_begin + h_per_thread, H_OUT);
- }
-
- for (int n = n_begin; n < n_end; ++n) {
- const uint8_t* A_base = A + n * T * H * W * K;
- uint8_t* C_uint8_base = C_uint8 + n * T_OUT * H_OUT * W_OUT * K;
-
- for (int t = t_begin; t < t_end; ++t) {
- for (int h = h_begin; h < h_end; ++h) {
- for (int w = 0; w < W_OUT; ++w) {
- depthwise_3x3x3_per_channel_quantization_kernel_<
- FUSE_RELU,
- HAS_BIAS,
- A_SYMMETRIC>(
- T,
- H,
- W,
- K,
- t,
- h,
- w,
- stride_t,
- stride_h,
- stride_w,
- A_zero_point,
- A_base,
- B_zero_point,
- Bp,
- C_multiplier,
- C_zero_point,
- C_int32,
- C_uint8_base,
- row_offsets,
- col_offsets,
- bias);
- } // w
- } // h
- } // t
- } // for each n
-
- FREE(row_offsets);
};
// Dispatch A_SYMMETRIC and B_SYMMETRIC
-template <bool FUSE_RELU, bool HAS_BIAS>
+template <bool FUSE_RELU, bool HAS_BIAS, typename BIAS_TYPE>
static void depthwise_3x3_pad_1_(
int N,
int H,
@@ -2360,12 +1003,13 @@ static void depthwise_3x3_pad_1_(
int32_t A_zero_point,
const uint8_t* A,
int32_t B_zero_point,
- const Packed3x3ConvMatrix& B,
+ const PackedDepthWiseConvMatrix& B,
float C_multiplier,
int32_t C_zero_point,
uint8_t* C,
const int32_t* col_offsets,
- const int32_t* bias,
+ const BIAS_TYPE* bias,
+ float act_times_w_scale,
int thread_id,
int num_threads) {
int32_t* C_int32_temp = new int32_t[(K + 31) / 32 * 32];
@@ -2375,7 +1019,8 @@ static void depthwise_3x3_pad_1_(
FUSE_RELU,
HAS_BIAS,
true /*A_symmetric*/,
- true /*B_symmetric*/>(
+ true /*B_symmetric*/,
+ BIAS_TYPE>(
N,
H,
W,
@@ -2392,6 +1037,7 @@ static void depthwise_3x3_pad_1_(
C,
col_offsets,
bias,
+ act_times_w_scale,
thread_id,
num_threads);
} else {
@@ -2399,7 +1045,8 @@ static void depthwise_3x3_pad_1_(
FUSE_RELU,
HAS_BIAS,
true /*A_symmetric*/,
- false /*B_symmetric*/>(
+ false /*B_symmetric*/,
+ BIAS_TYPE>(
N,
H,
W,
@@ -2416,6 +1063,7 @@ static void depthwise_3x3_pad_1_(
C,
col_offsets,
bias,
+ act_times_w_scale,
thread_id,
num_threads);
}
@@ -2425,7 +1073,8 @@ static void depthwise_3x3_pad_1_(
FUSE_RELU,
HAS_BIAS,
false /*A_symmetric*/,
- true /*B_symmetric*/>(
+ true /*B_symmetric*/,
+ BIAS_TYPE>(
N,
H,
W,
@@ -2442,6 +1091,7 @@ static void depthwise_3x3_pad_1_(
C,
col_offsets,
bias,
+ act_times_w_scale,
thread_id,
num_threads);
} else {
@@ -2449,7 +1099,8 @@ static void depthwise_3x3_pad_1_(
FUSE_RELU,
HAS_BIAS,
false /*A_symmetric*/,
- false /*B_symmetric*/>(
+ false /*B_symmetric*/,
+ BIAS_TYPE>(
N,
H,
W,
@@ -2466,6 +1117,7 @@ static void depthwise_3x3_pad_1_(
C,
col_offsets,
bias,
+ act_times_w_scale,
thread_id,
num_threads);
}
@@ -2474,7 +1126,7 @@ static void depthwise_3x3_pad_1_(
}
// Dispatch HAS_BIAS
-template <bool FUSE_RELU>
+template <bool FUSE_RELU, typename BIAS_TYPE>
static void depthwise_3x3_pad_1_(
int N,
int H,
@@ -2485,16 +1137,17 @@ static void depthwise_3x3_pad_1_(
int32_t A_zero_point,
const uint8_t* A,
int32_t B_zero_point,
- const Packed3x3ConvMatrix& B,
+ const PackedDepthWiseConvMatrix& B,
float C_multiplier,
int32_t C_zero_point,
uint8_t* C,
const int32_t* col_offsets,
- const int32_t* bias,
+ const BIAS_TYPE* bias,
+ float act_times_w_scale,
int thread_id,
int num_threads) {
if (bias) {
- depthwise_3x3_pad_1_<FUSE_RELU, true /*HAS_BIAS*/>(
+ depthwise_3x3_pad_1_<FUSE_RELU, true /*HAS_BIAS*/, BIAS_TYPE>(
N,
H,
W,
@@ -2510,10 +1163,11 @@ static void depthwise_3x3_pad_1_(
C,
col_offsets,
bias,
+ act_times_w_scale,
thread_id,
num_threads);
} else {
- depthwise_3x3_pad_1_<FUSE_RELU, false /*HAS_BIAS*/>(
+ depthwise_3x3_pad_1_<FUSE_RELU, false /*HAS_BIAS*/, BIAS_TYPE>(
N,
H,
W,
@@ -2529,6 +1183,7 @@ static void depthwise_3x3_pad_1_(
C,
col_offsets,
bias,
+ act_times_w_scale,
thread_id,
num_threads);
}
@@ -2536,6 +1191,7 @@ static void depthwise_3x3_pad_1_(
// Dispatch input shape and FUSE_RELU
// assumption: W > 3 and H > 3
+template <typename BIAS_TYPE>
void depthwise_3x3_pad_1(
int N,
int H,
@@ -2546,18 +1202,33 @@ void depthwise_3x3_pad_1(
int32_t A_zero_point,
const uint8_t* A,
int32_t B_zero_point,
- const Packed3x3ConvMatrix& B,
+ const PackedDepthWiseConvMatrix& B,
float C_multiplier,
int32_t C_zero_point,
uint8_t* C,
const int32_t* col_offsets,
- const int32_t* bias,
+ const BIAS_TYPE* bias,
bool fuse_relu,
+ float act_times_w_scale,
int thread_id,
int num_threads) {
+ if (B.GetKernelProduct() != 3 * 3) {
+ string msg =
+ "[FBGEMM_CONV_ERROR] Packed weight is expected to have kernel_prod " +
+ to_string(3 * 3) + " but has " + to_string(B.GetKernelProduct());
+ throw logic_error(msg);
+ }
+ if (stride_h == 0 || stride_w == 0 || num_threads == 0) {
+ assert(0 && "stride_h == 0 || stride_w == 0 || num_threads == 0");
+ return;
+ }
+ if (N == 0) {
+ // In C2, batch size 0 is allowed, so we should just early return.
+ return;
+ }
if (fuse_relu) {
if (7 == H && 7 == W && 1 == stride_h && 1 == stride_w) {
- depthwise_3x3_pad_1_<true /* FUSE_RELU */>(
+ depthwise_3x3_pad_1_<true /* FUSE_RELU */, BIAS_TYPE>(
N,
H,
W,
@@ -2573,10 +1244,11 @@ void depthwise_3x3_pad_1(
C,
col_offsets,
bias,
+ act_times_w_scale,
thread_id,
num_threads);
} else if (14 == H && 14 == W && 2 == stride_h && 2 == stride_w) {
- depthwise_3x3_pad_1_<true /* FUSE_RELU */>(
+ depthwise_3x3_pad_1_<true /* FUSE_RELU */, BIAS_TYPE>(
N,
H,
W,
@@ -2592,10 +1264,11 @@ void depthwise_3x3_pad_1(
C,
col_offsets,
bias,
+ act_times_w_scale,
thread_id,
num_threads);
} else if (1 == stride_h && 1 == stride_w) {
- depthwise_3x3_pad_1_<true /* FUSE_RELU */>(
+ depthwise_3x3_pad_1_<true /* FUSE_RELU */, BIAS_TYPE>(
N,
H,
W,
@@ -2611,10 +1284,11 @@ void depthwise_3x3_pad_1(
C,
col_offsets,
bias,
+ act_times_w_scale,
thread_id,
num_threads);
} else if (2 == stride_h && 2 == stride_w) {
- depthwise_3x3_pad_1_<true /* FUSE_RELU */>(
+ depthwise_3x3_pad_1_<true /* FUSE_RELU */, BIAS_TYPE>(
N,
H,
W,
@@ -2630,10 +1304,11 @@ void depthwise_3x3_pad_1(
C,
col_offsets,
bias,
+ act_times_w_scale,
thread_id,
num_threads);
} else {
- depthwise_3x3_pad_1_<true /* FUSE_RELU */>(
+ depthwise_3x3_pad_1_<true /* FUSE_RELU */, BIAS_TYPE>(
N,
H,
W,
@@ -2649,12 +1324,13 @@ void depthwise_3x3_pad_1(
C,
col_offsets,
bias,
+ act_times_w_scale,
thread_id,
num_threads);
}
} else {
if (7 == H && 7 == W && 1 == stride_h && 1 == stride_w) {
- depthwise_3x3_pad_1_<false /* FUSE_RELU */>(
+ depthwise_3x3_pad_1_<false /* FUSE_RELU */, BIAS_TYPE>(
N,
H,
W,
@@ -2670,10 +1346,11 @@ void depthwise_3x3_pad_1(
C,
col_offsets,
bias,
+ act_times_w_scale,
thread_id,
num_threads);
} else if (14 == H && 14 == W && 2 == stride_h && 2 == stride_w) {
- depthwise_3x3_pad_1_<false /* FUSE_RELU */>(
+ depthwise_3x3_pad_1_<false /* FUSE_RELU */, BIAS_TYPE>(
N,
H,
W,
@@ -2689,10 +1366,11 @@ void depthwise_3x3_pad_1(
C,
col_offsets,
bias,
+ act_times_w_scale,
thread_id,
num_threads);
} else if (1 == stride_h && 1 == stride_w) {
- depthwise_3x3_pad_1_<false /* FUSE_RELU */>(
+ depthwise_3x3_pad_1_<false /* FUSE_RELU */, BIAS_TYPE>(
N,
H,
W,
@@ -2708,10 +1386,11 @@ void depthwise_3x3_pad_1(
C,
col_offsets,
bias,
+ act_times_w_scale,
thread_id,
num_threads);
} else if (2 == stride_h && 2 == stride_w) {
- depthwise_3x3_pad_1_<false /* FUSE_RELU */>(
+ depthwise_3x3_pad_1_<false /* FUSE_RELU */, BIAS_TYPE>(
N,
H,
W,
@@ -2727,10 +1406,11 @@ void depthwise_3x3_pad_1(
C,
col_offsets,
bias,
+ act_times_w_scale,
thread_id,
num_threads);
} else {
- depthwise_3x3_pad_1_<false /* FUSE_RELU */>(
+ depthwise_3x3_pad_1_<false /* FUSE_RELU */, BIAS_TYPE>(
N,
H,
W,
@@ -2746,283 +1426,15 @@ void depthwise_3x3_pad_1(
C,
col_offsets,
bias,
+ act_times_w_scale,
thread_id,
num_threads);
}
}
}
-// Dispatch A_SYMMETRIC and B_SYMMETRIC
-template <bool FUSE_RELU, bool HAS_BIAS>
-static void depthwise_3x3x3_pad_1_(
- int N,
- int T,
- int H,
- int W,
- int K,
- int stride_t,
- int stride_h,
- int stride_w,
- int32_t A_zero_point,
- const uint8_t* A,
- int32_t B_zero_point,
- const Packed3x3x3ConvMatrix& B,
- float C_multiplier,
- int32_t C_zero_point,
- uint8_t* C,
- const int32_t* col_offsets,
- const int32_t* bias,
- int thread_id,
- int num_threads) {
- int32_t* C_int32_temp = new int32_t[(K + 31) / 32 * 32];
- if (A_zero_point == 0 || col_offsets == nullptr) {
- if (B_zero_point == 0) {
- depthwise_3x3x3_pad_1_<
- FUSE_RELU,
- HAS_BIAS,
- true /*A_symmetric*/,
- true /*B_symmetric*/>(
- N,
- T,
- H,
- W,
- K,
- stride_t,
- stride_h,
- stride_w,
- A_zero_point,
- A,
- B_zero_point,
- B,
- C_multiplier,
- C_zero_point,
- C_int32_temp,
- C,
- col_offsets,
- bias,
- thread_id,
- num_threads);
- } else {
- depthwise_3x3x3_pad_1_<
- FUSE_RELU,
- HAS_BIAS,
- true /*A_symmetric*/,
- false /*B_symmetric*/>(
- N,
- T,
- H,
- W,
- K,
- stride_t,
- stride_h,
- stride_w,
- A_zero_point,
- A,
- B_zero_point,
- B,
- C_multiplier,
- C_zero_point,
- C_int32_temp,
- C,
- col_offsets,
- bias,
- thread_id,
- num_threads);
- }
- } else {
- if (B_zero_point == 0) {
- depthwise_3x3x3_pad_1_<
- FUSE_RELU,
- HAS_BIAS,
- false /*A_symmetric*/,
- true /*B_symmetric*/>(
- N,
- T,
- H,
- W,
- K,
- stride_t,
- stride_h,
- stride_w,
- A_zero_point,
- A,
- B_zero_point,
- B,
- C_multiplier,
- C_zero_point,
- C_int32_temp,
- C,
- col_offsets,
- bias,
- thread_id,
- num_threads);
- } else {
- depthwise_3x3x3_pad_1_<
- FUSE_RELU,
- HAS_BIAS,
- false /*A_symmetric*/,
- false /*B_symmetric*/>(
- N,
- T,
- H,
- W,
- K,
- stride_t,
- stride_h,
- stride_w,
- A_zero_point,
- A,
- B_zero_point,
- B,
- C_multiplier,
- C_zero_point,
- C_int32_temp,
- C,
- col_offsets,
- bias,
- thread_id,
- num_threads);
- }
- }
- delete[] C_int32_temp;
-}
-
-// Dispatch HAS_BIAS
-template <bool FUSE_RELU>
-static void depthwise_3x3x3_pad_1_(
- int N,
- int T,
- int H,
- int W,
- int K,
- int stride_t,
- int stride_h,
- int stride_w,
- int32_t A_zero_point,
- const uint8_t* A,
- int32_t B_zero_point,
- const Packed3x3x3ConvMatrix& B,
- float C_multiplier,
- int32_t C_zero_point,
- uint8_t* C,
- const int32_t* col_offsets,
- const int32_t* bias,
- int thread_id,
- int num_threads) {
- if (bias) {
- depthwise_3x3x3_pad_1_<FUSE_RELU, true /*HAS_BIAS*/>(
- N,
- T,
- H,
- W,
- K,
- stride_t,
- stride_h,
- stride_w,
- A_zero_point,
- A,
- B_zero_point,
- B,
- C_multiplier,
- C_zero_point,
- C,
- col_offsets,
- bias,
- thread_id,
- num_threads);
- } else {
- depthwise_3x3x3_pad_1_<FUSE_RELU, false /*HAS_BIAS*/>(
- N,
- T,
- H,
- W,
- K,
- stride_t,
- stride_h,
- stride_w,
- A_zero_point,
- A,
- B_zero_point,
- B,
- C_multiplier,
- C_zero_point,
- C,
- col_offsets,
- bias,
- thread_id,
- num_threads);
- }
-}
-
-// Dispatch FUSE_RELU
-void depthwise_3x3x3_pad_1(
- int N,
- int T,
- int H,
- int W,
- int K,
- int stride_t,
- int stride_h,
- int stride_w,
- int32_t A_zero_point,
- const uint8_t* A,
- int32_t B_zero_point,
- const Packed3x3x3ConvMatrix& B,
- float C_multiplier,
- int32_t C_zero_point,
- uint8_t* C,
- const int32_t* col_offsets,
- const int32_t* bias,
- bool fuse_relu,
- int thread_id,
- int num_threads) {
- if (fuse_relu) {
- depthwise_3x3x3_pad_1_<true /*FUSE_RELU*/>(
- N,
- T,
- H,
- W,
- K,
- stride_t,
- stride_h,
- stride_w,
- A_zero_point,
- A,
- B_zero_point,
- B,
- C_multiplier,
- C_zero_point,
- C,
- col_offsets,
- bias,
- thread_id,
- num_threads);
- } else {
- depthwise_3x3x3_pad_1_<false /*FUSE_RELU*/>(
- N,
- T,
- H,
- W,
- K,
- stride_t,
- stride_h,
- stride_w,
- A_zero_point,
- A,
- B_zero_point,
- B,
- C_multiplier,
- C_zero_point,
- C,
- col_offsets,
- bias,
- thread_id,
- num_threads);
- }
-}
-
// Dispatch A_SYMMETRIC
-template <bool FUSE_RELU, bool HAS_BIAS>
+template <bool FUSE_RELU, bool HAS_BIAS, typename BIAS_TYPE>
static void depthwise_3x3_per_channel_quantization_pad_1_(
int N,
int H,
@@ -3033,12 +1445,13 @@ static void depthwise_3x3_per_channel_quantization_pad_1_(
int32_t A_zero_point,
const uint8_t* A,
const int32_t* B_zero_point,
- const Packed3x3ConvMatrix& Bp,
+ const PackedDepthWiseConvMatrix& Bp,
const float* C_multiplier,
int32_t C_zero_point,
uint8_t* C,
const int32_t* col_offsets,
- const int32_t* bias,
+ const BIAS_TYPE* bias,
+ const float* act_times_w_scale,
int thread_id,
int num_threads) {
int32_t* C_int32_temp = new int32_t[(K + 31) / 32 * 32];
@@ -3046,7 +1459,8 @@ static void depthwise_3x3_per_channel_quantization_pad_1_(
depthwise_3x3_per_channel_quantization_pad_1_<
FUSE_RELU,
HAS_BIAS,
- true /*A_SYMM*/>(
+ true /*A_SYMM*/,
+ BIAS_TYPE>(
N,
H,
W,
@@ -3063,13 +1477,15 @@ static void depthwise_3x3_per_channel_quantization_pad_1_(
C,
col_offsets,
bias,
+ act_times_w_scale,
thread_id,
num_threads);
} else {
depthwise_3x3_per_channel_quantization_pad_1_<
FUSE_RELU,
HAS_BIAS,
- false /*A_SYMM*/>(
+ false /*A_SYMM*/,
+ BIAS_TYPE>(
N,
H,
W,
@@ -3086,6 +1502,7 @@ static void depthwise_3x3_per_channel_quantization_pad_1_(
C,
col_offsets,
bias,
+ act_times_w_scale,
thread_id,
num_threads);
}
@@ -3093,7 +1510,7 @@ static void depthwise_3x3_per_channel_quantization_pad_1_(
}
// Dispatch HAS_BIAS
-template <bool FUSE_RELU>
+template <bool FUSE_RELU, typename BIAS_TYPE>
static void depthwise_3x3_per_channel_quantization_pad_1_(
int N,
int H,
@@ -3104,18 +1521,20 @@ static void depthwise_3x3_per_channel_quantization_pad_1_(
int32_t A_zero_point,
const uint8_t* A,
const int32_t* B_zero_point,
- const Packed3x3ConvMatrix& Bp,
+ const PackedDepthWiseConvMatrix& Bp,
const float* C_multiplier,
int32_t C_zero_point,
uint8_t* C,
const int32_t* col_offsets,
- const int32_t* bias,
+ const BIAS_TYPE* bias,
+ const float* act_times_w_scale,
int thread_id,
int num_threads) {
if (bias) {
depthwise_3x3_per_channel_quantization_pad_1_<
FUSE_RELU,
- true /* HAS_BIAS */>(
+ true /* HAS_BIAS */,
+ BIAS_TYPE>(
N,
H,
W,
@@ -3131,12 +1550,14 @@ static void depthwise_3x3_per_channel_quantization_pad_1_(
C,
col_offsets,
bias,
+ act_times_w_scale,
thread_id,
num_threads);
} else {
depthwise_3x3_per_channel_quantization_pad_1_<
FUSE_RELU,
- false /* HAS_BIAS */>(
+ false /* HAS_BIAS */,
+ BIAS_TYPE>(
N,
H,
W,
@@ -3152,12 +1573,14 @@ static void depthwise_3x3_per_channel_quantization_pad_1_(
C,
col_offsets,
bias,
+ act_times_w_scale,
thread_id,
num_threads);
}
}
// Dispatch input shape and FUSE_RELU
+template <typename BIAS_TYPE>
void depthwise_3x3_per_channel_quantization_pad_1(
int N,
int H,
@@ -3168,18 +1591,35 @@ void depthwise_3x3_per_channel_quantization_pad_1(
int32_t A_zero_point,
const uint8_t* A,
const int32_t* B_zero_point,
- const Packed3x3ConvMatrix& Bp,
+ const PackedDepthWiseConvMatrix& Bp,
const float* C_multiplier,
int32_t C_zero_point,
uint8_t* C,
const int32_t* col_offsets,
- const int32_t* bias,
+ const BIAS_TYPE* bias,
bool fuse_relu,
+ const float* act_times_w_scale,
int thread_id,
int num_threads) {
+ if (Bp.GetKernelProduct() != 3 * 3) {
+ string msg =
+ "[FBGEMM_CONV_ERROR] Packed weight is expected to have kernel_prod " +
+ to_string(3 * 3) + " but has " + to_string(Bp.GetKernelProduct());
+ throw logic_error(msg);
+ }
+ if (stride_h == 0 || stride_w == 0 || num_threads == 0) {
+ assert(0 && "stride_h == 0 || stride_w == 0 || num_threads == 0");
+ return;
+ }
+ if (N == 0) {
+ // In C2, batch size 0 is allowed, so we should just early return.
+ return;
+ }
if (fuse_relu) {
if (7 == H && 7 == W && 1 == stride_h && 1 == stride_w) {
- depthwise_3x3_per_channel_quantization_pad_1_<true /* FUSE_RELU */>(
+ depthwise_3x3_per_channel_quantization_pad_1_<
+ true /* FUSE_RELU */,
+ BIAS_TYPE>(
N,
H,
W,
@@ -3195,10 +1635,13 @@ void depthwise_3x3_per_channel_quantization_pad_1(
C,
col_offsets,
bias,
+ act_times_w_scale,
thread_id,
num_threads);
} else if (14 == H && 14 == W && 2 == stride_h && 2 == stride_w) {
- depthwise_3x3_per_channel_quantization_pad_1_<true /* FUSE_RELU */>(
+ depthwise_3x3_per_channel_quantization_pad_1_<
+ true /* FUSE_RELU */,
+ BIAS_TYPE>(
N,
H,
W,
@@ -3214,10 +1657,13 @@ void depthwise_3x3_per_channel_quantization_pad_1(
C,
col_offsets,
bias,
+ act_times_w_scale,
thread_id,
num_threads);
} else if (1 == stride_h && 1 == stride_w) {
- depthwise_3x3_per_channel_quantization_pad_1_<true /* FUSE_RELU */>(
+ depthwise_3x3_per_channel_quantization_pad_1_<
+ true /* FUSE_RELU */,
+ BIAS_TYPE>(
N,
H,
W,
@@ -3233,10 +1679,13 @@ void depthwise_3x3_per_channel_quantization_pad_1(
C,
col_offsets,
bias,
+ act_times_w_scale,
thread_id,
num_threads);
} else if (2 == stride_h && 2 == stride_w) {
- depthwise_3x3_per_channel_quantization_pad_1_<true /* FUSE_RELU */>(
+ depthwise_3x3_per_channel_quantization_pad_1_<
+ true /* FUSE_RELU */,
+ BIAS_TYPE>(
N,
H,
W,
@@ -3252,10 +1701,13 @@ void depthwise_3x3_per_channel_quantization_pad_1(
C,
col_offsets,
bias,
+ act_times_w_scale,
thread_id,
num_threads);
} else {
- depthwise_3x3_per_channel_quantization_pad_1_<true /* FUSE_RELU */>(
+ depthwise_3x3_per_channel_quantization_pad_1_<
+ true /* FUSE_RELU */,
+ BIAS_TYPE>(
N,
H,
W,
@@ -3271,12 +1723,15 @@ void depthwise_3x3_per_channel_quantization_pad_1(
C,
col_offsets,
bias,
+ act_times_w_scale,
thread_id,
num_threads);
}
} else {
if (7 == H && 7 == W && 1 == stride_h && 1 == stride_w) {
- depthwise_3x3_per_channel_quantization_pad_1_<false /* FUSE_RELU */>(
+ depthwise_3x3_per_channel_quantization_pad_1_<
+ false /* FUSE_RELU */,
+ BIAS_TYPE>(
N,
H,
W,
@@ -3292,10 +1747,13 @@ void depthwise_3x3_per_channel_quantization_pad_1(
C,
col_offsets,
bias,
+ act_times_w_scale,
thread_id,
num_threads);
} else if (14 == H && 14 == W && 2 == stride_h && 2 == stride_w) {
- depthwise_3x3_per_channel_quantization_pad_1_<false /* FUSE_RELU */>(
+ depthwise_3x3_per_channel_quantization_pad_1_<
+ false /* FUSE_RELU */,
+ BIAS_TYPE>(
N,
H,
W,
@@ -3311,10 +1769,13 @@ void depthwise_3x3_per_channel_quantization_pad_1(
C,
col_offsets,
bias,
+ act_times_w_scale,
thread_id,
num_threads);
} else if (1 == stride_h && 1 == stride_w) {
- depthwise_3x3_per_channel_quantization_pad_1_<false /* FUSE_RELU */>(
+ depthwise_3x3_per_channel_quantization_pad_1_<
+ false /* FUSE_RELU */,
+ BIAS_TYPE>(
N,
H,
W,
@@ -3330,10 +1791,13 @@ void depthwise_3x3_per_channel_quantization_pad_1(
C,
col_offsets,
bias,
+ act_times_w_scale,
thread_id,
num_threads);
} else if (2 == stride_h && 2 == stride_w) {
- depthwise_3x3_per_channel_quantization_pad_1_<false /* FUSE_RELU */>(
+ depthwise_3x3_per_channel_quantization_pad_1_<
+ false /* FUSE_RELU */,
+ BIAS_TYPE>(
N,
H,
W,
@@ -3349,10 +1813,13 @@ void depthwise_3x3_per_channel_quantization_pad_1(
C,
col_offsets,
bias,
+ act_times_w_scale,
thread_id,
num_threads);
} else {
- depthwise_3x3_per_channel_quantization_pad_1_<false /* FUSE_RELU */>(
+ depthwise_3x3_per_channel_quantization_pad_1_<
+ false /* FUSE_RELU */,
+ BIAS_TYPE>(
N,
H,
W,
@@ -3368,225 +1835,179 @@ void depthwise_3x3_per_channel_quantization_pad_1(
C,
col_offsets,
bias,
+ act_times_w_scale,
thread_id,
num_threads);
}
}
}
-// Dispatch A_SYMMETRIC
-template <bool FUSE_RELU, bool HAS_BIAS>
-static void depthwise_3x3x3_per_channel_quantization_pad_1_(
+// To be removed
+void depthwise_3x3_pad_1(
int N,
- int T,
int H,
int W,
int K,
- int stride_t,
int stride_h,
int stride_w,
int32_t A_zero_point,
const uint8_t* A,
- const int32_t* B_zero_point,
- const Packed3x3x3ConvMatrix& B,
- const float* C_multiplier,
+ int32_t B_zero_point,
+ const PackedDepthWiseConvMatrix& B,
+ float C_multiplier,
int32_t C_zero_point,
uint8_t* C,
const int32_t* col_offsets,
const int32_t* bias,
+ bool fuse_relu,
int thread_id,
int num_threads) {
- int32_t* C_int32_temp = new int32_t[(K + 31) / 32 * 32];
- if (A_zero_point == 0 || col_offsets == nullptr) {
- depthwise_3x3x3_per_channel_quantization_pad_1_<
- FUSE_RELU,
- HAS_BIAS,
- true /*A_SYMM*/>(
- N,
- T,
- H,
- W,
- K,
- stride_t,
- stride_h,
- stride_w,
- A_zero_point,
- A,
- B_zero_point,
- B,
- C_multiplier,
- C_zero_point,
- C_int32_temp,
- C,
- col_offsets,
- bias,
- thread_id,
- num_threads);
- } else {
- depthwise_3x3x3_per_channel_quantization_pad_1_<
- FUSE_RELU,
- HAS_BIAS,
- false /*A_SYMM*/>(
- N,
- T,
- H,
- W,
- K,
- stride_t,
- stride_h,
- stride_w,
- A_zero_point,
- A,
- B_zero_point,
- B,
- C_multiplier,
- C_zero_point,
- C_int32_temp,
- C,
- col_offsets,
- bias,
- thread_id,
- num_threads);
- }
- delete[] C_int32_temp;
+ depthwise_3x3_pad_1<std::int32_t>(
+ N,
+ H,
+ W,
+ K,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B_zero_point,
+ B,
+ C_multiplier,
+ C_zero_point,
+ C,
+ col_offsets,
+ bias,
+ fuse_relu,
+ 1.0f,
+ thread_id,
+ num_threads);
}
-// Dispatch HAS_BIAS
-template <bool FUSE_RELU>
-static void depthwise_3x3x3_per_channel_quantization_pad_1_(
+// To be removed
+void depthwise_3x3_per_channel_quantization_pad_1(
int N,
- int T,
int H,
int W,
int K,
- int stride_t,
int stride_h,
int stride_w,
int32_t A_zero_point,
const uint8_t* A,
const int32_t* B_zero_point,
- const Packed3x3x3ConvMatrix& B,
+ const PackedDepthWiseConvMatrix& Bp,
const float* C_multiplier,
int32_t C_zero_point,
uint8_t* C,
const int32_t* col_offsets,
const int32_t* bias,
+ bool fuse_relu,
int thread_id,
int num_threads) {
- if (bias) {
- depthwise_3x3x3_per_channel_quantization_pad_1_<
- FUSE_RELU,
- true /* HAS_BIAS */>(
- N,
- T,
- H,
- W,
- K,
- stride_t,
- stride_h,
- stride_w,
- A_zero_point,
- A,
- B_zero_point,
- B,
- C_multiplier,
- C_zero_point,
- C,
- col_offsets,
- bias,
- thread_id,
- num_threads);
- } else {
- depthwise_3x3x3_per_channel_quantization_pad_1_<
- FUSE_RELU,
- false /* HAS_BIAS */>(
- N,
- T,
- H,
- W,
- K,
- stride_t,
- stride_h,
- stride_w,
- A_zero_point,
- A,
- B_zero_point,
- B,
- C_multiplier,
- C_zero_point,
- C,
- col_offsets,
- bias,
- thread_id,
- num_threads);
- }
+ depthwise_3x3_per_channel_quantization_pad_1<std::int32_t>(
+ N,
+ H,
+ W,
+ K,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B_zero_point,
+ Bp,
+ C_multiplier,
+ C_zero_point,
+ C,
+ col_offsets,
+ bias,
+ fuse_relu,
+ nullptr,
+ thread_id,
+ num_threads);
}
-// Dispatch FUSE_RELU
-void depthwise_3x3x3_per_channel_quantization_pad_1(
+template void depthwise_3x3_pad_1(
+ int N,
+ int H,
+ int W,
+ int K,
+ int stride_h,
+ int stride_w,
+ int32_t A_zero_point,
+ const uint8_t* A,
+ int32_t B_zero_point,
+ const PackedDepthWiseConvMatrix& B,
+ float C_multiplier,
+ int32_t C_zero_point,
+ uint8_t* C,
+ const int32_t* col_offsets,
+ const int32_t* bias,
+ bool fuse_relu,
+ float act_times_w_scale,
+ int thread_id,
+ int num_threads);
+
+template void depthwise_3x3_pad_1(
+ int N,
+ int H,
+ int W,
+ int K,
+ int stride_h,
+ int stride_w,
+ int32_t A_zero_point,
+ const uint8_t* A,
+ int32_t B_zero_point,
+ const PackedDepthWiseConvMatrix& B,
+ float C_multiplier,
+ int32_t C_zero_point,
+ uint8_t* C,
+ const int32_t* col_offsets,
+ const float* bias,
+ bool fuse_relu,
+ float act_times_w_scale,
+ int thread_id,
+ int num_threads);
+
+template void depthwise_3x3_per_channel_quantization_pad_1(
int N,
- int T,
int H,
int W,
int K,
- int stride_t,
int stride_h,
int stride_w,
int32_t A_zero_point,
const uint8_t* A,
const int32_t* B_zero_point,
- const Packed3x3x3ConvMatrix& B,
+ const PackedDepthWiseConvMatrix& Bp,
const float* C_multiplier,
int32_t C_zero_point,
uint8_t* C,
const int32_t* col_offsets,
const int32_t* bias,
bool fuse_relu,
+ const float* act_times_w_scale,
int thread_id,
- int num_threads) {
- if (fuse_relu) {
- depthwise_3x3x3_per_channel_quantization_pad_1_<true /* FUSE_RELU */>(
- N,
- T,
- H,
- W,
- K,
- stride_t,
- stride_h,
- stride_w,
- A_zero_point,
- A,
- B_zero_point,
- B,
- C_multiplier,
- C_zero_point,
- C,
- col_offsets,
- bias,
- thread_id,
- num_threads);
- } else {
- depthwise_3x3x3_per_channel_quantization_pad_1_<false /* FUSE_RELU */>(
- N,
- T,
- H,
- W,
- K,
- stride_t,
- stride_h,
- stride_w,
- A_zero_point,
- A,
- B_zero_point,
- B,
- C_multiplier,
- C_zero_point,
- C,
- col_offsets,
- bias,
- thread_id,
- num_threads);
- }
-}
+ int num_threads);
+
+template void depthwise_3x3_per_channel_quantization_pad_1(
+ int N,
+ int H,
+ int W,
+ int K,
+ int stride_h,
+ int stride_w,
+ int32_t A_zero_point,
+ const uint8_t* A,
+ const int32_t* B_zero_point,
+ const PackedDepthWiseConvMatrix& Bp,
+ const float* C_multiplier,
+ int32_t C_zero_point,
+ uint8_t* C,
+ const int32_t* col_offsets,
+ const float* bias,
+ bool fuse_relu,
+ const float* act_times_w_scale,
+ int thread_id,
+ int num_threads);
} // namespace fbgemm
diff --git a/src/GenerateKernel.h b/src/GenerateKernel.h
index e52097e..c0fece4 100644
--- a/src/GenerateKernel.h
+++ b/src/GenerateKernel.h
@@ -8,8 +8,11 @@
#include <asmjit/asmjit.h>
#include <cpuinfo.h>
#include <map>
+#include <mutex>
+#include <sstream>
#include <string>
#include <tuple>
+#include "CodeCache.h"
#include "fbgemm/Fbgemm.h"
/*#define FBGEMM_LOG_CODE 1*/
@@ -40,35 +43,7 @@ class CodeGenBase {
* @brief Constructor for initializing AVX2/AVX512 registers.
*/
CodeGenBase(const BlockingFactors* params = nullptr)
- : blocking_params(params),
- CRegs_avx2_{x86::ymm0,
- x86::ymm1,
- x86::ymm2,
- x86::ymm3,
- x86::ymm4,
- x86::ymm5,
- x86::ymm6,
- x86::ymm7,
- x86::ymm8,
- x86::ymm9,
- x86::ymm10,
- x86::ymm11},
- CRegs_avx512_{
- x86::zmm0, x86::zmm1, x86::zmm2, x86::zmm3, x86::zmm4,
- x86::zmm5, x86::zmm6, x86::zmm7, x86::zmm8, x86::zmm9,
- x86::zmm10, x86::zmm11, x86::zmm12, x86::zmm13, x86::zmm14,
- x86::zmm15, x86::zmm16, x86::zmm17, x86::zmm18, x86::zmm19,
- x86::zmm20, x86::zmm21, x86::zmm22, x86::zmm23, x86::zmm24,
- x86::zmm25, x86::zmm26, x86::zmm27,
- },
- AllRegs_avx512_{x86::zmm0, x86::zmm1, x86::zmm2, x86::zmm3,
- x86::zmm4, x86::zmm5, x86::zmm6, x86::zmm7,
- x86::zmm8, x86::zmm9, x86::zmm10, x86::zmm11,
- x86::zmm12, x86::zmm13, x86::zmm14, x86::zmm15,
- x86::zmm16, x86::zmm17, x86::zmm18, x86::zmm19,
- x86::zmm20, x86::zmm21, x86::zmm22, x86::zmm23,
- x86::zmm24, x86::zmm25, x86::zmm26, x86::zmm27,
- x86::zmm28, x86::zmm29, x86::zmm30, x86::zmm31} {
+ : blocking_params(params) {
// vector width in bits
if (cpuinfo_initialize()) {
if (fbgemmHasAvx512Support()) {
@@ -143,7 +118,7 @@ class CodeGenBase {
* (debug-only)
*/
template <inst_set_t instSet>
- std::string getCodeLoggingFile(
+ static std::string getCodeLoggingFile(
bool accum,
int mc,
int nc,
@@ -152,48 +127,60 @@ class CodeGenBase {
int MR,
int NR,
int NR_MIN) {
- std::string fileName = "gemm_";
+ std::ostringstream oss;
+ oss << "gemm_";
if (std::is_same<accT, std::int16_t>::value) {
- fileName += "acc16_";
+ oss << "acc16_";
} else if (std::is_same<accT, std::int32_t>::value) {
- fileName += "acc32_";
+ oss << "acc32_";
} else {
- fileName += "unknown_";
+ oss << "unknown_";
}
- fileName += "accum-" + std::to_string(accum);
- fileName += "_MC-" + std::to_string(mc);
- fileName += "_NC-" + std::to_string(nc);
- fileName += "_NCB-" + std::to_string(NCB);
- fileName += "_NCB-" + std::to_string(KCB);
- fileName += "_MR-" + std::to_string(MR);
- fileName += "_NR-" + std::to_string(NR);
- fileName += "_NR_MIN-" + std::to_string(NR_MIN);
+ oss << "accum-" + std::to_string(accum)
+ << "_MC-" + std::to_string(mc)
+ << "_NC-" + std::to_string(nc)
+ << "_NCB-" + std::to_string(NCB)
+ << "_NCB-" + std::to_string(KCB)
+ << "_MR-" + std::to_string(MR)
+ << "_NR-" + std::to_string(NR)
+ << "_NR_MIN-" + std::to_string(NR_MIN);
if (instSet == inst_set_t::avx512_vnni) {
- fileName += "_avx512vnni";
+ oss << "_avx512vnni";
} else if (instSet == inst_set_t::avx512) {
- fileName += "_avx512";
+ oss << "_avx512";
} else if (instSet == inst_set_t::avx2) {
- fileName += "_avx2";
+ oss << "_avx2";
}
- fileName += ".txt";
- return fileName;
+ oss << ".txt";
+ return oss.str();
}
private:
- x86::Ymm CRegs_avx2_[12]; ///< AVX2 ymm registers for C in the micro-kernel.
- x86::Zmm
- CRegs_avx512_[28]; ///< AVX512 zmm registers for C in the micro-kernel.
- x86::Zmm AllRegs_avx512_[32]; ///< all AVX512 zmm registers.
-
int vectorWidth_; ///< Vector width in bits.
int VLEN_; ///< Vector width in elements.
- static thread_local asmjit::JitRuntime rt_; ///< JIT Runtime for asmjit.
- static thread_local asmjit::CodeHolder code_; ///< JIT Code Holder for asmjit.
+
+ static asmjit::JitRuntime &runtime() {
+ static asmjit::JitRuntime rt; //< JIT Runtime for asmjit,
+ // depents on other static
+ // variables. Required to prevent
+ // initialization order fiasco
+ return rt;
+ }
+
+ static std::mutex rtMutex_; ///< Controll access to runtime;
+
// The hash depends on accumulate, mc, nc, ncb, kcb, nr, mr, nr_min
- static thread_local std::map<
- std::tuple<bool, int, int, int, int, int, int, int>,
- jit_micro_kernel_fp>
+ static CodeCache<std::tuple<bool, int, int, int, int, int, int, int>,
+ jit_micro_kernel_fp>
codeCache_; ///< JIT Code Cache for reuse.
};
+template <typename TA, typename TB, typename TC, typename accT>
+std::mutex CodeGenBase<TA, TB, TC, accT>::rtMutex_;
+
+template <typename TA, typename TB, typename TC, typename accT>
+CodeCache<std::tuple<bool, int, int, int, int, int, int, int>,
+ typename CodeGenBase<TA, TB, TC, accT>::jit_micro_kernel_fp>
+ CodeGenBase<TA, TB, TC, accT>::codeCache_;
+
} // namespace fbgemm
diff --git a/src/GenerateKernelU8S8S32ACC16.cc b/src/GenerateKernelU8S8S32ACC16.cc
index 1e7e081..cbd5877 100644
--- a/src/GenerateKernelU8S8S32ACC16.cc
+++ b/src/GenerateKernelU8S8S32ACC16.cc
@@ -9,18 +9,6 @@
namespace fbgemm {
-template <typename TA, typename TB, typename TC, typename accT>
-thread_local asmjit::JitRuntime CodeGenBase<TA, TB, TC, accT>::rt_;
-
-template <typename TA, typename TB, typename TC, typename accT>
-thread_local asmjit::CodeHolder CodeGenBase<TA, TB, TC, accT>::code_;
-
-template <typename TA, typename TB, typename TC, typename accT>
-thread_local std::map<
- std::tuple<bool, int, int, int, int, int, int, int>,
- typename CodeGenBase<TA, TB, TC, accT>::jit_micro_kernel_fp>
- CodeGenBase<TA, TB, TC, accT>::codeCache_;
-
namespace x86 = asmjit::x86;
/**
@@ -35,12 +23,13 @@ void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::initCRegs<
int rowRegs,
int colRegs,
int leadingDimCReg) {
+ using CRegs = x86::Ymm;
for (int i = 0; i < rowRegs; ++i) {
for (int j = 0; j < colRegs; ++j) {
a->vxorps(
- CRegs_avx2_[i * leadingDimCReg + j],
- CRegs_avx2_[i * leadingDimCReg + j],
- CRegs_avx2_[i * leadingDimCReg + j]);
+ CRegs(i * leadingDimCReg + j),
+ CRegs(i * leadingDimCReg + j),
+ CRegs(i * leadingDimCReg + j));
}
}
}
@@ -66,6 +55,8 @@ void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::genComputeBlock<
x86::Ymm tmpReg = x86::ymm14;
+ using CRegs = x86::Ymm;
+
for (int i = 0; i < rowRegs; ++i) {
// broadcast A
a->vpbroadcastw(
@@ -74,9 +65,9 @@ void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::genComputeBlock<
a->vpmaddubsw(
tmpReg, AReg, x86::dword_ptr(buffer_B, j * VLEN_ * sizeof(int8_t)));
a->vpaddsw(
- CRegs_avx2_[i * leadingDimCReg + j],
+ CRegs(i * leadingDimCReg + j),
tmpReg,
- CRegs_avx2_[i * leadingDimCReg + j]);
+ CRegs(i * leadingDimCReg + j));
// Prefetching is hurting performance in some cases
// because prefetch instructions itself consumes a slot
// in pipeline issue thus slowing down the kernel.
@@ -105,12 +96,13 @@ void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::storeCRegs<
x86::Xmm extractDest128 = x86::xmm15;
x86::Ymm extractDest256 = x86::ymm15;
+ using CRegs = x86::Ymm;
for (int i = 0; i < rowRegs; ++i) {
a->imul(C_Offset, ldcReg, static_cast<asmjit::Imm>(i * sizeof(int32_t)));
for (int j = 0; j < colRegs; ++j) {
for (int idx = 0; idx < 2; ++idx) {
a->vextracti128(
- extractDest128, CRegs_avx2_[i * leadingDimCReg + j], idx);
+ extractDest128, CRegs(i * leadingDimCReg + j), idx);
a->vpmovsxwd(extractDest256, extractDest128);
x86::Mem destAddr = x86::dword_ptr(
a->zcx(), C_Offset, 0, (j * 2 + idx) * 8 * sizeof(int32_t));
@@ -172,191 +164,186 @@ CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<inst_set_t::avx2>(
nRegBlockSize,
nRegBlockSizeMin);
- if (codeCache_.find(kernelSig) != codeCache_.end()) {
- return codeCache_[kernelSig];
- }
- code_.reset(false);
- code_.init(rt_.codeInfo());
- x86::Assembler assembler(&code_);
- x86::Emitter* a = assembler.as<x86::Emitter>();
+ return codeCache_.getOrCreate(kernelSig, [&]() -> jit_micro_kernel_fp {
+ asmjit::CodeHolder code;
+ code.init(runtime().codeInfo());
+ x86::Assembler assembler(&code);
+ x86::Emitter *a = assembler.as<x86::Emitter>();
#if defined(FBGEMM_LOG_CODE)
- // generated code logging
- FILE* codeLogfile = fopen(
- getCodeLoggingFile<inst_set_t::avx2>(
- accum,
- mc,
- nc,
- nBlock,
- kBlock,
- mRegBlockSize,
- nRegBlockSize,
- nRegBlockSizeMin)
- .c_str(),
- "w");
- asmjit::FileLogger* codeLogger = new asmjit::FileLogger(codeLogfile);
- if (codeLogger) {
- code_.setLogger(codeLogger);
- }
+ // generated code logging
+ FILE *codeLogfile = fopen(getCodeLoggingFile<inst_set_t::avx2>(
+ accum, mc, nc, nBlock, kBlock, mRegBlockSize,
+ nRegBlockSize, nRegBlockSizeMin)
+ .c_str(),
+ "w");
+ asmjit::FileLogger *codeLogger = new asmjit::FileLogger(codeLogfile);
+ if (codeLogger) {
+ code.setLogger(codeLogger);
+ }
#endif
- int mRegBlocks = mc / mRegBlockSize;
- int mRegBlocksRem = mc % mRegBlockSize;
- assert(kc % row_interleave == 0 && "kc must be a multiple of row_interleave");
- // assert((nc == nRegBlockSize) &&
- //"nc must be equal to the number of register blocks");
-
- // arguments to the function created
- x86::Gp buffer_A = a->zdi();
- x86::Gp buffer_B = a->zsi();
- x86::Gp B_pf = a->zdx();
- x86::Gp CBase = a->zcx();
- x86::Gp kSize = a->gpz(8);
- x86::Gp ldcReg = a->gpz(9);
-
- asmjit::FuncDetail func;
- func.init(
- asmjit::
- FuncSignatureT<void, uint8_t*, int8_t*, int8_t*, int32_t*, int, int>(
- asmjit::CallConv::kIdHost));
-
- asmjit::FuncFrame frame;
- frame.init(func);
- frame.setDirtyRegs(
- x86::Reg::kGroupVec,
- asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) |
- asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
- frame.setDirtyRegs(
- x86::Reg::kGroupGp, asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14));
-
- asmjit::FuncArgsAssignment args(&func);
- args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg);
-
- args.updateFuncFrame(frame);
- frame.finalize();
-
- a->emitProlog(frame);
- a->emitArgsAssignment(frame, args);
-
- asmjit::Label Loopk = a->newLabel();
- asmjit::Label LoopMBlocks = a->newLabel();
-
- x86::Gp buffer_B_saved = a->gpz(10);
- x86::Gp C_Offset = a->gpz(11);
- // x86::Gp B_pf_saved = a->gpz(12);
- x86::Gp iIdx = a->gpz(13);
- x86::Gp kIdx = a->gpz(14);
-
- int colRegs = nc * row_interleave * sizeof(int8_t) / VLEN_;
- if (mRegBlocks > 0) {
- // move 0 to iteration variables
- a->mov(iIdx, 0);
-
- // save B_buffer address
- a->mov(buffer_B_saved, buffer_B);
- // a->mov(B_pf_saved, B_pf);
-
- a->bind(LoopMBlocks);
- a->inc(iIdx);
-
- int rowRegs = mRegBlockSize;
-
- // init C registers
- initCRegs<inst_set_t::avx2>(a, rowRegs, colRegs);
-
- // init k loop index
- a->mov(kIdx, 0);
- a->bind(Loopk);
- // k is incremented by row_interleave
- a->add(kIdx, static_cast<asmjit::Imm>(row_interleave));
-
- genComputeBlock<inst_set_t::avx2>(
- a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock);
-
- // update buffer_A address for next k iteration
- a->add(
- buffer_A, static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t)));
-
- // update buffer_B address for next k iteration
- a->add(
- buffer_B,
- static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t)));
- // a->add(B_pf, static_cast<asmjit::Imm>(nBlock * row_interleave *
- // sizeof(int8_t)));
-
- a->cmp(kIdx, kSize);
- a->jl(Loopk);
-
- // store C matrix
- storeCRegs<inst_set_t::avx2>(a, rowRegs, colRegs, C_Offset, ldcReg, accum);
-
- // increment A for next block
- a->sub(buffer_A, kSize);
- a->add(
- buffer_A, static_cast<asmjit::Imm>((rowRegs)*kBlock * sizeof(uint8_t)));
- // increment C for next block
- a->imul(
- C_Offset, ldcReg, static_cast<asmjit::Imm>(rowRegs * sizeof(int32_t)));
- a->add(CBase, C_Offset);
- // reset B
- a->mov(buffer_B, buffer_B_saved);
- // a->mov(B_pf, B_pf_saved);
-
- a->cmp(iIdx, mRegBlocks);
- a->jl(LoopMBlocks);
- }
- // generate code for remainder
- if (mRegBlocksRem > 0) {
- asmjit::Label LoopkRem = a->newLabel();
- int rowRegs = mRegBlocksRem;
+ int mRegBlocks = mc / mRegBlockSize;
+ int mRegBlocksRem = mc % mRegBlockSize;
+ assert(kc % row_interleave == 0 &&
+ "kc must be a multiple of row_interleave");
+ // assert((nc == nRegBlockSize) &&
+ //"nc must be equal to the number of register blocks");
+
+ // arguments to the function created
+ x86::Gp buffer_A = a->zdi();
+ x86::Gp buffer_B = a->zsi();
+ x86::Gp B_pf = a->zdx();
+ x86::Gp CBase = a->zcx();
+ x86::Gp kSize = a->gpz(8);
+ x86::Gp ldcReg = a->gpz(9);
+
+ asmjit::FuncDetail func;
+ func.init(
+ asmjit::FuncSignatureT<void, uint8_t *, int8_t *, int8_t *, int32_t *,
+ int, int>(asmjit::CallConv::kIdHost));
+
+ asmjit::FuncFrame frame;
+ frame.init(func);
+ frame.setDirtyRegs(
+ x86::Reg::kGroupVec,
+ asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) |
+ asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
+ frame.setDirtyRegs(x86::Reg::kGroupGp,
+ asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14));
+
+ asmjit::FuncArgsAssignment args(&func);
+ args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg);
+
+ args.updateFuncFrame(frame);
+ frame.finalize();
+
+ a->emitProlog(frame);
+ a->emitArgsAssignment(frame, args);
+
+ asmjit::Label Loopk = a->newLabel();
+ asmjit::Label LoopMBlocks = a->newLabel();
+
+ x86::Gp buffer_B_saved = a->gpz(10);
+ x86::Gp C_Offset = a->gpz(11);
+ // x86::Gp B_pf_saved = a->gpz(12);
+ x86::Gp iIdx = a->gpz(13);
+ x86::Gp kIdx = a->gpz(14);
+
+ int colRegs = nc * row_interleave * sizeof(int8_t) / VLEN_;
+ if (mRegBlocks > 0) {
+ // move 0 to iteration variables
+ a->mov(iIdx, 0);
+
+ // save B_buffer address
+ a->mov(buffer_B_saved, buffer_B);
+ // a->mov(B_pf_saved, B_pf);
+
+ a->bind(LoopMBlocks);
+ a->inc(iIdx);
+
+ int rowRegs = mRegBlockSize;
+
+ // init C registers
+ initCRegs<inst_set_t::avx2>(a, rowRegs, colRegs);
+
+ // init k loop index
+ a->mov(kIdx, 0);
+ a->bind(Loopk);
+ // k is incremented by row_interleave
+ a->add(kIdx, static_cast<asmjit::Imm>(row_interleave));
+
+ genComputeBlock<inst_set_t::avx2>(a, buffer_A, buffer_B, B_pf, rowRegs,
+ colRegs, kBlock);
+
+ // update buffer_A address for next k iteration
+ a->add(buffer_A,
+ static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t)));
+
+ // update buffer_B address for next k iteration
+ a->add(buffer_B, static_cast<asmjit::Imm>(nBlock * row_interleave *
+ sizeof(int8_t)));
+ // a->add(B_pf, static_cast<asmjit::Imm>(nBlock * row_interleave *
+ // sizeof(int8_t)));
+
+ a->cmp(kIdx, kSize);
+ a->jl(Loopk);
+
+ // store C matrix
+ storeCRegs<inst_set_t::avx2>(a, rowRegs, colRegs, C_Offset, ldcReg,
+ accum);
+
+ // increment A for next block
+ a->sub(buffer_A, kSize);
+ a->add(buffer_A,
+ static_cast<asmjit::Imm>((rowRegs)*kBlock * sizeof(uint8_t)));
+ // increment C for next block
+ a->imul(C_Offset, ldcReg,
+ static_cast<asmjit::Imm>(rowRegs * sizeof(int32_t)));
+ a->add(CBase, C_Offset);
+ // reset B
+ a->mov(buffer_B, buffer_B_saved);
+ // a->mov(B_pf, B_pf_saved);
+
+ a->cmp(iIdx, mRegBlocks);
+ a->jl(LoopMBlocks);
+ }
+ // generate code for remainder
+ if (mRegBlocksRem > 0) {
+ asmjit::Label LoopkRem = a->newLabel();
+ int rowRegs = mRegBlocksRem;
- // init C registers
- initCRegs<inst_set_t::avx2>(a, rowRegs, colRegs);
+ // init C registers
+ initCRegs<inst_set_t::avx2>(a, rowRegs, colRegs);
- // init k loop index
- a->mov(kIdx, 0);
- a->bind(LoopkRem);
+ // init k loop index
+ a->mov(kIdx, 0);
+ a->bind(LoopkRem);
- // k is incremented by row_interleave
- a->add(kIdx, static_cast<asmjit::Imm>(row_interleave));
+ // k is incremented by row_interleave
+ a->add(kIdx, static_cast<asmjit::Imm>(row_interleave));
- genComputeBlock<inst_set_t::avx2>(
- a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock);
+ genComputeBlock<inst_set_t::avx2>(a, buffer_A, buffer_B, B_pf, rowRegs,
+ colRegs, kBlock);
- // update buffer_A address for next k iteration
- a->add(
- buffer_A, static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t)));
+ // update buffer_A address for next k iteration
+ a->add(buffer_A,
+ static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t)));
- // update buffer_B address for next k iteration
- a->add(
- buffer_B,
- static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t)));
- // a->add(B_pf, static_cast<asmjit::Imm>(nBlock * row_interleave *
- // sizeof(int8_t)));
+ // update buffer_B address for next k iteration
+ a->add(buffer_B, static_cast<asmjit::Imm>(nBlock * row_interleave *
+ sizeof(int8_t)));
+ // a->add(B_pf, static_cast<asmjit::Imm>(nBlock * row_interleave *
+ // sizeof(int8_t)));
- a->cmp(kIdx, kSize);
- a->jl(LoopkRem);
+ a->cmp(kIdx, kSize);
+ a->jl(LoopkRem);
- // store C matrix
- storeCRegs<inst_set_t::avx2>(a, rowRegs, colRegs, C_Offset, ldcReg, accum);
- }
+ // store C matrix
+ storeCRegs<inst_set_t::avx2>(a, rowRegs, colRegs, C_Offset, ldcReg,
+ accum);
+ }
- a->emitEpilog(frame);
+ a->emitEpilog(frame);
- jit_micro_kernel_fp fn;
- asmjit::Error err = rt_.add(&fn, &code_);
- if (err) {
- std::cout << "Error: in fn add" << std::endl;
- return nullptr;
- }
- codeCache_[kernelSig] = fn;
+ jit_micro_kernel_fp fn;
+ asmjit::Error err;
+ {
+ std::unique_lock<std::mutex> lock(rtMutex_);
+ err = runtime().add(&fn, &code);
+ }
+ if (err) {
+ std::cout << "Error: in fn add" << std::endl;
+ return nullptr;
+ }
#if defined(FBGEMM_LOG_CODE)
- fclose(codeLogfile);
- delete codeLogger;
+ fclose(codeLogfile);
+ delete codeLogger;
#endif
- return fn;
+ return fn;
+ });
}
} // namespace fbgemm
diff --git a/src/GenerateKernelU8S8S32ACC16Avx512.cc b/src/GenerateKernelU8S8S32ACC16Avx512.cc
index a49e440..512c8ba 100644
--- a/src/GenerateKernelU8S8S32ACC16Avx512.cc
+++ b/src/GenerateKernelU8S8S32ACC16Avx512.cc
@@ -23,12 +23,13 @@ void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::initCRegs<
int rowRegs,
int colRegs,
int leadingDimCReg) {
+ using CRegs = x86::Zmm;
for (int i = 0; i < rowRegs; ++i) {
for (int j = 0; j < colRegs; ++j) {
a->vxorps(
- CRegs_avx512_[i * leadingDimCReg + j],
- CRegs_avx512_[i * leadingDimCReg + j],
- CRegs_avx512_[i * leadingDimCReg + j]);
+ CRegs(i * leadingDimCReg + j),
+ CRegs(i * leadingDimCReg + j),
+ CRegs(i * leadingDimCReg + j));
}
}
}
@@ -57,20 +58,22 @@ void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::genComputeBlock<
// We start allocating BRegs from zmm27 and then allocate zmm26 and so on.
for (int j = 0; j < colRegs; ++j) {
a->vmovups(
- AllRegs_avx512_[27 - j],
+ x86::Zmm(27 - j),
x86::dword_ptr(buffer_B, j * VLEN_ * sizeof(int8_t)));
}
+ using CRegs = x86::Zmm;
+
for (int i = 0; i < rowRegs; ++i) {
// broadcast A
a->vpbroadcastw(
AReg, x86::dword_ptr(buffer_A, (i * lda) * sizeof(uint8_t)));
for (int j = 0; j < colRegs; ++j) {
- a->vpmaddubsw(tmpReg, AReg, AllRegs_avx512_[27 - j]);
+ a->vpmaddubsw(tmpReg, AReg, x86::Zmm(27 - j));
a->vpaddsw(
- CRegs_avx512_[i * leadingDimCReg + j],
+ CRegs(i * leadingDimCReg + j),
tmpReg,
- CRegs_avx512_[i * leadingDimCReg + j]);
+ CRegs(i * leadingDimCReg + j));
// Prefetching is hurting performance in some cases
// because prefetch instructions itself consumes a slot
// in pipeline issue thus slowing down the kernel.
@@ -100,12 +103,13 @@ void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::storeCRegs<
x86::Ymm extractDest256 = x86::ymm31;
x86::Zmm extractDest512 = x86::zmm31;
+ using CRegs = x86::Zmm;
for (int i = 0; i < rowRegs; ++i) {
a->imul(C_Offset, ldcReg, static_cast<asmjit::Imm>(i * sizeof(int32_t)));
for (int j = 0; j < colRegs; ++j) {
for (int idx = 0; idx < 2; ++idx) {
a->vextracti32x8(
- extractDest256, CRegs_avx512_[i * leadingDimCReg + j], idx);
+ extractDest256, CRegs(i * leadingDimCReg + j), idx);
a->vpmovsxwd(extractDest512, extractDest256);
x86::Mem destAddr = x86::dword_ptr(
a->zcx(), C_Offset, 0, (j * 2 + idx) * 16 * sizeof(int32_t));
@@ -167,262 +171,247 @@ CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<inst_set_t::avx512>(
nRegBlockSize,
nRegBlockSizeMin);
- if (codeCache_.find(kernelSig) != codeCache_.end()) {
- return codeCache_[kernelSig];
- }
-
- code_.reset(false);
- code_.init(rt_.codeInfo());
- x86::Assembler assembler(&code_);
- x86::Emitter* a = assembler.as<x86::Emitter>();
+ return codeCache_.getOrCreate(kernelSig, [&]() -> jit_micro_kernel_fp {
+ asmjit::CodeHolder code;
+ code.init(runtime().codeInfo());
+ x86::Assembler assembler(&code);
+ x86::Emitter *a = assembler.as<x86::Emitter>();
#if defined(FBGEMM_LOG_CODE)
- // generated code logging
- FILE* codeLogfile = fopen(
- getCodeLoggingFile<inst_set_t::avx512>(
- accum,
- mc,
- nc,
- nBlock,
- kBlock,
- mRegBlockSize,
- nRegBlockSize,
- nRegBlockSizeMin)
- .c_str(),
- "w");
- asmjit::FileLogger* codeLogger = new asmjit::FileLogger(codeLogfile);
- if (codeLogger) {
- code_.setLogger(codeLogger);
- }
+ // generated code logging
+ FILE *codeLogfile = fopen(getCodeLoggingFile<inst_set_t::avx512>(
+ accum, mc, nc, nBlock, kBlock, mRegBlockSize,
+ nRegBlockSize, nRegBlockSizeMin)
+ .c_str(),
+ "w");
+ asmjit::FileLogger *codeLogger = new asmjit::FileLogger(codeLogfile);
+ if (codeLogger) {
+ code.setLogger(codeLogger);
+ }
#endif
- assert(kc % row_interleave == 0 && "kc must be a multiple of row_interleave");
- assert(nc % nRegBlockSizeMin == 0 && "nc must be a multiple of NR_MIN");
- int maxMRegs = mRegBlockSize;
- int maxNRegs = nRegBlockSize * row_interleave / VLEN_;
- assert(
- (maxMRegs+1) * maxNRegs <= 28 &&
- "number of zmm registers for C + one row for loading B: \
+ assert(kc % row_interleave == 0 &&
+ "kc must be a multiple of row_interleave");
+ assert(nc % nRegBlockSizeMin == 0 && "nc must be a multiple of NR_MIN");
+ int maxMRegs = mRegBlockSize;
+ int maxNRegs = nRegBlockSize * row_interleave / VLEN_;
+ assert((maxMRegs + 1) * maxNRegs <= 28 &&
+ "number of zmm registers for C + one row for loading B: \
MR*(NR*ROW_INTERLEAVE*8/512) + (NR*ROW_INTERLEAVE*8/512) \
must be <= 28(available registers constraint)");
- int mRegBlocks = mc / mRegBlockSize;
- int mRegBlocksRem = mc % mRegBlockSize;
-
- // arguments to the function created
- x86::Gp buffer_A = a->zdi();
- x86::Gp buffer_B = a->zsi();
- x86::Gp B_pf = a->zdx();
- x86::Gp CBase = a->zcx();
- x86::Gp kSize = a->gpz(8);
- x86::Gp ldcReg = a->gpz(9);
-
- asmjit::FuncDetail func;
- func.init(
- asmjit::
- FuncSignatureT<void, uint8_t*, int8_t*, int8_t*, int32_t*, int, int>(
- asmjit::CallConv::kIdHost));
-
- asmjit::FuncFrame frame;
- frame.init(func);
-
- frame.setDirtyRegs(
- x86::Reg::kGroupVec,
- asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) |
- asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
- frame.setDirtyRegs(
- x86::Reg::kGroupGp,
- asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
-
- asmjit::FuncArgsAssignment args(&func);
- args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg);
-
- args.updateFuncFrame(frame);
- frame.finalize();
-
- a->emitProlog(frame);
- a->emitArgsAssignment(frame, args);
-
- asmjit::Label LoopMBlocks = a->newLabel();
- asmjit::Label LoopNBlocks = a->newLabel();
- asmjit::Label Loopk = a->newLabel();
-
- x86::Gp buffer_B_saved = a->gpz(10);
- x86::Gp C_Offset = a->gpz(11);
- // x86::Gp B_pf_saved = a->gpz(12);
- x86::Gp iIdx = a->gpz(13);
- x86::Gp jIdx = a->gpz(14);
- x86::Gp kIdx = a->gpz(15);
-
- // save B_buffer address
- a->mov(buffer_B_saved, buffer_B);
- // a->mov(B_pf_saved, B_pf);
-
- int currColRegs = nc * row_interleave * sizeof(int8_t) / VLEN_;
- int colRegs = std::min(currColRegs, maxNRegs);
- if (mRegBlocks > 0) {
- // move 0 to iteration variables
- a->mov(iIdx, 0);
-
- a->bind(LoopMBlocks);
- a->inc(iIdx);
- a->mov(jIdx, 0);
-
- a->bind(LoopNBlocks);
- a->inc(jIdx);
-
- int rowRegs = mRegBlockSize;
-
- // init C registers
- initCRegs<inst_set_t::avx512>(a, rowRegs, colRegs, colRegs);
-
- // init k loop index
- a->mov(kIdx, 0);
- a->bind(Loopk);
- // k is incremented by row_interleave
- a->add(kIdx, static_cast<asmjit::Imm>(row_interleave));
-
- genComputeBlock<inst_set_t::avx512>(
- a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock, colRegs);
-
- // update buffer_A address for next k iteration
- a->add(
- buffer_A, static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t)));
-
- // update buffer_B address for next k iteration
- a->add(
- buffer_B,
- static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t)));
- // a->add(B_pf, static_cast<asmjit::Imm>(nBlock * row_interleave *
- // sizeof(int8_t)));
-
- a->cmp(kIdx, kSize);
- a->jl(Loopk);
-
- // store C matrix
- storeCRegs<inst_set_t::avx512>(
- a, rowRegs, colRegs, C_Offset, ldcReg, accum, colRegs);
-
- // reset A
- a->sub(buffer_A, kSize);
-
- // B for next block
- a->mov(buffer_B, buffer_B_saved);
- // using C_Offset as temp reg
- a->imul(
- C_Offset,
- jIdx,
- static_cast<asmjit::Imm>(
- nRegBlockSize * row_interleave * sizeof(int8_t)));
- a->add(buffer_B, C_Offset);
-
- // increment C for next block
- a->add(CBase, static_cast<asmjit::Imm>(nRegBlockSize * sizeof(int32_t)));
-
- int jLoopTrips = currColRegs / maxNRegs;
- // jLoopTrips should be at least 1
- jLoopTrips = jLoopTrips ? jLoopTrips : 1;
- a->cmp(jIdx, jLoopTrips);
- a->jl(LoopNBlocks);
-
- // increment A for next block
- a->add(
- buffer_A, static_cast<asmjit::Imm>((rowRegs)*kBlock * sizeof(uint8_t)));
-
- // increment C for next A block
- a->sub(
- CBase,
- static_cast<asmjit::Imm>(jLoopTrips * nRegBlockSize * sizeof(int32_t)));
- a->imul(
- C_Offset, ldcReg, static_cast<asmjit::Imm>(rowRegs * sizeof(int32_t)));
- a->add(CBase, C_Offset);
-
- // reset B
- a->mov(buffer_B, buffer_B_saved);
- // a->mov(B_pf, B_pf_saved);
-
- a->cmp(iIdx, mRegBlocks);
- a->jl(LoopMBlocks);
- }
- // generate code for remainder
- if (mRegBlocksRem > 0) {
- asmjit::Label LoopNRem = a->newLabel();
- asmjit::Label LoopkRem = a->newLabel();
- int rowRegs = mRegBlocksRem;
-
- a->mov(jIdx, 0);
- a->bind(LoopNRem);
- a->inc(jIdx);
-
- // init C registers
- initCRegs<inst_set_t::avx512>(a, rowRegs, colRegs, colRegs);
-
- // init k loop index
- a->mov(kIdx, 0);
- a->bind(LoopkRem);
-
- // k is incremented by row_interleave
- a->add(kIdx, static_cast<asmjit::Imm>(row_interleave));
-
- genComputeBlock<inst_set_t::avx512>(
- a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock, colRegs);
-
- // update buffer_A address for next k iteration
- a->add(
- buffer_A, static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t)));
-
- // update buffer_B address for next k iteration
- a->add(
- buffer_B,
- static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t)));
- // a->add(B_pf, static_cast<asmjit::Imm>(nBlock * row_interleave *
- // sizeof(int8_t)));
-
- a->cmp(kIdx, kSize);
- a->jl(LoopkRem);
-
- // reset A
- a->sub(buffer_A, kSize);
-
- // B for next block
- a->mov(buffer_B, buffer_B_saved);
- // using C_Offset as temp reg
- a->imul(
- C_Offset,
- jIdx,
- static_cast<asmjit::Imm>(
- nRegBlockSize * row_interleave * sizeof(int8_t)));
- a->add(buffer_B, C_Offset);
-
- // store C matrix
- storeCRegs<inst_set_t::avx512>(
- a, rowRegs, colRegs, C_Offset, ldcReg, accum, colRegs);
-
- // increment C for next block
- a->add(CBase, static_cast<asmjit::Imm>(nRegBlockSize * sizeof(int32_t)));
-
- int jLoopTrips = currColRegs / maxNRegs;
- // jLoopTrips should be at least 1
- jLoopTrips = jLoopTrips ? jLoopTrips : 1;
- a->cmp(jIdx, jLoopTrips);
- a->jl(LoopNRem);
- }
+ int mRegBlocks = mc / mRegBlockSize;
+ int mRegBlocksRem = mc % mRegBlockSize;
+
+ // arguments to the function created
+ x86::Gp buffer_A = a->zdi();
+ x86::Gp buffer_B = a->zsi();
+ x86::Gp B_pf = a->zdx();
+ x86::Gp CBase = a->zcx();
+ x86::Gp kSize = a->gpz(8);
+ x86::Gp ldcReg = a->gpz(9);
+
+ asmjit::FuncDetail func;
+ func.init(
+ asmjit::FuncSignatureT<void, uint8_t *, int8_t *, int8_t *, int32_t *,
+ int, int>(asmjit::CallConv::kIdHost));
+
+ asmjit::FuncFrame frame;
+ frame.init(func);
+
+ frame.setDirtyRegs(
+ x86::Reg::kGroupVec,
+ asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) |
+ asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
+ frame.setDirtyRegs(x86::Reg::kGroupGp,
+ asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
+
+ asmjit::FuncArgsAssignment args(&func);
+ args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg);
+
+ args.updateFuncFrame(frame);
+ frame.finalize();
+
+ a->emitProlog(frame);
+ a->emitArgsAssignment(frame, args);
+
+ asmjit::Label LoopMBlocks = a->newLabel();
+ asmjit::Label LoopNBlocks = a->newLabel();
+ asmjit::Label Loopk = a->newLabel();
+
+ x86::Gp buffer_B_saved = a->gpz(10);
+ x86::Gp C_Offset = a->gpz(11);
+ // x86::Gp B_pf_saved = a->gpz(12);
+ x86::Gp iIdx = a->gpz(13);
+ x86::Gp jIdx = a->gpz(14);
+ x86::Gp kIdx = a->gpz(15);
+
+ // save B_buffer address
+ a->mov(buffer_B_saved, buffer_B);
+ // a->mov(B_pf_saved, B_pf);
+
+ int currColRegs = nc * row_interleave * sizeof(int8_t) / VLEN_;
+ int colRegs = std::min(currColRegs, maxNRegs);
+ if (mRegBlocks > 0) {
+ // move 0 to iteration variables
+ a->mov(iIdx, 0);
+
+ a->bind(LoopMBlocks);
+ a->inc(iIdx);
+ a->mov(jIdx, 0);
+
+ a->bind(LoopNBlocks);
+ a->inc(jIdx);
+
+ int rowRegs = mRegBlockSize;
+
+ // init C registers
+ initCRegs<inst_set_t::avx512>(a, rowRegs, colRegs, colRegs);
+
+ // init k loop index
+ a->mov(kIdx, 0);
+ a->bind(Loopk);
+ // k is incremented by row_interleave
+ a->add(kIdx, static_cast<asmjit::Imm>(row_interleave));
+
+ genComputeBlock<inst_set_t::avx512>(a, buffer_A, buffer_B, B_pf, rowRegs,
+ colRegs, kBlock, colRegs);
+
+ // update buffer_A address for next k iteration
+ a->add(buffer_A,
+ static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t)));
+
+ // update buffer_B address for next k iteration
+ a->add(buffer_B, static_cast<asmjit::Imm>(nBlock * row_interleave *
+ sizeof(int8_t)));
+ // a->add(B_pf, static_cast<asmjit::Imm>(nBlock * row_interleave *
+ // sizeof(int8_t)));
+
+ a->cmp(kIdx, kSize);
+ a->jl(Loopk);
+
+ // store C matrix
+ storeCRegs<inst_set_t::avx512>(a, rowRegs, colRegs, C_Offset, ldcReg,
+ accum, colRegs);
+
+ // reset A
+ a->sub(buffer_A, kSize);
+
+ // B for next block
+ a->mov(buffer_B, buffer_B_saved);
+ // using C_Offset as temp reg
+ a->imul(C_Offset, jIdx,
+ static_cast<asmjit::Imm>(nRegBlockSize * row_interleave *
+ sizeof(int8_t)));
+ a->add(buffer_B, C_Offset);
+
+ // increment C for next block
+ a->add(CBase, static_cast<asmjit::Imm>(nRegBlockSize * sizeof(int32_t)));
+
+ int jLoopTrips = currColRegs / maxNRegs;
+ // jLoopTrips should be at least 1
+ jLoopTrips = jLoopTrips ? jLoopTrips : 1;
+ a->cmp(jIdx, jLoopTrips);
+ a->jl(LoopNBlocks);
+
+ // increment A for next block
+ a->add(buffer_A,
+ static_cast<asmjit::Imm>((rowRegs)*kBlock * sizeof(uint8_t)));
+
+ // increment C for next A block
+ a->sub(CBase, static_cast<asmjit::Imm>(jLoopTrips * nRegBlockSize *
+ sizeof(int32_t)));
+ a->imul(C_Offset, ldcReg,
+ static_cast<asmjit::Imm>(rowRegs * sizeof(int32_t)));
+ a->add(CBase, C_Offset);
+
+ // reset B
+ a->mov(buffer_B, buffer_B_saved);
+ // a->mov(B_pf, B_pf_saved);
+
+ a->cmp(iIdx, mRegBlocks);
+ a->jl(LoopMBlocks);
+ }
+ // generate code for remainder
+ if (mRegBlocksRem > 0) {
+ asmjit::Label LoopNRem = a->newLabel();
+ asmjit::Label LoopkRem = a->newLabel();
+ int rowRegs = mRegBlocksRem;
+
+ a->mov(jIdx, 0);
+ a->bind(LoopNRem);
+ a->inc(jIdx);
+
+ // init C registers
+ initCRegs<inst_set_t::avx512>(a, rowRegs, colRegs, colRegs);
+
+ // init k loop index
+ a->mov(kIdx, 0);
+ a->bind(LoopkRem);
+
+ // k is incremented by row_interleave
+ a->add(kIdx, static_cast<asmjit::Imm>(row_interleave));
+
+ genComputeBlock<inst_set_t::avx512>(a, buffer_A, buffer_B, B_pf, rowRegs,
+ colRegs, kBlock, colRegs);
+
+ // update buffer_A address for next k iteration
+ a->add(buffer_A,
+ static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t)));
+
+ // update buffer_B address for next k iteration
+ a->add(buffer_B, static_cast<asmjit::Imm>(nBlock * row_interleave *
+ sizeof(int8_t)));
+ // a->add(B_pf, static_cast<asmjit::Imm>(nBlock * row_interleave *
+ // sizeof(int8_t)));
+
+ a->cmp(kIdx, kSize);
+ a->jl(LoopkRem);
+
+ // reset A
+ a->sub(buffer_A, kSize);
+
+ // B for next block
+ a->mov(buffer_B, buffer_B_saved);
+ // using C_Offset as temp reg
+ a->imul(C_Offset, jIdx,
+ static_cast<asmjit::Imm>(nRegBlockSize * row_interleave *
+ sizeof(int8_t)));
+ a->add(buffer_B, C_Offset);
+
+ // store C matrix
+ storeCRegs<inst_set_t::avx512>(a, rowRegs, colRegs, C_Offset, ldcReg,
+ accum, colRegs);
+
+ // increment C for next block
+ a->add(CBase, static_cast<asmjit::Imm>(nRegBlockSize * sizeof(int32_t)));
+
+ int jLoopTrips = currColRegs / maxNRegs;
+ // jLoopTrips should be at least 1
+ jLoopTrips = jLoopTrips ? jLoopTrips : 1;
+ a->cmp(jIdx, jLoopTrips);
+ a->jl(LoopNRem);
+ }
- a->emitEpilog(frame);
+ a->emitEpilog(frame);
- jit_micro_kernel_fp fn;
- asmjit::Error err = rt_.add(&fn, &code_);
- if (err) {
- std::cout << "Error: in fn add" << std::endl;
- return nullptr;
- }
- codeCache_[kernelSig] = fn;
+ jit_micro_kernel_fp fn;
+ asmjit::Error err;
+ {
+ std::unique_lock<std::mutex> lock(rtMutex_);
+ err = runtime().add(&fn, &code);
+ }
+ if (err) {
+ std::cout << "Error: in fn add" << std::endl;
+ return nullptr;
+ }
#if defined(FBGEMM_LOG_CODE)
- fclose(codeLogfile);
- delete codeLogger;
+ fclose(codeLogfile);
+ delete codeLogger;
#endif
- return fn;
+ return fn;
+ });
}
} // namespace fbgemm
diff --git a/src/GenerateKernelU8S8S32ACC32.cc b/src/GenerateKernelU8S8S32ACC32.cc
index 6b54743..226e974 100644
--- a/src/GenerateKernelU8S8S32ACC32.cc
+++ b/src/GenerateKernelU8S8S32ACC32.cc
@@ -9,18 +9,6 @@
namespace fbgemm {
-template <typename TA, typename TB, typename TC, typename accT>
-thread_local asmjit::JitRuntime CodeGenBase<TA, TB, TC, accT>::rt_;
-
-template <typename TA, typename TB, typename TC, typename accT>
-thread_local asmjit::CodeHolder CodeGenBase<TA, TB, TC, accT>::code_;
-
-template <typename TA, typename TB, typename TC, typename accT>
-thread_local std::map<
- std::tuple<bool, int, int, int, int, int, int, int>,
- typename CodeGenBase<TA, TB, TC, accT>::jit_micro_kernel_fp>
- CodeGenBase<TA, TB, TC, accT>::codeCache_;
-
namespace x86 = asmjit::x86;
/**
@@ -35,12 +23,13 @@ void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::initCRegs<
int rowRegs,
int colRegs,
int leadingDimCReg) {
+ using CRegs = x86::Ymm;
for (int i = 0; i < rowRegs; ++i) {
for (int j = 0; j < colRegs; ++j) {
a->vxorps(
- CRegs_avx2_[i * leadingDimCReg + j],
- CRegs_avx2_[i * leadingDimCReg + j],
- CRegs_avx2_[i * leadingDimCReg + j]);
+ CRegs(i * leadingDimCReg + j),
+ CRegs(i * leadingDimCReg + j),
+ CRegs(i * leadingDimCReg + j));
}
}
}
@@ -73,6 +62,8 @@ void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::genComputeBlock<
// temporary register
x86::Ymm res1 = x86::ymm14;
+ using CRegs = x86::Ymm;
+
for (int j = 0; j < colRegs; ++j) {
// load B
a->vmovaps(BReg, x86::dword_ptr(buffer_B, j * VLEN_ * sizeof(int8_t)));
@@ -83,9 +74,9 @@ void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::genComputeBlock<
a->vpmaddubsw(res1, AReg, BReg);
a->vpmaddwd(res1, oneReg, res1);
a->vpaddd(
- CRegs_avx2_[i * leadingDimCReg + j],
+ CRegs(i * leadingDimCReg + j),
res1,
- CRegs_avx2_[i * leadingDimCReg + j]);
+ CRegs(i * leadingDimCReg + j));
}
a->prefetcht0(x86::dword_ptr(B_pf, j * VLEN_ * sizeof(int8_t)));
}
@@ -106,6 +97,7 @@ void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::storeCRegs<
x86::Gp ldcReg,
bool accum,
int leadingDimCReg) {
+ using CRegs = x86::Ymm;
for (int i = 0; i < rowRegs; ++i) {
if (i != 0) {
a->add(C_Offset, ldcReg);
@@ -113,13 +105,13 @@ void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::storeCRegs<
for (int j = 0; j < colRegs; ++j) {
if (accum) {
a->vpaddd(
- CRegs_avx2_[i * leadingDimCReg + j],
- CRegs_avx2_[i * leadingDimCReg + j],
+ CRegs(i * leadingDimCReg + j),
+ CRegs(i * leadingDimCReg + j),
x86::dword_ptr(a->zcx(), C_Offset, 0, j * 8 * sizeof(int32_t)));
}
a->vmovups(
x86::dword_ptr(a->zcx(), C_Offset, 0, j * 8 * sizeof(int32_t)),
- CRegs_avx2_[i * leadingDimCReg + j]);
+ CRegs(i * leadingDimCReg + j));
}
}
}
@@ -173,206 +165,169 @@ CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreate<inst_set_t::avx2>(
nRegBlockSize,
nRegBlockSizeMin);
- if (codeCache_.find(kernelSig) != codeCache_.end()) {
- return codeCache_[kernelSig];
- }
- code_.reset(false);
- code_.init(rt_.codeInfo());
- x86::Assembler assembler(&code_);
- x86::Emitter* a = assembler.as<x86::Emitter>();
+ return codeCache_.getOrCreate(kernelSig, [&]() -> jit_micro_kernel_fp {
+ asmjit::CodeHolder code;
+ code.init(runtime().codeInfo());
+ x86::Assembler assembler(&code);
+ x86::Emitter *a = assembler.as<x86::Emitter>();
#if defined(FBGEMM_LOG_CODE)
- // generated code logging
- FILE* codeLogfile = fopen(
- getCodeLoggingFile<inst_set_t::avx2>(
- accum,
- mc,
- nc,
- nBlock,
- kBlock,
- mRegBlockSize,
- nRegBlockSize,
- nRegBlockSizeMin)
- .c_str(),
- "w");
- asmjit::FileLogger* codeLogger = new asmjit::FileLogger(codeLogfile);
- if (codeLogger) {
- code_.setLogger(codeLogger);
- }
+ // generated code logging
+ FILE *codeLogfile = fopen(getCodeLoggingFile<inst_set_t::avx2>(
+ accum, mc, nc, nBlock, kBlock, mRegBlockSize,
+ nRegBlockSize, nRegBlockSizeMin)
+ .c_str(),
+ "w");
+ asmjit::FileLogger *codeLogger = new asmjit::FileLogger(codeLogfile);
+ if (codeLogger) {
+ code.setLogger(codeLogger);
+ }
#endif
- // assert(mc <= 12 && "mc must be <= 12 (available registers constraint)");
- int mRegBlocks = mc / mRegBlockSize;
- int mRegBlocksRem = mc % mRegBlockSize;
-
- // arguments to the function created
- x86::Gp buffer_A = a->zdi();
- x86::Gp buffer_B = a->zsi();
- x86::Gp B_pf = a->zdx();
- x86::Gp CBase = a->zcx();
- x86::Gp kSize = a->gpz(8);
- x86::Gp ldcReg = a->gpz(9);
-
- asmjit::FuncDetail func;
- func.init(
- asmjit::
- FuncSignatureT<void, uint8_t*, int8_t*, int8_t*, int32_t*, int, int>(
- asmjit::CallConv::kIdHost));
-
- asmjit::FuncFrame frame;
- frame.init(func);
- frame.setDirtyRegs(
- x86::Reg::kGroupVec,
- asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) |
- asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
- frame.setDirtyRegs(
- x86::Reg::kGroupGp, asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14));
-
- asmjit::FuncArgsAssignment args(&func);
- args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg);
-
- args.updateFuncFrame(frame);
- frame.finalize();
-
- a->emitProlog(frame);
- a->emitArgsAssignment(frame, args);
-
- asmjit::Label Loopk = a->newLabel();
- asmjit::Label LoopMBlocks = a->newLabel();
-
- x86::Gp buffer_B_saved = a->gpz(10);
- x86::Gp C_Offset = a->gpz(11);
- x86::Gp B_pf_saved = a->gpz(12);
- x86::Gp iIdx = a->gpz(13);
- x86::Gp kIdx = a->gpz(14);
- // x86::Gp B_pf = a->gpz(8);
-
- x86::Ymm oneReg = x86::ymm15;
- // create 16-bit 1s
- // i.e., oneReg[0:15] contains 0x0001, oneReg[16:31] contains 0x0001
- // and so on
- a->vpcmpeqw(oneReg, oneReg, oneReg);
- a->vpsrlw(oneReg, oneReg, 15);
- a->imul(ldcReg, ldcReg, static_cast<asmjit::Imm>(sizeof(int32_t)));
- a->mov(C_Offset, 0);
-
- int colRegs = nc * row_interleave * sizeof(int8_t) / VLEN_;
- if (mRegBlocks > 0) {
- // move 0 to iteration variables
- a->mov(iIdx, 0);
-
- // save B_buffer address
- a->mov(buffer_B_saved, buffer_B);
- a->mov(B_pf_saved, B_pf);
-
- a->bind(LoopMBlocks);
- a->inc(iIdx);
-
- int rowRegs = mRegBlockSize;
-
- // init C registers
- initCRegs<inst_set_t::avx2>(a, rowRegs, colRegs, colRegs);
-
- // init k loop index
- a->mov(kIdx, 0);
- a->bind(Loopk);
-
- // k is incremented by row_interleave
- a->add(kIdx, static_cast<asmjit::Imm>(row_interleave));
-
- genComputeBlock<inst_set_t::avx2>(
- a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock, colRegs);
-
- // update buffer_A address for next k iteration
- a->add(
- buffer_A, static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t)));
-
- // update buffer_B address for next k iteration
- a->add(
- buffer_B,
- static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t)));
- a->add(
- B_pf,
- static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t)));
-
- // a->add(B_pf, 32*sizeof(float));
-
- a->cmp(kIdx, kSize);
- a->jl(Loopk);
-
- // store C matrix
- storeCRegs<inst_set_t::avx2>(
- a, rowRegs, colRegs, C_Offset, ldcReg, accum, colRegs);
-
- // increment A for next block
- a->sub(buffer_A, kSize);
- a->add(
- buffer_A, static_cast<asmjit::Imm>((rowRegs)*kBlock * sizeof(uint8_t)));
-
- // increment C for next block
- a->imul(C_Offset, ldcReg, static_cast<asmjit::Imm>(rowRegs));
- a->add(CBase, C_Offset);
+ // assert(mc <= 12 && "mc must be <= 12 (available registers constraint)");
+ int mRegBlocks = mc / mRegBlockSize;
+ int mRegBlocksRem = mc % mRegBlockSize;
+
+ // arguments to the function created
+ x86::Gp buffer_A = a->zdi();
+ x86::Gp buffer_B = a->zsi();
+ x86::Gp B_pf = a->zdx();
+ x86::Gp CBase = a->zcx();
+ x86::Gp kSize = a->gpz(8);
+ x86::Gp ldcReg = a->gpz(9);
+
+ asmjit::FuncDetail func;
+ func.init(
+ asmjit::FuncSignatureT<void, uint8_t *, int8_t *, int8_t *, int32_t *,
+ int, int>(asmjit::CallConv::kIdHost));
+
+ asmjit::FuncFrame frame;
+ frame.init(func);
+ frame.setDirtyRegs(
+ x86::Reg::kGroupVec,
+ asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) |
+ asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
+ frame.setDirtyRegs(x86::Reg::kGroupGp,
+ asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14));
+
+ asmjit::FuncArgsAssignment args(&func);
+ args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg);
+
+ args.updateFuncFrame(frame);
+ frame.finalize();
+
+ a->emitProlog(frame);
+ a->emitArgsAssignment(frame, args);
+
+ asmjit::Label LoopMBlocks = a->newLabel();
+
+ x86::Gp buffer_B_saved = a->gpz(10);
+ x86::Gp C_Offset = a->gpz(11);
+ x86::Gp B_pf_saved = a->gpz(12);
+ x86::Gp iIdx = a->gpz(13);
+ x86::Gp kIdx = a->gpz(14);
+ // x86::Gp B_pf = a->gpz(8);
+
+ x86::Ymm oneReg = x86::ymm15;
+ // create 16-bit 1s
+ // i.e., oneReg[0:15] contains 0x0001, oneReg[16:31] contains 0x0001
+ // and so on
+ a->vpcmpeqw(oneReg, oneReg, oneReg);
+ a->vpsrlw(oneReg, oneReg, 15);
+ a->imul(ldcReg, ldcReg, static_cast<asmjit::Imm>(sizeof(int32_t)));
a->mov(C_Offset, 0);
- // reset B
- a->mov(buffer_B, buffer_B_saved);
- a->mov(B_pf, B_pf_saved);
- a->cmp(iIdx, mRegBlocks);
- a->jl(LoopMBlocks);
- }
- // generate code for remainder
- if (mRegBlocksRem > 0) {
- asmjit::Label LoopkRem = a->newLabel();
- int rowRegs = mRegBlocksRem;
-
- // init C registers
- initCRegs<inst_set_t::avx2>(a, rowRegs, colRegs, colRegs);
-
- // init k loop index
- a->mov(kIdx, 0);
- a->bind(LoopkRem);
-
- // k is incremented by row_interleave
- a->add(kIdx, static_cast<asmjit::Imm>(row_interleave));
-
- genComputeBlock<inst_set_t::avx2>(
- a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock, colRegs);
-
- // update buffer_A address for next k iteration
- a->add(
- buffer_A, static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t)));
-
- // update buffer_B address for next k iteration
- a->add(
- buffer_B,
- static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t)));
- a->add(
- B_pf,
- static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t)));
-
- a->cmp(kIdx, kSize);
- a->jl(LoopkRem);
-
- // store C matrix
- storeCRegs<inst_set_t::avx2>(
- a, rowRegs, colRegs, C_Offset, ldcReg, accum, colRegs);
- }
+ int colRegs = nc * row_interleave * sizeof(int8_t) / VLEN_;
- a->emitEpilog(frame);
+ auto issueLoopOverK = [&](int rowRegs) {
+ asmjit::Label LoopKLabel = a->newLabel();
- jit_micro_kernel_fp fn;
- asmjit::Error err = rt_.add(&fn, &code_);
- if (err) {
- std::cout << "Error: in fn add" << std::endl;
- return nullptr;
- }
- codeCache_[kernelSig] = fn;
+ // Init C (result) vector registers
+ initCRegs<inst_set_t::avx2>(a, rowRegs, colRegs, colRegs);
+
+ // Loops over K
+ a->mov(kIdx, 0);
+ a->bind(LoopKLabel);
+
+ // k is incremented by row_interleave
+ a->add(kIdx, static_cast<asmjit::Imm>(row_interleave));
+
+ genComputeBlock<inst_set_t::avx2>(a, buffer_A, buffer_B, B_pf, rowRegs,
+ colRegs, kBlock, colRegs);
+
+ // update buffer_A address for next k iteration
+ a->add(buffer_A,
+ static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t)));
+
+ // update buffer_B address for next k iteration
+ a->add(buffer_B, static_cast<asmjit::Imm>(nBlock * row_interleave *
+ sizeof(int8_t)));
+ a->add(B_pf, static_cast<asmjit::Imm>(nBlock * row_interleave *
+ sizeof(int8_t)));
+
+ a->cmp(kIdx, kSize);
+ a->jl(LoopKLabel);
+
+ // store C matrix
+ storeCRegs<inst_set_t::avx2>(a, rowRegs, colRegs, C_Offset, ldcReg, accum,
+ colRegs);
+ };
+
+ if (mRegBlocks > 0) {
+ // move 0 to iteration variables
+ a->mov(iIdx, 0);
+
+ // save B_buffer address
+ a->mov(buffer_B_saved, buffer_B);
+ a->mov(B_pf_saved, B_pf);
+
+ a->bind(LoopMBlocks);
+ a->inc(iIdx);
+
+ issueLoopOverK(mRegBlockSize);
+
+ int rowRegs = mRegBlockSize;
+ // increment A for next block
+ a->sub(buffer_A, kSize);
+ a->add(buffer_A,
+ static_cast<asmjit::Imm>((rowRegs)*kBlock * sizeof(uint8_t)));
+
+ // increment C for next block
+ a->imul(C_Offset, ldcReg, static_cast<asmjit::Imm>(rowRegs));
+ a->add(CBase, C_Offset);
+ a->mov(C_Offset, 0);
+
+ // reset B
+ a->mov(buffer_B, buffer_B_saved);
+ a->mov(B_pf, B_pf_saved);
+
+ a->cmp(iIdx, mRegBlocks);
+ a->jl(LoopMBlocks);
+ }
+ // generate code for remainder
+ if (mRegBlocksRem > 0) {
+ issueLoopOverK(mRegBlocksRem);
+ }
+
+ a->emitEpilog(frame);
+
+ jit_micro_kernel_fp fn;
+ asmjit::Error err;
+ {
+ std::unique_lock<std::mutex> lock(rtMutex_);
+ err = runtime().add(&fn, &code);
+ }
+ if (err) {
+ std::cout << "Error: in fn add" << std::endl;
+ return nullptr;
+ }
#if defined(FBGEMM_LOG_CODE)
- fclose(codeLogfile);
- delete codeLogger;
+ fclose(codeLogfile);
+ delete codeLogger;
#endif
- return fn;
+ return fn;
+ });
}
} // namespace fbgemm
diff --git a/src/GenerateKernelU8S8S32ACC32Avx512.cc b/src/GenerateKernelU8S8S32ACC32Avx512.cc
index 5986e48..5037292 100644
--- a/src/GenerateKernelU8S8S32ACC32Avx512.cc
+++ b/src/GenerateKernelU8S8S32ACC32Avx512.cc
@@ -23,12 +23,13 @@ void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::initCRegs<
int rowRegs,
int colRegs,
int leadingDimCReg) {
+ using CRegs = x86::Zmm;
for (int i = 0; i < rowRegs; ++i) {
for (int j = 0; j < colRegs; ++j) {
a->vxorps(
- CRegs_avx512_[i * leadingDimCReg + j],
- CRegs_avx512_[i * leadingDimCReg + j],
- CRegs_avx512_[i * leadingDimCReg + j]);
+ CRegs(i * leadingDimCReg + j),
+ CRegs(i * leadingDimCReg + j),
+ CRegs(i * leadingDimCReg + j));
}
}
}
@@ -61,6 +62,7 @@ void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::genComputeBlock<
// temporary register
x86::Zmm res1 = x86::zmm28;
+ using CRegs = x86::Zmm;
for (int j = 0; j < colRegs; ++j) {
// load B
a->vmovaps(BReg, x86::dword_ptr(buffer_B, j * VLEN_ * sizeof(int8_t)));
@@ -71,9 +73,9 @@ void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::genComputeBlock<
a->vpmaddubsw(res1, AReg, BReg);
a->vpmaddwd(res1, oneReg, res1);
a->vpaddd(
- CRegs_avx512_[i * leadingDimCReg + j],
+ CRegs(i * leadingDimCReg + j),
res1,
- CRegs_avx512_[i * leadingDimCReg + j]);
+ CRegs(i * leadingDimCReg + j));
}
a->prefetcht0(x86::dword_ptr(B_pf, j * VLEN_ * sizeof(int8_t)));
}
@@ -94,6 +96,7 @@ void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::storeCRegs<
x86::Gp ldcReg,
bool accum,
int leadingDimCReg) {
+ using CRegs = x86::Zmm;
for (int i = 0; i < rowRegs; ++i) {
if (i != 0) {
a->add(C_Offset, ldcReg);
@@ -103,15 +106,21 @@ void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::storeCRegs<
for (int j = 0; j < colRegs; ++j) {
if (accum) {
a->vpaddd(
- CRegs_avx512_[i * leadingDimCReg + j],
- CRegs_avx512_[i * leadingDimCReg + j],
+ CRegs(i * leadingDimCReg + j),
+ CRegs(i * leadingDimCReg + j),
+#ifdef _MSC_VER
x86::dword_ptr(a->gpz(9), C_Offset, 0, j * 16 * sizeof(int32_t)));
- // x86::dword_ptr(a->zcx(), C_Offset, 0, j * 16 * sizeof(int32_t)));
+#else
+ x86::dword_ptr(a->zcx(), C_Offset, 0, j * 16 * sizeof(int32_t)));
+#endif
}
a->vmovups(
+#ifdef _MSC_VER
x86::dword_ptr(a->gpz(9), C_Offset, 0, j * 16 * sizeof(int32_t)),
-// x86::dword_ptr(a->zcx(), C_Offset, 0, j * 16 * sizeof(int32_t)),
- CRegs_avx512_[i * leadingDimCReg + j]);
+#else
+ x86::dword_ptr(a->zcx(), C_Offset, 0, j * 16 * sizeof(int32_t)),
+#endif
+ CRegs(i * leadingDimCReg + j));
}
}
}
@@ -165,302 +174,269 @@ CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreate<inst_set_t::avx512>(
nRegBlockSize,
nRegBlockSizeMin);
- if (codeCache_.find(kernelSig) != codeCache_.end()) {
- return codeCache_[kernelSig];
- }
- code_.reset(false);
- code_.init(rt_.codeInfo());
- x86::Assembler assembler(&code_);
- x86::Emitter* a = assembler.as<x86::Emitter>();
-
+ return codeCache_.getOrCreate(kernelSig, [&]() -> jit_micro_kernel_fp {
+ asmjit::CodeHolder code;
+ code.init(runtime().codeInfo());
+ x86::Assembler assembler(&code);
+ x86::Emitter *a = assembler.as<x86::Emitter>();
#if defined(FBGEMM_LOG_CODE)
- // generated code logging
- FILE* codeLogfile = fopen(
- getCodeLoggingFile<inst_set_t::avx512>(
- accum,
- mc,
- nc,
- nBlock,
- kBlock,
- mRegBlockSize,
- nRegBlockSize,
- nRegBlockSizeMin)
- .c_str(),
- "w");
- asmjit::FileLogger* codeLogger = new asmjit::FileLogger(codeLogfile);
- if (codeLogger) {
- code_.setLogger(codeLogger);
- }
+ // generated code logging
+ FILE *codeLogfile = fopen(getCodeLoggingFile<inst_set_t::avx512>(
+ accum, mc, nc, nBlock, kBlock, mRegBlockSize,
+ nRegBlockSize, nRegBlockSizeMin)
+ .c_str(),
+ "w");
+ asmjit::FileLogger *codeLogger = new asmjit::FileLogger(codeLogfile);
+ if (codeLogger) {
+ code.setLogger(codeLogger);
+ }
#endif
- assert(kc % row_interleave == 0 && "kc must be a multiple of row_interleave");
- assert(nc % nRegBlockSizeMin == 0 && "nc must be a multiple of NR_MIN");
- int maxMRegs = mRegBlockSize;
- int maxNRegs = nRegBlockSize * row_interleave / VLEN_;
- assert(
- maxMRegs * maxNRegs <= 28 &&
- "MR*(NR*ROW_INTERLEAVE*8/512) \
+ assert(kc % row_interleave == 0 &&
+ "kc must be a multiple of row_interleave");
+ assert(nc % nRegBlockSizeMin == 0 && "nc must be a multiple of NR_MIN");
+ int maxMRegs = mRegBlockSize;
+ int maxNRegs = nRegBlockSize * row_interleave / VLEN_;
+ assert(maxMRegs * maxNRegs <= 28 && "MR*(NR*ROW_INTERLEAVE*8/512) \
must be <= 28(available registers constraint)");
- int mRegBlocks = mc / mRegBlockSize;
- int mRegBlocksRem = mc % mRegBlockSize;
+ int mRegBlocks = mc / mRegBlockSize;
+ int mRegBlocksRem = mc % mRegBlockSize;
- // arguments to the function created
+ // arguments to the function created
#ifdef _MSC_VER
- x86::Gp buffer_A = a->zcx();
- x86::Gp buffer_B = a->zdx();
- x86::Gp B_pf = a->gpz(8);
- x86::Gp CBase = a->gpz(9);
- x86::Gp kSize = a->zdi(); // a->zsi(); // x86::esi; // a->zsi();
- x86::Gp ldcReg = a->zsi(); // a->zdi(); // x86::edi; // a->zdi();
+ x86::Gp buffer_A = a->zcx();
+ x86::Gp buffer_B = a->zdx();
+ x86::Gp B_pf = a->gpz(8);
+ x86::Gp CBase = a->gpz(9);
+ x86::Gp kSize = a->zdi();
+ x86::Gp ldcReg = a->zsi();
#else
- x86::Gp buffer_A = a->zdi();
- x86::Gp buffer_B = a->zsi();
- x86::Gp B_pf = a->zdx();
- x86::Gp CBase = a->zcx();
- x86::Gp kSize = a->gpz(8);
- x86::Gp ldcReg = a->gpz(9);
+ x86::Gp buffer_A = a->zdi();
+ x86::Gp buffer_B = a->zsi();
+ x86::Gp B_pf = a->zdx();
+ x86::Gp CBase = a->zcx();
+ x86::Gp kSize = a->gpz(8);
+ x86::Gp ldcReg = a->gpz(9);
#endif
- asmjit::FuncDetail func;
-#ifdef _MSC_VER
- //func.init(asmjit::FuncSignature4<void, uint8_t*, int8_t*, int8_t*, int32_t*>(
- // asmjit::CallConv::kIdHost));
- func.init(asmjit::FuncSignatureT<void, uint8_t*, int8_t*, int8_t*, int32_t*, int, int>(
- asmjit::CallConv::kIdHost));
-#else
- func.init(
- asmjit::
- FuncSignatureT<void, uint8_t*, int8_t*, int8_t*, int32_t*, int, int>(
- asmjit::CallConv::kIdHost));
-#endif
-
- asmjit::FuncFrame frame;
- frame.init(func);
-
- frame.setDirtyRegs(
- x86::Reg::kGroupVec,
- asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) |
- asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
- frame.setDirtyRegs(
- x86::Reg::kGroupGp,
- asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
-
- asmjit::FuncArgsAssignment args(&func);
- args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg);
-
- args.updateFuncFrame(frame);
- frame.finalize();
-
- a->emitProlog(frame);
- a->emitArgsAssignment(frame, args);
-
-//#ifdef _MSC_VER
-// // retrieve parameters from stack
-// a->mov(kSize, asmjit::x86::dword_ptr(asmjit::x86::rsp, func.getArg(4).getStackOffset())); //0x20)); //func.getArg(4).getStackOffset()));
-// std::cout << "func.getArg(4).getStackOffset(): " << func.getArg(4).getStackOffset() << std::endl;
-// a->mov(ldcReg, asmjit::x86::dword_ptr(asmjit::x86::rsp, func.getArg(5).getStackOffset())); //;0x28)); //func.getArg(5).getStackOffset()));
-// std::cout << "func.getArg(5).getStackOffset(): " << func.getArg(5).getStackOffset() << std::endl;
-//#endif
-
- asmjit::Label LoopMBlocks = a->newLabel();
- asmjit::Label LoopNBlocks = a->newLabel();
- asmjit::Label Loopk = a->newLabel();
-
- x86::Gp buffer_B_saved = a->gpz(10);
- x86::Gp C_Offset = a->gpz(11);
- x86::Gp B_pf_saved = a->gpz(12);
- x86::Gp iIdx = a->gpz(13);
- x86::Gp jIdx = a->gpz(14);
- x86::Gp kIdx = a->gpz(15);
- // x86::Gp B_pf = a->gpz(8);
-
- x86::Zmm oneReg = x86::zmm29;
- // create 16-bit 1s
- // i.e., oneReg[0:15] contains 0x0001, oneReg[16:31] contains 0x0001
- // and so on
- // a->vpcmpeqw(oneReg, oneReg, oneReg);
- a->vpternlogd(oneReg, oneReg, oneReg, 0xff);
- a->vpsrlw(oneReg, oneReg, 15);
- a->imul(ldcReg, ldcReg, static_cast<asmjit::Imm>(sizeof(int32_t)));
-
- // save B_buffer address
- a->mov(buffer_B_saved, buffer_B);
- a->mov(B_pf_saved, B_pf);
-
- int currColRegs = nc * row_interleave * sizeof(int8_t) / VLEN_;
- int colRegs = std::min(currColRegs, maxNRegs);
- if (mRegBlocks > 0) {
- // move 0 to iteration variables
- a->mov(iIdx, 0);
-
- a->bind(LoopMBlocks);
- a->inc(iIdx);
- a->mov(jIdx, 0);
-
- a->bind(LoopNBlocks);
- a->inc(jIdx);
-
- int rowRegs = mRegBlockSize;
-
- // init C registers
- initCRegs<inst_set_t::avx512>(a, rowRegs, colRegs, colRegs);
-
- // init k loop index
- a->mov(kIdx, 0);
- a->bind(Loopk);
-
- // k is incremented by row_interleave
- a->add(kIdx, static_cast<asmjit::Imm>(row_interleave));
-
- genComputeBlock<inst_set_t::avx512>(
- a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock, colRegs);
-
- // update buffer_A address for next k iteration
- a->add(
- buffer_A, static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t)));
-
- // update buffer_B address for next k iteration
- a->add(
- buffer_B,
- static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t)));
- a->add(
- B_pf,
- static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t)));
-
- // a->add(B_pf, static_cast<asmjit::Imm>(32*sizeof(float)));
-
- a->cmp(kIdx, kSize);
- a->jl(Loopk);
-
- // store C matrix
- storeCRegs<inst_set_t::avx512>(
- a, rowRegs, colRegs, C_Offset, ldcReg, accum, colRegs);
-
- // reset A
- a->sub(buffer_A, kSize);
-
- // B for next block
- a->mov(buffer_B, buffer_B_saved);
- // using C_Offset as temp reg
- a->imul(
- C_Offset,
- jIdx,
- static_cast<asmjit::Imm>(
- nRegBlockSize * row_interleave * sizeof(int8_t)));
- a->add(buffer_B, C_Offset);
- a->mov(B_pf, B_pf_saved);
- a->add(B_pf, C_Offset);
-
- // increment C for next B block
- a->add(CBase, static_cast<asmjit::Imm>(nRegBlockSize * sizeof(int32_t)));
-
- int jLoopTrips = currColRegs / maxNRegs;
- // jLoopTrips should be at least 1
- jLoopTrips = jLoopTrips ? jLoopTrips : 1;
- a->cmp(jIdx, jLoopTrips);
- a->jl(LoopNBlocks);
-
- // increment A for next block
- a->add(
- buffer_A, static_cast<asmjit::Imm>((rowRegs)*kBlock * sizeof(uint8_t)));
-
- // increment C for next A block
- a->sub(
- CBase,
- static_cast<asmjit::Imm>(jLoopTrips * nRegBlockSize * sizeof(int32_t)));
- a->imul(C_Offset, ldcReg, static_cast<asmjit::Imm>(rowRegs));
- a->add(CBase, C_Offset);
-
- // reset B
- a->mov(buffer_B, buffer_B_saved);
- a->mov(B_pf, B_pf_saved);
- a->cmp(iIdx, mRegBlocks);
- a->jl(LoopMBlocks);
- }
- // generate code for remainder
- if (mRegBlocksRem > 0) {
- asmjit::Label LoopNRem = a->newLabel();
- asmjit::Label LoopkRem = a->newLabel();
- int rowRegs = mRegBlocksRem;
-
- a->mov(jIdx, 0);
- a->bind(LoopNRem);
- a->inc(jIdx);
-
- // init C registers
- initCRegs<inst_set_t::avx512>(a, rowRegs, colRegs, colRegs);
-
- // init k loop index
- a->mov(kIdx, 0);
- a->bind(LoopkRem);
-
- // k is incremented by row_interleave
- a->add(kIdx, static_cast<asmjit::Imm>(row_interleave));
-
- genComputeBlock<inst_set_t::avx512>(
- a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock, colRegs);
-
- // update buffer_A address for next k iteration
- a->add(
- buffer_A, static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t)));
-
- // update buffer_B address for next k iteration
- a->add(
- buffer_B,
- static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t)));
- a->add(
- B_pf,
- static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t)));
-
- a->cmp(kIdx, kSize);
- a->jl(LoopkRem);
-
- // reset A
- a->sub(buffer_A, kSize);
-
- // B for next block
- // using C_Offset as temp reg
- a->imul(
- C_Offset,
- jIdx,
- static_cast<asmjit::Imm>(
- nRegBlockSize * row_interleave * sizeof(int8_t)));
- a->mov(buffer_B, buffer_B_saved);
- a->add(buffer_B, C_Offset);
- a->mov(B_pf, B_pf_saved);
- a->add(B_pf, C_Offset);
-
- // store C matrix
- storeCRegs<inst_set_t::avx512>(
- a, rowRegs, colRegs, C_Offset, ldcReg, accum, colRegs);
-
- // increment C for next B block
- a->add(CBase, static_cast<asmjit::Imm>(nRegBlockSize * sizeof(int32_t)));
-
- int jLoopTrips = currColRegs / maxNRegs;
- // jLoopTrips should be at least 1
- jLoopTrips = jLoopTrips ? jLoopTrips : 1;
- a->cmp(jIdx, jLoopTrips);
- a->jl(LoopNRem);
- }
+ asmjit::FuncDetail func;
+ func.init(
+ asmjit::FuncSignatureT<void, uint8_t *, int8_t *, int8_t *, int32_t *,
+ int, int>(asmjit::CallConv::kIdHost));
+
+ asmjit::FuncFrame frame;
+ frame.init(func);
+
+ frame.setDirtyRegs(
+ x86::Reg::kGroupVec,
+ asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) |
+ asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
+ frame.setDirtyRegs(x86::Reg::kGroupGp,
+ asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
+
+ asmjit::FuncArgsAssignment args(&func);
+ args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg);
+
+ args.updateFuncFrame(frame);
+ frame.finalize();
+
+ a->emitProlog(frame);
+ a->emitArgsAssignment(frame, args);
+
+ asmjit::Label LoopMBlocks = a->newLabel();
+ asmjit::Label LoopNBlocks = a->newLabel();
+ asmjit::Label Loopk = a->newLabel();
+
+ x86::Gp buffer_B_saved = a->gpz(10);
+ x86::Gp C_Offset = a->gpz(11);
+ x86::Gp B_pf_saved = a->gpz(12);
+ x86::Gp iIdx = a->gpz(13);
+ x86::Gp jIdx = a->gpz(14);
+ x86::Gp kIdx = a->gpz(15);
+ // x86::Gp B_pf = a->gpz(8);
+
+ x86::Zmm oneReg = x86::zmm29;
+ // create 16-bit 1s
+ // i.e., oneReg[0:15] contains 0x0001, oneReg[16:31] contains 0x0001
+ // and so on
+ // a->vpcmpeqw(oneReg, oneReg, oneReg);
+ a->vpternlogd(oneReg, oneReg, oneReg, 0xff);
+ a->vpsrlw(oneReg, oneReg, 15);
+ a->imul(ldcReg, ldcReg, static_cast<asmjit::Imm>(sizeof(int32_t)));
+
+ // save B_buffer address
+ a->mov(buffer_B_saved, buffer_B);
+ a->mov(B_pf_saved, B_pf);
+
+ int currColRegs = nc * row_interleave * sizeof(int8_t) / VLEN_;
+ int colRegs = std::min(currColRegs, maxNRegs);
+ if (mRegBlocks > 0) {
+ // move 0 to iteration variables
+ a->mov(iIdx, 0);
+
+ a->bind(LoopMBlocks);
+ a->inc(iIdx);
+ a->mov(jIdx, 0);
+
+ a->bind(LoopNBlocks);
+ a->inc(jIdx);
+
+ int rowRegs = mRegBlockSize;
+
+ // init C registers
+ initCRegs<inst_set_t::avx512>(a, rowRegs, colRegs, colRegs);
+
+ // init k loop index
+ a->mov(kIdx, 0);
+ a->bind(Loopk);
+
+ // k is incremented by row_interleave
+ a->add(kIdx, static_cast<asmjit::Imm>(row_interleave));
+
+ genComputeBlock<inst_set_t::avx512>(a, buffer_A, buffer_B, B_pf, rowRegs,
+ colRegs, kBlock, colRegs);
+
+ // update buffer_A address for next k iteration
+ a->add(buffer_A,
+ static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t)));
+
+ // update buffer_B address for next k iteration
+ a->add(buffer_B, static_cast<asmjit::Imm>(nBlock * row_interleave *
+ sizeof(int8_t)));
+ a->add(B_pf, static_cast<asmjit::Imm>(nBlock * row_interleave *
+ sizeof(int8_t)));
+
+ // a->add(B_pf, static_cast<asmjit::Imm>(32*sizeof(float)));
+
+ a->cmp(kIdx, kSize);
+ a->jl(Loopk);
+
+ // store C matrix
+ storeCRegs<inst_set_t::avx512>(a, rowRegs, colRegs, C_Offset, ldcReg,
+ accum, colRegs);
+
+ // reset A
+ a->sub(buffer_A, kSize);
+
+ // B for next block
+ a->mov(buffer_B, buffer_B_saved);
+ // using C_Offset as temp reg
+ a->imul(C_Offset, jIdx,
+ static_cast<asmjit::Imm>(nRegBlockSize * row_interleave *
+ sizeof(int8_t)));
+ a->add(buffer_B, C_Offset);
+ a->mov(B_pf, B_pf_saved);
+ a->add(B_pf, C_Offset);
+
+ // increment C for next B block
+ a->add(CBase, static_cast<asmjit::Imm>(nRegBlockSize * sizeof(int32_t)));
+
+ int jLoopTrips = currColRegs / maxNRegs;
+ // jLoopTrips should be at least 1
+ jLoopTrips = jLoopTrips ? jLoopTrips : 1;
+ a->cmp(jIdx, jLoopTrips);
+ a->jl(LoopNBlocks);
+
+ // increment A for next block
+ a->add(buffer_A,
+ static_cast<asmjit::Imm>((rowRegs)*kBlock * sizeof(uint8_t)));
+
+ // increment C for next A block
+ a->sub(CBase, static_cast<asmjit::Imm>(jLoopTrips * nRegBlockSize *
+ sizeof(int32_t)));
+ a->imul(C_Offset, ldcReg, static_cast<asmjit::Imm>(rowRegs));
+ a->add(CBase, C_Offset);
+
+ // reset B
+ a->mov(buffer_B, buffer_B_saved);
+ a->mov(B_pf, B_pf_saved);
+ a->cmp(iIdx, mRegBlocks);
+ a->jl(LoopMBlocks);
+ }
+ // generate code for remainder
+ if (mRegBlocksRem > 0) {
+ asmjit::Label LoopNRem = a->newLabel();
+ asmjit::Label LoopkRem = a->newLabel();
+ int rowRegs = mRegBlocksRem;
+
+ a->mov(jIdx, 0);
+ a->bind(LoopNRem);
+ a->inc(jIdx);
+
+ // init C registers
+ initCRegs<inst_set_t::avx512>(a, rowRegs, colRegs, colRegs);
+
+ // init k loop index
+ a->mov(kIdx, 0);
+ a->bind(LoopkRem);
+
+ // k is incremented by row_interleave
+ a->add(kIdx, static_cast<asmjit::Imm>(row_interleave));
+
+ genComputeBlock<inst_set_t::avx512>(a, buffer_A, buffer_B, B_pf, rowRegs,
+ colRegs, kBlock, colRegs);
+
+ // update buffer_A address for next k iteration
+ a->add(buffer_A,
+ static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t)));
+
+ // update buffer_B address for next k iteration
+ a->add(buffer_B, static_cast<asmjit::Imm>(nBlock * row_interleave *
+ sizeof(int8_t)));
+ a->add(B_pf, static_cast<asmjit::Imm>(nBlock * row_interleave *
+ sizeof(int8_t)));
+
+ a->cmp(kIdx, kSize);
+ a->jl(LoopkRem);
+
+ // reset A
+ a->sub(buffer_A, kSize);
+
+ // B for next block
+ // using C_Offset as temp reg
+ a->imul(C_Offset, jIdx,
+ static_cast<asmjit::Imm>(nRegBlockSize * row_interleave *
+ sizeof(int8_t)));
+ a->mov(buffer_B, buffer_B_saved);
+ a->add(buffer_B, C_Offset);
+ a->mov(B_pf, B_pf_saved);
+ a->add(B_pf, C_Offset);
+
+ // store C matrix
+ storeCRegs<inst_set_t::avx512>(a, rowRegs, colRegs, C_Offset, ldcReg,
+ accum, colRegs);
+
+ // increment C for next B block
+ a->add(CBase, static_cast<asmjit::Imm>(nRegBlockSize * sizeof(int32_t)));
+
+ int jLoopTrips = currColRegs / maxNRegs;
+ // jLoopTrips should be at least 1
+ jLoopTrips = jLoopTrips ? jLoopTrips : 1;
+ a->cmp(jIdx, jLoopTrips);
+ a->jl(LoopNRem);
+ }
- a->emitEpilog(frame);
+ a->emitEpilog(frame);
- jit_micro_kernel_fp fn;
- asmjit::Error err = rt_.add(&fn, &code_);
- if (err) {
- std::cout << "Error: in fn add" << std::endl;
- return nullptr;
- }
- codeCache_[kernelSig] = fn;
+ jit_micro_kernel_fp fn;
+ asmjit::Error err;
+ {
+ std::unique_lock<std::mutex> lock(rtMutex_);
+ err = runtime().add(&fn, &code);
+ }
+ if (err) {
+ std::cout << "Error: in fn add" << std::endl;
+ return nullptr;
+ }
#if defined(FBGEMM_LOG_CODE)
- fclose(codeLogfile);
- delete codeLogger;
+ fclose(codeLogfile);
+ delete codeLogger;
#endif
- return fn;
+ return fn;
+ });
}
} // namespace fbgemm
diff --git a/src/GenerateKernelU8S8S32ACC32Avx512VNNI.cc b/src/GenerateKernelU8S8S32ACC32Avx512VNNI.cc
index 8ae0745..1d23e90 100644
--- a/src/GenerateKernelU8S8S32ACC32Avx512VNNI.cc
+++ b/src/GenerateKernelU8S8S32ACC32Avx512VNNI.cc
@@ -23,12 +23,13 @@ void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::initCRegs<
int rowRegs,
int colRegs,
int leadingDimCReg) {
+ using CRegs = x86::Zmm;
for (int i = 0; i < rowRegs; ++i) {
for (int j = 0; j < colRegs; ++j) {
a->vxorps(
- CRegs_avx512_[i * leadingDimCReg + j],
- CRegs_avx512_[i * leadingDimCReg + j],
- CRegs_avx512_[i * leadingDimCReg + j]);
+ CRegs(i * leadingDimCReg + j),
+ CRegs(i * leadingDimCReg + j),
+ CRegs(i * leadingDimCReg + j));
}
}
}
@@ -55,6 +56,8 @@ void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::genComputeBlock<
// used for matrix B
x86::Zmm BReg = x86::zmm30;
+ using CRegs = x86::Zmm;
+
for (int j = 0; j < colRegs; ++j) {
// load B
a->vmovaps(BReg, x86::dword_ptr(buffer_B, j * VLEN_ * sizeof(int8_t)));
@@ -62,7 +65,7 @@ void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::genComputeBlock<
for (int i = 0; i < rowRegs; ++i) {
a->vpbroadcastd(
AReg, x86::dword_ptr(buffer_A, (i * lda) * sizeof(uint8_t)));
- a->vpdpbusd(CRegs_avx512_[i * leadingDimCReg + j], AReg, BReg);
+ a->vpdpbusd(CRegs(i * leadingDimCReg + j), AReg, BReg);
}
a->prefetcht0(x86::dword_ptr(B_pf, j * VLEN_ * sizeof(int8_t)));
}
@@ -83,6 +86,7 @@ void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::storeCRegs<
x86::Gp ldcReg,
bool accum,
int leadingDimCReg) {
+ using CRegs = x86::Zmm;
for (int i = 0; i < rowRegs; ++i) {
if (i != 0) {
a->add(C_Offset, ldcReg);
@@ -92,13 +96,13 @@ void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::storeCRegs<
for (int j = 0; j < colRegs; ++j) {
if (accum) {
a->vpaddd(
- CRegs_avx512_[i * leadingDimCReg + j],
- CRegs_avx512_[i * leadingDimCReg + j],
+ CRegs(i * leadingDimCReg + j),
+ CRegs(i * leadingDimCReg + j),
x86::dword_ptr(a->zcx(), C_Offset, 0, j * 16 * sizeof(int32_t)));
}
a->vmovups(
x86::dword_ptr(a->zcx(), C_Offset, 0, j * 16 * sizeof(int32_t)),
- CRegs_avx512_[i * leadingDimCReg + j]);
+ CRegs(i * leadingDimCReg + j));
}
}
}
@@ -155,277 +159,260 @@ CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreate<
nRegBlockSize,
nRegBlockSizeMin);
- if (codeCache_.find(kernelSig) != codeCache_.end()) {
- return codeCache_[kernelSig];
- }
- code_.reset(false);
- code_.init(rt_.codeInfo());
- x86::Assembler assembler(&code_);
- x86::Emitter* a = assembler.as<x86::Emitter>();
+ return codeCache_.getOrCreate(kernelSig, [&]() -> jit_micro_kernel_fp {
+ asmjit::CodeHolder code;
+ code.init(runtime().codeInfo());
+ x86::Assembler assembler(&code);
+ x86::Emitter *a = assembler.as<x86::Emitter>();
#if defined(FBGEMM_LOG_CODE)
- // generated code logging
- FILE* codeLogfile = fopen(
- getCodeLoggingFile<inst_set_t::avx512_vnni>(
- accum,
- mc,
- nc,
- nBlock,
- kBlock,
- mRegBlockSize,
- nRegBlockSize,
- nRegBlockSizeMin)
- .c_str(),
- "w");
- asmjit::FileLogger* codeLogger = new asmjit::FileLogger(codeLogfile);
- if (codeLogger) {
- code_.setLogger(codeLogger);
- }
+ // generated code logging
+ FILE *codeLogfile = fopen(getCodeLoggingFile<inst_set_t::avx512_vnni>(
+ accum, mc, nc, nBlock, kBlock, mRegBlockSize,
+ nRegBlockSize, nRegBlockSizeMin)
+ .c_str(),
+ "w");
+ asmjit::FileLogger *codeLogger = new asmjit::FileLogger(codeLogfile);
+ if (codeLogger) {
+ code.setLogger(codeLogger);
+ }
#endif
- assert(kc % row_interleave == 0 && "kc must be a multiple of row_interleave");
- assert(nc % nRegBlockSizeMin == 0 && "nc must be a multiple of NR_MIN");
- int maxMRegs = mRegBlockSize;
- int maxNRegs = nRegBlockSize * row_interleave / VLEN_;
- assert(
- maxMRegs * maxNRegs <= 28 &&
- "MR*(NR*ROW_INTERLEAVE*8/512) \
+ assert(kc % row_interleave == 0 &&
+ "kc must be a multiple of row_interleave");
+ assert(nc % nRegBlockSizeMin == 0 && "nc must be a multiple of NR_MIN");
+ int maxMRegs = mRegBlockSize;
+ int maxNRegs = nRegBlockSize * row_interleave / VLEN_;
+ assert(maxMRegs * maxNRegs <= 28 && "MR*(NR*ROW_INTERLEAVE*8/512) \
must be <= 28(available registers constraint)");
- int mRegBlocks = mc / mRegBlockSize;
- int mRegBlocksRem = mc % mRegBlockSize;
-
- // arguments to the function created
- x86::Gp buffer_A = a->zdi();
- x86::Gp buffer_B = a->zsi();
- x86::Gp B_pf = a->zdx();
- x86::Gp CBase = a->zcx();
- x86::Gp kSize = a->gpz(8);
- x86::Gp ldcReg = a->gpz(9);
-
- asmjit::FuncDetail func;
- func.init(
- asmjit::
- FuncSignatureT<void, uint8_t*, int8_t*, int8_t*, int32_t*, int, int>(
- asmjit::CallConv::kIdHost));
-
- asmjit::FuncFrame frame;
- frame.init(func);
-
- frame.setDirtyRegs(
- x86::Reg::kGroupVec,
- asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) |
- asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
- frame.setDirtyRegs(
- x86::Reg::kGroupGp,
- asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
-
- asmjit::FuncArgsAssignment args(&func);
- args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg);
-
- args.updateFuncFrame(frame);
- frame.finalize();
-
- a->emitProlog(frame);
- a->emitArgsAssignment(frame, args);
-
- asmjit::Label LoopMBlocks = a->newLabel();
- asmjit::Label LoopNBlocks = a->newLabel();
- asmjit::Label Loopk = a->newLabel();
-
- x86::Gp buffer_B_saved = a->gpz(10);
- x86::Gp C_Offset = a->gpz(11);
- x86::Gp B_pf_saved = a->gpz(12);
- x86::Gp iIdx = a->gpz(13);
- x86::Gp jIdx = a->gpz(14);
- x86::Gp kIdx = a->gpz(15);
- // x86::Gp B_pf = a->gpz(8);
-
- x86::Zmm oneReg = x86::zmm29;
- // create 16-bit 1s
- // i.e., oneReg[0:15] contains 0x0001, oneReg[16:31] contains 0x0001
- // and so on
- // a->vpcmpeqw(oneReg, oneReg, oneReg);
- a->vpternlogd(oneReg, oneReg, oneReg, 0xff);
- a->vpsrlw(oneReg, oneReg, 15);
- a->imul(ldcReg, ldcReg, static_cast<asmjit::Imm>(sizeof(int32_t)));
-
- // save B_buffer address
- a->mov(buffer_B_saved, buffer_B);
- a->mov(B_pf_saved, B_pf);
-
- int currColRegs = nc * row_interleave * sizeof(int8_t) / VLEN_;
- int colRegs = std::min(currColRegs, maxNRegs);
- if (mRegBlocks > 0) {
- // move 0 to iteration variables
- a->mov(iIdx, 0);
-
- a->bind(LoopMBlocks);
- a->inc(iIdx);
- a->mov(jIdx, 0);
-
- a->bind(LoopNBlocks);
- a->inc(jIdx);
-
- int rowRegs = mRegBlockSize;
-
- // init C registers
- initCRegs<inst_set_t::avx512_vnni>(a, rowRegs, colRegs, colRegs);
-
- // init k loop index
- a->mov(kIdx, 0);
- a->bind(Loopk);
-
- // k is incremented by row_interleave
- a->add(kIdx, static_cast<asmjit::Imm>(row_interleave));
-
- genComputeBlock<inst_set_t::avx512_vnni>(
- a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock, colRegs);
-
- // update buffer_A address for next k iteration
- a->add(
- buffer_A, static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t)));
-
- // update buffer_B address for next k iteration
- a->add(
- buffer_B,
- static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t)));
- a->add(
- B_pf,
- static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t)));
-
- // a->add(B_pf, static_cast<asmjit::Imm>(32*sizeof(float)));
-
- a->cmp(kIdx, kSize);
- a->jl(Loopk);
-
- // store C matrix
- storeCRegs<inst_set_t::avx512_vnni>(
- a, rowRegs, colRegs, C_Offset, ldcReg, accum, colRegs);
-
- // reset A
- a->sub(buffer_A, kSize);
-
- // B for next block
- a->mov(buffer_B, buffer_B_saved);
- // using C_Offset as temp reg
- a->imul(
- C_Offset,
- jIdx,
- static_cast<asmjit::Imm>(
- nRegBlockSize * row_interleave * sizeof(int8_t)));
- a->add(buffer_B, C_Offset);
- a->mov(B_pf, B_pf_saved);
- a->add(B_pf, C_Offset);
-
- // increment C for next B block
- a->add(CBase, static_cast<asmjit::Imm>(nRegBlockSize * sizeof(int32_t)));
-
- int jLoopTrips = currColRegs / maxNRegs;
- // jLoopTrips should be at least 1
- jLoopTrips = jLoopTrips ? jLoopTrips : 1;
- a->cmp(jIdx, jLoopTrips);
- a->jl(LoopNBlocks);
-
- // increment A for next block
- a->add(
- buffer_A, static_cast<asmjit::Imm>((rowRegs)*kBlock * sizeof(uint8_t)));
-
- // increment C for next A block
- a->sub(
- CBase,
- static_cast<asmjit::Imm>(jLoopTrips * nRegBlockSize * sizeof(int32_t)));
- a->imul(C_Offset, ldcReg, static_cast<asmjit::Imm>(rowRegs));
- a->add(CBase, C_Offset);
-
- // reset B
- a->mov(buffer_B, buffer_B_saved);
- a->mov(B_pf, B_pf_saved);
- a->cmp(iIdx, mRegBlocks);
- a->jl(LoopMBlocks);
- }
- // generate code for remainder
- if (mRegBlocksRem > 0) {
- asmjit::Label LoopNRem = a->newLabel();
- asmjit::Label LoopkRem = a->newLabel();
- int rowRegs = mRegBlocksRem;
-
- a->mov(jIdx, 0);
- a->bind(LoopNRem);
- a->inc(jIdx);
-
- // init C registers
- initCRegs<inst_set_t::avx512_vnni>(a, rowRegs, colRegs, colRegs);
-
- // init k loop index
- a->mov(kIdx, 0);
- a->bind(LoopkRem);
-
- // k is incremented by row_interleave
- a->add(kIdx, static_cast<asmjit::Imm>(row_interleave));
-
- genComputeBlock<inst_set_t::avx512_vnni>(
- a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock, colRegs);
-
- // update buffer_A address for next k iteration
- a->add(
- buffer_A, static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t)));
-
- // update buffer_B address for next k iteration
- a->add(
- buffer_B,
- static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t)));
- a->add(
- B_pf,
- static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t)));
-
- a->cmp(kIdx, kSize);
- a->jl(LoopkRem);
-
- // reset A
- a->sub(buffer_A, kSize);
- // B for next block
- // using C_Offset as temp reg
- a->imul(
- C_Offset,
- jIdx,
- static_cast<asmjit::Imm>(
- nRegBlockSize * row_interleave * sizeof(int8_t)));
- a->mov(buffer_B, buffer_B_saved);
- a->add(buffer_B, C_Offset);
- a->mov(B_pf, B_pf_saved);
- a->add(B_pf, C_Offset);
-
- // store C matrix
- storeCRegs<inst_set_t::avx512_vnni>(
- a, rowRegs, colRegs, C_Offset, ldcReg, accum, colRegs);
-
- // increment C for next B block
- a->add(CBase, static_cast<asmjit::Imm>(nRegBlockSize * sizeof(int32_t)));
-
- int jLoopTrips = currColRegs / maxNRegs;
- // jLoopTrips should be at least 1
- jLoopTrips = jLoopTrips ? jLoopTrips : 1;
- a->cmp(jIdx, jLoopTrips);
- a->jl(LoopNRem);
- }
+ int mRegBlocks = mc / mRegBlockSize;
+ int mRegBlocksRem = mc % mRegBlockSize;
+
+ // arguments to the function created
+ x86::Gp buffer_A = a->zdi();
+ x86::Gp buffer_B = a->zsi();
+ x86::Gp B_pf = a->zdx();
+ x86::Gp CBase = a->zcx();
+ x86::Gp kSize = a->gpz(8);
+ x86::Gp ldcReg = a->gpz(9);
+
+ asmjit::FuncDetail func;
+ func.init(
+ asmjit::FuncSignatureT<void, uint8_t *, int8_t *, int8_t *, int32_t *,
+ int, int>(asmjit::CallConv::kIdHost));
+
+ asmjit::FuncFrame frame;
+ frame.init(func);
+
+ frame.setDirtyRegs(
+ x86::Reg::kGroupVec,
+ asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) |
+ asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
+ frame.setDirtyRegs(x86::Reg::kGroupGp,
+ asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
+
+ asmjit::FuncArgsAssignment args(&func);
+ args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg);
+
+ args.updateFuncFrame(frame);
+ frame.finalize();
+
+ a->emitProlog(frame);
+ a->emitArgsAssignment(frame, args);
+
+ asmjit::Label LoopMBlocks = a->newLabel();
+ asmjit::Label LoopNBlocks = a->newLabel();
+ asmjit::Label Loopk = a->newLabel();
+
+ x86::Gp buffer_B_saved = a->gpz(10);
+ x86::Gp C_Offset = a->gpz(11);
+ x86::Gp B_pf_saved = a->gpz(12);
+ x86::Gp iIdx = a->gpz(13);
+ x86::Gp jIdx = a->gpz(14);
+ x86::Gp kIdx = a->gpz(15);
+ // x86::Gp B_pf = a->gpz(8);
+
+ x86::Zmm oneReg = x86::zmm29;
+ // create 16-bit 1s
+ // i.e., oneReg[0:15] contains 0x0001, oneReg[16:31] contains 0x0001
+ // and so on
+ // a->vpcmpeqw(oneReg, oneReg, oneReg);
+ a->vpternlogd(oneReg, oneReg, oneReg, 0xff);
+ a->vpsrlw(oneReg, oneReg, 15);
+ a->imul(ldcReg, ldcReg, static_cast<asmjit::Imm>(sizeof(int32_t)));
+
+ // save B_buffer address
+ a->mov(buffer_B_saved, buffer_B);
+ a->mov(B_pf_saved, B_pf);
+
+ int currColRegs = nc * row_interleave * sizeof(int8_t) / VLEN_;
+ int colRegs = std::min(currColRegs, maxNRegs);
+ if (mRegBlocks > 0) {
+ // move 0 to iteration variables
+ a->mov(iIdx, 0);
+
+ a->bind(LoopMBlocks);
+ a->inc(iIdx);
+ a->mov(jIdx, 0);
+
+ a->bind(LoopNBlocks);
+ a->inc(jIdx);
+
+ int rowRegs = mRegBlockSize;
+
+ // init C registers
+ initCRegs<inst_set_t::avx512_vnni>(a, rowRegs, colRegs, colRegs);
+
+ // init k loop index
+ a->mov(kIdx, 0);
+ a->bind(Loopk);
+
+ // k is incremented by row_interleave
+ a->add(kIdx, static_cast<asmjit::Imm>(row_interleave));
+
+ genComputeBlock<inst_set_t::avx512_vnni>(
+ a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock, colRegs);
+
+ // update buffer_A address for next k iteration
+ a->add(buffer_A,
+ static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t)));
+
+ // update buffer_B address for next k iteration
+ a->add(buffer_B, static_cast<asmjit::Imm>(nBlock * row_interleave *
+ sizeof(int8_t)));
+ a->add(B_pf, static_cast<asmjit::Imm>(nBlock * row_interleave *
+ sizeof(int8_t)));
+
+ // a->add(B_pf, static_cast<asmjit::Imm>(32*sizeof(float)));
+
+ a->cmp(kIdx, kSize);
+ a->jl(Loopk);
+
+ // store C matrix
+ storeCRegs<inst_set_t::avx512_vnni>(a, rowRegs, colRegs, C_Offset, ldcReg,
+ accum, colRegs);
+
+ // reset A
+ a->sub(buffer_A, kSize);
+
+ // B for next block
+ a->mov(buffer_B, buffer_B_saved);
+ // using C_Offset as temp reg
+ a->imul(C_Offset, jIdx,
+ static_cast<asmjit::Imm>(nRegBlockSize * row_interleave *
+ sizeof(int8_t)));
+ a->add(buffer_B, C_Offset);
+ a->mov(B_pf, B_pf_saved);
+ a->add(B_pf, C_Offset);
+
+ // increment C for next B block
+ a->add(CBase, static_cast<asmjit::Imm>(nRegBlockSize * sizeof(int32_t)));
+
+ int jLoopTrips = currColRegs / maxNRegs;
+ // jLoopTrips should be at least 1
+ jLoopTrips = jLoopTrips ? jLoopTrips : 1;
+ a->cmp(jIdx, jLoopTrips);
+ a->jl(LoopNBlocks);
+
+ // increment A for next block
+ a->add(buffer_A,
+ static_cast<asmjit::Imm>((rowRegs)*kBlock * sizeof(uint8_t)));
+
+ // increment C for next A block
+ a->sub(CBase, static_cast<asmjit::Imm>(jLoopTrips * nRegBlockSize *
+ sizeof(int32_t)));
+ a->imul(C_Offset, ldcReg, static_cast<asmjit::Imm>(rowRegs));
+ a->add(CBase, C_Offset);
+
+ // reset B
+ a->mov(buffer_B, buffer_B_saved);
+ a->mov(B_pf, B_pf_saved);
+ a->cmp(iIdx, mRegBlocks);
+ a->jl(LoopMBlocks);
+ }
+ // generate code for remainder
+ if (mRegBlocksRem > 0) {
+ asmjit::Label LoopNRem = a->newLabel();
+ asmjit::Label LoopkRem = a->newLabel();
+ int rowRegs = mRegBlocksRem;
+
+ a->mov(jIdx, 0);
+ a->bind(LoopNRem);
+ a->inc(jIdx);
+
+ // init C registers
+ initCRegs<inst_set_t::avx512_vnni>(a, rowRegs, colRegs, colRegs);
+
+ // init k loop index
+ a->mov(kIdx, 0);
+ a->bind(LoopkRem);
+
+ // k is incremented by row_interleave
+ a->add(kIdx, static_cast<asmjit::Imm>(row_interleave));
+
+ genComputeBlock<inst_set_t::avx512_vnni>(
+ a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock, colRegs);
+
+ // update buffer_A address for next k iteration
+ a->add(buffer_A,
+ static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t)));
+
+ // update buffer_B address for next k iteration
+ a->add(buffer_B, static_cast<asmjit::Imm>(nBlock * row_interleave *
+ sizeof(int8_t)));
+ a->add(B_pf, static_cast<asmjit::Imm>(nBlock * row_interleave *
+ sizeof(int8_t)));
+
+ a->cmp(kIdx, kSize);
+ a->jl(LoopkRem);
+
+ // reset A
+ a->sub(buffer_A, kSize);
+ // B for next block
+ // using C_Offset as temp reg
+ a->imul(C_Offset, jIdx,
+ static_cast<asmjit::Imm>(nRegBlockSize * row_interleave *
+ sizeof(int8_t)));
+ a->mov(buffer_B, buffer_B_saved);
+ a->add(buffer_B, C_Offset);
+ a->mov(B_pf, B_pf_saved);
+ a->add(B_pf, C_Offset);
+
+ // store C matrix
+ storeCRegs<inst_set_t::avx512_vnni>(a, rowRegs, colRegs, C_Offset, ldcReg,
+ accum, colRegs);
+
+ // increment C for next B block
+ a->add(CBase, static_cast<asmjit::Imm>(nRegBlockSize * sizeof(int32_t)));
+
+ int jLoopTrips = currColRegs / maxNRegs;
+ // jLoopTrips should be at least 1
+ jLoopTrips = jLoopTrips ? jLoopTrips : 1;
+ a->cmp(jIdx, jLoopTrips);
+ a->jl(LoopNRem);
+ }
- a->emitEpilog(frame);
+ a->emitEpilog(frame);
- jit_micro_kernel_fp fn;
- asmjit::Error err = rt_.add(&fn, &code_);
- if (err) {
- std::cout << "Error: in fn add" << std::endl;
- return nullptr;
- }
- codeCache_[kernelSig] = fn;
+ jit_micro_kernel_fp fn;
+ asmjit::Error err;
+ {
+ std::unique_lock<std::mutex> lock(rtMutex_);
+ err = runtime().add(&fn, &code);
+ }
+ if (err) {
+ std::cout << "Error: in fn add" << std::endl;
+ return nullptr;
+ }
#if defined(FBGEMM_LOG_CODE)
- fclose(codeLogfile);
- delete codeLogger;
+ fclose(codeLogfile);
+ delete codeLogger;
#endif
- return fn;
+ return fn;
+ });
}
} // namespace fbgemm
diff --git a/src/GroupwiseConv.h b/src/GroupwiseConv.h
index 4c5eea5..58ee24d 100644
--- a/src/GroupwiseConv.h
+++ b/src/GroupwiseConv.h
@@ -10,8 +10,10 @@
#include <cassert>
#include <cstdint>
#include <map>
+#include <mutex>
#include <string>
#include <tuple>
+#include "CodeCache.h"
#include "fbgemm/ConvUtils.h"
#include "fbgemm/Fbgemm.h"
#include "fbgemm/Utils.h"
@@ -217,16 +219,23 @@ class GenConvKernel {
template <inst_set_t instSet>
void storeResultRowoffset(x86::Emitter* a, int offset = 0);
- static thread_local asmjit::JitRuntime rt_; ///< JIT Runtime for asmjit.
- static thread_local asmjit::CodeHolder code_; ///< JIT Code Holder for asmjit.
- static thread_local std::
- map<std::tuple<bool, int, int, int>, jit_conv_kernel_fp>
- codeCache_; ///< JIT Code Cache for reuse.
- static thread_local std::
- map<std::tuple<bool, int, int, int>, jit_rowoffset_kernel_fp>
- codeCacheRowOffset_; ///< JIT Code Cache for row offset kernel.
- private:
+ static asmjit::JitRuntime &runtime() {
+ static asmjit::JitRuntime rt; //< JIT Runtime for asmjit,
+ // depents on other static
+ // variables. Required to prevent
+ // initialization order fiasco
+ return rt;
+ }
+
+ static std::mutex rtMutex_; ///< Controll access to runtime;
+
+ static CodeCache<std::tuple<bool, int, int, int>, jit_conv_kernel_fp>
+ codeCache_; ///< JIT Code Cache for reuse.
+ static CodeCache<std::tuple<bool, int, int, int>, jit_rowoffset_kernel_fp>
+ codeCacheRowOffset_; ///< JIT Code Cache for row offset kernel.
+
+private:
int vectorWidth_; ///< Vector width in bits.
int VLEN_; ///< Vector width in elements.
// avx2 specific
@@ -272,4 +281,15 @@ class GenConvKernel {
int W_PAD_; ///< Padding for width (left and right)
};
+template <int SPATIAL_DIM, typename accT>
+std::mutex GenConvKernel<SPATIAL_DIM, accT>::rtMutex_;
+
+template <int SPATIAL_DIM, typename accT>
+CodeCache<std::tuple<bool, int, int, int>, jit_conv_kernel_fp>
+ GenConvKernel<SPATIAL_DIM, accT>::codeCache_;
+
+template <int SPATIAL_DIM, typename accT>
+CodeCache<std::tuple<bool, int, int, int>, jit_rowoffset_kernel_fp>
+ GenConvKernel<SPATIAL_DIM, accT>::codeCacheRowOffset_;
+
} // namespace fbgemm
diff --git a/src/GroupwiseConvAcc32Avx2.cc b/src/GroupwiseConvAcc32Avx2.cc
index b140c83..d1e0fdd 100644
--- a/src/GroupwiseConvAcc32Avx2.cc
+++ b/src/GroupwiseConvAcc32Avx2.cc
@@ -21,20 +21,6 @@ namespace fbgemm {
using namespace std;
-template <int SPATIAL_DIM, typename accT>
-thread_local asmjit::JitRuntime GenConvKernel<SPATIAL_DIM, accT>::rt_;
-
-template <int SPATIAL_DIM, typename accT>
-thread_local asmjit::CodeHolder GenConvKernel<SPATIAL_DIM, accT>::code_;
-
-template <int SPATIAL_DIM, typename accT>
-thread_local std::map<std::tuple<bool, int, int, int>, jit_conv_kernel_fp>
- GenConvKernel<SPATIAL_DIM, accT>::codeCache_;
-
-template <int SPATIAL_DIM, typename accT>
-thread_local std::map<std::tuple<bool, int, int, int>, jit_rowoffset_kernel_fp>
- GenConvKernel<SPATIAL_DIM, accT>::codeCacheRowOffset_;
-
namespace x86 = asmjit::x86;
template <int SPATIAL_DIM>
@@ -91,14 +77,13 @@ jit_conv_kernel_fp getOrCreateConvKernel(
// Note: Wrong code is generated if it's not one of the supported convolution
assert(fbgemmOptimizedGConv<SPATIAL_DIM>(conv_param));
auto kernelSig = getKernelSig(conv_param, a_zero_point == 0);
- if (GenConvKernel<SPATIAL_DIM, accT>::codeCache_.find(kernelSig) !=
- GenConvKernel<SPATIAL_DIM, accT>::codeCache_.end()) {
- return GenConvKernel<SPATIAL_DIM, accT>::codeCache_[kernelSig];
- } else {
- auto genObj = GenConvKernel<SPATIAL_DIM, accT>(conv_param, a_zero_point);
- // TODO: Instruction set based dispatch
- return genObj.template getOrCreate<inst_set_t::avx2>(conv_param);
- }
+ return GenConvKernel<SPATIAL_DIM, accT>::codeCache_.getOrCreate(
+ kernelSig, [&]() {
+ auto genObj =
+ GenConvKernel<SPATIAL_DIM, accT>(conv_param, a_zero_point);
+ // TODO: Instruction set based dispatch
+ return genObj.template getOrCreate<inst_set_t::avx2>(conv_param);
+ });
}
template <>
@@ -1009,9 +994,9 @@ template <>
template <>
jit_conv_kernel_fp GenConvKernel<2, int32_t>::getOrCreate<inst_set_t::avx2>(
const conv_param_t<2>& conv_param) {
- code_.reset(false);
- code_.init(rt_.codeInfo());
- x86::Assembler assembler(&code_);
+ asmjit::CodeHolder code;
+ code.init(runtime().codeInfo());
+ x86::Assembler assembler(&code);
x86::Emitter* a = assembler.as<x86::Emitter>();
#if defined(FBGEMM_LOG_CODE)
@@ -1020,7 +1005,7 @@ jit_conv_kernel_fp GenConvKernel<2, int32_t>::getOrCreate<inst_set_t::avx2>(
fopen(getCodeLoggingFile<inst_set_t::avx2>(false).c_str(), "w");
asmjit::FileLogger* codeLogger = new asmjit::FileLogger(codeLogfile);
if (codeLogger) {
- code_.setLogger(codeLogger);
+ code.setLogger(codeLogger);
}
#endif
@@ -1097,13 +1082,15 @@ jit_conv_kernel_fp GenConvKernel<2, int32_t>::getOrCreate<inst_set_t::avx2>(
a->emitEpilog(frame);
jit_conv_kernel_fp fn;
- asmjit::Error err = rt_.add(&fn, &code_);
+ asmjit::Error err;
+ {
+ std::unique_lock<std::mutex> lock(rtMutex_);
+ err = runtime().add(&fn, &code);
+ }
if (err) {
std::cout << "Error: in fn add" << std::endl;
return nullptr;
}
- auto kernelSig = getKernelSig(conv_param, isAZeroPointZero_);
- codeCache_[kernelSig] = fn;
#if defined(FBGEMM_LOG_CODE)
fclose(codeLogfile);
@@ -1489,9 +1476,9 @@ template <>
jit_rowoffset_kernel_fp
GenConvKernel<2, int32_t>::getOrCreateRowOffset<inst_set_t::avx2>(
const conv_param_t<2>& conv_param) {
- code_.reset(false);
- code_.init(rt_.codeInfo());
- x86::Assembler assembler(&code_);
+ asmjit::CodeHolder code;
+ code.init(runtime().codeInfo());
+ x86::Assembler assembler(&code);
x86::Emitter* a = assembler.as<x86::Emitter>();
#if defined(FBGEMM_LOG_CODE)
@@ -1500,7 +1487,7 @@ GenConvKernel<2, int32_t>::getOrCreateRowOffset<inst_set_t::avx2>(
fopen(getCodeLoggingFile<inst_set_t::avx2>(true).c_str(), "w");
asmjit::FileLogger* codeLogger = new asmjit::FileLogger(codeLogfile);
if (codeLogger) {
- code_.setLogger(codeLogger);
+ code.setLogger(codeLogger);
}
#endif
@@ -1570,14 +1557,16 @@ GenConvKernel<2, int32_t>::getOrCreateRowOffset<inst_set_t::avx2>(
a->emitEpilog(frame);
+ asmjit::Error err;
jit_rowoffset_kernel_fp fn;
- asmjit::Error err = rt_.add(&fn, &code_);
+ {
+ std::unique_lock<std::mutex> lock(rtMutex_);
+ err = runtime().add(&fn, &code);
+ }
if (err) {
std::cout << "Error: in fn add" << std::endl;
return nullptr;
}
- auto kernelSig = getKernelSig(conv_param, isAZeroPointZero_);
- codeCacheRowOffset_[kernelSig] = fn;
#if defined(FBGEMM_LOG_CODE)
delete codeLogger;
@@ -1780,7 +1769,8 @@ template <
typename outType,
bool FUSE_RELU,
QuantizationGranularity Q_GRAN,
- int SPATIAL_DIM>
+ int SPATIAL_DIM,
+ typename BIAS_TYPE>
void fbgemmGroupwiseConv(
const conv_param_t<SPATIAL_DIM>& conv_param,
const std::uint8_t* activations,
@@ -1789,10 +1779,10 @@ void fbgemmGroupwiseConv(
packed_W& packed_weights,
outType* out,
int32_t* outBuffer,
- const ReQuantizeOutput<FUSE_RELU, Q_GRAN>& outProcess,
+ const ReQuantizeOutput<FUSE_RELU, Q_GRAN, BIAS_TYPE>& outProcess,
int thread_id,
int num_threads) {
- typedef ReQuantizeOutput<FUSE_RELU, Q_GRAN> processOutputType;
+ typedef ReQuantizeOutput<FUSE_RELU, Q_GRAN, BIAS_TYPE> processOutputType;
if (!cpuinfo_initialize()) {
throw std::runtime_error("Failed to initialize cpuinfo!");
@@ -1883,15 +1873,17 @@ void fbgemmGroupwiseConv(
outProcess.getBZeroPoint()[0] == 0) ||
rowOffsetBuf == nullptr;
- requantizationParams_t r = {a_zero_point,
- outProcess.getBZeroPoint(),
- outProcess.getCZeroPoint(),
- outProcess.getCMultiplier(),
- rowOffsetBuf,
- outProcess.getColOffsets(),
- outProcess.getBias(),
- outProcess.getNCols(),
- G};
+ requantizationParams_t<typename processOutputType::BIAS_T> r = {
+ a_zero_point,
+ outProcess.getBZeroPoint(),
+ outProcess.getCZeroPoint(),
+ outProcess.getCMultiplier(),
+ rowOffsetBuf,
+ outProcess.getColOffsets(),
+ outProcess.getBias(),
+ outProcess.getNCols(),
+ G,
+ outProcess.getActWScale()};
const std::int32_t* inp = outBuffer;
block_type_t block{i * oh_ow, oh_ow, gOuter * K_per_G, 8 * K_per_G};
@@ -2162,15 +2154,14 @@ jit_rowoffset_kernel_fp getOrCreateRowOffsetKernel(
// Note: Wrong code is generated if it's not one of the supported convolution
assert(fbgemmOptimizedGConv<SPATIAL_DIM>(conv_param));
auto kernelSig = getKernelSig(conv_param, a_zero_point == 0);
- if (GenConvKernel<SPATIAL_DIM, int32_t>::codeCacheRowOffset_.find(
- kernelSig) !=
- GenConvKernel<SPATIAL_DIM, int32_t>::codeCacheRowOffset_.end()) {
- return GenConvKernel<SPATIAL_DIM, int32_t>::codeCacheRowOffset_[kernelSig];
- } else {
- auto genObj = GenConvKernel<SPATIAL_DIM, int32_t>(conv_param, a_zero_point);
- // TODO: Instruction set based dispatch
- return genObj.template getOrCreateRowOffset<inst_set_t::avx2>(conv_param);
- }
+ return GenConvKernel<SPATIAL_DIM, int32_t>::codeCacheRowOffset_.getOrCreate(
+ kernelSig, [&]() {
+ auto genObj =
+ GenConvKernel<SPATIAL_DIM, int32_t>(conv_param, a_zero_point);
+ // TODO: Instruction set based dispatch
+ return genObj.template getOrCreateRowOffset<inst_set_t::avx2>(
+ conv_param);
+ });
}
template <int SPATIAL_DIM>
@@ -2214,7 +2205,7 @@ int rowOffsetBufferSizeGConv(const conv_param_t<SPATIAL_DIM>& conv_param) {
template int rowOffsetBufferSizeGConv<2>(const conv_param_t<2>& conv_param);
template int rowOffsetBufferSizeGConv<3>(const conv_param_t<3>& conv_param);
-#define INSTANTIATE_BASE(RELU, Q_GRAN, SPATIAL_DIM) \
+#define INSTANTIATE_BASE(RELU, Q_GRAN, SPATIAL_DIM, BIAS_TYPE) \
template void fbgemmGroupwiseConv( \
const conv_param_t<SPATIAL_DIM>& conv_param, \
const uint8_t* activations, \
@@ -2223,13 +2214,17 @@ template int rowOffsetBufferSizeGConv<3>(const conv_param_t<3>& conv_param);
PackWeightMatrixForGConv<int8_t, int32_t, SPATIAL_DIM>& packed_weights, \
uint8_t* out, \
int32_t* outBuffer, \
- const ReQuantizeOutput<RELU, Q_GRAN>& outProcess, \
+ const ReQuantizeOutput<RELU, Q_GRAN, BIAS_TYPE>& outProcess, \
int thread_id, \
int num_threads);
+#define INSTANTIATE_BIAS_T(RELU, Q_GRAN, SPATIAL_DIM) \
+ INSTANTIATE_BASE(RELU, Q_GRAN, SPATIAL_DIM, float); \
+ INSTANTIATE_BASE(RELU, Q_GRAN, SPATIAL_DIM, int32_t);
+
#define INSTANTIATE_SPATIAL_DIM(RELU, Q_GRAN) \
- INSTANTIATE_BASE(RELU, Q_GRAN, 2); \
- INSTANTIATE_BASE(RELU, Q_GRAN, 3);
+ INSTANTIATE_BIAS_T(RELU, Q_GRAN, 2); \
+ INSTANTIATE_BIAS_T(RELU, Q_GRAN, 3);
#define INSTANTIATE_Q_GRANS(RELU) \
INSTANTIATE_SPATIAL_DIM(RELU, QuantizationGranularity::TENSOR); \
@@ -2241,6 +2236,7 @@ INSTANTIATE_Q_GRANS(true);
#undef INSTANTIATE_Q_GRANS
#undef INSTANTIATE_SPATIAL_DIM
+#undef INSTANTIATE_BIAS_T
#undef INSTANTIATE_BASE
template void fbgemmGroupwiseConv(
diff --git a/src/PackAWithIm2Col.cc b/src/PackAWithIm2Col.cc
index 2aca27d..6101fef 100644
--- a/src/PackAWithIm2Col.cc
+++ b/src/PackAWithIm2Col.cc
@@ -275,6 +275,7 @@ void PackAWithIm2Col<T, accT, SPATIAL_DIM>::pack(const block_type_t& block) {
conv_p_.K[1] == 7 && conv_p_.stride[0] == 2 && conv_p_.stride[1] == 2 &&
conv_p_.pad[0] == 3 && conv_p_.pad[1] == 3 && block.col_size == 147 &&
block_p.col_size == 148 && block.col_start == 0 &&
+ conv_p_.dilation[0] == 1 && conv_p_.dilation[1] == 1 &&
std::is_same<T, uint8_t>::value) {
if (BaseType::blockColSize() == 256) {
pack_a_with_im2col_opt<
@@ -350,8 +351,10 @@ void PackAWithIm2Col<T, accT, SPATIAL_DIM>::pack(const block_type_t& block) {
int r = grs / conv_p_.K[1] % conv_p_.K[0];
int g = grs / conv_p_.K[1] / conv_p_.K[0];
- int h_in = -conv_p_.pad[0] + h * conv_p_.stride[0] + r;
- int w_in = -conv_p_.pad[1] + w * conv_p_.stride[1] + s;
+ int h_in =
+ -conv_p_.pad[0] + h * conv_p_.stride[0] + r * conv_p_.dilation[0];
+ int w_in =
+ -conv_p_.pad[1] + w * conv_p_.stride[1] + s * conv_p_.dilation[1];
if (h_in < 0 || h_in >= conv_p_.IN_DIM[0] || w_in < 0 ||
w_in >= conv_p_.IN_DIM[1]) {
@@ -399,9 +402,12 @@ void PackAWithIm2Col<T, accT, SPATIAL_DIM>::pack(const block_type_t& block) {
int q = gqrs / conv_p_.K[2] / conv_p_.K[1] % conv_p_.K[0];
int g = gqrs / conv_p_.K[2] / conv_p_.K[1] / conv_p_.K[0];
- int t_in = -conv_p_.pad[0] + t * conv_p_.stride[0] + q;
- int h_in = -conv_p_.pad[1] + h * conv_p_.stride[1] + r;
- int w_in = -conv_p_.pad[2] + w * conv_p_.stride[2] + s;
+ int t_in =
+ -conv_p_.pad[0] + t * conv_p_.stride[0] + q * conv_p_.dilation[0];
+ int h_in =
+ -conv_p_.pad[1] + h * conv_p_.stride[1] + r * conv_p_.dilation[1];
+ int w_in =
+ -conv_p_.pad[2] + w * conv_p_.stride[2] + s * conv_p_.dilation[2];
if (t_in < 0 || t_in >= conv_p_.IN_DIM[0] || h_in < 0 ||
h_in >= conv_p_.IN_DIM[1] || w_in < 0 ||
diff --git a/src/PackDepthwiseConvMatrixAvx2.cc b/src/PackDepthwiseConvMatrixAvx2.cc
new file mode 100644
index 0000000..0e17bcd
--- /dev/null
+++ b/src/PackDepthwiseConvMatrixAvx2.cc
@@ -0,0 +1,203 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+#include "fbgemm/FbgemmI8DepthwiseAvx2.h"
+
+#include <immintrin.h>
+
+using namespace std;
+
+namespace fbgemm {
+
+// clang-format off
+static int masks[8][8] = {
+ // NOTE: clang-format wants to use a different formatting but the current
+ // formatting should be easier to read.
+ { 0, 0, 0, 0, 0, 0, 0, 0, },
+ { -1, 0, 0, 0, 0, 0, 0, 0, },
+ { -1, -1, 0, 0, 0, 0, 0, 0, },
+ { -1, -1, -1, 0, 0, 0, 0, 0, },
+ { -1, -1, -1, -1, 0, 0, 0, 0, },
+ { -1, -1, -1, -1, -1, 0, 0, 0, },
+ { -1, -1, -1, -1, -1, -1, 0, 0, },
+ { -1, -1, -1, -1, -1, -1, -1, 0, },
+};
+// clang-format on
+
+PackedDepthWiseConvMatrix::PackedDepthWiseConvMatrix(
+ int K,
+ int kernel_prod,
+ const int8_t* smat)
+ : K_(K), kernel_prod_(kernel_prod) {
+ // Transpose the input matrix to make packing faster.
+ alignas(64) int8_t smat_transposed[K * kernel_prod];
+ for (int i = 0; i < kernel_prod; ++i) {
+ for (int j = 0; j < K; ++j) {
+ smat_transposed[i * K + j] = smat[i + j * kernel_prod];
+ }
+ }
+
+ // Allocate packed arrays
+ int kernel_prod_aligned = (kernel_prod + 1) / 2 * 2;
+ // pmat_ = static_cast<int8_t *>(fbgemmAlignedAlloc(
+ // 64, ((K + 31) / 32) * KERNEL_PROD_ALIGNED * 32 * sizeof(int8_t)));
+ posix_memalign(
+ (void**)&pmat_,
+ 64,
+ ((K + 31) / 32) * kernel_prod_aligned * 32 * sizeof(int8_t));
+
+ // Pack input matrix
+ // The layout is optimized to use vpmaddubsw efficiently (see
+ // madd_epi16x4_packed function).
+ // For a group of 32 channels, we have 10 32B SIMD registers.
+ // Denote ith channel jth filter as (i, j)
+ // 0th SIMD register:
+ // (0, 0), (0, 1), (0, 2), (0, 3), ..., (3, 0), (3, 1), (3, 2), (3, 3)
+ // (16, 0), (16, 1), (16, 2), (16, 3), ..., (19, 0), (19, 1), (19, 2), (19, 3)
+ // 1st SIMD register:
+ // (4, 0), (4, 1), (4, 2), (4, 3), ..., (7, 0), (7, 1), (7, 2), (7, 3)
+ // (20, 0), (20, 1), (20, 2), (20, 3), ..., (23, 0), (23, 1), (23, 2), (23, 3)
+ // 2nd SIMD register:
+ // (8, 0), (8, 1), (8, 2), (8, 3), ..., (11, 0), (11, 1), (11, 2), (11, 3)
+ // (24, 0), (24, 1), (24, 2), (24, 3), ..., (27, 0), (27, 1), (27, 2), (27, 3)
+ // 3rd SIMD register:
+ // (12, 0), (12, 1), (12, 2), (12, 3), ..., (15, 0), (15, 1), (15, 2), (15, 3)
+ // (28, 0), (28, 1), (28, 2), (28, 3), ..., (31, 0), (31, 1), (31, 2), (31, 3)
+ // 4-7th SIMD register: same as the previous 4 registers but for 4-7th filter
+ // coefficients
+ // ...
+ //
+ // REMAINDER
+ // If kernel_prod % 4 == 1 for example when kernel_prod == 9
+ // 8th SIMD register:
+ // (0, 8), zero, ..., (7, 8), zero
+ // (16, 8), zero, ..., (23, 8), zero
+ // 9th SIMD register:
+ // (8, 8), zero, ..., (15, 8), zero
+ // (24, 8), zero, ..., (31, 8), zero
+ // We use madd_epi16_packed for this case
+ //
+ // If kernel_prod % 4 == 2 for example when kernel_prod == 10
+ // 8th SIMD register:
+ // (0, 8), (0, 9), ..., (7, 8), (7, 9)
+ // (16, 8), (16, 9), ..., (23, 8), (23, 9)
+ // 9th SIMD register:
+ // (8, 8), (8, 9), ..., (15, 8), (15, 9)
+ // (24, 8), (24, 9), ..., (31, 8), (31, 9)
+ //
+ // If kernel_prod % 4 == 3 for example when kernel_prod == 11
+ // 8th SIMD register:
+ // (0, 8), (0, 9), (0, 10), zero, ..., (3, 8), (3, 9), (3, 10), zero
+ // (16, 8), (16, 9), (16, 10), zero, ..., (19, 8), (19, 9), (19, 10), zero
+ // 9th SIMD register:
+ // (4, 8), (4, 9), (4, 10), zero, ..., (7, 8), (7, 9), (7, 10), zero
+ // (20, 8), (20, 9), (20, 10), zero, ..., (23, 8), (23, 9), (23, 10), zero
+ // 10th SIMD register:
+ // (8, 8), (8, 9), (8, 10), zero, ..., (11, 8), (11, 9), (11, 10), zero
+ // (24, 8), (24, 9), (24, 10), zero, ..., (27, 8), (27, 9), (27, 10), zero
+ // 11th SIMD register:
+ // (12, 8), (12, 9), (12, 10), zero, ..., (15, 8), (15, 9), (15, 10), zero
+ // (28, 8), (28, 9), (28, 10), zero, ..., (31, 8), (31, 9), (31, 10), zero
+ for (int k1 = 0; k1 < K; k1 += 32) {
+ __m256i b_v[kernel_prod];
+ int remainder = K - k1;
+ if (remainder < 32) {
+ __m256i mask_v = _mm256_loadu_si256(
+ reinterpret_cast<const __m256i*>(masks[remainder / 4]));
+ for (int i = 0; i < kernel_prod; ++i) {
+ b_v[i] = _mm256_maskload_epi32(
+ reinterpret_cast<const int*>(smat_transposed + i * K + k1), mask_v);
+ }
+ } else {
+ for (int i = 0; i < kernel_prod; ++i) {
+ b_v[i] = _mm256_lddqu_si256(
+ reinterpret_cast<const __m256i*>(smat_transposed + i * K + k1));
+ }
+ }
+
+ // Interleave 2 SIMD registers
+ __m256i b_interleaved_epi16[kernel_prod_aligned];
+ __m256i zero_v = _mm256_setzero_si256();
+ for (int i = 0; i < kernel_prod_aligned / 2; ++i) {
+ if (2 * i + 1 >= kernel_prod) {
+ b_interleaved_epi16[2 * i] = _mm256_unpacklo_epi8(b_v[2 * i], zero_v);
+ b_interleaved_epi16[2 * i + 1] =
+ _mm256_unpackhi_epi8(b_v[2 * i], zero_v);
+ } else {
+ b_interleaved_epi16[2 * i] =
+ _mm256_unpacklo_epi8(b_v[2 * i], b_v[2 * i + 1]);
+ b_interleaved_epi16[2 * i + 1] =
+ _mm256_unpackhi_epi8(b_v[2 * i], b_v[2 * i + 1]);
+ }
+ }
+
+ // Interleave 4 SIMD registers
+ __m256i b_interleaved_epi32[kernel_prod_aligned];
+ for (int i = 0; i < kernel_prod_aligned / 4; ++i) {
+ b_interleaved_epi32[4 * i] = _mm256_unpacklo_epi16(
+ b_interleaved_epi16[4 * i], b_interleaved_epi16[4 * i + 2]);
+ b_interleaved_epi32[4 * i + 1] = _mm256_unpackhi_epi16(
+ b_interleaved_epi16[4 * i], b_interleaved_epi16[4 * i + 2]);
+ b_interleaved_epi32[4 * i + 2] = _mm256_unpacklo_epi16(
+ b_interleaved_epi16[4 * i + 1], b_interleaved_epi16[4 * i + 3]);
+ b_interleaved_epi32[4 * i + 3] = _mm256_unpackhi_epi16(
+ b_interleaved_epi16[4 * i + 1], b_interleaved_epi16[4 * i + 3]);
+ }
+ for (int i = kernel_prod_aligned / 4 * 4; i < kernel_prod_aligned; ++i) {
+ b_interleaved_epi32[i] = b_interleaved_epi16[i];
+ }
+
+ for (int i = 0; i < kernel_prod_aligned; ++i) {
+ _mm256_storeu_si256(
+ reinterpret_cast<__m256i*>(
+ &pmat_[((k1 / 32) * kernel_prod_aligned + i) * 32]),
+ b_interleaved_epi32[i]);
+ }
+ }
+}
+
+int PackedDepthWiseConvMatrix::addr(int r, int c) {
+ int kernel_prod_aligned = (kernel_prod_ + 1) / 2 * 2;
+ if (c >= kernel_prod_ / 4 * 4 &&
+ (kernel_prod_ % 4 == 1 || kernel_prod_ % 4 == 2)) {
+ int kBlock = r / 32;
+ int reg_idx = (r % 16) / 8 + c / 4 * 4;
+
+ int blk_idx = kBlock * kernel_prod_aligned + reg_idx;
+
+ int r_ = r % 8;
+ int c_ = c % 4;
+
+ int in_blk_idx = (r % 32) / 16 * 16 + 2 * r_ + c_;
+ return blk_idx * 32 + in_blk_idx;
+
+ } else {
+ int kBlock = r / 32;
+ int reg_idx = (r % 16) / 4 + c / 4 * 4;
+
+ int blk_idx = kBlock * kernel_prod_aligned + reg_idx;
+
+ int r_ = r % 4;
+ int c_ = c % 4;
+
+ int in_blk_idx = (r % 32) / 16 * 16 + 4 * r_ + c_;
+ return blk_idx * 32 + in_blk_idx;
+ }
+}
+
+void PackedDepthWiseConvMatrix::unpack(int8_t* unpacked_data) {
+ for (int r = 0; r < K_; ++r) {
+ for (int c = 0; c < kernel_prod_; ++c) {
+ unpacked_data[r * kernel_prod_ + c] = pmat_[addr(r, c)];
+ }
+ }
+}
+
+PackedDepthWiseConvMatrix::~PackedDepthWiseConvMatrix() {
+ free(pmat_);
+}
+
+} // namespace fbgemm
diff --git a/src/PackWeightsForConv.cc b/src/PackWeightsForConv.cc
index 44f210e..192fb00 100644
--- a/src/PackWeightsForConv.cc
+++ b/src/PackWeightsForConv.cc
@@ -23,35 +23,17 @@ PackWeightsForConv<SPATIAL_DIM, T, accT>::PackWeightsForConv(
// FbgemmConv.cc
switch (ConvFastPath<SPATIAL_DIM, accT>(conv_p)) {
case optimized_conv_t::depthwise: {
- if (SPATIAL_DIM == 3) {
- W_im2col_packed_ = nullptr;
- W_dw_2D_packed_ = nullptr;
- W_dw_3D_packed_ =
- std::make_shared<Packed3x3x3ConvMatrix>(conv_p.G, sdata);
- W_gconv_packed_ = nullptr;
- } else {
- W_im2col_packed_ = nullptr;
- W_dw_2D_packed_ =
- std::make_shared<Packed3x3ConvMatrix>(conv_p.G, sdata);
- W_dw_3D_packed_ = nullptr;
- W_gconv_packed_ = nullptr;
- }
+ W_dw_packed_ = std::make_shared<PackedDepthWiseConvMatrix>(
+ conv_p.G, SPATIAL_DIM == 3 ? 3 * 3 * 3 : 3 * 3, sdata);
break;
}
case optimized_conv_t::groupwise: {
- W_im2col_packed_ = nullptr;
- W_dw_2D_packed_ = nullptr;
- W_dw_3D_packed_ = nullptr;
W_gconv_packed_ =
std::make_shared<PackWeightMatrixForGConv<T, accT, SPATIAL_DIM>>(
matrix_op_t::Transpose, conv_p, sdata, nullptr);
break;
}
case optimized_conv_t::pointwise: {
- W_im2col_packed_ = nullptr;
- W_dw_2D_packed_ = nullptr;
- W_dw_3D_packed_ = nullptr;
- W_gconv_packed_ = nullptr;
int NDim = conv_p.OC / conv_p.G;
int KDim = conv_p.K[0] * conv_p.K[1] * conv_p.IC;
W_pointwise_packed_ = std::make_shared<PackBMatrix<T, accT>>(
@@ -77,9 +59,6 @@ PackWeightsForConv<SPATIAL_DIM, T, accT>::PackWeightsForConv(
nullptr,
conv_p.G,
blocking_params);
- W_dw_2D_packed_ = nullptr;
- W_dw_3D_packed_ = nullptr;
- W_gconv_packed_ = nullptr;
break;
}
} // switch
@@ -87,10 +66,8 @@ PackWeightsForConv<SPATIAL_DIM, T, accT>::PackWeightsForConv(
template <int SPATIAL_DIM, typename T, typename accT>
void PackWeightsForConv<SPATIAL_DIM, T, accT>::unpack(T* origin_buf) {
- if (W_dw_2D_packed_) {
- W_dw_2D_packed_->unpack(origin_buf);
- } else if (W_dw_3D_packed_) {
- W_dw_3D_packed_->unpack(origin_buf);
+ if (W_dw_packed_) {
+ W_dw_packed_->unpack(origin_buf);
} else if (W_gconv_packed_) {
W_gconv_packed_->unpack(origin_buf);
} else if (W_im2col_packed_) {
@@ -139,7 +116,7 @@ std::string PackWeightsForConv<SPATIAL_DIM, T, accT>::mismatchingParams(
};
auto combineInt = [&combineStr](std::string id, int int1, int int2) {
- return combineStr(id, std::to_string(int1), std::to_string(int2));
+ return combineStr(id, std::to_string(int1), std::to_string(int2));
};
if (conv_param_.IC != test_conv_p.IC) {
diff --git a/src/QuantUtils.cc b/src/QuantUtils.cc
index 5dde90b..a209efc 100644
--- a/src/QuantUtils.cc
+++ b/src/QuantUtils.cc
@@ -164,30 +164,34 @@ void ChooseRequantizationMultiplier(
dst[i] = Quantize<T>(src[i], qparams); \
} \
}
-FBGEMM_SPECIALIZED_QUANTIZE(int8_t)
FBGEMM_SPECIALIZED_QUANTIZE(uint16_t)
FBGEMM_SPECIALIZED_QUANTIZE(int16_t)
FBGEMM_SPECIALIZED_QUANTIZE(int32_t)
#undef FBGEMM_SPECIALIZED_QUANTIZE
-template <>
-void Quantize<uint8_t>(
- const float* src,
- uint8_t* dst,
- int len,
- const TensorQuantizationParams& qparams) {
- bool avx2_support = cpuinfo_initialize() && fbgemmHasAvx2Support();
- bool fma_support = cpuinfo_has_x86_fma3();
- if (avx2_support && fma_support && qparams.precision == 8) {
- // fast path
- QuantizeAvx2(src, dst, len, qparams);
- } else {
- for (std::size_t i = 0; i < len; ++i) {
- dst[i] = Quantize<uint8_t>(src[i], qparams);
- }
- }
+#define FBGEMM_SPECIALIZED_QUANTIZE_AVX2(T) \
+template <> \
+void Quantize<T>( \
+ const float* src, \
+ T* dst, \
+ int len, \
+ const TensorQuantizationParams& qparams) { \
+ bool avx2_support = cpuinfo_initialize() && fbgemmHasAvx2Support(); \
+ bool fma_support = cpuinfo_has_x86_fma3(); \
+ if (avx2_support && fma_support && qparams.precision == 8) { \
+ /* fast path */ \
+ QuantizeAvx2<T>(src, dst, len, qparams); \
+ } else { \
+ for (std::size_t i = 0; i < len; ++i) { \
+ dst[i] = Quantize<T>(src[i], qparams); \
+ } \
+ } \
}
+FBGEMM_SPECIALIZED_QUANTIZE_AVX2(int8_t)
+FBGEMM_SPECIALIZED_QUANTIZE_AVX2(uint8_t)
+#undef FBGEMM_SPECIALIZED_QUANTIZE_AVX2
+
#define FBGEMM_SPECIALIZED_QUANTIZEGROUPWISEKCX(T) \
template <> \
void QuantizeGroupwise<T, layout_t::KCX>( \
diff --git a/src/QuantUtilsAvx2.cc b/src/QuantUtilsAvx2.cc
index 875f1dc..9381f0c 100644
--- a/src/QuantUtilsAvx2.cc
+++ b/src/QuantUtilsAvx2.cc
@@ -18,13 +18,16 @@ using namespace std;
////////////////////////////////////////////////////////////////////////////////
// Utility functions
+template <typename T>
void QuantizeAvx2(
const float* src,
- uint8_t* dst,
+ T* dst,
int len,
const TensorQuantizationParams& qparams) {
#if defined(__AVX2__) && (defined(__FMA__) || defined(_MSC_VER))
constexpr int VLEN = 8;
+ constexpr float min_val = std::numeric_limits<T>::min();
+ constexpr float max_val = std::numeric_limits<T>::max();
std::size_t i = 0;
__m256 inverse_scale_v = _mm256_set1_ps(1.f / qparams.scale);
__m256i shuffle_mask_v = _mm256_set_epi8(
@@ -67,8 +70,8 @@ void QuantizeAvx2(
__m256 transformed_v = _mm256_fmadd_ps(
src_v, inverse_scale_v, _mm256_set1_ps(qparams.zero_point));
__m256 clipped_v = _mm256_min_ps(
- _mm256_max_ps(transformed_v, _mm256_set1_ps(0.f)),
- _mm256_set1_ps(255.f));
+ _mm256_max_ps(transformed_v, _mm256_set1_ps(min_val)),
+ _mm256_set1_ps(max_val));
__m256i rounded_v = _mm256_cvtps_epi32(clipped_v);
// An instruction sequence to save 8 32-bit integers as 8 8-bit integers
@@ -80,7 +83,7 @@ void QuantizeAvx2(
for (; i < len; ++i) {
float transformed = qparams.zero_point + src[i] / qparams.scale;
- float clipped = std::min(std::max(transformed, 0.f), 255.f);
+ float clipped = std::min(std::max(transformed, min_val), max_val);
// Not exactly the same behavior as the vectorized code.
// The vectorized code above always rounds to even in halfway cases
// (https://software.intel.com/en-us/node/523819), but std::nearbyint
@@ -95,6 +98,21 @@ void QuantizeAvx2(
#endif
}
+// Instantiate QuantizeAvx2 for known datatypes
+template
+void QuantizeAvx2<uint8_t>(
+ const float* src,
+ uint8_t* dst,
+ int len,
+ const TensorQuantizationParams& qparams);
+template
+void QuantizeAvx2<int8_t>(
+ const float* src,
+ int8_t* dst,
+ int len,
+ const TensorQuantizationParams& qparams);
+
+
void FindMinMax(const float* a, float* min, float* max, int len) {
if (len <= 0) {
*min = 0.0f;
@@ -264,14 +282,15 @@ template <
bool B_SYMMETRIC,
QuantizationGranularity Q_GRAN,
bool HAS_BIAS,
- bool FUSE_RELU>
+ bool FUSE_RELU,
+ typename BIAS_TYPE>
void requantizeOutputProcessingAvx2(
uint8_t* out,
const int32_t* inp,
const block_type_t& block,
int ld_out,
int ld_in,
- const requantizationParams_t& r) {
+ const requantizationParams_t<BIAS_TYPE>& r) {
// Adoption of implementation at QNNPACK/src/requantization/fp32-sse2.c
// using AVX2 instructions
int quant_param_idx = 0;
@@ -282,6 +301,15 @@ void requantizeOutputProcessingAvx2(
}
__m256 multiplier_v = _mm256_set1_ps(r.C_multiplier[quant_param_idx]);
+ // Broadcasted reciprocal of act_times_w_scale
+ __m256 act_times_w_rcp_v;
+ if (!(Q_GRAN == QuantizationGranularity::OUT_CHANNEL)) {
+ if (is_same<BIAS_TYPE, float>::value) {
+ act_times_w_rcp_v =
+ _mm256_set1_ps(1.0 / r.act_times_w_scale[quant_param_idx]);
+ }
+ }
+
__m256i min_v = _mm256_set1_epi8(static_cast<uint8_t>(0));
__m256i max_v = _mm256_set1_epi8(static_cast<uint8_t>(255));
@@ -391,22 +419,76 @@ void requantizeOutputProcessingAvx2(
}
w_v = _mm256_sub_epi32(w_v, row_offset_v);
}
+ __m256 xf_v, yf_v, zf_v, wf_v;
if (HAS_BIAS) {
- x_v = _mm256_add_epi32(
- x_v,
- _mm256_loadu_si256(reinterpret_cast<const __m256i*>(r.bias + j)));
- y_v = _mm256_add_epi32(
- y_v,
- _mm256_loadu_si256(
- reinterpret_cast<const __m256i*>(r.bias + j + VLEN)));
- z_v = _mm256_add_epi32(
- z_v,
- _mm256_loadu_si256(
- reinterpret_cast<const __m256i*>(r.bias + j + 2 * VLEN)));
- w_v = _mm256_add_epi32(
- w_v,
- _mm256_loadu_si256(
- reinterpret_cast<const __m256i*>(r.bias + j + 3 * VLEN)));
+ if (is_same<BIAS_TYPE, float>::value) {
+ __m256 x_bias_v, y_bias_v, z_bias_v, w_bias_v;
+ if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) {
+ x_bias_v = _mm256_div_ps(
+ _mm256_loadu_ps(
+ reinterpret_cast<const float*>(r.bias + j + 0 * VLEN)),
+ _mm256_loadu_ps(r.act_times_w_scale + j + 0 * VLEN));
+ y_bias_v = _mm256_div_ps(
+ _mm256_loadu_ps(
+ reinterpret_cast<const float*>(r.bias + j + 1 * VLEN)),
+ _mm256_loadu_ps(r.act_times_w_scale + j + 1 * VLEN));
+ z_bias_v = _mm256_div_ps(
+ _mm256_loadu_ps(
+ reinterpret_cast<const float*>(r.bias + j + 2 * VLEN)),
+ _mm256_loadu_ps(r.act_times_w_scale + j + 2 * VLEN));
+ w_bias_v = _mm256_div_ps(
+ _mm256_loadu_ps(
+ reinterpret_cast<const float*>(r.bias + j + 3 * VLEN)),
+ _mm256_loadu_ps(r.act_times_w_scale + j + 3 * VLEN));
+ } else {
+ x_bias_v = _mm256_mul_ps(
+ _mm256_loadu_ps(
+ reinterpret_cast<const float*>(r.bias + j + 0 * VLEN)),
+ act_times_w_rcp_v);
+ y_bias_v = _mm256_mul_ps(
+ _mm256_loadu_ps(
+ reinterpret_cast<const float*>(r.bias + j + 1 * VLEN)),
+ act_times_w_rcp_v);
+ z_bias_v = _mm256_mul_ps(
+ _mm256_loadu_ps(
+ reinterpret_cast<const float*>(r.bias + j + 2 * VLEN)),
+ act_times_w_rcp_v);
+ w_bias_v = _mm256_mul_ps(
+ _mm256_loadu_ps(
+ reinterpret_cast<const float*>(r.bias + j + 3 * VLEN)),
+ act_times_w_rcp_v);
+ }
+ xf_v = _mm256_add_ps(_mm256_cvtepi32_ps(x_v), x_bias_v);
+ yf_v = _mm256_add_ps(_mm256_cvtepi32_ps(y_v), y_bias_v);
+ zf_v = _mm256_add_ps(_mm256_cvtepi32_ps(z_v), z_bias_v);
+ wf_v = _mm256_add_ps(_mm256_cvtepi32_ps(w_v), w_bias_v);
+ } else {
+ x_v = _mm256_add_epi32(
+ x_v,
+ _mm256_loadu_si256(
+ reinterpret_cast<const __m256i*>(r.bias + j + 0 * VLEN)));
+ y_v = _mm256_add_epi32(
+ y_v,
+ _mm256_loadu_si256(
+ reinterpret_cast<const __m256i*>(r.bias + j + 1 * VLEN)));
+ z_v = _mm256_add_epi32(
+ z_v,
+ _mm256_loadu_si256(
+ reinterpret_cast<const __m256i*>(r.bias + j + 2 * VLEN)));
+ w_v = _mm256_add_epi32(
+ w_v,
+ _mm256_loadu_si256(
+ reinterpret_cast<const __m256i*>(r.bias + j + 3 * VLEN)));
+ xf_v = _mm256_cvtepi32_ps(x_v);
+ yf_v = _mm256_cvtepi32_ps(y_v);
+ zf_v = _mm256_cvtepi32_ps(z_v);
+ wf_v = _mm256_cvtepi32_ps(w_v);
+ }
+ } else {
+ xf_v = _mm256_cvtepi32_ps(x_v);
+ yf_v = _mm256_cvtepi32_ps(y_v);
+ zf_v = _mm256_cvtepi32_ps(z_v);
+ wf_v = _mm256_cvtepi32_ps(w_v);
}
/*
@@ -423,22 +505,19 @@ void requantizeOutputProcessingAvx2(
*/
__m256 x_scaled_v, y_scaled_v, z_scaled_v, w_scaled_v;
if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) {
- x_scaled_v = _mm256_mul_ps(
- _mm256_cvtepi32_ps(x_v), _mm256_loadu_ps(r.C_multiplier + j));
- y_scaled_v = _mm256_mul_ps(
- _mm256_cvtepi32_ps(y_v),
- _mm256_loadu_ps(r.C_multiplier + j + VLEN));
- z_scaled_v = _mm256_mul_ps(
- _mm256_cvtepi32_ps(z_v),
- _mm256_loadu_ps(r.C_multiplier + j + 2 * VLEN));
- w_scaled_v = _mm256_mul_ps(
- _mm256_cvtepi32_ps(w_v),
- _mm256_loadu_ps(r.C_multiplier + j + 3 * VLEN));
+ x_scaled_v =
+ _mm256_mul_ps(xf_v, _mm256_loadu_ps(r.C_multiplier + j + 0 * VLEN));
+ y_scaled_v =
+ _mm256_mul_ps(yf_v, _mm256_loadu_ps(r.C_multiplier + j + 1 * VLEN));
+ z_scaled_v =
+ _mm256_mul_ps(zf_v, _mm256_loadu_ps(r.C_multiplier + j + 2 * VLEN));
+ w_scaled_v =
+ _mm256_mul_ps(wf_v, _mm256_loadu_ps(r.C_multiplier + j + 3 * VLEN));
} else {
- x_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(x_v), multiplier_v);
- y_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(y_v), multiplier_v);
- z_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(z_v), multiplier_v);
- w_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(w_v), multiplier_v);
+ x_scaled_v = _mm256_mul_ps(xf_v, multiplier_v);
+ y_scaled_v = _mm256_mul_ps(yf_v, multiplier_v);
+ z_scaled_v = _mm256_mul_ps(zf_v, multiplier_v);
+ w_scaled_v = _mm256_mul_ps(wf_v, multiplier_v);
}
/*
@@ -525,18 +604,35 @@ void requantizeOutputProcessingAvx2(
}
x_v = _mm256_sub_epi32(x_v, row_offset_v);
}
+ __m256 xf_v;
if (HAS_BIAS) {
- x_v = _mm256_add_epi32(
- x_v,
- _mm256_loadu_si256(reinterpret_cast<const __m256i*>(r.bias + j)));
+ if (is_same<BIAS_TYPE, float>::value) {
+ __m256 x_bias_v;
+ if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) {
+ x_bias_v = _mm256_div_ps(
+ _mm256_loadu_ps(reinterpret_cast<const float*>(r.bias + j)),
+ _mm256_loadu_ps(r.act_times_w_scale + j));
+ } else {
+ x_bias_v = _mm256_mul_ps(
+ _mm256_loadu_ps(reinterpret_cast<const float*>(r.bias + j)),
+ act_times_w_rcp_v);
+ }
+ xf_v = _mm256_add_ps(_mm256_cvtepi32_ps(x_v), x_bias_v);
+ } else {
+ x_v = _mm256_add_epi32(
+ x_v,
+ _mm256_loadu_si256(reinterpret_cast<const __m256i*>(r.bias + j)));
+ xf_v = _mm256_cvtepi32_ps(x_v);
+ }
+ } else {
+ xf_v = _mm256_cvtepi32_ps(x_v);
}
__m256 x_scaled_v;
if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) {
- x_scaled_v = _mm256_mul_ps(
- _mm256_cvtepi32_ps(x_v), _mm256_loadu_ps(r.C_multiplier + j));
+ x_scaled_v = _mm256_mul_ps(xf_v, _mm256_loadu_ps(r.C_multiplier + j));
} else {
- x_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(x_v), multiplier_v);
+ x_scaled_v = _mm256_mul_ps(xf_v, multiplier_v);
}
__m256i x_rounded_v = _mm256_cvtps_epi32(x_scaled_v);
@@ -574,6 +670,7 @@ void requantizeOutputProcessingAvx2(
int remainder = block.col_start + block.col_size - j;
if (remainder > 0) {
+ // clang-format off
alignas(64) const int masks[8][8] = {
// NOTE: clang-format wants to use a different formatting but the
// current formatting should be easier to read.
@@ -586,6 +683,7 @@ void requantizeOutputProcessingAvx2(
{ -1, -1, -1, -1, -1, -1, 0, 0, },
{ -1, -1, -1, -1, -1, -1, -1, 0, },
};
+ // clang-format on
__m256i mask_v = _mm256_load_si256(
reinterpret_cast<const __m256i*>(masks[remainder]));
@@ -607,17 +705,40 @@ void requantizeOutputProcessingAvx2(
}
x_v = _mm256_sub_epi32(x_v, row_offset_v);
}
+
+ __m256 xf_v;
if (HAS_BIAS) {
- x_v = _mm256_add_epi32(x_v, _mm256_maskload_epi32(r.bias + j, mask_v));
+ if (is_same<BIAS_TYPE, float>::value) {
+ __m256 x_bias_v;
+ if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) {
+ x_bias_v = _mm256_div_ps(
+ _mm256_maskload_ps(
+ reinterpret_cast<const float*>(r.bias + j), mask_v),
+ _mm256_maskload_ps(r.act_times_w_scale + j, mask_v));
+ } else {
+ x_bias_v = _mm256_mul_ps(
+ _mm256_maskload_ps(
+ reinterpret_cast<const float*>(r.bias + j), mask_v),
+ act_times_w_rcp_v);
+ }
+ xf_v = _mm256_add_ps(_mm256_cvtepi32_ps(x_v), x_bias_v);
+ } else {
+ x_v = _mm256_add_epi32(
+ x_v,
+ _mm256_maskload_epi32(
+ reinterpret_cast<const int*>(r.bias + j), mask_v));
+ xf_v = _mm256_cvtepi32_ps(x_v);
+ }
+ } else {
+ xf_v = _mm256_cvtepi32_ps(x_v);
}
__m256 x_scaled_v;
if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) {
- x_scaled_v = _mm256_mul_ps(
- _mm256_cvtepi32_ps(x_v),
- _mm256_maskload_ps(r.C_multiplier + j, mask_v));
+ x_scaled_v =
+ _mm256_mul_ps(xf_v, _mm256_maskload_ps(r.C_multiplier + j, mask_v));
} else {
- x_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(x_v), multiplier_v);
+ x_scaled_v = _mm256_mul_ps(xf_v, multiplier_v);
}
__m256i x_rounded_v = _mm256_cvtps_epi32(x_scaled_v);
@@ -759,6 +880,7 @@ void requantizeForFloatAvx2(
int remainder = block.col_start + block.col_size - j;
if (remainder > 0) {
+ // clang-format off
alignas(64) const int masks[8][8] = {
// NOTE: clang-format wants to use a different formatting but the
// current formatting should be easier to read.
@@ -771,6 +893,7 @@ void requantizeForFloatAvx2(
{ -1, -1, -1, -1, -1, -1, 0, 0, },
{ -1, -1, -1, -1, -1, -1, -1, 0, },
};
+ // clang-format on
__m256i mask_v = _mm256_load_si256(
reinterpret_cast<const __m256i*>(masks[remainder]));
@@ -823,14 +946,15 @@ template <
QuantizationGranularity Q_GRAN,
bool HAS_BIAS,
bool FUSE_RELU,
- int C_PER_G>
+ int C_PER_G,
+ typename BIAS_TYPE>
void requantizeOutputProcessingGConvAvx2(
uint8_t* out,
const int32_t* inp,
const block_type_t& block,
int ld_out,
int ld_in,
- const requantizationParams_t& r) {
+ const requantizationParams_t<BIAS_TYPE>& r) {
// Adoption of implementation at QNNPACK/src/requantization/fp32-sse2.c
// using AVX2 instructions
int quant_param_idx = 0;
@@ -841,6 +965,14 @@ void requantizeOutputProcessingGConvAvx2(
}
__m256 multiplier_v = _mm256_set1_ps(r.C_multiplier[quant_param_idx]);
+ // Broadcasted reciprocal of act_times_w_scale
+ __m256 act_times_w_rcp_v;
+ if (!(Q_GRAN == QuantizationGranularity::OUT_CHANNEL)) {
+ if (is_same<BIAS_TYPE, float>::value) {
+ act_times_w_rcp_v =
+ _mm256_set1_ps(1.0 / r.act_times_w_scale[quant_param_idx]);
+ }
+ }
__m256i min_v = _mm256_set1_epi8(static_cast<uint8_t>(0));
__m256i max_v = _mm256_set1_epi8(static_cast<uint8_t>(255));
@@ -1087,22 +1219,135 @@ void requantizeOutputProcessingGConvAvx2(
}
w_v = _mm256_sub_epi32(w_v, row_offset_v);
}
+ __m256 xf_v, yf_v, zf_v, wf_v;
if (HAS_BIAS) {
- x_v = _mm256_add_epi32(
- x_v,
- _mm256_loadu_si256(reinterpret_cast<const __m256i*>(r.bias + j)));
- y_v = _mm256_add_epi32(
- y_v,
- _mm256_loadu_si256(
- reinterpret_cast<const __m256i*>(r.bias + j + VLEN)));
- z_v = _mm256_add_epi32(
- z_v,
- _mm256_loadu_si256(
- reinterpret_cast<const __m256i*>(r.bias + j + 2 * VLEN)));
- w_v = _mm256_add_epi32(
- w_v,
- _mm256_loadu_si256(
- reinterpret_cast<const __m256i*>(r.bias + j + 3 * VLEN)));
+ if (is_same<BIAS_TYPE, float>::value) {
+ __m256 x_bias_v = _mm256_loadu_ps(
+ reinterpret_cast<const float*>(r.bias + j + 0 * VLEN));
+ __m256 y_bias_v = _mm256_loadu_ps(
+ reinterpret_cast<const float*>(r.bias + j + 1 * VLEN));
+ __m256 z_bias_v = _mm256_loadu_ps(
+ reinterpret_cast<const float*>(r.bias + j + 2 * VLEN));
+ __m256 w_bias_v = _mm256_loadu_ps(
+ reinterpret_cast<const float*>(r.bias + j + 3 * VLEN));
+ if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) {
+ x_bias_v = _mm256_div_ps(
+ x_bias_v, _mm256_loadu_ps(r.act_times_w_scale + j + 0 * VLEN));
+ y_bias_v = _mm256_div_ps(
+ y_bias_v, _mm256_loadu_ps(r.act_times_w_scale + j + 1 * VLEN));
+ z_bias_v = _mm256_div_ps(
+ z_bias_v, _mm256_loadu_ps(r.act_times_w_scale + j + 2 * VLEN));
+ w_bias_v = _mm256_div_ps(
+ w_bias_v, _mm256_loadu_ps(r.act_times_w_scale + j + 3 * VLEN));
+ } else if (Q_GRAN == QuantizationGranularity::GROUP) {
+ __m256 diviser_v;
+ if (C_PER_G == 4) {
+ diviser_v = _mm256_insertf128_ps(
+ _mm256_castps128_ps256(
+ _mm_set1_ps(r.act_times_w_scale[quant_param_idx + 0])),
+ _mm_set1_ps(r.act_times_w_scale[quant_param_idx + 1]),
+ 1);
+ x_bias_v = _mm256_div_ps(x_bias_v, diviser_v);
+
+ diviser_v = _mm256_insertf128_ps(
+ _mm256_castps128_ps256(
+ _mm_set1_ps(r.act_times_w_scale[quant_param_idx + 2])),
+ _mm_set1_ps(r.act_times_w_scale[quant_param_idx + 3]),
+ 1);
+ y_bias_v = _mm256_div_ps(y_bias_v, diviser_v);
+
+ diviser_v = _mm256_insertf128_ps(
+ _mm256_castps128_ps256(
+ _mm_set1_ps(r.act_times_w_scale[quant_param_idx + 4])),
+ _mm_set1_ps(r.act_times_w_scale[quant_param_idx + 5]),
+ 1);
+ z_bias_v = _mm256_div_ps(z_bias_v, diviser_v);
+
+ diviser_v = _mm256_insertf128_ps(
+ _mm256_castps128_ps256(
+ _mm_set1_ps(r.act_times_w_scale[quant_param_idx + 6])),
+ _mm_set1_ps(r.act_times_w_scale[quant_param_idx + 7]),
+ 1);
+ w_bias_v = _mm256_div_ps(w_bias_v, diviser_v);
+
+ } else if (C_PER_G == 8) {
+ diviser_v = _mm256_set1_ps(
+ r.act_times_w_scale
+ [quant_param_idx +
+ (j - block.col_start) / (VLEN * 4) * 4 + 0]);
+ x_bias_v = _mm256_div_ps(x_bias_v, diviser_v);
+
+ diviser_v = _mm256_set1_ps(
+ r.act_times_w_scale
+ [quant_param_idx +
+ (j - block.col_start) / (VLEN * 4) * 4 + 1]);
+ y_bias_v = _mm256_div_ps(y_bias_v, diviser_v);
+
+ diviser_v = _mm256_set1_ps(
+ r.act_times_w_scale
+ [quant_param_idx +
+ (j - block.col_start) / (VLEN * 4) * 4 + 2]);
+ z_bias_v = _mm256_div_ps(z_bias_v, diviser_v);
+
+ diviser_v = _mm256_set1_ps(
+ r.act_times_w_scale
+ [quant_param_idx +
+ (j - block.col_start) / (VLEN * 4) * 4 + 3]);
+ w_bias_v = _mm256_div_ps(w_bias_v, diviser_v);
+
+ } else {
+ assert(C_PER_G == 16);
+ diviser_v = _mm256_set1_ps(
+ r.act_times_w_scale
+ [quant_param_idx +
+ (j - block.col_start) / (VLEN * 4) * 2 + 0]);
+ x_bias_v = _mm256_div_ps(x_bias_v, diviser_v);
+ y_bias_v = _mm256_div_ps(y_bias_v, diviser_v);
+
+ diviser_v = _mm256_set1_ps(
+ r.act_times_w_scale
+ [quant_param_idx +
+ (j - block.col_start) / (VLEN * 4) * 2 + 1]);
+ z_bias_v = _mm256_div_ps(z_bias_v, diviser_v);
+ w_bias_v = _mm256_div_ps(w_bias_v, diviser_v);
+ }
+ } else {
+ x_bias_v = _mm256_mul_ps(x_bias_v, act_times_w_rcp_v);
+ y_bias_v = _mm256_mul_ps(y_bias_v, act_times_w_rcp_v);
+ z_bias_v = _mm256_mul_ps(z_bias_v, act_times_w_rcp_v);
+ w_bias_v = _mm256_mul_ps(w_bias_v, act_times_w_rcp_v);
+ }
+ xf_v = _mm256_add_ps(_mm256_cvtepi32_ps(x_v), x_bias_v);
+ yf_v = _mm256_add_ps(_mm256_cvtepi32_ps(y_v), y_bias_v);
+ zf_v = _mm256_add_ps(_mm256_cvtepi32_ps(z_v), z_bias_v);
+ wf_v = _mm256_add_ps(_mm256_cvtepi32_ps(w_v), w_bias_v);
+ } else {
+ x_v = _mm256_add_epi32(
+ x_v,
+ _mm256_loadu_si256(
+ reinterpret_cast<const __m256i*>(r.bias + j + 0 * VLEN)));
+ y_v = _mm256_add_epi32(
+ y_v,
+ _mm256_loadu_si256(
+ reinterpret_cast<const __m256i*>(r.bias + j + 1 * VLEN)));
+ z_v = _mm256_add_epi32(
+ z_v,
+ _mm256_loadu_si256(
+ reinterpret_cast<const __m256i*>(r.bias + j + 2 * VLEN)));
+ w_v = _mm256_add_epi32(
+ w_v,
+ _mm256_loadu_si256(
+ reinterpret_cast<const __m256i*>(r.bias + j + 3 * VLEN)));
+ xf_v = _mm256_cvtepi32_ps(x_v);
+ yf_v = _mm256_cvtepi32_ps(y_v);
+ zf_v = _mm256_cvtepi32_ps(z_v);
+ wf_v = _mm256_cvtepi32_ps(w_v);
+ }
+ } else {
+ xf_v = _mm256_cvtepi32_ps(x_v);
+ yf_v = _mm256_cvtepi32_ps(y_v);
+ zf_v = _mm256_cvtepi32_ps(z_v);
+ wf_v = _mm256_cvtepi32_ps(w_v);
}
/*
@@ -1119,17 +1364,13 @@ void requantizeOutputProcessingGConvAvx2(
*/
__m256 x_scaled_v, y_scaled_v, z_scaled_v, w_scaled_v;
if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) {
- x_scaled_v = _mm256_mul_ps(
- _mm256_cvtepi32_ps(x_v), _mm256_loadu_ps(r.C_multiplier + j));
- y_scaled_v = _mm256_mul_ps(
- _mm256_cvtepi32_ps(y_v),
- _mm256_loadu_ps(r.C_multiplier + j + VLEN));
- z_scaled_v = _mm256_mul_ps(
- _mm256_cvtepi32_ps(z_v),
- _mm256_loadu_ps(r.C_multiplier + j + 2 * VLEN));
- w_scaled_v = _mm256_mul_ps(
- _mm256_cvtepi32_ps(w_v),
- _mm256_loadu_ps(r.C_multiplier + j + 3 * VLEN));
+ x_scaled_v = _mm256_mul_ps(xf_v, _mm256_loadu_ps(r.C_multiplier + j));
+ y_scaled_v =
+ _mm256_mul_ps(yf_v, _mm256_loadu_ps(r.C_multiplier + j + VLEN));
+ z_scaled_v =
+ _mm256_mul_ps(zf_v, _mm256_loadu_ps(r.C_multiplier + j + 2 * VLEN));
+ w_scaled_v =
+ _mm256_mul_ps(wf_v, _mm256_loadu_ps(r.C_multiplier + j + 3 * VLEN));
} else if (Q_GRAN == QuantizationGranularity::GROUP) {
if (C_PER_G == 4) {
multiplier_v = _mm256_insertf128_ps(
@@ -1137,70 +1378,70 @@ void requantizeOutputProcessingGConvAvx2(
_mm_set1_ps(r.C_multiplier[quant_param_idx])),
_mm_set1_ps(r.C_multiplier[quant_param_idx + 1]),
1);
- x_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(x_v), multiplier_v);
+ x_scaled_v = _mm256_mul_ps(xf_v, multiplier_v);
multiplier_v = _mm256_insertf128_ps(
_mm256_castps128_ps256(
_mm_set1_ps(r.C_multiplier[quant_param_idx + 2])),
_mm_set1_ps(r.C_multiplier[quant_param_idx + 3]),
1);
- y_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(y_v), multiplier_v);
+ y_scaled_v = _mm256_mul_ps(yf_v, multiplier_v);
multiplier_v = _mm256_insertf128_ps(
_mm256_castps128_ps256(
_mm_set1_ps(r.C_multiplier[quant_param_idx + 4])),
_mm_set1_ps(r.C_multiplier[quant_param_idx + 5]),
1);
- z_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(z_v), multiplier_v);
+ z_scaled_v = _mm256_mul_ps(zf_v, multiplier_v);
multiplier_v = _mm256_insertf128_ps(
_mm256_castps128_ps256(
_mm_set1_ps(r.C_multiplier[quant_param_idx + 6])),
_mm_set1_ps(r.C_multiplier[quant_param_idx + 7]),
1);
- w_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(w_v), multiplier_v);
+ w_scaled_v = _mm256_mul_ps(wf_v, multiplier_v);
} else if (C_PER_G == 8) {
multiplier_v = _mm256_set1_ps(
r.C_multiplier
[quant_param_idx + (j - block.col_start) / (VLEN * 4) * 4]);
- x_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(x_v), multiplier_v);
-
- multiplier_v = _mm256_set1_ps(
- r.C_multiplier
- [quant_param_idx + (j - block.col_start) / (VLEN * 4) * 4 +
- 1]);
- y_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(y_v), multiplier_v);
-
- multiplier_v = _mm256_set1_ps(
- r.C_multiplier
- [quant_param_idx + (j - block.col_start) / (VLEN * 4) * 4 +
- 2]);
- z_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(z_v), multiplier_v);
-
- multiplier_v = _mm256_set1_ps(
- r.C_multiplier
- [quant_param_idx + (j - block.col_start) / (VLEN * 4) * 4 +
- 3]);
- w_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(w_v), multiplier_v);
+ x_scaled_v = _mm256_mul_ps(xf_v, multiplier_v);
+
+ multiplier_v =
+ _mm256_set1_ps(r.C_multiplier
+ [quant_param_idx +
+ (j - block.col_start) / (VLEN * 4) * 4 + 1]);
+ y_scaled_v = _mm256_mul_ps(yf_v, multiplier_v);
+
+ multiplier_v =
+ _mm256_set1_ps(r.C_multiplier
+ [quant_param_idx +
+ (j - block.col_start) / (VLEN * 4) * 4 + 2]);
+ z_scaled_v = _mm256_mul_ps(zf_v, multiplier_v);
+
+ multiplier_v =
+ _mm256_set1_ps(r.C_multiplier
+ [quant_param_idx +
+ (j - block.col_start) / (VLEN * 4) * 4 + 3]);
+ w_scaled_v = _mm256_mul_ps(wf_v, multiplier_v);
} else {
multiplier_v = _mm256_set1_ps(
r.C_multiplier
[quant_param_idx + (j - block.col_start) / (VLEN * 4) * 2]);
- x_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(x_v), multiplier_v);
- y_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(y_v), multiplier_v);
-
- multiplier_v = _mm256_set1_ps(
- r.C_multiplier
- [quant_param_idx + (j - block.col_start) / (VLEN * 4) * 2 +
- 1]);
- z_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(z_v), multiplier_v);
- w_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(w_v), multiplier_v);
+ x_scaled_v = _mm256_mul_ps(xf_v, multiplier_v);
+ y_scaled_v = _mm256_mul_ps(yf_v, multiplier_v);
+
+ multiplier_v =
+ _mm256_set1_ps(r.C_multiplier
+ [quant_param_idx +
+ (j - block.col_start) / (VLEN * 4) * 2 + 1]);
+ z_scaled_v = _mm256_mul_ps(zf_v, multiplier_v);
+ w_scaled_v = _mm256_mul_ps(wf_v, multiplier_v);
}
} else {
- x_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(x_v), multiplier_v);
- y_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(y_v), multiplier_v);
- z_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(z_v), multiplier_v);
- w_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(w_v), multiplier_v);
+ x_scaled_v = _mm256_mul_ps(xf_v, multiplier_v);
+ y_scaled_v = _mm256_mul_ps(yf_v, multiplier_v);
+ z_scaled_v = _mm256_mul_ps(zf_v, multiplier_v);
+ w_scaled_v = _mm256_mul_ps(wf_v, multiplier_v);
}
/*
@@ -1271,46 +1512,69 @@ void requantizeOutputProcessingGConvAvx2(
} // i loop
}
-#define INSTANTIATE_REQUANTIZE(A_SYM, B_SYM, Q_GRAN, BIAS, RELU) \
- template void \
- requantizeOutputProcessingAvx2<A_SYM, B_SYM, Q_GRAN, BIAS, RELU>( \
- uint8_t * out, \
- const int32_t* inp, \
- const block_type_t& block, \
- int ld_out, \
- int ld_in, \
- const requantizationParams_t& r); \
- template void requantizeForFloatAvx2<A_SYM, B_SYM, Q_GRAN, BIAS, RELU>( \
- float* out, \
- const int32_t* inp, \
- const block_type_t& block, \
- int ld_out, \
- int ld_in, \
- const requantizationForFloatParams_t& r); \
- template void \
- requantizeOutputProcessingGConvAvx2<A_SYM, B_SYM, Q_GRAN, BIAS, RELU, 4>( \
- uint8_t * out, \
- const int32_t* inp, \
- const block_type_t& block, \
- int ld_out, \
- int ld_in, \
- const requantizationParams_t& r); \
- template void \
- requantizeOutputProcessingGConvAvx2<A_SYM, B_SYM, Q_GRAN, BIAS, RELU, 8>( \
- uint8_t * out, \
- const int32_t* inp, \
- const block_type_t& block, \
- int ld_out, \
- int ld_in, \
- const requantizationParams_t& r); \
- template void \
- requantizeOutputProcessingGConvAvx2<A_SYM, B_SYM, Q_GRAN, BIAS, RELU, 16>( \
- uint8_t * out, \
- const int32_t* inp, \
- const block_type_t& block, \
- int ld_out, \
- int ld_in, \
- const requantizationParams_t& r);
+#define INSTANTIATE_REQUANTIZE_BIAS_TYPE( \
+ A_SYM, B_SYM, Q_GRAN, BIAS, RELU, BIAS_TYPE) \
+ template void \
+ requantizeOutputProcessingAvx2<A_SYM, B_SYM, Q_GRAN, BIAS, RELU, BIAS_TYPE>( \
+ uint8_t * out, \
+ const int32_t* inp, \
+ const block_type_t& block, \
+ int ld_out, \
+ int ld_in, \
+ const requantizationParams_t<BIAS_TYPE>& r); \
+ template void requantizeOutputProcessingGConvAvx2< \
+ A_SYM, \
+ B_SYM, \
+ Q_GRAN, \
+ BIAS, \
+ RELU, \
+ 4, \
+ BIAS_TYPE>( \
+ uint8_t * out, \
+ const int32_t* inp, \
+ const block_type_t& block, \
+ int ld_out, \
+ int ld_in, \
+ const requantizationParams_t<BIAS_TYPE>& r); \
+ template void requantizeOutputProcessingGConvAvx2< \
+ A_SYM, \
+ B_SYM, \
+ Q_GRAN, \
+ BIAS, \
+ RELU, \
+ 8, \
+ BIAS_TYPE>( \
+ uint8_t * out, \
+ const int32_t* inp, \
+ const block_type_t& block, \
+ int ld_out, \
+ int ld_in, \
+ const requantizationParams_t<BIAS_TYPE>& r); \
+ template void requantizeOutputProcessingGConvAvx2< \
+ A_SYM, \
+ B_SYM, \
+ Q_GRAN, \
+ BIAS, \
+ RELU, \
+ 16, \
+ BIAS_TYPE>( \
+ uint8_t * out, \
+ const int32_t* inp, \
+ const block_type_t& block, \
+ int ld_out, \
+ int ld_in, \
+ const requantizationParams_t<BIAS_TYPE>& r);
+
+#define INSTANTIATE_REQUANTIZE(A_SYM, B_SYM, Q_GRAN, BIAS, RELU) \
+ INSTANTIATE_REQUANTIZE_BIAS_TYPE(A_SYM, B_SYM, Q_GRAN, BIAS, RELU, float) \
+ INSTANTIATE_REQUANTIZE_BIAS_TYPE(A_SYM, B_SYM, Q_GRAN, BIAS, RELU, int32_t) \
+ template void requantizeForFloatAvx2<A_SYM, B_SYM, Q_GRAN, BIAS, RELU>( \
+ float* out, \
+ const int32_t* inp, \
+ const block_type_t& block, \
+ int ld_out, \
+ int ld_in, \
+ const requantizationForFloatParams_t& r);
#define INSTANTIATE_A_SYM(B_SYM, Q_GRAN, BIAS, RELU) \
INSTANTIATE_REQUANTIZE(true, B_SYM, Q_GRAN, BIAS, RELU) \
diff --git a/src/RefImplementations.cc b/src/RefImplementations.cc
index e3c0eac..dc40d44 100644
--- a/src/RefImplementations.cc
+++ b/src/RefImplementations.cc
@@ -300,9 +300,11 @@ void im2col_ref(
for (int h = 0; h < OUT_DIM[0]; ++h) {
for (int w = 0; w < OUT_DIM[1]; ++w) {
for (int r = 0; r < K[0]; ++r) {
- int h_in = -conv_p.pad[0] + h * conv_p.stride[0] + r;
+ int h_in =
+ -conv_p.pad[0] + h * conv_p.stride[0] + r * conv_p.dilation[0];
for (int s = 0; s < K[1]; ++s) {
- int w_in = -conv_p.pad[1] + w * conv_p.stride[1] + s;
+ int w_in =
+ -conv_p.pad[1] + w * conv_p.stride[1] + s * conv_p.dilation[1];
if (h_in < 0 || h_in >= IN_DIM[0] || w_in < 0 ||
w_in >= IN_DIM[1]) {
for (int g = 0; g < G; ++g) {
@@ -363,11 +365,14 @@ void im2col_ref(
for (int h = 0; h < OUT_DIM[1]; ++h) {
for (int w = 0; w < OUT_DIM[2]; ++w) {
for (int q = 0; q < K[0]; ++q) {
- int t_in = -conv_p.pad[0] + t * conv_p.stride[0] + q;
+ int t_in =
+ -conv_p.pad[0] + t * conv_p.stride[0] + q * conv_p.dilation[0];
for (int r = 0; r < K[1]; ++r) {
- int h_in = -conv_p.pad[1] + h * conv_p.stride[1] + r;
+ int h_in = -conv_p.pad[1] + h * conv_p.stride[1] +
+ r * conv_p.dilation[1];
for (int s = 0; s < K[2]; ++s) {
- int w_in = -conv_p.pad[2] + w * conv_p.stride[2] + s;
+ int w_in = -conv_p.pad[2] + w * conv_p.stride[2] +
+ s * conv_p.dilation[2];
if (t_in < 0 || t_in >= IN_DIM[0] || h_in < 0 ||
h_in >= IN_DIM[1] || w_in < 0 || w_in >= IN_DIM[2]) {
for (int g = 0; g < G; ++g) {
@@ -447,9 +452,11 @@ void conv_ref(
for (int m = 0; m < OC / G; ++m) {
int sum = 0;
for (int r = 0; r < K[0]; ++r) {
- int h_in = -conv_p.pad[0] + h * conv_p.stride[0] + r;
+ int h_in = -conv_p.pad[0] + h * conv_p.stride[0] +
+ r * conv_p.dilation[0];
for (int s = 0; s < K[1]; ++s) {
- int w_in = -conv_p.pad[1] + w * conv_p.stride[1] + s;
+ int w_in = -conv_p.pad[1] + w * conv_p.stride[1] +
+ s * conv_p.dilation[1];
for (int c = 0; c < IC / G; ++c) {
int a = h_in < 0 || h_in >= IN_DIM[0] || w_in < 0 ||
w_in >= IN_DIM[1]
@@ -499,11 +506,14 @@ void conv_ref(
for (int m = 0; m < OC / G; ++m) {
int sum = 0;
for (int q = 0; q < K[0]; ++q) {
- int t_in = -conv_p.pad[0] + t * conv_p.stride[0] + q;
+ int t_in = -conv_p.pad[0] + t * conv_p.stride[0] +
+ q * conv_p.dilation[0];
for (int r = 0; r < K[1]; ++r) {
- int h_in = -conv_p.pad[1] + h * conv_p.stride[1] + r;
+ int h_in = -conv_p.pad[1] + h * conv_p.stride[1] +
+ r * conv_p.dilation[1];
for (int s = 0; s < K[2]; ++s) {
- int w_in = -conv_p.pad[2] + w * conv_p.stride[2] + s;
+ int w_in = -conv_p.pad[2] + w * conv_p.stride[2] +
+ s * conv_p.dilation[2];
for (int c = 0; c < IC / G; ++c) {
int a = t_in < 0 || t_in >= IN_DIM[0] || h_in < 0 ||
h_in >= IN_DIM[1] || w_in < 0 || w_in >= IN_DIM[2]
@@ -590,406 +600,6 @@ void transposeConvWeights(
}
}
-void depthwise_3x3_pad_1_ref(
- int N,
- int H,
- int W,
- int K,
- int stride_h,
- int stride_w,
- int32_t A_zero_point,
- const uint8_t* A,
- const int8_t* B,
- int32_t* C) {
- constexpr int R = 3, S = 3;
- constexpr int PAD_T = 1, PAD_B = 1, PAD_L = 1, PAD_R = 1;
- int H_OUT = (H + PAD_T + PAD_B - R) / stride_h + 1;
- int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1;
-
- for (int n = 0; n < N; ++n) {
- for (int h = 0; h < H_OUT; ++h) {
- for (int w = 0; w < W_OUT; ++w) {
- for (int k = 0; k < K; ++k) {
- int sum = 0;
- for (int r = 0; r < R; ++r) {
- int h_in = -PAD_T + h * stride_h + r;
- for (int s = 0; s < S; ++s) {
- int w_in = -PAD_L + w * stride_w + s;
- int a = h_in < 0 || h_in >= H || w_in < 0 || w_in >= W
- ? A_zero_point
- : A[((n * H + h_in) * W + w_in) * K + k];
- int b = B[(k * R + r) * S + s];
- sum += a * b;
- }
- }
- C[((n * H_OUT + h) * W_OUT + w) * K + k] = sum;
- }
- }
- }
- } // for each n
-};
-
-void depthwise_3x3_pad_1_ref(
- int N,
- int H,
- int W,
- int K,
- int stride_h,
- int stride_w,
- int32_t A_zero_point,
- const uint8_t* A,
- int32_t B_zero_point,
- const int8_t* B,
- float C_multiplier,
- int32_t C_zero_point,
- uint8_t* C,
- const int32_t* col_offsets,
- const int32_t* bias) {
- constexpr int R = 3, S = 3;
- constexpr int PAD_T = 1, PAD_B = 1, PAD_L = 1, PAD_R = 1;
- int H_OUT = (H + PAD_T + PAD_B - R) / stride_h + 1;
- int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1;
-
- vector<int32_t> C_int32(N * H_OUT * W_OUT * K);
- depthwise_3x3_pad_1_ref(
- N, H, W, K, stride_h, stride_w, A_zero_point, A, B, C_int32.data());
-
- vector<int32_t> row_offsets(N * H_OUT * W_OUT * K);
- for (int n = 0; n < N; ++n) {
- for (int h = 0; h < H_OUT; ++h) {
- for (int w = 0; w < W_OUT; ++w) {
- for (int k = 0; k < K; ++k) {
- int sum = 0;
- for (int r = 0; r < R; ++r) {
- int h_in = -PAD_T + h * stride_h + r;
- for (int s = 0; s < S; ++s) {
- int w_in = -PAD_L + w * stride_w + s;
- int a = h_in < 0 || h_in >= H || w_in < 0 || w_in >= W
- ? A_zero_point
- : A[((n * H + h_in) * W + w_in) * K + k];
- sum += a;
- }
- }
- row_offsets[((n * H_OUT + h) * W_OUT + w) * K + k] = sum;
- }
- }
- }
- } // for each n
-
- for (int i = 0; i < N * H_OUT * W_OUT; ++i) {
- for (int k = 0; k < K; ++k) {
- requantize_u8acc32_ref(
- 1,
- 1,
- 1,
- C_int32.data() + i * K + k,
- C + i * K + k,
- &C_multiplier,
- C_zero_point,
- A_zero_point,
- &B_zero_point,
- &row_offsets[i * K + k],
- col_offsets + k,
- bias ? bias + k : nullptr,
- 1);
- }
- }
-};
-
-void depthwise_3x3_per_channel_quantization_pad_1_ref(
- int N,
- int H,
- int W,
- int K,
- int stride_h,
- int stride_w,
- int32_t A_zero_point,
- const uint8_t* A,
- const int32_t* B_zero_point,
- const int8_t* B,
- const float* C_multiplier,
- int32_t C_zero_point,
- uint8_t* C,
- const int32_t* col_offsets,
- const int32_t* bias) {
- constexpr int R = 3, S = 3;
- constexpr int PAD_T = 1, PAD_B = 1, PAD_L = 1, PAD_R = 1;
- int H_OUT = (H + PAD_T + PAD_B - R) / stride_h + 1;
- int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1;
-
- vector<int32_t> C_int32(N * H_OUT * W_OUT * K);
- depthwise_3x3_pad_1_ref(
- N, H, W, K, stride_h, stride_w, A_zero_point, A, B, C_int32.data());
-
- vector<int32_t> row_offsets(N * H_OUT * W_OUT * K);
- for (int n = 0; n < N; ++n) {
- for (int h = 0; h < H_OUT; ++h) {
- for (int w = 0; w < W_OUT; ++w) {
- for (int k = 0; k < K; ++k) {
- int sum = 0;
- for (int r = 0; r < R; ++r) {
- int h_in = -PAD_T + h * stride_h + r;
- for (int s = 0; s < S; ++s) {
- int w_in = -PAD_L + w * stride_w + s;
- int a = h_in < 0 || h_in >= H || w_in < 0 || w_in >= W
- ? A_zero_point
- : A[((n * H + h_in) * W + w_in) * K + k];
- sum += a;
- }
- }
- row_offsets[((n * H_OUT + h) * W_OUT + w) * K + k] = sum;
- }
- }
- }
- } // for each n
-
- for (int i = 0; i < N * H_OUT * W_OUT; ++i) {
- for (int k = 0; k < K; ++k) {
- requantize_u8acc32_ref(
- 1,
- 1,
- 1,
- C_int32.data() + i * K + k,
- C + i * K + k,
- &C_multiplier[k],
- C_zero_point,
- A_zero_point,
- &B_zero_point[k],
- &row_offsets[i * K + k],
- col_offsets + k,
- bias ? bias + k : nullptr,
- 1);
- }
- }
-};
-
-void depthwise_3x3x3_pad_1_ref(
- int N,
- int T,
- int H,
- int W,
- int K,
- int stride_t,
- int stride_h,
- int stride_w,
- int32_t A_zero_point,
- const uint8_t* A,
- const int8_t* B,
- int32_t* C) {
- constexpr int K_T = 3, K_H = 3, K_W = 3;
- constexpr int PAD_P = 1, PAD_N = 1, PAD_T = 1, PAD_B = 1, PAD_L = 1,
- PAD_R = 1;
- int T_OUT = (T + PAD_P + PAD_N - K_T) / stride_t + 1;
- int H_OUT = (H + PAD_T + PAD_B - K_H) / stride_h + 1;
- int W_OUT = (W + PAD_L + PAD_R - K_W) / stride_w + 1;
-
- for (int n = 0; n < N; ++n) {
- for (int t = 0; t < T_OUT; ++t) {
- for (int h = 0; h < H_OUT; ++h) {
- for (int w = 0; w < W_OUT; ++w) {
- for (int k = 0; k < K; ++k) {
- int sum = 0;
- for (int k_t = 0; k_t < K_T; ++k_t) {
- int t_in = -PAD_P + t * stride_t + k_t;
- for (int k_h = 0; k_h < K_H; ++k_h) {
- int h_in = -PAD_T + h * stride_h + k_h;
- for (int k_w = 0; k_w < K_W; ++k_w) {
- int w_in = -PAD_L + w * stride_w + k_w;
- int a = t_in < 0 || t_in >= T || h_in < 0 || h_in >= H ||
- w_in < 0 || w_in >= W
- ? A_zero_point
- : A[(((n * T + t_in) * H + h_in) * W + w_in) * K + k];
- int b = B[((k * K_T + k_t) * K_H + k_h) * K_W + k_w];
- sum += a * b;
- }
- }
- }
- C[(((n * T_OUT + t) * H_OUT + h) * W_OUT + w) * K + k] = sum;
- }
- } // w
- } // h
- } // t
- } // for each n
-};
-
-void depthwise_3x3x3_pad_1_ref(
- int N,
- int T,
- int H,
- int W,
- int K,
- int stride_t,
- int stride_h,
- int stride_w,
- int32_t A_zero_point,
- const uint8_t* A,
- int32_t B_zero_point,
- const int8_t* B,
- float C_multiplier,
- int32_t C_zero_point,
- uint8_t* C,
- const int32_t* col_offsets,
- const int32_t* bias) {
- constexpr int K_T = 3, K_H = 3, K_W = 3;
- constexpr int PAD_P = 1, PAD_N = 1, PAD_T = 1, PAD_B = 1, PAD_L = 1,
- PAD_R = 1;
- int T_OUT = (T + PAD_P + PAD_N - K_T) / stride_t + 1;
- int H_OUT = (H + PAD_T + PAD_B - K_H) / stride_h + 1;
- int W_OUT = (W + PAD_L + PAD_R - K_W) / stride_w + 1;
-
- vector<int32_t> C_int32(N * T_OUT * H_OUT * W_OUT * K);
- depthwise_3x3x3_pad_1_ref(
- N,
- T,
- H,
- W,
- K,
- stride_t,
- stride_h,
- stride_w,
- A_zero_point,
- A,
- B,
- C_int32.data());
-
- vector<int32_t> row_offsets(N * T_OUT * H_OUT * W_OUT * K);
- for (int n = 0; n < N; ++n) {
- for (int t = 0; t < T_OUT; ++t) {
- for (int h = 0; h < H_OUT; ++h) {
- for (int w = 0; w < W_OUT; ++w) {
- for (int k = 0; k < K; ++k) {
- int sum = 0;
- for (int k_t = 0; k_t < K_T; ++k_t) {
- int t_in = -PAD_P + t * stride_t + k_t;
- for (int k_h = 0; k_h < K_H; ++k_h) {
- int h_in = -PAD_T + h * stride_h + k_h;
- for (int k_w = 0; k_w < K_W; ++k_w) {
- int w_in = -PAD_L + w * stride_w + k_w;
- int a = t_in < 0 || t_in >= T || h_in < 0 || h_in >= H ||
- w_in < 0 || w_in >= W
- ? A_zero_point
- : A[(((n * T + t_in) * H + h_in) * W + w_in) * K + k];
- sum += a;
- }
- }
- }
- row_offsets[(((n * T_OUT + t) * H_OUT + h) * W_OUT + w) * K + k] =
- sum;
- }
- } // w
- } // h
- } // t
- } // for each n
-
- for (int i = 0; i < N * T_OUT * H_OUT * W_OUT; ++i) {
- for (int k = 0; k < K; ++k) {
- requantize_u8acc32_ref(
- 1,
- 1,
- 1,
- C_int32.data() + i * K + k,
- C + i * K + k,
- &C_multiplier,
- C_zero_point,
- A_zero_point,
- &B_zero_point,
- &row_offsets[i * K + k],
- col_offsets + k,
- bias ? bias + k : nullptr,
- 1);
- }
- }
-};
-
-void depthwise_3x3x3_per_channel_quantization_pad_1_ref(
- int N,
- int T,
- int H,
- int W,
- int K,
- int stride_t,
- int stride_h,
- int stride_w,
- int32_t A_zero_point,
- const uint8_t* A,
- const int32_t* B_zero_point,
- const int8_t* B,
- const float* C_multiplier,
- int32_t C_zero_point,
- uint8_t* C,
- const int32_t* col_offsets,
- const int32_t* bias) {
- constexpr int K_T = 3, K_H = 3, K_W = 3;
- constexpr int PAD_P = 1, PAD_N = 1, PAD_T = 1, PAD_B = 1, PAD_L = 1,
- PAD_R = 1;
- int T_OUT = (T + PAD_P + PAD_N - K_T) / stride_t + 1;
- int H_OUT = (H + PAD_T + PAD_B - K_H) / stride_h + 1;
- int W_OUT = (W + PAD_L + PAD_R - K_W) / stride_w + 1;
-
- vector<int32_t> C_int32(N * T_OUT * H_OUT * W_OUT * K);
- depthwise_3x3x3_pad_1_ref(
- N,
- T,
- H,
- W,
- K,
- stride_t,
- stride_h,
- stride_w,
- A_zero_point,
- A,
- B,
- C_int32.data());
-
- vector<int32_t> row_offsets(N * T_OUT * H_OUT * W_OUT * K);
- for (int n = 0; n < N; ++n) {
- for (int t = 0; t < T_OUT; ++t) {
- for (int h = 0; h < H_OUT; ++h) {
- for (int w = 0; w < W_OUT; ++w) {
- for (int k = 0; k < K; ++k) {
- int sum = 0;
- for (int k_t = 0; k_t < K_T; ++k_t) {
- int t_in = -PAD_P + t * stride_t + k_t;
- for (int k_h = 0; k_h < K_H; ++k_h) {
- int h_in = -PAD_T + h * stride_h + k_h;
- for (int k_w = 0; k_w < K_W; ++k_w) {
- int w_in = -PAD_L + w * stride_w + k_w;
- int a = t_in < 0 || t_in >= T || h_in < 0 || h_in >= H ||
- w_in < 0 || w_in >= W
- ? A_zero_point
- : A[(((n * T + t_in) * H + h_in) * W + w_in) * K + k];
- sum += a;
- }
- }
- }
- row_offsets[(((n * T_OUT + t) * H_OUT + h) * W_OUT + w) * K + k] =
- sum;
- }
- } // w
- } // h
- } // t
- } // for each n
-
- for (int i = 0; i < N * T_OUT * H_OUT * W_OUT; ++i) {
- for (int k = 0; k < K; ++k) {
- requantize_u8acc32_ref(
- 1,
- 1,
- 1,
- C_int32.data() + i * K + k,
- C + i * K + k,
- &C_multiplier[k],
- C_zero_point,
- A_zero_point,
- &B_zero_point[k],
- &row_offsets[i * K + k],
- col_offsets + k,
- bias ? bias + k : nullptr,
- 1);
- }
- }
-};
-
template void transposeConvWeights(
const conv_param_t<2>& conv_p,
const std::int8_t* src,
diff --git a/src/RefImplementations.h b/src/RefImplementations.h
index 082bdf1..a20e348 100644
--- a/src/RefImplementations.h
+++ b/src/RefImplementations.h
@@ -215,124 +215,4 @@ FBGEMM_API void im2col_ref(
std::int32_t A_zero_point,
std::uint8_t* Ao);
-/*
- * @brief Reference implementation of depthwise convolution with a 3x3 filter
- * and padding size 1.
- */
-FBGEMM_API void depthwise_3x3_pad_1_ref(
- int N,
- int H,
- int W,
- int K,
- int stride_h,
- int stride_w,
- std::int32_t A_zero_point,
- const std::uint8_t* A,
- const std::int8_t* B,
- std::int32_t* C);
-
-/*
- * @brief Reference implementation of depthwise convolution with a 3x3 filter
- * and padding size 1, followed by requantization. (the same scaling factors and
- * zero points for each channel).
- */
-FBGEMM_API void depthwise_3x3_pad_1_ref(
- int N,
- int H,
- int W,
- int K,
- int stride_h,
- int stride_w,
- std::int32_t A_zero_point,
- const std::uint8_t* A,
- std::int32_t B_zero_point,
- const std::int8_t* B,
- float C_multiplier,
- std::int32_t C_zero_point,
- std::uint8_t* C,
- const std::int32_t* col_offsets,
- const std::int32_t* bias);
-
-/*
- * @brief Reference implementation of depthwise convolution with a 3x3 filter
- * and padding size 1, followed by requantization. (different scaling factors
- * and zero points for each channel).
- */
-FBGEMM_API void depthwise_3x3_per_channel_quantization_pad_1_ref(
- int N,
- int H,
- int W,
- int K,
- int stride_h,
- int stride_w,
- std::int32_t A_zero_point,
- const std::uint8_t* A,
- const std::int32_t* B_zero_point,
- const std::int8_t* B,
- const float* C_multiplier,
- std::int32_t C_zero_point,
- std::uint8_t* C,
- const std::int32_t* col_offsets,
- const std::int32_t* bias);
-
-/*
- * @brief Reference implementation of 3D depthwise convolution with a 3x3x3
- * filter and padding size 1.
- */
-FBGEMM_API void depthwise_3x3x3_pad_1_ref(
- int N,
- int T,
- int H,
- int W,
- int K,
- int stride_t,
- int stride_h,
- int stride_w,
- std::int32_t A_zero_point,
- const std::uint8_t* A,
- const std::int8_t* B,
- std::int32_t* C);
-
-/*
- * @brief Reference implementation of 3D depthwise convolution with a 3x3x3
- * filter and padding size 1, followed by requantization.
- */
-FBGEMM_API void depthwise_3x3x3_pad_1_ref(
- int N,
- int T,
- int H,
- int W,
- int K,
- int stride_t,
- int stride_h,
- int stride_w,
- std::int32_t A_zero_point,
- const std::uint8_t* A,
- std::int32_t B_zero_point,
- const std::int8_t* B,
- float C_multiplier,
- std::int32_t C_zero_point,
- std::uint8_t* C,
- const std::int32_t* col_offsets,
- const std::int32_t* bias);
-
-FBGEMM_API void depthwise_3x3x3_per_channel_quantization_pad_1_ref(
- int N,
- int T,
- int H,
- int W,
- int K,
- int stride_t,
- int stride_h,
- int stride_w,
- std::int32_t A_zero_point,
- const std::uint8_t* A,
- const std::int32_t* B_zero_point,
- const std::int8_t* B,
- const float* C_multiplier,
- std::int32_t C_zero_point,
- std::uint8_t* C,
- const std::int32_t* col_offsets,
- const std::int32_t* bias);
-
} // namespace fbgemm
diff --git a/test/I8DepthwiseTest.cc b/test/I8DepthwiseTest.cc
index 0604879..9de6943 100644
--- a/test/I8DepthwiseTest.cc
+++ b/test/I8DepthwiseTest.cc
@@ -4,7 +4,6 @@
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
-#include "I8DepthwiseTest.h"
#include <cmath>
#include <cstdio>
@@ -22,6 +21,7 @@ using namespace std;
namespace fbgemm {
// From Xray OCR
+// clang-format off
static vector<vector<int>> shapes = {
// NOTE: clang-format wants to use a different formatting but the current
// formatting should be easier to read.
@@ -68,6 +68,17 @@ static vector<vector<int>> shapes = {
{ 1, 8, 4, 4, 1, },
};
+static vector<vector<int>> shapes_3d = {
+ // NOTE: clang-format wants to use a different formatting but the current
+ // formatting should be easier to read.
+ // N, K, T_in, H_in, W_in, stride
+ { 1, 32, 16, 28, 28, 1, },
+ { 1, 128, 8, 14, 14, 2, },
+ { 5, 16, 32, 56, 56, 1, },
+ { 1, 8, 4, 4, 4, 1, },
+};
+// clang-format on
+
namespace {
class FBGemmDepthWiseTest : public testing::TestWithParam<tuple<bool, bool>> {};
@@ -105,13 +116,29 @@ TEST_P(FBGemmDepthWiseTest, Test3x3) {
int stride_h = shape[4];
int stride_w = stride_h;
constexpr int R = 3, S = 3;
- constexpr int PAD_T = 1, PAD_B = 1, PAD_L = 1, PAD_R = 1;
- int H_OUT = (H + PAD_T + PAD_B - R) / stride_h + 1;
- int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1;
+ int PAD_T = (R - 1) / 2, PAD_B = (R - 1) / 2, PAD_L = (S - 1) / 2,
+ PAD_R = (S - 1) / 2;
+
+ conv_param_t<2> conv_p(
+ N,
+ K,
+ K,
+ {H, W},
+ K,
+ {R, S},
+ {stride_h, stride_w},
+ {PAD_T, PAD_L, PAD_B, PAD_R});
+ int H_OUT = conv_p.OUT_DIM[0];
+ int W_OUT = conv_p.OUT_DIM[1];
+
+ int MDim = N * H_OUT * W_OUT;
+ int KDim = R * S * K;
+ int KDimPerGroup = KDim / conv_p.G;
aligned_vector<uint8_t> A(N * H * W * K);
- aligned_vector<int8_t> B(K * R * S);
- aligned_vector<int32_t> C_ref(N * H_OUT * W_OUT * K), C(C_ref.size());
+ aligned_vector<int8_t> B(KDim);
+ aligned_vector<int32_t> C_ref(MDim * K), C(C_ref.size());
+ aligned_vector<uint8_t> C_uint8_ref(C_ref.size()), C_uint8(C_ref.size());
randFill<uint8_t>(A, 0, 86);
int32_t A_zero_point = a_symmetric ? 0 : 43;
@@ -119,48 +146,54 @@ TEST_P(FBGemmDepthWiseTest, Test3x3) {
randFill<int8_t>(B, -16, 16);
int32_t B_zero_point = b_symmetric ? 0 : 5;
- depthwise_3x3_pad_1_ref(
- N,
- H,
- W,
- K,
- stride_h,
- stride_w,
- A_zero_point,
- A.data(),
- B.data(),
- C_ref.data());
-
- int32_t minimum = *min_element(C_ref.begin(), C_ref.end());
- int32_t maximum = *max_element(C_ref.begin(), C_ref.end());
-
- float C_multiplier = 255. / (maximum - minimum);
+ aligned_vector<float> C_multiplier(1);
+ randFill(C_multiplier, 0.001234f / 2, 0.001234f * 3 / 2);
+ int32_t C_zero_point = 5;
aligned_vector<int32_t> col_offsets(K);
aligned_vector<int32_t> bias(K);
randFill(col_offsets, -100, 100);
randFill(bias, -40, 40);
- int32_t C_zero_point = 5;
- aligned_vector<uint8_t> C_uint8_ref(C_ref.size()), C_uint8(C_ref.size());
- depthwise_3x3_pad_1_ref(
- N,
- H,
- W,
- K,
- stride_h,
- stride_w,
- A_zero_point,
- A.data(),
- B_zero_point,
- B.data(),
- C_multiplier,
- C_zero_point,
- C_uint8_ref.data(),
- col_offsets.data(),
- bias.data());
+ vector<int32_t> row_offsets(MDim);
+ // im2col to compute row offset later
+ vector<uint8_t> A_im2col;
+ if (!b_symmetric) {
+ A_im2col.resize(MDim * KDim);
+ im2col_ref(conv_p, A.data(), A_zero_point, A_im2col.data());
+ }
+
+ conv_ref(conv_p, A.data(), A_zero_point, B.data(), C_ref.data());
+
+ for (int g = 0; g < conv_p.G; ++g) {
+ // Compute row offset
+ if (!b_symmetric) {
+ row_offsets_u8acc32_ref(
+ MDim,
+ KDimPerGroup,
+ KDim,
+ A_im2col.data() + g * KDimPerGroup,
+ row_offsets.data());
+ }
- Packed3x3ConvMatrix Bp(K, B.data());
+ // Requantization
+ requantize_u8acc32_ref(
+ MDim,
+ 1,
+ conv_p.G,
+ C_ref.data() + g,
+ C_uint8_ref.data() + g,
+ C_multiplier.data(),
+ C_zero_point,
+ A_zero_point,
+ &B_zero_point,
+ row_offsets.data(),
+ col_offsets.data() + g,
+ bias.data() + g,
+ K);
+ }
+
+ PackedDepthWiseConvMatrix Bp(K, 3 * 3, B.data());
depthwise_3x3_pad_1(
N,
@@ -173,12 +206,13 @@ TEST_P(FBGemmDepthWiseTest, Test3x3) {
A.data(),
B_zero_point,
Bp,
- C_multiplier,
+ C_multiplier[0],
C_zero_point,
C_uint8.data(),
a_symmetric ? nullptr : col_offsets.data(),
bias.data(),
false, /* fuse_relu */
+ 1.0f, /* act_scale * w_scale */
0,
1);
@@ -220,67 +254,83 @@ TEST_P(FBGemmDepthWiseTest, Test3x3x3) {
constexpr int K_T = 3, K_H = 3, K_W = 3;
constexpr int PAD_P = 1, PAD_N = 1, PAD_T = 1, PAD_B = 1, PAD_L = 1,
PAD_R = 1;
- int T_OUT = (T + PAD_P + PAD_N - K_T) / stride_t + 1;
- int H_OUT = (H + PAD_T + PAD_B - K_H) / stride_h + 1;
- int W_OUT = (W + PAD_L + PAD_R - K_W) / stride_w + 1;
+
+ conv_param_t<3> conv_p(
+ N,
+ K,
+ K,
+ {T, H, W},
+ K,
+ {K_T, K_H, K_W},
+ {stride_t, stride_h, stride_w},
+ {PAD_P, PAD_T, PAD_L, PAD_N, PAD_B, PAD_R});
+ int T_OUT = conv_p.OUT_DIM[0];
+ int H_OUT = conv_p.OUT_DIM[1];
+ int W_OUT = conv_p.OUT_DIM[2];
+
+ int MDim = N * T_OUT * H_OUT * W_OUT;
+ int KDim = K_T * K_H * K_W * K;
+ int KDimPerGroup = KDim / conv_p.G;
aligned_vector<uint8_t> A(N * T * H * W * K);
- aligned_vector<int8_t> B(K * K_T * K_H * K_W);
- aligned_vector<int32_t> C_ref(N * T_OUT * H_OUT * W_OUT * K),
- C(C_ref.size());
+ aligned_vector<int8_t> B(KDim);
+ aligned_vector<int32_t> C_ref(MDim * K), C(C_ref.size());
+ aligned_vector<uint8_t> C_uint8_ref(C_ref.size()), C_uint8(C_ref.size());
randFill<uint8_t>(A, 0, 86);
int32_t A_zero_point = a_symmetric ? 0 : 43;
randFill<int8_t>(B, -16, 16);
- int32_t B_zero_point = 5;
-
- depthwise_3x3x3_pad_1_ref(
- N,
- T,
- H,
- W,
- K,
- stride_t,
- stride_h,
- stride_w,
- A_zero_point,
- A.data(),
- B.data(),
- C_ref.data());
-
- int32_t minimum = *min_element(C_ref.begin(), C_ref.end());
- int32_t maximum = *max_element(C_ref.begin(), C_ref.end());
+ int32_t B_zero_point = b_symmetric ? 0 : 5;
- float C_multiplier = 255. / (maximum - minimum);
+ aligned_vector<float> C_multiplier(1);
+ randFill(C_multiplier, 0.001234f / 2, 0.001234f * 3 / 2);
+ int32_t C_zero_point = 5;
aligned_vector<int32_t> col_offsets(K);
aligned_vector<int32_t> bias(K);
randFill(col_offsets, -100, 100);
randFill(bias, -40, 40);
- int32_t C_zero_point = 5;
- aligned_vector<uint8_t> C_uint8_ref(C_ref.size()), C_uint8(C_ref.size());
- depthwise_3x3x3_pad_1_ref(
- N,
- T,
- H,
- W,
- K,
- stride_t,
- stride_h,
- stride_w,
- A_zero_point,
- A.data(),
- B_zero_point,
- B.data(),
- C_multiplier,
- C_zero_point,
- C_uint8_ref.data(),
- col_offsets.data(),
- bias.data());
+ vector<int32_t> row_offsets(MDim);
+ // im2col to compute row offset later
+ vector<uint8_t> A_im2col;
+ if (!b_symmetric) {
+ A_im2col.resize(MDim * KDim);
+ im2col_ref(conv_p, A.data(), A_zero_point, A_im2col.data());
+ }
+
+ conv_ref(conv_p, A.data(), A_zero_point, B.data(), C_ref.data());
+
+ for (int g = 0; g < conv_p.G; ++g) {
+ // Compute row offset
+ if (!b_symmetric) {
+ row_offsets_u8acc32_ref(
+ MDim,
+ KDimPerGroup,
+ KDim,
+ A_im2col.data() + g * KDimPerGroup,
+ row_offsets.data());
+ }
+
+ // Requantization
+ requantize_u8acc32_ref(
+ MDim,
+ 1,
+ conv_p.G,
+ C_ref.data() + g,
+ C_uint8_ref.data() + g,
+ C_multiplier.data(),
+ C_zero_point,
+ A_zero_point,
+ &B_zero_point,
+ row_offsets.data(),
+ col_offsets.data() + g,
+ bias.data() + g,
+ K);
+ }
- Packed3x3x3ConvMatrix Bp(K, B.data());
+ PackedDepthWiseConvMatrix Bp(K, 3 * 3 * 3, B.data());
depthwise_3x3x3_pad_1(
N,
@@ -295,10 +345,10 @@ TEST_P(FBGemmDepthWiseTest, Test3x3x3) {
A.data(),
B_zero_point,
Bp,
- C_multiplier,
+ C_multiplier[0],
C_zero_point,
C_uint8.data(),
- col_offsets.data(),
+ a_symmetric ? nullptr : col_offsets.data(),
bias.data(),
false, /* fuse_relu */
0,
@@ -334,14 +384,29 @@ TEST(FBGemmDepthWiseTest, Test3x3PerChannelQuantization) {
int stride_h = shape[4];
int stride_w = stride_h;
constexpr int R = 3, S = 3;
- constexpr int PAD_T = 1, PAD_B = 1, PAD_L = 1, PAD_R = 1;
- int H_OUT = (H + PAD_T + PAD_B - R) / stride_h + 1;
- int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1;
+ int PAD_T = (R - 1) / 2, PAD_B = (R - 1) / 2, PAD_L = (S - 1) / 2,
+ PAD_R = (S - 1) / 2;
+
+ conv_param_t<2> conv_p(
+ N,
+ K,
+ K,
+ {H, W},
+ K,
+ {R, S},
+ {stride_h, stride_w},
+ {PAD_T, PAD_L, PAD_B, PAD_R});
+ int H_OUT = conv_p.OUT_DIM[0];
+ int W_OUT = conv_p.OUT_DIM[1];
+
+ int MDim = N * H_OUT * W_OUT;
+ int KDim = R * S * K;
+ int KDimPerGroup = KDim / conv_p.G;
aligned_vector<uint8_t> A(N * H * W * K);
- aligned_vector<int8_t> B(K * R * S);
- int32_t C_num_rows = N * H_OUT * W_OUT;
- aligned_vector<int32_t> C_ref(C_num_rows * K), C(C_ref.size());
+ aligned_vector<int8_t> B(KDim);
+ aligned_vector<int32_t> C_ref(MDim * K), C(C_ref.size());
+ aligned_vector<uint8_t> C_uint8_ref(C_ref.size()), C_uint8(C_ref.size());
randFill<uint8_t>(A, 0, 86);
int32_t A_zero_point = 43;
@@ -357,28 +422,8 @@ TEST(FBGemmDepthWiseTest, Test3x3PerChannelQuantization) {
B_zero_point[k] = 5 + k;
}
- depthwise_3x3_pad_1_ref(
- N,
- H,
- W,
- K,
- stride_h,
- stride_w,
- A_zero_point,
- A.data(),
- B.data(),
- C_ref.data());
-
- aligned_vector<int32_t> C_ref_transpose(C_ref);
- transpose_matrix(C_ref.data(), C_num_rows, K);
- vector<float> C_multiplier(K);
- for (auto k = 0; k < K; ++k) {
- auto C_ref_k_begin = C_ref_transpose.begin() + k * C_num_rows;
- auto C_ref_k_end = C_ref_k_begin + C_num_rows;
- int32_t minimum = *min_element(C_ref_k_begin, C_ref_k_end);
- int32_t maximum = *max_element(C_ref_k_begin, C_ref_k_end);
- C_multiplier[k] = 255. / (maximum - minimum);
- }
+ aligned_vector<float> C_multiplier(K);
+ randFill(C_multiplier, 0.001234f / 2, 0.001234f * 3 / 2);
int32_t C_zero_point = 5;
aligned_vector<int32_t> col_offsets(K);
@@ -386,25 +431,40 @@ TEST(FBGemmDepthWiseTest, Test3x3PerChannelQuantization) {
randFill(col_offsets, -100, 100);
randFill(bias, -40, 40);
- aligned_vector<uint8_t> C_uint8_ref(C_ref.size()), C_uint8(C_ref.size());
- depthwise_3x3_per_channel_quantization_pad_1_ref(
- N,
- H,
- W,
- K,
- stride_h,
- stride_w,
- A_zero_point,
- A.data(),
- B_zero_point.data(),
- B.data(),
- C_multiplier.data(),
- C_zero_point,
- C_uint8_ref.data(),
- col_offsets.data(),
- bias.data());
+ // im2col to compute row offset later
+ vector<int32_t> row_offsets(MDim);
+ vector<uint8_t> A_im2col(MDim * KDim);
+ im2col_ref(conv_p, A.data(), A_zero_point, A_im2col.data());
+
+ conv_ref(conv_p, A.data(), A_zero_point, B.data(), C_ref.data());
+
+ for (int g = 0; g < conv_p.G; ++g) {
+ // Compute row offset
+ row_offsets_u8acc32_ref(
+ MDim,
+ KDimPerGroup,
+ KDim,
+ A_im2col.data() + g * KDimPerGroup,
+ row_offsets.data());
+
+ // Requantization
+ requantize_u8acc32_ref(
+ MDim,
+ 1,
+ conv_p.G,
+ C_ref.data() + g,
+ C_uint8_ref.data() + g,
+ C_multiplier.data() + g,
+ C_zero_point,
+ A_zero_point,
+ B_zero_point.data() + g,
+ row_offsets.data(),
+ col_offsets.data() + g,
+ bias.data() + g,
+ K);
+ }
- Packed3x3ConvMatrix Bp(K, B.data());
+ PackedDepthWiseConvMatrix Bp(K, 3 * 3, B.data());
depthwise_3x3_per_channel_quantization_pad_1(
N,
@@ -457,14 +517,28 @@ TEST(FBGemmDepthWiseTest, Test3x3x3PerChannelQuantization) {
constexpr int K_T = 3, K_H = 3, K_W = 3;
constexpr int PAD_P = 1, PAD_N = 1, PAD_T = 1, PAD_B = 1, PAD_L = 1,
PAD_R = 1;
- int T_OUT = (T + PAD_P + PAD_N - K_T) / stride_t + 1;
- int H_OUT = (H + PAD_T + PAD_B - K_H) / stride_h + 1;
- int W_OUT = (W + PAD_L + PAD_R - K_W) / stride_w + 1;
+
+ conv_param_t<3> conv_p(
+ N,
+ K,
+ K,
+ {T, H, W},
+ K,
+ {K_T, K_H, K_W},
+ {stride_t, stride_h, stride_w},
+ {PAD_P, PAD_T, PAD_L, PAD_N, PAD_B, PAD_R});
+ int T_OUT = conv_p.OUT_DIM[0];
+ int H_OUT = conv_p.OUT_DIM[1];
+ int W_OUT = conv_p.OUT_DIM[2];
+
+ int MDim = N * T_OUT * H_OUT * W_OUT;
+ int KDim = K_T * K_H * K_W * K;
+ int KDimPerGroup = KDim / conv_p.G;
aligned_vector<uint8_t> A(N * T * H * W * K);
- aligned_vector<int8_t> B(K * K_T * K_H * K_W);
- int32_t C_num_rows = N * T_OUT * H_OUT * W_OUT;
- aligned_vector<int32_t> C_ref(C_num_rows * K), C(C_ref.size());
+ aligned_vector<int8_t> B(KDim);
+ aligned_vector<int32_t> C_ref(MDim * K), C(C_ref.size());
+ aligned_vector<uint8_t> C_uint8_ref(C_ref.size()), C_uint8(C_ref.size());
randFill<uint8_t>(A, 0, 86);
int32_t A_zero_point = 43;
@@ -480,30 +554,8 @@ TEST(FBGemmDepthWiseTest, Test3x3x3PerChannelQuantization) {
B_zero_point[k] = 5 + k;
}
- depthwise_3x3x3_pad_1_ref(
- N,
- T,
- H,
- W,
- K,
- stride_t,
- stride_h,
- stride_w,
- A_zero_point,
- A.data(),
- B.data(),
- C_ref.data());
-
- aligned_vector<int32_t> C_ref_transpose(C_ref);
- transpose_matrix(C_ref.data(), C_num_rows, K);
- vector<float> C_multiplier(K);
- for (auto k = 0; k < K; ++k) {
- auto C_ref_k_begin = C_ref_transpose.begin() + k * C_num_rows;
- auto C_ref_k_end = C_ref_k_begin + C_num_rows;
- int32_t minimum = *min_element(C_ref_k_begin, C_ref_k_end);
- int32_t maximum = *max_element(C_ref_k_begin, C_ref_k_end);
- C_multiplier[k] = 255. / (maximum - minimum);
- }
+ aligned_vector<float> C_multiplier(K);
+ randFill(C_multiplier, 0.001234f / 2, 0.001234f * 3 / 2);
int32_t C_zero_point = 5;
aligned_vector<int32_t> col_offsets(K);
@@ -511,27 +563,40 @@ TEST(FBGemmDepthWiseTest, Test3x3x3PerChannelQuantization) {
randFill(col_offsets, -100, 100);
randFill(bias, -40, 40);
- aligned_vector<uint8_t> C_uint8_ref(C_ref.size()), C_uint8(C_ref.size());
- depthwise_3x3x3_per_channel_quantization_pad_1_ref(
- N,
- T,
- H,
- W,
- K,
- stride_t,
- stride_h,
- stride_w,
- A_zero_point,
- A.data(),
- B_zero_point.data(),
- B.data(),
- C_multiplier.data(),
- C_zero_point,
- C_uint8_ref.data(),
- col_offsets.data(),
- bias.data());
+ vector<int32_t> row_offsets(MDim);
+ // im2col to compute row offset later
+ vector<uint8_t> A_im2col(MDim * KDim);
+ im2col_ref(conv_p, A.data(), A_zero_point, A_im2col.data());
+
+ conv_ref(conv_p, A.data(), A_zero_point, B.data(), C_ref.data());
+
+ for (int g = 0; g < conv_p.G; ++g) {
+ // Compute row offset
+ row_offsets_u8acc32_ref(
+ MDim,
+ KDimPerGroup,
+ KDim,
+ A_im2col.data() + g * KDimPerGroup,
+ row_offsets.data());
+
+ // Requantization
+ requantize_u8acc32_ref(
+ MDim,
+ 1,
+ conv_p.G,
+ C_ref.data() + g,
+ C_uint8_ref.data() + g,
+ C_multiplier.data() + g,
+ C_zero_point,
+ A_zero_point,
+ B_zero_point.data() + g,
+ row_offsets.data(),
+ col_offsets.data() + g,
+ bias.data() + g,
+ K);
+ }
- Packed3x3x3ConvMatrix Bp(K, B.data());
+ PackedDepthWiseConvMatrix Bp(K, 3 * 3 * 3, B.data());
depthwise_3x3x3_per_channel_quantization_pad_1(
N,
@@ -587,36 +652,8 @@ TEST_P(FBGemmDepthWisePackUnpackTest, TestPackUnpack) {
aligned_vector<int8_t> BUnpacked(K * kernel_prod);
- if (kernel_prod == 1) {
- Packed1ConvMatrix BPacked(K, B.data());
- BPacked.unpack(BUnpacked.data());
- } else if (kernel_prod == 2) {
- Packed2ConvMatrix BPacked(K, B.data());
- BPacked.unpack(BUnpacked.data());
- } else if (kernel_prod == 3) {
- Packed3ConvMatrix BPacked(K, B.data());
- BPacked.unpack(BUnpacked.data());
- } else if (kernel_prod == 4) {
- Packed4ConvMatrix BPacked(K, B.data());
- BPacked.unpack(BUnpacked.data());
- } else if (kernel_prod == 5) {
- Packed5ConvMatrix BPacked(K, B.data());
- BPacked.unpack(BUnpacked.data());
- } else if (kernel_prod == 9) {
- Packed3x3ConvMatrix BPacked(K, B.data());
- BPacked.unpack(BUnpacked.data());
- } else if (kernel_prod == 10) {
- Packed10ConvMatrix BPacked(K, B.data());
- BPacked.unpack(BUnpacked.data());
- } else if (kernel_prod == 11) {
- Packed11ConvMatrix BPacked(K, B.data());
- BPacked.unpack(BUnpacked.data());
- } else if (kernel_prod == 27) {
- Packed3x3x3ConvMatrix BPacked(K, B.data());
- BPacked.unpack(BUnpacked.data());
- } else {
- ASSERT_TRUE(false);
- }
+ PackedDepthWiseConvMatrix BPacked(K, kernel_prod, B.data());
+ BPacked.unpack(BUnpacked.data());
ASSERT_EQ(B, BUnpacked)
<< "Original and unpacked data elements are not the same";
diff --git a/test/I8DepthwiseTest.h b/test/I8DepthwiseTest.h
deleted file mode 100644
index d65362a..0000000
--- a/test/I8DepthwiseTest.h
+++ /dev/null
@@ -1,39 +0,0 @@
-/*
- * Copyright (c) Facebook, Inc. and its affiliates.
- * All rights reserved.
- * This source code is licensed under the BSD-style license found in the
- * LICENSE file in the root directory of this source tree.
- */
-#pragma once
-
-#include <vector>
-
-namespace fbgemm {
-
-// From ResNeXt-3D-101
-static std::vector<std::vector<int>> shapes_3d = {
- // NOTE: clang-format wants to use a different formatting but the current
- // formatting should be easier to read.
- // N, K, T_in, H_in, W_in, stride
- { 1, 64, 32, 56, 56, 1, },
- { 1, 128, 16, 28, 28, 1, },
- { 1, 256, 8, 14, 14, 1, },
- { 1, 512, 4, 7, 7, 1, },
-
- { 1, 128, 32, 56, 56, 2, },
- { 1, 256, 16, 28, 28, 2, },
- { 1, 512, 8, 14, 14, 2, },
-
- { 5, 64, 32, 56, 56, 1, },
- { 5, 128, 16, 28, 28, 1, },
- { 5, 256, 8, 14, 14, 1, },
- { 5, 512, 4, 7, 7, 1, },
-
- { 5, 128, 32, 56, 56, 2, },
- { 5, 256, 16, 28, 28, 2, },
- { 5, 512, 8, 14, 14, 2, },
-
- { 1, 8, 4, 4, 4, 1, },
-};
-
-} // namespace fbgemm
diff --git a/test/RequantizeOnlyTest.cc b/test/RequantizeOnlyTest.cc
new file mode 100644
index 0000000..2f73d49
--- /dev/null
+++ b/test/RequantizeOnlyTest.cc
@@ -0,0 +1,169 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+#include <algorithm>
+#include <functional>
+#include <iostream>
+#include <random>
+#include <stdexcept>
+#include <string>
+
+#include <gtest/gtest.h>
+
+#include "TestUtils.h"
+#include "bench/BenchUtils.h"
+#include "fbgemm/Fbgemm.h"
+
+using namespace std;
+using namespace fbgemm;
+
+vector<QuantizationGranularity> qGranularityVals{
+ QuantizationGranularity::TENSOR,
+ QuantizationGranularity::OUT_CHANNEL};
+
+namespace {
+
+// tuple represents #rows, #cols, fuse_relu, quantization_granularity, bias_type
+class FloatRequantizeTest
+ : public testing::TestWithParam<
+ tuple<int, int, bool, QuantizationGranularity>> {};
+
+}; // namespace
+
+INSTANTIATE_TEST_CASE_P(
+ InstantiationName,
+ FloatRequantizeTest,
+ ::testing::Combine(
+ ::testing::ValuesIn({1, 2, 3, 4}), // number of rows
+ ::testing::ValuesIn(
+ {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 16, 20, 32}), // number of
+ // cols
+ ::testing::Bool(), // fuse relu
+ ::testing::ValuesIn(qGranularityVals))); // requantization granularity
+
+/**
+ * Test for float bias
+ */
+TEST_P(FloatRequantizeTest, floatBiasTest) {
+ int rows, cols;
+ bool fuse_relu;
+ QuantizationGranularity q_gran;
+ tie(rows, cols, fuse_relu, q_gran) = GetParam();
+
+ int numElements = rows * cols;
+
+ aligned_vector<float> act_times_w_scale(cols);
+ randFill<float>(act_times_w_scale, -8, 8);
+
+ float out_scale = 2.0f;
+
+ aligned_vector<float> C_multiplier(cols);
+ transform(
+ act_times_w_scale.begin(),
+ act_times_w_scale.end(),
+ C_multiplier.begin(),
+ [&out_scale](float i) { return i / out_scale; });
+
+ aligned_vector<int32_t> Bint8_zero_point(cols);
+ randFill<int32_t>(Bint8_zero_point, -8, 8);
+
+ aligned_vector<int32_t> row_offset_buf(rows);
+ randFill<int32_t>(row_offset_buf, -8, 8);
+
+ aligned_vector<int32_t> col_offsets(cols);
+ randFill<int32_t>(col_offsets, -8, 8);
+
+ // quantized bias
+ aligned_vector<int32_t> bias_q(cols);
+ randFill<int32_t>(bias_q, -8, 8);
+
+ // floating point bias
+ aligned_vector<float> bias_f(cols);
+ if (q_gran == QuantizationGranularity::TENSOR) {
+ transform(
+ bias_q.begin(),
+ bias_q.end(),
+ bias_f.begin(),
+ [&act_times_w_scale](float i) { return i * act_times_w_scale[0]; });
+ } else if (q_gran == QuantizationGranularity::OUT_CHANNEL) {
+ transform(
+ act_times_w_scale.begin(),
+ act_times_w_scale.end(),
+ bias_q.begin(),
+ bias_f.begin(),
+ multiplies<float>());
+
+ } else {
+ FAIL();
+ }
+
+ aligned_vector<int32_t> input(numElements);
+ randFill<int32_t>(input, -8, 8);
+
+ aligned_vector<uint8_t> output_q_bias(numElements);
+ aligned_vector<uint8_t> output_f_bias(numElements);
+
+ int32_t C_zero_point = 3;
+ int32_t Aint8_zero_point = 3;
+
+ block_type_t block{0, rows, 0, cols};
+
+ DoNothing<> doNothingObj{};
+
+#define TESTCODE(FUSE_RELU, Q_GRAN) \
+ ReQuantizeOutput<FUSE_RELU, Q_GRAN> reqObj_q( \
+ doNothingObj, \
+ C_multiplier.data(), \
+ C_zero_point, \
+ Aint8_zero_point, \
+ Bint8_zero_point.data(), \
+ row_offset_buf.data(), \
+ col_offsets.data(), \
+ bias_q.data(), \
+ cols); \
+ ReQuantizeOutput<FUSE_RELU, Q_GRAN, float> reqObj_f( \
+ doNothingObj, \
+ C_multiplier.data(), \
+ C_zero_point, \
+ Aint8_zero_point, \
+ Bint8_zero_point.data(), \
+ row_offset_buf.data(), \
+ col_offsets.data(), \
+ bias_f.data(), \
+ cols, \
+ 1, \
+ act_times_w_scale.data()); \
+ reqObj_q.f<inst_set_t::avx2>( \
+ output_q_bias.data(), input.data(), block, cols, cols); \
+ reqObj_f.f<inst_set_t::avx2>( \
+ output_f_bias.data(), input.data(), block, cols, cols);
+
+ if (fuse_relu) {
+ if (q_gran == QuantizationGranularity::TENSOR) {
+ TESTCODE(true, QuantizationGranularity::TENSOR)
+
+ } else if (q_gran == QuantizationGranularity::OUT_CHANNEL) {
+ TESTCODE(true, QuantizationGranularity::OUT_CHANNEL)
+
+ } else {
+ FAIL();
+ }
+
+ } else {
+ if (q_gran == QuantizationGranularity::TENSOR) {
+ TESTCODE(false, QuantizationGranularity::TENSOR)
+
+ } else if (q_gran == QuantizationGranularity::OUT_CHANNEL) {
+ TESTCODE(false, QuantizationGranularity::OUT_CHANNEL)
+
+ } else {
+ FAIL();
+ }
+ }
+#undef TESTCODE
+ ASSERT_EQ(output_q_bias, output_f_bias)
+ << "Requantization with quantized bias and float bias differs";
+}
diff --git a/test/UniConvTest.cc b/test/UniConvTest.cc
index 91bf578..cead3a6 100644
--- a/test/UniConvTest.cc
+++ b/test/UniConvTest.cc
@@ -5,11 +5,10 @@
* LICENSE file in the root directory of this source tree.
*/
#include <algorithm>
-#include <random>
#include <iostream>
+#include <random>
#include <stdexcept>
-
#include <gtest/gtest.h>
#include "QuantizationHelpers.h"
@@ -21,6 +20,67 @@
using namespace std;
using namespace fbgemm;
+vector<QuantizationGranularity> qGranularityVals{
+ QuantizationGranularity::TENSOR,
+ QuantizationGranularity::GROUP,
+ QuantizationGranularity::OUT_CHANNEL};
+
+// clang-format off
+static vector<conv_param_t<>> GetShapes_() {
+ vector<conv_param_t<>> shapes = {
+ // MB, IC, OC, {IH, IW}, G, {KH, KW}, {stride_h, stride_w}, {pad_t, pad_l,
+ // pad_b, pad_r}
+ // Regular
+ conv_param_t<>(1, 16, 16, {10, 30}, 1, {3, 3}, {1, 1}, {1, 1, 1, 1}),
+ conv_param_t<>(1, 32, 32, {10, 30}, 1, {3, 3}, {1, 1}, {1, 1, 1, 1}),
+ conv_param_t<>(1, 16, 32, {30, 10}, 1, {3, 3}, {1, 1}, {1, 1, 1, 1}),
+ conv_param_t<>(1, 32, 16, {10, 30}, 1, {3, 3}, {1, 1}, {1, 1, 1, 1}),
+ conv_param_t<>(1, 32, 16, {10, 30}, 1, {3, 3}, {1, 1}, {0, 0, 0, 0}),
+ conv_param_t<>(1, 32, 16, {10, 30}, 1, {3, 3}, {1, 1}, {0, 0, 0, 0}, {2, 2}),
+ conv_param_t<>(1, 32, 16, {10, 30}, 1, {3, 3}, {1, 1}, {1, 1, 1, 1}, {2, 2}),
+ conv_param_t<>(1, 32, 16, {10, 30}, 1, {3, 3}, {1, 1}, {1, 1, 1, 1}, {3, 3}),
+ conv_param_t<>(1, 32, 16, {10, 30}, 1, {3, 3}, {2, 2}, {1, 1, 1, 1}),
+ conv_param_t<>(1, 32, 16, {10, 30}, 1, {3, 3}, {2, 2}, {1, 1, 1, 1}, {2, 2}),
+ conv_param_t<>(1, 32, 16, {10, 30}, 1, {3, 3}, {2, 2}, {1, 1, 1, 1}, {2, 1}),
+ conv_param_t<>(1, 32, 16, {10, 30}, 1, {3, 3}, {2, 2}, {1, 1, 1, 1}, {1, 2}),
+ conv_param_t<>(1, 32, 16, {10, 30}, 1, {3, 3}, {1, 1}, {2, 1, 2, 1}),
+ conv_param_t<>(1, 32, 16, {10, 30}, 1, {3, 3}, {1, 1}, {1, 2, 1, 2}),
+ conv_param_t<>(1, 32, 16, {10, 30}, 1, {3, 3}, {1, 2}, {1, 1, 1, 1}),
+ conv_param_t<>(1, 32, 16, {10, 30}, 1, {3, 3}, {2, 1}, {1, 1, 1, 1}),
+ conv_param_t<>(1, 16, 16, {10, 30}, 1, {3, 5}, {1, 1}, {1, 1, 1, 1}),
+ conv_param_t<>(1, 16, 16, {10, 30}, 1, {5, 3}, {1, 1}, {1, 1, 1, 1}),
+ conv_param_t<>(1, 16, 16, {10, 30}, 1, {5, 3}, {1, 1}, {1, 1, 1, 1}, {2, 2}),
+ // groupwise
+ conv_param_t<>(1, 32, 32, {10, 30}, 8, {3, 3}, {1, 1}, {1, 1, 1, 1}),
+ conv_param_t<>(1, 32, 16, {10, 30}, 8, {3, 3}, {1, 1}, {1, 1, 1, 1}),
+ conv_param_t<>(1, 16, 32, {10, 30}, 8, {3, 3}, {2, 2}, {1, 1, 1, 1}),
+ conv_param_t<>(1, 32, 32, {10, 30}, 8, {3, 3}, {2, 2}, {2, 1, 2, 1}),
+ conv_param_t<>(1, 32, 32, {10, 30}, 8, {3, 3}, {1, 2}, {2, 1, 2, 1}),
+ conv_param_t<>(1, 32, 32, {10, 30}, 8, {3, 3}, {2, 1}, {2, 1, 2, 1}),
+ conv_param_t<>(1, 32, 32, {10, 30}, 8, {3, 5}, {1, 1}, {1, 1, 1, 1}),
+ conv_param_t<>(1, 32, 32, {10, 30}, 8, {5, 3}, {1, 1}, {1, 1, 1, 1}),
+ conv_param_t<>(1, 32, 32, {10, 30}, 8, {5, 3}, {1, 1}, {1, 1, 1, 1}, {2, 2}),
+ // DW
+ conv_param_t<>(1, 32, 32, {10, 30}, 32, {3, 3}, {1, 1}, {1, 1, 1, 1}),
+ conv_param_t<>(1, 32, 32, {10, 30}, 32, {3, 3}, {2, 2}, {1, 1, 1, 1}),
+ conv_param_t<>(1, 32, 32, {10, 30}, 32, {3, 3}, {1, 1}, {1, 2, 1, 2}),
+ conv_param_t<>(1, 32, 32, {10, 30}, 32, {3, 3}, {2, 1}, {1, 1, 1, 1}),
+ conv_param_t<>(1, 32, 32, {10, 30}, 32, {3, 3}, {1, 2}, {1, 1, 1, 1}),
+ conv_param_t<>(1, 32, 32, {10, 30}, 32, {3, 5}, {1, 1}, {1, 1, 1, 1}),
+ conv_param_t<>(1, 32, 32, {10, 30}, 32, {5, 3}, {1, 1}, {1, 1, 1, 1}),
+ conv_param_t<>(1, 32, 32, {10, 30}, 32, {5, 3}, {1, 1}, {1, 1, 1, 1}, {2, 2}),
+ // Pointwise
+ conv_param_t<>(1, 32, 32, {10, 30}, 1, {1, 1}, {1, 1}, {0, 0, 0, 0}),
+ conv_param_t<>(1, 16, 32, {10, 30}, 1, {1, 1}, {1, 1}, {0, 0, 0, 0}),
+ conv_param_t<>(1, 32, 16, {10, 30}, 1, {1, 1}, {1, 1}, {0, 0, 0, 0}),
+ conv_param_t<>(1, 32, 16, {10, 30}, 1, {1, 1}, {2, 2}, {0, 0, 0, 0}),
+ conv_param_t<>(1, 32, 16, {10, 30}, 1, {1, 1}, {1, 2}, {0, 0, 0, 0}),
+ conv_param_t<>(1, 32, 16, {10, 30}, 1, {1, 1}, {2, 1}, {0, 0, 0, 0}),
+ };
+ return shapes;
+}
+// clang-format on
+
namespace {
// tuple represents MB, IC, OC, IT, IH, IW, KH/KW, stride, pad
@@ -28,8 +88,15 @@ class uniConvTest
: public testing::TestWithParam<
tuple<int, int, int, int, int, int, int, int, int, int>> {};
+// tuple represents QuantizationGranularity, A symmetric, B symmetric,
+// test_bias, test_float_bias
+class UniConvQGranTest
+ : public testing::TestWithParam<
+ tuple<QuantizationGranularity, bool, bool, bool, bool>> {};
+
}; // namespace
+// Combine only allows at most 10 generators.
INSTANTIATE_TEST_CASE_P(
InstantiationName,
uniConvTest,
@@ -45,6 +112,15 @@ INSTANTIATE_TEST_CASE_P(
::testing::ValuesIn({1, 2}), // stride
::testing::ValuesIn({0, 1, 2}))); // pad
+INSTANTIATE_TEST_CASE_P(
+ InstantiationName,
+ UniConvQGranTest,
+ ::testing::Combine(
+ ::testing::ValuesIn(qGranularityVals),
+ ::testing::Bool(), // A symmetric
+ ::testing::Bool(), // B symmetric
+ ::testing::Bool(), // test_bias
+ ::testing::Bool())); // test_float_bias
/**
* Test for conv packing
*/
@@ -71,23 +147,19 @@ TEST_P(uniConvTest, packingTest) {
case optimized_conv_t::depthwise: {
ASSERT_EQ(packedB_2D.getPackedWForIm2col(), nullptr)
<< "im2col packed matrix should be null";
- ASSERT_EQ(packedB_2D.getPackedWFor3DDW(), nullptr)
- << "3D depthwise packed matrix should be null";
ASSERT_EQ(packedB_2D.getPackedWForGroupwise(), nullptr)
<< "groupwise packed matrix should be null";
ASSERT_EQ(packedB_2D.getPackedWForPointwise(), nullptr)
<< "pointwise packed matrix should be null";
- ASSERT_NE(packedB_2D.getPackedWFor2DDW(), nullptr)
- << "2D depthwise packed matrix is null";
+ ASSERT_NE(packedB_2D.getPackedWForDepthwise(), nullptr)
+ << "depthwise packed matrix is null";
break;
}
case optimized_conv_t::groupwise: {
ASSERT_EQ(packedB_2D.getPackedWForIm2col(), nullptr)
<< "im2col packed matrix should be null";
- ASSERT_EQ(packedB_2D.getPackedWFor2DDW(), nullptr)
- << "2D depthwise packed matrix is null";
- ASSERT_EQ(packedB_2D.getPackedWFor3DDW(), nullptr)
- << "3D depthwise packed matrix should be null";
+ ASSERT_EQ(packedB_2D.getPackedWForDepthwise(), nullptr)
+ << "depthwise packed matrix should be null";
ASSERT_EQ(packedB_2D.getPackedWForPointwise(), nullptr)
<< "pointwise packed matrix should be null";
ASSERT_NE(packedB_2D.getPackedWForGroupwise(), nullptr)
@@ -97,10 +169,8 @@ TEST_P(uniConvTest, packingTest) {
case optimized_conv_t::pointwise: {
ASSERT_EQ(packedB_2D.getPackedWForIm2col(), nullptr)
<< "im2col packed matrix should be null";
- ASSERT_EQ(packedB_2D.getPackedWFor2DDW(), nullptr)
- << "2D depthwise packed matrix is null";
- ASSERT_EQ(packedB_2D.getPackedWFor3DDW(), nullptr)
- << "3D depthwise packed matrix should be null";
+ ASSERT_EQ(packedB_2D.getPackedWForDepthwise(), nullptr)
+ << "depthwise packed matrix should null";
ASSERT_EQ(packedB_2D.getPackedWForGroupwise(), nullptr)
<< "Groupwise packed matrix should be null";
ASSERT_NE(packedB_2D.getPackedWForPointwise(), nullptr)
@@ -108,10 +178,8 @@ TEST_P(uniConvTest, packingTest) {
break;
}
case optimized_conv_t::im2col: {
- ASSERT_EQ(packedB_2D.getPackedWFor2DDW(), nullptr)
- << "2D depthwise packed matrix is null";
- ASSERT_EQ(packedB_2D.getPackedWFor3DDW(), nullptr)
- << "3D depthwise packed matrix should be null";
+ ASSERT_EQ(packedB_2D.getPackedWForDepthwise(), nullptr)
+ << "depthwise packed matrix should be null";
ASSERT_EQ(packedB_2D.getPackedWForGroupwise(), nullptr)
<< "groupwise packed matrix should be null";
ASSERT_EQ(packedB_2D.getPackedWForPointwise(), nullptr)
@@ -139,16 +207,14 @@ TEST_P(uniConvTest, packingTest) {
switch (ConvFastPath<3, int32_t>(conv_p_3d)) {
case optimized_conv_t::depthwise: {
- ASSERT_EQ(packedB_3D.getPackedWFor2DDW(), nullptr)
- << "2D depthwise packed matrix is null";
ASSERT_EQ(packedB_3D.getPackedWForIm2col(), nullptr)
<< "im2col packed matrix should be null";
ASSERT_EQ(packedB_3D.getPackedWForGroupwise(), nullptr)
<< "groupwise packed matrix should be null";
ASSERT_EQ(packedB_3D.getPackedWForPointwise(), nullptr)
<< "pointwise packed matrix should be null";
- ASSERT_NE(packedB_3D.getPackedWFor3DDW(), nullptr)
- << "3D depthwise packed matrix should be null";
+ ASSERT_NE(packedB_3D.getPackedWForDepthwise(), nullptr)
+ << "depthwise packed matrix is null";
break;
}
case optimized_conv_t::groupwise: {
@@ -156,10 +222,8 @@ TEST_P(uniConvTest, packingTest) {
break;
}
case optimized_conv_t::pointwise: {
- ASSERT_EQ(packedB_3D.getPackedWFor2DDW(), nullptr)
- << "2D depthwise packed matrix is null";
- ASSERT_EQ(packedB_3D.getPackedWFor3DDW(), nullptr)
- << "3D depthwise packed matrix should be null";
+ ASSERT_EQ(packedB_3D.getPackedWForDepthwise(), nullptr)
+ << "depthwise packed matrix should be null";
ASSERT_EQ(packedB_3D.getPackedWForGroupwise(), nullptr)
<< "groupwise packed matrix should be null";
ASSERT_EQ(packedB_3D.getPackedWForIm2col(), nullptr)
@@ -169,10 +233,8 @@ TEST_P(uniConvTest, packingTest) {
break;
}
case optimized_conv_t::im2col: {
- ASSERT_EQ(packedB_3D.getPackedWFor2DDW(), nullptr)
- << "2D depthwise packed matrix is null";
- ASSERT_EQ(packedB_3D.getPackedWFor3DDW(), nullptr)
- << "3D depthwise packed matrix should be null";
+ ASSERT_EQ(packedB_3D.getPackedWForDepthwise(), nullptr)
+ << "depthwise packed matrix should be null";
ASSERT_EQ(packedB_3D.getPackedWForGroupwise(), nullptr)
<< "groupwise packed matrix should be null";
ASSERT_EQ(packedB_3D.getPackedWForPointwise(), nullptr)
@@ -323,3 +385,335 @@ TEST(uniConvTest, cornerCases) {
0,
1);
}
+
+/**
+ * @brief Unit test for uint8 activations, int8 weights, and 32-bit
+ * accumulation. Output processing: requantization -> nothing
+ */
+TEST_P(UniConvQGranTest, requantizeTest) {
+ vector<conv_param_t<>> shapes(GetShapes_());
+ QuantizationGranularity q_granularity;
+ bool a_symmetric, b_symmetric;
+ bool test_bias, test_float_bias;
+ tie(q_granularity, a_symmetric, b_symmetric, test_bias, test_float_bias) =
+ GetParam();
+
+ for (auto conv_p : shapes) {
+ int R = conv_p.K[0];
+ int S = conv_p.K[1];
+ int G = conv_p.G;
+ int OC = conv_p.OC;
+ int OH = conv_p.OUT_DIM[0];
+ int OW = conv_p.OUT_DIM[1];
+ int IC_per_G = conv_p.IC / conv_p.G;
+ int OC_per_G = conv_p.OC / conv_p.G;
+
+ // activations
+ aligned_vector<uint8_t> Aint8(
+ conv_p.MB * conv_p.IN_DIM[0] * conv_p.IN_DIM[1] * conv_p.IC, 0);
+
+ // weights
+ // The weight matrix is in layout G K/G (R S C/G)
+ aligned_vector<int8_t> Bint8(R * S * conv_p.G * IC_per_G * OC_per_G, 0);
+ aligned_vector<int8_t> Bint8_tr(R * S * G * IC_per_G * OC_per_G, 0);
+
+ aligned_vector<int32_t> Cint32_ref(conv_p.MB * OH * OW * OC, 0);
+ aligned_vector<int32_t> Cint32_fb(Cint32_ref.size(), 0);
+ aligned_vector<uint8_t> Cint8_ref(Cint32_ref.size(), 0);
+ aligned_vector<uint8_t> Cint8_fb(Cint32_ref.size(), 0);
+
+ randFill<uint8_t>(Aint8, 0, 5);
+ int32_t Aint8_zero_point = a_symmetric ? 0 : 4;
+
+ randFill<int8_t>(Bint8, -4, 4);
+
+ // computing column offset
+ vector<int32_t> col_offsets(G * OC_per_G);
+
+ int ncols_per_quant_group = G * OC_per_G;
+ if (q_granularity == QuantizationGranularity::GROUP) {
+ ncols_per_quant_group = OC_per_G;
+ } else if (q_granularity == QuantizationGranularity::OUT_CHANNEL) {
+ ncols_per_quant_group = 1;
+ }
+
+ aligned_vector<int32_t> Bint8_zero_point(
+ G * OC_per_G / ncols_per_quant_group);
+ if (b_symmetric) {
+ randFill(Bint8_zero_point, -3, 3);
+ } else {
+ randFill(Bint8_zero_point, 0, 0);
+ }
+
+ // matrix dimensions after im2col for each GEMM.
+ // For each group, there is one GEMM of the following dimensions
+ int MDim = conv_p.MB * OH * OW;
+ int NDim = OC_per_G;
+ int KDim = R * S * IC_per_G;
+
+ vector<uint8_t> Aint8_im2col(MDim * KDim * G);
+ im2col_ref(conv_p, Aint8.data(), Aint8_zero_point, Aint8_im2col.data());
+
+ vector<int32_t> row_offsets(MDim);
+
+ // activation_scale * weight_scale
+ aligned_vector<float> act_times_w_scale(Bint8_zero_point.size());
+ randFill(act_times_w_scale, 0.1234f / 2, 0.1234f * 3 / 2);
+
+ float out_scale = 2.0f;
+ aligned_vector<float> C_multiplier(Bint8_zero_point.size());
+ transform(
+ act_times_w_scale.begin(),
+ act_times_w_scale.end(),
+ C_multiplier.begin(),
+ [&out_scale](float i) { return i / out_scale; });
+
+ int32_t C_zero_pt = 5;
+
+ // initialize bias
+ aligned_vector<int32_t> bias_int32(OC);
+ aligned_vector<float> bias_fp32(OC);
+ if (test_bias) {
+ randFill(bias_int32, -8, 8);
+ }
+
+ // floating point bias
+ if (test_float_bias) {
+ if (q_granularity == QuantizationGranularity::TENSOR) {
+ transform(
+ bias_int32.begin(),
+ bias_int32.end(),
+ bias_fp32.begin(),
+ [&act_times_w_scale](float i) { return i * act_times_w_scale[0]; });
+ } else if (q_granularity == QuantizationGranularity::GROUP) {
+ for (int g = 0; g < G; ++g) {
+ for (int c = 0; c < OC_per_G; ++c) {
+ bias_fp32[g * OC_per_G + c] = act_times_w_scale[g] *
+ static_cast<float>(bias_int32[g * OC_per_G + c]);
+ }
+ }
+ } else { // OUT_CHANNEL
+ transform(
+ act_times_w_scale.begin(),
+ act_times_w_scale.end(),
+ bias_int32.begin(),
+ bias_fp32.begin(),
+ multiplies<float>());
+ }
+ }
+ // reference implementation
+ // conv_ref expects weights to be in G (R S C/G) K/G
+ int8_t* rightBData = Bint8.data();
+ transposeConvWeights(conv_p, Bint8.data(), Bint8_tr.data());
+ rightBData = Bint8_tr.data();
+ for (int g = 0; g < G; ++g) {
+ col_offsets_with_zero_pt_s8acc32_ref(
+ R * S * IC_per_G,
+ OC_per_G,
+ OC_per_G,
+ rightBData + g * R * S * IC_per_G * OC_per_G,
+ Bint8_zero_point.data() + g * OC_per_G / ncols_per_quant_group,
+ col_offsets.data() + g * OC_per_G,
+ ncols_per_quant_group);
+ }
+ conv_ref(
+ conv_p, Aint8.data(), Aint8_zero_point, rightBData, Cint32_ref.data());
+
+ for (int g = 0; g < G; ++g) {
+ row_offsets_u8acc32_ref(
+ MDim,
+ KDim,
+ KDim * G,
+ Aint8_im2col.data() + g * KDim,
+ row_offsets.data());
+
+ requantize_u8acc32_ref(
+ MDim,
+ NDim,
+ G * NDim,
+ Cint32_ref.data() + g * NDim,
+ Cint8_ref.data() + g * NDim,
+ C_multiplier.data() + g * NDim / ncols_per_quant_group,
+ C_zero_pt,
+ Aint8_zero_point,
+ Bint8_zero_point.data() + g * NDim / ncols_per_quant_group,
+ row_offsets.data(),
+ col_offsets.data() + g * NDim,
+ test_bias ? bias_int32.data() + g * NDim : nullptr,
+ ncols_per_quant_group);
+ }
+
+ PackWeightsForConv<2> packedWeights(conv_p, Bint8.data());
+
+ // TODO: Uncomment once we support multiple threads in fbgemmGroupwiseConv
+ // #ifdef _OPENMP
+ // #pragma omp parallel
+ // #endif
+ {
+ vector<int32_t> row_offset_buf(rowOffsetBufferSizeGConv(conv_p));
+
+ DoNothing<> doNothingObj{};
+
+ int num_threads = fbgemm_get_num_threads();
+ int tid = fbgemm_get_thread_num();
+
+ if (q_granularity == QuantizationGranularity::TENSOR) {
+ if (test_float_bias) {
+ ReQuantizeOutput<false, QuantizationGranularity::TENSOR, float>
+ reqObj(
+ doNothingObj,
+ C_multiplier.data(),
+ C_zero_pt,
+ Aint8_zero_point,
+ Bint8_zero_point.data(),
+ nullptr, /* row offset buffer */
+ col_offsets.data(),
+ test_bias ? bias_fp32.data() : nullptr,
+ G * NDim,
+ G,
+ act_times_w_scale.data());
+
+ fbgemmConv(
+ conv_p,
+ Aint8.data(),
+ packedWeights,
+ Cint8_fb.data(),
+ Cint32_fb.data(),
+ reqObj,
+ tid,
+ num_threads);
+
+ } else {
+ ReQuantizeOutput<false, QuantizationGranularity::TENSOR> reqObj(
+ doNothingObj,
+ C_multiplier.data(),
+ C_zero_pt,
+ Aint8_zero_point,
+ Bint8_zero_point.data(),
+ nullptr, /* row offset buffer */
+ col_offsets.data(),
+ test_bias ? bias_int32.data() : nullptr,
+ G * NDim,
+ G);
+
+ fbgemmConv(
+ conv_p,
+ Aint8.data(),
+ packedWeights,
+ Cint8_fb.data(),
+ Cint32_fb.data(),
+ reqObj,
+ tid,
+ num_threads);
+ }
+
+ } else if (q_granularity == QuantizationGranularity::GROUP) {
+ if (test_float_bias) {
+ ReQuantizeOutput<false, QuantizationGranularity::GROUP, float> reqObj(
+ doNothingObj,
+ C_multiplier.data(),
+ C_zero_pt,
+ Aint8_zero_point,
+ Bint8_zero_point.data(),
+ nullptr, /* row offset buffer */
+ col_offsets.data(),
+ test_bias ? bias_fp32.data() : nullptr,
+ G * NDim,
+ G,
+ act_times_w_scale.data());
+
+ fbgemmConv(
+ conv_p,
+ Aint8.data(),
+ packedWeights,
+ Cint8_fb.data(),
+ Cint32_fb.data(),
+ reqObj,
+ tid,
+ num_threads);
+
+ } else {
+ ReQuantizeOutput<false, QuantizationGranularity::GROUP> reqObj(
+ doNothingObj,
+ C_multiplier.data(),
+ C_zero_pt,
+ Aint8_zero_point,
+ Bint8_zero_point.data(),
+ nullptr, /* row offset buffer */
+ col_offsets.data(),
+ test_bias ? bias_int32.data() : nullptr,
+ G * NDim,
+ G);
+
+ fbgemmConv(
+ conv_p,
+ Aint8.data(),
+ packedWeights,
+ Cint8_fb.data(),
+ Cint32_fb.data(),
+ reqObj,
+ tid,
+ num_threads);
+ }
+
+ } else {
+ if (test_float_bias) {
+ ReQuantizeOutput<false, QuantizationGranularity::OUT_CHANNEL, float>
+ reqObj(
+ doNothingObj,
+ C_multiplier.data(),
+ C_zero_pt,
+ Aint8_zero_point,
+ Bint8_zero_point.data(),
+ nullptr, /* row offset buffer */
+ col_offsets.data(),
+ test_bias ? bias_fp32.data() : nullptr,
+ G * NDim,
+ G,
+ act_times_w_scale.data());
+
+ fbgemmConv(
+ conv_p,
+ Aint8.data(),
+ packedWeights,
+ Cint8_fb.data(),
+ Cint32_fb.data(),
+ reqObj,
+ tid,
+ num_threads);
+
+ } else {
+ ReQuantizeOutput<false, QuantizationGranularity::OUT_CHANNEL> reqObj(
+ doNothingObj,
+ C_multiplier.data(),
+ C_zero_pt,
+ Aint8_zero_point,
+ Bint8_zero_point.data(),
+ nullptr, /* row offset buffer */
+ col_offsets.data(),
+ test_bias ? bias_int32.data() : nullptr,
+ G * NDim,
+ G);
+
+ fbgemmConv(
+ conv_p,
+ Aint8.data(),
+ packedWeights,
+ Cint8_fb.data(),
+ Cint32_fb.data(),
+ reqObj,
+ tid,
+ num_threads);
+ }
+ }
+ } // omp parallel
+
+ compare_validate_buffers(
+ Cint8_ref.data(),
+ Cint8_fb.data(),
+ MDim,
+ NDim * G,
+ NDim * G,
+ static_cast<uint8_t>(0));
+ } // for each shape
+}