Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/FBGEMM.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/ExecuteKernel.cc12
-rw-r--r--src/ExecuteKernel.h11
-rw-r--r--src/ExecuteKernelGeneric.h64
-rw-r--r--src/ExecuteKernelU8S8.cc354
-rw-r--r--src/ExecuteKernelU8S8.h73
-rw-r--r--src/Fbgemm.cc363
-rw-r--r--src/FbgemmFP16.cc293
-rw-r--r--src/FbgemmFP16UKernels.cc2203
-rw-r--r--src/FbgemmFP16UKernels.h40
-rw-r--r--src/FbgemmI8Depthwise.cc1953
-rw-r--r--src/FbgemmI8Depthwise.h105
-rw-r--r--src/FbgemmI8Spmdm.cc508
-rw-r--r--src/GenerateKernel.h154
-rw-r--r--src/GenerateKernelU8S8S32ACC16.cc292
-rw-r--r--src/GenerateKernelU8S8S32ACC16_avx512.cc295
-rw-r--r--src/GenerateKernelU8S8S32ACC32.cc310
-rw-r--r--src/GenerateKernelU8S8S32ACC32_avx512.cc312
-rw-r--r--src/PackAMatrix.cc165
-rw-r--r--src/PackAWithIm2Col.cc146
-rw-r--r--src/PackBMatrix.cc144
-rw-r--r--src/PackMatrix.cc86
-rw-r--r--src/PackWithQuantRowOffset.cc230
-rw-r--r--src/PackWithRowOffset.cc211
-rw-r--r--src/RefImplementations.cc608
-rw-r--r--src/RefImplementations.h268
-rw-r--r--src/Utils.cc357
-rw-r--r--src/Utils_avx512.cc243
-rw-r--r--src/codegen_fp16fp32.cc387
28 files changed, 10187 insertions, 0 deletions
diff --git a/src/ExecuteKernel.cc b/src/ExecuteKernel.cc
new file mode 100644
index 0000000..0e3d122
--- /dev/null
+++ b/src/ExecuteKernel.cc
@@ -0,0 +1,12 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+#include "ExecuteKernel.h"
+#include <immintrin.h>
+#include "fbgemm/Fbgemm.h"
+#include "fbgemm/Utils.h"
+
+namespace fbgemm2 {} // namespace fbgemm2
diff --git a/src/ExecuteKernel.h b/src/ExecuteKernel.h
new file mode 100644
index 0000000..55a2581
--- /dev/null
+++ b/src/ExecuteKernel.h
@@ -0,0 +1,11 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+#pragma once
+#include <cstdint>
+#include "fbgemm/Fbgemm.h"
+#include "ExecuteKernelGeneric.h"
+#include "ExecuteKernelU8S8.h"
diff --git a/src/ExecuteKernelGeneric.h b/src/ExecuteKernelGeneric.h
new file mode 100644
index 0000000..e83e943
--- /dev/null
+++ b/src/ExecuteKernelGeneric.h
@@ -0,0 +1,64 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+#pragma once
+#include <cstdint>
+#include "fbgemm/Fbgemm.h"
+#include "GenerateKernel.h"
+
+namespace fbgemm2 {
+
+/**
+ * @brief Execute Engine for the macro-kernel and output processing.
+ * ExecuteKernel is a derived class of CodeGenBase.
+ */
+template <
+ typename packingAMatrix,
+ typename packingBMatrix,
+ typename cT,
+ typename processOutputType>
+class ExecuteKernel : public CodeGenBase<
+ typename packingAMatrix::inpType,
+ typename packingBMatrix::inpType,
+ cT,
+ typename packingBMatrix::accType> {
+ public:
+ ExecuteKernel(
+ PackMatrix<
+ packingAMatrix,
+ typename packingAMatrix::inpType,
+ typename packingAMatrix::accType>& packA,
+ PackMatrix<
+ packingBMatrix,
+ typename packingBMatrix::inpType,
+ typename packingBMatrix::accType>& packB,
+ int32_t kBlock,
+ cT* matC,
+ typename packingBMatrix::accType* C_buffer,
+ int32_t ldc,
+ const processOutputType& outputProcess);
+ void execute(int kBlock);
+
+ private:
+ PackMatrix<
+ packingAMatrix,
+ typename packingAMatrix::inpType,
+ typename packingAMatrix::accType>&
+ packedA_; ///< Packed block of matrix A.
+ PackMatrix<
+ packingBMatrix,
+ typename packingBMatrix::inpType,
+ typename packingBMatrix::accType>& packedB_; ///< Packed matrix B.
+ int32_t kBlock_; ///< Block ID in the k dimension.
+ cT* matC_; ///< Output for matrix C.
+ typename packingAMatrix::accType*
+ C_buffer_; ///< the accumulation buffer for matrix C.
+ int32_t ldc_; ///< the leading dimension of matrix C.
+ const processOutputType& outputProcess_; ///< output processing function for
+ ///< the C tile in the macro-kernel.
+};
+
+} // namespace fbgemm2
diff --git a/src/ExecuteKernelU8S8.cc b/src/ExecuteKernelU8S8.cc
new file mode 100644
index 0000000..5145869
--- /dev/null
+++ b/src/ExecuteKernelU8S8.cc
@@ -0,0 +1,354 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+#include "ExecuteKernelU8S8.h"
+#include <cpuinfo.h>
+#include <chrono>
+
+
+#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN
+double kernel_time = 0.0;
+double postprocessing_time = 0.0;
+#endif
+
+namespace fbgemm2 {
+
+template <typename packingAMatrix, typename cT, typename processOutputType>
+ExecuteKernel<
+ packingAMatrix,
+ PackBMatrix<int8_t, typename packingAMatrix::accType>,
+ cT,
+ processOutputType>::
+ ExecuteKernel(
+ PackMatrix<packingAMatrix, uint8_t, typename packingAMatrix::accType>&
+ packA,
+ PackMatrix<
+ PackBMatrix<int8_t, typename packingAMatrix::accType>,
+ int8_t,
+ typename packingAMatrix::accType>& packB,
+ int32_t kBlock,
+ cT* matC,
+ int32_t* C_buffer,
+ int32_t ldc,
+ const processOutputType& outputProcess)
+ : packedA_(packA),
+ packedB_(packB),
+ kBlock_(kBlock),
+ matC_(matC),
+ C_buffer_(C_buffer),
+ ldc_(ldc),
+ outputProcess_(outputProcess) {
+ if (cpuinfo_has_x86_avx512f()) {
+ mbSize_ = PackingTraits<
+ int8_t,
+ typename packingAMatrix::accType,
+ inst_set_t::avx512>::MCB;
+ nbSize_ = PackingTraits<
+ int8_t,
+ typename packingAMatrix::accType,
+ inst_set_t::avx512>::NCB;
+ } else if (cpuinfo_has_x86_avx2()) {
+ mbSize_ = PackingTraits<
+ int8_t,
+ typename packingAMatrix::accType,
+ inst_set_t::avx2>::MCB;
+ nbSize_ = PackingTraits<
+ int8_t,
+ typename packingAMatrix::accType,
+ inst_set_t::avx2>::NCB;
+ } else {
+ assert(0 && "unsupported architecure");
+ }
+ C_tile_ = new int32_t[mbSize_ * nbSize_];
+}
+
+template <typename packingAMatrix, typename cT, typename processOutputType>
+void ExecuteKernel<
+ packingAMatrix,
+ PackBMatrix<int8_t, typename packingAMatrix::accType>,
+ cT,
+ processOutputType>::execute(int kBlock) {
+ // packedA_.printPackedMatrix("packedA from kernel");
+ // packedB_.printPackedMatrix("packedB from kernel");
+
+ int32_t bColBlocks = packedB_.blockCols();
+
+ int8_t* bBuf;
+ int8_t* bBuf_pf;
+
+ uint8_t* aBuf = packedA_.getBuf(0);
+
+ int32_t packed_rows_A = packedA_.numPackedRows();
+ int32_t row_start_A = packedA_.packedRowStart();
+
+ bool lastKBlock = packedB_.isThisLastKBlock(kBlock);
+ bool accum = kBlock > 0;
+
+ typename BaseType::jit_micro_kernel_fp fn;
+
+ if (cpuinfo_initialize()) {
+ if (cpuinfo_has_x86_avx512f()) {
+ fn = BaseType::template getOrCreate<inst_set_t::avx512>(
+ accum,
+ packed_rows_A,
+ packedB_.blockColSize(),
+ packedA_.numPackedCols(),
+ nbSize_);
+ } else if (cpuinfo_has_x86_avx2()) {
+ fn = BaseType::template getOrCreate<inst_set_t::avx2>(
+ accum,
+ packed_rows_A,
+ packedB_.blockColSize(),
+ packedA_.numPackedCols(),
+ nbSize_);
+ } else {
+ // TODO: Have default slower path
+ assert(0 && "unsupported architecture");
+ return;
+ }
+ } else {
+ throw std::runtime_error("Failed to initialize cpuinfo!");
+ }
+
+#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN
+ std::chrono::time_point<std::chrono::high_resolution_clock> t_start, t_end;
+ double dt;
+ t_start = std::chrono::high_resolution_clock::now();
+#endif
+
+ for (int jb = 0; jb < bColBlocks; ++jb) {
+
+ bBuf = packedB_.getBuf(jb, kBlock);
+ // prefetch addr of the next packed block of B matrix
+ bBuf_pf = packedB_.getBuf(jb == bColBlocks - 1 ? jb : jb + 1, kBlock);
+
+ // Reuse the first rowblock of C_buffer_ unless when C_buffer_ is same as
+ // matC_ (inplace output processing)
+ int32_t* C_buffer_row_start = C_buffer_ +
+ ((C_buffer_ == reinterpret_cast<int32_t*>(matC_)) ? row_start_A * ldc_
+ : 0);
+ int32_t* C_buffer_start = C_buffer_row_start + jb * nbSize_;
+ int32_t leadingDim = ldc_;
+ if (packedB_.isThereColRemainder() && (jb == bColBlocks - 1)) {
+ // In case we will access memory past C_buffer_, we use C_tile_ instead.
+ C_buffer_start = C_tile_;
+ leadingDim = nbSize_;
+ }
+
+ fn(aBuf,
+ bBuf,
+ bBuf_pf,
+ C_buffer_start,
+ packedA_.numPackedCols(),
+ leadingDim);
+
+#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN
+ t_end = std::chrono::high_resolution_clock::now();
+ dt = std::chrono::duration_cast<std::chrono::nanoseconds>(t_end - t_start)
+ .count();
+ kernel_time += (dt);
+ t_start = std::chrono::high_resolution_clock::now();
+#endif
+
+ // Output processing is done only once per rowblock
+ if (lastKBlock && jb == bColBlocks - 1) {
+ // When C_tile_ is used for the last column block, we need a separate
+ // handling for the last column block.
+ int32_t nSize =
+ C_buffer_start == C_tile_ ? jb * nbSize_ : packedB_.numCols();
+ if (nSize) {
+ if (cpuinfo_has_x86_avx512f()) {
+ // TODO: avx512 path
+ // Currently use avx2 code
+ outputProcess_.template f<inst_set_t::avx2>(
+ matC_,
+ C_buffer_row_start,
+ {row_start_A, packed_rows_A, 0, nSize},
+ ldc_,
+ ldc_);
+ } else if (cpuinfo_has_x86_avx2()) {
+ outputProcess_.template f<inst_set_t::avx2>(
+ matC_,
+ C_buffer_row_start,
+ {row_start_A, packed_rows_A, 0, nSize},
+ ldc_,
+ ldc_);
+ } else {
+ // TODO: Have default slower path
+ assert(0 && "unsupported architecure");
+ }
+ }
+
+ if (C_buffer_start == C_tile_) {
+ if (cpuinfo_has_x86_avx512f()) {
+ // TODO: avx512 path
+ // Currently use avx2 code
+ outputProcess_.template f<inst_set_t::avx2>(
+ matC_,
+ C_tile_,
+ {row_start_A, packed_rows_A, jb * nbSize_, packedB_.lastBcol()},
+ ldc_,
+ leadingDim);
+ } else if (cpuinfo_has_x86_avx2()) {
+ outputProcess_.template f<inst_set_t::avx2>(
+ matC_,
+ C_tile_,
+ {row_start_A, packed_rows_A, jb * nbSize_, packedB_.lastBcol()},
+ ldc_,
+ leadingDim);
+ } else {
+ // TODO: Have default slower path
+ assert(0 && "unsupported architecure");
+ }
+ }
+ } // output processing
+
+#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN
+ t_end = std::chrono::high_resolution_clock::now();
+ dt = std::chrono::duration_cast<std::chrono::nanoseconds>(t_end - t_start)
+ .count();
+ postprocessing_time += (dt);
+ t_start = std::chrono::high_resolution_clock::now();
+#endif
+
+ } // for each j block
+}
+template class ExecuteKernel<
+ PackAWithRowOffset<uint8_t, int32_t>,
+ PackBMatrix<int8_t, int32_t>,
+ uint8_t,
+ ReQuantizeOutput<false /* FUSE_RELU*/>>;
+
+template class ExecuteKernel<
+ PackAWithRowOffset<uint8_t, int32_t>,
+ PackBMatrix<int8_t, int32_t>,
+ uint8_t,
+ ReQuantizeOutput<true>>;
+
+template class ExecuteKernel<
+ PackAWithQuantRowOffset<uint8_t, int32_t>,
+ PackBMatrix<int8_t, int32_t>,
+ float,
+ ReQuantizeForFloat<false>>;
+
+template class ExecuteKernel<
+ PackAWithQuantRowOffset<uint8_t, int32_t>,
+ PackBMatrix<int8_t, int32_t>,
+ float,
+ ReQuantizeForFloat<true>>;
+
+template class ExecuteKernel<
+ PackAWithRowOffset<uint8_t, int32_t>,
+ PackBMatrix<int8_t, int32_t>,
+ float,
+ ReQuantizeForFloat<false>>;
+
+template class ExecuteKernel<
+ PackAWithRowOffset<uint8_t, int32_t>,
+ PackBMatrix<int8_t, int32_t>,
+ float,
+ ReQuantizeForFloat<true>>;
+
+template class ExecuteKernel<
+ PackAMatrix<uint8_t, int16_t>,
+ PackBMatrix<int8_t, int16_t>,
+ int32_t,
+ memCopy<>>;
+
+template class ExecuteKernel<
+ PackAMatrix<uint8_t, int16_t>,
+ PackBMatrix<int8_t, int16_t>,
+ uint8_t,
+ ReQuantizeOutput<false>>;
+
+template class ExecuteKernel<
+ PackAMatrix<uint8_t, int32_t>,
+ PackBMatrix<int8_t, int32_t>,
+ int32_t,
+ memCopy<>>;
+
+template class ExecuteKernel<
+ PackAWithRowOffset<uint8_t, int16_t>,
+ PackBMatrix<int8_t, int16_t>,
+ uint8_t,
+ DoSpmdmOnInpBuffer<
+ ReQuantizeOutput<false>::outType,
+ int32_t,
+ ReQuantizeOutput<false>>>;
+
+template class ExecuteKernel<
+ PackAWithRowOffset<uint8_t, int16_t>,
+ PackBMatrix<int8_t, int16_t>,
+ uint8_t,
+ DoSpmdmOnInpBuffer<
+ ReQuantizeOutput<true>::outType,
+ int32_t,
+ ReQuantizeOutput<true>>>;
+
+template class ExecuteKernel<
+ PackAWithRowOffset<uint8_t, int16_t>,
+ PackBMatrix<int8_t, int16_t>,
+ float,
+ DoSpmdmOnInpBuffer<
+ ReQuantizeForFloat<false>::outType,
+ int32_t,
+ ReQuantizeForFloat<false>>>;
+
+template class ExecuteKernel<
+ PackAWithRowOffset<uint8_t, int16_t>,
+ PackBMatrix<int8_t, int16_t>,
+ uint8_t,
+ ReQuantizeOutput<false>>;
+
+template class ExecuteKernel<
+ PackAWithRowOffset<uint8_t, int16_t>,
+ PackBMatrix<int8_t, int16_t>,
+ uint8_t,
+ ReQuantizeOutput<true>>;
+
+template class ExecuteKernel<
+ PackAWithRowOffset<uint8_t, int16_t>,
+ PackBMatrix<int8_t, int16_t>,
+ int32_t,
+ memCopy<>>;
+
+template class ExecuteKernel<
+ PackAWithIm2Col<uint8_t, int16_t>,
+ PackBMatrix<int8_t, int16_t>,
+ int32_t,
+ memCopy<>>;
+
+template class ExecuteKernel<
+ PackAWithRowOffset<uint8_t, int32_t>,
+ PackBMatrix<int8_t, int32_t>,
+ int32_t,
+ memCopy<>>;
+
+template class ExecuteKernel<
+ PackAWithIm2Col<uint8_t, int32_t>,
+ PackBMatrix<int8_t, int32_t>,
+ int32_t,
+ memCopy<>>;
+
+template class ExecuteKernel<
+ PackAWithQuantRowOffset<uint8_t, int32_t>,
+ PackBMatrix<int8_t, int32_t>,
+ int32_t,
+ memCopy<>>;
+
+template class ExecuteKernel<
+ PackAWithRowOffset<uint8_t, int16_t>,
+ PackBMatrix<int8_t, int16_t>,
+ float,
+ ReQuantizeForFloat<false>>;
+
+template class ExecuteKernel<
+ PackAMatrix<uint8_t, int16_t>,
+ PackBMatrix<int8_t, int16_t>,
+ int32_t,
+ DoNothing<int32_t, int32_t>>;
+
+} // namespace fbgemm2
diff --git a/src/ExecuteKernelU8S8.h b/src/ExecuteKernelU8S8.h
new file mode 100644
index 0000000..0bd7fc5
--- /dev/null
+++ b/src/ExecuteKernelU8S8.h
@@ -0,0 +1,73 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+#pragma once
+#include "ExecuteKernel.h"
+
+namespace fbgemm2 {
+
+/**
+ * @brief Execute Engine of uint 8 and int8 matrix
+ * multiplication for the macro-kernel and output processing. ExecuteKernel is a
+ * derived class of CodeGenBase.
+ */
+template <typename packingAMatrix, typename cT, typename processOutputType>
+class ExecuteKernel<
+ packingAMatrix,
+ PackBMatrix<int8_t, typename packingAMatrix::accType>,
+ cT,
+ processOutputType>
+ : public CodeGenBase<
+ uint8_t,
+ int8_t,
+ int32_t,
+ typename packingAMatrix::accType> {
+ public:
+ using BaseType =
+ CodeGenBase<uint8_t, int8_t, int32_t, typename packingAMatrix::accType>;
+ /**
+ * @brief Constructor for initializing the parameters for macro-kernel and
+ * output processing type.
+ */
+ ExecuteKernel(
+ PackMatrix<packingAMatrix, uint8_t, typename packingAMatrix::accType>&
+ packA,
+ PackMatrix<
+ PackBMatrix<int8_t, typename packingAMatrix::accType>,
+ int8_t,
+ typename packingAMatrix::accType>& packB,
+ int32_t kBlock,
+ cT* matC,
+ int32_t* C_buffer,
+ int32_t ldc,
+ const processOutputType& outputProcess);
+ void execute(int kBlock);
+
+ ~ExecuteKernel() {
+ delete[] C_tile_;
+ }
+
+ private:
+ PackMatrix<packingAMatrix, uint8_t, typename packingAMatrix::accType>&
+ packedA_; ///< Packed uint8 block of matrix A.
+ PackMatrix<
+ PackBMatrix<int8_t, typename packingAMatrix::accType>,
+ int8_t,
+ typename packingAMatrix::accType>&
+ packedB_; ///< Packed int8 matrix B.
+ int32_t kBlock_; ///< Block ID in the k dimension.
+ cT* matC_; ///< Output for matrix C.
+ int32_t* C_buffer_; ///< the accumulation buffer for matrix C.
+ int32_t ldc_; ///< the leading dimension of matrix C.
+ const processOutputType& outputProcess_; ///< output processing function for
+ ///< matrix C in the macro-kernel.
+ int32_t* C_tile_; ///< buffer for the last N block when NCB is not an exact
+ ///< multiple of N.
+ int mbSize_; ///< block size in the m dimension.
+ int nbSize_; ///< block size in the n dimension.
+};
+
+} // namespace fbgemm2
diff --git a/src/Fbgemm.cc b/src/Fbgemm.cc
new file mode 100644
index 0000000..f3bac97
--- /dev/null
+++ b/src/Fbgemm.cc
@@ -0,0 +1,363 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+#include "fbgemm/Fbgemm.h"
+#include <cpuinfo.h>
+#include <stdexcept>
+#include "ExecuteKernel.h"
+
+#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN
+double packing_time = 0.0;
+double computing_time = 0.0;
+double run_time = 0.0;
+#endif
+
+using namespace fbgemm2;
+
+namespace fbgemm2 {
+
+template <
+ typename packingAMatrix,
+ typename packingBMatrix,
+ typename cT,
+ typename processOutputType>
+void fbgemmPacked(
+ PackMatrix<
+ packingAMatrix,
+ typename packingAMatrix::inpType,
+ typename packingAMatrix::accType>& packA,
+ PackMatrix<
+ packingBMatrix,
+ typename packingBMatrix::inpType,
+ typename packingBMatrix::accType>& packB,
+ cT* C,
+ int32_t* C_buffer,
+ uint32_t ldc,
+ const processOutputType& outProcess,
+ int thread_id,
+ int /* num_threads */) {
+ static_assert(
+ std::is_same<
+ typename packingAMatrix::accType,
+ typename packingBMatrix::accType>::value,
+ "Accumulation type of both matrices should be the same");
+
+ int MCB, KCB;
+
+ // Run time CPU detection
+ if (cpuinfo_initialize()) {
+ if (cpuinfo_has_x86_avx512f()) {
+ MCB = PackingTraits<
+ typename packingAMatrix::inpType,
+ typename packingAMatrix::accType,
+ inst_set_t::avx512>::MCB;
+ KCB = PackingTraits<
+ typename packingAMatrix::inpType,
+ typename packingAMatrix::accType,
+ inst_set_t::avx512>::KCB;
+ } else if (cpuinfo_has_x86_avx2()) {
+ MCB = PackingTraits<
+ typename packingAMatrix::inpType,
+ typename packingAMatrix::accType,
+ inst_set_t::avx2>::MCB;
+ KCB = PackingTraits<
+ typename packingAMatrix::inpType,
+ typename packingAMatrix::accType,
+ inst_set_t::avx2>::KCB;
+ } else {
+ // TODO: Have default slower path
+ assert(0 && "unsupported architecture");
+ return;
+ }
+ } else {
+ throw std::runtime_error("Failed to initialize cpuinfo!");
+ }
+
+ int MDim = packA.numRows();
+ int KDim = packB.numRows();
+
+ int mBlocks = (MDim + MCB - 1) / MCB;
+ int kBlocks = (KDim + KCB - 1) / KCB;
+
+ // remainders
+ int _mc = MDim % MCB;
+ int _kc = KDim % KCB;
+
+ int kc, mc;
+
+ block_type_t blockA{0, 0, 0, 0};
+
+ // B must be prepacked
+ assert(packB.isPrePacked() && "B matrix must be prepacked");
+
+#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN
+ std::chrono::time_point<std::chrono::high_resolution_clock> t_very_start,
+ t_start, t_end;
+ double dt;
+ t_start = std::chrono::high_resolution_clock::now();
+ t_very_start = std::chrono::high_resolution_clock::now();
+#endif
+
+ ExecuteKernel<packingAMatrix, packingBMatrix, cT, processOutputType>
+ exeKernelObj(packA, packB, 0, C, C_buffer, ldc, outProcess);
+ // ToDo: thread based work division
+ for (int i = 0; i < mBlocks; ++i) {
+ mc = (i != mBlocks - 1 || _mc == 0) ? MCB : _mc;
+ for (int k = 0; k < kBlocks; ++k) {
+ kc = (k != kBlocks - 1 || _kc == 0) ? KCB : _kc;
+ // pack A matrix
+ blockA = {i * MCB, mc, k * KCB, kc};
+
+ packA.pack(blockA);
+
+#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN
+ t_end = std::chrono::high_resolution_clock::now();
+ dt = std::chrono::duration_cast<std::chrono::nanoseconds>(t_end - t_start)
+ .count();
+ packing_time += (dt);
+ t_start = std::chrono::high_resolution_clock::now();
+#endif
+
+ exeKernelObj.execute(k);
+
+#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN
+ t_end = std::chrono::high_resolution_clock::now();
+ dt = std::chrono::duration_cast<std::chrono::nanoseconds>(t_end - t_start)
+ .count();
+ computing_time += (dt);
+ t_start = std::chrono::high_resolution_clock::now();
+#endif
+ }
+ }
+
+#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN
+ t_end = std::chrono::high_resolution_clock::now();
+ dt =
+ std::chrono::duration_cast<std::chrono::nanoseconds>(t_end - t_very_start)
+ .count();
+ run_time += (dt);
+ t_start = std::chrono::high_resolution_clock::now();
+#endif
+}
+
+template void fbgemmPacked(
+ PackMatrix<PackAWithRowOffset<uint8_t, int32_t>, uint8_t, int32_t>& packA,
+ PackMatrix<PackBMatrix<int8_t, int32_t>, int8_t, int32_t>& packB,
+ uint8_t* C,
+ int32_t* C_buffer,
+ uint32_t ldc,
+ const ReQuantizeOutput<false>& outProcess,
+ int thread_id,
+ int num_threads);
+
+template void fbgemmPacked(
+ PackMatrix<PackAWithRowOffset<uint8_t, int32_t>, uint8_t, int32_t>& packA,
+ PackMatrix<PackBMatrix<int8_t, int32_t>, int8_t, int32_t>& packB,
+ uint8_t* C,
+ int32_t* C_buffer,
+ uint32_t ldc,
+ const ReQuantizeOutput<true>& outProcess,
+ int thread_id,
+ int num_threads);
+
+template void fbgemmPacked(
+ PackMatrix<PackAWithQuantRowOffset<uint8_t, int32_t>, uint8_t, int32_t>&
+ packA,
+ PackMatrix<PackBMatrix<int8_t, int32_t>, int8_t, int32_t>& packB,
+ float* C,
+ int32_t* C_buffer,
+ uint32_t ldc,
+ const ReQuantizeForFloat<false>& outProcess,
+ int thread_id,
+ int num_threads);
+
+template void fbgemmPacked(
+ PackMatrix<PackAWithQuantRowOffset<uint8_t, int32_t>, uint8_t, int32_t>&
+ packA,
+ PackMatrix<PackBMatrix<int8_t, int32_t>, int8_t, int32_t>& packB,
+ float* C,
+ int32_t* C_buffer,
+ uint32_t ldc,
+ const ReQuantizeForFloat<true>& outProcess,
+ int thread_id,
+ int num_threads);
+
+template void fbgemmPacked(
+ PackMatrix<PackAMatrix<uint8_t, int32_t>, uint8_t, int32_t>& packA,
+ PackMatrix<PackBMatrix<int8_t, int32_t>, int8_t, int32_t>& packB,
+ int32_t* C,
+ int32_t* C_buffer,
+ uint32_t ldc,
+ const memCopy<>& outProcess,
+ int thread_id,
+ int num_threads);
+
+template void fbgemmPacked(
+ PackMatrix<PackAWithRowOffset<uint8_t, int32_t>, uint8_t, int32_t>& packA,
+ PackMatrix<PackBMatrix<int8_t, int32_t>, int8_t, int32_t>& packB,
+ float* C,
+ int32_t* C_buffer,
+ uint32_t ldc,
+ const ReQuantizeForFloat<false>& outProcess,
+ int thread_id,
+ int num_threads);
+
+template void fbgemmPacked(
+ PackMatrix<PackAWithRowOffset<uint8_t, int32_t>, uint8_t, int32_t>& packA,
+ PackMatrix<PackBMatrix<int8_t, int32_t>, int8_t, int32_t>& packB,
+ float* C,
+ int32_t* C_buffer,
+ uint32_t ldc,
+ const ReQuantizeForFloat<true>& outProcess,
+ int thread_id,
+ int num_threads);
+
+template void fbgemmPacked(
+ PackMatrix<PackAWithRowOffset<uint8_t, int32_t>, uint8_t, int32_t>& packA,
+ PackMatrix<PackBMatrix<int8_t, int32_t>, int8_t, int32_t>& packB,
+ int32_t* C,
+ int32_t* C_buffer,
+ uint32_t ldc,
+ const memCopy<>& outProcess,
+ int thread_id,
+ int num_threads);
+
+template void fbgemmPacked(
+ PackMatrix<PackAWithIm2Col<uint8_t, int32_t>, uint8_t, int32_t>& packA,
+ PackMatrix<PackBMatrix<int8_t, int32_t>, int8_t, int32_t>& packB,
+ int32_t* C,
+ int32_t* C_buffer,
+ uint32_t ldc,
+ const memCopy<>& outProcess,
+ int thread_id,
+ int num_threads);
+
+template void fbgemmPacked(
+ PackMatrix<PackAWithQuantRowOffset<uint8_t, int32_t>, uint8_t, int32_t>&
+ packA,
+ PackMatrix<PackBMatrix<int8_t, int32_t>, int8_t, int32_t>& packB,
+ int32_t* C,
+ int32_t* C_buffer,
+ uint32_t ldc,
+ const memCopy<>& outProcess,
+ int thread_id,
+ int num_threads);
+
+// 16 bit accumulation functions
+template void fbgemmPacked(
+ PackMatrix<PackAMatrix<uint8_t, int16_t>, uint8_t, int16_t>& packA,
+ PackMatrix<PackBMatrix<int8_t, int16_t>, int8_t, int16_t>& packB,
+ int32_t* C,
+ int32_t* C_buffer,
+ uint32_t ldc,
+ const memCopy<>& outProcess,
+ int thread_id,
+ int num_threads);
+
+template void fbgemmPacked(
+ PackMatrix<PackAMatrix<uint8_t, int16_t>, uint8_t, int16_t>& packA,
+ PackMatrix<PackBMatrix<int8_t, int16_t>, int8_t, int16_t>& packB,
+ uint8_t* C,
+ int32_t* C_buffer,
+ uint32_t ldc,
+ const ReQuantizeOutput<false>& outProcess,
+ int thread_id,
+ int num_threads);
+
+template void fbgemmPacked(
+ PackMatrix<PackAWithRowOffset<uint8_t, int16_t>, uint8_t, int16_t>& packA,
+ PackMatrix<PackBMatrix<int8_t, int16_t>, int8_t, int16_t>& packB,
+ uint8_t* C,
+ int32_t* C_buffer,
+ uint32_t ldc,
+ const DoSpmdmOnInpBuffer<uint8_t, int32_t, ReQuantizeOutput<false>>&
+ outProcess,
+ int thread_id,
+ int num_threads);
+
+template void fbgemmPacked(
+ PackMatrix<PackAWithRowOffset<uint8_t, int16_t>, uint8_t, int16_t>& packA,
+ PackMatrix<PackBMatrix<int8_t, int16_t>, int8_t, int16_t>& packB,
+ uint8_t* C,
+ int32_t* C_buffer,
+ uint32_t ldc,
+ const DoSpmdmOnInpBuffer<uint8_t, int32_t, ReQuantizeOutput<true>>&
+ outProcess,
+ int thread_id,
+ int num_threads);
+
+template void fbgemmPacked(
+ PackMatrix<PackAWithRowOffset<uint8_t, int16_t>, uint8_t, int16_t>& packA,
+ PackMatrix<PackBMatrix<int8_t, int16_t>, int8_t, int16_t>& packB,
+ float* C,
+ int32_t* C_buffer,
+ uint32_t ldc,
+ const DoSpmdmOnInpBuffer<float, int32_t, ReQuantizeForFloat<false>>&
+ outProcess,
+ int thread_id,
+ int num_threads);
+
+template void fbgemmPacked(
+ PackMatrix<PackAWithRowOffset<uint8_t, int16_t>, uint8_t, int16_t>& packA,
+ PackMatrix<PackBMatrix<int8_t, int16_t>, int8_t, int16_t>& packB,
+ uint8_t* C,
+ int32_t* C_buffer,
+ uint32_t ldc,
+ const ReQuantizeOutput<false>& outProcess,
+ int thread_id,
+ int num_threads);
+
+template void fbgemmPacked(
+ PackMatrix<PackAWithRowOffset<uint8_t, int16_t>, uint8_t, int16_t>& packA,
+ PackMatrix<PackBMatrix<int8_t, int16_t>, int8_t, int16_t>& packB,
+ uint8_t* C,
+ int32_t* C_buffer,
+ uint32_t ldc,
+ const ReQuantizeOutput<true>& outProcess,
+ int thread_id,
+ int num_threads);
+
+template void fbgemmPacked(
+ PackMatrix<PackAWithRowOffset<uint8_t, int16_t>, uint8_t, int16_t>& packA,
+ PackMatrix<PackBMatrix<int8_t, int16_t>, int8_t, int16_t>& packB,
+ int32_t* C,
+ int32_t* C_buffer,
+ uint32_t ldc,
+ const memCopy<>& outProcess,
+ int thread_id,
+ int num_threads);
+
+template void fbgemmPacked(
+ PackMatrix<PackAWithIm2Col<uint8_t, int16_t>, uint8_t, int16_t>& packA,
+ PackMatrix<PackBMatrix<int8_t, int16_t>, int8_t, int16_t>& packB,
+ int32_t* C,
+ int32_t* C_buffer,
+ uint32_t ldc,
+ const memCopy<>& outProcess,
+ int thread_id,
+ int num_threads);
+
+template void fbgemmPacked(
+ PackMatrix<PackAMatrix<uint8_t, int16_t>, uint8_t, int16_t>& packA,
+ PackMatrix<PackBMatrix<int8_t, int16_t>, int8_t, int16_t>& packB,
+ int32_t* C,
+ int32_t* C_buffer,
+ uint32_t ldc,
+ const DoNothing<int32_t, int32_t>& outProcess,
+ int thread_id,
+ int num_threads);
+
+template void fbgemmPacked(
+ PackMatrix<PackAWithRowOffset<uint8_t, int16_t>, uint8_t, int16_t>& packA,
+ PackMatrix<PackBMatrix<int8_t, int16_t>, int8_t, int16_t>& packB,
+ float* C,
+ int32_t* C_buffer,
+ uint32_t ldc,
+ const ReQuantizeForFloat<false>& outProcess,
+ int thread_id,
+ int num_threads);
+
+} // namespace fbgemm2
diff --git a/src/FbgemmFP16.cc b/src/FbgemmFP16.cc
new file mode 100644
index 0000000..7bbfa54
--- /dev/null
+++ b/src/FbgemmFP16.cc
@@ -0,0 +1,293 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+#include "fbgemm/FbgemmFP16.h"
+
+#include <cpuinfo.h>
+
+#include "FbgemmFP16UKernels.h"
+
+using namespace std;
+
+namespace fbgemm2 {
+
+/// class that performs packing of matrix in
+/// row-major or col-major format into
+/// internal packed blocked-row major format
+
+/// Todo: make it fast with AVX2 transpose
+inline void PackA(int nrow, int ncol, const float* from, int ldim, float* to) {
+ // for (int r = 0; r < nrow; ++r) {
+ // for (int c = 0; c < ncol; ++c) {
+ // to[r + c * nrow] = from[r * ldim + c];
+ // }
+ // }
+ transpose_simd( nrow, ncol, from, ldim, to, nrow );
+}
+
+struct KernelInfo {
+ using knl_ptr = funcptr_fp16;
+ // optimized kernels to cover all cases
+ static constexpr array<knl_ptr, 15> kernel = {{
+ nullptr, gemmkernel_1x1_AVX2_fA0fB0fC0,
+ gemmkernel_2x1_AVX2_fA0fB0fC0, gemmkernel_3x1_AVX2_fA0fB0fC0,
+ gemmkernel_4x1_AVX2_fA0fB0fC0, gemmkernel_5x1_AVX2_fA0fB0fC0,
+ gemmkernel_6x1_AVX2_fA0fB0fC0, gemmkernel_7x1_AVX2_fA0fB0fC0,
+ gemmkernel_8x1_AVX2_fA0fB0fC0, gemmkernel_9x1_AVX2_fA0fB0fC0,
+ gemmkernel_10x1_AVX2_fA0fB0fC0, gemmkernel_11x1_AVX2_fA0fB0fC0,
+ gemmkernel_12x1_AVX2_fA0fB0fC0, gemmkernel_13x1_AVX2_fA0fB0fC0,
+ gemmkernel_14x1_AVX2_fA0fB0fC0
+ }};
+
+ // autotuned kernel splits for various cases m = 1:mb_max
+ // may need re-autotuning for new uarch
+ static constexpr array<array<pair<int, int>, 2>, 121 > partition = {
+ {
+ {{ { 0, 0 }, { 0, 0 } } },
+ {{ { 1, 1 }, { 0, 0 } } },
+ {{ { 2, 1 }, { 0, 0 } } },
+ {{ { 3, 1 }, { 0, 0 } } },
+ {{ { 4, 1 }, { 0, 0 } } },
+ {{ { 5, 1 }, { 0, 0 } } },
+ {{ { 6, 1 }, { 0, 0 } } },
+ {{ { 7, 1 }, { 0, 0 } } },
+ {{ { 8, 1 }, { 0, 0 } } },
+ {{ { 9, 1 }, { 0, 0 } } },
+ {{ { 10, 1 }, { 0, 0 } } },
+ {{ { 11, 1 }, { 0, 0 } } },
+ {{ { 12, 1 }, { 0, 0 } } },
+ {{ { 13, 1 }, { 0, 0 } } },
+ {{ { 14, 1 }, { 0, 0 } } },
+ {{ { 8, 1 }, { 7, 1 } } },
+ {{ { 10, 1 }, { 6, 1 } } },
+ {{ { 11, 1 }, { 6, 1 } } },
+ {{ { 12, 1 }, { 6, 1 } } },
+ {{ { 11, 1 }, { 8, 1 } } },
+ {{ { 11, 1 }, { 9, 1 } } },
+ {{ { 12, 1 }, { 9, 1 } } },
+ {{ { 11, 2 }, { 0, 0 } } },
+ {{ { 12, 1 }, { 11, 1 } } },
+ {{ { 12, 2 }, { 0, 0 } } },
+ {{ { 13, 1 }, { 12, 1 } } },
+ {{ { 13, 2 }, { 0, 0 } } },
+ {{ { 14, 1 }, { 13, 1 } } },
+ {{ { 14, 2 }, { 0, 0 } } },
+ {{ { 11, 2 }, { 7, 1 } } },
+ {{ { 10, 3 }, { 0, 0 } } },
+ {{ { 12, 2 }, { 7, 1 } } },
+ {{ { 12, 2 }, { 8, 1 } } },
+ {{ { 11, 3 }, { 0, 0 } } },
+ {{ { 13, 2 }, { 8, 1 } } },
+ {{ { 13, 2 }, { 9, 1 } } },
+ {{ { 13, 2 }, { 10, 1 } } },
+ {{ { 13, 2 }, { 11, 1 } } },
+ {{ { 13, 2 }, { 12, 1 } } },
+ {{ { 13, 3 }, { 0, 0 } } },
+ {{ { 14, 2 }, { 12, 1 } } },
+ {{ { 14, 2 }, { 13, 1 } } },
+ {{ { 11, 3 }, { 9, 1 } } },
+ {{ { 11, 3 }, { 10, 1 } } },
+ {{ { 11, 4 }, { 0, 0 } } },
+ {{ { 12, 3 }, { 9, 1 } } },
+ {{ { 12, 3 }, { 10, 1 } } },
+ {{ { 13, 3 }, { 8, 1 } } },
+ {{ { 13, 3 }, { 9, 1 } } },
+ {{ { 13, 3 }, { 10, 1 } } },
+ {{ { 13, 3 }, { 11, 1 } } },
+ {{ { 13, 3 }, { 12, 1 } } },
+ {{ { 13, 4 }, { 0, 0 } } },
+ {{ { 14, 3 }, { 11, 1 } } },
+ {{ { 11, 4 }, { 10, 1 } } },
+ {{ { 12, 4 }, { 7, 1 } } },
+ {{ { 14, 4 }, { 0, 0 } } },
+ {{ { 12, 4 }, { 9, 1 } } },
+ {{ { 12, 4 }, { 10, 1 } } },
+ {{ { 12, 4 }, { 11, 1 } } },
+ {{ { 13, 4 }, { 8, 1 } } },
+ {{ { 13, 4 }, { 9, 1 } } },
+ {{ { 13, 4 }, { 10, 1 } } },
+ {{ { 13, 4 }, { 11, 1 } } },
+ {{ { 11, 5 }, { 9, 1 } } },
+ {{ { 13, 5 }, { 0, 0 } } },
+ {{ { 14, 4 }, { 10, 1 } } },
+ {{ { 12, 5 }, { 7, 1 } } },
+ {{ { 12, 5 }, { 8, 1 } } },
+ {{ { 14, 4 }, { 13, 1 } } },
+ {{ { 14, 5 }, { 0, 0 } } },
+ {{ { 12, 5 }, { 11, 1 } } },
+ {{ { 13, 5 }, { 7, 1 } } },
+ {{ { 11, 6 }, { 7, 1 } } },
+ {{ { 13, 5 }, { 9, 1 } } },
+ {{ { 13, 5 }, { 10, 1 } } },
+ {{ { 13, 5 }, { 11, 1 } } },
+ {{ { 13, 5 }, { 12, 1 } } },
+ {{ { 13, 6 }, { 0, 0 } } },
+ {{ { 12, 6 }, { 7, 1 } } },
+ {{ { 12, 6 }, { 8, 1 } } },
+ {{ { 12, 6 }, { 9, 1 } } },
+ {{ { 12, 6 }, { 10, 1 } } },
+ {{ { 12, 6 }, { 11, 1 } } },
+ {{ { 12, 7 }, { 0, 0 } } },
+ {{ { 13, 6 }, { 7, 1 } } },
+ {{ { 13, 6 }, { 8, 1 } } },
+ {{ { 13, 6 }, { 9, 1 } } },
+ {{ { 13, 6 }, { 10, 1 } } },
+ {{ { 13, 6 }, { 11, 1 } } },
+ {{ { 13, 6 }, { 12, 1 } } },
+ {{ { 13, 7 }, { 0, 0 } } },
+ {{ { 12, 7 }, { 8, 1 } } },
+ {{ { 12, 7 }, { 9, 1 } } },
+ {{ { 14, 6 }, { 10, 1 } } },
+ {{ { 12, 7 }, { 11, 1 } } },
+ {{ { 13, 7 }, { 5, 1 } } },
+ {{ { 13, 7 }, { 6, 1 } } },
+ {{ { 13, 7 }, { 7, 1 } } },
+ {{ { 13, 7 }, { 8, 1 } } },
+ {{ { 13, 7 }, { 9, 1 } } },
+ {{ { 13, 7 }, { 10, 1 } } },
+ {{ { 13, 7 }, { 11, 1 } } },
+ {{ { 13, 7 }, { 12, 1 } } },
+ {{ { 12, 8 }, { 8, 1 } } },
+ {{ { 12, 8 }, { 9, 1 } } },
+ {{ { 12, 8 }, { 10, 1 } } },
+ {{ { 12, 8 }, { 11, 1 } } },
+ {{ { 12, 9 }, { 0, 0 } } },
+ {{ { 11, 9 }, { 10, 1 } } },
+ {{ { 13, 8 }, { 6, 1 } } },
+ {{ { 13, 8 }, { 7, 1 } } },
+ {{ { 13, 8 }, { 8, 1 } } },
+ {{ { 13, 8 }, { 9, 1 } } },
+ {{ { 13, 8 }, { 10, 1 } } },
+ {{ { 13, 8 }, { 11, 1 } } },
+ {{ { 12, 9 }, { 8, 1 } } },
+ {{ { 13, 9 }, { 0, 0 } } },
+ {{ { 12, 9 }, { 10, 1 } } },
+ {{ { 12, 9 }, { 11, 1 } } },
+ {{ { 12, 10 }, { 0, 0 } } }
+ }
+ };
+};
+constexpr array<KernelInfo::knl_ptr, 15> KernelInfo::kernel;
+constexpr array<array<pair<int, int>, 2>, 121 > KernelInfo::partition;
+
+// autotuned kernel splits for various cases m = 1:mb_max
+void
+cblas_gemm_compute(const matrix_op_t transa, const int m, const float *A,
+ const PackedGemmMatrixFP16 &Bp, const float beta,
+ float *C) {
+ // ground truth
+ assert(cpuinfo_initialize());
+ assert(cpuinfo_has_x86_fma3());
+ assert(cpuinfo_has_x86_f16c());
+ assert(transa == matrix_op_t::NoTranspose);
+
+ // constants
+ const int n = Bp.numCols(), k = Bp.numRows(), ldc = n;
+ const int mb_max = 120;
+ constexpr int simd_width = 8;
+ constexpr int kernel_ncol_blocks = 1;
+ constexpr int kernel_ncols = kernel_ncol_blocks * simd_width;
+
+ // private scratchpad storage
+ static thread_local unique_ptr<std::array<float, 256 * 1024> > scratchpad(
+ new std::array<float, 256 * 1024>());
+
+ GemmParams gp;
+ for (auto m0 = 0; m0 < m; m0 += mb_max) {
+ int mb = std::min(mb_max, m - m0);
+ assert(mb < KernelInfo::partition.size());
+ for (auto k_ind = 0; k_ind < k; k_ind += Bp.blockRowSize()) {
+
+ // set up proper accumulation to avoid "Nan" problem
+ float beta_;
+ uint64_t accum;
+ if (k_ind == 0) {
+ // accumulate of beta != 0.0
+ // do not!!! accumulate otherwise
+ beta_ = beta;
+ accum = (beta_ == 0.0f) ? 0 : 1;
+ } else {
+ // always accumulate with beta_ = 1.0f
+ beta_ = 1.0f;
+ accum = 1;
+ }
+
+ const int kb = std::min(Bp.blockRowSize(), Bp.numRows() - k_ind);
+
+ auto m1 = 0;
+ for (auto c = 0; c < 2; c++) {
+
+ auto kernel_nrows = KernelInfo::partition[mb][c].first;
+ auto nkernel_nrows = KernelInfo::partition[mb][c].second;
+
+ auto m_start = m1, m_end = m1 + kernel_nrows * nkernel_nrows;
+ for (auto m2 = m_start; m2 < m_end; m2 += kernel_nrows) {
+ assert(kernel_nrows * kb < scratchpad->size());
+ PackA(kernel_nrows, kb, &A[m2 * k + k_ind], k, scratchpad->data());
+
+ int nbcol = n / Bp.blockColSize();
+ gp.k = kb;
+ gp.A = scratchpad->data();
+ gp.B = &(Bp(k_ind, 0));
+ gp.beta = &beta_;
+ gp.accum = accum;
+ gp.C = &C[m2 * ldc];
+ gp.ldc = ldc * sizeof(C[0]);
+ gp.b_block_cols = nbcol;
+ gp.b_block_size = gp.k * Bp.blockColSize() * sizeof(gp.B[0]);
+ if ((n % Bp.blockColSize()) == 0) {
+ KernelInfo::kernel[kernel_nrows](&gp);
+ } else {
+ int last_blk_col = nbcol * Bp.blockColSize();
+ if (nbcol) {
+ KernelInfo::kernel[kernel_nrows](&gp);
+ }
+
+ // leftover
+ int rem = n - last_blk_col;
+ assert(rem < kernel_ncols);
+ int b = (rem % simd_width) ? ((rem + simd_width) / simd_width)
+ : (rem / simd_width);
+ assert(b == 1);
+ if ((rem % simd_width) == 0) {
+ gp.B = &(Bp(k_ind, last_blk_col));
+ gp.C = &C[m2 * ldc + last_blk_col];
+ gp.b_block_cols = 1;
+ KernelInfo::kernel[kernel_nrows](&gp);
+ } else {
+ // small temporary buffer
+ float c_tmp[16 * 24] = { 0 };
+ assert((16 * 24) > kernel_nrows * kernel_ncols);
+
+ gp.B = &(Bp(k_ind, last_blk_col));
+ gp.C = c_tmp;
+ gp.ldc = 8 * sizeof(C[0]);
+ gp.b_block_cols = 1;
+ KernelInfo::kernel[kernel_nrows](&gp);
+ for (int i = 0; i < kernel_nrows; i++) {
+ // Todo: use assembly
+ for (int j = last_blk_col; j < n; j++) {
+ assert(i * 8 + (j - last_blk_col) <
+ sizeof(c_tmp) / sizeof(c_tmp[0]));
+ if (accum == 0) {
+ C[(m2 + i) * ldc + j] = c_tmp[i * 8 + (j - last_blk_col)];
+ } else {
+ C[(m2 + i) * ldc + j] = beta_ * C[(m2 + i) * ldc + j] +
+ c_tmp[i * 8 + (j - last_blk_col)];
+ }
+ }
+ }
+ }
+ }
+ }
+ m1 += kernel_nrows * nkernel_nrows;
+ }
+ }
+ }
+}
+
+
+} // namespace fbgemm
diff --git a/src/FbgemmFP16UKernels.cc b/src/FbgemmFP16UKernels.cc
new file mode 100644
index 0000000..ec1b297
--- /dev/null
+++ b/src/FbgemmFP16UKernels.cc
@@ -0,0 +1,2203 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+#include "FbgemmFP16UKernels.h"
+
+namespace fbgemm2 {
+
+void __attribute__ ((noinline)) gemmkernel_1x1_AVX2_fA0fB0fC0(GemmParams *gp)
+{
+asm volatile
+(
+#if !defined(__clang__)
+"mov r14, %[gp]\t\n"
+#else
+"mov %[gp], %%r14\t\n"
+".intel_syntax noprefix\t\n"
+#endif
+
+// Copy parameters
+// k
+"mov r8, [r14 + 0]\t\n"
+// A
+"mov r9, [r14 + 8]\t\n"
+// B
+"mov r10, [r14 + 16]\t\n"
+// beta
+"mov r15, [r14 + 24]\t\n"
+// accum
+"mov rdx, [r14 + 32]\t\n"
+// C
+"mov r12, [r14 + 40]\t\n"
+// ldc
+"mov r13, [r14 + 48]\t\n"
+// b_block_cols
+"mov rdi, [r14 + 56]\t\n"
+// b_block_size
+"mov rsi, [r14 + 64]\t\n"
+// Make copies of A and C
+"mov rax, r9\t\n"
+"mov rcx, r12\t\n"
+
+
+"mov rbx, 0\t\n"
+"loop_outter%=:\t\n"
+"mov r14, 0\t\n"
+"vxorps ymm0,ymm0,ymm0\t\n"
+
+"vcvtph2ps ymm15, XMMWORD PTR [r10 + 0]\t\n"
+"mov r11, 16\t\n"
+
+"loop_inner%=:\t\n"
+
+"vcvtph2ps ymm14,XMMWORD PTR [r10 + r11 + 0]\t\n"
+"inc r14\t\n"
+"vbroadcastss ymm1,DWORD PTR [r9+0]\t\n"
+"vfmadd231ps ymm0,ymm15,ymm1\t\n"
+"cmp r14, r8\t\n"
+"jge L_exit%=\t\n"
+"vcvtph2ps ymm15,XMMWORD PTR [r10 + r11 + 16]\t\n"
+"inc r14\t\n"
+"vbroadcastss ymm1,DWORD PTR [r9+4]\t\n"
+"vfmadd231ps ymm0,ymm14,ymm1\t\n"
+"add r11, 32\t\n"
+"add r9,8\t\n"
+"cmp r14, r8\t\n"
+"jl loop_inner%=\t\n"
+
+"L_exit%=:\t\n"
+"add r10, rsi\t\n"
+
+"cmp rdx, 1\t\n"
+"je L_accum%=\t\n"
+// Dump C
+"vmovups YMMWORD PTR [r12 + 0], ymm0\t\n"
+"add r12, r13\t\n"
+"jmp L_done%=\t\n"
+
+
+"L_accum%=:\t\n"
+// Dump C with accumulate
+"vbroadcastss ymm15,DWORD PTR [r15]\t\n"
+"vfmadd231ps ymm0,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm0\t\n"
+"add r12, r13\t\n"
+
+"L_done%=:\t\n"
+
+// next outer iteration
+"add rcx, 32\t\n"
+"mov r12, rcx\t\n"
+"mov r9, rax\t\n"
+"inc rbx\t\n"
+"cmp rbx, rdi\t\n"
+"jl loop_outter%=\t\n"
+:
+:
+[gp] "rm" (gp)
+: "r8", "r9", "r10", "r11", "r15", "r13", "r14",
+"rax", "rcx", "rdx", "rsi", "rdi", "rbx", "r12", "memory"
+);
+}
+void __attribute__ ((noinline)) gemmkernel_2x1_AVX2_fA0fB0fC0(GemmParams *gp)
+{
+asm volatile
+(
+#if !defined(__clang__)
+"mov r14, %[gp]\t\n"
+#else
+"mov %[gp], %%r14\t\n"
+".intel_syntax noprefix\t\n"
+#endif
+
+// Copy parameters
+// k
+"mov r8, [r14 + 0]\t\n"
+// A
+"mov r9, [r14 + 8]\t\n"
+// B
+"mov r10, [r14 + 16]\t\n"
+// beta
+"mov r15, [r14 + 24]\t\n"
+// accum
+"mov rdx, [r14 + 32]\t\n"
+// C
+"mov r12, [r14 + 40]\t\n"
+// ldc
+"mov r13, [r14 + 48]\t\n"
+// b_block_cols
+"mov rdi, [r14 + 56]\t\n"
+// b_block_size
+"mov rsi, [r14 + 64]\t\n"
+// Make copies of A and C
+"mov rax, r9\t\n"
+"mov rcx, r12\t\n"
+
+
+"mov rbx, 0\t\n"
+"loop_outter%=:\t\n"
+"mov r14, 0\t\n"
+"vxorps ymm0,ymm0,ymm0\t\n"
+"vxorps ymm1,ymm1,ymm1\t\n"
+
+"vcvtph2ps ymm15, XMMWORD PTR [r10 + 0]\t\n"
+"mov r11, 16\t\n"
+
+"loop_inner%=:\t\n"
+
+"vcvtph2ps ymm14,XMMWORD PTR [r10 + r11 + 0]\t\n"
+"inc r14\t\n"
+"vbroadcastss ymm2,DWORD PTR [r9+0]\t\n"
+"vfmadd231ps ymm0,ymm15,ymm2\t\n"
+"vbroadcastss ymm2,DWORD PTR [r9+4]\t\n"
+"vfmadd231ps ymm1,ymm15,ymm2\t\n"
+"cmp r14, r8\t\n"
+"jge L_exit%=\t\n"
+"vcvtph2ps ymm15,XMMWORD PTR [r10 + r11 + 16]\t\n"
+"inc r14\t\n"
+"vbroadcastss ymm2,DWORD PTR [r9+8]\t\n"
+"vfmadd231ps ymm0,ymm14,ymm2\t\n"
+"vbroadcastss ymm2,DWORD PTR [r9+12]\t\n"
+"vfmadd231ps ymm1,ymm14,ymm2\t\n"
+"add r11, 32\t\n"
+"add r9,16\t\n"
+"cmp r14, r8\t\n"
+"jl loop_inner%=\t\n"
+
+"L_exit%=:\t\n"
+"add r10, rsi\t\n"
+
+"cmp rdx, 1\t\n"
+"je L_accum%=\t\n"
+// Dump C
+"vmovups YMMWORD PTR [r12 + 0], ymm0\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm1\t\n"
+"add r12, r13\t\n"
+"jmp L_done%=\t\n"
+
+
+"L_accum%=:\t\n"
+// Dump C with accumulate
+"vbroadcastss ymm15,DWORD PTR [r15]\t\n"
+"vfmadd231ps ymm0,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm0\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm1,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm1\t\n"
+"add r12, r13\t\n"
+
+"L_done%=:\t\n"
+
+// next outer iteration
+"add rcx, 32\t\n"
+"mov r12, rcx\t\n"
+"mov r9, rax\t\n"
+"inc rbx\t\n"
+"cmp rbx, rdi\t\n"
+"jl loop_outter%=\t\n"
+:
+:
+[gp] "rm" (gp)
+: "r8", "r9", "r10", "r11", "r15", "r13", "r14",
+"rax", "rcx", "rdx", "rsi", "rdi", "rbx", "r12", "memory"
+);
+}
+void __attribute__ ((noinline)) gemmkernel_3x1_AVX2_fA0fB0fC0(GemmParams *gp)
+{
+asm volatile
+(
+#if !defined(__clang__)
+"mov r14, %[gp]\t\n"
+#else
+"mov %[gp], %%r14\t\n"
+".intel_syntax noprefix\t\n"
+#endif
+
+// Copy parameters
+// k
+"mov r8, [r14 + 0]\t\n"
+// A
+"mov r9, [r14 + 8]\t\n"
+// B
+"mov r10, [r14 + 16]\t\n"
+// beta
+"mov r15, [r14 + 24]\t\n"
+// accum
+"mov rdx, [r14 + 32]\t\n"
+// C
+"mov r12, [r14 + 40]\t\n"
+// ldc
+"mov r13, [r14 + 48]\t\n"
+// b_block_cols
+"mov rdi, [r14 + 56]\t\n"
+// b_block_size
+"mov rsi, [r14 + 64]\t\n"
+// Make copies of A and C
+"mov rax, r9\t\n"
+"mov rcx, r12\t\n"
+
+
+"mov rbx, 0\t\n"
+"loop_outter%=:\t\n"
+"mov r14, 0\t\n"
+"vxorps ymm0,ymm0,ymm0\t\n"
+"vxorps ymm1,ymm1,ymm1\t\n"
+"vxorps ymm2,ymm2,ymm2\t\n"
+
+"vcvtph2ps ymm15, XMMWORD PTR [r10 + 0]\t\n"
+"mov r11, 16\t\n"
+
+"loop_inner%=:\t\n"
+
+"vcvtph2ps ymm14,XMMWORD PTR [r10 + r11 + 0]\t\n"
+"inc r14\t\n"
+"vbroadcastss ymm3,DWORD PTR [r9+0]\t\n"
+"vfmadd231ps ymm0,ymm15,ymm3\t\n"
+"vbroadcastss ymm3,DWORD PTR [r9+4]\t\n"
+"vfmadd231ps ymm1,ymm15,ymm3\t\n"
+"vbroadcastss ymm3,DWORD PTR [r9+8]\t\n"
+"vfmadd231ps ymm2,ymm15,ymm3\t\n"
+"cmp r14, r8\t\n"
+"jge L_exit%=\t\n"
+"vcvtph2ps ymm15,XMMWORD PTR [r10 + r11 + 16]\t\n"
+"inc r14\t\n"
+"vbroadcastss ymm3,DWORD PTR [r9+12]\t\n"
+"vfmadd231ps ymm0,ymm14,ymm3\t\n"
+"vbroadcastss ymm3,DWORD PTR [r9+16]\t\n"
+"vfmadd231ps ymm1,ymm14,ymm3\t\n"
+"add r11, 32\t\n"
+"vbroadcastss ymm3,DWORD PTR [r9+20]\t\n"
+"vfmadd231ps ymm2,ymm14,ymm3\t\n"
+"add r9,24\t\n"
+"cmp r14, r8\t\n"
+"jl loop_inner%=\t\n"
+
+"L_exit%=:\t\n"
+"add r10, rsi\t\n"
+
+"cmp rdx, 1\t\n"
+"je L_accum%=\t\n"
+// Dump C
+"vmovups YMMWORD PTR [r12 + 0], ymm0\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm1\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm2\t\n"
+"add r12, r13\t\n"
+"jmp L_done%=\t\n"
+
+
+"L_accum%=:\t\n"
+// Dump C with accumulate
+"vbroadcastss ymm15,DWORD PTR [r15]\t\n"
+"vfmadd231ps ymm0,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm0\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm1,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm1\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm2,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm2\t\n"
+"add r12, r13\t\n"
+
+"L_done%=:\t\n"
+
+// next outer iteration
+"add rcx, 32\t\n"
+"mov r12, rcx\t\n"
+"mov r9, rax\t\n"
+"inc rbx\t\n"
+"cmp rbx, rdi\t\n"
+"jl loop_outter%=\t\n"
+:
+:
+[gp] "rm" (gp)
+: "r8", "r9", "r10", "r11", "r15", "r13", "r14",
+"rax", "rcx", "rdx", "rsi", "rdi", "rbx", "r12", "memory"
+);
+}
+void __attribute__ ((noinline)) gemmkernel_4x1_AVX2_fA0fB0fC0(GemmParams *gp)
+{
+asm volatile
+(
+#if !defined(__clang__)
+"mov r14, %[gp]\t\n"
+#else
+"mov %[gp], %%r14\t\n"
+".intel_syntax noprefix\t\n"
+#endif
+
+// Copy parameters
+// k
+"mov r8, [r14 + 0]\t\n"
+// A
+"mov r9, [r14 + 8]\t\n"
+// B
+"mov r10, [r14 + 16]\t\n"
+// beta
+"mov r15, [r14 + 24]\t\n"
+// accum
+"mov rdx, [r14 + 32]\t\n"
+// C
+"mov r12, [r14 + 40]\t\n"
+// ldc
+"mov r13, [r14 + 48]\t\n"
+// b_block_cols
+"mov rdi, [r14 + 56]\t\n"
+// b_block_size
+"mov rsi, [r14 + 64]\t\n"
+// Make copies of A and C
+"mov rax, r9\t\n"
+"mov rcx, r12\t\n"
+
+
+"mov rbx, 0\t\n"
+"loop_outter%=:\t\n"
+"mov r14, 0\t\n"
+"vxorps ymm0,ymm0,ymm0\t\n"
+"vxorps ymm1,ymm1,ymm1\t\n"
+"vxorps ymm2,ymm2,ymm2\t\n"
+"vxorps ymm3,ymm3,ymm3\t\n"
+
+"vcvtph2ps ymm15, XMMWORD PTR [r10 + 0]\t\n"
+"mov r11, 16\t\n"
+
+"loop_inner%=:\t\n"
+
+"vcvtph2ps ymm14,XMMWORD PTR [r10 + r11 + 0]\t\n"
+"inc r14\t\n"
+"vbroadcastss ymm4,DWORD PTR [r9+0]\t\n"
+"vfmadd231ps ymm0,ymm15,ymm4\t\n"
+"vbroadcastss ymm4,DWORD PTR [r9+4]\t\n"
+"vfmadd231ps ymm1,ymm15,ymm4\t\n"
+"vbroadcastss ymm4,DWORD PTR [r9+8]\t\n"
+"vfmadd231ps ymm2,ymm15,ymm4\t\n"
+"vbroadcastss ymm4,DWORD PTR [r9+12]\t\n"
+"vfmadd231ps ymm3,ymm15,ymm4\t\n"
+"cmp r14, r8\t\n"
+"jge L_exit%=\t\n"
+"vcvtph2ps ymm15,XMMWORD PTR [r10 + r11 + 16]\t\n"
+"inc r14\t\n"
+"vbroadcastss ymm4,DWORD PTR [r9+16]\t\n"
+"vfmadd231ps ymm0,ymm14,ymm4\t\n"
+"vbroadcastss ymm4,DWORD PTR [r9+20]\t\n"
+"vfmadd231ps ymm1,ymm14,ymm4\t\n"
+"vbroadcastss ymm4,DWORD PTR [r9+24]\t\n"
+"vfmadd231ps ymm2,ymm14,ymm4\t\n"
+"add r11, 32\t\n"
+"vbroadcastss ymm4,DWORD PTR [r9+28]\t\n"
+"vfmadd231ps ymm3,ymm14,ymm4\t\n"
+"add r9,32\t\n"
+"cmp r14, r8\t\n"
+"jl loop_inner%=\t\n"
+
+"L_exit%=:\t\n"
+"add r10, rsi\t\n"
+
+"cmp rdx, 1\t\n"
+"je L_accum%=\t\n"
+// Dump C
+"vmovups YMMWORD PTR [r12 + 0], ymm0\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm1\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm2\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm3\t\n"
+"add r12, r13\t\n"
+"jmp L_done%=\t\n"
+
+
+"L_accum%=:\t\n"
+// Dump C with accumulate
+"vbroadcastss ymm15,DWORD PTR [r15]\t\n"
+"vfmadd231ps ymm0,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm0\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm1,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm1\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm2,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm2\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm3,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm3\t\n"
+"add r12, r13\t\n"
+
+"L_done%=:\t\n"
+
+// next outer iteration
+"add rcx, 32\t\n"
+"mov r12, rcx\t\n"
+"mov r9, rax\t\n"
+"inc rbx\t\n"
+"cmp rbx, rdi\t\n"
+"jl loop_outter%=\t\n"
+:
+:
+[gp] "rm" (gp)
+: "r8", "r9", "r10", "r11", "r15", "r13", "r14",
+"rax", "rcx", "rdx", "rsi", "rdi", "rbx", "r12", "memory"
+);
+}
+void __attribute__ ((noinline)) gemmkernel_5x1_AVX2_fA0fB0fC0(GemmParams *gp)
+{
+asm volatile
+(
+#if !defined(__clang__)
+"mov r14, %[gp]\t\n"
+#else
+"mov %[gp], %%r14\t\n"
+".intel_syntax noprefix\t\n"
+#endif
+
+// Copy parameters
+// k
+"mov r8, [r14 + 0]\t\n"
+// A
+"mov r9, [r14 + 8]\t\n"
+// B
+"mov r10, [r14 + 16]\t\n"
+// beta
+"mov r15, [r14 + 24]\t\n"
+// accum
+"mov rdx, [r14 + 32]\t\n"
+// C
+"mov r12, [r14 + 40]\t\n"
+// ldc
+"mov r13, [r14 + 48]\t\n"
+// b_block_cols
+"mov rdi, [r14 + 56]\t\n"
+// b_block_size
+"mov rsi, [r14 + 64]\t\n"
+// Make copies of A and C
+"mov rax, r9\t\n"
+"mov rcx, r12\t\n"
+
+
+"mov rbx, 0\t\n"
+"loop_outter%=:\t\n"
+"mov r14, 0\t\n"
+"vxorps ymm0,ymm0,ymm0\t\n"
+"vxorps ymm1,ymm1,ymm1\t\n"
+"vxorps ymm2,ymm2,ymm2\t\n"
+"vxorps ymm3,ymm3,ymm3\t\n"
+"vxorps ymm4,ymm4,ymm4\t\n"
+
+"vcvtph2ps ymm15, XMMWORD PTR [r10 + 0]\t\n"
+"mov r11, 16\t\n"
+
+"loop_inner%=:\t\n"
+
+"vcvtph2ps ymm14,XMMWORD PTR [r10 + r11 + 0]\t\n"
+"inc r14\t\n"
+"vbroadcastss ymm5,DWORD PTR [r9+0]\t\n"
+"vfmadd231ps ymm0,ymm15,ymm5\t\n"
+"vbroadcastss ymm5,DWORD PTR [r9+4]\t\n"
+"vfmadd231ps ymm1,ymm15,ymm5\t\n"
+"vbroadcastss ymm5,DWORD PTR [r9+8]\t\n"
+"vfmadd231ps ymm2,ymm15,ymm5\t\n"
+"vbroadcastss ymm5,DWORD PTR [r9+12]\t\n"
+"vfmadd231ps ymm3,ymm15,ymm5\t\n"
+"vbroadcastss ymm5,DWORD PTR [r9+16]\t\n"
+"vfmadd231ps ymm4,ymm15,ymm5\t\n"
+"cmp r14, r8\t\n"
+"jge L_exit%=\t\n"
+"vcvtph2ps ymm15,XMMWORD PTR [r10 + r11 + 16]\t\n"
+"inc r14\t\n"
+"vbroadcastss ymm5,DWORD PTR [r9+20]\t\n"
+"vfmadd231ps ymm0,ymm14,ymm5\t\n"
+"vbroadcastss ymm5,DWORD PTR [r9+24]\t\n"
+"vfmadd231ps ymm1,ymm14,ymm5\t\n"
+"vbroadcastss ymm5,DWORD PTR [r9+28]\t\n"
+"vfmadd231ps ymm2,ymm14,ymm5\t\n"
+"add r11, 32\t\n"
+"vbroadcastss ymm5,DWORD PTR [r9+32]\t\n"
+"vfmadd231ps ymm3,ymm14,ymm5\t\n"
+"vbroadcastss ymm5,DWORD PTR [r9+36]\t\n"
+"vfmadd231ps ymm4,ymm14,ymm5\t\n"
+"add r9,40\t\n"
+"cmp r14, r8\t\n"
+"jl loop_inner%=\t\n"
+
+"L_exit%=:\t\n"
+"add r10, rsi\t\n"
+
+"cmp rdx, 1\t\n"
+"je L_accum%=\t\n"
+// Dump C
+"vmovups YMMWORD PTR [r12 + 0], ymm0\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm1\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm2\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm3\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm4\t\n"
+"add r12, r13\t\n"
+"jmp L_done%=\t\n"
+
+
+"L_accum%=:\t\n"
+// Dump C with accumulate
+"vbroadcastss ymm15,DWORD PTR [r15]\t\n"
+"vfmadd231ps ymm0,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm0\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm1,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm1\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm2,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm2\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm3,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm3\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm4,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm4\t\n"
+"add r12, r13\t\n"
+
+"L_done%=:\t\n"
+
+// next outer iteration
+"add rcx, 32\t\n"
+"mov r12, rcx\t\n"
+"mov r9, rax\t\n"
+"inc rbx\t\n"
+"cmp rbx, rdi\t\n"
+"jl loop_outter%=\t\n"
+:
+:
+[gp] "rm" (gp)
+: "r8", "r9", "r10", "r11", "r15", "r13", "r14",
+"rax", "rcx", "rdx", "rsi", "rdi", "rbx", "r12", "memory"
+);
+}
+void __attribute__ ((noinline)) gemmkernel_6x1_AVX2_fA0fB0fC0(GemmParams *gp)
+{
+asm volatile
+(
+#if !defined(__clang__)
+"mov r14, %[gp]\t\n"
+#else
+"mov %[gp], %%r14\t\n"
+".intel_syntax noprefix\t\n"
+#endif
+
+// Copy parameters
+// k
+"mov r8, [r14 + 0]\t\n"
+// A
+"mov r9, [r14 + 8]\t\n"
+// B
+"mov r10, [r14 + 16]\t\n"
+// beta
+"mov r15, [r14 + 24]\t\n"
+// accum
+"mov rdx, [r14 + 32]\t\n"
+// C
+"mov r12, [r14 + 40]\t\n"
+// ldc
+"mov r13, [r14 + 48]\t\n"
+// b_block_cols
+"mov rdi, [r14 + 56]\t\n"
+// b_block_size
+"mov rsi, [r14 + 64]\t\n"
+// Make copies of A and C
+"mov rax, r9\t\n"
+"mov rcx, r12\t\n"
+
+
+"mov rbx, 0\t\n"
+"loop_outter%=:\t\n"
+"mov r14, 0\t\n"
+"vxorps ymm0,ymm0,ymm0\t\n"
+"vxorps ymm1,ymm1,ymm1\t\n"
+"vxorps ymm2,ymm2,ymm2\t\n"
+"vxorps ymm3,ymm3,ymm3\t\n"
+"vxorps ymm4,ymm4,ymm4\t\n"
+"vxorps ymm5,ymm5,ymm5\t\n"
+
+"vcvtph2ps ymm15, XMMWORD PTR [r10 + 0]\t\n"
+"mov r11, 16\t\n"
+
+"loop_inner%=:\t\n"
+
+"vcvtph2ps ymm14,XMMWORD PTR [r10 + r11 + 0]\t\n"
+"inc r14\t\n"
+"vbroadcastss ymm6,DWORD PTR [r9+0]\t\n"
+"vfmadd231ps ymm0,ymm15,ymm6\t\n"
+"vbroadcastss ymm6,DWORD PTR [r9+4]\t\n"
+"vfmadd231ps ymm1,ymm15,ymm6\t\n"
+"vbroadcastss ymm6,DWORD PTR [r9+8]\t\n"
+"vfmadd231ps ymm2,ymm15,ymm6\t\n"
+"vbroadcastss ymm6,DWORD PTR [r9+12]\t\n"
+"vfmadd231ps ymm3,ymm15,ymm6\t\n"
+"vbroadcastss ymm6,DWORD PTR [r9+16]\t\n"
+"vfmadd231ps ymm4,ymm15,ymm6\t\n"
+"vbroadcastss ymm6,DWORD PTR [r9+20]\t\n"
+"vfmadd231ps ymm5,ymm15,ymm6\t\n"
+"cmp r14, r8\t\n"
+"jge L_exit%=\t\n"
+"vcvtph2ps ymm15,XMMWORD PTR [r10 + r11 + 16]\t\n"
+"inc r14\t\n"
+"vbroadcastss ymm6,DWORD PTR [r9+24]\t\n"
+"vfmadd231ps ymm0,ymm14,ymm6\t\n"
+"vbroadcastss ymm6,DWORD PTR [r9+28]\t\n"
+"vfmadd231ps ymm1,ymm14,ymm6\t\n"
+"vbroadcastss ymm6,DWORD PTR [r9+32]\t\n"
+"vfmadd231ps ymm2,ymm14,ymm6\t\n"
+"vbroadcastss ymm6,DWORD PTR [r9+36]\t\n"
+"vfmadd231ps ymm3,ymm14,ymm6\t\n"
+"add r11, 32\t\n"
+"vbroadcastss ymm6,DWORD PTR [r9+40]\t\n"
+"vfmadd231ps ymm4,ymm14,ymm6\t\n"
+"vbroadcastss ymm6,DWORD PTR [r9+44]\t\n"
+"vfmadd231ps ymm5,ymm14,ymm6\t\n"
+"add r9,48\t\n"
+"cmp r14, r8\t\n"
+"jl loop_inner%=\t\n"
+
+"L_exit%=:\t\n"
+"add r10, rsi\t\n"
+
+"cmp rdx, 1\t\n"
+"je L_accum%=\t\n"
+// Dump C
+"vmovups YMMWORD PTR [r12 + 0], ymm0\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm1\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm2\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm3\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm4\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm5\t\n"
+"add r12, r13\t\n"
+"jmp L_done%=\t\n"
+
+
+"L_accum%=:\t\n"
+// Dump C with accumulate
+"vbroadcastss ymm15,DWORD PTR [r15]\t\n"
+"vfmadd231ps ymm0,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm0\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm1,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm1\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm2,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm2\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm3,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm3\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm4,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm4\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm5,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm5\t\n"
+"add r12, r13\t\n"
+
+"L_done%=:\t\n"
+
+// next outer iteration
+"add rcx, 32\t\n"
+"mov r12, rcx\t\n"
+"mov r9, rax\t\n"
+"inc rbx\t\n"
+"cmp rbx, rdi\t\n"
+"jl loop_outter%=\t\n"
+:
+:
+[gp] "rm" (gp)
+: "r8", "r9", "r10", "r11", "r15", "r13", "r14",
+"rax", "rcx", "rdx", "rsi", "rdi", "rbx", "r12", "memory"
+);
+}
+void __attribute__ ((noinline)) gemmkernel_7x1_AVX2_fA0fB0fC0(GemmParams *gp)
+{
+asm volatile
+(
+#if !defined(__clang__)
+"mov r14, %[gp]\t\n"
+#else
+"mov %[gp], %%r14\t\n"
+".intel_syntax noprefix\t\n"
+#endif
+
+// Copy parameters
+// k
+"mov r8, [r14 + 0]\t\n"
+// A
+"mov r9, [r14 + 8]\t\n"
+// B
+"mov r10, [r14 + 16]\t\n"
+// beta
+"mov r15, [r14 + 24]\t\n"
+// accum
+"mov rdx, [r14 + 32]\t\n"
+// C
+"mov r12, [r14 + 40]\t\n"
+// ldc
+"mov r13, [r14 + 48]\t\n"
+// b_block_cols
+"mov rdi, [r14 + 56]\t\n"
+// b_block_size
+"mov rsi, [r14 + 64]\t\n"
+// Make copies of A and C
+"mov rax, r9\t\n"
+"mov rcx, r12\t\n"
+
+
+"mov rbx, 0\t\n"
+"loop_outter%=:\t\n"
+"mov r14, 0\t\n"
+"vxorps ymm0,ymm0,ymm0\t\n"
+"vxorps ymm1,ymm1,ymm1\t\n"
+"vxorps ymm2,ymm2,ymm2\t\n"
+"vxorps ymm3,ymm3,ymm3\t\n"
+"vxorps ymm4,ymm4,ymm4\t\n"
+"vxorps ymm5,ymm5,ymm5\t\n"
+"vxorps ymm6,ymm6,ymm6\t\n"
+
+"vcvtph2ps ymm15, XMMWORD PTR [r10 + 0]\t\n"
+"mov r11, 16\t\n"
+
+"loop_inner%=:\t\n"
+
+"vcvtph2ps ymm14,XMMWORD PTR [r10 + r11 + 0]\t\n"
+"inc r14\t\n"
+"vbroadcastss ymm7,DWORD PTR [r9+0]\t\n"
+"vfmadd231ps ymm0,ymm15,ymm7\t\n"
+"vbroadcastss ymm7,DWORD PTR [r9+4]\t\n"
+"vfmadd231ps ymm1,ymm15,ymm7\t\n"
+"vbroadcastss ymm7,DWORD PTR [r9+8]\t\n"
+"vfmadd231ps ymm2,ymm15,ymm7\t\n"
+"vbroadcastss ymm7,DWORD PTR [r9+12]\t\n"
+"vfmadd231ps ymm3,ymm15,ymm7\t\n"
+"vbroadcastss ymm7,DWORD PTR [r9+16]\t\n"
+"vfmadd231ps ymm4,ymm15,ymm7\t\n"
+"vbroadcastss ymm7,DWORD PTR [r9+20]\t\n"
+"vfmadd231ps ymm5,ymm15,ymm7\t\n"
+"vbroadcastss ymm7,DWORD PTR [r9+24]\t\n"
+"vfmadd231ps ymm6,ymm15,ymm7\t\n"
+"cmp r14, r8\t\n"
+"jge L_exit%=\t\n"
+"vcvtph2ps ymm15,XMMWORD PTR [r10 + r11 + 16]\t\n"
+"inc r14\t\n"
+"vbroadcastss ymm7,DWORD PTR [r9+28]\t\n"
+"vfmadd231ps ymm0,ymm14,ymm7\t\n"
+"vbroadcastss ymm7,DWORD PTR [r9+32]\t\n"
+"vfmadd231ps ymm1,ymm14,ymm7\t\n"
+"vbroadcastss ymm7,DWORD PTR [r9+36]\t\n"
+"vfmadd231ps ymm2,ymm14,ymm7\t\n"
+"vbroadcastss ymm7,DWORD PTR [r9+40]\t\n"
+"vfmadd231ps ymm3,ymm14,ymm7\t\n"
+"add r11, 32\t\n"
+"vbroadcastss ymm7,DWORD PTR [r9+44]\t\n"
+"vfmadd231ps ymm4,ymm14,ymm7\t\n"
+"vbroadcastss ymm7,DWORD PTR [r9+48]\t\n"
+"vfmadd231ps ymm5,ymm14,ymm7\t\n"
+"vbroadcastss ymm7,DWORD PTR [r9+52]\t\n"
+"vfmadd231ps ymm6,ymm14,ymm7\t\n"
+"add r9,56\t\n"
+"cmp r14, r8\t\n"
+"jl loop_inner%=\t\n"
+
+"L_exit%=:\t\n"
+"add r10, rsi\t\n"
+
+"cmp rdx, 1\t\n"
+"je L_accum%=\t\n"
+// Dump C
+"vmovups YMMWORD PTR [r12 + 0], ymm0\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm1\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm2\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm3\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm4\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm5\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm6\t\n"
+"add r12, r13\t\n"
+"jmp L_done%=\t\n"
+
+
+"L_accum%=:\t\n"
+// Dump C with accumulate
+"vbroadcastss ymm15,DWORD PTR [r15]\t\n"
+"vfmadd231ps ymm0,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm0\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm1,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm1\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm2,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm2\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm3,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm3\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm4,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm4\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm5,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm5\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm6,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm6\t\n"
+"add r12, r13\t\n"
+
+"L_done%=:\t\n"
+
+// next outer iteration
+"add rcx, 32\t\n"
+"mov r12, rcx\t\n"
+"mov r9, rax\t\n"
+"inc rbx\t\n"
+"cmp rbx, rdi\t\n"
+"jl loop_outter%=\t\n"
+:
+:
+[gp] "rm" (gp)
+: "r8", "r9", "r10", "r11", "r15", "r13", "r14",
+"rax", "rcx", "rdx", "rsi", "rdi", "rbx", "r12", "memory"
+);
+}
+void __attribute__ ((noinline)) gemmkernel_8x1_AVX2_fA0fB0fC0(GemmParams *gp)
+{
+asm volatile
+(
+#if !defined(__clang__)
+"mov r14, %[gp]\t\n"
+#else
+"mov %[gp], %%r14\t\n"
+".intel_syntax noprefix\t\n"
+#endif
+
+// Copy parameters
+// k
+"mov r8, [r14 + 0]\t\n"
+// A
+"mov r9, [r14 + 8]\t\n"
+// B
+"mov r10, [r14 + 16]\t\n"
+// beta
+"mov r15, [r14 + 24]\t\n"
+// accum
+"mov rdx, [r14 + 32]\t\n"
+// C
+"mov r12, [r14 + 40]\t\n"
+// ldc
+"mov r13, [r14 + 48]\t\n"
+// b_block_cols
+"mov rdi, [r14 + 56]\t\n"
+// b_block_size
+"mov rsi, [r14 + 64]\t\n"
+// Make copies of A and C
+"mov rax, r9\t\n"
+"mov rcx, r12\t\n"
+
+
+"mov rbx, 0\t\n"
+"loop_outter%=:\t\n"
+"mov r14, 0\t\n"
+"vxorps ymm0,ymm0,ymm0\t\n"
+"vxorps ymm1,ymm1,ymm1\t\n"
+"vxorps ymm2,ymm2,ymm2\t\n"
+"vxorps ymm3,ymm3,ymm3\t\n"
+"vxorps ymm4,ymm4,ymm4\t\n"
+"vxorps ymm5,ymm5,ymm5\t\n"
+"vxorps ymm6,ymm6,ymm6\t\n"
+"vxorps ymm7,ymm7,ymm7\t\n"
+
+"vcvtph2ps ymm15, XMMWORD PTR [r10 + 0]\t\n"
+"mov r11, 16\t\n"
+
+"loop_inner%=:\t\n"
+
+"vcvtph2ps ymm14,XMMWORD PTR [r10 + r11 + 0]\t\n"
+"inc r14\t\n"
+"vbroadcastss ymm8,DWORD PTR [r9+0]\t\n"
+"vfmadd231ps ymm0,ymm15,ymm8\t\n"
+"vbroadcastss ymm8,DWORD PTR [r9+4]\t\n"
+"vfmadd231ps ymm1,ymm15,ymm8\t\n"
+"vbroadcastss ymm8,DWORD PTR [r9+8]\t\n"
+"vfmadd231ps ymm2,ymm15,ymm8\t\n"
+"vbroadcastss ymm8,DWORD PTR [r9+12]\t\n"
+"vfmadd231ps ymm3,ymm15,ymm8\t\n"
+"vbroadcastss ymm8,DWORD PTR [r9+16]\t\n"
+"vfmadd231ps ymm4,ymm15,ymm8\t\n"
+"vbroadcastss ymm8,DWORD PTR [r9+20]\t\n"
+"vfmadd231ps ymm5,ymm15,ymm8\t\n"
+"vbroadcastss ymm8,DWORD PTR [r9+24]\t\n"
+"vfmadd231ps ymm6,ymm15,ymm8\t\n"
+"vbroadcastss ymm8,DWORD PTR [r9+28]\t\n"
+"vfmadd231ps ymm7,ymm15,ymm8\t\n"
+"cmp r14, r8\t\n"
+"jge L_exit%=\t\n"
+"vcvtph2ps ymm15,XMMWORD PTR [r10 + r11 + 16]\t\n"
+"inc r14\t\n"
+"vbroadcastss ymm8,DWORD PTR [r9+32]\t\n"
+"vfmadd231ps ymm0,ymm14,ymm8\t\n"
+"vbroadcastss ymm8,DWORD PTR [r9+36]\t\n"
+"vfmadd231ps ymm1,ymm14,ymm8\t\n"
+"vbroadcastss ymm8,DWORD PTR [r9+40]\t\n"
+"vfmadd231ps ymm2,ymm14,ymm8\t\n"
+"vbroadcastss ymm8,DWORD PTR [r9+44]\t\n"
+"vfmadd231ps ymm3,ymm14,ymm8\t\n"
+"vbroadcastss ymm8,DWORD PTR [r9+48]\t\n"
+"vfmadd231ps ymm4,ymm14,ymm8\t\n"
+"add r11, 32\t\n"
+"vbroadcastss ymm8,DWORD PTR [r9+52]\t\n"
+"vfmadd231ps ymm5,ymm14,ymm8\t\n"
+"vbroadcastss ymm8,DWORD PTR [r9+56]\t\n"
+"vfmadd231ps ymm6,ymm14,ymm8\t\n"
+"vbroadcastss ymm8,DWORD PTR [r9+60]\t\n"
+"vfmadd231ps ymm7,ymm14,ymm8\t\n"
+"add r9,64\t\n"
+"cmp r14, r8\t\n"
+"jl loop_inner%=\t\n"
+
+"L_exit%=:\t\n"
+"add r10, rsi\t\n"
+
+"cmp rdx, 1\t\n"
+"je L_accum%=\t\n"
+// Dump C
+"vmovups YMMWORD PTR [r12 + 0], ymm0\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm1\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm2\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm3\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm4\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm5\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm6\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm7\t\n"
+"add r12, r13\t\n"
+"jmp L_done%=\t\n"
+
+
+"L_accum%=:\t\n"
+// Dump C with accumulate
+"vbroadcastss ymm15,DWORD PTR [r15]\t\n"
+"vfmadd231ps ymm0,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm0\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm1,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm1\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm2,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm2\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm3,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm3\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm4,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm4\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm5,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm5\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm6,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm6\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm7,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm7\t\n"
+"add r12, r13\t\n"
+
+"L_done%=:\t\n"
+
+// next outer iteration
+"add rcx, 32\t\n"
+"mov r12, rcx\t\n"
+"mov r9, rax\t\n"
+"inc rbx\t\n"
+"cmp rbx, rdi\t\n"
+"jl loop_outter%=\t\n"
+:
+:
+[gp] "rm" (gp)
+: "r8", "r9", "r10", "r11", "r15", "r13", "r14",
+"rax", "rcx", "rdx", "rsi", "rdi", "rbx", "r12", "memory"
+);
+}
+void __attribute__ ((noinline)) gemmkernel_9x1_AVX2_fA0fB0fC0(GemmParams *gp)
+{
+asm volatile
+(
+#if !defined(__clang__)
+"mov r14, %[gp]\t\n"
+#else
+"mov %[gp], %%r14\t\n"
+".intel_syntax noprefix\t\n"
+#endif
+
+// Copy parameters
+// k
+"mov r8, [r14 + 0]\t\n"
+// A
+"mov r9, [r14 + 8]\t\n"
+// B
+"mov r10, [r14 + 16]\t\n"
+// beta
+"mov r15, [r14 + 24]\t\n"
+// accum
+"mov rdx, [r14 + 32]\t\n"
+// C
+"mov r12, [r14 + 40]\t\n"
+// ldc
+"mov r13, [r14 + 48]\t\n"
+// b_block_cols
+"mov rdi, [r14 + 56]\t\n"
+// b_block_size
+"mov rsi, [r14 + 64]\t\n"
+// Make copies of A and C
+"mov rax, r9\t\n"
+"mov rcx, r12\t\n"
+
+
+"mov rbx, 0\t\n"
+"loop_outter%=:\t\n"
+"mov r14, 0\t\n"
+"vxorps ymm0,ymm0,ymm0\t\n"
+"vxorps ymm1,ymm1,ymm1\t\n"
+"vxorps ymm2,ymm2,ymm2\t\n"
+"vxorps ymm3,ymm3,ymm3\t\n"
+"vxorps ymm4,ymm4,ymm4\t\n"
+"vxorps ymm5,ymm5,ymm5\t\n"
+"vxorps ymm6,ymm6,ymm6\t\n"
+"vxorps ymm7,ymm7,ymm7\t\n"
+"vxorps ymm8,ymm8,ymm8\t\n"
+
+"vcvtph2ps ymm15, XMMWORD PTR [r10 + 0]\t\n"
+"mov r11, 16\t\n"
+
+"loop_inner%=:\t\n"
+
+"vcvtph2ps ymm14,XMMWORD PTR [r10 + r11 + 0]\t\n"
+"inc r14\t\n"
+"vbroadcastss ymm9,DWORD PTR [r9+0]\t\n"
+"vfmadd231ps ymm0,ymm15,ymm9\t\n"
+"vbroadcastss ymm9,DWORD PTR [r9+4]\t\n"
+"vfmadd231ps ymm1,ymm15,ymm9\t\n"
+"vbroadcastss ymm9,DWORD PTR [r9+8]\t\n"
+"vfmadd231ps ymm2,ymm15,ymm9\t\n"
+"vbroadcastss ymm9,DWORD PTR [r9+12]\t\n"
+"vfmadd231ps ymm3,ymm15,ymm9\t\n"
+"vbroadcastss ymm9,DWORD PTR [r9+16]\t\n"
+"vfmadd231ps ymm4,ymm15,ymm9\t\n"
+"vbroadcastss ymm9,DWORD PTR [r9+20]\t\n"
+"vfmadd231ps ymm5,ymm15,ymm9\t\n"
+"vbroadcastss ymm9,DWORD PTR [r9+24]\t\n"
+"vfmadd231ps ymm6,ymm15,ymm9\t\n"
+"vbroadcastss ymm9,DWORD PTR [r9+28]\t\n"
+"vfmadd231ps ymm7,ymm15,ymm9\t\n"
+"vbroadcastss ymm9,DWORD PTR [r9+32]\t\n"
+"vfmadd231ps ymm8,ymm15,ymm9\t\n"
+"cmp r14, r8\t\n"
+"jge L_exit%=\t\n"
+"vcvtph2ps ymm15,XMMWORD PTR [r10 + r11 + 16]\t\n"
+"inc r14\t\n"
+"vbroadcastss ymm9,DWORD PTR [r9+36]\t\n"
+"vfmadd231ps ymm0,ymm14,ymm9\t\n"
+"vbroadcastss ymm9,DWORD PTR [r9+40]\t\n"
+"vfmadd231ps ymm1,ymm14,ymm9\t\n"
+"vbroadcastss ymm9,DWORD PTR [r9+44]\t\n"
+"vfmadd231ps ymm2,ymm14,ymm9\t\n"
+"vbroadcastss ymm9,DWORD PTR [r9+48]\t\n"
+"vfmadd231ps ymm3,ymm14,ymm9\t\n"
+"vbroadcastss ymm9,DWORD PTR [r9+52]\t\n"
+"vfmadd231ps ymm4,ymm14,ymm9\t\n"
+"add r11, 32\t\n"
+"vbroadcastss ymm9,DWORD PTR [r9+56]\t\n"
+"vfmadd231ps ymm5,ymm14,ymm9\t\n"
+"vbroadcastss ymm9,DWORD PTR [r9+60]\t\n"
+"vfmadd231ps ymm6,ymm14,ymm9\t\n"
+"vbroadcastss ymm9,DWORD PTR [r9+64]\t\n"
+"vfmadd231ps ymm7,ymm14,ymm9\t\n"
+"vbroadcastss ymm9,DWORD PTR [r9+68]\t\n"
+"vfmadd231ps ymm8,ymm14,ymm9\t\n"
+"add r9,72\t\n"
+"cmp r14, r8\t\n"
+"jl loop_inner%=\t\n"
+
+"L_exit%=:\t\n"
+"add r10, rsi\t\n"
+
+"cmp rdx, 1\t\n"
+"je L_accum%=\t\n"
+// Dump C
+"vmovups YMMWORD PTR [r12 + 0], ymm0\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm1\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm2\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm3\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm4\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm5\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm6\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm7\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm8\t\n"
+"add r12, r13\t\n"
+"jmp L_done%=\t\n"
+
+
+"L_accum%=:\t\n"
+// Dump C with accumulate
+"vbroadcastss ymm15,DWORD PTR [r15]\t\n"
+"vfmadd231ps ymm0,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm0\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm1,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm1\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm2,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm2\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm3,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm3\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm4,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm4\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm5,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm5\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm6,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm6\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm7,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm7\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm8,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm8\t\n"
+"add r12, r13\t\n"
+
+"L_done%=:\t\n"
+
+// next outer iteration
+"add rcx, 32\t\n"
+"mov r12, rcx\t\n"
+"mov r9, rax\t\n"
+"inc rbx\t\n"
+"cmp rbx, rdi\t\n"
+"jl loop_outter%=\t\n"
+:
+:
+[gp] "rm" (gp)
+: "r8", "r9", "r10", "r11", "r15", "r13", "r14",
+"rax", "rcx", "rdx", "rsi", "rdi", "rbx", "r12", "memory"
+);
+}
+void __attribute__ ((noinline)) gemmkernel_10x1_AVX2_fA0fB0fC0(GemmParams *gp)
+{
+asm volatile
+(
+#if !defined(__clang__)
+"mov r14, %[gp]\t\n"
+#else
+"mov %[gp], %%r14\t\n"
+".intel_syntax noprefix\t\n"
+#endif
+
+// Copy parameters
+// k
+"mov r8, [r14 + 0]\t\n"
+// A
+"mov r9, [r14 + 8]\t\n"
+// B
+"mov r10, [r14 + 16]\t\n"
+// beta
+"mov r15, [r14 + 24]\t\n"
+// accum
+"mov rdx, [r14 + 32]\t\n"
+// C
+"mov r12, [r14 + 40]\t\n"
+// ldc
+"mov r13, [r14 + 48]\t\n"
+// b_block_cols
+"mov rdi, [r14 + 56]\t\n"
+// b_block_size
+"mov rsi, [r14 + 64]\t\n"
+// Make copies of A and C
+"mov rax, r9\t\n"
+"mov rcx, r12\t\n"
+
+
+"mov rbx, 0\t\n"
+"loop_outter%=:\t\n"
+"mov r14, 0\t\n"
+"vxorps ymm0,ymm0,ymm0\t\n"
+"vxorps ymm1,ymm1,ymm1\t\n"
+"vxorps ymm2,ymm2,ymm2\t\n"
+"vxorps ymm3,ymm3,ymm3\t\n"
+"vxorps ymm4,ymm4,ymm4\t\n"
+"vxorps ymm5,ymm5,ymm5\t\n"
+"vxorps ymm6,ymm6,ymm6\t\n"
+"vxorps ymm7,ymm7,ymm7\t\n"
+"vxorps ymm8,ymm8,ymm8\t\n"
+"vxorps ymm9,ymm9,ymm9\t\n"
+
+"vcvtph2ps ymm15, XMMWORD PTR [r10 + 0]\t\n"
+"mov r11, 16\t\n"
+
+"loop_inner%=:\t\n"
+
+"vcvtph2ps ymm14,XMMWORD PTR [r10 + r11 + 0]\t\n"
+"inc r14\t\n"
+"vbroadcastss ymm10,DWORD PTR [r9+0]\t\n"
+"vfmadd231ps ymm0,ymm15,ymm10\t\n"
+"vbroadcastss ymm10,DWORD PTR [r9+4]\t\n"
+"vfmadd231ps ymm1,ymm15,ymm10\t\n"
+"vbroadcastss ymm10,DWORD PTR [r9+8]\t\n"
+"vfmadd231ps ymm2,ymm15,ymm10\t\n"
+"vbroadcastss ymm10,DWORD PTR [r9+12]\t\n"
+"vfmadd231ps ymm3,ymm15,ymm10\t\n"
+"vbroadcastss ymm10,DWORD PTR [r9+16]\t\n"
+"vfmadd231ps ymm4,ymm15,ymm10\t\n"
+"vbroadcastss ymm10,DWORD PTR [r9+20]\t\n"
+"vfmadd231ps ymm5,ymm15,ymm10\t\n"
+"vbroadcastss ymm10,DWORD PTR [r9+24]\t\n"
+"vfmadd231ps ymm6,ymm15,ymm10\t\n"
+"vbroadcastss ymm10,DWORD PTR [r9+28]\t\n"
+"vfmadd231ps ymm7,ymm15,ymm10\t\n"
+"vbroadcastss ymm10,DWORD PTR [r9+32]\t\n"
+"vfmadd231ps ymm8,ymm15,ymm10\t\n"
+"vbroadcastss ymm10,DWORD PTR [r9+36]\t\n"
+"vfmadd231ps ymm9,ymm15,ymm10\t\n"
+"cmp r14, r8\t\n"
+"jge L_exit%=\t\n"
+"vcvtph2ps ymm15,XMMWORD PTR [r10 + r11 + 16]\t\n"
+"inc r14\t\n"
+"vbroadcastss ymm10,DWORD PTR [r9+40]\t\n"
+"vfmadd231ps ymm0,ymm14,ymm10\t\n"
+"vbroadcastss ymm10,DWORD PTR [r9+44]\t\n"
+"vfmadd231ps ymm1,ymm14,ymm10\t\n"
+"vbroadcastss ymm10,DWORD PTR [r9+48]\t\n"
+"vfmadd231ps ymm2,ymm14,ymm10\t\n"
+"vbroadcastss ymm10,DWORD PTR [r9+52]\t\n"
+"vfmadd231ps ymm3,ymm14,ymm10\t\n"
+"vbroadcastss ymm10,DWORD PTR [r9+56]\t\n"
+"vfmadd231ps ymm4,ymm14,ymm10\t\n"
+"vbroadcastss ymm10,DWORD PTR [r9+60]\t\n"
+"vfmadd231ps ymm5,ymm14,ymm10\t\n"
+"add r11, 32\t\n"
+"vbroadcastss ymm10,DWORD PTR [r9+64]\t\n"
+"vfmadd231ps ymm6,ymm14,ymm10\t\n"
+"vbroadcastss ymm10,DWORD PTR [r9+68]\t\n"
+"vfmadd231ps ymm7,ymm14,ymm10\t\n"
+"vbroadcastss ymm10,DWORD PTR [r9+72]\t\n"
+"vfmadd231ps ymm8,ymm14,ymm10\t\n"
+"vbroadcastss ymm10,DWORD PTR [r9+76]\t\n"
+"vfmadd231ps ymm9,ymm14,ymm10\t\n"
+"add r9,80\t\n"
+"cmp r14, r8\t\n"
+"jl loop_inner%=\t\n"
+
+"L_exit%=:\t\n"
+"add r10, rsi\t\n"
+
+"cmp rdx, 1\t\n"
+"je L_accum%=\t\n"
+// Dump C
+"vmovups YMMWORD PTR [r12 + 0], ymm0\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm1\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm2\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm3\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm4\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm5\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm6\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm7\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm8\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm9\t\n"
+"add r12, r13\t\n"
+"jmp L_done%=\t\n"
+
+
+"L_accum%=:\t\n"
+// Dump C with accumulate
+"vbroadcastss ymm15,DWORD PTR [r15]\t\n"
+"vfmadd231ps ymm0,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm0\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm1,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm1\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm2,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm2\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm3,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm3\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm4,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm4\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm5,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm5\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm6,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm6\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm7,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm7\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm8,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm8\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm9,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm9\t\n"
+"add r12, r13\t\n"
+
+"L_done%=:\t\n"
+
+// next outer iteration
+"add rcx, 32\t\n"
+"mov r12, rcx\t\n"
+"mov r9, rax\t\n"
+"inc rbx\t\n"
+"cmp rbx, rdi\t\n"
+"jl loop_outter%=\t\n"
+:
+:
+[gp] "rm" (gp)
+: "r8", "r9", "r10", "r11", "r15", "r13", "r14",
+"rax", "rcx", "rdx", "rsi", "rdi", "rbx", "r12", "memory"
+);
+}
+void __attribute__ ((noinline)) gemmkernel_11x1_AVX2_fA0fB0fC0(GemmParams *gp)
+{
+asm volatile
+(
+#if !defined(__clang__)
+"mov r14, %[gp]\t\n"
+#else
+"mov %[gp], %%r14\t\n"
+".intel_syntax noprefix\t\n"
+#endif
+
+// Copy parameters
+// k
+"mov r8, [r14 + 0]\t\n"
+// A
+"mov r9, [r14 + 8]\t\n"
+// B
+"mov r10, [r14 + 16]\t\n"
+// beta
+"mov r15, [r14 + 24]\t\n"
+// accum
+"mov rdx, [r14 + 32]\t\n"
+// C
+"mov r12, [r14 + 40]\t\n"
+// ldc
+"mov r13, [r14 + 48]\t\n"
+// b_block_cols
+"mov rdi, [r14 + 56]\t\n"
+// b_block_size
+"mov rsi, [r14 + 64]\t\n"
+// Make copies of A and C
+"mov rax, r9\t\n"
+"mov rcx, r12\t\n"
+
+
+"mov rbx, 0\t\n"
+"loop_outter%=:\t\n"
+"mov r14, 0\t\n"
+"vxorps ymm0,ymm0,ymm0\t\n"
+"vxorps ymm1,ymm1,ymm1\t\n"
+"vxorps ymm2,ymm2,ymm2\t\n"
+"vxorps ymm3,ymm3,ymm3\t\n"
+"vxorps ymm4,ymm4,ymm4\t\n"
+"vxorps ymm5,ymm5,ymm5\t\n"
+"vxorps ymm6,ymm6,ymm6\t\n"
+"vxorps ymm7,ymm7,ymm7\t\n"
+"vxorps ymm8,ymm8,ymm8\t\n"
+"vxorps ymm9,ymm9,ymm9\t\n"
+"vxorps ymm10,ymm10,ymm10\t\n"
+
+"vcvtph2ps ymm15, XMMWORD PTR [r10 + 0]\t\n"
+"mov r11, 16\t\n"
+
+"loop_inner%=:\t\n"
+
+"vcvtph2ps ymm14,XMMWORD PTR [r10 + r11 + 0]\t\n"
+"inc r14\t\n"
+"vbroadcastss ymm11,DWORD PTR [r9+0]\t\n"
+"vfmadd231ps ymm0,ymm15,ymm11\t\n"
+"vbroadcastss ymm11,DWORD PTR [r9+4]\t\n"
+"vfmadd231ps ymm1,ymm15,ymm11\t\n"
+"vbroadcastss ymm11,DWORD PTR [r9+8]\t\n"
+"vfmadd231ps ymm2,ymm15,ymm11\t\n"
+"vbroadcastss ymm11,DWORD PTR [r9+12]\t\n"
+"vfmadd231ps ymm3,ymm15,ymm11\t\n"
+"vbroadcastss ymm11,DWORD PTR [r9+16]\t\n"
+"vfmadd231ps ymm4,ymm15,ymm11\t\n"
+"vbroadcastss ymm11,DWORD PTR [r9+20]\t\n"
+"vfmadd231ps ymm5,ymm15,ymm11\t\n"
+"vbroadcastss ymm11,DWORD PTR [r9+24]\t\n"
+"vfmadd231ps ymm6,ymm15,ymm11\t\n"
+"vbroadcastss ymm11,DWORD PTR [r9+28]\t\n"
+"vfmadd231ps ymm7,ymm15,ymm11\t\n"
+"vbroadcastss ymm11,DWORD PTR [r9+32]\t\n"
+"vfmadd231ps ymm8,ymm15,ymm11\t\n"
+"vbroadcastss ymm11,DWORD PTR [r9+36]\t\n"
+"vfmadd231ps ymm9,ymm15,ymm11\t\n"
+"vbroadcastss ymm11,DWORD PTR [r9+40]\t\n"
+"vfmadd231ps ymm10,ymm15,ymm11\t\n"
+"cmp r14, r8\t\n"
+"jge L_exit%=\t\n"
+"vcvtph2ps ymm15,XMMWORD PTR [r10 + r11 + 16]\t\n"
+"inc r14\t\n"
+"vbroadcastss ymm11,DWORD PTR [r9+44]\t\n"
+"vfmadd231ps ymm0,ymm14,ymm11\t\n"
+"vbroadcastss ymm11,DWORD PTR [r9+48]\t\n"
+"vfmadd231ps ymm1,ymm14,ymm11\t\n"
+"vbroadcastss ymm11,DWORD PTR [r9+52]\t\n"
+"vfmadd231ps ymm2,ymm14,ymm11\t\n"
+"vbroadcastss ymm11,DWORD PTR [r9+56]\t\n"
+"vfmadd231ps ymm3,ymm14,ymm11\t\n"
+"vbroadcastss ymm11,DWORD PTR [r9+60]\t\n"
+"vfmadd231ps ymm4,ymm14,ymm11\t\n"
+"vbroadcastss ymm11,DWORD PTR [r9+64]\t\n"
+"vfmadd231ps ymm5,ymm14,ymm11\t\n"
+"add r11, 32\t\n"
+"vbroadcastss ymm11,DWORD PTR [r9+68]\t\n"
+"vfmadd231ps ymm6,ymm14,ymm11\t\n"
+"vbroadcastss ymm11,DWORD PTR [r9+72]\t\n"
+"vfmadd231ps ymm7,ymm14,ymm11\t\n"
+"vbroadcastss ymm11,DWORD PTR [r9+76]\t\n"
+"vfmadd231ps ymm8,ymm14,ymm11\t\n"
+"vbroadcastss ymm11,DWORD PTR [r9+80]\t\n"
+"vfmadd231ps ymm9,ymm14,ymm11\t\n"
+"vbroadcastss ymm11,DWORD PTR [r9+84]\t\n"
+"vfmadd231ps ymm10,ymm14,ymm11\t\n"
+"add r9,88\t\n"
+"cmp r14, r8\t\n"
+"jl loop_inner%=\t\n"
+
+"L_exit%=:\t\n"
+"add r10, rsi\t\n"
+
+"cmp rdx, 1\t\n"
+"je L_accum%=\t\n"
+// Dump C
+"vmovups YMMWORD PTR [r12 + 0], ymm0\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm1\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm2\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm3\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm4\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm5\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm6\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm7\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm8\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm9\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm10\t\n"
+"add r12, r13\t\n"
+"jmp L_done%=\t\n"
+
+
+"L_accum%=:\t\n"
+// Dump C with accumulate
+"vbroadcastss ymm15,DWORD PTR [r15]\t\n"
+"vfmadd231ps ymm0,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm0\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm1,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm1\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm2,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm2\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm3,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm3\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm4,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm4\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm5,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm5\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm6,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm6\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm7,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm7\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm8,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm8\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm9,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm9\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm10,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm10\t\n"
+"add r12, r13\t\n"
+
+"L_done%=:\t\n"
+
+// next outer iteration
+"add rcx, 32\t\n"
+"mov r12, rcx\t\n"
+"mov r9, rax\t\n"
+"inc rbx\t\n"
+"cmp rbx, rdi\t\n"
+"jl loop_outter%=\t\n"
+:
+:
+[gp] "rm" (gp)
+: "r8", "r9", "r10", "r11", "r15", "r13", "r14",
+"rax", "rcx", "rdx", "rsi", "rdi", "rbx", "r12", "memory"
+);
+}
+void __attribute__ ((noinline)) gemmkernel_12x1_AVX2_fA0fB0fC0(GemmParams *gp)
+{
+asm volatile
+(
+#if !defined(__clang__)
+"mov r14, %[gp]\t\n"
+#else
+"mov %[gp], %%r14\t\n"
+".intel_syntax noprefix\t\n"
+#endif
+
+// Copy parameters
+// k
+"mov r8, [r14 + 0]\t\n"
+// A
+"mov r9, [r14 + 8]\t\n"
+// B
+"mov r10, [r14 + 16]\t\n"
+// beta
+"mov r15, [r14 + 24]\t\n"
+// accum
+"mov rdx, [r14 + 32]\t\n"
+// C
+"mov r12, [r14 + 40]\t\n"
+// ldc
+"mov r13, [r14 + 48]\t\n"
+// b_block_cols
+"mov rdi, [r14 + 56]\t\n"
+// b_block_size
+"mov rsi, [r14 + 64]\t\n"
+// Make copies of A and C
+"mov rax, r9\t\n"
+"mov rcx, r12\t\n"
+
+
+"mov rbx, 0\t\n"
+"loop_outter%=:\t\n"
+"mov r14, 0\t\n"
+"vxorps ymm0,ymm0,ymm0\t\n"
+"vxorps ymm1,ymm1,ymm1\t\n"
+"vxorps ymm2,ymm2,ymm2\t\n"
+"vxorps ymm3,ymm3,ymm3\t\n"
+"vxorps ymm4,ymm4,ymm4\t\n"
+"vxorps ymm5,ymm5,ymm5\t\n"
+"vxorps ymm6,ymm6,ymm6\t\n"
+"vxorps ymm7,ymm7,ymm7\t\n"
+"vxorps ymm8,ymm8,ymm8\t\n"
+"vxorps ymm9,ymm9,ymm9\t\n"
+"vxorps ymm10,ymm10,ymm10\t\n"
+"vxorps ymm11,ymm11,ymm11\t\n"
+
+"vcvtph2ps ymm15, XMMWORD PTR [r10 + 0]\t\n"
+"mov r11, 16\t\n"
+
+"loop_inner%=:\t\n"
+
+"vcvtph2ps ymm14,XMMWORD PTR [r10 + r11 + 0]\t\n"
+"inc r14\t\n"
+"vbroadcastss ymm12,DWORD PTR [r9+0]\t\n"
+"vfmadd231ps ymm0,ymm15,ymm12\t\n"
+"vbroadcastss ymm12,DWORD PTR [r9+4]\t\n"
+"vfmadd231ps ymm1,ymm15,ymm12\t\n"
+"vbroadcastss ymm12,DWORD PTR [r9+8]\t\n"
+"vfmadd231ps ymm2,ymm15,ymm12\t\n"
+"vbroadcastss ymm12,DWORD PTR [r9+12]\t\n"
+"vfmadd231ps ymm3,ymm15,ymm12\t\n"
+"vbroadcastss ymm12,DWORD PTR [r9+16]\t\n"
+"vfmadd231ps ymm4,ymm15,ymm12\t\n"
+"vbroadcastss ymm12,DWORD PTR [r9+20]\t\n"
+"vfmadd231ps ymm5,ymm15,ymm12\t\n"
+"vbroadcastss ymm12,DWORD PTR [r9+24]\t\n"
+"vfmadd231ps ymm6,ymm15,ymm12\t\n"
+"vbroadcastss ymm12,DWORD PTR [r9+28]\t\n"
+"vfmadd231ps ymm7,ymm15,ymm12\t\n"
+"vbroadcastss ymm12,DWORD PTR [r9+32]\t\n"
+"vfmadd231ps ymm8,ymm15,ymm12\t\n"
+"vbroadcastss ymm12,DWORD PTR [r9+36]\t\n"
+"vfmadd231ps ymm9,ymm15,ymm12\t\n"
+"vbroadcastss ymm12,DWORD PTR [r9+40]\t\n"
+"vfmadd231ps ymm10,ymm15,ymm12\t\n"
+"vbroadcastss ymm12,DWORD PTR [r9+44]\t\n"
+"vfmadd231ps ymm11,ymm15,ymm12\t\n"
+"cmp r14, r8\t\n"
+"jge L_exit%=\t\n"
+"vcvtph2ps ymm15,XMMWORD PTR [r10 + r11 + 16]\t\n"
+"inc r14\t\n"
+"vbroadcastss ymm12,DWORD PTR [r9+48]\t\n"
+"vfmadd231ps ymm0,ymm14,ymm12\t\n"
+"vbroadcastss ymm12,DWORD PTR [r9+52]\t\n"
+"vfmadd231ps ymm1,ymm14,ymm12\t\n"
+"vbroadcastss ymm12,DWORD PTR [r9+56]\t\n"
+"vfmadd231ps ymm2,ymm14,ymm12\t\n"
+"vbroadcastss ymm12,DWORD PTR [r9+60]\t\n"
+"vfmadd231ps ymm3,ymm14,ymm12\t\n"
+"vbroadcastss ymm12,DWORD PTR [r9+64]\t\n"
+"vfmadd231ps ymm4,ymm14,ymm12\t\n"
+"vbroadcastss ymm12,DWORD PTR [r9+68]\t\n"
+"vfmadd231ps ymm5,ymm14,ymm12\t\n"
+"vbroadcastss ymm12,DWORD PTR [r9+72]\t\n"
+"vfmadd231ps ymm6,ymm14,ymm12\t\n"
+"add r11, 32\t\n"
+"vbroadcastss ymm12,DWORD PTR [r9+76]\t\n"
+"vfmadd231ps ymm7,ymm14,ymm12\t\n"
+"vbroadcastss ymm12,DWORD PTR [r9+80]\t\n"
+"vfmadd231ps ymm8,ymm14,ymm12\t\n"
+"vbroadcastss ymm12,DWORD PTR [r9+84]\t\n"
+"vfmadd231ps ymm9,ymm14,ymm12\t\n"
+"vbroadcastss ymm12,DWORD PTR [r9+88]\t\n"
+"vfmadd231ps ymm10,ymm14,ymm12\t\n"
+"vbroadcastss ymm12,DWORD PTR [r9+92]\t\n"
+"vfmadd231ps ymm11,ymm14,ymm12\t\n"
+"add r9,96\t\n"
+"cmp r14, r8\t\n"
+"jl loop_inner%=\t\n"
+
+"L_exit%=:\t\n"
+"add r10, rsi\t\n"
+
+"cmp rdx, 1\t\n"
+"je L_accum%=\t\n"
+// Dump C
+"vmovups YMMWORD PTR [r12 + 0], ymm0\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm1\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm2\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm3\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm4\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm5\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm6\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm7\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm8\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm9\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm10\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm11\t\n"
+"add r12, r13\t\n"
+"jmp L_done%=\t\n"
+
+
+"L_accum%=:\t\n"
+// Dump C with accumulate
+"vbroadcastss ymm15,DWORD PTR [r15]\t\n"
+"vfmadd231ps ymm0,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm0\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm1,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm1\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm2,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm2\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm3,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm3\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm4,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm4\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm5,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm5\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm6,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm6\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm7,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm7\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm8,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm8\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm9,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm9\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm10,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm10\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm11,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm11\t\n"
+"add r12, r13\t\n"
+
+"L_done%=:\t\n"
+
+// next outer iteration
+"add rcx, 32\t\n"
+"mov r12, rcx\t\n"
+"mov r9, rax\t\n"
+"inc rbx\t\n"
+"cmp rbx, rdi\t\n"
+"jl loop_outter%=\t\n"
+:
+:
+[gp] "rm" (gp)
+: "r8", "r9", "r10", "r11", "r15", "r13", "r14",
+"rax", "rcx", "rdx", "rsi", "rdi", "rbx", "r12", "memory"
+);
+}
+void __attribute__ ((noinline)) gemmkernel_13x1_AVX2_fA0fB0fC0(GemmParams *gp)
+{
+asm volatile
+(
+#if !defined(__clang__)
+"mov r14, %[gp]\t\n"
+#else
+"mov %[gp], %%r14\t\n"
+".intel_syntax noprefix\t\n"
+#endif
+
+// Copy parameters
+// k
+"mov r8, [r14 + 0]\t\n"
+// A
+"mov r9, [r14 + 8]\t\n"
+// B
+"mov r10, [r14 + 16]\t\n"
+// beta
+"mov r15, [r14 + 24]\t\n"
+// accum
+"mov rdx, [r14 + 32]\t\n"
+// C
+"mov r12, [r14 + 40]\t\n"
+// ldc
+"mov r13, [r14 + 48]\t\n"
+// b_block_cols
+"mov rdi, [r14 + 56]\t\n"
+// b_block_size
+"mov rsi, [r14 + 64]\t\n"
+// Make copies of A and C
+"mov rax, r9\t\n"
+"mov rcx, r12\t\n"
+
+
+"mov rbx, 0\t\n"
+"loop_outter%=:\t\n"
+"mov r14, 0\t\n"
+"vxorps ymm0,ymm0,ymm0\t\n"
+"vxorps ymm1,ymm1,ymm1\t\n"
+"vxorps ymm2,ymm2,ymm2\t\n"
+"vxorps ymm3,ymm3,ymm3\t\n"
+"vxorps ymm4,ymm4,ymm4\t\n"
+"vxorps ymm5,ymm5,ymm5\t\n"
+"vxorps ymm6,ymm6,ymm6\t\n"
+"vxorps ymm7,ymm7,ymm7\t\n"
+"vxorps ymm8,ymm8,ymm8\t\n"
+"vxorps ymm9,ymm9,ymm9\t\n"
+"vxorps ymm10,ymm10,ymm10\t\n"
+"vxorps ymm11,ymm11,ymm11\t\n"
+"vxorps ymm12,ymm12,ymm12\t\n"
+
+"vcvtph2ps ymm15, XMMWORD PTR [r10 + 0]\t\n"
+"mov r11, 16\t\n"
+
+"loop_inner%=:\t\n"
+
+"vcvtph2ps ymm14,XMMWORD PTR [r10 + r11 + 0]\t\n"
+"inc r14\t\n"
+"vbroadcastss ymm13,DWORD PTR [r9+0]\t\n"
+"vfmadd231ps ymm0,ymm15,ymm13\t\n"
+"vbroadcastss ymm13,DWORD PTR [r9+4]\t\n"
+"vfmadd231ps ymm1,ymm15,ymm13\t\n"
+"vbroadcastss ymm13,DWORD PTR [r9+8]\t\n"
+"vfmadd231ps ymm2,ymm15,ymm13\t\n"
+"vbroadcastss ymm13,DWORD PTR [r9+12]\t\n"
+"vfmadd231ps ymm3,ymm15,ymm13\t\n"
+"vbroadcastss ymm13,DWORD PTR [r9+16]\t\n"
+"vfmadd231ps ymm4,ymm15,ymm13\t\n"
+"vbroadcastss ymm13,DWORD PTR [r9+20]\t\n"
+"vfmadd231ps ymm5,ymm15,ymm13\t\n"
+"vbroadcastss ymm13,DWORD PTR [r9+24]\t\n"
+"vfmadd231ps ymm6,ymm15,ymm13\t\n"
+"vbroadcastss ymm13,DWORD PTR [r9+28]\t\n"
+"vfmadd231ps ymm7,ymm15,ymm13\t\n"
+"vbroadcastss ymm13,DWORD PTR [r9+32]\t\n"
+"vfmadd231ps ymm8,ymm15,ymm13\t\n"
+"vbroadcastss ymm13,DWORD PTR [r9+36]\t\n"
+"vfmadd231ps ymm9,ymm15,ymm13\t\n"
+"vbroadcastss ymm13,DWORD PTR [r9+40]\t\n"
+"vfmadd231ps ymm10,ymm15,ymm13\t\n"
+"vbroadcastss ymm13,DWORD PTR [r9+44]\t\n"
+"vfmadd231ps ymm11,ymm15,ymm13\t\n"
+"vbroadcastss ymm13,DWORD PTR [r9+48]\t\n"
+"vfmadd231ps ymm12,ymm15,ymm13\t\n"
+"cmp r14, r8\t\n"
+"jge L_exit%=\t\n"
+"vcvtph2ps ymm15,XMMWORD PTR [r10 + r11 + 16]\t\n"
+"inc r14\t\n"
+"vbroadcastss ymm13,DWORD PTR [r9+52]\t\n"
+"vfmadd231ps ymm0,ymm14,ymm13\t\n"
+"vbroadcastss ymm13,DWORD PTR [r9+56]\t\n"
+"vfmadd231ps ymm1,ymm14,ymm13\t\n"
+"vbroadcastss ymm13,DWORD PTR [r9+60]\t\n"
+"vfmadd231ps ymm2,ymm14,ymm13\t\n"
+"vbroadcastss ymm13,DWORD PTR [r9+64]\t\n"
+"vfmadd231ps ymm3,ymm14,ymm13\t\n"
+"vbroadcastss ymm13,DWORD PTR [r9+68]\t\n"
+"vfmadd231ps ymm4,ymm14,ymm13\t\n"
+"vbroadcastss ymm13,DWORD PTR [r9+72]\t\n"
+"vfmadd231ps ymm5,ymm14,ymm13\t\n"
+"vbroadcastss ymm13,DWORD PTR [r9+76]\t\n"
+"vfmadd231ps ymm6,ymm14,ymm13\t\n"
+"add r11, 32\t\n"
+"vbroadcastss ymm13,DWORD PTR [r9+80]\t\n"
+"vfmadd231ps ymm7,ymm14,ymm13\t\n"
+"vbroadcastss ymm13,DWORD PTR [r9+84]\t\n"
+"vfmadd231ps ymm8,ymm14,ymm13\t\n"
+"vbroadcastss ymm13,DWORD PTR [r9+88]\t\n"
+"vfmadd231ps ymm9,ymm14,ymm13\t\n"
+"vbroadcastss ymm13,DWORD PTR [r9+92]\t\n"
+"vfmadd231ps ymm10,ymm14,ymm13\t\n"
+"vbroadcastss ymm13,DWORD PTR [r9+96]\t\n"
+"vfmadd231ps ymm11,ymm14,ymm13\t\n"
+"vbroadcastss ymm13,DWORD PTR [r9+100]\t\n"
+"vfmadd231ps ymm12,ymm14,ymm13\t\n"
+"add r9,104\t\n"
+"cmp r14, r8\t\n"
+"jl loop_inner%=\t\n"
+
+"L_exit%=:\t\n"
+"add r10, rsi\t\n"
+
+"cmp rdx, 1\t\n"
+"je L_accum%=\t\n"
+// Dump C
+"vmovups YMMWORD PTR [r12 + 0], ymm0\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm1\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm2\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm3\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm4\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm5\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm6\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm7\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm8\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm9\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm10\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm11\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm12\t\n"
+"add r12, r13\t\n"
+"jmp L_done%=\t\n"
+
+
+"L_accum%=:\t\n"
+// Dump C with accumulate
+"vbroadcastss ymm15,DWORD PTR [r15]\t\n"
+"vfmadd231ps ymm0,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm0\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm1,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm1\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm2,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm2\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm3,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm3\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm4,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm4\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm5,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm5\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm6,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm6\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm7,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm7\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm8,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm8\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm9,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm9\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm10,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm10\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm11,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm11\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm12,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm12\t\n"
+"add r12, r13\t\n"
+
+"L_done%=:\t\n"
+
+// next outer iteration
+"add rcx, 32\t\n"
+"mov r12, rcx\t\n"
+"mov r9, rax\t\n"
+"inc rbx\t\n"
+"cmp rbx, rdi\t\n"
+"jl loop_outter%=\t\n"
+:
+:
+[gp] "rm" (gp)
+: "r8", "r9", "r10", "r11", "r15", "r13", "r14",
+"rax", "rcx", "rdx", "rsi", "rdi", "rbx", "r12", "memory"
+);
+}
+void __attribute__ ((noinline)) gemmkernel_14x1_AVX2_fA0fB0fC0(GemmParams *gp)
+{
+asm volatile
+(
+#if !defined(__clang__)
+"mov r14, %[gp]\t\n"
+#else
+"mov %[gp], %%r14\t\n"
+".intel_syntax noprefix\t\n"
+#endif
+
+// Copy parameters
+// k
+"mov r8, [r14 + 0]\t\n"
+// A
+"mov r9, [r14 + 8]\t\n"
+// B
+"mov r10, [r14 + 16]\t\n"
+// beta
+"mov r15, [r14 + 24]\t\n"
+// accum
+"mov rdx, [r14 + 32]\t\n"
+// C
+"mov r12, [r14 + 40]\t\n"
+// ldc
+"mov r13, [r14 + 48]\t\n"
+// b_block_cols
+"mov rdi, [r14 + 56]\t\n"
+// b_block_size
+"mov rsi, [r14 + 64]\t\n"
+// Make copies of A and C
+"mov rax, r9\t\n"
+"mov rcx, r12\t\n"
+
+
+"mov rbx, 0\t\n"
+"loop_outter%=:\t\n"
+"mov r14, 0\t\n"
+"vxorps ymm0,ymm0,ymm0\t\n"
+"vxorps ymm1,ymm1,ymm1\t\n"
+"vxorps ymm2,ymm2,ymm2\t\n"
+"vxorps ymm3,ymm3,ymm3\t\n"
+"vxorps ymm4,ymm4,ymm4\t\n"
+"vxorps ymm5,ymm5,ymm5\t\n"
+"vxorps ymm6,ymm6,ymm6\t\n"
+"vxorps ymm7,ymm7,ymm7\t\n"
+"vxorps ymm8,ymm8,ymm8\t\n"
+"vxorps ymm9,ymm9,ymm9\t\n"
+"vxorps ymm10,ymm10,ymm10\t\n"
+"vxorps ymm11,ymm11,ymm11\t\n"
+"vxorps ymm12,ymm12,ymm12\t\n"
+"vxorps ymm13,ymm13,ymm13\t\n"
+
+"mov r11, 0\t\n"
+
+"loop_inner%=:\t\n"
+
+"vcvtph2ps ymm15,XMMWORD PTR [r10 + r11]\t\n"
+"vbroadcastss ymm14,DWORD PTR [r9+0]\t\n"
+"vfmadd231ps ymm0,ymm15,ymm14\t\n"
+"vbroadcastss ymm14,DWORD PTR [r9+4]\t\n"
+"vfmadd231ps ymm1,ymm15,ymm14\t\n"
+"vbroadcastss ymm14,DWORD PTR [r9+8]\t\n"
+"vfmadd231ps ymm2,ymm15,ymm14\t\n"
+"vbroadcastss ymm14,DWORD PTR [r9+12]\t\n"
+"vfmadd231ps ymm3,ymm15,ymm14\t\n"
+"vbroadcastss ymm14,DWORD PTR [r9+16]\t\n"
+"vfmadd231ps ymm4,ymm15,ymm14\t\n"
+"vbroadcastss ymm14,DWORD PTR [r9+20]\t\n"
+"vfmadd231ps ymm5,ymm15,ymm14\t\n"
+"vbroadcastss ymm14,DWORD PTR [r9+24]\t\n"
+"vfmadd231ps ymm6,ymm15,ymm14\t\n"
+"vbroadcastss ymm14,DWORD PTR [r9+28]\t\n"
+"vfmadd231ps ymm7,ymm15,ymm14\t\n"
+"vbroadcastss ymm14,DWORD PTR [r9+32]\t\n"
+"vfmadd231ps ymm8,ymm15,ymm14\t\n"
+"vbroadcastss ymm14,DWORD PTR [r9+36]\t\n"
+"vfmadd231ps ymm9,ymm15,ymm14\t\n"
+"vbroadcastss ymm14,DWORD PTR [r9+40]\t\n"
+"vfmadd231ps ymm10,ymm15,ymm14\t\n"
+"vbroadcastss ymm14,DWORD PTR [r9+44]\t\n"
+"vfmadd231ps ymm11,ymm15,ymm14\t\n"
+"vbroadcastss ymm14,DWORD PTR [r9+48]\t\n"
+"vfmadd231ps ymm12,ymm15,ymm14\t\n"
+"vbroadcastss ymm14,DWORD PTR [r9+52]\t\n"
+"vfmadd231ps ymm13,ymm15,ymm14\t\n"
+"add r9,56\t\n"
+"add r11, 16\t\n"
+"inc r14\t\n"
+"cmp r14, r8\t\n"
+"jl loop_inner%=\t\n"
+"add r10, rsi\t\n"
+
+"cmp rdx, 1\t\n"
+"je L_accum%=\t\n"
+// Dump C
+"vmovups YMMWORD PTR [r12 + 0], ymm0\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm1\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm2\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm3\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm4\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm5\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm6\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm7\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm8\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm9\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm10\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm11\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm12\t\n"
+"add r12, r13\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm13\t\n"
+"add r12, r13\t\n"
+"jmp L_done%=\t\n"
+
+
+"L_accum%=:\t\n"
+// Dump C with accumulate
+"vbroadcastss ymm15,DWORD PTR [r15]\t\n"
+"vfmadd231ps ymm0,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm0\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm1,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm1\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm2,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm2\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm3,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm3\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm4,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm4\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm5,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm5\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm6,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm6\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm7,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm7\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm8,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm8\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm9,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm9\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm10,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm10\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm11,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm11\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm12,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm12\t\n"
+"add r12, r13\t\n"
+"vfmadd231ps ymm13,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+"vmovups YMMWORD PTR [r12 + 0], ymm13\t\n"
+"add r12, r13\t\n"
+
+"L_done%=:\t\n"
+
+// next outer iteration
+"add rcx, 32\t\n"
+"mov r12, rcx\t\n"
+"mov r9, rax\t\n"
+"inc rbx\t\n"
+"cmp rbx, rdi\t\n"
+"jl loop_outter%=\t\n"
+:
+:
+[gp] "rm" (gp)
+: "r8", "r9", "r10", "r11", "r15", "r13", "r14",
+"rax", "rcx", "rdx", "rsi", "rdi", "rbx", "r12", "memory"
+);
+}
+
+} // namespace fbgemm2
diff --git a/src/FbgemmFP16UKernels.h b/src/FbgemmFP16UKernels.h
new file mode 100644
index 0000000..bf7f247
--- /dev/null
+++ b/src/FbgemmFP16UKernels.h
@@ -0,0 +1,40 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+#ifndef FBGEMM_UKERNELS
+#define FBGEMM_UKERNELS
+#include <cstdint>
+#include <tuple>
+#include <vector>
+#include "fbgemm/Types.h"
+
+namespace fbgemm2 {
+
+using fp16 = float16;
+using fp32 = float;
+struct GemmParams {uint64_t k; float *A; const fp16 *B;
+float *beta; uint64_t accum; float *C; uint64_t ldc;
+uint64_t b_block_cols; uint64_t b_block_size;};
+void __attribute__ ((noinline)) gemmkernel_1x1_AVX2_fA0fB0fC0(GemmParams *gp);
+void __attribute__ ((noinline)) gemmkernel_2x1_AVX2_fA0fB0fC0(GemmParams *gp);
+void __attribute__ ((noinline)) gemmkernel_3x1_AVX2_fA0fB0fC0(GemmParams *gp);
+void __attribute__ ((noinline)) gemmkernel_4x1_AVX2_fA0fB0fC0(GemmParams *gp);
+void __attribute__ ((noinline)) gemmkernel_5x1_AVX2_fA0fB0fC0(GemmParams *gp);
+void __attribute__ ((noinline)) gemmkernel_6x1_AVX2_fA0fB0fC0(GemmParams *gp);
+void __attribute__ ((noinline)) gemmkernel_7x1_AVX2_fA0fB0fC0(GemmParams *gp);
+void __attribute__ ((noinline)) gemmkernel_8x1_AVX2_fA0fB0fC0(GemmParams *gp);
+void __attribute__ ((noinline)) gemmkernel_9x1_AVX2_fA0fB0fC0(GemmParams *gp);
+void __attribute__ ((noinline)) gemmkernel_10x1_AVX2_fA0fB0fC0(GemmParams *gp);
+void __attribute__ ((noinline)) gemmkernel_11x1_AVX2_fA0fB0fC0(GemmParams *gp);
+void __attribute__ ((noinline)) gemmkernel_12x1_AVX2_fA0fB0fC0(GemmParams *gp);
+void __attribute__ ((noinline)) gemmkernel_13x1_AVX2_fA0fB0fC0(GemmParams *gp);
+void __attribute__ ((noinline)) gemmkernel_14x1_AVX2_fA0fB0fC0(GemmParams *gp);
+typedef void (* funcptr_fp16) (GemmParams *gp);
+;
+
+} // namespace fbgemm2
+
+#endif
diff --git a/src/FbgemmI8Depthwise.cc b/src/FbgemmI8Depthwise.cc
new file mode 100644
index 0000000..54e2272
--- /dev/null
+++ b/src/FbgemmI8Depthwise.cc
@@ -0,0 +1,1953 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+#include "FbgemmI8Depthwise.h"
+
+#include <algorithm>
+#include <array>
+#include <cassert>
+#include <cmath>
+#include <cstdio>
+#include <tuple>
+#include <vector>
+
+#include <x86intrin.h>
+
+using namespace std;
+
+namespace fbgemm2
+{
+
+static array<array<int, 8>, 8> masks = {{
+ { 0, 0, 0, 0, 0, 0, 0, 0, },
+ { -1, 0, 0, 0, 0, 0, 0, 0, },
+ { -1, -1, 0, 0, 0, 0, 0, 0, },
+ { -1, -1, -1, 0, 0, 0, 0, 0, },
+ { -1, -1, -1, -1, 0, 0, 0, 0, },
+ { -1, -1, -1, -1, -1, 0, 0, 0, },
+ { -1, -1, -1, -1, -1, -1, 0, 0, },
+ { -1, -1, -1, -1, -1, -1, -1, 0, },
+}};
+
+template <int KERNEL_PROD>
+PackedDepthWiseConvMatrix<KERNEL_PROD>::PackedDepthWiseConvMatrix(
+ int K, const int8_t *smat)
+ : K_(K) {
+ // Transpose the input matrix to make packing faster.
+ vector<int8_t> smat_transposed(K * KERNEL_PROD);
+ for (int i = 0; i < KERNEL_PROD; ++i) {
+ for (int j = 0; j < K; ++j) {
+ smat_transposed[i * K + j] = smat[i + j * KERNEL_PROD];
+ }
+ }
+
+ // Allocate packed arrays
+ constexpr int KERNEL_PROD_ALIGNED = (KERNEL_PROD + 1) / 2 * 2;
+ pmat_ = static_cast<int8_t *>(aligned_alloc(
+ 64, ((K + 31) / 32) * KERNEL_PROD_ALIGNED * 32 * sizeof(int8_t)));
+
+ // Pack input matrix
+ // The layout is optimized to use vpmaddubsw efficiently (see
+ // madd_epi16x4_packed function).
+ // For a group of 32 channels, we have 10 32B SIMD registers.
+ // Denote ith channel jth filter as (i, j)
+ // 0th SIMD register:
+ // (0, 0), (0, 1), (0, 2), (0, 3), ..., (3, 0), (3, 1), (3, 2), (3, 3)
+ // (16, 0), (16, 1), (16, 2), (16, 3), ..., (19, 0), (19, 1), (19, 2), (19, 3)
+ // 1st SIMD register:
+ // (4, 0), (4, 1), (4, 2), (4, 3), ..., (7, 0), (7, 1), (7, 2), (7, 3)
+ // (20, 0), (20, 1), (20, 2), (20, 3), ..., (23, 0), (23, 1), (23, 2), (23, 3)
+ // 2nd SIMD register:
+ // (8, 0), (8, 1), (8, 2), (8, 3), ..., (11, 0), (11, 1), (11, 2), (11, 3)
+ // (24, 0), (24, 1), (24, 2), (24, 3), ..., (27, 0), (27, 1), (27, 2), (27, 3)
+ // 3rd SIMD register:
+ // (12, 0), (12, 1), (12, 2), (12, 3), ..., (15, 0), (15, 1), (15, 2), (15, 3)
+ // (28, 0), (28, 1), (28, 2), (28, 3), ..., (31, 0), (31, 1), (31, 2), (31, 3)
+ // 4-7th SIMD register: same as the previous 4 registers but for 4-7th filter
+ // coefficients
+ // ...
+ //
+ // REMAINDER
+ // If KERNEL_PROD % 4 == 1 for example when KERNEL_PROD == 9
+ // 8th SIMD register:
+ // (0, 8), zero, ..., (7, 8), zero
+ // (16, 8), zero, ..., (23, 8), zero
+ // 9th SIMD register:
+ // (8, 8), zero, ..., (15, 8), zero
+ // (24, 8), zero, ..., (31, 8), zero
+ // We use madd_epi16_packed for this case
+ //
+ // If KERNEL_PROD % 4 == 2 for example when KERNEL_PROD == 10
+ // 8th SIMD register:
+ // (0, 8), (0, 9), ..., (7, 8), (7, 9)
+ // (16, 8), (16, 9), ..., (23, 8), (23, 9)
+ // 9th SIMD register:
+ // (8, 8), (8, 9), ..., (15, 8), (15, 9)
+ // (24, 8), (24, 9), ..., (31, 8), (31, 9)
+ //
+ // If KERNEL_PROD % 4 == 3 for example when KERNEL_PROD == 11
+ // 8th SIMD register:
+ // (0, 8), (0, 9), (0, 10), zero, ..., (3, 8), (3, 9), (3, 10), zero
+ // (16, 8), (16, 9), (16, 10), zero, ..., (19, 8), (19, 9), (19, 10), zero
+ // 9th SIMD register:
+ // (4, 8), (4, 9), (4, 10), zero, ..., (7, 8), (7, 9), (7, 10), zero
+ // (20, 8), (20, 9), (20, 10), zero, ..., (23, 8), (23, 9), (23, 10), zero
+ // 10th SIMD register:
+ // (8, 8), (8, 9), (8, 10), zero, ..., (11, 8), (11, 9), (11, 10), zero
+ // (24, 8), (24, 9), (24, 10), zero, ..., (27, 8), (27, 9), (27, 10), zero
+ // 11th SIMD register:
+ // (12, 8), (12, 9), (12, 10), zero, ..., (15, 8), (15, 9), (15, 10), zero
+ // (28, 8), (28, 9), (28, 10), zero, ..., (31, 8), (31, 9), (31, 10), zero
+ for (int k1 = 0; k1 < K; k1 += 32) {
+ array<__m256i, KERNEL_PROD> b_v;
+ int remainder = K - k1;
+ if (remainder < 32) {
+ __m256i mask_v = _mm256_loadu_si256(
+ reinterpret_cast<const __m256i *>(masks[remainder / 4].data()));
+ for (int i = 0; i < KERNEL_PROD; ++i) {
+ b_v[i] = _mm256_maskload_epi32(
+ reinterpret_cast<const int *>(smat_transposed.data() + i * K + k1),
+ mask_v);
+ }
+ } else {
+ for (int i = 0; i < KERNEL_PROD; ++i) {
+ b_v[i] = _mm256_lddqu_si256(reinterpret_cast<const __m256i *>(
+ smat_transposed.data() + i * K + k1));
+ }
+ }
+
+ // Interleave 2 SIMD registers
+ array<__m256i, KERNEL_PROD_ALIGNED> b_interleaved_epi16;
+ __m256i zero_v = _mm256_setzero_si256();
+ for (int i = 0; i < KERNEL_PROD_ALIGNED / 2; ++i) {
+ if (2 * i + 1 >= KERNEL_PROD) {
+ b_interleaved_epi16[2 * i] = _mm256_unpacklo_epi8(b_v[2 * i], zero_v);
+ b_interleaved_epi16[2 * i + 1] =
+ _mm256_unpackhi_epi8(b_v[2 * i], zero_v);
+ } else {
+ b_interleaved_epi16[2 * i] =
+ _mm256_unpacklo_epi8(b_v[2 * i], b_v[2 * i + 1]);
+ b_interleaved_epi16[2 * i + 1] =
+ _mm256_unpackhi_epi8(b_v[2 * i], b_v[2 * i + 1]);
+ }
+ }
+
+ // Interleave 4 SIMD registers
+ array<__m256i, KERNEL_PROD_ALIGNED> b_interleaved_epi32;
+ for (int i = 0; i < KERNEL_PROD_ALIGNED / 4; ++i) {
+ b_interleaved_epi32[4 * i] = _mm256_unpacklo_epi16(
+ b_interleaved_epi16[4 * i], b_interleaved_epi16[4 * i + 2]);
+ b_interleaved_epi32[4 * i + 1] = _mm256_unpackhi_epi16(
+ b_interleaved_epi16[4 * i], b_interleaved_epi16[4 * i + 2]);
+ b_interleaved_epi32[4 * i + 2] = _mm256_unpacklo_epi16(
+ b_interleaved_epi16[4 * i + 1], b_interleaved_epi16[4 * i + 3]);
+ b_interleaved_epi32[4 * i + 3] = _mm256_unpackhi_epi16(
+ b_interleaved_epi16[4 * i + 1], b_interleaved_epi16[4 * i + 3]);
+ }
+ for (int i = KERNEL_PROD_ALIGNED / 4 * 4; i < KERNEL_PROD_ALIGNED; ++i) {
+ b_interleaved_epi32[i] = b_interleaved_epi16[i];
+ }
+
+ for (int i = 0; i < KERNEL_PROD_ALIGNED; ++i) {
+ _mm256_storeu_si256(
+ reinterpret_cast<__m256i *>(
+ &pmat_[((k1 / 32) * KERNEL_PROD_ALIGNED + i) * 32]),
+ b_interleaved_epi32[i]);
+ }
+ }
+}
+
+template <int KERNEL_PROD>
+PackedDepthWiseConvMatrix<KERNEL_PROD>::~PackedDepthWiseConvMatrix()
+{
+ free(pmat_);
+}
+
+template class PackedDepthWiseConvMatrix<3 * 3>;
+template class PackedDepthWiseConvMatrix<3 * 3 * 3>;
+
+// c = a0 * b0 + a1 * b1 + a2 * b2 + a3 * b3
+// A is in uint8_t
+// B is in int8_t and pre-interleaved
+// C is in int32_t and 4 registers have results in the following layout:
+// c0_v: c[0:4], c[16:20]
+// c1_v: c[4:8], c[20:24]
+// c2_v: c[8:12], c[24:28]
+// c3_v: c[12:16], c[28:32]
+template <bool SUM_A = false>
+static inline __attribute__((always_inline))
+void madd_epi16x4_packed(
+ __m256i a0_v, __m256i a1_v, __m256i a2_v, __m256i a3_v,
+ const __m256i* b,
+ __m256i* c0_v, __m256i* c1_v, __m256i* c2_v, __m256i* c3_v,
+ __m256i* a_sum = nullptr) {
+ __m256i a01_lo_v = _mm256_unpacklo_epi8(a0_v, a1_v);
+ __m256i a01_hi_v = _mm256_unpackhi_epi8(a0_v, a1_v);
+ __m256i a23_lo_v = _mm256_unpacklo_epi8(a2_v, a3_v);
+ __m256i a23_hi_v = _mm256_unpackhi_epi8(a2_v, a3_v);
+
+ if (SUM_A) {
+ __m256i one_epi8_v = _mm256_set1_epi8(1);
+ a_sum[0] =
+ _mm256_adds_epi16(_mm256_maddubs_epi16(a01_lo_v, one_epi8_v), a_sum[0]);
+ a_sum[1] =
+ _mm256_adds_epi16(_mm256_maddubs_epi16(a01_hi_v, one_epi8_v), a_sum[1]);
+ a_sum[0] =
+ _mm256_adds_epi16(_mm256_maddubs_epi16(a23_lo_v, one_epi8_v), a_sum[0]);
+ a_sum[1] =
+ _mm256_adds_epi16(_mm256_maddubs_epi16(a23_hi_v, one_epi8_v), a_sum[1]);
+ }
+
+ __m256i a0_interleaved_v = _mm256_unpacklo_epi16(a01_lo_v, a23_lo_v);
+ __m256i a1_interleaved_v = _mm256_unpackhi_epi16(a01_lo_v, a23_lo_v);
+ __m256i a2_interleaved_v = _mm256_unpacklo_epi16(a01_hi_v, a23_hi_v);
+ __m256i a3_interleaved_v = _mm256_unpackhi_epi16(a01_hi_v, a23_hi_v);
+
+ __m256i b0_v = _mm256_load_si256(b + 0);
+ __m256i b1_v = _mm256_load_si256(b + 1);
+ __m256i b2_v = _mm256_load_si256(b + 2);
+ __m256i b3_v = _mm256_load_si256(b + 3);
+
+ __m256i ab0 = _mm256_maddubs_epi16(a0_interleaved_v, b0_v);
+ __m256i ab1 = _mm256_maddubs_epi16(a1_interleaved_v, b1_v);
+ __m256i ab2 = _mm256_maddubs_epi16(a2_interleaved_v, b2_v);
+ __m256i ab3 = _mm256_maddubs_epi16(a3_interleaved_v, b3_v);
+
+ __m256i one_v = _mm256_set1_epi16(1);
+ *c0_v = _mm256_madd_epi16(ab0, one_v);
+ *c1_v = _mm256_madd_epi16(ab1, one_v);
+ *c2_v = _mm256_madd_epi16(ab2, one_v);
+ *c3_v = _mm256_madd_epi16(ab3, one_v);
+}
+
+// c = a0 * b0 + a1 * b1 + a2 * b2
+// A is in uint8_t
+// B is in int8_t and pre-interleaved
+// C is in int32_t and 4 registers have results in the following layout:
+// c0_v: c[0:4], c[16:20]
+// c1_v: c[4:8], c[20:24]
+// c2_v: c[8:12], c[24:28]
+// c3_v: c[12:16], c[28:32]
+template <bool SUM_A = false>
+static inline __attribute__((always_inline))
+void madd_epi16x3_packed(
+ __m256i a0_v, __m256i a1_v, __m256i a2_v,
+ const __m256i* b,
+ __m256i* c0_v, __m256i* c1_v, __m256i* c2_v, __m256i* c3_v,
+ __m256i* a_sum = nullptr) {
+ __m256i zero_v = _mm256_setzero_si256();
+
+ __m256i a01_lo_v = _mm256_unpacklo_epi8(a0_v, a1_v);
+ __m256i a01_hi_v = _mm256_unpackhi_epi8(a0_v, a1_v);
+ __m256i a23_lo_v = _mm256_unpacklo_epi8(a2_v, zero_v);
+ __m256i a23_hi_v = _mm256_unpackhi_epi8(a2_v, zero_v);
+
+ if (SUM_A) {
+ __m256i one_epi8_v = _mm256_set1_epi8(1);
+ a_sum[0] =
+ _mm256_adds_epi16(_mm256_maddubs_epi16(a01_lo_v, one_epi8_v), a_sum[0]);
+ a_sum[1] =
+ _mm256_adds_epi16(_mm256_maddubs_epi16(a01_hi_v, one_epi8_v), a_sum[1]);
+ a_sum[0] =
+ _mm256_adds_epi16(_mm256_maddubs_epi16(a23_lo_v, one_epi8_v), a_sum[0]);
+ a_sum[1] =
+ _mm256_adds_epi16(_mm256_maddubs_epi16(a23_hi_v, one_epi8_v), a_sum[1]);
+ }
+
+ __m256i a0_interleaved_v = _mm256_unpacklo_epi16(a01_lo_v, a23_lo_v);
+ __m256i a1_interleaved_v = _mm256_unpackhi_epi16(a01_lo_v, a23_lo_v);
+ __m256i a2_interleaved_v = _mm256_unpacklo_epi16(a01_hi_v, a23_hi_v);
+ __m256i a3_interleaved_v = _mm256_unpackhi_epi16(a01_hi_v, a23_hi_v);
+
+ __m256i b0_v = _mm256_load_si256(b + 0);
+ __m256i b1_v = _mm256_load_si256(b + 1);
+ __m256i b2_v = _mm256_load_si256(b + 2);
+ __m256i b3_v = _mm256_load_si256(b + 3);
+
+ __m256i ab0 = _mm256_maddubs_epi16(a0_interleaved_v, b0_v);
+ __m256i ab1 = _mm256_maddubs_epi16(a1_interleaved_v, b1_v);
+ __m256i ab2 = _mm256_maddubs_epi16(a2_interleaved_v, b2_v);
+ __m256i ab3 = _mm256_maddubs_epi16(a3_interleaved_v, b3_v);
+
+ __m256i one_v = _mm256_set1_epi16(1);
+ *c0_v = _mm256_madd_epi16(ab0, one_v);
+ *c1_v = _mm256_madd_epi16(ab1, one_v);
+ *c2_v = _mm256_madd_epi16(ab2, one_v);
+ *c3_v = _mm256_madd_epi16(ab3, one_v);
+}
+
+// c = a0 * b0 + a1 * b1
+// A is in uint8_t
+// B is in int8_t and pre-interleaved
+// C is in int32_t and 4 registers have results in the following layout:
+// c0_v: c[0:4], c[4:8]
+// c1_v: c[8:12], c[12:16]
+// c2_v: c[16:20], c[20:24]
+// c3_v: c[24:28], c[28:32]
+template <bool SUM_A = false>
+static inline __attribute__((always_inline)) void
+madd_epi16x2_packed(__m256i a0_v, __m256i a1_v, const __m256i *b, __m256i *c0_v,
+ __m256i *c1_v, __m256i *c2_v, __m256i *c3_v,
+ __m256i *a_sum = nullptr) {
+ __m256i a_lo_v = _mm256_unpacklo_epi8(a0_v, a1_v);
+ __m256i a_hi_v = _mm256_unpackhi_epi8(a0_v, a1_v);
+
+ if (SUM_A) {
+ __m256i one_epi8_v = _mm256_set1_epi8(1);
+ a_sum[0] =
+ _mm256_adds_epi16(_mm256_maddubs_epi16(a_lo_v, one_epi8_v), a_sum[0]);
+ a_sum[1] =
+ _mm256_adds_epi16(_mm256_maddubs_epi16(a_hi_v, one_epi8_v), a_sum[1]);
+ }
+
+ __m256i b0_v = _mm256_load_si256(b + 0);
+ __m256i b1_v = _mm256_load_si256(b + 1);
+
+ __m256i ab_lo_v = _mm256_maddubs_epi16(a_lo_v, b0_v);
+ __m256i ab_hi_v = _mm256_maddubs_epi16(a_hi_v, b1_v);
+
+ *c0_v = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(ab_lo_v));
+ *c1_v = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(ab_hi_v));
+ *c2_v = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(ab_lo_v, 1));
+ *c3_v = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(ab_hi_v, 1));
+}
+
+// c = a0 * b0
+// A is in uint8_t
+// B is in int8_t and pre-interleaved
+// C is in int32_t and 4 registers have results in the following layout:
+// c0_v: c[0:4], c[4:8]
+// c1_v: c[8:12], c[12:16]
+// c2_v: c[16:20], c[20:24]
+// c3_v: c[24:28], c[28:32]
+template <bool SUM_A = false>
+static inline __attribute__((always_inline)) void
+madd_epi16_packed(__m256i a_v, const __m256i *b, __m256i *c0_v, __m256i *c1_v,
+ __m256i *c2_v, __m256i *c3_v, __m256i *a_sum = nullptr) {
+ __m256i zero_v = _mm256_setzero_si256();
+
+ __m256i a_lo_v = _mm256_unpacklo_epi8(a_v, zero_v);
+ __m256i a_hi_v = _mm256_unpackhi_epi8(a_v, zero_v);
+
+ if (SUM_A) {
+ __m256i one_epi8_v = _mm256_set1_epi8(1);
+ a_sum[0] =
+ _mm256_adds_epi16(_mm256_maddubs_epi16(a_lo_v, one_epi8_v), a_sum[0]);
+ a_sum[1] =
+ _mm256_adds_epi16(_mm256_maddubs_epi16(a_hi_v, one_epi8_v), a_sum[1]);
+ }
+
+ __m256i b0_v = _mm256_load_si256(b + 0);
+ __m256i b1_v = _mm256_load_si256(b + 1);
+
+ __m256i ab_lo_v = _mm256_maddubs_epi16(a_lo_v, b0_v);
+ __m256i ab_hi_v = _mm256_maddubs_epi16(a_hi_v, b1_v);
+
+ *c0_v = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(ab_lo_v));
+ *c1_v = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(ab_hi_v));
+ *c2_v = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(ab_lo_v, 1));
+ *c3_v = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(ab_hi_v, 1));
+}
+
+// K is the number of accumulations we're doing
+template <int K, bool SUM_A = false, bool REMAINDER = false, bool ACC = false>
+static inline __attribute__((always_inline)) void
+inner_prod_packed_(const __m256i *a_v, const __m256i *Bp, int32_t *C,
+ int remainder, __m256i *a_sum = nullptr) {
+ array<__m256i, 4> c, c_temp;
+ array<__m256i, 2> a_sum_temp{};
+
+ int k = 0;
+ if (K >= 4) {
+ madd_epi16x4_packed<SUM_A>(a_v[0], a_v[1], a_v[2], a_v[3], Bp,
+ &c[0], &c[1], &c[2], &c[3], a_sum_temp.data());
+
+ for (k = 4; k < K / 4 * 4; k += 4) {
+ madd_epi16x4_packed<SUM_A>(a_v[k + 0], a_v[k + 1], a_v[k + 2], a_v[k + 3],
+ Bp + k, &c_temp[0], &c_temp[1], &c_temp[2],
+ &c_temp[3], a_sum_temp.data());
+
+ c[0] = _mm256_add_epi32(c[0], c_temp[0]);
+ c[1] = _mm256_add_epi32(c[1], c_temp[1]);
+ c[2] = _mm256_add_epi32(c[2], c_temp[2]);
+ c[3] = _mm256_add_epi32(c[3], c_temp[3]);
+ }
+ } else {
+ c[0] = _mm256_setzero_si256();
+ c[1] = _mm256_setzero_si256();
+ c[2] = _mm256_setzero_si256();
+ c[3] = _mm256_setzero_si256();
+ }
+
+ if (K - k == 3) {
+ madd_epi16x3_packed<SUM_A>(a_v[k], a_v[k + 1], a_v[k + 2], Bp + k,
+ &c_temp[0], &c_temp[1], &c_temp[2], &c_temp[3],
+ a_sum_temp.data());
+
+ c[0] = _mm256_add_epi32(c[0], c_temp[0]);
+ c[1] = _mm256_add_epi32(c[1], c_temp[1]);
+ c[2] = _mm256_add_epi32(c[2], c_temp[2]);
+ c[3] = _mm256_add_epi32(c[3], c_temp[3]);
+ }
+
+ c_temp[0] = _mm256_permute2f128_si256(c[0], c[1], 0x20);
+ c_temp[1] = _mm256_permute2f128_si256(c[2], c[3], 0x20);
+ c_temp[2] = _mm256_permute2f128_si256(c[0], c[1], 0x31);
+ c_temp[3] = _mm256_permute2f128_si256(c[2], c[3], 0x31);
+
+ if (K - k == 0 || K - k == 3) {
+ c[0] = c_temp[0];
+ c[1] = c_temp[1];
+ c[2] = c_temp[2];
+ c[3] = c_temp[3];
+ } else {
+ if (K - k == 1) {
+ madd_epi16_packed<SUM_A>(a_v[k], Bp + k, &c[0], &c[1], &c[2], &c[3],
+ a_sum_temp.data());
+ } else if (K - k == 2) {
+ madd_epi16x2_packed<SUM_A>(a_v[k], a_v[k + 1], Bp + k, &c[0], &c[1],
+ &c[2], &c[3], a_sum_temp.data());
+ }
+
+ c[0] = _mm256_add_epi32(c[0], c_temp[0]);
+ c[1] = _mm256_add_epi32(c[1], c_temp[1]);
+ c[2] = _mm256_add_epi32(c[2], c_temp[2]);
+ c[3] = _mm256_add_epi32(c[3], c_temp[3]);
+ }
+
+ if (REMAINDER) {
+ for (int r = 0; r < remainder / 8; ++r) {
+ if (ACC) {
+ _mm256_storeu_si256(
+ reinterpret_cast<__m256i *>(C + r * 8),
+ _mm256_add_epi32(
+ _mm256_loadu_si256(reinterpret_cast<__m256i *>(C + r * 8)),
+ c[r]));
+ } else {
+ _mm256_storeu_si256(reinterpret_cast<__m256i *>(C + r * 8), c[r]);
+ }
+ }
+ } else {
+ if (ACC) {
+ _mm256_storeu_si256(
+ reinterpret_cast<__m256i *>(C),
+ _mm256_add_epi32(_mm256_loadu_si256(reinterpret_cast<__m256i *>(C)),
+ c[0]));
+ _mm256_storeu_si256(
+ reinterpret_cast<__m256i *>(C + 8),
+ _mm256_add_epi32(
+ _mm256_loadu_si256(reinterpret_cast<__m256i *>(C + 8)), c[1]));
+ _mm256_storeu_si256(
+ reinterpret_cast<__m256i *>(C + 16),
+ _mm256_add_epi32(
+ _mm256_loadu_si256(reinterpret_cast<__m256i *>(C + 16)), c[2]));
+ _mm256_storeu_si256(
+ reinterpret_cast<__m256i *>(C + 24),
+ _mm256_add_epi32(
+ _mm256_loadu_si256(reinterpret_cast<__m256i *>(C + 24)), c[3]));
+ } else {
+ _mm256_storeu_si256(reinterpret_cast<__m256i *>(C), c[0]);
+ _mm256_storeu_si256(reinterpret_cast<__m256i *>(C + 8), c[1]);
+ _mm256_storeu_si256(reinterpret_cast<__m256i *>(C + 16), c[2]);
+ _mm256_storeu_si256(reinterpret_cast<__m256i *>(C + 24), c[3]);
+ }
+ }
+
+ if (SUM_A) {
+ a_sum[0] = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(a_sum_temp[0]));
+ a_sum[1] = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(a_sum_temp[1]));
+ a_sum[2] =
+ _mm256_cvtepi16_epi32(_mm256_extracti128_si256(a_sum_temp[0], 1));
+ a_sum[3] =
+ _mm256_cvtepi16_epi32(_mm256_extracti128_si256(a_sum_temp[1], 1));
+ }
+}
+
+template <bool SUM_A = false, bool REMAINDER = false>
+static inline __attribute__((always_inline))
+void inner_prod_3x3_packed_(const __m256i* a_v,
+ const __m256i* Bp,
+ int32_t* C,
+ int remainder,
+ __m256i* a_sum = nullptr) {
+ return inner_prod_packed_<9, SUM_A, REMAINDER>(a_v, Bp, C, remainder,
+ a_sum);
+}
+
+// Almost same as ReQuantizeOutput in OutputProcessing-inh.h but different
+// row_offsets for each row because of depth-wise convolution
+template <bool FUSE_RELU, bool HAS_BIAS, bool PER_CHANNEL_QUANTIZATION>
+static inline __attribute__((always_inline)) void requantize_(
+ int32_t A_zero_point,
+ const float* C_multiplier,
+ int32_t C_zero_point,
+ const int32_t* C_int32,
+ uint8_t* C_uint8,
+ int n,
+ const int32_t* row_offsets,
+ const int32_t* col_offsets,
+ const int32_t* bias) {
+ __m256 multiplier_v = _mm256_setzero_ps();
+ if (!PER_CHANNEL_QUANTIZATION) {
+ multiplier_v = _mm256_set1_ps(*C_multiplier);
+ }
+
+ __m256i min_v = _mm256_set1_epi8(numeric_limits<uint8_t>::min());
+ __m256i max_v = _mm256_set1_epi8(numeric_limits<uint8_t>::max());
+
+ __m256i A_zero_point_v = _mm256_set1_epi32(A_zero_point);
+ __m256i C_zero_point_epi16_v = _mm256_set1_epi16(C_zero_point);
+ __m256i C_zero_point_epi8_v = _mm256_set1_epi8(C_zero_point);
+
+ __m256i permute_mask_v =
+ _mm256_set_epi32(0x07, 0x03, 0x06, 0x02, 0x05, 0x01, 0x04, 0x00);
+
+ constexpr int VLEN = 8;
+ int j = 0;
+ for ( ; j < n / (VLEN * 4) * (VLEN * 4); j += (VLEN * 4)) {
+ __m256i x_v =
+ _mm256_loadu_si256(reinterpret_cast<const __m256i*>(C_int32 + j));
+ __m256i y_v = _mm256_loadu_si256(
+ reinterpret_cast<const __m256i*>(C_int32 + j + VLEN));
+ __m256i z_v = _mm256_loadu_si256(
+ reinterpret_cast<const __m256i*>(C_int32 + j + 2 * VLEN));
+ __m256i w_v = _mm256_loadu_si256(
+ reinterpret_cast<const __m256i*>(C_int32 + j + 3 * VLEN));
+
+ __m256i col_off_v = _mm256_mullo_epi32(
+ A_zero_point_v,
+ _mm256_loadu_si256(
+ reinterpret_cast<const __m256i*>(col_offsets + j)));
+ __m256i row_offset_v =
+ _mm256_loadu_si256(reinterpret_cast<const __m256i *>(row_offsets + j));
+ x_v = _mm256_sub_epi32(_mm256_sub_epi32(x_v, col_off_v), row_offset_v);
+
+ col_off_v = _mm256_mullo_epi32(
+ A_zero_point_v,
+ _mm256_loadu_si256(
+ reinterpret_cast<const __m256i*>(col_offsets + j + VLEN)));
+ row_offset_v = _mm256_loadu_si256(
+ reinterpret_cast<const __m256i*>(row_offsets + j + VLEN));
+ y_v = _mm256_sub_epi32(_mm256_sub_epi32(y_v, col_off_v), row_offset_v);
+
+ col_off_v = _mm256_mullo_epi32(
+ A_zero_point_v,
+ _mm256_loadu_si256(reinterpret_cast<const __m256i*>(
+ col_offsets + j + 2 * VLEN)));
+ row_offset_v = _mm256_loadu_si256(
+ reinterpret_cast<const __m256i*>(row_offsets + j + 2 * VLEN));
+ z_v = _mm256_sub_epi32(_mm256_sub_epi32(z_v, col_off_v), row_offset_v);
+
+ col_off_v = _mm256_mullo_epi32(
+ A_zero_point_v,
+ _mm256_loadu_si256(reinterpret_cast<const __m256i*>(
+ col_offsets + j + 3 * VLEN)));
+ row_offset_v = _mm256_loadu_si256(
+ reinterpret_cast<const __m256i*>(row_offsets + j + 3 * VLEN));
+ w_v = _mm256_sub_epi32(_mm256_sub_epi32(w_v, col_off_v), row_offset_v);
+
+ if (HAS_BIAS) { // static if
+ x_v = _mm256_add_epi32(
+ x_v, _mm256_loadu_si256(reinterpret_cast<const __m256i *>(bias + j)));
+ y_v = _mm256_add_epi32(
+ y_v,
+ _mm256_loadu_si256(
+ reinterpret_cast<const __m256i*>(bias + j + VLEN)));
+ z_v = _mm256_add_epi32(
+ z_v,
+ _mm256_loadu_si256(
+ reinterpret_cast<const __m256i*>(bias + j + 2 * VLEN)));
+ w_v = _mm256_add_epi32(
+ w_v,
+ _mm256_loadu_si256(
+ reinterpret_cast<const __m256i*>(bias + j + 3 * VLEN)));
+ }
+
+ if (PER_CHANNEL_QUANTIZATION) {
+ multiplier_v = _mm256_loadu_ps(C_multiplier + j);
+ }
+ __m256 x_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(x_v), multiplier_v);
+ if (PER_CHANNEL_QUANTIZATION) {
+ multiplier_v = _mm256_loadu_ps(C_multiplier + j + VLEN);
+ }
+ __m256 y_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(y_v), multiplier_v);
+ if (PER_CHANNEL_QUANTIZATION) {
+ multiplier_v = _mm256_loadu_ps(C_multiplier + j + 2 * VLEN);
+ }
+ __m256 z_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(z_v), multiplier_v);
+ if (PER_CHANNEL_QUANTIZATION) {
+ multiplier_v = _mm256_loadu_ps(C_multiplier + j + 3 * VLEN);
+ }
+ __m256 w_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(w_v), multiplier_v);
+
+ __m256i x_rounded_v = _mm256_cvtps_epi32(x_scaled_v);
+ __m256i y_rounded_v = _mm256_cvtps_epi32(y_scaled_v);
+ __m256i z_rounded_v = _mm256_cvtps_epi32(z_scaled_v);
+ __m256i w_rounded_v = _mm256_cvtps_epi32(w_scaled_v);
+
+ __m256i xy_packed_v = _mm256_adds_epi16(
+ _mm256_packs_epi32(x_rounded_v, y_rounded_v), C_zero_point_epi16_v);
+ __m256i zw_packed_v = _mm256_adds_epi16(
+ _mm256_packs_epi32(z_rounded_v, w_rounded_v), C_zero_point_epi16_v);
+ __m256i xyzw_packed_v = _mm256_packus_epi16(xy_packed_v, zw_packed_v);
+ __m256i xyzw_clamped_v = _mm256_max_epu8(
+ FUSE_RELU ? C_zero_point_epi8_v : min_v,
+ _mm256_min_epu8(xyzw_packed_v, max_v));
+
+ xyzw_clamped_v =
+ _mm256_permutevar8x32_epi32(xyzw_clamped_v, permute_mask_v);
+
+ _mm256_storeu_si256(
+ reinterpret_cast<__m256i*>(C_uint8 + j), xyzw_clamped_v);
+ } // j loop vectorized and unrolled 4x
+
+ for ( ; j < n / VLEN * VLEN; j += VLEN) {
+ __m256i x_v =
+ _mm256_loadu_si256(reinterpret_cast<const __m256i*>(C_int32 + j));
+
+ __m256i col_off_v = _mm256_mullo_epi32(
+ A_zero_point_v,
+ _mm256_loadu_si256(
+ reinterpret_cast<const __m256i*>(col_offsets + j)));
+ __m256i row_offset_v =
+ _mm256_loadu_si256(reinterpret_cast<const __m256i *>(row_offsets + j));
+ x_v = _mm256_sub_epi32(_mm256_sub_epi32(x_v, col_off_v), row_offset_v);
+
+ if (HAS_BIAS) { // static if
+ x_v = _mm256_add_epi32(
+ x_v, _mm256_loadu_si256(reinterpret_cast<const __m256i *>(bias + j)));
+ }
+
+ if (PER_CHANNEL_QUANTIZATION) {
+ multiplier_v = _mm256_loadu_ps(C_multiplier + j);
+ }
+ __m256 x_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(x_v), multiplier_v);
+ __m256i x_rounded_v = _mm256_cvtps_epi32(x_scaled_v);
+
+ __m256i x_packed_v = _mm256_adds_epi16(
+ _mm256_packs_epi32(x_rounded_v, _mm256_setzero_si256()),
+ C_zero_point_epi16_v);
+ x_packed_v = _mm256_packus_epi16(x_packed_v, _mm256_setzero_si256());
+ __m256i x_clamped_v = _mm256_max_epu8(
+ FUSE_RELU ? C_zero_point_epi8_v : min_v,
+ _mm256_min_epu8(x_packed_v, max_v));
+
+ x_clamped_v =
+ _mm256_permutevar8x32_epi32(x_clamped_v, permute_mask_v);
+
+ _mm_storel_epi64(
+ reinterpret_cast<__m128i*>(C_uint8 + j),
+ _mm256_castsi256_si128(x_clamped_v));
+ } // j loop vectorized
+
+ for ( ; j < n; ++j) {
+ int32_t raw = C_int32[j] - A_zero_point * col_offsets[j] - row_offsets[j];
+ if (HAS_BIAS) { // static if
+ raw += bias[j];
+ }
+
+ float ab = raw * C_multiplier[PER_CHANNEL_QUANTIZATION ? j : 0];
+ long rounded = lrintf(ab) + C_zero_point;
+
+ C_uint8[j] = std::max(
+ FUSE_RELU ? static_cast<long>(C_zero_point) : 0l,
+ std::min(255l, rounded));
+ }
+}
+
+template <bool FUSE_RELU, bool HAS_BIAS>
+static inline __attribute__((always_inline)) void
+requantize_(int32_t A_zero_point, float C_multiplier,
+ int32_t C_zero_point, const int32_t *C_int32, uint8_t *C_uint8,
+ int n, const int32_t *row_offsets, const int32_t *col_offsets,
+ const int32_t *bias) {
+ requantize_<FUSE_RELU, HAS_BIAS, false /* PER_CHANNEL_QUANTIZATION */>(
+ A_zero_point,
+ &C_multiplier,
+ C_zero_point,
+ C_int32,
+ C_uint8,
+ n,
+ row_offsets,
+ col_offsets,
+ bias);
+}
+
+template <bool FUSE_RELU, bool HAS_BIAS>
+static inline __attribute__((always_inline)) void
+requantize_per_channel_(int32_t A_zero_point, const float *C_multiplier,
+ int32_t C_zero_point, const int32_t *C_int32,
+ uint8_t *C_uint8, int n, const int32_t *row_offsets,
+ const int32_t *col_offsets, const int32_t *bias) {
+ requantize_<FUSE_RELU, HAS_BIAS, true /* PER_CHANNEL_QUANTIZATION */>(
+ A_zero_point,
+ C_multiplier,
+ C_zero_point,
+ C_int32,
+ C_uint8,
+ n,
+ row_offsets,
+ col_offsets,
+ bias);
+}
+
+template <bool REMAINDER>
+static inline __attribute__((always_inline)) __m256i
+load_a(const uint8_t* A, __m256i mask_v) {
+ if (REMAINDER) {
+ return _mm256_maskload_epi32(reinterpret_cast<const int *>(A), mask_v);
+ } else {
+ return _mm256_lddqu_si256(reinterpret_cast<const __m256i *>(A));
+ }
+}
+
+template <bool SUM_A, bool REMAINDER = false,
+ bool PER_CHANNEL_QUANTIZATION = false>
+static inline __attribute__((always_inline)) void
+inner_prod_3x3_packed_(int H, int W, int K, int h_in, int w_in,
+ const uint8_t *A, int32_t A_zero_point, const int8_t *Bp,
+ const int32_t *B_zero_point, int32_t *C, int remainder,
+ int32_t *row_offsets) {
+ __m256i A_zero_point_v = _mm256_set1_epi8(static_cast<uint8_t>(A_zero_point));
+ __m256i mask_v = _mm256_setzero_si256();
+ if (REMAINDER) {
+ mask_v = _mm256_loadu_si256(
+ reinterpret_cast<const __m256i *>(masks[remainder / 4].data()));
+ }
+
+ // The code below can be written as a simple R*S loop but the compiler
+ // doesn't unroll so we're manually unrolling it.
+ // constexpr int R = 3, S = 3;
+ // array<__m256i, R * S> a_v;
+ // for (int r = 0; r < R; ++r) {
+ // for (int s = 0; s < S; ++s) {
+ // if (h_in + r >= 0 && h_in + r < H && w_in + s >= 0 && w_in + s < W) {
+ // if (REMAINDER) {
+ // a_v[r * S + s] =
+ // _mm256_maskload_epi32((const int *)(A + (r * W + s) * K),
+ // mask_v);
+ // } else {
+ // a_v[r * S + s] =
+ // _mm256_lddqu_si256((const __m256i *)(A + (r * W + s) * K));
+ // }
+ // } else {
+ // a_v[r * S + s] = A_zero_point_v;
+ // }
+ // }
+ // }
+ array<__m256i, 9> a_v = {
+ A_zero_point_v,
+ A_zero_point_v,
+ A_zero_point_v,
+ A_zero_point_v,
+ A_zero_point_v,
+ A_zero_point_v,
+ A_zero_point_v,
+ A_zero_point_v,
+ A_zero_point_v,
+ };
+
+ if (h_in >= 0 && h_in < H) {
+ if (w_in >= 0 && w_in < W) {
+ a_v[0] = load_a<REMAINDER>(A + (0 * W + 0) * K, mask_v);
+ }
+ if (w_in + 1 >= 0 && w_in + 1 < W) {
+ a_v[1] = load_a<REMAINDER>(A + (0 * W + 1) * K, mask_v);
+ }
+ if (w_in + 2 >= 0 && w_in + 2 < W) {
+ a_v[2] = load_a<REMAINDER>(A + (0 * W + 2) * K, mask_v);
+ }
+ }
+
+ if (h_in + 1 >= 0 && h_in + 1 < H) {
+ if (w_in >= 0 && w_in < W) {
+ a_v[3] = load_a<REMAINDER>(A + (1 * W + 0) * K, mask_v);
+ }
+ if (w_in + 1 >= 0 && w_in + 1 < W) {
+ a_v[4] = load_a<REMAINDER>(A + (1 * W + 1) * K, mask_v);
+ }
+ if (w_in + 2 >= 0 && w_in + 2 < W) {
+ a_v[5] = load_a<REMAINDER>(A + (1 * W + 2) * K, mask_v);
+ }
+ }
+
+ if (h_in + 2 >= 0 && h_in + 2 < H) {
+ if (w_in >= 0 && w_in < W) {
+ a_v[6] = load_a<REMAINDER>(A + (2 * W + 0) * K, mask_v);
+ }
+ if (w_in + 1 >= 0 && w_in + 1 < W) {
+ a_v[7] = load_a<REMAINDER>(A + (2 * W + 1) * K, mask_v);
+ }
+ if (w_in + 2 >= 0 && w_in + 2 < W) {
+ a_v[8] = load_a<REMAINDER>(A + (2 * W + 2) * K, mask_v);
+ }
+ }
+
+ array<__m256i, 4> a_sum;
+ inner_prod_3x3_packed_<SUM_A, REMAINDER>(
+ a_v.data(), reinterpret_cast<const __m256i *>(Bp), C, remainder,
+ a_sum.data());
+ if (SUM_A) {
+ __m256i B_zero_point_v;
+ for (int i = 0; i < (REMAINDER ? (remainder / 8) : 4); ++i) {
+ if (PER_CHANNEL_QUANTIZATION) {
+ B_zero_point_v = _mm256_loadu_si256(
+ reinterpret_cast<const __m256i *>(B_zero_point + i * 8));
+ } else {
+ B_zero_point_v = _mm256_set1_epi32(B_zero_point[0]);
+ }
+ _mm256_store_si256(reinterpret_cast<__m256i *>(&row_offsets[i * 8]),
+ _mm256_mullo_epi32(a_sum[i], B_zero_point_v));
+ }
+ }
+}
+
+template <bool SUM_A, bool REMAINDER = false,
+ bool PER_CHANNEL_QUANTIZATION = false>
+static inline __attribute__((always_inline)) void
+inner_prod_3x3x3_packed_(int T, int H, int W, int K, int t_in, int h_in,
+ int w_in, const uint8_t *A, int32_t A_zero_point,
+ const int8_t *Bp, const int32_t *B_zero_point,
+ int32_t *C, int remainder, int32_t *row_offsets) {
+ __m256i A_zero_point_v = _mm256_set1_epi8(static_cast<uint8_t>(A_zero_point));
+ __m256i mask_v = _mm256_setzero_si256();
+ if (REMAINDER) {
+ mask_v = _mm256_loadu_si256(
+ reinterpret_cast<const __m256i *>(masks[remainder / 4].data()));
+ }
+
+ // The code below can be written as a simple R*S loop but the compiler
+ // doesn't unroll so we're manually unrolling it.
+ // constexpr int R = 3, S = 3;
+ // array<__m256i, R * S> a_v;
+ // for (int r = 0; r < R; ++r) {
+ // for (int s = 0; s < S; ++s) {
+ // if (h_in + r >= 0 && h_in + r < H && w_in + s >= 0 && w_in + s < W) {
+ // if (REMAINDER) {
+ // a_v[r * S + s] =
+ // _mm256_maskload_epi32((const int *)(A + (r * W + s) * K),
+ // mask_v);
+ // } else {
+ // a_v[r * S + s] =
+ // _mm256_lddqu_si256((const __m256i *)(A + (r * W + s) * K));
+ // }
+ // } else {
+ // a_v[r * S + s] = A_zero_point_v;
+ // }
+ // }
+ // }
+ array<__m256i, 8> a_v;
+ a_v[0] = A_zero_point_v;
+ a_v[1] = A_zero_point_v;
+ a_v[2] = A_zero_point_v;
+ a_v[3] = A_zero_point_v;
+ a_v[4] = A_zero_point_v;
+ a_v[5] = A_zero_point_v;
+ a_v[6] = A_zero_point_v;
+ a_v[7] = A_zero_point_v;
+
+ if (t_in >= 0 && t_in < T) {
+ if (h_in >= 0 && h_in < H) {
+ if (w_in >= 0 && w_in < W) {
+ a_v[0] = load_a<REMAINDER>(A + ((0 * H + 0) * W + 0) * K, mask_v);
+ }
+ if (w_in + 1 >= 0 && w_in + 1 < W) {
+ a_v[1] = load_a<REMAINDER>(A + ((0 * H + 0) * W + 1) * K, mask_v);
+ }
+ if (w_in + 2 >= 0 && w_in + 2 < W) {
+ a_v[2] = load_a<REMAINDER>(A + ((0 * H + 0) * W + 2) * K, mask_v);
+ }
+ }
+
+ if (h_in + 1 >= 0 && h_in + 1 < H) {
+ if (w_in >= 0 && w_in < W) {
+ a_v[3] = load_a<REMAINDER>(A + ((0 * H + 1) * W + 0) * K, mask_v);
+ }
+ if (w_in + 1 >= 0 && w_in + 1 < W) {
+ a_v[4] = load_a<REMAINDER>(A + ((0 * H + 1) * W + 1) * K, mask_v);
+ }
+ if (w_in + 2 >= 0 && w_in + 2 < W) {
+ a_v[5] = load_a<REMAINDER>(A + ((0 * H + 1) * W + 2) * K, mask_v);
+ }
+ }
+
+ if (h_in + 2 >= 0 && h_in + 2 < H) {
+ if (w_in >= 0 && w_in < W) {
+ a_v[6] = load_a<REMAINDER>(A + ((0 * H + 2) * W + 0) * K, mask_v);
+ }
+ if (w_in + 1 >= 0 && w_in + 1 < W) {
+ a_v[7] = load_a<REMAINDER>(A + ((0 * H + 2) * W + 1) * K, mask_v);
+ }
+ }
+ }
+
+ array<__m256i, 4> a_sum;
+ inner_prod_packed_<8, SUM_A, REMAINDER>(a_v.data(),
+ reinterpret_cast<const __m256i *>(Bp),
+ C, remainder, a_sum.data());
+
+ a_v[0] = A_zero_point_v;
+ a_v[1] = A_zero_point_v;
+ a_v[2] = A_zero_point_v;
+ a_v[3] = A_zero_point_v;
+ a_v[4] = A_zero_point_v;
+ a_v[5] = A_zero_point_v;
+ a_v[6] = A_zero_point_v;
+ a_v[7] = A_zero_point_v;
+
+ if (t_in >= 0 && t_in < T) {
+ if (h_in + 2 >= 0 && h_in + 2 < H) {
+ if (w_in + 2 >= 0 && w_in + 2 < W) {
+ a_v[0] = load_a<REMAINDER>(A + ((0 * H + 2) * W + 2) * K, mask_v);
+ }
+ }
+ }
+
+ if (t_in + 1 >= 0 && t_in + 1 < T) {
+ if (h_in >= 0 && h_in < H) {
+ if (w_in >= 0 && w_in < W) {
+ a_v[1] = load_a<REMAINDER>(A + ((1 * H + 0) * W + 0) * K, mask_v);
+ }
+ if (w_in + 1 >= 0 && w_in + 1 < W) {
+ a_v[2] = load_a<REMAINDER>(A + ((1 * H + 0) * W + 1) * K, mask_v);
+ }
+ if (w_in + 2 >= 0 && w_in + 2 < W) {
+ a_v[3] = load_a<REMAINDER>(A + ((1 * H + 0) * W + 2) * K, mask_v);
+ }
+ }
+
+ if (h_in + 1 >= 0 && h_in + 1 < H) {
+ if (w_in >= 0 && w_in < W) {
+ a_v[4] = load_a<REMAINDER>(A + ((1 * H + 1) * W + 0) * K, mask_v);
+ }
+ if (w_in + 1 >= 0 && w_in + 1 < W) {
+ a_v[5] = load_a<REMAINDER>(A + ((1 * H + 1) * W + 1) * K, mask_v);
+ }
+ if (w_in + 2 >= 0 && w_in + 2 < W) {
+ a_v[6] = load_a<REMAINDER>(A + ((1 * H + 1) * W + 2) * K, mask_v);
+ }
+ }
+
+ if (h_in + 2 >= 0 && h_in + 2 < H) {
+ if (w_in >= 0 && w_in < W) {
+ a_v[7] = load_a<REMAINDER>(A + ((1 * H + 2) * W + 0) * K, mask_v);
+ }
+ }
+ }
+
+ array<__m256i, 4> a_sum_temp;
+ inner_prod_packed_<8, SUM_A, REMAINDER, true /* acc */>(
+ a_v.data(), reinterpret_cast<const __m256i *>(Bp) + 8, C, remainder,
+ a_sum_temp.data());
+ if (SUM_A) {
+ a_sum[0] = _mm256_add_epi32(a_sum[0], a_sum_temp[0]);
+ a_sum[1] = _mm256_add_epi32(a_sum[1], a_sum_temp[1]);
+ a_sum[2] = _mm256_add_epi32(a_sum[2], a_sum_temp[2]);
+ a_sum[3] = _mm256_add_epi32(a_sum[3], a_sum_temp[3]);
+ }
+
+ a_v[0] = A_zero_point_v;
+ a_v[1] = A_zero_point_v;
+ a_v[2] = A_zero_point_v;
+ a_v[3] = A_zero_point_v;
+ a_v[4] = A_zero_point_v;
+ a_v[5] = A_zero_point_v;
+ a_v[6] = A_zero_point_v;
+ a_v[7] = A_zero_point_v;
+
+ if (t_in + 1 >= 0 && t_in + 1 < T) {
+ if (h_in + 2 >= 0 && h_in + 2 < H) {
+ if (w_in + 1 >= 0 && w_in + 1 < W) {
+ a_v[0] = load_a<REMAINDER>(A + ((1 * H + 2) * W + 1) * K, mask_v);
+ }
+ if (w_in + 2 >= 0 && w_in + 2 < W) {
+ a_v[1] = load_a<REMAINDER>(A + ((1 * H + 2) * W + 2) * K, mask_v);
+ }
+ }
+ }
+
+ if (t_in + 2 >= 0 && t_in + 2 < T) {
+ if (h_in >= 0 && h_in < H) {
+ if (w_in >= 0 && w_in < W) {
+ a_v[2] = load_a<REMAINDER>(A + ((2 * H + 0) * W + 0) * K, mask_v);
+ }
+ if (w_in + 1 >= 0 && w_in + 1 < W) {
+ a_v[3] = load_a<REMAINDER>(A + ((2 * H + 0) * W + 1) * K, mask_v);
+ }
+ if (w_in + 2 >= 0 && w_in + 2 < W) {
+ a_v[4] = load_a<REMAINDER>(A + ((2 * H + 0) * W + 2) * K, mask_v);
+ }
+ }
+
+ if (h_in + 1 >= 0 && h_in + 1 < H) {
+ if (w_in >= 0 && w_in < W) {
+ a_v[5] = load_a<REMAINDER>(A + ((2 * H + 1) * W + 0) * K, mask_v);
+ }
+ if (w_in + 1 >= 0 && w_in + 1 < W) {
+ a_v[6] = load_a<REMAINDER>(A + ((2 * H + 1) * W + 1) * K, mask_v);
+ }
+ if (w_in + 2 >= 0 && w_in + 2 < W) {
+ a_v[7] = load_a<REMAINDER>(A + ((2 * H + 1) * W + 2) * K, mask_v);
+ }
+ }
+ }
+
+ inner_prod_packed_<8, SUM_A, REMAINDER, true /* acc */>(
+ a_v.data(), reinterpret_cast<const __m256i *>(Bp) + 16, C, remainder,
+ a_sum_temp.data());
+ if (SUM_A) {
+ a_sum[0] = _mm256_add_epi32(a_sum[0], a_sum_temp[0]);
+ a_sum[1] = _mm256_add_epi32(a_sum[1], a_sum_temp[1]);
+ a_sum[2] = _mm256_add_epi32(a_sum[2], a_sum_temp[2]);
+ a_sum[3] = _mm256_add_epi32(a_sum[3], a_sum_temp[3]);
+ }
+
+ a_v[0] = A_zero_point_v;
+ a_v[1] = A_zero_point_v;
+ a_v[2] = A_zero_point_v;
+
+ if (t_in + 2 >= 0 && t_in + 2 < T) {
+ if (h_in + 2 >= 0 && h_in + 2 < H) {
+ if (w_in >= 0 && w_in < W) {
+ a_v[0] = load_a<REMAINDER>(A + ((2 * H + 2) * W + 0) * K, mask_v);
+ }
+ if (w_in + 1 >= 0 && w_in + 1 < W) {
+ a_v[1] = load_a<REMAINDER>(A + ((2 * H + 2) * W + 1) * K, mask_v);
+ }
+ if (w_in + 2 >= 0 && w_in + 2 < W) {
+ a_v[2] = load_a<REMAINDER>(A + ((2 * H + 2) * W + 2) * K, mask_v);
+ }
+ }
+ }
+
+ inner_prod_packed_<3, SUM_A, REMAINDER, true /* acc */>(
+ a_v.data(), reinterpret_cast<const __m256i *>(Bp) + 24, C, remainder,
+ a_sum_temp.data());
+
+ if (SUM_A) {
+ a_sum[0] = _mm256_add_epi32(a_sum[0], a_sum_temp[0]);
+ a_sum[1] = _mm256_add_epi32(a_sum[1], a_sum_temp[1]);
+ a_sum[2] = _mm256_add_epi32(a_sum[2], a_sum_temp[2]);
+ a_sum[3] = _mm256_add_epi32(a_sum[3], a_sum_temp[3]);
+
+ __m256i B_zero_point_v;
+ for (int i = 0; i < (REMAINDER ? (remainder / 8) : 4); ++i) {
+ if (PER_CHANNEL_QUANTIZATION) {
+ B_zero_point_v = _mm256_loadu_si256(
+ reinterpret_cast<const __m256i *>(B_zero_point + i * 8));
+ } else {
+ B_zero_point_v = _mm256_set1_epi32(B_zero_point[0]);
+ }
+ _mm256_store_si256(reinterpret_cast<__m256i *>(&row_offsets[i * 8]),
+ _mm256_mullo_epi32(a_sum[i], B_zero_point_v));
+ }
+ }
+}
+
+template <bool SUM_A, bool FUSE_RELU>
+static inline __attribute__((always_inline))
+void depthwise_3x3_kernel_(int H, int W, int K, int h, int w,
+ int stride_h, int stride_w,
+ int32_t A_zero_point, const uint8_t* A,
+ int32_t B_zero_point, const int8_t* Bp,
+ float C_multiplier,
+ int32_t C_zero_point,
+ int32_t* C_int32, uint8_t* C_uint8,
+ int32_t* row_offsets,
+ const int32_t *col_offsets,
+ const int32_t *bias)
+{
+ constexpr int S = 3;
+ constexpr int PAD_T = 1, PAD_L = 1, PAD_R = 1;
+ int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1;
+ int h_in = -PAD_T + h * stride_h;
+ int w_in = -PAD_L + w * stride_w;
+
+ int k;
+ for (k = 0; k < K / 32 * 32; k += 32) {
+ inner_prod_3x3_packed_<SUM_A>(
+ H, W, K, h_in, w_in,
+ A + (h_in * W + w_in) * K + k, A_zero_point,
+ Bp + k * 10, &B_zero_point,
+ C_int32 + k, 0, &row_offsets[k]);
+ }
+ int remainder = K - k;
+ if (remainder) {
+ inner_prod_3x3_packed_<SUM_A, true>(
+ H, W, K, h_in, w_in,
+ A + (h_in * W + w_in) * K + k, A_zero_point,
+ Bp + k * 10, &B_zero_point,
+ C_int32 + k, remainder, &row_offsets[k]);
+ }
+ if (SUM_A) {
+ requantize_<FUSE_RELU, true>
+ (
+ A_zero_point, C_multiplier, C_zero_point,
+ C_int32, C_uint8 + (h * W_OUT + w) * K, K,
+ row_offsets,
+ col_offsets, bias
+ );
+ }
+}
+
+template <bool SUM_A, bool FUSE_RELU>
+static inline __attribute__((always_inline))
+void depthwise_3x3x3_kernel_(int T, int H, int W, int K, int t, int h, int w,
+ int stride_t, int stride_h, int stride_w,
+ int32_t A_zero_point, const uint8_t* A,
+ int32_t B_zero_point, const int8_t* Bp,
+ float C_multiplier,
+ int32_t C_zero_point,
+ int32_t* C_int32, uint8_t* C_uint8,
+ int32_t* row_offsets,
+ const int32_t *col_offsets,
+ const int32_t *bias)
+{
+ constexpr int R = 3, S = 3;
+ constexpr int PAD_P = 1, PAD_T = 1, PAD_B = 1, PAD_L = 1, PAD_R = 1;
+ int H_OUT = (H + PAD_T + PAD_B - R) / stride_h + 1;
+ int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1;
+ int t_in = -PAD_P + t * stride_t;
+ int h_in = -PAD_T + h * stride_h;
+ int w_in = -PAD_L + w * stride_w;
+
+ int k;
+ for (k = 0; k < K / 32 * 32; k += 32) {
+ inner_prod_3x3x3_packed_<SUM_A>(
+ T, H, W, K, t_in, h_in, w_in,
+ A + ((t_in * H + h_in) * W + w_in) * K + k, A_zero_point,
+ Bp + k * 28, &B_zero_point,
+ C_int32 + k, 0, &row_offsets[k]);
+ }
+ int remainder = K - k;
+ if (remainder) {
+ inner_prod_3x3x3_packed_<SUM_A, true>(
+ T, H, W, K, t_in, h_in, w_in,
+ A + ((t_in * H + h_in) * W + w_in) * K + k, A_zero_point,
+ Bp + k * 28, &B_zero_point,
+ C_int32 + k, remainder, &row_offsets[k]);
+ }
+ if (SUM_A) {
+ requantize_<FUSE_RELU, true>
+ (
+ A_zero_point, C_multiplier, C_zero_point,
+ C_int32, C_uint8 + ((t * H_OUT + h) * W_OUT + w) * K, K,
+ row_offsets,
+ col_offsets, bias
+ );
+ }
+}
+
+template <bool SUM_A>
+static inline __attribute__((always_inline)) void
+depthwise_3x3_per_channel_quantization_kernel_(
+ int H, int W, int K, int h, int w, int stride_h, int stride_w,
+ int32_t A_zero_point, const uint8_t *A,
+ const int32_t *B_zero_point, const int8_t *Bp,
+ const float *C_multiplier, int32_t C_zero_point,
+ int32_t *C_int32, uint8_t *C_uint8,
+ int32_t *row_offsets, const int32_t *col_offsets, const int32_t *bias) {
+ constexpr int S = 3;
+ constexpr int PAD_T = 1, PAD_L = 1, PAD_R = 1;
+ int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1;
+ int h_in = -PAD_T + h * stride_h;
+ int w_in = -PAD_L + w * stride_w;
+
+ int k;
+ for (k = 0; k < K / 32 * 32; k += 32) {
+ inner_prod_3x3_packed_<SUM_A, false/*remainder*/, true/*per-channel*/>(
+ H, W, K, h_in, w_in,
+ A + (h_in * W + w_in) * K + k, A_zero_point,
+ Bp + k * 10, B_zero_point + k,
+ C_int32 + k, 0, &row_offsets[k]);
+ }
+ int remainder = K - k;
+ if (remainder) {
+ inner_prod_3x3_packed_<SUM_A, true/*remainder*/, true/*per-channel*/>(
+ H, W, K, h_in, w_in,
+ A + (h_in * W + w_in) * K + k, A_zero_point,
+ Bp + k * 10, B_zero_point + k,
+ C_int32 + k, remainder, &row_offsets[k]);
+ }
+ if (SUM_A) {
+ requantize_per_channel_<false, true>
+ (
+ A_zero_point, C_multiplier, C_zero_point,
+ C_int32, C_uint8 + (h * W_OUT + w) * K, K,
+ row_offsets,
+ col_offsets, bias
+ );
+ }
+}
+
+static pair<int, int> closest_factors_(int n) {
+ int a = (int)std::sqrt(n);
+ while (n % a != 0) {
+ a--;
+ }
+ return { a, n / a }; // a <= n / a
+}
+
+// TODO: short-circuit when B_zero_point is 0 or A_zero_point is 0
+// This implemntation should be general enough to handle not just 3x3 but other
+// filter shapes by parameterizing with R and S but restricting it to just 3x3
+// for now.
+template <bool FUSE_RESCALE = true, bool FUSE_RELU = false>
+static inline __attribute__((always_inline))
+void depthwise_3x3_pad_1_(int N, int H, int W, int K,
+ int stride_h, int stride_w,
+ int32_t A_zero_point, const uint8_t *A,
+ int32_t B_zero_point, const Packed3x3ConvMatrix &B,
+ float C_multiplier,
+ int32_t C_zero_point,
+ int32_t* C_int32, uint8_t* C_uint8,
+ const int32_t *col_offsets, const int32_t *bias,
+ int thread_id, int num_threads) {
+ assert(K % 8 == 0);
+ constexpr int R = 3, S = 3;
+ constexpr int PAD_T = 1, PAD_B = 1, PAD_L = 1, PAD_R = 1;
+ int H_OUT = (H + PAD_T + PAD_B - R) / stride_h + 1;
+ int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1;
+ const int8_t* Bp = B.PackedMat();
+
+ int32_t row_offsets[(K + 31) / 32 * 32] __attribute__ ((aligned (64)));
+ int32_t *C_temp;
+
+ int n_begin, n_end;
+ int h_begin, h_end, w_begin, w_end;
+ if (N >= num_threads) {
+ int n_per_thread = (N + num_threads - 1) / num_threads;
+ n_begin = std::min(thread_id * n_per_thread, N);
+ n_end = std::min(n_begin + n_per_thread, N);
+ h_begin = 0;
+ h_end = H_OUT;
+ w_begin = 0;
+ w_end = W_OUT;
+ } else {
+ int nthreads_per_n = num_threads / N;
+ n_begin = std::min(thread_id / nthreads_per_n, N);
+ n_end = std::min(n_begin + 1, N);
+
+ int tid_of_n_begin = std::min(n_begin * nthreads_per_n, num_threads);
+ int tid_of_n_end = std::min(tid_of_n_begin + nthreads_per_n, num_threads);
+ int nthreads_of_n = tid_of_n_end - tid_of_n_begin;
+ int tid_within_n = thread_id - tid_of_n_begin;
+ assert(tid_within_n >= 0);
+ assert(tid_within_n < nthreads_of_n);
+
+ // n is processed by num_threads_h * num_threads_w 2D grid of threads
+ int num_threads_h, num_threads_w;
+ // num_threads_w <= num_threads_h
+ tie(num_threads_w, num_threads_h) = closest_factors_(nthreads_of_n);
+ int tid_h = tid_within_n / num_threads_w;
+ int tid_w = tid_within_n % num_threads_w;
+
+ int h_per_thread = (H_OUT + num_threads_h - 1) / num_threads_h;
+ h_begin = std::min(tid_h * h_per_thread, H_OUT);
+ h_end = std::min(h_begin + h_per_thread, H_OUT);
+
+ int w_per_thread = (W_OUT + num_threads_w - 1) / num_threads_w;
+ w_begin = std::min(tid_w * w_per_thread, W_OUT);
+ w_end = std::min(w_begin + w_per_thread, W_OUT);
+ }
+
+ for (int n = n_begin; n < n_end; ++n) {
+ const uint8_t* A_base = A + n * H * W * K;
+ uint8_t* C_uint8_base = C_uint8 + n * H_OUT * W_OUT * K;
+
+ int h = 0;
+ int w = 0;
+
+ if (h_begin == 0) {
+ if (w_begin == 0) {
+ C_temp = FUSE_RESCALE ? C_int32
+ : C_int32 + ((n * H_OUT + h) * W_OUT + w) * K;
+ depthwise_3x3_kernel_<FUSE_RESCALE, FUSE_RELU>(
+ H, W, K, h, w, stride_h, stride_w,
+ A_zero_point, A_base,
+ B_zero_point, Bp,
+ C_multiplier, C_zero_point, C_temp, C_uint8_base,
+ row_offsets, col_offsets, bias);
+ }
+
+ for (w = std::max(1, w_begin); w < std::min(W_OUT - 1, w_end); ++w) {
+ C_temp = FUSE_RESCALE ? C_int32
+ : C_int32 + ((n * H_OUT + h) * W_OUT + w) * K;
+ depthwise_3x3_kernel_<FUSE_RESCALE, FUSE_RELU>(
+ H, W, K, h, w, stride_h, stride_w,
+ A_zero_point, A_base,
+ B_zero_point, Bp,
+ C_multiplier, C_zero_point, C_temp, C_uint8_base,
+ row_offsets, col_offsets, bias);
+ }
+
+ if (w_end == W_OUT) {
+ w = W_OUT - 1;
+ C_temp = FUSE_RESCALE ? C_int32
+ : C_int32 + ((n * H_OUT + h) * W_OUT + w) * K;
+ depthwise_3x3_kernel_<FUSE_RESCALE, FUSE_RELU>(
+ H, W, K, h, w, stride_h, stride_w,
+ A_zero_point, A_base,
+ B_zero_point, Bp,
+ C_multiplier, C_zero_point, C_temp, C_uint8_base,
+ row_offsets, col_offsets, bias);
+ }
+ }
+
+ for (h = std::max(1, h_begin); h < std::min(H - 1, h_end); ++h) {
+ if (w_begin == 0) {
+ w = 0;
+ C_temp = FUSE_RESCALE ? C_int32
+ : C_int32 + ((n * H_OUT + h) * W_OUT + w) * K;
+ depthwise_3x3_kernel_<FUSE_RESCALE, FUSE_RELU>(
+ H, W, K, h, w, stride_h, stride_w,
+ A_zero_point, A_base,
+ B_zero_point, Bp,
+ C_multiplier, C_zero_point, C_temp, C_uint8_base,
+ row_offsets, col_offsets, bias);
+ }
+
+ for (w = std::max(1, w_begin); w < std::min(W_OUT - 1, w_end); ++w) {
+ C_temp = FUSE_RESCALE ? C_int32
+ : C_int32 + ((n * H_OUT + h) * W_OUT + w) * K;
+ depthwise_3x3_kernel_<FUSE_RESCALE, FUSE_RELU>(
+ H, W, K, h, w, stride_h, stride_w,
+ A_zero_point, A_base,
+ B_zero_point, Bp,
+ C_multiplier, C_zero_point, C_temp, C_uint8_base,
+ row_offsets, col_offsets, bias);
+ }
+
+ if (w_end == W_OUT) {
+ w = W_OUT - 1;
+ C_temp = FUSE_RESCALE ? C_int32
+ : C_int32 + ((n * H_OUT + h) * W_OUT + w) * K;
+ depthwise_3x3_kernel_<FUSE_RESCALE, FUSE_RELU>(
+ H, W, K, h, w, stride_h, stride_w,
+ A_zero_point, A_base,
+ B_zero_point, Bp,
+ C_multiplier, C_zero_point, C_temp, C_uint8_base,
+ row_offsets, col_offsets, bias);
+ }
+ }
+
+ if (h_end == H_OUT) {
+ h = H_OUT - 1;
+ w = 0;
+ if (w_begin == 0) {
+ C_temp = FUSE_RESCALE ? C_int32
+ : C_int32 + ((n * H_OUT + h) * W_OUT + w) * K;
+ depthwise_3x3_kernel_<FUSE_RESCALE, FUSE_RELU>(
+ H, W, K, h, w, stride_h, stride_w,
+ A_zero_point, A_base,
+ B_zero_point, Bp,
+ C_multiplier, C_zero_point, C_temp, C_uint8_base,
+ row_offsets, col_offsets, bias);
+ }
+
+ for (w = std::max(1, w_begin); w < std::min(W_OUT - 1, w_end); ++w) {
+ C_temp = FUSE_RESCALE ? C_int32
+ : C_int32 + ((n * H_OUT + h) * W_OUT + w) * K;
+ depthwise_3x3_kernel_<FUSE_RESCALE, FUSE_RELU>(
+ H, W, K, h, w, stride_h, stride_w,
+ A_zero_point, A_base,
+ B_zero_point, Bp,
+ C_multiplier, C_zero_point, C_temp, C_uint8_base,
+ row_offsets, col_offsets, bias);
+ }
+
+ if (w_end == W_OUT) {
+ w = W_OUT - 1;
+ C_temp = FUSE_RESCALE ? C_int32
+ : C_int32 + ((n * H_OUT + h) * W_OUT + w) * K;
+ depthwise_3x3_kernel_<FUSE_RESCALE, FUSE_RELU>(
+ H, W, K, h, w, stride_h, stride_w,
+ A_zero_point, A_base,
+ B_zero_point, Bp,
+ C_multiplier, C_zero_point, C_temp, C_uint8_base,
+ row_offsets, col_offsets, bias);
+ }
+ }
+ } // for each n
+};
+
+template <bool FUSE_RESCALE = true, bool FUSE_RELU = false>
+static inline __attribute__((always_inline))
+void depthwise_3x3x3_pad_1_(int N, int T, int H, int W, int K,
+ int stride_t, int stride_h, int stride_w,
+ int32_t A_zero_point, const uint8_t *A,
+ int32_t B_zero_point,
+ const Packed3x3x3ConvMatrix &B,
+ float C_multiplier,
+ int32_t C_zero_point,
+ int32_t* C_int32, uint8_t* C_uint8,
+ const int32_t *col_offsets, const int32_t *bias,
+ int thread_id, int num_threads) {
+ assert(K % 8 == 0);
+ constexpr int K_T = 3, K_H = 3, K_W = 3;
+ constexpr int PAD_P = 1, PAD_N = 1, PAD_T = 1, PAD_B = 1, PAD_L = 1,
+ PAD_R = 1;
+ int T_OUT = (T + PAD_P + PAD_N - K_T) / stride_t + 1;
+ int H_OUT = (H + PAD_T + PAD_B - K_H) / stride_h + 1;
+ int W_OUT = (W + PAD_L + PAD_R - K_W) / stride_w + 1;
+ const int8_t* Bp = B.PackedMat();
+
+ int32_t row_offsets[(K + 31) / 32 * 32] __attribute__ ((aligned (64)));
+ int32_t *C_temp;
+
+ int n_begin, n_end;
+ int t_begin, t_end, h_begin, h_end;
+ if (N >= num_threads) {
+ int n_per_thread = (N + num_threads - 1) / num_threads;
+ n_begin = std::min(thread_id * n_per_thread, N);
+ n_end = std::min(n_begin + n_per_thread, N);
+ t_begin = 0;
+ t_end = T_OUT;
+ h_begin = 0;
+ h_end = H_OUT;
+ } else {
+ int nthreads_per_n = num_threads / N;
+ n_begin = std::min(thread_id / nthreads_per_n, N);
+ n_end = std::min(n_begin + 1, N);
+
+ int tid_of_n_begin = std::min(n_begin * nthreads_per_n, num_threads);
+ int tid_of_n_end = std::min(tid_of_n_begin + nthreads_per_n, num_threads);
+ int nthreads_of_n = tid_of_n_end - tid_of_n_begin;
+ int tid_within_n = thread_id - tid_of_n_begin;
+ assert(tid_within_n >= 0);
+ assert(tid_within_n < nthreads_of_n);
+
+ // n is processed by num_threads_t * num_threads_h 2D grid of threads
+ int num_threads_t, num_threads_h;
+ // num_threads_w <= num_threads_h
+ tie(num_threads_t, num_threads_h) = closest_factors_(nthreads_of_n);
+ int tid_t = tid_within_n / num_threads_h;
+ int tid_h = tid_within_n % num_threads_h;
+
+ int t_per_thread = (T_OUT + num_threads_t - 1) / num_threads_t;
+ t_begin = std::min(tid_t * t_per_thread, T_OUT);
+ t_end = std::min(t_begin + t_per_thread, T_OUT);
+
+ int h_per_thread = (H_OUT + num_threads_h - 1) / num_threads_h;
+ h_begin = std::min(tid_h * h_per_thread, H_OUT);
+ h_end = std::min(h_begin + h_per_thread, H_OUT);
+ }
+
+ for (int n = n_begin; n < n_end; ++n) {
+ const uint8_t* A_base = A + n * T * H * W * K;
+ uint8_t* C_uint8_base = C_uint8 + n * T_OUT * H_OUT * W_OUT * K;
+
+ for (int t = t_begin; t < t_end; ++t) {
+ for (int h = h_begin; h < h_end; ++h) {
+ for (int w = 0; w < W_OUT; ++w) {
+ C_temp =
+ FUSE_RESCALE
+ ? C_int32
+ : C_int32 + (((n * T_OUT + t) * H_OUT + h) * W_OUT + w) * K;
+ depthwise_3x3x3_kernel_<FUSE_RESCALE, FUSE_RELU>(
+ T, H, W, K, t, h, w, stride_t, stride_h, stride_w, A_zero_point,
+ A_base, B_zero_point, Bp, C_multiplier,
+ C_zero_point, C_temp, C_uint8_base, row_offsets, col_offsets,
+ bias);
+ } // w
+ } // h
+ } // t
+ } // for each n
+};
+
+template <bool FUSE_RESCALE = true>
+static inline __attribute__((always_inline)) void
+depthwise_3x3_per_channel_quantization_pad_1_(
+ int N, int H, int W, int K, int stride_h, int stride_w,
+ int32_t A_zero_point, const uint8_t *A, const int32_t *B_zero_point,
+ const Packed3x3ConvMatrix &B, const float *C_multiplier,
+ int32_t C_zero_point, int32_t *C_int32, uint8_t *C_uint8,
+ const int32_t *col_offsets, const int32_t *bias, int thread_id,
+ int num_threads) {
+ assert(K % 8 == 0);
+ constexpr int R = 3, S = 3;
+ constexpr int PAD_T = 1, PAD_B = 1, PAD_L = 1, PAD_R = 1;
+ int H_OUT = (H + PAD_T + PAD_B - R) / stride_h + 1;
+ int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1;
+ const int8_t* Bp = B.PackedMat();
+
+ int32_t row_offsets[(K + 31) / 32 * 32] __attribute__ ((aligned (64)));
+ int32_t *C_temp;
+
+ int n_begin, n_end;
+ int h_begin, h_end, w_begin, w_end;
+ if (N >= num_threads) {
+ int n_per_thread = (N + num_threads - 1) / num_threads;
+ n_begin = std::min(thread_id * n_per_thread, N);
+ n_end = std::min(n_begin + n_per_thread, N);
+ h_begin = 0;
+ h_end = H_OUT;
+ w_begin = 0;
+ w_end = W_OUT;
+ } else {
+ int nthreads_per_n = num_threads / N;
+ n_begin = std::min(thread_id / nthreads_per_n, N);
+ n_end = std::min(n_begin + 1, N);
+
+ int tid_of_n_begin = std::min(n_begin * nthreads_per_n, num_threads);
+ int tid_of_n_end = std::min(tid_of_n_begin + nthreads_per_n, num_threads);
+ int nthreads_of_n = tid_of_n_end - tid_of_n_begin;
+ int tid_within_n = thread_id - tid_of_n_begin;
+ assert(tid_within_n >= 0);
+ assert(tid_within_n < nthreads_of_n);
+
+ // n is processed by num_threads_h * num_threads_w 2D grid of threads
+ int num_threads_h, num_threads_w;
+ // num_threads_w <= num_threads_h
+ tie(num_threads_w, num_threads_h) = closest_factors_(nthreads_of_n);
+ int tid_h = tid_within_n / num_threads_w;
+ int tid_w = tid_within_n % num_threads_w;
+
+ int h_per_thread = (H_OUT + num_threads_h - 1) / num_threads_h;
+ h_begin = std::min(tid_h * h_per_thread, H_OUT);
+ h_end = std::min(h_begin + h_per_thread, H_OUT);
+
+ int w_per_thread = (W_OUT + num_threads_w - 1) / num_threads_w;
+ w_begin = std::min(tid_w * w_per_thread, W_OUT);
+ w_end = std::min(w_begin + w_per_thread, W_OUT);
+ }
+
+ for (int n = n_begin; n < n_end; ++n) {
+ const uint8_t* A_base = A + n * H * W * K;
+ uint8_t* C_uint8_base = C_uint8 + n * H_OUT * W_OUT * K;
+
+ int h = 0;
+ int w = 0;
+
+ if (h_begin == 0) {
+ if (w_begin == 0) {
+ C_temp = FUSE_RESCALE ? C_int32
+ : C_int32 + ((n * H_OUT + h) * W_OUT + w) * K;
+ depthwise_3x3_per_channel_quantization_kernel_<FUSE_RESCALE>(
+ H, W, K, h, w, stride_h, stride_w,
+ A_zero_point, A_base,
+ B_zero_point, Bp,
+ C_multiplier, C_zero_point, C_temp, C_uint8_base,
+ row_offsets, col_offsets, bias);
+ }
+
+ for (w = std::max(1, w_begin); w < std::min(W_OUT - 1, w_end); ++w) {
+ C_temp = FUSE_RESCALE ? C_int32
+ : C_int32 + ((n * H_OUT + h) * W_OUT + w) * K;
+ depthwise_3x3_per_channel_quantization_kernel_<FUSE_RESCALE>(
+ H, W, K, h, w, stride_h, stride_w,
+ A_zero_point, A_base,
+ B_zero_point, Bp,
+ C_multiplier, C_zero_point, C_temp, C_uint8_base,
+ row_offsets, col_offsets, bias);
+ }
+
+ if (w_end == W_OUT) {
+ w = W_OUT - 1;
+ C_temp = FUSE_RESCALE ? C_int32
+ : C_int32 + ((n * H_OUT + h) * W_OUT + w) * K;
+ depthwise_3x3_per_channel_quantization_kernel_<FUSE_RESCALE>(
+ H, W, K, h, w, stride_h, stride_w,
+ A_zero_point, A_base,
+ B_zero_point, Bp,
+ C_multiplier, C_zero_point, C_temp, C_uint8_base,
+ row_offsets, col_offsets, bias);
+ }
+ }
+
+ for (h = std::max(1, h_begin); h < std::min(H - 1, h_end); ++h) {
+ if (w_begin == 0) {
+ w = 0;
+ C_temp = FUSE_RESCALE ? C_int32
+ : C_int32 + ((n * H_OUT + h) * W_OUT + w) * K;
+ depthwise_3x3_per_channel_quantization_kernel_<FUSE_RESCALE>(
+ H, W, K, h, w, stride_h, stride_w,
+ A_zero_point, A_base,
+ B_zero_point, Bp,
+ C_multiplier, C_zero_point, C_temp, C_uint8_base,
+ row_offsets, col_offsets, bias);
+ }
+
+ for (w = std::max(1, w_begin); w < std::min(W_OUT - 1, w_end); ++w) {
+ C_temp = FUSE_RESCALE ? C_int32
+ : C_int32 + ((n * H_OUT + h) * W_OUT + w) * K;
+ depthwise_3x3_per_channel_quantization_kernel_<FUSE_RESCALE>(
+ H, W, K, h, w, stride_h, stride_w,
+ A_zero_point, A_base,
+ B_zero_point, Bp,
+ C_multiplier, C_zero_point, C_temp, C_uint8_base,
+ row_offsets, col_offsets, bias);
+ }
+
+ if (w_end == W_OUT) {
+ w = W_OUT - 1;
+ C_temp = FUSE_RESCALE ? C_int32
+ : C_int32 + ((n * H_OUT + h) * W_OUT + w) * K;
+ depthwise_3x3_per_channel_quantization_kernel_<FUSE_RESCALE>(
+ H, W, K, h, w, stride_h, stride_w,
+ A_zero_point, A_base,
+ B_zero_point, Bp,
+ C_multiplier, C_zero_point, C_temp, C_uint8_base,
+ row_offsets, col_offsets, bias);
+ }
+ }
+
+ if (h_end == H_OUT) {
+ h = H_OUT - 1;
+ w = 0;
+ if (w_begin == 0) {
+ C_temp = FUSE_RESCALE ? C_int32
+ : C_int32 + ((n * H_OUT + h) * W_OUT + w) * K;
+ depthwise_3x3_per_channel_quantization_kernel_<FUSE_RESCALE>(
+ H, W, K, h, w, stride_h, stride_w,
+ A_zero_point, A_base,
+ B_zero_point, Bp,
+ C_multiplier, C_zero_point, C_temp, C_uint8_base,
+ row_offsets, col_offsets, bias);
+ }
+
+ for (w = std::max(1, w_begin); w < std::min(W_OUT - 1, w_end); ++w) {
+ C_temp = FUSE_RESCALE ? C_int32
+ : C_int32 + ((n * H_OUT + h) * W_OUT + w) * K;
+ depthwise_3x3_per_channel_quantization_kernel_<FUSE_RESCALE>(
+ H, W, K, h, w, stride_h, stride_w,
+ A_zero_point, A_base,
+ B_zero_point, Bp,
+ C_multiplier, C_zero_point, C_temp, C_uint8_base,
+ row_offsets, col_offsets, bias);
+ }
+
+ if (w_end == W_OUT) {
+ w = W_OUT - 1;
+ C_temp = FUSE_RESCALE ? C_int32
+ : C_int32 + ((n * H_OUT + h) * W_OUT + w) * K;
+ depthwise_3x3_per_channel_quantization_kernel_<FUSE_RESCALE>(
+ H, W, K, h, w, stride_h, stride_w,
+ A_zero_point, A_base,
+ B_zero_point, Bp,
+ C_multiplier, C_zero_point, C_temp, C_uint8_base,
+ row_offsets, col_offsets, bias);
+ }
+ }
+ } // for each n
+};
+
+// assumption: W > 3 and H > 3
+void depthwise_3x3_pad_1(
+ int N,
+ int H,
+ int W,
+ int K,
+ int stride_h,
+ int stride_w,
+ int32_t A_zero_point, const uint8_t* A,
+ const Packed3x3ConvMatrix& B,
+ int32_t* C,
+ int thread_id,
+ int num_threads) {
+ if (7 == H && 7 == W && 1 == stride_h && 1 == stride_w) {
+ depthwise_3x3_pad_1_<false>(
+ N, H, W, K,
+ stride_h, stride_w,
+ A_zero_point, A,
+ 0, B,
+ 0.0f, 0, C, nullptr,
+ nullptr, nullptr,
+ thread_id, num_threads);
+ } else if (14 == H && 14 == W && 2 == stride_h && 2 == stride_w) {
+ depthwise_3x3_pad_1_<false>(
+ N, H, W, K,
+ stride_h, stride_w,
+ A_zero_point, A,
+ 0, B,
+ 0.0f, 0, C, nullptr,
+ nullptr, nullptr,
+ thread_id, num_threads);
+ } else if (1 == stride_h && 1 == stride_w) {
+ depthwise_3x3_pad_1_<false>(
+ N, H, W, K,
+ stride_h, stride_w,
+ A_zero_point, A,
+ 0, B,
+ 0.0f, 0, C, nullptr,
+ nullptr, nullptr,
+ thread_id, num_threads);
+ } else if (2 == stride_h && 2 == stride_w) {
+ depthwise_3x3_pad_1_<false>(
+ N, H, W, K,
+ stride_h, stride_w,
+ A_zero_point, A,
+ 0, B,
+ 0.0f, 0, C, nullptr,
+ nullptr, nullptr,
+ thread_id, num_threads);
+ } else {
+ depthwise_3x3_pad_1_<false>(
+ N, H, W, K,
+ stride_h, stride_w,
+ A_zero_point, A,
+ 0, B,
+ 0.0f, 0, C, nullptr,
+ nullptr, nullptr,
+ thread_id, num_threads);
+ }
+}
+
+void depthwise_3x3_pad_1(
+ int N, int H, int W, int K,
+ int stride_h, int stride_w,
+ int32_t A_zero_point, const uint8_t* A,
+ int32_t B_zero_point, const Packed3x3ConvMatrix& B,
+ float C_multiplier, int32_t C_zero_point, uint8_t* C,
+ const int32_t* col_offsets,
+ const int32_t* bias,
+ int thread_id, int num_threads,
+ bool fuse_relu) {
+ int32_t C_int32_temp[(K + 31) / 32 * 32];
+ if (fuse_relu) {
+ if (7 == H && 7 == W && 1 == stride_h && 1 == stride_w) {
+ depthwise_3x3_pad_1_<true /* FUSE_RESCALE */, true /* FUSE_RELU */>(
+ N, H, W, K,
+ stride_h, stride_w,
+ A_zero_point, A,
+ B_zero_point, B,
+ C_multiplier, C_zero_point, C_int32_temp, C,
+ col_offsets, bias,
+ thread_id, num_threads);
+ } else if (14 == H && 14 == W && 2 == stride_h && 2 == stride_w) {
+ depthwise_3x3_pad_1_<true /* FUSE_RESCALE */, true /* FUSE_RELU */>(
+ N, H, W, K,
+ stride_h, stride_w,
+ A_zero_point, A,
+ B_zero_point, B,
+ C_multiplier, C_zero_point, C_int32_temp, C,
+ col_offsets, bias,
+ thread_id, num_threads);
+ } else if (1 == stride_h && 1 == stride_w) {
+ depthwise_3x3_pad_1_<true /* FUSE_RESCALE */, true /* FUSE_RELU */>(
+ N, H, W, K,
+ stride_h, stride_w,
+ A_zero_point, A,
+ B_zero_point, B,
+ C_multiplier, C_zero_point, C_int32_temp, C,
+ col_offsets, bias,
+ thread_id, num_threads);
+ } else if (2 == stride_h && 2 == stride_w) {
+ depthwise_3x3_pad_1_<true /* FUSE_RESCALE */, true /* FUSE_RELU */>(
+ N, H, W, K,
+ stride_h, stride_w,
+ A_zero_point, A,
+ B_zero_point, B,
+ C_multiplier, C_zero_point, C_int32_temp, C,
+ col_offsets, bias,
+ thread_id, num_threads);
+ } else {
+ depthwise_3x3_pad_1_<true /* FUSE_RESCALE */, true /* FUSE_RELU */>(
+ N, H, W, K,
+ stride_h, stride_w,
+ A_zero_point, A,
+ B_zero_point, B,
+ C_multiplier, C_zero_point, C_int32_temp, C,
+ col_offsets, bias,
+ thread_id, num_threads);
+ }
+ } else {
+ if (7 == H && 7 == W && 1 == stride_h && 1 == stride_w) {
+ depthwise_3x3_pad_1_(
+ N, H, W, K,
+ stride_h, stride_w,
+ A_zero_point, A,
+ B_zero_point, B,
+ C_multiplier, C_zero_point, C_int32_temp, C,
+ col_offsets, bias,
+ thread_id, num_threads);
+ } else if (14 == H && 14 == W && 2 == stride_h && 2 == stride_w) {
+ depthwise_3x3_pad_1_(
+ N, H, W, K,
+ stride_h, stride_w,
+ A_zero_point, A,
+ B_zero_point, B,
+ C_multiplier, C_zero_point, C_int32_temp, C,
+ col_offsets, bias,
+ thread_id, num_threads);
+ } else if (1 == stride_h && 1 == stride_w) {
+ depthwise_3x3_pad_1_(
+ N, H, W, K,
+ stride_h, stride_w,
+ A_zero_point, A,
+ B_zero_point, B,
+ C_multiplier, C_zero_point, C_int32_temp, C,
+ col_offsets, bias,
+ thread_id, num_threads);
+ } else if (2 == stride_h && 2 == stride_w) {
+ depthwise_3x3_pad_1_(
+ N, H, W, K,
+ stride_h, stride_w,
+ A_zero_point, A,
+ B_zero_point, B,
+ C_multiplier, C_zero_point, C_int32_temp, C,
+ col_offsets, bias,
+ thread_id, num_threads);
+ } else {
+ depthwise_3x3_pad_1_(
+ N, H, W, K,
+ stride_h, stride_w,
+ A_zero_point, A,
+ B_zero_point, B,
+ C_multiplier, C_zero_point, C_int32_temp, C,
+ col_offsets, bias,
+ thread_id, num_threads);
+ }
+ }
+}
+
+void depthwise_3x3x3_pad_1(
+ int N,
+ int T,
+ int H,
+ int W,
+ int K,
+ int stride_t,
+ int stride_h,
+ int stride_w,
+ int32_t A_zero_point, const uint8_t* A,
+ const Packed3x3x3ConvMatrix& B,
+ int32_t* C,
+ int thread_id,
+ int num_threads) {
+ depthwise_3x3x3_pad_1_<false /* FUSE_RESCALE */>(
+ N, T, H, W, K,
+ stride_t, stride_h, stride_w,
+ A_zero_point, A,
+ 0, B,
+ 0.0f, 0, C, nullptr,
+ nullptr, nullptr,
+ thread_id, num_threads);
+}
+
+static void depthwise_3x3x3_pad_1_(
+ int N, int T, int H, int W, int K,
+ int stride_t, int stride_h, int stride_w,
+ int32_t A_zero_point, const uint8_t* A,
+ int32_t B_zero_point, const Packed3x3x3ConvMatrix& B,
+ float C_multiplier, int32_t C_zero_point, uint8_t* C,
+ const int32_t* col_offsets,
+ const int32_t* bias,
+ int thread_id, int num_threads) {
+ int32_t C_int32_temp[(K + 31) / 32 * 32];
+ depthwise_3x3x3_pad_1_<true /* FUSE_RESCALE */, false /* FUSE_RELU */>(
+ N, T, H, W, K,
+ stride_t, stride_h, stride_w,
+ A_zero_point, A,
+ B_zero_point, B,
+ C_multiplier, C_zero_point, C_int32_temp, C,
+ col_offsets, bias,
+ thread_id, num_threads);
+}
+
+static void depthwise_3x3x3_pad_1_relu_fused_(
+ int N, int T, int H, int W, int K,
+ int stride_t, int stride_h, int stride_w,
+ int32_t A_zero_point, const uint8_t* A,
+ int32_t B_zero_point, const Packed3x3x3ConvMatrix& B,
+ float C_multiplier, int32_t C_zero_point, uint8_t* C,
+ const int32_t* col_offsets,
+ const int32_t* bias,
+ int thread_id, int num_threads) {
+ int32_t C_int32_temp[(K + 31) / 32 * 32];
+ depthwise_3x3x3_pad_1_<true /* FUSE_RESCALE */, true /* FUSE_RELU */>(
+ N, T, H, W, K,
+ stride_t, stride_h, stride_w,
+ A_zero_point, A,
+ B_zero_point, B,
+ C_multiplier, C_zero_point, C_int32_temp, C,
+ col_offsets, bias,
+ thread_id, num_threads);
+}
+
+void depthwise_3x3x3_pad_1(
+ int N, int T, int H, int W, int K,
+ int stride_t, int stride_h, int stride_w,
+ int32_t A_zero_point, const uint8_t* A,
+ int32_t B_zero_point, const Packed3x3x3ConvMatrix& B,
+ float C_multiplier, int32_t C_zero_point, uint8_t* C,
+ const int32_t* col_offsets,
+ const int32_t* bias,
+ bool fuse_relu,
+ int thread_id, int num_threads) {
+ // If we inline the following two functions, I see stack overflow.
+ if (fuse_relu) {
+ depthwise_3x3x3_pad_1_relu_fused_(
+ N, T, H, W, K, stride_t, stride_h, stride_w, A_zero_point, A,
+ B_zero_point, B, C_multiplier, C_zero_point, C,
+ col_offsets, bias, thread_id, num_threads);
+ } else {
+ depthwise_3x3x3_pad_1_(N, T, H, W, K, stride_t, stride_h, stride_w,
+ A_zero_point, A, B_zero_point, B, C_multiplier,
+ C_zero_point, C, col_offsets, bias,
+ thread_id, num_threads);
+ }
+}
+
+void depthwise_3x3_per_channel_quantization_pad_1(
+ int N, int H, int W, int K,
+ int stride_h, int stride_w,
+ int32_t A_zero_point, const uint8_t* A,
+ const int32_t *B_zero_point, const Packed3x3ConvMatrix& Bp,
+ const float *C_multiplier, int32_t C_zero_point, uint8_t* C,
+ const int32_t* col_offsets,
+ const int32_t* bias,
+ int thread_id, int num_threads) {
+ int32_t C_int32_temp[(K + 31) / 32 * 32];
+ if (7 == H && 7 == W && 1 == stride_h && 1 == stride_w) {
+ depthwise_3x3_per_channel_quantization_pad_1_(
+ N, H, W, K,
+ stride_h, stride_w,
+ A_zero_point, A,
+ B_zero_point, Bp,
+ C_multiplier, C_zero_point, C_int32_temp, C,
+ col_offsets, bias,
+ thread_id, num_threads);
+ } else if (14 == H && 14 == W && 2 == stride_h && 2 == stride_w) {
+ depthwise_3x3_per_channel_quantization_pad_1_(
+ N, H, W, K,
+ stride_h, stride_w,
+ A_zero_point, A,
+ B_zero_point, Bp,
+ C_multiplier, C_zero_point, C_int32_temp, C,
+ col_offsets, bias,
+ thread_id, num_threads);
+ } else if (1 == stride_h && 1 == stride_w) {
+ depthwise_3x3_per_channel_quantization_pad_1_(
+ N, H, W, K,
+ stride_h, stride_w,
+ A_zero_point, A,
+ B_zero_point, Bp,
+ C_multiplier, C_zero_point, C_int32_temp, C,
+ col_offsets, bias,
+ thread_id, num_threads);
+ } else if (2 == stride_h && 2 == stride_w) {
+ depthwise_3x3_per_channel_quantization_pad_1_(
+ N, H, W, K,
+ stride_h, stride_w,
+ A_zero_point, A,
+ B_zero_point, Bp,
+ C_multiplier, C_zero_point, C_int32_temp, C,
+ col_offsets, bias,
+ thread_id, num_threads);
+ } else {
+ depthwise_3x3_per_channel_quantization_pad_1_(
+ N, H, W, K,
+ stride_h, stride_w,
+ A_zero_point, A,
+ B_zero_point, Bp,
+ C_multiplier, C_zero_point, C_int32_temp, C,
+ col_offsets, bias,
+ thread_id, num_threads);
+ }
+}
+
+} // namespace fbgemm2
diff --git a/src/FbgemmI8Depthwise.h b/src/FbgemmI8Depthwise.h
new file mode 100644
index 0000000..bc62c84
--- /dev/null
+++ b/src/FbgemmI8Depthwise.h
@@ -0,0 +1,105 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+#pragma once
+
+#include <cstdint>
+
+namespace fbgemm2
+{
+
+// KERNEL_PROD is the product of all kernels.
+// For example, KERNEL_PROD = 9 for 3x3, and 27 for 3x3x3.
+template <int KERNEL_PROD>
+class PackedDepthWiseConvMatrix
+{
+ public:
+ // smat in RSG layout
+ PackedDepthWiseConvMatrix(int K, const std::int8_t *smat);
+ virtual ~PackedDepthWiseConvMatrix();
+
+ const std::int8_t* PackedMat() const {
+ return pmat_;
+ }
+
+ private:
+ int K_;
+ std::int8_t* pmat_;
+}; // Packed3x3ConvMatrix
+
+using Packed3x3ConvMatrix = PackedDepthWiseConvMatrix<3 * 3>;
+using Packed3x3x3ConvMatrix = PackedDepthWiseConvMatrix<3 * 3 * 3>;
+
+/**
+ * Depth-wise 3x3 convolution with pad=1 and stride=1 and K a multiple of 8
+ * @params A The input image in NHWK layout
+ * @params Bp The pre-packed filter
+ */
+void depthwise_3x3_pad_1(
+ int N,
+ int H,
+ int W,
+ int K,
+ int stride_h,
+ int stride_w,
+ std::int32_t A_zero_point, const std::uint8_t* A,
+ const Packed3x3ConvMatrix& Bp,
+ std::int32_t* C,
+ int thread_id = 0, int num_threads = 1);
+
+/**
+ * Depth-wise 3x3 convolution with pad=1 and stride=1 and K a multiple of 8
+ * This version is fused with requantization.
+ */
+void depthwise_3x3_pad_1(
+ int N, int H, int W, int K,
+ int stride_h, int stride_w,
+ std::int32_t A_zero_point, const std::uint8_t* A,
+ std::int32_t B_zero_point, const Packed3x3ConvMatrix& Bp,
+ float C_multiplier, std::int32_t C_zero_point,
+ std::uint8_t* C,
+ const std::int32_t* col_offsets, const std::int32_t* bias,
+ int thread_id = 0, int num_threads = 1, bool fuse_relu = false);
+
+/**
+ * Depth-wise 3x3 convolution with pad=1 and stride=1 and K a multiple of 8
+ * This version is fused with requantization and uses per-channel quantization.
+ */
+void depthwise_3x3_per_channel_quantization_pad_1(
+ int N, int H, int W, int K,
+ int stride_h, int stride_w,
+ std::int32_t A_zero_point, const std::uint8_t* A,
+ const std::int32_t *B_zero_point, const Packed3x3ConvMatrix& Bp,
+ const float *C_multiplier, std::int32_t C_zero_point,
+ std::uint8_t* C,
+ const std::int32_t* col_offsets, const std::int32_t* bias,
+ int thread_id = 0, int num_threads = 1);
+
+void depthwise_3x3x3_pad_1(
+ int N,
+ int T,
+ int H,
+ int W,
+ int K,
+ int stride_t,
+ int stride_h,
+ int stride_w,
+ std::int32_t A_zero_point, const std::uint8_t* A,
+ const Packed3x3x3ConvMatrix& Bp,
+ std::int32_t* C,
+ int thread_id = 0, int num_threads = 1);
+
+void depthwise_3x3x3_pad_1(
+ int N, int T, int H, int W, int K,
+ int stride_t, int stride_h, int stride_w,
+ std::int32_t A_zero_point, const std::uint8_t* A,
+ std::int32_t B_zero_point, const Packed3x3x3ConvMatrix& Bp,
+ float C_multiplier, std::int32_t C_zero_point,
+ std::uint8_t* C,
+ const std::int32_t* col_offsets, const std::int32_t* bias,
+ bool fuse_relu = false, int thread_id = 0, int num_threads = 1);
+
+} // namespace fbgemm2
diff --git a/src/FbgemmI8Spmdm.cc b/src/FbgemmI8Spmdm.cc
new file mode 100644
index 0000000..723a467
--- /dev/null
+++ b/src/FbgemmI8Spmdm.cc
@@ -0,0 +1,508 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+#include "fbgemm/FbgemmI8Spmdm.h"
+
+#include <algorithm>
+#include <array>
+#include <cassert>
+#include <cmath>
+#include <cstring>
+
+#include <immintrin.h>
+
+#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN
+double spmdm_initial_time = 0.0;
+double spmdm_transpose_uint8_time = 0.0;
+double spmdm_transpose_32xN_time = 0.0;
+double spmdm_compute_time = 0.0;
+double spmdm_transpose_Nx32_time = 0.0;
+double spmdm_run_time = 0.0;
+#endif
+
+using namespace std;
+
+namespace fbgemm2 {
+
+CompressedSparseColumn::CompressedSparseColumn(int num_of_rows, int num_of_cols)
+ : num_rows_(num_of_rows),
+ colptr_(num_of_cols + 1),
+ hyper_sparse_(false),
+ old_nnz_(-1) {}
+
+double CompressedSparseColumn::Density() const {
+ return (double)NumOfNonZeros() / (NumOfRows() * NumOfCols());
+}
+
+bool CompressedSparseColumn::IsHyperSparse() const {
+ if (NumOfNonZeros() != old_nnz_) {
+ old_nnz_ = NumOfNonZeros();
+ // The number of non-zero per row is very small.
+ hyper_sparse_ = (double)old_nnz_ / NumOfRows() < 0.08;
+ }
+
+ return hyper_sparse_;
+}
+
+static void transpose_8rows(
+ int N,
+ const uint8_t* src,
+ int ld_src,
+ uint8_t* dst,
+ int ld_dst) {
+ constexpr int M = 8;
+ int j;
+ // vectorized loop
+ for (j = 0; j < N / 32 * 32; j += 32) {
+ // a : a0 a1 ... a31
+ // b : b0 b1 ... b31
+ // c : c0 c1 ... c31
+ // d : d0 d1 ... d31
+ __m256i a = _mm256_lddqu_si256(
+ reinterpret_cast<const __m256i*>(src + j + 0 * ld_src));
+ __m256i b = _mm256_lddqu_si256(
+ reinterpret_cast<const __m256i*>(src + j + 1 * ld_src));
+ __m256i c = _mm256_lddqu_si256(
+ reinterpret_cast<const __m256i*>(src + j + 2 * ld_src));
+ __m256i d = _mm256_lddqu_si256(
+ reinterpret_cast<const __m256i*>(src + j + 3 * ld_src));
+ __m256i e = _mm256_lddqu_si256(
+ reinterpret_cast<const __m256i*>(src + j + 4 * ld_src));
+ __m256i f = _mm256_lddqu_si256(
+ reinterpret_cast<const __m256i*>(src + j + 5 * ld_src));
+ __m256i g = _mm256_lddqu_si256(
+ reinterpret_cast<const __m256i*>(src + j + 6 * ld_src));
+ __m256i h = _mm256_lddqu_si256(
+ reinterpret_cast<const __m256i*>(src + j + 7 * ld_src));
+
+ // even-odd interleaving
+ // ab_lo : a0 b0 a1 b1 ... a7 b7 | a16 b16 ... a23 b23
+ // ab_hi : a8 b8 a9 b9 ... a15 b15 | a24 b24 ... a31 b31
+ // cd_lo : c0 d0 c1 d1 ... c7 d7 | c16 d16 ... c23 d23
+ // cd_hi : c8 d8 c9 d9 ... c15 d15 | c24 d24 ... c31 d31
+ __m256i ab_lo = _mm256_unpacklo_epi8(a, b);
+ __m256i ab_hi = _mm256_unpackhi_epi8(a, b);
+ __m256i cd_lo = _mm256_unpacklo_epi8(c, d);
+ __m256i cd_hi = _mm256_unpackhi_epi8(c, d);
+ __m256i ef_lo = _mm256_unpacklo_epi8(e, f);
+ __m256i ef_hi = _mm256_unpackhi_epi8(e, f);
+ __m256i gh_lo = _mm256_unpacklo_epi8(g, h);
+ __m256i gh_hi = _mm256_unpackhi_epi8(g, h);
+
+ // 4-row interleaving but permuted at 128-bit granularity
+ // abcd0 : a0 b0 c0 d0 ... a-d3 | a-d16 ... a-d19
+ // abcd1 : a4 b4 c4 d4 ... a-d7 | a-d20 ... a-d23
+ // abcd2 : a8 b8 c8 d8 ... a-d11 | a-d24 ... a-d27
+ // abcd3 : a12 b12 c12 d12 ... a-d15 | a-d28 ... a-d31
+ __m256i abcd0 = _mm256_unpacklo_epi16(ab_lo, cd_lo);
+ __m256i abcd1 = _mm256_unpackhi_epi16(ab_lo, cd_lo);
+ __m256i abcd2 = _mm256_unpacklo_epi16(ab_hi, cd_hi);
+ __m256i abcd3 = _mm256_unpackhi_epi16(ab_hi, cd_hi);
+ __m256i efgh0 = _mm256_unpacklo_epi16(ef_lo, gh_lo);
+ __m256i efgh1 = _mm256_unpackhi_epi16(ef_lo, gh_lo);
+ __m256i efgh2 = _mm256_unpacklo_epi16(ef_hi, gh_hi);
+ __m256i efgh3 = _mm256_unpackhi_epi16(ef_hi, gh_hi);
+
+ // 8-row interleaving
+ __m256i y0 = _mm256_unpacklo_epi32(abcd0, efgh0);
+ __m256i y1 = _mm256_unpackhi_epi32(abcd0, efgh0);
+ __m256i y2 = _mm256_unpacklo_epi32(abcd1, efgh1);
+ __m256i y3 = _mm256_unpackhi_epi32(abcd1, efgh1);
+ __m256i y4 = _mm256_unpacklo_epi32(abcd2, efgh2);
+ __m256i y5 = _mm256_unpackhi_epi32(abcd2, efgh2);
+ __m256i y6 = _mm256_unpacklo_epi32(abcd3, efgh3);
+ __m256i y7 = _mm256_unpackhi_epi32(abcd3, efgh3);
+
+ // Storing with 128-bit lanes are permuted so that everything is in order
+ _mm_storel_epi64(
+ reinterpret_cast<__m128i*>(dst + (j + 0) * ld_dst),
+ _mm256_castsi256_si128(y0));
+ *reinterpret_cast<int64_t*>(dst + (j + 1) * ld_dst) =
+ _mm256_extract_epi64(y0, 1);
+ _mm_storel_epi64(
+ reinterpret_cast<__m128i*>(dst + (j + 2) * ld_dst),
+ _mm256_castsi256_si128(y1));
+ *reinterpret_cast<int64_t*>(dst + (j + 3) * ld_dst) =
+ _mm256_extract_epi64(y1, 1);
+ _mm_storel_epi64(
+ reinterpret_cast<__m128i*>(dst + (j + 4) * ld_dst),
+ _mm256_castsi256_si128(y2));
+ *reinterpret_cast<int64_t*>(dst + (j + 5) * ld_dst) =
+ _mm256_extract_epi64(y2, 1);
+ _mm_storel_epi64(
+ reinterpret_cast<__m128i*>(dst + (j + 6) * ld_dst),
+ _mm256_castsi256_si128(y3));
+ *reinterpret_cast<int64_t*>(dst + (j + 7) * ld_dst) =
+ _mm256_extract_epi64(y3, 1);
+ _mm_storel_epi64(
+ reinterpret_cast<__m128i*>(dst + (j + 8) * ld_dst),
+ _mm256_castsi256_si128(y4));
+ *reinterpret_cast<int64_t*>(dst + (j + 9) * ld_dst) =
+ _mm256_extract_epi64(y4, 1);
+ _mm_storel_epi64(
+ reinterpret_cast<__m128i*>(dst + (j + 10) * ld_dst),
+ _mm256_castsi256_si128(y5));
+ *reinterpret_cast<int64_t*>(dst + (j + 11) * ld_dst) =
+ _mm256_extract_epi64(y5, 1);
+ _mm_storel_epi64(
+ reinterpret_cast<__m128i*>(dst + (j + 12) * ld_dst),
+ _mm256_castsi256_si128(y6));
+ *reinterpret_cast<int64_t*>(dst + (j + 13) * ld_dst) =
+ _mm256_extract_epi64(y6, 1);
+ _mm_storel_epi64(
+ reinterpret_cast<__m128i*>(dst + (j + 14) * ld_dst),
+ _mm256_castsi256_si128(y7));
+ *reinterpret_cast<int64_t*>(dst + (j + 15) * ld_dst) =
+ _mm256_extract_epi64(y7, 1);
+ *reinterpret_cast<int64_t*>(dst + (j + 16) * ld_dst) =
+ _mm256_extract_epi64(y0, 2);
+ *reinterpret_cast<int64_t*>(dst + (j + 17) * ld_dst) =
+ _mm256_extract_epi64(y0, 3);
+ *reinterpret_cast<int64_t*>(dst + (j + 18) * ld_dst) =
+ _mm256_extract_epi64(y1, 2);
+ *reinterpret_cast<int64_t*>(dst + (j + 19) * ld_dst) =
+ _mm256_extract_epi64(y1, 3);
+ *reinterpret_cast<int64_t*>(dst + (j + 20) * ld_dst) =
+ _mm256_extract_epi64(y2, 2);
+ *reinterpret_cast<int64_t*>(dst + (j + 21) * ld_dst) =
+ _mm256_extract_epi64(y2, 3);
+ *reinterpret_cast<int64_t*>(dst + (j + 22) * ld_dst) =
+ _mm256_extract_epi64(y3, 2);
+ *reinterpret_cast<int64_t*>(dst + (j + 23) * ld_dst) =
+ _mm256_extract_epi64(y3, 3);
+ *reinterpret_cast<int64_t*>(dst + (j + 24) * ld_dst) =
+ _mm256_extract_epi64(y4, 2);
+ *reinterpret_cast<int64_t*>(dst + (j + 25) * ld_dst) =
+ _mm256_extract_epi64(y4, 3);
+ *reinterpret_cast<int64_t*>(dst + (j + 26) * ld_dst) =
+ _mm256_extract_epi64(y5, 2);
+ *reinterpret_cast<int64_t*>(dst + (j + 27) * ld_dst) =
+ _mm256_extract_epi64(y5, 3);
+ *reinterpret_cast<int64_t*>(dst + (j + 28) * ld_dst) =
+ _mm256_extract_epi64(y6, 2);
+ *reinterpret_cast<int64_t*>(dst + (j + 29) * ld_dst) =
+ _mm256_extract_epi64(y6, 3);
+ *reinterpret_cast<int64_t*>(dst + (j + 30) * ld_dst) =
+ _mm256_extract_epi64(y7, 2);
+ *reinterpret_cast<int64_t*>(dst + (j + 31) * ld_dst) =
+ _mm256_extract_epi64(y7, 3);
+ }
+
+ // scalar loop for remainder
+ for (; j < N; ++j) {
+ for (int i = 0; i < M; ++i) {
+ dst[j * ld_dst + i] = src[j + i * ld_src];
+ }
+ }
+}
+
+// TODO: fallback when AVX2 is not available
+void CompressedSparseColumn::SpMDM(
+ const block_type_t& block,
+ const uint8_t* A,
+ int lda,
+ bool accumulation,
+ int32_t* C,
+ int ldc) const {
+ int K = NumOfRows();
+ int N = block.col_size;
+
+ if (K == 0 || N == 0) {
+ return;
+ }
+
+#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN
+ std::chrono::time_point<std::chrono::high_resolution_clock> t_very_start,
+ t_start, t_end;
+ double dt;
+ t_start = std::chrono::high_resolution_clock::now();
+ t_very_start = std::chrono::high_resolution_clock::now();
+#endif
+
+ uint8_t A_buffer[K * 32] __attribute__((aligned(64)));
+ int32_t C_buffer[N * 32] __attribute__((aligned(64)));
+
+ // If we compute C = C + A * B, where B is a sparse matrix in CSC format, for
+ // each non-zero in B, we'd need to access the corresponding column in A.
+ // This results in strided access, which we want to avoid.
+ // Instead, we pre-transpose A and C, and compute C = (C^T + B^T * A^T)^T
+
+ if (IsHyperSparse()) {
+ // The cost of transpose is O(K*N) and we do O(NNZ*N) multiplications.
+ // If NNZ/K is small, it's not worth doing transpose so we just use this
+ // scalar loop.
+ if (!accumulation) {
+ for (int i = block.row_start; i < block.row_start + block.row_size; ++i) {
+ for (int j = block.col_start; j < block.col_start + block.col_size;
+ ++j) {
+ C[(i - block.row_start) * ldc + j - block.col_start] = 0;
+ }
+ }
+ }
+ for (int j = block.col_start; j < block.col_start + block.col_size; ++j) {
+ for (int k = colptr_[j]; k < colptr_[j + 1]; ++k) {
+ int row = rowidx_[k];
+ int w = values_[k];
+ for (int i = block.row_start; i < block.row_start + block.row_size;
+ ++i) {
+ C[(i - block.row_start) * ldc + j - block.col_start] +=
+ A[i * lda + row] * w;
+ }
+ }
+ } // for each column of B
+ return;
+ }
+
+#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN
+ t_end = std::chrono::high_resolution_clock::now();
+ dt = std::chrono::duration_cast<std::chrono::nanoseconds>(t_end - t_start)
+ .count();
+ spmdm_initial_time += (dt);
+ t_start = std::chrono::high_resolution_clock::now();
+#endif
+
+ // Take 32 rows at a time
+ int i_end = block.row_start + block.row_size;
+ for (int i1 = block.row_start; i1 < i_end; i1 += 32) {
+ // Transpose 32 x K submatrix of A
+ if (i_end - i1 < 32) {
+ uint8_t A_temp_buffer[K * 32] __attribute__((aligned(64)));
+ for (int i2 = 0; i2 < (i_end - i1) / 8 * 8; i2 += 8) {
+ transpose_8rows(K, A + (i1 + i2) * lda, lda, A_buffer + i2, 32);
+ }
+
+ for (int i2 = (i_end - i1) / 8 * 8; i2 < i_end - i1; ++i2) {
+ memcpy(
+ A_temp_buffer + i2 * K, A + (i1 + i2) * lda, K * sizeof(uint8_t));
+ }
+ memset(
+ A_temp_buffer + (i_end - i1) * K,
+ 0,
+ (32 - (i_end - i1)) * K * sizeof(uint8_t));
+ for (int i2 = (i_end - i1) / 8 * 8; i2 < 32; i2 += 8) {
+ transpose_8rows(K, A_temp_buffer + i2 * K, K, A_buffer + i2, 32);
+ }
+ } else {
+ for (int i2 = 0; i2 < 32; i2 += 8) {
+ transpose_8rows(K, A + (i1 + i2) * lda, lda, A_buffer + i2, 32);
+ }
+ }
+
+#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN
+ t_end = std::chrono::high_resolution_clock::now();
+ dt = std::chrono::duration_cast<std::chrono::nanoseconds>(t_end - t_start)
+ .count();
+ spmdm_transpose_uint8_time += (dt);
+ t_start = std::chrono::high_resolution_clock::now();
+#endif
+
+ if (accumulation) {
+ // Transpose 32 x N submatrix of C to fill N x 32 C_buffer
+ transpose_simd(
+ std::min(32, i_end - i1),
+ N,
+ reinterpret_cast<const float*>(C + (i1 - block.row_start) * ldc),
+ ldc,
+ reinterpret_cast<float*>(C_buffer),
+ 32);
+ } else {
+ memset(C_buffer, 0, N * 32 * sizeof(int32_t));
+ }
+
+#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN
+ t_end = std::chrono::high_resolution_clock::now();
+ dt = std::chrono::duration_cast<std::chrono::nanoseconds>(t_end - t_start)
+ .count();
+ spmdm_transpose_32xN_time += (dt);
+ t_start = std::chrono::high_resolution_clock::now();
+#endif
+
+ for (int j = 0; j < block.col_size; ++j) {
+ int j_start = j + block.col_start;
+ int k = colptr_[j_start];
+ int k_end_aligned =
+ colptr_[j_start] + (colptr_[j_start + 1] - colptr_[j_start]) / 4 * 4;
+
+ for (; k < k_end_aligned; k += 4) {
+ __m256i w =
+ _mm256_set1_epi32(*(reinterpret_cast<const int32_t*>(&values_[k])));
+ array<__m256i, 4> a;
+ a[0] = _mm256_load_si256(
+ reinterpret_cast<const __m256i*>(&A_buffer[rowidx_[k + 0] * 32]));
+ a[1] = _mm256_load_si256(
+ reinterpret_cast<const __m256i*>(&A_buffer[rowidx_[k + 1] * 32]));
+ a[2] = _mm256_load_si256(
+ reinterpret_cast<const __m256i*>(&A_buffer[rowidx_[k + 2] * 32]));
+ a[3] = _mm256_load_si256(
+ reinterpret_cast<const __m256i*>(&A_buffer[rowidx_[k + 3] * 32]));
+
+ __m256i a01_lo = _mm256_unpacklo_epi8(a[0], a[1]);
+ __m256i a01_hi = _mm256_unpackhi_epi8(a[0], a[1]);
+ __m256i a23_lo = _mm256_unpacklo_epi8(a[2], a[3]);
+ __m256i a23_hi = _mm256_unpackhi_epi8(a[2], a[3]);
+
+ a[0] = _mm256_unpacklo_epi16(a01_lo, a23_lo);
+ a[1] = _mm256_unpackhi_epi16(a01_lo, a23_lo);
+ a[2] = _mm256_unpacklo_epi16(a01_hi, a23_hi);
+ a[3] = _mm256_unpackhi_epi16(a01_hi, a23_hi);
+
+ array<__m256i, 4> ab;
+ ab[0] = _mm256_maddubs_epi16(a[0], w);
+ ab[1] = _mm256_maddubs_epi16(a[1], w);
+ ab[2] = _mm256_maddubs_epi16(a[2], w);
+ ab[3] = _mm256_maddubs_epi16(a[3], w);
+
+ __m256i one = _mm256_set1_epi16(1);
+ ab[0] = _mm256_madd_epi16(ab[0], one);
+ ab[1] = _mm256_madd_epi16(ab[1], one);
+ ab[2] = _mm256_madd_epi16(ab[2], one);
+ ab[3] = _mm256_madd_epi16(ab[3], one);
+
+ array<__m256i, 4> t;
+ t[0] = _mm256_permute2f128_si256(ab[0], ab[1], 0x20);
+ t[1] = _mm256_permute2f128_si256(ab[2], ab[3], 0x20);
+ t[2] = _mm256_permute2f128_si256(ab[0], ab[1], 0x31);
+ t[3] = _mm256_permute2f128_si256(ab[2], ab[3], 0x31);
+
+ _mm256_store_si256(
+ reinterpret_cast<__m256i*>(&C_buffer[j * 32 + 0 * 8]),
+ _mm256_add_epi32(
+ _mm256_load_si256(reinterpret_cast<const __m256i*>(
+ &C_buffer[j * 32 + 0 * 8])),
+ t[0]));
+ _mm256_store_si256(
+ reinterpret_cast<__m256i*>(&C_buffer[j * 32 + 1 * 8]),
+ _mm256_add_epi32(
+ _mm256_load_si256(reinterpret_cast<const __m256i*>(
+ &C_buffer[j * 32 + 1 * 8])),
+ t[1]));
+ _mm256_store_si256(
+ reinterpret_cast<__m256i*>(&C_buffer[j * 32 + 2 * 8]),
+ _mm256_add_epi32(
+ _mm256_load_si256(reinterpret_cast<const __m256i*>(
+ &C_buffer[j * 32 + 2 * 8])),
+ t[2]));
+ _mm256_store_si256(
+ reinterpret_cast<__m256i*>(&C_buffer[j * 32 + 3 * 8]),
+ _mm256_add_epi32(
+ _mm256_load_si256(reinterpret_cast<const __m256i*>(
+ &C_buffer[j * 32 + 3 * 8])),
+ t[3]));
+ }
+
+ int remainder = colptr_[j_start + 1] - k;
+ assert(remainder < 4);
+ if (remainder > 0) {
+ int32_t temp_w = 0;
+ for (int r = 0; r < remainder; ++r) {
+ (reinterpret_cast<int8_t*>(&temp_w))[r] = values_[k + r];
+ }
+ __m256i w = _mm256_set1_epi32(temp_w);
+ array<__m256i, 4> a;
+ a[0] = _mm256_load_si256(
+ reinterpret_cast<const __m256i*>(&A_buffer[rowidx_[k + 0] * 32]));
+ a[1] = remainder > 1
+ ? _mm256_load_si256(reinterpret_cast<const __m256i*>(
+ &A_buffer[rowidx_[k + 1] * 32]))
+ : _mm256_setzero_si256();
+ a[2] = remainder > 2
+ ? _mm256_load_si256(reinterpret_cast<const __m256i*>(
+ &A_buffer[rowidx_[k + 2] * 32]))
+ : _mm256_setzero_si256();
+ a[3] = _mm256_setzero_si256();
+
+ __m256i a01_lo = _mm256_unpacklo_epi8(a[0], a[1]);
+ __m256i a01_hi = _mm256_unpackhi_epi8(a[0], a[1]);
+ __m256i a23_lo = _mm256_unpacklo_epi8(a[2], a[3]);
+ __m256i a23_hi = _mm256_unpackhi_epi8(a[2], a[3]);
+
+ a[0] = _mm256_unpacklo_epi16(a01_lo, a23_lo);
+ a[1] = _mm256_unpackhi_epi16(a01_lo, a23_lo);
+ a[2] = _mm256_unpacklo_epi16(a01_hi, a23_hi);
+ a[3] = _mm256_unpackhi_epi16(a01_hi, a23_hi);
+
+ array<__m256i, 4> ab;
+ ab[0] = _mm256_maddubs_epi16(a[0], w);
+ ab[1] = _mm256_maddubs_epi16(a[1], w);
+ ab[2] = _mm256_maddubs_epi16(a[2], w);
+ ab[3] = _mm256_maddubs_epi16(a[3], w);
+
+ __m256i one = _mm256_set1_epi16(1);
+ ab[0] = _mm256_madd_epi16(ab[0], one);
+ ab[1] = _mm256_madd_epi16(ab[1], one);
+ ab[2] = _mm256_madd_epi16(ab[2], one);
+ ab[3] = _mm256_madd_epi16(ab[3], one);
+
+ array<__m256i, 4> t;
+ t[0] = _mm256_permute2f128_si256(ab[0], ab[1], 0x20);
+ t[1] = _mm256_permute2f128_si256(ab[2], ab[3], 0x20);
+ t[2] = _mm256_permute2f128_si256(ab[0], ab[1], 0x31);
+ t[3] = _mm256_permute2f128_si256(ab[2], ab[3], 0x31);
+
+ _mm256_store_si256(
+ reinterpret_cast<__m256i*>(&C_buffer[j * 32 + 0 * 8]),
+ _mm256_add_epi32(
+ _mm256_load_si256(reinterpret_cast<const __m256i*>(
+ &C_buffer[j * 32 + 0 * 8])),
+ t[0]));
+ _mm256_store_si256(
+ reinterpret_cast<__m256i*>(&C_buffer[j * 32 + 1 * 8]),
+ _mm256_add_epi32(
+ _mm256_load_si256(reinterpret_cast<const __m256i*>(
+ &C_buffer[j * 32 + 1 * 8])),
+ t[1]));
+ _mm256_store_si256(
+ reinterpret_cast<__m256i*>(&C_buffer[j * 32 + 2 * 8]),
+ _mm256_add_epi32(
+ _mm256_load_si256(reinterpret_cast<const __m256i*>(
+ &C_buffer[j * 32 + 2 * 8])),
+ t[2]));
+ _mm256_store_si256(
+ reinterpret_cast<__m256i*>(&C_buffer[j * 32 + 3 * 8]),
+ _mm256_add_epi32(
+ _mm256_load_si256(reinterpret_cast<const __m256i*>(
+ &C_buffer[j * 32 + 3 * 8])),
+ t[3]));
+ }
+ } // for each column of B
+
+#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN
+ t_end = std::chrono::high_resolution_clock::now();
+ dt = std::chrono::duration_cast<std::chrono::nanoseconds>(t_end - t_start)
+ .count();
+ spmdm_compute_time += (dt);
+ t_start = std::chrono::high_resolution_clock::now();
+#endif
+
+ // Transpose N x 32 C_buffer to fill 32 x N submatrix of C
+ transpose_simd(
+ N,
+ std::min(32, i_end - i1),
+ reinterpret_cast<const float*>(C_buffer),
+ 32,
+ reinterpret_cast<float*>(C + (i1 - block.row_start) * ldc),
+ ldc);
+
+#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN
+ t_end = std::chrono::high_resolution_clock::now();
+ dt = std::chrono::duration_cast<std::chrono::nanoseconds>(t_end - t_start)
+ .count();
+ spmdm_transpose_Nx32_time += (dt);
+ t_start = std::chrono::high_resolution_clock::now();
+#endif
+ }
+
+#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN
+ t_end = std::chrono::high_resolution_clock::now();
+ dt =
+ std::chrono::duration_cast<std::chrono::nanoseconds>(t_end - t_very_start)
+ .count();
+ spmdm_run_time += (dt);
+ t_start = std::chrono::high_resolution_clock::now();
+#endif
+}
+
+} // namespace fbgemm2
diff --git a/src/GenerateKernel.h b/src/GenerateKernel.h
new file mode 100644
index 0000000..30160d1
--- /dev/null
+++ b/src/GenerateKernel.h
@@ -0,0 +1,154 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+#pragma once
+#include <asmjit/asmjit.h>
+#include <cpuinfo.h>
+#include <map>
+#include <tuple>
+#include "fbgemm/Fbgemm.h"
+
+namespace fbgemm2 {
+
+namespace x86 = asmjit::x86;
+
+/**
+ * @brief AVX2/AVX512 JIT assembly code generator.
+ * @tparam TA Type of matrix A.
+ * @tparam TB Type of matrix B.
+ * @tparam TC Type of matrix C.
+ * @tparam accT Accumulation type, currently we support 16-bit (std::int16_t) or
+ * 32-bit (std::int32_t) accumulation.
+ */
+template <typename TA, typename TB, typename TC, typename accT>
+class CodeGenBase {
+ public:
+ using jit_micro_kernel_fp = void (*)(
+ TA* bufferA,
+ TB* bufferB,
+ TB* b_pf,
+ TC* bufferC,
+ int kc,
+ int ldc);
+
+ /**
+ * @brief Constructor for initializing AVX2/AVX512 registers.
+ */
+ CodeGenBase()
+ : CRegs_avx2_{x86::ymm0,
+ x86::ymm1,
+ x86::ymm2,
+ x86::ymm3,
+ x86::ymm4,
+ x86::ymm5,
+ x86::ymm6,
+ x86::ymm7,
+ x86::ymm8,
+ x86::ymm9,
+ x86::ymm10,
+ x86::ymm11},
+ CRegs_avx512_{
+ x86::zmm0, x86::zmm1, x86::zmm2, x86::zmm3, x86::zmm4,
+ x86::zmm5, x86::zmm6, x86::zmm7, x86::zmm8, x86::zmm9,
+ x86::zmm10, x86::zmm11, x86::zmm12, x86::zmm13, x86::zmm14,
+ x86::zmm15, x86::zmm16, x86::zmm17, x86::zmm18, x86::zmm19,
+ x86::zmm20, x86::zmm21, x86::zmm22, x86::zmm23, x86::zmm24,
+ x86::zmm25, x86::zmm26, x86::zmm27,
+ } {
+ // vector width in bits
+ if (cpuinfo_initialize()) {
+ if (cpuinfo_has_x86_avx512f()) {
+ vectorWidth_ = 512;
+ } else if (cpuinfo_has_x86_avx2()) {
+ vectorWidth_ = 256;
+ } else {
+ // TODO: Have default path
+ assert(0 && "unsupported architecture");
+ return;
+ }
+ } else {
+ throw std::runtime_error("Failed to initialize cpuinfo!");
+ }
+ // vector width in elements
+ VLEN_ = vectorWidth_ / 8 * sizeof(TA);
+ }
+
+ /**
+ * @brief Get or Create the instructions for macro-kernel.
+ *
+ * If the problem size (mc, nc) and accumulation flag (accum) can be found in
+ * the code cache (a hash map), then get the macro-kernel instructions
+ * directly from it. Otherwise, create the instructions for macro-kernel, and
+ * store that into the code cache.
+ */
+ template <inst_set_t instSet>
+ jit_micro_kernel_fp
+ getOrCreate(bool accum, int32_t mc, int32_t nc, int32_t kc, int32_t ldc);
+
+ /**
+ * @brief Generate instructions for initializing the C registers to 0.
+ */
+ template <inst_set_t instSet>
+ void initCRegs(
+ asmjit::X86Emitter* a,
+ int rowRegs,
+ int colRegs,
+ int leadingDimCRegAssign = 4);
+
+ /**
+ * @brief Generate instructions for computing block in the rank-k update.
+ */
+ template <inst_set_t instSet>
+ void genComputeBlock(
+ asmjit::X86Emitter* a,
+ asmjit::X86Gp buffer_A,
+ asmjit::X86Gp buffer_B,
+ asmjit::X86Gp B_pf,
+ int rowRegs,
+ int colRegs,
+ int lda,
+ int leadingDimCRegAssign = 4);
+
+ /**
+ * @brief Generate instructions for storing the C registers back to the
+ * memory.
+ */
+ template <inst_set_t instSet>
+ void storeCRegs(
+ asmjit::X86Emitter* a,
+ int rowRegs,
+ int colRegs,
+ asmjit::X86Gp C_Offset,
+ asmjit::X86Gp ldcReg,
+ bool accum,
+ int leadingDimCRegAssign = 4);
+
+ private:
+ asmjit::X86Ymm
+ CRegs_avx2_[12]; ///< AVX2 ymm registers for C in the micro-kernel.
+ asmjit::X86Zmm
+ CRegs_avx512_[28]; ///< AVX512 zmm registers for C in the micro-kernel.
+ int vectorWidth_; ///< Vector width in bits.
+ int VLEN_; ///< Vector width in elements.
+ static thread_local asmjit::JitRuntime rt_; ///< JIT Runtime for asmjit.
+ static thread_local asmjit::CodeHolder code_; ///< JIT Code Holder for asmjit.
+ static thread_local std::map<std::tuple<bool, int, int>, jit_micro_kernel_fp>
+ codeCache_; ///< JIT Code Cache for reuse.
+};
+
+template <typename TA, typename TB, typename TC, typename accT>
+thread_local asmjit::JitRuntime CodeGenBase<TA, TB, TC, accT>::rt_;
+
+template <typename TA, typename TB, typename TC, typename accT>
+thread_local asmjit::CodeHolder CodeGenBase<TA, TB, TC, accT>::code_;
+
+template <typename TA, typename TB, typename TC, typename accT>
+thread_local std::map<
+ std::tuple<bool, int, int>,
+ typename CodeGenBase<TA, TB, TC, accT>::jit_micro_kernel_fp>
+ CodeGenBase<TA, TB, TC, accT>::codeCache_;
+
+} // namespace fbgemm2
diff --git a/src/GenerateKernelU8S8S32ACC16.cc b/src/GenerateKernelU8S8S32ACC16.cc
new file mode 100644
index 0000000..2ffe3ab
--- /dev/null
+++ b/src/GenerateKernelU8S8S32ACC16.cc
@@ -0,0 +1,292 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+#include <iostream>
+#include "GenerateKernel.h"
+
+namespace fbgemm2 {
+
+namespace x86 = asmjit::x86;
+
+/**
+ * Generate AVX2 instructions for initializing the C registers to 0 in 16-bit
+ * Accumulation kernel.
+ */
+template <>
+template <>
+void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::initCRegs<
+ inst_set_t::avx2>(
+ asmjit::X86Emitter* a,
+ int rowRegs,
+ int colRegs,
+ int leadingDimCRegAssign) {
+ for (int i = 0; i < rowRegs; ++i) {
+ for (int j = 0; j < colRegs; ++j) {
+ a->vxorps(
+ CRegs_avx2_[i * leadingDimCRegAssign + j],
+ CRegs_avx2_[i * leadingDimCRegAssign + j],
+ CRegs_avx2_[i * leadingDimCRegAssign + j]);
+ }
+ }
+}
+
+/**
+ * Generate AVX2 instructions for computing block in the rank-k update of 16-bit
+ * Accmulation kernel.
+ */
+template <>
+template <>
+void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::genComputeBlock<
+ inst_set_t::avx2>(
+ asmjit::X86Emitter* a,
+ asmjit::X86Gp buffer_A,
+ asmjit::X86Gp buffer_B,
+ asmjit::X86Gp /* unused (reserved for prefetching)*/,
+ int rowRegs,
+ int colRegs,
+ int lda,
+ int leadingDimCRegAssign) {
+ // used for matrix A
+ asmjit::X86Ymm AReg = x86::ymm12;
+
+ asmjit::X86Ymm tmpReg = x86::ymm14;
+
+ for (int i = 0; i < rowRegs; ++i) {
+ // broadcast A
+ a->vpbroadcastw(
+ AReg, x86::dword_ptr(buffer_A, (i * lda) * sizeof(uint8_t)));
+ for (int j = 0; j < colRegs; ++j) {
+ a->vpmaddubsw(
+ tmpReg, AReg, x86::dword_ptr(buffer_B, j * VLEN_ * sizeof(int8_t)));
+ a->vpaddsw(
+ CRegs_avx2_[i * leadingDimCRegAssign + j],
+ tmpReg,
+ CRegs_avx2_[i * leadingDimCRegAssign + j]);
+ // Prefetching is hurting performance in some cases
+ // because prefetch instructions itself consumes a slot
+ // in pipeline issue thus slowing down the kernel.
+ // if((i == rowRegs - 1) && j % 2 == 0){
+ // a->prefetcht0(x86::dword_ptr(B_pf, j*VLEN_*sizeof(int8_t)));
+ //}
+ }
+ }
+}
+
+/**
+ * Generate AVX2 instructions for storing the C registers back to the memory in
+ * 16-bit Accumulation kernel.
+ */
+template <>
+template <>
+void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::storeCRegs<
+ inst_set_t::avx2>(
+ asmjit::X86Emitter* a,
+ int rowRegs,
+ int colRegs,
+ asmjit::X86Gp C_Offset,
+ asmjit::X86Gp ldcReg,
+ bool accum,
+ int leadingDimCRegAssign) {
+ asmjit::X86Xmm extractDest128 = x86::xmm15;
+ asmjit::X86Ymm extractDest256 = x86::ymm15;
+
+ for (int i = 0; i < rowRegs; ++i) {
+ a->imul(C_Offset, ldcReg, i * sizeof(int32_t));
+ for (int j = 0; j < colRegs; ++j) {
+ for (int idx = 0; idx < 2; ++idx) {
+ a->vextracti128(
+ extractDest128, CRegs_avx2_[i * leadingDimCRegAssign + j], idx);
+ a->vpmovsxwd(extractDest256, extractDest128);
+ asmjit::X86Mem destAddr = x86::dword_ptr(
+ a->zcx(), C_Offset, 0, (j * 2 + idx) * 8 * sizeof(int32_t));
+ if (accum) {
+ a->vpaddd(extractDest256, extractDest256, destAddr);
+ }
+ a->vmovups(destAddr, extractDest256);
+ }
+ }
+ }
+}
+
+/**
+ * Get or Create the AVX2 instructions for 16-bit Accumulation macro-kernel.
+ *
+ */
+template <>
+template <>
+CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::jit_micro_kernel_fp
+CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<inst_set_t::avx2>(
+ bool accum,
+ int32_t mc,
+ int32_t nc,
+ int32_t kc,
+ int32_t /* unused */) {
+ auto kernelSig = std::make_tuple(accum, mc, nc);
+ if (codeCache_.find(kernelSig) != codeCache_.end()) {
+ return codeCache_[kernelSig];
+ }
+
+ code_.reset(false);
+ code_.init(rt_.getCodeInfo());
+ asmjit::X86Assembler assembler(&code_);
+ asmjit::X86Emitter* a = assembler.asEmitter();
+ // ToDo: Dump in a file for debugging
+ // code dumping/logging
+ // asmjit::FileLogger logger(stderr);
+ // code_.setLogger(&logger);
+
+ constexpr int kBlock = PackingTraits<int8_t, int16_t, inst_set_t::avx2>::KCB;
+ constexpr int mRegBlockSize =
+ PackingTraits<int8_t, int16_t, inst_set_t::avx2>::MR;
+ // constexpr int nRegBlockSize =
+ // PackingTraits<int8_t, int16_t, inst_set_t::avx2>::NR;
+ constexpr int row_interleave =
+ PackingTraits<int8_t, int16_t, inst_set_t::avx2>::ROW_INTERLEAVE;
+ int mRegBlocks = mc / mRegBlockSize;
+ int mRegBlocksRem = mc % mRegBlockSize;
+ assert(kc % row_interleave == 0 && "kc must be a multiple of row_interleave");
+ // assert((nc == nRegBlockSize) &&
+ //"nc must be equal to the number of register blocks");
+
+ // arguments to the function created
+ asmjit::X86Gp buffer_A = a->zdi();
+ asmjit::X86Gp buffer_B = a->zsi();
+ asmjit::X86Gp B_pf = a->zdx();
+ asmjit::X86Gp CBase = a->zcx();
+ asmjit::X86Gp kSize = a->gpzRef(8);
+ asmjit::X86Gp ldcReg = a->gpzRef(9);
+
+ asmjit::FuncDetail func;
+ func.init(
+ asmjit::
+ FuncSignature6<void, uint8_t*, int8_t*, int8_t*, int32_t*, int, int>(
+ asmjit::CallConv::kIdHost));
+
+ asmjit::FuncFrameInfo ffi;
+ ffi.setDirtyRegs(
+ asmjit::X86Reg::kKindVec,
+ asmjit::Utils::mask(0, 1, 2, 3, 4, 5, 6, 7) |
+ asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14, 15));
+ ffi.setDirtyRegs(
+ asmjit::X86Reg::kKindGp, asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14));
+
+ asmjit::FuncArgsMapper args(&func);
+ args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg);
+
+ args.updateFrameInfo(ffi);
+
+ asmjit::FuncFrameLayout layout;
+ layout.init(func, ffi);
+
+ asmjit::FuncUtils::emitProlog(a, layout);
+ asmjit::FuncUtils::allocArgs(a, layout, args);
+
+ asmjit::Label Loopk = a->newLabel();
+ asmjit::Label LoopMBlocks = a->newLabel();
+
+ asmjit::X86Gp buffer_B_saved = a->gpzRef(10);
+ asmjit::X86Gp C_Offset = a->gpzRef(11);
+ // asmjit::X86Gp B_pf_saved = a->gpzRef(12);
+ asmjit::X86Gp iIdx = a->gpzRef(13);
+ asmjit::X86Gp kIdx = a->gpzRef(14);
+
+ int colRegs = nc * row_interleave * sizeof(int8_t) / VLEN_;
+ if (mRegBlocks > 0) {
+ // move 0 to iteration variables
+ a->mov(iIdx, 0);
+
+ // save B_buffer address
+ a->mov(buffer_B_saved, buffer_B);
+ // a->mov(B_pf_saved, B_pf);
+
+ a->bind(LoopMBlocks);
+ a->inc(iIdx);
+
+ int rowRegs = mRegBlockSize;
+
+ // init C registers
+ initCRegs<inst_set_t::avx2>(a, rowRegs, colRegs);
+
+ // init k loop index
+ a->mov(kIdx, 0);
+ a->bind(Loopk);
+ // k is incremented by row_interleave
+ a->add(kIdx, row_interleave);
+
+ genComputeBlock<inst_set_t::avx2>(
+ a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock);
+
+ // update buffer_A address for next k iteration
+ a->add(buffer_A, row_interleave * sizeof(uint8_t));
+
+ // update buffer_B address for next k iteration
+ a->add(buffer_B, VLEN_ * colRegs * sizeof(int8_t));
+ // a->add(B_pf, VLEN_ * colRegs * sizeof(int8_t));
+
+ a->cmp(kIdx, kSize);
+ a->jl(Loopk);
+
+ // store C matrix
+ storeCRegs<inst_set_t::avx2>(a, rowRegs, colRegs, C_Offset, ldcReg, accum);
+
+ // increment A for next block
+ a->sub(buffer_A, kSize);
+ a->add(buffer_A, (rowRegs)*kBlock * sizeof(uint8_t));
+ // increment C for next block
+ a->imul(C_Offset, ldcReg, rowRegs * sizeof(int32_t));
+ a->add(CBase, C_Offset);
+ // reset B
+ a->mov(buffer_B, buffer_B_saved);
+ // a->mov(B_pf, B_pf_saved);
+
+ a->cmp(iIdx, mRegBlocks);
+ a->jl(LoopMBlocks);
+ }
+ // generate code for remainder
+ if (mRegBlocksRem > 0) {
+ asmjit::Label LoopkRem = a->newLabel();
+ int rowRegs = mRegBlocksRem;
+
+ // init C registers
+ initCRegs<inst_set_t::avx2>(a, rowRegs, colRegs);
+
+ // init k loop index
+ a->mov(kIdx, 0);
+ a->bind(LoopkRem);
+
+ // k is incremented by row_interleave
+ a->add(kIdx, row_interleave);
+
+ genComputeBlock<inst_set_t::avx2>(
+ a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock);
+
+ // update buffer_A address for next k iteration
+ a->add(buffer_A, row_interleave * sizeof(uint8_t));
+
+ // update buffer_B address for next k iteration
+ a->add(buffer_B, VLEN_ * colRegs * sizeof(int8_t));
+ // a->add(B_pf, VLEN_ * colRegs * sizeof(int8_t));
+
+ a->cmp(kIdx, kSize);
+ a->jl(LoopkRem);
+
+ // store C matrix
+ storeCRegs<inst_set_t::avx2>(a, rowRegs, colRegs, C_Offset, ldcReg, accum);
+ }
+
+ asmjit::FuncUtils::emitEpilog(a, layout);
+
+ jit_micro_kernel_fp fn;
+ asmjit::Error err = rt_.add(&fn, &code_);
+ if (err) {
+ std::cout << "Error: in fn add" << std::endl;
+ return nullptr;
+ }
+ codeCache_[kernelSig] = fn;
+ return fn;
+}
+
+} // namespace fbgemm2
diff --git a/src/GenerateKernelU8S8S32ACC16_avx512.cc b/src/GenerateKernelU8S8S32ACC16_avx512.cc
new file mode 100644
index 0000000..e613cf1
--- /dev/null
+++ b/src/GenerateKernelU8S8S32ACC16_avx512.cc
@@ -0,0 +1,295 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+#include <iostream>
+#include "GenerateKernel.h"
+
+namespace fbgemm2 {
+
+namespace x86 = asmjit::x86;
+
+/**
+ * Generate AVX512 instructions for initializing the C registers to 0 in 16-bit
+ * Accumulation kernel.
+ */
+template <>
+template <>
+void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::initCRegs<
+ inst_set_t::avx512>(
+ asmjit::X86Emitter* a,
+ int rowRegs,
+ int colRegs,
+ int leadingDimCRegAssign) {
+ for (int i = 0; i < rowRegs; ++i) {
+ for (int j = 0; j < colRegs; ++j) {
+ a->vxorps(
+ CRegs_avx512_[i * leadingDimCRegAssign + j],
+ CRegs_avx512_[i * leadingDimCRegAssign + j],
+ CRegs_avx512_[i * leadingDimCRegAssign + j]);
+ }
+ }
+}
+
+/**
+ * Generate AVX512 instructions for computing block in the rank-k update of
+ * 16-bit Accmulation kernel.
+ */
+template <>
+template <>
+void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::genComputeBlock<
+ inst_set_t::avx512>(
+ asmjit::X86Emitter* a,
+ asmjit::X86Gp buffer_A,
+ asmjit::X86Gp buffer_B,
+ asmjit::X86Gp /* unused (reserved for prefetching)*/,
+ int rowRegs,
+ int colRegs,
+ int lda,
+ int leadingDimCRegAssign) {
+ // used for matrix A
+ asmjit::X86Zmm AReg = x86::zmm29;
+
+ asmjit::X86Zmm tmpReg = x86::zmm30;
+
+ for (int i = 0; i < rowRegs; ++i) {
+ // broadcast A
+ a->vpbroadcastw(
+ AReg, x86::dword_ptr(buffer_A, (i * lda) * sizeof(uint8_t)));
+ for (int j = 0; j < colRegs; ++j) {
+ a->vpmaddubsw(
+ tmpReg, AReg, x86::dword_ptr(buffer_B, j * VLEN_ * sizeof(int8_t)));
+ a->vpaddsw(
+ CRegs_avx512_[i * leadingDimCRegAssign + j],
+ tmpReg,
+ CRegs_avx512_[i * leadingDimCRegAssign + j]);
+ // Prefetching is hurting performance in some cases
+ // because prefetch instructions itself consumes a slot
+ // in pipeline issue thus slowing down the kernel.
+ // if((i == rowRegs - 1) && j % 2 == 0){
+ // a->prefetcht0(x86::dword_ptr(B_pf, j*VLEN_*sizeof(int8_t)));
+ //}
+ }
+ }
+}
+
+/**
+ * Generate AVX512 instructions for storing the C registers back to the memory
+ * in 16-bit Accumulation kernel.
+ */
+template <>
+template <>
+void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::storeCRegs<
+ inst_set_t::avx512>(
+ asmjit::X86Emitter* a,
+ int rowRegs,
+ int colRegs,
+ asmjit::X86Gp C_Offset,
+ asmjit::X86Gp ldcReg,
+ bool accum,
+ int leadingDimCRegAssign) {
+ asmjit::X86Ymm extractDest256 = x86::ymm31;
+ asmjit::X86Zmm extractDest512 = x86::zmm31;
+
+ for (int i = 0; i < rowRegs; ++i) {
+ a->imul(C_Offset, ldcReg, i * sizeof(int32_t));
+ for (int j = 0; j < colRegs; ++j) {
+ for (int idx = 0; idx < 2; ++idx) {
+ a->vextracti32x8(
+ extractDest256, CRegs_avx512_[i * leadingDimCRegAssign + j], idx);
+ a->vpmovsxwd(extractDest512, extractDest256);
+ asmjit::X86Mem destAddr = x86::dword_ptr(
+ a->zcx(), C_Offset, 0, (j * 2 + idx) * 16 * sizeof(int32_t));
+ if (accum) {
+ a->vpaddd(extractDest512, extractDest512, destAddr);
+ }
+ a->vmovups(destAddr, extractDest512);
+ }
+ }
+ }
+}
+
+/**
+ * Get or Create the AVX512 instructions for 16-bit Accumulation macro-kernel.
+ *
+ */
+template <>
+template <>
+CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::jit_micro_kernel_fp
+CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<inst_set_t::avx512>(
+ bool accum,
+ int32_t mc,
+ int32_t nc,
+ int32_t kc,
+ int32_t /* unused */) {
+ auto kernelSig = std::make_tuple(accum, mc, nc);
+ if (codeCache_.find(kernelSig) != codeCache_.end()) {
+ return codeCache_[kernelSig];
+ }
+
+ code_.reset(false);
+ code_.init(rt_.getCodeInfo());
+ asmjit::X86Assembler assembler(&code_);
+ asmjit::X86Emitter* a = assembler.asEmitter();
+ // ToDo: Dump in a file for debugging
+ // code dumping/logging
+ // asmjit::FileLogger logger(stderr);
+ // code_.setLogger(&logger);
+
+ constexpr int kBlock =
+ PackingTraits<int8_t, int16_t, inst_set_t::avx512>::KCB;
+ constexpr int mRegBlockSize =
+ PackingTraits<int8_t, int16_t, inst_set_t::avx512>::MR;
+ // constexpr int nRegBlockSize =
+ // PackingTraits<int8_t, int16_t, inst_set_t::avx512>::NR;
+ constexpr int row_interleave =
+ PackingTraits<int8_t, int16_t, inst_set_t::avx512>::ROW_INTERLEAVE;
+ int mRegBlocks = mc / mRegBlockSize;
+ int mRegBlocksRem = mc % mRegBlockSize;
+ assert(kc % row_interleave == 0 && "kc must be a multiple of row_interleave");
+ // assert((nc == nRegBlockSize) &&
+ //"nc must be equal to the number of register blocks");
+
+ // arguments to the function created
+ asmjit::X86Gp buffer_A = a->zdi();
+ asmjit::X86Gp buffer_B = a->zsi();
+ asmjit::X86Gp B_pf = a->zdx();
+ asmjit::X86Gp CBase = a->zcx();
+ asmjit::X86Gp kSize = a->gpzRef(8);
+ asmjit::X86Gp ldcReg = a->gpzRef(9);
+
+ asmjit::FuncDetail func;
+ func.init(
+ asmjit::
+ FuncSignature6<void, uint8_t*, int8_t*, int8_t*, int32_t*, int, int>(
+ asmjit::CallConv::kIdHost));
+
+ asmjit::FuncFrameInfo ffi;
+ ffi.setDirtyRegs(
+ asmjit::X86Reg::kKindVec,
+ asmjit::Utils::mask(0, 1, 2, 3, 4, 5, 6, 7) |
+ asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14, 15));
+ ffi.setDirtyRegs(
+ asmjit::X86Reg::kKindGp, asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14));
+
+ asmjit::FuncArgsMapper args(&func);
+ args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg);
+
+ args.updateFrameInfo(ffi);
+
+ asmjit::FuncFrameLayout layout;
+ layout.init(func, ffi);
+
+ asmjit::FuncUtils::emitProlog(a, layout);
+ asmjit::FuncUtils::allocArgs(a, layout, args);
+
+ asmjit::Label Loopk = a->newLabel();
+ asmjit::Label LoopMBlocks = a->newLabel();
+
+ asmjit::X86Gp buffer_B_saved = a->gpzRef(10);
+ asmjit::X86Gp C_Offset = a->gpzRef(11);
+ // asmjit::X86Gp B_pf_saved = a->gpzRef(12);
+ asmjit::X86Gp iIdx = a->gpzRef(13);
+ asmjit::X86Gp kIdx = a->gpzRef(14);
+
+ int colRegs = nc * row_interleave * sizeof(int8_t) / VLEN_;
+ if (mRegBlocks > 0) {
+ // move 0 to iteration variables
+ a->mov(iIdx, 0);
+
+ // save B_buffer address
+ a->mov(buffer_B_saved, buffer_B);
+ // a->mov(B_pf_saved, B_pf);
+
+ a->bind(LoopMBlocks);
+ a->inc(iIdx);
+
+ int rowRegs = mRegBlockSize;
+
+ // init C registers
+ initCRegs<inst_set_t::avx512>(a, rowRegs, colRegs);
+
+ // init k loop index
+ a->mov(kIdx, 0);
+ a->bind(Loopk);
+ // k is incremented by row_interleave
+ a->add(kIdx, row_interleave);
+
+ genComputeBlock<inst_set_t::avx512>(
+ a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock);
+
+ // update buffer_A address for next k iteration
+ a->add(buffer_A, row_interleave * sizeof(uint8_t));
+
+ // update buffer_B address for next k iteration
+ a->add(buffer_B, VLEN_ * colRegs * sizeof(int8_t));
+ // a->add(B_pf, VLEN_ * colRegs * sizeof(int8_t));
+
+ a->cmp(kIdx, kSize);
+ a->jl(Loopk);
+
+ // store C matrix
+ storeCRegs<inst_set_t::avx512>(
+ a, rowRegs, colRegs, C_Offset, ldcReg, accum);
+
+ // increment A for next block
+ a->sub(buffer_A, kSize);
+ a->add(buffer_A, (rowRegs)*kBlock * sizeof(uint8_t));
+ // increment C for next block
+ a->imul(C_Offset, ldcReg, rowRegs * sizeof(int32_t));
+ a->add(CBase, C_Offset);
+ // reset B
+ a->mov(buffer_B, buffer_B_saved);
+ // a->mov(B_pf, B_pf_saved);
+
+ a->cmp(iIdx, mRegBlocks);
+ a->jl(LoopMBlocks);
+ }
+ // generate code for remainder
+ if (mRegBlocksRem > 0) {
+ asmjit::Label LoopkRem = a->newLabel();
+ int rowRegs = mRegBlocksRem;
+
+ // init C registers
+ initCRegs<inst_set_t::avx512>(a, rowRegs, colRegs);
+
+ // init k loop index
+ a->mov(kIdx, 0);
+ a->bind(LoopkRem);
+
+ // k is incremented by row_interleave
+ a->add(kIdx, row_interleave);
+
+ genComputeBlock<inst_set_t::avx512>(
+ a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock);
+
+ // update buffer_A address for next k iteration
+ a->add(buffer_A, row_interleave * sizeof(uint8_t));
+
+ // update buffer_B address for next k iteration
+ a->add(buffer_B, VLEN_ * colRegs * sizeof(int8_t));
+ // a->add(B_pf, VLEN_ * colRegs * sizeof(int8_t));
+
+ a->cmp(kIdx, kSize);
+ a->jl(LoopkRem);
+
+ // store C matrix
+ storeCRegs<inst_set_t::avx512>(
+ a, rowRegs, colRegs, C_Offset, ldcReg, accum);
+ }
+
+ asmjit::FuncUtils::emitEpilog(a, layout);
+
+ jit_micro_kernel_fp fn;
+ asmjit::Error err = rt_.add(&fn, &code_);
+ if (err) {
+ std::cout << "Error: in fn add" << std::endl;
+ return nullptr;
+ }
+ codeCache_[kernelSig] = fn;
+ return fn;
+}
+
+} // namespace fbgemm2
diff --git a/src/GenerateKernelU8S8S32ACC32.cc b/src/GenerateKernelU8S8S32ACC32.cc
new file mode 100644
index 0000000..dc8c6d3
--- /dev/null
+++ b/src/GenerateKernelU8S8S32ACC32.cc
@@ -0,0 +1,310 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+#include <iostream>
+#include "GenerateKernel.h"
+
+namespace fbgemm2 {
+
+namespace x86 = asmjit::x86;
+
+/**
+ * Generate AVX2 instructions for initializing the C registers to 0 in 32-bit
+ * Accumulation kernel.
+ */
+template <>
+template <>
+void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::initCRegs<
+ inst_set_t::avx2>(
+ asmjit::X86Emitter* a,
+ int rowRegs,
+ int colRegs,
+ int leadingDimCReg) {
+ for (int i = 0; i < rowRegs; ++i) {
+ for (int j = 0; j < colRegs; ++j) {
+ a->vxorps(
+ CRegs_avx2_[i * leadingDimCReg + j],
+ CRegs_avx2_[i * leadingDimCReg + j],
+ CRegs_avx2_[i * leadingDimCReg + j]);
+ }
+ }
+}
+
+/**
+ * Generate AVX2 instructions for computing block in the rank-k update of 32-bit
+ * Accmulation kernel.
+ */
+template <>
+template <>
+void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::genComputeBlock<
+ inst_set_t::avx2>(
+ asmjit::X86Emitter* a,
+ asmjit::X86Gp buffer_A,
+ asmjit::X86Gp buffer_B,
+ asmjit::X86Gp B_pf,
+ int rowRegs,
+ int colRegs,
+ int lda,
+ int leadingDimCRegAssign) {
+ // used for matrix A
+ asmjit::X86Ymm AReg = x86::ymm12;
+
+ // used for matrix B
+ asmjit::X86Ymm BReg = x86::ymm13;
+
+ // Contains 16-bit 1s
+ asmjit::X86Ymm oneReg = x86::ymm15;
+
+ // temporary register
+ asmjit::X86Ymm res1 = x86::ymm14;
+
+ for (int j = 0; j < colRegs; ++j) {
+ // load B
+ a->vmovaps(BReg, x86::dword_ptr(buffer_B, j * VLEN_ * sizeof(int8_t)));
+ // load A, broadcast and fmas
+ for (int i = 0; i < rowRegs; ++i) {
+ a->vpbroadcastd(
+ AReg, x86::dword_ptr(buffer_A, (i * lda) * sizeof(uint8_t)));
+ a->vpmaddubsw(res1, AReg, BReg);
+ a->vpmaddwd(res1, oneReg, res1);
+ a->vpaddd(
+ CRegs_avx2_[i * leadingDimCRegAssign + j],
+ res1,
+ CRegs_avx2_[i * leadingDimCRegAssign + j]);
+ }
+ a->prefetcht0(x86::dword_ptr(B_pf, j * VLEN_ * sizeof(int8_t)));
+ }
+}
+
+/**
+ * Generate AVX2 instructions for storing the C registers back to the memory in
+ * 32-bit Accumulation kernel.
+ */
+template <>
+template <>
+void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::storeCRegs<
+ inst_set_t::avx2>(
+ asmjit::X86Emitter* a,
+ int rowRegs,
+ int colRegs,
+ asmjit::X86Gp C_Offset,
+ asmjit::X86Gp ldcReg,
+ bool accum,
+ int leadingDimCRegAssign) {
+ // temp register
+ asmjit::X86Ymm tmpReg = x86::ymm14;
+
+ for (int i = 0; i < rowRegs; ++i) {
+ if (i != 0) {
+ a->add(C_Offset, ldcReg);
+ }
+ for (int j = 0; j < colRegs; ++j) {
+ if (accum) {
+ a->vpaddd(
+ CRegs_avx2_[i * leadingDimCRegAssign + j],
+ CRegs_avx2_[i * leadingDimCRegAssign + j],
+ x86::dword_ptr(a->zcx(), C_Offset, 0, j * 8 * sizeof(int32_t)));
+ }
+ a->vmovups(
+ x86::dword_ptr(a->zcx(), C_Offset, 0, j * 8 * sizeof(int32_t)),
+ CRegs_avx2_[i * leadingDimCRegAssign + j]);
+ }
+ }
+}
+
+/**
+ * Get or Create the AVX2 instructions for 32-bit Accumulation macro-kernel.
+ *
+ */
+template <>
+template <>
+CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::jit_micro_kernel_fp
+CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreate<inst_set_t::avx2>(
+ bool accum,
+ int32_t mc,
+ int32_t nc,
+ int32_t kc,
+ int32_t /* unused */) {
+ auto kernelSig = std::make_tuple(accum, mc, nc);
+ if (codeCache_.find(kernelSig) != codeCache_.end()) {
+ return codeCache_[kernelSig];
+ }
+
+ code_.reset(false);
+ code_.init(rt_.getCodeInfo());
+ asmjit::X86Assembler assembler(&code_);
+ asmjit::X86Emitter* a = assembler.asEmitter();
+ // ToDo: Dump in a file for debugging
+ // code dumping/logging
+ // asmjit::FileLogger logger(stderr);
+ // code_.setLogger(&logger);
+
+ constexpr int kBlock = PackingTraits<int8_t, int32_t, inst_set_t::avx2>::KCB;
+ constexpr int mRegBlockSize =
+ PackingTraits<int8_t, int32_t, inst_set_t::avx2>::MR;
+ constexpr int row_interleave =
+ PackingTraits<int8_t, int32_t, inst_set_t::avx2>::ROW_INTERLEAVE;
+ assert(kc % row_interleave == 0 && "kc must be a multiple of row_interleave");
+ // assert(mc <= 12 && "mc must be <= 12 (available registers constraint)");
+ int mRegBlocks = mc / mRegBlockSize;
+ int mRegBlocksRem = mc % mRegBlockSize;
+
+ // arguments to the function created
+ asmjit::X86Gp buffer_A = a->zdi();
+ asmjit::X86Gp buffer_B = a->zsi();
+ asmjit::X86Gp B_pf = a->zdx();
+ asmjit::X86Gp CBase = a->zcx();
+ asmjit::X86Gp kSize = a->gpzRef(8);
+ asmjit::X86Gp ldcReg = a->gpzRef(9);
+
+ asmjit::FuncDetail func;
+ func.init(
+ asmjit::
+ FuncSignature6<void, uint8_t*, int8_t*, int8_t*, int32_t*, int, int>(
+ asmjit::CallConv::kIdHost));
+
+ asmjit::FuncFrameInfo ffi;
+ ffi.setDirtyRegs(
+ asmjit::X86Reg::kKindVec,
+ asmjit::Utils::mask(0, 1, 2, 3, 4, 5, 6, 7) |
+ asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14, 15));
+ ffi.setDirtyRegs(
+ asmjit::X86Reg::kKindGp, asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14));
+
+ asmjit::FuncArgsMapper args(&func);
+ args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg);
+
+ args.updateFrameInfo(ffi);
+
+ asmjit::FuncFrameLayout layout;
+ layout.init(func, ffi);
+
+ asmjit::FuncUtils::emitProlog(a, layout);
+ asmjit::FuncUtils::allocArgs(a, layout, args);
+
+ asmjit::Label Loopk = a->newLabel();
+ asmjit::Label LoopMBlocks = a->newLabel();
+
+ asmjit::X86Gp buffer_B_saved = a->gpzRef(10);
+ asmjit::X86Gp C_Offset = a->gpzRef(11);
+ asmjit::X86Gp B_pf_saved = a->gpzRef(12);
+ asmjit::X86Gp iIdx = a->gpzRef(13);
+ asmjit::X86Gp kIdx = a->gpzRef(14);
+ // asmjit::X86Gp B_pf = a->gpzRef(8);
+
+ asmjit::X86Ymm oneReg = x86::ymm15;
+ // create 16-bit 1s
+ // i.e., oneReg[0:15] contains 0x0001, oneReg[16:31] contains 0x0001
+ // and so on
+ a->vpcmpeqw(oneReg, oneReg, oneReg);
+ a->vpsrlw(oneReg, oneReg, 15);
+ a->imul(ldcReg, ldcReg, sizeof(int32_t));
+ a->mov(C_Offset, 0);
+
+ int colRegs = nc * row_interleave * sizeof(int8_t) / VLEN_;
+ if (mRegBlocks > 0) {
+ // move 0 to iteration variables
+ a->mov(iIdx, 0);
+
+ // save B_buffer address
+ a->mov(buffer_B_saved, buffer_B);
+ a->mov(B_pf_saved, B_pf);
+
+ a->bind(LoopMBlocks);
+ a->inc(iIdx);
+
+ int rowRegs = mRegBlockSize;
+
+ // init C registers
+ initCRegs<inst_set_t::avx2>(a, rowRegs, colRegs, colRegs);
+
+ // init k loop index
+ a->mov(kIdx, 0);
+ a->bind(Loopk);
+
+ // k is incremented by row_interleave
+ a->add(kIdx, row_interleave);
+
+ genComputeBlock<inst_set_t::avx2>(
+ a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock, colRegs);
+
+ // update buffer_A address for next k iteration
+ a->add(buffer_A, row_interleave * sizeof(uint8_t));
+
+ // update buffer_B address for next k iteration
+ a->add(buffer_B, VLEN_ * colRegs * sizeof(int8_t));
+ a->add(B_pf, VLEN_ * colRegs * sizeof(int8_t));
+
+ // a->add(B_pf, 32*sizeof(float));
+
+ a->cmp(kIdx, kSize);
+ a->jl(Loopk);
+
+ // store C matrix
+ storeCRegs<inst_set_t::avx2>(
+ a, rowRegs, colRegs, C_Offset, ldcReg, accum, colRegs);
+
+ // increment A for next block
+ a->sub(buffer_A, kSize);
+ a->add(buffer_A, (rowRegs)*kBlock * sizeof(uint8_t));
+
+ // increment C for next block
+ a->imul(C_Offset, ldcReg, rowRegs);
+ a->add(CBase, C_Offset);
+ a->mov(C_Offset, 0);
+
+ // reset B
+ a->mov(buffer_B, buffer_B_saved);
+ a->mov(B_pf, B_pf_saved);
+ a->cmp(iIdx, mRegBlocks);
+ a->jl(LoopMBlocks);
+ }
+ // generate code for remainder
+ if (mRegBlocksRem > 0) {
+ asmjit::Label LoopkRem = a->newLabel();
+ int rowRegs = mRegBlocksRem;
+
+ // init C registers
+ initCRegs<inst_set_t::avx2>(a, rowRegs, colRegs, colRegs);
+
+ // init k loop index
+ a->mov(kIdx, 0);
+ a->bind(LoopkRem);
+
+ // k is incremented by row_interleave
+ a->add(kIdx, row_interleave);
+
+ genComputeBlock<inst_set_t::avx2>(
+ a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock, colRegs);
+
+ // update buffer_A address for next k iteration
+ a->add(buffer_A, row_interleave * sizeof(uint8_t));
+
+ // update buffer_B address for next k iteration
+ a->add(buffer_B, VLEN_ * colRegs * sizeof(int8_t));
+ a->add(B_pf, VLEN_ * colRegs * sizeof(int8_t));
+
+ a->cmp(kIdx, kSize);
+ a->jl(LoopkRem);
+
+ // store C matrix
+ storeCRegs<inst_set_t::avx2>(
+ a, rowRegs, colRegs, C_Offset, ldcReg, accum, colRegs);
+ }
+
+ asmjit::FuncUtils::emitEpilog(a, layout);
+
+ jit_micro_kernel_fp fn;
+ asmjit::Error err = rt_.add(&fn, &code_);
+ if (err) {
+ std::cout << "Error: in fn add" << std::endl;
+ return nullptr;
+ }
+ codeCache_[kernelSig] = fn;
+ return fn;
+}
+
+} // namespace fbgemm2
diff --git a/src/GenerateKernelU8S8S32ACC32_avx512.cc b/src/GenerateKernelU8S8S32ACC32_avx512.cc
new file mode 100644
index 0000000..5cd5684
--- /dev/null
+++ b/src/GenerateKernelU8S8S32ACC32_avx512.cc
@@ -0,0 +1,312 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+#include <iostream>
+#include "GenerateKernel.h"
+
+namespace fbgemm2 {
+
+namespace x86 = asmjit::x86;
+
+/**
+ * Generate AVX512 instructions for initializing the C registers to 0 in 32-bit
+ * Accumulation kernel.
+ */
+template <>
+template <>
+void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::initCRegs<
+ inst_set_t::avx512>(
+ asmjit::X86Emitter* a,
+ int rowRegs,
+ int colRegs,
+ int leadingDimCReg) {
+ for (int i = 0; i < rowRegs; ++i) {
+ for (int j = 0; j < colRegs; ++j) {
+ a->vxorps(
+ CRegs_avx512_[i * leadingDimCReg + j],
+ CRegs_avx512_[i * leadingDimCReg + j],
+ CRegs_avx512_[i * leadingDimCReg + j]);
+ }
+ }
+}
+
+/**
+ * Generate AVX512 instructions for computing block in the rank-k update of
+ * 32-bit Accmulation kernel.
+ */
+template <>
+template <>
+void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::genComputeBlock<
+ inst_set_t::avx512>(
+ asmjit::X86Emitter* a,
+ asmjit::X86Gp buffer_A,
+ asmjit::X86Gp buffer_B,
+ asmjit::X86Gp B_pf,
+ int rowRegs,
+ int colRegs,
+ int lda,
+ int leadingDimCRegAssign) {
+ // used for matrix A
+ asmjit::X86Zmm AReg = x86::zmm31;
+
+ // used for matrix B
+ asmjit::X86Zmm BReg = x86::zmm30;
+
+ // Contains 16-bit 1s
+ asmjit::X86Zmm oneReg = x86::zmm29;
+
+ // temporary register
+ asmjit::X86Zmm res1 = x86::zmm28;
+
+ for (int j = 0; j < colRegs; ++j) {
+ // load B
+ a->vmovaps(BReg, x86::dword_ptr(buffer_B, j * VLEN_ * sizeof(int8_t)));
+ // load A, broadcast and fmas
+ for (int i = 0; i < rowRegs; ++i) {
+ a->vpbroadcastd(
+ AReg, x86::dword_ptr(buffer_A, (i * lda) * sizeof(uint8_t)));
+ a->vpmaddubsw(res1, AReg, BReg);
+ a->vpmaddwd(res1, oneReg, res1);
+ a->vpaddd(
+ CRegs_avx512_[i * leadingDimCRegAssign + j],
+ res1,
+ CRegs_avx512_[i * leadingDimCRegAssign + j]);
+ }
+ a->prefetcht0(x86::dword_ptr(B_pf, j * VLEN_ * sizeof(int8_t)));
+ }
+}
+
+/**
+ * Generate AVX512 instructions for storing the C registers back to the memory
+ * in 32-bit Accumulation kernel.
+ */
+template <>
+template <>
+void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::storeCRegs<
+ inst_set_t::avx512>(
+ asmjit::X86Emitter* a,
+ int rowRegs,
+ int colRegs,
+ asmjit::X86Gp C_Offset,
+ asmjit::X86Gp ldcReg,
+ bool accum,
+ int leadingDimCRegAssign) {
+ // temp register
+ asmjit::X86Zmm tmpReg = x86::zmm28;
+
+ for (int i = 0; i < rowRegs; ++i) {
+ if (i != 0) {
+ a->add(C_Offset, ldcReg);
+ }
+ for (int j = 0; j < colRegs; ++j) {
+ if (accum) {
+ a->vpaddd(
+ CRegs_avx512_[i * leadingDimCRegAssign + j],
+ CRegs_avx512_[i * leadingDimCRegAssign + j],
+ x86::dword_ptr(a->zcx(), C_Offset, 0, j * 16 * sizeof(int32_t)));
+ }
+ a->vmovups(
+ x86::dword_ptr(a->zcx(), C_Offset, 0, j * 16 * sizeof(int32_t)),
+ CRegs_avx512_[i * leadingDimCRegAssign + j]);
+ }
+ }
+}
+
+/**
+ * Get or Create the AVX512 instructions for 32-bit Accumulation macro-kernel.
+ *
+ */
+template <>
+template <>
+CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::jit_micro_kernel_fp
+CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreate<inst_set_t::avx512>(
+ bool accum,
+ int32_t mc,
+ int32_t nc,
+ int32_t kc,
+ int32_t /* unused */) {
+ auto kernelSig = std::make_tuple(accum, mc, nc);
+ if (codeCache_.find(kernelSig) != codeCache_.end()) {
+ return codeCache_[kernelSig];
+ }
+
+ code_.reset(false);
+ code_.init(rt_.getCodeInfo());
+ asmjit::X86Assembler assembler(&code_);
+ asmjit::X86Emitter* a = assembler.asEmitter();
+ // ToDo: Dump in a file for debugging
+ // code dumping/logging
+ // asmjit::FileLogger logger(stderr);
+ // code_.setLogger(&logger);
+
+ constexpr int kBlock =
+ PackingTraits<int8_t, int32_t, inst_set_t::avx512>::KCB;
+ constexpr int mRegBlockSize =
+ PackingTraits<int8_t, int32_t, inst_set_t::avx512>::MR;
+ constexpr int row_interleave =
+ PackingTraits<int8_t, int32_t, inst_set_t::avx512>::ROW_INTERLEAVE;
+ assert(kc % row_interleave == 0 && "kc must be a multiple of row_interleave");
+ // assert(mc <= 12 && "mc must be <= 12 (available registers constraint)");
+ int mRegBlocks = mc / mRegBlockSize;
+ int mRegBlocksRem = mc % mRegBlockSize;
+
+ // arguments to the function created
+ asmjit::X86Gp buffer_A = a->zdi();
+ asmjit::X86Gp buffer_B = a->zsi();
+ asmjit::X86Gp B_pf = a->zdx();
+ asmjit::X86Gp CBase = a->zcx();
+ asmjit::X86Gp kSize = a->gpzRef(8);
+ asmjit::X86Gp ldcReg = a->gpzRef(9);
+
+ asmjit::FuncDetail func;
+ func.init(
+ asmjit::
+ FuncSignature6<void, uint8_t*, int8_t*, int8_t*, int32_t*, int, int>(
+ asmjit::CallConv::kIdHost));
+
+ asmjit::FuncFrameInfo ffi;
+ ffi.setDirtyRegs(
+ asmjit::X86Reg::kKindVec,
+ asmjit::Utils::mask(0, 1, 2, 3, 4, 5, 6, 7) |
+ asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14, 15));
+ ffi.setDirtyRegs(
+ asmjit::X86Reg::kKindGp, asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14));
+
+ asmjit::FuncArgsMapper args(&func);
+ args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg);
+
+ args.updateFrameInfo(ffi);
+
+ asmjit::FuncFrameLayout layout;
+ layout.init(func, ffi);
+
+ asmjit::FuncUtils::emitProlog(a, layout);
+ asmjit::FuncUtils::allocArgs(a, layout, args);
+
+ asmjit::Label Loopk = a->newLabel();
+ asmjit::Label LoopMBlocks = a->newLabel();
+
+ asmjit::X86Gp buffer_B_saved = a->gpzRef(10);
+ asmjit::X86Gp C_Offset = a->gpzRef(11);
+ asmjit::X86Gp B_pf_saved = a->gpzRef(12);
+ asmjit::X86Gp iIdx = a->gpzRef(13);
+ asmjit::X86Gp kIdx = a->gpzRef(14);
+ // asmjit::X86Gp B_pf = a->gpzRef(8);
+
+ asmjit::X86Zmm oneReg = x86::zmm29;
+ // create 16-bit 1s
+ // i.e., oneReg[0:15] contains 0x0001, oneReg[16:31] contains 0x0001
+ // and so on
+ // a->vpcmpeqw(oneReg, oneReg, oneReg);
+ a->vpternlogd(oneReg, oneReg, oneReg, 0xff);
+ a->vpsrlw(oneReg, oneReg, 15);
+ a->imul(ldcReg, ldcReg, sizeof(int32_t));
+ a->mov(C_Offset, 0);
+
+ int colRegs = nc * row_interleave * sizeof(int8_t) / VLEN_;
+ if (mRegBlocks > 0) {
+ // move 0 to iteration variables
+ a->mov(iIdx, 0);
+
+ // save B_buffer address
+ a->mov(buffer_B_saved, buffer_B);
+ a->mov(B_pf_saved, B_pf);
+
+ a->bind(LoopMBlocks);
+ a->inc(iIdx);
+
+ int rowRegs = mRegBlockSize;
+
+ // init C registers
+ initCRegs<inst_set_t::avx512>(a, rowRegs, colRegs, colRegs);
+
+ // init k loop index
+ a->mov(kIdx, 0);
+ a->bind(Loopk);
+
+ // k is incremented by row_interleave
+ a->add(kIdx, row_interleave);
+
+ genComputeBlock<inst_set_t::avx512>(
+ a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock, colRegs);
+
+ // update buffer_A address for next k iteration
+ a->add(buffer_A, row_interleave * sizeof(uint8_t));
+
+ // update buffer_B address for next k iteration
+ a->add(buffer_B, VLEN_ * colRegs * sizeof(int8_t));
+ a->add(B_pf, VLEN_ * colRegs * sizeof(int8_t));
+
+ // a->add(B_pf, 32*sizeof(float));
+
+ a->cmp(kIdx, kSize);
+ a->jl(Loopk);
+
+ // store C matrix
+ storeCRegs<inst_set_t::avx512>(
+ a, rowRegs, colRegs, C_Offset, ldcReg, accum, colRegs);
+
+ // increment A for next block
+ a->sub(buffer_A, kSize);
+ a->add(buffer_A, (rowRegs)*kBlock * sizeof(uint8_t));
+
+ // increment C for next block
+ a->imul(C_Offset, ldcReg, rowRegs);
+ a->add(CBase, C_Offset);
+ a->mov(C_Offset, 0);
+
+ // reset B
+ a->mov(buffer_B, buffer_B_saved);
+ a->mov(B_pf, B_pf_saved);
+ a->cmp(iIdx, mRegBlocks);
+ a->jl(LoopMBlocks);
+ }
+ // generate code for remainder
+ if (mRegBlocksRem > 0) {
+ asmjit::Label LoopkRem = a->newLabel();
+ int rowRegs = mRegBlocksRem;
+
+ // init C registers
+ initCRegs<inst_set_t::avx512>(a, rowRegs, colRegs, colRegs);
+
+ // init k loop index
+ a->mov(kIdx, 0);
+ a->bind(LoopkRem);
+
+ // k is incremented by row_interleave
+ a->add(kIdx, row_interleave);
+
+ genComputeBlock<inst_set_t::avx512>(
+ a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock, colRegs);
+
+ // update buffer_A address for next k iteration
+ a->add(buffer_A, row_interleave * sizeof(uint8_t));
+
+ // update buffer_B address for next k iteration
+ a->add(buffer_B, VLEN_ * colRegs * sizeof(int8_t));
+ a->add(B_pf, VLEN_ * colRegs * sizeof(int8_t));
+
+ a->cmp(kIdx, kSize);
+ a->jl(LoopkRem);
+
+ // store C matrix
+ storeCRegs<inst_set_t::avx512>(
+ a, rowRegs, colRegs, C_Offset, ldcReg, accum, colRegs);
+ }
+
+ asmjit::FuncUtils::emitEpilog(a, layout);
+
+ jit_micro_kernel_fp fn;
+ asmjit::Error err = rt_.add(&fn, &code_);
+ if (err) {
+ std::cout << "Error: in fn add" << std::endl;
+ return nullptr;
+ }
+ codeCache_[kernelSig] = fn;
+ return fn;
+}
+
+} // namespace fbgemm2
diff --git a/src/PackAMatrix.cc b/src/PackAMatrix.cc
new file mode 100644
index 0000000..543d99b
--- /dev/null
+++ b/src/PackAMatrix.cc
@@ -0,0 +1,165 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+#include <cpuinfo.h>
+#include <cassert>
+#include <iomanip>
+#include <iostream>
+#include "fbgemm/Fbgemm.h"
+
+namespace fbgemm2 {
+
+template <typename T, typename accT>
+PackAMatrix<T, accT>::PackAMatrix(
+ matrix_op_t trans,
+ int32_t nRow,
+ int32_t nCol,
+ const T* smat,
+ int32_t ld,
+ inpType* pmat,
+ int32_t groups,
+ accT zero_pt)
+ : PackMatrix<PackAMatrix<T, accT>, T, accT>(nRow, nCol, pmat, zero_pt),
+ trans_(trans),
+ smat_(smat),
+ ld_(ld),
+ G_(groups) {
+ assert(G_ == 1 && "Groups != 1 not supported yet");
+
+ if (cpuinfo_has_x86_avx512f()) {
+ BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx512>::MCB;
+ BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx512>::KCB;
+ row_interleave_B_ =
+ PackingTraits<T, accT, inst_set_t::avx512>::ROW_INTERLEAVE;
+ } else if (cpuinfo_has_x86_avx2()) {
+ BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx2>::MCB;
+ BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx2>::KCB;
+ row_interleave_B_ =
+ PackingTraits<T, accT, inst_set_t::avx2>::ROW_INTERLEAVE;
+ } else {
+ // TODO: Have default slower path
+ assert(0 && "unsupported architecure");
+ }
+ if (!pmat) {
+ BaseType::buf_ =
+ (T*)aligned_alloc(64, BaseType::brow_ * BaseType::bcol_ * sizeof(T));
+ }
+}
+
+template <typename T, typename accT>
+void PackAMatrix<T, accT>::pack(const block_type_t& block) {
+ block_type_t block_p = {block.row_start,
+ block.row_size,
+ block.col_start,
+ (block.col_size + row_interleave_B_ - 1) /
+ row_interleave_B_ * row_interleave_B_};
+
+ BaseType::packedBlock(block_p);
+ bool tr = (trans_ == matrix_op_t::Transpose);
+ T* out = BaseType::getBuf();
+ if (tr) {
+ for (int i = block.row_start; i < block.row_start + block.row_size; ++i) {
+ for (int j = block.col_start; j < block.col_start + block.col_size; ++j) {
+ T val = smat_[i + ld_ * j];
+ out[addr(i, j) - addr(block.row_start, block.col_start)] = val;
+ }
+ // zero fill
+ // Please note that we zero fill, not zero_pt fill, because for
+ // requantization original, i.e., not padded, dimensions are used. If we
+ // were to use padded dimensions for requantization, we would zero_pt
+ // fill.
+ // For example, consider the following dot product:
+ // A = .3(5-15), .3(20-15) //.3 is scale and 15 is zero_pt
+ // B = .4(1+10), .4(4+10) // .4 is scale and -10 is zero_pt
+ //
+ // numElements(A) = 2 and numElements(B) = 2
+ //
+ // Dot product is (real): -3*4.4+1.5*5.6 = -4.8
+ // Dot product is (quantized): 5*1+20*4 = 85
+ //
+ // requantization: .3*.4(85 - (5+20)*(-10) - (1+4)*(15) +
+ // numElements(A)*(15)(-10)) = -4.8
+ //
+ // In the above adding one more element zero in the quantized domain,
+ // i.e., the quantized vectors become:
+ // A_q = 5, 20, 0
+ // B_q = 1, 4, 0
+ //
+ // and requantization with numElements(A) = 2 will produce the same
+ // answer (-4.8).
+ //
+ // Also in the above adding one more element zero_pt in the quantized
+ // domain, i.e., the quantized vectors become:
+ // A_q = 5, 20, 15
+ // B_q = 1, 4, -10
+ //
+ // and requantization with numElements(A) = 3 will produce the same
+ // answer (-4.8).
+ for (int j = block.col_start + block.col_size;
+ j < block_p.col_start + block_p.col_size;
+ ++j) {
+ out[addr(i, j) - addr(block.row_start, block.col_start)] = 0;
+ }
+ }
+ } else {
+ for (int i = block.row_start; i < block.row_start + block.row_size; ++i) {
+ int buf_idx = i - block.row_start;
+ memcpy(
+ out + buf_idx * BaseType::blockColSize(),
+ smat_ + i * ld_ + block.col_start,
+ block.col_size * sizeof(T));
+ // zero fill
+ for (int j = block.col_size; j < block_p.col_size; ++j) {
+ out[buf_idx * BaseType::blockColSize() + j] = 0;
+ }
+ }
+ }
+}
+
+template <typename T, typename accT>
+int32_t PackAMatrix<T, accT>::addr(int32_t r, int32_t c) const {
+ int32_t block_row_id = r / BaseType::blockRowSize();
+ int32_t brow_offset = (block_row_id * BaseType::blockCols()) *
+ (BaseType::blockRowSize() * BaseType::blockColSize());
+
+ int32_t block_col_id = c / BaseType::blockColSize();
+ int32_t bcol_offset =
+ block_col_id * BaseType::blockRowSize() * BaseType::blockColSize();
+ int32_t block_offset = brow_offset + bcol_offset;
+ int32_t inblock_offset =
+ (r % BaseType::blockRowSize()) * BaseType::blockColSize() +
+ (c % BaseType::blockColSize());
+
+ int32_t index = block_offset + inblock_offset;
+
+ return index;
+}
+
+template <typename T, typename accT>
+void PackAMatrix<T, accT>::printPackedMatrix(std::string name) {
+ std::cout << name << ":"
+ << "[" << BaseType::numPackedRows() << ", "
+ << BaseType::numPackedCols() << "]" << std::endl;
+
+ T* out = BaseType::getBuf();
+ for (auto r = 0; r < BaseType::numPackedRows(); ++r) {
+ for (auto c = 0; c < BaseType::numPackedCols(); ++c) {
+ T val = out[addr(r, c)];
+ if (std::is_integral<T>::value) {
+ // cast to int64 because cout doesn't print int8_t type directly
+ std::cout << std::setw(5) << static_cast<int64_t>(val) << " ";
+ } else {
+ std::cout << std::setw(5) << val << " ";
+ }
+ }
+ std::cout << std::endl;
+ }
+ std::cout << std::endl;
+}
+
+template class PackAMatrix<uint8_t, int32_t>;
+template class PackAMatrix<uint8_t, int16_t>;
+} // namespace fbgemm2
diff --git a/src/PackAWithIm2Col.cc b/src/PackAWithIm2Col.cc
new file mode 100644
index 0000000..7012289
--- /dev/null
+++ b/src/PackAWithIm2Col.cc
@@ -0,0 +1,146 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+#include <cpuinfo.h>
+#include <cassert>
+#include <iomanip>
+#include <iostream>
+#include "fbgemm/Fbgemm.h"
+
+namespace fbgemm2 {
+
+template <typename T, typename accT>
+PackAWithIm2Col<T, accT>::PackAWithIm2Col(
+ const conv_param_t& conv_p,
+ const T* sdata,
+ inpType* pmat,
+ int32_t zero_pt,
+ int32_t* row_offset)
+ : PackMatrix<PackAWithIm2Col<T, accT>, T, accT>(
+ conv_p.MB * conv_p.OH * conv_p.OW,
+ conv_p.KH * conv_p.KW * conv_p.IC,
+ pmat,
+ zero_pt),
+ conv_p_(conv_p),
+ sdata_(sdata) {
+ assert(conv_p.G == 1 && "Groups != 1 not supported yet");
+
+ if (cpuinfo_has_x86_avx512f()) {
+ BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx512>::MCB;
+ BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx512>::KCB;
+ row_interleave_B_ =
+ PackingTraits<T, accT, inst_set_t::avx512>::ROW_INTERLEAVE;
+ } else if (cpuinfo_has_x86_avx2()) {
+ BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx2>::MCB;
+ BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx2>::KCB;
+ row_interleave_B_ =
+ PackingTraits<T, accT, inst_set_t::avx2>::ROW_INTERLEAVE;
+ } else {
+ // TODO: Have default slower path
+ assert(0 && "unsupported architecure");
+ }
+ if (pmat) {
+ BaseType::buf_ = pmat;
+ } else {
+ BaseType::bufAllocatedHere_ = true;
+ BaseType::buf_ = static_cast<T*>(
+ aligned_alloc(64, BaseType::brow_ * BaseType::bcol_ * sizeof(T)));
+ }
+ if (row_offset) {
+ row_offset_ = row_offset;
+ } else {
+ rowOffsetAllocatedHere = true;
+ row_offset_ = static_cast<int32_t*>(
+ aligned_alloc(64, BaseType::brow_ * sizeof(int32_t)));
+ }
+}
+
+template <typename T, typename accT>
+void PackAWithIm2Col<T, accT>::pack(const block_type_t& block) {
+ block_type_t block_p = {block.row_start,
+ block.row_size,
+ block.col_start,
+ (block.col_size + row_interleave_B_ - 1) /
+ row_interleave_B_ * row_interleave_B_};
+
+ BaseType::packedBlock(block_p);
+ T* out = BaseType::getBuf();
+ for (int i = block.row_start; i < block.row_start + block.row_size; ++i) {
+ int n = i / (conv_p_.OH * conv_p_.OW);
+ int hw = i % (conv_p_.OH * conv_p_.OW);
+ int w = hw % conv_p_.OW;
+ int h = hw / conv_p_.OW;
+ for (int j = block.col_start; j < block.col_start + block.col_size; ++j) {
+ int c = j % conv_p_.IC;
+ int rs = j / conv_p_.IC;
+ int s = rs % conv_p_.KW;
+ int r = rs / conv_p_.KW;
+
+ int w_in = -conv_p_.pad_w + w * conv_p_.stride_w + s;
+ int h_in = -conv_p_.pad_h + h * conv_p_.stride_h + r;
+ // Please note that padding for convolution should be filled with zero_pt
+ if (h_in < 0 || h_in >= conv_p_.IH || w_in < 0 || w_in >= conv_p_.IW) {
+ out[(i - block.row_start) * BaseType::blockColSize() +
+ (j - block.col_start)] = BaseType::zeroPoint();
+ } else {
+ out[(i - block.row_start) * BaseType::blockColSize() +
+ (j - block.col_start)] = sdata_
+ [((n * conv_p_.IH + h_in) * conv_p_.IW + w_in) * conv_p_.IC + c];
+ }
+ }
+ // zero fill
+ // Please see the comment in PackAMatrix.cc for zero vs zero_pt fill.
+ for (int j = block.col_start + block.col_size;
+ j < block_p.col_start + block_p.col_size;
+ ++j) {
+ out[(i - block.row_start) * BaseType::blockColSize() +
+ (j - block.col_start)] = 0;
+ }
+ }
+}
+
+template <typename T, typename accT>
+void PackAWithIm2Col<T, accT>::printPackedMatrix(std::string name) {
+ std::cout << name << ":"
+ << "[" << BaseType::numPackedRows() << ", "
+ << BaseType::numPackedCols() << "]" << std::endl;
+
+ T* out = BaseType::getBuf();
+ for (auto r = 0; r < BaseType::numPackedRows(); ++r) {
+ for (auto c = 0; c < BaseType::numPackedCols(); ++c) {
+ T val = out[ r * BaseType::blockColSize() + c ];
+ if (std::is_integral<T>::value) {
+ // cast to int64 because cout doesn't print int8_t type directly
+ std::cout << std::setw(5) << static_cast<int64_t>(val) << " ";
+ } else {
+ std::cout << std::setw(5) << val << " ";
+ }
+ }
+ std::cout << std::endl;
+ }
+ std::cout << std::endl;
+}
+
+template <typename T, typename accT>
+int PackAWithIm2Col<T, accT>::rowOffsetBufferSize() {
+ if (cpuinfo_initialize()) {
+ if (cpuinfo_has_x86_avx512f()) {
+ return PackingTraits<T, accT, inst_set_t::avx512>::MCB;
+ } else if (cpuinfo_has_x86_avx2()) {
+ return PackingTraits<T, accT, inst_set_t::avx2>::MCB;
+ } else {
+ // TODO: Have default slower path
+ assert(0 && "unsupported architecture");
+ return -1;
+ }
+ } else {
+ throw std::runtime_error("Failed to initialize cpuinfo!");
+ }
+}
+
+template class PackAWithIm2Col<uint8_t, int32_t>;
+template class PackAWithIm2Col<uint8_t, int16_t>;
+} // namespace fbgemm2
diff --git a/src/PackBMatrix.cc b/src/PackBMatrix.cc
new file mode 100644
index 0000000..30d94f8
--- /dev/null
+++ b/src/PackBMatrix.cc
@@ -0,0 +1,144 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+#include <cpuinfo.h>
+#include <cassert>
+#include <iomanip>
+#include <iostream>
+#include "fbgemm/Fbgemm.h"
+
+namespace fbgemm2 {
+
+template <typename T, typename accT>
+PackBMatrix<T, accT>::PackBMatrix(
+ matrix_op_t trans,
+ int32_t nRow,
+ int32_t nCol,
+ const T* smat,
+ int32_t ld,
+ inpType* pmat,
+ int32_t groups,
+ accT zero_pt)
+ : PackMatrix<PackBMatrix<T, accT>, T, accT>(nRow, nCol, pmat, zero_pt),
+ trans_(trans),
+ smat_(smat),
+ ld_(ld),
+ G_(groups) {
+ assert(G_ == 1 && "Groups != 1 not supported yet");
+
+ if (cpuinfo_has_x86_avx512f()) {
+ BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx512>::KCB;
+ BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx512>::NCB;
+ row_interleave_ =
+ PackingTraits<T, accT, inst_set_t::avx512>::ROW_INTERLEAVE;
+ } else if (cpuinfo_has_x86_avx2()) {
+ BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx2>::KCB;
+ BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx2>::NCB;
+ row_interleave_ = PackingTraits<T, accT, inst_set_t::avx2>::ROW_INTERLEAVE;
+ } else {
+ // Error
+ assert(0 && "unknown architecure");
+ }
+ block_type_t block{0, BaseType::numRows(), 0, BaseType::numCols()};
+ BaseType::packedBlock(block);
+ if (!pmat) {
+ BaseType::bufAllocatedHere_ = true;
+ BaseType::buf_ = (T*)aligned_alloc(
+ 64,
+ BaseType::blockRows() * BaseType::brow_ * BaseType::blockCols() *
+ BaseType::bcol_ * sizeof(T));
+ }
+ pack(block);
+}
+
+template <typename T, typename accT>
+void PackBMatrix<T, accT>::pack(const block_type_t& block) {
+ assert((BaseType::blockRowSize() % row_interleave_) == 0);
+
+ BaseType::packedBlock(block);
+ T* out = BaseType::getBuf();
+ bool tr = (trans_ == matrix_op_t::Transpose);
+ for (int i = block.row_start; i < block.row_start + block.row_size; ++i) {
+ for (int j = block.col_start; j < block.col_start + block.col_size; ++j) {
+ T val = tr ? smat_[i + ld_ * j] : smat_[i * ld_ + j];
+ out[addr(i, j) - addr(block.row_start, block.col_start)] =
+ tconv(val, out[addr(i, j)]);
+ }
+ }
+ // fill the remaining with zero.
+ // Please see the comment in PackAMatrix.cc on zero vs zero_pt fill.
+ for (int i = block.row_start + block.row_size;
+ i < (block.row_start + block.row_size + row_interleave_ - 1) /
+ row_interleave_ * row_interleave_;
+ ++i) {
+ for (int j = block.col_start; j < block.col_start + block.col_size; j++) {
+ out[addr(i, j) - addr(block.row_start, block.col_start)] =
+ tconv(0, out[addr(i, j)]);
+ }
+ }
+}
+
+template <typename T, typename accT>
+int32_t PackBMatrix<T, accT>::addr(int32_t r, int32_t c) const {
+ int32_t block_row_id = r / BaseType::blockRowSize();
+ int32_t brow_offset = (block_row_id * BaseType::blockCols()) *
+ (BaseType::blockRowSize() * BaseType::blockColSize());
+
+ int32_t block_col_id = c / BaseType::blockColSize();
+ int32_t bcol_offset =
+ block_col_id * BaseType::blockRowSize() * BaseType::blockColSize();
+ int32_t block_offset = brow_offset + bcol_offset;
+ int32_t inblock_offset = (r % BaseType::blockRowSize() / row_interleave_) *
+ BaseType::blockColSize() * row_interleave_ +
+ (c % BaseType::blockColSize()) * row_interleave_ + r % row_interleave_;
+
+ int32_t index = block_offset + inblock_offset;
+
+ return index;
+}
+
+template <typename T, typename accT>
+void PackBMatrix<T, accT>::printPackedMatrix(std::string name) {
+ std::cout << name << ":"
+ << "[" << BaseType::numPackedRows() << ", "
+ << BaseType::numPackedCols() << "]" << std::endl;
+ std::cout << "block size:"
+ << "[" << BaseType::blockRowSize() << ", "
+ << BaseType::blockColSize() << "]" << std::endl;
+
+ T* out = BaseType::getBuf();
+ for (auto nr = 0; nr < BaseType::blockRows(); ++nr) {
+ auto rows = (nr == BaseType::blockRows() - 1) ? BaseType::lastBrow()
+ : BaseType::blockRowSize();
+ for (auto nc = 0; nc < BaseType::blockCols(); ++nc) {
+ std::cout << "block:" << nr << ", " << nc << std::endl;
+ auto cols = (nc == BaseType::blockCols() - 1) ? BaseType::lastBcol()
+ : BaseType::blockColSize();
+ for (auto r = 0; r < (rows + row_interleave_ - 1) / row_interleave_;
+ ++r) {
+ for (auto c = 0; c < cols * row_interleave_; ++c) {
+ T val =
+ out[nr * BaseType::blockCols() * BaseType::blockRowSize() *
+ BaseType::blockColSize() +
+ nc * BaseType::blockRowSize() * BaseType::blockColSize() +
+ r * BaseType::blockColSize() * row_interleave_ + c];
+ if (std::is_integral<T>::value) {
+ // cast to int64 because cout doesn't print int8_t type directly
+ std::cout << std::setw(5) << static_cast<int64_t>(val) << " ";
+ } else {
+ std::cout << std::setw(5) << val << " ";
+ }
+ }
+ std::cout << std::endl;
+ }
+ std::cout << std::endl;
+ }
+ }
+}
+
+template class PackBMatrix<int8_t, int32_t>;
+template class PackBMatrix<int8_t, int16_t>;
+} // namespace fbgemm2
diff --git a/src/PackMatrix.cc b/src/PackMatrix.cc
new file mode 100644
index 0000000..85000ac
--- /dev/null
+++ b/src/PackMatrix.cc
@@ -0,0 +1,86 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+#include <cpuinfo.h>
+#include <iomanip>
+#include <stdexcept>
+#include <type_traits>
+#include "fbgemm/ConvUtils.h"
+#include "fbgemm/Fbgemm.h"
+
+namespace fbgemm2 {
+
+template <typename PT, typename inpType, typename accType>
+PackMatrix<PT, inpType, accType>::PackMatrix(
+ int32_t rows,
+ int32_t cols,
+ inpType* buf,
+ int32_t zero_pt)
+ : buf_(buf), nrows_(rows), ncols_(cols), zero_pt_(zero_pt) {
+ bufAllocatedHere_ = false;
+ if (!cpuinfo_initialize()) {
+ throw std::runtime_error("Failed to initialize cpuinfo!");
+ }
+}
+
+template <typename PT, typename inpType, typename accType>
+int PackMatrix<PT, inpType, accType>::packedBufferSize(int rows, int cols) {
+ if (cpuinfo_has_x86_avx512f()) {
+ if (isA) {
+ return PackingTraits<inpType, accType, inst_set_t::avx512>::MCB *
+ PackingTraits<inpType, accType, inst_set_t::avx512>::KCB;
+ } else {
+ int rowBlock = PackingTraits<inpType, accType, inst_set_t::avx512>::KCB;
+ int colBlock = PackingTraits<inpType, accType, inst_set_t::avx512>::NCB;
+ return (((rows + rowBlock - 1) / rowBlock) * rowBlock) *
+ (((cols + colBlock - 1) / colBlock) * colBlock);
+ }
+ } else if (cpuinfo_has_x86_avx2()) {
+ if (isA) {
+ return PackingTraits<inpType, accType, inst_set_t::avx2>::MCB *
+ PackingTraits<inpType, accType, inst_set_t::avx2>::KCB;
+ } else {
+ int rowBlock = PackingTraits<inpType, accType, inst_set_t::avx2>::KCB;
+ int colBlock = PackingTraits<inpType, accType, inst_set_t::avx2>::NCB;
+ return (((rows + rowBlock - 1) / rowBlock) * rowBlock) *
+ (((cols + colBlock - 1) / colBlock) * colBlock);
+ }
+ } else {
+ // TODO: Have default slower path
+ assert(0 && "unsupported architecure");
+ }
+ return -1;
+}
+
+// int32 accumulation
+template class PackMatrix<PackAMatrix<uint8_t, int32_t>, uint8_t, int32_t>;
+
+template class PackMatrix<
+ PackAWithRowOffset<uint8_t, int32_t>,
+ uint8_t,
+ int32_t>;
+
+template class PackMatrix<PackAWithIm2Col<uint8_t, int32_t>, uint8_t, int32_t>;
+
+template class PackMatrix<
+ PackAWithQuantRowOffset<uint8_t, int32_t>,
+ uint8_t,
+ int32_t>;
+
+template class PackMatrix<PackBMatrix<int8_t, int32_t>, int8_t, int32_t>;
+
+// int16 accumulation
+template class PackMatrix<PackAWithIm2Col<uint8_t, int16_t>, uint8_t, int16_t>;
+
+template class PackMatrix<
+ PackAWithRowOffset<uint8_t, int16_t>,
+ uint8_t,
+ int16_t>;
+
+template class PackMatrix<PackAMatrix<uint8_t, int16_t>, uint8_t, int16_t>;
+
+template class PackMatrix<PackBMatrix<int8_t, int16_t>, int8_t, int16_t>;
+} // namespace fbgemm2
diff --git a/src/PackWithQuantRowOffset.cc b/src/PackWithQuantRowOffset.cc
new file mode 100644
index 0000000..74eaade
--- /dev/null
+++ b/src/PackWithQuantRowOffset.cc
@@ -0,0 +1,230 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+#include <cpuinfo.h>
+#include <cassert>
+#include <cmath>
+#include <cstring>
+#include <iomanip>
+#include <iostream>
+#include <stdexcept>
+#include "fbgemm/Fbgemm.h"
+
+namespace fbgemm2 {
+
+template <typename T, typename accT>
+PackAWithQuantRowOffset<T, accT>::PackAWithQuantRowOffset(
+ matrix_op_t trans,
+ int32_t nRow,
+ int32_t nCol,
+ const float* smat,
+ int32_t ld,
+ inpType* pmat,
+ float scale,
+ int32_t zero_pt,
+ int32_t groups,
+ int32_t* row_offset)
+ : PackMatrix<PackAWithQuantRowOffset<T, accT>, T, accT>(
+ nRow,
+ nCol,
+ pmat,
+ zero_pt),
+ trans_(trans),
+ smat_(smat),
+ ld_(ld),
+ scale_(scale),
+ G_(groups),
+ row_offset_(row_offset) {
+ assert(G_ == 1 && "Groups != 1 not supported yet");
+
+ rowOffsetAllocatedHere = false;
+
+ if (cpuinfo_has_x86_avx512f()) {
+ BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx512>::MCB;
+ BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx512>::KCB;
+ row_interleave_B_ =
+ PackingTraits<T, accT, inst_set_t::avx512>::ROW_INTERLEAVE;
+ } else if (cpuinfo_has_x86_avx2()) {
+ BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx2>::MCB;
+ BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx2>::KCB;
+ row_interleave_B_ =
+ PackingTraits<T, accT, inst_set_t::avx2>::ROW_INTERLEAVE;
+ } else {
+ // TODO: Have default slower path
+ assert(0 && "unknown architecure");
+ }
+ if (pmat) {
+ BaseType::buf_ = pmat;
+ } else {
+ BaseType::bufAllocatedHere_ = true;
+ BaseType::buf_ =
+ (T*)aligned_alloc(64, BaseType::brow_ * BaseType::bcol_ * sizeof(T));
+ }
+ if (!row_offset_) {
+ rowOffsetAllocatedHere = true;
+ row_offset_ = reinterpret_cast<int32_t*>(
+ aligned_alloc(64, BaseType::brow_ * sizeof(accT)));
+ }
+}
+
+template <typename T, typename accT>
+void PackAWithQuantRowOffset<T, accT>::pack(const block_type_t& block) {
+ assert(block.row_start % BaseType::blockRowSize() == 0);
+ assert(block.col_start % BaseType::blockColSize() == 0);
+ assert(block.row_size <= BaseType::blockRowSize());
+ assert(block.col_size <= BaseType::blockColSize());
+
+ block_type_t block_p = {block.row_start,
+ block.row_size,
+ block.col_start,
+ (block.col_size + row_interleave_B_ - 1) /
+ row_interleave_B_ * row_interleave_B_};
+ assert(block_p.col_size <= BaseType::blockColSize());
+ BaseType::packedBlock(block_p);
+
+ T* out = BaseType::getBuf();
+ bool tr = (trans_ == matrix_op_t::Transpose);
+ // accumulate into row offset?
+ bool row_offset_acc = (block.col_start != 0);
+ int32_t* row_offset_buf = getRowOffsetBuffer();
+
+ float smat_transposed[block.row_size * block.col_size];
+ if (tr) {
+ transpose_simd(
+ block.col_size,
+ block.row_size,
+ smat_ + block.col_start * ld_ + block.row_start,
+ ld_,
+ smat_transposed,
+ block.col_size);
+ }
+ const float* smat_temp =
+ tr ? smat_transposed : smat_ + block.row_start * ld_ + block.col_start;
+ int32_t ld_temp = tr ? block.col_size : ld_;
+
+#if defined(__AVX2__) && defined(__FMA__)
+ constexpr int VLEN = 8;
+ __m256 inverse_scale_v = _mm256_set1_ps(1.0f / scale_);
+ __m256i shuffle_mask_v = _mm256_set_epi8(
+ 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
+ 0xff, 0xff, 0xff, 0xff, 0x0c, 0x08, 0x04, 0x00,
+ 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
+ 0xff, 0xff, 0xff, 0xff, 0x0c, 0x08, 0x04, 0x00);
+ __m256i permute_mask_v = _mm256_set_epi32(
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x00);
+#endif
+
+ for (int i = 0; i < block.row_size; ++i) {
+ int32_t row_sum = row_offset_acc ? row_offset_buf[i] : 0;
+ int j = 0;
+#if defined(__AVX2__) && defined(__FMA__)
+ static_assert(
+ std::is_same<T, uint8_t>::value,
+ "PackAWithQuantRowOffset<T, accT>::pack only works for T == uint8_t");
+ for (; j < block.col_size / VLEN * VLEN; j += VLEN) {
+ __m256 val_v = _mm256_loadu_ps(smat_temp + i * ld_temp + j);
+ __m256 transformed_v = _mm256_fmadd_ps(
+ val_v, inverse_scale_v, _mm256_set1_ps(BaseType::zeroPoint()));
+ __m256 clipped_v = _mm256_max_ps(
+ _mm256_set1_ps(std::numeric_limits<uint8_t>::min()),
+ _mm256_min_ps(
+ transformed_v,
+ _mm256_set1_ps(std::numeric_limits<uint8_t>::max())));
+ __m256i res_v = _mm256_cvtps_epi32(clipped_v);
+
+ // An instruction sequence to save 8 32-bit integers as 8 8-bit integers
+ res_v = _mm256_shuffle_epi8(res_v, shuffle_mask_v);
+ res_v = _mm256_permutevar8x32_epi32(res_v, permute_mask_v);
+ _mm_storel_epi64(
+ reinterpret_cast<__m128i*>(out + i * BaseType::blockColSize() + j),
+ _mm256_castsi256_si128(res_v));
+
+ for (int j2 = j; j2 < j + VLEN; ++j2) {
+ row_sum += out[i * BaseType::blockColSize() + j2];
+ }
+ }
+#endif
+ for (; j < block.col_size; ++j) {
+ float val = smat_temp[i * ld_temp + j];
+ float transformed = val / scale_ + BaseType::zeroPoint();
+ float clipped = std::min<float>(
+ std::max<float>(transformed, std::numeric_limits<uint8_t>::min()),
+ std::numeric_limits<uint8_t>::max());
+ T res = round(clipped);
+ row_sum += res;
+ out[i * BaseType::blockColSize() + j] = res;
+ }
+ // zero fill
+ // Please see the comment in PackAMatrix.cc on zero vs zero_pt fill.
+ for (; j < block_p.col_size; ++j) {
+ out[i * BaseType::blockColSize() + j] = 0;
+ }
+ row_offset_buf[i] = row_sum;
+ }
+}
+
+template <typename T, typename accT>
+int32_t PackAWithQuantRowOffset<T, accT>::addr(int32_t r, int32_t c) const {
+ int32_t block_row_id = r / BaseType::blockRowSize();
+ int32_t brow_offset = (block_row_id * BaseType::blockCols()) *
+ (BaseType::blockRowSize() * BaseType::blockColSize());
+
+ int32_t block_col_id = c / BaseType::blockColSize();
+ int32_t bcol_offset =
+ block_col_id * BaseType::blockRowSize() * BaseType::blockColSize();
+ int32_t block_offset = brow_offset + bcol_offset;
+ int32_t inblock_offset =
+ (r % BaseType::blockRowSize()) * BaseType::blockColSize() +
+ (c % BaseType::blockColSize());
+
+ int32_t index = block_offset + inblock_offset;
+
+ return index;
+}
+
+template <typename T, typename accT>
+void PackAWithQuantRowOffset<T, accT>::printPackedMatrix(std::string name) {
+ std::cout << name << ":"
+ << "[" << BaseType::numPackedRows() << ", "
+ << BaseType::numPackedCols() << "]" << std::endl;
+
+ T* out = BaseType::getBuf();
+ for (auto r = 0; r < BaseType::numPackedRows(); ++r) {
+ for (auto c = 0; c < BaseType::numPackedCols(); ++c) {
+ T val = out[addr(r, c)];
+ if (std::is_integral<T>::value) {
+ // cast to int64 because cout doesn't print int8_t type directly
+ std::cout << std::setw(5) << static_cast<int64_t>(val) << " ";
+ } else {
+ std::cout << std::setw(5) << val << " ";
+ }
+ }
+ std::cout << std::endl;
+ }
+ std::cout << std::endl;
+}
+
+template <typename T, typename accT>
+int PackAWithQuantRowOffset<T, accT>::rowOffsetBufferSize() {
+ if (cpuinfo_initialize()) {
+ if (cpuinfo_has_x86_avx512f()) {
+ // TODO: avx512 path
+ // Currently use avx2 code
+ return PackingTraits<T, accT, inst_set_t::avx2>::MCB;
+ } else if (cpuinfo_has_x86_avx2()) {
+ return PackingTraits<T, accT, inst_set_t::avx2>::MCB;
+ } else {
+ assert(0 && "unsupported architecture");
+ return -1;
+ }
+ } else {
+ throw std::runtime_error("Failed to initialize cpuinfo!");
+ }
+}
+
+template class PackAWithQuantRowOffset<uint8_t, int32_t>;
+
+} // namespace fbgemm2
diff --git a/src/PackWithRowOffset.cc b/src/PackWithRowOffset.cc
new file mode 100644
index 0000000..8722723
--- /dev/null
+++ b/src/PackWithRowOffset.cc
@@ -0,0 +1,211 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+#include <cassert>
+#include <cstring>
+#include <iomanip>
+#include <iostream>
+#include <stdexcept>
+#include <cpuinfo.h>
+#include "fbgemm/Fbgemm.h"
+
+namespace fbgemm2 {
+
+template <typename T, typename accT>
+PackAWithRowOffset<T, accT>::PackAWithRowOffset(
+ matrix_op_t trans,
+ uint32_t nRow,
+ uint32_t nCol,
+ const T* smat,
+ uint32_t ld,
+ inpType* pmat,
+ uint32_t groups,
+ int32_t zero_pt,
+ int32_t* row_offset)
+ : PackMatrix<PackAWithRowOffset<T, accT>, T, accT>(
+ nRow,
+ nCol,
+ pmat,
+ zero_pt),
+ trans_(trans),
+ smat_(smat),
+ ld_(ld),
+ G_(groups),
+ row_offset_(row_offset) {
+ assert(G_ == 1 && "Groups != 1 not supported yet");
+
+ rowOffsetAllocatedHere = false;
+
+ if (cpuinfo_has_x86_avx512f()) {
+ BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx512>::MCB;
+ BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx512>::KCB;
+ row_interleave_B_ =
+ PackingTraits<T, accT, inst_set_t::avx512>::ROW_INTERLEAVE;
+ } else if (cpuinfo_has_x86_avx2()) {
+ BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx2>::MCB;
+ BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx2>::KCB;
+ row_interleave_B_ =
+ PackingTraits<T, accT, inst_set_t::avx2>::ROW_INTERLEAVE;
+ } else {
+ //TODO: Have default slower path
+ assert(0 && "unknown architecure");
+ }
+ if (pmat) {
+ BaseType::buf_ = pmat;
+ } else {
+ BaseType::bufAllocatedHere_ = true;
+ BaseType::buf_ =
+ (T*)aligned_alloc(64, BaseType::brow_ * BaseType::bcol_ * sizeof(T));
+ }
+ if (!row_offset_) {
+ rowOffsetAllocatedHere = true;
+ row_offset_ = static_cast<int32_t*>(aligned_alloc(64,
+ BaseType::brow_ * sizeof(int32_t)));
+ }
+}
+
+template <typename T, typename accT>
+void PackAWithRowOffset<T, accT>::pack(const block_type_t& block) {
+ assert(block.row_start % BaseType::blockRowSize() == 0);
+ assert(block.col_start % BaseType::blockColSize() == 0);
+ assert(block.row_size <= BaseType::blockRowSize());
+ assert(block.col_size <= BaseType::blockColSize());
+
+ block_type_t block_p = {block.row_start,
+ block.row_size,
+ block.col_start,
+ (block.col_size + row_interleave_B_ - 1) /
+ row_interleave_B_ * row_interleave_B_};
+ assert(block_p.col_size <= BaseType::blockColSize());
+ BaseType::packedBlock(block_p);
+
+ T* out = BaseType::getBuf();
+ bool tr = (trans_ == matrix_op_t::Transpose);
+ // accumulate into row offset?
+ bool row_offset_acc = (block.col_start != 0);
+ int32_t* row_offset_buf = getRowOffsetBuffer();
+ if (tr) {
+ for (int i = block.row_start; i < block.row_start + block.row_size; ++i) {
+ int32_t row_sum = row_offset_acc ?
+ row_offset_buf[i - block.row_start] : 0;
+ for (int j = block.col_start; j < block.col_start + block.col_size; ++j) {
+ T val = smat_[i + ld_ * j];
+ row_sum += val;
+ out[(i - block.row_start) * BaseType::blockColSize() +
+ (j - block.col_start)] = val;
+ }
+ row_offset_buf[i - block.row_start] = row_sum;
+ // zero fill
+ // Please see the comment in PackAMatrix.cc on zero vs zero_pt fill.
+ for (int j = block.col_start + block.col_size;
+ j < block_p.col_start + block_p.col_size; ++j) {
+ out[(i - block.row_start) * BaseType::blockColSize() +
+ (j - block.col_start)] = 0;
+ }
+ }
+ } else {
+ for (int i = block.row_start; i < block.row_start + block.row_size; ++i) {
+ int buf_idx = i - block.row_start;
+ memcpy(
+ out + buf_idx * BaseType::blockColSize(),
+ smat_ + i * ld_ + block.col_start,
+ block.col_size * sizeof(T));
+ // zero fill
+ for (int j = block.col_size; j < block_p.col_size; ++j) {
+ out[buf_idx * BaseType::blockColSize() + j] = 0;
+ }
+ int32_t row_sum = row_offset_acc ?
+ row_offset_buf[i - block.row_start] : 0;
+ __m256i sum_v = _mm256_setzero_si256();
+ __m256i one_epi16_v = _mm256_set1_epi16(1);
+ __m256i one_epi8_v = _mm256_set1_epi8(1);
+ for (int j = block.col_start;
+ j < block.col_start + block.col_size / 32 * 32;
+ j += 32) {
+ __m256i src_v = _mm256_loadu_si256(
+ reinterpret_cast<__m256i const*>(smat_ + i * ld_ + j));
+ sum_v = _mm256_add_epi32(
+ sum_v,
+ _mm256_madd_epi16(
+ _mm256_maddubs_epi16(src_v, one_epi8_v), one_epi16_v));
+ }
+ for (int j = block.col_start + block.col_size / 32 * 32;
+ j < block.col_start + block.col_size;
+ ++j) {
+ row_sum += smat_[i * ld_ + j];
+ }
+ alignas(64) std::array<int32_t, 8> temp;
+ _mm256_store_si256(reinterpret_cast<__m256i*>(temp.data()), sum_v);
+ for (int k = 0; k < 8; ++k) {
+ row_sum += temp[k];
+ }
+ row_offset_buf[i - block.row_start] = row_sum;
+ }
+ }
+}
+
+template <typename T, typename accT>
+int32_t PackAWithRowOffset<T, accT>::addr(int32_t r, int32_t c) const {
+ int32_t block_row_id = r / BaseType::blockRowSize();
+ int32_t brow_offset = (block_row_id * BaseType::blockCols()) *
+ (BaseType::blockRowSize() * BaseType::blockColSize());
+
+ int32_t block_col_id = c / BaseType::blockColSize();
+ int32_t bcol_offset =
+ block_col_id * BaseType::blockRowSize() * BaseType::blockColSize();
+ int32_t block_offset = brow_offset + bcol_offset;
+ int32_t inblock_offset =
+ (r % BaseType::blockRowSize()) * BaseType::blockColSize() +
+ (c % BaseType::blockColSize());
+
+ int32_t index = block_offset + inblock_offset;
+
+ return index;
+}
+
+template <typename T, typename accT>
+void PackAWithRowOffset<T, accT>::printPackedMatrix(std::string name) {
+ std::cout << name << ":"
+ << "[" << BaseType::numPackedRows() << ", "
+ << BaseType::numPackedCols() << "]" << std::endl;
+
+ T* out = BaseType::getBuf();
+ for (auto r = 0; r < BaseType::numPackedRows(); ++r) {
+ for (auto c = 0; c < BaseType::numPackedCols(); ++c) {
+ T val = out[addr(r, c)];
+ if (std::is_integral<T>::value) {
+ // cast to int64 because cout doesn't print int8_t type directly
+ std::cout << std::setw(5) << static_cast<int64_t>(val) << " ";
+ } else {
+ std::cout << std::setw(5) << val << " ";
+ }
+ }
+ std::cout << std::endl;
+ }
+ std::cout << std::endl;
+}
+
+template <typename T, typename accT>
+int PackAWithRowOffset<T, accT>::rowOffsetBufferSize() {
+ if(cpuinfo_initialize()){
+ if (cpuinfo_has_x86_avx512f()) {
+ return PackingTraits<T, accT, inst_set_t::avx512>::MCB;
+ } else if (cpuinfo_has_x86_avx2()) {
+ return PackingTraits<T, accT, inst_set_t::avx2>::MCB;
+ } else {
+ //TODO: Have default slower path
+ assert(0 && "unsupported architecture");
+ return -1;
+ }
+ } else {
+ throw std::runtime_error("Failed to initialize cpuinfo!");
+ }
+}
+
+template class PackAWithRowOffset<uint8_t, int32_t>;
+template class PackAWithRowOffset<uint8_t, int16_t>;
+
+} // namespace fbgemm2
diff --git a/src/RefImplementations.cc b/src/RefImplementations.cc
new file mode 100644
index 0000000..9aedc88
--- /dev/null
+++ b/src/RefImplementations.cc
@@ -0,0 +1,608 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+#include "RefImplementations.h"
+
+#include <cassert>
+#include <cmath>
+
+using namespace std;
+
+namespace fbgemm2 {
+
+void requantize_u8acc32_ref(
+ int M,
+ int N,
+ int ld,
+ const int32_t* inp,
+ uint8_t* out,
+ int32_t C_multiplier,
+ int32_t C_right_shift,
+ int32_t C_zero_point,
+ int32_t A_zero_point,
+ int32_t B_zero_point,
+ const int32_t* row_offsets,
+ const int32_t* col_offsets,
+ const int32_t* bias,
+ bool fuse_relu) {
+ int64_t nudge = 1ll << std::max(0, C_right_shift - 1);
+ for (int i = 0; i < M; ++i) {
+ for (int j = 0; j < N; ++j) {
+ int32_t raw = inp[i * ld + j];
+ raw -= A_zero_point * col_offsets[j];
+ raw -= B_zero_point * row_offsets[i];
+ if (bias) {
+ raw += bias[j];
+ }
+
+ int64_t ab_64 =
+ static_cast<int64_t>(raw) * static_cast<int64_t>(C_multiplier);
+ int64_t rounded = ((ab_64 + nudge) >> C_right_shift) + C_zero_point;
+
+ out[i * ld + j] = std::max(
+ fuse_relu ? static_cast<int64_t>(C_zero_point) : 0l,
+ std::min(255l, rounded));
+ }
+ }
+}
+
+void requantize_u8acc32_ref(
+ int M,
+ int N,
+ int ld,
+ const int32_t* inp,
+ uint8_t* out,
+ float C_multiplier,
+ int32_t C_zero_point,
+ int32_t A_zero_point,
+ int32_t B_zero_point,
+ const int32_t* row_offsets,
+ const int32_t* col_offsets,
+ const int32_t* bias,
+ bool fuse_relu) {
+ for (int i = 0; i < M; ++i) {
+ for (int j = 0; j < N; ++j) {
+ int32_t raw = inp[i * ld + j];
+ raw -= A_zero_point * col_offsets[j];
+ raw -= B_zero_point * row_offsets[i];
+ if (bias) {
+ raw += bias[j];
+ }
+
+ float result = raw * C_multiplier;
+ long rounded = lrintf(result) + C_zero_point;
+ out[i * ld + j] = std::max(
+ fuse_relu ? static_cast<long>(C_zero_point) : 0l,
+ std::min(255l, rounded));
+ }
+ }
+}
+
+void matmul_u8i8acc32_ref(
+ int M,
+ int N,
+ int K,
+ int lda,
+ int ldb,
+ int ldc,
+ const uint8_t* Aint8,
+ const int8_t* Bint8,
+ int32_t* Cint32) {
+ for (int j = 0; j < N; ++j) {
+ for (int i = 0; i < M; ++i) {
+ int32_t sum = 0;
+ for (int k = 0; k < K; ++k) {
+ sum += static_cast<int32_t>(Aint8[i * lda + k]) *
+ static_cast<int32_t>(Bint8[k * ldb + j]);
+ }
+ Cint32[i * ldc + j] = sum;
+ }
+ }
+}
+
+void matmul_u8i8acc16_ref(
+ int M,
+ int N,
+ int K,
+ int lda,
+ int ldb,
+ int ldc,
+ int brow,
+ const uint8_t* Aint8,
+ const int8_t* Bint8,
+ int32_t* Cint32) {
+ for (int j = 0; j < N; ++j) {
+ for (int i = 0; i < M; ++i) {
+ int32_t sum = 0, sum_32bit = 0;
+ for (int k = 0; k < K; k += 2) {
+ int a0 = Aint8[i * lda + k];
+ int b0 = Bint8[k * ldb + j];
+ int a1 = 0, b1 = 0;
+ if (k + 1 < K) {
+ a1 = Aint8[i * lda + k + 1];
+ b1 = Bint8[(k + 1) * ldb + j];
+ }
+ sum = clip_16bit(sum + clip_16bit(a0 * b0 + a1 * b1));
+ if ((k % brow) == (brow - 2)) {
+ sum_32bit += sum;
+ sum = 0;
+ }
+ }
+ Cint32[i * ldc + j] = sum_32bit + sum;
+ }
+ }
+}
+
+void matmul_fp_ref(
+ int M,
+ int N,
+ int K,
+ int lda,
+ int ldb,
+ int ldc,
+ const float* Afp32,
+ const float* Bfp32,
+ float* Cfp32) {
+ for (int j = 0; j < N; ++j) {
+ for (int i = 0; i < M; ++i) {
+ float sum = 0;
+ for (int k = 0; k < K; ++k) {
+ sum += Afp32[i * lda + k] * Bfp32[k * ldb + j];
+ }
+ Cfp32[i * ldc + j] = sum;
+ }
+ }
+}
+
+void row_offsets_u8acc32_ref(
+ int M,
+ int K,
+ int ld,
+ const uint8_t* Aint8,
+ int32_t* row_offsets) {
+ // row offset
+ for (int i = 0; i < M; ++i) {
+ int32_t sum = 0;
+ for (int k = 0; k < K; ++k) {
+ sum += static_cast<int32_t>(Aint8[i * ld + k]);
+ }
+ row_offsets[i] = sum;
+ }
+}
+
+void col_offsets_with_zero_pt_s8acc32_ref(
+ int K,
+ int N,
+ int ld,
+ const int8_t* Bint8,
+ int32_t B_zero_point,
+ int32_t* col_offsets) {
+ for (int j = 0; j < N; ++j) {
+ int32_t sum = 0;
+ for (int k = 0; k < K; ++k) {
+ sum += Bint8[k * ld + j];
+ }
+ col_offsets[j] = sum - B_zero_point * K;
+ }
+}
+
+void spmdm_ref(
+ int M,
+ const uint8_t* A,
+ int lda,
+ fbgemm2::CompressedSparseColumn& B,
+ bool accumulation,
+ int32_t* C,
+ int ldc) {
+ int N = B.NumOfCols();
+ if (!accumulation) {
+ for (int i = 0; i < M; ++i) {
+ for (int j = 0; j < N; ++j) {
+ C[i * ldc + j] = 0;
+ }
+ }
+ }
+ for (int j = 0; j < N; ++j) {
+ for (int k = B.ColPtr()[j]; k < B.ColPtr()[j + 1]; ++k) {
+ int row = B.RowIdx()[k];
+ int w = B.Values()[k];
+ for (int i = 0; i < M; ++i) {
+ C[i * ldc + j] += A[i * lda + row] * w;
+ }
+ }
+ } // for each column of B
+}
+
+int32_t clip_16bit(int32_t x) {
+ if (x > std::numeric_limits<int16_t>::max()) {
+ return std::min<int>(std::numeric_limits<int16_t>::max(), x);
+ } else if (x < std::numeric_limits<int16_t>::min()) {
+ return std::max<int>(std::numeric_limits<int16_t>::min(), x);
+ } else {
+ return x;
+ }
+}
+
+void im2col_ref(
+ const conv_param_t& conv_p,
+ const std::uint8_t* A,
+ std::int32_t A_zero_point,
+ std::uint8_t* Ao) {
+ for (int n = 0; n < conv_p.MB; ++n) {
+ for (int h = 0; h < conv_p.OH; ++h) {
+ for (int w = 0; w < conv_p.OW; ++w) {
+ for (int r = 0; r < conv_p.KH; ++r) {
+ int h_in = -conv_p.pad_h + h * conv_p.stride_h + r;
+ for (int s = 0; s < conv_p.KW; ++s) {
+ int w_in = -conv_p.pad_w + w * conv_p.stride_w + s;
+ for (int c = 0; c < conv_p.IC; ++c) {
+ // Ai: NHWC: NH_0W_0 x C_0
+ std::uint8_t val =
+ h_in < 0 || h_in >= conv_p.IH || w_in < 0 || w_in >= conv_p.IW
+ ? A_zero_point
+ : A[((n * conv_p.IH + h_in) * conv_p.IW + w_in) * conv_p.IC +
+ c];
+ // Ao: NHWC: NH_1W_1 x RSC_0
+ Ao[((((n * conv_p.OH + h) * conv_p.OW + w) * conv_p.KH + r) *
+ conv_p.KW +
+ s) *
+ conv_p.IC +
+ c] = val;
+ } // for each c
+ } // for each s
+ } // for each r
+ } // for each w
+ } // for each h
+ } // for each n
+}
+
+void conv_ref(
+ const conv_param_t& conv_p,
+ const std::uint8_t* A,
+ std::int32_t A_zero_point,
+ const std::int8_t* B,
+ std::int32_t* C) {
+ // filters are assumed to be in RSCK format
+ assert(conv_p.G == 1 && "Groups != 1 not supported yet");
+
+ for (int n = 0; n < conv_p.MB; ++n) {
+ for (int h = 0; h < conv_p.OH; ++h) {
+ for (int w = 0; w < conv_p.OW; ++w) {
+ for (int k = 0; k < conv_p.OC; ++k) {
+ int sum = 0;
+ for (int r = 0; r < conv_p.KH; ++r) {
+ int h_in = -conv_p.pad_h + h * conv_p.stride_h + r;
+ for (int s = 0; s < conv_p.KW; ++s) {
+ int w_in = -conv_p.pad_w + w * conv_p.stride_w + s;
+ for (int c = 0; c < conv_p.IC; ++c) {
+ int a = h_in < 0 || h_in >= conv_p.IH || w_in < 0 ||
+ w_in >= conv_p.IW
+ ? A_zero_point
+ : A[((n * conv_p.IH + h_in) * conv_p.IW + w_in) *
+ conv_p.IC +
+ c];
+ int b =
+ B[((r * conv_p.KW + s) * conv_p.IC + c) * conv_p.OC + k];
+ sum += a * b;
+ } // for each c
+ } // for each s
+ } // for each r
+ C[((n * conv_p.OH + h) * conv_p.OW + w) * conv_p.OC + k] = sum;
+ } // for each k
+ } // for each w
+ } // for each h
+ } // for each n
+}
+
+void depthwise_3x3_pad_1_ref(
+ int N,
+ int H,
+ int W,
+ int K,
+ int stride_h,
+ int stride_w,
+ int32_t A_zero_point,
+ const uint8_t* A,
+ const int8_t* B,
+ int32_t* C) {
+ constexpr int R = 3, S = 3;
+ constexpr int PAD_T = 1, PAD_B = 1, PAD_L = 1, PAD_R = 1;
+ int H_OUT = (H + PAD_T + PAD_B - R) / stride_h + 1;
+ int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1;
+
+ for (int n = 0; n < N; ++n) {
+ for (int h = 0; h < H_OUT; ++h) {
+ for (int w = 0; w < W_OUT; ++w) {
+ for (int k = 0; k < K; ++k) {
+ int sum = 0;
+ for (int r = 0; r < R; ++r) {
+ int h_in = -PAD_T + h * stride_h + r;
+ for (int s = 0; s < S; ++s) {
+ int w_in = -PAD_L + w * stride_w + s;
+ int a = h_in < 0 || h_in >= H || w_in < 0 || w_in >= W
+ ? A_zero_point
+ : A[((n * H + h_in) * W + w_in) * K + k];
+ int b = B[(k * R + r) * S + s];
+ sum += a * b;
+ }
+ }
+ C[((n * H_OUT + h) * W_OUT + w) * K + k] = sum;
+ }
+ }
+ }
+ } // for each n
+};
+
+void depthwise_3x3_pad_1_ref(
+ int N,
+ int H,
+ int W,
+ int K,
+ int stride_h,
+ int stride_w,
+ int32_t A_zero_point,
+ const uint8_t* A,
+ int32_t B_zero_point,
+ const int8_t* B,
+ float C_multiplier,
+ int32_t C_zero_point,
+ uint8_t* C,
+ const int32_t* col_offsets,
+ const int32_t* bias) {
+ constexpr int R = 3, S = 3;
+ constexpr int PAD_T = 1, PAD_B = 1, PAD_L = 1, PAD_R = 1;
+ int H_OUT = (H + PAD_T + PAD_B - R) / stride_h + 1;
+ int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1;
+
+ vector<int32_t> C_int32(N * H_OUT * W_OUT * K);
+ depthwise_3x3_pad_1_ref(
+ N, H, W, K, stride_h, stride_w, A_zero_point, A, B, C_int32.data());
+
+ vector<int32_t> row_offsets(N * H_OUT * W_OUT * K);
+ for (int n = 0; n < N; ++n) {
+ for (int h = 0; h < H_OUT; ++h) {
+ for (int w = 0; w < W_OUT; ++w) {
+ for (int k = 0; k < K; ++k) {
+ int sum = 0;
+ for (int r = 0; r < R; ++r) {
+ int h_in = -PAD_T + h * stride_h + r;
+ for (int s = 0; s < S; ++s) {
+ int w_in = -PAD_L + w * stride_w + s;
+ int a = h_in < 0 || h_in >= H || w_in < 0 || w_in >= W
+ ? A_zero_point
+ : A[((n * H + h_in) * W + w_in) * K + k];
+ sum += a;
+ }
+ }
+ row_offsets[((n * H_OUT + h) * W_OUT + w) * K + k] = sum;
+ }
+ }
+ }
+ } // for each n
+
+ for (int i = 0; i < N * H_OUT * W_OUT; ++i) {
+ for (int k = 0; k < K; ++k) {
+ requantize_u8acc32_ref(
+ 1,
+ 1,
+ 1,
+ C_int32.data() + i * K + k,
+ C + i * K + k,
+ C_multiplier,
+ C_zero_point,
+ A_zero_point,
+ B_zero_point,
+ &row_offsets[i * K + k],
+ col_offsets + k,
+ bias ? bias + k : nullptr);
+ }
+ }
+};
+
+void depthwise_3x3_per_channel_quantization_pad_1_ref(
+ int N,
+ int H,
+ int W,
+ int K,
+ int stride_h,
+ int stride_w,
+ int32_t A_zero_point,
+ const uint8_t* A,
+ const int32_t* B_zero_point,
+ const int8_t* B,
+ const float* C_multiplier,
+ int32_t C_zero_point,
+ uint8_t* C,
+ const int32_t* col_offsets,
+ const int32_t* bias) {
+ constexpr int R = 3, S = 3;
+ constexpr int PAD_T = 1, PAD_B = 1, PAD_L = 1, PAD_R = 1;
+ int H_OUT = (H + PAD_T + PAD_B - R) / stride_h + 1;
+ int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1;
+
+ vector<int32_t> C_int32(N * H_OUT * W_OUT * K);
+ depthwise_3x3_pad_1_ref(
+ N, H, W, K, stride_h, stride_w, A_zero_point, A, B, C_int32.data());
+
+ vector<int32_t> row_offsets(N * H_OUT * W_OUT * K);
+ for (int n = 0; n < N; ++n) {
+ for (int h = 0; h < H_OUT; ++h) {
+ for (int w = 0; w < W_OUT; ++w) {
+ for (int k = 0; k < K; ++k) {
+ int sum = 0;
+ for (int r = 0; r < R; ++r) {
+ int h_in = -PAD_T + h * stride_h + r;
+ for (int s = 0; s < S; ++s) {
+ int w_in = -PAD_L + w * stride_w + s;
+ int a = h_in < 0 || h_in >= H || w_in < 0 || w_in >= W
+ ? A_zero_point
+ : A[((n * H + h_in) * W + w_in) * K + k];
+ sum += a;
+ }
+ }
+ row_offsets[((n * H_OUT + h) * W_OUT + w) * K + k] = sum;
+ }
+ }
+ }
+ } // for each n
+
+ for (int i = 0; i < N * H_OUT * W_OUT; ++i) {
+ for (int k = 0; k < K; ++k) {
+ requantize_u8acc32_ref(
+ 1,
+ 1,
+ 1,
+ C_int32.data() + i * K + k,
+ C + i * K + k,
+ C_multiplier[k],
+ C_zero_point,
+ A_zero_point,
+ B_zero_point[k],
+ &row_offsets[i * K + k],
+ col_offsets + k,
+ bias ? bias + k : nullptr);
+ }
+ }
+};
+
+void depthwise_3x3x3_pad_1_ref(
+ int N,
+ int T,
+ int H,
+ int W,
+ int K,
+ int stride_t,
+ int stride_h,
+ int stride_w,
+ int32_t A_zero_point,
+ const uint8_t* A,
+ const int8_t* B,
+ int32_t* C) {
+ constexpr int K_T = 3, K_H = 3, K_W = 3;
+ constexpr int PAD_P = 1, PAD_N = 1, PAD_T = 1, PAD_B = 1, PAD_L = 1,
+ PAD_R = 1;
+ int T_OUT = (T + PAD_P + PAD_N - K_T) / stride_t + 1;
+ int H_OUT = (H + PAD_T + PAD_B - K_H) / stride_h + 1;
+ int W_OUT = (W + PAD_L + PAD_R - K_W) / stride_w + 1;
+
+ for (int n = 0; n < N; ++n) {
+ for (int t = 0; t < T_OUT; ++t) {
+ for (int h = 0; h < H_OUT; ++h) {
+ for (int w = 0; w < W_OUT; ++w) {
+ for (int k = 0; k < K; ++k) {
+ int sum = 0;
+ for (int k_t = 0; k_t < K_T; ++k_t) {
+ int t_in = -PAD_P + t * stride_t + k_t;
+ for (int k_h = 0; k_h < K_H; ++k_h) {
+ int h_in = -PAD_T + h * stride_h + k_h;
+ for (int k_w = 0; k_w < K_W; ++k_w) {
+ int w_in = -PAD_L + w * stride_w + k_w;
+ int a = t_in < 0 || t_in >= T || h_in < 0 || h_in >= H ||
+ w_in < 0 || w_in >= W
+ ? A_zero_point
+ : A[(((n * T + t_in) * H + h_in) * W + w_in) * K + k];
+ int b = B[((k * K_T + k_t) * K_H + k_h) * K_W + k_w];
+ sum += a * b;
+ }
+ }
+ }
+ C[(((n * T_OUT + t) * H_OUT + h) * W_OUT + w) * K + k] = sum;
+ }
+ } // w
+ } // h
+ } // t
+ } // for each n
+};
+
+void depthwise_3x3x3_pad_1_ref(
+ int N,
+ int T,
+ int H,
+ int W,
+ int K,
+ int stride_t,
+ int stride_h,
+ int stride_w,
+ int32_t A_zero_point,
+ const uint8_t* A,
+ int32_t B_zero_point,
+ const int8_t* B,
+ float C_multiplier,
+ int32_t C_zero_point,
+ uint8_t* C,
+ const int32_t* col_offsets,
+ const int32_t* bias) {
+ constexpr int K_T = 3, K_H = 3, K_W = 3;
+ constexpr int PAD_P = 1, PAD_N = 1, PAD_T = 1, PAD_B = 1, PAD_L = 1,
+ PAD_R = 1;
+ int T_OUT = (T + PAD_P + PAD_N - K_T) / stride_t + 1;
+ int H_OUT = (H + PAD_T + PAD_B - K_H) / stride_h + 1;
+ int W_OUT = (W + PAD_L + PAD_R - K_W) / stride_w + 1;
+
+ vector<int32_t> C_int32(N * T_OUT * H_OUT * W_OUT * K);
+ depthwise_3x3x3_pad_1_ref(
+ N,
+ T,
+ H,
+ W,
+ K,
+ stride_t,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B,
+ C_int32.data());
+
+ vector<int32_t> row_offsets(N * T_OUT * H_OUT * W_OUT * K);
+ for (int n = 0; n < N; ++n) {
+ for (int t = 0; t < T_OUT; ++t) {
+ for (int h = 0; h < H_OUT; ++h) {
+ for (int w = 0; w < W_OUT; ++w) {
+ for (int k = 0; k < K; ++k) {
+ int sum = 0;
+ for (int k_t = 0; k_t < K_T; ++k_t) {
+ int t_in = -PAD_P + t * stride_t + k_t;
+ for (int k_h = 0; k_h < K_H; ++k_h) {
+ int h_in = -PAD_T + h * stride_h + k_h;
+ for (int k_w = 0; k_w < K_W; ++k_w) {
+ int w_in = -PAD_L + w * stride_w + k_w;
+ int a = t_in < 0 || t_in >= T || h_in < 0 || h_in >= H ||
+ w_in < 0 || w_in >= W
+ ? A_zero_point
+ : A[(((n * T + t_in) * H + h_in) * W + w_in) * K + k];
+ sum += a;
+ }
+ }
+ }
+ row_offsets[(((n * T_OUT + t) * H_OUT + h) * W_OUT + w) * K + k] =
+ sum;
+ }
+ } // w
+ } // h
+ } // t
+ } // for each n
+
+ for (int i = 0; i < N * T_OUT * H_OUT * W_OUT; ++i) {
+ for (int k = 0; k < K; ++k) {
+ requantize_u8acc32_ref(
+ 1,
+ 1,
+ 1,
+ C_int32.data() + i * K + k,
+ C + i * K + k,
+ C_multiplier,
+ C_zero_point,
+ A_zero_point,
+ B_zero_point,
+ &row_offsets[i * K + k],
+ col_offsets + k,
+ bias ? bias + k : nullptr);
+ }
+ }
+};
+
+} // namespace fbgemm2
diff --git a/src/RefImplementations.h b/src/RefImplementations.h
new file mode 100644
index 0000000..e9eaeed
--- /dev/null
+++ b/src/RefImplementations.h
@@ -0,0 +1,268 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+#pragma once
+
+#include <algorithm>
+#include <cstdint>
+
+#include "fbgemm/ConvUtils.h"
+#include "fbgemm/FbgemmI8Spmdm.h"
+
+namespace fbgemm2 {
+
+/**
+ * @brief Reference implementation of requantization step.
+ * int32 multiplier
+ * @params bias can be nullptr
+ */
+void requantize_u8acc32_ref(
+ int M,
+ int N,
+ int ld,
+ const std::int32_t* inp,
+ std::uint8_t* out,
+ std::int32_t C_multiplier,
+ std::int32_t C_right_shift,
+ std::int32_t C_zero_point,
+ std::int32_t A_zero_point,
+ std::int32_t B_zero_point,
+ const std::int32_t* row_offsets,
+ const std::int32_t* col_offsets,
+ const std::int32_t* bias,
+ bool fuse_relu = false);
+
+/**
+ * @brief Reference implementation of requantization step.
+ * float multiplier
+ * @params bias can be nullptr
+ */
+void requantize_u8acc32_ref(
+ int M,
+ int N,
+ int ld,
+ const std::int32_t* inp,
+ std::uint8_t* out,
+ float C_multiplier,
+ std::int32_t C_zero_point,
+ std::int32_t A_zero_point,
+ std::int32_t B_zero_point,
+ const std::int32_t* row_offsets,
+ const std::int32_t* col_offsets,
+ const std::int32_t* bias,
+ bool fuse_relu = false);
+
+/**
+ * @brief Reference implementation of matrix multiply with uint8 for A,
+ * int8 for B, and 32-bit accumulation.
+ */
+void matmul_u8i8acc32_ref(
+ int M,
+ int N,
+ int K,
+ int lda,
+ int ldb,
+ int ldc,
+ const std::uint8_t* Aint8,
+ const std::int8_t* Bint8,
+ std::int32_t* Cint32);
+
+/**
+ * @brief Reference implementation of matrix multiply with uint 8 for A,
+ * int8 for B, and 16-bit accumulation.
+ */
+void matmul_u8i8acc16_ref(
+ int M,
+ int N,
+ int K,
+ int lda,
+ int ldb,
+ int ldc,
+ int brow,
+ const std::uint8_t* Aint8,
+ const std::int8_t* Bint8,
+ std::int32_t* Cint32);
+
+/**
+ * @brief Reference implementation of matrix multiply with fp32 (single
+ * precision) floating point number.
+ */
+void matmul_fp_ref(
+ int M,
+ int N,
+ int K,
+ int lda,
+ int ldb,
+ int ldc,
+ const float* Afp32,
+ const float* Bfp32,
+ float* Cfp32);
+
+/**
+ * @brief Reference implementation to compute row_offsets (sums of rows of A).
+ */
+void row_offsets_u8acc32_ref(
+ int M,
+ int K,
+ int ld,
+ const std::uint8_t* Aint8,
+ std::int32_t* row_offsets);
+
+/**
+ * @brief Reference implementation to compute adjusted col_offsets (sum of
+ * columns of B and adjusted with B_zero_point)
+ */
+void col_offsets_with_zero_pt_s8acc32_ref(
+ int K,
+ int N,
+ int ld,
+ const std::int8_t* Bint8,
+ std::int32_t B_zero_point,
+ std::int32_t* col_offsets);
+
+/**
+ * @brief Reference implementation of SPMDM (sparse matrix times dense matrix).
+ */
+void spmdm_ref(
+ int M,
+ const std::uint8_t* A,
+ int lda,
+ CompressedSparseColumn& B,
+ bool accumulation,
+ std::int32_t* C,
+ int ldc);
+
+/*
+ * @brief Trim a 32-bit integer to a 16-bit integer.
+ */
+int32_t clip_16bit(int32_t x);
+
+/*
+ * @brief Reference implementation of convolution operation.
+ * The activations A are assumed to be in NHiWiC format.
+ * The filters B are assumed to be in RSCK format.
+ * The output C is assumed to be in NHoWoC format.
+ */
+void conv_ref(
+ const conv_param_t& conv_p,
+ const std::uint8_t* A,
+ std::int32_t A_zero_point,
+ const std::int8_t* B,
+ std::int32_t* C);
+
+/*
+ * @brief Reference implementation of im2col operation.
+ * The input A is assumed to be in NHiWiC format.
+ * The output A is assumed to be in NHoWoRSC format.
+ */
+void im2col_ref(
+ const conv_param_t& conv_p,
+ const std::uint8_t* A,
+ std::int32_t A_zero_point,
+ std::uint8_t* Ao);
+
+/*
+ * @brief Reference implementation of depthwise convolution with a 3x3 filter
+ * and padding size 1.
+ */
+void depthwise_3x3_pad_1_ref(
+ int N,
+ int H,
+ int W,
+ int K,
+ int stride_h,
+ int stride_w,
+ std::int32_t A_zero_point,
+ const std::uint8_t* A,
+ const std::int8_t* B,
+ std::int32_t* C);
+
+/*
+ * @brief Reference implementation of depthwise convolution with a 3x3 filter
+ * and padding size 1, followed by requantization. (the same scaling factors and
+ * zero points for each channel).
+ */
+void depthwise_3x3_pad_1_ref(
+ int N,
+ int H,
+ int W,
+ int K,
+ int stride_h,
+ int stride_w,
+ std::int32_t A_zero_point,
+ const std::uint8_t* A,
+ std::int32_t B_zero_point,
+ const std::int8_t* B,
+ float C_multiplier,
+ std::int32_t C_zero_point,
+ std::uint8_t* C,
+ const std::int32_t* col_offsets,
+ const std::int32_t* bias);
+
+/*
+ * @brief Reference implementation of depthwise convolution with a 3x3 filter
+ * and padding size 1, followed by requantization. (different scaling factors
+ * and zero points for each channel).
+ */
+void depthwise_3x3_per_channel_quantization_pad_1_ref(
+ int N,
+ int H,
+ int W,
+ int K,
+ int stride_h,
+ int stride_w,
+ std::int32_t A_zero_point,
+ const std::uint8_t* A,
+ const std::int32_t* B_zero_point,
+ const std::int8_t* B,
+ const float* C_multiplier,
+ std::int32_t C_zero_point,
+ std::uint8_t* C,
+ const std::int32_t* col_offsets,
+ const std::int32_t* bias);
+
+/*
+ * @brief Reference implementation of 3D depthwise convolution with a 3x3x3
+ * filter and padding size 1.
+ */
+void depthwise_3x3x3_pad_1_ref(
+ int N,
+ int T,
+ int H,
+ int W,
+ int K,
+ int stride_t,
+ int stride_h,
+ int stride_w,
+ std::int32_t A_zero_point,
+ const std::uint8_t* A,
+ const std::int8_t* B,
+ std::int32_t* C);
+
+/*
+ * @brief Reference implementation of 3D depthwise convolution with a 3x3x3
+ * filter and padding size 1, followed by requantization.
+ */
+void depthwise_3x3x3_pad_1_ref(
+ int N,
+ int T,
+ int H,
+ int W,
+ int K,
+ int stride_t,
+ int stride_h,
+ int stride_w,
+ std::int32_t A_zero_point,
+ const std::uint8_t* A,
+ std::int32_t B_zero_point,
+ const std::int8_t* B,
+ float C_multiplier,
+ std::int32_t C_zero_point,
+ std::uint8_t* C,
+ const std::int32_t* col_offsets,
+ const std::int32_t* bias);
+
+} // namespace fbgemm2
diff --git a/src/Utils.cc b/src/Utils.cc
new file mode 100644
index 0000000..10ab469
--- /dev/null
+++ b/src/Utils.cc
@@ -0,0 +1,357 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+#include "fbgemm/Utils.h"
+#include <cpuinfo.h>
+#include <immintrin.h>
+#include <cassert>
+#include <cinttypes>
+#include <cmath>
+#include <iomanip>
+#include <iostream>
+#include <limits>
+#include <stdexcept>
+
+namespace fbgemm2 {
+
+/**
+ * @brief Compare the reference and test result matrix to check the correctness.
+ * @param ref The buffer for the reference result matrix.
+ * @param test The buffer for the test result matrix.
+ * @param m The height of the reference and test result matrix.
+ * @param n The width of the reference and test result matrix.
+ * @param ld The leading dimension of the reference and test result matrix.
+ * @param max_mismatches_to_report The maximum number of tolerable mismatches to
+ * report.
+ * @param atol The tolerable error.
+ * @retval false If the number of mismatches for reference and test result
+ * matrix exceeds max_mismatches_to_report.
+ * @retval true If the number of mismatches for reference and test result matrix
+ * is tolerable.
+ */
+template <typename T>
+int compare_buffers(
+ const T* ref,
+ const T* test,
+ int m,
+ int n,
+ int ld,
+ int max_mismatches_to_report,
+ float atol /*=1e-3*/) {
+ size_t mismatches = 0;
+ for (int i = 0; i < m; ++i) {
+ for (int j = 0; j < n; ++j) {
+ T reference = ref[i * ld + j], actual = test[i * ld + j];
+ if (std::abs(reference - actual) > atol) {
+ std::cout << "\tmismatch at (" << i << ", " << j << ")" << std::endl;
+ if (std::is_integral<T>::value) {
+ std::cout << "\t reference:" << static_cast<int64_t>(reference)
+ << " test:" << static_cast<int64_t>(actual) << std::endl;
+ } else {
+ std::cout << "\t reference:" << reference << " test:" << actual
+ << std::endl;
+ }
+
+ mismatches++;
+ if (mismatches > max_mismatches_to_report) {
+ return 1;
+ }
+ }
+ }
+ }
+ return 0;
+}
+
+
+/**
+ * @brief Print the matrix.
+ * @param op Transpose type of the matrix.
+ * @param R The height of the matrix.
+ * @param C The width of the matrix.
+ * @param ld The leading dimension of the matrix.
+ * @param name The prefix string before printing the matrix.
+ */
+template <typename T>
+void printMatrix(
+ matrix_op_t op,
+ const T* inp,
+ size_t R,
+ size_t C,
+ size_t ld,
+ std::string name) {
+ // R: number of rows in op(inp)
+ // C: number of cols in op(inp)
+ // ld: leading dimension in inp
+ std::cout << name << ":"
+ << "[" << R << ", " << C << "]" << std::endl;
+ bool tr = (op == matrix_op_t::Transpose);
+ for (auto r = 0; r < R; ++r) {
+ for (auto c = 0; c < C; ++c) {
+ T res = tr ? inp[c * ld + r] : inp[r * ld + c];
+ if (std::is_integral<T>::value) {
+ std::cout << std::setw(5) << static_cast<int64_t>(res) << " ";
+ } else {
+ std::cout << std::setw(5) << res << " ";
+ }
+ }
+ std::cout << std::endl;
+ }
+}
+
+template int compare_buffers<float>(
+ const float* ref,
+ const float* test,
+ int m,
+ int n,
+ int ld,
+ int max_mismatches_to_report,
+ float atol);
+
+template int compare_buffers<int32_t>(
+ const int32_t* ref,
+ const int32_t* test,
+ int m,
+ int n,
+ int ld,
+ int max_mismatches_to_report,
+ float atol);
+
+template int compare_buffers<uint8_t>(
+ const uint8_t* ref,
+ const uint8_t* test,
+ int m,
+ int n,
+ int ld,
+ int max_mismatches_to_report,
+ float atol);
+
+template void printMatrix<float>(
+ matrix_op_t op,
+ const float* inp,
+ size_t R,
+ size_t C,
+ size_t ld,
+ std::string name);
+template void printMatrix<int8_t>(
+ matrix_op_t op,
+ const int8_t* inp,
+ size_t R,
+ size_t C,
+ size_t ld,
+ std::string name);
+template void printMatrix<uint8_t>(
+ matrix_op_t op,
+ const uint8_t* inp,
+ size_t R,
+ size_t C,
+ size_t ld,
+ std::string name);
+template void printMatrix<int32_t>(
+ matrix_op_t op,
+ const int32_t* inp,
+ size_t R,
+ size_t C,
+ size_t ld,
+ std::string name);
+
+
+/**
+ * @brief Reference implementation of matrix transposition: B = A^T.
+ * @param M The height of the matrix.
+ * @param N The width of the matrix.
+ * @param src The memory buffer of the source matrix A.
+ * @param ld_src The leading dimension of the source matrix A.
+ * @param dst The memory buffer of the destination matrix B.
+ * @param ld_dst The leading dimension of the destination matrix B.
+ */
+inline void transpose_ref(
+ int M,
+ int N,
+ const float* src,
+ int ld_src,
+ float* dst,
+ int ld_dst) {
+ for (int i = 0; i < M; i++) {
+ for (int j = 0; j < N; j++) {
+ dst[i + j * ld_dst] = src[i * ld_src + j];
+ }
+ }
+}
+
+inline void
+transpose_kernel_4x4_sse(const float* src, int ld_src, float* dst, int ld_dst) {
+ // load from src to registers
+ // a : a0 a1 a2 a3
+ // b : b0 b1 b2 b3
+ // c : c0 c1 c2 c3
+ // d : d0 d1 d2 d3
+ __m128 a = _mm_loadu_ps(&src[0 * ld_src]);
+ __m128 b = _mm_loadu_ps(&src[1 * ld_src]);
+ __m128 c = _mm_loadu_ps(&src[2 * ld_src]);
+ __m128 d = _mm_loadu_ps(&src[3 * ld_src]);
+
+ // transpose the 4x4 matrix formed by 32-bit elements: Macro from SSE
+ // a : a0 b0 c0 d0
+ // b : a1 b1 c1 d1
+ // c : a2 b2 c2 d2
+ // d : a3 b3 c3 d3
+ _MM_TRANSPOSE4_PS(a, b, c, d);
+
+ // store from registers to dst
+ _mm_storeu_ps(&dst[0 * ld_dst], a);
+ _mm_storeu_ps(&dst[1 * ld_dst], b);
+ _mm_storeu_ps(&dst[2 * ld_dst], c);
+ _mm_storeu_ps(&dst[3 * ld_dst], d);
+}
+inline void transpose_4x4(
+ int M,
+ int N,
+ const float* src,
+ int ld_src,
+ float* dst,
+ int ld_dst) {
+ int ib = 0, jb = 0;
+ for (ib = 0; ib + 4 <= M; ib += 4) {
+ for (jb = 0; jb + 4 <= N; jb += 4) {
+ transpose_kernel_4x4_sse(
+ &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst);
+ }
+ }
+ transpose_ref(ib, N - jb, &src[jb], ld_src, &dst[jb * ld_dst], ld_dst);
+ transpose_ref(M - ib, N, &src[ib * ld_src], ld_src, &dst[ib], ld_dst);
+}
+
+inline void transpose_kernel_8x8_avx2(
+ const float* src,
+ int ld_src,
+ float* dst,
+ int ld_dst) {
+ // load from src to registers
+ // a : a0 a1 a2 a3 a4 a5 a6 a7
+ // b : b0 b1 b2 b3 b4 b5 b6 b7
+ // c : c0 c1 c2 c3 c4 c5 c6 c7
+ // d : d0 d1 d2 d3 d4 d5 d6 d7
+ // e : e0 e1 e2 e3 e4 e5 e6 e7
+ // f : f0 f1 f2 f3 f4 f5 f6 f7
+ // g : g0 g1 g2 g3 g4 g5 g6 g7
+ // h : h0 h1 h2 h3 h4 h5 h6 h7
+ __m256 a = _mm256_loadu_ps(&src[0 * ld_src]);
+ __m256 b = _mm256_loadu_ps(&src[1 * ld_src]);
+ __m256 c = _mm256_loadu_ps(&src[2 * ld_src]);
+ __m256 d = _mm256_loadu_ps(&src[3 * ld_src]);
+ __m256 e = _mm256_loadu_ps(&src[4 * ld_src]);
+ __m256 f = _mm256_loadu_ps(&src[5 * ld_src]);
+ __m256 g = _mm256_loadu_ps(&src[6 * ld_src]);
+ __m256 h = _mm256_loadu_ps(&src[7 * ld_src]);
+
+ __m256 ab0145, ab2367, cd0145, cd2367, ef0145, ef2367, gh0145, gh2367;
+ __m256 abcd04, abcd15, efgh04, efgh15, abcd26, abcd37, efgh26, efgh37;
+ // unpacking and interleaving 32-bit elements
+ // ab0145 : a0 b0 a1 b1 a4 b4 a5 b5
+ // ab2367 : a2 b2 a3 b3 a6 b6 a7 b7
+ // cd0145 : c0 d0 c1 d1 c4 d4 c5 d5
+ // cd2367 : c2 d2 c3 d3 c6 d6 c7 d7
+ // ef0145 : e0 f0 e1 f1 e4 f4 e5 f5
+ // ef2367 : e2 f2 e3 f3 e6 f6 e7 f7
+ // gh0145 : g0 h0 g1 h1 g4 h4 g5 h5
+ // gh2367 : g2 h2 g3 h3 g6 h6 g7 h7
+ ab0145 = _mm256_unpacklo_ps(a, b);
+ ab2367 = _mm256_unpackhi_ps(a, b);
+ cd0145 = _mm256_unpacklo_ps(c, d);
+ cd2367 = _mm256_unpackhi_ps(c, d);
+ ef0145 = _mm256_unpacklo_ps(e, f);
+ ef2367 = _mm256_unpackhi_ps(e, f);
+ gh0145 = _mm256_unpacklo_ps(g, h);
+ gh2367 = _mm256_unpackhi_ps(g, h);
+
+ // shuffling the 32-bit elements
+ // abcd04 : a0 b0 c0 d0 a4 b4 c4 d4
+ // abcd15 : a1 b1 c1 d1 a5 b5 c5 d5
+ // efgh04 : e0 f0 g0 h0 e4 f4 g4 h4
+ // efgh15 : e1 f1 g1 h1 e5 b5 c5 d5
+ // abcd26 : a2 b2 c2 d2 a6 b6 c6 d6
+ // abcd37 : a3 b3 c3 d3 a7 b7 c7 d7
+ // efgh26 : e2 f2 g2 h2 e6 f6 g6 h6
+ // efgh37 : e3 f3 g3 h3 e7 f7 g7 h7
+ abcd04 = _mm256_shuffle_ps(ab0145, cd0145, 0x44);
+ abcd15 = _mm256_shuffle_ps(ab0145, cd0145, 0xee);
+ efgh04 = _mm256_shuffle_ps(ef0145, gh0145, 0x44);
+ efgh15 = _mm256_shuffle_ps(ef0145, gh0145, 0xee);
+ abcd26 = _mm256_shuffle_ps(ab2367, cd2367, 0x44);
+ abcd37 = _mm256_shuffle_ps(ab2367, cd2367, 0xee);
+ efgh26 = _mm256_shuffle_ps(ef2367, gh2367, 0x44);
+ efgh37 = _mm256_shuffle_ps(ef2367, gh2367, 0xee);
+
+ // shuffling 128-bit elements
+ // a : a0 b0 c0 d0 e0 f0 g0 h0
+ // b : a1 b1 c1 d1 e1 f1 g1 h1
+ // c : a2 b2 c2 d2 e2 f2 g2 h2
+ // d : a3 b3 c3 d3 e3 f3 g3 h3
+ // e : a4 b4 c4 d4 e4 f4 g4 h4
+ // f : a5 b5 c5 d5 e5 f5 g5 h5
+ // g : a6 b6 c6 d6 e6 f6 g6 h6
+ // h : a7 b7 c7 d7 e7 f7 g7 h7
+ a = _mm256_permute2f128_ps(efgh04, abcd04, 0x02);
+ b = _mm256_permute2f128_ps(efgh15, abcd15, 0x02);
+ c = _mm256_permute2f128_ps(efgh26, abcd26, 0x02);
+ d = _mm256_permute2f128_ps(efgh37, abcd37, 0x02);
+ e = _mm256_permute2f128_ps(efgh04, abcd04, 0x13);
+ f = _mm256_permute2f128_ps(efgh15, abcd15, 0x13);
+ g = _mm256_permute2f128_ps(efgh26, abcd26, 0x13);
+ h = _mm256_permute2f128_ps(efgh37, abcd37, 0x13);
+
+ // store from registers to dst
+ _mm256_storeu_ps(&dst[0 * ld_dst], a);
+ _mm256_storeu_ps(&dst[1 * ld_dst], b);
+ _mm256_storeu_ps(&dst[2 * ld_dst], c);
+ _mm256_storeu_ps(&dst[3 * ld_dst], d);
+ _mm256_storeu_ps(&dst[4 * ld_dst], e);
+ _mm256_storeu_ps(&dst[5 * ld_dst], f);
+ _mm256_storeu_ps(&dst[6 * ld_dst], g);
+ _mm256_storeu_ps(&dst[7 * ld_dst], h);
+}
+
+void transpose_8x8(
+ int M,
+ int N,
+ const float* src,
+ int ld_src,
+ float* dst,
+ int ld_dst) {
+ int ib = 0, jb = 0;
+ for (ib = 0; ib + 8 <= M; ib += 8) {
+ for (jb = 0; jb + 8 <= N; jb += 8) {
+ transpose_kernel_8x8_avx2(
+ &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst);
+ }
+ }
+ transpose_4x4(ib, N - jb, &src[jb], ld_src, &dst[jb * ld_dst], ld_dst);
+ transpose_4x4(M - ib, N, &src[ib * ld_src], ld_src, &dst[ib], ld_dst);
+}
+
+void transpose_simd(
+ int M,
+ int N,
+ const float* src,
+ int ld_src,
+ float* dst,
+ int ld_dst) {
+ // Run time CPU detection
+ if (cpuinfo_initialize()) {
+ if (cpuinfo_has_x86_avx512f()) {
+ transpose_16x16(M, N, src, ld_src, dst, ld_dst);
+ } else if (cpuinfo_has_x86_avx2()) {
+ transpose_8x8(M, N, src, ld_src, dst, ld_dst);
+ } else {
+ transpose_ref(M, N, src, ld_src, dst, ld_dst);
+ return;
+ }
+ } else {
+ throw std::runtime_error("Failed to initialize cpuinfo!");
+ }
+}
+
+} // namespace fbgemm2
diff --git a/src/Utils_avx512.cc b/src/Utils_avx512.cc
new file mode 100644
index 0000000..b6bf413
--- /dev/null
+++ b/src/Utils_avx512.cc
@@ -0,0 +1,243 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include "fbgemm/Utils.h"
+
+#include <immintrin.h>
+
+namespace fbgemm2 {
+
+inline void transpose_kernel_16x16_avx512(
+ const float* src,
+ int ld_src,
+ float* dst,
+ int ld_dst) {
+ // load from src to registers
+ // a: a0 a1 a2 a3 a4 a5 a6 a7 a8 a9 a10 a11 a12 a13 a14 a15
+ // b: b0 b1 b2 b3 b4 b5 b6 b7 b8 b9 b10 b11 b12 b13 b14 b15
+ // c: c0 c1 c2 c3 c4 c5 c6 c7 c8 c9 c10 c11 c12 c13 c14 c15
+ // d: d0 d1 d2 d3 d4 d5 d6 d7 d8 d9 d10 d11 d12 d13 d14 d15
+ // e: e0 e1 e2 e3 e4 e5 e6 e7 e8 e9 e10 e11 e12 e13 e14 e15
+ // f: f0 f1 f2 f3 f4 f5 f6 f7 f8 f9 f10 f11 f12 f13 f14 f15
+ // g: g0 g1 g2 g3 g4 g5 g6 g7 g8 g9 g10 g11 g12 g13 g14 g15
+ // h: h0 h1 h2 h3 h4 h5 h6 h7 h8 h9 h10 h11 h12 h13 h14 h15
+ // i: i0 i1 i2 i3 i4 i5 i6 i7 i8 i9 i10 i11 i12 i13 i14 i15
+ // j: j0 j1 j2 j3 j4 j5 j6 j7 j8 j9 j10 j11 j12 j13 j14 j15
+ // k: k0 k1 k2 k3 k4 k5 k6 k7 k8 k9 k10 k11 k12 k13 k14 k15
+ // l: l0 l1 l2 l3 l4 l5 l6 l7 l8 l9 l10 l11 l12 l13 l14 l15
+ // m: m0 m1 m2 m3 m4 m5 m6 m7 m8 m9 m10 m11 m12 m13 m14 m15
+ // n: n0 n1 n2 n3 n4 n5 n6 n7 n8 n9 n10 n11 n12 n13 n14 n15
+ // o: o0 o1 o2 o3 o4 o5 o6 o7 o8 o9 o10 o11 o12 o13 o14 o15
+ // p: p0 p1 p2 p3 p4 p5 p6 p7 p8 p9 p10 p11 p12 p13 p14 p15
+ __m512 a = _mm512_loadu_ps(&src[0 * ld_src]);
+ __m512 b = _mm512_loadu_ps(&src[1 * ld_src]);
+ __m512 c = _mm512_loadu_ps(&src[2 * ld_src]);
+ __m512 d = _mm512_loadu_ps(&src[3 * ld_src]);
+ __m512 e = _mm512_loadu_ps(&src[4 * ld_src]);
+ __m512 f = _mm512_loadu_ps(&src[5 * ld_src]);
+ __m512 g = _mm512_loadu_ps(&src[6 * ld_src]);
+ __m512 h = _mm512_loadu_ps(&src[7 * ld_src]);
+ __m512 i = _mm512_loadu_ps(&src[8 * ld_src]);
+ __m512 j = _mm512_loadu_ps(&src[9 * ld_src]);
+ __m512 k = _mm512_loadu_ps(&src[10 * ld_src]);
+ __m512 l = _mm512_loadu_ps(&src[11 * ld_src]);
+ __m512 m = _mm512_loadu_ps(&src[12 * ld_src]);
+ __m512 n = _mm512_loadu_ps(&src[13 * ld_src]);
+ __m512 o = _mm512_loadu_ps(&src[14 * ld_src]);
+ __m512 p = _mm512_loadu_ps(&src[15 * ld_src]);
+
+ __m512 ta, tb, tc, td, te, tf, tg, th, ti, tj, tk, tl, tm, tn, to, tq;
+ // unpacking and interleaving 32-bit elements
+ // a0 b0 a1 b1 a4 b4 a5 b5 a8 b8 a9 b9 a12 b12 a13 b13
+ // a2 b2 a3 b3 a6 b6 a7 b7 a10 b10 a11 b11 a14 b14 a15 b15
+ // c0 d0 c1 d1 ...
+ // c2 d2 c3 d3 ...
+ // e0 f0 e1 f1 ...
+ // e2 f2 e3 f3 ...
+ // g0 h0 g1 h1 ...
+ // g2 h2 g3 h3 ...
+ // i0 ...
+ // i2 ...
+ // k0 ...
+ // k2 ...
+ // m0 ...
+ // m2 ...
+ // o0 ...
+ // o1 ...
+ ta = _mm512_unpacklo_ps(a, b);
+ tb = _mm512_unpackhi_ps(a, b);
+ tc = _mm512_unpacklo_ps(c, d);
+ td = _mm512_unpackhi_ps(c, d);
+ te = _mm512_unpacklo_ps(e, f);
+ tf = _mm512_unpackhi_ps(e, f);
+ tg = _mm512_unpacklo_ps(g, h);
+ th = _mm512_unpackhi_ps(g, h);
+ ti = _mm512_unpacklo_ps(i, j);
+ tj = _mm512_unpackhi_ps(i, j);
+ tk = _mm512_unpacklo_ps(k, l);
+ tl = _mm512_unpackhi_ps(k, l);
+ tm = _mm512_unpacklo_ps(m, n);
+ tn = _mm512_unpackhi_ps(m, n);
+ to = _mm512_unpacklo_ps(o, p);
+ tq = _mm512_unpackhi_ps(o, p);
+
+ // unpacking and interleaving 64-bit elements
+ // a0 b0 c0 d0 a4 b4 c4 d4 a8 b8 c8 d8 a12 b12 c12 d12
+ // a1 b1 c1 d1 ...
+ // a2 b2 c2 d2 ...
+ // a3 b3 c3 d3 ...
+ // e0 f0 g0 h0 e4 f4 g4 h4 e8 f8 g8 h8 e12 f12 g12 h12
+ // e1 f1 g1 h1 ...
+ // e2 f2 g2 h2 ...
+ // e3 f3 g3 h3 ...
+ // i0 j0 k0 l0 ...
+ // i1 j1 k1 l1 ...
+ // i2 j2 k2 l2 ...
+ // i3 j3 k3 l3 ...
+ // m0 n0 o0 p0 ...
+ // m1 n1 o1 p1 ...
+ // m2 n2 o2 p2 ...
+ // m3 n3 o3 p3 ...
+ a = _mm512_castpd_ps(
+ _mm512_unpacklo_pd(_mm512_castps_pd(ta), _mm512_castps_pd(tc)));
+ b = _mm512_castpd_ps(
+ _mm512_unpackhi_pd(_mm512_castps_pd(ta), _mm512_castps_pd(tc)));
+ c = _mm512_castpd_ps(
+ _mm512_unpacklo_pd(_mm512_castps_pd(tb), _mm512_castps_pd(td)));
+ d = _mm512_castpd_ps(
+ _mm512_unpackhi_pd(_mm512_castps_pd(tb), _mm512_castps_pd(td)));
+ e = _mm512_castpd_ps(
+ _mm512_unpacklo_pd(_mm512_castps_pd(te), _mm512_castps_pd(tg)));
+ f = _mm512_castpd_ps(
+ _mm512_unpackhi_pd(_mm512_castps_pd(te), _mm512_castps_pd(tg)));
+ g = _mm512_castpd_ps(
+ _mm512_unpacklo_pd(_mm512_castps_pd(tf), _mm512_castps_pd(th)));
+ h = _mm512_castpd_ps(
+ _mm512_unpackhi_pd(_mm512_castps_pd(tf), _mm512_castps_pd(th)));
+ i = _mm512_castpd_ps(
+ _mm512_unpacklo_pd(_mm512_castps_pd(ti), _mm512_castps_pd(tk)));
+ j = _mm512_castpd_ps(
+ _mm512_unpackhi_pd(_mm512_castps_pd(ti), _mm512_castps_pd(tk)));
+ k = _mm512_castpd_ps(
+ _mm512_unpacklo_pd(_mm512_castps_pd(tj), _mm512_castps_pd(tl)));
+ l = _mm512_castpd_ps(
+ _mm512_unpackhi_pd(_mm512_castps_pd(tj), _mm512_castps_pd(tl)));
+ m = _mm512_castpd_ps(
+ _mm512_unpacklo_pd(_mm512_castps_pd(tm), _mm512_castps_pd(to)));
+ n = _mm512_castpd_ps(
+ _mm512_unpackhi_pd(_mm512_castps_pd(tm), _mm512_castps_pd(to)));
+ o = _mm512_castpd_ps(
+ _mm512_unpacklo_pd(_mm512_castps_pd(tn), _mm512_castps_pd(tq)));
+ p = _mm512_castpd_ps(
+ _mm512_unpackhi_pd(_mm512_castps_pd(tn), _mm512_castps_pd(tq)));
+
+ // shuffle 128-bits (composed of 4 32-bit elements)
+ // a0 b0 c0 d0 a8 b8 c8 d8 e0 f0 g0 h0 e8 f8 g8 h8
+ // a1 b1 c1 d1 ...
+ // a2 b2 c2 d2 ...
+ // a3 b3 c3 d3 ...
+ // a4 b4 c4 d4 ...
+ // a5 b5 c5 d5 ...
+ // a6 b6 c6 d6 ...
+ // a7 b7 c7 d7 ...
+ // i0 j0 k0 l0 i8 j8 k8 l8 m0 n0 o0 p0 m8 n8 o8 p8
+ // i1 j1 k1 l1 ...
+ // i2 j2 k2 l2 ...
+ // i3 j3 k3 l3 ...
+ // i4 j4 k4 l4 ...
+ // i5 j5 k5 l5 ...
+ // i6 j6 k6 l6 ...
+ // i7 j7 k7 l7 ...
+ ta = _mm512_shuffle_f32x4(a, e, 0x88);
+ tb = _mm512_shuffle_f32x4(b, f, 0x88);
+ tc = _mm512_shuffle_f32x4(c, g, 0x88);
+ td = _mm512_shuffle_f32x4(d, h, 0x88);
+ te = _mm512_shuffle_f32x4(a, e, 0xdd);
+ tf = _mm512_shuffle_f32x4(b, f, 0xdd);
+ tg = _mm512_shuffle_f32x4(c, g, 0xdd);
+ th = _mm512_shuffle_f32x4(d, h, 0xdd);
+ ti = _mm512_shuffle_f32x4(i, m, 0x88);
+ tj = _mm512_shuffle_f32x4(j, n, 0x88);
+ tk = _mm512_shuffle_f32x4(k, o, 0x88);
+ tl = _mm512_shuffle_f32x4(l, p, 0x88);
+ tm = _mm512_shuffle_f32x4(i, m, 0xdd);
+ tn = _mm512_shuffle_f32x4(j, n, 0xdd);
+ to = _mm512_shuffle_f32x4(k, o, 0xdd);
+ tq = _mm512_shuffle_f32x4(l, p, 0xdd);
+
+ // shuffle 128-bits (composed of 4 32-bit elements)
+ // a0 b0 c0 d0 ... o0
+ // a1 b1 c1 d1 ... o1
+ // a2 b2 c2 d2 ... o2
+ // a3 b3 c3 d3 ... o3
+ // a4 ...
+ // a5 ...
+ // a6 ...
+ // a7 ...
+ // a8 ...
+ // a9 ...
+ // a10 ...
+ // a11 ...
+ // a12 ...
+ // a13 ...
+ // a14 ...
+ // a15 b15 c15 d15 ... o15
+ a = _mm512_shuffle_f32x4(ta, ti, 0x88);
+ b = _mm512_shuffle_f32x4(tb, tj, 0x88);
+ c = _mm512_shuffle_f32x4(tc, tk, 0x88);
+ d = _mm512_shuffle_f32x4(td, tl, 0x88);
+ e = _mm512_shuffle_f32x4(te, tm, 0x88);
+ f = _mm512_shuffle_f32x4(tf, tn, 0x88);
+ g = _mm512_shuffle_f32x4(tg, to, 0x88);
+ h = _mm512_shuffle_f32x4(th, tq, 0x88);
+ i = _mm512_shuffle_f32x4(ta, ti, 0xdd);
+ j = _mm512_shuffle_f32x4(tb, tj, 0xdd);
+ k = _mm512_shuffle_f32x4(tc, tk, 0xdd);
+ l = _mm512_shuffle_f32x4(td, tl, 0xdd);
+ m = _mm512_shuffle_f32x4(te, tm, 0xdd);
+ n = _mm512_shuffle_f32x4(tf, tn, 0xdd);
+ o = _mm512_shuffle_f32x4(tg, to, 0xdd);
+ p = _mm512_shuffle_f32x4(th, tq, 0xdd);
+
+ // store from registers to dst
+ _mm512_storeu_ps(&dst[0 * ld_dst], a);
+ _mm512_storeu_ps(&dst[1 * ld_dst], b);
+ _mm512_storeu_ps(&dst[2 * ld_dst], c);
+ _mm512_storeu_ps(&dst[3 * ld_dst], d);
+ _mm512_storeu_ps(&dst[4 * ld_dst], e);
+ _mm512_storeu_ps(&dst[5 * ld_dst], f);
+ _mm512_storeu_ps(&dst[6 * ld_dst], g);
+ _mm512_storeu_ps(&dst[7 * ld_dst], h);
+ _mm512_storeu_ps(&dst[8 * ld_dst], i);
+ _mm512_storeu_ps(&dst[9 * ld_dst], j);
+ _mm512_storeu_ps(&dst[10 * ld_dst], k);
+ _mm512_storeu_ps(&dst[11 * ld_dst], l);
+ _mm512_storeu_ps(&dst[12 * ld_dst], m);
+ _mm512_storeu_ps(&dst[13 * ld_dst], n);
+ _mm512_storeu_ps(&dst[14 * ld_dst], o);
+ _mm512_storeu_ps(&dst[15 * ld_dst], p);
+}
+
+void transpose_16x16(
+ int M,
+ int N,
+ const float* src,
+ int ld_src,
+ float* dst,
+ int ld_dst) {
+ int ib = 0, jb = 0;
+ for (ib = 0; ib + 16 <= M; ib += 16) {
+ for (jb = 0; jb + 16 <= N; jb += 16) {
+ transpose_kernel_16x16_avx512(
+ &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst);
+ }
+ }
+ transpose_8x8(ib, N - jb, &src[jb], ld_src, &dst[jb * ld_dst], ld_dst);
+ transpose_8x8(M - ib, N, &src[ib * ld_src], ld_src, &dst[ib], ld_dst);
+}
+
+} // namespace fbgemm2
diff --git a/src/codegen_fp16fp32.cc b/src/codegen_fp16fp32.cc
new file mode 100644
index 0000000..8e36c85
--- /dev/null
+++ b/src/codegen_fp16fp32.cc
@@ -0,0 +1,387 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+#include <assert.h>
+#include <cpuid.h>
+#include <stdlib.h>
+#include <string.h>
+#include <algorithm>
+#include <array>
+#include <fstream>
+#include <functional>
+#include <iostream>
+#include <map>
+#include <string>
+
+using namespace std;
+
+void addi(ofstream& of, string i, bool disable = false) {
+ if (disable == false)
+ of << "\"" + i + "\\t\\n\"" + "\n";
+}
+
+struct ISA {
+ unsigned avx; // 1, 2 or 3
+ string name;
+ vector<vector<unsigned>> shapes;
+};
+
+int main() {
+ bool iaca = false;
+ bool disable = false;
+
+ bool fixedA = true, fixedB = true, fixedC = true;
+
+ int eax, ebx, ecx, edx;
+ __cpuid(1 /* ecx = vendor string */, eax, ebx, ecx, edx);
+ printf("FC16 is %s supported\n", ((ecx & bit_F16C) ? " " : "not"));
+
+ string comma = ",";
+
+ vector<ISA> isa = {
+ // {1, "AVX", {{4, 1, 0}, {4, 2, 0}, {4, 3, 0}, {3, 1, 0}, {3, 2, 0}, {3,
+ // 3, 0}}},
+ { 2, "AVX2",
+ { { 1, 1, 0 },
+ { 2, 1, 0 },
+ { 3, 1, 0 },
+ { 4, 1, 0 },
+ { 5, 1, 0 },
+ { 6, 1, 0 },
+ { 7, 1, 0 },
+ { 8, 1, 0 },
+ { 9, 1, 0 },
+ { 10, 1, 0 },
+ { 11, 1, 0 },
+ { 12, 1, 0 },
+ { 13, 1, 0 },
+ { 14, 1, 0 },
+ }
+ }
+ };
+
+ // open all files
+ ofstream srcfile;
+ srcfile.open("FbgemmFP16UKernels.cc");
+ srcfile << "#include \"FbgemmFP16UKernels.h\"\n";
+ if (iaca)
+ srcfile << "#include \"iacaMarks.h\"\n";
+
+ ofstream hdrfile;
+ hdrfile.open("FbgemmFP16UKernels.h");
+
+ hdrfile << "#ifndef FBGEMM_UKERNELS\n";
+ hdrfile << "#define FBGEMM_UKERNELS\n";
+ hdrfile << "#include <cstdint>\n";
+ hdrfile << "#include <tuple>\n";
+ hdrfile << "#include <vector>\n";
+ hdrfile << "#include \"fbgemm/Types.h\"\n";
+ hdrfile << "using fp16 = fbgemm2::float16;\n";
+ hdrfile << "using fp32 = float;\n";
+ hdrfile << "struct GemmParams {uint64_t k; float *A; const fp16 *B;\n"
+ "float *beta; uint64_t accum; float *C; uint64_t ldc;\n"
+ "uint64_t b_block_cols; uint64_t b_block_size;};\n";
+
+ std::map<string, string> fptr_typedef;
+ fptr_typedef["fp16"] = "";
+ fptr_typedef["fp32"] = "";
+
+ unsigned labelId = 0;
+#if 1
+ for (auto fixedA : {false})
+ for (auto fixedB : {false})
+ for (auto fixedC : {false})
+#else
+ for (auto fixedA : {true})
+ for (auto fixedB : {true})
+ for (auto fixedC : {true})
+#endif
+ for (auto s : isa) {
+ vector<vector<unsigned>>& ukernel_shape = s.shapes;
+
+ vector<string> funcname(ukernel_shape.size()),
+ fheader(ukernel_shape.size());
+ string fargs;
+
+ for (auto fp16 : {true}) {
+ string B_type = ((fp16) ? "fp16" : "fp32");
+ string prefix = s.name + /*"_" + B_type */ + "_" + "fA" +
+ to_string(fixedA) + "fB" + to_string(fixedB) + "fC" +
+ to_string(fixedC);
+ cout << "Generating code for " << s.name << " " << B_type << "\n";
+
+ for (unsigned k = 0; k < ukernel_shape.size(); k++) {
+ printf(
+ "shape: %d x %d * 32\n",
+ ukernel_shape[k][0],
+ ukernel_shape[k][1]);
+
+ string p1 = "GemmParams *gp";
+
+ funcname[k] = "gemmkernel_" + to_string(ukernel_shape[k][0]) +
+ "x" + to_string(ukernel_shape[k][1]) + "_";
+ funcname[k] += prefix;
+
+ fargs = "(" + p1 + ")";
+
+ fheader[k] =
+ "void __attribute__ ((noinline)) " + funcname[k] + fargs;
+ srcfile << fheader[k] << "\n";
+ srcfile << "{\n";
+
+ unsigned last_free_ymmreg = 0;
+ // produce register block of C
+ vector<vector<string>> vCtile(ukernel_shape[k][0]);
+ for (auto r = 0; r < ukernel_shape[k][0]; r++)
+ for (auto c = 0; c < ukernel_shape[k][1]; c++) {
+ vCtile[r].push_back("ymm" + to_string(last_free_ymmreg));
+ last_free_ymmreg++;
+ }
+ assert(last_free_ymmreg <= 14);
+
+ string vAtmp = "ymm" + to_string(last_free_ymmreg++);
+ // produce register block of B col
+ assert(ukernel_shape[k][1] == 1);
+ vector<string> vBcol(ukernel_shape[k][1]);
+
+ for (auto c = 0; c < ukernel_shape[k][1]; c++) {
+ vBcol[c] = ("ymm" + to_string(last_free_ymmreg));
+ last_free_ymmreg++;
+ }
+
+ assert(last_free_ymmreg <= 16);
+
+ srcfile << "asm volatile\n";
+ srcfile << "(\n";
+
+ srcfile << "#if !defined(__clang__)" << "\n";
+ addi(srcfile, "mov r14, %[gp]");
+ srcfile << "#else\n";
+ addi(srcfile, "mov %[gp], %%r14");
+ addi(srcfile, ".intel_syntax noprefix");
+ srcfile << "#endif\n";
+
+ srcfile << "\n// Copy parameters\n";
+ srcfile << "// k\n";
+ addi(srcfile, "mov r8, [r14 + 0]");
+ srcfile << "// A\n";
+ addi(srcfile, "mov r9, [r14 + 8]");
+ srcfile << "// B\n";
+ addi(srcfile, "mov r10, [r14 + 16]");
+ srcfile << "// beta\n";
+ addi(srcfile, "mov r15, [r14 + 24]");
+ srcfile << "// accum\n";
+ addi(srcfile, "mov rdx, [r14 + 32]");
+ srcfile << "// C\n";
+ addi(srcfile, "mov r12, [r14 + 40]");
+ srcfile << "// ldc\n";
+ addi(srcfile, "mov r13, [r14 + 48]");
+ srcfile << "// b_block_cols\n";
+ addi(srcfile, "mov rdi, [r14 + 56]");
+ srcfile << "// b_block_size\n";
+ addi(srcfile, "mov rsi, [r14 + 64]");
+ srcfile << "// Make copies of A and C\n";
+ addi(srcfile, "mov rax, r9");
+ addi(srcfile, "mov rcx, r12");
+ srcfile << "\n\n";
+
+ addi(srcfile, "mov rbx, 0");
+
+ string exitlabel = "L_exit%=";
+ string label2 = "loop_outter%=";
+ addi(srcfile, label2 + ":");
+ addi(srcfile, "mov r14, 0");
+
+ // set all vCtile regs to zeros
+ for (auto r = 0; r < vCtile.size(); r++) {
+ for (auto c = 0; c < vCtile[r].size(); c++) {
+ addi(
+ srcfile,
+ "vxorps " + vCtile[r][c] + "," + vCtile[r][c] + "," +
+ vCtile[r][c]);
+ }
+ }
+
+ // start marker
+ if (iaca) {
+ addi(srcfile, "mov ebx, 111");
+ addi(srcfile, ".byte 0x64, 0x67, 0x90");
+ }
+
+ srcfile << "\n";
+
+ if (ukernel_shape[k][0] <= 13) {
+ addi(srcfile, "vcvtph2ps ymm15, XMMWORD PTR [r10 + 0]");
+ addi(srcfile, "mov r11, 16");
+ } else {
+ addi(srcfile, "mov r11, 0");
+ }
+
+ srcfile << "\n";
+ string label = "loop_inner%=";
+ addi(srcfile, label + ":");
+ srcfile << "\n";
+
+ if (ukernel_shape[k][0] <= 13) {
+ auto a_offset = 0, unroll_factor = 2;
+ for (auto u = 0; u < unroll_factor; u++) {
+ string breg = (u == 0) ? "ymm14" : "ymm15";
+ string breg_rev = (u == 0) ? "ymm15" : "ymm14";
+
+ addi(srcfile, "vcvtph2ps " + breg +
+ ",XMMWORD PTR [r10 + r11 + " +
+ to_string(u * 16) + "]");
+ addi(srcfile, "inc r14");
+ for (auto r = 0; r < vCtile.size(); r++) {
+ addi(srcfile, "vbroadcastss " + vAtmp + ",DWORD PTR [r9+" +
+ to_string(a_offset) + "]");
+ addi(srcfile, "vfmadd231ps " + vCtile[r][0] + "," +
+ breg_rev + "," + vAtmp);
+ if (u == 1 && r == vCtile.size() / 2)
+ addi(srcfile, "add r11, 32");
+ a_offset += 4;
+ }
+ if (u < unroll_factor - 1) {
+ addi(srcfile, "cmp r14, r8");
+ addi(srcfile, "jge " + exitlabel);
+ }
+ }
+
+ addi(srcfile, "add r9," + to_string(a_offset));
+ addi(srcfile, "cmp r14, r8");
+ addi(srcfile, "jl " + label);
+
+ srcfile << "\n";
+
+ addi(srcfile, exitlabel + ":");
+ } else {
+ addi(srcfile,
+ "vcvtph2ps " + vBcol[0] + ",XMMWORD PTR [r10 + r11]");
+ for (auto r = 0; r < vCtile.size(); r++) {
+ addi(srcfile, "vbroadcastss " + vAtmp + ",DWORD PTR [r9+" +
+ to_string(4 * r) + "]");
+ addi(srcfile, "vfmadd231ps " + vCtile[r][0] + "," + vBcol[0] +
+ "," + vAtmp);
+ }
+
+ addi(srcfile, "add r9," + to_string(4 * ukernel_shape[k][0]),
+ fixedA); // move A ptr
+ addi(srcfile, "add r11, 16");
+
+ addi(srcfile, "inc r14");
+ addi(srcfile, "cmp r14, r8");
+ addi(srcfile, "jl " + label);
+ }
+
+ addi(srcfile, "add r10, rsi");
+ srcfile << "\n";
+
+ // end marker
+ if (iaca) {
+ addi(srcfile, "mov ebx, 222");
+ addi(srcfile, ".byte 0x64, 0x67, 0x90");
+ }
+
+
+ addi(srcfile, "cmp rdx, 1");
+ addi(srcfile, "je L_accum%=");
+ srcfile << "// Dump C\n";
+
+ for (auto r = 0; r < vCtile.size(); r++) {
+ for (auto c = 0; c < vCtile[r].size(); c++) {
+ addi(srcfile, "vmovups YMMWORD PTR [r12 + " +
+ to_string(32 * c) + "], " + vCtile[r][c],
+ fixedC);
+ }
+ addi(srcfile, "add r12, r13", fixedC); // move C ptr
+ }
+ addi(srcfile, "jmp L_done%=");
+
+ srcfile << "\n\n";
+ addi(srcfile, "L_accum%=:");
+ srcfile << "// Dump C with accumulate\n";
+
+ string r_spare = (s.avx == 1) ? "ymm14" : "ymm15";
+ addi(srcfile,
+ "vbroadcastss " + r_spare + string(",DWORD PTR [r15]"),
+ fixedC);
+ // store out C
+ for (auto r = 0; r < vCtile.size(); r++) {
+ for (auto c = 0; c < vCtile[r].size(); c++) {
+ switch (s.avx) {
+ case 1:
+ addi(srcfile,
+ string("vmulps ymm15, ") + r_spare + comma +
+ "YMMWORD PTR [r12 + " + to_string(32 * c) + "]",
+ fixedC);
+ addi(srcfile, "vaddps " + vCtile[r][c] + "," +
+ vCtile[r][c] + "," + "ymm15",
+ fixedC);
+ break;
+ case 2:
+ addi(srcfile,
+ "vfmadd231ps " + vCtile[r][c] + "," + r_spare + "," +
+ "YMMWORD PTR [r12 + " + to_string(32 * c) + "]",
+ fixedC);
+ break;
+ default:
+ assert(0);
+ }
+ addi(srcfile, "vmovups YMMWORD PTR [r12 + " +
+ to_string(32 * c) + "], " + vCtile[r][c],
+ fixedC);
+ }
+ addi(srcfile, "add r12, r13", fixedC); // move C ptr
+ }
+
+ srcfile << "\n";
+ addi(srcfile, "L_done%=:");
+
+ srcfile << "\n// next outer iteration\n";
+ // C
+ addi(srcfile, "add rcx, " + to_string(32 * ukernel_shape[k][1]),
+ fixedC);
+ addi(srcfile, "mov r12, rcx", fixedC);
+ // A
+ addi(srcfile, "mov r9, rax");
+
+ addi(srcfile, "inc rbx");
+ addi(srcfile, "cmp rbx, rdi");
+ addi(srcfile, "jl " + label2);
+
+ // output
+ srcfile << ":\n";
+ // input
+ srcfile << ":\n";
+ srcfile << "[gp] \"rm\" (gp)\n";
+
+ // clobbered
+ srcfile
+ << (string) ": \"r8\", \"r9\", \"r10\", \"r11\", \"r15\", " +
+ (string) " \"r13\", \"r14\",\n" +
+ (string) "\"rax\", \"rcx\", "
+ "\"rdx\", \"rsi\", \"rdi\", \"rbx\", "
+ "\"r12\", \"memory\"" +
+ (string) "\n";
+ srcfile << ");\n";
+ srcfile << "}\n";
+ }
+
+ for (unsigned k = 0; k < ukernel_shape.size(); k++) {
+ hdrfile << fheader[k] << ";\n";
+ }
+
+ fptr_typedef[B_type] =
+ "typedef void (* funcptr_" + B_type + ") " + fargs;
+ }
+ }
+
+ srcfile.close();
+ hdrfile << fptr_typedef["fp16"] << ";\n";
+ hdrfile << fptr_typedef["fp32"] << ";\n";
+ hdrfile << "#endif\n";
+ hdrfile.close();
+}