diff options
author | Jianyu Huang <jianyuhuang@fb.com> | 2019-08-06 21:55:17 +0300 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2019-08-06 21:59:00 +0300 |
commit | cf34b9a26b609109b18d6498f0608faddb7a911b (patch) | |
tree | 1ceaddaf942edb9debcafad7491b750fc3a5f066 /src/ExecuteKernelU8S8.cc | |
parent | d8b3323668fdd15dc70e9cb43ab16e96f4846eeb (diff) |
Back out "[fbgemm] Integrate VNNI into FBGEMM master branch"
Summary:
Original commit changeset: fcaa13cc3159
ASMJIT requires the CMake version to be 3.8
However, FBGEMM and PyTorch only need the CMake version to be 3.5+.
This caused the build failure in FBGEMM:
https://circleci.com/gh/pytorch/FBGEMM/122#build-timing/containers/0
Reviewed By: dskhudia
Differential Revision: D16670547
fbshipit-source-id: 506714c3db1cb82cf98895f58f82f235128f5285
Diffstat (limited to 'src/ExecuteKernelU8S8.cc')
-rw-r--r-- | src/ExecuteKernelU8S8.cc | 47 |
1 files changed, 6 insertions, 41 deletions
diff --git a/src/ExecuteKernelU8S8.cc b/src/ExecuteKernelU8S8.cc index 0a4ff55..f7292fd 100644 --- a/src/ExecuteKernelU8S8.cc +++ b/src/ExecuteKernelU8S8.cc @@ -49,8 +49,7 @@ ExecuteKernel< throw std::runtime_error("Failed to initialize cpuinfo!"); } if (params) { - if (fbgemmHasAvx512VnniSupport() || fbgemmHasAvx512Support() || - fbgemmHasAvx2Support()) { + if (fbgemmHasAvx512Support() || fbgemmHasAvx2Support()) { mbSize_ = params->MCB; nbSize_ = params->NCB; nrMinSize_ = params->NR_MIN; @@ -60,20 +59,7 @@ ExecuteKernel< assert(0 && "unsupported architecure"); } } else { - if (fbgemmHasAvx512VnniSupport()) { - mbSize_ = PackingTraits< - int8_t, - typename packingAMatrix::accType, - inst_set_t::avx512_vnni>::MCB; - nbSize_ = PackingTraits< - int8_t, - typename packingAMatrix::accType, - inst_set_t::avx512_vnni>::NCB; - nrMinSize_ = PackingTraits< - int8_t, - typename packingAMatrix::accType, - inst_set_t::avx512_vnni>::NR_MIN; - } else if (fbgemmHasAvx512Support()) { + if (fbgemmHasAvx512Support()) { mbSize_ = PackingTraits< int8_t, typename packingAMatrix::accType, @@ -132,25 +118,7 @@ void ExecuteKernel< typename BaseType::jit_micro_kernel_fp fn; - if (fbgemmHasAvx512VnniSupport()) { - if (std::is_same<typename packingAMatrix::accType, std::int16_t>::value) { - // For AVX512VNNI, we redirect int16_t to int32_t accumulation. - CodeGenBase<uint8_t, int8_t, int32_t, int32_t> codeObj; - fn = codeObj.getOrCreate<inst_set_t::avx512_vnni>( - accum, - packed_rows_A, - packedB_.blockColSize(), - packedA_.numPackedCols(), - nbSize_); - } else { - fn = BaseType::template getOrCreate<inst_set_t::avx512_vnni>( - accum, - packed_rows_A, - packedB_.blockColSize(), - packedA_.numPackedCols(), - nbSize_); - } - } else if (fbgemmHasAvx512Support()) { + if (fbgemmHasAvx512Support()) { fn = BaseType::template getOrCreate<inst_set_t::avx512>( accum, packed_rows_A, @@ -180,10 +148,7 @@ void ExecuteKernel< if (jb == bColBlocks - 1) { int nc = ((packedB_.lastBcol() - 1) / nrMinSize_ + 1) * nrMinSize_; if (nc != nbSize_) { - if (fbgemmHasAvx512VnniSupport()) { - fn = BaseType::template getOrCreate<inst_set_t::avx512_vnni>( - accum, packed_rows_A, nc, packedA_.numPackedCols(), nbSize_); - } else if (fbgemmHasAvx512Support()) { + if (fbgemmHasAvx512Support()) { fn = BaseType::template getOrCreate<inst_set_t::avx512>( accum, packed_rows_A, nc, packedA_.numPackedCols(), nbSize_); } else if (fbgemmHasAvx2Support()) { @@ -248,7 +213,7 @@ void ExecuteKernel< int32_t nSize = C_buffer_start == C_tile_ ? jb * nbSize_ : packedB_.numCols(); if (nSize) { - if (fbgemmHasAvx512VnniSupport() || fbgemmHasAvx512Support()) { + if (fbgemmHasAvx512Support()) { // TODO: avx512 path // Currently use avx2 code outputProcess_.template f<inst_set_t::avx2>( @@ -273,7 +238,7 @@ void ExecuteKernel< if (C_buffer_start == C_tile_) { // When C_tile_ scratchpad was used to avoid accessing memory past // C_buffer_ . - if (fbgemmHasAvx512VnniSupport() || fbgemmHasAvx512Support()) { + if (fbgemmHasAvx512Support()) { // TODO: avx512 path // Currently use avx2 code outputProcess_.template f<inst_set_t::avx2>( |