diff options
Diffstat (limited to 'src/GroupwiseConvAcc32Avx2.cc')
-rw-r--r-- | src/GroupwiseConvAcc32Avx2.cc | 78 |
1 files changed, 33 insertions, 45 deletions
diff --git a/src/GroupwiseConvAcc32Avx2.cc b/src/GroupwiseConvAcc32Avx2.cc index b140c83..ef4ba7b 100644 --- a/src/GroupwiseConvAcc32Avx2.cc +++ b/src/GroupwiseConvAcc32Avx2.cc @@ -21,20 +21,6 @@ namespace fbgemm { using namespace std; -template <int SPATIAL_DIM, typename accT> -thread_local asmjit::JitRuntime GenConvKernel<SPATIAL_DIM, accT>::rt_; - -template <int SPATIAL_DIM, typename accT> -thread_local asmjit::CodeHolder GenConvKernel<SPATIAL_DIM, accT>::code_; - -template <int SPATIAL_DIM, typename accT> -thread_local std::map<std::tuple<bool, int, int, int>, jit_conv_kernel_fp> - GenConvKernel<SPATIAL_DIM, accT>::codeCache_; - -template <int SPATIAL_DIM, typename accT> -thread_local std::map<std::tuple<bool, int, int, int>, jit_rowoffset_kernel_fp> - GenConvKernel<SPATIAL_DIM, accT>::codeCacheRowOffset_; - namespace x86 = asmjit::x86; template <int SPATIAL_DIM> @@ -91,14 +77,13 @@ jit_conv_kernel_fp getOrCreateConvKernel( // Note: Wrong code is generated if it's not one of the supported convolution assert(fbgemmOptimizedGConv<SPATIAL_DIM>(conv_param)); auto kernelSig = getKernelSig(conv_param, a_zero_point == 0); - if (GenConvKernel<SPATIAL_DIM, accT>::codeCache_.find(kernelSig) != - GenConvKernel<SPATIAL_DIM, accT>::codeCache_.end()) { - return GenConvKernel<SPATIAL_DIM, accT>::codeCache_[kernelSig]; - } else { - auto genObj = GenConvKernel<SPATIAL_DIM, accT>(conv_param, a_zero_point); - // TODO: Instruction set based dispatch - return genObj.template getOrCreate<inst_set_t::avx2>(conv_param); - } + return GenConvKernel<SPATIAL_DIM, accT>::codeCache_.getOrCreate( + kernelSig, [&]() { + auto genObj = + GenConvKernel<SPATIAL_DIM, accT>(conv_param, a_zero_point); + // TODO: Instruction set based dispatch + return genObj.template getOrCreate<inst_set_t::avx2>(conv_param); + }); } template <> @@ -1009,9 +994,9 @@ template <> template <> jit_conv_kernel_fp GenConvKernel<2, int32_t>::getOrCreate<inst_set_t::avx2>( const conv_param_t<2>& conv_param) { - code_.reset(false); - code_.init(rt_.codeInfo()); - x86::Assembler assembler(&code_); + asmjit::CodeHolder code; + code.init(runtime().codeInfo()); + x86::Assembler assembler(&code); x86::Emitter* a = assembler.as<x86::Emitter>(); #if defined(FBGEMM_LOG_CODE) @@ -1020,7 +1005,7 @@ jit_conv_kernel_fp GenConvKernel<2, int32_t>::getOrCreate<inst_set_t::avx2>( fopen(getCodeLoggingFile<inst_set_t::avx2>(false).c_str(), "w"); asmjit::FileLogger* codeLogger = new asmjit::FileLogger(codeLogfile); if (codeLogger) { - code_.setLogger(codeLogger); + code.setLogger(codeLogger); } #endif @@ -1097,13 +1082,15 @@ jit_conv_kernel_fp GenConvKernel<2, int32_t>::getOrCreate<inst_set_t::avx2>( a->emitEpilog(frame); jit_conv_kernel_fp fn; - asmjit::Error err = rt_.add(&fn, &code_); + asmjit::Error err; + { + std::unique_lock<std::mutex> lock(rtMutex_); + err = runtime().add(&fn, &code); + } if (err) { std::cout << "Error: in fn add" << std::endl; return nullptr; } - auto kernelSig = getKernelSig(conv_param, isAZeroPointZero_); - codeCache_[kernelSig] = fn; #if defined(FBGEMM_LOG_CODE) fclose(codeLogfile); @@ -1489,9 +1476,9 @@ template <> jit_rowoffset_kernel_fp GenConvKernel<2, int32_t>::getOrCreateRowOffset<inst_set_t::avx2>( const conv_param_t<2>& conv_param) { - code_.reset(false); - code_.init(rt_.codeInfo()); - x86::Assembler assembler(&code_); + asmjit::CodeHolder code; + code.init(runtime().codeInfo()); + x86::Assembler assembler(&code); x86::Emitter* a = assembler.as<x86::Emitter>(); #if defined(FBGEMM_LOG_CODE) @@ -1500,7 +1487,7 @@ GenConvKernel<2, int32_t>::getOrCreateRowOffset<inst_set_t::avx2>( fopen(getCodeLoggingFile<inst_set_t::avx2>(true).c_str(), "w"); asmjit::FileLogger* codeLogger = new asmjit::FileLogger(codeLogfile); if (codeLogger) { - code_.setLogger(codeLogger); + code.setLogger(codeLogger); } #endif @@ -1570,14 +1557,16 @@ GenConvKernel<2, int32_t>::getOrCreateRowOffset<inst_set_t::avx2>( a->emitEpilog(frame); + asmjit::Error err; jit_rowoffset_kernel_fp fn; - asmjit::Error err = rt_.add(&fn, &code_); + { + std::unique_lock<std::mutex> lock(rtMutex_); + err = runtime().add(&fn, &code); + } if (err) { std::cout << "Error: in fn add" << std::endl; return nullptr; } - auto kernelSig = getKernelSig(conv_param, isAZeroPointZero_); - codeCacheRowOffset_[kernelSig] = fn; #if defined(FBGEMM_LOG_CODE) delete codeLogger; @@ -2162,15 +2151,14 @@ jit_rowoffset_kernel_fp getOrCreateRowOffsetKernel( // Note: Wrong code is generated if it's not one of the supported convolution assert(fbgemmOptimizedGConv<SPATIAL_DIM>(conv_param)); auto kernelSig = getKernelSig(conv_param, a_zero_point == 0); - if (GenConvKernel<SPATIAL_DIM, int32_t>::codeCacheRowOffset_.find( - kernelSig) != - GenConvKernel<SPATIAL_DIM, int32_t>::codeCacheRowOffset_.end()) { - return GenConvKernel<SPATIAL_DIM, int32_t>::codeCacheRowOffset_[kernelSig]; - } else { - auto genObj = GenConvKernel<SPATIAL_DIM, int32_t>(conv_param, a_zero_point); - // TODO: Instruction set based dispatch - return genObj.template getOrCreateRowOffset<inst_set_t::avx2>(conv_param); - } + return GenConvKernel<SPATIAL_DIM, int32_t>::codeCacheRowOffset_.getOrCreate( + kernelSig, [&]() { + auto genObj = + GenConvKernel<SPATIAL_DIM, int32_t>(conv_param, a_zero_point); + // TODO: Instruction set based dispatch + return genObj.template getOrCreateRowOffset<inst_set_t::avx2>( + conv_param); + }); } template <int SPATIAL_DIM> |