Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/FBGEMM.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'src/GenerateKernel.h')
-rw-r--r--src/GenerateKernel.h101
1 files changed, 44 insertions, 57 deletions
diff --git a/src/GenerateKernel.h b/src/GenerateKernel.h
index e52097e..c0fece4 100644
--- a/src/GenerateKernel.h
+++ b/src/GenerateKernel.h
@@ -8,8 +8,11 @@
#include <asmjit/asmjit.h>
#include <cpuinfo.h>
#include <map>
+#include <mutex>
+#include <sstream>
#include <string>
#include <tuple>
+#include "CodeCache.h"
#include "fbgemm/Fbgemm.h"
/*#define FBGEMM_LOG_CODE 1*/
@@ -40,35 +43,7 @@ class CodeGenBase {
* @brief Constructor for initializing AVX2/AVX512 registers.
*/
CodeGenBase(const BlockingFactors* params = nullptr)
- : blocking_params(params),
- CRegs_avx2_{x86::ymm0,
- x86::ymm1,
- x86::ymm2,
- x86::ymm3,
- x86::ymm4,
- x86::ymm5,
- x86::ymm6,
- x86::ymm7,
- x86::ymm8,
- x86::ymm9,
- x86::ymm10,
- x86::ymm11},
- CRegs_avx512_{
- x86::zmm0, x86::zmm1, x86::zmm2, x86::zmm3, x86::zmm4,
- x86::zmm5, x86::zmm6, x86::zmm7, x86::zmm8, x86::zmm9,
- x86::zmm10, x86::zmm11, x86::zmm12, x86::zmm13, x86::zmm14,
- x86::zmm15, x86::zmm16, x86::zmm17, x86::zmm18, x86::zmm19,
- x86::zmm20, x86::zmm21, x86::zmm22, x86::zmm23, x86::zmm24,
- x86::zmm25, x86::zmm26, x86::zmm27,
- },
- AllRegs_avx512_{x86::zmm0, x86::zmm1, x86::zmm2, x86::zmm3,
- x86::zmm4, x86::zmm5, x86::zmm6, x86::zmm7,
- x86::zmm8, x86::zmm9, x86::zmm10, x86::zmm11,
- x86::zmm12, x86::zmm13, x86::zmm14, x86::zmm15,
- x86::zmm16, x86::zmm17, x86::zmm18, x86::zmm19,
- x86::zmm20, x86::zmm21, x86::zmm22, x86::zmm23,
- x86::zmm24, x86::zmm25, x86::zmm26, x86::zmm27,
- x86::zmm28, x86::zmm29, x86::zmm30, x86::zmm31} {
+ : blocking_params(params) {
// vector width in bits
if (cpuinfo_initialize()) {
if (fbgemmHasAvx512Support()) {
@@ -143,7 +118,7 @@ class CodeGenBase {
* (debug-only)
*/
template <inst_set_t instSet>
- std::string getCodeLoggingFile(
+ static std::string getCodeLoggingFile(
bool accum,
int mc,
int nc,
@@ -152,48 +127,60 @@ class CodeGenBase {
int MR,
int NR,
int NR_MIN) {
- std::string fileName = "gemm_";
+ std::ostringstream oss;
+ oss << "gemm_";
if (std::is_same<accT, std::int16_t>::value) {
- fileName += "acc16_";
+ oss << "acc16_";
} else if (std::is_same<accT, std::int32_t>::value) {
- fileName += "acc32_";
+ oss << "acc32_";
} else {
- fileName += "unknown_";
+ oss << "unknown_";
}
- fileName += "accum-" + std::to_string(accum);
- fileName += "_MC-" + std::to_string(mc);
- fileName += "_NC-" + std::to_string(nc);
- fileName += "_NCB-" + std::to_string(NCB);
- fileName += "_NCB-" + std::to_string(KCB);
- fileName += "_MR-" + std::to_string(MR);
- fileName += "_NR-" + std::to_string(NR);
- fileName += "_NR_MIN-" + std::to_string(NR_MIN);
+ oss << "accum-" + std::to_string(accum)
+ << "_MC-" + std::to_string(mc)
+ << "_NC-" + std::to_string(nc)
+ << "_NCB-" + std::to_string(NCB)
+ << "_NCB-" + std::to_string(KCB)
+ << "_MR-" + std::to_string(MR)
+ << "_NR-" + std::to_string(NR)
+ << "_NR_MIN-" + std::to_string(NR_MIN);
if (instSet == inst_set_t::avx512_vnni) {
- fileName += "_avx512vnni";
+ oss << "_avx512vnni";
} else if (instSet == inst_set_t::avx512) {
- fileName += "_avx512";
+ oss << "_avx512";
} else if (instSet == inst_set_t::avx2) {
- fileName += "_avx2";
+ oss << "_avx2";
}
- fileName += ".txt";
- return fileName;
+ oss << ".txt";
+ return oss.str();
}
private:
- x86::Ymm CRegs_avx2_[12]; ///< AVX2 ymm registers for C in the micro-kernel.
- x86::Zmm
- CRegs_avx512_[28]; ///< AVX512 zmm registers for C in the micro-kernel.
- x86::Zmm AllRegs_avx512_[32]; ///< all AVX512 zmm registers.
-
int vectorWidth_; ///< Vector width in bits.
int VLEN_; ///< Vector width in elements.
- static thread_local asmjit::JitRuntime rt_; ///< JIT Runtime for asmjit.
- static thread_local asmjit::CodeHolder code_; ///< JIT Code Holder for asmjit.
+
+ static asmjit::JitRuntime &runtime() {
+ static asmjit::JitRuntime rt; //< JIT Runtime for asmjit,
+ // depents on other static
+ // variables. Required to prevent
+ // initialization order fiasco
+ return rt;
+ }
+
+ static std::mutex rtMutex_; ///< Controll access to runtime;
+
// The hash depends on accumulate, mc, nc, ncb, kcb, nr, mr, nr_min
- static thread_local std::map<
- std::tuple<bool, int, int, int, int, int, int, int>,
- jit_micro_kernel_fp>
+ static CodeCache<std::tuple<bool, int, int, int, int, int, int, int>,
+ jit_micro_kernel_fp>
codeCache_; ///< JIT Code Cache for reuse.
};
+template <typename TA, typename TB, typename TC, typename accT>
+std::mutex CodeGenBase<TA, TB, TC, accT>::rtMutex_;
+
+template <typename TA, typename TB, typename TC, typename accT>
+CodeCache<std::tuple<bool, int, int, int, int, int, int, int>,
+ typename CodeGenBase<TA, TB, TC, accT>::jit_micro_kernel_fp>
+ CodeGenBase<TA, TB, TC, accT>::codeCache_;
+
} // namespace fbgemm