diff options
Diffstat (limited to 'src/GenerateKernelU8S8S32ACC16Avx512.cc')
-rw-r--r-- | src/GenerateKernelU8S8S32ACC16Avx512.cc | 543 |
1 files changed, 273 insertions, 270 deletions
diff --git a/src/GenerateKernelU8S8S32ACC16Avx512.cc b/src/GenerateKernelU8S8S32ACC16Avx512.cc index 505fec1..819f33b 100644 --- a/src/GenerateKernelU8S8S32ACC16Avx512.cc +++ b/src/GenerateKernelU8S8S32ACC16Avx512.cc @@ -19,16 +19,17 @@ template <> template <> void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::initCRegs< inst_set_t::avx512>( - asmjit::X86Emitter* a, + x86::Emitter* a, int rowRegs, int colRegs, - int leadingDimCRegAssign) { + int leadingDimCReg) { + using CRegs = x86::Zmm; for (int i = 0; i < rowRegs; ++i) { for (int j = 0; j < colRegs; ++j) { a->vxorps( - CRegs_avx512_[i * leadingDimCRegAssign + j], - CRegs_avx512_[i * leadingDimCRegAssign + j], - CRegs_avx512_[i * leadingDimCRegAssign + j]); + CRegs(i * leadingDimCReg + j), + CRegs(i * leadingDimCReg + j), + CRegs(i * leadingDimCReg + j)); } } } @@ -41,37 +42,38 @@ template <> template <> void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::genComputeBlock< inst_set_t::avx512>( - asmjit::X86Emitter* a, - asmjit::X86Gp buffer_A, - asmjit::X86Gp buffer_B, - asmjit::X86Gp /* unused (reserved for prefetching)*/, + x86::Emitter* a, + x86::Gp buffer_A, + x86::Gp buffer_B, + x86::Gp /* unused (reserved for prefetching)*/, int rowRegs, int colRegs, int lda, - int leadingDimCRegAssign) { + int leadingDimCReg) { // used for matrix A - asmjit::X86Zmm AReg = x86::zmm29; + x86::Zmm AReg = x86::zmm29; - asmjit::X86Zmm tmpReg = x86::zmm30; + x86::Zmm tmpReg = x86::zmm30; // We start allocating BRegs from zmm27 and then allocate zmm26 and so on. for (int j = 0; j < colRegs; ++j) { a->vmovups( - AllRegs_avx512_[27 - j], + x86::Zmm(27 - j), x86::dword_ptr(buffer_B, j * VLEN_ * sizeof(int8_t))); } + using CRegs = x86::Zmm; + for (int i = 0; i < rowRegs; ++i) { // broadcast A a->vpbroadcastw( AReg, x86::dword_ptr(buffer_A, (i * lda) * sizeof(uint8_t))); for (int j = 0; j < colRegs; ++j) { - a->vpmaddubsw( - tmpReg, AReg, AllRegs_avx512_[27-j]); + a->vpmaddubsw(tmpReg, AReg, x86::Zmm(27 - j)); a->vpaddsw( - CRegs_avx512_[i * leadingDimCRegAssign + j], + CRegs(i * leadingDimCReg + j), tmpReg, - CRegs_avx512_[i * leadingDimCRegAssign + j]); + CRegs(i * leadingDimCReg + j)); // Prefetching is hurting performance in some cases // because prefetch instructions itself consumes a slot // in pipeline issue thus slowing down the kernel. @@ -90,25 +92,31 @@ template <> template <> void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::storeCRegs< inst_set_t::avx512>( - asmjit::X86Emitter* a, + x86::Emitter* a, int rowRegs, int colRegs, - asmjit::X86Gp C_Offset, - asmjit::X86Gp ldcReg, + x86::Gp C_Offset, + x86::Gp ldcReg, + bool accum, - int leadingDimCRegAssign) { - asmjit::X86Ymm extractDest256 = x86::ymm31; - asmjit::X86Zmm extractDest512 = x86::zmm31; + int leadingDimCReg) { + x86::Ymm extractDest256 = x86::ymm31; + x86::Zmm extractDest512 = x86::zmm31; + using CRegs = x86::Zmm; for (int i = 0; i < rowRegs; ++i) { a->imul(C_Offset, ldcReg, static_cast<asmjit::Imm>(i * sizeof(int32_t))); for (int j = 0; j < colRegs; ++j) { for (int idx = 0; idx < 2; ++idx) { a->vextracti32x8( - extractDest256, CRegs_avx512_[i * leadingDimCRegAssign + j], idx); + extractDest256, CRegs(i * leadingDimCReg + j), idx); a->vpmovsxwd(extractDest512, extractDest256); - asmjit::X86Mem destAddr = x86::dword_ptr( + x86::Mem destAddr = x86::dword_ptr( +#ifdef _MSC_VER + a->gpz(9), C_Offset, 0, (j * 2 + idx) * 16 * sizeof(int32_t)); +#else a->zcx(), C_Offset, 0, (j * 2 + idx) * 16 * sizeof(int32_t)); +#endif if (accum) { a->vpaddd(extractDest512, extractDest512, destAddr); } @@ -167,261 +175,256 @@ CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<inst_set_t::avx512>( nRegBlockSize, nRegBlockSizeMin); - if (codeCache_.find(kernelSig) != codeCache_.end()) { - return codeCache_[kernelSig]; - } - - code_.reset(false); - code_.init(rt_.getCodeInfo()); - asmjit::X86Assembler assembler(&code_); - asmjit::X86Emitter* a = assembler.asEmitter(); + return codeCache_.getOrCreate(kernelSig, [&]() -> jit_micro_kernel_fp { + asmjit::CodeHolder code; + code.init(runtime().codeInfo()); + x86::Assembler assembler(&code); + x86::Emitter *a = assembler.as<x86::Emitter>(); #if defined(FBGEMM_LOG_CODE) - // generated code logging - FILE* codeLogfile = fopen( - getCodeLoggingFile<inst_set_t::avx512>( - accum, - mc, - nc, - nBlock, - kBlock, - mRegBlockSize, - nRegBlockSize, - nRegBlockSizeMin) - .c_str(), - "w"); - asmjit::FileLogger* codeLogger = new asmjit::FileLogger(codeLogfile); - if (codeLogger) { - code_.setLogger(codeLogger); - } + // generated code logging + FILE *codeLogfile = fopen(getCodeLoggingFile<inst_set_t::avx512>( + accum, mc, nc, nBlock, kBlock, mRegBlockSize, + nRegBlockSize, nRegBlockSizeMin) + .c_str(), + "w"); + asmjit::FileLogger *codeLogger = new asmjit::FileLogger(codeLogfile); + if (codeLogger) { + code.setLogger(codeLogger); + } #endif - assert(kc % row_interleave == 0 && "kc must be a multiple of row_interleave"); - assert(nc % nRegBlockSizeMin == 0 && "nc must be a multiple of NR_MIN"); - int maxMRegs = mRegBlockSize; - int maxNRegs = nRegBlockSize * row_interleave / VLEN_; - assert( - maxMRegs * maxNRegs <= 24 && - "MR*(NR*ROW_INTERLEAVE*8/512) \ - must be <= 24(available registers constraint)"); - int mRegBlocks = mc / mRegBlockSize; - int mRegBlocksRem = mc % mRegBlockSize; - - // arguments to the function created - asmjit::X86Gp buffer_A = a->zdi(); - asmjit::X86Gp buffer_B = a->zsi(); - asmjit::X86Gp B_pf = a->zdx(); - asmjit::X86Gp CBase = a->zcx(); - asmjit::X86Gp kSize = a->gpzRef(8); - asmjit::X86Gp ldcReg = a->gpzRef(9); - - asmjit::FuncDetail func; - func.init( - asmjit:: - FuncSignature6<void, uint8_t*, int8_t*, int8_t*, int32_t*, int, int>( - asmjit::CallConv::kIdHost)); - - asmjit::FuncFrameInfo ffi; - ffi.setDirtyRegs( - asmjit::X86Reg::kKindVec, - asmjit::Utils::mask(0, 1, 2, 3, 4, 5, 6, 7) | - asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14, 15)); - ffi.setDirtyRegs( - asmjit::X86Reg::kKindGp, - asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14, 15)); - - asmjit::FuncArgsMapper args(&func); - args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg); - - args.updateFrameInfo(ffi); - - asmjit::FuncFrameLayout layout; - layout.init(func, ffi); - - asmjit::FuncUtils::emitProlog(a, layout); - asmjit::FuncUtils::allocArgs(a, layout, args); - - asmjit::Label LoopMBlocks = a->newLabel(); - asmjit::Label LoopNBlocks = a->newLabel(); - asmjit::Label Loopk = a->newLabel(); - - asmjit::X86Gp buffer_B_saved = a->gpzRef(10); - asmjit::X86Gp C_Offset = a->gpzRef(11); - // asmjit::X86Gp B_pf_saved = a->gpzRef(12); - asmjit::X86Gp iIdx = a->gpzRef(13); - asmjit::X86Gp jIdx = a->gpzRef(14); - asmjit::X86Gp kIdx = a->gpzRef(15); - - // save B_buffer address - a->mov(buffer_B_saved, buffer_B); - // a->mov(B_pf_saved, B_pf); - - int currColRegs = nc * row_interleave * sizeof(int8_t) / VLEN_; - int colRegs = std::min(currColRegs, maxNRegs); - if (mRegBlocks > 0) { - // move 0 to iteration variables - a->mov(iIdx, 0); - - a->bind(LoopMBlocks); - a->inc(iIdx); - a->mov(jIdx, 0); - - a->bind(LoopNBlocks); - a->inc(jIdx); - - int rowRegs = mRegBlockSize; - - // init C registers - initCRegs<inst_set_t::avx512>(a, rowRegs, colRegs, colRegs); - - // init k loop index - a->mov(kIdx, 0); - a->bind(Loopk); - // k is incremented by row_interleave - a->add(kIdx, static_cast<asmjit::Imm>(row_interleave)); - - genComputeBlock<inst_set_t::avx512>( - a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock, colRegs); - - // update buffer_A address for next k iteration - a->add( - buffer_A, static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t))); - - // update buffer_B address for next k iteration - a->add( - buffer_B, - static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t))); - // a->add(B_pf, static_cast<asmjit::Imm>(nBlock * row_interleave * - // sizeof(int8_t))); - - a->cmp(kIdx, kSize); - a->jl(Loopk); - - // store C matrix - storeCRegs<inst_set_t::avx512>( - a, rowRegs, colRegs, C_Offset, ldcReg, accum, colRegs); - - // reset A - a->sub(buffer_A, kSize); - - // B for next block - a->mov(buffer_B, buffer_B_saved); - // using C_Offset as temp reg - a->imul( - C_Offset, - jIdx, - static_cast<asmjit::Imm>( - nRegBlockSize * row_interleave * sizeof(int8_t))); - a->add(buffer_B, C_Offset); - - // increment C for next block - a->add(CBase, static_cast<asmjit::Imm>(nRegBlockSize * sizeof(int32_t))); - - int jLoopTrips = currColRegs / maxNRegs; - // jLoopTrips should be at least 1 - jLoopTrips = jLoopTrips ? jLoopTrips : 1; - a->cmp(jIdx, jLoopTrips); - a->jl(LoopNBlocks); - - // increment A for next block - a->add( - buffer_A, static_cast<asmjit::Imm>((rowRegs)*kBlock * sizeof(uint8_t))); - - // increment C for next A block - a->sub( - CBase, - static_cast<asmjit::Imm>(jLoopTrips * nRegBlockSize * sizeof(int32_t))); - a->imul( - C_Offset, ldcReg, static_cast<asmjit::Imm>(rowRegs * sizeof(int32_t))); - a->add(CBase, C_Offset); - - // reset B - a->mov(buffer_B, buffer_B_saved); - // a->mov(B_pf, B_pf_saved); - - a->cmp(iIdx, mRegBlocks); - a->jl(LoopMBlocks); - } - // generate code for remainder - if (mRegBlocksRem > 0) { - asmjit::Label LoopNRem = a->newLabel(); - asmjit::Label LoopkRem = a->newLabel(); - int rowRegs = mRegBlocksRem; - - a->mov(jIdx, 0); - a->bind(LoopNRem); - a->inc(jIdx); - - // init C registers - initCRegs<inst_set_t::avx512>(a, rowRegs, colRegs, colRegs); - - // init k loop index - a->mov(kIdx, 0); - a->bind(LoopkRem); - - // k is incremented by row_interleave - a->add(kIdx, static_cast<asmjit::Imm>(row_interleave)); - - genComputeBlock<inst_set_t::avx512>( - a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock, colRegs); - - // update buffer_A address for next k iteration - a->add( - buffer_A, static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t))); - - // update buffer_B address for next k iteration - a->add( - buffer_B, - static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t))); - // a->add(B_pf, static_cast<asmjit::Imm>(nBlock * row_interleave * - // sizeof(int8_t))); - - a->cmp(kIdx, kSize); - a->jl(LoopkRem); - - // reset A - a->sub(buffer_A, kSize); - - // B for next block - a->mov(buffer_B, buffer_B_saved); - // using C_Offset as temp reg - a->imul( - C_Offset, - jIdx, - static_cast<asmjit::Imm>( - nRegBlockSize * row_interleave * sizeof(int8_t))); - a->add(buffer_B, C_Offset); - - // store C matrix - storeCRegs<inst_set_t::avx512>( - a, rowRegs, colRegs, C_Offset, ldcReg, accum, colRegs); - - // increment C for next block - a->add(CBase, static_cast<asmjit::Imm>(nRegBlockSize * sizeof(int32_t))); - - int jLoopTrips = currColRegs / maxNRegs; - // jLoopTrips should be at least 1 - jLoopTrips = jLoopTrips ? jLoopTrips : 1; - a->cmp(jIdx, jLoopTrips); - a->jl(LoopNRem); - } + assert(kc % row_interleave == 0 && + "kc must be a multiple of row_interleave"); + assert(nc % nRegBlockSizeMin == 0 && "nc must be a multiple of NR_MIN"); + int maxMRegs = mRegBlockSize; + int maxNRegs = nRegBlockSize * row_interleave / VLEN_; + assert((maxMRegs + 1) * maxNRegs <= 28 && + "number of zmm registers for C + one row for loading B: \ + MR*(NR*ROW_INTERLEAVE*8/512) + (NR*ROW_INTERLEAVE*8/512) \ + must be <= 28(available registers constraint)"); + int mRegBlocks = mc / mRegBlockSize; + int mRegBlocksRem = mc % mRegBlockSize; + + // arguments to the function created +#ifdef _MSC_VER + x86::Gp buffer_A = a->zcx(); + x86::Gp buffer_B = a->zdx(); + x86::Gp B_pf = a->gpz(8); + x86::Gp CBase = a->gpz(9); + x86::Gp kSize = a->zdi(); + x86::Gp ldcReg = a->zsi(); +#else + x86::Gp buffer_A = a->zdi(); + x86::Gp buffer_B = a->zsi(); + x86::Gp B_pf = a->zdx(); + x86::Gp CBase = a->zcx(); + x86::Gp kSize = a->gpz(8); + x86::Gp ldcReg = a->gpz(9); +#endif + + asmjit::FuncDetail func; + func.init( + asmjit::FuncSignatureT<void, uint8_t *, int8_t *, int8_t *, int32_t *, + int, int>(asmjit::CallConv::kIdHost)); + + asmjit::FuncFrame frame; + frame.init(func); + + frame.setDirtyRegs( + x86::Reg::kGroupVec, + asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) | + asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15)); + frame.setDirtyRegs(x86::Reg::kGroupGp, + asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15)); + + asmjit::FuncArgsAssignment args(&func); + args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg); + + args.updateFuncFrame(frame); + frame.finalize(); + + a->emitProlog(frame); + a->emitArgsAssignment(frame, args); + + asmjit::Label LoopMBlocks = a->newLabel(); + asmjit::Label LoopNBlocks = a->newLabel(); + asmjit::Label Loopk = a->newLabel(); + + x86::Gp buffer_B_saved = a->gpz(10); + x86::Gp C_Offset = a->gpz(11); + // x86::Gp B_pf_saved = a->gpz(12); + x86::Gp iIdx = a->gpz(13); + x86::Gp jIdx = a->gpz(14); + x86::Gp kIdx = a->gpz(15); + + // save B_buffer address + a->mov(buffer_B_saved, buffer_B); + // a->mov(B_pf_saved, B_pf); + + int currColRegs = nc * row_interleave * sizeof(int8_t) / VLEN_; + int colRegs = std::min(currColRegs, maxNRegs); + if (mRegBlocks > 0) { + // move 0 to iteration variables + a->mov(iIdx, 0); + + a->bind(LoopMBlocks); + a->inc(iIdx); + a->mov(jIdx, 0); + + a->bind(LoopNBlocks); + a->inc(jIdx); + + int rowRegs = mRegBlockSize; + + // init C registers + initCRegs<inst_set_t::avx512>(a, rowRegs, colRegs, colRegs); + + // init k loop index + a->mov(kIdx, 0); + a->bind(Loopk); + // k is incremented by row_interleave + a->add(kIdx, static_cast<asmjit::Imm>(row_interleave)); + + genComputeBlock<inst_set_t::avx512>(a, buffer_A, buffer_B, B_pf, rowRegs, + colRegs, kBlock, colRegs); + + // update buffer_A address for next k iteration + a->add(buffer_A, + static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t))); + + // update buffer_B address for next k iteration + a->add(buffer_B, static_cast<asmjit::Imm>(nBlock * row_interleave * + sizeof(int8_t))); + // a->add(B_pf, static_cast<asmjit::Imm>(nBlock * row_interleave * + // sizeof(int8_t))); + + a->cmp(kIdx, kSize); + a->jl(Loopk); + + // store C matrix + storeCRegs<inst_set_t::avx512>(a, rowRegs, colRegs, C_Offset, ldcReg, + accum, colRegs); + + // reset A + a->sub(buffer_A, kSize); + + // B for next block + a->mov(buffer_B, buffer_B_saved); + // using C_Offset as temp reg + a->imul(C_Offset, jIdx, + static_cast<asmjit::Imm>(nRegBlockSize * row_interleave * + sizeof(int8_t))); + a->add(buffer_B, C_Offset); + + // increment C for next block + a->add(CBase, static_cast<asmjit::Imm>(nRegBlockSize * sizeof(int32_t))); + + int jLoopTrips = currColRegs / maxNRegs; + // jLoopTrips should be at least 1 + jLoopTrips = jLoopTrips ? jLoopTrips : 1; + a->cmp(jIdx, jLoopTrips); + a->jl(LoopNBlocks); + + // increment A for next block + a->add(buffer_A, + static_cast<asmjit::Imm>((rowRegs)*kBlock * sizeof(uint8_t))); + + // increment C for next A block + a->sub(CBase, static_cast<asmjit::Imm>(jLoopTrips * nRegBlockSize * + sizeof(int32_t))); + a->imul(C_Offset, ldcReg, + static_cast<asmjit::Imm>(rowRegs * sizeof(int32_t))); + a->add(CBase, C_Offset); + + // reset B + a->mov(buffer_B, buffer_B_saved); + // a->mov(B_pf, B_pf_saved); + + a->cmp(iIdx, mRegBlocks); + a->jl(LoopMBlocks); + } + // generate code for remainder + if (mRegBlocksRem > 0) { + asmjit::Label LoopNRem = a->newLabel(); + asmjit::Label LoopkRem = a->newLabel(); + int rowRegs = mRegBlocksRem; + + a->mov(jIdx, 0); + a->bind(LoopNRem); + a->inc(jIdx); + + // init C registers + initCRegs<inst_set_t::avx512>(a, rowRegs, colRegs, colRegs); + + // init k loop index + a->mov(kIdx, 0); + a->bind(LoopkRem); + + // k is incremented by row_interleave + a->add(kIdx, static_cast<asmjit::Imm>(row_interleave)); + + genComputeBlock<inst_set_t::avx512>(a, buffer_A, buffer_B, B_pf, rowRegs, + colRegs, kBlock, colRegs); + + // update buffer_A address for next k iteration + a->add(buffer_A, + static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t))); + + // update buffer_B address for next k iteration + a->add(buffer_B, static_cast<asmjit::Imm>(nBlock * row_interleave * + sizeof(int8_t))); + // a->add(B_pf, static_cast<asmjit::Imm>(nBlock * row_interleave * + // sizeof(int8_t))); + + a->cmp(kIdx, kSize); + a->jl(LoopkRem); + + // reset A + a->sub(buffer_A, kSize); + + // B for next block + a->mov(buffer_B, buffer_B_saved); + // using C_Offset as temp reg + a->imul(C_Offset, jIdx, + static_cast<asmjit::Imm>(nRegBlockSize * row_interleave * + sizeof(int8_t))); + a->add(buffer_B, C_Offset); + + // store C matrix + storeCRegs<inst_set_t::avx512>(a, rowRegs, colRegs, C_Offset, ldcReg, + accum, colRegs); + + // increment C for next block + a->add(CBase, static_cast<asmjit::Imm>(nRegBlockSize * sizeof(int32_t))); + + int jLoopTrips = currColRegs / maxNRegs; + // jLoopTrips should be at least 1 + jLoopTrips = jLoopTrips ? jLoopTrips : 1; + a->cmp(jIdx, jLoopTrips); + a->jl(LoopNRem); + } - asmjit::FuncUtils::emitEpilog(a, layout); + a->emitEpilog(frame); - jit_micro_kernel_fp fn; - asmjit::Error err = rt_.add(&fn, &code_); - if (err) { - std::cout << "Error: in fn add" << std::endl; - return nullptr; - } - codeCache_[kernelSig] = fn; + jit_micro_kernel_fp fn; + asmjit::Error err; + { + std::unique_lock<std::mutex> lock(rtMutex_); + err = runtime().add(&fn, &code); + } + if (err) { + std::cout << "Error: in fn add" << std::endl; + return nullptr; + } #if defined(FBGEMM_LOG_CODE) - fclose(codeLogfile); - delete codeLogger; + fclose(codeLogfile); + delete codeLogger; #endif - return fn; + return fn; + }); } } // namespace fbgemm |