Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/FBGEMM.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'src/GroupwiseConvAcc32Avx2.cc')
-rw-r--r--src/GroupwiseConvAcc32Avx2.cc78
1 files changed, 33 insertions, 45 deletions
diff --git a/src/GroupwiseConvAcc32Avx2.cc b/src/GroupwiseConvAcc32Avx2.cc
index b140c83..ef4ba7b 100644
--- a/src/GroupwiseConvAcc32Avx2.cc
+++ b/src/GroupwiseConvAcc32Avx2.cc
@@ -21,20 +21,6 @@ namespace fbgemm {
using namespace std;
-template <int SPATIAL_DIM, typename accT>
-thread_local asmjit::JitRuntime GenConvKernel<SPATIAL_DIM, accT>::rt_;
-
-template <int SPATIAL_DIM, typename accT>
-thread_local asmjit::CodeHolder GenConvKernel<SPATIAL_DIM, accT>::code_;
-
-template <int SPATIAL_DIM, typename accT>
-thread_local std::map<std::tuple<bool, int, int, int>, jit_conv_kernel_fp>
- GenConvKernel<SPATIAL_DIM, accT>::codeCache_;
-
-template <int SPATIAL_DIM, typename accT>
-thread_local std::map<std::tuple<bool, int, int, int>, jit_rowoffset_kernel_fp>
- GenConvKernel<SPATIAL_DIM, accT>::codeCacheRowOffset_;
-
namespace x86 = asmjit::x86;
template <int SPATIAL_DIM>
@@ -91,14 +77,13 @@ jit_conv_kernel_fp getOrCreateConvKernel(
// Note: Wrong code is generated if it's not one of the supported convolution
assert(fbgemmOptimizedGConv<SPATIAL_DIM>(conv_param));
auto kernelSig = getKernelSig(conv_param, a_zero_point == 0);
- if (GenConvKernel<SPATIAL_DIM, accT>::codeCache_.find(kernelSig) !=
- GenConvKernel<SPATIAL_DIM, accT>::codeCache_.end()) {
- return GenConvKernel<SPATIAL_DIM, accT>::codeCache_[kernelSig];
- } else {
- auto genObj = GenConvKernel<SPATIAL_DIM, accT>(conv_param, a_zero_point);
- // TODO: Instruction set based dispatch
- return genObj.template getOrCreate<inst_set_t::avx2>(conv_param);
- }
+ return GenConvKernel<SPATIAL_DIM, accT>::codeCache_.getOrCreate(
+ kernelSig, [&]() {
+ auto genObj =
+ GenConvKernel<SPATIAL_DIM, accT>(conv_param, a_zero_point);
+ // TODO: Instruction set based dispatch
+ return genObj.template getOrCreate<inst_set_t::avx2>(conv_param);
+ });
}
template <>
@@ -1009,9 +994,9 @@ template <>
template <>
jit_conv_kernel_fp GenConvKernel<2, int32_t>::getOrCreate<inst_set_t::avx2>(
const conv_param_t<2>& conv_param) {
- code_.reset(false);
- code_.init(rt_.codeInfo());
- x86::Assembler assembler(&code_);
+ asmjit::CodeHolder code;
+ code.init(runtime().codeInfo());
+ x86::Assembler assembler(&code);
x86::Emitter* a = assembler.as<x86::Emitter>();
#if defined(FBGEMM_LOG_CODE)
@@ -1020,7 +1005,7 @@ jit_conv_kernel_fp GenConvKernel<2, int32_t>::getOrCreate<inst_set_t::avx2>(
fopen(getCodeLoggingFile<inst_set_t::avx2>(false).c_str(), "w");
asmjit::FileLogger* codeLogger = new asmjit::FileLogger(codeLogfile);
if (codeLogger) {
- code_.setLogger(codeLogger);
+ code.setLogger(codeLogger);
}
#endif
@@ -1097,13 +1082,15 @@ jit_conv_kernel_fp GenConvKernel<2, int32_t>::getOrCreate<inst_set_t::avx2>(
a->emitEpilog(frame);
jit_conv_kernel_fp fn;
- asmjit::Error err = rt_.add(&fn, &code_);
+ asmjit::Error err;
+ {
+ std::unique_lock<std::mutex> lock(rtMutex_);
+ err = runtime().add(&fn, &code);
+ }
if (err) {
std::cout << "Error: in fn add" << std::endl;
return nullptr;
}
- auto kernelSig = getKernelSig(conv_param, isAZeroPointZero_);
- codeCache_[kernelSig] = fn;
#if defined(FBGEMM_LOG_CODE)
fclose(codeLogfile);
@@ -1489,9 +1476,9 @@ template <>
jit_rowoffset_kernel_fp
GenConvKernel<2, int32_t>::getOrCreateRowOffset<inst_set_t::avx2>(
const conv_param_t<2>& conv_param) {
- code_.reset(false);
- code_.init(rt_.codeInfo());
- x86::Assembler assembler(&code_);
+ asmjit::CodeHolder code;
+ code.init(runtime().codeInfo());
+ x86::Assembler assembler(&code);
x86::Emitter* a = assembler.as<x86::Emitter>();
#if defined(FBGEMM_LOG_CODE)
@@ -1500,7 +1487,7 @@ GenConvKernel<2, int32_t>::getOrCreateRowOffset<inst_set_t::avx2>(
fopen(getCodeLoggingFile<inst_set_t::avx2>(true).c_str(), "w");
asmjit::FileLogger* codeLogger = new asmjit::FileLogger(codeLogfile);
if (codeLogger) {
- code_.setLogger(codeLogger);
+ code.setLogger(codeLogger);
}
#endif
@@ -1570,14 +1557,16 @@ GenConvKernel<2, int32_t>::getOrCreateRowOffset<inst_set_t::avx2>(
a->emitEpilog(frame);
+ asmjit::Error err;
jit_rowoffset_kernel_fp fn;
- asmjit::Error err = rt_.add(&fn, &code_);
+ {
+ std::unique_lock<std::mutex> lock(rtMutex_);
+ err = runtime().add(&fn, &code);
+ }
if (err) {
std::cout << "Error: in fn add" << std::endl;
return nullptr;
}
- auto kernelSig = getKernelSig(conv_param, isAZeroPointZero_);
- codeCacheRowOffset_[kernelSig] = fn;
#if defined(FBGEMM_LOG_CODE)
delete codeLogger;
@@ -2162,15 +2151,14 @@ jit_rowoffset_kernel_fp getOrCreateRowOffsetKernel(
// Note: Wrong code is generated if it's not one of the supported convolution
assert(fbgemmOptimizedGConv<SPATIAL_DIM>(conv_param));
auto kernelSig = getKernelSig(conv_param, a_zero_point == 0);
- if (GenConvKernel<SPATIAL_DIM, int32_t>::codeCacheRowOffset_.find(
- kernelSig) !=
- GenConvKernel<SPATIAL_DIM, int32_t>::codeCacheRowOffset_.end()) {
- return GenConvKernel<SPATIAL_DIM, int32_t>::codeCacheRowOffset_[kernelSig];
- } else {
- auto genObj = GenConvKernel<SPATIAL_DIM, int32_t>(conv_param, a_zero_point);
- // TODO: Instruction set based dispatch
- return genObj.template getOrCreateRowOffset<inst_set_t::avx2>(conv_param);
- }
+ return GenConvKernel<SPATIAL_DIM, int32_t>::codeCacheRowOffset_.getOrCreate(
+ kernelSig, [&]() {
+ auto genObj =
+ GenConvKernel<SPATIAL_DIM, int32_t>(conv_param, a_zero_point);
+ // TODO: Instruction set based dispatch
+ return genObj.template getOrCreateRowOffset<inst_set_t::avx2>(
+ conv_param);
+ });
}
template <int SPATIAL_DIM>