diff options
Diffstat (limited to 'src')
28 files changed, 10187 insertions, 0 deletions
diff --git a/src/ExecuteKernel.cc b/src/ExecuteKernel.cc new file mode 100644 index 0000000..0e3d122 --- /dev/null +++ b/src/ExecuteKernel.cc @@ -0,0 +1,12 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include "ExecuteKernel.h" +#include <immintrin.h> +#include "fbgemm/Fbgemm.h" +#include "fbgemm/Utils.h" + +namespace fbgemm2 {} // namespace fbgemm2 diff --git a/src/ExecuteKernel.h b/src/ExecuteKernel.h new file mode 100644 index 0000000..55a2581 --- /dev/null +++ b/src/ExecuteKernel.h @@ -0,0 +1,11 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#pragma once +#include <cstdint> +#include "fbgemm/Fbgemm.h" +#include "ExecuteKernelGeneric.h" +#include "ExecuteKernelU8S8.h" diff --git a/src/ExecuteKernelGeneric.h b/src/ExecuteKernelGeneric.h new file mode 100644 index 0000000..e83e943 --- /dev/null +++ b/src/ExecuteKernelGeneric.h @@ -0,0 +1,64 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#pragma once +#include <cstdint> +#include "fbgemm/Fbgemm.h" +#include "GenerateKernel.h" + +namespace fbgemm2 { + +/** + * @brief Execute Engine for the macro-kernel and output processing. + * ExecuteKernel is a derived class of CodeGenBase. + */ +template < + typename packingAMatrix, + typename packingBMatrix, + typename cT, + typename processOutputType> +class ExecuteKernel : public CodeGenBase< + typename packingAMatrix::inpType, + typename packingBMatrix::inpType, + cT, + typename packingBMatrix::accType> { + public: + ExecuteKernel( + PackMatrix< + packingAMatrix, + typename packingAMatrix::inpType, + typename packingAMatrix::accType>& packA, + PackMatrix< + packingBMatrix, + typename packingBMatrix::inpType, + typename packingBMatrix::accType>& packB, + int32_t kBlock, + cT* matC, + typename packingBMatrix::accType* C_buffer, + int32_t ldc, + const processOutputType& outputProcess); + void execute(int kBlock); + + private: + PackMatrix< + packingAMatrix, + typename packingAMatrix::inpType, + typename packingAMatrix::accType>& + packedA_; ///< Packed block of matrix A. + PackMatrix< + packingBMatrix, + typename packingBMatrix::inpType, + typename packingBMatrix::accType>& packedB_; ///< Packed matrix B. + int32_t kBlock_; ///< Block ID in the k dimension. + cT* matC_; ///< Output for matrix C. + typename packingAMatrix::accType* + C_buffer_; ///< the accumulation buffer for matrix C. + int32_t ldc_; ///< the leading dimension of matrix C. + const processOutputType& outputProcess_; ///< output processing function for + ///< the C tile in the macro-kernel. +}; + +} // namespace fbgemm2 diff --git a/src/ExecuteKernelU8S8.cc b/src/ExecuteKernelU8S8.cc new file mode 100644 index 0000000..5145869 --- /dev/null +++ b/src/ExecuteKernelU8S8.cc @@ -0,0 +1,354 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include "ExecuteKernelU8S8.h" +#include <cpuinfo.h> +#include <chrono> + + +#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN +double kernel_time = 0.0; +double postprocessing_time = 0.0; +#endif + +namespace fbgemm2 { + +template <typename packingAMatrix, typename cT, typename processOutputType> +ExecuteKernel< + packingAMatrix, + PackBMatrix<int8_t, typename packingAMatrix::accType>, + cT, + processOutputType>:: + ExecuteKernel( + PackMatrix<packingAMatrix, uint8_t, typename packingAMatrix::accType>& + packA, + PackMatrix< + PackBMatrix<int8_t, typename packingAMatrix::accType>, + int8_t, + typename packingAMatrix::accType>& packB, + int32_t kBlock, + cT* matC, + int32_t* C_buffer, + int32_t ldc, + const processOutputType& outputProcess) + : packedA_(packA), + packedB_(packB), + kBlock_(kBlock), + matC_(matC), + C_buffer_(C_buffer), + ldc_(ldc), + outputProcess_(outputProcess) { + if (cpuinfo_has_x86_avx512f()) { + mbSize_ = PackingTraits< + int8_t, + typename packingAMatrix::accType, + inst_set_t::avx512>::MCB; + nbSize_ = PackingTraits< + int8_t, + typename packingAMatrix::accType, + inst_set_t::avx512>::NCB; + } else if (cpuinfo_has_x86_avx2()) { + mbSize_ = PackingTraits< + int8_t, + typename packingAMatrix::accType, + inst_set_t::avx2>::MCB; + nbSize_ = PackingTraits< + int8_t, + typename packingAMatrix::accType, + inst_set_t::avx2>::NCB; + } else { + assert(0 && "unsupported architecure"); + } + C_tile_ = new int32_t[mbSize_ * nbSize_]; +} + +template <typename packingAMatrix, typename cT, typename processOutputType> +void ExecuteKernel< + packingAMatrix, + PackBMatrix<int8_t, typename packingAMatrix::accType>, + cT, + processOutputType>::execute(int kBlock) { + // packedA_.printPackedMatrix("packedA from kernel"); + // packedB_.printPackedMatrix("packedB from kernel"); + + int32_t bColBlocks = packedB_.blockCols(); + + int8_t* bBuf; + int8_t* bBuf_pf; + + uint8_t* aBuf = packedA_.getBuf(0); + + int32_t packed_rows_A = packedA_.numPackedRows(); + int32_t row_start_A = packedA_.packedRowStart(); + + bool lastKBlock = packedB_.isThisLastKBlock(kBlock); + bool accum = kBlock > 0; + + typename BaseType::jit_micro_kernel_fp fn; + + if (cpuinfo_initialize()) { + if (cpuinfo_has_x86_avx512f()) { + fn = BaseType::template getOrCreate<inst_set_t::avx512>( + accum, + packed_rows_A, + packedB_.blockColSize(), + packedA_.numPackedCols(), + nbSize_); + } else if (cpuinfo_has_x86_avx2()) { + fn = BaseType::template getOrCreate<inst_set_t::avx2>( + accum, + packed_rows_A, + packedB_.blockColSize(), + packedA_.numPackedCols(), + nbSize_); + } else { + // TODO: Have default slower path + assert(0 && "unsupported architecture"); + return; + } + } else { + throw std::runtime_error("Failed to initialize cpuinfo!"); + } + +#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN + std::chrono::time_point<std::chrono::high_resolution_clock> t_start, t_end; + double dt; + t_start = std::chrono::high_resolution_clock::now(); +#endif + + for (int jb = 0; jb < bColBlocks; ++jb) { + + bBuf = packedB_.getBuf(jb, kBlock); + // prefetch addr of the next packed block of B matrix + bBuf_pf = packedB_.getBuf(jb == bColBlocks - 1 ? jb : jb + 1, kBlock); + + // Reuse the first rowblock of C_buffer_ unless when C_buffer_ is same as + // matC_ (inplace output processing) + int32_t* C_buffer_row_start = C_buffer_ + + ((C_buffer_ == reinterpret_cast<int32_t*>(matC_)) ? row_start_A * ldc_ + : 0); + int32_t* C_buffer_start = C_buffer_row_start + jb * nbSize_; + int32_t leadingDim = ldc_; + if (packedB_.isThereColRemainder() && (jb == bColBlocks - 1)) { + // In case we will access memory past C_buffer_, we use C_tile_ instead. + C_buffer_start = C_tile_; + leadingDim = nbSize_; + } + + fn(aBuf, + bBuf, + bBuf_pf, + C_buffer_start, + packedA_.numPackedCols(), + leadingDim); + +#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN + t_end = std::chrono::high_resolution_clock::now(); + dt = std::chrono::duration_cast<std::chrono::nanoseconds>(t_end - t_start) + .count(); + kernel_time += (dt); + t_start = std::chrono::high_resolution_clock::now(); +#endif + + // Output processing is done only once per rowblock + if (lastKBlock && jb == bColBlocks - 1) { + // When C_tile_ is used for the last column block, we need a separate + // handling for the last column block. + int32_t nSize = + C_buffer_start == C_tile_ ? jb * nbSize_ : packedB_.numCols(); + if (nSize) { + if (cpuinfo_has_x86_avx512f()) { + // TODO: avx512 path + // Currently use avx2 code + outputProcess_.template f<inst_set_t::avx2>( + matC_, + C_buffer_row_start, + {row_start_A, packed_rows_A, 0, nSize}, + ldc_, + ldc_); + } else if (cpuinfo_has_x86_avx2()) { + outputProcess_.template f<inst_set_t::avx2>( + matC_, + C_buffer_row_start, + {row_start_A, packed_rows_A, 0, nSize}, + ldc_, + ldc_); + } else { + // TODO: Have default slower path + assert(0 && "unsupported architecure"); + } + } + + if (C_buffer_start == C_tile_) { + if (cpuinfo_has_x86_avx512f()) { + // TODO: avx512 path + // Currently use avx2 code + outputProcess_.template f<inst_set_t::avx2>( + matC_, + C_tile_, + {row_start_A, packed_rows_A, jb * nbSize_, packedB_.lastBcol()}, + ldc_, + leadingDim); + } else if (cpuinfo_has_x86_avx2()) { + outputProcess_.template f<inst_set_t::avx2>( + matC_, + C_tile_, + {row_start_A, packed_rows_A, jb * nbSize_, packedB_.lastBcol()}, + ldc_, + leadingDim); + } else { + // TODO: Have default slower path + assert(0 && "unsupported architecure"); + } + } + } // output processing + +#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN + t_end = std::chrono::high_resolution_clock::now(); + dt = std::chrono::duration_cast<std::chrono::nanoseconds>(t_end - t_start) + .count(); + postprocessing_time += (dt); + t_start = std::chrono::high_resolution_clock::now(); +#endif + + } // for each j block +} +template class ExecuteKernel< + PackAWithRowOffset<uint8_t, int32_t>, + PackBMatrix<int8_t, int32_t>, + uint8_t, + ReQuantizeOutput<false /* FUSE_RELU*/>>; + +template class ExecuteKernel< + PackAWithRowOffset<uint8_t, int32_t>, + PackBMatrix<int8_t, int32_t>, + uint8_t, + ReQuantizeOutput<true>>; + +template class ExecuteKernel< + PackAWithQuantRowOffset<uint8_t, int32_t>, + PackBMatrix<int8_t, int32_t>, + float, + ReQuantizeForFloat<false>>; + +template class ExecuteKernel< + PackAWithQuantRowOffset<uint8_t, int32_t>, + PackBMatrix<int8_t, int32_t>, + float, + ReQuantizeForFloat<true>>; + +template class ExecuteKernel< + PackAWithRowOffset<uint8_t, int32_t>, + PackBMatrix<int8_t, int32_t>, + float, + ReQuantizeForFloat<false>>; + +template class ExecuteKernel< + PackAWithRowOffset<uint8_t, int32_t>, + PackBMatrix<int8_t, int32_t>, + float, + ReQuantizeForFloat<true>>; + +template class ExecuteKernel< + PackAMatrix<uint8_t, int16_t>, + PackBMatrix<int8_t, int16_t>, + int32_t, + memCopy<>>; + +template class ExecuteKernel< + PackAMatrix<uint8_t, int16_t>, + PackBMatrix<int8_t, int16_t>, + uint8_t, + ReQuantizeOutput<false>>; + +template class ExecuteKernel< + PackAMatrix<uint8_t, int32_t>, + PackBMatrix<int8_t, int32_t>, + int32_t, + memCopy<>>; + +template class ExecuteKernel< + PackAWithRowOffset<uint8_t, int16_t>, + PackBMatrix<int8_t, int16_t>, + uint8_t, + DoSpmdmOnInpBuffer< + ReQuantizeOutput<false>::outType, + int32_t, + ReQuantizeOutput<false>>>; + +template class ExecuteKernel< + PackAWithRowOffset<uint8_t, int16_t>, + PackBMatrix<int8_t, int16_t>, + uint8_t, + DoSpmdmOnInpBuffer< + ReQuantizeOutput<true>::outType, + int32_t, + ReQuantizeOutput<true>>>; + +template class ExecuteKernel< + PackAWithRowOffset<uint8_t, int16_t>, + PackBMatrix<int8_t, int16_t>, + float, + DoSpmdmOnInpBuffer< + ReQuantizeForFloat<false>::outType, + int32_t, + ReQuantizeForFloat<false>>>; + +template class ExecuteKernel< + PackAWithRowOffset<uint8_t, int16_t>, + PackBMatrix<int8_t, int16_t>, + uint8_t, + ReQuantizeOutput<false>>; + +template class ExecuteKernel< + PackAWithRowOffset<uint8_t, int16_t>, + PackBMatrix<int8_t, int16_t>, + uint8_t, + ReQuantizeOutput<true>>; + +template class ExecuteKernel< + PackAWithRowOffset<uint8_t, int16_t>, + PackBMatrix<int8_t, int16_t>, + int32_t, + memCopy<>>; + +template class ExecuteKernel< + PackAWithIm2Col<uint8_t, int16_t>, + PackBMatrix<int8_t, int16_t>, + int32_t, + memCopy<>>; + +template class ExecuteKernel< + PackAWithRowOffset<uint8_t, int32_t>, + PackBMatrix<int8_t, int32_t>, + int32_t, + memCopy<>>; + +template class ExecuteKernel< + PackAWithIm2Col<uint8_t, int32_t>, + PackBMatrix<int8_t, int32_t>, + int32_t, + memCopy<>>; + +template class ExecuteKernel< + PackAWithQuantRowOffset<uint8_t, int32_t>, + PackBMatrix<int8_t, int32_t>, + int32_t, + memCopy<>>; + +template class ExecuteKernel< + PackAWithRowOffset<uint8_t, int16_t>, + PackBMatrix<int8_t, int16_t>, + float, + ReQuantizeForFloat<false>>; + +template class ExecuteKernel< + PackAMatrix<uint8_t, int16_t>, + PackBMatrix<int8_t, int16_t>, + int32_t, + DoNothing<int32_t, int32_t>>; + +} // namespace fbgemm2 diff --git a/src/ExecuteKernelU8S8.h b/src/ExecuteKernelU8S8.h new file mode 100644 index 0000000..0bd7fc5 --- /dev/null +++ b/src/ExecuteKernelU8S8.h @@ -0,0 +1,73 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#pragma once +#include "ExecuteKernel.h" + +namespace fbgemm2 { + +/** + * @brief Execute Engine of uint 8 and int8 matrix + * multiplication for the macro-kernel and output processing. ExecuteKernel is a + * derived class of CodeGenBase. + */ +template <typename packingAMatrix, typename cT, typename processOutputType> +class ExecuteKernel< + packingAMatrix, + PackBMatrix<int8_t, typename packingAMatrix::accType>, + cT, + processOutputType> + : public CodeGenBase< + uint8_t, + int8_t, + int32_t, + typename packingAMatrix::accType> { + public: + using BaseType = + CodeGenBase<uint8_t, int8_t, int32_t, typename packingAMatrix::accType>; + /** + * @brief Constructor for initializing the parameters for macro-kernel and + * output processing type. + */ + ExecuteKernel( + PackMatrix<packingAMatrix, uint8_t, typename packingAMatrix::accType>& + packA, + PackMatrix< + PackBMatrix<int8_t, typename packingAMatrix::accType>, + int8_t, + typename packingAMatrix::accType>& packB, + int32_t kBlock, + cT* matC, + int32_t* C_buffer, + int32_t ldc, + const processOutputType& outputProcess); + void execute(int kBlock); + + ~ExecuteKernel() { + delete[] C_tile_; + } + + private: + PackMatrix<packingAMatrix, uint8_t, typename packingAMatrix::accType>& + packedA_; ///< Packed uint8 block of matrix A. + PackMatrix< + PackBMatrix<int8_t, typename packingAMatrix::accType>, + int8_t, + typename packingAMatrix::accType>& + packedB_; ///< Packed int8 matrix B. + int32_t kBlock_; ///< Block ID in the k dimension. + cT* matC_; ///< Output for matrix C. + int32_t* C_buffer_; ///< the accumulation buffer for matrix C. + int32_t ldc_; ///< the leading dimension of matrix C. + const processOutputType& outputProcess_; ///< output processing function for + ///< matrix C in the macro-kernel. + int32_t* C_tile_; ///< buffer for the last N block when NCB is not an exact + ///< multiple of N. + int mbSize_; ///< block size in the m dimension. + int nbSize_; ///< block size in the n dimension. +}; + +} // namespace fbgemm2 diff --git a/src/Fbgemm.cc b/src/Fbgemm.cc new file mode 100644 index 0000000..f3bac97 --- /dev/null +++ b/src/Fbgemm.cc @@ -0,0 +1,363 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include "fbgemm/Fbgemm.h" +#include <cpuinfo.h> +#include <stdexcept> +#include "ExecuteKernel.h" + +#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN +double packing_time = 0.0; +double computing_time = 0.0; +double run_time = 0.0; +#endif + +using namespace fbgemm2; + +namespace fbgemm2 { + +template < + typename packingAMatrix, + typename packingBMatrix, + typename cT, + typename processOutputType> +void fbgemmPacked( + PackMatrix< + packingAMatrix, + typename packingAMatrix::inpType, + typename packingAMatrix::accType>& packA, + PackMatrix< + packingBMatrix, + typename packingBMatrix::inpType, + typename packingBMatrix::accType>& packB, + cT* C, + int32_t* C_buffer, + uint32_t ldc, + const processOutputType& outProcess, + int thread_id, + int /* num_threads */) { + static_assert( + std::is_same< + typename packingAMatrix::accType, + typename packingBMatrix::accType>::value, + "Accumulation type of both matrices should be the same"); + + int MCB, KCB; + + // Run time CPU detection + if (cpuinfo_initialize()) { + if (cpuinfo_has_x86_avx512f()) { + MCB = PackingTraits< + typename packingAMatrix::inpType, + typename packingAMatrix::accType, + inst_set_t::avx512>::MCB; + KCB = PackingTraits< + typename packingAMatrix::inpType, + typename packingAMatrix::accType, + inst_set_t::avx512>::KCB; + } else if (cpuinfo_has_x86_avx2()) { + MCB = PackingTraits< + typename packingAMatrix::inpType, + typename packingAMatrix::accType, + inst_set_t::avx2>::MCB; + KCB = PackingTraits< + typename packingAMatrix::inpType, + typename packingAMatrix::accType, + inst_set_t::avx2>::KCB; + } else { + // TODO: Have default slower path + assert(0 && "unsupported architecture"); + return; + } + } else { + throw std::runtime_error("Failed to initialize cpuinfo!"); + } + + int MDim = packA.numRows(); + int KDim = packB.numRows(); + + int mBlocks = (MDim + MCB - 1) / MCB; + int kBlocks = (KDim + KCB - 1) / KCB; + + // remainders + int _mc = MDim % MCB; + int _kc = KDim % KCB; + + int kc, mc; + + block_type_t blockA{0, 0, 0, 0}; + + // B must be prepacked + assert(packB.isPrePacked() && "B matrix must be prepacked"); + +#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN + std::chrono::time_point<std::chrono::high_resolution_clock> t_very_start, + t_start, t_end; + double dt; + t_start = std::chrono::high_resolution_clock::now(); + t_very_start = std::chrono::high_resolution_clock::now(); +#endif + + ExecuteKernel<packingAMatrix, packingBMatrix, cT, processOutputType> + exeKernelObj(packA, packB, 0, C, C_buffer, ldc, outProcess); + // ToDo: thread based work division + for (int i = 0; i < mBlocks; ++i) { + mc = (i != mBlocks - 1 || _mc == 0) ? MCB : _mc; + for (int k = 0; k < kBlocks; ++k) { + kc = (k != kBlocks - 1 || _kc == 0) ? KCB : _kc; + // pack A matrix + blockA = {i * MCB, mc, k * KCB, kc}; + + packA.pack(blockA); + +#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN + t_end = std::chrono::high_resolution_clock::now(); + dt = std::chrono::duration_cast<std::chrono::nanoseconds>(t_end - t_start) + .count(); + packing_time += (dt); + t_start = std::chrono::high_resolution_clock::now(); +#endif + + exeKernelObj.execute(k); + +#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN + t_end = std::chrono::high_resolution_clock::now(); + dt = std::chrono::duration_cast<std::chrono::nanoseconds>(t_end - t_start) + .count(); + computing_time += (dt); + t_start = std::chrono::high_resolution_clock::now(); +#endif + } + } + +#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN + t_end = std::chrono::high_resolution_clock::now(); + dt = + std::chrono::duration_cast<std::chrono::nanoseconds>(t_end - t_very_start) + .count(); + run_time += (dt); + t_start = std::chrono::high_resolution_clock::now(); +#endif +} + +template void fbgemmPacked( + PackMatrix<PackAWithRowOffset<uint8_t, int32_t>, uint8_t, int32_t>& packA, + PackMatrix<PackBMatrix<int8_t, int32_t>, int8_t, int32_t>& packB, + uint8_t* C, + int32_t* C_buffer, + uint32_t ldc, + const ReQuantizeOutput<false>& outProcess, + int thread_id, + int num_threads); + +template void fbgemmPacked( + PackMatrix<PackAWithRowOffset<uint8_t, int32_t>, uint8_t, int32_t>& packA, + PackMatrix<PackBMatrix<int8_t, int32_t>, int8_t, int32_t>& packB, + uint8_t* C, + int32_t* C_buffer, + uint32_t ldc, + const ReQuantizeOutput<true>& outProcess, + int thread_id, + int num_threads); + +template void fbgemmPacked( + PackMatrix<PackAWithQuantRowOffset<uint8_t, int32_t>, uint8_t, int32_t>& + packA, + PackMatrix<PackBMatrix<int8_t, int32_t>, int8_t, int32_t>& packB, + float* C, + int32_t* C_buffer, + uint32_t ldc, + const ReQuantizeForFloat<false>& outProcess, + int thread_id, + int num_threads); + +template void fbgemmPacked( + PackMatrix<PackAWithQuantRowOffset<uint8_t, int32_t>, uint8_t, int32_t>& + packA, + PackMatrix<PackBMatrix<int8_t, int32_t>, int8_t, int32_t>& packB, + float* C, + int32_t* C_buffer, + uint32_t ldc, + const ReQuantizeForFloat<true>& outProcess, + int thread_id, + int num_threads); + +template void fbgemmPacked( + PackMatrix<PackAMatrix<uint8_t, int32_t>, uint8_t, int32_t>& packA, + PackMatrix<PackBMatrix<int8_t, int32_t>, int8_t, int32_t>& packB, + int32_t* C, + int32_t* C_buffer, + uint32_t ldc, + const memCopy<>& outProcess, + int thread_id, + int num_threads); + +template void fbgemmPacked( + PackMatrix<PackAWithRowOffset<uint8_t, int32_t>, uint8_t, int32_t>& packA, + PackMatrix<PackBMatrix<int8_t, int32_t>, int8_t, int32_t>& packB, + float* C, + int32_t* C_buffer, + uint32_t ldc, + const ReQuantizeForFloat<false>& outProcess, + int thread_id, + int num_threads); + +template void fbgemmPacked( + PackMatrix<PackAWithRowOffset<uint8_t, int32_t>, uint8_t, int32_t>& packA, + PackMatrix<PackBMatrix<int8_t, int32_t>, int8_t, int32_t>& packB, + float* C, + int32_t* C_buffer, + uint32_t ldc, + const ReQuantizeForFloat<true>& outProcess, + int thread_id, + int num_threads); + +template void fbgemmPacked( + PackMatrix<PackAWithRowOffset<uint8_t, int32_t>, uint8_t, int32_t>& packA, + PackMatrix<PackBMatrix<int8_t, int32_t>, int8_t, int32_t>& packB, + int32_t* C, + int32_t* C_buffer, + uint32_t ldc, + const memCopy<>& outProcess, + int thread_id, + int num_threads); + +template void fbgemmPacked( + PackMatrix<PackAWithIm2Col<uint8_t, int32_t>, uint8_t, int32_t>& packA, + PackMatrix<PackBMatrix<int8_t, int32_t>, int8_t, int32_t>& packB, + int32_t* C, + int32_t* C_buffer, + uint32_t ldc, + const memCopy<>& outProcess, + int thread_id, + int num_threads); + +template void fbgemmPacked( + PackMatrix<PackAWithQuantRowOffset<uint8_t, int32_t>, uint8_t, int32_t>& + packA, + PackMatrix<PackBMatrix<int8_t, int32_t>, int8_t, int32_t>& packB, + int32_t* C, + int32_t* C_buffer, + uint32_t ldc, + const memCopy<>& outProcess, + int thread_id, + int num_threads); + +// 16 bit accumulation functions +template void fbgemmPacked( + PackMatrix<PackAMatrix<uint8_t, int16_t>, uint8_t, int16_t>& packA, + PackMatrix<PackBMatrix<int8_t, int16_t>, int8_t, int16_t>& packB, + int32_t* C, + int32_t* C_buffer, + uint32_t ldc, + const memCopy<>& outProcess, + int thread_id, + int num_threads); + +template void fbgemmPacked( + PackMatrix<PackAMatrix<uint8_t, int16_t>, uint8_t, int16_t>& packA, + PackMatrix<PackBMatrix<int8_t, int16_t>, int8_t, int16_t>& packB, + uint8_t* C, + int32_t* C_buffer, + uint32_t ldc, + const ReQuantizeOutput<false>& outProcess, + int thread_id, + int num_threads); + +template void fbgemmPacked( + PackMatrix<PackAWithRowOffset<uint8_t, int16_t>, uint8_t, int16_t>& packA, + PackMatrix<PackBMatrix<int8_t, int16_t>, int8_t, int16_t>& packB, + uint8_t* C, + int32_t* C_buffer, + uint32_t ldc, + const DoSpmdmOnInpBuffer<uint8_t, int32_t, ReQuantizeOutput<false>>& + outProcess, + int thread_id, + int num_threads); + +template void fbgemmPacked( + PackMatrix<PackAWithRowOffset<uint8_t, int16_t>, uint8_t, int16_t>& packA, + PackMatrix<PackBMatrix<int8_t, int16_t>, int8_t, int16_t>& packB, + uint8_t* C, + int32_t* C_buffer, + uint32_t ldc, + const DoSpmdmOnInpBuffer<uint8_t, int32_t, ReQuantizeOutput<true>>& + outProcess, + int thread_id, + int num_threads); + +template void fbgemmPacked( + PackMatrix<PackAWithRowOffset<uint8_t, int16_t>, uint8_t, int16_t>& packA, + PackMatrix<PackBMatrix<int8_t, int16_t>, int8_t, int16_t>& packB, + float* C, + int32_t* C_buffer, + uint32_t ldc, + const DoSpmdmOnInpBuffer<float, int32_t, ReQuantizeForFloat<false>>& + outProcess, + int thread_id, + int num_threads); + +template void fbgemmPacked( + PackMatrix<PackAWithRowOffset<uint8_t, int16_t>, uint8_t, int16_t>& packA, + PackMatrix<PackBMatrix<int8_t, int16_t>, int8_t, int16_t>& packB, + uint8_t* C, + int32_t* C_buffer, + uint32_t ldc, + const ReQuantizeOutput<false>& outProcess, + int thread_id, + int num_threads); + +template void fbgemmPacked( + PackMatrix<PackAWithRowOffset<uint8_t, int16_t>, uint8_t, int16_t>& packA, + PackMatrix<PackBMatrix<int8_t, int16_t>, int8_t, int16_t>& packB, + uint8_t* C, + int32_t* C_buffer, + uint32_t ldc, + const ReQuantizeOutput<true>& outProcess, + int thread_id, + int num_threads); + +template void fbgemmPacked( + PackMatrix<PackAWithRowOffset<uint8_t, int16_t>, uint8_t, int16_t>& packA, + PackMatrix<PackBMatrix<int8_t, int16_t>, int8_t, int16_t>& packB, + int32_t* C, + int32_t* C_buffer, + uint32_t ldc, + const memCopy<>& outProcess, + int thread_id, + int num_threads); + +template void fbgemmPacked( + PackMatrix<PackAWithIm2Col<uint8_t, int16_t>, uint8_t, int16_t>& packA, + PackMatrix<PackBMatrix<int8_t, int16_t>, int8_t, int16_t>& packB, + int32_t* C, + int32_t* C_buffer, + uint32_t ldc, + const memCopy<>& outProcess, + int thread_id, + int num_threads); + +template void fbgemmPacked( + PackMatrix<PackAMatrix<uint8_t, int16_t>, uint8_t, int16_t>& packA, + PackMatrix<PackBMatrix<int8_t, int16_t>, int8_t, int16_t>& packB, + int32_t* C, + int32_t* C_buffer, + uint32_t ldc, + const DoNothing<int32_t, int32_t>& outProcess, + int thread_id, + int num_threads); + +template void fbgemmPacked( + PackMatrix<PackAWithRowOffset<uint8_t, int16_t>, uint8_t, int16_t>& packA, + PackMatrix<PackBMatrix<int8_t, int16_t>, int8_t, int16_t>& packB, + float* C, + int32_t* C_buffer, + uint32_t ldc, + const ReQuantizeForFloat<false>& outProcess, + int thread_id, + int num_threads); + +} // namespace fbgemm2 diff --git a/src/FbgemmFP16.cc b/src/FbgemmFP16.cc new file mode 100644 index 0000000..7bbfa54 --- /dev/null +++ b/src/FbgemmFP16.cc @@ -0,0 +1,293 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include "fbgemm/FbgemmFP16.h" + +#include <cpuinfo.h> + +#include "FbgemmFP16UKernels.h" + +using namespace std; + +namespace fbgemm2 { + +/// class that performs packing of matrix in +/// row-major or col-major format into +/// internal packed blocked-row major format + +/// Todo: make it fast with AVX2 transpose +inline void PackA(int nrow, int ncol, const float* from, int ldim, float* to) { + // for (int r = 0; r < nrow; ++r) { + // for (int c = 0; c < ncol; ++c) { + // to[r + c * nrow] = from[r * ldim + c]; + // } + // } + transpose_simd( nrow, ncol, from, ldim, to, nrow ); +} + +struct KernelInfo { + using knl_ptr = funcptr_fp16; + // optimized kernels to cover all cases + static constexpr array<knl_ptr, 15> kernel = {{ + nullptr, gemmkernel_1x1_AVX2_fA0fB0fC0, + gemmkernel_2x1_AVX2_fA0fB0fC0, gemmkernel_3x1_AVX2_fA0fB0fC0, + gemmkernel_4x1_AVX2_fA0fB0fC0, gemmkernel_5x1_AVX2_fA0fB0fC0, + gemmkernel_6x1_AVX2_fA0fB0fC0, gemmkernel_7x1_AVX2_fA0fB0fC0, + gemmkernel_8x1_AVX2_fA0fB0fC0, gemmkernel_9x1_AVX2_fA0fB0fC0, + gemmkernel_10x1_AVX2_fA0fB0fC0, gemmkernel_11x1_AVX2_fA0fB0fC0, + gemmkernel_12x1_AVX2_fA0fB0fC0, gemmkernel_13x1_AVX2_fA0fB0fC0, + gemmkernel_14x1_AVX2_fA0fB0fC0 + }}; + + // autotuned kernel splits for various cases m = 1:mb_max + // may need re-autotuning for new uarch + static constexpr array<array<pair<int, int>, 2>, 121 > partition = { + { + {{ { 0, 0 }, { 0, 0 } } }, + {{ { 1, 1 }, { 0, 0 } } }, + {{ { 2, 1 }, { 0, 0 } } }, + {{ { 3, 1 }, { 0, 0 } } }, + {{ { 4, 1 }, { 0, 0 } } }, + {{ { 5, 1 }, { 0, 0 } } }, + {{ { 6, 1 }, { 0, 0 } } }, + {{ { 7, 1 }, { 0, 0 } } }, + {{ { 8, 1 }, { 0, 0 } } }, + {{ { 9, 1 }, { 0, 0 } } }, + {{ { 10, 1 }, { 0, 0 } } }, + {{ { 11, 1 }, { 0, 0 } } }, + {{ { 12, 1 }, { 0, 0 } } }, + {{ { 13, 1 }, { 0, 0 } } }, + {{ { 14, 1 }, { 0, 0 } } }, + {{ { 8, 1 }, { 7, 1 } } }, + {{ { 10, 1 }, { 6, 1 } } }, + {{ { 11, 1 }, { 6, 1 } } }, + {{ { 12, 1 }, { 6, 1 } } }, + {{ { 11, 1 }, { 8, 1 } } }, + {{ { 11, 1 }, { 9, 1 } } }, + {{ { 12, 1 }, { 9, 1 } } }, + {{ { 11, 2 }, { 0, 0 } } }, + {{ { 12, 1 }, { 11, 1 } } }, + {{ { 12, 2 }, { 0, 0 } } }, + {{ { 13, 1 }, { 12, 1 } } }, + {{ { 13, 2 }, { 0, 0 } } }, + {{ { 14, 1 }, { 13, 1 } } }, + {{ { 14, 2 }, { 0, 0 } } }, + {{ { 11, 2 }, { 7, 1 } } }, + {{ { 10, 3 }, { 0, 0 } } }, + {{ { 12, 2 }, { 7, 1 } } }, + {{ { 12, 2 }, { 8, 1 } } }, + {{ { 11, 3 }, { 0, 0 } } }, + {{ { 13, 2 }, { 8, 1 } } }, + {{ { 13, 2 }, { 9, 1 } } }, + {{ { 13, 2 }, { 10, 1 } } }, + {{ { 13, 2 }, { 11, 1 } } }, + {{ { 13, 2 }, { 12, 1 } } }, + {{ { 13, 3 }, { 0, 0 } } }, + {{ { 14, 2 }, { 12, 1 } } }, + {{ { 14, 2 }, { 13, 1 } } }, + {{ { 11, 3 }, { 9, 1 } } }, + {{ { 11, 3 }, { 10, 1 } } }, + {{ { 11, 4 }, { 0, 0 } } }, + {{ { 12, 3 }, { 9, 1 } } }, + {{ { 12, 3 }, { 10, 1 } } }, + {{ { 13, 3 }, { 8, 1 } } }, + {{ { 13, 3 }, { 9, 1 } } }, + {{ { 13, 3 }, { 10, 1 } } }, + {{ { 13, 3 }, { 11, 1 } } }, + {{ { 13, 3 }, { 12, 1 } } }, + {{ { 13, 4 }, { 0, 0 } } }, + {{ { 14, 3 }, { 11, 1 } } }, + {{ { 11, 4 }, { 10, 1 } } }, + {{ { 12, 4 }, { 7, 1 } } }, + {{ { 14, 4 }, { 0, 0 } } }, + {{ { 12, 4 }, { 9, 1 } } }, + {{ { 12, 4 }, { 10, 1 } } }, + {{ { 12, 4 }, { 11, 1 } } }, + {{ { 13, 4 }, { 8, 1 } } }, + {{ { 13, 4 }, { 9, 1 } } }, + {{ { 13, 4 }, { 10, 1 } } }, + {{ { 13, 4 }, { 11, 1 } } }, + {{ { 11, 5 }, { 9, 1 } } }, + {{ { 13, 5 }, { 0, 0 } } }, + {{ { 14, 4 }, { 10, 1 } } }, + {{ { 12, 5 }, { 7, 1 } } }, + {{ { 12, 5 }, { 8, 1 } } }, + {{ { 14, 4 }, { 13, 1 } } }, + {{ { 14, 5 }, { 0, 0 } } }, + {{ { 12, 5 }, { 11, 1 } } }, + {{ { 13, 5 }, { 7, 1 } } }, + {{ { 11, 6 }, { 7, 1 } } }, + {{ { 13, 5 }, { 9, 1 } } }, + {{ { 13, 5 }, { 10, 1 } } }, + {{ { 13, 5 }, { 11, 1 } } }, + {{ { 13, 5 }, { 12, 1 } } }, + {{ { 13, 6 }, { 0, 0 } } }, + {{ { 12, 6 }, { 7, 1 } } }, + {{ { 12, 6 }, { 8, 1 } } }, + {{ { 12, 6 }, { 9, 1 } } }, + {{ { 12, 6 }, { 10, 1 } } }, + {{ { 12, 6 }, { 11, 1 } } }, + {{ { 12, 7 }, { 0, 0 } } }, + {{ { 13, 6 }, { 7, 1 } } }, + {{ { 13, 6 }, { 8, 1 } } }, + {{ { 13, 6 }, { 9, 1 } } }, + {{ { 13, 6 }, { 10, 1 } } }, + {{ { 13, 6 }, { 11, 1 } } }, + {{ { 13, 6 }, { 12, 1 } } }, + {{ { 13, 7 }, { 0, 0 } } }, + {{ { 12, 7 }, { 8, 1 } } }, + {{ { 12, 7 }, { 9, 1 } } }, + {{ { 14, 6 }, { 10, 1 } } }, + {{ { 12, 7 }, { 11, 1 } } }, + {{ { 13, 7 }, { 5, 1 } } }, + {{ { 13, 7 }, { 6, 1 } } }, + {{ { 13, 7 }, { 7, 1 } } }, + {{ { 13, 7 }, { 8, 1 } } }, + {{ { 13, 7 }, { 9, 1 } } }, + {{ { 13, 7 }, { 10, 1 } } }, + {{ { 13, 7 }, { 11, 1 } } }, + {{ { 13, 7 }, { 12, 1 } } }, + {{ { 12, 8 }, { 8, 1 } } }, + {{ { 12, 8 }, { 9, 1 } } }, + {{ { 12, 8 }, { 10, 1 } } }, + {{ { 12, 8 }, { 11, 1 } } }, + {{ { 12, 9 }, { 0, 0 } } }, + {{ { 11, 9 }, { 10, 1 } } }, + {{ { 13, 8 }, { 6, 1 } } }, + {{ { 13, 8 }, { 7, 1 } } }, + {{ { 13, 8 }, { 8, 1 } } }, + {{ { 13, 8 }, { 9, 1 } } }, + {{ { 13, 8 }, { 10, 1 } } }, + {{ { 13, 8 }, { 11, 1 } } }, + {{ { 12, 9 }, { 8, 1 } } }, + {{ { 13, 9 }, { 0, 0 } } }, + {{ { 12, 9 }, { 10, 1 } } }, + {{ { 12, 9 }, { 11, 1 } } }, + {{ { 12, 10 }, { 0, 0 } } } + } + }; +}; +constexpr array<KernelInfo::knl_ptr, 15> KernelInfo::kernel; +constexpr array<array<pair<int, int>, 2>, 121 > KernelInfo::partition; + +// autotuned kernel splits for various cases m = 1:mb_max +void +cblas_gemm_compute(const matrix_op_t transa, const int m, const float *A, + const PackedGemmMatrixFP16 &Bp, const float beta, + float *C) { + // ground truth + assert(cpuinfo_initialize()); + assert(cpuinfo_has_x86_fma3()); + assert(cpuinfo_has_x86_f16c()); + assert(transa == matrix_op_t::NoTranspose); + + // constants + const int n = Bp.numCols(), k = Bp.numRows(), ldc = n; + const int mb_max = 120; + constexpr int simd_width = 8; + constexpr int kernel_ncol_blocks = 1; + constexpr int kernel_ncols = kernel_ncol_blocks * simd_width; + + // private scratchpad storage + static thread_local unique_ptr<std::array<float, 256 * 1024> > scratchpad( + new std::array<float, 256 * 1024>()); + + GemmParams gp; + for (auto m0 = 0; m0 < m; m0 += mb_max) { + int mb = std::min(mb_max, m - m0); + assert(mb < KernelInfo::partition.size()); + for (auto k_ind = 0; k_ind < k; k_ind += Bp.blockRowSize()) { + + // set up proper accumulation to avoid "Nan" problem + float beta_; + uint64_t accum; + if (k_ind == 0) { + // accumulate of beta != 0.0 + // do not!!! accumulate otherwise + beta_ = beta; + accum = (beta_ == 0.0f) ? 0 : 1; + } else { + // always accumulate with beta_ = 1.0f + beta_ = 1.0f; + accum = 1; + } + + const int kb = std::min(Bp.blockRowSize(), Bp.numRows() - k_ind); + + auto m1 = 0; + for (auto c = 0; c < 2; c++) { + + auto kernel_nrows = KernelInfo::partition[mb][c].first; + auto nkernel_nrows = KernelInfo::partition[mb][c].second; + + auto m_start = m1, m_end = m1 + kernel_nrows * nkernel_nrows; + for (auto m2 = m_start; m2 < m_end; m2 += kernel_nrows) { + assert(kernel_nrows * kb < scratchpad->size()); + PackA(kernel_nrows, kb, &A[m2 * k + k_ind], k, scratchpad->data()); + + int nbcol = n / Bp.blockColSize(); + gp.k = kb; + gp.A = scratchpad->data(); + gp.B = &(Bp(k_ind, 0)); + gp.beta = &beta_; + gp.accum = accum; + gp.C = &C[m2 * ldc]; + gp.ldc = ldc * sizeof(C[0]); + gp.b_block_cols = nbcol; + gp.b_block_size = gp.k * Bp.blockColSize() * sizeof(gp.B[0]); + if ((n % Bp.blockColSize()) == 0) { + KernelInfo::kernel[kernel_nrows](&gp); + } else { + int last_blk_col = nbcol * Bp.blockColSize(); + if (nbcol) { + KernelInfo::kernel[kernel_nrows](&gp); + } + + // leftover + int rem = n - last_blk_col; + assert(rem < kernel_ncols); + int b = (rem % simd_width) ? ((rem + simd_width) / simd_width) + : (rem / simd_width); + assert(b == 1); + if ((rem % simd_width) == 0) { + gp.B = &(Bp(k_ind, last_blk_col)); + gp.C = &C[m2 * ldc + last_blk_col]; + gp.b_block_cols = 1; + KernelInfo::kernel[kernel_nrows](&gp); + } else { + // small temporary buffer + float c_tmp[16 * 24] = { 0 }; + assert((16 * 24) > kernel_nrows * kernel_ncols); + + gp.B = &(Bp(k_ind, last_blk_col)); + gp.C = c_tmp; + gp.ldc = 8 * sizeof(C[0]); + gp.b_block_cols = 1; + KernelInfo::kernel[kernel_nrows](&gp); + for (int i = 0; i < kernel_nrows; i++) { + // Todo: use assembly + for (int j = last_blk_col; j < n; j++) { + assert(i * 8 + (j - last_blk_col) < + sizeof(c_tmp) / sizeof(c_tmp[0])); + if (accum == 0) { + C[(m2 + i) * ldc + j] = c_tmp[i * 8 + (j - last_blk_col)]; + } else { + C[(m2 + i) * ldc + j] = beta_ * C[(m2 + i) * ldc + j] + + c_tmp[i * 8 + (j - last_blk_col)]; + } + } + } + } + } + } + m1 += kernel_nrows * nkernel_nrows; + } + } + } +} + + +} // namespace fbgemm diff --git a/src/FbgemmFP16UKernels.cc b/src/FbgemmFP16UKernels.cc new file mode 100644 index 0000000..ec1b297 --- /dev/null +++ b/src/FbgemmFP16UKernels.cc @@ -0,0 +1,2203 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include "FbgemmFP16UKernels.h" + +namespace fbgemm2 { + +void __attribute__ ((noinline)) gemmkernel_1x1_AVX2_fA0fB0fC0(GemmParams *gp) +{ +asm volatile +( +#if !defined(__clang__) +"mov r14, %[gp]\t\n" +#else +"mov %[gp], %%r14\t\n" +".intel_syntax noprefix\t\n" +#endif + +// Copy parameters +// k +"mov r8, [r14 + 0]\t\n" +// A +"mov r9, [r14 + 8]\t\n" +// B +"mov r10, [r14 + 16]\t\n" +// beta +"mov r15, [r14 + 24]\t\n" +// accum +"mov rdx, [r14 + 32]\t\n" +// C +"mov r12, [r14 + 40]\t\n" +// ldc +"mov r13, [r14 + 48]\t\n" +// b_block_cols +"mov rdi, [r14 + 56]\t\n" +// b_block_size +"mov rsi, [r14 + 64]\t\n" +// Make copies of A and C +"mov rax, r9\t\n" +"mov rcx, r12\t\n" + + +"mov rbx, 0\t\n" +"loop_outter%=:\t\n" +"mov r14, 0\t\n" +"vxorps ymm0,ymm0,ymm0\t\n" + +"vcvtph2ps ymm15, XMMWORD PTR [r10 + 0]\t\n" +"mov r11, 16\t\n" + +"loop_inner%=:\t\n" + +"vcvtph2ps ymm14,XMMWORD PTR [r10 + r11 + 0]\t\n" +"inc r14\t\n" +"vbroadcastss ymm1,DWORD PTR [r9+0]\t\n" +"vfmadd231ps ymm0,ymm15,ymm1\t\n" +"cmp r14, r8\t\n" +"jge L_exit%=\t\n" +"vcvtph2ps ymm15,XMMWORD PTR [r10 + r11 + 16]\t\n" +"inc r14\t\n" +"vbroadcastss ymm1,DWORD PTR [r9+4]\t\n" +"vfmadd231ps ymm0,ymm14,ymm1\t\n" +"add r11, 32\t\n" +"add r9,8\t\n" +"cmp r14, r8\t\n" +"jl loop_inner%=\t\n" + +"L_exit%=:\t\n" +"add r10, rsi\t\n" + +"cmp rdx, 1\t\n" +"je L_accum%=\t\n" +// Dump C +"vmovups YMMWORD PTR [r12 + 0], ymm0\t\n" +"add r12, r13\t\n" +"jmp L_done%=\t\n" + + +"L_accum%=:\t\n" +// Dump C with accumulate +"vbroadcastss ymm15,DWORD PTR [r15]\t\n" +"vfmadd231ps ymm0,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm0\t\n" +"add r12, r13\t\n" + +"L_done%=:\t\n" + +// next outer iteration +"add rcx, 32\t\n" +"mov r12, rcx\t\n" +"mov r9, rax\t\n" +"inc rbx\t\n" +"cmp rbx, rdi\t\n" +"jl loop_outter%=\t\n" +: +: +[gp] "rm" (gp) +: "r8", "r9", "r10", "r11", "r15", "r13", "r14", +"rax", "rcx", "rdx", "rsi", "rdi", "rbx", "r12", "memory" +); +} +void __attribute__ ((noinline)) gemmkernel_2x1_AVX2_fA0fB0fC0(GemmParams *gp) +{ +asm volatile +( +#if !defined(__clang__) +"mov r14, %[gp]\t\n" +#else +"mov %[gp], %%r14\t\n" +".intel_syntax noprefix\t\n" +#endif + +// Copy parameters +// k +"mov r8, [r14 + 0]\t\n" +// A +"mov r9, [r14 + 8]\t\n" +// B +"mov r10, [r14 + 16]\t\n" +// beta +"mov r15, [r14 + 24]\t\n" +// accum +"mov rdx, [r14 + 32]\t\n" +// C +"mov r12, [r14 + 40]\t\n" +// ldc +"mov r13, [r14 + 48]\t\n" +// b_block_cols +"mov rdi, [r14 + 56]\t\n" +// b_block_size +"mov rsi, [r14 + 64]\t\n" +// Make copies of A and C +"mov rax, r9\t\n" +"mov rcx, r12\t\n" + + +"mov rbx, 0\t\n" +"loop_outter%=:\t\n" +"mov r14, 0\t\n" +"vxorps ymm0,ymm0,ymm0\t\n" +"vxorps ymm1,ymm1,ymm1\t\n" + +"vcvtph2ps ymm15, XMMWORD PTR [r10 + 0]\t\n" +"mov r11, 16\t\n" + +"loop_inner%=:\t\n" + +"vcvtph2ps ymm14,XMMWORD PTR [r10 + r11 + 0]\t\n" +"inc r14\t\n" +"vbroadcastss ymm2,DWORD PTR [r9+0]\t\n" +"vfmadd231ps ymm0,ymm15,ymm2\t\n" +"vbroadcastss ymm2,DWORD PTR [r9+4]\t\n" +"vfmadd231ps ymm1,ymm15,ymm2\t\n" +"cmp r14, r8\t\n" +"jge L_exit%=\t\n" +"vcvtph2ps ymm15,XMMWORD PTR [r10 + r11 + 16]\t\n" +"inc r14\t\n" +"vbroadcastss ymm2,DWORD PTR [r9+8]\t\n" +"vfmadd231ps ymm0,ymm14,ymm2\t\n" +"vbroadcastss ymm2,DWORD PTR [r9+12]\t\n" +"vfmadd231ps ymm1,ymm14,ymm2\t\n" +"add r11, 32\t\n" +"add r9,16\t\n" +"cmp r14, r8\t\n" +"jl loop_inner%=\t\n" + +"L_exit%=:\t\n" +"add r10, rsi\t\n" + +"cmp rdx, 1\t\n" +"je L_accum%=\t\n" +// Dump C +"vmovups YMMWORD PTR [r12 + 0], ymm0\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm1\t\n" +"add r12, r13\t\n" +"jmp L_done%=\t\n" + + +"L_accum%=:\t\n" +// Dump C with accumulate +"vbroadcastss ymm15,DWORD PTR [r15]\t\n" +"vfmadd231ps ymm0,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm0\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm1,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm1\t\n" +"add r12, r13\t\n" + +"L_done%=:\t\n" + +// next outer iteration +"add rcx, 32\t\n" +"mov r12, rcx\t\n" +"mov r9, rax\t\n" +"inc rbx\t\n" +"cmp rbx, rdi\t\n" +"jl loop_outter%=\t\n" +: +: +[gp] "rm" (gp) +: "r8", "r9", "r10", "r11", "r15", "r13", "r14", +"rax", "rcx", "rdx", "rsi", "rdi", "rbx", "r12", "memory" +); +} +void __attribute__ ((noinline)) gemmkernel_3x1_AVX2_fA0fB0fC0(GemmParams *gp) +{ +asm volatile +( +#if !defined(__clang__) +"mov r14, %[gp]\t\n" +#else +"mov %[gp], %%r14\t\n" +".intel_syntax noprefix\t\n" +#endif + +// Copy parameters +// k +"mov r8, [r14 + 0]\t\n" +// A +"mov r9, [r14 + 8]\t\n" +// B +"mov r10, [r14 + 16]\t\n" +// beta +"mov r15, [r14 + 24]\t\n" +// accum +"mov rdx, [r14 + 32]\t\n" +// C +"mov r12, [r14 + 40]\t\n" +// ldc +"mov r13, [r14 + 48]\t\n" +// b_block_cols +"mov rdi, [r14 + 56]\t\n" +// b_block_size +"mov rsi, [r14 + 64]\t\n" +// Make copies of A and C +"mov rax, r9\t\n" +"mov rcx, r12\t\n" + + +"mov rbx, 0\t\n" +"loop_outter%=:\t\n" +"mov r14, 0\t\n" +"vxorps ymm0,ymm0,ymm0\t\n" +"vxorps ymm1,ymm1,ymm1\t\n" +"vxorps ymm2,ymm2,ymm2\t\n" + +"vcvtph2ps ymm15, XMMWORD PTR [r10 + 0]\t\n" +"mov r11, 16\t\n" + +"loop_inner%=:\t\n" + +"vcvtph2ps ymm14,XMMWORD PTR [r10 + r11 + 0]\t\n" +"inc r14\t\n" +"vbroadcastss ymm3,DWORD PTR [r9+0]\t\n" +"vfmadd231ps ymm0,ymm15,ymm3\t\n" +"vbroadcastss ymm3,DWORD PTR [r9+4]\t\n" +"vfmadd231ps ymm1,ymm15,ymm3\t\n" +"vbroadcastss ymm3,DWORD PTR [r9+8]\t\n" +"vfmadd231ps ymm2,ymm15,ymm3\t\n" +"cmp r14, r8\t\n" +"jge L_exit%=\t\n" +"vcvtph2ps ymm15,XMMWORD PTR [r10 + r11 + 16]\t\n" +"inc r14\t\n" +"vbroadcastss ymm3,DWORD PTR [r9+12]\t\n" +"vfmadd231ps ymm0,ymm14,ymm3\t\n" +"vbroadcastss ymm3,DWORD PTR [r9+16]\t\n" +"vfmadd231ps ymm1,ymm14,ymm3\t\n" +"add r11, 32\t\n" +"vbroadcastss ymm3,DWORD PTR [r9+20]\t\n" +"vfmadd231ps ymm2,ymm14,ymm3\t\n" +"add r9,24\t\n" +"cmp r14, r8\t\n" +"jl loop_inner%=\t\n" + +"L_exit%=:\t\n" +"add r10, rsi\t\n" + +"cmp rdx, 1\t\n" +"je L_accum%=\t\n" +// Dump C +"vmovups YMMWORD PTR [r12 + 0], ymm0\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm1\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm2\t\n" +"add r12, r13\t\n" +"jmp L_done%=\t\n" + + +"L_accum%=:\t\n" +// Dump C with accumulate +"vbroadcastss ymm15,DWORD PTR [r15]\t\n" +"vfmadd231ps ymm0,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm0\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm1,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm1\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm2,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm2\t\n" +"add r12, r13\t\n" + +"L_done%=:\t\n" + +// next outer iteration +"add rcx, 32\t\n" +"mov r12, rcx\t\n" +"mov r9, rax\t\n" +"inc rbx\t\n" +"cmp rbx, rdi\t\n" +"jl loop_outter%=\t\n" +: +: +[gp] "rm" (gp) +: "r8", "r9", "r10", "r11", "r15", "r13", "r14", +"rax", "rcx", "rdx", "rsi", "rdi", "rbx", "r12", "memory" +); +} +void __attribute__ ((noinline)) gemmkernel_4x1_AVX2_fA0fB0fC0(GemmParams *gp) +{ +asm volatile +( +#if !defined(__clang__) +"mov r14, %[gp]\t\n" +#else +"mov %[gp], %%r14\t\n" +".intel_syntax noprefix\t\n" +#endif + +// Copy parameters +// k +"mov r8, [r14 + 0]\t\n" +// A +"mov r9, [r14 + 8]\t\n" +// B +"mov r10, [r14 + 16]\t\n" +// beta +"mov r15, [r14 + 24]\t\n" +// accum +"mov rdx, [r14 + 32]\t\n" +// C +"mov r12, [r14 + 40]\t\n" +// ldc +"mov r13, [r14 + 48]\t\n" +// b_block_cols +"mov rdi, [r14 + 56]\t\n" +// b_block_size +"mov rsi, [r14 + 64]\t\n" +// Make copies of A and C +"mov rax, r9\t\n" +"mov rcx, r12\t\n" + + +"mov rbx, 0\t\n" +"loop_outter%=:\t\n" +"mov r14, 0\t\n" +"vxorps ymm0,ymm0,ymm0\t\n" +"vxorps ymm1,ymm1,ymm1\t\n" +"vxorps ymm2,ymm2,ymm2\t\n" +"vxorps ymm3,ymm3,ymm3\t\n" + +"vcvtph2ps ymm15, XMMWORD PTR [r10 + 0]\t\n" +"mov r11, 16\t\n" + +"loop_inner%=:\t\n" + +"vcvtph2ps ymm14,XMMWORD PTR [r10 + r11 + 0]\t\n" +"inc r14\t\n" +"vbroadcastss ymm4,DWORD PTR [r9+0]\t\n" +"vfmadd231ps ymm0,ymm15,ymm4\t\n" +"vbroadcastss ymm4,DWORD PTR [r9+4]\t\n" +"vfmadd231ps ymm1,ymm15,ymm4\t\n" +"vbroadcastss ymm4,DWORD PTR [r9+8]\t\n" +"vfmadd231ps ymm2,ymm15,ymm4\t\n" +"vbroadcastss ymm4,DWORD PTR [r9+12]\t\n" +"vfmadd231ps ymm3,ymm15,ymm4\t\n" +"cmp r14, r8\t\n" +"jge L_exit%=\t\n" +"vcvtph2ps ymm15,XMMWORD PTR [r10 + r11 + 16]\t\n" +"inc r14\t\n" +"vbroadcastss ymm4,DWORD PTR [r9+16]\t\n" +"vfmadd231ps ymm0,ymm14,ymm4\t\n" +"vbroadcastss ymm4,DWORD PTR [r9+20]\t\n" +"vfmadd231ps ymm1,ymm14,ymm4\t\n" +"vbroadcastss ymm4,DWORD PTR [r9+24]\t\n" +"vfmadd231ps ymm2,ymm14,ymm4\t\n" +"add r11, 32\t\n" +"vbroadcastss ymm4,DWORD PTR [r9+28]\t\n" +"vfmadd231ps ymm3,ymm14,ymm4\t\n" +"add r9,32\t\n" +"cmp r14, r8\t\n" +"jl loop_inner%=\t\n" + +"L_exit%=:\t\n" +"add r10, rsi\t\n" + +"cmp rdx, 1\t\n" +"je L_accum%=\t\n" +// Dump C +"vmovups YMMWORD PTR [r12 + 0], ymm0\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm1\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm2\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm3\t\n" +"add r12, r13\t\n" +"jmp L_done%=\t\n" + + +"L_accum%=:\t\n" +// Dump C with accumulate +"vbroadcastss ymm15,DWORD PTR [r15]\t\n" +"vfmadd231ps ymm0,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm0\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm1,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm1\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm2,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm2\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm3,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm3\t\n" +"add r12, r13\t\n" + +"L_done%=:\t\n" + +// next outer iteration +"add rcx, 32\t\n" +"mov r12, rcx\t\n" +"mov r9, rax\t\n" +"inc rbx\t\n" +"cmp rbx, rdi\t\n" +"jl loop_outter%=\t\n" +: +: +[gp] "rm" (gp) +: "r8", "r9", "r10", "r11", "r15", "r13", "r14", +"rax", "rcx", "rdx", "rsi", "rdi", "rbx", "r12", "memory" +); +} +void __attribute__ ((noinline)) gemmkernel_5x1_AVX2_fA0fB0fC0(GemmParams *gp) +{ +asm volatile +( +#if !defined(__clang__) +"mov r14, %[gp]\t\n" +#else +"mov %[gp], %%r14\t\n" +".intel_syntax noprefix\t\n" +#endif + +// Copy parameters +// k +"mov r8, [r14 + 0]\t\n" +// A +"mov r9, [r14 + 8]\t\n" +// B +"mov r10, [r14 + 16]\t\n" +// beta +"mov r15, [r14 + 24]\t\n" +// accum +"mov rdx, [r14 + 32]\t\n" +// C +"mov r12, [r14 + 40]\t\n" +// ldc +"mov r13, [r14 + 48]\t\n" +// b_block_cols +"mov rdi, [r14 + 56]\t\n" +// b_block_size +"mov rsi, [r14 + 64]\t\n" +// Make copies of A and C +"mov rax, r9\t\n" +"mov rcx, r12\t\n" + + +"mov rbx, 0\t\n" +"loop_outter%=:\t\n" +"mov r14, 0\t\n" +"vxorps ymm0,ymm0,ymm0\t\n" +"vxorps ymm1,ymm1,ymm1\t\n" +"vxorps ymm2,ymm2,ymm2\t\n" +"vxorps ymm3,ymm3,ymm3\t\n" +"vxorps ymm4,ymm4,ymm4\t\n" + +"vcvtph2ps ymm15, XMMWORD PTR [r10 + 0]\t\n" +"mov r11, 16\t\n" + +"loop_inner%=:\t\n" + +"vcvtph2ps ymm14,XMMWORD PTR [r10 + r11 + 0]\t\n" +"inc r14\t\n" +"vbroadcastss ymm5,DWORD PTR [r9+0]\t\n" +"vfmadd231ps ymm0,ymm15,ymm5\t\n" +"vbroadcastss ymm5,DWORD PTR [r9+4]\t\n" +"vfmadd231ps ymm1,ymm15,ymm5\t\n" +"vbroadcastss ymm5,DWORD PTR [r9+8]\t\n" +"vfmadd231ps ymm2,ymm15,ymm5\t\n" +"vbroadcastss ymm5,DWORD PTR [r9+12]\t\n" +"vfmadd231ps ymm3,ymm15,ymm5\t\n" +"vbroadcastss ymm5,DWORD PTR [r9+16]\t\n" +"vfmadd231ps ymm4,ymm15,ymm5\t\n" +"cmp r14, r8\t\n" +"jge L_exit%=\t\n" +"vcvtph2ps ymm15,XMMWORD PTR [r10 + r11 + 16]\t\n" +"inc r14\t\n" +"vbroadcastss ymm5,DWORD PTR [r9+20]\t\n" +"vfmadd231ps ymm0,ymm14,ymm5\t\n" +"vbroadcastss ymm5,DWORD PTR [r9+24]\t\n" +"vfmadd231ps ymm1,ymm14,ymm5\t\n" +"vbroadcastss ymm5,DWORD PTR [r9+28]\t\n" +"vfmadd231ps ymm2,ymm14,ymm5\t\n" +"add r11, 32\t\n" +"vbroadcastss ymm5,DWORD PTR [r9+32]\t\n" +"vfmadd231ps ymm3,ymm14,ymm5\t\n" +"vbroadcastss ymm5,DWORD PTR [r9+36]\t\n" +"vfmadd231ps ymm4,ymm14,ymm5\t\n" +"add r9,40\t\n" +"cmp r14, r8\t\n" +"jl loop_inner%=\t\n" + +"L_exit%=:\t\n" +"add r10, rsi\t\n" + +"cmp rdx, 1\t\n" +"je L_accum%=\t\n" +// Dump C +"vmovups YMMWORD PTR [r12 + 0], ymm0\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm1\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm2\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm3\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm4\t\n" +"add r12, r13\t\n" +"jmp L_done%=\t\n" + + +"L_accum%=:\t\n" +// Dump C with accumulate +"vbroadcastss ymm15,DWORD PTR [r15]\t\n" +"vfmadd231ps ymm0,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm0\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm1,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm1\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm2,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm2\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm3,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm3\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm4,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm4\t\n" +"add r12, r13\t\n" + +"L_done%=:\t\n" + +// next outer iteration +"add rcx, 32\t\n" +"mov r12, rcx\t\n" +"mov r9, rax\t\n" +"inc rbx\t\n" +"cmp rbx, rdi\t\n" +"jl loop_outter%=\t\n" +: +: +[gp] "rm" (gp) +: "r8", "r9", "r10", "r11", "r15", "r13", "r14", +"rax", "rcx", "rdx", "rsi", "rdi", "rbx", "r12", "memory" +); +} +void __attribute__ ((noinline)) gemmkernel_6x1_AVX2_fA0fB0fC0(GemmParams *gp) +{ +asm volatile +( +#if !defined(__clang__) +"mov r14, %[gp]\t\n" +#else +"mov %[gp], %%r14\t\n" +".intel_syntax noprefix\t\n" +#endif + +// Copy parameters +// k +"mov r8, [r14 + 0]\t\n" +// A +"mov r9, [r14 + 8]\t\n" +// B +"mov r10, [r14 + 16]\t\n" +// beta +"mov r15, [r14 + 24]\t\n" +// accum +"mov rdx, [r14 + 32]\t\n" +// C +"mov r12, [r14 + 40]\t\n" +// ldc +"mov r13, [r14 + 48]\t\n" +// b_block_cols +"mov rdi, [r14 + 56]\t\n" +// b_block_size +"mov rsi, [r14 + 64]\t\n" +// Make copies of A and C +"mov rax, r9\t\n" +"mov rcx, r12\t\n" + + +"mov rbx, 0\t\n" +"loop_outter%=:\t\n" +"mov r14, 0\t\n" +"vxorps ymm0,ymm0,ymm0\t\n" +"vxorps ymm1,ymm1,ymm1\t\n" +"vxorps ymm2,ymm2,ymm2\t\n" +"vxorps ymm3,ymm3,ymm3\t\n" +"vxorps ymm4,ymm4,ymm4\t\n" +"vxorps ymm5,ymm5,ymm5\t\n" + +"vcvtph2ps ymm15, XMMWORD PTR [r10 + 0]\t\n" +"mov r11, 16\t\n" + +"loop_inner%=:\t\n" + +"vcvtph2ps ymm14,XMMWORD PTR [r10 + r11 + 0]\t\n" +"inc r14\t\n" +"vbroadcastss ymm6,DWORD PTR [r9+0]\t\n" +"vfmadd231ps ymm0,ymm15,ymm6\t\n" +"vbroadcastss ymm6,DWORD PTR [r9+4]\t\n" +"vfmadd231ps ymm1,ymm15,ymm6\t\n" +"vbroadcastss ymm6,DWORD PTR [r9+8]\t\n" +"vfmadd231ps ymm2,ymm15,ymm6\t\n" +"vbroadcastss ymm6,DWORD PTR [r9+12]\t\n" +"vfmadd231ps ymm3,ymm15,ymm6\t\n" +"vbroadcastss ymm6,DWORD PTR [r9+16]\t\n" +"vfmadd231ps ymm4,ymm15,ymm6\t\n" +"vbroadcastss ymm6,DWORD PTR [r9+20]\t\n" +"vfmadd231ps ymm5,ymm15,ymm6\t\n" +"cmp r14, r8\t\n" +"jge L_exit%=\t\n" +"vcvtph2ps ymm15,XMMWORD PTR [r10 + r11 + 16]\t\n" +"inc r14\t\n" +"vbroadcastss ymm6,DWORD PTR [r9+24]\t\n" +"vfmadd231ps ymm0,ymm14,ymm6\t\n" +"vbroadcastss ymm6,DWORD PTR [r9+28]\t\n" +"vfmadd231ps ymm1,ymm14,ymm6\t\n" +"vbroadcastss ymm6,DWORD PTR [r9+32]\t\n" +"vfmadd231ps ymm2,ymm14,ymm6\t\n" +"vbroadcastss ymm6,DWORD PTR [r9+36]\t\n" +"vfmadd231ps ymm3,ymm14,ymm6\t\n" +"add r11, 32\t\n" +"vbroadcastss ymm6,DWORD PTR [r9+40]\t\n" +"vfmadd231ps ymm4,ymm14,ymm6\t\n" +"vbroadcastss ymm6,DWORD PTR [r9+44]\t\n" +"vfmadd231ps ymm5,ymm14,ymm6\t\n" +"add r9,48\t\n" +"cmp r14, r8\t\n" +"jl loop_inner%=\t\n" + +"L_exit%=:\t\n" +"add r10, rsi\t\n" + +"cmp rdx, 1\t\n" +"je L_accum%=\t\n" +// Dump C +"vmovups YMMWORD PTR [r12 + 0], ymm0\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm1\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm2\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm3\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm4\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm5\t\n" +"add r12, r13\t\n" +"jmp L_done%=\t\n" + + +"L_accum%=:\t\n" +// Dump C with accumulate +"vbroadcastss ymm15,DWORD PTR [r15]\t\n" +"vfmadd231ps ymm0,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm0\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm1,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm1\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm2,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm2\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm3,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm3\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm4,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm4\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm5,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm5\t\n" +"add r12, r13\t\n" + +"L_done%=:\t\n" + +// next outer iteration +"add rcx, 32\t\n" +"mov r12, rcx\t\n" +"mov r9, rax\t\n" +"inc rbx\t\n" +"cmp rbx, rdi\t\n" +"jl loop_outter%=\t\n" +: +: +[gp] "rm" (gp) +: "r8", "r9", "r10", "r11", "r15", "r13", "r14", +"rax", "rcx", "rdx", "rsi", "rdi", "rbx", "r12", "memory" +); +} +void __attribute__ ((noinline)) gemmkernel_7x1_AVX2_fA0fB0fC0(GemmParams *gp) +{ +asm volatile +( +#if !defined(__clang__) +"mov r14, %[gp]\t\n" +#else +"mov %[gp], %%r14\t\n" +".intel_syntax noprefix\t\n" +#endif + +// Copy parameters +// k +"mov r8, [r14 + 0]\t\n" +// A +"mov r9, [r14 + 8]\t\n" +// B +"mov r10, [r14 + 16]\t\n" +// beta +"mov r15, [r14 + 24]\t\n" +// accum +"mov rdx, [r14 + 32]\t\n" +// C +"mov r12, [r14 + 40]\t\n" +// ldc +"mov r13, [r14 + 48]\t\n" +// b_block_cols +"mov rdi, [r14 + 56]\t\n" +// b_block_size +"mov rsi, [r14 + 64]\t\n" +// Make copies of A and C +"mov rax, r9\t\n" +"mov rcx, r12\t\n" + + +"mov rbx, 0\t\n" +"loop_outter%=:\t\n" +"mov r14, 0\t\n" +"vxorps ymm0,ymm0,ymm0\t\n" +"vxorps ymm1,ymm1,ymm1\t\n" +"vxorps ymm2,ymm2,ymm2\t\n" +"vxorps ymm3,ymm3,ymm3\t\n" +"vxorps ymm4,ymm4,ymm4\t\n" +"vxorps ymm5,ymm5,ymm5\t\n" +"vxorps ymm6,ymm6,ymm6\t\n" + +"vcvtph2ps ymm15, XMMWORD PTR [r10 + 0]\t\n" +"mov r11, 16\t\n" + +"loop_inner%=:\t\n" + +"vcvtph2ps ymm14,XMMWORD PTR [r10 + r11 + 0]\t\n" +"inc r14\t\n" +"vbroadcastss ymm7,DWORD PTR [r9+0]\t\n" +"vfmadd231ps ymm0,ymm15,ymm7\t\n" +"vbroadcastss ymm7,DWORD PTR [r9+4]\t\n" +"vfmadd231ps ymm1,ymm15,ymm7\t\n" +"vbroadcastss ymm7,DWORD PTR [r9+8]\t\n" +"vfmadd231ps ymm2,ymm15,ymm7\t\n" +"vbroadcastss ymm7,DWORD PTR [r9+12]\t\n" +"vfmadd231ps ymm3,ymm15,ymm7\t\n" +"vbroadcastss ymm7,DWORD PTR [r9+16]\t\n" +"vfmadd231ps ymm4,ymm15,ymm7\t\n" +"vbroadcastss ymm7,DWORD PTR [r9+20]\t\n" +"vfmadd231ps ymm5,ymm15,ymm7\t\n" +"vbroadcastss ymm7,DWORD PTR [r9+24]\t\n" +"vfmadd231ps ymm6,ymm15,ymm7\t\n" +"cmp r14, r8\t\n" +"jge L_exit%=\t\n" +"vcvtph2ps ymm15,XMMWORD PTR [r10 + r11 + 16]\t\n" +"inc r14\t\n" +"vbroadcastss ymm7,DWORD PTR [r9+28]\t\n" +"vfmadd231ps ymm0,ymm14,ymm7\t\n" +"vbroadcastss ymm7,DWORD PTR [r9+32]\t\n" +"vfmadd231ps ymm1,ymm14,ymm7\t\n" +"vbroadcastss ymm7,DWORD PTR [r9+36]\t\n" +"vfmadd231ps ymm2,ymm14,ymm7\t\n" +"vbroadcastss ymm7,DWORD PTR [r9+40]\t\n" +"vfmadd231ps ymm3,ymm14,ymm7\t\n" +"add r11, 32\t\n" +"vbroadcastss ymm7,DWORD PTR [r9+44]\t\n" +"vfmadd231ps ymm4,ymm14,ymm7\t\n" +"vbroadcastss ymm7,DWORD PTR [r9+48]\t\n" +"vfmadd231ps ymm5,ymm14,ymm7\t\n" +"vbroadcastss ymm7,DWORD PTR [r9+52]\t\n" +"vfmadd231ps ymm6,ymm14,ymm7\t\n" +"add r9,56\t\n" +"cmp r14, r8\t\n" +"jl loop_inner%=\t\n" + +"L_exit%=:\t\n" +"add r10, rsi\t\n" + +"cmp rdx, 1\t\n" +"je L_accum%=\t\n" +// Dump C +"vmovups YMMWORD PTR [r12 + 0], ymm0\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm1\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm2\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm3\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm4\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm5\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm6\t\n" +"add r12, r13\t\n" +"jmp L_done%=\t\n" + + +"L_accum%=:\t\n" +// Dump C with accumulate +"vbroadcastss ymm15,DWORD PTR [r15]\t\n" +"vfmadd231ps ymm0,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm0\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm1,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm1\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm2,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm2\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm3,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm3\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm4,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm4\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm5,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm5\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm6,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm6\t\n" +"add r12, r13\t\n" + +"L_done%=:\t\n" + +// next outer iteration +"add rcx, 32\t\n" +"mov r12, rcx\t\n" +"mov r9, rax\t\n" +"inc rbx\t\n" +"cmp rbx, rdi\t\n" +"jl loop_outter%=\t\n" +: +: +[gp] "rm" (gp) +: "r8", "r9", "r10", "r11", "r15", "r13", "r14", +"rax", "rcx", "rdx", "rsi", "rdi", "rbx", "r12", "memory" +); +} +void __attribute__ ((noinline)) gemmkernel_8x1_AVX2_fA0fB0fC0(GemmParams *gp) +{ +asm volatile +( +#if !defined(__clang__) +"mov r14, %[gp]\t\n" +#else +"mov %[gp], %%r14\t\n" +".intel_syntax noprefix\t\n" +#endif + +// Copy parameters +// k +"mov r8, [r14 + 0]\t\n" +// A +"mov r9, [r14 + 8]\t\n" +// B +"mov r10, [r14 + 16]\t\n" +// beta +"mov r15, [r14 + 24]\t\n" +// accum +"mov rdx, [r14 + 32]\t\n" +// C +"mov r12, [r14 + 40]\t\n" +// ldc +"mov r13, [r14 + 48]\t\n" +// b_block_cols +"mov rdi, [r14 + 56]\t\n" +// b_block_size +"mov rsi, [r14 + 64]\t\n" +// Make copies of A and C +"mov rax, r9\t\n" +"mov rcx, r12\t\n" + + +"mov rbx, 0\t\n" +"loop_outter%=:\t\n" +"mov r14, 0\t\n" +"vxorps ymm0,ymm0,ymm0\t\n" +"vxorps ymm1,ymm1,ymm1\t\n" +"vxorps ymm2,ymm2,ymm2\t\n" +"vxorps ymm3,ymm3,ymm3\t\n" +"vxorps ymm4,ymm4,ymm4\t\n" +"vxorps ymm5,ymm5,ymm5\t\n" +"vxorps ymm6,ymm6,ymm6\t\n" +"vxorps ymm7,ymm7,ymm7\t\n" + +"vcvtph2ps ymm15, XMMWORD PTR [r10 + 0]\t\n" +"mov r11, 16\t\n" + +"loop_inner%=:\t\n" + +"vcvtph2ps ymm14,XMMWORD PTR [r10 + r11 + 0]\t\n" +"inc r14\t\n" +"vbroadcastss ymm8,DWORD PTR [r9+0]\t\n" +"vfmadd231ps ymm0,ymm15,ymm8\t\n" +"vbroadcastss ymm8,DWORD PTR [r9+4]\t\n" +"vfmadd231ps ymm1,ymm15,ymm8\t\n" +"vbroadcastss ymm8,DWORD PTR [r9+8]\t\n" +"vfmadd231ps ymm2,ymm15,ymm8\t\n" +"vbroadcastss ymm8,DWORD PTR [r9+12]\t\n" +"vfmadd231ps ymm3,ymm15,ymm8\t\n" +"vbroadcastss ymm8,DWORD PTR [r9+16]\t\n" +"vfmadd231ps ymm4,ymm15,ymm8\t\n" +"vbroadcastss ymm8,DWORD PTR [r9+20]\t\n" +"vfmadd231ps ymm5,ymm15,ymm8\t\n" +"vbroadcastss ymm8,DWORD PTR [r9+24]\t\n" +"vfmadd231ps ymm6,ymm15,ymm8\t\n" +"vbroadcastss ymm8,DWORD PTR [r9+28]\t\n" +"vfmadd231ps ymm7,ymm15,ymm8\t\n" +"cmp r14, r8\t\n" +"jge L_exit%=\t\n" +"vcvtph2ps ymm15,XMMWORD PTR [r10 + r11 + 16]\t\n" +"inc r14\t\n" +"vbroadcastss ymm8,DWORD PTR [r9+32]\t\n" +"vfmadd231ps ymm0,ymm14,ymm8\t\n" +"vbroadcastss ymm8,DWORD PTR [r9+36]\t\n" +"vfmadd231ps ymm1,ymm14,ymm8\t\n" +"vbroadcastss ymm8,DWORD PTR [r9+40]\t\n" +"vfmadd231ps ymm2,ymm14,ymm8\t\n" +"vbroadcastss ymm8,DWORD PTR [r9+44]\t\n" +"vfmadd231ps ymm3,ymm14,ymm8\t\n" +"vbroadcastss ymm8,DWORD PTR [r9+48]\t\n" +"vfmadd231ps ymm4,ymm14,ymm8\t\n" +"add r11, 32\t\n" +"vbroadcastss ymm8,DWORD PTR [r9+52]\t\n" +"vfmadd231ps ymm5,ymm14,ymm8\t\n" +"vbroadcastss ymm8,DWORD PTR [r9+56]\t\n" +"vfmadd231ps ymm6,ymm14,ymm8\t\n" +"vbroadcastss ymm8,DWORD PTR [r9+60]\t\n" +"vfmadd231ps ymm7,ymm14,ymm8\t\n" +"add r9,64\t\n" +"cmp r14, r8\t\n" +"jl loop_inner%=\t\n" + +"L_exit%=:\t\n" +"add r10, rsi\t\n" + +"cmp rdx, 1\t\n" +"je L_accum%=\t\n" +// Dump C +"vmovups YMMWORD PTR [r12 + 0], ymm0\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm1\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm2\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm3\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm4\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm5\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm6\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm7\t\n" +"add r12, r13\t\n" +"jmp L_done%=\t\n" + + +"L_accum%=:\t\n" +// Dump C with accumulate +"vbroadcastss ymm15,DWORD PTR [r15]\t\n" +"vfmadd231ps ymm0,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm0\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm1,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm1\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm2,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm2\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm3,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm3\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm4,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm4\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm5,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm5\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm6,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm6\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm7,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm7\t\n" +"add r12, r13\t\n" + +"L_done%=:\t\n" + +// next outer iteration +"add rcx, 32\t\n" +"mov r12, rcx\t\n" +"mov r9, rax\t\n" +"inc rbx\t\n" +"cmp rbx, rdi\t\n" +"jl loop_outter%=\t\n" +: +: +[gp] "rm" (gp) +: "r8", "r9", "r10", "r11", "r15", "r13", "r14", +"rax", "rcx", "rdx", "rsi", "rdi", "rbx", "r12", "memory" +); +} +void __attribute__ ((noinline)) gemmkernel_9x1_AVX2_fA0fB0fC0(GemmParams *gp) +{ +asm volatile +( +#if !defined(__clang__) +"mov r14, %[gp]\t\n" +#else +"mov %[gp], %%r14\t\n" +".intel_syntax noprefix\t\n" +#endif + +// Copy parameters +// k +"mov r8, [r14 + 0]\t\n" +// A +"mov r9, [r14 + 8]\t\n" +// B +"mov r10, [r14 + 16]\t\n" +// beta +"mov r15, [r14 + 24]\t\n" +// accum +"mov rdx, [r14 + 32]\t\n" +// C +"mov r12, [r14 + 40]\t\n" +// ldc +"mov r13, [r14 + 48]\t\n" +// b_block_cols +"mov rdi, [r14 + 56]\t\n" +// b_block_size +"mov rsi, [r14 + 64]\t\n" +// Make copies of A and C +"mov rax, r9\t\n" +"mov rcx, r12\t\n" + + +"mov rbx, 0\t\n" +"loop_outter%=:\t\n" +"mov r14, 0\t\n" +"vxorps ymm0,ymm0,ymm0\t\n" +"vxorps ymm1,ymm1,ymm1\t\n" +"vxorps ymm2,ymm2,ymm2\t\n" +"vxorps ymm3,ymm3,ymm3\t\n" +"vxorps ymm4,ymm4,ymm4\t\n" +"vxorps ymm5,ymm5,ymm5\t\n" +"vxorps ymm6,ymm6,ymm6\t\n" +"vxorps ymm7,ymm7,ymm7\t\n" +"vxorps ymm8,ymm8,ymm8\t\n" + +"vcvtph2ps ymm15, XMMWORD PTR [r10 + 0]\t\n" +"mov r11, 16\t\n" + +"loop_inner%=:\t\n" + +"vcvtph2ps ymm14,XMMWORD PTR [r10 + r11 + 0]\t\n" +"inc r14\t\n" +"vbroadcastss ymm9,DWORD PTR [r9+0]\t\n" +"vfmadd231ps ymm0,ymm15,ymm9\t\n" +"vbroadcastss ymm9,DWORD PTR [r9+4]\t\n" +"vfmadd231ps ymm1,ymm15,ymm9\t\n" +"vbroadcastss ymm9,DWORD PTR [r9+8]\t\n" +"vfmadd231ps ymm2,ymm15,ymm9\t\n" +"vbroadcastss ymm9,DWORD PTR [r9+12]\t\n" +"vfmadd231ps ymm3,ymm15,ymm9\t\n" +"vbroadcastss ymm9,DWORD PTR [r9+16]\t\n" +"vfmadd231ps ymm4,ymm15,ymm9\t\n" +"vbroadcastss ymm9,DWORD PTR [r9+20]\t\n" +"vfmadd231ps ymm5,ymm15,ymm9\t\n" +"vbroadcastss ymm9,DWORD PTR [r9+24]\t\n" +"vfmadd231ps ymm6,ymm15,ymm9\t\n" +"vbroadcastss ymm9,DWORD PTR [r9+28]\t\n" +"vfmadd231ps ymm7,ymm15,ymm9\t\n" +"vbroadcastss ymm9,DWORD PTR [r9+32]\t\n" +"vfmadd231ps ymm8,ymm15,ymm9\t\n" +"cmp r14, r8\t\n" +"jge L_exit%=\t\n" +"vcvtph2ps ymm15,XMMWORD PTR [r10 + r11 + 16]\t\n" +"inc r14\t\n" +"vbroadcastss ymm9,DWORD PTR [r9+36]\t\n" +"vfmadd231ps ymm0,ymm14,ymm9\t\n" +"vbroadcastss ymm9,DWORD PTR [r9+40]\t\n" +"vfmadd231ps ymm1,ymm14,ymm9\t\n" +"vbroadcastss ymm9,DWORD PTR [r9+44]\t\n" +"vfmadd231ps ymm2,ymm14,ymm9\t\n" +"vbroadcastss ymm9,DWORD PTR [r9+48]\t\n" +"vfmadd231ps ymm3,ymm14,ymm9\t\n" +"vbroadcastss ymm9,DWORD PTR [r9+52]\t\n" +"vfmadd231ps ymm4,ymm14,ymm9\t\n" +"add r11, 32\t\n" +"vbroadcastss ymm9,DWORD PTR [r9+56]\t\n" +"vfmadd231ps ymm5,ymm14,ymm9\t\n" +"vbroadcastss ymm9,DWORD PTR [r9+60]\t\n" +"vfmadd231ps ymm6,ymm14,ymm9\t\n" +"vbroadcastss ymm9,DWORD PTR [r9+64]\t\n" +"vfmadd231ps ymm7,ymm14,ymm9\t\n" +"vbroadcastss ymm9,DWORD PTR [r9+68]\t\n" +"vfmadd231ps ymm8,ymm14,ymm9\t\n" +"add r9,72\t\n" +"cmp r14, r8\t\n" +"jl loop_inner%=\t\n" + +"L_exit%=:\t\n" +"add r10, rsi\t\n" + +"cmp rdx, 1\t\n" +"je L_accum%=\t\n" +// Dump C +"vmovups YMMWORD PTR [r12 + 0], ymm0\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm1\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm2\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm3\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm4\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm5\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm6\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm7\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm8\t\n" +"add r12, r13\t\n" +"jmp L_done%=\t\n" + + +"L_accum%=:\t\n" +// Dump C with accumulate +"vbroadcastss ymm15,DWORD PTR [r15]\t\n" +"vfmadd231ps ymm0,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm0\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm1,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm1\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm2,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm2\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm3,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm3\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm4,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm4\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm5,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm5\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm6,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm6\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm7,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm7\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm8,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm8\t\n" +"add r12, r13\t\n" + +"L_done%=:\t\n" + +// next outer iteration +"add rcx, 32\t\n" +"mov r12, rcx\t\n" +"mov r9, rax\t\n" +"inc rbx\t\n" +"cmp rbx, rdi\t\n" +"jl loop_outter%=\t\n" +: +: +[gp] "rm" (gp) +: "r8", "r9", "r10", "r11", "r15", "r13", "r14", +"rax", "rcx", "rdx", "rsi", "rdi", "rbx", "r12", "memory" +); +} +void __attribute__ ((noinline)) gemmkernel_10x1_AVX2_fA0fB0fC0(GemmParams *gp) +{ +asm volatile +( +#if !defined(__clang__) +"mov r14, %[gp]\t\n" +#else +"mov %[gp], %%r14\t\n" +".intel_syntax noprefix\t\n" +#endif + +// Copy parameters +// k +"mov r8, [r14 + 0]\t\n" +// A +"mov r9, [r14 + 8]\t\n" +// B +"mov r10, [r14 + 16]\t\n" +// beta +"mov r15, [r14 + 24]\t\n" +// accum +"mov rdx, [r14 + 32]\t\n" +// C +"mov r12, [r14 + 40]\t\n" +// ldc +"mov r13, [r14 + 48]\t\n" +// b_block_cols +"mov rdi, [r14 + 56]\t\n" +// b_block_size +"mov rsi, [r14 + 64]\t\n" +// Make copies of A and C +"mov rax, r9\t\n" +"mov rcx, r12\t\n" + + +"mov rbx, 0\t\n" +"loop_outter%=:\t\n" +"mov r14, 0\t\n" +"vxorps ymm0,ymm0,ymm0\t\n" +"vxorps ymm1,ymm1,ymm1\t\n" +"vxorps ymm2,ymm2,ymm2\t\n" +"vxorps ymm3,ymm3,ymm3\t\n" +"vxorps ymm4,ymm4,ymm4\t\n" +"vxorps ymm5,ymm5,ymm5\t\n" +"vxorps ymm6,ymm6,ymm6\t\n" +"vxorps ymm7,ymm7,ymm7\t\n" +"vxorps ymm8,ymm8,ymm8\t\n" +"vxorps ymm9,ymm9,ymm9\t\n" + +"vcvtph2ps ymm15, XMMWORD PTR [r10 + 0]\t\n" +"mov r11, 16\t\n" + +"loop_inner%=:\t\n" + +"vcvtph2ps ymm14,XMMWORD PTR [r10 + r11 + 0]\t\n" +"inc r14\t\n" +"vbroadcastss ymm10,DWORD PTR [r9+0]\t\n" +"vfmadd231ps ymm0,ymm15,ymm10\t\n" +"vbroadcastss ymm10,DWORD PTR [r9+4]\t\n" +"vfmadd231ps ymm1,ymm15,ymm10\t\n" +"vbroadcastss ymm10,DWORD PTR [r9+8]\t\n" +"vfmadd231ps ymm2,ymm15,ymm10\t\n" +"vbroadcastss ymm10,DWORD PTR [r9+12]\t\n" +"vfmadd231ps ymm3,ymm15,ymm10\t\n" +"vbroadcastss ymm10,DWORD PTR [r9+16]\t\n" +"vfmadd231ps ymm4,ymm15,ymm10\t\n" +"vbroadcastss ymm10,DWORD PTR [r9+20]\t\n" +"vfmadd231ps ymm5,ymm15,ymm10\t\n" +"vbroadcastss ymm10,DWORD PTR [r9+24]\t\n" +"vfmadd231ps ymm6,ymm15,ymm10\t\n" +"vbroadcastss ymm10,DWORD PTR [r9+28]\t\n" +"vfmadd231ps ymm7,ymm15,ymm10\t\n" +"vbroadcastss ymm10,DWORD PTR [r9+32]\t\n" +"vfmadd231ps ymm8,ymm15,ymm10\t\n" +"vbroadcastss ymm10,DWORD PTR [r9+36]\t\n" +"vfmadd231ps ymm9,ymm15,ymm10\t\n" +"cmp r14, r8\t\n" +"jge L_exit%=\t\n" +"vcvtph2ps ymm15,XMMWORD PTR [r10 + r11 + 16]\t\n" +"inc r14\t\n" +"vbroadcastss ymm10,DWORD PTR [r9+40]\t\n" +"vfmadd231ps ymm0,ymm14,ymm10\t\n" +"vbroadcastss ymm10,DWORD PTR [r9+44]\t\n" +"vfmadd231ps ymm1,ymm14,ymm10\t\n" +"vbroadcastss ymm10,DWORD PTR [r9+48]\t\n" +"vfmadd231ps ymm2,ymm14,ymm10\t\n" +"vbroadcastss ymm10,DWORD PTR [r9+52]\t\n" +"vfmadd231ps ymm3,ymm14,ymm10\t\n" +"vbroadcastss ymm10,DWORD PTR [r9+56]\t\n" +"vfmadd231ps ymm4,ymm14,ymm10\t\n" +"vbroadcastss ymm10,DWORD PTR [r9+60]\t\n" +"vfmadd231ps ymm5,ymm14,ymm10\t\n" +"add r11, 32\t\n" +"vbroadcastss ymm10,DWORD PTR [r9+64]\t\n" +"vfmadd231ps ymm6,ymm14,ymm10\t\n" +"vbroadcastss ymm10,DWORD PTR [r9+68]\t\n" +"vfmadd231ps ymm7,ymm14,ymm10\t\n" +"vbroadcastss ymm10,DWORD PTR [r9+72]\t\n" +"vfmadd231ps ymm8,ymm14,ymm10\t\n" +"vbroadcastss ymm10,DWORD PTR [r9+76]\t\n" +"vfmadd231ps ymm9,ymm14,ymm10\t\n" +"add r9,80\t\n" +"cmp r14, r8\t\n" +"jl loop_inner%=\t\n" + +"L_exit%=:\t\n" +"add r10, rsi\t\n" + +"cmp rdx, 1\t\n" +"je L_accum%=\t\n" +// Dump C +"vmovups YMMWORD PTR [r12 + 0], ymm0\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm1\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm2\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm3\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm4\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm5\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm6\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm7\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm8\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm9\t\n" +"add r12, r13\t\n" +"jmp L_done%=\t\n" + + +"L_accum%=:\t\n" +// Dump C with accumulate +"vbroadcastss ymm15,DWORD PTR [r15]\t\n" +"vfmadd231ps ymm0,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm0\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm1,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm1\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm2,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm2\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm3,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm3\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm4,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm4\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm5,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm5\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm6,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm6\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm7,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm7\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm8,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm8\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm9,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm9\t\n" +"add r12, r13\t\n" + +"L_done%=:\t\n" + +// next outer iteration +"add rcx, 32\t\n" +"mov r12, rcx\t\n" +"mov r9, rax\t\n" +"inc rbx\t\n" +"cmp rbx, rdi\t\n" +"jl loop_outter%=\t\n" +: +: +[gp] "rm" (gp) +: "r8", "r9", "r10", "r11", "r15", "r13", "r14", +"rax", "rcx", "rdx", "rsi", "rdi", "rbx", "r12", "memory" +); +} +void __attribute__ ((noinline)) gemmkernel_11x1_AVX2_fA0fB0fC0(GemmParams *gp) +{ +asm volatile +( +#if !defined(__clang__) +"mov r14, %[gp]\t\n" +#else +"mov %[gp], %%r14\t\n" +".intel_syntax noprefix\t\n" +#endif + +// Copy parameters +// k +"mov r8, [r14 + 0]\t\n" +// A +"mov r9, [r14 + 8]\t\n" +// B +"mov r10, [r14 + 16]\t\n" +// beta +"mov r15, [r14 + 24]\t\n" +// accum +"mov rdx, [r14 + 32]\t\n" +// C +"mov r12, [r14 + 40]\t\n" +// ldc +"mov r13, [r14 + 48]\t\n" +// b_block_cols +"mov rdi, [r14 + 56]\t\n" +// b_block_size +"mov rsi, [r14 + 64]\t\n" +// Make copies of A and C +"mov rax, r9\t\n" +"mov rcx, r12\t\n" + + +"mov rbx, 0\t\n" +"loop_outter%=:\t\n" +"mov r14, 0\t\n" +"vxorps ymm0,ymm0,ymm0\t\n" +"vxorps ymm1,ymm1,ymm1\t\n" +"vxorps ymm2,ymm2,ymm2\t\n" +"vxorps ymm3,ymm3,ymm3\t\n" +"vxorps ymm4,ymm4,ymm4\t\n" +"vxorps ymm5,ymm5,ymm5\t\n" +"vxorps ymm6,ymm6,ymm6\t\n" +"vxorps ymm7,ymm7,ymm7\t\n" +"vxorps ymm8,ymm8,ymm8\t\n" +"vxorps ymm9,ymm9,ymm9\t\n" +"vxorps ymm10,ymm10,ymm10\t\n" + +"vcvtph2ps ymm15, XMMWORD PTR [r10 + 0]\t\n" +"mov r11, 16\t\n" + +"loop_inner%=:\t\n" + +"vcvtph2ps ymm14,XMMWORD PTR [r10 + r11 + 0]\t\n" +"inc r14\t\n" +"vbroadcastss ymm11,DWORD PTR [r9+0]\t\n" +"vfmadd231ps ymm0,ymm15,ymm11\t\n" +"vbroadcastss ymm11,DWORD PTR [r9+4]\t\n" +"vfmadd231ps ymm1,ymm15,ymm11\t\n" +"vbroadcastss ymm11,DWORD PTR [r9+8]\t\n" +"vfmadd231ps ymm2,ymm15,ymm11\t\n" +"vbroadcastss ymm11,DWORD PTR [r9+12]\t\n" +"vfmadd231ps ymm3,ymm15,ymm11\t\n" +"vbroadcastss ymm11,DWORD PTR [r9+16]\t\n" +"vfmadd231ps ymm4,ymm15,ymm11\t\n" +"vbroadcastss ymm11,DWORD PTR [r9+20]\t\n" +"vfmadd231ps ymm5,ymm15,ymm11\t\n" +"vbroadcastss ymm11,DWORD PTR [r9+24]\t\n" +"vfmadd231ps ymm6,ymm15,ymm11\t\n" +"vbroadcastss ymm11,DWORD PTR [r9+28]\t\n" +"vfmadd231ps ymm7,ymm15,ymm11\t\n" +"vbroadcastss ymm11,DWORD PTR [r9+32]\t\n" +"vfmadd231ps ymm8,ymm15,ymm11\t\n" +"vbroadcastss ymm11,DWORD PTR [r9+36]\t\n" +"vfmadd231ps ymm9,ymm15,ymm11\t\n" +"vbroadcastss ymm11,DWORD PTR [r9+40]\t\n" +"vfmadd231ps ymm10,ymm15,ymm11\t\n" +"cmp r14, r8\t\n" +"jge L_exit%=\t\n" +"vcvtph2ps ymm15,XMMWORD PTR [r10 + r11 + 16]\t\n" +"inc r14\t\n" +"vbroadcastss ymm11,DWORD PTR [r9+44]\t\n" +"vfmadd231ps ymm0,ymm14,ymm11\t\n" +"vbroadcastss ymm11,DWORD PTR [r9+48]\t\n" +"vfmadd231ps ymm1,ymm14,ymm11\t\n" +"vbroadcastss ymm11,DWORD PTR [r9+52]\t\n" +"vfmadd231ps ymm2,ymm14,ymm11\t\n" +"vbroadcastss ymm11,DWORD PTR [r9+56]\t\n" +"vfmadd231ps ymm3,ymm14,ymm11\t\n" +"vbroadcastss ymm11,DWORD PTR [r9+60]\t\n" +"vfmadd231ps ymm4,ymm14,ymm11\t\n" +"vbroadcastss ymm11,DWORD PTR [r9+64]\t\n" +"vfmadd231ps ymm5,ymm14,ymm11\t\n" +"add r11, 32\t\n" +"vbroadcastss ymm11,DWORD PTR [r9+68]\t\n" +"vfmadd231ps ymm6,ymm14,ymm11\t\n" +"vbroadcastss ymm11,DWORD PTR [r9+72]\t\n" +"vfmadd231ps ymm7,ymm14,ymm11\t\n" +"vbroadcastss ymm11,DWORD PTR [r9+76]\t\n" +"vfmadd231ps ymm8,ymm14,ymm11\t\n" +"vbroadcastss ymm11,DWORD PTR [r9+80]\t\n" +"vfmadd231ps ymm9,ymm14,ymm11\t\n" +"vbroadcastss ymm11,DWORD PTR [r9+84]\t\n" +"vfmadd231ps ymm10,ymm14,ymm11\t\n" +"add r9,88\t\n" +"cmp r14, r8\t\n" +"jl loop_inner%=\t\n" + +"L_exit%=:\t\n" +"add r10, rsi\t\n" + +"cmp rdx, 1\t\n" +"je L_accum%=\t\n" +// Dump C +"vmovups YMMWORD PTR [r12 + 0], ymm0\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm1\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm2\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm3\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm4\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm5\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm6\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm7\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm8\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm9\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm10\t\n" +"add r12, r13\t\n" +"jmp L_done%=\t\n" + + +"L_accum%=:\t\n" +// Dump C with accumulate +"vbroadcastss ymm15,DWORD PTR [r15]\t\n" +"vfmadd231ps ymm0,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm0\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm1,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm1\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm2,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm2\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm3,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm3\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm4,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm4\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm5,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm5\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm6,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm6\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm7,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm7\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm8,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm8\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm9,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm9\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm10,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm10\t\n" +"add r12, r13\t\n" + +"L_done%=:\t\n" + +// next outer iteration +"add rcx, 32\t\n" +"mov r12, rcx\t\n" +"mov r9, rax\t\n" +"inc rbx\t\n" +"cmp rbx, rdi\t\n" +"jl loop_outter%=\t\n" +: +: +[gp] "rm" (gp) +: "r8", "r9", "r10", "r11", "r15", "r13", "r14", +"rax", "rcx", "rdx", "rsi", "rdi", "rbx", "r12", "memory" +); +} +void __attribute__ ((noinline)) gemmkernel_12x1_AVX2_fA0fB0fC0(GemmParams *gp) +{ +asm volatile +( +#if !defined(__clang__) +"mov r14, %[gp]\t\n" +#else +"mov %[gp], %%r14\t\n" +".intel_syntax noprefix\t\n" +#endif + +// Copy parameters +// k +"mov r8, [r14 + 0]\t\n" +// A +"mov r9, [r14 + 8]\t\n" +// B +"mov r10, [r14 + 16]\t\n" +// beta +"mov r15, [r14 + 24]\t\n" +// accum +"mov rdx, [r14 + 32]\t\n" +// C +"mov r12, [r14 + 40]\t\n" +// ldc +"mov r13, [r14 + 48]\t\n" +// b_block_cols +"mov rdi, [r14 + 56]\t\n" +// b_block_size +"mov rsi, [r14 + 64]\t\n" +// Make copies of A and C +"mov rax, r9\t\n" +"mov rcx, r12\t\n" + + +"mov rbx, 0\t\n" +"loop_outter%=:\t\n" +"mov r14, 0\t\n" +"vxorps ymm0,ymm0,ymm0\t\n" +"vxorps ymm1,ymm1,ymm1\t\n" +"vxorps ymm2,ymm2,ymm2\t\n" +"vxorps ymm3,ymm3,ymm3\t\n" +"vxorps ymm4,ymm4,ymm4\t\n" +"vxorps ymm5,ymm5,ymm5\t\n" +"vxorps ymm6,ymm6,ymm6\t\n" +"vxorps ymm7,ymm7,ymm7\t\n" +"vxorps ymm8,ymm8,ymm8\t\n" +"vxorps ymm9,ymm9,ymm9\t\n" +"vxorps ymm10,ymm10,ymm10\t\n" +"vxorps ymm11,ymm11,ymm11\t\n" + +"vcvtph2ps ymm15, XMMWORD PTR [r10 + 0]\t\n" +"mov r11, 16\t\n" + +"loop_inner%=:\t\n" + +"vcvtph2ps ymm14,XMMWORD PTR [r10 + r11 + 0]\t\n" +"inc r14\t\n" +"vbroadcastss ymm12,DWORD PTR [r9+0]\t\n" +"vfmadd231ps ymm0,ymm15,ymm12\t\n" +"vbroadcastss ymm12,DWORD PTR [r9+4]\t\n" +"vfmadd231ps ymm1,ymm15,ymm12\t\n" +"vbroadcastss ymm12,DWORD PTR [r9+8]\t\n" +"vfmadd231ps ymm2,ymm15,ymm12\t\n" +"vbroadcastss ymm12,DWORD PTR [r9+12]\t\n" +"vfmadd231ps ymm3,ymm15,ymm12\t\n" +"vbroadcastss ymm12,DWORD PTR [r9+16]\t\n" +"vfmadd231ps ymm4,ymm15,ymm12\t\n" +"vbroadcastss ymm12,DWORD PTR [r9+20]\t\n" +"vfmadd231ps ymm5,ymm15,ymm12\t\n" +"vbroadcastss ymm12,DWORD PTR [r9+24]\t\n" +"vfmadd231ps ymm6,ymm15,ymm12\t\n" +"vbroadcastss ymm12,DWORD PTR [r9+28]\t\n" +"vfmadd231ps ymm7,ymm15,ymm12\t\n" +"vbroadcastss ymm12,DWORD PTR [r9+32]\t\n" +"vfmadd231ps ymm8,ymm15,ymm12\t\n" +"vbroadcastss ymm12,DWORD PTR [r9+36]\t\n" +"vfmadd231ps ymm9,ymm15,ymm12\t\n" +"vbroadcastss ymm12,DWORD PTR [r9+40]\t\n" +"vfmadd231ps ymm10,ymm15,ymm12\t\n" +"vbroadcastss ymm12,DWORD PTR [r9+44]\t\n" +"vfmadd231ps ymm11,ymm15,ymm12\t\n" +"cmp r14, r8\t\n" +"jge L_exit%=\t\n" +"vcvtph2ps ymm15,XMMWORD PTR [r10 + r11 + 16]\t\n" +"inc r14\t\n" +"vbroadcastss ymm12,DWORD PTR [r9+48]\t\n" +"vfmadd231ps ymm0,ymm14,ymm12\t\n" +"vbroadcastss ymm12,DWORD PTR [r9+52]\t\n" +"vfmadd231ps ymm1,ymm14,ymm12\t\n" +"vbroadcastss ymm12,DWORD PTR [r9+56]\t\n" +"vfmadd231ps ymm2,ymm14,ymm12\t\n" +"vbroadcastss ymm12,DWORD PTR [r9+60]\t\n" +"vfmadd231ps ymm3,ymm14,ymm12\t\n" +"vbroadcastss ymm12,DWORD PTR [r9+64]\t\n" +"vfmadd231ps ymm4,ymm14,ymm12\t\n" +"vbroadcastss ymm12,DWORD PTR [r9+68]\t\n" +"vfmadd231ps ymm5,ymm14,ymm12\t\n" +"vbroadcastss ymm12,DWORD PTR [r9+72]\t\n" +"vfmadd231ps ymm6,ymm14,ymm12\t\n" +"add r11, 32\t\n" +"vbroadcastss ymm12,DWORD PTR [r9+76]\t\n" +"vfmadd231ps ymm7,ymm14,ymm12\t\n" +"vbroadcastss ymm12,DWORD PTR [r9+80]\t\n" +"vfmadd231ps ymm8,ymm14,ymm12\t\n" +"vbroadcastss ymm12,DWORD PTR [r9+84]\t\n" +"vfmadd231ps ymm9,ymm14,ymm12\t\n" +"vbroadcastss ymm12,DWORD PTR [r9+88]\t\n" +"vfmadd231ps ymm10,ymm14,ymm12\t\n" +"vbroadcastss ymm12,DWORD PTR [r9+92]\t\n" +"vfmadd231ps ymm11,ymm14,ymm12\t\n" +"add r9,96\t\n" +"cmp r14, r8\t\n" +"jl loop_inner%=\t\n" + +"L_exit%=:\t\n" +"add r10, rsi\t\n" + +"cmp rdx, 1\t\n" +"je L_accum%=\t\n" +// Dump C +"vmovups YMMWORD PTR [r12 + 0], ymm0\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm1\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm2\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm3\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm4\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm5\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm6\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm7\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm8\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm9\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm10\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm11\t\n" +"add r12, r13\t\n" +"jmp L_done%=\t\n" + + +"L_accum%=:\t\n" +// Dump C with accumulate +"vbroadcastss ymm15,DWORD PTR [r15]\t\n" +"vfmadd231ps ymm0,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm0\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm1,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm1\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm2,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm2\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm3,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm3\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm4,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm4\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm5,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm5\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm6,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm6\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm7,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm7\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm8,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm8\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm9,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm9\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm10,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm10\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm11,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm11\t\n" +"add r12, r13\t\n" + +"L_done%=:\t\n" + +// next outer iteration +"add rcx, 32\t\n" +"mov r12, rcx\t\n" +"mov r9, rax\t\n" +"inc rbx\t\n" +"cmp rbx, rdi\t\n" +"jl loop_outter%=\t\n" +: +: +[gp] "rm" (gp) +: "r8", "r9", "r10", "r11", "r15", "r13", "r14", +"rax", "rcx", "rdx", "rsi", "rdi", "rbx", "r12", "memory" +); +} +void __attribute__ ((noinline)) gemmkernel_13x1_AVX2_fA0fB0fC0(GemmParams *gp) +{ +asm volatile +( +#if !defined(__clang__) +"mov r14, %[gp]\t\n" +#else +"mov %[gp], %%r14\t\n" +".intel_syntax noprefix\t\n" +#endif + +// Copy parameters +// k +"mov r8, [r14 + 0]\t\n" +// A +"mov r9, [r14 + 8]\t\n" +// B +"mov r10, [r14 + 16]\t\n" +// beta +"mov r15, [r14 + 24]\t\n" +// accum +"mov rdx, [r14 + 32]\t\n" +// C +"mov r12, [r14 + 40]\t\n" +// ldc +"mov r13, [r14 + 48]\t\n" +// b_block_cols +"mov rdi, [r14 + 56]\t\n" +// b_block_size +"mov rsi, [r14 + 64]\t\n" +// Make copies of A and C +"mov rax, r9\t\n" +"mov rcx, r12\t\n" + + +"mov rbx, 0\t\n" +"loop_outter%=:\t\n" +"mov r14, 0\t\n" +"vxorps ymm0,ymm0,ymm0\t\n" +"vxorps ymm1,ymm1,ymm1\t\n" +"vxorps ymm2,ymm2,ymm2\t\n" +"vxorps ymm3,ymm3,ymm3\t\n" +"vxorps ymm4,ymm4,ymm4\t\n" +"vxorps ymm5,ymm5,ymm5\t\n" +"vxorps ymm6,ymm6,ymm6\t\n" +"vxorps ymm7,ymm7,ymm7\t\n" +"vxorps ymm8,ymm8,ymm8\t\n" +"vxorps ymm9,ymm9,ymm9\t\n" +"vxorps ymm10,ymm10,ymm10\t\n" +"vxorps ymm11,ymm11,ymm11\t\n" +"vxorps ymm12,ymm12,ymm12\t\n" + +"vcvtph2ps ymm15, XMMWORD PTR [r10 + 0]\t\n" +"mov r11, 16\t\n" + +"loop_inner%=:\t\n" + +"vcvtph2ps ymm14,XMMWORD PTR [r10 + r11 + 0]\t\n" +"inc r14\t\n" +"vbroadcastss ymm13,DWORD PTR [r9+0]\t\n" +"vfmadd231ps ymm0,ymm15,ymm13\t\n" +"vbroadcastss ymm13,DWORD PTR [r9+4]\t\n" +"vfmadd231ps ymm1,ymm15,ymm13\t\n" +"vbroadcastss ymm13,DWORD PTR [r9+8]\t\n" +"vfmadd231ps ymm2,ymm15,ymm13\t\n" +"vbroadcastss ymm13,DWORD PTR [r9+12]\t\n" +"vfmadd231ps ymm3,ymm15,ymm13\t\n" +"vbroadcastss ymm13,DWORD PTR [r9+16]\t\n" +"vfmadd231ps ymm4,ymm15,ymm13\t\n" +"vbroadcastss ymm13,DWORD PTR [r9+20]\t\n" +"vfmadd231ps ymm5,ymm15,ymm13\t\n" +"vbroadcastss ymm13,DWORD PTR [r9+24]\t\n" +"vfmadd231ps ymm6,ymm15,ymm13\t\n" +"vbroadcastss ymm13,DWORD PTR [r9+28]\t\n" +"vfmadd231ps ymm7,ymm15,ymm13\t\n" +"vbroadcastss ymm13,DWORD PTR [r9+32]\t\n" +"vfmadd231ps ymm8,ymm15,ymm13\t\n" +"vbroadcastss ymm13,DWORD PTR [r9+36]\t\n" +"vfmadd231ps ymm9,ymm15,ymm13\t\n" +"vbroadcastss ymm13,DWORD PTR [r9+40]\t\n" +"vfmadd231ps ymm10,ymm15,ymm13\t\n" +"vbroadcastss ymm13,DWORD PTR [r9+44]\t\n" +"vfmadd231ps ymm11,ymm15,ymm13\t\n" +"vbroadcastss ymm13,DWORD PTR [r9+48]\t\n" +"vfmadd231ps ymm12,ymm15,ymm13\t\n" +"cmp r14, r8\t\n" +"jge L_exit%=\t\n" +"vcvtph2ps ymm15,XMMWORD PTR [r10 + r11 + 16]\t\n" +"inc r14\t\n" +"vbroadcastss ymm13,DWORD PTR [r9+52]\t\n" +"vfmadd231ps ymm0,ymm14,ymm13\t\n" +"vbroadcastss ymm13,DWORD PTR [r9+56]\t\n" +"vfmadd231ps ymm1,ymm14,ymm13\t\n" +"vbroadcastss ymm13,DWORD PTR [r9+60]\t\n" +"vfmadd231ps ymm2,ymm14,ymm13\t\n" +"vbroadcastss ymm13,DWORD PTR [r9+64]\t\n" +"vfmadd231ps ymm3,ymm14,ymm13\t\n" +"vbroadcastss ymm13,DWORD PTR [r9+68]\t\n" +"vfmadd231ps ymm4,ymm14,ymm13\t\n" +"vbroadcastss ymm13,DWORD PTR [r9+72]\t\n" +"vfmadd231ps ymm5,ymm14,ymm13\t\n" +"vbroadcastss ymm13,DWORD PTR [r9+76]\t\n" +"vfmadd231ps ymm6,ymm14,ymm13\t\n" +"add r11, 32\t\n" +"vbroadcastss ymm13,DWORD PTR [r9+80]\t\n" +"vfmadd231ps ymm7,ymm14,ymm13\t\n" +"vbroadcastss ymm13,DWORD PTR [r9+84]\t\n" +"vfmadd231ps ymm8,ymm14,ymm13\t\n" +"vbroadcastss ymm13,DWORD PTR [r9+88]\t\n" +"vfmadd231ps ymm9,ymm14,ymm13\t\n" +"vbroadcastss ymm13,DWORD PTR [r9+92]\t\n" +"vfmadd231ps ymm10,ymm14,ymm13\t\n" +"vbroadcastss ymm13,DWORD PTR [r9+96]\t\n" +"vfmadd231ps ymm11,ymm14,ymm13\t\n" +"vbroadcastss ymm13,DWORD PTR [r9+100]\t\n" +"vfmadd231ps ymm12,ymm14,ymm13\t\n" +"add r9,104\t\n" +"cmp r14, r8\t\n" +"jl loop_inner%=\t\n" + +"L_exit%=:\t\n" +"add r10, rsi\t\n" + +"cmp rdx, 1\t\n" +"je L_accum%=\t\n" +// Dump C +"vmovups YMMWORD PTR [r12 + 0], ymm0\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm1\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm2\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm3\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm4\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm5\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm6\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm7\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm8\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm9\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm10\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm11\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm12\t\n" +"add r12, r13\t\n" +"jmp L_done%=\t\n" + + +"L_accum%=:\t\n" +// Dump C with accumulate +"vbroadcastss ymm15,DWORD PTR [r15]\t\n" +"vfmadd231ps ymm0,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm0\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm1,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm1\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm2,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm2\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm3,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm3\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm4,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm4\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm5,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm5\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm6,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm6\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm7,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm7\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm8,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm8\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm9,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm9\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm10,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm10\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm11,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm11\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm12,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm12\t\n" +"add r12, r13\t\n" + +"L_done%=:\t\n" + +// next outer iteration +"add rcx, 32\t\n" +"mov r12, rcx\t\n" +"mov r9, rax\t\n" +"inc rbx\t\n" +"cmp rbx, rdi\t\n" +"jl loop_outter%=\t\n" +: +: +[gp] "rm" (gp) +: "r8", "r9", "r10", "r11", "r15", "r13", "r14", +"rax", "rcx", "rdx", "rsi", "rdi", "rbx", "r12", "memory" +); +} +void __attribute__ ((noinline)) gemmkernel_14x1_AVX2_fA0fB0fC0(GemmParams *gp) +{ +asm volatile +( +#if !defined(__clang__) +"mov r14, %[gp]\t\n" +#else +"mov %[gp], %%r14\t\n" +".intel_syntax noprefix\t\n" +#endif + +// Copy parameters +// k +"mov r8, [r14 + 0]\t\n" +// A +"mov r9, [r14 + 8]\t\n" +// B +"mov r10, [r14 + 16]\t\n" +// beta +"mov r15, [r14 + 24]\t\n" +// accum +"mov rdx, [r14 + 32]\t\n" +// C +"mov r12, [r14 + 40]\t\n" +// ldc +"mov r13, [r14 + 48]\t\n" +// b_block_cols +"mov rdi, [r14 + 56]\t\n" +// b_block_size +"mov rsi, [r14 + 64]\t\n" +// Make copies of A and C +"mov rax, r9\t\n" +"mov rcx, r12\t\n" + + +"mov rbx, 0\t\n" +"loop_outter%=:\t\n" +"mov r14, 0\t\n" +"vxorps ymm0,ymm0,ymm0\t\n" +"vxorps ymm1,ymm1,ymm1\t\n" +"vxorps ymm2,ymm2,ymm2\t\n" +"vxorps ymm3,ymm3,ymm3\t\n" +"vxorps ymm4,ymm4,ymm4\t\n" +"vxorps ymm5,ymm5,ymm5\t\n" +"vxorps ymm6,ymm6,ymm6\t\n" +"vxorps ymm7,ymm7,ymm7\t\n" +"vxorps ymm8,ymm8,ymm8\t\n" +"vxorps ymm9,ymm9,ymm9\t\n" +"vxorps ymm10,ymm10,ymm10\t\n" +"vxorps ymm11,ymm11,ymm11\t\n" +"vxorps ymm12,ymm12,ymm12\t\n" +"vxorps ymm13,ymm13,ymm13\t\n" + +"mov r11, 0\t\n" + +"loop_inner%=:\t\n" + +"vcvtph2ps ymm15,XMMWORD PTR [r10 + r11]\t\n" +"vbroadcastss ymm14,DWORD PTR [r9+0]\t\n" +"vfmadd231ps ymm0,ymm15,ymm14\t\n" +"vbroadcastss ymm14,DWORD PTR [r9+4]\t\n" +"vfmadd231ps ymm1,ymm15,ymm14\t\n" +"vbroadcastss ymm14,DWORD PTR [r9+8]\t\n" +"vfmadd231ps ymm2,ymm15,ymm14\t\n" +"vbroadcastss ymm14,DWORD PTR [r9+12]\t\n" +"vfmadd231ps ymm3,ymm15,ymm14\t\n" +"vbroadcastss ymm14,DWORD PTR [r9+16]\t\n" +"vfmadd231ps ymm4,ymm15,ymm14\t\n" +"vbroadcastss ymm14,DWORD PTR [r9+20]\t\n" +"vfmadd231ps ymm5,ymm15,ymm14\t\n" +"vbroadcastss ymm14,DWORD PTR [r9+24]\t\n" +"vfmadd231ps ymm6,ymm15,ymm14\t\n" +"vbroadcastss ymm14,DWORD PTR [r9+28]\t\n" +"vfmadd231ps ymm7,ymm15,ymm14\t\n" +"vbroadcastss ymm14,DWORD PTR [r9+32]\t\n" +"vfmadd231ps ymm8,ymm15,ymm14\t\n" +"vbroadcastss ymm14,DWORD PTR [r9+36]\t\n" +"vfmadd231ps ymm9,ymm15,ymm14\t\n" +"vbroadcastss ymm14,DWORD PTR [r9+40]\t\n" +"vfmadd231ps ymm10,ymm15,ymm14\t\n" +"vbroadcastss ymm14,DWORD PTR [r9+44]\t\n" +"vfmadd231ps ymm11,ymm15,ymm14\t\n" +"vbroadcastss ymm14,DWORD PTR [r9+48]\t\n" +"vfmadd231ps ymm12,ymm15,ymm14\t\n" +"vbroadcastss ymm14,DWORD PTR [r9+52]\t\n" +"vfmadd231ps ymm13,ymm15,ymm14\t\n" +"add r9,56\t\n" +"add r11, 16\t\n" +"inc r14\t\n" +"cmp r14, r8\t\n" +"jl loop_inner%=\t\n" +"add r10, rsi\t\n" + +"cmp rdx, 1\t\n" +"je L_accum%=\t\n" +// Dump C +"vmovups YMMWORD PTR [r12 + 0], ymm0\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm1\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm2\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm3\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm4\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm5\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm6\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm7\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm8\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm9\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm10\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm11\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm12\t\n" +"add r12, r13\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm13\t\n" +"add r12, r13\t\n" +"jmp L_done%=\t\n" + + +"L_accum%=:\t\n" +// Dump C with accumulate +"vbroadcastss ymm15,DWORD PTR [r15]\t\n" +"vfmadd231ps ymm0,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm0\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm1,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm1\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm2,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm2\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm3,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm3\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm4,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm4\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm5,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm5\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm6,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm6\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm7,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm7\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm8,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm8\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm9,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm9\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm10,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm10\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm11,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm11\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm12,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm12\t\n" +"add r12, r13\t\n" +"vfmadd231ps ymm13,ymm15,YMMWORD PTR [r12 + 0]\t\n" +"vmovups YMMWORD PTR [r12 + 0], ymm13\t\n" +"add r12, r13\t\n" + +"L_done%=:\t\n" + +// next outer iteration +"add rcx, 32\t\n" +"mov r12, rcx\t\n" +"mov r9, rax\t\n" +"inc rbx\t\n" +"cmp rbx, rdi\t\n" +"jl loop_outter%=\t\n" +: +: +[gp] "rm" (gp) +: "r8", "r9", "r10", "r11", "r15", "r13", "r14", +"rax", "rcx", "rdx", "rsi", "rdi", "rbx", "r12", "memory" +); +} + +} // namespace fbgemm2 diff --git a/src/FbgemmFP16UKernels.h b/src/FbgemmFP16UKernels.h new file mode 100644 index 0000000..bf7f247 --- /dev/null +++ b/src/FbgemmFP16UKernels.h @@ -0,0 +1,40 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#ifndef FBGEMM_UKERNELS +#define FBGEMM_UKERNELS +#include <cstdint> +#include <tuple> +#include <vector> +#include "fbgemm/Types.h" + +namespace fbgemm2 { + +using fp16 = float16; +using fp32 = float; +struct GemmParams {uint64_t k; float *A; const fp16 *B; +float *beta; uint64_t accum; float *C; uint64_t ldc; +uint64_t b_block_cols; uint64_t b_block_size;}; +void __attribute__ ((noinline)) gemmkernel_1x1_AVX2_fA0fB0fC0(GemmParams *gp); +void __attribute__ ((noinline)) gemmkernel_2x1_AVX2_fA0fB0fC0(GemmParams *gp); +void __attribute__ ((noinline)) gemmkernel_3x1_AVX2_fA0fB0fC0(GemmParams *gp); +void __attribute__ ((noinline)) gemmkernel_4x1_AVX2_fA0fB0fC0(GemmParams *gp); +void __attribute__ ((noinline)) gemmkernel_5x1_AVX2_fA0fB0fC0(GemmParams *gp); +void __attribute__ ((noinline)) gemmkernel_6x1_AVX2_fA0fB0fC0(GemmParams *gp); +void __attribute__ ((noinline)) gemmkernel_7x1_AVX2_fA0fB0fC0(GemmParams *gp); +void __attribute__ ((noinline)) gemmkernel_8x1_AVX2_fA0fB0fC0(GemmParams *gp); +void __attribute__ ((noinline)) gemmkernel_9x1_AVX2_fA0fB0fC0(GemmParams *gp); +void __attribute__ ((noinline)) gemmkernel_10x1_AVX2_fA0fB0fC0(GemmParams *gp); +void __attribute__ ((noinline)) gemmkernel_11x1_AVX2_fA0fB0fC0(GemmParams *gp); +void __attribute__ ((noinline)) gemmkernel_12x1_AVX2_fA0fB0fC0(GemmParams *gp); +void __attribute__ ((noinline)) gemmkernel_13x1_AVX2_fA0fB0fC0(GemmParams *gp); +void __attribute__ ((noinline)) gemmkernel_14x1_AVX2_fA0fB0fC0(GemmParams *gp); +typedef void (* funcptr_fp16) (GemmParams *gp); +; + +} // namespace fbgemm2 + +#endif diff --git a/src/FbgemmI8Depthwise.cc b/src/FbgemmI8Depthwise.cc new file mode 100644 index 0000000..54e2272 --- /dev/null +++ b/src/FbgemmI8Depthwise.cc @@ -0,0 +1,1953 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include "FbgemmI8Depthwise.h" + +#include <algorithm> +#include <array> +#include <cassert> +#include <cmath> +#include <cstdio> +#include <tuple> +#include <vector> + +#include <x86intrin.h> + +using namespace std; + +namespace fbgemm2 +{ + +static array<array<int, 8>, 8> masks = {{ + { 0, 0, 0, 0, 0, 0, 0, 0, }, + { -1, 0, 0, 0, 0, 0, 0, 0, }, + { -1, -1, 0, 0, 0, 0, 0, 0, }, + { -1, -1, -1, 0, 0, 0, 0, 0, }, + { -1, -1, -1, -1, 0, 0, 0, 0, }, + { -1, -1, -1, -1, -1, 0, 0, 0, }, + { -1, -1, -1, -1, -1, -1, 0, 0, }, + { -1, -1, -1, -1, -1, -1, -1, 0, }, +}}; + +template <int KERNEL_PROD> +PackedDepthWiseConvMatrix<KERNEL_PROD>::PackedDepthWiseConvMatrix( + int K, const int8_t *smat) + : K_(K) { + // Transpose the input matrix to make packing faster. + vector<int8_t> smat_transposed(K * KERNEL_PROD); + for (int i = 0; i < KERNEL_PROD; ++i) { + for (int j = 0; j < K; ++j) { + smat_transposed[i * K + j] = smat[i + j * KERNEL_PROD]; + } + } + + // Allocate packed arrays + constexpr int KERNEL_PROD_ALIGNED = (KERNEL_PROD + 1) / 2 * 2; + pmat_ = static_cast<int8_t *>(aligned_alloc( + 64, ((K + 31) / 32) * KERNEL_PROD_ALIGNED * 32 * sizeof(int8_t))); + + // Pack input matrix + // The layout is optimized to use vpmaddubsw efficiently (see + // madd_epi16x4_packed function). + // For a group of 32 channels, we have 10 32B SIMD registers. + // Denote ith channel jth filter as (i, j) + // 0th SIMD register: + // (0, 0), (0, 1), (0, 2), (0, 3), ..., (3, 0), (3, 1), (3, 2), (3, 3) + // (16, 0), (16, 1), (16, 2), (16, 3), ..., (19, 0), (19, 1), (19, 2), (19, 3) + // 1st SIMD register: + // (4, 0), (4, 1), (4, 2), (4, 3), ..., (7, 0), (7, 1), (7, 2), (7, 3) + // (20, 0), (20, 1), (20, 2), (20, 3), ..., (23, 0), (23, 1), (23, 2), (23, 3) + // 2nd SIMD register: + // (8, 0), (8, 1), (8, 2), (8, 3), ..., (11, 0), (11, 1), (11, 2), (11, 3) + // (24, 0), (24, 1), (24, 2), (24, 3), ..., (27, 0), (27, 1), (27, 2), (27, 3) + // 3rd SIMD register: + // (12, 0), (12, 1), (12, 2), (12, 3), ..., (15, 0), (15, 1), (15, 2), (15, 3) + // (28, 0), (28, 1), (28, 2), (28, 3), ..., (31, 0), (31, 1), (31, 2), (31, 3) + // 4-7th SIMD register: same as the previous 4 registers but for 4-7th filter + // coefficients + // ... + // + // REMAINDER + // If KERNEL_PROD % 4 == 1 for example when KERNEL_PROD == 9 + // 8th SIMD register: + // (0, 8), zero, ..., (7, 8), zero + // (16, 8), zero, ..., (23, 8), zero + // 9th SIMD register: + // (8, 8), zero, ..., (15, 8), zero + // (24, 8), zero, ..., (31, 8), zero + // We use madd_epi16_packed for this case + // + // If KERNEL_PROD % 4 == 2 for example when KERNEL_PROD == 10 + // 8th SIMD register: + // (0, 8), (0, 9), ..., (7, 8), (7, 9) + // (16, 8), (16, 9), ..., (23, 8), (23, 9) + // 9th SIMD register: + // (8, 8), (8, 9), ..., (15, 8), (15, 9) + // (24, 8), (24, 9), ..., (31, 8), (31, 9) + // + // If KERNEL_PROD % 4 == 3 for example when KERNEL_PROD == 11 + // 8th SIMD register: + // (0, 8), (0, 9), (0, 10), zero, ..., (3, 8), (3, 9), (3, 10), zero + // (16, 8), (16, 9), (16, 10), zero, ..., (19, 8), (19, 9), (19, 10), zero + // 9th SIMD register: + // (4, 8), (4, 9), (4, 10), zero, ..., (7, 8), (7, 9), (7, 10), zero + // (20, 8), (20, 9), (20, 10), zero, ..., (23, 8), (23, 9), (23, 10), zero + // 10th SIMD register: + // (8, 8), (8, 9), (8, 10), zero, ..., (11, 8), (11, 9), (11, 10), zero + // (24, 8), (24, 9), (24, 10), zero, ..., (27, 8), (27, 9), (27, 10), zero + // 11th SIMD register: + // (12, 8), (12, 9), (12, 10), zero, ..., (15, 8), (15, 9), (15, 10), zero + // (28, 8), (28, 9), (28, 10), zero, ..., (31, 8), (31, 9), (31, 10), zero + for (int k1 = 0; k1 < K; k1 += 32) { + array<__m256i, KERNEL_PROD> b_v; + int remainder = K - k1; + if (remainder < 32) { + __m256i mask_v = _mm256_loadu_si256( + reinterpret_cast<const __m256i *>(masks[remainder / 4].data())); + for (int i = 0; i < KERNEL_PROD; ++i) { + b_v[i] = _mm256_maskload_epi32( + reinterpret_cast<const int *>(smat_transposed.data() + i * K + k1), + mask_v); + } + } else { + for (int i = 0; i < KERNEL_PROD; ++i) { + b_v[i] = _mm256_lddqu_si256(reinterpret_cast<const __m256i *>( + smat_transposed.data() + i * K + k1)); + } + } + + // Interleave 2 SIMD registers + array<__m256i, KERNEL_PROD_ALIGNED> b_interleaved_epi16; + __m256i zero_v = _mm256_setzero_si256(); + for (int i = 0; i < KERNEL_PROD_ALIGNED / 2; ++i) { + if (2 * i + 1 >= KERNEL_PROD) { + b_interleaved_epi16[2 * i] = _mm256_unpacklo_epi8(b_v[2 * i], zero_v); + b_interleaved_epi16[2 * i + 1] = + _mm256_unpackhi_epi8(b_v[2 * i], zero_v); + } else { + b_interleaved_epi16[2 * i] = + _mm256_unpacklo_epi8(b_v[2 * i], b_v[2 * i + 1]); + b_interleaved_epi16[2 * i + 1] = + _mm256_unpackhi_epi8(b_v[2 * i], b_v[2 * i + 1]); + } + } + + // Interleave 4 SIMD registers + array<__m256i, KERNEL_PROD_ALIGNED> b_interleaved_epi32; + for (int i = 0; i < KERNEL_PROD_ALIGNED / 4; ++i) { + b_interleaved_epi32[4 * i] = _mm256_unpacklo_epi16( + b_interleaved_epi16[4 * i], b_interleaved_epi16[4 * i + 2]); + b_interleaved_epi32[4 * i + 1] = _mm256_unpackhi_epi16( + b_interleaved_epi16[4 * i], b_interleaved_epi16[4 * i + 2]); + b_interleaved_epi32[4 * i + 2] = _mm256_unpacklo_epi16( + b_interleaved_epi16[4 * i + 1], b_interleaved_epi16[4 * i + 3]); + b_interleaved_epi32[4 * i + 3] = _mm256_unpackhi_epi16( + b_interleaved_epi16[4 * i + 1], b_interleaved_epi16[4 * i + 3]); + } + for (int i = KERNEL_PROD_ALIGNED / 4 * 4; i < KERNEL_PROD_ALIGNED; ++i) { + b_interleaved_epi32[i] = b_interleaved_epi16[i]; + } + + for (int i = 0; i < KERNEL_PROD_ALIGNED; ++i) { + _mm256_storeu_si256( + reinterpret_cast<__m256i *>( + &pmat_[((k1 / 32) * KERNEL_PROD_ALIGNED + i) * 32]), + b_interleaved_epi32[i]); + } + } +} + +template <int KERNEL_PROD> +PackedDepthWiseConvMatrix<KERNEL_PROD>::~PackedDepthWiseConvMatrix() +{ + free(pmat_); +} + +template class PackedDepthWiseConvMatrix<3 * 3>; +template class PackedDepthWiseConvMatrix<3 * 3 * 3>; + +// c = a0 * b0 + a1 * b1 + a2 * b2 + a3 * b3 +// A is in uint8_t +// B is in int8_t and pre-interleaved +// C is in int32_t and 4 registers have results in the following layout: +// c0_v: c[0:4], c[16:20] +// c1_v: c[4:8], c[20:24] +// c2_v: c[8:12], c[24:28] +// c3_v: c[12:16], c[28:32] +template <bool SUM_A = false> +static inline __attribute__((always_inline)) +void madd_epi16x4_packed( + __m256i a0_v, __m256i a1_v, __m256i a2_v, __m256i a3_v, + const __m256i* b, + __m256i* c0_v, __m256i* c1_v, __m256i* c2_v, __m256i* c3_v, + __m256i* a_sum = nullptr) { + __m256i a01_lo_v = _mm256_unpacklo_epi8(a0_v, a1_v); + __m256i a01_hi_v = _mm256_unpackhi_epi8(a0_v, a1_v); + __m256i a23_lo_v = _mm256_unpacklo_epi8(a2_v, a3_v); + __m256i a23_hi_v = _mm256_unpackhi_epi8(a2_v, a3_v); + + if (SUM_A) { + __m256i one_epi8_v = _mm256_set1_epi8(1); + a_sum[0] = + _mm256_adds_epi16(_mm256_maddubs_epi16(a01_lo_v, one_epi8_v), a_sum[0]); + a_sum[1] = + _mm256_adds_epi16(_mm256_maddubs_epi16(a01_hi_v, one_epi8_v), a_sum[1]); + a_sum[0] = + _mm256_adds_epi16(_mm256_maddubs_epi16(a23_lo_v, one_epi8_v), a_sum[0]); + a_sum[1] = + _mm256_adds_epi16(_mm256_maddubs_epi16(a23_hi_v, one_epi8_v), a_sum[1]); + } + + __m256i a0_interleaved_v = _mm256_unpacklo_epi16(a01_lo_v, a23_lo_v); + __m256i a1_interleaved_v = _mm256_unpackhi_epi16(a01_lo_v, a23_lo_v); + __m256i a2_interleaved_v = _mm256_unpacklo_epi16(a01_hi_v, a23_hi_v); + __m256i a3_interleaved_v = _mm256_unpackhi_epi16(a01_hi_v, a23_hi_v); + + __m256i b0_v = _mm256_load_si256(b + 0); + __m256i b1_v = _mm256_load_si256(b + 1); + __m256i b2_v = _mm256_load_si256(b + 2); + __m256i b3_v = _mm256_load_si256(b + 3); + + __m256i ab0 = _mm256_maddubs_epi16(a0_interleaved_v, b0_v); + __m256i ab1 = _mm256_maddubs_epi16(a1_interleaved_v, b1_v); + __m256i ab2 = _mm256_maddubs_epi16(a2_interleaved_v, b2_v); + __m256i ab3 = _mm256_maddubs_epi16(a3_interleaved_v, b3_v); + + __m256i one_v = _mm256_set1_epi16(1); + *c0_v = _mm256_madd_epi16(ab0, one_v); + *c1_v = _mm256_madd_epi16(ab1, one_v); + *c2_v = _mm256_madd_epi16(ab2, one_v); + *c3_v = _mm256_madd_epi16(ab3, one_v); +} + +// c = a0 * b0 + a1 * b1 + a2 * b2 +// A is in uint8_t +// B is in int8_t and pre-interleaved +// C is in int32_t and 4 registers have results in the following layout: +// c0_v: c[0:4], c[16:20] +// c1_v: c[4:8], c[20:24] +// c2_v: c[8:12], c[24:28] +// c3_v: c[12:16], c[28:32] +template <bool SUM_A = false> +static inline __attribute__((always_inline)) +void madd_epi16x3_packed( + __m256i a0_v, __m256i a1_v, __m256i a2_v, + const __m256i* b, + __m256i* c0_v, __m256i* c1_v, __m256i* c2_v, __m256i* c3_v, + __m256i* a_sum = nullptr) { + __m256i zero_v = _mm256_setzero_si256(); + + __m256i a01_lo_v = _mm256_unpacklo_epi8(a0_v, a1_v); + __m256i a01_hi_v = _mm256_unpackhi_epi8(a0_v, a1_v); + __m256i a23_lo_v = _mm256_unpacklo_epi8(a2_v, zero_v); + __m256i a23_hi_v = _mm256_unpackhi_epi8(a2_v, zero_v); + + if (SUM_A) { + __m256i one_epi8_v = _mm256_set1_epi8(1); + a_sum[0] = + _mm256_adds_epi16(_mm256_maddubs_epi16(a01_lo_v, one_epi8_v), a_sum[0]); + a_sum[1] = + _mm256_adds_epi16(_mm256_maddubs_epi16(a01_hi_v, one_epi8_v), a_sum[1]); + a_sum[0] = + _mm256_adds_epi16(_mm256_maddubs_epi16(a23_lo_v, one_epi8_v), a_sum[0]); + a_sum[1] = + _mm256_adds_epi16(_mm256_maddubs_epi16(a23_hi_v, one_epi8_v), a_sum[1]); + } + + __m256i a0_interleaved_v = _mm256_unpacklo_epi16(a01_lo_v, a23_lo_v); + __m256i a1_interleaved_v = _mm256_unpackhi_epi16(a01_lo_v, a23_lo_v); + __m256i a2_interleaved_v = _mm256_unpacklo_epi16(a01_hi_v, a23_hi_v); + __m256i a3_interleaved_v = _mm256_unpackhi_epi16(a01_hi_v, a23_hi_v); + + __m256i b0_v = _mm256_load_si256(b + 0); + __m256i b1_v = _mm256_load_si256(b + 1); + __m256i b2_v = _mm256_load_si256(b + 2); + __m256i b3_v = _mm256_load_si256(b + 3); + + __m256i ab0 = _mm256_maddubs_epi16(a0_interleaved_v, b0_v); + __m256i ab1 = _mm256_maddubs_epi16(a1_interleaved_v, b1_v); + __m256i ab2 = _mm256_maddubs_epi16(a2_interleaved_v, b2_v); + __m256i ab3 = _mm256_maddubs_epi16(a3_interleaved_v, b3_v); + + __m256i one_v = _mm256_set1_epi16(1); + *c0_v = _mm256_madd_epi16(ab0, one_v); + *c1_v = _mm256_madd_epi16(ab1, one_v); + *c2_v = _mm256_madd_epi16(ab2, one_v); + *c3_v = _mm256_madd_epi16(ab3, one_v); +} + +// c = a0 * b0 + a1 * b1 +// A is in uint8_t +// B is in int8_t and pre-interleaved +// C is in int32_t and 4 registers have results in the following layout: +// c0_v: c[0:4], c[4:8] +// c1_v: c[8:12], c[12:16] +// c2_v: c[16:20], c[20:24] +// c3_v: c[24:28], c[28:32] +template <bool SUM_A = false> +static inline __attribute__((always_inline)) void +madd_epi16x2_packed(__m256i a0_v, __m256i a1_v, const __m256i *b, __m256i *c0_v, + __m256i *c1_v, __m256i *c2_v, __m256i *c3_v, + __m256i *a_sum = nullptr) { + __m256i a_lo_v = _mm256_unpacklo_epi8(a0_v, a1_v); + __m256i a_hi_v = _mm256_unpackhi_epi8(a0_v, a1_v); + + if (SUM_A) { + __m256i one_epi8_v = _mm256_set1_epi8(1); + a_sum[0] = + _mm256_adds_epi16(_mm256_maddubs_epi16(a_lo_v, one_epi8_v), a_sum[0]); + a_sum[1] = + _mm256_adds_epi16(_mm256_maddubs_epi16(a_hi_v, one_epi8_v), a_sum[1]); + } + + __m256i b0_v = _mm256_load_si256(b + 0); + __m256i b1_v = _mm256_load_si256(b + 1); + + __m256i ab_lo_v = _mm256_maddubs_epi16(a_lo_v, b0_v); + __m256i ab_hi_v = _mm256_maddubs_epi16(a_hi_v, b1_v); + + *c0_v = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(ab_lo_v)); + *c1_v = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(ab_hi_v)); + *c2_v = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(ab_lo_v, 1)); + *c3_v = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(ab_hi_v, 1)); +} + +// c = a0 * b0 +// A is in uint8_t +// B is in int8_t and pre-interleaved +// C is in int32_t and 4 registers have results in the following layout: +// c0_v: c[0:4], c[4:8] +// c1_v: c[8:12], c[12:16] +// c2_v: c[16:20], c[20:24] +// c3_v: c[24:28], c[28:32] +template <bool SUM_A = false> +static inline __attribute__((always_inline)) void +madd_epi16_packed(__m256i a_v, const __m256i *b, __m256i *c0_v, __m256i *c1_v, + __m256i *c2_v, __m256i *c3_v, __m256i *a_sum = nullptr) { + __m256i zero_v = _mm256_setzero_si256(); + + __m256i a_lo_v = _mm256_unpacklo_epi8(a_v, zero_v); + __m256i a_hi_v = _mm256_unpackhi_epi8(a_v, zero_v); + + if (SUM_A) { + __m256i one_epi8_v = _mm256_set1_epi8(1); + a_sum[0] = + _mm256_adds_epi16(_mm256_maddubs_epi16(a_lo_v, one_epi8_v), a_sum[0]); + a_sum[1] = + _mm256_adds_epi16(_mm256_maddubs_epi16(a_hi_v, one_epi8_v), a_sum[1]); + } + + __m256i b0_v = _mm256_load_si256(b + 0); + __m256i b1_v = _mm256_load_si256(b + 1); + + __m256i ab_lo_v = _mm256_maddubs_epi16(a_lo_v, b0_v); + __m256i ab_hi_v = _mm256_maddubs_epi16(a_hi_v, b1_v); + + *c0_v = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(ab_lo_v)); + *c1_v = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(ab_hi_v)); + *c2_v = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(ab_lo_v, 1)); + *c3_v = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(ab_hi_v, 1)); +} + +// K is the number of accumulations we're doing +template <int K, bool SUM_A = false, bool REMAINDER = false, bool ACC = false> +static inline __attribute__((always_inline)) void +inner_prod_packed_(const __m256i *a_v, const __m256i *Bp, int32_t *C, + int remainder, __m256i *a_sum = nullptr) { + array<__m256i, 4> c, c_temp; + array<__m256i, 2> a_sum_temp{}; + + int k = 0; + if (K >= 4) { + madd_epi16x4_packed<SUM_A>(a_v[0], a_v[1], a_v[2], a_v[3], Bp, + &c[0], &c[1], &c[2], &c[3], a_sum_temp.data()); + + for (k = 4; k < K / 4 * 4; k += 4) { + madd_epi16x4_packed<SUM_A>(a_v[k + 0], a_v[k + 1], a_v[k + 2], a_v[k + 3], + Bp + k, &c_temp[0], &c_temp[1], &c_temp[2], + &c_temp[3], a_sum_temp.data()); + + c[0] = _mm256_add_epi32(c[0], c_temp[0]); + c[1] = _mm256_add_epi32(c[1], c_temp[1]); + c[2] = _mm256_add_epi32(c[2], c_temp[2]); + c[3] = _mm256_add_epi32(c[3], c_temp[3]); + } + } else { + c[0] = _mm256_setzero_si256(); + c[1] = _mm256_setzero_si256(); + c[2] = _mm256_setzero_si256(); + c[3] = _mm256_setzero_si256(); + } + + if (K - k == 3) { + madd_epi16x3_packed<SUM_A>(a_v[k], a_v[k + 1], a_v[k + 2], Bp + k, + &c_temp[0], &c_temp[1], &c_temp[2], &c_temp[3], + a_sum_temp.data()); + + c[0] = _mm256_add_epi32(c[0], c_temp[0]); + c[1] = _mm256_add_epi32(c[1], c_temp[1]); + c[2] = _mm256_add_epi32(c[2], c_temp[2]); + c[3] = _mm256_add_epi32(c[3], c_temp[3]); + } + + c_temp[0] = _mm256_permute2f128_si256(c[0], c[1], 0x20); + c_temp[1] = _mm256_permute2f128_si256(c[2], c[3], 0x20); + c_temp[2] = _mm256_permute2f128_si256(c[0], c[1], 0x31); + c_temp[3] = _mm256_permute2f128_si256(c[2], c[3], 0x31); + + if (K - k == 0 || K - k == 3) { + c[0] = c_temp[0]; + c[1] = c_temp[1]; + c[2] = c_temp[2]; + c[3] = c_temp[3]; + } else { + if (K - k == 1) { + madd_epi16_packed<SUM_A>(a_v[k], Bp + k, &c[0], &c[1], &c[2], &c[3], + a_sum_temp.data()); + } else if (K - k == 2) { + madd_epi16x2_packed<SUM_A>(a_v[k], a_v[k + 1], Bp + k, &c[0], &c[1], + &c[2], &c[3], a_sum_temp.data()); + } + + c[0] = _mm256_add_epi32(c[0], c_temp[0]); + c[1] = _mm256_add_epi32(c[1], c_temp[1]); + c[2] = _mm256_add_epi32(c[2], c_temp[2]); + c[3] = _mm256_add_epi32(c[3], c_temp[3]); + } + + if (REMAINDER) { + for (int r = 0; r < remainder / 8; ++r) { + if (ACC) { + _mm256_storeu_si256( + reinterpret_cast<__m256i *>(C + r * 8), + _mm256_add_epi32( + _mm256_loadu_si256(reinterpret_cast<__m256i *>(C + r * 8)), + c[r])); + } else { + _mm256_storeu_si256(reinterpret_cast<__m256i *>(C + r * 8), c[r]); + } + } + } else { + if (ACC) { + _mm256_storeu_si256( + reinterpret_cast<__m256i *>(C), + _mm256_add_epi32(_mm256_loadu_si256(reinterpret_cast<__m256i *>(C)), + c[0])); + _mm256_storeu_si256( + reinterpret_cast<__m256i *>(C + 8), + _mm256_add_epi32( + _mm256_loadu_si256(reinterpret_cast<__m256i *>(C + 8)), c[1])); + _mm256_storeu_si256( + reinterpret_cast<__m256i *>(C + 16), + _mm256_add_epi32( + _mm256_loadu_si256(reinterpret_cast<__m256i *>(C + 16)), c[2])); + _mm256_storeu_si256( + reinterpret_cast<__m256i *>(C + 24), + _mm256_add_epi32( + _mm256_loadu_si256(reinterpret_cast<__m256i *>(C + 24)), c[3])); + } else { + _mm256_storeu_si256(reinterpret_cast<__m256i *>(C), c[0]); + _mm256_storeu_si256(reinterpret_cast<__m256i *>(C + 8), c[1]); + _mm256_storeu_si256(reinterpret_cast<__m256i *>(C + 16), c[2]); + _mm256_storeu_si256(reinterpret_cast<__m256i *>(C + 24), c[3]); + } + } + + if (SUM_A) { + a_sum[0] = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(a_sum_temp[0])); + a_sum[1] = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(a_sum_temp[1])); + a_sum[2] = + _mm256_cvtepi16_epi32(_mm256_extracti128_si256(a_sum_temp[0], 1)); + a_sum[3] = + _mm256_cvtepi16_epi32(_mm256_extracti128_si256(a_sum_temp[1], 1)); + } +} + +template <bool SUM_A = false, bool REMAINDER = false> +static inline __attribute__((always_inline)) +void inner_prod_3x3_packed_(const __m256i* a_v, + const __m256i* Bp, + int32_t* C, + int remainder, + __m256i* a_sum = nullptr) { + return inner_prod_packed_<9, SUM_A, REMAINDER>(a_v, Bp, C, remainder, + a_sum); +} + +// Almost same as ReQuantizeOutput in OutputProcessing-inh.h but different +// row_offsets for each row because of depth-wise convolution +template <bool FUSE_RELU, bool HAS_BIAS, bool PER_CHANNEL_QUANTIZATION> +static inline __attribute__((always_inline)) void requantize_( + int32_t A_zero_point, + const float* C_multiplier, + int32_t C_zero_point, + const int32_t* C_int32, + uint8_t* C_uint8, + int n, + const int32_t* row_offsets, + const int32_t* col_offsets, + const int32_t* bias) { + __m256 multiplier_v = _mm256_setzero_ps(); + if (!PER_CHANNEL_QUANTIZATION) { + multiplier_v = _mm256_set1_ps(*C_multiplier); + } + + __m256i min_v = _mm256_set1_epi8(numeric_limits<uint8_t>::min()); + __m256i max_v = _mm256_set1_epi8(numeric_limits<uint8_t>::max()); + + __m256i A_zero_point_v = _mm256_set1_epi32(A_zero_point); + __m256i C_zero_point_epi16_v = _mm256_set1_epi16(C_zero_point); + __m256i C_zero_point_epi8_v = _mm256_set1_epi8(C_zero_point); + + __m256i permute_mask_v = + _mm256_set_epi32(0x07, 0x03, 0x06, 0x02, 0x05, 0x01, 0x04, 0x00); + + constexpr int VLEN = 8; + int j = 0; + for ( ; j < n / (VLEN * 4) * (VLEN * 4); j += (VLEN * 4)) { + __m256i x_v = + _mm256_loadu_si256(reinterpret_cast<const __m256i*>(C_int32 + j)); + __m256i y_v = _mm256_loadu_si256( + reinterpret_cast<const __m256i*>(C_int32 + j + VLEN)); + __m256i z_v = _mm256_loadu_si256( + reinterpret_cast<const __m256i*>(C_int32 + j + 2 * VLEN)); + __m256i w_v = _mm256_loadu_si256( + reinterpret_cast<const __m256i*>(C_int32 + j + 3 * VLEN)); + + __m256i col_off_v = _mm256_mullo_epi32( + A_zero_point_v, + _mm256_loadu_si256( + reinterpret_cast<const __m256i*>(col_offsets + j))); + __m256i row_offset_v = + _mm256_loadu_si256(reinterpret_cast<const __m256i *>(row_offsets + j)); + x_v = _mm256_sub_epi32(_mm256_sub_epi32(x_v, col_off_v), row_offset_v); + + col_off_v = _mm256_mullo_epi32( + A_zero_point_v, + _mm256_loadu_si256( + reinterpret_cast<const __m256i*>(col_offsets + j + VLEN))); + row_offset_v = _mm256_loadu_si256( + reinterpret_cast<const __m256i*>(row_offsets + j + VLEN)); + y_v = _mm256_sub_epi32(_mm256_sub_epi32(y_v, col_off_v), row_offset_v); + + col_off_v = _mm256_mullo_epi32( + A_zero_point_v, + _mm256_loadu_si256(reinterpret_cast<const __m256i*>( + col_offsets + j + 2 * VLEN))); + row_offset_v = _mm256_loadu_si256( + reinterpret_cast<const __m256i*>(row_offsets + j + 2 * VLEN)); + z_v = _mm256_sub_epi32(_mm256_sub_epi32(z_v, col_off_v), row_offset_v); + + col_off_v = _mm256_mullo_epi32( + A_zero_point_v, + _mm256_loadu_si256(reinterpret_cast<const __m256i*>( + col_offsets + j + 3 * VLEN))); + row_offset_v = _mm256_loadu_si256( + reinterpret_cast<const __m256i*>(row_offsets + j + 3 * VLEN)); + w_v = _mm256_sub_epi32(_mm256_sub_epi32(w_v, col_off_v), row_offset_v); + + if (HAS_BIAS) { // static if + x_v = _mm256_add_epi32( + x_v, _mm256_loadu_si256(reinterpret_cast<const __m256i *>(bias + j))); + y_v = _mm256_add_epi32( + y_v, + _mm256_loadu_si256( + reinterpret_cast<const __m256i*>(bias + j + VLEN))); + z_v = _mm256_add_epi32( + z_v, + _mm256_loadu_si256( + reinterpret_cast<const __m256i*>(bias + j + 2 * VLEN))); + w_v = _mm256_add_epi32( + w_v, + _mm256_loadu_si256( + reinterpret_cast<const __m256i*>(bias + j + 3 * VLEN))); + } + + if (PER_CHANNEL_QUANTIZATION) { + multiplier_v = _mm256_loadu_ps(C_multiplier + j); + } + __m256 x_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(x_v), multiplier_v); + if (PER_CHANNEL_QUANTIZATION) { + multiplier_v = _mm256_loadu_ps(C_multiplier + j + VLEN); + } + __m256 y_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(y_v), multiplier_v); + if (PER_CHANNEL_QUANTIZATION) { + multiplier_v = _mm256_loadu_ps(C_multiplier + j + 2 * VLEN); + } + __m256 z_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(z_v), multiplier_v); + if (PER_CHANNEL_QUANTIZATION) { + multiplier_v = _mm256_loadu_ps(C_multiplier + j + 3 * VLEN); + } + __m256 w_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(w_v), multiplier_v); + + __m256i x_rounded_v = _mm256_cvtps_epi32(x_scaled_v); + __m256i y_rounded_v = _mm256_cvtps_epi32(y_scaled_v); + __m256i z_rounded_v = _mm256_cvtps_epi32(z_scaled_v); + __m256i w_rounded_v = _mm256_cvtps_epi32(w_scaled_v); + + __m256i xy_packed_v = _mm256_adds_epi16( + _mm256_packs_epi32(x_rounded_v, y_rounded_v), C_zero_point_epi16_v); + __m256i zw_packed_v = _mm256_adds_epi16( + _mm256_packs_epi32(z_rounded_v, w_rounded_v), C_zero_point_epi16_v); + __m256i xyzw_packed_v = _mm256_packus_epi16(xy_packed_v, zw_packed_v); + __m256i xyzw_clamped_v = _mm256_max_epu8( + FUSE_RELU ? C_zero_point_epi8_v : min_v, + _mm256_min_epu8(xyzw_packed_v, max_v)); + + xyzw_clamped_v = + _mm256_permutevar8x32_epi32(xyzw_clamped_v, permute_mask_v); + + _mm256_storeu_si256( + reinterpret_cast<__m256i*>(C_uint8 + j), xyzw_clamped_v); + } // j loop vectorized and unrolled 4x + + for ( ; j < n / VLEN * VLEN; j += VLEN) { + __m256i x_v = + _mm256_loadu_si256(reinterpret_cast<const __m256i*>(C_int32 + j)); + + __m256i col_off_v = _mm256_mullo_epi32( + A_zero_point_v, + _mm256_loadu_si256( + reinterpret_cast<const __m256i*>(col_offsets + j))); + __m256i row_offset_v = + _mm256_loadu_si256(reinterpret_cast<const __m256i *>(row_offsets + j)); + x_v = _mm256_sub_epi32(_mm256_sub_epi32(x_v, col_off_v), row_offset_v); + + if (HAS_BIAS) { // static if + x_v = _mm256_add_epi32( + x_v, _mm256_loadu_si256(reinterpret_cast<const __m256i *>(bias + j))); + } + + if (PER_CHANNEL_QUANTIZATION) { + multiplier_v = _mm256_loadu_ps(C_multiplier + j); + } + __m256 x_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(x_v), multiplier_v); + __m256i x_rounded_v = _mm256_cvtps_epi32(x_scaled_v); + + __m256i x_packed_v = _mm256_adds_epi16( + _mm256_packs_epi32(x_rounded_v, _mm256_setzero_si256()), + C_zero_point_epi16_v); + x_packed_v = _mm256_packus_epi16(x_packed_v, _mm256_setzero_si256()); + __m256i x_clamped_v = _mm256_max_epu8( + FUSE_RELU ? C_zero_point_epi8_v : min_v, + _mm256_min_epu8(x_packed_v, max_v)); + + x_clamped_v = + _mm256_permutevar8x32_epi32(x_clamped_v, permute_mask_v); + + _mm_storel_epi64( + reinterpret_cast<__m128i*>(C_uint8 + j), + _mm256_castsi256_si128(x_clamped_v)); + } // j loop vectorized + + for ( ; j < n; ++j) { + int32_t raw = C_int32[j] - A_zero_point * col_offsets[j] - row_offsets[j]; + if (HAS_BIAS) { // static if + raw += bias[j]; + } + + float ab = raw * C_multiplier[PER_CHANNEL_QUANTIZATION ? j : 0]; + long rounded = lrintf(ab) + C_zero_point; + + C_uint8[j] = std::max( + FUSE_RELU ? static_cast<long>(C_zero_point) : 0l, + std::min(255l, rounded)); + } +} + +template <bool FUSE_RELU, bool HAS_BIAS> +static inline __attribute__((always_inline)) void +requantize_(int32_t A_zero_point, float C_multiplier, + int32_t C_zero_point, const int32_t *C_int32, uint8_t *C_uint8, + int n, const int32_t *row_offsets, const int32_t *col_offsets, + const int32_t *bias) { + requantize_<FUSE_RELU, HAS_BIAS, false /* PER_CHANNEL_QUANTIZATION */>( + A_zero_point, + &C_multiplier, + C_zero_point, + C_int32, + C_uint8, + n, + row_offsets, + col_offsets, + bias); +} + +template <bool FUSE_RELU, bool HAS_BIAS> +static inline __attribute__((always_inline)) void +requantize_per_channel_(int32_t A_zero_point, const float *C_multiplier, + int32_t C_zero_point, const int32_t *C_int32, + uint8_t *C_uint8, int n, const int32_t *row_offsets, + const int32_t *col_offsets, const int32_t *bias) { + requantize_<FUSE_RELU, HAS_BIAS, true /* PER_CHANNEL_QUANTIZATION */>( + A_zero_point, + C_multiplier, + C_zero_point, + C_int32, + C_uint8, + n, + row_offsets, + col_offsets, + bias); +} + +template <bool REMAINDER> +static inline __attribute__((always_inline)) __m256i +load_a(const uint8_t* A, __m256i mask_v) { + if (REMAINDER) { + return _mm256_maskload_epi32(reinterpret_cast<const int *>(A), mask_v); + } else { + return _mm256_lddqu_si256(reinterpret_cast<const __m256i *>(A)); + } +} + +template <bool SUM_A, bool REMAINDER = false, + bool PER_CHANNEL_QUANTIZATION = false> +static inline __attribute__((always_inline)) void +inner_prod_3x3_packed_(int H, int W, int K, int h_in, int w_in, + const uint8_t *A, int32_t A_zero_point, const int8_t *Bp, + const int32_t *B_zero_point, int32_t *C, int remainder, + int32_t *row_offsets) { + __m256i A_zero_point_v = _mm256_set1_epi8(static_cast<uint8_t>(A_zero_point)); + __m256i mask_v = _mm256_setzero_si256(); + if (REMAINDER) { + mask_v = _mm256_loadu_si256( + reinterpret_cast<const __m256i *>(masks[remainder / 4].data())); + } + + // The code below can be written as a simple R*S loop but the compiler + // doesn't unroll so we're manually unrolling it. + // constexpr int R = 3, S = 3; + // array<__m256i, R * S> a_v; + // for (int r = 0; r < R; ++r) { + // for (int s = 0; s < S; ++s) { + // if (h_in + r >= 0 && h_in + r < H && w_in + s >= 0 && w_in + s < W) { + // if (REMAINDER) { + // a_v[r * S + s] = + // _mm256_maskload_epi32((const int *)(A + (r * W + s) * K), + // mask_v); + // } else { + // a_v[r * S + s] = + // _mm256_lddqu_si256((const __m256i *)(A + (r * W + s) * K)); + // } + // } else { + // a_v[r * S + s] = A_zero_point_v; + // } + // } + // } + array<__m256i, 9> a_v = { + A_zero_point_v, + A_zero_point_v, + A_zero_point_v, + A_zero_point_v, + A_zero_point_v, + A_zero_point_v, + A_zero_point_v, + A_zero_point_v, + A_zero_point_v, + }; + + if (h_in >= 0 && h_in < H) { + if (w_in >= 0 && w_in < W) { + a_v[0] = load_a<REMAINDER>(A + (0 * W + 0) * K, mask_v); + } + if (w_in + 1 >= 0 && w_in + 1 < W) { + a_v[1] = load_a<REMAINDER>(A + (0 * W + 1) * K, mask_v); + } + if (w_in + 2 >= 0 && w_in + 2 < W) { + a_v[2] = load_a<REMAINDER>(A + (0 * W + 2) * K, mask_v); + } + } + + if (h_in + 1 >= 0 && h_in + 1 < H) { + if (w_in >= 0 && w_in < W) { + a_v[3] = load_a<REMAINDER>(A + (1 * W + 0) * K, mask_v); + } + if (w_in + 1 >= 0 && w_in + 1 < W) { + a_v[4] = load_a<REMAINDER>(A + (1 * W + 1) * K, mask_v); + } + if (w_in + 2 >= 0 && w_in + 2 < W) { + a_v[5] = load_a<REMAINDER>(A + (1 * W + 2) * K, mask_v); + } + } + + if (h_in + 2 >= 0 && h_in + 2 < H) { + if (w_in >= 0 && w_in < W) { + a_v[6] = load_a<REMAINDER>(A + (2 * W + 0) * K, mask_v); + } + if (w_in + 1 >= 0 && w_in + 1 < W) { + a_v[7] = load_a<REMAINDER>(A + (2 * W + 1) * K, mask_v); + } + if (w_in + 2 >= 0 && w_in + 2 < W) { + a_v[8] = load_a<REMAINDER>(A + (2 * W + 2) * K, mask_v); + } + } + + array<__m256i, 4> a_sum; + inner_prod_3x3_packed_<SUM_A, REMAINDER>( + a_v.data(), reinterpret_cast<const __m256i *>(Bp), C, remainder, + a_sum.data()); + if (SUM_A) { + __m256i B_zero_point_v; + for (int i = 0; i < (REMAINDER ? (remainder / 8) : 4); ++i) { + if (PER_CHANNEL_QUANTIZATION) { + B_zero_point_v = _mm256_loadu_si256( + reinterpret_cast<const __m256i *>(B_zero_point + i * 8)); + } else { + B_zero_point_v = _mm256_set1_epi32(B_zero_point[0]); + } + _mm256_store_si256(reinterpret_cast<__m256i *>(&row_offsets[i * 8]), + _mm256_mullo_epi32(a_sum[i], B_zero_point_v)); + } + } +} + +template <bool SUM_A, bool REMAINDER = false, + bool PER_CHANNEL_QUANTIZATION = false> +static inline __attribute__((always_inline)) void +inner_prod_3x3x3_packed_(int T, int H, int W, int K, int t_in, int h_in, + int w_in, const uint8_t *A, int32_t A_zero_point, + const int8_t *Bp, const int32_t *B_zero_point, + int32_t *C, int remainder, int32_t *row_offsets) { + __m256i A_zero_point_v = _mm256_set1_epi8(static_cast<uint8_t>(A_zero_point)); + __m256i mask_v = _mm256_setzero_si256(); + if (REMAINDER) { + mask_v = _mm256_loadu_si256( + reinterpret_cast<const __m256i *>(masks[remainder / 4].data())); + } + + // The code below can be written as a simple R*S loop but the compiler + // doesn't unroll so we're manually unrolling it. + // constexpr int R = 3, S = 3; + // array<__m256i, R * S> a_v; + // for (int r = 0; r < R; ++r) { + // for (int s = 0; s < S; ++s) { + // if (h_in + r >= 0 && h_in + r < H && w_in + s >= 0 && w_in + s < W) { + // if (REMAINDER) { + // a_v[r * S + s] = + // _mm256_maskload_epi32((const int *)(A + (r * W + s) * K), + // mask_v); + // } else { + // a_v[r * S + s] = + // _mm256_lddqu_si256((const __m256i *)(A + (r * W + s) * K)); + // } + // } else { + // a_v[r * S + s] = A_zero_point_v; + // } + // } + // } + array<__m256i, 8> a_v; + a_v[0] = A_zero_point_v; + a_v[1] = A_zero_point_v; + a_v[2] = A_zero_point_v; + a_v[3] = A_zero_point_v; + a_v[4] = A_zero_point_v; + a_v[5] = A_zero_point_v; + a_v[6] = A_zero_point_v; + a_v[7] = A_zero_point_v; + + if (t_in >= 0 && t_in < T) { + if (h_in >= 0 && h_in < H) { + if (w_in >= 0 && w_in < W) { + a_v[0] = load_a<REMAINDER>(A + ((0 * H + 0) * W + 0) * K, mask_v); + } + if (w_in + 1 >= 0 && w_in + 1 < W) { + a_v[1] = load_a<REMAINDER>(A + ((0 * H + 0) * W + 1) * K, mask_v); + } + if (w_in + 2 >= 0 && w_in + 2 < W) { + a_v[2] = load_a<REMAINDER>(A + ((0 * H + 0) * W + 2) * K, mask_v); + } + } + + if (h_in + 1 >= 0 && h_in + 1 < H) { + if (w_in >= 0 && w_in < W) { + a_v[3] = load_a<REMAINDER>(A + ((0 * H + 1) * W + 0) * K, mask_v); + } + if (w_in + 1 >= 0 && w_in + 1 < W) { + a_v[4] = load_a<REMAINDER>(A + ((0 * H + 1) * W + 1) * K, mask_v); + } + if (w_in + 2 >= 0 && w_in + 2 < W) { + a_v[5] = load_a<REMAINDER>(A + ((0 * H + 1) * W + 2) * K, mask_v); + } + } + + if (h_in + 2 >= 0 && h_in + 2 < H) { + if (w_in >= 0 && w_in < W) { + a_v[6] = load_a<REMAINDER>(A + ((0 * H + 2) * W + 0) * K, mask_v); + } + if (w_in + 1 >= 0 && w_in + 1 < W) { + a_v[7] = load_a<REMAINDER>(A + ((0 * H + 2) * W + 1) * K, mask_v); + } + } + } + + array<__m256i, 4> a_sum; + inner_prod_packed_<8, SUM_A, REMAINDER>(a_v.data(), + reinterpret_cast<const __m256i *>(Bp), + C, remainder, a_sum.data()); + + a_v[0] = A_zero_point_v; + a_v[1] = A_zero_point_v; + a_v[2] = A_zero_point_v; + a_v[3] = A_zero_point_v; + a_v[4] = A_zero_point_v; + a_v[5] = A_zero_point_v; + a_v[6] = A_zero_point_v; + a_v[7] = A_zero_point_v; + + if (t_in >= 0 && t_in < T) { + if (h_in + 2 >= 0 && h_in + 2 < H) { + if (w_in + 2 >= 0 && w_in + 2 < W) { + a_v[0] = load_a<REMAINDER>(A + ((0 * H + 2) * W + 2) * K, mask_v); + } + } + } + + if (t_in + 1 >= 0 && t_in + 1 < T) { + if (h_in >= 0 && h_in < H) { + if (w_in >= 0 && w_in < W) { + a_v[1] = load_a<REMAINDER>(A + ((1 * H + 0) * W + 0) * K, mask_v); + } + if (w_in + 1 >= 0 && w_in + 1 < W) { + a_v[2] = load_a<REMAINDER>(A + ((1 * H + 0) * W + 1) * K, mask_v); + } + if (w_in + 2 >= 0 && w_in + 2 < W) { + a_v[3] = load_a<REMAINDER>(A + ((1 * H + 0) * W + 2) * K, mask_v); + } + } + + if (h_in + 1 >= 0 && h_in + 1 < H) { + if (w_in >= 0 && w_in < W) { + a_v[4] = load_a<REMAINDER>(A + ((1 * H + 1) * W + 0) * K, mask_v); + } + if (w_in + 1 >= 0 && w_in + 1 < W) { + a_v[5] = load_a<REMAINDER>(A + ((1 * H + 1) * W + 1) * K, mask_v); + } + if (w_in + 2 >= 0 && w_in + 2 < W) { + a_v[6] = load_a<REMAINDER>(A + ((1 * H + 1) * W + 2) * K, mask_v); + } + } + + if (h_in + 2 >= 0 && h_in + 2 < H) { + if (w_in >= 0 && w_in < W) { + a_v[7] = load_a<REMAINDER>(A + ((1 * H + 2) * W + 0) * K, mask_v); + } + } + } + + array<__m256i, 4> a_sum_temp; + inner_prod_packed_<8, SUM_A, REMAINDER, true /* acc */>( + a_v.data(), reinterpret_cast<const __m256i *>(Bp) + 8, C, remainder, + a_sum_temp.data()); + if (SUM_A) { + a_sum[0] = _mm256_add_epi32(a_sum[0], a_sum_temp[0]); + a_sum[1] = _mm256_add_epi32(a_sum[1], a_sum_temp[1]); + a_sum[2] = _mm256_add_epi32(a_sum[2], a_sum_temp[2]); + a_sum[3] = _mm256_add_epi32(a_sum[3], a_sum_temp[3]); + } + + a_v[0] = A_zero_point_v; + a_v[1] = A_zero_point_v; + a_v[2] = A_zero_point_v; + a_v[3] = A_zero_point_v; + a_v[4] = A_zero_point_v; + a_v[5] = A_zero_point_v; + a_v[6] = A_zero_point_v; + a_v[7] = A_zero_point_v; + + if (t_in + 1 >= 0 && t_in + 1 < T) { + if (h_in + 2 >= 0 && h_in + 2 < H) { + if (w_in + 1 >= 0 && w_in + 1 < W) { + a_v[0] = load_a<REMAINDER>(A + ((1 * H + 2) * W + 1) * K, mask_v); + } + if (w_in + 2 >= 0 && w_in + 2 < W) { + a_v[1] = load_a<REMAINDER>(A + ((1 * H + 2) * W + 2) * K, mask_v); + } + } + } + + if (t_in + 2 >= 0 && t_in + 2 < T) { + if (h_in >= 0 && h_in < H) { + if (w_in >= 0 && w_in < W) { + a_v[2] = load_a<REMAINDER>(A + ((2 * H + 0) * W + 0) * K, mask_v); + } + if (w_in + 1 >= 0 && w_in + 1 < W) { + a_v[3] = load_a<REMAINDER>(A + ((2 * H + 0) * W + 1) * K, mask_v); + } + if (w_in + 2 >= 0 && w_in + 2 < W) { + a_v[4] = load_a<REMAINDER>(A + ((2 * H + 0) * W + 2) * K, mask_v); + } + } + + if (h_in + 1 >= 0 && h_in + 1 < H) { + if (w_in >= 0 && w_in < W) { + a_v[5] = load_a<REMAINDER>(A + ((2 * H + 1) * W + 0) * K, mask_v); + } + if (w_in + 1 >= 0 && w_in + 1 < W) { + a_v[6] = load_a<REMAINDER>(A + ((2 * H + 1) * W + 1) * K, mask_v); + } + if (w_in + 2 >= 0 && w_in + 2 < W) { + a_v[7] = load_a<REMAINDER>(A + ((2 * H + 1) * W + 2) * K, mask_v); + } + } + } + + inner_prod_packed_<8, SUM_A, REMAINDER, true /* acc */>( + a_v.data(), reinterpret_cast<const __m256i *>(Bp) + 16, C, remainder, + a_sum_temp.data()); + if (SUM_A) { + a_sum[0] = _mm256_add_epi32(a_sum[0], a_sum_temp[0]); + a_sum[1] = _mm256_add_epi32(a_sum[1], a_sum_temp[1]); + a_sum[2] = _mm256_add_epi32(a_sum[2], a_sum_temp[2]); + a_sum[3] = _mm256_add_epi32(a_sum[3], a_sum_temp[3]); + } + + a_v[0] = A_zero_point_v; + a_v[1] = A_zero_point_v; + a_v[2] = A_zero_point_v; + + if (t_in + 2 >= 0 && t_in + 2 < T) { + if (h_in + 2 >= 0 && h_in + 2 < H) { + if (w_in >= 0 && w_in < W) { + a_v[0] = load_a<REMAINDER>(A + ((2 * H + 2) * W + 0) * K, mask_v); + } + if (w_in + 1 >= 0 && w_in + 1 < W) { + a_v[1] = load_a<REMAINDER>(A + ((2 * H + 2) * W + 1) * K, mask_v); + } + if (w_in + 2 >= 0 && w_in + 2 < W) { + a_v[2] = load_a<REMAINDER>(A + ((2 * H + 2) * W + 2) * K, mask_v); + } + } + } + + inner_prod_packed_<3, SUM_A, REMAINDER, true /* acc */>( + a_v.data(), reinterpret_cast<const __m256i *>(Bp) + 24, C, remainder, + a_sum_temp.data()); + + if (SUM_A) { + a_sum[0] = _mm256_add_epi32(a_sum[0], a_sum_temp[0]); + a_sum[1] = _mm256_add_epi32(a_sum[1], a_sum_temp[1]); + a_sum[2] = _mm256_add_epi32(a_sum[2], a_sum_temp[2]); + a_sum[3] = _mm256_add_epi32(a_sum[3], a_sum_temp[3]); + + __m256i B_zero_point_v; + for (int i = 0; i < (REMAINDER ? (remainder / 8) : 4); ++i) { + if (PER_CHANNEL_QUANTIZATION) { + B_zero_point_v = _mm256_loadu_si256( + reinterpret_cast<const __m256i *>(B_zero_point + i * 8)); + } else { + B_zero_point_v = _mm256_set1_epi32(B_zero_point[0]); + } + _mm256_store_si256(reinterpret_cast<__m256i *>(&row_offsets[i * 8]), + _mm256_mullo_epi32(a_sum[i], B_zero_point_v)); + } + } +} + +template <bool SUM_A, bool FUSE_RELU> +static inline __attribute__((always_inline)) +void depthwise_3x3_kernel_(int H, int W, int K, int h, int w, + int stride_h, int stride_w, + int32_t A_zero_point, const uint8_t* A, + int32_t B_zero_point, const int8_t* Bp, + float C_multiplier, + int32_t C_zero_point, + int32_t* C_int32, uint8_t* C_uint8, + int32_t* row_offsets, + const int32_t *col_offsets, + const int32_t *bias) +{ + constexpr int S = 3; + constexpr int PAD_T = 1, PAD_L = 1, PAD_R = 1; + int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1; + int h_in = -PAD_T + h * stride_h; + int w_in = -PAD_L + w * stride_w; + + int k; + for (k = 0; k < K / 32 * 32; k += 32) { + inner_prod_3x3_packed_<SUM_A>( + H, W, K, h_in, w_in, + A + (h_in * W + w_in) * K + k, A_zero_point, + Bp + k * 10, &B_zero_point, + C_int32 + k, 0, &row_offsets[k]); + } + int remainder = K - k; + if (remainder) { + inner_prod_3x3_packed_<SUM_A, true>( + H, W, K, h_in, w_in, + A + (h_in * W + w_in) * K + k, A_zero_point, + Bp + k * 10, &B_zero_point, + C_int32 + k, remainder, &row_offsets[k]); + } + if (SUM_A) { + requantize_<FUSE_RELU, true> + ( + A_zero_point, C_multiplier, C_zero_point, + C_int32, C_uint8 + (h * W_OUT + w) * K, K, + row_offsets, + col_offsets, bias + ); + } +} + +template <bool SUM_A, bool FUSE_RELU> +static inline __attribute__((always_inline)) +void depthwise_3x3x3_kernel_(int T, int H, int W, int K, int t, int h, int w, + int stride_t, int stride_h, int stride_w, + int32_t A_zero_point, const uint8_t* A, + int32_t B_zero_point, const int8_t* Bp, + float C_multiplier, + int32_t C_zero_point, + int32_t* C_int32, uint8_t* C_uint8, + int32_t* row_offsets, + const int32_t *col_offsets, + const int32_t *bias) +{ + constexpr int R = 3, S = 3; + constexpr int PAD_P = 1, PAD_T = 1, PAD_B = 1, PAD_L = 1, PAD_R = 1; + int H_OUT = (H + PAD_T + PAD_B - R) / stride_h + 1; + int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1; + int t_in = -PAD_P + t * stride_t; + int h_in = -PAD_T + h * stride_h; + int w_in = -PAD_L + w * stride_w; + + int k; + for (k = 0; k < K / 32 * 32; k += 32) { + inner_prod_3x3x3_packed_<SUM_A>( + T, H, W, K, t_in, h_in, w_in, + A + ((t_in * H + h_in) * W + w_in) * K + k, A_zero_point, + Bp + k * 28, &B_zero_point, + C_int32 + k, 0, &row_offsets[k]); + } + int remainder = K - k; + if (remainder) { + inner_prod_3x3x3_packed_<SUM_A, true>( + T, H, W, K, t_in, h_in, w_in, + A + ((t_in * H + h_in) * W + w_in) * K + k, A_zero_point, + Bp + k * 28, &B_zero_point, + C_int32 + k, remainder, &row_offsets[k]); + } + if (SUM_A) { + requantize_<FUSE_RELU, true> + ( + A_zero_point, C_multiplier, C_zero_point, + C_int32, C_uint8 + ((t * H_OUT + h) * W_OUT + w) * K, K, + row_offsets, + col_offsets, bias + ); + } +} + +template <bool SUM_A> +static inline __attribute__((always_inline)) void +depthwise_3x3_per_channel_quantization_kernel_( + int H, int W, int K, int h, int w, int stride_h, int stride_w, + int32_t A_zero_point, const uint8_t *A, + const int32_t *B_zero_point, const int8_t *Bp, + const float *C_multiplier, int32_t C_zero_point, + int32_t *C_int32, uint8_t *C_uint8, + int32_t *row_offsets, const int32_t *col_offsets, const int32_t *bias) { + constexpr int S = 3; + constexpr int PAD_T = 1, PAD_L = 1, PAD_R = 1; + int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1; + int h_in = -PAD_T + h * stride_h; + int w_in = -PAD_L + w * stride_w; + + int k; + for (k = 0; k < K / 32 * 32; k += 32) { + inner_prod_3x3_packed_<SUM_A, false/*remainder*/, true/*per-channel*/>( + H, W, K, h_in, w_in, + A + (h_in * W + w_in) * K + k, A_zero_point, + Bp + k * 10, B_zero_point + k, + C_int32 + k, 0, &row_offsets[k]); + } + int remainder = K - k; + if (remainder) { + inner_prod_3x3_packed_<SUM_A, true/*remainder*/, true/*per-channel*/>( + H, W, K, h_in, w_in, + A + (h_in * W + w_in) * K + k, A_zero_point, + Bp + k * 10, B_zero_point + k, + C_int32 + k, remainder, &row_offsets[k]); + } + if (SUM_A) { + requantize_per_channel_<false, true> + ( + A_zero_point, C_multiplier, C_zero_point, + C_int32, C_uint8 + (h * W_OUT + w) * K, K, + row_offsets, + col_offsets, bias + ); + } +} + +static pair<int, int> closest_factors_(int n) { + int a = (int)std::sqrt(n); + while (n % a != 0) { + a--; + } + return { a, n / a }; // a <= n / a +} + +// TODO: short-circuit when B_zero_point is 0 or A_zero_point is 0 +// This implemntation should be general enough to handle not just 3x3 but other +// filter shapes by parameterizing with R and S but restricting it to just 3x3 +// for now. +template <bool FUSE_RESCALE = true, bool FUSE_RELU = false> +static inline __attribute__((always_inline)) +void depthwise_3x3_pad_1_(int N, int H, int W, int K, + int stride_h, int stride_w, + int32_t A_zero_point, const uint8_t *A, + int32_t B_zero_point, const Packed3x3ConvMatrix &B, + float C_multiplier, + int32_t C_zero_point, + int32_t* C_int32, uint8_t* C_uint8, + const int32_t *col_offsets, const int32_t *bias, + int thread_id, int num_threads) { + assert(K % 8 == 0); + constexpr int R = 3, S = 3; + constexpr int PAD_T = 1, PAD_B = 1, PAD_L = 1, PAD_R = 1; + int H_OUT = (H + PAD_T + PAD_B - R) / stride_h + 1; + int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1; + const int8_t* Bp = B.PackedMat(); + + int32_t row_offsets[(K + 31) / 32 * 32] __attribute__ ((aligned (64))); + int32_t *C_temp; + + int n_begin, n_end; + int h_begin, h_end, w_begin, w_end; + if (N >= num_threads) { + int n_per_thread = (N + num_threads - 1) / num_threads; + n_begin = std::min(thread_id * n_per_thread, N); + n_end = std::min(n_begin + n_per_thread, N); + h_begin = 0; + h_end = H_OUT; + w_begin = 0; + w_end = W_OUT; + } else { + int nthreads_per_n = num_threads / N; + n_begin = std::min(thread_id / nthreads_per_n, N); + n_end = std::min(n_begin + 1, N); + + int tid_of_n_begin = std::min(n_begin * nthreads_per_n, num_threads); + int tid_of_n_end = std::min(tid_of_n_begin + nthreads_per_n, num_threads); + int nthreads_of_n = tid_of_n_end - tid_of_n_begin; + int tid_within_n = thread_id - tid_of_n_begin; + assert(tid_within_n >= 0); + assert(tid_within_n < nthreads_of_n); + + // n is processed by num_threads_h * num_threads_w 2D grid of threads + int num_threads_h, num_threads_w; + // num_threads_w <= num_threads_h + tie(num_threads_w, num_threads_h) = closest_factors_(nthreads_of_n); + int tid_h = tid_within_n / num_threads_w; + int tid_w = tid_within_n % num_threads_w; + + int h_per_thread = (H_OUT + num_threads_h - 1) / num_threads_h; + h_begin = std::min(tid_h * h_per_thread, H_OUT); + h_end = std::min(h_begin + h_per_thread, H_OUT); + + int w_per_thread = (W_OUT + num_threads_w - 1) / num_threads_w; + w_begin = std::min(tid_w * w_per_thread, W_OUT); + w_end = std::min(w_begin + w_per_thread, W_OUT); + } + + for (int n = n_begin; n < n_end; ++n) { + const uint8_t* A_base = A + n * H * W * K; + uint8_t* C_uint8_base = C_uint8 + n * H_OUT * W_OUT * K; + + int h = 0; + int w = 0; + + if (h_begin == 0) { + if (w_begin == 0) { + C_temp = FUSE_RESCALE ? C_int32 + : C_int32 + ((n * H_OUT + h) * W_OUT + w) * K; + depthwise_3x3_kernel_<FUSE_RESCALE, FUSE_RELU>( + H, W, K, h, w, stride_h, stride_w, + A_zero_point, A_base, + B_zero_point, Bp, + C_multiplier, C_zero_point, C_temp, C_uint8_base, + row_offsets, col_offsets, bias); + } + + for (w = std::max(1, w_begin); w < std::min(W_OUT - 1, w_end); ++w) { + C_temp = FUSE_RESCALE ? C_int32 + : C_int32 + ((n * H_OUT + h) * W_OUT + w) * K; + depthwise_3x3_kernel_<FUSE_RESCALE, FUSE_RELU>( + H, W, K, h, w, stride_h, stride_w, + A_zero_point, A_base, + B_zero_point, Bp, + C_multiplier, C_zero_point, C_temp, C_uint8_base, + row_offsets, col_offsets, bias); + } + + if (w_end == W_OUT) { + w = W_OUT - 1; + C_temp = FUSE_RESCALE ? C_int32 + : C_int32 + ((n * H_OUT + h) * W_OUT + w) * K; + depthwise_3x3_kernel_<FUSE_RESCALE, FUSE_RELU>( + H, W, K, h, w, stride_h, stride_w, + A_zero_point, A_base, + B_zero_point, Bp, + C_multiplier, C_zero_point, C_temp, C_uint8_base, + row_offsets, col_offsets, bias); + } + } + + for (h = std::max(1, h_begin); h < std::min(H - 1, h_end); ++h) { + if (w_begin == 0) { + w = 0; + C_temp = FUSE_RESCALE ? C_int32 + : C_int32 + ((n * H_OUT + h) * W_OUT + w) * K; + depthwise_3x3_kernel_<FUSE_RESCALE, FUSE_RELU>( + H, W, K, h, w, stride_h, stride_w, + A_zero_point, A_base, + B_zero_point, Bp, + C_multiplier, C_zero_point, C_temp, C_uint8_base, + row_offsets, col_offsets, bias); + } + + for (w = std::max(1, w_begin); w < std::min(W_OUT - 1, w_end); ++w) { + C_temp = FUSE_RESCALE ? C_int32 + : C_int32 + ((n * H_OUT + h) * W_OUT + w) * K; + depthwise_3x3_kernel_<FUSE_RESCALE, FUSE_RELU>( + H, W, K, h, w, stride_h, stride_w, + A_zero_point, A_base, + B_zero_point, Bp, + C_multiplier, C_zero_point, C_temp, C_uint8_base, + row_offsets, col_offsets, bias); + } + + if (w_end == W_OUT) { + w = W_OUT - 1; + C_temp = FUSE_RESCALE ? C_int32 + : C_int32 + ((n * H_OUT + h) * W_OUT + w) * K; + depthwise_3x3_kernel_<FUSE_RESCALE, FUSE_RELU>( + H, W, K, h, w, stride_h, stride_w, + A_zero_point, A_base, + B_zero_point, Bp, + C_multiplier, C_zero_point, C_temp, C_uint8_base, + row_offsets, col_offsets, bias); + } + } + + if (h_end == H_OUT) { + h = H_OUT - 1; + w = 0; + if (w_begin == 0) { + C_temp = FUSE_RESCALE ? C_int32 + : C_int32 + ((n * H_OUT + h) * W_OUT + w) * K; + depthwise_3x3_kernel_<FUSE_RESCALE, FUSE_RELU>( + H, W, K, h, w, stride_h, stride_w, + A_zero_point, A_base, + B_zero_point, Bp, + C_multiplier, C_zero_point, C_temp, C_uint8_base, + row_offsets, col_offsets, bias); + } + + for (w = std::max(1, w_begin); w < std::min(W_OUT - 1, w_end); ++w) { + C_temp = FUSE_RESCALE ? C_int32 + : C_int32 + ((n * H_OUT + h) * W_OUT + w) * K; + depthwise_3x3_kernel_<FUSE_RESCALE, FUSE_RELU>( + H, W, K, h, w, stride_h, stride_w, + A_zero_point, A_base, + B_zero_point, Bp, + C_multiplier, C_zero_point, C_temp, C_uint8_base, + row_offsets, col_offsets, bias); + } + + if (w_end == W_OUT) { + w = W_OUT - 1; + C_temp = FUSE_RESCALE ? C_int32 + : C_int32 + ((n * H_OUT + h) * W_OUT + w) * K; + depthwise_3x3_kernel_<FUSE_RESCALE, FUSE_RELU>( + H, W, K, h, w, stride_h, stride_w, + A_zero_point, A_base, + B_zero_point, Bp, + C_multiplier, C_zero_point, C_temp, C_uint8_base, + row_offsets, col_offsets, bias); + } + } + } // for each n +}; + +template <bool FUSE_RESCALE = true, bool FUSE_RELU = false> +static inline __attribute__((always_inline)) +void depthwise_3x3x3_pad_1_(int N, int T, int H, int W, int K, + int stride_t, int stride_h, int stride_w, + int32_t A_zero_point, const uint8_t *A, + int32_t B_zero_point, + const Packed3x3x3ConvMatrix &B, + float C_multiplier, + int32_t C_zero_point, + int32_t* C_int32, uint8_t* C_uint8, + const int32_t *col_offsets, const int32_t *bias, + int thread_id, int num_threads) { + assert(K % 8 == 0); + constexpr int K_T = 3, K_H = 3, K_W = 3; + constexpr int PAD_P = 1, PAD_N = 1, PAD_T = 1, PAD_B = 1, PAD_L = 1, + PAD_R = 1; + int T_OUT = (T + PAD_P + PAD_N - K_T) / stride_t + 1; + int H_OUT = (H + PAD_T + PAD_B - K_H) / stride_h + 1; + int W_OUT = (W + PAD_L + PAD_R - K_W) / stride_w + 1; + const int8_t* Bp = B.PackedMat(); + + int32_t row_offsets[(K + 31) / 32 * 32] __attribute__ ((aligned (64))); + int32_t *C_temp; + + int n_begin, n_end; + int t_begin, t_end, h_begin, h_end; + if (N >= num_threads) { + int n_per_thread = (N + num_threads - 1) / num_threads; + n_begin = std::min(thread_id * n_per_thread, N); + n_end = std::min(n_begin + n_per_thread, N); + t_begin = 0; + t_end = T_OUT; + h_begin = 0; + h_end = H_OUT; + } else { + int nthreads_per_n = num_threads / N; + n_begin = std::min(thread_id / nthreads_per_n, N); + n_end = std::min(n_begin + 1, N); + + int tid_of_n_begin = std::min(n_begin * nthreads_per_n, num_threads); + int tid_of_n_end = std::min(tid_of_n_begin + nthreads_per_n, num_threads); + int nthreads_of_n = tid_of_n_end - tid_of_n_begin; + int tid_within_n = thread_id - tid_of_n_begin; + assert(tid_within_n >= 0); + assert(tid_within_n < nthreads_of_n); + + // n is processed by num_threads_t * num_threads_h 2D grid of threads + int num_threads_t, num_threads_h; + // num_threads_w <= num_threads_h + tie(num_threads_t, num_threads_h) = closest_factors_(nthreads_of_n); + int tid_t = tid_within_n / num_threads_h; + int tid_h = tid_within_n % num_threads_h; + + int t_per_thread = (T_OUT + num_threads_t - 1) / num_threads_t; + t_begin = std::min(tid_t * t_per_thread, T_OUT); + t_end = std::min(t_begin + t_per_thread, T_OUT); + + int h_per_thread = (H_OUT + num_threads_h - 1) / num_threads_h; + h_begin = std::min(tid_h * h_per_thread, H_OUT); + h_end = std::min(h_begin + h_per_thread, H_OUT); + } + + for (int n = n_begin; n < n_end; ++n) { + const uint8_t* A_base = A + n * T * H * W * K; + uint8_t* C_uint8_base = C_uint8 + n * T_OUT * H_OUT * W_OUT * K; + + for (int t = t_begin; t < t_end; ++t) { + for (int h = h_begin; h < h_end; ++h) { + for (int w = 0; w < W_OUT; ++w) { + C_temp = + FUSE_RESCALE + ? C_int32 + : C_int32 + (((n * T_OUT + t) * H_OUT + h) * W_OUT + w) * K; + depthwise_3x3x3_kernel_<FUSE_RESCALE, FUSE_RELU>( + T, H, W, K, t, h, w, stride_t, stride_h, stride_w, A_zero_point, + A_base, B_zero_point, Bp, C_multiplier, + C_zero_point, C_temp, C_uint8_base, row_offsets, col_offsets, + bias); + } // w + } // h + } // t + } // for each n +}; + +template <bool FUSE_RESCALE = true> +static inline __attribute__((always_inline)) void +depthwise_3x3_per_channel_quantization_pad_1_( + int N, int H, int W, int K, int stride_h, int stride_w, + int32_t A_zero_point, const uint8_t *A, const int32_t *B_zero_point, + const Packed3x3ConvMatrix &B, const float *C_multiplier, + int32_t C_zero_point, int32_t *C_int32, uint8_t *C_uint8, + const int32_t *col_offsets, const int32_t *bias, int thread_id, + int num_threads) { + assert(K % 8 == 0); + constexpr int R = 3, S = 3; + constexpr int PAD_T = 1, PAD_B = 1, PAD_L = 1, PAD_R = 1; + int H_OUT = (H + PAD_T + PAD_B - R) / stride_h + 1; + int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1; + const int8_t* Bp = B.PackedMat(); + + int32_t row_offsets[(K + 31) / 32 * 32] __attribute__ ((aligned (64))); + int32_t *C_temp; + + int n_begin, n_end; + int h_begin, h_end, w_begin, w_end; + if (N >= num_threads) { + int n_per_thread = (N + num_threads - 1) / num_threads; + n_begin = std::min(thread_id * n_per_thread, N); + n_end = std::min(n_begin + n_per_thread, N); + h_begin = 0; + h_end = H_OUT; + w_begin = 0; + w_end = W_OUT; + } else { + int nthreads_per_n = num_threads / N; + n_begin = std::min(thread_id / nthreads_per_n, N); + n_end = std::min(n_begin + 1, N); + + int tid_of_n_begin = std::min(n_begin * nthreads_per_n, num_threads); + int tid_of_n_end = std::min(tid_of_n_begin + nthreads_per_n, num_threads); + int nthreads_of_n = tid_of_n_end - tid_of_n_begin; + int tid_within_n = thread_id - tid_of_n_begin; + assert(tid_within_n >= 0); + assert(tid_within_n < nthreads_of_n); + + // n is processed by num_threads_h * num_threads_w 2D grid of threads + int num_threads_h, num_threads_w; + // num_threads_w <= num_threads_h + tie(num_threads_w, num_threads_h) = closest_factors_(nthreads_of_n); + int tid_h = tid_within_n / num_threads_w; + int tid_w = tid_within_n % num_threads_w; + + int h_per_thread = (H_OUT + num_threads_h - 1) / num_threads_h; + h_begin = std::min(tid_h * h_per_thread, H_OUT); + h_end = std::min(h_begin + h_per_thread, H_OUT); + + int w_per_thread = (W_OUT + num_threads_w - 1) / num_threads_w; + w_begin = std::min(tid_w * w_per_thread, W_OUT); + w_end = std::min(w_begin + w_per_thread, W_OUT); + } + + for (int n = n_begin; n < n_end; ++n) { + const uint8_t* A_base = A + n * H * W * K; + uint8_t* C_uint8_base = C_uint8 + n * H_OUT * W_OUT * K; + + int h = 0; + int w = 0; + + if (h_begin == 0) { + if (w_begin == 0) { + C_temp = FUSE_RESCALE ? C_int32 + : C_int32 + ((n * H_OUT + h) * W_OUT + w) * K; + depthwise_3x3_per_channel_quantization_kernel_<FUSE_RESCALE>( + H, W, K, h, w, stride_h, stride_w, + A_zero_point, A_base, + B_zero_point, Bp, + C_multiplier, C_zero_point, C_temp, C_uint8_base, + row_offsets, col_offsets, bias); + } + + for (w = std::max(1, w_begin); w < std::min(W_OUT - 1, w_end); ++w) { + C_temp = FUSE_RESCALE ? C_int32 + : C_int32 + ((n * H_OUT + h) * W_OUT + w) * K; + depthwise_3x3_per_channel_quantization_kernel_<FUSE_RESCALE>( + H, W, K, h, w, stride_h, stride_w, + A_zero_point, A_base, + B_zero_point, Bp, + C_multiplier, C_zero_point, C_temp, C_uint8_base, + row_offsets, col_offsets, bias); + } + + if (w_end == W_OUT) { + w = W_OUT - 1; + C_temp = FUSE_RESCALE ? C_int32 + : C_int32 + ((n * H_OUT + h) * W_OUT + w) * K; + depthwise_3x3_per_channel_quantization_kernel_<FUSE_RESCALE>( + H, W, K, h, w, stride_h, stride_w, + A_zero_point, A_base, + B_zero_point, Bp, + C_multiplier, C_zero_point, C_temp, C_uint8_base, + row_offsets, col_offsets, bias); + } + } + + for (h = std::max(1, h_begin); h < std::min(H - 1, h_end); ++h) { + if (w_begin == 0) { + w = 0; + C_temp = FUSE_RESCALE ? C_int32 + : C_int32 + ((n * H_OUT + h) * W_OUT + w) * K; + depthwise_3x3_per_channel_quantization_kernel_<FUSE_RESCALE>( + H, W, K, h, w, stride_h, stride_w, + A_zero_point, A_base, + B_zero_point, Bp, + C_multiplier, C_zero_point, C_temp, C_uint8_base, + row_offsets, col_offsets, bias); + } + + for (w = std::max(1, w_begin); w < std::min(W_OUT - 1, w_end); ++w) { + C_temp = FUSE_RESCALE ? C_int32 + : C_int32 + ((n * H_OUT + h) * W_OUT + w) * K; + depthwise_3x3_per_channel_quantization_kernel_<FUSE_RESCALE>( + H, W, K, h, w, stride_h, stride_w, + A_zero_point, A_base, + B_zero_point, Bp, + C_multiplier, C_zero_point, C_temp, C_uint8_base, + row_offsets, col_offsets, bias); + } + + if (w_end == W_OUT) { + w = W_OUT - 1; + C_temp = FUSE_RESCALE ? C_int32 + : C_int32 + ((n * H_OUT + h) * W_OUT + w) * K; + depthwise_3x3_per_channel_quantization_kernel_<FUSE_RESCALE>( + H, W, K, h, w, stride_h, stride_w, + A_zero_point, A_base, + B_zero_point, Bp, + C_multiplier, C_zero_point, C_temp, C_uint8_base, + row_offsets, col_offsets, bias); + } + } + + if (h_end == H_OUT) { + h = H_OUT - 1; + w = 0; + if (w_begin == 0) { + C_temp = FUSE_RESCALE ? C_int32 + : C_int32 + ((n * H_OUT + h) * W_OUT + w) * K; + depthwise_3x3_per_channel_quantization_kernel_<FUSE_RESCALE>( + H, W, K, h, w, stride_h, stride_w, + A_zero_point, A_base, + B_zero_point, Bp, + C_multiplier, C_zero_point, C_temp, C_uint8_base, + row_offsets, col_offsets, bias); + } + + for (w = std::max(1, w_begin); w < std::min(W_OUT - 1, w_end); ++w) { + C_temp = FUSE_RESCALE ? C_int32 + : C_int32 + ((n * H_OUT + h) * W_OUT + w) * K; + depthwise_3x3_per_channel_quantization_kernel_<FUSE_RESCALE>( + H, W, K, h, w, stride_h, stride_w, + A_zero_point, A_base, + B_zero_point, Bp, + C_multiplier, C_zero_point, C_temp, C_uint8_base, + row_offsets, col_offsets, bias); + } + + if (w_end == W_OUT) { + w = W_OUT - 1; + C_temp = FUSE_RESCALE ? C_int32 + : C_int32 + ((n * H_OUT + h) * W_OUT + w) * K; + depthwise_3x3_per_channel_quantization_kernel_<FUSE_RESCALE>( + H, W, K, h, w, stride_h, stride_w, + A_zero_point, A_base, + B_zero_point, Bp, + C_multiplier, C_zero_point, C_temp, C_uint8_base, + row_offsets, col_offsets, bias); + } + } + } // for each n +}; + +// assumption: W > 3 and H > 3 +void depthwise_3x3_pad_1( + int N, + int H, + int W, + int K, + int stride_h, + int stride_w, + int32_t A_zero_point, const uint8_t* A, + const Packed3x3ConvMatrix& B, + int32_t* C, + int thread_id, + int num_threads) { + if (7 == H && 7 == W && 1 == stride_h && 1 == stride_w) { + depthwise_3x3_pad_1_<false>( + N, H, W, K, + stride_h, stride_w, + A_zero_point, A, + 0, B, + 0.0f, 0, C, nullptr, + nullptr, nullptr, + thread_id, num_threads); + } else if (14 == H && 14 == W && 2 == stride_h && 2 == stride_w) { + depthwise_3x3_pad_1_<false>( + N, H, W, K, + stride_h, stride_w, + A_zero_point, A, + 0, B, + 0.0f, 0, C, nullptr, + nullptr, nullptr, + thread_id, num_threads); + } else if (1 == stride_h && 1 == stride_w) { + depthwise_3x3_pad_1_<false>( + N, H, W, K, + stride_h, stride_w, + A_zero_point, A, + 0, B, + 0.0f, 0, C, nullptr, + nullptr, nullptr, + thread_id, num_threads); + } else if (2 == stride_h && 2 == stride_w) { + depthwise_3x3_pad_1_<false>( + N, H, W, K, + stride_h, stride_w, + A_zero_point, A, + 0, B, + 0.0f, 0, C, nullptr, + nullptr, nullptr, + thread_id, num_threads); + } else { + depthwise_3x3_pad_1_<false>( + N, H, W, K, + stride_h, stride_w, + A_zero_point, A, + 0, B, + 0.0f, 0, C, nullptr, + nullptr, nullptr, + thread_id, num_threads); + } +} + +void depthwise_3x3_pad_1( + int N, int H, int W, int K, + int stride_h, int stride_w, + int32_t A_zero_point, const uint8_t* A, + int32_t B_zero_point, const Packed3x3ConvMatrix& B, + float C_multiplier, int32_t C_zero_point, uint8_t* C, + const int32_t* col_offsets, + const int32_t* bias, + int thread_id, int num_threads, + bool fuse_relu) { + int32_t C_int32_temp[(K + 31) / 32 * 32]; + if (fuse_relu) { + if (7 == H && 7 == W && 1 == stride_h && 1 == stride_w) { + depthwise_3x3_pad_1_<true /* FUSE_RESCALE */, true /* FUSE_RELU */>( + N, H, W, K, + stride_h, stride_w, + A_zero_point, A, + B_zero_point, B, + C_multiplier, C_zero_point, C_int32_temp, C, + col_offsets, bias, + thread_id, num_threads); + } else if (14 == H && 14 == W && 2 == stride_h && 2 == stride_w) { + depthwise_3x3_pad_1_<true /* FUSE_RESCALE */, true /* FUSE_RELU */>( + N, H, W, K, + stride_h, stride_w, + A_zero_point, A, + B_zero_point, B, + C_multiplier, C_zero_point, C_int32_temp, C, + col_offsets, bias, + thread_id, num_threads); + } else if (1 == stride_h && 1 == stride_w) { + depthwise_3x3_pad_1_<true /* FUSE_RESCALE */, true /* FUSE_RELU */>( + N, H, W, K, + stride_h, stride_w, + A_zero_point, A, + B_zero_point, B, + C_multiplier, C_zero_point, C_int32_temp, C, + col_offsets, bias, + thread_id, num_threads); + } else if (2 == stride_h && 2 == stride_w) { + depthwise_3x3_pad_1_<true /* FUSE_RESCALE */, true /* FUSE_RELU */>( + N, H, W, K, + stride_h, stride_w, + A_zero_point, A, + B_zero_point, B, + C_multiplier, C_zero_point, C_int32_temp, C, + col_offsets, bias, + thread_id, num_threads); + } else { + depthwise_3x3_pad_1_<true /* FUSE_RESCALE */, true /* FUSE_RELU */>( + N, H, W, K, + stride_h, stride_w, + A_zero_point, A, + B_zero_point, B, + C_multiplier, C_zero_point, C_int32_temp, C, + col_offsets, bias, + thread_id, num_threads); + } + } else { + if (7 == H && 7 == W && 1 == stride_h && 1 == stride_w) { + depthwise_3x3_pad_1_( + N, H, W, K, + stride_h, stride_w, + A_zero_point, A, + B_zero_point, B, + C_multiplier, C_zero_point, C_int32_temp, C, + col_offsets, bias, + thread_id, num_threads); + } else if (14 == H && 14 == W && 2 == stride_h && 2 == stride_w) { + depthwise_3x3_pad_1_( + N, H, W, K, + stride_h, stride_w, + A_zero_point, A, + B_zero_point, B, + C_multiplier, C_zero_point, C_int32_temp, C, + col_offsets, bias, + thread_id, num_threads); + } else if (1 == stride_h && 1 == stride_w) { + depthwise_3x3_pad_1_( + N, H, W, K, + stride_h, stride_w, + A_zero_point, A, + B_zero_point, B, + C_multiplier, C_zero_point, C_int32_temp, C, + col_offsets, bias, + thread_id, num_threads); + } else if (2 == stride_h && 2 == stride_w) { + depthwise_3x3_pad_1_( + N, H, W, K, + stride_h, stride_w, + A_zero_point, A, + B_zero_point, B, + C_multiplier, C_zero_point, C_int32_temp, C, + col_offsets, bias, + thread_id, num_threads); + } else { + depthwise_3x3_pad_1_( + N, H, W, K, + stride_h, stride_w, + A_zero_point, A, + B_zero_point, B, + C_multiplier, C_zero_point, C_int32_temp, C, + col_offsets, bias, + thread_id, num_threads); + } + } +} + +void depthwise_3x3x3_pad_1( + int N, + int T, + int H, + int W, + int K, + int stride_t, + int stride_h, + int stride_w, + int32_t A_zero_point, const uint8_t* A, + const Packed3x3x3ConvMatrix& B, + int32_t* C, + int thread_id, + int num_threads) { + depthwise_3x3x3_pad_1_<false /* FUSE_RESCALE */>( + N, T, H, W, K, + stride_t, stride_h, stride_w, + A_zero_point, A, + 0, B, + 0.0f, 0, C, nullptr, + nullptr, nullptr, + thread_id, num_threads); +} + +static void depthwise_3x3x3_pad_1_( + int N, int T, int H, int W, int K, + int stride_t, int stride_h, int stride_w, + int32_t A_zero_point, const uint8_t* A, + int32_t B_zero_point, const Packed3x3x3ConvMatrix& B, + float C_multiplier, int32_t C_zero_point, uint8_t* C, + const int32_t* col_offsets, + const int32_t* bias, + int thread_id, int num_threads) { + int32_t C_int32_temp[(K + 31) / 32 * 32]; + depthwise_3x3x3_pad_1_<true /* FUSE_RESCALE */, false /* FUSE_RELU */>( + N, T, H, W, K, + stride_t, stride_h, stride_w, + A_zero_point, A, + B_zero_point, B, + C_multiplier, C_zero_point, C_int32_temp, C, + col_offsets, bias, + thread_id, num_threads); +} + +static void depthwise_3x3x3_pad_1_relu_fused_( + int N, int T, int H, int W, int K, + int stride_t, int stride_h, int stride_w, + int32_t A_zero_point, const uint8_t* A, + int32_t B_zero_point, const Packed3x3x3ConvMatrix& B, + float C_multiplier, int32_t C_zero_point, uint8_t* C, + const int32_t* col_offsets, + const int32_t* bias, + int thread_id, int num_threads) { + int32_t C_int32_temp[(K + 31) / 32 * 32]; + depthwise_3x3x3_pad_1_<true /* FUSE_RESCALE */, true /* FUSE_RELU */>( + N, T, H, W, K, + stride_t, stride_h, stride_w, + A_zero_point, A, + B_zero_point, B, + C_multiplier, C_zero_point, C_int32_temp, C, + col_offsets, bias, + thread_id, num_threads); +} + +void depthwise_3x3x3_pad_1( + int N, int T, int H, int W, int K, + int stride_t, int stride_h, int stride_w, + int32_t A_zero_point, const uint8_t* A, + int32_t B_zero_point, const Packed3x3x3ConvMatrix& B, + float C_multiplier, int32_t C_zero_point, uint8_t* C, + const int32_t* col_offsets, + const int32_t* bias, + bool fuse_relu, + int thread_id, int num_threads) { + // If we inline the following two functions, I see stack overflow. + if (fuse_relu) { + depthwise_3x3x3_pad_1_relu_fused_( + N, T, H, W, K, stride_t, stride_h, stride_w, A_zero_point, A, + B_zero_point, B, C_multiplier, C_zero_point, C, + col_offsets, bias, thread_id, num_threads); + } else { + depthwise_3x3x3_pad_1_(N, T, H, W, K, stride_t, stride_h, stride_w, + A_zero_point, A, B_zero_point, B, C_multiplier, + C_zero_point, C, col_offsets, bias, + thread_id, num_threads); + } +} + +void depthwise_3x3_per_channel_quantization_pad_1( + int N, int H, int W, int K, + int stride_h, int stride_w, + int32_t A_zero_point, const uint8_t* A, + const int32_t *B_zero_point, const Packed3x3ConvMatrix& Bp, + const float *C_multiplier, int32_t C_zero_point, uint8_t* C, + const int32_t* col_offsets, + const int32_t* bias, + int thread_id, int num_threads) { + int32_t C_int32_temp[(K + 31) / 32 * 32]; + if (7 == H && 7 == W && 1 == stride_h && 1 == stride_w) { + depthwise_3x3_per_channel_quantization_pad_1_( + N, H, W, K, + stride_h, stride_w, + A_zero_point, A, + B_zero_point, Bp, + C_multiplier, C_zero_point, C_int32_temp, C, + col_offsets, bias, + thread_id, num_threads); + } else if (14 == H && 14 == W && 2 == stride_h && 2 == stride_w) { + depthwise_3x3_per_channel_quantization_pad_1_( + N, H, W, K, + stride_h, stride_w, + A_zero_point, A, + B_zero_point, Bp, + C_multiplier, C_zero_point, C_int32_temp, C, + col_offsets, bias, + thread_id, num_threads); + } else if (1 == stride_h && 1 == stride_w) { + depthwise_3x3_per_channel_quantization_pad_1_( + N, H, W, K, + stride_h, stride_w, + A_zero_point, A, + B_zero_point, Bp, + C_multiplier, C_zero_point, C_int32_temp, C, + col_offsets, bias, + thread_id, num_threads); + } else if (2 == stride_h && 2 == stride_w) { + depthwise_3x3_per_channel_quantization_pad_1_( + N, H, W, K, + stride_h, stride_w, + A_zero_point, A, + B_zero_point, Bp, + C_multiplier, C_zero_point, C_int32_temp, C, + col_offsets, bias, + thread_id, num_threads); + } else { + depthwise_3x3_per_channel_quantization_pad_1_( + N, H, W, K, + stride_h, stride_w, + A_zero_point, A, + B_zero_point, Bp, + C_multiplier, C_zero_point, C_int32_temp, C, + col_offsets, bias, + thread_id, num_threads); + } +} + +} // namespace fbgemm2 diff --git a/src/FbgemmI8Depthwise.h b/src/FbgemmI8Depthwise.h new file mode 100644 index 0000000..bc62c84 --- /dev/null +++ b/src/FbgemmI8Depthwise.h @@ -0,0 +1,105 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#pragma once + +#include <cstdint> + +namespace fbgemm2 +{ + +// KERNEL_PROD is the product of all kernels. +// For example, KERNEL_PROD = 9 for 3x3, and 27 for 3x3x3. +template <int KERNEL_PROD> +class PackedDepthWiseConvMatrix +{ + public: + // smat in RSG layout + PackedDepthWiseConvMatrix(int K, const std::int8_t *smat); + virtual ~PackedDepthWiseConvMatrix(); + + const std::int8_t* PackedMat() const { + return pmat_; + } + + private: + int K_; + std::int8_t* pmat_; +}; // Packed3x3ConvMatrix + +using Packed3x3ConvMatrix = PackedDepthWiseConvMatrix<3 * 3>; +using Packed3x3x3ConvMatrix = PackedDepthWiseConvMatrix<3 * 3 * 3>; + +/** + * Depth-wise 3x3 convolution with pad=1 and stride=1 and K a multiple of 8 + * @params A The input image in NHWK layout + * @params Bp The pre-packed filter + */ +void depthwise_3x3_pad_1( + int N, + int H, + int W, + int K, + int stride_h, + int stride_w, + std::int32_t A_zero_point, const std::uint8_t* A, + const Packed3x3ConvMatrix& Bp, + std::int32_t* C, + int thread_id = 0, int num_threads = 1); + +/** + * Depth-wise 3x3 convolution with pad=1 and stride=1 and K a multiple of 8 + * This version is fused with requantization. + */ +void depthwise_3x3_pad_1( + int N, int H, int W, int K, + int stride_h, int stride_w, + std::int32_t A_zero_point, const std::uint8_t* A, + std::int32_t B_zero_point, const Packed3x3ConvMatrix& Bp, + float C_multiplier, std::int32_t C_zero_point, + std::uint8_t* C, + const std::int32_t* col_offsets, const std::int32_t* bias, + int thread_id = 0, int num_threads = 1, bool fuse_relu = false); + +/** + * Depth-wise 3x3 convolution with pad=1 and stride=1 and K a multiple of 8 + * This version is fused with requantization and uses per-channel quantization. + */ +void depthwise_3x3_per_channel_quantization_pad_1( + int N, int H, int W, int K, + int stride_h, int stride_w, + std::int32_t A_zero_point, const std::uint8_t* A, + const std::int32_t *B_zero_point, const Packed3x3ConvMatrix& Bp, + const float *C_multiplier, std::int32_t C_zero_point, + std::uint8_t* C, + const std::int32_t* col_offsets, const std::int32_t* bias, + int thread_id = 0, int num_threads = 1); + +void depthwise_3x3x3_pad_1( + int N, + int T, + int H, + int W, + int K, + int stride_t, + int stride_h, + int stride_w, + std::int32_t A_zero_point, const std::uint8_t* A, + const Packed3x3x3ConvMatrix& Bp, + std::int32_t* C, + int thread_id = 0, int num_threads = 1); + +void depthwise_3x3x3_pad_1( + int N, int T, int H, int W, int K, + int stride_t, int stride_h, int stride_w, + std::int32_t A_zero_point, const std::uint8_t* A, + std::int32_t B_zero_point, const Packed3x3x3ConvMatrix& Bp, + float C_multiplier, std::int32_t C_zero_point, + std::uint8_t* C, + const std::int32_t* col_offsets, const std::int32_t* bias, + bool fuse_relu = false, int thread_id = 0, int num_threads = 1); + +} // namespace fbgemm2 diff --git a/src/FbgemmI8Spmdm.cc b/src/FbgemmI8Spmdm.cc new file mode 100644 index 0000000..723a467 --- /dev/null +++ b/src/FbgemmI8Spmdm.cc @@ -0,0 +1,508 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include "fbgemm/FbgemmI8Spmdm.h" + +#include <algorithm> +#include <array> +#include <cassert> +#include <cmath> +#include <cstring> + +#include <immintrin.h> + +#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN +double spmdm_initial_time = 0.0; +double spmdm_transpose_uint8_time = 0.0; +double spmdm_transpose_32xN_time = 0.0; +double spmdm_compute_time = 0.0; +double spmdm_transpose_Nx32_time = 0.0; +double spmdm_run_time = 0.0; +#endif + +using namespace std; + +namespace fbgemm2 { + +CompressedSparseColumn::CompressedSparseColumn(int num_of_rows, int num_of_cols) + : num_rows_(num_of_rows), + colptr_(num_of_cols + 1), + hyper_sparse_(false), + old_nnz_(-1) {} + +double CompressedSparseColumn::Density() const { + return (double)NumOfNonZeros() / (NumOfRows() * NumOfCols()); +} + +bool CompressedSparseColumn::IsHyperSparse() const { + if (NumOfNonZeros() != old_nnz_) { + old_nnz_ = NumOfNonZeros(); + // The number of non-zero per row is very small. + hyper_sparse_ = (double)old_nnz_ / NumOfRows() < 0.08; + } + + return hyper_sparse_; +} + +static void transpose_8rows( + int N, + const uint8_t* src, + int ld_src, + uint8_t* dst, + int ld_dst) { + constexpr int M = 8; + int j; + // vectorized loop + for (j = 0; j < N / 32 * 32; j += 32) { + // a : a0 a1 ... a31 + // b : b0 b1 ... b31 + // c : c0 c1 ... c31 + // d : d0 d1 ... d31 + __m256i a = _mm256_lddqu_si256( + reinterpret_cast<const __m256i*>(src + j + 0 * ld_src)); + __m256i b = _mm256_lddqu_si256( + reinterpret_cast<const __m256i*>(src + j + 1 * ld_src)); + __m256i c = _mm256_lddqu_si256( + reinterpret_cast<const __m256i*>(src + j + 2 * ld_src)); + __m256i d = _mm256_lddqu_si256( + reinterpret_cast<const __m256i*>(src + j + 3 * ld_src)); + __m256i e = _mm256_lddqu_si256( + reinterpret_cast<const __m256i*>(src + j + 4 * ld_src)); + __m256i f = _mm256_lddqu_si256( + reinterpret_cast<const __m256i*>(src + j + 5 * ld_src)); + __m256i g = _mm256_lddqu_si256( + reinterpret_cast<const __m256i*>(src + j + 6 * ld_src)); + __m256i h = _mm256_lddqu_si256( + reinterpret_cast<const __m256i*>(src + j + 7 * ld_src)); + + // even-odd interleaving + // ab_lo : a0 b0 a1 b1 ... a7 b7 | a16 b16 ... a23 b23 + // ab_hi : a8 b8 a9 b9 ... a15 b15 | a24 b24 ... a31 b31 + // cd_lo : c0 d0 c1 d1 ... c7 d7 | c16 d16 ... c23 d23 + // cd_hi : c8 d8 c9 d9 ... c15 d15 | c24 d24 ... c31 d31 + __m256i ab_lo = _mm256_unpacklo_epi8(a, b); + __m256i ab_hi = _mm256_unpackhi_epi8(a, b); + __m256i cd_lo = _mm256_unpacklo_epi8(c, d); + __m256i cd_hi = _mm256_unpackhi_epi8(c, d); + __m256i ef_lo = _mm256_unpacklo_epi8(e, f); + __m256i ef_hi = _mm256_unpackhi_epi8(e, f); + __m256i gh_lo = _mm256_unpacklo_epi8(g, h); + __m256i gh_hi = _mm256_unpackhi_epi8(g, h); + + // 4-row interleaving but permuted at 128-bit granularity + // abcd0 : a0 b0 c0 d0 ... a-d3 | a-d16 ... a-d19 + // abcd1 : a4 b4 c4 d4 ... a-d7 | a-d20 ... a-d23 + // abcd2 : a8 b8 c8 d8 ... a-d11 | a-d24 ... a-d27 + // abcd3 : a12 b12 c12 d12 ... a-d15 | a-d28 ... a-d31 + __m256i abcd0 = _mm256_unpacklo_epi16(ab_lo, cd_lo); + __m256i abcd1 = _mm256_unpackhi_epi16(ab_lo, cd_lo); + __m256i abcd2 = _mm256_unpacklo_epi16(ab_hi, cd_hi); + __m256i abcd3 = _mm256_unpackhi_epi16(ab_hi, cd_hi); + __m256i efgh0 = _mm256_unpacklo_epi16(ef_lo, gh_lo); + __m256i efgh1 = _mm256_unpackhi_epi16(ef_lo, gh_lo); + __m256i efgh2 = _mm256_unpacklo_epi16(ef_hi, gh_hi); + __m256i efgh3 = _mm256_unpackhi_epi16(ef_hi, gh_hi); + + // 8-row interleaving + __m256i y0 = _mm256_unpacklo_epi32(abcd0, efgh0); + __m256i y1 = _mm256_unpackhi_epi32(abcd0, efgh0); + __m256i y2 = _mm256_unpacklo_epi32(abcd1, efgh1); + __m256i y3 = _mm256_unpackhi_epi32(abcd1, efgh1); + __m256i y4 = _mm256_unpacklo_epi32(abcd2, efgh2); + __m256i y5 = _mm256_unpackhi_epi32(abcd2, efgh2); + __m256i y6 = _mm256_unpacklo_epi32(abcd3, efgh3); + __m256i y7 = _mm256_unpackhi_epi32(abcd3, efgh3); + + // Storing with 128-bit lanes are permuted so that everything is in order + _mm_storel_epi64( + reinterpret_cast<__m128i*>(dst + (j + 0) * ld_dst), + _mm256_castsi256_si128(y0)); + *reinterpret_cast<int64_t*>(dst + (j + 1) * ld_dst) = + _mm256_extract_epi64(y0, 1); + _mm_storel_epi64( + reinterpret_cast<__m128i*>(dst + (j + 2) * ld_dst), + _mm256_castsi256_si128(y1)); + *reinterpret_cast<int64_t*>(dst + (j + 3) * ld_dst) = + _mm256_extract_epi64(y1, 1); + _mm_storel_epi64( + reinterpret_cast<__m128i*>(dst + (j + 4) * ld_dst), + _mm256_castsi256_si128(y2)); + *reinterpret_cast<int64_t*>(dst + (j + 5) * ld_dst) = + _mm256_extract_epi64(y2, 1); + _mm_storel_epi64( + reinterpret_cast<__m128i*>(dst + (j + 6) * ld_dst), + _mm256_castsi256_si128(y3)); + *reinterpret_cast<int64_t*>(dst + (j + 7) * ld_dst) = + _mm256_extract_epi64(y3, 1); + _mm_storel_epi64( + reinterpret_cast<__m128i*>(dst + (j + 8) * ld_dst), + _mm256_castsi256_si128(y4)); + *reinterpret_cast<int64_t*>(dst + (j + 9) * ld_dst) = + _mm256_extract_epi64(y4, 1); + _mm_storel_epi64( + reinterpret_cast<__m128i*>(dst + (j + 10) * ld_dst), + _mm256_castsi256_si128(y5)); + *reinterpret_cast<int64_t*>(dst + (j + 11) * ld_dst) = + _mm256_extract_epi64(y5, 1); + _mm_storel_epi64( + reinterpret_cast<__m128i*>(dst + (j + 12) * ld_dst), + _mm256_castsi256_si128(y6)); + *reinterpret_cast<int64_t*>(dst + (j + 13) * ld_dst) = + _mm256_extract_epi64(y6, 1); + _mm_storel_epi64( + reinterpret_cast<__m128i*>(dst + (j + 14) * ld_dst), + _mm256_castsi256_si128(y7)); + *reinterpret_cast<int64_t*>(dst + (j + 15) * ld_dst) = + _mm256_extract_epi64(y7, 1); + *reinterpret_cast<int64_t*>(dst + (j + 16) * ld_dst) = + _mm256_extract_epi64(y0, 2); + *reinterpret_cast<int64_t*>(dst + (j + 17) * ld_dst) = + _mm256_extract_epi64(y0, 3); + *reinterpret_cast<int64_t*>(dst + (j + 18) * ld_dst) = + _mm256_extract_epi64(y1, 2); + *reinterpret_cast<int64_t*>(dst + (j + 19) * ld_dst) = + _mm256_extract_epi64(y1, 3); + *reinterpret_cast<int64_t*>(dst + (j + 20) * ld_dst) = + _mm256_extract_epi64(y2, 2); + *reinterpret_cast<int64_t*>(dst + (j + 21) * ld_dst) = + _mm256_extract_epi64(y2, 3); + *reinterpret_cast<int64_t*>(dst + (j + 22) * ld_dst) = + _mm256_extract_epi64(y3, 2); + *reinterpret_cast<int64_t*>(dst + (j + 23) * ld_dst) = + _mm256_extract_epi64(y3, 3); + *reinterpret_cast<int64_t*>(dst + (j + 24) * ld_dst) = + _mm256_extract_epi64(y4, 2); + *reinterpret_cast<int64_t*>(dst + (j + 25) * ld_dst) = + _mm256_extract_epi64(y4, 3); + *reinterpret_cast<int64_t*>(dst + (j + 26) * ld_dst) = + _mm256_extract_epi64(y5, 2); + *reinterpret_cast<int64_t*>(dst + (j + 27) * ld_dst) = + _mm256_extract_epi64(y5, 3); + *reinterpret_cast<int64_t*>(dst + (j + 28) * ld_dst) = + _mm256_extract_epi64(y6, 2); + *reinterpret_cast<int64_t*>(dst + (j + 29) * ld_dst) = + _mm256_extract_epi64(y6, 3); + *reinterpret_cast<int64_t*>(dst + (j + 30) * ld_dst) = + _mm256_extract_epi64(y7, 2); + *reinterpret_cast<int64_t*>(dst + (j + 31) * ld_dst) = + _mm256_extract_epi64(y7, 3); + } + + // scalar loop for remainder + for (; j < N; ++j) { + for (int i = 0; i < M; ++i) { + dst[j * ld_dst + i] = src[j + i * ld_src]; + } + } +} + +// TODO: fallback when AVX2 is not available +void CompressedSparseColumn::SpMDM( + const block_type_t& block, + const uint8_t* A, + int lda, + bool accumulation, + int32_t* C, + int ldc) const { + int K = NumOfRows(); + int N = block.col_size; + + if (K == 0 || N == 0) { + return; + } + +#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN + std::chrono::time_point<std::chrono::high_resolution_clock> t_very_start, + t_start, t_end; + double dt; + t_start = std::chrono::high_resolution_clock::now(); + t_very_start = std::chrono::high_resolution_clock::now(); +#endif + + uint8_t A_buffer[K * 32] __attribute__((aligned(64))); + int32_t C_buffer[N * 32] __attribute__((aligned(64))); + + // If we compute C = C + A * B, where B is a sparse matrix in CSC format, for + // each non-zero in B, we'd need to access the corresponding column in A. + // This results in strided access, which we want to avoid. + // Instead, we pre-transpose A and C, and compute C = (C^T + B^T * A^T)^T + + if (IsHyperSparse()) { + // The cost of transpose is O(K*N) and we do O(NNZ*N) multiplications. + // If NNZ/K is small, it's not worth doing transpose so we just use this + // scalar loop. + if (!accumulation) { + for (int i = block.row_start; i < block.row_start + block.row_size; ++i) { + for (int j = block.col_start; j < block.col_start + block.col_size; + ++j) { + C[(i - block.row_start) * ldc + j - block.col_start] = 0; + } + } + } + for (int j = block.col_start; j < block.col_start + block.col_size; ++j) { + for (int k = colptr_[j]; k < colptr_[j + 1]; ++k) { + int row = rowidx_[k]; + int w = values_[k]; + for (int i = block.row_start; i < block.row_start + block.row_size; + ++i) { + C[(i - block.row_start) * ldc + j - block.col_start] += + A[i * lda + row] * w; + } + } + } // for each column of B + return; + } + +#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN + t_end = std::chrono::high_resolution_clock::now(); + dt = std::chrono::duration_cast<std::chrono::nanoseconds>(t_end - t_start) + .count(); + spmdm_initial_time += (dt); + t_start = std::chrono::high_resolution_clock::now(); +#endif + + // Take 32 rows at a time + int i_end = block.row_start + block.row_size; + for (int i1 = block.row_start; i1 < i_end; i1 += 32) { + // Transpose 32 x K submatrix of A + if (i_end - i1 < 32) { + uint8_t A_temp_buffer[K * 32] __attribute__((aligned(64))); + for (int i2 = 0; i2 < (i_end - i1) / 8 * 8; i2 += 8) { + transpose_8rows(K, A + (i1 + i2) * lda, lda, A_buffer + i2, 32); + } + + for (int i2 = (i_end - i1) / 8 * 8; i2 < i_end - i1; ++i2) { + memcpy( + A_temp_buffer + i2 * K, A + (i1 + i2) * lda, K * sizeof(uint8_t)); + } + memset( + A_temp_buffer + (i_end - i1) * K, + 0, + (32 - (i_end - i1)) * K * sizeof(uint8_t)); + for (int i2 = (i_end - i1) / 8 * 8; i2 < 32; i2 += 8) { + transpose_8rows(K, A_temp_buffer + i2 * K, K, A_buffer + i2, 32); + } + } else { + for (int i2 = 0; i2 < 32; i2 += 8) { + transpose_8rows(K, A + (i1 + i2) * lda, lda, A_buffer + i2, 32); + } + } + +#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN + t_end = std::chrono::high_resolution_clock::now(); + dt = std::chrono::duration_cast<std::chrono::nanoseconds>(t_end - t_start) + .count(); + spmdm_transpose_uint8_time += (dt); + t_start = std::chrono::high_resolution_clock::now(); +#endif + + if (accumulation) { + // Transpose 32 x N submatrix of C to fill N x 32 C_buffer + transpose_simd( + std::min(32, i_end - i1), + N, + reinterpret_cast<const float*>(C + (i1 - block.row_start) * ldc), + ldc, + reinterpret_cast<float*>(C_buffer), + 32); + } else { + memset(C_buffer, 0, N * 32 * sizeof(int32_t)); + } + +#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN + t_end = std::chrono::high_resolution_clock::now(); + dt = std::chrono::duration_cast<std::chrono::nanoseconds>(t_end - t_start) + .count(); + spmdm_transpose_32xN_time += (dt); + t_start = std::chrono::high_resolution_clock::now(); +#endif + + for (int j = 0; j < block.col_size; ++j) { + int j_start = j + block.col_start; + int k = colptr_[j_start]; + int k_end_aligned = + colptr_[j_start] + (colptr_[j_start + 1] - colptr_[j_start]) / 4 * 4; + + for (; k < k_end_aligned; k += 4) { + __m256i w = + _mm256_set1_epi32(*(reinterpret_cast<const int32_t*>(&values_[k]))); + array<__m256i, 4> a; + a[0] = _mm256_load_si256( + reinterpret_cast<const __m256i*>(&A_buffer[rowidx_[k + 0] * 32])); + a[1] = _mm256_load_si256( + reinterpret_cast<const __m256i*>(&A_buffer[rowidx_[k + 1] * 32])); + a[2] = _mm256_load_si256( + reinterpret_cast<const __m256i*>(&A_buffer[rowidx_[k + 2] * 32])); + a[3] = _mm256_load_si256( + reinterpret_cast<const __m256i*>(&A_buffer[rowidx_[k + 3] * 32])); + + __m256i a01_lo = _mm256_unpacklo_epi8(a[0], a[1]); + __m256i a01_hi = _mm256_unpackhi_epi8(a[0], a[1]); + __m256i a23_lo = _mm256_unpacklo_epi8(a[2], a[3]); + __m256i a23_hi = _mm256_unpackhi_epi8(a[2], a[3]); + + a[0] = _mm256_unpacklo_epi16(a01_lo, a23_lo); + a[1] = _mm256_unpackhi_epi16(a01_lo, a23_lo); + a[2] = _mm256_unpacklo_epi16(a01_hi, a23_hi); + a[3] = _mm256_unpackhi_epi16(a01_hi, a23_hi); + + array<__m256i, 4> ab; + ab[0] = _mm256_maddubs_epi16(a[0], w); + ab[1] = _mm256_maddubs_epi16(a[1], w); + ab[2] = _mm256_maddubs_epi16(a[2], w); + ab[3] = _mm256_maddubs_epi16(a[3], w); + + __m256i one = _mm256_set1_epi16(1); + ab[0] = _mm256_madd_epi16(ab[0], one); + ab[1] = _mm256_madd_epi16(ab[1], one); + ab[2] = _mm256_madd_epi16(ab[2], one); + ab[3] = _mm256_madd_epi16(ab[3], one); + + array<__m256i, 4> t; + t[0] = _mm256_permute2f128_si256(ab[0], ab[1], 0x20); + t[1] = _mm256_permute2f128_si256(ab[2], ab[3], 0x20); + t[2] = _mm256_permute2f128_si256(ab[0], ab[1], 0x31); + t[3] = _mm256_permute2f128_si256(ab[2], ab[3], 0x31); + + _mm256_store_si256( + reinterpret_cast<__m256i*>(&C_buffer[j * 32 + 0 * 8]), + _mm256_add_epi32( + _mm256_load_si256(reinterpret_cast<const __m256i*>( + &C_buffer[j * 32 + 0 * 8])), + t[0])); + _mm256_store_si256( + reinterpret_cast<__m256i*>(&C_buffer[j * 32 + 1 * 8]), + _mm256_add_epi32( + _mm256_load_si256(reinterpret_cast<const __m256i*>( + &C_buffer[j * 32 + 1 * 8])), + t[1])); + _mm256_store_si256( + reinterpret_cast<__m256i*>(&C_buffer[j * 32 + 2 * 8]), + _mm256_add_epi32( + _mm256_load_si256(reinterpret_cast<const __m256i*>( + &C_buffer[j * 32 + 2 * 8])), + t[2])); + _mm256_store_si256( + reinterpret_cast<__m256i*>(&C_buffer[j * 32 + 3 * 8]), + _mm256_add_epi32( + _mm256_load_si256(reinterpret_cast<const __m256i*>( + &C_buffer[j * 32 + 3 * 8])), + t[3])); + } + + int remainder = colptr_[j_start + 1] - k; + assert(remainder < 4); + if (remainder > 0) { + int32_t temp_w = 0; + for (int r = 0; r < remainder; ++r) { + (reinterpret_cast<int8_t*>(&temp_w))[r] = values_[k + r]; + } + __m256i w = _mm256_set1_epi32(temp_w); + array<__m256i, 4> a; + a[0] = _mm256_load_si256( + reinterpret_cast<const __m256i*>(&A_buffer[rowidx_[k + 0] * 32])); + a[1] = remainder > 1 + ? _mm256_load_si256(reinterpret_cast<const __m256i*>( + &A_buffer[rowidx_[k + 1] * 32])) + : _mm256_setzero_si256(); + a[2] = remainder > 2 + ? _mm256_load_si256(reinterpret_cast<const __m256i*>( + &A_buffer[rowidx_[k + 2] * 32])) + : _mm256_setzero_si256(); + a[3] = _mm256_setzero_si256(); + + __m256i a01_lo = _mm256_unpacklo_epi8(a[0], a[1]); + __m256i a01_hi = _mm256_unpackhi_epi8(a[0], a[1]); + __m256i a23_lo = _mm256_unpacklo_epi8(a[2], a[3]); + __m256i a23_hi = _mm256_unpackhi_epi8(a[2], a[3]); + + a[0] = _mm256_unpacklo_epi16(a01_lo, a23_lo); + a[1] = _mm256_unpackhi_epi16(a01_lo, a23_lo); + a[2] = _mm256_unpacklo_epi16(a01_hi, a23_hi); + a[3] = _mm256_unpackhi_epi16(a01_hi, a23_hi); + + array<__m256i, 4> ab; + ab[0] = _mm256_maddubs_epi16(a[0], w); + ab[1] = _mm256_maddubs_epi16(a[1], w); + ab[2] = _mm256_maddubs_epi16(a[2], w); + ab[3] = _mm256_maddubs_epi16(a[3], w); + + __m256i one = _mm256_set1_epi16(1); + ab[0] = _mm256_madd_epi16(ab[0], one); + ab[1] = _mm256_madd_epi16(ab[1], one); + ab[2] = _mm256_madd_epi16(ab[2], one); + ab[3] = _mm256_madd_epi16(ab[3], one); + + array<__m256i, 4> t; + t[0] = _mm256_permute2f128_si256(ab[0], ab[1], 0x20); + t[1] = _mm256_permute2f128_si256(ab[2], ab[3], 0x20); + t[2] = _mm256_permute2f128_si256(ab[0], ab[1], 0x31); + t[3] = _mm256_permute2f128_si256(ab[2], ab[3], 0x31); + + _mm256_store_si256( + reinterpret_cast<__m256i*>(&C_buffer[j * 32 + 0 * 8]), + _mm256_add_epi32( + _mm256_load_si256(reinterpret_cast<const __m256i*>( + &C_buffer[j * 32 + 0 * 8])), + t[0])); + _mm256_store_si256( + reinterpret_cast<__m256i*>(&C_buffer[j * 32 + 1 * 8]), + _mm256_add_epi32( + _mm256_load_si256(reinterpret_cast<const __m256i*>( + &C_buffer[j * 32 + 1 * 8])), + t[1])); + _mm256_store_si256( + reinterpret_cast<__m256i*>(&C_buffer[j * 32 + 2 * 8]), + _mm256_add_epi32( + _mm256_load_si256(reinterpret_cast<const __m256i*>( + &C_buffer[j * 32 + 2 * 8])), + t[2])); + _mm256_store_si256( + reinterpret_cast<__m256i*>(&C_buffer[j * 32 + 3 * 8]), + _mm256_add_epi32( + _mm256_load_si256(reinterpret_cast<const __m256i*>( + &C_buffer[j * 32 + 3 * 8])), + t[3])); + } + } // for each column of B + +#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN + t_end = std::chrono::high_resolution_clock::now(); + dt = std::chrono::duration_cast<std::chrono::nanoseconds>(t_end - t_start) + .count(); + spmdm_compute_time += (dt); + t_start = std::chrono::high_resolution_clock::now(); +#endif + + // Transpose N x 32 C_buffer to fill 32 x N submatrix of C + transpose_simd( + N, + std::min(32, i_end - i1), + reinterpret_cast<const float*>(C_buffer), + 32, + reinterpret_cast<float*>(C + (i1 - block.row_start) * ldc), + ldc); + +#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN + t_end = std::chrono::high_resolution_clock::now(); + dt = std::chrono::duration_cast<std::chrono::nanoseconds>(t_end - t_start) + .count(); + spmdm_transpose_Nx32_time += (dt); + t_start = std::chrono::high_resolution_clock::now(); +#endif + } + +#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN + t_end = std::chrono::high_resolution_clock::now(); + dt = + std::chrono::duration_cast<std::chrono::nanoseconds>(t_end - t_very_start) + .count(); + spmdm_run_time += (dt); + t_start = std::chrono::high_resolution_clock::now(); +#endif +} + +} // namespace fbgemm2 diff --git a/src/GenerateKernel.h b/src/GenerateKernel.h new file mode 100644 index 0000000..30160d1 --- /dev/null +++ b/src/GenerateKernel.h @@ -0,0 +1,154 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#pragma once +#include <asmjit/asmjit.h> +#include <cpuinfo.h> +#include <map> +#include <tuple> +#include "fbgemm/Fbgemm.h" + +namespace fbgemm2 { + +namespace x86 = asmjit::x86; + +/** + * @brief AVX2/AVX512 JIT assembly code generator. + * @tparam TA Type of matrix A. + * @tparam TB Type of matrix B. + * @tparam TC Type of matrix C. + * @tparam accT Accumulation type, currently we support 16-bit (std::int16_t) or + * 32-bit (std::int32_t) accumulation. + */ +template <typename TA, typename TB, typename TC, typename accT> +class CodeGenBase { + public: + using jit_micro_kernel_fp = void (*)( + TA* bufferA, + TB* bufferB, + TB* b_pf, + TC* bufferC, + int kc, + int ldc); + + /** + * @brief Constructor for initializing AVX2/AVX512 registers. + */ + CodeGenBase() + : CRegs_avx2_{x86::ymm0, + x86::ymm1, + x86::ymm2, + x86::ymm3, + x86::ymm4, + x86::ymm5, + x86::ymm6, + x86::ymm7, + x86::ymm8, + x86::ymm9, + x86::ymm10, + x86::ymm11}, + CRegs_avx512_{ + x86::zmm0, x86::zmm1, x86::zmm2, x86::zmm3, x86::zmm4, + x86::zmm5, x86::zmm6, x86::zmm7, x86::zmm8, x86::zmm9, + x86::zmm10, x86::zmm11, x86::zmm12, x86::zmm13, x86::zmm14, + x86::zmm15, x86::zmm16, x86::zmm17, x86::zmm18, x86::zmm19, + x86::zmm20, x86::zmm21, x86::zmm22, x86::zmm23, x86::zmm24, + x86::zmm25, x86::zmm26, x86::zmm27, + } { + // vector width in bits + if (cpuinfo_initialize()) { + if (cpuinfo_has_x86_avx512f()) { + vectorWidth_ = 512; + } else if (cpuinfo_has_x86_avx2()) { + vectorWidth_ = 256; + } else { + // TODO: Have default path + assert(0 && "unsupported architecture"); + return; + } + } else { + throw std::runtime_error("Failed to initialize cpuinfo!"); + } + // vector width in elements + VLEN_ = vectorWidth_ / 8 * sizeof(TA); + } + + /** + * @brief Get or Create the instructions for macro-kernel. + * + * If the problem size (mc, nc) and accumulation flag (accum) can be found in + * the code cache (a hash map), then get the macro-kernel instructions + * directly from it. Otherwise, create the instructions for macro-kernel, and + * store that into the code cache. + */ + template <inst_set_t instSet> + jit_micro_kernel_fp + getOrCreate(bool accum, int32_t mc, int32_t nc, int32_t kc, int32_t ldc); + + /** + * @brief Generate instructions for initializing the C registers to 0. + */ + template <inst_set_t instSet> + void initCRegs( + asmjit::X86Emitter* a, + int rowRegs, + int colRegs, + int leadingDimCRegAssign = 4); + + /** + * @brief Generate instructions for computing block in the rank-k update. + */ + template <inst_set_t instSet> + void genComputeBlock( + asmjit::X86Emitter* a, + asmjit::X86Gp buffer_A, + asmjit::X86Gp buffer_B, + asmjit::X86Gp B_pf, + int rowRegs, + int colRegs, + int lda, + int leadingDimCRegAssign = 4); + + /** + * @brief Generate instructions for storing the C registers back to the + * memory. + */ + template <inst_set_t instSet> + void storeCRegs( + asmjit::X86Emitter* a, + int rowRegs, + int colRegs, + asmjit::X86Gp C_Offset, + asmjit::X86Gp ldcReg, + bool accum, + int leadingDimCRegAssign = 4); + + private: + asmjit::X86Ymm + CRegs_avx2_[12]; ///< AVX2 ymm registers for C in the micro-kernel. + asmjit::X86Zmm + CRegs_avx512_[28]; ///< AVX512 zmm registers for C in the micro-kernel. + int vectorWidth_; ///< Vector width in bits. + int VLEN_; ///< Vector width in elements. + static thread_local asmjit::JitRuntime rt_; ///< JIT Runtime for asmjit. + static thread_local asmjit::CodeHolder code_; ///< JIT Code Holder for asmjit. + static thread_local std::map<std::tuple<bool, int, int>, jit_micro_kernel_fp> + codeCache_; ///< JIT Code Cache for reuse. +}; + +template <typename TA, typename TB, typename TC, typename accT> +thread_local asmjit::JitRuntime CodeGenBase<TA, TB, TC, accT>::rt_; + +template <typename TA, typename TB, typename TC, typename accT> +thread_local asmjit::CodeHolder CodeGenBase<TA, TB, TC, accT>::code_; + +template <typename TA, typename TB, typename TC, typename accT> +thread_local std::map< + std::tuple<bool, int, int>, + typename CodeGenBase<TA, TB, TC, accT>::jit_micro_kernel_fp> + CodeGenBase<TA, TB, TC, accT>::codeCache_; + +} // namespace fbgemm2 diff --git a/src/GenerateKernelU8S8S32ACC16.cc b/src/GenerateKernelU8S8S32ACC16.cc new file mode 100644 index 0000000..2ffe3ab --- /dev/null +++ b/src/GenerateKernelU8S8S32ACC16.cc @@ -0,0 +1,292 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include <iostream> +#include "GenerateKernel.h" + +namespace fbgemm2 { + +namespace x86 = asmjit::x86; + +/** + * Generate AVX2 instructions for initializing the C registers to 0 in 16-bit + * Accumulation kernel. + */ +template <> +template <> +void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::initCRegs< + inst_set_t::avx2>( + asmjit::X86Emitter* a, + int rowRegs, + int colRegs, + int leadingDimCRegAssign) { + for (int i = 0; i < rowRegs; ++i) { + for (int j = 0; j < colRegs; ++j) { + a->vxorps( + CRegs_avx2_[i * leadingDimCRegAssign + j], + CRegs_avx2_[i * leadingDimCRegAssign + j], + CRegs_avx2_[i * leadingDimCRegAssign + j]); + } + } +} + +/** + * Generate AVX2 instructions for computing block in the rank-k update of 16-bit + * Accmulation kernel. + */ +template <> +template <> +void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::genComputeBlock< + inst_set_t::avx2>( + asmjit::X86Emitter* a, + asmjit::X86Gp buffer_A, + asmjit::X86Gp buffer_B, + asmjit::X86Gp /* unused (reserved for prefetching)*/, + int rowRegs, + int colRegs, + int lda, + int leadingDimCRegAssign) { + // used for matrix A + asmjit::X86Ymm AReg = x86::ymm12; + + asmjit::X86Ymm tmpReg = x86::ymm14; + + for (int i = 0; i < rowRegs; ++i) { + // broadcast A + a->vpbroadcastw( + AReg, x86::dword_ptr(buffer_A, (i * lda) * sizeof(uint8_t))); + for (int j = 0; j < colRegs; ++j) { + a->vpmaddubsw( + tmpReg, AReg, x86::dword_ptr(buffer_B, j * VLEN_ * sizeof(int8_t))); + a->vpaddsw( + CRegs_avx2_[i * leadingDimCRegAssign + j], + tmpReg, + CRegs_avx2_[i * leadingDimCRegAssign + j]); + // Prefetching is hurting performance in some cases + // because prefetch instructions itself consumes a slot + // in pipeline issue thus slowing down the kernel. + // if((i == rowRegs - 1) && j % 2 == 0){ + // a->prefetcht0(x86::dword_ptr(B_pf, j*VLEN_*sizeof(int8_t))); + //} + } + } +} + +/** + * Generate AVX2 instructions for storing the C registers back to the memory in + * 16-bit Accumulation kernel. + */ +template <> +template <> +void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::storeCRegs< + inst_set_t::avx2>( + asmjit::X86Emitter* a, + int rowRegs, + int colRegs, + asmjit::X86Gp C_Offset, + asmjit::X86Gp ldcReg, + bool accum, + int leadingDimCRegAssign) { + asmjit::X86Xmm extractDest128 = x86::xmm15; + asmjit::X86Ymm extractDest256 = x86::ymm15; + + for (int i = 0; i < rowRegs; ++i) { + a->imul(C_Offset, ldcReg, i * sizeof(int32_t)); + for (int j = 0; j < colRegs; ++j) { + for (int idx = 0; idx < 2; ++idx) { + a->vextracti128( + extractDest128, CRegs_avx2_[i * leadingDimCRegAssign + j], idx); + a->vpmovsxwd(extractDest256, extractDest128); + asmjit::X86Mem destAddr = x86::dword_ptr( + a->zcx(), C_Offset, 0, (j * 2 + idx) * 8 * sizeof(int32_t)); + if (accum) { + a->vpaddd(extractDest256, extractDest256, destAddr); + } + a->vmovups(destAddr, extractDest256); + } + } + } +} + +/** + * Get or Create the AVX2 instructions for 16-bit Accumulation macro-kernel. + * + */ +template <> +template <> +CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::jit_micro_kernel_fp +CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<inst_set_t::avx2>( + bool accum, + int32_t mc, + int32_t nc, + int32_t kc, + int32_t /* unused */) { + auto kernelSig = std::make_tuple(accum, mc, nc); + if (codeCache_.find(kernelSig) != codeCache_.end()) { + return codeCache_[kernelSig]; + } + + code_.reset(false); + code_.init(rt_.getCodeInfo()); + asmjit::X86Assembler assembler(&code_); + asmjit::X86Emitter* a = assembler.asEmitter(); + // ToDo: Dump in a file for debugging + // code dumping/logging + // asmjit::FileLogger logger(stderr); + // code_.setLogger(&logger); + + constexpr int kBlock = PackingTraits<int8_t, int16_t, inst_set_t::avx2>::KCB; + constexpr int mRegBlockSize = + PackingTraits<int8_t, int16_t, inst_set_t::avx2>::MR; + // constexpr int nRegBlockSize = + // PackingTraits<int8_t, int16_t, inst_set_t::avx2>::NR; + constexpr int row_interleave = + PackingTraits<int8_t, int16_t, inst_set_t::avx2>::ROW_INTERLEAVE; + int mRegBlocks = mc / mRegBlockSize; + int mRegBlocksRem = mc % mRegBlockSize; + assert(kc % row_interleave == 0 && "kc must be a multiple of row_interleave"); + // assert((nc == nRegBlockSize) && + //"nc must be equal to the number of register blocks"); + + // arguments to the function created + asmjit::X86Gp buffer_A = a->zdi(); + asmjit::X86Gp buffer_B = a->zsi(); + asmjit::X86Gp B_pf = a->zdx(); + asmjit::X86Gp CBase = a->zcx(); + asmjit::X86Gp kSize = a->gpzRef(8); + asmjit::X86Gp ldcReg = a->gpzRef(9); + + asmjit::FuncDetail func; + func.init( + asmjit:: + FuncSignature6<void, uint8_t*, int8_t*, int8_t*, int32_t*, int, int>( + asmjit::CallConv::kIdHost)); + + asmjit::FuncFrameInfo ffi; + ffi.setDirtyRegs( + asmjit::X86Reg::kKindVec, + asmjit::Utils::mask(0, 1, 2, 3, 4, 5, 6, 7) | + asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14, 15)); + ffi.setDirtyRegs( + asmjit::X86Reg::kKindGp, asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14)); + + asmjit::FuncArgsMapper args(&func); + args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg); + + args.updateFrameInfo(ffi); + + asmjit::FuncFrameLayout layout; + layout.init(func, ffi); + + asmjit::FuncUtils::emitProlog(a, layout); + asmjit::FuncUtils::allocArgs(a, layout, args); + + asmjit::Label Loopk = a->newLabel(); + asmjit::Label LoopMBlocks = a->newLabel(); + + asmjit::X86Gp buffer_B_saved = a->gpzRef(10); + asmjit::X86Gp C_Offset = a->gpzRef(11); + // asmjit::X86Gp B_pf_saved = a->gpzRef(12); + asmjit::X86Gp iIdx = a->gpzRef(13); + asmjit::X86Gp kIdx = a->gpzRef(14); + + int colRegs = nc * row_interleave * sizeof(int8_t) / VLEN_; + if (mRegBlocks > 0) { + // move 0 to iteration variables + a->mov(iIdx, 0); + + // save B_buffer address + a->mov(buffer_B_saved, buffer_B); + // a->mov(B_pf_saved, B_pf); + + a->bind(LoopMBlocks); + a->inc(iIdx); + + int rowRegs = mRegBlockSize; + + // init C registers + initCRegs<inst_set_t::avx2>(a, rowRegs, colRegs); + + // init k loop index + a->mov(kIdx, 0); + a->bind(Loopk); + // k is incremented by row_interleave + a->add(kIdx, row_interleave); + + genComputeBlock<inst_set_t::avx2>( + a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock); + + // update buffer_A address for next k iteration + a->add(buffer_A, row_interleave * sizeof(uint8_t)); + + // update buffer_B address for next k iteration + a->add(buffer_B, VLEN_ * colRegs * sizeof(int8_t)); + // a->add(B_pf, VLEN_ * colRegs * sizeof(int8_t)); + + a->cmp(kIdx, kSize); + a->jl(Loopk); + + // store C matrix + storeCRegs<inst_set_t::avx2>(a, rowRegs, colRegs, C_Offset, ldcReg, accum); + + // increment A for next block + a->sub(buffer_A, kSize); + a->add(buffer_A, (rowRegs)*kBlock * sizeof(uint8_t)); + // increment C for next block + a->imul(C_Offset, ldcReg, rowRegs * sizeof(int32_t)); + a->add(CBase, C_Offset); + // reset B + a->mov(buffer_B, buffer_B_saved); + // a->mov(B_pf, B_pf_saved); + + a->cmp(iIdx, mRegBlocks); + a->jl(LoopMBlocks); + } + // generate code for remainder + if (mRegBlocksRem > 0) { + asmjit::Label LoopkRem = a->newLabel(); + int rowRegs = mRegBlocksRem; + + // init C registers + initCRegs<inst_set_t::avx2>(a, rowRegs, colRegs); + + // init k loop index + a->mov(kIdx, 0); + a->bind(LoopkRem); + + // k is incremented by row_interleave + a->add(kIdx, row_interleave); + + genComputeBlock<inst_set_t::avx2>( + a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock); + + // update buffer_A address for next k iteration + a->add(buffer_A, row_interleave * sizeof(uint8_t)); + + // update buffer_B address for next k iteration + a->add(buffer_B, VLEN_ * colRegs * sizeof(int8_t)); + // a->add(B_pf, VLEN_ * colRegs * sizeof(int8_t)); + + a->cmp(kIdx, kSize); + a->jl(LoopkRem); + + // store C matrix + storeCRegs<inst_set_t::avx2>(a, rowRegs, colRegs, C_Offset, ldcReg, accum); + } + + asmjit::FuncUtils::emitEpilog(a, layout); + + jit_micro_kernel_fp fn; + asmjit::Error err = rt_.add(&fn, &code_); + if (err) { + std::cout << "Error: in fn add" << std::endl; + return nullptr; + } + codeCache_[kernelSig] = fn; + return fn; +} + +} // namespace fbgemm2 diff --git a/src/GenerateKernelU8S8S32ACC16_avx512.cc b/src/GenerateKernelU8S8S32ACC16_avx512.cc new file mode 100644 index 0000000..e613cf1 --- /dev/null +++ b/src/GenerateKernelU8S8S32ACC16_avx512.cc @@ -0,0 +1,295 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include <iostream> +#include "GenerateKernel.h" + +namespace fbgemm2 { + +namespace x86 = asmjit::x86; + +/** + * Generate AVX512 instructions for initializing the C registers to 0 in 16-bit + * Accumulation kernel. + */ +template <> +template <> +void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::initCRegs< + inst_set_t::avx512>( + asmjit::X86Emitter* a, + int rowRegs, + int colRegs, + int leadingDimCRegAssign) { + for (int i = 0; i < rowRegs; ++i) { + for (int j = 0; j < colRegs; ++j) { + a->vxorps( + CRegs_avx512_[i * leadingDimCRegAssign + j], + CRegs_avx512_[i * leadingDimCRegAssign + j], + CRegs_avx512_[i * leadingDimCRegAssign + j]); + } + } +} + +/** + * Generate AVX512 instructions for computing block in the rank-k update of + * 16-bit Accmulation kernel. + */ +template <> +template <> +void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::genComputeBlock< + inst_set_t::avx512>( + asmjit::X86Emitter* a, + asmjit::X86Gp buffer_A, + asmjit::X86Gp buffer_B, + asmjit::X86Gp /* unused (reserved for prefetching)*/, + int rowRegs, + int colRegs, + int lda, + int leadingDimCRegAssign) { + // used for matrix A + asmjit::X86Zmm AReg = x86::zmm29; + + asmjit::X86Zmm tmpReg = x86::zmm30; + + for (int i = 0; i < rowRegs; ++i) { + // broadcast A + a->vpbroadcastw( + AReg, x86::dword_ptr(buffer_A, (i * lda) * sizeof(uint8_t))); + for (int j = 0; j < colRegs; ++j) { + a->vpmaddubsw( + tmpReg, AReg, x86::dword_ptr(buffer_B, j * VLEN_ * sizeof(int8_t))); + a->vpaddsw( + CRegs_avx512_[i * leadingDimCRegAssign + j], + tmpReg, + CRegs_avx512_[i * leadingDimCRegAssign + j]); + // Prefetching is hurting performance in some cases + // because prefetch instructions itself consumes a slot + // in pipeline issue thus slowing down the kernel. + // if((i == rowRegs - 1) && j % 2 == 0){ + // a->prefetcht0(x86::dword_ptr(B_pf, j*VLEN_*sizeof(int8_t))); + //} + } + } +} + +/** + * Generate AVX512 instructions for storing the C registers back to the memory + * in 16-bit Accumulation kernel. + */ +template <> +template <> +void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::storeCRegs< + inst_set_t::avx512>( + asmjit::X86Emitter* a, + int rowRegs, + int colRegs, + asmjit::X86Gp C_Offset, + asmjit::X86Gp ldcReg, + bool accum, + int leadingDimCRegAssign) { + asmjit::X86Ymm extractDest256 = x86::ymm31; + asmjit::X86Zmm extractDest512 = x86::zmm31; + + for (int i = 0; i < rowRegs; ++i) { + a->imul(C_Offset, ldcReg, i * sizeof(int32_t)); + for (int j = 0; j < colRegs; ++j) { + for (int idx = 0; idx < 2; ++idx) { + a->vextracti32x8( + extractDest256, CRegs_avx512_[i * leadingDimCRegAssign + j], idx); + a->vpmovsxwd(extractDest512, extractDest256); + asmjit::X86Mem destAddr = x86::dword_ptr( + a->zcx(), C_Offset, 0, (j * 2 + idx) * 16 * sizeof(int32_t)); + if (accum) { + a->vpaddd(extractDest512, extractDest512, destAddr); + } + a->vmovups(destAddr, extractDest512); + } + } + } +} + +/** + * Get or Create the AVX512 instructions for 16-bit Accumulation macro-kernel. + * + */ +template <> +template <> +CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::jit_micro_kernel_fp +CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<inst_set_t::avx512>( + bool accum, + int32_t mc, + int32_t nc, + int32_t kc, + int32_t /* unused */) { + auto kernelSig = std::make_tuple(accum, mc, nc); + if (codeCache_.find(kernelSig) != codeCache_.end()) { + return codeCache_[kernelSig]; + } + + code_.reset(false); + code_.init(rt_.getCodeInfo()); + asmjit::X86Assembler assembler(&code_); + asmjit::X86Emitter* a = assembler.asEmitter(); + // ToDo: Dump in a file for debugging + // code dumping/logging + // asmjit::FileLogger logger(stderr); + // code_.setLogger(&logger); + + constexpr int kBlock = + PackingTraits<int8_t, int16_t, inst_set_t::avx512>::KCB; + constexpr int mRegBlockSize = + PackingTraits<int8_t, int16_t, inst_set_t::avx512>::MR; + // constexpr int nRegBlockSize = + // PackingTraits<int8_t, int16_t, inst_set_t::avx512>::NR; + constexpr int row_interleave = + PackingTraits<int8_t, int16_t, inst_set_t::avx512>::ROW_INTERLEAVE; + int mRegBlocks = mc / mRegBlockSize; + int mRegBlocksRem = mc % mRegBlockSize; + assert(kc % row_interleave == 0 && "kc must be a multiple of row_interleave"); + // assert((nc == nRegBlockSize) && + //"nc must be equal to the number of register blocks"); + + // arguments to the function created + asmjit::X86Gp buffer_A = a->zdi(); + asmjit::X86Gp buffer_B = a->zsi(); + asmjit::X86Gp B_pf = a->zdx(); + asmjit::X86Gp CBase = a->zcx(); + asmjit::X86Gp kSize = a->gpzRef(8); + asmjit::X86Gp ldcReg = a->gpzRef(9); + + asmjit::FuncDetail func; + func.init( + asmjit:: + FuncSignature6<void, uint8_t*, int8_t*, int8_t*, int32_t*, int, int>( + asmjit::CallConv::kIdHost)); + + asmjit::FuncFrameInfo ffi; + ffi.setDirtyRegs( + asmjit::X86Reg::kKindVec, + asmjit::Utils::mask(0, 1, 2, 3, 4, 5, 6, 7) | + asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14, 15)); + ffi.setDirtyRegs( + asmjit::X86Reg::kKindGp, asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14)); + + asmjit::FuncArgsMapper args(&func); + args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg); + + args.updateFrameInfo(ffi); + + asmjit::FuncFrameLayout layout; + layout.init(func, ffi); + + asmjit::FuncUtils::emitProlog(a, layout); + asmjit::FuncUtils::allocArgs(a, layout, args); + + asmjit::Label Loopk = a->newLabel(); + asmjit::Label LoopMBlocks = a->newLabel(); + + asmjit::X86Gp buffer_B_saved = a->gpzRef(10); + asmjit::X86Gp C_Offset = a->gpzRef(11); + // asmjit::X86Gp B_pf_saved = a->gpzRef(12); + asmjit::X86Gp iIdx = a->gpzRef(13); + asmjit::X86Gp kIdx = a->gpzRef(14); + + int colRegs = nc * row_interleave * sizeof(int8_t) / VLEN_; + if (mRegBlocks > 0) { + // move 0 to iteration variables + a->mov(iIdx, 0); + + // save B_buffer address + a->mov(buffer_B_saved, buffer_B); + // a->mov(B_pf_saved, B_pf); + + a->bind(LoopMBlocks); + a->inc(iIdx); + + int rowRegs = mRegBlockSize; + + // init C registers + initCRegs<inst_set_t::avx512>(a, rowRegs, colRegs); + + // init k loop index + a->mov(kIdx, 0); + a->bind(Loopk); + // k is incremented by row_interleave + a->add(kIdx, row_interleave); + + genComputeBlock<inst_set_t::avx512>( + a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock); + + // update buffer_A address for next k iteration + a->add(buffer_A, row_interleave * sizeof(uint8_t)); + + // update buffer_B address for next k iteration + a->add(buffer_B, VLEN_ * colRegs * sizeof(int8_t)); + // a->add(B_pf, VLEN_ * colRegs * sizeof(int8_t)); + + a->cmp(kIdx, kSize); + a->jl(Loopk); + + // store C matrix + storeCRegs<inst_set_t::avx512>( + a, rowRegs, colRegs, C_Offset, ldcReg, accum); + + // increment A for next block + a->sub(buffer_A, kSize); + a->add(buffer_A, (rowRegs)*kBlock * sizeof(uint8_t)); + // increment C for next block + a->imul(C_Offset, ldcReg, rowRegs * sizeof(int32_t)); + a->add(CBase, C_Offset); + // reset B + a->mov(buffer_B, buffer_B_saved); + // a->mov(B_pf, B_pf_saved); + + a->cmp(iIdx, mRegBlocks); + a->jl(LoopMBlocks); + } + // generate code for remainder + if (mRegBlocksRem > 0) { + asmjit::Label LoopkRem = a->newLabel(); + int rowRegs = mRegBlocksRem; + + // init C registers + initCRegs<inst_set_t::avx512>(a, rowRegs, colRegs); + + // init k loop index + a->mov(kIdx, 0); + a->bind(LoopkRem); + + // k is incremented by row_interleave + a->add(kIdx, row_interleave); + + genComputeBlock<inst_set_t::avx512>( + a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock); + + // update buffer_A address for next k iteration + a->add(buffer_A, row_interleave * sizeof(uint8_t)); + + // update buffer_B address for next k iteration + a->add(buffer_B, VLEN_ * colRegs * sizeof(int8_t)); + // a->add(B_pf, VLEN_ * colRegs * sizeof(int8_t)); + + a->cmp(kIdx, kSize); + a->jl(LoopkRem); + + // store C matrix + storeCRegs<inst_set_t::avx512>( + a, rowRegs, colRegs, C_Offset, ldcReg, accum); + } + + asmjit::FuncUtils::emitEpilog(a, layout); + + jit_micro_kernel_fp fn; + asmjit::Error err = rt_.add(&fn, &code_); + if (err) { + std::cout << "Error: in fn add" << std::endl; + return nullptr; + } + codeCache_[kernelSig] = fn; + return fn; +} + +} // namespace fbgemm2 diff --git a/src/GenerateKernelU8S8S32ACC32.cc b/src/GenerateKernelU8S8S32ACC32.cc new file mode 100644 index 0000000..dc8c6d3 --- /dev/null +++ b/src/GenerateKernelU8S8S32ACC32.cc @@ -0,0 +1,310 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include <iostream> +#include "GenerateKernel.h" + +namespace fbgemm2 { + +namespace x86 = asmjit::x86; + +/** + * Generate AVX2 instructions for initializing the C registers to 0 in 32-bit + * Accumulation kernel. + */ +template <> +template <> +void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::initCRegs< + inst_set_t::avx2>( + asmjit::X86Emitter* a, + int rowRegs, + int colRegs, + int leadingDimCReg) { + for (int i = 0; i < rowRegs; ++i) { + for (int j = 0; j < colRegs; ++j) { + a->vxorps( + CRegs_avx2_[i * leadingDimCReg + j], + CRegs_avx2_[i * leadingDimCReg + j], + CRegs_avx2_[i * leadingDimCReg + j]); + } + } +} + +/** + * Generate AVX2 instructions for computing block in the rank-k update of 32-bit + * Accmulation kernel. + */ +template <> +template <> +void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::genComputeBlock< + inst_set_t::avx2>( + asmjit::X86Emitter* a, + asmjit::X86Gp buffer_A, + asmjit::X86Gp buffer_B, + asmjit::X86Gp B_pf, + int rowRegs, + int colRegs, + int lda, + int leadingDimCRegAssign) { + // used for matrix A + asmjit::X86Ymm AReg = x86::ymm12; + + // used for matrix B + asmjit::X86Ymm BReg = x86::ymm13; + + // Contains 16-bit 1s + asmjit::X86Ymm oneReg = x86::ymm15; + + // temporary register + asmjit::X86Ymm res1 = x86::ymm14; + + for (int j = 0; j < colRegs; ++j) { + // load B + a->vmovaps(BReg, x86::dword_ptr(buffer_B, j * VLEN_ * sizeof(int8_t))); + // load A, broadcast and fmas + for (int i = 0; i < rowRegs; ++i) { + a->vpbroadcastd( + AReg, x86::dword_ptr(buffer_A, (i * lda) * sizeof(uint8_t))); + a->vpmaddubsw(res1, AReg, BReg); + a->vpmaddwd(res1, oneReg, res1); + a->vpaddd( + CRegs_avx2_[i * leadingDimCRegAssign + j], + res1, + CRegs_avx2_[i * leadingDimCRegAssign + j]); + } + a->prefetcht0(x86::dword_ptr(B_pf, j * VLEN_ * sizeof(int8_t))); + } +} + +/** + * Generate AVX2 instructions for storing the C registers back to the memory in + * 32-bit Accumulation kernel. + */ +template <> +template <> +void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::storeCRegs< + inst_set_t::avx2>( + asmjit::X86Emitter* a, + int rowRegs, + int colRegs, + asmjit::X86Gp C_Offset, + asmjit::X86Gp ldcReg, + bool accum, + int leadingDimCRegAssign) { + // temp register + asmjit::X86Ymm tmpReg = x86::ymm14; + + for (int i = 0; i < rowRegs; ++i) { + if (i != 0) { + a->add(C_Offset, ldcReg); + } + for (int j = 0; j < colRegs; ++j) { + if (accum) { + a->vpaddd( + CRegs_avx2_[i * leadingDimCRegAssign + j], + CRegs_avx2_[i * leadingDimCRegAssign + j], + x86::dword_ptr(a->zcx(), C_Offset, 0, j * 8 * sizeof(int32_t))); + } + a->vmovups( + x86::dword_ptr(a->zcx(), C_Offset, 0, j * 8 * sizeof(int32_t)), + CRegs_avx2_[i * leadingDimCRegAssign + j]); + } + } +} + +/** + * Get or Create the AVX2 instructions for 32-bit Accumulation macro-kernel. + * + */ +template <> +template <> +CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::jit_micro_kernel_fp +CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreate<inst_set_t::avx2>( + bool accum, + int32_t mc, + int32_t nc, + int32_t kc, + int32_t /* unused */) { + auto kernelSig = std::make_tuple(accum, mc, nc); + if (codeCache_.find(kernelSig) != codeCache_.end()) { + return codeCache_[kernelSig]; + } + + code_.reset(false); + code_.init(rt_.getCodeInfo()); + asmjit::X86Assembler assembler(&code_); + asmjit::X86Emitter* a = assembler.asEmitter(); + // ToDo: Dump in a file for debugging + // code dumping/logging + // asmjit::FileLogger logger(stderr); + // code_.setLogger(&logger); + + constexpr int kBlock = PackingTraits<int8_t, int32_t, inst_set_t::avx2>::KCB; + constexpr int mRegBlockSize = + PackingTraits<int8_t, int32_t, inst_set_t::avx2>::MR; + constexpr int row_interleave = + PackingTraits<int8_t, int32_t, inst_set_t::avx2>::ROW_INTERLEAVE; + assert(kc % row_interleave == 0 && "kc must be a multiple of row_interleave"); + // assert(mc <= 12 && "mc must be <= 12 (available registers constraint)"); + int mRegBlocks = mc / mRegBlockSize; + int mRegBlocksRem = mc % mRegBlockSize; + + // arguments to the function created + asmjit::X86Gp buffer_A = a->zdi(); + asmjit::X86Gp buffer_B = a->zsi(); + asmjit::X86Gp B_pf = a->zdx(); + asmjit::X86Gp CBase = a->zcx(); + asmjit::X86Gp kSize = a->gpzRef(8); + asmjit::X86Gp ldcReg = a->gpzRef(9); + + asmjit::FuncDetail func; + func.init( + asmjit:: + FuncSignature6<void, uint8_t*, int8_t*, int8_t*, int32_t*, int, int>( + asmjit::CallConv::kIdHost)); + + asmjit::FuncFrameInfo ffi; + ffi.setDirtyRegs( + asmjit::X86Reg::kKindVec, + asmjit::Utils::mask(0, 1, 2, 3, 4, 5, 6, 7) | + asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14, 15)); + ffi.setDirtyRegs( + asmjit::X86Reg::kKindGp, asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14)); + + asmjit::FuncArgsMapper args(&func); + args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg); + + args.updateFrameInfo(ffi); + + asmjit::FuncFrameLayout layout; + layout.init(func, ffi); + + asmjit::FuncUtils::emitProlog(a, layout); + asmjit::FuncUtils::allocArgs(a, layout, args); + + asmjit::Label Loopk = a->newLabel(); + asmjit::Label LoopMBlocks = a->newLabel(); + + asmjit::X86Gp buffer_B_saved = a->gpzRef(10); + asmjit::X86Gp C_Offset = a->gpzRef(11); + asmjit::X86Gp B_pf_saved = a->gpzRef(12); + asmjit::X86Gp iIdx = a->gpzRef(13); + asmjit::X86Gp kIdx = a->gpzRef(14); + // asmjit::X86Gp B_pf = a->gpzRef(8); + + asmjit::X86Ymm oneReg = x86::ymm15; + // create 16-bit 1s + // i.e., oneReg[0:15] contains 0x0001, oneReg[16:31] contains 0x0001 + // and so on + a->vpcmpeqw(oneReg, oneReg, oneReg); + a->vpsrlw(oneReg, oneReg, 15); + a->imul(ldcReg, ldcReg, sizeof(int32_t)); + a->mov(C_Offset, 0); + + int colRegs = nc * row_interleave * sizeof(int8_t) / VLEN_; + if (mRegBlocks > 0) { + // move 0 to iteration variables + a->mov(iIdx, 0); + + // save B_buffer address + a->mov(buffer_B_saved, buffer_B); + a->mov(B_pf_saved, B_pf); + + a->bind(LoopMBlocks); + a->inc(iIdx); + + int rowRegs = mRegBlockSize; + + // init C registers + initCRegs<inst_set_t::avx2>(a, rowRegs, colRegs, colRegs); + + // init k loop index + a->mov(kIdx, 0); + a->bind(Loopk); + + // k is incremented by row_interleave + a->add(kIdx, row_interleave); + + genComputeBlock<inst_set_t::avx2>( + a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock, colRegs); + + // update buffer_A address for next k iteration + a->add(buffer_A, row_interleave * sizeof(uint8_t)); + + // update buffer_B address for next k iteration + a->add(buffer_B, VLEN_ * colRegs * sizeof(int8_t)); + a->add(B_pf, VLEN_ * colRegs * sizeof(int8_t)); + + // a->add(B_pf, 32*sizeof(float)); + + a->cmp(kIdx, kSize); + a->jl(Loopk); + + // store C matrix + storeCRegs<inst_set_t::avx2>( + a, rowRegs, colRegs, C_Offset, ldcReg, accum, colRegs); + + // increment A for next block + a->sub(buffer_A, kSize); + a->add(buffer_A, (rowRegs)*kBlock * sizeof(uint8_t)); + + // increment C for next block + a->imul(C_Offset, ldcReg, rowRegs); + a->add(CBase, C_Offset); + a->mov(C_Offset, 0); + + // reset B + a->mov(buffer_B, buffer_B_saved); + a->mov(B_pf, B_pf_saved); + a->cmp(iIdx, mRegBlocks); + a->jl(LoopMBlocks); + } + // generate code for remainder + if (mRegBlocksRem > 0) { + asmjit::Label LoopkRem = a->newLabel(); + int rowRegs = mRegBlocksRem; + + // init C registers + initCRegs<inst_set_t::avx2>(a, rowRegs, colRegs, colRegs); + + // init k loop index + a->mov(kIdx, 0); + a->bind(LoopkRem); + + // k is incremented by row_interleave + a->add(kIdx, row_interleave); + + genComputeBlock<inst_set_t::avx2>( + a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock, colRegs); + + // update buffer_A address for next k iteration + a->add(buffer_A, row_interleave * sizeof(uint8_t)); + + // update buffer_B address for next k iteration + a->add(buffer_B, VLEN_ * colRegs * sizeof(int8_t)); + a->add(B_pf, VLEN_ * colRegs * sizeof(int8_t)); + + a->cmp(kIdx, kSize); + a->jl(LoopkRem); + + // store C matrix + storeCRegs<inst_set_t::avx2>( + a, rowRegs, colRegs, C_Offset, ldcReg, accum, colRegs); + } + + asmjit::FuncUtils::emitEpilog(a, layout); + + jit_micro_kernel_fp fn; + asmjit::Error err = rt_.add(&fn, &code_); + if (err) { + std::cout << "Error: in fn add" << std::endl; + return nullptr; + } + codeCache_[kernelSig] = fn; + return fn; +} + +} // namespace fbgemm2 diff --git a/src/GenerateKernelU8S8S32ACC32_avx512.cc b/src/GenerateKernelU8S8S32ACC32_avx512.cc new file mode 100644 index 0000000..5cd5684 --- /dev/null +++ b/src/GenerateKernelU8S8S32ACC32_avx512.cc @@ -0,0 +1,312 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include <iostream> +#include "GenerateKernel.h" + +namespace fbgemm2 { + +namespace x86 = asmjit::x86; + +/** + * Generate AVX512 instructions for initializing the C registers to 0 in 32-bit + * Accumulation kernel. + */ +template <> +template <> +void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::initCRegs< + inst_set_t::avx512>( + asmjit::X86Emitter* a, + int rowRegs, + int colRegs, + int leadingDimCReg) { + for (int i = 0; i < rowRegs; ++i) { + for (int j = 0; j < colRegs; ++j) { + a->vxorps( + CRegs_avx512_[i * leadingDimCReg + j], + CRegs_avx512_[i * leadingDimCReg + j], + CRegs_avx512_[i * leadingDimCReg + j]); + } + } +} + +/** + * Generate AVX512 instructions for computing block in the rank-k update of + * 32-bit Accmulation kernel. + */ +template <> +template <> +void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::genComputeBlock< + inst_set_t::avx512>( + asmjit::X86Emitter* a, + asmjit::X86Gp buffer_A, + asmjit::X86Gp buffer_B, + asmjit::X86Gp B_pf, + int rowRegs, + int colRegs, + int lda, + int leadingDimCRegAssign) { + // used for matrix A + asmjit::X86Zmm AReg = x86::zmm31; + + // used for matrix B + asmjit::X86Zmm BReg = x86::zmm30; + + // Contains 16-bit 1s + asmjit::X86Zmm oneReg = x86::zmm29; + + // temporary register + asmjit::X86Zmm res1 = x86::zmm28; + + for (int j = 0; j < colRegs; ++j) { + // load B + a->vmovaps(BReg, x86::dword_ptr(buffer_B, j * VLEN_ * sizeof(int8_t))); + // load A, broadcast and fmas + for (int i = 0; i < rowRegs; ++i) { + a->vpbroadcastd( + AReg, x86::dword_ptr(buffer_A, (i * lda) * sizeof(uint8_t))); + a->vpmaddubsw(res1, AReg, BReg); + a->vpmaddwd(res1, oneReg, res1); + a->vpaddd( + CRegs_avx512_[i * leadingDimCRegAssign + j], + res1, + CRegs_avx512_[i * leadingDimCRegAssign + j]); + } + a->prefetcht0(x86::dword_ptr(B_pf, j * VLEN_ * sizeof(int8_t))); + } +} + +/** + * Generate AVX512 instructions for storing the C registers back to the memory + * in 32-bit Accumulation kernel. + */ +template <> +template <> +void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::storeCRegs< + inst_set_t::avx512>( + asmjit::X86Emitter* a, + int rowRegs, + int colRegs, + asmjit::X86Gp C_Offset, + asmjit::X86Gp ldcReg, + bool accum, + int leadingDimCRegAssign) { + // temp register + asmjit::X86Zmm tmpReg = x86::zmm28; + + for (int i = 0; i < rowRegs; ++i) { + if (i != 0) { + a->add(C_Offset, ldcReg); + } + for (int j = 0; j < colRegs; ++j) { + if (accum) { + a->vpaddd( + CRegs_avx512_[i * leadingDimCRegAssign + j], + CRegs_avx512_[i * leadingDimCRegAssign + j], + x86::dword_ptr(a->zcx(), C_Offset, 0, j * 16 * sizeof(int32_t))); + } + a->vmovups( + x86::dword_ptr(a->zcx(), C_Offset, 0, j * 16 * sizeof(int32_t)), + CRegs_avx512_[i * leadingDimCRegAssign + j]); + } + } +} + +/** + * Get or Create the AVX512 instructions for 32-bit Accumulation macro-kernel. + * + */ +template <> +template <> +CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::jit_micro_kernel_fp +CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreate<inst_set_t::avx512>( + bool accum, + int32_t mc, + int32_t nc, + int32_t kc, + int32_t /* unused */) { + auto kernelSig = std::make_tuple(accum, mc, nc); + if (codeCache_.find(kernelSig) != codeCache_.end()) { + return codeCache_[kernelSig]; + } + + code_.reset(false); + code_.init(rt_.getCodeInfo()); + asmjit::X86Assembler assembler(&code_); + asmjit::X86Emitter* a = assembler.asEmitter(); + // ToDo: Dump in a file for debugging + // code dumping/logging + // asmjit::FileLogger logger(stderr); + // code_.setLogger(&logger); + + constexpr int kBlock = + PackingTraits<int8_t, int32_t, inst_set_t::avx512>::KCB; + constexpr int mRegBlockSize = + PackingTraits<int8_t, int32_t, inst_set_t::avx512>::MR; + constexpr int row_interleave = + PackingTraits<int8_t, int32_t, inst_set_t::avx512>::ROW_INTERLEAVE; + assert(kc % row_interleave == 0 && "kc must be a multiple of row_interleave"); + // assert(mc <= 12 && "mc must be <= 12 (available registers constraint)"); + int mRegBlocks = mc / mRegBlockSize; + int mRegBlocksRem = mc % mRegBlockSize; + + // arguments to the function created + asmjit::X86Gp buffer_A = a->zdi(); + asmjit::X86Gp buffer_B = a->zsi(); + asmjit::X86Gp B_pf = a->zdx(); + asmjit::X86Gp CBase = a->zcx(); + asmjit::X86Gp kSize = a->gpzRef(8); + asmjit::X86Gp ldcReg = a->gpzRef(9); + + asmjit::FuncDetail func; + func.init( + asmjit:: + FuncSignature6<void, uint8_t*, int8_t*, int8_t*, int32_t*, int, int>( + asmjit::CallConv::kIdHost)); + + asmjit::FuncFrameInfo ffi; + ffi.setDirtyRegs( + asmjit::X86Reg::kKindVec, + asmjit::Utils::mask(0, 1, 2, 3, 4, 5, 6, 7) | + asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14, 15)); + ffi.setDirtyRegs( + asmjit::X86Reg::kKindGp, asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14)); + + asmjit::FuncArgsMapper args(&func); + args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg); + + args.updateFrameInfo(ffi); + + asmjit::FuncFrameLayout layout; + layout.init(func, ffi); + + asmjit::FuncUtils::emitProlog(a, layout); + asmjit::FuncUtils::allocArgs(a, layout, args); + + asmjit::Label Loopk = a->newLabel(); + asmjit::Label LoopMBlocks = a->newLabel(); + + asmjit::X86Gp buffer_B_saved = a->gpzRef(10); + asmjit::X86Gp C_Offset = a->gpzRef(11); + asmjit::X86Gp B_pf_saved = a->gpzRef(12); + asmjit::X86Gp iIdx = a->gpzRef(13); + asmjit::X86Gp kIdx = a->gpzRef(14); + // asmjit::X86Gp B_pf = a->gpzRef(8); + + asmjit::X86Zmm oneReg = x86::zmm29; + // create 16-bit 1s + // i.e., oneReg[0:15] contains 0x0001, oneReg[16:31] contains 0x0001 + // and so on + // a->vpcmpeqw(oneReg, oneReg, oneReg); + a->vpternlogd(oneReg, oneReg, oneReg, 0xff); + a->vpsrlw(oneReg, oneReg, 15); + a->imul(ldcReg, ldcReg, sizeof(int32_t)); + a->mov(C_Offset, 0); + + int colRegs = nc * row_interleave * sizeof(int8_t) / VLEN_; + if (mRegBlocks > 0) { + // move 0 to iteration variables + a->mov(iIdx, 0); + + // save B_buffer address + a->mov(buffer_B_saved, buffer_B); + a->mov(B_pf_saved, B_pf); + + a->bind(LoopMBlocks); + a->inc(iIdx); + + int rowRegs = mRegBlockSize; + + // init C registers + initCRegs<inst_set_t::avx512>(a, rowRegs, colRegs, colRegs); + + // init k loop index + a->mov(kIdx, 0); + a->bind(Loopk); + + // k is incremented by row_interleave + a->add(kIdx, row_interleave); + + genComputeBlock<inst_set_t::avx512>( + a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock, colRegs); + + // update buffer_A address for next k iteration + a->add(buffer_A, row_interleave * sizeof(uint8_t)); + + // update buffer_B address for next k iteration + a->add(buffer_B, VLEN_ * colRegs * sizeof(int8_t)); + a->add(B_pf, VLEN_ * colRegs * sizeof(int8_t)); + + // a->add(B_pf, 32*sizeof(float)); + + a->cmp(kIdx, kSize); + a->jl(Loopk); + + // store C matrix + storeCRegs<inst_set_t::avx512>( + a, rowRegs, colRegs, C_Offset, ldcReg, accum, colRegs); + + // increment A for next block + a->sub(buffer_A, kSize); + a->add(buffer_A, (rowRegs)*kBlock * sizeof(uint8_t)); + + // increment C for next block + a->imul(C_Offset, ldcReg, rowRegs); + a->add(CBase, C_Offset); + a->mov(C_Offset, 0); + + // reset B + a->mov(buffer_B, buffer_B_saved); + a->mov(B_pf, B_pf_saved); + a->cmp(iIdx, mRegBlocks); + a->jl(LoopMBlocks); + } + // generate code for remainder + if (mRegBlocksRem > 0) { + asmjit::Label LoopkRem = a->newLabel(); + int rowRegs = mRegBlocksRem; + + // init C registers + initCRegs<inst_set_t::avx512>(a, rowRegs, colRegs, colRegs); + + // init k loop index + a->mov(kIdx, 0); + a->bind(LoopkRem); + + // k is incremented by row_interleave + a->add(kIdx, row_interleave); + + genComputeBlock<inst_set_t::avx512>( + a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock, colRegs); + + // update buffer_A address for next k iteration + a->add(buffer_A, row_interleave * sizeof(uint8_t)); + + // update buffer_B address for next k iteration + a->add(buffer_B, VLEN_ * colRegs * sizeof(int8_t)); + a->add(B_pf, VLEN_ * colRegs * sizeof(int8_t)); + + a->cmp(kIdx, kSize); + a->jl(LoopkRem); + + // store C matrix + storeCRegs<inst_set_t::avx512>( + a, rowRegs, colRegs, C_Offset, ldcReg, accum, colRegs); + } + + asmjit::FuncUtils::emitEpilog(a, layout); + + jit_micro_kernel_fp fn; + asmjit::Error err = rt_.add(&fn, &code_); + if (err) { + std::cout << "Error: in fn add" << std::endl; + return nullptr; + } + codeCache_[kernelSig] = fn; + return fn; +} + +} // namespace fbgemm2 diff --git a/src/PackAMatrix.cc b/src/PackAMatrix.cc new file mode 100644 index 0000000..543d99b --- /dev/null +++ b/src/PackAMatrix.cc @@ -0,0 +1,165 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include <cpuinfo.h> +#include <cassert> +#include <iomanip> +#include <iostream> +#include "fbgemm/Fbgemm.h" + +namespace fbgemm2 { + +template <typename T, typename accT> +PackAMatrix<T, accT>::PackAMatrix( + matrix_op_t trans, + int32_t nRow, + int32_t nCol, + const T* smat, + int32_t ld, + inpType* pmat, + int32_t groups, + accT zero_pt) + : PackMatrix<PackAMatrix<T, accT>, T, accT>(nRow, nCol, pmat, zero_pt), + trans_(trans), + smat_(smat), + ld_(ld), + G_(groups) { + assert(G_ == 1 && "Groups != 1 not supported yet"); + + if (cpuinfo_has_x86_avx512f()) { + BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx512>::MCB; + BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx512>::KCB; + row_interleave_B_ = + PackingTraits<T, accT, inst_set_t::avx512>::ROW_INTERLEAVE; + } else if (cpuinfo_has_x86_avx2()) { + BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx2>::MCB; + BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx2>::KCB; + row_interleave_B_ = + PackingTraits<T, accT, inst_set_t::avx2>::ROW_INTERLEAVE; + } else { + // TODO: Have default slower path + assert(0 && "unsupported architecure"); + } + if (!pmat) { + BaseType::buf_ = + (T*)aligned_alloc(64, BaseType::brow_ * BaseType::bcol_ * sizeof(T)); + } +} + +template <typename T, typename accT> +void PackAMatrix<T, accT>::pack(const block_type_t& block) { + block_type_t block_p = {block.row_start, + block.row_size, + block.col_start, + (block.col_size + row_interleave_B_ - 1) / + row_interleave_B_ * row_interleave_B_}; + + BaseType::packedBlock(block_p); + bool tr = (trans_ == matrix_op_t::Transpose); + T* out = BaseType::getBuf(); + if (tr) { + for (int i = block.row_start; i < block.row_start + block.row_size; ++i) { + for (int j = block.col_start; j < block.col_start + block.col_size; ++j) { + T val = smat_[i + ld_ * j]; + out[addr(i, j) - addr(block.row_start, block.col_start)] = val; + } + // zero fill + // Please note that we zero fill, not zero_pt fill, because for + // requantization original, i.e., not padded, dimensions are used. If we + // were to use padded dimensions for requantization, we would zero_pt + // fill. + // For example, consider the following dot product: + // A = .3(5-15), .3(20-15) //.3 is scale and 15 is zero_pt + // B = .4(1+10), .4(4+10) // .4 is scale and -10 is zero_pt + // + // numElements(A) = 2 and numElements(B) = 2 + // + // Dot product is (real): -3*4.4+1.5*5.6 = -4.8 + // Dot product is (quantized): 5*1+20*4 = 85 + // + // requantization: .3*.4(85 - (5+20)*(-10) - (1+4)*(15) + + // numElements(A)*(15)(-10)) = -4.8 + // + // In the above adding one more element zero in the quantized domain, + // i.e., the quantized vectors become: + // A_q = 5, 20, 0 + // B_q = 1, 4, 0 + // + // and requantization with numElements(A) = 2 will produce the same + // answer (-4.8). + // + // Also in the above adding one more element zero_pt in the quantized + // domain, i.e., the quantized vectors become: + // A_q = 5, 20, 15 + // B_q = 1, 4, -10 + // + // and requantization with numElements(A) = 3 will produce the same + // answer (-4.8). + for (int j = block.col_start + block.col_size; + j < block_p.col_start + block_p.col_size; + ++j) { + out[addr(i, j) - addr(block.row_start, block.col_start)] = 0; + } + } + } else { + for (int i = block.row_start; i < block.row_start + block.row_size; ++i) { + int buf_idx = i - block.row_start; + memcpy( + out + buf_idx * BaseType::blockColSize(), + smat_ + i * ld_ + block.col_start, + block.col_size * sizeof(T)); + // zero fill + for (int j = block.col_size; j < block_p.col_size; ++j) { + out[buf_idx * BaseType::blockColSize() + j] = 0; + } + } + } +} + +template <typename T, typename accT> +int32_t PackAMatrix<T, accT>::addr(int32_t r, int32_t c) const { + int32_t block_row_id = r / BaseType::blockRowSize(); + int32_t brow_offset = (block_row_id * BaseType::blockCols()) * + (BaseType::blockRowSize() * BaseType::blockColSize()); + + int32_t block_col_id = c / BaseType::blockColSize(); + int32_t bcol_offset = + block_col_id * BaseType::blockRowSize() * BaseType::blockColSize(); + int32_t block_offset = brow_offset + bcol_offset; + int32_t inblock_offset = + (r % BaseType::blockRowSize()) * BaseType::blockColSize() + + (c % BaseType::blockColSize()); + + int32_t index = block_offset + inblock_offset; + + return index; +} + +template <typename T, typename accT> +void PackAMatrix<T, accT>::printPackedMatrix(std::string name) { + std::cout << name << ":" + << "[" << BaseType::numPackedRows() << ", " + << BaseType::numPackedCols() << "]" << std::endl; + + T* out = BaseType::getBuf(); + for (auto r = 0; r < BaseType::numPackedRows(); ++r) { + for (auto c = 0; c < BaseType::numPackedCols(); ++c) { + T val = out[addr(r, c)]; + if (std::is_integral<T>::value) { + // cast to int64 because cout doesn't print int8_t type directly + std::cout << std::setw(5) << static_cast<int64_t>(val) << " "; + } else { + std::cout << std::setw(5) << val << " "; + } + } + std::cout << std::endl; + } + std::cout << std::endl; +} + +template class PackAMatrix<uint8_t, int32_t>; +template class PackAMatrix<uint8_t, int16_t>; +} // namespace fbgemm2 diff --git a/src/PackAWithIm2Col.cc b/src/PackAWithIm2Col.cc new file mode 100644 index 0000000..7012289 --- /dev/null +++ b/src/PackAWithIm2Col.cc @@ -0,0 +1,146 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include <cpuinfo.h> +#include <cassert> +#include <iomanip> +#include <iostream> +#include "fbgemm/Fbgemm.h" + +namespace fbgemm2 { + +template <typename T, typename accT> +PackAWithIm2Col<T, accT>::PackAWithIm2Col( + const conv_param_t& conv_p, + const T* sdata, + inpType* pmat, + int32_t zero_pt, + int32_t* row_offset) + : PackMatrix<PackAWithIm2Col<T, accT>, T, accT>( + conv_p.MB * conv_p.OH * conv_p.OW, + conv_p.KH * conv_p.KW * conv_p.IC, + pmat, + zero_pt), + conv_p_(conv_p), + sdata_(sdata) { + assert(conv_p.G == 1 && "Groups != 1 not supported yet"); + + if (cpuinfo_has_x86_avx512f()) { + BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx512>::MCB; + BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx512>::KCB; + row_interleave_B_ = + PackingTraits<T, accT, inst_set_t::avx512>::ROW_INTERLEAVE; + } else if (cpuinfo_has_x86_avx2()) { + BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx2>::MCB; + BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx2>::KCB; + row_interleave_B_ = + PackingTraits<T, accT, inst_set_t::avx2>::ROW_INTERLEAVE; + } else { + // TODO: Have default slower path + assert(0 && "unsupported architecure"); + } + if (pmat) { + BaseType::buf_ = pmat; + } else { + BaseType::bufAllocatedHere_ = true; + BaseType::buf_ = static_cast<T*>( + aligned_alloc(64, BaseType::brow_ * BaseType::bcol_ * sizeof(T))); + } + if (row_offset) { + row_offset_ = row_offset; + } else { + rowOffsetAllocatedHere = true; + row_offset_ = static_cast<int32_t*>( + aligned_alloc(64, BaseType::brow_ * sizeof(int32_t))); + } +} + +template <typename T, typename accT> +void PackAWithIm2Col<T, accT>::pack(const block_type_t& block) { + block_type_t block_p = {block.row_start, + block.row_size, + block.col_start, + (block.col_size + row_interleave_B_ - 1) / + row_interleave_B_ * row_interleave_B_}; + + BaseType::packedBlock(block_p); + T* out = BaseType::getBuf(); + for (int i = block.row_start; i < block.row_start + block.row_size; ++i) { + int n = i / (conv_p_.OH * conv_p_.OW); + int hw = i % (conv_p_.OH * conv_p_.OW); + int w = hw % conv_p_.OW; + int h = hw / conv_p_.OW; + for (int j = block.col_start; j < block.col_start + block.col_size; ++j) { + int c = j % conv_p_.IC; + int rs = j / conv_p_.IC; + int s = rs % conv_p_.KW; + int r = rs / conv_p_.KW; + + int w_in = -conv_p_.pad_w + w * conv_p_.stride_w + s; + int h_in = -conv_p_.pad_h + h * conv_p_.stride_h + r; + // Please note that padding for convolution should be filled with zero_pt + if (h_in < 0 || h_in >= conv_p_.IH || w_in < 0 || w_in >= conv_p_.IW) { + out[(i - block.row_start) * BaseType::blockColSize() + + (j - block.col_start)] = BaseType::zeroPoint(); + } else { + out[(i - block.row_start) * BaseType::blockColSize() + + (j - block.col_start)] = sdata_ + [((n * conv_p_.IH + h_in) * conv_p_.IW + w_in) * conv_p_.IC + c]; + } + } + // zero fill + // Please see the comment in PackAMatrix.cc for zero vs zero_pt fill. + for (int j = block.col_start + block.col_size; + j < block_p.col_start + block_p.col_size; + ++j) { + out[(i - block.row_start) * BaseType::blockColSize() + + (j - block.col_start)] = 0; + } + } +} + +template <typename T, typename accT> +void PackAWithIm2Col<T, accT>::printPackedMatrix(std::string name) { + std::cout << name << ":" + << "[" << BaseType::numPackedRows() << ", " + << BaseType::numPackedCols() << "]" << std::endl; + + T* out = BaseType::getBuf(); + for (auto r = 0; r < BaseType::numPackedRows(); ++r) { + for (auto c = 0; c < BaseType::numPackedCols(); ++c) { + T val = out[ r * BaseType::blockColSize() + c ]; + if (std::is_integral<T>::value) { + // cast to int64 because cout doesn't print int8_t type directly + std::cout << std::setw(5) << static_cast<int64_t>(val) << " "; + } else { + std::cout << std::setw(5) << val << " "; + } + } + std::cout << std::endl; + } + std::cout << std::endl; +} + +template <typename T, typename accT> +int PackAWithIm2Col<T, accT>::rowOffsetBufferSize() { + if (cpuinfo_initialize()) { + if (cpuinfo_has_x86_avx512f()) { + return PackingTraits<T, accT, inst_set_t::avx512>::MCB; + } else if (cpuinfo_has_x86_avx2()) { + return PackingTraits<T, accT, inst_set_t::avx2>::MCB; + } else { + // TODO: Have default slower path + assert(0 && "unsupported architecture"); + return -1; + } + } else { + throw std::runtime_error("Failed to initialize cpuinfo!"); + } +} + +template class PackAWithIm2Col<uint8_t, int32_t>; +template class PackAWithIm2Col<uint8_t, int16_t>; +} // namespace fbgemm2 diff --git a/src/PackBMatrix.cc b/src/PackBMatrix.cc new file mode 100644 index 0000000..30d94f8 --- /dev/null +++ b/src/PackBMatrix.cc @@ -0,0 +1,144 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include <cpuinfo.h> +#include <cassert> +#include <iomanip> +#include <iostream> +#include "fbgemm/Fbgemm.h" + +namespace fbgemm2 { + +template <typename T, typename accT> +PackBMatrix<T, accT>::PackBMatrix( + matrix_op_t trans, + int32_t nRow, + int32_t nCol, + const T* smat, + int32_t ld, + inpType* pmat, + int32_t groups, + accT zero_pt) + : PackMatrix<PackBMatrix<T, accT>, T, accT>(nRow, nCol, pmat, zero_pt), + trans_(trans), + smat_(smat), + ld_(ld), + G_(groups) { + assert(G_ == 1 && "Groups != 1 not supported yet"); + + if (cpuinfo_has_x86_avx512f()) { + BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx512>::KCB; + BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx512>::NCB; + row_interleave_ = + PackingTraits<T, accT, inst_set_t::avx512>::ROW_INTERLEAVE; + } else if (cpuinfo_has_x86_avx2()) { + BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx2>::KCB; + BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx2>::NCB; + row_interleave_ = PackingTraits<T, accT, inst_set_t::avx2>::ROW_INTERLEAVE; + } else { + // Error + assert(0 && "unknown architecure"); + } + block_type_t block{0, BaseType::numRows(), 0, BaseType::numCols()}; + BaseType::packedBlock(block); + if (!pmat) { + BaseType::bufAllocatedHere_ = true; + BaseType::buf_ = (T*)aligned_alloc( + 64, + BaseType::blockRows() * BaseType::brow_ * BaseType::blockCols() * + BaseType::bcol_ * sizeof(T)); + } + pack(block); +} + +template <typename T, typename accT> +void PackBMatrix<T, accT>::pack(const block_type_t& block) { + assert((BaseType::blockRowSize() % row_interleave_) == 0); + + BaseType::packedBlock(block); + T* out = BaseType::getBuf(); + bool tr = (trans_ == matrix_op_t::Transpose); + for (int i = block.row_start; i < block.row_start + block.row_size; ++i) { + for (int j = block.col_start; j < block.col_start + block.col_size; ++j) { + T val = tr ? smat_[i + ld_ * j] : smat_[i * ld_ + j]; + out[addr(i, j) - addr(block.row_start, block.col_start)] = + tconv(val, out[addr(i, j)]); + } + } + // fill the remaining with zero. + // Please see the comment in PackAMatrix.cc on zero vs zero_pt fill. + for (int i = block.row_start + block.row_size; + i < (block.row_start + block.row_size + row_interleave_ - 1) / + row_interleave_ * row_interleave_; + ++i) { + for (int j = block.col_start; j < block.col_start + block.col_size; j++) { + out[addr(i, j) - addr(block.row_start, block.col_start)] = + tconv(0, out[addr(i, j)]); + } + } +} + +template <typename T, typename accT> +int32_t PackBMatrix<T, accT>::addr(int32_t r, int32_t c) const { + int32_t block_row_id = r / BaseType::blockRowSize(); + int32_t brow_offset = (block_row_id * BaseType::blockCols()) * + (BaseType::blockRowSize() * BaseType::blockColSize()); + + int32_t block_col_id = c / BaseType::blockColSize(); + int32_t bcol_offset = + block_col_id * BaseType::blockRowSize() * BaseType::blockColSize(); + int32_t block_offset = brow_offset + bcol_offset; + int32_t inblock_offset = (r % BaseType::blockRowSize() / row_interleave_) * + BaseType::blockColSize() * row_interleave_ + + (c % BaseType::blockColSize()) * row_interleave_ + r % row_interleave_; + + int32_t index = block_offset + inblock_offset; + + return index; +} + +template <typename T, typename accT> +void PackBMatrix<T, accT>::printPackedMatrix(std::string name) { + std::cout << name << ":" + << "[" << BaseType::numPackedRows() << ", " + << BaseType::numPackedCols() << "]" << std::endl; + std::cout << "block size:" + << "[" << BaseType::blockRowSize() << ", " + << BaseType::blockColSize() << "]" << std::endl; + + T* out = BaseType::getBuf(); + for (auto nr = 0; nr < BaseType::blockRows(); ++nr) { + auto rows = (nr == BaseType::blockRows() - 1) ? BaseType::lastBrow() + : BaseType::blockRowSize(); + for (auto nc = 0; nc < BaseType::blockCols(); ++nc) { + std::cout << "block:" << nr << ", " << nc << std::endl; + auto cols = (nc == BaseType::blockCols() - 1) ? BaseType::lastBcol() + : BaseType::blockColSize(); + for (auto r = 0; r < (rows + row_interleave_ - 1) / row_interleave_; + ++r) { + for (auto c = 0; c < cols * row_interleave_; ++c) { + T val = + out[nr * BaseType::blockCols() * BaseType::blockRowSize() * + BaseType::blockColSize() + + nc * BaseType::blockRowSize() * BaseType::blockColSize() + + r * BaseType::blockColSize() * row_interleave_ + c]; + if (std::is_integral<T>::value) { + // cast to int64 because cout doesn't print int8_t type directly + std::cout << std::setw(5) << static_cast<int64_t>(val) << " "; + } else { + std::cout << std::setw(5) << val << " "; + } + } + std::cout << std::endl; + } + std::cout << std::endl; + } + } +} + +template class PackBMatrix<int8_t, int32_t>; +template class PackBMatrix<int8_t, int16_t>; +} // namespace fbgemm2 diff --git a/src/PackMatrix.cc b/src/PackMatrix.cc new file mode 100644 index 0000000..85000ac --- /dev/null +++ b/src/PackMatrix.cc @@ -0,0 +1,86 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include <cpuinfo.h> +#include <iomanip> +#include <stdexcept> +#include <type_traits> +#include "fbgemm/ConvUtils.h" +#include "fbgemm/Fbgemm.h" + +namespace fbgemm2 { + +template <typename PT, typename inpType, typename accType> +PackMatrix<PT, inpType, accType>::PackMatrix( + int32_t rows, + int32_t cols, + inpType* buf, + int32_t zero_pt) + : buf_(buf), nrows_(rows), ncols_(cols), zero_pt_(zero_pt) { + bufAllocatedHere_ = false; + if (!cpuinfo_initialize()) { + throw std::runtime_error("Failed to initialize cpuinfo!"); + } +} + +template <typename PT, typename inpType, typename accType> +int PackMatrix<PT, inpType, accType>::packedBufferSize(int rows, int cols) { + if (cpuinfo_has_x86_avx512f()) { + if (isA) { + return PackingTraits<inpType, accType, inst_set_t::avx512>::MCB * + PackingTraits<inpType, accType, inst_set_t::avx512>::KCB; + } else { + int rowBlock = PackingTraits<inpType, accType, inst_set_t::avx512>::KCB; + int colBlock = PackingTraits<inpType, accType, inst_set_t::avx512>::NCB; + return (((rows + rowBlock - 1) / rowBlock) * rowBlock) * + (((cols + colBlock - 1) / colBlock) * colBlock); + } + } else if (cpuinfo_has_x86_avx2()) { + if (isA) { + return PackingTraits<inpType, accType, inst_set_t::avx2>::MCB * + PackingTraits<inpType, accType, inst_set_t::avx2>::KCB; + } else { + int rowBlock = PackingTraits<inpType, accType, inst_set_t::avx2>::KCB; + int colBlock = PackingTraits<inpType, accType, inst_set_t::avx2>::NCB; + return (((rows + rowBlock - 1) / rowBlock) * rowBlock) * + (((cols + colBlock - 1) / colBlock) * colBlock); + } + } else { + // TODO: Have default slower path + assert(0 && "unsupported architecure"); + } + return -1; +} + +// int32 accumulation +template class PackMatrix<PackAMatrix<uint8_t, int32_t>, uint8_t, int32_t>; + +template class PackMatrix< + PackAWithRowOffset<uint8_t, int32_t>, + uint8_t, + int32_t>; + +template class PackMatrix<PackAWithIm2Col<uint8_t, int32_t>, uint8_t, int32_t>; + +template class PackMatrix< + PackAWithQuantRowOffset<uint8_t, int32_t>, + uint8_t, + int32_t>; + +template class PackMatrix<PackBMatrix<int8_t, int32_t>, int8_t, int32_t>; + +// int16 accumulation +template class PackMatrix<PackAWithIm2Col<uint8_t, int16_t>, uint8_t, int16_t>; + +template class PackMatrix< + PackAWithRowOffset<uint8_t, int16_t>, + uint8_t, + int16_t>; + +template class PackMatrix<PackAMatrix<uint8_t, int16_t>, uint8_t, int16_t>; + +template class PackMatrix<PackBMatrix<int8_t, int16_t>, int8_t, int16_t>; +} // namespace fbgemm2 diff --git a/src/PackWithQuantRowOffset.cc b/src/PackWithQuantRowOffset.cc new file mode 100644 index 0000000..74eaade --- /dev/null +++ b/src/PackWithQuantRowOffset.cc @@ -0,0 +1,230 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include <cpuinfo.h> +#include <cassert> +#include <cmath> +#include <cstring> +#include <iomanip> +#include <iostream> +#include <stdexcept> +#include "fbgemm/Fbgemm.h" + +namespace fbgemm2 { + +template <typename T, typename accT> +PackAWithQuantRowOffset<T, accT>::PackAWithQuantRowOffset( + matrix_op_t trans, + int32_t nRow, + int32_t nCol, + const float* smat, + int32_t ld, + inpType* pmat, + float scale, + int32_t zero_pt, + int32_t groups, + int32_t* row_offset) + : PackMatrix<PackAWithQuantRowOffset<T, accT>, T, accT>( + nRow, + nCol, + pmat, + zero_pt), + trans_(trans), + smat_(smat), + ld_(ld), + scale_(scale), + G_(groups), + row_offset_(row_offset) { + assert(G_ == 1 && "Groups != 1 not supported yet"); + + rowOffsetAllocatedHere = false; + + if (cpuinfo_has_x86_avx512f()) { + BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx512>::MCB; + BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx512>::KCB; + row_interleave_B_ = + PackingTraits<T, accT, inst_set_t::avx512>::ROW_INTERLEAVE; + } else if (cpuinfo_has_x86_avx2()) { + BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx2>::MCB; + BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx2>::KCB; + row_interleave_B_ = + PackingTraits<T, accT, inst_set_t::avx2>::ROW_INTERLEAVE; + } else { + // TODO: Have default slower path + assert(0 && "unknown architecure"); + } + if (pmat) { + BaseType::buf_ = pmat; + } else { + BaseType::bufAllocatedHere_ = true; + BaseType::buf_ = + (T*)aligned_alloc(64, BaseType::brow_ * BaseType::bcol_ * sizeof(T)); + } + if (!row_offset_) { + rowOffsetAllocatedHere = true; + row_offset_ = reinterpret_cast<int32_t*>( + aligned_alloc(64, BaseType::brow_ * sizeof(accT))); + } +} + +template <typename T, typename accT> +void PackAWithQuantRowOffset<T, accT>::pack(const block_type_t& block) { + assert(block.row_start % BaseType::blockRowSize() == 0); + assert(block.col_start % BaseType::blockColSize() == 0); + assert(block.row_size <= BaseType::blockRowSize()); + assert(block.col_size <= BaseType::blockColSize()); + + block_type_t block_p = {block.row_start, + block.row_size, + block.col_start, + (block.col_size + row_interleave_B_ - 1) / + row_interleave_B_ * row_interleave_B_}; + assert(block_p.col_size <= BaseType::blockColSize()); + BaseType::packedBlock(block_p); + + T* out = BaseType::getBuf(); + bool tr = (trans_ == matrix_op_t::Transpose); + // accumulate into row offset? + bool row_offset_acc = (block.col_start != 0); + int32_t* row_offset_buf = getRowOffsetBuffer(); + + float smat_transposed[block.row_size * block.col_size]; + if (tr) { + transpose_simd( + block.col_size, + block.row_size, + smat_ + block.col_start * ld_ + block.row_start, + ld_, + smat_transposed, + block.col_size); + } + const float* smat_temp = + tr ? smat_transposed : smat_ + block.row_start * ld_ + block.col_start; + int32_t ld_temp = tr ? block.col_size : ld_; + +#if defined(__AVX2__) && defined(__FMA__) + constexpr int VLEN = 8; + __m256 inverse_scale_v = _mm256_set1_ps(1.0f / scale_); + __m256i shuffle_mask_v = _mm256_set_epi8( + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0x0c, 0x08, 0x04, 0x00, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0x0c, 0x08, 0x04, 0x00); + __m256i permute_mask_v = _mm256_set_epi32( + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x00); +#endif + + for (int i = 0; i < block.row_size; ++i) { + int32_t row_sum = row_offset_acc ? row_offset_buf[i] : 0; + int j = 0; +#if defined(__AVX2__) && defined(__FMA__) + static_assert( + std::is_same<T, uint8_t>::value, + "PackAWithQuantRowOffset<T, accT>::pack only works for T == uint8_t"); + for (; j < block.col_size / VLEN * VLEN; j += VLEN) { + __m256 val_v = _mm256_loadu_ps(smat_temp + i * ld_temp + j); + __m256 transformed_v = _mm256_fmadd_ps( + val_v, inverse_scale_v, _mm256_set1_ps(BaseType::zeroPoint())); + __m256 clipped_v = _mm256_max_ps( + _mm256_set1_ps(std::numeric_limits<uint8_t>::min()), + _mm256_min_ps( + transformed_v, + _mm256_set1_ps(std::numeric_limits<uint8_t>::max()))); + __m256i res_v = _mm256_cvtps_epi32(clipped_v); + + // An instruction sequence to save 8 32-bit integers as 8 8-bit integers + res_v = _mm256_shuffle_epi8(res_v, shuffle_mask_v); + res_v = _mm256_permutevar8x32_epi32(res_v, permute_mask_v); + _mm_storel_epi64( + reinterpret_cast<__m128i*>(out + i * BaseType::blockColSize() + j), + _mm256_castsi256_si128(res_v)); + + for (int j2 = j; j2 < j + VLEN; ++j2) { + row_sum += out[i * BaseType::blockColSize() + j2]; + } + } +#endif + for (; j < block.col_size; ++j) { + float val = smat_temp[i * ld_temp + j]; + float transformed = val / scale_ + BaseType::zeroPoint(); + float clipped = std::min<float>( + std::max<float>(transformed, std::numeric_limits<uint8_t>::min()), + std::numeric_limits<uint8_t>::max()); + T res = round(clipped); + row_sum += res; + out[i * BaseType::blockColSize() + j] = res; + } + // zero fill + // Please see the comment in PackAMatrix.cc on zero vs zero_pt fill. + for (; j < block_p.col_size; ++j) { + out[i * BaseType::blockColSize() + j] = 0; + } + row_offset_buf[i] = row_sum; + } +} + +template <typename T, typename accT> +int32_t PackAWithQuantRowOffset<T, accT>::addr(int32_t r, int32_t c) const { + int32_t block_row_id = r / BaseType::blockRowSize(); + int32_t brow_offset = (block_row_id * BaseType::blockCols()) * + (BaseType::blockRowSize() * BaseType::blockColSize()); + + int32_t block_col_id = c / BaseType::blockColSize(); + int32_t bcol_offset = + block_col_id * BaseType::blockRowSize() * BaseType::blockColSize(); + int32_t block_offset = brow_offset + bcol_offset; + int32_t inblock_offset = + (r % BaseType::blockRowSize()) * BaseType::blockColSize() + + (c % BaseType::blockColSize()); + + int32_t index = block_offset + inblock_offset; + + return index; +} + +template <typename T, typename accT> +void PackAWithQuantRowOffset<T, accT>::printPackedMatrix(std::string name) { + std::cout << name << ":" + << "[" << BaseType::numPackedRows() << ", " + << BaseType::numPackedCols() << "]" << std::endl; + + T* out = BaseType::getBuf(); + for (auto r = 0; r < BaseType::numPackedRows(); ++r) { + for (auto c = 0; c < BaseType::numPackedCols(); ++c) { + T val = out[addr(r, c)]; + if (std::is_integral<T>::value) { + // cast to int64 because cout doesn't print int8_t type directly + std::cout << std::setw(5) << static_cast<int64_t>(val) << " "; + } else { + std::cout << std::setw(5) << val << " "; + } + } + std::cout << std::endl; + } + std::cout << std::endl; +} + +template <typename T, typename accT> +int PackAWithQuantRowOffset<T, accT>::rowOffsetBufferSize() { + if (cpuinfo_initialize()) { + if (cpuinfo_has_x86_avx512f()) { + // TODO: avx512 path + // Currently use avx2 code + return PackingTraits<T, accT, inst_set_t::avx2>::MCB; + } else if (cpuinfo_has_x86_avx2()) { + return PackingTraits<T, accT, inst_set_t::avx2>::MCB; + } else { + assert(0 && "unsupported architecture"); + return -1; + } + } else { + throw std::runtime_error("Failed to initialize cpuinfo!"); + } +} + +template class PackAWithQuantRowOffset<uint8_t, int32_t>; + +} // namespace fbgemm2 diff --git a/src/PackWithRowOffset.cc b/src/PackWithRowOffset.cc new file mode 100644 index 0000000..8722723 --- /dev/null +++ b/src/PackWithRowOffset.cc @@ -0,0 +1,211 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include <cassert> +#include <cstring> +#include <iomanip> +#include <iostream> +#include <stdexcept> +#include <cpuinfo.h> +#include "fbgemm/Fbgemm.h" + +namespace fbgemm2 { + +template <typename T, typename accT> +PackAWithRowOffset<T, accT>::PackAWithRowOffset( + matrix_op_t trans, + uint32_t nRow, + uint32_t nCol, + const T* smat, + uint32_t ld, + inpType* pmat, + uint32_t groups, + int32_t zero_pt, + int32_t* row_offset) + : PackMatrix<PackAWithRowOffset<T, accT>, T, accT>( + nRow, + nCol, + pmat, + zero_pt), + trans_(trans), + smat_(smat), + ld_(ld), + G_(groups), + row_offset_(row_offset) { + assert(G_ == 1 && "Groups != 1 not supported yet"); + + rowOffsetAllocatedHere = false; + + if (cpuinfo_has_x86_avx512f()) { + BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx512>::MCB; + BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx512>::KCB; + row_interleave_B_ = + PackingTraits<T, accT, inst_set_t::avx512>::ROW_INTERLEAVE; + } else if (cpuinfo_has_x86_avx2()) { + BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx2>::MCB; + BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx2>::KCB; + row_interleave_B_ = + PackingTraits<T, accT, inst_set_t::avx2>::ROW_INTERLEAVE; + } else { + //TODO: Have default slower path + assert(0 && "unknown architecure"); + } + if (pmat) { + BaseType::buf_ = pmat; + } else { + BaseType::bufAllocatedHere_ = true; + BaseType::buf_ = + (T*)aligned_alloc(64, BaseType::brow_ * BaseType::bcol_ * sizeof(T)); + } + if (!row_offset_) { + rowOffsetAllocatedHere = true; + row_offset_ = static_cast<int32_t*>(aligned_alloc(64, + BaseType::brow_ * sizeof(int32_t))); + } +} + +template <typename T, typename accT> +void PackAWithRowOffset<T, accT>::pack(const block_type_t& block) { + assert(block.row_start % BaseType::blockRowSize() == 0); + assert(block.col_start % BaseType::blockColSize() == 0); + assert(block.row_size <= BaseType::blockRowSize()); + assert(block.col_size <= BaseType::blockColSize()); + + block_type_t block_p = {block.row_start, + block.row_size, + block.col_start, + (block.col_size + row_interleave_B_ - 1) / + row_interleave_B_ * row_interleave_B_}; + assert(block_p.col_size <= BaseType::blockColSize()); + BaseType::packedBlock(block_p); + + T* out = BaseType::getBuf(); + bool tr = (trans_ == matrix_op_t::Transpose); + // accumulate into row offset? + bool row_offset_acc = (block.col_start != 0); + int32_t* row_offset_buf = getRowOffsetBuffer(); + if (tr) { + for (int i = block.row_start; i < block.row_start + block.row_size; ++i) { + int32_t row_sum = row_offset_acc ? + row_offset_buf[i - block.row_start] : 0; + for (int j = block.col_start; j < block.col_start + block.col_size; ++j) { + T val = smat_[i + ld_ * j]; + row_sum += val; + out[(i - block.row_start) * BaseType::blockColSize() + + (j - block.col_start)] = val; + } + row_offset_buf[i - block.row_start] = row_sum; + // zero fill + // Please see the comment in PackAMatrix.cc on zero vs zero_pt fill. + for (int j = block.col_start + block.col_size; + j < block_p.col_start + block_p.col_size; ++j) { + out[(i - block.row_start) * BaseType::blockColSize() + + (j - block.col_start)] = 0; + } + } + } else { + for (int i = block.row_start; i < block.row_start + block.row_size; ++i) { + int buf_idx = i - block.row_start; + memcpy( + out + buf_idx * BaseType::blockColSize(), + smat_ + i * ld_ + block.col_start, + block.col_size * sizeof(T)); + // zero fill + for (int j = block.col_size; j < block_p.col_size; ++j) { + out[buf_idx * BaseType::blockColSize() + j] = 0; + } + int32_t row_sum = row_offset_acc ? + row_offset_buf[i - block.row_start] : 0; + __m256i sum_v = _mm256_setzero_si256(); + __m256i one_epi16_v = _mm256_set1_epi16(1); + __m256i one_epi8_v = _mm256_set1_epi8(1); + for (int j = block.col_start; + j < block.col_start + block.col_size / 32 * 32; + j += 32) { + __m256i src_v = _mm256_loadu_si256( + reinterpret_cast<__m256i const*>(smat_ + i * ld_ + j)); + sum_v = _mm256_add_epi32( + sum_v, + _mm256_madd_epi16( + _mm256_maddubs_epi16(src_v, one_epi8_v), one_epi16_v)); + } + for (int j = block.col_start + block.col_size / 32 * 32; + j < block.col_start + block.col_size; + ++j) { + row_sum += smat_[i * ld_ + j]; + } + alignas(64) std::array<int32_t, 8> temp; + _mm256_store_si256(reinterpret_cast<__m256i*>(temp.data()), sum_v); + for (int k = 0; k < 8; ++k) { + row_sum += temp[k]; + } + row_offset_buf[i - block.row_start] = row_sum; + } + } +} + +template <typename T, typename accT> +int32_t PackAWithRowOffset<T, accT>::addr(int32_t r, int32_t c) const { + int32_t block_row_id = r / BaseType::blockRowSize(); + int32_t brow_offset = (block_row_id * BaseType::blockCols()) * + (BaseType::blockRowSize() * BaseType::blockColSize()); + + int32_t block_col_id = c / BaseType::blockColSize(); + int32_t bcol_offset = + block_col_id * BaseType::blockRowSize() * BaseType::blockColSize(); + int32_t block_offset = brow_offset + bcol_offset; + int32_t inblock_offset = + (r % BaseType::blockRowSize()) * BaseType::blockColSize() + + (c % BaseType::blockColSize()); + + int32_t index = block_offset + inblock_offset; + + return index; +} + +template <typename T, typename accT> +void PackAWithRowOffset<T, accT>::printPackedMatrix(std::string name) { + std::cout << name << ":" + << "[" << BaseType::numPackedRows() << ", " + << BaseType::numPackedCols() << "]" << std::endl; + + T* out = BaseType::getBuf(); + for (auto r = 0; r < BaseType::numPackedRows(); ++r) { + for (auto c = 0; c < BaseType::numPackedCols(); ++c) { + T val = out[addr(r, c)]; + if (std::is_integral<T>::value) { + // cast to int64 because cout doesn't print int8_t type directly + std::cout << std::setw(5) << static_cast<int64_t>(val) << " "; + } else { + std::cout << std::setw(5) << val << " "; + } + } + std::cout << std::endl; + } + std::cout << std::endl; +} + +template <typename T, typename accT> +int PackAWithRowOffset<T, accT>::rowOffsetBufferSize() { + if(cpuinfo_initialize()){ + if (cpuinfo_has_x86_avx512f()) { + return PackingTraits<T, accT, inst_set_t::avx512>::MCB; + } else if (cpuinfo_has_x86_avx2()) { + return PackingTraits<T, accT, inst_set_t::avx2>::MCB; + } else { + //TODO: Have default slower path + assert(0 && "unsupported architecture"); + return -1; + } + } else { + throw std::runtime_error("Failed to initialize cpuinfo!"); + } +} + +template class PackAWithRowOffset<uint8_t, int32_t>; +template class PackAWithRowOffset<uint8_t, int16_t>; + +} // namespace fbgemm2 diff --git a/src/RefImplementations.cc b/src/RefImplementations.cc new file mode 100644 index 0000000..9aedc88 --- /dev/null +++ b/src/RefImplementations.cc @@ -0,0 +1,608 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include "RefImplementations.h" + +#include <cassert> +#include <cmath> + +using namespace std; + +namespace fbgemm2 { + +void requantize_u8acc32_ref( + int M, + int N, + int ld, + const int32_t* inp, + uint8_t* out, + int32_t C_multiplier, + int32_t C_right_shift, + int32_t C_zero_point, + int32_t A_zero_point, + int32_t B_zero_point, + const int32_t* row_offsets, + const int32_t* col_offsets, + const int32_t* bias, + bool fuse_relu) { + int64_t nudge = 1ll << std::max(0, C_right_shift - 1); + for (int i = 0; i < M; ++i) { + for (int j = 0; j < N; ++j) { + int32_t raw = inp[i * ld + j]; + raw -= A_zero_point * col_offsets[j]; + raw -= B_zero_point * row_offsets[i]; + if (bias) { + raw += bias[j]; + } + + int64_t ab_64 = + static_cast<int64_t>(raw) * static_cast<int64_t>(C_multiplier); + int64_t rounded = ((ab_64 + nudge) >> C_right_shift) + C_zero_point; + + out[i * ld + j] = std::max( + fuse_relu ? static_cast<int64_t>(C_zero_point) : 0l, + std::min(255l, rounded)); + } + } +} + +void requantize_u8acc32_ref( + int M, + int N, + int ld, + const int32_t* inp, + uint8_t* out, + float C_multiplier, + int32_t C_zero_point, + int32_t A_zero_point, + int32_t B_zero_point, + const int32_t* row_offsets, + const int32_t* col_offsets, + const int32_t* bias, + bool fuse_relu) { + for (int i = 0; i < M; ++i) { + for (int j = 0; j < N; ++j) { + int32_t raw = inp[i * ld + j]; + raw -= A_zero_point * col_offsets[j]; + raw -= B_zero_point * row_offsets[i]; + if (bias) { + raw += bias[j]; + } + + float result = raw * C_multiplier; + long rounded = lrintf(result) + C_zero_point; + out[i * ld + j] = std::max( + fuse_relu ? static_cast<long>(C_zero_point) : 0l, + std::min(255l, rounded)); + } + } +} + +void matmul_u8i8acc32_ref( + int M, + int N, + int K, + int lda, + int ldb, + int ldc, + const uint8_t* Aint8, + const int8_t* Bint8, + int32_t* Cint32) { + for (int j = 0; j < N; ++j) { + for (int i = 0; i < M; ++i) { + int32_t sum = 0; + for (int k = 0; k < K; ++k) { + sum += static_cast<int32_t>(Aint8[i * lda + k]) * + static_cast<int32_t>(Bint8[k * ldb + j]); + } + Cint32[i * ldc + j] = sum; + } + } +} + +void matmul_u8i8acc16_ref( + int M, + int N, + int K, + int lda, + int ldb, + int ldc, + int brow, + const uint8_t* Aint8, + const int8_t* Bint8, + int32_t* Cint32) { + for (int j = 0; j < N; ++j) { + for (int i = 0; i < M; ++i) { + int32_t sum = 0, sum_32bit = 0; + for (int k = 0; k < K; k += 2) { + int a0 = Aint8[i * lda + k]; + int b0 = Bint8[k * ldb + j]; + int a1 = 0, b1 = 0; + if (k + 1 < K) { + a1 = Aint8[i * lda + k + 1]; + b1 = Bint8[(k + 1) * ldb + j]; + } + sum = clip_16bit(sum + clip_16bit(a0 * b0 + a1 * b1)); + if ((k % brow) == (brow - 2)) { + sum_32bit += sum; + sum = 0; + } + } + Cint32[i * ldc + j] = sum_32bit + sum; + } + } +} + +void matmul_fp_ref( + int M, + int N, + int K, + int lda, + int ldb, + int ldc, + const float* Afp32, + const float* Bfp32, + float* Cfp32) { + for (int j = 0; j < N; ++j) { + for (int i = 0; i < M; ++i) { + float sum = 0; + for (int k = 0; k < K; ++k) { + sum += Afp32[i * lda + k] * Bfp32[k * ldb + j]; + } + Cfp32[i * ldc + j] = sum; + } + } +} + +void row_offsets_u8acc32_ref( + int M, + int K, + int ld, + const uint8_t* Aint8, + int32_t* row_offsets) { + // row offset + for (int i = 0; i < M; ++i) { + int32_t sum = 0; + for (int k = 0; k < K; ++k) { + sum += static_cast<int32_t>(Aint8[i * ld + k]); + } + row_offsets[i] = sum; + } +} + +void col_offsets_with_zero_pt_s8acc32_ref( + int K, + int N, + int ld, + const int8_t* Bint8, + int32_t B_zero_point, + int32_t* col_offsets) { + for (int j = 0; j < N; ++j) { + int32_t sum = 0; + for (int k = 0; k < K; ++k) { + sum += Bint8[k * ld + j]; + } + col_offsets[j] = sum - B_zero_point * K; + } +} + +void spmdm_ref( + int M, + const uint8_t* A, + int lda, + fbgemm2::CompressedSparseColumn& B, + bool accumulation, + int32_t* C, + int ldc) { + int N = B.NumOfCols(); + if (!accumulation) { + for (int i = 0; i < M; ++i) { + for (int j = 0; j < N; ++j) { + C[i * ldc + j] = 0; + } + } + } + for (int j = 0; j < N; ++j) { + for (int k = B.ColPtr()[j]; k < B.ColPtr()[j + 1]; ++k) { + int row = B.RowIdx()[k]; + int w = B.Values()[k]; + for (int i = 0; i < M; ++i) { + C[i * ldc + j] += A[i * lda + row] * w; + } + } + } // for each column of B +} + +int32_t clip_16bit(int32_t x) { + if (x > std::numeric_limits<int16_t>::max()) { + return std::min<int>(std::numeric_limits<int16_t>::max(), x); + } else if (x < std::numeric_limits<int16_t>::min()) { + return std::max<int>(std::numeric_limits<int16_t>::min(), x); + } else { + return x; + } +} + +void im2col_ref( + const conv_param_t& conv_p, + const std::uint8_t* A, + std::int32_t A_zero_point, + std::uint8_t* Ao) { + for (int n = 0; n < conv_p.MB; ++n) { + for (int h = 0; h < conv_p.OH; ++h) { + for (int w = 0; w < conv_p.OW; ++w) { + for (int r = 0; r < conv_p.KH; ++r) { + int h_in = -conv_p.pad_h + h * conv_p.stride_h + r; + for (int s = 0; s < conv_p.KW; ++s) { + int w_in = -conv_p.pad_w + w * conv_p.stride_w + s; + for (int c = 0; c < conv_p.IC; ++c) { + // Ai: NHWC: NH_0W_0 x C_0 + std::uint8_t val = + h_in < 0 || h_in >= conv_p.IH || w_in < 0 || w_in >= conv_p.IW + ? A_zero_point + : A[((n * conv_p.IH + h_in) * conv_p.IW + w_in) * conv_p.IC + + c]; + // Ao: NHWC: NH_1W_1 x RSC_0 + Ao[((((n * conv_p.OH + h) * conv_p.OW + w) * conv_p.KH + r) * + conv_p.KW + + s) * + conv_p.IC + + c] = val; + } // for each c + } // for each s + } // for each r + } // for each w + } // for each h + } // for each n +} + +void conv_ref( + const conv_param_t& conv_p, + const std::uint8_t* A, + std::int32_t A_zero_point, + const std::int8_t* B, + std::int32_t* C) { + // filters are assumed to be in RSCK format + assert(conv_p.G == 1 && "Groups != 1 not supported yet"); + + for (int n = 0; n < conv_p.MB; ++n) { + for (int h = 0; h < conv_p.OH; ++h) { + for (int w = 0; w < conv_p.OW; ++w) { + for (int k = 0; k < conv_p.OC; ++k) { + int sum = 0; + for (int r = 0; r < conv_p.KH; ++r) { + int h_in = -conv_p.pad_h + h * conv_p.stride_h + r; + for (int s = 0; s < conv_p.KW; ++s) { + int w_in = -conv_p.pad_w + w * conv_p.stride_w + s; + for (int c = 0; c < conv_p.IC; ++c) { + int a = h_in < 0 || h_in >= conv_p.IH || w_in < 0 || + w_in >= conv_p.IW + ? A_zero_point + : A[((n * conv_p.IH + h_in) * conv_p.IW + w_in) * + conv_p.IC + + c]; + int b = + B[((r * conv_p.KW + s) * conv_p.IC + c) * conv_p.OC + k]; + sum += a * b; + } // for each c + } // for each s + } // for each r + C[((n * conv_p.OH + h) * conv_p.OW + w) * conv_p.OC + k] = sum; + } // for each k + } // for each w + } // for each h + } // for each n +} + +void depthwise_3x3_pad_1_ref( + int N, + int H, + int W, + int K, + int stride_h, + int stride_w, + int32_t A_zero_point, + const uint8_t* A, + const int8_t* B, + int32_t* C) { + constexpr int R = 3, S = 3; + constexpr int PAD_T = 1, PAD_B = 1, PAD_L = 1, PAD_R = 1; + int H_OUT = (H + PAD_T + PAD_B - R) / stride_h + 1; + int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1; + + for (int n = 0; n < N; ++n) { + for (int h = 0; h < H_OUT; ++h) { + for (int w = 0; w < W_OUT; ++w) { + for (int k = 0; k < K; ++k) { + int sum = 0; + for (int r = 0; r < R; ++r) { + int h_in = -PAD_T + h * stride_h + r; + for (int s = 0; s < S; ++s) { + int w_in = -PAD_L + w * stride_w + s; + int a = h_in < 0 || h_in >= H || w_in < 0 || w_in >= W + ? A_zero_point + : A[((n * H + h_in) * W + w_in) * K + k]; + int b = B[(k * R + r) * S + s]; + sum += a * b; + } + } + C[((n * H_OUT + h) * W_OUT + w) * K + k] = sum; + } + } + } + } // for each n +}; + +void depthwise_3x3_pad_1_ref( + int N, + int H, + int W, + int K, + int stride_h, + int stride_w, + int32_t A_zero_point, + const uint8_t* A, + int32_t B_zero_point, + const int8_t* B, + float C_multiplier, + int32_t C_zero_point, + uint8_t* C, + const int32_t* col_offsets, + const int32_t* bias) { + constexpr int R = 3, S = 3; + constexpr int PAD_T = 1, PAD_B = 1, PAD_L = 1, PAD_R = 1; + int H_OUT = (H + PAD_T + PAD_B - R) / stride_h + 1; + int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1; + + vector<int32_t> C_int32(N * H_OUT * W_OUT * K); + depthwise_3x3_pad_1_ref( + N, H, W, K, stride_h, stride_w, A_zero_point, A, B, C_int32.data()); + + vector<int32_t> row_offsets(N * H_OUT * W_OUT * K); + for (int n = 0; n < N; ++n) { + for (int h = 0; h < H_OUT; ++h) { + for (int w = 0; w < W_OUT; ++w) { + for (int k = 0; k < K; ++k) { + int sum = 0; + for (int r = 0; r < R; ++r) { + int h_in = -PAD_T + h * stride_h + r; + for (int s = 0; s < S; ++s) { + int w_in = -PAD_L + w * stride_w + s; + int a = h_in < 0 || h_in >= H || w_in < 0 || w_in >= W + ? A_zero_point + : A[((n * H + h_in) * W + w_in) * K + k]; + sum += a; + } + } + row_offsets[((n * H_OUT + h) * W_OUT + w) * K + k] = sum; + } + } + } + } // for each n + + for (int i = 0; i < N * H_OUT * W_OUT; ++i) { + for (int k = 0; k < K; ++k) { + requantize_u8acc32_ref( + 1, + 1, + 1, + C_int32.data() + i * K + k, + C + i * K + k, + C_multiplier, + C_zero_point, + A_zero_point, + B_zero_point, + &row_offsets[i * K + k], + col_offsets + k, + bias ? bias + k : nullptr); + } + } +}; + +void depthwise_3x3_per_channel_quantization_pad_1_ref( + int N, + int H, + int W, + int K, + int stride_h, + int stride_w, + int32_t A_zero_point, + const uint8_t* A, + const int32_t* B_zero_point, + const int8_t* B, + const float* C_multiplier, + int32_t C_zero_point, + uint8_t* C, + const int32_t* col_offsets, + const int32_t* bias) { + constexpr int R = 3, S = 3; + constexpr int PAD_T = 1, PAD_B = 1, PAD_L = 1, PAD_R = 1; + int H_OUT = (H + PAD_T + PAD_B - R) / stride_h + 1; + int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1; + + vector<int32_t> C_int32(N * H_OUT * W_OUT * K); + depthwise_3x3_pad_1_ref( + N, H, W, K, stride_h, stride_w, A_zero_point, A, B, C_int32.data()); + + vector<int32_t> row_offsets(N * H_OUT * W_OUT * K); + for (int n = 0; n < N; ++n) { + for (int h = 0; h < H_OUT; ++h) { + for (int w = 0; w < W_OUT; ++w) { + for (int k = 0; k < K; ++k) { + int sum = 0; + for (int r = 0; r < R; ++r) { + int h_in = -PAD_T + h * stride_h + r; + for (int s = 0; s < S; ++s) { + int w_in = -PAD_L + w * stride_w + s; + int a = h_in < 0 || h_in >= H || w_in < 0 || w_in >= W + ? A_zero_point + : A[((n * H + h_in) * W + w_in) * K + k]; + sum += a; + } + } + row_offsets[((n * H_OUT + h) * W_OUT + w) * K + k] = sum; + } + } + } + } // for each n + + for (int i = 0; i < N * H_OUT * W_OUT; ++i) { + for (int k = 0; k < K; ++k) { + requantize_u8acc32_ref( + 1, + 1, + 1, + C_int32.data() + i * K + k, + C + i * K + k, + C_multiplier[k], + C_zero_point, + A_zero_point, + B_zero_point[k], + &row_offsets[i * K + k], + col_offsets + k, + bias ? bias + k : nullptr); + } + } +}; + +void depthwise_3x3x3_pad_1_ref( + int N, + int T, + int H, + int W, + int K, + int stride_t, + int stride_h, + int stride_w, + int32_t A_zero_point, + const uint8_t* A, + const int8_t* B, + int32_t* C) { + constexpr int K_T = 3, K_H = 3, K_W = 3; + constexpr int PAD_P = 1, PAD_N = 1, PAD_T = 1, PAD_B = 1, PAD_L = 1, + PAD_R = 1; + int T_OUT = (T + PAD_P + PAD_N - K_T) / stride_t + 1; + int H_OUT = (H + PAD_T + PAD_B - K_H) / stride_h + 1; + int W_OUT = (W + PAD_L + PAD_R - K_W) / stride_w + 1; + + for (int n = 0; n < N; ++n) { + for (int t = 0; t < T_OUT; ++t) { + for (int h = 0; h < H_OUT; ++h) { + for (int w = 0; w < W_OUT; ++w) { + for (int k = 0; k < K; ++k) { + int sum = 0; + for (int k_t = 0; k_t < K_T; ++k_t) { + int t_in = -PAD_P + t * stride_t + k_t; + for (int k_h = 0; k_h < K_H; ++k_h) { + int h_in = -PAD_T + h * stride_h + k_h; + for (int k_w = 0; k_w < K_W; ++k_w) { + int w_in = -PAD_L + w * stride_w + k_w; + int a = t_in < 0 || t_in >= T || h_in < 0 || h_in >= H || + w_in < 0 || w_in >= W + ? A_zero_point + : A[(((n * T + t_in) * H + h_in) * W + w_in) * K + k]; + int b = B[((k * K_T + k_t) * K_H + k_h) * K_W + k_w]; + sum += a * b; + } + } + } + C[(((n * T_OUT + t) * H_OUT + h) * W_OUT + w) * K + k] = sum; + } + } // w + } // h + } // t + } // for each n +}; + +void depthwise_3x3x3_pad_1_ref( + int N, + int T, + int H, + int W, + int K, + int stride_t, + int stride_h, + int stride_w, + int32_t A_zero_point, + const uint8_t* A, + int32_t B_zero_point, + const int8_t* B, + float C_multiplier, + int32_t C_zero_point, + uint8_t* C, + const int32_t* col_offsets, + const int32_t* bias) { + constexpr int K_T = 3, K_H = 3, K_W = 3; + constexpr int PAD_P = 1, PAD_N = 1, PAD_T = 1, PAD_B = 1, PAD_L = 1, + PAD_R = 1; + int T_OUT = (T + PAD_P + PAD_N - K_T) / stride_t + 1; + int H_OUT = (H + PAD_T + PAD_B - K_H) / stride_h + 1; + int W_OUT = (W + PAD_L + PAD_R - K_W) / stride_w + 1; + + vector<int32_t> C_int32(N * T_OUT * H_OUT * W_OUT * K); + depthwise_3x3x3_pad_1_ref( + N, + T, + H, + W, + K, + stride_t, + stride_h, + stride_w, + A_zero_point, + A, + B, + C_int32.data()); + + vector<int32_t> row_offsets(N * T_OUT * H_OUT * W_OUT * K); + for (int n = 0; n < N; ++n) { + for (int t = 0; t < T_OUT; ++t) { + for (int h = 0; h < H_OUT; ++h) { + for (int w = 0; w < W_OUT; ++w) { + for (int k = 0; k < K; ++k) { + int sum = 0; + for (int k_t = 0; k_t < K_T; ++k_t) { + int t_in = -PAD_P + t * stride_t + k_t; + for (int k_h = 0; k_h < K_H; ++k_h) { + int h_in = -PAD_T + h * stride_h + k_h; + for (int k_w = 0; k_w < K_W; ++k_w) { + int w_in = -PAD_L + w * stride_w + k_w; + int a = t_in < 0 || t_in >= T || h_in < 0 || h_in >= H || + w_in < 0 || w_in >= W + ? A_zero_point + : A[(((n * T + t_in) * H + h_in) * W + w_in) * K + k]; + sum += a; + } + } + } + row_offsets[(((n * T_OUT + t) * H_OUT + h) * W_OUT + w) * K + k] = + sum; + } + } // w + } // h + } // t + } // for each n + + for (int i = 0; i < N * T_OUT * H_OUT * W_OUT; ++i) { + for (int k = 0; k < K; ++k) { + requantize_u8acc32_ref( + 1, + 1, + 1, + C_int32.data() + i * K + k, + C + i * K + k, + C_multiplier, + C_zero_point, + A_zero_point, + B_zero_point, + &row_offsets[i * K + k], + col_offsets + k, + bias ? bias + k : nullptr); + } + } +}; + +} // namespace fbgemm2 diff --git a/src/RefImplementations.h b/src/RefImplementations.h new file mode 100644 index 0000000..e9eaeed --- /dev/null +++ b/src/RefImplementations.h @@ -0,0 +1,268 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#pragma once + +#include <algorithm> +#include <cstdint> + +#include "fbgemm/ConvUtils.h" +#include "fbgemm/FbgemmI8Spmdm.h" + +namespace fbgemm2 { + +/** + * @brief Reference implementation of requantization step. + * int32 multiplier + * @params bias can be nullptr + */ +void requantize_u8acc32_ref( + int M, + int N, + int ld, + const std::int32_t* inp, + std::uint8_t* out, + std::int32_t C_multiplier, + std::int32_t C_right_shift, + std::int32_t C_zero_point, + std::int32_t A_zero_point, + std::int32_t B_zero_point, + const std::int32_t* row_offsets, + const std::int32_t* col_offsets, + const std::int32_t* bias, + bool fuse_relu = false); + +/** + * @brief Reference implementation of requantization step. + * float multiplier + * @params bias can be nullptr + */ +void requantize_u8acc32_ref( + int M, + int N, + int ld, + const std::int32_t* inp, + std::uint8_t* out, + float C_multiplier, + std::int32_t C_zero_point, + std::int32_t A_zero_point, + std::int32_t B_zero_point, + const std::int32_t* row_offsets, + const std::int32_t* col_offsets, + const std::int32_t* bias, + bool fuse_relu = false); + +/** + * @brief Reference implementation of matrix multiply with uint8 for A, + * int8 for B, and 32-bit accumulation. + */ +void matmul_u8i8acc32_ref( + int M, + int N, + int K, + int lda, + int ldb, + int ldc, + const std::uint8_t* Aint8, + const std::int8_t* Bint8, + std::int32_t* Cint32); + +/** + * @brief Reference implementation of matrix multiply with uint 8 for A, + * int8 for B, and 16-bit accumulation. + */ +void matmul_u8i8acc16_ref( + int M, + int N, + int K, + int lda, + int ldb, + int ldc, + int brow, + const std::uint8_t* Aint8, + const std::int8_t* Bint8, + std::int32_t* Cint32); + +/** + * @brief Reference implementation of matrix multiply with fp32 (single + * precision) floating point number. + */ +void matmul_fp_ref( + int M, + int N, + int K, + int lda, + int ldb, + int ldc, + const float* Afp32, + const float* Bfp32, + float* Cfp32); + +/** + * @brief Reference implementation to compute row_offsets (sums of rows of A). + */ +void row_offsets_u8acc32_ref( + int M, + int K, + int ld, + const std::uint8_t* Aint8, + std::int32_t* row_offsets); + +/** + * @brief Reference implementation to compute adjusted col_offsets (sum of + * columns of B and adjusted with B_zero_point) + */ +void col_offsets_with_zero_pt_s8acc32_ref( + int K, + int N, + int ld, + const std::int8_t* Bint8, + std::int32_t B_zero_point, + std::int32_t* col_offsets); + +/** + * @brief Reference implementation of SPMDM (sparse matrix times dense matrix). + */ +void spmdm_ref( + int M, + const std::uint8_t* A, + int lda, + CompressedSparseColumn& B, + bool accumulation, + std::int32_t* C, + int ldc); + +/* + * @brief Trim a 32-bit integer to a 16-bit integer. + */ +int32_t clip_16bit(int32_t x); + +/* + * @brief Reference implementation of convolution operation. + * The activations A are assumed to be in NHiWiC format. + * The filters B are assumed to be in RSCK format. + * The output C is assumed to be in NHoWoC format. + */ +void conv_ref( + const conv_param_t& conv_p, + const std::uint8_t* A, + std::int32_t A_zero_point, + const std::int8_t* B, + std::int32_t* C); + +/* + * @brief Reference implementation of im2col operation. + * The input A is assumed to be in NHiWiC format. + * The output A is assumed to be in NHoWoRSC format. + */ +void im2col_ref( + const conv_param_t& conv_p, + const std::uint8_t* A, + std::int32_t A_zero_point, + std::uint8_t* Ao); + +/* + * @brief Reference implementation of depthwise convolution with a 3x3 filter + * and padding size 1. + */ +void depthwise_3x3_pad_1_ref( + int N, + int H, + int W, + int K, + int stride_h, + int stride_w, + std::int32_t A_zero_point, + const std::uint8_t* A, + const std::int8_t* B, + std::int32_t* C); + +/* + * @brief Reference implementation of depthwise convolution with a 3x3 filter + * and padding size 1, followed by requantization. (the same scaling factors and + * zero points for each channel). + */ +void depthwise_3x3_pad_1_ref( + int N, + int H, + int W, + int K, + int stride_h, + int stride_w, + std::int32_t A_zero_point, + const std::uint8_t* A, + std::int32_t B_zero_point, + const std::int8_t* B, + float C_multiplier, + std::int32_t C_zero_point, + std::uint8_t* C, + const std::int32_t* col_offsets, + const std::int32_t* bias); + +/* + * @brief Reference implementation of depthwise convolution with a 3x3 filter + * and padding size 1, followed by requantization. (different scaling factors + * and zero points for each channel). + */ +void depthwise_3x3_per_channel_quantization_pad_1_ref( + int N, + int H, + int W, + int K, + int stride_h, + int stride_w, + std::int32_t A_zero_point, + const std::uint8_t* A, + const std::int32_t* B_zero_point, + const std::int8_t* B, + const float* C_multiplier, + std::int32_t C_zero_point, + std::uint8_t* C, + const std::int32_t* col_offsets, + const std::int32_t* bias); + +/* + * @brief Reference implementation of 3D depthwise convolution with a 3x3x3 + * filter and padding size 1. + */ +void depthwise_3x3x3_pad_1_ref( + int N, + int T, + int H, + int W, + int K, + int stride_t, + int stride_h, + int stride_w, + std::int32_t A_zero_point, + const std::uint8_t* A, + const std::int8_t* B, + std::int32_t* C); + +/* + * @brief Reference implementation of 3D depthwise convolution with a 3x3x3 + * filter and padding size 1, followed by requantization. + */ +void depthwise_3x3x3_pad_1_ref( + int N, + int T, + int H, + int W, + int K, + int stride_t, + int stride_h, + int stride_w, + std::int32_t A_zero_point, + const std::uint8_t* A, + std::int32_t B_zero_point, + const std::int8_t* B, + float C_multiplier, + std::int32_t C_zero_point, + std::uint8_t* C, + const std::int32_t* col_offsets, + const std::int32_t* bias); + +} // namespace fbgemm2 diff --git a/src/Utils.cc b/src/Utils.cc new file mode 100644 index 0000000..10ab469 --- /dev/null +++ b/src/Utils.cc @@ -0,0 +1,357 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include "fbgemm/Utils.h" +#include <cpuinfo.h> +#include <immintrin.h> +#include <cassert> +#include <cinttypes> +#include <cmath> +#include <iomanip> +#include <iostream> +#include <limits> +#include <stdexcept> + +namespace fbgemm2 { + +/** + * @brief Compare the reference and test result matrix to check the correctness. + * @param ref The buffer for the reference result matrix. + * @param test The buffer for the test result matrix. + * @param m The height of the reference and test result matrix. + * @param n The width of the reference and test result matrix. + * @param ld The leading dimension of the reference and test result matrix. + * @param max_mismatches_to_report The maximum number of tolerable mismatches to + * report. + * @param atol The tolerable error. + * @retval false If the number of mismatches for reference and test result + * matrix exceeds max_mismatches_to_report. + * @retval true If the number of mismatches for reference and test result matrix + * is tolerable. + */ +template <typename T> +int compare_buffers( + const T* ref, + const T* test, + int m, + int n, + int ld, + int max_mismatches_to_report, + float atol /*=1e-3*/) { + size_t mismatches = 0; + for (int i = 0; i < m; ++i) { + for (int j = 0; j < n; ++j) { + T reference = ref[i * ld + j], actual = test[i * ld + j]; + if (std::abs(reference - actual) > atol) { + std::cout << "\tmismatch at (" << i << ", " << j << ")" << std::endl; + if (std::is_integral<T>::value) { + std::cout << "\t reference:" << static_cast<int64_t>(reference) + << " test:" << static_cast<int64_t>(actual) << std::endl; + } else { + std::cout << "\t reference:" << reference << " test:" << actual + << std::endl; + } + + mismatches++; + if (mismatches > max_mismatches_to_report) { + return 1; + } + } + } + } + return 0; +} + + +/** + * @brief Print the matrix. + * @param op Transpose type of the matrix. + * @param R The height of the matrix. + * @param C The width of the matrix. + * @param ld The leading dimension of the matrix. + * @param name The prefix string before printing the matrix. + */ +template <typename T> +void printMatrix( + matrix_op_t op, + const T* inp, + size_t R, + size_t C, + size_t ld, + std::string name) { + // R: number of rows in op(inp) + // C: number of cols in op(inp) + // ld: leading dimension in inp + std::cout << name << ":" + << "[" << R << ", " << C << "]" << std::endl; + bool tr = (op == matrix_op_t::Transpose); + for (auto r = 0; r < R; ++r) { + for (auto c = 0; c < C; ++c) { + T res = tr ? inp[c * ld + r] : inp[r * ld + c]; + if (std::is_integral<T>::value) { + std::cout << std::setw(5) << static_cast<int64_t>(res) << " "; + } else { + std::cout << std::setw(5) << res << " "; + } + } + std::cout << std::endl; + } +} + +template int compare_buffers<float>( + const float* ref, + const float* test, + int m, + int n, + int ld, + int max_mismatches_to_report, + float atol); + +template int compare_buffers<int32_t>( + const int32_t* ref, + const int32_t* test, + int m, + int n, + int ld, + int max_mismatches_to_report, + float atol); + +template int compare_buffers<uint8_t>( + const uint8_t* ref, + const uint8_t* test, + int m, + int n, + int ld, + int max_mismatches_to_report, + float atol); + +template void printMatrix<float>( + matrix_op_t op, + const float* inp, + size_t R, + size_t C, + size_t ld, + std::string name); +template void printMatrix<int8_t>( + matrix_op_t op, + const int8_t* inp, + size_t R, + size_t C, + size_t ld, + std::string name); +template void printMatrix<uint8_t>( + matrix_op_t op, + const uint8_t* inp, + size_t R, + size_t C, + size_t ld, + std::string name); +template void printMatrix<int32_t>( + matrix_op_t op, + const int32_t* inp, + size_t R, + size_t C, + size_t ld, + std::string name); + + +/** + * @brief Reference implementation of matrix transposition: B = A^T. + * @param M The height of the matrix. + * @param N The width of the matrix. + * @param src The memory buffer of the source matrix A. + * @param ld_src The leading dimension of the source matrix A. + * @param dst The memory buffer of the destination matrix B. + * @param ld_dst The leading dimension of the destination matrix B. + */ +inline void transpose_ref( + int M, + int N, + const float* src, + int ld_src, + float* dst, + int ld_dst) { + for (int i = 0; i < M; i++) { + for (int j = 0; j < N; j++) { + dst[i + j * ld_dst] = src[i * ld_src + j]; + } + } +} + +inline void +transpose_kernel_4x4_sse(const float* src, int ld_src, float* dst, int ld_dst) { + // load from src to registers + // a : a0 a1 a2 a3 + // b : b0 b1 b2 b3 + // c : c0 c1 c2 c3 + // d : d0 d1 d2 d3 + __m128 a = _mm_loadu_ps(&src[0 * ld_src]); + __m128 b = _mm_loadu_ps(&src[1 * ld_src]); + __m128 c = _mm_loadu_ps(&src[2 * ld_src]); + __m128 d = _mm_loadu_ps(&src[3 * ld_src]); + + // transpose the 4x4 matrix formed by 32-bit elements: Macro from SSE + // a : a0 b0 c0 d0 + // b : a1 b1 c1 d1 + // c : a2 b2 c2 d2 + // d : a3 b3 c3 d3 + _MM_TRANSPOSE4_PS(a, b, c, d); + + // store from registers to dst + _mm_storeu_ps(&dst[0 * ld_dst], a); + _mm_storeu_ps(&dst[1 * ld_dst], b); + _mm_storeu_ps(&dst[2 * ld_dst], c); + _mm_storeu_ps(&dst[3 * ld_dst], d); +} +inline void transpose_4x4( + int M, + int N, + const float* src, + int ld_src, + float* dst, + int ld_dst) { + int ib = 0, jb = 0; + for (ib = 0; ib + 4 <= M; ib += 4) { + for (jb = 0; jb + 4 <= N; jb += 4) { + transpose_kernel_4x4_sse( + &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst); + } + } + transpose_ref(ib, N - jb, &src[jb], ld_src, &dst[jb * ld_dst], ld_dst); + transpose_ref(M - ib, N, &src[ib * ld_src], ld_src, &dst[ib], ld_dst); +} + +inline void transpose_kernel_8x8_avx2( + const float* src, + int ld_src, + float* dst, + int ld_dst) { + // load from src to registers + // a : a0 a1 a2 a3 a4 a5 a6 a7 + // b : b0 b1 b2 b3 b4 b5 b6 b7 + // c : c0 c1 c2 c3 c4 c5 c6 c7 + // d : d0 d1 d2 d3 d4 d5 d6 d7 + // e : e0 e1 e2 e3 e4 e5 e6 e7 + // f : f0 f1 f2 f3 f4 f5 f6 f7 + // g : g0 g1 g2 g3 g4 g5 g6 g7 + // h : h0 h1 h2 h3 h4 h5 h6 h7 + __m256 a = _mm256_loadu_ps(&src[0 * ld_src]); + __m256 b = _mm256_loadu_ps(&src[1 * ld_src]); + __m256 c = _mm256_loadu_ps(&src[2 * ld_src]); + __m256 d = _mm256_loadu_ps(&src[3 * ld_src]); + __m256 e = _mm256_loadu_ps(&src[4 * ld_src]); + __m256 f = _mm256_loadu_ps(&src[5 * ld_src]); + __m256 g = _mm256_loadu_ps(&src[6 * ld_src]); + __m256 h = _mm256_loadu_ps(&src[7 * ld_src]); + + __m256 ab0145, ab2367, cd0145, cd2367, ef0145, ef2367, gh0145, gh2367; + __m256 abcd04, abcd15, efgh04, efgh15, abcd26, abcd37, efgh26, efgh37; + // unpacking and interleaving 32-bit elements + // ab0145 : a0 b0 a1 b1 a4 b4 a5 b5 + // ab2367 : a2 b2 a3 b3 a6 b6 a7 b7 + // cd0145 : c0 d0 c1 d1 c4 d4 c5 d5 + // cd2367 : c2 d2 c3 d3 c6 d6 c7 d7 + // ef0145 : e0 f0 e1 f1 e4 f4 e5 f5 + // ef2367 : e2 f2 e3 f3 e6 f6 e7 f7 + // gh0145 : g0 h0 g1 h1 g4 h4 g5 h5 + // gh2367 : g2 h2 g3 h3 g6 h6 g7 h7 + ab0145 = _mm256_unpacklo_ps(a, b); + ab2367 = _mm256_unpackhi_ps(a, b); + cd0145 = _mm256_unpacklo_ps(c, d); + cd2367 = _mm256_unpackhi_ps(c, d); + ef0145 = _mm256_unpacklo_ps(e, f); + ef2367 = _mm256_unpackhi_ps(e, f); + gh0145 = _mm256_unpacklo_ps(g, h); + gh2367 = _mm256_unpackhi_ps(g, h); + + // shuffling the 32-bit elements + // abcd04 : a0 b0 c0 d0 a4 b4 c4 d4 + // abcd15 : a1 b1 c1 d1 a5 b5 c5 d5 + // efgh04 : e0 f0 g0 h0 e4 f4 g4 h4 + // efgh15 : e1 f1 g1 h1 e5 b5 c5 d5 + // abcd26 : a2 b2 c2 d2 a6 b6 c6 d6 + // abcd37 : a3 b3 c3 d3 a7 b7 c7 d7 + // efgh26 : e2 f2 g2 h2 e6 f6 g6 h6 + // efgh37 : e3 f3 g3 h3 e7 f7 g7 h7 + abcd04 = _mm256_shuffle_ps(ab0145, cd0145, 0x44); + abcd15 = _mm256_shuffle_ps(ab0145, cd0145, 0xee); + efgh04 = _mm256_shuffle_ps(ef0145, gh0145, 0x44); + efgh15 = _mm256_shuffle_ps(ef0145, gh0145, 0xee); + abcd26 = _mm256_shuffle_ps(ab2367, cd2367, 0x44); + abcd37 = _mm256_shuffle_ps(ab2367, cd2367, 0xee); + efgh26 = _mm256_shuffle_ps(ef2367, gh2367, 0x44); + efgh37 = _mm256_shuffle_ps(ef2367, gh2367, 0xee); + + // shuffling 128-bit elements + // a : a0 b0 c0 d0 e0 f0 g0 h0 + // b : a1 b1 c1 d1 e1 f1 g1 h1 + // c : a2 b2 c2 d2 e2 f2 g2 h2 + // d : a3 b3 c3 d3 e3 f3 g3 h3 + // e : a4 b4 c4 d4 e4 f4 g4 h4 + // f : a5 b5 c5 d5 e5 f5 g5 h5 + // g : a6 b6 c6 d6 e6 f6 g6 h6 + // h : a7 b7 c7 d7 e7 f7 g7 h7 + a = _mm256_permute2f128_ps(efgh04, abcd04, 0x02); + b = _mm256_permute2f128_ps(efgh15, abcd15, 0x02); + c = _mm256_permute2f128_ps(efgh26, abcd26, 0x02); + d = _mm256_permute2f128_ps(efgh37, abcd37, 0x02); + e = _mm256_permute2f128_ps(efgh04, abcd04, 0x13); + f = _mm256_permute2f128_ps(efgh15, abcd15, 0x13); + g = _mm256_permute2f128_ps(efgh26, abcd26, 0x13); + h = _mm256_permute2f128_ps(efgh37, abcd37, 0x13); + + // store from registers to dst + _mm256_storeu_ps(&dst[0 * ld_dst], a); + _mm256_storeu_ps(&dst[1 * ld_dst], b); + _mm256_storeu_ps(&dst[2 * ld_dst], c); + _mm256_storeu_ps(&dst[3 * ld_dst], d); + _mm256_storeu_ps(&dst[4 * ld_dst], e); + _mm256_storeu_ps(&dst[5 * ld_dst], f); + _mm256_storeu_ps(&dst[6 * ld_dst], g); + _mm256_storeu_ps(&dst[7 * ld_dst], h); +} + +void transpose_8x8( + int M, + int N, + const float* src, + int ld_src, + float* dst, + int ld_dst) { + int ib = 0, jb = 0; + for (ib = 0; ib + 8 <= M; ib += 8) { + for (jb = 0; jb + 8 <= N; jb += 8) { + transpose_kernel_8x8_avx2( + &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst); + } + } + transpose_4x4(ib, N - jb, &src[jb], ld_src, &dst[jb * ld_dst], ld_dst); + transpose_4x4(M - ib, N, &src[ib * ld_src], ld_src, &dst[ib], ld_dst); +} + +void transpose_simd( + int M, + int N, + const float* src, + int ld_src, + float* dst, + int ld_dst) { + // Run time CPU detection + if (cpuinfo_initialize()) { + if (cpuinfo_has_x86_avx512f()) { + transpose_16x16(M, N, src, ld_src, dst, ld_dst); + } else if (cpuinfo_has_x86_avx2()) { + transpose_8x8(M, N, src, ld_src, dst, ld_dst); + } else { + transpose_ref(M, N, src, ld_src, dst, ld_dst); + return; + } + } else { + throw std::runtime_error("Failed to initialize cpuinfo!"); + } +} + +} // namespace fbgemm2 diff --git a/src/Utils_avx512.cc b/src/Utils_avx512.cc new file mode 100644 index 0000000..b6bf413 --- /dev/null +++ b/src/Utils_avx512.cc @@ -0,0 +1,243 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "fbgemm/Utils.h" + +#include <immintrin.h> + +namespace fbgemm2 { + +inline void transpose_kernel_16x16_avx512( + const float* src, + int ld_src, + float* dst, + int ld_dst) { + // load from src to registers + // a: a0 a1 a2 a3 a4 a5 a6 a7 a8 a9 a10 a11 a12 a13 a14 a15 + // b: b0 b1 b2 b3 b4 b5 b6 b7 b8 b9 b10 b11 b12 b13 b14 b15 + // c: c0 c1 c2 c3 c4 c5 c6 c7 c8 c9 c10 c11 c12 c13 c14 c15 + // d: d0 d1 d2 d3 d4 d5 d6 d7 d8 d9 d10 d11 d12 d13 d14 d15 + // e: e0 e1 e2 e3 e4 e5 e6 e7 e8 e9 e10 e11 e12 e13 e14 e15 + // f: f0 f1 f2 f3 f4 f5 f6 f7 f8 f9 f10 f11 f12 f13 f14 f15 + // g: g0 g1 g2 g3 g4 g5 g6 g7 g8 g9 g10 g11 g12 g13 g14 g15 + // h: h0 h1 h2 h3 h4 h5 h6 h7 h8 h9 h10 h11 h12 h13 h14 h15 + // i: i0 i1 i2 i3 i4 i5 i6 i7 i8 i9 i10 i11 i12 i13 i14 i15 + // j: j0 j1 j2 j3 j4 j5 j6 j7 j8 j9 j10 j11 j12 j13 j14 j15 + // k: k0 k1 k2 k3 k4 k5 k6 k7 k8 k9 k10 k11 k12 k13 k14 k15 + // l: l0 l1 l2 l3 l4 l5 l6 l7 l8 l9 l10 l11 l12 l13 l14 l15 + // m: m0 m1 m2 m3 m4 m5 m6 m7 m8 m9 m10 m11 m12 m13 m14 m15 + // n: n0 n1 n2 n3 n4 n5 n6 n7 n8 n9 n10 n11 n12 n13 n14 n15 + // o: o0 o1 o2 o3 o4 o5 o6 o7 o8 o9 o10 o11 o12 o13 o14 o15 + // p: p0 p1 p2 p3 p4 p5 p6 p7 p8 p9 p10 p11 p12 p13 p14 p15 + __m512 a = _mm512_loadu_ps(&src[0 * ld_src]); + __m512 b = _mm512_loadu_ps(&src[1 * ld_src]); + __m512 c = _mm512_loadu_ps(&src[2 * ld_src]); + __m512 d = _mm512_loadu_ps(&src[3 * ld_src]); + __m512 e = _mm512_loadu_ps(&src[4 * ld_src]); + __m512 f = _mm512_loadu_ps(&src[5 * ld_src]); + __m512 g = _mm512_loadu_ps(&src[6 * ld_src]); + __m512 h = _mm512_loadu_ps(&src[7 * ld_src]); + __m512 i = _mm512_loadu_ps(&src[8 * ld_src]); + __m512 j = _mm512_loadu_ps(&src[9 * ld_src]); + __m512 k = _mm512_loadu_ps(&src[10 * ld_src]); + __m512 l = _mm512_loadu_ps(&src[11 * ld_src]); + __m512 m = _mm512_loadu_ps(&src[12 * ld_src]); + __m512 n = _mm512_loadu_ps(&src[13 * ld_src]); + __m512 o = _mm512_loadu_ps(&src[14 * ld_src]); + __m512 p = _mm512_loadu_ps(&src[15 * ld_src]); + + __m512 ta, tb, tc, td, te, tf, tg, th, ti, tj, tk, tl, tm, tn, to, tq; + // unpacking and interleaving 32-bit elements + // a0 b0 a1 b1 a4 b4 a5 b5 a8 b8 a9 b9 a12 b12 a13 b13 + // a2 b2 a3 b3 a6 b6 a7 b7 a10 b10 a11 b11 a14 b14 a15 b15 + // c0 d0 c1 d1 ... + // c2 d2 c3 d3 ... + // e0 f0 e1 f1 ... + // e2 f2 e3 f3 ... + // g0 h0 g1 h1 ... + // g2 h2 g3 h3 ... + // i0 ... + // i2 ... + // k0 ... + // k2 ... + // m0 ... + // m2 ... + // o0 ... + // o1 ... + ta = _mm512_unpacklo_ps(a, b); + tb = _mm512_unpackhi_ps(a, b); + tc = _mm512_unpacklo_ps(c, d); + td = _mm512_unpackhi_ps(c, d); + te = _mm512_unpacklo_ps(e, f); + tf = _mm512_unpackhi_ps(e, f); + tg = _mm512_unpacklo_ps(g, h); + th = _mm512_unpackhi_ps(g, h); + ti = _mm512_unpacklo_ps(i, j); + tj = _mm512_unpackhi_ps(i, j); + tk = _mm512_unpacklo_ps(k, l); + tl = _mm512_unpackhi_ps(k, l); + tm = _mm512_unpacklo_ps(m, n); + tn = _mm512_unpackhi_ps(m, n); + to = _mm512_unpacklo_ps(o, p); + tq = _mm512_unpackhi_ps(o, p); + + // unpacking and interleaving 64-bit elements + // a0 b0 c0 d0 a4 b4 c4 d4 a8 b8 c8 d8 a12 b12 c12 d12 + // a1 b1 c1 d1 ... + // a2 b2 c2 d2 ... + // a3 b3 c3 d3 ... + // e0 f0 g0 h0 e4 f4 g4 h4 e8 f8 g8 h8 e12 f12 g12 h12 + // e1 f1 g1 h1 ... + // e2 f2 g2 h2 ... + // e3 f3 g3 h3 ... + // i0 j0 k0 l0 ... + // i1 j1 k1 l1 ... + // i2 j2 k2 l2 ... + // i3 j3 k3 l3 ... + // m0 n0 o0 p0 ... + // m1 n1 o1 p1 ... + // m2 n2 o2 p2 ... + // m3 n3 o3 p3 ... + a = _mm512_castpd_ps( + _mm512_unpacklo_pd(_mm512_castps_pd(ta), _mm512_castps_pd(tc))); + b = _mm512_castpd_ps( + _mm512_unpackhi_pd(_mm512_castps_pd(ta), _mm512_castps_pd(tc))); + c = _mm512_castpd_ps( + _mm512_unpacklo_pd(_mm512_castps_pd(tb), _mm512_castps_pd(td))); + d = _mm512_castpd_ps( + _mm512_unpackhi_pd(_mm512_castps_pd(tb), _mm512_castps_pd(td))); + e = _mm512_castpd_ps( + _mm512_unpacklo_pd(_mm512_castps_pd(te), _mm512_castps_pd(tg))); + f = _mm512_castpd_ps( + _mm512_unpackhi_pd(_mm512_castps_pd(te), _mm512_castps_pd(tg))); + g = _mm512_castpd_ps( + _mm512_unpacklo_pd(_mm512_castps_pd(tf), _mm512_castps_pd(th))); + h = _mm512_castpd_ps( + _mm512_unpackhi_pd(_mm512_castps_pd(tf), _mm512_castps_pd(th))); + i = _mm512_castpd_ps( + _mm512_unpacklo_pd(_mm512_castps_pd(ti), _mm512_castps_pd(tk))); + j = _mm512_castpd_ps( + _mm512_unpackhi_pd(_mm512_castps_pd(ti), _mm512_castps_pd(tk))); + k = _mm512_castpd_ps( + _mm512_unpacklo_pd(_mm512_castps_pd(tj), _mm512_castps_pd(tl))); + l = _mm512_castpd_ps( + _mm512_unpackhi_pd(_mm512_castps_pd(tj), _mm512_castps_pd(tl))); + m = _mm512_castpd_ps( + _mm512_unpacklo_pd(_mm512_castps_pd(tm), _mm512_castps_pd(to))); + n = _mm512_castpd_ps( + _mm512_unpackhi_pd(_mm512_castps_pd(tm), _mm512_castps_pd(to))); + o = _mm512_castpd_ps( + _mm512_unpacklo_pd(_mm512_castps_pd(tn), _mm512_castps_pd(tq))); + p = _mm512_castpd_ps( + _mm512_unpackhi_pd(_mm512_castps_pd(tn), _mm512_castps_pd(tq))); + + // shuffle 128-bits (composed of 4 32-bit elements) + // a0 b0 c0 d0 a8 b8 c8 d8 e0 f0 g0 h0 e8 f8 g8 h8 + // a1 b1 c1 d1 ... + // a2 b2 c2 d2 ... + // a3 b3 c3 d3 ... + // a4 b4 c4 d4 ... + // a5 b5 c5 d5 ... + // a6 b6 c6 d6 ... + // a7 b7 c7 d7 ... + // i0 j0 k0 l0 i8 j8 k8 l8 m0 n0 o0 p0 m8 n8 o8 p8 + // i1 j1 k1 l1 ... + // i2 j2 k2 l2 ... + // i3 j3 k3 l3 ... + // i4 j4 k4 l4 ... + // i5 j5 k5 l5 ... + // i6 j6 k6 l6 ... + // i7 j7 k7 l7 ... + ta = _mm512_shuffle_f32x4(a, e, 0x88); + tb = _mm512_shuffle_f32x4(b, f, 0x88); + tc = _mm512_shuffle_f32x4(c, g, 0x88); + td = _mm512_shuffle_f32x4(d, h, 0x88); + te = _mm512_shuffle_f32x4(a, e, 0xdd); + tf = _mm512_shuffle_f32x4(b, f, 0xdd); + tg = _mm512_shuffle_f32x4(c, g, 0xdd); + th = _mm512_shuffle_f32x4(d, h, 0xdd); + ti = _mm512_shuffle_f32x4(i, m, 0x88); + tj = _mm512_shuffle_f32x4(j, n, 0x88); + tk = _mm512_shuffle_f32x4(k, o, 0x88); + tl = _mm512_shuffle_f32x4(l, p, 0x88); + tm = _mm512_shuffle_f32x4(i, m, 0xdd); + tn = _mm512_shuffle_f32x4(j, n, 0xdd); + to = _mm512_shuffle_f32x4(k, o, 0xdd); + tq = _mm512_shuffle_f32x4(l, p, 0xdd); + + // shuffle 128-bits (composed of 4 32-bit elements) + // a0 b0 c0 d0 ... o0 + // a1 b1 c1 d1 ... o1 + // a2 b2 c2 d2 ... o2 + // a3 b3 c3 d3 ... o3 + // a4 ... + // a5 ... + // a6 ... + // a7 ... + // a8 ... + // a9 ... + // a10 ... + // a11 ... + // a12 ... + // a13 ... + // a14 ... + // a15 b15 c15 d15 ... o15 + a = _mm512_shuffle_f32x4(ta, ti, 0x88); + b = _mm512_shuffle_f32x4(tb, tj, 0x88); + c = _mm512_shuffle_f32x4(tc, tk, 0x88); + d = _mm512_shuffle_f32x4(td, tl, 0x88); + e = _mm512_shuffle_f32x4(te, tm, 0x88); + f = _mm512_shuffle_f32x4(tf, tn, 0x88); + g = _mm512_shuffle_f32x4(tg, to, 0x88); + h = _mm512_shuffle_f32x4(th, tq, 0x88); + i = _mm512_shuffle_f32x4(ta, ti, 0xdd); + j = _mm512_shuffle_f32x4(tb, tj, 0xdd); + k = _mm512_shuffle_f32x4(tc, tk, 0xdd); + l = _mm512_shuffle_f32x4(td, tl, 0xdd); + m = _mm512_shuffle_f32x4(te, tm, 0xdd); + n = _mm512_shuffle_f32x4(tf, tn, 0xdd); + o = _mm512_shuffle_f32x4(tg, to, 0xdd); + p = _mm512_shuffle_f32x4(th, tq, 0xdd); + + // store from registers to dst + _mm512_storeu_ps(&dst[0 * ld_dst], a); + _mm512_storeu_ps(&dst[1 * ld_dst], b); + _mm512_storeu_ps(&dst[2 * ld_dst], c); + _mm512_storeu_ps(&dst[3 * ld_dst], d); + _mm512_storeu_ps(&dst[4 * ld_dst], e); + _mm512_storeu_ps(&dst[5 * ld_dst], f); + _mm512_storeu_ps(&dst[6 * ld_dst], g); + _mm512_storeu_ps(&dst[7 * ld_dst], h); + _mm512_storeu_ps(&dst[8 * ld_dst], i); + _mm512_storeu_ps(&dst[9 * ld_dst], j); + _mm512_storeu_ps(&dst[10 * ld_dst], k); + _mm512_storeu_ps(&dst[11 * ld_dst], l); + _mm512_storeu_ps(&dst[12 * ld_dst], m); + _mm512_storeu_ps(&dst[13 * ld_dst], n); + _mm512_storeu_ps(&dst[14 * ld_dst], o); + _mm512_storeu_ps(&dst[15 * ld_dst], p); +} + +void transpose_16x16( + int M, + int N, + const float* src, + int ld_src, + float* dst, + int ld_dst) { + int ib = 0, jb = 0; + for (ib = 0; ib + 16 <= M; ib += 16) { + for (jb = 0; jb + 16 <= N; jb += 16) { + transpose_kernel_16x16_avx512( + &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst); + } + } + transpose_8x8(ib, N - jb, &src[jb], ld_src, &dst[jb * ld_dst], ld_dst); + transpose_8x8(M - ib, N, &src[ib * ld_src], ld_src, &dst[ib], ld_dst); +} + +} // namespace fbgemm2 diff --git a/src/codegen_fp16fp32.cc b/src/codegen_fp16fp32.cc new file mode 100644 index 0000000..8e36c85 --- /dev/null +++ b/src/codegen_fp16fp32.cc @@ -0,0 +1,387 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include <assert.h> +#include <cpuid.h> +#include <stdlib.h> +#include <string.h> +#include <algorithm> +#include <array> +#include <fstream> +#include <functional> +#include <iostream> +#include <map> +#include <string> + +using namespace std; + +void addi(ofstream& of, string i, bool disable = false) { + if (disable == false) + of << "\"" + i + "\\t\\n\"" + "\n"; +} + +struct ISA { + unsigned avx; // 1, 2 or 3 + string name; + vector<vector<unsigned>> shapes; +}; + +int main() { + bool iaca = false; + bool disable = false; + + bool fixedA = true, fixedB = true, fixedC = true; + + int eax, ebx, ecx, edx; + __cpuid(1 /* ecx = vendor string */, eax, ebx, ecx, edx); + printf("FC16 is %s supported\n", ((ecx & bit_F16C) ? " " : "not")); + + string comma = ","; + + vector<ISA> isa = { + // {1, "AVX", {{4, 1, 0}, {4, 2, 0}, {4, 3, 0}, {3, 1, 0}, {3, 2, 0}, {3, + // 3, 0}}}, + { 2, "AVX2", + { { 1, 1, 0 }, + { 2, 1, 0 }, + { 3, 1, 0 }, + { 4, 1, 0 }, + { 5, 1, 0 }, + { 6, 1, 0 }, + { 7, 1, 0 }, + { 8, 1, 0 }, + { 9, 1, 0 }, + { 10, 1, 0 }, + { 11, 1, 0 }, + { 12, 1, 0 }, + { 13, 1, 0 }, + { 14, 1, 0 }, + } + } + }; + + // open all files + ofstream srcfile; + srcfile.open("FbgemmFP16UKernels.cc"); + srcfile << "#include \"FbgemmFP16UKernels.h\"\n"; + if (iaca) + srcfile << "#include \"iacaMarks.h\"\n"; + + ofstream hdrfile; + hdrfile.open("FbgemmFP16UKernels.h"); + + hdrfile << "#ifndef FBGEMM_UKERNELS\n"; + hdrfile << "#define FBGEMM_UKERNELS\n"; + hdrfile << "#include <cstdint>\n"; + hdrfile << "#include <tuple>\n"; + hdrfile << "#include <vector>\n"; + hdrfile << "#include \"fbgemm/Types.h\"\n"; + hdrfile << "using fp16 = fbgemm2::float16;\n"; + hdrfile << "using fp32 = float;\n"; + hdrfile << "struct GemmParams {uint64_t k; float *A; const fp16 *B;\n" + "float *beta; uint64_t accum; float *C; uint64_t ldc;\n" + "uint64_t b_block_cols; uint64_t b_block_size;};\n"; + + std::map<string, string> fptr_typedef; + fptr_typedef["fp16"] = ""; + fptr_typedef["fp32"] = ""; + + unsigned labelId = 0; +#if 1 + for (auto fixedA : {false}) + for (auto fixedB : {false}) + for (auto fixedC : {false}) +#else + for (auto fixedA : {true}) + for (auto fixedB : {true}) + for (auto fixedC : {true}) +#endif + for (auto s : isa) { + vector<vector<unsigned>>& ukernel_shape = s.shapes; + + vector<string> funcname(ukernel_shape.size()), + fheader(ukernel_shape.size()); + string fargs; + + for (auto fp16 : {true}) { + string B_type = ((fp16) ? "fp16" : "fp32"); + string prefix = s.name + /*"_" + B_type */ + "_" + "fA" + + to_string(fixedA) + "fB" + to_string(fixedB) + "fC" + + to_string(fixedC); + cout << "Generating code for " << s.name << " " << B_type << "\n"; + + for (unsigned k = 0; k < ukernel_shape.size(); k++) { + printf( + "shape: %d x %d * 32\n", + ukernel_shape[k][0], + ukernel_shape[k][1]); + + string p1 = "GemmParams *gp"; + + funcname[k] = "gemmkernel_" + to_string(ukernel_shape[k][0]) + + "x" + to_string(ukernel_shape[k][1]) + "_"; + funcname[k] += prefix; + + fargs = "(" + p1 + ")"; + + fheader[k] = + "void __attribute__ ((noinline)) " + funcname[k] + fargs; + srcfile << fheader[k] << "\n"; + srcfile << "{\n"; + + unsigned last_free_ymmreg = 0; + // produce register block of C + vector<vector<string>> vCtile(ukernel_shape[k][0]); + for (auto r = 0; r < ukernel_shape[k][0]; r++) + for (auto c = 0; c < ukernel_shape[k][1]; c++) { + vCtile[r].push_back("ymm" + to_string(last_free_ymmreg)); + last_free_ymmreg++; + } + assert(last_free_ymmreg <= 14); + + string vAtmp = "ymm" + to_string(last_free_ymmreg++); + // produce register block of B col + assert(ukernel_shape[k][1] == 1); + vector<string> vBcol(ukernel_shape[k][1]); + + for (auto c = 0; c < ukernel_shape[k][1]; c++) { + vBcol[c] = ("ymm" + to_string(last_free_ymmreg)); + last_free_ymmreg++; + } + + assert(last_free_ymmreg <= 16); + + srcfile << "asm volatile\n"; + srcfile << "(\n"; + + srcfile << "#if !defined(__clang__)" << "\n"; + addi(srcfile, "mov r14, %[gp]"); + srcfile << "#else\n"; + addi(srcfile, "mov %[gp], %%r14"); + addi(srcfile, ".intel_syntax noprefix"); + srcfile << "#endif\n"; + + srcfile << "\n// Copy parameters\n"; + srcfile << "// k\n"; + addi(srcfile, "mov r8, [r14 + 0]"); + srcfile << "// A\n"; + addi(srcfile, "mov r9, [r14 + 8]"); + srcfile << "// B\n"; + addi(srcfile, "mov r10, [r14 + 16]"); + srcfile << "// beta\n"; + addi(srcfile, "mov r15, [r14 + 24]"); + srcfile << "// accum\n"; + addi(srcfile, "mov rdx, [r14 + 32]"); + srcfile << "// C\n"; + addi(srcfile, "mov r12, [r14 + 40]"); + srcfile << "// ldc\n"; + addi(srcfile, "mov r13, [r14 + 48]"); + srcfile << "// b_block_cols\n"; + addi(srcfile, "mov rdi, [r14 + 56]"); + srcfile << "// b_block_size\n"; + addi(srcfile, "mov rsi, [r14 + 64]"); + srcfile << "// Make copies of A and C\n"; + addi(srcfile, "mov rax, r9"); + addi(srcfile, "mov rcx, r12"); + srcfile << "\n\n"; + + addi(srcfile, "mov rbx, 0"); + + string exitlabel = "L_exit%="; + string label2 = "loop_outter%="; + addi(srcfile, label2 + ":"); + addi(srcfile, "mov r14, 0"); + + // set all vCtile regs to zeros + for (auto r = 0; r < vCtile.size(); r++) { + for (auto c = 0; c < vCtile[r].size(); c++) { + addi( + srcfile, + "vxorps " + vCtile[r][c] + "," + vCtile[r][c] + "," + + vCtile[r][c]); + } + } + + // start marker + if (iaca) { + addi(srcfile, "mov ebx, 111"); + addi(srcfile, ".byte 0x64, 0x67, 0x90"); + } + + srcfile << "\n"; + + if (ukernel_shape[k][0] <= 13) { + addi(srcfile, "vcvtph2ps ymm15, XMMWORD PTR [r10 + 0]"); + addi(srcfile, "mov r11, 16"); + } else { + addi(srcfile, "mov r11, 0"); + } + + srcfile << "\n"; + string label = "loop_inner%="; + addi(srcfile, label + ":"); + srcfile << "\n"; + + if (ukernel_shape[k][0] <= 13) { + auto a_offset = 0, unroll_factor = 2; + for (auto u = 0; u < unroll_factor; u++) { + string breg = (u == 0) ? "ymm14" : "ymm15"; + string breg_rev = (u == 0) ? "ymm15" : "ymm14"; + + addi(srcfile, "vcvtph2ps " + breg + + ",XMMWORD PTR [r10 + r11 + " + + to_string(u * 16) + "]"); + addi(srcfile, "inc r14"); + for (auto r = 0; r < vCtile.size(); r++) { + addi(srcfile, "vbroadcastss " + vAtmp + ",DWORD PTR [r9+" + + to_string(a_offset) + "]"); + addi(srcfile, "vfmadd231ps " + vCtile[r][0] + "," + + breg_rev + "," + vAtmp); + if (u == 1 && r == vCtile.size() / 2) + addi(srcfile, "add r11, 32"); + a_offset += 4; + } + if (u < unroll_factor - 1) { + addi(srcfile, "cmp r14, r8"); + addi(srcfile, "jge " + exitlabel); + } + } + + addi(srcfile, "add r9," + to_string(a_offset)); + addi(srcfile, "cmp r14, r8"); + addi(srcfile, "jl " + label); + + srcfile << "\n"; + + addi(srcfile, exitlabel + ":"); + } else { + addi(srcfile, + "vcvtph2ps " + vBcol[0] + ",XMMWORD PTR [r10 + r11]"); + for (auto r = 0; r < vCtile.size(); r++) { + addi(srcfile, "vbroadcastss " + vAtmp + ",DWORD PTR [r9+" + + to_string(4 * r) + "]"); + addi(srcfile, "vfmadd231ps " + vCtile[r][0] + "," + vBcol[0] + + "," + vAtmp); + } + + addi(srcfile, "add r9," + to_string(4 * ukernel_shape[k][0]), + fixedA); // move A ptr + addi(srcfile, "add r11, 16"); + + addi(srcfile, "inc r14"); + addi(srcfile, "cmp r14, r8"); + addi(srcfile, "jl " + label); + } + + addi(srcfile, "add r10, rsi"); + srcfile << "\n"; + + // end marker + if (iaca) { + addi(srcfile, "mov ebx, 222"); + addi(srcfile, ".byte 0x64, 0x67, 0x90"); + } + + + addi(srcfile, "cmp rdx, 1"); + addi(srcfile, "je L_accum%="); + srcfile << "// Dump C\n"; + + for (auto r = 0; r < vCtile.size(); r++) { + for (auto c = 0; c < vCtile[r].size(); c++) { + addi(srcfile, "vmovups YMMWORD PTR [r12 + " + + to_string(32 * c) + "], " + vCtile[r][c], + fixedC); + } + addi(srcfile, "add r12, r13", fixedC); // move C ptr + } + addi(srcfile, "jmp L_done%="); + + srcfile << "\n\n"; + addi(srcfile, "L_accum%=:"); + srcfile << "// Dump C with accumulate\n"; + + string r_spare = (s.avx == 1) ? "ymm14" : "ymm15"; + addi(srcfile, + "vbroadcastss " + r_spare + string(",DWORD PTR [r15]"), + fixedC); + // store out C + for (auto r = 0; r < vCtile.size(); r++) { + for (auto c = 0; c < vCtile[r].size(); c++) { + switch (s.avx) { + case 1: + addi(srcfile, + string("vmulps ymm15, ") + r_spare + comma + + "YMMWORD PTR [r12 + " + to_string(32 * c) + "]", + fixedC); + addi(srcfile, "vaddps " + vCtile[r][c] + "," + + vCtile[r][c] + "," + "ymm15", + fixedC); + break; + case 2: + addi(srcfile, + "vfmadd231ps " + vCtile[r][c] + "," + r_spare + "," + + "YMMWORD PTR [r12 + " + to_string(32 * c) + "]", + fixedC); + break; + default: + assert(0); + } + addi(srcfile, "vmovups YMMWORD PTR [r12 + " + + to_string(32 * c) + "], " + vCtile[r][c], + fixedC); + } + addi(srcfile, "add r12, r13", fixedC); // move C ptr + } + + srcfile << "\n"; + addi(srcfile, "L_done%=:"); + + srcfile << "\n// next outer iteration\n"; + // C + addi(srcfile, "add rcx, " + to_string(32 * ukernel_shape[k][1]), + fixedC); + addi(srcfile, "mov r12, rcx", fixedC); + // A + addi(srcfile, "mov r9, rax"); + + addi(srcfile, "inc rbx"); + addi(srcfile, "cmp rbx, rdi"); + addi(srcfile, "jl " + label2); + + // output + srcfile << ":\n"; + // input + srcfile << ":\n"; + srcfile << "[gp] \"rm\" (gp)\n"; + + // clobbered + srcfile + << (string) ": \"r8\", \"r9\", \"r10\", \"r11\", \"r15\", " + + (string) " \"r13\", \"r14\",\n" + + (string) "\"rax\", \"rcx\", " + "\"rdx\", \"rsi\", \"rdi\", \"rbx\", " + "\"r12\", \"memory\"" + + (string) "\n"; + srcfile << ");\n"; + srcfile << "}\n"; + } + + for (unsigned k = 0; k < ukernel_shape.size(); k++) { + hdrfile << fheader[k] << ";\n"; + } + + fptr_typedef[B_type] = + "typedef void (* funcptr_" + B_type + ") " + fargs; + } + } + + srcfile.close(); + hdrfile << fptr_typedef["fp16"] << ";\n"; + hdrfile << fptr_typedef["fp32"] << ";\n"; + hdrfile << "#endif\n"; + hdrfile.close(); +} |