diff options
-rw-r--r-- | CMakeLists.txt | 2 | ||||
-rw-r--r-- | bench/GroupwiseConvRequantizeBenchmark.cc | 507 | ||||
-rw-r--r-- | include/fbgemm/Fbgemm.h | 85 | ||||
-rw-r--r-- | src/Fbgemm.cc | 16 | ||||
-rw-r--r-- | src/GroupwiseConv.h | 248 | ||||
-rw-r--r-- | src/GroupwiseConvAcc32Avx2.cc | 1552 | ||||
-rw-r--r-- | src/PackWeightMatrixForGConv.cc | 103 | ||||
-rw-r--r-- | src/RefImplementations.cc | 25 | ||||
-rw-r--r-- | src/RefImplementations.h | 8 | ||||
-rw-r--r-- | test/GConvTest.cc | 382 |
10 files changed, 2928 insertions, 0 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt index 5d889cd..dfb5623 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -34,12 +34,14 @@ set(FBGEMM_GENERIC_SRCS src/ExecuteKernel.cc src/GenerateKernelU8S8S32ACC16Avx512.cc src/GenerateKernelU8S8S32ACC32.cc src/GenerateKernelU8S8S32ACC32Avx512.cc + src/GroupwiseConvAcc32Avx2.cc src/PackAMatrix.cc src/PackAWithIm2Col.cc src/PackBMatrix.cc src/PackMatrix.cc src/PackAWithQuantRowOffset.cc src/PackAWithRowOffset.cc + src/PackWeightMatrixForGConv.cc src/QuantUtils.cc src/RefImplementations.cc src/Utils.cc) diff --git a/bench/GroupwiseConvRequantizeBenchmark.cc b/bench/GroupwiseConvRequantizeBenchmark.cc new file mode 100644 index 0000000..032d0d3 --- /dev/null +++ b/bench/GroupwiseConvRequantizeBenchmark.cc @@ -0,0 +1,507 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include <algorithm> +#include <chrono> +#include <cmath> +#include <iomanip> +#include <iostream> +#include <random> +#include <vector> + +#ifdef _OPENMP +#include <omp.h> +#endif + +#include "BenchUtils.h" +#include "fbgemm/Fbgemm.h" +#include "src/RefImplementations.h" + +using namespace std; +using namespace fbgemm; + +void performance_test() { + vector<conv_param_t<>> shapes = { + // MB, IC, OC, {IH, IW}, G, {KH, KW}, {stride_h, stride_w}, pad_t, pad_l, + // pad_b, pad_r + // conv_param_t<>(1, 16, 16, {16, 14}, 4, {3, 3}, {1, 1}, {1, 1, 1, 1}), + conv_param_t<>(1, 128, 128, {56, 48}, 32, {3, 3}, {1, 1}, {1, 1, 1, 1}), + conv_param_t<>(1, 128, 128, {48, 56}, 32, {3, 3}, {1, 1}, {1, 1, 1, 1}), + conv_param_t<>(1, 128, 128, {56, 56}, 32, {3, 3}, {1, 1}, {1, 1, 1, 1}), + conv_param_t<>(2, 128, 128, {56, 56}, 32, {3, 3}, {1, 1}, {1, 1, 1, 1}), + // conv_param_t<>(1, 256, 256, {56, 56}, 64, {3, 3}, {1, 1}, {1, 1, 1, + // 1}), + // conv_param_t<>(1, 3, 64, {224, 224}, 1, {7, 7}, {2, 2}, {3, 3, 3, 3}), + // conv_param_t<>(1, 128, 128, {56, 56}, 32, {3, 3}, {1, 1}, {1, 1, 1, + // 1}), + // conv_param_t<>(1, 128, 128, {56, 56}, 32, {3, 3}, {1, 1}, {1, 1, 1, + // 1}), + // conv_param_t<>(1, 256, 256, {56, 56}, 32, {3, 3}, {2, 2}, {1, 1, 1, + // 1}), + // conv_param_t<>(1, 256, 256, {28, 28}, 32, {3, 3}, {1, 1}, {1, 1, 1, + // 1}), + // conv_param_t<>(1, 512, 512, {28, 28}, 32, {3, 3}, {2, 2}, {1, 1, 1, + // 1}), + // conv_param_t<>(1, 512, 512, {14, 14}, 32, {3, 3}, {1, 1}, {1, 1, 1, + // 1}), + // conv_param_t<>(1, 512, 512, {14, 14}, 32, {3, 3}, {1, 1}, {1, 1, 1, + // 1}), + // conv_param_t<>(1, 1024, 1024, {14, 14}, 32, {3, 3}, {2, 2}, {1, 1, 1, + // 1}), + // conv_param_t<>(1, 1024, 1024, {7, 7}, 32, {3, 3}, {1, 1}, {1, 1, 1, + // 1}), + // conv_param_t<>(1, 1024, 1024, {7, 7}, 32, {3, 3}, {1, 1}, {1, 1, 1, + // 1}), + // BatchSize > 1 + // conv_param_t<>(2, 128, 128, {56, 48}, 32, {3, 3}, {1, 1}, {1, 1, 1, + // 1}), + }; + + bool flush = true; + std::vector<char> llc; + + if (flush) { + llc.resize(128 * 1024 * 1024, 1.0); + } + + constexpr int NWARMUP = 4; + constexpr int NITER = 10; + +#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN + cout << "WARNING: the timer may be inaccurate when used by multiple threads." + << endl; + cout << "MB, " + << "IC, " + << "OC, " + << "IH, " + << "IW, " + << "KH, " + << "KW, " + << "stride_h, " + << "stride_w, " + << "pad_h, " + << "pad_w, " + << "Type, " + << "M, " + << "N, " + << "K, " + << "Im2Col (ms), " + << "Packing (ms), " + << "Kernel (ms), " + << "Postprocessing (ms), " + << "fbgemmPacked (ms), " + << "Total (ms), " + << "GOPS" << endl; +#else + cout << setw(8) << "MB, " + << "IC, " + << "OC, " + << "IH, " + << "IW, " + << "KH, " + << "KW, " + << "stride_h, " + << "stride_w, " + << "pad_h, " + << "pad_w, " + << "Type, " + << "M, " + << "N, " + << "K, " << setw(5) << "GOPS" << endl; +#endif + + chrono::time_point<chrono::high_resolution_clock> begin, end; + for (auto conv_p : shapes) { + if (conv_p.IC % conv_p.G != 0) { + cout << "Error: Number of input channels " << conv_p.IC + << " is not a multiple of groups " << conv_p.G << endl; + continue; + } + if (conv_p.OC % conv_p.G != 0) { + cout << "Error: Number of output channels " << conv_p.OC + << " is not a multiple of groups " << conv_p.G << endl; + continue; + } + + int IC_per_G = conv_p.IC / conv_p.G; + int OC_per_G = conv_p.OC / conv_p.G; + + aligned_vector<uint8_t> Aint8( + conv_p.MB * conv_p.IN_DIM[0] * conv_p.IN_DIM[1] * conv_p.IC, 0); + + // aligned_vector<uint8_t> Aint8_im2col( + // conv_p.MB * conv_p.OUT_DIM[0] * conv_p.OUT_DIM[1] * conv_p.K[0] * + // conv_p.K[1] * conv_p.IC, + // 0); + + aligned_vector<int8_t> Bint8( + conv_p.K[0] * conv_p.K[1] * conv_p.G * IC_per_G * OC_per_G, 0); + aligned_vector<int8_t> Bp( + conv_p.K[0] * conv_p.K[1] * conv_p.G * IC_per_G * OC_per_G, 0); + + aligned_vector<int32_t> Cint32_ref( + conv_p.MB * conv_p.OUT_DIM[0] * conv_p.OUT_DIM[1] * conv_p.OC, 0); + + aligned_vector<uint8_t> Cint8_ref( + conv_p.MB * conv_p.OUT_DIM[0] * conv_p.OUT_DIM[1] * conv_p.OC, 0); + + aligned_vector<int32_t> Cint32_fb_fused( + conv_p.MB * conv_p.OUT_DIM[0] * conv_p.OUT_DIM[1] * conv_p.OC, 0); + + aligned_vector<uint8_t> Cint8_fb_fused( + conv_p.MB * conv_p.OUT_DIM[0] * conv_p.OUT_DIM[1] * conv_p.OC, 0); + + aligned_vector<int32_t> Cint32_fb_direct( + conv_p.MB * conv_p.OUT_DIM[0] * conv_p.OUT_DIM[1] * conv_p.OC, 0); + + aligned_vector<uint8_t> Cint8_fb_direct( + conv_p.MB * conv_p.OUT_DIM[0] * conv_p.OUT_DIM[1] * conv_p.OC, 0); + + // cout << conv_p.toString() << endl; + + // A matrix (input activations) + randFill<uint8_t>(Aint8, 0, 5); + int32_t Aint8_zero_point = 4; + + // B matrix (weights) + randFill<int8_t>(Bint8, -4, 4); + aligned_vector<int32_t> Bint8_zero_point(1); + randFill(Bint8_zero_point, -3, -1); + + aligned_vector<float> C_multiplier(Bint8_zero_point.size()); + randFill(C_multiplier, 0.1234f / 2, 0.1234f * 3 / 2); + int32_t C_zero_pt = 5; + + int R = conv_p.K[0]; + int S = conv_p.K[1]; + + // reference implementation + conv_ref( + conv_p, + Aint8.data(), + Aint8_zero_point, + Bint8.data(), + Cint32_ref.data()); + + // matrix dimensions after im2col + int MDim = conv_p.MB * conv_p.OUT_DIM[0] * conv_p.OUT_DIM[1]; + int NDim = conv_p.OC / conv_p.G; + int KDim = conv_p.K[0] * conv_p.K[1] * conv_p.IC; + + // computing row offset + vector<int32_t> row_offsets(MDim); + vector<uint8_t> Aint8_im2col(MDim * KDim); + im2col_ref(conv_p, Aint8.data(), Aint8_zero_point, Aint8_im2col.data()); + + // computing column offset + vector<int32_t> col_offsets(conv_p.OC); + for (int g = 0; g < conv_p.G; ++g) { + col_offsets_with_zero_pt_s8acc32_ref( + R * S * IC_per_G, + OC_per_G, + OC_per_G, + Bint8.data() + g * R * S * IC_per_G * OC_per_G, + Bint8_zero_point.data(), + col_offsets.data() + g * OC_per_G, + conv_p.OC); + } + + for (int g = 0; g < conv_p.G; ++g) { + row_offsets_u8acc32_ref( + MDim, + R * S * IC_per_G, + KDim, + Aint8_im2col.data() + g * R * S * IC_per_G, + row_offsets.data()); + + requantize_u8acc32_ref( + MDim, + NDim, + conv_p.G * NDim, + Cint32_ref.data() + g * NDim, + Cint8_ref.data() + g * NDim, + C_multiplier.data() + g * NDim / conv_p.OC, + C_zero_pt, + Aint8_zero_point, + Bint8_zero_point.data() + g * NDim / conv_p.OC, + row_offsets.data(), + col_offsets.data() + g * NDim, + nullptr, + conv_p.OC); + } + // printMatrix(matrix_op_t::NoTranspose, Cint8_ref.data(), MDim, NDim, NDim, + // "B unpacked"); + + // printMatrix(matrix_op_t::NoTranspose, Bint8.data(), KDim, NDim, NDim, + // "B unpacked"); + // packedB.printPackedMatrix("B Packed"); + + double nops = 2.0 * static_cast<double>(NITER) * MDim * NDim * KDim; + double ttot = 0.0; + string runType; + + vector<int32_t> row_offset_buf; + row_offset_buf.resize( + PackAWithIm2Col<uint8_t, int32_t>::rowOffsetBufferSize()); + + PackAWithIm2Col<uint8_t, int32_t> packA( + conv_p, Aint8.data(), nullptr, Aint8_zero_point, row_offset_buf.data()); + + PackBMatrix<int8_t, int32_t> packedB( + matrix_op_t::NoTranspose, + KDim, + NDim, + Bint8.data(), + NDim, + nullptr, + conv_p.G); + + // no-op output process objects + DoNothing<> doNothingObj{}; + ReQuantizeOutput<false, QuantizationGranularity::TENSOR> outputProcObj( + doNothingObj, + C_multiplier.data(), + C_zero_pt, + Aint8_zero_point, + Bint8_zero_point.data(), + packA.getRowOffsetBuffer(), + col_offsets.data(), + nullptr, + conv_p.G * NDim, + conv_p.G); + + runType = "FusedIm2Col"; + ttot = 0; +#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN + double im2col_time = 0.0; + double total_im2col_time = 0.0; + double total_packing_time = 0.0; + double total_computing_time = 0.0; + double total_kernel_time = 0.0; + double total_postprocessing_time = 0.0; + double total_run_time = 0.0; +#endif + for (auto i = 0; i < NWARMUP + NITER; ++i) { +#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN + packing_time = 0.0; + computing_time = 0.0; + kernel_time = 0.0; + postprocessing_time = 0.0; + run_time = 0.0; +#endif + llc_flush(llc); + begin = chrono::high_resolution_clock::now(); + fbgemmPacked( + packA, + packedB, + Cint8_fb_fused.data(), + Cint32_fb_fused.data(), + conv_p.G * NDim, + outputProcObj, + 0, + 1); + end = chrono::high_resolution_clock::now(); + + if (i >= NWARMUP) { + auto dur = chrono::duration_cast<chrono::nanoseconds>(end - begin); + ttot += dur.count(); +#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN + total_packing_time += packing_time; + total_computing_time += computing_time; + total_kernel_time += kernel_time; + total_postprocessing_time += postprocessing_time; + total_run_time += run_time; +#endif + } + } + + cout << setw(4) << conv_p.MB << ", " << conv_p.IC << ", " << conv_p.OC + << ", " << conv_p.IN_DIM[0] << ", " << conv_p.IN_DIM[1] << ", " + << conv_p.G << ", " << conv_p.K[0] << ", " << conv_p.K[1] << ", " + << conv_p.stride[0] << ", " << conv_p.stride[1] << ", " + << conv_p.pad[0] << ", " << conv_p.pad[1] << ", "; + + cout << setw(13) << runType << ", " << setw(5) << fixed << setw(5) + << setw(6) << MDim << ", " << setw(6) << NDim << ", " << setw(6) + << KDim << ", "; +#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN + cout << fixed << setprecision(6) << setw(8) << 0 << ", " + << total_packing_time / (double)NITER / 1e6 << ", " + << total_kernel_time / (double)NITER / 1e6 << ", " + << total_postprocessing_time / (double)NITER / 1e6 << ", " + << total_run_time / (double)NITER / 1e6 << ", " + << ttot / (double)NITER / 1e6 << ", "; +#endif + cout << setprecision(2) << nops / ttot << endl; + + // correctness check + for (int n = 0; n < conv_p.MB; ++n) { + for (int h = 0; h < conv_p.OUT_DIM[0]; ++h) { + for (int w = 0; w < conv_p.OUT_DIM[1]; ++w) { + for (int k = 0; k < conv_p.OC; ++k) { + int32_t expected = Cint8_ref + [((n * conv_p.OUT_DIM[0] + h) * conv_p.OUT_DIM[1] + w) * + conv_p.OC + + k]; + int32_t actual = Cint8_fb_fused + [((n * conv_p.OUT_DIM[0] + h) * conv_p.OUT_DIM[1] + w) * + conv_p.OC + + k]; + if (expected != actual) { + cout << "Im2Col fused results differ at (" << n << ", " << h + << ", " << w << ", " << k << ")." + << " expected:" << expected << " actual:" << actual << endl; + } + } + } + } + } + // compare_buffers(Cint32_ref.data(), Cint32_fb_fused.data(), MDim, NDim * + // conv_p.G, NDim*conv_p.G, 5); + + runType = "direct"; + ttot = 0; + + vector<int32_t> row_offset_buf_direct(rowOffsetBufferSizeGConv(conv_p)); + + PackWeightMatrixForGConv<int8_t> packedWeights( + matrix_op_t::NoTranspose, conv_p, Bint8.data(), nullptr); + + ReQuantizeOutput<false, QuantizationGranularity::TENSOR> reqObj( + doNothingObj, + C_multiplier.data(), + C_zero_pt, + Aint8_zero_point, + Bint8_zero_point.data(), + row_offset_buf_direct.data(), + col_offsets.data(), + nullptr, + conv_p.OC, + conv_p.G); + +#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN + total_im2col_time = 0.0; + total_packing_time = 0.0; + total_computing_time = 0.0; + total_kernel_time = 0.0; + total_postprocessing_time = 0.0; + total_run_time = 0.0; +#endif + for (auto i = 0; i < NWARMUP + NITER; ++i) { +#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN + im2col_time = 0.0; + packing_time = 0.0; + computing_time = 0.0; + kernel_time = 0.0; + postprocessing_time = 0.0; + run_time = 0.0; +#endif + llc_flush(llc); + begin = chrono::high_resolution_clock::now(); + + // im2col_ref(conv_p, Aint8.data(), Aint8_zero_point, + // Aint8_im2col.data()); + +#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN + end = chrono::high_resolution_clock::now(); + im2col_time = + chrono::duration_cast<chrono::nanoseconds>(end - begin).count(); +#endif + + // printMatrix(matrix_op_t::NoTranspose, Aint8_im2col.data(), MDim, KDim, + // KDim, "A_out after im2col unpacked"); + + fbgemmGroupwiseConv( + conv_p, + Aint8.data(), + Aint8_zero_point, + row_offset_buf_direct.data(), + packedWeights, + Cint8_fb_direct.data(), + Cint32_fb_direct.data(), + reqObj, + 0, + 1); + + end = chrono::high_resolution_clock::now(); + + if (i >= NWARMUP) { + auto dur = chrono::duration_cast<chrono::nanoseconds>(end - begin); + ttot += dur.count(); +#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN + total_im2col_time += im2col_time; + total_packing_time += packing_time; + total_computing_time += computing_time; + total_kernel_time += kernel_time; + total_postprocessing_time += postprocessing_time; + total_run_time += run_time; +#endif + } + } + + ((volatile char*)(llc.data())); + + // packedB.printPackedMatrix("bench B Packed"); + // printMatrix(matrix_op_t::NoTranspose, Cint32_fb_fused.data(), MDim, NDim, + // NDim, "C fb fp32"); printMatrix(matrix_op_t::NoTranspose, + // Cint32_fb_direct.data(), MDim, NDim, NDim, "C fb2 fp32"); + // printMatrix(matrix_op_t::NoTranspose, + // Cint32_ref.data(), MDim, NDim, NDim, "C ref fp32"); + + cout << setw(4) << conv_p.MB << ", " << conv_p.IC << ", " << conv_p.OC + << ", " << conv_p.IN_DIM[0] << ", " << conv_p.IN_DIM[1] << ", " + << conv_p.G << ", " << conv_p.K[0] << ", " << conv_p.K[1] << ", " + << conv_p.stride[0] << ", " << conv_p.stride[1] << ", " + << conv_p.pad[0] << ", " << conv_p.pad[1] << ", "; + + cout << setw(13) << runType << ", " << setw(5) << fixed << setw(5) + << setw(6) << MDim << ", " << setw(6) << NDim << ", " << setw(6) + << KDim << ", "; +#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN + cout << fixed << setprecision(6) << setw(8) + << total_im2col_time / (double)NITER / 1e6 << ", " + << total_packing_time / (double)NITER / 1e6 << ", " + << total_kernel_time / (double)NITER / 1e6 << ", " + << total_postprocessing_time / (double)NITER / 1e6 << ", " + << total_run_time / (double)NITER / 1e6 << ", " + << ttot / (double)NITER / 1e6 << ", "; +#endif + cout << setprecision(2) << nops / ttot << endl; + + // correctness check + for (int n = 0; n < conv_p.MB; ++n) { + for (int h = 0; h < conv_p.OUT_DIM[0]; ++h) { + for (int w = 0; w < conv_p.OUT_DIM[1]; ++w) { + for (int k = 0; k < conv_p.OC; ++k) { + int32_t expected = Cint8_ref + [((n * conv_p.OUT_DIM[0] + h) * conv_p.OUT_DIM[1] + w) * + conv_p.OC + + k]; + int32_t actual = Cint8_fb_direct + [((n * conv_p.OUT_DIM[0] + h) * conv_p.OUT_DIM[1] + w) * + conv_p.OC + + k]; + if (expected != actual) { + cout << "direct conv results differ at (" << n << ", " << h + << ", " << w << ", " << k << ")." + << " expected:" << expected << " actual:" << actual << endl; + } + } + } + } + } + // compare_buffers(Cint32_ref.data(), Cint32_fb_direct.data(), MDim, + // NDim*conv_p.G, NDim*conv_p.G, 5); + } // shapes +} + +int main() { +#ifdef _OPENMP + omp_set_num_threads(1); +#endif + performance_test(); + return 0; +} diff --git a/include/fbgemm/Fbgemm.h b/include/fbgemm/Fbgemm.h index bca5347..f49da57 100644 --- a/include/fbgemm/Fbgemm.h +++ b/include/fbgemm/Fbgemm.h @@ -453,6 +453,56 @@ class FBGEMM_API PackBMatrix final }; /** + * @brief Matrix packed for direct group convolution. + * The source matrix is already quantized. Default accumulation + * type is int32. + */ +template <typename T, typename accT = std::int32_t, int SPATIAL_DIM = 2> +class FBGEMM_API PackWeightMatrixForGConv { + public: + using This = PackWeightMatrixForGConv<T, accT, SPATIAL_DIM>; + using inpType = T; + using accType = accT; + + PackWeightMatrixForGConv() = delete; // no default constructor + + /** + * @params pmat if nullptr, a buffer is allocated and owned by this class. + * + */ + PackWeightMatrixForGConv( + matrix_op_t trans, + const conv_param_t<SPATIAL_DIM>& conv_param, + const inpType* sdata, + inpType* pdata = nullptr); + + /** + * @brief Packs a block of source matrix into pmat buffer. + */ + void pack(); + + /** + * @brief Return packed data + */ + inpType* getBuf() { + return pdata_; + } + + ~PackWeightMatrixForGConv() { + if (bufAllocatedHere_) { + free(pdata_); + } + } + + private: + matrix_op_t trans_; + const conv_param_t<SPATIAL_DIM> conv_param_; + const T* sdata_; + T* pdata_; + bool bufAllocatedHere_; +}; + +/** * @brief Matrix packed for the first input matrix in GEMM (usually activation), * and row offsets used for requantization is computed during packing. * Im2col is fused with packing here. The source matrix is already @@ -1106,6 +1156,35 @@ FBGEMM_API void fbgemmPacked( int num_threads); /** + * @brief Perform small-channels-per-group groupwise convolution + * + */ + +template < + typename packed_W, + typename outType, + typename processOutputType, + int SPATIAL_DIM = 2> +FBGEMM_API void fbgemmGroupwiseConv( + const conv_param_t<SPATIAL_DIM>& conv_param, + const std::uint8_t* activations, + std::int32_t a_zero_point, + std::int32_t* rowOffsetBuf, + packed_W& packed_weights, + outType* out, + std::int32_t* outBuffer, + const processOutputType& outProcess, + int thread_id, + int num_threads); +/** + * @return Size of row offset buffer in number of elements needed for + * fbgemmGroupwiseConv + */ +template <int SPATIAL_DIM = 2> +FBGEMM_API int rowOffsetBufferSizeGConv( + const conv_param_t<SPATIAL_DIM>& conv_param); + +/** * @brief Perform depthwise separable convolution */ template < @@ -1122,6 +1201,12 @@ void convDepthwiseSeparable( const processOutputType& output); /** + * @brief Is this groupwise convolution supported? + */ +template <int SPATIAL_DIM> +FBGEMM_API bool fbgemmOptimizedGConv(const conv_param_t<SPATIAL_DIM>& conv_p); + +/** * @brief Allocate __size bytes of uninitialized storage whose alignment is * specified by __align. */ diff --git a/src/Fbgemm.cc b/src/Fbgemm.cc index 45108d0..9384af6 100644 --- a/src/Fbgemm.cc +++ b/src/Fbgemm.cc @@ -192,6 +192,22 @@ void fbgemmPacked( #endif } +template <int SPATIAL_DIM> +FBGEMM_API bool fbgemmOptimizedGConv(const conv_param_t<SPATIAL_DIM>& conv_p) { + int C_per_G = conv_p.IC / conv_p.G; + int K_per_G = conv_p.OC / conv_p.G; + + return (SPATIAL_DIM == 2) && (C_per_G == K_per_G) && (C_per_G == 4) && + (conv_p.G % 8 == 0) && (conv_p.K[0] == conv_p.K[1]) && + (conv_p.K[0] == 3) && (conv_p.pad[0] == 1) && (conv_p.pad[1] == 1) && + (conv_p.pad[0] == conv_p.pad[2]) && (conv_p.pad[1] == conv_p.pad[3]) && + (conv_p.dilation[0] == 1) && (conv_p.dilation[0] == conv_p.dilation[1]) && + (conv_p.stride[0] == 1) && (conv_p.stride[0] == conv_p.stride[1]); +} + +template bool fbgemmOptimizedGConv(const conv_param_t<2>& conv_p); +template bool fbgemmOptimizedGConv(const conv_param_t<3>& conv_p); + bool fbgemmSupportedCPU() { return (cpuinfo_initialize() && cpuinfo_has_x86_avx2()); } diff --git a/src/GroupwiseConv.h b/src/GroupwiseConv.h new file mode 100644 index 0000000..a46a895 --- /dev/null +++ b/src/GroupwiseConv.h @@ -0,0 +1,248 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#pragma once +#include <asmjit/asmjit.h> +#include <cpuinfo.h> +#include <cassert> +#include <cstdint> +#include <map> +#include <string> +#include <tuple> +#include "fbgemm/ConvUtils.h" +#include "fbgemm/Fbgemm.h" +#include "fbgemm/Utils.h" +/*#define FBGEMM_LOG_CODE 1*/ + +namespace fbgemm { + +namespace x86 = asmjit::x86; + +using jit_conv_kernel_fp = void (*)( + const uint8_t* in_acts, + int8_t* wghts, + int32_t* out_acts, + int32_t a_zero_pt, + int32_t height, + int32_t width); + +using jit_rowoffset_kernel_fp = void (*)( + const uint8_t* in_acts, + int32_t a_zero_pt, + int32_t height, + int32_t width, + int32_t* row_offset); + +template <typename accT = int32_t> +class GenConvKernel { + public: + GenConvKernel(const conv_param_t<>& conv_param, std::int32_t zero_point) + : WRegs_avx2_{x86::ymm0, + x86::ymm1, + x86::ymm2, + x86::ymm3, + x86::ymm4, + x86::ymm5, + x86::ymm6, + x86::ymm7, + x86::ymm8} { + // vector width in bits + if (cpuinfo_initialize()) { + if (cpuinfo_has_x86_avx512f()) { + vectorWidth_ = 512; + } else if (cpuinfo_has_x86_avx2()) { + vectorWidth_ = 256; + } else { + // TODO: Have default path + assert(0 && "unsupported architecture"); + return; + } + } else { + throw std::runtime_error("Failed to initialize cpuinfo!"); + } + zeroPTRegAvx2_ = x86::ymm9; + oneReg8BitAvx2_ = x86::ymm10; + tmpReg1Avx2_ = x86::ymm11; + stPermRegAvx2_ = x86::ymm12; + actRegAvx2_ = x86::ymm13; + resultRegAvx2_ = x86::ymm14; + oneReg16BitAvx2_ = x86::ymm15; + + // vector width in elements; Each element is int8 or uint8 + VLEN_ = vectorWidth_ / 8; + + if (zero_point == 0) { + isZeroPointZero_ = true; + } else { + isZeroPointZero_ = false; + } + + G_ = conv_param.G; + K_per_G_ = conv_param.OC / conv_param.G; + K_ = conv_param.OC; + C_per_G_ = conv_param.IC / conv_param.G; + C_ = conv_param.IC; + R_ = conv_param.K[0]; + S_ = conv_param.K[1]; + H_ = conv_param.OUT_DIM[0]; + W_ = conv_param.OUT_DIM[1]; + H_PAD_ = conv_param.pad[0]; + W_PAD_ = conv_param.pad[1]; + + assert(fbgemmOptimizedGConv(conv_param)); + } + + template <inst_set_t instSet> + std::string getCodeLoggingFile(bool rowOffsetKernel = false) { + std::string fileName = "conv_"; + fileName += "G-" + std::to_string(G_); + fileName += "_K-" + std::to_string(K_); + fileName += "_C-" + std::to_string(C_); + fileName += "_R-" + std::to_string(R_); + fileName += "_S-" + std::to_string(S_); + fileName += "_PADH-" + std::to_string(H_PAD_); + fileName += "_PADW-" + std::to_string(W_PAD_); + fileName += "_isZeroPointZero-" + std::to_string(isZeroPointZero_); + if (rowOffsetKernel) { + fileName += "_rowOffset"; + } + + if (instSet == inst_set_t::avx512) { + fileName += "_avx512"; + } else if (instSet == inst_set_t::avx2) { + fileName += "_avx2"; + } + fileName += ".txt"; + return fileName; + } + + ~GenConvKernel() {} + + template <inst_set_t instSet> + jit_conv_kernel_fp getOrCreate(const conv_param_t<>& conv_param); + + template <inst_set_t instSet> + jit_rowoffset_kernel_fp getOrCreateRowOffset( + const conv_param_t<>& conv_param); + + template <inst_set_t instSet> + void createVector16BitOne(asmjit::X86Emitter* a); + + template <inst_set_t instSet> + void createVector8BitOne(asmjit::X86Emitter* a); + + template <inst_set_t instSet> + void setToZeroPt(asmjit::X86Emitter* a, asmjit::X86Ymm destReg); + + template <inst_set_t instSet> + void + gen8bitFMA(asmjit::X86Emitter* a, asmjit::X86Ymm aReg, asmjit::X86Ymm wReg); + + template <inst_set_t instSet> + void genForLoadingWeights(asmjit::X86Emitter* a); + + template <inst_set_t instSet> + void genConstForPermutations(asmjit::X86Emitter* a); + + template <inst_set_t instSet> + void genForTopEdge(asmjit::X86Emitter* a); + + template <inst_set_t instSet> + void genForLeftEdge(asmjit::X86Emitter* a); + + template <inst_set_t instSet> + void genForRightEdge(asmjit::X86Emitter* a); + + template <inst_set_t instSet> + void genForBottomEdge(asmjit::X86Emitter* a); + + template <inst_set_t instSet> + void genCoreInsts(asmjit::X86Emitter* a); + + template <inst_set_t instSet> + void storeResult(asmjit::X86Emitter* a, int offset = 0); + + // for Rowoffset kernel + template <inst_set_t instSet> + void gen8BitSum(asmjit::X86Emitter* a, asmjit::X86Ymm aReg); + + template <inst_set_t instSet> + void genForTopEdgeRowoffset(asmjit::X86Emitter* a); + + template <inst_set_t instSet> + void genForLeftEdgeRowoffset(asmjit::X86Emitter* a); + + template <inst_set_t instSet> + void genForRightEdgeRowoffset(asmjit::X86Emitter* a); + + template <inst_set_t instSet> + void genForBottomEdgeRowoffset(asmjit::X86Emitter* a); + + template <inst_set_t instSet> + void genRowoffsetCorners(asmjit::X86Emitter* a); + + template <inst_set_t instSet> + void genRowoffsetCore(asmjit::X86Emitter* a); + + template <inst_set_t instSet> + void storeResultRowoffset(asmjit::X86Emitter* a, int offset = 0); + + static thread_local asmjit::JitRuntime rt_; ///< JIT Runtime for asmjit. + static thread_local asmjit::CodeHolder code_; ///< JIT Code Holder for asmjit. + static thread_local std:: + map<std::tuple<bool, int, int, int>, jit_conv_kernel_fp> + codeCache_; ///< JIT Code Cache for reuse. + static thread_local std:: + map<std::tuple<bool, int, int, int>, jit_rowoffset_kernel_fp> + codeCacheRowOffset_; ///< JIT Code Cache for row offset kernel. + + private: + int vectorWidth_; ///< Vector width in bits. + int VLEN_; ///< Vector width in elements. + // avx2 specific + asmjit::X86Ymm + WRegs_avx2_[9]; ///< AVX2 ymm registers for weights in the micro-kernel. + asmjit::X86Ymm zeroPTRegAvx2_; + asmjit::X86Ymm tmpReg1Avx2_; + asmjit::X86Ymm stPermRegAvx2_; + asmjit::X86Ymm actRegAvx2_; + asmjit::X86Ymm resultRegAvx2_; + asmjit::X86Ymm oneReg8BitAvx2_; + asmjit::X86Ymm oneReg16BitAvx2_; + + // arguments to the function created + asmjit::X86Gp in_acts_R_; + asmjit::X86Gp wghts_R_; + asmjit::X86Gp out_acts_R_; + asmjit::X86Gp a_zero_pt_R_; + asmjit::X86Gp H_R_; + asmjit::X86Gp W_R_; + asmjit::X86Gp row_offset_R_; + + // Used registers + asmjit::X86Gp loopR1_; + asmjit::X86Gp loopR2_; + asmjit::X86Gp scratchReg1_; + asmjit::X86Gp scratchReg2_; + + // Other parameters + bool isZeroPointZero_; + + // current conv parameters + int G_; ///< Number of groups + int K_; ///< Number of output channels + int K_per_G_; ///< Number of output channels per group + int C_; ///< Number of input channels + int C_per_G_; ///< Number of input channels per group + int R_; ///< Filter/Kernel height + int S_; ///< Filter/Kernel width + int H_; + int W_; + int H_PAD_; ///< Padding for height (top and bottom) + int W_PAD_; ///< Padding for width (left and right) +}; + +} // namespace fbgemm diff --git a/src/GroupwiseConvAcc32Avx2.cc b/src/GroupwiseConvAcc32Avx2.cc new file mode 100644 index 0000000..8298f4c --- /dev/null +++ b/src/GroupwiseConvAcc32Avx2.cc @@ -0,0 +1,1552 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include <asmjit/asmjit.h> +#include <cpuinfo.h> +#include <immintrin.h> +#include <array> +#include <iostream> +#include <map> +#include <stdexcept> +#include <tuple> +#include "GroupwiseConv.h" +#include "RefImplementations.h" +#include "TransposeUtils.h" +#include "fbgemm/Fbgemm.h" + +namespace fbgemm { + +using namespace std; + +template <typename accT> +thread_local asmjit::JitRuntime GenConvKernel<accT>::rt_; + +template <typename accT> +thread_local asmjit::CodeHolder GenConvKernel<accT>::code_; + +template <typename accT> +thread_local std::map<std::tuple<bool, int, int, int>, jit_conv_kernel_fp> + GenConvKernel<accT>::codeCache_; + +template <typename accT> +thread_local std::map<std::tuple<bool, int, int, int>, jit_rowoffset_kernel_fp> + GenConvKernel<accT>::codeCacheRowOffset_; + +namespace x86 = asmjit::x86; + +void calculateRowOffsets( + const conv_param_t<>& conv_param, + const uint8_t* activations, + int32_t* rowOffsetBuf, + int32_t a_zero_point, + int groupNum) { + int H = conv_param.OUT_DIM[0]; + int W = conv_param.OUT_DIM[1]; + int G = conv_param.G; + int C_per_G = conv_param.IC / conv_param.G; + int H_PAD = conv_param.pad[0]; + int W_PAD = conv_param.pad[1]; + // calculate row offset + for (int h = 0; h < H; ++h) { + for (int w = 0; w < W; ++w) { + int32_t sum = 0; + for (int r = 0; r < conv_param.K[0]; ++r) { + int h_in = -H_PAD + h + r; + for (int s = 0; s < conv_param.K[1]; ++s) { + int w_in = -W_PAD + w + s; + for (int c = 0; c < C_per_G; ++c) { + if (h_in < 0 || h_in >= H || w_in < 0 || w_in >= W) { + sum += a_zero_point; + } else { + sum += + activations[((h_in * W + w_in) * G + groupNum) * C_per_G + c]; + } + } + } + } + rowOffsetBuf[h * W + w] = sum; + } + } +} + +tuple<bool, int, int, int> getKernelSig( + const conv_param_t<>& conv_param, + bool isZeroPointZero) { + int C_per_G = conv_param.IC / conv_param.G; + int K_per_G = conv_param.OC / conv_param.G; + auto kernelSig = + std::make_tuple(isZeroPointZero, conv_param.G, C_per_G, K_per_G); + return kernelSig; +} + +template <typename accT = int32_t> +jit_conv_kernel_fp getOrCreateConvKernel( + const conv_param_t<>& conv_param, + int a_zero_point) { + // Note: Wrong code is generated if it's not one of the supported convolution + assert(fbgemmOptimizedGConv<2>(conv_param)); + auto kernelSig = getKernelSig(conv_param, a_zero_point == 0); + if (GenConvKernel<accT>::codeCache_.find(kernelSig) != + GenConvKernel<accT>::codeCache_.end()) { + return GenConvKernel<accT>::codeCache_[kernelSig]; + } else { + auto genObj = GenConvKernel<accT>(conv_param, a_zero_point); + // TODO: Instruction set based dispatch + return genObj.template getOrCreate<inst_set_t::avx2>(conv_param); + } +} + +template <> +template <> +void GenConvKernel<int32_t>::createVector8BitOne<inst_set_t::avx2>( + asmjit::X86Emitter* a) { + // create 8-bit 1s + // i.e., oneReg16BitAvx2_[0:7] contains 0x01, oneReg8BitAvx2_[8:15] contains + // 0x01 and so on + a->vpcmpeqw(oneReg8BitAvx2_, oneReg8BitAvx2_, oneReg8BitAvx2_); + a->vpabsb(oneReg8BitAvx2_, oneReg8BitAvx2_); +} + +template <> +template <> +void GenConvKernel<int32_t>::createVector16BitOne<inst_set_t::avx2>( + asmjit::X86Emitter* a) { + // create 16-bit 1s + // i.e., oneReg16BitAvx2_[0:15] contains 0x0001, oneReg16BitAvx2_[16:31] + // contains 0x0001 and so on + a->vpcmpeqw(oneReg16BitAvx2_, oneReg16BitAvx2_, oneReg16BitAvx2_); + a->vpsrlw(oneReg16BitAvx2_, oneReg16BitAvx2_, 15); +} +template <> +template <> +void GenConvKernel<int32_t>::setToZeroPt<inst_set_t::avx2>( + asmjit::X86Emitter* a, + asmjit::X86Ymm destReg) { + // make destReg all zeros + a->vxorps(destReg, destReg, destReg); + asmjit::X86Xmm const_reg_xmm = x86::xmm10; + // move zero point to xmm10 + a->movq(const_reg_xmm, a_zero_pt_R_); + // make copies of zero point + a->vbroadcastsd(x86::ymm10, const_reg_xmm); + // shuffle + // overall impact is that destReg contains 32 8-bit values equal to the lower + // 8-bits of a_zero_pt_R_ + a->vpshufb(destReg, x86::ymm10, destReg); +} + +template <> +template <> +void GenConvKernel<int32_t>::genConstForPermutations<inst_set_t::avx2>( + asmjit::X86Emitter* a) { + asmjit::X86Gp permute_const_reg = a->gpzRef(12); + asmjit::X86Xmm const_reg_xmm = x86::xmm10; + // We have 1st group in even lanes and 2nd group in odd lanes. + // Permute to put 1st group to lower 128-bit and 2nd group in upper + // 128-bit. + // load 7, 5, 3, 1, 6, 4, 2, 0 in a 64-bit reg + a->mov(permute_const_reg, 0x0705030106040200); + a->movq(const_reg_xmm, permute_const_reg); + // Zero extend 8 packed 8-bit integers in the low 8 bytes of const_reg_xmm to + // 8 packed 32-bit integers in stPermRegAvx2_ + a->vpmovzxbd(stPermRegAvx2_, const_reg_xmm); +} + +template <> +template <> +void GenConvKernel<int32_t>::storeResult<inst_set_t::avx2>( + asmjit::X86Emitter* a, + int offset) { + // store with permutation + a->vpermd(resultRegAvx2_, stPermRegAvx2_, resultRegAvx2_); + a->vmovups(x86::dword_ptr(out_acts_R_, offset), resultRegAvx2_); +} + +template <> +template <> +void GenConvKernel<int32_t>::storeResultRowoffset<inst_set_t::avx2>( + asmjit::X86Emitter* a, + int offset) { + // store + a->vmovups(x86::dword_ptr(row_offset_R_, offset), resultRegAvx2_); +} + +template <> +template <> +void GenConvKernel<int32_t>::genForLoadingWeights<inst_set_t::avx2>( + asmjit::X86Emitter* a) { + // load weights + for (int r = 0; r < R_; ++r) { + for (int s = 0; s < S_; ++s) { + a->vmovaps( + WRegs_avx2_[r * S_ + s], + x86::dword_ptr( + wghts_R_, + (r * S_ + s) * G_ * K_per_G_ * C_per_G_ * sizeof(int8_t))); + } + } +} + +template <> +template <> +void GenConvKernel<int32_t>::gen8bitFMA<inst_set_t::avx2>( + asmjit::X86Emitter* a, + asmjit::X86Ymm aReg, + asmjit::X86Ymm wReg) { + a->vpmaddubsw(tmpReg1Avx2_, aReg, wReg); + a->vpmaddwd(tmpReg1Avx2_, oneReg16BitAvx2_, tmpReg1Avx2_); + a->vpaddd(resultRegAvx2_, tmpReg1Avx2_, resultRegAvx2_); +} + +template <> +template <> +void GenConvKernel<int32_t>::gen8BitSum<inst_set_t::avx2>( + asmjit::X86Emitter* a, + asmjit::X86Ymm aReg) { + a->vpmaddubsw(tmpReg1Avx2_, aReg, oneReg8BitAvx2_); + a->vpmaddwd(tmpReg1Avx2_, tmpReg1Avx2_, oneReg16BitAvx2_); + a->vpaddd(resultRegAvx2_, tmpReg1Avx2_, resultRegAvx2_); +} + +template <> +template <> +void GenConvKernel<int32_t>::genForTopEdge<inst_set_t::avx2>( + asmjit::X86Emitter* a) { + // top-left corner code + // zero out the results register + a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_); + for (int r = 0; r < R_; ++r) { + int h_in = -H_PAD_ + r; + if (h_in >= 0) { + a->imul( + scratchReg1_, + W_R_, + static_cast<asmjit::Imm>(h_in * C_ * sizeof(uint8_t))); + } + for (int s = 0; s < S_; ++s) { + int w_in = -W_PAD_ + s; + if (h_in >= 0 && w_in >= 0) { + a->vbroadcastsd( + actRegAvx2_, + x86::dword_ptr( + in_acts_R_, scratchReg1_, 0, w_in * C_ * sizeof(uint8_t))); + gen8bitFMA<inst_set_t::avx2>(a, actRegAvx2_, WRegs_avx2_[r * S_ + s]); + } else { + if (!isZeroPointZero_) { + gen8bitFMA<inst_set_t::avx2>( + a, zeroPTRegAvx2_, WRegs_avx2_[r * S_ + s]); + } + } + } + } + storeResult<inst_set_t::avx2>(a); + + a->add(out_acts_R_, static_cast<asmjit::Imm>(K_ * sizeof(int32_t))); + + // top edge excluding corners + asmjit::Label LoopTopEdge = a->newLabel(); + a->mov(loopR2_, static_cast<asmjit::Imm>(W_PAD_)); + a->bind(LoopTopEdge); + // zero out + a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_); + if (!isZeroPointZero_) { + for (int r = 0; r < H_PAD_; ++r) { + for (int s = 0; s < S_; ++s) { + gen8bitFMA<inst_set_t::avx2>(a, zeroPTRegAvx2_, WRegs_avx2_[s]); + } + } + } + for (int r = H_PAD_; r < R_; ++r) { + int h_in = -H_PAD_ + r; + a->imul( + scratchReg1_, + W_R_, + static_cast<asmjit::Imm>(h_in * C_ * sizeof(uint8_t))); + for (int s = 0; s < S_; ++s) { + a->vbroadcastsd( + actRegAvx2_, + x86::dword_ptr( + in_acts_R_, scratchReg1_, 0, s * C_ * sizeof(uint8_t))); + gen8bitFMA<inst_set_t::avx2>(a, actRegAvx2_, WRegs_avx2_[r * S_ + s]); + } + } + a->add(in_acts_R_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t))); + + storeResult<inst_set_t::avx2>(a); + + a->add(out_acts_R_, static_cast<asmjit::Imm>(K_ * sizeof(int32_t))); + a->mov(loopR1_, W_R_); + a->sub(loopR1_, static_cast<asmjit::Imm>(W_PAD_)); + a->inc(loopR2_); + a->cmp(loopR2_, loopR1_); + a->jl(LoopTopEdge); + a->mov(scratchReg2_, W_R_); + a->imul(scratchReg2_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t))); + a->sub( + scratchReg2_, + static_cast<asmjit::Imm>(2 * W_PAD_ * C_ * sizeof(uint8_t))); + a->sub(in_acts_R_, scratchReg2_); + + // top-right corner code + + // zero out + a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_); + if (!isZeroPointZero_) { + for (int r = 0; r < H_PAD_; ++r) { + for (int s = 0; s < S_; ++s) { + gen8bitFMA<inst_set_t::avx2>( + a, zeroPTRegAvx2_, WRegs_avx2_[r * S_ + s]); + } + } + } + for (int r = H_PAD_; r < R_; ++r) { + int h_in = -H_PAD_ + r; + for (int s = 0; s < S_ - W_PAD_; ++s) { + a->imul( + scratchReg1_, + W_R_, + static_cast<asmjit::Imm>(h_in * C_ * sizeof(uint8_t))); + a->mov(scratchReg2_, W_R_); + a->sub(scratchReg2_, static_cast<asmjit::Imm>(R_ - W_PAD_ - s)); + a->imul(scratchReg2_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t))); + a->add(scratchReg1_, scratchReg2_); + a->vbroadcastsd(actRegAvx2_, x86::dword_ptr(in_acts_R_, scratchReg1_)); + gen8bitFMA<inst_set_t::avx2>(a, actRegAvx2_, WRegs_avx2_[r * S_ + s]); + } + if (!isZeroPointZero_) { + for (int s = S_ - W_PAD_; s < S_; ++s) { + gen8bitFMA<inst_set_t::avx2>( + a, zeroPTRegAvx2_, WRegs_avx2_[r * S_ + s]); + } + } + } + storeResult<inst_set_t::avx2>(a); + a->add(out_acts_R_, static_cast<asmjit::Imm>(K_ * sizeof(int32_t))); + + // reset output activation pointer + a->imul(scratchReg1_, W_R_, static_cast<asmjit::Imm>(K_ * sizeof(int32_t))); + a->sub(out_acts_R_, scratchReg1_); +} + +template <> +template <> +void GenConvKernel<int32_t>::genForLeftEdge<inst_set_t::avx2>( + asmjit::X86Emitter* a) { + // left edge excluding corners + asmjit::Label LoopLeftEdge = a->newLabel(); + a->mov(loopR1_, static_cast<asmjit::Imm>(H_PAD_)); + a->bind(LoopLeftEdge); + // zero out + a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_); + a->mov(scratchReg1_, loopR1_); + a->sub(scratchReg1_, static_cast<asmjit::Imm>(H_PAD_)); + a->imul(scratchReg1_, W_R_); + a->imul(scratchReg1_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t))); + for (int r = 0; r < R_; ++r) { + if (!isZeroPointZero_) { + for (int s = 0; s < W_PAD_; ++s) { + gen8bitFMA<inst_set_t::avx2>( + a, zeroPTRegAvx2_, WRegs_avx2_[r * S_ + s]); + } + } + for (int s = W_PAD_; s < S_; ++s) { + a->vbroadcastsd( + actRegAvx2_, + x86::dword_ptr( + in_acts_R_, + scratchReg1_, + 0, + (s - W_PAD_) * C_ * sizeof(uint8_t))); + gen8bitFMA<inst_set_t::avx2>(a, actRegAvx2_, WRegs_avx2_[r * S_ + s]); + } + a->imul(scratchReg2_, W_R_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t))); + a->add(scratchReg1_, scratchReg2_); + } + + a->imul(scratchReg2_, W_R_, static_cast<asmjit::Imm>(K_ * sizeof(int32_t))); + a->add(out_acts_R_, scratchReg2_); + storeResult<inst_set_t::avx2>(a); + + a->inc(loopR1_); + a->mov(loopR2_, H_R_); + a->sub(loopR2_, static_cast<asmjit::Imm>(H_PAD_)); + a->cmp(loopR1_, loopR2_); + a->jl(LoopLeftEdge); + + // reset output activation pointer + a->mov(scratchReg2_, H_R_); + a->sub(scratchReg2_, static_cast<asmjit::Imm>(2 * H_PAD_)); + a->imul(scratchReg2_, W_R_); + a->imul(scratchReg2_, static_cast<asmjit::Imm>(K_ * sizeof(int32_t))); + a->sub(out_acts_R_, scratchReg2_); +} + +template <> +template <> +void GenConvKernel<int32_t>::genForRightEdge<inst_set_t::avx2>( + asmjit::X86Emitter* a) { + // right edge excluding corners + asmjit::Label LoopRightEdge = a->newLabel(); + + // output pointer to the right edge + // (W_ + W_ - 1)*K_*sizeof(int32_t) + a->mov(scratchReg2_, W_R_); + a->imul(scratchReg2_, 2); + a->sub(scratchReg2_, 1); + a->imul(scratchReg2_, static_cast<asmjit::Imm>(K_ * sizeof(int32_t))); + a->add(out_acts_R_, scratchReg2_); + + a->mov(loopR1_, static_cast<asmjit::Imm>(H_PAD_)); + a->bind(LoopRightEdge); + // zero out + a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_); + a->mov(scratchReg1_, loopR1_); + a->sub(scratchReg1_, static_cast<asmjit::Imm>(H_PAD_)); + a->imul(scratchReg1_, W_R_); + a->imul(scratchReg1_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t))); + + a->mov(scratchReg2_, W_R_); + a->sub(scratchReg2_, static_cast<asmjit::Imm>(2 * W_PAD_)); + a->imul(scratchReg2_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t))); + a->add(scratchReg1_, scratchReg2_); + for (int r = 0; r < R_; ++r) { + for (int s = 0; s < S_ - W_PAD_; ++s) { + a->vbroadcastsd(actRegAvx2_, x86::dword_ptr(in_acts_R_, scratchReg1_)); + gen8bitFMA<inst_set_t::avx2>(a, actRegAvx2_, WRegs_avx2_[r * S_ + s]); + a->add(scratchReg1_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t))); + } + if (!isZeroPointZero_) { + for (int s = S_ - W_PAD_; s < S_; ++s) { + gen8bitFMA<inst_set_t::avx2>( + a, zeroPTRegAvx2_, WRegs_avx2_[r * S_ + s]); + } + } + + a->sub( + scratchReg1_, + static_cast<asmjit::Imm>((S_ - W_PAD_) * C_ * sizeof(uint8_t))); + a->imul(scratchReg2_, W_R_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t))); + a->add(scratchReg1_, scratchReg2_); + } + + // storeResult<inst_set_t::avx2>(a, (W_+W_-1)*K_*sizeof(int32_t)); + storeResult<inst_set_t::avx2>(a); + + a->imul(scratchReg2_, W_R_, static_cast<asmjit::Imm>(K_ * sizeof(int32_t))); + a->add(out_acts_R_, scratchReg2_); + a->mov(loopR2_, H_R_); + a->sub(loopR2_, static_cast<asmjit::Imm>(H_PAD_)); + a->inc(loopR1_); + a->cmp(loopR1_, loopR2_); + a->jl(LoopRightEdge); + + // reset base + a->mov(scratchReg2_, W_R_); + a->imul(scratchReg2_, 2); + a->sub(scratchReg2_, 1); + a->imul(scratchReg2_, static_cast<asmjit::Imm>(K_ * sizeof(int32_t))); + a->sub(out_acts_R_, scratchReg2_); + + // reset loop increments + //(H_ - 2*H_PAD_)*W_*K_*sizeof(int32_t) + a->mov(scratchReg2_, H_R_); + a->sub(scratchReg2_, static_cast<asmjit::Imm>(2 * H_PAD_)); + a->imul(scratchReg2_, W_R_); + a->imul(scratchReg2_, static_cast<asmjit::Imm>(K_ * sizeof(int32_t))); + a->sub(out_acts_R_, scratchReg2_); + // a->sub(out_acts_R_, (H_ - 2*H_PAD_)*W_*K_*sizeof(int32_t)); +} + +template <> +template <> +void GenConvKernel<int32_t>::genForBottomEdge<inst_set_t::avx2>( + asmjit::X86Emitter* a) { + // bottom-left corner + // zero out + a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_); + a->mov(scratchReg1_, H_R_); + a->sub(scratchReg1_, static_cast<asmjit::Imm>(2 * H_PAD_)); + a->imul(scratchReg1_, W_R_); + a->imul(scratchReg1_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t))); + for (int r = 0; r < R_ - H_PAD_; ++r) { + if (!isZeroPointZero_) { + for (int s = 0; s < W_PAD_; ++s) { + gen8bitFMA<inst_set_t::avx2>( + a, zeroPTRegAvx2_, WRegs_avx2_[r * S_ + s]); + } + } + for (int s = W_PAD_; s < S_; ++s) { + a->vbroadcastsd( + actRegAvx2_, + x86::dword_ptr( + in_acts_R_, + scratchReg1_, + 0, + (s - W_PAD_) * C_ * sizeof(uint8_t))); + gen8bitFMA<inst_set_t::avx2>(a, actRegAvx2_, WRegs_avx2_[r * S_ + s]); + } + a->imul(scratchReg2_, W_R_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t))); + a->add(scratchReg1_, scratchReg2_); + } + if (!isZeroPointZero_) { + for (int r = R_ - H_PAD_; r < R_; ++r) { + for (int s = 0; s < S_; ++s) { + gen8bitFMA<inst_set_t::avx2>( + a, zeroPTRegAvx2_, WRegs_avx2_[r * S_ + s]); + } + } + } + + // we updating the last row + a->mov(scratchReg1_, H_R_); + a->sub(scratchReg1_, 1); + a->imul(scratchReg1_, W_R_); + a->imul(scratchReg1_, static_cast<asmjit::Imm>(K_ * sizeof(int32_t))); + a->add(out_acts_R_, scratchReg1_); + // storeResult<inst_set_t::avx2>(a, (H_-1)*W_*K_*sizeof(int32_t)); + storeResult<inst_set_t::avx2>(a); + a->add(out_acts_R_, static_cast<asmjit::Imm>(K_ * sizeof(int32_t))); + + // bottom edge excluding corners + asmjit::Label LoopBottomEdge = a->newLabel(); + a->mov(loopR2_, static_cast<asmjit::Imm>(W_PAD_)); + a->bind(LoopBottomEdge); + // zero out + a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_); + a->mov(scratchReg1_, H_R_); + a->sub(scratchReg1_, static_cast<asmjit::Imm>(2 * H_PAD_)); + a->imul(scratchReg1_, W_R_); + a->imul(scratchReg1_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t))); + for (int r = 0; r < R_ - W_PAD_; ++r) { + // int h_in = H_-2*H_PAD_ + r; + for (int s = 0; s < S_; ++s) { + a->vbroadcastsd( + actRegAvx2_, + x86::dword_ptr( + in_acts_R_, scratchReg1_, 0, s * C_ * sizeof(uint8_t))); + gen8bitFMA<inst_set_t::avx2>(a, actRegAvx2_, WRegs_avx2_[r * S_ + s]); + } + a->imul(scratchReg2_, W_R_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t))); + a->add(scratchReg1_, scratchReg2_); + } + + if (!isZeroPointZero_) { + for (int r = R_ - W_PAD_; r < R_; ++r) { + for (int s = 0; s < S_; ++s) { + gen8bitFMA<inst_set_t::avx2>( + a, zeroPTRegAvx2_, WRegs_avx2_[r * S_ + s]); + } + } + } + + a->add(in_acts_R_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t))); + // storeResult<inst_set_t::avx2>(a, ((H_-1)*W_+1)*K_*sizeof(int32_t)); + storeResult<inst_set_t::avx2>(a); + + a->add(out_acts_R_, static_cast<asmjit::Imm>(K_ * sizeof(int32_t))); + a->inc(loopR2_); + a->mov(loopR1_, W_R_); + a->sub(loopR1_, static_cast<asmjit::Imm>(W_PAD_)); + a->cmp(loopR2_, loopR1_); + a->jl(LoopBottomEdge); + a->mov(scratchReg1_, W_R_); + a->sub(scratchReg1_, static_cast<asmjit::Imm>(2 * W_PAD_)); + a->imul(scratchReg1_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t))); + a->sub(in_acts_R_, scratchReg1_); + // a->sub(in_acts_R_, (W_ - 2*W_PAD_)*C_*sizeof(uint8_t)); + // a->sub(out_acts_R_, (W_ - 2*W_PAD_)*K_*sizeof(int32_t)); + + // bottom-right corner + // zero out + a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_); + // input start point + // ((H_-(R_-H_PAD_))*W_+(W_-(S_-W_PAD_)))*C_*sizeof(uint8_t) + a->mov(scratchReg1_, H_R_); + a->sub(scratchReg1_, static_cast<asmjit::Imm>(R_ - H_PAD_)); + a->imul(scratchReg1_, W_R_); + a->add(scratchReg1_, W_R_); + a->sub(scratchReg1_, static_cast<asmjit::Imm>(S_ - W_PAD_)); + a->imul(scratchReg1_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t))); + for (int r = 0; r < R_ - H_PAD_; ++r) { + for (int s = 0; s < S_ - W_PAD_; ++s) { + a->vbroadcastsd( + actRegAvx2_, + x86::dword_ptr( + in_acts_R_, scratchReg1_, 0, s * C_ * sizeof(uint8_t))); + gen8bitFMA<inst_set_t::avx2>(a, actRegAvx2_, WRegs_avx2_[r * S_ + s]); + } + a->imul(scratchReg2_, W_R_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t))); + a->add(scratchReg1_, scratchReg2_); + if (!isZeroPointZero_) { + for (int s = S_ - W_PAD_; s < S_; ++s) { + gen8bitFMA<inst_set_t::avx2>( + a, zeroPTRegAvx2_, WRegs_avx2_[r * S_ + s]); + } + } + } + + if (!isZeroPointZero_) { + for (int r = R_ - H_PAD_; r < R_; ++r) { + for (int s = 0; s < S_; ++s) { + gen8bitFMA<inst_set_t::avx2>( + a, zeroPTRegAvx2_, WRegs_avx2_[r * S_ + s]); + } + } + } + + storeResult<inst_set_t::avx2>(a); + // storeResult<inst_set_t::avx2>(a, ((H_-1)*W_+W_-1)*K_*sizeof(int32_t)); + // reset output pointer + a->mov(scratchReg1_, H_R_); + a->sub(scratchReg1_, 1); + a->imul(scratchReg1_, W_R_); + a->add(scratchReg1_, W_R_); + a->sub(scratchReg1_, 1); + a->imul(scratchReg1_, static_cast<asmjit::Imm>(K_ * sizeof(int32_t))); + a->sub(out_acts_R_, scratchReg1_); +} + +template <> +template <> +void GenConvKernel<int32_t>::genCoreInsts<inst_set_t::avx2>( + asmjit::X86Emitter* a) { + // main compute + asmjit::Label LoopH = a->newLabel(); + asmjit::Label LoopW = a->newLabel(); + // base for output + a->mov(scratchReg2_, static_cast<asmjit::Imm>(H_PAD_)); + a->imul(scratchReg2_, W_R_); + a->add(scratchReg2_, static_cast<asmjit::Imm>(W_PAD_)); + a->imul(scratchReg2_, static_cast<asmjit::Imm>(K_ * sizeof(int32_t))); + a->add(out_acts_R_, scratchReg2_); + + a->mov(scratchReg1_, W_R_); + a->sub(scratchReg1_, static_cast<asmjit::Imm>(W_PAD_)); + + // H loop + a->mov(loopR1_, static_cast<asmjit::Imm>(H_PAD_)); + a->bind(LoopH); + // W loop + a->mov(loopR2_, static_cast<asmjit::Imm>(W_PAD_)); + a->bind(LoopW); + // zero out + a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_); + // compute on all filters + for (int r = 0; r < R_; ++r) { + for (int s = 0; s < S_; ++s) { + a->vbroadcastsd( + actRegAvx2_, x86::dword_ptr(in_acts_R_, s * C_ * sizeof(uint8_t))); + gen8bitFMA<inst_set_t::avx2>(a, actRegAvx2_, WRegs_avx2_[r * S_ + s]); + } + a->imul(scratchReg2_, W_R_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t))); + a->add(in_acts_R_, scratchReg2_); + } + a->imul( + scratchReg2_, W_R_, static_cast<asmjit::Imm>(R_ * C_ * sizeof(uint8_t))); + a->sub(in_acts_R_, scratchReg2_); + // a->add(scratchReg1_, C_*sizeof(uint8_t)); + a->add(in_acts_R_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t))); + + // storeResult<inst_set_t::avx2>(a, (W_+1)*K_*sizeof(int32_t)); + storeResult<inst_set_t::avx2>(a); + + a->add(out_acts_R_, static_cast<asmjit::Imm>(K_ * sizeof(int32_t))); + + a->inc(loopR2_); + a->cmp(loopR2_, scratchReg1_); + a->jl(LoopW); + // add (W_ - 2*W_PAD_)*C_*sizeof(uint8_t) and subtract W_*C_*sizeof(uint8_t) + a->add( + in_acts_R_, static_cast<asmjit::Imm>(2 * W_PAD_ * C_ * sizeof(uint8_t))); + // a->sub(in_acts_R_, (W_ - 2*W_PAD_)*C_*sizeof(uint8_t)); + // a->add(in_acts_R_, W_*C_*sizeof(uint8_t)); + a->add( + out_acts_R_, static_cast<asmjit::Imm>(2 * W_PAD_ * K_ * sizeof(int32_t))); + // a->sub(out_acts_R_, (W_ - 2*W_PAD_)*K_*sizeof(int32_t)); + // a->add(out_acts_R_, W_*K_*sizeof(int32_t)); + + a->inc(loopR1_); + a->mov(scratchReg2_, H_R_); + a->sub(scratchReg2_, static_cast<asmjit::Imm>(H_PAD_)); + a->cmp(loopR1_, scratchReg2_); + a->jl(LoopH); +} + +template <> +template <> +jit_conv_kernel_fp GenConvKernel<int32_t>::getOrCreate<inst_set_t::avx2>( + const conv_param_t<>& conv_param) { + code_.reset(false); + code_.init(rt_.getCodeInfo()); + asmjit::X86Assembler assembler(&code_); + asmjit::X86Emitter* a = assembler.asEmitter(); + +#if defined(FBGEMM_LOG_CODE) + // log code to a file + FILE* codeLogfile = + fopen(getCodeLoggingFile<inst_set_t::avx2>(false).c_str(), "w"); + asmjit::FileLogger* codeLogger = new asmjit::FileLogger(codeLogfile); + if (codeLogger) { + code_.setLogger(codeLogger); + } +#endif + + // arguments to the function created + in_acts_R_ = a->zdi(); + wghts_R_ = a->zsi(); + out_acts_R_ = a->zdx(); + a_zero_pt_R_ = a->zcx(); + H_R_ = a->gpzRef(8); + W_R_ = a->gpzRef(9); + row_offset_R_ = a->gpzRef(10); + + // register for temporary use + scratchReg1_ = a->gpzRef(12); + scratchReg2_ = a->gpzRef(13); + + asmjit::FuncDetail func; + func.init(asmjit::FuncSignature6< + void, + uint8_t*, + int8_t*, + int32_t*, + int32_t, + int32_t, + int32_t>(asmjit::CallConv::kIdHost)); + + asmjit::FuncFrameInfo ffi; + ffi.setDirtyRegs( + asmjit::X86Reg::kKindVec, + asmjit::Utils::mask(0, 1, 2, 3, 4, 5, 6, 7) | + asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14, 15)); + ffi.setDirtyRegs( + asmjit::X86Reg::kKindGp, asmjit::Utils::mask(10, 11, 12, 13, 14, 15)); + + asmjit::FuncArgsMapper args(&func); + args.assignAll(in_acts_R_, wghts_R_, out_acts_R_, a_zero_pt_R_, H_R_, W_R_); + + args.updateFrameInfo(ffi); + + asmjit::FuncFrameLayout layout; + layout.init(func, ffi); + + asmjit::FuncUtils::emitProlog(a, layout); + asmjit::FuncUtils::allocArgs(a, layout, args); + + createVector16BitOne<inst_set_t::avx2>(a); + + loopR1_ = a->gpzRef(14); + loopR2_ = a->gpzRef(15); + + if (!isZeroPointZero_) { + setToZeroPt<inst_set_t::avx2>(a, zeroPTRegAvx2_); + } + + genForLoadingWeights<inst_set_t::avx2>(a); + + genConstForPermutations<inst_set_t::avx2>(a); + + genForTopEdge<inst_set_t::avx2>(a); + genForLeftEdge<inst_set_t::avx2>(a); + genForRightEdge<inst_set_t::avx2>(a); + genForBottomEdge<inst_set_t::avx2>(a); + + genCoreInsts<inst_set_t::avx2>(a); + + asmjit::FuncUtils::emitEpilog(a, layout); + + jit_conv_kernel_fp fn; + asmjit::Error err = rt_.add(&fn, &code_); + if (err) { + std::cout << "Error: in fn add" << std::endl; + return nullptr; + } + auto kernelSig = getKernelSig(conv_param, isZeroPointZero_); + codeCache_[kernelSig] = fn; + +#if defined(FBGEMM_LOG_CODE) + fclose(codeLogfile); + delete codeLogger; +#endif + + return fn; +} + +template <> +template <> +void GenConvKernel<int32_t>::genForTopEdgeRowoffset<inst_set_t::avx2>( + asmjit::X86Emitter* a) { + // top-left corner code + // zero out the results register + a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_); + for (int r = 0; r < R_; ++r) { + int h_in = -H_PAD_ + r; + if (h_in >= 0) { + a->imul( + scratchReg1_, + W_R_, + static_cast<asmjit::Imm>(h_in * C_ * sizeof(uint8_t))); + } + for (int s = 0; s < S_; ++s) { + int w_in = -W_PAD_ + s; + if (h_in >= 0 && w_in >= 0) { + a->vmovaps( + actRegAvx2_, + x86::dword_ptr( + in_acts_R_, scratchReg1_, 0, w_in * C_ * sizeof(uint8_t))); + gen8BitSum<inst_set_t::avx2>(a, actRegAvx2_); + } else { + if (!isZeroPointZero_) { + gen8BitSum<inst_set_t::avx2>(a, zeroPTRegAvx2_); + } + } + } + } + // store results + storeResultRowoffset<inst_set_t::avx2>(a); + + // for C_per_G == 4 and K_per_G == 4, 8 groups processed at a time + a->add(row_offset_R_, static_cast<asmjit::Imm>(8 * sizeof(int32_t))); + + // top edge excluding corners + asmjit::Label LoopTopEdge = a->newLabel(); + a->mov(loopR2_, static_cast<asmjit::Imm>(W_PAD_)); + a->bind(LoopTopEdge); + // zero out + a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_); + if (!isZeroPointZero_) { + for (int r = 0; r < H_PAD_; ++r) { + for (int s = 0; s < S_; ++s) { + gen8BitSum<inst_set_t::avx2>(a, zeroPTRegAvx2_); + } + } + } + for (int r = H_PAD_; r < R_; ++r) { + int h_in = -H_PAD_ + r; + a->imul( + scratchReg1_, + W_R_, + static_cast<asmjit::Imm>(h_in * C_ * sizeof(uint8_t))); + for (int s = 0; s < S_; ++s) { + a->vmovaps( + actRegAvx2_, + x86::dword_ptr( + in_acts_R_, scratchReg1_, 0, s * C_ * sizeof(uint8_t))); + gen8BitSum<inst_set_t::avx2>(a, actRegAvx2_); + } + } + a->add(in_acts_R_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t))); + + // store results + storeResultRowoffset<inst_set_t::avx2>(a); + + a->add(row_offset_R_, static_cast<asmjit::Imm>(8 * sizeof(int32_t))); + a->mov(loopR1_, W_R_); + a->sub(loopR1_, static_cast<asmjit::Imm>(W_PAD_)); + a->inc(loopR2_); + a->cmp(loopR2_, loopR1_); + a->jl(LoopTopEdge); + a->mov(scratchReg2_, W_R_); + a->imul(scratchReg2_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t))); + a->sub( + scratchReg2_, + static_cast<asmjit::Imm>(2 * W_PAD_ * C_ * sizeof(uint8_t))); + a->sub(in_acts_R_, scratchReg2_); + + // top-right corner code + // zero out + a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_); + if (!isZeroPointZero_) { + for (int r = 0; r < H_PAD_; ++r) { + for (int s = 0; s < S_; ++s) { + gen8BitSum<inst_set_t::avx2>(a, zeroPTRegAvx2_); + } + } + } + for (int r = H_PAD_; r < R_; ++r) { + int h_in = -H_PAD_ + r; + for (int s = 0; s < S_ - W_PAD_; ++s) { + a->imul( + scratchReg1_, + W_R_, + static_cast<asmjit::Imm>(h_in * C_ * sizeof(uint8_t))); + a->mov(scratchReg2_, W_R_); + a->sub(scratchReg2_, static_cast<asmjit::Imm>(R_ - W_PAD_ - s)); + a->imul(scratchReg2_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t))); + a->add(scratchReg1_, scratchReg2_); + a->vmovaps(actRegAvx2_, x86::dword_ptr(in_acts_R_, scratchReg1_)); + gen8BitSum<inst_set_t::avx2>(a, actRegAvx2_); + } + if (!isZeroPointZero_) { + for (int s = S_ - W_PAD_; s < S_; ++s) { + gen8BitSum<inst_set_t::avx2>(a, zeroPTRegAvx2_); + } + } + } + + // store results + storeResultRowoffset<inst_set_t::avx2>(a); + + a->add(row_offset_R_, static_cast<asmjit::Imm>(8 * sizeof(int32_t))); + + // reset output pointer + a->imul(scratchReg1_, W_R_, static_cast<asmjit::Imm>(8 * sizeof(int32_t))); + a->sub(row_offset_R_, scratchReg1_); +} + +template <> +template <> +void GenConvKernel<int32_t>::genForLeftEdgeRowoffset<inst_set_t::avx2>( + asmjit::X86Emitter* a) { + // left edge excluding corners + asmjit::Label LoopLeftEdge = a->newLabel(); + a->mov(loopR1_, static_cast<asmjit::Imm>(H_PAD_)); + a->bind(LoopLeftEdge); + // zero out + a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_); + a->mov(scratchReg1_, loopR1_); + a->sub(scratchReg1_, static_cast<asmjit::Imm>(H_PAD_)); + a->imul(scratchReg1_, W_R_); + a->imul(scratchReg1_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t))); + for (int r = 0; r < R_; ++r) { + if (!isZeroPointZero_) { + for (int s = 0; s < W_PAD_; ++s) { + gen8BitSum<inst_set_t::avx2>(a, zeroPTRegAvx2_); + } + } + for (int s = W_PAD_; s < S_; ++s) { + a->vmovaps( + actRegAvx2_, + x86::dword_ptr( + in_acts_R_, + scratchReg1_, + 0, + (s - W_PAD_) * C_ * sizeof(uint8_t))); + gen8BitSum<inst_set_t::avx2>(a, actRegAvx2_); + } + a->imul(scratchReg2_, W_R_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t))); + a->add(scratchReg1_, scratchReg2_); + } + + a->imul(scratchReg2_, W_R_, static_cast<asmjit::Imm>(8 * sizeof(int32_t))); + a->add(row_offset_R_, scratchReg2_); + storeResultRowoffset<inst_set_t::avx2>(a); + + a->inc(loopR1_); + a->mov(loopR2_, H_R_); + a->sub(loopR2_, static_cast<asmjit::Imm>(H_PAD_)); + a->cmp(loopR1_, loopR2_); + a->jl(LoopLeftEdge); + + // reset output pointer + a->mov(scratchReg2_, H_R_); + a->sub(scratchReg2_, static_cast<asmjit::Imm>(2 * H_PAD_)); + a->imul(scratchReg2_, W_R_); + a->imul(scratchReg2_, static_cast<asmjit::Imm>(8 * sizeof(int32_t))); + a->sub(row_offset_R_, scratchReg2_); +} + +template <> +template <> +void GenConvKernel<int32_t>::genForRightEdgeRowoffset<inst_set_t::avx2>( + asmjit::X86Emitter* a) { + // right edge excluding corners + asmjit::Label LoopRightEdge = a->newLabel(); + + // output pointer to the right edge + // (W_ + W_ - 1)*8*sizeof(int32_t) + a->mov(scratchReg2_, W_R_); + a->imul(scratchReg2_, 2); + a->sub(scratchReg2_, 1); + a->imul(scratchReg2_, static_cast<asmjit::Imm>(8 * sizeof(int32_t))); + a->add(row_offset_R_, scratchReg2_); + + a->mov(loopR1_, static_cast<asmjit::Imm>(H_PAD_)); + a->bind(LoopRightEdge); + // zero out + a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_); + a->mov(scratchReg1_, loopR1_); + a->sub(scratchReg1_, static_cast<asmjit::Imm>(H_PAD_)); + a->imul(scratchReg1_, W_R_); + a->imul(scratchReg1_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t))); + + a->mov(scratchReg2_, W_R_); + a->sub(scratchReg2_, static_cast<asmjit::Imm>(2 * W_PAD_)); + a->imul(scratchReg2_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t))); + a->add(scratchReg1_, scratchReg2_); + for (int r = 0; r < R_; ++r) { + for (int s = 0; s < S_ - W_PAD_; ++s) { + a->vbroadcastsd(actRegAvx2_, x86::dword_ptr(in_acts_R_, scratchReg1_)); + a->vmovaps(actRegAvx2_, x86::dword_ptr(in_acts_R_, scratchReg1_)); + gen8BitSum<inst_set_t::avx2>(a, actRegAvx2_); + a->add(scratchReg1_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t))); + } + if (!isZeroPointZero_) { + for (int s = S_ - W_PAD_; s < S_; ++s) { + gen8BitSum<inst_set_t::avx2>(a, zeroPTRegAvx2_); + } + } + + a->sub( + scratchReg1_, + static_cast<asmjit::Imm>((S_ - W_PAD_) * C_ * sizeof(uint8_t))); + a->imul(scratchReg2_, W_R_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t))); + a->add(scratchReg1_, scratchReg2_); + } + + storeResultRowoffset<inst_set_t::avx2>(a); + + a->imul(scratchReg2_, W_R_, static_cast<asmjit::Imm>(8 * sizeof(int32_t))); + a->add(row_offset_R_, scratchReg2_); + a->mov(loopR2_, H_R_); + a->sub(loopR2_, static_cast<asmjit::Imm>(H_PAD_)); + a->inc(loopR1_); + a->cmp(loopR1_, loopR2_); + a->jl(LoopRightEdge); + + // reset base + a->mov(scratchReg2_, W_R_); + a->imul(scratchReg2_, 2); + a->sub(scratchReg2_, 1); + a->imul(scratchReg2_, static_cast<asmjit::Imm>(8 * sizeof(int32_t))); + a->sub(row_offset_R_, scratchReg2_); + + // reset increments done in the loop + //(H_ - 2*H_PAD_)*W_*8*sizeof(int32_t) + a->mov(scratchReg2_, H_R_); + a->sub(scratchReg2_, static_cast<asmjit::Imm>(2 * H_PAD_)); + a->imul(scratchReg2_, W_R_); + a->imul(scratchReg2_, static_cast<asmjit::Imm>(8 * sizeof(int32_t))); + a->sub(row_offset_R_, scratchReg2_); +} + +template <> +template <> +void GenConvKernel<int32_t>::genForBottomEdgeRowoffset<inst_set_t::avx2>( + asmjit::X86Emitter* a) { + // bottom-left corner + // zero out + a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_); + a->mov(scratchReg1_, H_R_); + a->sub(scratchReg1_, static_cast<asmjit::Imm>(2 * H_PAD_)); + a->imul(scratchReg1_, W_R_); + a->imul(scratchReg1_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t))); + for (int r = 0; r < R_ - H_PAD_; ++r) { + if (!isZeroPointZero_) { + for (int s = 0; s < W_PAD_; ++s) { + gen8BitSum<inst_set_t::avx2>(a, zeroPTRegAvx2_); + } + } + for (int s = W_PAD_; s < S_; ++s) { + a->vmovaps( + actRegAvx2_, + x86::dword_ptr( + in_acts_R_, + scratchReg1_, + 0, + (s - W_PAD_) * C_ * sizeof(uint8_t))); + gen8BitSum<inst_set_t::avx2>(a, actRegAvx2_); + } + a->imul(scratchReg2_, W_R_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t))); + a->add(scratchReg1_, scratchReg2_); + } + if (!isZeroPointZero_) { + for (int r = R_ - H_PAD_; r < R_; ++r) { + for (int s = 0; s < S_; ++s) { + gen8BitSum<inst_set_t::avx2>(a, zeroPTRegAvx2_); + } + } + } + + // we updating the last row + a->mov(scratchReg1_, H_R_); + a->sub(scratchReg1_, 1); + a->imul(scratchReg1_, W_R_); + a->imul(scratchReg1_, static_cast<asmjit::Imm>(8 * sizeof(int32_t))); + a->add(row_offset_R_, scratchReg1_); + storeResultRowoffset<inst_set_t::avx2>(a); + a->add(row_offset_R_, static_cast<asmjit::Imm>(8 * sizeof(int32_t))); + + // bottom edge excluding corners + asmjit::Label LoopBottomEdge = a->newLabel(); + a->mov(loopR2_, static_cast<asmjit::Imm>(W_PAD_)); + a->bind(LoopBottomEdge); + // zero out + a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_); + a->mov(scratchReg1_, H_R_); + a->sub(scratchReg1_, static_cast<asmjit::Imm>(2 * H_PAD_)); + a->imul(scratchReg1_, W_R_); + a->imul(scratchReg1_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t))); + for (int r = 0; r < R_ - W_PAD_; ++r) { + // int h_in = H_-2*H_PAD_ + r; + for (int s = 0; s < S_; ++s) { + a->vmovaps( + actRegAvx2_, + x86::dword_ptr( + in_acts_R_, scratchReg1_, 0, s * C_ * sizeof(uint8_t))); + gen8BitSum<inst_set_t::avx2>(a, actRegAvx2_); + } + a->imul(scratchReg2_, W_R_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t))); + a->add(scratchReg1_, scratchReg2_); + } + + if (!isZeroPointZero_) { + for (int r = R_ - W_PAD_; r < R_; ++r) { + for (int s = 0; s < S_; ++s) { + gen8BitSum<inst_set_t::avx2>(a, zeroPTRegAvx2_); + } + } + } + + a->add(in_acts_R_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t))); + // storeResult<inst_set_t::avx2>(a, ((H_-1)*W_+1)*8*sizeof(int32_t)); + storeResultRowoffset<inst_set_t::avx2>(a); + + a->add(row_offset_R_, static_cast<asmjit::Imm>(8 * sizeof(int32_t))); + a->inc(loopR2_); + a->mov(loopR1_, W_R_); + a->sub(loopR1_, static_cast<asmjit::Imm>(W_PAD_)); + a->cmp(loopR2_, loopR1_); + a->jl(LoopBottomEdge); + a->mov(scratchReg1_, W_R_); + a->sub(scratchReg1_, static_cast<asmjit::Imm>(2 * W_PAD_)); + a->imul(scratchReg1_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t))); + a->sub(in_acts_R_, scratchReg1_); + // a->sub(in_acts_R_, (W_ - 2*W_PAD_)*C_*sizeof(uint8_t)); + // a->sub(out_acts_R_, (W_ - 2*W_PAD_)*8*sizeof(int32_t)); + + // bottom-right corner + // zero out + a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_); + // input start point + // ((H_-(R_-H_PAD_))*W_+(W_-(S_-W_PAD_)))*C_*sizeof(uint8_t) + a->mov(scratchReg1_, H_R_); + a->sub(scratchReg1_, static_cast<asmjit::Imm>(R_ - H_PAD_)); + a->imul(scratchReg1_, W_R_); + a->add(scratchReg1_, W_R_); + a->sub(scratchReg1_, static_cast<asmjit::Imm>(S_ - W_PAD_)); + a->imul(scratchReg1_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t))); + for (int r = 0; r < R_ - H_PAD_; ++r) { + for (int s = 0; s < S_ - W_PAD_; ++s) { + a->vmovaps( + actRegAvx2_, + x86::dword_ptr( + in_acts_R_, scratchReg1_, 0, s * C_ * sizeof(uint8_t))); + gen8BitSum<inst_set_t::avx2>(a, actRegAvx2_); + } + a->imul(scratchReg2_, W_R_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t))); + a->add(scratchReg1_, scratchReg2_); + if (!isZeroPointZero_) { + for (int s = S_ - W_PAD_; s < S_; ++s) { + gen8BitSum<inst_set_t::avx2>(a, zeroPTRegAvx2_); + } + } + } + + if (!isZeroPointZero_) { + for (int r = R_ - H_PAD_; r < R_; ++r) { + for (int s = 0; s < S_; ++s) { + gen8BitSum<inst_set_t::avx2>(a, zeroPTRegAvx2_); + } + } + } + + storeResultRowoffset<inst_set_t::avx2>(a); + // reset output pointer + a->mov(scratchReg1_, H_R_); + a->sub(scratchReg1_, 1); + a->imul(scratchReg1_, W_R_); + a->add(scratchReg1_, W_R_); + a->sub(scratchReg1_, 1); + a->imul(scratchReg1_, static_cast<asmjit::Imm>(8 * sizeof(int32_t))); + a->sub(row_offset_R_, scratchReg1_); +} + +template <> +template <> +void GenConvKernel<int32_t>::genRowoffsetCore<inst_set_t::avx2>( + asmjit::X86Emitter* a) { + // number of uint8 elements in input channels should be a multiple of 32 + assert(C_ % 32 == 0); + + asmjit::Label LoopH = a->newLabel(); + asmjit::Label LoopW = a->newLabel(); + // base for output + a->mov(scratchReg2_, static_cast<asmjit::Imm>(H_PAD_)); + a->imul(scratchReg2_, W_R_); + a->add(scratchReg2_, static_cast<asmjit::Imm>(W_PAD_)); + a->imul(scratchReg2_, static_cast<asmjit::Imm>(8 * sizeof(int32_t))); + a->add(row_offset_R_, scratchReg2_); + + a->mov(scratchReg1_, W_R_); + a->sub(scratchReg1_, static_cast<asmjit::Imm>(W_PAD_)); + + // H loop + a->mov(loopR1_, static_cast<asmjit::Imm>(H_PAD_)); + a->bind(LoopH); + // W loop + a->mov(loopR2_, static_cast<asmjit::Imm>(W_PAD_)); + a->bind(LoopW); + + // zero out + a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_); + for (int r = 0; r < R_; ++r) { + for (int s = 0; s < S_; ++s) { + a->vmovaps( + actRegAvx2_, x86::dword_ptr(in_acts_R_, s * C_ * sizeof(uint8_t))); + gen8BitSum<inst_set_t::avx2>(a, actRegAvx2_); + } + a->imul(scratchReg2_, W_R_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t))); + a->add(in_acts_R_, scratchReg2_); + } + a->imul( + scratchReg2_, W_R_, static_cast<asmjit::Imm>(R_ * C_ * sizeof(uint8_t))); + a->sub(in_acts_R_, scratchReg2_); + // store results + storeResultRowoffset<inst_set_t::avx2>(a); + + a->add(in_acts_R_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t))); + a->add(row_offset_R_, static_cast<asmjit::Imm>(8 * sizeof(int32_t))); + + a->inc(loopR2_); + a->cmp(loopR2_, scratchReg1_); + a->jl(LoopW); + a->add( + in_acts_R_, static_cast<asmjit::Imm>(2 * W_PAD_ * C_ * sizeof(uint8_t))); + a->add( + row_offset_R_, + static_cast<asmjit::Imm>(2 * W_PAD_ * 8 * sizeof(int32_t))); + a->inc(loopR1_); + a->mov(scratchReg2_, H_R_); + a->sub(scratchReg2_, static_cast<asmjit::Imm>(H_PAD_)); + a->cmp(loopR1_, scratchReg2_); + a->jl(LoopH); +} + +template <> +template <> +jit_rowoffset_kernel_fp +GenConvKernel<int32_t>::getOrCreateRowOffset<inst_set_t::avx2>( + const conv_param_t<>& conv_param) { + code_.reset(false); + code_.init(rt_.getCodeInfo()); + asmjit::X86Assembler assembler(&code_); + asmjit::X86Emitter* a = assembler.asEmitter(); + +#if defined(FBGEMM_LOG_CODE) + // log code to a file + FILE* codeLogfile = + fopen(getCodeLoggingFile<inst_set_t::avx2>(true).c_str(), "w"); + asmjit::FileLogger* codeLogger = new asmjit::FileLogger(codeLogfile); + if (codeLogger) { + code_.setLogger(codeLogger); + } +#endif + + // arguments to the function created + in_acts_R_ = a->zdi(); + a_zero_pt_R_ = a->zsi(); + H_R_ = a->zdx(); + W_R_ = a->zcx(); + row_offset_R_ = a->gpzRef(8); + + // register for temporary use + scratchReg1_ = a->gpzRef(12); + scratchReg2_ = a->gpzRef(13); + + loopR1_ = a->gpzRef(14); + loopR2_ = a->gpzRef(15); + + asmjit::FuncDetail func; + func.init( + asmjit:: + FuncSignature5<void, uint8_t*, int32_t, int32_t, int32_t, int32_t*>( + asmjit::CallConv::kIdHost)); + + asmjit::FuncFrameInfo ffi; + ffi.setDirtyRegs( + asmjit::X86Reg::kKindVec, + asmjit::Utils::mask(0, 1, 2, 3, 4, 5, 6, 7) | + asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14, 15)); + ffi.setDirtyRegs( + asmjit::X86Reg::kKindGp, asmjit::Utils::mask(10, 11, 12, 13, 14, 15)); + + asmjit::FuncArgsMapper args(&func); + args.assignAll(in_acts_R_, a_zero_pt_R_, H_R_, W_R_, row_offset_R_); + + args.updateFrameInfo(ffi); + + asmjit::FuncFrameLayout layout; + layout.init(func, ffi); + + asmjit::FuncUtils::emitProlog(a, layout); + asmjit::FuncUtils::allocArgs(a, layout, args); + + // This uses xmm10 register temporarily. Should come before + // createVector8BitOne + if (!isZeroPointZero_) { + setToZeroPt<inst_set_t::avx2>(a, zeroPTRegAvx2_); + } + + createVector16BitOne<inst_set_t::avx2>(a); + // we set ymm10 to contain 8-bit 1s + createVector8BitOne<inst_set_t::avx2>(a); + + genForTopEdgeRowoffset<inst_set_t::avx2>(a); + genForLeftEdgeRowoffset<inst_set_t::avx2>(a); + genForRightEdgeRowoffset<inst_set_t::avx2>(a); + genForBottomEdgeRowoffset<inst_set_t::avx2>(a); + + genRowoffsetCore<inst_set_t::avx2>(a); + + asmjit::FuncUtils::emitEpilog(a, layout); + + jit_rowoffset_kernel_fp fn; + asmjit::Error err = rt_.add(&fn, &code_); + if (err) { + std::cout << "Error: in fn add" << std::endl; + return nullptr; + } + auto kernelSig = getKernelSig(conv_param, isZeroPointZero_); + codeCacheRowOffset_[kernelSig] = fn; + +#if defined(FBGEMM_LOG_CODE) + delete codeLogger; + fclose(codeLogfile); +#endif + + return fn; +} + +template < + typename packed_W, + typename outType, + typename processOutputType, + int SPATIAL_DIM> +void fbgemmGroupwiseConv( + const conv_param_t<SPATIAL_DIM>& conv_param, + const std::uint8_t* activations, + std::int32_t a_zero_point, + std::int32_t* rowOffsetBuf, + packed_W& packed_weights, + outType* out, + int32_t* outBuffer, + const processOutputType& outProcess, + int thread_id, + int num_threads) { + + int MB = conv_param.MB; + int H = conv_param.OUT_DIM[0]; + int W = conv_param.OUT_DIM[1]; + int G = conv_param.G; + int K_per_G = conv_param.OC / G; + int C_per_G = conv_param.IC / G; + int oh_ow = conv_param.OUT_DIM[0] * conv_param.OUT_DIM[1]; + + static_assert(SPATIAL_DIM == 2, "3D conv not supported yet"); + + int32_t* rowOffsetTrDest = + rowOffsetBuf + 8 * conv_param.IN_DIM[0] * conv_param.IN_DIM[1]; + if (fbgemmOptimizedGConv<SPATIAL_DIM>(conv_param)) { + assert(G % 8 == 0); + // generate convolution kernel + jit_conv_kernel_fp fpConv = + getOrCreateConvKernel<>(conv_param, a_zero_point); + // generate row offset kernel + jit_rowoffset_kernel_fp fpRowoffset = + getOrCreateRowOffsetKernel(conv_param, a_zero_point); + for (int i = 0; i < MB; ++i) { + const uint8_t* actStartBatch = activations + + i * conv_param.IN_DIM[0] * conv_param.IN_DIM[1] * conv_param.IC; + for (int gOuter = 0; gOuter < G; gOuter += 8) { + // for C_per_G == 4 and K_per_G == 4, row offset is calcualted for 8 + // groups at a time The result is row offsets in the format IH*IW x G + fpRowoffset( + actStartBatch + gOuter * C_per_G, + a_zero_point, + H, + W, + rowOffsetBuf); + // Transpose to get row offsets in the format G x IH*IW + internal::transpose_8x8( + conv_param.IN_DIM[0] * conv_param.IN_DIM[1], + 8, + (const float*)rowOffsetBuf, + 8, + (float*)rowOffsetTrDest, + conv_param.IN_DIM[0] * conv_param.IN_DIM[1]); + int gLimit = gOuter + 8; + for (int g = gOuter; g < gLimit; g += 2) { + int32_t* currOutBuf = + outBuffer + i * oh_ow * conv_param.OC + g * K_per_G; + const uint8_t* actStartGroup = actStartBatch + g * C_per_G; + + fpConv( + actStartGroup, + packed_weights.getBuf() + g * K_per_G * C_per_G, + currOutBuf, + a_zero_point, + H, + W); + + // Output processing should be called for each group + for (int j = 0; j < 2; ++j) { + // calculateRowOffsets( + // conv_param, actStartGroup, rowOffsetBuf, a_zero_point, j); + int32_t* rowOffsetForCurG = rowOffsetTrDest + + ((g - gOuter) + j) * conv_param.IN_DIM[0] * + conv_param.IN_DIM[1]; + // compare_buffers(rowOffsetBuf, rowOffsetForCurG, + // conv_param.IN_DIM[0]*conv_param.IN_DIM[1], 1, 1, 100); + + // outProcess expects rowOffsetBuf to contain row offsets for the + // current group + memcpy( + rowOffsetBuf, + rowOffsetForCurG, + conv_param.IN_DIM[0] * conv_param.IN_DIM[1] * sizeof(int32_t)); + + if (cpuinfo_has_x86_avx512f()) { + // Currently use avx2 code + outProcess.template f<inst_set_t::avx2>( + out, + currOutBuf + j * K_per_G, + {i * oh_ow, oh_ow, (g + j) * K_per_G, K_per_G}, + K_per_G * G, + K_per_G * G); + } else if (cpuinfo_has_x86_avx2()) { + outProcess.template f<inst_set_t::avx2>( + out, + currOutBuf + j * K_per_G, + {i * oh_ow, oh_ow, (g + j) * K_per_G, K_per_G}, + K_per_G * G, + K_per_G * G); + } else { + // TODO: Have default slower path + assert(0 && "unsupported architecure"); + } + } // j loop + } + } + } + } else { + // for the not supported cases, just execute the naive C implementation + conv_ref( + conv_param, + activations, + a_zero_point, + packed_weights.getBuf(), + outBuffer); + for (int i = 0; i < conv_param.MB; ++i) { + for (int g = 0; g < conv_param.G; ++g) { + calculateRowOffsets( + conv_param, + activations + + i * conv_param.IN_DIM[0] * conv_param.IN_DIM[1] * conv_param.IC, + rowOffsetBuf, + a_zero_point, + g); + outProcess.template f<inst_set_t::anyarch>( + out, + outBuffer + i * oh_ow * conv_param.OC + g * K_per_G, + {i * oh_ow, oh_ow, g * K_per_G, K_per_G}, + K_per_G * G, + K_per_G * G); + } + } + } +} + +jit_rowoffset_kernel_fp getOrCreateRowOffsetKernel( + const conv_param_t<>& conv_param, + int a_zero_point) { + // Note: Wrong code is generated if it's not one of the supported convolution + assert(fbgemmOptimizedGConv<2>(conv_param)); + auto kernelSig = getKernelSig(conv_param, a_zero_point == 0); + if (GenConvKernel<int32_t>::codeCacheRowOffset_.find(kernelSig) != + GenConvKernel<int32_t>::codeCacheRowOffset_.end()) { + return GenConvKernel<int32_t>::codeCacheRowOffset_[kernelSig]; + } else { + auto genObj = GenConvKernel<int32_t>(conv_param, a_zero_point); + // TODO: Instruction set based dispatch + return genObj.template getOrCreateRowOffset<inst_set_t::avx2>(conv_param); + } +} + +template <int SPATIAL_DIM> +int rowOffsetBufferSizeGConv(const conv_param_t<SPATIAL_DIM>& conv_param) { + // row offset buffer should be a able to hold row offsets for however + // number of groups we process at a time. + if (cpuinfo_initialize()) { + if (cpuinfo_has_x86_avx512f()) { + int bufferSize = conv_param.OUT_DIM[0] * conv_param.OUT_DIM[1]; + int C_per_G = conv_param.IC / conv_param.G; + int K_per_G = conv_param.OC / conv_param.G; + if (C_per_G == 4 && K_per_G == 4) { + return 2 * 8 * bufferSize; + } else { + return conv_param.G * bufferSize; + } + } else if (cpuinfo_has_x86_avx2()) { + int bufferSize = conv_param.OUT_DIM[0] * conv_param.OUT_DIM[1]; + int C_per_G = conv_param.IC / conv_param.G; + int K_per_G = conv_param.OC / conv_param.G; + if (C_per_G == 4 && K_per_G == 4) { + // row offset is calculated for 8 groups at a time + // 2x is needed for transposing + return 2 * 8 * bufferSize; + } else { + return conv_param.G * bufferSize; + } + } else { + // TODO: Have default slower path + assert(0 && "unsupported architecture"); + return -1; + } + } else { + throw std::runtime_error("Failed to initialize cpuinfo!"); + } +} + +template int rowOffsetBufferSizeGConv<2>(const conv_param_t<2>& conv_param); + +#define INSTANTIATE_BASE(RELU, Q_GRAN) \ + template void fbgemmGroupwiseConv( \ + const conv_param_t<2>& conv_param, \ + const uint8_t* activations, \ + int32_t a_zero_point, \ + std::int32_t* rowOffsetBuf, \ + PackWeightMatrixForGConv<int8_t, int32_t, 2>& packed_weights, \ + uint8_t* out, \ + int32_t* outBuffer, \ + const ReQuantizeOutput<RELU, Q_GRAN>& outProcess, \ + int thread_id, \ + int num_threads); + +#define INSTANTIATE_Q_GRANS(RELU) \ + INSTANTIATE_BASE(RELU, QuantizationGranularity::TENSOR); \ + INSTANTIATE_BASE(RELU, QuantizationGranularity::GROUP); \ + INSTANTIATE_BASE(RELU, QuantizationGranularity::OUT_CHANNEL); + +INSTANTIATE_Q_GRANS(false); +INSTANTIATE_Q_GRANS(true); + +#undef INSTANTIATE_Q_GRANS +#undef INSTANTIATE_BASE + +template void fbgemmGroupwiseConv( + const conv_param_t<2>& conv_param, + const uint8_t* activations, + int32_t a_zero_point, + std::int32_t* rowOffsetBuf, + PackWeightMatrixForGConv<int8_t, int32_t, 2>& packed_weights, + int32_t* out, + int32_t* outBuffer, + const DoNothing<int32_t, int32_t>& outProcess, + int thread_id, + int num_threads); + +} // namespace fbgemm diff --git a/src/PackWeightMatrixForGConv.cc b/src/PackWeightMatrixForGConv.cc new file mode 100644 index 0000000..e6c9b7d --- /dev/null +++ b/src/PackWeightMatrixForGConv.cc @@ -0,0 +1,103 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include <cpuinfo.h> +#include <cassert> +#include <iomanip> +#include "RefImplementations.h" +#include "fbgemm/Fbgemm.h" + +namespace fbgemm { + +template <typename T, typename accT, int SPATIAL_DIM> +PackWeightMatrixForGConv<T, accT, SPATIAL_DIM>::PackWeightMatrixForGConv( + matrix_op_t trans, + const conv_param_t<SPATIAL_DIM>& conv_param, + const T* sdata, + T* pdata) + : trans_(trans), conv_param_(conv_param), sdata_(sdata) { + static_assert(SPATIAL_DIM == 2, "3D conv not supported yet"); + + if (!pdata) { + bufAllocatedHere_ = true; + pdata_ = static_cast<T*>(fbgemmAlignedAlloc( + 64, + conv_param_.G * conv_param_.K[0] * conv_param_.K[1] * + (conv_param_.OC / conv_param_.G) * + (conv_param_.IC / conv_param_.G) * sizeof(T))); + } else { + bufAllocatedHere_ = false; + pdata_ = pdata; + } + pack(); +} + +/** + * @brief Pack weight tensor in a suitable format required for the optimized + * kernel. + * + * Let IC_per_G be number of input channels per group and OC_per_G be number of + * output channels per group. + * + * For IC_per_G == 4 && OC_per_G == 4 optimized + * kernel works on 2 groups at a time hence input channels for g and g+1 group + * are laid out sequentially for each output channel, i.e., the layout is R S + * (G/2) K (2C) + * We work on two groups at a time to fully utilize the avx2 SIMD width of + * 256-bits. + * + * For IC_per_G == 8, 16, 32 && OC_per_G == 8, 16, 32 there is no need to work + * on 2 groups at a time and full SIMD width can be efficiently utilized even + * while working on 1 group at a time. + */ +template <typename T, typename accT, int SPATIAL_DIM> +void PackWeightMatrixForGConv<T, accT, SPATIAL_DIM>::pack() { + // filters are assumed to be in G RS C/G K/G format + int R = conv_param_.K[0]; + int S = conv_param_.K[1]; + int G = conv_param_.G; + int IC_per_G = conv_param_.IC / conv_param_.G; + int OC_per_G = conv_param_.OC / conv_param_.G; + + // If transpose option is set, the weight matrix is in layout G K/G (R S C/G) + // instead of G (R S C/G) K/G + bool tr = (trans_ == matrix_op_t::Transpose); + if (fbgemmOptimizedGConv(conv_param_)) { + // currently only this case is supported + for (int r = 0; r < R; ++r) { + for (int s = 0; s < S; ++s) { + for (int k = 0; k < OC_per_G; ++k) { + for (int g = 0; g < G; ++g) { + for (int c = 0; c < IC_per_G; ++c) { + inpType b = tr + ? sdata_ + [(((g * OC_per_G + k) * R + r) * S + s) * IC_per_G + c] + : sdata_ + [(((g * R + r) * S + s) * IC_per_G + c) * OC_per_G + k]; + pdata_ + [((((r * S + s) * (G / 2) + (g / 2)) * OC_per_G + k) * 2 + + (g % 2)) * + IC_per_G + + c] = b; + } + } + } + } + } + } else { + if (tr) { + // conv_ref expects weights to be in G (R S C/G) K/G format + transposeConvWeights(conv_param_, sdata_, pdata_); + } else { + // just copy the data for not supported cases + memcpy(pdata_, sdata_, G * R * S * OC_per_G * IC_per_G * sizeof(inpType)); + } + } +} + +template class PackWeightMatrixForGConv<int8_t, int32_t, 2>; +template class PackWeightMatrixForGConv<int8_t, int16_t, 2>; +} // namespace fbgemm diff --git a/src/RefImplementations.cc b/src/RefImplementations.cc index 5168a15..5c6cf1b 100644 --- a/src/RefImplementations.cc +++ b/src/RefImplementations.cc @@ -486,6 +486,31 @@ void conv3d_ref( } // for each n } +void transposeConvWeights( + const conv_param_t<>& conv_p, + const std::int8_t* src, + std::int8_t* dest) { + int R = conv_p.K[0]; + int S = conv_p.K[1]; + int G = conv_p.G; + int IC_per_G = conv_p.IC / conv_p.G; + int OC_per_G = conv_p.OC / conv_p.G; + + // Transforms weights from G K/G (R S C/G) to G (R S C/G) K/G format. + for (int r = 0; r < R; ++r) { + for (int s = 0; s < S; ++s) { + for (int k = 0; k < OC_per_G; ++k) { + for (int g = 0; g < G; ++g) { + for (int c = 0; c < IC_per_G; ++c) { + dest[(((g * R + r) * S + s) * IC_per_G + c) * OC_per_G + k] = + src[(((g * OC_per_G + k) * R + r) * S + s) * IC_per_G + c]; + } + } + } + } + } +} + void depthwise_3x3_pad_1_ref( int N, int H, diff --git a/src/RefImplementations.h b/src/RefImplementations.h index fce68e6..62f17e9 100644 --- a/src/RefImplementations.h +++ b/src/RefImplementations.h @@ -176,6 +176,14 @@ FBGEMM_API void conv3d_ref( std::int32_t* C); /* + * @brief Transforms weights from G K/G (R S C/G) to G (R S C/G) K/G format. + */ +FBGEMM_API void transposeConvWeights( + const conv_param_t<>& conv_p, + const std::int8_t* src, + std::int8_t* dest); + +/* * @brief Reference implementation of im2col operation. * The input A is assumed to be in NHiWiC format. * The output A is assumed to be in NHoWoRSC format. diff --git a/test/GConvTest.cc b/test/GConvTest.cc new file mode 100644 index 0000000..34042c6 --- /dev/null +++ b/test/GConvTest.cc @@ -0,0 +1,382 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include <algorithm> +#include <chrono> +#include <cmath> +#include <random> +#include <vector> + +#ifdef _OPENMP +#include <omp.h> +#endif + +#include <gtest/gtest.h> + +#include "QuantizationHelpers.h" +#include "TestUtils.h" +#include "bench/BenchUtils.h" +#include "fbgemm/Fbgemm.h" +#include "src/RefImplementations.h" + +using namespace std; +using namespace fbgemm; + +vector<matrix_op_t> transposeVals{matrix_op_t::NoTranspose, + matrix_op_t::Transpose}; + +vector<QuantizationGranularity> qGranularityVals{ + QuantizationGranularity::TENSOR, + QuantizationGranularity::GROUP, + QuantizationGranularity::OUT_CHANNEL}; + +namespace { +class fbgemmGConvAcc32Test + : public testing::TestWithParam<tuple<matrix_op_t, matrix_op_t>> {}; +class fbgemmGConvAcc32WithQuantGranularityTest + : public testing::TestWithParam< + tuple<matrix_op_t, matrix_op_t, QuantizationGranularity>> {}; +}; // namespace + +INSTANTIATE_TEST_CASE_P( + InstantiationName, + fbgemmGConvAcc32Test, + ::testing::Combine( + ::testing::Values(matrix_op_t::NoTranspose), + ::testing::ValuesIn(transposeVals))); + +INSTANTIATE_TEST_CASE_P( + InstantiationName, + fbgemmGConvAcc32WithQuantGranularityTest, + ::testing::Combine( + ::testing::Values(matrix_op_t::NoTranspose), + ::testing::ValuesIn(transposeVals), + ::testing::ValuesIn(qGranularityVals))); +/** + * @brief Shapes for unit test. + */ +static vector<conv_param_t<>> GetShapes_() { + vector<conv_param_t<>> shapes = { + // MB, IC, OC, {IH, IW}, G, {KH, KW}, {stride_h, stride_w}, {pad_t, pad_l, + // pad_b, pad_r} + conv_param_t<>(1, 32, 32, {3, 3}, 8, {3, 3}, {1, 1}, {1, 1, 1, 1}), + conv_param_t<>(1, 32, 32, {4, 4}, 8, {3, 3}, {1, 1}, {1, 1, 1, 1}), + conv_param_t<>(1, 32, 32, {3, 5}, 8, {3, 3}, {1, 1}, {1, 1, 1, 1}), + conv_param_t<>(1, 32, 32, {5, 3}, 8, {3, 3}, {1, 1}, {1, 1, 1, 1}), + conv_param_t<>(1, 8, 8, {5, 5}, 2, {3, 3}, {1, 1}, {1, 1, 1, 1}), + conv_param_t<>(1, 128, 128, {56, 48}, 32, {3, 3}, {1, 1}, {1, 1, 1, 1}), + conv_param_t<>(1, 128, 128, {48, 56}, 32, {3, 3}, {1, 1}, {1, 1, 1, 1}), + conv_param_t<>(1, 128, 128, {56, 56}, 32, {3, 3}, {1, 1}, {1, 1, 1, 1}), + conv_param_t<>(2, 128, 128, {56, 56}, 32, {3, 3}, {1, 1}, {1, 1, 1, 1}), + }; + return shapes; +} + +/** + * @brief Unit test for uint8 activations, int8 weights, and 32-bit + * accumulation. Output processing: requantization -> nothing + */ +TEST_P(fbgemmGConvAcc32WithQuantGranularityTest, requantizeTest) { + vector<conv_param_t<>> shapes(GetShapes_()); + matrix_op_t atrans, btrans; + QuantizationGranularity q_granularity; + tie(atrans, btrans, q_granularity) = GetParam(); + + for (auto conv_p : shapes) { + int R = conv_p.K[0]; + int S = conv_p.K[1]; + int G = conv_p.G; + int OC = conv_p.OC; + int OH = conv_p.OUT_DIM[0]; + int OW = conv_p.OUT_DIM[1]; + int IC_per_G = conv_p.IC / conv_p.G; + int OC_per_G = conv_p.OC / conv_p.G; + + // activations + aligned_vector<uint8_t> Aint8( + conv_p.MB * conv_p.IN_DIM[0] * conv_p.IN_DIM[1] * conv_p.IC, 0); + + // weights + // when btrans == Transpose, the weight matrix is in layout G K/G (R S C/G) + // instead of G (R S C/G) K/G + aligned_vector<int8_t> Bint8(R * S * conv_p.G * IC_per_G * OC_per_G, 0); + aligned_vector<int8_t> Bint8_tr(R * S * G * IC_per_G * OC_per_G, 0); + + aligned_vector<int32_t> Cint32_ref(conv_p.MB * OH * OW * OC, 0); + aligned_vector<int32_t> Cint32_fb(Cint32_ref.size(), 0); + aligned_vector<uint8_t> Cint8_ref(Cint32_ref.size(), 0); + aligned_vector<uint8_t> Cint8_fb(Cint32_ref.size(), 0); + + randFill<uint8_t>(Aint8, 0, 5); + int32_t Aint8_zero_point = 4; + + randFill<int8_t>(Bint8, -4, 4); + + // computing column offset + vector<int32_t> col_offsets(G * OC_per_G); + + int ncols_per_quant_group = G * OC_per_G; + if (q_granularity == QuantizationGranularity::GROUP) { + ncols_per_quant_group = OC_per_G; + } else if (q_granularity == QuantizationGranularity::OUT_CHANNEL) { + ncols_per_quant_group = 1; + } + + aligned_vector<int32_t> Bint8_zero_point( + G * OC_per_G / ncols_per_quant_group); + randFill(Bint8_zero_point, -3, -1); + + // matrix dimensions after im2col for each GEMM. + // For each group, there is one GEMM of the following dimensions + int MDim = conv_p.MB * OH * OW; + int NDim = OC_per_G; + int KDim = R * S * IC_per_G; + + vector<uint8_t> Aint8_im2col(MDim * KDim * G); + im2col_ref(conv_p, Aint8.data(), Aint8_zero_point, Aint8_im2col.data()); + + vector<int32_t> row_offsets(MDim); + + aligned_vector<float> C_multiplier(Bint8_zero_point.size()); + randFill(C_multiplier, 0.1234f / 2, 0.1234f * 3 / 2); + int32_t C_zero_pt = 5; + + // reference implementation + // conv_ref expects weights to be in G (R S C/G) K/G + int8_t* rightBData = Bint8.data(); + if (btrans == matrix_op_t::Transpose) { + transposeConvWeights(conv_p, Bint8.data(), Bint8_tr.data()); + rightBData = Bint8_tr.data(); + } + for (int g = 0; g < G; ++g) { + col_offsets_with_zero_pt_s8acc32_ref( + R * S * IC_per_G, + OC_per_G, + OC_per_G, + rightBData + g * R * S * IC_per_G * OC_per_G, + Bint8_zero_point.data() + g * OC_per_G / ncols_per_quant_group, + col_offsets.data() + g * OC_per_G, + ncols_per_quant_group); + } + conv_ref( + conv_p, Aint8.data(), Aint8_zero_point, rightBData, Cint32_ref.data()); + + for (int g = 0; g < G; ++g) { + row_offsets_u8acc32_ref( + MDim, + KDim, + KDim * G, + Aint8_im2col.data() + g * KDim, + row_offsets.data()); + + requantize_u8acc32_ref( + MDim, + NDim, + G * NDim, + Cint32_ref.data() + g * NDim, + Cint8_ref.data() + g * NDim, + C_multiplier.data() + g * NDim / ncols_per_quant_group, + C_zero_pt, + Aint8_zero_point, + Bint8_zero_point.data() + g * NDim / ncols_per_quant_group, + row_offsets.data(), + col_offsets.data() + g * NDim, + nullptr, + ncols_per_quant_group); + } + + PackWeightMatrixForGConv<int8_t> packedWeights( + btrans, conv_p, Bint8.data(), nullptr); + + // TODO: Uncomment once we support multiple threads in fbgemmGroupwiseConv + // #ifdef _OPENMP + // #pragma omp parallel + // #endif + { + vector<int32_t> row_offset_buf(rowOffsetBufferSizeGConv(conv_p)); + + DoNothing<> doNothingObj{}; + + int num_threads = fbgemm_get_num_threads(); + int tid = fbgemm_get_thread_num(); + + if (q_granularity == QuantizationGranularity::TENSOR) { + ReQuantizeOutput<false, QuantizationGranularity::TENSOR> reqObj( + doNothingObj, + C_multiplier.data(), + C_zero_pt, + Aint8_zero_point, + Bint8_zero_point.data(), + row_offset_buf.data(), + col_offsets.data(), + nullptr, + G * NDim, + G); + + fbgemmGroupwiseConv( + conv_p, + Aint8.data(), + Aint8_zero_point, + row_offset_buf.data(), + packedWeights, + Cint8_fb.data(), + Cint32_fb.data(), + reqObj, + tid, + num_threads); + + } else if (q_granularity == QuantizationGranularity::GROUP) { + ReQuantizeOutput<false, QuantizationGranularity::GROUP> reqObj( + doNothingObj, + C_multiplier.data(), + C_zero_pt, + Aint8_zero_point, + Bint8_zero_point.data(), + row_offset_buf.data(), + col_offsets.data(), + nullptr, + G * NDim, + G); + + fbgemmGroupwiseConv( + conv_p, + Aint8.data(), + Aint8_zero_point, + row_offset_buf.data(), + packedWeights, + Cint8_fb.data(), + Cint32_fb.data(), + reqObj, + tid, + num_threads); + + } else { + ReQuantizeOutput<false, QuantizationGranularity::OUT_CHANNEL> reqObj( + doNothingObj, + C_multiplier.data(), + C_zero_pt, + Aint8_zero_point, + Bint8_zero_point.data(), + row_offset_buf.data(), + col_offsets.data(), + nullptr, + G * NDim, + G); + + fbgemmGroupwiseConv( + conv_p, + Aint8.data(), + Aint8_zero_point, + row_offset_buf.data(), + packedWeights, + Cint8_fb.data(), + Cint32_fb.data(), + reqObj, + tid, + num_threads); + } + } // omp parallel + + compare_validate_buffers( + Cint8_ref.data(), + Cint8_fb.data(), + MDim, + NDim * G, + NDim * G, + static_cast<uint8_t>(0)); + } // for each shape +} + +/** + * @brief Unit test for uint8 activations, int8 weights, and 32-bit + * accumulation. Output processing: nothing + */ +TEST_P(fbgemmGConvAcc32Test, NoRequantizeTest) { + vector<conv_param_t<>> shapes(GetShapes_()); + matrix_op_t atrans, btrans; + tie(atrans, btrans) = GetParam(); + + for (auto conv_p : shapes) { + int R = conv_p.K[0]; + int S = conv_p.K[1]; + int G = conv_p.G; + int OC = conv_p.OC; + int OH = conv_p.OUT_DIM[0]; + int OW = conv_p.OUT_DIM[1]; + int IC_per_G = conv_p.IC / conv_p.G; + int OC_per_G = conv_p.OC / conv_p.G; + + // activations + aligned_vector<uint8_t> Aint8( + conv_p.MB * conv_p.IN_DIM[0] * conv_p.IN_DIM[1] * conv_p.IC, 0); + + // weights + // when btrans == Transpose, the weight matrix is in layout G K/G (R S C/G) + // instead of G (R S C/G) K/G + aligned_vector<int8_t> Bint8(R * S * conv_p.G * IC_per_G * OC_per_G, 0); + aligned_vector<int8_t> Bint8_tr(R * S * conv_p.G * IC_per_G * OC_per_G, 0); + + aligned_vector<int32_t> Cint32_ref(conv_p.MB * OH * OW * OC, 0); + aligned_vector<int32_t> Cint32_fb(Cint32_ref.size(), 0); + + randFill<uint8_t>(Aint8, 0, 5); + int32_t Aint8_zero_point = 4; + + randFill<int8_t>(Bint8, -4, 4); + + // matrix dimensions after im2col for each GEMM. + // For each group, there is one GEMM of the following dimensions + int MDim = conv_p.MB * OH * OW; + int NDim = OC_per_G; + // int KDim = R * S * IC_per_G; + + // reference implementation + // conv_ref expects weights to be in G (R S C/G) K/G + int8_t* rightBData = Bint8.data(); + if (btrans == matrix_op_t::Transpose) { + transposeConvWeights(conv_p, Bint8.data(), Bint8_tr.data()); + rightBData = Bint8_tr.data(); + } + conv_ref( + conv_p, Aint8.data(), Aint8_zero_point, rightBData, Cint32_ref.data()); + + PackWeightMatrixForGConv<int8_t> packedWeights( + btrans, conv_p, Bint8.data(), nullptr); + + // TODO: Uncomment once we support multiple threads in fbgemmGroupwiseConv + // #ifdef _OPENMP + // #pragma omp parallel + // #endif + { + vector<int32_t> row_offset_buf(rowOffsetBufferSizeGConv(conv_p)); + + DoNothing<int32_t, int32_t> doNothingObj{}; + + int num_threads = fbgemm_get_num_threads(); + int tid = fbgemm_get_thread_num(); + + fbgemmGroupwiseConv( + conv_p, + Aint8.data(), + Aint8_zero_point, + row_offset_buf.data(), + packedWeights, + Cint32_fb.data(), + Cint32_fb.data(), + doNothingObj, + tid, + num_threads); + } + + compare_validate_buffers( + Cint32_ref.data(), + Cint32_fb.data(), + MDim, + NDim * G, + NDim * G, + static_cast<int32_t>(0)); + } // for each shape +} |