Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/FBGEMM.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorProtonu Basu <protonu@fb.com>2019-04-02 15:22:44 +0300
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2019-04-02 15:28:21 +0300
commitf12ec122be12b0647ada3ff2c374cca57aa4ae95 (patch)
tree43584749ec09d493ea3a3ec04e407c2b88e8c76c
parentd8e0d440ef80362a786f4ebb68cf1b393c33b52d (diff)
Exposing tuning parameters in FBGEMM (MCB, NCB, KCB, MR, NR, Row Interleave) (#90)
Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/90 Exposing tuning parameters in FBGEMM (MCB, NCB, KCB, MR, NR, Row Interleave) Reviewed By: dskhudia Differential Revision: D14358148 fbshipit-source-id: 783fb4653fd696dbbd4075ad56cb8682db3011a5
-rw-r--r--bench/GEMMsTunableBenchmark.cc339
-rw-r--r--include/fbgemm/Fbgemm.h37
-rw-r--r--include/fbgemm/PackingTraits-inl.h9
-rw-r--r--include/fbgemm/Utils.h59
-rw-r--r--src/ExecuteKernelGeneric.h3
-rw-r--r--src/ExecuteKernelU8S8.cc68
-rw-r--r--src/ExecuteKernelU8S8.h4
-rw-r--r--src/Fbgemm.cc110
-rw-r--r--src/GenerateKernel.h26
-rw-r--r--src/GenerateKernelU8S8S32ACC16.cc62
-rw-r--r--src/GenerateKernelU8S8S32ACC16Avx512.cc65
-rw-r--r--src/GenerateKernelU8S8S32ACC32.cc60
-rw-r--r--src/GenerateKernelU8S8S32ACC32Avx512.cc65
-rw-r--r--src/PackAMatrix.cc43
-rw-r--r--src/PackAWithIm2Col.cc53
-rw-r--r--src/PackAWithQuantRowOffset.cc61
-rw-r--r--src/PackAWithRowOffset.cc50
-rw-r--r--src/PackBMatrix.cc39
-rw-r--r--src/PackMatrix.cc44
19 files changed, 949 insertions, 248 deletions
diff --git a/bench/GEMMsTunableBenchmark.cc b/bench/GEMMsTunableBenchmark.cc
new file mode 100644
index 0000000..04eed8a
--- /dev/null
+++ b/bench/GEMMsTunableBenchmark.cc
@@ -0,0 +1,339 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+#include <cpuinfo.h>
+#include <algorithm>
+#include <chrono>
+#include <cmath>
+#include <iomanip>
+#include <iostream>
+#include <vector>
+#include<set>
+
+#ifdef _OPENMP
+#include <omp.h>
+#endif
+
+#ifdef USE_MKL
+#include <mkl.h>
+#endif
+
+#include "bench/BenchUtils.h"
+#include "fbgemm/Fbgemm.h"
+#include "src/RefImplementations.h"
+#include "test/QuantizationHelpers.h"
+
+using namespace std;
+using namespace fbgemm;
+
+void performance_test(
+ const BlockingFactors* tuning_params,
+ set<vector<int>>& incorrect_configs,
+ const vector<int>& shape,
+ array<int, 6>& best_config,
+ float& giga_ops) {
+
+ bool flush = true;
+ std::vector<char> llc;
+
+ if (flush) {
+ llc.resize(128 * 1024 * 1024, 1.0);
+ }
+
+ constexpr int NWARMUP = 4;
+ constexpr int NITER = 10;
+
+#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN
+ cout << "WARNING: the timer may be inaccurate when used by multiple threads."
+ << endl;
+ cout << setw(8) << "M, " << setw(8) << "N, " << setw(8) << "K, " << setw(18)
+ << "Type, " << setw(18) << "Packing (us), " << setw(18)
+ << "Kernel (us), " << setw(18) << "Postproc (us), " << setw(18)
+ << "Total (us), " << setw(5) << "GOPs" << endl;
+#else
+#endif
+
+ chrono::time_point<chrono::high_resolution_clock> start, end;
+
+ int m = shape[0];
+ int n = shape[1];
+ int k = shape[2];
+
+ aligned_vector<uint8_t> Aint8(m * k);
+ aligned_vector<int8_t> Bint8(k * n);
+ aligned_vector<float> Cfp32_mkl(m * n);
+ aligned_vector<int32_t> Cint32_mkl(Cfp32_mkl.size());
+ aligned_vector<int32_t> Cint32_ref(Cfp32_mkl.size());
+ aligned_vector<int32_t> Cint32_fb_acc32(Cfp32_mkl.size());
+ aligned_vector<int32_t> Cint32_fb_acc16(Cfp32_mkl.size());
+
+ // A matrix
+ randFill<uint8_t>(Aint8, 0, 5);
+ aligned_vector<float> Afp32(Aint8.begin(), Aint8.end());
+
+ randFill<int8_t>(Bint8, -4, 4);
+ avoidOverflow(m, n, k, Aint8.data(), Bint8.data());
+
+ aligned_vector<float> Bfp32(Bint8.begin(), Bint8.end());
+
+ double nops = 2.0 * static_cast<double>(NITER) * m * n * k;
+ double ttot = 0.0;
+ string runType;
+
+ vector<int32_t> row_offsets(m);
+
+ matmul_u8i8acc32_ref(
+ m, n, k, k, n, n, Aint8.data(), Bint8.data(), Cint32_ref.data());
+
+ PackBMatrix<int8_t> packedB_int32(
+ matrix_op_t::NoTranspose,
+ k,
+ n,
+ Bint8.data(),
+ n,
+ nullptr,
+ 1,
+ tuning_params);
+
+ ttot = 0.0;
+ runType = "FBGEMM_i8_acc32";
+#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN
+ double total_packing_time = 0.0;
+ double total_computing_time = 0.0;
+ double total_kernel_time = 0.0;
+ double total_postprocessing_time = 0.0;
+ double total_run_time = 0.0;
+#endif
+
+ for (auto i = 0; i < NWARMUP + NITER; ++i) {
+#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN
+ packing_time = 0.0;
+ computing_time = 0.0;
+ kernel_time = 0.0;
+ postprocessing_time = 0.0;
+ run_time = 0.0;
+#endif
+ llc_flush(llc);
+ start = chrono::high_resolution_clock::now();
+
+#ifdef _OPENMP
+#pragma omp parallel
+#endif
+ {
+ PackAMatrix<uint8_t> packA_int32(
+ matrix_op_t::NoTranspose,
+ m,
+ k,
+ Aint8.data(),
+ k,
+ nullptr,
+ 1,
+ tuning_params);
+
+ DoNothing<int32_t, int32_t> doNothing32BitObj;
+ memCopy<> memcopyObj(doNothing32BitObj);
+ int num_threads = fbgemm_get_num_threads();
+ int tid = fbgemm_get_thread_num();
+ // printf ( "tid: %d, num_threads: %d\n", tid, num_threads );
+ fbgemmPacked(
+ packA_int32,
+ packedB_int32,
+ Cint32_fb_acc32.data(),
+ Cint32_fb_acc32.data(),
+ n,
+ memcopyObj,
+ tid,
+ num_threads,
+ tuning_params);
+ }
+
+ end = chrono::high_resolution_clock::now();
+
+ if (i >= NWARMUP) {
+ auto dur = chrono::duration_cast<chrono::nanoseconds>(end - start);
+ ttot += dur.count();
+#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN
+ total_packing_time += packing_time;
+ total_computing_time += computing_time;
+ total_kernel_time += kernel_time;
+ total_postprocessing_time += postprocessing_time;
+ total_run_time += run_time;
+#endif
+ }
+ }
+ ((volatile char*)(llc.data()));
+
+#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN
+ cout << ", " << setw(16) << total_packing_time / (double)NITER / 1e3 << ", "
+ << setw(16) << total_kernel_time / (double)NITER / 1e3 << ", "
+ << setw(16) << total_postprocessing_time / (double)NITER / 1e3 << ", "
+ << setw(16) << total_run_time / (double)NITER / 1e3;
+#endif
+
+ if (compare_buffers(
+ Cint32_ref.data(), Cint32_fb_acc32.data(), m, n, n, 5)) {
+ vector<int> config = {tuning_params->MCB,
+ tuning_params->NCB,
+ tuning_params->KCB,
+ tuning_params->MR,
+ tuning_params->NR,
+ tuning_params->ROW_INTERLEAVE};
+ incorrect_configs.insert(config);
+ } else {
+ cout << setw(5) << "MCB, " << setw(5) << "NCB, " << setw(5) << "KCB, "
+ << setw(5) << "MR, " << setw(5) << "NR, " << setw(5) << "ROW INT."
+ << endl;
+ cout << setw(5) << tuning_params->MCB << setw(5) << tuning_params->NCB
+ << setw(5) << tuning_params->KCB << setw(5) << tuning_params->MR
+ << setw(5) << tuning_params->NR << setw(5)
+ << tuning_params->ROW_INTERLEAVE << endl;
+
+ cout << setw(8) << "M, " << setw(8) << "N, " << setw(8) << "K, "
+ << setw(18) << "Type, " << setw(5) << "GOPS" << endl;
+ cout << setw(6) << m << ", " << setw(6) << n << ", " << setw(6) << k
+ << ", " << setw(16) << runType;
+ cout << ", " << setw(5) << fixed << setw(5) << setprecision(1)
+ << nops / ttot << endl;
+ if ((nops/ttot) > giga_ops){
+ giga_ops = nops/ttot;
+ best_config = {tuning_params->MCB,
+ tuning_params->NCB,
+ tuning_params->KCB,
+ tuning_params->MR,
+ tuning_params->NR,
+ tuning_params->ROW_INTERLEAVE};
+ }
+ }
+}
+
+int main(int /* unused */, char** /* unused */) {
+#ifdef _OPENMP
+ // Use 1 thread unless OMP_NUM_THREADS is explicit set.
+ const char* val = getenv("OMP_NUM_THREADS");
+ if (val == nullptr || !*val) {
+ omp_set_num_threads(1);
+ }
+#endif
+
+ vector<vector<int>> shapes = {
+ // NOTE: clang-format wants to use a different formatting but the current
+ // formatting should be easier to read.
+ // m, n, k
+ //warning these take time to run!
+ {156800, 4, 36},
+ {156800, 8, 36},
+ {156800, 16, 36},
+ {1, 128, 512},
+ {1, 1024, 256},
+ {1, 2048, 512},
+ {1, 4096, 1024},
+ {6, 256, 1024},
+ {6, 256, 2048},
+ {6, 512, 512},
+ {6, 1024, 256},
+ {6, 2048, 256},
+ {6, 2048, 512},
+ {6, 4096, 256},
+ {6, 4096, 1024},
+ {6, 4096, 2048},
+
+ {10, 2048, 256},
+ {10, 4096, 1024},
+
+ {20, 2048, 256},
+ {20, 4096, 1024},
+
+ {102, 1024, 512},
+ {102, 2323, 256},
+ {102, 512, 256},
+
+ {1, 800, 3200},
+ {1, 800, 8000},
+
+ {16, 256, 1500},
+ {16, 256, 1567},
+ {1, 128, 2876},
+ {16, 128, 1567},
+ {1, 128, 2722},
+
+ {16, 256, 512},
+ {64, 800, 320},
+ {64, 768, 512},
+ {16, 256, 512},
+ {128, 128, 128},
+ {256, 512, 256},
+ {1024, 1024, 1024},
+};
+
+ vector<int> MCBs;
+ vector<int> NCBs;
+ vector<int> KCBs;
+ vector<int> MRs;
+ int NR = 16;
+ int NR_MIN = 16;
+ int ROW_INTERLEAVE = 4; // do 32-bit accumulation for now
+
+ if (cpuinfo_initialize()) {
+ if (fbgemmHasAvx512Support()) {
+ NR = 16;
+ MCBs.insert(MCBs.end(), {48, 96, 144, 192, 240});
+ NCBs.insert(NCBs.end(), {16, 32, 64, 128, 48, 98, 192, 384});
+ KCBs.insert(
+ KCBs.end(),
+ {256, 320, 384, 448, 512, 576, 640, 704, 768, 832, 960, 1024});
+ MRs.insert(MRs.end(), {24, 12, 6, 3, 8, 4, 2, 1});
+ } else if (fbgemmHasAvx2Support()) {
+ assert(0 && "Benchmark will be extended for this architecture");
+ } else {
+ assert(0 && "architecture not supported");
+ return 0;
+ }
+ }
+
+ set<vector<int>> incorrect_configs;
+ float giga_ops = 0.0;
+ array<int, 6> best_config = {0, 0, 0, 0, 0, 0};
+ BlockingFactors params;
+ for (auto const& shape : shapes) {
+ for (auto const& mcb : MCBs) {
+ for (auto const& ncb : NCBs) {
+ for (auto const& kcb : KCBs) {
+ for (auto const& mr : MRs) {
+ params.MCB = mcb;
+ params.NCB = ncb;
+ params.KCB = kcb;
+ params.MR = mr;
+ params.NR = NR;
+ params.ROW_INTERLEAVE = ROW_INTERLEAVE;
+ params.NR_MIN = NR_MIN;
+ if (isValidBlockingFactor<int32_t>(&params)) {
+ performance_test(
+ &params, incorrect_configs, shape, best_config, giga_ops);
+ }
+ }
+ }
+ }
+ }
+ cout << endl << "This is the Best Config!" << endl;
+ cout << setw(5) << "MCB, " << setw(5) << "NCB, " << setw(5) << "KCB, "
+ << setw(5) << "MR, " << setw(5) << "NR, " << setw(5) << "ROW INT."
+ << setw(5) << "GOPS" << endl;
+ cout << setw(5) << best_config[0] << setw(5) << best_config[1] << setw(5)
+ << best_config[2] << setw(5) << best_config[3] << setw(5)
+ << best_config[4] << setw(5) << best_config[5] << giga_ops << endl;
+ } // end shapes
+
+ cout << endl << "Warning there are configs that didn't work!" << endl;
+ for (auto const& entry : incorrect_configs) {
+ cout << setw(5) << "MCB, " << setw(5) << "NCB, " << setw(5) << "KCB, "
+ << setw(5) << "MR, " << setw(5) << "NR, " << setw(5) << "ROW INT."
+ << endl;
+ cout << setw(5) << entry[0] << setw(5) << entry[1] << setw(5) << entry[2]
+ << setw(5) << entry[3] << setw(5) << entry[4] << setw(5) << entry[5]
+ << endl;
+ }
+ return 0;
+}
diff --git a/include/fbgemm/Fbgemm.h b/include/fbgemm/Fbgemm.h
index 4f3c92e..48f7255 100644
--- a/include/fbgemm/Fbgemm.h
+++ b/include/fbgemm/Fbgemm.h
@@ -98,7 +98,8 @@ class PackMatrix {
std::int32_t rows,
std::int32_t cols,
inpType* pmat,
- int groups = 1);
+ int groups = 1,
+ const BlockingFactors* params = nullptr);
/**
* @return true usually when the matrix is constant matrix (e.g., weight
@@ -124,7 +125,10 @@ class PackMatrix {
* client code can use this function to query how big the buffer used for
* packing should be.
*/
- static int packedBufferSize(int rows = 0, int cols = 0);
+ static int packedBufferSize(
+ int rows = 0,
+ int cols = 0,
+ const BlockingFactors* params = nullptr);
/**
* @return Pointer to a buffer containing row offset results. Some packing
@@ -281,6 +285,8 @@ class PackMatrix {
std::int32_t nbrow_; ///< the number of blocks along rows
std::int32_t nbcol_; ///< the number of blocks along columns
bool bufAllocatedHere_;
+ const BlockingFactors*
+ blocking_params; ///< MCB, KCB, NCB, MR, NR, NR_MIN, ROW_INTERLEAVE;
private:
std::int32_t nrows_, ncols_;
@@ -312,7 +318,8 @@ class FBGEMM_API PackAMatrix final
const inpType* smat,
std::int32_t ld,
inpType* pmat = nullptr,
- int groups = 1);
+ int groups = 1,
+ const BlockingFactors* params = nullptr);
/**
* Activation matrices are not constant so cannot amortize the cost of
@@ -393,7 +400,8 @@ class FBGEMM_API PackBMatrix final
const inpType* smat,
std::int32_t ld,
inpType* pmat = nullptr,
- int groups = 1);
+ int groups = 1,
+ const BlockingFactors* params = nullptr);
/**
* Weight matrices are usually constant so worth pre-packing.
@@ -532,7 +540,8 @@ class FBGEMM_API PackAWithIm2Col
inpType* pmat = nullptr,
std::int32_t a_zero_pt = 0,
std::int32_t* row_offset = nullptr,
- bool b_symmetric = false);
+ bool b_symmetric = false,
+ const BlockingFactors* params = nullptr);
/**
* Activation matrices are not constant so cannot amortize the cost of
@@ -569,7 +578,8 @@ class FBGEMM_API PackAWithIm2Col
/**
* @return Size of row offset buffer in number of elements
*/
- static int rowOffsetBufferSize();
+ static int rowOffsetBufferSize(
+ const BlockingFactors* params = nullptr);
~PackAWithIm2Col() {
if (rowOffsetAllocatedHere) {
@@ -615,7 +625,8 @@ class FBGEMM_API PackAWithRowOffset final
std::uint32_t ld,
inpType* pmat = nullptr,
int groups = 1,
- std::int32_t* row_offset = nullptr);
+ std::int32_t* row_offset = nullptr,
+ const BlockingFactors* params = nullptr);
/**
* Activation matrices are not constant so cannot amortize the cost of
@@ -658,7 +669,8 @@ class FBGEMM_API PackAWithRowOffset final
/**
* @return size of row offset buffer in number of elements
*/
- static int rowOffsetBufferSize();
+ static int rowOffsetBufferSize(
+ const BlockingFactors* params = nullptr);
~PackAWithRowOffset() {
if (rowOffsetAllocatedHere) {
@@ -706,7 +718,8 @@ class FBGEMM_API PackAWithQuantRowOffset final
float scale = 1.0f,
std::int32_t zero_pt = 0,
int groups = 1,
- std::int32_t* row_offset = nullptr);
+ std::int32_t* row_offset = nullptr,
+ const BlockingFactors* params = nullptr);
/**
* Activation matrices are not constant so cannot amortize the cost of
@@ -749,7 +762,8 @@ class FBGEMM_API PackAWithQuantRowOffset final
/**
* @return Size of row offset buffer in number of elements
*/
- static int rowOffsetBufferSize();
+ static int rowOffsetBufferSize(
+ const BlockingFactors* params = nullptr);
~PackAWithQuantRowOffset() {
if (rowOffsetAllocatedHere) {
@@ -1174,7 +1188,8 @@ FBGEMM_API void fbgemmPacked(
std::uint32_t ldc,
const processOutputType& outProcess,
int thread_id,
- int num_threads);
+ int num_threads,
+ const BlockingFactors* blocking_params = nullptr);
/**
* @brief Perform small-channels-per-group groupwise convolution
diff --git a/include/fbgemm/PackingTraits-inl.h b/include/fbgemm/PackingTraits-inl.h
index 5b50bc9..76eb425 100644
--- a/include/fbgemm/PackingTraits-inl.h
+++ b/include/fbgemm/PackingTraits-inl.h
@@ -57,6 +57,10 @@ struct PackingTraits<
inst_set_t::avx2,
typename std::enable_if<is_8bit<T>::value>::type> {
static constexpr int MR{12}; ///< Register block for M dimension.
+ static constexpr int NR_MIN{
+ 8}; ///< Minimum register block for N dimension.
+ ///< 8 because 8*ROW_INTERLEAVE int8 elements
+ ///< completely fill a 256-bit wide vector.
static constexpr int NR{8}; ///< Register block for N dimension.
///< NR = VLEN/8/ROW_INTERLEAVE = 256 / 8 / 4 = 8.
///< Total registers used for N dimension: NCB/NR.
@@ -88,6 +92,11 @@ struct PackingTraits<
inst_set_t::avx2,
typename std::enable_if<is_8bit<T>::value>::type> {
static constexpr int MR{3}; ///< Register block for M dimension.
+ static constexpr int NR_MIN{
+ 16}; ///< Minimum register block for N dimension.
+ ///< 16 because 16*ROW_INTERLEAVE int8 elements
+ ///< completely fill a 256-bit wide vector.
+
static constexpr int NR{
16}; ///< Register block for N dimension;
///< NR = VLEN/8/ROW_INTERLEAVE = 256 / 8 / 2 = 16.
diff --git a/include/fbgemm/Utils.h b/include/fbgemm/Utils.h
index 2813629..58232e7 100644
--- a/include/fbgemm/Utils.h
+++ b/include/fbgemm/Utils.h
@@ -87,4 +87,63 @@ FBGEMM_API bool fbgemmHasAvx512Support();
*/
FBGEMM_API bool fbgemmHasAvx2Support();
+/**
+ * @brief Helper struct to enable autotuning of FBGEMM packing and kernels.
+ *
+ * This structure is optional. If not used, the default values for these
+ * parameters are picked up from PackingTraits-inl.h. Please see this
+ * file for details on these parameters.
+ */
+struct FBGEMM_API BlockingFactors {
+ int MR;
+ int NR;
+ int NR_MIN;
+ int ROW_INTERLEAVE;
+ int MCB;
+ int KCB;
+ int NCB;
+};
+
+template <typename accT = std::int32_t>
+FBGEMM_API bool isValidBlockingFactor(BlockingFactors* param) {
+ constexpr bool is_32bit = std::is_same<accT, int32_t>::value;
+ constexpr bool is_16bit = std::is_same<accT, int16_t>::value;
+
+ if (is_32bit) {
+ if (param->ROW_INTERLEAVE != 4)
+ return false;
+
+ if (fbgemmHasAvx512Support()) {
+ if (param->NR != 16)
+ return false;
+ } else if (fbgemmHasAvx2Support()) {
+ if (param->NR != 8)
+ return false;
+ }
+ } else if (is_16bit) {
+ if (param->ROW_INTERLEAVE != 2)
+ return false;
+
+ if (fbgemmHasAvx512Support()) {
+ if (param->NR != 32)
+ return false;
+ } else if (fbgemmHasAvx2Support()) {
+ if (param->NR != 16)
+ return false;
+ }
+ }
+
+ if (param->MCB % param->MR)
+ return false;
+ if (param->NCB % param->NR)
+ return false;
+ if (fbgemmHasAvx512Support()) {
+ if (param->MR * (param->NCB / param->NR) > 24)
+ return false;
+ } else if (fbgemmHasAvx2Support()) {
+ if (param->MR * (param->NCB / param->NR) > 16)
+ return false;
+ }
+ return true;
+}
} // namespace fbgemm
diff --git a/src/ExecuteKernelGeneric.h b/src/ExecuteKernelGeneric.h
index 667b0ef..ce9a7bb 100644
--- a/src/ExecuteKernelGeneric.h
+++ b/src/ExecuteKernelGeneric.h
@@ -40,7 +40,8 @@ class ExecuteKernel : public CodeGenBase<
int32_t ldc,
const processOutputType& outputProcess,
int thread_id,
- int num_threads);
+ int num_threads,
+ const BlockingFactors* params = nullptr);
void execute(int kBlock);
private:
diff --git a/src/ExecuteKernelU8S8.cc b/src/ExecuteKernelU8S8.cc
index 9b0ea41..4175d65 100644
--- a/src/ExecuteKernelU8S8.cc
+++ b/src/ExecuteKernelU8S8.cc
@@ -33,8 +33,11 @@ ExecuteKernel<
int32_t ldc,
const processOutputType& outputProcess,
int thread_id,
- int num_threads)
- : packedA_(packA),
+ int num_threads,
+ const BlockingFactors* params)
+ : CodeGenBase<uint8_t, int8_t, int32_t, typename packingAMatrix::accType>(
+ params),
+ packedA_(packA),
packedB_(packB),
matC_(matC),
C_buffer_(C_buffer),
@@ -42,34 +45,41 @@ ExecuteKernel<
outputProcess_(outputProcess),
thread_id_(thread_id),
num_threads_(num_threads) {
- if (fbgemmHasAvx512Support()) {
- mbSize_ = PackingTraits<
- int8_t,
- typename packingAMatrix::accType,
- inst_set_t::avx512>::MCB;
- nbSize_ = PackingTraits<
- int8_t,
- typename packingAMatrix::accType,
- inst_set_t::avx512>::NCB;
- nrMinSize_ = PackingTraits<
- int8_t,
- typename packingAMatrix::accType,
- inst_set_t::avx512>::NR_MIN;
- } else if (fbgemmHasAvx2Support()) {
- mbSize_ = PackingTraits<
- int8_t,
- typename packingAMatrix::accType,
- inst_set_t::avx2>::MCB;
- nbSize_ = PackingTraits<
- int8_t,
- typename packingAMatrix::accType,
- inst_set_t::avx2>::NCB;
- nrMinSize_ = PackingTraits<
- int8_t,
- typename packingAMatrix::accType,
- inst_set_t::avx2>::NR;
+ if (params) {
+ mbSize_ = params->MCB;
+ nbSize_ = params->NCB;
+ nrMinSize_ = params->NR_MIN;
+ nrSize_ = params->NR;
} else {
- assert(0 && "unsupported architecure");
+ if (fbgemmHasAvx512Support()) {
+ mbSize_ = PackingTraits<
+ int8_t,
+ typename packingAMatrix::accType,
+ inst_set_t::avx512>::MCB;
+ nbSize_ = PackingTraits<
+ int8_t,
+ typename packingAMatrix::accType,
+ inst_set_t::avx512>::NCB;
+ nrMinSize_ = PackingTraits<
+ int8_t,
+ typename packingAMatrix::accType,
+ inst_set_t::avx512>::NR_MIN;
+ } else if (fbgemmHasAvx2Support()) {
+ mbSize_ = PackingTraits<
+ int8_t,
+ typename packingAMatrix::accType,
+ inst_set_t::avx2>::MCB;
+ nbSize_ = PackingTraits<
+ int8_t,
+ typename packingAMatrix::accType,
+ inst_set_t::avx2>::NCB;
+ nrMinSize_ = PackingTraits<
+ int8_t,
+ typename packingAMatrix::accType,
+ inst_set_t::avx2>::NR;
+ } else {
+ assert(0 && "unsupported architecure");
+ }
}
C_tile_ = new int32_t[mbSize_ * nbSize_];
}
diff --git a/src/ExecuteKernelU8S8.h b/src/ExecuteKernelU8S8.h
index b56f54c..bb20134 100644
--- a/src/ExecuteKernelU8S8.h
+++ b/src/ExecuteKernelU8S8.h
@@ -44,7 +44,8 @@ class ExecuteKernel<
int32_t ldc,
const processOutputType& outputProcess,
int thread_id,
- int num_threads);
+ int num_threads,
+ const BlockingFactors* params = nullptr);
void execute(int kBlock);
~ExecuteKernel() {
@@ -70,6 +71,7 @@ class ExecuteKernel<
int mbSize_; ///< block size in the m dimension.
int nbSize_; ///< block size in the n dimension.
int nrMinSize_; ///< minimum register size in the n dimension.
+ int nrSize_; ///< register size in the n dimension.
};
} // namespace fbgemm
diff --git a/src/Fbgemm.cc b/src/Fbgemm.cc
index a90dd2d..a40f38a 100644
--- a/src/Fbgemm.cc
+++ b/src/Fbgemm.cc
@@ -36,7 +36,8 @@ void fbgemmPacked(
uint32_t ldc,
const processOutputType& outProcess,
int thread_id,
- int num_threads) {
+ int num_threads,
+ const BlockingFactors* blocking_params) {
static_assert(
std::is_same<
typename packingAMatrix::accType,
@@ -48,36 +49,43 @@ void fbgemmPacked(
// Run time CPU detection
if (cpuinfo_initialize()) {
- if (fbgemmHasAvx512Support()) {
- MCB = PackingTraits<
- typename packingAMatrix::inpType,
- typename packingAMatrix::accType,
- inst_set_t::avx512>::MCB;
- KCB = PackingTraits<
- typename packingAMatrix::inpType,
- typename packingAMatrix::accType,
- inst_set_t::avx512>::KCB;
- MR = PackingTraits<
- typename packingAMatrix::inpType,
- typename packingAMatrix::accType,
- inst_set_t::avx512>::MR;
- } else if (fbgemmHasAvx2Support()) {
- MCB = PackingTraits<
- typename packingAMatrix::inpType,
- typename packingAMatrix::accType,
- inst_set_t::avx2>::MCB;
- KCB = PackingTraits<
- typename packingAMatrix::inpType,
- typename packingAMatrix::accType,
- inst_set_t::avx2>::KCB;
- MR = PackingTraits<
- typename packingAMatrix::inpType,
- typename packingAMatrix::accType,
- inst_set_t::avx2>::MR;
+ if (blocking_params) {
+ MCB = blocking_params->MCB;
+ KCB = blocking_params->KCB;
+ MR = blocking_params->MR;
} else {
- // TODO: Have default slower path
- assert(0 && "unsupported architecture");
- return;
+ if (fbgemmHasAvx512Support()) {
+ MCB = PackingTraits<
+ typename packingAMatrix::inpType,
+ typename packingAMatrix::accType,
+ inst_set_t::avx512>::MCB;
+ KCB = PackingTraits<
+ typename packingAMatrix::inpType,
+ typename packingAMatrix::accType,
+ inst_set_t::avx512>::KCB;
+ MR = PackingTraits<
+ typename packingAMatrix::inpType,
+ typename packingAMatrix::accType,
+ inst_set_t::avx512>::MR;
+ } else if (fbgemmHasAvx2Support()) {
+ MCB = PackingTraits<
+ typename packingAMatrix::inpType,
+ typename packingAMatrix::accType,
+ inst_set_t::avx2>::MCB;
+ KCB = PackingTraits<
+ typename packingAMatrix::inpType,
+ typename packingAMatrix::accType,
+ inst_set_t::avx2>::KCB;
+ MR = PackingTraits<
+ typename packingAMatrix::inpType,
+ typename packingAMatrix::accType,
+ inst_set_t::avx2>::MR;
+
+ } else {
+ // TODO: Have default slower path
+ assert(0 && "unsupported architecture");
+ return;
+ }
}
} else {
throw std::runtime_error("Failed to initialize cpuinfo!");
@@ -149,7 +157,8 @@ void fbgemmPacked(
ldc,
outProcess,
thread_id,
- num_threads);
+ num_threads,
+ blocking_params);
for (int i = i_begin; i < i_end; i += MCB) { // i is the element index
mc = std::min(i_end - i, MCB);
for (int kb = 0; kb < kBlocks; ++kb) { // kb is the block index
@@ -209,7 +218,7 @@ template bool fbgemmOptimizedGConv(const conv_param_t<2>& conv_p);
template bool fbgemmOptimizedGConv(const conv_param_t<3>& conv_p);
bool fbgemmSupportedCPU() {
- return (cpuinfo_initialize() && cpuinfo_has_x86_avx2());
+ return (cpuinfo_initialize() && fbgemmHasAvx2Support());
}
////////////////////////////////////////////////////////////////////////////////
@@ -223,7 +232,8 @@ bool fbgemmSupportedCPU() {
uint32_t ldc, \
const ReQuantizeOutput<RELU, Q_GRAN>& outProcess, \
int thread_id, \
- int num_threads);
+ int num_threads, \
+ const BlockingFactors* blocking_params);
#define INSTANTIATE_Q_GRANS(PACK_A, ACC_T, RELU) \
INSTANTIATE_BASE(PACK_A, ACC_T, RELU, QuantizationGranularity::TENSOR); \
@@ -258,7 +268,8 @@ INSTANTIATE_ACC_T(PackAWithRowOffset);
uint32_t ldc, \
const ReQuantizeOutput<RELU, Q_GRAN>& outProcess, \
int thread_id, \
- int num_threads);
+ int num_threads, \
+ const BlockingFactors* blocking_params);
#define INSTANTIATE_Q_GRANS(ACC_T, RELU, SPATIAL_DIM) \
INSTANTIATE_BASE(ACC_T, RELU, SPATIAL_DIM, QuantizationGranularity::TENSOR); \
@@ -293,7 +304,8 @@ INSTANTIATE_RELU(int16_t);
uint32_t ldc, \
const ReQuantizeForFloat<RELU, Q_GRAN>& outProcess, \
int thread_id, \
- int num_threads);
+ int num_threads, \
+ const BlockingFactors* blocking_params);
#define INSTANTIATE_Q_GRANS(PACK_A, RELU) \
INSTANTIATE_BASE(PACK_A, RELU, QuantizationGranularity::TENSOR); \
@@ -323,7 +335,8 @@ INSTANTIATE_RELU(PackAWithQuantRowOffset);
uint32_t ldc, \
const ReQuantizeForFloat<RELU, Q_GRAN>& outProcess, \
int thread_id, \
- int num_threads);
+ int num_threads, \
+ const BlockingFactors* blocking_params);
#define INSTANTIATE_Q_GRANS(ACC_T, RELU, SPATIAL_DIM) \
INSTANTIATE_BASE(ACC_T, RELU, SPATIAL_DIM, QuantizationGranularity::TENSOR); \
@@ -355,7 +368,8 @@ template void fbgemmPacked(
uint32_t ldc,
const ReQuantizeForFloat<false>& outProcess,
int thread_id,
- int num_threads);
+ int num_threads,
+ const BlockingFactors* blocking_params);
////////////////////////////////////////////////////////////////////////////////
// DoSpmdmOnInpBuffer
@@ -371,7 +385,8 @@ template void fbgemmPacked(
int32_t, \
ReQuantizeOutput<RELU, Q_GRAN>>& outProcess, \
int thread_id, \
- int num_threads);
+ int num_threads, \
+ const BlockingFactors* blocking_params);
#define INSTANTIATE_Q_GRANS(PACK_A, RELU) \
INSTANTIATE_BASE(PACK_A, RELU, QuantizationGranularity::TENSOR); \
@@ -401,7 +416,8 @@ INSTANTIATE_RELU(PackAWithRowOffset);
int32_t, \
ReQuantizeOutput<RELU, Q_GRAN>>& outProcess, \
int thread_id, \
- int num_threads);
+ int num_threads, \
+ const BlockingFactors* blocking_params);
#define INSTANTIATE_Q_GRANS(RELU) \
INSTANTIATE_BASE(RELU, QuantizationGranularity::TENSOR); \
@@ -423,7 +439,8 @@ template void fbgemmPacked(
const DoSpmdmOnInpBuffer<float, int32_t, ReQuantizeForFloat<false>>&
outProcess,
int thread_id,
- int num_threads);
+ int num_threads,
+ const BlockingFactors* blocking_params);
////////////////////////////////////////////////////////////////////////////////
// memCopy
@@ -436,7 +453,8 @@ template void fbgemmPacked(
uint32_t ldc, \
const memCopy<>& outProcess, \
int thread_id, \
- int num_threads);
+ int num_threads, \
+ const BlockingFactors* blocking_params);
#define INSTANTIATE_ACC_T(PACK_A) \
INSTANTIATE_BASE(PACK_A, int32_t) \
@@ -460,7 +478,8 @@ INSTANTIATE_ACC_T(PackAWithRowOffset);
uint32_t ldc, \
const memCopy<>& outProcess, \
int thread_id, \
- int num_threads);
+ int num_threads, \
+ const BlockingFactors* blocking_params);
#define INSTANTIATE_SPATIAL_DIM(ACC_T) \
INSTANTIATE_BASE(ACC_T, 2); \
@@ -481,7 +500,8 @@ template void fbgemmPacked(
uint32_t ldc,
const memCopy<>& outProcess,
int thread_id,
- int num_threads);
+ int num_threads,
+ const BlockingFactors* blocking_params);
template void fbgemmPacked(
PackMatrix<PackAMatrix<uint8_t, int16_t>, uint8_t, int16_t>& packA,
@@ -491,6 +511,8 @@ template void fbgemmPacked(
uint32_t ldc,
const DoNothing<int32_t, int32_t>& outProcess,
int thread_id,
- int num_threads);
+ int num_threads,
+ const BlockingFactors* blocking_params);
+
} // namespace fbgemm
diff --git a/src/GenerateKernel.h b/src/GenerateKernel.h
index 7d8ac05..dccdfc5 100644
--- a/src/GenerateKernel.h
+++ b/src/GenerateKernel.h
@@ -39,8 +39,9 @@ class CodeGenBase {
/**
* @brief Constructor for initializing AVX2/AVX512 registers.
*/
- CodeGenBase()
- : CRegs_avx2_{x86::ymm0,
+ CodeGenBase(const BlockingFactors* params = nullptr)
+ : blocking_params(params),
+ CRegs_avx2_{x86::ymm0,
x86::ymm1,
x86::ymm2,
x86::ymm3,
@@ -136,12 +137,21 @@ class CodeGenBase {
bool accum,
int leadingDimCRegAssign = 4);
+ const BlockingFactors* blocking_params;
/**
* @brief Generate filename to dump generated code
* (debug-only)
*/
template <inst_set_t instSet>
- std::string getCodeLoggingFile(bool accum, int mc, int nc) {
+ std::string getCodeLoggingFile(
+ bool accum,
+ int mc,
+ int nc,
+ int NCB,
+ int KCB,
+ int MR,
+ int NR,
+ int NR_MIN) {
std::string fileName = "gemm_";
if (std::is_same<accT, std::int16_t>::value) {
fileName += "acc16_";
@@ -153,6 +163,11 @@ class CodeGenBase {
fileName += "accum-" + std::to_string(accum);
fileName += "_MC-" + std::to_string(mc);
fileName += "_NC-" + std::to_string(nc);
+ fileName += "_NCB-" + std::to_string(NCB);
+ fileName += "_NCB-" + std::to_string(KCB);
+ fileName += "_MR-" + std::to_string(MR);
+ fileName += "_NR-" + std::to_string(NR);
+ fileName += "_NR_MIN-" + std::to_string(NR_MIN);
if (instSet == inst_set_t::avx512) {
fileName += "_avx512";
} else if (instSet == inst_set_t::avx2) {
@@ -174,7 +189,10 @@ class CodeGenBase {
int VLEN_; ///< Vector width in elements.
static thread_local asmjit::JitRuntime rt_; ///< JIT Runtime for asmjit.
static thread_local asmjit::CodeHolder code_; ///< JIT Code Holder for asmjit.
- static thread_local std::map<std::tuple<bool, int, int>, jit_micro_kernel_fp>
+ // The hash depends on accumulate, mc, nc, ncb, kcb, nr, mr, nr_min
+ static thread_local std::map<
+ std::tuple<bool, int, int, int, int, int, int, int>,
+ jit_micro_kernel_fp>
codeCache_; ///< JIT Code Cache for reuse.
};
diff --git a/src/GenerateKernelU8S8S32ACC16.cc b/src/GenerateKernelU8S8S32ACC16.cc
index e5980b9..082518c 100644
--- a/src/GenerateKernelU8S8S32ACC16.cc
+++ b/src/GenerateKernelU8S8S32ACC16.cc
@@ -17,7 +17,7 @@ thread_local asmjit::CodeHolder CodeGenBase<TA, TB, TC, accT>::code_;
template <typename TA, typename TB, typename TC, typename accT>
thread_local std::map<
- std::tuple<bool, int, int>,
+ std::tuple<bool, int, int, int, int, int, int, int>,
typename CodeGenBase<TA, TB, TC, accT>::jit_micro_kernel_fp>
CodeGenBase<TA, TB, TC, accT>::codeCache_;
@@ -136,11 +136,45 @@ CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<inst_set_t::avx2>(
int32_t nc,
int32_t kc,
int32_t /* unused */) {
- auto kernelSig = std::make_tuple(accum, mc, nc);
+ std::tuple<bool, int, int, int, int, int, int, int> kernelSig;
+ int kBlock;
+ int nBlock;
+ int mRegBlockSize;
+ int nRegBlockSize;
+ int nRegBlockSizeMin;
+ int row_interleave;
+
+ if (blocking_params) {
+ kBlock = blocking_params->KCB;
+ nBlock = blocking_params->NCB;
+ mRegBlockSize = blocking_params->MR;
+ nRegBlockSize = blocking_params->NR;
+ nRegBlockSizeMin = blocking_params->NR_MIN;
+ row_interleave = blocking_params->ROW_INTERLEAVE;
+ } else {
+ kBlock = PackingTraits<uint8_t, int16_t, inst_set_t::avx2>::KCB;
+ nBlock = PackingTraits<uint8_t, int16_t, inst_set_t::avx2>::NCB;
+ mRegBlockSize = PackingTraits<uint8_t, int16_t, inst_set_t::avx2>::MR;
+ nRegBlockSize = PackingTraits<uint8_t, int16_t, inst_set_t::avx2>::NR;
+ nRegBlockSizeMin =
+ PackingTraits<uint8_t, int16_t, inst_set_t::avx2>::NR_MIN;
+ row_interleave =
+ PackingTraits<uint8_t, int16_t, inst_set_t::avx2>::ROW_INTERLEAVE;
+ }
+
+ kernelSig = std::make_tuple(
+ accum,
+ mc,
+ nc,
+ nBlock,
+ kBlock,
+ mRegBlockSize,
+ nRegBlockSize,
+ nRegBlockSizeMin);
+
if (codeCache_.find(kernelSig) != codeCache_.end()) {
return codeCache_[kernelSig];
}
-
code_.reset(false);
code_.init(rt_.getCodeInfo());
asmjit::X86Assembler assembler(&code_);
@@ -148,22 +182,24 @@ CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<inst_set_t::avx2>(
#if defined(FBGEMM_LOG_CODE)
// generated code logging
- FILE* codeLogfile =
- fopen(getCodeLoggingFile<inst_set_t::avx2>(accum, mc, nc).c_str(), "w");
+ FILE* codeLogfile = fopen(
+ getCodeLoggingFile<inst_set_t::avx2>(
+ accum,
+ mc,
+ nc,
+ nBlock,
+ kBlock,
+ mRegBlockSize,
+ nRegBlockSize,
+ nRegBlockSizeMin)
+ .c_str(),
+ "w");
asmjit::FileLogger* codeLogger = new asmjit::FileLogger(codeLogfile);
if (codeLogger) {
code_.setLogger(codeLogger);
}
#endif
- constexpr int kBlock = PackingTraits<int8_t, int16_t, inst_set_t::avx2>::KCB;
- constexpr int nBlock = PackingTraits<int8_t, int16_t, inst_set_t::avx2>::NCB;
- constexpr int mRegBlockSize =
- PackingTraits<int8_t, int16_t, inst_set_t::avx2>::MR;
- // constexpr int nRegBlockSize =
- // PackingTraits<int8_t, int16_t, inst_set_t::avx2>::NR;
- constexpr int row_interleave =
- PackingTraits<int8_t, int16_t, inst_set_t::avx2>::ROW_INTERLEAVE;
int mRegBlocks = mc / mRegBlockSize;
int mRegBlocksRem = mc % mRegBlockSize;
assert(kc % row_interleave == 0 && "kc must be a multiple of row_interleave");
diff --git a/src/GenerateKernelU8S8S32ACC16Avx512.cc b/src/GenerateKernelU8S8S32ACC16Avx512.cc
index 2ded242..505fec1 100644
--- a/src/GenerateKernelU8S8S32ACC16Avx512.cc
+++ b/src/GenerateKernelU8S8S32ACC16Avx512.cc
@@ -131,7 +131,42 @@ CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<inst_set_t::avx512>(
int32_t nc,
int32_t kc,
int32_t /* unused */) {
- auto kernelSig = std::make_tuple(accum, mc, nc);
+ std::tuple<bool, int, int, int, int, int, int, int> kernelSig;
+ int kBlock;
+ int nBlock;
+ int mRegBlockSize;
+ int nRegBlockSize;
+ int nRegBlockSizeMin;
+ int row_interleave;
+
+ if (blocking_params) {
+ kBlock = blocking_params->KCB;
+ nBlock = blocking_params->NCB;
+ mRegBlockSize = blocking_params->MR;
+ nRegBlockSize = blocking_params->NR;
+ nRegBlockSizeMin = blocking_params->NR_MIN;
+ row_interleave = blocking_params->ROW_INTERLEAVE;
+ } else {
+ kBlock = PackingTraits<uint8_t, int16_t, inst_set_t::avx512>::KCB;
+ nBlock = PackingTraits<uint8_t, int16_t, inst_set_t::avx512>::NCB;
+ mRegBlockSize = PackingTraits<uint8_t, int16_t, inst_set_t::avx512>::MR;
+ nRegBlockSize = PackingTraits<uint8_t, int16_t, inst_set_t::avx512>::NR;
+ nRegBlockSizeMin =
+ PackingTraits<uint8_t, int16_t, inst_set_t::avx512>::NR_MIN;
+ row_interleave =
+ PackingTraits<uint8_t, int16_t, inst_set_t::avx512>::ROW_INTERLEAVE;
+ }
+
+ kernelSig = std::make_tuple(
+ accum,
+ mc,
+ nc,
+ nBlock,
+ kBlock,
+ mRegBlockSize,
+ nRegBlockSize,
+ nRegBlockSizeMin);
+
if (codeCache_.find(kernelSig) != codeCache_.end()) {
return codeCache_[kernelSig];
}
@@ -143,27 +178,24 @@ CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<inst_set_t::avx512>(
#if defined(FBGEMM_LOG_CODE)
// generated code logging
- FILE* codeLogfile =
- fopen(getCodeLoggingFile<inst_set_t::avx512>(accum, mc, nc).c_str(), "w");
+ FILE* codeLogfile = fopen(
+ getCodeLoggingFile<inst_set_t::avx512>(
+ accum,
+ mc,
+ nc,
+ nBlock,
+ kBlock,
+ mRegBlockSize,
+ nRegBlockSize,
+ nRegBlockSizeMin)
+ .c_str(),
+ "w");
asmjit::FileLogger* codeLogger = new asmjit::FileLogger(codeLogfile);
if (codeLogger) {
code_.setLogger(codeLogger);
}
#endif
- constexpr int kBlock =
- PackingTraits<int8_t, int16_t, inst_set_t::avx512>::KCB;
- constexpr int nBlock =
- PackingTraits<int8_t, int16_t, inst_set_t::avx512>::NCB;
- constexpr int mRegBlockSize =
- PackingTraits<int8_t, int16_t, inst_set_t::avx512>::MR;
- constexpr int nRegBlockSize =
- PackingTraits<int8_t, int16_t, inst_set_t::avx512>::NR;
- constexpr int nRegBlockSizeMin =
- PackingTraits<int8_t, int16_t, inst_set_t::avx512>::NR_MIN;
- constexpr int row_interleave =
- PackingTraits<int8_t, int16_t, inst_set_t::avx512>::ROW_INTERLEAVE;
-
assert(kc % row_interleave == 0 && "kc must be a multiple of row_interleave");
assert(nc % nRegBlockSizeMin == 0 && "nc must be a multiple of NR_MIN");
int maxMRegs = mRegBlockSize;
@@ -172,7 +204,6 @@ CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<inst_set_t::avx512>(
maxMRegs * maxNRegs <= 24 &&
"MR*(NR*ROW_INTERLEAVE*8/512) \
must be <= 24(available registers constraint)");
-
int mRegBlocks = mc / mRegBlockSize;
int mRegBlocksRem = mc % mRegBlockSize;
diff --git a/src/GenerateKernelU8S8S32ACC32.cc b/src/GenerateKernelU8S8S32ACC32.cc
index 203dd9a..ca750d9 100644
--- a/src/GenerateKernelU8S8S32ACC32.cc
+++ b/src/GenerateKernelU8S8S32ACC32.cc
@@ -17,7 +17,7 @@ thread_local asmjit::CodeHolder CodeGenBase<TA, TB, TC, accT>::code_;
template <typename TA, typename TB, typename TC, typename accT>
thread_local std::map<
- std::tuple<bool, int, int>,
+ std::tuple<bool, int, int, int, int, int, int, int>,
typename CodeGenBase<TA, TB, TC, accT>::jit_micro_kernel_fp>
CodeGenBase<TA, TB, TC, accT>::codeCache_;
@@ -140,11 +140,45 @@ CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreate<inst_set_t::avx2>(
int32_t nc,
int32_t kc,
int32_t /* unused */) {
- auto kernelSig = std::make_tuple(accum, mc, nc);
+ std::tuple<bool, int, int, int, int, int, int, int> kernelSig;
+ int kBlock;
+ int nBlock;
+ int mRegBlockSize;
+ int nRegBlockSize;
+ int nRegBlockSizeMin;
+ int row_interleave;
+
+ if (blocking_params) {
+ kBlock = blocking_params->KCB;
+ nBlock = blocking_params->NCB;
+ mRegBlockSize = blocking_params->MR;
+ nRegBlockSize = blocking_params->NR;
+ nRegBlockSizeMin = blocking_params->NR_MIN;
+ row_interleave = blocking_params->ROW_INTERLEAVE;
+ } else {
+ kBlock = PackingTraits<uint8_t, int32_t, inst_set_t::avx2>::KCB;
+ nBlock = PackingTraits<uint8_t, int32_t, inst_set_t::avx2>::NCB;
+ mRegBlockSize = PackingTraits<uint8_t, int32_t, inst_set_t::avx2>::MR;
+ nRegBlockSize = PackingTraits<uint8_t, int32_t, inst_set_t::avx2>::NR;
+ nRegBlockSizeMin =
+ PackingTraits<uint8_t, int32_t, inst_set_t::avx2>::NR_MIN;
+ row_interleave =
+ PackingTraits<uint8_t, int32_t, inst_set_t::avx2>::ROW_INTERLEAVE;
+ }
+
+ kernelSig = std::make_tuple(
+ accum,
+ mc,
+ nc,
+ nBlock,
+ kBlock,
+ mRegBlockSize,
+ nRegBlockSize,
+ nRegBlockSizeMin);
+
if (codeCache_.find(kernelSig) != codeCache_.end()) {
return codeCache_[kernelSig];
}
-
code_.reset(false);
code_.init(rt_.getCodeInfo());
asmjit::X86Assembler assembler(&code_);
@@ -152,20 +186,24 @@ CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreate<inst_set_t::avx2>(
#if defined(FBGEMM_LOG_CODE)
// generated code logging
FILE* codeLogfile =
- fopen(getCodeLoggingFile<inst_set_t::avx2>(accum, mc, nc).c_str(), "w");
+ FILE* codeLogfile = fopen(
+ getCodeLoggingFile<inst_set_t::avx2>(
+ accum,
+ mc,
+ nc,
+ nBlock,
+ kBlock,
+ mRegBlockSize,
+ nRegBlockSize,
+ nRegBlockSizeMin)
+ .c_str(),
+ "w");
asmjit::FileLogger* codeLogger = new asmjit::FileLogger(codeLogfile);
if (codeLogger) {
code_.setLogger(codeLogger);
}
#endif
- constexpr int kBlock = PackingTraits<int8_t, int32_t, inst_set_t::avx2>::KCB;
- constexpr int nBlock = PackingTraits<int8_t, int32_t, inst_set_t::avx2>::NCB;
- constexpr int mRegBlockSize =
- PackingTraits<int8_t, int32_t, inst_set_t::avx2>::MR;
- constexpr int row_interleave =
- PackingTraits<int8_t, int32_t, inst_set_t::avx2>::ROW_INTERLEAVE;
- assert(kc % row_interleave == 0 && "kc must be a multiple of row_interleave");
// assert(mc <= 12 && "mc must be <= 12 (available registers constraint)");
int mRegBlocks = mc / mRegBlockSize;
int mRegBlocksRem = mc % mRegBlockSize;
diff --git a/src/GenerateKernelU8S8S32ACC32Avx512.cc b/src/GenerateKernelU8S8S32ACC32Avx512.cc
index 333aa9d..d1729e4 100644
--- a/src/GenerateKernelU8S8S32ACC32Avx512.cc
+++ b/src/GenerateKernelU8S8S32ACC32Avx512.cc
@@ -131,11 +131,45 @@ CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreate<inst_set_t::avx512>(
int32_t nc,
int32_t kc,
int32_t /* unused */) {
- auto kernelSig = std::make_tuple(accum, mc, nc);
+ std::tuple<bool, int, int, int, int, int, int, int> kernelSig;
+ int kBlock;
+ int nBlock;
+ int mRegBlockSize;
+ int nRegBlockSize;
+ int nRegBlockSizeMin;
+ int row_interleave;
+
+ if (blocking_params) {
+ kBlock = blocking_params->KCB;
+ nBlock = blocking_params->NCB;
+ mRegBlockSize = blocking_params->MR;
+ nRegBlockSize = blocking_params->NR;
+ nRegBlockSizeMin = blocking_params->NR_MIN;
+ row_interleave = blocking_params->ROW_INTERLEAVE;
+ } else {
+ kBlock = PackingTraits<uint8_t, int32_t, inst_set_t::avx512>::KCB;
+ nBlock = PackingTraits<uint8_t, int32_t, inst_set_t::avx512>::NCB;
+ mRegBlockSize = PackingTraits<uint8_t, int32_t, inst_set_t::avx512>::MR;
+ nRegBlockSize = PackingTraits<uint8_t, int32_t, inst_set_t::avx512>::NR;
+ nRegBlockSizeMin =
+ PackingTraits<uint8_t, int32_t, inst_set_t::avx512>::NR_MIN;
+ row_interleave =
+ PackingTraits<uint8_t, int32_t, inst_set_t::avx512>::ROW_INTERLEAVE;
+ }
+
+ kernelSig = std::make_tuple(
+ accum,
+ mc,
+ nc,
+ nBlock,
+ kBlock,
+ mRegBlockSize,
+ nRegBlockSize,
+ nRegBlockSizeMin);
+
if (codeCache_.find(kernelSig) != codeCache_.end()) {
return codeCache_[kernelSig];
}
-
code_.reset(false);
code_.init(rt_.getCodeInfo());
asmjit::X86Assembler assembler(&code_);
@@ -143,27 +177,24 @@ CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreate<inst_set_t::avx512>(
#if defined(FBGEMM_LOG_CODE)
// generated code logging
- FILE* codeLogfile =
- fopen(getCodeLoggingFile<inst_set_t::avx512>(accum, mc, nc).c_str(), "w");
+ FILE* codeLogfile = fopen(
+ getCodeLoggingFile<inst_set_t::avx512>(
+ accum,
+ mc,
+ nc,
+ nBlock,
+ kBlock,
+ mRegBlockSize,
+ nRegBlockSize,
+ nRegBlockSizeMin)
+ .c_str(),
+ "w");
asmjit::FileLogger* codeLogger = new asmjit::FileLogger(codeLogfile);
if (codeLogger) {
code_.setLogger(codeLogger);
}
#endif
- constexpr int kBlock =
- PackingTraits<int8_t, int32_t, inst_set_t::avx512>::KCB;
- constexpr int nBlock =
- PackingTraits<int8_t, int32_t, inst_set_t::avx512>::NCB;
- constexpr int mRegBlockSize =
- PackingTraits<int8_t, int32_t, inst_set_t::avx512>::MR;
- constexpr int nRegBlockSize =
- PackingTraits<int8_t, int32_t, inst_set_t::avx512>::NR;
- constexpr int nRegBlockSizeMin =
- PackingTraits<int8_t, int32_t, inst_set_t::avx512>::NR_MIN;
- constexpr int row_interleave =
- PackingTraits<int8_t, int32_t, inst_set_t::avx512>::ROW_INTERLEAVE;
-
assert(kc % row_interleave == 0 && "kc must be a multiple of row_interleave");
assert(nc % nRegBlockSizeMin == 0 && "nc must be a multiple of NR_MIN");
int maxMRegs = mRegBlockSize;
diff --git a/src/PackAMatrix.cc b/src/PackAMatrix.cc
index db019db..89ec13e 100644
--- a/src/PackAMatrix.cc
+++ b/src/PackAMatrix.cc
@@ -20,24 +20,36 @@ PackAMatrix<T, accT>::PackAMatrix(
const T* smat,
int32_t ld,
inpType* pmat,
- int groups)
- : PackMatrix<PackAMatrix<T, accT>, T, accT>(nRow, nCol, pmat, groups),
+ int groups,
+ const BlockingFactors* params)
+ : PackMatrix<PackAMatrix<T, accT>, T, accT>(
+ nRow,
+ nCol,
+ pmat,
+ groups,
+ params),
trans_(trans),
smat_(smat),
ld_(ld) {
- if (fbgemmHasAvx512Support()) {
- BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx512>::MCB;
- BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx512>::KCB;
- row_interleave_B_ =
- PackingTraits<T, accT, inst_set_t::avx512>::ROW_INTERLEAVE;
- } else if (fbgemmHasAvx2Support()) {
- BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx2>::MCB;
- BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx2>::KCB;
- row_interleave_B_ =
- PackingTraits<T, accT, inst_set_t::avx2>::ROW_INTERLEAVE;
+ if (params) {
+ BaseType::brow_ = params->MCB;
+ BaseType::bcol_ = params->KCB;
+ row_interleave_B_ = params->ROW_INTERLEAVE;
} else {
- // TODO: Have default slower path
- assert(0 && "unsupported architecure");
+ if (fbgemmHasAvx512Support()) {
+ BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx512>::MCB;
+ BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx512>::KCB;
+ row_interleave_B_ =
+ PackingTraits<T, accT, inst_set_t::avx512>::ROW_INTERLEAVE;
+ } else if (fbgemmHasAvx2Support()) {
+ BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx2>::MCB;
+ BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx2>::KCB;
+ row_interleave_B_ =
+ PackingTraits<T, accT, inst_set_t::avx2>::ROW_INTERLEAVE;
+ } else {
+ // TODO: Have default slower path
+ assert(0 && "unsupported architecure");
+ }
}
if (BaseType::numCols() % groups != 0) {
throw std::runtime_error(
@@ -46,8 +58,7 @@ PackAMatrix<T, accT>::PackAMatrix(
}
if (pmat) {
BaseType::buf_ = pmat;
- }
- else {
+ } else {
BaseType::bufAllocatedHere_ = true;
BaseType::buf_ = (T*)fbgemmAlignedAlloc(
64, BaseType::brow_ * BaseType::bcol_ * sizeof(T));
diff --git a/src/PackAWithIm2Col.cc b/src/PackAWithIm2Col.cc
index 93408da..fb4556c 100644
--- a/src/PackAWithIm2Col.cc
+++ b/src/PackAWithIm2Col.cc
@@ -23,7 +23,8 @@ PackAWithIm2Col<T, accT, SPATIAL_DIM>::PackAWithIm2Col(
inpType* pmat,
int32_t a_zero_pt,
int32_t* row_offset,
- bool b_symmetric)
+ bool b_symmetric,
+ const BlockingFactors* params)
: PackMatrix<PackAWithIm2Col<T, accT, SPATIAL_DIM>, T, accT>(
conv_p.MB *
std::accumulate(
@@ -38,25 +39,33 @@ PackAWithIm2Col<T, accT, SPATIAL_DIM>::PackAWithIm2Col(
std::multiplies<int>()) *
conv_p.IC,
pmat,
- conv_p.G),
+ conv_p.G,
+ params),
conv_p_(conv_p),
sdata_(sdata),
a_zero_pt_(a_zero_pt) {
static_assert(
SPATIAL_DIM == 2 || SPATIAL_DIM == 3, "unsupported conv dimension ");
- if (fbgemmHasAvx512Support()) {
- BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx512>::MCB;
- BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx512>::KCB;
- row_interleave_B_ =
- PackingTraits<T, accT, inst_set_t::avx512>::ROW_INTERLEAVE;
- } else if (fbgemmHasAvx2Support()) {
- BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx2>::MCB;
- BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx2>::KCB;
- row_interleave_B_ =
- PackingTraits<T, accT, inst_set_t::avx2>::ROW_INTERLEAVE;
+
+ if (params) {
+ BaseType::brow_ = params->MCB;
+ BaseType::bcol_ = params->KCB;
+ row_interleave_B_ = params->ROW_INTERLEAVE;
} else {
- // TODO: Have default slower path
- assert(0 && "unsupported architecure");
+ if (fbgemmHasAvx512Support()) {
+ BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx512>::MCB;
+ BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx512>::KCB;
+ row_interleave_B_ =
+ PackingTraits<T, accT, inst_set_t::avx512>::ROW_INTERLEAVE;
+ } else if (fbgemmHasAvx2Support()) {
+ BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx2>::MCB;
+ BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx2>::KCB;
+ row_interleave_B_ =
+ PackingTraits<T, accT, inst_set_t::avx2>::ROW_INTERLEAVE;
+ } else {
+ // TODO: Have default slower path
+ assert(0 && "unsupported architecure");
+ }
}
if (BaseType::numCols() % conv_p.G != 0) {
throw std::runtime_error(
@@ -145,8 +154,7 @@ void pack_a_with_im2col_opt(
std::memcpy(
out + (i - block.row_start) * BCOL + j + s * IC,
sdata +
- ((n * IN_DIM_H + h_in) * IN_DIM_W + -PAD_W + w * STRIDE_W +
- s) *
+ ((n * IN_DIM_H + h_in) * IN_DIM_W + -PAD_W + w * STRIDE_W + s) *
IC,
sizeof(uint8_t) * mid_len * IC);
s += mid_len;
@@ -459,17 +467,22 @@ void PackAWithIm2Col<T, accT, SPATIAL_DIM>::printPackedMatrix(
}
template <typename T, typename accT, int SPATIAL_DIM>
-int PackAWithIm2Col<T, accT, SPATIAL_DIM>::rowOffsetBufferSize() {
+int PackAWithIm2Col<T, accT, SPATIAL_DIM>::rowOffsetBufferSize(
+ const BlockingFactors* params) {
if (cpuinfo_initialize()) {
+ if (params){
+ return params->MCB;
+ } else {
if (fbgemmHasAvx512Support()) {
- return PackingTraits<T, accT, inst_set_t::avx512>::MCB;
+ return PackingTraits<T, accT, inst_set_t::avx512>::MCB;
} else if (fbgemmHasAvx2Support()) {
- return PackingTraits<T, accT, inst_set_t::avx2>::MCB;
- } else {
+ return PackingTraits<T, accT, inst_set_t::avx2>::MCB;
+ } else {
// TODO: Have default slower path
assert(0 && "unsupported architecture");
return -1;
}
+ }
} else {
throw std::runtime_error("Failed to initialize cpuinfo!");
}
diff --git a/src/PackAWithQuantRowOffset.cc b/src/PackAWithQuantRowOffset.cc
index 2929ebb..175425f 100644
--- a/src/PackAWithQuantRowOffset.cc
+++ b/src/PackAWithQuantRowOffset.cc
@@ -28,12 +28,14 @@ PackAWithQuantRowOffset<T, accT>::PackAWithQuantRowOffset(
float scale,
int32_t zero_pt,
int groups,
- int32_t* row_offset)
+ int32_t* row_offset,
+ const BlockingFactors* params)
: PackMatrix<PackAWithQuantRowOffset<T, accT>, T, accT>(
nRow,
nCol,
pmat,
- groups),
+ groups,
+ params),
trans_(trans),
smat_(smat),
ld_(ld),
@@ -41,20 +43,30 @@ PackAWithQuantRowOffset<T, accT>::PackAWithQuantRowOffset(
zero_pt_(zero_pt),
row_offset_(row_offset) {
rowOffsetAllocatedHere = false;
-
- if (fbgemmHasAvx512Support()) {
- BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx512>::MCB;
- BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx512>::KCB;
- row_interleave_B_ =
- PackingTraits<T, accT, inst_set_t::avx512>::ROW_INTERLEAVE;
- } else if (fbgemmHasAvx2Support()) {
- BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx2>::MCB;
- BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx2>::KCB;
- row_interleave_B_ =
- PackingTraits<T, accT, inst_set_t::avx2>::ROW_INTERLEAVE;
+ if (params) {
+ if (fbgemmHasAvx512Support() || fbgemmHasAvx2Support()) {
+ BaseType::brow_ = params->MCB;
+ BaseType::bcol_ = params->KCB;
+ row_interleave_B_ = params->ROW_INTERLEAVE;
+ } else {
+ // TODO: Have default slower path
+ assert(0 && "unsupported architecure");
+ }
} else {
- // TODO: Have default slower path
- assert(0 && "unknown architecure");
+ if (fbgemmHasAvx512Support()) {
+ BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx512>::MCB;
+ BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx512>::KCB;
+ row_interleave_B_ =
+ PackingTraits<T, accT, inst_set_t::avx512>::ROW_INTERLEAVE;
+ } else if (fbgemmHasAvx2Support()) {
+ BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx2>::MCB;
+ BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx2>::KCB;
+ row_interleave_B_ =
+ PackingTraits<T, accT, inst_set_t::avx2>::ROW_INTERLEAVE;
+ } else {
+ // TODO: Have default slower path
+ assert(0 && "unknown architecure");
+ }
}
if (BaseType::numCols() % groups != 0) {
throw std::runtime_error(
@@ -179,15 +191,20 @@ void PackAWithQuantRowOffset<T, accT>::printPackedMatrix(std::string name) {
}
template <typename T, typename accT>
-int PackAWithQuantRowOffset<T, accT>::rowOffsetBufferSize() {
+int PackAWithQuantRowOffset<T, accT>::rowOffsetBufferSize(
+ const BlockingFactors* params) {
if (cpuinfo_initialize()) {
- if (fbgemmHasAvx512Support()) {
- return PackingTraits<T, accT, inst_set_t::avx512>::MCB;
- } else if (fbgemmHasAvx2Support()) {
- return PackingTraits<T, accT, inst_set_t::avx2>::MCB;
+ if (params) {
+ return params->MCB;
} else {
- assert(0 && "unsupported architecture");
- return -1;
+ if (fbgemmHasAvx512Support()) {
+ return PackingTraits<T, accT, inst_set_t::avx512>::MCB;
+ } else if (fbgemmHasAvx2Support()) {
+ return PackingTraits<T, accT, inst_set_t::avx2>::MCB;
+ } else {
+ assert(0 && "unsupported architecture");
+ return -1;
+ }
}
} else {
throw std::runtime_error("Failed to initialize cpuinfo!");
diff --git a/src/PackAWithRowOffset.cc b/src/PackAWithRowOffset.cc
index 7777f1a..139a6d3 100644
--- a/src/PackAWithRowOffset.cc
+++ b/src/PackAWithRowOffset.cc
@@ -24,31 +24,38 @@ PackAWithRowOffset<T, accT>::PackAWithRowOffset(
uint32_t ld,
inpType* pmat,
int groups,
- int32_t* row_offset)
+ int32_t* row_offset,
+ const BlockingFactors* params)
: PackMatrix<PackAWithRowOffset<T, accT>, T, accT>(
nRow,
nCol,
pmat,
- groups),
+ groups,
+ params),
trans_(trans),
smat_(smat),
ld_(ld),
row_offset_(row_offset) {
rowOffsetAllocatedHere = false;
-
- if (fbgemmHasAvx512Support()) {
- BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx512>::MCB;
- BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx512>::KCB;
- row_interleave_B_ =
- PackingTraits<T, accT, inst_set_t::avx512>::ROW_INTERLEAVE;
- } else if (fbgemmHasAvx2Support()) {
- BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx2>::MCB;
- BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx2>::KCB;
- row_interleave_B_ =
- PackingTraits<T, accT, inst_set_t::avx2>::ROW_INTERLEAVE;
+ if (params) {
+ BaseType::brow_ = params->MCB;
+ BaseType::bcol_ = params->KCB;
+ row_interleave_B_ = params->ROW_INTERLEAVE;
} else {
- // TODO: Have default slower path
- assert(0 && "unknown architecure");
+ if (fbgemmHasAvx512Support()) {
+ BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx512>::MCB;
+ BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx512>::KCB;
+ row_interleave_B_ =
+ PackingTraits<T, accT, inst_set_t::avx512>::ROW_INTERLEAVE;
+ } else if (fbgemmHasAvx2Support()) {
+ BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx2>::MCB;
+ BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx2>::KCB;
+ row_interleave_B_ =
+ PackingTraits<T, accT, inst_set_t::avx2>::ROW_INTERLEAVE;
+ } else {
+ // TODO: Have default slower path
+ assert(0 && "unknown architecure");
+ }
}
if (BaseType::numCols() % groups != 0) {
throw std::runtime_error(
@@ -169,17 +176,22 @@ void PackAWithRowOffset<T, accT>::printPackedMatrix(std::string name) {
}
template <typename T, typename accT>
-int PackAWithRowOffset<T, accT>::rowOffsetBufferSize() {
+int PackAWithRowOffset<T, accT>::rowOffsetBufferSize(
+ const BlockingFactors* params) {
if (cpuinfo_initialize()) {
+ if (params){
+ return params->MCB;
+ } else {
if (fbgemmHasAvx512Support()) {
- return PackingTraits<T, accT, inst_set_t::avx512>::MCB;
- } else if (fbgemmHasAvx2Support()) {
- return PackingTraits<T, accT, inst_set_t::avx2>::MCB;
+ return PackingTraits<T, accT, inst_set_t::avx512>::MCB;
+ } else if (fbgemmHasAvx2Support()) {
+ return PackingTraits<T, accT, inst_set_t::avx2>::MCB;
} else {
// TODO: Have default slower path
assert(0 && "unsupported architecture");
return -1;
}
+ }
} else {
throw std::runtime_error("Failed to initialize cpuinfo!");
}
diff --git a/src/PackBMatrix.cc b/src/PackBMatrix.cc
index 48641ff..472c802 100644
--- a/src/PackBMatrix.cc
+++ b/src/PackBMatrix.cc
@@ -174,23 +174,36 @@ PackBMatrix<T, accT>::PackBMatrix(
const T* smat,
int32_t ld,
inpType* pmat,
- int groups)
- : PackMatrix<PackBMatrix<T, accT>, T, accT>(nRow, nCol, pmat, groups),
+ int groups,
+ const BlockingFactors* params)
+ : PackMatrix<PackBMatrix<T, accT>, T, accT>(
+ nRow,
+ nCol,
+ pmat,
+ groups,
+ params),
trans_(trans),
smat_(smat),
ld_(ld) {
- if (fbgemmHasAvx512Support()) {
- BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx512>::KCB;
- BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx512>::NCB;
- row_interleave_ =
- PackingTraits<T, accT, inst_set_t::avx512>::ROW_INTERLEAVE;
- } else if (fbgemmHasAvx2Support()) {
- BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx2>::KCB;
- BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx2>::NCB;
- row_interleave_ = PackingTraits<T, accT, inst_set_t::avx2>::ROW_INTERLEAVE;
+ if (params) {
+ BaseType::brow_ = params->KCB;
+ BaseType::bcol_ = params->NCB;
+ row_interleave_ = params->ROW_INTERLEAVE;
} else {
- // Error
- assert(0 && "unknown architecure");
+ if (fbgemmHasAvx512Support()) {
+ BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx512>::KCB;
+ BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx512>::NCB;
+ row_interleave_ =
+ PackingTraits<T, accT, inst_set_t::avx512>::ROW_INTERLEAVE;
+ } else if (fbgemmHasAvx2Support()) {
+ BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx2>::KCB;
+ BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx2>::NCB;
+ row_interleave_ =
+ PackingTraits<T, accT, inst_set_t::avx2>::ROW_INTERLEAVE;
+ } else {
+ // Error
+ assert(0 && "unknown architecure");
+ }
}
if (BaseType::numRows() % groups != 0) {
throw std::runtime_error(
diff --git a/src/PackMatrix.cc b/src/PackMatrix.cc
index 316fc06..e93b97c 100644
--- a/src/PackMatrix.cc
+++ b/src/PackMatrix.cc
@@ -18,33 +18,57 @@ PackMatrix<PT, inpType, accType>::PackMatrix(
int32_t rows,
int32_t cols,
inpType* buf,
- int groups)
+ int groups,
+ const BlockingFactors* params)
: buf_(buf), nrows_(rows), ncols_(cols), G_(groups) {
bufAllocatedHere_ = false;
+ blocking_params = params;
if (!cpuinfo_initialize()) {
throw std::runtime_error("Failed to initialize cpuinfo!");
}
}
template <typename PT, typename inpType, typename accType>
-int PackMatrix<PT, inpType, accType>::packedBufferSize(int rows, int cols) {
+int PackMatrix<PT, inpType, accType>::packedBufferSize(
+ int rows,
+ int cols,
+ const BlockingFactors* params) {
+ int MCB, KCB, NCB;
+ if (params) {
+ MCB = params->MCB;
+ NCB = params->NCB;
+ KCB = params->KCB;
+ } else {
+ if (fbgemmHasAvx512Support()) {
+ MCB = PackingTraits<inpType, accType, inst_set_t::avx512>::MCB;
+ NCB = PackingTraits<inpType, accType, inst_set_t::avx512>::NCB;
+ KCB = PackingTraits<inpType, accType, inst_set_t::avx512>::KCB;
+ } else if (fbgemmHasAvx2Support()) {
+ MCB = PackingTraits<inpType, accType, inst_set_t::avx2>::MCB;
+ NCB = PackingTraits<inpType, accType, inst_set_t::avx2>::NCB;
+ KCB = PackingTraits<inpType, accType, inst_set_t::avx2>::KCB;
+ } else {
+ // TODO: Have default slower path
+ assert(0 && "unsupported architecure");
+ return -1;
+ }
+ }
+
if (fbgemmHasAvx512Support()) {
if (isA()) {
- return PackingTraits<inpType, accType, inst_set_t::avx512>::MCB *
- PackingTraits<inpType, accType, inst_set_t::avx512>::KCB;
+ return MCB * KCB;
} else {
- int rowBlock = PackingTraits<inpType, accType, inst_set_t::avx512>::KCB;
- int colBlock = PackingTraits<inpType, accType, inst_set_t::avx512>::NCB;
+ int rowBlock = KCB;
+ int colBlock = NCB;
return (((rows + rowBlock - 1) / rowBlock) * rowBlock) *
(((cols + colBlock - 1) / colBlock) * colBlock);
}
} else if (fbgemmHasAvx2Support()) {
if (isA()) {
- return PackingTraits<inpType, accType, inst_set_t::avx2>::MCB *
- PackingTraits<inpType, accType, inst_set_t::avx2>::KCB;
+ return MCB * KCB;
} else {
- int rowBlock = PackingTraits<inpType, accType, inst_set_t::avx2>::KCB;
- int colBlock = PackingTraits<inpType, accType, inst_set_t::avx2>::NCB;
+ int rowBlock = KCB;
+ int colBlock = NCB;
return (((rows + rowBlock - 1) / rowBlock) * rowBlock) *
(((cols + colBlock - 1) / colBlock) * colBlock);
}