From f12ec122be12b0647ada3ff2c374cca57aa4ae95 Mon Sep 17 00:00:00 2001 From: Protonu Basu Date: Tue, 2 Apr 2019 05:22:44 -0700 Subject: Exposing tuning parameters in FBGEMM (MCB, NCB, KCB, MR, NR, Row Interleave) (#90) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/90 Exposing tuning parameters in FBGEMM (MCB, NCB, KCB, MR, NR, Row Interleave) Reviewed By: dskhudia Differential Revision: D14358148 fbshipit-source-id: 783fb4653fd696dbbd4075ad56cb8682db3011a5 --- bench/GEMMsTunableBenchmark.cc | 339 ++++++++++++++++++++++++++++++++ include/fbgemm/Fbgemm.h | 37 ++-- include/fbgemm/PackingTraits-inl.h | 9 + include/fbgemm/Utils.h | 59 ++++++ src/ExecuteKernelGeneric.h | 3 +- src/ExecuteKernelU8S8.cc | 68 ++++--- src/ExecuteKernelU8S8.h | 4 +- src/Fbgemm.cc | 110 ++++++----- src/GenerateKernel.h | 26 ++- src/GenerateKernelU8S8S32ACC16.cc | 62 ++++-- src/GenerateKernelU8S8S32ACC16Avx512.cc | 65 ++++-- src/GenerateKernelU8S8S32ACC32.cc | 60 ++++-- src/GenerateKernelU8S8S32ACC32Avx512.cc | 65 ++++-- src/PackAMatrix.cc | 43 ++-- src/PackAWithIm2Col.cc | 53 +++-- src/PackAWithQuantRowOffset.cc | 61 +++--- src/PackAWithRowOffset.cc | 50 +++-- src/PackBMatrix.cc | 39 ++-- src/PackMatrix.cc | 44 ++++- 19 files changed, 949 insertions(+), 248 deletions(-) create mode 100644 bench/GEMMsTunableBenchmark.cc diff --git a/bench/GEMMsTunableBenchmark.cc b/bench/GEMMsTunableBenchmark.cc new file mode 100644 index 0000000..04eed8a --- /dev/null +++ b/bench/GEMMsTunableBenchmark.cc @@ -0,0 +1,339 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include +#include +#include +#include +#include +#include +#include +#include + +#ifdef _OPENMP +#include +#endif + +#ifdef USE_MKL +#include +#endif + +#include "bench/BenchUtils.h" +#include "fbgemm/Fbgemm.h" +#include "src/RefImplementations.h" +#include "test/QuantizationHelpers.h" + +using namespace std; +using namespace fbgemm; + +void performance_test( + const BlockingFactors* tuning_params, + set>& incorrect_configs, + const vector& shape, + array& best_config, + float& giga_ops) { + + bool flush = true; + std::vector llc; + + if (flush) { + llc.resize(128 * 1024 * 1024, 1.0); + } + + constexpr int NWARMUP = 4; + constexpr int NITER = 10; + +#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN + cout << "WARNING: the timer may be inaccurate when used by multiple threads." + << endl; + cout << setw(8) << "M, " << setw(8) << "N, " << setw(8) << "K, " << setw(18) + << "Type, " << setw(18) << "Packing (us), " << setw(18) + << "Kernel (us), " << setw(18) << "Postproc (us), " << setw(18) + << "Total (us), " << setw(5) << "GOPs" << endl; +#else +#endif + + chrono::time_point start, end; + + int m = shape[0]; + int n = shape[1]; + int k = shape[2]; + + aligned_vector Aint8(m * k); + aligned_vector Bint8(k * n); + aligned_vector Cfp32_mkl(m * n); + aligned_vector Cint32_mkl(Cfp32_mkl.size()); + aligned_vector Cint32_ref(Cfp32_mkl.size()); + aligned_vector Cint32_fb_acc32(Cfp32_mkl.size()); + aligned_vector Cint32_fb_acc16(Cfp32_mkl.size()); + + // A matrix + randFill(Aint8, 0, 5); + aligned_vector Afp32(Aint8.begin(), Aint8.end()); + + randFill(Bint8, -4, 4); + avoidOverflow(m, n, k, Aint8.data(), Bint8.data()); + + aligned_vector Bfp32(Bint8.begin(), Bint8.end()); + + double nops = 2.0 * static_cast(NITER) * m * n * k; + double ttot = 0.0; + string runType; + + vector row_offsets(m); + + matmul_u8i8acc32_ref( + m, n, k, k, n, n, Aint8.data(), Bint8.data(), Cint32_ref.data()); + + PackBMatrix packedB_int32( + matrix_op_t::NoTranspose, + k, + n, + Bint8.data(), + n, + nullptr, + 1, + tuning_params); + + ttot = 0.0; + runType = "FBGEMM_i8_acc32"; +#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN + double total_packing_time = 0.0; + double total_computing_time = 0.0; + double total_kernel_time = 0.0; + double total_postprocessing_time = 0.0; + double total_run_time = 0.0; +#endif + + for (auto i = 0; i < NWARMUP + NITER; ++i) { +#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN + packing_time = 0.0; + computing_time = 0.0; + kernel_time = 0.0; + postprocessing_time = 0.0; + run_time = 0.0; +#endif + llc_flush(llc); + start = chrono::high_resolution_clock::now(); + +#ifdef _OPENMP +#pragma omp parallel +#endif + { + PackAMatrix packA_int32( + matrix_op_t::NoTranspose, + m, + k, + Aint8.data(), + k, + nullptr, + 1, + tuning_params); + + DoNothing doNothing32BitObj; + memCopy<> memcopyObj(doNothing32BitObj); + int num_threads = fbgemm_get_num_threads(); + int tid = fbgemm_get_thread_num(); + // printf ( "tid: %d, num_threads: %d\n", tid, num_threads ); + fbgemmPacked( + packA_int32, + packedB_int32, + Cint32_fb_acc32.data(), + Cint32_fb_acc32.data(), + n, + memcopyObj, + tid, + num_threads, + tuning_params); + } + + end = chrono::high_resolution_clock::now(); + + if (i >= NWARMUP) { + auto dur = chrono::duration_cast(end - start); + ttot += dur.count(); +#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN + total_packing_time += packing_time; + total_computing_time += computing_time; + total_kernel_time += kernel_time; + total_postprocessing_time += postprocessing_time; + total_run_time += run_time; +#endif + } + } + ((volatile char*)(llc.data())); + +#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN + cout << ", " << setw(16) << total_packing_time / (double)NITER / 1e3 << ", " + << setw(16) << total_kernel_time / (double)NITER / 1e3 << ", " + << setw(16) << total_postprocessing_time / (double)NITER / 1e3 << ", " + << setw(16) << total_run_time / (double)NITER / 1e3; +#endif + + if (compare_buffers( + Cint32_ref.data(), Cint32_fb_acc32.data(), m, n, n, 5)) { + vector config = {tuning_params->MCB, + tuning_params->NCB, + tuning_params->KCB, + tuning_params->MR, + tuning_params->NR, + tuning_params->ROW_INTERLEAVE}; + incorrect_configs.insert(config); + } else { + cout << setw(5) << "MCB, " << setw(5) << "NCB, " << setw(5) << "KCB, " + << setw(5) << "MR, " << setw(5) << "NR, " << setw(5) << "ROW INT." + << endl; + cout << setw(5) << tuning_params->MCB << setw(5) << tuning_params->NCB + << setw(5) << tuning_params->KCB << setw(5) << tuning_params->MR + << setw(5) << tuning_params->NR << setw(5) + << tuning_params->ROW_INTERLEAVE << endl; + + cout << setw(8) << "M, " << setw(8) << "N, " << setw(8) << "K, " + << setw(18) << "Type, " << setw(5) << "GOPS" << endl; + cout << setw(6) << m << ", " << setw(6) << n << ", " << setw(6) << k + << ", " << setw(16) << runType; + cout << ", " << setw(5) << fixed << setw(5) << setprecision(1) + << nops / ttot << endl; + if ((nops/ttot) > giga_ops){ + giga_ops = nops/ttot; + best_config = {tuning_params->MCB, + tuning_params->NCB, + tuning_params->KCB, + tuning_params->MR, + tuning_params->NR, + tuning_params->ROW_INTERLEAVE}; + } + } +} + +int main(int /* unused */, char** /* unused */) { +#ifdef _OPENMP + // Use 1 thread unless OMP_NUM_THREADS is explicit set. + const char* val = getenv("OMP_NUM_THREADS"); + if (val == nullptr || !*val) { + omp_set_num_threads(1); + } +#endif + + vector> shapes = { + // NOTE: clang-format wants to use a different formatting but the current + // formatting should be easier to read. + // m, n, k + //warning these take time to run! + {156800, 4, 36}, + {156800, 8, 36}, + {156800, 16, 36}, + {1, 128, 512}, + {1, 1024, 256}, + {1, 2048, 512}, + {1, 4096, 1024}, + {6, 256, 1024}, + {6, 256, 2048}, + {6, 512, 512}, + {6, 1024, 256}, + {6, 2048, 256}, + {6, 2048, 512}, + {6, 4096, 256}, + {6, 4096, 1024}, + {6, 4096, 2048}, + + {10, 2048, 256}, + {10, 4096, 1024}, + + {20, 2048, 256}, + {20, 4096, 1024}, + + {102, 1024, 512}, + {102, 2323, 256}, + {102, 512, 256}, + + {1, 800, 3200}, + {1, 800, 8000}, + + {16, 256, 1500}, + {16, 256, 1567}, + {1, 128, 2876}, + {16, 128, 1567}, + {1, 128, 2722}, + + {16, 256, 512}, + {64, 800, 320}, + {64, 768, 512}, + {16, 256, 512}, + {128, 128, 128}, + {256, 512, 256}, + {1024, 1024, 1024}, +}; + + vector MCBs; + vector NCBs; + vector KCBs; + vector MRs; + int NR = 16; + int NR_MIN = 16; + int ROW_INTERLEAVE = 4; // do 32-bit accumulation for now + + if (cpuinfo_initialize()) { + if (fbgemmHasAvx512Support()) { + NR = 16; + MCBs.insert(MCBs.end(), {48, 96, 144, 192, 240}); + NCBs.insert(NCBs.end(), {16, 32, 64, 128, 48, 98, 192, 384}); + KCBs.insert( + KCBs.end(), + {256, 320, 384, 448, 512, 576, 640, 704, 768, 832, 960, 1024}); + MRs.insert(MRs.end(), {24, 12, 6, 3, 8, 4, 2, 1}); + } else if (fbgemmHasAvx2Support()) { + assert(0 && "Benchmark will be extended for this architecture"); + } else { + assert(0 && "architecture not supported"); + return 0; + } + } + + set> incorrect_configs; + float giga_ops = 0.0; + array best_config = {0, 0, 0, 0, 0, 0}; + BlockingFactors params; + for (auto const& shape : shapes) { + for (auto const& mcb : MCBs) { + for (auto const& ncb : NCBs) { + for (auto const& kcb : KCBs) { + for (auto const& mr : MRs) { + params.MCB = mcb; + params.NCB = ncb; + params.KCB = kcb; + params.MR = mr; + params.NR = NR; + params.ROW_INTERLEAVE = ROW_INTERLEAVE; + params.NR_MIN = NR_MIN; + if (isValidBlockingFactor(¶ms)) { + performance_test( + ¶ms, incorrect_configs, shape, best_config, giga_ops); + } + } + } + } + } + cout << endl << "This is the Best Config!" << endl; + cout << setw(5) << "MCB, " << setw(5) << "NCB, " << setw(5) << "KCB, " + << setw(5) << "MR, " << setw(5) << "NR, " << setw(5) << "ROW INT." + << setw(5) << "GOPS" << endl; + cout << setw(5) << best_config[0] << setw(5) << best_config[1] << setw(5) + << best_config[2] << setw(5) << best_config[3] << setw(5) + << best_config[4] << setw(5) << best_config[5] << giga_ops << endl; + } // end shapes + + cout << endl << "Warning there are configs that didn't work!" << endl; + for (auto const& entry : incorrect_configs) { + cout << setw(5) << "MCB, " << setw(5) << "NCB, " << setw(5) << "KCB, " + << setw(5) << "MR, " << setw(5) << "NR, " << setw(5) << "ROW INT." + << endl; + cout << setw(5) << entry[0] << setw(5) << entry[1] << setw(5) << entry[2] + << setw(5) << entry[3] << setw(5) << entry[4] << setw(5) << entry[5] + << endl; + } + return 0; +} diff --git a/include/fbgemm/Fbgemm.h b/include/fbgemm/Fbgemm.h index 4f3c92e..48f7255 100644 --- a/include/fbgemm/Fbgemm.h +++ b/include/fbgemm/Fbgemm.h @@ -98,7 +98,8 @@ class PackMatrix { std::int32_t rows, std::int32_t cols, inpType* pmat, - int groups = 1); + int groups = 1, + const BlockingFactors* params = nullptr); /** * @return true usually when the matrix is constant matrix (e.g., weight @@ -124,7 +125,10 @@ class PackMatrix { * client code can use this function to query how big the buffer used for * packing should be. */ - static int packedBufferSize(int rows = 0, int cols = 0); + static int packedBufferSize( + int rows = 0, + int cols = 0, + const BlockingFactors* params = nullptr); /** * @return Pointer to a buffer containing row offset results. Some packing @@ -281,6 +285,8 @@ class PackMatrix { std::int32_t nbrow_; ///< the number of blocks along rows std::int32_t nbcol_; ///< the number of blocks along columns bool bufAllocatedHere_; + const BlockingFactors* + blocking_params; ///< MCB, KCB, NCB, MR, NR, NR_MIN, ROW_INTERLEAVE; private: std::int32_t nrows_, ncols_; @@ -312,7 +318,8 @@ class FBGEMM_API PackAMatrix final const inpType* smat, std::int32_t ld, inpType* pmat = nullptr, - int groups = 1); + int groups = 1, + const BlockingFactors* params = nullptr); /** * Activation matrices are not constant so cannot amortize the cost of @@ -393,7 +400,8 @@ class FBGEMM_API PackBMatrix final const inpType* smat, std::int32_t ld, inpType* pmat = nullptr, - int groups = 1); + int groups = 1, + const BlockingFactors* params = nullptr); /** * Weight matrices are usually constant so worth pre-packing. @@ -532,7 +540,8 @@ class FBGEMM_API PackAWithIm2Col inpType* pmat = nullptr, std::int32_t a_zero_pt = 0, std::int32_t* row_offset = nullptr, - bool b_symmetric = false); + bool b_symmetric = false, + const BlockingFactors* params = nullptr); /** * Activation matrices are not constant so cannot amortize the cost of @@ -569,7 +578,8 @@ class FBGEMM_API PackAWithIm2Col /** * @return Size of row offset buffer in number of elements */ - static int rowOffsetBufferSize(); + static int rowOffsetBufferSize( + const BlockingFactors* params = nullptr); ~PackAWithIm2Col() { if (rowOffsetAllocatedHere) { @@ -615,7 +625,8 @@ class FBGEMM_API PackAWithRowOffset final std::uint32_t ld, inpType* pmat = nullptr, int groups = 1, - std::int32_t* row_offset = nullptr); + std::int32_t* row_offset = nullptr, + const BlockingFactors* params = nullptr); /** * Activation matrices are not constant so cannot amortize the cost of @@ -658,7 +669,8 @@ class FBGEMM_API PackAWithRowOffset final /** * @return size of row offset buffer in number of elements */ - static int rowOffsetBufferSize(); + static int rowOffsetBufferSize( + const BlockingFactors* params = nullptr); ~PackAWithRowOffset() { if (rowOffsetAllocatedHere) { @@ -706,7 +718,8 @@ class FBGEMM_API PackAWithQuantRowOffset final float scale = 1.0f, std::int32_t zero_pt = 0, int groups = 1, - std::int32_t* row_offset = nullptr); + std::int32_t* row_offset = nullptr, + const BlockingFactors* params = nullptr); /** * Activation matrices are not constant so cannot amortize the cost of @@ -749,7 +762,8 @@ class FBGEMM_API PackAWithQuantRowOffset final /** * @return Size of row offset buffer in number of elements */ - static int rowOffsetBufferSize(); + static int rowOffsetBufferSize( + const BlockingFactors* params = nullptr); ~PackAWithQuantRowOffset() { if (rowOffsetAllocatedHere) { @@ -1174,7 +1188,8 @@ FBGEMM_API void fbgemmPacked( std::uint32_t ldc, const processOutputType& outProcess, int thread_id, - int num_threads); + int num_threads, + const BlockingFactors* blocking_params = nullptr); /** * @brief Perform small-channels-per-group groupwise convolution diff --git a/include/fbgemm/PackingTraits-inl.h b/include/fbgemm/PackingTraits-inl.h index 5b50bc9..76eb425 100644 --- a/include/fbgemm/PackingTraits-inl.h +++ b/include/fbgemm/PackingTraits-inl.h @@ -57,6 +57,10 @@ struct PackingTraits< inst_set_t::avx2, typename std::enable_if::value>::type> { static constexpr int MR{12}; ///< Register block for M dimension. + static constexpr int NR_MIN{ + 8}; ///< Minimum register block for N dimension. + ///< 8 because 8*ROW_INTERLEAVE int8 elements + ///< completely fill a 256-bit wide vector. static constexpr int NR{8}; ///< Register block for N dimension. ///< NR = VLEN/8/ROW_INTERLEAVE = 256 / 8 / 4 = 8. ///< Total registers used for N dimension: NCB/NR. @@ -88,6 +92,11 @@ struct PackingTraits< inst_set_t::avx2, typename std::enable_if::value>::type> { static constexpr int MR{3}; ///< Register block for M dimension. + static constexpr int NR_MIN{ + 16}; ///< Minimum register block for N dimension. + ///< 16 because 16*ROW_INTERLEAVE int8 elements + ///< completely fill a 256-bit wide vector. + static constexpr int NR{ 16}; ///< Register block for N dimension; ///< NR = VLEN/8/ROW_INTERLEAVE = 256 / 8 / 2 = 16. diff --git a/include/fbgemm/Utils.h b/include/fbgemm/Utils.h index 2813629..58232e7 100644 --- a/include/fbgemm/Utils.h +++ b/include/fbgemm/Utils.h @@ -87,4 +87,63 @@ FBGEMM_API bool fbgemmHasAvx512Support(); */ FBGEMM_API bool fbgemmHasAvx2Support(); +/** + * @brief Helper struct to enable autotuning of FBGEMM packing and kernels. + * + * This structure is optional. If not used, the default values for these + * parameters are picked up from PackingTraits-inl.h. Please see this + * file for details on these parameters. + */ +struct FBGEMM_API BlockingFactors { + int MR; + int NR; + int NR_MIN; + int ROW_INTERLEAVE; + int MCB; + int KCB; + int NCB; +}; + +template +FBGEMM_API bool isValidBlockingFactor(BlockingFactors* param) { + constexpr bool is_32bit = std::is_same::value; + constexpr bool is_16bit = std::is_same::value; + + if (is_32bit) { + if (param->ROW_INTERLEAVE != 4) + return false; + + if (fbgemmHasAvx512Support()) { + if (param->NR != 16) + return false; + } else if (fbgemmHasAvx2Support()) { + if (param->NR != 8) + return false; + } + } else if (is_16bit) { + if (param->ROW_INTERLEAVE != 2) + return false; + + if (fbgemmHasAvx512Support()) { + if (param->NR != 32) + return false; + } else if (fbgemmHasAvx2Support()) { + if (param->NR != 16) + return false; + } + } + + if (param->MCB % param->MR) + return false; + if (param->NCB % param->NR) + return false; + if (fbgemmHasAvx512Support()) { + if (param->MR * (param->NCB / param->NR) > 24) + return false; + } else if (fbgemmHasAvx2Support()) { + if (param->MR * (param->NCB / param->NR) > 16) + return false; + } + return true; +} } // namespace fbgemm diff --git a/src/ExecuteKernelGeneric.h b/src/ExecuteKernelGeneric.h index 667b0ef..ce9a7bb 100644 --- a/src/ExecuteKernelGeneric.h +++ b/src/ExecuteKernelGeneric.h @@ -40,7 +40,8 @@ class ExecuteKernel : public CodeGenBase< int32_t ldc, const processOutputType& outputProcess, int thread_id, - int num_threads); + int num_threads, + const BlockingFactors* params = nullptr); void execute(int kBlock); private: diff --git a/src/ExecuteKernelU8S8.cc b/src/ExecuteKernelU8S8.cc index 9b0ea41..4175d65 100644 --- a/src/ExecuteKernelU8S8.cc +++ b/src/ExecuteKernelU8S8.cc @@ -33,8 +33,11 @@ ExecuteKernel< int32_t ldc, const processOutputType& outputProcess, int thread_id, - int num_threads) - : packedA_(packA), + int num_threads, + const BlockingFactors* params) + : CodeGenBase( + params), + packedA_(packA), packedB_(packB), matC_(matC), C_buffer_(C_buffer), @@ -42,34 +45,41 @@ ExecuteKernel< outputProcess_(outputProcess), thread_id_(thread_id), num_threads_(num_threads) { - if (fbgemmHasAvx512Support()) { - mbSize_ = PackingTraits< - int8_t, - typename packingAMatrix::accType, - inst_set_t::avx512>::MCB; - nbSize_ = PackingTraits< - int8_t, - typename packingAMatrix::accType, - inst_set_t::avx512>::NCB; - nrMinSize_ = PackingTraits< - int8_t, - typename packingAMatrix::accType, - inst_set_t::avx512>::NR_MIN; - } else if (fbgemmHasAvx2Support()) { - mbSize_ = PackingTraits< - int8_t, - typename packingAMatrix::accType, - inst_set_t::avx2>::MCB; - nbSize_ = PackingTraits< - int8_t, - typename packingAMatrix::accType, - inst_set_t::avx2>::NCB; - nrMinSize_ = PackingTraits< - int8_t, - typename packingAMatrix::accType, - inst_set_t::avx2>::NR; + if (params) { + mbSize_ = params->MCB; + nbSize_ = params->NCB; + nrMinSize_ = params->NR_MIN; + nrSize_ = params->NR; } else { - assert(0 && "unsupported architecure"); + if (fbgemmHasAvx512Support()) { + mbSize_ = PackingTraits< + int8_t, + typename packingAMatrix::accType, + inst_set_t::avx512>::MCB; + nbSize_ = PackingTraits< + int8_t, + typename packingAMatrix::accType, + inst_set_t::avx512>::NCB; + nrMinSize_ = PackingTraits< + int8_t, + typename packingAMatrix::accType, + inst_set_t::avx512>::NR_MIN; + } else if (fbgemmHasAvx2Support()) { + mbSize_ = PackingTraits< + int8_t, + typename packingAMatrix::accType, + inst_set_t::avx2>::MCB; + nbSize_ = PackingTraits< + int8_t, + typename packingAMatrix::accType, + inst_set_t::avx2>::NCB; + nrMinSize_ = PackingTraits< + int8_t, + typename packingAMatrix::accType, + inst_set_t::avx2>::NR; + } else { + assert(0 && "unsupported architecure"); + } } C_tile_ = new int32_t[mbSize_ * nbSize_]; } diff --git a/src/ExecuteKernelU8S8.h b/src/ExecuteKernelU8S8.h index b56f54c..bb20134 100644 --- a/src/ExecuteKernelU8S8.h +++ b/src/ExecuteKernelU8S8.h @@ -44,7 +44,8 @@ class ExecuteKernel< int32_t ldc, const processOutputType& outputProcess, int thread_id, - int num_threads); + int num_threads, + const BlockingFactors* params = nullptr); void execute(int kBlock); ~ExecuteKernel() { @@ -70,6 +71,7 @@ class ExecuteKernel< int mbSize_; ///< block size in the m dimension. int nbSize_; ///< block size in the n dimension. int nrMinSize_; ///< minimum register size in the n dimension. + int nrSize_; ///< register size in the n dimension. }; } // namespace fbgemm diff --git a/src/Fbgemm.cc b/src/Fbgemm.cc index a90dd2d..a40f38a 100644 --- a/src/Fbgemm.cc +++ b/src/Fbgemm.cc @@ -36,7 +36,8 @@ void fbgemmPacked( uint32_t ldc, const processOutputType& outProcess, int thread_id, - int num_threads) { + int num_threads, + const BlockingFactors* blocking_params) { static_assert( std::is_same< typename packingAMatrix::accType, @@ -48,36 +49,43 @@ void fbgemmPacked( // Run time CPU detection if (cpuinfo_initialize()) { - if (fbgemmHasAvx512Support()) { - MCB = PackingTraits< - typename packingAMatrix::inpType, - typename packingAMatrix::accType, - inst_set_t::avx512>::MCB; - KCB = PackingTraits< - typename packingAMatrix::inpType, - typename packingAMatrix::accType, - inst_set_t::avx512>::KCB; - MR = PackingTraits< - typename packingAMatrix::inpType, - typename packingAMatrix::accType, - inst_set_t::avx512>::MR; - } else if (fbgemmHasAvx2Support()) { - MCB = PackingTraits< - typename packingAMatrix::inpType, - typename packingAMatrix::accType, - inst_set_t::avx2>::MCB; - KCB = PackingTraits< - typename packingAMatrix::inpType, - typename packingAMatrix::accType, - inst_set_t::avx2>::KCB; - MR = PackingTraits< - typename packingAMatrix::inpType, - typename packingAMatrix::accType, - inst_set_t::avx2>::MR; + if (blocking_params) { + MCB = blocking_params->MCB; + KCB = blocking_params->KCB; + MR = blocking_params->MR; } else { - // TODO: Have default slower path - assert(0 && "unsupported architecture"); - return; + if (fbgemmHasAvx512Support()) { + MCB = PackingTraits< + typename packingAMatrix::inpType, + typename packingAMatrix::accType, + inst_set_t::avx512>::MCB; + KCB = PackingTraits< + typename packingAMatrix::inpType, + typename packingAMatrix::accType, + inst_set_t::avx512>::KCB; + MR = PackingTraits< + typename packingAMatrix::inpType, + typename packingAMatrix::accType, + inst_set_t::avx512>::MR; + } else if (fbgemmHasAvx2Support()) { + MCB = PackingTraits< + typename packingAMatrix::inpType, + typename packingAMatrix::accType, + inst_set_t::avx2>::MCB; + KCB = PackingTraits< + typename packingAMatrix::inpType, + typename packingAMatrix::accType, + inst_set_t::avx2>::KCB; + MR = PackingTraits< + typename packingAMatrix::inpType, + typename packingAMatrix::accType, + inst_set_t::avx2>::MR; + + } else { + // TODO: Have default slower path + assert(0 && "unsupported architecture"); + return; + } } } else { throw std::runtime_error("Failed to initialize cpuinfo!"); @@ -149,7 +157,8 @@ void fbgemmPacked( ldc, outProcess, thread_id, - num_threads); + num_threads, + blocking_params); for (int i = i_begin; i < i_end; i += MCB) { // i is the element index mc = std::min(i_end - i, MCB); for (int kb = 0; kb < kBlocks; ++kb) { // kb is the block index @@ -209,7 +218,7 @@ template bool fbgemmOptimizedGConv(const conv_param_t<2>& conv_p); template bool fbgemmOptimizedGConv(const conv_param_t<3>& conv_p); bool fbgemmSupportedCPU() { - return (cpuinfo_initialize() && cpuinfo_has_x86_avx2()); + return (cpuinfo_initialize() && fbgemmHasAvx2Support()); } //////////////////////////////////////////////////////////////////////////////// @@ -223,7 +232,8 @@ bool fbgemmSupportedCPU() { uint32_t ldc, \ const ReQuantizeOutput& outProcess, \ int thread_id, \ - int num_threads); + int num_threads, \ + const BlockingFactors* blocking_params); #define INSTANTIATE_Q_GRANS(PACK_A, ACC_T, RELU) \ INSTANTIATE_BASE(PACK_A, ACC_T, RELU, QuantizationGranularity::TENSOR); \ @@ -258,7 +268,8 @@ INSTANTIATE_ACC_T(PackAWithRowOffset); uint32_t ldc, \ const ReQuantizeOutput& outProcess, \ int thread_id, \ - int num_threads); + int num_threads, \ + const BlockingFactors* blocking_params); #define INSTANTIATE_Q_GRANS(ACC_T, RELU, SPATIAL_DIM) \ INSTANTIATE_BASE(ACC_T, RELU, SPATIAL_DIM, QuantizationGranularity::TENSOR); \ @@ -293,7 +304,8 @@ INSTANTIATE_RELU(int16_t); uint32_t ldc, \ const ReQuantizeForFloat& outProcess, \ int thread_id, \ - int num_threads); + int num_threads, \ + const BlockingFactors* blocking_params); #define INSTANTIATE_Q_GRANS(PACK_A, RELU) \ INSTANTIATE_BASE(PACK_A, RELU, QuantizationGranularity::TENSOR); \ @@ -323,7 +335,8 @@ INSTANTIATE_RELU(PackAWithQuantRowOffset); uint32_t ldc, \ const ReQuantizeForFloat& outProcess, \ int thread_id, \ - int num_threads); + int num_threads, \ + const BlockingFactors* blocking_params); #define INSTANTIATE_Q_GRANS(ACC_T, RELU, SPATIAL_DIM) \ INSTANTIATE_BASE(ACC_T, RELU, SPATIAL_DIM, QuantizationGranularity::TENSOR); \ @@ -355,7 +368,8 @@ template void fbgemmPacked( uint32_t ldc, const ReQuantizeForFloat& outProcess, int thread_id, - int num_threads); + int num_threads, + const BlockingFactors* blocking_params); //////////////////////////////////////////////////////////////////////////////// // DoSpmdmOnInpBuffer @@ -371,7 +385,8 @@ template void fbgemmPacked( int32_t, \ ReQuantizeOutput>& outProcess, \ int thread_id, \ - int num_threads); + int num_threads, \ + const BlockingFactors* blocking_params); #define INSTANTIATE_Q_GRANS(PACK_A, RELU) \ INSTANTIATE_BASE(PACK_A, RELU, QuantizationGranularity::TENSOR); \ @@ -401,7 +416,8 @@ INSTANTIATE_RELU(PackAWithRowOffset); int32_t, \ ReQuantizeOutput>& outProcess, \ int thread_id, \ - int num_threads); + int num_threads, \ + const BlockingFactors* blocking_params); #define INSTANTIATE_Q_GRANS(RELU) \ INSTANTIATE_BASE(RELU, QuantizationGranularity::TENSOR); \ @@ -423,7 +439,8 @@ template void fbgemmPacked( const DoSpmdmOnInpBuffer>& outProcess, int thread_id, - int num_threads); + int num_threads, + const BlockingFactors* blocking_params); //////////////////////////////////////////////////////////////////////////////// // memCopy @@ -436,7 +453,8 @@ template void fbgemmPacked( uint32_t ldc, \ const memCopy<>& outProcess, \ int thread_id, \ - int num_threads); + int num_threads, \ + const BlockingFactors* blocking_params); #define INSTANTIATE_ACC_T(PACK_A) \ INSTANTIATE_BASE(PACK_A, int32_t) \ @@ -460,7 +478,8 @@ INSTANTIATE_ACC_T(PackAWithRowOffset); uint32_t ldc, \ const memCopy<>& outProcess, \ int thread_id, \ - int num_threads); + int num_threads, \ + const BlockingFactors* blocking_params); #define INSTANTIATE_SPATIAL_DIM(ACC_T) \ INSTANTIATE_BASE(ACC_T, 2); \ @@ -481,7 +500,8 @@ template void fbgemmPacked( uint32_t ldc, const memCopy<>& outProcess, int thread_id, - int num_threads); + int num_threads, + const BlockingFactors* blocking_params); template void fbgemmPacked( PackMatrix, uint8_t, int16_t>& packA, @@ -491,6 +511,8 @@ template void fbgemmPacked( uint32_t ldc, const DoNothing& outProcess, int thread_id, - int num_threads); + int num_threads, + const BlockingFactors* blocking_params); + } // namespace fbgemm diff --git a/src/GenerateKernel.h b/src/GenerateKernel.h index 7d8ac05..dccdfc5 100644 --- a/src/GenerateKernel.h +++ b/src/GenerateKernel.h @@ -39,8 +39,9 @@ class CodeGenBase { /** * @brief Constructor for initializing AVX2/AVX512 registers. */ - CodeGenBase() - : CRegs_avx2_{x86::ymm0, + CodeGenBase(const BlockingFactors* params = nullptr) + : blocking_params(params), + CRegs_avx2_{x86::ymm0, x86::ymm1, x86::ymm2, x86::ymm3, @@ -136,12 +137,21 @@ class CodeGenBase { bool accum, int leadingDimCRegAssign = 4); + const BlockingFactors* blocking_params; /** * @brief Generate filename to dump generated code * (debug-only) */ template - std::string getCodeLoggingFile(bool accum, int mc, int nc) { + std::string getCodeLoggingFile( + bool accum, + int mc, + int nc, + int NCB, + int KCB, + int MR, + int NR, + int NR_MIN) { std::string fileName = "gemm_"; if (std::is_same::value) { fileName += "acc16_"; @@ -153,6 +163,11 @@ class CodeGenBase { fileName += "accum-" + std::to_string(accum); fileName += "_MC-" + std::to_string(mc); fileName += "_NC-" + std::to_string(nc); + fileName += "_NCB-" + std::to_string(NCB); + fileName += "_NCB-" + std::to_string(KCB); + fileName += "_MR-" + std::to_string(MR); + fileName += "_NR-" + std::to_string(NR); + fileName += "_NR_MIN-" + std::to_string(NR_MIN); if (instSet == inst_set_t::avx512) { fileName += "_avx512"; } else if (instSet == inst_set_t::avx2) { @@ -174,7 +189,10 @@ class CodeGenBase { int VLEN_; ///< Vector width in elements. static thread_local asmjit::JitRuntime rt_; ///< JIT Runtime for asmjit. static thread_local asmjit::CodeHolder code_; ///< JIT Code Holder for asmjit. - static thread_local std::map, jit_micro_kernel_fp> + // The hash depends on accumulate, mc, nc, ncb, kcb, nr, mr, nr_min + static thread_local std::map< + std::tuple, + jit_micro_kernel_fp> codeCache_; ///< JIT Code Cache for reuse. }; diff --git a/src/GenerateKernelU8S8S32ACC16.cc b/src/GenerateKernelU8S8S32ACC16.cc index e5980b9..082518c 100644 --- a/src/GenerateKernelU8S8S32ACC16.cc +++ b/src/GenerateKernelU8S8S32ACC16.cc @@ -17,7 +17,7 @@ thread_local asmjit::CodeHolder CodeGenBase::code_; template thread_local std::map< - std::tuple, + std::tuple, typename CodeGenBase::jit_micro_kernel_fp> CodeGenBase::codeCache_; @@ -136,11 +136,45 @@ CodeGenBase::getOrCreate( int32_t nc, int32_t kc, int32_t /* unused */) { - auto kernelSig = std::make_tuple(accum, mc, nc); + std::tuple kernelSig; + int kBlock; + int nBlock; + int mRegBlockSize; + int nRegBlockSize; + int nRegBlockSizeMin; + int row_interleave; + + if (blocking_params) { + kBlock = blocking_params->KCB; + nBlock = blocking_params->NCB; + mRegBlockSize = blocking_params->MR; + nRegBlockSize = blocking_params->NR; + nRegBlockSizeMin = blocking_params->NR_MIN; + row_interleave = blocking_params->ROW_INTERLEAVE; + } else { + kBlock = PackingTraits::KCB; + nBlock = PackingTraits::NCB; + mRegBlockSize = PackingTraits::MR; + nRegBlockSize = PackingTraits::NR; + nRegBlockSizeMin = + PackingTraits::NR_MIN; + row_interleave = + PackingTraits::ROW_INTERLEAVE; + } + + kernelSig = std::make_tuple( + accum, + mc, + nc, + nBlock, + kBlock, + mRegBlockSize, + nRegBlockSize, + nRegBlockSizeMin); + if (codeCache_.find(kernelSig) != codeCache_.end()) { return codeCache_[kernelSig]; } - code_.reset(false); code_.init(rt_.getCodeInfo()); asmjit::X86Assembler assembler(&code_); @@ -148,22 +182,24 @@ CodeGenBase::getOrCreate( #if defined(FBGEMM_LOG_CODE) // generated code logging - FILE* codeLogfile = - fopen(getCodeLoggingFile(accum, mc, nc).c_str(), "w"); + FILE* codeLogfile = fopen( + getCodeLoggingFile( + accum, + mc, + nc, + nBlock, + kBlock, + mRegBlockSize, + nRegBlockSize, + nRegBlockSizeMin) + .c_str(), + "w"); asmjit::FileLogger* codeLogger = new asmjit::FileLogger(codeLogfile); if (codeLogger) { code_.setLogger(codeLogger); } #endif - constexpr int kBlock = PackingTraits::KCB; - constexpr int nBlock = PackingTraits::NCB; - constexpr int mRegBlockSize = - PackingTraits::MR; - // constexpr int nRegBlockSize = - // PackingTraits::NR; - constexpr int row_interleave = - PackingTraits::ROW_INTERLEAVE; int mRegBlocks = mc / mRegBlockSize; int mRegBlocksRem = mc % mRegBlockSize; assert(kc % row_interleave == 0 && "kc must be a multiple of row_interleave"); diff --git a/src/GenerateKernelU8S8S32ACC16Avx512.cc b/src/GenerateKernelU8S8S32ACC16Avx512.cc index 2ded242..505fec1 100644 --- a/src/GenerateKernelU8S8S32ACC16Avx512.cc +++ b/src/GenerateKernelU8S8S32ACC16Avx512.cc @@ -131,7 +131,42 @@ CodeGenBase::getOrCreate( int32_t nc, int32_t kc, int32_t /* unused */) { - auto kernelSig = std::make_tuple(accum, mc, nc); + std::tuple kernelSig; + int kBlock; + int nBlock; + int mRegBlockSize; + int nRegBlockSize; + int nRegBlockSizeMin; + int row_interleave; + + if (blocking_params) { + kBlock = blocking_params->KCB; + nBlock = blocking_params->NCB; + mRegBlockSize = blocking_params->MR; + nRegBlockSize = blocking_params->NR; + nRegBlockSizeMin = blocking_params->NR_MIN; + row_interleave = blocking_params->ROW_INTERLEAVE; + } else { + kBlock = PackingTraits::KCB; + nBlock = PackingTraits::NCB; + mRegBlockSize = PackingTraits::MR; + nRegBlockSize = PackingTraits::NR; + nRegBlockSizeMin = + PackingTraits::NR_MIN; + row_interleave = + PackingTraits::ROW_INTERLEAVE; + } + + kernelSig = std::make_tuple( + accum, + mc, + nc, + nBlock, + kBlock, + mRegBlockSize, + nRegBlockSize, + nRegBlockSizeMin); + if (codeCache_.find(kernelSig) != codeCache_.end()) { return codeCache_[kernelSig]; } @@ -143,27 +178,24 @@ CodeGenBase::getOrCreate( #if defined(FBGEMM_LOG_CODE) // generated code logging - FILE* codeLogfile = - fopen(getCodeLoggingFile(accum, mc, nc).c_str(), "w"); + FILE* codeLogfile = fopen( + getCodeLoggingFile( + accum, + mc, + nc, + nBlock, + kBlock, + mRegBlockSize, + nRegBlockSize, + nRegBlockSizeMin) + .c_str(), + "w"); asmjit::FileLogger* codeLogger = new asmjit::FileLogger(codeLogfile); if (codeLogger) { code_.setLogger(codeLogger); } #endif - constexpr int kBlock = - PackingTraits::KCB; - constexpr int nBlock = - PackingTraits::NCB; - constexpr int mRegBlockSize = - PackingTraits::MR; - constexpr int nRegBlockSize = - PackingTraits::NR; - constexpr int nRegBlockSizeMin = - PackingTraits::NR_MIN; - constexpr int row_interleave = - PackingTraits::ROW_INTERLEAVE; - assert(kc % row_interleave == 0 && "kc must be a multiple of row_interleave"); assert(nc % nRegBlockSizeMin == 0 && "nc must be a multiple of NR_MIN"); int maxMRegs = mRegBlockSize; @@ -172,7 +204,6 @@ CodeGenBase::getOrCreate( maxMRegs * maxNRegs <= 24 && "MR*(NR*ROW_INTERLEAVE*8/512) \ must be <= 24(available registers constraint)"); - int mRegBlocks = mc / mRegBlockSize; int mRegBlocksRem = mc % mRegBlockSize; diff --git a/src/GenerateKernelU8S8S32ACC32.cc b/src/GenerateKernelU8S8S32ACC32.cc index 203dd9a..ca750d9 100644 --- a/src/GenerateKernelU8S8S32ACC32.cc +++ b/src/GenerateKernelU8S8S32ACC32.cc @@ -17,7 +17,7 @@ thread_local asmjit::CodeHolder CodeGenBase::code_; template thread_local std::map< - std::tuple, + std::tuple, typename CodeGenBase::jit_micro_kernel_fp> CodeGenBase::codeCache_; @@ -140,11 +140,45 @@ CodeGenBase::getOrCreate( int32_t nc, int32_t kc, int32_t /* unused */) { - auto kernelSig = std::make_tuple(accum, mc, nc); + std::tuple kernelSig; + int kBlock; + int nBlock; + int mRegBlockSize; + int nRegBlockSize; + int nRegBlockSizeMin; + int row_interleave; + + if (blocking_params) { + kBlock = blocking_params->KCB; + nBlock = blocking_params->NCB; + mRegBlockSize = blocking_params->MR; + nRegBlockSize = blocking_params->NR; + nRegBlockSizeMin = blocking_params->NR_MIN; + row_interleave = blocking_params->ROW_INTERLEAVE; + } else { + kBlock = PackingTraits::KCB; + nBlock = PackingTraits::NCB; + mRegBlockSize = PackingTraits::MR; + nRegBlockSize = PackingTraits::NR; + nRegBlockSizeMin = + PackingTraits::NR_MIN; + row_interleave = + PackingTraits::ROW_INTERLEAVE; + } + + kernelSig = std::make_tuple( + accum, + mc, + nc, + nBlock, + kBlock, + mRegBlockSize, + nRegBlockSize, + nRegBlockSizeMin); + if (codeCache_.find(kernelSig) != codeCache_.end()) { return codeCache_[kernelSig]; } - code_.reset(false); code_.init(rt_.getCodeInfo()); asmjit::X86Assembler assembler(&code_); @@ -152,20 +186,24 @@ CodeGenBase::getOrCreate( #if defined(FBGEMM_LOG_CODE) // generated code logging FILE* codeLogfile = - fopen(getCodeLoggingFile(accum, mc, nc).c_str(), "w"); + FILE* codeLogfile = fopen( + getCodeLoggingFile( + accum, + mc, + nc, + nBlock, + kBlock, + mRegBlockSize, + nRegBlockSize, + nRegBlockSizeMin) + .c_str(), + "w"); asmjit::FileLogger* codeLogger = new asmjit::FileLogger(codeLogfile); if (codeLogger) { code_.setLogger(codeLogger); } #endif - constexpr int kBlock = PackingTraits::KCB; - constexpr int nBlock = PackingTraits::NCB; - constexpr int mRegBlockSize = - PackingTraits::MR; - constexpr int row_interleave = - PackingTraits::ROW_INTERLEAVE; - assert(kc % row_interleave == 0 && "kc must be a multiple of row_interleave"); // assert(mc <= 12 && "mc must be <= 12 (available registers constraint)"); int mRegBlocks = mc / mRegBlockSize; int mRegBlocksRem = mc % mRegBlockSize; diff --git a/src/GenerateKernelU8S8S32ACC32Avx512.cc b/src/GenerateKernelU8S8S32ACC32Avx512.cc index 333aa9d..d1729e4 100644 --- a/src/GenerateKernelU8S8S32ACC32Avx512.cc +++ b/src/GenerateKernelU8S8S32ACC32Avx512.cc @@ -131,11 +131,45 @@ CodeGenBase::getOrCreate( int32_t nc, int32_t kc, int32_t /* unused */) { - auto kernelSig = std::make_tuple(accum, mc, nc); + std::tuple kernelSig; + int kBlock; + int nBlock; + int mRegBlockSize; + int nRegBlockSize; + int nRegBlockSizeMin; + int row_interleave; + + if (blocking_params) { + kBlock = blocking_params->KCB; + nBlock = blocking_params->NCB; + mRegBlockSize = blocking_params->MR; + nRegBlockSize = blocking_params->NR; + nRegBlockSizeMin = blocking_params->NR_MIN; + row_interleave = blocking_params->ROW_INTERLEAVE; + } else { + kBlock = PackingTraits::KCB; + nBlock = PackingTraits::NCB; + mRegBlockSize = PackingTraits::MR; + nRegBlockSize = PackingTraits::NR; + nRegBlockSizeMin = + PackingTraits::NR_MIN; + row_interleave = + PackingTraits::ROW_INTERLEAVE; + } + + kernelSig = std::make_tuple( + accum, + mc, + nc, + nBlock, + kBlock, + mRegBlockSize, + nRegBlockSize, + nRegBlockSizeMin); + if (codeCache_.find(kernelSig) != codeCache_.end()) { return codeCache_[kernelSig]; } - code_.reset(false); code_.init(rt_.getCodeInfo()); asmjit::X86Assembler assembler(&code_); @@ -143,27 +177,24 @@ CodeGenBase::getOrCreate( #if defined(FBGEMM_LOG_CODE) // generated code logging - FILE* codeLogfile = - fopen(getCodeLoggingFile(accum, mc, nc).c_str(), "w"); + FILE* codeLogfile = fopen( + getCodeLoggingFile( + accum, + mc, + nc, + nBlock, + kBlock, + mRegBlockSize, + nRegBlockSize, + nRegBlockSizeMin) + .c_str(), + "w"); asmjit::FileLogger* codeLogger = new asmjit::FileLogger(codeLogfile); if (codeLogger) { code_.setLogger(codeLogger); } #endif - constexpr int kBlock = - PackingTraits::KCB; - constexpr int nBlock = - PackingTraits::NCB; - constexpr int mRegBlockSize = - PackingTraits::MR; - constexpr int nRegBlockSize = - PackingTraits::NR; - constexpr int nRegBlockSizeMin = - PackingTraits::NR_MIN; - constexpr int row_interleave = - PackingTraits::ROW_INTERLEAVE; - assert(kc % row_interleave == 0 && "kc must be a multiple of row_interleave"); assert(nc % nRegBlockSizeMin == 0 && "nc must be a multiple of NR_MIN"); int maxMRegs = mRegBlockSize; diff --git a/src/PackAMatrix.cc b/src/PackAMatrix.cc index db019db..89ec13e 100644 --- a/src/PackAMatrix.cc +++ b/src/PackAMatrix.cc @@ -20,24 +20,36 @@ PackAMatrix::PackAMatrix( const T* smat, int32_t ld, inpType* pmat, - int groups) - : PackMatrix, T, accT>(nRow, nCol, pmat, groups), + int groups, + const BlockingFactors* params) + : PackMatrix, T, accT>( + nRow, + nCol, + pmat, + groups, + params), trans_(trans), smat_(smat), ld_(ld) { - if (fbgemmHasAvx512Support()) { - BaseType::brow_ = PackingTraits::MCB; - BaseType::bcol_ = PackingTraits::KCB; - row_interleave_B_ = - PackingTraits::ROW_INTERLEAVE; - } else if (fbgemmHasAvx2Support()) { - BaseType::brow_ = PackingTraits::MCB; - BaseType::bcol_ = PackingTraits::KCB; - row_interleave_B_ = - PackingTraits::ROW_INTERLEAVE; + if (params) { + BaseType::brow_ = params->MCB; + BaseType::bcol_ = params->KCB; + row_interleave_B_ = params->ROW_INTERLEAVE; } else { - // TODO: Have default slower path - assert(0 && "unsupported architecure"); + if (fbgemmHasAvx512Support()) { + BaseType::brow_ = PackingTraits::MCB; + BaseType::bcol_ = PackingTraits::KCB; + row_interleave_B_ = + PackingTraits::ROW_INTERLEAVE; + } else if (fbgemmHasAvx2Support()) { + BaseType::brow_ = PackingTraits::MCB; + BaseType::bcol_ = PackingTraits::KCB; + row_interleave_B_ = + PackingTraits::ROW_INTERLEAVE; + } else { + // TODO: Have default slower path + assert(0 && "unsupported architecure"); + } } if (BaseType::numCols() % groups != 0) { throw std::runtime_error( @@ -46,8 +58,7 @@ PackAMatrix::PackAMatrix( } if (pmat) { BaseType::buf_ = pmat; - } - else { + } else { BaseType::bufAllocatedHere_ = true; BaseType::buf_ = (T*)fbgemmAlignedAlloc( 64, BaseType::brow_ * BaseType::bcol_ * sizeof(T)); diff --git a/src/PackAWithIm2Col.cc b/src/PackAWithIm2Col.cc index 93408da..fb4556c 100644 --- a/src/PackAWithIm2Col.cc +++ b/src/PackAWithIm2Col.cc @@ -23,7 +23,8 @@ PackAWithIm2Col::PackAWithIm2Col( inpType* pmat, int32_t a_zero_pt, int32_t* row_offset, - bool b_symmetric) + bool b_symmetric, + const BlockingFactors* params) : PackMatrix, T, accT>( conv_p.MB * std::accumulate( @@ -38,25 +39,33 @@ PackAWithIm2Col::PackAWithIm2Col( std::multiplies()) * conv_p.IC, pmat, - conv_p.G), + conv_p.G, + params), conv_p_(conv_p), sdata_(sdata), a_zero_pt_(a_zero_pt) { static_assert( SPATIAL_DIM == 2 || SPATIAL_DIM == 3, "unsupported conv dimension "); - if (fbgemmHasAvx512Support()) { - BaseType::brow_ = PackingTraits::MCB; - BaseType::bcol_ = PackingTraits::KCB; - row_interleave_B_ = - PackingTraits::ROW_INTERLEAVE; - } else if (fbgemmHasAvx2Support()) { - BaseType::brow_ = PackingTraits::MCB; - BaseType::bcol_ = PackingTraits::KCB; - row_interleave_B_ = - PackingTraits::ROW_INTERLEAVE; + + if (params) { + BaseType::brow_ = params->MCB; + BaseType::bcol_ = params->KCB; + row_interleave_B_ = params->ROW_INTERLEAVE; } else { - // TODO: Have default slower path - assert(0 && "unsupported architecure"); + if (fbgemmHasAvx512Support()) { + BaseType::brow_ = PackingTraits::MCB; + BaseType::bcol_ = PackingTraits::KCB; + row_interleave_B_ = + PackingTraits::ROW_INTERLEAVE; + } else if (fbgemmHasAvx2Support()) { + BaseType::brow_ = PackingTraits::MCB; + BaseType::bcol_ = PackingTraits::KCB; + row_interleave_B_ = + PackingTraits::ROW_INTERLEAVE; + } else { + // TODO: Have default slower path + assert(0 && "unsupported architecure"); + } } if (BaseType::numCols() % conv_p.G != 0) { throw std::runtime_error( @@ -145,8 +154,7 @@ void pack_a_with_im2col_opt( std::memcpy( out + (i - block.row_start) * BCOL + j + s * IC, sdata + - ((n * IN_DIM_H + h_in) * IN_DIM_W + -PAD_W + w * STRIDE_W + - s) * + ((n * IN_DIM_H + h_in) * IN_DIM_W + -PAD_W + w * STRIDE_W + s) * IC, sizeof(uint8_t) * mid_len * IC); s += mid_len; @@ -459,17 +467,22 @@ void PackAWithIm2Col::printPackedMatrix( } template -int PackAWithIm2Col::rowOffsetBufferSize() { +int PackAWithIm2Col::rowOffsetBufferSize( + const BlockingFactors* params) { if (cpuinfo_initialize()) { + if (params){ + return params->MCB; + } else { if (fbgemmHasAvx512Support()) { - return PackingTraits::MCB; + return PackingTraits::MCB; } else if (fbgemmHasAvx2Support()) { - return PackingTraits::MCB; - } else { + return PackingTraits::MCB; + } else { // TODO: Have default slower path assert(0 && "unsupported architecture"); return -1; } + } } else { throw std::runtime_error("Failed to initialize cpuinfo!"); } diff --git a/src/PackAWithQuantRowOffset.cc b/src/PackAWithQuantRowOffset.cc index 2929ebb..175425f 100644 --- a/src/PackAWithQuantRowOffset.cc +++ b/src/PackAWithQuantRowOffset.cc @@ -28,12 +28,14 @@ PackAWithQuantRowOffset::PackAWithQuantRowOffset( float scale, int32_t zero_pt, int groups, - int32_t* row_offset) + int32_t* row_offset, + const BlockingFactors* params) : PackMatrix, T, accT>( nRow, nCol, pmat, - groups), + groups, + params), trans_(trans), smat_(smat), ld_(ld), @@ -41,20 +43,30 @@ PackAWithQuantRowOffset::PackAWithQuantRowOffset( zero_pt_(zero_pt), row_offset_(row_offset) { rowOffsetAllocatedHere = false; - - if (fbgemmHasAvx512Support()) { - BaseType::brow_ = PackingTraits::MCB; - BaseType::bcol_ = PackingTraits::KCB; - row_interleave_B_ = - PackingTraits::ROW_INTERLEAVE; - } else if (fbgemmHasAvx2Support()) { - BaseType::brow_ = PackingTraits::MCB; - BaseType::bcol_ = PackingTraits::KCB; - row_interleave_B_ = - PackingTraits::ROW_INTERLEAVE; + if (params) { + if (fbgemmHasAvx512Support() || fbgemmHasAvx2Support()) { + BaseType::brow_ = params->MCB; + BaseType::bcol_ = params->KCB; + row_interleave_B_ = params->ROW_INTERLEAVE; + } else { + // TODO: Have default slower path + assert(0 && "unsupported architecure"); + } } else { - // TODO: Have default slower path - assert(0 && "unknown architecure"); + if (fbgemmHasAvx512Support()) { + BaseType::brow_ = PackingTraits::MCB; + BaseType::bcol_ = PackingTraits::KCB; + row_interleave_B_ = + PackingTraits::ROW_INTERLEAVE; + } else if (fbgemmHasAvx2Support()) { + BaseType::brow_ = PackingTraits::MCB; + BaseType::bcol_ = PackingTraits::KCB; + row_interleave_B_ = + PackingTraits::ROW_INTERLEAVE; + } else { + // TODO: Have default slower path + assert(0 && "unknown architecure"); + } } if (BaseType::numCols() % groups != 0) { throw std::runtime_error( @@ -179,15 +191,20 @@ void PackAWithQuantRowOffset::printPackedMatrix(std::string name) { } template -int PackAWithQuantRowOffset::rowOffsetBufferSize() { +int PackAWithQuantRowOffset::rowOffsetBufferSize( + const BlockingFactors* params) { if (cpuinfo_initialize()) { - if (fbgemmHasAvx512Support()) { - return PackingTraits::MCB; - } else if (fbgemmHasAvx2Support()) { - return PackingTraits::MCB; + if (params) { + return params->MCB; } else { - assert(0 && "unsupported architecture"); - return -1; + if (fbgemmHasAvx512Support()) { + return PackingTraits::MCB; + } else if (fbgemmHasAvx2Support()) { + return PackingTraits::MCB; + } else { + assert(0 && "unsupported architecture"); + return -1; + } } } else { throw std::runtime_error("Failed to initialize cpuinfo!"); diff --git a/src/PackAWithRowOffset.cc b/src/PackAWithRowOffset.cc index 7777f1a..139a6d3 100644 --- a/src/PackAWithRowOffset.cc +++ b/src/PackAWithRowOffset.cc @@ -24,31 +24,38 @@ PackAWithRowOffset::PackAWithRowOffset( uint32_t ld, inpType* pmat, int groups, - int32_t* row_offset) + int32_t* row_offset, + const BlockingFactors* params) : PackMatrix, T, accT>( nRow, nCol, pmat, - groups), + groups, + params), trans_(trans), smat_(smat), ld_(ld), row_offset_(row_offset) { rowOffsetAllocatedHere = false; - - if (fbgemmHasAvx512Support()) { - BaseType::brow_ = PackingTraits::MCB; - BaseType::bcol_ = PackingTraits::KCB; - row_interleave_B_ = - PackingTraits::ROW_INTERLEAVE; - } else if (fbgemmHasAvx2Support()) { - BaseType::brow_ = PackingTraits::MCB; - BaseType::bcol_ = PackingTraits::KCB; - row_interleave_B_ = - PackingTraits::ROW_INTERLEAVE; + if (params) { + BaseType::brow_ = params->MCB; + BaseType::bcol_ = params->KCB; + row_interleave_B_ = params->ROW_INTERLEAVE; } else { - // TODO: Have default slower path - assert(0 && "unknown architecure"); + if (fbgemmHasAvx512Support()) { + BaseType::brow_ = PackingTraits::MCB; + BaseType::bcol_ = PackingTraits::KCB; + row_interleave_B_ = + PackingTraits::ROW_INTERLEAVE; + } else if (fbgemmHasAvx2Support()) { + BaseType::brow_ = PackingTraits::MCB; + BaseType::bcol_ = PackingTraits::KCB; + row_interleave_B_ = + PackingTraits::ROW_INTERLEAVE; + } else { + // TODO: Have default slower path + assert(0 && "unknown architecure"); + } } if (BaseType::numCols() % groups != 0) { throw std::runtime_error( @@ -169,17 +176,22 @@ void PackAWithRowOffset::printPackedMatrix(std::string name) { } template -int PackAWithRowOffset::rowOffsetBufferSize() { +int PackAWithRowOffset::rowOffsetBufferSize( + const BlockingFactors* params) { if (cpuinfo_initialize()) { + if (params){ + return params->MCB; + } else { if (fbgemmHasAvx512Support()) { - return PackingTraits::MCB; - } else if (fbgemmHasAvx2Support()) { - return PackingTraits::MCB; + return PackingTraits::MCB; + } else if (fbgemmHasAvx2Support()) { + return PackingTraits::MCB; } else { // TODO: Have default slower path assert(0 && "unsupported architecture"); return -1; } + } } else { throw std::runtime_error("Failed to initialize cpuinfo!"); } diff --git a/src/PackBMatrix.cc b/src/PackBMatrix.cc index 48641ff..472c802 100644 --- a/src/PackBMatrix.cc +++ b/src/PackBMatrix.cc @@ -174,23 +174,36 @@ PackBMatrix::PackBMatrix( const T* smat, int32_t ld, inpType* pmat, - int groups) - : PackMatrix, T, accT>(nRow, nCol, pmat, groups), + int groups, + const BlockingFactors* params) + : PackMatrix, T, accT>( + nRow, + nCol, + pmat, + groups, + params), trans_(trans), smat_(smat), ld_(ld) { - if (fbgemmHasAvx512Support()) { - BaseType::brow_ = PackingTraits::KCB; - BaseType::bcol_ = PackingTraits::NCB; - row_interleave_ = - PackingTraits::ROW_INTERLEAVE; - } else if (fbgemmHasAvx2Support()) { - BaseType::brow_ = PackingTraits::KCB; - BaseType::bcol_ = PackingTraits::NCB; - row_interleave_ = PackingTraits::ROW_INTERLEAVE; + if (params) { + BaseType::brow_ = params->KCB; + BaseType::bcol_ = params->NCB; + row_interleave_ = params->ROW_INTERLEAVE; } else { - // Error - assert(0 && "unknown architecure"); + if (fbgemmHasAvx512Support()) { + BaseType::brow_ = PackingTraits::KCB; + BaseType::bcol_ = PackingTraits::NCB; + row_interleave_ = + PackingTraits::ROW_INTERLEAVE; + } else if (fbgemmHasAvx2Support()) { + BaseType::brow_ = PackingTraits::KCB; + BaseType::bcol_ = PackingTraits::NCB; + row_interleave_ = + PackingTraits::ROW_INTERLEAVE; + } else { + // Error + assert(0 && "unknown architecure"); + } } if (BaseType::numRows() % groups != 0) { throw std::runtime_error( diff --git a/src/PackMatrix.cc b/src/PackMatrix.cc index 316fc06..e93b97c 100644 --- a/src/PackMatrix.cc +++ b/src/PackMatrix.cc @@ -18,33 +18,57 @@ PackMatrix::PackMatrix( int32_t rows, int32_t cols, inpType* buf, - int groups) + int groups, + const BlockingFactors* params) : buf_(buf), nrows_(rows), ncols_(cols), G_(groups) { bufAllocatedHere_ = false; + blocking_params = params; if (!cpuinfo_initialize()) { throw std::runtime_error("Failed to initialize cpuinfo!"); } } template -int PackMatrix::packedBufferSize(int rows, int cols) { +int PackMatrix::packedBufferSize( + int rows, + int cols, + const BlockingFactors* params) { + int MCB, KCB, NCB; + if (params) { + MCB = params->MCB; + NCB = params->NCB; + KCB = params->KCB; + } else { + if (fbgemmHasAvx512Support()) { + MCB = PackingTraits::MCB; + NCB = PackingTraits::NCB; + KCB = PackingTraits::KCB; + } else if (fbgemmHasAvx2Support()) { + MCB = PackingTraits::MCB; + NCB = PackingTraits::NCB; + KCB = PackingTraits::KCB; + } else { + // TODO: Have default slower path + assert(0 && "unsupported architecure"); + return -1; + } + } + if (fbgemmHasAvx512Support()) { if (isA()) { - return PackingTraits::MCB * - PackingTraits::KCB; + return MCB * KCB; } else { - int rowBlock = PackingTraits::KCB; - int colBlock = PackingTraits::NCB; + int rowBlock = KCB; + int colBlock = NCB; return (((rows + rowBlock - 1) / rowBlock) * rowBlock) * (((cols + colBlock - 1) / colBlock) * colBlock); } } else if (fbgemmHasAvx2Support()) { if (isA()) { - return PackingTraits::MCB * - PackingTraits::KCB; + return MCB * KCB; } else { - int rowBlock = PackingTraits::KCB; - int colBlock = PackingTraits::NCB; + int rowBlock = KCB; + int colBlock = NCB; return (((rows + rowBlock - 1) / rowBlock) * rowBlock) * (((cols + colBlock - 1) / colBlock) * colBlock); } -- cgit v1.2.3