Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/FBGEMM.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAleks Zi <zlateski@fb.com>2019-09-04 21:27:35 +0300
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2019-09-04 21:29:38 +0300
commitab2d5278073fce692b4f3b95d164f8cb3bbb1f72 (patch)
tree07cd1856b513a44566731a6d9a403a129bc50570
parent21782ffd9ede194cdf2395854adc10ba11d0d896 (diff)
Introduced CodeCache container to share the microkernels among different threads.
Summary: CodeCache is thread safe and ensures single creation of each microkernel. Uses a single jitRuntiume written to under a lock. The CodeHolder was removed from the class members as it is only a tmporary class, and can be created/destroyed on demand - no need to keep the metadata of the last generated microkernel. Reviewed By: dskhudia Differential Revision: D16968373 fbshipit-source-id: 22d66e50d9b3173c542e28daa322e7869eb52b14
-rw-r--r--src/CodeCache.h59
-rw-r--r--src/GenerateKernel.h23
-rw-r--r--src/GenerateKernelU8S8S32ACC16.cc341
-rw-r--r--src/GenerateKernelU8S8S32ACC16Avx512.cc473
-rw-r--r--src/GenerateKernelU8S8S32ACC32.cc380
-rw-r--r--src/GenerateKernelU8S8S32ACC32Avx512.cc504
-rw-r--r--src/GenerateKernelU8S8S32ACC32Avx512VNNI.cc501
-rw-r--r--src/GroupwiseConv.h33
-rw-r--r--src/GroupwiseConvAcc32Avx2.cc78
9 files changed, 1189 insertions, 1203 deletions
diff --git a/src/CodeCache.h b/src/CodeCache.h
new file mode 100644
index 0000000..08e9c9b
--- /dev/null
+++ b/src/CodeCache.h
@@ -0,0 +1,59 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+#pragma once
+#include <condition_variable>
+#include <future>
+#include <map>
+#include <mutex>
+
+namespace fbgemm {
+
+/**
+ * @brief Thread safe cache for microkernels, ensures single creation per key.
+ * @tparam Key Type of unique key (typically a tuple)
+ * @tparam Value Type of the microkernel function (Typically a function pointer)
+ */
+template <typename KEY, typename VALUE> class CodeCache {
+private:
+ std::map<KEY, std::shared_future<VALUE>> values_;
+ std::mutex mutex_;
+
+public:
+ CodeCache(const CodeCache &) = delete;
+ CodeCache &operator=(const CodeCache &) = delete;
+
+ CodeCache(){};
+
+ VALUE getOrCreate(const KEY &key, std::function<VALUE()> generatorFunction) {
+ std::shared_future<VALUE> returnFuture;
+ std::promise<VALUE> returnPromise;
+ bool needsToGenerate = false;
+
+ // Check for existance of the key
+ {
+ std::unique_lock<std::mutex> lock(mutex_);
+
+ auto it = values_.find(key);
+ if (it != values_.end()) {
+ returnFuture = it->second;
+ } else {
+ values_[key] = returnFuture = returnPromise.get_future().share();
+ needsToGenerate = true;
+ }
+ }
+
+ // The value (code) generation is not happening under a lock
+ if (needsToGenerate) {
+ returnPromise.set_value(generatorFunction());
+ }
+
+ // Wait for the future and return the value
+ return returnFuture.get();
+ }
+};
+
+} // namespace fbgemm
diff --git a/src/GenerateKernel.h b/src/GenerateKernel.h
index e52097e..bd61473 100644
--- a/src/GenerateKernel.h
+++ b/src/GenerateKernel.h
@@ -8,8 +8,10 @@
#include <asmjit/asmjit.h>
#include <cpuinfo.h>
#include <map>
+#include <mutex>
#include <string>
#include <tuple>
+#include "CodeCache.h"
#include "fbgemm/Fbgemm.h"
/*#define FBGEMM_LOG_CODE 1*/
@@ -187,13 +189,24 @@ class CodeGenBase {
int vectorWidth_; ///< Vector width in bits.
int VLEN_; ///< Vector width in elements.
- static thread_local asmjit::JitRuntime rt_; ///< JIT Runtime for asmjit.
- static thread_local asmjit::CodeHolder code_; ///< JIT Code Holder for asmjit.
+ static asmjit::JitRuntime rt_; ///< JIT Runtime for asmjit.
+ static std::mutex rtMutex_; ///< Controll access to rt_;
+
// The hash depends on accumulate, mc, nc, ncb, kcb, nr, mr, nr_min
- static thread_local std::map<
- std::tuple<bool, int, int, int, int, int, int, int>,
- jit_micro_kernel_fp>
+ static CodeCache<std::tuple<bool, int, int, int, int, int, int, int>,
+ jit_micro_kernel_fp>
codeCache_; ///< JIT Code Cache for reuse.
};
+template <typename TA, typename TB, typename TC, typename accT>
+asmjit::JitRuntime CodeGenBase<TA, TB, TC, accT>::rt_;
+
+template <typename TA, typename TB, typename TC, typename accT>
+std::mutex CodeGenBase<TA, TB, TC, accT>::rtMutex_;
+
+template <typename TA, typename TB, typename TC, typename accT>
+CodeCache<std::tuple<bool, int, int, int, int, int, int, int>,
+ typename CodeGenBase<TA, TB, TC, accT>::jit_micro_kernel_fp>
+ CodeGenBase<TA, TB, TC, accT>::codeCache_;
+
} // namespace fbgemm
diff --git a/src/GenerateKernelU8S8S32ACC16.cc b/src/GenerateKernelU8S8S32ACC16.cc
index 1e7e081..6377961 100644
--- a/src/GenerateKernelU8S8S32ACC16.cc
+++ b/src/GenerateKernelU8S8S32ACC16.cc
@@ -9,18 +9,6 @@
namespace fbgemm {
-template <typename TA, typename TB, typename TC, typename accT>
-thread_local asmjit::JitRuntime CodeGenBase<TA, TB, TC, accT>::rt_;
-
-template <typename TA, typename TB, typename TC, typename accT>
-thread_local asmjit::CodeHolder CodeGenBase<TA, TB, TC, accT>::code_;
-
-template <typename TA, typename TB, typename TC, typename accT>
-thread_local std::map<
- std::tuple<bool, int, int, int, int, int, int, int>,
- typename CodeGenBase<TA, TB, TC, accT>::jit_micro_kernel_fp>
- CodeGenBase<TA, TB, TC, accT>::codeCache_;
-
namespace x86 = asmjit::x86;
/**
@@ -172,191 +160,186 @@ CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<inst_set_t::avx2>(
nRegBlockSize,
nRegBlockSizeMin);
- if (codeCache_.find(kernelSig) != codeCache_.end()) {
- return codeCache_[kernelSig];
- }
- code_.reset(false);
- code_.init(rt_.codeInfo());
- x86::Assembler assembler(&code_);
- x86::Emitter* a = assembler.as<x86::Emitter>();
+ return codeCache_.getOrCreate(kernelSig, [&]() -> jit_micro_kernel_fp {
+ asmjit::CodeHolder code;
+ code.init(rt_.codeInfo());
+ x86::Assembler assembler(&code);
+ x86::Emitter *a = assembler.as<x86::Emitter>();
#if defined(FBGEMM_LOG_CODE)
- // generated code logging
- FILE* codeLogfile = fopen(
- getCodeLoggingFile<inst_set_t::avx2>(
- accum,
- mc,
- nc,
- nBlock,
- kBlock,
- mRegBlockSize,
- nRegBlockSize,
- nRegBlockSizeMin)
- .c_str(),
- "w");
- asmjit::FileLogger* codeLogger = new asmjit::FileLogger(codeLogfile);
- if (codeLogger) {
- code_.setLogger(codeLogger);
- }
+ // generated code logging
+ FILE *codeLogfile = fopen(getCodeLoggingFile<inst_set_t::avx2>(
+ accum, mc, nc, nBlock, kBlock, mRegBlockSize,
+ nRegBlockSize, nRegBlockSizeMin)
+ .c_str(),
+ "w");
+ asmjit::FileLogger *codeLogger = new asmjit::FileLogger(codeLogfile);
+ if (codeLogger) {
+ code.setLogger(codeLogger);
+ }
#endif
- int mRegBlocks = mc / mRegBlockSize;
- int mRegBlocksRem = mc % mRegBlockSize;
- assert(kc % row_interleave == 0 && "kc must be a multiple of row_interleave");
- // assert((nc == nRegBlockSize) &&
- //"nc must be equal to the number of register blocks");
-
- // arguments to the function created
- x86::Gp buffer_A = a->zdi();
- x86::Gp buffer_B = a->zsi();
- x86::Gp B_pf = a->zdx();
- x86::Gp CBase = a->zcx();
- x86::Gp kSize = a->gpz(8);
- x86::Gp ldcReg = a->gpz(9);
-
- asmjit::FuncDetail func;
- func.init(
- asmjit::
- FuncSignatureT<void, uint8_t*, int8_t*, int8_t*, int32_t*, int, int>(
- asmjit::CallConv::kIdHost));
-
- asmjit::FuncFrame frame;
- frame.init(func);
- frame.setDirtyRegs(
- x86::Reg::kGroupVec,
- asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) |
- asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
- frame.setDirtyRegs(
- x86::Reg::kGroupGp, asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14));
-
- asmjit::FuncArgsAssignment args(&func);
- args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg);
-
- args.updateFuncFrame(frame);
- frame.finalize();
-
- a->emitProlog(frame);
- a->emitArgsAssignment(frame, args);
-
- asmjit::Label Loopk = a->newLabel();
- asmjit::Label LoopMBlocks = a->newLabel();
-
- x86::Gp buffer_B_saved = a->gpz(10);
- x86::Gp C_Offset = a->gpz(11);
- // x86::Gp B_pf_saved = a->gpz(12);
- x86::Gp iIdx = a->gpz(13);
- x86::Gp kIdx = a->gpz(14);
-
- int colRegs = nc * row_interleave * sizeof(int8_t) / VLEN_;
- if (mRegBlocks > 0) {
- // move 0 to iteration variables
- a->mov(iIdx, 0);
-
- // save B_buffer address
- a->mov(buffer_B_saved, buffer_B);
- // a->mov(B_pf_saved, B_pf);
-
- a->bind(LoopMBlocks);
- a->inc(iIdx);
-
- int rowRegs = mRegBlockSize;
-
- // init C registers
- initCRegs<inst_set_t::avx2>(a, rowRegs, colRegs);
-
- // init k loop index
- a->mov(kIdx, 0);
- a->bind(Loopk);
- // k is incremented by row_interleave
- a->add(kIdx, static_cast<asmjit::Imm>(row_interleave));
-
- genComputeBlock<inst_set_t::avx2>(
- a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock);
-
- // update buffer_A address for next k iteration
- a->add(
- buffer_A, static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t)));
-
- // update buffer_B address for next k iteration
- a->add(
- buffer_B,
- static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t)));
- // a->add(B_pf, static_cast<asmjit::Imm>(nBlock * row_interleave *
- // sizeof(int8_t)));
-
- a->cmp(kIdx, kSize);
- a->jl(Loopk);
-
- // store C matrix
- storeCRegs<inst_set_t::avx2>(a, rowRegs, colRegs, C_Offset, ldcReg, accum);
-
- // increment A for next block
- a->sub(buffer_A, kSize);
- a->add(
- buffer_A, static_cast<asmjit::Imm>((rowRegs)*kBlock * sizeof(uint8_t)));
- // increment C for next block
- a->imul(
- C_Offset, ldcReg, static_cast<asmjit::Imm>(rowRegs * sizeof(int32_t)));
- a->add(CBase, C_Offset);
- // reset B
- a->mov(buffer_B, buffer_B_saved);
- // a->mov(B_pf, B_pf_saved);
-
- a->cmp(iIdx, mRegBlocks);
- a->jl(LoopMBlocks);
- }
- // generate code for remainder
- if (mRegBlocksRem > 0) {
- asmjit::Label LoopkRem = a->newLabel();
- int rowRegs = mRegBlocksRem;
+ int mRegBlocks = mc / mRegBlockSize;
+ int mRegBlocksRem = mc % mRegBlockSize;
+ assert(kc % row_interleave == 0 &&
+ "kc must be a multiple of row_interleave");
+ // assert((nc == nRegBlockSize) &&
+ //"nc must be equal to the number of register blocks");
+
+ // arguments to the function created
+ x86::Gp buffer_A = a->zdi();
+ x86::Gp buffer_B = a->zsi();
+ x86::Gp B_pf = a->zdx();
+ x86::Gp CBase = a->zcx();
+ x86::Gp kSize = a->gpz(8);
+ x86::Gp ldcReg = a->gpz(9);
+
+ asmjit::FuncDetail func;
+ func.init(
+ asmjit::FuncSignatureT<void, uint8_t *, int8_t *, int8_t *, int32_t *,
+ int, int>(asmjit::CallConv::kIdHost));
+
+ asmjit::FuncFrame frame;
+ frame.init(func);
+ frame.setDirtyRegs(
+ x86::Reg::kGroupVec,
+ asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) |
+ asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
+ frame.setDirtyRegs(x86::Reg::kGroupGp,
+ asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14));
+
+ asmjit::FuncArgsAssignment args(&func);
+ args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg);
+
+ args.updateFuncFrame(frame);
+ frame.finalize();
+
+ a->emitProlog(frame);
+ a->emitArgsAssignment(frame, args);
+
+ asmjit::Label Loopk = a->newLabel();
+ asmjit::Label LoopMBlocks = a->newLabel();
+
+ x86::Gp buffer_B_saved = a->gpz(10);
+ x86::Gp C_Offset = a->gpz(11);
+ // x86::Gp B_pf_saved = a->gpz(12);
+ x86::Gp iIdx = a->gpz(13);
+ x86::Gp kIdx = a->gpz(14);
+
+ int colRegs = nc * row_interleave * sizeof(int8_t) / VLEN_;
+ if (mRegBlocks > 0) {
+ // move 0 to iteration variables
+ a->mov(iIdx, 0);
+
+ // save B_buffer address
+ a->mov(buffer_B_saved, buffer_B);
+ // a->mov(B_pf_saved, B_pf);
+
+ a->bind(LoopMBlocks);
+ a->inc(iIdx);
+
+ int rowRegs = mRegBlockSize;
+
+ // init C registers
+ initCRegs<inst_set_t::avx2>(a, rowRegs, colRegs);
+
+ // init k loop index
+ a->mov(kIdx, 0);
+ a->bind(Loopk);
+ // k is incremented by row_interleave
+ a->add(kIdx, static_cast<asmjit::Imm>(row_interleave));
+
+ genComputeBlock<inst_set_t::avx2>(a, buffer_A, buffer_B, B_pf, rowRegs,
+ colRegs, kBlock);
+
+ // update buffer_A address for next k iteration
+ a->add(buffer_A,
+ static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t)));
+
+ // update buffer_B address for next k iteration
+ a->add(buffer_B, static_cast<asmjit::Imm>(nBlock * row_interleave *
+ sizeof(int8_t)));
+ // a->add(B_pf, static_cast<asmjit::Imm>(nBlock * row_interleave *
+ // sizeof(int8_t)));
+
+ a->cmp(kIdx, kSize);
+ a->jl(Loopk);
+
+ // store C matrix
+ storeCRegs<inst_set_t::avx2>(a, rowRegs, colRegs, C_Offset, ldcReg,
+ accum);
+
+ // increment A for next block
+ a->sub(buffer_A, kSize);
+ a->add(buffer_A,
+ static_cast<asmjit::Imm>((rowRegs)*kBlock * sizeof(uint8_t)));
+ // increment C for next block
+ a->imul(C_Offset, ldcReg,
+ static_cast<asmjit::Imm>(rowRegs * sizeof(int32_t)));
+ a->add(CBase, C_Offset);
+ // reset B
+ a->mov(buffer_B, buffer_B_saved);
+ // a->mov(B_pf, B_pf_saved);
+
+ a->cmp(iIdx, mRegBlocks);
+ a->jl(LoopMBlocks);
+ }
+ // generate code for remainder
+ if (mRegBlocksRem > 0) {
+ asmjit::Label LoopkRem = a->newLabel();
+ int rowRegs = mRegBlocksRem;
- // init C registers
- initCRegs<inst_set_t::avx2>(a, rowRegs, colRegs);
+ // init C registers
+ initCRegs<inst_set_t::avx2>(a, rowRegs, colRegs);
- // init k loop index
- a->mov(kIdx, 0);
- a->bind(LoopkRem);
+ // init k loop index
+ a->mov(kIdx, 0);
+ a->bind(LoopkRem);
- // k is incremented by row_interleave
- a->add(kIdx, static_cast<asmjit::Imm>(row_interleave));
+ // k is incremented by row_interleave
+ a->add(kIdx, static_cast<asmjit::Imm>(row_interleave));
- genComputeBlock<inst_set_t::avx2>(
- a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock);
+ genComputeBlock<inst_set_t::avx2>(a, buffer_A, buffer_B, B_pf, rowRegs,
+ colRegs, kBlock);
- // update buffer_A address for next k iteration
- a->add(
- buffer_A, static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t)));
+ // update buffer_A address for next k iteration
+ a->add(buffer_A,
+ static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t)));
- // update buffer_B address for next k iteration
- a->add(
- buffer_B,
- static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t)));
- // a->add(B_pf, static_cast<asmjit::Imm>(nBlock * row_interleave *
- // sizeof(int8_t)));
+ // update buffer_B address for next k iteration
+ a->add(buffer_B, static_cast<asmjit::Imm>(nBlock * row_interleave *
+ sizeof(int8_t)));
+ // a->add(B_pf, static_cast<asmjit::Imm>(nBlock * row_interleave *
+ // sizeof(int8_t)));
- a->cmp(kIdx, kSize);
- a->jl(LoopkRem);
+ a->cmp(kIdx, kSize);
+ a->jl(LoopkRem);
- // store C matrix
- storeCRegs<inst_set_t::avx2>(a, rowRegs, colRegs, C_Offset, ldcReg, accum);
- }
+ // store C matrix
+ storeCRegs<inst_set_t::avx2>(a, rowRegs, colRegs, C_Offset, ldcReg,
+ accum);
+ }
- a->emitEpilog(frame);
+ a->emitEpilog(frame);
- jit_micro_kernel_fp fn;
- asmjit::Error err = rt_.add(&fn, &code_);
- if (err) {
- std::cout << "Error: in fn add" << std::endl;
- return nullptr;
- }
- codeCache_[kernelSig] = fn;
+ jit_micro_kernel_fp fn;
+ asmjit::Error err;
+ {
+ std::unique_lock<std::mutex> lock(rtMutex_);
+ err = rt_.add(&fn, &code);
+ }
+ if (err) {
+ std::cout << "Error: in fn add" << std::endl;
+ return nullptr;
+ }
#if defined(FBGEMM_LOG_CODE)
- fclose(codeLogfile);
- delete codeLogger;
+ fclose(codeLogfile);
+ delete codeLogger;
#endif
- return fn;
+ return fn;
+ });
}
} // namespace fbgemm
diff --git a/src/GenerateKernelU8S8S32ACC16Avx512.cc b/src/GenerateKernelU8S8S32ACC16Avx512.cc
index a49e440..c904de0 100644
--- a/src/GenerateKernelU8S8S32ACC16Avx512.cc
+++ b/src/GenerateKernelU8S8S32ACC16Avx512.cc
@@ -167,262 +167,247 @@ CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<inst_set_t::avx512>(
nRegBlockSize,
nRegBlockSizeMin);
- if (codeCache_.find(kernelSig) != codeCache_.end()) {
- return codeCache_[kernelSig];
- }
-
- code_.reset(false);
- code_.init(rt_.codeInfo());
- x86::Assembler assembler(&code_);
- x86::Emitter* a = assembler.as<x86::Emitter>();
+ return codeCache_.getOrCreate(kernelSig, [&]() -> jit_micro_kernel_fp {
+ asmjit::CodeHolder code;
+ code.init(rt_.codeInfo());
+ x86::Assembler assembler(&code);
+ x86::Emitter *a = assembler.as<x86::Emitter>();
#if defined(FBGEMM_LOG_CODE)
- // generated code logging
- FILE* codeLogfile = fopen(
- getCodeLoggingFile<inst_set_t::avx512>(
- accum,
- mc,
- nc,
- nBlock,
- kBlock,
- mRegBlockSize,
- nRegBlockSize,
- nRegBlockSizeMin)
- .c_str(),
- "w");
- asmjit::FileLogger* codeLogger = new asmjit::FileLogger(codeLogfile);
- if (codeLogger) {
- code_.setLogger(codeLogger);
- }
+ // generated code logging
+ FILE *codeLogfile = fopen(getCodeLoggingFile<inst_set_t::avx512>(
+ accum, mc, nc, nBlock, kBlock, mRegBlockSize,
+ nRegBlockSize, nRegBlockSizeMin)
+ .c_str(),
+ "w");
+ asmjit::FileLogger *codeLogger = new asmjit::FileLogger(codeLogfile);
+ if (codeLogger) {
+ code.setLogger(codeLogger);
+ }
#endif
- assert(kc % row_interleave == 0 && "kc must be a multiple of row_interleave");
- assert(nc % nRegBlockSizeMin == 0 && "nc must be a multiple of NR_MIN");
- int maxMRegs = mRegBlockSize;
- int maxNRegs = nRegBlockSize * row_interleave / VLEN_;
- assert(
- (maxMRegs+1) * maxNRegs <= 28 &&
- "number of zmm registers for C + one row for loading B: \
+ assert(kc % row_interleave == 0 &&
+ "kc must be a multiple of row_interleave");
+ assert(nc % nRegBlockSizeMin == 0 && "nc must be a multiple of NR_MIN");
+ int maxMRegs = mRegBlockSize;
+ int maxNRegs = nRegBlockSize * row_interleave / VLEN_;
+ assert((maxMRegs + 1) * maxNRegs <= 28 &&
+ "number of zmm registers for C + one row for loading B: \
MR*(NR*ROW_INTERLEAVE*8/512) + (NR*ROW_INTERLEAVE*8/512) \
must be <= 28(available registers constraint)");
- int mRegBlocks = mc / mRegBlockSize;
- int mRegBlocksRem = mc % mRegBlockSize;
-
- // arguments to the function created
- x86::Gp buffer_A = a->zdi();
- x86::Gp buffer_B = a->zsi();
- x86::Gp B_pf = a->zdx();
- x86::Gp CBase = a->zcx();
- x86::Gp kSize = a->gpz(8);
- x86::Gp ldcReg = a->gpz(9);
-
- asmjit::FuncDetail func;
- func.init(
- asmjit::
- FuncSignatureT<void, uint8_t*, int8_t*, int8_t*, int32_t*, int, int>(
- asmjit::CallConv::kIdHost));
-
- asmjit::FuncFrame frame;
- frame.init(func);
-
- frame.setDirtyRegs(
- x86::Reg::kGroupVec,
- asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) |
- asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
- frame.setDirtyRegs(
- x86::Reg::kGroupGp,
- asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
-
- asmjit::FuncArgsAssignment args(&func);
- args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg);
-
- args.updateFuncFrame(frame);
- frame.finalize();
-
- a->emitProlog(frame);
- a->emitArgsAssignment(frame, args);
-
- asmjit::Label LoopMBlocks = a->newLabel();
- asmjit::Label LoopNBlocks = a->newLabel();
- asmjit::Label Loopk = a->newLabel();
-
- x86::Gp buffer_B_saved = a->gpz(10);
- x86::Gp C_Offset = a->gpz(11);
- // x86::Gp B_pf_saved = a->gpz(12);
- x86::Gp iIdx = a->gpz(13);
- x86::Gp jIdx = a->gpz(14);
- x86::Gp kIdx = a->gpz(15);
-
- // save B_buffer address
- a->mov(buffer_B_saved, buffer_B);
- // a->mov(B_pf_saved, B_pf);
-
- int currColRegs = nc * row_interleave * sizeof(int8_t) / VLEN_;
- int colRegs = std::min(currColRegs, maxNRegs);
- if (mRegBlocks > 0) {
- // move 0 to iteration variables
- a->mov(iIdx, 0);
-
- a->bind(LoopMBlocks);
- a->inc(iIdx);
- a->mov(jIdx, 0);
-
- a->bind(LoopNBlocks);
- a->inc(jIdx);
-
- int rowRegs = mRegBlockSize;
-
- // init C registers
- initCRegs<inst_set_t::avx512>(a, rowRegs, colRegs, colRegs);
-
- // init k loop index
- a->mov(kIdx, 0);
- a->bind(Loopk);
- // k is incremented by row_interleave
- a->add(kIdx, static_cast<asmjit::Imm>(row_interleave));
-
- genComputeBlock<inst_set_t::avx512>(
- a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock, colRegs);
-
- // update buffer_A address for next k iteration
- a->add(
- buffer_A, static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t)));
-
- // update buffer_B address for next k iteration
- a->add(
- buffer_B,
- static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t)));
- // a->add(B_pf, static_cast<asmjit::Imm>(nBlock * row_interleave *
- // sizeof(int8_t)));
-
- a->cmp(kIdx, kSize);
- a->jl(Loopk);
-
- // store C matrix
- storeCRegs<inst_set_t::avx512>(
- a, rowRegs, colRegs, C_Offset, ldcReg, accum, colRegs);
-
- // reset A
- a->sub(buffer_A, kSize);
-
- // B for next block
- a->mov(buffer_B, buffer_B_saved);
- // using C_Offset as temp reg
- a->imul(
- C_Offset,
- jIdx,
- static_cast<asmjit::Imm>(
- nRegBlockSize * row_interleave * sizeof(int8_t)));
- a->add(buffer_B, C_Offset);
-
- // increment C for next block
- a->add(CBase, static_cast<asmjit::Imm>(nRegBlockSize * sizeof(int32_t)));
-
- int jLoopTrips = currColRegs / maxNRegs;
- // jLoopTrips should be at least 1
- jLoopTrips = jLoopTrips ? jLoopTrips : 1;
- a->cmp(jIdx, jLoopTrips);
- a->jl(LoopNBlocks);
-
- // increment A for next block
- a->add(
- buffer_A, static_cast<asmjit::Imm>((rowRegs)*kBlock * sizeof(uint8_t)));
-
- // increment C for next A block
- a->sub(
- CBase,
- static_cast<asmjit::Imm>(jLoopTrips * nRegBlockSize * sizeof(int32_t)));
- a->imul(
- C_Offset, ldcReg, static_cast<asmjit::Imm>(rowRegs * sizeof(int32_t)));
- a->add(CBase, C_Offset);
-
- // reset B
- a->mov(buffer_B, buffer_B_saved);
- // a->mov(B_pf, B_pf_saved);
-
- a->cmp(iIdx, mRegBlocks);
- a->jl(LoopMBlocks);
- }
- // generate code for remainder
- if (mRegBlocksRem > 0) {
- asmjit::Label LoopNRem = a->newLabel();
- asmjit::Label LoopkRem = a->newLabel();
- int rowRegs = mRegBlocksRem;
-
- a->mov(jIdx, 0);
- a->bind(LoopNRem);
- a->inc(jIdx);
-
- // init C registers
- initCRegs<inst_set_t::avx512>(a, rowRegs, colRegs, colRegs);
-
- // init k loop index
- a->mov(kIdx, 0);
- a->bind(LoopkRem);
-
- // k is incremented by row_interleave
- a->add(kIdx, static_cast<asmjit::Imm>(row_interleave));
-
- genComputeBlock<inst_set_t::avx512>(
- a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock, colRegs);
-
- // update buffer_A address for next k iteration
- a->add(
- buffer_A, static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t)));
-
- // update buffer_B address for next k iteration
- a->add(
- buffer_B,
- static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t)));
- // a->add(B_pf, static_cast<asmjit::Imm>(nBlock * row_interleave *
- // sizeof(int8_t)));
-
- a->cmp(kIdx, kSize);
- a->jl(LoopkRem);
-
- // reset A
- a->sub(buffer_A, kSize);
-
- // B for next block
- a->mov(buffer_B, buffer_B_saved);
- // using C_Offset as temp reg
- a->imul(
- C_Offset,
- jIdx,
- static_cast<asmjit::Imm>(
- nRegBlockSize * row_interleave * sizeof(int8_t)));
- a->add(buffer_B, C_Offset);
-
- // store C matrix
- storeCRegs<inst_set_t::avx512>(
- a, rowRegs, colRegs, C_Offset, ldcReg, accum, colRegs);
-
- // increment C for next block
- a->add(CBase, static_cast<asmjit::Imm>(nRegBlockSize * sizeof(int32_t)));
-
- int jLoopTrips = currColRegs / maxNRegs;
- // jLoopTrips should be at least 1
- jLoopTrips = jLoopTrips ? jLoopTrips : 1;
- a->cmp(jIdx, jLoopTrips);
- a->jl(LoopNRem);
- }
+ int mRegBlocks = mc / mRegBlockSize;
+ int mRegBlocksRem = mc % mRegBlockSize;
+
+ // arguments to the function created
+ x86::Gp buffer_A = a->zdi();
+ x86::Gp buffer_B = a->zsi();
+ x86::Gp B_pf = a->zdx();
+ x86::Gp CBase = a->zcx();
+ x86::Gp kSize = a->gpz(8);
+ x86::Gp ldcReg = a->gpz(9);
+
+ asmjit::FuncDetail func;
+ func.init(
+ asmjit::FuncSignatureT<void, uint8_t *, int8_t *, int8_t *, int32_t *,
+ int, int>(asmjit::CallConv::kIdHost));
+
+ asmjit::FuncFrame frame;
+ frame.init(func);
+
+ frame.setDirtyRegs(
+ x86::Reg::kGroupVec,
+ asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) |
+ asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
+ frame.setDirtyRegs(x86::Reg::kGroupGp,
+ asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
+
+ asmjit::FuncArgsAssignment args(&func);
+ args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg);
+
+ args.updateFuncFrame(frame);
+ frame.finalize();
+
+ a->emitProlog(frame);
+ a->emitArgsAssignment(frame, args);
+
+ asmjit::Label LoopMBlocks = a->newLabel();
+ asmjit::Label LoopNBlocks = a->newLabel();
+ asmjit::Label Loopk = a->newLabel();
+
+ x86::Gp buffer_B_saved = a->gpz(10);
+ x86::Gp C_Offset = a->gpz(11);
+ // x86::Gp B_pf_saved = a->gpz(12);
+ x86::Gp iIdx = a->gpz(13);
+ x86::Gp jIdx = a->gpz(14);
+ x86::Gp kIdx = a->gpz(15);
+
+ // save B_buffer address
+ a->mov(buffer_B_saved, buffer_B);
+ // a->mov(B_pf_saved, B_pf);
+
+ int currColRegs = nc * row_interleave * sizeof(int8_t) / VLEN_;
+ int colRegs = std::min(currColRegs, maxNRegs);
+ if (mRegBlocks > 0) {
+ // move 0 to iteration variables
+ a->mov(iIdx, 0);
+
+ a->bind(LoopMBlocks);
+ a->inc(iIdx);
+ a->mov(jIdx, 0);
+
+ a->bind(LoopNBlocks);
+ a->inc(jIdx);
+
+ int rowRegs = mRegBlockSize;
+
+ // init C registers
+ initCRegs<inst_set_t::avx512>(a, rowRegs, colRegs, colRegs);
+
+ // init k loop index
+ a->mov(kIdx, 0);
+ a->bind(Loopk);
+ // k is incremented by row_interleave
+ a->add(kIdx, static_cast<asmjit::Imm>(row_interleave));
+
+ genComputeBlock<inst_set_t::avx512>(a, buffer_A, buffer_B, B_pf, rowRegs,
+ colRegs, kBlock, colRegs);
+
+ // update buffer_A address for next k iteration
+ a->add(buffer_A,
+ static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t)));
+
+ // update buffer_B address for next k iteration
+ a->add(buffer_B, static_cast<asmjit::Imm>(nBlock * row_interleave *
+ sizeof(int8_t)));
+ // a->add(B_pf, static_cast<asmjit::Imm>(nBlock * row_interleave *
+ // sizeof(int8_t)));
+
+ a->cmp(kIdx, kSize);
+ a->jl(Loopk);
+
+ // store C matrix
+ storeCRegs<inst_set_t::avx512>(a, rowRegs, colRegs, C_Offset, ldcReg,
+ accum, colRegs);
+
+ // reset A
+ a->sub(buffer_A, kSize);
+
+ // B for next block
+ a->mov(buffer_B, buffer_B_saved);
+ // using C_Offset as temp reg
+ a->imul(C_Offset, jIdx,
+ static_cast<asmjit::Imm>(nRegBlockSize * row_interleave *
+ sizeof(int8_t)));
+ a->add(buffer_B, C_Offset);
+
+ // increment C for next block
+ a->add(CBase, static_cast<asmjit::Imm>(nRegBlockSize * sizeof(int32_t)));
+
+ int jLoopTrips = currColRegs / maxNRegs;
+ // jLoopTrips should be at least 1
+ jLoopTrips = jLoopTrips ? jLoopTrips : 1;
+ a->cmp(jIdx, jLoopTrips);
+ a->jl(LoopNBlocks);
+
+ // increment A for next block
+ a->add(buffer_A,
+ static_cast<asmjit::Imm>((rowRegs)*kBlock * sizeof(uint8_t)));
+
+ // increment C for next A block
+ a->sub(CBase, static_cast<asmjit::Imm>(jLoopTrips * nRegBlockSize *
+ sizeof(int32_t)));
+ a->imul(C_Offset, ldcReg,
+ static_cast<asmjit::Imm>(rowRegs * sizeof(int32_t)));
+ a->add(CBase, C_Offset);
+
+ // reset B
+ a->mov(buffer_B, buffer_B_saved);
+ // a->mov(B_pf, B_pf_saved);
+
+ a->cmp(iIdx, mRegBlocks);
+ a->jl(LoopMBlocks);
+ }
+ // generate code for remainder
+ if (mRegBlocksRem > 0) {
+ asmjit::Label LoopNRem = a->newLabel();
+ asmjit::Label LoopkRem = a->newLabel();
+ int rowRegs = mRegBlocksRem;
+
+ a->mov(jIdx, 0);
+ a->bind(LoopNRem);
+ a->inc(jIdx);
+
+ // init C registers
+ initCRegs<inst_set_t::avx512>(a, rowRegs, colRegs, colRegs);
+
+ // init k loop index
+ a->mov(kIdx, 0);
+ a->bind(LoopkRem);
+
+ // k is incremented by row_interleave
+ a->add(kIdx, static_cast<asmjit::Imm>(row_interleave));
+
+ genComputeBlock<inst_set_t::avx512>(a, buffer_A, buffer_B, B_pf, rowRegs,
+ colRegs, kBlock, colRegs);
+
+ // update buffer_A address for next k iteration
+ a->add(buffer_A,
+ static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t)));
+
+ // update buffer_B address for next k iteration
+ a->add(buffer_B, static_cast<asmjit::Imm>(nBlock * row_interleave *
+ sizeof(int8_t)));
+ // a->add(B_pf, static_cast<asmjit::Imm>(nBlock * row_interleave *
+ // sizeof(int8_t)));
+
+ a->cmp(kIdx, kSize);
+ a->jl(LoopkRem);
+
+ // reset A
+ a->sub(buffer_A, kSize);
+
+ // B for next block
+ a->mov(buffer_B, buffer_B_saved);
+ // using C_Offset as temp reg
+ a->imul(C_Offset, jIdx,
+ static_cast<asmjit::Imm>(nRegBlockSize * row_interleave *
+ sizeof(int8_t)));
+ a->add(buffer_B, C_Offset);
+
+ // store C matrix
+ storeCRegs<inst_set_t::avx512>(a, rowRegs, colRegs, C_Offset, ldcReg,
+ accum, colRegs);
+
+ // increment C for next block
+ a->add(CBase, static_cast<asmjit::Imm>(nRegBlockSize * sizeof(int32_t)));
+
+ int jLoopTrips = currColRegs / maxNRegs;
+ // jLoopTrips should be at least 1
+ jLoopTrips = jLoopTrips ? jLoopTrips : 1;
+ a->cmp(jIdx, jLoopTrips);
+ a->jl(LoopNRem);
+ }
- a->emitEpilog(frame);
+ a->emitEpilog(frame);
- jit_micro_kernel_fp fn;
- asmjit::Error err = rt_.add(&fn, &code_);
- if (err) {
- std::cout << "Error: in fn add" << std::endl;
- return nullptr;
- }
- codeCache_[kernelSig] = fn;
+ jit_micro_kernel_fp fn;
+ asmjit::Error err;
+ {
+ std::unique_lock<std::mutex> lock(rtMutex_);
+ err = rt_.add(&fn, &code);
+ }
+ if (err) {
+ std::cout << "Error: in fn add" << std::endl;
+ return nullptr;
+ }
#if defined(FBGEMM_LOG_CODE)
- fclose(codeLogfile);
- delete codeLogger;
+ fclose(codeLogfile);
+ delete codeLogger;
#endif
- return fn;
+ return fn;
+ });
}
} // namespace fbgemm
diff --git a/src/GenerateKernelU8S8S32ACC32.cc b/src/GenerateKernelU8S8S32ACC32.cc
index 6b54743..4f2b160 100644
--- a/src/GenerateKernelU8S8S32ACC32.cc
+++ b/src/GenerateKernelU8S8S32ACC32.cc
@@ -9,18 +9,6 @@
namespace fbgemm {
-template <typename TA, typename TB, typename TC, typename accT>
-thread_local asmjit::JitRuntime CodeGenBase<TA, TB, TC, accT>::rt_;
-
-template <typename TA, typename TB, typename TC, typename accT>
-thread_local asmjit::CodeHolder CodeGenBase<TA, TB, TC, accT>::code_;
-
-template <typename TA, typename TB, typename TC, typename accT>
-thread_local std::map<
- std::tuple<bool, int, int, int, int, int, int, int>,
- typename CodeGenBase<TA, TB, TC, accT>::jit_micro_kernel_fp>
- CodeGenBase<TA, TB, TC, accT>::codeCache_;
-
namespace x86 = asmjit::x86;
/**
@@ -173,206 +161,196 @@ CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreate<inst_set_t::avx2>(
nRegBlockSize,
nRegBlockSizeMin);
- if (codeCache_.find(kernelSig) != codeCache_.end()) {
- return codeCache_[kernelSig];
- }
- code_.reset(false);
- code_.init(rt_.codeInfo());
- x86::Assembler assembler(&code_);
- x86::Emitter* a = assembler.as<x86::Emitter>();
+ return codeCache_.getOrCreate(kernelSig, [&]() -> jit_micro_kernel_fp {
+ asmjit::CodeHolder code;
+ code.init(rt_.codeInfo());
+ x86::Assembler assembler(&code);
+ x86::Emitter *a = assembler.as<x86::Emitter>();
#if defined(FBGEMM_LOG_CODE)
- // generated code logging
- FILE* codeLogfile = fopen(
- getCodeLoggingFile<inst_set_t::avx2>(
- accum,
- mc,
- nc,
- nBlock,
- kBlock,
- mRegBlockSize,
- nRegBlockSize,
- nRegBlockSizeMin)
- .c_str(),
- "w");
- asmjit::FileLogger* codeLogger = new asmjit::FileLogger(codeLogfile);
- if (codeLogger) {
- code_.setLogger(codeLogger);
- }
+ // generated code logging
+ FILE *codeLogfile = fopen(getCodeLoggingFile<inst_set_t::avx2>(
+ accum, mc, nc, nBlock, kBlock, mRegBlockSize,
+ nRegBlockSize, nRegBlockSizeMin)
+ .c_str(),
+ "w");
+ asmjit::FileLogger *codeLogger = new asmjit::FileLogger(codeLogfile);
+ if (codeLogger) {
+ code.setLogger(codeLogger);
+ }
#endif
- // assert(mc <= 12 && "mc must be <= 12 (available registers constraint)");
- int mRegBlocks = mc / mRegBlockSize;
- int mRegBlocksRem = mc % mRegBlockSize;
-
- // arguments to the function created
- x86::Gp buffer_A = a->zdi();
- x86::Gp buffer_B = a->zsi();
- x86::Gp B_pf = a->zdx();
- x86::Gp CBase = a->zcx();
- x86::Gp kSize = a->gpz(8);
- x86::Gp ldcReg = a->gpz(9);
-
- asmjit::FuncDetail func;
- func.init(
- asmjit::
- FuncSignatureT<void, uint8_t*, int8_t*, int8_t*, int32_t*, int, int>(
- asmjit::CallConv::kIdHost));
-
- asmjit::FuncFrame frame;
- frame.init(func);
- frame.setDirtyRegs(
- x86::Reg::kGroupVec,
- asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) |
- asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
- frame.setDirtyRegs(
- x86::Reg::kGroupGp, asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14));
-
- asmjit::FuncArgsAssignment args(&func);
- args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg);
-
- args.updateFuncFrame(frame);
- frame.finalize();
-
- a->emitProlog(frame);
- a->emitArgsAssignment(frame, args);
-
- asmjit::Label Loopk = a->newLabel();
- asmjit::Label LoopMBlocks = a->newLabel();
-
- x86::Gp buffer_B_saved = a->gpz(10);
- x86::Gp C_Offset = a->gpz(11);
- x86::Gp B_pf_saved = a->gpz(12);
- x86::Gp iIdx = a->gpz(13);
- x86::Gp kIdx = a->gpz(14);
- // x86::Gp B_pf = a->gpz(8);
-
- x86::Ymm oneReg = x86::ymm15;
- // create 16-bit 1s
- // i.e., oneReg[0:15] contains 0x0001, oneReg[16:31] contains 0x0001
- // and so on
- a->vpcmpeqw(oneReg, oneReg, oneReg);
- a->vpsrlw(oneReg, oneReg, 15);
- a->imul(ldcReg, ldcReg, static_cast<asmjit::Imm>(sizeof(int32_t)));
- a->mov(C_Offset, 0);
-
- int colRegs = nc * row_interleave * sizeof(int8_t) / VLEN_;
- if (mRegBlocks > 0) {
- // move 0 to iteration variables
- a->mov(iIdx, 0);
-
- // save B_buffer address
- a->mov(buffer_B_saved, buffer_B);
- a->mov(B_pf_saved, B_pf);
-
- a->bind(LoopMBlocks);
- a->inc(iIdx);
-
- int rowRegs = mRegBlockSize;
-
- // init C registers
- initCRegs<inst_set_t::avx2>(a, rowRegs, colRegs, colRegs);
-
- // init k loop index
- a->mov(kIdx, 0);
- a->bind(Loopk);
-
- // k is incremented by row_interleave
- a->add(kIdx, static_cast<asmjit::Imm>(row_interleave));
-
- genComputeBlock<inst_set_t::avx2>(
- a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock, colRegs);
-
- // update buffer_A address for next k iteration
- a->add(
- buffer_A, static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t)));
-
- // update buffer_B address for next k iteration
- a->add(
- buffer_B,
- static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t)));
- a->add(
- B_pf,
- static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t)));
-
- // a->add(B_pf, 32*sizeof(float));
-
- a->cmp(kIdx, kSize);
- a->jl(Loopk);
-
- // store C matrix
- storeCRegs<inst_set_t::avx2>(
- a, rowRegs, colRegs, C_Offset, ldcReg, accum, colRegs);
-
- // increment A for next block
- a->sub(buffer_A, kSize);
- a->add(
- buffer_A, static_cast<asmjit::Imm>((rowRegs)*kBlock * sizeof(uint8_t)));
-
- // increment C for next block
- a->imul(C_Offset, ldcReg, static_cast<asmjit::Imm>(rowRegs));
- a->add(CBase, C_Offset);
+ // assert(mc <= 12 && "mc must be <= 12 (available registers constraint)");
+ int mRegBlocks = mc / mRegBlockSize;
+ int mRegBlocksRem = mc % mRegBlockSize;
+
+ // arguments to the function created
+ x86::Gp buffer_A = a->zdi();
+ x86::Gp buffer_B = a->zsi();
+ x86::Gp B_pf = a->zdx();
+ x86::Gp CBase = a->zcx();
+ x86::Gp kSize = a->gpz(8);
+ x86::Gp ldcReg = a->gpz(9);
+
+ asmjit::FuncDetail func;
+ func.init(
+ asmjit::FuncSignatureT<void, uint8_t *, int8_t *, int8_t *, int32_t *,
+ int, int>(asmjit::CallConv::kIdHost));
+
+ asmjit::FuncFrame frame;
+ frame.init(func);
+ frame.setDirtyRegs(
+ x86::Reg::kGroupVec,
+ asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) |
+ asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
+ frame.setDirtyRegs(x86::Reg::kGroupGp,
+ asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14));
+
+ asmjit::FuncArgsAssignment args(&func);
+ args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg);
+
+ args.updateFuncFrame(frame);
+ frame.finalize();
+
+ a->emitProlog(frame);
+ a->emitArgsAssignment(frame, args);
+
+ asmjit::Label Loopk = a->newLabel();
+ asmjit::Label LoopMBlocks = a->newLabel();
+
+ x86::Gp buffer_B_saved = a->gpz(10);
+ x86::Gp C_Offset = a->gpz(11);
+ x86::Gp B_pf_saved = a->gpz(12);
+ x86::Gp iIdx = a->gpz(13);
+ x86::Gp kIdx = a->gpz(14);
+ // x86::Gp B_pf = a->gpz(8);
+
+ x86::Ymm oneReg = x86::ymm15;
+ // create 16-bit 1s
+ // i.e., oneReg[0:15] contains 0x0001, oneReg[16:31] contains 0x0001
+ // and so on
+ a->vpcmpeqw(oneReg, oneReg, oneReg);
+ a->vpsrlw(oneReg, oneReg, 15);
+ a->imul(ldcReg, ldcReg, static_cast<asmjit::Imm>(sizeof(int32_t)));
a->mov(C_Offset, 0);
- // reset B
- a->mov(buffer_B, buffer_B_saved);
- a->mov(B_pf, B_pf_saved);
- a->cmp(iIdx, mRegBlocks);
- a->jl(LoopMBlocks);
- }
- // generate code for remainder
- if (mRegBlocksRem > 0) {
- asmjit::Label LoopkRem = a->newLabel();
- int rowRegs = mRegBlocksRem;
-
- // init C registers
- initCRegs<inst_set_t::avx2>(a, rowRegs, colRegs, colRegs);
-
- // init k loop index
- a->mov(kIdx, 0);
- a->bind(LoopkRem);
-
- // k is incremented by row_interleave
- a->add(kIdx, static_cast<asmjit::Imm>(row_interleave));
-
- genComputeBlock<inst_set_t::avx2>(
- a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock, colRegs);
-
- // update buffer_A address for next k iteration
- a->add(
- buffer_A, static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t)));
-
- // update buffer_B address for next k iteration
- a->add(
- buffer_B,
- static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t)));
- a->add(
- B_pf,
- static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t)));
-
- a->cmp(kIdx, kSize);
- a->jl(LoopkRem);
-
- // store C matrix
- storeCRegs<inst_set_t::avx2>(
- a, rowRegs, colRegs, C_Offset, ldcReg, accum, colRegs);
- }
+ int colRegs = nc * row_interleave * sizeof(int8_t) / VLEN_;
+ if (mRegBlocks > 0) {
+ // move 0 to iteration variables
+ a->mov(iIdx, 0);
- a->emitEpilog(frame);
+ // save B_buffer address
+ a->mov(buffer_B_saved, buffer_B);
+ a->mov(B_pf_saved, B_pf);
- jit_micro_kernel_fp fn;
- asmjit::Error err = rt_.add(&fn, &code_);
- if (err) {
- std::cout << "Error: in fn add" << std::endl;
- return nullptr;
- }
- codeCache_[kernelSig] = fn;
+ a->bind(LoopMBlocks);
+ a->inc(iIdx);
+
+ int rowRegs = mRegBlockSize;
+
+ // init C registers
+ initCRegs<inst_set_t::avx2>(a, rowRegs, colRegs, colRegs);
+
+ // init k loop index
+ a->mov(kIdx, 0);
+ a->bind(Loopk);
+
+ // k is incremented by row_interleave
+ a->add(kIdx, static_cast<asmjit::Imm>(row_interleave));
+
+ genComputeBlock<inst_set_t::avx2>(a, buffer_A, buffer_B, B_pf, rowRegs,
+ colRegs, kBlock, colRegs);
+
+ // update buffer_A address for next k iteration
+ a->add(buffer_A,
+ static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t)));
+
+ // update buffer_B address for next k iteration
+ a->add(buffer_B, static_cast<asmjit::Imm>(nBlock * row_interleave *
+ sizeof(int8_t)));
+ a->add(B_pf, static_cast<asmjit::Imm>(nBlock * row_interleave *
+ sizeof(int8_t)));
+
+ // a->add(B_pf, 32*sizeof(float));
+
+ a->cmp(kIdx, kSize);
+ a->jl(Loopk);
+
+ // store C matrix
+ storeCRegs<inst_set_t::avx2>(a, rowRegs, colRegs, C_Offset, ldcReg, accum,
+ colRegs);
+
+ // increment A for next block
+ a->sub(buffer_A, kSize);
+ a->add(buffer_A,
+ static_cast<asmjit::Imm>((rowRegs)*kBlock * sizeof(uint8_t)));
+
+ // increment C for next block
+ a->imul(C_Offset, ldcReg, static_cast<asmjit::Imm>(rowRegs));
+ a->add(CBase, C_Offset);
+ a->mov(C_Offset, 0);
+
+ // reset B
+ a->mov(buffer_B, buffer_B_saved);
+ a->mov(B_pf, B_pf_saved);
+ a->cmp(iIdx, mRegBlocks);
+ a->jl(LoopMBlocks);
+ }
+ // generate code for remainder
+ if (mRegBlocksRem > 0) {
+ asmjit::Label LoopkRem = a->newLabel();
+ int rowRegs = mRegBlocksRem;
+
+ // init C registers
+ initCRegs<inst_set_t::avx2>(a, rowRegs, colRegs, colRegs);
+
+ // init k loop index
+ a->mov(kIdx, 0);
+ a->bind(LoopkRem);
+
+ // k is incremented by row_interleave
+ a->add(kIdx, static_cast<asmjit::Imm>(row_interleave));
+
+ genComputeBlock<inst_set_t::avx2>(a, buffer_A, buffer_B, B_pf, rowRegs,
+ colRegs, kBlock, colRegs);
+
+ // update buffer_A address for next k iteration
+ a->add(buffer_A,
+ static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t)));
+
+ // update buffer_B address for next k iteration
+ a->add(buffer_B, static_cast<asmjit::Imm>(nBlock * row_interleave *
+ sizeof(int8_t)));
+ a->add(B_pf, static_cast<asmjit::Imm>(nBlock * row_interleave *
+ sizeof(int8_t)));
+
+ a->cmp(kIdx, kSize);
+ a->jl(LoopkRem);
+
+ // store C matrix
+ storeCRegs<inst_set_t::avx2>(a, rowRegs, colRegs, C_Offset, ldcReg, accum,
+ colRegs);
+ }
+
+ a->emitEpilog(frame);
+
+ jit_micro_kernel_fp fn;
+ asmjit::Error err;
+ {
+ std::unique_lock<std::mutex> lock(rtMutex_);
+ err = rt_.add(&fn, &code);
+ }
+ if (err) {
+ std::cout << "Error: in fn add" << std::endl;
+ return nullptr;
+ }
#if defined(FBGEMM_LOG_CODE)
- fclose(codeLogfile);
- delete codeLogger;
+ fclose(codeLogfile);
+ delete codeLogger;
#endif
- return fn;
+ return fn;
+ });
}
} // namespace fbgemm
diff --git a/src/GenerateKernelU8S8S32ACC32Avx512.cc b/src/GenerateKernelU8S8S32ACC32Avx512.cc
index fe35627..16f78d0 100644
--- a/src/GenerateKernelU8S8S32ACC32Avx512.cc
+++ b/src/GenerateKernelU8S8S32ACC32Avx512.cc
@@ -163,278 +163,260 @@ CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreate<inst_set_t::avx512>(
nRegBlockSize,
nRegBlockSizeMin);
- if (codeCache_.find(kernelSig) != codeCache_.end()) {
- return codeCache_[kernelSig];
- }
- code_.reset(false);
- code_.init(rt_.codeInfo());
- x86::Assembler assembler(&code_);
- x86::Emitter* a = assembler.as<x86::Emitter>();
-
+ return codeCache_.getOrCreate(kernelSig, [&]() -> jit_micro_kernel_fp {
+ asmjit::CodeHolder code;
+ code.init(rt_.codeInfo());
+ x86::Assembler assembler(&code);
+ x86::Emitter *a = assembler.as<x86::Emitter>();
#if defined(FBGEMM_LOG_CODE)
- // generated code logging
- FILE* codeLogfile = fopen(
- getCodeLoggingFile<inst_set_t::avx512>(
- accum,
- mc,
- nc,
- nBlock,
- kBlock,
- mRegBlockSize,
- nRegBlockSize,
- nRegBlockSizeMin)
- .c_str(),
- "w");
- asmjit::FileLogger* codeLogger = new asmjit::FileLogger(codeLogfile);
- if (codeLogger) {
- code_.setLogger(codeLogger);
- }
+ // generated code logging
+ FILE *codeLogfile = fopen(getCodeLoggingFile<inst_set_t::avx512>(
+ accum, mc, nc, nBlock, kBlock, mRegBlockSize,
+ nRegBlockSize, nRegBlockSizeMin)
+ .c_str(),
+ "w");
+ asmjit::FileLogger *codeLogger = new asmjit::FileLogger(codeLogfile);
+ if (codeLogger) {
+ code.setLogger(codeLogger);
+ }
#endif
- assert(kc % row_interleave == 0 && "kc must be a multiple of row_interleave");
- assert(nc % nRegBlockSizeMin == 0 && "nc must be a multiple of NR_MIN");
- int maxMRegs = mRegBlockSize;
- int maxNRegs = nRegBlockSize * row_interleave / VLEN_;
- assert(
- maxMRegs * maxNRegs <= 28 &&
- "MR*(NR*ROW_INTERLEAVE*8/512) \
+ assert(kc % row_interleave == 0 &&
+ "kc must be a multiple of row_interleave");
+ assert(nc % nRegBlockSizeMin == 0 && "nc must be a multiple of NR_MIN");
+ int maxMRegs = mRegBlockSize;
+ int maxNRegs = nRegBlockSize * row_interleave / VLEN_;
+ assert(maxMRegs * maxNRegs <= 28 && "MR*(NR*ROW_INTERLEAVE*8/512) \
must be <= 28(available registers constraint)");
- int mRegBlocks = mc / mRegBlockSize;
- int mRegBlocksRem = mc % mRegBlockSize;
-
- // arguments to the function created
- x86::Gp buffer_A = a->zdi();
- x86::Gp buffer_B = a->zsi();
- x86::Gp B_pf = a->zdx();
- x86::Gp CBase = a->zcx();
- x86::Gp kSize = a->gpz(8);
- x86::Gp ldcReg = a->gpz(9);
-
- asmjit::FuncDetail func;
- func.init(
- asmjit::
- FuncSignatureT<void, uint8_t*, int8_t*, int8_t*, int32_t*, int, int>(
- asmjit::CallConv::kIdHost));
-
- asmjit::FuncFrame frame;
- frame.init(func);
-
- frame.setDirtyRegs(
- x86::Reg::kGroupVec,
- asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) |
- asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
- frame.setDirtyRegs(
- x86::Reg::kGroupGp,
- asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
-
- asmjit::FuncArgsAssignment args(&func);
- args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg);
-
- args.updateFuncFrame(frame);
- frame.finalize();
-
- a->emitProlog(frame);
- a->emitArgsAssignment(frame, args);
-
- asmjit::Label LoopMBlocks = a->newLabel();
- asmjit::Label LoopNBlocks = a->newLabel();
- asmjit::Label Loopk = a->newLabel();
-
- x86::Gp buffer_B_saved = a->gpz(10);
- x86::Gp C_Offset = a->gpz(11);
- x86::Gp B_pf_saved = a->gpz(12);
- x86::Gp iIdx = a->gpz(13);
- x86::Gp jIdx = a->gpz(14);
- x86::Gp kIdx = a->gpz(15);
- // x86::Gp B_pf = a->gpz(8);
-
- x86::Zmm oneReg = x86::zmm29;
- // create 16-bit 1s
- // i.e., oneReg[0:15] contains 0x0001, oneReg[16:31] contains 0x0001
- // and so on
- // a->vpcmpeqw(oneReg, oneReg, oneReg);
- a->vpternlogd(oneReg, oneReg, oneReg, 0xff);
- a->vpsrlw(oneReg, oneReg, 15);
- a->imul(ldcReg, ldcReg, static_cast<asmjit::Imm>(sizeof(int32_t)));
-
- // save B_buffer address
- a->mov(buffer_B_saved, buffer_B);
- a->mov(B_pf_saved, B_pf);
-
- int currColRegs = nc * row_interleave * sizeof(int8_t) / VLEN_;
- int colRegs = std::min(currColRegs, maxNRegs);
- if (mRegBlocks > 0) {
- // move 0 to iteration variables
- a->mov(iIdx, 0);
-
- a->bind(LoopMBlocks);
- a->inc(iIdx);
- a->mov(jIdx, 0);
-
- a->bind(LoopNBlocks);
- a->inc(jIdx);
-
- int rowRegs = mRegBlockSize;
-
- // init C registers
- initCRegs<inst_set_t::avx512>(a, rowRegs, colRegs, colRegs);
-
- // init k loop index
- a->mov(kIdx, 0);
- a->bind(Loopk);
-
- // k is incremented by row_interleave
- a->add(kIdx, static_cast<asmjit::Imm>(row_interleave));
-
- genComputeBlock<inst_set_t::avx512>(
- a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock, colRegs);
-
- // update buffer_A address for next k iteration
- a->add(
- buffer_A, static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t)));
-
- // update buffer_B address for next k iteration
- a->add(
- buffer_B,
- static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t)));
- a->add(
- B_pf,
- static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t)));
-
- // a->add(B_pf, static_cast<asmjit::Imm>(32*sizeof(float)));
-
- a->cmp(kIdx, kSize);
- a->jl(Loopk);
-
- // store C matrix
- storeCRegs<inst_set_t::avx512>(
- a, rowRegs, colRegs, C_Offset, ldcReg, accum, colRegs);
-
- // reset A
- a->sub(buffer_A, kSize);
-
- // B for next block
- a->mov(buffer_B, buffer_B_saved);
- // using C_Offset as temp reg
- a->imul(
- C_Offset,
- jIdx,
- static_cast<asmjit::Imm>(
- nRegBlockSize * row_interleave * sizeof(int8_t)));
- a->add(buffer_B, C_Offset);
- a->mov(B_pf, B_pf_saved);
- a->add(B_pf, C_Offset);
-
- // increment C for next B block
- a->add(CBase, static_cast<asmjit::Imm>(nRegBlockSize * sizeof(int32_t)));
-
- int jLoopTrips = currColRegs / maxNRegs;
- // jLoopTrips should be at least 1
- jLoopTrips = jLoopTrips ? jLoopTrips : 1;
- a->cmp(jIdx, jLoopTrips);
- a->jl(LoopNBlocks);
-
- // increment A for next block
- a->add(
- buffer_A, static_cast<asmjit::Imm>((rowRegs)*kBlock * sizeof(uint8_t)));
-
- // increment C for next A block
- a->sub(
- CBase,
- static_cast<asmjit::Imm>(jLoopTrips * nRegBlockSize * sizeof(int32_t)));
- a->imul(C_Offset, ldcReg, static_cast<asmjit::Imm>(rowRegs));
- a->add(CBase, C_Offset);
-
- // reset B
- a->mov(buffer_B, buffer_B_saved);
- a->mov(B_pf, B_pf_saved);
- a->cmp(iIdx, mRegBlocks);
- a->jl(LoopMBlocks);
- }
- // generate code for remainder
- if (mRegBlocksRem > 0) {
- asmjit::Label LoopNRem = a->newLabel();
- asmjit::Label LoopkRem = a->newLabel();
- int rowRegs = mRegBlocksRem;
-
- a->mov(jIdx, 0);
- a->bind(LoopNRem);
- a->inc(jIdx);
-
- // init C registers
- initCRegs<inst_set_t::avx512>(a, rowRegs, colRegs, colRegs);
-
- // init k loop index
- a->mov(kIdx, 0);
- a->bind(LoopkRem);
-
- // k is incremented by row_interleave
- a->add(kIdx, static_cast<asmjit::Imm>(row_interleave));
-
- genComputeBlock<inst_set_t::avx512>(
- a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock, colRegs);
-
- // update buffer_A address for next k iteration
- a->add(
- buffer_A, static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t)));
-
- // update buffer_B address for next k iteration
- a->add(
- buffer_B,
- static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t)));
- a->add(
- B_pf,
- static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t)));
-
- a->cmp(kIdx, kSize);
- a->jl(LoopkRem);
-
- // reset A
- a->sub(buffer_A, kSize);
-
- // B for next block
- // using C_Offset as temp reg
- a->imul(
- C_Offset,
- jIdx,
- static_cast<asmjit::Imm>(
- nRegBlockSize * row_interleave * sizeof(int8_t)));
- a->mov(buffer_B, buffer_B_saved);
- a->add(buffer_B, C_Offset);
- a->mov(B_pf, B_pf_saved);
- a->add(B_pf, C_Offset);
-
- // store C matrix
- storeCRegs<inst_set_t::avx512>(
- a, rowRegs, colRegs, C_Offset, ldcReg, accum, colRegs);
-
- // increment C for next B block
- a->add(CBase, static_cast<asmjit::Imm>(nRegBlockSize * sizeof(int32_t)));
-
- int jLoopTrips = currColRegs / maxNRegs;
- // jLoopTrips should be at least 1
- jLoopTrips = jLoopTrips ? jLoopTrips : 1;
- a->cmp(jIdx, jLoopTrips);
- a->jl(LoopNRem);
- }
+ int mRegBlocks = mc / mRegBlockSize;
+ int mRegBlocksRem = mc % mRegBlockSize;
+
+ // arguments to the function created
+ x86::Gp buffer_A = a->zdi();
+ x86::Gp buffer_B = a->zsi();
+ x86::Gp B_pf = a->zdx();
+ x86::Gp CBase = a->zcx();
+ x86::Gp kSize = a->gpz(8);
+ x86::Gp ldcReg = a->gpz(9);
+
+ asmjit::FuncDetail func;
+ func.init(
+ asmjit::FuncSignatureT<void, uint8_t *, int8_t *, int8_t *, int32_t *,
+ int, int>(asmjit::CallConv::kIdHost));
+
+ asmjit::FuncFrame frame;
+ frame.init(func);
+
+ frame.setDirtyRegs(
+ x86::Reg::kGroupVec,
+ asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) |
+ asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
+ frame.setDirtyRegs(x86::Reg::kGroupGp,
+ asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
+
+ asmjit::FuncArgsAssignment args(&func);
+ args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg);
+
+ args.updateFuncFrame(frame);
+ frame.finalize();
+
+ a->emitProlog(frame);
+ a->emitArgsAssignment(frame, args);
+
+ asmjit::Label LoopMBlocks = a->newLabel();
+ asmjit::Label LoopNBlocks = a->newLabel();
+ asmjit::Label Loopk = a->newLabel();
+
+ x86::Gp buffer_B_saved = a->gpz(10);
+ x86::Gp C_Offset = a->gpz(11);
+ x86::Gp B_pf_saved = a->gpz(12);
+ x86::Gp iIdx = a->gpz(13);
+ x86::Gp jIdx = a->gpz(14);
+ x86::Gp kIdx = a->gpz(15);
+ // x86::Gp B_pf = a->gpz(8);
+
+ x86::Zmm oneReg = x86::zmm29;
+ // create 16-bit 1s
+ // i.e., oneReg[0:15] contains 0x0001, oneReg[16:31] contains 0x0001
+ // and so on
+ // a->vpcmpeqw(oneReg, oneReg, oneReg);
+ a->vpternlogd(oneReg, oneReg, oneReg, 0xff);
+ a->vpsrlw(oneReg, oneReg, 15);
+ a->imul(ldcReg, ldcReg, static_cast<asmjit::Imm>(sizeof(int32_t)));
+
+ // save B_buffer address
+ a->mov(buffer_B_saved, buffer_B);
+ a->mov(B_pf_saved, B_pf);
+
+ int currColRegs = nc * row_interleave * sizeof(int8_t) / VLEN_;
+ int colRegs = std::min(currColRegs, maxNRegs);
+ if (mRegBlocks > 0) {
+ // move 0 to iteration variables
+ a->mov(iIdx, 0);
+
+ a->bind(LoopMBlocks);
+ a->inc(iIdx);
+ a->mov(jIdx, 0);
+
+ a->bind(LoopNBlocks);
+ a->inc(jIdx);
+
+ int rowRegs = mRegBlockSize;
+
+ // init C registers
+ initCRegs<inst_set_t::avx512>(a, rowRegs, colRegs, colRegs);
+
+ // init k loop index
+ a->mov(kIdx, 0);
+ a->bind(Loopk);
+
+ // k is incremented by row_interleave
+ a->add(kIdx, static_cast<asmjit::Imm>(row_interleave));
+
+ genComputeBlock<inst_set_t::avx512>(a, buffer_A, buffer_B, B_pf, rowRegs,
+ colRegs, kBlock, colRegs);
+
+ // update buffer_A address for next k iteration
+ a->add(buffer_A,
+ static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t)));
+
+ // update buffer_B address for next k iteration
+ a->add(buffer_B, static_cast<asmjit::Imm>(nBlock * row_interleave *
+ sizeof(int8_t)));
+ a->add(B_pf, static_cast<asmjit::Imm>(nBlock * row_interleave *
+ sizeof(int8_t)));
+
+ // a->add(B_pf, static_cast<asmjit::Imm>(32*sizeof(float)));
+
+ a->cmp(kIdx, kSize);
+ a->jl(Loopk);
+
+ // store C matrix
+ storeCRegs<inst_set_t::avx512>(a, rowRegs, colRegs, C_Offset, ldcReg,
+ accum, colRegs);
+
+ // reset A
+ a->sub(buffer_A, kSize);
+
+ // B for next block
+ a->mov(buffer_B, buffer_B_saved);
+ // using C_Offset as temp reg
+ a->imul(C_Offset, jIdx,
+ static_cast<asmjit::Imm>(nRegBlockSize * row_interleave *
+ sizeof(int8_t)));
+ a->add(buffer_B, C_Offset);
+ a->mov(B_pf, B_pf_saved);
+ a->add(B_pf, C_Offset);
+
+ // increment C for next B block
+ a->add(CBase, static_cast<asmjit::Imm>(nRegBlockSize * sizeof(int32_t)));
+
+ int jLoopTrips = currColRegs / maxNRegs;
+ // jLoopTrips should be at least 1
+ jLoopTrips = jLoopTrips ? jLoopTrips : 1;
+ a->cmp(jIdx, jLoopTrips);
+ a->jl(LoopNBlocks);
+
+ // increment A for next block
+ a->add(buffer_A,
+ static_cast<asmjit::Imm>((rowRegs)*kBlock * sizeof(uint8_t)));
+
+ // increment C for next A block
+ a->sub(CBase, static_cast<asmjit::Imm>(jLoopTrips * nRegBlockSize *
+ sizeof(int32_t)));
+ a->imul(C_Offset, ldcReg, static_cast<asmjit::Imm>(rowRegs));
+ a->add(CBase, C_Offset);
+
+ // reset B
+ a->mov(buffer_B, buffer_B_saved);
+ a->mov(B_pf, B_pf_saved);
+ a->cmp(iIdx, mRegBlocks);
+ a->jl(LoopMBlocks);
+ }
+ // generate code for remainder
+ if (mRegBlocksRem > 0) {
+ asmjit::Label LoopNRem = a->newLabel();
+ asmjit::Label LoopkRem = a->newLabel();
+ int rowRegs = mRegBlocksRem;
+
+ a->mov(jIdx, 0);
+ a->bind(LoopNRem);
+ a->inc(jIdx);
+
+ // init C registers
+ initCRegs<inst_set_t::avx512>(a, rowRegs, colRegs, colRegs);
+
+ // init k loop index
+ a->mov(kIdx, 0);
+ a->bind(LoopkRem);
+
+ // k is incremented by row_interleave
+ a->add(kIdx, static_cast<asmjit::Imm>(row_interleave));
+
+ genComputeBlock<inst_set_t::avx512>(a, buffer_A, buffer_B, B_pf, rowRegs,
+ colRegs, kBlock, colRegs);
+
+ // update buffer_A address for next k iteration
+ a->add(buffer_A,
+ static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t)));
+
+ // update buffer_B address for next k iteration
+ a->add(buffer_B, static_cast<asmjit::Imm>(nBlock * row_interleave *
+ sizeof(int8_t)));
+ a->add(B_pf, static_cast<asmjit::Imm>(nBlock * row_interleave *
+ sizeof(int8_t)));
+
+ a->cmp(kIdx, kSize);
+ a->jl(LoopkRem);
+
+ // reset A
+ a->sub(buffer_A, kSize);
+
+ // B for next block
+ // using C_Offset as temp reg
+ a->imul(C_Offset, jIdx,
+ static_cast<asmjit::Imm>(nRegBlockSize * row_interleave *
+ sizeof(int8_t)));
+ a->mov(buffer_B, buffer_B_saved);
+ a->add(buffer_B, C_Offset);
+ a->mov(B_pf, B_pf_saved);
+ a->add(B_pf, C_Offset);
+
+ // store C matrix
+ storeCRegs<inst_set_t::avx512>(a, rowRegs, colRegs, C_Offset, ldcReg,
+ accum, colRegs);
+
+ // increment C for next B block
+ a->add(CBase, static_cast<asmjit::Imm>(nRegBlockSize * sizeof(int32_t)));
+
+ int jLoopTrips = currColRegs / maxNRegs;
+ // jLoopTrips should be at least 1
+ jLoopTrips = jLoopTrips ? jLoopTrips : 1;
+ a->cmp(jIdx, jLoopTrips);
+ a->jl(LoopNRem);
+ }
- a->emitEpilog(frame);
+ a->emitEpilog(frame);
- jit_micro_kernel_fp fn;
- asmjit::Error err = rt_.add(&fn, &code_);
- if (err) {
- std::cout << "Error: in fn add" << std::endl;
- return nullptr;
- }
- codeCache_[kernelSig] = fn;
+ jit_micro_kernel_fp fn;
+ asmjit::Error err;
+ {
+ std::unique_lock<std::mutex> lock(rtMutex_);
+ err = rt_.add(&fn, &code);
+ }
+ if (err) {
+ std::cout << "Error: in fn add" << std::endl;
+ return nullptr;
+ }
#if defined(FBGEMM_LOG_CODE)
- fclose(codeLogfile);
- delete codeLogger;
+ fclose(codeLogfile);
+ delete codeLogger;
#endif
- return fn;
+ return fn;
+ });
}
} // namespace fbgemm
diff --git a/src/GenerateKernelU8S8S32ACC32Avx512VNNI.cc b/src/GenerateKernelU8S8S32ACC32Avx512VNNI.cc
index 8ae0745..23103ae 100644
--- a/src/GenerateKernelU8S8S32ACC32Avx512VNNI.cc
+++ b/src/GenerateKernelU8S8S32ACC32Avx512VNNI.cc
@@ -155,277 +155,260 @@ CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreate<
nRegBlockSize,
nRegBlockSizeMin);
- if (codeCache_.find(kernelSig) != codeCache_.end()) {
- return codeCache_[kernelSig];
- }
- code_.reset(false);
- code_.init(rt_.codeInfo());
- x86::Assembler assembler(&code_);
- x86::Emitter* a = assembler.as<x86::Emitter>();
+ return codeCache_.getOrCreate(kernelSig, [&]() -> jit_micro_kernel_fp {
+ asmjit::CodeHolder code;
+ code.init(rt_.codeInfo());
+ x86::Assembler assembler(&code);
+ x86::Emitter *a = assembler.as<x86::Emitter>();
#if defined(FBGEMM_LOG_CODE)
- // generated code logging
- FILE* codeLogfile = fopen(
- getCodeLoggingFile<inst_set_t::avx512_vnni>(
- accum,
- mc,
- nc,
- nBlock,
- kBlock,
- mRegBlockSize,
- nRegBlockSize,
- nRegBlockSizeMin)
- .c_str(),
- "w");
- asmjit::FileLogger* codeLogger = new asmjit::FileLogger(codeLogfile);
- if (codeLogger) {
- code_.setLogger(codeLogger);
- }
+ // generated code logging
+ FILE *codeLogfile = fopen(getCodeLoggingFile<inst_set_t::avx512_vnni>(
+ accum, mc, nc, nBlock, kBlock, mRegBlockSize,
+ nRegBlockSize, nRegBlockSizeMin)
+ .c_str(),
+ "w");
+ asmjit::FileLogger *codeLogger = new asmjit::FileLogger(codeLogfile);
+ if (codeLogger) {
+ code.setLogger(codeLogger);
+ }
#endif
- assert(kc % row_interleave == 0 && "kc must be a multiple of row_interleave");
- assert(nc % nRegBlockSizeMin == 0 && "nc must be a multiple of NR_MIN");
- int maxMRegs = mRegBlockSize;
- int maxNRegs = nRegBlockSize * row_interleave / VLEN_;
- assert(
- maxMRegs * maxNRegs <= 28 &&
- "MR*(NR*ROW_INTERLEAVE*8/512) \
+ assert(kc % row_interleave == 0 &&
+ "kc must be a multiple of row_interleave");
+ assert(nc % nRegBlockSizeMin == 0 && "nc must be a multiple of NR_MIN");
+ int maxMRegs = mRegBlockSize;
+ int maxNRegs = nRegBlockSize * row_interleave / VLEN_;
+ assert(maxMRegs * maxNRegs <= 28 && "MR*(NR*ROW_INTERLEAVE*8/512) \
must be <= 28(available registers constraint)");
- int mRegBlocks = mc / mRegBlockSize;
- int mRegBlocksRem = mc % mRegBlockSize;
-
- // arguments to the function created
- x86::Gp buffer_A = a->zdi();
- x86::Gp buffer_B = a->zsi();
- x86::Gp B_pf = a->zdx();
- x86::Gp CBase = a->zcx();
- x86::Gp kSize = a->gpz(8);
- x86::Gp ldcReg = a->gpz(9);
-
- asmjit::FuncDetail func;
- func.init(
- asmjit::
- FuncSignatureT<void, uint8_t*, int8_t*, int8_t*, int32_t*, int, int>(
- asmjit::CallConv::kIdHost));
-
- asmjit::FuncFrame frame;
- frame.init(func);
-
- frame.setDirtyRegs(
- x86::Reg::kGroupVec,
- asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) |
- asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
- frame.setDirtyRegs(
- x86::Reg::kGroupGp,
- asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
-
- asmjit::FuncArgsAssignment args(&func);
- args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg);
-
- args.updateFuncFrame(frame);
- frame.finalize();
-
- a->emitProlog(frame);
- a->emitArgsAssignment(frame, args);
-
- asmjit::Label LoopMBlocks = a->newLabel();
- asmjit::Label LoopNBlocks = a->newLabel();
- asmjit::Label Loopk = a->newLabel();
-
- x86::Gp buffer_B_saved = a->gpz(10);
- x86::Gp C_Offset = a->gpz(11);
- x86::Gp B_pf_saved = a->gpz(12);
- x86::Gp iIdx = a->gpz(13);
- x86::Gp jIdx = a->gpz(14);
- x86::Gp kIdx = a->gpz(15);
- // x86::Gp B_pf = a->gpz(8);
-
- x86::Zmm oneReg = x86::zmm29;
- // create 16-bit 1s
- // i.e., oneReg[0:15] contains 0x0001, oneReg[16:31] contains 0x0001
- // and so on
- // a->vpcmpeqw(oneReg, oneReg, oneReg);
- a->vpternlogd(oneReg, oneReg, oneReg, 0xff);
- a->vpsrlw(oneReg, oneReg, 15);
- a->imul(ldcReg, ldcReg, static_cast<asmjit::Imm>(sizeof(int32_t)));
-
- // save B_buffer address
- a->mov(buffer_B_saved, buffer_B);
- a->mov(B_pf_saved, B_pf);
-
- int currColRegs = nc * row_interleave * sizeof(int8_t) / VLEN_;
- int colRegs = std::min(currColRegs, maxNRegs);
- if (mRegBlocks > 0) {
- // move 0 to iteration variables
- a->mov(iIdx, 0);
-
- a->bind(LoopMBlocks);
- a->inc(iIdx);
- a->mov(jIdx, 0);
-
- a->bind(LoopNBlocks);
- a->inc(jIdx);
-
- int rowRegs = mRegBlockSize;
-
- // init C registers
- initCRegs<inst_set_t::avx512_vnni>(a, rowRegs, colRegs, colRegs);
-
- // init k loop index
- a->mov(kIdx, 0);
- a->bind(Loopk);
-
- // k is incremented by row_interleave
- a->add(kIdx, static_cast<asmjit::Imm>(row_interleave));
-
- genComputeBlock<inst_set_t::avx512_vnni>(
- a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock, colRegs);
-
- // update buffer_A address for next k iteration
- a->add(
- buffer_A, static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t)));
-
- // update buffer_B address for next k iteration
- a->add(
- buffer_B,
- static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t)));
- a->add(
- B_pf,
- static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t)));
-
- // a->add(B_pf, static_cast<asmjit::Imm>(32*sizeof(float)));
-
- a->cmp(kIdx, kSize);
- a->jl(Loopk);
-
- // store C matrix
- storeCRegs<inst_set_t::avx512_vnni>(
- a, rowRegs, colRegs, C_Offset, ldcReg, accum, colRegs);
-
- // reset A
- a->sub(buffer_A, kSize);
-
- // B for next block
- a->mov(buffer_B, buffer_B_saved);
- // using C_Offset as temp reg
- a->imul(
- C_Offset,
- jIdx,
- static_cast<asmjit::Imm>(
- nRegBlockSize * row_interleave * sizeof(int8_t)));
- a->add(buffer_B, C_Offset);
- a->mov(B_pf, B_pf_saved);
- a->add(B_pf, C_Offset);
-
- // increment C for next B block
- a->add(CBase, static_cast<asmjit::Imm>(nRegBlockSize * sizeof(int32_t)));
-
- int jLoopTrips = currColRegs / maxNRegs;
- // jLoopTrips should be at least 1
- jLoopTrips = jLoopTrips ? jLoopTrips : 1;
- a->cmp(jIdx, jLoopTrips);
- a->jl(LoopNBlocks);
-
- // increment A for next block
- a->add(
- buffer_A, static_cast<asmjit::Imm>((rowRegs)*kBlock * sizeof(uint8_t)));
-
- // increment C for next A block
- a->sub(
- CBase,
- static_cast<asmjit::Imm>(jLoopTrips * nRegBlockSize * sizeof(int32_t)));
- a->imul(C_Offset, ldcReg, static_cast<asmjit::Imm>(rowRegs));
- a->add(CBase, C_Offset);
-
- // reset B
- a->mov(buffer_B, buffer_B_saved);
- a->mov(B_pf, B_pf_saved);
- a->cmp(iIdx, mRegBlocks);
- a->jl(LoopMBlocks);
- }
- // generate code for remainder
- if (mRegBlocksRem > 0) {
- asmjit::Label LoopNRem = a->newLabel();
- asmjit::Label LoopkRem = a->newLabel();
- int rowRegs = mRegBlocksRem;
-
- a->mov(jIdx, 0);
- a->bind(LoopNRem);
- a->inc(jIdx);
-
- // init C registers
- initCRegs<inst_set_t::avx512_vnni>(a, rowRegs, colRegs, colRegs);
-
- // init k loop index
- a->mov(kIdx, 0);
- a->bind(LoopkRem);
-
- // k is incremented by row_interleave
- a->add(kIdx, static_cast<asmjit::Imm>(row_interleave));
-
- genComputeBlock<inst_set_t::avx512_vnni>(
- a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock, colRegs);
-
- // update buffer_A address for next k iteration
- a->add(
- buffer_A, static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t)));
-
- // update buffer_B address for next k iteration
- a->add(
- buffer_B,
- static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t)));
- a->add(
- B_pf,
- static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t)));
-
- a->cmp(kIdx, kSize);
- a->jl(LoopkRem);
-
- // reset A
- a->sub(buffer_A, kSize);
- // B for next block
- // using C_Offset as temp reg
- a->imul(
- C_Offset,
- jIdx,
- static_cast<asmjit::Imm>(
- nRegBlockSize * row_interleave * sizeof(int8_t)));
- a->mov(buffer_B, buffer_B_saved);
- a->add(buffer_B, C_Offset);
- a->mov(B_pf, B_pf_saved);
- a->add(B_pf, C_Offset);
-
- // store C matrix
- storeCRegs<inst_set_t::avx512_vnni>(
- a, rowRegs, colRegs, C_Offset, ldcReg, accum, colRegs);
-
- // increment C for next B block
- a->add(CBase, static_cast<asmjit::Imm>(nRegBlockSize * sizeof(int32_t)));
-
- int jLoopTrips = currColRegs / maxNRegs;
- // jLoopTrips should be at least 1
- jLoopTrips = jLoopTrips ? jLoopTrips : 1;
- a->cmp(jIdx, jLoopTrips);
- a->jl(LoopNRem);
- }
+ int mRegBlocks = mc / mRegBlockSize;
+ int mRegBlocksRem = mc % mRegBlockSize;
+
+ // arguments to the function created
+ x86::Gp buffer_A = a->zdi();
+ x86::Gp buffer_B = a->zsi();
+ x86::Gp B_pf = a->zdx();
+ x86::Gp CBase = a->zcx();
+ x86::Gp kSize = a->gpz(8);
+ x86::Gp ldcReg = a->gpz(9);
+
+ asmjit::FuncDetail func;
+ func.init(
+ asmjit::FuncSignatureT<void, uint8_t *, int8_t *, int8_t *, int32_t *,
+ int, int>(asmjit::CallConv::kIdHost));
+
+ asmjit::FuncFrame frame;
+ frame.init(func);
+
+ frame.setDirtyRegs(
+ x86::Reg::kGroupVec,
+ asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) |
+ asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
+ frame.setDirtyRegs(x86::Reg::kGroupGp,
+ asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
+
+ asmjit::FuncArgsAssignment args(&func);
+ args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg);
+
+ args.updateFuncFrame(frame);
+ frame.finalize();
+
+ a->emitProlog(frame);
+ a->emitArgsAssignment(frame, args);
+
+ asmjit::Label LoopMBlocks = a->newLabel();
+ asmjit::Label LoopNBlocks = a->newLabel();
+ asmjit::Label Loopk = a->newLabel();
+
+ x86::Gp buffer_B_saved = a->gpz(10);
+ x86::Gp C_Offset = a->gpz(11);
+ x86::Gp B_pf_saved = a->gpz(12);
+ x86::Gp iIdx = a->gpz(13);
+ x86::Gp jIdx = a->gpz(14);
+ x86::Gp kIdx = a->gpz(15);
+ // x86::Gp B_pf = a->gpz(8);
+
+ x86::Zmm oneReg = x86::zmm29;
+ // create 16-bit 1s
+ // i.e., oneReg[0:15] contains 0x0001, oneReg[16:31] contains 0x0001
+ // and so on
+ // a->vpcmpeqw(oneReg, oneReg, oneReg);
+ a->vpternlogd(oneReg, oneReg, oneReg, 0xff);
+ a->vpsrlw(oneReg, oneReg, 15);
+ a->imul(ldcReg, ldcReg, static_cast<asmjit::Imm>(sizeof(int32_t)));
+
+ // save B_buffer address
+ a->mov(buffer_B_saved, buffer_B);
+ a->mov(B_pf_saved, B_pf);
+
+ int currColRegs = nc * row_interleave * sizeof(int8_t) / VLEN_;
+ int colRegs = std::min(currColRegs, maxNRegs);
+ if (mRegBlocks > 0) {
+ // move 0 to iteration variables
+ a->mov(iIdx, 0);
+
+ a->bind(LoopMBlocks);
+ a->inc(iIdx);
+ a->mov(jIdx, 0);
+
+ a->bind(LoopNBlocks);
+ a->inc(jIdx);
+
+ int rowRegs = mRegBlockSize;
+
+ // init C registers
+ initCRegs<inst_set_t::avx512_vnni>(a, rowRegs, colRegs, colRegs);
+
+ // init k loop index
+ a->mov(kIdx, 0);
+ a->bind(Loopk);
+
+ // k is incremented by row_interleave
+ a->add(kIdx, static_cast<asmjit::Imm>(row_interleave));
+
+ genComputeBlock<inst_set_t::avx512_vnni>(
+ a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock, colRegs);
+
+ // update buffer_A address for next k iteration
+ a->add(buffer_A,
+ static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t)));
+
+ // update buffer_B address for next k iteration
+ a->add(buffer_B, static_cast<asmjit::Imm>(nBlock * row_interleave *
+ sizeof(int8_t)));
+ a->add(B_pf, static_cast<asmjit::Imm>(nBlock * row_interleave *
+ sizeof(int8_t)));
+
+ // a->add(B_pf, static_cast<asmjit::Imm>(32*sizeof(float)));
+
+ a->cmp(kIdx, kSize);
+ a->jl(Loopk);
+
+ // store C matrix
+ storeCRegs<inst_set_t::avx512_vnni>(a, rowRegs, colRegs, C_Offset, ldcReg,
+ accum, colRegs);
+
+ // reset A
+ a->sub(buffer_A, kSize);
+
+ // B for next block
+ a->mov(buffer_B, buffer_B_saved);
+ // using C_Offset as temp reg
+ a->imul(C_Offset, jIdx,
+ static_cast<asmjit::Imm>(nRegBlockSize * row_interleave *
+ sizeof(int8_t)));
+ a->add(buffer_B, C_Offset);
+ a->mov(B_pf, B_pf_saved);
+ a->add(B_pf, C_Offset);
+
+ // increment C for next B block
+ a->add(CBase, static_cast<asmjit::Imm>(nRegBlockSize * sizeof(int32_t)));
+
+ int jLoopTrips = currColRegs / maxNRegs;
+ // jLoopTrips should be at least 1
+ jLoopTrips = jLoopTrips ? jLoopTrips : 1;
+ a->cmp(jIdx, jLoopTrips);
+ a->jl(LoopNBlocks);
+
+ // increment A for next block
+ a->add(buffer_A,
+ static_cast<asmjit::Imm>((rowRegs)*kBlock * sizeof(uint8_t)));
+
+ // increment C for next A block
+ a->sub(CBase, static_cast<asmjit::Imm>(jLoopTrips * nRegBlockSize *
+ sizeof(int32_t)));
+ a->imul(C_Offset, ldcReg, static_cast<asmjit::Imm>(rowRegs));
+ a->add(CBase, C_Offset);
+
+ // reset B
+ a->mov(buffer_B, buffer_B_saved);
+ a->mov(B_pf, B_pf_saved);
+ a->cmp(iIdx, mRegBlocks);
+ a->jl(LoopMBlocks);
+ }
+ // generate code for remainder
+ if (mRegBlocksRem > 0) {
+ asmjit::Label LoopNRem = a->newLabel();
+ asmjit::Label LoopkRem = a->newLabel();
+ int rowRegs = mRegBlocksRem;
+
+ a->mov(jIdx, 0);
+ a->bind(LoopNRem);
+ a->inc(jIdx);
+
+ // init C registers
+ initCRegs<inst_set_t::avx512_vnni>(a, rowRegs, colRegs, colRegs);
+
+ // init k loop index
+ a->mov(kIdx, 0);
+ a->bind(LoopkRem);
+
+ // k is incremented by row_interleave
+ a->add(kIdx, static_cast<asmjit::Imm>(row_interleave));
+
+ genComputeBlock<inst_set_t::avx512_vnni>(
+ a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock, colRegs);
+
+ // update buffer_A address for next k iteration
+ a->add(buffer_A,
+ static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t)));
+
+ // update buffer_B address for next k iteration
+ a->add(buffer_B, static_cast<asmjit::Imm>(nBlock * row_interleave *
+ sizeof(int8_t)));
+ a->add(B_pf, static_cast<asmjit::Imm>(nBlock * row_interleave *
+ sizeof(int8_t)));
+
+ a->cmp(kIdx, kSize);
+ a->jl(LoopkRem);
+
+ // reset A
+ a->sub(buffer_A, kSize);
+ // B for next block
+ // using C_Offset as temp reg
+ a->imul(C_Offset, jIdx,
+ static_cast<asmjit::Imm>(nRegBlockSize * row_interleave *
+ sizeof(int8_t)));
+ a->mov(buffer_B, buffer_B_saved);
+ a->add(buffer_B, C_Offset);
+ a->mov(B_pf, B_pf_saved);
+ a->add(B_pf, C_Offset);
+
+ // store C matrix
+ storeCRegs<inst_set_t::avx512_vnni>(a, rowRegs, colRegs, C_Offset, ldcReg,
+ accum, colRegs);
+
+ // increment C for next B block
+ a->add(CBase, static_cast<asmjit::Imm>(nRegBlockSize * sizeof(int32_t)));
+
+ int jLoopTrips = currColRegs / maxNRegs;
+ // jLoopTrips should be at least 1
+ jLoopTrips = jLoopTrips ? jLoopTrips : 1;
+ a->cmp(jIdx, jLoopTrips);
+ a->jl(LoopNRem);
+ }
- a->emitEpilog(frame);
+ a->emitEpilog(frame);
- jit_micro_kernel_fp fn;
- asmjit::Error err = rt_.add(&fn, &code_);
- if (err) {
- std::cout << "Error: in fn add" << std::endl;
- return nullptr;
- }
- codeCache_[kernelSig] = fn;
+ jit_micro_kernel_fp fn;
+ asmjit::Error err;
+ {
+ std::unique_lock<std::mutex> lock(rtMutex_);
+ err = rt_.add(&fn, &code);
+ }
+ if (err) {
+ std::cout << "Error: in fn add" << std::endl;
+ return nullptr;
+ }
#if defined(FBGEMM_LOG_CODE)
- fclose(codeLogfile);
- delete codeLogger;
+ fclose(codeLogfile);
+ delete codeLogger;
#endif
- return fn;
+ return fn;
+ });
}
} // namespace fbgemm
diff --git a/src/GroupwiseConv.h b/src/GroupwiseConv.h
index 4c5eea5..cbcf445 100644
--- a/src/GroupwiseConv.h
+++ b/src/GroupwiseConv.h
@@ -10,8 +10,10 @@
#include <cassert>
#include <cstdint>
#include <map>
+#include <mutex>
#include <string>
#include <tuple>
+#include "CodeCache.h"
#include "fbgemm/ConvUtils.h"
#include "fbgemm/Fbgemm.h"
#include "fbgemm/Utils.h"
@@ -217,16 +219,15 @@ class GenConvKernel {
template <inst_set_t instSet>
void storeResultRowoffset(x86::Emitter* a, int offset = 0);
- static thread_local asmjit::JitRuntime rt_; ///< JIT Runtime for asmjit.
- static thread_local asmjit::CodeHolder code_; ///< JIT Code Holder for asmjit.
- static thread_local std::
- map<std::tuple<bool, int, int, int>, jit_conv_kernel_fp>
- codeCache_; ///< JIT Code Cache for reuse.
- static thread_local std::
- map<std::tuple<bool, int, int, int>, jit_rowoffset_kernel_fp>
- codeCacheRowOffset_; ///< JIT Code Cache for row offset kernel.
+ static asmjit::JitRuntime rt_; ///< JIT Runtime for asmjit.
+ static std::mutex rtMutex_; ///< Controll access to rt_;
- private:
+ static CodeCache<std::tuple<bool, int, int, int>, jit_conv_kernel_fp>
+ codeCache_; ///< JIT Code Cache for reuse.
+ static CodeCache<std::tuple<bool, int, int, int>, jit_rowoffset_kernel_fp>
+ codeCacheRowOffset_; ///< JIT Code Cache for row offset kernel.
+
+private:
int vectorWidth_; ///< Vector width in bits.
int VLEN_; ///< Vector width in elements.
// avx2 specific
@@ -272,4 +273,18 @@ class GenConvKernel {
int W_PAD_; ///< Padding for width (left and right)
};
+template <int SPATIAL_DIM, typename accT>
+asmjit::JitRuntime GenConvKernel<SPATIAL_DIM, accT>::rt_;
+
+template <int SPATIAL_DIM, typename accT>
+std::mutex GenConvKernel<SPATIAL_DIM, accT>::rtMutex_;
+
+template <int SPATIAL_DIM, typename accT>
+CodeCache<std::tuple<bool, int, int, int>, jit_conv_kernel_fp>
+ GenConvKernel<SPATIAL_DIM, accT>::codeCache_;
+
+template <int SPATIAL_DIM, typename accT>
+CodeCache<std::tuple<bool, int, int, int>, jit_rowoffset_kernel_fp>
+ GenConvKernel<SPATIAL_DIM, accT>::codeCacheRowOffset_;
+
} // namespace fbgemm
diff --git a/src/GroupwiseConvAcc32Avx2.cc b/src/GroupwiseConvAcc32Avx2.cc
index b140c83..c24c391 100644
--- a/src/GroupwiseConvAcc32Avx2.cc
+++ b/src/GroupwiseConvAcc32Avx2.cc
@@ -21,20 +21,6 @@ namespace fbgemm {
using namespace std;
-template <int SPATIAL_DIM, typename accT>
-thread_local asmjit::JitRuntime GenConvKernel<SPATIAL_DIM, accT>::rt_;
-
-template <int SPATIAL_DIM, typename accT>
-thread_local asmjit::CodeHolder GenConvKernel<SPATIAL_DIM, accT>::code_;
-
-template <int SPATIAL_DIM, typename accT>
-thread_local std::map<std::tuple<bool, int, int, int>, jit_conv_kernel_fp>
- GenConvKernel<SPATIAL_DIM, accT>::codeCache_;
-
-template <int SPATIAL_DIM, typename accT>
-thread_local std::map<std::tuple<bool, int, int, int>, jit_rowoffset_kernel_fp>
- GenConvKernel<SPATIAL_DIM, accT>::codeCacheRowOffset_;
-
namespace x86 = asmjit::x86;
template <int SPATIAL_DIM>
@@ -91,14 +77,13 @@ jit_conv_kernel_fp getOrCreateConvKernel(
// Note: Wrong code is generated if it's not one of the supported convolution
assert(fbgemmOptimizedGConv<SPATIAL_DIM>(conv_param));
auto kernelSig = getKernelSig(conv_param, a_zero_point == 0);
- if (GenConvKernel<SPATIAL_DIM, accT>::codeCache_.find(kernelSig) !=
- GenConvKernel<SPATIAL_DIM, accT>::codeCache_.end()) {
- return GenConvKernel<SPATIAL_DIM, accT>::codeCache_[kernelSig];
- } else {
- auto genObj = GenConvKernel<SPATIAL_DIM, accT>(conv_param, a_zero_point);
- // TODO: Instruction set based dispatch
- return genObj.template getOrCreate<inst_set_t::avx2>(conv_param);
- }
+ return GenConvKernel<SPATIAL_DIM, accT>::codeCache_.getOrCreate(
+ kernelSig, [&]() {
+ auto genObj =
+ GenConvKernel<SPATIAL_DIM, accT>(conv_param, a_zero_point);
+ // TODO: Instruction set based dispatch
+ return genObj.template getOrCreate<inst_set_t::avx2>(conv_param);
+ });
}
template <>
@@ -1009,9 +994,9 @@ template <>
template <>
jit_conv_kernel_fp GenConvKernel<2, int32_t>::getOrCreate<inst_set_t::avx2>(
const conv_param_t<2>& conv_param) {
- code_.reset(false);
- code_.init(rt_.codeInfo());
- x86::Assembler assembler(&code_);
+ asmjit::CodeHolder code;
+ code.init(rt_.codeInfo());
+ x86::Assembler assembler(&code);
x86::Emitter* a = assembler.as<x86::Emitter>();
#if defined(FBGEMM_LOG_CODE)
@@ -1020,7 +1005,7 @@ jit_conv_kernel_fp GenConvKernel<2, int32_t>::getOrCreate<inst_set_t::avx2>(
fopen(getCodeLoggingFile<inst_set_t::avx2>(false).c_str(), "w");
asmjit::FileLogger* codeLogger = new asmjit::FileLogger(codeLogfile);
if (codeLogger) {
- code_.setLogger(codeLogger);
+ code.setLogger(codeLogger);
}
#endif
@@ -1097,13 +1082,15 @@ jit_conv_kernel_fp GenConvKernel<2, int32_t>::getOrCreate<inst_set_t::avx2>(
a->emitEpilog(frame);
jit_conv_kernel_fp fn;
- asmjit::Error err = rt_.add(&fn, &code_);
+ asmjit::Error err;
+ {
+ std::unique_lock<std::mutex> lock(rtMutex_);
+ err = rt_.add(&fn, &code);
+ }
if (err) {
std::cout << "Error: in fn add" << std::endl;
return nullptr;
}
- auto kernelSig = getKernelSig(conv_param, isAZeroPointZero_);
- codeCache_[kernelSig] = fn;
#if defined(FBGEMM_LOG_CODE)
fclose(codeLogfile);
@@ -1489,9 +1476,9 @@ template <>
jit_rowoffset_kernel_fp
GenConvKernel<2, int32_t>::getOrCreateRowOffset<inst_set_t::avx2>(
const conv_param_t<2>& conv_param) {
- code_.reset(false);
- code_.init(rt_.codeInfo());
- x86::Assembler assembler(&code_);
+ asmjit::CodeHolder code;
+ code.init(rt_.codeInfo());
+ x86::Assembler assembler(&code);
x86::Emitter* a = assembler.as<x86::Emitter>();
#if defined(FBGEMM_LOG_CODE)
@@ -1500,7 +1487,7 @@ GenConvKernel<2, int32_t>::getOrCreateRowOffset<inst_set_t::avx2>(
fopen(getCodeLoggingFile<inst_set_t::avx2>(true).c_str(), "w");
asmjit::FileLogger* codeLogger = new asmjit::FileLogger(codeLogfile);
if (codeLogger) {
- code_.setLogger(codeLogger);
+ code.setLogger(codeLogger);
}
#endif
@@ -1570,14 +1557,16 @@ GenConvKernel<2, int32_t>::getOrCreateRowOffset<inst_set_t::avx2>(
a->emitEpilog(frame);
+ asmjit::Error err;
jit_rowoffset_kernel_fp fn;
- asmjit::Error err = rt_.add(&fn, &code_);
+ {
+ std::unique_lock<std::mutex> lock(rtMutex_);
+ err = rt_.add(&fn, &code);
+ }
if (err) {
std::cout << "Error: in fn add" << std::endl;
return nullptr;
}
- auto kernelSig = getKernelSig(conv_param, isAZeroPointZero_);
- codeCacheRowOffset_[kernelSig] = fn;
#if defined(FBGEMM_LOG_CODE)
delete codeLogger;
@@ -2162,15 +2151,14 @@ jit_rowoffset_kernel_fp getOrCreateRowOffsetKernel(
// Note: Wrong code is generated if it's not one of the supported convolution
assert(fbgemmOptimizedGConv<SPATIAL_DIM>(conv_param));
auto kernelSig = getKernelSig(conv_param, a_zero_point == 0);
- if (GenConvKernel<SPATIAL_DIM, int32_t>::codeCacheRowOffset_.find(
- kernelSig) !=
- GenConvKernel<SPATIAL_DIM, int32_t>::codeCacheRowOffset_.end()) {
- return GenConvKernel<SPATIAL_DIM, int32_t>::codeCacheRowOffset_[kernelSig];
- } else {
- auto genObj = GenConvKernel<SPATIAL_DIM, int32_t>(conv_param, a_zero_point);
- // TODO: Instruction set based dispatch
- return genObj.template getOrCreateRowOffset<inst_set_t::avx2>(conv_param);
- }
+ return GenConvKernel<SPATIAL_DIM, int32_t>::codeCacheRowOffset_.getOrCreate(
+ kernelSig, [&]() {
+ auto genObj =
+ GenConvKernel<SPATIAL_DIM, int32_t>(conv_param, a_zero_point);
+ // TODO: Instruction set based dispatch
+ return genObj.template getOrCreateRowOffset<inst_set_t::avx2>(
+ conv_param);
+ });
}
template <int SPATIAL_DIM>