diff options
Diffstat (limited to 'src/GroupwiseConvAcc32Avx2.cc')
-rw-r--r-- | src/GroupwiseConvAcc32Avx2.cc | 78 |
1 files changed, 45 insertions, 33 deletions
diff --git a/src/GroupwiseConvAcc32Avx2.cc b/src/GroupwiseConvAcc32Avx2.cc index c24c391..b140c83 100644 --- a/src/GroupwiseConvAcc32Avx2.cc +++ b/src/GroupwiseConvAcc32Avx2.cc @@ -21,6 +21,20 @@ namespace fbgemm { using namespace std; +template <int SPATIAL_DIM, typename accT> +thread_local asmjit::JitRuntime GenConvKernel<SPATIAL_DIM, accT>::rt_; + +template <int SPATIAL_DIM, typename accT> +thread_local asmjit::CodeHolder GenConvKernel<SPATIAL_DIM, accT>::code_; + +template <int SPATIAL_DIM, typename accT> +thread_local std::map<std::tuple<bool, int, int, int>, jit_conv_kernel_fp> + GenConvKernel<SPATIAL_DIM, accT>::codeCache_; + +template <int SPATIAL_DIM, typename accT> +thread_local std::map<std::tuple<bool, int, int, int>, jit_rowoffset_kernel_fp> + GenConvKernel<SPATIAL_DIM, accT>::codeCacheRowOffset_; + namespace x86 = asmjit::x86; template <int SPATIAL_DIM> @@ -77,13 +91,14 @@ jit_conv_kernel_fp getOrCreateConvKernel( // Note: Wrong code is generated if it's not one of the supported convolution assert(fbgemmOptimizedGConv<SPATIAL_DIM>(conv_param)); auto kernelSig = getKernelSig(conv_param, a_zero_point == 0); - return GenConvKernel<SPATIAL_DIM, accT>::codeCache_.getOrCreate( - kernelSig, [&]() { - auto genObj = - GenConvKernel<SPATIAL_DIM, accT>(conv_param, a_zero_point); - // TODO: Instruction set based dispatch - return genObj.template getOrCreate<inst_set_t::avx2>(conv_param); - }); + if (GenConvKernel<SPATIAL_DIM, accT>::codeCache_.find(kernelSig) != + GenConvKernel<SPATIAL_DIM, accT>::codeCache_.end()) { + return GenConvKernel<SPATIAL_DIM, accT>::codeCache_[kernelSig]; + } else { + auto genObj = GenConvKernel<SPATIAL_DIM, accT>(conv_param, a_zero_point); + // TODO: Instruction set based dispatch + return genObj.template getOrCreate<inst_set_t::avx2>(conv_param); + } } template <> @@ -994,9 +1009,9 @@ template <> template <> jit_conv_kernel_fp GenConvKernel<2, int32_t>::getOrCreate<inst_set_t::avx2>( const conv_param_t<2>& conv_param) { - asmjit::CodeHolder code; - code.init(rt_.codeInfo()); - x86::Assembler assembler(&code); + code_.reset(false); + code_.init(rt_.codeInfo()); + x86::Assembler assembler(&code_); x86::Emitter* a = assembler.as<x86::Emitter>(); #if defined(FBGEMM_LOG_CODE) @@ -1005,7 +1020,7 @@ jit_conv_kernel_fp GenConvKernel<2, int32_t>::getOrCreate<inst_set_t::avx2>( fopen(getCodeLoggingFile<inst_set_t::avx2>(false).c_str(), "w"); asmjit::FileLogger* codeLogger = new asmjit::FileLogger(codeLogfile); if (codeLogger) { - code.setLogger(codeLogger); + code_.setLogger(codeLogger); } #endif @@ -1082,15 +1097,13 @@ jit_conv_kernel_fp GenConvKernel<2, int32_t>::getOrCreate<inst_set_t::avx2>( a->emitEpilog(frame); jit_conv_kernel_fp fn; - asmjit::Error err; - { - std::unique_lock<std::mutex> lock(rtMutex_); - err = rt_.add(&fn, &code); - } + asmjit::Error err = rt_.add(&fn, &code_); if (err) { std::cout << "Error: in fn add" << std::endl; return nullptr; } + auto kernelSig = getKernelSig(conv_param, isAZeroPointZero_); + codeCache_[kernelSig] = fn; #if defined(FBGEMM_LOG_CODE) fclose(codeLogfile); @@ -1476,9 +1489,9 @@ template <> jit_rowoffset_kernel_fp GenConvKernel<2, int32_t>::getOrCreateRowOffset<inst_set_t::avx2>( const conv_param_t<2>& conv_param) { - asmjit::CodeHolder code; - code.init(rt_.codeInfo()); - x86::Assembler assembler(&code); + code_.reset(false); + code_.init(rt_.codeInfo()); + x86::Assembler assembler(&code_); x86::Emitter* a = assembler.as<x86::Emitter>(); #if defined(FBGEMM_LOG_CODE) @@ -1487,7 +1500,7 @@ GenConvKernel<2, int32_t>::getOrCreateRowOffset<inst_set_t::avx2>( fopen(getCodeLoggingFile<inst_set_t::avx2>(true).c_str(), "w"); asmjit::FileLogger* codeLogger = new asmjit::FileLogger(codeLogfile); if (codeLogger) { - code.setLogger(codeLogger); + code_.setLogger(codeLogger); } #endif @@ -1557,16 +1570,14 @@ GenConvKernel<2, int32_t>::getOrCreateRowOffset<inst_set_t::avx2>( a->emitEpilog(frame); - asmjit::Error err; jit_rowoffset_kernel_fp fn; - { - std::unique_lock<std::mutex> lock(rtMutex_); - err = rt_.add(&fn, &code); - } + asmjit::Error err = rt_.add(&fn, &code_); if (err) { std::cout << "Error: in fn add" << std::endl; return nullptr; } + auto kernelSig = getKernelSig(conv_param, isAZeroPointZero_); + codeCacheRowOffset_[kernelSig] = fn; #if defined(FBGEMM_LOG_CODE) delete codeLogger; @@ -2151,14 +2162,15 @@ jit_rowoffset_kernel_fp getOrCreateRowOffsetKernel( // Note: Wrong code is generated if it's not one of the supported convolution assert(fbgemmOptimizedGConv<SPATIAL_DIM>(conv_param)); auto kernelSig = getKernelSig(conv_param, a_zero_point == 0); - return GenConvKernel<SPATIAL_DIM, int32_t>::codeCacheRowOffset_.getOrCreate( - kernelSig, [&]() { - auto genObj = - GenConvKernel<SPATIAL_DIM, int32_t>(conv_param, a_zero_point); - // TODO: Instruction set based dispatch - return genObj.template getOrCreateRowOffset<inst_set_t::avx2>( - conv_param); - }); + if (GenConvKernel<SPATIAL_DIM, int32_t>::codeCacheRowOffset_.find( + kernelSig) != + GenConvKernel<SPATIAL_DIM, int32_t>::codeCacheRowOffset_.end()) { + return GenConvKernel<SPATIAL_DIM, int32_t>::codeCacheRowOffset_[kernelSig]; + } else { + auto genObj = GenConvKernel<SPATIAL_DIM, int32_t>(conv_param, a_zero_point); + // TODO: Instruction set based dispatch + return genObj.template getOrCreateRowOffset<inst_set_t::avx2>(conv_param); + } } template <int SPATIAL_DIM> |