Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/FBGEMM.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'src/GroupwiseConvAcc32Avx2.cc')
-rw-r--r--src/GroupwiseConvAcc32Avx2.cc78
1 files changed, 45 insertions, 33 deletions
diff --git a/src/GroupwiseConvAcc32Avx2.cc b/src/GroupwiseConvAcc32Avx2.cc
index c24c391..b140c83 100644
--- a/src/GroupwiseConvAcc32Avx2.cc
+++ b/src/GroupwiseConvAcc32Avx2.cc
@@ -21,6 +21,20 @@ namespace fbgemm {
using namespace std;
+template <int SPATIAL_DIM, typename accT>
+thread_local asmjit::JitRuntime GenConvKernel<SPATIAL_DIM, accT>::rt_;
+
+template <int SPATIAL_DIM, typename accT>
+thread_local asmjit::CodeHolder GenConvKernel<SPATIAL_DIM, accT>::code_;
+
+template <int SPATIAL_DIM, typename accT>
+thread_local std::map<std::tuple<bool, int, int, int>, jit_conv_kernel_fp>
+ GenConvKernel<SPATIAL_DIM, accT>::codeCache_;
+
+template <int SPATIAL_DIM, typename accT>
+thread_local std::map<std::tuple<bool, int, int, int>, jit_rowoffset_kernel_fp>
+ GenConvKernel<SPATIAL_DIM, accT>::codeCacheRowOffset_;
+
namespace x86 = asmjit::x86;
template <int SPATIAL_DIM>
@@ -77,13 +91,14 @@ jit_conv_kernel_fp getOrCreateConvKernel(
// Note: Wrong code is generated if it's not one of the supported convolution
assert(fbgemmOptimizedGConv<SPATIAL_DIM>(conv_param));
auto kernelSig = getKernelSig(conv_param, a_zero_point == 0);
- return GenConvKernel<SPATIAL_DIM, accT>::codeCache_.getOrCreate(
- kernelSig, [&]() {
- auto genObj =
- GenConvKernel<SPATIAL_DIM, accT>(conv_param, a_zero_point);
- // TODO: Instruction set based dispatch
- return genObj.template getOrCreate<inst_set_t::avx2>(conv_param);
- });
+ if (GenConvKernel<SPATIAL_DIM, accT>::codeCache_.find(kernelSig) !=
+ GenConvKernel<SPATIAL_DIM, accT>::codeCache_.end()) {
+ return GenConvKernel<SPATIAL_DIM, accT>::codeCache_[kernelSig];
+ } else {
+ auto genObj = GenConvKernel<SPATIAL_DIM, accT>(conv_param, a_zero_point);
+ // TODO: Instruction set based dispatch
+ return genObj.template getOrCreate<inst_set_t::avx2>(conv_param);
+ }
}
template <>
@@ -994,9 +1009,9 @@ template <>
template <>
jit_conv_kernel_fp GenConvKernel<2, int32_t>::getOrCreate<inst_set_t::avx2>(
const conv_param_t<2>& conv_param) {
- asmjit::CodeHolder code;
- code.init(rt_.codeInfo());
- x86::Assembler assembler(&code);
+ code_.reset(false);
+ code_.init(rt_.codeInfo());
+ x86::Assembler assembler(&code_);
x86::Emitter* a = assembler.as<x86::Emitter>();
#if defined(FBGEMM_LOG_CODE)
@@ -1005,7 +1020,7 @@ jit_conv_kernel_fp GenConvKernel<2, int32_t>::getOrCreate<inst_set_t::avx2>(
fopen(getCodeLoggingFile<inst_set_t::avx2>(false).c_str(), "w");
asmjit::FileLogger* codeLogger = new asmjit::FileLogger(codeLogfile);
if (codeLogger) {
- code.setLogger(codeLogger);
+ code_.setLogger(codeLogger);
}
#endif
@@ -1082,15 +1097,13 @@ jit_conv_kernel_fp GenConvKernel<2, int32_t>::getOrCreate<inst_set_t::avx2>(
a->emitEpilog(frame);
jit_conv_kernel_fp fn;
- asmjit::Error err;
- {
- std::unique_lock<std::mutex> lock(rtMutex_);
- err = rt_.add(&fn, &code);
- }
+ asmjit::Error err = rt_.add(&fn, &code_);
if (err) {
std::cout << "Error: in fn add" << std::endl;
return nullptr;
}
+ auto kernelSig = getKernelSig(conv_param, isAZeroPointZero_);
+ codeCache_[kernelSig] = fn;
#if defined(FBGEMM_LOG_CODE)
fclose(codeLogfile);
@@ -1476,9 +1489,9 @@ template <>
jit_rowoffset_kernel_fp
GenConvKernel<2, int32_t>::getOrCreateRowOffset<inst_set_t::avx2>(
const conv_param_t<2>& conv_param) {
- asmjit::CodeHolder code;
- code.init(rt_.codeInfo());
- x86::Assembler assembler(&code);
+ code_.reset(false);
+ code_.init(rt_.codeInfo());
+ x86::Assembler assembler(&code_);
x86::Emitter* a = assembler.as<x86::Emitter>();
#if defined(FBGEMM_LOG_CODE)
@@ -1487,7 +1500,7 @@ GenConvKernel<2, int32_t>::getOrCreateRowOffset<inst_set_t::avx2>(
fopen(getCodeLoggingFile<inst_set_t::avx2>(true).c_str(), "w");
asmjit::FileLogger* codeLogger = new asmjit::FileLogger(codeLogfile);
if (codeLogger) {
- code.setLogger(codeLogger);
+ code_.setLogger(codeLogger);
}
#endif
@@ -1557,16 +1570,14 @@ GenConvKernel<2, int32_t>::getOrCreateRowOffset<inst_set_t::avx2>(
a->emitEpilog(frame);
- asmjit::Error err;
jit_rowoffset_kernel_fp fn;
- {
- std::unique_lock<std::mutex> lock(rtMutex_);
- err = rt_.add(&fn, &code);
- }
+ asmjit::Error err = rt_.add(&fn, &code_);
if (err) {
std::cout << "Error: in fn add" << std::endl;
return nullptr;
}
+ auto kernelSig = getKernelSig(conv_param, isAZeroPointZero_);
+ codeCacheRowOffset_[kernelSig] = fn;
#if defined(FBGEMM_LOG_CODE)
delete codeLogger;
@@ -2151,14 +2162,15 @@ jit_rowoffset_kernel_fp getOrCreateRowOffsetKernel(
// Note: Wrong code is generated if it's not one of the supported convolution
assert(fbgemmOptimizedGConv<SPATIAL_DIM>(conv_param));
auto kernelSig = getKernelSig(conv_param, a_zero_point == 0);
- return GenConvKernel<SPATIAL_DIM, int32_t>::codeCacheRowOffset_.getOrCreate(
- kernelSig, [&]() {
- auto genObj =
- GenConvKernel<SPATIAL_DIM, int32_t>(conv_param, a_zero_point);
- // TODO: Instruction set based dispatch
- return genObj.template getOrCreateRowOffset<inst_set_t::avx2>(
- conv_param);
- });
+ if (GenConvKernel<SPATIAL_DIM, int32_t>::codeCacheRowOffset_.find(
+ kernelSig) !=
+ GenConvKernel<SPATIAL_DIM, int32_t>::codeCacheRowOffset_.end()) {
+ return GenConvKernel<SPATIAL_DIM, int32_t>::codeCacheRowOffset_[kernelSig];
+ } else {
+ auto genObj = GenConvKernel<SPATIAL_DIM, int32_t>(conv_param, a_zero_point);
+ // TODO: Instruction set based dispatch
+ return genObj.template getOrCreateRowOffset<inst_set_t::avx2>(conv_param);
+ }
}
template <int SPATIAL_DIM>