diff options
Diffstat (limited to 'src/GenerateKernelU8S8S32ACC32Avx512VNNI.cc')
-rw-r--r-- | src/GenerateKernelU8S8S32ACC32Avx512VNNI.cc | 501 |
1 files changed, 259 insertions, 242 deletions
diff --git a/src/GenerateKernelU8S8S32ACC32Avx512VNNI.cc b/src/GenerateKernelU8S8S32ACC32Avx512VNNI.cc index 23103ae..8ae0745 100644 --- a/src/GenerateKernelU8S8S32ACC32Avx512VNNI.cc +++ b/src/GenerateKernelU8S8S32ACC32Avx512VNNI.cc @@ -155,260 +155,277 @@ CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreate< nRegBlockSize, nRegBlockSizeMin); - return codeCache_.getOrCreate(kernelSig, [&]() -> jit_micro_kernel_fp { - asmjit::CodeHolder code; - code.init(rt_.codeInfo()); - x86::Assembler assembler(&code); - x86::Emitter *a = assembler.as<x86::Emitter>(); + if (codeCache_.find(kernelSig) != codeCache_.end()) { + return codeCache_[kernelSig]; + } + code_.reset(false); + code_.init(rt_.codeInfo()); + x86::Assembler assembler(&code_); + x86::Emitter* a = assembler.as<x86::Emitter>(); #if defined(FBGEMM_LOG_CODE) - // generated code logging - FILE *codeLogfile = fopen(getCodeLoggingFile<inst_set_t::avx512_vnni>( - accum, mc, nc, nBlock, kBlock, mRegBlockSize, - nRegBlockSize, nRegBlockSizeMin) - .c_str(), - "w"); - asmjit::FileLogger *codeLogger = new asmjit::FileLogger(codeLogfile); - if (codeLogger) { - code.setLogger(codeLogger); - } + // generated code logging + FILE* codeLogfile = fopen( + getCodeLoggingFile<inst_set_t::avx512_vnni>( + accum, + mc, + nc, + nBlock, + kBlock, + mRegBlockSize, + nRegBlockSize, + nRegBlockSizeMin) + .c_str(), + "w"); + asmjit::FileLogger* codeLogger = new asmjit::FileLogger(codeLogfile); + if (codeLogger) { + code_.setLogger(codeLogger); + } #endif - assert(kc % row_interleave == 0 && - "kc must be a multiple of row_interleave"); - assert(nc % nRegBlockSizeMin == 0 && "nc must be a multiple of NR_MIN"); - int maxMRegs = mRegBlockSize; - int maxNRegs = nRegBlockSize * row_interleave / VLEN_; - assert(maxMRegs * maxNRegs <= 28 && "MR*(NR*ROW_INTERLEAVE*8/512) \ + assert(kc % row_interleave == 0 && "kc must be a multiple of row_interleave"); + assert(nc % nRegBlockSizeMin == 0 && "nc must be a multiple of NR_MIN"); + int maxMRegs = mRegBlockSize; + int maxNRegs = nRegBlockSize * row_interleave / VLEN_; + assert( + maxMRegs * maxNRegs <= 28 && + "MR*(NR*ROW_INTERLEAVE*8/512) \ must be <= 28(available registers constraint)"); - int mRegBlocks = mc / mRegBlockSize; - int mRegBlocksRem = mc % mRegBlockSize; - - // arguments to the function created - x86::Gp buffer_A = a->zdi(); - x86::Gp buffer_B = a->zsi(); - x86::Gp B_pf = a->zdx(); - x86::Gp CBase = a->zcx(); - x86::Gp kSize = a->gpz(8); - x86::Gp ldcReg = a->gpz(9); - - asmjit::FuncDetail func; - func.init( - asmjit::FuncSignatureT<void, uint8_t *, int8_t *, int8_t *, int32_t *, - int, int>(asmjit::CallConv::kIdHost)); - - asmjit::FuncFrame frame; - frame.init(func); - - frame.setDirtyRegs( - x86::Reg::kGroupVec, - asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) | - asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15)); - frame.setDirtyRegs(x86::Reg::kGroupGp, - asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15)); - - asmjit::FuncArgsAssignment args(&func); - args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg); - - args.updateFuncFrame(frame); - frame.finalize(); - - a->emitProlog(frame); - a->emitArgsAssignment(frame, args); - - asmjit::Label LoopMBlocks = a->newLabel(); - asmjit::Label LoopNBlocks = a->newLabel(); - asmjit::Label Loopk = a->newLabel(); - - x86::Gp buffer_B_saved = a->gpz(10); - x86::Gp C_Offset = a->gpz(11); - x86::Gp B_pf_saved = a->gpz(12); - x86::Gp iIdx = a->gpz(13); - x86::Gp jIdx = a->gpz(14); - x86::Gp kIdx = a->gpz(15); - // x86::Gp B_pf = a->gpz(8); - - x86::Zmm oneReg = x86::zmm29; - // create 16-bit 1s - // i.e., oneReg[0:15] contains 0x0001, oneReg[16:31] contains 0x0001 - // and so on - // a->vpcmpeqw(oneReg, oneReg, oneReg); - a->vpternlogd(oneReg, oneReg, oneReg, 0xff); - a->vpsrlw(oneReg, oneReg, 15); - a->imul(ldcReg, ldcReg, static_cast<asmjit::Imm>(sizeof(int32_t))); - - // save B_buffer address - a->mov(buffer_B_saved, buffer_B); - a->mov(B_pf_saved, B_pf); - - int currColRegs = nc * row_interleave * sizeof(int8_t) / VLEN_; - int colRegs = std::min(currColRegs, maxNRegs); - if (mRegBlocks > 0) { - // move 0 to iteration variables - a->mov(iIdx, 0); - - a->bind(LoopMBlocks); - a->inc(iIdx); - a->mov(jIdx, 0); - - a->bind(LoopNBlocks); - a->inc(jIdx); - - int rowRegs = mRegBlockSize; - - // init C registers - initCRegs<inst_set_t::avx512_vnni>(a, rowRegs, colRegs, colRegs); - - // init k loop index - a->mov(kIdx, 0); - a->bind(Loopk); - - // k is incremented by row_interleave - a->add(kIdx, static_cast<asmjit::Imm>(row_interleave)); - - genComputeBlock<inst_set_t::avx512_vnni>( - a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock, colRegs); - - // update buffer_A address for next k iteration - a->add(buffer_A, - static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t))); - - // update buffer_B address for next k iteration - a->add(buffer_B, static_cast<asmjit::Imm>(nBlock * row_interleave * - sizeof(int8_t))); - a->add(B_pf, static_cast<asmjit::Imm>(nBlock * row_interleave * - sizeof(int8_t))); - - // a->add(B_pf, static_cast<asmjit::Imm>(32*sizeof(float))); - - a->cmp(kIdx, kSize); - a->jl(Loopk); - - // store C matrix - storeCRegs<inst_set_t::avx512_vnni>(a, rowRegs, colRegs, C_Offset, ldcReg, - accum, colRegs); - - // reset A - a->sub(buffer_A, kSize); - - // B for next block - a->mov(buffer_B, buffer_B_saved); - // using C_Offset as temp reg - a->imul(C_Offset, jIdx, - static_cast<asmjit::Imm>(nRegBlockSize * row_interleave * - sizeof(int8_t))); - a->add(buffer_B, C_Offset); - a->mov(B_pf, B_pf_saved); - a->add(B_pf, C_Offset); - - // increment C for next B block - a->add(CBase, static_cast<asmjit::Imm>(nRegBlockSize * sizeof(int32_t))); - - int jLoopTrips = currColRegs / maxNRegs; - // jLoopTrips should be at least 1 - jLoopTrips = jLoopTrips ? jLoopTrips : 1; - a->cmp(jIdx, jLoopTrips); - a->jl(LoopNBlocks); - - // increment A for next block - a->add(buffer_A, - static_cast<asmjit::Imm>((rowRegs)*kBlock * sizeof(uint8_t))); - - // increment C for next A block - a->sub(CBase, static_cast<asmjit::Imm>(jLoopTrips * nRegBlockSize * - sizeof(int32_t))); - a->imul(C_Offset, ldcReg, static_cast<asmjit::Imm>(rowRegs)); - a->add(CBase, C_Offset); - - // reset B - a->mov(buffer_B, buffer_B_saved); - a->mov(B_pf, B_pf_saved); - a->cmp(iIdx, mRegBlocks); - a->jl(LoopMBlocks); - } - // generate code for remainder - if (mRegBlocksRem > 0) { - asmjit::Label LoopNRem = a->newLabel(); - asmjit::Label LoopkRem = a->newLabel(); - int rowRegs = mRegBlocksRem; - - a->mov(jIdx, 0); - a->bind(LoopNRem); - a->inc(jIdx); - - // init C registers - initCRegs<inst_set_t::avx512_vnni>(a, rowRegs, colRegs, colRegs); - - // init k loop index - a->mov(kIdx, 0); - a->bind(LoopkRem); - - // k is incremented by row_interleave - a->add(kIdx, static_cast<asmjit::Imm>(row_interleave)); - - genComputeBlock<inst_set_t::avx512_vnni>( - a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock, colRegs); - - // update buffer_A address for next k iteration - a->add(buffer_A, - static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t))); - - // update buffer_B address for next k iteration - a->add(buffer_B, static_cast<asmjit::Imm>(nBlock * row_interleave * - sizeof(int8_t))); - a->add(B_pf, static_cast<asmjit::Imm>(nBlock * row_interleave * - sizeof(int8_t))); - - a->cmp(kIdx, kSize); - a->jl(LoopkRem); - - // reset A - a->sub(buffer_A, kSize); - // B for next block - // using C_Offset as temp reg - a->imul(C_Offset, jIdx, - static_cast<asmjit::Imm>(nRegBlockSize * row_interleave * - sizeof(int8_t))); - a->mov(buffer_B, buffer_B_saved); - a->add(buffer_B, C_Offset); - a->mov(B_pf, B_pf_saved); - a->add(B_pf, C_Offset); - - // store C matrix - storeCRegs<inst_set_t::avx512_vnni>(a, rowRegs, colRegs, C_Offset, ldcReg, - accum, colRegs); - - // increment C for next B block - a->add(CBase, static_cast<asmjit::Imm>(nRegBlockSize * sizeof(int32_t))); - - int jLoopTrips = currColRegs / maxNRegs; - // jLoopTrips should be at least 1 - jLoopTrips = jLoopTrips ? jLoopTrips : 1; - a->cmp(jIdx, jLoopTrips); - a->jl(LoopNRem); - } + int mRegBlocks = mc / mRegBlockSize; + int mRegBlocksRem = mc % mRegBlockSize; + + // arguments to the function created + x86::Gp buffer_A = a->zdi(); + x86::Gp buffer_B = a->zsi(); + x86::Gp B_pf = a->zdx(); + x86::Gp CBase = a->zcx(); + x86::Gp kSize = a->gpz(8); + x86::Gp ldcReg = a->gpz(9); + + asmjit::FuncDetail func; + func.init( + asmjit:: + FuncSignatureT<void, uint8_t*, int8_t*, int8_t*, int32_t*, int, int>( + asmjit::CallConv::kIdHost)); + + asmjit::FuncFrame frame; + frame.init(func); + + frame.setDirtyRegs( + x86::Reg::kGroupVec, + asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) | + asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15)); + frame.setDirtyRegs( + x86::Reg::kGroupGp, + asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15)); + + asmjit::FuncArgsAssignment args(&func); + args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg); + + args.updateFuncFrame(frame); + frame.finalize(); + + a->emitProlog(frame); + a->emitArgsAssignment(frame, args); + + asmjit::Label LoopMBlocks = a->newLabel(); + asmjit::Label LoopNBlocks = a->newLabel(); + asmjit::Label Loopk = a->newLabel(); + + x86::Gp buffer_B_saved = a->gpz(10); + x86::Gp C_Offset = a->gpz(11); + x86::Gp B_pf_saved = a->gpz(12); + x86::Gp iIdx = a->gpz(13); + x86::Gp jIdx = a->gpz(14); + x86::Gp kIdx = a->gpz(15); + // x86::Gp B_pf = a->gpz(8); + + x86::Zmm oneReg = x86::zmm29; + // create 16-bit 1s + // i.e., oneReg[0:15] contains 0x0001, oneReg[16:31] contains 0x0001 + // and so on + // a->vpcmpeqw(oneReg, oneReg, oneReg); + a->vpternlogd(oneReg, oneReg, oneReg, 0xff); + a->vpsrlw(oneReg, oneReg, 15); + a->imul(ldcReg, ldcReg, static_cast<asmjit::Imm>(sizeof(int32_t))); + + // save B_buffer address + a->mov(buffer_B_saved, buffer_B); + a->mov(B_pf_saved, B_pf); + + int currColRegs = nc * row_interleave * sizeof(int8_t) / VLEN_; + int colRegs = std::min(currColRegs, maxNRegs); + if (mRegBlocks > 0) { + // move 0 to iteration variables + a->mov(iIdx, 0); + + a->bind(LoopMBlocks); + a->inc(iIdx); + a->mov(jIdx, 0); + + a->bind(LoopNBlocks); + a->inc(jIdx); + + int rowRegs = mRegBlockSize; + + // init C registers + initCRegs<inst_set_t::avx512_vnni>(a, rowRegs, colRegs, colRegs); + + // init k loop index + a->mov(kIdx, 0); + a->bind(Loopk); + + // k is incremented by row_interleave + a->add(kIdx, static_cast<asmjit::Imm>(row_interleave)); + + genComputeBlock<inst_set_t::avx512_vnni>( + a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock, colRegs); + + // update buffer_A address for next k iteration + a->add( + buffer_A, static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t))); + + // update buffer_B address for next k iteration + a->add( + buffer_B, + static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t))); + a->add( + B_pf, + static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t))); + + // a->add(B_pf, static_cast<asmjit::Imm>(32*sizeof(float))); + + a->cmp(kIdx, kSize); + a->jl(Loopk); + + // store C matrix + storeCRegs<inst_set_t::avx512_vnni>( + a, rowRegs, colRegs, C_Offset, ldcReg, accum, colRegs); + + // reset A + a->sub(buffer_A, kSize); + + // B for next block + a->mov(buffer_B, buffer_B_saved); + // using C_Offset as temp reg + a->imul( + C_Offset, + jIdx, + static_cast<asmjit::Imm>( + nRegBlockSize * row_interleave * sizeof(int8_t))); + a->add(buffer_B, C_Offset); + a->mov(B_pf, B_pf_saved); + a->add(B_pf, C_Offset); + + // increment C for next B block + a->add(CBase, static_cast<asmjit::Imm>(nRegBlockSize * sizeof(int32_t))); + + int jLoopTrips = currColRegs / maxNRegs; + // jLoopTrips should be at least 1 + jLoopTrips = jLoopTrips ? jLoopTrips : 1; + a->cmp(jIdx, jLoopTrips); + a->jl(LoopNBlocks); + + // increment A for next block + a->add( + buffer_A, static_cast<asmjit::Imm>((rowRegs)*kBlock * sizeof(uint8_t))); + + // increment C for next A block + a->sub( + CBase, + static_cast<asmjit::Imm>(jLoopTrips * nRegBlockSize * sizeof(int32_t))); + a->imul(C_Offset, ldcReg, static_cast<asmjit::Imm>(rowRegs)); + a->add(CBase, C_Offset); + + // reset B + a->mov(buffer_B, buffer_B_saved); + a->mov(B_pf, B_pf_saved); + a->cmp(iIdx, mRegBlocks); + a->jl(LoopMBlocks); + } + // generate code for remainder + if (mRegBlocksRem > 0) { + asmjit::Label LoopNRem = a->newLabel(); + asmjit::Label LoopkRem = a->newLabel(); + int rowRegs = mRegBlocksRem; + + a->mov(jIdx, 0); + a->bind(LoopNRem); + a->inc(jIdx); + + // init C registers + initCRegs<inst_set_t::avx512_vnni>(a, rowRegs, colRegs, colRegs); + + // init k loop index + a->mov(kIdx, 0); + a->bind(LoopkRem); + + // k is incremented by row_interleave + a->add(kIdx, static_cast<asmjit::Imm>(row_interleave)); + + genComputeBlock<inst_set_t::avx512_vnni>( + a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock, colRegs); + + // update buffer_A address for next k iteration + a->add( + buffer_A, static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t))); + + // update buffer_B address for next k iteration + a->add( + buffer_B, + static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t))); + a->add( + B_pf, + static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t))); + + a->cmp(kIdx, kSize); + a->jl(LoopkRem); + + // reset A + a->sub(buffer_A, kSize); + // B for next block + // using C_Offset as temp reg + a->imul( + C_Offset, + jIdx, + static_cast<asmjit::Imm>( + nRegBlockSize * row_interleave * sizeof(int8_t))); + a->mov(buffer_B, buffer_B_saved); + a->add(buffer_B, C_Offset); + a->mov(B_pf, B_pf_saved); + a->add(B_pf, C_Offset); + + // store C matrix + storeCRegs<inst_set_t::avx512_vnni>( + a, rowRegs, colRegs, C_Offset, ldcReg, accum, colRegs); + + // increment C for next B block + a->add(CBase, static_cast<asmjit::Imm>(nRegBlockSize * sizeof(int32_t))); + + int jLoopTrips = currColRegs / maxNRegs; + // jLoopTrips should be at least 1 + jLoopTrips = jLoopTrips ? jLoopTrips : 1; + a->cmp(jIdx, jLoopTrips); + a->jl(LoopNRem); + } - a->emitEpilog(frame); + a->emitEpilog(frame); - jit_micro_kernel_fp fn; - asmjit::Error err; - { - std::unique_lock<std::mutex> lock(rtMutex_); - err = rt_.add(&fn, &code); - } - if (err) { - std::cout << "Error: in fn add" << std::endl; - return nullptr; - } + jit_micro_kernel_fp fn; + asmjit::Error err = rt_.add(&fn, &code_); + if (err) { + std::cout << "Error: in fn add" << std::endl; + return nullptr; + } + codeCache_[kernelSig] = fn; #if defined(FBGEMM_LOG_CODE) - fclose(codeLogfile); - delete codeLogger; + fclose(codeLogfile); + delete codeLogger; #endif - return fn; - }); + return fn; } } // namespace fbgemm |