Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/FBGEMM.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorYoung Jin Kim <youki@microsoft.com>2019-12-03 22:53:14 +0300
committerGitHub <noreply@github.com>2019-12-03 22:53:14 +0300
commit84e66a976046180187724aff60a236c5378fde7c (patch)
treef2c4e39fe4d46df1b7a23602d244d21c9f9ee35b
parentf0b354327aaf2330c65340725b1981040c8bec9e (diff)
parente6e9b167426c12cd048c3d7d76651492f818daec (diff)
Merge pull request #1 from marian-nmt/youki/win-jit-debug-int8
Youki/win jit debug int8
-rw-r--r--CMakeLists.txt4
-rw-r--r--CODE_OF_CONDUCT.md78
-rw-r--r--README.md9
-rw-r--r--bench/ConvUnifiedBenchmark.cc45
-rw-r--r--bench/Depthwise3DBenchmark.cc134
-rw-r--r--bench/DepthwiseBenchmark.cc128
-rw-r--r--bench/GEMMsBenchmark.cc2
-rw-r--r--bench/GEMMsTunableBenchmark.cc6
-rw-r--r--bench/PackedFloatInOutBenchmark.cc2
-rw-r--r--bench/PackedRequantizeAcc16Benchmark.cc2
-rw-r--r--bench/PackedRequantizeAcc32Benchmark.cc2
-rw-r--r--include/fbgemm/ConvUtils.h38
-rw-r--r--include/fbgemm/Fbgemm.h158
-rw-r--r--include/fbgemm/FbgemmFP16.h32
-rw-r--r--include/fbgemm/FbgemmI8DepthwiseAvx2.h155
-rw-r--r--include/fbgemm/OutputProcessing-inl.h36
-rw-r--r--include/fbgemm/PackingTraits-inl.h50
-rw-r--r--include/fbgemm/QuantUtils.h35
-rw-r--r--include/fbgemm/QuantUtilsAvx2.h13
-rw-r--r--include/fbgemm/Utils.h50
-rw-r--r--include/fbgemm/UtilsAvx2.h5
-rw-r--r--src/CodeCache.h59
-rw-r--r--src/ExecuteKernelU8S8.cc94
-rw-r--r--src/Fbgemm.cc76
-rw-r--r--src/FbgemmConv.cc247
-rw-r--r--src/FbgemmFP16.cc2
-rw-r--r--src/FbgemmI8Depthwise3DAvx2.cc1423
-rw-r--r--src/FbgemmI8DepthwiseAvx2-inl.h710
-rw-r--r--src/FbgemmI8DepthwiseAvx2.cc2329
-rw-r--r--src/GenerateKernel.h123
-rw-r--r--src/GenerateKernelU8S8S32ACC16.cc403
-rw-r--r--src/GenerateKernelU8S8S32ACC16Avx512.cc543
-rw-r--r--src/GenerateKernelU8S8S32ACC16Avx512VNNI.cc102
-rw-r--r--src/GenerateKernelU8S8S32ACC32.cc422
-rw-r--r--src/GenerateKernelU8S8S32ACC32Avx512.cc574
-rw-r--r--src/GenerateKernelU8S8S32ACC32Avx512VNNI.cc435
-rw-r--r--src/GroupwiseConv.h142
-rw-r--r--src/GroupwiseConvAcc32Avx2.cc304
-rw-r--r--src/OptimizedKernelsAvx2.cc51
-rw-r--r--src/PackAMatrix.cc30
-rw-r--r--src/PackAWithIm2Col.cc49
-rw-r--r--src/PackAWithQuantRowOffset.cc37
-rw-r--r--src/PackAWithRowOffset.cc37
-rw-r--r--src/PackBMatrix.cc260
-rw-r--r--src/PackDepthwiseConvMatrixAvx2.cc211
-rw-r--r--src/PackMatrix.cc56
-rw-r--r--src/PackWeightMatrixForGConv.cc139
-rw-r--r--src/PackWeightsForConv.cc151
-rw-r--r--src/QuantUtils.cc139
-rw-r--r--[-rwxr-xr-x]src/QuantUtilsAvx2.cc760
-rw-r--r--src/RefImplementations.cc474
-rw-r--r--src/RefImplementations.h120
-rwxr-xr-xsrc/Utils.cc7
-rw-r--r--test/FP16Test.cc116
-rw-r--r--test/GConvTest.cc65
-rw-r--r--test/I8DepthwiseTest.cc478
-rw-r--r--test/I8DepthwiseTest.h39
-rw-r--r--test/Im2ColFusedRequantizeTest.cc5
-rw-r--r--test/PackedRequantizeAcc16Test.cc93
-rw-r--r--test/PackedRequantizeTest.cc93
-rw-r--r--test/QuantUtilsTest.cc183
-rw-r--r--test/RequantizeOnlyTest.cc169
-rw-r--r--test/TestUtils.h9
-rw-r--r--test/UniConvPackingTest.cc148
-rw-r--r--test/UniConvTest.cc714
m---------third_party/asmjit0
66 files changed, 8610 insertions, 4995 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt
index e6c7419..c06b60b 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -37,8 +37,10 @@ set(FBGEMM_GENERIC_SRCS src/ExecuteKernel.cc
src/FbgemmI8Spmdm.cc
src/GenerateKernelU8S8S32ACC16.cc
src/GenerateKernelU8S8S32ACC16Avx512.cc
+ src/GenerateKernelU8S8S32ACC16Avx512VNNI.cc
src/GenerateKernelU8S8S32ACC32.cc
src/GenerateKernelU8S8S32ACC32Avx512.cc
+ src/GenerateKernelU8S8S32ACC32Avx512VNNI.cc
src/GroupwiseConvAcc32Avx2.cc
src/PackAMatrix.cc
src/PackAWithIm2Col.cc
@@ -87,8 +89,10 @@ endif()
#All the source files that either use avx2 instructions statically
set(FBGEMM_AVX2_SRCS
src/FbgemmFP16UKernelsAvx2.cc
+ src/FbgemmI8Depthwise3DAvx2.cc
src/FbgemmI8DepthwiseAvx2.cc
src/OptimizedKernelsAvx2.cc
+ src/PackDepthwiseConvMatrixAvx2.cc
src/QuantUtilsAvx2.cc
src/UtilsAvx2.cc)
diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md
index 0f7ad8b..d1abc70 100644
--- a/CODE_OF_CONDUCT.md
+++ b/CODE_OF_CONDUCT.md
@@ -1,5 +1,77 @@
# Code of Conduct
-Facebook has adopted a Code of Conduct that we expect project participants to adhere to.
-Please read the [full text](https://code.fb.com/codeofconduct/)
-so that you can understand what actions will and will not be tolerated.
+## Our Pledge
+
+In the interest of fostering an open and welcoming environment, we as
+contributors and maintainers pledge to make participation in our project and
+our community a harassment-free experience for everyone, regardless of age, body
+size, disability, ethnicity, sex characteristics, gender identity and expression,
+level of experience, education, socio-economic status, nationality, personal
+appearance, race, religion, or sexual identity and orientation.
+
+## Our Standards
+
+Examples of behavior that contributes to creating a positive environment
+include:
+
+* Using welcoming and inclusive language
+* Being respectful of differing viewpoints and experiences
+* Gracefully accepting constructive criticism
+* Focusing on what is best for the community
+* Showing empathy towards other community members
+
+Examples of unacceptable behavior by participants include:
+
+* The use of sexualized language or imagery and unwelcome sexual attention or
+ advances
+* Trolling, insulting/derogatory comments, and personal or political attacks
+* Public or private harassment
+* Publishing others' private information, such as a physical or electronic
+ address, without explicit permission
+* Other conduct which could reasonably be considered inappropriate in a
+ professional setting
+
+## Our Responsibilities
+
+Project maintainers are responsible for clarifying the standards of acceptable
+behavior and are expected to take appropriate and fair corrective action in
+response to any instances of unacceptable behavior.
+
+Project maintainers have the right and responsibility to remove, edit, or
+reject comments, commits, code, wiki edits, issues, and other contributions
+that are not aligned to this Code of Conduct, or to ban temporarily or
+permanently any contributor for other behaviors that they deem inappropriate,
+threatening, offensive, or harmful.
+
+## Scope
+
+This Code of Conduct applies within all project spaces, and it also applies when
+an individual is representing the project or its community in public spaces.
+Examples of representing a project or community include using an official
+project e-mail address, posting via an official social media account, or acting
+as an appointed representative at an online or offline event. Representation of
+a project may be further defined and clarified by project maintainers.
+
+## Enforcement
+
+Instances of abusive, harassing, or otherwise unacceptable behavior may be
+reported by contacting the project team at <opensource-conduct@fb.com>. All
+complaints will be reviewed and investigated and will result in a response that
+is deemed necessary and appropriate to the circumstances. The project team is
+obligated to maintain confidentiality with regard to the reporter of an incident.
+Further details of specific enforcement policies may be posted separately.
+
+Project maintainers who do not follow or enforce the Code of Conduct in good
+faith may face temporary or permanent repercussions as determined by other
+members of the project's leadership.
+
+## Attribution
+
+This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
+available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
+
+[homepage]: https://www.contributor-covenant.org
+
+For answers to common questions about this code of conduct, see
+https://www.contributor-covenant.org/faq
+
diff --git a/README.md b/README.md
index 2335b81..5f3ca40 100644
--- a/README.md
+++ b/README.md
@@ -12,9 +12,9 @@ row-wise quantization and outlier-aware quantization. FBGEMM also exploits
fusion opportunities in order to overcome the unique challenges of matrix
multiplication at lower precision with bandwidth-bound operations.
-FBGEMM is used as a backend of Caffe2 quantized operators for x86 machines
-(https://github.com/pytorch/pytorch/tree/master/caffe2/quantization/server).
-We also plan to integrate FBGEMM into PyTorch.
+FBGEMM is used as a backend of Caffe2 and PyTorch quantized operators for x86 machines:
+* Caffe2: https://github.com/pytorch/pytorch/tree/master/caffe2/quantization/server
+* PyTorch: https://github.com/pytorch/pytorch/tree/master/aten/src/ATen/native/quantized/cpu
## Examples
@@ -64,6 +64,9 @@ General build instructions are as follows:
```
git clone --recursive https://github.com/pytorch/FBGEMM.git
cd FBGEMM
+# if you are updating an existing checkout
+git submodule sync
+git submodule update --init --recursive
mkdir build && cd build
cmake ..
make
diff --git a/bench/ConvUnifiedBenchmark.cc b/bench/ConvUnifiedBenchmark.cc
index 6bc2cf4..35a3b2a 100644
--- a/bench/ConvUnifiedBenchmark.cc
+++ b/bench/ConvUnifiedBenchmark.cc
@@ -24,6 +24,7 @@
using namespace std;
using namespace fbgemm;
+// clang-format off
// 2D conv shapes
vector<conv_param_t<2>> shapes_2d = {
// MB, IC, OC, IH, IW, G, KH, KW, stride_h, stride_w,
@@ -31,22 +32,31 @@ vector<conv_param_t<2>> shapes_2d = {
// 2D convolutions
// regular
conv_param_t<>(1, 128, 128, {56, 56}, 1, {3, 3}, {1, 1}, {1, 1, 1, 1}),
+ // regular with dilation
+ conv_param_t<>(1, 128, 128, {56, 56}, 1, {3, 3}, {1, 1}, {1, 1, 1, 1}, {2, 2}),
// groupwise
conv_param_t<>(1, 128, 128, {56, 56}, 32, {3, 3}, {1, 1}, {1, 1, 1, 1}),
-
// DW
conv_param_t<>(1, 272, 272, {47, 125}, 272, {3, 3}, {1, 1}, {1, 1, 1, 1}),
+ // Pointwise
+ conv_param_t<>(1, 128, 128, {56, 56}, 1, {1, 1}, {1, 1}, {0, 0, 0, 0})
+
};
// 3D conv shapes
vector<conv_param_t<3>> shapes_3d = {
- // MB, IC, OC, {IT, IH, IW}, G, {KT, KH, KW}, {stride_t, stride_h, stride_w},
- // {pad_prev, pad_h_top, pad_w_left, pad_next, pad_h_bottom, pad_w_right}
- // Regular
- conv_param_t<3>(1, 64, 64, {32, 56, 56}, 1, {3, 3, 3}, {1, 1, 1}, {1, 1, 1, 1, 1, 1}),
- // Depthwise
- conv_param_t<3>(1, 64, 64, {32, 56, 56}, 64, {3, 3, 3}, {1, 1, 1}, {1, 1, 1, 1, 1, 1})
-};
+ // MB, IC, OC, {IT, IH, IW}, G, {KT, KH, KW}, {stride_t, stride_h,
+ // stride_w},
+ // {pad_prev, pad_h_top, pad_w_left, pad_next, pad_h_bottom, pad_w_right}
+ // Regular
+ conv_param_t<3>(1, 64, 64, {8, 14, 14}, 1, {3, 3, 3}, {1, 1, 1}, {1, 1, 1, 1, 1, 1}),
+ //With dilations
+ conv_param_t<3>(1, 64, 64, {8, 14, 14}, 1, {3, 3, 3}, {1, 1, 1}, {1, 1, 1, 1, 1, 1}, {2, 2, 2}),
+ // Depthwise
+ conv_param_t<3>(1, 64, 64, {8, 14, 14}, 64, {3, 3, 3}, {1, 1, 1}, {1, 1, 1, 1, 1, 1}),
+ // Pointwise
+ conv_param_t<3>(1, 128, 128, {8, 14, 14}, 1, {1, 1, 1}, {1, 1, 1}, {0, 0, 0, 0})};
+// clang-format on
template <int SPATIAL_DIM, typename Acc_t>
void performance_test(const vector<conv_param_t<SPATIAL_DIM>>& shapes) {
@@ -77,6 +87,10 @@ void performance_test(const vector<conv_param_t<SPATIAL_DIM>>& shapes) {
header += "pad_t, ";
}
header += "pad_h, pad_w, ";
+ if (SPATIAL_DIM == 3) {
+ header += "dilation_t, ";
+ }
+ header += "dilation_h, dilation_w, ";
header += "Type, M, N, K, ";
@@ -110,6 +124,9 @@ void performance_test(const vector<conv_param_t<SPATIAL_DIM>>& shapes) {
aligned_vector<int8_t> Bint8(
kernel_dim * conv_p.IC * (conv_p.OC / conv_p.G));
+ aligned_vector<int8_t> Bint8_tr(
+ kernel_dim * conv_p.IC * (conv_p.OC / conv_p.G));
+
int im_out_dim = accumulate(
conv_p.OUT_DIM.begin(), conv_p.OUT_DIM.end(), 1, multiplies<int>());
aligned_vector<int32_t> Cint32_ref(conv_p.MB * im_out_dim * conv_p.OC);
@@ -132,14 +149,14 @@ void performance_test(const vector<conv_param_t<SPATIAL_DIM>>& shapes) {
randFill(C_multiplier, 0.1234f / 2, 0.1234f * 3 / 2);
int32_t C_zero_point = 5;
- aligned_vector<float> Bfp32(Bint8.begin(), Bint8.end());
-
// reference implementation
+ // conv_ref expects weights to be in G (R S C/G) K/G
+ transposeConvWeights<SPATIAL_DIM>(conv_p, Bint8.data(), Bint8_tr.data());
conv_ref(
conv_p,
Aint8.data(),
Aint8_zero_point,
- Bint8.data(),
+ Bint8_tr.data(),
Cint32_ref.data());
// matrix dimensions after im2col
@@ -162,7 +179,7 @@ void performance_test(const vector<conv_param_t<SPATIAL_DIM>>& shapes) {
KDimPerGroup,
OC_per_G,
OC_per_G,
- Bint8.data() + g * KDimPerGroup * OC_per_G,
+ Bint8_tr.data() + g * KDimPerGroup * OC_per_G,
Bint8_zero_point.data(),
col_offsets.data() + g * OC_per_G,
conv_p.OC);
@@ -271,7 +288,9 @@ void performance_test(const vector<conv_param_t<SPATIAL_DIM>>& shapes) {
for (int i = 0; i < SPATIAL_DIM; ++i) {
cout << conv_p.pad[i] << ", ";
}
-
+ for (int i = 0; i < SPATIAL_DIM; ++i) {
+ cout << conv_p.dilation[i] << ", ";
+ }
cout << setw(13) << runType << ", " << setw(5) << fixed << setw(5)
<< setw(6) << MDim << ", " << setw(6) << NDim << ", " << setw(6)
<< KDim << ", ";
diff --git a/bench/Depthwise3DBenchmark.cc b/bench/Depthwise3DBenchmark.cc
index 0efdcac..ff2be6f 100644
--- a/bench/Depthwise3DBenchmark.cc
+++ b/bench/Depthwise3DBenchmark.cc
@@ -4,7 +4,6 @@
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
-#include "test/I8DepthwiseTest.h"
#include <algorithm>
#include <chrono>
@@ -19,8 +18,8 @@
#include "AlignedVec.h"
#include "BenchUtils.h"
-#include "fbgemm/Utils.h"
#include "fbgemm/FbgemmI8DepthwiseAvx2.h"
+#include "fbgemm/Utils.h"
#include "src/RefImplementations.h"
using namespace std;
@@ -35,6 +34,34 @@ int main() {
}
#endif
+ // From ResNeXt-3D-101
+ // clang-format off
+ vector<vector<int>> shapes_3d = {
+ // NOTE: clang-format wants to use a different formatting but the current
+ // formatting should be easier to read.
+ // N, K, T_in, H_in, W_in, stride
+ { 1, 64, 32, 56, 56, 1, },
+ { 1, 128, 16, 28, 28, 1, },
+ { 1, 256, 8, 14, 14, 1, },
+ { 1, 512, 4, 7, 7, 1, },
+
+ { 1, 128, 32, 56, 56, 2, },
+ { 1, 256, 16, 28, 28, 2, },
+ { 1, 512, 8, 14, 14, 2, },
+
+ { 5, 64, 32, 56, 56, 1, },
+ { 5, 128, 16, 28, 28, 1, },
+ { 5, 256, 8, 14, 14, 1, },
+ { 5, 512, 4, 7, 7, 1, },
+
+ { 5, 128, 32, 56, 56, 2, },
+ { 5, 256, 16, 28, 28, 2, },
+ { 5, 512, 8, 14, 14, 2, },
+
+ { 1, 8, 4, 4, 4, 1, },
+ };
+ // clang-format on
+
// Depthwise is memory BW bound so we want to flush LLC.
bool flush = true;
std::vector<char> llc;
@@ -61,14 +88,28 @@ int main() {
constexpr int K_T = 3, K_H = 3, K_W = 3;
constexpr int PAD_P = 1, PAD_N = 1, PAD_T = 1, PAD_B = 1, PAD_L = 1,
PAD_R = 1;
- int T_OUT = (T + PAD_P + PAD_N - K_T) / stride_t + 1;
- int H_OUT = (H + PAD_T + PAD_B - K_H) / stride_h + 1;
- int W_OUT = (W + PAD_L + PAD_R - K_W) / stride_w + 1;
+
+ conv_param_t<3> conv_p(
+ N,
+ K,
+ K,
+ {T, H, W},
+ K,
+ {K_T, K_H, K_W},
+ {stride_t, stride_h, stride_w},
+ {PAD_P, PAD_T, PAD_L, PAD_N, PAD_B, PAD_R});
+ int T_OUT = conv_p.OUT_DIM[0];
+ int H_OUT = conv_p.OUT_DIM[1];
+ int W_OUT = conv_p.OUT_DIM[2];
+
+ int MDim = N * T_OUT * H_OUT * W_OUT;
+ int KDim = K_T * K_H * K_W * K;
+ int KDimPerGroup = KDim / conv_p.G;
aligned_vector<uint8_t> A(N * T * H * W * K);
- aligned_vector<int8_t> B(K * K_T * K_H * K_W);
- aligned_vector<int32_t> C_ref(N * T_OUT * H_OUT * W_OUT * K),
- C(C_ref.size());
+ aligned_vector<int8_t> B(KDim);
+ aligned_vector<int32_t> C_ref(MDim * K), C(C_ref.size());
+ aligned_vector<uint8_t> C_uint8_ref(C_ref.size()), C_uint8(C_ref.size());
randFill<uint8_t>(A, 0, 86);
int32_t A_zero_point = 43;
@@ -76,52 +117,49 @@ int main() {
randFill<int8_t>(B, -16, 16);
int32_t B_zero_point = 5;
- depthwise_3x3x3_pad_1_ref(
- N,
- T,
- H,
- W,
- K,
- stride_t,
- stride_h,
- stride_w,
- A_zero_point,
- A.data(),
- B.data(),
- C_ref.data());
-
- int32_t minimum = *min_element(C_ref.begin(), C_ref.end());
- int32_t maximum = *max_element(C_ref.begin(), C_ref.end());
+ aligned_vector<float> C_multiplier(1);
+ randFill(C_multiplier, 0.001234f / 2, 0.001234f * 3 / 2);
+ int32_t C_zero_point = 5;
- float C_multiplier = 255. / (maximum - minimum);
+ vector<int32_t> row_offsets(MDim);
+ // im2col to compute row offset later
+ vector<uint8_t> A_im2col(MDim * KDim);
+ im2col_ref(conv_p, A.data(), A_zero_point, A_im2col.data());
aligned_vector<int32_t> col_offsets(K);
aligned_vector<int32_t> bias(K);
randFill(col_offsets, -100, 100);
randFill(bias, -40, 40);
- int32_t C_zero_point = 5;
- aligned_vector<uint8_t> C_uint8_ref(C_ref.size()), C_uint8(C_ref.size());
- depthwise_3x3x3_pad_1_ref(
- N,
- T,
- H,
- W,
- K,
- stride_t,
- stride_h,
- stride_w,
- A_zero_point,
- A.data(),
- B_zero_point,
- B.data(),
- C_multiplier,
- C_zero_point,
- C_uint8_ref.data(),
- col_offsets.data(),
- bias.data());
-
- Packed3x3x3ConvMatrix Bp(K, B.data());
+ conv_ref(conv_p, A.data(), A_zero_point, B.data(), C_ref.data());
+
+ for (int g = 0; g < conv_p.G; ++g) {
+ // Compute row offset
+ row_offsets_u8acc32_ref(
+ MDim,
+ KDimPerGroup,
+ KDim,
+ A_im2col.data() + g * KDimPerGroup,
+ row_offsets.data());
+
+ // Requantization
+ requantize_u8acc32_ref(
+ MDim,
+ 1,
+ conv_p.G,
+ C_ref.data() + g,
+ C_uint8_ref.data() + g,
+ C_multiplier.data(),
+ C_zero_point,
+ A_zero_point,
+ &B_zero_point,
+ row_offsets.data(),
+ col_offsets.data() + g,
+ bias.data() + g,
+ K);
+ }
+
+ PackedDepthWiseConvMatrix Bp(K, 3 * 3 * 3, B.data());
double ttot = 0;
double bytes = double(NITER) *
@@ -153,7 +191,7 @@ int main() {
A.data(),
B_zero_point,
Bp,
- C_multiplier,
+ C_multiplier[0],
C_zero_point,
C_uint8.data(),
col_offsets.data(),
diff --git a/bench/DepthwiseBenchmark.cc b/bench/DepthwiseBenchmark.cc
index 96921a1..6c2ee17 100644
--- a/bench/DepthwiseBenchmark.cc
+++ b/bench/DepthwiseBenchmark.cc
@@ -17,8 +17,8 @@
#include "AlignedVec.h"
#include "BenchUtils.h"
-#include "fbgemm/Utils.h"
#include "fbgemm/FbgemmI8DepthwiseAvx2.h"
+#include "fbgemm/Utils.h"
#include "src/RefImplementations.h"
using namespace std;
@@ -34,10 +34,11 @@ int main() {
#endif
// From Xray OCR
+ // clang-format off
vector<vector<int>> shapes = {
// NOTE: clang-format wants to use a different formatting but the current
// formatting should be easier to read.
- // N, G, H_in, W_in, stride
+ // N, K, H_in, W_in, stride
{ 1, 272, 47, 125, 1, },
{ 1, 272, 64, 125, 1, },
{ 1, 272, 66, 125, 1, },
@@ -138,6 +139,7 @@ int main() {
{ 96, 544, 14, 14, 2, },
{ 100, 544, 14, 14, 2, },
};
+ // clang-format on
// Depthwise is memory BW bound so we want to flush LLC.
bool flush = true;
@@ -155,19 +157,35 @@ int main() {
for (auto shape : shapes) {
int N = shape[0];
- int G = shape[1];
+ int K = shape[1];
int H = shape[2];
int W = shape[3];
int stride_h = shape[4];
int stride_w = stride_h;
constexpr int R = 3, S = 3;
- constexpr int PAD_T = 1, PAD_B = 1, PAD_L = 1, PAD_R = 1;
- int H_OUT = (H + PAD_T + PAD_B - R) / stride_h + 1;
- int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1;
+ int PAD_T = (R - 1) / 2, PAD_B = (R - 1) / 2, PAD_L = (S - 1) / 2,
+ PAD_R = (S - 1) / 2;
+
+ conv_param_t<2> conv_p(
+ N,
+ K,
+ K,
+ {H, W},
+ K,
+ {R, S},
+ {stride_h, stride_w},
+ {PAD_T, PAD_L, PAD_B, PAD_R});
+ int H_OUT = conv_p.OUT_DIM[0];
+ int W_OUT = conv_p.OUT_DIM[1];
+
+ int MDim = N * H_OUT * W_OUT;
+ int KDim = R * S * K;
+ int KDimPerGroup = KDim / conv_p.G;
- aligned_vector<uint8_t> A(N * H * W * G);
- aligned_vector<int8_t> B(G * R * S);
- aligned_vector<int32_t> C_ref(N * H_OUT * W_OUT * G), C(C_ref.size());
+ aligned_vector<uint8_t> A(N * H * W * K);
+ aligned_vector<int8_t> B(KDim);
+ aligned_vector<int32_t> C_ref(MDim * K), C(C_ref.size());
+ aligned_vector<uint8_t> C_uint8_ref(C_ref.size()), C_uint8(C_ref.size());
randFill<uint8_t>(A, 0, 86);
int32_t A_zero_point = 43;
@@ -175,53 +193,54 @@ int main() {
randFill<int8_t>(B, -16, 16);
int32_t B_zero_point = 5;
- depthwise_3x3_pad_1_ref(
- N,
- H,
- W,
- G,
- stride_h,
- stride_w,
- A_zero_point,
- A.data(),
- B.data(),
- C_ref.data());
-
- int32_t minimum = *min_element(C_ref.begin(), C_ref.end());
- int32_t maximum = *max_element(C_ref.begin(), C_ref.end());
+ aligned_vector<float> C_multiplier(1);
+ randFill(C_multiplier, 0.001234f / 2, 0.001234f * 3 / 2);
+ int32_t C_zero_point = 5;
- float C_multiplier = 255. / (maximum - minimum);
+ vector<int32_t> row_offsets(MDim);
+ // im2col to compute row offset later
+ vector<uint8_t> A_im2col(MDim * KDim);
+ im2col_ref(conv_p, A.data(), A_zero_point, A_im2col.data());
- aligned_vector<int32_t> col_offsets(G);
- aligned_vector<int32_t> bias(G);
+ aligned_vector<int32_t> col_offsets(K);
+ aligned_vector<int32_t> bias(K);
randFill(col_offsets, -100, 100);
randFill(bias, -40, 40);
- int32_t C_zero_point = 5;
- aligned_vector<uint8_t> C_uint8_ref(C_ref.size()), C_uint8(C_ref.size());
- depthwise_3x3_pad_1_ref(
- N,
- H,
- W,
- G,
- stride_h,
- stride_w,
- A_zero_point,
- A.data(),
- B_zero_point,
- B.data(),
- C_multiplier,
- C_zero_point,
- C_uint8_ref.data(),
- col_offsets.data(),
- bias.data());
+ conv_ref(conv_p, A.data(), A_zero_point, B.data(), C_ref.data());
+
+ for (int g = 0; g < conv_p.G; ++g) {
+ // Compute row offset
+ row_offsets_u8acc32_ref(
+ MDim,
+ KDimPerGroup,
+ KDim,
+ A_im2col.data() + g * KDimPerGroup,
+ row_offsets.data());
+
+ // Requantization
+ requantize_u8acc32_ref(
+ MDim,
+ 1,
+ conv_p.G,
+ C_ref.data() + g,
+ C_uint8_ref.data() + g,
+ C_multiplier.data(),
+ C_zero_point,
+ A_zero_point,
+ &B_zero_point,
+ row_offsets.data(),
+ col_offsets.data() + g,
+ bias.data() + g,
+ K);
+ }
- Packed3x3ConvMatrix Bp(G, B.data());
+ PackedDepthWiseConvMatrix Bp(K, 3 * 3, B.data());
double ttot = 0;
double bytes = double(NITER) *
- (G * (N * (2 * sizeof(int32_t) * H_OUT * W_OUT + H * W) + R * S));
- double ops = double(NITER) * N * H_OUT * W_OUT * G * R * S * 2;
+ (K * (N * (2 * sizeof(int32_t) * H_OUT * W_OUT + H * W) + R * S));
+ double ops = double(NITER) * N * H_OUT * W_OUT * K * R * S * 2;
chrono::time_point<chrono::system_clock> t_begin, t_end;
for (int i = 0; i < NWARMUP + NITER; ++i) {
llc_flush();
@@ -235,19 +254,20 @@ int main() {
N,
H,
W,
- G,
+ K,
stride_h,
stride_w,
A_zero_point,
A.data(),
B_zero_point,
Bp,
- C_multiplier,
+ C_multiplier[0],
C_zero_point,
C_uint8.data(),
col_offsets.data(),
bias.data(),
false, /* fuse_relu */
+ 1.0f, /* act_scale * w_scale */
tid,
num_threads);
}
@@ -262,10 +282,10 @@ int main() {
for (int n = 0; n < N; ++n) {
for (int h = 0; h < H_OUT; ++h) {
for (int w = 0; w < W_OUT; ++w) {
- for (int g = 0; g < G; ++g) {
+ for (int g = 0; g < K; ++g) {
uint8_t expected =
- C_uint8_ref[((n * H_OUT + h) * W_OUT + w) * G + g];
- uint8_t actual = C_uint8[((n * H_OUT + h) * W_OUT + w) * G + g];
+ C_uint8_ref[((n * H_OUT + h) * W_OUT + w) * K + g];
+ uint8_t actual = C_uint8[((n * H_OUT + h) * W_OUT + w) * K + g];
if (expected != actual) {
cerr << "Depthwise 3x3 results differ at (" << n << ", " << h
<< ", " << w << ", " << g << "). expected " << (int)expected
@@ -280,9 +300,9 @@ int main() {
// Report performance
printf(
- "N = %d G = %d H = %d W = %d stride = %d with requantization fused\n",
+ "N = %d K = %d H = %d W = %d stride = %d with requantization fused\n",
N,
- G,
+ K,
H,
W,
stride_h);
diff --git a/bench/GEMMsBenchmark.cc b/bench/GEMMsBenchmark.cc
index b404d8b..f493a96 100644
--- a/bench/GEMMsBenchmark.cc
+++ b/bench/GEMMsBenchmark.cc
@@ -28,6 +28,7 @@ using namespace std;
using namespace fbgemm;
void performance_test() {
+ // clang-format off
static const vector<vector<int>> shapes = {
// NOTE: clang-format wants to use a different formatting but the current
// formatting should be easier to read.
@@ -39,6 +40,7 @@ void performance_test() {
{256, 512, 256},
{1024, 1024, 1024},
};
+ // clang-format on
bool flush = true;
std::vector<char> llc;
diff --git a/bench/GEMMsTunableBenchmark.cc b/bench/GEMMsTunableBenchmark.cc
index a65b51f..2adc556 100644
--- a/bench/GEMMsTunableBenchmark.cc
+++ b/bench/GEMMsTunableBenchmark.cc
@@ -218,7 +218,8 @@ int main(int /* unused */, char** /* unused */) {
}
#endif
- vector<vector<int>> shapes = {
+ // clang-format off
+ vector<vector<int>> shapes = {
// NOTE: clang-format wants to use a different formatting but the current
// formatting should be easier to read.
// m, n, k
@@ -266,7 +267,8 @@ int main(int /* unused */, char** /* unused */) {
{128, 128, 128},
{256, 512, 256},
{1024, 1024, 1024},
-};
+ };
+ // clang-format on
vector<int> MCBs;
vector<int> NCBs;
diff --git a/bench/PackedFloatInOutBenchmark.cc b/bench/PackedFloatInOutBenchmark.cc
index 66ca67e..dcea65c 100644
--- a/bench/PackedFloatInOutBenchmark.cc
+++ b/bench/PackedFloatInOutBenchmark.cc
@@ -28,6 +28,7 @@ using namespace std;
using namespace fbgemm;
void performance_test() {
+ // clang-format off
vector<vector<int>> shapes = {
// NOTE: clang-format wants to use a different formatting but the current
// formatting should be easier to read.
@@ -66,6 +67,7 @@ void performance_test() {
{1, 128, 2722},
{16, 256, 512},
};
+ // clang-format on
bool flush = true;
std::vector<char> llc;
diff --git a/bench/PackedRequantizeAcc16Benchmark.cc b/bench/PackedRequantizeAcc16Benchmark.cc
index 40ff662..c6e2869 100644
--- a/bench/PackedRequantizeAcc16Benchmark.cc
+++ b/bench/PackedRequantizeAcc16Benchmark.cc
@@ -37,6 +37,7 @@ enum class BenchmarkType {
};
void performance_test() {
+ // clang-format off
vector<vector<int>> shapes = {
// NOTE: clang-format wants to use a different formatting but the current
// formatting should be easier to read.
@@ -67,6 +68,7 @@ void performance_test() {
{392, 2048, 512},
{392, 512, 2048},
};
+ // clang-format on
bool flush = true;
std::vector<char> llc;
diff --git a/bench/PackedRequantizeAcc32Benchmark.cc b/bench/PackedRequantizeAcc32Benchmark.cc
index 2f04795..a61ef5a 100644
--- a/bench/PackedRequantizeAcc32Benchmark.cc
+++ b/bench/PackedRequantizeAcc32Benchmark.cc
@@ -28,6 +28,7 @@ using namespace std;
using namespace fbgemm;
void performance_test() {
+ // clang-format off
vector<vector<int>> shapes = {
// NOTE: clang-format wants to use a different formatting but the current
// formatting should be easier to read.
@@ -70,6 +71,7 @@ void performance_test() {
{1, 128, 2722},
{16, 256, 512},
};
+ // clang-format on
bool flush = true;
std::vector<char> llc;
diff --git a/include/fbgemm/ConvUtils.h b/include/fbgemm/ConvUtils.h
index 11f3dcc..5431958 100644
--- a/include/fbgemm/ConvUtils.h
+++ b/include/fbgemm/ConvUtils.h
@@ -8,9 +8,24 @@
#include <array>
#include <string>
+#include <type_traits>
namespace fbgemm {
+template <int N, int... Vals>
+constexpr
+ typename std::enable_if<N == sizeof...(Vals), std::array<int, N>>::type
+ array_of_ones() {
+ return std::array<int, N>{{Vals...}};
+}
+
+template <int N, int... Vals>
+constexpr
+ typename std::enable_if<N != sizeof...(Vals), std::array<int, N>>::type
+ array_of_ones() {
+ return array_of_ones<N, Vals..., 1>();
+}
+
/**
* @brief A struct to conveniently store all convolution parameters.
*/
@@ -34,7 +49,6 @@ struct conv_param_t {
/**
* @brief Constructor for initializing the convolution parameters.
- * TODO: Dilation is not handled correctly.
*/
conv_param_t(
int mb,
@@ -44,7 +58,8 @@ struct conv_param_t {
int g,
std::array<int, SPATIAL_DIM> k,
std::array<int, SPATIAL_DIM> strd,
- std::array<int, SPATIAL_DIM * 2> pd)
+ std::array<int, SPATIAL_DIM * 2> pd,
+ std::array<int, SPATIAL_DIM> dilations = array_of_ones<SPATIAL_DIM>())
: MB(mb),
IC(ic),
OC(oc),
@@ -52,7 +67,8 @@ struct conv_param_t {
G(g),
K(k),
stride(strd),
- pad(pd) {
+ pad(pd),
+ dilation(dilations) {
if (ic % g != 0) {
throw std::runtime_error(
"groups = " + std::to_string(g) +
@@ -63,10 +79,10 @@ struct conv_param_t {
"groups = " + std::to_string(g) +
" does not divide number of output channels = " + std::to_string(oc));
}
+
for (int d = 0; d < SPATIAL_DIM; ++d) {
- dilation[d] = 1;
IN_DIMP[d] = IN_DIM[d] + pad[d] + pad[SPATIAL_DIM + d];
- OUT_DIM[d] = (IN_DIMP[d] - K[d]) / stride[d] + 1;
+ OUT_DIM[d] = (IN_DIMP[d] - dilation[d] * (K[d] - 1) - 1) / stride[d] + 1;
}
}
@@ -102,8 +118,12 @@ struct conv_param_t {
}
for (int d = 0; d < SPATIAL_DIM * 2; ++d) {
out += "pad_" + dim_string[3 - SPATIAL_DIM + (d % SPATIAL_DIM)] + ":" +
- std::to_string(pad[d]);
- if (d < SPATIAL_DIM * 2 - 1) {
+ std::to_string(pad[d]) + ", ";
+ }
+ for (int d = 0; d < SPATIAL_DIM; ++d) {
+ out += "dilation_" + dim_string[3 - SPATIAL_DIM + d] + ":" +
+ std::to_string(dilation[d]);
+ if (d < SPATIAL_DIM - 1) {
out += ", ";
}
}
@@ -121,6 +141,10 @@ struct conv_param_t {
out += ", ";
}
}
+ for (int d = 0; d < SPATIAL_DIM; ++d) {
+ out += "dilation_" + std::to_string(d) + ":" +
+ std::to_string(dilation[d]) + ", ";
+ }
}
return out;
}
diff --git a/include/fbgemm/Fbgemm.h b/include/fbgemm/Fbgemm.h
index 90d1ee9..668bd42 100644
--- a/include/fbgemm/Fbgemm.h
+++ b/include/fbgemm/Fbgemm.h
@@ -416,6 +416,19 @@ class FBGEMM_API PackBMatrix final
const BlockingFactors* params = nullptr);
/**
+ * This constructor accepts pre-packed matrix as an input.
+ * And, it skips the actual packing procedure.
+ */
+ PackBMatrix(
+ matrix_op_t trans,
+ std::int32_t nRow,
+ std::int32_t nCol,
+ inpType* prepackedmat,
+ std::int32_t ld,
+ int groups = 1,
+ const BlockingFactors* params = nullptr);
+
+ /**
* Weight matrices are usually constant so worth pre-packing.
*/
bool isPrePacked() const {
@@ -445,14 +458,17 @@ class FBGEMM_API PackBMatrix final
std::int32_t addr(std::int32_t i, std::int32_t j) const;
/**
- * @brief Packs a block of source matrix into pmat buffer.
+ * @brief Packs a block of source matrix into pmat buffer. The blocking
+ * parameters are needed to compute the buffer size of each group.
+ * It will use default blocking parameters if params is not provided.
*/
- void pack(const block_type_t& block);
+ void pack(const block_type_t& block, const BlockingFactors* params = nullptr);
/**
* @brief Print the packed block.
*/
- void printPackedMatrix(std::string name);
+ void printPackedMatrix(std::string name,
+ const BlockingFactors* params = nullptr);
/**
* @return true if meta information like matrix shape is the same.
@@ -467,7 +483,7 @@ class FBGEMM_API PackBMatrix final
* @brief Unpack pmat buffer to the origin_buf (Used for the serialization to
* recover weight matrix).
*/
- void unpack(T* origin_buf);
+ void unpack(T* origin_buf, const BlockingFactors* params = nullptr);
~PackBMatrix() {}
@@ -476,6 +492,16 @@ class FBGEMM_API PackBMatrix final
const T* smat_;
std::int32_t ld_;
std::int32_t row_interleave_;
+
+ /**
+ * @brief Internal function performing both pack & unpack
+ */
+ void pack_unpack_(
+ const block_type_t& block,
+ T* unpack_buf,
+ T* pack_buf,
+ bool ispack,
+ const BlockingFactors* params = nullptr);
};
/**
@@ -508,6 +534,11 @@ class FBGEMM_API PackWeightMatrixForGConv {
void pack();
/**
+ * @brief Unpacks a pmat buffer into source matrix.
+ */
+ void unpack(T* origin_buf);
+
+ /**
* @brief Return packed data
*/
inpType* getBuf() {
@@ -530,6 +561,22 @@ class FBGEMM_API PackWeightMatrixForGConv {
const T* sdata_;
T* pdata_;
bool bufAllocatedHere_;
+
+ /**
+ * @brief Internal function performing both pack & unpack
+ */
+ void pack_unpack_(const T* src, T* dst, bool ispack);
+
+ /**
+ * @brief Get the index of the unpacked data
+ */
+ int unpacked_index_(int r, int s, int k, int g, int c, bool tr);
+
+ /**
+ * @brief Get the index of the packed data
+ */
+ int packed_index_(int r, int s, int k, int g, int c);
+
};
/**
@@ -562,12 +609,16 @@ class FBGEMM_API PackWeightsForConv {
return W_im2col_packed_;
}
- std::shared_ptr<Packed3x3ConvMatrix> getPackedWFor2DDW() {
- return W_dw_2D_packed_;
+ std::shared_ptr<PackedDepthWiseConvMatrix> getPackedWForDepthwise() {
+ return W_dw_packed_;
}
- std::shared_ptr<Packed3x3x3ConvMatrix> getPackedWFor3DDW() {
- return W_dw_3D_packed_;
+ std::shared_ptr<PackedDepthWiseConvMatrix> getPackedWFor2DDW() {
+ return W_dw_packed_;
+ }
+
+ std::shared_ptr<PackedDepthWiseConvMatrix> getPackedWFor3DDW() {
+ return W_dw_packed_;
}
std::shared_ptr<PackWeightMatrixForGConv<T, accT, SPATIAL_DIM>>
@@ -575,17 +626,55 @@ class FBGEMM_API PackWeightsForConv {
return W_gconv_packed_;
}
+ std::shared_ptr<PackBMatrix<T, accT>> getPackedWForPointwise() {
+ return W_pointwise_packed_;
+ }
+
+ int inputChannels() {
+ return conv_param_.IC;
+ }
+
+ int outputChannels() {
+ return conv_param_.OC;
+ }
+
+ std::array<int, SPATIAL_DIM> kernelDims() {
+ return conv_param_.K;
+ }
+
+ int groups() {
+ return conv_param_.G;
+ }
+
+ /**
+ * @brief Returns true if the packed weights would work for the given
+ * convolution parameters, and false otherwise
+ */
+ bool isPackingCompliant(const conv_param_t<SPATIAL_DIM>& conv_p);
+
+ /**
+ * @brief Returns a string of mismatching parameters
+ */
+ std::string mismatchingParams(const conv_param_t<SPATIAL_DIM>& conv_p);
+
+ /**
+ * @brief Unpack packed matric into origin_buf (Used for the serialization to
+ * recover weight matrix).
+ */
+ void unpack(T* origin_buf);
+
private:
+ const conv_param_t<SPATIAL_DIM> conv_param_;
// Packed weights if we use im2col based convolution implementation
std::shared_ptr<PackBMatrix<T, accT>> W_im2col_packed_;
- // Packed weights if we use 2D depthwise convolution implementation
- std::shared_ptr<Packed3x3ConvMatrix> W_dw_2D_packed_;
- // Packed weights if we use 3D depthwise convolution implementation
- std::shared_ptr<Packed3x3x3ConvMatrix> W_dw_3D_packed_;
+ // Packed weights if we use depthwise convolution implementation
+ std::shared_ptr<PackedDepthWiseConvMatrix> W_dw_packed_;
// Packed weights if we use groupwise (small channels per group) convolution
// implementation
std::shared_ptr<PackWeightMatrixForGConv<T, accT, SPATIAL_DIM>>
W_gconv_packed_;
+ // Packed weights if we use direct gemm for pointwise convolution
+ std::shared_ptr<PackBMatrix<T, accT>> W_pointwise_packed_;
};
/**
@@ -661,7 +750,11 @@ class FBGEMM_API PackAWithIm2Col
~PackAWithIm2Col() {
if (rowOffsetAllocatedHere) {
+#ifdef _MSC_VER
+ _aligned_free(row_offset_);
+#else
free(row_offset_);
+#endif
}
}
@@ -752,7 +845,11 @@ class FBGEMM_API PackAWithRowOffset final
~PackAWithRowOffset() {
if (rowOffsetAllocatedHere) {
+#ifdef _MSC_VER
+ _aligned_free(row_offset_);
+#else
free(row_offset_);
+#endif
}
}
@@ -845,7 +942,11 @@ class FBGEMM_API PackAWithQuantRowOffset final
~PackAWithQuantRowOffset() {
if (rowOffsetAllocatedHere) {
+#ifdef _MSC_VER
+ _aligned_free(row_offset_);
+#else
free(row_offset_);
+#endif
}
}
@@ -1062,12 +1163,15 @@ class FBGEMM_API DoSConvOnInpBuffer {
template <
bool FUSE_RELU,
QuantizationGranularity Q_GRAN = QuantizationGranularity::TENSOR,
+ typename BIAS_TYPE = std::int32_t,
typename outT = std::uint8_t,
typename inT = std::int32_t,
typename nextOPType = DoNothing<outT, outT>>
class FBGEMM_API ReQuantizeOutput {
public:
static constexpr int RELU_FUSED = FUSE_RELU;
+ static constexpr QuantizationGranularity QGRANType = Q_GRAN;
+ using BIAS_T = BIAS_TYPE;
using outType = outT;
using inpType = inT;
/**
@@ -1088,6 +1192,8 @@ class FBGEMM_API ReQuantizeOutput {
* See PackedRequantizeTest.cc for an example.
* TODO: if Aq_zero_point == 0, allow passing nullptr.
* @params bias can be nullptr otherwise the length should be nCol
+ * @params act_times_w_scale activation_scale * weight_scale. This is only
+ * used if bias is unquantized (i.e., float).
*/
ReQuantizeOutput(
nextOPType& nextop,
@@ -1097,9 +1203,10 @@ class FBGEMM_API ReQuantizeOutput {
const std::int32_t* Bq_zero_point,
const std::int32_t* row_offsets,
const std::int32_t* col_offsets,
- const std::int32_t* bias,
+ const BIAS_T* bias,
std::uint32_t nCol,
- int groups = 1)
+ int groups = 1,
+ const float* act_times_w_scale = nullptr)
: nextop_(nextop),
C_multiplier_(C_multiplier),
C_zero_point_(C_zero_point),
@@ -1109,7 +1216,8 @@ class FBGEMM_API ReQuantizeOutput {
q_col_offsets_(col_offsets),
bias_(bias),
ncols_(nCol),
- groups_(groups) {}
+ groups_(groups),
+ act_times_w_scale_(act_times_w_scale) {}
template <inst_set_t instSet>
inline int f(
@@ -1137,12 +1245,15 @@ class FBGEMM_API ReQuantizeOutput {
const std::int32_t* getColOffsets() const {
return q_col_offsets_;
}
- const std::int32_t* getBias() const {
+ const BIAS_T* getBias() const {
return bias_;
}
std::uint32_t getNCols() const {
return ncols_;
}
+ const float* getActWScale() const {
+ return act_times_w_scale_;
+ }
void setRowOffsets(const std::int32_t* row_offsets) {
q_row_offsets_ = row_offsets;
@@ -1156,9 +1267,10 @@ class FBGEMM_API ReQuantizeOutput {
const std::int32_t* Bq_zero_point_;
const std::int32_t* q_row_offsets_;
const std::int32_t* q_col_offsets_;
- const std::int32_t* bias_;
+ const BIAS_T* bias_;
std::uint32_t ncols_;
int groups_;
+ const float* act_times_w_scale_;
};
/**
@@ -1311,7 +1423,8 @@ template <
typename outType,
bool FUSE_RELU,
QuantizationGranularity Q_GRAN,
- int SPATIAL_DIM = 2>
+ int SPATIAL_DIM = 2,
+ typename BIAS_TYPE = std::int32_t>
FBGEMM_API void fbgemmGroupwiseConv(
const conv_param_t<SPATIAL_DIM>& conv_param,
const std::uint8_t* activations,
@@ -1320,7 +1433,7 @@ FBGEMM_API void fbgemmGroupwiseConv(
packed_W& packed_weights,
outType* out,
std::int32_t* outBuffer,
- const ReQuantizeOutput<FUSE_RELU, Q_GRAN>& outProcess,
+ const ReQuantizeOutput<FUSE_RELU, Q_GRAN, BIAS_TYPE>& outProcess,
int thread_id,
int num_threads);
@@ -1361,6 +1474,13 @@ template <int SPATIAL_DIM>
FBGEMM_API bool fbgemmOptimizedGConv(const conv_param_t<SPATIAL_DIM>& conv_p);
/**
+ * @brief Is this convolution a direct matrix-matrix multiplication, i.e., 1x1
+ * (aka pointwise) with right paddings etc.?
+ */
+template <int SPATIAL_DIM>
+FBGEMM_API bool takePointWiseFastPath(const conv_param_t<SPATIAL_DIM>& conv_p);
+
+/**
* @brief Allocate __size bytes of uninitialized storage whose alignment is
* specified by __align.
*/
diff --git a/include/fbgemm/FbgemmFP16.h b/include/fbgemm/FbgemmFP16.h
index 3d84977..8da0b56 100644
--- a/include/fbgemm/FbgemmFP16.h
+++ b/include/fbgemm/FbgemmFP16.h
@@ -104,6 +104,14 @@ class PackedGemmMatrixFP16 {
}
}
+ void setPacked(bool p) {
+ packed_ = p;
+ }
+
+ bool packed() const {
+ return packed_;
+ }
+
void initializeMemory() {
// allocate and initialize packed memory
const int padding = 1024; // required by sw pipelined kernels
@@ -128,6 +136,16 @@ class PackedGemmMatrixFP16 {
#endif
}
+ void unpackFromSrc(const matrix_op_t trans, float16* src_mat) {
+ bool tr = (trans == matrix_op_t::Transpose);
+ for (int i = 0; i < numRows(); i++) {
+ for (int j = 0; j < numCols(); j++) {
+ pmat_[tr ? i + numRows() * j : i * numCols() + j] = src_mat[addr(i, j)];
+ }
+ }
+ packed_ = false;
+ }
+
// protected:
// blocked row-major format address arithmetic
uint64_t addr(const int r_, const int c_) const {
@@ -163,6 +181,19 @@ class PackedGemmMatrixFP16 {
pmat_[addr(i, j)]);
}
}
+ packed_ = true;
+ }
+
+ // This function takes in an unpacked float16 matrix of the same size and
+ // packs it. There is no floating type conversion.
+ void packFromSrc(const matrix_op_t trans, const float16* smat) {
+ bool tr = (trans == matrix_op_t::Transpose);
+ for (int i = 0; i < numRows(); ++i) {
+ for (int j = 0; j < numCols(); ++j) {
+ pmat_[addr(i, j)] = smat[tr ? i + numRows() * j : i * numCols() + j];
+ }
+ }
+ packed_ = true;
}
const float16& operator()(const int r, const int c) const {
@@ -210,6 +241,7 @@ class PackedGemmMatrixFP16 {
uint64_t size_;
int kernel_ncol_blocks_;
float16* pmat_;
+ bool packed_{false};
friend void cblas_gemm_compute(
const matrix_op_t transa,
diff --git a/include/fbgemm/FbgemmI8DepthwiseAvx2.h b/include/fbgemm/FbgemmI8DepthwiseAvx2.h
index 069ff77..c454b16 100644
--- a/include/fbgemm/FbgemmI8DepthwiseAvx2.h
+++ b/include/fbgemm/FbgemmI8DepthwiseAvx2.h
@@ -11,31 +11,58 @@
namespace fbgemm {
-// KERNEL_PROD is the product of all kernels.
-// For example, KERNEL_PROD = 9 for 3x3, and 27 for 3x3x3.
-template <int KERNEL_PROD>
class FBGEMM_API PackedDepthWiseConvMatrix {
public:
- // smat in RSG layout
- PackedDepthWiseConvMatrix(int K, const std::int8_t* smat);
+ /**
+ * @params K the number of channels (same as the number of groups because
+ * depth-wise convolution has one input/output channel per group)
+ * @params kernel_prod the product of all kernels. For example, kernel_prod =
+ * 9 for 3x3 conv, and 27 for 3x3x3 conv.
+ * @param smat the source unpacked weight in GRS layout
+ */
+ PackedDepthWiseConvMatrix(int K, int kernel_prod, const std::int8_t* smat);
virtual ~PackedDepthWiseConvMatrix();
const std::int8_t* PackedMat() const {
return pmat_;
}
+ int GetKernelProduct() const {
+ return kernel_prod_;
+ }
+
+ /**
+ * @brief Unpacks pmat_ into unpack_data.
+ * Used for recovering the weight matrix into the original format
+ */
+ void unpack(std::int8_t* unpacked_data);
+
+ /**
+ * @brief returns the index into pmat_ given the row and column for smat
+ */
+ int addr(int r, int c);
+
private:
- int K_;
- std::int8_t* pmat_;
-}; // Packed3x3ConvMatrix
+ const int K_; /**< the number of channels */
+ const int kernel_prod_; /** the product of all kernel dims */
+ std::int8_t* pmat_; /** packed weight */
+}; // PackedDepthWiseConvMatrix
-using Packed3x3ConvMatrix = PackedDepthWiseConvMatrix<3 * 3>;
-using Packed3x3x3ConvMatrix = PackedDepthWiseConvMatrix<3 * 3 * 3>;
+class FBGEMM_API Packed3x3ConvMatrix : public PackedDepthWiseConvMatrix {
+ public:
+ Packed3x3ConvMatrix(int K, const std::int8_t* smat)
+ : PackedDepthWiseConvMatrix(K, 3 * 3, smat) {}
+};
-/**
- * Depth-wise 3x3 convolution with pad=1 and stride=1 and K a multiple of 8
- * @params A The input image in NHWK layout
- * @params Bp The pre-packed filter
+class FBGEMM_API Packed3x3x3ConvMatrix : public PackedDepthWiseConvMatrix {
+ public:
+ Packed3x3x3ConvMatrix(int K, const std::int8_t* smat)
+ : PackedDepthWiseConvMatrix(K, 3 * 3 * 3, smat) {}
+};
+
+/** To be removed. Keeping it just to make sure we don't change C2 files and
+ * fbgemm files in a single diff
+ *
*/
FBGEMM_API void depthwise_3x3_pad_1(
int N,
@@ -46,8 +73,14 @@ FBGEMM_API void depthwise_3x3_pad_1(
int stride_w,
std::int32_t A_zero_point,
const std::uint8_t* A,
- const Packed3x3ConvMatrix& Bp,
- std::int32_t* C,
+ std::int32_t B_zero_point,
+ const PackedDepthWiseConvMatrix& Bp,
+ float C_multiplier,
+ std::int32_t C_zero_point,
+ std::uint8_t* C,
+ const std::int32_t* col_offsets,
+ const std::int32_t* bias,
+ bool fuse_relu = false,
int thread_id = 0,
int num_threads = 1);
@@ -56,7 +89,10 @@ FBGEMM_API void depthwise_3x3_pad_1(
* This version is fused with requantization.
*
* @col_offsets nullptr if col_offsets are folded into bias
+ * @act_times_w_scale Only used if BIAS_TYPE is float, i.e., bias is
+ * unquantized.
*/
+template <typename BIAS_TYPE = std::int32_t>
FBGEMM_API void depthwise_3x3_pad_1(
int N,
int H,
@@ -67,22 +103,24 @@ FBGEMM_API void depthwise_3x3_pad_1(
std::int32_t A_zero_point,
const std::uint8_t* A,
std::int32_t B_zero_point,
- const Packed3x3ConvMatrix& Bp,
+ const PackedDepthWiseConvMatrix& Bp,
float C_multiplier,
std::int32_t C_zero_point,
std::uint8_t* C,
const std::int32_t* col_offsets,
- const std::int32_t* bias,
+ const BIAS_TYPE* bias,
bool fuse_relu = false,
+ float act_times_w_scale = 1.0f,
int thread_id = 0,
int num_threads = 1);
/**
- * Depth-wise 3x3 convolution with pad=1 and stride=1 and K a multiple of 8
- * This version is fused with requantization and uses per-channel quantization.
+ * Depth-wise 3x3 convolution with pad=1 and K a multiple of 8, fused with
+ * requantization, and using per-channel quantization.
*
* @col_offsets nullptr if col_offsets are folded into bias
*/
+template <typename BIAS_TYPE = std::int32_t>
FBGEMM_API void depthwise_3x3_per_channel_quantization_pad_1(
int N,
int H,
@@ -93,7 +131,31 @@ FBGEMM_API void depthwise_3x3_per_channel_quantization_pad_1(
std::int32_t A_zero_point,
const std::uint8_t* A,
const std::int32_t* B_zero_point,
- const Packed3x3ConvMatrix& Bp,
+ const PackedDepthWiseConvMatrix& Bp,
+ const float* C_multiplier,
+ std::int32_t C_zero_point,
+ std::uint8_t* C,
+ const std::int32_t* col_offsets,
+ const BIAS_TYPE* bias,
+ bool fuse_relu = false,
+ const float* act_times_w_scale = nullptr,
+ int thread_id = 0,
+ int num_threads = 1);
+
+/** To be removed. Keeping it just to make sure we don't change C2 files and
+ * fbgemm files in a single diff
+ */
+FBGEMM_API void depthwise_3x3_per_channel_quantization_pad_1(
+ int N,
+ int H,
+ int W,
+ int K,
+ int stride_h,
+ int stride_w,
+ std::int32_t A_zero_point,
+ const std::uint8_t* A,
+ const std::int32_t* B_zero_point,
+ const PackedDepthWiseConvMatrix& Bp,
const float* C_multiplier,
std::int32_t C_zero_point,
std::uint8_t* C,
@@ -103,6 +165,10 @@ FBGEMM_API void depthwise_3x3_per_channel_quantization_pad_1(
int thread_id = 0,
int num_threads = 1);
+/** To be removed. Keeping it just to make sure we don't change C2 files and
+ * fbgemm files in a single diff
+ *
+ */
FBGEMM_API void depthwise_3x3x3_pad_1(
int N,
int T,
@@ -114,14 +180,20 @@ FBGEMM_API void depthwise_3x3x3_pad_1(
int stride_w,
std::int32_t A_zero_point,
const std::uint8_t* A,
- const Packed3x3x3ConvMatrix& Bp,
- std::int32_t* C,
+ std::int32_t B_zero_point,
+ const PackedDepthWiseConvMatrix& Bp,
+ float C_multiplier,
+ std::int32_t C_zero_point,
+ std::uint8_t* C,
+ const std::int32_t* col_offsets,
+ const std::int32_t* bias,
+ bool fuse_relu = false,
int thread_id = 0,
int num_threads = 1);
-
/**
* @col_offsets nullptr if col_offsets are folded into bias
*/
+template <typename BIAS_TYPE = std::int32_t>
FBGEMM_API void depthwise_3x3x3_pad_1(
int N,
int T,
@@ -134,11 +206,38 @@ FBGEMM_API void depthwise_3x3x3_pad_1(
std::int32_t A_zero_point,
const std::uint8_t* A,
std::int32_t B_zero_point,
- const Packed3x3x3ConvMatrix& Bp,
+ const PackedDepthWiseConvMatrix& Bp,
float C_multiplier,
std::int32_t C_zero_point,
std::uint8_t* C,
const std::int32_t* col_offsets,
+ const BIAS_TYPE* bias,
+ bool fuse_relu = false,
+ float act_times_w_scale = 1.0f,
+ int thread_id = 0,
+ int num_threads = 1);
+
+/** To be removed. Keeping it just to make sure we don't change C2 files and
+ * fbgemm files in a single diff
+ *
+ */
+FBGEMM_API void depthwise_3x3x3_per_channel_quantization_pad_1(
+ int N,
+ int T,
+ int H,
+ int W,
+ int K,
+ int stride_t,
+ int stride_h,
+ int stride_w,
+ std::int32_t A_zero_point,
+ const std::uint8_t* A,
+ const std::int32_t* B_zero_point,
+ const PackedDepthWiseConvMatrix& Bp,
+ const float* C_multiplier,
+ std::int32_t C_zero_point,
+ std::uint8_t* C,
+ const std::int32_t* col_offsets,
const std::int32_t* bias,
bool fuse_relu = false,
int thread_id = 0,
@@ -147,6 +246,7 @@ FBGEMM_API void depthwise_3x3x3_pad_1(
/**
* @col_offsets nullptr if col_offsets are folded into bias
*/
+template <typename BIAS_TYPE = std::int32_t>
FBGEMM_API void depthwise_3x3x3_per_channel_quantization_pad_1(
int N,
int T,
@@ -159,13 +259,14 @@ FBGEMM_API void depthwise_3x3x3_per_channel_quantization_pad_1(
std::int32_t A_zero_point,
const std::uint8_t* A,
const std::int32_t* B_zero_point,
- const Packed3x3x3ConvMatrix& Bp,
+ const PackedDepthWiseConvMatrix& Bp,
const float* C_multiplier,
std::int32_t C_zero_point,
std::uint8_t* C,
const std::int32_t* col_offsets,
- const std::int32_t* bias,
+ const BIAS_TYPE* bias,
bool fuse_relu = false,
+ const float* act_times_w_scale = nullptr,
int thread_id = 0,
int num_threads = 1);
diff --git a/include/fbgemm/OutputProcessing-inl.h b/include/fbgemm/OutputProcessing-inl.h
index d984c60..04ae100 100644
--- a/include/fbgemm/OutputProcessing-inl.h
+++ b/include/fbgemm/OutputProcessing-inl.h
@@ -59,11 +59,13 @@ inline int DoSConvOnInpBuffer<outT, inT, nextOPType>::f(
template <
bool FUSE_RELU,
QuantizationGranularity Q_GRAN,
+ typename BIAS_TYPE,
typename outT,
typename inT,
typename nextOPType>
template <inst_set_t instSet>
-inline int ReQuantizeOutput<FUSE_RELU, Q_GRAN, outT, inT, nextOPType>::f(
+inline int
+ReQuantizeOutput<FUSE_RELU, Q_GRAN, BIAS_TYPE, outT, inT, nextOPType>::f(
outT* out,
const inT* inp,
const block_type_t& block,
@@ -98,11 +100,20 @@ inline int ReQuantizeOutput<FUSE_RELU, Q_GRAN, outT, inT, nextOPType>::f(
raw -= q_row_offsets_[i - block.row_start] *
Bq_zero_point_[Bq_zero_point_idx];
}
+ float raw_f;
if (bias_) {
- raw += bias_[j];
+ if (std::is_same<BIAS_TYPE, float>::value) {
+ raw_f = raw;
+ raw_f += bias_[j] / act_times_w_scale_[Bq_zero_point_idx];
+ } else {
+ raw += bias_[j];
+ raw_f = raw;
+ }
+ } else {
+ raw_f = raw;
}
- float ab = raw * C_multiplier_[Bq_zero_point_idx];
+ float ab = raw_f * C_multiplier_[Bq_zero_point_idx];
long rounded = std::lrintf(ab) + C_zero_point_;
out[i * ld_out + j] = std::max(
@@ -115,15 +126,16 @@ inline int ReQuantizeOutput<FUSE_RELU, Q_GRAN, outT, inT, nextOPType>::f(
Bq_zero_point_[0] == 0) ||
q_row_offsets_ == nullptr;
- requantizationParams_t r = {Aq_zero_point_,
- Bq_zero_point_,
- C_zero_point_,
- C_multiplier_,
- q_row_offsets_,
- q_col_offsets_,
- bias_,
- ncols_,
- groups_};
+ requantizationParams_t<BIAS_TYPE> r = {Aq_zero_point_,
+ Bq_zero_point_,
+ C_zero_point_,
+ C_multiplier_,
+ q_row_offsets_,
+ q_col_offsets_,
+ bias_,
+ ncols_,
+ groups_,
+ act_times_w_scale_};
if (Aq_zero_point_ == 0) {
if (b_symmetric) {
diff --git a/include/fbgemm/PackingTraits-inl.h b/include/fbgemm/PackingTraits-inl.h
index 76eb425..baccfad 100644
--- a/include/fbgemm/PackingTraits-inl.h
+++ b/include/fbgemm/PackingTraits-inl.h
@@ -222,3 +222,53 @@ struct PackingTraits<
128}; ///< Cache block for N dimension (multiple of NR).
static constexpr int KCB{256}; ///< Cache block for K dimension.
};
+
+/**
+ * @brief Helper struct to type specialize for int16_t and int32_t together.
+ */
+template <typename T>
+struct is_16or32bit {
+ static constexpr bool value =
+ std::is_same<T, int16_t>::value || std::is_same<T, int32_t>::value;
+};
+
+/**
+ * @brief Packing parameter specialization for accumulation into 32-bit/16-bit
+ * integers.
+ *
+ * Since there is no int16_t accumulation for AVX512 VNNI, we redirect int16_t
+ * to int32_t accumulation and use the same blocking parameters as int32_t.
+ *
+ * This is picked when T is of int8 type (signed or unsigned) and instruction
+ * set is avx512_vnni.
+ */
+template <typename T, typename accT>
+struct PackingTraits<
+ T,
+ accT,
+ inst_set_t::avx512_vnni,
+ typename std::enable_if<
+ is_8bit<T>::value && is_16or32bit<accT>::value>::type> {
+ static constexpr int MR{8}; ///< Register block for M dimension.
+ static constexpr int NR_MIN{
+ 16}; ///< Minimum register block for N dimension.
+ ///< 16 because 16*ROW_INTERLEAVE int8 elements
+ ///< completely fill a 512-bit wide vector.
+ static constexpr int NR{
+ 32}; ///< Register block for N dimension.
+ ///< Must be a multiple of 16 because 16*ROW_INTERLEAVE int8 elements
+ ///< completely fill a 512-bit wide vector. Total registers used for
+ ///< N dimension: NR*ROW_INTERLEAVE*8/VLEN. We use MR x
+ ///< NR*ROW_INTERLEAVE*8/VLEN zmm registers
+ ///< for C accumulations.
+
+ static constexpr int ROW_INTERLEAVE{
+ 4}; ///< 4 rows are interleaved to use vpmaddubsw instruction for packing
+ ///< B matrix.
+
+ static constexpr int MCB{
+ 128}; ///< Cache block for M dimension (multiple of MR).
+ static constexpr int NCB{
+ 32}; ///< Cache block for N dimension (multiple of NR).
+ static constexpr int KCB{256}; ///< Cache block for K dimension.
+};
diff --git a/include/fbgemm/QuantUtils.h b/include/fbgemm/QuantUtils.h
index 43855d8..508ce7d 100644
--- a/include/fbgemm/QuantUtils.h
+++ b/include/fbgemm/QuantUtils.h
@@ -7,6 +7,7 @@
#include <limits>
#include "FbgemmBuild.h"
#include "QuantUtilsAvx2.h"
+#include "Utils.h"
namespace fbgemm {
@@ -78,6 +79,40 @@ FBGEMM_API void Quantize(
int len,
const TensorQuantizationParams& qparams);
+/*
+ * @brief Quantize floating point data in src to type T
+ *
+ * @tparam T output quantized data type (int8_t, uint8_t and int32_t are
+ * supported)
+ *
+ * @tparam T LAYOUT layout of input tensor in src. (KCX and KXC are supported)
+ * KCX corresponds to KCRS or KCTRS (for weight tensors with
+ * time dimension)
+ * KXC corresponds to KRSC or KTRSC (for weight tensors with
+ * time dimension)
+ *
+ * @params K Output channels for weight tensors
+ * @params C Number of channels
+ * @params X R*S or T*R*S
+ * @params G Groups (if G == C the function performs channelwise quantization;
+ * if 1 < G < C the function performs groupwise quantization;
+ * if G == 1 the function performs per tensor quantization;)
+ * @params scales floating point scales.
+ * Size should be equal G
+ * @params zero_points zero points (should be reprsentable in type T).
+ * Size should be equal G
+ */
+template <typename T, layout_t LAYOUT = layout_t::KCX>
+FBGEMM_API void QuantizeGroupwise(
+ const float* src,
+ int K,
+ int C,
+ int X,
+ int G,
+ const float* scales,
+ const std::int32_t* zero_points,
+ T* dst);
+
template <typename T>
FBGEMM_API float Dequantize(T src, const TensorQuantizationParams& qparams) {
return qparams.scale * (src - qparams.zero_point);
diff --git a/include/fbgemm/QuantUtilsAvx2.h b/include/fbgemm/QuantUtilsAvx2.h
index 47f33a8..c7f3f35 100644
--- a/include/fbgemm/QuantUtilsAvx2.h
+++ b/include/fbgemm/QuantUtilsAvx2.h
@@ -40,9 +40,10 @@ struct FBGEMM_API RequantizationParams {
////////////////////////////////////////////////////////////////////////////////
// Utility functions
+template <typename T=std::uint8_t>
void QuantizeAvx2(
const float* src,
- std::uint8_t* dst,
+ T* dst,
int len,
const TensorQuantizationParams& qparams);
@@ -71,14 +72,15 @@ template <
bool B_SYMMETRIC,
QuantizationGranularity Q_GRAN,
bool HAS_BIAS,
- bool FUSE_RELU>
+ bool FUSE_RELU,
+ typename BIAS_TYPE = std::int32_t>
FBGEMM_API void requantizeOutputProcessingAvx2(
std::uint8_t* out,
const std::int32_t* inp,
const block_type_t& block,
int ld_out,
int ld_in,
- const requantizationParams_t& r);
+ const requantizationParams_t<BIAS_TYPE>& r);
template <
bool A_SYMMETRIC,
@@ -86,14 +88,15 @@ template <
QuantizationGranularity Q_GRAN,
bool HAS_BIAS,
bool FUSE_RELU,
- int C_PER_G>
+ int C_PER_G,
+ typename BIAS_TYPE = std::int32_t>
FBGEMM_API void requantizeOutputProcessingGConvAvx2(
std::uint8_t* out,
const std::int32_t* inp,
const block_type_t& block,
int ld_out,
int ld_in,
- const requantizationParams_t& r);
+ const requantizationParams_t<BIAS_TYPE>& r);
template <
bool A_SYMMETRIC,
diff --git a/include/fbgemm/Utils.h b/include/fbgemm/Utils.h
index 9f8e1ee..3976790 100644
--- a/include/fbgemm/Utils.h
+++ b/include/fbgemm/Utils.h
@@ -5,6 +5,7 @@
* LICENSE file in the root directory of this source tree.
*/
#pragma once
+#include <array>
#include <string>
#include <type_traits>
#include "FbgemmBuild.h"
@@ -39,12 +40,12 @@ enum class matrix_op_t { NoTranspose, Transpose };
/**
* @brief Typed enum for supported instruction sets.
*/
-enum class inst_set_t { anyarch, avx2, avx512 };
+enum class inst_set_t { anyarch, avx2, avx512, avx512_vnni };
/**
* @brief Typed enum for optimized paths for convolutions
*/
-enum class optimized_conv_t { depthwise, groupwise, im2col };
+enum class optimized_conv_t { depthwise, groupwise, pointwise, im2col };
/**
* @brief Typed enum for implementation type.
@@ -54,6 +55,13 @@ enum class optimized_conv_t { depthwise, groupwise, im2col };
enum class impl_type_t { ref, opt };
/**
+ * @brief Typed enum to specify data layout.
+ * KCX can be KCRS format or KCTRS format (e.g., for 3-D convolutions)
+ * KXC can be KRSC format or KTRSC format (e.g., for 3-D convolutions)
+ */
+enum class layout_t { KCX, KXC };
+
+/**
* @brief A function to compare data in two buffers for closeness/equality.
*/
template <typename T>
@@ -103,6 +111,11 @@ FBGEMM_API bool fbgemmHasAvx512Support();
FBGEMM_API bool fbgemmHasAvx2Support();
/**
+ * @brief Are we running on a AVX512_VNNI supported cpu?
+ */
+FBGEMM_API bool fbgemmHasAvx512VnniSupport();
+
+/**
* @brief Helper struct to enable autotuning of FBGEMM packing and kernels.
*
* This structure is optional. If not used, the default values for these
@@ -119,6 +132,16 @@ struct FBGEMM_API BlockingFactors {
int NCB;
};
+template <int SIZE, typename T = std::int32_t>
+FBGEMM_API std::string arrayToString(const std::array<T, SIZE>& inp) {
+ std::string out = "[";
+ for (int i = 0; i < SIZE; ++i) {
+ out += std::to_string(inp[i]);
+ out += (i != SIZE - 1) ? std::string(", ") : std::string("]");
+ }
+ return out;
+}
+
template <typename accT = std::int32_t>
FBGEMM_API bool isValidBlockingFactor(BlockingFactors* param) {
constexpr bool is_32bit = std::is_same<accT, int32_t>::value;
@@ -129,10 +152,10 @@ FBGEMM_API bool isValidBlockingFactor(BlockingFactors* param) {
return false;
if (fbgemmHasAvx512Support()) {
- if (param->NR != 16)
+ if (param->NR_MIN != 16 || param->NR % param->NR_MIN)
return false;
} else if (fbgemmHasAvx2Support()) {
- if (param->NR != 8)
+ if (param->NR_MIN != 8 || param->NR % param->NR_MIN)
return false;
}
} else if (is_16bit) {
@@ -140,10 +163,10 @@ FBGEMM_API bool isValidBlockingFactor(BlockingFactors* param) {
return false;
if (fbgemmHasAvx512Support()) {
- if (param->NR != 32)
+ if (param->NR_MIN != 32 || param->NR % param->NR_MIN)
return false;
} else if (fbgemmHasAvx2Support()) {
- if (param->NR != 16)
+ if (param->NR_MIN != 16 || param->NR % param->NR_MIN)
return false;
}
}
@@ -153,10 +176,19 @@ FBGEMM_API bool isValidBlockingFactor(BlockingFactors* param) {
if (param->NCB % param->NR)
return false;
if (fbgemmHasAvx512Support()) {
- if (param->MR * (param->NCB / param->NR) > 24)
- return false;
+ if (is_32bit) {
+ // Zmm register usage for C
+ if (param->MR * (param->NR / param->NR_MIN) > 28)
+ return false;
+ } else if (is_16bit) {
+ // Zmm register usage for C + one row for loading B
+ if ((param->MR * (param->NR / param->NR_MIN) +
+ (param->NR / param->NR_MIN)) > 28)
+ return false;
+ }
+
} else if (fbgemmHasAvx2Support()) {
- if (param->MR * (param->NCB / param->NR) > 16)
+ if (param->MR * (param->NR / param->NR_MIN) > 12)
return false;
}
return true;
diff --git a/include/fbgemm/UtilsAvx2.h b/include/fbgemm/UtilsAvx2.h
index 082edc1..3bac909 100644
--- a/include/fbgemm/UtilsAvx2.h
+++ b/include/fbgemm/UtilsAvx2.h
@@ -44,16 +44,19 @@ struct block_type_t {
* QuantUtilsAvx2.h as it combines all the parameters needed for various
* quantization granularities
*/
+template<typename BIAS_TYPE = std::int32_t>
struct requantizationParams_t {
+ using BIAS_T = BIAS_TYPE;
std::int32_t A_zero_point;
const std::int32_t* B_zero_point;
std::int32_t C_zero_point;
const float* C_multiplier;
const std::int32_t* row_offsets;
const std::int32_t* col_offsets;
- const std::int32_t* bias;
+ const BIAS_T* bias;
std::uint32_t ncols;
int groups;
+ const float* act_times_w_scale;
};
/**
diff --git a/src/CodeCache.h b/src/CodeCache.h
new file mode 100644
index 0000000..08e9c9b
--- /dev/null
+++ b/src/CodeCache.h
@@ -0,0 +1,59 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+#pragma once
+#include <condition_variable>
+#include <future>
+#include <map>
+#include <mutex>
+
+namespace fbgemm {
+
+/**
+ * @brief Thread safe cache for microkernels, ensures single creation per key.
+ * @tparam Key Type of unique key (typically a tuple)
+ * @tparam Value Type of the microkernel function (Typically a function pointer)
+ */
+template <typename KEY, typename VALUE> class CodeCache {
+private:
+ std::map<KEY, std::shared_future<VALUE>> values_;
+ std::mutex mutex_;
+
+public:
+ CodeCache(const CodeCache &) = delete;
+ CodeCache &operator=(const CodeCache &) = delete;
+
+ CodeCache(){};
+
+ VALUE getOrCreate(const KEY &key, std::function<VALUE()> generatorFunction) {
+ std::shared_future<VALUE> returnFuture;
+ std::promise<VALUE> returnPromise;
+ bool needsToGenerate = false;
+
+ // Check for existance of the key
+ {
+ std::unique_lock<std::mutex> lock(mutex_);
+
+ auto it = values_.find(key);
+ if (it != values_.end()) {
+ returnFuture = it->second;
+ } else {
+ values_[key] = returnFuture = returnPromise.get_future().share();
+ needsToGenerate = true;
+ }
+ }
+
+ // The value (code) generation is not happening under a lock
+ if (needsToGenerate) {
+ returnPromise.set_value(generatorFunction());
+ }
+
+ // Wait for the future and return the value
+ return returnFuture.get();
+ }
+};
+
+} // namespace fbgemm
diff --git a/src/ExecuteKernelU8S8.cc b/src/ExecuteKernelU8S8.cc
index f7292fd..4ae1b50 100644
--- a/src/ExecuteKernelU8S8.cc
+++ b/src/ExecuteKernelU8S8.cc
@@ -49,7 +49,8 @@ ExecuteKernel<
throw std::runtime_error("Failed to initialize cpuinfo!");
}
if (params) {
- if (fbgemmHasAvx512Support() || fbgemmHasAvx2Support()) {
+ if (fbgemmHasAvx512VnniSupport() || fbgemmHasAvx512Support() ||
+ fbgemmHasAvx2Support()) {
mbSize_ = params->MCB;
nbSize_ = params->NCB;
nrMinSize_ = params->NR_MIN;
@@ -59,7 +60,20 @@ ExecuteKernel<
assert(0 && "unsupported architecure");
}
} else {
- if (fbgemmHasAvx512Support()) {
+ if (fbgemmHasAvx512VnniSupport()) {
+ mbSize_ = PackingTraits<
+ int8_t,
+ typename packingAMatrix::accType,
+ inst_set_t::avx512_vnni>::MCB;
+ nbSize_ = PackingTraits<
+ int8_t,
+ typename packingAMatrix::accType,
+ inst_set_t::avx512_vnni>::NCB;
+ nrMinSize_ = PackingTraits<
+ int8_t,
+ typename packingAMatrix::accType,
+ inst_set_t::avx512_vnni>::NR_MIN;
+ } else if (fbgemmHasAvx512Support()) {
mbSize_ = PackingTraits<
int8_t,
typename packingAMatrix::accType,
@@ -118,7 +132,25 @@ void ExecuteKernel<
typename BaseType::jit_micro_kernel_fp fn;
- if (fbgemmHasAvx512Support()) {
+ if (fbgemmHasAvx512VnniSupport()) {
+ if (std::is_same<typename packingAMatrix::accType, std::int16_t>::value) {
+ // For AVX512VNNI, we redirect int16_t to int32_t accumulation.
+ CodeGenBase<uint8_t, int8_t, int32_t, int32_t> codeObj;
+ fn = codeObj.getOrCreate<inst_set_t::avx512_vnni>(
+ accum,
+ packed_rows_A,
+ packedB_.blockColSize(),
+ packedA_.numPackedCols(),
+ nbSize_);
+ } else {
+ fn = BaseType::template getOrCreate<inst_set_t::avx512_vnni>(
+ accum,
+ packed_rows_A,
+ packedB_.blockColSize(),
+ packedA_.numPackedCols(),
+ nbSize_);
+ }
+ } else if (fbgemmHasAvx512Support()) {
fn = BaseType::template getOrCreate<inst_set_t::avx512>(
accum,
packed_rows_A,
@@ -148,7 +180,10 @@ void ExecuteKernel<
if (jb == bColBlocks - 1) {
int nc = ((packedB_.lastBcol() - 1) / nrMinSize_ + 1) * nrMinSize_;
if (nc != nbSize_) {
- if (fbgemmHasAvx512Support()) {
+ if (fbgemmHasAvx512VnniSupport()) {
+ fn = BaseType::template getOrCreate<inst_set_t::avx512_vnni>(
+ accum, packed_rows_A, nc, packedA_.numPackedCols(), nbSize_);
+ } else if (fbgemmHasAvx512Support()) {
fn = BaseType::template getOrCreate<inst_set_t::avx512>(
accum, packed_rows_A, nc, packedA_.numPackedCols(), nbSize_);
} else if (fbgemmHasAvx2Support()) {
@@ -213,7 +248,7 @@ void ExecuteKernel<
int32_t nSize =
C_buffer_start == C_tile_ ? jb * nbSize_ : packedB_.numCols();
if (nSize) {
- if (fbgemmHasAvx512Support()) {
+ if (fbgemmHasAvx512VnniSupport() || fbgemmHasAvx512Support()) {
// TODO: avx512 path
// Currently use avx2 code
outputProcess_.template f<inst_set_t::avx2>(
@@ -238,7 +273,7 @@ void ExecuteKernel<
if (C_buffer_start == C_tile_) {
// When C_tile_ scratchpad was used to avoid accessing memory past
// C_buffer_ .
- if (fbgemmHasAvx512Support()) {
+ if (fbgemmHasAvx512VnniSupport() || fbgemmHasAvx512Support()) {
// TODO: avx512 path
// Currently use avx2 code
outputProcess_.template f<inst_set_t::avx2>(
@@ -280,19 +315,23 @@ void ExecuteKernel<
////////////////////////////////////////////////////////////////////////////////
// ReQuantizeOutput
-#define INSTANTIATE_REQUANT_BASE(PACK_A, ACC_T, RELU, Q_GRAN) \
- template class ExecuteKernel< \
- PACK_A<uint8_t, ACC_T>, \
- PackBMatrix<int8_t, ACC_T>, \
- uint8_t, \
- ReQuantizeOutput<RELU, Q_GRAN>>;
+#define INSTANTIATE_REQUANT_BASE(PACK_A, ACC_T, RELU, Q_GRAN, BIAS_TYPE) \
+ template class ExecuteKernel< \
+ PACK_A<uint8_t, ACC_T>, \
+ PackBMatrix<int8_t, ACC_T>, \
+ uint8_t, \
+ ReQuantizeOutput<RELU, Q_GRAN, BIAS_TYPE>>;
+
+#define INSTANTIATE_REQUANT_BIAS_T(PACK_A, ACC_T, RELU, Q_GRAN) \
+ INSTANTIATE_REQUANT_BASE(PACK_A, ACC_T, RELU, Q_GRAN, float); \
+ INSTANTIATE_REQUANT_BASE(PACK_A, ACC_T, RELU, Q_GRAN, int32_t);
#define INSTANTIATE_REQUANT_Q_GRANS(PACK_A, ACC_T, RELU) \
- INSTANTIATE_REQUANT_BASE( \
+ INSTANTIATE_REQUANT_BIAS_T( \
PACK_A, ACC_T, RELU, QuantizationGranularity::TENSOR); \
- INSTANTIATE_REQUANT_BASE( \
+ INSTANTIATE_REQUANT_BIAS_T( \
PACK_A, ACC_T, RELU, QuantizationGranularity::GROUP); \
- INSTANTIATE_REQUANT_BASE( \
+ INSTANTIATE_REQUANT_BIAS_T( \
PACK_A, ACC_T, RELU, QuantizationGranularity::OUT_CHANNEL);
#define INSTANTIATE_REQUANT_RELU(PACK_A, ACC_T) \
@@ -309,21 +348,27 @@ INSTANTIATE_REQUANT_ACC_T(PackAWithRowOffset);
#undef INSTANTIATE_REQUANT_ACC_T
#undef INSTANTIATE_REQUANT_RELU
#undef INSTANTIATE_REQUANT_Q_GRANS
+#undef INSTANTIATE_REQUANT_BIAS_T
#undef INSTANTIATE_REQUANT_BASE
-#define INSTANTIATE_IM2COL_REQUANT_BASE(ACC_T, RELU, SPATIAL_DIM, Q_GRAN) \
- template class ExecuteKernel< \
- PackAWithIm2Col<uint8_t, ACC_T, SPATIAL_DIM>, \
- PackBMatrix<int8_t, ACC_T>, \
- uint8_t, \
- ReQuantizeOutput<RELU, Q_GRAN>>;
+#define INSTANTIATE_IM2COL_REQUANT_BASE( \
+ ACC_T, RELU, SPATIAL_DIM, Q_GRAN, BIAS_TYPE) \
+ template class ExecuteKernel< \
+ PackAWithIm2Col<uint8_t, ACC_T, SPATIAL_DIM>, \
+ PackBMatrix<int8_t, ACC_T>, \
+ uint8_t, \
+ ReQuantizeOutput<RELU, Q_GRAN, BIAS_TYPE>>;
+
+#define INSTANTIATE_IM2COL_REQUANT_BIAS_T(ACC_T, RELU, SPATIAL_DIM, Q_GRAN) \
+ INSTANTIATE_IM2COL_REQUANT_BASE(ACC_T, RELU, SPATIAL_DIM, Q_GRAN, float); \
+ INSTANTIATE_IM2COL_REQUANT_BASE(ACC_T, RELU, SPATIAL_DIM, Q_GRAN, int32_t);
#define INSTANTIATE_IM2COL_REQUANT_Q_GRANS(ACC_T, RELU, SPATIAL_DIM) \
- INSTANTIATE_IM2COL_REQUANT_BASE( \
+ INSTANTIATE_IM2COL_REQUANT_BIAS_T( \
ACC_T, RELU, SPATIAL_DIM, QuantizationGranularity::TENSOR); \
- INSTANTIATE_IM2COL_REQUANT_BASE( \
+ INSTANTIATE_IM2COL_REQUANT_BIAS_T( \
ACC_T, RELU, SPATIAL_DIM, QuantizationGranularity::GROUP); \
- INSTANTIATE_IM2COL_REQUANT_BASE( \
+ INSTANTIATE_IM2COL_REQUANT_BIAS_T( \
ACC_T, RELU, SPATIAL_DIM, QuantizationGranularity::OUT_CHANNEL);
#define INSTANTIATE_IM2COL_REQUANT_SPATIAL_DIM(ACC_T, RELU) \
@@ -340,6 +385,7 @@ INSTANTIATE_IM2COL_REQUANT_RELU(int16_t);
#undef INSTANTIATE_IM2COL_REQUANT_RELU
#undef INSTANTIATE_IM2COL_REQUANT_SPATIAL_DIM
#undef INSTANTIATE_IM2COL_REQUANT_Q_GRANS
+#undef INSTANTIATE_IM2COL_REQUANT_BIAS_T
#undef INSTANTIATE_IM2COL_REQUANT_BASE
////////////////////////////////////////////////////////////////////////////////
diff --git a/src/Fbgemm.cc b/src/Fbgemm.cc
index 2f641ee..b691b88 100644
--- a/src/Fbgemm.cc
+++ b/src/Fbgemm.cc
@@ -48,7 +48,8 @@ void fbgemmPacked(
if (!cpuinfo_initialize()) {
throw std::runtime_error("Failed to initialize cpuinfo!");
}
- if ((!fbgemmHasAvx512Support() && !fbgemmHasAvx2Support())) {
+ if ((!fbgemmHasAvx512VnniSupport() && !fbgemmHasAvx512Support() &&
+ !fbgemmHasAvx2Support())) {
assert(0 && "unknown architecure");
}
@@ -62,7 +63,20 @@ void fbgemmPacked(
MR = blocking_params->MR;
} else {
- if (fbgemmHasAvx512Support()) {
+ if (fbgemmHasAvx512VnniSupport()) {
+ MCB = PackingTraits<
+ typename packingAMatrix::inpType,
+ typename packingAMatrix::accType,
+ inst_set_t::avx512_vnni>::MCB;
+ KCB = PackingTraits<
+ typename packingAMatrix::inpType,
+ typename packingAMatrix::accType,
+ inst_set_t::avx512_vnni>::KCB;
+ MR = PackingTraits<
+ typename packingAMatrix::inpType,
+ typename packingAMatrix::accType,
+ inst_set_t::avx512_vnni>::MR;
+ } else if (fbgemmHasAvx512Support()) {
MCB = PackingTraits<
typename packingAMatrix::inpType,
typename packingAMatrix::accType,
@@ -223,22 +237,26 @@ bool fbgemmSupportedCPU() {
////////////////////////////////////////////////////////////////////////////////
// ReQuantizeOutput
-#define INSTANTIATE_BASE(PACK_A, ACC_T, RELU, Q_GRAN) \
+#define INSTANTIATE_BASE(PACK_A, ACC_T, RELU, Q_GRAN, BIAS_TYPE) \
template void fbgemmPacked( \
PackMatrix<PACK_A<uint8_t, ACC_T>, uint8_t, ACC_T>& packA, \
PackMatrix<PackBMatrix<int8_t, ACC_T>, int8_t, ACC_T>& packB, \
uint8_t* C, \
int32_t* C_buffer, \
uint32_t ldc, \
- const ReQuantizeOutput<RELU, Q_GRAN>& outProcess, \
+ const ReQuantizeOutput<RELU, Q_GRAN, BIAS_TYPE>& outProcess, \
int thread_id, \
int num_threads, \
const BlockingFactors* blocking_params);
-#define INSTANTIATE_Q_GRANS(PACK_A, ACC_T, RELU) \
- INSTANTIATE_BASE(PACK_A, ACC_T, RELU, QuantizationGranularity::TENSOR); \
- INSTANTIATE_BASE(PACK_A, ACC_T, RELU, QuantizationGranularity::GROUP); \
- INSTANTIATE_BASE(PACK_A, ACC_T, RELU, QuantizationGranularity::OUT_CHANNEL);
+#define INSTANTIATE_BIAS_T(PACK_A, ACC_T, RELU, Q_GRAN) \
+ INSTANTIATE_BASE(PACK_A, ACC_T, RELU, Q_GRAN, float); \
+ INSTANTIATE_BASE(PACK_A, ACC_T, RELU, Q_GRAN, int32_t);
+
+#define INSTANTIATE_Q_GRANS(PACK_A, ACC_T, RELU) \
+ INSTANTIATE_BIAS_T(PACK_A, ACC_T, RELU, QuantizationGranularity::TENSOR); \
+ INSTANTIATE_BIAS_T(PACK_A, ACC_T, RELU, QuantizationGranularity::GROUP); \
+ INSTANTIATE_BIAS_T(PACK_A, ACC_T, RELU, QuantizationGranularity::OUT_CHANNEL);
#define INSTANTIATE_RELU(PACK_A, ACC_T) \
INSTANTIATE_Q_GRANS(PACK_A, ACC_T, false); \
@@ -254,27 +272,34 @@ INSTANTIATE_ACC_T(PackAWithRowOffset);
#undef INSTANTIATE_ACC_T
#undef INSTANTIATE_RELU
#undef INSTANTIATE_Q_GRANS
+#undef INSTANTIATE_BIAS_T
#undef INSTANTIATE_BASE
-#define INSTANTIATE_BASE(ACC_T, RELU, SPATIAL_DIM, Q_GRAN) \
- template void fbgemmPacked( \
- PackMatrix< \
- PackAWithIm2Col<uint8_t, ACC_T, SPATIAL_DIM>, \
- uint8_t, \
- ACC_T>& packA, \
- PackMatrix<PackBMatrix<int8_t, ACC_T>, int8_t, ACC_T>& packB, \
- uint8_t* C, \
- int32_t* C_buffer, \
- uint32_t ldc, \
- const ReQuantizeOutput<RELU, Q_GRAN>& outProcess, \
- int thread_id, \
- int num_threads, \
+#define INSTANTIATE_BASE(ACC_T, RELU, SPATIAL_DIM, Q_GRAN, BIAS_TYPE) \
+ template void fbgemmPacked( \
+ PackMatrix< \
+ PackAWithIm2Col<uint8_t, ACC_T, SPATIAL_DIM>, \
+ uint8_t, \
+ ACC_T>& packA, \
+ PackMatrix<PackBMatrix<int8_t, ACC_T>, int8_t, ACC_T>& packB, \
+ uint8_t* C, \
+ int32_t* C_buffer, \
+ uint32_t ldc, \
+ const ReQuantizeOutput<RELU, Q_GRAN, BIAS_TYPE>& outProcess, \
+ int thread_id, \
+ int num_threads, \
const BlockingFactors* blocking_params);
-#define INSTANTIATE_Q_GRANS(ACC_T, RELU, SPATIAL_DIM) \
- INSTANTIATE_BASE(ACC_T, RELU, SPATIAL_DIM, QuantizationGranularity::TENSOR); \
- INSTANTIATE_BASE(ACC_T, RELU, SPATIAL_DIM, QuantizationGranularity::GROUP); \
- INSTANTIATE_BASE( \
+#define INSTANTIATE_BIAS_T(ACC_T, RELU, SPATIAL_DIM, Q_GRAN) \
+ INSTANTIATE_BASE(ACC_T, RELU, SPATIAL_DIM, Q_GRAN, float); \
+ INSTANTIATE_BASE(ACC_T, RELU, SPATIAL_DIM, Q_GRAN, int32_t);
+
+#define INSTANTIATE_Q_GRANS(ACC_T, RELU, SPATIAL_DIM) \
+ INSTANTIATE_BIAS_T( \
+ ACC_T, RELU, SPATIAL_DIM, QuantizationGranularity::TENSOR); \
+ INSTANTIATE_BIAS_T( \
+ ACC_T, RELU, SPATIAL_DIM, QuantizationGranularity::GROUP); \
+ INSTANTIATE_BIAS_T( \
ACC_T, RELU, SPATIAL_DIM, QuantizationGranularity::OUT_CHANNEL);
#define INSTANTIATE_SPATIAL_DIM(ACC_T, RELU) \
@@ -291,6 +316,7 @@ INSTANTIATE_RELU(int16_t);
#undef INSTANTIATE_RELU
#undef INSTANTIATE_SPATIAL_DIM
#undef INSTANTIATE_Q_GRANS
+#undef INSTANTIATE_BIAS_T
#undef INSTANTIATE_BASE
////////////////////////////////////////////////////////////////////////////////
diff --git a/src/FbgemmConv.cc b/src/FbgemmConv.cc
index 5db63f6..de833d2 100644
--- a/src/FbgemmConv.cc
+++ b/src/FbgemmConv.cc
@@ -6,8 +6,9 @@
*/
#include <algorithm>
-#include <iostream>
+#include <numeric>
#include <vector>
+#include <functional>
#include "fbgemm/Fbgemm.h"
namespace fbgemm {
@@ -33,12 +34,24 @@ bool takeDepthWiseFastPath(const conv_param_t<SPATIAL_DIM>& conv_p) {
});
}
+template <int SPATIAL_DIM>
+bool takePointWiseFastPath(const conv_param_t<SPATIAL_DIM>& conv_p) {
+ return std::accumulate(conv_p.K.begin(), conv_p.K.end(), 0) == SPATIAL_DIM &&
+ std::accumulate(conv_p.stride.begin(), conv_p.stride.end(), 0) ==
+ SPATIAL_DIM &&
+ std::accumulate(conv_p.dilation.begin(), conv_p.dilation.end(), 0) ==
+ SPATIAL_DIM &&
+ std::accumulate(conv_p.pad.begin(), conv_p.pad.end(), 0) == 0;
+}
+
template <int SPATIAL_DIM, typename ACC_T>
optimized_conv_t ConvFastPath(const conv_param_t<SPATIAL_DIM>& conv_p) {
if (takeDepthWiseFastPath<SPATIAL_DIM, ACC_T>(conv_p)) {
return optimized_conv_t::depthwise;
} else if (fbgemmOptimizedGConv<SPATIAL_DIM>(conv_p)) {
return optimized_conv_t::groupwise;
+ } else if (takePointWiseFastPath<SPATIAL_DIM>(conv_p)) {
+ return optimized_conv_t::pointwise;
} else {
return optimized_conv_t::im2col;
}
@@ -58,58 +71,139 @@ int fbgemmConv(
static_assert(
SPATIAL_DIM == 2 || SPATIAL_DIM == 3,
"Only 2D and 3D convolutions are supported");
+
+ if (!packed_weights.isPackingCompliant(conv_p)) {
+ std::string msg =
+ "[FBGEMM_CONV_ERROR] Convolution parameters "
+ "mismatch between pre-packed weights and conv invocation! ";
+ msg += packed_weights.mismatchingParams(conv_p);
+ msg += std::string(
+ " Please pack weights using the same parameters "
+ "with which convolution operation is invoked!");
+ throw std::logic_error(msg);
+ }
+
switch (ConvFastPath<SPATIAL_DIM, ACC_T>(conv_p)) {
case optimized_conv_t::depthwise: {
// 2D and 3D depthwise fast path
// std::cout << "Depthwise fast path" << std::endl;
const std::int32_t* B_zero_point = outProcess.getBZeroPoint();
const float* C_multiplier = outProcess.getCMultiplier();
+ const float* act_times_w_scale = outProcess.getActWScale();
if (SPATIAL_DIM == 3) {
static_assert(
std::is_same<typename processOutputType::outType, std::uint8_t>::
value,
"For depthwise, only requantized output is supported");
- depthwise_3x3x3_pad_1(
- conv_p.MB, // mini batch
- conv_p.IN_DIM[0], // T
- conv_p.IN_DIM[1], // H
- conv_p.IN_DIM[2], // W
- conv_p.OC, // output channels
- conv_p.stride[0], // stride_t
- conv_p.stride[1], // stride_h
- conv_p.stride[2], // stride_w
- outProcess.getAZeroPoint(),
- activations,
- B_zero_point[0],
- *(packed_weights.getPackedWFor3DDW()),
- C_multiplier[0],
- outProcess.getCZeroPoint(),
- out,
- outProcess.getColOffsets(),
- outProcess.getBias(),
- outProcess.RELU_FUSED, // fuse_relu
- thread_id,
- num_threads);
+
+ if (processOutputType::QGRANType == QuantizationGranularity::TENSOR) {
+ depthwise_3x3x3_pad_1(
+ conv_p.MB, // mini batch
+ conv_p.IN_DIM[0], // T
+ conv_p.IN_DIM[1], // H
+ conv_p.IN_DIM[2], // W
+ conv_p.OC, // output channels
+ conv_p.stride[0], // stride_t
+ conv_p.stride[1], // stride_h
+ conv_p.stride[2], // stride_w
+ outProcess.getAZeroPoint(),
+ activations,
+ B_zero_point[0],
+ *(packed_weights.getPackedWForDepthwise()),
+ C_multiplier[0],
+ outProcess.getCZeroPoint(),
+ out,
+ outProcess.getColOffsets(),
+ outProcess.getBias(),
+ outProcess.RELU_FUSED, // fuse_relu
+ act_times_w_scale ? act_times_w_scale[0] : 1.0f,
+ thread_id,
+ num_threads);
+ } else if (
+ processOutputType::QGRANType ==
+ QuantizationGranularity::OUT_CHANNEL ||
+ processOutputType::QGRANType == QuantizationGranularity::GROUP) {
+ depthwise_3x3x3_per_channel_quantization_pad_1(
+ conv_p.MB, // mini batch
+ conv_p.IN_DIM[0], // T
+ conv_p.IN_DIM[1], // H
+ conv_p.IN_DIM[2], // W
+ conv_p.OC, // output channels
+ conv_p.stride[0], // stride_t
+ conv_p.stride[1], // stride_h
+ conv_p.stride[2], // stride_w
+ outProcess.getAZeroPoint(),
+ activations,
+ B_zero_point,
+ *(packed_weights.getPackedWForDepthwise()),
+ C_multiplier,
+ outProcess.getCZeroPoint(),
+ out,
+ outProcess.getColOffsets(),
+ outProcess.getBias(),
+ outProcess.RELU_FUSED, // fuse_relu
+ outProcess.getActWScale(), // act_scale * weight_scale
+ thread_id,
+ num_threads);
+ } else {
+ std::string msg =
+ "[FBGEMM_CONV_ERROR] This quantization granularity is "
+ "not supported";
+ throw std::runtime_error(msg);
+ }
} else {
- depthwise_3x3_pad_1(
- conv_p.MB, // mini batch
- conv_p.IN_DIM[0], // H
- conv_p.IN_DIM[1], // W
- conv_p.OC, // output channels
- conv_p.stride[0], // stride_h
- conv_p.stride[1], // stride_w
- outProcess.getAZeroPoint(),
- activations,
- B_zero_point[0],
- *(packed_weights.getPackedWFor2DDW()),
- C_multiplier[0],
- outProcess.getCZeroPoint(),
- out,
- outProcess.getColOffsets(),
- outProcess.getBias(),
- outProcess.RELU_FUSED, // fuse_relu
- thread_id,
- num_threads);
+ if (processOutputType::QGRANType == QuantizationGranularity::TENSOR) {
+ depthwise_3x3_pad_1(
+ conv_p.MB, // mini batch
+ conv_p.IN_DIM[0], // H
+ conv_p.IN_DIM[1], // W
+ conv_p.OC, // output channels
+ conv_p.stride[0], // stride_h
+ conv_p.stride[1], // stride_w
+ outProcess.getAZeroPoint(),
+ activations,
+ B_zero_point[0],
+ *(packed_weights.getPackedWForDepthwise()),
+ C_multiplier[0],
+ outProcess.getCZeroPoint(),
+ out,
+ outProcess.getColOffsets(),
+ outProcess.getBias(),
+ outProcess.RELU_FUSED, // fuse_relu
+ act_times_w_scale ? act_times_w_scale[0] : 1.0f,
+ thread_id,
+ num_threads);
+ } else if (
+ processOutputType::QGRANType ==
+ QuantizationGranularity::OUT_CHANNEL ||
+ processOutputType::QGRANType == QuantizationGranularity::GROUP) {
+ // The number of channels == groups for depthwise convolutions
+ depthwise_3x3_per_channel_quantization_pad_1(
+ conv_p.MB, // mini batch
+ conv_p.IN_DIM[0], // H
+ conv_p.IN_DIM[1], // W
+ conv_p.OC, // output channels
+ conv_p.stride[0], // stride_h
+ conv_p.stride[1], // stride_w
+ outProcess.getAZeroPoint(),
+ activations,
+ B_zero_point,
+ *(packed_weights.getPackedWForDepthwise()),
+ C_multiplier,
+ outProcess.getCZeroPoint(),
+ out,
+ outProcess.getColOffsets(),
+ outProcess.getBias(),
+ outProcess.RELU_FUSED, // fuse_relu
+ outProcess.getActWScale(), // act_scale * weight_scale
+ thread_id,
+ num_threads);
+ } else {
+ std::string msg =
+ "[FBGEMM_CONV_ERROR] This quantization granularity is "
+ "not supported";
+ throw std::runtime_error(msg);
+ }
}
break;
}
@@ -134,14 +228,68 @@ int fbgemmConv(
num_threads);
break;
}
+ case optimized_conv_t::pointwise: {
+ std::vector<int32_t> row_offset_buf(
+ PackAWithRowOffset<uint8_t>::rowOffsetBufferSize(blocking_params));
+ int image_dim = std::accumulate(
+ conv_p.IN_DIM.begin(),
+ conv_p.IN_DIM.end(),
+ 1,
+ std::multiplies<int>());
+ PackAWithRowOffset<uint8_t, ACC_T> packA(
+ matrix_op_t::NoTranspose,
+ conv_p.MB * image_dim,
+ conv_p.IC,
+ activations,
+ conv_p.IC,
+ nullptr,
+ conv_p.G,
+ row_offset_buf.data(),
+ blocking_params);
+
+ outProcess.setRowOffsets(row_offset_buf.data());
+ fbgemmPacked(
+ packA,
+ *(packed_weights.getPackedWForPointwise()),
+ out,
+ outBuffer,
+ conv_p.OC,
+ outProcess,
+ thread_id,
+ num_threads,
+ blocking_params);
+ break;
+ }
case optimized_conv_t::im2col: {
// All other convolutions go through im2col-based implementation
// std::cout << "Im2col path" << std::endl;
std::vector<int32_t> row_offset_buf(
- PackAWithIm2Col<uint8_t, ACC_T, SPATIAL_DIM>::rowOffsetBufferSize());
+ PackAWithIm2Col<uint8_t, ACC_T, SPATIAL_DIM>::rowOffsetBufferSize(
+ blocking_params));
const std::int32_t* b_zero_point = outProcess.getBZeroPoint();
- bool b_symmetric = b_zero_point[0] == 0;
+ bool b_symmetric = false;
+ if (processOutputType::QGRANType == QuantizationGranularity::TENSOR) {
+ b_symmetric = b_zero_point[0] == 0;
+ } else if (
+ processOutputType::QGRANType == QuantizationGranularity::GROUP) {
+ b_symmetric =
+ std::all_of(b_zero_point, b_zero_point + conv_p.G, [](int i) {
+ return i == 0;
+ });
+ } else if (
+ processOutputType::QGRANType ==
+ QuantizationGranularity::OUT_CHANNEL) {
+ b_symmetric =
+ std::all_of(b_zero_point, b_zero_point + conv_p.OC, [](int i) {
+ return i == 0;
+ });
+ } else {
+ std::string msg =
+ "[FBGEMM_CONV_ERROR] This quantization granularity is "
+ "not supported";
+ throw std::runtime_error(msg);
+ }
PackAWithIm2Col<uint8_t, ACC_T, SPATIAL_DIM> packA(
conv_p,
activations,
@@ -169,21 +317,25 @@ int fbgemmConv(
return 0;
}
-#define INSTANTIATE_BASE(ACC_T, Q_GRAN, RELU, SPATIAL_DIM) \
+#define INSTANTIATE_BASE(ACC_T, Q_GRAN, RELU, SPATIAL_DIM, BIAS_TYPE) \
template int fbgemmConv( \
const conv_param_t<SPATIAL_DIM>& conv_p, \
const std::uint8_t* activations, \
PackWeightsForConv<SPATIAL_DIM, std::int8_t, ACC_T>& packed_weights, \
std::uint8_t* out, \
std::int32_t* outBuffer, \
- ReQuantizeOutput<RELU, Q_GRAN>& outProcess, \
+ ReQuantizeOutput<RELU, Q_GRAN, BIAS_TYPE>& outProcess, \
int thread_id, \
int num_threads, \
const BlockingFactors* blocking_params);
+#define INSTANTIATE_BIAS_T(ACC_T, Q_GRAN, RELU, SPATIAL_DIM) \
+ INSTANTIATE_BASE(ACC_T, Q_GRAN, RELU, SPATIAL_DIM, float); \
+ INSTANTIATE_BASE(ACC_T, Q_GRAN, RELU, SPATIAL_DIM, int32_t);
+
#define INSTANTIATE_SPATIAL_DIM(ACC_T, Q_GRAN, RELU) \
- INSTANTIATE_BASE(ACC_T, Q_GRAN, RELU, 2); \
- INSTANTIATE_BASE(ACC_T, Q_GRAN, RELU, 3);
+ INSTANTIATE_BIAS_T(ACC_T, Q_GRAN, RELU, 2); \
+ INSTANTIATE_BIAS_T(ACC_T, Q_GRAN, RELU, 3);
#define INSTANTIATE_RELU(ACC_T, Q_GRAN) \
INSTANTIATE_SPATIAL_DIM(ACC_T, Q_GRAN, true); \
@@ -199,6 +351,7 @@ INSTANTIATE_Q_GRANS(std::int32_t);
#undef INSTANTIATE_Q_GRANS
#undef INSTANTIATE_RELU
#undef INSTANTIATE_SPATIAL_DIM
+#undef INSTANTIATE_BIAS_T
#undef INSTANTIATE_BASE
template bool takeDepthWiseFastPath<2, std::int32_t>(
diff --git a/src/FbgemmFP16.cc b/src/FbgemmFP16.cc
index f357966..b034f2c 100644
--- a/src/FbgemmFP16.cc
+++ b/src/FbgemmFP16.cc
@@ -50,6 +50,7 @@ struct KernelInfo {
// autotuned kernel splits for various cases m = 1:mb_max
// may need re-autotuning for new uarch
+ // clang-format off
static constexpr int partition[121][2][2] = {
// NOTE: clang-format wants to use a different formatting but the current
// formatting should be easier to read.
@@ -175,6 +176,7 @@ struct KernelInfo {
{ { 6, 19 }, { 5, 1 } }, // 119
{ { 6, 20 }, { 0, 0 } }, // 120
};
+ // clang-format on
};
constexpr KernelInfo::knl_ptr KernelInfo::kernel[7];;
constexpr int KernelInfo::partition[121][2][2];
diff --git a/src/FbgemmI8Depthwise3DAvx2.cc b/src/FbgemmI8Depthwise3DAvx2.cc
new file mode 100644
index 0000000..2114b20
--- /dev/null
+++ b/src/FbgemmI8Depthwise3DAvx2.cc
@@ -0,0 +1,1423 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+#include "fbgemm/FbgemmI8DepthwiseAvx2.h"
+
+#include <string>
+#include <tuple> // for tie
+
+#include "FbgemmI8DepthwiseAvx2-inl.h"
+
+using namespace std;
+
+namespace fbgemm {
+
+template <
+ bool SUM_A,
+ bool REMAINDER = false,
+ bool PER_CHANNEL_QUANTIZATION = false>
+static inline ALWAYS_INLINE void inner_prod_3x3x3_packed_(
+ int T,
+ int H,
+ int W,
+ int K,
+ int t_in,
+ int h_in,
+ int w_in,
+ const uint8_t* A,
+ int32_t A_zero_point,
+ const int8_t* Bp,
+ const int32_t* B_zero_point,
+ int32_t* C,
+ int remainder,
+ int32_t* row_offsets) {
+ __m256i A_zero_point_v = _mm256_set1_epi8(static_cast<uint8_t>(A_zero_point));
+ __m256i mask_v = _mm256_setzero_si256();
+ if (REMAINDER) {
+ mask_v = _mm256_loadu_si256(
+ reinterpret_cast<const __m256i*>(masks[remainder / 4]));
+ }
+
+ // The code below can be written as a simple R*S loop but the compiler
+ // doesn't unroll so we're manually unrolling it.
+ // constexpr int R = 3, S = 3;
+ // array<__m256i, R * S> a_v;
+ // for (int r = 0; r < R; ++r) {
+ // for (int s = 0; s < S; ++s) {
+ // if (h_in + r >= 0 && h_in + r < H && w_in + s >= 0 && w_in + s < W) {
+ // if (REMAINDER) {
+ // a_v[r * S + s] =
+ // _mm256_maskload_epi32((const int *)(A + (r * W + s) * K),
+ // mask_v);
+ // } else {
+ // a_v[r * S + s] =
+ // _mm256_lddqu_si256((const __m256i *)(A + (r * W + s) * K));
+ // }
+ // } else {
+ // a_v[r * S + s] = A_zero_point_v;
+ // }
+ // }
+ // }
+ __m256i a_v[8];
+ a_v[0] = A_zero_point_v;
+ a_v[1] = A_zero_point_v;
+ a_v[2] = A_zero_point_v;
+ a_v[3] = A_zero_point_v;
+ a_v[4] = A_zero_point_v;
+ a_v[5] = A_zero_point_v;
+ a_v[6] = A_zero_point_v;
+ a_v[7] = A_zero_point_v;
+
+ if (t_in >= 0 && t_in < T) {
+ if (h_in >= 0 && h_in < H) {
+ if (w_in >= 0 && w_in < W) {
+ a_v[0] = load_a<REMAINDER>(A + ((0 * H + 0) * W + 0) * K, mask_v);
+ }
+ if (w_in + 1 >= 0 && w_in + 1 < W) {
+ a_v[1] = load_a<REMAINDER>(A + ((0 * H + 0) * W + 1) * K, mask_v);
+ }
+ if (w_in + 2 >= 0 && w_in + 2 < W) {
+ a_v[2] = load_a<REMAINDER>(A + ((0 * H + 0) * W + 2) * K, mask_v);
+ }
+ }
+
+ if (h_in + 1 >= 0 && h_in + 1 < H) {
+ if (w_in >= 0 && w_in < W) {
+ a_v[3] = load_a<REMAINDER>(A + ((0 * H + 1) * W + 0) * K, mask_v);
+ }
+ if (w_in + 1 >= 0 && w_in + 1 < W) {
+ a_v[4] = load_a<REMAINDER>(A + ((0 * H + 1) * W + 1) * K, mask_v);
+ }
+ if (w_in + 2 >= 0 && w_in + 2 < W) {
+ a_v[5] = load_a<REMAINDER>(A + ((0 * H + 1) * W + 2) * K, mask_v);
+ }
+ }
+
+ if (h_in + 2 >= 0 && h_in + 2 < H) {
+ if (w_in >= 0 && w_in < W) {
+ a_v[6] = load_a<REMAINDER>(A + ((0 * H + 2) * W + 0) * K, mask_v);
+ }
+ if (w_in + 1 >= 0 && w_in + 1 < W) {
+ a_v[7] = load_a<REMAINDER>(A + ((0 * H + 2) * W + 1) * K, mask_v);
+ }
+ }
+ }
+
+ __m256i a_sum[4];
+ inner_prod_packed_<8, SUM_A, REMAINDER>(
+ a_v, reinterpret_cast<const __m256i*>(Bp), C, remainder, a_sum);
+
+ a_v[0] = A_zero_point_v;
+ a_v[1] = A_zero_point_v;
+ a_v[2] = A_zero_point_v;
+ a_v[3] = A_zero_point_v;
+ a_v[4] = A_zero_point_v;
+ a_v[5] = A_zero_point_v;
+ a_v[6] = A_zero_point_v;
+ a_v[7] = A_zero_point_v;
+
+ if (t_in >= 0 && t_in < T) {
+ if (h_in + 2 >= 0 && h_in + 2 < H) {
+ if (w_in + 2 >= 0 && w_in + 2 < W) {
+ a_v[0] = load_a<REMAINDER>(A + ((0 * H + 2) * W + 2) * K, mask_v);
+ }
+ }
+ }
+
+ if (t_in + 1 >= 0 && t_in + 1 < T) {
+ if (h_in >= 0 && h_in < H) {
+ if (w_in >= 0 && w_in < W) {
+ a_v[1] = load_a<REMAINDER>(A + ((1 * H + 0) * W + 0) * K, mask_v);
+ }
+ if (w_in + 1 >= 0 && w_in + 1 < W) {
+ a_v[2] = load_a<REMAINDER>(A + ((1 * H + 0) * W + 1) * K, mask_v);
+ }
+ if (w_in + 2 >= 0 && w_in + 2 < W) {
+ a_v[3] = load_a<REMAINDER>(A + ((1 * H + 0) * W + 2) * K, mask_v);
+ }
+ }
+
+ if (h_in + 1 >= 0 && h_in + 1 < H) {
+ if (w_in >= 0 && w_in < W) {
+ a_v[4] = load_a<REMAINDER>(A + ((1 * H + 1) * W + 0) * K, mask_v);
+ }
+ if (w_in + 1 >= 0 && w_in + 1 < W) {
+ a_v[5] = load_a<REMAINDER>(A + ((1 * H + 1) * W + 1) * K, mask_v);
+ }
+ if (w_in + 2 >= 0 && w_in + 2 < W) {
+ a_v[6] = load_a<REMAINDER>(A + ((1 * H + 1) * W + 2) * K, mask_v);
+ }
+ }
+
+ if (h_in + 2 >= 0 && h_in + 2 < H) {
+ if (w_in >= 0 && w_in < W) {
+ a_v[7] = load_a<REMAINDER>(A + ((1 * H + 2) * W + 0) * K, mask_v);
+ }
+ }
+ }
+
+ __m256i a_sum_temp[4];
+ inner_prod_packed_<8, SUM_A, REMAINDER, true /* acc */>(
+ a_v, reinterpret_cast<const __m256i*>(Bp) + 8, C, remainder, a_sum_temp);
+ if (SUM_A) {
+ a_sum[0] = _mm256_add_epi32(a_sum[0], a_sum_temp[0]);
+ a_sum[1] = _mm256_add_epi32(a_sum[1], a_sum_temp[1]);
+ a_sum[2] = _mm256_add_epi32(a_sum[2], a_sum_temp[2]);
+ a_sum[3] = _mm256_add_epi32(a_sum[3], a_sum_temp[3]);
+ }
+
+ a_v[0] = A_zero_point_v;
+ a_v[1] = A_zero_point_v;
+ a_v[2] = A_zero_point_v;
+ a_v[3] = A_zero_point_v;
+ a_v[4] = A_zero_point_v;
+ a_v[5] = A_zero_point_v;
+ a_v[6] = A_zero_point_v;
+ a_v[7] = A_zero_point_v;
+
+ if (t_in + 1 >= 0 && t_in + 1 < T) {
+ if (h_in + 2 >= 0 && h_in + 2 < H) {
+ if (w_in + 1 >= 0 && w_in + 1 < W) {
+ a_v[0] = load_a<REMAINDER>(A + ((1 * H + 2) * W + 1) * K, mask_v);
+ }
+ if (w_in + 2 >= 0 && w_in + 2 < W) {
+ a_v[1] = load_a<REMAINDER>(A + ((1 * H + 2) * W + 2) * K, mask_v);
+ }
+ }
+ }
+
+ if (t_in + 2 >= 0 && t_in + 2 < T) {
+ if (h_in >= 0 && h_in < H) {
+ if (w_in >= 0 && w_in < W) {
+ a_v[2] = load_a<REMAINDER>(A + ((2 * H + 0) * W + 0) * K, mask_v);
+ }
+ if (w_in + 1 >= 0 && w_in + 1 < W) {
+ a_v[3] = load_a<REMAINDER>(A + ((2 * H + 0) * W + 1) * K, mask_v);
+ }
+ if (w_in + 2 >= 0 && w_in + 2 < W) {
+ a_v[4] = load_a<REMAINDER>(A + ((2 * H + 0) * W + 2) * K, mask_v);
+ }
+ }
+
+ if (h_in + 1 >= 0 && h_in + 1 < H) {
+ if (w_in >= 0 && w_in < W) {
+ a_v[5] = load_a<REMAINDER>(A + ((2 * H + 1) * W + 0) * K, mask_v);
+ }
+ if (w_in + 1 >= 0 && w_in + 1 < W) {
+ a_v[6] = load_a<REMAINDER>(A + ((2 * H + 1) * W + 1) * K, mask_v);
+ }
+ if (w_in + 2 >= 0 && w_in + 2 < W) {
+ a_v[7] = load_a<REMAINDER>(A + ((2 * H + 1) * W + 2) * K, mask_v);
+ }
+ }
+ }
+
+ inner_prod_packed_<8, SUM_A, REMAINDER, true /* acc */>(
+ a_v, reinterpret_cast<const __m256i*>(Bp) + 16, C, remainder, a_sum_temp);
+ if (SUM_A) {
+ a_sum[0] = _mm256_add_epi32(a_sum[0], a_sum_temp[0]);
+ a_sum[1] = _mm256_add_epi32(a_sum[1], a_sum_temp[1]);
+ a_sum[2] = _mm256_add_epi32(a_sum[2], a_sum_temp[2]);
+ a_sum[3] = _mm256_add_epi32(a_sum[3], a_sum_temp[3]);
+ }
+
+ a_v[0] = A_zero_point_v;
+ a_v[1] = A_zero_point_v;
+ a_v[2] = A_zero_point_v;
+
+ if (t_in + 2 >= 0 && t_in + 2 < T) {
+ if (h_in + 2 >= 0 && h_in + 2 < H) {
+ if (w_in >= 0 && w_in < W) {
+ a_v[0] = load_a<REMAINDER>(A + ((2 * H + 2) * W + 0) * K, mask_v);
+ }
+ if (w_in + 1 >= 0 && w_in + 1 < W) {
+ a_v[1] = load_a<REMAINDER>(A + ((2 * H + 2) * W + 1) * K, mask_v);
+ }
+ if (w_in + 2 >= 0 && w_in + 2 < W) {
+ a_v[2] = load_a<REMAINDER>(A + ((2 * H + 2) * W + 2) * K, mask_v);
+ }
+ }
+ }
+
+ inner_prod_packed_<3, SUM_A, REMAINDER, true /* acc */>(
+ a_v, reinterpret_cast<const __m256i*>(Bp) + 24, C, remainder, a_sum_temp);
+
+ if (SUM_A) {
+ a_sum[0] = _mm256_add_epi32(a_sum[0], a_sum_temp[0]);
+ a_sum[1] = _mm256_add_epi32(a_sum[1], a_sum_temp[1]);
+ a_sum[2] = _mm256_add_epi32(a_sum[2], a_sum_temp[2]);
+ a_sum[3] = _mm256_add_epi32(a_sum[3], a_sum_temp[3]);
+
+ __m256i B_zero_point_v;
+ for (int i = 0; i < (REMAINDER ? (remainder / 8) : 4); ++i) {
+ if (PER_CHANNEL_QUANTIZATION) {
+ B_zero_point_v = _mm256_loadu_si256(
+ reinterpret_cast<const __m256i*>(B_zero_point + i * 8));
+ } else {
+ B_zero_point_v = _mm256_set1_epi32(B_zero_point[0]);
+ }
+ _mm256_store_si256(
+ reinterpret_cast<__m256i*>(&row_offsets[i * 8]),
+ _mm256_mullo_epi32(a_sum[i], B_zero_point_v));
+ }
+ }
+}
+
+template <
+ bool FUSE_RELU,
+ bool HAS_BIAS,
+ bool A_SYMMETRIC,
+ bool B_SYMMETRIC,
+ typename BIAS_TYPE>
+static inline ALWAYS_INLINE void depthwise_3x3x3_kernel_(
+ int T,
+ int H,
+ int W,
+ int K,
+ int t,
+ int h,
+ int w,
+ int stride_t,
+ int stride_h,
+ int stride_w,
+ int32_t A_zero_point,
+ const uint8_t* A,
+ int32_t B_zero_point,
+ const int8_t* Bp,
+ float C_multiplier,
+ int32_t C_zero_point,
+ int32_t* C_int32,
+ uint8_t* C_uint8,
+ int32_t* row_offsets,
+ const int32_t* col_offsets,
+ const BIAS_TYPE* bias,
+ float act_times_w_scale) {
+ constexpr int R = 3, S = 3;
+ constexpr int PAD_P = 1, PAD_T = 1, PAD_B = 1, PAD_L = 1, PAD_R = 1;
+ int H_OUT = (H + PAD_T + PAD_B - R) / stride_h + 1;
+ int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1;
+ int t_in = -PAD_P + t * stride_t;
+ int h_in = -PAD_T + h * stride_h;
+ int w_in = -PAD_L + w * stride_w;
+
+ int k;
+ for (k = 0; k < K / 32 * 32; k += 32) {
+ inner_prod_3x3x3_packed_<!B_SYMMETRIC /*SUM_A*/>(
+ T,
+ H,
+ W,
+ K,
+ t_in,
+ h_in,
+ w_in,
+ A + ((t_in * H + h_in) * W + w_in) * K + k,
+ A_zero_point,
+ Bp + k * 28,
+ &B_zero_point,
+ C_int32 + k,
+ 0,
+ B_SYMMETRIC ? nullptr : &row_offsets[k]);
+ }
+ int remainder = K - k;
+ if (remainder) {
+ inner_prod_3x3x3_packed_<!B_SYMMETRIC /*SUM_A*/, true>(
+ T,
+ H,
+ W,
+ K,
+ t_in,
+ h_in,
+ w_in,
+ A + ((t_in * H + h_in) * W + w_in) * K + k,
+ A_zero_point,
+ Bp + k * 28,
+ &B_zero_point,
+ C_int32 + k,
+ remainder,
+ B_SYMMETRIC ? nullptr : &row_offsets[k]);
+ }
+
+ requantize_<
+ FUSE_RELU,
+ HAS_BIAS,
+ false, /*PER_CHAN_QUANT*/
+ A_SYMMETRIC,
+ B_SYMMETRIC>(
+ A_zero_point,
+ &C_multiplier,
+ C_zero_point,
+ C_int32,
+ C_uint8 + ((t * H_OUT + h) * W_OUT + w) * K,
+ K,
+ row_offsets,
+ col_offsets,
+ bias,
+ &act_times_w_scale);
+}
+
+template <bool FUSE_RELU, bool HAS_BIAS, bool A_SYMMETRIC, typename BIAS_TYPE>
+static inline ALWAYS_INLINE void
+depthwise_3x3x3_per_channel_quantization_kernel_(
+ int T,
+ int H,
+ int W,
+ int K,
+ int t,
+ int h,
+ int w,
+ int stride_t,
+ int stride_h,
+ int stride_w,
+ int32_t A_zero_point,
+ const uint8_t* A,
+ const int32_t* B_zero_point,
+ const int8_t* Bp,
+ const float* C_multiplier,
+ int32_t C_zero_point,
+ int32_t* C_int32,
+ uint8_t* C_uint8,
+ int32_t* row_offsets,
+ const int32_t* col_offsets,
+ const BIAS_TYPE* bias,
+ const float* act_times_w_scale) {
+ constexpr int R = 3, S = 3;
+ constexpr int PAD_P = 1, PAD_T = 1, PAD_B = 1, PAD_L = 1, PAD_R = 1;
+ int H_OUT = (H + PAD_T + PAD_B - R) / stride_h + 1;
+ int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1;
+ int t_in = -PAD_P + t * stride_t;
+ int h_in = -PAD_T + h * stride_h;
+ int w_in = -PAD_L + w * stride_w;
+
+ int k;
+ for (k = 0; k < K / 32 * 32; k += 32) {
+ inner_prod_3x3x3_packed_<
+ true, /*SUM_A*/
+ false, /*remainder*/
+ true /*per-channel*/>(
+ T,
+ H,
+ W,
+ K,
+ t_in,
+ h_in,
+ w_in,
+ A + ((t_in * H + h_in) * W + w_in) * K + k,
+ A_zero_point,
+ Bp + k * 28,
+ B_zero_point + k,
+ C_int32 + k,
+ 0,
+ &row_offsets[k]);
+ }
+ int remainder = K - k;
+ if (remainder) {
+ inner_prod_3x3x3_packed_<
+ true, /*SUM_A*/
+ true, /*remainder*/
+ true /*per-channel*/>(
+ T,
+ H,
+ W,
+ K,
+ t_in,
+ h_in,
+ w_in,
+ A + ((t_in * H + h_in) * W + w_in) * K + k,
+ A_zero_point,
+ Bp + k * 28,
+ B_zero_point + k,
+ C_int32 + k,
+ remainder,
+ &row_offsets[k]);
+ }
+ requantize_<
+ FUSE_RELU,
+ HAS_BIAS,
+ true, /*PER_CHAN_QUANT*/
+ A_SYMMETRIC,
+ false /*B_SYMM*/>(
+ A_zero_point,
+ C_multiplier,
+ C_zero_point,
+ C_int32,
+ C_uint8 + ((t * H_OUT + h) * W_OUT + w) * K,
+ K,
+ row_offsets,
+ col_offsets,
+ bias,
+ act_times_w_scale);
+}
+
+template <
+ bool FUSE_RELU,
+ bool HAS_BIAS,
+ bool A_SYMMETRIC,
+ bool B_SYMMETRIC,
+ typename BIAS_TYPE>
+static inline ALWAYS_INLINE void depthwise_3x3x3_pad_1_(
+ int N,
+ int T,
+ int H,
+ int W,
+ int K,
+ int stride_t,
+ int stride_h,
+ int stride_w,
+ int32_t A_zero_point,
+ const uint8_t* A,
+ int32_t B_zero_point,
+ const PackedDepthWiseConvMatrix& B,
+ float C_multiplier,
+ int32_t C_zero_point,
+ int32_t* C_int32,
+ uint8_t* C_uint8,
+ const int32_t* col_offsets,
+ const BIAS_TYPE* bias,
+ float act_times_w_scale,
+ int thread_id,
+ int num_threads) {
+ assert(K % 8 == 0);
+ constexpr int K_T = 3, K_H = 3, K_W = 3;
+ constexpr int PAD_P = 1, PAD_N = 1, PAD_T = 1, PAD_B = 1, PAD_L = 1,
+ PAD_R = 1;
+ int T_OUT = (T + PAD_P + PAD_N - K_T) / stride_t + 1;
+ int H_OUT = (H + PAD_T + PAD_B - K_H) / stride_h + 1;
+ int W_OUT = (W + PAD_L + PAD_R - K_W) / stride_w + 1;
+ const int8_t* Bp = B.PackedMat();
+
+ //int32_t row_offsets[(K + 31) / 32 * 32] __attribute__((aligned(64)));
+ int32_t* row_offsets
+ = static_cast<int32_t*>(ALIGNED_MALLOC((K + 31) / 32 * 32 * sizeof(int32_t), 64));
+
+ int n_begin, n_end;
+ int t_begin, t_end, h_begin, h_end;
+ if (N >= num_threads) {
+ int n_per_thread = (N + num_threads - 1) / num_threads;
+ n_begin = std::min(thread_id * n_per_thread, N);
+ n_end = std::min(n_begin + n_per_thread, N);
+ t_begin = 0;
+ t_end = T_OUT;
+ h_begin = 0;
+ h_end = H_OUT;
+ } else {
+ int nthreads_per_n = num_threads / N;
+ n_begin = std::min(thread_id / nthreads_per_n, N);
+ n_end = std::min(n_begin + 1, N);
+
+ int tid_of_n_begin = std::min(n_begin * nthreads_per_n, num_threads);
+ int tid_of_n_end = std::min(tid_of_n_begin + nthreads_per_n, num_threads);
+ int nthreads_of_n = tid_of_n_end - tid_of_n_begin;
+ int tid_within_n = thread_id - tid_of_n_begin;
+ assert(tid_within_n >= 0);
+ assert(tid_within_n < nthreads_of_n);
+
+ // n is processed by num_threads_t * num_threads_h 2D grid of threads
+ int num_threads_t, num_threads_h;
+ // num_threads_w <= num_threads_h
+ tie(num_threads_t, num_threads_h) = closest_factors_(nthreads_of_n);
+ int tid_t = tid_within_n / num_threads_h;
+ int tid_h = tid_within_n % num_threads_h;
+
+ int t_per_thread = (T_OUT + num_threads_t - 1) / num_threads_t;
+ t_begin = std::min(tid_t * t_per_thread, T_OUT);
+ t_end = std::min(t_begin + t_per_thread, T_OUT);
+
+ int h_per_thread = (H_OUT + num_threads_h - 1) / num_threads_h;
+ h_begin = std::min(tid_h * h_per_thread, H_OUT);
+ h_end = std::min(h_begin + h_per_thread, H_OUT);
+ }
+
+ for (int n = n_begin; n < n_end; ++n) {
+ const uint8_t* A_base = A + n * T * H * W * K;
+ uint8_t* C_uint8_base = C_uint8 + n * T_OUT * H_OUT * W_OUT * K;
+
+ for (int t = t_begin; t < t_end; ++t) {
+ for (int h = h_begin; h < h_end; ++h) {
+ for (int w = 0; w < W_OUT; ++w) {
+ depthwise_3x3x3_kernel_<
+ FUSE_RELU,
+ HAS_BIAS,
+ A_SYMMETRIC,
+ B_SYMMETRIC>(
+ T,
+ H,
+ W,
+ K,
+ t,
+ h,
+ w,
+ stride_t,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A_base,
+ B_zero_point,
+ Bp,
+ C_multiplier,
+ C_zero_point,
+ C_int32,
+ C_uint8_base,
+ row_offsets,
+ col_offsets,
+ bias,
+ act_times_w_scale);
+ } // w
+ } // h
+ } // t
+ } // for each n
+ FREE(row_offsets);
+};
+
+template <bool FUSE_RELU, bool HAS_BIAS, bool A_SYMMETRIC, typename BIAS_TYPE>
+static inline ALWAYS_INLINE void
+depthwise_3x3x3_per_channel_quantization_pad_1_(
+ int N,
+ int T,
+ int H,
+ int W,
+ int K,
+ int stride_t,
+ int stride_h,
+ int stride_w,
+ int32_t A_zero_point,
+ const uint8_t* A,
+ const int32_t* B_zero_point,
+ const PackedDepthWiseConvMatrix& B,
+ const float* C_multiplier,
+ int32_t C_zero_point,
+ int32_t* C_int32,
+ uint8_t* C_uint8,
+ const int32_t* col_offsets,
+ const BIAS_TYPE* bias,
+ const float* act_times_w_scale,
+ int thread_id,
+ int num_threads) {
+ assert(K % 8 == 0);
+ constexpr int K_T = 3, K_H = 3, K_W = 3;
+ constexpr int PAD_P = 1, PAD_N = 1, PAD_T = 1, PAD_B = 1, PAD_L = 1,
+ PAD_R = 1;
+ int T_OUT = (T + PAD_P + PAD_N - K_T) / stride_t + 1;
+ int H_OUT = (H + PAD_T + PAD_B - K_H) / stride_h + 1;
+ int W_OUT = (W + PAD_L + PAD_R - K_W) / stride_w + 1;
+ const int8_t* Bp = B.PackedMat();
+
+ //int32_t row_offsets[(K + 31) / 32 * 32] __attribute__((aligned(64)));
+ int32_t* row_offsets
+ = static_cast<int32_t*>(ALIGNED_MALLOC((K + 31) / 32 * 32 * sizeof(int32_t), 64));
+
+ int n_begin, n_end;
+ int t_begin, t_end, h_begin, h_end;
+ if (N >= num_threads) {
+ int n_per_thread = (N + num_threads - 1) / num_threads;
+ n_begin = std::min(thread_id * n_per_thread, N);
+ n_end = std::min(n_begin + n_per_thread, N);
+ t_begin = 0;
+ t_end = T_OUT;
+ h_begin = 0;
+ h_end = H_OUT;
+ } else {
+ int nthreads_per_n = num_threads / N;
+ n_begin = std::min(thread_id / nthreads_per_n, N);
+ n_end = std::min(n_begin + 1, N);
+
+ int tid_of_n_begin = std::min(n_begin * nthreads_per_n, num_threads);
+ int tid_of_n_end = std::min(tid_of_n_begin + nthreads_per_n, num_threads);
+ int nthreads_of_n = tid_of_n_end - tid_of_n_begin;
+ int tid_within_n = thread_id - tid_of_n_begin;
+ assert(tid_within_n >= 0);
+ assert(tid_within_n < nthreads_of_n);
+
+ // n is processed by num_threads_t * num_threads_h 2D grid of threads
+ int num_threads_t, num_threads_h;
+ // num_threads_w <= num_threads_h
+ tie(num_threads_t, num_threads_h) = closest_factors_(nthreads_of_n);
+ int tid_t = tid_within_n / num_threads_h;
+ int tid_h = tid_within_n % num_threads_h;
+
+ int t_per_thread = (T_OUT + num_threads_t - 1) / num_threads_t;
+ t_begin = std::min(tid_t * t_per_thread, T_OUT);
+ t_end = std::min(t_begin + t_per_thread, T_OUT);
+
+ int h_per_thread = (H_OUT + num_threads_h - 1) / num_threads_h;
+ h_begin = std::min(tid_h * h_per_thread, H_OUT);
+ h_end = std::min(h_begin + h_per_thread, H_OUT);
+ }
+
+ for (int n = n_begin; n < n_end; ++n) {
+ const uint8_t* A_base = A + n * T * H * W * K;
+ uint8_t* C_uint8_base = C_uint8 + n * T_OUT * H_OUT * W_OUT * K;
+
+ for (int t = t_begin; t < t_end; ++t) {
+ for (int h = h_begin; h < h_end; ++h) {
+ for (int w = 0; w < W_OUT; ++w) {
+ depthwise_3x3x3_per_channel_quantization_kernel_<
+ FUSE_RELU,
+ HAS_BIAS,
+ A_SYMMETRIC,
+ BIAS_TYPE>(
+ T,
+ H,
+ W,
+ K,
+ t,
+ h,
+ w,
+ stride_t,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A_base,
+ B_zero_point,
+ Bp,
+ C_multiplier,
+ C_zero_point,
+ C_int32,
+ C_uint8_base,
+ row_offsets,
+ col_offsets,
+ bias,
+ act_times_w_scale);
+ } // w
+ } // h
+ } // t
+ } // for each n
+ FREE(row_offsets);
+};
+
+// Dispatch A_SYMMETRIC and B_SYMMETRIC
+template <bool FUSE_RELU, bool HAS_BIAS, typename BIAS_TYPE>
+static void depthwise_3x3x3_pad_1_(
+ int N,
+ int T,
+ int H,
+ int W,
+ int K,
+ int stride_t,
+ int stride_h,
+ int stride_w,
+ int32_t A_zero_point,
+ const uint8_t* A,
+ int32_t B_zero_point,
+ const PackedDepthWiseConvMatrix& B,
+ float C_multiplier,
+ int32_t C_zero_point,
+ uint8_t* C,
+ const int32_t* col_offsets,
+ const BIAS_TYPE* bias,
+ float act_times_w_scale,
+ int thread_id,
+ int num_threads) {
+ int32_t* C_int32_temp = new int32_t[(K + 31) / 32 * 32];
+ if (A_zero_point == 0 || col_offsets == nullptr) {
+ if (B_zero_point == 0) {
+ depthwise_3x3x3_pad_1_<
+ FUSE_RELU,
+ HAS_BIAS,
+ true /*A_symmetric*/,
+ true /*B_symmetric*/,
+ BIAS_TYPE>(
+ N,
+ T,
+ H,
+ W,
+ K,
+ stride_t,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B_zero_point,
+ B,
+ C_multiplier,
+ C_zero_point,
+ C_int32_temp,
+ C,
+ col_offsets,
+ bias,
+ act_times_w_scale,
+ thread_id,
+ num_threads);
+ } else {
+ depthwise_3x3x3_pad_1_<
+ FUSE_RELU,
+ HAS_BIAS,
+ true /*A_symmetric*/,
+ false /*B_symmetric*/,
+ BIAS_TYPE>(
+ N,
+ T,
+ H,
+ W,
+ K,
+ stride_t,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B_zero_point,
+ B,
+ C_multiplier,
+ C_zero_point,
+ C_int32_temp,
+ C,
+ col_offsets,
+ bias,
+ act_times_w_scale,
+ thread_id,
+ num_threads);
+ }
+ } else {
+ if (B_zero_point == 0) {
+ depthwise_3x3x3_pad_1_<
+ FUSE_RELU,
+ HAS_BIAS,
+ false /*A_symmetric*/,
+ true /*B_symmetric*/,
+ BIAS_TYPE>(
+ N,
+ T,
+ H,
+ W,
+ K,
+ stride_t,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B_zero_point,
+ B,
+ C_multiplier,
+ C_zero_point,
+ C_int32_temp,
+ C,
+ col_offsets,
+ bias,
+ act_times_w_scale,
+ thread_id,
+ num_threads);
+ } else {
+ depthwise_3x3x3_pad_1_<
+ FUSE_RELU,
+ HAS_BIAS,
+ false /*A_symmetric*/,
+ false /*B_symmetric*/,
+ BIAS_TYPE>(
+ N,
+ T,
+ H,
+ W,
+ K,
+ stride_t,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B_zero_point,
+ B,
+ C_multiplier,
+ C_zero_point,
+ C_int32_temp,
+ C,
+ col_offsets,
+ bias,
+ act_times_w_scale,
+ thread_id,
+ num_threads);
+ }
+ }
+ delete[] C_int32_temp;
+}
+
+// Dispatch HAS_BIAS
+template <bool FUSE_RELU, typename BIAS_TYPE>
+static void depthwise_3x3x3_pad_1_(
+ int N,
+ int T,
+ int H,
+ int W,
+ int K,
+ int stride_t,
+ int stride_h,
+ int stride_w,
+ int32_t A_zero_point,
+ const uint8_t* A,
+ int32_t B_zero_point,
+ const PackedDepthWiseConvMatrix& B,
+ float C_multiplier,
+ int32_t C_zero_point,
+ uint8_t* C,
+ const int32_t* col_offsets,
+ const BIAS_TYPE* bias,
+ float act_times_w_scale,
+ int thread_id,
+ int num_threads) {
+ if (bias) {
+ depthwise_3x3x3_pad_1_<FUSE_RELU, true /*HAS_BIAS*/, BIAS_TYPE>(
+ N,
+ T,
+ H,
+ W,
+ K,
+ stride_t,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B_zero_point,
+ B,
+ C_multiplier,
+ C_zero_point,
+ C,
+ col_offsets,
+ bias,
+ act_times_w_scale,
+ thread_id,
+ num_threads);
+ } else {
+ depthwise_3x3x3_pad_1_<FUSE_RELU, false /*HAS_BIAS*/, BIAS_TYPE>(
+ N,
+ T,
+ H,
+ W,
+ K,
+ stride_t,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B_zero_point,
+ B,
+ C_multiplier,
+ C_zero_point,
+ C,
+ col_offsets,
+ bias,
+ act_times_w_scale,
+ thread_id,
+ num_threads);
+ }
+}
+
+// Dispatch FUSE_RELU
+template <typename BIAS_TYPE>
+void depthwise_3x3x3_pad_1(
+ int N,
+ int T,
+ int H,
+ int W,
+ int K,
+ int stride_t,
+ int stride_h,
+ int stride_w,
+ int32_t A_zero_point,
+ const uint8_t* A,
+ int32_t B_zero_point,
+ const PackedDepthWiseConvMatrix& B,
+ float C_multiplier,
+ int32_t C_zero_point,
+ uint8_t* C,
+ const int32_t* col_offsets,
+ const BIAS_TYPE* bias,
+ bool fuse_relu,
+ float act_times_w_scale,
+ int thread_id,
+ int num_threads) {
+ if (B.GetKernelProduct() != 3 * 3 * 3) {
+ string msg =
+ "[FBGEMM_CONV_ERROR] Packed weight is expected to have kernel_prod " +
+ to_string(3 * 3 * 3) + " but has " + to_string(B.GetKernelProduct());
+ throw logic_error(msg);
+ }
+ if (stride_t == 0 || stride_h == 0 || stride_w == 0 || num_threads == 0) {
+ assert(
+ 0 &&
+ "stride_t == 0 || stride_h == 0 || stride_w == 0 || num_threads == 0");
+ return;
+ }
+ if (N == 0) {
+ // In C2, batch size 0 is allowed, so we should just early return.
+ return;
+ }
+ if (fuse_relu) {
+ depthwise_3x3x3_pad_1_<true /*FUSE_RELU*/, BIAS_TYPE>(
+ N,
+ T,
+ H,
+ W,
+ K,
+ stride_t,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B_zero_point,
+ B,
+ C_multiplier,
+ C_zero_point,
+ C,
+ col_offsets,
+ bias,
+ act_times_w_scale,
+ thread_id,
+ num_threads);
+ } else {
+ depthwise_3x3x3_pad_1_<false /*FUSE_RELU*/, BIAS_TYPE>(
+ N,
+ T,
+ H,
+ W,
+ K,
+ stride_t,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B_zero_point,
+ B,
+ C_multiplier,
+ C_zero_point,
+ C,
+ col_offsets,
+ bias,
+ act_times_w_scale,
+ thread_id,
+ num_threads);
+ }
+}
+
+// Dispatch A_SYMMETRIC
+template <bool FUSE_RELU, bool HAS_BIAS, typename BIAS_TYPE>
+static void depthwise_3x3x3_per_channel_quantization_pad_1_(
+ int N,
+ int T,
+ int H,
+ int W,
+ int K,
+ int stride_t,
+ int stride_h,
+ int stride_w,
+ int32_t A_zero_point,
+ const uint8_t* A,
+ const int32_t* B_zero_point,
+ const PackedDepthWiseConvMatrix& B,
+ const float* C_multiplier,
+ int32_t C_zero_point,
+ uint8_t* C,
+ const int32_t* col_offsets,
+ const BIAS_TYPE* bias,
+ const float* act_times_w_scale,
+ int thread_id,
+ int num_threads) {
+ int32_t* C_int32_temp = new int32_t[(K + 31) / 32 * 32];
+ if (A_zero_point == 0 || col_offsets == nullptr) {
+ depthwise_3x3x3_per_channel_quantization_pad_1_<
+ FUSE_RELU,
+ HAS_BIAS,
+ true /*A_SYMM*/,
+ BIAS_TYPE>(
+ N,
+ T,
+ H,
+ W,
+ K,
+ stride_t,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B_zero_point,
+ B,
+ C_multiplier,
+ C_zero_point,
+ C_int32_temp,
+ C,
+ col_offsets,
+ bias,
+ act_times_w_scale,
+ thread_id,
+ num_threads);
+ } else {
+ depthwise_3x3x3_per_channel_quantization_pad_1_<
+ FUSE_RELU,
+ HAS_BIAS,
+ false /*A_SYMM*/,
+ BIAS_TYPE>(
+ N,
+ T,
+ H,
+ W,
+ K,
+ stride_t,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B_zero_point,
+ B,
+ C_multiplier,
+ C_zero_point,
+ C_int32_temp,
+ C,
+ col_offsets,
+ bias,
+ act_times_w_scale,
+ thread_id,
+ num_threads);
+ }
+ delete[] C_int32_temp;
+}
+
+// Dispatch HAS_BIAS
+template <bool FUSE_RELU, typename BIAS_TYPE>
+static void depthwise_3x3x3_per_channel_quantization_pad_1_(
+ int N,
+ int T,
+ int H,
+ int W,
+ int K,
+ int stride_t,
+ int stride_h,
+ int stride_w,
+ int32_t A_zero_point,
+ const uint8_t* A,
+ const int32_t* B_zero_point,
+ const PackedDepthWiseConvMatrix& B,
+ const float* C_multiplier,
+ int32_t C_zero_point,
+ uint8_t* C,
+ const int32_t* col_offsets,
+ const BIAS_TYPE* bias,
+ const float* act_times_w_scale,
+ int thread_id,
+ int num_threads) {
+ if (bias) {
+ depthwise_3x3x3_per_channel_quantization_pad_1_<
+ FUSE_RELU,
+ true /* HAS_BIAS */,
+ BIAS_TYPE>(
+ N,
+ T,
+ H,
+ W,
+ K,
+ stride_t,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B_zero_point,
+ B,
+ C_multiplier,
+ C_zero_point,
+ C,
+ col_offsets,
+ bias,
+ act_times_w_scale,
+ thread_id,
+ num_threads);
+ } else {
+ depthwise_3x3x3_per_channel_quantization_pad_1_<
+ FUSE_RELU,
+ false /* HAS_BIAS */,
+ BIAS_TYPE>(
+ N,
+ T,
+ H,
+ W,
+ K,
+ stride_t,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B_zero_point,
+ B,
+ C_multiplier,
+ C_zero_point,
+ C,
+ col_offsets,
+ bias,
+ act_times_w_scale,
+ thread_id,
+ num_threads);
+ }
+}
+
+// Dispatch FUSE_RELU
+template <typename BIAS_TYPE>
+void depthwise_3x3x3_per_channel_quantization_pad_1(
+ int N,
+ int T,
+ int H,
+ int W,
+ int K,
+ int stride_t,
+ int stride_h,
+ int stride_w,
+ int32_t A_zero_point,
+ const uint8_t* A,
+ const int32_t* B_zero_point,
+ const PackedDepthWiseConvMatrix& B,
+ const float* C_multiplier,
+ int32_t C_zero_point,
+ uint8_t* C,
+ const int32_t* col_offsets,
+ const BIAS_TYPE* bias,
+ bool fuse_relu,
+ const float* act_times_w_scale,
+ int thread_id,
+ int num_threads) {
+ if (B.GetKernelProduct() != 3 * 3 * 3) {
+ string msg =
+ "[FBGEMM_CONV_ERROR] Packed weight is expected to have kernel_prod " +
+ to_string(3 * 3 * 3) + " but has " + to_string(B.GetKernelProduct());
+ throw logic_error(msg);
+ }
+ if (stride_t == 0 || stride_h == 0 || stride_w == 0 || num_threads == 0) {
+ assert(
+ 0 &&
+ "stride_t == 0 || stride_h == 0 || stride_w == 0 || num_threads == 0");
+ return;
+ }
+ if (N == 0) {
+ // In C2, batch size 0 is allowed, so we should just early return.
+ return;
+ }
+ if (fuse_relu) {
+ depthwise_3x3x3_per_channel_quantization_pad_1_<
+ true /* FUSE_RELU */,
+ BIAS_TYPE>(
+ N,
+ T,
+ H,
+ W,
+ K,
+ stride_t,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B_zero_point,
+ B,
+ C_multiplier,
+ C_zero_point,
+ C,
+ col_offsets,
+ bias,
+ act_times_w_scale,
+ thread_id,
+ num_threads);
+ } else {
+ depthwise_3x3x3_per_channel_quantization_pad_1_<
+ false /* FUSE_RELU */,
+ BIAS_TYPE>(
+ N,
+ T,
+ H,
+ W,
+ K,
+ stride_t,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B_zero_point,
+ B,
+ C_multiplier,
+ C_zero_point,
+ C,
+ col_offsets,
+ bias,
+ act_times_w_scale,
+ thread_id,
+ num_threads);
+ }
+}
+
+// To be removed
+void depthwise_3x3x3_pad_1(
+ int N,
+ int T,
+ int H,
+ int W,
+ int K,
+ int stride_t,
+ int stride_h,
+ int stride_w,
+ int32_t A_zero_point,
+ const uint8_t* A,
+ int32_t B_zero_point,
+ const PackedDepthWiseConvMatrix& B,
+ float C_multiplier,
+ int32_t C_zero_point,
+ uint8_t* C,
+ const int32_t* col_offsets,
+ const int32_t* bias,
+ bool fuse_relu,
+ int thread_id,
+ int num_threads) {
+ depthwise_3x3x3_pad_1<int32_t>(
+ N,
+ T,
+ H,
+ W,
+ K,
+ stride_t,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B_zero_point,
+ B,
+ C_multiplier,
+ C_zero_point,
+ C,
+ col_offsets,
+ bias,
+ fuse_relu,
+ 1.0f, // act_scale * weight_scale
+ thread_id,
+ num_threads);
+}
+
+void depthwise_3x3x3_per_channel_quantization_pad_1(
+ int N,
+ int T,
+ int H,
+ int W,
+ int K,
+ int stride_t,
+ int stride_h,
+ int stride_w,
+ int32_t A_zero_point,
+ const uint8_t* A,
+ const int32_t* B_zero_point,
+ const PackedDepthWiseConvMatrix& B,
+ const float* C_multiplier,
+ int32_t C_zero_point,
+ uint8_t* C,
+ const int32_t* col_offsets,
+ const int32_t* bias,
+ bool fuse_relu,
+ int thread_id,
+ int num_threads) {
+ depthwise_3x3x3_per_channel_quantization_pad_1(
+ N,
+ T,
+ H,
+ W,
+ K,
+ stride_t,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B_zero_point,
+ B,
+ C_multiplier,
+ C_zero_point,
+ C,
+ col_offsets,
+ bias,
+ fuse_relu,
+ nullptr, // act_scale * weight_scale
+ thread_id,
+ num_threads);
+}
+
+template void depthwise_3x3x3_pad_1(
+ int N,
+ int T,
+ int H,
+ int W,
+ int K,
+ int stride_t,
+ int stride_h,
+ int stride_w,
+ int32_t A_zero_point,
+ const uint8_t* A,
+ int32_t B_zero_point,
+ const PackedDepthWiseConvMatrix& B,
+ float C_multiplier,
+ int32_t C_zero_point,
+ uint8_t* C,
+ const int32_t* col_offsets,
+ const int32_t* bias,
+ bool fuse_relu,
+ float act_times_w_scale,
+ int thread_id,
+ int num_threads);
+
+template void depthwise_3x3x3_pad_1(
+ int N,
+ int T,
+ int H,
+ int W,
+ int K,
+ int stride_t,
+ int stride_h,
+ int stride_w,
+ int32_t A_zero_point,
+ const uint8_t* A,
+ int32_t B_zero_point,
+ const PackedDepthWiseConvMatrix& B,
+ float C_multiplier,
+ int32_t C_zero_point,
+ uint8_t* C,
+ const int32_t* col_offsets,
+ const float* bias,
+ bool fuse_relu,
+ float act_times_w_scale,
+ int thread_id,
+ int num_threads);
+
+template void depthwise_3x3x3_per_channel_quantization_pad_1(
+ int N,
+ int T,
+ int H,
+ int W,
+ int K,
+ int stride_t,
+ int stride_h,
+ int stride_w,
+ int32_t A_zero_point,
+ const uint8_t* A,
+ const int32_t* B_zero_point,
+ const PackedDepthWiseConvMatrix& B,
+ const float* C_multiplier,
+ int32_t C_zero_point,
+ uint8_t* C,
+ const int32_t* col_offsets,
+ const int32_t* bias,
+ bool fuse_relu,
+ const float* act_times_w_scale,
+ int thread_id,
+ int num_threads);
+
+template void depthwise_3x3x3_per_channel_quantization_pad_1(
+ int N,
+ int T,
+ int H,
+ int W,
+ int K,
+ int stride_t,
+ int stride_h,
+ int stride_w,
+ int32_t A_zero_point,
+ const uint8_t* A,
+ const int32_t* B_zero_point,
+ const PackedDepthWiseConvMatrix& B,
+ const float* C_multiplier,
+ int32_t C_zero_point,
+ uint8_t* C,
+ const int32_t* col_offsets,
+ const float* bias,
+ bool fuse_relu,
+ const float* act_times_w_scale,
+ int thread_id,
+ int num_threads);
+
+} // namespace fbgemm
diff --git a/src/FbgemmI8DepthwiseAvx2-inl.h b/src/FbgemmI8DepthwiseAvx2-inl.h
new file mode 100644
index 0000000..aee9ab3
--- /dev/null
+++ b/src/FbgemmI8DepthwiseAvx2-inl.h
@@ -0,0 +1,710 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+#pragma once
+
+#include <algorithm> // for min and max
+#include <cassert>
+#include <cmath> // for lrintf and sqrt
+#include <cstdint>
+#include <type_traits> // for is_same
+
+#include <immintrin.h>
+#include "fbgemm/Utils.h"
+
+namespace fbgemm {
+
+// clang-format off
+static int masks[8][8] = {
+ // NOTE: clang-format wants to use a different formatting but the current
+ // formatting should be easier to read.
+ { 0, 0, 0, 0, 0, 0, 0, 0, },
+ { -1, 0, 0, 0, 0, 0, 0, 0, },
+ { -1, -1, 0, 0, 0, 0, 0, 0, },
+ { -1, -1, -1, 0, 0, 0, 0, 0, },
+ { -1, -1, -1, -1, 0, 0, 0, 0, },
+ { -1, -1, -1, -1, -1, 0, 0, 0, },
+ { -1, -1, -1, -1, -1, -1, 0, 0, },
+ { -1, -1, -1, -1, -1, -1, -1, 0, },
+};
+// clang-format on
+
+// c = a0 * b0 + a1 * b1 + a2 * b2 + a3 * b3
+// A is in uint8_t
+// B is in int8_t and pre-interleaved
+// C is in int32_t and 4 registers have results in the following layout:
+// c0_v: c[0:4], c[16:20]
+// c1_v: c[4:8], c[20:24]
+// c2_v: c[8:12], c[24:28]
+// c3_v: c[12:16], c[28:32]
+template <bool SUM_A = false>
+static inline ALWAYS_INLINE void madd_epi16x4_packed(
+ __m256i a0_v,
+ __m256i a1_v,
+ __m256i a2_v,
+ __m256i a3_v,
+ const __m256i* b,
+ __m256i* c0_v,
+ __m256i* c1_v,
+ __m256i* c2_v,
+ __m256i* c3_v,
+ __m256i* a_sum = nullptr) {
+ __m256i a01_lo_v = _mm256_unpacklo_epi8(a0_v, a1_v);
+ __m256i a01_hi_v = _mm256_unpackhi_epi8(a0_v, a1_v);
+ __m256i a23_lo_v = _mm256_unpacklo_epi8(a2_v, a3_v);
+ __m256i a23_hi_v = _mm256_unpackhi_epi8(a2_v, a3_v);
+
+ if (SUM_A) {
+ __m256i one_epi8_v = _mm256_set1_epi8(1);
+ a_sum[0] =
+ _mm256_adds_epi16(_mm256_maddubs_epi16(a01_lo_v, one_epi8_v), a_sum[0]);
+ a_sum[1] =
+ _mm256_adds_epi16(_mm256_maddubs_epi16(a01_hi_v, one_epi8_v), a_sum[1]);
+ a_sum[0] =
+ _mm256_adds_epi16(_mm256_maddubs_epi16(a23_lo_v, one_epi8_v), a_sum[0]);
+ a_sum[1] =
+ _mm256_adds_epi16(_mm256_maddubs_epi16(a23_hi_v, one_epi8_v), a_sum[1]);
+ }
+
+ __m256i a0_interleaved_v = _mm256_unpacklo_epi16(a01_lo_v, a23_lo_v);
+ __m256i a1_interleaved_v = _mm256_unpackhi_epi16(a01_lo_v, a23_lo_v);
+ __m256i a2_interleaved_v = _mm256_unpacklo_epi16(a01_hi_v, a23_hi_v);
+ __m256i a3_interleaved_v = _mm256_unpackhi_epi16(a01_hi_v, a23_hi_v);
+
+ __m256i b0_v = _mm256_load_si256(b + 0);
+ __m256i b1_v = _mm256_load_si256(b + 1);
+ __m256i b2_v = _mm256_load_si256(b + 2);
+ __m256i b3_v = _mm256_load_si256(b + 3);
+
+ __m256i ab0 = _mm256_maddubs_epi16(a0_interleaved_v, b0_v);
+ __m256i ab1 = _mm256_maddubs_epi16(a1_interleaved_v, b1_v);
+ __m256i ab2 = _mm256_maddubs_epi16(a2_interleaved_v, b2_v);
+ __m256i ab3 = _mm256_maddubs_epi16(a3_interleaved_v, b3_v);
+
+ __m256i one_v = _mm256_set1_epi16(1);
+ *c0_v = _mm256_madd_epi16(ab0, one_v);
+ *c1_v = _mm256_madd_epi16(ab1, one_v);
+ *c2_v = _mm256_madd_epi16(ab2, one_v);
+ *c3_v = _mm256_madd_epi16(ab3, one_v);
+}
+
+// c = a0 * b0 + a1 * b1 + a2 * b2
+// A is in uint8_t
+// B is in int8_t and pre-interleaved
+// C is in int32_t and 4 registers have results in the following layout:
+// c0_v: c[0:4], c[16:20]
+// c1_v: c[4:8], c[20:24]
+// c2_v: c[8:12], c[24:28]
+// c3_v: c[12:16], c[28:32]
+template <bool SUM_A = false>
+static inline ALWAYS_INLINE void madd_epi16x3_packed(
+ __m256i a0_v,
+ __m256i a1_v,
+ __m256i a2_v,
+ const __m256i* b,
+ __m256i* c0_v,
+ __m256i* c1_v,
+ __m256i* c2_v,
+ __m256i* c3_v,
+ __m256i* a_sum = nullptr) {
+ __m256i zero_v = _mm256_setzero_si256();
+
+ __m256i a01_lo_v = _mm256_unpacklo_epi8(a0_v, a1_v);
+ __m256i a01_hi_v = _mm256_unpackhi_epi8(a0_v, a1_v);
+ __m256i a23_lo_v = _mm256_unpacklo_epi8(a2_v, zero_v);
+ __m256i a23_hi_v = _mm256_unpackhi_epi8(a2_v, zero_v);
+
+ if (SUM_A) {
+ __m256i one_epi8_v = _mm256_set1_epi8(1);
+ a_sum[0] =
+ _mm256_adds_epi16(_mm256_maddubs_epi16(a01_lo_v, one_epi8_v), a_sum[0]);
+ a_sum[1] =
+ _mm256_adds_epi16(_mm256_maddubs_epi16(a01_hi_v, one_epi8_v), a_sum[1]);
+ a_sum[0] =
+ _mm256_adds_epi16(_mm256_maddubs_epi16(a23_lo_v, one_epi8_v), a_sum[0]);
+ a_sum[1] =
+ _mm256_adds_epi16(_mm256_maddubs_epi16(a23_hi_v, one_epi8_v), a_sum[1]);
+ }
+
+ __m256i a0_interleaved_v = _mm256_unpacklo_epi16(a01_lo_v, a23_lo_v);
+ __m256i a1_interleaved_v = _mm256_unpackhi_epi16(a01_lo_v, a23_lo_v);
+ __m256i a2_interleaved_v = _mm256_unpacklo_epi16(a01_hi_v, a23_hi_v);
+ __m256i a3_interleaved_v = _mm256_unpackhi_epi16(a01_hi_v, a23_hi_v);
+
+ __m256i b0_v = _mm256_load_si256(b + 0);
+ __m256i b1_v = _mm256_load_si256(b + 1);
+ __m256i b2_v = _mm256_load_si256(b + 2);
+ __m256i b3_v = _mm256_load_si256(b + 3);
+
+ __m256i ab0 = _mm256_maddubs_epi16(a0_interleaved_v, b0_v);
+ __m256i ab1 = _mm256_maddubs_epi16(a1_interleaved_v, b1_v);
+ __m256i ab2 = _mm256_maddubs_epi16(a2_interleaved_v, b2_v);
+ __m256i ab3 = _mm256_maddubs_epi16(a3_interleaved_v, b3_v);
+
+ __m256i one_v = _mm256_set1_epi16(1);
+ *c0_v = _mm256_madd_epi16(ab0, one_v);
+ *c1_v = _mm256_madd_epi16(ab1, one_v);
+ *c2_v = _mm256_madd_epi16(ab2, one_v);
+ *c3_v = _mm256_madd_epi16(ab3, one_v);
+}
+
+// c = a0 * b0 + a1 * b1
+// A is in uint8_t
+// B is in int8_t and pre-interleaved
+// C is in int32_t and 4 registers have results in the following layout:
+// c0_v: c[0:4], c[4:8]
+// c1_v: c[8:12], c[12:16]
+// c2_v: c[16:20], c[20:24]
+// c3_v: c[24:28], c[28:32]
+template <bool SUM_A = false>
+static inline ALWAYS_INLINE void madd_epi16x2_packed(
+ __m256i a0_v,
+ __m256i a1_v,
+ const __m256i* b,
+ __m256i* c0_v,
+ __m256i* c1_v,
+ __m256i* c2_v,
+ __m256i* c3_v,
+ __m256i* a_sum = nullptr) {
+ __m256i a_lo_v = _mm256_unpacklo_epi8(a0_v, a1_v);
+ __m256i a_hi_v = _mm256_unpackhi_epi8(a0_v, a1_v);
+
+ if (SUM_A) {
+ __m256i one_epi8_v = _mm256_set1_epi8(1);
+ a_sum[0] =
+ _mm256_adds_epi16(_mm256_maddubs_epi16(a_lo_v, one_epi8_v), a_sum[0]);
+ a_sum[1] =
+ _mm256_adds_epi16(_mm256_maddubs_epi16(a_hi_v, one_epi8_v), a_sum[1]);
+ }
+
+ __m256i b0_v = _mm256_load_si256(b + 0);
+ __m256i b1_v = _mm256_load_si256(b + 1);
+
+ __m256i ab_lo_v = _mm256_maddubs_epi16(a_lo_v, b0_v);
+ __m256i ab_hi_v = _mm256_maddubs_epi16(a_hi_v, b1_v);
+
+ *c0_v = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(ab_lo_v));
+ *c1_v = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(ab_hi_v));
+ *c2_v = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(ab_lo_v, 1));
+ *c3_v = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(ab_hi_v, 1));
+}
+
+// c = a0 * b0
+// A is in uint8_t
+// B is in int8_t and pre-interleaved
+// C is in int32_t and 4 registers have results in the following layout:
+// c0_v: c[0:4], c[4:8]
+// c1_v: c[8:12], c[12:16]
+// c2_v: c[16:20], c[20:24]
+// c3_v: c[24:28], c[28:32]
+template <bool SUM_A = false>
+static inline ALWAYS_INLINE void madd_epi16_packed(
+ __m256i a_v,
+ const __m256i* b,
+ __m256i* c0_v,
+ __m256i* c1_v,
+ __m256i* c2_v,
+ __m256i* c3_v,
+ __m256i* a_sum = nullptr) {
+ __m256i zero_v = _mm256_setzero_si256();
+
+ __m256i a_lo_v = _mm256_unpacklo_epi8(a_v, zero_v);
+ __m256i a_hi_v = _mm256_unpackhi_epi8(a_v, zero_v);
+
+ if (SUM_A) {
+ __m256i one_epi8_v = _mm256_set1_epi8(1);
+ a_sum[0] =
+ _mm256_adds_epi16(_mm256_maddubs_epi16(a_lo_v, one_epi8_v), a_sum[0]);
+ a_sum[1] =
+ _mm256_adds_epi16(_mm256_maddubs_epi16(a_hi_v, one_epi8_v), a_sum[1]);
+ }
+
+ __m256i b0_v = _mm256_load_si256(b + 0);
+ __m256i b1_v = _mm256_load_si256(b + 1);
+
+ __m256i ab_lo_v = _mm256_maddubs_epi16(a_lo_v, b0_v);
+ __m256i ab_hi_v = _mm256_maddubs_epi16(a_hi_v, b1_v);
+
+ *c0_v = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(ab_lo_v));
+ *c1_v = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(ab_hi_v));
+ *c2_v = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(ab_lo_v, 1));
+ *c3_v = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(ab_hi_v, 1));
+}
+
+// K is the number of accumulations we're doing
+template <int K, bool SUM_A = false, bool REMAINDER = false, bool ACC = false>
+static inline ALWAYS_INLINE void inner_prod_packed_(
+ const __m256i* a_v,
+ const __m256i* Bp,
+ std::int32_t* C,
+ int remainder,
+ __m256i* a_sum = nullptr) {
+ __m256i c[4], c_temp[4];
+ __m256i a_sum_temp[2] = {0, 0};
+
+ int k = 0;
+ if (K >= 4) {
+ madd_epi16x4_packed<SUM_A>(
+ a_v[0],
+ a_v[1],
+ a_v[2],
+ a_v[3],
+ Bp,
+ &c[0],
+ &c[1],
+ &c[2],
+ &c[3],
+ a_sum_temp);
+
+ for (k = 4; k < K / 4 * 4; k += 4) {
+ madd_epi16x4_packed<SUM_A>(
+ a_v[k + 0],
+ a_v[k + 1],
+ a_v[k + 2],
+ a_v[k + 3],
+ Bp + k,
+ &c_temp[0],
+ &c_temp[1],
+ &c_temp[2],
+ &c_temp[3],
+ a_sum_temp);
+
+ c[0] = _mm256_add_epi32(c[0], c_temp[0]);
+ c[1] = _mm256_add_epi32(c[1], c_temp[1]);
+ c[2] = _mm256_add_epi32(c[2], c_temp[2]);
+ c[3] = _mm256_add_epi32(c[3], c_temp[3]);
+ }
+ } else {
+ c[0] = _mm256_setzero_si256();
+ c[1] = _mm256_setzero_si256();
+ c[2] = _mm256_setzero_si256();
+ c[3] = _mm256_setzero_si256();
+ }
+
+ if (K - k == 3) {
+ madd_epi16x3_packed<SUM_A>(
+ a_v[k],
+ a_v[k + 1],
+ a_v[k + 2],
+ Bp + k,
+ &c_temp[0],
+ &c_temp[1],
+ &c_temp[2],
+ &c_temp[3],
+ a_sum_temp);
+
+ c[0] = _mm256_add_epi32(c[0], c_temp[0]);
+ c[1] = _mm256_add_epi32(c[1], c_temp[1]);
+ c[2] = _mm256_add_epi32(c[2], c_temp[2]);
+ c[3] = _mm256_add_epi32(c[3], c_temp[3]);
+ }
+
+ c_temp[0] = _mm256_permute2f128_si256(c[0], c[1], 0x20);
+ c_temp[1] = _mm256_permute2f128_si256(c[2], c[3], 0x20);
+ c_temp[2] = _mm256_permute2f128_si256(c[0], c[1], 0x31);
+ c_temp[3] = _mm256_permute2f128_si256(c[2], c[3], 0x31);
+
+ if (K - k == 0 || K - k == 3) {
+ c[0] = c_temp[0];
+ c[1] = c_temp[1];
+ c[2] = c_temp[2];
+ c[3] = c_temp[3];
+ } else {
+ if (K - k == 1) {
+ madd_epi16_packed<SUM_A>(
+ a_v[k], Bp + k, &c[0], &c[1], &c[2], &c[3], a_sum_temp);
+ } else if (K - k == 2) {
+ madd_epi16x2_packed<SUM_A>(
+ a_v[k], a_v[k + 1], Bp + k, &c[0], &c[1], &c[2], &c[3], a_sum_temp);
+ }
+
+ c[0] = _mm256_add_epi32(c[0], c_temp[0]);
+ c[1] = _mm256_add_epi32(c[1], c_temp[1]);
+ c[2] = _mm256_add_epi32(c[2], c_temp[2]);
+ c[3] = _mm256_add_epi32(c[3], c_temp[3]);
+ }
+
+ if (REMAINDER) {
+ for (int r = 0; r < remainder / 8; ++r) {
+ if (ACC) {
+ _mm256_storeu_si256(
+ reinterpret_cast<__m256i*>(C + r * 8),
+ _mm256_add_epi32(
+ _mm256_loadu_si256(reinterpret_cast<__m256i*>(C + r * 8)),
+ c[r]));
+ } else {
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(C + r * 8), c[r]);
+ }
+ }
+ } else {
+ if (ACC) {
+ _mm256_storeu_si256(
+ reinterpret_cast<__m256i*>(C),
+ _mm256_add_epi32(
+ _mm256_loadu_si256(reinterpret_cast<__m256i*>(C)), c[0]));
+ _mm256_storeu_si256(
+ reinterpret_cast<__m256i*>(C + 8),
+ _mm256_add_epi32(
+ _mm256_loadu_si256(reinterpret_cast<__m256i*>(C + 8)), c[1]));
+ _mm256_storeu_si256(
+ reinterpret_cast<__m256i*>(C + 16),
+ _mm256_add_epi32(
+ _mm256_loadu_si256(reinterpret_cast<__m256i*>(C + 16)), c[2]));
+ _mm256_storeu_si256(
+ reinterpret_cast<__m256i*>(C + 24),
+ _mm256_add_epi32(
+ _mm256_loadu_si256(reinterpret_cast<__m256i*>(C + 24)), c[3]));
+ } else {
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(C), c[0]);
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(C + 8), c[1]);
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(C + 16), c[2]);
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(C + 24), c[3]);
+ }
+ }
+
+ if (SUM_A) {
+ a_sum[0] = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(a_sum_temp[0]));
+ a_sum[1] = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(a_sum_temp[1]));
+ a_sum[2] =
+ _mm256_cvtepi16_epi32(_mm256_extracti128_si256(a_sum_temp[0], 1));
+ a_sum[3] =
+ _mm256_cvtepi16_epi32(_mm256_extracti128_si256(a_sum_temp[1], 1));
+ }
+}
+
+// Almost same as ReQuantizeOutput in OutputProcessing-inh.h but different
+// row_offsets for each row because of depth-wise convolution
+template <
+ bool FUSE_RELU,
+ bool HAS_BIAS,
+ bool PER_CHANNEL_QUANTIZATION,
+ bool A_SYMMETRIC,
+ bool B_SYMMETRIC,
+ typename BIAS_TYPE>
+static inline ALWAYS_INLINE void requantize_(
+ std::int32_t A_zero_point,
+ const float* C_multiplier,
+ std::int32_t C_zero_point,
+ const std::int32_t* C_int32,
+ std::uint8_t* C_uint8,
+ int n,
+ const std::int32_t* row_offsets,
+ const std::int32_t* col_offsets,
+ const BIAS_TYPE* bias,
+ const float* act_times_w_scale = nullptr) {
+ __m256 multiplier_v = _mm256_setzero_ps();
+ // Broadcasted reciprocal of act_times_w_scale
+ __m256 act_times_w_rcp_v = _mm256_setzero_ps();
+ if (!PER_CHANNEL_QUANTIZATION) {
+ multiplier_v = _mm256_set1_ps(*C_multiplier);
+ if (std::is_same<BIAS_TYPE, float>::value) {
+ act_times_w_rcp_v = _mm256_set1_ps(1.0f / (*act_times_w_scale));
+ }
+ }
+
+ __m256i min_v = _mm256_set1_epi8(static_cast<std::uint8_t>(0));
+ __m256i max_v = _mm256_set1_epi8(static_cast<std::uint8_t>(255));
+
+ if (A_SYMMETRIC) {
+ assert(A_zero_point == 0 || col_offsets == nullptr);
+ }
+ __m256i A_zero_point_v = _mm256_set1_epi32(A_zero_point);
+ __m256i C_zero_point_epi16_v = _mm256_set1_epi16(C_zero_point);
+ __m256i C_zero_point_epi8_v = _mm256_set1_epi8(C_zero_point);
+
+ __m256i permute_mask_v =
+ _mm256_set_epi32(0x07, 0x03, 0x06, 0x02, 0x05, 0x01, 0x04, 0x00);
+
+ constexpr int VLEN = 8;
+ int j = 0;
+ for (; j < n / (VLEN * 4) * (VLEN * 4); j += (VLEN * 4)) {
+ __m256i x_v =
+ _mm256_loadu_si256(reinterpret_cast<const __m256i*>(C_int32 + j));
+ __m256i y_v = _mm256_loadu_si256(
+ reinterpret_cast<const __m256i*>(C_int32 + j + VLEN));
+ __m256i z_v = _mm256_loadu_si256(
+ reinterpret_cast<const __m256i*>(C_int32 + j + 2 * VLEN));
+ __m256i w_v = _mm256_loadu_si256(
+ reinterpret_cast<const __m256i*>(C_int32 + j + 3 * VLEN));
+
+ __m256i row_offset_v;
+ if (!B_SYMMETRIC) {
+ row_offset_v =
+ _mm256_loadu_si256(reinterpret_cast<const __m256i*>(row_offsets + j));
+ x_v = _mm256_sub_epi32(x_v, row_offset_v);
+ }
+ __m256i col_off_v;
+ if (!A_SYMMETRIC) {
+ col_off_v = _mm256_mullo_epi32(
+ A_zero_point_v,
+ _mm256_loadu_si256(
+ reinterpret_cast<const __m256i*>(col_offsets + j)));
+ x_v = _mm256_sub_epi32(x_v, col_off_v);
+ }
+
+ if (!B_SYMMETRIC) {
+ row_offset_v = _mm256_loadu_si256(
+ reinterpret_cast<const __m256i*>(row_offsets + j + VLEN));
+ y_v = _mm256_sub_epi32(y_v, row_offset_v);
+ }
+ if (!A_SYMMETRIC) {
+ col_off_v = _mm256_mullo_epi32(
+ A_zero_point_v,
+ _mm256_loadu_si256(
+ reinterpret_cast<const __m256i*>(col_offsets + j + VLEN)));
+ y_v = _mm256_sub_epi32(y_v, col_off_v);
+ }
+
+ if (!B_SYMMETRIC) {
+ row_offset_v = _mm256_loadu_si256(
+ reinterpret_cast<const __m256i*>(row_offsets + j + 2 * VLEN));
+ z_v = _mm256_sub_epi32(z_v, row_offset_v);
+ }
+ if (!A_SYMMETRIC) {
+ col_off_v = _mm256_mullo_epi32(
+ A_zero_point_v,
+ _mm256_loadu_si256(
+ reinterpret_cast<const __m256i*>(col_offsets + j + 2 * VLEN)));
+ z_v = _mm256_sub_epi32(z_v, col_off_v);
+ }
+
+ if (!B_SYMMETRIC) {
+ row_offset_v = _mm256_loadu_si256(
+ reinterpret_cast<const __m256i*>(row_offsets + j + 3 * VLEN));
+ w_v = _mm256_sub_epi32(w_v, row_offset_v);
+ }
+ if (!A_SYMMETRIC) {
+ col_off_v = _mm256_mullo_epi32(
+ A_zero_point_v,
+ _mm256_loadu_si256(
+ reinterpret_cast<const __m256i*>(col_offsets + j + 3 * VLEN)));
+ w_v = _mm256_sub_epi32(w_v, col_off_v);
+ }
+
+ // convert to float
+ __m256 xf_v, yf_v, zf_v, wf_v;
+ if (HAS_BIAS) { // static if
+ if (std::is_same<BIAS_TYPE, float>::value) {
+ __m256 x_bias_v, y_bias_v, z_bias_v, w_bias_v;
+ if (PER_CHANNEL_QUANTIZATION) {
+ x_bias_v = _mm256_div_ps(
+ _mm256_loadu_ps(
+ reinterpret_cast<const float*>(bias + j + 0 * VLEN)),
+ _mm256_loadu_ps(act_times_w_scale + j + 0 * VLEN));
+ y_bias_v = _mm256_div_ps(
+ _mm256_loadu_ps(
+ reinterpret_cast<const float*>(bias + j + 1 * VLEN)),
+ _mm256_loadu_ps(act_times_w_scale + j + 1 * VLEN));
+ z_bias_v = _mm256_div_ps(
+ _mm256_loadu_ps(
+ reinterpret_cast<const float*>(bias + j + 2 * VLEN)),
+ _mm256_loadu_ps(act_times_w_scale + j + 2 * VLEN));
+ w_bias_v = _mm256_div_ps(
+ _mm256_loadu_ps(
+ reinterpret_cast<const float*>(bias + j + 3 * VLEN)),
+ _mm256_loadu_ps(act_times_w_scale + j + 3 * VLEN));
+ } else {
+ x_bias_v = _mm256_mul_ps(
+ _mm256_loadu_ps(
+ reinterpret_cast<const float*>(bias + j + 0 * VLEN)),
+ act_times_w_rcp_v);
+ y_bias_v = _mm256_mul_ps(
+ _mm256_loadu_ps(
+ reinterpret_cast<const float*>(bias + j + 1 * VLEN)),
+ act_times_w_rcp_v);
+ z_bias_v = _mm256_mul_ps(
+ _mm256_loadu_ps(
+ reinterpret_cast<const float*>(bias + j + 2 * VLEN)),
+ act_times_w_rcp_v);
+ w_bias_v = _mm256_mul_ps(
+ _mm256_loadu_ps(
+ reinterpret_cast<const float*>(bias + j + 3 * VLEN)),
+ act_times_w_rcp_v);
+ }
+ xf_v = _mm256_add_ps(_mm256_cvtepi32_ps(x_v), x_bias_v);
+ yf_v = _mm256_add_ps(_mm256_cvtepi32_ps(y_v), y_bias_v);
+ zf_v = _mm256_add_ps(_mm256_cvtepi32_ps(z_v), z_bias_v);
+ wf_v = _mm256_add_ps(_mm256_cvtepi32_ps(w_v), w_bias_v);
+ } else {
+ x_v = _mm256_add_epi32(
+ x_v,
+ _mm256_loadu_si256(
+ reinterpret_cast<const __m256i*>(bias + j + 0 * VLEN)));
+ y_v = _mm256_add_epi32(
+ y_v,
+ _mm256_loadu_si256(
+ reinterpret_cast<const __m256i*>(bias + j + 1 * VLEN)));
+ z_v = _mm256_add_epi32(
+ z_v,
+ _mm256_loadu_si256(
+ reinterpret_cast<const __m256i*>(bias + j + 2 * VLEN)));
+ w_v = _mm256_add_epi32(
+ w_v,
+ _mm256_loadu_si256(
+ reinterpret_cast<const __m256i*>(bias + j + 3 * VLEN)));
+ xf_v = _mm256_cvtepi32_ps(x_v);
+ yf_v = _mm256_cvtepi32_ps(y_v);
+ zf_v = _mm256_cvtepi32_ps(z_v);
+ wf_v = _mm256_cvtepi32_ps(w_v);
+ }
+ } else {
+ xf_v = _mm256_cvtepi32_ps(x_v);
+ yf_v = _mm256_cvtepi32_ps(y_v);
+ zf_v = _mm256_cvtepi32_ps(z_v);
+ wf_v = _mm256_cvtepi32_ps(w_v);
+ }
+
+ if (PER_CHANNEL_QUANTIZATION) {
+ multiplier_v = _mm256_loadu_ps(C_multiplier + j + 0 * VLEN);
+ }
+ __m256 x_scaled_v = _mm256_mul_ps(xf_v, multiplier_v);
+ if (PER_CHANNEL_QUANTIZATION) {
+ multiplier_v = _mm256_loadu_ps(C_multiplier + j + 1 * VLEN);
+ }
+ __m256 y_scaled_v = _mm256_mul_ps(yf_v, multiplier_v);
+ if (PER_CHANNEL_QUANTIZATION) {
+ multiplier_v = _mm256_loadu_ps(C_multiplier + j + 2 * VLEN);
+ }
+ __m256 z_scaled_v = _mm256_mul_ps(zf_v, multiplier_v);
+ if (PER_CHANNEL_QUANTIZATION) {
+ multiplier_v = _mm256_loadu_ps(C_multiplier + j + 3 * VLEN);
+ }
+ __m256 w_scaled_v = _mm256_mul_ps(wf_v, multiplier_v);
+
+ __m256i x_rounded_v = _mm256_cvtps_epi32(x_scaled_v);
+ __m256i y_rounded_v = _mm256_cvtps_epi32(y_scaled_v);
+ __m256i z_rounded_v = _mm256_cvtps_epi32(z_scaled_v);
+ __m256i w_rounded_v = _mm256_cvtps_epi32(w_scaled_v);
+
+ __m256i xy_packed_v = _mm256_adds_epi16(
+ _mm256_packs_epi32(x_rounded_v, y_rounded_v), C_zero_point_epi16_v);
+ __m256i zw_packed_v = _mm256_adds_epi16(
+ _mm256_packs_epi32(z_rounded_v, w_rounded_v), C_zero_point_epi16_v);
+ __m256i xyzw_packed_v = _mm256_packus_epi16(xy_packed_v, zw_packed_v);
+ __m256i xyzw_clamped_v = _mm256_max_epu8(
+ FUSE_RELU ? C_zero_point_epi8_v : min_v,
+ _mm256_min_epu8(xyzw_packed_v, max_v));
+
+ xyzw_clamped_v =
+ _mm256_permutevar8x32_epi32(xyzw_clamped_v, permute_mask_v);
+
+ _mm256_storeu_si256(
+ reinterpret_cast<__m256i*>(C_uint8 + j), xyzw_clamped_v);
+ } // j loop vectorized and unrolled 4x
+
+ for (; j < n / VLEN * VLEN; j += VLEN) {
+ __m256i x_v =
+ _mm256_loadu_si256(reinterpret_cast<const __m256i*>(C_int32 + j));
+
+ if (!B_SYMMETRIC) {
+ __m256i row_offset_v =
+ _mm256_loadu_si256(reinterpret_cast<const __m256i*>(row_offsets + j));
+ x_v = _mm256_sub_epi32(x_v, row_offset_v);
+ }
+ if (!A_SYMMETRIC) {
+ __m256i col_off_v = _mm256_mullo_epi32(
+ A_zero_point_v,
+ _mm256_loadu_si256(
+ reinterpret_cast<const __m256i*>(col_offsets + j)));
+ x_v = _mm256_sub_epi32(x_v, col_off_v);
+ }
+
+ // Convert to float
+ __m256 xf_v;
+ if (HAS_BIAS) { // static if
+ if (std::is_same<BIAS_TYPE, float>::value) {
+ __m256 x_bias_v;
+ if (PER_CHANNEL_QUANTIZATION) {
+ x_bias_v = _mm256_div_ps(
+ _mm256_loadu_ps(reinterpret_cast<const float*>(bias + j)),
+ _mm256_loadu_ps(act_times_w_scale + j));
+ } else {
+ x_bias_v = _mm256_mul_ps(
+ _mm256_loadu_ps(reinterpret_cast<const float*>(bias + j)),
+ act_times_w_rcp_v);
+ }
+ xf_v = _mm256_add_ps(_mm256_cvtepi32_ps(x_v), x_bias_v);
+ } else {
+ x_v = _mm256_add_epi32(
+ x_v,
+ _mm256_loadu_si256(reinterpret_cast<const __m256i*>(bias + j)));
+ xf_v = _mm256_cvtepi32_ps(x_v);
+ }
+ } else {
+ xf_v = _mm256_cvtepi32_ps(x_v);
+ }
+
+ if (PER_CHANNEL_QUANTIZATION) {
+ multiplier_v = _mm256_loadu_ps(C_multiplier + j);
+ }
+ __m256 x_scaled_v = _mm256_mul_ps(xf_v, multiplier_v);
+ __m256i x_rounded_v = _mm256_cvtps_epi32(x_scaled_v);
+
+ __m256i x_packed_v = _mm256_adds_epi16(
+ _mm256_packs_epi32(x_rounded_v, _mm256_setzero_si256()),
+ C_zero_point_epi16_v);
+ x_packed_v = _mm256_packus_epi16(x_packed_v, _mm256_setzero_si256());
+ __m256i x_clamped_v = _mm256_max_epu8(
+ FUSE_RELU ? C_zero_point_epi8_v : min_v,
+ _mm256_min_epu8(x_packed_v, max_v));
+
+ x_clamped_v = _mm256_permutevar8x32_epi32(x_clamped_v, permute_mask_v);
+
+ _mm_storel_epi64(
+ reinterpret_cast<__m128i*>(C_uint8 + j),
+ _mm256_castsi256_si128(x_clamped_v));
+ } // j loop vectorized
+
+ for (; j < n; ++j) {
+ std::int32_t raw = C_int32[j];
+ if (!B_SYMMETRIC) {
+ raw -= row_offsets[j];
+ }
+ if (!A_SYMMETRIC) {
+ raw -= A_zero_point * col_offsets[j];
+ }
+ float raw_f;
+ if (HAS_BIAS) { // static if
+ if (std::is_same<BIAS_TYPE, float>::value) {
+ raw_f = raw;
+ raw_f += bias[j] / act_times_w_scale[PER_CHANNEL_QUANTIZATION ? j : 0];
+ } else {
+ raw += bias[j];
+ raw_f = raw;
+ }
+ } else {
+ raw_f = raw;
+ }
+
+ float ab = raw_f * C_multiplier[PER_CHANNEL_QUANTIZATION ? j : 0];
+ long rounded = lrintf(ab) + C_zero_point;
+
+ C_uint8[j] = std::max(
+ FUSE_RELU ? static_cast<long>(C_zero_point) : 0l,
+ std::min(255l, rounded));
+ }
+}
+
+template <bool REMAINDER>
+static inline ALWAYS_INLINE __m256i load_a(
+ const std::uint8_t* A,
+ __m256i mask_v) {
+ if (REMAINDER) {
+ return _mm256_maskload_epi32(reinterpret_cast<const int*>(A), mask_v);
+ } else {
+ return _mm256_lddqu_si256(reinterpret_cast<const __m256i*>(A));
+ }
+}
+
+static inline std::pair<int, int> closest_factors_(int n) {
+ int a = static_cast<int>(std::sqrt(n));
+ while (n % a != 0) {
+ a--;
+ }
+ return {a, n / a}; // a <= n / a
+}
+
+} // namespace fbgemm
diff --git a/src/FbgemmI8DepthwiseAvx2.cc b/src/FbgemmI8DepthwiseAvx2.cc
index f96d1d2..994f206 100644
--- a/src/FbgemmI8DepthwiseAvx2.cc
+++ b/src/FbgemmI8DepthwiseAvx2.cc
@@ -7,523 +7,15 @@
#include "fbgemm/FbgemmI8DepthwiseAvx2.h"
#include "fbgemm/Utils.h"
-#include <algorithm> // for min and max
-#include <cassert>
-#include <cmath> // for lrintf and sqrt
+#include <string>
#include <tuple> // for tie
-#include <immintrin.h>
+#include "FbgemmI8DepthwiseAvx2-inl.h"
using namespace std;
namespace fbgemm {
-static int masks[8][8] = {
- // NOTE: clang-format wants to use a different formatting but the current
- // formatting should be easier to read.
- { 0, 0, 0, 0, 0, 0, 0, 0, },
- { -1, 0, 0, 0, 0, 0, 0, 0, },
- { -1, -1, 0, 0, 0, 0, 0, 0, },
- { -1, -1, -1, 0, 0, 0, 0, 0, },
- { -1, -1, -1, -1, 0, 0, 0, 0, },
- { -1, -1, -1, -1, -1, 0, 0, 0, },
- { -1, -1, -1, -1, -1, -1, 0, 0, },
- { -1, -1, -1, -1, -1, -1, -1, 0, },
-};
-
-template <int KERNEL_PROD>
-PackedDepthWiseConvMatrix<KERNEL_PROD>::PackedDepthWiseConvMatrix(
- int K,
- const int8_t* smat)
- : K_(K) {
- // Transpose the input matrix to make packing faster.
- int8_t* smat_transposed = static_cast<int8_t *>(ALIGNED_MALLOC(
- K * KERNEL_PROD * sizeof(int8_t), 64));
- for (int i = 0; i < KERNEL_PROD; ++i) {
- for (int j = 0; j < K; ++j) {
- smat_transposed[i * K + j] = smat[i + j * KERNEL_PROD];
- }
- }
-
- // Allocate packed arrays
- constexpr int KERNEL_PROD_ALIGNED = (KERNEL_PROD + 1) / 2 * 2;
-#ifdef _MSC_VER
- pmat_ = static_cast<int8_t *>(_aligned_malloc(
- ((K + 31) / 32) * KERNEL_PROD_ALIGNED * 32 * sizeof(int8_t), 64));
-#else
- posix_memalign(
- (void**)&pmat_,
- 64,
- ((K + 31) / 32) * KERNEL_PROD_ALIGNED * 32 * sizeof(int8_t));
-#endif
-
- // Pack input matrix
- // The layout is optimized to use vpmaddubsw efficiently (see
- // madd_epi16x4_packed function).
- // For a group of 32 channels, we have 10 32B SIMD registers.
- // Denote ith channel jth filter as (i, j)
- // 0th SIMD register:
- // (0, 0), (0, 1), (0, 2), (0, 3), ..., (3, 0), (3, 1), (3, 2), (3, 3)
- // (16, 0), (16, 1), (16, 2), (16, 3), ..., (19, 0), (19, 1), (19, 2), (19, 3)
- // 1st SIMD register:
- // (4, 0), (4, 1), (4, 2), (4, 3), ..., (7, 0), (7, 1), (7, 2), (7, 3)
- // (20, 0), (20, 1), (20, 2), (20, 3), ..., (23, 0), (23, 1), (23, 2), (23, 3)
- // 2nd SIMD register:
- // (8, 0), (8, 1), (8, 2), (8, 3), ..., (11, 0), (11, 1), (11, 2), (11, 3)
- // (24, 0), (24, 1), (24, 2), (24, 3), ..., (27, 0), (27, 1), (27, 2), (27, 3)
- // 3rd SIMD register:
- // (12, 0), (12, 1), (12, 2), (12, 3), ..., (15, 0), (15, 1), (15, 2), (15, 3)
- // (28, 0), (28, 1), (28, 2), (28, 3), ..., (31, 0), (31, 1), (31, 2), (31, 3)
- // 4-7th SIMD register: same as the previous 4 registers but for 4-7th filter
- // coefficients
- // ...
- //
- // REMAINDER
- // If KERNEL_PROD % 4 == 1 for example when KERNEL_PROD == 9
- // 8th SIMD register:
- // (0, 8), zero, ..., (7, 8), zero
- // (16, 8), zero, ..., (23, 8), zero
- // 9th SIMD register:
- // (8, 8), zero, ..., (15, 8), zero
- // (24, 8), zero, ..., (31, 8), zero
- // We use madd_epi16_packed for this case
- //
- // If KERNEL_PROD % 4 == 2 for example when KERNEL_PROD == 10
- // 8th SIMD register:
- // (0, 8), (0, 9), ..., (7, 8), (7, 9)
- // (16, 8), (16, 9), ..., (23, 8), (23, 9)
- // 9th SIMD register:
- // (8, 8), (8, 9), ..., (15, 8), (15, 9)
- // (24, 8), (24, 9), ..., (31, 8), (31, 9)
- //
- // If KERNEL_PROD % 4 == 3 for example when KERNEL_PROD == 11
- // 8th SIMD register:
- // (0, 8), (0, 9), (0, 10), zero, ..., (3, 8), (3, 9), (3, 10), zero
- // (16, 8), (16, 9), (16, 10), zero, ..., (19, 8), (19, 9), (19, 10), zero
- // 9th SIMD register:
- // (4, 8), (4, 9), (4, 10), zero, ..., (7, 8), (7, 9), (7, 10), zero
- // (20, 8), (20, 9), (20, 10), zero, ..., (23, 8), (23, 9), (23, 10), zero
- // 10th SIMD register:
- // (8, 8), (8, 9), (8, 10), zero, ..., (11, 8), (11, 9), (11, 10), zero
- // (24, 8), (24, 9), (24, 10), zero, ..., (27, 8), (27, 9), (27, 10), zero
- // 11th SIMD register:
- // (12, 8), (12, 9), (12, 10), zero, ..., (15, 8), (15, 9), (15, 10), zero
- // (28, 8), (28, 9), (28, 10), zero, ..., (31, 8), (31, 9), (31, 10), zero
- for (int k1 = 0; k1 < K; k1 += 32) {
- __m256i b_v[KERNEL_PROD];
- int remainder = K - k1;
- if (remainder < 32) {
- __m256i mask_v = _mm256_loadu_si256(
- reinterpret_cast<const __m256i*>(masks[remainder / 4]));
- for (int i = 0; i < KERNEL_PROD; ++i) {
- b_v[i] = _mm256_maskload_epi32(
- reinterpret_cast<const int*>(smat_transposed + i * K + k1), mask_v);
- }
- } else {
- for (int i = 0; i < KERNEL_PROD; ++i) {
- b_v[i] = _mm256_lddqu_si256(
- reinterpret_cast<const __m256i*>(smat_transposed + i * K + k1));
- }
- }
-
- // Interleave 2 SIMD registers
- __m256i b_interleaved_epi16[KERNEL_PROD_ALIGNED];
- __m256i zero_v = _mm256_setzero_si256();
- for (int i = 0; i < KERNEL_PROD_ALIGNED / 2; ++i) {
- if (2 * i + 1 >= KERNEL_PROD) {
- b_interleaved_epi16[2 * i] = _mm256_unpacklo_epi8(b_v[2 * i], zero_v);
- b_interleaved_epi16[2 * i + 1] =
- _mm256_unpackhi_epi8(b_v[2 * i], zero_v);
- } else {
- b_interleaved_epi16[2 * i] =
- _mm256_unpacklo_epi8(b_v[2 * i], b_v[2 * i + 1]);
- b_interleaved_epi16[2 * i + 1] =
- _mm256_unpackhi_epi8(b_v[2 * i], b_v[2 * i + 1]);
- }
- }
-
- // Interleave 4 SIMD registers
- __m256i b_interleaved_epi32[KERNEL_PROD_ALIGNED];
- for (int i = 0; i < KERNEL_PROD_ALIGNED / 4; ++i) {
- b_interleaved_epi32[4 * i] = _mm256_unpacklo_epi16(
- b_interleaved_epi16[4 * i], b_interleaved_epi16[4 * i + 2]);
- b_interleaved_epi32[4 * i + 1] = _mm256_unpackhi_epi16(
- b_interleaved_epi16[4 * i], b_interleaved_epi16[4 * i + 2]);
- b_interleaved_epi32[4 * i + 2] = _mm256_unpacklo_epi16(
- b_interleaved_epi16[4 * i + 1], b_interleaved_epi16[4 * i + 3]);
- b_interleaved_epi32[4 * i + 3] = _mm256_unpackhi_epi16(
- b_interleaved_epi16[4 * i + 1], b_interleaved_epi16[4 * i + 3]);
- }
- for (int i = KERNEL_PROD_ALIGNED / 4 * 4; i < KERNEL_PROD_ALIGNED; ++i) {
- b_interleaved_epi32[i] = b_interleaved_epi16[i];
- }
-
- for (int i = 0; i < KERNEL_PROD_ALIGNED; ++i) {
- _mm256_storeu_si256(
- reinterpret_cast<__m256i*>(
- &pmat_[((k1 / 32) * KERNEL_PROD_ALIGNED + i) * 32]),
- b_interleaved_epi32[i]);
- }
- }
-
- FREE(smat_transposed);
-}
-
-template <int KERNEL_PROD>
-PackedDepthWiseConvMatrix<KERNEL_PROD>::~PackedDepthWiseConvMatrix() {
-#ifdef _MSC_VER
- _aligned_free(pmat_);
-#else
- free(pmat_);
-#endif
-}
-
-template class PackedDepthWiseConvMatrix<3 * 3>;
-template class PackedDepthWiseConvMatrix<3 * 3 * 3>;
-
-// c = a0 * b0 + a1 * b1 + a2 * b2 + a3 * b3
-// A is in uint8_t
-// B is in int8_t and pre-interleaved
-// C is in int32_t and 4 registers have results in the following layout:
-// c0_v: c[0:4], c[16:20]
-// c1_v: c[4:8], c[20:24]
-// c2_v: c[8:12], c[24:28]
-// c3_v: c[12:16], c[28:32]
-template <bool SUM_A = false>
-static inline ALWAYS_INLINE void madd_epi16x4_packed(
- __m256i a0_v,
- __m256i a1_v,
- __m256i a2_v,
- __m256i a3_v,
- const __m256i* b,
- __m256i* c0_v,
- __m256i* c1_v,
- __m256i* c2_v,
- __m256i* c3_v,
- __m256i* a_sum = nullptr) {
- __m256i a01_lo_v = _mm256_unpacklo_epi8(a0_v, a1_v);
- __m256i a01_hi_v = _mm256_unpackhi_epi8(a0_v, a1_v);
- __m256i a23_lo_v = _mm256_unpacklo_epi8(a2_v, a3_v);
- __m256i a23_hi_v = _mm256_unpackhi_epi8(a2_v, a3_v);
-
- if (SUM_A) {
- __m256i one_epi8_v = _mm256_set1_epi8(1);
- a_sum[0] =
- _mm256_adds_epi16(_mm256_maddubs_epi16(a01_lo_v, one_epi8_v), a_sum[0]);
- a_sum[1] =
- _mm256_adds_epi16(_mm256_maddubs_epi16(a01_hi_v, one_epi8_v), a_sum[1]);
- a_sum[0] =
- _mm256_adds_epi16(_mm256_maddubs_epi16(a23_lo_v, one_epi8_v), a_sum[0]);
- a_sum[1] =
- _mm256_adds_epi16(_mm256_maddubs_epi16(a23_hi_v, one_epi8_v), a_sum[1]);
- }
-
- __m256i a0_interleaved_v = _mm256_unpacklo_epi16(a01_lo_v, a23_lo_v);
- __m256i a1_interleaved_v = _mm256_unpackhi_epi16(a01_lo_v, a23_lo_v);
- __m256i a2_interleaved_v = _mm256_unpacklo_epi16(a01_hi_v, a23_hi_v);
- __m256i a3_interleaved_v = _mm256_unpackhi_epi16(a01_hi_v, a23_hi_v);
-
- __m256i b0_v = _mm256_load_si256(b + 0);
- __m256i b1_v = _mm256_load_si256(b + 1);
- __m256i b2_v = _mm256_load_si256(b + 2);
- __m256i b3_v = _mm256_load_si256(b + 3);
-
- __m256i ab0 = _mm256_maddubs_epi16(a0_interleaved_v, b0_v);
- __m256i ab1 = _mm256_maddubs_epi16(a1_interleaved_v, b1_v);
- __m256i ab2 = _mm256_maddubs_epi16(a2_interleaved_v, b2_v);
- __m256i ab3 = _mm256_maddubs_epi16(a3_interleaved_v, b3_v);
-
- __m256i one_v = _mm256_set1_epi16(1);
- *c0_v = _mm256_madd_epi16(ab0, one_v);
- *c1_v = _mm256_madd_epi16(ab1, one_v);
- *c2_v = _mm256_madd_epi16(ab2, one_v);
- *c3_v = _mm256_madd_epi16(ab3, one_v);
-}
-
-// c = a0 * b0 + a1 * b1 + a2 * b2
-// A is in uint8_t
-// B is in int8_t and pre-interleaved
-// C is in int32_t and 4 registers have results in the following layout:
-// c0_v: c[0:4], c[16:20]
-// c1_v: c[4:8], c[20:24]
-// c2_v: c[8:12], c[24:28]
-// c3_v: c[12:16], c[28:32]
-template <bool SUM_A = false>
-static inline ALWAYS_INLINE void madd_epi16x3_packed(
- __m256i a0_v,
- __m256i a1_v,
- __m256i a2_v,
- const __m256i* b,
- __m256i* c0_v,
- __m256i* c1_v,
- __m256i* c2_v,
- __m256i* c3_v,
- __m256i* a_sum = nullptr) {
- __m256i zero_v = _mm256_setzero_si256();
-
- __m256i a01_lo_v = _mm256_unpacklo_epi8(a0_v, a1_v);
- __m256i a01_hi_v = _mm256_unpackhi_epi8(a0_v, a1_v);
- __m256i a23_lo_v = _mm256_unpacklo_epi8(a2_v, zero_v);
- __m256i a23_hi_v = _mm256_unpackhi_epi8(a2_v, zero_v);
-
- if (SUM_A) {
- __m256i one_epi8_v = _mm256_set1_epi8(1);
- a_sum[0] =
- _mm256_adds_epi16(_mm256_maddubs_epi16(a01_lo_v, one_epi8_v), a_sum[0]);
- a_sum[1] =
- _mm256_adds_epi16(_mm256_maddubs_epi16(a01_hi_v, one_epi8_v), a_sum[1]);
- a_sum[0] =
- _mm256_adds_epi16(_mm256_maddubs_epi16(a23_lo_v, one_epi8_v), a_sum[0]);
- a_sum[1] =
- _mm256_adds_epi16(_mm256_maddubs_epi16(a23_hi_v, one_epi8_v), a_sum[1]);
- }
-
- __m256i a0_interleaved_v = _mm256_unpacklo_epi16(a01_lo_v, a23_lo_v);
- __m256i a1_interleaved_v = _mm256_unpackhi_epi16(a01_lo_v, a23_lo_v);
- __m256i a2_interleaved_v = _mm256_unpacklo_epi16(a01_hi_v, a23_hi_v);
- __m256i a3_interleaved_v = _mm256_unpackhi_epi16(a01_hi_v, a23_hi_v);
-
- __m256i b0_v = _mm256_load_si256(b + 0);
- __m256i b1_v = _mm256_load_si256(b + 1);
- __m256i b2_v = _mm256_load_si256(b + 2);
- __m256i b3_v = _mm256_load_si256(b + 3);
-
- __m256i ab0 = _mm256_maddubs_epi16(a0_interleaved_v, b0_v);
- __m256i ab1 = _mm256_maddubs_epi16(a1_interleaved_v, b1_v);
- __m256i ab2 = _mm256_maddubs_epi16(a2_interleaved_v, b2_v);
- __m256i ab3 = _mm256_maddubs_epi16(a3_interleaved_v, b3_v);
-
- __m256i one_v = _mm256_set1_epi16(1);
- *c0_v = _mm256_madd_epi16(ab0, one_v);
- *c1_v = _mm256_madd_epi16(ab1, one_v);
- *c2_v = _mm256_madd_epi16(ab2, one_v);
- *c3_v = _mm256_madd_epi16(ab3, one_v);
-}
-
-// c = a0 * b0 + a1 * b1
-// A is in uint8_t
-// B is in int8_t and pre-interleaved
-// C is in int32_t and 4 registers have results in the following layout:
-// c0_v: c[0:4], c[4:8]
-// c1_v: c[8:12], c[12:16]
-// c2_v: c[16:20], c[20:24]
-// c3_v: c[24:28], c[28:32]
-template <bool SUM_A = false>
-static inline ALWAYS_INLINE void madd_epi16x2_packed(
- __m256i a0_v,
- __m256i a1_v,
- const __m256i* b,
- __m256i* c0_v,
- __m256i* c1_v,
- __m256i* c2_v,
- __m256i* c3_v,
- __m256i* a_sum = nullptr) {
- __m256i a_lo_v = _mm256_unpacklo_epi8(a0_v, a1_v);
- __m256i a_hi_v = _mm256_unpackhi_epi8(a0_v, a1_v);
-
- if (SUM_A) {
- __m256i one_epi8_v = _mm256_set1_epi8(1);
- a_sum[0] =
- _mm256_adds_epi16(_mm256_maddubs_epi16(a_lo_v, one_epi8_v), a_sum[0]);
- a_sum[1] =
- _mm256_adds_epi16(_mm256_maddubs_epi16(a_hi_v, one_epi8_v), a_sum[1]);
- }
-
- __m256i b0_v = _mm256_load_si256(b + 0);
- __m256i b1_v = _mm256_load_si256(b + 1);
-
- __m256i ab_lo_v = _mm256_maddubs_epi16(a_lo_v, b0_v);
- __m256i ab_hi_v = _mm256_maddubs_epi16(a_hi_v, b1_v);
-
- *c0_v = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(ab_lo_v));
- *c1_v = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(ab_hi_v));
- *c2_v = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(ab_lo_v, 1));
- *c3_v = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(ab_hi_v, 1));
-}
-
-// c = a0 * b0
-// A is in uint8_t
-// B is in int8_t and pre-interleaved
-// C is in int32_t and 4 registers have results in the following layout:
-// c0_v: c[0:4], c[4:8]
-// c1_v: c[8:12], c[12:16]
-// c2_v: c[16:20], c[20:24]
-// c3_v: c[24:28], c[28:32]
-template <bool SUM_A = false>
-static inline ALWAYS_INLINE void madd_epi16_packed(
- __m256i a_v,
- const __m256i* b,
- __m256i* c0_v,
- __m256i* c1_v,
- __m256i* c2_v,
- __m256i* c3_v,
- __m256i* a_sum = nullptr) {
- __m256i zero_v = _mm256_setzero_si256();
-
- __m256i a_lo_v = _mm256_unpacklo_epi8(a_v, zero_v);
- __m256i a_hi_v = _mm256_unpackhi_epi8(a_v, zero_v);
-
- if (SUM_A) {
- __m256i one_epi8_v = _mm256_set1_epi8(1);
- a_sum[0] =
- _mm256_adds_epi16(_mm256_maddubs_epi16(a_lo_v, one_epi8_v), a_sum[0]);
- a_sum[1] =
- _mm256_adds_epi16(_mm256_maddubs_epi16(a_hi_v, one_epi8_v), a_sum[1]);
- }
-
- __m256i b0_v = _mm256_load_si256(b + 0);
- __m256i b1_v = _mm256_load_si256(b + 1);
-
- __m256i ab_lo_v = _mm256_maddubs_epi16(a_lo_v, b0_v);
- __m256i ab_hi_v = _mm256_maddubs_epi16(a_hi_v, b1_v);
-
- *c0_v = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(ab_lo_v));
- *c1_v = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(ab_hi_v));
- *c2_v = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(ab_lo_v, 1));
- *c3_v = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(ab_hi_v, 1));
-}
-
-// K is the number of accumulations we're doing
-template <int K, bool SUM_A = false, bool REMAINDER = false, bool ACC = false>
-static inline ALWAYS_INLINE void inner_prod_packed_(
- const __m256i* a_v,
- const __m256i* Bp,
- int32_t* C,
- int remainder,
- __m256i* a_sum = nullptr) {
- __m256i c[4], c_temp[4];
- __m256i a_sum_temp[2] = {0, 0};
-
- int k = 0;
- if (K >= 4) {
- madd_epi16x4_packed<SUM_A>(
- a_v[0],
- a_v[1],
- a_v[2],
- a_v[3],
- Bp,
- &c[0],
- &c[1],
- &c[2],
- &c[3],
- a_sum_temp);
-
- for (k = 4; k < K / 4 * 4; k += 4) {
- madd_epi16x4_packed<SUM_A>(
- a_v[k + 0],
- a_v[k + 1],
- a_v[k + 2],
- a_v[k + 3],
- Bp + k,
- &c_temp[0],
- &c_temp[1],
- &c_temp[2],
- &c_temp[3],
- a_sum_temp);
-
- c[0] = _mm256_add_epi32(c[0], c_temp[0]);
- c[1] = _mm256_add_epi32(c[1], c_temp[1]);
- c[2] = _mm256_add_epi32(c[2], c_temp[2]);
- c[3] = _mm256_add_epi32(c[3], c_temp[3]);
- }
- } else {
- c[0] = _mm256_setzero_si256();
- c[1] = _mm256_setzero_si256();
- c[2] = _mm256_setzero_si256();
- c[3] = _mm256_setzero_si256();
- }
-
- if (K - k == 3) {
- madd_epi16x3_packed<SUM_A>(
- a_v[k],
- a_v[k + 1],
- a_v[k + 2],
- Bp + k,
- &c_temp[0],
- &c_temp[1],
- &c_temp[2],
- &c_temp[3],
- a_sum_temp);
-
- c[0] = _mm256_add_epi32(c[0], c_temp[0]);
- c[1] = _mm256_add_epi32(c[1], c_temp[1]);
- c[2] = _mm256_add_epi32(c[2], c_temp[2]);
- c[3] = _mm256_add_epi32(c[3], c_temp[3]);
- }
-
- c_temp[0] = _mm256_permute2f128_si256(c[0], c[1], 0x20);
- c_temp[1] = _mm256_permute2f128_si256(c[2], c[3], 0x20);
- c_temp[2] = _mm256_permute2f128_si256(c[0], c[1], 0x31);
- c_temp[3] = _mm256_permute2f128_si256(c[2], c[3], 0x31);
-
- if (K - k == 0 || K - k == 3) {
- c[0] = c_temp[0];
- c[1] = c_temp[1];
- c[2] = c_temp[2];
- c[3] = c_temp[3];
- } else {
- if (K - k == 1) {
- madd_epi16_packed<SUM_A>(
- a_v[k], Bp + k, &c[0], &c[1], &c[2], &c[3], a_sum_temp);
- } else if (K - k == 2) {
- madd_epi16x2_packed<SUM_A>(
- a_v[k], a_v[k + 1], Bp + k, &c[0], &c[1], &c[2], &c[3], a_sum_temp);
- }
-
- c[0] = _mm256_add_epi32(c[0], c_temp[0]);
- c[1] = _mm256_add_epi32(c[1], c_temp[1]);
- c[2] = _mm256_add_epi32(c[2], c_temp[2]);
- c[3] = _mm256_add_epi32(c[3], c_temp[3]);
- }
-
- if (REMAINDER) {
- for (int r = 0; r < remainder / 8; ++r) {
- if (ACC) {
- _mm256_storeu_si256(
- reinterpret_cast<__m256i*>(C + r * 8),
- _mm256_add_epi32(
- _mm256_loadu_si256(reinterpret_cast<__m256i*>(C + r * 8)),
- c[r]));
- } else {
- _mm256_storeu_si256(reinterpret_cast<__m256i*>(C + r * 8), c[r]);
- }
- }
- } else {
- if (ACC) {
- _mm256_storeu_si256(
- reinterpret_cast<__m256i*>(C),
- _mm256_add_epi32(
- _mm256_loadu_si256(reinterpret_cast<__m256i*>(C)), c[0]));
- _mm256_storeu_si256(
- reinterpret_cast<__m256i*>(C + 8),
- _mm256_add_epi32(
- _mm256_loadu_si256(reinterpret_cast<__m256i*>(C + 8)), c[1]));
- _mm256_storeu_si256(
- reinterpret_cast<__m256i*>(C + 16),
- _mm256_add_epi32(
- _mm256_loadu_si256(reinterpret_cast<__m256i*>(C + 16)), c[2]));
- _mm256_storeu_si256(
- reinterpret_cast<__m256i*>(C + 24),
- _mm256_add_epi32(
- _mm256_loadu_si256(reinterpret_cast<__m256i*>(C + 24)), c[3]));
- } else {
- _mm256_storeu_si256(reinterpret_cast<__m256i*>(C), c[0]);
- _mm256_storeu_si256(reinterpret_cast<__m256i*>(C + 8), c[1]);
- _mm256_storeu_si256(reinterpret_cast<__m256i*>(C + 16), c[2]);
- _mm256_storeu_si256(reinterpret_cast<__m256i*>(C + 24), c[3]);
- }
- }
-
- if (SUM_A) {
- a_sum[0] = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(a_sum_temp[0]));
- a_sum[1] = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(a_sum_temp[1]));
- a_sum[2] =
- _mm256_cvtepi16_epi32(_mm256_extracti128_si256(a_sum_temp[0], 1));
- a_sum[3] =
- _mm256_cvtepi16_epi32(_mm256_extracti128_si256(a_sum_temp[1], 1));
- }
-}
-
template <bool SUM_A = false, bool REMAINDER = false>
static inline ALWAYS_INLINE void inner_prod_3x3_packed_(
const __m256i* a_v,
@@ -534,238 +26,6 @@ static inline ALWAYS_INLINE void inner_prod_3x3_packed_(
return inner_prod_packed_<9, SUM_A, REMAINDER>(a_v, Bp, C, remainder, a_sum);
}
-// Almost same as ReQuantizeOutput in OutputProcessing-inh.h but different
-// row_offsets for each row because of depth-wise convolution
-template <
- bool FUSE_RELU,
- bool HAS_BIAS,
- bool PER_CHANNEL_QUANTIZATION,
- bool A_SYMMETRIC,
- bool B_SYMMETRIC>
-static inline ALWAYS_INLINE void requantize_(
- int32_t A_zero_point,
- const float* C_multiplier,
- int32_t C_zero_point,
- const int32_t* C_int32,
- uint8_t* C_uint8,
- int n,
- const int32_t* row_offsets,
- const int32_t* col_offsets,
- const int32_t* bias) {
- __m256 multiplier_v = _mm256_setzero_ps();
- if (!PER_CHANNEL_QUANTIZATION) {
- multiplier_v = _mm256_set1_ps(*C_multiplier);
- }
-
- __m256i min_v = _mm256_set1_epi8(static_cast<uint8_t>(0));
- __m256i max_v = _mm256_set1_epi8(static_cast<uint8_t>(255));
-
- if (A_SYMMETRIC) {
- assert(A_zero_point == 0 || col_offsets == nullptr);
- }
- __m256i A_zero_point_v = _mm256_set1_epi32(A_zero_point);
- __m256i C_zero_point_epi16_v = _mm256_set1_epi16(C_zero_point);
- __m256i C_zero_point_epi8_v = _mm256_set1_epi8(C_zero_point);
-
- __m256i permute_mask_v =
- _mm256_set_epi32(0x07, 0x03, 0x06, 0x02, 0x05, 0x01, 0x04, 0x00);
-
- constexpr int VLEN = 8;
- int j = 0;
- for (; j < n / (VLEN * 4) * (VLEN * 4); j += (VLEN * 4)) {
- __m256i x_v =
- _mm256_loadu_si256(reinterpret_cast<const __m256i*>(C_int32 + j));
- __m256i y_v = _mm256_loadu_si256(
- reinterpret_cast<const __m256i*>(C_int32 + j + VLEN));
- __m256i z_v = _mm256_loadu_si256(
- reinterpret_cast<const __m256i*>(C_int32 + j + 2 * VLEN));
- __m256i w_v = _mm256_loadu_si256(
- reinterpret_cast<const __m256i*>(C_int32 + j + 3 * VLEN));
-
- __m256i row_offset_v;
- if (!B_SYMMETRIC) {
- row_offset_v =
- _mm256_loadu_si256(reinterpret_cast<const __m256i*>(row_offsets + j));
- x_v = _mm256_sub_epi32(x_v, row_offset_v);
- }
- __m256i col_off_v;
- if (!A_SYMMETRIC) {
- col_off_v = _mm256_mullo_epi32(
- A_zero_point_v,
- _mm256_loadu_si256(
- reinterpret_cast<const __m256i*>(col_offsets + j)));
- x_v = _mm256_sub_epi32(x_v, col_off_v);
- }
-
- if (!B_SYMMETRIC) {
- row_offset_v = _mm256_loadu_si256(
- reinterpret_cast<const __m256i*>(row_offsets + j + VLEN));
- y_v = _mm256_sub_epi32(y_v, row_offset_v);
- }
- if (!A_SYMMETRIC) {
- col_off_v = _mm256_mullo_epi32(
- A_zero_point_v,
- _mm256_loadu_si256(
- reinterpret_cast<const __m256i*>(col_offsets + j + VLEN)));
- y_v = _mm256_sub_epi32(y_v, col_off_v);
- }
-
- if (!B_SYMMETRIC) {
- row_offset_v = _mm256_loadu_si256(
- reinterpret_cast<const __m256i*>(row_offsets + j + 2 * VLEN));
- z_v = _mm256_sub_epi32(z_v, row_offset_v);
- }
- if (!A_SYMMETRIC) {
- col_off_v = _mm256_mullo_epi32(
- A_zero_point_v,
- _mm256_loadu_si256(
- reinterpret_cast<const __m256i*>(col_offsets + j + 2 * VLEN)));
- z_v = _mm256_sub_epi32(z_v, col_off_v);
- }
-
- if (!B_SYMMETRIC) {
- row_offset_v = _mm256_loadu_si256(
- reinterpret_cast<const __m256i*>(row_offsets + j + 3 * VLEN));
- w_v = _mm256_sub_epi32(w_v, row_offset_v);
- }
- if (!A_SYMMETRIC) {
- col_off_v = _mm256_mullo_epi32(
- A_zero_point_v,
- _mm256_loadu_si256(
- reinterpret_cast<const __m256i*>(col_offsets + j + 3 * VLEN)));
- w_v = _mm256_sub_epi32(w_v, col_off_v);
- }
-
- if (HAS_BIAS) { // static if
- x_v = _mm256_add_epi32(
- x_v, _mm256_loadu_si256(reinterpret_cast<const __m256i*>(bias + j)));
- y_v = _mm256_add_epi32(
- y_v,
- _mm256_loadu_si256(
- reinterpret_cast<const __m256i*>(bias + j + VLEN)));
- z_v = _mm256_add_epi32(
- z_v,
- _mm256_loadu_si256(
- reinterpret_cast<const __m256i*>(bias + j + 2 * VLEN)));
- w_v = _mm256_add_epi32(
- w_v,
- _mm256_loadu_si256(
- reinterpret_cast<const __m256i*>(bias + j + 3 * VLEN)));
- }
-
- if (PER_CHANNEL_QUANTIZATION) {
- multiplier_v = _mm256_loadu_ps(C_multiplier + j);
- }
- __m256 x_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(x_v), multiplier_v);
- if (PER_CHANNEL_QUANTIZATION) {
- multiplier_v = _mm256_loadu_ps(C_multiplier + j + VLEN);
- }
- __m256 y_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(y_v), multiplier_v);
- if (PER_CHANNEL_QUANTIZATION) {
- multiplier_v = _mm256_loadu_ps(C_multiplier + j + 2 * VLEN);
- }
- __m256 z_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(z_v), multiplier_v);
- if (PER_CHANNEL_QUANTIZATION) {
- multiplier_v = _mm256_loadu_ps(C_multiplier + j + 3 * VLEN);
- }
- __m256 w_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(w_v), multiplier_v);
-
- __m256i x_rounded_v = _mm256_cvtps_epi32(x_scaled_v);
- __m256i y_rounded_v = _mm256_cvtps_epi32(y_scaled_v);
- __m256i z_rounded_v = _mm256_cvtps_epi32(z_scaled_v);
- __m256i w_rounded_v = _mm256_cvtps_epi32(w_scaled_v);
-
- __m256i xy_packed_v = _mm256_adds_epi16(
- _mm256_packs_epi32(x_rounded_v, y_rounded_v), C_zero_point_epi16_v);
- __m256i zw_packed_v = _mm256_adds_epi16(
- _mm256_packs_epi32(z_rounded_v, w_rounded_v), C_zero_point_epi16_v);
- __m256i xyzw_packed_v = _mm256_packus_epi16(xy_packed_v, zw_packed_v);
- __m256i xyzw_clamped_v = _mm256_max_epu8(
- FUSE_RELU ? C_zero_point_epi8_v : min_v,
- _mm256_min_epu8(xyzw_packed_v, max_v));
-
- xyzw_clamped_v =
- _mm256_permutevar8x32_epi32(xyzw_clamped_v, permute_mask_v);
-
- _mm256_storeu_si256(
- reinterpret_cast<__m256i*>(C_uint8 + j), xyzw_clamped_v);
- } // j loop vectorized and unrolled 4x
-
- for (; j < n / VLEN * VLEN; j += VLEN) {
- __m256i x_v =
- _mm256_loadu_si256(reinterpret_cast<const __m256i*>(C_int32 + j));
-
- if (!B_SYMMETRIC) {
- __m256i row_offset_v =
- _mm256_loadu_si256(reinterpret_cast<const __m256i*>(row_offsets + j));
- x_v = _mm256_sub_epi32(x_v, row_offset_v);
- }
- if (!A_SYMMETRIC) {
- __m256i col_off_v = _mm256_mullo_epi32(
- A_zero_point_v,
- _mm256_loadu_si256(
- reinterpret_cast<const __m256i*>(col_offsets + j)));
- x_v = _mm256_sub_epi32(x_v, col_off_v);
- }
-
- if (HAS_BIAS) { // static if
- x_v = _mm256_add_epi32(
- x_v, _mm256_loadu_si256(reinterpret_cast<const __m256i*>(bias + j)));
- }
-
- if (PER_CHANNEL_QUANTIZATION) {
- multiplier_v = _mm256_loadu_ps(C_multiplier + j);
- }
- __m256 x_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(x_v), multiplier_v);
- __m256i x_rounded_v = _mm256_cvtps_epi32(x_scaled_v);
-
- __m256i x_packed_v = _mm256_adds_epi16(
- _mm256_packs_epi32(x_rounded_v, _mm256_setzero_si256()),
- C_zero_point_epi16_v);
- x_packed_v = _mm256_packus_epi16(x_packed_v, _mm256_setzero_si256());
- __m256i x_clamped_v = _mm256_max_epu8(
- FUSE_RELU ? C_zero_point_epi8_v : min_v,
- _mm256_min_epu8(x_packed_v, max_v));
-
- x_clamped_v = _mm256_permutevar8x32_epi32(x_clamped_v, permute_mask_v);
-
- _mm_storel_epi64(
- reinterpret_cast<__m128i*>(C_uint8 + j),
- _mm256_castsi256_si128(x_clamped_v));
- } // j loop vectorized
-
- for (; j < n; ++j) {
- int32_t raw = C_int32[j];
- if (!B_SYMMETRIC) {
- raw -= row_offsets[j];
- }
- if (!A_SYMMETRIC) {
- raw -= A_zero_point * col_offsets[j];
- }
- if (HAS_BIAS) { // static if
- raw += bias[j];
- }
-
- float ab = raw * C_multiplier[PER_CHANNEL_QUANTIZATION ? j : 0];
- long rounded = lrintf(ab) + C_zero_point;
-
- C_uint8[j] = std::max(
- FUSE_RELU ? static_cast<long>(C_zero_point) : 0l,
- std::min(255l, rounded));
- }
-}
-
-template <bool REMAINDER>
-static inline ALWAYS_INLINE __m256i load_a(
- const uint8_t* A,
- __m256i mask_v) {
- if (REMAINDER) {
- return _mm256_maskload_epi32(reinterpret_cast<const int*>(A), mask_v);
- } else {
- return _mm256_lddqu_si256(reinterpret_cast<const __m256i*>(A));
- }
-}
-
template <
bool SUM_A,
bool REMAINDER = false,
@@ -878,257 +138,11 @@ static inline ALWAYS_INLINE void inner_prod_3x3_packed_(
}
template <
- bool SUM_A,
- bool REMAINDER = false,
- bool PER_CHANNEL_QUANTIZATION = false>
-static inline ALWAYS_INLINE void inner_prod_3x3x3_packed_(
- int T,
- int H,
- int W,
- int K,
- int t_in,
- int h_in,
- int w_in,
- const uint8_t* A,
- int32_t A_zero_point,
- const int8_t* Bp,
- const int32_t* B_zero_point,
- int32_t* C,
- int remainder,
- int32_t* row_offsets) {
- __m256i A_zero_point_v = _mm256_set1_epi8(static_cast<uint8_t>(A_zero_point));
- __m256i mask_v = _mm256_setzero_si256();
- if (REMAINDER) {
- mask_v = _mm256_loadu_si256(
- reinterpret_cast<const __m256i*>(masks[remainder / 4]));
- }
-
- // The code below can be written as a simple R*S loop but the compiler
- // doesn't unroll so we're manually unrolling it.
- // constexpr int R = 3, S = 3;
- // array<__m256i, R * S> a_v;
- // for (int r = 0; r < R; ++r) {
- // for (int s = 0; s < S; ++s) {
- // if (h_in + r >= 0 && h_in + r < H && w_in + s >= 0 && w_in + s < W) {
- // if (REMAINDER) {
- // a_v[r * S + s] =
- // _mm256_maskload_epi32((const int *)(A + (r * W + s) * K),
- // mask_v);
- // } else {
- // a_v[r * S + s] =
- // _mm256_lddqu_si256((const __m256i *)(A + (r * W + s) * K));
- // }
- // } else {
- // a_v[r * S + s] = A_zero_point_v;
- // }
- // }
- // }
- __m256i a_v[8];
- a_v[0] = A_zero_point_v;
- a_v[1] = A_zero_point_v;
- a_v[2] = A_zero_point_v;
- a_v[3] = A_zero_point_v;
- a_v[4] = A_zero_point_v;
- a_v[5] = A_zero_point_v;
- a_v[6] = A_zero_point_v;
- a_v[7] = A_zero_point_v;
-
- if (t_in >= 0 && t_in < T) {
- if (h_in >= 0 && h_in < H) {
- if (w_in >= 0 && w_in < W) {
- a_v[0] = load_a<REMAINDER>(A + ((0 * H + 0) * W + 0) * K, mask_v);
- }
- if (w_in + 1 >= 0 && w_in + 1 < W) {
- a_v[1] = load_a<REMAINDER>(A + ((0 * H + 0) * W + 1) * K, mask_v);
- }
- if (w_in + 2 >= 0 && w_in + 2 < W) {
- a_v[2] = load_a<REMAINDER>(A + ((0 * H + 0) * W + 2) * K, mask_v);
- }
- }
-
- if (h_in + 1 >= 0 && h_in + 1 < H) {
- if (w_in >= 0 && w_in < W) {
- a_v[3] = load_a<REMAINDER>(A + ((0 * H + 1) * W + 0) * K, mask_v);
- }
- if (w_in + 1 >= 0 && w_in + 1 < W) {
- a_v[4] = load_a<REMAINDER>(A + ((0 * H + 1) * W + 1) * K, mask_v);
- }
- if (w_in + 2 >= 0 && w_in + 2 < W) {
- a_v[5] = load_a<REMAINDER>(A + ((0 * H + 1) * W + 2) * K, mask_v);
- }
- }
-
- if (h_in + 2 >= 0 && h_in + 2 < H) {
- if (w_in >= 0 && w_in < W) {
- a_v[6] = load_a<REMAINDER>(A + ((0 * H + 2) * W + 0) * K, mask_v);
- }
- if (w_in + 1 >= 0 && w_in + 1 < W) {
- a_v[7] = load_a<REMAINDER>(A + ((0 * H + 2) * W + 1) * K, mask_v);
- }
- }
- }
-
- __m256i a_sum[4];
- inner_prod_packed_<8, SUM_A, REMAINDER>(
- a_v, reinterpret_cast<const __m256i*>(Bp), C, remainder, a_sum);
-
- a_v[0] = A_zero_point_v;
- a_v[1] = A_zero_point_v;
- a_v[2] = A_zero_point_v;
- a_v[3] = A_zero_point_v;
- a_v[4] = A_zero_point_v;
- a_v[5] = A_zero_point_v;
- a_v[6] = A_zero_point_v;
- a_v[7] = A_zero_point_v;
-
- if (t_in >= 0 && t_in < T) {
- if (h_in + 2 >= 0 && h_in + 2 < H) {
- if (w_in + 2 >= 0 && w_in + 2 < W) {
- a_v[0] = load_a<REMAINDER>(A + ((0 * H + 2) * W + 2) * K, mask_v);
- }
- }
- }
-
- if (t_in + 1 >= 0 && t_in + 1 < T) {
- if (h_in >= 0 && h_in < H) {
- if (w_in >= 0 && w_in < W) {
- a_v[1] = load_a<REMAINDER>(A + ((1 * H + 0) * W + 0) * K, mask_v);
- }
- if (w_in + 1 >= 0 && w_in + 1 < W) {
- a_v[2] = load_a<REMAINDER>(A + ((1 * H + 0) * W + 1) * K, mask_v);
- }
- if (w_in + 2 >= 0 && w_in + 2 < W) {
- a_v[3] = load_a<REMAINDER>(A + ((1 * H + 0) * W + 2) * K, mask_v);
- }
- }
-
- if (h_in + 1 >= 0 && h_in + 1 < H) {
- if (w_in >= 0 && w_in < W) {
- a_v[4] = load_a<REMAINDER>(A + ((1 * H + 1) * W + 0) * K, mask_v);
- }
- if (w_in + 1 >= 0 && w_in + 1 < W) {
- a_v[5] = load_a<REMAINDER>(A + ((1 * H + 1) * W + 1) * K, mask_v);
- }
- if (w_in + 2 >= 0 && w_in + 2 < W) {
- a_v[6] = load_a<REMAINDER>(A + ((1 * H + 1) * W + 2) * K, mask_v);
- }
- }
-
- if (h_in + 2 >= 0 && h_in + 2 < H) {
- if (w_in >= 0 && w_in < W) {
- a_v[7] = load_a<REMAINDER>(A + ((1 * H + 2) * W + 0) * K, mask_v);
- }
- }
- }
-
- __m256i a_sum_temp[4];
- inner_prod_packed_<8, SUM_A, REMAINDER, true /* acc */>(
- a_v, reinterpret_cast<const __m256i*>(Bp) + 8, C, remainder, a_sum_temp);
- if (SUM_A) {
- a_sum[0] = _mm256_add_epi32(a_sum[0], a_sum_temp[0]);
- a_sum[1] = _mm256_add_epi32(a_sum[1], a_sum_temp[1]);
- a_sum[2] = _mm256_add_epi32(a_sum[2], a_sum_temp[2]);
- a_sum[3] = _mm256_add_epi32(a_sum[3], a_sum_temp[3]);
- }
-
- a_v[0] = A_zero_point_v;
- a_v[1] = A_zero_point_v;
- a_v[2] = A_zero_point_v;
- a_v[3] = A_zero_point_v;
- a_v[4] = A_zero_point_v;
- a_v[5] = A_zero_point_v;
- a_v[6] = A_zero_point_v;
- a_v[7] = A_zero_point_v;
-
- if (t_in + 1 >= 0 && t_in + 1 < T) {
- if (h_in + 2 >= 0 && h_in + 2 < H) {
- if (w_in + 1 >= 0 && w_in + 1 < W) {
- a_v[0] = load_a<REMAINDER>(A + ((1 * H + 2) * W + 1) * K, mask_v);
- }
- if (w_in + 2 >= 0 && w_in + 2 < W) {
- a_v[1] = load_a<REMAINDER>(A + ((1 * H + 2) * W + 2) * K, mask_v);
- }
- }
- }
-
- if (t_in + 2 >= 0 && t_in + 2 < T) {
- if (h_in >= 0 && h_in < H) {
- if (w_in >= 0 && w_in < W) {
- a_v[2] = load_a<REMAINDER>(A + ((2 * H + 0) * W + 0) * K, mask_v);
- }
- if (w_in + 1 >= 0 && w_in + 1 < W) {
- a_v[3] = load_a<REMAINDER>(A + ((2 * H + 0) * W + 1) * K, mask_v);
- }
- if (w_in + 2 >= 0 && w_in + 2 < W) {
- a_v[4] = load_a<REMAINDER>(A + ((2 * H + 0) * W + 2) * K, mask_v);
- }
- }
-
- if (h_in + 1 >= 0 && h_in + 1 < H) {
- if (w_in >= 0 && w_in < W) {
- a_v[5] = load_a<REMAINDER>(A + ((2 * H + 1) * W + 0) * K, mask_v);
- }
- if (w_in + 1 >= 0 && w_in + 1 < W) {
- a_v[6] = load_a<REMAINDER>(A + ((2 * H + 1) * W + 1) * K, mask_v);
- }
- if (w_in + 2 >= 0 && w_in + 2 < W) {
- a_v[7] = load_a<REMAINDER>(A + ((2 * H + 1) * W + 2) * K, mask_v);
- }
- }
- }
-
- inner_prod_packed_<8, SUM_A, REMAINDER, true /* acc */>(
- a_v, reinterpret_cast<const __m256i*>(Bp) + 16, C, remainder, a_sum_temp);
- if (SUM_A) {
- a_sum[0] = _mm256_add_epi32(a_sum[0], a_sum_temp[0]);
- a_sum[1] = _mm256_add_epi32(a_sum[1], a_sum_temp[1]);
- a_sum[2] = _mm256_add_epi32(a_sum[2], a_sum_temp[2]);
- a_sum[3] = _mm256_add_epi32(a_sum[3], a_sum_temp[3]);
- }
-
- a_v[0] = A_zero_point_v;
- a_v[1] = A_zero_point_v;
- a_v[2] = A_zero_point_v;
-
- if (t_in + 2 >= 0 && t_in + 2 < T) {
- if (h_in + 2 >= 0 && h_in + 2 < H) {
- if (w_in >= 0 && w_in < W) {
- a_v[0] = load_a<REMAINDER>(A + ((2 * H + 2) * W + 0) * K, mask_v);
- }
- if (w_in + 1 >= 0 && w_in + 1 < W) {
- a_v[1] = load_a<REMAINDER>(A + ((2 * H + 2) * W + 1) * K, mask_v);
- }
- if (w_in + 2 >= 0 && w_in + 2 < W) {
- a_v[2] = load_a<REMAINDER>(A + ((2 * H + 2) * W + 2) * K, mask_v);
- }
- }
- }
-
- inner_prod_packed_<3, SUM_A, REMAINDER, true /* acc */>(
- a_v, reinterpret_cast<const __m256i*>(Bp) + 24, C, remainder, a_sum_temp);
-
- if (SUM_A) {
- a_sum[0] = _mm256_add_epi32(a_sum[0], a_sum_temp[0]);
- a_sum[1] = _mm256_add_epi32(a_sum[1], a_sum_temp[1]);
- a_sum[2] = _mm256_add_epi32(a_sum[2], a_sum_temp[2]);
- a_sum[3] = _mm256_add_epi32(a_sum[3], a_sum_temp[3]);
-
- __m256i B_zero_point_v;
- for (int i = 0; i < (REMAINDER ? (remainder / 8) : 4); ++i) {
- if (PER_CHANNEL_QUANTIZATION) {
- B_zero_point_v = _mm256_loadu_si256(
- reinterpret_cast<const __m256i*>(B_zero_point + i * 8));
- } else {
- B_zero_point_v = _mm256_set1_epi32(B_zero_point[0]);
- }
- _mm256_store_si256(
- reinterpret_cast<__m256i*>(&row_offsets[i * 8]),
- _mm256_mullo_epi32(a_sum[i], B_zero_point_v));
- }
- }
-}
-
-template <bool FUSE_RELU, bool HAS_BIAS, bool A_SYMMETRIC, bool B_SYMMETRIC>
+ bool FUSE_RELU,
+ bool HAS_BIAS,
+ bool A_SYMMETRIC,
+ bool B_SYMMETRIC,
+ typename BIAS_TYPE>
static inline ALWAYS_INLINE void depthwise_3x3_kernel_(
int H,
int W,
@@ -1147,7 +161,8 @@ static inline ALWAYS_INLINE void depthwise_3x3_kernel_(
uint8_t* C_uint8,
int32_t* row_offsets,
const int32_t* col_offsets,
- const int32_t* bias) {
+ const BIAS_TYPE* bias,
+ float act_times_w_scale) {
constexpr int S = 3;
constexpr int PAD_T = 1, PAD_L = 1, PAD_R = 1;
int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1;
@@ -1192,7 +207,8 @@ static inline ALWAYS_INLINE void depthwise_3x3_kernel_(
HAS_BIAS,
false, /*PER_CHAN_QUANT*/
A_SYMMETRIC,
- B_SYMMETRIC>(
+ B_SYMMETRIC,
+ BIAS_TYPE>(
A_zero_point,
&C_multiplier,
C_zero_point,
@@ -1201,95 +217,11 @@ static inline ALWAYS_INLINE void depthwise_3x3_kernel_(
K,
row_offsets,
col_offsets,
- bias);
-}
-
-template <bool FUSE_RELU, bool HAS_BIAS, bool A_SYMMETRIC, bool B_SYMMETRIC>
-static inline ALWAYS_INLINE void depthwise_3x3x3_kernel_(
- int T,
- int H,
- int W,
- int K,
- int t,
- int h,
- int w,
- int stride_t,
- int stride_h,
- int stride_w,
- int32_t A_zero_point,
- const uint8_t* A,
- int32_t B_zero_point,
- const int8_t* Bp,
- float C_multiplier,
- int32_t C_zero_point,
- int32_t* C_int32,
- uint8_t* C_uint8,
- int32_t* row_offsets,
- const int32_t* col_offsets,
- const int32_t* bias) {
- constexpr int R = 3, S = 3;
- constexpr int PAD_P = 1, PAD_T = 1, PAD_B = 1, PAD_L = 1, PAD_R = 1;
- int H_OUT = (H + PAD_T + PAD_B - R) / stride_h + 1;
- int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1;
- int t_in = -PAD_P + t * stride_t;
- int h_in = -PAD_T + h * stride_h;
- int w_in = -PAD_L + w * stride_w;
-
- int k;
- for (k = 0; k < K / 32 * 32; k += 32) {
- inner_prod_3x3x3_packed_<!B_SYMMETRIC /*SUM_A*/>(
- T,
- H,
- W,
- K,
- t_in,
- h_in,
- w_in,
- A + ((t_in * H + h_in) * W + w_in) * K + k,
- A_zero_point,
- Bp + k * 28,
- &B_zero_point,
- C_int32 + k,
- 0,
- B_SYMMETRIC ? nullptr : &row_offsets[k]);
- }
- int remainder = K - k;
- if (remainder) {
- inner_prod_3x3x3_packed_<!B_SYMMETRIC /*SUM_A*/, true>(
- T,
- H,
- W,
- K,
- t_in,
- h_in,
- w_in,
- A + ((t_in * H + h_in) * W + w_in) * K + k,
- A_zero_point,
- Bp + k * 28,
- &B_zero_point,
- C_int32 + k,
- remainder,
- B_SYMMETRIC ? nullptr : &row_offsets[k]);
- }
-
- requantize_<
- FUSE_RELU,
- HAS_BIAS,
- false, /*PER_CHAN_QUANT*/
- A_SYMMETRIC,
- B_SYMMETRIC>(
- A_zero_point,
- &C_multiplier,
- C_zero_point,
- C_int32,
- C_uint8 + ((t * H_OUT + h) * W_OUT + w) * K,
- K,
- row_offsets,
- col_offsets,
- bias);
+ bias,
+ &act_times_w_scale);
}
-template <bool FUSE_RELU, bool HAS_BIAS, bool A_SYMMETRIC>
+template <bool FUSE_RELU, bool HAS_BIAS, bool A_SYMMETRIC, typename BIAS_TYPE>
static inline ALWAYS_INLINE void
depthwise_3x3_per_channel_quantization_kernel_(
int H,
@@ -1309,7 +241,8 @@ depthwise_3x3_per_channel_quantization_kernel_(
uint8_t* C_uint8,
int32_t* row_offsets,
const int32_t* col_offsets,
- const int32_t* bias) {
+ const BIAS_TYPE* bias,
+ const float* act_times_w_scale) {
constexpr int S = 3;
constexpr int PAD_T = 1, PAD_L = 1, PAD_R = 1;
int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1;
@@ -1360,7 +293,8 @@ depthwise_3x3_per_channel_quantization_kernel_(
HAS_BIAS,
true, /*PER_CHAN_QUANT*/
A_SYMMETRIC,
- false /*B_SYMM*/>(
+ false, /*B_SYMM*/
+ BIAS_TYPE>(
A_zero_point,
C_multiplier,
C_zero_point,
@@ -1369,113 +303,20 @@ depthwise_3x3_per_channel_quantization_kernel_(
K,
row_offsets,
col_offsets,
- bias);
-}
-
-template <bool FUSE_RELU, bool HAS_BIAS, bool A_SYMMETRIC>
-static inline ALWAYS_INLINE void
-depthwise_3x3x3_per_channel_quantization_kernel_(
- int T,
- int H,
- int W,
- int K,
- int t,
- int h,
- int w,
- int stride_t,
- int stride_h,
- int stride_w,
- int32_t A_zero_point,
- const uint8_t* A,
- const int32_t* B_zero_point,
- const int8_t* Bp,
- const float* C_multiplier,
- int32_t C_zero_point,
- int32_t* C_int32,
- uint8_t* C_uint8,
- int32_t* row_offsets,
- const int32_t* col_offsets,
- const int32_t* bias) {
- constexpr int R = 3, S = 3;
- constexpr int PAD_P = 1, PAD_T = 1, PAD_B = 1, PAD_L = 1, PAD_R = 1;
- int H_OUT = (H + PAD_T + PAD_B - R) / stride_h + 1;
- int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1;
- int t_in = -PAD_P + t * stride_t;
- int h_in = -PAD_T + h * stride_h;
- int w_in = -PAD_L + w * stride_w;
-
- int k;
- for (k = 0; k < K / 32 * 32; k += 32) {
- inner_prod_3x3x3_packed_<
- true, /*SUM_A*/
- false, /*remainder*/
- true /*per-channel*/>(
- T,
- H,
- W,
- K,
- t_in,
- h_in,
- w_in,
- A + ((t_in * H + h_in) * W + w_in) * K + k,
- A_zero_point,
- Bp + k * 28,
- B_zero_point + k,
- C_int32 + k,
- 0,
- &row_offsets[k]);
- }
- int remainder = K - k;
- if (remainder) {
- inner_prod_3x3x3_packed_<
- true, /*SUM_A*/
- true, /*remainder*/
- true /*per-channel*/>(
- T,
- H,
- W,
- K,
- t_in,
- h_in,
- w_in,
- A + ((t_in * H + h_in) * W + w_in) * K + k,
- A_zero_point,
- Bp + k * 28,
- B_zero_point + k,
- C_int32 + k,
- remainder,
- &row_offsets[k]);
- }
- requantize_<
- FUSE_RELU,
- HAS_BIAS,
- true, /*PER_CHAN_QUANT*/
- A_SYMMETRIC,
- false /*B_SYMM*/>(
- A_zero_point,
- C_multiplier,
- C_zero_point,
- C_int32,
- C_uint8 + ((t * H_OUT + h) * W_OUT + w) * K,
- K,
- row_offsets,
- col_offsets,
- bias);
-}
-
-static pair<int, int> closest_factors_(int n) {
- int a = (int)std::sqrt(n);
- while (n % a != 0) {
- a--;
- }
- return {a, n / a}; // a <= n / a
+ bias,
+ act_times_w_scale);
}
// TODO: short-circuit when B_zero_point is 0 or A_zero_point is 0
// This implemntation should be general enough to handle not just 3x3 but other
// filter shapes by parameterizing with R and S but restricting it to just 3x3
// for now.
-template <bool FUSE_RELU, bool HAS_BIAS, bool A_SYMMETRIC, bool B_SYMMETRIC>
+template <
+ bool FUSE_RELU,
+ bool HAS_BIAS,
+ bool A_SYMMETRIC,
+ bool B_SYMMETRIC,
+ typename BIAS_TYPE>
static inline ALWAYS_INLINE void depthwise_3x3_pad_1_(
int N,
int H,
@@ -1486,13 +327,14 @@ static inline ALWAYS_INLINE void depthwise_3x3_pad_1_(
int32_t A_zero_point,
const uint8_t* A,
int32_t B_zero_point,
- const Packed3x3ConvMatrix& B,
+ const PackedDepthWiseConvMatrix& B,
float C_multiplier,
int32_t C_zero_point,
int32_t* C_int32,
uint8_t* C_uint8,
const int32_t* col_offsets,
- const int32_t* bias,
+ const BIAS_TYPE* bias,
+ float act_times_w_scale,
int thread_id,
int num_threads) {
assert(K % 8 == 0);
@@ -1551,7 +393,12 @@ static inline ALWAYS_INLINE void depthwise_3x3_pad_1_(
if (h_begin == 0) {
if (w_begin == 0) {
- depthwise_3x3_kernel_<FUSE_RELU, HAS_BIAS, A_SYMMETRIC, B_SYMMETRIC>(
+ depthwise_3x3_kernel_<
+ FUSE_RELU,
+ HAS_BIAS,
+ A_SYMMETRIC,
+ B_SYMMETRIC,
+ BIAS_TYPE>(
H,
W,
K,
@@ -1569,11 +416,17 @@ static inline ALWAYS_INLINE void depthwise_3x3_pad_1_(
C_uint8_base,
row_offsets,
col_offsets,
- bias);
+ bias,
+ act_times_w_scale);
}
for (w = std::max(1, w_begin); w < std::min(W_OUT - 1, w_end); ++w) {
- depthwise_3x3_kernel_<FUSE_RELU, HAS_BIAS, A_SYMMETRIC, B_SYMMETRIC>(
+ depthwise_3x3_kernel_<
+ FUSE_RELU,
+ HAS_BIAS,
+ A_SYMMETRIC,
+ B_SYMMETRIC,
+ BIAS_TYPE>(
H,
W,
K,
@@ -1591,12 +444,18 @@ static inline ALWAYS_INLINE void depthwise_3x3_pad_1_(
C_uint8_base,
row_offsets,
col_offsets,
- bias);
+ bias,
+ act_times_w_scale);
}
if (w_end == W_OUT) {
w = W_OUT - 1;
- depthwise_3x3_kernel_<FUSE_RELU, HAS_BIAS, A_SYMMETRIC, B_SYMMETRIC>(
+ depthwise_3x3_kernel_<
+ FUSE_RELU,
+ HAS_BIAS,
+ A_SYMMETRIC,
+ B_SYMMETRIC,
+ BIAS_TYPE>(
H,
W,
K,
@@ -1614,14 +473,20 @@ static inline ALWAYS_INLINE void depthwise_3x3_pad_1_(
C_uint8_base,
row_offsets,
col_offsets,
- bias);
+ bias,
+ act_times_w_scale);
}
}
for (h = std::max(1, h_begin); h < std::min(H - 1, h_end); ++h) {
if (w_begin == 0) {
w = 0;
- depthwise_3x3_kernel_<FUSE_RELU, HAS_BIAS, A_SYMMETRIC, B_SYMMETRIC>(
+ depthwise_3x3_kernel_<
+ FUSE_RELU,
+ HAS_BIAS,
+ A_SYMMETRIC,
+ B_SYMMETRIC,
+ BIAS_TYPE>(
H,
W,
K,
@@ -1639,11 +504,17 @@ static inline ALWAYS_INLINE void depthwise_3x3_pad_1_(
C_uint8_base,
row_offsets,
col_offsets,
- bias);
+ bias,
+ act_times_w_scale);
}
for (w = std::max(1, w_begin); w < std::min(W_OUT - 1, w_end); ++w) {
- depthwise_3x3_kernel_<FUSE_RELU, HAS_BIAS, A_SYMMETRIC, B_SYMMETRIC>(
+ depthwise_3x3_kernel_<
+ FUSE_RELU,
+ HAS_BIAS,
+ A_SYMMETRIC,
+ B_SYMMETRIC,
+ BIAS_TYPE>(
H,
W,
K,
@@ -1661,12 +532,18 @@ static inline ALWAYS_INLINE void depthwise_3x3_pad_1_(
C_uint8_base,
row_offsets,
col_offsets,
- bias);
+ bias,
+ act_times_w_scale);
}
if (w_end == W_OUT) {
w = W_OUT - 1;
- depthwise_3x3_kernel_<FUSE_RELU, HAS_BIAS, A_SYMMETRIC, B_SYMMETRIC>(
+ depthwise_3x3_kernel_<
+ FUSE_RELU,
+ HAS_BIAS,
+ A_SYMMETRIC,
+ B_SYMMETRIC,
+ BIAS_TYPE>(
H,
W,
K,
@@ -1684,7 +561,8 @@ static inline ALWAYS_INLINE void depthwise_3x3_pad_1_(
C_uint8_base,
row_offsets,
col_offsets,
- bias);
+ bias,
+ act_times_w_scale);
}
}
@@ -1692,7 +570,12 @@ static inline ALWAYS_INLINE void depthwise_3x3_pad_1_(
h = H_OUT - 1;
w = 0;
if (w_begin == 0) {
- depthwise_3x3_kernel_<FUSE_RELU, HAS_BIAS, A_SYMMETRIC, B_SYMMETRIC>(
+ depthwise_3x3_kernel_<
+ FUSE_RELU,
+ HAS_BIAS,
+ A_SYMMETRIC,
+ B_SYMMETRIC,
+ BIAS_TYPE>(
H,
W,
K,
@@ -1710,11 +593,17 @@ static inline ALWAYS_INLINE void depthwise_3x3_pad_1_(
C_uint8_base,
row_offsets,
col_offsets,
- bias);
+ bias,
+ act_times_w_scale);
}
for (w = std::max(1, w_begin); w < std::min(W_OUT - 1, w_end); ++w) {
- depthwise_3x3_kernel_<FUSE_RELU, HAS_BIAS, A_SYMMETRIC, B_SYMMETRIC>(
+ depthwise_3x3_kernel_<
+ FUSE_RELU,
+ HAS_BIAS,
+ A_SYMMETRIC,
+ B_SYMMETRIC,
+ BIAS_TYPE>(
H,
W,
K,
@@ -1732,12 +621,18 @@ static inline ALWAYS_INLINE void depthwise_3x3_pad_1_(
C_uint8_base,
row_offsets,
col_offsets,
- bias);
+ bias,
+ act_times_w_scale);
}
if (w_end == W_OUT) {
w = W_OUT - 1;
- depthwise_3x3_kernel_<FUSE_RELU, HAS_BIAS, A_SYMMETRIC, B_SYMMETRIC>(
+ depthwise_3x3_kernel_<
+ FUSE_RELU,
+ HAS_BIAS,
+ A_SYMMETRIC,
+ B_SYMMETRIC,
+ BIAS_TYPE>(
H,
W,
K,
@@ -1755,126 +650,15 @@ static inline ALWAYS_INLINE void depthwise_3x3_pad_1_(
C_uint8_base,
row_offsets,
col_offsets,
- bias);
+ bias,
+ act_times_w_scale);
}
}
} // for each n
FREE(row_offsets);
};
-template <bool FUSE_RELU, bool HAS_BIAS, bool A_SYMMETRIC, bool B_SYMMETRIC>
-static inline ALWAYS_INLINE void depthwise_3x3x3_pad_1_(
- int N,
- int T,
- int H,
- int W,
- int K,
- int stride_t,
- int stride_h,
- int stride_w,
- int32_t A_zero_point,
- const uint8_t* A,
- int32_t B_zero_point,
- const Packed3x3x3ConvMatrix& B,
- float C_multiplier,
- int32_t C_zero_point,
- int32_t* C_int32,
- uint8_t* C_uint8,
- const int32_t* col_offsets,
- const int32_t* bias,
- int thread_id,
- int num_threads) {
- assert(K % 8 == 0);
- constexpr int K_T = 3, K_H = 3, K_W = 3;
- constexpr int PAD_P = 1, PAD_N = 1, PAD_T = 1, PAD_B = 1, PAD_L = 1,
- PAD_R = 1;
- int T_OUT = (T + PAD_P + PAD_N - K_T) / stride_t + 1;
- int H_OUT = (H + PAD_T + PAD_B - K_H) / stride_h + 1;
- int W_OUT = (W + PAD_L + PAD_R - K_W) / stride_w + 1;
- const int8_t* Bp = B.PackedMat();
-
- int32_t* row_offsets = static_cast<int32_t*>(ALIGNED_MALLOC(((K + 31) / 32 * 32)*sizeof(int32_t), 64)); // __attribute__((aligned(64)));
-
- int n_begin, n_end;
- int t_begin, t_end, h_begin, h_end;
- if (N >= num_threads) {
- int n_per_thread = (N + num_threads - 1) / num_threads;
- n_begin = std::min(thread_id * n_per_thread, N);
- n_end = std::min(n_begin + n_per_thread, N);
- t_begin = 0;
- t_end = T_OUT;
- h_begin = 0;
- h_end = H_OUT;
- } else {
- int nthreads_per_n = num_threads / N;
- n_begin = std::min(thread_id / nthreads_per_n, N);
- n_end = std::min(n_begin + 1, N);
-
- int tid_of_n_begin = std::min(n_begin * nthreads_per_n, num_threads);
- int tid_of_n_end = std::min(tid_of_n_begin + nthreads_per_n, num_threads);
- int nthreads_of_n = tid_of_n_end - tid_of_n_begin;
- int tid_within_n = thread_id - tid_of_n_begin;
- assert(tid_within_n >= 0);
- assert(tid_within_n < nthreads_of_n);
-
- // n is processed by num_threads_t * num_threads_h 2D grid of threads
- int num_threads_t, num_threads_h;
- // num_threads_w <= num_threads_h
- tie(num_threads_t, num_threads_h) = closest_factors_(nthreads_of_n);
- int tid_t = tid_within_n / num_threads_h;
- int tid_h = tid_within_n % num_threads_h;
-
- int t_per_thread = (T_OUT + num_threads_t - 1) / num_threads_t;
- t_begin = std::min(tid_t * t_per_thread, T_OUT);
- t_end = std::min(t_begin + t_per_thread, T_OUT);
-
- int h_per_thread = (H_OUT + num_threads_h - 1) / num_threads_h;
- h_begin = std::min(tid_h * h_per_thread, H_OUT);
- h_end = std::min(h_begin + h_per_thread, H_OUT);
- }
-
- for (int n = n_begin; n < n_end; ++n) {
- const uint8_t* A_base = A + n * T * H * W * K;
- uint8_t* C_uint8_base = C_uint8 + n * T_OUT * H_OUT * W_OUT * K;
-
- for (int t = t_begin; t < t_end; ++t) {
- for (int h = h_begin; h < h_end; ++h) {
- for (int w = 0; w < W_OUT; ++w) {
- depthwise_3x3x3_kernel_<
- FUSE_RELU,
- HAS_BIAS,
- A_SYMMETRIC,
- B_SYMMETRIC>(
- T,
- H,
- W,
- K,
- t,
- h,
- w,
- stride_t,
- stride_h,
- stride_w,
- A_zero_point,
- A_base,
- B_zero_point,
- Bp,
- C_multiplier,
- C_zero_point,
- C_int32,
- C_uint8_base,
- row_offsets,
- col_offsets,
- bias);
- } // w
- } // h
- } // t
- } // for each n
-
- FREE(row_offsets);
-};
-
-template <bool FUSE_RELU, bool HAS_BIAS, bool A_SYMMETRIC>
+template <bool FUSE_RELU, bool HAS_BIAS, bool A_SYMMETRIC, typename BIAS_TYPE>
static inline ALWAYS_INLINE void
depthwise_3x3_per_channel_quantization_pad_1_(
int N,
@@ -1886,13 +670,14 @@ depthwise_3x3_per_channel_quantization_pad_1_(
int32_t A_zero_point,
const uint8_t* A,
const int32_t* B_zero_point,
- const Packed3x3ConvMatrix& B,
+ const PackedDepthWiseConvMatrix& B,
const float* C_multiplier,
int32_t C_zero_point,
int32_t* C_int32,
uint8_t* C_uint8,
const int32_t* col_offsets,
- const int32_t* bias,
+ const BIAS_TYPE* bias,
+ const float* act_times_w_scale,
int thread_id,
int num_threads) {
assert(K % 8 == 0);
@@ -1954,7 +739,8 @@ depthwise_3x3_per_channel_quantization_pad_1_(
depthwise_3x3_per_channel_quantization_kernel_<
FUSE_RELU,
HAS_BIAS,
- A_SYMMETRIC>(
+ A_SYMMETRIC,
+ BIAS_TYPE>(
H,
W,
K,
@@ -1972,14 +758,16 @@ depthwise_3x3_per_channel_quantization_pad_1_(
C_uint8_base,
row_offsets,
col_offsets,
- bias);
+ bias,
+ act_times_w_scale);
}
for (w = std::max(1, w_begin); w < std::min(W_OUT - 1, w_end); ++w) {
depthwise_3x3_per_channel_quantization_kernel_<
FUSE_RELU,
HAS_BIAS,
- A_SYMMETRIC>(
+ A_SYMMETRIC,
+ BIAS_TYPE>(
H,
W,
K,
@@ -1997,7 +785,8 @@ depthwise_3x3_per_channel_quantization_pad_1_(
C_uint8_base,
row_offsets,
col_offsets,
- bias);
+ bias,
+ act_times_w_scale);
}
if (w_end == W_OUT) {
@@ -2005,7 +794,8 @@ depthwise_3x3_per_channel_quantization_pad_1_(
depthwise_3x3_per_channel_quantization_kernel_<
FUSE_RELU,
HAS_BIAS,
- A_SYMMETRIC>(
+ A_SYMMETRIC,
+ BIAS_TYPE>(
H,
W,
K,
@@ -2023,7 +813,8 @@ depthwise_3x3_per_channel_quantization_pad_1_(
C_uint8_base,
row_offsets,
col_offsets,
- bias);
+ bias,
+ act_times_w_scale);
}
}
@@ -2033,7 +824,8 @@ depthwise_3x3_per_channel_quantization_pad_1_(
depthwise_3x3_per_channel_quantization_kernel_<
FUSE_RELU,
HAS_BIAS,
- A_SYMMETRIC>(
+ A_SYMMETRIC,
+ BIAS_TYPE>(
H,
W,
K,
@@ -2051,14 +843,16 @@ depthwise_3x3_per_channel_quantization_pad_1_(
C_uint8_base,
row_offsets,
col_offsets,
- bias);
+ bias,
+ act_times_w_scale);
}
for (w = std::max(1, w_begin); w < std::min(W_OUT - 1, w_end); ++w) {
depthwise_3x3_per_channel_quantization_kernel_<
FUSE_RELU,
HAS_BIAS,
- A_SYMMETRIC>(
+ A_SYMMETRIC,
+ BIAS_TYPE>(
H,
W,
K,
@@ -2076,7 +870,8 @@ depthwise_3x3_per_channel_quantization_pad_1_(
C_uint8_base,
row_offsets,
col_offsets,
- bias);
+ bias,
+ act_times_w_scale);
}
if (w_end == W_OUT) {
@@ -2084,7 +879,8 @@ depthwise_3x3_per_channel_quantization_pad_1_(
depthwise_3x3_per_channel_quantization_kernel_<
FUSE_RELU,
HAS_BIAS,
- A_SYMMETRIC>(
+ A_SYMMETRIC,
+ BIAS_TYPE>(
H,
W,
K,
@@ -2102,7 +898,8 @@ depthwise_3x3_per_channel_quantization_pad_1_(
C_uint8_base,
row_offsets,
col_offsets,
- bias);
+ bias,
+ act_times_w_scale);
}
}
@@ -2113,7 +910,8 @@ depthwise_3x3_per_channel_quantization_pad_1_(
depthwise_3x3_per_channel_quantization_kernel_<
FUSE_RELU,
HAS_BIAS,
- A_SYMMETRIC>(
+ A_SYMMETRIC,
+ BIAS_TYPE>(
H,
W,
K,
@@ -2131,14 +929,16 @@ depthwise_3x3_per_channel_quantization_pad_1_(
C_uint8_base,
row_offsets,
col_offsets,
- bias);
+ bias,
+ act_times_w_scale);
}
for (w = std::max(1, w_begin); w < std::min(W_OUT - 1, w_end); ++w) {
depthwise_3x3_per_channel_quantization_kernel_<
FUSE_RELU,
HAS_BIAS,
- A_SYMMETRIC>(
+ A_SYMMETRIC,
+ BIAS_TYPE>(
H,
W,
K,
@@ -2156,7 +956,8 @@ depthwise_3x3_per_channel_quantization_pad_1_(
C_uint8_base,
row_offsets,
col_offsets,
- bias);
+ bias,
+ act_times_w_scale);
}
if (w_end == W_OUT) {
@@ -2164,7 +965,8 @@ depthwise_3x3_per_channel_quantization_pad_1_(
depthwise_3x3_per_channel_quantization_kernel_<
FUSE_RELU,
HAS_BIAS,
- A_SYMMETRIC>(
+ A_SYMMETRIC,
+ BIAS_TYPE>(
H,
W,
K,
@@ -2182,128 +984,15 @@ depthwise_3x3_per_channel_quantization_pad_1_(
C_uint8_base,
row_offsets,
col_offsets,
- bias);
+ bias,
+ act_times_w_scale);
}
}
} // for each n
-
- FREE(row_offsets);
-};
-
-template <bool FUSE_RELU, bool HAS_BIAS, bool A_SYMMETRIC>
-static inline ALWAYS_INLINE void
-depthwise_3x3x3_per_channel_quantization_pad_1_(
- int N,
- int T,
- int H,
- int W,
- int K,
- int stride_t,
- int stride_h,
- int stride_w,
- int32_t A_zero_point,
- const uint8_t* A,
- const int32_t* B_zero_point,
- const Packed3x3x3ConvMatrix& B,
- const float* C_multiplier,
- int32_t C_zero_point,
- int32_t* C_int32,
- uint8_t* C_uint8,
- const int32_t* col_offsets,
- const int32_t* bias,
- int thread_id,
- int num_threads) {
- assert(K % 8 == 0);
- constexpr int K_T = 3, K_H = 3, K_W = 3;
- constexpr int PAD_P = 1, PAD_N = 1, PAD_T = 1, PAD_B = 1, PAD_L = 1,
- PAD_R = 1;
- int T_OUT = (T + PAD_P + PAD_N - K_T) / stride_t + 1;
- int H_OUT = (H + PAD_T + PAD_B - K_H) / stride_h + 1;
- int W_OUT = (W + PAD_L + PAD_R - K_W) / stride_w + 1;
- const int8_t* Bp = B.PackedMat();
-
- int32_t* row_offsets = static_cast<int32_t*>(ALIGNED_MALLOC(((K + 31) / 32 * 32)*sizeof(int32_t), 64)); // __attribute__((aligned(64)));
-
- int n_begin, n_end;
- int t_begin, t_end, h_begin, h_end;
- if (N >= num_threads) {
- int n_per_thread = (N + num_threads - 1) / num_threads;
- n_begin = std::min(thread_id * n_per_thread, N);
- n_end = std::min(n_begin + n_per_thread, N);
- t_begin = 0;
- t_end = T_OUT;
- h_begin = 0;
- h_end = H_OUT;
- } else {
- int nthreads_per_n = num_threads / N;
- n_begin = std::min(thread_id / nthreads_per_n, N);
- n_end = std::min(n_begin + 1, N);
-
- int tid_of_n_begin = std::min(n_begin * nthreads_per_n, num_threads);
- int tid_of_n_end = std::min(tid_of_n_begin + nthreads_per_n, num_threads);
- int nthreads_of_n = tid_of_n_end - tid_of_n_begin;
- int tid_within_n = thread_id - tid_of_n_begin;
- assert(tid_within_n >= 0);
- assert(tid_within_n < nthreads_of_n);
-
- // n is processed by num_threads_t * num_threads_h 2D grid of threads
- int num_threads_t, num_threads_h;
- // num_threads_w <= num_threads_h
- tie(num_threads_t, num_threads_h) = closest_factors_(nthreads_of_n);
- int tid_t = tid_within_n / num_threads_h;
- int tid_h = tid_within_n % num_threads_h;
-
- int t_per_thread = (T_OUT + num_threads_t - 1) / num_threads_t;
- t_begin = std::min(tid_t * t_per_thread, T_OUT);
- t_end = std::min(t_begin + t_per_thread, T_OUT);
-
- int h_per_thread = (H_OUT + num_threads_h - 1) / num_threads_h;
- h_begin = std::min(tid_h * h_per_thread, H_OUT);
- h_end = std::min(h_begin + h_per_thread, H_OUT);
- }
-
- for (int n = n_begin; n < n_end; ++n) {
- const uint8_t* A_base = A + n * T * H * W * K;
- uint8_t* C_uint8_base = C_uint8 + n * T_OUT * H_OUT * W_OUT * K;
-
- for (int t = t_begin; t < t_end; ++t) {
- for (int h = h_begin; h < h_end; ++h) {
- for (int w = 0; w < W_OUT; ++w) {
- depthwise_3x3x3_per_channel_quantization_kernel_<
- FUSE_RELU,
- HAS_BIAS,
- A_SYMMETRIC>(
- T,
- H,
- W,
- K,
- t,
- h,
- w,
- stride_t,
- stride_h,
- stride_w,
- A_zero_point,
- A_base,
- B_zero_point,
- Bp,
- C_multiplier,
- C_zero_point,
- C_int32,
- C_uint8_base,
- row_offsets,
- col_offsets,
- bias);
- } // w
- } // h
- } // t
- } // for each n
-
- FREE(row_offsets);
};
// Dispatch A_SYMMETRIC and B_SYMMETRIC
-template <bool FUSE_RELU, bool HAS_BIAS>
+template <bool FUSE_RELU, bool HAS_BIAS, typename BIAS_TYPE>
static void depthwise_3x3_pad_1_(
int N,
int H,
@@ -2314,12 +1003,13 @@ static void depthwise_3x3_pad_1_(
int32_t A_zero_point,
const uint8_t* A,
int32_t B_zero_point,
- const Packed3x3ConvMatrix& B,
+ const PackedDepthWiseConvMatrix& B,
float C_multiplier,
int32_t C_zero_point,
uint8_t* C,
const int32_t* col_offsets,
- const int32_t* bias,
+ const BIAS_TYPE* bias,
+ float act_times_w_scale,
int thread_id,
int num_threads) {
int32_t* C_int32_temp = new int32_t[(K + 31) / 32 * 32];
@@ -2329,7 +1019,8 @@ static void depthwise_3x3_pad_1_(
FUSE_RELU,
HAS_BIAS,
true /*A_symmetric*/,
- true /*B_symmetric*/>(
+ true /*B_symmetric*/,
+ BIAS_TYPE>(
N,
H,
W,
@@ -2346,6 +1037,7 @@ static void depthwise_3x3_pad_1_(
C,
col_offsets,
bias,
+ act_times_w_scale,
thread_id,
num_threads);
} else {
@@ -2353,7 +1045,8 @@ static void depthwise_3x3_pad_1_(
FUSE_RELU,
HAS_BIAS,
true /*A_symmetric*/,
- false /*B_symmetric*/>(
+ false /*B_symmetric*/,
+ BIAS_TYPE>(
N,
H,
W,
@@ -2370,6 +1063,7 @@ static void depthwise_3x3_pad_1_(
C,
col_offsets,
bias,
+ act_times_w_scale,
thread_id,
num_threads);
}
@@ -2379,7 +1073,8 @@ static void depthwise_3x3_pad_1_(
FUSE_RELU,
HAS_BIAS,
false /*A_symmetric*/,
- true /*B_symmetric*/>(
+ true /*B_symmetric*/,
+ BIAS_TYPE>(
N,
H,
W,
@@ -2396,6 +1091,7 @@ static void depthwise_3x3_pad_1_(
C,
col_offsets,
bias,
+ act_times_w_scale,
thread_id,
num_threads);
} else {
@@ -2403,7 +1099,8 @@ static void depthwise_3x3_pad_1_(
FUSE_RELU,
HAS_BIAS,
false /*A_symmetric*/,
- false /*B_symmetric*/>(
+ false /*B_symmetric*/,
+ BIAS_TYPE>(
N,
H,
W,
@@ -2420,6 +1117,7 @@ static void depthwise_3x3_pad_1_(
C,
col_offsets,
bias,
+ act_times_w_scale,
thread_id,
num_threads);
}
@@ -2428,7 +1126,7 @@ static void depthwise_3x3_pad_1_(
}
// Dispatch HAS_BIAS
-template <bool FUSE_RELU>
+template <bool FUSE_RELU, typename BIAS_TYPE>
static void depthwise_3x3_pad_1_(
int N,
int H,
@@ -2439,16 +1137,17 @@ static void depthwise_3x3_pad_1_(
int32_t A_zero_point,
const uint8_t* A,
int32_t B_zero_point,
- const Packed3x3ConvMatrix& B,
+ const PackedDepthWiseConvMatrix& B,
float C_multiplier,
int32_t C_zero_point,
uint8_t* C,
const int32_t* col_offsets,
- const int32_t* bias,
+ const BIAS_TYPE* bias,
+ float act_times_w_scale,
int thread_id,
int num_threads) {
if (bias) {
- depthwise_3x3_pad_1_<FUSE_RELU, true /*HAS_BIAS*/>(
+ depthwise_3x3_pad_1_<FUSE_RELU, true /*HAS_BIAS*/, BIAS_TYPE>(
N,
H,
W,
@@ -2464,10 +1163,11 @@ static void depthwise_3x3_pad_1_(
C,
col_offsets,
bias,
+ act_times_w_scale,
thread_id,
num_threads);
} else {
- depthwise_3x3_pad_1_<FUSE_RELU, false /*HAS_BIAS*/>(
+ depthwise_3x3_pad_1_<FUSE_RELU, false /*HAS_BIAS*/, BIAS_TYPE>(
N,
H,
W,
@@ -2483,6 +1183,7 @@ static void depthwise_3x3_pad_1_(
C,
col_offsets,
bias,
+ act_times_w_scale,
thread_id,
num_threads);
}
@@ -2490,6 +1191,7 @@ static void depthwise_3x3_pad_1_(
// Dispatch input shape and FUSE_RELU
// assumption: W > 3 and H > 3
+template <typename BIAS_TYPE>
void depthwise_3x3_pad_1(
int N,
int H,
@@ -2500,18 +1202,33 @@ void depthwise_3x3_pad_1(
int32_t A_zero_point,
const uint8_t* A,
int32_t B_zero_point,
- const Packed3x3ConvMatrix& B,
+ const PackedDepthWiseConvMatrix& B,
float C_multiplier,
int32_t C_zero_point,
uint8_t* C,
const int32_t* col_offsets,
- const int32_t* bias,
+ const BIAS_TYPE* bias,
bool fuse_relu,
+ float act_times_w_scale,
int thread_id,
int num_threads) {
+ if (B.GetKernelProduct() != 3 * 3) {
+ string msg =
+ "[FBGEMM_CONV_ERROR] Packed weight is expected to have kernel_prod " +
+ to_string(3 * 3) + " but has " + to_string(B.GetKernelProduct());
+ throw logic_error(msg);
+ }
+ if (stride_h == 0 || stride_w == 0 || num_threads == 0) {
+ assert(0 && "stride_h == 0 || stride_w == 0 || num_threads == 0");
+ return;
+ }
+ if (N == 0) {
+ // In C2, batch size 0 is allowed, so we should just early return.
+ return;
+ }
if (fuse_relu) {
if (7 == H && 7 == W && 1 == stride_h && 1 == stride_w) {
- depthwise_3x3_pad_1_<true /* FUSE_RELU */>(
+ depthwise_3x3_pad_1_<true /* FUSE_RELU */, BIAS_TYPE>(
N,
H,
W,
@@ -2527,10 +1244,11 @@ void depthwise_3x3_pad_1(
C,
col_offsets,
bias,
+ act_times_w_scale,
thread_id,
num_threads);
} else if (14 == H && 14 == W && 2 == stride_h && 2 == stride_w) {
- depthwise_3x3_pad_1_<true /* FUSE_RELU */>(
+ depthwise_3x3_pad_1_<true /* FUSE_RELU */, BIAS_TYPE>(
N,
H,
W,
@@ -2546,10 +1264,11 @@ void depthwise_3x3_pad_1(
C,
col_offsets,
bias,
+ act_times_w_scale,
thread_id,
num_threads);
} else if (1 == stride_h && 1 == stride_w) {
- depthwise_3x3_pad_1_<true /* FUSE_RELU */>(
+ depthwise_3x3_pad_1_<true /* FUSE_RELU */, BIAS_TYPE>(
N,
H,
W,
@@ -2565,10 +1284,11 @@ void depthwise_3x3_pad_1(
C,
col_offsets,
bias,
+ act_times_w_scale,
thread_id,
num_threads);
} else if (2 == stride_h && 2 == stride_w) {
- depthwise_3x3_pad_1_<true /* FUSE_RELU */>(
+ depthwise_3x3_pad_1_<true /* FUSE_RELU */, BIAS_TYPE>(
N,
H,
W,
@@ -2584,10 +1304,11 @@ void depthwise_3x3_pad_1(
C,
col_offsets,
bias,
+ act_times_w_scale,
thread_id,
num_threads);
} else {
- depthwise_3x3_pad_1_<true /* FUSE_RELU */>(
+ depthwise_3x3_pad_1_<true /* FUSE_RELU */, BIAS_TYPE>(
N,
H,
W,
@@ -2603,12 +1324,13 @@ void depthwise_3x3_pad_1(
C,
col_offsets,
bias,
+ act_times_w_scale,
thread_id,
num_threads);
}
} else {
if (7 == H && 7 == W && 1 == stride_h && 1 == stride_w) {
- depthwise_3x3_pad_1_<false /* FUSE_RELU */>(
+ depthwise_3x3_pad_1_<false /* FUSE_RELU */, BIAS_TYPE>(
N,
H,
W,
@@ -2624,10 +1346,11 @@ void depthwise_3x3_pad_1(
C,
col_offsets,
bias,
+ act_times_w_scale,
thread_id,
num_threads);
} else if (14 == H && 14 == W && 2 == stride_h && 2 == stride_w) {
- depthwise_3x3_pad_1_<false /* FUSE_RELU */>(
+ depthwise_3x3_pad_1_<false /* FUSE_RELU */, BIAS_TYPE>(
N,
H,
W,
@@ -2643,10 +1366,11 @@ void depthwise_3x3_pad_1(
C,
col_offsets,
bias,
+ act_times_w_scale,
thread_id,
num_threads);
} else if (1 == stride_h && 1 == stride_w) {
- depthwise_3x3_pad_1_<false /* FUSE_RELU */>(
+ depthwise_3x3_pad_1_<false /* FUSE_RELU */, BIAS_TYPE>(
N,
H,
W,
@@ -2662,10 +1386,11 @@ void depthwise_3x3_pad_1(
C,
col_offsets,
bias,
+ act_times_w_scale,
thread_id,
num_threads);
} else if (2 == stride_h && 2 == stride_w) {
- depthwise_3x3_pad_1_<false /* FUSE_RELU */>(
+ depthwise_3x3_pad_1_<false /* FUSE_RELU */, BIAS_TYPE>(
N,
H,
W,
@@ -2681,10 +1406,11 @@ void depthwise_3x3_pad_1(
C,
col_offsets,
bias,
+ act_times_w_scale,
thread_id,
num_threads);
} else {
- depthwise_3x3_pad_1_<false /* FUSE_RELU */>(
+ depthwise_3x3_pad_1_<false /* FUSE_RELU */, BIAS_TYPE>(
N,
H,
W,
@@ -2700,283 +1426,15 @@ void depthwise_3x3_pad_1(
C,
col_offsets,
bias,
+ act_times_w_scale,
thread_id,
num_threads);
}
}
}
-// Dispatch A_SYMMETRIC and B_SYMMETRIC
-template <bool FUSE_RELU, bool HAS_BIAS>
-static void depthwise_3x3x3_pad_1_(
- int N,
- int T,
- int H,
- int W,
- int K,
- int stride_t,
- int stride_h,
- int stride_w,
- int32_t A_zero_point,
- const uint8_t* A,
- int32_t B_zero_point,
- const Packed3x3x3ConvMatrix& B,
- float C_multiplier,
- int32_t C_zero_point,
- uint8_t* C,
- const int32_t* col_offsets,
- const int32_t* bias,
- int thread_id,
- int num_threads) {
- int32_t* C_int32_temp = new int32_t[(K + 31) / 32 * 32];
- if (A_zero_point == 0 || col_offsets == nullptr) {
- if (B_zero_point == 0) {
- depthwise_3x3x3_pad_1_<
- FUSE_RELU,
- HAS_BIAS,
- true /*A_symmetric*/,
- true /*B_symmetric*/>(
- N,
- T,
- H,
- W,
- K,
- stride_t,
- stride_h,
- stride_w,
- A_zero_point,
- A,
- B_zero_point,
- B,
- C_multiplier,
- C_zero_point,
- C_int32_temp,
- C,
- col_offsets,
- bias,
- thread_id,
- num_threads);
- } else {
- depthwise_3x3x3_pad_1_<
- FUSE_RELU,
- HAS_BIAS,
- true /*A_symmetric*/,
- false /*B_symmetric*/>(
- N,
- T,
- H,
- W,
- K,
- stride_t,
- stride_h,
- stride_w,
- A_zero_point,
- A,
- B_zero_point,
- B,
- C_multiplier,
- C_zero_point,
- C_int32_temp,
- C,
- col_offsets,
- bias,
- thread_id,
- num_threads);
- }
- } else {
- if (B_zero_point == 0) {
- depthwise_3x3x3_pad_1_<
- FUSE_RELU,
- HAS_BIAS,
- false /*A_symmetric*/,
- true /*B_symmetric*/>(
- N,
- T,
- H,
- W,
- K,
- stride_t,
- stride_h,
- stride_w,
- A_zero_point,
- A,
- B_zero_point,
- B,
- C_multiplier,
- C_zero_point,
- C_int32_temp,
- C,
- col_offsets,
- bias,
- thread_id,
- num_threads);
- } else {
- depthwise_3x3x3_pad_1_<
- FUSE_RELU,
- HAS_BIAS,
- false /*A_symmetric*/,
- false /*B_symmetric*/>(
- N,
- T,
- H,
- W,
- K,
- stride_t,
- stride_h,
- stride_w,
- A_zero_point,
- A,
- B_zero_point,
- B,
- C_multiplier,
- C_zero_point,
- C_int32_temp,
- C,
- col_offsets,
- bias,
- thread_id,
- num_threads);
- }
- }
- delete[] C_int32_temp;
-}
-
-// Dispatch HAS_BIAS
-template <bool FUSE_RELU>
-static void depthwise_3x3x3_pad_1_(
- int N,
- int T,
- int H,
- int W,
- int K,
- int stride_t,
- int stride_h,
- int stride_w,
- int32_t A_zero_point,
- const uint8_t* A,
- int32_t B_zero_point,
- const Packed3x3x3ConvMatrix& B,
- float C_multiplier,
- int32_t C_zero_point,
- uint8_t* C,
- const int32_t* col_offsets,
- const int32_t* bias,
- int thread_id,
- int num_threads) {
- if (bias) {
- depthwise_3x3x3_pad_1_<FUSE_RELU, true /*HAS_BIAS*/>(
- N,
- T,
- H,
- W,
- K,
- stride_t,
- stride_h,
- stride_w,
- A_zero_point,
- A,
- B_zero_point,
- B,
- C_multiplier,
- C_zero_point,
- C,
- col_offsets,
- bias,
- thread_id,
- num_threads);
- } else {
- depthwise_3x3x3_pad_1_<FUSE_RELU, false /*HAS_BIAS*/>(
- N,
- T,
- H,
- W,
- K,
- stride_t,
- stride_h,
- stride_w,
- A_zero_point,
- A,
- B_zero_point,
- B,
- C_multiplier,
- C_zero_point,
- C,
- col_offsets,
- bias,
- thread_id,
- num_threads);
- }
-}
-
-// Dispatch FUSE_RELU
-void depthwise_3x3x3_pad_1(
- int N,
- int T,
- int H,
- int W,
- int K,
- int stride_t,
- int stride_h,
- int stride_w,
- int32_t A_zero_point,
- const uint8_t* A,
- int32_t B_zero_point,
- const Packed3x3x3ConvMatrix& B,
- float C_multiplier,
- int32_t C_zero_point,
- uint8_t* C,
- const int32_t* col_offsets,
- const int32_t* bias,
- bool fuse_relu,
- int thread_id,
- int num_threads) {
- if (fuse_relu) {
- depthwise_3x3x3_pad_1_<true /*FUSE_RELU*/>(
- N,
- T,
- H,
- W,
- K,
- stride_t,
- stride_h,
- stride_w,
- A_zero_point,
- A,
- B_zero_point,
- B,
- C_multiplier,
- C_zero_point,
- C,
- col_offsets,
- bias,
- thread_id,
- num_threads);
- } else {
- depthwise_3x3x3_pad_1_<false /*FUSE_RELU*/>(
- N,
- T,
- H,
- W,
- K,
- stride_t,
- stride_h,
- stride_w,
- A_zero_point,
- A,
- B_zero_point,
- B,
- C_multiplier,
- C_zero_point,
- C,
- col_offsets,
- bias,
- thread_id,
- num_threads);
- }
-}
-
// Dispatch A_SYMMETRIC
-template <bool FUSE_RELU, bool HAS_BIAS>
+template <bool FUSE_RELU, bool HAS_BIAS, typename BIAS_TYPE>
static void depthwise_3x3_per_channel_quantization_pad_1_(
int N,
int H,
@@ -2987,12 +1445,13 @@ static void depthwise_3x3_per_channel_quantization_pad_1_(
int32_t A_zero_point,
const uint8_t* A,
const int32_t* B_zero_point,
- const Packed3x3ConvMatrix& Bp,
+ const PackedDepthWiseConvMatrix& Bp,
const float* C_multiplier,
int32_t C_zero_point,
uint8_t* C,
const int32_t* col_offsets,
- const int32_t* bias,
+ const BIAS_TYPE* bias,
+ const float* act_times_w_scale,
int thread_id,
int num_threads) {
int32_t* C_int32_temp = new int32_t[(K + 31) / 32 * 32];
@@ -3000,7 +1459,8 @@ static void depthwise_3x3_per_channel_quantization_pad_1_(
depthwise_3x3_per_channel_quantization_pad_1_<
FUSE_RELU,
HAS_BIAS,
- true /*A_SYMM*/>(
+ true /*A_SYMM*/,
+ BIAS_TYPE>(
N,
H,
W,
@@ -3017,13 +1477,15 @@ static void depthwise_3x3_per_channel_quantization_pad_1_(
C,
col_offsets,
bias,
+ act_times_w_scale,
thread_id,
num_threads);
} else {
depthwise_3x3_per_channel_quantization_pad_1_<
FUSE_RELU,
HAS_BIAS,
- false /*A_SYMM*/>(
+ false /*A_SYMM*/,
+ BIAS_TYPE>(
N,
H,
W,
@@ -3040,6 +1502,7 @@ static void depthwise_3x3_per_channel_quantization_pad_1_(
C,
col_offsets,
bias,
+ act_times_w_scale,
thread_id,
num_threads);
}
@@ -3047,7 +1510,7 @@ static void depthwise_3x3_per_channel_quantization_pad_1_(
}
// Dispatch HAS_BIAS
-template <bool FUSE_RELU>
+template <bool FUSE_RELU, typename BIAS_TYPE>
static void depthwise_3x3_per_channel_quantization_pad_1_(
int N,
int H,
@@ -3058,18 +1521,20 @@ static void depthwise_3x3_per_channel_quantization_pad_1_(
int32_t A_zero_point,
const uint8_t* A,
const int32_t* B_zero_point,
- const Packed3x3ConvMatrix& Bp,
+ const PackedDepthWiseConvMatrix& Bp,
const float* C_multiplier,
int32_t C_zero_point,
uint8_t* C,
const int32_t* col_offsets,
- const int32_t* bias,
+ const BIAS_TYPE* bias,
+ const float* act_times_w_scale,
int thread_id,
int num_threads) {
if (bias) {
depthwise_3x3_per_channel_quantization_pad_1_<
FUSE_RELU,
- true /* HAS_BIAS */>(
+ true /* HAS_BIAS */,
+ BIAS_TYPE>(
N,
H,
W,
@@ -3085,12 +1550,14 @@ static void depthwise_3x3_per_channel_quantization_pad_1_(
C,
col_offsets,
bias,
+ act_times_w_scale,
thread_id,
num_threads);
} else {
depthwise_3x3_per_channel_quantization_pad_1_<
FUSE_RELU,
- false /* HAS_BIAS */>(
+ false /* HAS_BIAS */,
+ BIAS_TYPE>(
N,
H,
W,
@@ -3106,12 +1573,14 @@ static void depthwise_3x3_per_channel_quantization_pad_1_(
C,
col_offsets,
bias,
+ act_times_w_scale,
thread_id,
num_threads);
}
}
// Dispatch input shape and FUSE_RELU
+template <typename BIAS_TYPE>
void depthwise_3x3_per_channel_quantization_pad_1(
int N,
int H,
@@ -3122,18 +1591,35 @@ void depthwise_3x3_per_channel_quantization_pad_1(
int32_t A_zero_point,
const uint8_t* A,
const int32_t* B_zero_point,
- const Packed3x3ConvMatrix& Bp,
+ const PackedDepthWiseConvMatrix& Bp,
const float* C_multiplier,
int32_t C_zero_point,
uint8_t* C,
const int32_t* col_offsets,
- const int32_t* bias,
+ const BIAS_TYPE* bias,
bool fuse_relu,
+ const float* act_times_w_scale,
int thread_id,
int num_threads) {
+ if (Bp.GetKernelProduct() != 3 * 3) {
+ string msg =
+ "[FBGEMM_CONV_ERROR] Packed weight is expected to have kernel_prod " +
+ to_string(3 * 3) + " but has " + to_string(Bp.GetKernelProduct());
+ throw logic_error(msg);
+ }
+ if (stride_h == 0 || stride_w == 0 || num_threads == 0) {
+ assert(0 && "stride_h == 0 || stride_w == 0 || num_threads == 0");
+ return;
+ }
+ if (N == 0) {
+ // In C2, batch size 0 is allowed, so we should just early return.
+ return;
+ }
if (fuse_relu) {
if (7 == H && 7 == W && 1 == stride_h && 1 == stride_w) {
- depthwise_3x3_per_channel_quantization_pad_1_<true /* FUSE_RELU */>(
+ depthwise_3x3_per_channel_quantization_pad_1_<
+ true /* FUSE_RELU */,
+ BIAS_TYPE>(
N,
H,
W,
@@ -3149,10 +1635,13 @@ void depthwise_3x3_per_channel_quantization_pad_1(
C,
col_offsets,
bias,
+ act_times_w_scale,
thread_id,
num_threads);
} else if (14 == H && 14 == W && 2 == stride_h && 2 == stride_w) {
- depthwise_3x3_per_channel_quantization_pad_1_<true /* FUSE_RELU */>(
+ depthwise_3x3_per_channel_quantization_pad_1_<
+ true /* FUSE_RELU */,
+ BIAS_TYPE>(
N,
H,
W,
@@ -3168,10 +1657,13 @@ void depthwise_3x3_per_channel_quantization_pad_1(
C,
col_offsets,
bias,
+ act_times_w_scale,
thread_id,
num_threads);
} else if (1 == stride_h && 1 == stride_w) {
- depthwise_3x3_per_channel_quantization_pad_1_<true /* FUSE_RELU */>(
+ depthwise_3x3_per_channel_quantization_pad_1_<
+ true /* FUSE_RELU */,
+ BIAS_TYPE>(
N,
H,
W,
@@ -3187,10 +1679,13 @@ void depthwise_3x3_per_channel_quantization_pad_1(
C,
col_offsets,
bias,
+ act_times_w_scale,
thread_id,
num_threads);
} else if (2 == stride_h && 2 == stride_w) {
- depthwise_3x3_per_channel_quantization_pad_1_<true /* FUSE_RELU */>(
+ depthwise_3x3_per_channel_quantization_pad_1_<
+ true /* FUSE_RELU */,
+ BIAS_TYPE>(
N,
H,
W,
@@ -3206,10 +1701,13 @@ void depthwise_3x3_per_channel_quantization_pad_1(
C,
col_offsets,
bias,
+ act_times_w_scale,
thread_id,
num_threads);
} else {
- depthwise_3x3_per_channel_quantization_pad_1_<true /* FUSE_RELU */>(
+ depthwise_3x3_per_channel_quantization_pad_1_<
+ true /* FUSE_RELU */,
+ BIAS_TYPE>(
N,
H,
W,
@@ -3225,12 +1723,15 @@ void depthwise_3x3_per_channel_quantization_pad_1(
C,
col_offsets,
bias,
+ act_times_w_scale,
thread_id,
num_threads);
}
} else {
if (7 == H && 7 == W && 1 == stride_h && 1 == stride_w) {
- depthwise_3x3_per_channel_quantization_pad_1_<false /* FUSE_RELU */>(
+ depthwise_3x3_per_channel_quantization_pad_1_<
+ false /* FUSE_RELU */,
+ BIAS_TYPE>(
N,
H,
W,
@@ -3246,10 +1747,13 @@ void depthwise_3x3_per_channel_quantization_pad_1(
C,
col_offsets,
bias,
+ act_times_w_scale,
thread_id,
num_threads);
} else if (14 == H && 14 == W && 2 == stride_h && 2 == stride_w) {
- depthwise_3x3_per_channel_quantization_pad_1_<false /* FUSE_RELU */>(
+ depthwise_3x3_per_channel_quantization_pad_1_<
+ false /* FUSE_RELU */,
+ BIAS_TYPE>(
N,
H,
W,
@@ -3265,10 +1769,13 @@ void depthwise_3x3_per_channel_quantization_pad_1(
C,
col_offsets,
bias,
+ act_times_w_scale,
thread_id,
num_threads);
} else if (1 == stride_h && 1 == stride_w) {
- depthwise_3x3_per_channel_quantization_pad_1_<false /* FUSE_RELU */>(
+ depthwise_3x3_per_channel_quantization_pad_1_<
+ false /* FUSE_RELU */,
+ BIAS_TYPE>(
N,
H,
W,
@@ -3284,10 +1791,13 @@ void depthwise_3x3_per_channel_quantization_pad_1(
C,
col_offsets,
bias,
+ act_times_w_scale,
thread_id,
num_threads);
} else if (2 == stride_h && 2 == stride_w) {
- depthwise_3x3_per_channel_quantization_pad_1_<false /* FUSE_RELU */>(
+ depthwise_3x3_per_channel_quantization_pad_1_<
+ false /* FUSE_RELU */,
+ BIAS_TYPE>(
N,
H,
W,
@@ -3303,10 +1813,13 @@ void depthwise_3x3_per_channel_quantization_pad_1(
C,
col_offsets,
bias,
+ act_times_w_scale,
thread_id,
num_threads);
} else {
- depthwise_3x3_per_channel_quantization_pad_1_<false /* FUSE_RELU */>(
+ depthwise_3x3_per_channel_quantization_pad_1_<
+ false /* FUSE_RELU */,
+ BIAS_TYPE>(
N,
H,
W,
@@ -3322,225 +1835,179 @@ void depthwise_3x3_per_channel_quantization_pad_1(
C,
col_offsets,
bias,
+ act_times_w_scale,
thread_id,
num_threads);
}
}
}
-// Dispatch A_SYMMETRIC
-template <bool FUSE_RELU, bool HAS_BIAS>
-static void depthwise_3x3x3_per_channel_quantization_pad_1_(
+// To be removed
+void depthwise_3x3_pad_1(
int N,
- int T,
int H,
int W,
int K,
- int stride_t,
int stride_h,
int stride_w,
int32_t A_zero_point,
const uint8_t* A,
- const int32_t* B_zero_point,
- const Packed3x3x3ConvMatrix& B,
- const float* C_multiplier,
+ int32_t B_zero_point,
+ const PackedDepthWiseConvMatrix& B,
+ float C_multiplier,
int32_t C_zero_point,
uint8_t* C,
const int32_t* col_offsets,
const int32_t* bias,
+ bool fuse_relu,
int thread_id,
int num_threads) {
- int32_t* C_int32_temp = new int32_t[(K + 31) / 32 * 32];
- if (A_zero_point == 0 || col_offsets == nullptr) {
- depthwise_3x3x3_per_channel_quantization_pad_1_<
- FUSE_RELU,
- HAS_BIAS,
- true /*A_SYMM*/>(
- N,
- T,
- H,
- W,
- K,
- stride_t,
- stride_h,
- stride_w,
- A_zero_point,
- A,
- B_zero_point,
- B,
- C_multiplier,
- C_zero_point,
- C_int32_temp,
- C,
- col_offsets,
- bias,
- thread_id,
- num_threads);
- } else {
- depthwise_3x3x3_per_channel_quantization_pad_1_<
- FUSE_RELU,
- HAS_BIAS,
- false /*A_SYMM*/>(
- N,
- T,
- H,
- W,
- K,
- stride_t,
- stride_h,
- stride_w,
- A_zero_point,
- A,
- B_zero_point,
- B,
- C_multiplier,
- C_zero_point,
- C_int32_temp,
- C,
- col_offsets,
- bias,
- thread_id,
- num_threads);
- }
- delete[] C_int32_temp;
+ depthwise_3x3_pad_1<std::int32_t>(
+ N,
+ H,
+ W,
+ K,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B_zero_point,
+ B,
+ C_multiplier,
+ C_zero_point,
+ C,
+ col_offsets,
+ bias,
+ fuse_relu,
+ 1.0f,
+ thread_id,
+ num_threads);
}
-// Dispatch HAS_BIAS
-template <bool FUSE_RELU>
-static void depthwise_3x3x3_per_channel_quantization_pad_1_(
+// To be removed
+void depthwise_3x3_per_channel_quantization_pad_1(
int N,
- int T,
int H,
int W,
int K,
- int stride_t,
int stride_h,
int stride_w,
int32_t A_zero_point,
const uint8_t* A,
const int32_t* B_zero_point,
- const Packed3x3x3ConvMatrix& B,
+ const PackedDepthWiseConvMatrix& Bp,
const float* C_multiplier,
int32_t C_zero_point,
uint8_t* C,
const int32_t* col_offsets,
const int32_t* bias,
+ bool fuse_relu,
int thread_id,
int num_threads) {
- if (bias) {
- depthwise_3x3x3_per_channel_quantization_pad_1_<
- FUSE_RELU,
- true /* HAS_BIAS */>(
- N,
- T,
- H,
- W,
- K,
- stride_t,
- stride_h,
- stride_w,
- A_zero_point,
- A,
- B_zero_point,
- B,
- C_multiplier,
- C_zero_point,
- C,
- col_offsets,
- bias,
- thread_id,
- num_threads);
- } else {
- depthwise_3x3x3_per_channel_quantization_pad_1_<
- FUSE_RELU,
- false /* HAS_BIAS */>(
- N,
- T,
- H,
- W,
- K,
- stride_t,
- stride_h,
- stride_w,
- A_zero_point,
- A,
- B_zero_point,
- B,
- C_multiplier,
- C_zero_point,
- C,
- col_offsets,
- bias,
- thread_id,
- num_threads);
- }
+ depthwise_3x3_per_channel_quantization_pad_1<std::int32_t>(
+ N,
+ H,
+ W,
+ K,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B_zero_point,
+ Bp,
+ C_multiplier,
+ C_zero_point,
+ C,
+ col_offsets,
+ bias,
+ fuse_relu,
+ nullptr,
+ thread_id,
+ num_threads);
}
-// Dispatch FUSE_RELU
-void depthwise_3x3x3_per_channel_quantization_pad_1(
+template void depthwise_3x3_pad_1(
+ int N,
+ int H,
+ int W,
+ int K,
+ int stride_h,
+ int stride_w,
+ int32_t A_zero_point,
+ const uint8_t* A,
+ int32_t B_zero_point,
+ const PackedDepthWiseConvMatrix& B,
+ float C_multiplier,
+ int32_t C_zero_point,
+ uint8_t* C,
+ const int32_t* col_offsets,
+ const int32_t* bias,
+ bool fuse_relu,
+ float act_times_w_scale,
+ int thread_id,
+ int num_threads);
+
+template void depthwise_3x3_pad_1(
+ int N,
+ int H,
+ int W,
+ int K,
+ int stride_h,
+ int stride_w,
+ int32_t A_zero_point,
+ const uint8_t* A,
+ int32_t B_zero_point,
+ const PackedDepthWiseConvMatrix& B,
+ float C_multiplier,
+ int32_t C_zero_point,
+ uint8_t* C,
+ const int32_t* col_offsets,
+ const float* bias,
+ bool fuse_relu,
+ float act_times_w_scale,
+ int thread_id,
+ int num_threads);
+
+template void depthwise_3x3_per_channel_quantization_pad_1(
int N,
- int T,
int H,
int W,
int K,
- int stride_t,
int stride_h,
int stride_w,
int32_t A_zero_point,
const uint8_t* A,
const int32_t* B_zero_point,
- const Packed3x3x3ConvMatrix& B,
+ const PackedDepthWiseConvMatrix& Bp,
const float* C_multiplier,
int32_t C_zero_point,
uint8_t* C,
const int32_t* col_offsets,
const int32_t* bias,
bool fuse_relu,
+ const float* act_times_w_scale,
int thread_id,
- int num_threads) {
- if (fuse_relu) {
- depthwise_3x3x3_per_channel_quantization_pad_1_<true /* FUSE_RELU */>(
- N,
- T,
- H,
- W,
- K,
- stride_t,
- stride_h,
- stride_w,
- A_zero_point,
- A,
- B_zero_point,
- B,
- C_multiplier,
- C_zero_point,
- C,
- col_offsets,
- bias,
- thread_id,
- num_threads);
- } else {
- depthwise_3x3x3_per_channel_quantization_pad_1_<false /* FUSE_RELU */>(
- N,
- T,
- H,
- W,
- K,
- stride_t,
- stride_h,
- stride_w,
- A_zero_point,
- A,
- B_zero_point,
- B,
- C_multiplier,
- C_zero_point,
- C,
- col_offsets,
- bias,
- thread_id,
- num_threads);
- }
-}
+ int num_threads);
+
+template void depthwise_3x3_per_channel_quantization_pad_1(
+ int N,
+ int H,
+ int W,
+ int K,
+ int stride_h,
+ int stride_w,
+ int32_t A_zero_point,
+ const uint8_t* A,
+ const int32_t* B_zero_point,
+ const PackedDepthWiseConvMatrix& Bp,
+ const float* C_multiplier,
+ int32_t C_zero_point,
+ uint8_t* C,
+ const int32_t* col_offsets,
+ const float* bias,
+ bool fuse_relu,
+ const float* act_times_w_scale,
+ int thread_id,
+ int num_threads);
} // namespace fbgemm
diff --git a/src/GenerateKernel.h b/src/GenerateKernel.h
index dccdfc5..c0fece4 100644
--- a/src/GenerateKernel.h
+++ b/src/GenerateKernel.h
@@ -8,8 +8,11 @@
#include <asmjit/asmjit.h>
#include <cpuinfo.h>
#include <map>
+#include <mutex>
+#include <sstream>
#include <string>
#include <tuple>
+#include "CodeCache.h"
#include "fbgemm/Fbgemm.h"
/*#define FBGEMM_LOG_CODE 1*/
@@ -18,7 +21,7 @@ namespace fbgemm {
namespace x86 = asmjit::x86;
/**
- * @brief AVX2/AVX512 JIT assembly code generator.
+ * @brief AVX2/AVX512/AVX512VNNI JIT assembly code generator.
* @tparam TA Type of matrix A.
* @tparam TB Type of matrix B.
* @tparam TC Type of matrix C.
@@ -40,35 +43,7 @@ class CodeGenBase {
* @brief Constructor for initializing AVX2/AVX512 registers.
*/
CodeGenBase(const BlockingFactors* params = nullptr)
- : blocking_params(params),
- CRegs_avx2_{x86::ymm0,
- x86::ymm1,
- x86::ymm2,
- x86::ymm3,
- x86::ymm4,
- x86::ymm5,
- x86::ymm6,
- x86::ymm7,
- x86::ymm8,
- x86::ymm9,
- x86::ymm10,
- x86::ymm11},
- CRegs_avx512_{
- x86::zmm0, x86::zmm1, x86::zmm2, x86::zmm3, x86::zmm4,
- x86::zmm5, x86::zmm6, x86::zmm7, x86::zmm8, x86::zmm9,
- x86::zmm10, x86::zmm11, x86::zmm12, x86::zmm13, x86::zmm14,
- x86::zmm15, x86::zmm16, x86::zmm17, x86::zmm18, x86::zmm19,
- x86::zmm20, x86::zmm21, x86::zmm22, x86::zmm23, x86::zmm24,
- x86::zmm25, x86::zmm26, x86::zmm27,
- },
- AllRegs_avx512_{x86::zmm0, x86::zmm1, x86::zmm2, x86::zmm3,
- x86::zmm4, x86::zmm5, x86::zmm6, x86::zmm7,
- x86::zmm8, x86::zmm9, x86::zmm10, x86::zmm11,
- x86::zmm12, x86::zmm13, x86::zmm14, x86::zmm15,
- x86::zmm16, x86::zmm17, x86::zmm18, x86::zmm19,
- x86::zmm20, x86::zmm21, x86::zmm22, x86::zmm23,
- x86::zmm24, x86::zmm25, x86::zmm26, x86::zmm27,
- x86::zmm28, x86::zmm29, x86::zmm30, x86::zmm31} {
+ : blocking_params(params) {
// vector width in bits
if (cpuinfo_initialize()) {
if (fbgemmHasAvx512Support()) {
@@ -104,7 +79,7 @@ class CodeGenBase {
*/
template <inst_set_t instSet>
void initCRegs(
- asmjit::X86Emitter* a,
+ x86::Emitter* a,
int rowRegs,
int colRegs,
int leadingDimCRegAssign = 4);
@@ -114,10 +89,10 @@ class CodeGenBase {
*/
template <inst_set_t instSet>
void genComputeBlock(
- asmjit::X86Emitter* a,
- asmjit::X86Gp buffer_A,
- asmjit::X86Gp buffer_B,
- asmjit::X86Gp B_pf,
+ x86::Emitter* a,
+ x86::Gp buffer_A,
+ x86::Gp buffer_B,
+ x86::Gp B_pf,
int rowRegs,
int colRegs,
int lda,
@@ -129,11 +104,11 @@ class CodeGenBase {
*/
template <inst_set_t instSet>
void storeCRegs(
- asmjit::X86Emitter* a,
+ x86::Emitter* a,
int rowRegs,
int colRegs,
- asmjit::X86Gp C_Offset,
- asmjit::X86Gp ldcReg,
+ x86::Gp C_Offset,
+ x86::Gp ldcReg,
bool accum,
int leadingDimCRegAssign = 4);
@@ -143,7 +118,7 @@ class CodeGenBase {
* (debug-only)
*/
template <inst_set_t instSet>
- std::string getCodeLoggingFile(
+ static std::string getCodeLoggingFile(
bool accum,
int mc,
int nc,
@@ -152,48 +127,60 @@ class CodeGenBase {
int MR,
int NR,
int NR_MIN) {
- std::string fileName = "gemm_";
+ std::ostringstream oss;
+ oss << "gemm_";
if (std::is_same<accT, std::int16_t>::value) {
- fileName += "acc16_";
+ oss << "acc16_";
} else if (std::is_same<accT, std::int32_t>::value) {
- fileName += "acc32_";
+ oss << "acc32_";
} else {
- fileName += "unknown_";
+ oss << "unknown_";
}
- fileName += "accum-" + std::to_string(accum);
- fileName += "_MC-" + std::to_string(mc);
- fileName += "_NC-" + std::to_string(nc);
- fileName += "_NCB-" + std::to_string(NCB);
- fileName += "_NCB-" + std::to_string(KCB);
- fileName += "_MR-" + std::to_string(MR);
- fileName += "_NR-" + std::to_string(NR);
- fileName += "_NR_MIN-" + std::to_string(NR_MIN);
- if (instSet == inst_set_t::avx512) {
- fileName += "_avx512";
+ oss << "accum-" + std::to_string(accum)
+ << "_MC-" + std::to_string(mc)
+ << "_NC-" + std::to_string(nc)
+ << "_NCB-" + std::to_string(NCB)
+ << "_NCB-" + std::to_string(KCB)
+ << "_MR-" + std::to_string(MR)
+ << "_NR-" + std::to_string(NR)
+ << "_NR_MIN-" + std::to_string(NR_MIN);
+ if (instSet == inst_set_t::avx512_vnni) {
+ oss << "_avx512vnni";
+ } else if (instSet == inst_set_t::avx512) {
+ oss << "_avx512";
} else if (instSet == inst_set_t::avx2) {
- fileName += "_avx2";
+ oss << "_avx2";
}
- fileName += ".txt";
- return fileName;
+ oss << ".txt";
+ return oss.str();
}
private:
- asmjit::X86Ymm
- CRegs_avx2_[12]; ///< AVX2 ymm registers for C in the micro-kernel.
- asmjit::X86Zmm
- CRegs_avx512_[28]; ///< AVX512 zmm registers for C in the micro-kernel.
- asmjit::X86Zmm
- AllRegs_avx512_[32]; ///< all AVX512 zmm registers.
-
int vectorWidth_; ///< Vector width in bits.
int VLEN_; ///< Vector width in elements.
- static thread_local asmjit::JitRuntime rt_; ///< JIT Runtime for asmjit.
- static thread_local asmjit::CodeHolder code_; ///< JIT Code Holder for asmjit.
+
+ static asmjit::JitRuntime &runtime() {
+ static asmjit::JitRuntime rt; //< JIT Runtime for asmjit,
+ // depents on other static
+ // variables. Required to prevent
+ // initialization order fiasco
+ return rt;
+ }
+
+ static std::mutex rtMutex_; ///< Controll access to runtime;
+
// The hash depends on accumulate, mc, nc, ncb, kcb, nr, mr, nr_min
- static thread_local std::map<
- std::tuple<bool, int, int, int, int, int, int, int>,
- jit_micro_kernel_fp>
+ static CodeCache<std::tuple<bool, int, int, int, int, int, int, int>,
+ jit_micro_kernel_fp>
codeCache_; ///< JIT Code Cache for reuse.
};
+template <typename TA, typename TB, typename TC, typename accT>
+std::mutex CodeGenBase<TA, TB, TC, accT>::rtMutex_;
+
+template <typename TA, typename TB, typename TC, typename accT>
+CodeCache<std::tuple<bool, int, int, int, int, int, int, int>,
+ typename CodeGenBase<TA, TB, TC, accT>::jit_micro_kernel_fp>
+ CodeGenBase<TA, TB, TC, accT>::codeCache_;
+
} // namespace fbgemm
diff --git a/src/GenerateKernelU8S8S32ACC16.cc b/src/GenerateKernelU8S8S32ACC16.cc
index 082518c..205af14 100644
--- a/src/GenerateKernelU8S8S32ACC16.cc
+++ b/src/GenerateKernelU8S8S32ACC16.cc
@@ -9,18 +9,6 @@
namespace fbgemm {
-template <typename TA, typename TB, typename TC, typename accT>
-thread_local asmjit::JitRuntime CodeGenBase<TA, TB, TC, accT>::rt_;
-
-template <typename TA, typename TB, typename TC, typename accT>
-thread_local asmjit::CodeHolder CodeGenBase<TA, TB, TC, accT>::code_;
-
-template <typename TA, typename TB, typename TC, typename accT>
-thread_local std::map<
- std::tuple<bool, int, int, int, int, int, int, int>,
- typename CodeGenBase<TA, TB, TC, accT>::jit_micro_kernel_fp>
- CodeGenBase<TA, TB, TC, accT>::codeCache_;
-
namespace x86 = asmjit::x86;
/**
@@ -31,16 +19,17 @@ template <>
template <>
void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::initCRegs<
inst_set_t::avx2>(
- asmjit::X86Emitter* a,
+ x86::Emitter* a,
int rowRegs,
int colRegs,
- int leadingDimCRegAssign) {
+ int leadingDimCReg) {
+ using CRegs = x86::Ymm;
for (int i = 0; i < rowRegs; ++i) {
for (int j = 0; j < colRegs; ++j) {
a->vxorps(
- CRegs_avx2_[i * leadingDimCRegAssign + j],
- CRegs_avx2_[i * leadingDimCRegAssign + j],
- CRegs_avx2_[i * leadingDimCRegAssign + j]);
+ CRegs(i * leadingDimCReg + j),
+ CRegs(i * leadingDimCReg + j),
+ CRegs(i * leadingDimCReg + j));
}
}
}
@@ -53,18 +42,20 @@ template <>
template <>
void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::genComputeBlock<
inst_set_t::avx2>(
- asmjit::X86Emitter* a,
- asmjit::X86Gp buffer_A,
- asmjit::X86Gp buffer_B,
- asmjit::X86Gp /* unused (reserved for prefetching)*/,
+ x86::Emitter* a,
+ x86::Gp buffer_A,
+ x86::Gp buffer_B,
+ x86::Gp /* unused (reserved for prefetching)*/,
int rowRegs,
int colRegs,
int lda,
- int leadingDimCRegAssign) {
+ int leadingDimCReg) {
// used for matrix A
- asmjit::X86Ymm AReg = x86::ymm12;
+ x86::Ymm AReg = x86::ymm12;
- asmjit::X86Ymm tmpReg = x86::ymm14;
+ x86::Ymm tmpReg = x86::ymm14;
+
+ using CRegs = x86::Ymm;
for (int i = 0; i < rowRegs; ++i) {
// broadcast A
@@ -74,9 +65,9 @@ void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::genComputeBlock<
a->vpmaddubsw(
tmpReg, AReg, x86::dword_ptr(buffer_B, j * VLEN_ * sizeof(int8_t)));
a->vpaddsw(
- CRegs_avx2_[i * leadingDimCRegAssign + j],
+ CRegs(i * leadingDimCReg + j),
tmpReg,
- CRegs_avx2_[i * leadingDimCRegAssign + j]);
+ CRegs(i * leadingDimCReg + j));
// Prefetching is hurting performance in some cases
// because prefetch instructions itself consumes a slot
// in pipeline issue thus slowing down the kernel.
@@ -95,25 +86,30 @@ template <>
template <>
void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::storeCRegs<
inst_set_t::avx2>(
- asmjit::X86Emitter* a,
+ x86::Emitter* a,
int rowRegs,
int colRegs,
- asmjit::X86Gp C_Offset,
- asmjit::X86Gp ldcReg,
+ x86::Gp C_Offset,
+ x86::Gp ldcReg,
bool accum,
- int leadingDimCRegAssign) {
- asmjit::X86Xmm extractDest128 = x86::xmm15;
- asmjit::X86Ymm extractDest256 = x86::ymm15;
+ int leadingDimCReg) {
+ x86::Xmm extractDest128 = x86::xmm15;
+ x86::Ymm extractDest256 = x86::ymm15;
+ using CRegs = x86::Ymm;
for (int i = 0; i < rowRegs; ++i) {
a->imul(C_Offset, ldcReg, static_cast<asmjit::Imm>(i * sizeof(int32_t)));
for (int j = 0; j < colRegs; ++j) {
for (int idx = 0; idx < 2; ++idx) {
a->vextracti128(
- extractDest128, CRegs_avx2_[i * leadingDimCRegAssign + j], idx);
+ extractDest128, CRegs(i * leadingDimCReg + j), idx);
a->vpmovsxwd(extractDest256, extractDest128);
- asmjit::X86Mem destAddr = x86::dword_ptr(
+ x86::Mem destAddr = x86::dword_ptr(
+#ifdef _MSC_VER
+ a->gpz(9), C_Offset, 0, (j * 2 + idx) * 8 * sizeof(int32_t));
+#else
a->zcx(), C_Offset, 0, (j * 2 + idx) * 8 * sizeof(int32_t));
+#endif
if (accum) {
a->vpaddd(extractDest256, extractDest256, destAddr);
}
@@ -172,192 +168,195 @@ CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<inst_set_t::avx2>(
nRegBlockSize,
nRegBlockSizeMin);
- if (codeCache_.find(kernelSig) != codeCache_.end()) {
- return codeCache_[kernelSig];
- }
- code_.reset(false);
- code_.init(rt_.getCodeInfo());
- asmjit::X86Assembler assembler(&code_);
- asmjit::X86Emitter* a = assembler.asEmitter();
+ return codeCache_.getOrCreate(kernelSig, [&]() -> jit_micro_kernel_fp {
+ asmjit::CodeHolder code;
+ code.init(runtime().codeInfo());
+ x86::Assembler assembler(&code);
+ x86::Emitter *a = assembler.as<x86::Emitter>();
#if defined(FBGEMM_LOG_CODE)
- // generated code logging
- FILE* codeLogfile = fopen(
- getCodeLoggingFile<inst_set_t::avx2>(
- accum,
- mc,
- nc,
- nBlock,
- kBlock,
- mRegBlockSize,
- nRegBlockSize,
- nRegBlockSizeMin)
- .c_str(),
- "w");
- asmjit::FileLogger* codeLogger = new asmjit::FileLogger(codeLogfile);
- if (codeLogger) {
- code_.setLogger(codeLogger);
- }
+ // generated code logging
+ FILE *codeLogfile = fopen(getCodeLoggingFile<inst_set_t::avx2>(
+ accum, mc, nc, nBlock, kBlock, mRegBlockSize,
+ nRegBlockSize, nRegBlockSizeMin)
+ .c_str(),
+ "w");
+ asmjit::FileLogger *codeLogger = new asmjit::FileLogger(codeLogfile);
+ if (codeLogger) {
+ code.setLogger(codeLogger);
+ }
#endif
- int mRegBlocks = mc / mRegBlockSize;
- int mRegBlocksRem = mc % mRegBlockSize;
- assert(kc % row_interleave == 0 && "kc must be a multiple of row_interleave");
- // assert((nc == nRegBlockSize) &&
- //"nc must be equal to the number of register blocks");
-
- // arguments to the function created
- asmjit::X86Gp buffer_A = a->zdi();
- asmjit::X86Gp buffer_B = a->zsi();
- asmjit::X86Gp B_pf = a->zdx();
- asmjit::X86Gp CBase = a->zcx();
- asmjit::X86Gp kSize = a->gpzRef(8);
- asmjit::X86Gp ldcReg = a->gpzRef(9);
-
- asmjit::FuncDetail func;
- func.init(
- asmjit::
- FuncSignature6<void, uint8_t*, int8_t*, int8_t*, int32_t*, int, int>(
- asmjit::CallConv::kIdHost));
-
- asmjit::FuncFrameInfo ffi;
- ffi.setDirtyRegs(
- asmjit::X86Reg::kKindVec,
- asmjit::Utils::mask(0, 1, 2, 3, 4, 5, 6, 7) |
- asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14, 15));
- ffi.setDirtyRegs(
- asmjit::X86Reg::kKindGp, asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14));
-
- asmjit::FuncArgsMapper args(&func);
- args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg);
-
- args.updateFrameInfo(ffi);
-
- asmjit::FuncFrameLayout layout;
- layout.init(func, ffi);
-
- asmjit::FuncUtils::emitProlog(a, layout);
- asmjit::FuncUtils::allocArgs(a, layout, args);
-
- asmjit::Label Loopk = a->newLabel();
- asmjit::Label LoopMBlocks = a->newLabel();
-
- asmjit::X86Gp buffer_B_saved = a->gpzRef(10);
- asmjit::X86Gp C_Offset = a->gpzRef(11);
- // asmjit::X86Gp B_pf_saved = a->gpzRef(12);
- asmjit::X86Gp iIdx = a->gpzRef(13);
- asmjit::X86Gp kIdx = a->gpzRef(14);
-
- int colRegs = nc * row_interleave * sizeof(int8_t) / VLEN_;
- if (mRegBlocks > 0) {
- // move 0 to iteration variables
- a->mov(iIdx, 0);
-
- // save B_buffer address
- a->mov(buffer_B_saved, buffer_B);
- // a->mov(B_pf_saved, B_pf);
-
- a->bind(LoopMBlocks);
- a->inc(iIdx);
-
- int rowRegs = mRegBlockSize;
-
- // init C registers
- initCRegs<inst_set_t::avx2>(a, rowRegs, colRegs);
-
- // init k loop index
- a->mov(kIdx, 0);
- a->bind(Loopk);
- // k is incremented by row_interleave
- a->add(kIdx, static_cast<asmjit::Imm>(row_interleave));
-
- genComputeBlock<inst_set_t::avx2>(
- a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock);
-
- // update buffer_A address for next k iteration
- a->add(
- buffer_A, static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t)));
-
- // update buffer_B address for next k iteration
- a->add(
- buffer_B,
- static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t)));
- // a->add(B_pf, static_cast<asmjit::Imm>(nBlock * row_interleave *
- // sizeof(int8_t)));
-
- a->cmp(kIdx, kSize);
- a->jl(Loopk);
-
- // store C matrix
- storeCRegs<inst_set_t::avx2>(a, rowRegs, colRegs, C_Offset, ldcReg, accum);
-
- // increment A for next block
- a->sub(buffer_A, kSize);
- a->add(
- buffer_A, static_cast<asmjit::Imm>((rowRegs)*kBlock * sizeof(uint8_t)));
- // increment C for next block
- a->imul(
- C_Offset, ldcReg, static_cast<asmjit::Imm>(rowRegs * sizeof(int32_t)));
- a->add(CBase, C_Offset);
- // reset B
- a->mov(buffer_B, buffer_B_saved);
- // a->mov(B_pf, B_pf_saved);
-
- a->cmp(iIdx, mRegBlocks);
- a->jl(LoopMBlocks);
- }
- // generate code for remainder
- if (mRegBlocksRem > 0) {
- asmjit::Label LoopkRem = a->newLabel();
- int rowRegs = mRegBlocksRem;
+ int mRegBlocks = mc / mRegBlockSize;
+ int mRegBlocksRem = mc % mRegBlockSize;
+ assert(kc % row_interleave == 0 &&
+ "kc must be a multiple of row_interleave");
+ // assert((nc == nRegBlockSize) &&
+ //"nc must be equal to the number of register blocks");
+
+ // arguments to the function created
+#ifdef _MSC_VER
+ x86::Gp buffer_A = a->zcx();
+ x86::Gp buffer_B = a->zdx();
+ x86::Gp B_pf = a->gpz(8);
+ x86::Gp CBase = a->gpz(9);
+ x86::Gp kSize = a->zdi();
+ x86::Gp ldcReg = a->zsi();
+#else
+ x86::Gp buffer_A = a->zdi();
+ x86::Gp buffer_B = a->zsi();
+ x86::Gp B_pf = a->zdx();
+ x86::Gp CBase = a->zcx();
+ x86::Gp kSize = a->gpz(8);
+ x86::Gp ldcReg = a->gpz(9);
+#endif
+
+ asmjit::FuncDetail func;
+ func.init(
+ asmjit::FuncSignatureT<void, uint8_t *, int8_t *, int8_t *, int32_t *,
+ int, int>(asmjit::CallConv::kIdHost));
+
+ asmjit::FuncFrame frame;
+ frame.init(func);
+ frame.setDirtyRegs(
+ x86::Reg::kGroupVec,
+ asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) |
+ asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
+ frame.setDirtyRegs(x86::Reg::kGroupGp,
+ asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14));
+
+ asmjit::FuncArgsAssignment args(&func);
+ args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg);
+
+ args.updateFuncFrame(frame);
+ frame.finalize();
+
+ a->emitProlog(frame);
+ a->emitArgsAssignment(frame, args);
+
+ asmjit::Label Loopk = a->newLabel();
+ asmjit::Label LoopMBlocks = a->newLabel();
+
+ x86::Gp buffer_B_saved = a->gpz(10);
+ x86::Gp C_Offset = a->gpz(11);
+ // x86::Gp B_pf_saved = a->gpz(12);
+ x86::Gp iIdx = a->gpz(13);
+ x86::Gp kIdx = a->gpz(14);
+
+ int colRegs = nc * row_interleave * sizeof(int8_t) / VLEN_;
+ if (mRegBlocks > 0) {
+ // move 0 to iteration variables
+ a->mov(iIdx, 0);
+
+ // save B_buffer address
+ a->mov(buffer_B_saved, buffer_B);
+ // a->mov(B_pf_saved, B_pf);
+
+ a->bind(LoopMBlocks);
+ a->inc(iIdx);
+
+ int rowRegs = mRegBlockSize;
+
+ // init C registers
+ initCRegs<inst_set_t::avx2>(a, rowRegs, colRegs);
+
+ // init k loop index
+ a->mov(kIdx, 0);
+ a->bind(Loopk);
+ // k is incremented by row_interleave
+ a->add(kIdx, static_cast<asmjit::Imm>(row_interleave));
+
+ genComputeBlock<inst_set_t::avx2>(a, buffer_A, buffer_B, B_pf, rowRegs,
+ colRegs, kBlock);
+
+ // update buffer_A address for next k iteration
+ a->add(buffer_A,
+ static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t)));
+
+ // update buffer_B address for next k iteration
+ a->add(buffer_B, static_cast<asmjit::Imm>(nBlock * row_interleave *
+ sizeof(int8_t)));
+ // a->add(B_pf, static_cast<asmjit::Imm>(nBlock * row_interleave *
+ // sizeof(int8_t)));
+
+ a->cmp(kIdx, kSize);
+ a->jl(Loopk);
+
+ // store C matrix
+ storeCRegs<inst_set_t::avx2>(a, rowRegs, colRegs, C_Offset, ldcReg,
+ accum);
+
+ // increment A for next block
+ a->sub(buffer_A, kSize);
+ a->add(buffer_A,
+ static_cast<asmjit::Imm>((rowRegs)*kBlock * sizeof(uint8_t)));
+ // increment C for next block
+ a->imul(C_Offset, ldcReg,
+ static_cast<asmjit::Imm>(rowRegs * sizeof(int32_t)));
+ a->add(CBase, C_Offset);
+ // reset B
+ a->mov(buffer_B, buffer_B_saved);
+ // a->mov(B_pf, B_pf_saved);
+
+ a->cmp(iIdx, mRegBlocks);
+ a->jl(LoopMBlocks);
+ }
+ // generate code for remainder
+ if (mRegBlocksRem > 0) {
+ asmjit::Label LoopkRem = a->newLabel();
+ int rowRegs = mRegBlocksRem;
- // init C registers
- initCRegs<inst_set_t::avx2>(a, rowRegs, colRegs);
+ // init C registers
+ initCRegs<inst_set_t::avx2>(a, rowRegs, colRegs);
- // init k loop index
- a->mov(kIdx, 0);
- a->bind(LoopkRem);
+ // init k loop index
+ a->mov(kIdx, 0);
+ a->bind(LoopkRem);
- // k is incremented by row_interleave
- a->add(kIdx, static_cast<asmjit::Imm>(row_interleave));
+ // k is incremented by row_interleave
+ a->add(kIdx, static_cast<asmjit::Imm>(row_interleave));
- genComputeBlock<inst_set_t::avx2>(
- a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock);
+ genComputeBlock<inst_set_t::avx2>(a, buffer_A, buffer_B, B_pf, rowRegs,
+ colRegs, kBlock);
- // update buffer_A address for next k iteration
- a->add(
- buffer_A, static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t)));
+ // update buffer_A address for next k iteration
+ a->add(buffer_A,
+ static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t)));
- // update buffer_B address for next k iteration
- a->add(
- buffer_B,
- static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t)));
- // a->add(B_pf, static_cast<asmjit::Imm>(nBlock * row_interleave *
- // sizeof(int8_t)));
+ // update buffer_B address for next k iteration
+ a->add(buffer_B, static_cast<asmjit::Imm>(nBlock * row_interleave *
+ sizeof(int8_t)));
+ // a->add(B_pf, static_cast<asmjit::Imm>(nBlock * row_interleave *
+ // sizeof(int8_t)));
- a->cmp(kIdx, kSize);
- a->jl(LoopkRem);
+ a->cmp(kIdx, kSize);
+ a->jl(LoopkRem);
- // store C matrix
- storeCRegs<inst_set_t::avx2>(a, rowRegs, colRegs, C_Offset, ldcReg, accum);
- }
+ // store C matrix
+ storeCRegs<inst_set_t::avx2>(a, rowRegs, colRegs, C_Offset, ldcReg,
+ accum);
+ }
- asmjit::FuncUtils::emitEpilog(a, layout);
+ a->emitEpilog(frame);
- jit_micro_kernel_fp fn;
- asmjit::Error err = rt_.add(&fn, &code_);
- if (err) {
- std::cout << "Error: in fn add" << std::endl;
- return nullptr;
- }
- codeCache_[kernelSig] = fn;
+ jit_micro_kernel_fp fn;
+ asmjit::Error err;
+ {
+ std::unique_lock<std::mutex> lock(rtMutex_);
+ err = runtime().add(&fn, &code);
+ }
+ if (err) {
+ std::cout << "Error: in fn add" << std::endl;
+ return nullptr;
+ }
#if defined(FBGEMM_LOG_CODE)
- fclose(codeLogfile);
- delete codeLogger;
+ fclose(codeLogfile);
+ delete codeLogger;
#endif
- return fn;
+ return fn;
+ });
}
} // namespace fbgemm
diff --git a/src/GenerateKernelU8S8S32ACC16Avx512.cc b/src/GenerateKernelU8S8S32ACC16Avx512.cc
index 505fec1..819f33b 100644
--- a/src/GenerateKernelU8S8S32ACC16Avx512.cc
+++ b/src/GenerateKernelU8S8S32ACC16Avx512.cc
@@ -19,16 +19,17 @@ template <>
template <>
void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::initCRegs<
inst_set_t::avx512>(
- asmjit::X86Emitter* a,
+ x86::Emitter* a,
int rowRegs,
int colRegs,
- int leadingDimCRegAssign) {
+ int leadingDimCReg) {
+ using CRegs = x86::Zmm;
for (int i = 0; i < rowRegs; ++i) {
for (int j = 0; j < colRegs; ++j) {
a->vxorps(
- CRegs_avx512_[i * leadingDimCRegAssign + j],
- CRegs_avx512_[i * leadingDimCRegAssign + j],
- CRegs_avx512_[i * leadingDimCRegAssign + j]);
+ CRegs(i * leadingDimCReg + j),
+ CRegs(i * leadingDimCReg + j),
+ CRegs(i * leadingDimCReg + j));
}
}
}
@@ -41,37 +42,38 @@ template <>
template <>
void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::genComputeBlock<
inst_set_t::avx512>(
- asmjit::X86Emitter* a,
- asmjit::X86Gp buffer_A,
- asmjit::X86Gp buffer_B,
- asmjit::X86Gp /* unused (reserved for prefetching)*/,
+ x86::Emitter* a,
+ x86::Gp buffer_A,
+ x86::Gp buffer_B,
+ x86::Gp /* unused (reserved for prefetching)*/,
int rowRegs,
int colRegs,
int lda,
- int leadingDimCRegAssign) {
+ int leadingDimCReg) {
// used for matrix A
- asmjit::X86Zmm AReg = x86::zmm29;
+ x86::Zmm AReg = x86::zmm29;
- asmjit::X86Zmm tmpReg = x86::zmm30;
+ x86::Zmm tmpReg = x86::zmm30;
// We start allocating BRegs from zmm27 and then allocate zmm26 and so on.
for (int j = 0; j < colRegs; ++j) {
a->vmovups(
- AllRegs_avx512_[27 - j],
+ x86::Zmm(27 - j),
x86::dword_ptr(buffer_B, j * VLEN_ * sizeof(int8_t)));
}
+ using CRegs = x86::Zmm;
+
for (int i = 0; i < rowRegs; ++i) {
// broadcast A
a->vpbroadcastw(
AReg, x86::dword_ptr(buffer_A, (i * lda) * sizeof(uint8_t)));
for (int j = 0; j < colRegs; ++j) {
- a->vpmaddubsw(
- tmpReg, AReg, AllRegs_avx512_[27-j]);
+ a->vpmaddubsw(tmpReg, AReg, x86::Zmm(27 - j));
a->vpaddsw(
- CRegs_avx512_[i * leadingDimCRegAssign + j],
+ CRegs(i * leadingDimCReg + j),
tmpReg,
- CRegs_avx512_[i * leadingDimCRegAssign + j]);
+ CRegs(i * leadingDimCReg + j));
// Prefetching is hurting performance in some cases
// because prefetch instructions itself consumes a slot
// in pipeline issue thus slowing down the kernel.
@@ -90,25 +92,31 @@ template <>
template <>
void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::storeCRegs<
inst_set_t::avx512>(
- asmjit::X86Emitter* a,
+ x86::Emitter* a,
int rowRegs,
int colRegs,
- asmjit::X86Gp C_Offset,
- asmjit::X86Gp ldcReg,
+ x86::Gp C_Offset,
+ x86::Gp ldcReg,
+
bool accum,
- int leadingDimCRegAssign) {
- asmjit::X86Ymm extractDest256 = x86::ymm31;
- asmjit::X86Zmm extractDest512 = x86::zmm31;
+ int leadingDimCReg) {
+ x86::Ymm extractDest256 = x86::ymm31;
+ x86::Zmm extractDest512 = x86::zmm31;
+ using CRegs = x86::Zmm;
for (int i = 0; i < rowRegs; ++i) {
a->imul(C_Offset, ldcReg, static_cast<asmjit::Imm>(i * sizeof(int32_t)));
for (int j = 0; j < colRegs; ++j) {
for (int idx = 0; idx < 2; ++idx) {
a->vextracti32x8(
- extractDest256, CRegs_avx512_[i * leadingDimCRegAssign + j], idx);
+ extractDest256, CRegs(i * leadingDimCReg + j), idx);
a->vpmovsxwd(extractDest512, extractDest256);
- asmjit::X86Mem destAddr = x86::dword_ptr(
+ x86::Mem destAddr = x86::dword_ptr(
+#ifdef _MSC_VER
+ a->gpz(9), C_Offset, 0, (j * 2 + idx) * 16 * sizeof(int32_t));
+#else
a->zcx(), C_Offset, 0, (j * 2 + idx) * 16 * sizeof(int32_t));
+#endif
if (accum) {
a->vpaddd(extractDest512, extractDest512, destAddr);
}
@@ -167,261 +175,256 @@ CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<inst_set_t::avx512>(
nRegBlockSize,
nRegBlockSizeMin);
- if (codeCache_.find(kernelSig) != codeCache_.end()) {
- return codeCache_[kernelSig];
- }
-
- code_.reset(false);
- code_.init(rt_.getCodeInfo());
- asmjit::X86Assembler assembler(&code_);
- asmjit::X86Emitter* a = assembler.asEmitter();
+ return codeCache_.getOrCreate(kernelSig, [&]() -> jit_micro_kernel_fp {
+ asmjit::CodeHolder code;
+ code.init(runtime().codeInfo());
+ x86::Assembler assembler(&code);
+ x86::Emitter *a = assembler.as<x86::Emitter>();
#if defined(FBGEMM_LOG_CODE)
- // generated code logging
- FILE* codeLogfile = fopen(
- getCodeLoggingFile<inst_set_t::avx512>(
- accum,
- mc,
- nc,
- nBlock,
- kBlock,
- mRegBlockSize,
- nRegBlockSize,
- nRegBlockSizeMin)
- .c_str(),
- "w");
- asmjit::FileLogger* codeLogger = new asmjit::FileLogger(codeLogfile);
- if (codeLogger) {
- code_.setLogger(codeLogger);
- }
+ // generated code logging
+ FILE *codeLogfile = fopen(getCodeLoggingFile<inst_set_t::avx512>(
+ accum, mc, nc, nBlock, kBlock, mRegBlockSize,
+ nRegBlockSize, nRegBlockSizeMin)
+ .c_str(),
+ "w");
+ asmjit::FileLogger *codeLogger = new asmjit::FileLogger(codeLogfile);
+ if (codeLogger) {
+ code.setLogger(codeLogger);
+ }
#endif
- assert(kc % row_interleave == 0 && "kc must be a multiple of row_interleave");
- assert(nc % nRegBlockSizeMin == 0 && "nc must be a multiple of NR_MIN");
- int maxMRegs = mRegBlockSize;
- int maxNRegs = nRegBlockSize * row_interleave / VLEN_;
- assert(
- maxMRegs * maxNRegs <= 24 &&
- "MR*(NR*ROW_INTERLEAVE*8/512) \
- must be <= 24(available registers constraint)");
- int mRegBlocks = mc / mRegBlockSize;
- int mRegBlocksRem = mc % mRegBlockSize;
-
- // arguments to the function created
- asmjit::X86Gp buffer_A = a->zdi();
- asmjit::X86Gp buffer_B = a->zsi();
- asmjit::X86Gp B_pf = a->zdx();
- asmjit::X86Gp CBase = a->zcx();
- asmjit::X86Gp kSize = a->gpzRef(8);
- asmjit::X86Gp ldcReg = a->gpzRef(9);
-
- asmjit::FuncDetail func;
- func.init(
- asmjit::
- FuncSignature6<void, uint8_t*, int8_t*, int8_t*, int32_t*, int, int>(
- asmjit::CallConv::kIdHost));
-
- asmjit::FuncFrameInfo ffi;
- ffi.setDirtyRegs(
- asmjit::X86Reg::kKindVec,
- asmjit::Utils::mask(0, 1, 2, 3, 4, 5, 6, 7) |
- asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14, 15));
- ffi.setDirtyRegs(
- asmjit::X86Reg::kKindGp,
- asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14, 15));
-
- asmjit::FuncArgsMapper args(&func);
- args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg);
-
- args.updateFrameInfo(ffi);
-
- asmjit::FuncFrameLayout layout;
- layout.init(func, ffi);
-
- asmjit::FuncUtils::emitProlog(a, layout);
- asmjit::FuncUtils::allocArgs(a, layout, args);
-
- asmjit::Label LoopMBlocks = a->newLabel();
- asmjit::Label LoopNBlocks = a->newLabel();
- asmjit::Label Loopk = a->newLabel();
-
- asmjit::X86Gp buffer_B_saved = a->gpzRef(10);
- asmjit::X86Gp C_Offset = a->gpzRef(11);
- // asmjit::X86Gp B_pf_saved = a->gpzRef(12);
- asmjit::X86Gp iIdx = a->gpzRef(13);
- asmjit::X86Gp jIdx = a->gpzRef(14);
- asmjit::X86Gp kIdx = a->gpzRef(15);
-
- // save B_buffer address
- a->mov(buffer_B_saved, buffer_B);
- // a->mov(B_pf_saved, B_pf);
-
- int currColRegs = nc * row_interleave * sizeof(int8_t) / VLEN_;
- int colRegs = std::min(currColRegs, maxNRegs);
- if (mRegBlocks > 0) {
- // move 0 to iteration variables
- a->mov(iIdx, 0);
-
- a->bind(LoopMBlocks);
- a->inc(iIdx);
- a->mov(jIdx, 0);
-
- a->bind(LoopNBlocks);
- a->inc(jIdx);
-
- int rowRegs = mRegBlockSize;
-
- // init C registers
- initCRegs<inst_set_t::avx512>(a, rowRegs, colRegs, colRegs);
-
- // init k loop index
- a->mov(kIdx, 0);
- a->bind(Loopk);
- // k is incremented by row_interleave
- a->add(kIdx, static_cast<asmjit::Imm>(row_interleave));
-
- genComputeBlock<inst_set_t::avx512>(
- a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock, colRegs);
-
- // update buffer_A address for next k iteration
- a->add(
- buffer_A, static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t)));
-
- // update buffer_B address for next k iteration
- a->add(
- buffer_B,
- static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t)));
- // a->add(B_pf, static_cast<asmjit::Imm>(nBlock * row_interleave *
- // sizeof(int8_t)));
-
- a->cmp(kIdx, kSize);
- a->jl(Loopk);
-
- // store C matrix
- storeCRegs<inst_set_t::avx512>(
- a, rowRegs, colRegs, C_Offset, ldcReg, accum, colRegs);
-
- // reset A
- a->sub(buffer_A, kSize);
-
- // B for next block
- a->mov(buffer_B, buffer_B_saved);
- // using C_Offset as temp reg
- a->imul(
- C_Offset,
- jIdx,
- static_cast<asmjit::Imm>(
- nRegBlockSize * row_interleave * sizeof(int8_t)));
- a->add(buffer_B, C_Offset);
-
- // increment C for next block
- a->add(CBase, static_cast<asmjit::Imm>(nRegBlockSize * sizeof(int32_t)));
-
- int jLoopTrips = currColRegs / maxNRegs;
- // jLoopTrips should be at least 1
- jLoopTrips = jLoopTrips ? jLoopTrips : 1;
- a->cmp(jIdx, jLoopTrips);
- a->jl(LoopNBlocks);
-
- // increment A for next block
- a->add(
- buffer_A, static_cast<asmjit::Imm>((rowRegs)*kBlock * sizeof(uint8_t)));
-
- // increment C for next A block
- a->sub(
- CBase,
- static_cast<asmjit::Imm>(jLoopTrips * nRegBlockSize * sizeof(int32_t)));
- a->imul(
- C_Offset, ldcReg, static_cast<asmjit::Imm>(rowRegs * sizeof(int32_t)));
- a->add(CBase, C_Offset);
-
- // reset B
- a->mov(buffer_B, buffer_B_saved);
- // a->mov(B_pf, B_pf_saved);
-
- a->cmp(iIdx, mRegBlocks);
- a->jl(LoopMBlocks);
- }
- // generate code for remainder
- if (mRegBlocksRem > 0) {
- asmjit::Label LoopNRem = a->newLabel();
- asmjit::Label LoopkRem = a->newLabel();
- int rowRegs = mRegBlocksRem;
-
- a->mov(jIdx, 0);
- a->bind(LoopNRem);
- a->inc(jIdx);
-
- // init C registers
- initCRegs<inst_set_t::avx512>(a, rowRegs, colRegs, colRegs);
-
- // init k loop index
- a->mov(kIdx, 0);
- a->bind(LoopkRem);
-
- // k is incremented by row_interleave
- a->add(kIdx, static_cast<asmjit::Imm>(row_interleave));
-
- genComputeBlock<inst_set_t::avx512>(
- a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock, colRegs);
-
- // update buffer_A address for next k iteration
- a->add(
- buffer_A, static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t)));
-
- // update buffer_B address for next k iteration
- a->add(
- buffer_B,
- static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t)));
- // a->add(B_pf, static_cast<asmjit::Imm>(nBlock * row_interleave *
- // sizeof(int8_t)));
-
- a->cmp(kIdx, kSize);
- a->jl(LoopkRem);
-
- // reset A
- a->sub(buffer_A, kSize);
-
- // B for next block
- a->mov(buffer_B, buffer_B_saved);
- // using C_Offset as temp reg
- a->imul(
- C_Offset,
- jIdx,
- static_cast<asmjit::Imm>(
- nRegBlockSize * row_interleave * sizeof(int8_t)));
- a->add(buffer_B, C_Offset);
-
- // store C matrix
- storeCRegs<inst_set_t::avx512>(
- a, rowRegs, colRegs, C_Offset, ldcReg, accum, colRegs);
-
- // increment C for next block
- a->add(CBase, static_cast<asmjit::Imm>(nRegBlockSize * sizeof(int32_t)));
-
- int jLoopTrips = currColRegs / maxNRegs;
- // jLoopTrips should be at least 1
- jLoopTrips = jLoopTrips ? jLoopTrips : 1;
- a->cmp(jIdx, jLoopTrips);
- a->jl(LoopNRem);
- }
+ assert(kc % row_interleave == 0 &&
+ "kc must be a multiple of row_interleave");
+ assert(nc % nRegBlockSizeMin == 0 && "nc must be a multiple of NR_MIN");
+ int maxMRegs = mRegBlockSize;
+ int maxNRegs = nRegBlockSize * row_interleave / VLEN_;
+ assert((maxMRegs + 1) * maxNRegs <= 28 &&
+ "number of zmm registers for C + one row for loading B: \
+ MR*(NR*ROW_INTERLEAVE*8/512) + (NR*ROW_INTERLEAVE*8/512) \
+ must be <= 28(available registers constraint)");
+ int mRegBlocks = mc / mRegBlockSize;
+ int mRegBlocksRem = mc % mRegBlockSize;
+
+ // arguments to the function created
+#ifdef _MSC_VER
+ x86::Gp buffer_A = a->zcx();
+ x86::Gp buffer_B = a->zdx();
+ x86::Gp B_pf = a->gpz(8);
+ x86::Gp CBase = a->gpz(9);
+ x86::Gp kSize = a->zdi();
+ x86::Gp ldcReg = a->zsi();
+#else
+ x86::Gp buffer_A = a->zdi();
+ x86::Gp buffer_B = a->zsi();
+ x86::Gp B_pf = a->zdx();
+ x86::Gp CBase = a->zcx();
+ x86::Gp kSize = a->gpz(8);
+ x86::Gp ldcReg = a->gpz(9);
+#endif
+
+ asmjit::FuncDetail func;
+ func.init(
+ asmjit::FuncSignatureT<void, uint8_t *, int8_t *, int8_t *, int32_t *,
+ int, int>(asmjit::CallConv::kIdHost));
+
+ asmjit::FuncFrame frame;
+ frame.init(func);
+
+ frame.setDirtyRegs(
+ x86::Reg::kGroupVec,
+ asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) |
+ asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
+ frame.setDirtyRegs(x86::Reg::kGroupGp,
+ asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
+
+ asmjit::FuncArgsAssignment args(&func);
+ args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg);
+
+ args.updateFuncFrame(frame);
+ frame.finalize();
+
+ a->emitProlog(frame);
+ a->emitArgsAssignment(frame, args);
+
+ asmjit::Label LoopMBlocks = a->newLabel();
+ asmjit::Label LoopNBlocks = a->newLabel();
+ asmjit::Label Loopk = a->newLabel();
+
+ x86::Gp buffer_B_saved = a->gpz(10);
+ x86::Gp C_Offset = a->gpz(11);
+ // x86::Gp B_pf_saved = a->gpz(12);
+ x86::Gp iIdx = a->gpz(13);
+ x86::Gp jIdx = a->gpz(14);
+ x86::Gp kIdx = a->gpz(15);
+
+ // save B_buffer address
+ a->mov(buffer_B_saved, buffer_B);
+ // a->mov(B_pf_saved, B_pf);
+
+ int currColRegs = nc * row_interleave * sizeof(int8_t) / VLEN_;
+ int colRegs = std::min(currColRegs, maxNRegs);
+ if (mRegBlocks > 0) {
+ // move 0 to iteration variables
+ a->mov(iIdx, 0);
+
+ a->bind(LoopMBlocks);
+ a->inc(iIdx);
+ a->mov(jIdx, 0);
+
+ a->bind(LoopNBlocks);
+ a->inc(jIdx);
+
+ int rowRegs = mRegBlockSize;
+
+ // init C registers
+ initCRegs<inst_set_t::avx512>(a, rowRegs, colRegs, colRegs);
+
+ // init k loop index
+ a->mov(kIdx, 0);
+ a->bind(Loopk);
+ // k is incremented by row_interleave
+ a->add(kIdx, static_cast<asmjit::Imm>(row_interleave));
+
+ genComputeBlock<inst_set_t::avx512>(a, buffer_A, buffer_B, B_pf, rowRegs,
+ colRegs, kBlock, colRegs);
+
+ // update buffer_A address for next k iteration
+ a->add(buffer_A,
+ static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t)));
+
+ // update buffer_B address for next k iteration
+ a->add(buffer_B, static_cast<asmjit::Imm>(nBlock * row_interleave *
+ sizeof(int8_t)));
+ // a->add(B_pf, static_cast<asmjit::Imm>(nBlock * row_interleave *
+ // sizeof(int8_t)));
+
+ a->cmp(kIdx, kSize);
+ a->jl(Loopk);
+
+ // store C matrix
+ storeCRegs<inst_set_t::avx512>(a, rowRegs, colRegs, C_Offset, ldcReg,
+ accum, colRegs);
+
+ // reset A
+ a->sub(buffer_A, kSize);
+
+ // B for next block
+ a->mov(buffer_B, buffer_B_saved);
+ // using C_Offset as temp reg
+ a->imul(C_Offset, jIdx,
+ static_cast<asmjit::Imm>(nRegBlockSize * row_interleave *
+ sizeof(int8_t)));
+ a->add(buffer_B, C_Offset);
+
+ // increment C for next block
+ a->add(CBase, static_cast<asmjit::Imm>(nRegBlockSize * sizeof(int32_t)));
+
+ int jLoopTrips = currColRegs / maxNRegs;
+ // jLoopTrips should be at least 1
+ jLoopTrips = jLoopTrips ? jLoopTrips : 1;
+ a->cmp(jIdx, jLoopTrips);
+ a->jl(LoopNBlocks);
+
+ // increment A for next block
+ a->add(buffer_A,
+ static_cast<asmjit::Imm>((rowRegs)*kBlock * sizeof(uint8_t)));
+
+ // increment C for next A block
+ a->sub(CBase, static_cast<asmjit::Imm>(jLoopTrips * nRegBlockSize *
+ sizeof(int32_t)));
+ a->imul(C_Offset, ldcReg,
+ static_cast<asmjit::Imm>(rowRegs * sizeof(int32_t)));
+ a->add(CBase, C_Offset);
+
+ // reset B
+ a->mov(buffer_B, buffer_B_saved);
+ // a->mov(B_pf, B_pf_saved);
+
+ a->cmp(iIdx, mRegBlocks);
+ a->jl(LoopMBlocks);
+ }
+ // generate code for remainder
+ if (mRegBlocksRem > 0) {
+ asmjit::Label LoopNRem = a->newLabel();
+ asmjit::Label LoopkRem = a->newLabel();
+ int rowRegs = mRegBlocksRem;
+
+ a->mov(jIdx, 0);
+ a->bind(LoopNRem);
+ a->inc(jIdx);
+
+ // init C registers
+ initCRegs<inst_set_t::avx512>(a, rowRegs, colRegs, colRegs);
+
+ // init k loop index
+ a->mov(kIdx, 0);
+ a->bind(LoopkRem);
+
+ // k is incremented by row_interleave
+ a->add(kIdx, static_cast<asmjit::Imm>(row_interleave));
+
+ genComputeBlock<inst_set_t::avx512>(a, buffer_A, buffer_B, B_pf, rowRegs,
+ colRegs, kBlock, colRegs);
+
+ // update buffer_A address for next k iteration
+ a->add(buffer_A,
+ static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t)));
+
+ // update buffer_B address for next k iteration
+ a->add(buffer_B, static_cast<asmjit::Imm>(nBlock * row_interleave *
+ sizeof(int8_t)));
+ // a->add(B_pf, static_cast<asmjit::Imm>(nBlock * row_interleave *
+ // sizeof(int8_t)));
+
+ a->cmp(kIdx, kSize);
+ a->jl(LoopkRem);
+
+ // reset A
+ a->sub(buffer_A, kSize);
+
+ // B for next block
+ a->mov(buffer_B, buffer_B_saved);
+ // using C_Offset as temp reg
+ a->imul(C_Offset, jIdx,
+ static_cast<asmjit::Imm>(nRegBlockSize * row_interleave *
+ sizeof(int8_t)));
+ a->add(buffer_B, C_Offset);
+
+ // store C matrix
+ storeCRegs<inst_set_t::avx512>(a, rowRegs, colRegs, C_Offset, ldcReg,
+ accum, colRegs);
+
+ // increment C for next block
+ a->add(CBase, static_cast<asmjit::Imm>(nRegBlockSize * sizeof(int32_t)));
+
+ int jLoopTrips = currColRegs / maxNRegs;
+ // jLoopTrips should be at least 1
+ jLoopTrips = jLoopTrips ? jLoopTrips : 1;
+ a->cmp(jIdx, jLoopTrips);
+ a->jl(LoopNRem);
+ }
- asmjit::FuncUtils::emitEpilog(a, layout);
+ a->emitEpilog(frame);
- jit_micro_kernel_fp fn;
- asmjit::Error err = rt_.add(&fn, &code_);
- if (err) {
- std::cout << "Error: in fn add" << std::endl;
- return nullptr;
- }
- codeCache_[kernelSig] = fn;
+ jit_micro_kernel_fp fn;
+ asmjit::Error err;
+ {
+ std::unique_lock<std::mutex> lock(rtMutex_);
+ err = runtime().add(&fn, &code);
+ }
+ if (err) {
+ std::cout << "Error: in fn add" << std::endl;
+ return nullptr;
+ }
#if defined(FBGEMM_LOG_CODE)
- fclose(codeLogfile);
- delete codeLogger;
+ fclose(codeLogfile);
+ delete codeLogger;
#endif
- return fn;
+ return fn;
+ });
}
} // namespace fbgemm
diff --git a/src/GenerateKernelU8S8S32ACC16Avx512VNNI.cc b/src/GenerateKernelU8S8S32ACC16Avx512VNNI.cc
new file mode 100644
index 0000000..f559aba
--- /dev/null
+++ b/src/GenerateKernelU8S8S32ACC16Avx512VNNI.cc
@@ -0,0 +1,102 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+#include <iostream>
+#include "GenerateKernel.h"
+
+namespace fbgemm {
+
+namespace x86 = asmjit::x86;
+
+/**
+ * Generate AVX512 instructions for initializing the C registers to 0 in 16-bit
+ * Accumulation kernel.
+ */
+template <>
+template <>
+void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::initCRegs<
+ inst_set_t::avx512_vnni>(
+ x86::Emitter* a,
+ int rowRegs,
+ int colRegs,
+ int leadingDimCReg) {
+ assert(0 && "Accumulation to int16_t is not available for VNNI!");
+
+ // For AVX512VNNI, redirect to int32_t accumulation.
+ CodeGenBase<uint8_t, int8_t, int32_t, int32_t> codeObj;
+ codeObj.initCRegs<inst_set_t::avx512_vnni>(
+ a, rowRegs, colRegs, leadingDimCReg);
+}
+
+/**
+ * Generate AVX512 instructions for computing block in the rank-k update of
+ * 16-bit Accmulation kernel.
+ */
+template <>
+template <>
+void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::genComputeBlock<
+ inst_set_t::avx512_vnni>(
+ x86::Emitter* a,
+ x86::Gp buffer_A,
+ x86::Gp buffer_B,
+ x86::Gp /* unused (reserved for prefetching)*/,
+ int rowRegs,
+ int colRegs,
+ int lda,
+ int leadingDimCReg) {
+ assert(0 && "Accumulation to int16_t is not available for VNNI!");
+
+ // For AVX512VNNI, redirect to int32_t accumulation.
+ CodeGenBase<uint8_t, int8_t, int32_t, int32_t> codeObj;
+ codeObj.genComputeBlock<inst_set_t::avx512_vnni>(
+ a, buffer_A, buffer_B, buffer_B, rowRegs, colRegs, lda, leadingDimCReg);
+}
+
+/**
+ * Generate AVX512 instructions for storing the C registers back to the memory
+ * in 16-bit Accumulation kernel.
+ */
+template <>
+template <>
+void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::storeCRegs<
+ inst_set_t::avx512_vnni>(
+ x86::Emitter* a,
+ int rowRegs,
+ int colRegs,
+ x86::Gp C_Offset,
+ x86::Gp ldcReg,
+ bool accum,
+ int leadingDimCReg) {
+ assert(0 && "Accumulation to int16_t is not available for VNNI!");
+
+ // For AVX512VNNI, redirect to int32_t accumulation.
+ CodeGenBase<uint8_t, int8_t, int32_t, int32_t> codeObj;
+ codeObj.storeCRegs<inst_set_t::avx512_vnni>(
+ a, rowRegs, colRegs, C_Offset, ldcReg, accum, leadingDimCReg);
+}
+
+/**
+ * Get or Create the AVX512 instructions for 16-bit Accumulation macro-kernel.
+ *
+ */
+template <>
+template <>
+CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::jit_micro_kernel_fp
+CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<
+ inst_set_t::avx512_vnni>(
+ bool accum,
+ int32_t mc,
+ int32_t nc,
+ int32_t kc,
+ int32_t /* unused */) {
+ assert(0 && "Accumulation to int16_t is not available for VNNI!");
+
+ // For AVX512VNNI, redirect to int32_t accumulation.
+ CodeGenBase<uint8_t, int8_t, int32_t, int32_t> codeObj;
+ return codeObj.getOrCreate<inst_set_t::avx512_vnni>(accum, mc, nc, kc, kc);
+}
+
+} // namespace fbgemm
diff --git a/src/GenerateKernelU8S8S32ACC32.cc b/src/GenerateKernelU8S8S32ACC32.cc
index d044530..dc9c534 100644
--- a/src/GenerateKernelU8S8S32ACC32.cc
+++ b/src/GenerateKernelU8S8S32ACC32.cc
@@ -9,18 +9,6 @@
namespace fbgemm {
-template <typename TA, typename TB, typename TC, typename accT>
-thread_local asmjit::JitRuntime CodeGenBase<TA, TB, TC, accT>::rt_;
-
-template <typename TA, typename TB, typename TC, typename accT>
-thread_local asmjit::CodeHolder CodeGenBase<TA, TB, TC, accT>::code_;
-
-template <typename TA, typename TB, typename TC, typename accT>
-thread_local std::map<
- std::tuple<bool, int, int, int, int, int, int, int>,
- typename CodeGenBase<TA, TB, TC, accT>::jit_micro_kernel_fp>
- CodeGenBase<TA, TB, TC, accT>::codeCache_;
-
namespace x86 = asmjit::x86;
/**
@@ -31,16 +19,17 @@ template <>
template <>
void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::initCRegs<
inst_set_t::avx2>(
- asmjit::X86Emitter* a,
+ x86::Emitter* a,
int rowRegs,
int colRegs,
int leadingDimCReg) {
+ using CRegs = x86::Ymm;
for (int i = 0; i < rowRegs; ++i) {
for (int j = 0; j < colRegs; ++j) {
a->vxorps(
- CRegs_avx2_[i * leadingDimCReg + j],
- CRegs_avx2_[i * leadingDimCReg + j],
- CRegs_avx2_[i * leadingDimCReg + j]);
+ CRegs(i * leadingDimCReg + j),
+ CRegs(i * leadingDimCReg + j),
+ CRegs(i * leadingDimCReg + j));
}
}
}
@@ -53,25 +42,27 @@ template <>
template <>
void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::genComputeBlock<
inst_set_t::avx2>(
- asmjit::X86Emitter* a,
- asmjit::X86Gp buffer_A,
- asmjit::X86Gp buffer_B,
- asmjit::X86Gp B_pf,
+ x86::Emitter* a,
+ x86::Gp buffer_A,
+ x86::Gp buffer_B,
+ x86::Gp B_pf,
int rowRegs,
int colRegs,
int lda,
- int leadingDimCRegAssign) {
+ int leadingDimCReg) {
// used for matrix A
- asmjit::X86Ymm AReg = x86::ymm12;
+ x86::Ymm AReg = x86::ymm12;
// used for matrix B
- asmjit::X86Ymm BReg = x86::ymm13;
+ x86::Ymm BReg = x86::ymm13;
// Contains 16-bit 1s
- asmjit::X86Ymm oneReg = x86::ymm15;
+ x86::Ymm oneReg = x86::ymm15;
// temporary register
- asmjit::X86Ymm res1 = x86::ymm14;
+ x86::Ymm res1 = x86::ymm14;
+
+ using CRegs = x86::Ymm;
for (int j = 0; j < colRegs; ++j) {
// load B
@@ -83,9 +74,9 @@ void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::genComputeBlock<
a->vpmaddubsw(res1, AReg, BReg);
a->vpmaddwd(res1, oneReg, res1);
a->vpaddd(
- CRegs_avx2_[i * leadingDimCRegAssign + j],
+ CRegs(i * leadingDimCReg + j),
res1,
- CRegs_avx2_[i * leadingDimCRegAssign + j]);
+ CRegs(i * leadingDimCReg + j));
}
a->prefetcht0(x86::dword_ptr(B_pf, j * VLEN_ * sizeof(int8_t)));
}
@@ -99,16 +90,14 @@ template <>
template <>
void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::storeCRegs<
inst_set_t::avx2>(
- asmjit::X86Emitter* a,
+ x86::Emitter* a,
int rowRegs,
int colRegs,
- asmjit::X86Gp C_Offset,
- asmjit::X86Gp ldcReg,
+ x86::Gp C_Offset,
+ x86::Gp ldcReg,
bool accum,
- int leadingDimCRegAssign) {
- // temp register
- asmjit::X86Ymm tmpReg = x86::ymm14;
-
+ int leadingDimCReg) {
+ using CRegs = x86::Ymm;
for (int i = 0; i < rowRegs; ++i) {
if (i != 0) {
a->add(C_Offset, ldcReg);
@@ -116,13 +105,21 @@ void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::storeCRegs<
for (int j = 0; j < colRegs; ++j) {
if (accum) {
a->vpaddd(
- CRegs_avx2_[i * leadingDimCRegAssign + j],
- CRegs_avx2_[i * leadingDimCRegAssign + j],
+ CRegs(i * leadingDimCReg + j),
+ CRegs(i * leadingDimCReg + j),
+#ifdef _MSC_VER
+ x86::dword_ptr(a->gpz(9), C_Offset, 0, j * 8 * sizeof(int32_t)));
+#else
x86::dword_ptr(a->zcx(), C_Offset, 0, j * 8 * sizeof(int32_t)));
+#endif
}
a->vmovups(
+#ifdef _MSC_VER
+ x86::dword_ptr(a->gpz(9), C_Offset, 0, j * 8 * sizeof(int32_t)),
+#else
x86::dword_ptr(a->zcx(), C_Offset, 0, j * 8 * sizeof(int32_t)),
- CRegs_avx2_[i * leadingDimCRegAssign + j]);
+#endif
+ CRegs(i * leadingDimCReg + j));
}
}
}
@@ -176,207 +173,178 @@ CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreate<inst_set_t::avx2>(
nRegBlockSize,
nRegBlockSizeMin);
- if (codeCache_.find(kernelSig) != codeCache_.end()) {
- return codeCache_[kernelSig];
- }
- code_.reset(false);
- code_.init(rt_.getCodeInfo());
- asmjit::X86Assembler assembler(&code_);
- asmjit::X86Emitter* a = assembler.asEmitter();
+ return codeCache_.getOrCreate(kernelSig, [&]() -> jit_micro_kernel_fp {
+ asmjit::CodeHolder code;
+ code.init(runtime().codeInfo());
+ x86::Assembler assembler(&code);
+ x86::Emitter *a = assembler.as<x86::Emitter>();
#if defined(FBGEMM_LOG_CODE)
- // generated code logging
- FILE* codeLogfile = fopen(
- getCodeLoggingFile<inst_set_t::avx2>(
- accum,
- mc,
- nc,
- nBlock,
- kBlock,
- mRegBlockSize,
- nRegBlockSize,
- nRegBlockSizeMin)
- .c_str(),
- "w");
- asmjit::FileLogger* codeLogger = new asmjit::FileLogger(codeLogfile);
- if (codeLogger) {
- code_.setLogger(codeLogger);
- }
+ // generated code logging
+ FILE *codeLogfile = fopen(getCodeLoggingFile<inst_set_t::avx2>(
+ accum, mc, nc, nBlock, kBlock, mRegBlockSize,
+ nRegBlockSize, nRegBlockSizeMin)
+ .c_str(),
+ "w");
+ asmjit::FileLogger *codeLogger = new asmjit::FileLogger(codeLogfile);
+ if (codeLogger) {
+ code.setLogger(codeLogger);
+ }
#endif
- // assert(mc <= 12 && "mc must be <= 12 (available registers constraint)");
- int mRegBlocks = mc / mRegBlockSize;
- int mRegBlocksRem = mc % mRegBlockSize;
-
- // arguments to the function created
- asmjit::X86Gp buffer_A = a->zdi();
- asmjit::X86Gp buffer_B = a->zsi();
- asmjit::X86Gp B_pf = a->zdx();
- asmjit::X86Gp CBase = a->zcx();
- asmjit::X86Gp kSize = a->gpzRef(8);
- asmjit::X86Gp ldcReg = a->gpzRef(9);
-
- asmjit::FuncDetail func;
- func.init(
- asmjit::
- FuncSignature6<void, uint8_t*, int8_t*, int8_t*, int32_t*, int, int>(
- asmjit::CallConv::kIdHost));
-
- asmjit::FuncFrameInfo ffi;
- ffi.setDirtyRegs(
- asmjit::X86Reg::kKindVec,
- asmjit::Utils::mask(0, 1, 2, 3, 4, 5, 6, 7) |
- asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14, 15));
- ffi.setDirtyRegs(
- asmjit::X86Reg::kKindGp, asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14));
-
- asmjit::FuncArgsMapper args(&func);
- args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg);
-
- args.updateFrameInfo(ffi);
-
- asmjit::FuncFrameLayout layout;
- layout.init(func, ffi);
-
- asmjit::FuncUtils::emitProlog(a, layout);
- asmjit::FuncUtils::allocArgs(a, layout, args);
-
- asmjit::Label Loopk = a->newLabel();
- asmjit::Label LoopMBlocks = a->newLabel();
-
- asmjit::X86Gp buffer_B_saved = a->gpzRef(10);
- asmjit::X86Gp C_Offset = a->gpzRef(11);
- asmjit::X86Gp B_pf_saved = a->gpzRef(12);
- asmjit::X86Gp iIdx = a->gpzRef(13);
- asmjit::X86Gp kIdx = a->gpzRef(14);
- // asmjit::X86Gp B_pf = a->gpzRef(8);
-
- asmjit::X86Ymm oneReg = x86::ymm15;
- // create 16-bit 1s
- // i.e., oneReg[0:15] contains 0x0001, oneReg[16:31] contains 0x0001
- // and so on
- a->vpcmpeqw(oneReg, oneReg, oneReg);
- a->vpsrlw(oneReg, oneReg, 15);
- a->imul(ldcReg, ldcReg, static_cast<asmjit::Imm>(sizeof(int32_t)));
- a->mov(C_Offset, 0);
-
- int colRegs = nc * row_interleave * sizeof(int8_t) / VLEN_;
- if (mRegBlocks > 0) {
- // move 0 to iteration variables
- a->mov(iIdx, 0);
-
- // save B_buffer address
- a->mov(buffer_B_saved, buffer_B);
- a->mov(B_pf_saved, B_pf);
-
- a->bind(LoopMBlocks);
- a->inc(iIdx);
-
- int rowRegs = mRegBlockSize;
-
- // init C registers
- initCRegs<inst_set_t::avx2>(a, rowRegs, colRegs, colRegs);
-
- // init k loop index
- a->mov(kIdx, 0);
- a->bind(Loopk);
-
- // k is incremented by row_interleave
- a->add(kIdx, static_cast<asmjit::Imm>(row_interleave));
-
- genComputeBlock<inst_set_t::avx2>(
- a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock, colRegs);
-
- // update buffer_A address for next k iteration
- a->add(
- buffer_A, static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t)));
-
- // update buffer_B address for next k iteration
- a->add(
- buffer_B,
- static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t)));
- a->add(
- B_pf,
- static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t)));
-
- // a->add(B_pf, 32*sizeof(float));
-
- a->cmp(kIdx, kSize);
- a->jl(Loopk);
-
- // store C matrix
- storeCRegs<inst_set_t::avx2>(
- a, rowRegs, colRegs, C_Offset, ldcReg, accum, colRegs);
-
- // increment A for next block
- a->sub(buffer_A, kSize);
- a->add(
- buffer_A, static_cast<asmjit::Imm>((rowRegs)*kBlock * sizeof(uint8_t)));
-
- // increment C for next block
- a->imul(C_Offset, ldcReg, static_cast<asmjit::Imm>(rowRegs));
- a->add(CBase, C_Offset);
+ // assert(mc <= 12 && "mc must be <= 12 (available registers constraint)");
+ int mRegBlocks = mc / mRegBlockSize;
+ int mRegBlocksRem = mc % mRegBlockSize;
+
+ // arguments to the function created
+#ifdef _MSC_VER
+ x86::Gp buffer_A = a->zcx();
+ x86::Gp buffer_B = a->zdx();
+ x86::Gp B_pf = a->gpz(8);
+ x86::Gp CBase = a->gpz(9);
+ x86::Gp kSize = a->zdi();
+ x86::Gp ldcReg = a->zsi();
+#else
+ x86::Gp buffer_A = a->zdi();
+ x86::Gp buffer_B = a->zsi();
+ x86::Gp B_pf = a->zdx();
+ x86::Gp CBase = a->zcx();
+ x86::Gp kSize = a->gpz(8);
+ x86::Gp ldcReg = a->gpz(9);
+#endif
+
+ asmjit::FuncDetail func;
+ func.init(
+ asmjit::FuncSignatureT<void, uint8_t *, int8_t *, int8_t *, int32_t *,
+ int, int>(asmjit::CallConv::kIdHost));
+
+ asmjit::FuncFrame frame;
+ frame.init(func);
+ frame.setDirtyRegs(
+ x86::Reg::kGroupVec,
+ asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) |
+ asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
+ frame.setDirtyRegs(x86::Reg::kGroupGp,
+ asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14));
+
+ asmjit::FuncArgsAssignment args(&func);
+ args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg);
+
+ args.updateFuncFrame(frame);
+ frame.finalize();
+
+ a->emitProlog(frame);
+ a->emitArgsAssignment(frame, args);
+
+ asmjit::Label LoopMBlocks = a->newLabel();
+
+ x86::Gp buffer_B_saved = a->gpz(10);
+ x86::Gp C_Offset = a->gpz(11);
+ x86::Gp B_pf_saved = a->gpz(12);
+ x86::Gp iIdx = a->gpz(13);
+ x86::Gp kIdx = a->gpz(14);
+ // x86::Gp B_pf = a->gpz(8);
+
+ x86::Ymm oneReg = x86::ymm15;
+ // create 16-bit 1s
+ // i.e., oneReg[0:15] contains 0x0001, oneReg[16:31] contains 0x0001
+ // and so on
+ a->vpcmpeqw(oneReg, oneReg, oneReg);
+ a->vpsrlw(oneReg, oneReg, 15);
+ a->imul(ldcReg, ldcReg, static_cast<asmjit::Imm>(sizeof(int32_t)));
a->mov(C_Offset, 0);
- // reset B
- a->mov(buffer_B, buffer_B_saved);
- a->mov(B_pf, B_pf_saved);
- a->cmp(iIdx, mRegBlocks);
- a->jl(LoopMBlocks);
- }
- // generate code for remainder
- if (mRegBlocksRem > 0) {
- asmjit::Label LoopkRem = a->newLabel();
- int rowRegs = mRegBlocksRem;
-
- // init C registers
- initCRegs<inst_set_t::avx2>(a, rowRegs, colRegs, colRegs);
-
- // init k loop index
- a->mov(kIdx, 0);
- a->bind(LoopkRem);
-
- // k is incremented by row_interleave
- a->add(kIdx, static_cast<asmjit::Imm>(row_interleave));
-
- genComputeBlock<inst_set_t::avx2>(
- a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock, colRegs);
-
- // update buffer_A address for next k iteration
- a->add(
- buffer_A, static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t)));
-
- // update buffer_B address for next k iteration
- a->add(
- buffer_B,
- static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t)));
- a->add(
- B_pf,
- static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t)));
-
- a->cmp(kIdx, kSize);
- a->jl(LoopkRem);
-
- // store C matrix
- storeCRegs<inst_set_t::avx2>(
- a, rowRegs, colRegs, C_Offset, ldcReg, accum, colRegs);
- }
+ int colRegs = nc * row_interleave * sizeof(int8_t) / VLEN_;
- asmjit::FuncUtils::emitEpilog(a, layout);
+ auto issueLoopOverK = [&](int rowRegs) {
+ asmjit::Label LoopKLabel = a->newLabel();
- jit_micro_kernel_fp fn;
- asmjit::Error err = rt_.add(&fn, &code_);
- if (err) {
- std::cout << "Error: in fn add" << std::endl;
- return nullptr;
- }
- codeCache_[kernelSig] = fn;
+ // Init C (result) vector registers
+ initCRegs<inst_set_t::avx2>(a, rowRegs, colRegs, colRegs);
+
+ // Loops over K
+ a->mov(kIdx, 0);
+ a->bind(LoopKLabel);
+
+ // k is incremented by row_interleave
+ a->add(kIdx, static_cast<asmjit::Imm>(row_interleave));
+
+ genComputeBlock<inst_set_t::avx2>(a, buffer_A, buffer_B, B_pf, rowRegs,
+ colRegs, kBlock, colRegs);
+
+ // update buffer_A address for next k iteration
+ a->add(buffer_A,
+ static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t)));
+
+ // update buffer_B address for next k iteration
+ a->add(buffer_B, static_cast<asmjit::Imm>(nBlock * row_interleave *
+ sizeof(int8_t)));
+ a->add(B_pf, static_cast<asmjit::Imm>(nBlock * row_interleave *
+ sizeof(int8_t)));
+
+ a->cmp(kIdx, kSize);
+ a->jl(LoopKLabel);
+
+ // store C matrix
+ storeCRegs<inst_set_t::avx2>(a, rowRegs, colRegs, C_Offset, ldcReg, accum,
+ colRegs);
+ };
+
+ if (mRegBlocks > 0) {
+ // move 0 to iteration variables
+ a->mov(iIdx, 0);
+
+ // save B_buffer address
+ a->mov(buffer_B_saved, buffer_B);
+ a->mov(B_pf_saved, B_pf);
+
+ a->bind(LoopMBlocks);
+ a->inc(iIdx);
+
+ issueLoopOverK(mRegBlockSize);
+
+ int rowRegs = mRegBlockSize;
+ // increment A for next block
+ a->sub(buffer_A, kSize);
+ a->add(buffer_A,
+ static_cast<asmjit::Imm>((rowRegs)*kBlock * sizeof(uint8_t)));
+
+ // increment C for next block
+ a->imul(C_Offset, ldcReg, static_cast<asmjit::Imm>(rowRegs));
+ a->add(CBase, C_Offset);
+ a->mov(C_Offset, 0);
+
+ // reset B
+ a->mov(buffer_B, buffer_B_saved);
+ a->mov(B_pf, B_pf_saved);
+
+ a->cmp(iIdx, mRegBlocks);
+ a->jl(LoopMBlocks);
+ }
+ // generate code for remainder
+ if (mRegBlocksRem > 0) {
+ issueLoopOverK(mRegBlocksRem);
+ }
+
+ a->emitEpilog(frame);
+
+ jit_micro_kernel_fp fn;
+ asmjit::Error err;
+ {
+ std::unique_lock<std::mutex> lock(rtMutex_);
+ err = runtime().add(&fn, &code);
+ }
+ if (err) {
+ std::cout << "Error: in fn add" << std::endl;
+ return nullptr;
+ }
#if defined(FBGEMM_LOG_CODE)
- fclose(codeLogfile);
- delete codeLogger;
+ fclose(codeLogfile);
+ delete codeLogger;
#endif
- return fn;
+ return fn;
+ });
}
} // namespace fbgemm
diff --git a/src/GenerateKernelU8S8S32ACC32Avx512.cc b/src/GenerateKernelU8S8S32ACC32Avx512.cc
index d1729e4..5037292 100644
--- a/src/GenerateKernelU8S8S32ACC32Avx512.cc
+++ b/src/GenerateKernelU8S8S32ACC32Avx512.cc
@@ -19,16 +19,17 @@ template <>
template <>
void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::initCRegs<
inst_set_t::avx512>(
- asmjit::X86Emitter* a,
+ x86::Emitter* a,
int rowRegs,
int colRegs,
int leadingDimCReg) {
+ using CRegs = x86::Zmm;
for (int i = 0; i < rowRegs; ++i) {
for (int j = 0; j < colRegs; ++j) {
a->vxorps(
- CRegs_avx512_[i * leadingDimCReg + j],
- CRegs_avx512_[i * leadingDimCReg + j],
- CRegs_avx512_[i * leadingDimCReg + j]);
+ CRegs(i * leadingDimCReg + j),
+ CRegs(i * leadingDimCReg + j),
+ CRegs(i * leadingDimCReg + j));
}
}
}
@@ -41,26 +42,27 @@ template <>
template <>
void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::genComputeBlock<
inst_set_t::avx512>(
- asmjit::X86Emitter* a,
- asmjit::X86Gp buffer_A,
- asmjit::X86Gp buffer_B,
- asmjit::X86Gp B_pf,
+ x86::Emitter* a,
+ x86::Gp buffer_A,
+ x86::Gp buffer_B,
+ x86::Gp B_pf,
int rowRegs,
int colRegs,
int lda,
- int leadingDimCRegAssign) {
+ int leadingDimCReg) {
// used for matrix A
- asmjit::X86Zmm AReg = x86::zmm31;
+ x86::Zmm AReg = x86::zmm31;
// used for matrix B
- asmjit::X86Zmm BReg = x86::zmm30;
+ x86::Zmm BReg = x86::zmm30;
// Contains 16-bit 1s
- asmjit::X86Zmm oneReg = x86::zmm29;
+ x86::Zmm oneReg = x86::zmm29;
// temporary register
- asmjit::X86Zmm res1 = x86::zmm28;
+ x86::Zmm res1 = x86::zmm28;
+ using CRegs = x86::Zmm;
for (int j = 0; j < colRegs; ++j) {
// load B
a->vmovaps(BReg, x86::dword_ptr(buffer_B, j * VLEN_ * sizeof(int8_t)));
@@ -71,9 +73,9 @@ void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::genComputeBlock<
a->vpmaddubsw(res1, AReg, BReg);
a->vpmaddwd(res1, oneReg, res1);
a->vpaddd(
- CRegs_avx512_[i * leadingDimCRegAssign + j],
+ CRegs(i * leadingDimCReg + j),
res1,
- CRegs_avx512_[i * leadingDimCRegAssign + j]);
+ CRegs(i * leadingDimCReg + j));
}
a->prefetcht0(x86::dword_ptr(B_pf, j * VLEN_ * sizeof(int8_t)));
}
@@ -87,33 +89,38 @@ template <>
template <>
void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::storeCRegs<
inst_set_t::avx512>(
- asmjit::X86Emitter* a,
+ x86::Emitter* a,
int rowRegs,
int colRegs,
- asmjit::X86Gp C_Offset,
- asmjit::X86Gp ldcReg,
+ x86::Gp C_Offset,
+ x86::Gp ldcReg,
bool accum,
- int leadingDimCRegAssign) {
- // temp register
- asmjit::X86Zmm tmpReg = x86::zmm28;
-
+ int leadingDimCReg) {
+ using CRegs = x86::Zmm;
for (int i = 0; i < rowRegs; ++i) {
if (i != 0) {
a->add(C_Offset, ldcReg);
- }
- else {
+ } else {
a->mov(C_Offset, static_cast<asmjit::Imm>(0));
}
for (int j = 0; j < colRegs; ++j) {
if (accum) {
a->vpaddd(
- CRegs_avx512_[i * leadingDimCRegAssign + j],
- CRegs_avx512_[i * leadingDimCRegAssign + j],
+ CRegs(i * leadingDimCReg + j),
+ CRegs(i * leadingDimCReg + j),
+#ifdef _MSC_VER
+ x86::dword_ptr(a->gpz(9), C_Offset, 0, j * 16 * sizeof(int32_t)));
+#else
x86::dword_ptr(a->zcx(), C_Offset, 0, j * 16 * sizeof(int32_t)));
+#endif
}
a->vmovups(
+#ifdef _MSC_VER
+ x86::dword_ptr(a->gpz(9), C_Offset, 0, j * 16 * sizeof(int32_t)),
+#else
x86::dword_ptr(a->zcx(), C_Offset, 0, j * 16 * sizeof(int32_t)),
- CRegs_avx512_[i * leadingDimCRegAssign + j]);
+#endif
+ CRegs(i * leadingDimCReg + j));
}
}
}
@@ -167,278 +174,269 @@ CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreate<inst_set_t::avx512>(
nRegBlockSize,
nRegBlockSizeMin);
- if (codeCache_.find(kernelSig) != codeCache_.end()) {
- return codeCache_[kernelSig];
- }
- code_.reset(false);
- code_.init(rt_.getCodeInfo());
- asmjit::X86Assembler assembler(&code_);
- asmjit::X86Emitter* a = assembler.asEmitter();
-
+ return codeCache_.getOrCreate(kernelSig, [&]() -> jit_micro_kernel_fp {
+ asmjit::CodeHolder code;
+ code.init(runtime().codeInfo());
+ x86::Assembler assembler(&code);
+ x86::Emitter *a = assembler.as<x86::Emitter>();
#if defined(FBGEMM_LOG_CODE)
- // generated code logging
- FILE* codeLogfile = fopen(
- getCodeLoggingFile<inst_set_t::avx512>(
- accum,
- mc,
- nc,
- nBlock,
- kBlock,
- mRegBlockSize,
- nRegBlockSize,
- nRegBlockSizeMin)
- .c_str(),
- "w");
- asmjit::FileLogger* codeLogger = new asmjit::FileLogger(codeLogfile);
- if (codeLogger) {
- code_.setLogger(codeLogger);
- }
+ // generated code logging
+ FILE *codeLogfile = fopen(getCodeLoggingFile<inst_set_t::avx512>(
+ accum, mc, nc, nBlock, kBlock, mRegBlockSize,
+ nRegBlockSize, nRegBlockSizeMin)
+ .c_str(),
+ "w");
+ asmjit::FileLogger *codeLogger = new asmjit::FileLogger(codeLogfile);
+ if (codeLogger) {
+ code.setLogger(codeLogger);
+ }
#endif
- assert(kc % row_interleave == 0 && "kc must be a multiple of row_interleave");
- assert(nc % nRegBlockSizeMin == 0 && "nc must be a multiple of NR_MIN");
- int maxMRegs = mRegBlockSize;
- int maxNRegs = nRegBlockSize * row_interleave / VLEN_;
- assert(
- maxMRegs * maxNRegs <= 28 &&
- "MR*(NR*ROW_INTERLEAVE*8/512) \
+ assert(kc % row_interleave == 0 &&
+ "kc must be a multiple of row_interleave");
+ assert(nc % nRegBlockSizeMin == 0 && "nc must be a multiple of NR_MIN");
+ int maxMRegs = mRegBlockSize;
+ int maxNRegs = nRegBlockSize * row_interleave / VLEN_;
+ assert(maxMRegs * maxNRegs <= 28 && "MR*(NR*ROW_INTERLEAVE*8/512) \
must be <= 28(available registers constraint)");
- int mRegBlocks = mc / mRegBlockSize;
- int mRegBlocksRem = mc % mRegBlockSize;
-
- // arguments to the function created
- asmjit::X86Gp buffer_A = a->zdi();
- asmjit::X86Gp buffer_B = a->zsi();
- asmjit::X86Gp B_pf = a->zdx();
- asmjit::X86Gp CBase = a->zcx();
- asmjit::X86Gp kSize = a->gpzRef(8);
- asmjit::X86Gp ldcReg = a->gpzRef(9);
-
- asmjit::FuncDetail func;
- func.init(
- asmjit::
- FuncSignature6<void, uint8_t*, int8_t*, int8_t*, int32_t*, int, int>(
- asmjit::CallConv::kIdHost));
-
- asmjit::FuncFrameInfo ffi;
- ffi.setDirtyRegs(
- asmjit::X86Reg::kKindVec,
- asmjit::Utils::mask(0, 1, 2, 3, 4, 5, 6, 7) |
- asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14, 15));
- ffi.setDirtyRegs(
- asmjit::X86Reg::kKindGp,
- asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14, 15));
-
- asmjit::FuncArgsMapper args(&func);
- args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg);
-
- args.updateFrameInfo(ffi);
-
- asmjit::FuncFrameLayout layout;
- layout.init(func, ffi);
-
- asmjit::FuncUtils::emitProlog(a, layout);
- asmjit::FuncUtils::allocArgs(a, layout, args);
-
- asmjit::Label LoopMBlocks = a->newLabel();
- asmjit::Label LoopNBlocks = a->newLabel();
- asmjit::Label Loopk = a->newLabel();
-
- asmjit::X86Gp buffer_B_saved = a->gpzRef(10);
- asmjit::X86Gp C_Offset = a->gpzRef(11);
- asmjit::X86Gp B_pf_saved = a->gpzRef(12);
- asmjit::X86Gp iIdx = a->gpzRef(13);
- asmjit::X86Gp jIdx = a->gpzRef(14);
- asmjit::X86Gp kIdx = a->gpzRef(15);
- // asmjit::X86Gp B_pf = a->gpzRef(8);
-
- asmjit::X86Zmm oneReg = x86::zmm29;
- // create 16-bit 1s
- // i.e., oneReg[0:15] contains 0x0001, oneReg[16:31] contains 0x0001
- // and so on
- // a->vpcmpeqw(oneReg, oneReg, oneReg);
- a->vpternlogd(oneReg, oneReg, oneReg, 0xff);
- a->vpsrlw(oneReg, oneReg, 15);
- a->imul(ldcReg, ldcReg, static_cast<asmjit::Imm>(sizeof(int32_t)));
-
- // save B_buffer address
- a->mov(buffer_B_saved, buffer_B);
- a->mov(B_pf_saved, B_pf);
-
- int currColRegs = nc * row_interleave * sizeof(int8_t) / VLEN_;
- int colRegs = std::min(currColRegs, maxNRegs);
- if (mRegBlocks > 0) {
- // move 0 to iteration variables
- a->mov(iIdx, 0);
-
- a->bind(LoopMBlocks);
- a->inc(iIdx);
- a->mov(jIdx, 0);
-
- a->bind(LoopNBlocks);
- a->inc(jIdx);
-
- int rowRegs = mRegBlockSize;
-
- // init C registers
- initCRegs<inst_set_t::avx512>(a, rowRegs, colRegs, colRegs);
-
- // init k loop index
- a->mov(kIdx, 0);
- a->bind(Loopk);
-
- // k is incremented by row_interleave
- a->add(kIdx, static_cast<asmjit::Imm>(row_interleave));
-
- genComputeBlock<inst_set_t::avx512>(
- a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock, colRegs);
-
- // update buffer_A address for next k iteration
- a->add(
- buffer_A, static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t)));
-
- // update buffer_B address for next k iteration
- a->add(
- buffer_B,
- static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t)));
- a->add(
- B_pf,
- static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t)));
-
- // a->add(B_pf, static_cast<asmjit::Imm>(32*sizeof(float)));
-
- a->cmp(kIdx, kSize);
- a->jl(Loopk);
-
- // store C matrix
- storeCRegs<inst_set_t::avx512>(
- a, rowRegs, colRegs, C_Offset, ldcReg, accum, colRegs);
-
- // reset A
- a->sub(buffer_A, kSize);
-
- // B for next block
- a->mov(buffer_B, buffer_B_saved);
- // using C_Offset as temp reg
- a->imul(
- C_Offset,
- jIdx,
- static_cast<asmjit::Imm>(
- nRegBlockSize * row_interleave * sizeof(int8_t)));
- a->add(buffer_B, C_Offset);
- a->mov(B_pf, B_pf_saved);
- a->add(B_pf, C_Offset);
-
- // increment C for next B block
- a->add(CBase, static_cast<asmjit::Imm>(nRegBlockSize * sizeof(int32_t)));
-
- int jLoopTrips = currColRegs / maxNRegs;
- // jLoopTrips should be at least 1
- jLoopTrips = jLoopTrips ? jLoopTrips : 1;
- a->cmp(jIdx, jLoopTrips);
- a->jl(LoopNBlocks);
-
- // increment A for next block
- a->add(
- buffer_A, static_cast<asmjit::Imm>((rowRegs)*kBlock * sizeof(uint8_t)));
-
- // increment C for next A block
- a->sub(
- CBase,
- static_cast<asmjit::Imm>(jLoopTrips * nRegBlockSize * sizeof(int32_t)));
- a->imul(C_Offset, ldcReg, static_cast<asmjit::Imm>(rowRegs));
- a->add(CBase, C_Offset);
-
- // reset B
- a->mov(buffer_B, buffer_B_saved);
- a->mov(B_pf, B_pf_saved);
- a->cmp(iIdx, mRegBlocks);
- a->jl(LoopMBlocks);
- }
- // generate code for remainder
- if (mRegBlocksRem > 0) {
- asmjit::Label LoopNRem = a->newLabel();
- asmjit::Label LoopkRem = a->newLabel();
- int rowRegs = mRegBlocksRem;
-
- a->mov(jIdx, 0);
- a->bind(LoopNRem);
- a->inc(jIdx);
-
- // init C registers
- initCRegs<inst_set_t::avx512>(a, rowRegs, colRegs, colRegs);
-
- // init k loop index
- a->mov(kIdx, 0);
- a->bind(LoopkRem);
-
- // k is incremented by row_interleave
- a->add(kIdx, static_cast<asmjit::Imm>(row_interleave));
-
- genComputeBlock<inst_set_t::avx512>(
- a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock, colRegs);
-
- // update buffer_A address for next k iteration
- a->add(
- buffer_A, static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t)));
-
- // update buffer_B address for next k iteration
- a->add(
- buffer_B,
- static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t)));
- a->add(
- B_pf,
- static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t)));
-
- a->cmp(kIdx, kSize);
- a->jl(LoopkRem);
-
- // reset A
- a->sub(buffer_A, kSize);
-
- // B for next block
- // using C_Offset as temp reg
- a->imul(
- C_Offset,
- jIdx,
- static_cast<asmjit::Imm>(
- nRegBlockSize * row_interleave * sizeof(int8_t)));
- a->mov(buffer_B, buffer_B_saved);
- a->add(buffer_B, C_Offset);
- a->mov(B_pf, B_pf_saved);
- a->add(B_pf, C_Offset);
-
- // store C matrix
- storeCRegs<inst_set_t::avx512>(
- a, rowRegs, colRegs, C_Offset, ldcReg, accum, colRegs);
-
- // increment C for next B block
- a->add(CBase, static_cast<asmjit::Imm>(nRegBlockSize * sizeof(int32_t)));
-
- int jLoopTrips = currColRegs / maxNRegs;
- // jLoopTrips should be at least 1
- jLoopTrips = jLoopTrips ? jLoopTrips : 1;
- a->cmp(jIdx, jLoopTrips);
- a->jl(LoopNRem);
- }
+ int mRegBlocks = mc / mRegBlockSize;
+ int mRegBlocksRem = mc % mRegBlockSize;
+
+ // arguments to the function created
+#ifdef _MSC_VER
+ x86::Gp buffer_A = a->zcx();
+ x86::Gp buffer_B = a->zdx();
+ x86::Gp B_pf = a->gpz(8);
+ x86::Gp CBase = a->gpz(9);
+ x86::Gp kSize = a->zdi();
+ x86::Gp ldcReg = a->zsi();
+#else
+ x86::Gp buffer_A = a->zdi();
+ x86::Gp buffer_B = a->zsi();
+ x86::Gp B_pf = a->zdx();
+ x86::Gp CBase = a->zcx();
+ x86::Gp kSize = a->gpz(8);
+ x86::Gp ldcReg = a->gpz(9);
+#endif
- asmjit::FuncUtils::emitEpilog(a, layout);
+ asmjit::FuncDetail func;
+ func.init(
+ asmjit::FuncSignatureT<void, uint8_t *, int8_t *, int8_t *, int32_t *,
+ int, int>(asmjit::CallConv::kIdHost));
+
+ asmjit::FuncFrame frame;
+ frame.init(func);
+
+ frame.setDirtyRegs(
+ x86::Reg::kGroupVec,
+ asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) |
+ asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
+ frame.setDirtyRegs(x86::Reg::kGroupGp,
+ asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
+
+ asmjit::FuncArgsAssignment args(&func);
+ args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg);
+
+ args.updateFuncFrame(frame);
+ frame.finalize();
+
+ a->emitProlog(frame);
+ a->emitArgsAssignment(frame, args);
+
+ asmjit::Label LoopMBlocks = a->newLabel();
+ asmjit::Label LoopNBlocks = a->newLabel();
+ asmjit::Label Loopk = a->newLabel();
+
+ x86::Gp buffer_B_saved = a->gpz(10);
+ x86::Gp C_Offset = a->gpz(11);
+ x86::Gp B_pf_saved = a->gpz(12);
+ x86::Gp iIdx = a->gpz(13);
+ x86::Gp jIdx = a->gpz(14);
+ x86::Gp kIdx = a->gpz(15);
+ // x86::Gp B_pf = a->gpz(8);
+
+ x86::Zmm oneReg = x86::zmm29;
+ // create 16-bit 1s
+ // i.e., oneReg[0:15] contains 0x0001, oneReg[16:31] contains 0x0001
+ // and so on
+ // a->vpcmpeqw(oneReg, oneReg, oneReg);
+ a->vpternlogd(oneReg, oneReg, oneReg, 0xff);
+ a->vpsrlw(oneReg, oneReg, 15);
+ a->imul(ldcReg, ldcReg, static_cast<asmjit::Imm>(sizeof(int32_t)));
+
+ // save B_buffer address
+ a->mov(buffer_B_saved, buffer_B);
+ a->mov(B_pf_saved, B_pf);
+
+ int currColRegs = nc * row_interleave * sizeof(int8_t) / VLEN_;
+ int colRegs = std::min(currColRegs, maxNRegs);
+ if (mRegBlocks > 0) {
+ // move 0 to iteration variables
+ a->mov(iIdx, 0);
+
+ a->bind(LoopMBlocks);
+ a->inc(iIdx);
+ a->mov(jIdx, 0);
+
+ a->bind(LoopNBlocks);
+ a->inc(jIdx);
+
+ int rowRegs = mRegBlockSize;
+
+ // init C registers
+ initCRegs<inst_set_t::avx512>(a, rowRegs, colRegs, colRegs);
+
+ // init k loop index
+ a->mov(kIdx, 0);
+ a->bind(Loopk);
+
+ // k is incremented by row_interleave
+ a->add(kIdx, static_cast<asmjit::Imm>(row_interleave));
+
+ genComputeBlock<inst_set_t::avx512>(a, buffer_A, buffer_B, B_pf, rowRegs,
+ colRegs, kBlock, colRegs);
+
+ // update buffer_A address for next k iteration
+ a->add(buffer_A,
+ static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t)));
+
+ // update buffer_B address for next k iteration
+ a->add(buffer_B, static_cast<asmjit::Imm>(nBlock * row_interleave *
+ sizeof(int8_t)));
+ a->add(B_pf, static_cast<asmjit::Imm>(nBlock * row_interleave *
+ sizeof(int8_t)));
+
+ // a->add(B_pf, static_cast<asmjit::Imm>(32*sizeof(float)));
+
+ a->cmp(kIdx, kSize);
+ a->jl(Loopk);
+
+ // store C matrix
+ storeCRegs<inst_set_t::avx512>(a, rowRegs, colRegs, C_Offset, ldcReg,
+ accum, colRegs);
+
+ // reset A
+ a->sub(buffer_A, kSize);
+
+ // B for next block
+ a->mov(buffer_B, buffer_B_saved);
+ // using C_Offset as temp reg
+ a->imul(C_Offset, jIdx,
+ static_cast<asmjit::Imm>(nRegBlockSize * row_interleave *
+ sizeof(int8_t)));
+ a->add(buffer_B, C_Offset);
+ a->mov(B_pf, B_pf_saved);
+ a->add(B_pf, C_Offset);
+
+ // increment C for next B block
+ a->add(CBase, static_cast<asmjit::Imm>(nRegBlockSize * sizeof(int32_t)));
+
+ int jLoopTrips = currColRegs / maxNRegs;
+ // jLoopTrips should be at least 1
+ jLoopTrips = jLoopTrips ? jLoopTrips : 1;
+ a->cmp(jIdx, jLoopTrips);
+ a->jl(LoopNBlocks);
+
+ // increment A for next block
+ a->add(buffer_A,
+ static_cast<asmjit::Imm>((rowRegs)*kBlock * sizeof(uint8_t)));
+
+ // increment C for next A block
+ a->sub(CBase, static_cast<asmjit::Imm>(jLoopTrips * nRegBlockSize *
+ sizeof(int32_t)));
+ a->imul(C_Offset, ldcReg, static_cast<asmjit::Imm>(rowRegs));
+ a->add(CBase, C_Offset);
+
+ // reset B
+ a->mov(buffer_B, buffer_B_saved);
+ a->mov(B_pf, B_pf_saved);
+ a->cmp(iIdx, mRegBlocks);
+ a->jl(LoopMBlocks);
+ }
+ // generate code for remainder
+ if (mRegBlocksRem > 0) {
+ asmjit::Label LoopNRem = a->newLabel();
+ asmjit::Label LoopkRem = a->newLabel();
+ int rowRegs = mRegBlocksRem;
+
+ a->mov(jIdx, 0);
+ a->bind(LoopNRem);
+ a->inc(jIdx);
+
+ // init C registers
+ initCRegs<inst_set_t::avx512>(a, rowRegs, colRegs, colRegs);
+
+ // init k loop index
+ a->mov(kIdx, 0);
+ a->bind(LoopkRem);
+
+ // k is incremented by row_interleave
+ a->add(kIdx, static_cast<asmjit::Imm>(row_interleave));
+
+ genComputeBlock<inst_set_t::avx512>(a, buffer_A, buffer_B, B_pf, rowRegs,
+ colRegs, kBlock, colRegs);
+
+ // update buffer_A address for next k iteration
+ a->add(buffer_A,
+ static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t)));
+
+ // update buffer_B address for next k iteration
+ a->add(buffer_B, static_cast<asmjit::Imm>(nBlock * row_interleave *
+ sizeof(int8_t)));
+ a->add(B_pf, static_cast<asmjit::Imm>(nBlock * row_interleave *
+ sizeof(int8_t)));
+
+ a->cmp(kIdx, kSize);
+ a->jl(LoopkRem);
+
+ // reset A
+ a->sub(buffer_A, kSize);
+
+ // B for next block
+ // using C_Offset as temp reg
+ a->imul(C_Offset, jIdx,
+ static_cast<asmjit::Imm>(nRegBlockSize * row_interleave *
+ sizeof(int8_t)));
+ a->mov(buffer_B, buffer_B_saved);
+ a->add(buffer_B, C_Offset);
+ a->mov(B_pf, B_pf_saved);
+ a->add(B_pf, C_Offset);
+
+ // store C matrix
+ storeCRegs<inst_set_t::avx512>(a, rowRegs, colRegs, C_Offset, ldcReg,
+ accum, colRegs);
+
+ // increment C for next B block
+ a->add(CBase, static_cast<asmjit::Imm>(nRegBlockSize * sizeof(int32_t)));
+
+ int jLoopTrips = currColRegs / maxNRegs;
+ // jLoopTrips should be at least 1
+ jLoopTrips = jLoopTrips ? jLoopTrips : 1;
+ a->cmp(jIdx, jLoopTrips);
+ a->jl(LoopNRem);
+ }
- jit_micro_kernel_fp fn;
- asmjit::Error err = rt_.add(&fn, &code_);
- if (err) {
- std::cout << "Error: in fn add" << std::endl;
- return nullptr;
- }
- codeCache_[kernelSig] = fn;
+ a->emitEpilog(frame);
+
+ jit_micro_kernel_fp fn;
+ asmjit::Error err;
+ {
+ std::unique_lock<std::mutex> lock(rtMutex_);
+ err = runtime().add(&fn, &code);
+ }
+ if (err) {
+ std::cout << "Error: in fn add" << std::endl;
+ return nullptr;
+ }
#if defined(FBGEMM_LOG_CODE)
- fclose(codeLogfile);
- delete codeLogger;
+ fclose(codeLogfile);
+ delete codeLogger;
#endif
- return fn;
+ return fn;
+ });
}
} // namespace fbgemm
diff --git a/src/GenerateKernelU8S8S32ACC32Avx512VNNI.cc b/src/GenerateKernelU8S8S32ACC32Avx512VNNI.cc
new file mode 100644
index 0000000..bd8be1f
--- /dev/null
+++ b/src/GenerateKernelU8S8S32ACC32Avx512VNNI.cc
@@ -0,0 +1,435 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+#include <iostream>
+#include "GenerateKernel.h"
+
+namespace fbgemm {
+
+namespace x86 = asmjit::x86;
+
+/**
+ * Generate AVX512 instructions for initializing the C registers to 0 in 32-bit
+ * Accumulation kernel.
+ */
+template <>
+template <>
+void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::initCRegs<
+ inst_set_t::avx512_vnni>(
+ x86::Emitter* a,
+ int rowRegs,
+ int colRegs,
+ int leadingDimCReg) {
+ using CRegs = x86::Zmm;
+ for (int i = 0; i < rowRegs; ++i) {
+ for (int j = 0; j < colRegs; ++j) {
+ a->vxorps(
+ CRegs(i * leadingDimCReg + j),
+ CRegs(i * leadingDimCReg + j),
+ CRegs(i * leadingDimCReg + j));
+ }
+ }
+}
+
+/**
+ * Generate AVX512 instructions for computing block in the rank-k update of
+ * 32-bit Accmulation kernel.
+ */
+template <>
+template <>
+void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::genComputeBlock<
+ inst_set_t::avx512_vnni>(
+ x86::Emitter* a,
+ x86::Gp buffer_A,
+ x86::Gp buffer_B,
+ x86::Gp B_pf,
+ int rowRegs,
+ int colRegs,
+ int lda,
+ int leadingDimCReg) {
+ // used for matrix A
+ x86::Zmm AReg = x86::zmm31;
+
+ // used for matrix B
+ x86::Zmm BReg = x86::zmm30;
+
+ using CRegs = x86::Zmm;
+
+ for (int j = 0; j < colRegs; ++j) {
+ // load B
+ a->vmovaps(BReg, x86::dword_ptr(buffer_B, j * VLEN_ * sizeof(int8_t)));
+ // load A, broadcast and fmas
+ for (int i = 0; i < rowRegs; ++i) {
+ a->vpbroadcastd(
+ AReg, x86::dword_ptr(buffer_A, (i * lda) * sizeof(uint8_t)));
+ a->vpdpbusd(CRegs(i * leadingDimCReg + j), AReg, BReg);
+ }
+ a->prefetcht0(x86::dword_ptr(B_pf, j * VLEN_ * sizeof(int8_t)));
+ }
+}
+
+/**
+ * Generate AVX512 instructions for storing the C registers back to the memory
+ * in 32-bit Accumulation kernel.
+ */
+template <>
+template <>
+void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::storeCRegs<
+ inst_set_t::avx512_vnni>(
+ x86::Emitter* a,
+ int rowRegs,
+ int colRegs,
+ x86::Gp C_Offset,
+ x86::Gp ldcReg,
+ bool accum,
+ int leadingDimCReg) {
+ using CRegs = x86::Zmm;
+ for (int i = 0; i < rowRegs; ++i) {
+ if (i != 0) {
+ a->add(C_Offset, ldcReg);
+ } else {
+ a->mov(C_Offset, static_cast<asmjit::Imm>(0));
+ }
+ for (int j = 0; j < colRegs; ++j) {
+ if (accum) {
+ a->vpaddd(
+ CRegs(i * leadingDimCReg + j),
+ CRegs(i * leadingDimCReg + j),
+#ifdef _MSC_VER
+ x86::dword_ptr(a->gpz(9), C_Offset, 0, j * 16 * sizeof(int32_t)));
+#else
+ x86::dword_ptr(a->zcx(), C_Offset, 0, j * 16 * sizeof(int32_t)));
+#endif
+ }
+ a->vmovups(
+#ifdef _MSC_VER
+ x86::dword_ptr(a->gpz(9), C_Offset, 0, j * 16 * sizeof(int32_t)),
+#else
+ x86::dword_ptr(a->zcx(), C_Offset, 0, j * 16 * sizeof(int32_t)),
+#endif
+ CRegs(i * leadingDimCReg + j));
+ }
+ }
+}
+
+/**
+ * Get or Create the AVX512 instructions for 32-bit Accumulation macro-kernel.
+ *
+ */
+template <>
+template <>
+CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::jit_micro_kernel_fp
+CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreate<
+ inst_set_t::avx512_vnni>(
+ bool accum,
+ int32_t mc,
+ int32_t nc,
+ int32_t kc,
+ int32_t /* unused */) {
+ std::tuple<bool, int, int, int, int, int, int, int> kernelSig;
+ int kBlock;
+ int nBlock;
+ int mRegBlockSize;
+ int nRegBlockSize;
+ int nRegBlockSizeMin;
+ int row_interleave;
+
+ if (blocking_params) {
+ kBlock = blocking_params->KCB;
+ nBlock = blocking_params->NCB;
+ mRegBlockSize = blocking_params->MR;
+ nRegBlockSize = blocking_params->NR;
+ nRegBlockSizeMin = blocking_params->NR_MIN;
+ row_interleave = blocking_params->ROW_INTERLEAVE;
+ } else {
+ kBlock = PackingTraits<uint8_t, int32_t, inst_set_t::avx512_vnni>::KCB;
+ nBlock = PackingTraits<uint8_t, int32_t, inst_set_t::avx512_vnni>::NCB;
+ mRegBlockSize =
+ PackingTraits<uint8_t, int32_t, inst_set_t::avx512_vnni>::MR;
+ nRegBlockSize =
+ PackingTraits<uint8_t, int32_t, inst_set_t::avx512_vnni>::NR;
+ nRegBlockSizeMin =
+ PackingTraits<uint8_t, int32_t, inst_set_t::avx512_vnni>::NR_MIN;
+ row_interleave = PackingTraits<uint8_t, int32_t, inst_set_t::avx512_vnni>::
+ ROW_INTERLEAVE;
+ }
+
+ kernelSig = std::make_tuple(
+ accum,
+ mc,
+ nc,
+ nBlock,
+ kBlock,
+ mRegBlockSize,
+ nRegBlockSize,
+ nRegBlockSizeMin);
+
+ return codeCache_.getOrCreate(kernelSig, [&]() -> jit_micro_kernel_fp {
+ asmjit::CodeHolder code;
+ code.init(runtime().codeInfo());
+ x86::Assembler assembler(&code);
+ x86::Emitter *a = assembler.as<x86::Emitter>();
+
+#if defined(FBGEMM_LOG_CODE)
+ // generated code logging
+ FILE *codeLogfile = fopen(getCodeLoggingFile<inst_set_t::avx512_vnni>(
+ accum, mc, nc, nBlock, kBlock, mRegBlockSize,
+ nRegBlockSize, nRegBlockSizeMin)
+ .c_str(),
+ "w");
+ asmjit::FileLogger *codeLogger = new asmjit::FileLogger(codeLogfile);
+ if (codeLogger) {
+ code.setLogger(codeLogger);
+ }
+#endif
+
+ assert(kc % row_interleave == 0 &&
+ "kc must be a multiple of row_interleave");
+ assert(nc % nRegBlockSizeMin == 0 && "nc must be a multiple of NR_MIN");
+ int maxMRegs = mRegBlockSize;
+ int maxNRegs = nRegBlockSize * row_interleave / VLEN_;
+ assert(maxMRegs * maxNRegs <= 28 && "MR*(NR*ROW_INTERLEAVE*8/512) \
+ must be <= 28(available registers constraint)");
+
+ int mRegBlocks = mc / mRegBlockSize;
+ int mRegBlocksRem = mc % mRegBlockSize;
+
+ // arguments to the function created
+#ifdef _MSC_VER
+ x86::Gp buffer_A = a->zcx();
+ x86::Gp buffer_B = a->zdx();
+ x86::Gp B_pf = a->gpz(8);
+ x86::Gp CBase = a->gpz(9);
+ x86::Gp kSize = a->zdi();
+ x86::Gp ldcReg = a->zsi();
+#else
+ x86::Gp buffer_A = a->zdi();
+ x86::Gp buffer_B = a->zsi();
+ x86::Gp B_pf = a->zdx();
+ x86::Gp CBase = a->zcx();
+ x86::Gp kSize = a->gpz(8);
+ x86::Gp ldcReg = a->gpz(9);
+#endif
+
+ asmjit::FuncDetail func;
+ func.init(
+ asmjit::FuncSignatureT<void, uint8_t *, int8_t *, int8_t *, int32_t *,
+ int, int>(asmjit::CallConv::kIdHost));
+
+ asmjit::FuncFrame frame;
+ frame.init(func);
+
+ frame.setDirtyRegs(
+ x86::Reg::kGroupVec,
+ asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) |
+ asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
+ frame.setDirtyRegs(x86::Reg::kGroupGp,
+ asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
+
+ asmjit::FuncArgsAssignment args(&func);
+ args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg);
+
+ args.updateFuncFrame(frame);
+ frame.finalize();
+
+ a->emitProlog(frame);
+ a->emitArgsAssignment(frame, args);
+
+ asmjit::Label LoopMBlocks = a->newLabel();
+ asmjit::Label LoopNBlocks = a->newLabel();
+ asmjit::Label Loopk = a->newLabel();
+
+ x86::Gp buffer_B_saved = a->gpz(10);
+ x86::Gp C_Offset = a->gpz(11);
+ x86::Gp B_pf_saved = a->gpz(12);
+ x86::Gp iIdx = a->gpz(13);
+ x86::Gp jIdx = a->gpz(14);
+ x86::Gp kIdx = a->gpz(15);
+ // x86::Gp B_pf = a->gpz(8);
+
+ x86::Zmm oneReg = x86::zmm29;
+ // create 16-bit 1s
+ // i.e., oneReg[0:15] contains 0x0001, oneReg[16:31] contains 0x0001
+ // and so on
+ // a->vpcmpeqw(oneReg, oneReg, oneReg);
+ a->vpternlogd(oneReg, oneReg, oneReg, 0xff);
+ a->vpsrlw(oneReg, oneReg, 15);
+ a->imul(ldcReg, ldcReg, static_cast<asmjit::Imm>(sizeof(int32_t)));
+
+ // save B_buffer address
+ a->mov(buffer_B_saved, buffer_B);
+ a->mov(B_pf_saved, B_pf);
+
+ int currColRegs = nc * row_interleave * sizeof(int8_t) / VLEN_;
+ int colRegs = std::min(currColRegs, maxNRegs);
+ if (mRegBlocks > 0) {
+ // move 0 to iteration variables
+ a->mov(iIdx, 0);
+
+ a->bind(LoopMBlocks);
+ a->inc(iIdx);
+ a->mov(jIdx, 0);
+
+ a->bind(LoopNBlocks);
+ a->inc(jIdx);
+
+ int rowRegs = mRegBlockSize;
+
+ // init C registers
+ initCRegs<inst_set_t::avx512_vnni>(a, rowRegs, colRegs, colRegs);
+
+ // init k loop index
+ a->mov(kIdx, 0);
+ a->bind(Loopk);
+
+ // k is incremented by row_interleave
+ a->add(kIdx, static_cast<asmjit::Imm>(row_interleave));
+
+ genComputeBlock<inst_set_t::avx512_vnni>(
+ a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock, colRegs);
+
+ // update buffer_A address for next k iteration
+ a->add(buffer_A,
+ static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t)));
+
+ // update buffer_B address for next k iteration
+ a->add(buffer_B, static_cast<asmjit::Imm>(nBlock * row_interleave *
+ sizeof(int8_t)));
+ a->add(B_pf, static_cast<asmjit::Imm>(nBlock * row_interleave *
+ sizeof(int8_t)));
+
+ // a->add(B_pf, static_cast<asmjit::Imm>(32*sizeof(float)));
+
+ a->cmp(kIdx, kSize);
+ a->jl(Loopk);
+
+ // store C matrix
+ storeCRegs<inst_set_t::avx512_vnni>(a, rowRegs, colRegs, C_Offset, ldcReg,
+ accum, colRegs);
+
+ // reset A
+ a->sub(buffer_A, kSize);
+
+ // B for next block
+ a->mov(buffer_B, buffer_B_saved);
+ // using C_Offset as temp reg
+ a->imul(C_Offset, jIdx,
+ static_cast<asmjit::Imm>(nRegBlockSize * row_interleave *
+ sizeof(int8_t)));
+ a->add(buffer_B, C_Offset);
+ a->mov(B_pf, B_pf_saved);
+ a->add(B_pf, C_Offset);
+
+ // increment C for next B block
+ a->add(CBase, static_cast<asmjit::Imm>(nRegBlockSize * sizeof(int32_t)));
+
+ int jLoopTrips = currColRegs / maxNRegs;
+ // jLoopTrips should be at least 1
+ jLoopTrips = jLoopTrips ? jLoopTrips : 1;
+ a->cmp(jIdx, jLoopTrips);
+ a->jl(LoopNBlocks);
+
+ // increment A for next block
+ a->add(buffer_A,
+ static_cast<asmjit::Imm>((rowRegs)*kBlock * sizeof(uint8_t)));
+
+ // increment C for next A block
+ a->sub(CBase, static_cast<asmjit::Imm>(jLoopTrips * nRegBlockSize *
+ sizeof(int32_t)));
+ a->imul(C_Offset, ldcReg, static_cast<asmjit::Imm>(rowRegs));
+ a->add(CBase, C_Offset);
+
+ // reset B
+ a->mov(buffer_B, buffer_B_saved);
+ a->mov(B_pf, B_pf_saved);
+ a->cmp(iIdx, mRegBlocks);
+ a->jl(LoopMBlocks);
+ }
+ // generate code for remainder
+ if (mRegBlocksRem > 0) {
+ asmjit::Label LoopNRem = a->newLabel();
+ asmjit::Label LoopkRem = a->newLabel();
+ int rowRegs = mRegBlocksRem;
+
+ a->mov(jIdx, 0);
+ a->bind(LoopNRem);
+ a->inc(jIdx);
+
+ // init C registers
+ initCRegs<inst_set_t::avx512_vnni>(a, rowRegs, colRegs, colRegs);
+
+ // init k loop index
+ a->mov(kIdx, 0);
+ a->bind(LoopkRem);
+
+ // k is incremented by row_interleave
+ a->add(kIdx, static_cast<asmjit::Imm>(row_interleave));
+
+ genComputeBlock<inst_set_t::avx512_vnni>(
+ a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock, colRegs);
+
+ // update buffer_A address for next k iteration
+ a->add(buffer_A,
+ static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t)));
+
+ // update buffer_B address for next k iteration
+ a->add(buffer_B, static_cast<asmjit::Imm>(nBlock * row_interleave *
+ sizeof(int8_t)));
+ a->add(B_pf, static_cast<asmjit::Imm>(nBlock * row_interleave *
+ sizeof(int8_t)));
+
+ a->cmp(kIdx, kSize);
+ a->jl(LoopkRem);
+
+ // reset A
+ a->sub(buffer_A, kSize);
+ // B for next block
+ // using C_Offset as temp reg
+ a->imul(C_Offset, jIdx,
+ static_cast<asmjit::Imm>(nRegBlockSize * row_interleave *
+ sizeof(int8_t)));
+ a->mov(buffer_B, buffer_B_saved);
+ a->add(buffer_B, C_Offset);
+ a->mov(B_pf, B_pf_saved);
+ a->add(B_pf, C_Offset);
+
+ // store C matrix
+ storeCRegs<inst_set_t::avx512_vnni>(a, rowRegs, colRegs, C_Offset, ldcReg,
+ accum, colRegs);
+
+ // increment C for next B block
+ a->add(CBase, static_cast<asmjit::Imm>(nRegBlockSize * sizeof(int32_t)));
+
+ int jLoopTrips = currColRegs / maxNRegs;
+ // jLoopTrips should be at least 1
+ jLoopTrips = jLoopTrips ? jLoopTrips : 1;
+ a->cmp(jIdx, jLoopTrips);
+ a->jl(LoopNRem);
+ }
+
+ a->emitEpilog(frame);
+
+ jit_micro_kernel_fp fn;
+ asmjit::Error err;
+ {
+ std::unique_lock<std::mutex> lock(rtMutex_);
+ err = runtime().add(&fn, &code);
+ }
+ if (err) {
+ std::cout << "Error: in fn add" << std::endl;
+ return nullptr;
+ }
+
+#if defined(FBGEMM_LOG_CODE)
+ fclose(codeLogfile);
+ delete codeLogger;
+#endif
+
+ return fn;
+ });
+}
+
+} // namespace fbgemm
diff --git a/src/GroupwiseConv.h b/src/GroupwiseConv.h
index 1e6324e..58ee24d 100644
--- a/src/GroupwiseConv.h
+++ b/src/GroupwiseConv.h
@@ -10,8 +10,10 @@
#include <cassert>
#include <cstdint>
#include <map>
+#include <mutex>
#include <string>
#include <tuple>
+#include "CodeCache.h"
#include "fbgemm/ConvUtils.h"
#include "fbgemm/Fbgemm.h"
#include "fbgemm/Utils.h"
@@ -128,60 +130,58 @@ class GenConvKernel {
const conv_param_t<SPATIAL_DIM>& conv_param);
template <inst_set_t instSet>
- void createVector16BitOne(asmjit::X86Emitter* a);
+ void createVector16BitOne(x86::Emitter* a);
template <inst_set_t instSet>
- void createVector8BitOne(asmjit::X86Emitter* a);
+ void createVector8BitOne(x86::Emitter* a);
template <inst_set_t instSet>
- void setToZeroPt(asmjit::X86Emitter* a, asmjit::X86Ymm destReg);
+ void setToZeroPt(x86::Emitter* a, x86::Ymm destReg);
template <inst_set_t instSet>
- void
- gen8bitFMA(asmjit::X86Emitter* a, asmjit::X86Ymm aReg, asmjit::X86Ymm wReg);
+ void gen8bitFMA(x86::Emitter* a, x86::Ymm aReg, x86::Ymm wReg);
template <inst_set_t instSet>
- void genForLoadingWeights(asmjit::X86Emitter* a, int c_offset);
+ void genForLoadingWeights(x86::Emitter* a, int c_offset);
template <inst_set_t instSet>
- void genConstForPermutations(asmjit::X86Emitter* a);
+ void genConstForPermutations(x86::Emitter* a);
template <inst_set_t instSet>
- void genForTopEdge(asmjit::X86Emitter* a, int c_offset);
+ void genForTopEdge(x86::Emitter* a, int c_offset);
template <inst_set_t instSet>
- void genForLeftEdge(asmjit::X86Emitter* a, int c_offset);
+ void genForLeftEdge(x86::Emitter* a, int c_offset);
template <inst_set_t instSet>
- void genForRightEdge(asmjit::X86Emitter* a, int c_offset);
+ void genForRightEdge(x86::Emitter* a, int c_offset);
template <inst_set_t instSet>
- void genForBottomEdge(asmjit::X86Emitter* a, int c_offset);
+ void genForBottomEdge(x86::Emitter* a, int c_offset);
template <inst_set_t instSet>
- void genCoreInsts(asmjit::X86Emitter* a, int c_offset);
+ void genCoreInsts(x86::Emitter* a, int c_offset);
template <inst_set_t instSet>
- void storeResult(asmjit::X86Emitter* a);
+ void storeResult(x86::Emitter* a);
// for Rowoffset kernel
// Add 4 consecutive numbers of 32 uint8 and emit 8 32-bit
template <inst_set_t instSet>
- void gen8BitSumX4(asmjit::X86Emitter* a, asmjit::X86Ymm aReg);
+ void gen8BitSumX4(x86::Emitter* a, x86::Ymm aReg);
// Add 8 consecutive numbers of 64 uint8 and emit 8 32-bit
template <inst_set_t instSet>
- void
- gen8BitSumX8(asmjit::X86Emitter* a, asmjit::X86Ymm aReg, asmjit::X86Ymm bReg);
+ void gen8BitSumX8(x86::Emitter* a, x86::Ymm aReg, x86::Ymm bReg);
// Add 16 consecutive numbers of 128 uint8 and emit 8 32-bit
template <inst_set_t instSet>
void gen8BitSumX16(
- asmjit::X86Emitter* a,
- asmjit::X86Ymm aReg,
- asmjit::X86Ymm bReg,
- asmjit::X86Ymm cReg,
- asmjit::X86Ymm dReg);
+ x86::Emitter* a,
+ x86::Ymm aReg,
+ x86::Ymm bReg,
+ x86::Ymm cReg,
+ x86::Ymm dReg);
// Generate instruction sequence that loads 8-bit values and sum them up.
// Depending on C_per_G_, this function dispatches to gen8BitSumX4/8/16
@@ -191,73 +191,78 @@ class GenConvKernel {
// Internally, actRegAvx2_, stPermRegAvx2_, WRegs_avx2_[0, 1], tmpReg1Avx2_,
// and resultRegAvx2_ are used.
template <inst_set_t instSet>
- void gen8BitSum(
- asmjit::X86Emitter* a,
- int act_offset,
- bool use_scratch_reg1 = true);
+ void
+ gen8BitSum(x86::Emitter* a, int act_offset, bool use_scratch_reg1 = true);
// Use scratchReg1_ and tmpReg1Avx2_ internally
template <inst_set_t instSet>
- void genZeroPtSum(asmjit::X86Emitter* a, int multiplier);
+ void genZeroPtSum(x86::Emitter* a, int multiplier);
template <inst_set_t instSet>
- void genForTopEdgeRowoffset(asmjit::X86Emitter* a);
+ void genForTopEdgeRowoffset(x86::Emitter* a);
template <inst_set_t instSet>
- void genForLeftEdgeRowoffset(asmjit::X86Emitter* a);
+ void genForLeftEdgeRowoffset(x86::Emitter* a);
template <inst_set_t instSet>
- void genForRightEdgeRowoffset(asmjit::X86Emitter* a);
+ void genForRightEdgeRowoffset(x86::Emitter* a);
template <inst_set_t instSet>
- void genForBottomEdgeRowoffset(asmjit::X86Emitter* a);
+ void genForBottomEdgeRowoffset(x86::Emitter* a);
template <inst_set_t instSet>
- void genRowoffsetCorners(asmjit::X86Emitter* a);
+ void genRowoffsetCorners(x86::Emitter* a);
template <inst_set_t instSet>
- void genRowoffsetCore(asmjit::X86Emitter* a);
+ void genRowoffsetCore(x86::Emitter* a);
template <inst_set_t instSet>
- void storeResultRowoffset(asmjit::X86Emitter* a, int offset = 0);
-
- static thread_local asmjit::JitRuntime rt_; ///< JIT Runtime for asmjit.
- static thread_local asmjit::CodeHolder code_; ///< JIT Code Holder for asmjit.
- static thread_local std::
- map<std::tuple<bool, int, int, int>, jit_conv_kernel_fp>
- codeCache_; ///< JIT Code Cache for reuse.
- static thread_local std::
- map<std::tuple<bool, int, int, int>, jit_rowoffset_kernel_fp>
- codeCacheRowOffset_; ///< JIT Code Cache for row offset kernel.
-
- private:
+ void storeResultRowoffset(x86::Emitter* a, int offset = 0);
+
+
+ static asmjit::JitRuntime &runtime() {
+ static asmjit::JitRuntime rt; //< JIT Runtime for asmjit,
+ // depents on other static
+ // variables. Required to prevent
+ // initialization order fiasco
+ return rt;
+ }
+
+ static std::mutex rtMutex_; ///< Controll access to runtime;
+
+ static CodeCache<std::tuple<bool, int, int, int>, jit_conv_kernel_fp>
+ codeCache_; ///< JIT Code Cache for reuse.
+ static CodeCache<std::tuple<bool, int, int, int>, jit_rowoffset_kernel_fp>
+ codeCacheRowOffset_; ///< JIT Code Cache for row offset kernel.
+
+private:
int vectorWidth_; ///< Vector width in bits.
int VLEN_; ///< Vector width in elements.
// avx2 specific
- asmjit::X86Ymm
+ x86::Ymm
WRegs_avx2_[9]; ///< AVX2 ymm registers for weights in the micro-kernel.
- asmjit::X86Ymm zeroPTRegAvx2_;
- asmjit::X86Ymm tmpReg1Avx2_;
- asmjit::X86Ymm stPermRegAvx2_;
- asmjit::X86Ymm actRegAvx2_;
- asmjit::X86Ymm resultRegAvx2_;
- asmjit::X86Ymm oneReg8BitAvx2_;
- asmjit::X86Ymm oneReg16BitAvx2_;
+ x86::Ymm zeroPTRegAvx2_;
+ x86::Ymm tmpReg1Avx2_;
+ x86::Ymm stPermRegAvx2_;
+ x86::Ymm actRegAvx2_;
+ x86::Ymm resultRegAvx2_;
+ x86::Ymm oneReg8BitAvx2_;
+ x86::Ymm oneReg16BitAvx2_;
// arguments to the function created
- asmjit::X86Gp in_acts_R_;
- asmjit::X86Gp wghts_R_;
- asmjit::X86Gp out_acts_R_;
- asmjit::X86Gp a_zero_pt_R_;
- asmjit::X86Gp H_R_;
- asmjit::X86Gp W_R_;
- asmjit::X86Gp row_offset_R_;
+ x86::Gp in_acts_R_;
+ x86::Gp wghts_R_;
+ x86::Gp out_acts_R_;
+ x86::Gp a_zero_pt_R_;
+ x86::Gp H_R_;
+ x86::Gp W_R_;
+ x86::Gp row_offset_R_;
// Used registers
- asmjit::X86Gp loopR1_;
- asmjit::X86Gp loopR2_;
- asmjit::X86Gp scratchReg1_;
- asmjit::X86Gp scratchReg2_;
+ x86::Gp loopR1_;
+ x86::Gp loopR2_;
+ x86::Gp scratchReg1_;
+ x86::Gp scratchReg2_;
// Other parameters
bool isAZeroPointZero_;
@@ -276,4 +281,15 @@ class GenConvKernel {
int W_PAD_; ///< Padding for width (left and right)
};
+template <int SPATIAL_DIM, typename accT>
+std::mutex GenConvKernel<SPATIAL_DIM, accT>::rtMutex_;
+
+template <int SPATIAL_DIM, typename accT>
+CodeCache<std::tuple<bool, int, int, int>, jit_conv_kernel_fp>
+ GenConvKernel<SPATIAL_DIM, accT>::codeCache_;
+
+template <int SPATIAL_DIM, typename accT>
+CodeCache<std::tuple<bool, int, int, int>, jit_rowoffset_kernel_fp>
+ GenConvKernel<SPATIAL_DIM, accT>::codeCacheRowOffset_;
+
} // namespace fbgemm
diff --git a/src/GroupwiseConvAcc32Avx2.cc b/src/GroupwiseConvAcc32Avx2.cc
index e789695..396e792 100644
--- a/src/GroupwiseConvAcc32Avx2.cc
+++ b/src/GroupwiseConvAcc32Avx2.cc
@@ -21,20 +21,6 @@ namespace fbgemm {
using namespace std;
-template <int SPATIAL_DIM, typename accT>
-thread_local asmjit::JitRuntime GenConvKernel<SPATIAL_DIM, accT>::rt_;
-
-template <int SPATIAL_DIM, typename accT>
-thread_local asmjit::CodeHolder GenConvKernel<SPATIAL_DIM, accT>::code_;
-
-template <int SPATIAL_DIM, typename accT>
-thread_local std::map<std::tuple<bool, int, int, int>, jit_conv_kernel_fp>
- GenConvKernel<SPATIAL_DIM, accT>::codeCache_;
-
-template <int SPATIAL_DIM, typename accT>
-thread_local std::map<std::tuple<bool, int, int, int>, jit_rowoffset_kernel_fp>
- GenConvKernel<SPATIAL_DIM, accT>::codeCacheRowOffset_;
-
namespace x86 = asmjit::x86;
template <int SPATIAL_DIM>
@@ -91,20 +77,19 @@ jit_conv_kernel_fp getOrCreateConvKernel(
// Note: Wrong code is generated if it's not one of the supported convolution
assert(fbgemmOptimizedGConv<SPATIAL_DIM>(conv_param));
auto kernelSig = getKernelSig(conv_param, a_zero_point == 0);
- if (GenConvKernel<SPATIAL_DIM, accT>::codeCache_.find(kernelSig) !=
- GenConvKernel<SPATIAL_DIM, accT>::codeCache_.end()) {
- return GenConvKernel<SPATIAL_DIM, accT>::codeCache_[kernelSig];
- } else {
- auto genObj = GenConvKernel<SPATIAL_DIM, accT>(conv_param, a_zero_point);
- // TODO: Instruction set based dispatch
- return genObj.template getOrCreate<inst_set_t::avx2>(conv_param);
- }
+ return GenConvKernel<SPATIAL_DIM, accT>::codeCache_.getOrCreate(
+ kernelSig, [&]() {
+ auto genObj =
+ GenConvKernel<SPATIAL_DIM, accT>(conv_param, a_zero_point);
+ // TODO: Instruction set based dispatch
+ return genObj.template getOrCreate<inst_set_t::avx2>(conv_param);
+ });
}
template <>
template <>
void GenConvKernel<2, int32_t>::createVector8BitOne<inst_set_t::avx2>(
- asmjit::X86Emitter* a) {
+ x86::Emitter* a) {
// create 8-bit 1s
// i.e., oneReg16BitAvx2_[0:7] contains 0x01, oneReg8BitAvx2_[8:15] contains
// 0x01 and so on
@@ -115,7 +100,7 @@ void GenConvKernel<2, int32_t>::createVector8BitOne<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::createVector16BitOne<inst_set_t::avx2>(
- asmjit::X86Emitter* a) {
+ x86::Emitter* a) {
// create 16-bit 1s
// i.e., oneReg16BitAvx2_[0:15] contains 0x0001, oneReg16BitAvx2_[16:31]
// contains 0x0001 and so on
@@ -125,11 +110,11 @@ void GenConvKernel<2, int32_t>::createVector16BitOne<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::setToZeroPt<inst_set_t::avx2>(
- asmjit::X86Emitter* a,
- asmjit::X86Ymm destReg) {
+ x86::Emitter* a,
+ x86::Ymm destReg) {
// make destReg all zeros
a->vxorps(destReg, destReg, destReg);
- asmjit::X86Xmm const_reg_xmm = x86::xmm10;
+ x86::Xmm const_reg_xmm = x86::xmm10;
// move zero point to xmm10
a->movq(const_reg_xmm, a_zero_pt_R_);
// make copies of zero point
@@ -143,9 +128,9 @@ void GenConvKernel<2, int32_t>::setToZeroPt<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::genConstForPermutations<inst_set_t::avx2>(
- asmjit::X86Emitter* a) {
- asmjit::X86Gp permute_const_reg = a->gpzRef(12);
- asmjit::X86Xmm const_reg_xmm = x86::xmm10;
+ x86::Emitter* a) {
+ x86::Gp permute_const_reg = a->gpz(12);
+ x86::Xmm const_reg_xmm = x86::xmm10;
// We have 1st group in even lanes and 2nd group in odd lanes.
// Permute to put 1st group to lower 128-bit and 2nd group in upper
// 128-bit.
@@ -159,8 +144,7 @@ void GenConvKernel<2, int32_t>::genConstForPermutations<inst_set_t::avx2>(
template <>
template <>
-void GenConvKernel<2, int32_t>::storeResult<inst_set_t::avx2>(
- asmjit::X86Emitter* a) {
+void GenConvKernel<2, int32_t>::storeResult<inst_set_t::avx2>(x86::Emitter* a) {
if (C_per_G_ == 4) {
// store with permutation
a->vpermd(resultRegAvx2_, stPermRegAvx2_, resultRegAvx2_);
@@ -171,7 +155,7 @@ void GenConvKernel<2, int32_t>::storeResult<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::storeResultRowoffset<inst_set_t::avx2>(
- asmjit::X86Emitter* a,
+ x86::Emitter* a,
int offset) {
// store
if (C_per_G_ == 4) {
@@ -198,7 +182,7 @@ void GenConvKernel<2, int32_t>::storeResultRowoffset<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::genForLoadingWeights<inst_set_t::avx2>(
- asmjit::X86Emitter* a,
+ x86::Emitter* a,
int c_offset) {
// load weights
for (int r = 0; r < R_; ++r) {
@@ -225,9 +209,9 @@ void GenConvKernel<2, int32_t>::genForLoadingWeights<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::gen8bitFMA<inst_set_t::avx2>(
- asmjit::X86Emitter* a,
- asmjit::X86Ymm aReg,
- asmjit::X86Ymm wReg) {
+ x86::Emitter* a,
+ x86::Ymm aReg,
+ x86::Ymm wReg) {
a->vpmaddubsw(tmpReg1Avx2_, aReg, wReg);
a->vpmaddwd(tmpReg1Avx2_, oneReg16BitAvx2_, tmpReg1Avx2_);
a->vpaddd(resultRegAvx2_, tmpReg1Avx2_, resultRegAvx2_);
@@ -236,8 +220,8 @@ void GenConvKernel<2, int32_t>::gen8bitFMA<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::gen8BitSumX4<inst_set_t::avx2>(
- asmjit::X86Emitter* a,
- asmjit::X86Ymm aReg) {
+ x86::Emitter* a,
+ x86::Ymm aReg) {
a->vpmaddubsw(tmpReg1Avx2_, aReg, oneReg8BitAvx2_);
a->vpmaddwd(tmpReg1Avx2_, tmpReg1Avx2_, oneReg16BitAvx2_);
a->vpaddd(resultRegAvx2_, tmpReg1Avx2_, resultRegAvx2_);
@@ -246,9 +230,9 @@ void GenConvKernel<2, int32_t>::gen8BitSumX4<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::gen8BitSumX8<inst_set_t::avx2>(
- asmjit::X86Emitter* a,
- asmjit::X86Ymm aReg,
- asmjit::X86Ymm bReg) {
+ x86::Emitter* a,
+ x86::Ymm aReg,
+ x86::Ymm bReg) {
a->vxorps(tmpReg1Avx2_, tmpReg1Avx2_, tmpReg1Avx2_);
// Let a[0] denote 0th (LSB) 8-bit of aReg
// After vpsadbw, a[0:2] = a[0] + ... + a[7]
@@ -267,11 +251,11 @@ void GenConvKernel<2, int32_t>::gen8BitSumX8<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::gen8BitSumX16<inst_set_t::avx2>(
- asmjit::X86Emitter* a,
- asmjit::X86Ymm aReg,
- asmjit::X86Ymm bReg,
- asmjit::X86Ymm cReg,
- asmjit::X86Ymm dReg) {
+ x86::Emitter* a,
+ x86::Ymm aReg,
+ x86::Ymm bReg,
+ x86::Ymm cReg,
+ x86::Ymm dReg) {
a->vxorps(tmpReg1Avx2_, tmpReg1Avx2_, tmpReg1Avx2_);
// After vpsadbw, a[0:2] = a[0] + ... + a[7]
// a[8:10] = a[8] + ... + a[15]
@@ -319,7 +303,7 @@ void GenConvKernel<2, int32_t>::gen8BitSumX16<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::gen8BitSum<inst_set_t::avx2>(
- asmjit::X86Emitter* a,
+ x86::Emitter* a,
int act_offset,
bool use_scratch_reg1 /*=true*/) {
if (use_scratch_reg1) {
@@ -385,11 +369,11 @@ void GenConvKernel<2, int32_t>::gen8BitSum<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::genZeroPtSum<inst_set_t::avx2>(
- asmjit::X86Emitter* a,
+ x86::Emitter* a,
int multiplier) {
a->mov(scratchReg1_, static_cast<asmjit::Imm>(multiplier));
// tmpReg1Avx2_ also uses xmm11
- asmjit::X86Xmm const_reg_xmm = x86::xmm11;
+ x86::Xmm const_reg_xmm = x86::xmm11;
a->movq(const_reg_xmm, scratchReg1_);
a->vpbroadcastd(tmpReg1Avx2_, const_reg_xmm);
a->vpmulld(tmpReg1Avx2_, zeroPTRegAvx2_, tmpReg1Avx2_);
@@ -399,7 +383,7 @@ void GenConvKernel<2, int32_t>::genZeroPtSum<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::genForTopEdge<inst_set_t::avx2>(
- asmjit::X86Emitter* a,
+ x86::Emitter* a,
int c_offset) {
// top-left corner code
if (c_offset == 0) {
@@ -559,7 +543,7 @@ void GenConvKernel<2, int32_t>::genForTopEdge<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::genForLeftEdge<inst_set_t::avx2>(
- asmjit::X86Emitter* a,
+ x86::Emitter* a,
int c_offset) {
// left edge excluding corners
asmjit::Label LoopLeftEdge = a->newLabel();
@@ -626,7 +610,7 @@ void GenConvKernel<2, int32_t>::genForLeftEdge<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::genForRightEdge<inst_set_t::avx2>(
- asmjit::X86Emitter* a,
+ x86::Emitter* a,
int c_offset) {
// right edge excluding corners
asmjit::Label LoopRightEdge = a->newLabel();
@@ -714,7 +698,7 @@ void GenConvKernel<2, int32_t>::genForRightEdge<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::genForBottomEdge<inst_set_t::avx2>(
- asmjit::X86Emitter* a,
+ x86::Emitter* a,
int c_offset) {
// bottom-left corner
// we updating the last row
@@ -906,7 +890,7 @@ void GenConvKernel<2, int32_t>::genForBottomEdge<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::genCoreInsts<inst_set_t::avx2>(
- asmjit::X86Emitter* a,
+ x86::Emitter* a,
int c_offset) {
// main compute
asmjit::Label LoopH = a->newLabel();
@@ -1010,10 +994,10 @@ template <>
template <>
jit_conv_kernel_fp GenConvKernel<2, int32_t>::getOrCreate<inst_set_t::avx2>(
const conv_param_t<2>& conv_param) {
- code_.reset(false);
- code_.init(rt_.getCodeInfo());
- asmjit::X86Assembler assembler(&code_);
- asmjit::X86Emitter* a = assembler.asEmitter();
+ asmjit::CodeHolder code;
+ code.init(runtime().codeInfo());
+ x86::Assembler assembler(&code);
+ x86::Emitter* a = assembler.as<x86::Emitter>();
#if defined(FBGEMM_LOG_CODE)
// log code to a file
@@ -1021,25 +1005,34 @@ jit_conv_kernel_fp GenConvKernel<2, int32_t>::getOrCreate<inst_set_t::avx2>(
fopen(getCodeLoggingFile<inst_set_t::avx2>(false).c_str(), "w");
asmjit::FileLogger* codeLogger = new asmjit::FileLogger(codeLogfile);
if (codeLogger) {
- code_.setLogger(codeLogger);
+ code.setLogger(codeLogger);
}
#endif
// arguments to the function created
+#ifdef _MSC_VER
+ in_acts_R_ = a->zcx();
+ wghts_R_ = a->zdx();
+ out_acts_R_ = a->gpz(8);
+ a_zero_pt_R_ = a->gpz(9);
+ H_R_ = a->zdi();
+ W_R_ = a->zsi();
+#else
in_acts_R_ = a->zdi();
wghts_R_ = a->zsi();
out_acts_R_ = a->zdx();
a_zero_pt_R_ = a->zcx();
- H_R_ = a->gpzRef(8);
- W_R_ = a->gpzRef(9);
- row_offset_R_ = a->gpzRef(10);
+ H_R_ = a->gpz(8);
+ W_R_ = a->gpz(9);
+#endif
+ row_offset_R_ = a->gpz(10);
// register for temporary use
- scratchReg1_ = a->gpzRef(12);
- scratchReg2_ = a->gpzRef(13);
+ scratchReg1_ = a->gpz(12);
+ scratchReg2_ = a->gpz(13);
asmjit::FuncDetail func;
- func.init(asmjit::FuncSignature6<
+ func.init(asmjit::FuncSignatureT<
void,
uint8_t*,
int8_t*,
@@ -1048,29 +1041,29 @@ jit_conv_kernel_fp GenConvKernel<2, int32_t>::getOrCreate<inst_set_t::avx2>(
int32_t,
int32_t>(asmjit::CallConv::kIdHost));
- asmjit::FuncFrameInfo ffi;
- ffi.setDirtyRegs(
- asmjit::X86Reg::kKindVec,
- asmjit::Utils::mask(0, 1, 2, 3, 4, 5, 6, 7) |
- asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14, 15));
- ffi.setDirtyRegs(
- asmjit::X86Reg::kKindGp, asmjit::Utils::mask(10, 11, 12, 13, 14, 15));
+ asmjit::FuncFrame frame;
+ frame.init(func);
- asmjit::FuncArgsMapper args(&func);
- args.assignAll(in_acts_R_, wghts_R_, out_acts_R_, a_zero_pt_R_, H_R_, W_R_);
+ frame.setDirtyRegs(
+ x86::Reg::kGroupVec,
+ asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) |
+ asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
+ frame.setDirtyRegs(
+ x86::Reg::kGroupGp, asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
- args.updateFrameInfo(ffi);
+ asmjit::FuncArgsAssignment args(&func);
+ args.assignAll(in_acts_R_, wghts_R_, out_acts_R_, a_zero_pt_R_, H_R_, W_R_);
- asmjit::FuncFrameLayout layout;
- layout.init(func, ffi);
+ args.updateFuncFrame(frame);
+ frame.finalize();
- asmjit::FuncUtils::emitProlog(a, layout);
- asmjit::FuncUtils::allocArgs(a, layout, args);
+ a->emitProlog(frame);
+ a->emitArgsAssignment(frame, args);
createVector16BitOne<inst_set_t::avx2>(a);
- loopR1_ = a->gpzRef(14);
- loopR2_ = a->gpzRef(15);
+ loopR1_ = a->gpz(14);
+ loopR2_ = a->gpz(15);
if (!isAZeroPointZero_) {
setToZeroPt<inst_set_t::avx2>(a, zeroPTRegAvx2_);
@@ -1095,16 +1088,18 @@ jit_conv_kernel_fp GenConvKernel<2, int32_t>::getOrCreate<inst_set_t::avx2>(
genCoreInsts<inst_set_t::avx2>(a, c);
}
- asmjit::FuncUtils::emitEpilog(a, layout);
+ a->emitEpilog(frame);
jit_conv_kernel_fp fn;
- asmjit::Error err = rt_.add(&fn, &code_);
+ asmjit::Error err;
+ {
+ std::unique_lock<std::mutex> lock(rtMutex_);
+ err = runtime().add(&fn, &code);
+ }
if (err) {
std::cout << "Error: in fn add" << std::endl;
return nullptr;
}
- auto kernelSig = getKernelSig(conv_param, isAZeroPointZero_);
- codeCache_[kernelSig] = fn;
#if defined(FBGEMM_LOG_CODE)
fclose(codeLogfile);
@@ -1117,7 +1112,7 @@ jit_conv_kernel_fp GenConvKernel<2, int32_t>::getOrCreate<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::genForTopEdgeRowoffset<inst_set_t::avx2>(
- asmjit::X86Emitter* a) {
+ x86::Emitter* a) {
// top-left corner code
// zero out the results register
a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_);
@@ -1213,7 +1208,7 @@ void GenConvKernel<2, int32_t>::genForTopEdgeRowoffset<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::genForLeftEdgeRowoffset<inst_set_t::avx2>(
- asmjit::X86Emitter* a) {
+ x86::Emitter* a) {
// left edge excluding corners
asmjit::Label LoopLeftEdge = a->newLabel();
a->mov(loopR1_, static_cast<asmjit::Imm>(H_PAD_));
@@ -1256,7 +1251,7 @@ void GenConvKernel<2, int32_t>::genForLeftEdgeRowoffset<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::genForRightEdgeRowoffset<inst_set_t::avx2>(
- asmjit::X86Emitter* a) {
+ x86::Emitter* a) {
// right edge excluding corners
asmjit::Label LoopRightEdge = a->newLabel();
@@ -1326,7 +1321,7 @@ void GenConvKernel<2, int32_t>::genForRightEdgeRowoffset<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::genForBottomEdgeRowoffset<inst_set_t::avx2>(
- asmjit::X86Emitter* a) {
+ x86::Emitter* a) {
// bottom-left corner
// zero out
a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_);
@@ -1429,7 +1424,7 @@ void GenConvKernel<2, int32_t>::genForBottomEdgeRowoffset<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::genRowoffsetCore<inst_set_t::avx2>(
- asmjit::X86Emitter* a) {
+ x86::Emitter* a) {
// number of uint8 elements in input channels should be a multiple of 32
assert(C_ % 32 == 0);
@@ -1490,10 +1485,10 @@ template <>
jit_rowoffset_kernel_fp
GenConvKernel<2, int32_t>::getOrCreateRowOffset<inst_set_t::avx2>(
const conv_param_t<2>& conv_param) {
- code_.reset(false);
- code_.init(rt_.getCodeInfo());
- asmjit::X86Assembler assembler(&code_);
- asmjit::X86Emitter* a = assembler.asEmitter();
+ asmjit::CodeHolder code;
+ code.init(runtime().codeInfo());
+ x86::Assembler assembler(&code);
+ x86::Emitter* a = assembler.as<x86::Emitter>();
#if defined(FBGEMM_LOG_CODE)
// log code to a file
@@ -1501,54 +1496,62 @@ GenConvKernel<2, int32_t>::getOrCreateRowOffset<inst_set_t::avx2>(
fopen(getCodeLoggingFile<inst_set_t::avx2>(true).c_str(), "w");
asmjit::FileLogger* codeLogger = new asmjit::FileLogger(codeLogfile);
if (codeLogger) {
- code_.setLogger(codeLogger);
+ code.setLogger(codeLogger);
}
#endif
// arguments to the function created
+#ifdef _MSC_VER
+ in_acts_R_ = a->zcx();
+ a_zero_pt_R_ = a->zdx();
+ H_R_ = a->gpz(8);
+ W_R_ = a->gpz(9);
+ row_offset_R_ = a->zdi();
+#else
in_acts_R_ = a->zdi();
a_zero_pt_R_ = a->zsi();
H_R_ = a->zdx();
W_R_ = a->zcx();
- row_offset_R_ = a->gpzRef(8);
+ row_offset_R_ = a->gpz(8);
+#endif
// register for temporary use
- scratchReg1_ = a->gpzRef(12);
- scratchReg2_ = a->gpzRef(13);
+ scratchReg1_ = a->gpz(12);
+ scratchReg2_ = a->gpz(13);
- loopR1_ = a->gpzRef(14);
- loopR2_ = a->gpzRef(15);
+ loopR1_ = a->gpz(14);
+ loopR2_ = a->gpz(15);
asmjit::FuncDetail func;
func.init(
asmjit::
- FuncSignature5<void, uint8_t*, int32_t, int32_t, int32_t, int32_t*>(
+ FuncSignatureT<void, uint8_t*, int32_t, int32_t, int32_t, int32_t*>(
asmjit::CallConv::kIdHost));
- asmjit::FuncFrameInfo ffi;
- ffi.setDirtyRegs(
- asmjit::X86Reg::kKindVec,
- asmjit::Utils::mask(0, 1, 2, 3, 4, 5, 6, 7) |
- asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14, 15));
- ffi.setDirtyRegs(
- asmjit::X86Reg::kKindGp, asmjit::Utils::mask(10, 11, 12, 13, 14, 15));
+ asmjit::FuncFrame frame;
+ frame.init(func);
- asmjit::FuncArgsMapper args(&func);
- args.assignAll(in_acts_R_, a_zero_pt_R_, H_R_, W_R_, row_offset_R_);
+ frame.setDirtyRegs(
+ x86::Reg::kGroupVec,
+ asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) |
+ asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
+ frame.setDirtyRegs(
+ x86::Reg::kGroupGp, asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
- args.updateFrameInfo(ffi);
+ asmjit::FuncArgsAssignment args(&func);
+ args.assignAll(in_acts_R_, a_zero_pt_R_, H_R_, W_R_, row_offset_R_);
- asmjit::FuncFrameLayout layout;
- layout.init(func, ffi);
+ args.updateFuncFrame(frame);
+ frame.finalize();
- asmjit::FuncUtils::emitProlog(a, layout);
- asmjit::FuncUtils::allocArgs(a, layout, args);
+ a->emitProlog(frame);
+ a->emitArgsAssignment(frame, args);
// This uses xmm10 register temporarily. Should come before
// createVector8BitOne
if (!isAZeroPointZero_) {
// we can use xmm11 because ymm11 is used by tmpReg1Avx2_
- asmjit::X86Xmm const_reg_xmm = x86::xmm11;
+ x86::Xmm const_reg_xmm = x86::xmm11;
a->movq(const_reg_xmm, a_zero_pt_R_);
a->vpbroadcastd(zeroPTRegAvx2_, const_reg_xmm);
@@ -1569,16 +1572,18 @@ GenConvKernel<2, int32_t>::getOrCreateRowOffset<inst_set_t::avx2>(
genRowoffsetCore<inst_set_t::avx2>(a);
- asmjit::FuncUtils::emitEpilog(a, layout);
+ a->emitEpilog(frame);
+ asmjit::Error err;
jit_rowoffset_kernel_fp fn;
- asmjit::Error err = rt_.add(&fn, &code_);
+ {
+ std::unique_lock<std::mutex> lock(rtMutex_);
+ err = runtime().add(&fn, &code);
+ }
if (err) {
std::cout << "Error: in fn add" << std::endl;
return nullptr;
}
- auto kernelSig = getKernelSig(conv_param, isAZeroPointZero_);
- codeCacheRowOffset_[kernelSig] = fn;
#if defined(FBGEMM_LOG_CODE)
delete codeLogger;
@@ -1781,7 +1786,8 @@ template <
typename outType,
bool FUSE_RELU,
QuantizationGranularity Q_GRAN,
- int SPATIAL_DIM>
+ int SPATIAL_DIM,
+ typename BIAS_TYPE>
void fbgemmGroupwiseConv(
const conv_param_t<SPATIAL_DIM>& conv_param,
const std::uint8_t* activations,
@@ -1790,10 +1796,10 @@ void fbgemmGroupwiseConv(
packed_W& packed_weights,
outType* out,
int32_t* outBuffer,
- const ReQuantizeOutput<FUSE_RELU, Q_GRAN>& outProcess,
+ const ReQuantizeOutput<FUSE_RELU, Q_GRAN, BIAS_TYPE>& outProcess,
int thread_id,
int num_threads) {
- typedef ReQuantizeOutput<FUSE_RELU, Q_GRAN> processOutputType;
+ typedef ReQuantizeOutput<FUSE_RELU, Q_GRAN, BIAS_TYPE> processOutputType;
if (!cpuinfo_initialize()) {
throw std::runtime_error("Failed to initialize cpuinfo!");
@@ -1884,15 +1890,17 @@ void fbgemmGroupwiseConv(
outProcess.getBZeroPoint()[0] == 0) ||
rowOffsetBuf == nullptr;
- requantizationParams_t r = {a_zero_point,
- outProcess.getBZeroPoint(),
- outProcess.getCZeroPoint(),
- outProcess.getCMultiplier(),
- rowOffsetBuf,
- outProcess.getColOffsets(),
- outProcess.getBias(),
- outProcess.getNCols(),
- G};
+ requantizationParams_t<typename processOutputType::BIAS_T> r = {
+ a_zero_point,
+ outProcess.getBZeroPoint(),
+ outProcess.getCZeroPoint(),
+ outProcess.getCMultiplier(),
+ rowOffsetBuf,
+ outProcess.getColOffsets(),
+ outProcess.getBias(),
+ outProcess.getNCols(),
+ G,
+ outProcess.getActWScale()};
const std::int32_t* inp = outBuffer;
block_type_t block{i * oh_ow, oh_ow, gOuter * K_per_G, 8 * K_per_G};
@@ -2163,15 +2171,14 @@ jit_rowoffset_kernel_fp getOrCreateRowOffsetKernel(
// Note: Wrong code is generated if it's not one of the supported convolution
assert(fbgemmOptimizedGConv<SPATIAL_DIM>(conv_param));
auto kernelSig = getKernelSig(conv_param, a_zero_point == 0);
- if (GenConvKernel<SPATIAL_DIM, int32_t>::codeCacheRowOffset_.find(
- kernelSig) !=
- GenConvKernel<SPATIAL_DIM, int32_t>::codeCacheRowOffset_.end()) {
- return GenConvKernel<SPATIAL_DIM, int32_t>::codeCacheRowOffset_[kernelSig];
- } else {
- auto genObj = GenConvKernel<SPATIAL_DIM, int32_t>(conv_param, a_zero_point);
- // TODO: Instruction set based dispatch
- return genObj.template getOrCreateRowOffset<inst_set_t::avx2>(conv_param);
- }
+ return GenConvKernel<SPATIAL_DIM, int32_t>::codeCacheRowOffset_.getOrCreate(
+ kernelSig, [&]() {
+ auto genObj =
+ GenConvKernel<SPATIAL_DIM, int32_t>(conv_param, a_zero_point);
+ // TODO: Instruction set based dispatch
+ return genObj.template getOrCreateRowOffset<inst_set_t::avx2>(
+ conv_param);
+ });
}
template <int SPATIAL_DIM>
@@ -2215,7 +2222,7 @@ int rowOffsetBufferSizeGConv(const conv_param_t<SPATIAL_DIM>& conv_param) {
template int rowOffsetBufferSizeGConv<2>(const conv_param_t<2>& conv_param);
template int rowOffsetBufferSizeGConv<3>(const conv_param_t<3>& conv_param);
-#define INSTANTIATE_BASE(RELU, Q_GRAN, SPATIAL_DIM) \
+#define INSTANTIATE_BASE(RELU, Q_GRAN, SPATIAL_DIM, BIAS_TYPE) \
template void fbgemmGroupwiseConv( \
const conv_param_t<SPATIAL_DIM>& conv_param, \
const uint8_t* activations, \
@@ -2224,13 +2231,17 @@ template int rowOffsetBufferSizeGConv<3>(const conv_param_t<3>& conv_param);
PackWeightMatrixForGConv<int8_t, int32_t, SPATIAL_DIM>& packed_weights, \
uint8_t* out, \
int32_t* outBuffer, \
- const ReQuantizeOutput<RELU, Q_GRAN>& outProcess, \
+ const ReQuantizeOutput<RELU, Q_GRAN, BIAS_TYPE>& outProcess, \
int thread_id, \
int num_threads);
+#define INSTANTIATE_BIAS_T(RELU, Q_GRAN, SPATIAL_DIM) \
+ INSTANTIATE_BASE(RELU, Q_GRAN, SPATIAL_DIM, float); \
+ INSTANTIATE_BASE(RELU, Q_GRAN, SPATIAL_DIM, int32_t);
+
#define INSTANTIATE_SPATIAL_DIM(RELU, Q_GRAN) \
- INSTANTIATE_BASE(RELU, Q_GRAN, 2); \
- INSTANTIATE_BASE(RELU, Q_GRAN, 3);
+ INSTANTIATE_BIAS_T(RELU, Q_GRAN, 2); \
+ INSTANTIATE_BIAS_T(RELU, Q_GRAN, 3);
#define INSTANTIATE_Q_GRANS(RELU) \
INSTANTIATE_SPATIAL_DIM(RELU, QuantizationGranularity::TENSOR); \
@@ -2242,6 +2253,7 @@ INSTANTIATE_Q_GRANS(true);
#undef INSTANTIATE_Q_GRANS
#undef INSTANTIATE_SPATIAL_DIM
+#undef INSTANTIATE_BIAS_T
#undef INSTANTIATE_BASE
template void fbgemmGroupwiseConv(
diff --git a/src/OptimizedKernelsAvx2.cc b/src/OptimizedKernelsAvx2.cc
index e8c65c3..326bd72 100644
--- a/src/OptimizedKernelsAvx2.cc
+++ b/src/OptimizedKernelsAvx2.cc
@@ -7,6 +7,7 @@
#include "OptimizedKernelsAvx2.h"
#include <immintrin.h>
+#include "fbgemm/Utils.h"
using namespace std;
@@ -14,37 +15,37 @@ namespace fbgemm {
int32_t reduceAvx2(const uint8_t* A, int len) {
int32_t row_sum = 0;
-#if defined(__AVX2__)
- __m256i sum_v = _mm256_setzero_si256();
- __m256i one_epi16_v = _mm256_set1_epi16(1);
- __m256i one_epi8_v = _mm256_set1_epi8(1);
+ if (fbgemm::fbgemmHasAvx2Support()) {
+ __m256i sum_v = _mm256_setzero_si256();
+ __m256i one_epi16_v = _mm256_set1_epi16(1);
+ __m256i one_epi8_v = _mm256_set1_epi8(1);
- int i;
- // vectorized
- for (i = 0; i < len / 32 * 32; i += 32) {
- __m256i src_v = _mm256_loadu_si256(reinterpret_cast<__m256i const*>(A + i));
- sum_v = _mm256_add_epi32(
+ int i;
+ // vectorized
+ for (i = 0; i < len / 32 * 32; i += 32) {
+ __m256i src_v = _mm256_loadu_si256(reinterpret_cast<__m256i const*>(A + i));
+ sum_v = _mm256_add_epi32(
sum_v,
_mm256_madd_epi16(
- _mm256_maddubs_epi16(src_v, one_epi8_v), one_epi16_v));
- }
-
- alignas(64) int32_t temp[8];
- _mm256_store_si256(reinterpret_cast<__m256i*>(temp), sum_v);
- for (int k = 0; k < 8; ++k) {
- row_sum += temp[k];
- }
+ _mm256_maddubs_epi16(src_v, one_epi8_v), one_epi16_v));
+ }
- // scalar
- for (; i < len; ++i) {
- row_sum += A[i];
- }
+ alignas(64) int32_t temp[8];
+ _mm256_store_si256(reinterpret_cast<__m256i*>(temp), sum_v);
+ for (int k = 0; k < 8; ++k) {
+ row_sum += temp[k];
+ }
-#else
- for (int i = 0; i < len; ++i) {
- row_sum += A[i];
+ // scalar
+ for (; i < len; ++i) {
+ row_sum += A[i];
+ }
+ } else {
+ for (int i = 0; i < len; ++i) {
+ row_sum += A[i];
+ }
}
-#endif
+
return row_sum;
}
diff --git a/src/PackAMatrix.cc b/src/PackAMatrix.cc
index 87adaba..5fabf97 100644
--- a/src/PackAMatrix.cc
+++ b/src/PackAMatrix.cc
@@ -34,31 +34,35 @@ PackAMatrix<T, accT>::PackAMatrix(
if (!cpuinfo_initialize()) {
throw std::runtime_error("Failed to initialize cpuinfo!");
}
+ if ((!fbgemmHasAvx512VnniSupport() && !fbgemmHasAvx512Support() &&
+ !fbgemmHasAvx2Support())) {
+ assert(0 && "unknown architecure");
+ }
+
if (params) {
- if (fbgemmHasAvx512Support() || fbgemmHasAvx2Support()) {
- BaseType::brow_ = params->MCB;
- BaseType::bcol_ = params->KCB;
- row_interleave_B_ = params->ROW_INTERLEAVE;
- } else {
- // TODO: Have default slower path
- assert(0 && "unsupported architecure");
- }
+ BaseType::brow_ = params->MCB;
+ BaseType::bcol_ = params->KCB;
+ row_interleave_B_ = params->ROW_INTERLEAVE;
} else {
- if (fbgemmHasAvx512Support()) {
+ if (fbgemmHasAvx512VnniSupport()) {
+ BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx512_vnni>::MCB;
+ BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx512_vnni>::KCB;
+ row_interleave_B_ =
+ PackingTraits<T, accT, inst_set_t::avx512_vnni>::ROW_INTERLEAVE;
+ } else if (fbgemmHasAvx512Support()) {
BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx512>::MCB;
BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx512>::KCB;
row_interleave_B_ =
PackingTraits<T, accT, inst_set_t::avx512>::ROW_INTERLEAVE;
- } else if (fbgemmHasAvx2Support()) {
+ } else {
+ // AVX2
BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx2>::MCB;
BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx2>::KCB;
row_interleave_B_ =
PackingTraits<T, accT, inst_set_t::avx2>::ROW_INTERLEAVE;
- } else {
- // TODO: Have default slower path
- assert(0 && "unsupported architecure");
}
}
+
if (BaseType::numCols() % groups != 0) {
throw std::runtime_error(
"groups = " + std::to_string(groups) +
diff --git a/src/PackAWithIm2Col.cc b/src/PackAWithIm2Col.cc
index e55dd4e..6101fef 100644
--- a/src/PackAWithIm2Col.cc
+++ b/src/PackAWithIm2Col.cc
@@ -49,32 +49,35 @@ PackAWithIm2Col<T, accT, SPATIAL_DIM>::PackAWithIm2Col(
if (!cpuinfo_initialize()) {
throw std::runtime_error("Failed to initialize cpuinfo!");
}
+ if ((!fbgemmHasAvx512VnniSupport() && !fbgemmHasAvx512Support() &&
+ !fbgemmHasAvx2Support())) {
+ assert(0 && "unknown architecure");
+ }
if (params) {
- if (fbgemmHasAvx512Support() || fbgemmHasAvx2Support()) {
- BaseType::brow_ = params->MCB;
- BaseType::bcol_ = params->KCB;
- row_interleave_B_ = params->ROW_INTERLEAVE;
- } else {
- // TODO: Have default slower path
- assert(0 && "unsupported architecure");
- }
+ BaseType::brow_ = params->MCB;
+ BaseType::bcol_ = params->KCB;
+ row_interleave_B_ = params->ROW_INTERLEAVE;
} else {
- if (fbgemmHasAvx512Support()) {
+ if (fbgemmHasAvx512VnniSupport()) {
+ BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx512_vnni>::MCB;
+ BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx512_vnni>::KCB;
+ row_interleave_B_ =
+ PackingTraits<T, accT, inst_set_t::avx512_vnni>::ROW_INTERLEAVE;
+ } else if (fbgemmHasAvx512Support()) {
BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx512>::MCB;
BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx512>::KCB;
row_interleave_B_ =
PackingTraits<T, accT, inst_set_t::avx512>::ROW_INTERLEAVE;
- } else if (fbgemmHasAvx2Support()) {
+ } else {
+ // AVX2
BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx2>::MCB;
BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx2>::KCB;
row_interleave_B_ =
PackingTraits<T, accT, inst_set_t::avx2>::ROW_INTERLEAVE;
- } else {
- // TODO: Have default slower path
- assert(0 && "unsupported architecure");
}
}
+
if (BaseType::numCols() % conv_p.G != 0) {
throw std::runtime_error(
"groups = " + std::to_string(conv_p.G) +
@@ -272,6 +275,7 @@ void PackAWithIm2Col<T, accT, SPATIAL_DIM>::pack(const block_type_t& block) {
conv_p_.K[1] == 7 && conv_p_.stride[0] == 2 && conv_p_.stride[1] == 2 &&
conv_p_.pad[0] == 3 && conv_p_.pad[1] == 3 && block.col_size == 147 &&
block_p.col_size == 148 && block.col_start == 0 &&
+ conv_p_.dilation[0] == 1 && conv_p_.dilation[1] == 1 &&
std::is_same<T, uint8_t>::value) {
if (BaseType::blockColSize() == 256) {
pack_a_with_im2col_opt<
@@ -347,8 +351,10 @@ void PackAWithIm2Col<T, accT, SPATIAL_DIM>::pack(const block_type_t& block) {
int r = grs / conv_p_.K[1] % conv_p_.K[0];
int g = grs / conv_p_.K[1] / conv_p_.K[0];
- int h_in = -conv_p_.pad[0] + h * conv_p_.stride[0] + r;
- int w_in = -conv_p_.pad[1] + w * conv_p_.stride[1] + s;
+ int h_in =
+ -conv_p_.pad[0] + h * conv_p_.stride[0] + r * conv_p_.dilation[0];
+ int w_in =
+ -conv_p_.pad[1] + w * conv_p_.stride[1] + s * conv_p_.dilation[1];
if (h_in < 0 || h_in >= conv_p_.IN_DIM[0] || w_in < 0 ||
w_in >= conv_p_.IN_DIM[1]) {
@@ -396,9 +402,12 @@ void PackAWithIm2Col<T, accT, SPATIAL_DIM>::pack(const block_type_t& block) {
int q = gqrs / conv_p_.K[2] / conv_p_.K[1] % conv_p_.K[0];
int g = gqrs / conv_p_.K[2] / conv_p_.K[1] / conv_p_.K[0];
- int t_in = -conv_p_.pad[0] + t * conv_p_.stride[0] + q;
- int h_in = -conv_p_.pad[1] + h * conv_p_.stride[1] + r;
- int w_in = -conv_p_.pad[2] + w * conv_p_.stride[2] + s;
+ int t_in =
+ -conv_p_.pad[0] + t * conv_p_.stride[0] + q * conv_p_.dilation[0];
+ int h_in =
+ -conv_p_.pad[1] + h * conv_p_.stride[1] + r * conv_p_.dilation[1];
+ int w_in =
+ -conv_p_.pad[2] + w * conv_p_.stride[2] + s * conv_p_.dilation[2];
if (t_in < 0 || t_in >= conv_p_.IN_DIM[0] || h_in < 0 ||
h_in >= conv_p_.IN_DIM[1] || w_in < 0 ||
@@ -481,7 +490,9 @@ int PackAWithIm2Col<T, accT, SPATIAL_DIM>::rowOffsetBufferSize(
if (params) {
return params->MCB;
} else {
- if (fbgemmHasAvx512Support()) {
+ if (fbgemmHasAvx512VnniSupport()) {
+ return PackingTraits<T, accT, inst_set_t::avx512_vnni>::MCB;
+ } else if (fbgemmHasAvx512Support()) {
return PackingTraits<T, accT, inst_set_t::avx512>::MCB;
} else if (fbgemmHasAvx2Support()) {
return PackingTraits<T, accT, inst_set_t::avx2>::MCB;
diff --git a/src/PackAWithQuantRowOffset.cc b/src/PackAWithQuantRowOffset.cc
index 305a298..13a8fad 100644
--- a/src/PackAWithQuantRowOffset.cc
+++ b/src/PackAWithQuantRowOffset.cc
@@ -45,32 +45,37 @@ PackAWithQuantRowOffset<T, accT>::PackAWithQuantRowOffset(
if (!cpuinfo_initialize()) {
throw std::runtime_error("Failed to initialize cpuinfo!");
}
- rowOffsetAllocatedHere = false;
+ if ((!fbgemmHasAvx512VnniSupport() && !fbgemmHasAvx512Support() &&
+ !fbgemmHasAvx2Support())) {
+ assert(0 && "unknown architecure");
+ }
+
if (params) {
- if (fbgemmHasAvx512Support() || fbgemmHasAvx2Support()) {
- BaseType::brow_ = params->MCB;
- BaseType::bcol_ = params->KCB;
- row_interleave_B_ = params->ROW_INTERLEAVE;
- } else {
- // TODO: Have default slower path
- assert(0 && "unsupported architecure");
- }
+ BaseType::brow_ = params->MCB;
+ BaseType::bcol_ = params->KCB;
+ row_interleave_B_ = params->ROW_INTERLEAVE;
} else {
- if (fbgemmHasAvx512Support()) {
+ if (fbgemmHasAvx512VnniSupport()) {
+ BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx512_vnni>::MCB;
+ BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx512_vnni>::KCB;
+ row_interleave_B_ =
+ PackingTraits<T, accT, inst_set_t::avx512_vnni>::ROW_INTERLEAVE;
+ } else if (fbgemmHasAvx512Support()) {
BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx512>::MCB;
BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx512>::KCB;
row_interleave_B_ =
PackingTraits<T, accT, inst_set_t::avx512>::ROW_INTERLEAVE;
- } else if (fbgemmHasAvx2Support()) {
+ } else {
+ // AVX2
BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx2>::MCB;
BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx2>::KCB;
row_interleave_B_ =
PackingTraits<T, accT, inst_set_t::avx2>::ROW_INTERLEAVE;
- } else {
- // TODO: Have default slower path
- assert(0 && "unknown architecure");
}
}
+
+ rowOffsetAllocatedHere = false;
+
if (BaseType::numCols() % groups != 0) {
throw std::runtime_error(
"groups = " + std::to_string(groups) +
@@ -202,7 +207,9 @@ int PackAWithQuantRowOffset<T, accT>::rowOffsetBufferSize(
if (params) {
return params->MCB;
} else {
- if (fbgemmHasAvx512Support()) {
+ if (fbgemmHasAvx512VnniSupport()) {
+ return PackingTraits<T, accT, inst_set_t::avx512_vnni>::MCB;
+ } else if (fbgemmHasAvx512Support()) {
return PackingTraits<T, accT, inst_set_t::avx512>::MCB;
} else if (fbgemmHasAvx2Support()) {
return PackingTraits<T, accT, inst_set_t::avx2>::MCB;
diff --git a/src/PackAWithRowOffset.cc b/src/PackAWithRowOffset.cc
index b791817..e84c67b 100644
--- a/src/PackAWithRowOffset.cc
+++ b/src/PackAWithRowOffset.cc
@@ -39,32 +39,37 @@ PackAWithRowOffset<T, accT>::PackAWithRowOffset(
if (!cpuinfo_initialize()) {
throw std::runtime_error("Failed to initialize cpuinfo!");
}
- rowOffsetAllocatedHere = false;
+ if ((!fbgemmHasAvx512VnniSupport() && !fbgemmHasAvx512Support() &&
+ !fbgemmHasAvx2Support())) {
+ assert(0 && "unknown architecure");
+ }
+
if (params) {
- if (fbgemmHasAvx512Support() || fbgemmHasAvx2Support()) {
- BaseType::brow_ = params->MCB;
- BaseType::bcol_ = params->KCB;
- row_interleave_B_ = params->ROW_INTERLEAVE;
- } else {
- // TODO: Have default slower path
- assert(0 && "unsupported architecure");
- }
+ BaseType::brow_ = params->MCB;
+ BaseType::bcol_ = params->KCB;
+ row_interleave_B_ = params->ROW_INTERLEAVE;
} else {
- if (fbgemmHasAvx512Support()) {
+ if (fbgemmHasAvx512VnniSupport()) {
+ BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx512_vnni>::MCB;
+ BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx512_vnni>::KCB;
+ row_interleave_B_ =
+ PackingTraits<T, accT, inst_set_t::avx512_vnni>::ROW_INTERLEAVE;
+ } else if (fbgemmHasAvx512Support()) {
BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx512>::MCB;
BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx512>::KCB;
row_interleave_B_ =
PackingTraits<T, accT, inst_set_t::avx512>::ROW_INTERLEAVE;
- } else if (fbgemmHasAvx2Support()) {
+ } else {
+ // AVX2
BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx2>::MCB;
BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx2>::KCB;
row_interleave_B_ =
PackingTraits<T, accT, inst_set_t::avx2>::ROW_INTERLEAVE;
- } else {
- // TODO: Have default slower path
- assert(0 && "unknown architecure");
}
}
+
+ rowOffsetAllocatedHere = false;
+
if (BaseType::numCols() % groups != 0) {
throw std::runtime_error(
"groups = " + std::to_string(groups) +
@@ -190,7 +195,9 @@ int PackAWithRowOffset<T, accT>::rowOffsetBufferSize(
if (params) {
return params->MCB;
} else {
- if (fbgemmHasAvx512Support()) {
+ if (fbgemmHasAvx512VnniSupport()) {
+ return PackingTraits<T, accT, inst_set_t::avx512_vnni>::MCB;
+ } else if (fbgemmHasAvx512Support()) {
return PackingTraits<T, accT, inst_set_t::avx512>::MCB;
} else if (fbgemmHasAvx2Support()) {
return PackingTraits<T, accT, inst_set_t::avx2>::MCB;
diff --git a/src/PackBMatrix.cc b/src/PackBMatrix.cc
index b6d06ca..c237ac4 100644
--- a/src/PackBMatrix.cc
+++ b/src/PackBMatrix.cc
@@ -188,6 +188,76 @@ PackBMatrix<T, accT>::PackBMatrix(
if (!cpuinfo_initialize()) {
throw std::runtime_error("Failed to initialize cpuinfo!");
}
+ if ((!fbgemmHasAvx512VnniSupport() && !fbgemmHasAvx512Support() &&
+ !fbgemmHasAvx2Support())) {
+ assert(0 && "unknown architecure");
+ }
+
+ if (params) {
+ BaseType::brow_ = params->KCB;
+ BaseType::bcol_ = params->NCB;
+ row_interleave_ = params->ROW_INTERLEAVE;
+ } else {
+ if (fbgemmHasAvx512VnniSupport()) {
+ BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx512_vnni>::KCB;
+ BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx512_vnni>::NCB;
+ row_interleave_ =
+ PackingTraits<T, accT, inst_set_t::avx512_vnni>::ROW_INTERLEAVE;
+ } else if (fbgemmHasAvx512Support()) {
+ BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx512>::KCB;
+ BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx512>::NCB;
+ row_interleave_ =
+ PackingTraits<T, accT, inst_set_t::avx512>::ROW_INTERLEAVE;
+ } else {
+ // AVX2
+ BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx2>::KCB;
+ BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx2>::NCB;
+ row_interleave_ =
+ PackingTraits<T, accT, inst_set_t::avx2>::ROW_INTERLEAVE;
+ }
+ }
+
+ if (BaseType::numRows() % groups != 0) {
+ throw std::runtime_error(
+ "groups = " + std::to_string(groups) +
+ " does not divide numRows = " + std::to_string(BaseType::numRows()));
+ }
+
+ // blocking for one group
+ block_type_t block{
+ 0, BaseType::numRows() / BaseType::numGroups(), 0, BaseType::numCols()};
+ BaseType::packedBlock(block);
+ if (!pmat) {
+ BaseType::bufAllocatedHere_ = true;
+ BaseType::buf_ = (T*)fbgemmAlignedAlloc(
+ 64,
+ BaseType::numGroups() * BaseType::blockRows() * BaseType::brow_ *
+ BaseType::blockCols() * BaseType::bcol_ * sizeof(T));
+ }
+ pack(block, params);
+}
+
+template <typename T, typename accT>
+PackBMatrix<T, accT>::PackBMatrix(
+ matrix_op_t trans,
+ int32_t nRow,
+ int32_t nCol,
+ inpType* prepackedmat,
+ int32_t ld,
+ int groups,
+ const BlockingFactors* params)
+ : PackMatrix<PackBMatrix<T, accT>, T, accT>(
+ nRow,
+ nCol,
+ prepackedmat,
+ groups,
+ params),
+ trans_(trans),
+ smat_(nullptr),
+ ld_(ld) {
+ if (!cpuinfo_initialize()) {
+ throw std::runtime_error("Failed to initialize cpuinfo!");
+ }
if (params) {
if (fbgemmHasAvx512Support() || fbgemmHasAvx2Support()) {
BaseType::brow_ = params->KCB;
@@ -221,20 +291,17 @@ PackBMatrix<T, accT>::PackBMatrix(
// blocking for one group
block_type_t block{
- 0, BaseType::numRows() / BaseType::numGroups(), 0, BaseType::numCols()};
+ 0, BaseType::numRows() / BaseType::numGroups(), 0, BaseType::numCols() };
BaseType::packedBlock(block);
- if (!pmat) {
- BaseType::bufAllocatedHere_ = true;
- BaseType::buf_ = (T*)fbgemmAlignedAlloc(
- 64,
- BaseType::numGroups() * BaseType::blockRows() * BaseType::brow_ *
- BaseType::blockCols() * BaseType::bcol_ * sizeof(T));
- }
- pack(block);
}
template <typename T, typename accT>
-void PackBMatrix<T, accT>::pack(const block_type_t& block) {
+void PackBMatrix<T, accT>::pack_unpack_(
+ const block_type_t& block,
+ T* unpack_buf,
+ T* pack_buf,
+ bool ispack,
+ const BlockingFactors* params) {
assert((BaseType::blockRowSize() % row_interleave_) == 0);
assert((block.row_start % BaseType::blockRowSize()) == 0);
assert((block.col_start % BaseType::blockColSize()) == 0);
@@ -242,8 +309,8 @@ void PackBMatrix<T, accT>::pack(const block_type_t& block) {
BaseType::packedBlock(block);
bool tr = (trans_ == matrix_op_t::Transpose);
for (int g = 0; g < BaseType::numGroups(); ++g) {
- T* out = BaseType::getBuf() +
- g * BaseType::packedBufferSize(block.row_size, block.col_size);
+ T* pack_buf_cur = pack_buf +
+ g * BaseType::packedBufferSize(block.row_size, block.col_size, params);
for (int i = block.row_start; i < block.row_start + block.row_size; ++i) {
int r_offset = ((i / BaseType::blockRowSize()) * BaseType::blockCols()) *
(BaseType::blockRowSize() * BaseType::blockColSize()) +
@@ -268,10 +335,16 @@ void PackBMatrix<T, accT>::pack(const block_type_t& block) {
c_blk_offset * BaseType::blockRowSize() * BaseType::blockColSize() +
c_idx_offset * row_interleave_;
- int out_idx = r_offset + c_offset;
- T val = tr ? smat_[i + (g * block.col_size + j) * ld_]
- : smat_[(g * block.row_size + i) * ld_ + j];
- out[out_idx] = val;
+ if (ispack) {
+ pack_buf_cur[r_offset + c_offset] = tr
+ ? unpack_buf[i + (g * block.col_size + j) * ld_]
+ : unpack_buf[(g * block.row_size + i) * ld_ + j];
+ } else {
+ T* unpack_buf_cur = tr
+ ? &(unpack_buf[i + (g * block.col_size + j) * ld_])
+ : &(unpack_buf[(g * block.row_size + i) * ld_ + j]);
+ *unpack_buf_cur = pack_buf_cur[r_offset + c_offset];
+ }
c_idx_offset++;
if (c_idx_offset == BaseType::blockColSize()) {
@@ -280,78 +353,49 @@ void PackBMatrix<T, accT>::pack(const block_type_t& block) {
}
}
}
- // fill the remaining with zero.
- // Please see the comment in PackAMatrix.cc on zero vs zero_pt fill.
- for (int i = block.row_start + block.row_size;
- i < (block.row_start + block.row_size + row_interleave_ - 1) /
- row_interleave_ * row_interleave_;
- ++i) {
- int r_offset = ((i / BaseType::blockRowSize()) * BaseType::blockCols()) *
- (BaseType::blockRowSize() * BaseType::blockColSize()) +
- (i % BaseType::blockRowSize() / row_interleave_) *
- BaseType::blockColSize() * row_interleave_ +
- i % row_interleave_;
- for (int j = block.col_start; j < block.col_start + block.col_size; j++) {
- int c_offset = (j / BaseType::blockColSize()) *
- BaseType::blockRowSize() * BaseType::blockColSize() +
- (j % BaseType::blockColSize()) * row_interleave_;
+ if (ispack) {
+ // fill the remaining with zero.
+ // Please see the comment in PackAMatrix.cc on zero vs zero_pt fill.
+ for (int i = block.row_start + block.row_size;
+ i < (block.row_start + block.row_size + row_interleave_ - 1) /
+ row_interleave_ * row_interleave_;
+ ++i) {
+ int r_offset =
+ ((i / BaseType::blockRowSize()) * BaseType::blockCols()) *
+ (BaseType::blockRowSize() * BaseType::blockColSize()) +
+ (i % BaseType::blockRowSize() / row_interleave_) *
+ BaseType::blockColSize() * row_interleave_ +
+ i % row_interleave_;
+ for (int j = block.col_start; j < block.col_start + block.col_size;
+ j++) {
+ int c_offset = (j / BaseType::blockColSize()) *
+ BaseType::blockRowSize() * BaseType::blockColSize() +
+ (j % BaseType::blockColSize()) * row_interleave_;
- int out_idx = r_offset + c_offset;
- out[out_idx] = 0;
+ int out_idx = r_offset + c_offset;
+ pack_buf_cur[out_idx] = 0;
+ }
}
}
} // for each group
}
template <typename T, typename accT>
-void PackBMatrix<T, accT>::unpack(T* origin_buf) {
- bool tr = (trans_ == matrix_op_t::Transpose);
- for (int g = 0; g < this->numGroups(); ++g) {
- T* out = BaseType::getBuf() +
- g *
- BaseType::packedBufferSize(
- BaseType::numPackedRows(), BaseType::numPackedCols());
- for (int i = BaseType::packedRowStart();
- i < BaseType::packedRowStart() + BaseType::numPackedRows();
- ++i) {
- int r_offset = ((i / BaseType::blockRowSize()) * BaseType::blockCols()) *
- (BaseType::blockRowSize() * BaseType::blockColSize()) +
- (i % BaseType::blockRowSize() / row_interleave_) *
- BaseType::blockColSize() * row_interleave_ +
- i % row_interleave_;
-
- int c_start_offset =
- (BaseType::packedColStart() / BaseType::blockColSize()) *
- BaseType::blockRowSize() * BaseType::blockColSize() +
- (BaseType::packedColStart() % BaseType::blockColSize()) *
- row_interleave_;
-
- int c_idx_offset = 0;
- int c_blk_offset = 0;
- for (int j = BaseType::packedColStart();
- j < BaseType::packedColStart() + BaseType::numPackedCols();
- ++j) {
- int c_offset = c_start_offset +
- c_blk_offset * BaseType::blockRowSize() * BaseType::blockColSize() +
- c_idx_offset * row_interleave_;
-
- int out_idx = r_offset + c_offset;
-
- T val = out[out_idx];
- if (tr) {
- origin_buf[i + (g * BaseType::numPackedCols() + j) * ld_] = val;
- } else {
- origin_buf[(g * BaseType::numPackedRows() + i) * ld_ + j] = val;
- }
+void PackBMatrix<T, accT>::pack(
+ const block_type_t& block,
+ const BlockingFactors* params) {
+ pack_unpack_(block, const_cast<T*>(smat_), BaseType::getBuf(), true, params);
+}
- c_idx_offset++;
- if (c_idx_offset == BaseType::blockColSize()) {
- c_idx_offset = 0;
- c_blk_offset++;
- }
- }
- }
- } // for each group
+template <typename T, typename accT>
+void PackBMatrix<T, accT>::unpack(
+ T* origin_buf,
+ const BlockingFactors* params) {
+ block_type_t blockB{BaseType::packedRowStart(),
+ BaseType::numPackedRows(),
+ BaseType::packedColStart(),
+ BaseType::numPackedCols()};
+ pack_unpack_(blockB, origin_buf, BaseType::getBuf(), false, params);
}
template <typename T, typename accT>
@@ -374,7 +418,9 @@ int32_t PackBMatrix<T, accT>::addr(int32_t r, int32_t c) const {
}
template <typename T, typename accT>
-void PackBMatrix<T, accT>::printPackedMatrix(std::string name) {
+void PackBMatrix<T, accT>::printPackedMatrix(
+ std::string name,
+ const BlockingFactors* params) {
std::cout << name << ":"
<< "[" << BaseType::numPackedRows() << ", "
<< BaseType::numPackedCols() << "]" << std::endl;
@@ -382,33 +428,39 @@ void PackBMatrix<T, accT>::printPackedMatrix(std::string name) {
<< "[" << BaseType::blockRowSize() << ", "
<< BaseType::blockColSize() << "]" << std::endl;
- T* out = BaseType::getBuf();
-
- for (auto nr = 0; nr < BaseType::blockRows(); ++nr) {
- auto rows = (nr == BaseType::blockRows() - 1) ? BaseType::lastBrow()
- : BaseType::blockRowSize();
- for (auto nc = 0; nc < BaseType::blockCols(); ++nc) {
- std::cout << "block:" << nr << ", " << nc << std::endl;
- auto cols = (nc == BaseType::blockCols() - 1) ? BaseType::lastBcol()
- : BaseType::blockColSize();
- for (auto r = 0; r < (rows + row_interleave_ - 1) / row_interleave_;
- ++r) {
- for (auto c = 0; c < cols * row_interleave_; ++c) {
- T val =
- out[nr * BaseType::blockCols() * BaseType::blockRowSize() *
- BaseType::blockColSize() +
- nc * BaseType::blockRowSize() * BaseType::blockColSize() +
- r * BaseType::blockColSize() * row_interleave_ + c];
- if (std::is_integral<T>::value) {
- // cast to int64 because cout doesn't print int8_t type directly
- std::cout << std::setw(5) << static_cast<int64_t>(val) << " ";
- } else {
- std::cout << std::setw(5) << val << " ";
+ for (int g = 0; g < BaseType::numGroups(); ++g) {
+ T* out = BaseType::getBuf() +
+ g *
+ BaseType::packedBufferSize(
+ BaseType::numPackedRows(), BaseType::numPackedCols(), params);
+ std::cout << "group: " << g << std::endl;
+ for (auto nr = 0; nr < BaseType::blockRows(); ++nr) {
+ auto rows = (nr == BaseType::blockRows() - 1) ? BaseType::lastBrow()
+ : BaseType::blockRowSize();
+ for (auto nc = 0; nc < BaseType::blockCols(); ++nc) {
+ std::cout << "block:" << nr << ", " << nc << std::endl;
+ auto cols = (nc == BaseType::blockCols() - 1)
+ ? BaseType::lastBcol()
+ : BaseType::blockColSize();
+ for (auto r = 0; r < (rows + row_interleave_ - 1) / row_interleave_;
+ ++r) {
+ for (auto c = 0; c < cols * row_interleave_; ++c) {
+ T val =
+ out[nr * BaseType::blockCols() * BaseType::blockRowSize() *
+ BaseType::blockColSize() +
+ nc * BaseType::blockRowSize() * BaseType::blockColSize() +
+ r * BaseType::blockColSize() * row_interleave_ + c];
+ if (std::is_integral<T>::value) {
+ // cast to int64 because cout doesn't print int8_t type directly
+ std::cout << std::setw(5) << static_cast<int64_t>(val) << " ";
+ } else {
+ std::cout << std::setw(5) << val << " ";
+ }
}
+ std::cout << std::endl;
}
std::cout << std::endl;
}
- std::cout << std::endl;
}
}
}
diff --git a/src/PackDepthwiseConvMatrixAvx2.cc b/src/PackDepthwiseConvMatrixAvx2.cc
new file mode 100644
index 0000000..a84c469
--- /dev/null
+++ b/src/PackDepthwiseConvMatrixAvx2.cc
@@ -0,0 +1,211 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+#include "fbgemm/FbgemmI8DepthwiseAvx2.h"
+#include "fbgemm/Utils.h"
+#include "fbgemm/Fbgemm.h"
+
+#include <immintrin.h>
+
+using namespace std;
+
+namespace fbgemm {
+
+// clang-format off
+static int masks[8][8] = {
+ // NOTE: clang-format wants to use a different formatting but the current
+ // formatting should be easier to read.
+ { 0, 0, 0, 0, 0, 0, 0, 0, },
+ { -1, 0, 0, 0, 0, 0, 0, 0, },
+ { -1, -1, 0, 0, 0, 0, 0, 0, },
+ { -1, -1, -1, 0, 0, 0, 0, 0, },
+ { -1, -1, -1, -1, 0, 0, 0, 0, },
+ { -1, -1, -1, -1, -1, 0, 0, 0, },
+ { -1, -1, -1, -1, -1, -1, 0, 0, },
+ { -1, -1, -1, -1, -1, -1, -1, 0, },
+};
+// clang-format on
+
+PackedDepthWiseConvMatrix::PackedDepthWiseConvMatrix(
+ int K,
+ int kernel_prod,
+ const int8_t* smat)
+ : K_(K), kernel_prod_(kernel_prod) {
+ // Transpose the input matrix to make packing faster.
+ int8_t* smat_transposed
+ = static_cast<int8_t*>(ALIGNED_MALLOC(K * kernel_prod * sizeof(int8_t), 64));
+
+ for (int i = 0; i < kernel_prod; ++i) {
+ for (int j = 0; j < K; ++j) {
+ smat_transposed[i * K + j] = smat[i + j * kernel_prod];
+ }
+ }
+
+ // Allocate packed arrays
+ int kernel_prod_aligned = (kernel_prod + 1) / 2 * 2;
+ pmat_ = static_cast<int8_t *>(fbgemmAlignedAlloc(64, ((K + 31) / 32) * kernel_prod_aligned * 32 * sizeof(int8_t)));
+
+ // Pack input matrix
+ // The layout is optimized to use vpmaddubsw efficiently (see
+ // madd_epi16x4_packed function).
+ // For a group of 32 channels, we have 10 32B SIMD registers.
+ // Denote ith channel jth filter as (i, j)
+ // 0th SIMD register:
+ // (0, 0), (0, 1), (0, 2), (0, 3), ..., (3, 0), (3, 1), (3, 2), (3, 3)
+ // (16, 0), (16, 1), (16, 2), (16, 3), ..., (19, 0), (19, 1), (19, 2), (19, 3)
+ // 1st SIMD register:
+ // (4, 0), (4, 1), (4, 2), (4, 3), ..., (7, 0), (7, 1), (7, 2), (7, 3)
+ // (20, 0), (20, 1), (20, 2), (20, 3), ..., (23, 0), (23, 1), (23, 2), (23, 3)
+ // 2nd SIMD register:
+ // (8, 0), (8, 1), (8, 2), (8, 3), ..., (11, 0), (11, 1), (11, 2), (11, 3)
+ // (24, 0), (24, 1), (24, 2), (24, 3), ..., (27, 0), (27, 1), (27, 2), (27, 3)
+ // 3rd SIMD register:
+ // (12, 0), (12, 1), (12, 2), (12, 3), ..., (15, 0), (15, 1), (15, 2), (15, 3)
+ // (28, 0), (28, 1), (28, 2), (28, 3), ..., (31, 0), (31, 1), (31, 2), (31, 3)
+ // 4-7th SIMD register: same as the previous 4 registers but for 4-7th filter
+ // coefficients
+ // ...
+ //
+ // REMAINDER
+ // If kernel_prod % 4 == 1 for example when kernel_prod == 9
+ // 8th SIMD register:
+ // (0, 8), zero, ..., (7, 8), zero
+ // (16, 8), zero, ..., (23, 8), zero
+ // 9th SIMD register:
+ // (8, 8), zero, ..., (15, 8), zero
+ // (24, 8), zero, ..., (31, 8), zero
+ // We use madd_epi16_packed for this case
+ //
+ // If kernel_prod % 4 == 2 for example when kernel_prod == 10
+ // 8th SIMD register:
+ // (0, 8), (0, 9), ..., (7, 8), (7, 9)
+ // (16, 8), (16, 9), ..., (23, 8), (23, 9)
+ // 9th SIMD register:
+ // (8, 8), (8, 9), ..., (15, 8), (15, 9)
+ // (24, 8), (24, 9), ..., (31, 8), (31, 9)
+ //
+ // If kernel_prod % 4 == 3 for example when kernel_prod == 11
+ // 8th SIMD register:
+ // (0, 8), (0, 9), (0, 10), zero, ..., (3, 8), (3, 9), (3, 10), zero
+ // (16, 8), (16, 9), (16, 10), zero, ..., (19, 8), (19, 9), (19, 10), zero
+ // 9th SIMD register:
+ // (4, 8), (4, 9), (4, 10), zero, ..., (7, 8), (7, 9), (7, 10), zero
+ // (20, 8), (20, 9), (20, 10), zero, ..., (23, 8), (23, 9), (23, 10), zero
+ // 10th SIMD register:
+ // (8, 8), (8, 9), (8, 10), zero, ..., (11, 8), (11, 9), (11, 10), zero
+ // (24, 8), (24, 9), (24, 10), zero, ..., (27, 8), (27, 9), (27, 10), zero
+ // 11th SIMD register:
+ // (12, 8), (12, 9), (12, 10), zero, ..., (15, 8), (15, 9), (15, 10), zero
+ // (28, 8), (28, 9), (28, 10), zero, ..., (31, 8), (31, 9), (31, 10), zero
+ for (int k1 = 0; k1 < K; k1 += 32) {
+ __m256i* b_v = static_cast<__m256i*>(ALIGNED_MALLOC(kernel_prod * sizeof(__m256i), 64));
+ int remainder = K - k1;
+ if (remainder < 32) {
+ __m256i mask_v = _mm256_loadu_si256(
+ reinterpret_cast<const __m256i*>(masks[remainder / 4]));
+ for (int i = 0; i < kernel_prod; ++i) {
+ b_v[i] = _mm256_maskload_epi32(
+ reinterpret_cast<const int*>(smat_transposed + i * K + k1), mask_v);
+ }
+ } else {
+ for (int i = 0; i < kernel_prod; ++i) {
+ b_v[i] = _mm256_lddqu_si256(
+ reinterpret_cast<const __m256i*>(smat_transposed + i * K + k1));
+ }
+ }
+
+ // Interleave 2 SIMD registers
+ __m256i* b_interleaved_epi16 = static_cast<__m256i*>(ALIGNED_MALLOC(kernel_prod_aligned * sizeof(__m256i), 64));
+ __m256i zero_v = _mm256_setzero_si256();
+ for (int i = 0; i < kernel_prod_aligned / 2; ++i) {
+ if (2 * i + 1 >= kernel_prod) {
+ b_interleaved_epi16[2 * i] = _mm256_unpacklo_epi8(b_v[2 * i], zero_v);
+ b_interleaved_epi16[2 * i + 1] =
+ _mm256_unpackhi_epi8(b_v[2 * i], zero_v);
+ } else {
+ b_interleaved_epi16[2 * i] =
+ _mm256_unpacklo_epi8(b_v[2 * i], b_v[2 * i + 1]);
+ b_interleaved_epi16[2 * i + 1] =
+ _mm256_unpackhi_epi8(b_v[2 * i], b_v[2 * i + 1]);
+ }
+ }
+
+ // Interleave 4 SIMD registers
+ __m256i* b_interleaved_epi32 = static_cast<__m256i*>(ALIGNED_MALLOC(kernel_prod_aligned * sizeof(__m256i), 64));
+ for (int i = 0; i < kernel_prod_aligned / 4; ++i) {
+ b_interleaved_epi32[4 * i] = _mm256_unpacklo_epi16(
+ b_interleaved_epi16[4 * i], b_interleaved_epi16[4 * i + 2]);
+ b_interleaved_epi32[4 * i + 1] = _mm256_unpackhi_epi16(
+ b_interleaved_epi16[4 * i], b_interleaved_epi16[4 * i + 2]);
+ b_interleaved_epi32[4 * i + 2] = _mm256_unpacklo_epi16(
+ b_interleaved_epi16[4 * i + 1], b_interleaved_epi16[4 * i + 3]);
+ b_interleaved_epi32[4 * i + 3] = _mm256_unpackhi_epi16(
+ b_interleaved_epi16[4 * i + 1], b_interleaved_epi16[4 * i + 3]);
+ }
+ for (int i = kernel_prod_aligned / 4 * 4; i < kernel_prod_aligned; ++i) {
+ b_interleaved_epi32[i] = b_interleaved_epi16[i];
+ }
+
+ for (int i = 0; i < kernel_prod_aligned; ++i) {
+ _mm256_storeu_si256(
+ reinterpret_cast<__m256i*>(
+ &pmat_[((k1 / 32) * kernel_prod_aligned + i) * 32]),
+ b_interleaved_epi32[i]);
+ }
+
+ FREE(b_v);
+ FREE(b_interleaved_epi16);
+ FREE(b_interleaved_epi32);
+ }
+ FREE(smat_transposed);
+}
+
+int PackedDepthWiseConvMatrix::addr(int r, int c) {
+ int kernel_prod_aligned = (kernel_prod_ + 1) / 2 * 2;
+ if (c >= kernel_prod_ / 4 * 4 &&
+ (kernel_prod_ % 4 == 1 || kernel_prod_ % 4 == 2)) {
+ int kBlock = r / 32;
+ int reg_idx = (r % 16) / 8 + c / 4 * 4;
+
+ int blk_idx = kBlock * kernel_prod_aligned + reg_idx;
+
+ int r_ = r % 8;
+ int c_ = c % 4;
+
+ int in_blk_idx = (r % 32) / 16 * 16 + 2 * r_ + c_;
+ return blk_idx * 32 + in_blk_idx;
+
+ } else {
+ int kBlock = r / 32;
+ int reg_idx = (r % 16) / 4 + c / 4 * 4;
+
+ int blk_idx = kBlock * kernel_prod_aligned + reg_idx;
+
+ int r_ = r % 4;
+ int c_ = c % 4;
+
+ int in_blk_idx = (r % 32) / 16 * 16 + 4 * r_ + c_;
+ return blk_idx * 32 + in_blk_idx;
+ }
+}
+
+void PackedDepthWiseConvMatrix::unpack(int8_t* unpacked_data) {
+ for (int r = 0; r < K_; ++r) {
+ for (int c = 0; c < kernel_prod_; ++c) {
+ unpacked_data[r * kernel_prod_ + c] = pmat_[addr(r, c)];
+ }
+ }
+}
+
+PackedDepthWiseConvMatrix::~PackedDepthWiseConvMatrix() {
+#ifdef _MSC_VER
+ _aligned_free(pmat_);
+#else
+ free(pmat_);
+#endif
+}
+
+} // namespace fbgemm
diff --git a/src/PackMatrix.cc b/src/PackMatrix.cc
index c9a68a6..ff7b842 100644
--- a/src/PackMatrix.cc
+++ b/src/PackMatrix.cc
@@ -36,54 +36,42 @@ int PackMatrix<PT, inpType, accType>::packedBufferSize(
if (!cpuinfo_initialize()) {
throw std::runtime_error("Failed to initialize cpuinfo!");
}
+ if ((!fbgemmHasAvx512VnniSupport() && !fbgemmHasAvx512Support() &&
+ !fbgemmHasAvx2Support())) {
+ assert(0 && "unknown architecure");
+ }
+
int MCB, KCB, NCB;
if (params) {
- if (fbgemmHasAvx512Support() || fbgemmHasAvx2Support()) {
- MCB = params->MCB;
- NCB = params->NCB;
- KCB = params->KCB;
- } else {
- // TODO: Have default slower path
- assert(0 && "unsupported architecure");
- }
+ MCB = params->MCB;
+ NCB = params->NCB;
+ KCB = params->KCB;
} else {
- if (fbgemmHasAvx512Support()) {
+ if (fbgemmHasAvx512VnniSupport()) {
+ MCB = PackingTraits<inpType, accType, inst_set_t::avx512_vnni>::MCB;
+ NCB = PackingTraits<inpType, accType, inst_set_t::avx512_vnni>::NCB;
+ KCB = PackingTraits<inpType, accType, inst_set_t::avx512_vnni>::KCB;
+ } else if (fbgemmHasAvx512Support()) {
MCB = PackingTraits<inpType, accType, inst_set_t::avx512>::MCB;
NCB = PackingTraits<inpType, accType, inst_set_t::avx512>::NCB;
KCB = PackingTraits<inpType, accType, inst_set_t::avx512>::KCB;
- } else if (fbgemmHasAvx2Support()) {
+ } else {
+ // AVX2
MCB = PackingTraits<inpType, accType, inst_set_t::avx2>::MCB;
NCB = PackingTraits<inpType, accType, inst_set_t::avx2>::NCB;
KCB = PackingTraits<inpType, accType, inst_set_t::avx2>::KCB;
- } else {
- // TODO: Have default slower path
- assert(0 && "unsupported architecure");
- return -1;
}
}
- if (fbgemmHasAvx512Support()) {
- if (isA()) {
- return MCB * KCB;
- } else {
- int rowBlock = KCB;
- int colBlock = NCB;
- return (((rows + rowBlock - 1) / rowBlock) * rowBlock) *
- (((cols + colBlock - 1) / colBlock) * colBlock);
- }
- } else if (fbgemmHasAvx2Support()) {
- if (isA()) {
- return MCB * KCB;
- } else {
- int rowBlock = KCB;
- int colBlock = NCB;
- return (((rows + rowBlock - 1) / rowBlock) * rowBlock) *
- (((cols + colBlock - 1) / colBlock) * colBlock);
- }
+ if (isA()) {
+ return MCB * KCB;
} else {
- // TODO: Have default slower path
- assert(0 && "unsupported architecure");
+ int rowBlock = KCB;
+ int colBlock = NCB;
+ return (((rows + rowBlock - 1) / rowBlock) * rowBlock) *
+ (((cols + colBlock - 1) / colBlock) * colBlock);
}
+
return -1;
}
diff --git a/src/PackWeightMatrixForGConv.cc b/src/PackWeightMatrixForGConv.cc
index 0fb0e2c..f6ad59e 100644
--- a/src/PackWeightMatrixForGConv.cc
+++ b/src/PackWeightMatrixForGConv.cc
@@ -36,8 +36,61 @@ PackWeightMatrixForGConv<T, accT, SPATIAL_DIM>::PackWeightMatrixForGConv(
}
/**
- * @brief Pack weight tensor in a suitable format required for the optimized
- * kernel.
+ * @brief Get the index of the unpacked data for a given <r, s, k, g, c, tr>
+ *
+ * Non-transposed: G (R S C/G) K/G
+ * Transposed: G K/G (R S C/G)
+ * Using inline as this will be called frequently
+ */
+template <typename T, typename accT, int SPATIAL_DIM>
+inline int PackWeightMatrixForGConv<T, accT, SPATIAL_DIM>::unpacked_index_(
+ int r, int s, int k, int g, int c, bool tr) {
+ // Get the full dimensions
+ int R = conv_param_.K[0];
+ int S = conv_param_.K[1];
+ int G = conv_param_.G;
+ int IC_per_G = conv_param_.IC / G;
+ int OC_per_G = conv_param_.OC / G;
+
+ int idx;
+ if (tr) {
+ idx = (((g * OC_per_G + k) * R + r) * S + s) * IC_per_G + c;
+ } else {
+ idx = (((g * R + r) * S + s) * IC_per_G + c) * OC_per_G + k;
+ }
+ return idx;
+}
+
+/**
+ * @brief Get the index of the packed data for a given <r, s, k, g, c>
+ *
+ * The index may differ depending on IC_per_G.
+ * Using inline as this will be called frequently
+ */
+template <typename T, typename accT, int SPATIAL_DIM>
+inline int PackWeightMatrixForGConv<T, accT, SPATIAL_DIM>::packed_index_(
+ int r, int s, int k, int g, int c) {
+ // Get the full dimensions
+ int R = conv_param_.K[0];
+ int S = conv_param_.K[1];
+ int G = conv_param_.G;
+ int IC_per_G = conv_param_.IC / G;
+ int OC_per_G = conv_param_.OC / G;
+
+ int idx;
+ // For IC_per_G == 4, we need to work on 2 groups at a time
+ if (IC_per_G == 4) {
+ idx = (((((g / 2) * R + r) * S + s) * OC_per_G + k) * 2 + (g % 2))
+ * IC_per_G + c;
+ } else {
+ idx = ((((g * (IC_per_G / 4) + (c / 4)) * R + r) * S + s) * OC_per_G + k)
+ * 4 + (c % 4);
+ }
+ return idx;
+}
+
+/**
+ * @ brief Pack or unpack matrix
*
* Let IC_per_G be number of input channels per group and OC_per_G be number of
* output channels per group.
@@ -54,14 +107,16 @@ PackWeightMatrixForGConv<T, accT, SPATIAL_DIM>::PackWeightMatrixForGConv(
* while working on 1 group at a time.
* In this case, the layout is G (C/4) R S K 4
*/
+
template <typename T, typename accT, int SPATIAL_DIM>
-void PackWeightMatrixForGConv<T, accT, SPATIAL_DIM>::pack() {
+void PackWeightMatrixForGConv<T, accT, SPATIAL_DIM>::pack_unpack_(
+ const T* src, T* dst, bool ispack) {
// filters are assumed to be in G RS C/G K/G format
int R = conv_param_.K[0];
int S = conv_param_.K[1];
int G = conv_param_.G;
- int IC_per_G = conv_param_.IC / conv_param_.G;
- int OC_per_G = conv_param_.OC / conv_param_.G;
+ int IC_per_G = conv_param_.IC / G;
+ int OC_per_G = conv_param_.OC / G;
// If transpose option is set, the weight matrix is in layout G K/G (R S C/G)
// instead of G (R S C/G) K/G
@@ -73,25 +128,13 @@ void PackWeightMatrixForGConv<T, accT, SPATIAL_DIM>::pack() {
for (int k = 0; k < OC_per_G; ++k) {
for (int g = 0; g < G; ++g) {
for (int c = 0; c < IC_per_G; ++c) {
- inpType b = tr
- ? sdata_
- [(((g * OC_per_G + k) * R + r) * S + s) * IC_per_G + c]
- : sdata_
- [(((g * R + r) * S + s) * IC_per_G + c) * OC_per_G + k];
- if (IC_per_G == 4) {
- // For IC_per_G == 4, we need to work on 2 groups at a time
- pdata_
- [(((((g / 2) * R + r) * S + s) * OC_per_G + k) * 2 +
- (g % 2)) *
- IC_per_G +
- c] = b;
+ int p_idx = packed_index_(r, s, k, g, c);
+ int up_idx = unpacked_index_(r, s, k, g, c, tr);
+ // Pack: src (unpacked) -> dst (packed)
+ if (ispack) {
+ dst[p_idx] = src[up_idx];
} else {
- pdata_
- [((((g * (IC_per_G / 4) + (c / 4)) * R + r) * S + s) *
- OC_per_G +
- k) *
- 4 +
- (c % 4)] = b;
+ dst[up_idx] = src[p_idx];
}
}
}
@@ -99,14 +142,54 @@ void PackWeightMatrixForGConv<T, accT, SPATIAL_DIM>::pack() {
}
}
} else {
+ // For pack & transposed, call transposeConvWeights()
+ // G K/G (R S C/G) => G (R S C/G) K/G
if (tr) {
- // conv_ref expects weights to be in G (R S C/G) K/G format
- transposeConvWeights(conv_param_, sdata_, pdata_);
+ if (ispack) {
+ transposeConvWeights(conv_param_, src, dst);
+ } else {
+ // TODO: Wrap this as a inverseTransposeConvWeights()?
+ // For unpack & transposed, call transposeConvWeights()
+ // G (R S C/G) K/G => G K/G (R S C/G)
+ for (int r = 0; r < R; ++r) {
+ for (int s = 0; s < S; ++s) {
+ for (int k = 0; k < OC_per_G; ++k) {
+ for (int g = 0; g < G; ++g) {
+ for (int c = 0; c < IC_per_G; ++c) {
+ dst[(((g * OC_per_G + k) * R + r) * S + s)
+ * IC_per_G + c] =
+ src[(((g * R + r) * S + s) * IC_per_G + c)
+ * OC_per_G + k];
+ }
+ }
+ }
+ }
+ }
+ } // end if(ispack)
} else {
// just copy the data for not supported cases
- memcpy(pdata_, sdata_, G * R * S * OC_per_G * IC_per_G * sizeof(inpType));
- }
- }
+ memcpy(dst, src,
+ G * R * S * OC_per_G * IC_per_G * sizeof(inpType));
+ } //end if(tr)
+ } // end if(fbgemmOptimizedGConv(conv_param_)
+}
+
+/**
+ * @brief Pack weight tensor in a suitable format required for the optimized
+ * kernel.
+ */
+template <typename T, typename accT, int SPATIAL_DIM>
+void PackWeightMatrixForGConv<T, accT, SPATIAL_DIM>::pack() {
+ pack_unpack_(sdata_, pdata_, true);
+}
+
+/**
+ * @brief Unpack the packed weight tensor (for the optimized kernel)
+ * to the original form.
+ */
+template <typename T, typename accT, int SPATIAL_DIM>
+void PackWeightMatrixForGConv<T, accT, SPATIAL_DIM>::unpack(T* origin_buf) {
+ pack_unpack_(const_cast<const T*>(pdata_), origin_buf, false);
}
template class PackWeightMatrixForGConv<int8_t, int32_t, 2>;
diff --git a/src/PackWeightsForConv.cc b/src/PackWeightsForConv.cc
index c811144..192fb00 100644
--- a/src/PackWeightsForConv.cc
+++ b/src/PackWeightsForConv.cc
@@ -4,6 +4,7 @@
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
+#include <algorithm>
#include <memory>
#include "fbgemm/Fbgemm.h"
@@ -13,7 +14,8 @@ template <int SPATIAL_DIM, typename T, typename accT>
PackWeightsForConv<SPATIAL_DIM, T, accT>::PackWeightsForConv(
const conv_param_t<SPATIAL_DIM>& conv_p,
const T* sdata,
- const BlockingFactors* blocking_params) {
+ const BlockingFactors* blocking_params)
+ : conv_param_(conv_p) {
static_assert(
SPATIAL_DIM == 2 || SPATIAL_DIM == 3,
"Only 2D and 3D convolutions are supported");
@@ -21,50 +23,153 @@ PackWeightsForConv<SPATIAL_DIM, T, accT>::PackWeightsForConv(
// FbgemmConv.cc
switch (ConvFastPath<SPATIAL_DIM, accT>(conv_p)) {
case optimized_conv_t::depthwise: {
- if (SPATIAL_DIM == 3) {
- W_im2col_packed_ = nullptr;
- W_dw_2D_packed_ = nullptr;
- W_dw_3D_packed_ =
- std::make_shared<Packed3x3x3ConvMatrix>(conv_p.G, sdata);
- W_gconv_packed_ = nullptr;
- } else {
- W_im2col_packed_ = nullptr;
- W_dw_2D_packed_ =
- std::make_shared<Packed3x3ConvMatrix>(conv_p.G, sdata);
- W_dw_3D_packed_ = nullptr;
- W_gconv_packed_ = nullptr;
- }
+ W_dw_packed_ = std::make_shared<PackedDepthWiseConvMatrix>(
+ conv_p.G, SPATIAL_DIM == 3 ? 3 * 3 * 3 : 3 * 3, sdata);
break;
}
case optimized_conv_t::groupwise: {
- W_im2col_packed_ = nullptr;
- W_dw_2D_packed_ = nullptr;
- W_dw_3D_packed_ = nullptr;
W_gconv_packed_ =
std::make_shared<PackWeightMatrixForGConv<T, accT, SPATIAL_DIM>>(
- matrix_op_t::NoTranspose, conv_p, sdata, nullptr);
+ matrix_op_t::Transpose, conv_p, sdata, nullptr);
+ break;
+ }
+ case optimized_conv_t::pointwise: {
+ int NDim = conv_p.OC / conv_p.G;
+ int KDim = conv_p.K[0] * conv_p.K[1] * conv_p.IC;
+ W_pointwise_packed_ = std::make_shared<PackBMatrix<T, accT>>(
+ matrix_op_t::Transpose,
+ KDim,
+ NDim,
+ sdata,
+ KDim / conv_p.G,
+ nullptr,
+ conv_p.G,
+ blocking_params);
break;
}
case optimized_conv_t::im2col: {
int NDim = conv_p.OC / conv_p.G;
int KDim = conv_p.K[0] * conv_p.K[1] * conv_p.IC;
W_im2col_packed_ = std::make_shared<PackBMatrix<T, accT>>(
- matrix_op_t::NoTranspose,
+ matrix_op_t::Transpose,
KDim,
NDim,
sdata,
- NDim,
+ KDim / conv_p.G,
nullptr,
conv_p.G,
blocking_params);
- W_dw_2D_packed_ = nullptr;
- W_dw_3D_packed_ = nullptr;
- W_gconv_packed_ = nullptr;
break;
}
} // switch
}
+template <int SPATIAL_DIM, typename T, typename accT>
+void PackWeightsForConv<SPATIAL_DIM, T, accT>::unpack(T* origin_buf) {
+ if (W_dw_packed_) {
+ W_dw_packed_->unpack(origin_buf);
+ } else if (W_gconv_packed_) {
+ W_gconv_packed_->unpack(origin_buf);
+ } else if (W_im2col_packed_) {
+ W_im2col_packed_->unpack(origin_buf);
+ } else if (W_pointwise_packed_) {
+ W_pointwise_packed_->unpack(origin_buf);
+ } else {
+ assert(false && "At least one packed weights object should exist");
+ }
+}
+
+template <int SPATIAL_DIM, typename T, typename accT>
+bool PackWeightsForConv<SPATIAL_DIM, T, accT>::isPackingCompliant(
+ const conv_param_t<SPATIAL_DIM>& test_conv_p) {
+ return conv_param_.IC == test_conv_p.IC && conv_param_.OC == test_conv_p.OC &&
+ conv_param_.G == test_conv_p.G &&
+ std::equal(
+ conv_param_.K.begin(),
+ conv_param_.K.end(),
+ test_conv_p.K.begin()) &&
+ std::equal(
+ conv_param_.stride.begin(),
+ conv_param_.stride.end(),
+ test_conv_p.stride.begin()) &&
+ std::equal(
+ conv_param_.pad.begin(),
+ conv_param_.pad.end(),
+ test_conv_p.pad.begin()) &&
+ std::equal(
+ conv_param_.dilation.begin(),
+ conv_param_.dilation.end(),
+ test_conv_p.dilation.begin());
+}
+
+template <int SPATIAL_DIM, typename T, typename accT>
+std::string PackWeightsForConv<SPATIAL_DIM, T, accT>::mismatchingParams(
+ const conv_param_t<SPATIAL_DIM>& test_conv_p) {
+ std::string msg = "";
+
+ auto combineStr = [](std::string id, std::string str1, std::string str2) {
+ std::string out = id + std::string(" ");
+ out += str1;
+ out += std::string(" vs ") + str2;
+ out += std::string(";");
+ return out;
+ };
+
+ auto combineInt = [&combineStr](std::string id, int int1, int int2) {
+ return combineStr(id, std::to_string(int1), std::to_string(int2));
+ };
+
+ if (conv_param_.IC != test_conv_p.IC) {
+ msg += combineInt("input_channels", conv_param_.IC, test_conv_p.IC);
+ }
+ if (conv_param_.OC != test_conv_p.OC) {
+ msg += combineInt("output_channels", conv_param_.IC, test_conv_p.IC);
+ }
+ if (conv_param_.G != test_conv_p.G) {
+ msg += combineInt("groups", conv_param_.G, test_conv_p.G);
+ }
+
+ if (!std::equal(
+ conv_param_.K.begin(), conv_param_.K.end(), test_conv_p.K.begin())) {
+ msg += combineStr(
+ "kernel",
+ arrayToString<SPATIAL_DIM>(conv_param_.K),
+ arrayToString<SPATIAL_DIM>(test_conv_p.K));
+ }
+
+ if (!std::equal(
+ conv_param_.stride.begin(),
+ conv_param_.stride.end(),
+ test_conv_p.stride.begin())) {
+ msg += combineStr(
+ "stride",
+ arrayToString<SPATIAL_DIM>(conv_param_.stride),
+ arrayToString<SPATIAL_DIM>(test_conv_p.stride));
+ }
+
+ if (!std::equal(
+ conv_param_.pad.begin(),
+ conv_param_.pad.end(),
+ test_conv_p.pad.begin())) {
+ msg += combineStr(
+ "pad",
+ arrayToString<2 * SPATIAL_DIM>(conv_param_.pad),
+ arrayToString<2 * SPATIAL_DIM>(test_conv_p.pad));
+ }
+
+ if (!std::equal(
+ conv_param_.dilation.begin(),
+ conv_param_.dilation.end(),
+ test_conv_p.dilation.begin())) {
+ msg += combineStr(
+ "dilation",
+ arrayToString<SPATIAL_DIM>(conv_param_.dilation),
+ arrayToString<SPATIAL_DIM>(test_conv_p.dilation));
+ }
+
+ return msg;
+}
+
template class PackWeightsForConv<2, int8_t, int32_t>;
template class PackWeightsForConv<3, int8_t, int32_t>;
diff --git a/src/QuantUtils.cc b/src/QuantUtils.cc
index 1ab00d1..a209efc 100644
--- a/src/QuantUtils.cc
+++ b/src/QuantUtils.cc
@@ -164,30 +164,143 @@ void ChooseRequantizationMultiplier(
dst[i] = Quantize<T>(src[i], qparams); \
} \
}
-FBGEMM_SPECIALIZED_QUANTIZE(int8_t)
FBGEMM_SPECIALIZED_QUANTIZE(uint16_t)
FBGEMM_SPECIALIZED_QUANTIZE(int16_t)
FBGEMM_SPECIALIZED_QUANTIZE(int32_t)
#undef FBGEMM_SPECIALIZED_QUANTIZE
+#define FBGEMM_SPECIALIZED_QUANTIZE_AVX2(T) \
+template <> \
+void Quantize<T>( \
+ const float* src, \
+ T* dst, \
+ int len, \
+ const TensorQuantizationParams& qparams) { \
+ bool avx2_support = cpuinfo_initialize() && fbgemmHasAvx2Support(); \
+ bool fma_support = cpuinfo_has_x86_fma3(); \
+ if (avx2_support && fma_support && qparams.precision == 8) { \
+ /* fast path */ \
+ QuantizeAvx2<T>(src, dst, len, qparams); \
+ } else { \
+ for (std::size_t i = 0; i < len; ++i) { \
+ dst[i] = Quantize<T>(src[i], qparams); \
+ } \
+ } \
+}
+
+FBGEMM_SPECIALIZED_QUANTIZE_AVX2(int8_t)
+FBGEMM_SPECIALIZED_QUANTIZE_AVX2(uint8_t)
+#undef FBGEMM_SPECIALIZED_QUANTIZE_AVX2
+
+#define FBGEMM_SPECIALIZED_QUANTIZEGROUPWISEKCX(T) \
+ template <> \
+ void QuantizeGroupwise<T, layout_t::KCX>( \
+ const float* src, \
+ int N, \
+ int C, \
+ int X, \
+ int G, \
+ const float* scales, \
+ const std::int32_t* zero_points, \
+ T* dst) { \
+ assert(C % G == 0); \
+ int C_per_G = C / G; \
+ for (int i = 0; i < N; ++i) { \
+ for (int g = 0; g < G; ++g) { \
+ float scale = scales[g]; \
+ int32_t zero_point = zero_points[g]; \
+ for (int c = 0; c < C / G; ++c) { \
+ for (int x = 0; x < X; ++x) { \
+ dst[(i * C + g * C_per_G + c) * X + x] = Quantize<T>( \
+ src[(i * C + g * C_per_G + c) * X + x], \
+ zero_point, \
+ scale, \
+ 8 * sizeof(T)); \
+ } \
+ } \
+ } \
+ } \
+ }
+FBGEMM_SPECIALIZED_QUANTIZEGROUPWISEKCX(int8_t)
+FBGEMM_SPECIALIZED_QUANTIZEGROUPWISEKCX(int32_t)
+#undef FBGEMM_SPECIALIZED_QUANTIZEGROUPWISEKCX
+
template <>
-void Quantize<uint8_t>(
+void QuantizeGroupwise<uint8_t, layout_t::KCX>(
const float* src,
- uint8_t* dst,
- int len,
- const TensorQuantizationParams& qparams) {
- bool avx2_support = cpuinfo_initialize() && fbgemmHasAvx2Support();
- bool fma_support = cpuinfo_has_x86_fma3();
- if (avx2_support && fma_support && qparams.precision == 8) {
- // fast path
- QuantizeAvx2(src, dst, len, qparams);
- } else {
- for (std::size_t i = 0; i < len; ++i) {
- dst[i] = Quantize<uint8_t>(src[i], qparams);
+ int K,
+ int C,
+ int X,
+ int G,
+ const float* scales,
+ const std::int32_t* zero_points,
+ uint8_t* dst) {
+ assert(C % G == 0);
+ int C_per_G = C / G;
+ fbgemm::TensorQuantizationParams qparams;
+ qparams.precision = 8 * sizeof(uint8_t);
+ bool takeFastPath =
+ cpuinfo_initialize() && fbgemmHasAvx2Support() && cpuinfo_has_x86_fma3();
+
+ for (int i = 0; i < K; ++i) {
+ for (int g = 0; g < G; ++g) {
+ qparams.scale = scales[g];
+ qparams.zero_point = zero_points[g];
+ if (takeFastPath) {
+ QuantizeAvx2(
+ src + (i * C + g * C_per_G) * X,
+ dst + (i * C + g * C_per_G) * X,
+ C_per_G * X,
+ qparams);
+ } else {
+ for (int c = 0; c < C / G; ++c) {
+ for (int x = 0; x < X; ++x) {
+ dst[(i * C + g * C_per_G + c) * X + x] = Quantize<uint8_t>(
+ src[(i * C + g * C_per_G + c) * X + x],
+ qparams.zero_point,
+ qparams.scale,
+ qparams.precision);
+ }
+ }
+ }
}
}
}
+#define FBGEMM_SPECIALIZED_QUANTIZEGROUPWISEKXC(T) \
+ template <> \
+ void QuantizeGroupwise<T, layout_t::KXC>( \
+ const float* src, \
+ int K, \
+ int C, \
+ int X, \
+ int G, \
+ const float* scales, \
+ const std::int32_t* zero_points, \
+ T* dst) { \
+ assert(C % G == 0); \
+ int C_per_G = C / G; \
+ for (int i = 0; i < K; ++i) { \
+ for (int x = 0; x < X; ++x) { \
+ for (int g = 0; g < G; ++g) { \
+ float scale = scales[g]; \
+ int32_t zero_point = zero_points[g]; \
+ for (int c = 0; c < C / G; ++c) { \
+ dst[(i * X + x) * C + g * C_per_G + c] = Quantize<T>( \
+ src[(i * X + x) * C + g * C_per_G + c], \
+ zero_point, \
+ scale, \
+ 8 * sizeof(T)); \
+ } \
+ } \
+ } \
+ } \
+ }
+FBGEMM_SPECIALIZED_QUANTIZEGROUPWISEKXC(int8_t)
+FBGEMM_SPECIALIZED_QUANTIZEGROUPWISEKXC(uint8_t)
+FBGEMM_SPECIALIZED_QUANTIZEGROUPWISEKXC(int32_t)
+#undef FBGEMM_SPECIALIZED_QUANTIZEGROUPWISEKXC
+
////////////////////////////////////////////////////////////////////////////////
// Requantization (pure fixed-point)
diff --git a/src/QuantUtilsAvx2.cc b/src/QuantUtilsAvx2.cc
index 821999e..66828ae 100755..100644
--- a/src/QuantUtilsAvx2.cc
+++ b/src/QuantUtilsAvx2.cc
@@ -18,16 +18,20 @@ using namespace std;
////////////////////////////////////////////////////////////////////////////////
// Utility functions
+template <typename T>
void QuantizeAvx2(
const float* src,
- uint8_t* dst,
+ T* dst,
int len,
const TensorQuantizationParams& qparams) {
-#if defined(__AVX2__) && defined(__FMA__)
- constexpr int VLEN = 8;
- std::size_t i = 0;
- __m256 inverse_scale_v = _mm256_set1_ps(1.f / qparams.scale);
- __m256i shuffle_mask_v = _mm256_set_epi8(
+ // original compile condition - #if defined(__AVX2__) && (defined(__FMA__) || defined(_MSC_VER))
+ if (fbgemm::fbgemmHasAvx2Support()) {
+ constexpr int VLEN = 8;
+ constexpr float min_val = std::numeric_limits<T>::min();
+ constexpr float max_val = std::numeric_limits<T>::max();
+ std::size_t i = 0;
+ __m256 inverse_scale_v = _mm256_set1_ps(1.f / qparams.scale);
+ __m256i shuffle_mask_v = _mm256_set_epi8(
0xff,
0xff,
0xff,
@@ -60,41 +64,56 @@ void QuantizeAvx2(
0x08,
0x04,
0x00);
- __m256i permute_mask_v =
+ __m256i permute_mask_v =
_mm256_set_epi32(0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x00);
- for (; i < len / VLEN * VLEN; i += VLEN) {
- __m256 src_v = _mm256_loadu_ps(src + i);
- __m256 transformed_v = _mm256_fmadd_ps(
+ for (; i < len / VLEN * VLEN; i += VLEN) {
+ __m256 src_v = _mm256_loadu_ps(src + i);
+ __m256 transformed_v = _mm256_fmadd_ps(
src_v, inverse_scale_v, _mm256_set1_ps(qparams.zero_point));
- __m256 clipped_v = _mm256_min_ps(
- _mm256_max_ps(transformed_v, _mm256_set1_ps(0.f)),
- _mm256_set1_ps(255.f));
- __m256i rounded_v = _mm256_cvtps_epi32(clipped_v);
-
- // An instruction sequence to save 8 32-bit integers as 8 8-bit integers
- rounded_v = _mm256_shuffle_epi8(rounded_v, shuffle_mask_v);
- rounded_v = _mm256_permutevar8x32_epi32(rounded_v, permute_mask_v);
- _mm_storel_epi64(
+ __m256 clipped_v = _mm256_min_ps(
+ _mm256_max_ps(transformed_v, _mm256_set1_ps(min_val)),
+ _mm256_set1_ps(max_val));
+ __m256i rounded_v = _mm256_cvtps_epi32(clipped_v);
+
+ // An instruction sequence to save 8 32-bit integers as 8 8-bit integers
+ rounded_v = _mm256_shuffle_epi8(rounded_v, shuffle_mask_v);
+ rounded_v = _mm256_permutevar8x32_epi32(rounded_v, permute_mask_v);
+ _mm_storel_epi64(
reinterpret_cast<__m128i*>(dst + i), _mm256_castsi256_si128(rounded_v));
- }
+ }
- for (; i < len; ++i) {
- float transformed = qparams.zero_point + src[i] / qparams.scale;
- float clipped = std::min(std::max(transformed, 0.f), 255.f);
- // Not exactly the same behavior as the vectorized code.
- // The vectorized code above always rounds to even in halfway cases
- // (https://software.intel.com/en-us/node/523819), but std::nearbyint
- // does the same only when the current rounding mode is FE_TONEAREST.
- // However, in practice, this should not be a problem because most cases
- // use the default rounding mode FE_TONEAREST.
- // Note that we cannot implement the same behavior as the vectorized code
- // using std::round because it does rounding away from zero in halfway
- // cases.
- dst[i] = nearbyint(clipped);
+ for (; i < len; ++i) {
+ float transformed = qparams.zero_point + src[i] / qparams.scale;
+ float clipped = std::min(std::max(transformed, min_val), max_val);
+ // Not exactly the same behavior as the vectorized code.
+ // The vectorized code above always rounds to even in halfway cases
+ // (https://software.intel.com/en-us/node/523819), but std::nearbyint
+ // does the same only when the current rounding mode is FE_TONEAREST.
+ // However, in practice, this should not be a problem because most cases
+ // use the default rounding mode FE_TONEAREST.
+ // Note that we cannot implement the same behavior as the vectorized code
+ // using std::round because it does rounding away from zero in halfway
+ // cases.
+ dst[i] = nearbyint(clipped);
+ }
}
-#endif
}
+// Instantiate QuantizeAvx2 for known datatypes
+template
+void QuantizeAvx2<uint8_t>(
+ const float* src,
+ uint8_t* dst,
+ int len,
+ const TensorQuantizationParams& qparams);
+template
+void QuantizeAvx2<int8_t>(
+ const float* src,
+ int8_t* dst,
+ int len,
+ const TensorQuantizationParams& qparams);
+
+
void FindMinMax(const float* a, float* min, float* max, int len) {
if (len <= 0) {
*min = 0.0f;
@@ -105,24 +124,24 @@ void FindMinMax(const float* a, float* min, float* max, int len) {
float temp_min = *a, temp_max = *a;
int i = 0;
-#ifdef __AVX__
- __m256 min_v = _mm256_set1_ps(*a), max_v = _mm256_set1_ps(*a);
- constexpr int VLEN = 8;
- if (len >= VLEN) {
- for (; i < len / VLEN * VLEN; i += VLEN) {
- min_v = _mm256_min_ps(min_v, _mm256_loadu_ps(a + i));
- max_v = _mm256_max_ps(max_v, _mm256_loadu_ps(a + i));
- }
+ if (fbgemm::fbgemmHasAvx2Support()) {
+ __m256 min_v = _mm256_set1_ps(*a), max_v = _mm256_set1_ps(*a);
+ constexpr int VLEN = 8;
+ if (len >= VLEN) {
+ for (; i < len / VLEN * VLEN; i += VLEN) {
+ min_v = _mm256_min_ps(min_v, _mm256_loadu_ps(a + i));
+ max_v = _mm256_max_ps(max_v, _mm256_loadu_ps(a + i));
+ }
- float min_buf[VLEN], max_buf[VLEN];
- _mm256_storeu_ps(min_buf, min_v);
- _mm256_storeu_ps(max_buf, max_v);
- for (int j = 0; j < VLEN; ++j) {
- temp_min = std::min(temp_min, min_buf[j]);
- temp_max = std::max(temp_max, max_buf[j]);
+ float min_buf[VLEN], max_buf[VLEN];
+ _mm256_storeu_ps(min_buf, min_v);
+ _mm256_storeu_ps(max_buf, max_v);
+ for (int j = 0; j < VLEN; ++j) {
+ temp_min = std::min(temp_min, min_buf[j]);
+ temp_max = std::max(temp_max, max_buf[j]);
+ }
}
}
-#endif
for (; i < len; i++) {
temp_min = std::min(temp_min, a[i]);
@@ -135,15 +154,15 @@ void FindMinMax(const float* a, float* min, float* max, int len) {
////////////////////////////////////////////////////////////////////////////////
// Requantization (with floats)
-#ifdef __AVX2__
void RequantizeAvx2(
const int32_t* src,
uint8_t* dst,
int len,
const RequantizationParams& params) {
- DoNothing<> doNothingObj{};
- int32_t Bq_zero_point[] = { 0 };
- ReQuantizeOutput<false /* FUSE_RELU */> requantizeObj(
+ if (fbgemm::fbgemmHasAvx2Support()) {
+ DoNothing<> doNothingObj{};
+ int32_t Bq_zero_point[] = { 0 };
+ ReQuantizeOutput<false /* FUSE_RELU */> requantizeObj(
doNothingObj,
&params.real_multiplier,
params.target_qparams.zero_point,
@@ -153,7 +172,8 @@ void RequantizeAvx2(
nullptr, // col_offsets
nullptr, // bias
len); // ncol
- requantizeObj.f<inst_set_t::avx2>(dst, src, {0, 1, 0, len}, 0, 0);
+ requantizeObj.f<inst_set_t::avx2>(dst, src, { 0, 1, 0, len }, 0, 0);
+ }
}
void RequantizeFixedPointAvx2(
@@ -161,24 +181,26 @@ void RequantizeFixedPointAvx2(
uint8_t* dst,
int len,
const RequantizationParams& params) {
- constexpr int VLEN = 8;
+ if (fbgemm::fbgemmHasAvx2Support())
+ {
+ constexpr int VLEN = 8;
- __m256i b = _mm256_set1_epi32(params.multiplier);
+ __m256i b = _mm256_set1_epi32(params.multiplier);
- // AVX2 doesn't support arithmetic right shift.
- // As a work around, we convert 64-bit multiplied results to uint64_t by
- // adding 0x8000000000000000ULL, logical right shift, and subtract by
- // (0x8000000000000000ULL >> right_shift).
- __m256i pre_shift_nudge = _mm256_set1_epi64x(
+ // AVX2 doesn't support arithmetic right shift.
+ // As a work around, we convert 64-bit multiplied results to uint64_t by
+ // adding 0x8000000000000000ULL, logical right shift, and subtract by
+ // (0x8000000000000000ULL >> right_shift).
+ __m256i pre_shift_nudge = _mm256_set1_epi64x(
(1ll << (params.right_shift - 1)) + 0x8000000000000000ULL);
- __m256i post_shift_nudge = _mm256_set1_epi64x(
+ __m256i post_shift_nudge = _mm256_set1_epi64x(
params.target_qparams.zero_point -
(0x8000000000000000ULL >> params.right_shift));
- __m256i min_v = _mm256_set1_epi32(numeric_limits<uint8_t>::min());
- __m256i max_v = _mm256_set1_epi32(numeric_limits<uint8_t>::max());
+ __m256i min_v = _mm256_set1_epi32(numeric_limits<uint8_t>::min());
+ __m256i max_v = _mm256_set1_epi32(numeric_limits<uint8_t>::max());
- __m256i shuffle_mask_v = _mm256_set_epi8(
+ __m256i shuffle_mask_v = _mm256_set_epi8(
0xff,
0xff,
0xff,
@@ -211,75 +233,68 @@ void RequantizeFixedPointAvx2(
0x08,
0x04,
0x00);
- __m256i permute_mask_v =
+ __m256i permute_mask_v =
_mm256_set_epi32(0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x00);
- int i = 0;
- for (; i < len / VLEN * VLEN; i += VLEN) {
- __m256i a_v = _mm256_loadu_si256((const __m256i*)(src + i));
+ int i = 0;
+ for (; i < len / VLEN * VLEN; i += VLEN) {
+ __m256i a_v = _mm256_loadu_si256((const __m256i*)(src + i));
- // a = a0 | a1 | a2 | a3 | a4 | a5 | a6 | a7
- // b = b0 | b1 | b3 | b3 | b4 | b5 | b6 | b7
- __m256i a_even_v = a_v;
- __m256i a_odd_v = _mm256_srli_si256(a_v, 4);
+ // a = a0 | a1 | a2 | a3 | a4 | a5 | a6 | a7
+ // b = b0 | b1 | b3 | b3 | b4 | b5 | b6 | b7
+ __m256i a_even_v = a_v;
+ __m256i a_odd_v = _mm256_srli_si256(a_v, 4);
- __m256i ab_even_v = _mm256_mul_epi32(a_even_v, b);
- __m256i ab_odd_v = _mm256_mul_epi32(a_odd_v, b);
+ __m256i ab_even_v = _mm256_mul_epi32(a_even_v, b);
+ __m256i ab_odd_v = _mm256_mul_epi32(a_odd_v, b);
- __m256i even_rounded_v = _mm256_add_epi64(ab_even_v, pre_shift_nudge);
- __m256i odd_rounded_v = _mm256_add_epi64(ab_odd_v, pre_shift_nudge);
+ __m256i even_rounded_v = _mm256_add_epi64(ab_even_v, pre_shift_nudge);
+ __m256i odd_rounded_v = _mm256_add_epi64(ab_odd_v, pre_shift_nudge);
- __m256i even_result_v = _mm256_add_epi64(
+ __m256i even_result_v = _mm256_add_epi64(
_mm256_srli_epi64(even_rounded_v, params.right_shift),
post_shift_nudge);
- __m256i odd_result_v = _mm256_add_epi64(
+ __m256i odd_result_v = _mm256_add_epi64(
_mm256_srli_epi64(odd_rounded_v, params.right_shift), post_shift_nudge);
- odd_result_v = _mm256_slli_si256(odd_result_v, 4);
+ odd_result_v = _mm256_slli_si256(odd_result_v, 4);
- // even_result_v has numbers we want in its even 32-bit SIMD lanes, and
- // odd_result_v has numbers we want in its odd 32-bit SIMD lanes.
- // Use blend to combine them.
- __m256i result_v = _mm256_blend_epi32(even_result_v, odd_result_v, 0xaa);
- __m256i clipped_v =
+ // even_result_v has numbers we want in its even 32-bit SIMD lanes, and
+ // odd_result_v has numbers we want in its odd 32-bit SIMD lanes.
+ // Use blend to combine them.
+ __m256i result_v = _mm256_blend_epi32(even_result_v, odd_result_v, 0xaa);
+ __m256i clipped_v =
_mm256_max_epi32(min_v, _mm256_min_epi32(max_v, result_v));
- clipped_v = _mm256_shuffle_epi8(clipped_v, shuffle_mask_v);
- clipped_v = _mm256_permutevar8x32_epi32(clipped_v, permute_mask_v);
- *(int64_t*)(dst + i) = _mm256_extract_epi64(clipped_v, 0);
- }
+ clipped_v = _mm256_shuffle_epi8(clipped_v, shuffle_mask_v);
+ clipped_v = _mm256_permutevar8x32_epi32(clipped_v, permute_mask_v);
+ *(int64_t*)(dst + i) = _mm256_extract_epi64(clipped_v, 0);
+ }
- for (; i < len; ++i) {
- int64_t ab_64 =
+ for (; i < len; ++i) {
+ int64_t ab_64 =
static_cast<int64_t>(src[i]) * static_cast<int64_t>(params.multiplier);
- int64_t nudge = 1ll << std::max(0, params.right_shift - 1);
- int64_t quantized_down = params.target_qparams.zero_point +
+ int64_t nudge = 1ll << std::max(0, params.right_shift - 1);
+ int64_t quantized_down = params.target_qparams.zero_point +
((ab_64 + nudge) >> params.right_shift);
- dst[i] = std::min<int64_t>(std::max<int64_t>(quantized_down, 0l), 255l);
+ dst[i] = std::min<int64_t>(std::max<int64_t>(quantized_down, 0l), 255l);
+ }
}
}
-#else
-// dummy implementations to avoid link errors
-void RequantizeAvx2(const int32_t* src, uint8_t* dst, int len, const RequantizationParams& params) {
- assert(false && "RequantizeAvx2() was called unexpectedly in non-AVX2 build");
-}
-void RequantizeFixedPointAvx2(const int32_t* src, uint8_t* dst, int len, const RequantizationParams& params) {
- assert(false && "RequantizeFixedPointAvx2() was called unexpectedly in non-AVX2 build");
-}
-#endif
template <
bool A_SYMMETRIC,
bool B_SYMMETRIC,
QuantizationGranularity Q_GRAN,
bool HAS_BIAS,
- bool FUSE_RELU>
+ bool FUSE_RELU,
+ typename BIAS_TYPE>
void requantizeOutputProcessingAvx2(
uint8_t* out,
const int32_t* inp,
const block_type_t& block,
int ld_out,
int ld_in,
- const requantizationParams_t& r) {
+ const requantizationParams_t<BIAS_TYPE>& r) {
// Adoption of implementation at QNNPACK/src/requantization/fp32-sse2.c
// using AVX2 instructions
int quant_param_idx = 0;
@@ -290,6 +305,15 @@ void requantizeOutputProcessingAvx2(
}
__m256 multiplier_v = _mm256_set1_ps(r.C_multiplier[quant_param_idx]);
+ // Broadcasted reciprocal of act_times_w_scale
+ __m256 act_times_w_rcp_v;
+ if (!(Q_GRAN == QuantizationGranularity::OUT_CHANNEL)) {
+ if (is_same<BIAS_TYPE, float>::value) {
+ act_times_w_rcp_v =
+ _mm256_set1_ps(1.0 / r.act_times_w_scale[quant_param_idx]);
+ }
+ }
+
__m256i min_v = _mm256_set1_epi8(static_cast<uint8_t>(0));
__m256i max_v = _mm256_set1_epi8(static_cast<uint8_t>(255));
@@ -399,22 +423,76 @@ void requantizeOutputProcessingAvx2(
}
w_v = _mm256_sub_epi32(w_v, row_offset_v);
}
+ __m256 xf_v, yf_v, zf_v, wf_v;
if (HAS_BIAS) {
- x_v = _mm256_add_epi32(
- x_v,
- _mm256_loadu_si256(reinterpret_cast<const __m256i*>(r.bias + j)));
- y_v = _mm256_add_epi32(
- y_v,
- _mm256_loadu_si256(
- reinterpret_cast<const __m256i*>(r.bias + j + VLEN)));
- z_v = _mm256_add_epi32(
- z_v,
- _mm256_loadu_si256(
- reinterpret_cast<const __m256i*>(r.bias + j + 2 * VLEN)));
- w_v = _mm256_add_epi32(
- w_v,
- _mm256_loadu_si256(
- reinterpret_cast<const __m256i*>(r.bias + j + 3 * VLEN)));
+ if (is_same<BIAS_TYPE, float>::value) {
+ __m256 x_bias_v, y_bias_v, z_bias_v, w_bias_v;
+ if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) {
+ x_bias_v = _mm256_div_ps(
+ _mm256_loadu_ps(
+ reinterpret_cast<const float*>(r.bias + j + 0 * VLEN)),
+ _mm256_loadu_ps(r.act_times_w_scale + j + 0 * VLEN));
+ y_bias_v = _mm256_div_ps(
+ _mm256_loadu_ps(
+ reinterpret_cast<const float*>(r.bias + j + 1 * VLEN)),
+ _mm256_loadu_ps(r.act_times_w_scale + j + 1 * VLEN));
+ z_bias_v = _mm256_div_ps(
+ _mm256_loadu_ps(
+ reinterpret_cast<const float*>(r.bias + j + 2 * VLEN)),
+ _mm256_loadu_ps(r.act_times_w_scale + j + 2 * VLEN));
+ w_bias_v = _mm256_div_ps(
+ _mm256_loadu_ps(
+ reinterpret_cast<const float*>(r.bias + j + 3 * VLEN)),
+ _mm256_loadu_ps(r.act_times_w_scale + j + 3 * VLEN));
+ } else {
+ x_bias_v = _mm256_mul_ps(
+ _mm256_loadu_ps(
+ reinterpret_cast<const float*>(r.bias + j + 0 * VLEN)),
+ act_times_w_rcp_v);
+ y_bias_v = _mm256_mul_ps(
+ _mm256_loadu_ps(
+ reinterpret_cast<const float*>(r.bias + j + 1 * VLEN)),
+ act_times_w_rcp_v);
+ z_bias_v = _mm256_mul_ps(
+ _mm256_loadu_ps(
+ reinterpret_cast<const float*>(r.bias + j + 2 * VLEN)),
+ act_times_w_rcp_v);
+ w_bias_v = _mm256_mul_ps(
+ _mm256_loadu_ps(
+ reinterpret_cast<const float*>(r.bias + j + 3 * VLEN)),
+ act_times_w_rcp_v);
+ }
+ xf_v = _mm256_add_ps(_mm256_cvtepi32_ps(x_v), x_bias_v);
+ yf_v = _mm256_add_ps(_mm256_cvtepi32_ps(y_v), y_bias_v);
+ zf_v = _mm256_add_ps(_mm256_cvtepi32_ps(z_v), z_bias_v);
+ wf_v = _mm256_add_ps(_mm256_cvtepi32_ps(w_v), w_bias_v);
+ } else {
+ x_v = _mm256_add_epi32(
+ x_v,
+ _mm256_loadu_si256(
+ reinterpret_cast<const __m256i*>(r.bias + j + 0 * VLEN)));
+ y_v = _mm256_add_epi32(
+ y_v,
+ _mm256_loadu_si256(
+ reinterpret_cast<const __m256i*>(r.bias + j + 1 * VLEN)));
+ z_v = _mm256_add_epi32(
+ z_v,
+ _mm256_loadu_si256(
+ reinterpret_cast<const __m256i*>(r.bias + j + 2 * VLEN)));
+ w_v = _mm256_add_epi32(
+ w_v,
+ _mm256_loadu_si256(
+ reinterpret_cast<const __m256i*>(r.bias + j + 3 * VLEN)));
+ xf_v = _mm256_cvtepi32_ps(x_v);
+ yf_v = _mm256_cvtepi32_ps(y_v);
+ zf_v = _mm256_cvtepi32_ps(z_v);
+ wf_v = _mm256_cvtepi32_ps(w_v);
+ }
+ } else {
+ xf_v = _mm256_cvtepi32_ps(x_v);
+ yf_v = _mm256_cvtepi32_ps(y_v);
+ zf_v = _mm256_cvtepi32_ps(z_v);
+ wf_v = _mm256_cvtepi32_ps(w_v);
}
/*
@@ -431,22 +509,19 @@ void requantizeOutputProcessingAvx2(
*/
__m256 x_scaled_v, y_scaled_v, z_scaled_v, w_scaled_v;
if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) {
- x_scaled_v = _mm256_mul_ps(
- _mm256_cvtepi32_ps(x_v), _mm256_loadu_ps(r.C_multiplier + j));
- y_scaled_v = _mm256_mul_ps(
- _mm256_cvtepi32_ps(y_v),
- _mm256_loadu_ps(r.C_multiplier + j + VLEN));
- z_scaled_v = _mm256_mul_ps(
- _mm256_cvtepi32_ps(z_v),
- _mm256_loadu_ps(r.C_multiplier + j + 2 * VLEN));
- w_scaled_v = _mm256_mul_ps(
- _mm256_cvtepi32_ps(w_v),
- _mm256_loadu_ps(r.C_multiplier + j + 3 * VLEN));
+ x_scaled_v =
+ _mm256_mul_ps(xf_v, _mm256_loadu_ps(r.C_multiplier + j + 0 * VLEN));
+ y_scaled_v =
+ _mm256_mul_ps(yf_v, _mm256_loadu_ps(r.C_multiplier + j + 1 * VLEN));
+ z_scaled_v =
+ _mm256_mul_ps(zf_v, _mm256_loadu_ps(r.C_multiplier + j + 2 * VLEN));
+ w_scaled_v =
+ _mm256_mul_ps(wf_v, _mm256_loadu_ps(r.C_multiplier + j + 3 * VLEN));
} else {
- x_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(x_v), multiplier_v);
- y_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(y_v), multiplier_v);
- z_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(z_v), multiplier_v);
- w_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(w_v), multiplier_v);
+ x_scaled_v = _mm256_mul_ps(xf_v, multiplier_v);
+ y_scaled_v = _mm256_mul_ps(yf_v, multiplier_v);
+ z_scaled_v = _mm256_mul_ps(zf_v, multiplier_v);
+ w_scaled_v = _mm256_mul_ps(wf_v, multiplier_v);
}
/*
@@ -533,18 +608,35 @@ void requantizeOutputProcessingAvx2(
}
x_v = _mm256_sub_epi32(x_v, row_offset_v);
}
+ __m256 xf_v;
if (HAS_BIAS) {
- x_v = _mm256_add_epi32(
- x_v,
- _mm256_loadu_si256(reinterpret_cast<const __m256i*>(r.bias + j)));
+ if (is_same<BIAS_TYPE, float>::value) {
+ __m256 x_bias_v;
+ if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) {
+ x_bias_v = _mm256_div_ps(
+ _mm256_loadu_ps(reinterpret_cast<const float*>(r.bias + j)),
+ _mm256_loadu_ps(r.act_times_w_scale + j));
+ } else {
+ x_bias_v = _mm256_mul_ps(
+ _mm256_loadu_ps(reinterpret_cast<const float*>(r.bias + j)),
+ act_times_w_rcp_v);
+ }
+ xf_v = _mm256_add_ps(_mm256_cvtepi32_ps(x_v), x_bias_v);
+ } else {
+ x_v = _mm256_add_epi32(
+ x_v,
+ _mm256_loadu_si256(reinterpret_cast<const __m256i*>(r.bias + j)));
+ xf_v = _mm256_cvtepi32_ps(x_v);
+ }
+ } else {
+ xf_v = _mm256_cvtepi32_ps(x_v);
}
__m256 x_scaled_v;
if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) {
- x_scaled_v = _mm256_mul_ps(
- _mm256_cvtepi32_ps(x_v), _mm256_loadu_ps(r.C_multiplier + j));
+ x_scaled_v = _mm256_mul_ps(xf_v, _mm256_loadu_ps(r.C_multiplier + j));
} else {
- x_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(x_v), multiplier_v);
+ x_scaled_v = _mm256_mul_ps(xf_v, multiplier_v);
}
__m256i x_rounded_v = _mm256_cvtps_epi32(x_scaled_v);
@@ -582,6 +674,7 @@ void requantizeOutputProcessingAvx2(
int remainder = block.col_start + block.col_size - j;
if (remainder > 0) {
+ // clang-format off
alignas(64) const int masks[8][8] = {
// NOTE: clang-format wants to use a different formatting but the
// current formatting should be easier to read.
@@ -594,6 +687,7 @@ void requantizeOutputProcessingAvx2(
{ -1, -1, -1, -1, -1, -1, 0, 0, },
{ -1, -1, -1, -1, -1, -1, -1, 0, },
};
+ // clang-format on
__m256i mask_v = _mm256_load_si256(
reinterpret_cast<const __m256i*>(masks[remainder]));
@@ -615,17 +709,40 @@ void requantizeOutputProcessingAvx2(
}
x_v = _mm256_sub_epi32(x_v, row_offset_v);
}
+
+ __m256 xf_v;
if (HAS_BIAS) {
- x_v = _mm256_add_epi32(x_v, _mm256_maskload_epi32(r.bias + j, mask_v));
+ if (is_same<BIAS_TYPE, float>::value) {
+ __m256 x_bias_v;
+ if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) {
+ x_bias_v = _mm256_div_ps(
+ _mm256_maskload_ps(
+ reinterpret_cast<const float*>(r.bias + j), mask_v),
+ _mm256_maskload_ps(r.act_times_w_scale + j, mask_v));
+ } else {
+ x_bias_v = _mm256_mul_ps(
+ _mm256_maskload_ps(
+ reinterpret_cast<const float*>(r.bias + j), mask_v),
+ act_times_w_rcp_v);
+ }
+ xf_v = _mm256_add_ps(_mm256_cvtepi32_ps(x_v), x_bias_v);
+ } else {
+ x_v = _mm256_add_epi32(
+ x_v,
+ _mm256_maskload_epi32(
+ reinterpret_cast<const int*>(r.bias + j), mask_v));
+ xf_v = _mm256_cvtepi32_ps(x_v);
+ }
+ } else {
+ xf_v = _mm256_cvtepi32_ps(x_v);
}
__m256 x_scaled_v;
if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) {
- x_scaled_v = _mm256_mul_ps(
- _mm256_cvtepi32_ps(x_v),
- _mm256_maskload_ps(r.C_multiplier + j, mask_v));
+ x_scaled_v =
+ _mm256_mul_ps(xf_v, _mm256_maskload_ps(r.C_multiplier + j, mask_v));
} else {
- x_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(x_v), multiplier_v);
+ x_scaled_v = _mm256_mul_ps(xf_v, multiplier_v);
}
__m256i x_rounded_v = _mm256_cvtps_epi32(x_scaled_v);
@@ -767,6 +884,7 @@ void requantizeForFloatAvx2(
int remainder = block.col_start + block.col_size - j;
if (remainder > 0) {
+ // clang-format off
alignas(64) const int masks[8][8] = {
// NOTE: clang-format wants to use a different formatting but the
// current formatting should be easier to read.
@@ -779,6 +897,7 @@ void requantizeForFloatAvx2(
{ -1, -1, -1, -1, -1, -1, 0, 0, },
{ -1, -1, -1, -1, -1, -1, -1, 0, },
};
+ // clang-format on
__m256i mask_v = _mm256_load_si256(
reinterpret_cast<const __m256i*>(masks[remainder]));
@@ -831,14 +950,15 @@ template <
QuantizationGranularity Q_GRAN,
bool HAS_BIAS,
bool FUSE_RELU,
- int C_PER_G>
+ int C_PER_G,
+ typename BIAS_TYPE>
void requantizeOutputProcessingGConvAvx2(
uint8_t* out,
const int32_t* inp,
const block_type_t& block,
int ld_out,
int ld_in,
- const requantizationParams_t& r) {
+ const requantizationParams_t<BIAS_TYPE>& r) {
// Adoption of implementation at QNNPACK/src/requantization/fp32-sse2.c
// using AVX2 instructions
int quant_param_idx = 0;
@@ -849,6 +969,14 @@ void requantizeOutputProcessingGConvAvx2(
}
__m256 multiplier_v = _mm256_set1_ps(r.C_multiplier[quant_param_idx]);
+ // Broadcasted reciprocal of act_times_w_scale
+ __m256 act_times_w_rcp_v;
+ if (!(Q_GRAN == QuantizationGranularity::OUT_CHANNEL)) {
+ if (is_same<BIAS_TYPE, float>::value) {
+ act_times_w_rcp_v =
+ _mm256_set1_ps(1.0 / r.act_times_w_scale[quant_param_idx]);
+ }
+ }
__m256i min_v = _mm256_set1_epi8(static_cast<uint8_t>(0));
__m256i max_v = _mm256_set1_epi8(static_cast<uint8_t>(255));
@@ -1095,22 +1223,135 @@ void requantizeOutputProcessingGConvAvx2(
}
w_v = _mm256_sub_epi32(w_v, row_offset_v);
}
+ __m256 xf_v, yf_v, zf_v, wf_v;
if (HAS_BIAS) {
- x_v = _mm256_add_epi32(
- x_v,
- _mm256_loadu_si256(reinterpret_cast<const __m256i*>(r.bias + j)));
- y_v = _mm256_add_epi32(
- y_v,
- _mm256_loadu_si256(
- reinterpret_cast<const __m256i*>(r.bias + j + VLEN)));
- z_v = _mm256_add_epi32(
- z_v,
- _mm256_loadu_si256(
- reinterpret_cast<const __m256i*>(r.bias + j + 2 * VLEN)));
- w_v = _mm256_add_epi32(
- w_v,
- _mm256_loadu_si256(
- reinterpret_cast<const __m256i*>(r.bias + j + 3 * VLEN)));
+ if (is_same<BIAS_TYPE, float>::value) {
+ __m256 x_bias_v = _mm256_loadu_ps(
+ reinterpret_cast<const float*>(r.bias + j + 0 * VLEN));
+ __m256 y_bias_v = _mm256_loadu_ps(
+ reinterpret_cast<const float*>(r.bias + j + 1 * VLEN));
+ __m256 z_bias_v = _mm256_loadu_ps(
+ reinterpret_cast<const float*>(r.bias + j + 2 * VLEN));
+ __m256 w_bias_v = _mm256_loadu_ps(
+ reinterpret_cast<const float*>(r.bias + j + 3 * VLEN));
+ if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) {
+ x_bias_v = _mm256_div_ps(
+ x_bias_v, _mm256_loadu_ps(r.act_times_w_scale + j + 0 * VLEN));
+ y_bias_v = _mm256_div_ps(
+ y_bias_v, _mm256_loadu_ps(r.act_times_w_scale + j + 1 * VLEN));
+ z_bias_v = _mm256_div_ps(
+ z_bias_v, _mm256_loadu_ps(r.act_times_w_scale + j + 2 * VLEN));
+ w_bias_v = _mm256_div_ps(
+ w_bias_v, _mm256_loadu_ps(r.act_times_w_scale + j + 3 * VLEN));
+ } else if (Q_GRAN == QuantizationGranularity::GROUP) {
+ __m256 diviser_v;
+ if (C_PER_G == 4) {
+ diviser_v = _mm256_insertf128_ps(
+ _mm256_castps128_ps256(
+ _mm_set1_ps(r.act_times_w_scale[quant_param_idx + 0])),
+ _mm_set1_ps(r.act_times_w_scale[quant_param_idx + 1]),
+ 1);
+ x_bias_v = _mm256_div_ps(x_bias_v, diviser_v);
+
+ diviser_v = _mm256_insertf128_ps(
+ _mm256_castps128_ps256(
+ _mm_set1_ps(r.act_times_w_scale[quant_param_idx + 2])),
+ _mm_set1_ps(r.act_times_w_scale[quant_param_idx + 3]),
+ 1);
+ y_bias_v = _mm256_div_ps(y_bias_v, diviser_v);
+
+ diviser_v = _mm256_insertf128_ps(
+ _mm256_castps128_ps256(
+ _mm_set1_ps(r.act_times_w_scale[quant_param_idx + 4])),
+ _mm_set1_ps(r.act_times_w_scale[quant_param_idx + 5]),
+ 1);
+ z_bias_v = _mm256_div_ps(z_bias_v, diviser_v);
+
+ diviser_v = _mm256_insertf128_ps(
+ _mm256_castps128_ps256(
+ _mm_set1_ps(r.act_times_w_scale[quant_param_idx + 6])),
+ _mm_set1_ps(r.act_times_w_scale[quant_param_idx + 7]),
+ 1);
+ w_bias_v = _mm256_div_ps(w_bias_v, diviser_v);
+
+ } else if (C_PER_G == 8) {
+ diviser_v = _mm256_set1_ps(
+ r.act_times_w_scale
+ [quant_param_idx +
+ (j - block.col_start) / (VLEN * 4) * 4 + 0]);
+ x_bias_v = _mm256_div_ps(x_bias_v, diviser_v);
+
+ diviser_v = _mm256_set1_ps(
+ r.act_times_w_scale
+ [quant_param_idx +
+ (j - block.col_start) / (VLEN * 4) * 4 + 1]);
+ y_bias_v = _mm256_div_ps(y_bias_v, diviser_v);
+
+ diviser_v = _mm256_set1_ps(
+ r.act_times_w_scale
+ [quant_param_idx +
+ (j - block.col_start) / (VLEN * 4) * 4 + 2]);
+ z_bias_v = _mm256_div_ps(z_bias_v, diviser_v);
+
+ diviser_v = _mm256_set1_ps(
+ r.act_times_w_scale
+ [quant_param_idx +
+ (j - block.col_start) / (VLEN * 4) * 4 + 3]);
+ w_bias_v = _mm256_div_ps(w_bias_v, diviser_v);
+
+ } else {
+ assert(C_PER_G == 16);
+ diviser_v = _mm256_set1_ps(
+ r.act_times_w_scale
+ [quant_param_idx +
+ (j - block.col_start) / (VLEN * 4) * 2 + 0]);
+ x_bias_v = _mm256_div_ps(x_bias_v, diviser_v);
+ y_bias_v = _mm256_div_ps(y_bias_v, diviser_v);
+
+ diviser_v = _mm256_set1_ps(
+ r.act_times_w_scale
+ [quant_param_idx +
+ (j - block.col_start) / (VLEN * 4) * 2 + 1]);
+ z_bias_v = _mm256_div_ps(z_bias_v, diviser_v);
+ w_bias_v = _mm256_div_ps(w_bias_v, diviser_v);
+ }
+ } else {
+ x_bias_v = _mm256_mul_ps(x_bias_v, act_times_w_rcp_v);
+ y_bias_v = _mm256_mul_ps(y_bias_v, act_times_w_rcp_v);
+ z_bias_v = _mm256_mul_ps(z_bias_v, act_times_w_rcp_v);
+ w_bias_v = _mm256_mul_ps(w_bias_v, act_times_w_rcp_v);
+ }
+ xf_v = _mm256_add_ps(_mm256_cvtepi32_ps(x_v), x_bias_v);
+ yf_v = _mm256_add_ps(_mm256_cvtepi32_ps(y_v), y_bias_v);
+ zf_v = _mm256_add_ps(_mm256_cvtepi32_ps(z_v), z_bias_v);
+ wf_v = _mm256_add_ps(_mm256_cvtepi32_ps(w_v), w_bias_v);
+ } else {
+ x_v = _mm256_add_epi32(
+ x_v,
+ _mm256_loadu_si256(
+ reinterpret_cast<const __m256i*>(r.bias + j + 0 * VLEN)));
+ y_v = _mm256_add_epi32(
+ y_v,
+ _mm256_loadu_si256(
+ reinterpret_cast<const __m256i*>(r.bias + j + 1 * VLEN)));
+ z_v = _mm256_add_epi32(
+ z_v,
+ _mm256_loadu_si256(
+ reinterpret_cast<const __m256i*>(r.bias + j + 2 * VLEN)));
+ w_v = _mm256_add_epi32(
+ w_v,
+ _mm256_loadu_si256(
+ reinterpret_cast<const __m256i*>(r.bias + j + 3 * VLEN)));
+ xf_v = _mm256_cvtepi32_ps(x_v);
+ yf_v = _mm256_cvtepi32_ps(y_v);
+ zf_v = _mm256_cvtepi32_ps(z_v);
+ wf_v = _mm256_cvtepi32_ps(w_v);
+ }
+ } else {
+ xf_v = _mm256_cvtepi32_ps(x_v);
+ yf_v = _mm256_cvtepi32_ps(y_v);
+ zf_v = _mm256_cvtepi32_ps(z_v);
+ wf_v = _mm256_cvtepi32_ps(w_v);
}
/*
@@ -1127,17 +1368,13 @@ void requantizeOutputProcessingGConvAvx2(
*/
__m256 x_scaled_v, y_scaled_v, z_scaled_v, w_scaled_v;
if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) {
- x_scaled_v = _mm256_mul_ps(
- _mm256_cvtepi32_ps(x_v), _mm256_loadu_ps(r.C_multiplier + j));
- y_scaled_v = _mm256_mul_ps(
- _mm256_cvtepi32_ps(y_v),
- _mm256_loadu_ps(r.C_multiplier + j + VLEN));
- z_scaled_v = _mm256_mul_ps(
- _mm256_cvtepi32_ps(z_v),
- _mm256_loadu_ps(r.C_multiplier + j + 2 * VLEN));
- w_scaled_v = _mm256_mul_ps(
- _mm256_cvtepi32_ps(w_v),
- _mm256_loadu_ps(r.C_multiplier + j + 3 * VLEN));
+ x_scaled_v = _mm256_mul_ps(xf_v, _mm256_loadu_ps(r.C_multiplier + j));
+ y_scaled_v =
+ _mm256_mul_ps(yf_v, _mm256_loadu_ps(r.C_multiplier + j + VLEN));
+ z_scaled_v =
+ _mm256_mul_ps(zf_v, _mm256_loadu_ps(r.C_multiplier + j + 2 * VLEN));
+ w_scaled_v =
+ _mm256_mul_ps(wf_v, _mm256_loadu_ps(r.C_multiplier + j + 3 * VLEN));
} else if (Q_GRAN == QuantizationGranularity::GROUP) {
if (C_PER_G == 4) {
multiplier_v = _mm256_insertf128_ps(
@@ -1145,70 +1382,70 @@ void requantizeOutputProcessingGConvAvx2(
_mm_set1_ps(r.C_multiplier[quant_param_idx])),
_mm_set1_ps(r.C_multiplier[quant_param_idx + 1]),
1);
- x_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(x_v), multiplier_v);
+ x_scaled_v = _mm256_mul_ps(xf_v, multiplier_v);
multiplier_v = _mm256_insertf128_ps(
_mm256_castps128_ps256(
_mm_set1_ps(r.C_multiplier[quant_param_idx + 2])),
_mm_set1_ps(r.C_multiplier[quant_param_idx + 3]),
1);
- y_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(y_v), multiplier_v);
+ y_scaled_v = _mm256_mul_ps(yf_v, multiplier_v);
multiplier_v = _mm256_insertf128_ps(
_mm256_castps128_ps256(
_mm_set1_ps(r.C_multiplier[quant_param_idx + 4])),
_mm_set1_ps(r.C_multiplier[quant_param_idx + 5]),
1);
- z_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(z_v), multiplier_v);
+ z_scaled_v = _mm256_mul_ps(zf_v, multiplier_v);
multiplier_v = _mm256_insertf128_ps(
_mm256_castps128_ps256(
_mm_set1_ps(r.C_multiplier[quant_param_idx + 6])),
_mm_set1_ps(r.C_multiplier[quant_param_idx + 7]),
1);
- w_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(w_v), multiplier_v);
+ w_scaled_v = _mm256_mul_ps(wf_v, multiplier_v);
} else if (C_PER_G == 8) {
multiplier_v = _mm256_set1_ps(
r.C_multiplier
[quant_param_idx + (j - block.col_start) / (VLEN * 4) * 4]);
- x_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(x_v), multiplier_v);
-
- multiplier_v = _mm256_set1_ps(
- r.C_multiplier
- [quant_param_idx + (j - block.col_start) / (VLEN * 4) * 4 +
- 1]);
- y_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(y_v), multiplier_v);
-
- multiplier_v = _mm256_set1_ps(
- r.C_multiplier
- [quant_param_idx + (j - block.col_start) / (VLEN * 4) * 4 +
- 2]);
- z_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(z_v), multiplier_v);
-
- multiplier_v = _mm256_set1_ps(
- r.C_multiplier
- [quant_param_idx + (j - block.col_start) / (VLEN * 4) * 4 +
- 3]);
- w_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(w_v), multiplier_v);
+ x_scaled_v = _mm256_mul_ps(xf_v, multiplier_v);
+
+ multiplier_v =
+ _mm256_set1_ps(r.C_multiplier
+ [quant_param_idx +
+ (j - block.col_start) / (VLEN * 4) * 4 + 1]);
+ y_scaled_v = _mm256_mul_ps(yf_v, multiplier_v);
+
+ multiplier_v =
+ _mm256_set1_ps(r.C_multiplier
+ [quant_param_idx +
+ (j - block.col_start) / (VLEN * 4) * 4 + 2]);
+ z_scaled_v = _mm256_mul_ps(zf_v, multiplier_v);
+
+ multiplier_v =
+ _mm256_set1_ps(r.C_multiplier
+ [quant_param_idx +
+ (j - block.col_start) / (VLEN * 4) * 4 + 3]);
+ w_scaled_v = _mm256_mul_ps(wf_v, multiplier_v);
} else {
multiplier_v = _mm256_set1_ps(
r.C_multiplier
[quant_param_idx + (j - block.col_start) / (VLEN * 4) * 2]);
- x_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(x_v), multiplier_v);
- y_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(y_v), multiplier_v);
-
- multiplier_v = _mm256_set1_ps(
- r.C_multiplier
- [quant_param_idx + (j - block.col_start) / (VLEN * 4) * 2 +
- 1]);
- z_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(z_v), multiplier_v);
- w_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(w_v), multiplier_v);
+ x_scaled_v = _mm256_mul_ps(xf_v, multiplier_v);
+ y_scaled_v = _mm256_mul_ps(yf_v, multiplier_v);
+
+ multiplier_v =
+ _mm256_set1_ps(r.C_multiplier
+ [quant_param_idx +
+ (j - block.col_start) / (VLEN * 4) * 2 + 1]);
+ z_scaled_v = _mm256_mul_ps(zf_v, multiplier_v);
+ w_scaled_v = _mm256_mul_ps(wf_v, multiplier_v);
}
} else {
- x_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(x_v), multiplier_v);
- y_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(y_v), multiplier_v);
- z_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(z_v), multiplier_v);
- w_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(w_v), multiplier_v);
+ x_scaled_v = _mm256_mul_ps(xf_v, multiplier_v);
+ y_scaled_v = _mm256_mul_ps(yf_v, multiplier_v);
+ z_scaled_v = _mm256_mul_ps(zf_v, multiplier_v);
+ w_scaled_v = _mm256_mul_ps(wf_v, multiplier_v);
}
/*
@@ -1279,46 +1516,69 @@ void requantizeOutputProcessingGConvAvx2(
} // i loop
}
-#define INSTANTIATE_REQUANTIZE(A_SYM, B_SYM, Q_GRAN, BIAS, RELU) \
- template void \
- requantizeOutputProcessingAvx2<A_SYM, B_SYM, Q_GRAN, BIAS, RELU>( \
- uint8_t * out, \
- const int32_t* inp, \
- const block_type_t& block, \
- int ld_out, \
- int ld_in, \
- const requantizationParams_t& r); \
- template void requantizeForFloatAvx2<A_SYM, B_SYM, Q_GRAN, BIAS, RELU>( \
- float* out, \
- const int32_t* inp, \
- const block_type_t& block, \
- int ld_out, \
- int ld_in, \
- const requantizationForFloatParams_t& r); \
- template void \
- requantizeOutputProcessingGConvAvx2<A_SYM, B_SYM, Q_GRAN, BIAS, RELU, 4>( \
- uint8_t * out, \
- const int32_t* inp, \
- const block_type_t& block, \
- int ld_out, \
- int ld_in, \
- const requantizationParams_t& r); \
- template void \
- requantizeOutputProcessingGConvAvx2<A_SYM, B_SYM, Q_GRAN, BIAS, RELU, 8>( \
- uint8_t * out, \
- const int32_t* inp, \
- const block_type_t& block, \
- int ld_out, \
- int ld_in, \
- const requantizationParams_t& r); \
- template void \
- requantizeOutputProcessingGConvAvx2<A_SYM, B_SYM, Q_GRAN, BIAS, RELU, 16>( \
- uint8_t * out, \
- const int32_t* inp, \
- const block_type_t& block, \
- int ld_out, \
- int ld_in, \
- const requantizationParams_t& r);
+#define INSTANTIATE_REQUANTIZE_BIAS_TYPE( \
+ A_SYM, B_SYM, Q_GRAN, BIAS, RELU, BIAS_TYPE) \
+ template void \
+ requantizeOutputProcessingAvx2<A_SYM, B_SYM, Q_GRAN, BIAS, RELU, BIAS_TYPE>( \
+ uint8_t * out, \
+ const int32_t* inp, \
+ const block_type_t& block, \
+ int ld_out, \
+ int ld_in, \
+ const requantizationParams_t<BIAS_TYPE>& r); \
+ template void requantizeOutputProcessingGConvAvx2< \
+ A_SYM, \
+ B_SYM, \
+ Q_GRAN, \
+ BIAS, \
+ RELU, \
+ 4, \
+ BIAS_TYPE>( \
+ uint8_t * out, \
+ const int32_t* inp, \
+ const block_type_t& block, \
+ int ld_out, \
+ int ld_in, \
+ const requantizationParams_t<BIAS_TYPE>& r); \
+ template void requantizeOutputProcessingGConvAvx2< \
+ A_SYM, \
+ B_SYM, \
+ Q_GRAN, \
+ BIAS, \
+ RELU, \
+ 8, \
+ BIAS_TYPE>( \
+ uint8_t * out, \
+ const int32_t* inp, \
+ const block_type_t& block, \
+ int ld_out, \
+ int ld_in, \
+ const requantizationParams_t<BIAS_TYPE>& r); \
+ template void requantizeOutputProcessingGConvAvx2< \
+ A_SYM, \
+ B_SYM, \
+ Q_GRAN, \
+ BIAS, \
+ RELU, \
+ 16, \
+ BIAS_TYPE>( \
+ uint8_t * out, \
+ const int32_t* inp, \
+ const block_type_t& block, \
+ int ld_out, \
+ int ld_in, \
+ const requantizationParams_t<BIAS_TYPE>& r);
+
+#define INSTANTIATE_REQUANTIZE(A_SYM, B_SYM, Q_GRAN, BIAS, RELU) \
+ INSTANTIATE_REQUANTIZE_BIAS_TYPE(A_SYM, B_SYM, Q_GRAN, BIAS, RELU, float) \
+ INSTANTIATE_REQUANTIZE_BIAS_TYPE(A_SYM, B_SYM, Q_GRAN, BIAS, RELU, int32_t) \
+ template void requantizeForFloatAvx2<A_SYM, B_SYM, Q_GRAN, BIAS, RELU>( \
+ float* out, \
+ const int32_t* inp, \
+ const block_type_t& block, \
+ int ld_out, \
+ int ld_in, \
+ const requantizationForFloatParams_t& r);
#define INSTANTIATE_A_SYM(B_SYM, Q_GRAN, BIAS, RELU) \
INSTANTIATE_REQUANTIZE(true, B_SYM, Q_GRAN, BIAS, RELU) \
diff --git a/src/RefImplementations.cc b/src/RefImplementations.cc
index b4b0c2b..dc40d44 100644
--- a/src/RefImplementations.cc
+++ b/src/RefImplementations.cc
@@ -181,8 +181,7 @@ void cblas_sgemm_ref(
int ldb,
float beta,
float* Cfp32,
- int ldc
- ) {
+ int ldc) {
for (int i = 0; i < m; ++i) {
for (int j = 0; j < n; ++j) {
float sum = 0;
@@ -204,7 +203,6 @@ void cblas_sgemm_ref(
}
}
-
void row_offsets_u8acc32_ref(
int M,
int K,
@@ -302,9 +300,11 @@ void im2col_ref(
for (int h = 0; h < OUT_DIM[0]; ++h) {
for (int w = 0; w < OUT_DIM[1]; ++w) {
for (int r = 0; r < K[0]; ++r) {
- int h_in = -conv_p.pad[0] + h * conv_p.stride[0] + r;
+ int h_in =
+ -conv_p.pad[0] + h * conv_p.stride[0] + r * conv_p.dilation[0];
for (int s = 0; s < K[1]; ++s) {
- int w_in = -conv_p.pad[1] + w * conv_p.stride[1] + s;
+ int w_in =
+ -conv_p.pad[1] + w * conv_p.stride[1] + s * conv_p.dilation[1];
if (h_in < 0 || h_in >= IN_DIM[0] || w_in < 0 ||
w_in >= IN_DIM[1]) {
for (int g = 0; g < G; ++g) {
@@ -365,11 +365,14 @@ void im2col_ref(
for (int h = 0; h < OUT_DIM[1]; ++h) {
for (int w = 0; w < OUT_DIM[2]; ++w) {
for (int q = 0; q < K[0]; ++q) {
- int t_in = -conv_p.pad[0] + t * conv_p.stride[0] + q;
+ int t_in =
+ -conv_p.pad[0] + t * conv_p.stride[0] + q * conv_p.dilation[0];
for (int r = 0; r < K[1]; ++r) {
- int h_in = -conv_p.pad[1] + h * conv_p.stride[1] + r;
+ int h_in = -conv_p.pad[1] + h * conv_p.stride[1] +
+ r * conv_p.dilation[1];
for (int s = 0; s < K[2]; ++s) {
- int w_in = -conv_p.pad[2] + w * conv_p.stride[2] + s;
+ int w_in = -conv_p.pad[2] + w * conv_p.stride[2] +
+ s * conv_p.dilation[2];
if (t_in < 0 || t_in >= IN_DIM[0] || h_in < 0 ||
h_in >= IN_DIM[1] || w_in < 0 || w_in >= IN_DIM[2]) {
for (int g = 0; g < G; ++g) {
@@ -449,9 +452,11 @@ void conv_ref(
for (int m = 0; m < OC / G; ++m) {
int sum = 0;
for (int r = 0; r < K[0]; ++r) {
- int h_in = -conv_p.pad[0] + h * conv_p.stride[0] + r;
+ int h_in = -conv_p.pad[0] + h * conv_p.stride[0] +
+ r * conv_p.dilation[0];
for (int s = 0; s < K[1]; ++s) {
- int w_in = -conv_p.pad[1] + w * conv_p.stride[1] + s;
+ int w_in = -conv_p.pad[1] + w * conv_p.stride[1] +
+ s * conv_p.dilation[1];
for (int c = 0; c < IC / G; ++c) {
int a = h_in < 0 || h_in >= IN_DIM[0] || w_in < 0 ||
w_in >= IN_DIM[1]
@@ -501,11 +506,14 @@ void conv_ref(
for (int m = 0; m < OC / G; ++m) {
int sum = 0;
for (int q = 0; q < K[0]; ++q) {
- int t_in = -conv_p.pad[0] + t * conv_p.stride[0] + q;
+ int t_in = -conv_p.pad[0] + t * conv_p.stride[0] +
+ q * conv_p.dilation[0];
for (int r = 0; r < K[1]; ++r) {
- int h_in = -conv_p.pad[1] + h * conv_p.stride[1] + r;
+ int h_in = -conv_p.pad[1] + h * conv_p.stride[1] +
+ r * conv_p.dilation[1];
for (int s = 0; s < K[2]; ++s) {
- int w_in = -conv_p.pad[2] + w * conv_p.stride[2] + s;
+ int w_in = -conv_p.pad[2] + w * conv_p.stride[2] +
+ s * conv_p.dilation[2];
for (int c = 0; c < IC / G; ++c) {
int a = t_in < 0 || t_in >= IN_DIM[0] || h_in < 0 ||
h_in >= IN_DIM[1] || w_in < 0 || w_in >= IN_DIM[2]
@@ -542,427 +550,55 @@ void transposeConvWeights(
const conv_param_t<SPATIAL_DIM>& conv_p,
const std::int8_t* src,
std::int8_t* dest) {
- assert(SPATIAL_DIM == 2 && "Only 2D supported currently");
- int R = conv_p.K[0];
- int S = conv_p.K[1];
int G = conv_p.G;
int IC_per_G = conv_p.IC / conv_p.G;
int OC_per_G = conv_p.OC / conv_p.G;
- // Transforms weights from G K/G (R S C/G) to G (R S C/G) K/G format.
- for (int r = 0; r < R; ++r) {
- for (int s = 0; s < S; ++s) {
- for (int k = 0; k < OC_per_G; ++k) {
- for (int g = 0; g < G; ++g) {
- for (int c = 0; c < IC_per_G; ++c) {
- dest[(((g * R + r) * S + s) * IC_per_G + c) * OC_per_G + k] =
- src[(((g * OC_per_G + k) * R + r) * S + s) * IC_per_G + c];
- }
- }
- }
- }
- }
-}
-
-void depthwise_3x3_pad_1_ref(
- int N,
- int H,
- int W,
- int K,
- int stride_h,
- int stride_w,
- int32_t A_zero_point,
- const uint8_t* A,
- const int8_t* B,
- int32_t* C) {
- constexpr int R = 3, S = 3;
- constexpr int PAD_T = 1, PAD_B = 1, PAD_L = 1, PAD_R = 1;
- int H_OUT = (H + PAD_T + PAD_B - R) / stride_h + 1;
- int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1;
-
- for (int n = 0; n < N; ++n) {
- for (int h = 0; h < H_OUT; ++h) {
- for (int w = 0; w < W_OUT; ++w) {
- for (int k = 0; k < K; ++k) {
- int sum = 0;
- for (int r = 0; r < R; ++r) {
- int h_in = -PAD_T + h * stride_h + r;
- for (int s = 0; s < S; ++s) {
- int w_in = -PAD_L + w * stride_w + s;
- int a = h_in < 0 || h_in >= H || w_in < 0 || w_in >= W
- ? A_zero_point
- : A[((n * H + h_in) * W + w_in) * K + k];
- int b = B[(k * R + r) * S + s];
- sum += a * b;
- }
- }
- C[((n * H_OUT + h) * W_OUT + w) * K + k] = sum;
- }
- }
- }
- } // for each n
-};
-
-void depthwise_3x3_pad_1_ref(
- int N,
- int H,
- int W,
- int K,
- int stride_h,
- int stride_w,
- int32_t A_zero_point,
- const uint8_t* A,
- int32_t B_zero_point,
- const int8_t* B,
- float C_multiplier,
- int32_t C_zero_point,
- uint8_t* C,
- const int32_t* col_offsets,
- const int32_t* bias) {
- constexpr int R = 3, S = 3;
- constexpr int PAD_T = 1, PAD_B = 1, PAD_L = 1, PAD_R = 1;
- int H_OUT = (H + PAD_T + PAD_B - R) / stride_h + 1;
- int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1;
-
- vector<int32_t> C_int32(N * H_OUT * W_OUT * K);
- depthwise_3x3_pad_1_ref(
- N, H, W, K, stride_h, stride_w, A_zero_point, A, B, C_int32.data());
-
- vector<int32_t> row_offsets(N * H_OUT * W_OUT * K);
- for (int n = 0; n < N; ++n) {
- for (int h = 0; h < H_OUT; ++h) {
- for (int w = 0; w < W_OUT; ++w) {
- for (int k = 0; k < K; ++k) {
- int sum = 0;
- for (int r = 0; r < R; ++r) {
- int h_in = -PAD_T + h * stride_h + r;
- for (int s = 0; s < S; ++s) {
- int w_in = -PAD_L + w * stride_w + s;
- int a = h_in < 0 || h_in >= H || w_in < 0 || w_in >= W
- ? A_zero_point
- : A[((n * H + h_in) * W + w_in) * K + k];
- sum += a;
+ assert(
+ (SPATIAL_DIM == 3 || SPATIAL_DIM == 2) &&
+ "Only 2D and 3D convolutions are supported");
+ if (SPATIAL_DIM == 2) {
+ int R = conv_p.K[0];
+ int S = conv_p.K[1];
+ // Transforms weights from G K/G (R S C/G) to G (R S C/G) K/G format.
+ for (int r = 0; r < R; ++r) {
+ for (int s = 0; s < S; ++s) {
+ for (int k = 0; k < OC_per_G; ++k) {
+ for (int g = 0; g < G; ++g) {
+ for (int c = 0; c < IC_per_G; ++c) {
+ dest[(((g * R + r) * S + s) * IC_per_G + c) * OC_per_G + k] =
+ src[(((g * OC_per_G + k) * R + r) * S + s) * IC_per_G + c];
}
}
- row_offsets[((n * H_OUT + h) * W_OUT + w) * K + k] = sum;
}
}
}
- } // for each n
-
- for (int i = 0; i < N * H_OUT * W_OUT; ++i) {
- for (int k = 0; k < K; ++k) {
- requantize_u8acc32_ref(
- 1,
- 1,
- 1,
- C_int32.data() + i * K + k,
- C + i * K + k,
- &C_multiplier,
- C_zero_point,
- A_zero_point,
- &B_zero_point,
- &row_offsets[i * K + k],
- col_offsets + k,
- bias ? bias + k : nullptr,
- 1);
- }
- }
-};
-
-void depthwise_3x3_per_channel_quantization_pad_1_ref(
- int N,
- int H,
- int W,
- int K,
- int stride_h,
- int stride_w,
- int32_t A_zero_point,
- const uint8_t* A,
- const int32_t* B_zero_point,
- const int8_t* B,
- const float* C_multiplier,
- int32_t C_zero_point,
- uint8_t* C,
- const int32_t* col_offsets,
- const int32_t* bias) {
- constexpr int R = 3, S = 3;
- constexpr int PAD_T = 1, PAD_B = 1, PAD_L = 1, PAD_R = 1;
- int H_OUT = (H + PAD_T + PAD_B - R) / stride_h + 1;
- int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1;
-
- vector<int32_t> C_int32(N * H_OUT * W_OUT * K);
- depthwise_3x3_pad_1_ref(
- N, H, W, K, stride_h, stride_w, A_zero_point, A, B, C_int32.data());
-
- vector<int32_t> row_offsets(N * H_OUT * W_OUT * K);
- for (int n = 0; n < N; ++n) {
- for (int h = 0; h < H_OUT; ++h) {
- for (int w = 0; w < W_OUT; ++w) {
- for (int k = 0; k < K; ++k) {
- int sum = 0;
- for (int r = 0; r < R; ++r) {
- int h_in = -PAD_T + h * stride_h + r;
- for (int s = 0; s < S; ++s) {
- int w_in = -PAD_L + w * stride_w + s;
- int a = h_in < 0 || h_in >= H || w_in < 0 || w_in >= W
- ? A_zero_point
- : A[((n * H + h_in) * W + w_in) * K + k];
- sum += a;
+ } else {
+ // Transforms weights from G K/G (T R S C/G) to G (T R S C/G) K/G format.
+ int T = conv_p.K[0];
+ int R = conv_p.K[1];
+ int S = conv_p.K[2];
+ for (int t = 0; t < T; ++t) {
+ for (int r = 0; r < R; ++r) {
+ for (int s = 0; s < S; ++s) {
+ for (int k = 0; k < OC_per_G; ++k) {
+ for (int g = 0; g < G; ++g) {
+ for (int c = 0; c < IC_per_G; ++c) {
+ dest
+ [((((g * T + t) * R + r) * S + s) * IC_per_G + c) *
+ OC_per_G +
+ k] =
+ src[((((g * OC_per_G + k) * T + t) * R + r) * S + s) *
+ IC_per_G +
+ c];
+ }
}
}
- row_offsets[((n * H_OUT + h) * W_OUT + w) * K + k] = sum;
}
}
}
- } // for each n
-
- for (int i = 0; i < N * H_OUT * W_OUT; ++i) {
- for (int k = 0; k < K; ++k) {
- requantize_u8acc32_ref(
- 1,
- 1,
- 1,
- C_int32.data() + i * K + k,
- C + i * K + k,
- &C_multiplier[k],
- C_zero_point,
- A_zero_point,
- &B_zero_point[k],
- &row_offsets[i * K + k],
- col_offsets + k,
- bias ? bias + k : nullptr,
- 1);
- }
}
-};
-
-void depthwise_3x3x3_pad_1_ref(
- int N,
- int T,
- int H,
- int W,
- int K,
- int stride_t,
- int stride_h,
- int stride_w,
- int32_t A_zero_point,
- const uint8_t* A,
- const int8_t* B,
- int32_t* C) {
- constexpr int K_T = 3, K_H = 3, K_W = 3;
- constexpr int PAD_P = 1, PAD_N = 1, PAD_T = 1, PAD_B = 1, PAD_L = 1,
- PAD_R = 1;
- int T_OUT = (T + PAD_P + PAD_N - K_T) / stride_t + 1;
- int H_OUT = (H + PAD_T + PAD_B - K_H) / stride_h + 1;
- int W_OUT = (W + PAD_L + PAD_R - K_W) / stride_w + 1;
-
- for (int n = 0; n < N; ++n) {
- for (int t = 0; t < T_OUT; ++t) {
- for (int h = 0; h < H_OUT; ++h) {
- for (int w = 0; w < W_OUT; ++w) {
- for (int k = 0; k < K; ++k) {
- int sum = 0;
- for (int k_t = 0; k_t < K_T; ++k_t) {
- int t_in = -PAD_P + t * stride_t + k_t;
- for (int k_h = 0; k_h < K_H; ++k_h) {
- int h_in = -PAD_T + h * stride_h + k_h;
- for (int k_w = 0; k_w < K_W; ++k_w) {
- int w_in = -PAD_L + w * stride_w + k_w;
- int a = t_in < 0 || t_in >= T || h_in < 0 || h_in >= H ||
- w_in < 0 || w_in >= W
- ? A_zero_point
- : A[(((n * T + t_in) * H + h_in) * W + w_in) * K + k];
- int b = B[((k * K_T + k_t) * K_H + k_h) * K_W + k_w];
- sum += a * b;
- }
- }
- }
- C[(((n * T_OUT + t) * H_OUT + h) * W_OUT + w) * K + k] = sum;
- }
- } // w
- } // h
- } // t
- } // for each n
-};
-
-void depthwise_3x3x3_pad_1_ref(
- int N,
- int T,
- int H,
- int W,
- int K,
- int stride_t,
- int stride_h,
- int stride_w,
- int32_t A_zero_point,
- const uint8_t* A,
- int32_t B_zero_point,
- const int8_t* B,
- float C_multiplier,
- int32_t C_zero_point,
- uint8_t* C,
- const int32_t* col_offsets,
- const int32_t* bias) {
- constexpr int K_T = 3, K_H = 3, K_W = 3;
- constexpr int PAD_P = 1, PAD_N = 1, PAD_T = 1, PAD_B = 1, PAD_L = 1,
- PAD_R = 1;
- int T_OUT = (T + PAD_P + PAD_N - K_T) / stride_t + 1;
- int H_OUT = (H + PAD_T + PAD_B - K_H) / stride_h + 1;
- int W_OUT = (W + PAD_L + PAD_R - K_W) / stride_w + 1;
-
- vector<int32_t> C_int32(N * T_OUT * H_OUT * W_OUT * K);
- depthwise_3x3x3_pad_1_ref(
- N,
- T,
- H,
- W,
- K,
- stride_t,
- stride_h,
- stride_w,
- A_zero_point,
- A,
- B,
- C_int32.data());
-
- vector<int32_t> row_offsets(N * T_OUT * H_OUT * W_OUT * K);
- for (int n = 0; n < N; ++n) {
- for (int t = 0; t < T_OUT; ++t) {
- for (int h = 0; h < H_OUT; ++h) {
- for (int w = 0; w < W_OUT; ++w) {
- for (int k = 0; k < K; ++k) {
- int sum = 0;
- for (int k_t = 0; k_t < K_T; ++k_t) {
- int t_in = -PAD_P + t * stride_t + k_t;
- for (int k_h = 0; k_h < K_H; ++k_h) {
- int h_in = -PAD_T + h * stride_h + k_h;
- for (int k_w = 0; k_w < K_W; ++k_w) {
- int w_in = -PAD_L + w * stride_w + k_w;
- int a = t_in < 0 || t_in >= T || h_in < 0 || h_in >= H ||
- w_in < 0 || w_in >= W
- ? A_zero_point
- : A[(((n * T + t_in) * H + h_in) * W + w_in) * K + k];
- sum += a;
- }
- }
- }
- row_offsets[(((n * T_OUT + t) * H_OUT + h) * W_OUT + w) * K + k] =
- sum;
- }
- } // w
- } // h
- } // t
- } // for each n
-
- for (int i = 0; i < N * T_OUT * H_OUT * W_OUT; ++i) {
- for (int k = 0; k < K; ++k) {
- requantize_u8acc32_ref(
- 1,
- 1,
- 1,
- C_int32.data() + i * K + k,
- C + i * K + k,
- &C_multiplier,
- C_zero_point,
- A_zero_point,
- &B_zero_point,
- &row_offsets[i * K + k],
- col_offsets + k,
- bias ? bias + k : nullptr,
- 1);
- }
- }
-};
-
-void depthwise_3x3x3_per_channel_quantization_pad_1_ref(
- int N,
- int T,
- int H,
- int W,
- int K,
- int stride_t,
- int stride_h,
- int stride_w,
- int32_t A_zero_point,
- const uint8_t* A,
- const int32_t* B_zero_point,
- const int8_t* B,
- const float* C_multiplier,
- int32_t C_zero_point,
- uint8_t* C,
- const int32_t* col_offsets,
- const int32_t* bias) {
- constexpr int K_T = 3, K_H = 3, K_W = 3;
- constexpr int PAD_P = 1, PAD_N = 1, PAD_T = 1, PAD_B = 1, PAD_L = 1,
- PAD_R = 1;
- int T_OUT = (T + PAD_P + PAD_N - K_T) / stride_t + 1;
- int H_OUT = (H + PAD_T + PAD_B - K_H) / stride_h + 1;
- int W_OUT = (W + PAD_L + PAD_R - K_W) / stride_w + 1;
-
- vector<int32_t> C_int32(N * T_OUT * H_OUT * W_OUT * K);
- depthwise_3x3x3_pad_1_ref(
- N,
- T,
- H,
- W,
- K,
- stride_t,
- stride_h,
- stride_w,
- A_zero_point,
- A,
- B,
- C_int32.data());
-
- vector<int32_t> row_offsets(N * T_OUT * H_OUT * W_OUT * K);
- for (int n = 0; n < N; ++n) {
- for (int t = 0; t < T_OUT; ++t) {
- for (int h = 0; h < H_OUT; ++h) {
- for (int w = 0; w < W_OUT; ++w) {
- for (int k = 0; k < K; ++k) {
- int sum = 0;
- for (int k_t = 0; k_t < K_T; ++k_t) {
- int t_in = -PAD_P + t * stride_t + k_t;
- for (int k_h = 0; k_h < K_H; ++k_h) {
- int h_in = -PAD_T + h * stride_h + k_h;
- for (int k_w = 0; k_w < K_W; ++k_w) {
- int w_in = -PAD_L + w * stride_w + k_w;
- int a = t_in < 0 || t_in >= T || h_in < 0 || h_in >= H ||
- w_in < 0 || w_in >= W
- ? A_zero_point
- : A[(((n * T + t_in) * H + h_in) * W + w_in) * K + k];
- sum += a;
- }
- }
- }
- row_offsets[(((n * T_OUT + t) * H_OUT + h) * W_OUT + w) * K + k] =
- sum;
- }
- } // w
- } // h
- } // t
- } // for each n
-
- for (int i = 0; i < N * T_OUT * H_OUT * W_OUT; ++i) {
- for (int k = 0; k < K; ++k) {
- requantize_u8acc32_ref(
- 1,
- 1,
- 1,
- C_int32.data() + i * K + k,
- C + i * K + k,
- &C_multiplier[k],
- C_zero_point,
- A_zero_point,
- &B_zero_point[k],
- &row_offsets[i * K + k],
- col_offsets + k,
- bias ? bias + k : nullptr,
- 1);
- }
- }
-};
+}
template void transposeConvWeights(
const conv_param_t<2>& conv_p,
diff --git a/src/RefImplementations.h b/src/RefImplementations.h
index 082bdf1..a20e348 100644
--- a/src/RefImplementations.h
+++ b/src/RefImplementations.h
@@ -215,124 +215,4 @@ FBGEMM_API void im2col_ref(
std::int32_t A_zero_point,
std::uint8_t* Ao);
-/*
- * @brief Reference implementation of depthwise convolution with a 3x3 filter
- * and padding size 1.
- */
-FBGEMM_API void depthwise_3x3_pad_1_ref(
- int N,
- int H,
- int W,
- int K,
- int stride_h,
- int stride_w,
- std::int32_t A_zero_point,
- const std::uint8_t* A,
- const std::int8_t* B,
- std::int32_t* C);
-
-/*
- * @brief Reference implementation of depthwise convolution with a 3x3 filter
- * and padding size 1, followed by requantization. (the same scaling factors and
- * zero points for each channel).
- */
-FBGEMM_API void depthwise_3x3_pad_1_ref(
- int N,
- int H,
- int W,
- int K,
- int stride_h,
- int stride_w,
- std::int32_t A_zero_point,
- const std::uint8_t* A,
- std::int32_t B_zero_point,
- const std::int8_t* B,
- float C_multiplier,
- std::int32_t C_zero_point,
- std::uint8_t* C,
- const std::int32_t* col_offsets,
- const std::int32_t* bias);
-
-/*
- * @brief Reference implementation of depthwise convolution with a 3x3 filter
- * and padding size 1, followed by requantization. (different scaling factors
- * and zero points for each channel).
- */
-FBGEMM_API void depthwise_3x3_per_channel_quantization_pad_1_ref(
- int N,
- int H,
- int W,
- int K,
- int stride_h,
- int stride_w,
- std::int32_t A_zero_point,
- const std::uint8_t* A,
- const std::int32_t* B_zero_point,
- const std::int8_t* B,
- const float* C_multiplier,
- std::int32_t C_zero_point,
- std::uint8_t* C,
- const std::int32_t* col_offsets,
- const std::int32_t* bias);
-
-/*
- * @brief Reference implementation of 3D depthwise convolution with a 3x3x3
- * filter and padding size 1.
- */
-FBGEMM_API void depthwise_3x3x3_pad_1_ref(
- int N,
- int T,
- int H,
- int W,
- int K,
- int stride_t,
- int stride_h,
- int stride_w,
- std::int32_t A_zero_point,
- const std::uint8_t* A,
- const std::int8_t* B,
- std::int32_t* C);
-
-/*
- * @brief Reference implementation of 3D depthwise convolution with a 3x3x3
- * filter and padding size 1, followed by requantization.
- */
-FBGEMM_API void depthwise_3x3x3_pad_1_ref(
- int N,
- int T,
- int H,
- int W,
- int K,
- int stride_t,
- int stride_h,
- int stride_w,
- std::int32_t A_zero_point,
- const std::uint8_t* A,
- std::int32_t B_zero_point,
- const std::int8_t* B,
- float C_multiplier,
- std::int32_t C_zero_point,
- std::uint8_t* C,
- const std::int32_t* col_offsets,
- const std::int32_t* bias);
-
-FBGEMM_API void depthwise_3x3x3_per_channel_quantization_pad_1_ref(
- int N,
- int T,
- int H,
- int W,
- int K,
- int stride_t,
- int stride_h,
- int stride_w,
- std::int32_t A_zero_point,
- const std::uint8_t* A,
- const std::int32_t* B_zero_point,
- const std::int8_t* B,
- const float* C_multiplier,
- std::int32_t C_zero_point,
- std::uint8_t* C,
- const std::int32_t* col_offsets,
- const std::int32_t* bias);
-
} // namespace fbgemm
diff --git a/src/Utils.cc b/src/Utils.cc
index 355a5cb..2e88561 100755
--- a/src/Utils.cc
+++ b/src/Utils.cc
@@ -180,11 +180,7 @@ void transpose_simd(
// Run time CPU detection
if (cpuinfo_initialize()) {
if (fbgemmHasAvx512Support()) {
-#ifdef _MSC_VER
- internal::transpose_8x8(M, N, src, ld_src, dst, ld_dst);
-#else
internal::transpose_16x16(M, N, src, ld_src, dst, ld_dst);
-#endif
} else if (fbgemmHasAvx2Support()) {
internal::transpose_8x8(M, N, src, ld_src, dst, ld_dst);
} else {
@@ -206,4 +202,7 @@ bool fbgemmHasAvx2Support() {
return (cpuinfo_initialize() && cpuinfo_has_x86_avx2());
}
+bool fbgemmHasAvx512VnniSupport() {
+ return (cpuinfo_has_x86_avx512vnni());
+}
} // namespace fbgemm
diff --git a/test/FP16Test.cc b/test/FP16Test.cc
index eb49086..3267655 100644
--- a/test/FP16Test.cc
+++ b/test/FP16Test.cc
@@ -27,7 +27,26 @@ using namespace fbgemm;
namespace {
// The template parameter is transpose of A and B
class FBGemmFP16Test
- : public testing::TestWithParam<pair<matrix_op_t, matrix_op_t>> {};
+ : public testing::TestWithParam<pair<matrix_op_t, matrix_op_t>> {
+ protected:
+ vector<vector<int>> GenShapes() const {
+ vector<vector<int>> shapes;
+ random_device r;
+ default_random_engine generator(r());
+ uniform_int_distribution<int> dm(1, 256);
+ uniform_int_distribution<int> dnk(1, 1024);
+ for (int i = 0; i < 10; i++) {
+ int m = dm(generator);
+ int n = dnk(generator);
+ int k = dnk(generator);
+ shapes.push_back({m, n, k});
+ if (m > 10) {
+ shapes.push_back({(m / 10) * 10, n, k});
+ }
+ }
+ return shapes;
+ }
+};
}; // namespace
INSTANTIATE_TEST_CASE_P(
@@ -44,21 +63,75 @@ INSTANTIATE_TEST_CASE_P(
matrix_op_t::Transpose, matrix_op_t::Transpose)*/));
TEST_P(FBGemmFP16Test, Test) {
- vector<vector<int>> shapes;
- random_device r;
- default_random_engine generator(r());
- uniform_int_distribution<int> dm(1, 256);
- uniform_int_distribution<int> dnk(1, 1024);
- for (int i = 0; i < 10; i++) {
- int m = dm(generator);
- int n = dnk(generator);
- int k = dnk(generator);
- shapes.push_back({m, n, k});
- if (m > 10) {
- shapes.push_back({(m / 10) * 10, n, k});
+ auto shapes = GenShapes();
+ float alpha = 1.f, beta = 0.f;
+ matrix_op_t atrans, btrans;
+ tie(atrans, btrans) = GetParam();
+
+ for (auto s : shapes) {
+ int m = s[0];
+ int n = s[1];
+ int k = s[2];
+
+ cerr << "m = " << m << " n = " << n << " k = " << k;
+ if (atrans == matrix_op_t::Transpose) {
+ cerr << " A_transposed";
+ }
+ if (btrans == matrix_op_t::Transpose) {
+ cerr << " B_transposed";
+ }
+ cerr << endl;
+
+ // initialize with small numbers
+ aligned_vector<int> Aint(m * k);
+ aligned_vector<int> Bint(k * n);
+ randFill(Aint, 0, 4);
+ randFill(Bint, 0, 4);
+ aligned_vector<float> A(Aint.begin(), Aint.end());
+ aligned_vector<float> B(Bint.begin(), Bint.end());
+
+ aligned_vector<float> C(m * n, NAN);
+
+ aligned_vector<float> A_ref(A), B_ref(B), C_ref(C);
+
+ if (atrans == matrix_op_t::Transpose) {
+ transpose_matrix(A_ref.data(), k, m);
+ }
+ if (btrans == matrix_op_t::Transpose) {
+ transpose_matrix(B_ref.data(), n, k);
+ }
+
+ // Gold via reference sgemm
+ matmul_fp_ref(m, n, k, k, n, n, A_ref.data(), B_ref.data(), C_ref.data());
+
+ // fbgemm fp16
+ PackedGemmMatrixFP16 Bp(btrans, k, n, alpha, B.data());
+#ifdef _OPENMP
+#pragma omp parallel
+#endif
+ {
+ int num_threads = fbgemm_get_num_threads();
+ int tid = fbgemm_get_thread_num();
+
+ cblas_gemm_compute(
+ atrans, m, A.data(), Bp, beta, C.data(), tid, num_threads);
+ }
+
+ // correctness check
+ for (int i = 0; i < m; ++i) {
+ for (int j = 0; j < n; ++j) {
+ float expected = C_ref[i * n + j];
+ float actual = C[i * n + j];
+ EXPECT_EQ(expected, actual)
+ << "GEMM results differ at (" << i << ", " << j << "). ref "
+ << expected << " FBGemm " << actual;
+ }
}
}
+}
+TEST_P(FBGemmFP16Test, Unpack) {
+ auto shapes = GenShapes();
float alpha = 1.f, beta = 0.f;
matrix_op_t atrans, btrans;
tie(atrans, btrans) = GetParam();
@@ -101,6 +174,23 @@ TEST_P(FBGemmFP16Test, Test) {
// fbgemm fp16
PackedGemmMatrixFP16 Bp(btrans, k, n, alpha, B.data());
+ EXPECT_TRUE(Bp.packed());
+
+ // Test unpack
+ aligned_vector<float16> tmp(Bp.matSize());
+ memcpy(tmp.data(), Bp.pmat(), Bp.matSize() * sizeof(float16));
+ Bp.unpackFromSrc(btrans, tmp.data());
+ EXPECT_FALSE(Bp.packed());
+ memcpy(tmp.data(), Bp.pmat(), Bp.matSize() * sizeof(float16));
+ for (int i = 0; i < k; ++i) {
+ for (int j = 0; j < n; ++j) {
+ EXPECT_EQ(B[i * n + j], cpu_half2float(tmp[i * n + j]));
+ }
+ }
+
+ // Pack it back
+ Bp.packFromSrc(btrans, tmp.data());
+ EXPECT_TRUE(Bp.packed());
#ifdef _OPENMP
#pragma omp parallel
diff --git a/test/GConvTest.cc b/test/GConvTest.cc
index 84f0d52..982208b 100644
--- a/test/GConvTest.cc
+++ b/test/GConvTest.cc
@@ -25,14 +25,6 @@
using namespace std;
using namespace fbgemm;
-vector<matrix_op_t> transposeVals{matrix_op_t::NoTranspose,
- matrix_op_t::Transpose};
-
-vector<QuantizationGranularity> qGranularityVals{
- QuantizationGranularity::TENSOR,
- QuantizationGranularity::GROUP,
- QuantizationGranularity::OUT_CHANNEL};
-
namespace {
class fbgemmGConvAcc32Test
: public testing::TestWithParam<tuple<matrix_op_t, matrix_op_t>> {};
@@ -43,6 +35,8 @@ class fbgemmGConvAcc32WithQuantGranularityTest
QuantizationGranularity,
bool,
bool>> {};
+class fbgemmGConvPackTest
+ : public testing::TestWithParam<tuple<matrix_op_t, matrix_op_t>> {};
}; // namespace
INSTANTIATE_TEST_CASE_P(
@@ -61,6 +55,13 @@ INSTANTIATE_TEST_CASE_P(
::testing::ValuesIn(qGranularityVals),
::testing::Bool(), // A symmetric
::testing::Bool())); // B symmetric
+
+INSTANTIATE_TEST_CASE_P(
+ InstantiationName,
+ fbgemmGConvPackTest,
+ ::testing::Combine(
+ ::testing::Values(matrix_op_t::NoTranspose),
+ ::testing::ValuesIn(transposeVals)));
/**
* @brief Shapes for unit test.
*/
@@ -413,3 +414,51 @@ TEST_P(fbgemmGConvAcc32Test, NoRequantizeTest) {
static_cast<int32_t>(0));
} // for each shape
}
+
+/**
+ * @brief Unit test for packing and unpacking the weight tensor
+ */
+TEST_P(fbgemmGConvPackTest, PackUnpackTest) {
+ vector<conv_param_t<>> shapes(GetShapes_());
+ matrix_op_t atrans, btrans;
+ tie(atrans, btrans) = GetParam();
+
+ for (auto conv_p : shapes) {
+ int R = conv_p.K[0];
+ int S = conv_p.K[1];
+ int IC_per_G = conv_p.IC / conv_p.G;
+ int OC_per_G = conv_p.OC / conv_p.G;
+
+ // Weights -- test the packing/unpacking of only the weights
+ // when btrans == Transpose, the weight matrix is in layout G K/G (R S C/G)
+ // instead of G (R S C/G) K/G
+ int weight_len = R * S * conv_p.G * IC_per_G * OC_per_G;
+ aligned_vector<int8_t> Bint8(weight_len, 0);
+
+ // Random fill the weights
+ randFill<int8_t>(Bint8, -4, 4);
+
+ // Instantiate the object
+ PackWeightMatrixForGConv<int8_t> packedWeights(
+ btrans, conv_p, Bint8.data(), nullptr);
+
+ // Setup a buffer to get pack -> unpacked results
+ aligned_vector<int8_t> unpack_buf(weight_len, 0);
+
+ // START Actual pack-unpack operations
+ // Perform packing first. This should populate pdata_ of packedWeights
+ packedWeights.pack();
+
+ // Next perform unpacking
+ packedWeights.unpack(unpack_buf.data());
+ // END actual pack-unpack operations
+
+ // Sanity check
+ for (int i = 0; i < weight_len; ++i) {
+ EXPECT_EQ(Bint8.data()[i], unpack_buf.data()[i])
+ << "Pack/Unpack results differ at index " << i
+ << ", Reference: " << static_cast<int>(Bint8.data()[i])
+ << ", Pack-Unpacked: " << static_cast<int>(unpack_buf.data()[i]);
+ }
+ } // for each shape
+}
diff --git a/test/I8DepthwiseTest.cc b/test/I8DepthwiseTest.cc
index 11bd625..9de6943 100644
--- a/test/I8DepthwiseTest.cc
+++ b/test/I8DepthwiseTest.cc
@@ -4,7 +4,6 @@
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
-#include "I8DepthwiseTest.h"
#include <cmath>
#include <cstdio>
@@ -22,6 +21,7 @@ using namespace std;
namespace fbgemm {
// From Xray OCR
+// clang-format off
static vector<vector<int>> shapes = {
// NOTE: clang-format wants to use a different formatting but the current
// formatting should be easier to read.
@@ -68,9 +68,28 @@ static vector<vector<int>> shapes = {
{ 1, 8, 4, 4, 1, },
};
+static vector<vector<int>> shapes_3d = {
+ // NOTE: clang-format wants to use a different formatting but the current
+ // formatting should be easier to read.
+ // N, K, T_in, H_in, W_in, stride
+ { 1, 32, 16, 28, 28, 1, },
+ { 1, 128, 8, 14, 14, 2, },
+ { 5, 16, 32, 56, 56, 1, },
+ { 1, 8, 4, 4, 4, 1, },
+};
+// clang-format on
+
namespace {
-class FBGemmDepthWiseTest
- : public testing::TestWithParam<tuple<bool, bool>> {};
+
+class FBGemmDepthWiseTest : public testing::TestWithParam<tuple<bool, bool>> {};
+
+// Two parameters are K (or Groups) and kernel_prod, i.e.,
+// (output_channels)(kernel_prod)
+// output_channels == Groups.
+// For example, kernel_prod for 3x3 convolution is 9
+class FBGemmDepthWisePackUnpackTest
+ : public testing::TestWithParam<tuple<int, int>> {};
+
} // namespace
INSTANTIATE_TEST_CASE_P(
@@ -78,6 +97,13 @@ INSTANTIATE_TEST_CASE_P(
FBGemmDepthWiseTest,
::testing::Combine(::testing::Bool(), ::testing::Bool()));
+INSTANTIATE_TEST_CASE_P(
+ InstantiationName,
+ FBGemmDepthWisePackUnpackTest,
+ ::testing::Combine(
+ ::testing::ValuesIn({8, 16, 24, 32, 40, 64, 72}),
+ ::testing::ValuesIn({1, 2, 3, 4, 5, 9, 10, 11, 27})));
+
TEST_P(FBGemmDepthWiseTest, Test3x3) {
bool a_symmetric, b_symmetric;
tie(a_symmetric, b_symmetric) = GetParam();
@@ -90,13 +116,29 @@ TEST_P(FBGemmDepthWiseTest, Test3x3) {
int stride_h = shape[4];
int stride_w = stride_h;
constexpr int R = 3, S = 3;
- constexpr int PAD_T = 1, PAD_B = 1, PAD_L = 1, PAD_R = 1;
- int H_OUT = (H + PAD_T + PAD_B - R) / stride_h + 1;
- int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1;
+ int PAD_T = (R - 1) / 2, PAD_B = (R - 1) / 2, PAD_L = (S - 1) / 2,
+ PAD_R = (S - 1) / 2;
+
+ conv_param_t<2> conv_p(
+ N,
+ K,
+ K,
+ {H, W},
+ K,
+ {R, S},
+ {stride_h, stride_w},
+ {PAD_T, PAD_L, PAD_B, PAD_R});
+ int H_OUT = conv_p.OUT_DIM[0];
+ int W_OUT = conv_p.OUT_DIM[1];
+
+ int MDim = N * H_OUT * W_OUT;
+ int KDim = R * S * K;
+ int KDimPerGroup = KDim / conv_p.G;
aligned_vector<uint8_t> A(N * H * W * K);
- aligned_vector<int8_t> B(K * R * S);
- aligned_vector<int32_t> C_ref(N * H_OUT * W_OUT * K), C(C_ref.size());
+ aligned_vector<int8_t> B(KDim);
+ aligned_vector<int32_t> C_ref(MDim * K), C(C_ref.size());
+ aligned_vector<uint8_t> C_uint8_ref(C_ref.size()), C_uint8(C_ref.size());
randFill<uint8_t>(A, 0, 86);
int32_t A_zero_point = a_symmetric ? 0 : 43;
@@ -104,48 +146,54 @@ TEST_P(FBGemmDepthWiseTest, Test3x3) {
randFill<int8_t>(B, -16, 16);
int32_t B_zero_point = b_symmetric ? 0 : 5;
- depthwise_3x3_pad_1_ref(
- N,
- H,
- W,
- K,
- stride_h,
- stride_w,
- A_zero_point,
- A.data(),
- B.data(),
- C_ref.data());
-
- int32_t minimum = *min_element(C_ref.begin(), C_ref.end());
- int32_t maximum = *max_element(C_ref.begin(), C_ref.end());
-
- float C_multiplier = 255. / (maximum - minimum);
+ aligned_vector<float> C_multiplier(1);
+ randFill(C_multiplier, 0.001234f / 2, 0.001234f * 3 / 2);
+ int32_t C_zero_point = 5;
aligned_vector<int32_t> col_offsets(K);
aligned_vector<int32_t> bias(K);
randFill(col_offsets, -100, 100);
randFill(bias, -40, 40);
- int32_t C_zero_point = 5;
- aligned_vector<uint8_t> C_uint8_ref(C_ref.size()), C_uint8(C_ref.size());
- depthwise_3x3_pad_1_ref(
- N,
- H,
- W,
- K,
- stride_h,
- stride_w,
- A_zero_point,
- A.data(),
- B_zero_point,
- B.data(),
- C_multiplier,
- C_zero_point,
- C_uint8_ref.data(),
- col_offsets.data(),
- bias.data());
+ vector<int32_t> row_offsets(MDim);
+ // im2col to compute row offset later
+ vector<uint8_t> A_im2col;
+ if (!b_symmetric) {
+ A_im2col.resize(MDim * KDim);
+ im2col_ref(conv_p, A.data(), A_zero_point, A_im2col.data());
+ }
- Packed3x3ConvMatrix Bp(K, B.data());
+ conv_ref(conv_p, A.data(), A_zero_point, B.data(), C_ref.data());
+
+ for (int g = 0; g < conv_p.G; ++g) {
+ // Compute row offset
+ if (!b_symmetric) {
+ row_offsets_u8acc32_ref(
+ MDim,
+ KDimPerGroup,
+ KDim,
+ A_im2col.data() + g * KDimPerGroup,
+ row_offsets.data());
+ }
+
+ // Requantization
+ requantize_u8acc32_ref(
+ MDim,
+ 1,
+ conv_p.G,
+ C_ref.data() + g,
+ C_uint8_ref.data() + g,
+ C_multiplier.data(),
+ C_zero_point,
+ A_zero_point,
+ &B_zero_point,
+ row_offsets.data(),
+ col_offsets.data() + g,
+ bias.data() + g,
+ K);
+ }
+
+ PackedDepthWiseConvMatrix Bp(K, 3 * 3, B.data());
depthwise_3x3_pad_1(
N,
@@ -158,12 +206,13 @@ TEST_P(FBGemmDepthWiseTest, Test3x3) {
A.data(),
B_zero_point,
Bp,
- C_multiplier,
+ C_multiplier[0],
C_zero_point,
C_uint8.data(),
a_symmetric ? nullptr : col_offsets.data(),
bias.data(),
false, /* fuse_relu */
+ 1.0f, /* act_scale * w_scale */
0,
1);
@@ -205,67 +254,83 @@ TEST_P(FBGemmDepthWiseTest, Test3x3x3) {
constexpr int K_T = 3, K_H = 3, K_W = 3;
constexpr int PAD_P = 1, PAD_N = 1, PAD_T = 1, PAD_B = 1, PAD_L = 1,
PAD_R = 1;
- int T_OUT = (T + PAD_P + PAD_N - K_T) / stride_t + 1;
- int H_OUT = (H + PAD_T + PAD_B - K_H) / stride_h + 1;
- int W_OUT = (W + PAD_L + PAD_R - K_W) / stride_w + 1;
+
+ conv_param_t<3> conv_p(
+ N,
+ K,
+ K,
+ {T, H, W},
+ K,
+ {K_T, K_H, K_W},
+ {stride_t, stride_h, stride_w},
+ {PAD_P, PAD_T, PAD_L, PAD_N, PAD_B, PAD_R});
+ int T_OUT = conv_p.OUT_DIM[0];
+ int H_OUT = conv_p.OUT_DIM[1];
+ int W_OUT = conv_p.OUT_DIM[2];
+
+ int MDim = N * T_OUT * H_OUT * W_OUT;
+ int KDim = K_T * K_H * K_W * K;
+ int KDimPerGroup = KDim / conv_p.G;
aligned_vector<uint8_t> A(N * T * H * W * K);
- aligned_vector<int8_t> B(K * K_T * K_H * K_W);
- aligned_vector<int32_t> C_ref(N * T_OUT * H_OUT * W_OUT * K),
- C(C_ref.size());
+ aligned_vector<int8_t> B(KDim);
+ aligned_vector<int32_t> C_ref(MDim * K), C(C_ref.size());
+ aligned_vector<uint8_t> C_uint8_ref(C_ref.size()), C_uint8(C_ref.size());
randFill<uint8_t>(A, 0, 86);
int32_t A_zero_point = a_symmetric ? 0 : 43;
randFill<int8_t>(B, -16, 16);
- int32_t B_zero_point = 5;
-
- depthwise_3x3x3_pad_1_ref(
- N,
- T,
- H,
- W,
- K,
- stride_t,
- stride_h,
- stride_w,
- A_zero_point,
- A.data(),
- B.data(),
- C_ref.data());
-
- int32_t minimum = *min_element(C_ref.begin(), C_ref.end());
- int32_t maximum = *max_element(C_ref.begin(), C_ref.end());
+ int32_t B_zero_point = b_symmetric ? 0 : 5;
- float C_multiplier = 255. / (maximum - minimum);
+ aligned_vector<float> C_multiplier(1);
+ randFill(C_multiplier, 0.001234f / 2, 0.001234f * 3 / 2);
+ int32_t C_zero_point = 5;
aligned_vector<int32_t> col_offsets(K);
aligned_vector<int32_t> bias(K);
randFill(col_offsets, -100, 100);
randFill(bias, -40, 40);
- int32_t C_zero_point = 5;
- aligned_vector<uint8_t> C_uint8_ref(C_ref.size()), C_uint8(C_ref.size());
- depthwise_3x3x3_pad_1_ref(
- N,
- T,
- H,
- W,
- K,
- stride_t,
- stride_h,
- stride_w,
- A_zero_point,
- A.data(),
- B_zero_point,
- B.data(),
- C_multiplier,
- C_zero_point,
- C_uint8_ref.data(),
- col_offsets.data(),
- bias.data());
+ vector<int32_t> row_offsets(MDim);
+ // im2col to compute row offset later
+ vector<uint8_t> A_im2col;
+ if (!b_symmetric) {
+ A_im2col.resize(MDim * KDim);
+ im2col_ref(conv_p, A.data(), A_zero_point, A_im2col.data());
+ }
- Packed3x3x3ConvMatrix Bp(K, B.data());
+ conv_ref(conv_p, A.data(), A_zero_point, B.data(), C_ref.data());
+
+ for (int g = 0; g < conv_p.G; ++g) {
+ // Compute row offset
+ if (!b_symmetric) {
+ row_offsets_u8acc32_ref(
+ MDim,
+ KDimPerGroup,
+ KDim,
+ A_im2col.data() + g * KDimPerGroup,
+ row_offsets.data());
+ }
+
+ // Requantization
+ requantize_u8acc32_ref(
+ MDim,
+ 1,
+ conv_p.G,
+ C_ref.data() + g,
+ C_uint8_ref.data() + g,
+ C_multiplier.data(),
+ C_zero_point,
+ A_zero_point,
+ &B_zero_point,
+ row_offsets.data(),
+ col_offsets.data() + g,
+ bias.data() + g,
+ K);
+ }
+
+ PackedDepthWiseConvMatrix Bp(K, 3 * 3 * 3, B.data());
depthwise_3x3x3_pad_1(
N,
@@ -280,10 +345,10 @@ TEST_P(FBGemmDepthWiseTest, Test3x3x3) {
A.data(),
B_zero_point,
Bp,
- C_multiplier,
+ C_multiplier[0],
C_zero_point,
C_uint8.data(),
- col_offsets.data(),
+ a_symmetric ? nullptr : col_offsets.data(),
bias.data(),
false, /* fuse_relu */
0,
@@ -297,8 +362,8 @@ TEST_P(FBGemmDepthWiseTest, Test3x3x3) {
for (int k = 0; k < K; ++k) {
int32_t expected = C_uint8_ref
[(((n * T_OUT + t) * H_OUT + h) * W_OUT + w) * K + k];
- int32_t actual = C_uint8
- [(((n * T_OUT + t) * H_OUT + h) * W_OUT + w) * K + k];
+ int32_t actual =
+ C_uint8[(((n * T_OUT + t) * H_OUT + h) * W_OUT + w) * K + k];
EXPECT_EQ(expected, actual)
<< "Depthwise 3x3 results differ at (" << n << ", " << t
<< ", " << h << ", " << w << ", " << k << ").";
@@ -319,14 +384,29 @@ TEST(FBGemmDepthWiseTest, Test3x3PerChannelQuantization) {
int stride_h = shape[4];
int stride_w = stride_h;
constexpr int R = 3, S = 3;
- constexpr int PAD_T = 1, PAD_B = 1, PAD_L = 1, PAD_R = 1;
- int H_OUT = (H + PAD_T + PAD_B - R) / stride_h + 1;
- int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1;
+ int PAD_T = (R - 1) / 2, PAD_B = (R - 1) / 2, PAD_L = (S - 1) / 2,
+ PAD_R = (S - 1) / 2;
+
+ conv_param_t<2> conv_p(
+ N,
+ K,
+ K,
+ {H, W},
+ K,
+ {R, S},
+ {stride_h, stride_w},
+ {PAD_T, PAD_L, PAD_B, PAD_R});
+ int H_OUT = conv_p.OUT_DIM[0];
+ int W_OUT = conv_p.OUT_DIM[1];
+
+ int MDim = N * H_OUT * W_OUT;
+ int KDim = R * S * K;
+ int KDimPerGroup = KDim / conv_p.G;
aligned_vector<uint8_t> A(N * H * W * K);
- aligned_vector<int8_t> B(K * R * S);
- int32_t C_num_rows = N * H_OUT * W_OUT;
- aligned_vector<int32_t> C_ref(C_num_rows * K), C(C_ref.size());
+ aligned_vector<int8_t> B(KDim);
+ aligned_vector<int32_t> C_ref(MDim * K), C(C_ref.size());
+ aligned_vector<uint8_t> C_uint8_ref(C_ref.size()), C_uint8(C_ref.size());
randFill<uint8_t>(A, 0, 86);
int32_t A_zero_point = 43;
@@ -342,28 +422,8 @@ TEST(FBGemmDepthWiseTest, Test3x3PerChannelQuantization) {
B_zero_point[k] = 5 + k;
}
- depthwise_3x3_pad_1_ref(
- N,
- H,
- W,
- K,
- stride_h,
- stride_w,
- A_zero_point,
- A.data(),
- B.data(),
- C_ref.data());
-
- aligned_vector<int32_t> C_ref_transpose(C_ref);
- transpose_matrix(C_ref.data(), C_num_rows, K);
- vector<float> C_multiplier(K);
- for (auto k = 0; k < K; ++k) {
- auto C_ref_k_begin = C_ref_transpose.begin() + k * C_num_rows;
- auto C_ref_k_end = C_ref_k_begin + C_num_rows;
- int32_t minimum = *min_element(C_ref_k_begin, C_ref_k_end);
- int32_t maximum = *max_element(C_ref_k_begin, C_ref_k_end);
- C_multiplier[k] = 255. / (maximum - minimum);
- }
+ aligned_vector<float> C_multiplier(K);
+ randFill(C_multiplier, 0.001234f / 2, 0.001234f * 3 / 2);
int32_t C_zero_point = 5;
aligned_vector<int32_t> col_offsets(K);
@@ -371,25 +431,40 @@ TEST(FBGemmDepthWiseTest, Test3x3PerChannelQuantization) {
randFill(col_offsets, -100, 100);
randFill(bias, -40, 40);
- aligned_vector<uint8_t> C_uint8_ref(C_ref.size()), C_uint8(C_ref.size());
- depthwise_3x3_per_channel_quantization_pad_1_ref(
- N,
- H,
- W,
- K,
- stride_h,
- stride_w,
- A_zero_point,
- A.data(),
- B_zero_point.data(),
- B.data(),
- C_multiplier.data(),
- C_zero_point,
- C_uint8_ref.data(),
- col_offsets.data(),
- bias.data());
+ // im2col to compute row offset later
+ vector<int32_t> row_offsets(MDim);
+ vector<uint8_t> A_im2col(MDim * KDim);
+ im2col_ref(conv_p, A.data(), A_zero_point, A_im2col.data());
+
+ conv_ref(conv_p, A.data(), A_zero_point, B.data(), C_ref.data());
+
+ for (int g = 0; g < conv_p.G; ++g) {
+ // Compute row offset
+ row_offsets_u8acc32_ref(
+ MDim,
+ KDimPerGroup,
+ KDim,
+ A_im2col.data() + g * KDimPerGroup,
+ row_offsets.data());
+
+ // Requantization
+ requantize_u8acc32_ref(
+ MDim,
+ 1,
+ conv_p.G,
+ C_ref.data() + g,
+ C_uint8_ref.data() + g,
+ C_multiplier.data() + g,
+ C_zero_point,
+ A_zero_point,
+ B_zero_point.data() + g,
+ row_offsets.data(),
+ col_offsets.data() + g,
+ bias.data() + g,
+ K);
+ }
- Packed3x3ConvMatrix Bp(K, B.data());
+ PackedDepthWiseConvMatrix Bp(K, 3 * 3, B.data());
depthwise_3x3_per_channel_quantization_pad_1(
N,
@@ -442,14 +517,28 @@ TEST(FBGemmDepthWiseTest, Test3x3x3PerChannelQuantization) {
constexpr int K_T = 3, K_H = 3, K_W = 3;
constexpr int PAD_P = 1, PAD_N = 1, PAD_T = 1, PAD_B = 1, PAD_L = 1,
PAD_R = 1;
- int T_OUT = (T + PAD_P + PAD_N - K_T) / stride_t + 1;
- int H_OUT = (H + PAD_T + PAD_B - K_H) / stride_h + 1;
- int W_OUT = (W + PAD_L + PAD_R - K_W) / stride_w + 1;
+
+ conv_param_t<3> conv_p(
+ N,
+ K,
+ K,
+ {T, H, W},
+ K,
+ {K_T, K_H, K_W},
+ {stride_t, stride_h, stride_w},
+ {PAD_P, PAD_T, PAD_L, PAD_N, PAD_B, PAD_R});
+ int T_OUT = conv_p.OUT_DIM[0];
+ int H_OUT = conv_p.OUT_DIM[1];
+ int W_OUT = conv_p.OUT_DIM[2];
+
+ int MDim = N * T_OUT * H_OUT * W_OUT;
+ int KDim = K_T * K_H * K_W * K;
+ int KDimPerGroup = KDim / conv_p.G;
aligned_vector<uint8_t> A(N * T * H * W * K);
- aligned_vector<int8_t> B(K * K_T * K_H * K_W);
- int32_t C_num_rows = N * T_OUT * H_OUT * W_OUT;
- aligned_vector<int32_t> C_ref(C_num_rows * K), C(C_ref.size());
+ aligned_vector<int8_t> B(KDim);
+ aligned_vector<int32_t> C_ref(MDim * K), C(C_ref.size());
+ aligned_vector<uint8_t> C_uint8_ref(C_ref.size()), C_uint8(C_ref.size());
randFill<uint8_t>(A, 0, 86);
int32_t A_zero_point = 43;
@@ -465,30 +554,8 @@ TEST(FBGemmDepthWiseTest, Test3x3x3PerChannelQuantization) {
B_zero_point[k] = 5 + k;
}
- depthwise_3x3x3_pad_1_ref(
- N,
- T,
- H,
- W,
- K,
- stride_t,
- stride_h,
- stride_w,
- A_zero_point,
- A.data(),
- B.data(),
- C_ref.data());
-
- aligned_vector<int32_t> C_ref_transpose(C_ref);
- transpose_matrix(C_ref.data(), C_num_rows, K);
- vector<float> C_multiplier(K);
- for (auto k = 0; k < K; ++k) {
- auto C_ref_k_begin = C_ref_transpose.begin() + k * C_num_rows;
- auto C_ref_k_end = C_ref_k_begin + C_num_rows;
- int32_t minimum = *min_element(C_ref_k_begin, C_ref_k_end);
- int32_t maximum = *max_element(C_ref_k_begin, C_ref_k_end);
- C_multiplier[k] = 255. / (maximum - minimum);
- }
+ aligned_vector<float> C_multiplier(K);
+ randFill(C_multiplier, 0.001234f / 2, 0.001234f * 3 / 2);
int32_t C_zero_point = 5;
aligned_vector<int32_t> col_offsets(K);
@@ -496,27 +563,40 @@ TEST(FBGemmDepthWiseTest, Test3x3x3PerChannelQuantization) {
randFill(col_offsets, -100, 100);
randFill(bias, -40, 40);
- aligned_vector<uint8_t> C_uint8_ref(C_ref.size()), C_uint8(C_ref.size());
- depthwise_3x3x3_per_channel_quantization_pad_1_ref(
- N,
- T,
- H,
- W,
- K,
- stride_t,
- stride_h,
- stride_w,
- A_zero_point,
- A.data(),
- B_zero_point.data(),
- B.data(),
- C_multiplier.data(),
- C_zero_point,
- C_uint8_ref.data(),
- col_offsets.data(),
- bias.data());
+ vector<int32_t> row_offsets(MDim);
+ // im2col to compute row offset later
+ vector<uint8_t> A_im2col(MDim * KDim);
+ im2col_ref(conv_p, A.data(), A_zero_point, A_im2col.data());
+
+ conv_ref(conv_p, A.data(), A_zero_point, B.data(), C_ref.data());
+
+ for (int g = 0; g < conv_p.G; ++g) {
+ // Compute row offset
+ row_offsets_u8acc32_ref(
+ MDim,
+ KDimPerGroup,
+ KDim,
+ A_im2col.data() + g * KDimPerGroup,
+ row_offsets.data());
+
+ // Requantization
+ requantize_u8acc32_ref(
+ MDim,
+ 1,
+ conv_p.G,
+ C_ref.data() + g,
+ C_uint8_ref.data() + g,
+ C_multiplier.data() + g,
+ C_zero_point,
+ A_zero_point,
+ B_zero_point.data() + g,
+ row_offsets.data(),
+ col_offsets.data() + g,
+ bias.data() + g,
+ K);
+ }
- Packed3x3x3ConvMatrix Bp(K, B.data());
+ PackedDepthWiseConvMatrix Bp(K, 3 * 3 * 3, B.data());
depthwise_3x3x3_per_channel_quantization_pad_1(
N,
@@ -561,4 +641,22 @@ TEST(FBGemmDepthWiseTest, Test3x3x3PerChannelQuantization) {
} // for each shape
} // Test3x3PerChannelQuantization
+TEST_P(FBGemmDepthWisePackUnpackTest, TestPackUnpack) {
+ int K, kernel_prod;
+ tie(K, kernel_prod) = GetParam();
+
+ ASSERT_EQ(K % 8, 0)
+ << "output channels (== groups) should be a multiple of 8";
+ aligned_vector<int8_t> B(K * kernel_prod);
+ randFill<int8_t>(B, -16, 16);
+
+ aligned_vector<int8_t> BUnpacked(K * kernel_prod);
+
+ PackedDepthWiseConvMatrix BPacked(K, kernel_prod, B.data());
+ BPacked.unpack(BUnpacked.data());
+
+ ASSERT_EQ(B, BUnpacked)
+ << "Original and unpacked data elements are not the same";
+} // TestPackUnpack
+
} // namespace fbgemm
diff --git a/test/I8DepthwiseTest.h b/test/I8DepthwiseTest.h
deleted file mode 100644
index d65362a..0000000
--- a/test/I8DepthwiseTest.h
+++ /dev/null
@@ -1,39 +0,0 @@
-/*
- * Copyright (c) Facebook, Inc. and its affiliates.
- * All rights reserved.
- * This source code is licensed under the BSD-style license found in the
- * LICENSE file in the root directory of this source tree.
- */
-#pragma once
-
-#include <vector>
-
-namespace fbgemm {
-
-// From ResNeXt-3D-101
-static std::vector<std::vector<int>> shapes_3d = {
- // NOTE: clang-format wants to use a different formatting but the current
- // formatting should be easier to read.
- // N, K, T_in, H_in, W_in, stride
- { 1, 64, 32, 56, 56, 1, },
- { 1, 128, 16, 28, 28, 1, },
- { 1, 256, 8, 14, 14, 1, },
- { 1, 512, 4, 7, 7, 1, },
-
- { 1, 128, 32, 56, 56, 2, },
- { 1, 256, 16, 28, 28, 2, },
- { 1, 512, 8, 14, 14, 2, },
-
- { 5, 64, 32, 56, 56, 1, },
- { 5, 128, 16, 28, 28, 1, },
- { 5, 256, 8, 14, 14, 1, },
- { 5, 512, 4, 7, 7, 1, },
-
- { 5, 128, 32, 56, 56, 2, },
- { 5, 256, 16, 28, 28, 2, },
- { 5, 512, 8, 14, 14, 2, },
-
- { 1, 8, 4, 4, 4, 1, },
-};
-
-} // namespace fbgemm
diff --git a/test/Im2ColFusedRequantizeTest.cc b/test/Im2ColFusedRequantizeTest.cc
index b14303f..56df3c8 100644
--- a/test/Im2ColFusedRequantizeTest.cc
+++ b/test/Im2ColFusedRequantizeTest.cc
@@ -24,11 +24,6 @@
using namespace std;
using namespace fbgemm;
-vector<QuantizationGranularity> qGranularityVals{
- QuantizationGranularity::TENSOR,
- QuantizationGranularity::GROUP,
- QuantizationGranularity::OUT_CHANNEL};
-
namespace {
class fbgemmIm2colTest
: public testing::TestWithParam<tuple<QuantizationGranularity, bool>> {};
diff --git a/test/PackedRequantizeAcc16Test.cc b/test/PackedRequantizeAcc16Test.cc
index 20f860e..8978150 100644
--- a/test/PackedRequantizeAcc16Test.cc
+++ b/test/PackedRequantizeAcc16Test.cc
@@ -26,20 +26,14 @@
using namespace std;
using namespace fbgemm;
-vector<matrix_op_t> transposeVals{matrix_op_t::NoTranspose,
- matrix_op_t::Transpose};
-
-vector<QuantizationGranularity> qGranularityVals{
- QuantizationGranularity::TENSOR,
- QuantizationGranularity::GROUP,
- QuantizationGranularity::OUT_CHANNEL};
-
namespace {
class fbgemmu8s8acc16WithQuantGranularityTest
: public testing::TestWithParam<
tuple<matrix_op_t, matrix_op_t, bool, QuantizationGranularity>> {};
class fbgemmu8s8acc16Test
: public testing::TestWithParam<tuple<matrix_op_t, matrix_op_t, bool>> {};
+class fbgemmPackUnpackAcc16Test
+ : public testing::TestWithParam<tuple<matrix_op_t, bool>> {};
}; // namespace
INSTANTIATE_TEST_CASE_P(
@@ -59,6 +53,11 @@ INSTANTIATE_TEST_CASE_P(
::testing::ValuesIn(transposeVals),
::testing::Bool()));
+INSTANTIATE_TEST_CASE_P(
+ InstantiationName,
+ fbgemmPackUnpackAcc16Test,
+ ::testing::Combine(::testing::ValuesIn(transposeVals), ::testing::Bool()));
+
/**
* @brief Shapes for unit test.
*/
@@ -87,6 +86,8 @@ static vector<vector<int>> GetShapes_() {
{102, 512, 258},
{1024, 512, 258},
+
+ {120, 4, 288},
};
return shapes;
}
@@ -810,3 +811,79 @@ TEST_P(fbgemmu8s8acc16Test, NoRequantizeTest) {
} // for each groups
} // for each shape
}
+
+/**
+ * @brief Unit test for packing and unpacking the weight tensor.
+ */
+TEST_P(fbgemmPackUnpackAcc16Test, TestPackUnpack) {
+ vector<vector<int>> shapes(GetShapes_());
+ matrix_op_t btrans;
+ bool test_ld;
+ tie(btrans, test_ld) = GetParam();
+
+ BlockingFactors params;
+ params.MCB = 48;
+ params.NCB = 16;
+ params.KCB = 256;
+ params.MR = 1;
+ params.NR = 16;
+ params.ROW_INTERLEAVE = 4;
+ params.NR_MIN = 16;
+ vector<BlockingFactors*> vec_params_ptr = {&params, nullptr};
+
+ for (auto shape : shapes) {
+ for (int groups : {1, 3, 4}) {
+ for (auto params_ptr : vec_params_ptr) {
+ int n = shape[1];
+ int k = shape[2];
+
+ if (k % groups != 0) {
+ continue;
+ }
+ int k_per_group = k / groups;
+
+ // kxn matrix
+ aligned_vector<int8_t> Bint8(k * n);
+ randFill<int8_t>(Bint8, -128, 127);
+
+ // To test lda != k , we just reduce k by half and use the original k
+ // as lda.
+ int n_adjusted = n;
+ if (test_ld) {
+ if (btrans == matrix_op_t::NoTranspose) {
+ n_adjusted = std::max(n / 2, 1);
+ }
+ }
+
+ // Note that packing for weight is performed during the constructor
+ // stage.
+ PackBMatrix<int8_t, int16_t> packedWeights(
+ btrans,
+ k,
+ n_adjusted,
+ Bint8.data(),
+ (btrans == matrix_op_t::Transpose) ? k_per_group : n,
+ nullptr,
+ groups,
+ params_ptr);
+
+ // Setup a buffer to get pack -> unpacked results
+ aligned_vector<int8_t> unpack_buf(k * n, 0);
+
+ // Perform unpacking
+ packedWeights.unpack(unpack_buf.data(), params_ptr);
+
+ // Sanity check
+ for (int i = 0; i < k; i++) {
+ for (int j = 0; j < n_adjusted; j++) {
+ EXPECT_EQ(Bint8.data()[i * n + j], unpack_buf.data()[i * n + j])
+ << "Pack/Unpack results differ at index (" << i << ", " << j
+ << ", Reference: " << static_cast<int>(Bint8.data()[i * n + j])
+ << ", Pack-Unpacked: "
+ << static_cast<int>(unpack_buf.data()[i * n + j]);
+ }
+ }
+ }
+ }
+ }
+}
diff --git a/test/PackedRequantizeTest.cc b/test/PackedRequantizeTest.cc
index fd827b0..15e7d55 100644
--- a/test/PackedRequantizeTest.cc
+++ b/test/PackedRequantizeTest.cc
@@ -25,20 +25,14 @@
using namespace std;
using namespace fbgemm;
-vector<matrix_op_t> transposeVals{matrix_op_t::NoTranspose,
- matrix_op_t::Transpose};
-
-vector<QuantizationGranularity> qGranularityVals{
- QuantizationGranularity::TENSOR,
- QuantizationGranularity::GROUP,
- QuantizationGranularity::OUT_CHANNEL};
-
namespace {
class fbgemmu8s8acc32WithQuantGranularityTest
: public testing::TestWithParam<
tuple<matrix_op_t, matrix_op_t, bool, QuantizationGranularity>> {};
class fbgemmu8s8acc32Test
: public testing::TestWithParam<tuple<matrix_op_t, matrix_op_t, bool>> {};
+class fbgemmPackUnpackAcc32Test
+ : public testing::TestWithParam<tuple<matrix_op_t, bool>> {};
}; // namespace
INSTANTIATE_TEST_CASE_P(
@@ -58,6 +52,11 @@ INSTANTIATE_TEST_CASE_P(
::testing::ValuesIn(transposeVals),
::testing::Bool()));
+INSTANTIATE_TEST_CASE_P(
+ InstantiationName,
+ fbgemmPackUnpackAcc32Test,
+ ::testing::Combine(::testing::ValuesIn(transposeVals), ::testing::Bool()));
+
/**
* @brief Shapes for unit test.
*/
@@ -86,6 +85,8 @@ static vector<vector<int>> GetShapes_() {
{102, 512, 258},
{1024, 512, 258},
+
+ {120, 4, 288},
};
return shapes;
}
@@ -749,3 +750,79 @@ TEST_P(fbgemmu8s8acc32Test, TestSymmetricQuantizedInputOutput) {
} // for each groups
} // for each shape
}
+
+/**
+ * @brief Unit test for packing and unpacking the weight tensor.
+ */
+TEST_P(fbgemmPackUnpackAcc32Test, TestPackUnpack) {
+ vector<vector<int>> shapes(GetShapes_());
+ matrix_op_t btrans;
+ bool test_ld;
+ tie(btrans, test_ld) = GetParam();
+
+ BlockingFactors params;
+ params.MCB = 48;
+ params.NCB = 16;
+ params.KCB = 256;
+ params.MR = 1;
+ params.NR = 16;
+ params.ROW_INTERLEAVE = 4;
+ params.NR_MIN = 16;
+ vector<BlockingFactors*> vec_params_ptr = {&params, nullptr};
+
+ for (auto shape : shapes) {
+ for (int groups : {1, 3, 4}) {
+ for (auto params_ptr : vec_params_ptr) {
+ int n = shape[1];
+ int k = shape[2];
+
+ if (k % groups != 0) {
+ continue;
+ }
+ int k_per_group = k / groups;
+
+ // kxn matrix
+ aligned_vector<int8_t> Bint8(k * n);
+ randFill<int8_t>(Bint8, -128, 127);
+
+ // To test lda != k , we just reduce k by half and use the original k
+ // as lda.
+ int n_adjusted = n;
+ if (test_ld) {
+ if (btrans == matrix_op_t::NoTranspose) {
+ n_adjusted = std::max(n / 2, 1);
+ }
+ }
+
+ // Note that packing for weight is performed during the constructor
+ // stage.
+ PackBMatrix<int8_t> packedWeights(
+ btrans,
+ k,
+ n_adjusted,
+ Bint8.data(),
+ (btrans == matrix_op_t::Transpose) ? k_per_group : n,
+ nullptr,
+ groups,
+ params_ptr);
+
+ // Setup a buffer to get pack -> unpacked results
+ aligned_vector<int8_t> unpack_buf(k * n, 0);
+
+ // Perform unpacking
+ packedWeights.unpack(unpack_buf.data(), params_ptr);
+
+ // Sanity check
+ for (int i = 0; i < k; i++) {
+ for (int j = 0; j < n_adjusted; j++) {
+ EXPECT_EQ(Bint8.data()[i * n + j], unpack_buf.data()[i * n + j])
+ << "Pack/Unpack results differ at index (" << i << ", " << j
+ << ", Reference: " << static_cast<int>(Bint8.data()[i * n + j])
+ << ", Pack-Unpacked: "
+ << static_cast<int>(unpack_buf.data()[i * n + j]);
+ }
+ }
+ }
+ }
+ }
+}
diff --git a/test/QuantUtilsTest.cc b/test/QuantUtilsTest.cc
new file mode 100644
index 0000000..ddb1f91
--- /dev/null
+++ b/test/QuantUtilsTest.cc
@@ -0,0 +1,183 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <algorithm>
+#include <limits>
+#include <random>
+
+#include <gtest/gtest.h>
+
+#include "fbgemm/QuantUtils.h"
+#include "fbgemm/Utils.h"
+
+using namespace std;
+using namespace fbgemm;
+
+// tuple represents K, C, X, G, layout_t
+// layout_t can be KCX or KXC
+class QuantizeGroupwiseTest
+ : public testing::TestWithParam<tuple<int, int, int, int, layout_t>> {};
+
+INSTANTIATE_TEST_CASE_P(
+ InstantiationName,
+ QuantizeGroupwiseTest,
+ ::testing::Combine(
+ ::testing::ValuesIn({4, 12, 64}), // K
+ ::testing::ValuesIn({12, 16, 32}), // C
+ ::testing::ValuesIn({1, 10, 15, 30}), // X
+ ::testing::ValuesIn({1, 4}), // G
+ ::testing::ValuesIn({layout_t::KCX, layout_t::KXC})));
+
+template <typename T, layout_t LT>
+void ref_impl(
+ const vector<float>& src,
+ int K,
+ int C,
+ int X,
+ int G,
+ const vector<float>& scales,
+ const vector<int>& zero_points,
+ vector<T>& dst) {
+ int C_per_G = C / G;
+ for (int i = 0; i < K; ++i) {
+ for (int g = 0; g < G; ++g) {
+ for (int c = 0; c < C / G; ++c) {
+ for (int x = 0; x < X; ++x) {
+ float num;
+ if (LT == layout_t::KCX) {
+ num = src[(i * C + g * C_per_G + c) * X + x];
+ } else {
+ num = src[(i * X + x) * C + g * C_per_G + c];
+ }
+ int res = nearbyint(zero_points[g] + num / scales[g]);
+ T final_res = min<T>(
+ max<T>(res, numeric_limits<T>::min()), numeric_limits<T>::max());
+ if (LT == layout_t::KCX) {
+ dst[(i * C + g * C_per_G + c) * X + x] = final_res;
+ } else {
+ dst[(i * X + x) * C + g * C_per_G + c] = final_res;
+ }
+ }
+ }
+ }
+ }
+}
+
+template <typename T, layout_t LT>
+void runTests(
+ const vector<float>& src,
+ int K,
+ int C,
+ int X,
+ int G,
+ const vector<float>& scales,
+ const vector<int>& zero_points,
+ vector<T>& dst,
+ vector<T>& dst_ref) {
+ QuantizeGroupwise<T, LT>(
+ src.data(), K, C, X, G, scales.data(), zero_points.data(), dst.data());
+
+ ref_impl<T, LT>(src, K, C, X, G, scales, zero_points, dst_ref);
+}
+
+/**
+ * There can be off-by-one error in quantized values due to how the mid-point
+ * cases are rounded-off in vectorized vs scalar codes and due to adding of
+ * zero_point before rounding vs after rounding. We ignore such differences
+ * while comparing results.
+ */
+template <typename T>
+::testing::AssertionResult isNear(
+ const vector<T>& res,
+ const vector<T>& res_ref) {
+ bool match = true;
+ if (res.size() == res_ref.size()) {
+ for (int i = 0; i < res.size(); ++i) {
+ if (!(res[i] == res_ref[i] || res[i] == res_ref[i] + 1 ||
+ res[i] == res_ref[i] - 1)) {
+ match = false;
+ break;
+ }
+ }
+ }
+ if (match)
+ return ::testing::AssertionSuccess();
+ else
+ return ::testing::AssertionFailure() << " Quantized results do not match";
+}
+
+/**
+ * Test for QuantizeGroupwise
+ */
+TEST_P(QuantizeGroupwiseTest, quantizeTest) {
+ int K, C, X, G;
+ layout_t layout;
+ tie(K, C, X, G, layout) = GetParam();
+
+ random_device rd;
+ mt19937 gen(rd());
+
+ uniform_real_distribution<float> disFP(0.1, 1.1);
+
+ vector<float> inp(K * C * X);
+ generate(inp.begin(), inp.end(), [&, disFP]() mutable { return disFP(gen); });
+
+ vector<float> scales(G);
+ generate(scales.begin(), scales.end(), [&, disFP]() mutable {
+ return disFP(gen);
+ });
+
+ uniform_int_distribution<> disUInt8(0, 8);
+ vector<int> zero_points_uint8(G);
+ generate(
+ zero_points_uint8.begin(),
+ zero_points_uint8.end(),
+ [&, disUInt8]() mutable { return disUInt8(gen); });
+
+ uniform_int_distribution<> disInt8(-64, 63);
+ vector<int> zero_points_int8(G);
+ generate(
+ zero_points_int8.begin(), zero_points_int8.end(), [&, disInt8]() mutable {
+ return disInt8(gen);
+ });
+
+ uniform_int_distribution<> disInt32(-512, 512);
+ vector<int> zero_points_int32(G);
+ generate(
+ zero_points_int32.begin(),
+ zero_points_int32.end(),
+ [&, disInt32]() mutable { return disInt32(gen); });
+
+ vector<uint8_t> dstuint8(K * C * X);
+ vector<uint8_t> dstuint8_ref(K * C * X);
+
+ vector<int8_t> dstint8(K * C * X);
+ vector<int8_t> dstint8_ref(K * C * X);
+
+ vector<int32_t> dstint32(K * C * X);
+ vector<int32_t> dstint32_ref(K * C * X);
+
+ if (layout == layout_t::KCX) {
+ runTests<uint8_t, layout_t::KCX>(
+ inp, K, C, X, G, scales, zero_points_uint8, dstuint8, dstuint8_ref);
+ runTests<int8_t, layout_t::KCX>(
+ inp, K, C, X, G, scales, zero_points_int8, dstint8, dstint8_ref);
+ runTests<int32_t, layout_t::KCX>(
+ inp, K, C, X, G, scales, zero_points_int32, dstint32, dstint32_ref);
+ } else {
+ runTests<uint8_t, layout_t::KXC>(
+ inp, K, C, X, G, scales, zero_points_uint8, dstuint8, dstuint8_ref);
+ runTests<int8_t, layout_t::KXC>(
+ inp, K, C, X, G, scales, zero_points_int8, dstint8, dstint8_ref);
+ runTests<int32_t, layout_t::KXC>(
+ inp, K, C, X, G, scales, zero_points_int32, dstint32, dstint32_ref);
+ }
+
+ EXPECT_TRUE(isNear(dstuint8, dstuint8_ref));
+ EXPECT_TRUE(isNear(dstint8, dstint8_ref));
+ EXPECT_TRUE(isNear(dstint32, dstint32_ref));
+}
diff --git a/test/RequantizeOnlyTest.cc b/test/RequantizeOnlyTest.cc
new file mode 100644
index 0000000..94e8e7d
--- /dev/null
+++ b/test/RequantizeOnlyTest.cc
@@ -0,0 +1,169 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+#include <algorithm>
+#include <functional>
+#include <iostream>
+#include <random>
+#include <stdexcept>
+#include <string>
+
+#include <gtest/gtest.h>
+
+#include "TestUtils.h"
+#include "bench/BenchUtils.h"
+#include "fbgemm/Fbgemm.h"
+
+using namespace std;
+using namespace fbgemm;
+
+vector<QuantizationGranularity> qGranularityValsLocal{
+ QuantizationGranularity::TENSOR,
+ QuantizationGranularity::OUT_CHANNEL};
+
+namespace {
+
+// tuple represents #rows, #cols, fuse_relu, quantization_granularity, bias_type
+class FloatRequantizeTest
+ : public testing::TestWithParam<
+ tuple<int, int, bool, QuantizationGranularity>> {};
+
+}; // namespace
+
+INSTANTIATE_TEST_CASE_P(
+ InstantiationName,
+ FloatRequantizeTest,
+ ::testing::Combine(
+ ::testing::ValuesIn({1, 2, 3, 4}), // number of rows
+ ::testing::ValuesIn(
+ {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 16, 20, 32}), // number of
+ // cols
+ ::testing::Bool(), // fuse relu
+ ::testing::ValuesIn(qGranularityValsLocal))); // requantization granularity
+
+/**
+ * Test for float bias
+ */
+TEST_P(FloatRequantizeTest, floatBiasTest) {
+ int rows, cols;
+ bool fuse_relu;
+ QuantizationGranularity q_gran;
+ tie(rows, cols, fuse_relu, q_gran) = GetParam();
+
+ int numElements = rows * cols;
+
+ aligned_vector<float> act_times_w_scale(cols);
+ randFill<float>(act_times_w_scale, -8, 8);
+
+ float out_scale = 2.0f;
+
+ aligned_vector<float> C_multiplier(cols);
+ transform(
+ act_times_w_scale.begin(),
+ act_times_w_scale.end(),
+ C_multiplier.begin(),
+ [&out_scale](float i) { return i / out_scale; });
+
+ aligned_vector<int32_t> Bint8_zero_point(cols);
+ randFill<int32_t>(Bint8_zero_point, -8, 8);
+
+ aligned_vector<int32_t> row_offset_buf(rows);
+ randFill<int32_t>(row_offset_buf, -8, 8);
+
+ aligned_vector<int32_t> col_offsets(cols);
+ randFill<int32_t>(col_offsets, -8, 8);
+
+ // quantized bias
+ aligned_vector<int32_t> bias_q(cols);
+ randFill<int32_t>(bias_q, -8, 8);
+
+ // floating point bias
+ aligned_vector<float> bias_f(cols);
+ if (q_gran == QuantizationGranularity::TENSOR) {
+ transform(
+ bias_q.begin(),
+ bias_q.end(),
+ bias_f.begin(),
+ [&act_times_w_scale](float i) { return i * act_times_w_scale[0]; });
+ } else if (q_gran == QuantizationGranularity::OUT_CHANNEL) {
+ transform(
+ act_times_w_scale.begin(),
+ act_times_w_scale.end(),
+ bias_q.begin(),
+ bias_f.begin(),
+ multiplies<float>());
+
+ } else {
+ FAIL();
+ }
+
+ aligned_vector<int32_t> input(numElements);
+ randFill<int32_t>(input, -8, 8);
+
+ aligned_vector<uint8_t> output_q_bias(numElements);
+ aligned_vector<uint8_t> output_f_bias(numElements);
+
+ int32_t C_zero_point = 3;
+ int32_t Aint8_zero_point = 3;
+
+ block_type_t block{0, rows, 0, cols};
+
+ DoNothing<> doNothingObj{};
+
+#define TESTCODE(FUSE_RELU, Q_GRAN) \
+ ReQuantizeOutput<FUSE_RELU, Q_GRAN> reqObj_q( \
+ doNothingObj, \
+ C_multiplier.data(), \
+ C_zero_point, \
+ Aint8_zero_point, \
+ Bint8_zero_point.data(), \
+ row_offset_buf.data(), \
+ col_offsets.data(), \
+ bias_q.data(), \
+ cols); \
+ ReQuantizeOutput<FUSE_RELU, Q_GRAN, float> reqObj_f( \
+ doNothingObj, \
+ C_multiplier.data(), \
+ C_zero_point, \
+ Aint8_zero_point, \
+ Bint8_zero_point.data(), \
+ row_offset_buf.data(), \
+ col_offsets.data(), \
+ bias_f.data(), \
+ cols, \
+ 1, \
+ act_times_w_scale.data()); \
+ reqObj_q.f<inst_set_t::avx2>( \
+ output_q_bias.data(), input.data(), block, cols, cols); \
+ reqObj_f.f<inst_set_t::avx2>( \
+ output_f_bias.data(), input.data(), block, cols, cols);
+
+ if (fuse_relu) {
+ if (q_gran == QuantizationGranularity::TENSOR) {
+ TESTCODE(true, QuantizationGranularity::TENSOR)
+
+ } else if (q_gran == QuantizationGranularity::OUT_CHANNEL) {
+ TESTCODE(true, QuantizationGranularity::OUT_CHANNEL)
+
+ } else {
+ FAIL();
+ }
+
+ } else {
+ if (q_gran == QuantizationGranularity::TENSOR) {
+ TESTCODE(false, QuantizationGranularity::TENSOR)
+
+ } else if (q_gran == QuantizationGranularity::OUT_CHANNEL) {
+ TESTCODE(false, QuantizationGranularity::OUT_CHANNEL)
+
+ } else {
+ FAIL();
+ }
+ }
+#undef TESTCODE
+ ASSERT_EQ(output_q_bias, output_f_bias)
+ << "Requantization with quantized bias and float bias differs";
+}
diff --git a/test/TestUtils.h b/test/TestUtils.h
index 2cb7b88..d320ae2 100644
--- a/test/TestUtils.h
+++ b/test/TestUtils.h
@@ -7,9 +7,18 @@
#pragma once
#include <cmath>
#include <vector>
+#include "fbgemm/Fbgemm.h"
namespace fbgemm {
+static std::vector<matrix_op_t> transposeVals = { matrix_op_t::NoTranspose,
+ matrix_op_t::Transpose };
+
+static std::vector<QuantizationGranularity> qGranularityVals = {
+ QuantizationGranularity::TENSOR,
+ QuantizationGranularity::GROUP,
+ QuantizationGranularity::OUT_CHANNEL };
+
/*
* @brief Check and validate the buffers for reference and FBGEMM result.
*/
diff --git a/test/UniConvPackingTest.cc b/test/UniConvPackingTest.cc
deleted file mode 100644
index 77552af..0000000
--- a/test/UniConvPackingTest.cc
+++ /dev/null
@@ -1,148 +0,0 @@
-/*
- * Copyright (c) Facebook, Inc. and its affiliates.
- * All rights reserved.
- * This source code is licensed under the BSD-style license found in the
- * LICENSE file in the root directory of this source tree.
- */
-#include <algorithm>
-#include <random>
-#include <iostream>
-
-
-#include <gtest/gtest.h>
-
-#include "QuantizationHelpers.h"
-#include "TestUtils.h"
-#include "bench/BenchUtils.h"
-#include "fbgemm/Fbgemm.h"
-#include "src/RefImplementations.h"
-
-using namespace std;
-using namespace fbgemm;
-
-namespace {
-
-// tuple represents MB, IC, OC, IT, IH, IW, KH/KW, stride, pad
-class convPackingTest
- : public testing::TestWithParam<
- tuple<int, int, int, int, int, int, int, int, int, int>> {};
-
-}; // namespace
-
-INSTANTIATE_TEST_CASE_P(
- InstantiationName,
- convPackingTest,
- ::testing::Combine(
- ::testing::ValuesIn({1, 2}), // MB
- ::testing::ValuesIn({16, 32}), // IC
- ::testing::ValuesIn({16, 32}), // OC
- ::testing::ValuesIn({17}), // IT
- ::testing::ValuesIn({10, 30, 55}), // IH
- ::testing::ValuesIn({10, 30, 55}), // IW
- ::testing::ValuesIn({1, 4, 16}), // G
- ::testing::ValuesIn({3, 7}), // kernel
- ::testing::ValuesIn({1, 2}), // stride
- ::testing::ValuesIn({1, 2}))); // pad
-
-/**
- * Test for conv packing
- */
-TEST_P(convPackingTest, packingTest) {
- int MB, IC, OC, IT, IH, IW, G, kernel, stride, pad;
- tie(MB, IC, OC, IT, IH, IW, G, kernel, stride, pad) = GetParam();
-
- conv_param_t<2> conv_p_2d(
- MB,
- IC,
- OC,
- {IH, IW},
- G,
- {kernel, kernel},
- {stride, stride},
- {pad, pad, pad, pad});
-
- int kernel_dim_2d = kernel * kernel;
- aligned_vector<int8_t> Bint8_2d(
- kernel_dim_2d * conv_p_2d.IC * (conv_p_2d.OC / conv_p_2d.G));
- PackWeightsForConv<2> packedB_2D(conv_p_2d, Bint8_2d.data());
-
- switch (ConvFastPath<2, int32_t>(conv_p_2d)) {
- case optimized_conv_t::depthwise: {
- ASSERT_NE(packedB_2D.getPackedWFor2DDW(), nullptr)
- << "2D depthwise packed matrix is null";
- ASSERT_EQ(packedB_2D.getPackedWForIm2col(), nullptr)
- << "im2col packed matrix should be null";
- ASSERT_EQ(packedB_2D.getPackedWFor3DDW(), nullptr)
- << "3D depthwise packed matrix should be null";
- ASSERT_EQ(packedB_2D.getPackedWForGroupwise(), nullptr)
- << "groupwise packed matrix should be null";
- break;
- }
- case optimized_conv_t::groupwise: {
- ASSERT_EQ(packedB_2D.getPackedWForIm2col(), nullptr)
- << "im2col packed matrix should be null";
- ASSERT_EQ(packedB_2D.getPackedWFor2DDW(), nullptr)
- << "2D depthwise packed matrix is null";
- ASSERT_EQ(packedB_2D.getPackedWFor3DDW(), nullptr)
- << "3D depthwise packed matrix should be null";
- ASSERT_NE(packedB_2D.getPackedWForGroupwise(), nullptr)
- << "Groupwise packed matrix is null";
- break;
- }
- case optimized_conv_t::im2col: {
- ASSERT_EQ(packedB_2D.getPackedWFor2DDW(), nullptr)
- << "2D depthwise packed matrix is null";
- ASSERT_EQ(packedB_2D.getPackedWFor3DDW(), nullptr)
- << "3D depthwise packed matrix should be null";
- ASSERT_EQ(packedB_2D.getPackedWForGroupwise(), nullptr)
- << "groupwise packed matrix should be null";
- ASSERT_NE(packedB_2D.getPackedWForIm2col(), nullptr)
- << "im2col packed matrix is null";
- break;
- }
- }
-
- conv_param_t<3> conv_p_3d(
- MB,
- IC,
- OC,
- {IT, IH, IW},
- G,
- {kernel, kernel, kernel},
- {stride, stride, stride},
- {pad, pad, pad, pad, pad, pad});
-
- int kernel_dim_3d = kernel * kernel * kernel;
- aligned_vector<int8_t> Bint8_3d(
- kernel_dim_3d * conv_p_3d.IC * (conv_p_3d.OC / conv_p_3d.G));
- PackWeightsForConv<3> packedB_3D(conv_p_3d, Bint8_3d.data());
-
- switch (ConvFastPath<3, int32_t>(conv_p_3d)) {
- case optimized_conv_t::depthwise: {
- ASSERT_EQ(packedB_3D.getPackedWFor2DDW(), nullptr)
- << "2D depthwise packed matrix is null";
- ASSERT_EQ(packedB_3D.getPackedWForIm2col(), nullptr)
- << "im2col packed matrix should be null";
- ASSERT_NE(packedB_3D.getPackedWFor3DDW(), nullptr)
- << "3D depthwise packed matrix should be null";
- ASSERT_EQ(packedB_3D.getPackedWForGroupwise(), nullptr)
- << "groupwise packed matrix should be null";
- break;
- }
- case optimized_conv_t::groupwise: {
- ASSERT_TRUE(false) << "groupwise are not supported for 3D";
- break;
- }
- case optimized_conv_t::im2col: {
- ASSERT_EQ(packedB_3D.getPackedWFor2DDW(), nullptr)
- << "2D depthwise packed matrix is null";
- ASSERT_EQ(packedB_3D.getPackedWFor3DDW(), nullptr)
- << "3D depthwise packed matrix should be null";
- ASSERT_EQ(packedB_3D.getPackedWForGroupwise(), nullptr)
- << "groupwise packed matrix should be null";
- ASSERT_NE(packedB_3D.getPackedWForIm2col(), nullptr)
- << "im2col packed matrix is null";
- break;
- }
- }
-}
diff --git a/test/UniConvTest.cc b/test/UniConvTest.cc
new file mode 100644
index 0000000..e9c7ba5
--- /dev/null
+++ b/test/UniConvTest.cc
@@ -0,0 +1,714 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+#include <algorithm>
+#include <iostream>
+#include <random>
+#include <stdexcept>
+
+#include <gtest/gtest.h>
+
+#include "QuantizationHelpers.h"
+#include "TestUtils.h"
+#include "bench/BenchUtils.h"
+#include "fbgemm/Fbgemm.h"
+#include "src/RefImplementations.h"
+
+using namespace std;
+using namespace fbgemm;
+
+// clang-format off
+static vector<conv_param_t<>> GetShapes_() {
+ vector<conv_param_t<>> shapes = {
+ // MB, IC, OC, {IH, IW}, G, {KH, KW}, {stride_h, stride_w}, {pad_t, pad_l,
+ // pad_b, pad_r}
+ // Regular
+ conv_param_t<>(1, 16, 16, {10, 30}, 1, {3, 3}, {1, 1}, {1, 1, 1, 1}),
+ conv_param_t<>(1, 32, 32, {10, 30}, 1, {3, 3}, {1, 1}, {1, 1, 1, 1}),
+ conv_param_t<>(1, 16, 32, {30, 10}, 1, {3, 3}, {1, 1}, {1, 1, 1, 1}),
+ conv_param_t<>(1, 32, 16, {10, 30}, 1, {3, 3}, {1, 1}, {1, 1, 1, 1}),
+ conv_param_t<>(1, 32, 16, {10, 30}, 1, {3, 3}, {1, 1}, {0, 0, 0, 0}),
+ conv_param_t<>(1, 32, 16, {10, 30}, 1, {3, 3}, {1, 1}, {0, 0, 0, 0}, {2, 2}),
+ conv_param_t<>(1, 32, 16, {10, 30}, 1, {3, 3}, {1, 1}, {1, 1, 1, 1}, {2, 2}),
+ conv_param_t<>(1, 32, 16, {10, 30}, 1, {3, 3}, {1, 1}, {1, 1, 1, 1}, {3, 3}),
+ conv_param_t<>(1, 32, 16, {10, 30}, 1, {3, 3}, {2, 2}, {1, 1, 1, 1}),
+ conv_param_t<>(1, 32, 16, {10, 30}, 1, {3, 3}, {2, 2}, {1, 1, 1, 1}, {2, 2}),
+ conv_param_t<>(1, 32, 16, {10, 30}, 1, {3, 3}, {2, 2}, {1, 1, 1, 1}, {2, 1}),
+ conv_param_t<>(1, 32, 16, {10, 30}, 1, {3, 3}, {2, 2}, {1, 1, 1, 1}, {1, 2}),
+ conv_param_t<>(1, 32, 16, {10, 30}, 1, {3, 3}, {1, 1}, {2, 1, 2, 1}),
+ conv_param_t<>(1, 32, 16, {10, 30}, 1, {3, 3}, {1, 1}, {1, 2, 1, 2}),
+ conv_param_t<>(1, 32, 16, {10, 30}, 1, {3, 3}, {1, 2}, {1, 1, 1, 1}),
+ conv_param_t<>(1, 32, 16, {10, 30}, 1, {3, 3}, {2, 1}, {1, 1, 1, 1}),
+ conv_param_t<>(1, 16, 16, {10, 30}, 1, {3, 5}, {1, 1}, {1, 1, 1, 1}),
+ conv_param_t<>(1, 16, 16, {10, 30}, 1, {5, 3}, {1, 1}, {1, 1, 1, 1}),
+ conv_param_t<>(1, 16, 16, {10, 30}, 1, {5, 3}, {1, 1}, {1, 1, 1, 1}, {2, 2}),
+ // groupwise
+ conv_param_t<>(1, 32, 32, {10, 30}, 8, {3, 3}, {1, 1}, {1, 1, 1, 1}),
+ conv_param_t<>(1, 32, 16, {10, 30}, 8, {3, 3}, {1, 1}, {1, 1, 1, 1}),
+ conv_param_t<>(1, 16, 32, {10, 30}, 8, {3, 3}, {2, 2}, {1, 1, 1, 1}),
+ conv_param_t<>(1, 32, 32, {10, 30}, 8, {3, 3}, {2, 2}, {2, 1, 2, 1}),
+ conv_param_t<>(1, 32, 32, {10, 30}, 8, {3, 3}, {1, 2}, {2, 1, 2, 1}),
+ conv_param_t<>(1, 32, 32, {10, 30}, 8, {3, 3}, {2, 1}, {2, 1, 2, 1}),
+ conv_param_t<>(1, 32, 32, {10, 30}, 8, {3, 5}, {1, 1}, {1, 1, 1, 1}),
+ conv_param_t<>(1, 32, 32, {10, 30}, 8, {5, 3}, {1, 1}, {1, 1, 1, 1}),
+ conv_param_t<>(1, 32, 32, {10, 30}, 8, {5, 3}, {1, 1}, {1, 1, 1, 1}, {2, 2}),
+ // DW
+ conv_param_t<>(1, 32, 32, {10, 30}, 32, {3, 3}, {1, 1}, {1, 1, 1, 1}),
+ conv_param_t<>(1, 32, 32, {10, 30}, 32, {3, 3}, {2, 2}, {1, 1, 1, 1}),
+ conv_param_t<>(1, 32, 32, {10, 30}, 32, {3, 3}, {1, 1}, {1, 2, 1, 2}),
+ conv_param_t<>(1, 32, 32, {10, 30}, 32, {3, 3}, {2, 1}, {1, 1, 1, 1}),
+ conv_param_t<>(1, 32, 32, {10, 30}, 32, {3, 3}, {1, 2}, {1, 1, 1, 1}),
+ conv_param_t<>(1, 32, 32, {10, 30}, 32, {3, 5}, {1, 1}, {1, 1, 1, 1}),
+ conv_param_t<>(1, 32, 32, {10, 30}, 32, {5, 3}, {1, 1}, {1, 1, 1, 1}),
+ conv_param_t<>(1, 32, 32, {10, 30}, 32, {5, 3}, {1, 1}, {1, 1, 1, 1}, {2, 2}),
+ // Pointwise
+ conv_param_t<>(1, 32, 32, {10, 30}, 1, {1, 1}, {1, 1}, {0, 0, 0, 0}),
+ conv_param_t<>(1, 16, 32, {10, 30}, 1, {1, 1}, {1, 1}, {0, 0, 0, 0}),
+ conv_param_t<>(1, 32, 16, {10, 30}, 1, {1, 1}, {1, 1}, {0, 0, 0, 0}),
+ conv_param_t<>(1, 32, 16, {10, 30}, 1, {1, 1}, {2, 2}, {0, 0, 0, 0}),
+ conv_param_t<>(1, 32, 16, {10, 30}, 1, {1, 1}, {1, 2}, {0, 0, 0, 0}),
+ conv_param_t<>(1, 32, 16, {10, 30}, 1, {1, 1}, {2, 1}, {0, 0, 0, 0}),
+ };
+ return shapes;
+}
+// clang-format on
+
+namespace {
+
+// tuple represents MB, IC, OC, IT, IH, IW, KH/KW, stride, pad
+class uniConvTest
+ : public testing::TestWithParam<
+ tuple<int, int, int, int, int, int, int, int, int, int>> {};
+
+// tuple represents QuantizationGranularity, A symmetric, B symmetric,
+// test_bias, test_float_bias
+class UniConvQGranTest
+ : public testing::TestWithParam<
+ tuple<QuantizationGranularity, bool, bool, bool, bool>> {};
+
+}; // namespace
+
+// Combine only allows at most 10 generators.
+INSTANTIATE_TEST_CASE_P(
+ InstantiationName,
+ uniConvTest,
+ ::testing::Combine(
+ ::testing::ValuesIn({1, 2}), // MB
+ ::testing::ValuesIn({16, 32}), // IC
+ ::testing::ValuesIn({16, 32}), // OC
+ ::testing::ValuesIn({17}), // IT
+ ::testing::ValuesIn({10, 30, 55}), // IH
+ ::testing::ValuesIn({10, 30, 55}), // IW
+ ::testing::ValuesIn({1, 4, 16}), // G
+ ::testing::ValuesIn({1, 3, 7}), // kernel
+ ::testing::ValuesIn({1, 2}), // stride
+ ::testing::ValuesIn({0, 1, 2}))); // pad
+
+INSTANTIATE_TEST_CASE_P(
+ InstantiationName,
+ UniConvQGranTest,
+ ::testing::Combine(
+ ::testing::ValuesIn(qGranularityVals),
+ ::testing::Bool(), // A symmetric
+ ::testing::Bool(), // B symmetric
+ ::testing::Bool(), // test_bias
+ ::testing::Bool())); // test_float_bias
+/**
+ * Test for conv packing
+ */
+TEST_P(uniConvTest, packingTest) {
+ int MB, IC, OC, IT, IH, IW, G, kernel, stride, pad;
+ tie(MB, IC, OC, IT, IH, IW, G, kernel, stride, pad) = GetParam();
+
+ conv_param_t<2> conv_p_2d(
+ MB,
+ IC,
+ OC,
+ {IH, IW},
+ G,
+ {kernel, kernel},
+ {stride, stride},
+ {pad, pad, pad, pad});
+
+ int kernel_dim_2d = kernel * kernel;
+ aligned_vector<int8_t> Bint8_2d(
+ kernel_dim_2d * conv_p_2d.IC * (conv_p_2d.OC / conv_p_2d.G));
+ PackWeightsForConv<2> packedB_2D(conv_p_2d, Bint8_2d.data());
+
+ switch (ConvFastPath<2, int32_t>(conv_p_2d)) {
+ case optimized_conv_t::depthwise: {
+ ASSERT_EQ(packedB_2D.getPackedWForIm2col(), nullptr)
+ << "im2col packed matrix should be null";
+ ASSERT_EQ(packedB_2D.getPackedWForGroupwise(), nullptr)
+ << "groupwise packed matrix should be null";
+ ASSERT_EQ(packedB_2D.getPackedWForPointwise(), nullptr)
+ << "pointwise packed matrix should be null";
+ ASSERT_NE(packedB_2D.getPackedWForDepthwise(), nullptr)
+ << "depthwise packed matrix is null";
+ break;
+ }
+ case optimized_conv_t::groupwise: {
+ ASSERT_EQ(packedB_2D.getPackedWForIm2col(), nullptr)
+ << "im2col packed matrix should be null";
+ ASSERT_EQ(packedB_2D.getPackedWForDepthwise(), nullptr)
+ << "depthwise packed matrix should be null";
+ ASSERT_EQ(packedB_2D.getPackedWForPointwise(), nullptr)
+ << "pointwise packed matrix should be null";
+ ASSERT_NE(packedB_2D.getPackedWForGroupwise(), nullptr)
+ << "Groupwise packed matrix is null";
+ break;
+ }
+ case optimized_conv_t::pointwise: {
+ ASSERT_EQ(packedB_2D.getPackedWForIm2col(), nullptr)
+ << "im2col packed matrix should be null";
+ ASSERT_EQ(packedB_2D.getPackedWForDepthwise(), nullptr)
+ << "depthwise packed matrix should null";
+ ASSERT_EQ(packedB_2D.getPackedWForGroupwise(), nullptr)
+ << "Groupwise packed matrix should be null";
+ ASSERT_NE(packedB_2D.getPackedWForPointwise(), nullptr)
+ << "pointwise packed matrix is null";
+ break;
+ }
+ case optimized_conv_t::im2col: {
+ ASSERT_EQ(packedB_2D.getPackedWForDepthwise(), nullptr)
+ << "depthwise packed matrix should be null";
+ ASSERT_EQ(packedB_2D.getPackedWForGroupwise(), nullptr)
+ << "groupwise packed matrix should be null";
+ ASSERT_EQ(packedB_2D.getPackedWForPointwise(), nullptr)
+ << "pointwise packed matrix should be null";
+ ASSERT_NE(packedB_2D.getPackedWForIm2col(), nullptr)
+ << "im2col packed matrix is null";
+ break;
+ }
+ }
+
+ conv_param_t<3> conv_p_3d(
+ MB,
+ IC,
+ OC,
+ {IT, IH, IW},
+ G,
+ {kernel, kernel, kernel},
+ {stride, stride, stride},
+ {pad, pad, pad, pad, pad, pad});
+
+ int kernel_dim_3d = kernel * kernel * kernel;
+ aligned_vector<int8_t> Bint8_3d(
+ kernel_dim_3d * conv_p_3d.IC * (conv_p_3d.OC / conv_p_3d.G));
+ PackWeightsForConv<3> packedB_3D(conv_p_3d, Bint8_3d.data());
+
+ switch (ConvFastPath<3, int32_t>(conv_p_3d)) {
+ case optimized_conv_t::depthwise: {
+ ASSERT_EQ(packedB_3D.getPackedWForIm2col(), nullptr)
+ << "im2col packed matrix should be null";
+ ASSERT_EQ(packedB_3D.getPackedWForGroupwise(), nullptr)
+ << "groupwise packed matrix should be null";
+ ASSERT_EQ(packedB_3D.getPackedWForPointwise(), nullptr)
+ << "pointwise packed matrix should be null";
+ ASSERT_NE(packedB_3D.getPackedWForDepthwise(), nullptr)
+ << "depthwise packed matrix is null";
+ break;
+ }
+ case optimized_conv_t::groupwise: {
+ ASSERT_TRUE(false) << "groupwise are not supported for 3D";
+ break;
+ }
+ case optimized_conv_t::pointwise: {
+ ASSERT_EQ(packedB_3D.getPackedWForDepthwise(), nullptr)
+ << "depthwise packed matrix should be null";
+ ASSERT_EQ(packedB_3D.getPackedWForGroupwise(), nullptr)
+ << "groupwise packed matrix should be null";
+ ASSERT_EQ(packedB_3D.getPackedWForIm2col(), nullptr)
+ << "im2col packed matrix should be null";
+ ASSERT_NE(packedB_3D.getPackedWForPointwise(), nullptr)
+ << "pointwise packed matrix is null";
+ break;
+ }
+ case optimized_conv_t::im2col: {
+ ASSERT_EQ(packedB_3D.getPackedWForDepthwise(), nullptr)
+ << "depthwise packed matrix should be null";
+ ASSERT_EQ(packedB_3D.getPackedWForGroupwise(), nullptr)
+ << "groupwise packed matrix should be null";
+ ASSERT_EQ(packedB_3D.getPackedWForPointwise(), nullptr)
+ << "pointwise packed matrix should be null";
+ ASSERT_NE(packedB_3D.getPackedWForIm2col(), nullptr)
+ << "im2col packed matrix is null";
+ break;
+ }
+ }
+}
+
+/**
+ * Test for packing/unpacking
+ */
+TEST_P(uniConvTest, packUnpackTest) {
+ int MB, IC, OC, IT, IH, IW, G, kernel, stride, pad;
+ tie(MB, IC, OC, IT, IH, IW, G, kernel, stride, pad) = GetParam();
+
+ conv_param_t<2> conv_p_2d(
+ MB,
+ IC,
+ OC,
+ {IH, IW},
+ G,
+ {kernel, kernel},
+ {stride, stride},
+ {pad, pad, pad, pad});
+
+ int kernel_dim_2d = kernel * kernel;
+
+ aligned_vector<int8_t> Bint8_2d(
+ kernel_dim_2d * conv_p_2d.IC * (conv_p_2d.OC / conv_p_2d.G));
+ aligned_vector<int8_t> Bint8_2d_unpacked(
+ kernel_dim_2d * conv_p_2d.IC * (conv_p_2d.OC / conv_p_2d.G));
+
+ PackWeightsForConv<2> packedB_2D(conv_p_2d, Bint8_2d.data());
+
+ packedB_2D.unpack(Bint8_2d_unpacked.data());
+
+ ASSERT_EQ(Bint8_2d, Bint8_2d_unpacked)
+ << "Original and unpacked data elements are not the same [2D]";
+
+ conv_param_t<3> conv_p_3d(
+ MB,
+ IC,
+ OC,
+ {IT, IH, IW},
+ G,
+ {kernel, kernel, kernel},
+ {stride, stride, stride},
+ {pad, pad, pad, pad, pad, pad});
+
+ int kernel_dim_3d = kernel * kernel * kernel;
+
+ aligned_vector<int8_t> Bint8_3d(
+ kernel_dim_3d * conv_p_3d.IC * (conv_p_3d.OC / conv_p_3d.G));
+
+ aligned_vector<int8_t> Bint8_3d_unpacked(
+ kernel_dim_3d * conv_p_3d.IC * (conv_p_3d.OC / conv_p_3d.G));
+
+ PackWeightsForConv<3> packedB_3D(conv_p_3d, Bint8_3d.data());
+
+ packedB_3D.unpack(Bint8_3d_unpacked.data());
+
+ ASSERT_EQ(Bint8_3d, Bint8_3d_unpacked)
+ << "Original and unpacked data elements are not the same [3D]";
+}
+
+TEST(uniConvTest, cornerCases) {
+ int stride = 1;
+ conv_param_t<2> conv_p_2d(
+ 1, // mini-batch
+ 16, // input channels
+ 32, // output channels
+ {28, 28}, // input height/width
+ 4, // groups
+ {3, 3}, // kernel height/width
+ {stride, stride}, // strides
+ {1, 1, 1, 1}); // padding
+
+ int kernel_dim_2d = conv_p_2d.K[0] * conv_p_2d.K[1];
+
+ aligned_vector<uint8_t> Aint8(
+ conv_p_2d.MB * conv_p_2d.IN_DIM[0] * conv_p_2d.IN_DIM[1] * conv_p_2d.IC);
+ aligned_vector<int8_t> Bint8_2d(
+ kernel_dim_2d * conv_p_2d.IC * (conv_p_2d.OC / conv_p_2d.G));
+ aligned_vector<int32_t> Cint32_fb(
+ conv_p_2d.MB * conv_p_2d.OUT_DIM[0] * conv_p_2d.OUT_DIM[1] *
+ conv_p_2d.OC);
+ aligned_vector<uint8_t> Cint8_fb(Cint32_fb.size(), 0);
+
+ // A matrix (input activations)
+ randFill<uint8_t>(Aint8, 0, 5);
+ int32_t Aint8_zero_point = 4;
+
+ // B matrix (weights)
+ randFill<int8_t>(Bint8_2d, -4, 4);
+ aligned_vector<int32_t> Bint8_zero_point(1);
+ randFill(Bint8_zero_point, -3, -1);
+
+ aligned_vector<float> C_multiplier(Bint8_zero_point.size());
+ randFill(C_multiplier, 0.1234f / 2, 0.1234f * 3 / 2);
+ int32_t C_zero_point = 5;
+
+ PackWeightsForConv<2> packedB_2D(conv_p_2d, Bint8_2d.data());
+
+ vector<int32_t> col_offsets(conv_p_2d.OC);
+
+ DoNothing<> doNothingObj{};
+ ReQuantizeOutput<false, QuantizationGranularity::TENSOR> outputProcObj(
+ doNothingObj,
+ C_multiplier.data(),
+ C_zero_point,
+ Aint8_zero_point,
+ Bint8_zero_point.data(),
+ nullptr, // row offsets
+ col_offsets.data(),
+ nullptr, // bias
+ conv_p_2d.OC,
+ conv_p_2d.G);
+
+ try {
+ conv_p_2d.stride[0] = 2;
+ fbgemmConv(
+ conv_p_2d,
+ Aint8.data(),
+ packedB_2D,
+ Cint8_fb.data(),
+ Cint32_fb.data(),
+ outputProcObj,
+ 0,
+ 1);
+ } catch (std::logic_error const& err) {
+ std::string s(err.what());
+ EXPECT_TRUE(s.rfind("[FBGEMM_CONV_ERROR]", 0) == 0);
+ }
+
+ // reset
+ conv_p_2d.stride[0] = stride;
+ // this should run fine
+ fbgemmConv(
+ conv_p_2d,
+ Aint8.data(),
+ packedB_2D,
+ Cint8_fb.data(),
+ Cint32_fb.data(),
+ outputProcObj,
+ 0,
+ 1);
+}
+
+/**
+ * @brief Unit test for uint8 activations, int8 weights, and 32-bit
+ * accumulation. Output processing: requantization -> nothing
+ */
+TEST_P(UniConvQGranTest, requantizeTest) {
+ vector<conv_param_t<>> shapes(GetShapes_());
+ QuantizationGranularity q_granularity;
+ bool a_symmetric, b_symmetric;
+ bool test_bias, test_float_bias;
+ tie(q_granularity, a_symmetric, b_symmetric, test_bias, test_float_bias) =
+ GetParam();
+
+ for (auto conv_p : shapes) {
+ int R = conv_p.K[0];
+ int S = conv_p.K[1];
+ int G = conv_p.G;
+ int OC = conv_p.OC;
+ int OH = conv_p.OUT_DIM[0];
+ int OW = conv_p.OUT_DIM[1];
+ int IC_per_G = conv_p.IC / conv_p.G;
+ int OC_per_G = conv_p.OC / conv_p.G;
+
+ // activations
+ aligned_vector<uint8_t> Aint8(
+ conv_p.MB * conv_p.IN_DIM[0] * conv_p.IN_DIM[1] * conv_p.IC, 0);
+
+ // weights
+ // The weight matrix is in layout G K/G (R S C/G)
+ aligned_vector<int8_t> Bint8(R * S * conv_p.G * IC_per_G * OC_per_G, 0);
+ aligned_vector<int8_t> Bint8_tr(R * S * G * IC_per_G * OC_per_G, 0);
+
+ aligned_vector<int32_t> Cint32_ref(conv_p.MB * OH * OW * OC, 0);
+ aligned_vector<int32_t> Cint32_fb(Cint32_ref.size(), 0);
+ aligned_vector<uint8_t> Cint8_ref(Cint32_ref.size(), 0);
+ aligned_vector<uint8_t> Cint8_fb(Cint32_ref.size(), 0);
+
+ randFill<uint8_t>(Aint8, 0, 5);
+ int32_t Aint8_zero_point = a_symmetric ? 0 : 4;
+
+ randFill<int8_t>(Bint8, -4, 4);
+
+ // computing column offset
+ vector<int32_t> col_offsets(G * OC_per_G);
+
+ int ncols_per_quant_group = G * OC_per_G;
+ if (q_granularity == QuantizationGranularity::GROUP) {
+ ncols_per_quant_group = OC_per_G;
+ } else if (q_granularity == QuantizationGranularity::OUT_CHANNEL) {
+ ncols_per_quant_group = 1;
+ }
+
+ aligned_vector<int32_t> Bint8_zero_point(
+ G * OC_per_G / ncols_per_quant_group);
+ if (b_symmetric) {
+ randFill(Bint8_zero_point, -3, 3);
+ } else {
+ randFill(Bint8_zero_point, 0, 0);
+ }
+
+ // matrix dimensions after im2col for each GEMM.
+ // For each group, there is one GEMM of the following dimensions
+ int MDim = conv_p.MB * OH * OW;
+ int NDim = OC_per_G;
+ int KDim = R * S * IC_per_G;
+
+ vector<uint8_t> Aint8_im2col(MDim * KDim * G);
+ im2col_ref(conv_p, Aint8.data(), Aint8_zero_point, Aint8_im2col.data());
+
+ vector<int32_t> row_offsets(MDim);
+
+ // activation_scale * weight_scale
+ aligned_vector<float> act_times_w_scale(Bint8_zero_point.size());
+ randFill(act_times_w_scale, 0.1234f / 2, 0.1234f * 3 / 2);
+
+ float out_scale = 2.0f;
+ aligned_vector<float> C_multiplier(Bint8_zero_point.size());
+ transform(
+ act_times_w_scale.begin(),
+ act_times_w_scale.end(),
+ C_multiplier.begin(),
+ [&out_scale](float i) { return i / out_scale; });
+
+ int32_t C_zero_pt = 5;
+
+ // initialize bias
+ aligned_vector<int32_t> bias_int32(OC);
+ aligned_vector<float> bias_fp32(OC);
+ if (test_bias) {
+ randFill(bias_int32, -8, 8);
+ }
+
+ // floating point bias
+ if (test_float_bias) {
+ if (q_granularity == QuantizationGranularity::TENSOR) {
+ transform(
+ bias_int32.begin(),
+ bias_int32.end(),
+ bias_fp32.begin(),
+ [&act_times_w_scale](float i) { return i * act_times_w_scale[0]; });
+ } else if (q_granularity == QuantizationGranularity::GROUP) {
+ for (int g = 0; g < G; ++g) {
+ for (int c = 0; c < OC_per_G; ++c) {
+ bias_fp32[g * OC_per_G + c] = act_times_w_scale[g] *
+ static_cast<float>(bias_int32[g * OC_per_G + c]);
+ }
+ }
+ } else { // OUT_CHANNEL
+ transform(
+ act_times_w_scale.begin(),
+ act_times_w_scale.end(),
+ bias_int32.begin(),
+ bias_fp32.begin(),
+ multiplies<float>());
+ }
+ }
+ // reference implementation
+ // conv_ref expects weights to be in G (R S C/G) K/G
+ int8_t* rightBData = Bint8.data();
+ transposeConvWeights(conv_p, Bint8.data(), Bint8_tr.data());
+ rightBData = Bint8_tr.data();
+ for (int g = 0; g < G; ++g) {
+ col_offsets_with_zero_pt_s8acc32_ref(
+ R * S * IC_per_G,
+ OC_per_G,
+ OC_per_G,
+ rightBData + g * R * S * IC_per_G * OC_per_G,
+ Bint8_zero_point.data() + g * OC_per_G / ncols_per_quant_group,
+ col_offsets.data() + g * OC_per_G,
+ ncols_per_quant_group);
+ }
+ conv_ref(
+ conv_p, Aint8.data(), Aint8_zero_point, rightBData, Cint32_ref.data());
+
+ for (int g = 0; g < G; ++g) {
+ row_offsets_u8acc32_ref(
+ MDim,
+ KDim,
+ KDim * G,
+ Aint8_im2col.data() + g * KDim,
+ row_offsets.data());
+
+ requantize_u8acc32_ref(
+ MDim,
+ NDim,
+ G * NDim,
+ Cint32_ref.data() + g * NDim,
+ Cint8_ref.data() + g * NDim,
+ C_multiplier.data() + g * NDim / ncols_per_quant_group,
+ C_zero_pt,
+ Aint8_zero_point,
+ Bint8_zero_point.data() + g * NDim / ncols_per_quant_group,
+ row_offsets.data(),
+ col_offsets.data() + g * NDim,
+ test_bias ? bias_int32.data() + g * NDim : nullptr,
+ ncols_per_quant_group);
+ }
+
+ PackWeightsForConv<2> packedWeights(conv_p, Bint8.data());
+
+ // TODO: Uncomment once we support multiple threads in fbgemmGroupwiseConv
+ // #ifdef _OPENMP
+ // #pragma omp parallel
+ // #endif
+ {
+ vector<int32_t> row_offset_buf(rowOffsetBufferSizeGConv(conv_p));
+
+ DoNothing<> doNothingObj{};
+
+ int num_threads = fbgemm_get_num_threads();
+ int tid = fbgemm_get_thread_num();
+
+ if (q_granularity == QuantizationGranularity::TENSOR) {
+ if (test_float_bias) {
+ ReQuantizeOutput<false, QuantizationGranularity::TENSOR, float>
+ reqObj(
+ doNothingObj,
+ C_multiplier.data(),
+ C_zero_pt,
+ Aint8_zero_point,
+ Bint8_zero_point.data(),
+ nullptr, /* row offset buffer */
+ col_offsets.data(),
+ test_bias ? bias_fp32.data() : nullptr,
+ G * NDim,
+ G,
+ act_times_w_scale.data());
+
+ fbgemmConv(
+ conv_p,
+ Aint8.data(),
+ packedWeights,
+ Cint8_fb.data(),
+ Cint32_fb.data(),
+ reqObj,
+ tid,
+ num_threads);
+
+ } else {
+ ReQuantizeOutput<false, QuantizationGranularity::TENSOR> reqObj(
+ doNothingObj,
+ C_multiplier.data(),
+ C_zero_pt,
+ Aint8_zero_point,
+ Bint8_zero_point.data(),
+ nullptr, /* row offset buffer */
+ col_offsets.data(),
+ test_bias ? bias_int32.data() : nullptr,
+ G * NDim,
+ G);
+
+ fbgemmConv(
+ conv_p,
+ Aint8.data(),
+ packedWeights,
+ Cint8_fb.data(),
+ Cint32_fb.data(),
+ reqObj,
+ tid,
+ num_threads);
+ }
+
+ } else if (q_granularity == QuantizationGranularity::GROUP) {
+ if (test_float_bias) {
+ ReQuantizeOutput<false, QuantizationGranularity::GROUP, float> reqObj(
+ doNothingObj,
+ C_multiplier.data(),
+ C_zero_pt,
+ Aint8_zero_point,
+ Bint8_zero_point.data(),
+ nullptr, /* row offset buffer */
+ col_offsets.data(),
+ test_bias ? bias_fp32.data() : nullptr,
+ G * NDim,
+ G,
+ act_times_w_scale.data());
+
+ fbgemmConv(
+ conv_p,
+ Aint8.data(),
+ packedWeights,
+ Cint8_fb.data(),
+ Cint32_fb.data(),
+ reqObj,
+ tid,
+ num_threads);
+
+ } else {
+ ReQuantizeOutput<false, QuantizationGranularity::GROUP> reqObj(
+ doNothingObj,
+ C_multiplier.data(),
+ C_zero_pt,
+ Aint8_zero_point,
+ Bint8_zero_point.data(),
+ nullptr, /* row offset buffer */
+ col_offsets.data(),
+ test_bias ? bias_int32.data() : nullptr,
+ G * NDim,
+ G);
+
+ fbgemmConv(
+ conv_p,
+ Aint8.data(),
+ packedWeights,
+ Cint8_fb.data(),
+ Cint32_fb.data(),
+ reqObj,
+ tid,
+ num_threads);
+ }
+
+ } else {
+ if (test_float_bias) {
+ ReQuantizeOutput<false, QuantizationGranularity::OUT_CHANNEL, float>
+ reqObj(
+ doNothingObj,
+ C_multiplier.data(),
+ C_zero_pt,
+ Aint8_zero_point,
+ Bint8_zero_point.data(),
+ nullptr, /* row offset buffer */
+ col_offsets.data(),
+ test_bias ? bias_fp32.data() : nullptr,
+ G * NDim,
+ G,
+ act_times_w_scale.data());
+
+ fbgemmConv(
+ conv_p,
+ Aint8.data(),
+ packedWeights,
+ Cint8_fb.data(),
+ Cint32_fb.data(),
+ reqObj,
+ tid,
+ num_threads);
+
+ } else {
+ ReQuantizeOutput<false, QuantizationGranularity::OUT_CHANNEL> reqObj(
+ doNothingObj,
+ C_multiplier.data(),
+ C_zero_pt,
+ Aint8_zero_point,
+ Bint8_zero_point.data(),
+ nullptr, /* row offset buffer */
+ col_offsets.data(),
+ test_bias ? bias_int32.data() : nullptr,
+ G * NDim,
+ G);
+
+ fbgemmConv(
+ conv_p,
+ Aint8.data(),
+ packedWeights,
+ Cint8_fb.data(),
+ Cint32_fb.data(),
+ reqObj,
+ tid,
+ num_threads);
+ }
+ }
+ } // omp parallel
+
+ compare_validate_buffers(
+ Cint8_ref.data(),
+ Cint8_fb.data(),
+ MDim,
+ NDim * G,
+ NDim * G,
+ static_cast<uint8_t>(0));
+ } // for each shape
+}
diff --git a/third_party/asmjit b/third_party/asmjit
-Subproject 673dcefaa048c5f5a2bf8b85daf8f7b9978d018
+Subproject 4da474ac9aa2689e88d5e40a2f37628f302d7e3