diff options
Diffstat (limited to 'src/GenerateKernel.h')
-rw-r--r-- | src/GenerateKernel.h | 101 |
1 files changed, 44 insertions, 57 deletions
diff --git a/src/GenerateKernel.h b/src/GenerateKernel.h index e52097e..c0fece4 100644 --- a/src/GenerateKernel.h +++ b/src/GenerateKernel.h @@ -8,8 +8,11 @@ #include <asmjit/asmjit.h> #include <cpuinfo.h> #include <map> +#include <mutex> +#include <sstream> #include <string> #include <tuple> +#include "CodeCache.h" #include "fbgemm/Fbgemm.h" /*#define FBGEMM_LOG_CODE 1*/ @@ -40,35 +43,7 @@ class CodeGenBase { * @brief Constructor for initializing AVX2/AVX512 registers. */ CodeGenBase(const BlockingFactors* params = nullptr) - : blocking_params(params), - CRegs_avx2_{x86::ymm0, - x86::ymm1, - x86::ymm2, - x86::ymm3, - x86::ymm4, - x86::ymm5, - x86::ymm6, - x86::ymm7, - x86::ymm8, - x86::ymm9, - x86::ymm10, - x86::ymm11}, - CRegs_avx512_{ - x86::zmm0, x86::zmm1, x86::zmm2, x86::zmm3, x86::zmm4, - x86::zmm5, x86::zmm6, x86::zmm7, x86::zmm8, x86::zmm9, - x86::zmm10, x86::zmm11, x86::zmm12, x86::zmm13, x86::zmm14, - x86::zmm15, x86::zmm16, x86::zmm17, x86::zmm18, x86::zmm19, - x86::zmm20, x86::zmm21, x86::zmm22, x86::zmm23, x86::zmm24, - x86::zmm25, x86::zmm26, x86::zmm27, - }, - AllRegs_avx512_{x86::zmm0, x86::zmm1, x86::zmm2, x86::zmm3, - x86::zmm4, x86::zmm5, x86::zmm6, x86::zmm7, - x86::zmm8, x86::zmm9, x86::zmm10, x86::zmm11, - x86::zmm12, x86::zmm13, x86::zmm14, x86::zmm15, - x86::zmm16, x86::zmm17, x86::zmm18, x86::zmm19, - x86::zmm20, x86::zmm21, x86::zmm22, x86::zmm23, - x86::zmm24, x86::zmm25, x86::zmm26, x86::zmm27, - x86::zmm28, x86::zmm29, x86::zmm30, x86::zmm31} { + : blocking_params(params) { // vector width in bits if (cpuinfo_initialize()) { if (fbgemmHasAvx512Support()) { @@ -143,7 +118,7 @@ class CodeGenBase { * (debug-only) */ template <inst_set_t instSet> - std::string getCodeLoggingFile( + static std::string getCodeLoggingFile( bool accum, int mc, int nc, @@ -152,48 +127,60 @@ class CodeGenBase { int MR, int NR, int NR_MIN) { - std::string fileName = "gemm_"; + std::ostringstream oss; + oss << "gemm_"; if (std::is_same<accT, std::int16_t>::value) { - fileName += "acc16_"; + oss << "acc16_"; } else if (std::is_same<accT, std::int32_t>::value) { - fileName += "acc32_"; + oss << "acc32_"; } else { - fileName += "unknown_"; + oss << "unknown_"; } - fileName += "accum-" + std::to_string(accum); - fileName += "_MC-" + std::to_string(mc); - fileName += "_NC-" + std::to_string(nc); - fileName += "_NCB-" + std::to_string(NCB); - fileName += "_NCB-" + std::to_string(KCB); - fileName += "_MR-" + std::to_string(MR); - fileName += "_NR-" + std::to_string(NR); - fileName += "_NR_MIN-" + std::to_string(NR_MIN); + oss << "accum-" + std::to_string(accum) + << "_MC-" + std::to_string(mc) + << "_NC-" + std::to_string(nc) + << "_NCB-" + std::to_string(NCB) + << "_NCB-" + std::to_string(KCB) + << "_MR-" + std::to_string(MR) + << "_NR-" + std::to_string(NR) + << "_NR_MIN-" + std::to_string(NR_MIN); if (instSet == inst_set_t::avx512_vnni) { - fileName += "_avx512vnni"; + oss << "_avx512vnni"; } else if (instSet == inst_set_t::avx512) { - fileName += "_avx512"; + oss << "_avx512"; } else if (instSet == inst_set_t::avx2) { - fileName += "_avx2"; + oss << "_avx2"; } - fileName += ".txt"; - return fileName; + oss << ".txt"; + return oss.str(); } private: - x86::Ymm CRegs_avx2_[12]; ///< AVX2 ymm registers for C in the micro-kernel. - x86::Zmm - CRegs_avx512_[28]; ///< AVX512 zmm registers for C in the micro-kernel. - x86::Zmm AllRegs_avx512_[32]; ///< all AVX512 zmm registers. - int vectorWidth_; ///< Vector width in bits. int VLEN_; ///< Vector width in elements. - static thread_local asmjit::JitRuntime rt_; ///< JIT Runtime for asmjit. - static thread_local asmjit::CodeHolder code_; ///< JIT Code Holder for asmjit. + + static asmjit::JitRuntime &runtime() { + static asmjit::JitRuntime rt; //< JIT Runtime for asmjit, + // depents on other static + // variables. Required to prevent + // initialization order fiasco + return rt; + } + + static std::mutex rtMutex_; ///< Controll access to runtime; + // The hash depends on accumulate, mc, nc, ncb, kcb, nr, mr, nr_min - static thread_local std::map< - std::tuple<bool, int, int, int, int, int, int, int>, - jit_micro_kernel_fp> + static CodeCache<std::tuple<bool, int, int, int, int, int, int, int>, + jit_micro_kernel_fp> codeCache_; ///< JIT Code Cache for reuse. }; +template <typename TA, typename TB, typename TC, typename accT> +std::mutex CodeGenBase<TA, TB, TC, accT>::rtMutex_; + +template <typename TA, typename TB, typename TC, typename accT> +CodeCache<std::tuple<bool, int, int, int, int, int, int, int>, + typename CodeGenBase<TA, TB, TC, accT>::jit_micro_kernel_fp> + CodeGenBase<TA, TB, TC, accT>::codeCache_; + } // namespace fbgemm |