diff options
author | Young Jin Kim <youki@microsoft.com> | 2019-09-25 02:55:03 +0300 |
---|---|---|
committer | Young Jin Kim <youki@microsoft.com> | 2019-09-25 02:55:03 +0300 |
commit | 08763b198ef743741560ae42a9c10a3017c7c9ce (patch) | |
tree | 35fc2935e4845e406a7d2d8ea1aaec3ceb24fa41 | |
parent | bb5063533256a8a5a91a812f6a193d7f352a2a3a (diff) | |
parent | 97caeee5af56b3d8ca56499f6107d6b3e7f21684 (diff) |
Fix jit code (AVX512) on windows
-rw-r--r-- | include/fbgemm/Fbgemm.h | 12 | ||||
-rw-r--r-- | src/GenerateKernelU8S8S32ACC32Avx512.cc | 30 | ||||
-rw-r--r-- | src/QuantUtilsAvx2.cc | 2 | ||||
-rw-r--r-- | src/Utils.cc | 3 |
4 files changed, 43 insertions, 4 deletions
diff --git a/include/fbgemm/Fbgemm.h b/include/fbgemm/Fbgemm.h index 70f6294..09ddbe1 100644 --- a/include/fbgemm/Fbgemm.h +++ b/include/fbgemm/Fbgemm.h @@ -748,7 +748,11 @@ class FBGEMM_API PackAWithIm2Col ~PackAWithIm2Col() { if (rowOffsetAllocatedHere) { +#ifdef _MSC_VER + _aligned_free(row_offset_); +#else free(row_offset_); +#endif } } @@ -839,7 +843,11 @@ class FBGEMM_API PackAWithRowOffset final ~PackAWithRowOffset() { if (rowOffsetAllocatedHere) { +#ifdef _MSC_VER + _aligned_free(row_offset_); +#else free(row_offset_); +#endif } } @@ -932,7 +940,11 @@ class FBGEMM_API PackAWithQuantRowOffset final ~PackAWithQuantRowOffset() { if (rowOffsetAllocatedHere) { +#ifdef _MSC_VER + _aligned_free(row_offset_); +#else free(row_offset_); +#endif } } diff --git a/src/GenerateKernelU8S8S32ACC32Avx512.cc b/src/GenerateKernelU8S8S32ACC32Avx512.cc index fe35627..5986e48 100644 --- a/src/GenerateKernelU8S8S32ACC32Avx512.cc +++ b/src/GenerateKernelU8S8S32ACC32Avx512.cc @@ -105,10 +105,12 @@ void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::storeCRegs< a->vpaddd( CRegs_avx512_[i * leadingDimCReg + j], CRegs_avx512_[i * leadingDimCReg + j], - x86::dword_ptr(a->zcx(), C_Offset, 0, j * 16 * sizeof(int32_t))); + x86::dword_ptr(a->gpz(9), C_Offset, 0, j * 16 * sizeof(int32_t))); + // x86::dword_ptr(a->zcx(), C_Offset, 0, j * 16 * sizeof(int32_t))); } a->vmovups( - x86::dword_ptr(a->zcx(), C_Offset, 0, j * 16 * sizeof(int32_t)), + x86::dword_ptr(a->gpz(9), C_Offset, 0, j * 16 * sizeof(int32_t)), +// x86::dword_ptr(a->zcx(), C_Offset, 0, j * 16 * sizeof(int32_t)), CRegs_avx512_[i * leadingDimCReg + j]); } } @@ -204,18 +206,34 @@ CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreate<inst_set_t::avx512>( int mRegBlocksRem = mc % mRegBlockSize; // arguments to the function created +#ifdef _MSC_VER + x86::Gp buffer_A = a->zcx(); + x86::Gp buffer_B = a->zdx(); + x86::Gp B_pf = a->gpz(8); + x86::Gp CBase = a->gpz(9); + x86::Gp kSize = a->zdi(); // a->zsi(); // x86::esi; // a->zsi(); + x86::Gp ldcReg = a->zsi(); // a->zdi(); // x86::edi; // a->zdi(); +#else x86::Gp buffer_A = a->zdi(); x86::Gp buffer_B = a->zsi(); x86::Gp B_pf = a->zdx(); x86::Gp CBase = a->zcx(); x86::Gp kSize = a->gpz(8); x86::Gp ldcReg = a->gpz(9); +#endif asmjit::FuncDetail func; +#ifdef _MSC_VER + //func.init(asmjit::FuncSignature4<void, uint8_t*, int8_t*, int8_t*, int32_t*>( + // asmjit::CallConv::kIdHost)); + func.init(asmjit::FuncSignatureT<void, uint8_t*, int8_t*, int8_t*, int32_t*, int, int>( + asmjit::CallConv::kIdHost)); +#else func.init( asmjit:: FuncSignatureT<void, uint8_t*, int8_t*, int8_t*, int32_t*, int, int>( asmjit::CallConv::kIdHost)); +#endif asmjit::FuncFrame frame; frame.init(func); @@ -237,6 +255,14 @@ CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreate<inst_set_t::avx512>( a->emitProlog(frame); a->emitArgsAssignment(frame, args); +//#ifdef _MSC_VER +// // retrieve parameters from stack +// a->mov(kSize, asmjit::x86::dword_ptr(asmjit::x86::rsp, func.getArg(4).getStackOffset())); //0x20)); //func.getArg(4).getStackOffset())); +// std::cout << "func.getArg(4).getStackOffset(): " << func.getArg(4).getStackOffset() << std::endl; +// a->mov(ldcReg, asmjit::x86::dword_ptr(asmjit::x86::rsp, func.getArg(5).getStackOffset())); //;0x28)); //func.getArg(5).getStackOffset())); +// std::cout << "func.getArg(5).getStackOffset(): " << func.getArg(5).getStackOffset() << std::endl; +//#endif + asmjit::Label LoopMBlocks = a->newLabel(); asmjit::Label LoopNBlocks = a->newLabel(); asmjit::Label Loopk = a->newLabel(); diff --git a/src/QuantUtilsAvx2.cc b/src/QuantUtilsAvx2.cc index ecd0be2..875f1dc 100644 --- a/src/QuantUtilsAvx2.cc +++ b/src/QuantUtilsAvx2.cc @@ -23,7 +23,7 @@ void QuantizeAvx2( uint8_t* dst, int len, const TensorQuantizationParams& qparams) { -#if defined(__AVX2__) && defined(__FMA__) +#if defined(__AVX2__) && (defined(__FMA__) || defined(_MSC_VER)) constexpr int VLEN = 8; std::size_t i = 0; __m256 inverse_scale_v = _mm256_set1_ps(1.f / qparams.scale); diff --git a/src/Utils.cc b/src/Utils.cc index af7d918..27f89bb 100644 --- a/src/Utils.cc +++ b/src/Utils.cc @@ -181,7 +181,8 @@ void transpose_simd( if (cpuinfo_initialize()) { if (fbgemmHasAvx512Support()) { #ifdef _MSC_VER - internal::transpose_8x8(M, N, src, ld_src, dst, ld_dst); +// internal::transpose_8x8(M, N, src, ld_src, dst, ld_dst); + internal::transpose_16x16(M, N, src, ld_src, dst, ld_dst); #else internal::transpose_16x16(M, N, src, ld_src, dst, ld_dst); #endif |