Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/FBGEMM.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorEdward Yang <ezyang@fb.com>2019-09-05 01:04:56 +0300
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2019-09-05 01:06:33 +0300
commitf9078fdd8104603897948f563245f4528b77da5b (patch)
tree6a638158c26e8ad1a105a241e18218e3e40795f2
parent823284ec8f1763cbfc5d135f8e1ade372ba9126b (diff)
Revert D16968373: Introduced CodeCache container to share the microkernels among different threads.
Differential Revision: D16968373 Original commit changeset: 22d66e50d9b3 fbshipit-source-id: 6163979bdb36cb0b1b95bfa1caeab67e7d23eee5
-rw-r--r--src/CodeCache.h59
-rw-r--r--src/GenerateKernel.h23
-rw-r--r--src/GenerateKernelU8S8S32ACC16.cc341
-rw-r--r--src/GenerateKernelU8S8S32ACC16Avx512.cc473
-rw-r--r--src/GenerateKernelU8S8S32ACC32.cc380
-rw-r--r--src/GenerateKernelU8S8S32ACC32Avx512.cc504
-rw-r--r--src/GenerateKernelU8S8S32ACC32Avx512VNNI.cc501
-rw-r--r--src/GroupwiseConv.h33
-rw-r--r--src/GroupwiseConvAcc32Avx2.cc78
9 files changed, 1203 insertions, 1189 deletions
diff --git a/src/CodeCache.h b/src/CodeCache.h
deleted file mode 100644
index 08e9c9b..0000000
--- a/src/CodeCache.h
+++ /dev/null
@@ -1,59 +0,0 @@
-/*
- * Copyright (c) Facebook, Inc. and its affiliates.
- * All rights reserved.
- * This source code is licensed under the BSD-style license found in the
- * LICENSE file in the root directory of this source tree.
- */
-#pragma once
-#include <condition_variable>
-#include <future>
-#include <map>
-#include <mutex>
-
-namespace fbgemm {
-
-/**
- * @brief Thread safe cache for microkernels, ensures single creation per key.
- * @tparam Key Type of unique key (typically a tuple)
- * @tparam Value Type of the microkernel function (Typically a function pointer)
- */
-template <typename KEY, typename VALUE> class CodeCache {
-private:
- std::map<KEY, std::shared_future<VALUE>> values_;
- std::mutex mutex_;
-
-public:
- CodeCache(const CodeCache &) = delete;
- CodeCache &operator=(const CodeCache &) = delete;
-
- CodeCache(){};
-
- VALUE getOrCreate(const KEY &key, std::function<VALUE()> generatorFunction) {
- std::shared_future<VALUE> returnFuture;
- std::promise<VALUE> returnPromise;
- bool needsToGenerate = false;
-
- // Check for existance of the key
- {
- std::unique_lock<std::mutex> lock(mutex_);
-
- auto it = values_.find(key);
- if (it != values_.end()) {
- returnFuture = it->second;
- } else {
- values_[key] = returnFuture = returnPromise.get_future().share();
- needsToGenerate = true;
- }
- }
-
- // The value (code) generation is not happening under a lock
- if (needsToGenerate) {
- returnPromise.set_value(generatorFunction());
- }
-
- // Wait for the future and return the value
- return returnFuture.get();
- }
-};
-
-} // namespace fbgemm
diff --git a/src/GenerateKernel.h b/src/GenerateKernel.h
index bd61473..e52097e 100644
--- a/src/GenerateKernel.h
+++ b/src/GenerateKernel.h
@@ -8,10 +8,8 @@
#include <asmjit/asmjit.h>
#include <cpuinfo.h>
#include <map>
-#include <mutex>
#include <string>
#include <tuple>
-#include "CodeCache.h"
#include "fbgemm/Fbgemm.h"
/*#define FBGEMM_LOG_CODE 1*/
@@ -189,24 +187,13 @@ class CodeGenBase {
int vectorWidth_; ///< Vector width in bits.
int VLEN_; ///< Vector width in elements.
- static asmjit::JitRuntime rt_; ///< JIT Runtime for asmjit.
- static std::mutex rtMutex_; ///< Controll access to rt_;
-
+ static thread_local asmjit::JitRuntime rt_; ///< JIT Runtime for asmjit.
+ static thread_local asmjit::CodeHolder code_; ///< JIT Code Holder for asmjit.
// The hash depends on accumulate, mc, nc, ncb, kcb, nr, mr, nr_min
- static CodeCache<std::tuple<bool, int, int, int, int, int, int, int>,
- jit_micro_kernel_fp>
+ static thread_local std::map<
+ std::tuple<bool, int, int, int, int, int, int, int>,
+ jit_micro_kernel_fp>
codeCache_; ///< JIT Code Cache for reuse.
};
-template <typename TA, typename TB, typename TC, typename accT>
-asmjit::JitRuntime CodeGenBase<TA, TB, TC, accT>::rt_;
-
-template <typename TA, typename TB, typename TC, typename accT>
-std::mutex CodeGenBase<TA, TB, TC, accT>::rtMutex_;
-
-template <typename TA, typename TB, typename TC, typename accT>
-CodeCache<std::tuple<bool, int, int, int, int, int, int, int>,
- typename CodeGenBase<TA, TB, TC, accT>::jit_micro_kernel_fp>
- CodeGenBase<TA, TB, TC, accT>::codeCache_;
-
} // namespace fbgemm
diff --git a/src/GenerateKernelU8S8S32ACC16.cc b/src/GenerateKernelU8S8S32ACC16.cc
index 6377961..1e7e081 100644
--- a/src/GenerateKernelU8S8S32ACC16.cc
+++ b/src/GenerateKernelU8S8S32ACC16.cc
@@ -9,6 +9,18 @@
namespace fbgemm {
+template <typename TA, typename TB, typename TC, typename accT>
+thread_local asmjit::JitRuntime CodeGenBase<TA, TB, TC, accT>::rt_;
+
+template <typename TA, typename TB, typename TC, typename accT>
+thread_local asmjit::CodeHolder CodeGenBase<TA, TB, TC, accT>::code_;
+
+template <typename TA, typename TB, typename TC, typename accT>
+thread_local std::map<
+ std::tuple<bool, int, int, int, int, int, int, int>,
+ typename CodeGenBase<TA, TB, TC, accT>::jit_micro_kernel_fp>
+ CodeGenBase<TA, TB, TC, accT>::codeCache_;
+
namespace x86 = asmjit::x86;
/**
@@ -160,186 +172,191 @@ CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<inst_set_t::avx2>(
nRegBlockSize,
nRegBlockSizeMin);
- return codeCache_.getOrCreate(kernelSig, [&]() -> jit_micro_kernel_fp {
- asmjit::CodeHolder code;
- code.init(rt_.codeInfo());
- x86::Assembler assembler(&code);
- x86::Emitter *a = assembler.as<x86::Emitter>();
+ if (codeCache_.find(kernelSig) != codeCache_.end()) {
+ return codeCache_[kernelSig];
+ }
+ code_.reset(false);
+ code_.init(rt_.codeInfo());
+ x86::Assembler assembler(&code_);
+ x86::Emitter* a = assembler.as<x86::Emitter>();
#if defined(FBGEMM_LOG_CODE)
- // generated code logging
- FILE *codeLogfile = fopen(getCodeLoggingFile<inst_set_t::avx2>(
- accum, mc, nc, nBlock, kBlock, mRegBlockSize,
- nRegBlockSize, nRegBlockSizeMin)
- .c_str(),
- "w");
- asmjit::FileLogger *codeLogger = new asmjit::FileLogger(codeLogfile);
- if (codeLogger) {
- code.setLogger(codeLogger);
- }
+ // generated code logging
+ FILE* codeLogfile = fopen(
+ getCodeLoggingFile<inst_set_t::avx2>(
+ accum,
+ mc,
+ nc,
+ nBlock,
+ kBlock,
+ mRegBlockSize,
+ nRegBlockSize,
+ nRegBlockSizeMin)
+ .c_str(),
+ "w");
+ asmjit::FileLogger* codeLogger = new asmjit::FileLogger(codeLogfile);
+ if (codeLogger) {
+ code_.setLogger(codeLogger);
+ }
#endif
- int mRegBlocks = mc / mRegBlockSize;
- int mRegBlocksRem = mc % mRegBlockSize;
- assert(kc % row_interleave == 0 &&
- "kc must be a multiple of row_interleave");
- // assert((nc == nRegBlockSize) &&
- //"nc must be equal to the number of register blocks");
-
- // arguments to the function created
- x86::Gp buffer_A = a->zdi();
- x86::Gp buffer_B = a->zsi();
- x86::Gp B_pf = a->zdx();
- x86::Gp CBase = a->zcx();
- x86::Gp kSize = a->gpz(8);
- x86::Gp ldcReg = a->gpz(9);
-
- asmjit::FuncDetail func;
- func.init(
- asmjit::FuncSignatureT<void, uint8_t *, int8_t *, int8_t *, int32_t *,
- int, int>(asmjit::CallConv::kIdHost));
-
- asmjit::FuncFrame frame;
- frame.init(func);
- frame.setDirtyRegs(
- x86::Reg::kGroupVec,
- asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) |
- asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
- frame.setDirtyRegs(x86::Reg::kGroupGp,
- asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14));
-
- asmjit::FuncArgsAssignment args(&func);
- args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg);
-
- args.updateFuncFrame(frame);
- frame.finalize();
-
- a->emitProlog(frame);
- a->emitArgsAssignment(frame, args);
-
- asmjit::Label Loopk = a->newLabel();
- asmjit::Label LoopMBlocks = a->newLabel();
-
- x86::Gp buffer_B_saved = a->gpz(10);
- x86::Gp C_Offset = a->gpz(11);
- // x86::Gp B_pf_saved = a->gpz(12);
- x86::Gp iIdx = a->gpz(13);
- x86::Gp kIdx = a->gpz(14);
-
- int colRegs = nc * row_interleave * sizeof(int8_t) / VLEN_;
- if (mRegBlocks > 0) {
- // move 0 to iteration variables
- a->mov(iIdx, 0);
-
- // save B_buffer address
- a->mov(buffer_B_saved, buffer_B);
- // a->mov(B_pf_saved, B_pf);
-
- a->bind(LoopMBlocks);
- a->inc(iIdx);
-
- int rowRegs = mRegBlockSize;
-
- // init C registers
- initCRegs<inst_set_t::avx2>(a, rowRegs, colRegs);
-
- // init k loop index
- a->mov(kIdx, 0);
- a->bind(Loopk);
- // k is incremented by row_interleave
- a->add(kIdx, static_cast<asmjit::Imm>(row_interleave));
-
- genComputeBlock<inst_set_t::avx2>(a, buffer_A, buffer_B, B_pf, rowRegs,
- colRegs, kBlock);
-
- // update buffer_A address for next k iteration
- a->add(buffer_A,
- static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t)));
-
- // update buffer_B address for next k iteration
- a->add(buffer_B, static_cast<asmjit::Imm>(nBlock * row_interleave *
- sizeof(int8_t)));
- // a->add(B_pf, static_cast<asmjit::Imm>(nBlock * row_interleave *
- // sizeof(int8_t)));
-
- a->cmp(kIdx, kSize);
- a->jl(Loopk);
-
- // store C matrix
- storeCRegs<inst_set_t::avx2>(a, rowRegs, colRegs, C_Offset, ldcReg,
- accum);
-
- // increment A for next block
- a->sub(buffer_A, kSize);
- a->add(buffer_A,
- static_cast<asmjit::Imm>((rowRegs)*kBlock * sizeof(uint8_t)));
- // increment C for next block
- a->imul(C_Offset, ldcReg,
- static_cast<asmjit::Imm>(rowRegs * sizeof(int32_t)));
- a->add(CBase, C_Offset);
- // reset B
- a->mov(buffer_B, buffer_B_saved);
- // a->mov(B_pf, B_pf_saved);
-
- a->cmp(iIdx, mRegBlocks);
- a->jl(LoopMBlocks);
- }
- // generate code for remainder
- if (mRegBlocksRem > 0) {
- asmjit::Label LoopkRem = a->newLabel();
- int rowRegs = mRegBlocksRem;
+ int mRegBlocks = mc / mRegBlockSize;
+ int mRegBlocksRem = mc % mRegBlockSize;
+ assert(kc % row_interleave == 0 && "kc must be a multiple of row_interleave");
+ // assert((nc == nRegBlockSize) &&
+ //"nc must be equal to the number of register blocks");
+
+ // arguments to the function created
+ x86::Gp buffer_A = a->zdi();
+ x86::Gp buffer_B = a->zsi();
+ x86::Gp B_pf = a->zdx();
+ x86::Gp CBase = a->zcx();
+ x86::Gp kSize = a->gpz(8);
+ x86::Gp ldcReg = a->gpz(9);
+
+ asmjit::FuncDetail func;
+ func.init(
+ asmjit::
+ FuncSignatureT<void, uint8_t*, int8_t*, int8_t*, int32_t*, int, int>(
+ asmjit::CallConv::kIdHost));
+
+ asmjit::FuncFrame frame;
+ frame.init(func);
+ frame.setDirtyRegs(
+ x86::Reg::kGroupVec,
+ asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) |
+ asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
+ frame.setDirtyRegs(
+ x86::Reg::kGroupGp, asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14));
+
+ asmjit::FuncArgsAssignment args(&func);
+ args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg);
+
+ args.updateFuncFrame(frame);
+ frame.finalize();
+
+ a->emitProlog(frame);
+ a->emitArgsAssignment(frame, args);
+
+ asmjit::Label Loopk = a->newLabel();
+ asmjit::Label LoopMBlocks = a->newLabel();
+
+ x86::Gp buffer_B_saved = a->gpz(10);
+ x86::Gp C_Offset = a->gpz(11);
+ // x86::Gp B_pf_saved = a->gpz(12);
+ x86::Gp iIdx = a->gpz(13);
+ x86::Gp kIdx = a->gpz(14);
+
+ int colRegs = nc * row_interleave * sizeof(int8_t) / VLEN_;
+ if (mRegBlocks > 0) {
+ // move 0 to iteration variables
+ a->mov(iIdx, 0);
+
+ // save B_buffer address
+ a->mov(buffer_B_saved, buffer_B);
+ // a->mov(B_pf_saved, B_pf);
+
+ a->bind(LoopMBlocks);
+ a->inc(iIdx);
+
+ int rowRegs = mRegBlockSize;
+
+ // init C registers
+ initCRegs<inst_set_t::avx2>(a, rowRegs, colRegs);
+
+ // init k loop index
+ a->mov(kIdx, 0);
+ a->bind(Loopk);
+ // k is incremented by row_interleave
+ a->add(kIdx, static_cast<asmjit::Imm>(row_interleave));
+
+ genComputeBlock<inst_set_t::avx2>(
+ a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock);
+
+ // update buffer_A address for next k iteration
+ a->add(
+ buffer_A, static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t)));
+
+ // update buffer_B address for next k iteration
+ a->add(
+ buffer_B,
+ static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t)));
+ // a->add(B_pf, static_cast<asmjit::Imm>(nBlock * row_interleave *
+ // sizeof(int8_t)));
+
+ a->cmp(kIdx, kSize);
+ a->jl(Loopk);
+
+ // store C matrix
+ storeCRegs<inst_set_t::avx2>(a, rowRegs, colRegs, C_Offset, ldcReg, accum);
+
+ // increment A for next block
+ a->sub(buffer_A, kSize);
+ a->add(
+ buffer_A, static_cast<asmjit::Imm>((rowRegs)*kBlock * sizeof(uint8_t)));
+ // increment C for next block
+ a->imul(
+ C_Offset, ldcReg, static_cast<asmjit::Imm>(rowRegs * sizeof(int32_t)));
+ a->add(CBase, C_Offset);
+ // reset B
+ a->mov(buffer_B, buffer_B_saved);
+ // a->mov(B_pf, B_pf_saved);
+
+ a->cmp(iIdx, mRegBlocks);
+ a->jl(LoopMBlocks);
+ }
+ // generate code for remainder
+ if (mRegBlocksRem > 0) {
+ asmjit::Label LoopkRem = a->newLabel();
+ int rowRegs = mRegBlocksRem;
- // init C registers
- initCRegs<inst_set_t::avx2>(a, rowRegs, colRegs);
+ // init C registers
+ initCRegs<inst_set_t::avx2>(a, rowRegs, colRegs);
- // init k loop index
- a->mov(kIdx, 0);
- a->bind(LoopkRem);
+ // init k loop index
+ a->mov(kIdx, 0);
+ a->bind(LoopkRem);
- // k is incremented by row_interleave
- a->add(kIdx, static_cast<asmjit::Imm>(row_interleave));
+ // k is incremented by row_interleave
+ a->add(kIdx, static_cast<asmjit::Imm>(row_interleave));
- genComputeBlock<inst_set_t::avx2>(a, buffer_A, buffer_B, B_pf, rowRegs,
- colRegs, kBlock);
+ genComputeBlock<inst_set_t::avx2>(
+ a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock);
- // update buffer_A address for next k iteration
- a->add(buffer_A,
- static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t)));
+ // update buffer_A address for next k iteration
+ a->add(
+ buffer_A, static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t)));
- // update buffer_B address for next k iteration
- a->add(buffer_B, static_cast<asmjit::Imm>(nBlock * row_interleave *
- sizeof(int8_t)));
- // a->add(B_pf, static_cast<asmjit::Imm>(nBlock * row_interleave *
- // sizeof(int8_t)));
+ // update buffer_B address for next k iteration
+ a->add(
+ buffer_B,
+ static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t)));
+ // a->add(B_pf, static_cast<asmjit::Imm>(nBlock * row_interleave *
+ // sizeof(int8_t)));
- a->cmp(kIdx, kSize);
- a->jl(LoopkRem);
+ a->cmp(kIdx, kSize);
+ a->jl(LoopkRem);
- // store C matrix
- storeCRegs<inst_set_t::avx2>(a, rowRegs, colRegs, C_Offset, ldcReg,
- accum);
- }
+ // store C matrix
+ storeCRegs<inst_set_t::avx2>(a, rowRegs, colRegs, C_Offset, ldcReg, accum);
+ }
- a->emitEpilog(frame);
+ a->emitEpilog(frame);
- jit_micro_kernel_fp fn;
- asmjit::Error err;
- {
- std::unique_lock<std::mutex> lock(rtMutex_);
- err = rt_.add(&fn, &code);
- }
- if (err) {
- std::cout << "Error: in fn add" << std::endl;
- return nullptr;
- }
+ jit_micro_kernel_fp fn;
+ asmjit::Error err = rt_.add(&fn, &code_);
+ if (err) {
+ std::cout << "Error: in fn add" << std::endl;
+ return nullptr;
+ }
+ codeCache_[kernelSig] = fn;
#if defined(FBGEMM_LOG_CODE)
- fclose(codeLogfile);
- delete codeLogger;
+ fclose(codeLogfile);
+ delete codeLogger;
#endif
- return fn;
- });
+ return fn;
}
} // namespace fbgemm
diff --git a/src/GenerateKernelU8S8S32ACC16Avx512.cc b/src/GenerateKernelU8S8S32ACC16Avx512.cc
index c904de0..a49e440 100644
--- a/src/GenerateKernelU8S8S32ACC16Avx512.cc
+++ b/src/GenerateKernelU8S8S32ACC16Avx512.cc
@@ -167,247 +167,262 @@ CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<inst_set_t::avx512>(
nRegBlockSize,
nRegBlockSizeMin);
- return codeCache_.getOrCreate(kernelSig, [&]() -> jit_micro_kernel_fp {
- asmjit::CodeHolder code;
- code.init(rt_.codeInfo());
- x86::Assembler assembler(&code);
- x86::Emitter *a = assembler.as<x86::Emitter>();
+ if (codeCache_.find(kernelSig) != codeCache_.end()) {
+ return codeCache_[kernelSig];
+ }
+
+ code_.reset(false);
+ code_.init(rt_.codeInfo());
+ x86::Assembler assembler(&code_);
+ x86::Emitter* a = assembler.as<x86::Emitter>();
#if defined(FBGEMM_LOG_CODE)
- // generated code logging
- FILE *codeLogfile = fopen(getCodeLoggingFile<inst_set_t::avx512>(
- accum, mc, nc, nBlock, kBlock, mRegBlockSize,
- nRegBlockSize, nRegBlockSizeMin)
- .c_str(),
- "w");
- asmjit::FileLogger *codeLogger = new asmjit::FileLogger(codeLogfile);
- if (codeLogger) {
- code.setLogger(codeLogger);
- }
+ // generated code logging
+ FILE* codeLogfile = fopen(
+ getCodeLoggingFile<inst_set_t::avx512>(
+ accum,
+ mc,
+ nc,
+ nBlock,
+ kBlock,
+ mRegBlockSize,
+ nRegBlockSize,
+ nRegBlockSizeMin)
+ .c_str(),
+ "w");
+ asmjit::FileLogger* codeLogger = new asmjit::FileLogger(codeLogfile);
+ if (codeLogger) {
+ code_.setLogger(codeLogger);
+ }
#endif
- assert(kc % row_interleave == 0 &&
- "kc must be a multiple of row_interleave");
- assert(nc % nRegBlockSizeMin == 0 && "nc must be a multiple of NR_MIN");
- int maxMRegs = mRegBlockSize;
- int maxNRegs = nRegBlockSize * row_interleave / VLEN_;
- assert((maxMRegs + 1) * maxNRegs <= 28 &&
- "number of zmm registers for C + one row for loading B: \
+ assert(kc % row_interleave == 0 && "kc must be a multiple of row_interleave");
+ assert(nc % nRegBlockSizeMin == 0 && "nc must be a multiple of NR_MIN");
+ int maxMRegs = mRegBlockSize;
+ int maxNRegs = nRegBlockSize * row_interleave / VLEN_;
+ assert(
+ (maxMRegs+1) * maxNRegs <= 28 &&
+ "number of zmm registers for C + one row for loading B: \
MR*(NR*ROW_INTERLEAVE*8/512) + (NR*ROW_INTERLEAVE*8/512) \
must be <= 28(available registers constraint)");
- int mRegBlocks = mc / mRegBlockSize;
- int mRegBlocksRem = mc % mRegBlockSize;
-
- // arguments to the function created
- x86::Gp buffer_A = a->zdi();
- x86::Gp buffer_B = a->zsi();
- x86::Gp B_pf = a->zdx();
- x86::Gp CBase = a->zcx();
- x86::Gp kSize = a->gpz(8);
- x86::Gp ldcReg = a->gpz(9);
-
- asmjit::FuncDetail func;
- func.init(
- asmjit::FuncSignatureT<void, uint8_t *, int8_t *, int8_t *, int32_t *,
- int, int>(asmjit::CallConv::kIdHost));
-
- asmjit::FuncFrame frame;
- frame.init(func);
-
- frame.setDirtyRegs(
- x86::Reg::kGroupVec,
- asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) |
- asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
- frame.setDirtyRegs(x86::Reg::kGroupGp,
- asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
-
- asmjit::FuncArgsAssignment args(&func);
- args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg);
-
- args.updateFuncFrame(frame);
- frame.finalize();
-
- a->emitProlog(frame);
- a->emitArgsAssignment(frame, args);
-
- asmjit::Label LoopMBlocks = a->newLabel();
- asmjit::Label LoopNBlocks = a->newLabel();
- asmjit::Label Loopk = a->newLabel();
-
- x86::Gp buffer_B_saved = a->gpz(10);
- x86::Gp C_Offset = a->gpz(11);
- // x86::Gp B_pf_saved = a->gpz(12);
- x86::Gp iIdx = a->gpz(13);
- x86::Gp jIdx = a->gpz(14);
- x86::Gp kIdx = a->gpz(15);
-
- // save B_buffer address
- a->mov(buffer_B_saved, buffer_B);
- // a->mov(B_pf_saved, B_pf);
-
- int currColRegs = nc * row_interleave * sizeof(int8_t) / VLEN_;
- int colRegs = std::min(currColRegs, maxNRegs);
- if (mRegBlocks > 0) {
- // move 0 to iteration variables
- a->mov(iIdx, 0);
-
- a->bind(LoopMBlocks);
- a->inc(iIdx);
- a->mov(jIdx, 0);
-
- a->bind(LoopNBlocks);
- a->inc(jIdx);
-
- int rowRegs = mRegBlockSize;
-
- // init C registers
- initCRegs<inst_set_t::avx512>(a, rowRegs, colRegs, colRegs);
-
- // init k loop index
- a->mov(kIdx, 0);
- a->bind(Loopk);
- // k is incremented by row_interleave
- a->add(kIdx, static_cast<asmjit::Imm>(row_interleave));
-
- genComputeBlock<inst_set_t::avx512>(a, buffer_A, buffer_B, B_pf, rowRegs,
- colRegs, kBlock, colRegs);
-
- // update buffer_A address for next k iteration
- a->add(buffer_A,
- static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t)));
-
- // update buffer_B address for next k iteration
- a->add(buffer_B, static_cast<asmjit::Imm>(nBlock * row_interleave *
- sizeof(int8_t)));
- // a->add(B_pf, static_cast<asmjit::Imm>(nBlock * row_interleave *
- // sizeof(int8_t)));
-
- a->cmp(kIdx, kSize);
- a->jl(Loopk);
-
- // store C matrix
- storeCRegs<inst_set_t::avx512>(a, rowRegs, colRegs, C_Offset, ldcReg,
- accum, colRegs);
-
- // reset A
- a->sub(buffer_A, kSize);
-
- // B for next block
- a->mov(buffer_B, buffer_B_saved);
- // using C_Offset as temp reg
- a->imul(C_Offset, jIdx,
- static_cast<asmjit::Imm>(nRegBlockSize * row_interleave *
- sizeof(int8_t)));
- a->add(buffer_B, C_Offset);
-
- // increment C for next block
- a->add(CBase, static_cast<asmjit::Imm>(nRegBlockSize * sizeof(int32_t)));
-
- int jLoopTrips = currColRegs / maxNRegs;
- // jLoopTrips should be at least 1
- jLoopTrips = jLoopTrips ? jLoopTrips : 1;
- a->cmp(jIdx, jLoopTrips);
- a->jl(LoopNBlocks);
-
- // increment A for next block
- a->add(buffer_A,
- static_cast<asmjit::Imm>((rowRegs)*kBlock * sizeof(uint8_t)));
-
- // increment C for next A block
- a->sub(CBase, static_cast<asmjit::Imm>(jLoopTrips * nRegBlockSize *
- sizeof(int32_t)));
- a->imul(C_Offset, ldcReg,
- static_cast<asmjit::Imm>(rowRegs * sizeof(int32_t)));
- a->add(CBase, C_Offset);
-
- // reset B
- a->mov(buffer_B, buffer_B_saved);
- // a->mov(B_pf, B_pf_saved);
-
- a->cmp(iIdx, mRegBlocks);
- a->jl(LoopMBlocks);
- }
- // generate code for remainder
- if (mRegBlocksRem > 0) {
- asmjit::Label LoopNRem = a->newLabel();
- asmjit::Label LoopkRem = a->newLabel();
- int rowRegs = mRegBlocksRem;
-
- a->mov(jIdx, 0);
- a->bind(LoopNRem);
- a->inc(jIdx);
-
- // init C registers
- initCRegs<inst_set_t::avx512>(a, rowRegs, colRegs, colRegs);
-
- // init k loop index
- a->mov(kIdx, 0);
- a->bind(LoopkRem);
-
- // k is incremented by row_interleave
- a->add(kIdx, static_cast<asmjit::Imm>(row_interleave));
-
- genComputeBlock<inst_set_t::avx512>(a, buffer_A, buffer_B, B_pf, rowRegs,
- colRegs, kBlock, colRegs);
-
- // update buffer_A address for next k iteration
- a->add(buffer_A,
- static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t)));
-
- // update buffer_B address for next k iteration
- a->add(buffer_B, static_cast<asmjit::Imm>(nBlock * row_interleave *
- sizeof(int8_t)));
- // a->add(B_pf, static_cast<asmjit::Imm>(nBlock * row_interleave *
- // sizeof(int8_t)));
-
- a->cmp(kIdx, kSize);
- a->jl(LoopkRem);
-
- // reset A
- a->sub(buffer_A, kSize);
-
- // B for next block
- a->mov(buffer_B, buffer_B_saved);
- // using C_Offset as temp reg
- a->imul(C_Offset, jIdx,
- static_cast<asmjit::Imm>(nRegBlockSize * row_interleave *
- sizeof(int8_t)));
- a->add(buffer_B, C_Offset);
-
- // store C matrix
- storeCRegs<inst_set_t::avx512>(a, rowRegs, colRegs, C_Offset, ldcReg,
- accum, colRegs);
-
- // increment C for next block
- a->add(CBase, static_cast<asmjit::Imm>(nRegBlockSize * sizeof(int32_t)));
-
- int jLoopTrips = currColRegs / maxNRegs;
- // jLoopTrips should be at least 1
- jLoopTrips = jLoopTrips ? jLoopTrips : 1;
- a->cmp(jIdx, jLoopTrips);
- a->jl(LoopNRem);
- }
+ int mRegBlocks = mc / mRegBlockSize;
+ int mRegBlocksRem = mc % mRegBlockSize;
+
+ // arguments to the function created
+ x86::Gp buffer_A = a->zdi();
+ x86::Gp buffer_B = a->zsi();
+ x86::Gp B_pf = a->zdx();
+ x86::Gp CBase = a->zcx();
+ x86::Gp kSize = a->gpz(8);
+ x86::Gp ldcReg = a->gpz(9);
+
+ asmjit::FuncDetail func;
+ func.init(
+ asmjit::
+ FuncSignatureT<void, uint8_t*, int8_t*, int8_t*, int32_t*, int, int>(
+ asmjit::CallConv::kIdHost));
+
+ asmjit::FuncFrame frame;
+ frame.init(func);
+
+ frame.setDirtyRegs(
+ x86::Reg::kGroupVec,
+ asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) |
+ asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
+ frame.setDirtyRegs(
+ x86::Reg::kGroupGp,
+ asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
+
+ asmjit::FuncArgsAssignment args(&func);
+ args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg);
+
+ args.updateFuncFrame(frame);
+ frame.finalize();
+
+ a->emitProlog(frame);
+ a->emitArgsAssignment(frame, args);
+
+ asmjit::Label LoopMBlocks = a->newLabel();
+ asmjit::Label LoopNBlocks = a->newLabel();
+ asmjit::Label Loopk = a->newLabel();
+
+ x86::Gp buffer_B_saved = a->gpz(10);
+ x86::Gp C_Offset = a->gpz(11);
+ // x86::Gp B_pf_saved = a->gpz(12);
+ x86::Gp iIdx = a->gpz(13);
+ x86::Gp jIdx = a->gpz(14);
+ x86::Gp kIdx = a->gpz(15);
+
+ // save B_buffer address
+ a->mov(buffer_B_saved, buffer_B);
+ // a->mov(B_pf_saved, B_pf);
+
+ int currColRegs = nc * row_interleave * sizeof(int8_t) / VLEN_;
+ int colRegs = std::min(currColRegs, maxNRegs);
+ if (mRegBlocks > 0) {
+ // move 0 to iteration variables
+ a->mov(iIdx, 0);
+
+ a->bind(LoopMBlocks);
+ a->inc(iIdx);
+ a->mov(jIdx, 0);
+
+ a->bind(LoopNBlocks);
+ a->inc(jIdx);
+
+ int rowRegs = mRegBlockSize;
+
+ // init C registers
+ initCRegs<inst_set_t::avx512>(a, rowRegs, colRegs, colRegs);
+
+ // init k loop index
+ a->mov(kIdx, 0);
+ a->bind(Loopk);
+ // k is incremented by row_interleave
+ a->add(kIdx, static_cast<asmjit::Imm>(row_interleave));
+
+ genComputeBlock<inst_set_t::avx512>(
+ a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock, colRegs);
+
+ // update buffer_A address for next k iteration
+ a->add(
+ buffer_A, static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t)));
+
+ // update buffer_B address for next k iteration
+ a->add(
+ buffer_B,
+ static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t)));
+ // a->add(B_pf, static_cast<asmjit::Imm>(nBlock * row_interleave *
+ // sizeof(int8_t)));
+
+ a->cmp(kIdx, kSize);
+ a->jl(Loopk);
+
+ // store C matrix
+ storeCRegs<inst_set_t::avx512>(
+ a, rowRegs, colRegs, C_Offset, ldcReg, accum, colRegs);
+
+ // reset A
+ a->sub(buffer_A, kSize);
+
+ // B for next block
+ a->mov(buffer_B, buffer_B_saved);
+ // using C_Offset as temp reg
+ a->imul(
+ C_Offset,
+ jIdx,
+ static_cast<asmjit::Imm>(
+ nRegBlockSize * row_interleave * sizeof(int8_t)));
+ a->add(buffer_B, C_Offset);
+
+ // increment C for next block
+ a->add(CBase, static_cast<asmjit::Imm>(nRegBlockSize * sizeof(int32_t)));
+
+ int jLoopTrips = currColRegs / maxNRegs;
+ // jLoopTrips should be at least 1
+ jLoopTrips = jLoopTrips ? jLoopTrips : 1;
+ a->cmp(jIdx, jLoopTrips);
+ a->jl(LoopNBlocks);
+
+ // increment A for next block
+ a->add(
+ buffer_A, static_cast<asmjit::Imm>((rowRegs)*kBlock * sizeof(uint8_t)));
+
+ // increment C for next A block
+ a->sub(
+ CBase,
+ static_cast<asmjit::Imm>(jLoopTrips * nRegBlockSize * sizeof(int32_t)));
+ a->imul(
+ C_Offset, ldcReg, static_cast<asmjit::Imm>(rowRegs * sizeof(int32_t)));
+ a->add(CBase, C_Offset);
+
+ // reset B
+ a->mov(buffer_B, buffer_B_saved);
+ // a->mov(B_pf, B_pf_saved);
+
+ a->cmp(iIdx, mRegBlocks);
+ a->jl(LoopMBlocks);
+ }
+ // generate code for remainder
+ if (mRegBlocksRem > 0) {
+ asmjit::Label LoopNRem = a->newLabel();
+ asmjit::Label LoopkRem = a->newLabel();
+ int rowRegs = mRegBlocksRem;
+
+ a->mov(jIdx, 0);
+ a->bind(LoopNRem);
+ a->inc(jIdx);
+
+ // init C registers
+ initCRegs<inst_set_t::avx512>(a, rowRegs, colRegs, colRegs);
+
+ // init k loop index
+ a->mov(kIdx, 0);
+ a->bind(LoopkRem);
+
+ // k is incremented by row_interleave
+ a->add(kIdx, static_cast<asmjit::Imm>(row_interleave));
+
+ genComputeBlock<inst_set_t::avx512>(
+ a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock, colRegs);
+
+ // update buffer_A address for next k iteration
+ a->add(
+ buffer_A, static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t)));
+
+ // update buffer_B address for next k iteration
+ a->add(
+ buffer_B,
+ static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t)));
+ // a->add(B_pf, static_cast<asmjit::Imm>(nBlock * row_interleave *
+ // sizeof(int8_t)));
+
+ a->cmp(kIdx, kSize);
+ a->jl(LoopkRem);
+
+ // reset A
+ a->sub(buffer_A, kSize);
+
+ // B for next block
+ a->mov(buffer_B, buffer_B_saved);
+ // using C_Offset as temp reg
+ a->imul(
+ C_Offset,
+ jIdx,
+ static_cast<asmjit::Imm>(
+ nRegBlockSize * row_interleave * sizeof(int8_t)));
+ a->add(buffer_B, C_Offset);
+
+ // store C matrix
+ storeCRegs<inst_set_t::avx512>(
+ a, rowRegs, colRegs, C_Offset, ldcReg, accum, colRegs);
+
+ // increment C for next block
+ a->add(CBase, static_cast<asmjit::Imm>(nRegBlockSize * sizeof(int32_t)));
+
+ int jLoopTrips = currColRegs / maxNRegs;
+ // jLoopTrips should be at least 1
+ jLoopTrips = jLoopTrips ? jLoopTrips : 1;
+ a->cmp(jIdx, jLoopTrips);
+ a->jl(LoopNRem);
+ }
- a->emitEpilog(frame);
+ a->emitEpilog(frame);
- jit_micro_kernel_fp fn;
- asmjit::Error err;
- {
- std::unique_lock<std::mutex> lock(rtMutex_);
- err = rt_.add(&fn, &code);
- }
- if (err) {
- std::cout << "Error: in fn add" << std::endl;
- return nullptr;
- }
+ jit_micro_kernel_fp fn;
+ asmjit::Error err = rt_.add(&fn, &code_);
+ if (err) {
+ std::cout << "Error: in fn add" << std::endl;
+ return nullptr;
+ }
+ codeCache_[kernelSig] = fn;
#if defined(FBGEMM_LOG_CODE)
- fclose(codeLogfile);
- delete codeLogger;
+ fclose(codeLogfile);
+ delete codeLogger;
#endif
- return fn;
- });
+ return fn;
}
} // namespace fbgemm
diff --git a/src/GenerateKernelU8S8S32ACC32.cc b/src/GenerateKernelU8S8S32ACC32.cc
index 4f2b160..6b54743 100644
--- a/src/GenerateKernelU8S8S32ACC32.cc
+++ b/src/GenerateKernelU8S8S32ACC32.cc
@@ -9,6 +9,18 @@
namespace fbgemm {
+template <typename TA, typename TB, typename TC, typename accT>
+thread_local asmjit::JitRuntime CodeGenBase<TA, TB, TC, accT>::rt_;
+
+template <typename TA, typename TB, typename TC, typename accT>
+thread_local asmjit::CodeHolder CodeGenBase<TA, TB, TC, accT>::code_;
+
+template <typename TA, typename TB, typename TC, typename accT>
+thread_local std::map<
+ std::tuple<bool, int, int, int, int, int, int, int>,
+ typename CodeGenBase<TA, TB, TC, accT>::jit_micro_kernel_fp>
+ CodeGenBase<TA, TB, TC, accT>::codeCache_;
+
namespace x86 = asmjit::x86;
/**
@@ -161,196 +173,206 @@ CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreate<inst_set_t::avx2>(
nRegBlockSize,
nRegBlockSizeMin);
- return codeCache_.getOrCreate(kernelSig, [&]() -> jit_micro_kernel_fp {
- asmjit::CodeHolder code;
- code.init(rt_.codeInfo());
- x86::Assembler assembler(&code);
- x86::Emitter *a = assembler.as<x86::Emitter>();
+ if (codeCache_.find(kernelSig) != codeCache_.end()) {
+ return codeCache_[kernelSig];
+ }
+ code_.reset(false);
+ code_.init(rt_.codeInfo());
+ x86::Assembler assembler(&code_);
+ x86::Emitter* a = assembler.as<x86::Emitter>();
#if defined(FBGEMM_LOG_CODE)
- // generated code logging
- FILE *codeLogfile = fopen(getCodeLoggingFile<inst_set_t::avx2>(
- accum, mc, nc, nBlock, kBlock, mRegBlockSize,
- nRegBlockSize, nRegBlockSizeMin)
- .c_str(),
- "w");
- asmjit::FileLogger *codeLogger = new asmjit::FileLogger(codeLogfile);
- if (codeLogger) {
- code.setLogger(codeLogger);
- }
+ // generated code logging
+ FILE* codeLogfile = fopen(
+ getCodeLoggingFile<inst_set_t::avx2>(
+ accum,
+ mc,
+ nc,
+ nBlock,
+ kBlock,
+ mRegBlockSize,
+ nRegBlockSize,
+ nRegBlockSizeMin)
+ .c_str(),
+ "w");
+ asmjit::FileLogger* codeLogger = new asmjit::FileLogger(codeLogfile);
+ if (codeLogger) {
+ code_.setLogger(codeLogger);
+ }
#endif
- // assert(mc <= 12 && "mc must be <= 12 (available registers constraint)");
- int mRegBlocks = mc / mRegBlockSize;
- int mRegBlocksRem = mc % mRegBlockSize;
-
- // arguments to the function created
- x86::Gp buffer_A = a->zdi();
- x86::Gp buffer_B = a->zsi();
- x86::Gp B_pf = a->zdx();
- x86::Gp CBase = a->zcx();
- x86::Gp kSize = a->gpz(8);
- x86::Gp ldcReg = a->gpz(9);
-
- asmjit::FuncDetail func;
- func.init(
- asmjit::FuncSignatureT<void, uint8_t *, int8_t *, int8_t *, int32_t *,
- int, int>(asmjit::CallConv::kIdHost));
-
- asmjit::FuncFrame frame;
- frame.init(func);
- frame.setDirtyRegs(
- x86::Reg::kGroupVec,
- asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) |
- asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
- frame.setDirtyRegs(x86::Reg::kGroupGp,
- asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14));
-
- asmjit::FuncArgsAssignment args(&func);
- args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg);
-
- args.updateFuncFrame(frame);
- frame.finalize();
-
- a->emitProlog(frame);
- a->emitArgsAssignment(frame, args);
-
- asmjit::Label Loopk = a->newLabel();
- asmjit::Label LoopMBlocks = a->newLabel();
-
- x86::Gp buffer_B_saved = a->gpz(10);
- x86::Gp C_Offset = a->gpz(11);
- x86::Gp B_pf_saved = a->gpz(12);
- x86::Gp iIdx = a->gpz(13);
- x86::Gp kIdx = a->gpz(14);
- // x86::Gp B_pf = a->gpz(8);
-
- x86::Ymm oneReg = x86::ymm15;
- // create 16-bit 1s
- // i.e., oneReg[0:15] contains 0x0001, oneReg[16:31] contains 0x0001
- // and so on
- a->vpcmpeqw(oneReg, oneReg, oneReg);
- a->vpsrlw(oneReg, oneReg, 15);
- a->imul(ldcReg, ldcReg, static_cast<asmjit::Imm>(sizeof(int32_t)));
- a->mov(C_Offset, 0);
-
- int colRegs = nc * row_interleave * sizeof(int8_t) / VLEN_;
- if (mRegBlocks > 0) {
- // move 0 to iteration variables
- a->mov(iIdx, 0);
-
- // save B_buffer address
- a->mov(buffer_B_saved, buffer_B);
- a->mov(B_pf_saved, B_pf);
-
- a->bind(LoopMBlocks);
- a->inc(iIdx);
-
- int rowRegs = mRegBlockSize;
-
- // init C registers
- initCRegs<inst_set_t::avx2>(a, rowRegs, colRegs, colRegs);
-
- // init k loop index
- a->mov(kIdx, 0);
- a->bind(Loopk);
-
- // k is incremented by row_interleave
- a->add(kIdx, static_cast<asmjit::Imm>(row_interleave));
-
- genComputeBlock<inst_set_t::avx2>(a, buffer_A, buffer_B, B_pf, rowRegs,
- colRegs, kBlock, colRegs);
-
- // update buffer_A address for next k iteration
- a->add(buffer_A,
- static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t)));
-
- // update buffer_B address for next k iteration
- a->add(buffer_B, static_cast<asmjit::Imm>(nBlock * row_interleave *
- sizeof(int8_t)));
- a->add(B_pf, static_cast<asmjit::Imm>(nBlock * row_interleave *
- sizeof(int8_t)));
-
- // a->add(B_pf, 32*sizeof(float));
-
- a->cmp(kIdx, kSize);
- a->jl(Loopk);
+ // assert(mc <= 12 && "mc must be <= 12 (available registers constraint)");
+ int mRegBlocks = mc / mRegBlockSize;
+ int mRegBlocksRem = mc % mRegBlockSize;
+
+ // arguments to the function created
+ x86::Gp buffer_A = a->zdi();
+ x86::Gp buffer_B = a->zsi();
+ x86::Gp B_pf = a->zdx();
+ x86::Gp CBase = a->zcx();
+ x86::Gp kSize = a->gpz(8);
+ x86::Gp ldcReg = a->gpz(9);
+
+ asmjit::FuncDetail func;
+ func.init(
+ asmjit::
+ FuncSignatureT<void, uint8_t*, int8_t*, int8_t*, int32_t*, int, int>(
+ asmjit::CallConv::kIdHost));
+
+ asmjit::FuncFrame frame;
+ frame.init(func);
+ frame.setDirtyRegs(
+ x86::Reg::kGroupVec,
+ asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) |
+ asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
+ frame.setDirtyRegs(
+ x86::Reg::kGroupGp, asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14));
+
+ asmjit::FuncArgsAssignment args(&func);
+ args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg);
+
+ args.updateFuncFrame(frame);
+ frame.finalize();
+
+ a->emitProlog(frame);
+ a->emitArgsAssignment(frame, args);
+
+ asmjit::Label Loopk = a->newLabel();
+ asmjit::Label LoopMBlocks = a->newLabel();
+
+ x86::Gp buffer_B_saved = a->gpz(10);
+ x86::Gp C_Offset = a->gpz(11);
+ x86::Gp B_pf_saved = a->gpz(12);
+ x86::Gp iIdx = a->gpz(13);
+ x86::Gp kIdx = a->gpz(14);
+ // x86::Gp B_pf = a->gpz(8);
- // store C matrix
- storeCRegs<inst_set_t::avx2>(a, rowRegs, colRegs, C_Offset, ldcReg, accum,
- colRegs);
-
- // increment A for next block
- a->sub(buffer_A, kSize);
- a->add(buffer_A,
- static_cast<asmjit::Imm>((rowRegs)*kBlock * sizeof(uint8_t)));
-
- // increment C for next block
- a->imul(C_Offset, ldcReg, static_cast<asmjit::Imm>(rowRegs));
- a->add(CBase, C_Offset);
- a->mov(C_Offset, 0);
-
- // reset B
- a->mov(buffer_B, buffer_B_saved);
- a->mov(B_pf, B_pf_saved);
- a->cmp(iIdx, mRegBlocks);
- a->jl(LoopMBlocks);
- }
- // generate code for remainder
- if (mRegBlocksRem > 0) {
- asmjit::Label LoopkRem = a->newLabel();
- int rowRegs = mRegBlocksRem;
-
- // init C registers
- initCRegs<inst_set_t::avx2>(a, rowRegs, colRegs, colRegs);
-
- // init k loop index
- a->mov(kIdx, 0);
- a->bind(LoopkRem);
-
- // k is incremented by row_interleave
- a->add(kIdx, static_cast<asmjit::Imm>(row_interleave));
-
- genComputeBlock<inst_set_t::avx2>(a, buffer_A, buffer_B, B_pf, rowRegs,
- colRegs, kBlock, colRegs);
-
- // update buffer_A address for next k iteration
- a->add(buffer_A,
- static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t)));
-
- // update buffer_B address for next k iteration
- a->add(buffer_B, static_cast<asmjit::Imm>(nBlock * row_interleave *
- sizeof(int8_t)));
- a->add(B_pf, static_cast<asmjit::Imm>(nBlock * row_interleave *
- sizeof(int8_t)));
-
- a->cmp(kIdx, kSize);
- a->jl(LoopkRem);
+ x86::Ymm oneReg = x86::ymm15;
+ // create 16-bit 1s
+ // i.e., oneReg[0:15] contains 0x0001, oneReg[16:31] contains 0x0001
+ // and so on
+ a->vpcmpeqw(oneReg, oneReg, oneReg);
+ a->vpsrlw(oneReg, oneReg, 15);
+ a->imul(ldcReg, ldcReg, static_cast<asmjit::Imm>(sizeof(int32_t)));
+ a->mov(C_Offset, 0);
+
+ int colRegs = nc * row_interleave * sizeof(int8_t) / VLEN_;
+ if (mRegBlocks > 0) {
+ // move 0 to iteration variables
+ a->mov(iIdx, 0);
+
+ // save B_buffer address
+ a->mov(buffer_B_saved, buffer_B);
+ a->mov(B_pf_saved, B_pf);
+
+ a->bind(LoopMBlocks);
+ a->inc(iIdx);
+
+ int rowRegs = mRegBlockSize;
+
+ // init C registers
+ initCRegs<inst_set_t::avx2>(a, rowRegs, colRegs, colRegs);
+
+ // init k loop index
+ a->mov(kIdx, 0);
+ a->bind(Loopk);
+
+ // k is incremented by row_interleave
+ a->add(kIdx, static_cast<asmjit::Imm>(row_interleave));
+
+ genComputeBlock<inst_set_t::avx2>(
+ a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock, colRegs);
+
+ // update buffer_A address for next k iteration
+ a->add(
+ buffer_A, static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t)));
+
+ // update buffer_B address for next k iteration
+ a->add(
+ buffer_B,
+ static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t)));
+ a->add(
+ B_pf,
+ static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t)));
+
+ // a->add(B_pf, 32*sizeof(float));
+
+ a->cmp(kIdx, kSize);
+ a->jl(Loopk);
+
+ // store C matrix
+ storeCRegs<inst_set_t::avx2>(
+ a, rowRegs, colRegs, C_Offset, ldcReg, accum, colRegs);
+
+ // increment A for next block
+ a->sub(buffer_A, kSize);
+ a->add(
+ buffer_A, static_cast<asmjit::Imm>((rowRegs)*kBlock * sizeof(uint8_t)));
+
+ // increment C for next block
+ a->imul(C_Offset, ldcReg, static_cast<asmjit::Imm>(rowRegs));
+ a->add(CBase, C_Offset);
+ a->mov(C_Offset, 0);
- // store C matrix
- storeCRegs<inst_set_t::avx2>(a, rowRegs, colRegs, C_Offset, ldcReg, accum,
- colRegs);
- }
+ // reset B
+ a->mov(buffer_B, buffer_B_saved);
+ a->mov(B_pf, B_pf_saved);
+ a->cmp(iIdx, mRegBlocks);
+ a->jl(LoopMBlocks);
+ }
+ // generate code for remainder
+ if (mRegBlocksRem > 0) {
+ asmjit::Label LoopkRem = a->newLabel();
+ int rowRegs = mRegBlocksRem;
+
+ // init C registers
+ initCRegs<inst_set_t::avx2>(a, rowRegs, colRegs, colRegs);
+
+ // init k loop index
+ a->mov(kIdx, 0);
+ a->bind(LoopkRem);
+
+ // k is incremented by row_interleave
+ a->add(kIdx, static_cast<asmjit::Imm>(row_interleave));
+
+ genComputeBlock<inst_set_t::avx2>(
+ a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock, colRegs);
+
+ // update buffer_A address for next k iteration
+ a->add(
+ buffer_A, static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t)));
+
+ // update buffer_B address for next k iteration
+ a->add(
+ buffer_B,
+ static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t)));
+ a->add(
+ B_pf,
+ static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t)));
+
+ a->cmp(kIdx, kSize);
+ a->jl(LoopkRem);
+
+ // store C matrix
+ storeCRegs<inst_set_t::avx2>(
+ a, rowRegs, colRegs, C_Offset, ldcReg, accum, colRegs);
+ }
- a->emitEpilog(frame);
+ a->emitEpilog(frame);
- jit_micro_kernel_fp fn;
- asmjit::Error err;
- {
- std::unique_lock<std::mutex> lock(rtMutex_);
- err = rt_.add(&fn, &code);
- }
- if (err) {
- std::cout << "Error: in fn add" << std::endl;
- return nullptr;
- }
+ jit_micro_kernel_fp fn;
+ asmjit::Error err = rt_.add(&fn, &code_);
+ if (err) {
+ std::cout << "Error: in fn add" << std::endl;
+ return nullptr;
+ }
+ codeCache_[kernelSig] = fn;
#if defined(FBGEMM_LOG_CODE)
- fclose(codeLogfile);
- delete codeLogger;
+ fclose(codeLogfile);
+ delete codeLogger;
#endif
- return fn;
- });
+ return fn;
}
} // namespace fbgemm
diff --git a/src/GenerateKernelU8S8S32ACC32Avx512.cc b/src/GenerateKernelU8S8S32ACC32Avx512.cc
index 16f78d0..fe35627 100644
--- a/src/GenerateKernelU8S8S32ACC32Avx512.cc
+++ b/src/GenerateKernelU8S8S32ACC32Avx512.cc
@@ -163,260 +163,278 @@ CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreate<inst_set_t::avx512>(
nRegBlockSize,
nRegBlockSizeMin);
- return codeCache_.getOrCreate(kernelSig, [&]() -> jit_micro_kernel_fp {
- asmjit::CodeHolder code;
- code.init(rt_.codeInfo());
- x86::Assembler assembler(&code);
- x86::Emitter *a = assembler.as<x86::Emitter>();
+ if (codeCache_.find(kernelSig) != codeCache_.end()) {
+ return codeCache_[kernelSig];
+ }
+ code_.reset(false);
+ code_.init(rt_.codeInfo());
+ x86::Assembler assembler(&code_);
+ x86::Emitter* a = assembler.as<x86::Emitter>();
+
#if defined(FBGEMM_LOG_CODE)
- // generated code logging
- FILE *codeLogfile = fopen(getCodeLoggingFile<inst_set_t::avx512>(
- accum, mc, nc, nBlock, kBlock, mRegBlockSize,
- nRegBlockSize, nRegBlockSizeMin)
- .c_str(),
- "w");
- asmjit::FileLogger *codeLogger = new asmjit::FileLogger(codeLogfile);
- if (codeLogger) {
- code.setLogger(codeLogger);
- }
+ // generated code logging
+ FILE* codeLogfile = fopen(
+ getCodeLoggingFile<inst_set_t::avx512>(
+ accum,
+ mc,
+ nc,
+ nBlock,
+ kBlock,
+ mRegBlockSize,
+ nRegBlockSize,
+ nRegBlockSizeMin)
+ .c_str(),
+ "w");
+ asmjit::FileLogger* codeLogger = new asmjit::FileLogger(codeLogfile);
+ if (codeLogger) {
+ code_.setLogger(codeLogger);
+ }
#endif
- assert(kc % row_interleave == 0 &&
- "kc must be a multiple of row_interleave");
- assert(nc % nRegBlockSizeMin == 0 && "nc must be a multiple of NR_MIN");
- int maxMRegs = mRegBlockSize;
- int maxNRegs = nRegBlockSize * row_interleave / VLEN_;
- assert(maxMRegs * maxNRegs <= 28 && "MR*(NR*ROW_INTERLEAVE*8/512) \
+ assert(kc % row_interleave == 0 && "kc must be a multiple of row_interleave");
+ assert(nc % nRegBlockSizeMin == 0 && "nc must be a multiple of NR_MIN");
+ int maxMRegs = mRegBlockSize;
+ int maxNRegs = nRegBlockSize * row_interleave / VLEN_;
+ assert(
+ maxMRegs * maxNRegs <= 28 &&
+ "MR*(NR*ROW_INTERLEAVE*8/512) \
must be <= 28(available registers constraint)");
- int mRegBlocks = mc / mRegBlockSize;
- int mRegBlocksRem = mc % mRegBlockSize;
-
- // arguments to the function created
- x86::Gp buffer_A = a->zdi();
- x86::Gp buffer_B = a->zsi();
- x86::Gp B_pf = a->zdx();
- x86::Gp CBase = a->zcx();
- x86::Gp kSize = a->gpz(8);
- x86::Gp ldcReg = a->gpz(9);
-
- asmjit::FuncDetail func;
- func.init(
- asmjit::FuncSignatureT<void, uint8_t *, int8_t *, int8_t *, int32_t *,
- int, int>(asmjit::CallConv::kIdHost));
-
- asmjit::FuncFrame frame;
- frame.init(func);
-
- frame.setDirtyRegs(
- x86::Reg::kGroupVec,
- asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) |
- asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
- frame.setDirtyRegs(x86::Reg::kGroupGp,
- asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
-
- asmjit::FuncArgsAssignment args(&func);
- args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg);
-
- args.updateFuncFrame(frame);
- frame.finalize();
-
- a->emitProlog(frame);
- a->emitArgsAssignment(frame, args);
-
- asmjit::Label LoopMBlocks = a->newLabel();
- asmjit::Label LoopNBlocks = a->newLabel();
- asmjit::Label Loopk = a->newLabel();
-
- x86::Gp buffer_B_saved = a->gpz(10);
- x86::Gp C_Offset = a->gpz(11);
- x86::Gp B_pf_saved = a->gpz(12);
- x86::Gp iIdx = a->gpz(13);
- x86::Gp jIdx = a->gpz(14);
- x86::Gp kIdx = a->gpz(15);
- // x86::Gp B_pf = a->gpz(8);
-
- x86::Zmm oneReg = x86::zmm29;
- // create 16-bit 1s
- // i.e., oneReg[0:15] contains 0x0001, oneReg[16:31] contains 0x0001
- // and so on
- // a->vpcmpeqw(oneReg, oneReg, oneReg);
- a->vpternlogd(oneReg, oneReg, oneReg, 0xff);
- a->vpsrlw(oneReg, oneReg, 15);
- a->imul(ldcReg, ldcReg, static_cast<asmjit::Imm>(sizeof(int32_t)));
-
- // save B_buffer address
- a->mov(buffer_B_saved, buffer_B);
- a->mov(B_pf_saved, B_pf);
-
- int currColRegs = nc * row_interleave * sizeof(int8_t) / VLEN_;
- int colRegs = std::min(currColRegs, maxNRegs);
- if (mRegBlocks > 0) {
- // move 0 to iteration variables
- a->mov(iIdx, 0);
-
- a->bind(LoopMBlocks);
- a->inc(iIdx);
- a->mov(jIdx, 0);
-
- a->bind(LoopNBlocks);
- a->inc(jIdx);
-
- int rowRegs = mRegBlockSize;
-
- // init C registers
- initCRegs<inst_set_t::avx512>(a, rowRegs, colRegs, colRegs);
-
- // init k loop index
- a->mov(kIdx, 0);
- a->bind(Loopk);
-
- // k is incremented by row_interleave
- a->add(kIdx, static_cast<asmjit::Imm>(row_interleave));
-
- genComputeBlock<inst_set_t::avx512>(a, buffer_A, buffer_B, B_pf, rowRegs,
- colRegs, kBlock, colRegs);
-
- // update buffer_A address for next k iteration
- a->add(buffer_A,
- static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t)));
-
- // update buffer_B address for next k iteration
- a->add(buffer_B, static_cast<asmjit::Imm>(nBlock * row_interleave *
- sizeof(int8_t)));
- a->add(B_pf, static_cast<asmjit::Imm>(nBlock * row_interleave *
- sizeof(int8_t)));
-
- // a->add(B_pf, static_cast<asmjit::Imm>(32*sizeof(float)));
-
- a->cmp(kIdx, kSize);
- a->jl(Loopk);
-
- // store C matrix
- storeCRegs<inst_set_t::avx512>(a, rowRegs, colRegs, C_Offset, ldcReg,
- accum, colRegs);
-
- // reset A
- a->sub(buffer_A, kSize);
-
- // B for next block
- a->mov(buffer_B, buffer_B_saved);
- // using C_Offset as temp reg
- a->imul(C_Offset, jIdx,
- static_cast<asmjit::Imm>(nRegBlockSize * row_interleave *
- sizeof(int8_t)));
- a->add(buffer_B, C_Offset);
- a->mov(B_pf, B_pf_saved);
- a->add(B_pf, C_Offset);
-
- // increment C for next B block
- a->add(CBase, static_cast<asmjit::Imm>(nRegBlockSize * sizeof(int32_t)));
-
- int jLoopTrips = currColRegs / maxNRegs;
- // jLoopTrips should be at least 1
- jLoopTrips = jLoopTrips ? jLoopTrips : 1;
- a->cmp(jIdx, jLoopTrips);
- a->jl(LoopNBlocks);
-
- // increment A for next block
- a->add(buffer_A,
- static_cast<asmjit::Imm>((rowRegs)*kBlock * sizeof(uint8_t)));
-
- // increment C for next A block
- a->sub(CBase, static_cast<asmjit::Imm>(jLoopTrips * nRegBlockSize *
- sizeof(int32_t)));
- a->imul(C_Offset, ldcReg, static_cast<asmjit::Imm>(rowRegs));
- a->add(CBase, C_Offset);
-
- // reset B
- a->mov(buffer_B, buffer_B_saved);
- a->mov(B_pf, B_pf_saved);
- a->cmp(iIdx, mRegBlocks);
- a->jl(LoopMBlocks);
- }
- // generate code for remainder
- if (mRegBlocksRem > 0) {
- asmjit::Label LoopNRem = a->newLabel();
- asmjit::Label LoopkRem = a->newLabel();
- int rowRegs = mRegBlocksRem;
-
- a->mov(jIdx, 0);
- a->bind(LoopNRem);
- a->inc(jIdx);
-
- // init C registers
- initCRegs<inst_set_t::avx512>(a, rowRegs, colRegs, colRegs);
-
- // init k loop index
- a->mov(kIdx, 0);
- a->bind(LoopkRem);
-
- // k is incremented by row_interleave
- a->add(kIdx, static_cast<asmjit::Imm>(row_interleave));
-
- genComputeBlock<inst_set_t::avx512>(a, buffer_A, buffer_B, B_pf, rowRegs,
- colRegs, kBlock, colRegs);
-
- // update buffer_A address for next k iteration
- a->add(buffer_A,
- static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t)));
-
- // update buffer_B address for next k iteration
- a->add(buffer_B, static_cast<asmjit::Imm>(nBlock * row_interleave *
- sizeof(int8_t)));
- a->add(B_pf, static_cast<asmjit::Imm>(nBlock * row_interleave *
- sizeof(int8_t)));
-
- a->cmp(kIdx, kSize);
- a->jl(LoopkRem);
-
- // reset A
- a->sub(buffer_A, kSize);
-
- // B for next block
- // using C_Offset as temp reg
- a->imul(C_Offset, jIdx,
- static_cast<asmjit::Imm>(nRegBlockSize * row_interleave *
- sizeof(int8_t)));
- a->mov(buffer_B, buffer_B_saved);
- a->add(buffer_B, C_Offset);
- a->mov(B_pf, B_pf_saved);
- a->add(B_pf, C_Offset);
-
- // store C matrix
- storeCRegs<inst_set_t::avx512>(a, rowRegs, colRegs, C_Offset, ldcReg,
- accum, colRegs);
-
- // increment C for next B block
- a->add(CBase, static_cast<asmjit::Imm>(nRegBlockSize * sizeof(int32_t)));
-
- int jLoopTrips = currColRegs / maxNRegs;
- // jLoopTrips should be at least 1
- jLoopTrips = jLoopTrips ? jLoopTrips : 1;
- a->cmp(jIdx, jLoopTrips);
- a->jl(LoopNRem);
- }
+ int mRegBlocks = mc / mRegBlockSize;
+ int mRegBlocksRem = mc % mRegBlockSize;
+
+ // arguments to the function created
+ x86::Gp buffer_A = a->zdi();
+ x86::Gp buffer_B = a->zsi();
+ x86::Gp B_pf = a->zdx();
+ x86::Gp CBase = a->zcx();
+ x86::Gp kSize = a->gpz(8);
+ x86::Gp ldcReg = a->gpz(9);
+
+ asmjit::FuncDetail func;
+ func.init(
+ asmjit::
+ FuncSignatureT<void, uint8_t*, int8_t*, int8_t*, int32_t*, int, int>(
+ asmjit::CallConv::kIdHost));
+
+ asmjit::FuncFrame frame;
+ frame.init(func);
+
+ frame.setDirtyRegs(
+ x86::Reg::kGroupVec,
+ asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) |
+ asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
+ frame.setDirtyRegs(
+ x86::Reg::kGroupGp,
+ asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
+
+ asmjit::FuncArgsAssignment args(&func);
+ args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg);
+
+ args.updateFuncFrame(frame);
+ frame.finalize();
+
+ a->emitProlog(frame);
+ a->emitArgsAssignment(frame, args);
+
+ asmjit::Label LoopMBlocks = a->newLabel();
+ asmjit::Label LoopNBlocks = a->newLabel();
+ asmjit::Label Loopk = a->newLabel();
+
+ x86::Gp buffer_B_saved = a->gpz(10);
+ x86::Gp C_Offset = a->gpz(11);
+ x86::Gp B_pf_saved = a->gpz(12);
+ x86::Gp iIdx = a->gpz(13);
+ x86::Gp jIdx = a->gpz(14);
+ x86::Gp kIdx = a->gpz(15);
+ // x86::Gp B_pf = a->gpz(8);
- a->emitEpilog(frame);
+ x86::Zmm oneReg = x86::zmm29;
+ // create 16-bit 1s
+ // i.e., oneReg[0:15] contains 0x0001, oneReg[16:31] contains 0x0001
+ // and so on
+ // a->vpcmpeqw(oneReg, oneReg, oneReg);
+ a->vpternlogd(oneReg, oneReg, oneReg, 0xff);
+ a->vpsrlw(oneReg, oneReg, 15);
+ a->imul(ldcReg, ldcReg, static_cast<asmjit::Imm>(sizeof(int32_t)));
+
+ // save B_buffer address
+ a->mov(buffer_B_saved, buffer_B);
+ a->mov(B_pf_saved, B_pf);
+
+ int currColRegs = nc * row_interleave * sizeof(int8_t) / VLEN_;
+ int colRegs = std::min(currColRegs, maxNRegs);
+ if (mRegBlocks > 0) {
+ // move 0 to iteration variables
+ a->mov(iIdx, 0);
+
+ a->bind(LoopMBlocks);
+ a->inc(iIdx);
+ a->mov(jIdx, 0);
+
+ a->bind(LoopNBlocks);
+ a->inc(jIdx);
+
+ int rowRegs = mRegBlockSize;
+
+ // init C registers
+ initCRegs<inst_set_t::avx512>(a, rowRegs, colRegs, colRegs);
+
+ // init k loop index
+ a->mov(kIdx, 0);
+ a->bind(Loopk);
+
+ // k is incremented by row_interleave
+ a->add(kIdx, static_cast<asmjit::Imm>(row_interleave));
+
+ genComputeBlock<inst_set_t::avx512>(
+ a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock, colRegs);
+
+ // update buffer_A address for next k iteration
+ a->add(
+ buffer_A, static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t)));
+
+ // update buffer_B address for next k iteration
+ a->add(
+ buffer_B,
+ static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t)));
+ a->add(
+ B_pf,
+ static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t)));
+
+ // a->add(B_pf, static_cast<asmjit::Imm>(32*sizeof(float)));
+
+ a->cmp(kIdx, kSize);
+ a->jl(Loopk);
+
+ // store C matrix
+ storeCRegs<inst_set_t::avx512>(
+ a, rowRegs, colRegs, C_Offset, ldcReg, accum, colRegs);
+
+ // reset A
+ a->sub(buffer_A, kSize);
+
+ // B for next block
+ a->mov(buffer_B, buffer_B_saved);
+ // using C_Offset as temp reg
+ a->imul(
+ C_Offset,
+ jIdx,
+ static_cast<asmjit::Imm>(
+ nRegBlockSize * row_interleave * sizeof(int8_t)));
+ a->add(buffer_B, C_Offset);
+ a->mov(B_pf, B_pf_saved);
+ a->add(B_pf, C_Offset);
+
+ // increment C for next B block
+ a->add(CBase, static_cast<asmjit::Imm>(nRegBlockSize * sizeof(int32_t)));
+
+ int jLoopTrips = currColRegs / maxNRegs;
+ // jLoopTrips should be at least 1
+ jLoopTrips = jLoopTrips ? jLoopTrips : 1;
+ a->cmp(jIdx, jLoopTrips);
+ a->jl(LoopNBlocks);
+
+ // increment A for next block
+ a->add(
+ buffer_A, static_cast<asmjit::Imm>((rowRegs)*kBlock * sizeof(uint8_t)));
+
+ // increment C for next A block
+ a->sub(
+ CBase,
+ static_cast<asmjit::Imm>(jLoopTrips * nRegBlockSize * sizeof(int32_t)));
+ a->imul(C_Offset, ldcReg, static_cast<asmjit::Imm>(rowRegs));
+ a->add(CBase, C_Offset);
+
+ // reset B
+ a->mov(buffer_B, buffer_B_saved);
+ a->mov(B_pf, B_pf_saved);
+ a->cmp(iIdx, mRegBlocks);
+ a->jl(LoopMBlocks);
+ }
+ // generate code for remainder
+ if (mRegBlocksRem > 0) {
+ asmjit::Label LoopNRem = a->newLabel();
+ asmjit::Label LoopkRem = a->newLabel();
+ int rowRegs = mRegBlocksRem;
+
+ a->mov(jIdx, 0);
+ a->bind(LoopNRem);
+ a->inc(jIdx);
+
+ // init C registers
+ initCRegs<inst_set_t::avx512>(a, rowRegs, colRegs, colRegs);
+
+ // init k loop index
+ a->mov(kIdx, 0);
+ a->bind(LoopkRem);
+
+ // k is incremented by row_interleave
+ a->add(kIdx, static_cast<asmjit::Imm>(row_interleave));
+
+ genComputeBlock<inst_set_t::avx512>(
+ a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock, colRegs);
+
+ // update buffer_A address for next k iteration
+ a->add(
+ buffer_A, static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t)));
+
+ // update buffer_B address for next k iteration
+ a->add(
+ buffer_B,
+ static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t)));
+ a->add(
+ B_pf,
+ static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t)));
+
+ a->cmp(kIdx, kSize);
+ a->jl(LoopkRem);
+
+ // reset A
+ a->sub(buffer_A, kSize);
+
+ // B for next block
+ // using C_Offset as temp reg
+ a->imul(
+ C_Offset,
+ jIdx,
+ static_cast<asmjit::Imm>(
+ nRegBlockSize * row_interleave * sizeof(int8_t)));
+ a->mov(buffer_B, buffer_B_saved);
+ a->add(buffer_B, C_Offset);
+ a->mov(B_pf, B_pf_saved);
+ a->add(B_pf, C_Offset);
+
+ // store C matrix
+ storeCRegs<inst_set_t::avx512>(
+ a, rowRegs, colRegs, C_Offset, ldcReg, accum, colRegs);
+
+ // increment C for next B block
+ a->add(CBase, static_cast<asmjit::Imm>(nRegBlockSize * sizeof(int32_t)));
+
+ int jLoopTrips = currColRegs / maxNRegs;
+ // jLoopTrips should be at least 1
+ jLoopTrips = jLoopTrips ? jLoopTrips : 1;
+ a->cmp(jIdx, jLoopTrips);
+ a->jl(LoopNRem);
+ }
- jit_micro_kernel_fp fn;
- asmjit::Error err;
- {
- std::unique_lock<std::mutex> lock(rtMutex_);
- err = rt_.add(&fn, &code);
- }
- if (err) {
- std::cout << "Error: in fn add" << std::endl;
- return nullptr;
- }
+ a->emitEpilog(frame);
+
+ jit_micro_kernel_fp fn;
+ asmjit::Error err = rt_.add(&fn, &code_);
+ if (err) {
+ std::cout << "Error: in fn add" << std::endl;
+ return nullptr;
+ }
+ codeCache_[kernelSig] = fn;
#if defined(FBGEMM_LOG_CODE)
- fclose(codeLogfile);
- delete codeLogger;
+ fclose(codeLogfile);
+ delete codeLogger;
#endif
- return fn;
- });
+ return fn;
}
} // namespace fbgemm
diff --git a/src/GenerateKernelU8S8S32ACC32Avx512VNNI.cc b/src/GenerateKernelU8S8S32ACC32Avx512VNNI.cc
index 23103ae..8ae0745 100644
--- a/src/GenerateKernelU8S8S32ACC32Avx512VNNI.cc
+++ b/src/GenerateKernelU8S8S32ACC32Avx512VNNI.cc
@@ -155,260 +155,277 @@ CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreate<
nRegBlockSize,
nRegBlockSizeMin);
- return codeCache_.getOrCreate(kernelSig, [&]() -> jit_micro_kernel_fp {
- asmjit::CodeHolder code;
- code.init(rt_.codeInfo());
- x86::Assembler assembler(&code);
- x86::Emitter *a = assembler.as<x86::Emitter>();
+ if (codeCache_.find(kernelSig) != codeCache_.end()) {
+ return codeCache_[kernelSig];
+ }
+ code_.reset(false);
+ code_.init(rt_.codeInfo());
+ x86::Assembler assembler(&code_);
+ x86::Emitter* a = assembler.as<x86::Emitter>();
#if defined(FBGEMM_LOG_CODE)
- // generated code logging
- FILE *codeLogfile = fopen(getCodeLoggingFile<inst_set_t::avx512_vnni>(
- accum, mc, nc, nBlock, kBlock, mRegBlockSize,
- nRegBlockSize, nRegBlockSizeMin)
- .c_str(),
- "w");
- asmjit::FileLogger *codeLogger = new asmjit::FileLogger(codeLogfile);
- if (codeLogger) {
- code.setLogger(codeLogger);
- }
+ // generated code logging
+ FILE* codeLogfile = fopen(
+ getCodeLoggingFile<inst_set_t::avx512_vnni>(
+ accum,
+ mc,
+ nc,
+ nBlock,
+ kBlock,
+ mRegBlockSize,
+ nRegBlockSize,
+ nRegBlockSizeMin)
+ .c_str(),
+ "w");
+ asmjit::FileLogger* codeLogger = new asmjit::FileLogger(codeLogfile);
+ if (codeLogger) {
+ code_.setLogger(codeLogger);
+ }
#endif
- assert(kc % row_interleave == 0 &&
- "kc must be a multiple of row_interleave");
- assert(nc % nRegBlockSizeMin == 0 && "nc must be a multiple of NR_MIN");
- int maxMRegs = mRegBlockSize;
- int maxNRegs = nRegBlockSize * row_interleave / VLEN_;
- assert(maxMRegs * maxNRegs <= 28 && "MR*(NR*ROW_INTERLEAVE*8/512) \
+ assert(kc % row_interleave == 0 && "kc must be a multiple of row_interleave");
+ assert(nc % nRegBlockSizeMin == 0 && "nc must be a multiple of NR_MIN");
+ int maxMRegs = mRegBlockSize;
+ int maxNRegs = nRegBlockSize * row_interleave / VLEN_;
+ assert(
+ maxMRegs * maxNRegs <= 28 &&
+ "MR*(NR*ROW_INTERLEAVE*8/512) \
must be <= 28(available registers constraint)");
- int mRegBlocks = mc / mRegBlockSize;
- int mRegBlocksRem = mc % mRegBlockSize;
-
- // arguments to the function created
- x86::Gp buffer_A = a->zdi();
- x86::Gp buffer_B = a->zsi();
- x86::Gp B_pf = a->zdx();
- x86::Gp CBase = a->zcx();
- x86::Gp kSize = a->gpz(8);
- x86::Gp ldcReg = a->gpz(9);
-
- asmjit::FuncDetail func;
- func.init(
- asmjit::FuncSignatureT<void, uint8_t *, int8_t *, int8_t *, int32_t *,
- int, int>(asmjit::CallConv::kIdHost));
-
- asmjit::FuncFrame frame;
- frame.init(func);
-
- frame.setDirtyRegs(
- x86::Reg::kGroupVec,
- asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) |
- asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
- frame.setDirtyRegs(x86::Reg::kGroupGp,
- asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
-
- asmjit::FuncArgsAssignment args(&func);
- args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg);
-
- args.updateFuncFrame(frame);
- frame.finalize();
-
- a->emitProlog(frame);
- a->emitArgsAssignment(frame, args);
-
- asmjit::Label LoopMBlocks = a->newLabel();
- asmjit::Label LoopNBlocks = a->newLabel();
- asmjit::Label Loopk = a->newLabel();
-
- x86::Gp buffer_B_saved = a->gpz(10);
- x86::Gp C_Offset = a->gpz(11);
- x86::Gp B_pf_saved = a->gpz(12);
- x86::Gp iIdx = a->gpz(13);
- x86::Gp jIdx = a->gpz(14);
- x86::Gp kIdx = a->gpz(15);
- // x86::Gp B_pf = a->gpz(8);
-
- x86::Zmm oneReg = x86::zmm29;
- // create 16-bit 1s
- // i.e., oneReg[0:15] contains 0x0001, oneReg[16:31] contains 0x0001
- // and so on
- // a->vpcmpeqw(oneReg, oneReg, oneReg);
- a->vpternlogd(oneReg, oneReg, oneReg, 0xff);
- a->vpsrlw(oneReg, oneReg, 15);
- a->imul(ldcReg, ldcReg, static_cast<asmjit::Imm>(sizeof(int32_t)));
-
- // save B_buffer address
- a->mov(buffer_B_saved, buffer_B);
- a->mov(B_pf_saved, B_pf);
-
- int currColRegs = nc * row_interleave * sizeof(int8_t) / VLEN_;
- int colRegs = std::min(currColRegs, maxNRegs);
- if (mRegBlocks > 0) {
- // move 0 to iteration variables
- a->mov(iIdx, 0);
-
- a->bind(LoopMBlocks);
- a->inc(iIdx);
- a->mov(jIdx, 0);
-
- a->bind(LoopNBlocks);
- a->inc(jIdx);
-
- int rowRegs = mRegBlockSize;
-
- // init C registers
- initCRegs<inst_set_t::avx512_vnni>(a, rowRegs, colRegs, colRegs);
-
- // init k loop index
- a->mov(kIdx, 0);
- a->bind(Loopk);
-
- // k is incremented by row_interleave
- a->add(kIdx, static_cast<asmjit::Imm>(row_interleave));
-
- genComputeBlock<inst_set_t::avx512_vnni>(
- a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock, colRegs);
-
- // update buffer_A address for next k iteration
- a->add(buffer_A,
- static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t)));
-
- // update buffer_B address for next k iteration
- a->add(buffer_B, static_cast<asmjit::Imm>(nBlock * row_interleave *
- sizeof(int8_t)));
- a->add(B_pf, static_cast<asmjit::Imm>(nBlock * row_interleave *
- sizeof(int8_t)));
-
- // a->add(B_pf, static_cast<asmjit::Imm>(32*sizeof(float)));
-
- a->cmp(kIdx, kSize);
- a->jl(Loopk);
-
- // store C matrix
- storeCRegs<inst_set_t::avx512_vnni>(a, rowRegs, colRegs, C_Offset, ldcReg,
- accum, colRegs);
-
- // reset A
- a->sub(buffer_A, kSize);
-
- // B for next block
- a->mov(buffer_B, buffer_B_saved);
- // using C_Offset as temp reg
- a->imul(C_Offset, jIdx,
- static_cast<asmjit::Imm>(nRegBlockSize * row_interleave *
- sizeof(int8_t)));
- a->add(buffer_B, C_Offset);
- a->mov(B_pf, B_pf_saved);
- a->add(B_pf, C_Offset);
-
- // increment C for next B block
- a->add(CBase, static_cast<asmjit::Imm>(nRegBlockSize * sizeof(int32_t)));
-
- int jLoopTrips = currColRegs / maxNRegs;
- // jLoopTrips should be at least 1
- jLoopTrips = jLoopTrips ? jLoopTrips : 1;
- a->cmp(jIdx, jLoopTrips);
- a->jl(LoopNBlocks);
-
- // increment A for next block
- a->add(buffer_A,
- static_cast<asmjit::Imm>((rowRegs)*kBlock * sizeof(uint8_t)));
-
- // increment C for next A block
- a->sub(CBase, static_cast<asmjit::Imm>(jLoopTrips * nRegBlockSize *
- sizeof(int32_t)));
- a->imul(C_Offset, ldcReg, static_cast<asmjit::Imm>(rowRegs));
- a->add(CBase, C_Offset);
-
- // reset B
- a->mov(buffer_B, buffer_B_saved);
- a->mov(B_pf, B_pf_saved);
- a->cmp(iIdx, mRegBlocks);
- a->jl(LoopMBlocks);
- }
- // generate code for remainder
- if (mRegBlocksRem > 0) {
- asmjit::Label LoopNRem = a->newLabel();
- asmjit::Label LoopkRem = a->newLabel();
- int rowRegs = mRegBlocksRem;
-
- a->mov(jIdx, 0);
- a->bind(LoopNRem);
- a->inc(jIdx);
-
- // init C registers
- initCRegs<inst_set_t::avx512_vnni>(a, rowRegs, colRegs, colRegs);
-
- // init k loop index
- a->mov(kIdx, 0);
- a->bind(LoopkRem);
-
- // k is incremented by row_interleave
- a->add(kIdx, static_cast<asmjit::Imm>(row_interleave));
-
- genComputeBlock<inst_set_t::avx512_vnni>(
- a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock, colRegs);
-
- // update buffer_A address for next k iteration
- a->add(buffer_A,
- static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t)));
-
- // update buffer_B address for next k iteration
- a->add(buffer_B, static_cast<asmjit::Imm>(nBlock * row_interleave *
- sizeof(int8_t)));
- a->add(B_pf, static_cast<asmjit::Imm>(nBlock * row_interleave *
- sizeof(int8_t)));
-
- a->cmp(kIdx, kSize);
- a->jl(LoopkRem);
-
- // reset A
- a->sub(buffer_A, kSize);
- // B for next block
- // using C_Offset as temp reg
- a->imul(C_Offset, jIdx,
- static_cast<asmjit::Imm>(nRegBlockSize * row_interleave *
- sizeof(int8_t)));
- a->mov(buffer_B, buffer_B_saved);
- a->add(buffer_B, C_Offset);
- a->mov(B_pf, B_pf_saved);
- a->add(B_pf, C_Offset);
-
- // store C matrix
- storeCRegs<inst_set_t::avx512_vnni>(a, rowRegs, colRegs, C_Offset, ldcReg,
- accum, colRegs);
-
- // increment C for next B block
- a->add(CBase, static_cast<asmjit::Imm>(nRegBlockSize * sizeof(int32_t)));
-
- int jLoopTrips = currColRegs / maxNRegs;
- // jLoopTrips should be at least 1
- jLoopTrips = jLoopTrips ? jLoopTrips : 1;
- a->cmp(jIdx, jLoopTrips);
- a->jl(LoopNRem);
- }
+ int mRegBlocks = mc / mRegBlockSize;
+ int mRegBlocksRem = mc % mRegBlockSize;
+
+ // arguments to the function created
+ x86::Gp buffer_A = a->zdi();
+ x86::Gp buffer_B = a->zsi();
+ x86::Gp B_pf = a->zdx();
+ x86::Gp CBase = a->zcx();
+ x86::Gp kSize = a->gpz(8);
+ x86::Gp ldcReg = a->gpz(9);
+
+ asmjit::FuncDetail func;
+ func.init(
+ asmjit::
+ FuncSignatureT<void, uint8_t*, int8_t*, int8_t*, int32_t*, int, int>(
+ asmjit::CallConv::kIdHost));
+
+ asmjit::FuncFrame frame;
+ frame.init(func);
+
+ frame.setDirtyRegs(
+ x86::Reg::kGroupVec,
+ asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) |
+ asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
+ frame.setDirtyRegs(
+ x86::Reg::kGroupGp,
+ asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
+
+ asmjit::FuncArgsAssignment args(&func);
+ args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg);
+
+ args.updateFuncFrame(frame);
+ frame.finalize();
+
+ a->emitProlog(frame);
+ a->emitArgsAssignment(frame, args);
+
+ asmjit::Label LoopMBlocks = a->newLabel();
+ asmjit::Label LoopNBlocks = a->newLabel();
+ asmjit::Label Loopk = a->newLabel();
+
+ x86::Gp buffer_B_saved = a->gpz(10);
+ x86::Gp C_Offset = a->gpz(11);
+ x86::Gp B_pf_saved = a->gpz(12);
+ x86::Gp iIdx = a->gpz(13);
+ x86::Gp jIdx = a->gpz(14);
+ x86::Gp kIdx = a->gpz(15);
+ // x86::Gp B_pf = a->gpz(8);
+
+ x86::Zmm oneReg = x86::zmm29;
+ // create 16-bit 1s
+ // i.e., oneReg[0:15] contains 0x0001, oneReg[16:31] contains 0x0001
+ // and so on
+ // a->vpcmpeqw(oneReg, oneReg, oneReg);
+ a->vpternlogd(oneReg, oneReg, oneReg, 0xff);
+ a->vpsrlw(oneReg, oneReg, 15);
+ a->imul(ldcReg, ldcReg, static_cast<asmjit::Imm>(sizeof(int32_t)));
+
+ // save B_buffer address
+ a->mov(buffer_B_saved, buffer_B);
+ a->mov(B_pf_saved, B_pf);
+
+ int currColRegs = nc * row_interleave * sizeof(int8_t) / VLEN_;
+ int colRegs = std::min(currColRegs, maxNRegs);
+ if (mRegBlocks > 0) {
+ // move 0 to iteration variables
+ a->mov(iIdx, 0);
+
+ a->bind(LoopMBlocks);
+ a->inc(iIdx);
+ a->mov(jIdx, 0);
+
+ a->bind(LoopNBlocks);
+ a->inc(jIdx);
+
+ int rowRegs = mRegBlockSize;
+
+ // init C registers
+ initCRegs<inst_set_t::avx512_vnni>(a, rowRegs, colRegs, colRegs);
+
+ // init k loop index
+ a->mov(kIdx, 0);
+ a->bind(Loopk);
+
+ // k is incremented by row_interleave
+ a->add(kIdx, static_cast<asmjit::Imm>(row_interleave));
+
+ genComputeBlock<inst_set_t::avx512_vnni>(
+ a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock, colRegs);
+
+ // update buffer_A address for next k iteration
+ a->add(
+ buffer_A, static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t)));
+
+ // update buffer_B address for next k iteration
+ a->add(
+ buffer_B,
+ static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t)));
+ a->add(
+ B_pf,
+ static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t)));
+
+ // a->add(B_pf, static_cast<asmjit::Imm>(32*sizeof(float)));
+
+ a->cmp(kIdx, kSize);
+ a->jl(Loopk);
+
+ // store C matrix
+ storeCRegs<inst_set_t::avx512_vnni>(
+ a, rowRegs, colRegs, C_Offset, ldcReg, accum, colRegs);
+
+ // reset A
+ a->sub(buffer_A, kSize);
+
+ // B for next block
+ a->mov(buffer_B, buffer_B_saved);
+ // using C_Offset as temp reg
+ a->imul(
+ C_Offset,
+ jIdx,
+ static_cast<asmjit::Imm>(
+ nRegBlockSize * row_interleave * sizeof(int8_t)));
+ a->add(buffer_B, C_Offset);
+ a->mov(B_pf, B_pf_saved);
+ a->add(B_pf, C_Offset);
+
+ // increment C for next B block
+ a->add(CBase, static_cast<asmjit::Imm>(nRegBlockSize * sizeof(int32_t)));
+
+ int jLoopTrips = currColRegs / maxNRegs;
+ // jLoopTrips should be at least 1
+ jLoopTrips = jLoopTrips ? jLoopTrips : 1;
+ a->cmp(jIdx, jLoopTrips);
+ a->jl(LoopNBlocks);
+
+ // increment A for next block
+ a->add(
+ buffer_A, static_cast<asmjit::Imm>((rowRegs)*kBlock * sizeof(uint8_t)));
+
+ // increment C for next A block
+ a->sub(
+ CBase,
+ static_cast<asmjit::Imm>(jLoopTrips * nRegBlockSize * sizeof(int32_t)));
+ a->imul(C_Offset, ldcReg, static_cast<asmjit::Imm>(rowRegs));
+ a->add(CBase, C_Offset);
+
+ // reset B
+ a->mov(buffer_B, buffer_B_saved);
+ a->mov(B_pf, B_pf_saved);
+ a->cmp(iIdx, mRegBlocks);
+ a->jl(LoopMBlocks);
+ }
+ // generate code for remainder
+ if (mRegBlocksRem > 0) {
+ asmjit::Label LoopNRem = a->newLabel();
+ asmjit::Label LoopkRem = a->newLabel();
+ int rowRegs = mRegBlocksRem;
+
+ a->mov(jIdx, 0);
+ a->bind(LoopNRem);
+ a->inc(jIdx);
+
+ // init C registers
+ initCRegs<inst_set_t::avx512_vnni>(a, rowRegs, colRegs, colRegs);
+
+ // init k loop index
+ a->mov(kIdx, 0);
+ a->bind(LoopkRem);
+
+ // k is incremented by row_interleave
+ a->add(kIdx, static_cast<asmjit::Imm>(row_interleave));
+
+ genComputeBlock<inst_set_t::avx512_vnni>(
+ a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock, colRegs);
+
+ // update buffer_A address for next k iteration
+ a->add(
+ buffer_A, static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t)));
+
+ // update buffer_B address for next k iteration
+ a->add(
+ buffer_B,
+ static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t)));
+ a->add(
+ B_pf,
+ static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t)));
+
+ a->cmp(kIdx, kSize);
+ a->jl(LoopkRem);
+
+ // reset A
+ a->sub(buffer_A, kSize);
+ // B for next block
+ // using C_Offset as temp reg
+ a->imul(
+ C_Offset,
+ jIdx,
+ static_cast<asmjit::Imm>(
+ nRegBlockSize * row_interleave * sizeof(int8_t)));
+ a->mov(buffer_B, buffer_B_saved);
+ a->add(buffer_B, C_Offset);
+ a->mov(B_pf, B_pf_saved);
+ a->add(B_pf, C_Offset);
+
+ // store C matrix
+ storeCRegs<inst_set_t::avx512_vnni>(
+ a, rowRegs, colRegs, C_Offset, ldcReg, accum, colRegs);
+
+ // increment C for next B block
+ a->add(CBase, static_cast<asmjit::Imm>(nRegBlockSize * sizeof(int32_t)));
+
+ int jLoopTrips = currColRegs / maxNRegs;
+ // jLoopTrips should be at least 1
+ jLoopTrips = jLoopTrips ? jLoopTrips : 1;
+ a->cmp(jIdx, jLoopTrips);
+ a->jl(LoopNRem);
+ }
- a->emitEpilog(frame);
+ a->emitEpilog(frame);
- jit_micro_kernel_fp fn;
- asmjit::Error err;
- {
- std::unique_lock<std::mutex> lock(rtMutex_);
- err = rt_.add(&fn, &code);
- }
- if (err) {
- std::cout << "Error: in fn add" << std::endl;
- return nullptr;
- }
+ jit_micro_kernel_fp fn;
+ asmjit::Error err = rt_.add(&fn, &code_);
+ if (err) {
+ std::cout << "Error: in fn add" << std::endl;
+ return nullptr;
+ }
+ codeCache_[kernelSig] = fn;
#if defined(FBGEMM_LOG_CODE)
- fclose(codeLogfile);
- delete codeLogger;
+ fclose(codeLogfile);
+ delete codeLogger;
#endif
- return fn;
- });
+ return fn;
}
} // namespace fbgemm
diff --git a/src/GroupwiseConv.h b/src/GroupwiseConv.h
index cbcf445..4c5eea5 100644
--- a/src/GroupwiseConv.h
+++ b/src/GroupwiseConv.h
@@ -10,10 +10,8 @@
#include <cassert>
#include <cstdint>
#include <map>
-#include <mutex>
#include <string>
#include <tuple>
-#include "CodeCache.h"
#include "fbgemm/ConvUtils.h"
#include "fbgemm/Fbgemm.h"
#include "fbgemm/Utils.h"
@@ -219,15 +217,16 @@ class GenConvKernel {
template <inst_set_t instSet>
void storeResultRowoffset(x86::Emitter* a, int offset = 0);
- static asmjit::JitRuntime rt_; ///< JIT Runtime for asmjit.
- static std::mutex rtMutex_; ///< Controll access to rt_;
+ static thread_local asmjit::JitRuntime rt_; ///< JIT Runtime for asmjit.
+ static thread_local asmjit::CodeHolder code_; ///< JIT Code Holder for asmjit.
+ static thread_local std::
+ map<std::tuple<bool, int, int, int>, jit_conv_kernel_fp>
+ codeCache_; ///< JIT Code Cache for reuse.
+ static thread_local std::
+ map<std::tuple<bool, int, int, int>, jit_rowoffset_kernel_fp>
+ codeCacheRowOffset_; ///< JIT Code Cache for row offset kernel.
- static CodeCache<std::tuple<bool, int, int, int>, jit_conv_kernel_fp>
- codeCache_; ///< JIT Code Cache for reuse.
- static CodeCache<std::tuple<bool, int, int, int>, jit_rowoffset_kernel_fp>
- codeCacheRowOffset_; ///< JIT Code Cache for row offset kernel.
-
-private:
+ private:
int vectorWidth_; ///< Vector width in bits.
int VLEN_; ///< Vector width in elements.
// avx2 specific
@@ -273,18 +272,4 @@ private:
int W_PAD_; ///< Padding for width (left and right)
};
-template <int SPATIAL_DIM, typename accT>
-asmjit::JitRuntime GenConvKernel<SPATIAL_DIM, accT>::rt_;
-
-template <int SPATIAL_DIM, typename accT>
-std::mutex GenConvKernel<SPATIAL_DIM, accT>::rtMutex_;
-
-template <int SPATIAL_DIM, typename accT>
-CodeCache<std::tuple<bool, int, int, int>, jit_conv_kernel_fp>
- GenConvKernel<SPATIAL_DIM, accT>::codeCache_;
-
-template <int SPATIAL_DIM, typename accT>
-CodeCache<std::tuple<bool, int, int, int>, jit_rowoffset_kernel_fp>
- GenConvKernel<SPATIAL_DIM, accT>::codeCacheRowOffset_;
-
} // namespace fbgemm
diff --git a/src/GroupwiseConvAcc32Avx2.cc b/src/GroupwiseConvAcc32Avx2.cc
index c24c391..b140c83 100644
--- a/src/GroupwiseConvAcc32Avx2.cc
+++ b/src/GroupwiseConvAcc32Avx2.cc
@@ -21,6 +21,20 @@ namespace fbgemm {
using namespace std;
+template <int SPATIAL_DIM, typename accT>
+thread_local asmjit::JitRuntime GenConvKernel<SPATIAL_DIM, accT>::rt_;
+
+template <int SPATIAL_DIM, typename accT>
+thread_local asmjit::CodeHolder GenConvKernel<SPATIAL_DIM, accT>::code_;
+
+template <int SPATIAL_DIM, typename accT>
+thread_local std::map<std::tuple<bool, int, int, int>, jit_conv_kernel_fp>
+ GenConvKernel<SPATIAL_DIM, accT>::codeCache_;
+
+template <int SPATIAL_DIM, typename accT>
+thread_local std::map<std::tuple<bool, int, int, int>, jit_rowoffset_kernel_fp>
+ GenConvKernel<SPATIAL_DIM, accT>::codeCacheRowOffset_;
+
namespace x86 = asmjit::x86;
template <int SPATIAL_DIM>
@@ -77,13 +91,14 @@ jit_conv_kernel_fp getOrCreateConvKernel(
// Note: Wrong code is generated if it's not one of the supported convolution
assert(fbgemmOptimizedGConv<SPATIAL_DIM>(conv_param));
auto kernelSig = getKernelSig(conv_param, a_zero_point == 0);
- return GenConvKernel<SPATIAL_DIM, accT>::codeCache_.getOrCreate(
- kernelSig, [&]() {
- auto genObj =
- GenConvKernel<SPATIAL_DIM, accT>(conv_param, a_zero_point);
- // TODO: Instruction set based dispatch
- return genObj.template getOrCreate<inst_set_t::avx2>(conv_param);
- });
+ if (GenConvKernel<SPATIAL_DIM, accT>::codeCache_.find(kernelSig) !=
+ GenConvKernel<SPATIAL_DIM, accT>::codeCache_.end()) {
+ return GenConvKernel<SPATIAL_DIM, accT>::codeCache_[kernelSig];
+ } else {
+ auto genObj = GenConvKernel<SPATIAL_DIM, accT>(conv_param, a_zero_point);
+ // TODO: Instruction set based dispatch
+ return genObj.template getOrCreate<inst_set_t::avx2>(conv_param);
+ }
}
template <>
@@ -994,9 +1009,9 @@ template <>
template <>
jit_conv_kernel_fp GenConvKernel<2, int32_t>::getOrCreate<inst_set_t::avx2>(
const conv_param_t<2>& conv_param) {
- asmjit::CodeHolder code;
- code.init(rt_.codeInfo());
- x86::Assembler assembler(&code);
+ code_.reset(false);
+ code_.init(rt_.codeInfo());
+ x86::Assembler assembler(&code_);
x86::Emitter* a = assembler.as<x86::Emitter>();
#if defined(FBGEMM_LOG_CODE)
@@ -1005,7 +1020,7 @@ jit_conv_kernel_fp GenConvKernel<2, int32_t>::getOrCreate<inst_set_t::avx2>(
fopen(getCodeLoggingFile<inst_set_t::avx2>(false).c_str(), "w");
asmjit::FileLogger* codeLogger = new asmjit::FileLogger(codeLogfile);
if (codeLogger) {
- code.setLogger(codeLogger);
+ code_.setLogger(codeLogger);
}
#endif
@@ -1082,15 +1097,13 @@ jit_conv_kernel_fp GenConvKernel<2, int32_t>::getOrCreate<inst_set_t::avx2>(
a->emitEpilog(frame);
jit_conv_kernel_fp fn;
- asmjit::Error err;
- {
- std::unique_lock<std::mutex> lock(rtMutex_);
- err = rt_.add(&fn, &code);
- }
+ asmjit::Error err = rt_.add(&fn, &code_);
if (err) {
std::cout << "Error: in fn add" << std::endl;
return nullptr;
}
+ auto kernelSig = getKernelSig(conv_param, isAZeroPointZero_);
+ codeCache_[kernelSig] = fn;
#if defined(FBGEMM_LOG_CODE)
fclose(codeLogfile);
@@ -1476,9 +1489,9 @@ template <>
jit_rowoffset_kernel_fp
GenConvKernel<2, int32_t>::getOrCreateRowOffset<inst_set_t::avx2>(
const conv_param_t<2>& conv_param) {
- asmjit::CodeHolder code;
- code.init(rt_.codeInfo());
- x86::Assembler assembler(&code);
+ code_.reset(false);
+ code_.init(rt_.codeInfo());
+ x86::Assembler assembler(&code_);
x86::Emitter* a = assembler.as<x86::Emitter>();
#if defined(FBGEMM_LOG_CODE)
@@ -1487,7 +1500,7 @@ GenConvKernel<2, int32_t>::getOrCreateRowOffset<inst_set_t::avx2>(
fopen(getCodeLoggingFile<inst_set_t::avx2>(true).c_str(), "w");
asmjit::FileLogger* codeLogger = new asmjit::FileLogger(codeLogfile);
if (codeLogger) {
- code.setLogger(codeLogger);
+ code_.setLogger(codeLogger);
}
#endif
@@ -1557,16 +1570,14 @@ GenConvKernel<2, int32_t>::getOrCreateRowOffset<inst_set_t::avx2>(
a->emitEpilog(frame);
- asmjit::Error err;
jit_rowoffset_kernel_fp fn;
- {
- std::unique_lock<std::mutex> lock(rtMutex_);
- err = rt_.add(&fn, &code);
- }
+ asmjit::Error err = rt_.add(&fn, &code_);
if (err) {
std::cout << "Error: in fn add" << std::endl;
return nullptr;
}
+ auto kernelSig = getKernelSig(conv_param, isAZeroPointZero_);
+ codeCacheRowOffset_[kernelSig] = fn;
#if defined(FBGEMM_LOG_CODE)
delete codeLogger;
@@ -2151,14 +2162,15 @@ jit_rowoffset_kernel_fp getOrCreateRowOffsetKernel(
// Note: Wrong code is generated if it's not one of the supported convolution
assert(fbgemmOptimizedGConv<SPATIAL_DIM>(conv_param));
auto kernelSig = getKernelSig(conv_param, a_zero_point == 0);
- return GenConvKernel<SPATIAL_DIM, int32_t>::codeCacheRowOffset_.getOrCreate(
- kernelSig, [&]() {
- auto genObj =
- GenConvKernel<SPATIAL_DIM, int32_t>(conv_param, a_zero_point);
- // TODO: Instruction set based dispatch
- return genObj.template getOrCreateRowOffset<inst_set_t::avx2>(
- conv_param);
- });
+ if (GenConvKernel<SPATIAL_DIM, int32_t>::codeCacheRowOffset_.find(
+ kernelSig) !=
+ GenConvKernel<SPATIAL_DIM, int32_t>::codeCacheRowOffset_.end()) {
+ return GenConvKernel<SPATIAL_DIM, int32_t>::codeCacheRowOffset_[kernelSig];
+ } else {
+ auto genObj = GenConvKernel<SPATIAL_DIM, int32_t>(conv_param, a_zero_point);
+ // TODO: Instruction set based dispatch
+ return genObj.template getOrCreateRowOffset<inst_set_t::avx2>(conv_param);
+ }
}
template <int SPATIAL_DIM>