diff options
Diffstat (limited to 'src/GenerateKernelU8S8S32ACC16Avx512.cc')
-rw-r--r-- | src/GenerateKernelU8S8S32ACC16Avx512.cc | 473 |
1 files changed, 229 insertions, 244 deletions
diff --git a/src/GenerateKernelU8S8S32ACC16Avx512.cc b/src/GenerateKernelU8S8S32ACC16Avx512.cc index a49e440..b67d8e8 100644 --- a/src/GenerateKernelU8S8S32ACC16Avx512.cc +++ b/src/GenerateKernelU8S8S32ACC16Avx512.cc @@ -167,262 +167,247 @@ CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<inst_set_t::avx512>( nRegBlockSize, nRegBlockSizeMin); - if (codeCache_.find(kernelSig) != codeCache_.end()) { - return codeCache_[kernelSig]; - } - - code_.reset(false); - code_.init(rt_.codeInfo()); - x86::Assembler assembler(&code_); - x86::Emitter* a = assembler.as<x86::Emitter>(); + return codeCache_.getOrCreate(kernelSig, [&]() -> jit_micro_kernel_fp { + asmjit::CodeHolder code; + code.init(runtime().codeInfo()); + x86::Assembler assembler(&code); + x86::Emitter *a = assembler.as<x86::Emitter>(); #if defined(FBGEMM_LOG_CODE) - // generated code logging - FILE* codeLogfile = fopen( - getCodeLoggingFile<inst_set_t::avx512>( - accum, - mc, - nc, - nBlock, - kBlock, - mRegBlockSize, - nRegBlockSize, - nRegBlockSizeMin) - .c_str(), - "w"); - asmjit::FileLogger* codeLogger = new asmjit::FileLogger(codeLogfile); - if (codeLogger) { - code_.setLogger(codeLogger); - } + // generated code logging + FILE *codeLogfile = fopen(getCodeLoggingFile<inst_set_t::avx512>( + accum, mc, nc, nBlock, kBlock, mRegBlockSize, + nRegBlockSize, nRegBlockSizeMin) + .c_str(), + "w"); + asmjit::FileLogger *codeLogger = new asmjit::FileLogger(codeLogfile); + if (codeLogger) { + code.setLogger(codeLogger); + } #endif - assert(kc % row_interleave == 0 && "kc must be a multiple of row_interleave"); - assert(nc % nRegBlockSizeMin == 0 && "nc must be a multiple of NR_MIN"); - int maxMRegs = mRegBlockSize; - int maxNRegs = nRegBlockSize * row_interleave / VLEN_; - assert( - (maxMRegs+1) * maxNRegs <= 28 && - "number of zmm registers for C + one row for loading B: \ + assert(kc % row_interleave == 0 && + "kc must be a multiple of row_interleave"); + assert(nc % nRegBlockSizeMin == 0 && "nc must be a multiple of NR_MIN"); + int maxMRegs = mRegBlockSize; + int maxNRegs = nRegBlockSize * row_interleave / VLEN_; + assert((maxMRegs + 1) * maxNRegs <= 28 && + "number of zmm registers for C + one row for loading B: \ MR*(NR*ROW_INTERLEAVE*8/512) + (NR*ROW_INTERLEAVE*8/512) \ must be <= 28(available registers constraint)"); - int mRegBlocks = mc / mRegBlockSize; - int mRegBlocksRem = mc % mRegBlockSize; - - // arguments to the function created - x86::Gp buffer_A = a->zdi(); - x86::Gp buffer_B = a->zsi(); - x86::Gp B_pf = a->zdx(); - x86::Gp CBase = a->zcx(); - x86::Gp kSize = a->gpz(8); - x86::Gp ldcReg = a->gpz(9); - - asmjit::FuncDetail func; - func.init( - asmjit:: - FuncSignatureT<void, uint8_t*, int8_t*, int8_t*, int32_t*, int, int>( - asmjit::CallConv::kIdHost)); - - asmjit::FuncFrame frame; - frame.init(func); - - frame.setDirtyRegs( - x86::Reg::kGroupVec, - asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) | - asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15)); - frame.setDirtyRegs( - x86::Reg::kGroupGp, - asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15)); - - asmjit::FuncArgsAssignment args(&func); - args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg); - - args.updateFuncFrame(frame); - frame.finalize(); - - a->emitProlog(frame); - a->emitArgsAssignment(frame, args); - - asmjit::Label LoopMBlocks = a->newLabel(); - asmjit::Label LoopNBlocks = a->newLabel(); - asmjit::Label Loopk = a->newLabel(); - - x86::Gp buffer_B_saved = a->gpz(10); - x86::Gp C_Offset = a->gpz(11); - // x86::Gp B_pf_saved = a->gpz(12); - x86::Gp iIdx = a->gpz(13); - x86::Gp jIdx = a->gpz(14); - x86::Gp kIdx = a->gpz(15); - - // save B_buffer address - a->mov(buffer_B_saved, buffer_B); - // a->mov(B_pf_saved, B_pf); - - int currColRegs = nc * row_interleave * sizeof(int8_t) / VLEN_; - int colRegs = std::min(currColRegs, maxNRegs); - if (mRegBlocks > 0) { - // move 0 to iteration variables - a->mov(iIdx, 0); - - a->bind(LoopMBlocks); - a->inc(iIdx); - a->mov(jIdx, 0); - - a->bind(LoopNBlocks); - a->inc(jIdx); - - int rowRegs = mRegBlockSize; - - // init C registers - initCRegs<inst_set_t::avx512>(a, rowRegs, colRegs, colRegs); - - // init k loop index - a->mov(kIdx, 0); - a->bind(Loopk); - // k is incremented by row_interleave - a->add(kIdx, static_cast<asmjit::Imm>(row_interleave)); - - genComputeBlock<inst_set_t::avx512>( - a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock, colRegs); - - // update buffer_A address for next k iteration - a->add( - buffer_A, static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t))); - - // update buffer_B address for next k iteration - a->add( - buffer_B, - static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t))); - // a->add(B_pf, static_cast<asmjit::Imm>(nBlock * row_interleave * - // sizeof(int8_t))); - - a->cmp(kIdx, kSize); - a->jl(Loopk); - - // store C matrix - storeCRegs<inst_set_t::avx512>( - a, rowRegs, colRegs, C_Offset, ldcReg, accum, colRegs); - - // reset A - a->sub(buffer_A, kSize); - - // B for next block - a->mov(buffer_B, buffer_B_saved); - // using C_Offset as temp reg - a->imul( - C_Offset, - jIdx, - static_cast<asmjit::Imm>( - nRegBlockSize * row_interleave * sizeof(int8_t))); - a->add(buffer_B, C_Offset); - - // increment C for next block - a->add(CBase, static_cast<asmjit::Imm>(nRegBlockSize * sizeof(int32_t))); - - int jLoopTrips = currColRegs / maxNRegs; - // jLoopTrips should be at least 1 - jLoopTrips = jLoopTrips ? jLoopTrips : 1; - a->cmp(jIdx, jLoopTrips); - a->jl(LoopNBlocks); - - // increment A for next block - a->add( - buffer_A, static_cast<asmjit::Imm>((rowRegs)*kBlock * sizeof(uint8_t))); - - // increment C for next A block - a->sub( - CBase, - static_cast<asmjit::Imm>(jLoopTrips * nRegBlockSize * sizeof(int32_t))); - a->imul( - C_Offset, ldcReg, static_cast<asmjit::Imm>(rowRegs * sizeof(int32_t))); - a->add(CBase, C_Offset); - - // reset B - a->mov(buffer_B, buffer_B_saved); - // a->mov(B_pf, B_pf_saved); - - a->cmp(iIdx, mRegBlocks); - a->jl(LoopMBlocks); - } - // generate code for remainder - if (mRegBlocksRem > 0) { - asmjit::Label LoopNRem = a->newLabel(); - asmjit::Label LoopkRem = a->newLabel(); - int rowRegs = mRegBlocksRem; - - a->mov(jIdx, 0); - a->bind(LoopNRem); - a->inc(jIdx); - - // init C registers - initCRegs<inst_set_t::avx512>(a, rowRegs, colRegs, colRegs); - - // init k loop index - a->mov(kIdx, 0); - a->bind(LoopkRem); - - // k is incremented by row_interleave - a->add(kIdx, static_cast<asmjit::Imm>(row_interleave)); - - genComputeBlock<inst_set_t::avx512>( - a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock, colRegs); - - // update buffer_A address for next k iteration - a->add( - buffer_A, static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t))); - - // update buffer_B address for next k iteration - a->add( - buffer_B, - static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t))); - // a->add(B_pf, static_cast<asmjit::Imm>(nBlock * row_interleave * - // sizeof(int8_t))); - - a->cmp(kIdx, kSize); - a->jl(LoopkRem); - - // reset A - a->sub(buffer_A, kSize); - - // B for next block - a->mov(buffer_B, buffer_B_saved); - // using C_Offset as temp reg - a->imul( - C_Offset, - jIdx, - static_cast<asmjit::Imm>( - nRegBlockSize * row_interleave * sizeof(int8_t))); - a->add(buffer_B, C_Offset); - - // store C matrix - storeCRegs<inst_set_t::avx512>( - a, rowRegs, colRegs, C_Offset, ldcReg, accum, colRegs); - - // increment C for next block - a->add(CBase, static_cast<asmjit::Imm>(nRegBlockSize * sizeof(int32_t))); - - int jLoopTrips = currColRegs / maxNRegs; - // jLoopTrips should be at least 1 - jLoopTrips = jLoopTrips ? jLoopTrips : 1; - a->cmp(jIdx, jLoopTrips); - a->jl(LoopNRem); - } + int mRegBlocks = mc / mRegBlockSize; + int mRegBlocksRem = mc % mRegBlockSize; + + // arguments to the function created + x86::Gp buffer_A = a->zdi(); + x86::Gp buffer_B = a->zsi(); + x86::Gp B_pf = a->zdx(); + x86::Gp CBase = a->zcx(); + x86::Gp kSize = a->gpz(8); + x86::Gp ldcReg = a->gpz(9); + + asmjit::FuncDetail func; + func.init( + asmjit::FuncSignatureT<void, uint8_t *, int8_t *, int8_t *, int32_t *, + int, int>(asmjit::CallConv::kIdHost)); + + asmjit::FuncFrame frame; + frame.init(func); + + frame.setDirtyRegs( + x86::Reg::kGroupVec, + asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) | + asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15)); + frame.setDirtyRegs(x86::Reg::kGroupGp, + asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15)); + + asmjit::FuncArgsAssignment args(&func); + args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg); + + args.updateFuncFrame(frame); + frame.finalize(); + + a->emitProlog(frame); + a->emitArgsAssignment(frame, args); + + asmjit::Label LoopMBlocks = a->newLabel(); + asmjit::Label LoopNBlocks = a->newLabel(); + asmjit::Label Loopk = a->newLabel(); + + x86::Gp buffer_B_saved = a->gpz(10); + x86::Gp C_Offset = a->gpz(11); + // x86::Gp B_pf_saved = a->gpz(12); + x86::Gp iIdx = a->gpz(13); + x86::Gp jIdx = a->gpz(14); + x86::Gp kIdx = a->gpz(15); + + // save B_buffer address + a->mov(buffer_B_saved, buffer_B); + // a->mov(B_pf_saved, B_pf); + + int currColRegs = nc * row_interleave * sizeof(int8_t) / VLEN_; + int colRegs = std::min(currColRegs, maxNRegs); + if (mRegBlocks > 0) { + // move 0 to iteration variables + a->mov(iIdx, 0); + + a->bind(LoopMBlocks); + a->inc(iIdx); + a->mov(jIdx, 0); + + a->bind(LoopNBlocks); + a->inc(jIdx); + + int rowRegs = mRegBlockSize; + + // init C registers + initCRegs<inst_set_t::avx512>(a, rowRegs, colRegs, colRegs); + + // init k loop index + a->mov(kIdx, 0); + a->bind(Loopk); + // k is incremented by row_interleave + a->add(kIdx, static_cast<asmjit::Imm>(row_interleave)); + + genComputeBlock<inst_set_t::avx512>(a, buffer_A, buffer_B, B_pf, rowRegs, + colRegs, kBlock, colRegs); + + // update buffer_A address for next k iteration + a->add(buffer_A, + static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t))); + + // update buffer_B address for next k iteration + a->add(buffer_B, static_cast<asmjit::Imm>(nBlock * row_interleave * + sizeof(int8_t))); + // a->add(B_pf, static_cast<asmjit::Imm>(nBlock * row_interleave * + // sizeof(int8_t))); + + a->cmp(kIdx, kSize); + a->jl(Loopk); + + // store C matrix + storeCRegs<inst_set_t::avx512>(a, rowRegs, colRegs, C_Offset, ldcReg, + accum, colRegs); + + // reset A + a->sub(buffer_A, kSize); + + // B for next block + a->mov(buffer_B, buffer_B_saved); + // using C_Offset as temp reg + a->imul(C_Offset, jIdx, + static_cast<asmjit::Imm>(nRegBlockSize * row_interleave * + sizeof(int8_t))); + a->add(buffer_B, C_Offset); + + // increment C for next block + a->add(CBase, static_cast<asmjit::Imm>(nRegBlockSize * sizeof(int32_t))); + + int jLoopTrips = currColRegs / maxNRegs; + // jLoopTrips should be at least 1 + jLoopTrips = jLoopTrips ? jLoopTrips : 1; + a->cmp(jIdx, jLoopTrips); + a->jl(LoopNBlocks); + + // increment A for next block + a->add(buffer_A, + static_cast<asmjit::Imm>((rowRegs)*kBlock * sizeof(uint8_t))); + + // increment C for next A block + a->sub(CBase, static_cast<asmjit::Imm>(jLoopTrips * nRegBlockSize * + sizeof(int32_t))); + a->imul(C_Offset, ldcReg, + static_cast<asmjit::Imm>(rowRegs * sizeof(int32_t))); + a->add(CBase, C_Offset); + + // reset B + a->mov(buffer_B, buffer_B_saved); + // a->mov(B_pf, B_pf_saved); + + a->cmp(iIdx, mRegBlocks); + a->jl(LoopMBlocks); + } + // generate code for remainder + if (mRegBlocksRem > 0) { + asmjit::Label LoopNRem = a->newLabel(); + asmjit::Label LoopkRem = a->newLabel(); + int rowRegs = mRegBlocksRem; + + a->mov(jIdx, 0); + a->bind(LoopNRem); + a->inc(jIdx); + + // init C registers + initCRegs<inst_set_t::avx512>(a, rowRegs, colRegs, colRegs); + + // init k loop index + a->mov(kIdx, 0); + a->bind(LoopkRem); + + // k is incremented by row_interleave + a->add(kIdx, static_cast<asmjit::Imm>(row_interleave)); + + genComputeBlock<inst_set_t::avx512>(a, buffer_A, buffer_B, B_pf, rowRegs, + colRegs, kBlock, colRegs); + + // update buffer_A address for next k iteration + a->add(buffer_A, + static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t))); + + // update buffer_B address for next k iteration + a->add(buffer_B, static_cast<asmjit::Imm>(nBlock * row_interleave * + sizeof(int8_t))); + // a->add(B_pf, static_cast<asmjit::Imm>(nBlock * row_interleave * + // sizeof(int8_t))); + + a->cmp(kIdx, kSize); + a->jl(LoopkRem); + + // reset A + a->sub(buffer_A, kSize); + + // B for next block + a->mov(buffer_B, buffer_B_saved); + // using C_Offset as temp reg + a->imul(C_Offset, jIdx, + static_cast<asmjit::Imm>(nRegBlockSize * row_interleave * + sizeof(int8_t))); + a->add(buffer_B, C_Offset); + + // store C matrix + storeCRegs<inst_set_t::avx512>(a, rowRegs, colRegs, C_Offset, ldcReg, + accum, colRegs); + + // increment C for next block + a->add(CBase, static_cast<asmjit::Imm>(nRegBlockSize * sizeof(int32_t))); + + int jLoopTrips = currColRegs / maxNRegs; + // jLoopTrips should be at least 1 + jLoopTrips = jLoopTrips ? jLoopTrips : 1; + a->cmp(jIdx, jLoopTrips); + a->jl(LoopNRem); + } - a->emitEpilog(frame); + a->emitEpilog(frame); - jit_micro_kernel_fp fn; - asmjit::Error err = rt_.add(&fn, &code_); - if (err) { - std::cout << "Error: in fn add" << std::endl; - return nullptr; - } - codeCache_[kernelSig] = fn; + jit_micro_kernel_fp fn; + asmjit::Error err; + { + std::unique_lock<std::mutex> lock(rtMutex_); + err = runtime().add(&fn, &code); + } + if (err) { + std::cout << "Error: in fn add" << std::endl; + return nullptr; + } #if defined(FBGEMM_LOG_CODE) - fclose(codeLogfile); - delete codeLogger; + fclose(codeLogfile); + delete codeLogger; #endif - return fn; + return fn; + }); } } // namespace fbgemm |